Skip to content

Supports eagle3 training for Gemma3 27B and Gemma4 26B.#553

Open
pyc96 wants to merge 5 commits into
sgl-project:mainfrom
pyc96:gemma-upstream
Open

Supports eagle3 training for Gemma3 27B and Gemma4 26B.#553
pyc96 wants to merge 5 commits into
sgl-project:mainfrom
pyc96:gemma-upstream

Conversation

@pyc96

@pyc96 pyc96 commented May 1, 2026

Copy link
Copy Markdown
Collaborator

Motivation

This PR supports eagle3 training for Gemma3 27B and Gemma4 26B. Other Gemma3/4 models should be supported as well but didn't verify.

Modifications

Besides the new models, it also supports the following features

  • reuse_target_lm_head: use target lm head when the flag is true
  • use_aux_norm: Add additional norm layers before the fc layer

For Gemma4, it requires transformers v5+.

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

@pyc96 pyc96 changed the title Gemma upstream Supports eagle3 training for Gemma3 27B and Gemma4 26B. May 1, 2026

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Gemma 3 and Gemma 4 models within the Eagle3 framework, including new configurations, training scripts, and a dedicated gemma-4 chat template. Key architectural improvements include a fast path for models where draft and target vocab sizes match, the ability to reuse and freeze the target model's LM head, and an improved weight initialization strategy for stable training. The training script now supports multiple data paths and directory resolution. Feedback focuses on preventing race conditions in distributed output directory creation, improving error handling for mismatched tool lists, and adhering to PEP-8 import standards.

Comment thread scripts/train_eagle3.py Outdated
Comment on lines +814 to +815
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.output_dir = os.path.join(args.output_dir, run_timestamp)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Generating the run_timestamp independently on each rank can lead to different output directories across processes if they cross a second boundary during initialization. This will break distributed training and checkpoint saving. The timestamp should be generated on rank 0 and broadcasted to all other ranks.

Suggested change
run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
args.output_dir = os.path.join(args.output_dir, run_timestamp)
run_timestamp = [datetime.now().strftime("%Y%m%d_%H%M%S") if dist.get_rank() == 0 else None]
dist.broadcast_object_list(run_timestamp, src=0)
args.output_dir = os.path.join(args.output_dir, run_timestamp[0])

Comment thread specforge/data/preprocessing.py Outdated
Comment on lines +159 to +160
if tools is None or len(tools) != len(conversations):
tools = [[] for _ in range(len(conversations))]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Silently replacing the tools list with empty lists when the length doesn't match conversations can hide data preparation bugs. It is safer to raise a ValueError if an explicitly provided tools list has an incorrect length.

Suggested change
if tools is None or len(tools) != len(conversations):
tools = [[] for _ in range(len(conversations))]
if tools is None:
tools = [[] for _ in range(len(conversations))]
elif len(tools) != len(conversations):
raise ValueError(f"Length of tools ({len(tools)}) does not match length of conversations ({len(conversations)})")

Comment thread scripts/train_eagle3.py Outdated
# transformers v5 mutating rope_scaling/rope_parameters and other
# fields in model.config during save_pretrained.
if getattr(args, "draft_model_config", None):
import json

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Per PEP-8, imports should be placed at the top of the file. Moving import json to the module level improves readability and follows standard Python practices.

References
  1. Imports should be at the top of the file, after any module comments and docstrings, and before module globals and constants. (link)

@pyc96 pyc96 force-pushed the gemma-upstream branch 2 times, most recently from 0074bd7 to c9910b2 Compare May 1, 2026 23:08
@pyc96 pyc96 force-pushed the gemma-upstream branch from c9910b2 to a79a5f5 Compare May 1, 2026 23:24
@pyc96 pyc96 marked this pull request as ready for review May 1, 2026 23:27
Gemma3 27B and Gemma4 26B have a vocabulary size of 262144, which makes
triton.next_power_of_2 round up to 262144 (==2^18). The previous limit
of 131072 caused _calculate_settings() to raise RuntimeError before the
log-softmax loss kernel could launch, preventing Eagle3 training on
these targets.
Bump MAX_FUSED_SIZE to 262208 to fit Gemma3/4 vocab
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants