-
Notifications
You must be signed in to change notification settings - Fork 2.9k
fix(planners): allow BuiltInPlanner subclasses to override process_planning_response #4141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d09d67d
20a089d
c945c04
8d57323
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,11 +14,17 @@ | |
|
|
||
| """Unit tests for NL planning logic.""" | ||
|
|
||
| from typing import List | ||
| from typing import Optional | ||
| from unittest.mock import MagicMock | ||
| from unittest.mock import patch | ||
|
|
||
| from google.adk.agents.callback_context import CallbackContext | ||
| from google.adk.agents.llm_agent import Agent | ||
| from google.adk.flows.llm_flows._nl_planning import request_processor | ||
| from google.adk.flows.llm_flows._nl_planning import response_processor | ||
| from google.adk.models.llm_request import LlmRequest | ||
| from google.adk.models.llm_response import LlmResponse | ||
| from google.adk.planners.built_in_planner import BuiltInPlanner | ||
| from google.adk.planners.plan_re_act_planner import PlanReActPlanner | ||
| from google.genai import types | ||
|
|
@@ -126,3 +132,89 @@ async def test_remove_thought_from_request_with_thoughts(): | |
| for content in llm_request.contents | ||
| for part in content.parts or [] | ||
| ) | ||
|
|
||
|
|
||
| class OverriddenBuiltInPlanner(BuiltInPlanner): | ||
| """Subclass that overrides process_planning_response.""" | ||
|
|
||
| def __init__(self, *, thinking_config: types.ThinkingConfig): | ||
| super().__init__(thinking_config=thinking_config) | ||
| self.process_planning_response_called = False | ||
| self.received_parts = None | ||
|
|
||
| def process_planning_response( | ||
| self, | ||
| callback_context: CallbackContext, | ||
| response_parts: List[types.Part], | ||
| ) -> Optional[List[types.Part]]: | ||
| self.process_planning_response_called = True | ||
| self.received_parts = response_parts | ||
| return response_parts | ||
|
|
||
|
|
||
| class NonOverriddenBuiltInPlanner(BuiltInPlanner): | ||
| """Subclass that does NOT override process_planning_response.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_overridden_subclass_process_planning_response_called(): | ||
| """Test that subclasses overriding process_planning_response have it called. | ||
|
|
||
| Regression test for issue #4133. | ||
| """ | ||
| planner = OverriddenBuiltInPlanner(thinking_config=types.ThinkingConfig()) | ||
| agent = Agent(name='test_agent', planner=planner) | ||
| invocation_context = await testing_utils.create_invocation_context( | ||
| agent=agent, user_content='test message' | ||
| ) | ||
|
|
||
| response_parts = [ | ||
| types.Part(text='thinking...', thought=True), | ||
| types.Part(text='Here is my response'), | ||
| ] | ||
| llm_response = LlmResponse( | ||
| content=types.Content(role='model', parts=response_parts) | ||
| ) | ||
|
|
||
| async for _ in response_processor.run_async(invocation_context, llm_response): | ||
| pass | ||
|
|
||
| assert planner.process_planning_response_called | ||
| assert planner.received_parts == response_parts | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| @pytest.mark.parametrize( | ||
| 'planner_class', | ||
| [BuiltInPlanner, NonOverriddenBuiltInPlanner], | ||
| ids=['base_class', 'non_overridden_subclass'], | ||
| ) | ||
| async def test_process_planning_response_not_called_without_override( | ||
| planner_class, | ||
| ): | ||
| """Test that process_planning_response is not called for base or non-overridden subclasses.""" | ||
| planner = planner_class(thinking_config=types.ThinkingConfig()) | ||
| agent = Agent(name='test_agent', planner=planner) | ||
| invocation_context = await testing_utils.create_invocation_context( | ||
| agent=agent, user_content='test message' | ||
| ) | ||
|
|
||
| response_parts = [ | ||
| types.Part(text='thinking...', thought=True), | ||
| types.Part(text='Here is my response'), | ||
| ] | ||
| llm_response = LlmResponse( | ||
| content=types.Content(role='model', parts=response_parts) | ||
| ) | ||
|
|
||
| with patch.object( | ||
| BuiltInPlanner, | ||
| 'process_planning_response', | ||
| ) as mock_method: | ||
|
Comment on lines
+212
to
+215
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current test patches with patch.object(planner, 'process_planning_response') as mock_method: |
||
| async for _ in response_processor.run_async( | ||
| invocation_context, llm_response | ||
| ): | ||
| pass | ||
| mock_method.assert_not_called() | ||
maru0804 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.