Skip to content

Add NVTE_KEEP_BACKWARD_UNQUANTIZED#2644

Open
zianglih wants to merge 23 commits intoNVIDIA:mainfrom
zianglih:keep-bwd
Open

Add NVTE_KEEP_BACKWARD_UNQUANTIZED#2644
zianglih wants to merge 23 commits intoNVIDIA:mainfrom
zianglih:keep-bwd

Conversation

@zianglih
Copy link

@zianglih zianglih commented Feb 3, 2026

Description

@HumansAnd

Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 3, 2026

Greptile Overview

Greptile Summary

This PR adds support for quantized forward pass with high-precision backward pass via the NVTE_KEEP_BACKWARD_UNQUANTIZED environment variable. When enabled, the feature performs FP8-quantized forward computation but uses high-precision (unquantized) tensors for weight gradient and data gradient computation in the backward pass.

Key Changes:

  • Added FP8GlobalStateManager.keep_backward_unquantized() method that reads the env var and automatically returns False for delayed scaling recipes
  • Modified Linear, GroupedLinear, and LayerNormLinear modules to save high-precision tensors instead of quantized ones when the flag is enabled
  • Updated backward pass logic to disable FP8 quantization and use high-precision saved tensors for dgrad and wgrad computation
  • Disabled Userbuffers communication optimizations in backward pass when this feature is active
  • Added assertion in LayerNormMLP to prevent usage (not yet implemented for this module)
  • Propagated the flag through all fused operations

Trade-offs:

  • Increased memory usage (stores high-precision tensors instead of FP8)
  • Potentially improved gradient accuracy
  • Feature is automatically disabled for delayed scaling recipes to avoid conflicts

Confidence Score: 4/5

  • This PR is generally safe to merge with minor concerns about memory usage and documentation
  • The implementation is systematic and correctly propagates the flag through all affected modules. The guard against delayed scaling recipes prevents assertion failures. However, there's increased memory usage that users should be aware of, and LayerNormMLP is not yet supported
  • Pay attention to transformer_engine/pytorch/module/layernorm_linear.py for memory usage implications

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantization.py Added keep_backward_unquantized() method to check NVTE_KEEP_BACKWARD_UNQUANTIZED env var, correctly returns False for delayed scaling recipes
transformer_engine/pytorch/module/layernorm_mlp.py Added assertion to block usage of NVTE_KEEP_BACKWARD_UNQUANTIZED in LayerNormMLP (not yet implemented)
transformer_engine/pytorch/module/linear.py Implements high-precision backward by saving original tensors, disabling FP8 quantizers in backward, and using unquantized weights for dgrad/wgrad
transformer_engine/pytorch/module/layernorm_linear.py Stores both quantized and high-precision layernorm output when flag is set, disables FP8 quantization in backward pass
transformer_engine/pytorch/ops/basic/basic_linear.py Adds keep_backward_unquantized parameter to forward function, conditionally saves high-precision tensors instead of quantized ones

Sequence Diagram

sequenceDiagram
    participant User
    participant FP8GlobalStateManager
    participant Linear
    participant Forward
    participant Backward
    
    User->>FP8GlobalStateManager: Set NVTE_KEEP_BACKWARD_UNQUANTIZED=1
    User->>Linear: forward(input, weight)
    Linear->>FP8GlobalStateManager: keep_backward_unquantized()
    FP8GlobalStateManager->>FP8GlobalStateManager: Check recipe.delayed()
    alt Delayed Scaling Recipe
        FP8GlobalStateManager-->>Linear: False (ignore flag)
    else Other Recipe
        FP8GlobalStateManager-->>Linear: True (use high precision)
    end
    
    alt keep_backward_unquantized=True
        Linear->>Forward: Quantize input for FP8 forward
        Forward->>Forward: Compute: y = x_fp8 @ w_fp8
        Forward->>Forward: Save x_high_precision, w_high_precision
        Forward->>Forward: Set ctx.with_quantized_compute=False
        Forward-->>Linear: y, saved tensors
    else keep_backward_unquantized=False
        Linear->>Forward: Quantize input for FP8 forward
        Forward->>Forward: Compute: y = x_fp8 @ w_fp8
        Forward->>Forward: Save x_fp8, w_fp8
        Forward->>Forward: Set ctx.with_quantized_compute=True
        Forward-->>Linear: y, saved tensors
    end
    
    User->>Linear: backward(grad_output)
    Linear->>Backward: grad_output
    
    alt keep_backward_unquantized=True
        Backward->>Backward: Load x_high_precision, w_high_precision
        Backward->>Backward: Disable quantizers (set to None)
        Backward->>Backward: dgrad = grad_out @ w_high_precision (high precision)
        Backward->>Backward: wgrad = x_high_precision.T @ grad_out (high precision)
    else keep_backward_unquantized=False
        Backward->>Backward: Load x_fp8, w_fp8
        Backward->>Backward: Quantize grad_output to FP8
        Backward->>Backward: dgrad = grad_out_fp8 @ w_fp8 (FP8)
        Backward->>Backward: wgrad = x_fp8.T @ grad_out_fp8 (FP8)
    end
    
    Backward-->>Linear: grad_input, grad_weight
    Linear-->>User: gradients
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 3, 2026

