From 67dcdd891795a678f9f72093831a52d6a0602c4f Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Mon, 29 Jun 2026 12:34:29 -0600 Subject: [PATCH 1/9] feat(plugin): support pending marks from LLM intercepts Signed-off-by: Bryan Bednarski --- crates/core/src/api/llm.rs | 56 ++++- crates/core/src/api/registry.rs | 57 ++++- crates/core/src/api/runtime.rs | 6 +- crates/core/src/api/runtime/callbacks.rs | 12 +- crates/core/src/api/runtime/state.rs | 33 +-- crates/core/src/api/shared.rs | 17 +- crates/core/src/context/registries.rs | 4 +- crates/core/src/plugin.rs | 44 +++- crates/core/src/plugin/dynamic/native.rs | 56 +++-- .../tests/fixtures/native_plugin/src/lib.rs | 24 +- .../tests/integration/middleware_tests.rs | 212 +++++++++++++++++- .../tests/integration/native_plugin_tests.rs | 25 +++ crates/core/tests/unit/shared_tests.rs | 40 ++-- crates/plugin/src/lib.rs | 107 ++++++++- crates/plugin/tests/typed_callbacks.rs | 84 ++++++- crates/types/src/api/event.rs | 23 ++ crates/types/src/api/llm.rs | 48 +++- crates/types/tests/serialization_tests.rs | 37 ++- examples/rust-native-plugin/README.md | 11 +- examples/rust-native-plugin/src/lib.rs | 20 +- 20 files changed, 795 insertions(+), 121 deletions(-) diff --git a/crates/core/src/api/llm.rs b/crates/core/src/api/llm.rs index 0af799d04..7910222e5 100644 --- a/crates/core/src/api/llm.rs +++ b/crates/core/src/api/llm.rs @@ -3,12 +3,13 @@ use std::sync::Arc; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, TimeDelta, Utc}; use serde::{Deserialize, Serialize}; use serde_json::json; use typed_builder::TypedBuilder; use uuid::Uuid; +use crate::api::event::{BaseEvent, Event, MarkEvent, PendingMarkSpec}; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::current_scope_stack; use crate::api::runtime::global_context; @@ -28,7 +29,7 @@ use crate::error::{FlowError, Result}; use crate::json::Json; use crate::stream::LlmStreamWrapper; -pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; /// Runtime-owned handle identifying an active or completed LLM call. #[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)] @@ -298,6 +299,35 @@ fn emit_llm_start( Ok(()) } +fn emit_pending_request_marks(handle: &LlmHandle, marks: Vec) -> Result<()> { + if marks.is_empty() { + return Ok(()); + } + ensure_runtime_owner()?; + let subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; + for (index, mark) in marks.into_iter().enumerate() { + let timestamp = handle.started_at + + TimeDelta::microseconds(i64::try_from(index).unwrap_or_default() + 1); + let event = Event::Mark(MarkEvent::new( + BaseEvent::builder() + .name(mark.name) + .parent_uuid(handle.uuid) + .timestamp(timestamp) + .data_opt(mark.data) + .metadata_opt(mark.metadata) + .build(), + mark.category, + mark.category_profile, + )); + NemoRelayContextState::emit_event(&event, &subscribers); + } + Ok(()) +} + /// Start a manual LLM lifecycle span. /// /// This emits an LLM-start event after applying sanitize-request guardrails to @@ -587,7 +617,7 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { } let request_codec = codec.clone(); - let (intercepted_request, annotated_request) = + let (intercepted_request, annotated_request, pending_marks) = run_request_intercepts_with_codec(&name, request, codec)?; let handle = create_llm_handle( @@ -606,6 +636,7 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { annotated_request.clone(), request_codec.as_deref(), )?; + emit_pending_request_marks(&handle, pending_marks)?; let execution = { let scope_stack = current_scope_stack(); @@ -743,7 +774,7 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu } let request_codec = codec.clone(); - let (intercepted_request, annotated_request) = + let (intercepted_request, annotated_request, pending_marks) = run_request_intercepts_with_codec(&name, request, codec)?; let handle = create_llm_handle( @@ -762,6 +793,7 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu annotated_request, request_codec.as_deref(), )?; + emit_pending_request_marks(&handle, pending_marks)?; let execution = { let scope_stack = current_scope_stack(); @@ -818,6 +850,17 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu /// Conditional guardrails, codecs, and execution intercepts are not run by /// this helper. pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result { + Ok(llm_request_intercepts_with_marks(name, request)?.request) +} + +/// Run the LLM request-intercept chain and return pending lifecycle marks. +/// +/// This helper does not emit the returned marks because it does not own an LLM +/// lifecycle. Callers must attach them to the lifecycle they own. +pub fn llm_request_intercepts_with_marks( + name: &str, + request: LlmRequest, +) -> Result { ensure_runtime_owner()?; let entries = { let scope_stack = current_scope_stack(); @@ -830,10 +873,7 @@ pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result Result<()> { + register_llm_request_intercept_with_marks( + name, + priority, + break_chain, + std::sync::Arc::new(move |name, request, annotated| { + callable(name, request, annotated).map(Into::into) + }), + ) +} global_execution_registry_api!( /// Register a global LLM execution intercept. /// Execution intercepts can wrap or replace the non-streaming provider @@ -653,15 +669,32 @@ scope_guardrail_registry_api!( LlmConditionalFn ); scope_intercept_registry_api!( - /// Register a scope-local LLM request intercept. - /// Request intercepts can rewrite or annotate LLM requests inside the - /// owning scope. - scope_register_llm_request_intercept, + /// Register a scope-local LLM request intercept that can schedule lifecycle marks. + scope_register_llm_request_intercept_with_marks, /// Deregister a scope-local LLM request intercept. scope_deregister_llm_request_intercept, llm_request_intercepts, - LlmRequestInterceptFn + LlmRequestInterceptWithMarksFn ); + +/// Register a scope-local LLM request intercept without pending marks. +pub fn scope_register_llm_request_intercept( + scope_uuid: &uuid::Uuid, + name: &str, + priority: i32, + break_chain: bool, + callable: LlmRequestInterceptFn, +) -> Result<()> { + scope_register_llm_request_intercept_with_marks( + scope_uuid, + name, + priority, + break_chain, + std::sync::Arc::new(move |name, request, annotated| { + callable(name, request, annotated).map(Into::into) + }), + ) +} scope_execution_registry_api!( /// Register a scope-local LLM execution intercept. /// Execution intercepts can wrap or replace the non-streaming provider diff --git a/crates/core/src/api/runtime.rs b/crates/core/src/api/runtime.rs index 2351ae352..4d0c372c4 100644 --- a/crates/core/src/api/runtime.rs +++ b/crates/core/src/api/runtime.rs @@ -11,9 +11,9 @@ pub mod subscriber_dispatcher; pub use callbacks::{ EventSubscriberFn, LlmCollectorFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, - LlmFinalizerFn, LlmJsonStream, LlmRequestInterceptFn, LlmSanitizeRequestFn, - LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, ToolConditionalFn, - ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, + LlmFinalizerFn, LlmJsonStream, LlmRequestInterceptFn, LlmRequestInterceptWithMarksFn, + LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, + ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; pub use global::global_context; pub use scope_stack::{ diff --git a/crates/core/src/api/runtime/callbacks.rs b/crates/core/src/api/runtime/callbacks.rs index 980c47d74..dfccd2fea 100644 --- a/crates/core/src/api/runtime/callbacks.rs +++ b/crates/core/src/api/runtime/callbacks.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use tokio_stream::Stream; use crate::api::event::Event; -use crate::api::llm::LlmRequest; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::codec::request::AnnotatedLlmRequest; use crate::error::Result; use crate::json::Json; @@ -177,6 +177,16 @@ pub type LlmRequestInterceptFn = Arc< + Send + Sync, >; +/// Rewrite or annotate an LLM request and schedule marks under its future scope. +/// +/// This callback has the same inputs as [`LlmRequestInterceptFn`] but returns a +/// structured outcome whose pending marks are emitted after the LLM-start +/// event and before provider execution. +pub type LlmRequestInterceptWithMarksFn = Arc< + dyn Fn(&str, LlmRequest, Option) -> Result + + Send + + Sync, +>; /// Continuation type invoked by non-streaming LLM execution intercepts. /// /// Execution intercepts use this callable to continue the non-streaming LLM diff --git a/crates/core/src/api/runtime/state.rs b/crates/core/src/api/runtime/state.rs index 70276786d..ca03d3b29 100644 --- a/crates/core/src/api/runtime/state.rs +++ b/crates/core/src/api/runtime/state.rs @@ -20,10 +20,10 @@ use crate::api::llm::{CreateLlmHandleParams, EndLlmHandleParams}; use crate::api::llm::{LlmHandle, LlmRequest}; use crate::api::registry::{ExecutionIntercept, Guardrail, Intercept}; use crate::api::runtime::callbacks::{ - EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmRequestInterceptFn, - LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, - LlmStreamExecutionRegistryRefs, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, - ToolInterceptFn, ToolSanitizeFn, + EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, + LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, + LlmStreamExecutionFn, LlmStreamExecutionNextFn, LlmStreamExecutionRegistryRefs, + ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; use crate::api::runtime::subscriber_dispatcher; use crate::api::scope::{CreateScopeHandleParams, EndScopeHandleParams, ScopeHandle, ScopeType}; @@ -63,7 +63,7 @@ pub struct NemoRelayContextState { /// Global LLM guardrails that can reject execution before the provider callback runs. pub(crate) llm_conditional_execution_guardrails: SortedRegistry>, /// Global LLM request intercepts that can rewrite or annotate requests. - pub(crate) llm_request_intercepts: SortedRegistry>, + pub(crate) llm_request_intercepts: SortedRegistry>, /// Global non-streaming LLM execution intercepts that wrap callback execution. pub(crate) llm_execution_intercepts: SortedRegistry>, /// Global streaming LLM execution intercepts that wrap stream-producing callbacks. @@ -1011,8 +1011,8 @@ impl NemoRelayContextState { /// are released. pub(crate) fn llm_request_intercept_entries( &self, - scope_locals: &[&SortedRegistry>], - ) -> Vec> { + scope_locals: &[&SortedRegistry>], + ) -> Vec> { merge_intercept_entries(&self.llm_request_intercepts, scope_locals) .into_iter() .cloned() @@ -1041,20 +1041,25 @@ impl NemoRelayContextState { name: &str, request: LlmRequest, annotated: Option, - entries: &[Intercept], - ) -> crate::error::Result<(LlmRequest, Option)> { + entries: &[Intercept], + ) -> crate::error::Result { let mut request_value = request; let mut annotated_value = annotated; + let mut pending_marks = Vec::new(); for entry in entries { - let (new_request, new_annotated) = - (entry.payload.callable)(name, request_value, annotated_value)?; - request_value = new_request; - annotated_value = new_annotated; + let outcome = (entry.payload.callable)(name, request_value, annotated_value)?; + request_value = outcome.request; + annotated_value = outcome.annotated_request; + pending_marks.extend(outcome.pending_marks); if entry.payload.break_chain { break; } } - Ok((request_value, annotated_value)) + Ok(crate::api::llm::LlmRequestInterceptOutcome { + request: request_value, + annotated_request: annotated_value, + pending_marks, + }) } /// Build the composed non-streaming LLM execution continuation chain. diff --git a/crates/core/src/api/shared.rs b/crates/core/src/api/shared.rs index 861cd41c2..9d6dff8c2 100644 --- a/crates/core/src/api/shared.rs +++ b/crates/core/src/api/shared.rs @@ -74,7 +74,11 @@ pub(crate) fn run_request_intercepts_with_codec( name: &str, request: LlmRequest, codec: Option>, -) -> Result<(LlmRequest, Option>)> { +) -> Result<( + LlmRequest, + Option>, + Vec, +)> { let original = request.clone(); let annotated = match &codec { Some(codec) => Some(codec.decode(&request)?), @@ -94,18 +98,19 @@ pub(crate) fn run_request_intercepts_with_codec( state.llm_request_intercept_entries(&scope_locals) }; - let (intercepted_request, intercepted_annotated) = + let outcome = crate::api::runtime::NemoRelayContextState::llm_request_intercepts_snapshot_chain( name, request, annotated, &entries, )?; + let pending_marks = outcome.pending_marks; - match (codec, intercepted_annotated) { + match (codec, outcome.annotated_request) { (Some(codec), Some(annotated)) => { let mut encoded = codec.encode(&annotated, &original)?; - encoded.headers = intercepted_request.headers; - Ok((encoded, Some(Arc::new(annotated)))) + encoded.headers = outcome.request.headers; + Ok((encoded, Some(Arc::new(annotated)), pending_marks)) } - _ => Ok((intercepted_request, None)), + _ => Ok((outcome.request, None, pending_marks)), } } diff --git a/crates/core/src/context/registries.rs b/crates/core/src/context/registries.rs index 2a0d2fde9..c6781eee0 100644 --- a/crates/core/src/context/registries.rs +++ b/crates/core/src/context/registries.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use crate::api::registry::{ExecutionIntercept, Guardrail, Intercept}; use crate::api::runtime::{ - EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, + EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, }; @@ -41,7 +41,7 @@ pub(crate) struct ScopeLocalRegistries { /// LLM guardrails that can reject execution before the provider callback runs. pub(crate) llm_conditional_execution_guardrails: SortedRegistry>, /// LLM request intercepts that can rewrite or annotate requests. - pub(crate) llm_request_intercepts: SortedRegistry>, + pub(crate) llm_request_intercepts: SortedRegistry>, /// Non-streaming LLM execution intercepts that wrap callback execution. pub(crate) llm_execution_intercepts: SortedRegistry>, /// Streaming LLM execution intercepts that wrap stream-producing callbacks. diff --git a/crates/core/src/plugin.rs b/crates/core/src/plugin.rs index 94420f6df..c2934e3d0 100644 --- a/crates/core/src/plugin.rs +++ b/crates/core/src/plugin.rs @@ -27,15 +27,16 @@ use crate::api::registry::{ deregister_tool_request_intercept, deregister_tool_sanitize_request_guardrail, deregister_tool_sanitize_response_guardrail, register_llm_conditional_execution_guardrail, register_llm_execution_intercept, register_llm_request_intercept, - register_llm_sanitize_request_guardrail, register_llm_sanitize_response_guardrail, - register_llm_stream_execution_intercept, register_tool_conditional_execution_guardrail, - register_tool_execution_intercept, register_tool_request_intercept, - register_tool_sanitize_request_guardrail, register_tool_sanitize_response_guardrail, + register_llm_request_intercept_with_marks, register_llm_sanitize_request_guardrail, + register_llm_sanitize_response_guardrail, register_llm_stream_execution_intercept, + register_tool_conditional_execution_guardrail, register_tool_execution_intercept, + register_tool_request_intercept, register_tool_sanitize_request_guardrail, + register_tool_sanitize_response_guardrail, }; use crate::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, - LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, - ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, + LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, + LlmStreamExecutionFn, ToolConditionalFn, ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, }; use crate::api::subscriber::{deregister_subscriber, register_subscriber}; pub use nemo_relay_types::plugin::{ConfigDiagnostic, DiagnosticLevel}; @@ -350,6 +351,37 @@ impl PluginRegistrationContext { Ok(()) } + /// Registers an LLM request intercept that can schedule lifecycle marks. + pub fn register_llm_request_intercept_with_marks( + &mut self, + name: &str, + priority: i32, + break_chain: bool, + callback: LlmRequestInterceptWithMarksFn, + ) -> Result<()> { + let qualified_name = self.qualify_name(name); + register_llm_request_intercept_with_marks(&qualified_name, priority, break_chain, callback) + .map_err(|err| { + PluginError::RegistrationFailed(format!("llm request intercept: {err}")) + })?; + + let name_owned = qualified_name; + self.registrations.push(PluginRegistration::new( + "plugin", + name_owned.clone(), + Box::new(move || { + deregister_llm_request_intercept(&name_owned) + .map(|_| ()) + .map_err(|err| { + PluginError::RegistrationFailed(format!( + "llm request intercept deregistration failed: {err}" + )) + }) + }), + )); + Ok(()) + } + /// Registers a tool sanitize-request guardrail and records its rollback closure. pub fn register_tool_sanitize_request_guardrail( &mut self, diff --git a/crates/core/src/plugin/dynamic/native.rs b/crates/core/src/plugin/dynamic/native.rs index 0f1c14444..c541ba6d5 100644 --- a/crates/core/src/plugin/dynamic/native.rs +++ b/crates/core/src/plugin/dynamic/native.rs @@ -26,18 +26,19 @@ use nemo_relay_plugin::{ NemoRelayNativeWithScopeStackCb, NemoRelayStatus, }; use semver::{Version, VersionReq}; +use serde::Deserialize; use serde_json::{Map, Value as Json}; use sha2::{Digest, Sha256}; use tokio::runtime::Runtime; use tokio_stream::{Stream, StreamExt}; -use crate::api::event::Event; -use crate::api::llm::LlmRequest; +use crate::api::event::{Event, PendingMarkSpec}; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmJsonStream, - LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, - LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, - ToolInterceptFn, ToolSanitizeFn, + LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, + LlmStreamExecutionFn, LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionFn, + ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; use crate::api::runtime::{ ScopeStackHandle, ThreadScopeStackBinding, capture_thread_scope_stack, create_scope_stack, @@ -1332,7 +1333,7 @@ unsafe extern "C" fn native_plugin_context_register_llm_request_intercept( Ok(name) => name, Err(status) => return status, }; - match ctx.register_llm_request_intercept( + match ctx.register_llm_request_intercept_with_marks( &name, priority, break_chain, @@ -1690,7 +1691,7 @@ fn wrap_llm_request_intercept_fn( cb: NemoRelayNativeLlmRequestInterceptCb, user_data: *mut c_void, free_fn: NemoRelayNativeFreeFn, -) -> LlmRequestInterceptFn { +) -> LlmRequestInterceptWithMarksFn { let user_data = make_user_data(instance, user_data, free_fn); Arc::new(move |name, request, annotated| { clear_native_last_error(); @@ -1756,17 +1757,42 @@ fn wrap_llm_request_intercept_fn( let annotated_json = annotated_json?; let request: LlmRequest = serde_json::from_value(request_json) .map_err(|err| FlowError::Internal(format!("invalid LLM request JSON: {err}")))?; - let annotated = annotated_json - .map(|annotated_json| { - serde_json::from_value::(annotated_json).map_err(|err| { - FlowError::Internal(format!("invalid annotated request JSON: {err}")) - }) - }) - .transpose()?; - Ok((request, annotated)) + let (annotated_request, pending_marks) = match annotated_json { + Some(value) + if value.get(nemo_relay_types::api::llm::NATIVE_LLM_INTERCEPT_OUTCOME_FIELD) + == Some(&Json::Bool(true)) => + { + let metadata: NativeLlmRequestInterceptOutcome = serde_json::from_value(value) + .map_err(|err| { + FlowError::Internal(format!("invalid marked LLM outcome JSON: {err}")) + })?; + (metadata.annotated_request, metadata.pending_marks) + } + Some(value) => { + let annotated = + serde_json::from_value::(value).map_err(|err| { + FlowError::Internal(format!("invalid annotated request JSON: {err}")) + })?; + (Some(annotated), Vec::new()) + } + None => (None, Vec::new()), + }; + Ok(LlmRequestInterceptOutcome { + request, + annotated_request, + pending_marks, + }) }) } +#[derive(Deserialize)] +struct NativeLlmRequestInterceptOutcome { + #[serde(rename = "__nemo_relay_llm_intercept_outcome")] + _marked_outcome: bool, + annotated_request: Option, + pending_marks: Vec, +} + fn wrap_llm_execution_fn( instance: Arc, cb: NemoRelayNativeLlmExecutionCb, diff --git a/crates/core/tests/fixtures/native_plugin/src/lib.rs b/crates/core/tests/fixtures/native_plugin/src/lib.rs index 9fc58bd41..4eeebd5ba 100644 --- a/crates/core/tests/fixtures/native_plugin/src/lib.rs +++ b/crates/core/tests/fixtures/native_plugin/src/lib.rs @@ -5,10 +5,10 @@ use std::ffi::c_void; use std::ptr; use nemo_relay_plugin::{ - ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmRequest, - NemoRelayNativeHostApiV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, - NemoRelayNativeString, NemoRelayStatus, NativePlugin, PluginContext, PluginRuntime, - ScopeCategory, ScopeType, + CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, Json, LlmJsonStream, + LlmRequest, LlmRequestInterceptOutcome, NemoRelayNativeHostApiV1, + NemoRelayNativePluginContext, NemoRelayNativePluginV1, NemoRelayNativeString, NemoRelayStatus, + NativePlugin, PendingMarkSpec, PluginContext, PluginRuntime, ScopeCategory, ScopeType, }; use serde_json::{Map, json}; @@ -114,14 +114,26 @@ impl NativePlugin for FixtureNativePlugin { 0, |_request| Ok(None), )?; - ctx.register_llm_request_intercept( + ctx.register_llm_request_intercept_with_marks( "fixture_llm_request_intercept", 0, false, |_name, request, annotated| { - Ok(( + Ok(LlmRequestInterceptOutcome::new( mark_llm_request(request, "native_plugin_llm_request_intercept"), annotated, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("fixture.native.llm_request.mark") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("fixture.native.pending".into()), + ..CategoryProfile::default() + }) + .data(json!({ "source": "native_request_intercept" })) + .metadata(json!({ "fixture": true })) + .build(), )) }, )?; diff --git a/crates/core/tests/integration/middleware_tests.rs b/crates/core/tests/integration/middleware_tests.rs index 612c5a83d..0394f6c10 100644 --- a/crates/core/tests/integration/middleware_tests.rs +++ b/crates/core/tests/integration/middleware_tests.rs @@ -14,12 +14,14 @@ use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use futures::StreamExt; -use nemo_relay::api::event::{Event, ScopeCategory}; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::event::{ + CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, +}; use nemo_relay::api::llm::{ LlmCallExecuteParams, LlmStreamCallExecuteParams, llm_call_execute, llm_request_intercepts, - llm_stream_call_execute, + llm_request_intercepts_with_marks, llm_stream_call_execute, }; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::registry::{ deregister_llm_conditional_execution_guardrail, deregister_llm_execution_intercept, deregister_llm_request_intercept, deregister_llm_sanitize_request_guardrail, @@ -28,13 +30,14 @@ use nemo_relay::api::registry::{ deregister_tool_request_intercept, deregister_tool_sanitize_request_guardrail, deregister_tool_sanitize_response_guardrail, register_llm_conditional_execution_guardrail, register_llm_execution_intercept, register_llm_request_intercept, - register_llm_sanitize_request_guardrail, register_llm_sanitize_response_guardrail, - register_llm_stream_execution_intercept, register_tool_conditional_execution_guardrail, - register_tool_execution_intercept, register_tool_request_intercept, - register_tool_sanitize_request_guardrail, register_tool_sanitize_response_guardrail, - scope_register_llm_conditional_execution_guardrail, scope_register_llm_execution_intercept, - scope_register_llm_request_intercept, scope_register_llm_sanitize_request_guardrail, - scope_register_llm_sanitize_response_guardrail, scope_register_llm_stream_execution_intercept, + register_llm_request_intercept_with_marks, register_llm_sanitize_request_guardrail, + register_llm_sanitize_response_guardrail, register_llm_stream_execution_intercept, + register_tool_conditional_execution_guardrail, register_tool_execution_intercept, + register_tool_request_intercept, register_tool_sanitize_request_guardrail, + register_tool_sanitize_response_guardrail, scope_register_llm_conditional_execution_guardrail, + scope_register_llm_execution_intercept, scope_register_llm_request_intercept, + scope_register_llm_sanitize_request_guardrail, scope_register_llm_sanitize_response_guardrail, + scope_register_llm_stream_execution_intercept, scope_register_tool_conditional_execution_guardrail, scope_register_tool_execution_intercept, scope_register_tool_request_intercept, scope_register_tool_sanitize_request_guardrail, scope_register_tool_sanitize_response_guardrail, @@ -2588,6 +2591,195 @@ async fn test_llm_request_intercept_transforms() { deregister_llm_request_intercept("llm_req_i").unwrap(); } +#[test] +fn test_llm_request_intercept_pending_marks_preserve_order_and_break_chain() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + for (name, priority, break_chain, mark_name) in [ + ("pending_first", 1, false, "first"), + ("pending_break", 2, true, "second"), + ("pending_skipped", 3, false, "skipped"), + ] { + register_llm_request_intercept_with_marks( + name, + priority, + break_chain, + Arc::new(move |_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark(PendingMarkSpec::builder().name(mark_name).build())) + }), + ) + .unwrap(); + } + + let outcome = llm_request_intercepts_with_marks( + "llm", + LlmRequest { + headers: serde_json::Map::new(), + content: json!({"prompt": "hello"}), + }, + ) + .unwrap(); + + assert_eq!( + outcome + .pending_marks + .iter() + .map(|mark| mark.name.as_str()) + .collect::>(), + ["first", "second"] + ); + assert_eq!(outcome.request.content["prompt"], "hello"); + + for name in ["pending_first", "pending_break", "pending_skipped"] { + deregister_llm_request_intercept(name).unwrap(); + } +} + +#[tokio::test] +async fn test_managed_llm_emits_pending_marks_under_started_scope() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "pending_mark_observer", + Arc::new(move |event: &Event| captured.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + register_llm_request_intercept_with_marks( + "pending_managed", + 1, + false, + Arc::new(|_name, request, annotated| { + Ok( + LlmRequestInterceptOutcome::new(request, annotated).with_pending_mark( + PendingMarkSpec::builder() + .name("request.optimized") + .category(EventCategory::custom()) + .category_profile( + CategoryProfile::builder() + .subtype("optimizer.saved_tokens") + .build(), + ) + .data(json!({"saved_tokens": 12})) + .build(), + ), + ) + }), + ) + .unwrap(); + + let provider_request = Arc::new(Mutex::new(None::)); + let captured_request = provider_request.clone(); + llm_call_execute( + LlmCallExecuteParams::builder() + .name("pending-managed-llm") + .request(LlmRequest { + headers: serde_json::Map::new(), + content: json!({"prompt": "hello"}), + }) + .func(Arc::new(move |request| { + *captured_request.lock().unwrap() = Some(request); + Box::pin(async { Ok(json!({"response": "done"})) }) + })) + .build(), + ) + .await + .unwrap(); + + let provider_request = provider_request.lock().unwrap().clone().unwrap(); + assert!( + serde_json::to_value(provider_request) + .unwrap() + .get("__nemo_relay_llm_intercept_outcome") + .is_none() + ); + + let captured = captured_events_snapshot(&events); + let start = captured + .iter() + .find(|event| { + event.name() == "pending-managed-llm" + && event.scope_category() == Some(ScopeCategory::Start) + }) + .unwrap(); + let mark = captured + .iter() + .find(|event| event.name() == "request.optimized") + .unwrap(); + assert_eq!(mark.parent_uuid(), Some(start.uuid())); + assert!(mark.timestamp() > start.timestamp()); + assert_eq!(mark.data().unwrap()["saved_tokens"], 12); + + deregister_llm_request_intercept("pending_managed").unwrap(); + deregister_subscriber("pending_mark_observer").unwrap(); +} + +#[tokio::test] +async fn test_failed_request_intercept_does_not_emit_pending_marks_or_start_scope() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "failed_pending_mark_observer", + Arc::new(move |event: &Event| captured.lock().unwrap().push(event.clone())), + ) + .unwrap(); + register_llm_request_intercept_with_marks( + "pending_before_failure", + 1, + false, + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark(PendingMarkSpec::builder().name("must.not.emit").build())) + }), + ) + .unwrap(); + register_llm_request_intercept_with_marks( + "pending_failure", + 2, + false, + Arc::new(|_name, _request, _annotated| { + Err(FlowError::Internal("request intercept failed".into())) + }), + ) + .unwrap(); + + let provider_called = Arc::new(AtomicBool::new(false)); + let called = provider_called.clone(); + let result = llm_call_execute( + LlmCallExecuteParams::builder() + .name("failed-pending-llm") + .request(LlmRequest { + headers: serde_json::Map::new(), + content: json!({"prompt": "hello"}), + }) + .func(Arc::new(move |_request| { + called.store(true, Ordering::SeqCst); + Box::pin(async { Ok(json!({"response": "unexpected"})) }) + })) + .build(), + ) + .await; + + assert!(result.is_err()); + assert!(!provider_called.load(Ordering::SeqCst)); + assert!(captured_events_snapshot(&events).is_empty()); + + deregister_llm_request_intercept("pending_before_failure").unwrap(); + deregister_llm_request_intercept("pending_failure").unwrap(); + deregister_subscriber("failed_pending_mark_observer").unwrap(); +} + /// LLM execution intercept middleware chain with next(). #[tokio::test] async fn test_llm_execution_intercept_chain() { diff --git a/crates/core/tests/integration/native_plugin_tests.rs b/crates/core/tests/integration/native_plugin_tests.rs index c209a553e..a41aa31a1 100644 --- a/crates/core/tests/integration/native_plugin_tests.rs +++ b/crates/core/tests/integration/native_plugin_tests.rs @@ -342,6 +342,24 @@ async fn sdk_cdylib_registers_tool_request_intercept() { llm_start.input().unwrap()["content"]["native_plugin_llm_request_intercept"], true ); + let pending_mark = find_event(&managed_llm_events, "fixture.native.llm_request.mark", None); + assert_eq!(pending_mark.parent_uuid(), Some(llm_start.uuid())); + assert_eq!( + pending_mark.category().map(|category| category.as_str()), + Some("custom") + ); + assert_eq!( + pending_mark + .category_profile() + .and_then(|profile| profile.subtype.as_deref()), + Some("fixture.native.pending") + ); + assert_eq!( + pending_mark.data().unwrap()["source"], + "native_request_intercept" + ); + assert_eq!(pending_mark.metadata().unwrap()["fixture"], true); + assert!(pending_mark.timestamp() > llm_start.timestamp()); let llm_end = find_event( &managed_llm_events, "native-fixture-llm-execute", @@ -400,6 +418,13 @@ async fn sdk_cdylib_registers_tool_request_intercept() { assert_eq!(*collected_stream_chunks.lock().unwrap(), stream_chunks); flush_subscribers().expect("stream native fixture events should flush"); let stream_events = events.lock().unwrap().clone(); + let stream_start = find_event( + &stream_events, + "native-fixture-llm-stream", + Some(ScopeCategory::Start), + ); + let stream_pending_mark = find_event(&stream_events, "fixture.native.llm_request.mark", None); + assert_eq!(stream_pending_mark.parent_uuid(), Some(stream_start.uuid())); let stream_end = find_event( &stream_events, "native-fixture-llm-stream", diff --git a/crates/core/tests/unit/shared_tests.rs b/crates/core/tests/unit/shared_tests.rs index fe3eeab59..416cbb4f1 100644 --- a/crates/core/tests/unit/shared_tests.rs +++ b/crates/core/tests/unit/shared_tests.rs @@ -161,20 +161,22 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { ) .unwrap(); - let (request_without_codec, annotated_without_codec) = run_request_intercepts_with_codec( - "shared", - LlmRequest { - headers: Map::new(), - content: json!({"prompt": "hello"}), - }, - None, - ) - .unwrap(); + let (request_without_codec, annotated_without_codec, pending_marks_without_codec) = + run_request_intercepts_with_codec( + "shared", + LlmRequest { + headers: Map::new(), + content: json!({"prompt": "hello"}), + }, + None, + ) + .unwrap(); assert_eq!( request_without_codec.headers.get("x-no-codec"), Some(&json!(true)) ); assert!(annotated_without_codec.is_none()); + assert!(pending_marks_without_codec.is_empty()); deregister_llm_request_intercept("shared-none").unwrap(); register_llm_request_intercept( @@ -191,15 +193,16 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { .unwrap(); let codec: Arc = Arc::new(SharedTestCodec); - let (request_with_codec, annotated_with_codec) = run_request_intercepts_with_codec( - "shared", - LlmRequest { - headers: Map::new(), - content: json!({"prompt": "hello"}), - }, - Some(codec), - ) - .unwrap(); + let (request_with_codec, annotated_with_codec, pending_marks_with_codec) = + run_request_intercepts_with_codec( + "shared", + LlmRequest { + headers: Map::new(), + content: json!({"prompt": "hello"}), + }, + Some(codec), + ) + .unwrap(); assert_eq!( request_with_codec.headers.get("x-codec"), @@ -215,6 +218,7 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { .and_then(|annotated| annotated.model.as_deref()), Some("intercepted-model") ); + assert!(pending_marks_with_codec.is_empty()); deregister_llm_request_intercept("shared-codec").unwrap(); reset_global(); diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 31f0c4e40..007322cf7 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -16,8 +16,10 @@ use std::ptr; use std::sync::Mutex; pub use nemo_relay_types::Json; -pub use nemo_relay_types::api::event::{Event, ScopeCategory}; -pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +pub use nemo_relay_types::api::event::{ + CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, +}; +pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; pub use nemo_relay_types::api::scope::{HandleAttributes, ScopeAttributes, ScopeType}; pub use nemo_relay_types::api::tool::ToolAttributes; pub use nemo_relay_types::codec::request::AnnotatedLlmRequest; @@ -1490,6 +1492,39 @@ impl<'a> PluginContext<'a> { finish_typed_registration::(self.host, status, user_data, "llm request intercept") } + /// Registers a typed LLM request intercept that can schedule lifecycle marks. + pub fn register_llm_request_intercept_with_marks( + &mut self, + name: &str, + priority: i32, + break_chain: bool, + callback: F, + ) -> Result<()> + where + F: Fn(&str, LlmRequest, Option) -> Result + + Send + + Sync + + 'static, + { + let user_data = typed_callback_user_data(self.host, callback); + let status = unsafe { + self.register_llm_request_intercept_raw( + name, + priority, + break_chain, + typed_llm_request_intercept_with_marks_trampoline::, + user_data, + Some(drop_typed_callback::), + ) + }; + finish_typed_registration::( + self.host, + status, + user_data, + "LLM request intercept with marks", + ) + } + /// Registers a typed LLM execution intercept. pub fn register_llm_execution_intercept( &mut self, @@ -2189,6 +2224,74 @@ where } } +#[derive(Serialize)] +struct NativeLlmRequestInterceptOutcome<'a> { + #[serde(rename = "__nemo_relay_llm_intercept_outcome")] + marked_outcome: bool, + annotated_request: &'a Option, + pending_marks: &'a [PendingMarkSpec], +} + +unsafe extern "C" fn typed_llm_request_intercept_with_marks_trampoline( + user_data: *mut c_void, + name: *const NemoRelayNativeString, + request_json: *const NemoRelayNativeString, + annotated_json: *const NemoRelayNativeString, + out_request_json: *mut *mut NemoRelayNativeString, + out_annotated_json: *mut *mut NemoRelayNativeString, +) -> NemoRelayStatus +where + F: Fn(&str, LlmRequest, Option) -> Result + + Send + + Sync + + 'static, +{ + if user_data.is_null() || out_request_json.is_null() || out_annotated_json.is_null() { + return NemoRelayStatus::NullPointer; + } + unsafe { + *out_request_json = ptr::null_mut(); + *out_annotated_json = ptr::null_mut(); + } + let state = unsafe { &*(user_data as *const TypedCallback) }; + let result = catch_unwind(AssertUnwindSafe(|| { + let name = read_required_host_string(&state.host, name, "LLM name")?; + let request: LlmRequest = read_json_value(&state.host, request_json, "LLM request")?; + let annotated: Option = + read_optional_json_value(&state.host, annotated_json, "annotated LLM request")?; + match (state.callback)(&name, request, annotated) { + Ok(outcome) => { + let Some(request) = HostString::from_json(&state.host, &outcome.request) else { + set_last_error(&state.host, "failed to allocate LLM request output"); + return Ok(NemoRelayStatus::Internal); + }; + let metadata = NativeLlmRequestInterceptOutcome { + marked_outcome: true, + annotated_request: &outcome.annotated_request, + pending_marks: &outcome.pending_marks, + }; + let Some(metadata) = HostString::from_json(&state.host, &metadata) else { + set_last_error(&state.host, "failed to allocate marked LLM outcome"); + return Ok(NemoRelayStatus::Internal); + }; + unsafe { + *out_request_json = request.ptr; + *out_annotated_json = metadata.ptr; + } + std::mem::forget(request); + std::mem::forget(metadata); + Ok(NemoRelayStatus::Ok) + } + Err(message) => Ok(callback_error(&state.host, message)), + } + })); + match result { + Ok(Ok(status)) => status, + Ok(Err(status)) => status, + Err(_) => callback_panic(&state.host, "LLM request intercept with marks callback"), + } +} + unsafe extern "C" fn typed_llm_execution_trampoline( user_data: *mut c_void, name: *const NemoRelayNativeString, diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs index d3a502ffb..7f5850070 100644 --- a/crates/plugin/tests/typed_callbacks.rs +++ b/crates/plugin/tests/typed_callbacks.rs @@ -14,16 +14,16 @@ use std::sync::{ use nemo_relay_plugin::{ AnnotatedLlmRequest, ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmNext, - LlmRequest, LlmStream, LlmStreamNext, NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, - NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, - NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, - NemoRelayNativeLlmRequestCb, NemoRelayNativeLlmRequestInterceptCb, - NemoRelayNativeLlmStreamExecutionCb, NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, - NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, - NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, NemoRelayNativeString, - NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, - NemoRelayNativeWithScopeStackCb, NemoRelayStatus, PluginContext, PluginRuntime, ScopeType, - ToolNext, + LlmRequest, LlmRequestInterceptOutcome, LlmStream, LlmStreamNext, + NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, NemoRelayNativeEventSubscriberCb, + NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, + NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, + NemoRelayNativeLlmRequestInterceptCb, NemoRelayNativeLlmStreamExecutionCb, + NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, + NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, NemoRelayNativeScopeStackBinding, + NemoRelayNativeScopeType, NemoRelayNativeString, NemoRelayNativeToolConditionalCb, + NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, NemoRelayNativeWithScopeStackCb, + NemoRelayStatus, PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, ToolNext, }; use serde_json::{Map, json}; @@ -4388,6 +4388,70 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { } } +#[test] +fn typed_llm_request_intercept_with_marks_uses_tagged_annotation_envelope() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_request_intercept_with_marks( + "llm", + 23, + false, + |_name, mut request, annotated| { + request.content["rewritten"] = json!(true); + Ok( + LlmRequestInterceptOutcome::new(request, annotated).with_pending_mark( + PendingMarkSpec::builder() + .name("plugin.request.rewritten") + .data(json!({ "saved_tokens": 7 })) + .build(), + ), + ) + }, + ) + .unwrap(); + + let registration = take_llm_request_intercept_registration(); + assert_eq!(registration.priority, 23); + assert!(!registration.break_chain); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let mut out_request = ptr::null_mut(); + let mut out_annotated = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + ptr::null(), + &mut out_request, + &mut out_annotated, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + let out_request = read_json_and_free(&host, out_request); + assert_eq!(out_request["content"]["rewritten"], true); + assert!( + out_request + .pointer("/content/__nemo_relay_llm_intercept_outcome") + .is_none() + ); + let metadata = read_json_and_free(&host, out_annotated); + assert_eq!(metadata["__nemo_relay_llm_intercept_outcome"], true); + assert_eq!(metadata["annotated_request"], Json::Null); + assert_eq!( + metadata["pending_marks"][0]["name"], + "plugin.request.rewritten" + ); + assert_eq!(metadata["pending_marks"][0]["data"]["saved_tokens"], 7); + + unsafe { + (host.string_free)(name); + (host.string_free)(request); + registration.free(); + } +} + struct DropCounter(Arc); impl Drop for DropCounter { diff --git a/crates/types/src/api/event.rs b/crates/types/src/api/event.rs index af17546ae..d08bf8180 100644 --- a/crates/types/src/api/event.rs +++ b/crates/types/src/api/event.rs @@ -364,6 +364,29 @@ pub struct MarkEvent { pub category_profile: Option, } +/// Mark requested by middleware before its owning runtime scope exists. +/// +/// The runtime assigns the parent UUID, event UUID, and timestamp when it +/// materializes the mark at the appropriate lifecycle boundary. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] +#[builder(field_defaults(setter(into, strip_option(ignore_invalid, fallback_suffix = "_opt"))))] +pub struct PendingMarkSpec { + /// Human-readable mark name. + pub name: String, + /// Optional semantic category for the mark. + #[builder(default)] + pub category: Option, + /// Optional category-specific typed fields. + #[builder(default)] + pub category_profile: Option, + /// Optional application payload attached to the mark. + #[builder(default)] + pub data: Option, + /// Optional metadata attached to the mark. + #[builder(default)] + pub metadata: Option, +} + impl MarkEvent { /// Construct a mark event from a base envelope and optional category data. /// diff --git a/crates/types/src/api/llm.rs b/crates/types/src/api/llm.rs index caa917a65..c4dba6efb 100644 --- a/crates/types/src/api/llm.rs +++ b/crates/types/src/api/llm.rs @@ -7,6 +7,15 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; use crate::Json; +use crate::api::event::PendingMarkSpec; +use crate::codec::request::AnnotatedLlmRequest; + +/// Private native-ABI tag used for marked LLM request-intercept outcomes. +/// +/// Native plugin authors should return [`LlmRequestInterceptOutcome`] through +/// the plugin SDK instead of reading or writing this field directly. +#[doc(hidden)] +pub const NATIVE_LLM_INTERCEPT_OUTCOME_FIELD: &str = "__nemo_relay_llm_intercept_outcome"; bitflags! { /// Bitflags that modify LLM-call behavior and observability. @@ -20,10 +29,47 @@ bitflags! { } /// JSON-shaped LLM request payload passed through the runtime. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct LlmRequest { /// Provider-specific request headers. pub headers: serde_json::Map, /// Provider-specific request body. pub content: Json, } + +/// Result of an LLM request intercept that can schedule lifecycle marks. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LlmRequestInterceptOutcome { + /// Rewritten provider request. + pub request: LlmRequest, + /// Optional normalized request annotation to carry forward. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub annotated_request: Option, + /// Ordered marks to emit after Relay creates and starts the LLM scope. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub pending_marks: Vec, +} + +impl LlmRequestInterceptOutcome { + /// Create an outcome without pending marks. + pub fn new(request: LlmRequest, annotated_request: Option) -> Self { + Self { + request, + annotated_request, + pending_marks: Vec::new(), + } + } + + /// Append one pending mark while preserving interceptor order. + #[must_use] + pub fn with_pending_mark(mut self, mark: PendingMarkSpec) -> Self { + self.pending_marks.push(mark); + self + } +} + +impl From<(LlmRequest, Option)> for LlmRequestInterceptOutcome { + fn from((request, annotated_request): (LlmRequest, Option)) -> Self { + Self::new(request, annotated_request) + } +} diff --git a/crates/types/tests/serialization_tests.rs b/crates/types/tests/serialization_tests.rs index 9e899ff7d..afe434778 100644 --- a/crates/types/tests/serialization_tests.rs +++ b/crates/types/tests/serialization_tests.rs @@ -6,10 +6,10 @@ use std::sync::Arc; use nemo_relay_types::api::event::{ - BaseEvent, CategoryProfile, Event, EventCategory, ScopeCategory, ScopeEvent, + BaseEvent, CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, ScopeEvent, llm_attributes_to_strings, }; -use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay_types::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; use nemo_relay_types::codec::response::AnnotatedLlmResponse; use serde_json::{Map, json}; @@ -78,3 +78,36 @@ fn event_round_trips_with_annotated_llm_profiles() { Some("resp_1") ); } + +#[test] +fn llm_request_intercept_outcome_round_trips_pending_marks() { + let outcome = LlmRequestInterceptOutcome::new( + LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "hello" }), + }, + None, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("request.optimized") + .category(EventCategory::custom()) + .category_profile( + CategoryProfile::builder() + .subtype("optimizer.saved_tokens") + .build(), + ) + .data(json!({ "saved_tokens": 12 })) + .metadata(json!({ "source": "test" })) + .build(), + ); + + let encoded = serde_json::to_value(&outcome).expect("outcome should serialize"); + assert_eq!(encoded["pending_marks"][0]["name"], "request.optimized"); + assert_eq!(encoded["pending_marks"][0]["category"], "custom"); + assert!(encoded.get("annotated_request").is_none()); + + let decoded: LlmRequestInterceptOutcome = + serde_json::from_value(encoded).expect("outcome should deserialize"); + assert_eq!(decoded, outcome); +} diff --git a/examples/rust-native-plugin/README.md b/examples/rust-native-plugin/README.md index 69e77d0e4..3cafe2979 100644 --- a/examples/rust-native-plugin/README.md +++ b/examples/rust-native-plugin/README.md @@ -70,9 +70,18 @@ The example registers the following runtime behavior: - Request and execution intercepts for tools that mutate JSON payloads and call continuations. - LLM sanitize request/response guardrails. -- LLM request, execution, and stream execution intercepts. +- An LLM request intercept that rewrites the request and schedules a mark. Relay + emits that mark after the LLM start event with the LLM scope as its parent. +- LLM execution and stream execution intercepts. - Runtime mark and scope events. - A plugin-owned isolated scope stack for non-correlated visibility. Native plugins are not sandboxed. They run in the Relay process and must not unwind across ABI callbacks. + +Request intercepts do not own an LLM lifecycle because they run before Relay +creates the LLM scope. Use `register_llm_request_intercept_with_marks` to return +`PendingMarkSpec` values. Relay emits them in interceptor order after the LLM +start event and before provider execution. The legacy +`register_llm_request_intercept` API remains available for intercepts that only +rewrite requests. diff --git a/examples/rust-native-plugin/src/lib.rs b/examples/rust-native-plugin/src/lib.rs index f18296f54..d7e4d961f 100644 --- a/examples/rust-native-plugin/src/lib.rs +++ b/examples/rust-native-plugin/src/lib.rs @@ -2,8 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use nemo_relay_plugin::{ - ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmRequest, NativePlugin, - PluginContext, PluginRuntime, ScopeCategory, ScopeType, + CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, Json, LlmJsonStream, + LlmRequest, LlmRequestInterceptOutcome, NativePlugin, PendingMarkSpec, PluginContext, + PluginRuntime, ScopeCategory, ScopeType, }; use serde_json::{Map, json}; @@ -232,12 +233,23 @@ impl NativePlugin for ExampleNativePlugin { Ok(block_llms.then(|| "LLM call blocked by Rust native plugin".to_string())) } })?; - ctx.register_llm_request_intercept("example_llm_request", 20, false, { + ctx.register_llm_request_intercept_with_marks("example_llm_request", 20, false, { let tag = config.tag.clone(); move |_name, request, annotated| { - Ok(( + Ok(LlmRequestInterceptOutcome::new( tag_llm_request(request, "native_llm_request_intercept", &tag), annotated, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("example.native.llm_request_intercept") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("example.native.request_rewrite".into()), + ..CategoryProfile::default() + }) + .data(json!({ "tag": &tag })) + .build(), )) } })?; From 192ff23fe4251ec688a980cd6523085ef5130e3f Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Mon, 29 Jun 2026 18:00:12 -0600 Subject: [PATCH 2/9] fix(llm): preserve request intercept outcomes Signed-off-by: Bryan Bednarski --- crates/core/src/api/shared.rs | 2 +- crates/core/src/plugin/dynamic/native.rs | 13 +++++++++++++ crates/core/tests/unit/shared_tests.rs | 11 +++++++++-- crates/plugin/tests/typed_callbacks.rs | 9 +++++++-- crates/types/tests/serialization_tests.rs | 10 ++++++++++ 5 files changed, 40 insertions(+), 5 deletions(-) diff --git a/crates/core/src/api/shared.rs b/crates/core/src/api/shared.rs index 9d6dff8c2..53f14a599 100644 --- a/crates/core/src/api/shared.rs +++ b/crates/core/src/api/shared.rs @@ -110,7 +110,7 @@ pub(crate) fn run_request_intercepts_with_codec( encoded.headers = outcome.request.headers; Ok((encoded, Some(Arc::new(annotated)), pending_marks)) } - _ => Ok((outcome.request, None, pending_marks)), + (_, annotated) => Ok((outcome.request, annotated.map(Arc::new), pending_marks)), } } diff --git a/crates/core/src/plugin/dynamic/native.rs b/crates/core/src/plugin/dynamic/native.rs index c541ba6d5..6582bef95 100644 --- a/crates/core/src/plugin/dynamic/native.rs +++ b/crates/core/src/plugin/dynamic/native.rs @@ -1790,9 +1790,22 @@ struct NativeLlmRequestInterceptOutcome { #[serde(rename = "__nemo_relay_llm_intercept_outcome")] _marked_outcome: bool, annotated_request: Option, + #[serde(default)] pending_marks: Vec, } +#[cfg(test)] +#[test] +fn native_llm_request_intercept_outcome_defaults_omitted_pending_marks() { + let outcome: NativeLlmRequestInterceptOutcome = serde_json::from_value(serde_json::json!({ + "__nemo_relay_llm_intercept_outcome": true + })) + .unwrap(); + + assert!(outcome.annotated_request.is_none()); + assert!(outcome.pending_marks.is_empty()); +} + fn wrap_llm_execution_fn( instance: Arc, cb: NemoRelayNativeLlmExecutionCb, diff --git a/crates/core/tests/unit/shared_tests.rs b/crates/core/tests/unit/shared_tests.rs index 416cbb4f1..c45ca5a5e 100644 --- a/crates/core/tests/unit/shared_tests.rs +++ b/crates/core/tests/unit/shared_tests.rs @@ -156,7 +156,9 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { Arc::new(|_name, mut request, annotated| { assert!(annotated.is_none()); request.headers.insert("x-no-codec".into(), json!(true)); - Ok((request, None)) + let mut annotated = SharedTestCodec.decode(&request)?; + annotated.model = Some("interceptor-model".into()); + Ok((request, Some(annotated))) }), ) .unwrap(); @@ -175,7 +177,12 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { request_without_codec.headers.get("x-no-codec"), Some(&json!(true)) ); - assert!(annotated_without_codec.is_none()); + assert_eq!( + annotated_without_codec + .as_deref() + .and_then(|annotated| annotated.model.as_deref()), + Some("interceptor-model") + ); assert!(pending_marks_without_codec.is_empty()); deregister_llm_request_intercept("shared-none").unwrap(); diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs index 7f5850070..ae88024b0 100644 --- a/crates/plugin/tests/typed_callbacks.rs +++ b/crates/plugin/tests/typed_callbacks.rs @@ -4416,6 +4416,10 @@ fn typed_llm_request_intercept_with_marks_uses_tagged_annotation_envelope() { assert!(!registration.break_chain); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let annotated = json_host_string( + &host, + serde_json::to_value(test_annotated_llm_request()).unwrap(), + ); let mut out_request = ptr::null_mut(); let mut out_annotated = ptr::null_mut(); let status = unsafe { @@ -4423,7 +4427,7 @@ fn typed_llm_request_intercept_with_marks_uses_tagged_annotation_envelope() { registration.user_data as *mut c_void, name, request, - ptr::null(), + annotated, &mut out_request, &mut out_annotated, ) @@ -4438,7 +4442,7 @@ fn typed_llm_request_intercept_with_marks_uses_tagged_annotation_envelope() { ); let metadata = read_json_and_free(&host, out_annotated); assert_eq!(metadata["__nemo_relay_llm_intercept_outcome"], true); - assert_eq!(metadata["annotated_request"], Json::Null); + assert_eq!(metadata["annotated_request"]["messages"], json!([])); assert_eq!( metadata["pending_marks"][0]["name"], "plugin.request.rewritten" @@ -4448,6 +4452,7 @@ fn typed_llm_request_intercept_with_marks_uses_tagged_annotation_envelope() { unsafe { (host.string_free)(name); (host.string_free)(request); + (host.string_free)(annotated); registration.free(); } } diff --git a/crates/types/tests/serialization_tests.rs b/crates/types/tests/serialization_tests.rs index afe434778..fb423ecf9 100644 --- a/crates/types/tests/serialization_tests.rs +++ b/crates/types/tests/serialization_tests.rs @@ -107,6 +107,16 @@ fn llm_request_intercept_outcome_round_trips_pending_marks() { assert_eq!(encoded["pending_marks"][0]["category"], "custom"); assert!(encoded.get("annotated_request").is_none()); + let mut encoded_without_pending_marks = encoded.clone(); + encoded_without_pending_marks + .as_object_mut() + .unwrap() + .remove("pending_marks"); + let decoded_without_pending_marks: LlmRequestInterceptOutcome = + serde_json::from_value(encoded_without_pending_marks) + .expect("outcome without pending marks should deserialize"); + assert!(decoded_without_pending_marks.pending_marks.is_empty()); + let decoded: LlmRequestInterceptOutcome = serde_json::from_value(encoded).expect("outcome should deserialize"); assert_eq!(decoded, outcome); From 1509dec7c4976031d330b3bd8774c3f9b7a151da Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Tue, 30 Jun 2026 14:54:23 -0600 Subject: [PATCH 3/9] feat: finalize canonical LLM request-intercept outcomes Unify request-intercept results across the core runtime, native ABI, worker protocol, C, Go, Python, Node.js, and WebAssembly. Add ordered pending lifecycle marks and remove legacy tuple, split-output, and mark-specific registration variants. Signed-off-by: Bryan Bednarski --- crates/adaptive/src/acg_component.rs | 4 +- .../adaptive/src/adaptive_hints_intercept.rs | 4 +- .../integration/runtime_integration_tests.rs | 8 +- .../tests/unit/acg_component_tests.rs | 4 +- .../unit/adaptive_hints_intercept_tests.rs | 11 +- .../tests/unit/plugin_component_tests.rs | 2 +- .../tests/unit/runtime_features_tests.rs | 20 +- crates/adaptive/tests/unit/runtime_tests.rs | 2 +- crates/core/src/api/llm.rs | 109 ++++++--- crates/core/src/api/registry.rs | 50 +---- crates/core/src/api/runtime.rs | 6 +- crates/core/src/api/runtime/callbacks.rs | 16 +- crates/core/src/api/runtime/state.rs | 22 +- crates/core/src/context/registries.rs | 4 +- crates/core/src/plugin.rs | 44 +--- crates/core/src/plugin/dynamic/native.rs | 115 +++------- crates/core/src/plugin/dynamic/worker.rs | 34 ++- crates/core/src/stream.rs | 53 +++-- .../tests/fixtures/native_plugin/src/lib.rs | 2 +- .../tests/fixtures/worker_plugin/src/main.rs | 5 +- .../tests/integration/api_surface_tests.rs | 21 +- .../tests/integration/middleware_tests.rs | 96 +++++--- .../core/tests/integration/pipeline_tests.rs | 56 +++-- .../core/tests/unit/dynamic_worker_tests.rs | 48 ++-- crates/core/tests/unit/plugin_tests.rs | 28 ++- crates/core/tests/unit/shared_tests.rs | 6 +- crates/ffi/nemo_relay.h | 26 ++- crates/ffi/src/api/mod.rs | 81 ++++++- crates/ffi/src/callable.rs | 55 ++--- crates/ffi/tests/integration/api_tests.rs | 11 +- .../tests/integration/callable_extra_tests.rs | 26 +-- crates/ffi/tests/unit/api/core_tests.rs | 54 +++++ crates/ffi/tests/unit/api/registry_tests.rs | 7 +- crates/ffi/tests/unit/api_tests.rs | 14 +- crates/ffi/tests/unit/callable_tests.rs | 31 +-- crates/node/plugin.d.ts | 9 +- crates/node/src/api/mod.rs | 18 +- crates/node/src/callable.rs | 50 ++--- crates/node/tests/llm_tests.mjs | 25 ++- crates/node/tests/scope_local_tests.mjs | 2 +- crates/plugin/src/lib.rs | 167 +++----------- crates/plugin/tests/typed_callbacks.rs | 207 ++++++++---------- crates/python/src/py_api/mod.rs | 7 +- crates/python/src/py_callable.rs | 48 +--- crates/python/src/py_types/core.rs | 137 +++++++++++- crates/python/src/py_types/mod.rs | 2 + .../python/tests/coverage/coverage_tests.rs | 11 +- .../coverage/py_adaptive_coverage_tests.rs | 2 +- .../tests/coverage/py_api_coverage_tests.rs | 14 +- .../coverage/py_callable_coverage_tests.rs | 20 +- .../coverage/py_plugin_coverage_tests.rs | 15 +- crates/types/src/api/llm.rs | 17 +- crates/types/tests/serialization_tests.rs | 19 +- crates/wasm/src/api/mod.rs | 7 +- crates/wasm/src/callable.rs | 27 ++- crates/wasm/tests/coverage/callable_tests.rs | 7 +- crates/wasm/wrappers/esm/plugin.d.ts | 9 +- .../nemo/relay/worker/v1/plugin_worker.proto | 4 +- crates/worker/src/lib.rs | 35 +-- crates/worker/tests/worker_sdk_tests.rs | 11 +- docs/build-plugins/code-examples.mdx | 7 +- docs/build-plugins/register-behavior.mdx | 7 +- .../instrument-applications/code-examples.mdx | 4 +- .../code-examples.mdx | 3 +- .../provider-codecs.mdx | 4 +- .../llm-request-intercept-outcomes.mdx | 57 +++++ examples/rust-native-plugin/README.md | 8 +- examples/rust-native-plugin/src/lib.rs | 2 +- go/nemo_relay/adaptive_plugin_test.go | 13 +- go/nemo_relay/callbacks.go | 56 +++-- go/nemo_relay/intercepts/intercepts.go | 2 +- go/nemo_relay/intercepts/intercepts_test.go | 14 +- go/nemo_relay/llm/llm.go | 2 +- go/nemo_relay/llm/llm_shorthand_test.go | 10 +- go/nemo_relay/llm_test.go | 12 +- go/nemo_relay/nemo_relay.go | 14 +- go/nemo_relay/plugin.go | 4 +- go/nemo_relay/scope_local_test.go | 4 +- go/nemo_relay/top_level_coverage_test.go | 34 +-- go/nemo_relay/wrapper_coverage_test.go | 10 +- python/nemo_relay/__init__.py | 17 +- python/nemo_relay/__init__.pyi | 10 +- python/nemo_relay/_native.pyi | 40 +++- python/nemo_relay/intercepts.py | 12 +- python/nemo_relay/llm.py | 3 +- .../langchain_tests/test_middleware.py | 2 +- python/tests/test_adaptive.py | 16 +- python/tests/test_codecs.py | 7 +- python/tests/test_integration_codecs.py | 12 +- python/tests/test_llm.py | 16 +- python/tests/test_scope_local.py | 6 +- 91 files changed, 1381 insertions(+), 972 deletions(-) create mode 100644 docs/reference/llm-request-intercept-outcomes.mdx diff --git a/crates/adaptive/src/acg_component.rs b/crates/adaptive/src/acg_component.rs index 7d5ccaddc..9d7026252 100644 --- a/crates/adaptive/src/acg_component.rs +++ b/crates/adaptive/src/acg_component.rs @@ -598,7 +598,9 @@ pub(crate) fn create_acg_llm_request_intercept( let translated = translate_request(&request, &agent_id, &provider, plugin.as_ref(), &hot_cache) .unwrap_or(request); - Ok((translated, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + translated, annotated, + )) }) } diff --git a/crates/adaptive/src/adaptive_hints_intercept.rs b/crates/adaptive/src/adaptive_hints_intercept.rs index 05613aadb..b4d9fe2b3 100644 --- a/crates/adaptive/src/adaptive_hints_intercept.rs +++ b/crates/adaptive/src/adaptive_hints_intercept.rs @@ -192,7 +192,9 @@ impl AdaptiveHintsIntercept { inject_agent_hints(&mut request, &hints); } - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }, ) } diff --git a/crates/adaptive/tests/integration/runtime_integration_tests.rs b/crates/adaptive/tests/integration/runtime_integration_tests.rs index cdee3c050..d02f4f4d4 100644 --- a/crates/adaptive/tests/integration/runtime_integration_tests.rs +++ b/crates/adaptive/tests/integration/runtime_integration_tests.rs @@ -588,7 +588,7 @@ async fn test_adaptive_plugin_registers_and_passes_calls_through() { }, ) .unwrap(); - assert_eq!(request.content["messages"], json!([])); + assert_eq!(request.request.content["messages"], json!([])); let llm_func: LlmExecutionNextFn = Arc::new(|_req: LlmRequest| Box::pin(async { Ok(json!({"response": "ok"})) })); @@ -721,7 +721,9 @@ impl Plugin for HeaderPlugin { false, Arc::new(|_name, mut request, annotated| { request.headers.insert("x-plugin".into(), json!("set")); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), )?; ctx.register_tool_request_intercept( @@ -806,7 +808,7 @@ async fn test_top_level_plugin_registers_request_and_execution_intercepts() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), Some(&json!("set"))); + assert_eq!(request.request.headers.get("x-plugin"), Some(&json!("set"))); let tool_func: ToolExecutionNextFn = Arc::new(|args| Box::pin(async move { Ok(args) })); let tool_result = tool_call_execute( diff --git a/crates/adaptive/tests/unit/acg_component_tests.rs b/crates/adaptive/tests/unit/acg_component_tests.rs index cc4c2f7e7..bbcf4d024 100644 --- a/crates/adaptive/tests/unit/acg_component_tests.rs +++ b/crates/adaptive/tests/unit/acg_component_tests.rs @@ -1087,12 +1087,14 @@ fn acg_component_request_intercept_passes_original_request_and_annotation_when_t plugin, ); - let (translated, returned_annotated) = intercept( + let outcome = intercept( "anthropic", invalid_request.clone(), Some(annotated.clone()), ) .expect("request intercept should pass through"); + let translated = outcome.request; + let returned_annotated = outcome.annotated_request; assert_eq!(translated.content, invalid_request.content); assert_eq!( diff --git a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs index a7efa190e..d68b3b016 100644 --- a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs +++ b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs @@ -194,7 +194,7 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override() stream: None, extra: serde_json::Map::new(), }; - let (request, returned_annotated) = req_fn( + let outcome = req_fn( "model", LlmRequest { headers: serde_json::Map::new(), @@ -203,6 +203,8 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override() Some(annotated.clone()), ) .unwrap(); + let request = outcome.request; + let returned_annotated = outcome.annotated_request; let body_hints = &request.content["nvext"]["agent_hints"]; assert_eq!(body_hints["osl"], serde_json::json!(150)); @@ -256,7 +258,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { })); let req_fn = AdaptiveHintsIntercept::new(hot_cache, "fallback-agent".to_string()).into_request_fn(); - let (request, annotated) = req_fn( + let outcome = req_fn( "model", LlmRequest { headers: serde_json::Map::new(), @@ -265,6 +267,8 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { None, ) .unwrap(); + let request = outcome.request; + let annotated = outcome.annotated_request; assert_eq!( request.headers.get(AGENT_HINTS_HEADER_KEY), Some(&serde_json::to_value(&defaults).unwrap()) @@ -293,7 +297,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { }); let poisoned_req_fn = AdaptiveHintsIntercept::new(poisoned_cache, "fallback-agent".to_string()).into_request_fn(); - let (poisoned_request, _) = poisoned_req_fn( + let poisoned_outcome = poisoned_req_fn( "model", LlmRequest { headers: serde_json::Map::new(), @@ -302,6 +306,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { None, ) .unwrap(); + let poisoned_request = poisoned_outcome.request; assert!( poisoned_request .headers diff --git a/crates/adaptive/tests/unit/plugin_component_tests.rs b/crates/adaptive/tests/unit/plugin_component_tests.rs index 03fe3f009..ebc6001f6 100644 --- a/crates/adaptive/tests/unit/plugin_component_tests.rs +++ b/crates/adaptive/tests/unit/plugin_component_tests.rs @@ -376,7 +376,7 @@ async fn adaptive_plugin_registers_runtime_and_rolls_back_registration() { }, ) .unwrap(); - assert!(request.headers.is_empty()); + assert!(request.request.headers.is_empty()); let mut registrations = ctx.into_registrations(); assert_eq!(registrations.len(), 1); diff --git a/crates/adaptive/tests/unit/runtime_features_tests.rs b/crates/adaptive/tests/unit/runtime_features_tests.rs index c64818606..93be673f3 100644 --- a/crates/adaptive/tests/unit/runtime_features_tests.rs +++ b/crates/adaptive/tests/unit/runtime_features_tests.rs @@ -141,7 +141,11 @@ fn assert_llm_request_intercept_registered(name: &str) { name, i32::MAX, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ), name, ); @@ -152,7 +156,11 @@ fn assert_llm_request_intercept_absent(name: &str) { name, i32::MAX, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ) .unwrap(); deregister_llm_request_intercept(name).unwrap(); @@ -553,7 +561,7 @@ async fn adaptive_hints_feature_registers_request_intercept() { }, ) .unwrap(); - assert!(request.headers.contains_key(AGENT_HINTS_HEADER_KEY)); + assert!(request.request.headers.contains_key(AGENT_HINTS_HEADER_KEY)); let mut registrations = ctx.finish(); rollback_registrations(&mut registrations); @@ -716,7 +724,11 @@ async fn registration_context_registers_all_supported_callback_types() { "adaptive_test_request", 5, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ) .unwrap(); ctx.register_llm_execution_intercept( diff --git a/crates/adaptive/tests/unit/runtime_tests.rs b/crates/adaptive/tests/unit/runtime_tests.rs index 5802a58a6..63c87429d 100644 --- a/crates/adaptive/tests/unit/runtime_tests.rs +++ b/crates/adaptive/tests/unit/runtime_tests.rs @@ -630,7 +630,7 @@ async fn adaptive_runtime_bind_scope_requires_registration_and_passes_through_wi let translated = llm_request_intercepts("anthropic", request.clone()) .expect("request intercept chain should pass through when no hot-cache state exists"); - assert_eq!(translated.content, request.content); + assert_eq!(translated.request.content, request.content); pop_scope(PopScopeParams::builder().handle_uuid(&scope.uuid).build()) .expect("scope pop should succeed"); } diff --git a/crates/core/src/api/llm.rs b/crates/core/src/api/llm.rs index 7910222e5..822a35572 100644 --- a/crates/core/src/api/llm.rs +++ b/crates/core/src/api/llm.rs @@ -14,7 +14,8 @@ use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::current_scope_stack; use crate::api::runtime::global_context; use crate::api::runtime::{ - LlmCollectorFn, LlmExecutionNextFn, LlmFinalizerFn, LlmJsonStream, LlmStreamExecutionNextFn, + EventSubscriberFn, LlmCollectorFn, LlmExecutionNextFn, LlmFinalizerFn, LlmJsonStream, + LlmStreamExecutionNextFn, }; use crate::api::scope::event; use crate::api::scope::{EmitMarkEventParams, ScopeHandle}; @@ -261,20 +262,39 @@ fn emit_llm_start( request_codec: Option<&dyn LlmCodec>, ) -> Result<()> { ensure_runtime_owner()?; - let (entries, subscribers) = { + let subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; + emit_llm_start_with_subscribers( + handle, + request, + annotated_request, + request_codec, + &subscribers, + ) +} + +fn emit_llm_start_with_subscribers( + handle: &LlmHandle, + request: &LlmRequest, + annotated_request: Option>, + request_codec: Option<&dyn LlmCodec>, + subscribers: &[EventSubscriberFn], +) -> Result<()> { + ensure_runtime_owner()?; + let entries = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_locals = scope_guard.collect_scope_local_registries(|registries| { ®istries.llm_sanitize_request_guardrails }); - let scope_subscribers = scope_guard.collect_scope_local_subscribers(); - let subscribers = snapshot_event_subscribers(scope_subscribers)?; let context = global_context(); let state = context .read() .map_err(|error| FlowError::Internal(error.to_string()))?; - let entries = state.llm_sanitize_request_entries(&scope_locals); - (entries, subscribers) + state.llm_sanitize_request_entries(&scope_locals) }; let sanitized_request = NemoRelayContextState::llm_sanitize_request_snapshot_chain(request.clone(), &entries); @@ -295,23 +315,21 @@ fn emit_llm_start( .map_err(|error| FlowError::Internal(error.to_string()))?; state.build_llm_start_event(handle, Some(input), annotated_request) }; - NemoRelayContextState::emit_event(&event, &subscribers); + NemoRelayContextState::emit_event(&event, subscribers); Ok(()) } -fn emit_pending_request_marks(handle: &LlmHandle, marks: Vec) -> Result<()> { +fn emit_pending_request_marks( + handle: &LlmHandle, + marks: Vec, + subscribers: &[EventSubscriberFn], +) -> Result<()> { if marks.is_empty() { return Ok(()); } ensure_runtime_owner()?; - let subscribers = { - let scope_stack = current_scope_stack(); - let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); - snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? - }; - for (index, mark) in marks.into_iter().enumerate() { - let timestamp = handle.started_at - + TimeDelta::microseconds(i64::try_from(index).unwrap_or_default() + 1); + let timestamp = handle.started_at + TimeDelta::microseconds(1); + for mark in marks { let event = Event::Mark(MarkEvent::new( BaseEvent::builder() .name(mark.name) @@ -323,7 +341,7 @@ fn emit_pending_request_marks(handle: &LlmHandle, marks: Vec) - mark.category, mark.category_profile, )); - NemoRelayContextState::emit_event(&event, &subscribers); + NemoRelayContextState::emit_event(&event, subscribers); } Ok(()) } @@ -417,12 +435,14 @@ pub fn llm_call_end(params: LlmCallEndParams<'_>) -> Result<()> { response_codec_errors_fatal: true, attach_estimated_cost: false, }, + None, ) } fn llm_call_end_with_behavior( params: LlmCallEndParams<'_>, behavior: LlmCallEndBehavior, + lifecycle_subscribers: Option<&[EventSubscriberFn]>, ) -> Result<()> { let LlmCallEndParams { handle, @@ -441,7 +461,10 @@ fn llm_call_end_with_behavior( ®istries.llm_sanitize_response_guardrails }); let scope_subscribers = scope_guard.collect_scope_local_subscribers(); - let subscribers = snapshot_event_subscribers(scope_subscribers)?; + let subscribers = match lifecycle_subscribers { + Some(subscribers) => subscribers.to_vec(), + None => snapshot_event_subscribers(scope_subscribers)?, + }; let context = global_context(); let state = context .read() @@ -501,13 +524,20 @@ fn llm_call_end_with_behavior( } } -fn emit_llm_end_without_output(handle: &LlmHandle, metadata: Option) -> Result<()> { +fn emit_llm_end_without_output( + handle: &LlmHandle, + metadata: Option, + lifecycle_subscribers: Option<&[EventSubscriberFn]>, +) -> Result<()> { ensure_runtime_owner()?; let (event, subscribers) = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_subscribers = scope_guard.collect_scope_local_subscribers(); - let subscribers = snapshot_event_subscribers(scope_subscribers)?; + let subscribers = match lifecycle_subscribers { + Some(subscribers) => subscribers.to_vec(), + None => snapshot_event_subscribers(scope_subscribers)?, + }; let context = global_context(); let state = context .read() @@ -630,13 +660,19 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { .model_name_opt(model_name) .build(), )?; - emit_llm_start( + let lifecycle_subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; + emit_llm_start_with_subscribers( &handle, &intercepted_request, annotated_request.clone(), request_codec.as_deref(), + &lifecycle_subscribers, )?; - emit_pending_request_marks(&handle, pending_marks)?; + emit_pending_request_marks(&handle, pending_marks, &lifecycle_subscribers)?; let execution = { let scope_stack = current_scope_stack(); @@ -664,13 +700,15 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { response_codec_errors_fatal: false, attach_estimated_cost: true, }, + Some(&lifecycle_subscribers), )?; Ok(response) } Err(error) => { let end_metadata = metadata_with_otel_status(metadata, "ERROR", Some(error.to_string())); - let _ = emit_llm_end_without_output(&handle, end_metadata); + let _ = + emit_llm_end_without_output(&handle, end_metadata, Some(&lifecycle_subscribers)); Err(error) } } @@ -787,13 +825,19 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu .model_name_opt(model_name) .build(), )?; - emit_llm_start( + let lifecycle_subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; + emit_llm_start_with_subscribers( &handle, &intercepted_request, annotated_request, request_codec.as_deref(), + &lifecycle_subscribers, )?; - emit_pending_request_marks(&handle, pending_marks)?; + emit_pending_request_marks(&handle, pending_marks, &lifecycle_subscribers)?; let execution = { let scope_stack = current_scope_stack(); @@ -810,21 +854,22 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu match execution(intercepted_request).await { Ok(raw_stream) => { - let wrapper = LlmStreamWrapper::new( + let wrapper = LlmStreamWrapper::new_managed( raw_stream, handle, collector, finalizer, - data, metadata, response_codec, + lifecycle_subscribers, ); Ok(Box::pin(wrapper) as LlmJsonStream) } Err(error) => { let end_metadata = metadata_with_otel_status(metadata, "ERROR", Some(error.to_string())); - let _ = emit_llm_end_without_output(&handle, end_metadata); + let _ = + emit_llm_end_without_output(&handle, end_metadata, Some(&lifecycle_subscribers)); Err(error) } } @@ -849,15 +894,11 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu /// # Notes /// Conditional guardrails, codecs, and execution intercepts are not run by /// this helper. -pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result { - Ok(llm_request_intercepts_with_marks(name, request)?.request) -} - -/// Run the LLM request-intercept chain and return pending lifecycle marks. +/// Run the LLM request-intercept chain and return its complete outcome. /// /// This helper does not emit the returned marks because it does not own an LLM /// lifecycle. Callers must attach them to the lifecycle they own. -pub fn llm_request_intercepts_with_marks( +pub fn llm_request_intercepts( name: &str, request: LlmRequest, ) -> Result { diff --git a/crates/core/src/api/registry.rs b/crates/core/src/api/registry.rs index d40bebe67..a4273291d 100644 --- a/crates/core/src/api/registry.rs +++ b/crates/core/src/api/registry.rs @@ -5,9 +5,9 @@ //! intercepts, and subscribers. use crate::api::runtime::{ - LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, LlmRequestInterceptWithMarksFn, - LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, - ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, + LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, LlmSanitizeRequestFn, + LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, ToolExecutionFn, + ToolInterceptFn, ToolSanitizeFn, }; use crate::api::runtime::{current_scope_stack, global_context}; use crate::api::shared::ensure_runtime_owner; @@ -548,29 +548,12 @@ global_guardrail_registry_api!( ); global_intercept_registry_api!( /// Register a global LLM request intercept that can schedule lifecycle marks. - register_llm_request_intercept_with_marks, + register_llm_request_intercept, /// Deregister a global LLM request intercept. deregister_llm_request_intercept, llm_request_intercepts, - LlmRequestInterceptWithMarksFn + LlmRequestInterceptFn ); - -/// Register a global LLM request intercept without pending marks. -pub fn register_llm_request_intercept( - name: &str, - priority: i32, - break_chain: bool, - callable: LlmRequestInterceptFn, -) -> Result<()> { - register_llm_request_intercept_with_marks( - name, - priority, - break_chain, - std::sync::Arc::new(move |name, request, annotated| { - callable(name, request, annotated).map(Into::into) - }), - ) -} global_execution_registry_api!( /// Register a global LLM execution intercept. /// Execution intercepts can wrap or replace the non-streaming provider @@ -670,31 +653,12 @@ scope_guardrail_registry_api!( ); scope_intercept_registry_api!( /// Register a scope-local LLM request intercept that can schedule lifecycle marks. - scope_register_llm_request_intercept_with_marks, + scope_register_llm_request_intercept, /// Deregister a scope-local LLM request intercept. scope_deregister_llm_request_intercept, llm_request_intercepts, - LlmRequestInterceptWithMarksFn + LlmRequestInterceptFn ); - -/// Register a scope-local LLM request intercept without pending marks. -pub fn scope_register_llm_request_intercept( - scope_uuid: &uuid::Uuid, - name: &str, - priority: i32, - break_chain: bool, - callable: LlmRequestInterceptFn, -) -> Result<()> { - scope_register_llm_request_intercept_with_marks( - scope_uuid, - name, - priority, - break_chain, - std::sync::Arc::new(move |name, request, annotated| { - callable(name, request, annotated).map(Into::into) - }), - ) -} scope_execution_registry_api!( /// Register a scope-local LLM execution intercept. /// Execution intercepts can wrap or replace the non-streaming provider diff --git a/crates/core/src/api/runtime.rs b/crates/core/src/api/runtime.rs index 4d0c372c4..2351ae352 100644 --- a/crates/core/src/api/runtime.rs +++ b/crates/core/src/api/runtime.rs @@ -11,9 +11,9 @@ pub mod subscriber_dispatcher; pub use callbacks::{ EventSubscriberFn, LlmCollectorFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, - LlmFinalizerFn, LlmJsonStream, LlmRequestInterceptFn, LlmRequestInterceptWithMarksFn, - LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, - ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, + LlmFinalizerFn, LlmJsonStream, LlmRequestInterceptFn, LlmSanitizeRequestFn, + LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, ToolConditionalFn, + ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; pub use global::global_context; pub use scope_stack::{ diff --git a/crates/core/src/api/runtime/callbacks.rs b/crates/core/src/api/runtime/callbacks.rs index dfccd2fea..7e6c0848c 100644 --- a/crates/core/src/api/runtime/callbacks.rs +++ b/crates/core/src/api/runtime/callbacks.rs @@ -163,26 +163,12 @@ pub type LlmConditionalFn = Arc Result> + /// - Third argument: Optional normalized request annotation to carry forward. /// /// # Returns -/// A [`Result`] containing the transformed request and optional annotation. +/// A [`Result`] containing the canonical request-intercept outcome. /// /// # Errors /// The callback can return any [`FlowError`](crate::error::FlowError) to abort /// the request-intercept chain. pub type LlmRequestInterceptFn = Arc< - dyn Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> - + Send - + Sync, ->; -/// Rewrite or annotate an LLM request and schedule marks under its future scope. -/// -/// This callback has the same inputs as [`LlmRequestInterceptFn`] but returns a -/// structured outcome whose pending marks are emitted after the LLM-start -/// event and before provider execution. -pub type LlmRequestInterceptWithMarksFn = Arc< dyn Fn(&str, LlmRequest, Option) -> Result + Send + Sync, diff --git a/crates/core/src/api/runtime/state.rs b/crates/core/src/api/runtime/state.rs index ca03d3b29..00b9010ab 100644 --- a/crates/core/src/api/runtime/state.rs +++ b/crates/core/src/api/runtime/state.rs @@ -20,10 +20,10 @@ use crate::api::llm::{CreateLlmHandleParams, EndLlmHandleParams}; use crate::api::llm::{LlmHandle, LlmRequest}; use crate::api::registry::{ExecutionIntercept, Guardrail, Intercept}; use crate::api::runtime::callbacks::{ - EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, - LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, - LlmStreamExecutionFn, LlmStreamExecutionNextFn, LlmStreamExecutionRegistryRefs, - ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, + EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmRequestInterceptFn, + LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn, + LlmStreamExecutionRegistryRefs, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, + ToolInterceptFn, ToolSanitizeFn, }; use crate::api::runtime::subscriber_dispatcher; use crate::api::scope::{CreateScopeHandleParams, EndScopeHandleParams, ScopeHandle, ScopeType}; @@ -63,7 +63,7 @@ pub struct NemoRelayContextState { /// Global LLM guardrails that can reject execution before the provider callback runs. pub(crate) llm_conditional_execution_guardrails: SortedRegistry>, /// Global LLM request intercepts that can rewrite or annotate requests. - pub(crate) llm_request_intercepts: SortedRegistry>, + pub(crate) llm_request_intercepts: SortedRegistry>, /// Global non-streaming LLM execution intercepts that wrap callback execution. pub(crate) llm_execution_intercepts: SortedRegistry>, /// Global streaming LLM execution intercepts that wrap stream-producing callbacks. @@ -1011,8 +1011,8 @@ impl NemoRelayContextState { /// are released. pub(crate) fn llm_request_intercept_entries( &self, - scope_locals: &[&SortedRegistry>], - ) -> Vec> { + scope_locals: &[&SortedRegistry>], + ) -> Vec> { merge_intercept_entries(&self.llm_request_intercepts, scope_locals) .into_iter() .cloned() @@ -1041,7 +1041,7 @@ impl NemoRelayContextState { name: &str, request: LlmRequest, annotated: Option, - entries: &[Intercept], + entries: &[Intercept], ) -> crate::error::Result { let mut request_value = request; let mut annotated_value = annotated; @@ -1129,11 +1129,7 @@ impl NemoRelayContextState { fn end_timestamp_after(started_at: chrono::DateTime) -> chrono::DateTime { let now = Utc::now(); - if now > started_at { - now - } else { - started_at + Duration::microseconds(1) - } + std::cmp::max(now, started_at + Duration::microseconds(1)) } impl Default for NemoRelayContextState { diff --git a/crates/core/src/context/registries.rs b/crates/core/src/context/registries.rs index c6781eee0..2a0d2fde9 100644 --- a/crates/core/src/context/registries.rs +++ b/crates/core/src/context/registries.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use crate::api::registry::{ExecutionIntercept, Guardrail, Intercept}; use crate::api::runtime::{ - EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptWithMarksFn, + EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, }; @@ -41,7 +41,7 @@ pub(crate) struct ScopeLocalRegistries { /// LLM guardrails that can reject execution before the provider callback runs. pub(crate) llm_conditional_execution_guardrails: SortedRegistry>, /// LLM request intercepts that can rewrite or annotate requests. - pub(crate) llm_request_intercepts: SortedRegistry>, + pub(crate) llm_request_intercepts: SortedRegistry>, /// Non-streaming LLM execution intercepts that wrap callback execution. pub(crate) llm_execution_intercepts: SortedRegistry>, /// Streaming LLM execution intercepts that wrap stream-producing callbacks. diff --git a/crates/core/src/plugin.rs b/crates/core/src/plugin.rs index c2934e3d0..94420f6df 100644 --- a/crates/core/src/plugin.rs +++ b/crates/core/src/plugin.rs @@ -27,16 +27,15 @@ use crate::api::registry::{ deregister_tool_request_intercept, deregister_tool_sanitize_request_guardrail, deregister_tool_sanitize_response_guardrail, register_llm_conditional_execution_guardrail, register_llm_execution_intercept, register_llm_request_intercept, - register_llm_request_intercept_with_marks, register_llm_sanitize_request_guardrail, - register_llm_sanitize_response_guardrail, register_llm_stream_execution_intercept, - register_tool_conditional_execution_guardrail, register_tool_execution_intercept, - register_tool_request_intercept, register_tool_sanitize_request_guardrail, - register_tool_sanitize_response_guardrail, + register_llm_sanitize_request_guardrail, register_llm_sanitize_response_guardrail, + register_llm_stream_execution_intercept, register_tool_conditional_execution_guardrail, + register_tool_execution_intercept, register_tool_request_intercept, + register_tool_sanitize_request_guardrail, register_tool_sanitize_response_guardrail, }; use crate::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmRequestInterceptFn, - LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, - LlmStreamExecutionFn, ToolConditionalFn, ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, + LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, ToolConditionalFn, + ToolExecutionFn, ToolInterceptFn, ToolSanitizeFn, }; use crate::api::subscriber::{deregister_subscriber, register_subscriber}; pub use nemo_relay_types::plugin::{ConfigDiagnostic, DiagnosticLevel}; @@ -351,37 +350,6 @@ impl PluginRegistrationContext { Ok(()) } - /// Registers an LLM request intercept that can schedule lifecycle marks. - pub fn register_llm_request_intercept_with_marks( - &mut self, - name: &str, - priority: i32, - break_chain: bool, - callback: LlmRequestInterceptWithMarksFn, - ) -> Result<()> { - let qualified_name = self.qualify_name(name); - register_llm_request_intercept_with_marks(&qualified_name, priority, break_chain, callback) - .map_err(|err| { - PluginError::RegistrationFailed(format!("llm request intercept: {err}")) - })?; - - let name_owned = qualified_name; - self.registrations.push(PluginRegistration::new( - "plugin", - name_owned.clone(), - Box::new(move || { - deregister_llm_request_intercept(&name_owned) - .map(|_| ()) - .map_err(|err| { - PluginError::RegistrationFailed(format!( - "llm request intercept deregistration failed: {err}" - )) - }) - }), - )); - Ok(()) - } - /// Registers a tool sanitize-request guardrail and records its rollback closure. pub fn register_tool_sanitize_request_guardrail( &mut self, diff --git a/crates/core/src/plugin/dynamic/native.rs b/crates/core/src/plugin/dynamic/native.rs index 6582bef95..381d9b8e3 100644 --- a/crates/core/src/plugin/dynamic/native.rs +++ b/crates/core/src/plugin/dynamic/native.rs @@ -15,30 +15,29 @@ use std::task::{Context, Poll}; use chrono::{DateTime, Utc}; use libloading::{Library, Symbol}; use nemo_relay_plugin::{ - NEMO_RELAY_NATIVE_ABI_VERSION, NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, - NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, - NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, - NemoRelayNativeLlmRequestInterceptCb, NemoRelayNativeLlmStreamExecutionCb, - NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, NemoRelayNativePluginEntry, - NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, - NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, NemoRelayNativeString, - NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, - NemoRelayNativeWithScopeStackCb, NemoRelayStatus, + NEMO_RELAY_NATIVE_ABI_VERSION, NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, + NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, + NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, + NemoRelayNativeLlmRequestCb, NemoRelayNativeLlmRequestInterceptCb, + NemoRelayNativeLlmStreamExecutionCb, NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, + NemoRelayNativePluginEntry, NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, + NemoRelayNativeScopeStack, NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, + NemoRelayNativeString, NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, + NemoRelayNativeToolJsonCb, NemoRelayNativeWithScopeStackCb, NemoRelayStatus, }; use semver::{Version, VersionReq}; -use serde::Deserialize; use serde_json::{Map, Value as Json}; use sha2::{Digest, Sha256}; use tokio::runtime::Runtime; use tokio_stream::{Stream, StreamExt}; -use crate::api::event::{Event, PendingMarkSpec}; +use crate::api::event::Event; use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmJsonStream, - LlmRequestInterceptWithMarksFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, - LlmStreamExecutionFn, LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionFn, - ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, + LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, + LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn, + ToolInterceptFn, ToolSanitizeFn, }; use crate::api::runtime::{ ScopeStackHandle, ThreadScopeStackBinding, capture_thread_scope_stack, create_scope_stack, @@ -49,7 +48,6 @@ use crate::api::scope::{ EmitMarkEventParams, PopScopeParams, PushScopeParams, ScopeAttributes, ScopeHandle, ScopeType, }; use crate::api::scope::{event as emit_scope_mark, get_handle, pop_scope, push_scope}; -use crate::codec::request::AnnotatedLlmRequest; use crate::error::{FlowError, Result as FlowResult}; use crate::plugin::{ ConfigDiagnostic, DiagnosticLevel, Plugin, PluginError, PluginRegistrationContext, @@ -403,6 +401,13 @@ fn validate_plugin_descriptor( plugin.struct_size ))); } + if plugin.llm_request_intercept_outcome_contract_version + != NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION + { + return Err(PluginError::InvalidConfig(format!( + "native plugin '{plugin_id}' returned an incompatible LLM request-intercept outcome contract" + ))); + } if plugin.plugin_kind.is_null() { return Err(PluginError::InvalidConfig(format!( "native plugin '{plugin_id}' returned a null plugin_kind" @@ -598,6 +603,8 @@ fn native_host_api() -> *const NemoRelayNativeHostApiV1 { scope_stack_binding_free: native_scope_stack_binding_free, scope_stack_active: native_scope_stack_active, scope_stack_with_current: native_scope_stack_with_current, + llm_request_intercept_outcome_contract_version: + NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, }) as *const _ } @@ -1333,7 +1340,7 @@ unsafe extern "C" fn native_plugin_context_register_llm_request_intercept( Ok(name) => name, Err(status) => return status, }; - match ctx.register_llm_request_intercept_with_marks( + match ctx.register_llm_request_intercept( &name, priority, break_chain, @@ -1691,7 +1698,7 @@ fn wrap_llm_request_intercept_fn( cb: NemoRelayNativeLlmRequestInterceptCb, user_data: *mut c_void, free_fn: NemoRelayNativeFreeFn, -) -> LlmRequestInterceptWithMarksFn { +) -> LlmRequestInterceptFn { let user_data = make_user_data(instance, user_data, free_fn); Arc::new(move |name, request, annotated| { clear_native_last_error(); @@ -1713,16 +1720,14 @@ fn wrap_llm_request_intercept_fn( } None => ptr::null_mut(), }; - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); let status = unsafe { cb( user_data.ptr, name_string, request_string, annotated_string, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; unsafe { @@ -1732,80 +1737,26 @@ fn wrap_llm_request_intercept_fn( } if status != NemoRelayStatus::Ok { unsafe { - native_string_free(out_request); - native_string_free(out_annotated); + native_string_free(out_outcome); } return Err(flow_error_from_status( status, "native LLM request intercept failed", )); } - let request_json = json_from_native_string( - out_request, - "native LLM request intercept returned null request", + let outcome_json = json_from_native_string( + out_outcome, + "native LLM request intercept returned null outcome", ); - let annotated_json = if out_annotated.is_null() { - Ok(None) - } else { - json_from_native_string(out_annotated, "invalid annotated request").map(Some) - }; unsafe { - native_string_free(out_request); - native_string_free(out_annotated); + native_string_free(out_outcome); } - let request_json = request_json?; - let annotated_json = annotated_json?; - let request: LlmRequest = serde_json::from_value(request_json) - .map_err(|err| FlowError::Internal(format!("invalid LLM request JSON: {err}")))?; - let (annotated_request, pending_marks) = match annotated_json { - Some(value) - if value.get(nemo_relay_types::api::llm::NATIVE_LLM_INTERCEPT_OUTCOME_FIELD) - == Some(&Json::Bool(true)) => - { - let metadata: NativeLlmRequestInterceptOutcome = serde_json::from_value(value) - .map_err(|err| { - FlowError::Internal(format!("invalid marked LLM outcome JSON: {err}")) - })?; - (metadata.annotated_request, metadata.pending_marks) - } - Some(value) => { - let annotated = - serde_json::from_value::(value).map_err(|err| { - FlowError::Internal(format!("invalid annotated request JSON: {err}")) - })?; - (Some(annotated), Vec::new()) - } - None => (None, Vec::new()), - }; - Ok(LlmRequestInterceptOutcome { - request, - annotated_request, - pending_marks, + serde_json::from_value::(outcome_json?).map_err(|err| { + FlowError::Internal(format!("invalid LLM request intercept outcome JSON: {err}")) }) }) } -#[derive(Deserialize)] -struct NativeLlmRequestInterceptOutcome { - #[serde(rename = "__nemo_relay_llm_intercept_outcome")] - _marked_outcome: bool, - annotated_request: Option, - #[serde(default)] - pending_marks: Vec, -} - -#[cfg(test)] -#[test] -fn native_llm_request_intercept_outcome_defaults_omitted_pending_marks() { - let outcome: NativeLlmRequestInterceptOutcome = serde_json::from_value(serde_json::json!({ - "__nemo_relay_llm_intercept_outcome": true - })) - .unwrap(); - - assert!(outcome.annotated_request.is_none()); - assert!(outcome.pending_marks.is_empty()); -} - fn wrap_llm_execution_fn( instance: Arc, cb: NemoRelayNativeLlmExecutionCb, diff --git a/crates/core/src/plugin/dynamic/worker.rs b/crates/core/src/plugin/dynamic/worker.rs index a4e46e562..3088c31f5 100644 --- a/crates/core/src/plugin/dynamic/worker.rs +++ b/crates/core/src/plugin/dynamic/worker.rs @@ -1125,7 +1125,7 @@ impl WorkerPluginCallback { model_name: &str, request: LlmRequest, annotated: Option, - ) -> FlowResult<(LlmRequest, Option)> { + ) -> FlowResult { let invoke = self.base_request( registration_name, RegistrationSurface::LlmRequestIntercept, @@ -1140,26 +1140,18 @@ impl WorkerPluginCallback { let response = self.invoke_blocking(invoke)?; match response.result { Some(invoke_response_result::Result::LlmRequest(result)) => { - let request = required_envelope(result.request, "llm request intercept request")?; - let request = decode_json_envelope::(&request).map_err(|err| { - FlowError::Internal(format!("worker returned invalid LLM request: {err}")) - })?; - let annotated = if result.has_annotated_request { - let envelope = required_envelope( - result.annotated_request, - "llm request intercept annotated request", - )?; - Some( - decode_json_envelope::(&envelope).map_err(|err| { - FlowError::Internal(format!( - "worker returned invalid annotated LLM request: {err}" - )) - })?, - ) - } else { - None - }; - Ok((request, annotated)) + let outcome = required_envelope(result.outcome, "llm request intercept outcome")?; + if outcome.schema != "nemo.relay.LlmRequestInterceptOutcome@1" { + return Err(FlowError::Internal(format!( + "worker returned unsupported LLM request intercept outcome schema: {}", + outcome.schema + ))); + } + decode_json_envelope(&outcome).map_err(|err| { + FlowError::Internal(format!( + "worker returned invalid LLM request intercept outcome: {err}" + )) + }) } Some(invoke_response_result::Result::Error(error)) => Err(worker_error_to_flow(error)), _ => Err(FlowError::Internal( diff --git a/crates/core/src/stream.rs b/crates/core/src/stream.rs index c2dde21f5..239228c83 100644 --- a/crates/core/src/stream.rs +++ b/crates/core/src/stream.rs @@ -35,7 +35,7 @@ use crate::api::event::{BaseEvent, MarkEvent}; use crate::api::llm::LlmHandle; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::global_context; -use crate::api::runtime::{ScopeStackHandle, current_scope_stack}; +use crate::api::runtime::{EventSubscriberFn, ScopeStackHandle, current_scope_stack}; use crate::api::shared::metadata_with_otel_status; use crate::codec::response::{AnnotatedLlmResponse, attach_estimated_cost_for_provider}; use crate::codec::traits::LlmResponseCodec; @@ -63,6 +63,7 @@ pub struct LlmStreamWrapper { finalizer: Option Json + Send>>, response_codec: Option>, metadata: Option, + subscribers: Vec, chunk_index: u64, ended: bool, } @@ -97,6 +98,36 @@ impl LlmStreamWrapper { _data: Option, metadata: Option, response_codec: Option>, + ) -> Self { + let subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + let scope_subscribers = scope_guard.collect_scope_local_subscribers(); + let context = global_context(); + context + .read() + .map(|state| state.collect_event_subscribers(&scope_subscribers)) + .unwrap_or_default() + }; + Self::new_managed( + inner, + handle, + collector, + finalizer, + metadata, + response_codec, + subscribers, + ) + } + + pub(crate) fn new_managed( + inner: Pin> + Send>>, + handle: LlmHandle, + collector: Box Result<()> + Send>, + finalizer: Box Json + Send>, + metadata: Option, + response_codec: Option>, + subscribers: Vec, ) -> Self { Self { inner, @@ -106,6 +137,7 @@ impl LlmStreamWrapper { finalizer: Some(finalizer), response_codec, metadata, + subscribers, chunk_index: 0, ended: false, } @@ -155,19 +187,17 @@ impl LlmStreamWrapper { let ss_guard = self.scope_stack.read().expect("scope stack lock poisoned"); let sl = ss_guard.collect_scope_local_registries(|r| &r.llm_sanitize_response_guardrails); - let sl_subs = ss_guard.collect_scope_local_subscribers(); let ctx = global_context(); let state = ctx.read(); match state { Ok(state) => { - let subscribers = state.collect_event_subscribers(&sl_subs); let entries = state.llm_sanitize_response_entries(&sl); - Some((entries, subscribers)) + Some(entries) } Err(_) => None, } }; - let Some((entries, subscribers)) = snapshot else { + let Some(entries) = snapshot else { return; }; let sanitized = @@ -197,7 +227,7 @@ impl LlmStreamWrapper { } }; if let Some(event) = event_snapshot { - NemoRelayContextState::emit_event(&event, &subscribers); + NemoRelayContextState::emit_event(&event, &self.subscribers); } } @@ -205,15 +235,10 @@ impl LlmStreamWrapper { fn emit_chunk_mark(&self, chunk_index: u64, raw_chunk: &Json) { let data = llm_chunk_mark_data(chunk_index, raw_chunk); let event_snapshot = { - let Ok(ss_guard) = self.scope_stack.read() else { - return; - }; - let sl_subs = ss_guard.collect_scope_local_subscribers(); let ctx = global_context(); let state = ctx.read(); match state { Ok(state) => { - let subscribers = state.collect_event_subscribers(&sl_subs); let event = state.create_event(MarkEvent::new( BaseEvent::builder() .name("llm.chunk") @@ -223,13 +248,13 @@ impl LlmStreamWrapper { None, None, )); - Some((event, subscribers)) + Some(event) } Err(_) => None, } }; - if let Some((event, subscribers)) = event_snapshot { - NemoRelayContextState::emit_event(&event, &subscribers); + if let Some(event) = event_snapshot { + NemoRelayContextState::emit_event(&event, &self.subscribers); } } } diff --git a/crates/core/tests/fixtures/native_plugin/src/lib.rs b/crates/core/tests/fixtures/native_plugin/src/lib.rs index 4eeebd5ba..cb99cb66c 100644 --- a/crates/core/tests/fixtures/native_plugin/src/lib.rs +++ b/crates/core/tests/fixtures/native_plugin/src/lib.rs @@ -114,7 +114,7 @@ impl NativePlugin for FixtureNativePlugin { 0, |_request| Ok(None), )?; - ctx.register_llm_request_intercept_with_marks( + ctx.register_llm_request_intercept( "fixture_llm_request_intercept", 0, false, diff --git a/crates/core/tests/fixtures/worker_plugin/src/main.rs b/crates/core/tests/fixtures/worker_plugin/src/main.rs index c6e1e2509..28de01f8a 100644 --- a/crates/core/tests/fixtures/worker_plugin/src/main.rs +++ b/crates/core/tests/fixtures/worker_plugin/src/main.rs @@ -188,7 +188,10 @@ impl WorkerPlugin for FixtureWorkerPlugin { .insert("worker_plugin_annotated_request".into(), json!(true)); annotated }); - Ok((mark_llm_request(request, "worker_plugin_llm_request_intercept"), annotated)) + Ok(nemo_relay_worker::LlmRequestInterceptOutcome::new( + mark_llm_request(request, "worker_plugin_llm_request_intercept"), + annotated, + )) }, ); ctx.register_llm_execution_intercept( diff --git a/crates/core/tests/integration/api_surface_tests.rs b/crates/core/tests/integration/api_surface_tests.rs index e751c15b3..3bc8538a0 100644 --- a/crates/core/tests/integration/api_surface_tests.rs +++ b/crates/core/tests/integration/api_surface_tests.rs @@ -423,7 +423,11 @@ fn test_global_registry_and_subscriber_wrappers_cover_success_and_duplicates() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ) .unwrap(); assert!(deregister_llm_request_intercept("llm-request").unwrap()); @@ -612,7 +616,11 @@ fn test_scope_registry_and_subscriber_wrappers_cover_success_duplicates_and_miss "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ) .unwrap(); assert!(scope_deregister_llm_request_intercept(&scope.uuid, "llm-request").unwrap()); @@ -931,7 +939,9 @@ async fn test_llm_api_emits_sanitized_events_and_covers_error_paths() { false, Arc::new(|_name, mut request, annotated| { request.headers.insert("x-intercepted".into(), json!(true)); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); @@ -940,7 +950,10 @@ async fn test_llm_api_emits_sanitized_events_and_covers_error_paths() { make_llm_request(json!({"messages": [{"role": "user", "content": "hello"}]})), ) .unwrap(); - assert_eq!(intercepted.headers.get("x-intercepted"), Some(&json!(true))); + assert_eq!( + intercepted.request.headers.get("x-intercepted"), + Some(&json!(true)) + ); deregister_llm_request_intercept("llm-request").unwrap(); register_llm_conditional_execution_guardrail( diff --git a/crates/core/tests/integration/middleware_tests.rs b/crates/core/tests/integration/middleware_tests.rs index 0394f6c10..8e8c9ea92 100644 --- a/crates/core/tests/integration/middleware_tests.rs +++ b/crates/core/tests/integration/middleware_tests.rs @@ -19,7 +19,7 @@ use nemo_relay::api::event::{ }; use nemo_relay::api::llm::{ LlmCallExecuteParams, LlmStreamCallExecuteParams, llm_call_execute, llm_request_intercepts, - llm_request_intercepts_with_marks, llm_stream_call_execute, + llm_stream_call_execute, }; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::registry::{ @@ -30,14 +30,13 @@ use nemo_relay::api::registry::{ deregister_tool_request_intercept, deregister_tool_sanitize_request_guardrail, deregister_tool_sanitize_response_guardrail, register_llm_conditional_execution_guardrail, register_llm_execution_intercept, register_llm_request_intercept, - register_llm_request_intercept_with_marks, register_llm_sanitize_request_guardrail, - register_llm_sanitize_response_guardrail, register_llm_stream_execution_intercept, - register_tool_conditional_execution_guardrail, register_tool_execution_intercept, - register_tool_request_intercept, register_tool_sanitize_request_guardrail, - register_tool_sanitize_response_guardrail, scope_register_llm_conditional_execution_guardrail, - scope_register_llm_execution_intercept, scope_register_llm_request_intercept, - scope_register_llm_sanitize_request_guardrail, scope_register_llm_sanitize_response_guardrail, - scope_register_llm_stream_execution_intercept, + register_llm_sanitize_request_guardrail, register_llm_sanitize_response_guardrail, + register_llm_stream_execution_intercept, register_tool_conditional_execution_guardrail, + register_tool_execution_intercept, register_tool_request_intercept, + register_tool_sanitize_request_guardrail, register_tool_sanitize_response_guardrail, + scope_register_llm_conditional_execution_guardrail, scope_register_llm_execution_intercept, + scope_register_llm_request_intercept, scope_register_llm_sanitize_request_guardrail, + scope_register_llm_sanitize_response_guardrail, scope_register_llm_stream_execution_intercept, scope_register_tool_conditional_execution_guardrail, scope_register_tool_execution_intercept, scope_register_tool_request_intercept, scope_register_tool_sanitize_request_guardrail, scope_register_tool_sanitize_response_guardrail, @@ -1703,13 +1702,17 @@ fn test_llm_request_intercept_registry_mutations_apply_to_later_calls() { Arc::new(move |_, request, annotated| { record_middleware_callback(&tracked, "llm_request_late"); assert_middleware_callback_locks_are_free(); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); } - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); @@ -1722,7 +1725,7 @@ fn test_llm_request_intercept_registry_mutations_apply_to_later_calls() { }, ) .unwrap(); - assert_eq!(request.content["round"], 1); + assert_eq!(request.request.content["round"], 1); assert_middleware_callback_labels(&callbacks, &["llm_request_initial"]); callbacks.lock().unwrap().clear(); @@ -1734,7 +1737,7 @@ fn test_llm_request_intercept_registry_mutations_apply_to_later_calls() { }, ) .unwrap(); - assert_eq!(request.content["round"], 2); + assert_eq!(request.request.content["round"], 2); assert_middleware_callback_labels(&callbacks, &["llm_request_initial", "llm_request_late"]); deregister_llm_request_intercept("snapshot_llm_request_initial").unwrap(); @@ -1950,7 +1953,9 @@ async fn test_llm_middleware_callbacks_run_without_registry_or_scope_locks() { Arc::new(move |_, request, annotated| { record_middleware_callback(&tracked, "llm_request_global"); assert_middleware_callback_locks_are_free(); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); @@ -1963,7 +1968,9 @@ async fn test_llm_middleware_callbacks_run_without_registry_or_scope_locks() { Arc::new(move |_, request, annotated| { record_middleware_callback(&tracked, "llm_request_scope"); assert_middleware_callback_locks_are_free(); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); @@ -2574,7 +2581,9 @@ async fn test_llm_request_intercept_transforms() { false, Arc::new(|_name: &str, mut req: LlmRequest, annotated| { req.headers.insert("x-intercepted".into(), json!(true)); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -2585,7 +2594,7 @@ async fn test_llm_request_intercept_transforms() { }; let result = llm_request_intercepts("test_llm", request).unwrap(); - assert_eq!(result.headers["x-intercepted"], true); + assert_eq!(result.request.headers["x-intercepted"], true); // Cleanup deregister_llm_request_intercept("llm_req_i").unwrap(); @@ -2602,7 +2611,7 @@ fn test_llm_request_intercept_pending_marks_preserve_order_and_break_chain() { ("pending_break", 2, true, "second"), ("pending_skipped", 3, false, "skipped"), ] { - register_llm_request_intercept_with_marks( + register_llm_request_intercept( name, priority, break_chain, @@ -2614,7 +2623,7 @@ fn test_llm_request_intercept_pending_marks_preserve_order_and_break_chain() { .unwrap(); } - let outcome = llm_request_intercepts_with_marks( + let outcome = llm_request_intercepts( "llm", LlmRequest { headers: serde_json::Map::new(), @@ -2652,13 +2661,13 @@ async fn test_managed_llm_emits_pending_marks_under_started_scope() { ) .unwrap(); - register_llm_request_intercept_with_marks( + register_llm_request_intercept( "pending_managed", 1, false, Arc::new(|_name, request, annotated| { - Ok( - LlmRequestInterceptOutcome::new(request, annotated).with_pending_mark( + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark( PendingMarkSpec::builder() .name("request.optimized") .category(EventCategory::custom()) @@ -2669,8 +2678,12 @@ async fn test_managed_llm_emits_pending_marks_under_started_scope() { ) .data(json!({"saved_tokens": 12})) .build(), - ), - ) + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("request.optimized.second") + .build(), + )) }), ) .unwrap(); @@ -2694,12 +2707,9 @@ async fn test_managed_llm_emits_pending_marks_under_started_scope() { .unwrap(); let provider_request = provider_request.lock().unwrap().clone().unwrap(); - assert!( - serde_json::to_value(provider_request) - .unwrap() - .get("__nemo_relay_llm_intercept_outcome") - .is_none() - ); + let provider_json = serde_json::to_value(provider_request).unwrap(); + assert!(provider_json.get("pending_marks").is_none()); + assert!(provider_json.get("annotated_request").is_none()); let captured = captured_events_snapshot(&events); let start = captured @@ -2713,8 +2723,22 @@ async fn test_managed_llm_emits_pending_marks_under_started_scope() { .iter() .find(|event| event.name() == "request.optimized") .unwrap(); + let second_mark = captured + .iter() + .find(|event| event.name() == "request.optimized.second") + .unwrap(); + let end = captured + .iter() + .find(|event| { + event.name() == "pending-managed-llm" + && event.scope_category() == Some(ScopeCategory::End) + }) + .unwrap(); assert_eq!(mark.parent_uuid(), Some(start.uuid())); + assert_eq!(second_mark.parent_uuid(), Some(start.uuid())); assert!(mark.timestamp() > start.timestamp()); + assert_eq!(mark.timestamp(), second_mark.timestamp()); + assert!(end.timestamp() >= mark.timestamp()); assert_eq!(mark.data().unwrap()["saved_tokens"], 12); deregister_llm_request_intercept("pending_managed").unwrap(); @@ -2734,7 +2758,7 @@ async fn test_failed_request_intercept_does_not_emit_pending_marks_or_start_scop Arc::new(move |event: &Event| captured.lock().unwrap().push(event.clone())), ) .unwrap(); - register_llm_request_intercept_with_marks( + register_llm_request_intercept( "pending_before_failure", 1, false, @@ -2744,7 +2768,7 @@ async fn test_failed_request_intercept_does_not_emit_pending_marks_or_start_scop }), ) .unwrap(); - register_llm_request_intercept_with_marks( + register_llm_request_intercept( "pending_failure", 2, false, @@ -2864,7 +2888,9 @@ async fn test_llm_start_emits_before_short_circuit_execution_intercept() { .as_object_mut() .unwrap() .insert("phase".into(), json!("request")); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -2956,7 +2982,9 @@ async fn test_llm_stream_start_emits_before_short_circuit_execution_intercept() .as_object_mut() .unwrap() .insert("phase".into(), json!("request")); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); diff --git a/crates/core/tests/integration/pipeline_tests.rs b/crates/core/tests/integration/pipeline_tests.rs index eefe05cd3..4431a6102 100644 --- a/crates/core/tests/integration/pipeline_tests.rs +++ b/crates/core/tests/integration/pipeline_tests.rs @@ -310,7 +310,9 @@ async fn test_decode_runs_before_intercepts() { false, Arc::new(move |_name, req, annotated| { *cap.lock().unwrap() = Some(annotated.clone()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -362,7 +364,10 @@ async fn test_encode_runs_after_intercepts() { Arc::new(|_name, req, annotated| { let mut ann = annotated.unwrap(); ann.model = Some("modified".into()); - Ok((req, Some(ann))) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, + Some(ann), + )) }), ) .unwrap(); @@ -424,7 +429,9 @@ async fn test_annotated_intercept_receives_both() { false, Arc::new(move |_name, req, annotated| { *cp.lock().unwrap() = Some((req.clone(), annotated.clone())); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -465,12 +472,12 @@ async fn test_annotated_intercept_receives_both() { // =========================================================================== #[tokio::test] -async fn test_legacy_intercept_backward_compat() { +async fn test_canonical_intercept_with_and_without_codec() { let _lock = TEST_MUTEX.lock().unwrap(); reset_global(); setup_isolated_thread(); - // Part 1: Legacy intercept with no Codec + // Part 1: canonical intercept with no codec. let legacy_called_1 = Arc::new(Mutex::new(false)); let lc1 = legacy_called_1.clone(); register_llm_request_intercept( @@ -480,7 +487,9 @@ async fn test_legacy_intercept_backward_compat() { Arc::new(move |_name, mut req, annotated| { *lc1.lock().unwrap() = true; req.headers.insert("x-legacy".into(), json!("was-here")); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -510,7 +519,7 @@ async fn test_legacy_intercept_backward_compat() { // Cleanup part 1 deregister_llm_request_intercept("legacy_1").unwrap(); - // Part 2: Legacy intercept WITH Codec — legacy intercept still runs + // Part 2: canonical intercept with a codec. reset_global(); let (codec, _, _) = make_tracking_codec("codec_D"); @@ -524,7 +533,9 @@ async fn test_legacy_intercept_backward_compat() { Arc::new(move |_name, mut req, annotated| { *lc2.lock().unwrap() = true; req.headers.insert("x-legacy-2".into(), json!("also-here")); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -580,7 +591,9 @@ async fn test_stream_path_also_decodes() { false, Arc::new(move |_name, req, annotated| { *ca.lock().unwrap() = Some(annotated.clone()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -644,7 +657,9 @@ async fn test_shared_helper_both_paths() { false, Arc::new(move |_name, req, annotated| { *acc.lock().unwrap() += 1; - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -721,7 +736,9 @@ async fn test_explicit_codec_param_overrides() { if let Some(ref ann) = annotated { *cm.lock().unwrap() = ann.model.clone(); } - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -768,7 +785,10 @@ async fn test_encode_merge_not_replace() { Arc::new(|_name, req, annotated| { let mut ann = annotated.unwrap(); ann.model = Some("new_model".into()); - Ok((req, Some(ann))) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, + Some(ann), + )) }), ) .unwrap(); @@ -836,7 +856,9 @@ async fn test_unified_chain_priority_order() { false, Arc::new(move |_name, req, annotated| { cl1.lock().unwrap().push("legacy_p10".into()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -849,7 +871,9 @@ async fn test_unified_chain_priority_order() { false, Arc::new(move |_name, req, annotated| { cl2.lock().unwrap().push("annotated_p5".into()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -897,7 +921,9 @@ async fn test_no_codec_annotated_intercept_receives_none() { false, Arc::new(move |_name, req, annotated| { *ca.lock().unwrap() = Some(annotated.clone()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); diff --git a/crates/core/tests/unit/dynamic_worker_tests.rs b/crates/core/tests/unit/dynamic_worker_tests.rs index aa77c0433..ea06cc67c 100644 --- a/crates/core/tests/unit/dynamic_worker_tests.rs +++ b/crates/core/tests/unit/dynamic_worker_tests.rs @@ -335,29 +335,31 @@ async fn callback_helpers_cover_worker_response_edges() { }, "llm_intercept_invalid_request" => InvokeResponse { result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { - request: Some(JsonEnvelope { - schema: LLM_REQUEST_SCHEMA.into(), - json: b"null".to_vec(), + outcome: Some(JsonEnvelope { + schema: "nemo.relay.LlmRequestInterceptOutcome@1".into(), + json: br#"{"request":null}"#.to_vec(), }), - annotated_request: None, - has_annotated_request: false, })), }, "llm_intercept_missing_annotated" => InvokeResponse { result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { - request: Some(valid_llm_request_envelope()), - annotated_request: None, - has_annotated_request: true, + outcome: Some(JsonEnvelope { + schema: "nemo.relay.LegacyLlmRequestInterceptResult@1".into(), + json: br#"{}"#.to_vec(), + }), })), }, "llm_intercept_invalid_annotated" => InvokeResponse { result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { - request: Some(valid_llm_request_envelope()), - annotated_request: Some(JsonEnvelope { - schema: ANNOTATED_LLM_REQUEST_SCHEMA.into(), - json: b"null".to_vec(), + outcome: Some(JsonEnvelope { + schema: "nemo.relay.LlmRequestInterceptOutcome@1".into(), + json: serde_json::to_vec(&json!({ + "request": valid_llm_request(), + "annotated_request": 3, + "pending_marks": [], + })) + .unwrap(), }), - has_annotated_request: true, })), }, "llm_intercept_error" => InvokeResponse { @@ -405,7 +407,11 @@ async fn callback_helpers_cover_worker_response_edges() { None, ) .expect_err("invalid LLM intercept request should fail"); - assert!(error.to_string().contains("invalid LLM request")); + assert!( + error + .to_string() + .contains("invalid LLM request intercept outcome") + ); let error = callback .invoke_llm_request_intercept( @@ -414,11 +420,11 @@ async fn callback_helpers_cover_worker_response_edges() { valid_llm_request(), None, ) - .expect_err("missing annotated request should fail when flagged present"); + .expect_err("legacy outcome schema should fail"); assert!( error .to_string() - .contains("llm request intercept annotated request is missing") + .contains("unsupported LLM request intercept outcome schema") ); let error = callback @@ -429,7 +435,11 @@ async fn callback_helpers_cover_worker_response_edges() { None, ) .expect_err("invalid annotated request should fail"); - assert!(error.to_string().contains("invalid annotated LLM request")); + assert!( + error + .to_string() + .contains("invalid LLM request intercept outcome") + ); let error = callback .invoke_llm_request_intercept("llm_intercept_error", "model", valid_llm_request(), None) @@ -968,10 +978,6 @@ fn valid_llm_request() -> LlmRequest { } } -fn valid_llm_request_envelope() -> JsonEnvelope { - json_envelope(LLM_REQUEST_SCHEMA, &valid_llm_request()).expect("llm request envelope") -} - async fn fake_callback_service( invoke: impl Fn(InvokeRequest) -> InvokeResponse + Send + Sync + 'static, ) -> (WorkerPluginCallback, oneshot::Sender<()>) { diff --git a/crates/core/tests/unit/plugin_tests.rs b/crates/core/tests/unit/plugin_tests.rs index 0e127cbcb..8eedbbbd8 100644 --- a/crates/core/tests/unit/plugin_tests.rs +++ b/crates/core/tests/unit/plugin_tests.rs @@ -10,7 +10,7 @@ use std::sync::{Mutex, OnceLock}; use serde_json::json; -use crate::api::llm::LlmRequest; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::api::llm::{llm_conditional_execution, llm_request_intercepts}; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::global_context; @@ -97,7 +97,7 @@ impl Plugin for TestPlugin { false, Arc::new(|_name, mut request, annotated| { request.headers.insert("x-plugin".into(), json!(true)); - Ok((request, annotated)) + Ok(LlmRequestInterceptOutcome::new(request, annotated)) }), ) }) @@ -475,7 +475,7 @@ fn test_plugin_registration_context_registers_and_rolls_back() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), Some(&json!(true))); + assert_eq!(request.request.headers.get("x-plugin"), Some(&json!(true))); let mut registrations = ctx.into_registrations(); rollback_registrations(&mut registrations); @@ -488,7 +488,7 @@ fn test_plugin_registration_context_registers_and_rolls_back() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), None); + assert_eq!(request.request.headers.get("x-plugin"), None); reset_global(); } @@ -519,7 +519,7 @@ fn test_initialize_plugins_registers_and_clears_components() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), Some(&json!(true))); + assert_eq!(request.request.headers.get("x-plugin"), Some(&json!(true))); clear_plugin_configuration().unwrap(); let request = llm_request_intercepts( @@ -530,7 +530,7 @@ fn test_initialize_plugins_registers_and_clears_components() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), None); + assert_eq!(request.request.headers.get("x-plugin"), None); reset_global(); } @@ -763,7 +763,9 @@ fn test_plugin_registration_context_covers_all_registration_helpers() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), ) .unwrap(); ctx.register_llm_execution_intercept( @@ -1063,7 +1065,9 @@ fn test_plugin_registration_context_maps_duplicate_registration_errors() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), ) .unwrap(); expect_registration_failed( @@ -1071,7 +1075,9 @@ fn test_plugin_registration_context_maps_duplicate_registration_errors() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), ), "llm request intercept:", ); @@ -1250,7 +1256,9 @@ fn test_plugin_registration_context_maps_deregistration_errors() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), ) .unwrap(); ctx.register_tool_sanitize_request_guardrail( diff --git a/crates/core/tests/unit/shared_tests.rs b/crates/core/tests/unit/shared_tests.rs index c45ca5a5e..851ea4eee 100644 --- a/crates/core/tests/unit/shared_tests.rs +++ b/crates/core/tests/unit/shared_tests.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use serde_json::{Map, json}; -use crate::api::llm::LlmRequest; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::api::registry::{deregister_llm_request_intercept, register_llm_request_intercept}; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::global_context; @@ -158,7 +158,7 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { request.headers.insert("x-no-codec".into(), json!(true)); let mut annotated = SharedTestCodec.decode(&request)?; annotated.model = Some("interceptor-model".into()); - Ok((request, Some(annotated))) + Ok(LlmRequestInterceptOutcome::new(request, Some(annotated))) }), ) .unwrap(); @@ -194,7 +194,7 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { let mut annotated = annotated.expect("codec should provide annotated request"); annotated.model = Some("intercepted-model".into()); request.headers.insert("x-codec".into(), json!(true)); - Ok((request, Some(annotated))) + Ok(LlmRequestInterceptOutcome::new(request, Some(annotated))) }), ) .unwrap(); diff --git a/crates/ffi/nemo_relay.h b/crates/ffi/nemo_relay.h index f07b3b2dc..b8bbe8269 100644 --- a/crates/ffi/nemo_relay.h +++ b/crates/ffi/nemo_relay.h @@ -254,15 +254,14 @@ typedef char *(*NemoRelayLlmConditionalCb)(void *user_data, const struct FfiLLMR * C callback type for LLM request intercepts with unified annotated-aware * signature. Receives the intercept name, the opaque `FfiLLMRequest`, and * optionally the annotated request as a JSON C string (null if no Codec - * resolved). Writes transformed outputs to `out_request` and - * `out_annotated_json`. Returns `NemoRelayStatus`. + * resolved). Writes one owned canonical outcome JSON string to + * `out_outcome_json`. Returns `NemoRelayStatus`. */ typedef NemoRelayStatus (*NemoRelayLlmRequestInterceptCb)(void *user_data, const char *name, const struct FfiLLMRequest *request, const char *annotated_json, - struct FfiLLMRequest **out_request, - char **out_annotated_json); + char **out_outcome_json); /** * Runtime-provided "next" callback for LLM execution middleware chain. @@ -399,6 +398,25 @@ NemoRelayStatus nemo_relay_llm_request_intercepts(const char *name, const char *native_json, char **out); +/** + * Allocate canonical JSON for a C LLM request-intercept callback result. + * + * `annotated_json` may be null. `pending_marks_json` may be null, in which + * case an empty list is serialized. The caller owns the returned string and + * must release it with [`nemo_relay_string_free`]. + * + * # Safety + * + * `request` must point to a live [`FfiLLMRequest`], optional JSON inputs must + * be valid null-terminated strings when non-null, and `out_outcome_json` must + * be writable. The caller must free a successful output with + * [`nemo_relay_string_free`]. + */ +NemoRelayStatus nemo_relay_llm_request_intercept_outcome_json_new(const struct FfiLLMRequest *request, + const char *annotated_json, + const char *pending_marks_json, + char **out_outcome_json); + /** * Run the registered LLM conditional execution guardrail chain. * diff --git a/crates/ffi/src/api/mod.rs b/crates/ffi/src/api/mod.rs index a20e27ccf..c6d60889a 100644 --- a/crates/ffi/src/api/mod.rs +++ b/crates/ffi/src/api/mod.rs @@ -35,14 +35,15 @@ use crate::error::{ status_from_plugin_error, }; use crate::types::{ - FfiAtifExporter, FfiAtofExporter, FfiCodecHandle, FfiLLMHandle, FfiOpenInferenceSubscriber, - FfiOpenTelemetrySubscriber, FfiPluginContext, FfiScopeHandle, FfiScopeStack, - FfiThreadScopeStackBinding, FfiToolHandle, NemoRelayScopeType, + FfiAtifExporter, FfiAtofExporter, FfiCodecHandle, FfiLLMHandle, FfiLLMRequest, + FfiOpenInferenceSubscriber, FfiOpenTelemetrySubscriber, FfiPluginContext, FfiScopeHandle, + FfiScopeStack, FfiThreadScopeStackBinding, FfiToolHandle, NemoRelayScopeType, }; pub use crate::types::{nemo_relay_openinference_subscriber_free, nemo_relay_otel_subscriber_free}; use libc::c_char; +use nemo_relay::api::event::PendingMarkSpec; use nemo_relay::api::llm as core_llm_api; -use nemo_relay::api::llm::{LlmAttributes, LlmRequest}; +use nemo_relay::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::registry as core_registry_api; use nemo_relay::api::runtime::{LlmExecutionNextFn, LlmStreamExecutionNextFn, ToolExecutionNextFn}; use nemo_relay::api::runtime::{ @@ -239,6 +240,78 @@ pub unsafe extern "C" fn nemo_relay_llm_request_intercepts( } } +/// Allocate canonical JSON for a C LLM request-intercept callback result. +/// +/// `annotated_json` may be null. `pending_marks_json` may be null, in which +/// case an empty list is serialized. The caller owns the returned string and +/// must release it with [`nemo_relay_string_free`]. +/// +/// # Safety +/// +/// `request` must point to a live [`FfiLLMRequest`], optional JSON inputs must +/// be valid null-terminated strings when non-null, and `out_outcome_json` must +/// be writable. The caller must free a successful output with +/// [`nemo_relay_string_free`]. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn nemo_relay_llm_request_intercept_outcome_json_new( + request: *const FfiLLMRequest, + annotated_json: *const c_char, + pending_marks_json: *const c_char, + out_outcome_json: *mut *mut c_char, +) -> NemoRelayStatus { + clear_last_error(); + if request.is_null() || out_outcome_json.is_null() { + set_last_error("request and out_outcome_json must be non-null"); + return NemoRelayStatus::NullPointer; + } + unsafe { *out_outcome_json = std::ptr::null_mut() }; + let annotated_request = if annotated_json.is_null() { + None + } else { + let value = match c_str_to_json(annotated_json) { + Some(value) => value, + None => return NemoRelayStatus::InvalidJson, + }; + match serde_json::from_value(value) { + Ok(value) => Some(value), + Err(error) => { + set_last_error(&format!("invalid annotated request JSON: {error}")); + return NemoRelayStatus::InvalidJson; + } + } + }; + let pending_marks = if pending_marks_json.is_null() { + Vec::new() + } else { + let value = match c_str_to_json(pending_marks_json) { + Some(value) => value, + None => return NemoRelayStatus::InvalidJson, + }; + match serde_json::from_value::>(value) { + Ok(value) => value, + Err(error) => { + set_last_error(&format!("invalid pending marks JSON: {error}")); + return NemoRelayStatus::InvalidJson; + } + } + }; + let outcome = LlmRequestInterceptOutcome { + request: unsafe { &*request }.0.clone(), + annotated_request, + pending_marks, + }; + match serde_json::to_value(outcome) { + Ok(value) => { + unsafe { *out_outcome_json = json_to_c_string(&value) }; + NemoRelayStatus::Ok + } + Err(error) => { + set_last_error(&format!("failed to serialize intercept outcome: {error}")); + NemoRelayStatus::Internal + } + } +} + /// Run the registered LLM conditional execution guardrail chain. /// /// Returns `NemoRelayStatus::Ok` if all guardrails pass, or diff --git a/crates/ffi/src/callable.rs b/crates/ffi/src/callable.rs index ea3bf4c28..3e1537311 100644 --- a/crates/ffi/src/callable.rs +++ b/crates/ffi/src/callable.rs @@ -31,7 +31,7 @@ use serde_json::Value as Json; use tokio_stream::{Stream, StreamExt}; use nemo_relay::api::event::Event; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::codec::request::AnnotatedLlmRequest as AnnotatedLLMRequest; use nemo_relay::codec::traits::LlmCodec; use nemo_relay::error::{FlowError, Result}; @@ -171,15 +171,14 @@ pub type NemoRelayCodecEncodeFn = Option< /// C callback type for LLM request intercepts with unified annotated-aware /// signature. Receives the intercept name, the opaque `FfiLLMRequest`, and /// optionally the annotated request as a JSON C string (null if no Codec -/// resolved). Writes transformed outputs to `out_request` and -/// `out_annotated_json`. Returns `NemoRelayStatus`. +/// resolved). Writes one owned canonical outcome JSON string to +/// `out_outcome_json`. Returns `NemoRelayStatus`. pub type NemoRelayLlmRequestInterceptCb = unsafe extern "C" fn( user_data: *mut libc::c_void, name: *const c_char, request: *const FfiLLMRequest, annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus; /// Callback for collecting intercepted stream chunks. Invoked with each chunk @@ -559,8 +558,8 @@ pub fn wrap_json_fn( /// Wrap a C LLM request intercept callback (annotated-aware) into a Rust /// `LlmRequestInterceptFn` closure. The callback receives the intercept name, -/// the opaque `FfiLLMRequest`, and the annotated JSON (or null). It writes -/// the transformed request and annotated JSON to output pointers. +/// the opaque `FfiLLMRequest`, and the annotated JSON (or null). It writes one +/// owned canonical outcome JSON string. pub fn wrap_llm_request_intercept_fn( cb: NemoRelayLlmRequestInterceptCb, user_data: *mut libc::c_void, @@ -587,9 +586,7 @@ pub fn wrap_llm_request_intercept_fn( std::ptr::null() }; - // Initialize output pointers - let mut out_request: *mut FfiLLMRequest = std::ptr::null_mut(); - let mut out_annotated: *mut c_char = std::ptr::null_mut(); + let mut out_outcome: *mut c_char = std::ptr::null_mut(); let status = unsafe { cb( @@ -597,8 +594,7 @@ pub fn wrap_llm_request_intercept_fn( c_name.as_ptr(), ffi_req, annotated_ptr, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; @@ -606,32 +602,29 @@ pub fn wrap_llm_request_intercept_fn( unsafe { drop(Box::from_raw(ffi_req)) }; if status != NemoRelayStatus::Ok { + unsafe { nemo_relay_string_free_internal(out_outcome) }; let message = last_error_message() .unwrap_or_else(|| "request intercept callback failed".to_string()); return Err(FlowError::Internal(message)); } - // Read output request - let new_request = if out_request.is_null() { + if out_outcome.is_null() { return Err(FlowError::Internal( - "request intercept returned null out_request".to_string(), + "request intercept returned null out_outcome_json".to_string(), )); - } else { - let boxed = unsafe { Box::from_raw(out_request) }; - boxed.0 - }; - - // Read output annotated - let new_annotated = if out_annotated.is_null() { - None - } else { - let s = unsafe { CStr::from_ptr(out_annotated) }.to_string_lossy(); - let parsed: Option = serde_json::from_str(&s).ok(); - unsafe { nemo_relay_string_free_internal(out_annotated) }; - parsed - }; - - Ok((new_request, new_annotated)) + } + let outcome = unsafe { CStr::from_ptr(out_outcome) } + .to_str() + .map_err(|error| FlowError::Internal(format!("invalid outcome UTF-8: {error}"))) + .and_then(|json| { + serde_json::from_str::(json).map_err(|error| { + FlowError::Internal(format!( + "invalid LLM request intercept outcome JSON: {error}" + )) + }) + }); + unsafe { nemo_relay_string_free_internal(out_outcome) }; + outcome }, ) } diff --git a/crates/ffi/tests/integration/api_tests.rs b/crates/ffi/tests/integration/api_tests.rs index 5f38fbcbe..d6668df4f 100644 --- a/crates/ffi/tests/integration/api_tests.rs +++ b/crates/ffi/tests/integration/api_tests.rs @@ -377,10 +377,15 @@ unsafe extern "C" fn llm_request_intercept_cb( _name: *const c_char, request: *const FfiLLMRequest, _annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { - unsafe { *out_request = llm_request_cb(ptr::null_mut(), request) }; + let transformed = unsafe { Box::from_raw(llm_request_cb(ptr::null_mut(), request)) }; + let outcome = json!({ + "request": transformed.0, + "annotated_request": null, + "pending_marks": [], + }); + unsafe { *out_outcome_json = CString::new(outcome.to_string()).unwrap().into_raw() }; NemoRelayStatus::Ok } diff --git a/crates/ffi/tests/integration/callable_extra_tests.rs b/crates/ffi/tests/integration/callable_extra_tests.rs index 0401abe2b..d82c158f0 100644 --- a/crates/ffi/tests/integration/callable_extra_tests.rs +++ b/crates/ffi/tests/integration/callable_extra_tests.rs @@ -40,8 +40,7 @@ unsafe extern "C" fn llm_request_intercept_status_error_cb( _name: *const c_char, _request: *const FfiLLMRequest, _annotated_json: *const c_char, - _out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + _out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { NemoRelayStatus::Internal } @@ -51,8 +50,7 @@ unsafe extern "C" fn llm_request_intercept_null_out_request_cb( _name: *const c_char, _request: *const FfiLLMRequest, _annotated_json: *const c_char, - _out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + _out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { NemoRelayStatus::Ok } @@ -60,15 +58,11 @@ unsafe extern "C" fn llm_request_intercept_null_out_request_cb( unsafe extern "C" fn llm_request_intercept_invalid_annotated_cb( _user_data: *mut libc::c_void, _name: *const c_char, - request: *const FfiLLMRequest, + _request: *const FfiLLMRequest, _annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { - unsafe { - *out_request = Box::into_raw(Box::new(FfiLLMRequest((&*request).0.clone()))); - *out_annotated_json = CString::new("not-json").unwrap().into_raw(); - } + unsafe { *out_outcome_json = CString::new("not-json").unwrap().into_raw() }; NemoRelayStatus::Ok } @@ -217,16 +211,18 @@ fn test_callable_extra_request_intercept_and_codec_paths() { None, ); let err = intercept_null("llm", request.clone(), None).unwrap_err(); - assert!(err.to_string().contains("null out_request")); + assert!(err.to_string().contains("null out_outcome_json")); let intercept_invalid_annotated = wrap_llm_request_intercept_fn( llm_request_intercept_invalid_annotated_cb, ptr::null_mut(), None, ); - let (_request_out, annotated_out) = intercept_invalid_annotated("llm", request.clone(), None) - .expect("invalid annotated JSON should be tolerated as None"); - assert!(annotated_out.is_none()); + let err = intercept_invalid_annotated("llm", request.clone(), None).unwrap_err(); + assert!( + err.to_string() + .contains("invalid LLM request intercept outcome JSON") + ); let sanitize = wrap_llm_sanitize_request_fn(llm_request_passthrough_cb, ptr::null_mut(), None); let sanitized = sanitize(request.clone()); diff --git a/crates/ffi/tests/unit/api/core_tests.rs b/crates/ffi/tests/unit/api/core_tests.rs index 08ce98c2c..735ca68ca 100644 --- a/crates/ffi/tests/unit/api/core_tests.rs +++ b/crates/ffi/tests/unit/api/core_tests.rs @@ -5,6 +5,60 @@ use super::*; +#[test] +fn test_ffi_llm_request_intercept_outcome_json_allocation_and_validation() { + let headers = cstring(r#"{"x-test":true}"#); + let content = cstring(r#"{"model":"test"}"#); + let request = unsafe { nemo_relay_llm_request_new(headers.as_ptr(), content.as_ptr()) }; + assert!(!request.is_null()); + + let marks = cstring(r#"[{"name":"first"},{"name":"second","data":{"order":2}}]"#); + let mut outcome_json = ptr::null_mut(); + assert_eq!( + unsafe { + api::nemo_relay_llm_request_intercept_outcome_json_new( + request, + ptr::null(), + marks.as_ptr(), + &mut outcome_json, + ) + }, + NemoRelayStatus::Ok + ); + let outcome = unsafe { returned_json(outcome_json) }; + assert_eq!(outcome["request"]["headers"]["x-test"], true); + assert_eq!(outcome["annotated_request"], Json::Null); + assert_eq!(outcome["pending_marks"][0]["name"], "first"); + assert_eq!(outcome["pending_marks"][1]["data"]["order"], 2); + + let malformed_marks = cstring(r#"{"name":"not-an-array"}"#); + assert_eq!( + unsafe { + api::nemo_relay_llm_request_intercept_outcome_json_new( + request, + ptr::null(), + malformed_marks.as_ptr(), + &mut outcome_json, + ) + }, + NemoRelayStatus::InvalidJson + ); + assert!(outcome_json.is_null()); + assert_eq!( + unsafe { + api::nemo_relay_llm_request_intercept_outcome_json_new( + ptr::null(), + ptr::null(), + ptr::null(), + &mut outcome_json, + ) + }, + NemoRelayStatus::NullPointer + ); + + unsafe { nemo_relay_llm_request_free(request) }; +} + #[test] fn test_ffi_plugin_config_validate_initialize_and_clear() { let _guard = TEST_MUTEX.lock().unwrap(); diff --git a/crates/ffi/tests/unit/api/registry_tests.rs b/crates/ffi/tests/unit/api/registry_tests.rs index 378358057..72afba048 100644 --- a/crates/ffi/tests/unit/api/registry_tests.rs +++ b/crates/ffi/tests/unit/api/registry_tests.rs @@ -381,7 +381,7 @@ fn test_ffi_helper_rejection_and_null_name_paths() { NemoRelayStatus::Ok ); let llm_json = returned_json(llm_out); - assert_eq!(llm_json["content"]["model"], json!("ffi-model")); + assert_eq!(llm_json["request"]["content"]["model"], json!("ffi-model")); assert_eq!( nemo_relay_llm_request_intercepts(llm_name.as_ptr(), request.as_ptr(), ptr::null_mut()), @@ -1536,7 +1536,10 @@ fn test_ffi_llm_execute_stream_and_atif_exporter() { NemoRelayStatus::Ok ); let helper_json = returned_json(helper_out); - assert_eq!(helper_json["content"]["intercepted"], json!(true)); + assert_eq!( + helper_json["request"]["content"]["intercepted"], + json!(true) + ); assert_eq!( nemo_relay_llm_conditional_execution(request.as_ptr()), diff --git a/crates/ffi/tests/unit/api_tests.rs b/crates/ffi/tests/unit/api_tests.rs index 8f4e728a0..78f37919d 100644 --- a/crates/ffi/tests/unit/api_tests.rs +++ b/crates/ffi/tests/unit/api_tests.rs @@ -374,10 +374,15 @@ unsafe extern "C" fn llm_request_intercept_cb( _name: *const c_char, request: *const FfiLLMRequest, _annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { - unsafe { *out_request = llm_request_cb(ptr::null_mut(), request) }; + let transformed = unsafe { Box::from_raw(llm_request_cb(ptr::null_mut(), request)) }; + let outcome = json!({ + "request": transformed.0, + "annotated_request": null, + "pending_marks": [], + }); + unsafe { *out_outcome_json = CString::new(outcome.to_string()).unwrap().into_raw() }; NemoRelayStatus::Ok } @@ -386,8 +391,7 @@ unsafe extern "C" fn llm_request_intercept_fail_cb( _name: *const c_char, _request: *const FfiLLMRequest, _annotated_json: *const c_char, - _out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + _out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { crate::error::set_last_error("llm request intercept callback failed"); NemoRelayStatus::Internal diff --git a/crates/ffi/tests/unit/callable_tests.rs b/crates/ffi/tests/unit/callable_tests.rs index 3089f5e83..b8fc9a30e 100644 --- a/crates/ffi/tests/unit/callable_tests.rs +++ b/crates/ffi/tests/unit/callable_tests.rs @@ -103,20 +103,24 @@ unsafe extern "C" fn llm_request_intercept_cb( _name: *const c_char, request: *const FfiLLMRequest, annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { let mut req = unsafe { (&*request).0.clone() }; req.content["intercepted"] = json!(true); - unsafe { *out_request = Box::into_raw(Box::new(FfiLLMRequest(req))) }; - if annotated_json.is_null() { - unsafe { *out_annotated_json = std::ptr::null_mut() }; + let annotated = if annotated_json.is_null() { + Json::Null } else { let s = unsafe { CStr::from_ptr(annotated_json) } .to_string_lossy() .into_owned(); - unsafe { *out_annotated_json = CString::new(s).unwrap().into_raw() }; - } + serde_json::from_str(&s).unwrap() + }; + let outcome = json!({ + "request": req, + "annotated_request": annotated, + "pending_marks": [], + }); + unsafe { *out_outcome_json = CString::new(outcome.to_string()).unwrap().into_raw() }; NemoRelayStatus::Ok } @@ -283,8 +287,8 @@ fn test_wrap_tool_exec_and_intercept_callbacks() { fn test_wrap_llm_request_response_and_conditional_callbacks() { let request_intercept = wrap_llm_request_intercept_fn(llm_request_intercept_cb, std::ptr::null_mut(), None); - let (intercepted, _annotated) = request_intercept("llm", make_request(), None).unwrap(); - assert_eq!(intercepted.content["intercepted"], json!(true)); + let outcome = request_intercept("llm", make_request(), None).unwrap(); + assert_eq!(outcome.request.content["intercepted"], json!(true)); let sanitize_request = wrap_llm_sanitize_request_fn(llm_request_null_cb, std::ptr::null_mut(), None); @@ -338,10 +342,11 @@ fn test_wrap_llm_request_intercept_with_annotated_input() { stream: None, extra: serde_json::Map::from_iter([("annotated".into(), json!(true))]), }; - let (intercepted, annotated_out) = - request_intercept("llm", make_request(), Some(annotated)).unwrap(); - assert_eq!(intercepted.content["intercepted"], json!(true)); - let annotated_out = annotated_out.expect("expected annotated request output"); + let outcome = request_intercept("llm", make_request(), Some(annotated)).unwrap(); + assert_eq!(outcome.request.content["intercepted"], json!(true)); + let annotated_out = outcome + .annotated_request + .expect("expected annotated request output"); assert_eq!(annotated_out.model.as_deref(), Some("test-model")); assert_eq!(annotated_out.extra.get("annotated"), Some(&json!(true))); } diff --git a/crates/node/plugin.d.ts b/crates/node/plugin.d.ts index 4fbb6be8a..c0368b3bb 100644 --- a/crates/node/plugin.d.ts +++ b/crates/node/plugin.d.ts @@ -84,7 +84,14 @@ export interface PluginContext { breakChain: boolean, callback: (args: { name: string; request: Json; annotated: Json | null }) => { request: Json; - annotated: Json | null; + annotated?: Json | null; + pendingMarks?: Array<{ + name: string; + category?: string | null; + category_profile?: Json; + data?: Json; + metadata?: Json; + }>; }, ): void; /** Register an LLM execution intercept for this component. */ diff --git a/crates/node/src/api/mod.rs b/crates/node/src/api/mod.rs index 7f8ccdd14..0300714b9 100644 --- a/crates/node/src/api/mod.rs +++ b/crates/node/src/api/mod.rs @@ -2259,6 +2259,9 @@ pub fn register_llm_request_intercept( name: String, priority: i32, break_chain: bool, + #[napi( + ts_arg_type = "(args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; pendingMarks?: Array<{ name: string; category?: string | null; category_profile?: Json; data?: Json; metadata?: Json }> }" + )] callable: ThreadsafeFunction, ) -> Result<()> { core_registry_api::register_llm_request_intercept( @@ -2729,6 +2732,9 @@ pub fn scope_register_llm_request_intercept( name: String, priority: i32, break_chain: bool, + #[napi( + ts_arg_type = "(args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; pendingMarks?: Array<{ name: string; category?: string | null; category_profile?: Json; data?: Json; metadata?: Json }> }" + )] callable: ThreadsafeFunction, ) -> Result<()> { let uuid = uuid::Uuid::parse_str(&scope_uuid) @@ -2942,7 +2948,9 @@ pub fn tool_conditional_execution(env: Env, name: String, args: Json) -> Result< /// Run the registered LLM request intercept chain on the given request. /// The `request` should be a JSON object with `headers` and `content` fields matching /// the `LlmRequest` schema. Returns the transformed request as JSON. -#[napi(ts_return_type = "Promise")] +#[napi( + ts_return_type = "Promise<{ request: Json; annotated: Json | null; pendingMarks: Array<{ name: string; category?: string | null; category_profile?: Json; data?: Json; metadata?: Json }> }>" +)] pub fn llm_request_intercepts(env: Env, name: String, request: Json) -> Result { let llm_request: LlmRequest = serde_json::from_value(request) .map_err(|e| napi::Error::from_reason(format!("invalid LlmRequest: {e}")))?; @@ -2952,7 +2960,13 @@ pub fn llm_request_intercepts(env: Env, name: String, request: Json) -> Result, ) -> LlmRequestInterceptFn { @@ -217,7 +218,7 @@ pub fn wrap_js_llm_request_intercept_fn( move |name: &str, request: LlmRequest, annotated: Option| - -> Result<(LlmRequest, Option)> { + -> Result { let func = func.clone(); let req_json = serde_json::to_value(&request).unwrap_or(Json::Null); let annotated_json = annotated @@ -245,32 +246,23 @@ pub fn wrap_js_llm_request_intercept_fn( } let result = recv_json_result(rx, "JS LLM request intercept callback failed")?; - // Validate expected shape: { "request": {...}, "annotated": ... } - let obj = result.as_object().ok_or_else(|| { - FlowError::Internal( - "JS LLM request intercept: expected object with 'request' and 'annotated' fields".to_string(), - ) - })?; - - let new_request: LlmRequest = serde_json::from_value( - obj.get("request").cloned().unwrap_or(Json::Null), - ) - .map_err(|e| { - FlowError::Internal(format!( - "JS LLM request intercept: failed to deserialize request: {e}" - )) + #[derive(Deserialize)] + #[serde(rename_all = "camelCase")] + struct JsOutcome { + request: LlmRequest, + #[serde(default)] + annotated: Option, + #[serde(default)] + pending_marks: Vec, + } + let outcome: JsOutcome = serde_json::from_value(result).map_err(|e| { + FlowError::Internal(format!("invalid JS LLM request intercept outcome: {e}")) })?; - - let new_annotated: Option = match obj.get("annotated") { - Some(Json::Null) | None => None, - Some(val) => Some(serde_json::from_value(val.clone()).map_err(|e| { - FlowError::Internal(format!( - "JS LLM request intercept: failed to deserialize annotated: {e}" - )) - })?), - }; - - Ok((new_request, new_annotated)) + Ok(LlmRequestInterceptOutcome { + request: outcome.request, + annotated_request: outcome.annotated, + pending_marks: outcome.pending_marks, + }) }, ) } diff --git a/crates/node/tests/llm_tests.mjs b/crates/node/tests/llm_tests.mjs index fdc9000ba..2f69cb1e7 100644 --- a/crates/node/tests/llm_tests.mjs +++ b/crates/node/tests/llm_tests.mjs @@ -592,7 +592,7 @@ describe('LLM intercepts', () => { null, null, ), - /expected object with 'request' and 'annotated' fields/i, + /invalid JS LLM request intercept outcome/i, ); } finally { deregisterLlmRequestIntercept('node_llm_req_bad'); @@ -826,11 +826,32 @@ describe('LLM intercepts', () => { return { request, annotated, + pendingMarks: [ + { name: 'request.first', data: { order: 1 } }, + { name: 'request.second', metadata: { source: 'node' } }, + ], }; }); const result = await llmRequestIntercepts('helper_llm', makeNative()); - assert.equal(result.content.helper, true); + assert.equal(result.request.content.helper, true); + assert.equal(result.annotated, null); + assert.deepEqual(result.pendingMarks, [ + { + name: 'request.first', + category: null, + category_profile: null, + data: { order: 1 }, + metadata: null, + }, + { + name: 'request.second', + category: null, + category_profile: null, + data: null, + metadata: { source: 'node' }, + }, + ]); deregisterLlmRequestIntercept('node_llm_req_helper'); }); diff --git a/crates/node/tests/scope_local_tests.mjs b/crates/node/tests/scope_local_tests.mjs index ebefb65d1..313e78e00 100644 --- a/crates/node/tests/scope_local_tests.mjs +++ b/crates/node/tests/scope_local_tests.mjs @@ -913,7 +913,7 @@ describe('Scope-local LLM intercepts', () => { null, null, ), - /expected object with 'request' and 'annotated' fields/i, + /invalid JS LLM request intercept outcome/i, ); } finally { scopeDeregisterLlmRequestIntercept(scope.uuid, 'sl_llm_req_bad'); diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 007322cf7..70fb32ce2 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -30,6 +30,8 @@ use serde_json::Map; /// Native plugin ABI version supported by this crate. pub const NEMO_RELAY_NATIVE_ABI_VERSION: u32 = 1; +/// Final canonical LLM request-intercept outcome contract version. +pub const NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION: u32 = 1; /// Status codes returned by stable native ABI functions. #[repr(i32)] @@ -253,8 +255,7 @@ pub type NemoRelayNativeLlmRequestInterceptCb = unsafe extern "C" fn( name: *const NemoRelayNativeString, request_json: *const NemoRelayNativeString, annotated_json: *const NemoRelayNativeString, - out_request_json: *mut *mut NemoRelayNativeString, - out_annotated_json: *mut *mut NemoRelayNativeString, + out_outcome_json: *mut *mut NemoRelayNativeString, ) -> NemoRelayStatus; /// Native LLM execution intercept callback. @@ -496,6 +497,8 @@ pub struct NemoRelayNativeHostApiV1 { cb: NemoRelayNativeWithScopeStackCb, user_data: *mut c_void, ) -> NemoRelayStatus, + /// Required canonical LLM request-intercept outcome contract version. + pub llm_request_intercept_outcome_contract_version: u32, } // The host API table is immutable after construction. Function pointers and @@ -520,6 +523,8 @@ pub struct NemoRelayNativePluginV1 { pub register: Option, /// Optional plugin-owned state destructor. pub drop: NemoRelayNativePluginDropFn, + /// Required canonical LLM request-intercept outcome contract version. + pub llm_request_intercept_outcome_contract_version: u32, } impl Default for NemoRelayNativePluginV1 { @@ -532,6 +537,8 @@ impl Default for NemoRelayNativePluginV1 { validate: None, register: None, drop: None, + llm_request_intercept_outcome_contract_version: + NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, } } } @@ -1469,11 +1476,7 @@ impl<'a> PluginContext<'a> { callback: F, ) -> Result<()> where - F: Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> + F: Fn(&str, LlmRequest, Option) -> Result + Send + Sync + 'static, @@ -1492,39 +1495,6 @@ impl<'a> PluginContext<'a> { finish_typed_registration::(self.host, status, user_data, "llm request intercept") } - /// Registers a typed LLM request intercept that can schedule lifecycle marks. - pub fn register_llm_request_intercept_with_marks( - &mut self, - name: &str, - priority: i32, - break_chain: bool, - callback: F, - ) -> Result<()> - where - F: Fn(&str, LlmRequest, Option) -> Result - + Send - + Sync - + 'static, - { - let user_data = typed_callback_user_data(self.host, callback); - let status = unsafe { - self.register_llm_request_intercept_raw( - name, - priority, - break_chain, - typed_llm_request_intercept_with_marks_trampoline::, - user_data, - Some(drop_typed_callback::), - ) - }; - finish_typed_registration::( - self.host, - status, - user_data, - "LLM request intercept with marks", - ) - } - /// Registers a typed LLM execution intercept. pub fn register_llm_execution_intercept( &mut self, @@ -2156,89 +2126,7 @@ unsafe extern "C" fn typed_llm_request_intercept_trampoline( name: *const NemoRelayNativeString, request_json: *const NemoRelayNativeString, annotated_json: *const NemoRelayNativeString, - out_request_json: *mut *mut NemoRelayNativeString, - out_annotated_json: *mut *mut NemoRelayNativeString, -) -> NemoRelayStatus -where - F: Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> - + Send - + Sync - + 'static, -{ - if user_data.is_null() || out_request_json.is_null() || out_annotated_json.is_null() { - return NemoRelayStatus::NullPointer; - } - unsafe { - *out_request_json = ptr::null_mut(); - *out_annotated_json = ptr::null_mut(); - } - let state = unsafe { &*(user_data as *const TypedCallback) }; - let result = catch_unwind(AssertUnwindSafe(|| { - let name = read_required_host_string(&state.host, name, "LLM name")?; - let request: LlmRequest = read_json_value(&state.host, request_json, "LLM request")?; - let annotated: Option = - read_optional_json_value(&state.host, annotated_json, "annotated LLM request")?; - match (state.callback)(&name, request, annotated) { - Ok((request, annotated)) => { - let Some(request) = HostString::from_json(&state.host, &request) else { - set_last_error(&state.host, "failed to allocate LLM request output"); - return Ok(NemoRelayStatus::Internal); - }; - let annotated = match annotated.as_ref() { - Some(annotated) => { - let Some(annotated) = HostString::from_json(&state.host, annotated) else { - set_last_error( - &state.host, - "failed to allocate annotated LLM request output", - ); - return Ok(NemoRelayStatus::Internal); - }; - Some(annotated) - } - None => None, - }; - unsafe { - *out_request_json = request.ptr; - *out_annotated_json = annotated - .as_ref() - .map(|annotated| annotated.ptr) - .unwrap_or(ptr::null_mut()); - } - std::mem::forget(request); - if let Some(annotated) = annotated { - std::mem::forget(annotated); - } - Ok(NemoRelayStatus::Ok) - } - Err(message) => Ok(callback_error(&state.host, message)), - } - })); - match result { - Ok(Ok(status)) => status, - Ok(Err(status)) => status, - Err(_) => callback_panic(&state.host, "LLM request intercept callback"), - } -} - -#[derive(Serialize)] -struct NativeLlmRequestInterceptOutcome<'a> { - #[serde(rename = "__nemo_relay_llm_intercept_outcome")] - marked_outcome: bool, - annotated_request: &'a Option, - pending_marks: &'a [PendingMarkSpec], -} - -unsafe extern "C" fn typed_llm_request_intercept_with_marks_trampoline( - user_data: *mut c_void, - name: *const NemoRelayNativeString, - request_json: *const NemoRelayNativeString, - annotated_json: *const NemoRelayNativeString, - out_request_json: *mut *mut NemoRelayNativeString, - out_annotated_json: *mut *mut NemoRelayNativeString, + out_outcome_json: *mut *mut NemoRelayNativeString, ) -> NemoRelayStatus where F: Fn(&str, LlmRequest, Option) -> Result @@ -2246,12 +2134,11 @@ where + Sync + 'static, { - if user_data.is_null() || out_request_json.is_null() || out_annotated_json.is_null() { + if user_data.is_null() || out_outcome_json.is_null() { return NemoRelayStatus::NullPointer; } unsafe { - *out_request_json = ptr::null_mut(); - *out_annotated_json = ptr::null_mut(); + *out_outcome_json = ptr::null_mut(); } let state = unsafe { &*(user_data as *const TypedCallback) }; let result = catch_unwind(AssertUnwindSafe(|| { @@ -2261,25 +2148,14 @@ where read_optional_json_value(&state.host, annotated_json, "annotated LLM request")?; match (state.callback)(&name, request, annotated) { Ok(outcome) => { - let Some(request) = HostString::from_json(&state.host, &outcome.request) else { - set_last_error(&state.host, "failed to allocate LLM request output"); - return Ok(NemoRelayStatus::Internal); - }; - let metadata = NativeLlmRequestInterceptOutcome { - marked_outcome: true, - annotated_request: &outcome.annotated_request, - pending_marks: &outcome.pending_marks, - }; - let Some(metadata) = HostString::from_json(&state.host, &metadata) else { - set_last_error(&state.host, "failed to allocate marked LLM outcome"); + let Some(outcome) = HostString::from_json(&state.host, &outcome) else { + set_last_error(&state.host, "failed to allocate LLM request outcome"); return Ok(NemoRelayStatus::Internal); }; unsafe { - *out_request_json = request.ptr; - *out_annotated_json = metadata.ptr; + *out_outcome_json = outcome.ptr; } - std::mem::forget(request); - std::mem::forget(metadata); + std::mem::forget(outcome); Ok(NemoRelayStatus::Ok) } Err(message) => Ok(callback_error(&state.host, message)), @@ -2288,7 +2164,7 @@ where match result { Ok(Ok(status)) => status, Ok(Err(status)) => status, - Err(_) => callback_panic(&state.host, "LLM request intercept with marks callback"), + Err(_) => callback_panic(&state.host, "LLM request intercept callback"), } } @@ -2779,6 +2655,11 @@ where if host_ref.struct_size < std::mem::size_of::() { return NemoRelayStatus::InvalidArg; } + if host_ref.llm_request_intercept_outcome_contract_version + != NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION + { + return NemoRelayStatus::InvalidArg; + } export_plugin_checked(host_ref, out, constructor()) } @@ -2813,6 +2694,8 @@ fn export_plugin_checked( validate: Some(validate_trampoline::

), register: Some(register_trampoline::

), drop: Some(drop_plugin_state::

), + llm_request_intercept_outcome_contract_version: + NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, }; } std::mem::forget(kind_handle); diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs index ae88024b0..a61d1e99d 100644 --- a/crates/plugin/tests/typed_callbacks.rs +++ b/crates/plugin/tests/typed_callbacks.rs @@ -15,9 +15,10 @@ use std::sync::{ use nemo_relay_plugin::{ AnnotatedLlmRequest, ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmNext, LlmRequest, LlmRequestInterceptOutcome, LlmStream, LlmStreamNext, - NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, NemoRelayNativeEventSubscriberCb, - NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, - NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, + NEMO_RELAY_NATIVE_ABI_VERSION, NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, + NativePlugin, NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, + NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, + NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, NemoRelayNativeLlmRequestInterceptCb, NemoRelayNativeLlmStreamExecutionCb, NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, NemoRelayNativeScopeStackBinding, @@ -296,17 +297,17 @@ fn native_abi_v1_struct_sizes_are_self_describing() { #[cfg(target_pointer_width = "64")] { assert_eq!(align_of::(), 8); - assert_eq!(size_of::(), 272); + assert_eq!(size_of::(), 280); assert_eq!( host_api_offsets(), [ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, - 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, + 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, ] ); assert_eq!(align_of::(), 8); - assert_eq!(size_of::(), 56); - assert_eq!(plugin_offsets(), [0, 8, 16, 24, 32, 40, 48]); + assert_eq!(size_of::(), 64); + assert_eq!(plugin_offsets(), [0, 8, 16, 24, 32, 40, 48, 56]); assert_eq!(align_of::(), 8); assert_eq!(size_of::(), 40); assert_eq!(stream_offsets(), [0, 8, 16, 24, 32]); @@ -315,24 +316,24 @@ fn native_abi_v1_struct_sizes_are_self_describing() { #[cfg(target_pointer_width = "32")] { assert_eq!(align_of::(), 4); - assert_eq!(size_of::(), 136); + assert_eq!(size_of::(), 140); assert_eq!( host_api_offsets(), [ 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, - 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, + 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, ] ); assert_eq!(align_of::(), 4); - assert_eq!(size_of::(), 28); - assert_eq!(plugin_offsets(), [0, 4, 8, 12, 16, 20, 24]); + assert_eq!(size_of::(), 32); + assert_eq!(plugin_offsets(), [0, 4, 8, 12, 16, 20, 24, 28]); assert_eq!(align_of::(), 4); assert_eq!(size_of::(), 20); assert_eq!(stream_offsets(), [0, 4, 8, 12, 16]); } } -fn host_api_offsets() -> [usize; 34] { +fn host_api_offsets() -> [usize; 35] { [ offset_of!(NemoRelayNativeHostApiV1, abi_version), offset_of!(NemoRelayNativeHostApiV1, struct_size), @@ -401,10 +402,14 @@ fn host_api_offsets() -> [usize; 34] { offset_of!(NemoRelayNativeHostApiV1, scope_stack_binding_free), offset_of!(NemoRelayNativeHostApiV1, scope_stack_active), offset_of!(NemoRelayNativeHostApiV1, scope_stack_with_current), + offset_of!( + NemoRelayNativeHostApiV1, + llm_request_intercept_outcome_contract_version + ), ] } -fn plugin_offsets() -> [usize; 7] { +fn plugin_offsets() -> [usize; 8] { [ offset_of!(NemoRelayNativePluginV1, struct_size), offset_of!(NemoRelayNativePluginV1, plugin_kind), @@ -413,6 +418,10 @@ fn plugin_offsets() -> [usize; 7] { offset_of!(NemoRelayNativePluginV1, validate), offset_of!(NemoRelayNativePluginV1, register), offset_of!(NemoRelayNativePluginV1, drop), + offset_of!( + NemoRelayNativePluginV1, + llm_request_intercept_outcome_contract_version + ), ] } @@ -1091,6 +1100,8 @@ fn test_host() -> NemoRelayNativeHostApiV1 { scope_stack_binding_free: capture_scope_stack_binding_free, scope_stack_active: true_scope_stack_active, scope_stack_with_current: capture_scope_stack_with_current, + llm_request_intercept_outcome_contract_version: + NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, } } @@ -2580,14 +2591,13 @@ fn typed_callbacks_reject_null_abi_pointers_before_decoding_inputs() { let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm-request-intercept", 0, false, |_name, request, ann| { - Ok((request, ann)) + Ok(LlmRequestInterceptOutcome::new(request, ann)) }) .unwrap(); let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); assert_eq!( unsafe { (registration.cb)( @@ -2595,8 +2605,7 @@ fn typed_callbacks_reject_null_abi_pointers_before_decoding_inputs() { name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }, NemoRelayStatus::NullPointer @@ -2609,20 +2618,6 @@ fn typed_callbacks_reject_null_abi_pointers_before_decoding_inputs() { request, ptr::null(), ptr::null_mut(), - &mut out_annotated, - ) - }, - NemoRelayStatus::NullPointer - ); - assert_eq!( - unsafe { - (registration.cb)( - registration.user_data as *mut c_void, - name, - request, - ptr::null(), - &mut out_request, - ptr::null_mut(), ) }, NemoRelayStatus::NullPointer @@ -2889,14 +2884,13 @@ fn typed_callbacks_report_invalid_json_for_each_decoder_family() { let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm-request", 0, false, |_name, request, ann| { - Ok((request, ann)) + Ok(LlmRequestInterceptOutcome::new(request, ann)) }) .unwrap(); let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let bad_request = host_string(&host, "{not json"); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); assert_eq!( unsafe { (registration.cb)( @@ -2904,8 +2898,7 @@ fn typed_callbacks_report_invalid_json_for_each_decoder_family() { name, bad_request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }, NemoRelayStatus::InvalidJson @@ -2919,8 +2912,7 @@ fn typed_callbacks_report_invalid_json_for_each_decoder_family() { name, request, bad_annotation, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }, NemoRelayStatus::InvalidJson @@ -3091,8 +3083,7 @@ fn typed_callbacks_map_additional_callback_errors() { let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); assert_eq!( unsafe { (registration.cb)( @@ -3100,8 +3091,7 @@ fn typed_callbacks_map_additional_callback_errors() { name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }, NemoRelayStatus::Internal @@ -4231,7 +4221,10 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { let host = test_host(); let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { - Ok((request, Some(test_annotated_llm_request()))) + Ok(LlmRequestInterceptOutcome::new( + request, + Some(test_annotated_llm_request()), + )) }) .unwrap(); @@ -4241,11 +4234,9 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { assert!(!registration.break_chain); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let stale_request = host_string(&host, r#"{"stale":"request"}"#); - let stale_annotated = host_string(&host, r#"{"stale":"annotated"}"#); - let mut out_request = stale_request; - let mut out_annotated = stale_annotated; - *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(1); + let stale_outcome = host_string(&host, r#"{"stale":"outcome"}"#); + let mut out_outcome = stale_outcome; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); let live_before = live_host_strings(); let status = unsafe { (registration.cb)( @@ -4253,18 +4244,15 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; assert_eq!(status, NemoRelayStatus::Internal); - assert!(out_request.is_null()); - assert!(out_annotated.is_null()); + assert!(out_outcome.is_null()); assert_eq!(live_host_strings(), live_before); unsafe { - (host.string_free)(stale_request); - (host.string_free)(stale_annotated); + (host.string_free)(stale_outcome); (host.string_free)(name); (host.string_free)(request); registration.free(); @@ -4272,15 +4260,14 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { - Ok((request, None)) + Ok(LlmRequestInterceptOutcome::new(request, None)) }) .unwrap(); let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); let status = unsafe { (registration.cb)( @@ -4288,14 +4275,12 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; assert_eq!(status, NemoRelayStatus::Internal); - assert!(out_request.is_null()); - assert!(out_annotated.is_null()); + assert!(out_outcome.is_null()); unsafe { (host.string_free)(name); (host.string_free)(request); @@ -4313,7 +4298,10 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { assert!(annotated.is_some()); request.headers.insert("x-mutated".into(), json!(true)); request.content["rewritten"] = json!(true); - Ok((request, Some(test_annotated_llm_request()))) + Ok(LlmRequestInterceptOutcome::new( + request, + Some(test_annotated_llm_request()), + )) }) .unwrap(); @@ -4327,24 +4315,22 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { &host, serde_json::to_value(test_annotated_llm_request()).unwrap(), ); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); let status = unsafe { (registration.cb)( registration.user_data as *mut c_void, name, request, annotated, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; assert_eq!(status, NemoRelayStatus::Ok); - let out_request = read_json_and_free(&host, out_request); - assert_eq!(out_request["headers"]["x-mutated"], json!(true)); - assert_eq!(out_request["content"]["rewritten"], json!(true)); - let out_annotated = read_json_and_free(&host, out_annotated); - assert_eq!(out_annotated["messages"], json!([])); + let outcome = read_json_and_free(&host, out_outcome); + assert_eq!(outcome["request"]["headers"]["x-mutated"], json!(true)); + assert_eq!(outcome["request"]["content"]["rewritten"], json!(true)); + assert_eq!(outcome["annotated_request"]["messages"], json!([])); + assert_eq!(outcome["pending_marks"], json!([])); unsafe { (host.string_free)(name); @@ -4355,33 +4341,30 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { - Ok((request, None)) + Ok(LlmRequestInterceptOutcome::new(request, None)) }) .unwrap(); let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let mut out_request = ptr::null_mut(); - let mut out_annotated = host_string(&host, r#"{"stale":true}"#); - let stale_annotated = out_annotated; + let mut out_outcome = host_string(&host, r#"{"stale":true}"#); + let stale_outcome = out_outcome; let status = unsafe { (registration.cb)( registration.user_data as *mut c_void, name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; assert_eq!(status, NemoRelayStatus::Ok); - assert!(out_annotated.is_null()); - assert_eq!( - read_json_and_free(&host, out_request)["content"]["input"], - json!(true) - ); + let outcome = read_json_and_free(&host, out_outcome); + assert!(outcome["annotated_request"].is_null()); + assert_eq!(outcome["request"]["content"]["input"], json!(true)); + assert_eq!(outcome["pending_marks"], json!([])); unsafe { - (host.string_free)(stale_annotated); + (host.string_free)(stale_outcome); (host.string_free)(name); (host.string_free)(request); registration.free(); @@ -4389,26 +4372,21 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { } #[test] -fn typed_llm_request_intercept_with_marks_uses_tagged_annotation_envelope() { +fn typed_llm_request_intercept_serializes_canonical_outcome() { let _guard = begin_test(); let host = test_host(); let mut ctx = test_context(&host); - ctx.register_llm_request_intercept_with_marks( - "llm", - 23, - false, - |_name, mut request, annotated| { - request.content["rewritten"] = json!(true); - Ok( - LlmRequestInterceptOutcome::new(request, annotated).with_pending_mark( - PendingMarkSpec::builder() - .name("plugin.request.rewritten") - .data(json!({ "saved_tokens": 7 })) - .build(), - ), - ) - }, - ) + ctx.register_llm_request_intercept("llm", 23, false, |_name, mut request, annotated| { + request.content["rewritten"] = json!(true); + Ok( + LlmRequestInterceptOutcome::new(request, annotated).with_pending_mark( + PendingMarkSpec::builder() + .name("plugin.request.rewritten") + .data(json!({ "saved_tokens": 7 })) + .build(), + ), + ) + }) .unwrap(); let registration = take_llm_request_intercept_registration(); @@ -4420,34 +4398,25 @@ fn typed_llm_request_intercept_with_marks_uses_tagged_annotation_envelope() { &host, serde_json::to_value(test_annotated_llm_request()).unwrap(), ); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); let status = unsafe { (registration.cb)( registration.user_data as *mut c_void, name, request, annotated, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; assert_eq!(status, NemoRelayStatus::Ok); - let out_request = read_json_and_free(&host, out_request); - assert_eq!(out_request["content"]["rewritten"], true); - assert!( - out_request - .pointer("/content/__nemo_relay_llm_intercept_outcome") - .is_none() - ); - let metadata = read_json_and_free(&host, out_annotated); - assert_eq!(metadata["__nemo_relay_llm_intercept_outcome"], true); - assert_eq!(metadata["annotated_request"]["messages"], json!([])); + let outcome = read_json_and_free(&host, out_outcome); + assert_eq!(outcome["request"]["content"]["rewritten"], true); + assert_eq!(outcome["annotated_request"]["messages"], json!([])); assert_eq!( - metadata["pending_marks"][0]["name"], + outcome["pending_marks"][0]["name"], "plugin.request.rewritten" ); - assert_eq!(metadata["pending_marks"][0]["data"]["saved_tokens"], 7); + assert_eq!(outcome["pending_marks"][0]["data"]["saved_tokens"], 7); unsafe { (host.string_free)(name); @@ -4767,6 +4736,8 @@ fn direct_export_plugin_validates_host_table_and_kind_allocation() { validate: None, register: None, drop: None, + llm_request_intercept_outcome_contract_version: + NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, }; assert_eq!( unsafe { nemo_relay_plugin::export_plugin(&bad_host, &mut plugin, CountingPlugin) }, @@ -4977,6 +4948,8 @@ fn exported_entry_symbol_validates_args_before_constructor() { validate: None, register: None, drop: None, + llm_request_intercept_outcome_contract_version: + NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, }; assert_eq!( unsafe { constructor_counting_entry(&bad_host, &mut plugin) }, diff --git a/crates/python/src/py_api/mod.rs b/crates/python/src/py_api/mod.rs index f6682b43e..2603e92f3 100644 --- a/crates/python/src/py_api/mod.rs +++ b/crates/python/src/py_api/mod.rs @@ -1207,9 +1207,12 @@ fn tool_conditional_execution(name: &str, args: &Bound<'_, PyAny>) -> PyResult<( /// Returns: /// The (possibly transformed) ``LlmRequest``. #[pyfunction] -fn llm_request_intercepts(name: &str, request: PyLLMRequest) -> PyResult { +fn llm_request_intercepts( + name: &str, + request: PyLLMRequest, +) -> PyResult { let result = core_llm_api::llm_request_intercepts(name, request.inner).map_err(to_py_err)?; - Ok(PyLLMRequest { inner: result }) + Ok(crate::py_types::PyLLMRequestInterceptOutcome { inner: result }) } /// Run the registered LLM conditional execution guardrail chain. diff --git a/crates/python/src/py_callable.rs b/crates/python/src/py_callable.rs index 07bc392ea..b47d39834 100644 --- a/crates/python/src/py_callable.rs +++ b/crates/python/src/py_callable.rs @@ -35,13 +35,15 @@ use serde_json::Value as Json; use tokio_stream::Stream; use nemo_relay::api::event::Event; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::codec::request::AnnotatedLlmRequest as AnnotatedLLMRequest; use nemo_relay::codec::response::AnnotatedLlmResponse as AnnotatedLLMResponse; use nemo_relay::codec::traits::{LlmCodec, LlmResponseCodec}; use crate::convert::{json_to_py, py_to_json}; -use crate::py_types::{PyAnnotatedLLMRequest, PyAnnotatedLLMResponse, PyLLMRequest}; +use crate::py_types::{ + PyAnnotatedLLMRequest, PyAnnotatedLLMResponse, PyLLMRequest, PyLLMRequestInterceptOutcome, +}; type PyValueFuture = Pin>> + Send>>; @@ -643,13 +645,13 @@ pub fn wrap_py_llm_conditional_fn(py_fn: Py) -> LlmConditionalFn { /// Wrap a Python callable for unified LLM request intercepts. /// /// The Python function receives ``(name: str, request: LlmRequest, annotated: AnnotatedLLMRequest | None)`` -/// and must return ``(LlmRequest, AnnotatedLLMRequest | None)``. +/// and must return ``LLMRequestInterceptOutcome``. pub fn wrap_py_llm_request_intercept_fn(py_fn: Py) -> LlmRequestInterceptFn { Arc::new( move |name: &str, request: LlmRequest, annotated: Option| - -> FlowResult<(LlmRequest, Option)> { + -> FlowResult { Python::attach(|py| { let py_req = PyLLMRequest { inner: request.clone(), @@ -673,42 +675,14 @@ pub fn wrap_py_llm_request_intercept_fn(py_fn: Py) -> LlmRequestIntercept FlowError::Internal(format!("LLM request intercept callable failed: {e}")) })?; - // Extract the tuple (LlmRequest, AnnotatedLLMRequest | None) - let tuple = result.bind(py); - let new_req: PyLLMRequest = tuple - .get_item(0) - .map_err(|e| { - FlowError::Internal(format!( - "LLM request intercept result[0] extraction failed: {e}" - )) - })? - .extract() + result + .extract::(py) + .map(|value| value.inner) .map_err(|e| { FlowError::Internal(format!( - "LLM request intercept result[0] is not LlmRequest: {e}" + "LLM request intercept must return LLMRequestInterceptOutcome: {e}" )) - })?; - let ann_item = tuple.get_item(1).map_err(|e| { - FlowError::Internal(format!( - "LLM request intercept result[1] extraction failed: {e}" - )) - })?; - let new_ann = if ann_item.is_none() { - None - } else { - Some( - ann_item - .extract::() - .map_err(|e| { - FlowError::Internal(format!( - "LLM request intercept result[1] is not AnnotatedLLMRequest: {e}" - )) - })? - .inner, - ) - }; - - Ok((new_req.inner, new_ann)) + }) }) }, ) diff --git a/crates/python/src/py_types/core.rs b/crates/python/src/py_types/core.rs index 98dfc6a59..fd7aba216 100644 --- a/crates/python/src/py_types/core.rs +++ b/crates/python/src/py_types/core.rs @@ -4,10 +4,12 @@ use pyo3::prelude::*; use super::{ - Bound, CoreScopeType, FlowResult, LlmAttributes, LlmHandle, LlmRequest, PyAny, PyErr, PyRef, - PyResult, Python, ScopeAttributes, ScopeHandle, ScopeStackHandle, ToolAttributes, ToolHandle, - json_to_py, opt_json_to_py, py_to_json, + AnnotatedLLMRequest, Bound, CoreScopeType, FlowResult, LlmAttributes, LlmHandle, LlmRequest, + PyAnnotatedLLMRequest, PyAny, PyErr, PyRef, PyResult, Python, ScopeAttributes, ScopeHandle, + ScopeStackHandle, ToolAttributes, ToolHandle, json_to_py, opt_json_to_py, py_to_json, }; +use nemo_relay::api::event::{CategoryProfile, EventCategory, PendingMarkSpec}; +use nemo_relay::api::llm::LlmRequestInterceptOutcome; // --------------------------------------------------------------------------- // LlmStream (async iterator) @@ -597,3 +599,132 @@ impl PyLLMRequest { "LLMRequest(...)".to_string() } } + +/// A mark to emit immediately after the managed LLM start event. +#[pyclass(name = "PendingMarkSpec", from_py_object)] +#[derive(Clone)] +pub struct PyPendingMarkSpec { + pub inner: PendingMarkSpec, +} + +#[pymethods] +impl PyPendingMarkSpec { + #[new] + #[pyo3(signature = (name, category=None, category_profile=None, data=None, metadata=None))] + fn new( + name: String, + category: Option, + category_profile: Option<&Bound<'_, PyAny>>, + data: Option<&Bound<'_, PyAny>>, + metadata: Option<&Bound<'_, PyAny>>, + ) -> PyResult { + let category = category + .map(|value| serde_json::from_value::(serde_json::Value::String(value))) + .transpose() + .map_err(|error| pyo3::exceptions::PyValueError::new_err(error.to_string()))?; + let category_profile = category_profile + .map(py_to_json) + .transpose()? + .map(serde_json::from_value::) + .transpose() + .map_err(|error| pyo3::exceptions::PyValueError::new_err(error.to_string()))?; + Ok(Self { + inner: PendingMarkSpec { + name, + category, + category_profile, + data: data.map(py_to_json).transpose()?, + metadata: metadata.map(py_to_json).transpose()?, + }, + }) + } + + #[getter] + fn name(&self) -> &str { + &self.inner.name + } + + #[getter] + fn category(&self) -> Option { + self.inner + .category + .as_ref() + .and_then(|value| serde_json::to_value(value).ok()) + .and_then(|value| value.as_str().map(str::to_owned)) + } + + #[getter] + fn category_profile(&self, py: Python<'_>) -> PyResult> { + opt_json_to_py( + py, + &self + .inner + .category_profile + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(|error| pyo3::exceptions::PyRuntimeError::new_err(error.to_string()))?, + ) + } + + #[getter] + fn data(&self, py: Python<'_>) -> PyResult> { + opt_json_to_py(py, &self.inner.data) + } + + #[getter] + fn metadata(&self, py: Python<'_>) -> PyResult> { + opt_json_to_py(py, &self.inner.metadata) + } +} + +/// Canonical result returned by Python LLM request intercepts. +#[pyclass(name = "LLMRequestInterceptOutcome", from_py_object)] +#[derive(Clone)] +pub struct PyLLMRequestInterceptOutcome { + pub inner: LlmRequestInterceptOutcome, +} + +#[pymethods] +impl PyLLMRequestInterceptOutcome { + #[new] + #[pyo3(signature = (request, annotated_request=None, pending_marks=Vec::new()))] + fn new( + request: PyLLMRequest, + annotated_request: Option, + pending_marks: Vec, + ) -> Self { + Self { + inner: LlmRequestInterceptOutcome { + request: request.inner, + annotated_request: annotated_request.map(|value| value.inner), + pending_marks: pending_marks.into_iter().map(|value| value.inner).collect(), + }, + } + } + + #[getter] + fn request(&self) -> PyLLMRequest { + PyLLMRequest { + inner: self.inner.request.clone(), + } + } + + #[getter] + fn annotated_request(&self) -> Option { + self.inner + .annotated_request + .clone() + .map(|inner: AnnotatedLLMRequest| PyAnnotatedLLMRequest { inner }) + } + + #[getter] + fn pending_marks(&self) -> Vec { + self.inner + .pending_marks + .iter() + .cloned() + .map(|inner| PyPendingMarkSpec { inner }) + .collect() + } +} diff --git a/crates/python/src/py_types/mod.rs b/crates/python/src/py_types/mod.rs index f4e018536..7a2b28265 100644 --- a/crates/python/src/py_types/mod.rs +++ b/crates/python/src/py_types/mod.rs @@ -132,6 +132,8 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 6c3205e00..4fa673131 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -30,7 +30,14 @@ fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { let code = CString::new(code).unwrap(); let file_name = CString::new("coverage_tests.py").unwrap(); let module_name = CString::new("coverage_tests").unwrap(); - PyModule::from_code(py, &code, &file_name, &module_name).unwrap() + let module = PyModule::from_code(py, &code, &file_name, &module_name).unwrap(); + module + .setattr( + "Outcome", + py.get_type::(), + ) + .unwrap(); + module } fn make_request() -> LlmRequest { @@ -390,7 +397,7 @@ def llm_conditional(request): return None def llm_request_intercept(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) async def llm_execution_intercept(name, request, next): return await next(request) diff --git a/crates/python/tests/coverage/py_adaptive_coverage_tests.rs b/crates/python/tests/coverage/py_adaptive_coverage_tests.rs index 6c040bb67..aa5f8ef95 100644 --- a/crates/python/tests/coverage/py_adaptive_coverage_tests.rs +++ b/crates/python/tests/coverage/py_adaptive_coverage_tests.rs @@ -135,7 +135,7 @@ def bind_scope_and_translate(api_module, runtime, request): handle = api_module.push_scope("adaptive-runtime-cov", api_module.ScopeType.Agent) try: runtime.bind_scope(handle) - return api_module.llm_request_intercepts("anthropic", request).content + return api_module.llm_request_intercepts("anthropic", request).request.content finally: api_module.pop_scope(handle) "# diff --git a/crates/python/tests/coverage/py_api_coverage_tests.rs b/crates/python/tests/coverage/py_api_coverage_tests.rs index 03542b8ea..409493997 100644 --- a/crates/python/tests/coverage/py_api_coverage_tests.rs +++ b/crates/python/tests/coverage/py_api_coverage_tests.rs @@ -195,7 +195,7 @@ def llm_request_intercept(name, request, annotated): headers["x-intercepted"] = "1" content = dict(request.content) content["messages"] = [{"role": "user", "content": "hello from intercept"}] - return (LLMRequest(headers, content), annotated) + return LLMRequestInterceptOutcome(LLMRequest(headers, content), annotated) async def llm_exec(request): return { @@ -324,6 +324,12 @@ async def run_stream(api, request, func, collector, finalizer, handle, attribute helpers .setattr("LLMRequest", types_module.getattr("LLMRequest").unwrap()) .unwrap(); + helpers + .setattr( + "LLMRequestInterceptOutcome", + types_module.getattr("LLMRequestInterceptOutcome").unwrap(), + ) + .unwrap(); helpers .setattr( "AnnotatedLLMRequest", @@ -457,7 +463,11 @@ async def run_stream(api, request, func, collector, finalizer, handle, attribute }; let intercepted_request = llm_request_intercepts("demo-llm", llm_request.clone()).unwrap(); assert_eq!( - intercepted_request.inner.headers.get("x-intercepted"), + intercepted_request + .inner + .request + .headers + .get("x-intercepted"), Some(&json!("1")) ); llm_conditional_execution(llm_request.clone()).unwrap(); diff --git a/crates/python/tests/coverage/py_callable_coverage_tests.rs b/crates/python/tests/coverage/py_callable_coverage_tests.rs index 9c9e0f7a5..cac6bbd6e 100644 --- a/crates/python/tests/coverage/py_callable_coverage_tests.rs +++ b/crates/python/tests/coverage/py_callable_coverage_tests.rs @@ -60,7 +60,7 @@ def sync_llm_intercept(name, request, next): return {"name": name, "model": request.content["model"], "mode": "sync"} def request_echo(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) def request_bad_annotated(name, request, annotated): return (request, {"bad": True}) @@ -93,6 +93,12 @@ class RaisingResponseCodec: raise RuntimeError("decode boom") "#, ); + module + .setattr( + "Outcome", + py.get_type::(), + ) + .unwrap(); let tool_exec_py: Py = module.getattr("sync_tool_exec").unwrap().unbind(); let tool_intercept_py: Py = module.getattr("sync_tool_intercept").unwrap().unbind(); @@ -142,9 +148,11 @@ class RaisingResponseCodec: "model": "codec-model" })) .unwrap(); - let (_request, echoed_ann) = - request_intercept("llm", make_request(), Some(annotated.clone())).unwrap(); - assert_eq!(echoed_ann.unwrap().last_user_message(), Some("annotated")); + let outcome = request_intercept("llm", make_request(), Some(annotated.clone())).unwrap(); + assert_eq!( + outcome.annotated_request.unwrap().last_user_message(), + Some("annotated") + ); let bad_request_intercept = wrap_py_llm_request_intercept_fn( module.getattr("request_bad_annotated").unwrap().unbind(), @@ -153,7 +161,7 @@ class RaisingResponseCodec: bad_request_intercept("llm", make_request(), Some(annotated)) .unwrap_err() .to_string() - .contains("result[1] is not AnnotatedLLMRequest") + .contains("must return LLMRequestInterceptOutcome") ); let short_request_intercept = wrap_py_llm_request_intercept_fn( @@ -163,7 +171,7 @@ class RaisingResponseCodec: short_request_intercept("llm", make_request(), None) .unwrap_err() .to_string() - .contains("result[1] extraction failed") + .contains("must return LLMRequestInterceptOutcome") ); let mut collector = wrap_py_collector_fn(module.getattr("collector_ok").unwrap().unbind()); diff --git a/crates/python/tests/coverage/py_plugin_coverage_tests.rs b/crates/python/tests/coverage/py_plugin_coverage_tests.rs index dbb8a21b9..0eba263c2 100644 --- a/crates/python/tests/coverage/py_plugin_coverage_tests.rs +++ b/crates/python/tests/coverage/py_plugin_coverage_tests.rs @@ -16,7 +16,14 @@ fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { let code = CString::new(code).unwrap(); let file_name = CString::new("py_plugin_coverage_tests.py").unwrap(); let module_name = CString::new("py_plugin_coverage_tests").unwrap(); - PyModule::from_code(py, &code, &file_name, &module_name).unwrap() + let module = PyModule::from_code(py, &code, &file_name, &module_name).unwrap(); + module + .setattr( + "Outcome", + py.get_type::(), + ) + .unwrap(); + module } fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> T { @@ -123,7 +130,7 @@ def llm_conditional(request): return None def llm_request_intercept(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) async def llm_execution_intercept(name, request, next): return await next(request) @@ -579,7 +586,7 @@ def llm_conditional(request): return None def llm_request_intercept(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) async def llm_execution_intercept(name, request, next): return await next(request) @@ -863,7 +870,7 @@ def llm_conditional(request): return None def llm_request_intercept(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) async def llm_execution_intercept(name, request, next): return await next(request) diff --git a/crates/types/src/api/llm.rs b/crates/types/src/api/llm.rs index c4dba6efb..6cc8c6083 100644 --- a/crates/types/src/api/llm.rs +++ b/crates/types/src/api/llm.rs @@ -10,13 +10,6 @@ use crate::Json; use crate::api::event::PendingMarkSpec; use crate::codec::request::AnnotatedLlmRequest; -/// Private native-ABI tag used for marked LLM request-intercept outcomes. -/// -/// Native plugin authors should return [`LlmRequestInterceptOutcome`] through -/// the plugin SDK instead of reading or writing this field directly. -#[doc(hidden)] -pub const NATIVE_LLM_INTERCEPT_OUTCOME_FIELD: &str = "__nemo_relay_llm_intercept_outcome"; - bitflags! { /// Bitflags that modify LLM-call behavior and observability. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] @@ -43,10 +36,10 @@ pub struct LlmRequestInterceptOutcome { /// Rewritten provider request. pub request: LlmRequest, /// Optional normalized request annotation to carry forward. - #[serde(default, skip_serializing_if = "Option::is_none")] + #[serde(default)] pub annotated_request: Option, /// Ordered marks to emit after Relay creates and starts the LLM scope. - #[serde(default, skip_serializing_if = "Vec::is_empty")] + #[serde(default)] pub pending_marks: Vec, } @@ -67,9 +60,3 @@ impl LlmRequestInterceptOutcome { self } } - -impl From<(LlmRequest, Option)> for LlmRequestInterceptOutcome { - fn from((request, annotated_request): (LlmRequest, Option)) -> Self { - Self::new(request, annotated_request) - } -} diff --git a/crates/types/tests/serialization_tests.rs b/crates/types/tests/serialization_tests.rs index fb423ecf9..cede82fe5 100644 --- a/crates/types/tests/serialization_tests.rs +++ b/crates/types/tests/serialization_tests.rs @@ -105,7 +105,7 @@ fn llm_request_intercept_outcome_round_trips_pending_marks() { let encoded = serde_json::to_value(&outcome).expect("outcome should serialize"); assert_eq!(encoded["pending_marks"][0]["name"], "request.optimized"); assert_eq!(encoded["pending_marks"][0]["category"], "custom"); - assert!(encoded.get("annotated_request").is_none()); + assert!(encoded["annotated_request"].is_null()); let mut encoded_without_pending_marks = encoded.clone(); encoded_without_pending_marks @@ -117,6 +117,23 @@ fn llm_request_intercept_outcome_round_trips_pending_marks() { .expect("outcome without pending marks should deserialize"); assert!(decoded_without_pending_marks.pending_marks.is_empty()); + let decoded_defaults: LlmRequestInterceptOutcome = serde_json::from_value(json!({ + "request": {"headers": {}, "content": {"prompt": "hello"}}, + "future_field": true + })) + .expect("omitted optional fields and unknown fields should be accepted"); + assert!(decoded_defaults.annotated_request.is_none()); + assert!(decoded_defaults.pending_marks.is_empty()); + + assert!( + serde_json::from_value::(json!({ + "annotated_request": null, + "pending_marks": [] + })) + .is_err(), + "request is required" + ); + let decoded: LlmRequestInterceptOutcome = serde_json::from_value(encoded).expect("outcome should deserialize"); assert_eq!(decoded, outcome); diff --git a/crates/wasm/src/api/mod.rs b/crates/wasm/src/api/mod.rs index 22c8ead65..e445d07cb 100644 --- a/crates/wasm/src/api/mod.rs +++ b/crates/wasm/src/api/mod.rs @@ -2036,8 +2036,11 @@ pub fn llm_request_intercepts_wasm( let llm_request: CoreLlmRequest = serde_json::from_value(request_json) .map_err(|e| to_js_err(FlowError::Internal(e.to_string())))?; let result = relay_llm_api::llm_request_intercepts(name, llm_request).map_err(to_js_err)?; - let result_json = - serde_json::to_value(&result).map_err(|e| to_js_err(FlowError::Internal(e.to_string())))?; + let result_json = serde_json::json!({ + "request": result.request, + "annotated": result.annotated_request, + "pendingMarks": result.pending_marks, + }); Ok(json_to_js(&result_json)) } diff --git a/crates/wasm/src/callable.rs b/crates/wasm/src/callable.rs index 4cf4b2135..1279421c4 100644 --- a/crates/wasm/src/callable.rs +++ b/crates/wasm/src/callable.rs @@ -29,7 +29,9 @@ use wasm_bindgen::JsValue; use wasm_bindgen_futures::JsFuture; use nemo_relay::api::event::Event; -use nemo_relay::api::llm::LlmRequest; +#[cfg(target_arch = "wasm32")] +use nemo_relay::api::event::PendingMarkSpec; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionNextFn, LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionNextFn, ToolConditionalFn, @@ -266,7 +268,7 @@ pub fn wrap_js_tool_exec_fn( pub fn wrap_js_llm_request_intercept_fn(_func: Function) -> LlmRequestInterceptFn { Arc::new( move |_name: &str, request: LlmRequest, annotated: Option| { - Ok((request, annotated)) + Ok(LlmRequestInterceptOutcome::new(request, annotated)) }, ) } @@ -278,7 +280,7 @@ pub fn wrap_js_llm_request_intercept_fn(func: Function) -> LlmRequestInterceptFn move |name: &str, request: LlmRequest, annotated: Option| - -> Result<(LlmRequest, Option)> { + -> Result { let req_json = serde_json::to_value(&request).unwrap_or(Json::Null); let js_name = JsValue::from_str(name); let js_req = json_to_js(&req_json); @@ -340,7 +342,24 @@ pub fn wrap_js_llm_request_intercept_fn(func: Function) -> LlmRequestInterceptFn ) }; - Ok((new_request, new_annotated)) + let js_pending_marks = + js_sys::Reflect::get(&result, &JsValue::from_str("pendingMarks")) + .map_err(|e| FlowError::Internal(js_error_message(&e)))?; + let pending_marks = if js_pending_marks.is_null() || js_pending_marks.is_undefined() { + Vec::new() + } else { + let marks_json = js_to_json(&js_pending_marks) + .map_err(|e| FlowError::Internal(js_error_message(&e)))?; + serde_json::from_value::>(marks_json).map_err(|e| { + FlowError::Internal(format!("failed to deserialize pendingMarks: {e}")) + })? + }; + + Ok(LlmRequestInterceptOutcome { + request: new_request, + annotated_request: new_annotated, + pending_marks, + }) }, ) } diff --git a/crates/wasm/tests/coverage/callable_tests.rs b/crates/wasm/tests/coverage/callable_tests.rs index fddbd1f8e..652329570 100644 --- a/crates/wasm/tests/coverage/callable_tests.rs +++ b/crates/wasm/tests/coverage/callable_tests.rs @@ -36,9 +36,10 @@ fn native_tool_and_llm_wrapper_fallbacks_are_stable() { content: json!({"messages": []}), }; let llm_intercept = wrap_js_llm_request_intercept_fn(dummy_function()); - let (request, annotated) = llm_intercept("llm", llm_request.clone(), None).unwrap(); - assert_eq!(request.content, llm_request.content); - assert!(annotated.is_none()); + let outcome = llm_intercept("llm", llm_request.clone(), None).unwrap(); + assert_eq!(outcome.request.content, llm_request.content); + assert!(outcome.annotated_request.is_none()); + assert!(outcome.pending_marks.is_empty()); let llm_sanitize = wrap_js_llm_sanitize_request_fn(dummy_function()); assert_eq!( diff --git a/crates/wasm/wrappers/esm/plugin.d.ts b/crates/wasm/wrappers/esm/plugin.d.ts index b664e8d5c..adc708e72 100644 --- a/crates/wasm/wrappers/esm/plugin.d.ts +++ b/crates/wasm/wrappers/esm/plugin.d.ts @@ -82,7 +82,14 @@ export interface PluginContext { breakChain: boolean, callback: (args: { name: string; request: Json; annotated: Json | null }) => { request: Json; - annotated: Json | null; + annotated?: Json | null; + pendingMarks?: Array<{ + name: string; + category?: string | null; + category_profile?: Json; + data?: Json; + metadata?: Json; + }>; }, ): void; /** Register an LLM execution intercept for this component. */ diff --git a/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto b/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto index 134154584..9a7e7f704 100644 --- a/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto +++ b/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto @@ -181,9 +181,7 @@ message GuardrailResult { } message LlmRequestInterceptResult { - JsonEnvelope request = 1; - JsonEnvelope annotated_request = 2; - bool has_annotated_request = 3; + JsonEnvelope outcome = 1; } message StreamChunk { diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index 4bb394c0d..7a4fb47fa 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -20,8 +20,8 @@ use futures_util::{Stream, StreamExt}; #[cfg(unix)] use hyper_util::rt::TokioIo; pub use nemo_relay_types::Json; -pub use nemo_relay_types::api::event::Event; -pub use nemo_relay_types::api::llm::LlmRequest; +pub use nemo_relay_types::api::event::{Event, PendingMarkSpec}; +pub use nemo_relay_types::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; pub use nemo_relay_types::api::scope::ScopeType; use nemo_relay_types::codec::request::AnnotatedLlmRequest; pub use nemo_relay_types::plugin::{ConfigDiagnostic, DiagnosticLevel}; @@ -111,11 +111,7 @@ type LlmSanitizeRequestFn = Arc LlmRequest + Send + Sync>; type LlmSanitizeResponseFn = Arc Json + Send + Sync>; type LlmConditionalFn = Arc Result> + Send + Sync>; type LlmRequestFn = Arc< - dyn Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> + dyn Fn(&str, LlmRequest, Option) -> Result + Send + Sync, >; @@ -350,11 +346,7 @@ impl PluginContext { break_chain: bool, callback: F, ) where - F: Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> + F: Fn(&str, LlmRequest, Option) -> Result + Send + Sync + 'static, @@ -1111,10 +1103,10 @@ impl WorkerService { .map(|value| decode_json_envelope::(&value)) .transpose()?; let handler = self.llm_request(&request.registration_name)?; - let (request, annotated) = with_thread_scope(&scope, || { + let outcome = with_thread_scope(&scope, || { handler(&payload.model_name, request_value, annotated) })?; - Ok(llm_request_response(request, annotated)?) + Ok(llm_request_response(outcome)?) } RegistrationSurface::LlmExecutionIntercept => { let payload = llm_payload(request.payload)?; @@ -1368,20 +1360,15 @@ fn guardrail_response(reason: Option) -> InvokeResponse { } } -fn llm_request_response( - request: LlmRequest, - annotated: Option, -) -> Result { +fn llm_request_response(outcome: LlmRequestInterceptOutcome) -> Result { Ok(InvokeResponse { result: Some( nemo_relay_worker_proto::v1::invoke_response::Result::LlmRequest( LlmRequestInterceptResult { - request: Some(json_envelope("nemo.relay.LlmRequest@1", &request)?), - annotated_request: annotated - .as_ref() - .map(|value| json_envelope("nemo.relay.AnnotatedLlmRequest@1", value)) - .transpose()?, - has_annotated_request: annotated.is_some(), + outcome: Some(json_envelope( + "nemo.relay.LlmRequestInterceptOutcome@1", + &outcome, + )?), }, ), ), diff --git a/crates/worker/tests/worker_sdk_tests.rs b/crates/worker/tests/worker_sdk_tests.rs index 30ae616a4..13df4abe8 100644 --- a/crates/worker/tests/worker_sdk_tests.rs +++ b/crates/worker/tests/worker_sdk_tests.rs @@ -1308,7 +1308,10 @@ impl WorkerPlugin for SurfacePlugin { .and_then(|blocked| blocked.then(|| "blocked-llm".into()))) }); ctx.register_llm_request_intercept("llm-request", 1, false, |_, request, annotated| { - Ok((set_llm_phase(request, "llm_request"), annotated)) + Ok(nemo_relay_worker::LlmRequestInterceptOutcome::new( + set_llm_phase(request, "llm_request"), + annotated, + )) }); ctx.register_llm_execution_intercept( @@ -1998,7 +2001,11 @@ async fn invoke_llm_request( .into_inner(); match response.result.expect("invoke result") { nemo_relay_worker_proto::v1::invoke_response::Result::LlmRequest(result) => { - decode_json_envelope(&result.request.expect("llm request")).expect("decode LLM request") + decode_json_envelope::( + &result.outcome.expect("llm request outcome"), + ) + .expect("decode LLM request outcome") + .request } other => panic!("unexpected invoke result: {other:?}"), } diff --git a/docs/build-plugins/code-examples.mdx b/docs/build-plugins/code-examples.mdx index f2114c2cf..068e33f1d 100644 --- a/docs/build-plugins/code-examples.mdx +++ b/docs/build-plugins/code-examples.mdx @@ -37,11 +37,14 @@ class HeaderPlugin: name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: + ) -> nemo_relay.LLMRequestInterceptOutcome: # The request object is immutable, however we can return a new instance with updated headers. headers = request.headers.copy() headers[plugin_config["header_name"]] = plugin_config["value"] - return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + return nemo_relay.LLMRequestInterceptOutcome( + nemo_relay.LLMRequest(headers=headers, content=request.content), + annotated, + ) context.register_llm_request_intercept("inject-header", 100, False, add_header) diff --git a/docs/build-plugins/register-behavior.mdx b/docs/build-plugins/register-behavior.mdx index d995ca773..04022fba8 100644 --- a/docs/build-plugins/register-behavior.mdx +++ b/docs/build-plugins/register-behavior.mdx @@ -51,10 +51,13 @@ class HeaderPlugin: name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: + ) -> nemo_relay.LLMRequestInterceptOutcome: headers = request.headers.copy() headers[plugin_config["header_name"]] = plugin_config["value"] - return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + return nemo_relay.LLMRequestInterceptOutcome( + nemo_relay.LLMRequest(headers=headers, content=request.content), + annotated, + ) context.register_llm_request_intercept("inject-header", 100, False, add_header) diff --git a/docs/instrument-applications/code-examples.mdx b/docs/instrument-applications/code-examples.mdx index 64d687da3..58b03404c 100644 --- a/docs/instrument-applications/code-examples.mdx +++ b/docs/instrument-applications/code-examples.mdx @@ -297,8 +297,8 @@ let request = LlmRequest { headers: Default::default(), content: json!({"messages": [{"role": "user", "content": "hello"}]}), }; -let rewritten = llm_request_intercepts("demo-provider", request)?; -llm_conditional_execution(&rewritten)?; +let outcome = llm_request_intercepts("demo-provider", request)?; +llm_conditional_execution(&outcome.request)?; ``` diff --git a/docs/integrate-into-frameworks/code-examples.mdx b/docs/integrate-into-frameworks/code-examples.mdx index c57181bb2..c3d0b4d05 100644 --- a/docs/integrate-into-frameworks/code-examples.mdx +++ b/docs/integrate-into-frameworks/code-examples.mdx @@ -208,7 +208,8 @@ use serde_json::json; let rewritten_args = tool_request_intercepts("search", json!({"query": "weather"}))?; let request = LlmRequest { headers: Default::default(), content: json!({"messages": []}) }; -let rewritten_request = llm_request_intercepts("demo-provider", request)?; +let outcome = llm_request_intercepts("demo-provider", request)?; +let rewritten_request = outcome.request; ``` diff --git a/docs/integrate-into-frameworks/provider-codecs.mdx b/docs/integrate-into-frameworks/provider-codecs.mdx index f0d632661..03815a6a6 100644 --- a/docs/integrate-into-frameworks/provider-codecs.mdx +++ b/docs/integrate-into-frameworks/provider-codecs.mdx @@ -87,7 +87,7 @@ from nemo_relay.codecs import OpenAIChatCodec def add_system_message(_name, request, annotated): if annotated is None: - return request, annotated + return nemo_relay.LLMRequestInterceptOutcome(request) # Attributes of the annotated request can be re-assigned, but cannot be modified in-place. # For example `annotated.messages.append(...)` would not work, but re-assigning @@ -96,7 +96,7 @@ def add_system_message(_name, request, annotated): {"role": "system", "content": "Answer with concise technical detail."}, *annotated.messages, ] - return request, annotated + return nemo_relay.LLMRequestInterceptOutcome(request, annotated) nemo_relay.intercepts.register_llm_request( "framework.add_system_message", diff --git a/docs/reference/llm-request-intercept-outcomes.mdx b/docs/reference/llm-request-intercept-outcomes.mdx new file mode 100644 index 000000000..9c2de1ab6 --- /dev/null +++ b/docs/reference/llm-request-intercept-outcomes.mdx @@ -0,0 +1,57 @@ +--- +title: "LLM Request Intercept Outcomes" +description: "Canonical request-intercept result and managed lifecycle behavior." +--- +{/* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: Apache-2.0 */} + +Every LLM request intercept returns one canonical outcome: + +```json +{ + "request": {"headers": {}, "content": {}}, + "annotated_request": null, + "pending_marks": [] +} +``` + +`request` is required. `annotated_request` defaults to `null` when omitted on +input, and `pending_marks` defaults to an empty list. Canonical serialization +includes all three fields. A pending mark contains only `name`, optional +`category` and `category_profile`, and optional `data` and `metadata`. Relay +owns event UUIDs, parent UUIDs, and timestamps. + +Python callbacks return `LLMRequestInterceptOutcome`; Rust callbacks return +`LlmRequestInterceptOutcome`; Go callbacks return +`LLMRequestInterceptOutcome`; and Node.js and WebAssembly callbacks return +`{ request, annotated?, pendingMarks? }`. Public C callbacks write one owned +canonical outcome JSON string. Native ABI v1 uses one host-owned outcome JSON +string, and `grpc-v1` uses a `JsonEnvelope` whose schema is +`nemo.relay.LlmRequestInterceptOutcome@1`. + +The standalone request-intercept helper returns the complete outcome but does +not emit its pending marks because it does not own an LLM lifecycle. + +## Managed Lifecycle + +Managed execution runs all effective global and scope-local intercepts before +creating the LLM handle. Each rewritten request and annotation feeds the next +intercept, while pending marks append in middleware order. A breaking +intercept's marks are retained. If any intercept fails or its boundary result +is malformed, Relay discards all accumulated marks and creates no LLM +lifecycle. + +After successful interception, Relay creates the handle and captures one +subscriber snapshot. It emits the LLM start at `T`, every pending mark at +`T + 1µs` in returned order with the LLM UUID as parent, and the LLM end no +earlier than `T + 1µs`. Streaming and non-streaming calls use the same rules. +Pending marks are never added to the provider request, annotated request, +codec input, sanitizer input, or start payload. + +## Migration + +This finalizes unpublished native ABI v1 and `grpc-v1` contracts. Rebuild all +development native plugins and workers. Replace tuple results, split C/Go +outputs, metadata envelopes, and parallel mark-aware registrations with the +canonical outcome and the existing `register_llm_request_intercept` +registration name. diff --git a/examples/rust-native-plugin/README.md b/examples/rust-native-plugin/README.md index 9c8df8860..ea38314f2 100644 --- a/examples/rust-native-plugin/README.md +++ b/examples/rust-native-plugin/README.md @@ -94,8 +94,6 @@ Native plugins are not sandboxed. They run in the Relay process and must not unwind across ABI callbacks. Request intercepts do not own an LLM lifecycle because they run before Relay -creates the LLM scope. Use `register_llm_request_intercept_with_marks` to return -`PendingMarkSpec` values. Relay emits them in interceptor order after the LLM -start event and before provider execution. The legacy -`register_llm_request_intercept` API remains available for intercepts that only -rewrite requests. +creates the LLM scope. `register_llm_request_intercept` returns one +`LlmRequestInterceptOutcome`, whose `pending_marks` Relay emits in interceptor +order after the LLM start event and before provider execution. diff --git a/examples/rust-native-plugin/src/lib.rs b/examples/rust-native-plugin/src/lib.rs index d7e4d961f..4d9130e8c 100644 --- a/examples/rust-native-plugin/src/lib.rs +++ b/examples/rust-native-plugin/src/lib.rs @@ -233,7 +233,7 @@ impl NativePlugin for ExampleNativePlugin { Ok(block_llms.then(|| "LLM call blocked by Rust native plugin".to_string())) } })?; - ctx.register_llm_request_intercept_with_marks("example_llm_request", 20, false, { + ctx.register_llm_request_intercept("example_llm_request", 20, false, { let tag = config.tag.clone(); move |_name, request, annotated| { Ok(LlmRequestInterceptOutcome::new( diff --git a/go/nemo_relay/adaptive_plugin_test.go b/go/nemo_relay/adaptive_plugin_test.go index 44d580b0c..7b53a7131 100644 --- a/go/nemo_relay/adaptive_plugin_test.go +++ b/go/nemo_relay/adaptive_plugin_test.go @@ -104,12 +104,13 @@ func registerLifecycleInterceptors(ctx *PluginContext, pluginKind string) error "llm_request", 7, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - out, err := decorateJSONPayload(headers, "x-go-plugin", pluginKind) + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { + out, err := decorateJSONPayload(request.Headers, "x-go-plugin", pluginKind) if err != nil { - return nil, nil, nil, err + return LLMRequestInterceptOutcome{}, err } - return out, content, annotated, nil + request.Headers = out + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { return err @@ -441,8 +442,8 @@ func TestPluginFuncsAndClosedContextBranches(t *testing.T) { return closed.RegisterLlmConditionalExecutionGuardrail("llm_conditional", 1, func(headers, content json.RawMessage) *string { return nil }) }}, {"llm request", func() error { - return closed.RegisterLlmRequestIntercept("llm_request", 1, false, func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return headers, content, annotated, nil + return closed.RegisterLlmRequestIntercept("llm_request", 1, false, func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }) }}, {"tool request", func() error { diff --git a/go/nemo_relay/callbacks.go b/go/nemo_relay/callbacks.go index 68f3d7f71..58182417d 100644 --- a/go/nemo_relay/callbacks.go +++ b/go/nemo_relay/callbacks.go @@ -212,16 +212,34 @@ type CodecFunc struct { Encode func(annotatedJSON json.RawMessage, originalHeadersJSON, originalContentJSON json.RawMessage) (json.RawMessage, error) } -// LLMRequestInterceptFunc is a callback for LLM request intercepts with -// the unified annotated-aware signature. It receives the intercept name, -// request headers/content, and optionally the annotated request JSON (nil if -// no Codec resolved). Returns the (possibly modified) headers, content, and -// annotated JSON. +// LLMRequestDTO is the JSON-shaped request used by request intercept outcomes. +type LLMRequestDTO struct { + Headers json.RawMessage `json:"headers"` + Content json.RawMessage `json:"content"` +} + +// PendingMarkSpec describes a mark Relay emits after starting a managed LLM call. +type PendingMarkSpec struct { + Name string `json:"name"` + Category *string `json:"category,omitempty"` + CategoryProfile json.RawMessage `json:"category_profile,omitempty"` + Data json.RawMessage `json:"data,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +// LLMRequestInterceptOutcome is the canonical result of an LLM request intercept. +type LLMRequestInterceptOutcome struct { + Request LLMRequestDTO `json:"request"` + AnnotatedRequest json.RawMessage `json:"annotated_request"` + PendingMarks []PendingMarkSpec `json:"pending_marks"` +} + +// LLMRequestInterceptFunc is a callback for LLM request intercepts. type LLMRequestInterceptFunc func( name string, - headers, content json.RawMessage, + request LLMRequestDTO, annotatedJSON json.RawMessage, -) (newHeaders, newContent, newAnnotatedJSON json.RawMessage, err error) +) (LLMRequestInterceptOutcome, error) func codecDecodePayload(codec *CodecFunc, headers, content json.RawMessage) (json.RawMessage, error) { return codec.Decode(headers, content) @@ -236,8 +254,8 @@ func llmRequestInterceptPayload( name string, headers, content json.RawMessage, annotatedJSON json.RawMessage, -) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return fn(name, headers, content, annotatedJSON) +) (LLMRequestInterceptOutcome, error) { + return fn(name, LLMRequestDTO{Headers: headers, Content: content}, annotatedJSON) } func pluginValidatePayload(plugin Plugin, pluginConfigJSON json.RawMessage) (json.RawMessage, error) { @@ -512,7 +530,7 @@ func goCodecEncodeTrampoline(userData unsafe.Pointer, annotatedJSON *C.char, ori //export goLlmRequestInterceptTrampoline func goLlmRequestInterceptTrampoline( userData unsafe.Pointer, name *C.char, request *C.FfiLLMRequest, - annotatedJSON *C.char, outRequest **C.FfiLLMRequest, outAnnotatedJSON **C.char, + annotatedJSON *C.char, outOutcomeJSON **C.char, ) C.int32_t { fn := lookupClosure(userData).(LLMRequestInterceptFunc) goName := C.GoString(name) @@ -526,20 +544,20 @@ func goLlmRequestInterceptTrampoline( if annotatedJSON != nil { goAnnotated = json.RawMessage(C.GoString(annotatedJSON)) } - newHeaders, newContent, newAnnotated, err := llmRequestInterceptPayload(fn, goName, goHeaders, goContent, goAnnotated) + outcome, err := llmRequestInterceptPayload(fn, goName, goHeaders, goContent, goAnnotated) if err != nil { setLastErrorMessage(err.Error()) return 5 // NemoRelayStatus::Internal } - // Create output FfiLLMRequest - cNewHeaders := C.CString(string(newHeaders)) - cNewContent := C.CString(string(newContent)) - defer C.free(unsafe.Pointer(cNewHeaders)) - defer C.free(unsafe.Pointer(cNewContent)) - *outRequest = C.nemo_relay_llm_request_new(cNewHeaders, cNewContent) - if newAnnotated != nil { - *outAnnotatedJSON = C.CString(string(newAnnotated)) + if outcome.PendingMarks == nil { + outcome.PendingMarks = []PendingMarkSpec{} + } + outcomeJSON, err := jsonMarshal(outcome) + if err != nil { + setLastErrorMessage(err.Error()) + return 5 } + *outOutcomeJSON = C.CString(string(outcomeJSON)) return 0 // NemoRelayStatus::Ok } diff --git a/go/nemo_relay/intercepts/intercepts.go b/go/nemo_relay/intercepts/intercepts.go index 5872b74a6..98030c400 100644 --- a/go/nemo_relay/intercepts/intercepts.go +++ b/go/nemo_relay/intercepts/intercepts.go @@ -206,6 +206,6 @@ func ToolRequestIntercepts(name string, args json.RawMessage) (json.RawMessage, // LlmRequestIntercepts runs the registered LLM request intercept chain and // returns the transformed request. This is a shorthand for // [nemo_relay.LlmRequestIntercepts]. -func LlmRequestIntercepts(name string, request json.RawMessage) (json.RawMessage, error) { +func LlmRequestIntercepts(name string, request json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { return nemo_relay.LlmRequestIntercepts(name, request) } diff --git a/go/nemo_relay/intercepts/intercepts_test.go b/go/nemo_relay/intercepts/intercepts_test.go index 1bb6eb8da..ed16007be 100644 --- a/go/nemo_relay/intercepts/intercepts_test.go +++ b/go/nemo_relay/intercepts/intercepts_test.go @@ -87,12 +87,12 @@ func runGlobalLLMInterceptShorthandChecks(t *testing.T) { t.Helper() if err := intercepts.RegisterLlmRequest("intercepts_llm_req", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request nemo_relay.LLMRequestDTO, annotated json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { var payload map[string]interface{} - _ = json.Unmarshal(content, &payload) + _ = json.Unmarshal(request.Content, &payload) payload["intercepted"] = true - out, _ := json.Marshal(payload) - return headers, out, annotated, nil + request.Content, _ = json.Marshal(payload) + return nemo_relay.LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { t.Fatalf("RegisterLlmRequest failed: %v", err) @@ -109,7 +109,7 @@ func runGlobalLLMInterceptShorthandChecks(t *testing.T) { var llmReq struct { Content map[string]interface{} `json:"content"` } - if err := json.Unmarshal(transformedRequest, &llmReq); err != nil { + if err := json.Unmarshal(transformedRequest.Request.Content, &llmReq.Content); err != nil { t.Fatalf("unmarshal llm request: %v", err) } if llmReq.Content["intercepted"] != true { @@ -198,8 +198,8 @@ func runScopeLocalLLMInterceptShorthandChecks(t *testing.T, scopeUUID string) { t.Helper() if err := intercepts.ScopeRegisterLlmRequest(scopeUUID, "intercepts_scope_llm_req", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return headers, content, annotated, nil + func(name string, request nemo_relay.LLMRequestDTO, annotated json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { + return nemo_relay.LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { t.Fatalf("ScopeRegisterLlmRequest failed: %v", err) diff --git a/go/nemo_relay/llm/llm.go b/go/nemo_relay/llm/llm.go index 662fa6b3e..742e16e99 100644 --- a/go/nemo_relay/llm/llm.go +++ b/go/nemo_relay/llm/llm.go @@ -63,7 +63,7 @@ func StreamExecute(name string, native interface{}, fn nemo_relay.LLMExecutionFu // RequestIntercepts runs the registered LLM request intercept chain on the // given request and returns the transformed request. This is a shorthand for // [nemo_relay.LlmRequestIntercepts]. -func RequestIntercepts(name string, request json.RawMessage) (json.RawMessage, error) { +func RequestIntercepts(name string, request json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { return nemo_relay.LlmRequestIntercepts(name, request) } diff --git a/go/nemo_relay/llm/llm_shorthand_test.go b/go/nemo_relay/llm/llm_shorthand_test.go index 2be60eadc..40e3604bc 100644 --- a/go/nemo_relay/llm/llm_shorthand_test.go +++ b/go/nemo_relay/llm/llm_shorthand_test.go @@ -45,12 +45,12 @@ func assertLLMRequestInterceptShorthand(t *testing.T) { t.Helper() if err := nemo_relay.RegisterLlmRequestIntercept("llm_req_int", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request nemo_relay.LLMRequestDTO, annotated json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { var payload map[string]interface{} - _ = json.Unmarshal(content, &payload) + _ = json.Unmarshal(request.Content, &payload) payload["intercepted"] = true - out, _ := json.Marshal(payload) - return headers, out, annotated, nil + request.Content, _ = json.Marshal(payload) + return nemo_relay.LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { t.Fatalf("RegisterLlmRequestIntercept failed: %v", err) @@ -67,7 +67,7 @@ func assertLLMRequestInterceptShorthand(t *testing.T) { var intercepted struct { Content map[string]interface{} `json:"content"` } - if err := json.Unmarshal(request, &intercepted); err != nil { + if err := json.Unmarshal(request.Request.Content, &intercepted.Content); err != nil { t.Fatalf("unmarshal request: %v", err) } if intercepted.Content["intercepted"] != true { diff --git a/go/nemo_relay/llm_test.go b/go/nemo_relay/llm_test.go index 7b82ef77c..da94d2fde 100644 --- a/go/nemo_relay/llm_test.go +++ b/go/nemo_relay/llm_test.go @@ -440,8 +440,8 @@ func TestLlmConditionalBlocksExecution(t *testing.T) { func TestLlmRequestInterceptRegisterDeregister(t *testing.T) { err := RegisterLlmRequestIntercept("go_llm_req", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return headers, content, annotated, nil + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ) if err != nil { @@ -521,12 +521,12 @@ func TestLlmStreamExecutionInterceptCanCallNext(t *testing.T) { func TestLlmRequestInterceptModifies(t *testing.T) { RegisterLlmRequestIntercept("go_llm_req_mod", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { var m map[string]interface{} - json.Unmarshal(content, &m) + json.Unmarshal(request.Content, &m) m["intercepted"] = true - out, _ := json.Marshal(m) - return headers, out, annotated, nil + request.Content, _ = json.Marshal(m) + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ) diff --git a/go/nemo_relay/nemo_relay.go b/go/nemo_relay/nemo_relay.go index 9799a9a76..87ab55291 100644 --- a/go/nemo_relay/nemo_relay.go +++ b/go/nemo_relay/nemo_relay.go @@ -148,7 +148,7 @@ extern int32_t nemo_relay_register_llm_conditional_execution_guardrail(const cha extern int32_t nemo_relay_deregister_llm_conditional_execution_guardrail(const char* name); // LLM intercepts -typedef int32_t (*NemoRelayLlmRequestInterceptCb)(void* user_data, const char* name, const FfiLLMRequest* request, const char* annotated_json, FfiLLMRequest** out_request, char** out_annotated_json); +typedef int32_t (*NemoRelayLlmRequestInterceptCb)(void* user_data, const char* name, const FfiLLMRequest* request, const char* annotated_json, char** out_outcome_json); extern int32_t nemo_relay_register_llm_request_intercept(const char* name, int32_t priority, _Bool break_chain, NemoRelayLlmRequestInterceptCb cb, void* user_data, NemoRelayFreeFn free_fn); extern int32_t nemo_relay_deregister_llm_request_intercept(const char* name); typedef char* (*NemoRelayLlmExecNextFn)(const char* native_json, void* next_ctx); @@ -269,7 +269,7 @@ extern char* goLlmExecInterceptTrampoline(void*, const char*, NemoRelayLlmExecNe extern char* goCodecDecodeTrampoline(void*, const FfiLLMRequest*); extern char* goCodecEncodeTrampoline(void*, const char*, const FfiLLMRequest*); extern int32_t goLlmRequestInterceptTrampoline( - void*, const char*, const FfiLLMRequest*, const char*, FfiLLMRequest**, char**); + void*, const char*, const FfiLLMRequest*, const char*, char**); */ import "C" @@ -2431,7 +2431,7 @@ func ToolConditionalExecution(name string, args json.RawMessage) error { // LlmRequestIntercepts runs the registered LLM request intercept chain on the // given request (serialized as JSON) and returns the transformed request JSON. -func LlmRequestIntercepts(name string, request json.RawMessage) (json.RawMessage, error) { +func LlmRequestIntercepts(name string, request json.RawMessage) (LLMRequestInterceptOutcome, error) { cName := C.CString(name) defer C.free(unsafe.Pointer(cName)) cRequest := C.CString(string(request)) @@ -2440,10 +2440,14 @@ func LlmRequestIntercepts(name string, request json.RawMessage) (json.RawMessage var out *C.char status := C.nemo_relay_llm_request_intercepts(cName, cRequest, &out) if err := checkStatus(status); err != nil { - return nil, err + return LLMRequestInterceptOutcome{}, err } defer C.nemo_relay_string_free(out) - return json.RawMessage(C.GoString(out)), nil + var outcome LLMRequestInterceptOutcome + if err := jsonUnmarshal([]byte(C.GoString(out)), &outcome); err != nil { + return LLMRequestInterceptOutcome{}, err + } + return outcome, nil } // LlmConditionalExecution runs the registered LLM conditional execution diff --git a/go/nemo_relay/plugin.go b/go/nemo_relay/plugin.go index c4a8affcf..af39171b6 100644 --- a/go/nemo_relay/plugin.go +++ b/go/nemo_relay/plugin.go @@ -18,7 +18,7 @@ typedef char* (*NemoRelayToolConditionalFn)(void* user_data, const char* name, c typedef void* (*NemoRelayLlmRequestCb)(void* user_data, const void* request); typedef char* (*NemoRelayLlmResponseFn)(void* user_data, const char* response_json); typedef char* (*NemoRelayLlmConditionalCb)(void* user_data, const void* request); -typedef int32_t (*NemoRelayLlmRequestInterceptCb)(void* user_data, const char* name, const void* request, const char* annotated_json, void** out_request, char** out_annotated_json); +typedef int32_t (*NemoRelayLlmRequestInterceptCb)(void* user_data, const char* name, const void* request, const char* annotated_json, char** out_outcome_json); typedef char* (*NemoRelayLlmExecNextFn)(const char* native_json, void* next_ctx); typedef char* (*NemoRelayLlmExecInterceptCb)(void* user_data, const char* native_json, NemoRelayLlmExecNextFn next_fn, void* next_ctx); typedef char* (*NemoRelayToolExecNextFn)(const char* args_json, void* next_ctx); @@ -56,7 +56,7 @@ extern void* goLlmRequestTrampoline(void*, const void*); extern char* goLlmResponseTrampoline(void*, const char*); extern char* goLlmConditionalTrampoline(void*, const void*); extern char* goLlmExecInterceptTrampoline(void*, const char*, NemoRelayLlmExecNextFn, void*); -extern int32_t goLlmRequestInterceptTrampoline(void*, const char*, const void*, const char*, void**, char**); +extern int32_t goLlmRequestInterceptTrampoline(void*, const char*, const void*, const char*, char**); extern char* goToolExecInterceptTrampoline(void*, const char*, NemoRelayToolExecNextFn, void*); */ import "C" diff --git a/go/nemo_relay/scope_local_test.go b/go/nemo_relay/scope_local_test.go index ca40b5218..9e3e62351 100644 --- a/go/nemo_relay/scope_local_test.go +++ b/go/nemo_relay/scope_local_test.go @@ -1128,9 +1128,9 @@ func assertScopeLocalLLMWrappersDeregister(t *testing.T, scopeUUID string, reque &requestInterceptCalls, func() error { return ScopeRegisterLlmRequestIntercept(scopeUUID, "llm_scope_req_int", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { requestInterceptCalls++ - return headers, content, annotated, nil + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ) }, diff --git a/go/nemo_relay/top_level_coverage_test.go b/go/nemo_relay/top_level_coverage_test.go index 455d945f5..48eba4fd6 100644 --- a/go/nemo_relay/top_level_coverage_test.go +++ b/go/nemo_relay/top_level_coverage_test.go @@ -270,15 +270,17 @@ func assertCodecEncodeCoverage(t *testing.T, request *LLMRequest) { func assertLlmRequestInterceptPayloadCoverage(t *testing.T, request *LLMRequest) { t.Helper() - newHeaders, newContent, newAnnotated, err := llmRequestInterceptPayload( - func(name string, headers, content, annotatedJSON json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + outcome, err := llmRequestInterceptPayload( + func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { if name != coverageInterceptName { t.Fatalf("unexpected intercept name: %q", name) } if string(annotatedJSON) != `{"annotated":true}` { t.Fatalf("unexpected annotated payload: %s", annotatedJSON) } - return json.RawMessage(`{"updated":"headers"}`), json.RawMessage(`{"model":"updated"}`), json.RawMessage(`{"annotated":"updated"}`), nil + request.Headers = json.RawMessage(`{"updated":"headers"}`) + request.Content = json.RawMessage(`{"model":"updated"}`) + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: json.RawMessage(`{"annotated":"updated"}`)}, nil }, coverageInterceptName, request.Headers(), @@ -288,22 +290,22 @@ func assertLlmRequestInterceptPayloadCoverage(t *testing.T, request *LLMRequest) if err != nil { t.Fatalf("expected intercept success, got %v", err) } - if string(newHeaders) != `{"updated":"headers"}` { - t.Fatalf("unexpected output headers: %s", newHeaders) + if string(outcome.Request.Headers) != `{"updated":"headers"}` { + t.Fatalf("unexpected output headers: %s", outcome.Request.Headers) } - if string(newContent) != `{"model":"updated"}` { - t.Fatalf("unexpected output content: %s", newContent) + if string(outcome.Request.Content) != `{"model":"updated"}` { + t.Fatalf("unexpected output content: %s", outcome.Request.Content) } - if string(newAnnotated) != `{"annotated":"updated"}` { - t.Fatalf("unexpected output annotated json: %s", newAnnotated) + if string(outcome.AnnotatedRequest) != `{"annotated":"updated"}` { + t.Fatalf("unexpected output annotated json: %s", outcome.AnnotatedRequest) } - _, _, _, err = llmRequestInterceptPayload( - func(name string, headers, content, annotatedJSON json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + _, err = llmRequestInterceptPayload( + func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { if annotatedJSON != nil { t.Fatalf("expected nil annotated JSON, got %s", annotatedJSON) } - return nil, nil, nil, errors.New("forced intercept failure") + return LLMRequestInterceptOutcome{}, errors.New("forced intercept failure") }, coverageInterceptName, request.Headers(), @@ -400,8 +402,8 @@ func assertLlmCodecInterceptCoverage(t *testing.T) { }, } - if err := RegisterLlmRequestIntercept("coverage_llm_codec_success", 1, false, func(name string, headers, content, annotatedJSON json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return headers, content, json.RawMessage(`{"messages":[{"role":"user","content":"updated"}],"model":"updated-model"}`), nil + if err := RegisterLlmRequestIntercept("coverage_llm_codec_success", 1, false, func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: json.RawMessage(`{"messages":[{"role":"user","content":"updated"}],"model":"updated-model"}`)}, nil }); err != nil { t.Fatalf("RegisterLlmRequestIntercept success case failed: %v", err) } @@ -415,8 +417,8 @@ func assertLlmCodecInterceptCoverage(t *testing.T) { t.Fatalf("expected codec-backed intercept success, got %v", err) } - if err := RegisterLlmRequestIntercept("coverage_llm_codec_error", 1, false, func(name string, headers, content, annotatedJSON json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return nil, nil, nil, errors.New("forced codec-backed intercept failure") + if err := RegisterLlmRequestIntercept("coverage_llm_codec_error", 1, false, func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { + return LLMRequestInterceptOutcome{}, errors.New("forced codec-backed intercept failure") }); err != nil { t.Fatalf("RegisterLlmRequestIntercept error case failed: %v", err) } diff --git a/go/nemo_relay/wrapper_coverage_test.go b/go/nemo_relay/wrapper_coverage_test.go index 70eb9dd26..19812aa29 100644 --- a/go/nemo_relay/wrapper_coverage_test.go +++ b/go/nemo_relay/wrapper_coverage_test.go @@ -112,12 +112,12 @@ func TestStandaloneMiddlewareHelpers(t *testing.T) { } if err := RegisterLlmRequestIntercept("go_standalone_llm_req", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { var payload map[string]interface{} - _ = json.Unmarshal(content, &payload) + _ = json.Unmarshal(request.Content, &payload) payload["intercepted"] = true - out, _ := json.Marshal(payload) - return headers, out, annotated, nil + request.Content, _ = json.Marshal(payload) + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { t.Fatalf("RegisterLlmRequestIntercept failed: %v", err) @@ -131,7 +131,7 @@ func TestStandaloneMiddlewareHelpers(t *testing.T) { var llmPayload struct { Content map[string]interface{} `json:"content"` } - if err := json.Unmarshal(request, &llmPayload); err != nil { + if err := json.Unmarshal(request.Request.Content, &llmPayload.Content); err != nil { t.Fatalf("unmarshal llm request: %v", err) } if llmPayload.Content["intercepted"] != true { diff --git a/python/nemo_relay/__init__.py b/python/nemo_relay/__init__.py index a14e8c8a6..2ef611e84 100644 --- a/python/nemo_relay/__init__.py +++ b/python/nemo_relay/__init__.py @@ -43,11 +43,13 @@ def add_header( name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: + ) -> nemo_relay.LLMRequestInterceptOutcome: # The request object is immutable, however we can return a new instance with updated headers. headers = request.headers.copy() headers["Authorization"] = "Bearer test-token" - return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + return nemo_relay.LLMRequestInterceptOutcome( + nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + ) async def tool_impl(args): return {"echo": args["query"]} @@ -96,11 +98,13 @@ async def main(): LLMAttributes, LLMHandle, LLMRequest, + LLMRequestInterceptOutcome, MarkEvent, OpenInferenceConfig, OpenInferenceSubscriber, OpenTelemetryConfig, OpenTelemetrySubscriber, + PendingMarkSpec, ScopeAttributes, ScopeEvent, ScopeHandle, @@ -162,12 +166,11 @@ async def main(): [str, Json, Callable[[Json], Awaitable[Json]]], Json | Awaitable[Json], ] -#: Request intercept callback that rewrites raw and annotated LLM requests -#: together. The return tuple supplies the request and optional annotated view -#: passed to later request intercepts and execution. +#: Request intercept callback that returns the canonical request, annotation, +#: and pending-mark outcome passed to later intercepts and managed execution. LlmRequestIntercept: TypeAlias = Callable[ [str, LLMRequest, AnnotatedLLMRequest | None], - tuple[LLMRequest, AnnotatedLLMRequest | None], + LLMRequestInterceptOutcome, ] #: Execution intercept callback that wraps non-streaming LLM execution. The #: callback receives the logical LLM name, request, and next callable. It may @@ -457,6 +460,8 @@ def worker() -> None: "ToolHandle", "LLMHandle", "LLMRequest", + "LLMRequestInterceptOutcome", + "PendingMarkSpec", "Event", "AnnotatedLLMRequest", "AnnotatedLLMResponse", diff --git a/python/nemo_relay/__init__.pyi b/python/nemo_relay/__init__.pyi index 676a0a6c1..d682011a4 100644 --- a/python/nemo_relay/__init__.pyi +++ b/python/nemo_relay/__init__.pyi @@ -69,6 +69,9 @@ from nemo_relay._native import ( from nemo_relay._native import ( LLMRequest as LLMRequest, ) +from nemo_relay._native import ( + LLMRequestInterceptOutcome as LLMRequestInterceptOutcome, +) from nemo_relay._native import ( MarkEvent as MarkEvent, ) @@ -84,6 +87,9 @@ from nemo_relay._native import ( from nemo_relay._native import ( OpenTelemetrySubscriber as OpenTelemetrySubscriber, ) +from nemo_relay._native import ( + PendingMarkSpec as PendingMarkSpec, +) from nemo_relay._native import ( ScopeAttributes as ScopeAttributes, ) @@ -209,7 +215,7 @@ Exceptional flow: """ LlmRequestIntercept: TypeAlias = Callable[ [str, LLMRequest, AnnotatedLLMRequest | None], - tuple[LLMRequest, AnnotatedLLMRequest | None], + LLMRequestInterceptOutcome, ] """Request intercept callback that rewrites raw and annotated LLM requests. @@ -217,7 +223,7 @@ Arguments: The logical LLM name, raw request, and optional annotated request view. Return: - The request and optional annotated view passed to later middleware. + The complete canonical outcome passed to later middleware. """ LlmExecutionIntercept: TypeAlias = Callable[ [str, LLMRequest, Callable[[LLMRequest], Awaitable[Json]]], diff --git a/python/nemo_relay/_native.pyi b/python/nemo_relay/_native.pyi index 8aed064ab..6acc400f5 100644 --- a/python/nemo_relay/_native.pyi +++ b/python/nemo_relay/_native.pyi @@ -42,7 +42,7 @@ _ToolExecutionIntercept: TypeAlias = Callable[ ] _LlmRequestIntercept: TypeAlias = Callable[ [str, "LLMRequest", "AnnotatedLLMRequest | None"], - tuple["LLMRequest", "AnnotatedLLMRequest | None"], + "LLMRequestInterceptOutcome", ] _LlmExecutionIntercept: TypeAlias = Callable[ [str, "LLMRequest", Callable[["LLMRequest"], Awaitable[_Json]]], @@ -382,6 +382,42 @@ class LLMRequest: """Return the request content body as a JSON object.""" ... +class PendingMarkSpec: + """A runtime-owned mark specification returned by request middleware.""" + def __init__( + self, + name: str, + category: Optional[str] = ..., + category_profile: Optional[_Json] = ..., + data: Optional[_Json] = ..., + metadata: Optional[_Json] = ..., + ) -> None: ... + @property + def name(self) -> str: ... + @property + def category(self) -> Optional[str]: ... + @property + def category_profile(self) -> Optional[_Json]: ... + @property + def data(self) -> Optional[_Json]: ... + @property + def metadata(self) -> Optional[_Json]: ... + +class LLMRequestInterceptOutcome: + """Canonical result returned by an LLM request intercept.""" + def __init__( + self, + request: LLMRequest, + annotated_request: Optional[AnnotatedLLMRequest] = ..., + pending_marks: list[PendingMarkSpec] = ..., + ) -> None: ... + @property + def request(self) -> LLMRequest: ... + @property + def annotated_request(self) -> Optional[AnnotatedLLMRequest]: ... + @property + def pending_marks(self) -> list[PendingMarkSpec]: ... + class AnnotatedLLMRequest: """Structured view of an LLM request produced by a codec. @@ -1509,7 +1545,7 @@ def tool_conditional_execution(name: str, args: _Json) -> None: """ ... -def llm_request_intercepts(name: str, request: LLMRequest) -> LLMRequest: +def llm_request_intercepts(name: str, request: LLMRequest) -> LLMRequestInterceptOutcome: """Run the registered LLM request-intercept chain. Args: diff --git a/python/nemo_relay/intercepts.py b/python/nemo_relay/intercepts.py index ac9f340df..836c2e76a 100644 --- a/python/nemo_relay/intercepts.py +++ b/python/nemo_relay/intercepts.py @@ -14,11 +14,13 @@ def add_header( name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: + ) -> nemo_relay.LLMRequestInterceptOutcome: # The request object is immutable, however we can return a new instance with updated headers. headers = request.headers.copy() headers["X-Trace"] = "demo" - return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + return nemo_relay.LLMRequestInterceptOutcome( + nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + ) nemo_relay.intercepts.register_llm_request("trace-header", 10, False, add_header) """ @@ -186,10 +188,12 @@ def register_llm_request(name: str, priority: int, break_chain: bool, fn: LlmReq def add_header( name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: + ) -> nemo_relay.LLMRequestInterceptOutcome: headers = request.headers.copy() headers["X-Trace"] = "req-123" - return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + return nemo_relay.LLMRequestInterceptOutcome( + nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + ) nemo_relay.intercepts.register_llm_request( "trace-header", diff --git a/python/nemo_relay/llm.py b/python/nemo_relay/llm.py index a99724593..78657143f 100644 --- a/python/nemo_relay/llm.py +++ b/python/nemo_relay/llm.py @@ -378,7 +378,8 @@ def request_intercepts(name, request): intercept chain. Returns: - LLMRequest: The request produced by the final request intercept. + LLMRequestInterceptOutcome: The complete request, annotation, and + pending-mark outcome produced by the intercept chain. Notes: This runs only the request-intercept chain. It does not execute diff --git a/python/tests/integrations/langchain_tests/test_middleware.py b/python/tests/integrations/langchain_tests/test_middleware.py index 7d2c7813c..fffb369b3 100644 --- a/python/tests/integrations/langchain_tests/test_middleware.py +++ b/python/tests/integrations/langchain_tests/test_middleware.py @@ -323,7 +323,7 @@ def change_request(name: str, request: nemo_relay.LLMRequest, annotated: Any): else message for message in annotated.messages ] - return request, annotated + return nemo_relay.LLMRequestInterceptOutcome(request, annotated) nemo_relay.intercepts.register_llm_request("test_langchain_change_request", 1, False, change_request) try: diff --git a/python/tests/test_adaptive.py b/python/tests/test_adaptive.py index f3fb92547..c7e956e7b 100644 --- a/python/tests/test_adaptive.py +++ b/python/tests/test_adaptive.py @@ -9,7 +9,17 @@ import pytest -from nemo_relay import AnnotatedLLMRequest, JsonObject, LLMRequest, ScopeType, llm, plugin, scope, tools +from nemo_relay import ( + AnnotatedLLMRequest, + JsonObject, + LLMRequest, + LLMRequestInterceptOutcome, + ScopeType, + llm, + plugin, + scope, + tools, +) from nemo_relay import adaptive as adaptive_module from nemo_relay.adaptive import ( ADAPTIVE_PLUGIN_KIND, @@ -213,7 +223,7 @@ async def test_adaptive_runtime_bind_scope_passes_through_without_state(self): with scope.scope("adaptive-runtime-translate", ScopeType.Agent) as handle: runtime.bind_scope(handle) translated = llm.request_intercepts("anthropic", request) - assert translated.content == { + assert translated.request.content == { "messages": [{"role": "user", "content": "Hello"}], "system": "You are helpful.", "model": "claude-sonnet-4-20250514", @@ -287,7 +297,7 @@ def register(self, plugin_config, context): def intercept(_name, request, annotated): headers = dict(request.headers) headers["x-python-plugin"] = f"priority:{priority}" - return LLMRequest(headers, request.content), annotated + return LLMRequestInterceptOutcome(LLMRequest(headers, request.content), annotated) async def llm_exec_intercept(_name, request, next_call): response = await next_call(request) diff --git a/python/tests/test_codecs.py b/python/tests/test_codecs.py index d49c70513..6caab0ca4 100644 --- a/python/tests/test_codecs.py +++ b/python/tests/test_codecs.py @@ -16,6 +16,7 @@ AnnotatedLLMRequest, JsonObject, LLMRequest, + LLMRequestInterceptOutcome, intercepts, llm, ) @@ -267,7 +268,7 @@ def annotated_intercept(name, request, annotated): *annotated.messages, ] annotated.messages = new_messages - return (request, annotated) + return LLMRequestInterceptOutcome(request, annotated) intercepts.register_llm_request("test-annot-intercept-pipeline", 1, False, annotated_intercept) @@ -295,7 +296,7 @@ async def test_codec_parameter(self): def annotated_intercept(name, request, annotated): if annotated is not None: intercept_data["extra"] = annotated.extra - return (request, annotated) + return LLMRequestInterceptOutcome(request, annotated) intercepts.register_llm_request("test-annot-intercept-cn", 1, False, annotated_intercept) @@ -324,7 +325,7 @@ def annotated_intercept(name, request, annotated): assert annotated is not None assert isinstance(annotated, AnnotatedLLMRequest) assert isinstance(request, LLMRequest) - return (request, annotated) + return LLMRequestInterceptOutcome(request, annotated) intercepts.register_llm_request("test-annot-typed", 1, False, annotated_intercept) diff --git a/python/tests/test_integration_codecs.py b/python/tests/test_integration_codecs.py index a98b72d04..a2f4be2ec 100644 --- a/python/tests/test_integration_codecs.py +++ b/python/tests/test_integration_codecs.py @@ -13,7 +13,15 @@ from typing import cast -from nemo_relay import AnnotatedLLMRequest, JsonObject, LLMRequest, ScopeType, llm, scope +from nemo_relay import ( + AnnotatedLLMRequest, + JsonObject, + LLMRequest, + LLMRequestInterceptOutcome, + ScopeType, + llm, + scope, +) from nemo_relay.codecs import ( LlmCodec, ) @@ -448,7 +456,7 @@ def annotated_intercept(name, request, annotated): if annotated is not None: intercept_data["model"] = annotated.model intercept_data["messages"] = annotated.messages - return (request, annotated) + return LLMRequestInterceptOutcome(request, annotated) intercepts.register_llm_request("delegation-test-intercept", 1, False, annotated_intercept) diff --git a/python/tests/test_llm.py b/python/tests/test_llm.py index 49aa8f070..30c390164 100644 --- a/python/tests/test_llm.py +++ b/python/tests/test_llm.py @@ -11,6 +11,7 @@ LLMAttributes, LLMHandle, LLMRequest, + LLMRequestInterceptOutcome, ScopeEvent, ScopeType, guardrails, @@ -285,20 +286,25 @@ def func(request): class TestLLMIntercepts: def test_request_intercept(self): # Request intercepts now operate on LLMRequest - intercepts.register_llm_request("py_llm_req", 1, False, lambda name, request, annotated: (request, annotated)) + intercepts.register_llm_request( + "py_llm_req", + 1, + False, + lambda name, request, annotated: LLMRequestInterceptOutcome(request, annotated), + ) assert intercepts.deregister_llm_request("py_llm_req") def test_request_intercepts_direct(self): def intercept_fn(name, request, annotated): content = request.content content["direct"] = True - return LLMRequest(request.headers, content), annotated + return LLMRequestInterceptOutcome(LLMRequest(request.headers, content), annotated) intercepts.register_llm_request("py_llm_req_direct", 1, False, intercept_fn) transformed = llm.request_intercepts("direct_llm", make_request()) intercepts.deregister_llm_request("py_llm_req_direct") - assert transformed.content["direct"] is True + assert transformed.request.content["direct"] is True def test_request_intercept_raises_on_exception(self): intercepts.register_llm_request( @@ -316,7 +322,7 @@ def test_request_intercept_raises_on_exception(self): def test_request_intercept_raises_on_invalid_return(self): intercepts.register_llm_request("py_llm_req_bad_return", 1, False, lambda name, request, annotated: object()) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] try: - with pytest.raises(RuntimeError, match="result\\[0\\] extraction failed"): + with pytest.raises(RuntimeError, match="must return LLMRequestInterceptOutcome"): llm.request_intercepts("bad_return_llm", make_request()) finally: intercepts.deregister_llm_request("py_llm_req_bad_return") @@ -358,7 +364,7 @@ def intercept_fn(name, request, annotated): # Request intercepts now operate on LLMRequest content = request.content content["intercepted"] = True - return LLMRequest(request.headers, content), annotated + return LLMRequestInterceptOutcome(LLMRequest(request.headers, content), annotated) intercepts.register_llm_request("py_llm_req_mod", 1, False, intercept_fn) diff --git a/python/tests/test_scope_local.py b/python/tests/test_scope_local.py index 0b4fcfe8d..7746ed48e 100644 --- a/python/tests/test_scope_local.py +++ b/python/tests/test_scope_local.py @@ -16,6 +16,7 @@ from nemo_relay import ( JsonObject, LLMRequest, + LLMRequestInterceptOutcome, MarkEvent, ScopeEvent, ScopeType, @@ -648,7 +649,10 @@ async def test_scope_local_llm_request_intercept_modifies_request(self): request = LLMRequest({}, {"messages": [], "model": "scope-local"}) def intercept(name, req, annotated): - return LLMRequest(req.headers, {**req.content, "intercepted": True}), annotated + return LLMRequestInterceptOutcome( + LLMRequest(req.headers, {**req.content, "intercepted": True}), + annotated, + ) with scope.scope("sl_llm_request_scope", ScopeType.Agent) as handle: scope_local.register_llm_request(handle, "sl_llm_request", 1, False, intercept) From dfb40cf78db4c2f845eb747ddd2fb549cc8ff750 Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Tue, 30 Jun 2026 17:25:02 -0600 Subject: [PATCH 4/9] fix(llm): enforce codec request authority Signed-off-by: Bryan Bednarski --- crates/adaptive/src/acg_component.rs | 13 ++ .../adaptive/src/adaptive_hints_intercept.rs | 32 ++-- .../unit/adaptive_hints_intercept_tests.rs | 12 +- crates/core/src/api/llm.rs | 4 +- crates/core/src/api/runtime/callbacks.rs | 4 + crates/core/src/api/runtime/state.rs | 16 ++ crates/core/src/api/shared.rs | 9 +- .../tests/fixtures/worker_plugin/src/main.rs | 20 +- .../core/tests/integration/pipeline_tests.rs | 181 +++++++++++++++++- crates/ffi/nemo_relay.h | 4 +- crates/ffi/src/callable.rs | 4 +- crates/node/src/callable.rs | 2 + crates/node/tests/codec_tests.mjs | 36 ++++ crates/python/src/py_callable.rs | 2 + crates/types/src/api/llm.rs | 9 +- crates/wasm/src/callable.rs | 2 + .../provider-codecs.mdx | 9 +- .../llm-request-intercept-outcomes.mdx | 44 ++++- go/nemo_relay/callbacks.go | 5 +- go/nemo_relay/top_level_coverage_test.go | 24 ++- python/plugin/README.md | 21 ++ .../plugin/src/nemo_relay_plugin/__init__.py | 6 + python/plugin/src/nemo_relay_plugin/_api.py | 83 ++++++-- python/tests/plugin/test_worker_sdk.py | 43 +++-- python/tests/test_codecs.py | 23 +++ 25 files changed, 540 insertions(+), 68 deletions(-) diff --git a/crates/adaptive/src/acg_component.rs b/crates/adaptive/src/acg_component.rs index 9d7026252..3d344866a 100644 --- a/crates/adaptive/src/acg_component.rs +++ b/crates/adaptive/src/acg_component.rs @@ -595,9 +595,22 @@ pub(crate) fn create_acg_llm_request_intercept( plugin: Arc, ) -> LlmRequestInterceptFn { Arc::new(move |_name: &str, request: LlmRequest, annotated| { + let input_content = request.content.clone(); let translated = translate_request(&request, &agent_id, &provider, plugin.as_ref(), &hot_cache) .unwrap_or(request); + if annotated.is_some() && translated.content != input_content { + let translated_annotated = build_semantic_request_view(&translated) + .map_err(|error| nemo_relay::error::FlowError::Internal(error.to_string()))? + .annotated_request; + return Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + LlmRequest { + headers: translated.headers, + content: input_content, + }, + Some(translated_annotated), + )); + } Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( translated, annotated, )) diff --git a/crates/adaptive/src/adaptive_hints_intercept.rs b/crates/adaptive/src/adaptive_hints_intercept.rs index b4d9fe2b3..c9f245505 100644 --- a/crates/adaptive/src/adaptive_hints_intercept.rs +++ b/crates/adaptive/src/adaptive_hints_intercept.rs @@ -86,22 +86,24 @@ fn manual_agent_hints(manual: u32, effective_agent_id: &str, scope_depth: usize) } } -fn inject_agent_hints(request: &mut LlmRequest, hints: &AgentHints) { +fn inject_agent_hints( + request: &mut LlmRequest, + annotated: &mut Option, + hints: &AgentHints, +) { let Ok(serialized_hints) = serde_json::to_value(hints) else { return; }; - if let Some(body) = request.content.as_object_mut() { - if !body.contains_key("nvext") { - body.insert( - "nvext".to_string(), - serde_json::Value::Object(serde_json::Map::new()), - ); - } - if let Some(nvext) = body - .get_mut("nvext") - .and_then(|value| value.as_object_mut()) - { + let body = annotated + .as_mut() + .map(|annotated| &mut annotated.extra) + .or_else(|| request.content.as_object_mut()); + if let Some(body) = body { + let nvext = body + .entry("nvext".to_string()) + .or_insert_with(|| serde_json::Value::Object(serde_json::Map::new())); + if let Some(nvext) = nvext.as_object_mut() { nvext.insert("agent_hints".to_string(), serialized_hints.clone()); } } @@ -172,7 +174,9 @@ impl AdaptiveHintsIntercept { pub fn into_request_fn(self) -> LlmRequestInterceptFn { let this = Arc::new(self); Arc::new( - move |_name: &str, mut request: LlmRequest, annotated: Option| { + move |_name: &str, + mut request: LlmRequest, + mut annotated: Option| { let scope_path = extract_scope_path(); let manual_ls = read_manual_latency_sensitivity(); let scope_depth = scope_path.len(); @@ -189,7 +193,7 @@ impl AdaptiveHintsIntercept { ); if let Some(hints) = final_hints { - inject_agent_hints(&mut request, &hints); + inject_agent_hints(&mut request, &mut annotated, &hints); } Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( diff --git a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs index d68b3b016..9d260fac5 100644 --- a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs +++ b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs @@ -206,7 +206,9 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override() let request = outcome.request; let returned_annotated = outcome.annotated_request; - let body_hints = &request.content["nvext"]["agent_hints"]; + assert_eq!(request.content, serde_json::json!({})); + let returned_annotated = returned_annotated.expect("annotation should be preserved"); + let body_hints = &returned_annotated.extra["nvext"]["agent_hints"]; assert_eq!(body_hints["osl"], serde_json::json!(150)); assert_eq!(body_hints["iat"], serde_json::json!(200)); assert_eq!(body_hints["latency_sensitivity"], serde_json::json!(5.0)); @@ -217,7 +219,11 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override() request.headers.get(AGENT_HINTS_HEADER_KEY).unwrap(), body_hints ); - assert_eq!(returned_annotated, Some(annotated)); + let mut expected_annotated = annotated; + expected_annotated + .extra + .insert("nvext".into(), returned_annotated.extra["nvext"].clone()); + assert_eq!(returned_annotated, expected_annotated); pop_scope( nemo_relay::api::scope::PopScopeParams::builder() @@ -349,7 +355,7 @@ fn test_apply_manual_latency_override_and_inject_agent_hints_cover_manual_paths( headers: serde_json::Map::new(), content: serde_json::json!("scalar"), }; - inject_agent_hints(&mut non_object_request, &manual_only); + inject_agent_hints(&mut non_object_request, &mut None, &manual_only); assert_eq!( non_object_request.headers.get(AGENT_HINTS_HEADER_KEY), Some(&serde_json::to_value(&manual_only).unwrap()) diff --git a/crates/core/src/api/llm.rs b/crates/core/src/api/llm.rs index 822a35572..6ac9a895a 100644 --- a/crates/core/src/api/llm.rs +++ b/crates/core/src/api/llm.rs @@ -914,7 +914,9 @@ pub fn llm_request_intercepts( .map_err(|error| FlowError::Internal(error.to_string()))?; state.llm_request_intercept_entries(&scope_locals) }; - NemoRelayContextState::llm_request_intercepts_snapshot_chain(name, request, None, &entries) + NemoRelayContextState::llm_request_intercepts_snapshot_chain( + name, request, None, &entries, false, + ) } /// Run only the LLM conditional-execution guardrail chain. diff --git a/crates/core/src/api/runtime/callbacks.rs b/crates/core/src/api/runtime/callbacks.rs index 7e6c0848c..58e5f8512 100644 --- a/crates/core/src/api/runtime/callbacks.rs +++ b/crates/core/src/api/runtime/callbacks.rs @@ -164,6 +164,10 @@ pub type LlmConditionalFn = Arc Result> + /// /// # Returns /// A [`Result`] containing the canonical request-intercept outcome. +/// Without a request codec, the returned request is authoritative. With a +/// request codec, its headers remain writable while its content must remain +/// unchanged; provider-body edits must be returned through the required +/// annotation. /// /// # Errors /// The callback can return any [`FlowError`](crate::error::FlowError) to abort diff --git a/crates/core/src/api/runtime/state.rs b/crates/core/src/api/runtime/state.rs index 00b9010ab..eb6a57bf9 100644 --- a/crates/core/src/api/runtime/state.rs +++ b/crates/core/src/api/runtime/state.rs @@ -1027,6 +1027,8 @@ impl NemoRelayContextState { /// - `annotated`: Optional normalized request annotation to carry through /// the chain. /// - `entries`: Intercept snapshots to evaluate. + /// - `codec_active`: Whether request content is owned by the normalized + /// annotation and must remain unchanged by callbacks. /// /// # Returns /// A [`Result`] containing the final request and annotation pair. @@ -1042,12 +1044,26 @@ impl NemoRelayContextState { request: LlmRequest, annotated: Option, entries: &[Intercept], + codec_active: bool, ) -> crate::error::Result { let mut request_value = request; let mut annotated_value = annotated; let mut pending_marks = Vec::new(); for entry in entries { + let input_content = request_value.content.clone(); let outcome = (entry.payload.callable)(name, request_value, annotated_value)?; + if codec_active && outcome.request.content != input_content { + return Err(crate::error::FlowError::InvalidArgument(format!( + "LLM request intercept '{}' changed request.content while a request codec is active; modify annotated_request instead", + entry.name + ))); + } + if codec_active && outcome.annotated_request.is_none() { + return Err(crate::error::FlowError::InvalidArgument(format!( + "LLM request intercept '{}' omitted annotated_request while a request codec is active", + entry.name + ))); + } request_value = outcome.request; annotated_value = outcome.annotated_request; pending_marks.extend(outcome.pending_marks); diff --git a/crates/core/src/api/shared.rs b/crates/core/src/api/shared.rs index 53f14a599..28965542c 100644 --- a/crates/core/src/api/shared.rs +++ b/crates/core/src/api/shared.rs @@ -79,7 +79,6 @@ pub(crate) fn run_request_intercepts_with_codec( Option>, Vec, )> { - let original = request.clone(); let annotated = match &codec { Some(codec) => Some(codec.decode(&request)?), None => None, @@ -100,13 +99,17 @@ pub(crate) fn run_request_intercepts_with_codec( let outcome = crate::api::runtime::NemoRelayContextState::llm_request_intercepts_snapshot_chain( - name, request, annotated, &entries, + name, + request, + annotated, + &entries, + codec.is_some(), )?; let pending_marks = outcome.pending_marks; match (codec, outcome.annotated_request) { (Some(codec), Some(annotated)) => { - let mut encoded = codec.encode(&annotated, &original)?; + let mut encoded = codec.encode(&annotated, &outcome.request)?; encoded.headers = outcome.request.headers; Ok((encoded, Some(Arc::new(annotated)), pending_marks)) } diff --git a/crates/core/tests/fixtures/worker_plugin/src/main.rs b/crates/core/tests/fixtures/worker_plugin/src/main.rs index 28de01f8a..1e296dd28 100644 --- a/crates/core/tests/fixtures/worker_plugin/src/main.rs +++ b/crates/core/tests/fixtures/worker_plugin/src/main.rs @@ -182,14 +182,20 @@ impl WorkerPlugin for FixtureWorkerPlugin { "fixture LLM request error requested".into(), )); } - let annotated = annotated.map(|mut annotated| { - annotated - .extra - .insert("worker_plugin_annotated_request".into(), json!(true)); - annotated - }); + let (request, annotated) = match annotated { + Some(mut annotated) => { + annotated + .extra + .insert("worker_plugin_annotated_request".into(), json!(true)); + (request, Some(annotated)) + } + None => ( + mark_llm_request(request, "worker_plugin_llm_request_intercept"), + None, + ), + }; Ok(nemo_relay_worker::LlmRequestInterceptOutcome::new( - mark_llm_request(request, "worker_plugin_llm_request_intercept"), + request, annotated, )) }, diff --git a/crates/core/tests/integration/pipeline_tests.rs b/crates/core/tests/integration/pipeline_tests.rs index 4431a6102..cf411e509 100644 --- a/crates/core/tests/integration/pipeline_tests.rs +++ b/crates/core/tests/integration/pipeline_tests.rs @@ -14,11 +14,11 @@ use futures::StreamExt; use serde_json::json; use tokio_stream::Stream; -use nemo_relay::api::event::{Event, ScopeCategory}; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::event::{Event, PendingMarkSpec, ScopeCategory}; use nemo_relay::api::llm::{ LlmCallExecuteParams, LlmStreamCallExecuteParams, llm_call_execute, llm_stream_call_execute, }; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::registry::{ deregister_llm_request_intercept, deregister_llm_sanitize_request_guardrail, deregister_llm_sanitize_response_guardrail, register_llm_request_intercept, @@ -361,9 +361,10 @@ async fn test_encode_runs_after_intercepts() { "modify_model", 1, false, - Arc::new(|_name, req, annotated| { + Arc::new(|_name, mut req, annotated| { let mut ann = annotated.unwrap(); ann.model = Some("modified".into()); + req.headers.insert("x-codec-route".into(), json!("blue")); Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( req, Some(ann), @@ -402,11 +403,185 @@ async fn test_encode_runs_after_intercepts() { let captured_req = exec_request.lock().unwrap(); let req = captured_req.as_ref().unwrap(); assert_eq!(req.content["model"], json!("modified")); + assert_eq!(req.headers["x-codec-route"], json!("blue")); // Cleanup deregister_llm_request_intercept("modify_model").unwrap(); } +#[tokio::test] +async fn test_codec_rejects_raw_content_mutation_before_lifecycle() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::new())); + let captured_events = events.clone(); + register_subscriber( + "codec_raw_mutation_subscriber", + Arc::new(move |event| captured_events.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + let later_intercept_called = Arc::new(Mutex::new(false)); + register_llm_request_intercept( + "codec_raw_mutation", + 1, + false, + Arc::new(|_name, mut request, annotated| { + request.content["model"] = json!("raw-model-edit"); + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark(PendingMarkSpec::builder().name("must.not.emit").build())) + }), + ) + .unwrap(); + let later_called = later_intercept_called.clone(); + register_llm_request_intercept( + "codec_after_raw_mutation", + 2, + false, + Arc::new(move |_name, request, annotated| { + *later_called.lock().unwrap() = true; + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), + ) + .unwrap(); + + let provider_called = Arc::new(Mutex::new(false)); + let called = provider_called.clone(); + let provider: LlmExecutionNextFn = Arc::new(move |_request| { + *called.lock().unwrap() = true; + Box::pin(async { Ok(json!({"response": "unexpected"})) }) + }); + let (codec, _, encode_log) = make_tracking_codec("codec_raw_read_only"); + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("test_llm") + .request(make_llm_request(json!({ + "model": "original", + "messages": [{"role": "user", "content": "hello"}] + }))) + .func(provider) + .codec(codec) + .build(), + ) + .await + .expect_err("raw content mutation must fail on the codec path"); + + assert!(error.to_string().contains("request.content")); + assert!(!*later_intercept_called.lock().unwrap()); + assert!(!*provider_called.lock().unwrap()); + assert!(encode_log.lock().unwrap().is_empty()); + assert!(captured_events_snapshot(&events).is_empty()); + + deregister_llm_request_intercept("codec_raw_mutation").unwrap(); + deregister_llm_request_intercept("codec_after_raw_mutation").unwrap(); + deregister_subscriber("codec_raw_mutation_subscriber").unwrap(); +} + +#[tokio::test] +async fn test_codec_rejects_missing_annotation_before_lifecycle() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + register_llm_request_intercept( + "codec_missing_annotation", + 1, + false, + Arc::new(|_name, request, _annotated| Ok(LlmRequestInterceptOutcome::new(request, None))), + ) + .unwrap(); + + let provider_called = Arc::new(Mutex::new(false)); + let called = provider_called.clone(); + let provider: LlmExecutionNextFn = Arc::new(move |_request| { + *called.lock().unwrap() = true; + Box::pin(async { Ok(json!({"response": "unexpected"})) }) + }); + let (codec, _, encode_log) = make_tracking_codec("codec_annotation_required"); + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("test_llm") + .request(make_llm_request(json!({ + "messages": [{"role": "user", "content": "hello"}] + }))) + .func(provider) + .codec(codec) + .build(), + ) + .await + .expect_err("missing annotation must fail on the codec path"); + + assert!(error.to_string().contains("omitted annotated_request")); + assert!(!*provider_called.lock().unwrap()); + assert!(encode_log.lock().unwrap().is_empty()); + + deregister_llm_request_intercept("codec_missing_annotation").unwrap(); +} + +#[tokio::test] +async fn test_stream_codec_rejects_raw_content_mutation_before_lifecycle() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::new())); + let captured_events = events.clone(); + register_subscriber( + "stream_codec_raw_mutation_subscriber", + Arc::new(move |event| captured_events.lock().unwrap().push(event.clone())), + ) + .unwrap(); + register_llm_request_intercept( + "stream_codec_raw_mutation", + 1, + false, + Arc::new(|_name, mut request, annotated| { + request.content["model"] = json!("raw-stream-edit"); + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), + ) + .unwrap(); + + let provider_called = Arc::new(Mutex::new(false)); + let called = provider_called.clone(); + let provider: LlmStreamExecutionNextFn = Arc::new(move |_request| { + *called.lock().unwrap() = true; + Box::pin(async { + let stream: Pin> + Send>> = + Box::pin(futures::stream::empty()); + Ok(stream) + }) + }); + let (codec, _, _) = make_tracking_codec("stream_codec_raw_read_only"); + let result = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("test_stream") + .request(make_llm_request(json!({ + "model": "original", + "messages": [{"role": "user", "content": "hello"}] + }))) + .func(provider) + .collector(Box::new(|_| Ok(()))) + .finalizer(Box::new(|| json!({"done": true}))) + .codec(codec) + .build(), + ) + .await; + let error = match result { + Ok(_) => panic!("raw content mutation must fail on the streaming codec path"), + Err(error) => error, + }; + + assert!(error.to_string().contains("request.content")); + assert!(!*provider_called.lock().unwrap()); + assert!(captured_events_snapshot(&events).is_empty()); + + deregister_llm_request_intercept("stream_codec_raw_mutation").unwrap(); + deregister_subscriber("stream_codec_raw_mutation_subscriber").unwrap(); +} + // =========================================================================== // Intercept receives both LlmRequest and AnnotatedLlmRequest // =========================================================================== diff --git a/crates/ffi/nemo_relay.h b/crates/ffi/nemo_relay.h index b8bbe8269..72a18bc3e 100644 --- a/crates/ffi/nemo_relay.h +++ b/crates/ffi/nemo_relay.h @@ -255,7 +255,9 @@ typedef char *(*NemoRelayLlmConditionalCb)(void *user_data, const struct FfiLLMR * signature. Receives the intercept name, the opaque `FfiLLMRequest`, and * optionally the annotated request as a JSON C string (null if no Codec * resolved). Writes one owned canonical outcome JSON string to - * `out_outcome_json`. Returns `NemoRelayStatus`. + * `out_outcome_json`. With a Codec, the outcome must preserve request content + * and return the annotation; only request headers and annotation fields are + * writable. Returns `NemoRelayStatus`. */ typedef NemoRelayStatus (*NemoRelayLlmRequestInterceptCb)(void *user_data, const char *name, diff --git a/crates/ffi/src/callable.rs b/crates/ffi/src/callable.rs index 3e1537311..f079697ea 100644 --- a/crates/ffi/src/callable.rs +++ b/crates/ffi/src/callable.rs @@ -172,7 +172,9 @@ pub type NemoRelayCodecEncodeFn = Option< /// signature. Receives the intercept name, the opaque `FfiLLMRequest`, and /// optionally the annotated request as a JSON C string (null if no Codec /// resolved). Writes one owned canonical outcome JSON string to -/// `out_outcome_json`. Returns `NemoRelayStatus`. +/// `out_outcome_json`. With a Codec, the outcome must preserve request content +/// and return the annotation; only request headers and annotation fields are +/// writable. Returns `NemoRelayStatus`. pub type NemoRelayLlmRequestInterceptCb = unsafe extern "C" fn( user_data: *mut libc::c_void, name: *const c_char, diff --git a/crates/node/src/callable.rs b/crates/node/src/callable.rs index ef5e78b64..aba114523 100644 --- a/crates/node/src/callable.rs +++ b/crates/node/src/callable.rs @@ -210,6 +210,8 @@ pub fn wrap_js_tool_exec_fn( /// The JS callback receives a single JSON object /// `{ name: string, request: LlmRequest, annotated: AnnotatedLlmRequest | null }` /// and must return `{ request, annotated?, pendingMarks? }`. +/// When `annotated` is non-null, request content is read-only and provider-body +/// edits must be made through the returned annotation; headers remain writable. pub fn wrap_js_llm_request_intercept_fn( func: ThreadsafeFunction, ) -> LlmRequestInterceptFn { diff --git a/crates/node/tests/codec_tests.mjs b/crates/node/tests/codec_tests.mjs index c6bd40bc2..400ad62d6 100644 --- a/crates/node/tests/codec_tests.mjs +++ b/crates/node/tests/codec_tests.mjs @@ -153,6 +153,42 @@ describe('Codec pipeline integration', () => { } }); + it('rejects raw request content edits before provider execution', async () => { + let providerCalled = false; + registerLlmRequestIntercept('raw-content-test', 10, false, ({ request, annotated }) => ({ + request: { + ...request, + content: { + ...request.content, + model: 'raw-model-edit', + }, + }, + annotated, + })); + + const handle = pushScope('raw-content-scope', ScopeType.Agent); + try { + await assert.rejects( + () => + execWithCodec( + 'test-llm', + makeRequest(), + async () => { + providerCalled = true; + return { choices: [] }; + }, + mockDecode, + mockEncode, + ), + /request\.content/, + ); + assert.equal(providerCalled, false); + } finally { + popScope(handle); + deregisterLlmRequestIntercept('raw-content-test'); + } + }); + it('different codec functions produce different results', async () => { let usedModel = null; diff --git a/crates/python/src/py_callable.rs b/crates/python/src/py_callable.rs index b47d39834..7b8629be5 100644 --- a/crates/python/src/py_callable.rs +++ b/crates/python/src/py_callable.rs @@ -646,6 +646,8 @@ pub fn wrap_py_llm_conditional_fn(py_fn: Py) -> LlmConditionalFn { /// /// The Python function receives ``(name: str, request: LlmRequest, annotated: AnnotatedLLMRequest | None)`` /// and must return ``LLMRequestInterceptOutcome``. +/// When ``annotated`` is present, request content is read-only and provider-body +/// edits must be made through the returned annotation; headers remain writable. pub fn wrap_py_llm_request_intercept_fn(py_fn: Py) -> LlmRequestInterceptFn { Arc::new( move |name: &str, diff --git a/crates/types/src/api/llm.rs b/crates/types/src/api/llm.rs index 6cc8c6083..e7d34680d 100644 --- a/crates/types/src/api/llm.rs +++ b/crates/types/src/api/llm.rs @@ -33,9 +33,16 @@ pub struct LlmRequest { /// Result of an LLM request intercept that can schedule lifecycle marks. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct LlmRequestInterceptOutcome { - /// Rewritten provider request. + /// Rewritten provider request when no request codec is active. + /// + /// With a request codec, callbacks may rewrite `headers`, but `content` + /// is read-only and provider-body changes must be made through + /// [`Self::annotated_request`]. pub request: LlmRequest, /// Optional normalized request annotation to carry forward. + /// + /// This is required and authoritative for provider content when a request + /// codec is active. It remains optional when no request codec is active. #[serde(default)] pub annotated_request: Option, /// Ordered marks to emit after Relay creates and starts the LLM scope. diff --git a/crates/wasm/src/callable.rs b/crates/wasm/src/callable.rs index 1279421c4..9ecf722f2 100644 --- a/crates/wasm/src/callable.rs +++ b/crates/wasm/src/callable.rs @@ -264,6 +264,8 @@ pub fn wrap_js_tool_exec_fn( /// /// Supports both `(name, request, annotated) => { request, annotated }` and /// `({ name, request, annotated }) => { request, annotated }`. +/// When `annotated` is non-null, request content is read-only and provider-body +/// edits must be made through the returned annotation; headers remain writable. #[cfg(not(target_arch = "wasm32"))] pub fn wrap_js_llm_request_intercept_fn(_func: Function) -> LlmRequestInterceptFn { Arc::new( diff --git a/docs/integrate-into-frameworks/provider-codecs.mdx b/docs/integrate-into-frameworks/provider-codecs.mdx index 03815a6a6..90df6c7f6 100644 --- a/docs/integrate-into-frameworks/provider-codecs.mdx +++ b/docs/integrate-into-frameworks/provider-codecs.mdx @@ -41,10 +41,17 @@ When a managed LLM call has a request codec: 1. NeMo Relay calls `decode` before LLM request intercepts run. 2. Request intercepts receive both the raw request and the annotated request. -3. Intercepts may edit the raw request, the annotated request, or both. +3. Intercepts edit provider-body fields through the annotated request and may + edit transport headers through the raw request. Raw `request.content` is + read-only while the codec is active. 4. NeMo Relay calls `encode` to merge the annotated request back into the original raw request. 5. Execution intercepts and the provider callback receive the encoded provider request. +If a codec-aware intercept changes raw `request.content` or omits the returned +annotation, Relay rejects the outcome before creating the LLM lifecycle. When +no request codec is active, the raw request remains fully writable and is the +provider-visible source of truth. + When a managed LLM call has a response codec, NeMo Relay decodes the raw provider response for observability and attaches the result to the emitted LLM end event. The response codec does not rewrite the value returned to the application. Use [Provider Response Codecs](/integrate-into-frameworks/provider-response-codecs) for response-only behavior and custom response codec examples. Codec implementations should preserve fields they do not understand. Treat `encode` as a merge operation over the original provider payload, not as a full replacement. diff --git a/docs/reference/llm-request-intercept-outcomes.mdx b/docs/reference/llm-request-intercept-outcomes.mdx index 9c2de1ab6..bc4a91c2b 100644 --- a/docs/reference/llm-request-intercept-outcomes.mdx +++ b/docs/reference/llm-request-intercept-outcomes.mdx @@ -21,12 +21,49 @@ includes all three fields. A pending mark contains only `name`, optional `category` and `category_profile`, and optional `data` and `metadata`. Relay owns event UUIDs, parent UUIDs, and timestamps. +## Request Authority + +The provider-body source of truth depends only on whether a request codec is +active: + +| Request codec | Provider body source | Header source | +| --- | --- | --- | +| No codec | `outcome.request.content` | `outcome.request.headers` | +| Active codec | `outcome.annotated_request` | `outcome.request.headers` | + +With an active codec, `request.content` is read-only context. Every intercept +must return an annotation and make provider-body changes through that +annotation, including its flattened `extra` fields for provider-specific data. +Relay rejects a changed raw body or missing annotation at the offending +intercept before invoking later middleware or creating an LLM lifecycle. + +```mermaid +flowchart TD + INPUT["Original LlmRequest"] --> CODEC{"Request codec active?"} + + CODEC -->|No| RAWCHAIN["Run intercept chain"] + RAWCHAIN --> RAWPROVIDER["Provider receives outcome.request"] + + CODEC -->|Yes| DECODE["Decode content into annotated_request"] + DECODE --> INTERCEPT["Invoke next intercept"] + INTERCEPT --> CHECKANN{"Annotation returned?"} + CHECKANN -->|No| FAIL["Fail before lifecycle"] + CHECKANN -->|Yes| CHECKRAW{"request.content unchanged?"} + CHECKRAW -->|No| FAIL + CHECKRAW -->|Yes| MORE{"More intercepts?"} + MORE -->|Yes| INTERCEPT + MORE -->|No| ENCODE["Encode final annotated_request"] + ENCODE --> HEADERS["Apply final request.headers"] + HEADERS --> PROVIDER["Provider receives one resolved LlmRequest"] +``` + Python callbacks return `LLMRequestInterceptOutcome`; Rust callbacks return `LlmRequestInterceptOutcome`; Go callbacks return `LLMRequestInterceptOutcome`; and Node.js and WebAssembly callbacks return `{ request, annotated?, pendingMarks? }`. Public C callbacks write one owned canonical outcome JSON string. Native ABI v1 uses one host-owned outcome JSON -string, and `grpc-v1` uses a `JsonEnvelope` whose schema is +string. Rust and Python `grpc-v1` worker SDKs return their canonical outcome +type in a `JsonEnvelope` whose schema is `nemo.relay.LlmRequestInterceptOutcome@1`. The standalone request-intercept helper returns the complete outcome but does @@ -35,8 +72,9 @@ not emit its pending marks because it does not own an LLM lifecycle. ## Managed Lifecycle Managed execution runs all effective global and scope-local intercepts before -creating the LLM handle. Each rewritten request and annotation feeds the next -intercept, while pending marks append in middleware order. A breaking +creating the LLM handle. Each accepted request/annotation pair feeds the next +intercept under the authority rules above, while pending marks append in +middleware order. A breaking intercept's marks are retained. If any intercept fails or its boundary result is malformed, Relay discards all accumulated marks and creates no LLM lifecycle. diff --git a/go/nemo_relay/callbacks.go b/go/nemo_relay/callbacks.go index 58182417d..856e6fa28 100644 --- a/go/nemo_relay/callbacks.go +++ b/go/nemo_relay/callbacks.go @@ -234,7 +234,10 @@ type LLMRequestInterceptOutcome struct { PendingMarks []PendingMarkSpec `json:"pending_marks"` } -// LLMRequestInterceptFunc is a callback for LLM request intercepts. +// LLMRequestInterceptFunc is a callback for LLM request intercepts. When +// annotatedJSON is non-nil, request.Content is read-only, request.Headers may +// be changed, and the returned annotation is authoritative for provider body +// content. Without an annotation, the full request is writable. type LLMRequestInterceptFunc func( name string, request LLMRequestDTO, diff --git a/go/nemo_relay/top_level_coverage_test.go b/go/nemo_relay/top_level_coverage_test.go index 48eba4fd6..bfa5a6043 100644 --- a/go/nemo_relay/top_level_coverage_test.go +++ b/go/nemo_relay/top_level_coverage_test.go @@ -417,13 +417,35 @@ func assertLlmCodecInterceptCoverage(t *testing.T) { t.Fatalf("expected codec-backed intercept success, got %v", err) } + if err := RegisterLlmRequestIntercept("coverage_llm_codec_raw_content", 1, false, func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { + request.Content = json.RawMessage(`{"model":"raw-model-edit"}`) + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotatedJSON}, nil + }); err != nil { + t.Fatalf("RegisterLlmRequestIntercept raw content case failed: %v", err) + } + providerCalled := false + _, err := LlmCallExecute("coverage_llm_codec_raw_content", map[string]any{ + "headers": map[string]any{}, + "content": map[string]any{"model": coverageModelName}, + }, func(json.RawMessage) (json.RawMessage, error) { + providerCalled = true + return json.RawMessage(`{"content":"unexpected"}`), nil + }, WithLLMCodec(requestCodec)) + assertErrorContains(t, err, "request.content", "codec-backed raw content mutation") + if providerCalled { + t.Fatal("provider should not run after a codec-backed raw content mutation") + } + if err := DeregisterLlmRequestIntercept("coverage_llm_codec_raw_content"); err != nil { + t.Fatalf("failed to deregister raw content intercept: %v", err) + } + if err := RegisterLlmRequestIntercept("coverage_llm_codec_error", 1, false, func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { return LLMRequestInterceptOutcome{}, errors.New("forced codec-backed intercept failure") }); err != nil { t.Fatalf("RegisterLlmRequestIntercept error case failed: %v", err) } defer DeregisterLlmRequestIntercept("coverage_llm_codec_error") - _, err := LlmCallExecute("coverage_llm_codec_error", map[string]any{ + _, err = LlmCallExecute("coverage_llm_codec_error", map[string]any{ "headers": map[string]any{}, "content": map[string]any{"model": coverageModelName}, }, func(json.RawMessage) (json.RawMessage, error) { diff --git a/python/plugin/README.md b/python/plugin/README.md index d0cf2da15..b913e6372 100644 --- a/python/plugin/README.md +++ b/python/plugin/README.md @@ -36,6 +36,27 @@ Set `load.entrypoint` to `your_module:main` in `relay-plugin.toml`. Relay imports that function and awaits the returned coroutine when it starts the worker process. +LLM request intercepts return one canonical outcome: + +```python +from nemo_relay_plugin import LlmRequestInterceptOutcome, PendingMarkSpec + + +def intercept(model_name, request, annotated): + del model_name + headers = {**request.get("headers", {}), "x-policy": "checked"} + return LlmRequestInterceptOutcome( + request={**request, "headers": headers}, + annotated_request=annotated, + pending_marks=[PendingMarkSpec("acme.policy.checked")], + ) +``` + +When `annotated` is present, it is authoritative for provider-body content: +leave raw `request["content"]` unchanged, edit normalized fields or provider +extensions through the annotation, and use `request["headers"]` for transport +headers. + The SDK owns gRPC serving, JSON envelope conversion, callback dispatch, continuations, host runtime calls, and local scope-stack binding. Its private protobuf bindings are generated from the canonical Relay schema while the diff --git a/python/plugin/src/nemo_relay_plugin/__init__.py b/python/plugin/src/nemo_relay_plugin/__init__.py index 2ed6a2267..676a0de3b 100644 --- a/python/plugin/src/nemo_relay_plugin/__init__.py +++ b/python/plugin/src/nemo_relay_plugin/__init__.py @@ -21,6 +21,8 @@ LlmRequest: A Relay LLM request represented as a JSON object. AnnotatedLlmRequest: An annotated Relay LLM request represented as a JSON object. + PendingMarkSpec: A mark Relay emits under the future managed LLM scope. + LlmRequestInterceptOutcome: Canonical LLM request-intercept result. DiagnosticLevel: Severity of a configuration diagnostic. ConfigDiagnostic: Structured configuration warning or error. ScopeType: Semantic category for a Relay execution scope. @@ -62,10 +64,12 @@ LlmNext, LlmRequest, LlmRequestCallback, + LlmRequestInterceptOutcome, LlmSanitizeRequestCallback, LlmSanitizeResponseCallback, LlmStreamExecutionCallback, LlmStreamNext, + PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, @@ -91,12 +95,14 @@ "LlmNext", "LlmRequest", "LlmRequestCallback", + "LlmRequestInterceptOutcome", "LlmSanitizeRequestCallback", "LlmSanitizeResponseCallback", "LlmStreamNext", "LlmStreamExecutionCallback", "PluginContext", "PluginRuntime", + "PendingMarkSpec", "ScopeType", "SubscriberCallback", "ToolConditionalCallback", diff --git a/python/plugin/src/nemo_relay_plugin/_api.py b/python/plugin/src/nemo_relay_plugin/_api.py index b0d21c43f..ffe07aba2 100644 --- a/python/plugin/src/nemo_relay_plugin/_api.py +++ b/python/plugin/src/nemo_relay_plugin/_api.py @@ -16,6 +16,8 @@ LlmRequest: A Relay LLM request represented as a JSON object. AnnotatedLlmRequest: An annotated Relay LLM request represented as a JSON object. + PendingMarkSpec: A mark Relay emits under the future managed LLM scope. + LlmRequestInterceptOutcome: Canonical LLM request-intercept result. DiagnosticLevel: Severity of a configuration diagnostic. ConfigDiagnostic: Structured configuration warning or error. ScopeType: Semantic category for a Relay execution scope. @@ -63,7 +65,7 @@ import tempfile import tomllib from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator, Mapping -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from enum import Enum from importlib import metadata from pathlib import Path @@ -88,8 +90,16 @@ EVENT_SCHEMA = "nemo.relay.Event@1" LLM_REQUEST_SCHEMA = "nemo.relay.LlmRequest@1" ANNOTATED_LLM_REQUEST_SCHEMA = "nemo.relay.AnnotatedLlmRequest@1" +LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA = "nemo.relay.LlmRequestInterceptOutcome@1" PLUGIN_DIAGNOSTICS_SCHEMA = "nemo.relay.PluginDiagnostics@1" -_OBJECT_SCHEMAS = frozenset({EVENT_SCHEMA, LLM_REQUEST_SCHEMA, ANNOTATED_LLM_REQUEST_SCHEMA}) +_OBJECT_SCHEMAS = frozenset( + { + EVENT_SCHEMA, + LLM_REQUEST_SCHEMA, + ANNOTATED_LLM_REQUEST_SCHEMA, + LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA, + } +) _UNREGISTERED = object() _SCOPE_CONTEXT: contextvars.ContextVar[_BoundScopeContext | None] = contextvars.ContextVar( "nemo_relay_plugin_scope_context", @@ -171,6 +181,49 @@ def to_json(self) -> dict[str, Any]: return _normalize_diagnostic(asdict(self)) +@dataclass(slots=True) +class PendingMarkSpec: + """Describe a mark Relay emits after starting a managed LLM call.""" + + name: str + category: str | None = None + category_profile: Json | None = None + data: Json | None = None + metadata: Json | None = None + + def to_json(self) -> dict[str, Json]: + """Convert this pending mark to its canonical JSON object.""" + if not isinstance(self.name, str): + raise WorkerSdkError("pending mark name must be a string") + return asdict(self) + + +@dataclass(slots=True) +class LlmRequestInterceptOutcome: + """Canonical result returned by a Python worker LLM request intercept.""" + + request: LlmRequest + annotated_request: AnnotatedLlmRequest | None = None + pending_marks: list[PendingMarkSpec] = field(default_factory=list) + + def to_json(self) -> dict[str, Json]: + """Convert this outcome to the canonical worker-envelope payload.""" + if not isinstance(self.request, dict): + raise WorkerSdkError("LLM request intercept outcome request must be a JSON object") + if self.annotated_request is not None and not isinstance(self.annotated_request, dict): + raise WorkerSdkError("LLM request intercept outcome annotated_request must be a JSON object or None") + marks = [] + for mark in self.pending_marks: + if not isinstance(mark, PendingMarkSpec): + raise WorkerSdkError("LLM request intercept outcome pending_marks must contain PendingMarkSpec values") + marks.append(mark.to_json()) + return { + "request": self.request, + "annotated_request": self.annotated_request, + "pending_marks": marks, + } + + def _normalize_diagnostic(value: Mapping[str, Any]) -> dict[str, Any]: try: level = DiagnosticLevel(value.get("level")).value @@ -325,9 +378,7 @@ def register(self, ctx: PluginContext, config: Json) -> None | Awaitable[None]: LlmConditionalCallback: TypeAlias = Callable[[LlmRequest], str | None | Awaitable[str | None]] LlmRequestCallback: TypeAlias = Callable[ [str, LlmRequest, AnnotatedLlmRequest | None], - LlmRequest - | tuple[LlmRequest, AnnotatedLlmRequest | None] - | Awaitable[LlmRequest | tuple[LlmRequest, AnnotatedLlmRequest | None]], + LlmRequestInterceptOutcome | Awaitable[LlmRequestInterceptOutcome], ] LlmExecutionCallback: TypeAlias = Callable[[str, LlmRequest, "LlmNext"], Json | Awaitable[Json]] LlmStreamExecutionCallback: TypeAlias = Callable[ @@ -593,10 +644,11 @@ def register_llm_request_intercept( Args: name: Component-local registration name. callback: Function receiving ``(model_name, request, - annotated_request)``. Return a replacement :data:`LlmRequest` - or ``(request, annotated_request)`` tuple, directly or through - an awaitable. ``annotated_request`` is ``None`` when the host - did not provide one. + annotated_request)``. Return + :class:`LlmRequestInterceptOutcome`, directly or through an + awaitable. ``annotated_request`` is ``None`` when the host did + not provide one. When present, it is authoritative for provider + body content and the raw request content must remain unchanged. priority: Execution order. Lower values run first. break_chain: Whether Relay skips later, lower-priority request intercepts after this callback runs. @@ -1358,15 +1410,14 @@ async def _invoke_result(self, request: Any) -> Any: annotated, ) ) - if isinstance(result, tuple): - llm_request, annotated = result - else: - llm_request = result + if not isinstance(result, LlmRequestInterceptOutcome): + raise WorkerSdkError("LLM request intercept must return LlmRequestInterceptOutcome") return pb.InvokeResponse( llm_request=pb.LlmRequestInterceptResult( - request=_json_envelope(LLM_REQUEST_SCHEMA, llm_request), - annotated_request=_optional_json_envelope(annotated, ANNOTATED_LLM_REQUEST_SCHEMA), - has_annotated_request=annotated is not None, + outcome=_json_envelope( + LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA, + result.to_json(), + ), ) ) if request.surface == pb.LLM_EXECUTION_INTERCEPT: diff --git a/python/tests/plugin/test_worker_sdk.py b/python/tests/plugin/test_worker_sdk.py index b0a7a2b27..8c3ac1d8a 100644 --- a/python/tests/plugin/test_worker_sdk.py +++ b/python/tests/plugin/test_worker_sdk.py @@ -13,7 +13,7 @@ import tempfile from collections.abc import AsyncIterator from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast import pytest @@ -26,6 +26,8 @@ ConfigDiagnostic, DiagnosticLevel, Json, + LlmRequestInterceptOutcome, + PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, @@ -39,6 +41,7 @@ ANNOTATED_LLM_REQUEST_SCHEMA, EVENT_SCHEMA, JSON_SCHEMA, + LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA, LLM_REQUEST_SCHEMA, WORKER_PROTOCOL, _announced_worker_endpoint, @@ -201,9 +204,13 @@ def llm_block(request: Json) -> str | None: del request return "llm blocked" - def llm_request(name: str, request: Json, annotated: Json | None) -> tuple[Json, Json]: + def llm_request(name: str, request: Json, annotated: Json | None) -> LlmRequestInterceptOutcome: del name - return _tag_llm_request(request, "llm_request"), _tag(annotated or {}, "annotated") + return LlmRequestInterceptOutcome( + request=_tag_llm_request(request, "llm_request"), + annotated_request=_tag(annotated or {}, "annotated"), + pending_marks=[PendingMarkSpec("worker.pending", data={"source": "python"})], + ) async def llm_execution(name: str, request: Json, next_call: Any) -> Json: result = await next_call.call(_tag_llm_request(request, f"llm_execute_{name}")) @@ -937,9 +944,19 @@ async def test_unary_invoke_success_paths(service: _WorkerService, host_stub: Re ), AbortContext(), ) - assert _envelope_value(llm_request.llm_request.request)["content"]["llm_request"] - assert _envelope_value(llm_request.llm_request.annotated_request)["tag"] == "annotated" - assert llm_request.llm_request.has_annotated_request + assert llm_request.llm_request.outcome.schema == LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA + outcome = _envelope_value(llm_request.llm_request.outcome) + assert outcome["request"]["content"]["llm_request"] + assert outcome["annotated_request"]["tag"] == "annotated" + assert outcome["pending_marks"] == [ + { + "name": "worker.pending", + "category": None, + "category_profile": None, + "data": {"source": "python"}, + "metadata": None, + } + ] llm_execution = await _invoke_json_async( service, @@ -1011,8 +1028,8 @@ def register(self, ctx: PluginContext, config: Json) -> None: def invalid_result(name: str, request: Json, annotated: Json | None) -> Any: del name if invalid_part == "request": - return [] - return request, [] + return LlmRequestInterceptOutcome(request=cast(Any, [])) + return LlmRequestInterceptOutcome(request=request, annotated_request=cast(Any, [])) ctx.register_llm_request_intercept("invalid", invalid_result) @@ -1127,9 +1144,9 @@ class RequestOnlyPlugin(WorkerPlugin): def register(self, ctx: PluginContext, config: Json) -> None: del config - def llm_request(name: str, request: Json, annotated: Json | None) -> Json: + def llm_request(name: str, request: Json, annotated: Json | None) -> LlmRequestInterceptOutcome: del name, annotated - return _tag_llm_request(request, "request_only") + return LlmRequestInterceptOutcome(request=_tag_llm_request(request, "request_only")) ctx.register_llm_request_intercept("request_only", llm_request) @@ -1143,8 +1160,10 @@ def llm_request(name: str, request: Json, annotated: Json | None) -> Json: ), AbortContext(), ) - assert _envelope_value(response.llm_request.request)["content"]["request_only"] - assert not response.llm_request.has_annotated_request + outcome = _envelope_value(response.llm_request.outcome) + assert outcome["request"]["content"]["request_only"] + assert outcome["annotated_request"] is None + assert outcome["pending_marks"] == [] async def test_stream_invoke_success_and_failures(service: _WorkerService, host_stub: RecordingHostStub): diff --git a/python/tests/test_codecs.py b/python/tests/test_codecs.py index 6caab0ca4..8080825a7 100644 --- a/python/tests/test_codecs.py +++ b/python/tests/test_codecs.py @@ -12,6 +12,8 @@ from typing import cast +import pytest + from nemo_relay import ( AnnotatedLLMRequest, JsonObject, @@ -287,6 +289,27 @@ def func(request): finally: intercepts.deregister_llm_request("test-annot-intercept-pipeline") + async def test_codec_rejects_raw_content_edits_before_provider(self): + """Codec-aware intercepts must edit the annotation, not the raw body.""" + provider_called = False + + def raw_content_intercept(name, request, annotated): + content = {**request.content, "model": "raw-model-edit"} + return LLMRequestInterceptOutcome(LLMRequest(request.headers, content), annotated) + + def provider(request): + nonlocal provider_called + provider_called = True + return {"unexpected": True} + + intercepts.register_llm_request("test-codec-raw-content", 1, False, raw_content_intercept) + try: + with pytest.raises(RuntimeError, match=r"request\.content"): + await llm.execute("pipeline-llm", make_request(), provider, codec=SimpleCodec()) + assert not provider_called + finally: + intercepts.deregister_llm_request("test-codec-raw-content") + async def test_codec_parameter(self): """codec parameter passes the specified codec instance directly.""" alternate = AlternateCodec() From 5ec1922d080dc4a0174ce343461b1b7b6ce2cd9a Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Tue, 30 Jun 2026 20:48:48 -0600 Subject: [PATCH 5/9] fix(bindings): camel-case pending mark profiles in JS Signed-off-by: Bryan Bednarski --- crates/node/plugin.d.ts | 2 +- crates/node/src/api/mod.rs | 8 +-- crates/node/src/callable.rs | 54 ++++++++++++++-- crates/node/tests/llm_tests.mjs | 14 +++-- crates/wasm/src/api/mod.rs | 2 +- crates/wasm/src/callable.rs | 63 ++++++++++++++++--- crates/wasm/tests-js/llm_tests.mjs | 20 +++++- crates/wasm/tests/coverage/callable_tests.rs | 34 ++++++++++ crates/wasm/wrappers/esm/plugin.d.ts | 2 +- .../llm-request-intercept-outcomes.mdx | 10 +-- 10 files changed, 179 insertions(+), 30 deletions(-) diff --git a/crates/node/plugin.d.ts b/crates/node/plugin.d.ts index c0368b3bb..e0dc1d4ba 100644 --- a/crates/node/plugin.d.ts +++ b/crates/node/plugin.d.ts @@ -88,7 +88,7 @@ export interface PluginContext { pendingMarks?: Array<{ name: string; category?: string | null; - category_profile?: Json; + categoryProfile?: Json; data?: Json; metadata?: Json; }>; diff --git a/crates/node/src/api/mod.rs b/crates/node/src/api/mod.rs index 0300714b9..59d90dc91 100644 --- a/crates/node/src/api/mod.rs +++ b/crates/node/src/api/mod.rs @@ -2260,7 +2260,7 @@ pub fn register_llm_request_intercept( priority: i32, break_chain: bool, #[napi( - ts_arg_type = "(args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; pendingMarks?: Array<{ name: string; category?: string | null; category_profile?: Json; data?: Json; metadata?: Json }> }" + ts_arg_type = "(args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; pendingMarks?: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> }" )] callable: ThreadsafeFunction, ) -> Result<()> { @@ -2733,7 +2733,7 @@ pub fn scope_register_llm_request_intercept( priority: i32, break_chain: bool, #[napi( - ts_arg_type = "(args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; pendingMarks?: Array<{ name: string; category?: string | null; category_profile?: Json; data?: Json; metadata?: Json }> }" + ts_arg_type = "(args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; pendingMarks?: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> }" )] callable: ThreadsafeFunction, ) -> Result<()> { @@ -2949,7 +2949,7 @@ pub fn tool_conditional_execution(env: Env, name: String, args: Json) -> Result< /// The `request` should be a JSON object with `headers` and `content` fields matching /// the `LlmRequest` schema. Returns the transformed request as JSON. #[napi( - ts_return_type = "Promise<{ request: Json; annotated: Json | null; pendingMarks: Array<{ name: string; category?: string | null; category_profile?: Json; data?: Json; metadata?: Json }> }>" + ts_return_type = "Promise<{ request: Json; annotated: Json | null; pendingMarks: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> }>" )] pub fn llm_request_intercepts(env: Env, name: String, request: Json) -> Result { let llm_request: LlmRequest = serde_json::from_value(request) @@ -2964,7 +2964,7 @@ pub fn llm_request_intercepts(env: Env, name: String, request: Json) -> Result, + #[serde(default)] + category_profile: Option, + #[serde(default)] + data: Option, + #[serde(default)] + metadata: Option, +} + +impl From for PendingMarkSpec { + fn from(mark: JsPendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +impl From for JsPendingMarkSpec { + fn from(mark: PendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +pub(crate) fn js_pending_marks(marks: Vec) -> Vec { + marks.into_iter().map(Into::into).collect() +} + fn recv_json_or_null(rx: std::sync::mpsc::Receiver, error_prefix: &str) -> Json { rx.recv().unwrap_or_else(|e| { record_callback_error(format!("{error_prefix}: {e}")); @@ -255,7 +301,7 @@ pub fn wrap_js_llm_request_intercept_fn( #[serde(default)] annotated: Option, #[serde(default)] - pending_marks: Vec, + pending_marks: Vec, } let outcome: JsOutcome = serde_json::from_value(result).map_err(|e| { FlowError::Internal(format!("invalid JS LLM request intercept outcome: {e}")) @@ -263,7 +309,7 @@ pub fn wrap_js_llm_request_intercept_fn( Ok(LlmRequestInterceptOutcome { request: outcome.request, annotated_request: outcome.annotated, - pending_marks: outcome.pending_marks, + pending_marks: outcome.pending_marks.into_iter().map(Into::into).collect(), }) }, ) diff --git a/crates/node/tests/llm_tests.mjs b/crates/node/tests/llm_tests.mjs index 2f69cb1e7..022afb8ad 100644 --- a/crates/node/tests/llm_tests.mjs +++ b/crates/node/tests/llm_tests.mjs @@ -729,7 +729,7 @@ describe('LLM intercepts', () => { null, ); - for (; ;) { + for (;;) { const chunk = await stream.next(); if (chunk === null) { break; @@ -769,7 +769,7 @@ describe('LLM intercepts', () => { null, ); - for (; ;) { + for (;;) { const chunk = await stream.next(); if (chunk === null) { break; @@ -827,7 +827,11 @@ describe('LLM intercepts', () => { request, annotated, pendingMarks: [ - { name: 'request.first', data: { order: 1 } }, + { + name: 'request.first', + categoryProfile: { subtype: 'optimizer.saved_tokens' }, + data: { order: 1 }, + }, { name: 'request.second', metadata: { source: 'node' } }, ], }; @@ -840,14 +844,14 @@ describe('LLM intercepts', () => { { name: 'request.first', category: null, - category_profile: null, + categoryProfile: { subtype: 'optimizer.saved_tokens' }, data: { order: 1 }, metadata: null, }, { name: 'request.second', category: null, - category_profile: null, + categoryProfile: null, data: null, metadata: { source: 'node' }, }, diff --git a/crates/wasm/src/api/mod.rs b/crates/wasm/src/api/mod.rs index e445d07cb..e3affd091 100644 --- a/crates/wasm/src/api/mod.rs +++ b/crates/wasm/src/api/mod.rs @@ -2039,7 +2039,7 @@ pub fn llm_request_intercepts_wasm( let result_json = serde_json::json!({ "request": result.request, "annotated": result.annotated_request, - "pendingMarks": result.pending_marks, + "pendingMarks": callable::js_pending_marks(result.pending_marks), }); Ok(json_to_js(&result_json)) } diff --git a/crates/wasm/src/callable.rs b/crates/wasm/src/callable.rs index 9ecf722f2..c1a8f5afa 100644 --- a/crates/wasm/src/callable.rs +++ b/crates/wasm/src/callable.rs @@ -16,8 +16,7 @@ use std::sync::Arc; use js_sys::Function; #[cfg(target_arch = "wasm32")] use send_wrapper::SendWrapper; -#[cfg(target_arch = "wasm32")] -use serde::Serialize; +use serde::{Deserialize, Serialize}; use serde_json::Value as Json; #[cfg(target_arch = "wasm32")] use tokio_stream::StreamExt; @@ -28,9 +27,7 @@ use wasm_bindgen::JsValue; #[cfg(target_arch = "wasm32")] use wasm_bindgen_futures::JsFuture; -use nemo_relay::api::event::Event; -#[cfg(target_arch = "wasm32")] -use nemo_relay::api::event::PendingMarkSpec; +use nemo_relay::api::event::{CategoryProfile, Event, EventCategory, PendingMarkSpec}; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionNextFn, LlmRequestInterceptFn, @@ -50,6 +47,52 @@ use crate::convert::{js_callback_to_json, js_to_json, json_to_js}; #[cfg(target_arch = "wasm32")] use crate::types::WasmEvent; +/// JavaScript-facing pending mark DTO. +/// +/// The public WebAssembly API uses camelCase while canonical Relay JSON keeps +/// snake_case field names. +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(crate) struct JsPendingMarkSpec { + name: String, + #[serde(default)] + category: Option, + #[serde(default)] + category_profile: Option, + #[serde(default)] + data: Option, + #[serde(default)] + metadata: Option, +} + +impl From for PendingMarkSpec { + fn from(mark: JsPendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +impl From for JsPendingMarkSpec { + fn from(mark: PendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +pub(crate) fn js_pending_marks(marks: Vec) -> Vec { + marks.into_iter().map(Into::into).collect() +} + /// Extract a human-readable error message from a `JsValue`. /// /// Tries `.as_string()` first (for string errors), then falls back to debug format. @@ -352,9 +395,13 @@ pub fn wrap_js_llm_request_intercept_fn(func: Function) -> LlmRequestInterceptFn } else { let marks_json = js_to_json(&js_pending_marks) .map_err(|e| FlowError::Internal(js_error_message(&e)))?; - serde_json::from_value::>(marks_json).map_err(|e| { - FlowError::Internal(format!("failed to deserialize pendingMarks: {e}")) - })? + serde_json::from_value::>(marks_json) + .map_err(|e| { + FlowError::Internal(format!("failed to deserialize pendingMarks: {e}")) + })? + .into_iter() + .map(Into::into) + .collect() }; Ok(LlmRequestInterceptOutcome { diff --git a/crates/wasm/tests-js/llm_tests.mjs b/crates/wasm/tests-js/llm_tests.mjs index 1b5818008..6733c26d2 100644 --- a/crates/wasm/tests-js/llm_tests.mjs +++ b/crates/wasm/tests-js/llm_tests.mjs @@ -213,11 +213,27 @@ test('WebAssembly llm and stream flows work from the generated Node package', as }, }, annotated, + pendingMarks: [ + { + name: 'request.optimized', + categoryProfile: { subtype: 'optimizer.saved_tokens' }, + data: { savedTokens: 12 }, + }, + ], })); try { - const interceptedRequest = wasm.llmRequestIntercepts('pkg_llm', request); - assert.equal(interceptedRequest.content.intercepted, true); + const outcome = wasm.llmRequestIntercepts('pkg_llm', request); + assert.equal(outcome.request.content.intercepted, true); + assert.deepEqual(outcome.pendingMarks, [ + { + name: 'request.optimized', + category: null, + categoryProfile: { subtype: 'optimizer.saved_tokens' }, + data: { savedTokens: 12 }, + metadata: null, + }, + ]); wasm.llmConditionalExecution(request); } finally { wasm.deregisterLlmRequestIntercept(llmInterceptName); diff --git a/crates/wasm/tests/coverage/callable_tests.rs b/crates/wasm/tests/coverage/callable_tests.rs index 652329570..27ede3f89 100644 --- a/crates/wasm/tests/coverage/callable_tests.rs +++ b/crates/wasm/tests/coverage/callable_tests.rs @@ -14,6 +14,40 @@ fn dummy_function() -> Function { JsValue::NULL.unchecked_into() } +#[test] +fn pending_mark_dto_uses_camel_case_without_changing_canonical_fields() { + let dto: JsPendingMarkSpec = serde_json::from_value(json!({ + "name": "request.optimized", + "categoryProfile": {"subtype": "optimizer.saved_tokens"}, + "data": {"savedTokens": 12} + })) + .unwrap(); + let canonical: PendingMarkSpec = dto.into(); + assert_eq!( + canonical + .category_profile + .as_ref() + .unwrap() + .subtype + .as_deref(), + Some("optimizer.saved_tokens") + ); + + let dto_json = serde_json::to_value(JsPendingMarkSpec::from(canonical)).unwrap(); + assert_eq!( + dto_json["categoryProfile"]["subtype"], + "optimizer.saved_tokens" + ); + assert!(dto_json.get("category_profile").is_none()); + assert!( + serde_json::from_value::(json!({ + "name": "wire-name-is-invalid-in-js", + "category_profile": {"subtype": "invalid"} + })) + .is_err() + ); +} + #[test] fn native_tool_and_llm_wrapper_fallbacks_are_stable() { let tool = wrap_js_tool_fn(dummy_function()); diff --git a/crates/wasm/wrappers/esm/plugin.d.ts b/crates/wasm/wrappers/esm/plugin.d.ts index adc708e72..bb3ab95ec 100644 --- a/crates/wasm/wrappers/esm/plugin.d.ts +++ b/crates/wasm/wrappers/esm/plugin.d.ts @@ -86,7 +86,7 @@ export interface PluginContext { pendingMarks?: Array<{ name: string; category?: string | null; - category_profile?: Json; + categoryProfile?: Json; data?: Json; metadata?: Json; }>; diff --git a/docs/reference/llm-request-intercept-outcomes.mdx b/docs/reference/llm-request-intercept-outcomes.mdx index bc4a91c2b..b184a35ec 100644 --- a/docs/reference/llm-request-intercept-outcomes.mdx +++ b/docs/reference/llm-request-intercept-outcomes.mdx @@ -60,10 +60,12 @@ flowchart TD Python callbacks return `LLMRequestInterceptOutcome`; Rust callbacks return `LlmRequestInterceptOutcome`; Go callbacks return `LLMRequestInterceptOutcome`; and Node.js and WebAssembly callbacks return -`{ request, annotated?, pendingMarks? }`. Public C callbacks write one owned -canonical outcome JSON string. Native ABI v1 uses one host-owned outcome JSON -string. Rust and Python `grpc-v1` worker SDKs return their canonical outcome -type in a `JsonEnvelope` whose schema is +`{ request, annotated?, pendingMarks? }`, with `categoryProfile` on each +JavaScript pending-mark DTO. The canonical JSON forms retain `pending_marks` +and `category_profile`. Public C callbacks write one owned canonical outcome +JSON string. Native ABI v1 uses one host-owned outcome JSON string. Rust and +Python `grpc-v1` worker SDKs return their canonical outcome type in a +`JsonEnvelope` whose schema is `nemo.relay.LlmRequestInterceptOutcome@1`. The standalone request-intercept helper returns the complete outcome but does From 6cb7ff98295eb83ceb9734e8cf2d383ad7a3cdb9 Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Tue, 30 Jun 2026 21:35:34 -0600 Subject: [PATCH 6/9] fix(llm): harden request intercept boundary contracts Signed-off-by: Bryan Bednarski --- crates/core/src/bindings.rs | 60 +++++++++++++++++++ crates/core/src/lib.rs | 2 + .../tests/fixtures/worker_plugin/src/main.rs | 11 +++- .../tests/integration/middleware_tests.rs | 15 +++-- .../tests/integration/worker_plugin_tests.rs | 7 +++ crates/ffi/nemo_relay.h | 24 +++++--- crates/ffi/src/api/mod.rs | 23 ++++--- crates/ffi/src/callable.rs | 9 ++- crates/ffi/tests/unit/api/core_tests.rs | 2 + crates/node/src/callable.rs | 51 +--------------- .../tests/coverage/py_api_coverage_tests.rs | 8 ++- crates/wasm/src/callable.rs | 54 ++--------------- crates/wasm/tests/coverage/callable_tests.rs | 1 + docs/build-plugins/register-behavior.mdx | 3 +- .../instrument-applications/code-examples.mdx | 8 +-- .../code-examples.mdx | 6 +- go/nemo_relay/top_level_coverage_test.go | 8 ++- python/tests/plugin/test_worker_sdk.py | 15 ++++- python/tests/test_codecs.py | 21 +++++++ 19 files changed, 198 insertions(+), 130 deletions(-) create mode 100644 crates/core/src/bindings.rs diff --git a/crates/core/src/bindings.rs b/crates/core/src/bindings.rs new file mode 100644 index 000000000..5329972ab --- /dev/null +++ b/crates/core/src/bindings.rs @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared data conversions used by language bindings. + +/// JavaScript-specific data transfer objects shared by Node.js and WebAssembly. +pub mod js { + use serde::{Deserialize, Serialize}; + use serde_json::Value as Json; + + use crate::api::event::{CategoryProfile, EventCategory, PendingMarkSpec}; + + /// JavaScript-facing pending mark DTO. + /// + /// JavaScript bindings use camelCase while canonical Relay JSON keeps + /// snake_case field names. + #[derive(Debug, Deserialize, Serialize)] + #[serde(rename_all = "camelCase", deny_unknown_fields)] + pub struct JsPendingMarkSpec { + name: String, + #[serde(default)] + category: Option, + #[serde(default)] + category_profile: Option, + #[serde(default)] + data: Option, + #[serde(default)] + metadata: Option, + } + + impl From for PendingMarkSpec { + fn from(mark: JsPendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } + } + + impl From for JsPendingMarkSpec { + fn from(mark: PendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } + } + + /// Convert canonical pending marks to JavaScript-facing DTOs. + #[must_use] + pub fn js_pending_marks(marks: Vec) -> Vec { + marks.into_iter().map(Into::into).collect() + } +} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index ffd03a4dd..c1532869e 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -54,6 +54,8 @@ //! All middleware is priority-ordered (ascending) and registered by name for //! easy addition and removal at runtime. pub mod api; +#[doc(hidden)] +pub mod bindings; pub mod codec; pub mod config_editor; mod context; diff --git a/crates/core/tests/fixtures/worker_plugin/src/main.rs b/crates/core/tests/fixtures/worker_plugin/src/main.rs index 1e296dd28..6291e1dc4 100644 --- a/crates/core/tests/fixtures/worker_plugin/src/main.rs +++ b/crates/core/tests/fixtures/worker_plugin/src/main.rs @@ -5,7 +5,9 @@ use nemo_relay_worker::{ JsonStream, LlmNext, LlmStreamNext, PluginContext, ScopeType, ToolNext, WorkerPlugin, WorkerSdkError, serve_plugin, }; -use nemo_relay_worker::{ConfigDiagnostic, DiagnosticLevel, Json, LlmRequest}; +use nemo_relay_worker::{ + ConfigDiagnostic, DiagnosticLevel, Json, LlmRequest, PendingMarkSpec, +}; use serde_json::json; struct FixtureWorkerPlugin; @@ -197,6 +199,13 @@ impl WorkerPlugin for FixtureWorkerPlugin { Ok(nemo_relay_worker::LlmRequestInterceptOutcome::new( request, annotated, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("fixture.worker.llm_request.mark") + .data(json!({ "source": "worker_request_intercept" })) + .metadata(json!({ "fixture": true })) + .build(), )) }, ); diff --git a/crates/core/tests/integration/middleware_tests.rs b/crates/core/tests/integration/middleware_tests.rs index 8e8c9ea92..cf9d29e20 100644 --- a/crates/core/tests/integration/middleware_tests.rs +++ b/crates/core/tests/integration/middleware_tests.rs @@ -2694,7 +2694,10 @@ async fn test_managed_llm_emits_pending_marks_under_started_scope() { LlmCallExecuteParams::builder() .name("pending-managed-llm") .request(LlmRequest { - headers: serde_json::Map::new(), + headers: serde_json::Map::from_iter([( + "x-pending-mark-test".into(), + json!("preserved"), + )]), content: json!({"prompt": "hello"}), }) .func(Arc::new(move |request| { @@ -2707,9 +2710,13 @@ async fn test_managed_llm_emits_pending_marks_under_started_scope() { .unwrap(); let provider_request = provider_request.lock().unwrap().clone().unwrap(); - let provider_json = serde_json::to_value(provider_request).unwrap(); - assert!(provider_json.get("pending_marks").is_none()); - assert!(provider_json.get("annotated_request").is_none()); + assert_eq!( + provider_request.headers.get("x-pending-mark-test"), + Some(&json!("preserved")) + ); + assert_eq!(provider_request.content["prompt"], "hello"); + assert!(provider_request.content.get("pending_marks").is_none()); + assert!(provider_request.content.get("annotated_request").is_none()); let captured = captured_events_snapshot(&events); let start = captured diff --git a/crates/core/tests/integration/worker_plugin_tests.rs b/crates/core/tests/integration/worker_plugin_tests.rs index 6b0fa1735..21f2a1d0e 100644 --- a/crates/core/tests/integration/worker_plugin_tests.rs +++ b/crates/core/tests/integration/worker_plugin_tests.rs @@ -182,6 +182,13 @@ async fn rust_worker_registers_and_invokes_all_current_surfaces() { llm_start.input().unwrap()["content"]["worker_plugin_llm_sanitize_request"], true ); + let pending_mark = find_event(&captured_events, "fixture.worker.llm_request.mark", None); + assert_eq!(pending_mark.parent_uuid(), Some(llm_start.uuid())); + assert_eq!( + pending_mark.data().unwrap()["source"], + "worker_request_intercept" + ); + assert_eq!(pending_mark.metadata().unwrap()["fixture"], true); let llm_end = find_event( &captured_events, "worker-fixture-llm-execute", diff --git a/crates/ffi/nemo_relay.h b/crates/ffi/nemo_relay.h index 72a18bc3e..2b1eaa5fb 100644 --- a/crates/ffi/nemo_relay.h +++ b/crates/ffi/nemo_relay.h @@ -255,8 +255,13 @@ typedef char *(*NemoRelayLlmConditionalCb)(void *user_data, const struct FfiLLMR * signature. Receives the intercept name, the opaque `FfiLLMRequest`, and * optionally the annotated request as a JSON C string (null if no Codec * resolved). Writes one owned canonical outcome JSON string to - * `out_outcome_json`. With a Codec, the outcome must preserve request content - * and return the annotation; only request headers and annotation fields are + * `out_outcome_json`. Any non-null string written there must be allocated by + * `nemo_relay_llm_request_intercept_outcome_json_new` or by an allocation + * compatible with `nemo_relay_string_free`. Ownership transfers to Relay + * when the callback returns; the callback must not free or reuse the string + * afterward. Relay frees it exactly once, even when the callback returns an + * error status. With a Codec, the outcome must preserve request content and + * return the annotation; only request headers and annotation fields are * writable. Returns `NemoRelayStatus`. */ typedef NemoRelayStatus (*NemoRelayLlmRequestInterceptCb)(void *user_data, @@ -404,15 +409,20 @@ NemoRelayStatus nemo_relay_llm_request_intercepts(const char *name, * Allocate canonical JSON for a C LLM request-intercept callback result. * * `annotated_json` may be null. `pending_marks_json` may be null, in which - * case an empty list is serialized. The caller owns the returned string and - * must release it with [`nemo_relay_string_free`]. + * case an empty list is serialized. When used by a + * `NemoRelayLlmRequestInterceptCb`, assign the successful output to the + * callback's `out_outcome_json`; ownership transfers to Relay when the + * callback returns, so the callback must not free or reuse it. Outside a + * callback, the caller owns the returned string and must release it with + * `nemo_relay_string_free`. * * # Safety * - * `request` must point to a live [`FfiLLMRequest`], optional JSON inputs must + * `request` must point to a live `FfiLLMRequest`, optional JSON inputs must * be valid null-terminated strings when non-null, and `out_outcome_json` must - * be writable. The caller must free a successful output with - * [`nemo_relay_string_free`]. + * be writable. A successful output must either be transferred through a + * callback's `out_outcome_json` or freed by its caller with + * `nemo_relay_string_free`. */ NemoRelayStatus nemo_relay_llm_request_intercept_outcome_json_new(const struct FfiLLMRequest *request, const char *annotated_json, diff --git a/crates/ffi/src/api/mod.rs b/crates/ffi/src/api/mod.rs index c6d60889a..a0d9ee5e2 100644 --- a/crates/ffi/src/api/mod.rs +++ b/crates/ffi/src/api/mod.rs @@ -243,15 +243,20 @@ pub unsafe extern "C" fn nemo_relay_llm_request_intercepts( /// Allocate canonical JSON for a C LLM request-intercept callback result. /// /// `annotated_json` may be null. `pending_marks_json` may be null, in which -/// case an empty list is serialized. The caller owns the returned string and -/// must release it with [`nemo_relay_string_free`]. +/// case an empty list is serialized. When used by a +/// `NemoRelayLlmRequestInterceptCb`, assign the successful output to the +/// callback's `out_outcome_json`; ownership transfers to Relay when the +/// callback returns, so the callback must not free or reuse it. Outside a +/// callback, the caller owns the returned string and must release it with +/// `nemo_relay_string_free`. /// /// # Safety /// -/// `request` must point to a live [`FfiLLMRequest`], optional JSON inputs must +/// `request` must point to a live `FfiLLMRequest`, optional JSON inputs must /// be valid null-terminated strings when non-null, and `out_outcome_json` must -/// be writable. The caller must free a successful output with -/// [`nemo_relay_string_free`]. +/// be writable. A successful output must either be transferred through a +/// callback's `out_outcome_json` or freed by its caller with +/// `nemo_relay_string_free`. #[unsafe(no_mangle)] pub unsafe extern "C" fn nemo_relay_llm_request_intercept_outcome_json_new( request: *const FfiLLMRequest, @@ -260,11 +265,15 @@ pub unsafe extern "C" fn nemo_relay_llm_request_intercept_outcome_json_new( out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { clear_last_error(); - if request.is_null() || out_outcome_json.is_null() { - set_last_error("request and out_outcome_json must be non-null"); + if out_outcome_json.is_null() { + set_last_error("out_outcome_json must be non-null"); return NemoRelayStatus::NullPointer; } unsafe { *out_outcome_json = std::ptr::null_mut() }; + if request.is_null() { + set_last_error("request must be non-null"); + return NemoRelayStatus::NullPointer; + } let annotated_request = if annotated_json.is_null() { None } else { diff --git a/crates/ffi/src/callable.rs b/crates/ffi/src/callable.rs index f079697ea..9baaba2b2 100644 --- a/crates/ffi/src/callable.rs +++ b/crates/ffi/src/callable.rs @@ -172,8 +172,13 @@ pub type NemoRelayCodecEncodeFn = Option< /// signature. Receives the intercept name, the opaque `FfiLLMRequest`, and /// optionally the annotated request as a JSON C string (null if no Codec /// resolved). Writes one owned canonical outcome JSON string to -/// `out_outcome_json`. With a Codec, the outcome must preserve request content -/// and return the annotation; only request headers and annotation fields are +/// `out_outcome_json`. Any non-null string written there must be allocated by +/// `nemo_relay_llm_request_intercept_outcome_json_new` or by an allocation +/// compatible with `nemo_relay_string_free`. Ownership transfers to Relay +/// when the callback returns; the callback must not free or reuse the string +/// afterward. Relay frees it exactly once, even when the callback returns an +/// error status. With a Codec, the outcome must preserve request content and +/// return the annotation; only request headers and annotation fields are /// writable. Returns `NemoRelayStatus`. pub type NemoRelayLlmRequestInterceptCb = unsafe extern "C" fn( user_data: *mut libc::c_void, diff --git a/crates/ffi/tests/unit/api/core_tests.rs b/crates/ffi/tests/unit/api/core_tests.rs index 735ca68ca..261ef8a0e 100644 --- a/crates/ffi/tests/unit/api/core_tests.rs +++ b/crates/ffi/tests/unit/api/core_tests.rs @@ -44,6 +44,7 @@ fn test_ffi_llm_request_intercept_outcome_json_allocation_and_validation() { NemoRelayStatus::InvalidJson ); assert!(outcome_json.is_null()); + outcome_json = std::ptr::dangling_mut(); assert_eq!( unsafe { api::nemo_relay_llm_request_intercept_outcome_json_new( @@ -55,6 +56,7 @@ fn test_ffi_llm_request_intercept_outcome_json_allocation_and_validation() { }, NemoRelayStatus::NullPointer ); + assert!(outcome_json.is_null()); unsafe { nemo_relay_llm_request_free(request) }; } diff --git a/crates/node/src/callable.rs b/crates/node/src/callable.rs index a00fbfdb5..87386fbfb 100644 --- a/crates/node/src/callable.rs +++ b/crates/node/src/callable.rs @@ -19,12 +19,13 @@ use nemo_relay::api::runtime::{ LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use serde_json::Value as Json; use tokio_stream::StreamExt; -use nemo_relay::api::event::{CategoryProfile, Event, EventCategory, PendingMarkSpec}; +use nemo_relay::api::event::Event; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; +pub(crate) use nemo_relay::bindings::js::{JsPendingMarkSpec, js_pending_marks}; use nemo_relay::codec::request::AnnotatedLlmRequest; use nemo_relay::codec::response::AnnotatedLlmResponse; use nemo_relay::codec::traits::{LlmCodec, LlmResponseCodec}; @@ -34,52 +35,6 @@ use crate::convert::{callback_json, record_callback_error}; use crate::promise_call::{JsonNextFn, JsonStreamNextFn, PromiseAwareFn}; use crate::types::JsEvent; -/// JavaScript-facing pending mark DTO. -/// -/// The public Node API uses camelCase while the canonical Relay wire contract -/// keeps snake_case field names. -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase", deny_unknown_fields)] -pub(crate) struct JsPendingMarkSpec { - name: String, - #[serde(default)] - category: Option, - #[serde(default)] - category_profile: Option, - #[serde(default)] - data: Option, - #[serde(default)] - metadata: Option, -} - -impl From for PendingMarkSpec { - fn from(mark: JsPendingMarkSpec) -> Self { - Self { - name: mark.name, - category: mark.category, - category_profile: mark.category_profile, - data: mark.data, - metadata: mark.metadata, - } - } -} - -impl From for JsPendingMarkSpec { - fn from(mark: PendingMarkSpec) -> Self { - Self { - name: mark.name, - category: mark.category, - category_profile: mark.category_profile, - data: mark.data, - metadata: mark.metadata, - } - } -} - -pub(crate) fn js_pending_marks(marks: Vec) -> Vec { - marks.into_iter().map(Into::into).collect() -} - fn recv_json_or_null(rx: std::sync::mpsc::Receiver, error_prefix: &str) -> Json { rx.recv().unwrap_or_else(|e| { record_callback_error(format!("{error_prefix}: {e}")); diff --git a/crates/python/tests/coverage/py_api_coverage_tests.rs b/crates/python/tests/coverage/py_api_coverage_tests.rs index 409493997..9ce9107ab 100644 --- a/crates/python/tests/coverage/py_api_coverage_tests.rs +++ b/crates/python/tests/coverage/py_api_coverage_tests.rs @@ -193,8 +193,12 @@ def llm_conditional(request): def llm_request_intercept(name, request, annotated): headers = dict(request.headers) headers["x-intercepted"] = "1" - content = dict(request.content) - content["messages"] = [{"role": "user", "content": "hello from intercept"}] + if annotated is None: + content = dict(request.content) + content["messages"] = [{"role": "user", "content": "hello from intercept"}] + else: + content = request.content + annotated.messages = [{"role": "user", "content": "hello from intercept"}] return LLMRequestInterceptOutcome(LLMRequest(headers, content), annotated) async def llm_exec(request): diff --git a/crates/wasm/src/callable.rs b/crates/wasm/src/callable.rs index c1a8f5afa..5fd3d6b00 100644 --- a/crates/wasm/src/callable.rs +++ b/crates/wasm/src/callable.rs @@ -16,7 +16,8 @@ use std::sync::Arc; use js_sys::Function; #[cfg(target_arch = "wasm32")] use send_wrapper::SendWrapper; -use serde::{Deserialize, Serialize}; +#[cfg(target_arch = "wasm32")] +use serde::Serialize; use serde_json::Value as Json; #[cfg(target_arch = "wasm32")] use tokio_stream::StreamExt; @@ -27,13 +28,16 @@ use wasm_bindgen::JsValue; #[cfg(target_arch = "wasm32")] use wasm_bindgen_futures::JsFuture; -use nemo_relay::api::event::{CategoryProfile, Event, EventCategory, PendingMarkSpec}; +use nemo_relay::api::event::Event; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionNextFn, LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; +#[cfg(any(target_arch = "wasm32", test))] +pub(crate) use nemo_relay::bindings::js::JsPendingMarkSpec; +pub(crate) use nemo_relay::bindings::js::js_pending_marks; use nemo_relay::codec::request::AnnotatedLlmRequest; #[cfg(target_arch = "wasm32")] use nemo_relay::codec::response::AnnotatedLlmResponse; @@ -47,52 +51,6 @@ use crate::convert::{js_callback_to_json, js_to_json, json_to_js}; #[cfg(target_arch = "wasm32")] use crate::types::WasmEvent; -/// JavaScript-facing pending mark DTO. -/// -/// The public WebAssembly API uses camelCase while canonical Relay JSON keeps -/// snake_case field names. -#[derive(Debug, Deserialize, Serialize)] -#[serde(rename_all = "camelCase", deny_unknown_fields)] -pub(crate) struct JsPendingMarkSpec { - name: String, - #[serde(default)] - category: Option, - #[serde(default)] - category_profile: Option, - #[serde(default)] - data: Option, - #[serde(default)] - metadata: Option, -} - -impl From for PendingMarkSpec { - fn from(mark: JsPendingMarkSpec) -> Self { - Self { - name: mark.name, - category: mark.category, - category_profile: mark.category_profile, - data: mark.data, - metadata: mark.metadata, - } - } -} - -impl From for JsPendingMarkSpec { - fn from(mark: PendingMarkSpec) -> Self { - Self { - name: mark.name, - category: mark.category, - category_profile: mark.category_profile, - data: mark.data, - metadata: mark.metadata, - } - } -} - -pub(crate) fn js_pending_marks(marks: Vec) -> Vec { - marks.into_iter().map(Into::into).collect() -} - /// Extract a human-readable error message from a `JsValue`. /// /// Tries `.as_string()` first (for string errors), then falls back to debug format. diff --git a/crates/wasm/tests/coverage/callable_tests.rs b/crates/wasm/tests/coverage/callable_tests.rs index 27ede3f89..46540a222 100644 --- a/crates/wasm/tests/coverage/callable_tests.rs +++ b/crates/wasm/tests/coverage/callable_tests.rs @@ -4,6 +4,7 @@ //! Coverage tests for callable in the NeMo Relay WebAssembly crate. use super::*; +use nemo_relay::api::event::PendingMarkSpec; use nemo_relay::codec::request::AnnotatedLlmRequest; use serde_json::json; use tokio_stream::StreamExt; diff --git a/docs/build-plugins/register-behavior.mdx b/docs/build-plugins/register-behavior.mdx index 04022fba8..82f01803f 100644 --- a/docs/build-plugins/register-behavior.mdx +++ b/docs/build-plugins/register-behavior.mdx @@ -102,6 +102,7 @@ plugin.register('header-plugin', headerPlugin); ```rust +use nemo_relay::api::llm::LlmRequestInterceptOutcome; use nemo_relay::plugin::{ register_plugin, ConfigDiagnostic, DiagnosticLevel, Plugin, PluginRegistrationContext, Result as PluginResult, @@ -169,7 +170,7 @@ impl Plugin for HeaderPlugin { request .headers .insert(header_name.clone(), header_value.clone().into()); - Ok((request, annotated)) + Ok(LlmRequestInterceptOutcome::new(request, annotated)) }), )?; Ok(()) diff --git a/docs/instrument-applications/code-examples.mdx b/docs/instrument-applications/code-examples.mdx index 58b03404c..81a9d4407 100644 --- a/docs/instrument-applications/code-examples.mdx +++ b/docs/instrument-applications/code-examples.mdx @@ -260,8 +260,8 @@ tool_args = nemo_relay.tools.request_intercepts("search", {"query": "weather"}) nemo_relay.tools.conditional_execution("search", tool_args) llm_request = LLMRequest({}, {"messages": [{"role": "user", "content": "hello"}]}) -llm_request = nemo_relay.llm.request_intercepts("demo-provider", llm_request) -nemo_relay.llm.conditional_execution(llm_request) +outcome = nemo_relay.llm.request_intercepts("demo-provider", llm_request) +nemo_relay.llm.conditional_execution(outcome.request) ``` @@ -279,8 +279,8 @@ const toolArgs = await toolRequestIntercepts('search', { query: 'weather' }); await toolConditionalExecution('search', toolArgs); const request = new LlmRequest({}, { messages: [{ role: 'user', content: 'hello' }] }); -const rewritten = await llmRequestIntercepts('demo-provider', request); -await llmConditionalExecution(rewritten); +const outcome = await llmRequestIntercepts('demo-provider', request); +await llmConditionalExecution(outcome.request); ``` diff --git a/docs/integrate-into-frameworks/code-examples.mdx b/docs/integrate-into-frameworks/code-examples.mdx index c3d0b4d05..698eed6fd 100644 --- a/docs/integrate-into-frameworks/code-examples.mdx +++ b/docs/integrate-into-frameworks/code-examples.mdx @@ -184,10 +184,11 @@ import nemo_relay from nemo_relay import LLMRequest rewritten_args = nemo_relay.tools.request_intercepts("search", {"query": "weather"}) -rewritten_request = nemo_relay.llm.request_intercepts( +outcome = nemo_relay.llm.request_intercepts( "demo-provider", LLMRequest({}, {"messages": []}), ) +rewritten_request = outcome.request ``` @@ -196,7 +197,8 @@ rewritten_request = nemo_relay.llm.request_intercepts( import { LlmRequest, llmRequestIntercepts, toolRequestIntercepts } from 'nemo-relay-node'; const rewrittenArgs = await toolRequestIntercepts('search', { query: 'weather' }); -const rewrittenRequest = await llmRequestIntercepts('demo-provider', new LlmRequest({}, { messages: [] })); +const outcome = await llmRequestIntercepts('demo-provider', new LlmRequest({}, { messages: [] })); +const rewrittenRequest = outcome.request; ``` diff --git a/go/nemo_relay/top_level_coverage_test.go b/go/nemo_relay/top_level_coverage_test.go index bfa5a6043..8c4684bae 100644 --- a/go/nemo_relay/top_level_coverage_test.go +++ b/go/nemo_relay/top_level_coverage_test.go @@ -423,6 +423,11 @@ func assertLlmCodecInterceptCoverage(t *testing.T) { }); err != nil { t.Fatalf("RegisterLlmRequestIntercept raw content case failed: %v", err) } + t.Cleanup(func() { + if err := DeregisterLlmRequestIntercept("coverage_llm_codec_raw_content"); err != nil { + t.Errorf("failed to deregister raw content intercept: %v", err) + } + }) providerCalled := false _, err := LlmCallExecute("coverage_llm_codec_raw_content", map[string]any{ "headers": map[string]any{}, @@ -435,9 +440,6 @@ func assertLlmCodecInterceptCoverage(t *testing.T) { if providerCalled { t.Fatal("provider should not run after a codec-backed raw content mutation") } - if err := DeregisterLlmRequestIntercept("coverage_llm_codec_raw_content"); err != nil { - t.Fatalf("failed to deregister raw content intercept: %v", err) - } if err := RegisterLlmRequestIntercept("coverage_llm_codec_error", 1, false, func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { return LLMRequestInterceptOutcome{}, errors.New("forced codec-backed intercept failure") diff --git a/python/tests/plugin/test_worker_sdk.py b/python/tests/plugin/test_worker_sdk.py index 8c3ac1d8a..5ef9728e6 100644 --- a/python/tests/plugin/test_worker_sdk.py +++ b/python/tests/plugin/test_worker_sdk.py @@ -208,7 +208,7 @@ def llm_request(name: str, request: Json, annotated: Json | None) -> LlmRequestI del name return LlmRequestInterceptOutcome( request=_tag_llm_request(request, "llm_request"), - annotated_request=_tag(annotated or {}, "annotated"), + annotated_request=_tag(annotated, "annotated") if annotated is not None else None, pending_marks=[PendingMarkSpec("worker.pending", data={"source": "python"})], ) @@ -958,6 +958,19 @@ async def test_unary_invoke_success_paths(service: _WorkerService, host_stub: Re } ] + request_only = await service.Invoke( + _invoke_request( + "llm_request", + pb.LLM_REQUEST_INTERCEPT, + llm=_llm_payload(request={"content": {"prompt": "hello"}}), + ), + AbortContext(), + ) + request_only_outcome = _envelope_value(request_only.llm_request.outcome) + assert request_only_outcome["request"]["content"]["llm_request"] + assert request_only_outcome["annotated_request"] is None + assert request_only_outcome["pending_marks"] == outcome["pending_marks"] + llm_execution = await _invoke_json_async( service, "llm_execution", diff --git a/python/tests/test_codecs.py b/python/tests/test_codecs.py index 8080825a7..d5ef6810b 100644 --- a/python/tests/test_codecs.py +++ b/python/tests/test_codecs.py @@ -310,6 +310,27 @@ def provider(request): finally: intercepts.deregister_llm_request("test-codec-raw-content") + async def test_codec_rejects_missing_annotation_before_provider(self): + """Codec-aware intercepts must return the decoded annotation.""" + provider_called = False + + def missing_annotation_intercept(name, request, annotated): + assert annotated is not None + return LLMRequestInterceptOutcome(request, None) + + def provider(request): + nonlocal provider_called + provider_called = True + return {"unexpected": True} + + intercepts.register_llm_request("test-codec-missing-annotation", 1, False, missing_annotation_intercept) + try: + with pytest.raises(RuntimeError, match="omitted annotated_request"): + await llm.execute("pipeline-llm", make_request(), provider, codec=SimpleCodec()) + assert not provider_called + finally: + intercepts.deregister_llm_request("test-codec-missing-annotation") + async def test_codec_parameter(self): """codec parameter passes the specified codec instance directly.""" alternate = AlternateCodec() From 66e8704702f382368b8aaca4445e9fda1c5b6c1f Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Tue, 30 Jun 2026 22:33:59 -0600 Subject: [PATCH 7/9] fix(plugin): validate intercept outcome contract before export Signed-off-by: Bryan Bednarski --- crates/plugin/src/lib.rs | 30 ++++++++++++-------------- crates/plugin/tests/typed_callbacks.rs | 21 ++++++++++++++++++ go/nemo_relay/llm_test.go | 3 +-- python/nemo_relay/intercepts.py | 6 +++--- python/tests/test_llm.py | 12 ++++++++++- 5 files changed, 50 insertions(+), 22 deletions(-) diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 70fb32ce2..5231ce803 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -2625,7 +2625,7 @@ pub unsafe fn export_plugin( } unsafe { *out = NemoRelayNativePluginV1::default() }; let host_ref = unsafe { &*host }; - export_plugin_checked(host_ref, out, plugin) + export_plugin_checked(host_ref, out, || plugin) } /// Initializes a native plugin descriptor from a constructor callback. @@ -2649,6 +2649,18 @@ where } unsafe { *out = NemoRelayNativePluginV1::default() }; let host_ref = unsafe { &*host }; + export_plugin_checked(host_ref, out, constructor) +} + +fn export_plugin_checked( + host_ref: &NemoRelayNativeHostApiV1, + out: *mut NemoRelayNativePluginV1, + constructor: F, +) -> NemoRelayStatus +where + P: NativePlugin, + F: FnOnce() -> P, +{ if host_ref.abi_version != NEMO_RELAY_NATIVE_ABI_VERSION { return NemoRelayStatus::InvalidArg; } @@ -2661,21 +2673,7 @@ where return NemoRelayStatus::InvalidArg; } - export_plugin_checked(host_ref, out, constructor()) -} - -fn export_plugin_checked( - host_ref: &NemoRelayNativeHostApiV1, - out: *mut NemoRelayNativePluginV1, - plugin: P, -) -> NemoRelayStatus { - if host_ref.abi_version != NEMO_RELAY_NATIVE_ABI_VERSION { - return NemoRelayStatus::InvalidArg; - } - if host_ref.struct_size < std::mem::size_of::() { - return NemoRelayStatus::InvalidArg; - } - + let plugin = constructor(); let kind = plugin.plugin_kind().to_owned(); let allows_multiple_components = plugin.allows_multiple_components(); let Some(kind_handle) = HostString::new(host_ref, &kind) else { diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs index a61d1e99d..8e751289e 100644 --- a/crates/plugin/tests/typed_callbacks.rs +++ b/crates/plugin/tests/typed_callbacks.rs @@ -4754,6 +4754,18 @@ fn direct_export_plugin_validates_host_table_and_kind_allocation() { NemoRelayStatus::InvalidArg ); + let mut incompatible_host = host; + incompatible_host.llm_request_intercept_outcome_contract_version = + NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION + 1; + assert_eq!( + unsafe { + nemo_relay_plugin::export_plugin(&incompatible_host, &mut plugin, CountingPlugin) + }, + NemoRelayStatus::InvalidArg + ); + assert!(plugin.plugin_kind.is_null()); + assert!(plugin.user_data.is_null()); + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); assert_eq!( unsafe { nemo_relay_plugin::export_plugin(&host, &mut plugin, CountingPlugin) }, @@ -4976,6 +4988,15 @@ fn exported_entry_symbol_validates_args_before_constructor() { NemoRelayStatus::InvalidArg ); assert_eq!(CONSTRUCTOR_CALLS.load(Ordering::SeqCst), 0); + + let mut incompatible_host = host; + incompatible_host.llm_request_intercept_outcome_contract_version = + NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION + 1; + assert_eq!( + unsafe { constructor_counting_entry(&incompatible_host, &mut plugin) }, + NemoRelayStatus::InvalidArg + ); + assert_eq!(CONSTRUCTOR_CALLS.load(Ordering::SeqCst), 0); } #[test] diff --git a/go/nemo_relay/llm_test.go b/go/nemo_relay/llm_test.go index da94d2fde..0ca101ce4 100644 --- a/go/nemo_relay/llm_test.go +++ b/go/nemo_relay/llm_test.go @@ -529,6 +529,7 @@ func TestLlmRequestInterceptModifies(t *testing.T) { return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ) + t.Cleanup(func() { _ = DeregisterLlmRequestIntercept("go_llm_req_mod") }) request := makeRequest() result, err := LlmCallExecute("int_llm", request, @@ -550,8 +551,6 @@ func TestLlmRequestInterceptModifies(t *testing.T) { if output["saw_intercepted"] != true { t.Fatalf("expected saw_intercepted=true, got %v", output) } - - DeregisterLlmRequestIntercept("go_llm_req_mod") } func TestLlmExecutionInterceptReplaces(t *testing.T) { diff --git a/python/nemo_relay/intercepts.py b/python/nemo_relay/intercepts.py index 836c2e76a..5392a6c46 100644 --- a/python/nemo_relay/intercepts.py +++ b/python/nemo_relay/intercepts.py @@ -169,9 +169,9 @@ def register_llm_request(name: str, priority: int, break_chain: bool, fn: LlmReq priority: Execution order for the intercept. Lower values run first. break_chain: Whether to stop applying lower-priority request intercepts after this intercept runs. - fn: Callable invoked as ``fn(name, request, annotated)`` that returns a - tuple of ``(request, annotated)`` for the next intercept or the - provider callback. + fn: Callable invoked as ``fn(name, request, annotated)`` that returns an + ``nemo_relay.LLMRequestInterceptOutcome`` for the next intercept or + the provider callback. Returns: None: This function returns after the intercept is registered. diff --git a/python/tests/test_llm.py b/python/tests/test_llm.py index 30c390164..5e9c738ef 100644 --- a/python/tests/test_llm.py +++ b/python/tests/test_llm.py @@ -12,6 +12,7 @@ LLMHandle, LLMRequest, LLMRequestInterceptOutcome, + PendingMarkSpec, ScopeEvent, ScopeType, guardrails, @@ -295,16 +296,25 @@ def test_request_intercept(self): assert intercepts.deregister_llm_request("py_llm_req") def test_request_intercepts_direct(self): + pending_mark = PendingMarkSpec("request.direct", data={"source": "python"}) + def intercept_fn(name, request, annotated): content = request.content content["direct"] = True - return LLMRequestInterceptOutcome(LLMRequest(request.headers, content), annotated) + return LLMRequestInterceptOutcome( + LLMRequest(request.headers, content), + annotated, + [pending_mark], + ) intercepts.register_llm_request("py_llm_req_direct", 1, False, intercept_fn) transformed = llm.request_intercepts("direct_llm", make_request()) intercepts.deregister_llm_request("py_llm_req_direct") assert transformed.request.content["direct"] is True + assert len(transformed.pending_marks) == 1 + assert transformed.pending_marks[0].name == pending_mark.name + assert transformed.pending_marks[0].data == pending_mark.data def test_request_intercept_raises_on_exception(self): intercepts.register_llm_request( From 3ea2c333b041f915db0f149110529bbd680341d1 Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Wed, 1 Jul 2026 07:48:44 -0600 Subject: [PATCH 8/9] chore: move LLM intercept docs to follow-up Signed-off-by: Bryan Bednarski --- docs/build-plugins/code-examples.mdx | 7 +- docs/build-plugins/register-behavior.mdx | 10 +- .../instrument-applications/code-examples.mdx | 12 +-- .../code-examples.mdx | 9 +- .../provider-codecs.mdx | 13 +-- .../llm-request-intercept-outcomes.mdx | 97 ------------------- 6 files changed, 17 insertions(+), 131 deletions(-) delete mode 100644 docs/reference/llm-request-intercept-outcomes.mdx diff --git a/docs/build-plugins/code-examples.mdx b/docs/build-plugins/code-examples.mdx index 068e33f1d..f2114c2cf 100644 --- a/docs/build-plugins/code-examples.mdx +++ b/docs/build-plugins/code-examples.mdx @@ -37,14 +37,11 @@ class HeaderPlugin: name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> nemo_relay.LLMRequestInterceptOutcome: + ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: # The request object is immutable, however we can return a new instance with updated headers. headers = request.headers.copy() headers[plugin_config["header_name"]] = plugin_config["value"] - return nemo_relay.LLMRequestInterceptOutcome( - nemo_relay.LLMRequest(headers=headers, content=request.content), - annotated, - ) + return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated context.register_llm_request_intercept("inject-header", 100, False, add_header) diff --git a/docs/build-plugins/register-behavior.mdx b/docs/build-plugins/register-behavior.mdx index 82f01803f..d995ca773 100644 --- a/docs/build-plugins/register-behavior.mdx +++ b/docs/build-plugins/register-behavior.mdx @@ -51,13 +51,10 @@ class HeaderPlugin: name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> nemo_relay.LLMRequestInterceptOutcome: + ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: headers = request.headers.copy() headers[plugin_config["header_name"]] = plugin_config["value"] - return nemo_relay.LLMRequestInterceptOutcome( - nemo_relay.LLMRequest(headers=headers, content=request.content), - annotated, - ) + return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated context.register_llm_request_intercept("inject-header", 100, False, add_header) @@ -102,7 +99,6 @@ plugin.register('header-plugin', headerPlugin); ```rust -use nemo_relay::api::llm::LlmRequestInterceptOutcome; use nemo_relay::plugin::{ register_plugin, ConfigDiagnostic, DiagnosticLevel, Plugin, PluginRegistrationContext, Result as PluginResult, @@ -170,7 +166,7 @@ impl Plugin for HeaderPlugin { request .headers .insert(header_name.clone(), header_value.clone().into()); - Ok(LlmRequestInterceptOutcome::new(request, annotated)) + Ok((request, annotated)) }), )?; Ok(()) diff --git a/docs/instrument-applications/code-examples.mdx b/docs/instrument-applications/code-examples.mdx index 81a9d4407..64d687da3 100644 --- a/docs/instrument-applications/code-examples.mdx +++ b/docs/instrument-applications/code-examples.mdx @@ -260,8 +260,8 @@ tool_args = nemo_relay.tools.request_intercepts("search", {"query": "weather"}) nemo_relay.tools.conditional_execution("search", tool_args) llm_request = LLMRequest({}, {"messages": [{"role": "user", "content": "hello"}]}) -outcome = nemo_relay.llm.request_intercepts("demo-provider", llm_request) -nemo_relay.llm.conditional_execution(outcome.request) +llm_request = nemo_relay.llm.request_intercepts("demo-provider", llm_request) +nemo_relay.llm.conditional_execution(llm_request) ``` @@ -279,8 +279,8 @@ const toolArgs = await toolRequestIntercepts('search', { query: 'weather' }); await toolConditionalExecution('search', toolArgs); const request = new LlmRequest({}, { messages: [{ role: 'user', content: 'hello' }] }); -const outcome = await llmRequestIntercepts('demo-provider', request); -await llmConditionalExecution(outcome.request); +const rewritten = await llmRequestIntercepts('demo-provider', request); +await llmConditionalExecution(rewritten); ``` @@ -297,8 +297,8 @@ let request = LlmRequest { headers: Default::default(), content: json!({"messages": [{"role": "user", "content": "hello"}]}), }; -let outcome = llm_request_intercepts("demo-provider", request)?; -llm_conditional_execution(&outcome.request)?; +let rewritten = llm_request_intercepts("demo-provider", request)?; +llm_conditional_execution(&rewritten)?; ``` diff --git a/docs/integrate-into-frameworks/code-examples.mdx b/docs/integrate-into-frameworks/code-examples.mdx index ffd9a7830..d989191fa 100644 --- a/docs/integrate-into-frameworks/code-examples.mdx +++ b/docs/integrate-into-frameworks/code-examples.mdx @@ -184,11 +184,10 @@ import nemo_relay from nemo_relay import LLMRequest rewritten_args = nemo_relay.tools.request_intercepts("search", {"query": "weather"}) -outcome = nemo_relay.llm.request_intercepts( +rewritten_request = nemo_relay.llm.request_intercepts( "demo-provider", LLMRequest({}, {"messages": []}), ) -rewritten_request = outcome.request ``` @@ -197,8 +196,7 @@ rewritten_request = outcome.request import { LlmRequest, llmRequestIntercepts, toolRequestIntercepts } from 'nemo-relay-node'; const rewrittenArgs = await toolRequestIntercepts('search', { query: 'weather' }); -const outcome = await llmRequestIntercepts('demo-provider', new LlmRequest({}, { messages: [] })); -const rewrittenRequest = outcome.request; +const rewrittenRequest = await llmRequestIntercepts('demo-provider', new LlmRequest({}, { messages: [] })); ``` @@ -210,8 +208,7 @@ use serde_json::json; let rewritten_args = tool_request_intercepts("search", json!({"query": "weather"}))?; let request = LlmRequest { headers: Default::default(), content: json!({"messages": []}) }; -let outcome = llm_request_intercepts("demo-provider", request)?; -let rewritten_request = outcome.request; +let rewritten_request = llm_request_intercepts("demo-provider", request)?; ``` diff --git a/docs/integrate-into-frameworks/provider-codecs.mdx b/docs/integrate-into-frameworks/provider-codecs.mdx index 90df6c7f6..f0d632661 100644 --- a/docs/integrate-into-frameworks/provider-codecs.mdx +++ b/docs/integrate-into-frameworks/provider-codecs.mdx @@ -41,17 +41,10 @@ When a managed LLM call has a request codec: 1. NeMo Relay calls `decode` before LLM request intercepts run. 2. Request intercepts receive both the raw request and the annotated request. -3. Intercepts edit provider-body fields through the annotated request and may - edit transport headers through the raw request. Raw `request.content` is - read-only while the codec is active. +3. Intercepts may edit the raw request, the annotated request, or both. 4. NeMo Relay calls `encode` to merge the annotated request back into the original raw request. 5. Execution intercepts and the provider callback receive the encoded provider request. -If a codec-aware intercept changes raw `request.content` or omits the returned -annotation, Relay rejects the outcome before creating the LLM lifecycle. When -no request codec is active, the raw request remains fully writable and is the -provider-visible source of truth. - When a managed LLM call has a response codec, NeMo Relay decodes the raw provider response for observability and attaches the result to the emitted LLM end event. The response codec does not rewrite the value returned to the application. Use [Provider Response Codecs](/integrate-into-frameworks/provider-response-codecs) for response-only behavior and custom response codec examples. Codec implementations should preserve fields they do not understand. Treat `encode` as a merge operation over the original provider payload, not as a full replacement. @@ -94,7 +87,7 @@ from nemo_relay.codecs import OpenAIChatCodec def add_system_message(_name, request, annotated): if annotated is None: - return nemo_relay.LLMRequestInterceptOutcome(request) + return request, annotated # Attributes of the annotated request can be re-assigned, but cannot be modified in-place. # For example `annotated.messages.append(...)` would not work, but re-assigning @@ -103,7 +96,7 @@ def add_system_message(_name, request, annotated): {"role": "system", "content": "Answer with concise technical detail."}, *annotated.messages, ] - return nemo_relay.LLMRequestInterceptOutcome(request, annotated) + return request, annotated nemo_relay.intercepts.register_llm_request( "framework.add_system_message", diff --git a/docs/reference/llm-request-intercept-outcomes.mdx b/docs/reference/llm-request-intercept-outcomes.mdx deleted file mode 100644 index b184a35ec..000000000 --- a/docs/reference/llm-request-intercept-outcomes.mdx +++ /dev/null @@ -1,97 +0,0 @@ ---- -title: "LLM Request Intercept Outcomes" -description: "Canonical request-intercept result and managed lifecycle behavior." ---- -{/* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -SPDX-License-Identifier: Apache-2.0 */} - -Every LLM request intercept returns one canonical outcome: - -```json -{ - "request": {"headers": {}, "content": {}}, - "annotated_request": null, - "pending_marks": [] -} -``` - -`request` is required. `annotated_request` defaults to `null` when omitted on -input, and `pending_marks` defaults to an empty list. Canonical serialization -includes all three fields. A pending mark contains only `name`, optional -`category` and `category_profile`, and optional `data` and `metadata`. Relay -owns event UUIDs, parent UUIDs, and timestamps. - -## Request Authority - -The provider-body source of truth depends only on whether a request codec is -active: - -| Request codec | Provider body source | Header source | -| --- | --- | --- | -| No codec | `outcome.request.content` | `outcome.request.headers` | -| Active codec | `outcome.annotated_request` | `outcome.request.headers` | - -With an active codec, `request.content` is read-only context. Every intercept -must return an annotation and make provider-body changes through that -annotation, including its flattened `extra` fields for provider-specific data. -Relay rejects a changed raw body or missing annotation at the offending -intercept before invoking later middleware or creating an LLM lifecycle. - -```mermaid -flowchart TD - INPUT["Original LlmRequest"] --> CODEC{"Request codec active?"} - - CODEC -->|No| RAWCHAIN["Run intercept chain"] - RAWCHAIN --> RAWPROVIDER["Provider receives outcome.request"] - - CODEC -->|Yes| DECODE["Decode content into annotated_request"] - DECODE --> INTERCEPT["Invoke next intercept"] - INTERCEPT --> CHECKANN{"Annotation returned?"} - CHECKANN -->|No| FAIL["Fail before lifecycle"] - CHECKANN -->|Yes| CHECKRAW{"request.content unchanged?"} - CHECKRAW -->|No| FAIL - CHECKRAW -->|Yes| MORE{"More intercepts?"} - MORE -->|Yes| INTERCEPT - MORE -->|No| ENCODE["Encode final annotated_request"] - ENCODE --> HEADERS["Apply final request.headers"] - HEADERS --> PROVIDER["Provider receives one resolved LlmRequest"] -``` - -Python callbacks return `LLMRequestInterceptOutcome`; Rust callbacks return -`LlmRequestInterceptOutcome`; Go callbacks return -`LLMRequestInterceptOutcome`; and Node.js and WebAssembly callbacks return -`{ request, annotated?, pendingMarks? }`, with `categoryProfile` on each -JavaScript pending-mark DTO. The canonical JSON forms retain `pending_marks` -and `category_profile`. Public C callbacks write one owned canonical outcome -JSON string. Native ABI v1 uses one host-owned outcome JSON string. Rust and -Python `grpc-v1` worker SDKs return their canonical outcome type in a -`JsonEnvelope` whose schema is -`nemo.relay.LlmRequestInterceptOutcome@1`. - -The standalone request-intercept helper returns the complete outcome but does -not emit its pending marks because it does not own an LLM lifecycle. - -## Managed Lifecycle - -Managed execution runs all effective global and scope-local intercepts before -creating the LLM handle. Each accepted request/annotation pair feeds the next -intercept under the authority rules above, while pending marks append in -middleware order. A breaking -intercept's marks are retained. If any intercept fails or its boundary result -is malformed, Relay discards all accumulated marks and creates no LLM -lifecycle. - -After successful interception, Relay creates the handle and captures one -subscriber snapshot. It emits the LLM start at `T`, every pending mark at -`T + 1µs` in returned order with the LLM UUID as parent, and the LLM end no -earlier than `T + 1µs`. Streaming and non-streaming calls use the same rules. -Pending marks are never added to the provider request, annotated request, -codec input, sanitizer input, or start payload. - -## Migration - -This finalizes unpublished native ABI v1 and `grpc-v1` contracts. Rebuild all -development native plugins and workers. Replace tuple results, split C/Go -outputs, metadata envelopes, and parallel mark-aware registrations with the -canonical outcome and the existing `register_llm_request_intercept` -registration name. From 24c9d3838de43ec7236dc6e55b6590f90734fb40 Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Wed, 1 Jul 2026 09:51:16 -0600 Subject: [PATCH 9/9] refactor(llm): finalize request intercept outcome contract - add ergonomic outcome conversions and restore registry documentation - keep JavaScript pending-mark DTOs binding-local - remove the unreleased native ABI contract flag Signed-off-by: Bryan Bednarski --- crates/core/src/api/registry.rs | 8 ++- crates/core/src/bindings.rs | 60 ---------------------- crates/core/src/lib.rs | 2 - crates/core/src/plugin/dynamic/native.rs | 27 ++++------ crates/node/src/callable.rs | 50 ++++++++++++++++-- crates/plugin/src/lib.rs | 15 ------ crates/plugin/tests/typed_callbacks.rs | 62 +++++------------------ crates/types/src/api/llm.rs | 18 +++++++ crates/types/tests/serialization_tests.rs | 33 ++++++++++++ crates/wasm/src/callable.rs | 53 ++++++++++++++++--- 10 files changed, 173 insertions(+), 155 deletions(-) delete mode 100644 crates/core/src/bindings.rs diff --git a/crates/core/src/api/registry.rs b/crates/core/src/api/registry.rs index a4273291d..c7989c7ae 100644 --- a/crates/core/src/api/registry.rs +++ b/crates/core/src/api/registry.rs @@ -547,7 +547,9 @@ global_guardrail_registry_api!( LlmConditionalFn ); global_intercept_registry_api!( - /// Register a global LLM request intercept that can schedule lifecycle marks. + /// Register a global LLM request intercept. + /// Request intercepts can rewrite or annotate the outgoing LLM request and + /// schedule lifecycle marks for the resulting LLM scope. register_llm_request_intercept, /// Deregister a global LLM request intercept. deregister_llm_request_intercept, @@ -652,7 +654,9 @@ scope_guardrail_registry_api!( LlmConditionalFn ); scope_intercept_registry_api!( - /// Register a scope-local LLM request intercept that can schedule lifecycle marks. + /// Register a scope-local LLM request intercept. + /// Request intercepts can rewrite or annotate LLM requests inside the + /// owning scope and schedule lifecycle marks for the resulting LLM scope. scope_register_llm_request_intercept, /// Deregister a scope-local LLM request intercept. scope_deregister_llm_request_intercept, diff --git a/crates/core/src/bindings.rs b/crates/core/src/bindings.rs deleted file mode 100644 index 5329972ab..000000000 --- a/crates/core/src/bindings.rs +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -//! Shared data conversions used by language bindings. - -/// JavaScript-specific data transfer objects shared by Node.js and WebAssembly. -pub mod js { - use serde::{Deserialize, Serialize}; - use serde_json::Value as Json; - - use crate::api::event::{CategoryProfile, EventCategory, PendingMarkSpec}; - - /// JavaScript-facing pending mark DTO. - /// - /// JavaScript bindings use camelCase while canonical Relay JSON keeps - /// snake_case field names. - #[derive(Debug, Deserialize, Serialize)] - #[serde(rename_all = "camelCase", deny_unknown_fields)] - pub struct JsPendingMarkSpec { - name: String, - #[serde(default)] - category: Option, - #[serde(default)] - category_profile: Option, - #[serde(default)] - data: Option, - #[serde(default)] - metadata: Option, - } - - impl From for PendingMarkSpec { - fn from(mark: JsPendingMarkSpec) -> Self { - Self { - name: mark.name, - category: mark.category, - category_profile: mark.category_profile, - data: mark.data, - metadata: mark.metadata, - } - } - } - - impl From for JsPendingMarkSpec { - fn from(mark: PendingMarkSpec) -> Self { - Self { - name: mark.name, - category: mark.category, - category_profile: mark.category_profile, - data: mark.data, - metadata: mark.metadata, - } - } - } - - /// Convert canonical pending marks to JavaScript-facing DTOs. - #[must_use] - pub fn js_pending_marks(marks: Vec) -> Vec { - marks.into_iter().map(Into::into).collect() - } -} diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index c1532869e..ffd03a4dd 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -54,8 +54,6 @@ //! All middleware is priority-ordered (ascending) and registered by name for //! easy addition and removal at runtime. pub mod api; -#[doc(hidden)] -pub mod bindings; pub mod codec; pub mod config_editor; mod context; diff --git a/crates/core/src/plugin/dynamic/native.rs b/crates/core/src/plugin/dynamic/native.rs index 381d9b8e3..1ba467588 100644 --- a/crates/core/src/plugin/dynamic/native.rs +++ b/crates/core/src/plugin/dynamic/native.rs @@ -15,15 +15,15 @@ use std::task::{Context, Poll}; use chrono::{DateTime, Utc}; use libloading::{Library, Symbol}; use nemo_relay_plugin::{ - NEMO_RELAY_NATIVE_ABI_VERSION, NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, - NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, - NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, - NemoRelayNativeLlmRequestCb, NemoRelayNativeLlmRequestInterceptCb, - NemoRelayNativeLlmStreamExecutionCb, NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, - NemoRelayNativePluginEntry, NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, - NemoRelayNativeScopeStack, NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, - NemoRelayNativeString, NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, - NemoRelayNativeToolJsonCb, NemoRelayNativeWithScopeStackCb, NemoRelayStatus, + NEMO_RELAY_NATIVE_ABI_VERSION, NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, + NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, + NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, + NemoRelayNativeLlmRequestInterceptCb, NemoRelayNativeLlmStreamExecutionCb, + NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, NemoRelayNativePluginEntry, + NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, + NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, NemoRelayNativeString, + NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, + NemoRelayNativeWithScopeStackCb, NemoRelayStatus, }; use semver::{Version, VersionReq}; use serde_json::{Map, Value as Json}; @@ -401,13 +401,6 @@ fn validate_plugin_descriptor( plugin.struct_size ))); } - if plugin.llm_request_intercept_outcome_contract_version - != NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION - { - return Err(PluginError::InvalidConfig(format!( - "native plugin '{plugin_id}' returned an incompatible LLM request-intercept outcome contract" - ))); - } if plugin.plugin_kind.is_null() { return Err(PluginError::InvalidConfig(format!( "native plugin '{plugin_id}' returned a null plugin_kind" @@ -603,8 +596,6 @@ fn native_host_api() -> *const NemoRelayNativeHostApiV1 { scope_stack_binding_free: native_scope_stack_binding_free, scope_stack_active: native_scope_stack_active, scope_stack_with_current: native_scope_stack_with_current, - llm_request_intercept_outcome_contract_version: - NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, }) as *const _ } diff --git a/crates/node/src/callable.rs b/crates/node/src/callable.rs index 87386fbfb..54cb93bff 100644 --- a/crates/node/src/callable.rs +++ b/crates/node/src/callable.rs @@ -19,13 +19,12 @@ use nemo_relay::api::runtime::{ LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use serde_json::Value as Json; use tokio_stream::StreamExt; -use nemo_relay::api::event::Event; +use nemo_relay::api::event::{CategoryProfile, Event, EventCategory, PendingMarkSpec}; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; -pub(crate) use nemo_relay::bindings::js::{JsPendingMarkSpec, js_pending_marks}; use nemo_relay::codec::request::AnnotatedLlmRequest; use nemo_relay::codec::response::AnnotatedLlmResponse; use nemo_relay::codec::traits::{LlmCodec, LlmResponseCodec}; @@ -35,6 +34,51 @@ use crate::convert::{callback_json, record_callback_error}; use crate::promise_call::{JsonNextFn, JsonStreamNextFn, PromiseAwareFn}; use crate::types::JsEvent; +/// JavaScript-facing pending mark DTO. +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(crate) struct JsPendingMarkSpec { + name: String, + #[serde(default)] + category: Option, + #[serde(default)] + category_profile: Option, + #[serde(default)] + data: Option, + #[serde(default)] + metadata: Option, +} + +impl From for PendingMarkSpec { + fn from(mark: JsPendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +impl From for JsPendingMarkSpec { + fn from(mark: PendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +/// Convert canonical pending marks to JavaScript-facing DTOs. +#[must_use] +pub(crate) fn js_pending_marks(marks: Vec) -> Vec { + marks.into_iter().map(Into::into).collect() +} + fn recv_json_or_null(rx: std::sync::mpsc::Receiver, error_prefix: &str) -> Json { rx.recv().unwrap_or_else(|e| { record_callback_error(format!("{error_prefix}: {e}")); diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 5231ce803..662b191a2 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -30,8 +30,6 @@ use serde_json::Map; /// Native plugin ABI version supported by this crate. pub const NEMO_RELAY_NATIVE_ABI_VERSION: u32 = 1; -/// Final canonical LLM request-intercept outcome contract version. -pub const NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION: u32 = 1; /// Status codes returned by stable native ABI functions. #[repr(i32)] @@ -497,8 +495,6 @@ pub struct NemoRelayNativeHostApiV1 { cb: NemoRelayNativeWithScopeStackCb, user_data: *mut c_void, ) -> NemoRelayStatus, - /// Required canonical LLM request-intercept outcome contract version. - pub llm_request_intercept_outcome_contract_version: u32, } // The host API table is immutable after construction. Function pointers and @@ -523,8 +519,6 @@ pub struct NemoRelayNativePluginV1 { pub register: Option, /// Optional plugin-owned state destructor. pub drop: NemoRelayNativePluginDropFn, - /// Required canonical LLM request-intercept outcome contract version. - pub llm_request_intercept_outcome_contract_version: u32, } impl Default for NemoRelayNativePluginV1 { @@ -537,8 +531,6 @@ impl Default for NemoRelayNativePluginV1 { validate: None, register: None, drop: None, - llm_request_intercept_outcome_contract_version: - NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, } } } @@ -2667,11 +2659,6 @@ where if host_ref.struct_size < std::mem::size_of::() { return NemoRelayStatus::InvalidArg; } - if host_ref.llm_request_intercept_outcome_contract_version - != NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION - { - return NemoRelayStatus::InvalidArg; - } let plugin = constructor(); let kind = plugin.plugin_kind().to_owned(); @@ -2692,8 +2679,6 @@ where validate: Some(validate_trampoline::

), register: Some(register_trampoline::

), drop: Some(drop_plugin_state::

), - llm_request_intercept_outcome_contract_version: - NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, }; } std::mem::forget(kind_handle); diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs index 8e751289e..b1fba3674 100644 --- a/crates/plugin/tests/typed_callbacks.rs +++ b/crates/plugin/tests/typed_callbacks.rs @@ -15,10 +15,9 @@ use std::sync::{ use nemo_relay_plugin::{ AnnotatedLlmRequest, ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmNext, LlmRequest, LlmRequestInterceptOutcome, LlmStream, LlmStreamNext, - NEMO_RELAY_NATIVE_ABI_VERSION, NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, - NativePlugin, NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, - NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, - NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, + NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, NemoRelayNativeEventSubscriberCb, + NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, + NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, NemoRelayNativeLlmRequestInterceptCb, NemoRelayNativeLlmStreamExecutionCb, NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, NemoRelayNativeScopeStackBinding, @@ -297,17 +296,17 @@ fn native_abi_v1_struct_sizes_are_self_describing() { #[cfg(target_pointer_width = "64")] { assert_eq!(align_of::(), 8); - assert_eq!(size_of::(), 280); + assert_eq!(size_of::(), 272); assert_eq!( host_api_offsets(), [ 0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, - 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, + 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, ] ); assert_eq!(align_of::(), 8); - assert_eq!(size_of::(), 64); - assert_eq!(plugin_offsets(), [0, 8, 16, 24, 32, 40, 48, 56]); + assert_eq!(size_of::(), 56); + assert_eq!(plugin_offsets(), [0, 8, 16, 24, 32, 40, 48]); assert_eq!(align_of::(), 8); assert_eq!(size_of::(), 40); assert_eq!(stream_offsets(), [0, 8, 16, 24, 32]); @@ -316,24 +315,24 @@ fn native_abi_v1_struct_sizes_are_self_describing() { #[cfg(target_pointer_width = "32")] { assert_eq!(align_of::(), 4); - assert_eq!(size_of::(), 140); + assert_eq!(size_of::(), 136); assert_eq!( host_api_offsets(), [ 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, - 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, 136, + 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 132, ] ); assert_eq!(align_of::(), 4); - assert_eq!(size_of::(), 32); - assert_eq!(plugin_offsets(), [0, 4, 8, 12, 16, 20, 24, 28]); + assert_eq!(size_of::(), 28); + assert_eq!(plugin_offsets(), [0, 4, 8, 12, 16, 20, 24]); assert_eq!(align_of::(), 4); assert_eq!(size_of::(), 20); assert_eq!(stream_offsets(), [0, 4, 8, 12, 16]); } } -fn host_api_offsets() -> [usize; 35] { +fn host_api_offsets() -> [usize; 34] { [ offset_of!(NemoRelayNativeHostApiV1, abi_version), offset_of!(NemoRelayNativeHostApiV1, struct_size), @@ -402,14 +401,10 @@ fn host_api_offsets() -> [usize; 35] { offset_of!(NemoRelayNativeHostApiV1, scope_stack_binding_free), offset_of!(NemoRelayNativeHostApiV1, scope_stack_active), offset_of!(NemoRelayNativeHostApiV1, scope_stack_with_current), - offset_of!( - NemoRelayNativeHostApiV1, - llm_request_intercept_outcome_contract_version - ), ] } -fn plugin_offsets() -> [usize; 8] { +fn plugin_offsets() -> [usize; 7] { [ offset_of!(NemoRelayNativePluginV1, struct_size), offset_of!(NemoRelayNativePluginV1, plugin_kind), @@ -418,10 +413,6 @@ fn plugin_offsets() -> [usize; 8] { offset_of!(NemoRelayNativePluginV1, validate), offset_of!(NemoRelayNativePluginV1, register), offset_of!(NemoRelayNativePluginV1, drop), - offset_of!( - NemoRelayNativePluginV1, - llm_request_intercept_outcome_contract_version - ), ] } @@ -1100,8 +1091,6 @@ fn test_host() -> NemoRelayNativeHostApiV1 { scope_stack_binding_free: capture_scope_stack_binding_free, scope_stack_active: true_scope_stack_active, scope_stack_with_current: capture_scope_stack_with_current, - llm_request_intercept_outcome_contract_version: - NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, } } @@ -4736,8 +4725,6 @@ fn direct_export_plugin_validates_host_table_and_kind_allocation() { validate: None, register: None, drop: None, - llm_request_intercept_outcome_contract_version: - NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, }; assert_eq!( unsafe { nemo_relay_plugin::export_plugin(&bad_host, &mut plugin, CountingPlugin) }, @@ -4754,18 +4741,6 @@ fn direct_export_plugin_validates_host_table_and_kind_allocation() { NemoRelayStatus::InvalidArg ); - let mut incompatible_host = host; - incompatible_host.llm_request_intercept_outcome_contract_version = - NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION + 1; - assert_eq!( - unsafe { - nemo_relay_plugin::export_plugin(&incompatible_host, &mut plugin, CountingPlugin) - }, - NemoRelayStatus::InvalidArg - ); - assert!(plugin.plugin_kind.is_null()); - assert!(plugin.user_data.is_null()); - *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); assert_eq!( unsafe { nemo_relay_plugin::export_plugin(&host, &mut plugin, CountingPlugin) }, @@ -4960,8 +4935,6 @@ fn exported_entry_symbol_validates_args_before_constructor() { validate: None, register: None, drop: None, - llm_request_intercept_outcome_contract_version: - NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION, }; assert_eq!( unsafe { constructor_counting_entry(&bad_host, &mut plugin) }, @@ -4988,15 +4961,6 @@ fn exported_entry_symbol_validates_args_before_constructor() { NemoRelayStatus::InvalidArg ); assert_eq!(CONSTRUCTOR_CALLS.load(Ordering::SeqCst), 0); - - let mut incompatible_host = host; - incompatible_host.llm_request_intercept_outcome_contract_version = - NEMO_RELAY_NATIVE_LLM_INTERCEPT_OUTCOME_CONTRACT_VERSION + 1; - assert_eq!( - unsafe { constructor_counting_entry(&incompatible_host, &mut plugin) }, - NemoRelayStatus::InvalidArg - ); - assert_eq!(CONSTRUCTOR_CALLS.load(Ordering::SeqCst), 0); } #[test] diff --git a/crates/types/src/api/llm.rs b/crates/types/src/api/llm.rs index e7d34680d..121754818 100644 --- a/crates/types/src/api/llm.rs +++ b/crates/types/src/api/llm.rs @@ -67,3 +67,21 @@ impl LlmRequestInterceptOutcome { self } } + +impl From for LlmRequestInterceptOutcome { + fn from(request: LlmRequest) -> Self { + Self::new(request, None) + } +} + +impl From<(LlmRequest, AnnotatedLlmRequest)> for LlmRequestInterceptOutcome { + fn from((request, annotated_request): (LlmRequest, AnnotatedLlmRequest)) -> Self { + Self::new(request, Some(annotated_request)) + } +} + +impl From<(LlmRequest, Option)> for LlmRequestInterceptOutcome { + fn from((request, annotated_request): (LlmRequest, Option)) -> Self { + Self::new(request, annotated_request) + } +} diff --git a/crates/types/tests/serialization_tests.rs b/crates/types/tests/serialization_tests.rs index cede82fe5..39bb8fbc7 100644 --- a/crates/types/tests/serialization_tests.rs +++ b/crates/types/tests/serialization_tests.rs @@ -138,3 +138,36 @@ fn llm_request_intercept_outcome_round_trips_pending_marks() { serde_json::from_value(encoded).expect("outcome should deserialize"); assert_eq!(decoded, outcome); } + +#[test] +fn llm_request_intercept_outcome_converts_from_request_inputs() { + let request = LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "hello" }), + }; + let annotated_request: AnnotatedLlmRequest = serde_json::from_value(json!({ + "messages": [], + "model": "model" + })) + .expect("annotated request should deserialize"); + + let request_only: LlmRequestInterceptOutcome = request.clone().into(); + assert_eq!( + request_only, + LlmRequestInterceptOutcome::new(request.clone(), None) + ); + + let required_annotation: LlmRequestInterceptOutcome = + (request.clone(), annotated_request.clone()).into(); + assert_eq!( + required_annotation, + LlmRequestInterceptOutcome::new(request.clone(), Some(annotated_request.clone())) + ); + + let optional_annotation: LlmRequestInterceptOutcome = + (request.clone(), Some(annotated_request.clone())).into(); + assert_eq!( + optional_annotation, + LlmRequestInterceptOutcome::new(request, Some(annotated_request)) + ); +} diff --git a/crates/wasm/src/callable.rs b/crates/wasm/src/callable.rs index 5fd3d6b00..51a9175d1 100644 --- a/crates/wasm/src/callable.rs +++ b/crates/wasm/src/callable.rs @@ -16,8 +16,7 @@ use std::sync::Arc; use js_sys::Function; #[cfg(target_arch = "wasm32")] use send_wrapper::SendWrapper; -#[cfg(target_arch = "wasm32")] -use serde::Serialize; +use serde::{Deserialize, Serialize}; use serde_json::Value as Json; #[cfg(target_arch = "wasm32")] use tokio_stream::StreamExt; @@ -28,16 +27,13 @@ use wasm_bindgen::JsValue; #[cfg(target_arch = "wasm32")] use wasm_bindgen_futures::JsFuture; -use nemo_relay::api::event::Event; +use nemo_relay::api::event::{CategoryProfile, Event, EventCategory, PendingMarkSpec}; use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionNextFn, LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionNextFn, ToolConditionalFn, ToolExecutionNextFn, ToolInterceptFn, ToolSanitizeFn, }; -#[cfg(any(target_arch = "wasm32", test))] -pub(crate) use nemo_relay::bindings::js::JsPendingMarkSpec; -pub(crate) use nemo_relay::bindings::js::js_pending_marks; use nemo_relay::codec::request::AnnotatedLlmRequest; #[cfg(target_arch = "wasm32")] use nemo_relay::codec::response::AnnotatedLlmResponse; @@ -51,6 +47,51 @@ use crate::convert::{js_callback_to_json, js_to_json, json_to_js}; #[cfg(target_arch = "wasm32")] use crate::types::WasmEvent; +/// JavaScript-facing pending mark DTO. +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase", deny_unknown_fields)] +pub(crate) struct JsPendingMarkSpec { + name: String, + #[serde(default)] + category: Option, + #[serde(default)] + category_profile: Option, + #[serde(default)] + data: Option, + #[serde(default)] + metadata: Option, +} + +impl From for PendingMarkSpec { + fn from(mark: JsPendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +impl From for JsPendingMarkSpec { + fn from(mark: PendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +/// Convert canonical pending marks to JavaScript-facing DTOs. +#[must_use] +pub(crate) fn js_pending_marks(marks: Vec) -> Vec { + marks.into_iter().map(Into::into).collect() +} + /// Extract a human-readable error message from a `JsValue`. /// /// Tries `.as_string()` first (for string errors), then falls back to debug format.