From d118f17c0e51f72daa05eed089cf74f8b585c370 Mon Sep 17 00:00:00 2001 From: Hari Haran Rathinakumar Date: Tue, 23 Jun 2026 13:55:59 +0100 Subject: [PATCH 1/3] feat: add class-based callback system for training lifecycle hooks --- src/instructlab/training/__init__.py | 3 + .../training/batch_loss_manager.py | 17 +- src/instructlab/training/callbacks.py | 280 +++++++++++ src/instructlab/training/config.py | 8 +- src/instructlab/training/main_ds.py | 110 ++++- tests/unit/test_callbacks.py | 459 ++++++++++++++++++ 6 files changed, 874 insertions(+), 3 deletions(-) create mode 100644 src/instructlab/training/callbacks.py create mode 100644 tests/unit/test_callbacks.py diff --git a/src/instructlab/training/__init__.py b/src/instructlab/training/__init__.py index 136d1384..fcb0d754 100644 --- a/src/instructlab/training/__init__.py +++ b/src/instructlab/training/__init__.py @@ -5,7 +5,9 @@ "LoraOptions", "QuantizeDataType", "TorchrunArgs", + "TrainerCallback", "TrainingArgs", + "TrainingContext", "run_training", "FSDPOptions", "ShardingStrategies", @@ -17,6 +19,7 @@ import instructlab.training.logger # Disable package logging by default # Local +from .callbacks import TrainerCallback, TrainingContext from .config import ( DataProcessArgs, DeepSpeedOffloadStrategy, diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index 46e2af30..3d7975c2 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -48,7 +48,14 @@ class BatchLossManager: - Computing average losses for logging """ - def __init__(self, model, accelerator, world_size: int, local_rank: int): + def __init__( + self, + model, + accelerator, + world_size: int, + local_rank: int, + callback_manager=None, + ): """ Initialize the BatchLossManager. @@ -57,12 +64,14 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): accelerator: The accelerator instance for distributed training world_size: Number of distributed processes local_rank: Local rank of the current process + callback_manager: Optional CallbackManager for lifecycle hooks """ self.model: Model = model self.accelerator: Accelerator = accelerator self.world_size: int = world_size self.local_rank: int = local_rank self.torch_device = torch.device("cuda", local_rank) + self.callback_manager = callback_manager def process_batch( self, @@ -111,6 +120,9 @@ def process_batch( batch_total_samples += micro_batch_size batch_total_length += total_length + if self.callback_manager: + self.callback_manager.fire("on_before_forward") + # prepare model inputs model_inputs = self._prepare_model_inputs(mb) @@ -126,6 +138,9 @@ def process_batch( self.accelerator.backward(scaled_loss) + if self.callback_manager: + self.callback_manager.fire("on_after_backward") + # accumulate losses grad_accum_steps += 1 accumulated_loss += raw_losses.main_loss diff --git a/src/instructlab/training/callbacks.py b/src/instructlab/training/callbacks.py new file mode 100644 index 00000000..af17012a --- /dev/null +++ b/src/instructlab/training/callbacks.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +Callback system for training lifecycle hooks. + +Provides async, fire-and-forget callbacks that observe training events +without blocking the training loop or propagating exceptions. +""" + +# Standard +from dataclasses import dataclass, field +from typing import Any +import asyncio +import base64 +import copy +import inspect +import json +import logging +import textwrap +import threading + +logger = logging.getLogger("instructlab.training") + +HOOK_NAMES = [ + "on_train_begin", + "on_epoch_begin", + "on_step_begin", + "on_before_forward", + "on_after_backward", + "on_pre_optimizer_step", + "on_optimizer_step", + "on_log", + "on_evaluate", + "on_save", + "on_step_end", + "on_epoch_end", + "on_train_end", +] + + +@dataclass +class TrainingContext: + """Mutable training state maintained by the training loop. + + The CallbackManager snapshots this before dispatching to callbacks, + so callback authors receive an effectively read-only view. + """ + + hook_name: str = "" + + step: int = 0 + epoch: int = 0 + total_samples: int = 0 + total_tokens: int = 0 + + loss: float | None = None + learning_rate: float | None = None + grad_norm: float | None = None + elapsed_time: float | None = None + overall_throughput: float | None = None + cuda_mem_allocated: float | None = None + + batch_metrics: dict[str, Any] = field(default_factory=dict) + val_metrics: dict[str, Any] = field(default_factory=dict) + checkpoint_path: str | None = None + + output_dir: str = "" + model_name_or_path: str = "" + max_epochs: int = 0 + world_size: int = 1 + is_local_process_zero: bool = True + is_world_process_zero: bool = True + + +class TrainerCallback: + """Base class for training callbacks. Subclass and override hooks you need. + + Callbacks fire on ALL distributed ranks. Use context.is_world_process_zero + or context.is_local_process_zero to gate rank-specific side effects + (logging, saving, external API calls). + """ + + def on_train_begin(self, context: TrainingContext) -> None: + pass + + def on_epoch_begin(self, context: TrainingContext) -> None: + pass + + def on_step_begin(self, context: TrainingContext) -> None: + pass + + def on_before_forward(self, context: TrainingContext) -> None: + pass + + def on_after_backward(self, context: TrainingContext) -> None: + pass + + def on_pre_optimizer_step(self, context: TrainingContext) -> None: + pass + + def on_optimizer_step(self, context: TrainingContext) -> None: + pass + + def on_log(self, context: TrainingContext) -> None: + pass + + def on_evaluate(self, context: TrainingContext) -> None: + pass + + def on_save(self, context: TrainingContext) -> None: + pass + + def on_step_end(self, context: TrainingContext) -> None: + pass + + def on_epoch_end(self, context: TrainingContext) -> None: + pass + + def on_train_end(self, context: TrainingContext) -> None: + pass + + +class CallbackManager: + """Dispatches lifecycle hooks to registered TrainerCallback instances.""" + + def __init__(self): + self._callbacks: list[TrainerCallback] = [] + self.context = TrainingContext() + + self._loop = asyncio.new_event_loop() + self._thread = threading.Thread(target=self._run_event_loop, daemon=True) + self._thread.start() + + def _run_event_loop(self): + asyncio.set_event_loop(self._loop) + self._loop.run_forever() + + def add_callback(self, callback: TrainerCallback) -> None: + if not isinstance(callback, TrainerCallback): + raise TypeError( + f"Expected a TrainerCallback instance, got " + f"{type(callback).__name__}. " + f"Pass an instance, not a class: callbacks=[MyCallback()]" + ) + self._callbacks.append(callback) + + def remove_callback(self, callback_or_type) -> None: + if isinstance(callback_or_type, type): + self._callbacks = [ + cb for cb in self._callbacks if not isinstance(cb, callback_or_type) + ] + else: + self._callbacks = [ + cb for cb in self._callbacks if cb is not callback_or_type + ] + + def fire(self, hook_name: str, **kwargs) -> None: + if not self.has_callbacks(hook_name): + return + + snapshot = copy.copy(self.context) + snapshot.hook_name = hook_name + snapshot.batch_metrics = dict(snapshot.batch_metrics) + snapshot.val_metrics = dict(snapshot.val_metrics) + _valid_fields = {f.name for f in snapshot.__dataclass_fields__.values()} + for key, value in kwargs.items(): + if key not in _valid_fields: + raise ValueError( + f"Unknown TrainingContext field: '{key}'. Valid fields: {sorted(_valid_fields)}" + ) + setattr(snapshot, key, value) + + for callback in self._callbacks: + method = getattr(callback, hook_name) + if getattr(type(callback), hook_name) is getattr( + TrainerCallback, hook_name + ): + continue + future = asyncio.run_coroutine_threadsafe( + self._safe_invoke(method, snapshot), self._loop + ) + if hook_name == "on_train_end": + try: + future.result(timeout=10) + except TimeoutError: + logger.warning( + "Callback %s.%s timed out during on_train_end (10s limit).", + type(callback).__name__, + hook_name, + ) + except Exception: + pass + + async def _safe_invoke(self, method, context: TrainingContext) -> None: + try: + result = method(context) + if asyncio.iscoroutine(result): + await result + except Exception: + logger.exception( + "Callback %s.%s raised an exception (hook=%s, step=%d). " + "This exception is suppressed and will not affect training.", + type(method.__self__).__name__ + if hasattr(method, "__self__") + else repr(method), + method.__name__, + context.hook_name, + context.step, + ) + + def has_callbacks(self, hook_name: str) -> bool: + base_method = getattr(TrainerCallback, hook_name) + return any( + getattr(type(cb), hook_name) is not base_method for cb in self._callbacks + ) + + def close(self) -> None: + """Shut down the background event loop and thread.""" + self._loop.call_soon_threadsafe(self._loop.stop) + self._thread.join(timeout=5) + self._loop.close() + + +def serialize_callback(callback: TrainerCallback) -> str: + """Serialize a TrainerCallback subclass to a base64 string. + + Callbacks must be self-contained classes with zero-argument constructors. + Any imports needed inside hooks should be inline (inside the method body), + not at module level. + """ + cls = type(callback) + try: + cls() + except TypeError as e: + raise TypeError( + f"Callback {cls.__name__} must have a zero-argument constructor " + f"to be serializable across the torchrun subprocess boundary: {e}" + ) from e + source = inspect.getsource(cls) + source = textwrap.dedent(source) + return base64.b64encode(source.encode("utf-8")).decode("ascii") + + +def deserialize_callback(encoded: str) -> TrainerCallback: + """Reconstruct a TrainerCallback instance from a base64-encoded class source.""" + source = base64.b64decode(encoded).decode("utf-8") + namespace: dict[str, Any] = { + "TrainerCallback": TrainerCallback, + "TrainingContext": TrainingContext, + } + exec(source, namespace) # noqa: S102 + classes = [ + v + for v in namespace.values() + if isinstance(v, type) + and issubclass(v, TrainerCallback) + and v is not TrainerCallback + ] + if len(classes) != 1: + raise ValueError( + f"Expected exactly one TrainerCallback subclass, " + f"got {len(classes)}. Source:\n{source}" + ) + return classes[0]() + + +def serialize_callbacks_for_cli( + callbacks: list[TrainerCallback], +) -> str: + """Serialize a list of callbacks to a base64 string for CLI transport.""" + serialized = [serialize_callback(cb) for cb in callbacks] + return base64.b64encode(json.dumps(serialized).encode("utf-8")).decode("ascii") + + +def deserialize_callbacks_from_cli( + encoded: str, +) -> list[TrainerCallback]: + """Reconstruct TrainerCallback instances from a CLI-transported base64 string.""" + decoded = json.loads(base64.b64decode(encoded).decode("utf-8")) + return [deserialize_callback(s) for s in decoded] diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 8811ead0..ba190496 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -191,7 +191,7 @@ class TrainingArgs(BaseModel): """ # disable the protected namespace for the model_config field - model_config = ConfigDict(protected_namespaces=()) + model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True) # Either the name of a HuggingFace model or a path to a model saved in HuggingFace format. model_path: str @@ -374,6 +374,12 @@ class TrainingArgs(BaseModel): ), ) + callbacks: list | None = Field( + default=None, + exclude=True, + description="List of TrainerCallback instances for training lifecycle hooks.", + ) + @model_validator(mode="after") def validate_validation_config(self): if not 0.0 <= self.validation_split < 1.0: diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 08ee3f02..6d7278a1 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -174,6 +174,7 @@ def train( val_data_loader=None, validation_frequency=None, on_demand_checkpointing: bool = False, + callback_manager=None, ): model.train() @@ -228,13 +229,25 @@ def _save_and_exit(checkpoint_location: str) -> None: global_grad_norm = None # Initialize the batch loss manager - batch_loss_manager = BatchLossManager(model, accelerator, world_size, local_rank) + batch_loss_manager = BatchLossManager( + model, accelerator, world_size, local_rank, callback_manager=callback_manager + ) + + if callback_manager: + callback_manager.context.step = global_step + callback_manager.context.total_samples = samples_seen + callback_manager.fire("on_train_begin") # Blast through batches for epoch in range(args.current_epoch, args.num_epochs): # set the epoch for correct sampling accelerator.train_loader.sampler.set_epoch(epoch) num_epoch_steps = len(accelerator.train_loader) + + if callback_manager: + callback_manager.context.epoch = epoch + callback_manager.fire("on_epoch_begin") + if local_rank == 0: inner_pb = tqdm(range(num_epoch_steps), desc=f"Epoch {epoch}") @@ -248,6 +261,10 @@ def _save_and_exit(checkpoint_location: str) -> None: continue start = time.time() + if callback_manager: + callback_manager.context.step = global_step + callback_manager.fire("on_step_begin") + # Process the batch using the BatchLossManager. # When on-demand checkpointing is enabled, pass a callback so # the check runs after every minibatch backward rather than @@ -265,24 +282,40 @@ def _save_and_exit(checkpoint_location: str) -> None: # exact resumption. if batch_metrics.interrupted: _save_and_exit("during minibatch processing") + if callback_manager: + callback_manager.fire("on_train_end") + callback_manager.close() return if on_demand_checkpointing and check_checkpoint_requested(): _save_and_exit("before optimizer step") + if callback_manager: + callback_manager.fire("on_train_end") + callback_manager.close() return base_logger.info( f"Epoch: {epoch}, Step: {global_step}, Rank: {dist.get_rank()}, loss = {avg_loss_across_ranks:.6f}, grad_accum_steps = {batch_metrics.grad_accum_steps}" ) + if callback_manager: + callback_manager.context.loss = float(avg_loss_across_ranks) + callback_manager.fire("on_pre_optimizer_step") + # Take optimizer step after all minibatches accelerator.take_optimizer_step() + if callback_manager: + callback_manager.fire("on_optimizer_step") + # Update samples seen after the optimizer step has been applied samples_seen += batch_metrics.total_samples if on_demand_checkpointing and check_checkpoint_requested(): _save_and_exit("after optimizer step") + if callback_manager: + callback_manager.fire("on_train_end") + callback_manager.close() return if local_rank == 0: @@ -328,6 +361,23 @@ def _save_and_exit(checkpoint_location: str) -> None: extra={"step": global_step}, ) + if callback_manager: + callback_manager.context.learning_rate = current_lr + callback_manager.context.grad_norm = global_grad_norm + callback_manager.context.elapsed_time = elapsed_time + callback_manager.context.overall_throughput = overall_throughput + callback_manager.context.cuda_mem_allocated = cuda_mem_allocated + callback_manager.context.total_samples = samples_seen + callback_manager.context.total_tokens = batch_metrics.total_length + callback_manager.context.batch_metrics = { + "total_samples": batch_metrics.total_samples, + "total_length": batch_metrics.total_length, + "num_loss_counted_tokens": batch_metrics.num_loss_counted_tokens, + "grad_accum_steps": batch_metrics.grad_accum_steps, + "num_minibatches": batch_metrics.num_minibatches, + } + callback_manager.fire("on_log") + # Compute validation loss if it's time to validate if ( val_data_loader is not None @@ -343,6 +393,9 @@ def _save_and_exit(checkpoint_location: str) -> None: val_metrics, extra={"step": global_step}, ) + if callback_manager and val_metrics: + callback_manager.context.val_metrics = dict(val_metrics) + callback_manager.fire("on_evaluate") if args.save_samples > 0 and (samples_seen % args.save_samples == 0): base_logger.debug(f"Saving checkpoint at step {global_step}") @@ -357,11 +410,18 @@ def _save_and_exit(checkpoint_location: str) -> None: ) base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) dist.barrier() + if callback_manager: + callback_manager.fire("on_save", checkpoint_path=args.output_dir) global_step += 1 if local_rank == 0: inner_pb.update(1) torch.cuda.empty_cache() + + if callback_manager: + callback_manager.context.step = global_step + callback_manager.fire("on_step_end") + if args.checkpoint_at_epoch: base_logger.debug(f"Saving checkpoint at epoch {epoch}") save_checkpoint( @@ -377,6 +437,11 @@ def _save_and_exit(checkpoint_location: str) -> None: ) base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank) dist.barrier() + if callback_manager: + callback_manager.fire("on_save", checkpoint_path=args.output_dir) + + if callback_manager: + callback_manager.fire("on_epoch_end") if args.save_last: save_hf_format_accelerate( @@ -387,6 +452,12 @@ def _save_and_exit(checkpoint_location: str) -> None: samples_seen, is_lora=bool(args.lora_r), ) + if callback_manager: + callback_manager.fire("on_save", checkpoint_path=args.output_dir) + + if callback_manager: + callback_manager.fire("on_train_end") + callback_manager.close() # This function makes an effort to stick to a default value from torch library, @@ -616,6 +687,28 @@ def main(args): load_latest_full_state(args=args, accelerator=accelerator) + # Deserialize callbacks if passed via CLI + callback_manager = None + if getattr(args, "callbacks", None): + # First Party + from instructlab.training.callbacks import ( + CallbackManager, + deserialize_callbacks_from_cli, + ) + + callback_manager = CallbackManager() + for cb in deserialize_callbacks_from_cli(args.callbacks): + callback_manager.add_callback(cb) + + callback_manager.context.output_dir = args.output_dir + callback_manager.context.model_name_or_path = args.model_name_or_path + callback_manager.context.max_epochs = args.num_epochs + callback_manager.context.world_size = int(os.environ.get("WORLD_SIZE", "1")) + callback_manager.context.is_local_process_zero = ( + int(os.environ["LOCAL_RANK"]) == 0 + ) + callback_manager.context.is_world_process_zero = dist.get_rank() == 0 + train( args, model=m, @@ -623,6 +716,7 @@ def main(args): val_data_loader=val_loader, validation_frequency=validation_frequency, on_demand_checkpointing=getattr(args, "on_demand_checkpointing", False), + callback_manager=callback_manager, ) dist.barrier() @@ -863,6 +957,14 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.on_demand_checkpointing: command.append("--on_demand_checkpointing") + if train_args.callbacks: + # First Party + from instructlab.training.callbacks import serialize_callbacks_for_cli + + command.append( + f"--callbacks={serialize_callbacks_for_cli(train_args.callbacks)}" + ) + logger.info("Running training command as subprocess: %s", " ".join(command)) # --- On-demand checkpointing: install signal handlers in the parent --- @@ -1245,6 +1347,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: default=None, help="How often to evaluate validation loss (in training steps). Required when validation_split > 0.", ) + parser.add_argument( + "--callbacks", + type=str, + default=None, + help="Base64-encoded serialized callbacks (internal use, set via TrainingArgs).", + ) args = parser.parse_args() if args.validation_split > 0.0 and ( diff --git a/tests/unit/test_callbacks.py b/tests/unit/test_callbacks.py new file mode 100644 index 00000000..27ee4433 --- /dev/null +++ b/tests/unit/test_callbacks.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the callback system.""" + +# Standard +import time + +# Third Party +import pytest + +# First Party +from instructlab.training.callbacks import ( + HOOK_NAMES, + CallbackManager, + TrainerCallback, + TrainingContext, + deserialize_callbacks_from_cli, + serialize_callbacks_for_cli, +) + + +@pytest.fixture +def mgr(): + """Create a CallbackManager and close it after the test.""" + m = CallbackManager() + yield m + m.close() + + +class TestTrainingContext: + def test_defaults(self): + ctx = TrainingContext() + assert ctx.step == 0 + assert ctx.epoch == 0 + assert ctx.loss is None + assert ctx.batch_metrics == {} + assert ctx.val_metrics == {} + assert ctx.is_world_process_zero is True + + def test_field_assignment(self): + ctx = TrainingContext(step=10, epoch=2, loss=0.5) + assert ctx.step == 10 + assert ctx.epoch == 2 + assert ctx.loss == 0.5 + + +class TestTrainerCallback: + def test_all_hooks_are_noop(self): + cb = TrainerCallback() + ctx = TrainingContext() + for hook in HOOK_NAMES: + getattr(cb, hook)(ctx) + + def test_subclass_override(self): + class MyCallback(TrainerCallback): + def __init__(self): + self.called = False + + def on_train_begin(self, context): + self.called = True + + cb = MyCallback() + cb.on_train_begin(TrainingContext()) + assert cb.called + + +class TestCallbackManager: + def test_fire_dispatches(self, mgr): + results = [] + + class Recorder(TrainerCallback): + def on_step_begin(self, context): + results.append(("on_step_begin", context.step)) + + mgr.add_callback(Recorder()) + mgr.context.step = 5 + mgr.fire("on_step_begin") + time.sleep(0.1) + assert results == [("on_step_begin", 5)] + + def test_fire_skips_non_overridden(self, mgr): + results = [] + + class Partial(TrainerCallback): + def on_log(self, context): + results.append("on_log") + + mgr.add_callback(Partial()) + mgr.fire("on_step_begin") + mgr.fire("on_log") + time.sleep(0.1) + assert results == ["on_log"] + + def test_has_callbacks(self, mgr): + class OnlyLog(TrainerCallback): + def on_log(self, context): + pass + + mgr.add_callback(OnlyLog()) + assert mgr.has_callbacks("on_log") is True + assert mgr.has_callbacks("on_save") is False + + def test_snapshot_isolation(self, mgr): + captured = [] + + class Capture(TrainerCallback): + def on_step_begin(self, context): + captured.append(context.step) + + mgr.add_callback(Capture()) + mgr.context.step = 1 + mgr.fire("on_step_begin") + mgr.context.step = 999 + time.sleep(0.1) + assert captured == [1] + + def test_exception_isolation(self, mgr): + class Broken(TrainerCallback): + def on_train_begin(self, context): + raise RuntimeError("boom") + + mgr.add_callback(Broken()) + mgr.fire("on_train_begin") + time.sleep(0.1) + + def test_multiple_callbacks(self, mgr): + results = [] + + class A(TrainerCallback): + def on_save(self, context): + results.append("A") + + class B(TrainerCallback): + def on_save(self, context): + results.append("B") + + mgr.add_callback(A()) + mgr.add_callback(B()) + mgr.fire("on_save") + time.sleep(0.1) + assert sorted(results) == ["A", "B"] + + def test_kwargs_set_on_snapshot(self, mgr): + captured = [] + + class SaveCb(TrainerCallback): + def on_save(self, context): + captured.append(context.checkpoint_path) + + mgr.add_callback(SaveCb()) + mgr.fire("on_save", checkpoint_path="/tmp/ckpt") + time.sleep(0.1) + assert captured == ["/tmp/ckpt"] + + def test_add_callback_type_error(self, mgr): + with pytest.raises(TypeError, match="TrainerCallback instance"): + mgr.add_callback("not a callback") + + def test_remove_callback_by_instance(self, mgr): + class Dummy(TrainerCallback): + def on_log(self, context): + pass + + cb = Dummy() + mgr.add_callback(cb) + assert mgr.has_callbacks("on_log") + mgr.remove_callback(cb) + assert not mgr.has_callbacks("on_log") + + def test_remove_callback_by_type(self, mgr): + class Dummy(TrainerCallback): + def on_log(self, context): + pass + + mgr.add_callback(Dummy()) + mgr.remove_callback(Dummy) + assert not mgr.has_callbacks("on_log") + + def test_fire_all_ranks(self, mgr): + results = [] + + class RankCb(TrainerCallback): + def on_log(self, context): + results.append(context.is_world_process_zero) + + mgr.add_callback(RankCb()) + mgr.context.is_world_process_zero = False + mgr.fire("on_log") + time.sleep(0.1) + assert results == [False] + + def test_on_train_end_blocks(self, mgr): + called = [] + + class SlowCb(TrainerCallback): + def on_train_end(self, context): + called.append(True) + + mgr.add_callback(SlowCb()) + mgr.fire("on_train_end") + assert called == [True] + + def test_fire_invalid_kwarg_raises(self, mgr): + class Dummy(TrainerCallback): + def on_save(self, context): + pass + + mgr.add_callback(Dummy()) + with pytest.raises(ValueError, match="Unknown TrainingContext field"): + mgr.fire("on_save", nonexistent_field="bad") + + def test_close(self): + m = CallbackManager() + assert m._thread.is_alive() + m.close() + assert not m._thread.is_alive() + + def test_empty_manager_no_callbacks(self, mgr): + assert mgr.has_callbacks("on_log") is False + mgr.fire("on_log") + + def test_hook_name_set_on_snapshot(self, mgr): + captured = [] + + class HookNameCb(TrainerCallback): + def on_step_begin(self, context): + captured.append(context.hook_name) + + mgr.add_callback(HookNameCb()) + mgr.fire("on_step_begin") + time.sleep(0.1) + assert captured == ["on_step_begin"] + + def test_dict_fields_snapshot_isolation(self, mgr): + captured = [] + + class MetricsCb(TrainerCallback): + def on_log(self, context): + captured.append(context.batch_metrics) + + mgr.add_callback(MetricsCb()) + mgr.context.batch_metrics = {"loss": 1.0} + mgr.fire("on_log") + mgr.context.batch_metrics["loss"] = 999.0 + time.sleep(0.1) + assert captured == [{"loss": 1.0}] + + +class TestAllRanksAndUserAPI: + """Tests that callbacks fire on all ranks and users can configure rank behavior.""" + + def test_fires_on_non_zero_rank(self, mgr): + results = [] + + class AllRankCb(TrainerCallback): + def on_step_begin(self, context): + results.append(context.is_world_process_zero) + + mgr.add_callback(AllRankCb()) + mgr.context.is_world_process_zero = False + mgr.context.is_local_process_zero = False + mgr.fire("on_step_begin") + time.sleep(0.1) + assert results == [False] + + def test_fires_on_rank_zero(self, mgr): + results = [] + + class AllRankCb(TrainerCallback): + def on_step_begin(self, context): + results.append(context.is_world_process_zero) + + mgr.add_callback(AllRankCb()) + mgr.context.is_world_process_zero = True + mgr.fire("on_step_begin") + time.sleep(0.1) + assert results == [True] + + def test_user_can_gate_on_rank_zero(self, mgr): + results = [] + + class RankGatedCb(TrainerCallback): + def on_log(self, context): + if context.is_world_process_zero: + results.append("logged") + + mgr.add_callback(RankGatedCb()) + + mgr.context.is_world_process_zero = False + mgr.fire("on_log") + time.sleep(0.1) + assert results == [] + + mgr.context.is_world_process_zero = True + mgr.fire("on_log") + time.sleep(0.1) + assert results == ["logged"] + + def test_user_can_gate_on_local_rank_zero(self, mgr): + results = [] + + class LocalRankCb(TrainerCallback): + def on_save(self, context): + if context.is_local_process_zero: + results.append("local_rank_0") + + mgr.add_callback(LocalRankCb()) + + mgr.context.is_local_process_zero = False + mgr.fire("on_save") + time.sleep(0.1) + assert results == [] + + mgr.context.is_local_process_zero = True + mgr.fire("on_save") + time.sleep(0.1) + assert results == ["local_rank_0"] + + def test_both_rank_flags_exposed(self, mgr): + captured = {} + + class RankInfoCb(TrainerCallback): + def on_train_begin(self, context): + captured["world"] = context.is_world_process_zero + captured["local"] = context.is_local_process_zero + + mgr.add_callback(RankInfoCb()) + mgr.context.is_world_process_zero = False + mgr.context.is_local_process_zero = True + mgr.fire("on_train_begin") + time.sleep(0.1) + assert captured == {"world": False, "local": True} + + def test_all_13_hooks_fire(self, mgr): + fired = [] + + class AllHooksCb(TrainerCallback): + def on_train_begin(self, context): + fired.append("on_train_begin") + + def on_epoch_begin(self, context): + fired.append("on_epoch_begin") + + def on_step_begin(self, context): + fired.append("on_step_begin") + + def on_before_forward(self, context): + fired.append("on_before_forward") + + def on_after_backward(self, context): + fired.append("on_after_backward") + + def on_pre_optimizer_step(self, context): + fired.append("on_pre_optimizer_step") + + def on_optimizer_step(self, context): + fired.append("on_optimizer_step") + + def on_log(self, context): + fired.append("on_log") + + def on_evaluate(self, context): + fired.append("on_evaluate") + + def on_save(self, context): + fired.append("on_save") + + def on_step_end(self, context): + fired.append("on_step_end") + + def on_epoch_end(self, context): + fired.append("on_epoch_end") + + def on_train_end(self, context): + fired.append("on_train_end") + + mgr.add_callback(AllHooksCb()) + for hook in HOOK_NAMES: + mgr.fire(hook) + if hook != "on_train_end": + time.sleep(0.05) + assert sorted(fired) == sorted(HOOK_NAMES) + + def test_public_import_from_package(self): + from instructlab.training import TrainerCallback, TrainingContext + + assert TrainerCallback is not None + assert TrainingContext is not None + + cb = TrainerCallback() + ctx = TrainingContext() + cb.on_log(ctx) + + def test_training_args_accepts_callbacks(self): + from instructlab.training import TrainingArgs + + assert "callbacks" in TrainingArgs.model_fields + + def test_context_has_training_config_fields(self): + ctx = TrainingContext( + output_dir="/tmp/output", + model_name_or_path="my-model", + max_epochs=3, + world_size=4, + ) + assert ctx.output_dir == "/tmp/output" + assert ctx.model_name_or_path == "my-model" + assert ctx.max_epochs == 3 + assert ctx.world_size == 4 + + +class TestSerialization: + def test_round_trip(self): + class TestCallback(TrainerCallback): + def on_log(self, context): + pass + + callbacks = [TestCallback()] + encoded = serialize_callbacks_for_cli(callbacks) + restored = deserialize_callbacks_from_cli(encoded) + assert len(restored) == 1 + assert isinstance(restored[0], TrainerCallback) + assert type(restored[0]).__name__ == "TestCallback" + + def test_round_trip_preserves_behavior(self): + class Adder(TrainerCallback): + def on_log(self, context): + context.loss = 42.0 + + encoded = serialize_callbacks_for_cli([Adder()]) + restored = deserialize_callbacks_from_cli(encoded) + ctx = TrainingContext() + restored[0].on_log(ctx) + assert ctx.loss == 42.0 + + def test_multiple_callbacks_round_trip(self): + class First(TrainerCallback): + def on_save(self, context): + pass + + class Second(TrainerCallback): + def on_log(self, context): + pass + + encoded = serialize_callbacks_for_cli([First(), Second()]) + restored = deserialize_callbacks_from_cli(encoded) + assert len(restored) == 2 + assert type(restored[0]).__name__ == "First" + assert type(restored[1]).__name__ == "Second" + + def test_non_zero_arg_constructor_raises(self): + class BadCallback(TrainerCallback): + def __init__(self, url): + self.url = url + + def on_log(self, context): + pass + + with pytest.raises(TypeError, match="zero-argument constructor"): + serialize_callbacks_for_cli([BadCallback("http://example.com")]) From 9181b5c3500cae5d5d334810110ff6d1edefe310 Mon Sep 17 00:00:00 2001 From: Hari Haran Rathinakumar Date: Wed, 24 Jun 2026 08:48:36 +0100 Subject: [PATCH 2/3] fix: ensure CallbackManager is closed in test_close --- tests/unit/test_callbacks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_callbacks.py b/tests/unit/test_callbacks.py index 27ee4433..ffe91100 100644 --- a/tests/unit/test_callbacks.py +++ b/tests/unit/test_callbacks.py @@ -210,8 +210,10 @@ def on_save(self, context): def test_close(self): m = CallbackManager() - assert m._thread.is_alive() - m.close() + try: + assert m._thread.is_alive() + finally: + m.close() assert not m._thread.is_alive() def test_empty_manager_no_callbacks(self, mgr): From 396cd7025950646eb3369c4d560297ea5749652f Mon Sep 17 00:00:00 2001 From: Hari Haran Rathinakumar Date: Thu, 25 Jun 2026 08:42:27 +0100 Subject: [PATCH 3/3] lint fix --- tests/unit/test_callbacks.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_callbacks.py b/tests/unit/test_callbacks.py index ffe91100..3ed0bc98 100644 --- a/tests/unit/test_callbacks.py +++ b/tests/unit/test_callbacks.py @@ -383,6 +383,7 @@ def on_train_end(self, context): assert sorted(fired) == sorted(HOOK_NAMES) def test_public_import_from_package(self): + # First Party from instructlab.training import TrainerCallback, TrainingContext assert TrainerCallback is not None @@ -393,6 +394,7 @@ def test_public_import_from_package(self): cb.on_log(ctx) def test_training_args_accepts_callbacks(self): + # First Party from instructlab.training import TrainingArgs assert "callbacks" in TrainingArgs.model_fields