Conversation
Signed-off-by: Ziang Li <[email protected]>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds support for quantized forward pass with high-precision backward pass via the Key Changes:
Trade-offs:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant FP8GlobalStateManager
participant Linear
participant Forward
participant Backward
User->>FP8GlobalStateManager: Set NVTE_KEEP_BACKWARD_UNQUANTIZED=1
User->>Linear: forward(input, weight)
Linear->>FP8GlobalStateManager: keep_backward_unquantized()
FP8GlobalStateManager->>FP8GlobalStateManager: Check recipe.delayed()
alt Delayed Scaling Recipe
FP8GlobalStateManager-->>Linear: False (ignore flag)
else Other Recipe
FP8GlobalStateManager-->>Linear: True (use high precision)
end
alt keep_backward_unquantized=True
Linear->>Forward: Quantize input for FP8 forward
Forward->>Forward: Compute: y = x_fp8 @ w_fp8
Forward->>Forward: Save x_high_precision, w_high_precision
Forward->>Forward: Set ctx.with_quantized_compute=False
Forward-->>Linear: y, saved tensors
else keep_backward_unquantized=False
Linear->>Forward: Quantize input for FP8 forward
Forward->>Forward: Compute: y = x_fp8 @ w_fp8
Forward->>Forward: Save x_fp8, w_fp8
Forward->>Forward: Set ctx.with_quantized_compute=True
Forward-->>Linear: y, saved tensors
end
User->>Linear: backward(grad_output)
Linear->>Backward: grad_output
alt keep_backward_unquantized=True
Backward->>Backward: Load x_high_precision, w_high_precision
Backward->>Backward: Disable quantizers (set to None)
Backward->>Backward: dgrad = grad_out @ w_high_precision (high precision)
Backward->>Backward: wgrad = x_high_precision.T @ grad_out (high precision)
else keep_backward_unquantized=False
Backward->>Backward: Load x_fp8, w_fp8
Backward->>Backward: Quantize grad_output to FP8
Backward->>Backward: dgrad = grad_out_fp8 @ w_fp8 (FP8)
Backward->>Backward: wgrad = x_fp8.T @ grad_out_fp8 (FP8)
end
Backward-->>Linear: grad_input, grad_weight
Linear-->>User: gradients
|
|
I'll work on potential unit test breakage. |
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <[email protected]>
… is used Signed-off-by: Ziang Li <[email protected]>
| ln_out_return = None | ||
| if return_layernorm_output or return_layernorm_output_gathered: | ||
| ln_out_return = ln_out | ||
| ln_out_hp = ln_out if keep_backward_unquantized else None |
There was a problem hiding this comment.
storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation
verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
Signed-off-by: Ziang Li <[email protected]>
| not ctx.use_bias | ||
| and not ctx.requires_wgrad | ||
| and ctx.grad_output_quantizer is not None | ||
| and use_fp8_bwd |
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
Maybe it's better to assert an error for delayed scaling? Okay with both.
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): | ||
| if ( | ||
| ctx.fp8 | ||
| and not ctx.keep_backward_unquantized |
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this seems redundant too if we skip quant in grad_output_preprocess
Signed-off-by: Ziang Li <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <[email protected]>
|
|
||
| # Prepare GEMM input |
There was a problem hiding this comment.
recomputing activation_func(fc1_out, None, **act_params) adds compute overhead for activations like GELU
consider storing high-precision act_out during forward pass when this feature is enabled to avoid redundant computation (trade memory for compute)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Signed-off-by: Ziang Li <[email protected]>
Additional Comments (1)
In
Also appears in: none found in this diff. |
Signed-off-by: Ziang Li <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <[email protected]>
|
Some nvfuser tests are failing: |
Signed-off-by: Ziang Li <[email protected]>
for more information, see https://pre-commit.ci
timmoon10
left a comment
There was a problem hiding this comment.
This feature is reasonably straightforward, although I have some design suggestions to make it more general. Also, we should add some unit tests to make sure this works as expected.
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| return cls.HIGH_PRECISION_INIT_VAL | ||
|
|
||
| @classmethod | ||
| def keep_backward_unquantized(cls) -> bool: |
There was a problem hiding this comment.
I would prefer this option to live in Recipe rather than FP8GlobalStateManager. FP8GlobalStateManager is for state that changes very frequently (e.g. when entering or exiting a te.autocast), while Recipe has configs that persist throughout training. Exposing the option in Recipe also makes it easier to configure programmatically rather than with an obscure envvar.
| return cls.HIGH_PRECISION_INIT_VAL | ||
|
|
||
| @classmethod | ||
| def keep_backward_unquantized(cls) -> bool: |
There was a problem hiding this comment.
This option name is specific to this workflow and doesn't generalize well. How about we break this up into two options: quantize_forward and quantize_backward. We have the following cases:
quantize_forward=True,quantize_backward=True: Equivalent to quantized case. In the future we might be able to replaceFP8GlobalStateManager.FP8_ENABLEDwithFP8GlobalStateManager.QUANTIZE_FORWARD or FP8GlobalStateManager.QUANTIZE_BACKWARD.quantize_forward=False,quantize_backward=False: Equivalent to unquantized case.quantize_forward=True,quantize_backward=False: Your desired workflow.quantize_forward=False,quantize_backward=True: We can error out in this case, but who know if someone in the future might want this.
| ctx.fp8 = fp8 | ||
| ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None | ||
| ctx.keep_backward_unquantized = keep_backward_unquantized |
There was a problem hiding this comment.
If the backward pass has unquantized compute, does it need to know that the forward pass was quantized? If possible, it would be nice to keep all the changed confined here where we configure the autograd context.
| ctx.fp8 = fp8 | |
| ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None | |
| ctx.keep_backward_unquantized = keep_backward_unquantized | |
| ctx.fp8 = fp8 and not keep_backward_unquantized | |
| ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None |
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: