Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"LoraOptions",
"QuantizeDataType",
"TorchrunArgs",
"TrainerCallback",
"TrainingArgs",
"TrainingContext",
"run_training",
"FSDPOptions",
"ShardingStrategies",
Expand All @@ -17,6 +19,7 @@
import instructlab.training.logger # Disable package logging by default

# Local
from .callbacks import TrainerCallback, TrainingContext
from .config import (
DataProcessArgs,
DeepSpeedOffloadStrategy,
Expand Down
17 changes: 16 additions & 1 deletion src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ class BatchLossManager:
- Computing average losses for logging
"""

def __init__(self, model, accelerator, world_size: int, local_rank: int):
def __init__(
self,
model,
accelerator,
world_size: int,
local_rank: int,
callback_manager=None,
):
"""
Initialize the BatchLossManager.

Expand All @@ -57,12 +64,14 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int):
accelerator: The accelerator instance for distributed training
world_size: Number of distributed processes
local_rank: Local rank of the current process
callback_manager: Optional CallbackManager for lifecycle hooks
"""
self.model: Model = model
self.accelerator: Accelerator = accelerator
self.world_size: int = world_size
self.local_rank: int = local_rank
self.torch_device = torch.device("cuda", local_rank)
self.callback_manager = callback_manager

def process_batch(
self,
Expand Down Expand Up @@ -111,6 +120,9 @@ def process_batch(
batch_total_samples += micro_batch_size
batch_total_length += total_length

if self.callback_manager:
self.callback_manager.fire("on_before_forward")

# prepare model inputs
model_inputs = self._prepare_model_inputs(mb)

Expand All @@ -126,6 +138,9 @@ def process_batch(

self.accelerator.backward(scaled_loss)

if self.callback_manager:
self.callback_manager.fire("on_after_backward")

# accumulate losses
grad_accum_steps += 1
accumulated_loss += raw_losses.main_loss
Expand Down
280 changes: 280 additions & 0 deletions src/instructlab/training/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0

"""
Callback system for training lifecycle hooks.

Provides async, fire-and-forget callbacks that observe training events
without blocking the training loop or propagating exceptions.
"""

# Standard
from dataclasses import dataclass, field
from typing import Any
import asyncio
import base64
import copy
import inspect
import json
import logging
import textwrap
import threading

logger = logging.getLogger("instructlab.training")

HOOK_NAMES = [
"on_train_begin",
"on_epoch_begin",
"on_step_begin",
"on_before_forward",
"on_after_backward",
"on_pre_optimizer_step",
"on_optimizer_step",
"on_log",
"on_evaluate",
"on_save",
"on_step_end",
"on_epoch_end",
"on_train_end",
]


@dataclass
class TrainingContext:
"""Mutable training state maintained by the training loop.

The CallbackManager snapshots this before dispatching to callbacks,
so callback authors receive an effectively read-only view.
"""

hook_name: str = ""

step: int = 0
epoch: int = 0
total_samples: int = 0
total_tokens: int = 0

loss: float | None = None
learning_rate: float | None = None
grad_norm: float | None = None
elapsed_time: float | None = None
overall_throughput: float | None = None
cuda_mem_allocated: float | None = None

batch_metrics: dict[str, Any] = field(default_factory=dict)
val_metrics: dict[str, Any] = field(default_factory=dict)
checkpoint_path: str | None = None

output_dir: str = ""
model_name_or_path: str = ""
max_epochs: int = 0
world_size: int = 1
is_local_process_zero: bool = True
is_world_process_zero: bool = True


class TrainerCallback:
"""Base class for training callbacks. Subclass and override hooks you need.

