Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/puzzletron/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,91 @@ hf auth login

30% GPU memory reduction leads to nearly 5% regression in token_accuracy_top_10 metric (0.898 / 0.942).

## Bypass Distillation (Local Knowledge Distillation)

Bypass distillation (also called Blockwise Local Distillation or BLD) is an **optional** pipeline stage that trains alternative transformer block configurations using per-block knowledge distillation from the teacher model. It significantly improves the quality of aggressively compressed models by producing better "puzzle pieces" for the MIP solver.

### When to use bypass

Bypass distillation is only necessary for **aggressive compression**. For mild pruning (e.g., reducing FFN intermediate size by less than 25%), weight-initialization-based pruning alone usually produces good results. Use bypass when:

- **Heavy FFN pruning**: the target `intermediate_size` is ≤ 1/8 of the teacher's width.
For example, on Llama-3.1-8B (teacher `intermediate_size=14336`), run bypass for sizes ≤ 1792.
For milder reductions (e.g., to 3072 = ~21%), bypass improves quality but may not be essential.
- **KV head compression**: the number of `num_key_value_heads` is being significantly reduced
(e.g., from 8 to 2 or fewer). The AverageKV initialization provides a good starting point,
but bypass distillation recovers additional accuracy.

### Time cost

Bypass distillation is a full training loop — plan for several hours per configuration when
using ~1B training tokens on H100 GPUs. Total time scales with `len(bypass.configs) × training_tokens`.
This is comparable to lightweight fine-tuning.

### Sequential execution

Each entry in `bypass.configs` trains **sequentially** (one config at a time). There is no
parallelism across configurations — if you have 3 configs, they run one after the other within
a single pipeline invocation. Distribute across different jobs if time is a constraint.

### Configuration

Add a `bypass` section to your config YAML (or include `bypass/defaults.yaml` via Hydra defaults).
Key parameters:

| Parameter | Description | Default |
|---|---|---|
| `training.learning_rate` | Initial learning rate | `1e-4` |
| `training.training_tokens` | Total training tokens per config | `1e+9` (1B) |
| `training.micro_batch_size` | Batch size per step | `2` |
| `data.block_size` | Sequence length | `512` |
| `model_factory.gqa_init_mode` | KV head init strategy (`AverageKV`, `RandomKV`) | `AverageKV` |
| `model_factory.mlp_init_mode` | FFN init strategy (`Truncate`, `PruneByActivationsLog`) | `Truncate` |
| `model_factory.keys_to_learn` | Which params to train (`subblock_ffn`, `subblock_attention`, `entire_block`) | computed |
| `configs` | List of configurations to train sequentially | — |

### Training multiple configurations

Use `bypass.configs` to train multiple block configurations in a single run. Each entry
overrides `model.model_config_overrides` and optionally `model_factory.keys_to_learn`:

```yaml
bypass:
training:
training_tokens: 1e+9 # ~1B tokens per config
configs:
- model_config_overrides:
ffn:
- intermediate_size: 1792 # ~1/8 of 14336 — bypass strongly recommended
attention:
- num_key_value_heads: 8
keys_to_learn: subblock_ffn
- model_config_overrides:
ffn:
- intermediate_size: 3584 # ~1/4 of 14336 — bypass optional but helpful
attention:
- num_key_value_heads: 8
keys_to_learn: subblock_ffn
```

Trained checkpoints are automatically symlinked into `$PUZZLE_DIR/ckpts/` where the replacement
library builder picks them up in the next pipeline stage.

### Weights & Biases logging

Enable W&B to track per-block distillation loss and validation metrics during training:

```yaml
bypass:
wandb_log: true
wandb:
project: my-puzzletron-project
entity: my-org
```

W&B logs iteration number, token count, learning rate, and per-block loss at each log interval.
If `wandb` is not installed, logging is silently disabled and training continues normally.

## Re-run MIP Search with different constraints

If you want to try different constraints without re-running the expensive pruning and scoring steps, use the `--mip-only` flag.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- pruning: ffn_pruning
- scoring: ../validate_solutions_defaults
- realize_model: ../validate_solutions_defaults
- bypass:
- bypass: defaults # comment out to run without bypass
- override hydra/hydra_logging: disabled
- _self_

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# @package bypass
# Bypass Distillation Configuration
# This config defines parameters for blockwise local distillation (BLD),
# which trains alternative transformer block configurations using per-block
# knowledge distillation from a teacher model.

# Runtime Configuration
dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability
seed: 42 # Random seed for reproducibility

# Experiment Tracking
experiment_id: # Unique identifier for this experiment. Will be dynamically set
experiment_dir: # Directory for this experiment. Will be dynamically set
iter_num: 1 # Current iteration number
step_num: 1 # Current step number within iteration
token_count: 0 # Token count tracker (auto-updated during training)

# Data Configuration
data:
data_column: "messages"
block_size: 512 # Sequence length (tokens per sample)
bos_rate: 0.5
fim_rate: 0
fim_spm_rate: 0
source_datasets_to_discard: []
load_from_disk: true # Load preprocessed data from disk or from stream
keep_in_memory: false
val_dataset_name: valid
max_eval_samples: 4
eval_samples_per_process: null # Samples per GPU during distributed eval (auto if null)
shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data

# Training Configuration
training:
learning_rate: 1e-4 # Initial learning rate (1e-4 = 0.0001)
training_tokens: 1e+4 # Total training tokens (10K tokens - sanity check)
micro_batch_size: 2
val_micro_batch_size: 1
warmup_ratio: 0.05
warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.warmup_ratio}} # Auto-calculated warmup steps
min_lr_factor: 1e-5
grad_accumulation_steps: 1
skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues.
weight_decay: 0.1
decay_lr: true
beta1: 0.9
beta2: 0.95
use_grad_scaling: false
grad_clip: 1.0
grad_clip_type: norm
clipping_count: 0
log_interval: 5
eval_interval: 5

# Model Loading Configuration
resume_checkpoint_path: null # Path to resume training from checkpoint
find_last_ckpt_for_resume: True # Auto-resume by finding last checkpoint (bool)
parameter_count: null
init_checkpoint_path: null # Path to initialize weights from

model:
student_weights_dtype: "bf16" # Student model weight precision

model_overrides:
delete_old_checkpoints: true # Clean up old checkpoints to save disk space
save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours
save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled)
save_checkpoint_when_done: true # Save final checkpoint when training completes

# Architecture modifications for student model
model_config_overrides:
ffn:
- intermediate_size:
no_op: # Disable FFN entirely (true/false)
attention:
- num_key_value_heads: # Number of kv-heads (for GQA)
no_op: # Disable attention entirely (true/false)

# Model Factory Configuration - Controls student model creation and initialization
model_factory:
factory: bypass_factory_fn # Unified factory supporting all layer types
block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks. vectorwise_normalized_mse_loss / batched_normalized_mse_loss / normalized_mse_loss
gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode
mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode
mlp_init_config: # Configuration for MLP initialization (if needed)
activations_log_dir: null # Directory with activation statistics (required for PruneByActivationsLog)
linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc.
submodule_for_loss_calculation: null # Specific submodule for loss calc.
keys_to_learn: null # What parameters to train. Either "entire_block", or specific submodules. Computed dynamically.

# Validation Configuration
disable_initial_validate: false
validate_teacher_model: true
validate_student_model: true
disable_validation: false # Enable validation to exercise all code paths
best_val_loss: 1e+9 # Track best validation loss achieved

# Performance Optimization
compile: false # Use PyTorch compilation
disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available)
teacher_model_load_on_cpu: false

# Checkpoint Management
save_checkpoint_before_training: false # Save initial checkpoint before training
disable_checkpoint_save: false # Disable all checkpoint saving
save_best_ckpt: true # Save checkpoint when validation improves
kill_after_first_save: false # Exit after first checkpoint save (for testing)
realize_best_or_latest: "best"

