Skip to content

test: don't hard-code bf16=True on devices that lack bf16 support#6036

Open
behroozazarkhalili wants to merge 5 commits into
mainfrom
fix/3616-bf16-cpu-tests
Open

test: don't hard-code bf16=True on devices that lack bf16 support#6036
behroozazarkhalili wants to merge 5 commits into
mainfrom
fix/3616-bf16-cpu-tests

Conversation

@behroozazarkhalili

@behroozazarkhalili behroozazarkhalili commented Jun 12, 2026

Copy link
Copy Markdown
Collaborator

What

Test cases that hard-code bf16=True raise "Your setup doesn't support bf16/gpu" when run on a CPU or a pre-Ampere GPU (e.g. T4), because the default bf16=True was introduced in #3515 and transformers' TrainingArguments validation rejects bf16=True on such devices.

How

  • Add is_bf16_supported() to tests/testing_utils.py, mirroring the exact condition TrainingArguments accepts for bf16=True (is_torch_bf16_gpu_available() or is_torch_xla_available()), as a sibling to the existing is_ampere_or_newer().
  • Use bf16=is_bf16_supported() at the one opportunistic GRPO site (test_vlm_processor_vllm_colocate_mode, where bf16 is a memory optimization, not a flash-attention-2 requirement).
  • The flash-attention-2 sites that genuinely require bf16 (test_train_padding_free in SFT/DPO, test_vlm_training in GRPO) are already guarded by @skipif(not is_ampere_or_newer() and torch_device != "xpu"), so they are left unchanged.
  • The invariant FA2 members (sft_fa2, sft_fa2_padfree) require bf16 for the bf16-only FA2 kernels, so they now skip on a non-bf16 device instead of erroring.

This follows the approach suggested in the issue (a small is_bf16_supported() helper threaded through the affected tests).

Verification

  • On CPU, is_bf16_supported() returns False, matching the device condition TrainingArguments validates.
  • The invariant FA2 members previously errored on CPU (bf16=True); they now skip with a clear reason.
  • ruff check and ruff format --check pass on the changed files.

Resolves #3616


Note

Low Risk
Test-only changes with no production training logic; risk is limited to CI behavior on CPU or older GPUs.

Overview
Adds is_bf16_supported() in tests/testing_utils.py (aligned with transformers’ TrainingArguments checks via is_torch_bf16_gpu_available() and is_torch_xla_available()), so tests can detect CPU / pre-Ampere GPUs where bf16=True fails validation.

Invariant FA2 cases (sft_fa2, sft_fa2_padfree) still require bf16 for the kernels; test_invariant now skips when the config has bf16=True and the helper returns false, instead of erroring during CLI training.

In test_vlm_processor_vllm_colocate_mode, GRPOConfig bf16 is set to False (replacing True) because the test only exercises trainer setup, not train(), and True breaks on non-bf16 hardware.

Reviewed by Cursor Bugbot for commit 2a2d3b5. Bugbot is set up for automated code reviews on this repo. Configure here.

@bot-ci-comment

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the test suite to avoid hard-coding bf16=True on machines that cannot support it (CPU and pre-Ampere CUDA GPUs), aligning test behavior with transformers.TrainingArguments bf16 validation so tests skip/adjust instead of failing during argument validation.

Changes:

  • Add is_bf16_supported() to tests/testing_utils.py based on is_torch_bf16_gpu_available() or is_torch_xla_available().
  • Use bf16=is_bf16_supported() in the opportunistic GRPO vLLM colocate test to avoid invalid bf16 settings on unsupported devices.
  • Skip invariant FA2 configs requiring bf16 when bf16 isn’t supported, preventing validation-time failures.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
tests/testing_utils.py Introduces is_bf16_supported() helper to centralize bf16 capability checks for tests.
tests/test_grpo_trainer.py Switches a GRPO vLLM test to conditionally enable bf16 only when supported.
tests/invariant/test_invariant.py Skips FA2 invariant members on non-bf16-capable devices to avoid TrainingArguments validation errors.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

The default bf16=True (#3515) makes test cases raise "Your setup doesn't
support bf16/gpu" on a CPU or a pre-Ampere GPU (e.g. T4). Add an
is_bf16_supported() helper, and use bf16=is_bf16_supported() at the
opportunistic GRPO site. The flash-attention-2 sites, which genuinely
require bf16, are already guarded by is_ampere_or_newer; the invariant
FA2 members now skip on non-bf16 devices instead of erroring.

Resolves #3616
Comment thread tests/test_grpo_trainer.py Outdated
Comment on lines +3198 to +3199
bf16=True, # Use bfloat16 to reduce memory
bf16=is_bf16_supported(), # bfloat16 to reduce memory, when the device supports it

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be easier to just remove bf16=True?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 2a2d3b5 — went with explicit bf16=False rather than deleting the line, since bf16 defaults to None and _BaseConfig.__post_init__ resolves it to not fp16 = True, so a bare delete would keep it True. (args.bf16 is unused on this construction-only vLLM path anyway.)

…ity helper

Removing the line would resolve bf16 back to True (base_config default), so set it explicitly. args.bf16 is unused on this construction-only vLLM path.
@behroozazarkhalili

Copy link
Copy Markdown
Collaborator Author

Good call, you're right that is_bf16_supported() is overkill here. I pushed a simpler version.

One detail on why I went with bf16=False rather than just deleting the line: bf16 defaults to None and is resolved in _BaseConfig.__post_init__ as

self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16

fp16 is unset in this test, so dropping the line entirely resolves bf16 back to True (not False), which is the original value. bf16=False makes it explicit and device-safe instead.

It doesn't actually change what this test exercises either way: test_vlm_processor_vllm_colocate_mode is construction-only (it never calls train()), and the only place args.bf16 is read is the use_transformers_paged branch, which this use_vllm=True test doesn't take. The 4-bit weights are already loaded with bnb_4bit_compute_dtype=torch.bfloat16 regardless, so the flag has no effect on the colocate path. bf16=False just avoids the hard-coded-True-on-an-unsupported-device footgun the PR is about.

I kept is_bf16_supported() for tests/invariant/test_invariant.py, where it gates a real skip for the FA2/bf16 configs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Latest default config of bf16=True breaks test cases when run on CPU.

3 participants