test: don't hard-code bf16=True on devices that lack bf16 support#6036
test: don't hard-code bf16=True on devices that lack bf16 support#6036behroozazarkhalili wants to merge 5 commits into
Conversation
46a3889 to
8395cf5
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
3d5d9af to
307b9ac
Compare
There was a problem hiding this comment.
Pull request overview
This PR updates the test suite to avoid hard-coding bf16=True on machines that cannot support it (CPU and pre-Ampere CUDA GPUs), aligning test behavior with transformers.TrainingArguments bf16 validation so tests skip/adjust instead of failing during argument validation.
Changes:
- Add
is_bf16_supported()totests/testing_utils.pybased onis_torch_bf16_gpu_available() or is_torch_xla_available(). - Use
bf16=is_bf16_supported()in the opportunistic GRPO vLLM colocate test to avoid invalid bf16 settings on unsupported devices. - Skip invariant FA2 configs requiring bf16 when bf16 isn’t supported, preventing validation-time failures.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
tests/testing_utils.py |
Introduces is_bf16_supported() helper to centralize bf16 capability checks for tests. |
tests/test_grpo_trainer.py |
Switches a GRPO vLLM test to conditionally enable bf16 only when supported. |
tests/invariant/test_invariant.py |
Skips FA2 invariant members on non-bf16-capable devices to avoid TrainingArguments validation errors. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
The default bf16=True (#3515) makes test cases raise "Your setup doesn't support bf16/gpu" on a CPU or a pre-Ampere GPU (e.g. T4). Add an is_bf16_supported() helper, and use bf16=is_bf16_supported() at the opportunistic GRPO site. The flash-attention-2 sites, which genuinely require bf16, are already guarded by is_ampere_or_newer; the invariant FA2 members now skip on non-bf16 devices instead of erroring. Resolves #3616
307b9ac to
ec141cb
Compare
| bf16=True, # Use bfloat16 to reduce memory | ||
| bf16=is_bf16_supported(), # bfloat16 to reduce memory, when the device supports it |
There was a problem hiding this comment.
wouldn't it be easier to just remove bf16=True?
There was a problem hiding this comment.
Done in 2a2d3b5 — went with explicit bf16=False rather than deleting the line, since bf16 defaults to None and _BaseConfig.__post_init__ resolves it to not fp16 = True, so a bare delete would keep it True. (args.bf16 is unused on this construction-only vLLM path anyway.)
…ity helper Removing the line would resolve bf16 back to True (base_config default), so set it explicitly. args.bf16 is unused on this construction-only vLLM path.
|
Good call, you're right that One detail on why I went with self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
It doesn't actually change what this test exercises either way: I kept |
What
Test cases that hard-code
bf16=Trueraise"Your setup doesn't support bf16/gpu"when run on a CPU or a pre-Ampere GPU (e.g. T4), because the defaultbf16=Truewas introduced in #3515 and transformers'TrainingArgumentsvalidation rejectsbf16=Trueon such devices.How
is_bf16_supported()totests/testing_utils.py, mirroring the exact conditionTrainingArgumentsaccepts forbf16=True(is_torch_bf16_gpu_available() or is_torch_xla_available()), as a sibling to the existingis_ampere_or_newer().bf16=is_bf16_supported()at the one opportunistic GRPO site (test_vlm_processor_vllm_colocate_mode, wherebf16is a memory optimization, not a flash-attention-2 requirement).bf16(test_train_padding_freein SFT/DPO,test_vlm_trainingin GRPO) are already guarded by@skipif(not is_ampere_or_newer() and torch_device != "xpu"), so they are left unchanged.sft_fa2,sft_fa2_padfree) requirebf16for the bf16-only FA2 kernels, so they now skip on a non-bf16 device instead of erroring.This follows the approach suggested in the issue (a small
is_bf16_supported()helper threaded through the affected tests).Verification
is_bf16_supported()returnsFalse, matching the device conditionTrainingArgumentsvalidates.bf16=True); they now skip with a clear reason.ruff checkandruff format --checkpass on the changed files.Resolves #3616
Note
Low Risk
Test-only changes with no production training logic; risk is limited to CI behavior on CPU or older GPUs.
Overview
Adds
is_bf16_supported()intests/testing_utils.py(aligned with transformers’TrainingArgumentschecks viais_torch_bf16_gpu_available()andis_torch_xla_available()), so tests can detect CPU / pre-Ampere GPUs wherebf16=Truefails validation.Invariant FA2 cases (
sft_fa2,sft_fa2_padfree) still require bf16 for the kernels;test_invariantnow skips when the config hasbf16=Trueand the helper returns false, instead of erroring during CLI training.In
test_vlm_processor_vllm_colocate_mode,GRPOConfigbf16is set toFalse(replacingTrue) because the test only exercises trainer setup, nottrain(), andTruebreaks on non-bf16 hardware.Reviewed by Cursor Bugbot for commit 2a2d3b5. Bugbot is set up for automated code reviews on this repo. Configure here.