wandb_log: false
wandb:
project:
entity:

# Multiple bypass configurations to train sequentially.
# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn.
# If empty or absent, a single run uses the settings above.
configs:
- model_config_overrides:
ffn:
- intermediate_size: 3072
attention:
- num_key_value_heads: 8
keys_to_learn: subblock_ffn
- model_config_overrides:
ffn:
- intermediate_size: 5888
attention:
- num_key_value_heads: 8
keys_to_learn: subblock_ffn
9 changes: 6 additions & 3 deletions examples/puzzletron/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import modelopt.torch.puzzletron.mip.sweep as sweep
import modelopt.torch.utils.distributed as dist
from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import PuzzletronModel
from modelopt.torch.puzzletron.nas.plugins.puzzletron_nas_plugin import _total_steps
from modelopt.torch.puzzletron.tools.hydra_utils import (
initialize_hydra_config_for_dir,
register_hydra_resolvers,
Expand Down Expand Up @@ -74,7 +75,6 @@ def run_full_puzzletron(hydra_config_path: str):
Args:
config_path: Path to the YAML configuration file
"""
mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
dist.setup(timeout=timedelta(10))

# Register Hydra custom resolvers (needed for config resolution)
Expand All @@ -84,12 +84,15 @@ def run_full_puzzletron(hydra_config_path: str):
hydra_config_dir = str(hydra_config_path.parent)
hydra_config_name = hydra_config_path.stem

# Load hydra config
# Load hydra config to determine total step count (bypass adds one step)
hydra_cfg = initialize_hydra_config_for_dir(
config_dir=hydra_config_dir,
config_name=hydra_config_name,
overrides=[],
)
N = _total_steps(hydra_cfg)

mprint(f"Puzzletron Progress 1/{N}: starting puzzletron pipeline")

# Convert model (convert from HF to DeciLM, score pruning activations,
# prune the model and save pruned checkpoints)
Expand Down Expand Up @@ -120,7 +123,7 @@ def run_full_puzzletron(hydra_config_path: str):
)

dist.cleanup()
mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
mprint(f"Puzzletron Progress {N}/{N}: puzzletron pipeline completed (multi-gpu)")


def run_mip_only(hydra_config_path: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ def layer_name_predicates(num_layers: int) -> Dict[str, re.Pattern]:
"""
raise NotImplementedError

@staticmethod
def pruning_mixins() -> Dict[str, Any]:
"""Return available pruning mixins for bypass distillation.

Override in subclasses to provide model-specific pruning mixins, e.g.
``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``.

Returns an empty dict by default so that descriptors that do not need
model-specific weight-slicing (e.g. Llama with standard FFN truncation)
can rely on the generic ``create_child_state_dict`` fallback path.
"""
return {}

@staticmethod
def uses_autocast() -> bool:
"""Whether this model supports torch.autocast.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
ExpertRemovalLayerDescriptor,
ExpertRemovalPruningMixIn,
)
from modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin import (
KVHeadsLayerDescriptor,
KVHeadsPruningMixIn,
)
from modelopt.torch.puzzletron.pruning.pruning_mixin import PruningMixIn


Expand All @@ -52,6 +56,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]:
return matches


@dataclass
class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor):
o_proj_name: str = "mixer.o_proj"
attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer"
qkvo_weight_names: List[str] = field(
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
)


@dataclass
class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor):
target_name: str = "mixer.gate"
Expand Down Expand Up @@ -253,4 +266,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]:
def pruning_mixins() -> Dict[str, PruningMixIn]:
return {
"experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()),
"kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()),
}
22 changes: 22 additions & 0 deletions modelopt/torch/puzzletron/bypass_distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Bypass distillation (blockwise local distillation) for the PUZZLE framework.

This module implements Stage 1 of the PUZZLE pipeline: training alternative transformer
block configurations using per-block knowledge distillation from a teacher model.
"""

from .training_loop import launch_bypass_distillation
Loading
Loading