Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,10 @@ def _start_all_gather_fp8_blockwise(

# Fall back to high-precision all-gather if FP8 is not supported
if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
out = torch.empty(out_shape, dtype=dtype, device=device)
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).

out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
out = quantizer(out)
return out, None
Expand Down Expand Up @@ -1338,10 +1341,13 @@ def _all_gather_nvfp4(
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

out = torch.empty(
out_shape,
dtype=dtype,
device=device,
dtype=inp.dtype,
device=inp.device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
Expand Down Expand Up @@ -1505,10 +1511,13 @@ def _all_gather_mxfp8(
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize() # Dequantize if needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.

out = torch.empty(
out_shape,
dtype=dtype,
device=device,
dtype=inp.dtype,
device=inp.device,
memory_format=torch.contiguous_format,
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ..quantization import FP8GlobalStateManager
from ..utils import (
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
cast_if_needed,
clear_tensor_data,
divide,
Expand Down Expand Up @@ -158,7 +157,6 @@ def forward(
inputmat = inp
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer)

# Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast")
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
init_method_constant,
cast_if_needed,
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
clear_tensor_data,
requires_grad,
needs_quantized_gemm,
Expand Down Expand Up @@ -331,7 +330,6 @@ def _forward(
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer)

activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
requires_grad,
needs_quantized_gemm,
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
nvtx_range_pop,
nvtx_range_push,
get_nvtx_range_context,
Expand Down Expand Up @@ -175,7 +174,6 @@ def forward(
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer)
if save_original_input:
assert not isinstance(
input_quantizer, Float8Quantizer
Expand Down
10 changes: 0 additions & 10 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,16 +447,6 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
)


def assert_dim_for_all_gather(
tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer
) -> None:
"""Assert that tensor dimensions are supported for all-gather"""
if with_all_gather:
assert quantizer.is_quantizable(tensor), (
"All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__
)


def is_bf16_compatible() -> bool:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
Expand Down
Loading