Skip to content

Comments

Add support z-loss in pre-training#3211

Open
gagika wants to merge 1 commit intomainfrom
agagik-z-loss
Open

Add support z-loss in pre-training#3211
gagika wants to merge 1 commit intomainfrom
agagik-z-loss

Conversation

@gagika
Copy link
Collaborator

@gagika gagika commented Feb 21, 2026

Description

This PR implements Z-loss to improve numerical stability and prevent runaway logits. Alongside techniques like QK-normalization and logit soft-capping, it is a key mechanism for stabilizing low-precision (BF16/FP8) training.

Key Changes:

  • Configuration: Added a z_loss_multiplier parameter to types.py (defaults to 0.0).
  • Integration: Wired the existing Z-loss utility (max_utils.cross_entropy_with_logits) into the standard training loop (loss_fn) and the vocabulary tiling path (vocab_tiling_linen_loss).
  • Logging: Normalized Z-loss is now exported to the aux dictionary and logged to TensorBoard as learning/z_loss.

Tests

  • Math Verification: Added test_cross_entropy_with_z_loss in max_utils_test.py to verify the penalty calculation is mathematically correct.
  • Integration/Tiling: Added test_vocab_tiling_gradient_with_z_loss in tiling_test.py to ensure loss and gradients match exactly between standard and vocabulary-tiled computations when Z-loss is enabled.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 21, 2026

Codecov Report

❌ Patch coverage is 70.37037% with 8 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 45.45% 6 Missing ⚠️
src/MaxText/vocabulary_tiling.py 87.50% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@github-actions
Copy link

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants