diff --git a/pyproject.toml b/pyproject.toml index 7b7247c..c1d51c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ "langchain>=0.3.7", "openai>=1.58.1", "pydantic>=2.9.2", - "og-test-v2-x402==0.0.11" + "og-x402==0.0.1.dev2" ] [project.scripts] diff --git a/requirements.txt b/requirements.txt index df03caa..480424d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ requests>=2.32.3 langchain>=0.3.7 openai>=1.58.1 pydantic>=2.9.2 -og-test-v2-x402==0.0.11 \ No newline at end of file +og-x402==0.0.1.dev2 \ No newline at end of file diff --git a/src/opengradient/agents/__init__.py b/src/opengradient/agents/__init__.py index 082f706..aa6a35d 100644 --- a/src/opengradient/agents/__init__.py +++ b/src/opengradient/agents/__init__.py @@ -6,15 +6,22 @@ into existing applications and agent frameworks. """ +from ..client.llm import LLM from ..types import TEE_LLM, x402SettlementMode from .og_langchain import * def langchain_adapter( - private_key: str, - model_cid: TEE_LLM, + private_key: str | None = None, + model_cid: TEE_LLM | str | None = None, + model: TEE_LLM | str | None = None, max_tokens: int = 300, + temperature: float = 0.0, x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + client: LLM | None = None, + rpc_url: str | None = None, + tee_registry_address: str | None = None, + llm_server_url: str | None = None, ) -> OpenGradientChatModel: """ Returns an OpenGradient LLM that implements LangChain's LLM interface @@ -22,9 +29,14 @@ def langchain_adapter( """ return OpenGradientChatModel( private_key=private_key, - model_cid=model_cid, + client=client, + model_cid=model_cid or model, max_tokens=max_tokens, + temperature=temperature, x402_settlement_mode=x402_settlement_mode, + rpc_url=rpc_url, + tee_registry_address=tee_registry_address, + llm_server_url=llm_server_url, ) diff --git a/src/opengradient/agents/og_langchain.py b/src/opengradient/agents/og_langchain.py index 4f238a5..bcb02f4 100644 --- a/src/opengradient/agents/og_langchain.py +++ b/src/opengradient/agents/og_langchain.py @@ -1,29 +1,34 @@ # mypy: ignore-errors import asyncio import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from enum import Enum +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast -from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, + ChatMessage, HumanMessage, SystemMessage, ToolCall, ) -from langchain_core.messages.tool import ToolMessage +from langchain_core.messages.tool import ToolCallChunk, ToolMessage from langchain_core.outputs import ( ChatGeneration, + ChatGenerationChunk, ChatResult, ) from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool from pydantic import PrivateAttr from ..client.llm import LLM -from ..types import TEE_LLM, x402SettlementMode +from ..types import StreamChunk, TEE_LLM, TextGenerationOutput, x402SettlementMode __all__ = ["OpenGradientChatModel"] @@ -47,7 +52,29 @@ def _extract_content(content: Any) -> str: return str(content) if content else "" -def _parse_tool_call(tool_call: Dict) -> ToolCall: +def _parse_tool_args(raw_args: Any) -> Dict[str, Any]: + if isinstance(raw_args, dict): + return raw_args + if raw_args is None or raw_args == "": + return {} + if isinstance(raw_args, str): + try: + parsed = json.loads(raw_args) + return parsed if isinstance(parsed, dict) else {} + except json.JSONDecodeError: + return {} + return {} + + +def _serialize_tool_args(raw_args: Any) -> str: + if raw_args is None: + return "{}" + if isinstance(raw_args, str): + return raw_args + return json.dumps(raw_args) + + +def _parse_tool_call(tool_call: Dict[str, Any]) -> ToolCall: """Parse a tool call from the API response. Handles both flat format {"id", "name", "arguments"} and @@ -58,86 +85,191 @@ def _parse_tool_call(tool_call: Dict) -> ToolCall: return ToolCall( id=tool_call.get("id", ""), name=func["name"], - args=json.loads(func.get("arguments", "{}")), + args=_parse_tool_args(func.get("arguments")), ) return ToolCall( id=tool_call.get("id", ""), name=tool_call["name"], - args=json.loads(tool_call.get("arguments", "{}")), + args=_parse_tool_args(tool_call.get("arguments")), + ) + + +def _parse_tool_call_chunk(tool_call: Dict[str, Any], default_index: int) -> ToolCallChunk: + if "function" in tool_call: + func = tool_call.get("function", {}) + name = func.get("name") + raw_args = func.get("arguments") + else: + name = tool_call.get("name") + raw_args = tool_call.get("arguments") + + args: Optional[str] + if raw_args is None: + args = None + elif isinstance(raw_args, str): + args = raw_args + else: + args = json.dumps(raw_args) + + return ToolCallChunk( + id=tool_call.get("id"), + index=tool_call.get("index", default_index), + name=name, + args=args, ) +def _run_coro_sync(coro_factory: Callable[[], Awaitable[Any]]) -> Any: + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro_factory()) + + raise RuntimeError( + "Synchronous LangChain calls cannot run inside an active event loop for this adapter. " + "Use `ainvoke`/`astream` instead of `invoke`/`stream`." + ) + + +def _validate_model_string(model: Union[TEE_LLM, str]) -> Union[TEE_LLM, str]: + if isinstance(model, Enum): + model_str = str(model.value) + else: + model_str = str(model) + if "/" not in model_str: + raise ValueError( + f"Unsupported model value '{model_str}'. " + "Expected provider/model format (for example: 'openai/gpt-5')." + ) + return model + + class OpenGradientChatModel(BaseChatModel): """OpenGradient adapter class for LangChain chat model""" - model_cid: str + model_cid: Union[TEE_LLM, str] max_tokens: int = 300 - x402_settlement_mode: Optional[str] = x402SettlementMode.BATCH_HASHED + temperature: float = 0.0 + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED _llm: LLM = PrivateAttr() + _owns_client: bool = PrivateAttr(default=False) _tools: List[Dict] = PrivateAttr(default_factory=list) + _tool_choice: Optional[str] = PrivateAttr(default=None) def __init__( self, - private_key: str, - model_cid: TEE_LLM, + private_key: Optional[str] = None, + model_cid: Optional[Union[TEE_LLM, str]] = None, + model: Optional[Union[TEE_LLM, str]] = None, max_tokens: int = 300, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, + temperature: float = 0.0, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + client: Optional[LLM] = None, + rpc_url: Optional[str] = None, + tee_registry_address: Optional[str] = None, + llm_server_url: Optional[str] = None, **kwargs, ): + resolved_model_cid = model_cid or model + if resolved_model_cid is None: + raise ValueError("model_cid (or model) is required.") + resolved_model_cid = _validate_model_string(resolved_model_cid) super().__init__( - model_cid=model_cid, + model_cid=resolved_model_cid, max_tokens=max_tokens, + temperature=temperature, x402_settlement_mode=x402_settlement_mode, **kwargs, ) - self._llm = LLM(private_key=private_key) + + if client is not None: + self._llm = client + self._owns_client = False + return + + if not private_key: + raise ValueError("private_key is required when client is not provided.") + + llm_kwargs: Dict[str, Any] = {} + if rpc_url is not None: + llm_kwargs["rpc_url"] = rpc_url + if tee_registry_address is not None: + llm_kwargs["tee_registry_address"] = tee_registry_address + if llm_server_url is not None: + llm_kwargs["llm_server_url"] = llm_server_url + + self._llm = LLM(private_key=private_key, **llm_kwargs) + self._owns_client = True @property def _llm_type(self) -> str: return "opengradient" + async def aclose(self) -> None: + if self._owns_client: + await self._llm.close() + + def close(self) -> None: + if self._owns_client: + _run_coro_sync(self._llm.close) + def bind_tools( self, tools: Sequence[ Union[Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 ], + *, + tool_choice: Optional[str] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tools to the model.""" - tool_dicts: List[Dict] = [] + strict = kwargs.get("strict") + self._tools = [convert_to_openai_tool(tool, strict=strict) for tool in tools] + self._tool_choice = tool_choice or kwargs.get("tool_choice") - for tool in tools: - if isinstance(tool, BaseTool): - tool_dicts.append( - { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": ( - tool.args_schema.model_json_schema() - if hasattr(tool, "args_schema") and tool.args_schema is not None - else {} - ), - }, - } - ) - else: - tool_dicts.append(tool) + return self - self._tools = tool_dicts + @staticmethod + def _stream_chunk_to_generation(chunk: StreamChunk) -> ChatGenerationChunk: + choice = chunk.choices[0] if chunk.choices else None + delta = choice.delta if choice else None - return self + usage = None + if chunk.usage is not None: + usage = { + "input_tokens": chunk.usage.prompt_tokens, + "output_tokens": chunk.usage.completion_tokens, + "total_tokens": chunk.usage.total_tokens, + } - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - sdk_messages = [] + tool_call_chunks: List[ToolCallChunk] = [] + if delta and delta.tool_calls: + for index, tool_call in enumerate(delta.tool_calls): + tool_call_chunks.append(_parse_tool_call_chunk(tool_call, index)) + + message_chunk = AIMessageChunk( + content=_extract_content(delta.content if delta else ""), + tool_call_chunks=tool_call_chunks, + usage_metadata=usage, + ) + + generation_info: Dict[str, Any] = {} + if choice and choice.finish_reason is not None: + generation_info["finish_reason"] = choice.finish_reason + + for key in ["tee_signature", "tee_timestamp", "tee_id", "tee_endpoint", "tee_payment_address"]: + value = getattr(chunk, key, None) + if value is not None: + generation_info[key] = value + + return ChatGenerationChunk( + message=message_chunk, + generation_info=generation_info or None, + ) + + def _convert_messages_to_sdk(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]: + sdk_messages: List[Dict[str, Any]] = [] for message in messages: if isinstance(message, SystemMessage): sdk_messages.append({"role": "system", "content": _extract_content(message.content)}) @@ -148,9 +280,12 @@ def _generate( if message.tool_calls: msg["tool_calls"] = [ { - "id": call["id"], + "id": call.get("id", ""), "type": "function", - "function": {"name": call["name"], "arguments": json.dumps(call["args"])}, + "function": { + "name": call["name"], + "arguments": _serialize_tool_args(call.get("args")), + }, } for call in message.tool_calls ] @@ -163,33 +298,125 @@ def _generate( "tool_call_id": message.tool_call_id, } ) + elif isinstance(message, ChatMessage): + sdk_messages.append({"role": message.role, "content": _extract_content(message.content)}) else: raise ValueError(f"Unexpected message type: {message}") + return sdk_messages - chat_output = asyncio.run( - self._llm.chat( - model=self.model_cid, - messages=sdk_messages, - stop_sequence=stop, - max_tokens=self.max_tokens, - tools=self._tools, - x402_settlement_mode=self.x402_settlement_mode, - ) - ) + def _build_chat_kwargs(self, sdk_messages: List[Dict[str, Any]], stop: Optional[List[str]], stream: bool, **kwargs: Any) -> Dict[str, Any]: + x402_settlement_mode = kwargs.get("x402_settlement_mode", self.x402_settlement_mode) + if isinstance(x402_settlement_mode, str): + x402_settlement_mode = x402SettlementMode(x402_settlement_mode) + model = kwargs.get("model", self.model_cid) + model = _validate_model_string(model) + return { + "model": model, + "messages": sdk_messages, + "stop_sequence": stop, + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + "tools": kwargs.get("tools", self._tools), + "tool_choice": kwargs.get("tool_choice", self._tool_choice), + "x402_settlement_mode": x402_settlement_mode, + "stream": stream, + } + + @staticmethod + def _build_chat_result(chat_output: TextGenerationOutput) -> ChatResult: finish_reason = chat_output.finish_reason or "" chat_response = chat_output.chat_output or {} + response_content = _extract_content(chat_response.get("content", "")) if chat_response.get("tool_calls"): tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]] - ai_message = AIMessage(content="", tool_calls=tool_calls) + ai_message = AIMessage(content=response_content, tool_calls=tool_calls) + else: + ai_message = AIMessage(content=response_content) + + generation_info = {"finish_reason": finish_reason} if finish_reason else {} + return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info=generation_info)]) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) + chat_output = _run_coro_sync(lambda: self._llm.chat(**chat_kwargs)) + if not isinstance(chat_output, TextGenerationOutput): + raise RuntimeError("Expected non-streaming chat output but received streaming generator.") + return self._build_chat_result(chat_output) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) + chat_output = await self._llm.chat(**chat_kwargs) + if not isinstance(chat_output, TextGenerationOutput): + raise RuntimeError("Expected non-streaming chat output but received streaming generator.") + return self._build_chat_result(chat_output) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) + try: + asyncio.get_running_loop() + except RuntimeError: + pass else: - ai_message = AIMessage(content=_extract_content(chat_response.get("content", ""))) + raise RuntimeError( + "Synchronous stream cannot run inside an active event loop for this adapter. " + "Use `astream` instead." + ) + + loop = asyncio.new_event_loop() + try: + stream = loop.run_until_complete(self._llm.chat(**chat_kwargs)) + stream_iter = cast(AsyncIterator[StreamChunk], stream) + + while True: + try: + chunk = loop.run_until_complete(stream_iter.__anext__()) + except StopAsyncIteration: + break + yield self._stream_chunk_to_generation(chunk) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() - return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})]) + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) + stream = await self._llm.chat(**chat_kwargs) + async for chunk in cast(AsyncIterator[StreamChunk], stream): + yield self._stream_chunk_to_generation(chunk) @property def _identifying_params(self) -> Dict[str, Any]: return { "model_name": self.model_cid, + "temperature": self.temperature, + "max_tokens": self.max_tokens, } diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index a345caa..078f826 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -3,22 +3,24 @@ import json import logging import ssl +import threading from dataclasses import dataclass -from typing import AsyncGenerator, Dict, List, Optional, Union +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union from eth_account import Account from eth_account.account import LocalAccount -from x402v2 import x402Client as x402Clientv2 -from x402v2.http.clients import x402HttpxClient as x402HttpxClientv2 -from x402v2.mechanisms.evm import EthAccountSigner as EthAccountSignerv2 -from x402v2.mechanisms.evm.exact.register import register_exact_evm_client as register_exact_evm_clientv2 -from x402v2.mechanisms.evm.upto.register import register_upto_evm_client as register_upto_evm_clientv2 +from x402 import x402Client +from x402.http.clients import x402HttpxClient +from x402.mechanisms.evm import EthAccountSigner +from x402.mechanisms.evm.exact.register import register_exact_evm_client +from x402.mechanisms.evm.upto.register import register_upto_evm_client from ..types import TEE_LLM, StreamChoice, StreamChunk, StreamDelta, TextGenerationOutput, x402SettlementMode from .opg_token import Permit2ApprovalResult, ensure_opg_approval from .tee_registry import TEERegistry, build_ssl_context_from_der logger = logging.getLogger(__name__) +T = TypeVar("T") DEFAULT_RPC_URL = "https://ogevmdevnet.opengradient.ai" DEFAULT_TEE_REGISTRY_ADDRESS = "0x4e72238852f3c918f4E4e57AeC9280dDB0c80248" @@ -30,6 +32,9 @@ _CHAT_ENDPOINT = "/v1/chat/completions" _COMPLETION_ENDPOINT = "/v1/completions" _REQUEST_TIMEOUT = 60 +_TEE_SIGNATURE_HEADER = "X-TEE-Signature" +_TEE_TIMESTAMP_HEADER = "X-TEE-Timestamp" +_TEE_ID_HEADER = "X-TEE-ID" @dataclass @@ -91,14 +96,64 @@ def __init__( ssl_ctx = build_ssl_context_from_der(tls_cert_der) if tls_cert_der else None self._tls_verify: Union[ssl.SSLContext, bool] = ssl_ctx if ssl_ctx else True + self._reset_lock = threading.Lock() - # x402 client and signer - signer = EthAccountSignerv2(self._wallet_account) - self._x402_client = x402Clientv2() - register_exact_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) - register_upto_evm_clientv2(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + # x402 client/signer/http stack + self._init_x402_stack() + + def _init_x402_stack(self) -> None: + """Initialize x402 signer/client/http stack.""" + signer = EthAccountSigner(self._wallet_account) + self._x402_client = x402Client() + register_exact_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) + register_upto_evm_client(self._x402_client, signer, networks=[BASE_TESTNET_NETWORK]) # httpx.AsyncClient subclass - construction is sync, connections open lazily - self._http_client = x402HttpxClientv2(self._x402_client, verify=self._tls_verify) + self._http_client = x402HttpxClient(self._x402_client, verify=self._tls_verify) + + async def _reset_x402_stack(self) -> None: + """Reset x402 state and underlying HTTP client.""" + with self._reset_lock: + old_http_client = self._http_client + self._init_x402_stack() + + try: + await old_http_client.aclose() + except Exception: + logger.debug("Failed to close previous x402 HTTP client during reset.", exc_info=True) + + @staticmethod + def _is_invalid_payment_required_error(exc: Exception) -> bool: + """Detect the known stale-session x402 failure mode.""" + visited: set[int] = set() + current: Optional[BaseException] = exc + + while current is not None and id(current) not in visited: + visited.add(id(current)) + msg = str(current).lower() + if "invalid payment required response" in msg: + return True + current = current.__cause__ or current.__context__ + return False + + async def _retry_once_on_invalid_payment_required( + self, + operation_name: str, + call: Callable[[], Awaitable[T]], + ) -> T: + """Retry once after resetting x402 state for recoverable payment errors.""" + try: + return await call() + except Exception as first_error: + if not self._is_invalid_payment_required_error(first_error): + raise + + logger.warning( + "Recoverable x402 payment error during %s; resetting x402 client and retrying once: %s", + operation_name, + first_error, + ) + await self._reset_x402_stack() + return await call() # ── TEE resolution ────────────────────────────────────────────────── @@ -169,6 +224,16 @@ def _tee_metadata(self) -> Dict: tee_payment_address=self._tee_payment_address, ) + @staticmethod + def _extract_tee_headers(response: Any) -> Dict[str, Optional[str]]: + """Extract TEE proof metadata from HTTP headers.""" + headers = getattr(response, "headers", {}) or {} + return { + "tee_signature": headers.get(_TEE_SIGNATURE_HEADER), + "tee_timestamp": headers.get(_TEE_TIMESTAMP_HEADER), + "tee_id": headers.get(_TEE_ID_HEADER), + } + # ── Public API ────────────────────────────────────────────────────── def ensure_opg_approval(self, opg_amount: float) -> Permit2ApprovalResult: @@ -239,7 +304,7 @@ async def completion( if stop_sequence: payload["stop"] = stop_sequence - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _COMPLETION_ENDPOINT, json=payload, @@ -249,13 +314,20 @@ async def completion( response.raise_for_status() raw_body = await response.aread() result = json.loads(raw_body.decode()) + tee_headers = self._extract_tee_headers(response) + metadata = self._tee_metadata() + if tee_headers.get("tee_id"): + metadata["tee_id"] = tee_headers["tee_id"] return TextGenerationOutput( transaction_hash="external", completion_output=result.get("completion"), - tee_signature=result.get("tee_signature"), - tee_timestamp=result.get("tee_timestamp"), - **self._tee_metadata(), + tee_signature=result.get("tee_signature") or tee_headers.get("tee_signature"), + tee_timestamp=result.get("tee_timestamp") or tee_headers.get("tee_timestamp"), + **metadata, ) + + try: + return await self._retry_once_on_invalid_payment_required("completion", _request) except RuntimeError: raise except Exception as e: @@ -326,7 +398,7 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages) - try: + async def _request() -> TextGenerationOutput: response = await self._http_client.post( self._tee_endpoint + _CHAT_ENDPOINT, json=payload, @@ -336,6 +408,10 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text response.raise_for_status() raw_body = await response.aread() result = json.loads(raw_body.decode()) + tee_headers = self._extract_tee_headers(response) + metadata = self._tee_metadata() + if tee_headers.get("tee_id"): + metadata["tee_id"] = tee_headers["tee_id"] choices = result.get("choices") if not choices: @@ -352,10 +428,13 @@ async def _chat_request(self, params: _ChatParams, messages: List[Dict]) -> Text transaction_hash="external", finish_reason=choices[0].get("finish_reason"), chat_output=message, - tee_signature=result.get("tee_signature"), - tee_timestamp=result.get("tee_timestamp"), - **self._tee_metadata(), + tee_signature=result.get("tee_signature") or tee_headers.get("tee_signature"), + tee_timestamp=result.get("tee_timestamp") or tee_headers.get("tee_timestamp"), + **metadata, ) + + try: + return await self._retry_once_on_invalid_payment_required("chat", _request) except RuntimeError: raise except Exception as e: @@ -391,17 +470,36 @@ async def _chat_stream(self, params: _ChatParams, messages: List[Dict]) -> Async headers = self._headers(params.x402_settlement_mode) payload = self._chat_payload(params, messages, stream=True) - async with self._http_client.stream( - "POST", - self._tee_endpoint + _CHAT_ENDPOINT, - json=payload, - headers=headers, - timeout=_REQUEST_TIMEOUT, - ) as response: - async for chunk in self._parse_sse_response(response): - yield chunk + retried = False + while True: + try: + async with self._http_client.stream( + "POST", + self._tee_endpoint + _CHAT_ENDPOINT, + json=payload, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) as response: + tee_headers = self._extract_tee_headers(response) + async for chunk in self._parse_sse_response(response, tee_headers=tee_headers): + yield chunk + return + except Exception as e: + if (not retried) and self._is_invalid_payment_required_error(e): + retried = True + logger.warning( + "Recoverable x402 payment error during stream; resetting x402 client and retrying once: %s", + e, + ) + await self._reset_x402_stack() + continue + raise - async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, None]: + async def _parse_sse_response( + self, + response, + tee_headers: Optional[Dict[str, Optional[str]]] = None, + ) -> AsyncGenerator[StreamChunk, None]: """Parse an SSE response stream into StreamChunk objects.""" status_code = getattr(response, "status_code", None) if status_code is not None and status_code >= 400: @@ -442,4 +540,8 @@ async def _parse_sse_response(self, response) -> AsyncGenerator[StreamChunk, Non chunk.tee_id = self._tee_id chunk.tee_endpoint = self._tee_endpoint chunk.tee_payment_address = self._tee_payment_address + if tee_headers: + chunk.tee_signature = chunk.tee_signature or tee_headers.get("tee_signature") + chunk.tee_timestamp = chunk.tee_timestamp or tee_headers.get("tee_timestamp") + chunk.tee_id = tee_headers.get("tee_id") or chunk.tee_id yield chunk diff --git a/src/opengradient/client/opg_token.py b/src/opengradient/client/opg_token.py index b19c1de..87d3867 100644 --- a/src/opengradient/client/opg_token.py +++ b/src/opengradient/client/opg_token.py @@ -5,7 +5,7 @@ from eth_account.account import LocalAccount from web3 import Web3 -from x402v2.mechanisms.evm.constants import PERMIT2_ADDRESS +from x402.mechanisms.evm.constants import PERMIT2_ADDRESS BASE_OPG_ADDRESS = "0x240b09731D96979f50B2C649C9CE10FcF9C7987F" BASE_SEPOLIA_RPC = "https://sepolia.base.org" diff --git a/tests/langchain_adapter_test.py b/tests/langchain_adapter_test.py index e651ab4..1747c1d 100644 --- a/tests/langchain_adapter_test.py +++ b/tests/langchain_adapter_test.py @@ -1,3 +1,4 @@ +import asyncio import json import os import sys @@ -11,7 +12,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from src.opengradient.agents.og_langchain import OpenGradientChatModel, _extract_content, _parse_tool_call -from src.opengradient.types import TEE_LLM, TextGenerationOutput, x402SettlementMode +from src.opengradient.types import StreamChoice, StreamChunk, StreamDelta, TEE_LLM, TextGenerationOutput, x402SettlementMode @pytest.fixture @@ -52,9 +53,24 @@ def test_initialization_custom_settlement_mode(self, mock_llm_client): ) assert model.x402_settlement_mode == x402SettlementMode.PRIVATE + def test_initialization_with_existing_client(self): + with patch("src.opengradient.agents.og_langchain.LLM") as MockLLM: + existing_client = MagicMock() + model = OpenGradientChatModel(private_key=None, client=existing_client, model_cid=TEE_LLM.GPT_5) + assert model._llm is existing_client + MockLLM.assert_not_called() + + def test_initialization_without_private_key_or_client_raises(self): + with pytest.raises(ValueError, match="private_key is required"): + OpenGradientChatModel(private_key=None, model_cid=TEE_LLM.GPT_5) + + def test_initialization_with_invalid_model_string_raises(self): + with pytest.raises(ValueError, match="provider/model format"): + OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid="gpt-5") + def test_identifying_params(self, model): """Test _identifying_params returns model name.""" - assert model._identifying_params == {"model_name": TEE_LLM.GPT_5} + assert model._identifying_params == {"model_name": TEE_LLM.GPT_5, "temperature": 0.0, "max_tokens": 300} class TestGenerate: @@ -156,6 +172,24 @@ def test_empty_chat_output(self, model, mock_llm_client): assert result.generations[0].message.content == "" + def test_generate_with_invalid_model_kwarg_raises(self, model): + with pytest.raises(ValueError, match="provider/model format"): + model._generate([HumanMessage(content="Hi")], model="gpt-5") + + def test_sync_generate_inside_running_loop_raises(self, model): + async def run_test(): + with pytest.raises(RuntimeError, match="Use `ainvoke`/`astream`"): + model._generate([HumanMessage(content="Hi")]) + + asyncio.run(run_test()) + + def test_sync_stream_inside_running_loop_raises(self, model): + async def run_test(): + with pytest.raises(RuntimeError, match="Use `astream`"): + next(model._stream([HumanMessage(content="Hi")])) + + asyncio.run(run_test()) + class TestMessageConversion: def test_converts_all_message_types(self, model, mock_llm_client): @@ -215,8 +249,11 @@ def test_passes_correct_params_to_client(self, model, mock_llm_client): messages=[{"role": "user", "content": "Hi"}], stop_sequence=["END"], max_tokens=300, + temperature=0.0, tools=[], + tool_choice=None, x402_settlement_mode=x402SettlementMode.BATCH_HASHED, + stream=False, ) @@ -306,3 +343,77 @@ def test_nested_function_format(self): assert tc["name"] == "bar" assert tc["args"] == {"y": 2} assert tc["id"] == "2" + + +class TestAsyncPaths: + def test_agenerate(self, model, mock_llm_client): + mock_llm_client.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "Hello async!"}, + ) + + result = asyncio.run(model._agenerate([HumanMessage(content="Hi")])) + assert result.generations[0].message.content == "Hello async!" + + def test_ainvoke(self, model, mock_llm_client): + mock_llm_client.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "pong"}, + ) + + message = asyncio.run(model.ainvoke([HumanMessage(content="ping")])) + assert message.content == "pong" + + def test_astream(self, model, mock_llm_client): + async def stream(): + yield StreamChunk( + choices=[StreamChoice(delta=StreamDelta(role="assistant", content="Hel"), index=0)], + model="gpt-5", + ) + yield StreamChunk( + choices=[StreamChoice(delta=StreamDelta(content="lo"), index=0, finish_reason="stop")], + model="gpt-5", + is_final=True, + ) + + mock_llm_client.chat.return_value = stream() + + async def collect_chunks(): + return [chunk async for chunk in model.astream([HumanMessage(content="Hi")])] + + chunks = asyncio.run(collect_chunks()) + output_text = "".join(chunk.content for chunk in chunks if chunk.content) + assert output_text == "Hello" + + def test_astream_tool_call_chunk(self, model, mock_llm_client): + async def stream(): + yield StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta( + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "search", "arguments": '{"q":"test"}'}, + } + ] + ), + index=0, + finish_reason="tool_calls", + ) + ], + model="gpt-5", + is_final=True, + ) + + mock_llm_client.chat.return_value = stream() + + async def collect_chunks(): + return [chunk async for chunk in model.astream([HumanMessage(content="Hi")])] + + chunks = asyncio.run(collect_chunks()) + assert chunks[0].tool_call_chunks[0]["id"] == "call_1" + assert chunks[0].tool_call_chunks[0]["name"] == "search" diff --git a/tests/llm_test.py b/tests/llm_test.py index 3f068f3..62c54b8 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -1,13 +1,13 @@ """Tests for LLM class. -Construction patches the x402 boundary (x402HttpxClientv2, EthAccountSignerv2, etc.) +Construction patches the x402 boundary (x402HttpxClient, EthAccountSigner, etc.) so LLM builds normally — no test-only constructor params, no mocking of private methods. """ import json from contextlib import asynccontextmanager from typing import List -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest @@ -19,25 +19,27 @@ class FakeHTTPClient: - """Stands in for x402HttpxClientv2. + """Stands in for x402HttpxClient. Configured per-test with set_response / set_stream_response, then - injected via the x402HttpxClientv2 patch so LLM's normal __init__ + injected via the x402HttpxClient patch so LLM's normal __init__ assigns it to self._http_client. """ def __init__(self, *_args, **_kwargs): self._response_status: int = 200 self._response_body: bytes = b"{}" + self._response_headers: dict = {} self._post_calls: List[dict] = [] self._stream_response = None - def set_response(self, status_code: int, body: dict) -> None: + def set_response(self, status_code: int, body: dict, headers: dict | None = None) -> None: self._response_status = status_code self._response_body = json.dumps(body).encode() + self._response_headers = headers or {} - def set_stream_response(self, status_code: int, chunks: List[bytes]) -> None: - self._stream_response = _FakeStreamResponse(status_code, chunks) + def set_stream_response(self, status_code: int, chunks: List[bytes], headers: dict | None = None) -> None: + self._stream_response = _FakeStreamResponse(status_code, chunks, headers=headers) @property def post_calls(self) -> List[dict]: @@ -45,7 +47,7 @@ def post_calls(self) -> List[dict]: async def post(self, url: str, *, json=None, headers=None, timeout=None) -> "_FakeResponse": self._post_calls.append({"url": url, "json": json, "headers": headers, "timeout": timeout}) - resp = _FakeResponse(self._response_status, self._response_body) + resp = _FakeResponse(self._response_status, self._response_body, headers=self._response_headers) if self._response_status >= 400: resp.raise_for_status = MagicMock(side_effect=httpx.HTTPStatusError("error", request=MagicMock(), response=MagicMock())) return resp @@ -60,9 +62,10 @@ async def aclose(self): class _FakeResponse: - def __init__(self, status_code: int, body: bytes): + def __init__(self, status_code: int, body: bytes, headers: dict | None = None): self.status_code = status_code self._body = body + self.headers = headers or {} def raise_for_status(self): pass @@ -72,9 +75,10 @@ async def aread(self) -> bytes: class _FakeStreamResponse: - def __init__(self, status_code: int, chunks: List[bytes]): + def __init__(self, status_code: int, chunks: List[bytes], headers: dict | None = None): self.status_code = status_code self._chunks = chunks + self.headers = headers or {} async def aiter_raw(self): for chunk in self._chunks: @@ -90,11 +94,11 @@ async def aread(self) -> bytes: # so LLM.__init__ runs its real code but gets our FakeHTTPClient. _PATCHES = { - "x402_httpx": "src.opengradient.client.llm.x402HttpxClientv2", - "x402_client": "src.opengradient.client.llm.x402Clientv2", - "signer": "src.opengradient.client.llm.EthAccountSignerv2", - "register_exact": "src.opengradient.client.llm.register_exact_evm_clientv2", - "register_upto": "src.opengradient.client.llm.register_upto_evm_clientv2", + "x402_httpx": "src.opengradient.client.llm.x402HttpxClient", + "x402_client": "src.opengradient.client.llm.x402Client", + "signer": "src.opengradient.client.llm.EthAccountSigner", + "register_exact": "src.opengradient.client.llm.register_exact_evm_client", + "register_upto": "src.opengradient.client.llm.register_upto_evm_client", } @@ -152,6 +156,24 @@ async def test_returns_completion_output(self, fake_http): assert result.tee_id == "test-tee-id" assert result.tee_payment_address == "0xTestPayment" + async def test_tee_metadata_falls_back_to_headers(self, fake_http): + fake_http.set_response( + 200, + {"completion": "ok"}, + headers={ + "X-TEE-Signature": "sig-from-header", + "X-TEE-Timestamp": "2026-03-13T00:00:00Z", + "X-TEE-ID": "tee-id-from-header", + }, + ) + llm = _make_llm() + + result = await llm.completion(model=TEE_LLM.GPT_5, prompt="Hi") + + assert result.tee_signature == "sig-from-header" + assert result.tee_timestamp == "2026-03-13T00:00:00Z" + assert result.tee_id == "tee-id-from-header" + async def test_sends_correct_payload(self, fake_http): fake_http.set_response(200, {"completion": "ok"}) llm = _make_llm() @@ -236,6 +258,29 @@ async def test_returns_chat_output(self, fake_http): assert result.finish_reason == "stop" assert result.tee_signature == "sig-xyz" + async def test_chat_tee_metadata_falls_back_to_headers(self, fake_http): + fake_http.set_response( + 200, + { + "choices": [{"message": {"role": "assistant", "content": "Hi there!"}, "finish_reason": "stop"}], + }, + headers={ + "X-TEE-Signature": "sig-from-header", + "X-TEE-Timestamp": "2026-03-13T00:00:00Z", + "X-TEE-ID": "tee-id-from-header", + }, + ) + llm = _make_llm() + + result = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hello"}], + ) + + assert result.tee_signature == "sig-from-header" + assert result.tee_timestamp == "2026-03-13T00:00:00Z" + assert result.tee_id == "tee-id-from-header" + async def test_flattens_content_blocks(self, fake_http): fake_http.set_response( 200, @@ -361,6 +406,32 @@ async def test_http_error_raises_opengradient_error(self, fake_http): with pytest.raises(RuntimeError, match="TEE LLM chat failed"): await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + async def test_retries_once_on_invalid_payment_required(self, fake_http): + fake_http.set_response( + 200, + { + "choices": [{"message": {"role": "assistant", "content": "recovered"}, "finish_reason": "stop"}], + }, + ) + llm = _make_llm() + llm._reset_x402_stack = AsyncMock(return_value=None) + original_post = llm._http_client.post + attempts = {"count": 0} + + async def flaky_post(*args, **kwargs): + attempts["count"] += 1 + if attempts["count"] == 1: + raise RuntimeError("Failed to handle payment: Invalid payment required response") + return await original_post(*args, **kwargs) + + llm._http_client.post = flaky_post + + result = await llm.chat(model=TEE_LLM.GPT_5, messages=[{"role": "user", "content": "Hi"}]) + + assert result.chat_output["content"] == "recovered" + assert attempts["count"] == 2 + llm._reset_x402_stack.assert_awaited_once() + # ── Streaming tests ────────────────────────────────────────────────── @@ -427,6 +498,33 @@ async def test_stream_sets_tee_metadata_on_final_chunk(self, fake_http): assert final.tee_id == "test-tee-id" assert final.tee_payment_address == "0xTestPayment" + async def test_stream_tee_signature_timestamp_fallback_to_headers(self, fake_http): + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"gpt-5","choices":[{"index":0,"delta":{"content":"done"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + headers={ + "X-TEE-Signature": "sig-from-header", + "X-TEE-Timestamp": "2026-03-13T00:00:00Z", + "X-TEE-ID": "tee-id-from-header", + }, + ) + llm = _make_llm() + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [chunk async for chunk in gen] + + final = chunks[-1] + assert final.tee_signature == "sig-from-header" + assert final.tee_timestamp == "2026-03-13T00:00:00Z" + assert final.tee_id == "tee-id-from-header" + async def test_stream_error_raises(self, fake_http): fake_http.set_stream_response(500, [b"Internal Server Error"]) llm = _make_llm() @@ -469,6 +567,40 @@ async def test_tools_with_stream_falls_back_to_single_chunk(self, fake_http): assert chunks[0].choices[0].delta.tool_calls == [{"id": "tc1"}] assert chunks[0].choices[0].finish_reason == "tool_calls" + async def test_stream_retries_once_on_invalid_payment_required(self, fake_http): + fake_http.set_stream_response( + 200, + [ + b'data: {"model":"gpt-5","choices":[{"index":0,"delta":{"content":"ok"},"finish_reason":"stop"}]}\n\n', + b"data: [DONE]\n\n", + ], + ) + llm = _make_llm() + llm._reset_x402_stack = AsyncMock(return_value=None) + original_stream = llm._http_client.stream + attempts = {"count": 0} + + @asynccontextmanager + async def flaky_stream(*args, **kwargs): + attempts["count"] += 1 + if attempts["count"] == 1: + raise RuntimeError("Failed to handle payment: Invalid payment required response") + async with original_stream(*args, **kwargs) as response: + yield response + + llm._http_client.stream = flaky_stream + + gen = await llm.chat( + model=TEE_LLM.GPT_5, + messages=[{"role": "user", "content": "Hi"}], + stream=True, + ) + chunks = [chunk async for chunk in gen] + + assert attempts["count"] == 2 + assert chunks[-1].choices[0].delta.content == "ok" + llm._reset_x402_stack.assert_awaited_once() + # ── ensure_opg_approval tests ──────────────────────────────────────── diff --git a/uv.lock b/uv.lock index bb26279..faee483 100644 --- a/uv.lock +++ b/uv.lock @@ -1609,16 +1609,16 @@ wheels = [ ] [[package]] -name = "og-test-v2-x402" -version = "0.0.9" +name = "og-x402" +version = "0.0.1.dev1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/de/fd26c297113c483f62f3a5ee5fc535e81f9413edc68d1bf9d2db4ba62dd4/og_test_v2_x402-0.0.9.tar.gz", hash = "sha256:f5353be907c7224371214d40ec8dc125ee0633e3dbd9deadf6e43c904a7a9328", size = 892006, upload-time = "2026-02-17T16:16:34.028Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/22/ffce1dac3875bd4d85cc4387eedf3be1d41ced92b29fbff22c5d0d2af257/og_x402-0.0.1.dev1.tar.gz", hash = "sha256:cb40451e988c1e3376e48c37a4cc5fd5058ca7834200da51076bd30e1ad83d66", size = 899655, upload-time = "2026-03-16T17:26:05.564Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/54/da/bd0f6670d9d577b10d13a9a68371b8fd40e1b24fd30dd690b2aa497eea81/og_test_v2_x402-0.0.9-py3-none-any.whl", hash = "sha256:80257701e8a1909ec5fba434482aa2cdcd9de2a7868b99cff70cf1763c8a53b0", size = 945014, upload-time = "2026-02-17T16:16:32.107Z" }, + { url = "https://files.pythonhosted.org/packages/1c/56/0d5eee73c45cfb409b9f87f3d555199c4e4d5a2c07b74ab312363cc6d644/og_x402-0.0.1.dev1-py3-none-any.whl", hash = "sha256:ffa09250097b25fad5c58105dd65e258468890d95be45ed9ff9f1ee6a7195c69", size = 952333, upload-time = "2026-03-16T17:26:03.325Z" }, ] [[package]] @@ -1642,7 +1642,7 @@ wheels = [ [[package]] name = "opengradient" -version = "0.7.3" +version = "0.8.0" source = { editable = "." } dependencies = [ { name = "click" }, @@ -1650,7 +1650,7 @@ dependencies = [ { name = "firebase-rest-api" }, { name = "langchain" }, { name = "numpy" }, - { name = "og-test-v2-x402" }, + { name = "og-x402" }, { name = "openai" }, { name = "pydantic" }, { name = "requests" }, @@ -1664,7 +1664,7 @@ requires-dist = [ { name = "firebase-rest-api", specifier = ">=1.11.0" }, { name = "langchain", specifier = ">=0.3.7" }, { name = "numpy", specifier = ">=1.26.4" }, - { name = "og-test-v2-x402", specifier = "==0.0.9" }, + { name = "og-x402", specifier = "==0.0.1.dev1" }, { name = "openai", specifier = ">=1.58.1" }, { name = "pydantic", specifier = ">=2.9.2" }, { name = "requests", specifier = ">=2.32.3" },