Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,20 @@ def test_bf16_exp_avg_sq(self):
master_atol=2e-3,
)

@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
def test_bf16_exp_avg_and_exp_avg_sq(self):
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
use_master_weights=True,
master_weight_dtype=torch.float32,
grad_dtype=torch.float32,
exp_avg_dtype=torch.bfloat16,
exp_avg_sq_dtype=torch.bfloat16,
master_rtol=2e-3,
master_atol=2e-3,
)
Comment on lines +410 to +422
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a test for capturable mode (CUDA Graphs) with BF16 momentums, since the PR description mentions "Enable CUDA Graphs for BF16 momentums" as a key feature. The current test only covers non-capturable mode.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +410 to +422
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test only covers non-capturable mode. Add test for capturable mode with BF16 momentums since PR enables CUDA Graphs for this case.

Note: Check that gen_precision_aware_test supports a capturable parameter, or create a separate test method.


@pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg_sq(self):
Expand Down Expand Up @@ -553,7 +567,7 @@ def forward(self, x):
return y


class AdamTest:
class TestAdamTest:

def setup_method(self, *, seed: int = 0) -> None:
torch.manual_seed(seed)
Expand All @@ -569,8 +583,8 @@ def setup_method(self, *, seed: int = 0) -> None:
def test_grad_scaler(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
scaler = torch.amp.GradScaler("cuda", enabled=True)
scaler_ = torch.amp.GradScaler("cuda", enabled=True)

for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
Expand Down Expand Up @@ -620,8 +634,8 @@ def test_grad_scaler(self):
def test_grad_scaler_capturable(self):
params_ = [p for p in self.model_.parameters() if p.requires_grad]
optimizer_ = te.optimizers.FusedAdam(params_, lr=self.lr, capturable=True)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
scaler = torch.amp.GradScaler("cuda", enabled=True)
scaler_ = torch.amp.GradScaler("cuda", enabled=True)

for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
Expand Down Expand Up @@ -678,8 +692,8 @@ def test_grad_scaler_capturable_master(self):
optimizer_ = te.optimizers.FusedAdam(
params_, lr=self.lr, capturable=True, master_weights=master_weights
)
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_ = torch.cuda.amp.GradScaler(enabled=True)
scaler = torch.amp.GradScaler("cuda", enabled=True)
scaler_ = torch.amp.GradScaler("cuda", enabled=True)

for i in range(100):
x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
Expand Down
15 changes: 15 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,21 @@ struct TypeInfo {
NVTE_ERROR("Invalid type."); \
}

#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP32_BF16(dtype, type, ...) \
switch (dtype) { \
using namespace transformer_engine; \
case DType::kFloat32: { \
using type = float; \
{ __VA_ARGS__ } \
} break; \
case DType::kBFloat16: { \
using type = bf16; \
{ __VA_ARGS__ } \
} break; \
default: \
NVTE_ERROR("Invalid type, expected Float32 or BFloat16."); \
}

// Add a pack_size argument to select the packed type for FP4
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \
switch (dtype) { \
Expand Down
Loading
Loading