Callbacks fire on ALL distributed ranks. Use context.is_world_process_zero
or context.is_local_process_zero to gate rank-specific side effects
(logging, saving, external API calls).
"""

def on_train_begin(self, context: TrainingContext) -> None:
pass

def on_epoch_begin(self, context: TrainingContext) -> None:
pass

def on_step_begin(self, context: TrainingContext) -> None:
pass

def on_before_forward(self, context: TrainingContext) -> None:
pass

def on_after_backward(self, context: TrainingContext) -> None:
pass

def on_pre_optimizer_step(self, context: TrainingContext) -> None:
pass

def on_optimizer_step(self, context: TrainingContext) -> None:
pass

def on_log(self, context: TrainingContext) -> None:
pass

def on_evaluate(self, context: TrainingContext) -> None:
pass

def on_save(self, context: TrainingContext) -> None:
pass

def on_step_end(self, context: TrainingContext) -> None:
pass

def on_epoch_end(self, context: TrainingContext) -> None:
pass

def on_train_end(self, context: TrainingContext) -> None:
pass


class CallbackManager:
"""Dispatches lifecycle hooks to registered TrainerCallback instances."""

def __init__(self):
self._callbacks: list[TrainerCallback] = []
self.context = TrainingContext()

self._loop = asyncio.new_event_loop()
self._thread = threading.Thread(target=self._run_event_loop, daemon=True)
self._thread.start()

def _run_event_loop(self):
asyncio.set_event_loop(self._loop)
self._loop.run_forever()

def add_callback(self, callback: TrainerCallback) -> None:
if not isinstance(callback, TrainerCallback):
raise TypeError(
f"Expected a TrainerCallback instance, got "
f"{type(callback).__name__}. "
f"Pass an instance, not a class: callbacks=[MyCallback()]"
)
self._callbacks.append(callback)

def remove_callback(self, callback_or_type) -> None:
if isinstance(callback_or_type, type):
self._callbacks = [
cb for cb in self._callbacks if not isinstance(cb, callback_or_type)
]
else:
self._callbacks = [
cb for cb in self._callbacks if cb is not callback_or_type
]

def fire(self, hook_name: str, **kwargs) -> None:
if not self.has_callbacks(hook_name):
return

snapshot = copy.copy(self.context)
snapshot.hook_name = hook_name
snapshot.batch_metrics = dict(snapshot.batch_metrics)
snapshot.val_metrics = dict(snapshot.val_metrics)
_valid_fields = {f.name for f in snapshot.__dataclass_fields__.values()}
for key, value in kwargs.items():
if key not in _valid_fields:
raise ValueError(
f"Unknown TrainingContext field: '{key}'. Valid fields: {sorted(_valid_fields)}"
)
setattr(snapshot, key, value)

for callback in self._callbacks:
method = getattr(callback, hook_name)
if getattr(type(callback), hook_name) is getattr(
TrainerCallback, hook_name
):
continue
future = asyncio.run_coroutine_threadsafe(
self._safe_invoke(method, snapshot), self._loop
)
if hook_name == "on_train_end":
try:
future.result(timeout=10)
except TimeoutError:
logger.warning(
"Callback %s.%s timed out during on_train_end (10s limit).",
type(callback).__name__,
hook_name,
)
except Exception:
pass

async def _safe_invoke(self, method, context: TrainingContext) -> None:
try:
result = method(context)
if asyncio.iscoroutine(result):
await result
except Exception:
logger.exception(
"Callback %s.%s raised an exception (hook=%s, step=%d). "
"This exception is suppressed and will not affect training.",
type(method.__self__).__name__
if hasattr(method, "__self__")
else repr(method),
method.__name__,
context.hook_name,
context.step,
)

def has_callbacks(self, hook_name: str) -> bool:
base_method = getattr(TrainerCallback, hook_name)
return any(
getattr(type(cb), hook_name) is not base_method for cb in self._callbacks
)

def close(self) -> None:
"""Shut down the background event loop and thread."""
self._loop.call_soon_threadsafe(self._loop.stop)
self._thread.join(timeout=5)
self._loop.close()


def serialize_callback(callback: TrainerCallback) -> str:
"""Serialize a TrainerCallback subclass to a base64 string.

Callbacks must be self-contained classes with zero-argument constructors.
Any imports needed inside hooks should be inline (inside the method body),
not at module level.
"""
cls = type(callback)
try:
cls()
except TypeError as e:
raise TypeError(
f"Callback {cls.__name__} must have a zero-argument constructor "
f"to be serializable across the torchrun subprocess boundary: {e}"
) from e
source = inspect.getsource(cls)
source = textwrap.dedent(source)
return base64.b64encode(source.encode("utf-8")).decode("ascii")


def deserialize_callback(encoded: str) -> TrainerCallback:
"""Reconstruct a TrainerCallback instance from a base64-encoded class source."""
source = base64.b64decode(encoded).decode("utf-8")
namespace: dict[str, Any] = {
"TrainerCallback": TrainerCallback,
"TrainingContext": TrainingContext,
}
exec(source, namespace) # noqa: S102
classes = [
v
for v in namespace.values()
if isinstance(v, type)
and issubclass(v, TrainerCallback)
and v is not TrainerCallback
]
if len(classes) != 1:
raise ValueError(
f"Expected exactly one TrainerCallback subclass, "
f"got {len(classes)}. Source:\n{source}"
)
return classes[0]()


def serialize_callbacks_for_cli(
callbacks: list[TrainerCallback],
) -> str:
"""Serialize a list of callbacks to a base64 string for CLI transport."""
serialized = [serialize_callback(cb) for cb in callbacks]
return base64.b64encode(json.dumps(serialized).encode("utf-8")).decode("ascii")


def deserialize_callbacks_from_cli(
encoded: str,
) -> list[TrainerCallback]:
"""Reconstruct TrainerCallback instances from a CLI-transported base64 string."""
decoded = json.loads(base64.b64decode(encoded).decode("utf-8"))
return [deserialize_callback(s) for s in decoded]
Comment thread
coderabbitai[bot] marked this conversation as resolved.
8 changes: 7 additions & 1 deletion src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class TrainingArgs(BaseModel):
"""

# disable the protected namespace for the model_config field
model_config = ConfigDict(protected_namespaces=())
model_config = ConfigDict(protected_namespaces=(), arbitrary_types_allowed=True)

# Either the name of a HuggingFace model or a path to a model saved in HuggingFace format.
model_path: str
Expand Down Expand Up @@ -374,6 +374,12 @@ class TrainingArgs(BaseModel):
),
)

callbacks: list | None = Field(
default=None,
exclude=True,
description="List of TrainerCallback instances for training lifecycle hooks.",
)

@model_validator(mode="after")
def validate_validation_config(self):
if not 0.0 <= self.validation_split < 1.0:
Expand Down
Loading