Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
88cf1b2
Rebased to main
Oleg-Goncharov Jan 21, 2026
ac23f06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2026
44ec5ba
Merge branch 'main' into pr_mxfp8_grouped_kernel
Oleg-Goncharov Jan 21, 2026
99f1f63
Fixed the year to 2026
Oleg-Goncharov Jan 21, 2026
7415138
Added compilation guards
Oleg-Goncharov Jan 21, 2026
adacda9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 21, 2026
39bb24f
Added BWD pass
Oleg-Goncharov Jan 22, 2026
02c05a6
Merge branch 'main' into pr_mxfp8_grouped_kernel
Oleg-Goncharov Jan 22, 2026
452651a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 22, 2026
9da18bf
Merge branch 'main' into pr_mxfp8_grouped_kernel
vthumbe1503 Jan 23, 2026
e8beb1e
Added dbias and dact tests. Refactoring.
Oleg-Goncharov Jan 23, 2026
b3f8468
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 23, 2026
1235167
Added grouped MXFP8 DACT and ACT API and tests
Oleg-Goncharov Jan 24, 2026
34b9dfd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2026
6dd3814
Fixed a typo
Oleg-Goncharov Jan 24, 2026
c20c9d4
Fixes per the review
Oleg-Goncharov Jan 26, 2026
82e9c77
Merge branch 'main' into pr_mxfp8_grouped_kernel
Oleg-Goncharov Jan 26, 2026
65afe16
More fixes from the review
Oleg-Goncharov Jan 26, 2026
fc0f9e9
Merge branch 'main' into pr_mxfp8_grouped_kernel
Oleg-Goncharov Jan 26, 2026
e01865f
Added fused preswizzling to the kernel
Oleg-Goncharov Jan 28, 2026
bf07d9d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_executable(test_operator
test_cast_mxfp8_gated_swiglu.cu
test_qdq.cu
test_cast_mxfp8.cu
test_cast_mxfp8_grouped.cu
test_cast_nvfp4_transpose.cu
test_cast_float8blockwise.cu
test_dequantize_mxfp8.cu
Expand Down
777 changes: 777 additions & 0 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu

Large diffs are not rendered by default.

73 changes: 73 additions & 0 deletions transformer_engine/common/activation/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,35 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
}

void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_gelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, gelu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dgelu<fp32, fp32>>(grad, input, output, stream);
}

void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
Expand All @@ -33,6 +55,20 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
Expand All @@ -54,13 +90,36 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
}

void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_qgelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, qgelu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dqgelu<fp32, fp32>>(grad, input, output, stream);
}

void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
Expand All @@ -74,6 +133,20 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dqgelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
Expand Down
73 changes: 73 additions & 0 deletions transformer_engine/common/activation/relu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,35 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
}

void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_relu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, relu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, drelu<fp32, fp32>>(grad, input, output, stream);
}

void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_drelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
Expand All @@ -33,6 +55,20 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_drelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, drelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
Expand All @@ -54,13 +90,36 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream)
act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
}

void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_srelu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, srelu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsrelu<fp32, fp32>>(grad, input, output, stream);
}

void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsrelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
Expand All @@ -74,6 +133,20 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsrelu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
Expand Down
36 changes: 36 additions & 0 deletions transformer_engine/common/activation/swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,35 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
}

void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_silu);
using namespace transformer_engine;
constexpr bool IS_ACT = true;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, silu<fp32, fp32>>(input, output, nullptr,
stream);
}

void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu);
using namespace transformer_engine;
dact_fn<fp32, Empty, dsilu<fp32, fp32>>(grad, input, output, stream);
}

void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsilu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
grad, input, output, dbias, workspace, nullptr, stream);
}

void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream) {
Expand All @@ -33,6 +55,20 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsilu);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = true;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, dsilu<fp32, fp32>>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
Expand Down
22 changes: 22 additions & 0 deletions transformer_engine/common/cast/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
dispatch::quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}

void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize);
using namespace transformer_engine;

constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, output, nullptr, stream);
}

void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
cudaStream_t stream) {
NVTE_API_CALL(nvte_quantize_noop);
Expand Down Expand Up @@ -60,6 +69,19 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias);
using namespace transformer_engine;

constexpr bool IS_DBIAS = true;
constexpr bool IS_DACT = false;
constexpr const NVTEGroupedTensor activation_input = nullptr;

dispatch::group_quantize_bwd_helper<IS_DBIAS, IS_DACT, Empty, nullptr>(
input, activation_input, output, dbias, workspace, nullptr, stream);
}

void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_dequantize);
using namespace transformer_engine;
Expand Down
Loading
Loading