diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index 504e81c92..f44a11d30 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -496,8 +496,8 @@ async def stream( yield self.format_chunk({"chunk_type": "message_stop", "data": choice.finish_reason}) - if hasattr(chunk, "usage"): - yield self.format_chunk({"chunk_type": "metadata", "data": chunk.usage}) + if hasattr(chunk, "data") and hasattr(chunk.data, "usage") and chunk.data.usage: + yield self.format_chunk({"chunk_type": "metadata", "data": chunk.data.usage}) except Exception as e: if "rate" in str(e).lower() or "429" in str(e): diff --git a/tests/strands/models/test_mistral.py b/tests/strands/models/test_mistral.py index ad74bae89..57189748e 100644 --- a/tests/strands/models/test_mistral.py +++ b/tests/strands/models/test_mistral.py @@ -451,9 +451,9 @@ async def test_stream(mistral_client, model, agenerator, alist, captured_warning delta=unittest.mock.Mock(content="test stream", tool_calls=None), finish_reason="end_turn", ) - ] + ], + usage=mock_usage, ), - usage=mock_usage, ) mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) @@ -476,6 +476,30 @@ async def test_stream(mistral_client, model, agenerator, alist, captured_warning assert len(captured_warnings) == 0 +@pytest.mark.asyncio +async def test_stream_no_usage(mistral_client, model, agenerator, alist): + mock_event = unittest.mock.Mock( + data=unittest.mock.Mock( + choices=[ + unittest.mock.Mock( + delta=unittest.mock.Mock(content="test stream", tool_calls=None), + finish_reason="end_turn", + ) + ], + usage=None, + ), + ) + + mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) + + messages = [{"role": "user", "content": [{"text": "test"}]}] + response = model.stream(messages, None, None) + + # Should complete without error and not yield a metadata chunk + chunks = await alist(response) + assert not any("metadata" in c for c in chunks if isinstance(c, dict)) + + @pytest.mark.asyncio async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator, alist, captured_warnings): tool_choice = {"auto": {}} @@ -492,9 +516,9 @@ async def test_tool_choice_not_supported_warns(mistral_client, model, agenerator delta=unittest.mock.Mock(content="test stream", tool_calls=None), finish_reason="end_turn", ) - ] + ], + usage=mock_usage, ), - usage=mock_usage, ) mistral_client.chat.stream_async = unittest.mock.AsyncMock(return_value=agenerator([mock_event])) diff --git a/tests_integ/models/test_model_mistral.py b/tests_integ/models/test_model_mistral.py index 3b13e5911..83f6af499 100644 --- a/tests_integ/models/test_model_mistral.py +++ b/tests_integ/models/test_model_mistral.py @@ -106,6 +106,11 @@ async def test_agent_stream_async(agent): assert all(string in text for string in ["12:00", "sunny"]) + assert result.metrics.accumulated_usage is not None + assert result.metrics.accumulated_usage["inputTokens"] > 0 + assert result.metrics.accumulated_usage["outputTokens"] > 0 + assert result.metrics.accumulated_usage["totalTokens"] > 0 + def test_agent_structured_output(non_streaming_agent, weather): tru_weather = non_streaming_agent.structured_output(type(weather), "The time is 12:00 and the weather is sunny")