I'll work on potential unit test breakage.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
ln_out_hp = ln_out if keep_backward_unquantized else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation

verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Ziang Li <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

not ctx.use_bias
and not ctx.requires_wgrad
and ctx.grad_output_quantizer is not None
and use_fp8_bwd
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to assert an error for delayed scaling? Okay with both.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
if (
ctx.fp8
and not ctx.keep_backward_unquantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment

# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
if ctx.grad_output_quantizer is not None and use_fp8_bwd:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems redundant too if we skip quant in grad_output_preprocess

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 423 to 424

# Prepare GEMM input
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recomputing activation_func(fc1_out, None, **act_params) adds compute overhead for activations like GELU

consider storing high-precision act_out during forward pass when this feature is enabled to avoid redundant computation (trade memory for compute)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Additional Comments (1)

transformer_engine/pytorch/module/layernorm_mlp.py
Incorrect instance check

In _LayerNormMLP.backward, this block checks isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) and then calls ctx.fc1_weight.update_usage(...).

QuantizedTensorStorage is a tensor storage type, not a quantizer; this condition will never be true, so usage for ctx.fc1_weight won’t be updated when it should be (FP8 backward + quantized weight path). This looks like a typo for checking the weight (or QuantizedTensorStorage on ctx.fc1_weight) and can break backward that relies on correct usage flags.

Also appears in: none found in this diff.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@zianglih
Copy link
Author

zianglih commented Feb 4, 2026

Some nvfuser tests are failing:

=================================================================== short test summary info ===================================================================
FAILED tests/pytorch/test_sanity.py::test_sanity_amp_and_nvfuser[True-small-None-dtype1] - RuntimeError: /root/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:764 in function cublas_gemm: Assertion failed: status != CUBLAS_STAT...
FAILED tests/pytorch/test_sanity.py::test_sanity_amp_and_nvfuser[True-small-None-dtype2] - RuntimeError: /root/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:764 in function cublas_gemm: Assertion failed: status != CUBLAS_STAT...
FAILED tests/pytorch/test_sanity.py::test_sanity_amp_and_nvfuser[False-small-None-dtype1] - RuntimeError: /root/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:764 in function cublas_gemm: Assertion failed: status != CUBLAS_STAT...
FAILED tests/pytorch/test_sanity.py::test_sanity_amp_and_nvfuser[False-small-None-dtype2] - RuntimeError: /root/TransformerEngine/transformer_engine/common/gemm/cublaslt_gemm.cu:764 in function cublas_gemm: Assertion failed: status != CUBLAS_STAT...
================================================ 4 failed, 12918 passed, 16523 skipped, 20 warnings in 40.71s =================================================

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feature is reasonably straightforward, although I have some design suggestions to make it more general. Also, we should add some unit tests to make sure this works as expected.

recipe = cls.get_fp8_recipe()
if recipe is not None and recipe.delayed():
# Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.

return cls.HIGH_PRECISION_INIT_VAL

@classmethod
def keep_backward_unquantized(cls) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer this option to live in Recipe rather than FP8GlobalStateManager. FP8GlobalStateManager is for state that changes very frequently (e.g. when entering or exiting a te.autocast), while Recipe has configs that persist throughout training. Exposing the option in Recipe also makes it easier to configure programmatically rather than with an obscure envvar.

return cls.HIGH_PRECISION_INIT_VAL

@classmethod
def keep_backward_unquantized(cls) -> bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option name is specific to this workflow and doesn't generalize well. How about we break this up into two options: quantize_forward and quantize_backward. We have the following cases:

  • quantize_forward=True, quantize_backward=True: Equivalent to quantized case. In the future we might be able to replace FP8GlobalStateManager.FP8_ENABLED with FP8GlobalStateManager.QUANTIZE_FORWARD or FP8GlobalStateManager.QUANTIZE_BACKWARD.
  • quantize_forward=False, quantize_backward=False: Equivalent to unquantized case.
  • quantize_forward=True, quantize_backward=False: Your desired workflow.
  • quantize_forward=False, quantize_backward=True: We can error out in this case, but who know if someone in the future might want this.

Comment on lines 448 to +450
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.keep_backward_unquantized = keep_backward_unquantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the backward pass has unquantized compute, does it need to know that the forward pass was quantized? If possible, it would be nice to keep all the changed confined here where we configure the autograd context.

Suggested change
ctx.fp8 = fp8
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
ctx.keep_backward_unquantized = keep_backward_unquantized
ctx.fp8 = fp8 and not keep_backward_unquantized
ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants