diff --git a/crates/adaptive/src/intercepts.rs b/crates/adaptive/src/intercepts.rs index 6641cd74..9e3a39da 100644 --- a/crates/adaptive/src/intercepts.rs +++ b/crates/adaptive/src/intercepts.rs @@ -128,7 +128,7 @@ pub(crate) fn create_tool_execution_intercept_with_mode( let name = name.to_string(); Box::pin(async move { let Some(cohort_key) = resolve_warm_first_cohort_key(&name, &mode, &cache) else { - return next(args).await; + return next(args).await.map(Into::into); }; match resolve_warm_first_role(®istry, cohort_key.clone()).await { @@ -136,7 +136,7 @@ pub(crate) fn create_tool_execution_intercept_with_mode( let result = next(args).await; gate.release(); cleanup_cohort_gate(®istry, &cohort_key, &gate).await; - result + result.map(Into::into) } WarmFirstRole::Follower(gate) => { let _ = tokio::time::timeout( @@ -144,10 +144,19 @@ pub(crate) fn create_tool_execution_intercept_with_mode( gate.wait_for_release(), ) .await; - next(args).await + next(args).await.map(Into::into) } } - }) as Pin> + Send>> + }) + as Pin< + Box< + dyn Future< + Output = FlowResult< + nemo_relay::api::tool::ToolExecutionInterceptOutcome, + >, + > + Send, + >, + > }) } diff --git a/crates/adaptive/tests/unit/intercepts_tests.rs b/crates/adaptive/tests/unit/intercepts_tests.rs index 98fd1641..971d71c6 100644 --- a/crates/adaptive/tests/unit/intercepts_tests.rs +++ b/crates/adaptive/tests/unit/intercepts_tests.rs @@ -102,7 +102,7 @@ async fn test_tool_intercept_calls_next() { let result = intercept("test", json!({"input": 1}), next).await; assert!(result.is_ok()); - assert_eq!(result.unwrap(), json!({"result": "ok"})); + assert_eq!(result.unwrap(), json!({"result": "ok"}).into()); } #[tokio::test] @@ -125,7 +125,7 @@ async fn test_tool_intercept_with_populated_cache() { // Should not panic and should return next's result let result = intercept("test", json!({"tool_input": "data"}), next).await; assert!(result.is_ok()); - assert_eq!(result.unwrap(), json!({"from_next": true})); + assert_eq!(result.unwrap(), json!({"from_next": true}).into()); } #[tokio::test] @@ -147,7 +147,7 @@ async fn test_tool_intercept_passes_args_to_next() { let input = json!({"tool_arg": "value", "count": 42}); let result = intercept("test", input.clone(), next).await; assert!(result.is_ok()); - assert_eq!(result.unwrap(), input); + assert_eq!(result.unwrap(), input.into()); } #[test] @@ -325,8 +325,8 @@ async fn test_schedule_mode_intercept_waits_for_primer_before_running_follower() tokio::task::yield_now().await; let follower = tokio::spawn(intercept("search", json!({"call": 2}), next.clone())); - assert_eq!(primer.await.unwrap().unwrap(), json!({"call": 1})); - assert_eq!(follower.await.unwrap().unwrap(), json!({"call": 2})); + assert_eq!(primer.await.unwrap().unwrap(), json!({"call": 1}).into()); + assert_eq!(follower.await.unwrap().unwrap(), json!({"call": 2}).into()); assert_eq!(next_order.load(Ordering::SeqCst), 2); reset_scope_stack(); diff --git a/crates/adaptive/tests/unit/runtime_features_tests.rs b/crates/adaptive/tests/unit/runtime_features_tests.rs index cdae1a8d..2a45418e 100644 --- a/crates/adaptive/tests/unit/runtime_features_tests.rs +++ b/crates/adaptive/tests/unit/runtime_features_tests.rs @@ -206,14 +206,22 @@ fn assert_llm_stream_execution_intercept_absent(name: &str) { fn assert_tool_execution_intercept_registered(name: &str) { assert_already_registered( - register_tool_execution_intercept(name, i32::MAX, Arc::new(|_name, args, next| next(args))), + register_tool_execution_intercept( + name, + i32::MAX, + Arc::new(|_name, args, next| Box::pin(async move { next(args).await.map(Into::into) })), + ), name, ); } fn assert_tool_execution_intercept_absent(name: &str) { - register_tool_execution_intercept(name, i32::MAX, Arc::new(|_name, args, next| next(args))) - .unwrap(); + register_tool_execution_intercept( + name, + i32::MAX, + Arc::new(|_name, args, next| Box::pin(async move { next(args).await.map(Into::into) })), + ) + .unwrap(); deregister_tool_execution_intercept(name).unwrap(); } @@ -753,7 +761,7 @@ async fn registration_context_registers_all_supported_callback_types() { ctx.register_tool_execution_intercept( "adaptive_test_tool", 8, - Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })), + Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })), ) .unwrap(); diff --git a/crates/core/src/api/registry.rs b/crates/core/src/api/registry.rs index c7989c7a..998096ee 100644 --- a/crates/core/src/api/registry.rs +++ b/crates/core/src/api/registry.rs @@ -508,7 +508,9 @@ global_intercept_registry_api!( ); global_execution_registry_api!( /// Register a global tool execution intercept. - /// Execution intercepts can wrap or replace the tool callback. + /// Execution intercepts can wrap or replace the tool callback. Each + /// callback returns a canonical tool execution outcome, while its + /// continuation resolves to the raw downstream result JSON. register_tool_execution_intercept, /// Deregister a global tool execution intercept. deregister_tool_execution_intercept, @@ -616,7 +618,8 @@ scope_intercept_registry_api!( scope_execution_registry_api!( /// Register a scope-local tool execution intercept. /// Execution intercepts can wrap or replace the tool callback inside the - /// owning scope. + /// owning scope. Each callback returns a canonical tool execution outcome, + /// while its continuation resolves to the raw downstream result JSON. scope_register_tool_execution_intercept, /// Deregister a scope-local tool execution intercept. scope_deregister_tool_execution_intercept, diff --git a/crates/core/src/api/runtime/callbacks.rs b/crates/core/src/api/runtime/callbacks.rs index 58e5f851..d3790fbb 100644 --- a/crates/core/src/api/runtime/callbacks.rs +++ b/crates/core/src/api/runtime/callbacks.rs @@ -16,6 +16,7 @@ use tokio_stream::Stream; use crate::api::event::Event; use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; +use crate::api::tool::ToolExecutionInterceptOutcome; use crate::codec::request::AnnotatedLlmRequest; use crate::error::Result; use crate::json::Json; @@ -81,7 +82,9 @@ pub type ToolInterceptFn = Arc Result + Send + Sync> /// chain. /// /// # Returns -/// A future resolving to the tool result JSON. +/// A future resolving to the downstream tool result JSON. Pending marks from +/// downstream intercepts are retained by the runtime and are not exposed +/// through this continuation. /// /// # Errors /// The future resolves to an error when the remaining execution chain fails. @@ -98,13 +101,25 @@ pub type ToolExecutionNextFn = /// - Third argument: Continuation for the remaining execution chain. /// /// # Returns -/// A future resolving to the tool result JSON. +/// A future resolving to the canonical tool execution outcome, containing the +/// tool result and any pending lifecycle marks produced by this intercept. /// /// # Errors /// The future resolves to an error when the intercept or remaining execution /// chain fails. pub type ToolExecutionFn = Arc< - dyn Fn(&str, Json, ToolExecutionNextFn) -> Pin> + Send>> + dyn Fn( + &str, + Json, + ToolExecutionNextFn, + ) -> Pin> + Send>> + + Send + + Sync, +>; + +/// Internal continuation carrying both a tool result and accumulated marks. +pub(crate) type ToolExecutionOutcomeNextFn = Arc< + dyn Fn(Json) -> Pin> + Send>> + Send + Sync, >; diff --git a/crates/core/src/api/runtime/state.rs b/crates/core/src/api/runtime/state.rs index 5011cc71..e5c13475 100644 --- a/crates/core/src/api/runtime/state.rs +++ b/crates/core/src/api/runtime/state.rs @@ -10,7 +10,8 @@ use std::any::Any; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex}; use crate::api::event::{ BaseEvent, CategoryProfile, Event, EventCategory, MarkEvent, ScopeCategory, ScopeEvent, @@ -23,12 +24,14 @@ use crate::api::runtime::callbacks::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, LlmStreamExecutionRegistryRefs, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, - ToolInterceptFn, ToolSanitizeFn, + ToolExecutionOutcomeNextFn, ToolInterceptFn, ToolSanitizeFn, }; use crate::api::runtime::subscriber_dispatcher; use crate::api::scope::{CreateScopeHandleParams, EndScopeHandleParams, ScopeHandle, ScopeType}; use crate::api::tool::ToolHandle; -use crate::api::tool::{CreateToolHandleParams, EndToolHandleParams}; +use crate::api::tool::{ + CreateToolHandleParams, EndToolHandleParams, ToolExecutionInterceptOutcome, +}; use crate::codec::request::AnnotatedLlmRequest; use crate::codec::response::AnnotatedLlmResponse; use crate::context::registries::{ @@ -821,22 +824,68 @@ impl NemoRelayContextState { /// from the active scope stack. /// /// # Returns - /// A composed [`ToolExecutionNextFn`] that wraps `default_fn` in every - /// matching execution intercept. + /// A composed [`ToolExecutionOutcomeNextFn`] that wraps `default_fn` in + /// every matching execution intercept. pub(crate) fn tool_build_execution_chain( &self, name: &str, default_fn: ToolExecutionNextFn, scope_locals: &[&SortedRegistry>], - ) -> ToolExecutionNextFn { + ) -> ToolExecutionOutcomeNextFn { let matching = merge_execution_intercept_callables(&self.tool_execution_intercepts, scope_locals); - let mut next = default_fn; + let mut next: ToolExecutionOutcomeNextFn = Arc::new(move |args| { + let default_fn = default_fn.clone(); + Box::pin(async move { + default_fn(args) + .await + .map(ToolExecutionInterceptOutcome::new) + }) + }); let name = name.to_string(); for (callable, _) in matching.into_iter().rev() { let current_next = next.clone(); let current_name = name.clone(); - next = Arc::new(move |args| callable(¤t_name, args, current_next.clone())); + next = Arc::new(move |args| { + let callable = callable.clone(); + let current_name = current_name.clone(); + let next_sequence = Arc::new(AtomicUsize::new(0)); + let downstream_marks = Arc::new(Mutex::new(Vec::new())); + let raw_next: ToolExecutionNextFn = { + let current_next = current_next.clone(); + let next_sequence = next_sequence.clone(); + let downstream_marks = downstream_marks.clone(); + Arc::new(move |args| { + let sequence = next_sequence.fetch_add(1, Ordering::Relaxed); + let current_next = current_next.clone(); + let downstream_marks = downstream_marks.clone(); + Box::pin(async move { + let outcome = current_next(args).await?; + downstream_marks + .lock() + .expect("tool pending mark accumulator lock poisoned") + .push((sequence, outcome.pending_marks)); + Ok(outcome.result) + }) + }) + }; + Box::pin(async move { + let mut outcome = callable(¤t_name, args, raw_next).await?; + let mut downstream_batches = std::mem::take( + &mut *downstream_marks + .lock() + .expect("tool pending mark accumulator lock poisoned"), + ); + downstream_batches.sort_by_key(|(sequence, _)| *sequence); + let mut marks = downstream_batches + .into_iter() + .flat_map(|(_, marks)| marks) + .collect::>(); + marks.append(&mut outcome.pending_marks); + outcome.pending_marks = marks; + Ok(outcome) + }) + }); } next } diff --git a/crates/core/src/api/tool.rs b/crates/core/src/api/tool.rs index 8e651289..2755f32d 100644 --- a/crates/core/src/api/tool.rs +++ b/crates/core/src/api/tool.rs @@ -3,10 +3,11 @@ use serde_json::json; +use crate::api::event::{BaseEvent, Event, MarkEvent, PendingMarkSpec}; use crate::api::runtime::NemoRelayContextState; -use crate::api::runtime::ToolExecutionNextFn; use crate::api::runtime::current_scope_stack; use crate::api::runtime::global_context; +use crate::api::runtime::{EventSubscriberFn, ToolExecutionNextFn}; use crate::api::scope::event; use crate::api::scope::{EmitMarkEventParams, ScopeHandle}; use crate::api::shared::{ @@ -15,12 +16,12 @@ use crate::api::shared::{ }; use crate::error::{FlowError, Result}; use crate::json::Json; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, TimeDelta, Utc}; use serde::{Deserialize, Serialize}; use typed_builder::TypedBuilder; use uuid::Uuid; -pub use nemo_relay_types::api::tool::ToolAttributes; +pub use nemo_relay_types::api::tool::{ToolAttributes, ToolExecutionInterceptOutcome}; /// Runtime-owned handle identifying an active or completed tool call. #[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)] @@ -204,6 +205,13 @@ pub struct ToolCallEndParams<'a> { /// Sanitize-request guardrails affect only the emitted start-event payload, not /// the caller-owned `args` value. pub fn tool_call(params: ToolCallParams<'_>) -> Result { + let (handle, _) = tool_call_with_subscriber_snapshot(params)?; + Ok(handle) +} + +fn tool_call_with_subscriber_snapshot( + params: ToolCallParams<'_>, +) -> Result<(ToolHandle, Vec)> { ensure_runtime_owner()?; let parent_uuid = resolve_parent_uuid(params.parent); let (entries, subscribers) = { @@ -245,7 +253,7 @@ pub fn tool_call(params: ToolCallParams<'_>) -> Result { (handle, event) }; NemoRelayContextState::emit_event(&event, &subscribers); - Ok(handle) + Ok((handle, subscribers)) } /// Finish a manual tool lifecycle span. @@ -275,6 +283,14 @@ pub fn tool_call(params: ToolCallParams<'_>) -> Result { /// Sanitize-response guardrails affect only the emitted end-event payload, not /// the caller-owned `result` value. pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> { + tool_call_end_with_pending_marks(params, Vec::new(), None) +} + +fn tool_call_end_with_pending_marks( + params: ToolCallEndParams<'_>, + pending_marks: Vec, + lifecycle_subscribers: Option<&[EventSubscriberFn]>, +) -> Result<()> { ensure_runtime_owner()?; let (entries, subscribers) = { let scope_stack = current_scope_stack(); @@ -282,8 +298,11 @@ pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> { let scope_locals = scope_guard.collect_scope_local_registries(|registries| { ®istries.tool_sanitize_response_guardrails }); - let scope_subscribers = scope_guard.collect_scope_local_subscribers(); - let subscribers = snapshot_event_subscribers(scope_subscribers)?; + let subscribers = if lifecycle_subscribers.is_some() { + Vec::new() + } else { + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; let context = global_context(); let state = context .read() @@ -291,6 +310,7 @@ pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> { let entries = state.tool_sanitize_response_entries(&scope_locals); (entries, subscribers) }; + let subscribers = lifecycle_subscribers.unwrap_or(&subscribers); let sanitized_result = NemoRelayContextState::tool_sanitize_response_snapshot_chain( ¶ms.handle.name, params.result, @@ -315,25 +335,46 @@ pub fn tool_call_end(params: ToolCallEndParams<'_>) -> Result<()> { .build(), ) }; - NemoRelayContextState::emit_event(&event, &subscribers); + let marks = pending_marks + .into_iter() + .enumerate() + .map(|(index, mark)| { + let timestamp = *event.timestamp() + + TimeDelta::microseconds(i64::try_from(index).unwrap_or_default() + 1); + Event::Mark(MarkEvent::new( + BaseEvent::builder() + .name(mark.name) + .parent_uuid(params.handle.uuid) + .timestamp(timestamp) + .data_opt(mark.data) + .metadata_opt(mark.metadata) + .build(), + mark.category, + mark.category_profile, + )) + }) + .collect::>(); + NemoRelayContextState::emit_event(&event, subscribers); + for mark in marks { + NemoRelayContextState::emit_event(&mark, subscribers); + } Ok(()) } -fn emit_tool_end_without_output(handle: &ToolHandle, metadata: Option) -> Result<()> { +fn emit_tool_end_without_output( + handle: &ToolHandle, + metadata: Option, + lifecycle_subscribers: &[EventSubscriberFn], +) -> Result<()> { ensure_runtime_owner()?; - let (event, subscribers) = { - let scope_stack = current_scope_stack(); - let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); - let scope_subscribers = scope_guard.collect_scope_local_subscribers(); - let subscribers = snapshot_event_subscribers(scope_subscribers)?; + let event = { let context = global_context(); let state = context .read() .map_err(|error| FlowError::Internal(error.to_string()))?; - let event = state.end_tool_handle(handle, handle.data.clone(), metadata); - (event, subscribers) + state.end_tool_handle(handle, handle.data.clone(), metadata) }; - NemoRelayContextState::emit_event(&event, &subscribers); + NemoRelayContextState::emit_event(&event, lifecycle_subscribers); Ok(()) } @@ -439,7 +480,7 @@ pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result { &intercept_entries, )?; - let handle = tool_call( + let (handle, lifecycle_subscribers) = tool_call_with_subscriber_snapshot( ToolCallParams::builder() .name(name.as_str()) .args(intercepted_args.clone()) @@ -463,22 +504,28 @@ pub async fn tool_call_execute(params: ToolCallExecuteParams) -> Result { }; match execution(intercepted_args).await { - Ok(result) => { + Ok(outcome) => { + let ToolExecutionInterceptOutcome { + result, + pending_marks, + } = outcome; let end_metadata = metadata_with_otel_status(metadata, "OK", None); - tool_call_end( + tool_call_end_with_pending_marks( ToolCallEndParams::builder() .handle(&handle) .result(result.clone()) .data_opt(data) .metadata_opt(end_metadata) .build(), + pending_marks, + Some(&lifecycle_subscribers), )?; Ok(result) } Err(error) => { let end_metadata = metadata_with_otel_status(metadata, "ERROR", Some(error.to_string())); - let _ = emit_tool_end_without_output(&handle, end_metadata); + let _ = emit_tool_end_without_output(&handle, end_metadata, &lifecycle_subscribers); Err(error) } } diff --git a/crates/core/src/plugin/dynamic/native.rs b/crates/core/src/plugin/dynamic/native.rs index 1ba46758..f612e884 100644 --- a/crates/core/src/plugin/dynamic/native.rs +++ b/crates/core/src/plugin/dynamic/native.rs @@ -48,6 +48,7 @@ use crate::api::scope::{ EmitMarkEventParams, PopScopeParams, PushScopeParams, ScopeAttributes, ScopeHandle, ScopeType, }; use crate::api::scope::{event as emit_scope_mark, get_handle, pop_scope, push_scope}; +use crate::api::tool::ToolExecutionInterceptOutcome; use crate::error::{FlowError, Result as FlowResult}; use crate::plugin::{ ConfigDiagnostic, DiagnosticLevel, Plugin, PluginError, PluginRegistrationContext, @@ -1524,7 +1525,7 @@ fn wrap_tool_execution_fn( let args_string = native_string_from_json(&args) .ok_or_else(|| FlowError::Internal("failed to allocate native args".into()))?; let next_ctx = Box::into_raw(Box::new(next)) as *mut c_void; - let mut out = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); let status = unsafe { cb( user_data.ptr, @@ -1532,7 +1533,7 @@ fn wrap_tool_execution_fn( args_string, native_tool_next, next_ctx, - &mut out, + &mut out_outcome, ) }; unsafe { @@ -1541,15 +1542,21 @@ fn wrap_tool_execution_fn( native_string_free(args_string); } if status != NemoRelayStatus::Ok { - if !out.is_null() { - unsafe { native_string_free(out) }; + if !out_outcome.is_null() { + unsafe { native_string_free(out_outcome) }; } return Err(flow_error_from_status( status, "native tool execution failed", )); } - take_json_from_native_string(out, "native tool execution returned null") + let outcome_json = take_json_from_native_string( + out_outcome, + "native tool execution returned null outcome", + )?; + serde_json::from_value::(outcome_json).map_err(|err| { + FlowError::Internal(format!("invalid native tool execution outcome JSON: {err}")) + }) }) }) } diff --git a/crates/core/src/plugin/dynamic/worker.rs b/crates/core/src/plugin/dynamic/worker.rs index ee43ace8..d8424043 100644 --- a/crates/core/src/plugin/dynamic/worker.rs +++ b/crates/core/src/plugin/dynamic/worker.rs @@ -61,6 +61,7 @@ use crate::api::scope::{ EmitMarkEventParams, PopScopeParams, PushScopeParams, ScopeAttributes, ScopeHandle, ScopeType, event as emit_scope_mark, pop_scope, push_scope, }; +use crate::api::tool::ToolExecutionInterceptOutcome; use crate::codec::request::AnnotatedLlmRequest; use crate::error::{FlowError, Result as FlowResult}; use crate::plugin::{ @@ -1180,7 +1181,7 @@ impl WorkerPluginCallback { tool_name: &str, value: Json, next: ToolExecutionNextFn, - ) -> FlowResult { + ) -> FlowResult { let continuation_id = self .host_state .insert_continuation(Continuation::Tool(next))?; @@ -1190,7 +1191,28 @@ impl WorkerPluginCallback { Some(continuation_id), Some(invoke_request_payload_tool(tool_name, value)), ); - json_from_invoke_response(self.invoke_async(request).await?) + let response = self.invoke_async(request).await?; + match response.result { + Some(invoke_response_result::Result::ToolExecution(result)) => { + let outcome = + required_envelope(result.outcome, "tool execution intercept outcome")?; + if outcome.schema != "nemo.relay.ToolExecutionInterceptOutcome@1" { + return Err(FlowError::Internal(format!( + "worker returned unsupported tool execution intercept outcome schema: {}", + outcome.schema + ))); + } + decode_json_envelope(&outcome).map_err(|err| { + FlowError::Internal(format!( + "worker returned invalid tool execution intercept outcome: {err}" + )) + }) + } + Some(invoke_response_result::Result::Error(error)) => Err(worker_error_to_flow(error)), + _ => Err(FlowError::Internal( + "worker tool execution intercept returned unexpected result".into(), + )), + } } fn invoke_llm_request_json( diff --git a/crates/core/src/plugins/nemo_guardrails/python.rs b/crates/core/src/plugins/nemo_guardrails/python.rs index 20a27fd3..d5220003 100644 --- a/crates/core/src/plugins/nemo_guardrails/python.rs +++ b/crates/core/src/plugins/nemo_guardrails/python.rs @@ -99,13 +99,14 @@ pub(super) fn register_local_backend( }; let tool_result = next(current_args.clone()).await?; - if !enable_tool_output { - return Ok(tool_result); - } - - runtime - .check_tool_output(&tool_name, ¤t_args, &tool_result) - .await + let tool_result = if enable_tool_output { + runtime + .check_tool_output(&tool_name, ¤t_args, &tool_result) + .await? + } else { + tool_result + }; + Ok(tool_result.into()) }) }); ctx.register_tool_execution_intercept( diff --git a/crates/core/src/plugins/nemo_guardrails/remote.rs b/crates/core/src/plugins/nemo_guardrails/remote.rs index de2e800c..e868f819 100644 --- a/crates/core/src/plugins/nemo_guardrails/remote.rs +++ b/crates/core/src/plugins/nemo_guardrails/remote.rs @@ -938,13 +938,14 @@ pub(super) fn register_remote_backend( }; let tool_result = next(current_args.clone()).await?; - if !enable_tool_output { - return Ok(tool_result); - } - - runtime - .check_tool_output(&tool_name, ¤t_args, &tool_result) - .await + let tool_result = if enable_tool_output { + runtime + .check_tool_output(&tool_name, ¤t_args, &tool_result) + .await? + } else { + tool_result + }; + Ok(tool_result.into()) }) }); ctx.register_tool_execution_intercept( diff --git a/crates/core/tests/fixtures/native_plugin/src/lib.rs b/crates/core/tests/fixtures/native_plugin/src/lib.rs index cb99cb66..4886a6c9 100644 --- a/crates/core/tests/fixtures/native_plugin/src/lib.rs +++ b/crates/core/tests/fixtures/native_plugin/src/lib.rs @@ -8,7 +8,8 @@ use nemo_relay_plugin::{ CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, Json, LlmJsonStream, LlmRequest, LlmRequestInterceptOutcome, NemoRelayNativeHostApiV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, NemoRelayNativeString, NemoRelayStatus, - NativePlugin, PendingMarkSpec, PluginContext, PluginRuntime, ScopeCategory, ScopeType, + NemoRelayNativeToolNextFn, NativePlugin, PendingMarkSpec, PluginContext, PluginRuntime, + ScopeCategory, ScopeType, ToolExecutionInterceptOutcome, }; use serde_json::{Map, json}; @@ -95,7 +96,21 @@ impl NativePlugin for FixtureNativePlugin { } else { next.call(args)? }; - Ok(mark_json(result, "native_plugin_tool_execution")) + let result = mark_json(result, "native_plugin_tool_execution"); + Ok( + ToolExecutionInterceptOutcome::new(result).with_pending_mark( + PendingMarkSpec::builder() + .name("fixture.native.tool_execution.mark") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("fixture.native.tool_execution".into()), + ..CategoryProfile::default() + }) + .data(json!({ "source": "native_tool_execution" })) + .metadata(json!({ "fixture": true })) + .build(), + ), + ) } })?; @@ -350,6 +365,23 @@ pub unsafe extern "C" fn nemo_relay_fixture_register_error( } } +#[unsafe(no_mangle)] +pub unsafe extern "C" fn nemo_relay_fixture_tool_outcome_errors( + host: *const NemoRelayNativeHostApiV1, + out: *mut NemoRelayNativePluginV1, +) -> NemoRelayStatus { + unsafe { + write_raw_descriptor( + host, + out, + "fixture_native", + None, + None, + Some(raw_register_tool_outcome_errors), + ) + } +} + type RawValidate = unsafe extern "C" fn( *mut c_void, *const NemoRelayNativeString, @@ -432,6 +464,80 @@ unsafe extern "C" fn raw_register_error( NemoRelayStatus::Internal } +unsafe extern "C" fn raw_register_tool_outcome_errors( + user_data: *mut c_void, + _plugin_config_json: *const NemoRelayNativeString, + ctx: *mut NemoRelayNativePluginContext, +) -> NemoRelayStatus { + let Some(host) = (unsafe { raw_host_from_user_data(user_data) }) else { + return NemoRelayStatus::NullPointer; + }; + let name = unsafe { raw_host_string(host, "fixture_raw_tool_outcome") }; + if name.is_null() { + return NemoRelayStatus::Internal; + } + let status = unsafe { + (host.plugin_context_register_tool_execution_intercept)( + ctx, + name, + 0, + raw_tool_outcome_callback, + user_data, + None, + ) + }; + unsafe { (host.string_free)(name) }; + status +} + +unsafe extern "C" fn raw_tool_outcome_callback( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + _args_json: *const NemoRelayNativeString, + _next_fn: NemoRelayNativeToolNextFn, + _next_ctx: *mut c_void, + out_outcome_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus { + if out_outcome_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { *out_outcome_json = ptr::null_mut() }; + let Some(host) = (unsafe { raw_host_from_user_data(user_data) }) else { + return NemoRelayStatus::NullPointer; + }; + let Some(name) = (unsafe { raw_host_string_value(host, name) }) else { + return NemoRelayStatus::InvalidUtf8; + }; + match name.as_str() { + "fixture-null-outcome" => NemoRelayStatus::Ok, + "fixture-malformed-outcome" => { + unsafe { + *out_outcome_json = raw_host_string(host, r#"{"pending_marks":[]}"#); + } + NemoRelayStatus::Ok + } + "fixture-status-error-outcome" => { + unsafe { + *out_outcome_json = raw_host_string( + host, + r#"{"result":{"stale":true},"pending_marks":[]}"#, + ); + set_raw_last_error_from_user_data(user_data, "fixture tool execution failed"); + } + NemoRelayStatus::Internal + } + _ => { + unsafe { + *out_outcome_json = raw_host_string( + host, + r#"{"result":{"raw_tool_outcome":true},"pending_marks":[]}"#, + ); + } + NemoRelayStatus::Ok + } + } +} + unsafe extern "C" fn raw_drop_host(user_data: *mut c_void) { if !user_data.is_null() { drop(unsafe { Box::from_raw(user_data as *mut NemoRelayNativeHostApiV1) }); @@ -480,3 +586,23 @@ unsafe fn raw_host_string( ptr::null_mut() } } + +unsafe fn raw_host_string_value( + host: &NemoRelayNativeHostApiV1, + value: *const NemoRelayNativeString, +) -> Option { + if value.is_null() { + return None; + } + let len = unsafe { (host.string_len)(value) }; + let data = unsafe { (host.string_data)(value) }; + if data.is_null() && len > 0 { + return None; + } + let bytes = if len == 0 { + &[][..] + } else { + unsafe { std::slice::from_raw_parts(data, len) } + }; + std::str::from_utf8(bytes).ok().map(str::to_owned) +} diff --git a/crates/core/tests/fixtures/worker_plugin/src/main.rs b/crates/core/tests/fixtures/worker_plugin/src/main.rs index 6291e1dc..8fe888cf 100644 --- a/crates/core/tests/fixtures/worker_plugin/src/main.rs +++ b/crates/core/tests/fixtures/worker_plugin/src/main.rs @@ -3,7 +3,7 @@ use nemo_relay_worker::{ JsonStream, LlmNext, LlmStreamNext, PluginContext, ScopeType, ToolNext, WorkerPlugin, - WorkerSdkError, serve_plugin, + ToolExecutionInterceptOutcome, WorkerSdkError, serve_plugin, }; use nemo_relay_worker::{ ConfigDiagnostic, DiagnosticLevel, Json, LlmRequest, PendingMarkSpec, @@ -155,7 +155,17 @@ impl WorkerPlugin for FixtureWorkerPlugin { let result = next .call(mark_json(args, "worker_plugin_tool_execution_request")) .await?; - Ok(mark_json(result, "worker_plugin_tool_execution")) + Ok( + ToolExecutionInterceptOutcome::new(mark_json( + result, + "worker_plugin_tool_execution", + )) + .with_pending_mark( + PendingMarkSpec::builder() + .name("fixture.worker.tool_execution.mark") + .build(), + ), + ) }, ); diff --git a/crates/core/tests/integration/api_surface_tests.rs b/crates/core/tests/integration/api_surface_tests.rs index 3bc8538a..f35dc213 100644 --- a/crates/core/tests/integration/api_surface_tests.rs +++ b/crates/core/tests/integration/api_surface_tests.rs @@ -394,7 +394,7 @@ fn test_global_registry_and_subscriber_wrappers_cover_success_and_duplicates() { register_tool_execution_intercept( "tool-execution", 1, - Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })), + Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })), ) .unwrap(); assert!(deregister_tool_execution_intercept("tool-execution").unwrap()); @@ -570,7 +570,7 @@ fn test_scope_registry_and_subscriber_wrappers_cover_success_duplicates_and_miss &scope.uuid, "tool-execution", 1, - Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })), + Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })), ) .unwrap(); assert!(scope_deregister_tool_execution_intercept(&scope.uuid, "tool-execution").unwrap()); @@ -690,7 +690,7 @@ fn test_scope_registry_and_subscriber_wrappers_cover_success_duplicates_and_miss &scope.uuid, "missing-tool-exec", 1, - Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })), + Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })), ) .unwrap_err(), "scope", diff --git a/crates/core/tests/integration/middleware_tests.rs b/crates/core/tests/integration/middleware_tests.rs index cf9d29e2..b978195b 100644 --- a/crates/core/tests/integration/middleware_tests.rs +++ b/crates/core/tests/integration/middleware_tests.rs @@ -51,10 +51,11 @@ use nemo_relay::api::scope::{ScopeHandle, ScopeType}; use nemo_relay::api::scope::{pop_scope, push_scope}; use nemo_relay::api::subscriber::{deregister_subscriber, flush_subscribers, register_subscriber}; use nemo_relay::api::tool::{ - tool_call, tool_call_end, tool_call_execute, tool_conditional_execution, - tool_request_intercepts, + ToolExecutionInterceptOutcome, tool_call, tool_call_end, tool_call_execute, + tool_conditional_execution, tool_request_intercepts, }; use nemo_relay::error::FlowError; +use nemo_relay::plugin::{PluginRegistrationContext, rollback_registrations}; use serde_json::json; // All tests share the global context, so we serialize them. @@ -456,7 +457,7 @@ async fn test_execution_intercept_calls_next() { Arc::new(|_name, args, next| { Box::pin(async move { // Call next — this should reach the original callable - next(args).await + next(args).await.map(Into::into) }) }), ) @@ -505,7 +506,7 @@ async fn test_execution_intercept_skips_next() { Arc::new(|_name, _args, _next| { Box::pin(async move { // Return a custom result without calling next - Ok(json!({"intercepted": true})) + Ok(json!({"intercepted": true}).into()) }) }), ) @@ -558,7 +559,7 @@ async fn test_execution_intercept_chain_ordering() { o.lock().unwrap().push("intercept_1_before".into()); let result = next(args).await; o.lock().unwrap().push("intercept_1_after".into()); - result + result.map(Into::into) }) }), ) @@ -575,7 +576,7 @@ async fn test_execution_intercept_chain_ordering() { o.lock().unwrap().push("intercept_2_before".into()); let result = next(args).await; o.lock().unwrap().push("intercept_2_after".into()); - result + result.map(Into::into) }) }), ) @@ -631,7 +632,7 @@ async fn test_execution_intercept_modifies_args() { args.as_object_mut() .unwrap() .insert("injected".into(), json!(true)); - next(args).await + next(args).await.map(Into::into) }) }), ) @@ -656,6 +657,449 @@ async fn test_execution_intercept_modifies_args() { deregister_tool_execution_intercept("arg_modifier").unwrap(); } +#[tokio::test] +async fn test_tool_execution_outcome_marks_follow_end_with_tool_parentage() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "tool_outcome_mark_observer", + Arc::new(move |event| captured.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + let mut plugin_ctx = PluginRegistrationContext::new(); + plugin_ctx + .register_tool_execution_intercept( + "outcome_outer", + 1, + Arc::new(|_name, args, next| { + Box::pin(async move { + let result = next(args).await?; + Ok( + ToolExecutionInterceptOutcome::new(result).with_pending_mark( + PendingMarkSpec::builder() + .name("tool.mark.outer") + .data(json!({"layer": "outer"})) + .build(), + ), + ) + }) + }), + ) + .unwrap(); + register_tool_execution_intercept( + "passthrough_between_outcomes", + 2, + Arc::new(|_name, args, next| Box::pin(async move { next(args).await.map(Into::into) })), + ) + .unwrap(); + plugin_ctx + .register_tool_execution_intercept( + "outcome_inner", + 3, + Arc::new(|_name, args, next| { + Box::pin(async move { + let mut result = next(args).await?; + result["compressed"] = json!(true); + Ok( + ToolExecutionInterceptOutcome::new(result).with_pending_mark( + PendingMarkSpec::builder() + .name("tool.mark.inner") + .category(EventCategory::custom()) + .category_profile( + CategoryProfile::builder() + .subtype("example.tool.compression") + .build(), + ) + .data(json!({"saved_tokens": 12})) + .metadata(json!({"source": "test"})) + .build(), + ), + ) + }) + }), + ) + .unwrap(); + + let result = tool_call_execute( + nemo_relay::api::tool::ToolCallExecuteParams::builder() + .name("tool-outcome") + .args(json!({"value": 42})) + .func(Arc::new(|args| Box::pin(async move { Ok(args) }))) + .build(), + ) + .await + .unwrap(); + assert_eq!(result, json!({"value": 42, "compressed": true})); + assert!(result.get("pending_marks").is_none()); + + flush_subscribers().unwrap(); + let captured = events.lock().unwrap(); + let start_index = captured + .iter() + .position(|event| { + event.name() == "tool-outcome" && event.scope_category() == Some(ScopeCategory::Start) + }) + .unwrap(); + let end_index = captured + .iter() + .position(|event| { + event.name() == "tool-outcome" && event.scope_category() == Some(ScopeCategory::End) + }) + .unwrap(); + let inner_index = captured + .iter() + .position(|event| event.name() == "tool.mark.inner") + .unwrap(); + let outer_index = captured + .iter() + .position(|event| event.name() == "tool.mark.outer") + .unwrap(); + assert!(start_index < end_index); + assert!(end_index < inner_index); + assert!(inner_index < outer_index); + + let start = &captured[start_index]; + let end = &captured[end_index]; + let inner = &captured[inner_index]; + let outer = &captured[outer_index]; + assert_eq!(inner.parent_uuid(), Some(start.uuid())); + assert_eq!(outer.parent_uuid(), Some(start.uuid())); + assert!(inner.timestamp() > end.timestamp()); + assert!(outer.timestamp() > inner.timestamp()); + assert_eq!(end.data().unwrap(), &result); + assert_eq!(inner.category().map(EventCategory::as_str), Some("custom")); + assert_eq!( + inner + .category_profile() + .and_then(|profile| profile.subtype.as_deref()), + Some("example.tool.compression") + ); + assert_eq!(inner.data().unwrap()["saved_tokens"], 12); + assert_eq!(inner.metadata().unwrap()["source"], "test"); + drop(captured); + + deregister_tool_execution_intercept("passthrough_between_outcomes").unwrap(); + let mut registrations = plugin_ctx.into_registrations(); + rollback_registrations(&mut registrations); + deregister_subscriber("tool_outcome_mark_observer").unwrap(); +} + +#[tokio::test] +async fn test_tool_execution_error_discards_downstream_pending_marks() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "tool_outcome_error_observer", + Arc::new(move |event| captured.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + register_tool_execution_intercept( + "error_after_outcome", + 1, + Arc::new(|_name, args, next| { + Box::pin(async move { + let _ = next(args).await?; + Err(FlowError::Internal("outer failure".into())) + }) + }), + ) + .unwrap(); + let mut plugin_ctx = PluginRegistrationContext::new(); + plugin_ctx + .register_tool_execution_intercept( + "outcome_before_error", + 2, + Arc::new(|_name, args, next| { + Box::pin(async move { + let result = next(args).await?; + Ok( + ToolExecutionInterceptOutcome::new(result).with_pending_mark( + PendingMarkSpec::builder() + .name("tool.mark.must_not_emit") + .build(), + ), + ) + }) + }), + ) + .unwrap(); + + let error = tool_call_execute( + nemo_relay::api::tool::ToolCallExecuteParams::builder() + .name("tool-outcome-error") + .args(json!({})) + .func(Arc::new(|args| Box::pin(async move { Ok(args) }))) + .build(), + ) + .await + .unwrap_err(); + assert!(error.to_string().contains("outer failure")); + + flush_subscribers().unwrap(); + let captured = events.lock().unwrap(); + assert!( + captured + .iter() + .all(|event| event.name() != "tool.mark.must_not_emit") + ); + assert!(captured.iter().any(|event| { + event.name() == "tool-outcome-error" && event.scope_category() == Some(ScopeCategory::End) + })); + drop(captured); + + deregister_tool_execution_intercept("error_after_outcome").unwrap(); + let mut registrations = plugin_ctx.into_registrations(); + rollback_registrations(&mut registrations); + deregister_subscriber("tool_outcome_error_observer").unwrap(); +} + +#[tokio::test] +async fn test_managed_tool_reuses_start_subscriber_snapshot_for_end_and_marks() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let original_events = Arc::new(Mutex::new(Vec::::new())); + let captured_original = original_events.clone(); + register_subscriber( + "tool_lifecycle_original", + Arc::new(move |event| captured_original.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + let replacement_events = Arc::new(Mutex::new(Vec::::new())); + let captured_replacement = replacement_events.clone(); + let mut plugin_ctx = PluginRegistrationContext::new(); + plugin_ctx + .register_tool_execution_intercept( + "mutate_tool_subscribers", + 1, + Arc::new(move |_name, args, next| { + let captured_replacement = captured_replacement.clone(); + Box::pin(async move { + assert!(deregister_subscriber("tool_lifecycle_original").unwrap()); + register_subscriber( + "tool_lifecycle_replacement", + Arc::new(move |event| { + captured_replacement.lock().unwrap().push(event.clone()); + }), + ) + .unwrap(); + let result = next(args).await?; + Ok( + ToolExecutionInterceptOutcome::new(result).with_pending_mark( + PendingMarkSpec::builder() + .name("tool.snapshot.mark") + .build(), + ), + ) + }) + }), + ) + .unwrap(); + + tool_call_execute( + nemo_relay::api::tool::ToolCallExecuteParams::builder() + .name("tool-subscriber-snapshot") + .args(json!({"value": 1})) + .func(Arc::new(|args| Box::pin(async move { Ok(args) }))) + .build(), + ) + .await + .unwrap(); + flush_subscribers().unwrap(); + + let original_events = original_events.lock().unwrap(); + assert!(original_events.iter().any(|event| { + event.name() == "tool-subscriber-snapshot" + && event.scope_category() == Some(ScopeCategory::Start) + })); + assert!(original_events.iter().any(|event| { + event.name() == "tool-subscriber-snapshot" + && event.scope_category() == Some(ScopeCategory::End) + })); + assert!( + original_events + .iter() + .any(|event| event.name() == "tool.snapshot.mark") + ); + drop(original_events); + assert!(replacement_events.lock().unwrap().is_empty()); + + assert!(deregister_subscriber("tool_lifecycle_replacement").unwrap()); + let mut registrations = plugin_ctx.into_registrations(); + rollback_registrations(&mut registrations); +} + +#[tokio::test] +async fn test_managed_tool_reuses_start_subscriber_snapshot_for_error_end() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let original_events = Arc::new(Mutex::new(Vec::::new())); + let captured_original = original_events.clone(); + register_subscriber( + "tool_error_original", + Arc::new(move |event| captured_original.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + let replacement_events = Arc::new(Mutex::new(Vec::::new())); + let captured_replacement = replacement_events.clone(); + register_tool_execution_intercept( + "mutate_tool_error_subscribers", + 1, + Arc::new(move |_name, _args, _next| { + let captured_replacement = captured_replacement.clone(); + Box::pin(async move { + assert!(deregister_subscriber("tool_error_original").unwrap()); + register_subscriber( + "tool_error_replacement", + Arc::new(move |event| { + captured_replacement.lock().unwrap().push(event.clone()); + }), + ) + .unwrap(); + Err(FlowError::Internal("managed tool failure".into())) + }) + }), + ) + .unwrap(); + + let error = tool_call_execute( + nemo_relay::api::tool::ToolCallExecuteParams::builder() + .name("tool-error-subscriber-snapshot") + .args(json!({})) + .func(Arc::new(|args| Box::pin(async move { Ok(args) }))) + .build(), + ) + .await + .unwrap_err(); + assert!(error.to_string().contains("managed tool failure")); + flush_subscribers().unwrap(); + + let original_events = original_events.lock().unwrap(); + assert!(original_events.iter().any(|event| { + event.name() == "tool-error-subscriber-snapshot" + && event.scope_category() == Some(ScopeCategory::Start) + })); + assert!(original_events.iter().any(|event| { + event.name() == "tool-error-subscriber-snapshot" + && event.scope_category() == Some(ScopeCategory::End) + })); + drop(original_events); + assert!(replacement_events.lock().unwrap().is_empty()); + + deregister_tool_execution_intercept("mutate_tool_error_subscribers").unwrap(); + assert!(deregister_subscriber("tool_error_replacement").unwrap()); +} + +#[tokio::test] +async fn test_repeated_next_marks_follow_invocation_order_not_completion_order() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured_events = events.clone(); + register_subscriber( + "tool_concurrent_next_observer", + Arc::new(move |event| captured_events.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + register_tool_execution_intercept( + "concurrent_next", + 1, + Arc::new(|_name, _args, next| { + Box::pin(async move { + let first = next(json!({"branch": "first", "delay_ms": 40})); + let second = next(json!({"branch": "second", "delay_ms": 1})); + let (first, second) = tokio::join!(first, second); + Ok(json!({ + "first": first?, + "second": second?, + }) + .into()) + }) + }), + ) + .unwrap(); + + let completion_order = Arc::new(Mutex::new(Vec::::new())); + let captured_completion_order = completion_order.clone(); + let mut plugin_ctx = PluginRegistrationContext::new(); + plugin_ctx + .register_tool_execution_intercept( + "delayed_outcomes", + 2, + Arc::new(move |_name, args, next| { + let captured_completion_order = captured_completion_order.clone(); + Box::pin(async move { + let branch = args["branch"].as_str().unwrap().to_string(); + let delay_ms = args["delay_ms"].as_u64().unwrap(); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + let result = next(args).await?; + captured_completion_order + .lock() + .unwrap() + .push(branch.clone()); + Ok( + ToolExecutionInterceptOutcome::new(result).with_pending_mark( + PendingMarkSpec::builder() + .name(format!("tool.concurrent.{branch}")) + .build(), + ), + ) + }) + }), + ) + .unwrap(); + + tool_call_execute( + nemo_relay::api::tool::ToolCallExecuteParams::builder() + .name("tool-concurrent-next") + .args(json!({})) + .func(Arc::new(|args| Box::pin(async move { Ok(args) }))) + .build(), + ) + .await + .unwrap(); + flush_subscribers().unwrap(); + + assert_eq!( + *completion_order.lock().unwrap(), + vec!["second".to_string(), "first".to_string()] + ); + let events = events.lock().unwrap(); + let marks = events + .iter() + .filter(|event| event.name().starts_with("tool.concurrent.")) + .collect::>(); + assert_eq!( + marks.iter().map(|event| event.name()).collect::>(), + ["tool.concurrent.first", "tool.concurrent.second"] + ); + assert!(marks[0].timestamp() < marks[1].timestamp()); + drop(events); + + deregister_tool_execution_intercept("concurrent_next").unwrap(); + let mut registrations = plugin_ctx.into_registrations(); + rollback_registrations(&mut registrations); + deregister_subscriber("tool_concurrent_next_observer").unwrap(); +} + // ========================================================================= // Guardrail Conditional Execution Tests // ========================================================================= @@ -991,7 +1435,7 @@ async fn test_scope_local_execution_intercept_cleanup() { 1, Arc::new(move |_name, args, next| { ic.fetch_add(1, Ordering::SeqCst); - Box::pin(async move { next(args).await }) + Box::pin(async move { next(args).await.map(Into::into) }) }), ) .unwrap(); @@ -1149,7 +1593,7 @@ async fn test_scope_local_and_global_execution_intercept_merge() { o.lock().unwrap().push("global_before".into()); let r = next(args).await; o.lock().unwrap().push("global_after".into()); - r + r.map(Into::into) }) }), ) @@ -1167,7 +1611,7 @@ async fn test_scope_local_and_global_execution_intercept_merge() { o.lock().unwrap().push("local_before".into()); let r = next(args).await; o.lock().unwrap().push("local_after".into()); - r + r.map(Into::into) }) }), ) @@ -1290,7 +1734,7 @@ async fn test_conditional_rejection_prevents_execution() { 1, Arc::new(move |_name, args, next| { ec.store(true, Ordering::SeqCst); - Box::pin(async move { next(args).await }) + Box::pin(async move { next(args).await.map(Into::into) }) }), ) .unwrap(); @@ -1829,7 +2273,7 @@ async fn test_tool_middleware_callbacks_run_without_registry_or_scope_locks() { Arc::new(move |_, args, next| { record_middleware_callback(&tracked, "tool_execution_global"); assert_middleware_callback_locks_are_free(); - Box::pin(async move { next(args).await }) + Box::pin(async move { next(args).await.map(Into::into) }) }), ) .unwrap(); @@ -1841,7 +2285,7 @@ async fn test_tool_middleware_callbacks_run_without_registry_or_scope_locks() { Arc::new(move |_, args, next| { record_middleware_callback(&tracked, "tool_execution_scope"); assert_middleware_callback_locks_are_free(); - Box::pin(async move { next(args).await }) + Box::pin(async move { next(args).await.map(Into::into) }) }), ) .unwrap(); @@ -2233,7 +2677,7 @@ async fn test_full_pipeline_integration() { let o = o4.clone(); Box::pin(async move { o.lock().unwrap().push("execution_intercept".into()); - next(args).await + next(args).await.map(Into::into) }) }), ) diff --git a/crates/core/tests/integration/native_plugin_tests.rs b/crates/core/tests/integration/native_plugin_tests.rs index 0d2b64f0..85f6a4b6 100644 --- a/crates/core/tests/integration/native_plugin_tests.rs +++ b/crates/core/tests/integration/native_plugin_tests.rs @@ -156,6 +156,7 @@ async fn sdk_cdylib_registers_tool_request_intercept() { tool_result["args"]["native_plugin_tool_execution_request"], true ); + assert!(tool_result.get("pending_marks").is_none()); flush_subscribers().expect("native fixture events should flush"); let first_events = events.lock().unwrap().clone(); @@ -197,6 +198,34 @@ async fn sdk_cdylib_registers_tool_request_intercept() { tool_end.output().unwrap()["native_plugin_tool_sanitize_response"], true ); + assert!(tool_end.output().unwrap().get("pending_marks").is_none()); + let tool_mark = find_event(&first_events, "fixture.native.tool_execution.mark", None); + assert_eq!(tool_mark.parent_uuid(), Some(tool_start.uuid())); + assert_eq!( + tool_mark.category().map(|category| category.as_str()), + Some("custom") + ); + assert_eq!( + tool_mark + .category_profile() + .and_then(|profile| profile.subtype.as_deref()), + Some("fixture.native.tool_execution") + ); + assert_eq!(tool_mark.data().unwrap()["source"], "native_tool_execution"); + assert_eq!(tool_mark.metadata().unwrap()["fixture"], true); + assert!(tool_mark.timestamp() > tool_end.timestamp()); + let tool_end_index = first_events + .iter() + .position(|event| { + event.name() == "native-fixture-tool" + && event.scope_category() == Some(ScopeCategory::End) + }) + .unwrap(); + let tool_mark_index = first_events + .iter() + .position(|event| event.name() == "fixture.native.tool_execution.mark") + .unwrap(); + assert!(tool_end_index < tool_mark_index); events.lock().unwrap().clear(); let isolated_next_stack = create_scope_stack(); @@ -506,6 +535,73 @@ async fn native_validation_diagnostics_prevent_initialization() { activation.clear(); } +#[tokio::test] +async fn native_tool_execution_rejects_null_malformed_and_error_outcomes() { + let _guard = NATIVE_PLUGIN_TEST_LOCK.lock().await; + let fixture = build_fixture_plugin(); + let manifest_ref = + write_manifest_with_symbol(&fixture, "nemo_relay_fixture_tool_outcome_errors"); + let activation = load_native_plugins([load_spec("fixture_native", &manifest_ref)]) + .expect("raw native outcome fixture should load"); + let mut cleanup = NativePluginTestCleanup::new(); + + let mut plugin_config = PluginConfig::default(); + plugin_config.components.push(PluginComponentSpec { + kind: "fixture_native".into(), + enabled: true, + config: Map::new(), + }); + initialize_plugins_exact(plugin_config) + .await + .expect("raw native outcome fixture should initialize"); + cleanup.mark_plugin_configuration_active(); + + for (name, expected) in [ + ( + "fixture-null-outcome", + "native tool execution returned null outcome", + ), + ( + "fixture-malformed-outcome", + "invalid native tool execution outcome JSON", + ), + ( + "fixture-status-error-outcome", + "fixture tool execution failed", + ), + ] { + let error = tool_call_execute( + ToolCallExecuteParams::builder() + .name(name) + .args(json!({ "input": true })) + .func(Arc::new(|args| Box::pin(async move { Ok(args) }))) + .build(), + ) + .await + .expect_err("invalid native tool outcome should fail") + .to_string(); + assert!( + error.contains(expected), + "expected {expected:?} in {error:?}" + ); + } + + let result = tool_call_execute( + ToolCallExecuteParams::builder() + .name("fixture-valid-outcome") + .args(json!({ "input": true })) + .func(Arc::new(|args| Box::pin(async move { Ok(args) }))) + .build(), + ) + .await + .expect("native loader should remain usable after rejected outcomes"); + assert_eq!(result["raw_tool_outcome"], true); + assert!(result.get("pending_marks").is_none()); + + drop(cleanup); + activation.clear(); +} + #[test] fn native_loader_rejects_missing_library() { let _guard = NATIVE_PLUGIN_TEST_LOCK.blocking_lock(); diff --git a/crates/core/tests/integration/worker_plugin_tests.rs b/crates/core/tests/integration/worker_plugin_tests.rs index 2e8756a1..842b6c19 100644 --- a/crates/core/tests/integration/worker_plugin_tests.rs +++ b/crates/core/tests/integration/worker_plugin_tests.rs @@ -144,6 +144,9 @@ async fn rust_worker_registers_and_invokes_all_current_surfaces() { tool_end.output().unwrap()["worker_plugin_tool_sanitize_response"], true ); + let tool_mark = find_event(&captured_events, "fixture.worker.tool_execution.mark", None); + assert_eq!(tool_mark.parent_uuid(), Some(tool_start.uuid())); + assert!(tool_mark.timestamp() > tool_end.timestamp()); let llm_execute_response = llm_call_execute( LlmCallExecuteParams::builder() diff --git a/crates/core/tests/unit/plugin_tests.rs b/crates/core/tests/unit/plugin_tests.rs index 282389a2..282c5c57 100644 --- a/crates/core/tests/unit/plugin_tests.rs +++ b/crates/core/tests/unit/plugin_tests.rs @@ -752,7 +752,7 @@ fn test_plugin_registration_context_covers_all_registration_helpers() { ctx.register_tool_execution_intercept( "tool-exec", 1, - Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })), + Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })), ) .unwrap(); ctx.register_llm_request_intercept( @@ -1223,14 +1223,14 @@ fn test_plugin_registration_context_maps_duplicate_registration_errors() { ctx.register_tool_execution_intercept( "tool-exec", 1, - Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })), + Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })), ) .unwrap(); expect_registration_failed( ctx.register_tool_execution_intercept( "tool-exec", 1, - Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })), + Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })), ), "tool execution intercept:", ); @@ -1313,7 +1313,7 @@ fn test_plugin_registration_context_maps_deregistration_errors() { ctx.register_tool_execution_intercept( "tool-exec", 1, - Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })), + Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })), ) .unwrap(); diff --git a/crates/ffi/nemo_relay.h b/crates/ffi/nemo_relay.h index 2b1eaa5f..0f5402c3 100644 --- a/crates/ffi/nemo_relay.h +++ b/crates/ffi/nemo_relay.h @@ -325,6 +325,14 @@ typedef char *(*NemoRelayToolExecNextFn)(const char *args_json, void *next_ctx); * Callback for tool execution intercepts. Receives arguments as JSON plus * a `next` callback and its context. Call `next_fn(args, next_ctx)` to invoke * the next layer in the middleware chain, or return directly to short-circuit. + * The `result` field is passed to the remaining middleware and application; + * `pending_marks` are Relay-owned lifecycle metadata emitted after the + * tool-end event and are not included in the application-visible result. + * The returned JSON must contain a `result` field and may contain a + * `pending_marks` array. The returned string must be allocated with `malloc` + * or an equivalent allocation compatible with `nemo_relay_string_free`. + * Ownership transfers to Relay when the callback returns; the callback must + * not free or reuse the string afterward, and Relay frees it exactly once. */ typedef char *(*NemoRelayToolExecInterceptCb)(void *user_data, const char *args_json, diff --git a/crates/ffi/src/callable.rs b/crates/ffi/src/callable.rs index 9baaba2b..26cb4000 100644 --- a/crates/ffi/src/callable.rs +++ b/crates/ffi/src/callable.rs @@ -25,13 +25,14 @@ use libc::c_char; use nemo_relay::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionNextFn, LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionNextFn, ToolConditionalFn, - ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, + ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; use serde_json::Value as Json; use tokio_stream::{Stream, StreamExt}; use nemo_relay::api::event::Event; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; +use nemo_relay::api::tool::ToolExecutionInterceptOutcome; use nemo_relay::codec::request::AnnotatedLlmRequest as AnnotatedLLMRequest; use nemo_relay::codec::traits::LlmCodec; use nemo_relay::error::{FlowError, Result}; @@ -81,6 +82,14 @@ pub type NemoRelayToolExecNextFn = /// Callback for tool execution intercepts. Receives arguments as JSON plus /// a `next` callback and its context. Call `next_fn(args, next_ctx)` to invoke /// the next layer in the middleware chain, or return directly to short-circuit. +/// The `result` field is passed to the remaining middleware and application; +/// `pending_marks` are Relay-owned lifecycle metadata emitted after the +/// tool-end event and are not included in the application-visible result. +/// The returned JSON must contain a `result` field and may contain a +/// `pending_marks` array. The returned string must be allocated with `malloc` +/// or an equivalent allocation compatible with `nemo_relay_string_free`. +/// Ownership transfers to Relay when the callback returns; the callback must +/// not free or reuse the string afterward, and Relay frees it exactly once. pub type NemoRelayToolExecInterceptCb = unsafe extern "C" fn( user_data: *mut libc::c_void, args_json: *const c_char, @@ -335,19 +344,16 @@ pub fn wrap_tool_exec_fn( }) } -/// Wrap a C tool execution intercept callback into an `Arc ...>`. +/// Wrap a C tool execution intercept callback into a [`ToolExecutionFn`]. /// /// The wrapper packages the Rust `ToolExecutionNextFn` into a C-callable -/// `(next_fn, next_ctx)` pair and passes both to the C intercept callback. +/// `(next_fn, next_ctx)` pair and passes both to the C intercept callback. The +/// callback must return a serialized [`ToolExecutionInterceptOutcome`]. pub fn wrap_tool_exec_intercept_fn( cb: NemoRelayToolExecInterceptCb, user_data: *mut libc::c_void, free_fn: NemoRelayFreeFn, -) -> Arc< - dyn Fn(&str, Json, ToolExecutionNextFn) -> Pin> + Send>> - + Send - + Sync, -> { +) -> ToolExecutionFn { let ud = make_user_data(user_data, free_fn); Arc::new(move |_name: &str, args: Json, next: ToolExecutionNextFn| { let ud = ud.clone(); @@ -387,10 +393,14 @@ pub fn wrap_tool_exec_intercept_fn( let result_ptr = unsafe { cb(ud.ptr, c_args, tool_next_trampoline, next_ctx) }; unsafe { drop(Box::from_raw(next_ctx as *mut ToolExecutionNextFn)) }; unsafe { nemo_relay_string_free_internal(c_args) }; - let result = + let outcome_json = json_result_from_ptr(result_ptr, "tool execution intercept callback failed")?; unsafe { nemo_relay_string_free_internal(result_ptr) }; - Ok(result) + serde_json::from_value::(outcome_json).map_err(|error| { + FlowError::Internal(format!( + "invalid tool execution intercept outcome JSON: {error}" + )) + }) }) }) } diff --git a/crates/ffi/tests/integration/api_tests.rs b/crates/ffi/tests/integration/api_tests.rs index d6668df4..8d9e11b8 100644 --- a/crates/ffi/tests/integration/api_tests.rs +++ b/crates/ffi/tests/integration/api_tests.rs @@ -328,7 +328,16 @@ unsafe extern "C" fn tool_exec_intercept_cb( next_fn: NemoRelayToolExecNextFn, next_ctx: *mut libc::c_void, ) -> *mut c_char { - unsafe { next_fn(args_json, next_ctx) } + let result_ptr = unsafe { next_fn(args_json, next_ctx) }; + if result_ptr.is_null() { + return ptr::null_mut(); + } + let result: Json = + serde_json::from_str(unsafe { CStr::from_ptr(result_ptr) }.to_str().unwrap()).unwrap(); + unsafe { nemo_relay_string_free(result_ptr) }; + CString::new(json!({ "result": result, "pending_marks": [] }).to_string()) + .unwrap() + .into_raw() } unsafe extern "C" fn llm_request_cb( diff --git a/crates/ffi/tests/unit/api/registry_tests.rs b/crates/ffi/tests/unit/api/registry_tests.rs index 72afba04..c0809d2c 100644 --- a/crates/ffi/tests/unit/api/registry_tests.rs +++ b/crates/ffi/tests/unit/api/registry_tests.rs @@ -1154,7 +1154,7 @@ fn test_ffi_duplicate_registration_sweep_and_helper_callbacks() { .unwrap(); assert_eq!( serde_json::from_str::(&tool_intercept_json).unwrap(), - json!({"next": true}) + json!({"result": {"next": true}, "pending_marks": []}) ); let request = cstring(r#"{"headers":{},"content":{"model":"ffi-model","messages":[]}}"#); diff --git a/crates/ffi/tests/unit/api_tests.rs b/crates/ffi/tests/unit/api_tests.rs index 78f37919..cf89acd7 100644 --- a/crates/ffi/tests/unit/api_tests.rs +++ b/crates/ffi/tests/unit/api_tests.rs @@ -325,7 +325,16 @@ unsafe extern "C" fn tool_exec_intercept_cb( next_fn: NemoRelayToolExecNextFn, next_ctx: *mut libc::c_void, ) -> *mut c_char { - unsafe { next_fn(args_json, next_ctx) } + let result_ptr = unsafe { next_fn(args_json, next_ctx) }; + if result_ptr.is_null() { + return ptr::null_mut(); + } + let result: Json = + serde_json::from_str(unsafe { CStr::from_ptr(result_ptr) }.to_str().unwrap()).unwrap(); + unsafe { nemo_relay_string_free(result_ptr) }; + CString::new(json!({ "result": result, "pending_marks": [] }).to_string()) + .unwrap() + .into_raw() } unsafe extern "C" fn llm_request_cb( diff --git a/crates/ffi/tests/unit/callable_tests.rs b/crates/ffi/tests/unit/callable_tests.rs index b8fc9a30..c512d732 100644 --- a/crates/ffi/tests/unit/callable_tests.rs +++ b/crates/ffi/tests/unit/callable_tests.rs @@ -93,7 +93,32 @@ unsafe extern "C" fn tool_exec_intercept_cb( serde_json::from_str(unsafe { CStr::from_ptr(result_ptr) }.to_str().unwrap()).unwrap(); unsafe { nemo_relay_string_free_internal(result_ptr) }; result["intercepted"] = json!(true); - CString::new(result.to_string()).unwrap().into_raw() + CString::new( + json!({ + "result": result, + "pending_marks": [{ + "name": "ffi.tool.execution", + "category": "custom", + "category_profile": { "subtype": "ffi.tool.execution" }, + "data": { "source": "c" }, + "metadata": { "fixture": true }, + }], + }) + .to_string(), + ) + .unwrap() + .into_raw() +} + +unsafe extern "C" fn tool_exec_legacy_intercept_cb( + _user_data: *mut libc::c_void, + _args_json: *const c_char, + _next_fn: NemoRelayToolExecNextFn, + _next_ctx: *mut libc::c_void, +) -> *mut c_char { + CString::new(r#"{"legacy_result":true}"#) + .unwrap() + .into_raw() } /// Intercept-specific callback with the unified annotated-aware signature @@ -270,8 +295,34 @@ fn test_wrap_tool_exec_and_intercept_callbacks() { let intercepted = runtime .block_on(intercept("tool", json!({"v": 1}), next)) .unwrap(); - assert_eq!(intercepted["intercepted"], json!(true)); - assert_eq!(intercepted["from_next"]["v"], json!(1)); + assert_eq!(intercepted.result["intercepted"], json!(true)); + assert_eq!(intercepted.result["from_next"]["v"], json!(1)); + assert_eq!(intercepted.pending_marks.len(), 1); + let mark = &intercepted.pending_marks[0]; + assert_eq!(mark.name, "ffi.tool.execution"); + assert_eq!( + mark.category.as_ref().map(|category| category.as_str()), + Some("custom") + ); + assert_eq!( + mark.category_profile + .as_ref() + .and_then(|profile| profile.subtype.as_deref()), + Some("ffi.tool.execution") + ); + assert_eq!(mark.data.as_ref().unwrap()["source"], "c"); + assert_eq!(mark.metadata.as_ref().unwrap()["fixture"], true); + + let legacy_intercept = + wrap_tool_exec_intercept_fn(tool_exec_legacy_intercept_cb, std::ptr::null_mut(), None); + let next: ToolExecutionNextFn = Arc::new(|args| Box::pin(async move { Ok(args) })); + let err = runtime + .block_on(legacy_intercept("tool", json!({}), next)) + .unwrap_err(); + assert!( + err.to_string() + .contains("invalid tool execution intercept outcome JSON") + ); let failing_intercept = wrap_tool_exec_intercept_fn(tool_exec_intercept_cb, std::ptr::null_mut(), None); diff --git a/crates/node/plugin.d.ts b/crates/node/plugin.d.ts index e0dc1d4b..7d8ac88d 100644 --- a/crates/node/plugin.d.ts +++ b/crates/node/plugin.d.ts @@ -45,6 +45,27 @@ export interface PluginConfig { policy?: ConfigPolicy; } +/** A mark Relay materializes under a managed lifecycle. */ +export interface PendingMarkSpec { + name: string; + category?: string | null; + categoryProfile?: Json; + data?: Json; + metadata?: Json; +} + +/** + * Canonical result returned by a tool execution intercept. + * + * `result` is passed to the remaining middleware and application. `pendingMarks` + * are Relay-owned lifecycle metadata emitted after the tool-end event and are + * not included in the application-visible result. + */ +export interface ToolExecutionInterceptOutcome { + result: Json; + pendingMarks?: PendingMarkSpec[]; +} + /** Component-scoped registration context passed to plugin handlers. */ export interface PluginContext { /** Register an infallible event subscriber for this component. */ @@ -85,13 +106,7 @@ export interface PluginContext { callback: (args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; - pendingMarks?: Array<{ - name: string; - category?: string | null; - categoryProfile?: Json; - data?: Json; - metadata?: Json; - }>; + pendingMarks?: PendingMarkSpec[]; }, ): void; /** Register an LLM execution intercept for this component. */ @@ -116,11 +131,17 @@ export interface PluginContext { breakChain: boolean, callback: (name: string, args: Json) => Json, ): void; - /** Register a tool execution intercept for this component. */ + /** + * Register tool execution middleware that returns a canonical outcome. + * The `next` callback resolves to the raw downstream result. + */ registerToolExecutionIntercept( name: string, priority: number, - callback: (args: Json, next: (args: Json) => Json | Promise) => Json | Promise, + callback: ( + args: Json, + next: (args: Json) => Json | Promise, + ) => ToolExecutionInterceptOutcome | Promise, ): void; } diff --git a/crates/node/src/api/mod.rs b/crates/node/src/api/mod.rs index 503a335b..efd757fd 100644 --- a/crates/node/src/api/mod.rs +++ b/crates/node/src/api/mod.rs @@ -2130,6 +2130,9 @@ pub fn register_tool_execution_intercept( env: Env, name: String, priority: i32, + #[napi( + ts_arg_type = "(args: Json, next: (args: Json) => Json | Promise) => { result: Json; pendingMarks?: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> } | Promise<{ result: Json; pendingMarks?: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> }>" + )] callable: JsFunction, ) -> Result<()> { let key = PromiseAwareKey::GlobalToolExecution(name.clone()); @@ -2560,6 +2563,9 @@ pub fn scope_register_tool_execution_intercept( scope_uuid: String, name: String, priority: i32, + #[napi( + ts_arg_type = "(args: Json, next: (args: Json) => Json | Promise) => { result: Json; pendingMarks?: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> } | Promise<{ result: Json; pendingMarks?: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> }>" + )] callable: JsFunction, ) -> Result<()> { let key = PromiseAwareKey::ScopeToolExecution { diff --git a/crates/node/src/callable.rs b/crates/node/src/callable.rs index 54cb93bf..57ba722c 100644 --- a/crates/node/src/callable.rs +++ b/crates/node/src/callable.rs @@ -25,6 +25,7 @@ use tokio_stream::StreamExt; use nemo_relay::api::event::{CategoryProfile, Event, EventCategory, PendingMarkSpec}; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; +use nemo_relay::api::tool::ToolExecutionInterceptOutcome; use nemo_relay::codec::request::AnnotatedLlmRequest; use nemo_relay::codec::response::AnnotatedLlmResponse; use nemo_relay::codec::traits::{LlmCodec, LlmResponseCodec}; @@ -625,21 +626,35 @@ pub fn wrap_js_response_codec( }) } -/// Wrap a JS function `(args, next) => result` for tool execution intercept. +/// Wrap a JS function `(args, next) => { result, pendingMarks? }` for tool execution intercept. /// /// The JS callback receives the tool arguments and a real `next(args)` function /// that returns a Promise for the downstream result. pub fn wrap_js_tool_exec_intercept_fn( func: Arc, -) -> Arc< - dyn Fn(&str, Json, ToolExecutionNextFn) -> Pin> + Send>> - + Send - + Sync, -> { +) -> nemo_relay::api::runtime::ToolExecutionFn { Arc::new(move |_name: &str, args: Json, next: ToolExecutionNextFn| { let func = func.clone(); let next_json: JsonNextFn = Arc::new(move |next_args| next(next_args)); - Box::pin(async move { func.call_with_json_next(args, next_json).await }) + Box::pin(async move { + let result = func.call_with_json_next(args, next_json).await?; + #[derive(Deserialize)] + #[serde(rename_all = "camelCase")] + struct JsOutcome { + result: Json, + #[serde(default)] + pending_marks: Vec, + } + let outcome: JsOutcome = serde_json::from_value(result).map_err(|error| { + FlowError::Internal(format!( + "invalid JS tool execution intercept outcome: {error}" + )) + })?; + Ok(ToolExecutionInterceptOutcome { + result: outcome.result, + pending_marks: outcome.pending_marks.into_iter().map(Into::into).collect(), + }) + }) }) } diff --git a/crates/node/tests/scope_local_tests.mjs b/crates/node/tests/scope_local_tests.mjs index 313e78e0..6a5b94f5 100644 --- a/crates/node/tests/scope_local_tests.mjs +++ b/crates/node/tests/scope_local_tests.mjs @@ -494,8 +494,10 @@ describe('Scope-local auto-cleanup on scope pop', () => { fromPoppedScope: true, }); return { - ...result, - wrapped: true, + result: { + ...result, + wrapped: true, + }, }; }); popScope(scope); @@ -699,8 +701,10 @@ describe('Priority merge of global and scope-local middleware', () => { from_global: true, }); return { - ...result, - global_exec: true, + result: { + ...result, + global_exec: true, + }, }; }); @@ -711,8 +715,10 @@ describe('Priority merge of global and scope-local middleware', () => { from_scope: true, }); return { - ...result, - scope_exec: true, + result: { + ...result, + scope_exec: true, + }, }; }); @@ -1253,7 +1259,10 @@ describe('Scope-local subscriber receives events', () => { () => scopeDeregisterToolConditionalExecutionGuardrail('not-a-uuid', 'bad_tool_cond'), () => scopeRegisterToolRequestIntercept('not-a-uuid', 'bad_tool_int', 10, false, (_name, args) => args), () => scopeDeregisterToolRequestIntercept('not-a-uuid', 'bad_tool_int'), - () => scopeRegisterToolExecutionIntercept('not-a-uuid', 'bad_tool_exec', 10, async (args, next) => next(args)), + () => + scopeRegisterToolExecutionIntercept('not-a-uuid', 'bad_tool_exec', 10, async (args, next) => ({ + result: await next(args), + })), () => scopeDeregisterToolExecutionIntercept('not-a-uuid', 'bad_tool_exec'), () => scopeRegisterLlmSanitizeRequestGuardrail('not-a-uuid', 'bad_llm_req', 10, (request) => request), () => scopeDeregisterLlmSanitizeRequestGuardrail('not-a-uuid', 'bad_llm_req'), diff --git a/crates/node/tests/tools_tests.mjs b/crates/node/tests/tools_tests.mjs index 9b7265a1..6cb9fdc0 100644 --- a/crates/node/tests/tools_tests.mjs +++ b/crates/node/tests/tools_tests.mjs @@ -571,7 +571,7 @@ describe('Tool intercepts', () => { }); it('execution intercept register/deregister', () => { - registerToolExecutionIntercept('node_tool_exec_int', 10, async (args, next) => next(args)); + registerToolExecutionIntercept('node_tool_exec_int', 10, async (args, next) => ({ result: await next(args) })); deregisterToolExecutionIntercept('node_tool_exec_int'); }); @@ -627,30 +627,59 @@ describe('Tool intercepts', () => { }); it('execution intercept composes with next', async () => { + const events = []; + registerSubscriber('node_tool_exec_mark_sub', (event) => events.push(event)); registerToolExecutionIntercept('node_tool_exec_repl', 10, async (args, next) => { const result = await next({ ...args, intercepted: true, }); return { - ...result, - wrapped: true, + result: { + ...result, + wrapped: true, + }, + pendingMarks: [{ name: 'node.tool.execution' }], }; }); - const result = await toolCallExecute( - 'replaced_tool', - {}, - (args) => ({ - original: !args.intercepted, - }), - null, - null, - null, - null, - ); - assert.equal(result.original, false); - assert.equal(result.wrapped, true); - deregisterToolExecutionIntercept('node_tool_exec_repl'); + try { + const result = await toolCallExecute( + 'replaced_tool', + {}, + (args) => ({ + original: !args.intercepted, + }), + null, + null, + null, + null, + ); + assert.equal(result.original, false); + assert.equal(result.wrapped, true); + await waitForSubscriberCallbacks( + () => + events.some( + (event) => + event.name === 'replaced_tool' && event.kind === 'scope' && event.scope_category === 'end', + ) && events.some((event) => event.name === 'node.tool.execution'), + ); + const start = events.find( + (event) => event.name === 'replaced_tool' && event.kind === 'scope' && event.scope_category === 'start', + ); + const end = events.find( + (event) => event.name === 'replaced_tool' && event.kind === 'scope' && event.scope_category === 'end', + ); + const mark = events.find((event) => event.name === 'node.tool.execution'); + assert.ok(start, 'expected tool start event'); + assert.ok(end, 'expected tool end event'); + assert.ok(mark, 'expected tool execution pending mark'); + assert.equal(mark.parent_uuid, start.uuid); + assert.ok(events.indexOf(end) < events.indexOf(mark), 'expected tool end before pending mark'); + assert.ok(mark.timestamp > end.timestamp, 'expected pending mark timestamp after tool end'); + } finally { + deregisterToolExecutionIntercept('node_tool_exec_repl'); + deregisterSubscriber('node_tool_exec_mark_sub'); + } }); it('execution intercept propagates Error messages', async () => { @@ -678,6 +707,33 @@ describe('Tool intercepts', () => { } }); + it('execution intercept rejects legacy raw results', async () => { + registerToolExecutionIntercept('node_tool_exec_legacy', 10, async () => ({ legacyResult: true })); + try { + await assert.rejects( + () => toolCallExecute('legacy_tool', {}, (args) => args, null, null, null, null), + /invalid JS tool execution intercept outcome/i, + ); + } finally { + deregisterToolExecutionIntercept('node_tool_exec_legacy'); + } + }); + + it('execution intercept rejects unknown pending-mark fields', async () => { + registerToolExecutionIntercept('node_tool_exec_bad_mark', 10, async () => ({ + result: { ok: true }, + pendingMarks: [{ name: 'node.bad.mark', category_profile: { subtype: 'invalid.snake.case' } }], + })); + try { + await assert.rejects( + () => toolCallExecute('bad_mark_tool', {}, (args) => args, null, null, null, null), + /unknown field.*category_profile/i, + ); + } finally { + deregisterToolExecutionIntercept('node_tool_exec_bad_mark'); + } + }); + it('async execute falls back to unknown error for primitive rejections', async () => { await assert.rejects( () => diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 662b191a..3f96f982 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -21,7 +21,7 @@ pub use nemo_relay_types::api::event::{ }; pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; pub use nemo_relay_types::api::scope::{HandleAttributes, ScopeAttributes, ScopeType}; -pub use nemo_relay_types::api::tool::ToolAttributes; +pub use nemo_relay_types::api::tool::{ToolAttributes, ToolExecutionInterceptOutcome}; pub use nemo_relay_types::codec::request::AnnotatedLlmRequest; pub use nemo_relay_types::codec::response::AnnotatedLlmResponse; pub use nemo_relay_types::plugin::{ConfigDiagnostic, DiagnosticLevel}; @@ -223,7 +223,7 @@ pub type NemoRelayNativeToolExecutionCb = unsafe extern "C" fn( args_json: *const NemoRelayNativeString, next_fn: NemoRelayNativeToolNextFn, next_ctx: *mut c_void, - out_json: *mut *mut NemoRelayNativeString, + out_outcome_json: *mut *mut NemoRelayNativeString, ) -> NemoRelayStatus; /// Native LLM request transform callback for request sanitizers. @@ -1353,6 +1353,10 @@ impl<'a> PluginContext<'a> { } /// Registers a typed tool execution intercept. + /// + /// The callback returns a [`ToolExecutionInterceptOutcome`]. Calling + /// [`ToolNext::call`] continues the chain and returns only the raw + /// downstream result JSON; Relay retains downstream pending marks. pub fn register_tool_execution_intercept( &mut self, name: &str, @@ -1360,7 +1364,10 @@ impl<'a> PluginContext<'a> { callback: F, ) -> Result<()> where - F: for<'next> Fn(&str, Json, ToolNext<'next>) -> Result + Send + Sync + 'static, + F: for<'next> Fn(&str, Json, ToolNext<'next>) -> Result + + Send + + Sync + + 'static, { let user_data = typed_callback_user_data(self.host, callback); let status = unsafe { @@ -1996,15 +2003,18 @@ unsafe extern "C" fn typed_tool_execution_trampoline( args_json: *const NemoRelayNativeString, next_fn: NemoRelayNativeToolNextFn, next_ctx: *mut c_void, - out_json: *mut *mut NemoRelayNativeString, + out_outcome_json: *mut *mut NemoRelayNativeString, ) -> NemoRelayStatus where - F: for<'next> Fn(&str, Json, ToolNext<'next>) -> Result + Send + Sync + 'static, + F: for<'next> Fn(&str, Json, ToolNext<'next>) -> Result + + Send + + Sync + + 'static, { - if user_data.is_null() || out_json.is_null() { + if user_data.is_null() || out_outcome_json.is_null() { return NemoRelayStatus::NullPointer; } - unsafe { *out_json = ptr::null_mut() }; + unsafe { *out_outcome_json = ptr::null_mut() }; let state = unsafe { &*(user_data as *const TypedCallback) }; let result = catch_unwind(AssertUnwindSafe(|| { let name = read_required_host_string(&state.host, name, "tool name")?; @@ -2015,7 +2025,15 @@ where next_ctx, }; match (state.callback)(&name, args, next) { - Ok(output) => Ok::<_, NemoRelayStatus>(write_json(&state.host, &output, out_json)), + Ok(outcome) => { + let Some(outcome) = HostString::from_json(&state.host, &outcome) else { + set_last_error(&state.host, "failed to allocate tool execution outcome"); + return Ok(NemoRelayStatus::Internal); + }; + unsafe { *out_outcome_json = outcome.ptr }; + std::mem::forget(outcome); + Ok(NemoRelayStatus::Ok) + } Err(message) => Ok(callback_error(&state.host, message)), } })); diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs index b1fba367..19495518 100644 --- a/crates/plugin/tests/typed_callbacks.rs +++ b/crates/plugin/tests/typed_callbacks.rs @@ -13,8 +13,8 @@ use std::sync::{ }; use nemo_relay_plugin::{ - AnnotatedLlmRequest, ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmNext, - LlmRequest, LlmRequestInterceptOutcome, LlmStream, LlmStreamNext, + AnnotatedLlmRequest, CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, + Json, LlmJsonStream, LlmNext, LlmRequest, LlmRequestInterceptOutcome, LlmStream, LlmStreamNext, NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, @@ -23,7 +23,8 @@ use nemo_relay_plugin::{ NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, NemoRelayNativeString, NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, NemoRelayNativeWithScopeStackCb, - NemoRelayStatus, PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, ToolNext, + NemoRelayStatus, PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, + ToolExecutionInterceptOutcome, ToolNext, }; use serde_json::{Map, json}; @@ -2464,7 +2465,7 @@ fn typed_callbacks_reject_null_abi_pointers_before_decoding_inputs() { } let mut ctx = test_context(&host); - ctx.register_tool_execution_intercept("tool-exec", 0, |_name, value, _next| Ok(value)) + ctx.register_tool_execution_intercept("tool-exec", 0, |_name, value, _next| Ok(value.into())) .unwrap(); let registration = take_tool_execution_registration(); let name = host_string(&host, "tool"); @@ -2789,7 +2790,7 @@ fn typed_callbacks_report_invalid_json_for_each_decoder_family() { } let mut ctx = test_context(&host); - ctx.register_tool_execution_intercept("tool-exec", 0, |_name, value, _next| Ok(value)) + ctx.register_tool_execution_intercept("tool-exec", 0, |_name, value, _next| Ok(value.into())) .unwrap(); let registration = take_tool_execution_registration(); let name = host_string(&host, "tool"); @@ -3486,7 +3487,21 @@ fn typed_tool_execution_registration_calls_next() { let called = Arc::new(AtomicUsize::new(0)); let mut ctx = test_context(&host); ctx.register_tool_execution_intercept("tool", 23, |_name, args, next: ToolNext<'_>| { - next.call(args) + let result = next.call(args)?; + Ok( + ToolExecutionInterceptOutcome::new(result).with_pending_mark( + PendingMarkSpec::builder() + .name("plugin.tool.completed") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("plugin.tool.pending".into()), + ..CategoryProfile::default() + }) + .data(json!({ "saved_tokens": 7 })) + .metadata(json!({ "source": "typed-test" })) + .build(), + ), + ) }) .unwrap(); @@ -3512,7 +3527,19 @@ fn typed_tool_execution_registration_calls_next() { }; assert_eq!(status, NemoRelayStatus::Ok); assert_eq!(called.load(Ordering::SeqCst), 1); - assert_eq!(read_json_and_free(&host, out)["next_called"], json!(true)); + let outcome = read_json_and_free(&host, out); + assert_eq!(outcome["result"]["next_called"], json!(true)); + assert_eq!(outcome["pending_marks"][0]["name"], "plugin.tool.completed"); + assert_eq!(outcome["pending_marks"][0]["category"], "custom"); + assert_eq!( + outcome["pending_marks"][0]["category_profile"]["subtype"], + "plugin.tool.pending" + ); + assert_eq!(outcome["pending_marks"][0]["data"]["saved_tokens"], 7); + assert_eq!( + outcome["pending_marks"][0]["metadata"]["source"], + "typed-test" + ); unsafe { (host.string_free)(name); (host.string_free)(args); @@ -3521,6 +3548,45 @@ fn typed_tool_execution_registration_calls_next() { } } +#[test] +fn typed_tool_execution_does_not_publish_partial_outcome() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_tool_execution_intercept("tool", 0, |_name, args, _next| { + Ok(ToolExecutionInterceptOutcome::new(args)) + }) + .unwrap(); + + let registration = take_tool_execution_registration(); + let name = host_string(&host, "tool"); + let args = json_host_string(&host, json!({ "input": true })); + let stale_outcome = host_string(&host, r#"{"stale":true}"#); + let mut out_outcome = stale_outcome; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); + let live_before = live_host_strings(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + args, + fake_tool_next, + ptr::null_mut(), + &mut out_outcome, + ) + }; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; + assert_eq!(status, NemoRelayStatus::Internal); + assert!(out_outcome.is_null()); + assert_eq!(live_host_strings(), live_before); + unsafe { + (host.string_free)(stale_outcome); + (host.string_free)(name); + (host.string_free)(args); + registration.free(); + } +} + #[test] fn typed_tool_execution_surfaces_next_status_failures() { let _guard = begin_test(); @@ -3528,7 +3594,7 @@ fn typed_tool_execution_surfaces_next_status_failures() { let called = Arc::new(AtomicUsize::new(0)); let mut ctx = test_context(&host); ctx.register_tool_execution_intercept("tool", 0, |_name, args, next: ToolNext<'_>| { - next.call(args) + next.call(args).map(Into::into) }) .unwrap(); @@ -3572,7 +3638,7 @@ fn typed_tool_execution_surfaces_invalid_next_json() { let called = Arc::new(AtomicUsize::new(0)); let mut ctx = test_context(&host); ctx.register_tool_execution_intercept("tool", 0, |_name, args, next: ToolNext<'_>| { - next.call(args) + next.call(args).map(Into::into) }) .unwrap(); @@ -3618,7 +3684,7 @@ fn typed_tool_execution_surfaces_null_next_output() { let called = Arc::new(AtomicUsize::new(0)); let mut ctx = test_context(&host); ctx.register_tool_execution_intercept("tool", 0, |_name, args, next: ToolNext<'_>| { - next.call(args) + next.call(args).map(Into::into) }) .unwrap(); diff --git a/crates/python/src/py_callable.rs b/crates/python/src/py_callable.rs index 7b8629be..8d59fe72 100644 --- a/crates/python/src/py_callable.rs +++ b/crates/python/src/py_callable.rs @@ -36,6 +36,7 @@ use tokio_stream::Stream; use nemo_relay::api::event::Event; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; +use nemo_relay::api::tool::ToolExecutionInterceptOutcome; use nemo_relay::codec::request::AnnotatedLlmRequest as AnnotatedLLMRequest; use nemo_relay::codec::response::AnnotatedLlmResponse as AnnotatedLLMResponse; use nemo_relay::codec::traits::{LlmCodec, LlmResponseCodec}; @@ -43,6 +44,7 @@ use nemo_relay::codec::traits::{LlmCodec, LlmResponseCodec}; use crate::convert::{json_to_py, py_to_json}; use crate::py_types::{ PyAnnotatedLLMRequest, PyAnnotatedLLMResponse, PyLLMRequest, PyLLMRequestInterceptOutcome, + PyToolExecutionInterceptOutcome, }; type PyValueFuture = Pin>> + Send>>; @@ -393,26 +395,21 @@ impl PyLlmStreamNextFn { } } -/// Wrap a Python callable `(Json, next) -> Json` for tool execution intercepts. +/// Wrap a Python callable `(Json, next) -> ToolExecutionInterceptOutcome` for tool execution intercepts. /// The `next` parameter is a `PyToolNextFn` that the Python code can `await`. pub fn wrap_py_tool_exec_intercept_fn( py_fn: Py, -) -> Arc< - dyn Fn( - &str, - Json, - ToolExecutionNextFn, - ) -> Pin> + Send>> - + Send - + Sync, -> { +) -> nemo_relay::api::runtime::ToolExecutionFn { let py_fn = Arc::new(py_fn); Arc::new(move |name: &str, args: Json, next: ToolExecutionNextFn| { let py_fn = py_fn.clone(); let name = name.to_string(); Box::pin(async move { let outcome: FlowResult< - Result>> + Send>>>, + Result< + ToolExecutionInterceptOutcome, + Pin>> + Send>>, + >, > = Python::attach(|py| { let py_args = json_to_py(py, &args).map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; @@ -440,9 +437,14 @@ pub fn wrap_py_tool_exec_intercept_fn( Box>> + Send>, >)) } else { - let json = - py_to_json(bound).map_err(|e: PyErr| FlowError::Internal(e.to_string()))?; - Ok(Ok(json)) + let outcome = result + .extract::(py) + .map_err(|e| { + FlowError::Internal(format!( + "tool execution intercept must return ToolExecutionInterceptOutcome: {e}" + )) + })?; + Ok(Ok(outcome.inner)) } }); @@ -453,8 +455,14 @@ pub fn wrap_py_tool_exec_intercept_fn( .await .map_err(|e| FlowError::Internal(e.to_string()))?; Python::attach(|py| { - py_to_json(py_result.bind(py)) - .map_err(|e: PyErr| FlowError::Internal(e.to_string())) + py_result + .extract::(py) + .map(|value| value.inner) + .map_err(|e| { + FlowError::Internal(format!( + "tool execution intercept must return ToolExecutionInterceptOutcome: {e}" + )) + }) }) } } diff --git a/crates/python/src/py_types/core.rs b/crates/python/src/py_types/core.rs index fd7aba21..aa6fc16b 100644 --- a/crates/python/src/py_types/core.rs +++ b/crates/python/src/py_types/core.rs @@ -10,6 +10,7 @@ use super::{ }; use nemo_relay::api::event::{CategoryProfile, EventCategory, PendingMarkSpec}; use nemo_relay::api::llm::LlmRequestInterceptOutcome; +use nemo_relay::api::tool::ToolExecutionInterceptOutcome; // --------------------------------------------------------------------------- // LlmStream (async iterator) @@ -728,3 +729,39 @@ impl PyLLMRequestInterceptOutcome { .collect() } } + +/// Canonical result returned by Python tool execution intercepts. +#[pyclass(name = "ToolExecutionInterceptOutcome", from_py_object)] +#[derive(Clone)] +pub struct PyToolExecutionInterceptOutcome { + pub inner: ToolExecutionInterceptOutcome, +} + +#[pymethods] +impl PyToolExecutionInterceptOutcome { + #[new] + #[pyo3(signature = (result, pending_marks=Vec::new()))] + fn new(result: &Bound<'_, PyAny>, pending_marks: Vec) -> PyResult { + Ok(Self { + inner: ToolExecutionInterceptOutcome { + result: py_to_json(result)?, + pending_marks: pending_marks.into_iter().map(|value| value.inner).collect(), + }, + }) + } + + #[getter] + fn result(&self, py: Python<'_>) -> PyResult> { + json_to_py(py, &self.inner.result) + } + + #[getter] + fn pending_marks(&self) -> Vec { + self.inner + .pending_marks + .iter() + .cloned() + .map(|inner| PyPendingMarkSpec { inner }) + .collect() + } +} diff --git a/crates/python/src/py_types/mod.rs b/crates/python/src/py_types/mod.rs index 7a2b2826..6fcd1dec 100644 --- a/crates/python/src/py_types/mod.rs +++ b/crates/python/src/py_types/mod.rs @@ -134,6 +134,7 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 4fa67313..49e98ebf 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -38,6 +38,12 @@ fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { ) .unwrap(); module + .setattr( + "ToolOutcome", + py.get_type::(), + ) + .unwrap(); + module } fn make_request() -> LlmRequest { @@ -409,7 +415,7 @@ def tool_request_intercept(name, value): return value async def tool_execution_intercept(name, value, next): - return await next(value) + return ToolOutcome(await next(value)) class CoveragePlugin: def validate(self, plugin_config): @@ -714,7 +720,7 @@ async def tool_exec(args): async def tool_intercept(name, args, next): result = await next({"x": args["x"] + 1}) result["wrapped"] = True - return result + return ToolOutcome(result) async def llm_exec(request): return {"model": request.content["model"]} @@ -726,6 +732,13 @@ async def llm_intercept(name, request, next): "#, ); + module + .setattr( + "ToolOutcome", + py.get_type::(), + ) + .unwrap(); + let tool_exec_py: Py = module.getattr("tool_exec").unwrap().unbind(); let tool_intercept_py: Py = module.getattr("tool_intercept").unwrap().unbind(); let llm_exec_py: Py = module.getattr("llm_exec").unwrap().unbind(); @@ -746,7 +759,7 @@ async def llm_intercept(name, request, next): tool_intercept("tool", json!({"x": 2}), tool_next) .await .unwrap(), - json!({"next": 3, "wrapped": true}) + json!({"next": 3, "wrapped": true}).into() ); let llm_exec = wrap_py_llm_exec_fn(llm_exec_py); diff --git a/crates/python/tests/coverage/py_api_coverage_tests.rs b/crates/python/tests/coverage/py_api_coverage_tests.rs index 9ce9107a..fa1bb80a 100644 --- a/crates/python/tests/coverage/py_api_coverage_tests.rs +++ b/crates/python/tests/coverage/py_api_coverage_tests.rs @@ -177,7 +177,7 @@ async def tool_exec(args): async def tool_exec_intercept(name, args, next): result = await next({"value": args["value"] + 3}) result["tool_intercepted"] = True - return result + return ToolExecutionInterceptOutcome(result) def llm_sanitize_request(request): return request @@ -334,6 +334,14 @@ async def run_stream(api, request, func, collector, finalizer, handle, attribute types_module.getattr("LLMRequestInterceptOutcome").unwrap(), ) .unwrap(); + helpers + .setattr( + "ToolExecutionInterceptOutcome", + types_module + .getattr("ToolExecutionInterceptOutcome") + .unwrap(), + ) + .unwrap(); helpers .setattr( "AnnotatedLLMRequest", diff --git a/crates/python/tests/coverage/py_callable_coverage_tests.rs b/crates/python/tests/coverage/py_callable_coverage_tests.rs index cac6bbd6..10f3908a 100644 --- a/crates/python/tests/coverage/py_callable_coverage_tests.rs +++ b/crates/python/tests/coverage/py_callable_coverage_tests.rs @@ -51,7 +51,7 @@ def sync_tool_exec(args): return {"sync_tool": args["x"] + 1} def sync_tool_intercept(name, args, next): - return {"name": name, "value": args["x"] + 2} + return ToolOutcome({"name": name, "value": args["x"] + 2}) def sync_llm_exec(request): return {"model": request.content["model"], "mode": "sync"} @@ -99,6 +99,12 @@ class RaisingResponseCodec: py.get_type::(), ) .unwrap(); + module + .setattr( + "ToolOutcome", + py.get_type::(), + ) + .unwrap(); let tool_exec_py: Py = module.getattr("sync_tool_exec").unwrap().unbind(); let tool_intercept_py: Py = module.getattr("sync_tool_intercept").unwrap().unbind(); @@ -120,7 +126,7 @@ class RaisingResponseCodec: tool_intercept("tool", json!({"x": 3}), tool_next) .await .unwrap(), - json!({"name": "tool", "value": 5}) + json!({"name": "tool", "value": 5}).into() ); let llm_exec = wrap_py_llm_exec_fn(llm_exec_py); diff --git a/crates/python/tests/coverage/py_plugin_coverage_tests.rs b/crates/python/tests/coverage/py_plugin_coverage_tests.rs index 0eba263c..33b9c822 100644 --- a/crates/python/tests/coverage/py_plugin_coverage_tests.rs +++ b/crates/python/tests/coverage/py_plugin_coverage_tests.rs @@ -24,6 +24,12 @@ fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { ) .unwrap(); module + .setattr( + "ToolOutcome", + py.get_type::(), + ) + .unwrap(); + module } fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> T { @@ -142,7 +148,7 @@ def tool_request_intercept(name, value): return value async def tool_execution_intercept(name, value, next): - return await next(value) + return ToolOutcome(await next(value)) "#, ); @@ -598,7 +604,7 @@ def tool_request_intercept(name, value): return value async def tool_execution_intercept(name, value, next): - return await next(value) + return ToolOutcome(await next(value)) "#, ); @@ -882,7 +888,7 @@ def tool_request_intercept(name, value): return value async def tool_execution_intercept(name, value, next): - return await next(value) + return ToolOutcome(await next(value)) "#, ); diff --git a/crates/types/src/api/event.rs b/crates/types/src/api/event.rs index d08bf818..839419d7 100644 --- a/crates/types/src/api/event.rs +++ b/crates/types/src/api/event.rs @@ -364,7 +364,7 @@ pub struct MarkEvent { pub category_profile: Option, } -/// Mark requested by middleware before its owning runtime scope exists. +/// Mark requested by middleware for materialization by a lifecycle owner. /// /// The runtime assigns the parent UUID, event UUID, and timestamp when it /// materializes the mark at the appropriate lifecycle boundary. diff --git a/crates/types/src/api/tool.rs b/crates/types/src/api/tool.rs index db4f5980..37242507 100644 --- a/crates/types/src/api/tool.rs +++ b/crates/types/src/api/tool.rs @@ -6,6 +6,9 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; +use crate::Json; +use crate::api::event::PendingMarkSpec; + bitflags! { /// Bitflags that modify tool-call behavior and observability. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -14,3 +17,40 @@ bitflags! { const REMOTE = 0b01; } } + +/// Canonical result returned by a tool execution intercept. +/// +/// `result` is passed to the remaining middleware and application. `pending_marks` +/// are Relay-owned lifecycle metadata retained separately and emitted after the +/// tool-end event; they are not included in the application-visible result. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ToolExecutionInterceptOutcome { + /// Tool result returned to the remaining middleware and application. + pub result: Json, + /// Ordered marks for the managed tool lifecycle owner to emit. + #[serde(default)] + pub pending_marks: Vec, +} + +impl ToolExecutionInterceptOutcome { + /// Create an outcome without pending marks. + pub fn new(result: Json) -> Self { + Self { + result, + pending_marks: Vec::new(), + } + } + + /// Append one pending mark while preserving callback order. + #[must_use] + pub fn with_pending_mark(mut self, mark: PendingMarkSpec) -> Self { + self.pending_marks.push(mark); + self + } +} + +impl From for ToolExecutionInterceptOutcome { + fn from(result: Json) -> Self { + Self::new(result) + } +} diff --git a/crates/types/tests/serialization_tests.rs b/crates/types/tests/serialization_tests.rs index 39bb8fbc..bb050744 100644 --- a/crates/types/tests/serialization_tests.rs +++ b/crates/types/tests/serialization_tests.rs @@ -10,6 +10,7 @@ use nemo_relay_types::api::event::{ llm_attributes_to_strings, }; use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; +use nemo_relay_types::api::tool::ToolExecutionInterceptOutcome; use nemo_relay_types::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; use nemo_relay_types::codec::response::AnnotatedLlmResponse; use serde_json::{Map, json}; @@ -171,3 +172,53 @@ fn llm_request_intercept_outcome_converts_from_request_inputs() { LlmRequestInterceptOutcome::new(request, Some(annotated_request)) ); } + +#[test] +fn tool_execution_intercept_outcome_round_trips_pending_marks() { + let outcome = ToolExecutionInterceptOutcome::new(json!({"stdout": "compacted"})) + .with_pending_mark( + PendingMarkSpec::builder() + .name("tool.output.compacted") + .category(EventCategory::custom()) + .category_profile( + CategoryProfile::builder() + .subtype("optimizer.saved_tokens") + .build(), + ) + .data(json!({"saved_tokens": 12})) + .metadata(json!({"source": "test"})) + .build(), + ); + + let encoded = serde_json::to_value(&outcome).expect("outcome should serialize"); + assert_eq!(encoded["result"]["stdout"], "compacted"); + assert_eq!(encoded["pending_marks"][0]["name"], "tool.output.compacted"); + assert_eq!(encoded["pending_marks"][0]["category"], "custom"); + + let decoded: ToolExecutionInterceptOutcome = + serde_json::from_value(encoded).expect("outcome should deserialize"); + assert_eq!(decoded, outcome); + + let defaults: ToolExecutionInterceptOutcome = serde_json::from_value(json!({ + "result": "plain", + "future_field": true + })) + .expect("omitted pending marks and unknown fields should be accepted"); + assert!(defaults.pending_marks.is_empty()); + assert_eq!(defaults.result, json!("plain")); + + assert!( + serde_json::from_value::(json!({ + "pending_marks": [] + })) + .is_err(), + "result is required" + ); +} + +#[test] +fn tool_execution_intercept_outcome_converts_from_json() { + let result = json!({"value": 42}); + let outcome: ToolExecutionInterceptOutcome = result.clone().into(); + assert_eq!(outcome, ToolExecutionInterceptOutcome::new(result)); +} diff --git a/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto b/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto index 9a7e7f70..9dbe9a67 100644 --- a/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto +++ b/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto @@ -166,6 +166,7 @@ message InvokeResponse { GuardrailResult guardrail = 3; LlmRequestInterceptResult llm_request = 4; WorkerError error = 5; + ToolExecutionInterceptResult tool_execution = 6; } } @@ -184,6 +185,10 @@ message LlmRequestInterceptResult { JsonEnvelope outcome = 1; } +message ToolExecutionInterceptResult { + JsonEnvelope outcome = 1; +} + message StreamChunk { oneof item { JsonEnvelope value = 1; diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index 9115e6d3..56921f54 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -37,6 +37,7 @@ pub use nemo_relay_types::Json; pub use nemo_relay_types::api::event::{Event, PendingMarkSpec}; pub use nemo_relay_types::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; pub use nemo_relay_types::api::scope::ScopeType; +pub use nemo_relay_types::api::tool::ToolExecutionInterceptOutcome; use nemo_relay_types::codec::request::AnnotatedLlmRequest; pub use nemo_relay_types::plugin::{ConfigDiagnostic, DiagnosticLevel}; use nemo_relay_worker_proto::v1::plugin_worker_server::{PluginWorker, PluginWorkerServer}; @@ -47,8 +48,8 @@ use nemo_relay_worker_proto::v1::{ HealthResponse, InvokeRequest, InvokeResponse, JsonEnvelope, JsonResult, LlmNextRequest, LlmRequestInterceptResult, LlmStreamNextRequest, PopScopeRequest, PushScopeRequest, RegisterRequest, RegisterResponse, Registration, RegistrationSurface, ScopeContext, - ShutdownRequest, StreamChunk, ToolNextRequest, ValidateRequest, ValidateResponse, WorkerAck, - WorkerError, + ShutdownRequest, StreamChunk, ToolExecutionInterceptResult, ToolNextRequest, ValidateRequest, + ValidateResponse, WorkerAck, WorkerError, }; use nemo_relay_worker_proto::{WORKER_PROTOCOL_GRPC_V1, decode_json_envelope, json_envelope}; use tokio::net::TcpListener; @@ -120,7 +121,9 @@ type SubscriberFn = Arc; type ToolSanitizeFn = Arc Json + Send + Sync>; type ToolConditionalFn = Arc Result> + Send + Sync>; type ToolRequestFn = Arc Result + Send + Sync>; -type ToolExecutionFn = Arc BoxFutureResult + Send + Sync>; +type ToolExecutionFn = Arc< + dyn Fn(&str, Json, ToolNext) -> BoxFutureResult + Send + Sync, +>; type LlmSanitizeRequestFn = Arc LlmRequest + Send + Sync>; type LlmSanitizeResponseFn = Arc Json + Send + Sync>; type LlmConditionalFn = Arc Result> + Send + Sync>; @@ -271,6 +274,10 @@ impl PluginContext { } /// Registers a tool execution intercept. + /// + /// The callback returns a [`ToolExecutionInterceptOutcome`]. Calling + /// [`ToolNext::call`] continues the chain and returns only the raw + /// downstream result JSON; Relay retains downstream pending marks. pub fn register_tool_execution_intercept( &mut self, name: &str, @@ -278,7 +285,7 @@ impl PluginContext { callback: F, ) where F: Fn(&str, Json, ToolNext) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + 'static, + Fut: Future> + Send + 'static, { self.push_registration( name, @@ -1364,7 +1371,7 @@ impl WorkerService { }; let future = with_thread_scope(&scope, || handler(&payload.tool_name, payload.value, next)); - Ok(json_response(future.await?)) + Ok(tool_execution_response(future.await?)?) } RegistrationSurface::LlmSanitizeRequestGuardrail => { let payload = llm_payload(request.payload)?; @@ -1671,6 +1678,21 @@ fn llm_request_response(outcome: LlmRequestInterceptOutcome) -> Result Result { + Ok(InvokeResponse { + result: Some( + nemo_relay_worker_proto::v1::invoke_response::Result::ToolExecution( + ToolExecutionInterceptResult { + outcome: Some(json_envelope( + "nemo.relay.ToolExecutionInterceptOutcome@1", + &outcome, + )?), + }, + ), + ), + }) +} + fn stream_chunk_to_json(chunk: StreamChunk) -> Result { match chunk.item { Some(nemo_relay_worker_proto::v1::stream_chunk::Item::Value(value)) => { diff --git a/crates/worker/tests/worker_sdk_tests.rs b/crates/worker/tests/worker_sdk_tests.rs index f8baa1fa..cbd84de5 100644 --- a/crates/worker/tests/worker_sdk_tests.rs +++ b/crates/worker/tests/worker_sdk_tests.rs @@ -17,11 +17,11 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use futures_util::{Stream, StreamExt}; #[cfg(unix)] use hyper_util::rt::TokioIo; -use nemo_relay_types::api::event::{BaseEvent, Event, MarkEvent}; +use nemo_relay_types::api::event::{BaseEvent, Event, MarkEvent, PendingMarkSpec}; use nemo_relay_worker::{ Json, JsonStream, LlmNext, LlmRequest, LlmStreamNext, PluginContext, PluginRuntime, Result, - ScopeType, ToolNext, WorkerPlugin, WorkerSdkError, WorkerServerConfig, serve_plugin, - serve_plugin_arc, serve_plugin_arc_with_config, + ScopeType, ToolExecutionInterceptOutcome, ToolNext, WorkerPlugin, WorkerSdkError, + WorkerServerConfig, serve_plugin, serve_plugin_arc, serve_plugin_arc_with_config, }; use nemo_relay_worker_proto::v1::plugin_worker_client::PluginWorkerClient; use nemo_relay_worker_proto::v1::relay_host_runtime_server::{ @@ -505,7 +505,7 @@ async fn worker_service_invokes_every_registration_surface() { "phase", "tool_request", ); - let tool_exec = invoke_json( + let tool_outcome = invoke_tool_execution( &mut client, tool_invoke( "tool-exec", @@ -514,6 +514,9 @@ async fn worker_service_invokes_every_registration_surface() { ), ) .await; + assert_eq!(tool_outcome.pending_marks.len(), 1); + assert_eq!(tool_outcome.pending_marks[0].name, "worker.tool.execution"); + let tool_exec = tool_outcome.result; assert_json_field(tool_exec.clone(), "next", "tool"); assert_json_field(tool_exec, "phase", "tool_exec"); assert!( @@ -1400,7 +1403,7 @@ impl WorkerPlugin for CancellationPlugin { async move { let _cancelled = CancelledOnDrop(unary_cancelled); unary_started.notify_one(); - std::future::pending::>().await + std::future::pending::>().await } }); @@ -1540,7 +1543,16 @@ impl WorkerPlugin for SurfacePlugin { runtime.pop_scope(&handle, None, None).await?; runtime.drop_scope_stack(&stack_id).await?; let next_value = next.call(value).await?; - Ok(set_json_field(next_value, "phase", "tool_exec")) + Ok(ToolExecutionInterceptOutcome::new(set_json_field( + next_value, + "phase", + "tool_exec", + )) + .with_pending_mark( + PendingMarkSpec::builder() + .name("worker.tool.execution") + .build(), + )) } }); let scope_runtime = runtime.clone(); @@ -1572,7 +1584,7 @@ impl WorkerPlugin for SurfacePlugin { .await?; runtime.pop_scope(&handle, None, None).await?; } - Ok(Json::Null) + Ok(Json::Null.into()) } }); @@ -2251,6 +2263,32 @@ async fn invoke_json(client: &mut PluginWorkerClient, request: InvokeRe nemo_relay_worker_proto::v1::invoke_response::Result::Json(result) => { decode_json_envelope(&result.value.expect("json value")).expect("decode JSON result") } + nemo_relay_worker_proto::v1::invoke_response::Result::ToolExecution(result) => { + decode_json_envelope::( + &result.outcome.expect("tool execution outcome"), + ) + .expect("decode tool execution outcome") + .result + } + other => panic!("unexpected invoke result: {other:?}"), + } +} + +async fn invoke_tool_execution( + client: &mut PluginWorkerClient, + request: InvokeRequest, +) -> ToolExecutionInterceptOutcome { + let response = client + .invoke(Request::new(request)) + .await + .expect("invoke succeeds") + .into_inner(); + match response.result.expect("invoke result") { + nemo_relay_worker_proto::v1::invoke_response::Result::ToolExecution(result) => { + let outcome = result.outcome.expect("tool execution outcome"); + assert_eq!(outcome.schema, "nemo.relay.ToolExecutionInterceptOutcome@1"); + decode_json_envelope(&outcome).expect("decode tool execution outcome") + } other => panic!("unexpected invoke result: {other:?}"), } } diff --git a/examples/rust-native-plugin/src/lib.rs b/examples/rust-native-plugin/src/lib.rs index 4d9130e8..af19ee8f 100644 --- a/examples/rust-native-plugin/src/lib.rs +++ b/examples/rust-native-plugin/src/lib.rs @@ -4,7 +4,7 @@ use nemo_relay_plugin::{ CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, Json, LlmJsonStream, LlmRequest, LlmRequestInterceptOutcome, NativePlugin, PendingMarkSpec, PluginContext, - PluginRuntime, ScopeCategory, ScopeType, + PluginRuntime, ScopeCategory, ScopeType, ToolExecutionInterceptOutcome, }; use serde_json::{Map, json}; @@ -215,7 +215,20 @@ impl NativePlugin for ExampleNativePlugin { move |_name, args, next| { let request = tag_json(args, "native_tool_execution_request", &tag); let result = next.call(request)?; - Ok(tag_json(result, "native_tool_execution_response", &tag)) + let result = tag_json(result, "native_tool_execution_response", &tag); + Ok( + ToolExecutionInterceptOutcome::new(result).with_pending_mark( + PendingMarkSpec::builder() + .name("example.native.tool_execution") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("example.native.tool_result_rewrite".into()), + ..CategoryProfile::default() + }) + .data(json!({ "tag": &tag })) + .build(), + ), + ) } })?; diff --git a/go/nemo_relay/adaptive_plugin_test.go b/go/nemo_relay/adaptive_plugin_test.go index 7b53a713..025ae70f 100644 --- a/go/nemo_relay/adaptive_plugin_test.go +++ b/go/nemo_relay/adaptive_plugin_test.go @@ -129,12 +129,13 @@ func registerLifecycleInterceptors(ctx *PluginContext, pluginKind string) error if err := ctx.RegisterToolExecutionIntercept( "tool_exec", 7, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { resultJSON, err := next(args) if err != nil { - return nil, err + return ToolExecutionInterceptOutcome{}, err } - return decorateJSONPayload(resultJSON, "goToolExecPlugin", pluginKind) + result, err := decorateJSONPayload(resultJSON, "goToolExecPlugin", pluginKind) + return ToolExecutionInterceptOutcome{Result: result}, err }, ); err != nil { return err @@ -460,8 +461,8 @@ func TestPluginFuncsAndClosedContextBranches(t *testing.T) { }) }}, {"tool execution", func() error { - return closed.RegisterToolExecutionIntercept("tool_exec", 1, func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { - return next(args) + return closed.RegisterToolExecutionIntercept("tool_exec", 1, func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { + return toolExecutionOutcome(next(args)) }) }}, } diff --git a/go/nemo_relay/callbacks.go b/go/nemo_relay/callbacks.go index 856e6fa2..8d771043 100644 --- a/go/nemo_relay/callbacks.go +++ b/go/nemo_relay/callbacks.go @@ -157,8 +157,9 @@ type ToolExecutionFunc func(args json.RawMessage) (json.RawMessage, error) // following the middleware chain pattern. It receives the tool arguments and // a `next` function. Call `next` to invoke the next intercept in the chain // (or the original tool implementation if this is the innermost intercept). -// Skip calling `next` to short-circuit the chain entirely. -type ToolExecutionInterceptFunc func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) +// Skip calling `next` to short-circuit the chain entirely. The callback returns +// the canonical outcome containing the tool result and any pending marks. +type ToolExecutionInterceptFunc func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) // LLMResponseFunc is a callback that transforms an LLM response. It receives // the response as plain JSON and must return the (possibly modified) response @@ -218,7 +219,7 @@ type LLMRequestDTO struct { Content json.RawMessage `json:"content"` } -// PendingMarkSpec describes a mark Relay emits after starting a managed LLM call. +// PendingMarkSpec describes a mark Relay materializes under a managed lifecycle. type PendingMarkSpec struct { Name string `json:"name"` Category *string `json:"category,omitempty"` @@ -234,6 +235,15 @@ type LLMRequestInterceptOutcome struct { PendingMarks []PendingMarkSpec `json:"pending_marks"` } +// ToolExecutionInterceptOutcome is the canonical result of a tool execution +// intercept. Result is passed to the remaining middleware and application; +// PendingMarks are Relay-owned lifecycle metadata emitted after the tool-end +// event and are not included in the application-visible result. +type ToolExecutionInterceptOutcome struct { + Result json.RawMessage `json:"result"` + PendingMarks []PendingMarkSpec `json:"pending_marks"` +} + // LLMRequestInterceptFunc is a callback for LLM request intercepts. When // annotatedJSON is non-nil, request.Content is read-only, request.Headers may // be changed, and the returned annotation is authoritative for provider body @@ -484,12 +494,20 @@ func goToolExecInterceptTrampoline(userData unsafe.Pointer, argsJSON *C.char, ne defer C.nemo_relay_string_free(result) return json.RawMessage(C.GoString(result)), nil } - result, err := fn(goArgs, goNext) + outcome, err := fn(goArgs, goNext) if err != nil { setLastErrorMessage(err.Error()) return nil } - return C.CString(string(result)) + if outcome.PendingMarks == nil { + outcome.PendingMarks = []PendingMarkSpec{} + } + outcomeJSON, err := jsonMarshal(outcome) + if err != nil { + setLastErrorMessage(err.Error()) + return nil + } + return C.CString(string(outcomeJSON)) } //export goLlmExecInterceptTrampoline diff --git a/go/nemo_relay/callbacks_test.go b/go/nemo_relay/callbacks_test.go index 0da74a49..2655608d 100644 --- a/go/nemo_relay/callbacks_test.go +++ b/go/nemo_relay/callbacks_test.go @@ -8,6 +8,10 @@ import ( "testing" ) +func toolExecutionOutcome(result json.RawMessage, err error) (ToolExecutionInterceptOutcome, error) { + return ToolExecutionInterceptOutcome{Result: result}, err +} + func TestRegisterAndUnregisterClosure(t *testing.T) { fn := ToolExecutionFunc(func(args json.RawMessage) (json.RawMessage, error) { return args, nil diff --git a/go/nemo_relay/deregister_test.go b/go/nemo_relay/deregister_test.go index 437b4764..ab6446e2 100644 --- a/go/nemo_relay/deregister_test.go +++ b/go/nemo_relay/deregister_test.go @@ -135,8 +135,8 @@ func TestRegisterDeregisterReregisterToolRequestIntercept(t *testing.T) { func TestRegisterDeregisterReregisterToolExecutionIntercept(t *testing.T) { name := "go_reregister_exec_int" - fn := func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { - return next(args) + fn := func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { + return toolExecutionOutcome(next(args)) } err := RegisterToolExecutionIntercept(name, 1, fn) @@ -235,8 +235,8 @@ func TestDeregisterAllInterceptTypes(t *testing.T) { func(n string, args json.RawMessage) json.RawMessage { return args }, ) RegisterToolExecutionIntercept("go_dereg_all_exec_int", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { - return next(args) + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { + return toolExecutionOutcome(next(args)) }, ) diff --git a/go/nemo_relay/error_test.go b/go/nemo_relay/error_test.go index ccee08d2..1d32b108 100644 --- a/go/nemo_relay/error_test.go +++ b/go/nemo_relay/error_test.go @@ -89,8 +89,8 @@ func TestAlreadyExistsErrorOnDuplicateToolRequestIntercept(t *testing.T) { func TestAlreadyExistsErrorOnDuplicateToolExecutionIntercept(t *testing.T) { name := "go_err_dup_exec_int" - fn := func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { - return next(args) + fn := func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { + return toolExecutionOutcome(next(args)) } err := RegisterToolExecutionIntercept(name, 1, fn) diff --git a/go/nemo_relay/intercepts/intercepts_test.go b/go/nemo_relay/intercepts/intercepts_test.go index ed16007b..3b298310 100644 --- a/go/nemo_relay/intercepts/intercepts_test.go +++ b/go/nemo_relay/intercepts/intercepts_test.go @@ -47,16 +47,16 @@ func runGlobalToolInterceptShorthandChecks(t *testing.T) { } if err := intercepts.RegisterToolExecution("intercepts_tool_exec", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (nemo_relay.ToolExecutionInterceptOutcome, error) { result, err := next(args) if err != nil { - return nil, err + return nemo_relay.ToolExecutionInterceptOutcome{}, err } var payload map[string]interface{} _ = json.Unmarshal(result, &payload) payload["wrapped"] = true out, _ := json.Marshal(payload) - return out, nil + return nemo_relay.ToolExecutionInterceptOutcome{Result: out}, nil }, ); err != nil { t.Fatalf("RegisterToolExecution failed: %v", err) @@ -174,8 +174,9 @@ func runScopeLocalToolInterceptShorthandChecks(t *testing.T, scopeUUID string) { t.Fatalf("ScopeRegisterToolRequest failed: %v", err) } if err := intercepts.ScopeRegisterToolExecution(scopeUUID, "intercepts_scope_tool_exec", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { - return next(args) + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (nemo_relay.ToolExecutionInterceptOutcome, error) { + result, err := next(args) + return nemo_relay.ToolExecutionInterceptOutcome{Result: result}, err }, ); err != nil { t.Fatalf("ScopeRegisterToolExecution failed: %v", err) diff --git a/go/nemo_relay/scope_local_test.go b/go/nemo_relay/scope_local_test.go index 9e3e6235..ca9fe3be 100644 --- a/go/nemo_relay/scope_local_test.go +++ b/go/nemo_relay/scope_local_test.go @@ -698,16 +698,16 @@ func TestScopeLocalToolExecutionIntercept(t *testing.T) { stack.Run(func() { handle, _ := PushScope("exec_intercept_scope", ScopeTypeAgent) defer PopScope(handle) - err := ScopeRegisterToolExecutionIntercept(handle.UUID(), "scope_exec_int", 1, func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { + err := ScopeRegisterToolExecutionIntercept(handle.UUID(), "scope_exec_int", 1, func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { result, err := next(args) if err != nil { - return nil, err + return ToolExecutionInterceptOutcome{}, err } var m map[string]interface{} json.Unmarshal(result, &m) m["exec_intercepted"] = true out, _ := json.Marshal(m) - return out, nil + return ToolExecutionInterceptOutcome{Result: out}, nil }) if err != nil { t.Fatalf("ScopeRegisterToolExecutionIntercept failed: %v", err) @@ -1034,9 +1034,9 @@ func assertScopeLocalToolWrappersDeregister(t *testing.T, scopeUUID string) { &executionInterceptCalls, func() error { return ScopeRegisterToolExecutionIntercept(scopeUUID, "tool_scope_exec_int", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { executionInterceptCalls++ - return next(args) + return toolExecutionOutcome(next(args)) }, ) }, diff --git a/go/nemo_relay/tools_test.go b/go/nemo_relay/tools_test.go index 930e589c..cb355c42 100644 --- a/go/nemo_relay/tools_test.go +++ b/go/nemo_relay/tools_test.go @@ -314,8 +314,8 @@ func TestToolRequestInterceptRegisterDeregister(t *testing.T) { func TestToolExecutionInterceptRegisterDeregister(t *testing.T) { err := RegisterToolExecutionIntercept("go_exec_int", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { - return next(args) + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { + return toolExecutionOutcome(next(args)) }, ) if err != nil { @@ -368,9 +368,9 @@ func TestToolRequestInterceptModifiesArgs(t *testing.T) { func TestToolExecutionInterceptReplacesFunc(t *testing.T) { RegisterToolExecutionIntercept("go_exec_replace", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { // Short-circuit: don't call next, return directly - return json.RawMessage(`{"from_intercept": true}`), nil + return ToolExecutionInterceptOutcome{Result: json.RawMessage(`{"from_intercept": true}`)}, nil }, ) @@ -454,16 +454,16 @@ func TestToolFullPipelineInterceptsAndExecute(t *testing.T) { // Register an execution intercept that wraps the callable RegisterToolExecutionIntercept("go_pipe_exec_int", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { result, err := next(args) if err != nil { - return nil, err + return ToolExecutionInterceptOutcome{}, err } var m map[string]interface{} json.Unmarshal(result, &m) m["exec_intercepted"] = true out, _ := json.Marshal(m) - return out, nil + return ToolExecutionInterceptOutcome{Result: out}, nil }, ) defer DeregisterToolExecutionIntercept("go_pipe_exec_int") @@ -668,7 +668,7 @@ func TestToolCallableErrorPropagation(t *testing.T) { func TestToolExecutionInterceptWrapsCallable(t *testing.T) { // Register an execution intercept that modifies args and result RegisterToolExecutionIntercept("go_wrap_exec_int", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { // Before: modify args var m map[string]interface{} json.Unmarshal(args, &m) @@ -678,7 +678,7 @@ func TestToolExecutionInterceptWrapsCallable(t *testing.T) { // Call the next function in the chain result, err := next(modifiedArgs) if err != nil { - return nil, err + return ToolExecutionInterceptOutcome{}, err } // After: modify result @@ -686,7 +686,7 @@ func TestToolExecutionInterceptWrapsCallable(t *testing.T) { json.Unmarshal(result, &out) out["after_exec"] = true final, _ := json.Marshal(out) - return final, nil + return ToolExecutionInterceptOutcome{Result: final}, nil }, ) defer DeregisterToolExecutionIntercept("go_wrap_exec_int") @@ -718,10 +718,115 @@ func TestToolExecutionInterceptWrapsCallable(t *testing.T) { } } +func TestToolExecutionInterceptEmitsPendingMarks(t *testing.T) { + const ( + interceptName = "go_pending_mark_exec" + subscriberName = "go_pending_mark_sub" + toolName = "go_pending_mark_tool" + markName = "go.tool.execution" + ) + + var mu sync.Mutex + events := make([]Event, 0, 3) + if err := RegisterSubscriber(subscriberName, func(event Event) { + mu.Lock() + events = append(events, event) + mu.Unlock() + }); err != nil { + t.Fatalf("RegisterSubscriber failed: %v", err) + } + t.Cleanup(func() { _ = DeregisterSubscriber(subscriberName) }) + + category := "custom" + if err := RegisterToolExecutionIntercept( + interceptName, + 1, + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { + result, err := next(args) + if err != nil { + return ToolExecutionInterceptOutcome{}, err + } + return ToolExecutionInterceptOutcome{ + Result: result, + PendingMarks: []PendingMarkSpec{{ + Name: markName, + Category: &category, + CategoryProfile: json.RawMessage(`{"subtype":"go.tool.execution"}`), + Data: json.RawMessage(`{"source":"go"}`), + Metadata: json.RawMessage(`{"fixture":true}`), + }}, + }, nil + }, + ); err != nil { + t.Fatalf(registerFailed, err) + } + t.Cleanup(func() { _ = DeregisterToolExecutionIntercept(interceptName) }) + + result, err := ToolCallExecute( + toolName, + json.RawMessage(`{"value":42}`), + func(args json.RawMessage) (json.RawMessage, error) { return args, nil }, + ) + if err != nil { + t.Fatalf(toolCallExecuteFailed, err) + } + var applicationResult map[string]any + if err := json.Unmarshal(result, &applicationResult); err != nil { + t.Fatalf("decode tool result: %v", err) + } + if applicationResult["value"] != float64(42) { + t.Fatalf("unexpected tool result: %s", result) + } + if _, leaked := applicationResult["pending_marks"]; leaked { + t.Fatalf("pending marks leaked into tool result: %s", result) + } + + if err := FlushSubscribers(); err != nil { + t.Fatalf(toolFlushSubscribersFailed, err) + } + mu.Lock() + captured := append([]Event(nil), events...) + mu.Unlock() + + startIndex, endIndex, markIndex := -1, -1, -1 + var start, mark Event + for index, event := range captured { + switch { + case event.Name() == toolName && event.Kind() == "scope" && event.ScopeCategory() == "start": + startIndex, start = index, event + case event.Name() == toolName && event.Kind() == "scope" && event.ScopeCategory() == "end": + endIndex = index + case event.Name() == markName && event.Kind() == "mark": + markIndex, mark = index, event + } + } + if startIndex < 0 || endIndex < 0 || markIndex < 0 { + t.Fatalf("missing lifecycle events: start=%d end=%d mark=%d", startIndex, endIndex, markIndex) + } + if !(startIndex < endIndex && endIndex < markIndex) { + t.Fatalf("unexpected lifecycle order: start=%d end=%d mark=%d", startIndex, endIndex, markIndex) + } + if mark.ParentUUID() != start.UUID() { + t.Fatalf("mark parent %q does not match tool UUID %q", mark.ParentUUID(), start.UUID()) + } + if mark.Category() != category { + t.Fatalf("mark category = %q, expected %q", mark.Category(), category) + } + assertJSONFieldString(t, mark.CategoryProfile(), "subtype", "go.tool.execution") + assertJSONFieldString(t, mark.Data(), "source", "go") + var metadata map[string]any + if err := json.Unmarshal(mark.Metadata(), &metadata); err != nil { + t.Fatalf("decode mark metadata: %v", err) + } + if metadata["fixture"] != true { + t.Fatalf("unexpected mark metadata: %s", mark.Metadata()) + } +} + func TestToolExecutionInterceptSeesNextError(t *testing.T) { RegisterToolExecutionIntercept("go_wrap_exec_err", 1, - func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (json.RawMessage, error) { - return next(args) + func(args json.RawMessage, next func(json.RawMessage) (json.RawMessage, error)) (ToolExecutionInterceptOutcome, error) { + return toolExecutionOutcome(next(args)) }, ) defer DeregisterToolExecutionIntercept("go_wrap_exec_err") diff --git a/python/nemo_relay/__init__.py b/python/nemo_relay/__init__.py index 285c0a9d..bd18a47d 100644 --- a/python/nemo_relay/__init__.py +++ b/python/nemo_relay/__init__.py @@ -111,6 +111,7 @@ async def main(): ScopeStack, ScopeType, ToolAttributes, + ToolExecutionInterceptOutcome, ToolHandle, ) from nemo_relay._native import create_scope_stack as _create_scope_stack @@ -164,7 +165,7 @@ async def main(): #: next callable. It may await and return ``next(args)`` or short-circuit. ToolExecutionIntercept: TypeAlias = Callable[ [str, Json, Callable[[Json], Awaitable[Json]]], - Json | Awaitable[Json], + ToolExecutionInterceptOutcome | Awaitable[ToolExecutionInterceptOutcome], ] #: Request intercept callback that returns the canonical request, annotation, #: and pending-mark outcome passed to later intercepts and managed execution. @@ -458,6 +459,7 @@ def worker() -> None: "MarkEvent", "ScopeHandle", "ToolHandle", + "ToolExecutionInterceptOutcome", "LLMHandle", "LLMRequest", "LLMRequestInterceptOutcome", diff --git a/python/nemo_relay/__init__.pyi b/python/nemo_relay/__init__.pyi index 2c5a9145..61400d12 100644 --- a/python/nemo_relay/__init__.pyi +++ b/python/nemo_relay/__init__.pyi @@ -108,6 +108,9 @@ from nemo_relay._native import ( from nemo_relay._native import ( ToolAttributes as ToolAttributes, ) +from nemo_relay._native import ( + ToolExecutionInterceptOutcome as ToolExecutionInterceptOutcome, +) from nemo_relay._native import ( ToolHandle as ToolHandle, ) @@ -199,7 +202,7 @@ Return: """ ToolExecutionIntercept: TypeAlias = Callable[ [str, Json, Callable[[Json], Awaitable[Json]]], - Json | Awaitable[Json], + ToolExecutionInterceptOutcome | Awaitable[ToolExecutionInterceptOutcome], ] """Execution intercept callback that wraps tool execution. @@ -207,7 +210,7 @@ Arguments: The tool name, current JSON arguments, and next callable. Return: - A JSON-compatible result, either directly or as an awaitable. + A canonical tool execution outcome, either directly or as an awaitable. Exceptional flow: The callback may short-circuit by not invoking ``next``. Exceptions diff --git a/python/nemo_relay/_native.pyi b/python/nemo_relay/_native.pyi index 6acc400f..81a211cf 100644 --- a/python/nemo_relay/_native.pyi +++ b/python/nemo_relay/_native.pyi @@ -38,7 +38,7 @@ _LlmConditionalExecutionGuardrail: TypeAlias = Callable[["LLMRequest"], Optional _ToolRequestIntercept: TypeAlias = Callable[[str, _Json], _Json] _ToolExecutionIntercept: TypeAlias = Callable[ [str, _Json, Callable[[_Json], Awaitable[_Json]]], - _Json | Awaitable[_Json], + "ToolExecutionInterceptOutcome | Awaitable[ToolExecutionInterceptOutcome]", ] _LlmRequestIntercept: TypeAlias = Callable[ [str, "LLMRequest", "AnnotatedLLMRequest | None"], @@ -383,7 +383,7 @@ class LLMRequest: ... class PendingMarkSpec: - """A runtime-owned mark specification returned by request middleware.""" + """A runtime-owned mark specification returned by lifecycle middleware.""" def __init__( self, name: str, @@ -418,6 +418,23 @@ class LLMRequestInterceptOutcome: @property def pending_marks(self) -> list[PendingMarkSpec]: ... +class ToolExecutionInterceptOutcome: + """Canonical result returned by a tool execution intercept. + + ``result`` is passed to the remaining middleware and application. + ``pending_marks`` are Relay-owned lifecycle metadata emitted after the + tool-end event and are not included in the application-visible result. + """ + def __init__( + self, + result: _Json, + pending_marks: list[PendingMarkSpec] = ..., + ) -> None: ... + @property + def result(self) -> _Json: ... + @property + def pending_marks(self) -> list[PendingMarkSpec]: ... + class AnnotatedLLMRequest: """Structured view of an LLM request produced by a codec. @@ -1745,7 +1762,10 @@ def register_tool_execution_intercept(name: str, priority: int, callable: _ToolE Args: name: Unique intercept name. priority: Execution order; lower values run first. - callable: Middleware callback that may call or short-circuit ``next``. + callable: Middleware callback returning + ``ToolExecutionInterceptOutcome``. It may call or short-circuit + ``next``; ``next`` resolves to the raw downstream result while + Relay retains downstream pending marks. Returns: ``None``. @@ -1975,7 +1995,10 @@ def scope_register_tool_execution_intercept( scope_uuid: UUID of the owning scope. name: Unique intercept name within that scope. priority: Execution order; lower values run first. - callable: Middleware callback used while the owning scope is active. + callable: Middleware callback returning + ``ToolExecutionInterceptOutcome`` while the owning scope is active. + Its ``next`` continuation resolves to the raw downstream result + while Relay retains downstream pending marks. Returns: ``None``. diff --git a/python/nemo_relay/intercepts.py b/python/nemo_relay/intercepts.py index 5392a6c4..0ed3b124 100644 --- a/python/nemo_relay/intercepts.py +++ b/python/nemo_relay/intercepts.py @@ -128,6 +128,7 @@ def register_tool_execution(name: str, priority: int, fn: ToolExecutionIntercept fn: Callable invoked as ``fn(tool_name, args, next_call)``. The callback may await or call ``next_call(args)`` to continue the chain, modify the result, or bypass downstream execution entirely. + It must return ``ToolExecutionInterceptOutcome``. Returns: None: This function returns after the intercept is registered. diff --git a/python/nemo_relay/scope_local.py b/python/nemo_relay/scope_local.py index 909a5bcc..7c1191c5 100644 --- a/python/nemo_relay/scope_local.py +++ b/python/nemo_relay/scope_local.py @@ -271,7 +271,8 @@ def register_tool_execution(scope_handle, name, priority, fn): priority: Execution order for the intercept. Lower values run first. fn: Callable invoked as ``fn(tool_name, args, next_call)``. It may call ``next_call(args)`` to continue execution, modify the result, or - short-circuit the tool call entirely. + short-circuit the tool call entirely. It must return + ``ToolExecutionInterceptOutcome``. Returns: None: This function returns after the scope-local intercept is diff --git a/python/plugin/src/nemo_relay_plugin/__init__.py b/python/plugin/src/nemo_relay_plugin/__init__.py index 676a0de3..7d11e087 100644 --- a/python/plugin/src/nemo_relay_plugin/__init__.py +++ b/python/plugin/src/nemo_relay_plugin/__init__.py @@ -21,8 +21,9 @@ LlmRequest: A Relay LLM request represented as a JSON object. AnnotatedLlmRequest: An annotated Relay LLM request represented as a JSON object. - PendingMarkSpec: A mark Relay emits under the future managed LLM scope. + PendingMarkSpec: A mark Relay emits under its managed lifecycle scope. LlmRequestInterceptOutcome: Canonical LLM request-intercept result. + ToolExecutionInterceptOutcome: Canonical tool execution-intercept result. DiagnosticLevel: Severity of a configuration diagnostic. ConfigDiagnostic: Structured configuration warning or error. ScopeType: Semantic category for a Relay execution scope. @@ -76,6 +77,7 @@ SubscriberCallback, ToolConditionalCallback, ToolExecutionCallback, + ToolExecutionInterceptOutcome, ToolNext, ToolRequestCallback, ToolSanitizeCallback, @@ -107,6 +109,7 @@ "SubscriberCallback", "ToolConditionalCallback", "ToolExecutionCallback", + "ToolExecutionInterceptOutcome", "ToolNext", "ToolRequestCallback", "ToolSanitizeCallback", diff --git a/python/plugin/src/nemo_relay_plugin/_api.py b/python/plugin/src/nemo_relay_plugin/_api.py index 43b1d9ce..615e8d86 100644 --- a/python/plugin/src/nemo_relay_plugin/_api.py +++ b/python/plugin/src/nemo_relay_plugin/_api.py @@ -16,7 +16,7 @@ LlmRequest: A Relay LLM request represented as a JSON object. AnnotatedLlmRequest: An annotated Relay LLM request represented as a JSON object. - PendingMarkSpec: A mark Relay emits under the future managed LLM scope. + PendingMarkSpec: A mark Relay emits under its managed lifecycle scope. LlmRequestInterceptOutcome: Canonical LLM request-intercept result. DiagnosticLevel: Severity of a configuration diagnostic. ConfigDiagnostic: Structured configuration warning or error. @@ -91,6 +91,7 @@ LLM_REQUEST_SCHEMA = "nemo.relay.LlmRequest@1" ANNOTATED_LLM_REQUEST_SCHEMA = "nemo.relay.AnnotatedLlmRequest@1" LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA = "nemo.relay.LlmRequestInterceptOutcome@1" +TOOL_EXECUTION_INTERCEPT_OUTCOME_SCHEMA = "nemo.relay.ToolExecutionInterceptOutcome@1" PLUGIN_DIAGNOSTICS_SCHEMA = "nemo.relay.PluginDiagnostics@1" _OBJECT_SCHEMAS = frozenset( { @@ -98,6 +99,7 @@ LLM_REQUEST_SCHEMA, ANNOTATED_LLM_REQUEST_SCHEMA, LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA, + TOOL_EXECUTION_INTERCEPT_OUTCOME_SCHEMA, } ) _UNREGISTERED = object() @@ -183,7 +185,7 @@ def to_json(self) -> dict[str, Any]: @dataclass(slots=True) class PendingMarkSpec: - """Describe a mark Relay emits after starting a managed LLM call.""" + """Describe a mark Relay emits under a managed lifecycle scope.""" name: str category: str | None = None @@ -224,6 +226,28 @@ def to_json(self) -> dict[str, Json]: } +@dataclass(slots=True) +class ToolExecutionInterceptOutcome: + """Canonical result returned by a Python worker tool execution intercept.""" + + result: Json + pending_marks: list[PendingMarkSpec] = field(default_factory=list) + + def to_json(self) -> dict[str, Json]: + """Convert this outcome to the canonical worker-envelope payload.""" + marks = [] + for mark in self.pending_marks: + if not isinstance(mark, PendingMarkSpec): + raise WorkerSdkError( + "tool execution intercept outcome pending_marks must contain PendingMarkSpec values" + ) + marks.append(mark.to_json()) + return { + "result": self.result, + "pending_marks": marks, + } + + def _normalize_diagnostic(value: Mapping[str, Any]) -> dict[str, Any]: try: level = DiagnosticLevel(value.get("level")).value @@ -372,7 +396,10 @@ def register(self, ctx: PluginContext, config: Json) -> None | Awaitable[None]: ToolSanitizeCallback: TypeAlias = Callable[[str, Json], Json | Awaitable[Json]] ToolConditionalCallback: TypeAlias = Callable[[str, Json], str | None | Awaitable[str | None]] ToolRequestCallback: TypeAlias = Callable[[str, Json], Json | Awaitable[Json]] -ToolExecutionCallback: TypeAlias = Callable[[str, Json, "ToolNext"], Json | Awaitable[Json]] +ToolExecutionCallback: TypeAlias = Callable[ + [str, Json, "ToolNext"], + ToolExecutionInterceptOutcome | Awaitable[ToolExecutionInterceptOutcome], +] LlmSanitizeRequestCallback: TypeAlias = Callable[[LlmRequest], LlmRequest | Awaitable[LlmRequest]] LlmSanitizeResponseCallback: TypeAlias = Callable[[Json], Json | Awaitable[Json]] LlmConditionalCallback: TypeAlias = Callable[[LlmRequest], str | None | Awaitable[str | None]] @@ -565,9 +592,9 @@ def register_tool_execution_intercept( Args: name: Component-local registration name. callback: Function receiving ``(tool_name, arguments, next_call)`` - and returning the tool result as JSON, directly or through an - awaitable. It can call :meth:`ToolNext.call` zero, one, or - multiple times while the invocation is active. + and returning :class:`ToolExecutionInterceptOutcome`, directly + or through an awaitable. It can call :meth:`ToolNext.call` + zero, one, or multiple times while the invocation is active. priority: Execution order. Lower values run first. """ self._push_registration(name, pb.TOOL_EXECUTION_INTERCEPT, priority, False) @@ -1437,7 +1464,16 @@ async def _invoke_result(self, request: Any) -> Any: ToolNext(self._runtime, request.continuation_id), ) ) - return _json_response(result) + if not isinstance(result, ToolExecutionInterceptOutcome): + raise WorkerSdkError("tool execution intercept must return ToolExecutionInterceptOutcome") + return pb.InvokeResponse( + tool_execution=pb.ToolExecutionInterceptResult( + outcome=_json_envelope( + TOOL_EXECUTION_INTERCEPT_OUTCOME_SCHEMA, + result.to_json(), + ), + ) + ) if request.surface == pb.LLM_SANITIZE_REQUEST_GUARDRAIL: return _json_response( await _maybe_await( diff --git a/python/tests/plugin/test_worker_sdk.py b/python/tests/plugin/test_worker_sdk.py index e4cf4b54..3f33fbb2 100644 --- a/python/tests/plugin/test_worker_sdk.py +++ b/python/tests/plugin/test_worker_sdk.py @@ -32,6 +32,7 @@ PluginContext, PluginRuntime, ScopeType, + ToolExecutionInterceptOutcome, ToolNext, WorkerPlugin, WorkerSdkError, @@ -44,6 +45,7 @@ JSON_SCHEMA, LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA, LLM_REQUEST_SCHEMA, + TOOL_EXECUTION_INTERCEPT_OUTCOME_SCHEMA, WORKER_PROTOCOL, _announced_worker_endpoint, _decode_required_envelope, @@ -192,9 +194,12 @@ def tool_block(name: str, value: Json) -> str | None: async def tool_request(name: str, value: Json) -> Json: return _tag(value, f"request_{name}") - async def tool_execution(name: str, value: Json, next_call: ToolNext) -> Json: + async def tool_execution(name: str, value: Json, next_call: ToolNext) -> ToolExecutionInterceptOutcome: result = await next_call.call(_tag(value, f"execute_{name}")) - return _tag(result, "tool_execution") + return ToolExecutionInterceptOutcome( + result=_tag(result, "tool_execution"), + pending_marks=[PendingMarkSpec("worker.tool.execution")], + ) def llm_sanitize_request(request: Json) -> Json: return _tag_llm_request(request, "llm_sanitize_request") @@ -905,9 +910,10 @@ async def test_unary_invoke_success_paths(service: _WorkerService, host_stub: Re tool_request = await _invoke_json_async(service, "tool_request", pb.TOOL_REQUEST_INTERCEPT) assert tool_request["tag"] == "request_lookup" - tool_execution = await _invoke_json_async(service, "tool_execution", pb.TOOL_EXECUTION_INTERCEPT) - assert tool_execution["tag"] == "tool_execution" - assert tool_execution["next_tool"]["tag"] == "execute_lookup" + tool_execution = await _invoke_tool_execution_async(service, "tool_execution") + assert tool_execution["result"]["tag"] == "tool_execution" + assert tool_execution["result"]["next_tool"]["tag"] == "execute_lookup" + assert tool_execution["pending_marks"][0]["name"] == "worker.tool.execution" llm_sanitize_request = await _invoke_json_async( service, @@ -1066,6 +1072,30 @@ def invalid_result(name: str, request: Json, annotated: Json | None) -> Any: assert "must be a JSON object" in response.error.message +async def test_tool_execution_intercept_rejects_legacy_raw_result(): + class LegacyResultPlugin(WorkerPlugin): + plugin_id = "tests.legacy_tool_execution_result" + + def register(self, ctx: PluginContext, config: Json) -> None: + del config + + def legacy_result(name: str, value: Json, next_call: ToolNext) -> Any: + del name, value, next_call + return {"legacy_result": True} + + ctx.register_tool_execution_intercept("legacy", legacy_result) + + service = _service(LegacyResultPlugin(), RecordingHostStub()) + await _register(service) + response = await service.Invoke( + _tool_request("legacy", pb.TOOL_EXECUTION_INTERCEPT, {}), + AbortContext(), + ) + + assert response.WhichOneof("result") == "error" + assert "must return ToolExecutionInterceptOutcome" in response.error.message + + @pytest.mark.parametrize( ("request_factory", "expected_message"), [ @@ -1551,7 +1581,7 @@ class CancelPlugin(WorkerPlugin): def register(self, ctx: PluginContext, config: Json) -> None: del config - async def tool_execution(tool_name: str, value: Json, next_call: ToolNext) -> Json: + async def tool_execution(tool_name: str, value: Json, next_call: ToolNext) -> ToolExecutionInterceptOutcome: del tool_name, value, next_call started.set() try: @@ -1560,6 +1590,7 @@ async def tool_execution(tool_name: str, value: Json, next_call: ToolNext) -> Js cancelled.set() await release.wait() raise + raise AssertionError("unreachable") ctx.register_tool_execution_intercept("cancel", tool_execution) @@ -2251,6 +2282,19 @@ async def _invoke_json_async( return _envelope_value(response.json.value) +async def _invoke_tool_execution_async( + service: _WorkerService, + registration_name: str, +) -> Json: + response = await service.Invoke( + _tool_request(registration_name, pb.TOOL_EXECUTION_INTERCEPT, {"query": "relay"}), + AbortContext(), + ) + assert response.WhichOneof("result") == "tool_execution", response + assert response.tool_execution.outcome.schema == TOOL_EXECUTION_INTERCEPT_OUTCOME_SCHEMA + return _envelope_value(response.tool_execution.outcome) + + def _envelope_value(envelope: Any) -> Json: return json.loads(envelope.json.decode("utf-8")) diff --git a/python/tests/test_scope_local.py b/python/tests/test_scope_local.py index 7746ed48..65440080 100644 --- a/python/tests/test_scope_local.py +++ b/python/tests/test_scope_local.py @@ -20,6 +20,7 @@ MarkEvent, ScopeEvent, ScopeType, + ToolExecutionInterceptOutcome, guardrails, llm, scope, @@ -495,7 +496,7 @@ def my_tool(args): handle, "sl_exec_intercept", 1, - lambda name, args, next_fn: {"from": "intercept"}, + lambda name, args, next_fn: ToolExecutionInterceptOutcome({"from": "intercept"}), ) result = await tools.execute("exec_int_tool", {}, my_tool) @@ -517,7 +518,7 @@ def my_tool(args): def intercept_fn(name, args, next_fn): # Cannot call next_fn here — it returns a Future. args["x"] = args["x"] + 1 - return {"value": args["x"] * 2, "intercepted": True} + return ToolExecutionInterceptOutcome({"value": args["x"] * 2, "intercepted": True}) with scope.scope("exec_next_scope", ScopeType.Agent) as handle: scope_local.register_tool_execution(handle, "sl_exec_next", 1, intercept_fn) @@ -593,7 +594,12 @@ async def stream_intercept(request_inner, next_fn): scope_local.register_tool_request(handle, "sl_tool_req_cov", 1, False, lambda name, args: args) assert scope_local.deregister_tool_request(handle, "sl_tool_req_cov") is True - scope_local.register_tool_execution(handle, "sl_tool_exec_cov", 1, lambda name, args, next_fn: args) + scope_local.register_tool_execution( + handle, + "sl_tool_exec_cov", + 1, + lambda name, args, next_fn: ToolExecutionInterceptOutcome(args), + ) assert scope_local.deregister_tool_execution(handle, "sl_tool_exec_cov") is True scope_local.register_llm_sanitize_request(handle, "sl_llm_req_cov", 1, lambda req: req) diff --git a/python/tests/test_tools.py b/python/tests/test_tools.py index 4fe9f8f2..18a0eaf9 100644 --- a/python/tests/test_tools.py +++ b/python/tests/test_tools.py @@ -8,9 +8,12 @@ import pytest from nemo_relay import ( + MarkEvent, + PendingMarkSpec, ScopeEvent, ScopeType, ToolAttributes, + ToolExecutionInterceptOutcome, ToolHandle, guardrails, intercepts, @@ -274,7 +277,7 @@ def test_execution_intercept_register_deregister(self): intercepts.register_tool_execution( "py_exec_int", 1, - lambda name, args, next: {"intercepted": True}, + lambda name, args, next: ToolExecutionInterceptOutcome({"intercepted": True}), ) assert intercepts.deregister_tool_execution("py_exec_int") @@ -327,7 +330,7 @@ async def test_execution_intercept_replaces_func(self): intercepts.register_tool_execution( "py_exec_replace", 1, - lambda name, args, next: {"from_intercept": True}, + lambda name, args, next: ToolExecutionInterceptOutcome({"from_intercept": True}), ) def original_func(args): @@ -340,12 +343,15 @@ def original_func(args): intercepts.deregister_tool_execution("py_exec_replace") async def test_execution_intercept_can_await_next(self): + events = [] + async def middleware(name, args, next): result = await next({"value": args["value"] + 1}) result["from_intercept"] = True - return result + return ToolExecutionInterceptOutcome(result, [PendingMarkSpec("python.tool.execution")]) intercepts.register_tool_execution("py_exec_next", 1, middleware) + subscribers.register("py_exec_mark_sub", lambda event: events.append(event)) def original(args): return {"value": args["value"] * 2} @@ -353,8 +359,29 @@ def original(args): try: result = await tools.execute("next_tool", {"value": 2}, original) assert result == {"value": 6, "from_intercept": True} + subscribers.flush() + start = _tool_event(events, "next_tool", "start") + end = _tool_event(events, "next_tool", "end") + mark = next( + event for event in events if isinstance(event, MarkEvent) and event.name == "python.tool.execution" + ) + assert mark.parent_uuid == start.uuid + assert events.index(end) < events.index(mark) finally: intercepts.deregister_tool_execution("py_exec_next") + subscribers.deregister("py_exec_mark_sub") + + async def test_execution_intercept_rejects_legacy_raw_result(self): + intercepts.register_tool_execution( + "py_exec_legacy", + 1, + lambda name, args, next: {"legacy_result": True}, # type: ignore[arg-type] # ty: ignore[invalid-argument-type] + ) + try: + with pytest.raises(RuntimeError, match="must return ToolExecutionInterceptOutcome"): + await tools.execute("legacy_tool", {}, lambda args: args) + finally: + intercepts.deregister_tool_execution("py_exec_legacy") async def test_request_intercept_break_chain(self): def first_fn(name, args):