-
Notifications
You must be signed in to change notification settings - Fork 626
Fix FP8 block scaling with sequence parallel #2637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3ba991b
390b2e1
637ba0f
9fe572b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1100,7 +1100,10 @@ def _start_all_gather_fp8_blockwise( | |
|
|
||
| # Fall back to high-precision all-gather if FP8 is not supported | ||
| if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: | ||
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed | ||
| out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) | ||
| out = quantizer(out) | ||
| return out, None | ||
|
|
@@ -1338,10 +1341,13 @@ def _all_gather_nvfp4( | |
| and quantizer is not None | ||
| and not quantizer.is_quantizable(inp) | ||
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same consideration as in |
||
| out = torch.empty( | ||
| out_shape, | ||
| dtype=dtype, | ||
| device=device, | ||
| dtype=inp.dtype, | ||
| device=inp.device, | ||
| memory_format=torch.contiguous_format, | ||
| ) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group) | ||
|
|
@@ -1505,10 +1511,13 @@ def _all_gather_mxfp8( | |
| and quantizer is not None | ||
| and not quantizer.is_quantizable(inp) | ||
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same consideration as in |
||
| out = torch.empty( | ||
| out_shape, | ||
| dtype=dtype, | ||
| device=device, | ||
| dtype=inp.dtype, | ||
| device=inp.device, | ||
| memory_format=torch.contiguous_format, | ||
| ) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
dequantize()method defaults todtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).