diff --git a/packages/optimization/src/ldai_optimization/__init__.py b/packages/optimization/src/ldai_optimization/__init__.py index 87401b3..61c1030 100644 --- a/packages/optimization/src/ldai_optimization/__init__.py +++ b/packages/optimization/src/ldai_optimization/__init__.py @@ -3,6 +3,8 @@ This package will provide helpers to run selected tools against the LaunchDarkly API from SDK-based workflows. """ +from ldai.tracker import TokenUsage + from ldai_optimization.client import OptimizationClient from ldai_optimization.dataclasses import ( AIJudgeCallConfig, @@ -13,6 +15,7 @@ OptimizationJudge, OptimizationJudgeContext, OptimizationOptions, + OptimizationResponse, ToolDefinition, ) from ldai_optimization.ld_api_client import LDApiError @@ -31,5 +34,7 @@ 'OptimizationJudge', 'OptimizationJudgeContext', 'OptimizationOptions', + 'OptimizationResponse', + 'TokenUsage', 'ToolDefinition', ] diff --git a/packages/optimization/src/ldai_optimization/client.py b/packages/optimization/src/ldai_optimization/client.py index 8f0b287..6d74481 100644 --- a/packages/optimization/src/ldai_optimization/client.py +++ b/packages/optimization/src/ldai_optimization/client.py @@ -5,6 +5,7 @@ import logging import os import random +import time import uuid from typing import Any, Dict, List, Literal, Optional, Union @@ -23,6 +24,7 @@ OptimizationJudge, OptimizationJudgeContext, OptimizationOptions, + OptimizationResponse, ToolDefinition, ) from ldai_optimization.ld_api_client import ( @@ -37,12 +39,9 @@ ) from ldai_optimization.util import ( await_if_needed, - create_evaluation_tool, - create_variation_tool, extract_json_from_response, - handle_evaluation_tool_call, - handle_variation_tool_call, interpolate_variables, + restore_variable_placeholders, ) logger = logging.getLogger(__name__) @@ -62,6 +61,25 @@ def _strip_provider_prefix(model: str) -> str: return model.split(".", 1)[-1] +def _compute_validation_count(pool_size: int) -> int: + """Compute how many validation samples to run after a candidate passes in chaos mode. + + Scales with the size of the available input/variable pool so that larger + option sets receive proportionally more validation coverage, capped at 5. + The floor of 2 ensures at least a minimal cross-check even for small pools. + + :param pool_size: Total number of distinct choices in the sampling pool + (user_input_options count when provided, otherwise variable_choices count). + :return: Number of validation samples to run (between 2 and 5 inclusive). + """ + return min(5, max(2, pool_size // 4)) + + +# Maximum number of attempts for variation generation. Transient empty or +# unparseable responses from the LLM are retried up to this many times before +# the variation step is treated as a failure. +_MAX_VARIATION_RETRIES = 3 + # Maps SDK status strings to the API status/activity values expected by # agent_optimization_result records. Defined at module level to avoid # allocating the dict on every on_status_update invocation. @@ -70,6 +88,7 @@ def _strip_provider_prefix(model: str) -> str: "generating": {"status": "RUNNING", "activity": "GENERATING"}, "evaluating": {"status": "RUNNING", "activity": "EVALUATING"}, "generating variation": {"status": "RUNNING", "activity": "GENERATING_VARIATION"}, + "validating": {"status": "RUNNING", "activity": "VALIDATING"}, "turn completed": {"status": "RUNNING", "activity": "COMPLETED"}, "success": {"status": "PASSED", "activity": "COMPLETED"}, "failure": {"status": "FAILED", "activity": "COMPLETED"}, @@ -183,6 +202,7 @@ def _safe_status_update( "generating", "evaluating", "generating variation", + "validating", "turn completed", "success", "failure", @@ -299,39 +319,6 @@ def _parse_judge_response( ) return JudgeResult(score=0.0, rationale=None) - def _builtin_judge_tool_handlers(self) -> Dict[str, Any]: - """ - Build the dict of built-in tool name → handler passed to handle_judge_call. - - Each handler accepts the tool-call arguments dict produced by the LLM and - returns a JSON string so the caller can forward it back to the model or use - it directly as the judge response. - - :return: Mapping of built-in tool names to their handler callables - """ - return { - create_evaluation_tool().name: handle_evaluation_tool_call, - } - - def _builtin_agent_tool_handlers(self, is_variation: bool) -> Dict[str, Any]: - """ - Build the dict of built-in tool name → handler passed to handle_agent_call. - - For regular agent turns this is empty — the config only contains user-defined - tools from the LD flag. For variation-generation turns the variation structured - output tool is included so the caller can distinguish it from user tools and - route the LLM tool call back to the framework. - - :param is_variation: True when called for a variation-generation turn - :return: Mapping of built-in tool names to their handler callables - """ - if is_variation: - return { - create_variation_tool( - self._options.model_choices - ).name: handle_variation_tool_call, - } - return {} async def _call_judges( self, @@ -518,8 +505,7 @@ async def _evaluate_config_judge( if msg.role == "system": system_parts.append( msg.content - + " Use the structured output tool to format your response." - " You should always return a JSON object with a score and rationale." + + " Return your response as a JSON object with 'score' and 'rationale' fields." ) elif msg.role == "user": user_parts.append(msg.content) @@ -574,14 +560,12 @@ async def _evaluate_config_judge( if agent_tools: tools = list(agent_tools) + tools - # Add structured output tool for score and rationale - tools.append(create_evaluation_tool()) - + tool_params = {"tools": [t.to_dict() for t in tools]} if tools else {} judge_call_config = AIJudgeCallConfig( key=judge_key, model=ModelConfig( name=model_name, - parameters={**model_params, "tools": [t.to_dict() for t in tools]}, + parameters={**model_params, **tool_params}, ), instructions=instructions, messages=updated_messages, @@ -592,10 +576,13 @@ async def _evaluate_config_judge( variables=variables or {}, ) + _judge_start = time.monotonic() result = self._options.handle_judge_call( - judge_key, judge_call_config, judge_ctx, self._builtin_judge_tool_handlers() + judge_key, judge_call_config, judge_ctx ) - judge_response_str = await await_if_needed(result) + judge_response: OptimizationResponse = await await_if_needed(result) + judge_duration_ms = (time.monotonic() - _judge_start) * 1000 + judge_response_str = judge_response.output logger.debug( "[Iteration %d] -> Judge response (%s): %s", @@ -606,13 +593,14 @@ async def _evaluate_config_judge( # Parse judge response — expect structured JSON output judge_identifier = optimization_judge.judge_key or judge_key - return self._parse_judge_response( + judge_result = self._parse_judge_response( judge_response_str, judge_key, judge_identifier, iteration, clamp_score=False, ) + return dataclasses.replace(judge_result, duration_ms=judge_duration_ms, usage=judge_response.usage) async def _evaluate_acceptance_judge( self, @@ -670,7 +658,7 @@ async def _evaluate_acceptance_judge( "A score of 0.0-0.3 means that it does not match well at all. " "You can return any value between 0.0 and 1.0.\n" "You should also provide a rationale for your score.\n" - "You should call the structured output tool to format your response.\n\n" + "Return your response as a JSON object with 'score' and 'rationale' fields.\n\n" 'Example: {"score": 0.8, "rationale": "The response matches the acceptance statement well."}' ) @@ -689,10 +677,8 @@ async def _evaluate_acceptance_judge( "Assume that previous feedback will have addressed bad tool call results from prior iterations." ) - # Prepend agent tools so the judge can invoke them for verification if needed - tools: List[ToolDefinition] = list(resolved_agent_tools) + [ - create_evaluation_tool() - ] + # Agent tools are passed through so the judge can invoke them for verification + tools: List[ToolDefinition] = list(resolved_agent_tools) judge_user_input = f"Here is the response to evaluate: {completion_response}" if expected_response is not None: @@ -702,11 +688,12 @@ async def _evaluate_acceptance_judge( "how closely it matches the expected response. Factor both into your score." ) + tool_params = {"tools": [t.to_dict() for t in tools]} if tools else {} judge_call_config = AIJudgeCallConfig( key=judge_key, model=ModelConfig( name=self._options.judge_model, - parameters={"tools": [t.to_dict() for t in tools]}, + parameters=tool_params, ), instructions=instructions, messages=[ @@ -720,22 +707,26 @@ async def _evaluate_acceptance_judge( variables=resolved_variables, ) + _judge_start = time.monotonic() result = self._options.handle_judge_call( - judge_key, judge_call_config, judge_ctx, self._builtin_judge_tool_handlers() + judge_key, judge_call_config, judge_ctx ) - judge_response = await await_if_needed(result) + judge_response: OptimizationResponse = await await_if_needed(result) + judge_duration_ms = (time.monotonic() - _judge_start) * 1000 + judge_response_str = judge_response.output logger.debug( "[Iteration %d] -> Judge response (%s): %s", iteration, judge_key, - judge_response, + judge_response_str, ) # Parse judge response — expect structured JSON output with score and rationale - return self._parse_judge_response( - judge_response, judge_key, judge_key, iteration, clamp_score=True + judge_result = self._parse_judge_response( + judge_response_str, judge_key, judge_key, iteration, clamp_score=True ) + return dataclasses.replace(judge_result, duration_ms=judge_duration_ms, usage=judge_response.usage) async def _get_agent_config( self, agent_key: str, context: Context @@ -1060,6 +1051,18 @@ def _apply_new_variation_response( ) self._current_instructions = response_data["current_instructions"] + + # Post-process: replace any leaked variable values back to {{key}} form. + # This is a deterministic safety net for when the LLM ignores the prompt + # instructions and hardcodes a concrete value (e.g. "user-123") instead + # of the placeholder ("{{user_id}}"). + self._current_instructions, placeholder_warnings = restore_variable_placeholders( + self._current_instructions, + self._options.variable_choices, + ) + for msg in placeholder_warnings: + logger.warning("[Iteration %d] -> %s", iteration, msg) + self._current_parameters = response_data["current_parameters"] # Update model — it should always be provided since it's required in the schema @@ -1159,15 +1162,12 @@ async def _generate_new_variation( flat_history = [prev_ctx.copy_without_history() for prev_ctx in self._history] # Create context for variation generation — low temperature for deterministic output. - # The variation tool is placed in current_parameters["tools"] so it surfaces through - # AIAgentConfig.model.parameters like any other tool, rather than as a separate field. variation_ctx = OptimizationContext( scores={}, completion_response="", current_instructions=instructions, current_parameters={ "temperature": 0.1, - "tools": [create_variation_tool(self._options.model_choices).to_dict()], }, current_variables=variables, current_model=self._current_model, @@ -1177,17 +1177,36 @@ async def _generate_new_variation( ) # Call handle_agent_call to generate new variation; expects a JSON string - # matching the structured output schema (current_instructions, current_parameters, model) - result = self._options.handle_agent_call( - self._agent_key, - self._build_agent_config_for_context(variation_ctx), - variation_ctx, - self._builtin_agent_tool_handlers(is_variation=True), - ) - response_str = await await_if_needed(result) + # matching the structured output schema (current_instructions, current_parameters, model). + # Retry up to _MAX_VARIATION_RETRIES times to handle transient empty or unparseable + # responses (e.g. when the agent SDK returns the LLM's post-tool-call empty text + # instead of the tool result). + agent_config = self._build_agent_config_for_context(variation_ctx) + response_data = None + response_str = "" + for attempt in range(1, _MAX_VARIATION_RETRIES + 1): + result = self._options.handle_agent_call( + self._agent_key, + agent_config, + variation_ctx, + ) + variation_response: OptimizationResponse = await await_if_needed(result) + response_str = variation_response.output + try: + response_data = extract_json_from_response(response_str) + break + except ValueError: + if attempt == _MAX_VARIATION_RETRIES: + raise + logger.warning( + "[Iteration %d] -> Variation response empty or unparseable " + "(attempt %d/%d), retrying...", + iteration, + attempt, + _MAX_VARIATION_RETRIES, + ) - # Extract and update current state from the parsed response - response_data = extract_json_from_response(response_str) + assert response_data is not None # loop always raises or breaks with data return self._apply_new_variation_response( response_data, variation_ctx, response_str, iteration ) @@ -1296,6 +1315,7 @@ def _persist_and_forward( "generating", "evaluating", "generating variation", + "validating", "turn completed", "success", "failure", @@ -1320,6 +1340,28 @@ def _persist_and_forward( "scores": {k: v.to_json() for k, v in snapshot.scores.items()}, "user_input": snapshot.user_input, } + if snapshot.duration_ms is not None: + payload["generation_latency"] = snapshot.duration_ms + if snapshot.usage is not None: + payload["generation_tokens"] = { + "total": snapshot.usage.total, + "input": snapshot.usage.input, + "output": snapshot.usage.output, + } + eval_latencies = { + k: v.duration_ms + for k, v in snapshot.scores.items() + if v.duration_ms is not None + } + if eval_latencies: + payload["evaluation_latencies"] = eval_latencies + eval_tokens = { + k: {"total": v.usage.total, "input": v.usage.input, "output": v.usage.output} + for k, v in snapshot.scores.items() + if v.usage is not None + } + if eval_tokens: + payload["evaluation_tokens"] = eval_tokens api_client.post_agent_optimization_result(project_key, optimization_id, payload) if options.on_status_update: @@ -1412,13 +1454,15 @@ async def _execute_agent_turn( optimize_context.current_model, ) try: + _agent_start = time.monotonic() result = self._options.handle_agent_call( self._agent_key, self._build_agent_config_for_context(optimize_context), optimize_context, - self._builtin_agent_tool_handlers(is_variation=False), ) - completion_response = await await_if_needed(result) + agent_response: OptimizationResponse = await await_if_needed(result) + agent_duration_ms = (time.monotonic() - _agent_start) * 1000 + completion_response = agent_response.output logger.debug( "[Iteration %d] -> Agent response: %.300s%s", iteration, @@ -1448,6 +1492,8 @@ async def _execute_agent_turn( optimize_context, completion_response=completion_response, scores=scores, + duration_ms=agent_duration_ms, + usage=agent_response.usage, ) def _evaluate_response(self, optimize_context: OptimizationContext) -> bool: @@ -1527,6 +1573,149 @@ def _handle_failure( ) return optimize_context + async def _run_validation_phase( + self, + passing_context: OptimizationContext, + iteration: int, + ) -> "tuple[bool, OptimizationContext]": + """Run additional evaluations against distinct random samples to confirm a passing candidate. + + Mirrors the sampling logic of _run_optimization: each validation turn selects + a user_input from user_input_options (when provided) AND a variables dict from + variable_choices independently. The validation count and distinctness guarantee + are driven by whichever pool is larger — user_input_options when present, + otherwise variable_choices — ensuring validation turns use inputs the passing + turn did not. + + If all samples pass, the caller should proceed to _handle_success. If any + sample fails, the caller should treat the result as a normal failed attempt + and generate a new variation. + + Validation turns are numbered sequentially in logs (iteration + 1, + 2, …) + for readability, but this numbering is internal only — the caller's iteration + counter is never advanced by this method so validation samples do not consume + the attempt budget. + + :param passing_context: The OptimizationContext from the turn that just passed. + :param iteration: The iteration number of the passing turn; used as the + base for validation log line numbering only. + :return: Tuple of (all_passed, last_context). + """ + options = self._options + + # Determine the primary axis of distinctness and the pool size. + # user_input_options drives the count when present; otherwise variable_choices does. + # In either case, both user_input and variables are selected per-sample just as + # they are in the main optimization loop. + if options.user_input_options: + primary_pool: List[str] = options.user_input_options + passing_input: Optional[str] = passing_context.user_input + remaining_inputs: List[str] = [ + inp for inp in primary_pool if inp != passing_input + ] + pool_size = len(primary_pool) + else: + var_pool: List[Dict[str, Any]] = options.variable_choices + passing_vars: Dict[str, Any] = passing_context.current_variables + remaining_vars: List[Dict[str, Any]] = [ + v for v in var_pool if v != passing_vars + ] + pool_size = len(var_pool) + + validation_count = _compute_validation_count(pool_size) + # Cap to the number of distinct remaining items, but never below 1. + # When the pool is exhausted (e.g. only one variable choice), sample + # with replacement from the full pool so at least one validation run + # always executes. + if options.user_input_options: + available = len(remaining_inputs) + else: + available = len(remaining_vars) + + allow_repeats = available == 0 + if allow_repeats: + validation_count = 1 + else: + validation_count = min(validation_count, available) + + logger.info( + "[Iteration %d] -> Candidate passed — entering validation phase (%d sample(s)%s)", + iteration, + validation_count, + ", repeated draw" if allow_repeats else "", + ) + self._safe_status_update("validating", passing_context, iteration) + + # Sample primary items, falling back to the full pool when no distinct + # items remain so the minimum-1 floor is always satisfied. + if options.user_input_options: + source_inputs = primary_pool if allow_repeats else remaining_inputs + sampled_inputs: List[str] = random.sample(source_inputs, validation_count) + else: + source_vars = var_pool if allow_repeats else remaining_vars + sampled_vars: List[Dict[str, Any]] = random.sample(source_vars, validation_count) + + last_ctx = passing_context + for i in range(validation_count): + val_iter = iteration + i + 1 + if options.user_input_options: + user_input: Optional[str] = sampled_inputs[i] + variables: Dict[str, Any] = random.choice(options.variable_choices) + else: + user_input = None + variables = sampled_vars[i] + + logger.info( + "[Validation %d/%d] -> Running sample (iteration=%d)", + i + 1, + validation_count, + val_iter, + ) + + val_ctx = self._create_optimization_context( + iteration=val_iter, + user_input=user_input, + variables=variables, + ) + self._safe_status_update("generating", val_ctx, val_iter) + val_ctx = await self._execute_agent_turn(val_ctx, val_iter) + + if options.on_turn is not None: + try: + sample_passed = options.on_turn(val_ctx) + except Exception: + logger.exception( + "[Validation %d/%d] -> on_turn evaluation failed", i + 1, validation_count + ) + sample_passed = False + else: + sample_passed = self._evaluate_response(val_ctx) + + last_ctx = val_ctx + + if not sample_passed: + logger.info( + "[Validation %d/%d] -> FAILED (iteration=%d) — candidate rejected", + i + 1, + validation_count, + val_iter, + ) + return False, last_ctx + + logger.debug( + "[Validation %d/%d] -> passed (iteration=%d)", + i + 1, + validation_count, + val_iter, + ) + + logger.info( + "[Iteration %d] -> All %d validation sample(s) passed — candidate confirmed", + iteration, + validation_count, + ) + return True, last_ctx + async def _run_optimization( self, agent_config: AIAgentConfig, options: OptimizationOptions ) -> Any: @@ -1594,62 +1783,75 @@ async def _run_optimization( ) on_turn_result = False - if on_turn_result: + initial_passed = on_turn_result + if initial_passed: logger.info( "[Iteration %d] -> on_turn returned True — turn passed", iteration, ) - return self._handle_success(optimize_context, iteration) + else: + # Auto-path: judge scores determine pass/fail via _evaluate_response + initial_passed = self._evaluate_response(optimize_context) + if initial_passed: + logger.info( + "[Iteration %d] -> All judges passed — turn succeeded", + iteration, + ) + if initial_passed: + all_valid, last_ctx = await self._run_validation_phase( + optimize_context, iteration + ) + if all_valid: + return self._handle_success(last_ctx, iteration) + # Validation failed — treat as a normal failed attempt logger.info( - "[Iteration %d] -> on_turn returned False — turn failed (attempt %d/%d)", + "[Iteration %d] -> Validation failed — generating new variation (attempt %d/%d)", iteration, iteration, self._options.max_attempts, ) if iteration >= self._options.max_attempts: - return self._handle_failure(optimize_context, iteration) - self._history.append(optimize_context) + return self._handle_failure(last_ctx, iteration) + self._history.append(last_ctx) try: await self._generate_new_variation( - iteration, optimize_context.current_variables + iteration, last_ctx.current_variables ) except Exception: logger.exception( "[Iteration %d] -> variation generation failed", iteration ) - return self._handle_failure(optimize_context, iteration) - self._safe_status_update("turn completed", optimize_context, iteration) + return self._handle_failure(last_ctx, iteration) + self._safe_status_update("turn completed", last_ctx, iteration) continue + + # Initial turn failed + if self._options.on_turn is not None: + logger.info( + "[Iteration %d] -> on_turn returned False — turn failed (attempt %d/%d)", + iteration, + iteration, + self._options.max_attempts, + ) else: - # Auto-path: judge scores determine pass/fail via _evaluate_response - passes = self._evaluate_response(optimize_context) - if passes: - logger.info( - "[Iteration %d] -> All judges passed — turn succeeded", - iteration, - ) - return self._handle_success(optimize_context, iteration) - else: - logger.info( - "[Iteration %d] -> One or more judges failed (attempt %d/%d) — generating new variation", - iteration, - iteration, - self._options.max_attempts, - ) - if iteration >= self._options.max_attempts: - return self._handle_failure(optimize_context, iteration) - self._history.append(optimize_context) - try: - await self._generate_new_variation( - iteration, optimize_context.current_variables - ) - except Exception: - logger.exception( - "[Iteration %d] -> variation generation failed", iteration - ) - return self._handle_failure(optimize_context, iteration) - self._safe_status_update( - "turn completed", optimize_context, iteration - ) - continue + logger.info( + "[Iteration %d] -> One or more judges failed (attempt %d/%d) — generating new variation", + iteration, + iteration, + self._options.max_attempts, + ) + if iteration >= self._options.max_attempts: + return self._handle_failure(optimize_context, iteration) + self._history.append(optimize_context) + try: + await self._generate_new_variation( + iteration, optimize_context.current_variables + ) + except Exception: + logger.exception( + "[Iteration %d] -> variation generation failed", iteration + ) + return self._handle_failure(optimize_context, iteration) + self._safe_status_update("turn completed", optimize_context, iteration) + continue diff --git a/packages/optimization/src/ldai_optimization/dataclasses.py b/packages/optimization/src/ldai_optimization/dataclasses.py index fdca939..02b6a7d 100644 --- a/packages/optimization/src/ldai_optimization/dataclasses.py +++ b/packages/optimization/src/ldai_optimization/dataclasses.py @@ -17,15 +17,31 @@ from ldai import AIAgentConfig from ldai.models import LDMessage, ModelConfig +from ldai.tracker import TokenUsage from ldclient import Context +@dataclass +class OptimizationResponse: + """The return value for both ``handle_agent_call`` and ``handle_judge_call`` callbacks. + + :param output: The text output produced by the LLM. + :param usage: Optional token usage for this call. Set fields to 0 or omit entirely + if token tracking is not available for the framework being used. + """ + + output: str + usage: Optional[TokenUsage] = None + + @dataclass class JudgeResult: """Result from a judge evaluation.""" score: float rationale: Optional[str] = None + duration_ms: Optional[float] = None + usage: Optional[TokenUsage] = None def to_json(self) -> Dict[str, Any]: """ @@ -33,10 +49,18 @@ def to_json(self) -> Dict[str, Any]: :return: Dictionary representation of the judge result that can be serialized with json.dumps() """ - return { + result: Dict[str, Any] = { "score": self.score, "rationale": self.rationale, + "duration_ms": self.duration_ms, } + if self.usage is not None: + result["usage"] = { + "total": self.usage.total, + "input": self.usage.input, + "output": self.usage.output, + } + return result @dataclass @@ -152,6 +176,8 @@ class OptimizationContext: default_factory=list ) # previous context items iteration: int = 0 # current iteration number + duration_ms: Optional[float] = None # wall-clock time for the agent call in milliseconds + usage: Optional[TokenUsage] = None # token usage reported by the agent for this iteration def copy_without_history(self) -> OptimizationContext: """ @@ -169,6 +195,8 @@ def copy_without_history(self) -> OptimizationContext: user_input=self.user_input, history=(), # Empty history to keep it flat iteration=self.iteration, + duration_ms=self.duration_ms, + usage=self.usage, ) def to_json(self) -> Dict[str, Any]: @@ -183,7 +211,7 @@ def to_json(self) -> Dict[str, Any]: history_list = [ctx.to_json() for ctx in self.history] - return { + result: Dict[str, Any] = { "scores": scores_dict, "completion_response": self.completion_response, "current_instructions": self.current_instructions, @@ -193,7 +221,15 @@ def to_json(self) -> Dict[str, Any]: "current_variables": self.current_variables, "history": history_list, "iteration": self.iteration, + "duration_ms": self.duration_ms, } + if self.usage is not None: + result["usage"] = { + "total": self.usage.total, + "input": self.usage.input, + "output": self.usage.output, + } + return result @dataclass @@ -209,12 +245,12 @@ class OptimizationJudgeContext: # Placed here so all referenced types (OptimizationContext, AIJudgeCallConfig, # OptimizationJudgeContext) are already defined above. HandleAgentCall = Union[ - Callable[[str, AIAgentConfig, OptimizationContext, Dict[str, Callable[..., Any]]], str], - Callable[[str, AIAgentConfig, OptimizationContext, Dict[str, Callable[..., Any]]], Awaitable[str]], + Callable[[str, AIAgentConfig, OptimizationContext], OptimizationResponse], + Callable[[str, AIAgentConfig, OptimizationContext], Awaitable[OptimizationResponse]], ] HandleJudgeCall = Union[ - Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext, Dict[str, Callable[..., Any]]], str], - Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext, Dict[str, Callable[..., Any]]], Awaitable[str]], + Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext], OptimizationResponse], + Callable[[str, AIJudgeCallConfig, OptimizationJudgeContext], Awaitable[OptimizationResponse]], ] _StatusLiteral = Literal[ @@ -222,6 +258,7 @@ class OptimizationJudgeContext: "generating", "evaluating", "generating variation", + "validating", "turn completed", "success", "failure", diff --git a/packages/optimization/src/ldai_optimization/ld_api_client.py b/packages/optimization/src/ldai_optimization/ld_api_client.py index 8a457cc..34f5921 100644 --- a/packages/optimization/src/ldai_optimization/ld_api_client.py +++ b/packages/optimization/src/ldai_optimization/ld_api_client.py @@ -99,9 +99,7 @@ class OptimizationResultPayload(_OptimizationResultPayloadRequired, total=False) """Typed payload for a single agent_optimization_result POST request. Required fields are always sent. Optional fields are omitted when not - available. Fields that require separate tracking instrumentation - (variation, generation_tokens, evaluation_tokens, generation_latency, - evaluation_latencies) are deferred. + available. created_variation_key is only present on the final result record of a successful run, populated once a winning variation is committed to LD. @@ -109,6 +107,10 @@ class OptimizationResultPayload(_OptimizationResultPayloadRequired, total=False) user_input: Optional[str] created_variation_key: str + generation_latency: float + generation_tokens: Dict[str, int] + evaluation_latencies: Dict[str, float] + evaluation_tokens: Dict[str, Dict[str, int]] # --------------------------------------------------------------------------- diff --git a/packages/optimization/src/ldai_optimization/prompts.py b/packages/optimization/src/ldai_optimization/prompts.py index 556b661..c8631c5 100644 --- a/packages/optimization/src/ldai_optimization/prompts.py +++ b/packages/optimization/src/ldai_optimization/prompts.py @@ -108,6 +108,7 @@ def build_new_variation_prompt( history, current_model, current_instructions, current_parameters ), variation_prompt_feedback(history, judges), + variation_prompt_overfit_warning(history), variation_prompt_improvement_instructions( history, model_choices, variable_choices, initial_instructions ), @@ -133,6 +134,10 @@ def variation_prompt_preamble() -> str: "If the original instructions were to use a placeholder like {{id}}, " "you should keep the placeholder in the new instructions, not replace it with the actual value. " "This is the case for all parameterized values (all parameters should appear in each new variation).", + "IMPORTANT: placeholder names are fixed identifiers (e.g. {{user_id}}, {{trip_purpose}}) — " + "never substitute the runtime value of a variable in place of its name. " + "For example, if the variable key is 'user_id' and its current value is 'user-125', " + "the placeholder MUST be written as {{user_id}}, NOT {{user-125}}.", "Pay particular attention to the instructions regarding tools and the rules for variables.", ] ) @@ -260,6 +265,55 @@ def variation_prompt_feedback( return "\n".join(lines) +def variation_prompt_overfit_warning(history: List[OptimizationContext]) -> str: + """ + Overfitting warning section of the variation prompt. + + Combines a general reminder to write generalizable instructions with + specific values from the most recent iteration so the LLM knows exactly + what concrete values to avoid embedding literally. Returns an empty string + when there is no history (first attempt, no feedback to overfit to). + + :param history: All previous OptimizationContexts, oldest first. + :return: Overfitting warning block, or empty string if history is empty. + """ + if not history: + return "" + + recent = history[-1] + + lines = [ + "## *** OVERFITTING WARNING ***", + "Do NOT hardcode specific values from the evaluation feedback into the instructions.", + "The configuration must generalise to all possible inputs, not just the ones seen so far.", + "Write instructions that treat the values below as examples of a broader class of inputs,", + "not as literals to match.", + "", + "The following specific values appeared in the most recent iteration " + "— do not embed them literally:", + ] + + if recent.user_input: + lines.append(f'- User input: "{recent.user_input}"') + + if recent.current_variables: + for k, v in recent.current_variables.items(): + lines.append(f' - placeholder {{{{{k}}}}}, current value: "{v}"') + lines.append( + " (These are the placeholder NAMES mapped to their current VALUES" + " — never use a value as a placeholder name)" + ) + + lines += [ + "", + "If you find yourself writing instructions that only work for the exact values above,", + "step back and generalise: what rule, pattern, or intent do those values represent?", + "Write instructions that satisfy that rule for any valid input.", + ] + + return "\n".join(lines) + + def variation_prompt_improvement_instructions( history: List[OptimizationContext], model_choices: List[str], @@ -284,40 +338,51 @@ def variation_prompt_improvement_instructions( ] ) - # Collect unique variable keys across all variable_choices entries - variable_keys: set = set() + # Build a per-variable table: key → sorted list of unique example values + # collected across all variable_choices entries. + examples: Dict[str, List[str]] = {} for choice in variable_choices: - variable_keys.update(choice.keys()) - placeholder_list = ", ".join(f"{{{{{k}}}}}" for k in sorted(variable_keys)) + for k, v in choice.items(): + examples.setdefault(k, []) + sv = str(v) + if sv not in examples[k]: + examples[k].append(sv) + + table_lines = [ + "## Prompt Variables:", + "These are the ONLY valid placeholder names. " + "Use them exactly as shown (case-sensitive) with {{...}} syntax:", + "", + ] + for k in sorted(examples.keys()): + vals = ", ".join(f'"{v}"' for v in examples[k]) + table_lines.append(f" - {{{{{k}}}}} (example values: {vals})") + + # Build concrete bad/good counterexamples using the actual keys and values + # so the LLM cannot mistake a runtime value for a placeholder name. + first_key = sorted(examples.keys())[0] if examples else "variable_name" + first_val = examples[first_key][0] if examples.get(first_key) else "some-value" + table_lines += [ + "", + "IMPORTANT: The names above are the KEYS — they are the placeholder names.", + "The values listed are only runtime examples that will be substituted at call time.", + "NEVER use a runtime value as a placeholder name.", + f'BAD: "...{{{{...{first_val}...}}}}..." ' + f'— "{first_val}" is a runtime value, not a placeholder name', + f'GOOD: "...{{{{{first_key}}}}}..." ' + f'— "{first_key}" is the correct placeholder name', + ] variable_instructions = "\n".join( - [ - "## Prompt Variables:", - "These variables are substituted into the instructions at call time using {{variable_name}} syntax.", - "Rules:", - "- If the {{variable_name}} placeholder is not present in the current instructions, " - "you should include it where logically appropriate.", + table_lines + + [ + "", + "If a placeholder is not present in the current instructions, " + "include it where logically appropriate.", "Here are the original instructions so that you can see how the " "placeholders are used and which are available:", "\nSTART:" "\n" + initial_instructions + "\n", "\nEND OF ORIGINAL INSTRUCTIONS\n", - "The following prompt variables are available and are the only " - f"variables that should be used: {placeholder_list}", - "Here is an example of a good response if an {{id}} placeholder is available: " - "'Select records matching id {{id}}'", - "Here is an example of a bad response if an {{id}} placeholder is available: " - "'Select records matching id 1232'", - "Here is an example of a good response if a {{resource_id}} and {{resource_type}} " - "placeholder are available: " - "'Select records matching id {{resource_id}} and type {{resource_type}}'", - "Here is an example of a bad response if a {{resource_id}} and {{resource_type}} " - "placeholder are available: " - "'Select records matching id 1232 and type {{resource_type}}'", - "Here is another example of a bad response if a {{resource_id}} and {{resource_type}} " - "placeholder are available: " - "'Select records matching id {{resource_id}} and type {{resource-123}}'", - "The above example is incorrect because the resource-123 is not a valid variable name.", - "To fix the above example, you would instead use {{resource_type}} and {{resource_id}}", ] ) @@ -362,9 +427,6 @@ def variation_prompt_improvement_instructions( " }", "}", "", - "Always call the return_improved_configuration tool to format the response.", - "Return the response as-is from the return_improved_configuration tool,", - "do not modify it in any way.", ] ) diff --git a/packages/optimization/src/ldai_optimization/util.py b/packages/optimization/src/ldai_optimization/util.py index 2882c87..0f901d5 100644 --- a/packages/optimization/src/ldai_optimization/util.py +++ b/packages/optimization/src/ldai_optimization/util.py @@ -4,7 +4,7 @@ import json import logging import re -from typing import Any, Awaitable, Dict, List, Optional, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from ldai_optimization.dataclasses import ToolDefinition @@ -68,7 +68,92 @@ def replace(match: re.Match) -> str: key = match.group(1).strip() return str(variables[key]) if key in variables else match.group(0) - return re.sub(r"\{\{(\w+)\}\}", replace, text) + return re.sub(r"\{\{([\w-]+)\}\}", replace, text) + + +def restore_variable_placeholders( + text: str, + variable_choices: List[Dict[str, Any]], + min_value_length: int = 3, +) -> Tuple[str, List[str]]: + """ + Scan ``text`` for leaked variable values and restore them to ``{{key}}`` form. + + This is the deterministic inverse of :func:`interpolate_variables`. It acts + as a post-processing safety net after variation generation: when the LLM + hardcodes a concrete variable value (e.g. ``user-123``) instead of writing + the placeholder (``{{user_id}}``), this function replaces the value back so + subsequent iterations receive correctly templated instructions. + + Values are matched with boundary guards so that a value like ``user-123`` + inside a longer token like ``user-1234`` is not substituted. Multi-line + values are handled identically to single-line ones — ``re.escape`` produces + a literal pattern and the lookbehind/lookahead only inspect the character + immediately adjacent to the match boundary. + + Values shorter than ``min_value_length`` characters are skipped because + short strings (e.g. ``"en"``, ``"US"``) are too likely to appear + coincidentally in unrelated prose. + + :param text: The generated instruction string to clean. + :param variable_choices: All possible variable dicts, used to build the + reverse value→key map. When the same value appears under multiple keys + the first key encountered wins. + :param min_value_length: Minimum character length a value must have before + it is considered for replacement. Defaults to 3. + :return: A tuple of ``(cleaned_text, warnings)`` where ``warnings`` is a + list of human-readable strings describing each replacement made. + """ + # Build reverse map: string(value) → key. Longest values first so that + # a longer value like "user-123-admin" is replaced before the shorter + # "user-123" substring, preventing partial-match corruption. + value_to_key: Dict[str, str] = {} + for choice in variable_choices: + for key, value in choice.items(): + str_value = str(value) + if str_value not in value_to_key: + value_to_key[str_value] = key + + sorted_entries = sorted(value_to_key.items(), key=lambda kv: len(kv[0]), reverse=True) + + warnings: List[str] = [] + for value, key in sorted_entries: + if len(value) < min_value_length: + continue + placeholder = f"{{{{{key}}}}}" + # Skip if the placeholder is already present — nothing to fix. + if placeholder in text and value not in text: + continue + + total_count = 0 + + # Pass 1: replace {{value}} forms — the LLM used the runtime value as + # if it were a placeholder key (e.g. {{user-125}} instead of {{user_id}}). + # This must run before the boundary-guarded pass so that the bare value + # inside the braces is consumed here rather than matched by pass 2, + # which would otherwise leave the surrounding braces and produce + # {{{{user_id}}}}. + brace_pattern = r'\{\{' + re.escape(value) + r'\}\}' + new_text, brace_count = re.subn(brace_pattern, placeholder, text, flags=re.DOTALL) + if brace_count: + text = new_text + total_count += brace_count + + # Pass 2: replace bare value occurrences with a boundary guard so that + # "user-123" inside "user-1234" is not substituted. + pattern = r'(? OptimizationOptions: if handle_agent_call is None: - handle_agent_call = AsyncMock(return_value="The capital of France is Paris.") + handle_agent_call = AsyncMock(return_value=OptimizationResponse(output="The capital of France is Paris.")) if handle_judge_call is None: - handle_judge_call = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) if judges is None: judges = { "accuracy": OptimizationJudge( @@ -244,7 +248,7 @@ def test_passes_at_exact_threshold(self): assert self.client._evaluate_response(ctx) is True def test_no_judges_always_passes(self): - options = _make_options(judges=None, handle_agent_call=AsyncMock(return_value="x")) + options = _make_options(judges=None, handle_agent_call=AsyncMock(return_value=OptimizationResponse(output="x"))) # Need on_turn to satisfy validation — inject directly options_with_on_turn = OptimizationOptions( context_choices=[LD_CONTEXT], @@ -252,8 +256,8 @@ def test_no_judges_always_passes(self): model_choices=["gpt-4o"], judge_model="gpt-4o", variable_choices=[{}], - handle_agent_call=AsyncMock(return_value="x"), - handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + handle_agent_call=AsyncMock(return_value=OptimizationResponse(output="x")), + handle_judge_call=AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)), judges={"j": OptimizationJudge(threshold=1.0, acceptance_statement="x")}, on_turn=lambda ctx: True, ) @@ -292,48 +296,6 @@ def test_multiple_judges_all_passing(self): assert self.client._evaluate_response(ctx) is True -# --------------------------------------------------------------------------- -# _builtin_judge_tool_handlers / _builtin_agent_tool_handlers -# --------------------------------------------------------------------------- - - -class TestBuiltinToolHandlers: - def setup_method(self): - self.client = _make_client() - self.client._options = _make_options() - - def test_judge_handlers_contains_evaluation_tool(self): - handlers = self.client._builtin_judge_tool_handlers() - assert create_evaluation_tool().name in handlers - - def test_judge_handler_returns_json(self): - handlers = self.client._builtin_judge_tool_handlers() - result = handlers[create_evaluation_tool().name](score=0.7, rationale="ok") - data = json.loads(result) - assert data["score"] == 0.7 - - def test_agent_handlers_empty_for_regular_turn(self): - handlers = self.client._builtin_agent_tool_handlers(is_variation=False) - assert handlers == {} - - def test_agent_handlers_contains_variation_tool_for_variation_turn(self): - handlers = self.client._builtin_agent_tool_handlers(is_variation=True) - expected_name = create_variation_tool(self.client._options.model_choices).name - assert expected_name in handlers - - def test_variation_handler_returns_valid_json(self): - handlers = self.client._builtin_agent_tool_handlers(is_variation=True) - name = create_variation_tool(self.client._options.model_choices).name - result = handlers[name]( - current_instructions="New instructions.", - current_parameters={"temperature": 0.3}, - model="gpt-4o", - ) - data = json.loads(result) - assert data["current_instructions"] == "New instructions." - assert data["model"] == "gpt-4o" - - # --------------------------------------------------------------------------- # _evaluate_acceptance_judge # --------------------------------------------------------------------------- @@ -346,7 +308,7 @@ def setup_method(self): self.client._agent_key = "test-agent" self.client._agent_config = agent_config self.client._initialize_class_members_from_config(agent_config) - self.handle_judge_call = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + self.handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) self.client._options = _make_options(handle_judge_call=self.handle_judge_call) async def test_returns_parsed_score_and_rationale(self): @@ -377,11 +339,10 @@ async def test_handle_judge_call_receives_correct_key_and_config(self): user_input="What time is it?", ) call_args = self.handle_judge_call.call_args - key, config, ctx, handlers = call_args.args + key, config, ctx = call_args.args assert key == "relevance" assert isinstance(config, AIJudgeCallConfig) assert isinstance(ctx, OptimizationJudgeContext) - assert create_evaluation_tool().name in handlers async def test_messages_has_system_and_user_turns(self): judge = OptimizationJudge( @@ -395,7 +356,7 @@ async def test_messages_has_system_and_user_turns(self): reasoning_history="", user_input="What colour is the sky?", ) - _, config, _, _ = self.handle_judge_call.call_args.args + _, config, _ = self.handle_judge_call.call_args.args roles = [m.role for m in config.messages] assert roles == ["system", "user"] @@ -411,7 +372,7 @@ async def test_messages_system_content_matches_instructions(self): reasoning_history="", user_input="Is Paris in France?", ) - _, config, _, _ = self.handle_judge_call.call_args.args + _, config, _ = self.handle_judge_call.call_args.args system_msg = next(m for m in config.messages if m.role == "system") assert system_msg.content == config.instructions @@ -427,7 +388,7 @@ async def test_messages_user_content_matches_context_user_input(self): reasoning_history="", user_input="Capital of France?", ) - _, config, ctx, _ = self.handle_judge_call.call_args.args + _, config, ctx = self.handle_judge_call.call_args.args user_msg = next(m for m in config.messages if m.role == "user") assert user_msg.content == ctx.user_input @@ -443,10 +404,11 @@ async def test_acceptance_statement_in_instructions(self): user_input="Tell me about Paris.", ) call_args = self.handle_judge_call.call_args - _, config, _, _ = call_args.args + _, config, _ = call_args.args assert statement in config.instructions - async def test_evaluation_tool_in_config_parameters(self): + async def test_no_structured_output_tool_in_judge_config(self): + """Structured output tool must not be injected — judges return plain JSON.""" judge = OptimizationJudge(threshold=0.8, acceptance_statement="Be brief.") await self.client._evaluate_acceptance_judge( judge_key="brevity", @@ -457,12 +419,11 @@ async def test_evaluation_tool_in_config_parameters(self): user_input="Is Paris in France?", ) call_args = self.handle_judge_call.call_args - _, config, _, _ = call_args.args + _, config, _ = call_args.args tools = config.model.get_parameter("tools") or [] - tool_names = [t["name"] for t in tools] - assert create_evaluation_tool().name in tool_names + assert tools == [] - async def test_agent_tools_prepended_to_config_tools(self): + async def test_agent_tools_included_in_config_tools(self): agent_tool = ToolDefinition( name="lookup", description="Lookup data", input_schema={} ) @@ -477,11 +438,10 @@ async def test_agent_tools_prepended_to_config_tools(self): agent_tools=[agent_tool], ) call_args = self.handle_judge_call.call_args - _, config, _, _ = call_args.args + _, config, _ = call_args.args tools = config.model.get_parameter("tools") or [] tool_names = [t["name"] for t in tools] - assert "lookup" in tool_names - assert tool_names.index("lookup") < tool_names.index(create_evaluation_tool().name) + assert tool_names == ["lookup"] async def test_variables_in_context(self): judge = OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") @@ -496,7 +456,7 @@ async def test_variables_in_context(self): variables=variables, ) call_args = self.handle_judge_call.call_args - _, _, ctx, _ = call_args.args + _, _, ctx = call_args.args assert ctx.variables == variables async def test_returns_zero_score_on_missing_acceptance_statement(self): @@ -513,7 +473,7 @@ async def test_returns_zero_score_on_missing_acceptance_statement(self): self.handle_judge_call.assert_not_called() async def test_returns_zero_score_on_parse_failure(self): - self.handle_judge_call.return_value = "not json at all" + self.handle_judge_call.return_value = OptimizationResponse(output="not json at all") judge = OptimizationJudge(threshold=0.8, acceptance_statement="Be clear.") result = await self.client._evaluate_acceptance_judge( judge_key="clarity", @@ -539,7 +499,7 @@ def setup_method(self): self.client._agent_key = "test-agent" self.client._agent_config = agent_config self.client._initialize_class_members_from_config(agent_config) - self.handle_judge_call = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + self.handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) self.client._options = _make_options(handle_judge_call=self.handle_judge_call) def _make_judge_config(self, enabled: bool = True) -> AIJudgeConfig: @@ -565,7 +525,7 @@ async def test_calls_handle_judge_call_with_correct_config_type(self): user_input="What is X?", ) call_args = self.handle_judge_call.call_args - key, config, ctx, handlers = call_args.args + key, config, ctx = call_args.args assert key == "quality" assert isinstance(config, AIJudgeCallConfig) assert "You are an evaluator." in config.instructions @@ -582,7 +542,7 @@ async def test_messages_has_system_and_user_turns(self): reasoning_history="", user_input="What is X?", ) - _, config, _, _ = self.handle_judge_call.call_args.args + _, config, _ = self.handle_judge_call.call_args.args roles = [m.role for m in config.messages] assert roles == ["system", "user"] @@ -597,7 +557,7 @@ async def test_messages_system_content_matches_instructions(self): reasoning_history="", user_input="What is X?", ) - _, config, _, _ = self.handle_judge_call.call_args.args + _, config, _ = self.handle_judge_call.call_args.args system_msg = next(m for m in config.messages if m.role == "system") assert system_msg.content == config.instructions @@ -612,7 +572,7 @@ async def test_messages_user_content_matches_context_user_input(self): reasoning_history="", user_input="What is X?", ) - _, config, ctx, _ = self.handle_judge_call.call_args.args + _, config, ctx = self.handle_judge_call.call_args.args user_msg = next(m for m in config.messages if m.role == "user") assert user_msg.content == ctx.user_input @@ -627,7 +587,7 @@ async def test_messages_user_content_contains_ld_user_message(self): reasoning_history="", user_input="What is X?", ) - _, config, _, _ = self.handle_judge_call.call_args.args + _, config, _ = self.handle_judge_call.call_args.args user_msg = next(m for m in config.messages if m.role == "user") assert "Evaluate this response." in user_msg.content @@ -684,7 +644,7 @@ async def test_template_variables_merged_into_judge_config_call(self): assert "message_history" in passed_vars assert "response_to_evaluate" in passed_vars - async def test_agent_tools_prepended_before_evaluation_tool(self): + async def test_agent_tools_included_without_evaluation_tool(self): self.mock_ldai.judge_config.return_value = self._make_judge_config() agent_tool = ToolDefinition(name="search", description="Search", input_schema={}) judge = OptimizationJudge(threshold=0.8, judge_key="ld-judge-key") @@ -697,11 +657,10 @@ async def test_agent_tools_prepended_before_evaluation_tool(self): user_input="Q?", agent_tools=[agent_tool], ) - _, config, _, _ = self.handle_judge_call.call_args.args + _, config, _ = self.handle_judge_call.call_args.args tools = config.model.get_parameter("tools") or [] names = [t["name"] for t in tools] - assert "search" in names - assert names.index("search") < names.index(create_evaluation_tool().name) + assert names == ["search"] # --------------------------------------------------------------------------- @@ -712,8 +671,8 @@ async def test_agent_tools_prepended_before_evaluation_tool(self): class TestExecuteAgentTurn: def setup_method(self): self.agent_response = "Paris is the capital of France." - self.handle_agent_call = AsyncMock(return_value=self.agent_response) - self.handle_judge_call = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + self.handle_agent_call = AsyncMock(return_value=OptimizationResponse(output=self.agent_response)) + self.handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) self.client = _make_client() agent_config = _make_agent_config() self.client._agent_key = "test-agent" @@ -740,11 +699,10 @@ async def test_calls_handle_agent_call_with_config_and_context(self): ctx = self._make_context() await self.client._execute_agent_turn(ctx, iteration=1) self.handle_agent_call.assert_called_once() - key, config, passed_ctx, handlers = self.handle_agent_call.call_args.args + key, config, passed_ctx = self.handle_agent_call.call_args.args assert key == "test-agent" assert isinstance(config, AIAgentConfig) assert passed_ctx is ctx - assert handlers == {} async def test_completion_response_stored_in_returned_context(self): ctx = self._make_context() @@ -760,7 +718,7 @@ async def test_judge_scores_stored_in_returned_context(self): async def test_variables_interpolated_into_agent_config_instructions(self): ctx = self._make_context() await self.client._execute_agent_turn(ctx, iteration=1) - _, config, _, _ = self.handle_agent_call.call_args.args + _, config, _ = self.handle_agent_call.call_args.args assert "{{language}}" not in config.instructions assert "English" in config.instructions @@ -778,7 +736,7 @@ async def test_raises_on_agent_call_failure(self): class TestGenerateNewVariation: def setup_method(self): - self.handle_agent_call = AsyncMock(return_value=VARIATION_RESPONSE) + self.handle_agent_call = AsyncMock(return_value=OptimizationResponse(output=VARIATION_RESPONSE)) self.client = _make_client() agent_config = _make_agent_config() self.client._agent_key = "test-agent" @@ -799,18 +757,17 @@ async def test_updates_current_model(self): await self.client._generate_new_variation(iteration=1, variables={}) assert self.client._current_model == "gpt-4o" - async def test_variation_tool_in_agent_config(self): + async def test_no_structured_output_tool_in_variation_config(self): + """Variation turn must not inject the structured-output tool — prompts use plain JSON.""" await self.client._generate_new_variation(iteration=1, variables={}) - _, config, _, _ = self.handle_agent_call.call_args.args + _, config, _ = self.handle_agent_call.call_args.args tools = config.model.get_parameter("tools") or [] - tool_names = [t["name"] for t in tools] - assert create_variation_tool(self.client._options.model_choices).name in tool_names + assert tools == [] - async def test_builtin_handlers_passed_for_variation(self): + async def test_variation_call_uses_three_arg_signature(self): + """handle_agent_call receives exactly (key, config, context) — no tools arg.""" await self.client._generate_new_variation(iteration=1, variables={}) - _, _, _, handlers = self.handle_agent_call.call_args.args - expected_name = create_variation_tool(self.client._options.model_choices).name - assert expected_name in handlers + assert len(self.handle_agent_call.call_args.args) == 3 async def test_model_not_updated_when_not_in_model_choices(self): bad_response = json.dumps({ @@ -818,11 +775,42 @@ async def test_model_not_updated_when_not_in_model_choices(self): "current_parameters": {}, "model": "some-unknown-model", }) - self.handle_agent_call.return_value = bad_response + self.handle_agent_call.return_value = OptimizationResponse(output=bad_response) original_model = self.client._current_model await self.client._generate_new_variation(iteration=1, variables={}) assert self.client._current_model == original_model + async def test_retries_on_empty_response_and_succeeds(self): + """First attempt returns empty string; second returns valid JSON — succeeds.""" + self.handle_agent_call.side_effect = [ + OptimizationResponse(output=""), # attempt 1: empty + OptimizationResponse(output=VARIATION_RESPONSE), # attempt 2: valid + ] + await self.client._generate_new_variation(iteration=1, variables={}) + assert self.client._current_instructions == "You are an improved assistant." + assert self.handle_agent_call.call_count == 2 + + async def test_retries_on_unparseable_response_and_succeeds(self): + """First attempt returns non-JSON text; second returns valid JSON — succeeds.""" + self.handle_agent_call.side_effect = [ + OptimizationResponse(output="Sorry, I cannot do that."), # attempt 1: not JSON + OptimizationResponse(output=VARIATION_RESPONSE), # attempt 2: valid + ] + await self.client._generate_new_variation(iteration=1, variables={}) + assert self.client._current_instructions == "You are an improved assistant." + assert self.handle_agent_call.call_count == 2 + + async def test_raises_after_max_retries_exhausted(self): + """All three attempts return empty strings — ValueError is raised.""" + self.handle_agent_call.side_effect = [ + OptimizationResponse(output=""), + OptimizationResponse(output=""), + OptimizationResponse(output=""), + ] + with pytest.raises(ValueError, match="Failed to parse structured output"): + await self.client._generate_new_variation(iteration=1, variables={}) + assert self.handle_agent_call.call_count == 3 + # --------------------------------------------------------------------------- # Full optimization loop @@ -834,8 +822,8 @@ def setup_method(self): self.mock_ldai = _make_ldai_client() async def test_succeeds_on_first_attempt_when_judge_passes(self): - handle_agent_call = AsyncMock(return_value="The capital of France is Paris.") - handle_judge_call = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + handle_agent_call = AsyncMock(return_value=OptimizationResponse(output="The capital of France is Paris.")) + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) client = _make_client(self.mock_ldai) options = _make_options( handle_agent_call=handle_agent_call, @@ -843,16 +831,22 @@ async def test_succeeds_on_first_attempt_when_judge_passes(self): ) result = await client.optimize_from_options("test-agent", options) assert result.scores["accuracy"].score == 1.0 - handle_agent_call.assert_called_once() + # 1 initial agent call + 1 validation sample (repeated draw — only 1 variable choice) + assert handle_agent_call.call_count == 2 async def test_generates_variation_when_judge_fails(self): agent_responses = [ - "Bad answer.", - VARIATION_RESPONSE, # variation generation - "Better answer.", + OptimizationResponse(output="Bad answer."), + OptimizationResponse(output=VARIATION_RESPONSE), # variation generation + OptimizationResponse(output="Better answer."), + OptimizationResponse(output="Better answer."), # 1 validation sample (repeated draw — only 1 variable choice) ] handle_agent_call = AsyncMock(side_effect=agent_responses) - judge_responses = [JUDGE_FAIL_RESPONSE, JUDGE_PASS_RESPONSE] + judge_responses = [ + OptimizationResponse(output=JUDGE_FAIL_RESPONSE), + OptimizationResponse(output=JUDGE_PASS_RESPONSE), + OptimizationResponse(output=JUDGE_PASS_RESPONSE), + ] handle_judge_call = AsyncMock(side_effect=judge_responses) client = _make_client(self.mock_ldai) options = _make_options( @@ -862,19 +856,20 @@ async def test_generates_variation_when_judge_fails(self): ) result = await client.optimize_from_options("test-agent", options) assert result.scores["accuracy"].score == 1.0 - assert handle_agent_call.call_count == 3 # 1 agent + 1 variation + 1 agent + # 1 agent + 1 variation + 1 agent + 1 validation sample + assert handle_agent_call.call_count == 4 async def test_returns_last_context_after_max_attempts(self): # The max_attempts guard fires before variation on the final iteration, # so only iterations 1 and 2 produce a variation call. handle_agent_call = AsyncMock(side_effect=[ - "Bad answer.", # iteration 1: agent - VARIATION_RESPONSE, # iteration 1: variation - "Still bad.", # iteration 2: agent - VARIATION_RESPONSE, # iteration 2: variation - "Still bad.", # iteration 3: agent (max_attempts reached — no variation) + OptimizationResponse(output="Bad answer."), # iteration 1: agent + OptimizationResponse(output=VARIATION_RESPONSE), # iteration 1: variation + OptimizationResponse(output="Still bad."), # iteration 2: agent + OptimizationResponse(output=VARIATION_RESPONSE), # iteration 2: variation + OptimizationResponse(output="Still bad."), # iteration 3: agent (max_attempts reached — no variation) ]) - handle_judge_call = AsyncMock(return_value=JUDGE_FAIL_RESPONSE) + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_FAIL_RESPONSE)) client = _make_client(self.mock_ldai) options = _make_options( handle_agent_call=handle_agent_call, @@ -886,8 +881,8 @@ async def test_returns_last_context_after_max_attempts(self): async def test_on_passing_result_called_on_success(self): on_passing = MagicMock() - handle_agent_call = AsyncMock(return_value="Great answer.") - handle_judge_call = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + handle_agent_call = AsyncMock(return_value=OptimizationResponse(output="Great answer.")) + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) client = _make_client(self.mock_ldai) options = _make_options( handle_agent_call=handle_agent_call, @@ -900,11 +895,11 @@ async def test_on_passing_result_called_on_success(self): async def test_on_failing_result_called_on_max_attempts(self): on_failing = MagicMock() handle_agent_call = AsyncMock(side_effect=[ - "Bad.", # iteration 1: agent - VARIATION_RESPONSE, # iteration 1: variation - "Still bad.", # iteration 2: agent (max_attempts reached — no variation) + OptimizationResponse(output="Bad."), # iteration 1: agent + OptimizationResponse(output=VARIATION_RESPONSE), # iteration 1: variation + OptimizationResponse(output="Still bad."), # iteration 2: agent (max_attempts reached — no variation) ]) - handle_judge_call = AsyncMock(return_value=JUDGE_FAIL_RESPONSE) + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_FAIL_RESPONSE)) client = _make_client(self.mock_ldai) options = _make_options( handle_agent_call=handle_agent_call, @@ -916,8 +911,8 @@ async def test_on_failing_result_called_on_max_attempts(self): on_failing.assert_called_once() async def test_on_turn_manual_path_success(self): - handle_agent_call = AsyncMock(return_value="Answer.") - handle_judge_call = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + handle_agent_call = AsyncMock(return_value=OptimizationResponse(output="Answer.")) + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) client = _make_client(self.mock_ldai) options = OptimizationOptions( context_choices=[LD_CONTEXT], @@ -935,8 +930,8 @@ async def test_on_turn_manual_path_success(self): async def test_status_update_callback_called_at_each_stage(self): statuses = [] - handle_agent_call = AsyncMock(return_value="Good answer.") - handle_judge_call = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + handle_agent_call = AsyncMock(return_value=OptimizationResponse(output="Good answer.")) + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) client = _make_client(self.mock_ldai) options = _make_options( handle_agent_call=handle_agent_call, @@ -950,6 +945,250 @@ async def test_status_update_callback_called_at_each_stage(self): assert "success" in statuses +# --------------------------------------------------------------------------- +# _compute_validation_count +# --------------------------------------------------------------------------- + + +class TestComputeValidationCount: + def test_pool_of_10_returns_2(self): + assert _compute_validation_count(10) == 2 + + def test_pool_of_20_returns_5(self): + assert _compute_validation_count(20) == 5 + + def test_pool_of_16_returns_4(self): + assert _compute_validation_count(16) == 4 + + def test_small_pool_floors_at_2(self): + assert _compute_validation_count(1) == 2 + assert _compute_validation_count(3) == 2 + + def test_large_pool_caps_at_5(self): + assert _compute_validation_count(100) == 5 + + def test_pool_of_8_returns_2(self): + assert _compute_validation_count(8) == 2 + + +# --------------------------------------------------------------------------- +# Validation phase (chaos mode) +# --------------------------------------------------------------------------- + +# Helper: build OptimizationOptions with multiple variable choices so the +# validation phase has a non-empty distinct pool to sample from. +def _make_multi_options( + *, + variable_count: int = 8, + user_input_options=None, + on_turn=None, + handle_agent_call=None, + handle_judge_call=None, + on_passing_result=None, + max_attempts: int = 5, +) -> OptimizationOptions: + if handle_agent_call is None: + handle_agent_call = AsyncMock(return_value=OptimizationResponse(output="answer")) + if handle_judge_call is None: + handle_judge_call = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) + judges = None if on_turn is not None else { + "acc": OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") + } + return OptimizationOptions( + context_choices=[LD_CONTEXT], + max_attempts=max_attempts, + model_choices=["gpt-4o"], + judge_model="gpt-4o", + variable_choices=[{"x": i} for i in range(variable_count)], + user_input_options=user_input_options, + handle_agent_call=handle_agent_call, + handle_judge_call=handle_judge_call, + judges=judges, + on_turn=on_turn, + on_passing_result=on_passing_result, + ) + + +class TestValidationPhase: + def setup_method(self): + self.mock_ldai = _make_ldai_client() + + def _make_client(self) -> OptimizationClient: + return _make_client(self.mock_ldai) + + async def test_on_passing_result_fires_only_after_all_validation_passes(self): + """on_passing_result must not fire until all validation samples pass.""" + on_passing = MagicMock() + client = self._make_client() + # 8 variable_choices → validation_count = 2; all judges always pass + opts = _make_multi_options(on_passing_result=on_passing) + await client.optimize_from_options("test-agent", opts) + on_passing.assert_called_once() + + async def test_validation_runs_additional_agent_calls(self): + """With 8 variable choices, validation runs 2 extra agent calls after the initial pass.""" + call_count = [0] + + async def counting_agent(key, config, ctx): + call_count[0] += 1 + return OptimizationResponse(output="answer") + + client = self._make_client() + opts = _make_multi_options(handle_agent_call=counting_agent) + await client.optimize_from_options("test-agent", opts) + # 1 initial pass + 2 validation samples + assert call_count[0] == 3 + + async def test_validation_failure_suppresses_on_passing_result_then_retries(self): + """When a validation sample fails, on_passing_result is not fired and the loop retries.""" + turn_calls = [0] + + def on_turn(ctx): + turn_calls[0] += 1 + # call 1: initial pass, call 2: first validation FAIL, everything else passes + return turn_calls[0] != 2 + + on_passing = MagicMock() + client = self._make_client() + opts = _make_multi_options( + on_turn=on_turn, + # 8 items → validation_count = 2 + variable_count=8, + handle_agent_call=AsyncMock(side_effect=[ + OptimizationResponse(output="iter1"), # initial turn (passes) + OptimizationResponse(output="val_iter2"), # validation sample 1 (fails) + OptimizationResponse(output=VARIATION_RESPONSE), # variation generation + OptimizationResponse(output="iter3"), # new attempt initial (passes) + OptimizationResponse(output="val_iter4"), # new validation sample 1 (passes) + OptimizationResponse(output="val_iter5"), # new validation sample 2 (passes) + ]), + on_passing_result=on_passing, + max_attempts=3, + ) + result = await client.optimize_from_options("test-agent", opts) + # Eventually succeeds after one failed validation cycle + on_passing.assert_called_once() + assert result is not None + + async def test_validation_does_not_reuse_passing_turn_variable(self): + """The variable set used in the initial passing turn must not appear in validation.""" + seen_variables = [] + + async def capture_agent(key, config, ctx): + seen_variables.append(ctx.current_variables) + return OptimizationResponse(output="answer") + + client = self._make_client() + opts = _make_multi_options(handle_agent_call=capture_agent, variable_count=8) + await client.optimize_from_options("test-agent", opts) + + # First call is the initial passing turn + initial_vars = seen_variables[0] + # Remaining calls are validation samples — none should match the initial + for val_vars in seen_variables[1:]: + assert val_vars != initial_vars, ( + f"Validation reused the passing turn's variables: {initial_vars}" + ) + + async def test_validation_uses_user_input_options_as_pool_when_provided(self): + """When user_input_options is provided, validation samples from that pool.""" + seen_inputs = [] + + async def capture_agent(key, config, ctx): + seen_inputs.append(ctx.user_input) + return OptimizationResponse(output="answer") + + client = self._make_client() + user_inputs = [f"question {i}" for i in range(8)] + opts = _make_multi_options( + handle_agent_call=capture_agent, + user_input_options=user_inputs, + ) + await client.optimize_from_options("test-agent", opts) + + # Initial input is at index 0; all validation inputs must be different + initial_input = seen_inputs[0] + for val_input in seen_inputs[1:]: + assert val_input != initial_input, ( + f"Validation reused the passing turn's user_input: {initial_input}" + ) + + async def test_pool_exhaustion_caps_validation_at_available_distinct_items(self): + """When fewer distinct items remain than validation_count, all available ones are used.""" + call_count = [0] + + async def counting_agent(key, config, ctx): + call_count[0] += 1 + return OptimizationResponse(output="answer") + + client = self._make_client() + # 3 variable choices → _compute_validation_count(3) = 2, but only 2 remain after + # excluding the passing item, so validation_count is still 2 (min of 2 and 2) + opts = _make_multi_options(handle_agent_call=counting_agent, variable_count=3) + await client.optimize_from_options("test-agent", opts) + # 1 initial + 2 validation (uses all remaining distinct items) + assert call_count[0] == 3 + + async def test_single_variable_choice_falls_back_to_repeated_draw(self): + """With only 1 variable choice validation still runs 1 sample (repeated draw).""" + call_count = [0] + + async def counting_agent(key, config, ctx): + call_count[0] += 1 + return OptimizationResponse(output="answer") + + client = self._make_client() + opts = _make_multi_options(handle_agent_call=counting_agent, variable_count=1) + await client.optimize_from_options("test-agent", opts) + # 1 initial pass + 1 validation sample (repeated draw from the only item) + assert call_count[0] == 2 + + async def test_validation_does_not_consume_attempt_budget(self): + """Validation samples must not count against max_attempts. + + With max_attempts=2 and 8 variable choices (validation_count=2), a failed + validation on attempt 1 should still leave a full attempt 2 available. + Without the fix, iteration would be inflated to 3 after validation, which + exceeds max_attempts=2 and would trigger _handle_failure prematurely. + """ + turn_calls = [0] + + def on_turn(ctx): + turn_calls[0] += 1 + # attempt 1 passes initial, validation sample 1 fails + # attempt 2 passes initial and all validation + return turn_calls[0] != 2 + + on_passing = MagicMock() + client = self._make_client() + opts = _make_multi_options( + on_turn=on_turn, + variable_count=8, + handle_agent_call=AsyncMock(side_effect=[ + OptimizationResponse(output="iter1"), # attempt 1 initial (passes) + OptimizationResponse(output="val_iter"), # validation sample 1 (fails) + OptimizationResponse(output=VARIATION_RESPONSE), # variation generation + OptimizationResponse(output="iter2"), # attempt 2 initial (passes) + OptimizationResponse(output="val_iter3"), # validation sample 1 (passes) + OptimizationResponse(output="val_iter4"), # validation sample 2 (passes) + ]), + on_passing_result=on_passing, + max_attempts=2, + ) + result = await client.optimize_from_options("test-agent", opts) + on_passing.assert_called_once() + assert result is not None + + async def test_validating_status_emitted(self): + """The 'validating' status must be emitted when entering the validation phase.""" + statuses = [] + client = self._make_client() + opts = _make_multi_options() + opts.on_status_update = lambda s, ctx: statuses.append(s) + await client.optimize_from_options("test-agent", opts) + assert "validating" in statuses + + # --------------------------------------------------------------------------- # Variation prompt — acceptance criteria section # --------------------------------------------------------------------------- @@ -1011,6 +1250,350 @@ def test_section_appears_in_full_prompt(self): assert "ACCEPTANCE CRITERIA" in prompt +# --------------------------------------------------------------------------- +# Variation prompt — overfitting warning section +# --------------------------------------------------------------------------- + + +class TestVariationPromptOverfitWarning: + def _make_ctx(self, user_input=None, variables=None, iteration=1): + return OptimizationContext( + iteration=iteration, + current_instructions=AGENT_INSTRUCTIONS, + current_parameters={}, + current_model="gpt-4o", + current_variables=variables or {}, + user_input=user_input, + completion_response=None, + scores={}, + ) + + def test_returns_empty_string_with_no_history(self): + assert variation_prompt_overfit_warning([]) == "" + + def test_contains_general_overfitting_reminder(self): + ctx = self._make_ctx(user_input="What is 2+2?") + section = variation_prompt_overfit_warning([ctx]) + assert "OVERFITTING" in section.upper() + assert "generalise" in section.lower() or "generalize" in section.lower() or "generaliz" in section.lower() or "general" in section.lower() + + def test_includes_recent_user_input(self): + ctx = self._make_ctx(user_input="What is the capital of France?") + section = variation_prompt_overfit_warning([ctx]) + assert "What is the capital of France?" in section + + def test_includes_recent_variables_as_structured_breakdown(self): + ctx = self._make_ctx(variables={"language": "English", "tone": "formal"}) + section = variation_prompt_overfit_warning([ctx]) + # Keys (placeholder names) and values must both appear + assert "{{language}}" in section + assert '"English"' in section + assert "{{tone}}" in section + assert '"formal"' in section + + def test_variables_section_labels_name_vs_value(self): + ctx = self._make_ctx(variables={"user_id": "user-125"}) + section = variation_prompt_overfit_warning([ctx]) + assert "{{user_id}}" in section + assert '"user-125"' in section + assert "placeholder" in section.lower() + assert "value" in section.lower() + # Must NOT render as a raw Python dict + assert "{'user_id': 'user-125'}" not in section + + def test_uses_most_recent_history_entry(self): + ctx_old = self._make_ctx(user_input="old question", iteration=1) + ctx_new = self._make_ctx(user_input="new question", iteration=2) + section = variation_prompt_overfit_warning([ctx_old, ctx_new]) + assert "new question" in section + assert "old question" not in section + + def test_omits_user_input_line_when_none(self): + ctx = self._make_ctx(user_input=None, variables={"lang": "en"}) + section = variation_prompt_overfit_warning([ctx]) + assert "User input" not in section + assert "lang" in section + + def test_omits_variables_line_when_empty(self): + ctx = self._make_ctx(user_input="hello", variables={}) + section = variation_prompt_overfit_warning([ctx]) + assert "Variables" not in section + assert "hello" in section + + def test_warning_appears_in_full_prompt_when_history_present(self): + ctx = self._make_ctx(user_input="test question", variables={"k": "v"}) + prompt = build_new_variation_prompt( + history=[ctx], + judges=None, + current_model="gpt-4o", + current_instructions=AGENT_INSTRUCTIONS, + current_parameters={}, + model_choices=["gpt-4o"], + variable_choices=[{"k": "v"}], + initial_instructions=AGENT_INSTRUCTIONS, + ) + assert "OVERFITTING" in prompt.upper() + assert "test question" in prompt + + def test_warning_absent_from_full_prompt_when_no_history(self): + prompt = build_new_variation_prompt( + history=[], + judges=None, + current_model="gpt-4o", + current_instructions=AGENT_INSTRUCTIONS, + current_parameters={}, + model_choices=["gpt-4o"], + variable_choices=[{"k": "v"}], + initial_instructions=AGENT_INSTRUCTIONS, + ) + assert "OVERFITTING" not in prompt.upper() + + +# --------------------------------------------------------------------------- +# Variation prompt — preamble key-vs-value note +# --------------------------------------------------------------------------- + + +class TestVariationPromptPreamble: + def test_contains_key_vs_value_important_note(self): + preamble = variation_prompt_preamble() + assert "IMPORTANT" in preamble + assert "placeholder" in preamble.lower() + assert "value" in preamble.lower() + + def test_never_use_value_as_placeholder_name(self): + preamble = variation_prompt_preamble() + assert "never" in preamble.lower() + + +# --------------------------------------------------------------------------- +# Variation prompt — placeholder table +# --------------------------------------------------------------------------- + + +class TestVariationPromptPlaceholderTable: + _variable_choices = [ + {"user_id": "user-123", "trip_purpose": "business"}, + {"user_id": "user-125", "trip_purpose": "personal"}, + ] + + def _section(self, variable_choices=None, history=None): + return variation_prompt_improvement_instructions( + history=history or [], + model_choices=["gpt-4o"], + variable_choices=variable_choices or self._variable_choices, + initial_instructions=AGENT_INSTRUCTIONS, + ) + + def test_placeholder_names_appear_in_table(self): + section = self._section() + assert "{{user_id}}" in section + assert "{{trip_purpose}}" in section + + def test_example_values_appear_alongside_keys(self): + section = self._section() + assert '"user-123"' in section or '"user-125"' in section + assert '"business"' in section or '"personal"' in section + + def test_keys_and_values_clearly_separated(self): + section = self._section() + assert "example values" in section.lower() + + def test_bad_good_counterexamples_use_actual_values(self): + section = self._section() + # The bad example must reference a runtime value, good example the key + assert "BAD" in section + assert "GOOD" in section + # At least one of the real values should appear in the bad example + assert "user-123" in section or "user-125" in section \ + or "business" in section or "personal" in section + + def test_raw_placeholder_list_not_used(self): + # The old format was a comma-separated list like "{{trip_purpose}}, {{user_id}}" + # The new format is a structured table; confirm no bare comma-list + section = self._section() + assert "{{trip_purpose}}, {{user_id}}" not in section + assert "{{user_id}}, {{trip_purpose}}" not in section + + def test_single_variable_choice(self): + section = self._section(variable_choices=[{"lang": "en"}]) + assert "{{lang}}" in section + assert '"en"' in section + + def test_table_appears_in_full_prompt(self): + prompt = build_new_variation_prompt( + history=[], + judges=None, + current_model="gpt-4o", + current_instructions=AGENT_INSTRUCTIONS, + current_parameters={}, + model_choices=["gpt-4o"], + variable_choices=self._variable_choices, + initial_instructions=AGENT_INSTRUCTIONS, + ) + assert "{{user_id}}" in prompt + assert "{{trip_purpose}}" in prompt + assert "example values" in prompt.lower() + + +# --------------------------------------------------------------------------- +# interpolate_variables — hyphenated key support +# --------------------------------------------------------------------------- + + +class TestInterpolateVariables: + def test_substitutes_standard_underscore_key(self): + result = interpolate_variables("Hello {{user_id}}", {"user_id": "abc"}) + assert result == "Hello abc" + + def test_substitutes_hyphenated_key(self): + result = interpolate_variables("Hello {{user-id}}", {"user-id": "abc"}) + assert result == "Hello abc" + + def test_leaves_unknown_placeholder_unchanged(self): + result = interpolate_variables("Hello {{unknown}}", {"user_id": "abc"}) + assert result == "Hello {{unknown}}" + + def test_leaves_unknown_hyphenated_placeholder_unchanged(self): + result = interpolate_variables("Hello {{bad-125}}", {"user_id": "abc"}) + assert result == "Hello {{bad-125}}" + + def test_mixed_keys_in_same_string(self): + result = interpolate_variables( + "{{user-id}} and {{trip_purpose}}", + {"user-id": "u-1", "trip_purpose": "leisure"}, + ) + assert result == "u-1 and leisure" + + def test_empty_variables_leaves_text_unchanged(self): + result = interpolate_variables("{{foo}} bar", {}) + assert result == "{{foo}} bar" + + +# --------------------------------------------------------------------------- +# restore_variable_placeholders +# --------------------------------------------------------------------------- + + +class TestRestoreVariablePlaceholders: + _CHOICES = [{"user_id": "user-123", "trip_purpose": "business"}] + + def test_replaces_hardcoded_id_value(self): + text = "Use the user ID user-123 to look up preferences." + result, warnings = restore_variable_placeholders(text, self._CHOICES) + assert "{{user_id}}" in result + assert "user-123" not in result + assert len(warnings) == 1 + assert "user-123" in warnings[0] + assert "{{user_id}}" in warnings[0] + + def test_replaces_multiline_value_verbatim(self): + multiline_value = "line one\nline two\nline three" + choices = [{"body_text": multiline_value}] + text = f"Instructions:\n{multiline_value}\nEnd." + result, warnings = restore_variable_placeholders(text, choices) + assert "{{body_text}}" in result + assert multiline_value not in result + assert len(warnings) == 1 + + def test_skips_value_shorter_than_min_length(self): + choices = [{"lang": "en"}] # "en" is only 2 chars + text = "Use language en for this request." + result, warnings = restore_variable_placeholders(text, choices, min_value_length=3) + assert result == text + assert warnings == [] + + def test_does_not_partially_match_longer_token(self): + """'user-123' must not be replaced inside 'user-1234'.""" + text = "Contact user-1234 for help." + result, warnings = restore_variable_placeholders(text, self._CHOICES) + assert "user-1234" in result + assert warnings == [] + + def test_replaces_multiple_variables(self): + text = "User user-123 is on a business trip." + result, warnings = restore_variable_placeholders(text, self._CHOICES) + assert "{{user_id}}" in result + assert "{{trip_purpose}}" in result + assert "user-123" not in result + assert "business" not in result + assert len(warnings) == 2 + + def test_leaves_correct_placeholder_unchanged(self): + text = "User {{user_id}} is on a {{trip_purpose}} trip." + result, warnings = restore_variable_placeholders(text, self._CHOICES) + assert result == text + assert warnings == [] + + def test_replaces_multiple_occurrences_of_same_value(self): + text = "user-123 and user-123 are duplicates." + result, warnings = restore_variable_placeholders(text, self._CHOICES) + assert result == "{{user_id}} and {{user_id}} are duplicates." + assert "2 occurrence(s)" in warnings[0] + + def test_longer_value_replaced_before_shorter_substring(self): + """When one value is a prefix of another, the longer one is replaced first.""" + choices = [{"full_id": "user-123-admin", "short_id": "user-123"}] + text = "Admin is user-123-admin, regular is user-123." + result, warnings = restore_variable_placeholders(text, choices) + assert "{{full_id}}" in result + assert "{{short_id}}" in result + assert "user-123-admin" not in result + # The shorter value should not have corrupted the longer replacement + assert result.count("{{full_id}}") == 1 + assert result.count("{{short_id}}") == 1 + + def test_replaces_brace_wrapped_value_without_double_bracketing(self): + """{{user-125}} must become {{user_id}}, not {{{{user_id}}}}.""" + text = "Fetch preferences for user {{user-123}}." + result, warnings = restore_variable_placeholders(text, self._CHOICES) + assert result == "Fetch preferences for user {{user_id}}." + assert len(warnings) == 1 + + def test_empty_variable_choices_returns_text_unchanged(self): + text = "Some instructions here." + result, warnings = restore_variable_placeholders(text, []) + assert result == text + assert warnings == [] + + def test_warning_message_format(self): + text = "Handle user user-123 carefully." + _, warnings = restore_variable_placeholders(text, self._CHOICES) + assert any("user-123" in w for w in warnings) + assert any("{{user_id}}" in w for w in warnings) + + async def test_apply_variation_response_calls_restore_and_logs_warning(self): + """_apply_new_variation_response must restore leaked values and log warnings.""" + leaked_instructions = "You serve user user-123 on a business trip." + variation_response = json.dumps({ + "current_instructions": leaked_instructions, + "current_parameters": {}, + "model": "gpt-4o", + }) + handle_agent_call = AsyncMock(return_value=OptimizationResponse(output=variation_response)) + client = _make_client() + agent_config = _make_agent_config() + client._agent_key = "test-agent" + client._agent_config = agent_config + client._initial_instructions = AGENT_INSTRUCTIONS + client._initialize_class_members_from_config(agent_config) + client._options = _make_options( + handle_agent_call=handle_agent_call, + variable_choices=[{"user_id": "user-123", "trip_purpose": "business"}], + ) + + with patch("ldai_optimization.client.logger") as mock_logger: + await client._generate_new_variation(iteration=1, variables={}) + warning_calls = [ + call for call in mock_logger.warning.call_args_list + if "user-123" in str(call) or "business" in str(call) + ] + assert len(warning_calls) >= 1 + + assert "{{user_id}}" in client._current_instructions + assert "user-123" not in client._current_instructions + + # --------------------------------------------------------------------------- # _build_options_from_config helpers # --------------------------------------------------------------------------- @@ -1035,8 +1618,8 @@ def _make_from_config_options(**overrides: Any) -> OptimizationFromConfigOptions defaults: Dict[str, Any] = dict( project_key="my-project", context_choices=[LD_CONTEXT], - handle_agent_call=AsyncMock(return_value="The answer is 4."), - handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + handle_agent_call=AsyncMock(return_value=OptimizationResponse(output="The answer is 4.")), + handle_judge_call=AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)), ) defaults.update(overrides) return OptimizationFromConfigOptions(**defaults) @@ -1171,8 +1754,8 @@ def test_model_with_multiple_dots_only_prefix_stripped(self): assert result.judge_model == "claude-opus-4.6" def test_callbacks_forwarded_from_options(self): - handle_agent = AsyncMock(return_value="ok") - handle_judge = AsyncMock(return_value=JUDGE_PASS_RESPONSE) + handle_agent = AsyncMock(return_value=OptimizationResponse(output="ok")) + handle_judge = AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)) options = _make_from_config_options( handle_agent_call=handle_agent, handle_judge_call=handle_judge, @@ -1424,8 +2007,8 @@ def _make(self, **overrides) -> GroundTruthOptimizationOptions: max_attempts=3, model_choices=["gpt-4o"], judge_model="gpt-4o", - handle_agent_call=AsyncMock(return_value="ans"), - handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + handle_agent_call=AsyncMock(return_value=OptimizationResponse(output="ans")), + handle_judge_call=AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)), judges={ "acc": OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") }, @@ -1473,8 +2056,8 @@ def _make_gt_options(**overrides) -> GroundTruthOptimizationOptions: max_attempts=3, model_choices=["gpt-4o", "gpt-4o-mini"], judge_model="gpt-4o", - handle_agent_call=AsyncMock(return_value="The answer is correct."), - handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + handle_agent_call=AsyncMock(return_value=OptimizationResponse(output="The answer is correct.")), + handle_judge_call=AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)), judges={ "acc": OptimizationJudge(threshold=0.8, acceptance_statement="Be accurate.") }, @@ -1508,7 +2091,7 @@ async def test_each_context_has_correct_user_input(self): async def test_completion_response_set_on_each_context(self): client = self._make_client() - opts = _make_gt_options(handle_agent_call=AsyncMock(return_value="42")) + opts = _make_gt_options(handle_agent_call=AsyncMock(return_value=OptimizationResponse(output="42"))) results = await client.optimize_from_ground_truth_options("test-agent", opts) for ctx in results: assert ctx.completion_response == "42" @@ -1531,7 +2114,7 @@ async def test_on_failing_result_called_when_max_attempts_exceeded(self): client = self._make_client() failing_calls = [] opts = _make_gt_options( - handle_judge_call=AsyncMock(return_value=JUDGE_FAIL_RESPONSE), + handle_judge_call=AsyncMock(return_value=OptimizationResponse(output=JUDGE_FAIL_RESPONSE)), max_attempts=2, on_failing_result=lambda ctx: failing_calls.append(ctx), ) @@ -1552,14 +2135,16 @@ async def side_effect(*args, **kwargs): nonlocal call_count resp = judge_responses[call_count] call_count += 1 - return resp + return OptimizationResponse(output=resp) opts = _make_gt_options( handle_judge_call=side_effect, handle_agent_call=AsyncMock(side_effect=[ - "ans1", "ans2", # attempt 1 samples - VARIATION_RESPONSE, # variation generation - "ans3", "ans4", # attempt 2 samples + OptimizationResponse(output="ans1"), + OptimizationResponse(output="ans2"), # attempt 1 samples + OptimizationResponse(output=VARIATION_RESPONSE), # variation generation + OptimizationResponse(output="ans3"), + OptimizationResponse(output="ans4"), # attempt 2 samples ]), max_attempts=3, ) @@ -1587,9 +2172,9 @@ def bad_callback(ctx): async def test_variables_from_samples_used_per_evaluation(self): client = self._make_client() received_contexts = [] - async def capture_agent_call(key, config, ctx, tools): + async def capture_agent_call(key, config, ctx): received_contexts.append(ctx) - return "response" + return OptimizationResponse(output="response") opts = _make_gt_options( ground_truth_responses=[ @@ -1609,9 +2194,9 @@ async def test_model_falls_back_to_first_model_choice_when_agent_config_has_no_m client = _make_client(mock_ldai) observed_models = [] - async def capture(key, config, ctx, tools): + async def capture(key, config, ctx): observed_models.append(config.model.name if config.model else None) - return "answer" + return OptimizationResponse(output="answer") opts = _make_gt_options( handle_agent_call=capture, @@ -1651,9 +2236,9 @@ def setup_method(self): async def test_expected_response_included_in_acceptance_judge_user_message(self): captured_configs = [] - async def capture_judge_call(key, config, ctx, tools): + async def capture_judge_call(key, config, ctx): captured_configs.append(config) - return JUDGE_PASS_RESPONSE + return OptimizationResponse(output=JUDGE_PASS_RESPONSE) self.client._options = _make_options( judges={ @@ -1673,9 +2258,9 @@ async def capture_judge_call(key, config, ctx, tools): async def test_expected_response_in_acceptance_judge_user_message(self): captured_configs = [] - async def capture_judge_call(key, config, ctx, tools): + async def capture_judge_call(key, config, ctx): captured_configs.append(config) - return JUDGE_PASS_RESPONSE + return OptimizationResponse(output=JUDGE_PASS_RESPONSE) self.client._options = _make_options( judges={ @@ -1698,9 +2283,9 @@ async def capture_judge_call(key, config, ctx, tools): async def test_no_expected_response_leaves_judge_messages_unchanged(self): captured_configs = [] - async def capture_judge_call(key, config, ctx, tools): + async def capture_judge_call(key, config, ctx): captured_configs.append(config) - return JUDGE_PASS_RESPONSE + return OptimizationResponse(output=JUDGE_PASS_RESPONSE) self.client._options = _make_options( judges={ @@ -1799,8 +2384,8 @@ async def test_optimize_from_config_dispatches_to_gt_run(self): with patch("ldai_optimization.client.LDApiClient", return_value=mock_api): options = _make_from_config_options( - handle_agent_call=AsyncMock(return_value="correct answer"), - handle_judge_call=AsyncMock(return_value=JUDGE_PASS_RESPONSE), + handle_agent_call=AsyncMock(return_value=OptimizationResponse(output="correct answer")), + handle_judge_call=AsyncMock(return_value=OptimizationResponse(output=JUDGE_PASS_RESPONSE)), ) result = await client.optimize_from_config("my-gt-opt", options)