Skip to content

Fix exp2f_rcp to properly handle nan and 0xFE cases#2647

Merged
ptrendx merged 1 commit intoNVIDIA:mainfrom
kainzhong:fix/fix_exp2f_rcp
Feb 6, 2026
Merged

Fix exp2f_rcp to properly handle nan and 0xFE cases#2647
ptrendx merged 1 commit intoNVIDIA:mainfrom
kainzhong:fix/fix_exp2f_rcp

Conversation

@kainzhong
Copy link
Collaborator

Description

Fix the implementation of exp2f_rcp

Fixes #2408

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

Our previous implementation of exp2f_rcp is not accurate. After some discussion we decide to

  • If biased_exp is nan, return nan
  • Otherwise, return the reciprocal using the fast math trick
    • If biased_exp == 254 (2^127), this is the only case where we have to use the mantissa bits to represent its reciprocal 2^-127, so we directly return the hardcoded result here
    • Otherwise we use a bits shifting trick to obtain the float reciprocal result

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@ptrendx
Copy link
Member

ptrendx commented Feb 4, 2026

/te-ci

@kainzhong kainzhong marked this pull request as ready for review February 4, 2026 21:08
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Greptile Overview

Greptile Summary

This PR updates the exp2f_rcp helper (device in transformer_engine/common/util/ptx.cuh and host reference in tests/cpp/test_common.h) to correctly handle two edge encodings from E8M0 exponents: 0xFF (NaN) and 0xFE (the only case where the reciprocal needs a mantissa bit to represent 2^-127). The normal path still uses the fast exponent-bit trick via __int_as_float((254 - biased_exp) << 23).

The main blocking issues are unrelated to the new math cases: both files still rely on reinterpret_cast-based type punning (floatuint32_t, uint16_tuint8_t, and int32_tfloat), which is undefined behavior under C++ strict-aliasing and can miscompile in optimized builds. These helpers are used broadly in quantization/scaling codepaths and as test references, so they should be made well-defined (e.g., memcpy/bit_cast, and masking for the uint16_t case).

Confidence Score: 3/5

  • This PR is close to mergeable but has a couple of correctness footguns that should be fixed first.
  • The exp2f_rcp edge-case logic itself is straightforward and consistent, but the surrounding helper code in both modified files uses strict-aliasing-violating reinterpret_cast type punning that can lead to wrong-code under optimization, including in test reference computations.
  • transformer_engine/common/util/ptx.cuh and tests/cpp/test_common.h (type punning sites in float_to_e8m0 and exp2f_rcp).

Important Files Changed

Filename Overview
transformer_engine/common/util/ptx.cuh Updates device exp2f_rcp to special-case biased_exp 0xFF (NaN) and 0xFE (reciprocal needing mantissa bit). Logic looks consistent, but file still contains UB-prone type punning (e.g., reinterpret_cast in float_to_e8m0) that the PR touches nearby and may impact optimized builds.
tests/cpp/test_common.h Updates host test helper exp2f_rcp to match new NaN/0xFE behavior. However, exp2f_rcp uses reinterpret_cast<int32_t*>->float and float_to_e8m0 uses reinterpret_cast<float*>->uint32_t, both strict-aliasing UB that can break under optimization.

Sequence Diagram

sequenceDiagram
  participant K as Kernel/Device code
  participant P as ptx.cuh::exp2f_rcp
  participant F as __int_as_float
  participant T as Host tests
  participant H as test_common.h::exp2f_rcp

  K->>P: exp2f_rcp(biased_exp: e8m0_t)
  alt biased_exp == 0xFF (NaN)
    P->>F: __int_as_float(0x7fffffff)
    F-->>P: qNaN float
  else biased_exp == 0xFE
    P->>F: __int_as_float(0x00400000)
    F-->>P: 2^-127 (subnormal)
  else other
    P->>F: __int_as_float((254 - biased_exp) << 23)
    F-->>P: 2^(127 - biased_exp)
  end
  P-->>K: float reciprocal

  T->>H: exp2f_rcp(biased_exp: fp8e8m0)
  alt biased_exp == 0xFF
    H-->>T: reinterpret bits 0x7fffffff (qNaN)
  else biased_exp == 0xFE
    H-->>T: reinterpret bits 0x00400000 (2^-127)
  else other
    H-->>T: reinterpret bits (254 - biased_exp) << 23
  end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 4, 2026

Additional Comments (2)

transformer_engine/common/util/ptx.cuh
Undefined behavior type punning
float_to_e8m0 returns *reinterpret_cast<e8m0_t *>(&out) (ptx.cuh:354) and later reads val via *reinterpret_cast<uint32_t *>(&val) (ptx.cuh:367). Both violate C++ strict-aliasing rules and can miscompile under optimization (especially with -O2/-O3 -fstrict-aliasing). Prefer a well-defined bit cast (memcpy/std::bit_cast) and for the uint16_t out case simply return static_cast<e8m0_t>(out & 0xFF); to avoid aliasing/alignment issues.


tests/cpp/test_common.h
Strict-aliasing UB in tests
float_to_e8m0 reads val with *reinterpret_cast<uint32_t*>(&val) (test_common.h:417) and exp2f_rcp converts int_val to float via *reinterpret_cast<float*>(&int_val) (test_common.h:436). These are undefined behavior under C++ strict-aliasing and can produce wrong results in optimized test builds; use std::memcpy/std::bit_cast for the bit reinterpretation instead so the reference implementation remains reliable.

@ptrendx ptrendx merged commit 71971e3 into NVIDIA:main Feb 6, 2026
34 of 42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Found that exp2f_rcp has a precision-handling issue when dealing with 0xfe value

2 participants