Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
28a5605
move logging utils
mikasenghaas Mar 7, 2026
602177b
update agent start cmd
mikasenghaas Mar 7, 2026
4b94f8a
adjust wording + remove task management prompt
mikasenghaas Mar 7, 2026
996362b
fix event loop blocking: offload serialization to thread, skip log st…
mikasenghaas Mar 7, 2026
6116061
fix: offload rollout serialization to thread to unblock event loop
mikasenghaas Mar 7, 2026
003744e
compact info logs
mikasenghaas Mar 8, 2026
46bb77f
ruff
mikasenghaas Mar 8, 2026
17be841
less verbose start log
mikasenghaas Mar 8, 2026
c73e3c8
use constant timeout
mikasenghaas Mar 8, 2026
69882bf
30s timeout
mikasenghaas Mar 8, 2026
f6469ba
bring back task management prompt
mikasenghaas Mar 8, 2026
c67d4b2
do not redundantly log model abort error + streaming error
mikasenghaas Mar 8, 2026
3dde6a9
minor logging improvements
mikasenghaas Mar 8, 2026
e5f439e
raise agent error
mikasenghaas Mar 9, 2026
db39779
raise sandbox error if agent bg job fails
mikasenghaas Mar 9, 2026
0f46882
Add EventLoopBlockingDetector to capture stack traces when event loop…
mikasenghaas Mar 9, 2026
c717430
Revert "Add EventLoopBlockingDetector to capture stack traces when ev…
mikasenghaas Mar 9, 2026
a6c5bed
add task sys prompt
mikasenghaas Mar 9, 2026
a41609e
handle bg job polling errors
mikasenghaas Mar 9, 2026
a88a216
30min default timeouts
mikasenghaas Mar 9, 2026
617ad93
pipe agent failure
mikasenghaas Mar 9, 2026
ac201f9
http prob for tunnel liveness
mikasenghaas Mar 9, 2026
ffa2f9c
ruff
mikasenghaas Mar 9, 2026
fab96b5
Merge branch 'main' into opencode-envs
mikasenghaas Mar 11, 2026
60a2194
dont re-raise agent error, but log error
mikasenghaas Mar 11, 2026
7542021
update hybrid math rubric constructor
mikasenghaas Mar 11, 2026
be2253b
allow offline difficulty filtering
mikasenghaas Mar 11, 2026
00c1a0f
fix default model
mikasenghaas Mar 11, 2026
c7857f4
collect agent logs
mikasenghaas Mar 11, 2026
b8c1791
use warning log
mikasenghaas Mar 11, 2026
7c9b59b
agent per-request timeout 30min->1h
mikasenghaas Mar 12, 2026
8347df4
remove http health probe again
mikasenghaas Mar 12, 2026
d729f1d
default timeouts to 1h
mikasenghaas Mar 12, 2026
67ea4bf
remove tunnel lock (not needed bc no concurrent health monitor)
mikasenghaas Mar 12, 2026
53c831b
simplify poll_job
mikasenghaas Mar 12, 2026
b149c96
use cls logger
mikasenghaas Mar 12, 2026
5dbf8a2
oops bring back tunnel lock
mikasenghaas Mar 12, 2026
cfb150e
add file upload helpers
mikasenghaas Mar 12, 2026
556b164
allow rubric cleanup
mikasenghaas Mar 12, 2026
bed4ed0
sandbox scoring in hybrid math rubric
mikasenghaas Mar 12, 2026
ee1cc99
do not upload bundle
mikasenghaas Mar 12, 2026
bb4b86c
align local and remote math verify
mikasenghaas Mar 12, 2026
e0a6713
add read_file
mikasenghaas Mar 12, 2026
7996d8b
remote math rubric
mikasenghaas Mar 12, 2026
724383d
remove script
mikasenghaas Mar 12, 2026
4f46a4e
ind retries for scoring
mikasenghaas Mar 12, 2026
f607f4a
catch vf.errors to populate state to allow retries
mikasenghaas Mar 12, 2026
302f0d0
scoring retry tests and state copy for retry isolation
mikasenghaas Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,116 @@ async def test_error_in_state_after_max_retries_exhausted(
assert "InfraError" == error_info["error"]


class RetryScoringRubric(Rubric):
"""Rubric with a reward function that fails first N times with configurable error type."""

def __init__(self, fail_count: int, error_type: type = vf.InfraError, **kwargs):
self.fail_count = fail_count
self.error_type = error_type
self.scoring_call_count = 0

def failing_reward(completion, answer, **kw):
self.scoring_call_count += 1
if self.scoring_call_count <= self.fail_count:
raise self.error_type(
f"Simulated scoring failure {self.scoring_call_count}/{self.fail_count}"
)
return 1.0

super().__init__(funcs=[failing_reward], **kwargs)


class TestScoringRetry:
"""Test cases for scoring retry functionality."""

@pytest.mark.asyncio
async def test_scoring_retry_after_retryable_error(self, mock_client, make_input):
"""Scoring retries on InfraError, succeeds after failures."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
rubric = RetryScoringRubric(fail_count=2)
env = SimpleEnvironment(
dataset=dataset, parser=Parser(), rubric=rubric, score_rollouts=True
)

inputs = [make_input()]
outputs = await env.generate(
inputs, client=mock_client, model="test-model", max_scoring_retries=3
)

assert rubric.scoring_call_count == 3 # 2 failures + 1 success
assert outputs["outputs"][0].get("error") is None
assert outputs["outputs"][0]["reward"] == 1.0

@pytest.mark.asyncio
async def test_scoring_no_retry_after_non_retryable_error(
self, mock_client, make_input
):
"""Non-retryable error type is NOT retried during scoring."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
rubric = RetryScoringRubric(fail_count=10, error_type=vf.ToolError)
env = SimpleEnvironment(
dataset=dataset, parser=Parser(), rubric=rubric, score_rollouts=True
)

inputs = [make_input()]
outputs = await env.generate(
inputs, client=mock_client, model="test-model", max_scoring_retries=3
)

assert rubric.scoring_call_count == 1 # No retries for non-retryable error
# ToolError is stored in state but not retried; reward falls back to 0.0
assert outputs["outputs"][0]["reward"] == 0.0

@pytest.mark.asyncio
async def test_scoring_error_after_retries_exhausted(
self, mock_client, make_input
):
"""Error persists after all scoring retries exhausted."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
rubric = RetryScoringRubric(fail_count=10)
env = SimpleEnvironment(
dataset=dataset, parser=Parser(), rubric=rubric, score_rollouts=True
)

inputs = [make_input()]
outputs = await env.generate(
inputs, client=mock_client, model="test-model", max_scoring_retries=2
)

assert rubric.scoring_call_count == 3 # 1 initial + 2 retries
assert outputs["outputs"][0].get("error") is not None
assert outputs["outputs"][0]["error"]["error"] == "InfraError"

@pytest.mark.asyncio
async def test_independent_rollout_and_scoring_retries(
self, mock_client, make_input
):
"""Rollout and scoring use independent retry counts."""
dataset = Dataset.from_dict({"question": ["test"], "answer": ["test"]})
rubric = RetryScoringRubric(fail_count=1)
env = RetryCounterEnv(
fail_count=1,
dataset=dataset,
parser=Parser(),
rubric=rubric,
score_rollouts=True,
)

inputs = [make_input()]
outputs = await env.generate(
inputs,
client=mock_client,
model="test-model",
max_rollout_retries=2,
max_scoring_retries=2,
)

assert env.call_counts[0] == 2 # 1 failure + 1 success for rollout
assert rubric.scoring_call_count == 2 # 1 failure + 1 success for scoring
assert outputs["outputs"][0].get("error") is None
assert outputs["outputs"][0]["reward"] == 1.0


class TestEmptyModelResponseErrors:
"""Test cases for empty and invalid model response error handling."""

Expand Down
3 changes: 3 additions & 0 deletions tests/test_environment_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ async def run_group(
sampling_args,
max_retries,
state_columns,
**kwargs,
):
assert isinstance(client_config, ClientConfig)
self.client_urls_per_group.append(str(client_config.api_base_url))
Expand Down Expand Up @@ -424,6 +425,7 @@ async def run_group(
sampling_args,
max_retries,
state_columns,
**kwargs,
):
assert isinstance(client_config, ClientConfig)
self.client_url = str(client_config.api_base_url)
Expand Down Expand Up @@ -483,6 +485,7 @@ async def run_rollout(
sampling_args,
max_retries,
state_columns,
**kwargs,
):
assert isinstance(client_config, ClientConfig)
self.client_url = str(client_config.api_base_url)
Expand Down
10 changes: 9 additions & 1 deletion verifiers/envs/env_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,15 @@ async def run_rollout( # type: ignore[override]
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
max_rollout_retries: int | None = None,
max_scoring_retries: int | None = None,
) -> vf.RolloutOutput:
env = self.get_env_for_task(input["task"])
env_client = env_client or env.env_client or self.env_client
return await env.run_rollout(
input, client, model, sampling_args, max_retries, state_columns, env_client
input, client, model, sampling_args, max_retries, state_columns, env_client,
max_rollout_retries=max_rollout_retries,
max_scoring_retries=max_scoring_retries,
)

@final
Expand All @@ -295,6 +299,8 @@ async def run_group( # type: ignore[override]
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
max_rollout_retries: int | None = None,
max_scoring_retries: int | None = None,
) -> list[vf.RolloutOutput]:
env = self.get_env_for_task(group_inputs[0]["task"])
env_client = env_client or env.env_client or self.env_client
Expand All @@ -306,6 +312,8 @@ async def run_group( # type: ignore[override]
max_retries,
state_columns,
env_client,
max_rollout_retries=max_rollout_retries,
max_scoring_retries=max_scoring_retries,
)

@final
Expand Down
55 changes: 49 additions & 6 deletions verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@
_MESSAGE_TYPE_UNSET = object()


def _copy_state_for_scoring(state: "State") -> "State":
"""Shallow-copy a state, duplicating only the mutable objects that scoring writes to."""
copied = State(state)
copied["timing"] = dict(state.get("timing", {}))
copied["trajectory"] = [dict(t) for t in state.get("trajectory", [])]
return copied


class Environment(ABC):
"""
Base class for all environments.
Expand Down Expand Up @@ -710,9 +718,14 @@ async def run_rollout(
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
max_rollout_retries: int | None = None,
max_scoring_retries: int | None = None,
) -> RolloutOutput:
"""Generate and, optionally, score a rollout."""

effective_rollout_retries = max_rollout_retries if max_rollout_retries is not None else max_retries
effective_scoring_retries = max_scoring_retries if max_scoring_retries is not None else max_retries

resolved_client_config: ClientConfig | None = None
if isinstance(client, ClientConfig):
resolved_client_config = resolve_client_config(client)
Expand All @@ -730,26 +743,33 @@ async def run_rollout(
sampling_args,
max_retries,
state_columns,
max_rollout_retries=effective_rollout_retries,
max_scoring_retries=effective_scoring_retries,
)

resolved_client = resolve_client(client)

async def run_rollout_attempt() -> State:
state = await self.rollout(
return await self.rollout(
input,
resolved_client,
model,
sampling_args,
)

state = await maybe_retry(run_rollout_attempt, max_retries=effective_rollout_retries)()
rollout_state = _copy_state_for_scoring(state)

async def run_scoring_attempt() -> State:
state = _copy_state_for_scoring(rollout_state)
if self.score_rollouts:
await self.rubric.score_rollout(state)
else:
await self.rubric.dummy_score_rollout(state)

await self.rubric.cleanup(state)
return state

state = await maybe_retry(run_rollout_attempt, max_retries=max_retries)()
state = await maybe_retry(run_scoring_attempt, max_retries=effective_scoring_retries)()
output = state_to_output(state, state_columns or [])
return output

Expand All @@ -763,10 +783,15 @@ async def run_group(
max_retries: int = 0,
state_columns: list[str] | None = None,
env_client: EnvClient | None = None,
max_rollout_retries: int | None = None,
max_scoring_retries: int | None = None,
**kwargs,
) -> list[RolloutOutput]:
"""Generate and, optionally, score one group."""

effective_rollout_retries = max_rollout_retries if max_rollout_retries is not None else max_retries
effective_scoring_retries = max_scoring_retries if max_scoring_retries is not None else max_retries

resolved_client_config: ClientConfig | None = None
if isinstance(client, ClientConfig):
resolved_client_config = resolve_client_config(client)
Expand All @@ -784,11 +809,13 @@ async def run_group(
sampling_args,
max_retries,
state_columns,
max_rollout_retries=effective_rollout_retries,
max_scoring_retries=effective_scoring_retries,
)

resolved_client = resolve_client(client)

async def run_group_attempt() -> list[State]:
async def run_group_rollout_attempt() -> list[State]:
rollout_tasks = [
self.rollout(
input,
Expand All @@ -798,15 +825,25 @@ async def run_group_attempt() -> list[State]:
)
for input in group_inputs
]
group_states = await asyncio.gather(*rollout_tasks)
return await asyncio.gather(*rollout_tasks)

group_states = await maybe_retry(run_group_rollout_attempt, max_retries=effective_rollout_retries)()
rollout_group_states = [_copy_state_for_scoring(s) for s in group_states]

async def run_group_scoring_attempt() -> list[State]:
group_states = [_copy_state_for_scoring(s) for s in rollout_group_states]

if self.score_rollouts:
await self.rubric.score_group(group_states)
else:
await self.rubric.dummy_score_group(group_states)

for state in group_states:
await self.rubric.cleanup(state)

return group_states

group_states = await maybe_retry(run_group_attempt, max_retries=max_retries)()
group_states = await maybe_retry(run_group_scoring_attempt, max_retries=effective_scoring_retries)()
outputs = [
state_to_output(state, state_columns or []) for state in group_states
]
Expand All @@ -826,6 +863,8 @@ async def generate(
hf_hub_dataset_name: str | None = None,
independent_scoring: bool = False,
max_retries: int = 0,
max_rollout_retries: int | None = None,
max_scoring_retries: int | None = None,
on_start: StartCallback | None = None,
on_progress: ProgressCallback | list[ProgressCallback] | None = None,
on_log: LogCallback | None = None,
Expand Down Expand Up @@ -1021,6 +1060,8 @@ def get_client_for_group() -> Client | ClientConfig:
sampling_args,
max_retries=max_retries,
state_columns=state_columns,
max_rollout_retries=max_rollout_retries,
max_scoring_retries=max_scoring_retries,
),
),
)
Expand All @@ -1047,6 +1088,8 @@ def get_client_for_group() -> Client | ClientConfig:
sampling_args,
max_retries=max_retries,
state_columns=state_columns,
max_rollout_retries=max_rollout_retries,
max_scoring_retries=max_scoring_retries,
),
),
)
Expand Down
Loading