Skip to content

Add inplace=False to liger_cross_entropy to fix upstream gradient corruption (#272)#1251

Open
lollinng wants to merge 2 commits into
linkedin:mainfrom
lollinng:fix/cross-entropy-inplace-272
Open

Add inplace=False to liger_cross_entropy to fix upstream gradient corruption (#272)#1251
lollinng wants to merge 2 commits into
linkedin:mainfrom
lollinng:fix/cross-entropy-inplace-272

Conversation

@lollinng

@lollinng lollinng commented Jun 4, 2026

Copy link
Copy Markdown

Problem (#272)

liger_cross_entropy stores its backward gradient in-place into its input tensor via a Triton tl.store. Triton writes don't bump PyTorch's tensor version counter, so autograd's in-place-correctness check never fires. When the input is the output of an upstream op (e.g. a softmax) that saved it for its own backward, that tensor is silently overwritten and the upstream op computes wrong gradients with no error — the exact scenario in #272.

Fix

Thread an inplace flag (default True, so existing behavior and the memory savings are unchanged) through liger_cross_entropyLigerCrossEntropyFunctioncross_entropy_forward. When inplace=False, the gradient is computed into a clone of the input, leaving the caller's tensor intact.

One subtlety worth calling out: the clone happens inside autograd.Function.forward (grad disabled), so the clone reports requires_grad=False. Since the kernel only writes the gradient when HAS_GRADIENTS is true, I capture requires_grad from the original input before cloning — otherwise the kernel would skip the gradient write entirely.

Verification (NVIDIA T4)

Ran the issue's reproducer (softmax(_p) → cross_entropy → backward) on a GPU, comparing the gradient w.r.t. the pre-softmax input against F.cross_entropy:

reference (F.cross_entropy) _p.grad[0]: [ 0.00168,  0.00287,  0.00029,  0.00546, -0.01822,  0.00242,  0.00231,  0.00319]
liger default (in-place)    _p.grad[0]: [ 2.1e-05,  3.5e-05,  6.8e-06,  6.7e-05,  0.01325,  3.0e-05,  2.8e-05,  3.9e-05]  -> matches ref: False  (the bug)
liger inplace=False         _p.grad[0]: [ 0.00168,  0.00287,  0.00029,  0.00546, -0.01822,  0.00242,  0.00231,  0.00319]  -> matches ref: True
loss matches ref: True

Added test/transformers/test_cross_entropy.py::test_cross_entropy_inplace_does_not_corrupt_upstream_grad, which asserts the default path corrupts the upstream gradient (reproducing the bug) and that inplace=False matches the F.cross_entropy reference.

Fixes #272

lollinng and others added 2 commits June 5, 2026 01:19
… grads (linkedin#272)

liger_cross_entropy stores its backward gradient in-place into the forward
input via a Triton tl.store. Triton writes do not bump PyTorch's tensor
version counter, so autograd's in-place-correctness check never fires. If an
upstream op (e.g. softmax) saved that same tensor for its own backward, it is
silently overwritten and computes wrong gradients with no error (linkedin#272).

Thread an inplace flag (default True, preserving the current memory-saving
behavior) through liger_cross_entropy -> LigerCrossEntropyFunction ->
cross_entropy_forward. When inplace=False, operate on a clone of the input so
the caller's tensor is preserved and upstream gradients stay correct.

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The kernel only writes the gradient when HAS_GRADIENTS is true, which was read
from _input.requires_grad at kernel-launch time. With inplace=False the clone
happens inside the autograd Function (grad disabled), so the clone reported
requires_grad=False and the kernel skipped the gradient write, leaving logit
values in the returned buffer. Capture requires_grad before the clone.

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

In-place operations in triton kernel might result in incorrect gradient calculations

1 participant