diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..2208ff720c 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1100,7 +1100,10 @@ def _start_all_gather_fp8_blockwise( # Fall back to high-precision all-gather if FP8 is not supported if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: - out = torch.empty(out_shape, dtype=dtype, device=device) + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed + out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) out = quantizer(out) return out, None @@ -1338,10 +1341,13 @@ def _all_gather_nvfp4( and quantizer is not None and not quantizer.is_quantizable(inp) ): + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed out = torch.empty( out_shape, - dtype=dtype, - device=device, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) @@ -1505,10 +1511,13 @@ def _all_gather_mxfp8( and quantizer is not None and not quantizer.is_quantizable(inp) ): + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed out = torch.empty( out_shape, - dtype=dtype, - device=device, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..f79bc91c0a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -29,7 +29,6 @@ from ..quantization import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, - assert_dim_for_all_gather, cast_if_needed, clear_tensor_data, divide, @@ -158,7 +157,6 @@ def forward( inputmat = inp if fp8: assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..d9f046aa38 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -40,7 +40,6 @@ init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, - assert_dim_for_all_gather, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -331,7 +330,6 @@ def _forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) - assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..d7283cc047 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -34,7 +34,6 @@ requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, - assert_dim_for_all_gather, nvtx_range_pop, nvtx_range_push, get_nvtx_range_context, @@ -175,7 +174,6 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 47af9fabe1..0a74c75edd 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -447,16 +447,6 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ) -def assert_dim_for_all_gather( - tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer -) -> None: - """Assert that tensor dimensions are supported for all-gather""" - if with_all_gather: - assert quantizer.is_quantizable(tensor), ( - "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ - ) - - def is_bf16_compatible() -> bool: """Replaces torch.cuda.is_bf16_compatible() with an explicit check on device compute capability to enforce sm_80 or higher.