fix(sft): drop examples with no trainable tokens after max_length truncation#6025
fix(sft): drop examples with no trainable tokens after max_length truncation#60250xadvait wants to merge 1 commit into
Conversation
…runcation When all completion or assistant tokens of an example lie beyond the max_length truncation boundary (e.g. the prompt alone reaches max_length), every label is set to -100 after truncation in the data collator. Such examples contribute nothing to the loss, and a batch made entirely of them silently trains with loss 0.0 and grad_norm nan. Drop these examples at dataset preparation time, warn with the dropped count, and raise an actionable error when no example survives. Fixes huggingface#3927
|
Thanks @0xadvait, nice diagnosis; the silent zero-loss from post-tokenization truncation is a real footgun. My hesitation: I'd rather fix it from the other end. The
Truncation can stay in the collator (moving it fights padding_free/packing), so the filter still slices with the same window, but reads one column and carries none of the mask-precedence logic. Since that's bigger than your diff, I'd split it: a first PR moving label construction into prep (and slimming the collator), then the drop-and-warn filter on top; almost trivial at that point. Happy to help with the first part. WDYT? |
|
Agreed on the drift risk; the filter already had to mirror the I took a pass through what part 1 touches, and it's more drop-in than I expected:
For the residual duplication in part 2 (the truncation window), a tiny shared helper for the Happy to draft part 1 along these lines as a separate PR and then strip this one down to the trivial labels filter on top -- or if you'd rather take part 1, I'll rebase once it's in. Converting this to draft meanwhile. |
|
Yes happy to receive a separate pr for 1. |
|
Part 1 is up: #6037. Once it lands I'll strip this PR down to the labels-column filter using the shared truncation window. |
What does this PR do?
assistant_only_loss=True(or completion-only training) combined withmax_lengthtruncation can silently train on nothing.The all-zero
assistant_maskscheck intokenize_fnruns before truncation, but truncation happens later, inDataCollatorForLanguageModeling. An example whose assistant/completion tokens all lie beyondmax_length(e.g. the prompt alone reachesmax_length) passes the check, then gets every label set to -100 by the collator. Since fully-masked micro-batches return a graph-connected zero loss (for DDP/FSDP safety), a batch made of such examples trains with no error and no warning, on current main:(repro: 8 conversational examples with a ~200-token user turn and a short assistant reply,
max_length=64,assistant_only_loss=True;num_tokenskeeps counting while nothing is learned)When only some examples are affected, they silently dilute training; the failure surfaces, if ever, as a confusing downstream error.
The fix
This implements the approach discussed in #3930 -- drop the affected examples at dataset preparation time and log how many were dropped -- without the dataset-wide trainable-token counting that stalled that PR:
max_lengthis set and packing is disabled, a batchedfilterkeeps only examples with at least one trainable token within the truncation boundaryassistant_masksis always applied by the collator when present, whilecompletion_maskis only applied whencompletion_only_lossresolves toTrue-- the filter mirrors exactly that, so examples are never dropped based on a mask the collator won't usetruncation_mode="keep_start"and"keep_end"are handled with the same slices the collator usesValueError(previously: silent zero-loss training, or a cryptic downstream failure)IterableDataset, the same filter applies lazily; the counting/warning applies to map-style datasets wherelen()is availableNo new dependencies, and a single extra batched
filterpass (num_proc-parallel) that only runs when a mask column is present and truncation is active.Tests
test_assistant_only_loss_drops_examples_with_no_trainable_tokens_after_truncation: 2 of 3 examples dropped, warning emitted, training proceeds on the survivortest_assistant_only_loss_raises_if_no_example_has_trainable_tokens: clearValueErrorwhen nothing survivestest_completion_only_loss_drops_examples_with_no_trainable_tokens_after_truncation: same protection for prompt-completion datasets viacompletion_masktest_truncation_keep_end_keeps_examples_with_trailing_trainable_tokens: guards against over-dropping -- withkeep_end, trailing completions survive and nothing is droppedThe first three fail without the fix; the last passes before and after.
Fixes #3927
Before submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Who can review?
@qgallouedec (proposed the drop-and-warn design in #3930)
Note
Medium Risk
Changes dataset preparation for a common SFT configuration (max_length + assistant/completion-only loss), which can shrink training data and fail fast when all examples are invalid; logic must stay aligned with collator truncation to avoid incorrect drops.
Overview
Fixes silent zero-loss SFT when
max_lengthtruncation cuts off every assistant/completion token after tokenization checks already passed (#3927).SFTTrainernow runs a post-tokenizationfilter(only whenmax_lengthis set, packing is off, andassistant_masks/ applicablecompletion_maskcolumns exist). It keeps examples that still have at least one trainable token inside the same truncation window the collator uses (keep_startvskeep_end), matching which masks actually affect labels. Partial drops log a warning with counts; if everything would be dropped, it raises a clearValueErrorinstead of training on all-100labels.Tests cover assistant-only and completion-only drops, the all-dropped error path, and that
truncation_mode="keep_end"does not over-drop trailing completions.Reviewed by Cursor Bugbot for commit ea5dd34. Bugbot is set up for automated code reviews on this repo. Configure here.