diff --git a/crates/adaptive/src/acg_component.rs b/crates/adaptive/src/acg_component.rs index 7d5ccaddc..3d344866a 100644 --- a/crates/adaptive/src/acg_component.rs +++ b/crates/adaptive/src/acg_component.rs @@ -595,10 +595,25 @@ pub(crate) fn create_acg_llm_request_intercept( plugin: Arc, ) -> LlmRequestInterceptFn { Arc::new(move |_name: &str, request: LlmRequest, annotated| { + let input_content = request.content.clone(); let translated = translate_request(&request, &agent_id, &provider, plugin.as_ref(), &hot_cache) .unwrap_or(request); - Ok((translated, annotated)) + if annotated.is_some() && translated.content != input_content { + let translated_annotated = build_semantic_request_view(&translated) + .map_err(|error| nemo_relay::error::FlowError::Internal(error.to_string()))? + .annotated_request; + return Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + LlmRequest { + headers: translated.headers, + content: input_content, + }, + Some(translated_annotated), + )); + } + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + translated, annotated, + )) }) } diff --git a/crates/adaptive/src/adaptive_hints_intercept.rs b/crates/adaptive/src/adaptive_hints_intercept.rs index 05613aadb..c9f245505 100644 --- a/crates/adaptive/src/adaptive_hints_intercept.rs +++ b/crates/adaptive/src/adaptive_hints_intercept.rs @@ -86,22 +86,24 @@ fn manual_agent_hints(manual: u32, effective_agent_id: &str, scope_depth: usize) } } -fn inject_agent_hints(request: &mut LlmRequest, hints: &AgentHints) { +fn inject_agent_hints( + request: &mut LlmRequest, + annotated: &mut Option, + hints: &AgentHints, +) { let Ok(serialized_hints) = serde_json::to_value(hints) else { return; }; - if let Some(body) = request.content.as_object_mut() { - if !body.contains_key("nvext") { - body.insert( - "nvext".to_string(), - serde_json::Value::Object(serde_json::Map::new()), - ); - } - if let Some(nvext) = body - .get_mut("nvext") - .and_then(|value| value.as_object_mut()) - { + let body = annotated + .as_mut() + .map(|annotated| &mut annotated.extra) + .or_else(|| request.content.as_object_mut()); + if let Some(body) = body { + let nvext = body + .entry("nvext".to_string()) + .or_insert_with(|| serde_json::Value::Object(serde_json::Map::new())); + if let Some(nvext) = nvext.as_object_mut() { nvext.insert("agent_hints".to_string(), serialized_hints.clone()); } } @@ -172,7 +174,9 @@ impl AdaptiveHintsIntercept { pub fn into_request_fn(self) -> LlmRequestInterceptFn { let this = Arc::new(self); Arc::new( - move |_name: &str, mut request: LlmRequest, annotated: Option| { + move |_name: &str, + mut request: LlmRequest, + mut annotated: Option| { let scope_path = extract_scope_path(); let manual_ls = read_manual_latency_sensitivity(); let scope_depth = scope_path.len(); @@ -189,10 +193,12 @@ impl AdaptiveHintsIntercept { ); if let Some(hints) = final_hints { - inject_agent_hints(&mut request, &hints); + inject_agent_hints(&mut request, &mut annotated, &hints); } - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }, ) } diff --git a/crates/adaptive/tests/integration/runtime_integration_tests.rs b/crates/adaptive/tests/integration/runtime_integration_tests.rs index cdee3c050..d02f4f4d4 100644 --- a/crates/adaptive/tests/integration/runtime_integration_tests.rs +++ b/crates/adaptive/tests/integration/runtime_integration_tests.rs @@ -588,7 +588,7 @@ async fn test_adaptive_plugin_registers_and_passes_calls_through() { }, ) .unwrap(); - assert_eq!(request.content["messages"], json!([])); + assert_eq!(request.request.content["messages"], json!([])); let llm_func: LlmExecutionNextFn = Arc::new(|_req: LlmRequest| Box::pin(async { Ok(json!({"response": "ok"})) })); @@ -721,7 +721,9 @@ impl Plugin for HeaderPlugin { false, Arc::new(|_name, mut request, annotated| { request.headers.insert("x-plugin".into(), json!("set")); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), )?; ctx.register_tool_request_intercept( @@ -806,7 +808,7 @@ async fn test_top_level_plugin_registers_request_and_execution_intercepts() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), Some(&json!("set"))); + assert_eq!(request.request.headers.get("x-plugin"), Some(&json!("set"))); let tool_func: ToolExecutionNextFn = Arc::new(|args| Box::pin(async move { Ok(args) })); let tool_result = tool_call_execute( diff --git a/crates/adaptive/tests/unit/acg_component_tests.rs b/crates/adaptive/tests/unit/acg_component_tests.rs index cc4c2f7e7..bbcf4d024 100644 --- a/crates/adaptive/tests/unit/acg_component_tests.rs +++ b/crates/adaptive/tests/unit/acg_component_tests.rs @@ -1087,12 +1087,14 @@ fn acg_component_request_intercept_passes_original_request_and_annotation_when_t plugin, ); - let (translated, returned_annotated) = intercept( + let outcome = intercept( "anthropic", invalid_request.clone(), Some(annotated.clone()), ) .expect("request intercept should pass through"); + let translated = outcome.request; + let returned_annotated = outcome.annotated_request; assert_eq!(translated.content, invalid_request.content); assert_eq!( diff --git a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs index a7efa190e..9d260fac5 100644 --- a/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs +++ b/crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs @@ -194,7 +194,7 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override() stream: None, extra: serde_json::Map::new(), }; - let (request, returned_annotated) = req_fn( + let outcome = req_fn( "model", LlmRequest { headers: serde_json::Map::new(), @@ -203,8 +203,12 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override() Some(annotated.clone()), ) .unwrap(); + let request = outcome.request; + let returned_annotated = outcome.annotated_request; - let body_hints = &request.content["nvext"]["agent_hints"]; + assert_eq!(request.content, serde_json::json!({})); + let returned_annotated = returned_annotated.expect("annotation should be preserved"); + let body_hints = &returned_annotated.extra["nvext"]["agent_hints"]; assert_eq!(body_hints["osl"], serde_json::json!(150)); assert_eq!(body_hints["iat"], serde_json::json!(200)); assert_eq!(body_hints["latency_sensitivity"], serde_json::json!(5.0)); @@ -215,7 +219,11 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override() request.headers.get(AGENT_HINTS_HEADER_KEY).unwrap(), body_hints ); - assert_eq!(returned_annotated, Some(annotated)); + let mut expected_annotated = annotated; + expected_annotated + .extra + .insert("nvext".into(), returned_annotated.extra["nvext"].clone()); + assert_eq!(returned_annotated, expected_annotated); pop_scope( nemo_relay::api::scope::PopScopeParams::builder() @@ -256,7 +264,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { })); let req_fn = AdaptiveHintsIntercept::new(hot_cache, "fallback-agent".to_string()).into_request_fn(); - let (request, annotated) = req_fn( + let outcome = req_fn( "model", LlmRequest { headers: serde_json::Map::new(), @@ -265,6 +273,8 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { None, ) .unwrap(); + let request = outcome.request; + let annotated = outcome.annotated_request; assert_eq!( request.headers.get(AGENT_HINTS_HEADER_KEY), Some(&serde_json::to_value(&defaults).unwrap()) @@ -293,7 +303,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { }); let poisoned_req_fn = AdaptiveHintsIntercept::new(poisoned_cache, "fallback-agent".to_string()).into_request_fn(); - let (poisoned_request, _) = poisoned_req_fn( + let poisoned_outcome = poisoned_req_fn( "model", LlmRequest { headers: serde_json::Map::new(), @@ -302,6 +312,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() { None, ) .unwrap(); + let poisoned_request = poisoned_outcome.request; assert!( poisoned_request .headers @@ -344,7 +355,7 @@ fn test_apply_manual_latency_override_and_inject_agent_hints_cover_manual_paths( headers: serde_json::Map::new(), content: serde_json::json!("scalar"), }; - inject_agent_hints(&mut non_object_request, &manual_only); + inject_agent_hints(&mut non_object_request, &mut None, &manual_only); assert_eq!( non_object_request.headers.get(AGENT_HINTS_HEADER_KEY), Some(&serde_json::to_value(&manual_only).unwrap()) diff --git a/crates/adaptive/tests/unit/plugin_component_tests.rs b/crates/adaptive/tests/unit/plugin_component_tests.rs index 854ff21fb..044504d19 100644 --- a/crates/adaptive/tests/unit/plugin_component_tests.rs +++ b/crates/adaptive/tests/unit/plugin_component_tests.rs @@ -365,7 +365,7 @@ async fn adaptive_plugin_registers_runtime_and_rolls_back_registration() { }, ) .unwrap(); - assert!(request.headers.is_empty()); + assert!(request.request.headers.is_empty()); let mut registrations = ctx.into_registrations(); assert_eq!(registrations.len(), 1); diff --git a/crates/adaptive/tests/unit/runtime_features_tests.rs b/crates/adaptive/tests/unit/runtime_features_tests.rs index cf09f84b8..cdae1a8d0 100644 --- a/crates/adaptive/tests/unit/runtime_features_tests.rs +++ b/crates/adaptive/tests/unit/runtime_features_tests.rs @@ -137,7 +137,11 @@ fn assert_llm_request_intercept_registered(name: &str) { name, i32::MAX, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ), name, ); @@ -148,7 +152,11 @@ fn assert_llm_request_intercept_absent(name: &str) { name, i32::MAX, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ) .unwrap(); deregister_llm_request_intercept(name).unwrap(); @@ -549,7 +557,7 @@ async fn adaptive_hints_feature_registers_request_intercept() { }, ) .unwrap(); - assert!(request.headers.contains_key(AGENT_HINTS_HEADER_KEY)); + assert!(request.request.headers.contains_key(AGENT_HINTS_HEADER_KEY)); let mut registrations = ctx.finish(); rollback_registrations(&mut registrations); @@ -712,7 +720,11 @@ async fn registration_context_registers_all_supported_callback_types() { "adaptive_test_request", 5, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ) .unwrap(); ctx.register_llm_execution_intercept( diff --git a/crates/adaptive/tests/unit/runtime_tests.rs b/crates/adaptive/tests/unit/runtime_tests.rs index 06c87f525..77231ccdb 100644 --- a/crates/adaptive/tests/unit/runtime_tests.rs +++ b/crates/adaptive/tests/unit/runtime_tests.rs @@ -631,7 +631,7 @@ async fn adaptive_runtime_bind_scope_requires_registration_and_passes_through_wi let translated = llm_request_intercepts("anthropic", request.clone()) .expect("request intercept chain should pass through when no hot-cache state exists"); - assert_eq!(translated.content, request.content); + assert_eq!(translated.request.content, request.content); pop_scope(PopScopeParams::builder().handle_uuid(&scope.uuid).build()) .expect("scope pop should succeed"); } diff --git a/crates/core/src/api/llm.rs b/crates/core/src/api/llm.rs index 0af799d04..6ac9a895a 100644 --- a/crates/core/src/api/llm.rs +++ b/crates/core/src/api/llm.rs @@ -3,17 +3,19 @@ use std::sync::Arc; -use chrono::{DateTime, Utc}; +use chrono::{DateTime, TimeDelta, Utc}; use serde::{Deserialize, Serialize}; use serde_json::json; use typed_builder::TypedBuilder; use uuid::Uuid; +use crate::api::event::{BaseEvent, Event, MarkEvent, PendingMarkSpec}; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::current_scope_stack; use crate::api::runtime::global_context; use crate::api::runtime::{ - LlmCollectorFn, LlmExecutionNextFn, LlmFinalizerFn, LlmJsonStream, LlmStreamExecutionNextFn, + EventSubscriberFn, LlmCollectorFn, LlmExecutionNextFn, LlmFinalizerFn, LlmJsonStream, + LlmStreamExecutionNextFn, }; use crate::api::scope::event; use crate::api::scope::{EmitMarkEventParams, ScopeHandle}; @@ -28,7 +30,7 @@ use crate::error::{FlowError, Result}; use crate::json::Json; use crate::stream::LlmStreamWrapper; -pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; /// Runtime-owned handle identifying an active or completed LLM call. #[derive(Debug, Clone, Serialize, Deserialize, TypedBuilder)] @@ -260,20 +262,39 @@ fn emit_llm_start( request_codec: Option<&dyn LlmCodec>, ) -> Result<()> { ensure_runtime_owner()?; - let (entries, subscribers) = { + let subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; + emit_llm_start_with_subscribers( + handle, + request, + annotated_request, + request_codec, + &subscribers, + ) +} + +fn emit_llm_start_with_subscribers( + handle: &LlmHandle, + request: &LlmRequest, + annotated_request: Option>, + request_codec: Option<&dyn LlmCodec>, + subscribers: &[EventSubscriberFn], +) -> Result<()> { + ensure_runtime_owner()?; + let entries = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_locals = scope_guard.collect_scope_local_registries(|registries| { ®istries.llm_sanitize_request_guardrails }); - let scope_subscribers = scope_guard.collect_scope_local_subscribers(); - let subscribers = snapshot_event_subscribers(scope_subscribers)?; let context = global_context(); let state = context .read() .map_err(|error| FlowError::Internal(error.to_string()))?; - let entries = state.llm_sanitize_request_entries(&scope_locals); - (entries, subscribers) + state.llm_sanitize_request_entries(&scope_locals) }; let sanitized_request = NemoRelayContextState::llm_sanitize_request_snapshot_chain(request.clone(), &entries); @@ -294,7 +315,34 @@ fn emit_llm_start( .map_err(|error| FlowError::Internal(error.to_string()))?; state.build_llm_start_event(handle, Some(input), annotated_request) }; - NemoRelayContextState::emit_event(&event, &subscribers); + NemoRelayContextState::emit_event(&event, subscribers); + Ok(()) +} + +fn emit_pending_request_marks( + handle: &LlmHandle, + marks: Vec, + subscribers: &[EventSubscriberFn], +) -> Result<()> { + if marks.is_empty() { + return Ok(()); + } + ensure_runtime_owner()?; + let timestamp = handle.started_at + TimeDelta::microseconds(1); + for mark in marks { + let event = Event::Mark(MarkEvent::new( + BaseEvent::builder() + .name(mark.name) + .parent_uuid(handle.uuid) + .timestamp(timestamp) + .data_opt(mark.data) + .metadata_opt(mark.metadata) + .build(), + mark.category, + mark.category_profile, + )); + NemoRelayContextState::emit_event(&event, subscribers); + } Ok(()) } @@ -387,12 +435,14 @@ pub fn llm_call_end(params: LlmCallEndParams<'_>) -> Result<()> { response_codec_errors_fatal: true, attach_estimated_cost: false, }, + None, ) } fn llm_call_end_with_behavior( params: LlmCallEndParams<'_>, behavior: LlmCallEndBehavior, + lifecycle_subscribers: Option<&[EventSubscriberFn]>, ) -> Result<()> { let LlmCallEndParams { handle, @@ -411,7 +461,10 @@ fn llm_call_end_with_behavior( ®istries.llm_sanitize_response_guardrails }); let scope_subscribers = scope_guard.collect_scope_local_subscribers(); - let subscribers = snapshot_event_subscribers(scope_subscribers)?; + let subscribers = match lifecycle_subscribers { + Some(subscribers) => subscribers.to_vec(), + None => snapshot_event_subscribers(scope_subscribers)?, + }; let context = global_context(); let state = context .read() @@ -471,13 +524,20 @@ fn llm_call_end_with_behavior( } } -fn emit_llm_end_without_output(handle: &LlmHandle, metadata: Option) -> Result<()> { +fn emit_llm_end_without_output( + handle: &LlmHandle, + metadata: Option, + lifecycle_subscribers: Option<&[EventSubscriberFn]>, +) -> Result<()> { ensure_runtime_owner()?; let (event, subscribers) = { let scope_stack = current_scope_stack(); let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); let scope_subscribers = scope_guard.collect_scope_local_subscribers(); - let subscribers = snapshot_event_subscribers(scope_subscribers)?; + let subscribers = match lifecycle_subscribers { + Some(subscribers) => subscribers.to_vec(), + None => snapshot_event_subscribers(scope_subscribers)?, + }; let context = global_context(); let state = context .read() @@ -587,7 +647,7 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { } let request_codec = codec.clone(); - let (intercepted_request, annotated_request) = + let (intercepted_request, annotated_request, pending_marks) = run_request_intercepts_with_codec(&name, request, codec)?; let handle = create_llm_handle( @@ -600,12 +660,19 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { .model_name_opt(model_name) .build(), )?; - emit_llm_start( + let lifecycle_subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; + emit_llm_start_with_subscribers( &handle, &intercepted_request, annotated_request.clone(), request_codec.as_deref(), + &lifecycle_subscribers, )?; + emit_pending_request_marks(&handle, pending_marks, &lifecycle_subscribers)?; let execution = { let scope_stack = current_scope_stack(); @@ -633,13 +700,15 @@ pub async fn llm_call_execute(params: LlmCallExecuteParams) -> Result { response_codec_errors_fatal: false, attach_estimated_cost: true, }, + Some(&lifecycle_subscribers), )?; Ok(response) } Err(error) => { let end_metadata = metadata_with_otel_status(metadata, "ERROR", Some(error.to_string())); - let _ = emit_llm_end_without_output(&handle, end_metadata); + let _ = + emit_llm_end_without_output(&handle, end_metadata, Some(&lifecycle_subscribers)); Err(error) } } @@ -743,7 +812,7 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu } let request_codec = codec.clone(); - let (intercepted_request, annotated_request) = + let (intercepted_request, annotated_request, pending_marks) = run_request_intercepts_with_codec(&name, request, codec)?; let handle = create_llm_handle( @@ -756,12 +825,19 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu .model_name_opt(model_name) .build(), )?; - emit_llm_start( + let lifecycle_subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + snapshot_event_subscribers(scope_guard.collect_scope_local_subscribers())? + }; + emit_llm_start_with_subscribers( &handle, &intercepted_request, annotated_request, request_codec.as_deref(), + &lifecycle_subscribers, )?; + emit_pending_request_marks(&handle, pending_marks, &lifecycle_subscribers)?; let execution = { let scope_stack = current_scope_stack(); @@ -778,21 +854,22 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu match execution(intercepted_request).await { Ok(raw_stream) => { - let wrapper = LlmStreamWrapper::new( + let wrapper = LlmStreamWrapper::new_managed( raw_stream, handle, collector, finalizer, - data, metadata, response_codec, + lifecycle_subscribers, ); Ok(Box::pin(wrapper) as LlmJsonStream) } Err(error) => { let end_metadata = metadata_with_otel_status(metadata, "ERROR", Some(error.to_string())); - let _ = emit_llm_end_without_output(&handle, end_metadata); + let _ = + emit_llm_end_without_output(&handle, end_metadata, Some(&lifecycle_subscribers)); Err(error) } } @@ -817,7 +894,14 @@ pub async fn llm_stream_call_execute(params: LlmStreamCallExecuteParams) -> Resu /// # Notes /// Conditional guardrails, codecs, and execution intercepts are not run by /// this helper. -pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result { +/// Run the LLM request-intercept chain and return its complete outcome. +/// +/// This helper does not emit the returned marks because it does not own an LLM +/// lifecycle. Callers must attach them to the lifecycle they own. +pub fn llm_request_intercepts( + name: &str, + request: LlmRequest, +) -> Result { ensure_runtime_owner()?; let entries = { let scope_stack = current_scope_stack(); @@ -830,10 +914,9 @@ pub fn llm_request_intercepts(name: &str, request: LlmRequest) -> Result Result> + /// - Third argument: Optional normalized request annotation to carry forward. /// /// # Returns -/// A [`Result`] containing the transformed request and optional annotation. +/// A [`Result`] containing the canonical request-intercept outcome. +/// Without a request codec, the returned request is authoritative. With a +/// request codec, its headers remain writable while its content must remain +/// unchanged; provider-body edits must be returned through the required +/// annotation. /// /// # Errors /// The callback can return any [`FlowError`](crate::error::FlowError) to abort /// the request-intercept chain. pub type LlmRequestInterceptFn = Arc< - dyn Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> + dyn Fn(&str, LlmRequest, Option) -> Result + Send + Sync, >; diff --git a/crates/core/src/api/runtime/state.rs b/crates/core/src/api/runtime/state.rs index 70276786d..eb6a57bf9 100644 --- a/crates/core/src/api/runtime/state.rs +++ b/crates/core/src/api/runtime/state.rs @@ -1027,6 +1027,8 @@ impl NemoRelayContextState { /// - `annotated`: Optional normalized request annotation to carry through /// the chain. /// - `entries`: Intercept snapshots to evaluate. + /// - `codec_active`: Whether request content is owned by the normalized + /// annotation and must remain unchanged by callbacks. /// /// # Returns /// A [`Result`] containing the final request and annotation pair. @@ -1042,19 +1044,38 @@ impl NemoRelayContextState { request: LlmRequest, annotated: Option, entries: &[Intercept], - ) -> crate::error::Result<(LlmRequest, Option)> { + codec_active: bool, + ) -> crate::error::Result { let mut request_value = request; let mut annotated_value = annotated; + let mut pending_marks = Vec::new(); for entry in entries { - let (new_request, new_annotated) = - (entry.payload.callable)(name, request_value, annotated_value)?; - request_value = new_request; - annotated_value = new_annotated; + let input_content = request_value.content.clone(); + let outcome = (entry.payload.callable)(name, request_value, annotated_value)?; + if codec_active && outcome.request.content != input_content { + return Err(crate::error::FlowError::InvalidArgument(format!( + "LLM request intercept '{}' changed request.content while a request codec is active; modify annotated_request instead", + entry.name + ))); + } + if codec_active && outcome.annotated_request.is_none() { + return Err(crate::error::FlowError::InvalidArgument(format!( + "LLM request intercept '{}' omitted annotated_request while a request codec is active", + entry.name + ))); + } + request_value = outcome.request; + annotated_value = outcome.annotated_request; + pending_marks.extend(outcome.pending_marks); if entry.payload.break_chain { break; } } - Ok((request_value, annotated_value)) + Ok(crate::api::llm::LlmRequestInterceptOutcome { + request: request_value, + annotated_request: annotated_value, + pending_marks, + }) } /// Build the composed non-streaming LLM execution continuation chain. @@ -1124,11 +1145,7 @@ impl NemoRelayContextState { fn end_timestamp_after(started_at: chrono::DateTime) -> chrono::DateTime { let now = Utc::now(); - if now > started_at { - now - } else { - started_at + Duration::microseconds(1) - } + std::cmp::max(now, started_at + Duration::microseconds(1)) } impl Default for NemoRelayContextState { diff --git a/crates/core/src/api/shared.rs b/crates/core/src/api/shared.rs index 861cd41c2..28965542c 100644 --- a/crates/core/src/api/shared.rs +++ b/crates/core/src/api/shared.rs @@ -74,8 +74,11 @@ pub(crate) fn run_request_intercepts_with_codec( name: &str, request: LlmRequest, codec: Option>, -) -> Result<(LlmRequest, Option>)> { - let original = request.clone(); +) -> Result<( + LlmRequest, + Option>, + Vec, +)> { let annotated = match &codec { Some(codec) => Some(codec.decode(&request)?), None => None, @@ -94,18 +97,23 @@ pub(crate) fn run_request_intercepts_with_codec( state.llm_request_intercept_entries(&scope_locals) }; - let (intercepted_request, intercepted_annotated) = + let outcome = crate::api::runtime::NemoRelayContextState::llm_request_intercepts_snapshot_chain( - name, request, annotated, &entries, + name, + request, + annotated, + &entries, + codec.is_some(), )?; + let pending_marks = outcome.pending_marks; - match (codec, intercepted_annotated) { + match (codec, outcome.annotated_request) { (Some(codec), Some(annotated)) => { - let mut encoded = codec.encode(&annotated, &original)?; - encoded.headers = intercepted_request.headers; - Ok((encoded, Some(Arc::new(annotated)))) + let mut encoded = codec.encode(&annotated, &outcome.request)?; + encoded.headers = outcome.request.headers; + Ok((encoded, Some(Arc::new(annotated)), pending_marks)) } - _ => Ok((intercepted_request, None)), + (_, annotated) => Ok((outcome.request, annotated.map(Arc::new), pending_marks)), } } diff --git a/crates/core/src/plugin/dynamic/native.rs b/crates/core/src/plugin/dynamic/native.rs index 0f1c14444..1ba467588 100644 --- a/crates/core/src/plugin/dynamic/native.rs +++ b/crates/core/src/plugin/dynamic/native.rs @@ -32,7 +32,7 @@ use tokio::runtime::Runtime; use tokio_stream::{Stream, StreamExt}; use crate::api::event::Event; -use crate::api::llm::LlmRequest; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::api::runtime::{ EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmJsonStream, LlmRequestInterceptFn, LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, @@ -48,7 +48,6 @@ use crate::api::scope::{ EmitMarkEventParams, PopScopeParams, PushScopeParams, ScopeAttributes, ScopeHandle, ScopeType, }; use crate::api::scope::{event as emit_scope_mark, get_handle, pop_scope, push_scope}; -use crate::codec::request::AnnotatedLlmRequest; use crate::error::{FlowError, Result as FlowResult}; use crate::plugin::{ ConfigDiagnostic, DiagnosticLevel, Plugin, PluginError, PluginRegistrationContext, @@ -1712,16 +1711,14 @@ fn wrap_llm_request_intercept_fn( } None => ptr::null_mut(), }; - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); let status = unsafe { cb( user_data.ptr, name_string, request_string, annotated_string, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; unsafe { @@ -1731,39 +1728,23 @@ fn wrap_llm_request_intercept_fn( } if status != NemoRelayStatus::Ok { unsafe { - native_string_free(out_request); - native_string_free(out_annotated); + native_string_free(out_outcome); } return Err(flow_error_from_status( status, "native LLM request intercept failed", )); } - let request_json = json_from_native_string( - out_request, - "native LLM request intercept returned null request", + let outcome_json = json_from_native_string( + out_outcome, + "native LLM request intercept returned null outcome", ); - let annotated_json = if out_annotated.is_null() { - Ok(None) - } else { - json_from_native_string(out_annotated, "invalid annotated request").map(Some) - }; unsafe { - native_string_free(out_request); - native_string_free(out_annotated); + native_string_free(out_outcome); } - let request_json = request_json?; - let annotated_json = annotated_json?; - let request: LlmRequest = serde_json::from_value(request_json) - .map_err(|err| FlowError::Internal(format!("invalid LLM request JSON: {err}")))?; - let annotated = annotated_json - .map(|annotated_json| { - serde_json::from_value::(annotated_json).map_err(|err| { - FlowError::Internal(format!("invalid annotated request JSON: {err}")) - }) - }) - .transpose()?; - Ok((request, annotated)) + serde_json::from_value::(outcome_json?).map_err(|err| { + FlowError::Internal(format!("invalid LLM request intercept outcome JSON: {err}")) + }) }) } diff --git a/crates/core/src/plugin/dynamic/worker.rs b/crates/core/src/plugin/dynamic/worker.rs index 2dfed5f4b..c06c26cd0 100644 --- a/crates/core/src/plugin/dynamic/worker.rs +++ b/crates/core/src/plugin/dynamic/worker.rs @@ -1195,7 +1195,7 @@ impl WorkerPluginCallback { model_name: &str, request: LlmRequest, annotated: Option, - ) -> FlowResult<(LlmRequest, Option)> { + ) -> FlowResult { let invoke = self.base_request( registration_name, RegistrationSurface::LlmRequestIntercept, @@ -1210,26 +1210,18 @@ impl WorkerPluginCallback { let response = self.invoke_blocking(invoke)?; match response.result { Some(invoke_response_result::Result::LlmRequest(result)) => { - let request = required_envelope(result.request, "llm request intercept request")?; - let request = decode_json_envelope::(&request).map_err(|err| { - FlowError::Internal(format!("worker returned invalid LLM request: {err}")) - })?; - let annotated = if result.has_annotated_request { - let envelope = required_envelope( - result.annotated_request, - "llm request intercept annotated request", - )?; - Some( - decode_json_envelope::(&envelope).map_err(|err| { - FlowError::Internal(format!( - "worker returned invalid annotated LLM request: {err}" - )) - })?, - ) - } else { - None - }; - Ok((request, annotated)) + let outcome = required_envelope(result.outcome, "llm request intercept outcome")?; + if outcome.schema != "nemo.relay.LlmRequestInterceptOutcome@1" { + return Err(FlowError::Internal(format!( + "worker returned unsupported LLM request intercept outcome schema: {}", + outcome.schema + ))); + } + decode_json_envelope(&outcome).map_err(|err| { + FlowError::Internal(format!( + "worker returned invalid LLM request intercept outcome: {err}" + )) + }) } Some(invoke_response_result::Result::Error(error)) => Err(worker_error_to_flow(error)), _ => Err(FlowError::Internal( diff --git a/crates/core/src/stream.rs b/crates/core/src/stream.rs index c2dde21f5..239228c83 100644 --- a/crates/core/src/stream.rs +++ b/crates/core/src/stream.rs @@ -35,7 +35,7 @@ use crate::api::event::{BaseEvent, MarkEvent}; use crate::api::llm::LlmHandle; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::global_context; -use crate::api::runtime::{ScopeStackHandle, current_scope_stack}; +use crate::api::runtime::{EventSubscriberFn, ScopeStackHandle, current_scope_stack}; use crate::api::shared::metadata_with_otel_status; use crate::codec::response::{AnnotatedLlmResponse, attach_estimated_cost_for_provider}; use crate::codec::traits::LlmResponseCodec; @@ -63,6 +63,7 @@ pub struct LlmStreamWrapper { finalizer: Option Json + Send>>, response_codec: Option>, metadata: Option, + subscribers: Vec, chunk_index: u64, ended: bool, } @@ -97,6 +98,36 @@ impl LlmStreamWrapper { _data: Option, metadata: Option, response_codec: Option>, + ) -> Self { + let subscribers = { + let scope_stack = current_scope_stack(); + let scope_guard = scope_stack.read().expect("scope stack lock poisoned"); + let scope_subscribers = scope_guard.collect_scope_local_subscribers(); + let context = global_context(); + context + .read() + .map(|state| state.collect_event_subscribers(&scope_subscribers)) + .unwrap_or_default() + }; + Self::new_managed( + inner, + handle, + collector, + finalizer, + metadata, + response_codec, + subscribers, + ) + } + + pub(crate) fn new_managed( + inner: Pin> + Send>>, + handle: LlmHandle, + collector: Box Result<()> + Send>, + finalizer: Box Json + Send>, + metadata: Option, + response_codec: Option>, + subscribers: Vec, ) -> Self { Self { inner, @@ -106,6 +137,7 @@ impl LlmStreamWrapper { finalizer: Some(finalizer), response_codec, metadata, + subscribers, chunk_index: 0, ended: false, } @@ -155,19 +187,17 @@ impl LlmStreamWrapper { let ss_guard = self.scope_stack.read().expect("scope stack lock poisoned"); let sl = ss_guard.collect_scope_local_registries(|r| &r.llm_sanitize_response_guardrails); - let sl_subs = ss_guard.collect_scope_local_subscribers(); let ctx = global_context(); let state = ctx.read(); match state { Ok(state) => { - let subscribers = state.collect_event_subscribers(&sl_subs); let entries = state.llm_sanitize_response_entries(&sl); - Some((entries, subscribers)) + Some(entries) } Err(_) => None, } }; - let Some((entries, subscribers)) = snapshot else { + let Some(entries) = snapshot else { return; }; let sanitized = @@ -197,7 +227,7 @@ impl LlmStreamWrapper { } }; if let Some(event) = event_snapshot { - NemoRelayContextState::emit_event(&event, &subscribers); + NemoRelayContextState::emit_event(&event, &self.subscribers); } } @@ -205,15 +235,10 @@ impl LlmStreamWrapper { fn emit_chunk_mark(&self, chunk_index: u64, raw_chunk: &Json) { let data = llm_chunk_mark_data(chunk_index, raw_chunk); let event_snapshot = { - let Ok(ss_guard) = self.scope_stack.read() else { - return; - }; - let sl_subs = ss_guard.collect_scope_local_subscribers(); let ctx = global_context(); let state = ctx.read(); match state { Ok(state) => { - let subscribers = state.collect_event_subscribers(&sl_subs); let event = state.create_event(MarkEvent::new( BaseEvent::builder() .name("llm.chunk") @@ -223,13 +248,13 @@ impl LlmStreamWrapper { None, None, )); - Some((event, subscribers)) + Some(event) } Err(_) => None, } }; - if let Some((event, subscribers)) = event_snapshot { - NemoRelayContextState::emit_event(&event, &subscribers); + if let Some(event) = event_snapshot { + NemoRelayContextState::emit_event(&event, &self.subscribers); } } } diff --git a/crates/core/tests/fixtures/native_plugin/src/lib.rs b/crates/core/tests/fixtures/native_plugin/src/lib.rs index 9fc58bd41..cb99cb66c 100644 --- a/crates/core/tests/fixtures/native_plugin/src/lib.rs +++ b/crates/core/tests/fixtures/native_plugin/src/lib.rs @@ -5,10 +5,10 @@ use std::ffi::c_void; use std::ptr; use nemo_relay_plugin::{ - ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmRequest, - NemoRelayNativeHostApiV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, - NemoRelayNativeString, NemoRelayStatus, NativePlugin, PluginContext, PluginRuntime, - ScopeCategory, ScopeType, + CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, Json, LlmJsonStream, + LlmRequest, LlmRequestInterceptOutcome, NemoRelayNativeHostApiV1, + NemoRelayNativePluginContext, NemoRelayNativePluginV1, NemoRelayNativeString, NemoRelayStatus, + NativePlugin, PendingMarkSpec, PluginContext, PluginRuntime, ScopeCategory, ScopeType, }; use serde_json::{Map, json}; @@ -119,9 +119,21 @@ impl NativePlugin for FixtureNativePlugin { 0, false, |_name, request, annotated| { - Ok(( + Ok(LlmRequestInterceptOutcome::new( mark_llm_request(request, "native_plugin_llm_request_intercept"), annotated, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("fixture.native.llm_request.mark") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("fixture.native.pending".into()), + ..CategoryProfile::default() + }) + .data(json!({ "source": "native_request_intercept" })) + .metadata(json!({ "fixture": true })) + .build(), )) }, )?; diff --git a/crates/core/tests/fixtures/worker_plugin/src/main.rs b/crates/core/tests/fixtures/worker_plugin/src/main.rs index c6e1e2509..6291e1dc4 100644 --- a/crates/core/tests/fixtures/worker_plugin/src/main.rs +++ b/crates/core/tests/fixtures/worker_plugin/src/main.rs @@ -5,7 +5,9 @@ use nemo_relay_worker::{ JsonStream, LlmNext, LlmStreamNext, PluginContext, ScopeType, ToolNext, WorkerPlugin, WorkerSdkError, serve_plugin, }; -use nemo_relay_worker::{ConfigDiagnostic, DiagnosticLevel, Json, LlmRequest}; +use nemo_relay_worker::{ + ConfigDiagnostic, DiagnosticLevel, Json, LlmRequest, PendingMarkSpec, +}; use serde_json::json; struct FixtureWorkerPlugin; @@ -182,13 +184,29 @@ impl WorkerPlugin for FixtureWorkerPlugin { "fixture LLM request error requested".into(), )); } - let annotated = annotated.map(|mut annotated| { - annotated - .extra - .insert("worker_plugin_annotated_request".into(), json!(true)); - annotated - }); - Ok((mark_llm_request(request, "worker_plugin_llm_request_intercept"), annotated)) + let (request, annotated) = match annotated { + Some(mut annotated) => { + annotated + .extra + .insert("worker_plugin_annotated_request".into(), json!(true)); + (request, Some(annotated)) + } + None => ( + mark_llm_request(request, "worker_plugin_llm_request_intercept"), + None, + ), + }; + Ok(nemo_relay_worker::LlmRequestInterceptOutcome::new( + request, + annotated, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("fixture.worker.llm_request.mark") + .data(json!({ "source": "worker_request_intercept" })) + .metadata(json!({ "fixture": true })) + .build(), + )) }, ); ctx.register_llm_execution_intercept( diff --git a/crates/core/tests/integration/api_surface_tests.rs b/crates/core/tests/integration/api_surface_tests.rs index e751c15b3..3bc8538a0 100644 --- a/crates/core/tests/integration/api_surface_tests.rs +++ b/crates/core/tests/integration/api_surface_tests.rs @@ -423,7 +423,11 @@ fn test_global_registry_and_subscriber_wrappers_cover_success_and_duplicates() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ) .unwrap(); assert!(deregister_llm_request_intercept("llm-request").unwrap()); @@ -612,7 +616,11 @@ fn test_scope_registry_and_subscriber_wrappers_cover_success_duplicates_and_miss "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) + }), ) .unwrap(); assert!(scope_deregister_llm_request_intercept(&scope.uuid, "llm-request").unwrap()); @@ -931,7 +939,9 @@ async fn test_llm_api_emits_sanitized_events_and_covers_error_paths() { false, Arc::new(|_name, mut request, annotated| { request.headers.insert("x-intercepted".into(), json!(true)); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); @@ -940,7 +950,10 @@ async fn test_llm_api_emits_sanitized_events_and_covers_error_paths() { make_llm_request(json!({"messages": [{"role": "user", "content": "hello"}]})), ) .unwrap(); - assert_eq!(intercepted.headers.get("x-intercepted"), Some(&json!(true))); + assert_eq!( + intercepted.request.headers.get("x-intercepted"), + Some(&json!(true)) + ); deregister_llm_request_intercept("llm-request").unwrap(); register_llm_conditional_execution_guardrail( diff --git a/crates/core/tests/integration/middleware_tests.rs b/crates/core/tests/integration/middleware_tests.rs index 612c5a83d..cf9d29e20 100644 --- a/crates/core/tests/integration/middleware_tests.rs +++ b/crates/core/tests/integration/middleware_tests.rs @@ -14,12 +14,14 @@ use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use futures::StreamExt; -use nemo_relay::api::event::{Event, ScopeCategory}; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::event::{ + CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, +}; use nemo_relay::api::llm::{ LlmCallExecuteParams, LlmStreamCallExecuteParams, llm_call_execute, llm_request_intercepts, llm_stream_call_execute, }; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::registry::{ deregister_llm_conditional_execution_guardrail, deregister_llm_execution_intercept, deregister_llm_request_intercept, deregister_llm_sanitize_request_guardrail, @@ -1700,13 +1702,17 @@ fn test_llm_request_intercept_registry_mutations_apply_to_later_calls() { Arc::new(move |_, request, annotated| { record_middleware_callback(&tracked, "llm_request_late"); assert_middleware_callback_locks_are_free(); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); } - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); @@ -1719,7 +1725,7 @@ fn test_llm_request_intercept_registry_mutations_apply_to_later_calls() { }, ) .unwrap(); - assert_eq!(request.content["round"], 1); + assert_eq!(request.request.content["round"], 1); assert_middleware_callback_labels(&callbacks, &["llm_request_initial"]); callbacks.lock().unwrap().clear(); @@ -1731,7 +1737,7 @@ fn test_llm_request_intercept_registry_mutations_apply_to_later_calls() { }, ) .unwrap(); - assert_eq!(request.content["round"], 2); + assert_eq!(request.request.content["round"], 2); assert_middleware_callback_labels(&callbacks, &["llm_request_initial", "llm_request_late"]); deregister_llm_request_intercept("snapshot_llm_request_initial").unwrap(); @@ -1947,7 +1953,9 @@ async fn test_llm_middleware_callbacks_run_without_registry_or_scope_locks() { Arc::new(move |_, request, annotated| { record_middleware_callback(&tracked, "llm_request_global"); assert_middleware_callback_locks_are_free(); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); @@ -1960,7 +1968,9 @@ async fn test_llm_middleware_callbacks_run_without_registry_or_scope_locks() { Arc::new(move |_, request, annotated| { record_middleware_callback(&tracked, "llm_request_scope"); assert_middleware_callback_locks_are_free(); - Ok((request, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + request, annotated, + )) }), ) .unwrap(); @@ -2571,7 +2581,9 @@ async fn test_llm_request_intercept_transforms() { false, Arc::new(|_name: &str, mut req: LlmRequest, annotated| { req.headers.insert("x-intercepted".into(), json!(true)); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -2582,12 +2594,223 @@ async fn test_llm_request_intercept_transforms() { }; let result = llm_request_intercepts("test_llm", request).unwrap(); - assert_eq!(result.headers["x-intercepted"], true); + assert_eq!(result.request.headers["x-intercepted"], true); // Cleanup deregister_llm_request_intercept("llm_req_i").unwrap(); } +#[test] +fn test_llm_request_intercept_pending_marks_preserve_order_and_break_chain() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + for (name, priority, break_chain, mark_name) in [ + ("pending_first", 1, false, "first"), + ("pending_break", 2, true, "second"), + ("pending_skipped", 3, false, "skipped"), + ] { + register_llm_request_intercept( + name, + priority, + break_chain, + Arc::new(move |_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark(PendingMarkSpec::builder().name(mark_name).build())) + }), + ) + .unwrap(); + } + + let outcome = llm_request_intercepts( + "llm", + LlmRequest { + headers: serde_json::Map::new(), + content: json!({"prompt": "hello"}), + }, + ) + .unwrap(); + + assert_eq!( + outcome + .pending_marks + .iter() + .map(|mark| mark.name.as_str()) + .collect::>(), + ["first", "second"] + ); + assert_eq!(outcome.request.content["prompt"], "hello"); + + for name in ["pending_first", "pending_break", "pending_skipped"] { + deregister_llm_request_intercept(name).unwrap(); + } +} + +#[tokio::test] +async fn test_managed_llm_emits_pending_marks_under_started_scope() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "pending_mark_observer", + Arc::new(move |event: &Event| captured.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + register_llm_request_intercept( + "pending_managed", + 1, + false, + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark( + PendingMarkSpec::builder() + .name("request.optimized") + .category(EventCategory::custom()) + .category_profile( + CategoryProfile::builder() + .subtype("optimizer.saved_tokens") + .build(), + ) + .data(json!({"saved_tokens": 12})) + .build(), + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("request.optimized.second") + .build(), + )) + }), + ) + .unwrap(); + + let provider_request = Arc::new(Mutex::new(None::)); + let captured_request = provider_request.clone(); + llm_call_execute( + LlmCallExecuteParams::builder() + .name("pending-managed-llm") + .request(LlmRequest { + headers: serde_json::Map::from_iter([( + "x-pending-mark-test".into(), + json!("preserved"), + )]), + content: json!({"prompt": "hello"}), + }) + .func(Arc::new(move |request| { + *captured_request.lock().unwrap() = Some(request); + Box::pin(async { Ok(json!({"response": "done"})) }) + })) + .build(), + ) + .await + .unwrap(); + + let provider_request = provider_request.lock().unwrap().clone().unwrap(); + assert_eq!( + provider_request.headers.get("x-pending-mark-test"), + Some(&json!("preserved")) + ); + assert_eq!(provider_request.content["prompt"], "hello"); + assert!(provider_request.content.get("pending_marks").is_none()); + assert!(provider_request.content.get("annotated_request").is_none()); + + let captured = captured_events_snapshot(&events); + let start = captured + .iter() + .find(|event| { + event.name() == "pending-managed-llm" + && event.scope_category() == Some(ScopeCategory::Start) + }) + .unwrap(); + let mark = captured + .iter() + .find(|event| event.name() == "request.optimized") + .unwrap(); + let second_mark = captured + .iter() + .find(|event| event.name() == "request.optimized.second") + .unwrap(); + let end = captured + .iter() + .find(|event| { + event.name() == "pending-managed-llm" + && event.scope_category() == Some(ScopeCategory::End) + }) + .unwrap(); + assert_eq!(mark.parent_uuid(), Some(start.uuid())); + assert_eq!(second_mark.parent_uuid(), Some(start.uuid())); + assert!(mark.timestamp() > start.timestamp()); + assert_eq!(mark.timestamp(), second_mark.timestamp()); + assert!(end.timestamp() >= mark.timestamp()); + assert_eq!(mark.data().unwrap()["saved_tokens"], 12); + + deregister_llm_request_intercept("pending_managed").unwrap(); + deregister_subscriber("pending_mark_observer").unwrap(); +} + +#[tokio::test] +async fn test_failed_request_intercept_does_not_emit_pending_marks_or_start_scope() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::::new())); + let captured = events.clone(); + register_subscriber( + "failed_pending_mark_observer", + Arc::new(move |event: &Event| captured.lock().unwrap().push(event.clone())), + ) + .unwrap(); + register_llm_request_intercept( + "pending_before_failure", + 1, + false, + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark(PendingMarkSpec::builder().name("must.not.emit").build())) + }), + ) + .unwrap(); + register_llm_request_intercept( + "pending_failure", + 2, + false, + Arc::new(|_name, _request, _annotated| { + Err(FlowError::Internal("request intercept failed".into())) + }), + ) + .unwrap(); + + let provider_called = Arc::new(AtomicBool::new(false)); + let called = provider_called.clone(); + let result = llm_call_execute( + LlmCallExecuteParams::builder() + .name("failed-pending-llm") + .request(LlmRequest { + headers: serde_json::Map::new(), + content: json!({"prompt": "hello"}), + }) + .func(Arc::new(move |_request| { + called.store(true, Ordering::SeqCst); + Box::pin(async { Ok(json!({"response": "unexpected"})) }) + })) + .build(), + ) + .await; + + assert!(result.is_err()); + assert!(!provider_called.load(Ordering::SeqCst)); + assert!(captured_events_snapshot(&events).is_empty()); + + deregister_llm_request_intercept("pending_before_failure").unwrap(); + deregister_llm_request_intercept("pending_failure").unwrap(); + deregister_subscriber("failed_pending_mark_observer").unwrap(); +} + /// LLM execution intercept middleware chain with next(). #[tokio::test] async fn test_llm_execution_intercept_chain() { @@ -2672,7 +2895,9 @@ async fn test_llm_start_emits_before_short_circuit_execution_intercept() { .as_object_mut() .unwrap() .insert("phase".into(), json!("request")); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -2764,7 +2989,9 @@ async fn test_llm_stream_start_emits_before_short_circuit_execution_intercept() .as_object_mut() .unwrap() .insert("phase".into(), json!("request")); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); diff --git a/crates/core/tests/integration/native_plugin_tests.rs b/crates/core/tests/integration/native_plugin_tests.rs index fc56422a0..0d2b64f05 100644 --- a/crates/core/tests/integration/native_plugin_tests.rs +++ b/crates/core/tests/integration/native_plugin_tests.rs @@ -340,6 +340,24 @@ async fn sdk_cdylib_registers_tool_request_intercept() { llm_start.input().unwrap()["content"]["native_plugin_llm_request_intercept"], true ); + let pending_mark = find_event(&managed_llm_events, "fixture.native.llm_request.mark", None); + assert_eq!(pending_mark.parent_uuid(), Some(llm_start.uuid())); + assert_eq!( + pending_mark.category().map(|category| category.as_str()), + Some("custom") + ); + assert_eq!( + pending_mark + .category_profile() + .and_then(|profile| profile.subtype.as_deref()), + Some("fixture.native.pending") + ); + assert_eq!( + pending_mark.data().unwrap()["source"], + "native_request_intercept" + ); + assert_eq!(pending_mark.metadata().unwrap()["fixture"], true); + assert!(pending_mark.timestamp() > llm_start.timestamp()); let llm_end = find_event( &managed_llm_events, "native-fixture-llm-execute", @@ -398,6 +416,13 @@ async fn sdk_cdylib_registers_tool_request_intercept() { assert_eq!(*collected_stream_chunks.lock().unwrap(), stream_chunks); flush_subscribers().expect("stream native fixture events should flush"); let stream_events = events.lock().unwrap().clone(); + let stream_start = find_event( + &stream_events, + "native-fixture-llm-stream", + Some(ScopeCategory::Start), + ); + let stream_pending_mark = find_event(&stream_events, "fixture.native.llm_request.mark", None); + assert_eq!(stream_pending_mark.parent_uuid(), Some(stream_start.uuid())); let stream_end = find_event( &stream_events, "native-fixture-llm-stream", diff --git a/crates/core/tests/integration/pipeline_tests.rs b/crates/core/tests/integration/pipeline_tests.rs index eefe05cd3..cf411e509 100644 --- a/crates/core/tests/integration/pipeline_tests.rs +++ b/crates/core/tests/integration/pipeline_tests.rs @@ -14,11 +14,11 @@ use futures::StreamExt; use serde_json::json; use tokio_stream::Stream; -use nemo_relay::api::event::{Event, ScopeCategory}; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::event::{Event, PendingMarkSpec, ScopeCategory}; use nemo_relay::api::llm::{ LlmCallExecuteParams, LlmStreamCallExecuteParams, llm_call_execute, llm_stream_call_execute, }; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::registry::{ deregister_llm_request_intercept, deregister_llm_sanitize_request_guardrail, deregister_llm_sanitize_response_guardrail, register_llm_request_intercept, @@ -310,7 +310,9 @@ async fn test_decode_runs_before_intercepts() { false, Arc::new(move |_name, req, annotated| { *cap.lock().unwrap() = Some(annotated.clone()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -359,10 +361,14 @@ async fn test_encode_runs_after_intercepts() { "modify_model", 1, false, - Arc::new(|_name, req, annotated| { + Arc::new(|_name, mut req, annotated| { let mut ann = annotated.unwrap(); ann.model = Some("modified".into()); - Ok((req, Some(ann))) + req.headers.insert("x-codec-route".into(), json!("blue")); + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, + Some(ann), + )) }), ) .unwrap(); @@ -397,11 +403,185 @@ async fn test_encode_runs_after_intercepts() { let captured_req = exec_request.lock().unwrap(); let req = captured_req.as_ref().unwrap(); assert_eq!(req.content["model"], json!("modified")); + assert_eq!(req.headers["x-codec-route"], json!("blue")); // Cleanup deregister_llm_request_intercept("modify_model").unwrap(); } +#[tokio::test] +async fn test_codec_rejects_raw_content_mutation_before_lifecycle() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::new())); + let captured_events = events.clone(); + register_subscriber( + "codec_raw_mutation_subscriber", + Arc::new(move |event| captured_events.lock().unwrap().push(event.clone())), + ) + .unwrap(); + + let later_intercept_called = Arc::new(Mutex::new(false)); + register_llm_request_intercept( + "codec_raw_mutation", + 1, + false, + Arc::new(|_name, mut request, annotated| { + request.content["model"] = json!("raw-model-edit"); + Ok(LlmRequestInterceptOutcome::new(request, annotated) + .with_pending_mark(PendingMarkSpec::builder().name("must.not.emit").build())) + }), + ) + .unwrap(); + let later_called = later_intercept_called.clone(); + register_llm_request_intercept( + "codec_after_raw_mutation", + 2, + false, + Arc::new(move |_name, request, annotated| { + *later_called.lock().unwrap() = true; + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), + ) + .unwrap(); + + let provider_called = Arc::new(Mutex::new(false)); + let called = provider_called.clone(); + let provider: LlmExecutionNextFn = Arc::new(move |_request| { + *called.lock().unwrap() = true; + Box::pin(async { Ok(json!({"response": "unexpected"})) }) + }); + let (codec, _, encode_log) = make_tracking_codec("codec_raw_read_only"); + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("test_llm") + .request(make_llm_request(json!({ + "model": "original", + "messages": [{"role": "user", "content": "hello"}] + }))) + .func(provider) + .codec(codec) + .build(), + ) + .await + .expect_err("raw content mutation must fail on the codec path"); + + assert!(error.to_string().contains("request.content")); + assert!(!*later_intercept_called.lock().unwrap()); + assert!(!*provider_called.lock().unwrap()); + assert!(encode_log.lock().unwrap().is_empty()); + assert!(captured_events_snapshot(&events).is_empty()); + + deregister_llm_request_intercept("codec_raw_mutation").unwrap(); + deregister_llm_request_intercept("codec_after_raw_mutation").unwrap(); + deregister_subscriber("codec_raw_mutation_subscriber").unwrap(); +} + +#[tokio::test] +async fn test_codec_rejects_missing_annotation_before_lifecycle() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + register_llm_request_intercept( + "codec_missing_annotation", + 1, + false, + Arc::new(|_name, request, _annotated| Ok(LlmRequestInterceptOutcome::new(request, None))), + ) + .unwrap(); + + let provider_called = Arc::new(Mutex::new(false)); + let called = provider_called.clone(); + let provider: LlmExecutionNextFn = Arc::new(move |_request| { + *called.lock().unwrap() = true; + Box::pin(async { Ok(json!({"response": "unexpected"})) }) + }); + let (codec, _, encode_log) = make_tracking_codec("codec_annotation_required"); + let error = llm_call_execute( + LlmCallExecuteParams::builder() + .name("test_llm") + .request(make_llm_request(json!({ + "messages": [{"role": "user", "content": "hello"}] + }))) + .func(provider) + .codec(codec) + .build(), + ) + .await + .expect_err("missing annotation must fail on the codec path"); + + assert!(error.to_string().contains("omitted annotated_request")); + assert!(!*provider_called.lock().unwrap()); + assert!(encode_log.lock().unwrap().is_empty()); + + deregister_llm_request_intercept("codec_missing_annotation").unwrap(); +} + +#[tokio::test] +async fn test_stream_codec_rejects_raw_content_mutation_before_lifecycle() { + let _lock = TEST_MUTEX.lock().unwrap(); + reset_global(); + setup_isolated_thread(); + + let events = Arc::new(Mutex::new(Vec::new())); + let captured_events = events.clone(); + register_subscriber( + "stream_codec_raw_mutation_subscriber", + Arc::new(move |event| captured_events.lock().unwrap().push(event.clone())), + ) + .unwrap(); + register_llm_request_intercept( + "stream_codec_raw_mutation", + 1, + false, + Arc::new(|_name, mut request, annotated| { + request.content["model"] = json!("raw-stream-edit"); + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), + ) + .unwrap(); + + let provider_called = Arc::new(Mutex::new(false)); + let called = provider_called.clone(); + let provider: LlmStreamExecutionNextFn = Arc::new(move |_request| { + *called.lock().unwrap() = true; + Box::pin(async { + let stream: Pin> + Send>> = + Box::pin(futures::stream::empty()); + Ok(stream) + }) + }); + let (codec, _, _) = make_tracking_codec("stream_codec_raw_read_only"); + let result = llm_stream_call_execute( + LlmStreamCallExecuteParams::builder() + .name("test_stream") + .request(make_llm_request(json!({ + "model": "original", + "messages": [{"role": "user", "content": "hello"}] + }))) + .func(provider) + .collector(Box::new(|_| Ok(()))) + .finalizer(Box::new(|| json!({"done": true}))) + .codec(codec) + .build(), + ) + .await; + let error = match result { + Ok(_) => panic!("raw content mutation must fail on the streaming codec path"), + Err(error) => error, + }; + + assert!(error.to_string().contains("request.content")); + assert!(!*provider_called.lock().unwrap()); + assert!(captured_events_snapshot(&events).is_empty()); + + deregister_llm_request_intercept("stream_codec_raw_mutation").unwrap(); + deregister_subscriber("stream_codec_raw_mutation_subscriber").unwrap(); +} + // =========================================================================== // Intercept receives both LlmRequest and AnnotatedLlmRequest // =========================================================================== @@ -424,7 +604,9 @@ async fn test_annotated_intercept_receives_both() { false, Arc::new(move |_name, req, annotated| { *cp.lock().unwrap() = Some((req.clone(), annotated.clone())); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -465,12 +647,12 @@ async fn test_annotated_intercept_receives_both() { // =========================================================================== #[tokio::test] -async fn test_legacy_intercept_backward_compat() { +async fn test_canonical_intercept_with_and_without_codec() { let _lock = TEST_MUTEX.lock().unwrap(); reset_global(); setup_isolated_thread(); - // Part 1: Legacy intercept with no Codec + // Part 1: canonical intercept with no codec. let legacy_called_1 = Arc::new(Mutex::new(false)); let lc1 = legacy_called_1.clone(); register_llm_request_intercept( @@ -480,7 +662,9 @@ async fn test_legacy_intercept_backward_compat() { Arc::new(move |_name, mut req, annotated| { *lc1.lock().unwrap() = true; req.headers.insert("x-legacy".into(), json!("was-here")); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -510,7 +694,7 @@ async fn test_legacy_intercept_backward_compat() { // Cleanup part 1 deregister_llm_request_intercept("legacy_1").unwrap(); - // Part 2: Legacy intercept WITH Codec — legacy intercept still runs + // Part 2: canonical intercept with a codec. reset_global(); let (codec, _, _) = make_tracking_codec("codec_D"); @@ -524,7 +708,9 @@ async fn test_legacy_intercept_backward_compat() { Arc::new(move |_name, mut req, annotated| { *lc2.lock().unwrap() = true; req.headers.insert("x-legacy-2".into(), json!("also-here")); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -580,7 +766,9 @@ async fn test_stream_path_also_decodes() { false, Arc::new(move |_name, req, annotated| { *ca.lock().unwrap() = Some(annotated.clone()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -644,7 +832,9 @@ async fn test_shared_helper_both_paths() { false, Arc::new(move |_name, req, annotated| { *acc.lock().unwrap() += 1; - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -721,7 +911,9 @@ async fn test_explicit_codec_param_overrides() { if let Some(ref ann) = annotated { *cm.lock().unwrap() = ann.model.clone(); } - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -768,7 +960,10 @@ async fn test_encode_merge_not_replace() { Arc::new(|_name, req, annotated| { let mut ann = annotated.unwrap(); ann.model = Some("new_model".into()); - Ok((req, Some(ann))) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, + Some(ann), + )) }), ) .unwrap(); @@ -836,7 +1031,9 @@ async fn test_unified_chain_priority_order() { false, Arc::new(move |_name, req, annotated| { cl1.lock().unwrap().push("legacy_p10".into()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -849,7 +1046,9 @@ async fn test_unified_chain_priority_order() { false, Arc::new(move |_name, req, annotated| { cl2.lock().unwrap().push("annotated_p5".into()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); @@ -897,7 +1096,9 @@ async fn test_no_codec_annotated_intercept_receives_none() { false, Arc::new(move |_name, req, annotated| { *ca.lock().unwrap() = Some(annotated.clone()); - Ok((req, annotated)) + Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new( + req, annotated, + )) }), ) .unwrap(); diff --git a/crates/core/tests/integration/worker_plugin_tests.rs b/crates/core/tests/integration/worker_plugin_tests.rs index 93b4360b4..ae1274f5b 100644 --- a/crates/core/tests/integration/worker_plugin_tests.rs +++ b/crates/core/tests/integration/worker_plugin_tests.rs @@ -182,6 +182,13 @@ async fn rust_worker_registers_and_invokes_all_current_surfaces() { llm_start.input().unwrap()["content"]["worker_plugin_llm_sanitize_request"], true ); + let pending_mark = find_event(&captured_events, "fixture.worker.llm_request.mark", None); + assert_eq!(pending_mark.parent_uuid(), Some(llm_start.uuid())); + assert_eq!( + pending_mark.data().unwrap()["source"], + "worker_request_intercept" + ); + assert_eq!(pending_mark.metadata().unwrap()["fixture"], true); let llm_end = find_event( &captured_events, "worker-fixture-llm-execute", diff --git a/crates/core/tests/unit/dynamic_worker_tests.rs b/crates/core/tests/unit/dynamic_worker_tests.rs index f779f8370..eff98271f 100644 --- a/crates/core/tests/unit/dynamic_worker_tests.rs +++ b/crates/core/tests/unit/dynamic_worker_tests.rs @@ -335,29 +335,31 @@ async fn callback_helpers_cover_worker_response_edges() { }, "llm_intercept_invalid_request" => InvokeResponse { result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { - request: Some(JsonEnvelope { - schema: LLM_REQUEST_SCHEMA.into(), - json: b"null".to_vec(), + outcome: Some(JsonEnvelope { + schema: "nemo.relay.LlmRequestInterceptOutcome@1".into(), + json: br#"{"request":null}"#.to_vec(), }), - annotated_request: None, - has_annotated_request: false, })), }, "llm_intercept_missing_annotated" => InvokeResponse { result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { - request: Some(valid_llm_request_envelope()), - annotated_request: None, - has_annotated_request: true, + outcome: Some(JsonEnvelope { + schema: "nemo.relay.LegacyLlmRequestInterceptResult@1".into(), + json: br#"{}"#.to_vec(), + }), })), }, "llm_intercept_invalid_annotated" => InvokeResponse { result: Some(InvokeResult::LlmRequest(LlmRequestInterceptResult { - request: Some(valid_llm_request_envelope()), - annotated_request: Some(JsonEnvelope { - schema: ANNOTATED_LLM_REQUEST_SCHEMA.into(), - json: b"null".to_vec(), + outcome: Some(JsonEnvelope { + schema: "nemo.relay.LlmRequestInterceptOutcome@1".into(), + json: serde_json::to_vec(&json!({ + "request": valid_llm_request(), + "annotated_request": 3, + "pending_marks": [], + })) + .unwrap(), }), - has_annotated_request: true, })), }, "llm_intercept_error" => InvokeResponse { @@ -405,7 +407,11 @@ async fn callback_helpers_cover_worker_response_edges() { None, ) .expect_err("invalid LLM intercept request should fail"); - assert!(error.to_string().contains("invalid LLM request")); + assert!( + error + .to_string() + .contains("invalid LLM request intercept outcome") + ); let error = callback .invoke_llm_request_intercept( @@ -414,11 +420,11 @@ async fn callback_helpers_cover_worker_response_edges() { valid_llm_request(), None, ) - .expect_err("missing annotated request should fail when flagged present"); + .expect_err("legacy outcome schema should fail"); assert!( error .to_string() - .contains("llm request intercept annotated request is missing") + .contains("unsupported LLM request intercept outcome schema") ); let error = callback @@ -429,7 +435,11 @@ async fn callback_helpers_cover_worker_response_edges() { None, ) .expect_err("invalid annotated request should fail"); - assert!(error.to_string().contains("invalid annotated LLM request")); + assert!( + error + .to_string() + .contains("invalid LLM request intercept outcome") + ); let error = callback .invoke_llm_request_intercept("llm_intercept_error", "model", valid_llm_request(), None) @@ -1327,10 +1337,6 @@ fn valid_llm_request() -> LlmRequest { } } -fn valid_llm_request_envelope() -> JsonEnvelope { - json_envelope(LLM_REQUEST_SCHEMA, &valid_llm_request()).expect("llm request envelope") -} - async fn fake_callback_service( invoke: impl Fn(InvokeRequest) -> InvokeResponse + Send + Sync + 'static, ) -> (WorkerPluginCallback, oneshot::Sender<()>) { diff --git a/crates/core/tests/unit/plugin_tests.rs b/crates/core/tests/unit/plugin_tests.rs index cda76a07c..282389a21 100644 --- a/crates/core/tests/unit/plugin_tests.rs +++ b/crates/core/tests/unit/plugin_tests.rs @@ -10,7 +10,7 @@ use std::sync::{Mutex, OnceLock}; use serde_json::json; -use crate::api::llm::LlmRequest; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::api::llm::{llm_conditional_execution, llm_request_intercepts}; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::global_context; @@ -93,7 +93,7 @@ impl Plugin for TestPlugin { false, Arc::new(|_name, mut request, annotated| { request.headers.insert("x-plugin".into(), json!(true)); - Ok((request, annotated)) + Ok(LlmRequestInterceptOutcome::new(request, annotated)) }), ) }) @@ -471,7 +471,7 @@ fn test_plugin_registration_context_registers_and_rolls_back() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), Some(&json!(true))); + assert_eq!(request.request.headers.get("x-plugin"), Some(&json!(true))); let mut registrations = ctx.into_registrations(); rollback_registrations(&mut registrations); @@ -484,7 +484,7 @@ fn test_plugin_registration_context_registers_and_rolls_back() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), None); + assert_eq!(request.request.headers.get("x-plugin"), None); reset_global(); } @@ -515,7 +515,7 @@ fn test_initialize_plugins_registers_and_clears_components() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), Some(&json!(true))); + assert_eq!(request.request.headers.get("x-plugin"), Some(&json!(true))); clear_plugin_configuration().unwrap(); let request = llm_request_intercepts( @@ -526,7 +526,7 @@ fn test_initialize_plugins_registers_and_clears_components() { }, ) .unwrap(); - assert_eq!(request.headers.get("x-plugin"), None); + assert_eq!(request.request.headers.get("x-plugin"), None); reset_global(); } @@ -759,7 +759,9 @@ fn test_plugin_registration_context_covers_all_registration_helpers() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), ) .unwrap(); ctx.register_llm_execution_intercept( @@ -1059,7 +1061,9 @@ fn test_plugin_registration_context_maps_duplicate_registration_errors() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), ) .unwrap(); expect_registration_failed( @@ -1067,7 +1071,9 @@ fn test_plugin_registration_context_maps_duplicate_registration_errors() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), ), "llm request intercept:", ); @@ -1246,7 +1252,9 @@ fn test_plugin_registration_context_maps_deregistration_errors() { "llm-request", 1, false, - Arc::new(|_name, request, annotated| Ok((request, annotated))), + Arc::new(|_name, request, annotated| { + Ok(LlmRequestInterceptOutcome::new(request, annotated)) + }), ) .unwrap(); ctx.register_tool_sanitize_request_guardrail( diff --git a/crates/core/tests/unit/shared_tests.rs b/crates/core/tests/unit/shared_tests.rs index fe3eeab59..851ea4eee 100644 --- a/crates/core/tests/unit/shared_tests.rs +++ b/crates/core/tests/unit/shared_tests.rs @@ -8,7 +8,7 @@ use std::sync::Arc; use serde_json::{Map, json}; -use crate::api::llm::LlmRequest; +use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use crate::api::registry::{deregister_llm_request_intercept, register_llm_request_intercept}; use crate::api::runtime::NemoRelayContextState; use crate::api::runtime::global_context; @@ -156,25 +156,34 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { Arc::new(|_name, mut request, annotated| { assert!(annotated.is_none()); request.headers.insert("x-no-codec".into(), json!(true)); - Ok((request, None)) + let mut annotated = SharedTestCodec.decode(&request)?; + annotated.model = Some("interceptor-model".into()); + Ok(LlmRequestInterceptOutcome::new(request, Some(annotated))) }), ) .unwrap(); - let (request_without_codec, annotated_without_codec) = run_request_intercepts_with_codec( - "shared", - LlmRequest { - headers: Map::new(), - content: json!({"prompt": "hello"}), - }, - None, - ) - .unwrap(); + let (request_without_codec, annotated_without_codec, pending_marks_without_codec) = + run_request_intercepts_with_codec( + "shared", + LlmRequest { + headers: Map::new(), + content: json!({"prompt": "hello"}), + }, + None, + ) + .unwrap(); assert_eq!( request_without_codec.headers.get("x-no-codec"), Some(&json!(true)) ); - assert!(annotated_without_codec.is_none()); + assert_eq!( + annotated_without_codec + .as_deref() + .and_then(|annotated| annotated.model.as_deref()), + Some("interceptor-model") + ); + assert!(pending_marks_without_codec.is_empty()); deregister_llm_request_intercept("shared-none").unwrap(); register_llm_request_intercept( @@ -185,21 +194,22 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { let mut annotated = annotated.expect("codec should provide annotated request"); annotated.model = Some("intercepted-model".into()); request.headers.insert("x-codec".into(), json!(true)); - Ok((request, Some(annotated))) + Ok(LlmRequestInterceptOutcome::new(request, Some(annotated))) }), ) .unwrap(); let codec: Arc = Arc::new(SharedTestCodec); - let (request_with_codec, annotated_with_codec) = run_request_intercepts_with_codec( - "shared", - LlmRequest { - headers: Map::new(), - content: json!({"prompt": "hello"}), - }, - Some(codec), - ) - .unwrap(); + let (request_with_codec, annotated_with_codec, pending_marks_with_codec) = + run_request_intercepts_with_codec( + "shared", + LlmRequest { + headers: Map::new(), + content: json!({"prompt": "hello"}), + }, + Some(codec), + ) + .unwrap(); assert_eq!( request_with_codec.headers.get("x-codec"), @@ -215,6 +225,7 @@ fn test_run_request_intercepts_with_codec_none_and_codec_paths() { .and_then(|annotated| annotated.model.as_deref()), Some("intercepted-model") ); + assert!(pending_marks_with_codec.is_empty()); deregister_llm_request_intercept("shared-codec").unwrap(); reset_global(); diff --git a/crates/ffi/nemo_relay.h b/crates/ffi/nemo_relay.h index f07b3b2dc..2b1eaa5fb 100644 --- a/crates/ffi/nemo_relay.h +++ b/crates/ffi/nemo_relay.h @@ -254,15 +254,21 @@ typedef char *(*NemoRelayLlmConditionalCb)(void *user_data, const struct FfiLLMR * C callback type for LLM request intercepts with unified annotated-aware * signature. Receives the intercept name, the opaque `FfiLLMRequest`, and * optionally the annotated request as a JSON C string (null if no Codec - * resolved). Writes transformed outputs to `out_request` and - * `out_annotated_json`. Returns `NemoRelayStatus`. + * resolved). Writes one owned canonical outcome JSON string to + * `out_outcome_json`. Any non-null string written there must be allocated by + * `nemo_relay_llm_request_intercept_outcome_json_new` or by an allocation + * compatible with `nemo_relay_string_free`. Ownership transfers to Relay + * when the callback returns; the callback must not free or reuse the string + * afterward. Relay frees it exactly once, even when the callback returns an + * error status. With a Codec, the outcome must preserve request content and + * return the annotation; only request headers and annotation fields are + * writable. Returns `NemoRelayStatus`. */ typedef NemoRelayStatus (*NemoRelayLlmRequestInterceptCb)(void *user_data, const char *name, const struct FfiLLMRequest *request, const char *annotated_json, - struct FfiLLMRequest **out_request, - char **out_annotated_json); + char **out_outcome_json); /** * Runtime-provided "next" callback for LLM execution middleware chain. @@ -399,6 +405,30 @@ NemoRelayStatus nemo_relay_llm_request_intercepts(const char *name, const char *native_json, char **out); +/** + * Allocate canonical JSON for a C LLM request-intercept callback result. + * + * `annotated_json` may be null. `pending_marks_json` may be null, in which + * case an empty list is serialized. When used by a + * `NemoRelayLlmRequestInterceptCb`, assign the successful output to the + * callback's `out_outcome_json`; ownership transfers to Relay when the + * callback returns, so the callback must not free or reuse it. Outside a + * callback, the caller owns the returned string and must release it with + * `nemo_relay_string_free`. + * + * # Safety + * + * `request` must point to a live `FfiLLMRequest`, optional JSON inputs must + * be valid null-terminated strings when non-null, and `out_outcome_json` must + * be writable. A successful output must either be transferred through a + * callback's `out_outcome_json` or freed by its caller with + * `nemo_relay_string_free`. + */ +NemoRelayStatus nemo_relay_llm_request_intercept_outcome_json_new(const struct FfiLLMRequest *request, + const char *annotated_json, + const char *pending_marks_json, + char **out_outcome_json); + /** * Run the registered LLM conditional execution guardrail chain. * diff --git a/crates/ffi/src/api/mod.rs b/crates/ffi/src/api/mod.rs index a20e27ccf..a0d9ee5e2 100644 --- a/crates/ffi/src/api/mod.rs +++ b/crates/ffi/src/api/mod.rs @@ -35,14 +35,15 @@ use crate::error::{ status_from_plugin_error, }; use crate::types::{ - FfiAtifExporter, FfiAtofExporter, FfiCodecHandle, FfiLLMHandle, FfiOpenInferenceSubscriber, - FfiOpenTelemetrySubscriber, FfiPluginContext, FfiScopeHandle, FfiScopeStack, - FfiThreadScopeStackBinding, FfiToolHandle, NemoRelayScopeType, + FfiAtifExporter, FfiAtofExporter, FfiCodecHandle, FfiLLMHandle, FfiLLMRequest, + FfiOpenInferenceSubscriber, FfiOpenTelemetrySubscriber, FfiPluginContext, FfiScopeHandle, + FfiScopeStack, FfiThreadScopeStackBinding, FfiToolHandle, NemoRelayScopeType, }; pub use crate::types::{nemo_relay_openinference_subscriber_free, nemo_relay_otel_subscriber_free}; use libc::c_char; +use nemo_relay::api::event::PendingMarkSpec; use nemo_relay::api::llm as core_llm_api; -use nemo_relay::api::llm::{LlmAttributes, LlmRequest}; +use nemo_relay::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::api::registry as core_registry_api; use nemo_relay::api::runtime::{LlmExecutionNextFn, LlmStreamExecutionNextFn, ToolExecutionNextFn}; use nemo_relay::api::runtime::{ @@ -239,6 +240,87 @@ pub unsafe extern "C" fn nemo_relay_llm_request_intercepts( } } +/// Allocate canonical JSON for a C LLM request-intercept callback result. +/// +/// `annotated_json` may be null. `pending_marks_json` may be null, in which +/// case an empty list is serialized. When used by a +/// `NemoRelayLlmRequestInterceptCb`, assign the successful output to the +/// callback's `out_outcome_json`; ownership transfers to Relay when the +/// callback returns, so the callback must not free or reuse it. Outside a +/// callback, the caller owns the returned string and must release it with +/// `nemo_relay_string_free`. +/// +/// # Safety +/// +/// `request` must point to a live `FfiLLMRequest`, optional JSON inputs must +/// be valid null-terminated strings when non-null, and `out_outcome_json` must +/// be writable. A successful output must either be transferred through a +/// callback's `out_outcome_json` or freed by its caller with +/// `nemo_relay_string_free`. +#[unsafe(no_mangle)] +pub unsafe extern "C" fn nemo_relay_llm_request_intercept_outcome_json_new( + request: *const FfiLLMRequest, + annotated_json: *const c_char, + pending_marks_json: *const c_char, + out_outcome_json: *mut *mut c_char, +) -> NemoRelayStatus { + clear_last_error(); + if out_outcome_json.is_null() { + set_last_error("out_outcome_json must be non-null"); + return NemoRelayStatus::NullPointer; + } + unsafe { *out_outcome_json = std::ptr::null_mut() }; + if request.is_null() { + set_last_error("request must be non-null"); + return NemoRelayStatus::NullPointer; + } + let annotated_request = if annotated_json.is_null() { + None + } else { + let value = match c_str_to_json(annotated_json) { + Some(value) => value, + None => return NemoRelayStatus::InvalidJson, + }; + match serde_json::from_value(value) { + Ok(value) => Some(value), + Err(error) => { + set_last_error(&format!("invalid annotated request JSON: {error}")); + return NemoRelayStatus::InvalidJson; + } + } + }; + let pending_marks = if pending_marks_json.is_null() { + Vec::new() + } else { + let value = match c_str_to_json(pending_marks_json) { + Some(value) => value, + None => return NemoRelayStatus::InvalidJson, + }; + match serde_json::from_value::>(value) { + Ok(value) => value, + Err(error) => { + set_last_error(&format!("invalid pending marks JSON: {error}")); + return NemoRelayStatus::InvalidJson; + } + } + }; + let outcome = LlmRequestInterceptOutcome { + request: unsafe { &*request }.0.clone(), + annotated_request, + pending_marks, + }; + match serde_json::to_value(outcome) { + Ok(value) => { + unsafe { *out_outcome_json = json_to_c_string(&value) }; + NemoRelayStatus::Ok + } + Err(error) => { + set_last_error(&format!("failed to serialize intercept outcome: {error}")); + NemoRelayStatus::Internal + } + } +} + /// Run the registered LLM conditional execution guardrail chain. /// /// Returns `NemoRelayStatus::Ok` if all guardrails pass, or diff --git a/crates/ffi/src/callable.rs b/crates/ffi/src/callable.rs index ea3bf4c28..9baaba2b2 100644 --- a/crates/ffi/src/callable.rs +++ b/crates/ffi/src/callable.rs @@ -31,7 +31,7 @@ use serde_json::Value as Json; use tokio_stream::{Stream, StreamExt}; use nemo_relay::api::event::Event; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::codec::request::AnnotatedLlmRequest as AnnotatedLLMRequest; use nemo_relay::codec::traits::LlmCodec; use nemo_relay::error::{FlowError, Result}; @@ -171,15 +171,21 @@ pub type NemoRelayCodecEncodeFn = Option< /// C callback type for LLM request intercepts with unified annotated-aware /// signature. Receives the intercept name, the opaque `FfiLLMRequest`, and /// optionally the annotated request as a JSON C string (null if no Codec -/// resolved). Writes transformed outputs to `out_request` and -/// `out_annotated_json`. Returns `NemoRelayStatus`. +/// resolved). Writes one owned canonical outcome JSON string to +/// `out_outcome_json`. Any non-null string written there must be allocated by +/// `nemo_relay_llm_request_intercept_outcome_json_new` or by an allocation +/// compatible with `nemo_relay_string_free`. Ownership transfers to Relay +/// when the callback returns; the callback must not free or reuse the string +/// afterward. Relay frees it exactly once, even when the callback returns an +/// error status. With a Codec, the outcome must preserve request content and +/// return the annotation; only request headers and annotation fields are +/// writable. Returns `NemoRelayStatus`. pub type NemoRelayLlmRequestInterceptCb = unsafe extern "C" fn( user_data: *mut libc::c_void, name: *const c_char, request: *const FfiLLMRequest, annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus; /// Callback for collecting intercepted stream chunks. Invoked with each chunk @@ -559,8 +565,8 @@ pub fn wrap_json_fn( /// Wrap a C LLM request intercept callback (annotated-aware) into a Rust /// `LlmRequestInterceptFn` closure. The callback receives the intercept name, -/// the opaque `FfiLLMRequest`, and the annotated JSON (or null). It writes -/// the transformed request and annotated JSON to output pointers. +/// the opaque `FfiLLMRequest`, and the annotated JSON (or null). It writes one +/// owned canonical outcome JSON string. pub fn wrap_llm_request_intercept_fn( cb: NemoRelayLlmRequestInterceptCb, user_data: *mut libc::c_void, @@ -587,9 +593,7 @@ pub fn wrap_llm_request_intercept_fn( std::ptr::null() }; - // Initialize output pointers - let mut out_request: *mut FfiLLMRequest = std::ptr::null_mut(); - let mut out_annotated: *mut c_char = std::ptr::null_mut(); + let mut out_outcome: *mut c_char = std::ptr::null_mut(); let status = unsafe { cb( @@ -597,8 +601,7 @@ pub fn wrap_llm_request_intercept_fn( c_name.as_ptr(), ffi_req, annotated_ptr, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; @@ -606,32 +609,29 @@ pub fn wrap_llm_request_intercept_fn( unsafe { drop(Box::from_raw(ffi_req)) }; if status != NemoRelayStatus::Ok { + unsafe { nemo_relay_string_free_internal(out_outcome) }; let message = last_error_message() .unwrap_or_else(|| "request intercept callback failed".to_string()); return Err(FlowError::Internal(message)); } - // Read output request - let new_request = if out_request.is_null() { + if out_outcome.is_null() { return Err(FlowError::Internal( - "request intercept returned null out_request".to_string(), + "request intercept returned null out_outcome_json".to_string(), )); - } else { - let boxed = unsafe { Box::from_raw(out_request) }; - boxed.0 - }; - - // Read output annotated - let new_annotated = if out_annotated.is_null() { - None - } else { - let s = unsafe { CStr::from_ptr(out_annotated) }.to_string_lossy(); - let parsed: Option = serde_json::from_str(&s).ok(); - unsafe { nemo_relay_string_free_internal(out_annotated) }; - parsed - }; - - Ok((new_request, new_annotated)) + } + let outcome = unsafe { CStr::from_ptr(out_outcome) } + .to_str() + .map_err(|error| FlowError::Internal(format!("invalid outcome UTF-8: {error}"))) + .and_then(|json| { + serde_json::from_str::(json).map_err(|error| { + FlowError::Internal(format!( + "invalid LLM request intercept outcome JSON: {error}" + )) + }) + }); + unsafe { nemo_relay_string_free_internal(out_outcome) }; + outcome }, ) } diff --git a/crates/ffi/tests/integration/api_tests.rs b/crates/ffi/tests/integration/api_tests.rs index 5f38fbcbe..d6668df4f 100644 --- a/crates/ffi/tests/integration/api_tests.rs +++ b/crates/ffi/tests/integration/api_tests.rs @@ -377,10 +377,15 @@ unsafe extern "C" fn llm_request_intercept_cb( _name: *const c_char, request: *const FfiLLMRequest, _annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { - unsafe { *out_request = llm_request_cb(ptr::null_mut(), request) }; + let transformed = unsafe { Box::from_raw(llm_request_cb(ptr::null_mut(), request)) }; + let outcome = json!({ + "request": transformed.0, + "annotated_request": null, + "pending_marks": [], + }); + unsafe { *out_outcome_json = CString::new(outcome.to_string()).unwrap().into_raw() }; NemoRelayStatus::Ok } diff --git a/crates/ffi/tests/integration/callable_extra_tests.rs b/crates/ffi/tests/integration/callable_extra_tests.rs index 0401abe2b..d82c158f0 100644 --- a/crates/ffi/tests/integration/callable_extra_tests.rs +++ b/crates/ffi/tests/integration/callable_extra_tests.rs @@ -40,8 +40,7 @@ unsafe extern "C" fn llm_request_intercept_status_error_cb( _name: *const c_char, _request: *const FfiLLMRequest, _annotated_json: *const c_char, - _out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + _out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { NemoRelayStatus::Internal } @@ -51,8 +50,7 @@ unsafe extern "C" fn llm_request_intercept_null_out_request_cb( _name: *const c_char, _request: *const FfiLLMRequest, _annotated_json: *const c_char, - _out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + _out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { NemoRelayStatus::Ok } @@ -60,15 +58,11 @@ unsafe extern "C" fn llm_request_intercept_null_out_request_cb( unsafe extern "C" fn llm_request_intercept_invalid_annotated_cb( _user_data: *mut libc::c_void, _name: *const c_char, - request: *const FfiLLMRequest, + _request: *const FfiLLMRequest, _annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { - unsafe { - *out_request = Box::into_raw(Box::new(FfiLLMRequest((&*request).0.clone()))); - *out_annotated_json = CString::new("not-json").unwrap().into_raw(); - } + unsafe { *out_outcome_json = CString::new("not-json").unwrap().into_raw() }; NemoRelayStatus::Ok } @@ -217,16 +211,18 @@ fn test_callable_extra_request_intercept_and_codec_paths() { None, ); let err = intercept_null("llm", request.clone(), None).unwrap_err(); - assert!(err.to_string().contains("null out_request")); + assert!(err.to_string().contains("null out_outcome_json")); let intercept_invalid_annotated = wrap_llm_request_intercept_fn( llm_request_intercept_invalid_annotated_cb, ptr::null_mut(), None, ); - let (_request_out, annotated_out) = intercept_invalid_annotated("llm", request.clone(), None) - .expect("invalid annotated JSON should be tolerated as None"); - assert!(annotated_out.is_none()); + let err = intercept_invalid_annotated("llm", request.clone(), None).unwrap_err(); + assert!( + err.to_string() + .contains("invalid LLM request intercept outcome JSON") + ); let sanitize = wrap_llm_sanitize_request_fn(llm_request_passthrough_cb, ptr::null_mut(), None); let sanitized = sanitize(request.clone()); diff --git a/crates/ffi/tests/unit/api/core_tests.rs b/crates/ffi/tests/unit/api/core_tests.rs index 08ce98c2c..261ef8a0e 100644 --- a/crates/ffi/tests/unit/api/core_tests.rs +++ b/crates/ffi/tests/unit/api/core_tests.rs @@ -5,6 +5,62 @@ use super::*; +#[test] +fn test_ffi_llm_request_intercept_outcome_json_allocation_and_validation() { + let headers = cstring(r#"{"x-test":true}"#); + let content = cstring(r#"{"model":"test"}"#); + let request = unsafe { nemo_relay_llm_request_new(headers.as_ptr(), content.as_ptr()) }; + assert!(!request.is_null()); + + let marks = cstring(r#"[{"name":"first"},{"name":"second","data":{"order":2}}]"#); + let mut outcome_json = ptr::null_mut(); + assert_eq!( + unsafe { + api::nemo_relay_llm_request_intercept_outcome_json_new( + request, + ptr::null(), + marks.as_ptr(), + &mut outcome_json, + ) + }, + NemoRelayStatus::Ok + ); + let outcome = unsafe { returned_json(outcome_json) }; + assert_eq!(outcome["request"]["headers"]["x-test"], true); + assert_eq!(outcome["annotated_request"], Json::Null); + assert_eq!(outcome["pending_marks"][0]["name"], "first"); + assert_eq!(outcome["pending_marks"][1]["data"]["order"], 2); + + let malformed_marks = cstring(r#"{"name":"not-an-array"}"#); + assert_eq!( + unsafe { + api::nemo_relay_llm_request_intercept_outcome_json_new( + request, + ptr::null(), + malformed_marks.as_ptr(), + &mut outcome_json, + ) + }, + NemoRelayStatus::InvalidJson + ); + assert!(outcome_json.is_null()); + outcome_json = std::ptr::dangling_mut(); + assert_eq!( + unsafe { + api::nemo_relay_llm_request_intercept_outcome_json_new( + ptr::null(), + ptr::null(), + ptr::null(), + &mut outcome_json, + ) + }, + NemoRelayStatus::NullPointer + ); + assert!(outcome_json.is_null()); + + unsafe { nemo_relay_llm_request_free(request) }; +} + #[test] fn test_ffi_plugin_config_validate_initialize_and_clear() { let _guard = TEST_MUTEX.lock().unwrap(); diff --git a/crates/ffi/tests/unit/api/registry_tests.rs b/crates/ffi/tests/unit/api/registry_tests.rs index 378358057..72afba048 100644 --- a/crates/ffi/tests/unit/api/registry_tests.rs +++ b/crates/ffi/tests/unit/api/registry_tests.rs @@ -381,7 +381,7 @@ fn test_ffi_helper_rejection_and_null_name_paths() { NemoRelayStatus::Ok ); let llm_json = returned_json(llm_out); - assert_eq!(llm_json["content"]["model"], json!("ffi-model")); + assert_eq!(llm_json["request"]["content"]["model"], json!("ffi-model")); assert_eq!( nemo_relay_llm_request_intercepts(llm_name.as_ptr(), request.as_ptr(), ptr::null_mut()), @@ -1536,7 +1536,10 @@ fn test_ffi_llm_execute_stream_and_atif_exporter() { NemoRelayStatus::Ok ); let helper_json = returned_json(helper_out); - assert_eq!(helper_json["content"]["intercepted"], json!(true)); + assert_eq!( + helper_json["request"]["content"]["intercepted"], + json!(true) + ); assert_eq!( nemo_relay_llm_conditional_execution(request.as_ptr()), diff --git a/crates/ffi/tests/unit/api_tests.rs b/crates/ffi/tests/unit/api_tests.rs index 8f4e728a0..78f37919d 100644 --- a/crates/ffi/tests/unit/api_tests.rs +++ b/crates/ffi/tests/unit/api_tests.rs @@ -374,10 +374,15 @@ unsafe extern "C" fn llm_request_intercept_cb( _name: *const c_char, request: *const FfiLLMRequest, _annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { - unsafe { *out_request = llm_request_cb(ptr::null_mut(), request) }; + let transformed = unsafe { Box::from_raw(llm_request_cb(ptr::null_mut(), request)) }; + let outcome = json!({ + "request": transformed.0, + "annotated_request": null, + "pending_marks": [], + }); + unsafe { *out_outcome_json = CString::new(outcome.to_string()).unwrap().into_raw() }; NemoRelayStatus::Ok } @@ -386,8 +391,7 @@ unsafe extern "C" fn llm_request_intercept_fail_cb( _name: *const c_char, _request: *const FfiLLMRequest, _annotated_json: *const c_char, - _out_request: *mut *mut FfiLLMRequest, - _out_annotated_json: *mut *mut c_char, + _out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { crate::error::set_last_error("llm request intercept callback failed"); NemoRelayStatus::Internal diff --git a/crates/ffi/tests/unit/callable_tests.rs b/crates/ffi/tests/unit/callable_tests.rs index 3089f5e83..b8fc9a30e 100644 --- a/crates/ffi/tests/unit/callable_tests.rs +++ b/crates/ffi/tests/unit/callable_tests.rs @@ -103,20 +103,24 @@ unsafe extern "C" fn llm_request_intercept_cb( _name: *const c_char, request: *const FfiLLMRequest, annotated_json: *const c_char, - out_request: *mut *mut FfiLLMRequest, - out_annotated_json: *mut *mut c_char, + out_outcome_json: *mut *mut c_char, ) -> NemoRelayStatus { let mut req = unsafe { (&*request).0.clone() }; req.content["intercepted"] = json!(true); - unsafe { *out_request = Box::into_raw(Box::new(FfiLLMRequest(req))) }; - if annotated_json.is_null() { - unsafe { *out_annotated_json = std::ptr::null_mut() }; + let annotated = if annotated_json.is_null() { + Json::Null } else { let s = unsafe { CStr::from_ptr(annotated_json) } .to_string_lossy() .into_owned(); - unsafe { *out_annotated_json = CString::new(s).unwrap().into_raw() }; - } + serde_json::from_str(&s).unwrap() + }; + let outcome = json!({ + "request": req, + "annotated_request": annotated, + "pending_marks": [], + }); + unsafe { *out_outcome_json = CString::new(outcome.to_string()).unwrap().into_raw() }; NemoRelayStatus::Ok } @@ -283,8 +287,8 @@ fn test_wrap_tool_exec_and_intercept_callbacks() { fn test_wrap_llm_request_response_and_conditional_callbacks() { let request_intercept = wrap_llm_request_intercept_fn(llm_request_intercept_cb, std::ptr::null_mut(), None); - let (intercepted, _annotated) = request_intercept("llm", make_request(), None).unwrap(); - assert_eq!(intercepted.content["intercepted"], json!(true)); + let outcome = request_intercept("llm", make_request(), None).unwrap(); + assert_eq!(outcome.request.content["intercepted"], json!(true)); let sanitize_request = wrap_llm_sanitize_request_fn(llm_request_null_cb, std::ptr::null_mut(), None); @@ -338,10 +342,11 @@ fn test_wrap_llm_request_intercept_with_annotated_input() { stream: None, extra: serde_json::Map::from_iter([("annotated".into(), json!(true))]), }; - let (intercepted, annotated_out) = - request_intercept("llm", make_request(), Some(annotated)).unwrap(); - assert_eq!(intercepted.content["intercepted"], json!(true)); - let annotated_out = annotated_out.expect("expected annotated request output"); + let outcome = request_intercept("llm", make_request(), Some(annotated)).unwrap(); + assert_eq!(outcome.request.content["intercepted"], json!(true)); + let annotated_out = outcome + .annotated_request + .expect("expected annotated request output"); assert_eq!(annotated_out.model.as_deref(), Some("test-model")); assert_eq!(annotated_out.extra.get("annotated"), Some(&json!(true))); } diff --git a/crates/node/plugin.d.ts b/crates/node/plugin.d.ts index 4fbb6be8a..e0dc1d4ba 100644 --- a/crates/node/plugin.d.ts +++ b/crates/node/plugin.d.ts @@ -84,7 +84,14 @@ export interface PluginContext { breakChain: boolean, callback: (args: { name: string; request: Json; annotated: Json | null }) => { request: Json; - annotated: Json | null; + annotated?: Json | null; + pendingMarks?: Array<{ + name: string; + category?: string | null; + categoryProfile?: Json; + data?: Json; + metadata?: Json; + }>; }, ): void; /** Register an LLM execution intercept for this component. */ diff --git a/crates/node/src/api/mod.rs b/crates/node/src/api/mod.rs index 3796607e7..503a335b3 100644 --- a/crates/node/src/api/mod.rs +++ b/crates/node/src/api/mod.rs @@ -2259,6 +2259,9 @@ pub fn register_llm_request_intercept( name: String, priority: i32, break_chain: bool, + #[napi( + ts_arg_type = "(args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; pendingMarks?: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> }" + )] callable: ThreadsafeFunction, ) -> Result<()> { core_registry_api::register_llm_request_intercept( @@ -2729,6 +2732,9 @@ pub fn scope_register_llm_request_intercept( name: String, priority: i32, break_chain: bool, + #[napi( + ts_arg_type = "(args: { name: string; request: Json; annotated: Json | null }) => { request: Json; annotated?: Json | null; pendingMarks?: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> }" + )] callable: ThreadsafeFunction, ) -> Result<()> { let uuid = uuid::Uuid::parse_str(&scope_uuid) @@ -2942,7 +2948,9 @@ pub fn tool_conditional_execution(env: Env, name: String, args: Json) -> Result< /// Run the registered LLM request intercept chain on the given request. /// The `request` should be a JSON object with `headers` and `content` fields matching /// the `LlmRequest` schema. Returns the transformed request as JSON. -#[napi(ts_return_type = "Promise")] +#[napi( + ts_return_type = "Promise<{ request: Json; annotated: Json | null; pendingMarks: Array<{ name: string; category?: string | null; categoryProfile?: Json; data?: Json; metadata?: Json }> }>" +)] pub fn llm_request_intercepts(env: Env, name: String, request: Json) -> Result { let llm_request: LlmRequest = serde_json::from_value(request) .map_err(|e| napi::Error::from_reason(format!("invalid LlmRequest: {e}")))?; @@ -2952,7 +2960,13 @@ pub fn llm_request_intercepts(env: Env, name: String, request: Json) -> Result, + #[serde(default)] + category_profile: Option, + #[serde(default)] + data: Option, + #[serde(default)] + metadata: Option, +} + +impl From for PendingMarkSpec { + fn from(mark: JsPendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +impl From for JsPendingMarkSpec { + fn from(mark: PendingMarkSpec) -> Self { + Self { + name: mark.name, + category: mark.category, + category_profile: mark.category_profile, + data: mark.data, + metadata: mark.metadata, + } + } +} + +/// Convert canonical pending marks to JavaScript-facing DTOs. +#[must_use] +pub(crate) fn js_pending_marks(marks: Vec) -> Vec { + marks.into_iter().map(Into::into).collect() +} + fn recv_json_or_null(rx: std::sync::mpsc::Receiver, error_prefix: &str) -> Json { rx.recv().unwrap_or_else(|e| { record_callback_error(format!("{error_prefix}: {e}")); @@ -208,7 +254,9 @@ pub fn wrap_js_tool_exec_fn( /// /// The JS callback receives a single JSON object /// `{ name: string, request: LlmRequest, annotated: AnnotatedLlmRequest | null }` -/// and must return `{ request: LlmRequest, annotated: AnnotatedLlmRequest | null }`. +/// and must return `{ request, annotated?, pendingMarks? }`. +/// When `annotated` is non-null, request content is read-only and provider-body +/// edits must be made through the returned annotation; headers remain writable. pub fn wrap_js_llm_request_intercept_fn( func: ThreadsafeFunction, ) -> LlmRequestInterceptFn { @@ -217,7 +265,7 @@ pub fn wrap_js_llm_request_intercept_fn( move |name: &str, request: LlmRequest, annotated: Option| - -> Result<(LlmRequest, Option)> { + -> Result { let func = func.clone(); let req_json = serde_json::to_value(&request).unwrap_or(Json::Null); let annotated_json = annotated @@ -245,32 +293,23 @@ pub fn wrap_js_llm_request_intercept_fn( } let result = recv_json_result(rx, "JS LLM request intercept callback failed")?; - // Validate expected shape: { "request": {...}, "annotated": ... } - let obj = result.as_object().ok_or_else(|| { - FlowError::Internal( - "JS LLM request intercept: expected object with 'request' and 'annotated' fields".to_string(), - ) - })?; - - let new_request: LlmRequest = serde_json::from_value( - obj.get("request").cloned().unwrap_or(Json::Null), - ) - .map_err(|e| { - FlowError::Internal(format!( - "JS LLM request intercept: failed to deserialize request: {e}" - )) + #[derive(Deserialize)] + #[serde(rename_all = "camelCase")] + struct JsOutcome { + request: LlmRequest, + #[serde(default)] + annotated: Option, + #[serde(default)] + pending_marks: Vec, + } + let outcome: JsOutcome = serde_json::from_value(result).map_err(|e| { + FlowError::Internal(format!("invalid JS LLM request intercept outcome: {e}")) })?; - - let new_annotated: Option = match obj.get("annotated") { - Some(Json::Null) | None => None, - Some(val) => Some(serde_json::from_value(val.clone()).map_err(|e| { - FlowError::Internal(format!( - "JS LLM request intercept: failed to deserialize annotated: {e}" - )) - })?), - }; - - Ok((new_request, new_annotated)) + Ok(LlmRequestInterceptOutcome { + request: outcome.request, + annotated_request: outcome.annotated, + pending_marks: outcome.pending_marks.into_iter().map(Into::into).collect(), + }) }, ) } diff --git a/crates/node/tests/codec_tests.mjs b/crates/node/tests/codec_tests.mjs index c6bd40bc2..400ad62d6 100644 --- a/crates/node/tests/codec_tests.mjs +++ b/crates/node/tests/codec_tests.mjs @@ -153,6 +153,42 @@ describe('Codec pipeline integration', () => { } }); + it('rejects raw request content edits before provider execution', async () => { + let providerCalled = false; + registerLlmRequestIntercept('raw-content-test', 10, false, ({ request, annotated }) => ({ + request: { + ...request, + content: { + ...request.content, + model: 'raw-model-edit', + }, + }, + annotated, + })); + + const handle = pushScope('raw-content-scope', ScopeType.Agent); + try { + await assert.rejects( + () => + execWithCodec( + 'test-llm', + makeRequest(), + async () => { + providerCalled = true; + return { choices: [] }; + }, + mockDecode, + mockEncode, + ), + /request\.content/, + ); + assert.equal(providerCalled, false); + } finally { + popScope(handle); + deregisterLlmRequestIntercept('raw-content-test'); + } + }); + it('different codec functions produce different results', async () => { let usedModel = null; diff --git a/crates/node/tests/llm_tests.mjs b/crates/node/tests/llm_tests.mjs index fdc9000ba..022afb8ad 100644 --- a/crates/node/tests/llm_tests.mjs +++ b/crates/node/tests/llm_tests.mjs @@ -592,7 +592,7 @@ describe('LLM intercepts', () => { null, null, ), - /expected object with 'request' and 'annotated' fields/i, + /invalid JS LLM request intercept outcome/i, ); } finally { deregisterLlmRequestIntercept('node_llm_req_bad'); @@ -729,7 +729,7 @@ describe('LLM intercepts', () => { null, ); - for (; ;) { + for (;;) { const chunk = await stream.next(); if (chunk === null) { break; @@ -769,7 +769,7 @@ describe('LLM intercepts', () => { null, ); - for (; ;) { + for (;;) { const chunk = await stream.next(); if (chunk === null) { break; @@ -826,11 +826,36 @@ describe('LLM intercepts', () => { return { request, annotated, + pendingMarks: [ + { + name: 'request.first', + categoryProfile: { subtype: 'optimizer.saved_tokens' }, + data: { order: 1 }, + }, + { name: 'request.second', metadata: { source: 'node' } }, + ], }; }); const result = await llmRequestIntercepts('helper_llm', makeNative()); - assert.equal(result.content.helper, true); + assert.equal(result.request.content.helper, true); + assert.equal(result.annotated, null); + assert.deepEqual(result.pendingMarks, [ + { + name: 'request.first', + category: null, + categoryProfile: { subtype: 'optimizer.saved_tokens' }, + data: { order: 1 }, + metadata: null, + }, + { + name: 'request.second', + category: null, + categoryProfile: null, + data: null, + metadata: { source: 'node' }, + }, + ]); deregisterLlmRequestIntercept('node_llm_req_helper'); }); diff --git a/crates/node/tests/scope_local_tests.mjs b/crates/node/tests/scope_local_tests.mjs index ebefb65d1..313e78e00 100644 --- a/crates/node/tests/scope_local_tests.mjs +++ b/crates/node/tests/scope_local_tests.mjs @@ -913,7 +913,7 @@ describe('Scope-local LLM intercepts', () => { null, null, ), - /expected object with 'request' and 'annotated' fields/i, + /invalid JS LLM request intercept outcome/i, ); } finally { scopeDeregisterLlmRequestIntercept(scope.uuid, 'sl_llm_req_bad'); diff --git a/crates/plugin/src/lib.rs b/crates/plugin/src/lib.rs index 31f0c4e40..662b191a2 100644 --- a/crates/plugin/src/lib.rs +++ b/crates/plugin/src/lib.rs @@ -16,8 +16,10 @@ use std::ptr; use std::sync::Mutex; pub use nemo_relay_types::Json; -pub use nemo_relay_types::api::event::{Event, ScopeCategory}; -pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +pub use nemo_relay_types::api::event::{ + CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, +}; +pub use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; pub use nemo_relay_types::api::scope::{HandleAttributes, ScopeAttributes, ScopeType}; pub use nemo_relay_types::api::tool::ToolAttributes; pub use nemo_relay_types::codec::request::AnnotatedLlmRequest; @@ -251,8 +253,7 @@ pub type NemoRelayNativeLlmRequestInterceptCb = unsafe extern "C" fn( name: *const NemoRelayNativeString, request_json: *const NemoRelayNativeString, annotated_json: *const NemoRelayNativeString, - out_request_json: *mut *mut NemoRelayNativeString, - out_annotated_json: *mut *mut NemoRelayNativeString, + out_outcome_json: *mut *mut NemoRelayNativeString, ) -> NemoRelayStatus; /// Native LLM execution intercept callback. @@ -1467,11 +1468,7 @@ impl<'a> PluginContext<'a> { callback: F, ) -> Result<()> where - F: Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> + F: Fn(&str, LlmRequest, Option) -> Result + Send + Sync + 'static, @@ -2121,25 +2118,19 @@ unsafe extern "C" fn typed_llm_request_intercept_trampoline( name: *const NemoRelayNativeString, request_json: *const NemoRelayNativeString, annotated_json: *const NemoRelayNativeString, - out_request_json: *mut *mut NemoRelayNativeString, - out_annotated_json: *mut *mut NemoRelayNativeString, + out_outcome_json: *mut *mut NemoRelayNativeString, ) -> NemoRelayStatus where - F: Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> + F: Fn(&str, LlmRequest, Option) -> Result + Send + Sync + 'static, { - if user_data.is_null() || out_request_json.is_null() || out_annotated_json.is_null() { + if user_data.is_null() || out_outcome_json.is_null() { return NemoRelayStatus::NullPointer; } unsafe { - *out_request_json = ptr::null_mut(); - *out_annotated_json = ptr::null_mut(); + *out_outcome_json = ptr::null_mut(); } let state = unsafe { &*(user_data as *const TypedCallback) }; let result = catch_unwind(AssertUnwindSafe(|| { @@ -2148,35 +2139,15 @@ where let annotated: Option = read_optional_json_value(&state.host, annotated_json, "annotated LLM request")?; match (state.callback)(&name, request, annotated) { - Ok((request, annotated)) => { - let Some(request) = HostString::from_json(&state.host, &request) else { - set_last_error(&state.host, "failed to allocate LLM request output"); + Ok(outcome) => { + let Some(outcome) = HostString::from_json(&state.host, &outcome) else { + set_last_error(&state.host, "failed to allocate LLM request outcome"); return Ok(NemoRelayStatus::Internal); }; - let annotated = match annotated.as_ref() { - Some(annotated) => { - let Some(annotated) = HostString::from_json(&state.host, annotated) else { - set_last_error( - &state.host, - "failed to allocate annotated LLM request output", - ); - return Ok(NemoRelayStatus::Internal); - }; - Some(annotated) - } - None => None, - }; unsafe { - *out_request_json = request.ptr; - *out_annotated_json = annotated - .as_ref() - .map(|annotated| annotated.ptr) - .unwrap_or(ptr::null_mut()); - } - std::mem::forget(request); - if let Some(annotated) = annotated { - std::mem::forget(annotated); + *out_outcome_json = outcome.ptr; } + std::mem::forget(outcome); Ok(NemoRelayStatus::Ok) } Err(message) => Ok(callback_error(&state.host, message)), @@ -2646,7 +2617,7 @@ pub unsafe fn export_plugin( } unsafe { *out = NemoRelayNativePluginV1::default() }; let host_ref = unsafe { &*host }; - export_plugin_checked(host_ref, out, plugin) + export_plugin_checked(host_ref, out, || plugin) } /// Initializes a native plugin descriptor from a constructor callback. @@ -2670,21 +2641,18 @@ where } unsafe { *out = NemoRelayNativePluginV1::default() }; let host_ref = unsafe { &*host }; - if host_ref.abi_version != NEMO_RELAY_NATIVE_ABI_VERSION { - return NemoRelayStatus::InvalidArg; - } - if host_ref.struct_size < std::mem::size_of::() { - return NemoRelayStatus::InvalidArg; - } - - export_plugin_checked(host_ref, out, constructor()) + export_plugin_checked(host_ref, out, constructor) } -fn export_plugin_checked( +fn export_plugin_checked( host_ref: &NemoRelayNativeHostApiV1, out: *mut NemoRelayNativePluginV1, - plugin: P, -) -> NemoRelayStatus { + constructor: F, +) -> NemoRelayStatus +where + P: NativePlugin, + F: FnOnce() -> P, +{ if host_ref.abi_version != NEMO_RELAY_NATIVE_ABI_VERSION { return NemoRelayStatus::InvalidArg; } @@ -2692,6 +2660,7 @@ fn export_plugin_checked( return NemoRelayStatus::InvalidArg; } + let plugin = constructor(); let kind = plugin.plugin_kind().to_owned(); let allows_multiple_components = plugin.allows_multiple_components(); let Some(kind_handle) = HostString::new(host_ref, &kind) else { diff --git a/crates/plugin/tests/typed_callbacks.rs b/crates/plugin/tests/typed_callbacks.rs index d3a502ffb..b1fba3674 100644 --- a/crates/plugin/tests/typed_callbacks.rs +++ b/crates/plugin/tests/typed_callbacks.rs @@ -14,16 +14,16 @@ use std::sync::{ use nemo_relay_plugin::{ AnnotatedLlmRequest, ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmNext, - LlmRequest, LlmStream, LlmStreamNext, NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, - NemoRelayNativeEventSubscriberCb, NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, - NemoRelayNativeJsonCb, NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, - NemoRelayNativeLlmRequestCb, NemoRelayNativeLlmRequestInterceptCb, - NemoRelayNativeLlmStreamExecutionCb, NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, - NemoRelayNativePluginV1, NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, - NemoRelayNativeScopeStackBinding, NemoRelayNativeScopeType, NemoRelayNativeString, - NemoRelayNativeToolConditionalCb, NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, - NemoRelayNativeWithScopeStackCb, NemoRelayStatus, PluginContext, PluginRuntime, ScopeType, - ToolNext, + LlmRequest, LlmRequestInterceptOutcome, LlmStream, LlmStreamNext, + NEMO_RELAY_NATIVE_ABI_VERSION, NativePlugin, NemoRelayNativeEventSubscriberCb, + NemoRelayNativeFreeFn, NemoRelayNativeHostApiV1, NemoRelayNativeJsonCb, + NemoRelayNativeLlmConditionalCb, NemoRelayNativeLlmExecutionCb, NemoRelayNativeLlmRequestCb, + NemoRelayNativeLlmRequestInterceptCb, NemoRelayNativeLlmStreamExecutionCb, + NemoRelayNativeLlmStreamV1, NemoRelayNativePluginContext, NemoRelayNativePluginV1, + NemoRelayNativeScopeHandle, NemoRelayNativeScopeStack, NemoRelayNativeScopeStackBinding, + NemoRelayNativeScopeType, NemoRelayNativeString, NemoRelayNativeToolConditionalCb, + NemoRelayNativeToolExecutionCb, NemoRelayNativeToolJsonCb, NemoRelayNativeWithScopeStackCb, + NemoRelayStatus, PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, ToolNext, }; use serde_json::{Map, json}; @@ -2580,14 +2580,13 @@ fn typed_callbacks_reject_null_abi_pointers_before_decoding_inputs() { let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm-request-intercept", 0, false, |_name, request, ann| { - Ok((request, ann)) + Ok(LlmRequestInterceptOutcome::new(request, ann)) }) .unwrap(); let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); assert_eq!( unsafe { (registration.cb)( @@ -2595,8 +2594,7 @@ fn typed_callbacks_reject_null_abi_pointers_before_decoding_inputs() { name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }, NemoRelayStatus::NullPointer @@ -2609,20 +2607,6 @@ fn typed_callbacks_reject_null_abi_pointers_before_decoding_inputs() { request, ptr::null(), ptr::null_mut(), - &mut out_annotated, - ) - }, - NemoRelayStatus::NullPointer - ); - assert_eq!( - unsafe { - (registration.cb)( - registration.user_data as *mut c_void, - name, - request, - ptr::null(), - &mut out_request, - ptr::null_mut(), ) }, NemoRelayStatus::NullPointer @@ -2889,14 +2873,13 @@ fn typed_callbacks_report_invalid_json_for_each_decoder_family() { let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm-request", 0, false, |_name, request, ann| { - Ok((request, ann)) + Ok(LlmRequestInterceptOutcome::new(request, ann)) }) .unwrap(); let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let bad_request = host_string(&host, "{not json"); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); assert_eq!( unsafe { (registration.cb)( @@ -2904,8 +2887,7 @@ fn typed_callbacks_report_invalid_json_for_each_decoder_family() { name, bad_request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }, NemoRelayStatus::InvalidJson @@ -2919,8 +2901,7 @@ fn typed_callbacks_report_invalid_json_for_each_decoder_family() { name, request, bad_annotation, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }, NemoRelayStatus::InvalidJson @@ -3091,8 +3072,7 @@ fn typed_callbacks_map_additional_callback_errors() { let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); assert_eq!( unsafe { (registration.cb)( @@ -3100,8 +3080,7 @@ fn typed_callbacks_map_additional_callback_errors() { name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }, NemoRelayStatus::Internal @@ -4231,7 +4210,10 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { let host = test_host(); let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { - Ok((request, Some(test_annotated_llm_request()))) + Ok(LlmRequestInterceptOutcome::new( + request, + Some(test_annotated_llm_request()), + )) }) .unwrap(); @@ -4241,11 +4223,9 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { assert!(!registration.break_chain); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let stale_request = host_string(&host, r#"{"stale":"request"}"#); - let stale_annotated = host_string(&host, r#"{"stale":"annotated"}"#); - let mut out_request = stale_request; - let mut out_annotated = stale_annotated; - *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(1); + let stale_outcome = host_string(&host, r#"{"stale":"outcome"}"#); + let mut out_outcome = stale_outcome; + *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); let live_before = live_host_strings(); let status = unsafe { (registration.cb)( @@ -4253,18 +4233,15 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; assert_eq!(status, NemoRelayStatus::Internal); - assert!(out_request.is_null()); - assert!(out_annotated.is_null()); + assert!(out_outcome.is_null()); assert_eq!(live_host_strings(), live_before); unsafe { - (host.string_free)(stale_request); - (host.string_free)(stale_annotated); + (host.string_free)(stale_outcome); (host.string_free)(name); (host.string_free)(request); registration.free(); @@ -4272,15 +4249,14 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { - Ok((request, None)) + Ok(LlmRequestInterceptOutcome::new(request, None)) }) .unwrap(); let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = Some(0); let status = unsafe { (registration.cb)( @@ -4288,14 +4264,12 @@ fn typed_llm_request_intercept_does_not_publish_partial_outputs() { name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; *STRING_NEW_REMAINING_SUCCESSES.lock().unwrap() = None; assert_eq!(status, NemoRelayStatus::Internal); - assert!(out_request.is_null()); - assert!(out_annotated.is_null()); + assert!(out_outcome.is_null()); unsafe { (host.string_free)(name); (host.string_free)(request); @@ -4313,7 +4287,10 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { assert!(annotated.is_some()); request.headers.insert("x-mutated".into(), json!(true)); request.content["rewritten"] = json!(true); - Ok((request, Some(test_annotated_llm_request()))) + Ok(LlmRequestInterceptOutcome::new( + request, + Some(test_annotated_llm_request()), + )) }) .unwrap(); @@ -4327,24 +4304,22 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { &host, serde_json::to_value(test_annotated_llm_request()).unwrap(), ); - let mut out_request = ptr::null_mut(); - let mut out_annotated = ptr::null_mut(); + let mut out_outcome = ptr::null_mut(); let status = unsafe { (registration.cb)( registration.user_data as *mut c_void, name, request, annotated, - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; assert_eq!(status, NemoRelayStatus::Ok); - let out_request = read_json_and_free(&host, out_request); - assert_eq!(out_request["headers"]["x-mutated"], json!(true)); - assert_eq!(out_request["content"]["rewritten"], json!(true)); - let out_annotated = read_json_and_free(&host, out_annotated); - assert_eq!(out_annotated["messages"], json!([])); + let outcome = read_json_and_free(&host, out_outcome); + assert_eq!(outcome["request"]["headers"]["x-mutated"], json!(true)); + assert_eq!(outcome["request"]["content"]["rewritten"], json!(true)); + assert_eq!(outcome["annotated_request"]["messages"], json!([])); + assert_eq!(outcome["pending_marks"], json!([])); unsafe { (host.string_free)(name); @@ -4355,35 +4330,87 @@ fn typed_llm_request_intercept_round_trips_request_and_annotations() { let mut ctx = test_context(&host); ctx.register_llm_request_intercept("llm", 0, false, |_name, request, _annotated| { - Ok((request, None)) + Ok(LlmRequestInterceptOutcome::new(request, None)) }) .unwrap(); let registration = take_llm_request_intercept_registration(); let name = host_string(&host, "llm"); let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); - let mut out_request = ptr::null_mut(); - let mut out_annotated = host_string(&host, r#"{"stale":true}"#); - let stale_annotated = out_annotated; + let mut out_outcome = host_string(&host, r#"{"stale":true}"#); + let stale_outcome = out_outcome; let status = unsafe { (registration.cb)( registration.user_data as *mut c_void, name, request, ptr::null(), - &mut out_request, - &mut out_annotated, + &mut out_outcome, ) }; assert_eq!(status, NemoRelayStatus::Ok); - assert!(out_annotated.is_null()); + let outcome = read_json_and_free(&host, out_outcome); + assert!(outcome["annotated_request"].is_null()); + assert_eq!(outcome["request"]["content"]["input"], json!(true)); + assert_eq!(outcome["pending_marks"], json!([])); + unsafe { + (host.string_free)(stale_outcome); + (host.string_free)(name); + (host.string_free)(request); + registration.free(); + } +} + +#[test] +fn typed_llm_request_intercept_serializes_canonical_outcome() { + let _guard = begin_test(); + let host = test_host(); + let mut ctx = test_context(&host); + ctx.register_llm_request_intercept("llm", 23, false, |_name, mut request, annotated| { + request.content["rewritten"] = json!(true); + Ok( + LlmRequestInterceptOutcome::new(request, annotated).with_pending_mark( + PendingMarkSpec::builder() + .name("plugin.request.rewritten") + .data(json!({ "saved_tokens": 7 })) + .build(), + ), + ) + }) + .unwrap(); + + let registration = take_llm_request_intercept_registration(); + assert_eq!(registration.priority, 23); + assert!(!registration.break_chain); + let name = host_string(&host, "llm"); + let request = json_host_string(&host, serde_json::to_value(test_llm_request()).unwrap()); + let annotated = json_host_string( + &host, + serde_json::to_value(test_annotated_llm_request()).unwrap(), + ); + let mut out_outcome = ptr::null_mut(); + let status = unsafe { + (registration.cb)( + registration.user_data as *mut c_void, + name, + request, + annotated, + &mut out_outcome, + ) + }; + assert_eq!(status, NemoRelayStatus::Ok); + let outcome = read_json_and_free(&host, out_outcome); + assert_eq!(outcome["request"]["content"]["rewritten"], true); + assert_eq!(outcome["annotated_request"]["messages"], json!([])); assert_eq!( - read_json_and_free(&host, out_request)["content"]["input"], - json!(true) + outcome["pending_marks"][0]["name"], + "plugin.request.rewritten" ); + assert_eq!(outcome["pending_marks"][0]["data"]["saved_tokens"], 7); + unsafe { - (host.string_free)(stale_annotated); (host.string_free)(name); (host.string_free)(request); + (host.string_free)(annotated); registration.free(); } } diff --git a/crates/python/src/py_api/mod.rs b/crates/python/src/py_api/mod.rs index f6682b43e..2603e92f3 100644 --- a/crates/python/src/py_api/mod.rs +++ b/crates/python/src/py_api/mod.rs @@ -1207,9 +1207,12 @@ fn tool_conditional_execution(name: &str, args: &Bound<'_, PyAny>) -> PyResult<( /// Returns: /// The (possibly transformed) ``LlmRequest``. #[pyfunction] -fn llm_request_intercepts(name: &str, request: PyLLMRequest) -> PyResult { +fn llm_request_intercepts( + name: &str, + request: PyLLMRequest, +) -> PyResult { let result = core_llm_api::llm_request_intercepts(name, request.inner).map_err(to_py_err)?; - Ok(PyLLMRequest { inner: result }) + Ok(crate::py_types::PyLLMRequestInterceptOutcome { inner: result }) } /// Run the registered LLM conditional execution guardrail chain. diff --git a/crates/python/src/py_callable.rs b/crates/python/src/py_callable.rs index 07bc392ea..7b8629be5 100644 --- a/crates/python/src/py_callable.rs +++ b/crates/python/src/py_callable.rs @@ -35,13 +35,15 @@ use serde_json::Value as Json; use tokio_stream::Stream; use nemo_relay::api::event::Event; -use nemo_relay::api::llm::LlmRequest; +use nemo_relay::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay::codec::request::AnnotatedLlmRequest as AnnotatedLLMRequest; use nemo_relay::codec::response::AnnotatedLlmResponse as AnnotatedLLMResponse; use nemo_relay::codec::traits::{LlmCodec, LlmResponseCodec}; use crate::convert::{json_to_py, py_to_json}; -use crate::py_types::{PyAnnotatedLLMRequest, PyAnnotatedLLMResponse, PyLLMRequest}; +use crate::py_types::{ + PyAnnotatedLLMRequest, PyAnnotatedLLMResponse, PyLLMRequest, PyLLMRequestInterceptOutcome, +}; type PyValueFuture = Pin>> + Send>>; @@ -643,13 +645,15 @@ pub fn wrap_py_llm_conditional_fn(py_fn: Py) -> LlmConditionalFn { /// Wrap a Python callable for unified LLM request intercepts. /// /// The Python function receives ``(name: str, request: LlmRequest, annotated: AnnotatedLLMRequest | None)`` -/// and must return ``(LlmRequest, AnnotatedLLMRequest | None)``. +/// and must return ``LLMRequestInterceptOutcome``. +/// When ``annotated`` is present, request content is read-only and provider-body +/// edits must be made through the returned annotation; headers remain writable. pub fn wrap_py_llm_request_intercept_fn(py_fn: Py) -> LlmRequestInterceptFn { Arc::new( move |name: &str, request: LlmRequest, annotated: Option| - -> FlowResult<(LlmRequest, Option)> { + -> FlowResult { Python::attach(|py| { let py_req = PyLLMRequest { inner: request.clone(), @@ -673,42 +677,14 @@ pub fn wrap_py_llm_request_intercept_fn(py_fn: Py) -> LlmRequestIntercept FlowError::Internal(format!("LLM request intercept callable failed: {e}")) })?; - // Extract the tuple (LlmRequest, AnnotatedLLMRequest | None) - let tuple = result.bind(py); - let new_req: PyLLMRequest = tuple - .get_item(0) - .map_err(|e| { - FlowError::Internal(format!( - "LLM request intercept result[0] extraction failed: {e}" - )) - })? - .extract() + result + .extract::(py) + .map(|value| value.inner) .map_err(|e| { FlowError::Internal(format!( - "LLM request intercept result[0] is not LlmRequest: {e}" + "LLM request intercept must return LLMRequestInterceptOutcome: {e}" )) - })?; - let ann_item = tuple.get_item(1).map_err(|e| { - FlowError::Internal(format!( - "LLM request intercept result[1] extraction failed: {e}" - )) - })?; - let new_ann = if ann_item.is_none() { - None - } else { - Some( - ann_item - .extract::() - .map_err(|e| { - FlowError::Internal(format!( - "LLM request intercept result[1] is not AnnotatedLLMRequest: {e}" - )) - })? - .inner, - ) - }; - - Ok((new_req.inner, new_ann)) + }) }) }, ) diff --git a/crates/python/src/py_types/core.rs b/crates/python/src/py_types/core.rs index 98dfc6a59..fd7aba216 100644 --- a/crates/python/src/py_types/core.rs +++ b/crates/python/src/py_types/core.rs @@ -4,10 +4,12 @@ use pyo3::prelude::*; use super::{ - Bound, CoreScopeType, FlowResult, LlmAttributes, LlmHandle, LlmRequest, PyAny, PyErr, PyRef, - PyResult, Python, ScopeAttributes, ScopeHandle, ScopeStackHandle, ToolAttributes, ToolHandle, - json_to_py, opt_json_to_py, py_to_json, + AnnotatedLLMRequest, Bound, CoreScopeType, FlowResult, LlmAttributes, LlmHandle, LlmRequest, + PyAnnotatedLLMRequest, PyAny, PyErr, PyRef, PyResult, Python, ScopeAttributes, ScopeHandle, + ScopeStackHandle, ToolAttributes, ToolHandle, json_to_py, opt_json_to_py, py_to_json, }; +use nemo_relay::api::event::{CategoryProfile, EventCategory, PendingMarkSpec}; +use nemo_relay::api::llm::LlmRequestInterceptOutcome; // --------------------------------------------------------------------------- // LlmStream (async iterator) @@ -597,3 +599,132 @@ impl PyLLMRequest { "LLMRequest(...)".to_string() } } + +/// A mark to emit immediately after the managed LLM start event. +#[pyclass(name = "PendingMarkSpec", from_py_object)] +#[derive(Clone)] +pub struct PyPendingMarkSpec { + pub inner: PendingMarkSpec, +} + +#[pymethods] +impl PyPendingMarkSpec { + #[new] + #[pyo3(signature = (name, category=None, category_profile=None, data=None, metadata=None))] + fn new( + name: String, + category: Option, + category_profile: Option<&Bound<'_, PyAny>>, + data: Option<&Bound<'_, PyAny>>, + metadata: Option<&Bound<'_, PyAny>>, + ) -> PyResult { + let category = category + .map(|value| serde_json::from_value::(serde_json::Value::String(value))) + .transpose() + .map_err(|error| pyo3::exceptions::PyValueError::new_err(error.to_string()))?; + let category_profile = category_profile + .map(py_to_json) + .transpose()? + .map(serde_json::from_value::) + .transpose() + .map_err(|error| pyo3::exceptions::PyValueError::new_err(error.to_string()))?; + Ok(Self { + inner: PendingMarkSpec { + name, + category, + category_profile, + data: data.map(py_to_json).transpose()?, + metadata: metadata.map(py_to_json).transpose()?, + }, + }) + } + + #[getter] + fn name(&self) -> &str { + &self.inner.name + } + + #[getter] + fn category(&self) -> Option { + self.inner + .category + .as_ref() + .and_then(|value| serde_json::to_value(value).ok()) + .and_then(|value| value.as_str().map(str::to_owned)) + } + + #[getter] + fn category_profile(&self, py: Python<'_>) -> PyResult> { + opt_json_to_py( + py, + &self + .inner + .category_profile + .as_ref() + .map(serde_json::to_value) + .transpose() + .map_err(|error| pyo3::exceptions::PyRuntimeError::new_err(error.to_string()))?, + ) + } + + #[getter] + fn data(&self, py: Python<'_>) -> PyResult> { + opt_json_to_py(py, &self.inner.data) + } + + #[getter] + fn metadata(&self, py: Python<'_>) -> PyResult> { + opt_json_to_py(py, &self.inner.metadata) + } +} + +/// Canonical result returned by Python LLM request intercepts. +#[pyclass(name = "LLMRequestInterceptOutcome", from_py_object)] +#[derive(Clone)] +pub struct PyLLMRequestInterceptOutcome { + pub inner: LlmRequestInterceptOutcome, +} + +#[pymethods] +impl PyLLMRequestInterceptOutcome { + #[new] + #[pyo3(signature = (request, annotated_request=None, pending_marks=Vec::new()))] + fn new( + request: PyLLMRequest, + annotated_request: Option, + pending_marks: Vec, + ) -> Self { + Self { + inner: LlmRequestInterceptOutcome { + request: request.inner, + annotated_request: annotated_request.map(|value| value.inner), + pending_marks: pending_marks.into_iter().map(|value| value.inner).collect(), + }, + } + } + + #[getter] + fn request(&self) -> PyLLMRequest { + PyLLMRequest { + inner: self.inner.request.clone(), + } + } + + #[getter] + fn annotated_request(&self) -> Option { + self.inner + .annotated_request + .clone() + .map(|inner: AnnotatedLLMRequest| PyAnnotatedLLMRequest { inner }) + } + + #[getter] + fn pending_marks(&self) -> Vec { + self.inner + .pending_marks + .iter() + .cloned() + .map(|inner| PyPendingMarkSpec { inner }) + .collect() + } +} diff --git a/crates/python/src/py_types/mod.rs b/crates/python/src/py_types/mod.rs index f4e018536..7a2b28265 100644 --- a/crates/python/src/py_types/mod.rs +++ b/crates/python/src/py_types/mod.rs @@ -132,6 +132,8 @@ pub fn register(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/crates/python/tests/coverage/coverage_tests.rs b/crates/python/tests/coverage/coverage_tests.rs index 6c3205e00..4fa673131 100644 --- a/crates/python/tests/coverage/coverage_tests.rs +++ b/crates/python/tests/coverage/coverage_tests.rs @@ -30,7 +30,14 @@ fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { let code = CString::new(code).unwrap(); let file_name = CString::new("coverage_tests.py").unwrap(); let module_name = CString::new("coverage_tests").unwrap(); - PyModule::from_code(py, &code, &file_name, &module_name).unwrap() + let module = PyModule::from_code(py, &code, &file_name, &module_name).unwrap(); + module + .setattr( + "Outcome", + py.get_type::(), + ) + .unwrap(); + module } fn make_request() -> LlmRequest { @@ -390,7 +397,7 @@ def llm_conditional(request): return None def llm_request_intercept(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) async def llm_execution_intercept(name, request, next): return await next(request) diff --git a/crates/python/tests/coverage/py_adaptive_coverage_tests.rs b/crates/python/tests/coverage/py_adaptive_coverage_tests.rs index 6c040bb67..aa5f8ef95 100644 --- a/crates/python/tests/coverage/py_adaptive_coverage_tests.rs +++ b/crates/python/tests/coverage/py_adaptive_coverage_tests.rs @@ -135,7 +135,7 @@ def bind_scope_and_translate(api_module, runtime, request): handle = api_module.push_scope("adaptive-runtime-cov", api_module.ScopeType.Agent) try: runtime.bind_scope(handle) - return api_module.llm_request_intercepts("anthropic", request).content + return api_module.llm_request_intercepts("anthropic", request).request.content finally: api_module.pop_scope(handle) "# diff --git a/crates/python/tests/coverage/py_api_coverage_tests.rs b/crates/python/tests/coverage/py_api_coverage_tests.rs index 03542b8ea..9ce9107ab 100644 --- a/crates/python/tests/coverage/py_api_coverage_tests.rs +++ b/crates/python/tests/coverage/py_api_coverage_tests.rs @@ -193,9 +193,13 @@ def llm_conditional(request): def llm_request_intercept(name, request, annotated): headers = dict(request.headers) headers["x-intercepted"] = "1" - content = dict(request.content) - content["messages"] = [{"role": "user", "content": "hello from intercept"}] - return (LLMRequest(headers, content), annotated) + if annotated is None: + content = dict(request.content) + content["messages"] = [{"role": "user", "content": "hello from intercept"}] + else: + content = request.content + annotated.messages = [{"role": "user", "content": "hello from intercept"}] + return LLMRequestInterceptOutcome(LLMRequest(headers, content), annotated) async def llm_exec(request): return { @@ -324,6 +328,12 @@ async def run_stream(api, request, func, collector, finalizer, handle, attribute helpers .setattr("LLMRequest", types_module.getattr("LLMRequest").unwrap()) .unwrap(); + helpers + .setattr( + "LLMRequestInterceptOutcome", + types_module.getattr("LLMRequestInterceptOutcome").unwrap(), + ) + .unwrap(); helpers .setattr( "AnnotatedLLMRequest", @@ -457,7 +467,11 @@ async def run_stream(api, request, func, collector, finalizer, handle, attribute }; let intercepted_request = llm_request_intercepts("demo-llm", llm_request.clone()).unwrap(); assert_eq!( - intercepted_request.inner.headers.get("x-intercepted"), + intercepted_request + .inner + .request + .headers + .get("x-intercepted"), Some(&json!("1")) ); llm_conditional_execution(llm_request.clone()).unwrap(); diff --git a/crates/python/tests/coverage/py_callable_coverage_tests.rs b/crates/python/tests/coverage/py_callable_coverage_tests.rs index 9c9e0f7a5..cac6bbd6e 100644 --- a/crates/python/tests/coverage/py_callable_coverage_tests.rs +++ b/crates/python/tests/coverage/py_callable_coverage_tests.rs @@ -60,7 +60,7 @@ def sync_llm_intercept(name, request, next): return {"name": name, "model": request.content["model"], "mode": "sync"} def request_echo(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) def request_bad_annotated(name, request, annotated): return (request, {"bad": True}) @@ -93,6 +93,12 @@ class RaisingResponseCodec: raise RuntimeError("decode boom") "#, ); + module + .setattr( + "Outcome", + py.get_type::(), + ) + .unwrap(); let tool_exec_py: Py = module.getattr("sync_tool_exec").unwrap().unbind(); let tool_intercept_py: Py = module.getattr("sync_tool_intercept").unwrap().unbind(); @@ -142,9 +148,11 @@ class RaisingResponseCodec: "model": "codec-model" })) .unwrap(); - let (_request, echoed_ann) = - request_intercept("llm", make_request(), Some(annotated.clone())).unwrap(); - assert_eq!(echoed_ann.unwrap().last_user_message(), Some("annotated")); + let outcome = request_intercept("llm", make_request(), Some(annotated.clone())).unwrap(); + assert_eq!( + outcome.annotated_request.unwrap().last_user_message(), + Some("annotated") + ); let bad_request_intercept = wrap_py_llm_request_intercept_fn( module.getattr("request_bad_annotated").unwrap().unbind(), @@ -153,7 +161,7 @@ class RaisingResponseCodec: bad_request_intercept("llm", make_request(), Some(annotated)) .unwrap_err() .to_string() - .contains("result[1] is not AnnotatedLLMRequest") + .contains("must return LLMRequestInterceptOutcome") ); let short_request_intercept = wrap_py_llm_request_intercept_fn( @@ -163,7 +171,7 @@ class RaisingResponseCodec: short_request_intercept("llm", make_request(), None) .unwrap_err() .to_string() - .contains("result[1] extraction failed") + .contains("must return LLMRequestInterceptOutcome") ); let mut collector = wrap_py_collector_fn(module.getattr("collector_ok").unwrap().unbind()); diff --git a/crates/python/tests/coverage/py_plugin_coverage_tests.rs b/crates/python/tests/coverage/py_plugin_coverage_tests.rs index dbb8a21b9..0eba263c2 100644 --- a/crates/python/tests/coverage/py_plugin_coverage_tests.rs +++ b/crates/python/tests/coverage/py_plugin_coverage_tests.rs @@ -16,7 +16,14 @@ fn load_module<'py>(py: Python<'py>, code: &str) -> Bound<'py, PyModule> { let code = CString::new(code).unwrap(); let file_name = CString::new("py_plugin_coverage_tests.py").unwrap(); let module_name = CString::new("py_plugin_coverage_tests").unwrap(); - PyModule::from_code(py, &code, &file_name, &module_name).unwrap() + let module = PyModule::from_code(py, &code, &file_name, &module_name).unwrap(); + module + .setattr( + "Outcome", + py.get_type::(), + ) + .unwrap(); + module } fn with_event_loop(py: Python<'_>, f: impl FnOnce(Bound<'_, PyAny>) -> T) -> T { @@ -123,7 +130,7 @@ def llm_conditional(request): return None def llm_request_intercept(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) async def llm_execution_intercept(name, request, next): return await next(request) @@ -579,7 +586,7 @@ def llm_conditional(request): return None def llm_request_intercept(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) async def llm_execution_intercept(name, request, next): return await next(request) @@ -863,7 +870,7 @@ def llm_conditional(request): return None def llm_request_intercept(name, request, annotated): - return (request, annotated) + return Outcome(request, annotated) async def llm_execution_intercept(name, request, next): return await next(request) diff --git a/crates/types/src/api/event.rs b/crates/types/src/api/event.rs index af17546ae..d08bf8180 100644 --- a/crates/types/src/api/event.rs +++ b/crates/types/src/api/event.rs @@ -364,6 +364,29 @@ pub struct MarkEvent { pub category_profile: Option, } +/// Mark requested by middleware before its owning runtime scope exists. +/// +/// The runtime assigns the parent UUID, event UUID, and timestamp when it +/// materializes the mark at the appropriate lifecycle boundary. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, TypedBuilder)] +#[builder(field_defaults(setter(into, strip_option(ignore_invalid, fallback_suffix = "_opt"))))] +pub struct PendingMarkSpec { + /// Human-readable mark name. + pub name: String, + /// Optional semantic category for the mark. + #[builder(default)] + pub category: Option, + /// Optional category-specific typed fields. + #[builder(default)] + pub category_profile: Option, + /// Optional application payload attached to the mark. + #[builder(default)] + pub data: Option, + /// Optional metadata attached to the mark. + #[builder(default)] + pub metadata: Option, +} + impl MarkEvent { /// Construct a mark event from a base envelope and optional category data. /// diff --git a/crates/types/src/api/llm.rs b/crates/types/src/api/llm.rs index caa917a65..121754818 100644 --- a/crates/types/src/api/llm.rs +++ b/crates/types/src/api/llm.rs @@ -7,6 +7,8 @@ use bitflags::bitflags; use serde::{Deserialize, Serialize}; use crate::Json; +use crate::api::event::PendingMarkSpec; +use crate::codec::request::AnnotatedLlmRequest; bitflags! { /// Bitflags that modify LLM-call behavior and observability. @@ -20,10 +22,66 @@ bitflags! { } /// JSON-shaped LLM request payload passed through the runtime. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct LlmRequest { /// Provider-specific request headers. pub headers: serde_json::Map, /// Provider-specific request body. pub content: Json, } + +/// Result of an LLM request intercept that can schedule lifecycle marks. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LlmRequestInterceptOutcome { + /// Rewritten provider request when no request codec is active. + /// + /// With a request codec, callbacks may rewrite `headers`, but `content` + /// is read-only and provider-body changes must be made through + /// [`Self::annotated_request`]. + pub request: LlmRequest, + /// Optional normalized request annotation to carry forward. + /// + /// This is required and authoritative for provider content when a request + /// codec is active. It remains optional when no request codec is active. + #[serde(default)] + pub annotated_request: Option, + /// Ordered marks to emit after Relay creates and starts the LLM scope. + #[serde(default)] + pub pending_marks: Vec, +} + +impl LlmRequestInterceptOutcome { + /// Create an outcome without pending marks. + pub fn new(request: LlmRequest, annotated_request: Option) -> Self { + Self { + request, + annotated_request, + pending_marks: Vec::new(), + } + } + + /// Append one pending mark while preserving interceptor order. + #[must_use] + pub fn with_pending_mark(mut self, mark: PendingMarkSpec) -> Self { + self.pending_marks.push(mark); + self + } +} + +impl From for LlmRequestInterceptOutcome { + fn from(request: LlmRequest) -> Self { + Self::new(request, None) + } +} + +impl From<(LlmRequest, AnnotatedLlmRequest)> for LlmRequestInterceptOutcome { + fn from((request, annotated_request): (LlmRequest, AnnotatedLlmRequest)) -> Self { + Self::new(request, Some(annotated_request)) + } +} + +impl From<(LlmRequest, Option)> for LlmRequestInterceptOutcome { + fn from((request, annotated_request): (LlmRequest, Option)) -> Self { + Self::new(request, annotated_request) + } +} diff --git a/crates/types/tests/serialization_tests.rs b/crates/types/tests/serialization_tests.rs index 9e899ff7d..39bb8fbc7 100644 --- a/crates/types/tests/serialization_tests.rs +++ b/crates/types/tests/serialization_tests.rs @@ -6,10 +6,10 @@ use std::sync::Arc; use nemo_relay_types::api::event::{ - BaseEvent, CategoryProfile, Event, EventCategory, ScopeCategory, ScopeEvent, + BaseEvent, CategoryProfile, Event, EventCategory, PendingMarkSpec, ScopeCategory, ScopeEvent, llm_attributes_to_strings, }; -use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest}; +use nemo_relay_types::api::llm::{LlmAttributes, LlmRequest, LlmRequestInterceptOutcome}; use nemo_relay_types::codec::request::{AnnotatedLlmRequest, Message, MessageContent}; use nemo_relay_types::codec::response::AnnotatedLlmResponse; use serde_json::{Map, json}; @@ -78,3 +78,96 @@ fn event_round_trips_with_annotated_llm_profiles() { Some("resp_1") ); } + +#[test] +fn llm_request_intercept_outcome_round_trips_pending_marks() { + let outcome = LlmRequestInterceptOutcome::new( + LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "hello" }), + }, + None, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("request.optimized") + .category(EventCategory::custom()) + .category_profile( + CategoryProfile::builder() + .subtype("optimizer.saved_tokens") + .build(), + ) + .data(json!({ "saved_tokens": 12 })) + .metadata(json!({ "source": "test" })) + .build(), + ); + + let encoded = serde_json::to_value(&outcome).expect("outcome should serialize"); + assert_eq!(encoded["pending_marks"][0]["name"], "request.optimized"); + assert_eq!(encoded["pending_marks"][0]["category"], "custom"); + assert!(encoded["annotated_request"].is_null()); + + let mut encoded_without_pending_marks = encoded.clone(); + encoded_without_pending_marks + .as_object_mut() + .unwrap() + .remove("pending_marks"); + let decoded_without_pending_marks: LlmRequestInterceptOutcome = + serde_json::from_value(encoded_without_pending_marks) + .expect("outcome without pending marks should deserialize"); + assert!(decoded_without_pending_marks.pending_marks.is_empty()); + + let decoded_defaults: LlmRequestInterceptOutcome = serde_json::from_value(json!({ + "request": {"headers": {}, "content": {"prompt": "hello"}}, + "future_field": true + })) + .expect("omitted optional fields and unknown fields should be accepted"); + assert!(decoded_defaults.annotated_request.is_none()); + assert!(decoded_defaults.pending_marks.is_empty()); + + assert!( + serde_json::from_value::(json!({ + "annotated_request": null, + "pending_marks": [] + })) + .is_err(), + "request is required" + ); + + let decoded: LlmRequestInterceptOutcome = + serde_json::from_value(encoded).expect("outcome should deserialize"); + assert_eq!(decoded, outcome); +} + +#[test] +fn llm_request_intercept_outcome_converts_from_request_inputs() { + let request = LlmRequest { + headers: Map::new(), + content: json!({ "prompt": "hello" }), + }; + let annotated_request: AnnotatedLlmRequest = serde_json::from_value(json!({ + "messages": [], + "model": "model" + })) + .expect("annotated request should deserialize"); + + let request_only: LlmRequestInterceptOutcome = request.clone().into(); + assert_eq!( + request_only, + LlmRequestInterceptOutcome::new(request.clone(), None) + ); + + let required_annotation: LlmRequestInterceptOutcome = + (request.clone(), annotated_request.clone()).into(); + assert_eq!( + required_annotation, + LlmRequestInterceptOutcome::new(request.clone(), Some(annotated_request.clone())) + ); + + let optional_annotation: LlmRequestInterceptOutcome = + (request.clone(), Some(annotated_request.clone())).into(); + assert_eq!( + optional_annotation, + LlmRequestInterceptOutcome::new(request, Some(annotated_request)) + ); +} diff --git a/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto b/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto index 134154584..9a7e7f704 100644 --- a/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto +++ b/crates/worker-proto/proto/nemo/relay/worker/v1/plugin_worker.proto @@ -181,9 +181,7 @@ message GuardrailResult { } message LlmRequestInterceptResult { - JsonEnvelope request = 1; - JsonEnvelope annotated_request = 2; - bool has_annotated_request = 3; + JsonEnvelope outcome = 1; } message StreamChunk { diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index 1f4b559c1..9115e6d32 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -34,8 +34,8 @@ use futures_util::{Stream, StreamExt}; #[cfg(unix)] use hyper_util::rt::TokioIo; pub use nemo_relay_types::Json; -pub use nemo_relay_types::api::event::Event; -pub use nemo_relay_types::api::llm::LlmRequest; +pub use nemo_relay_types::api::event::{Event, PendingMarkSpec}; +pub use nemo_relay_types::api::llm::{LlmRequest, LlmRequestInterceptOutcome}; pub use nemo_relay_types::api::scope::ScopeType; use nemo_relay_types::codec::request::AnnotatedLlmRequest; pub use nemo_relay_types::plugin::{ConfigDiagnostic, DiagnosticLevel}; @@ -125,11 +125,7 @@ type LlmSanitizeRequestFn = Arc LlmRequest + Send + Sync>; type LlmSanitizeResponseFn = Arc Json + Send + Sync>; type LlmConditionalFn = Arc Result> + Send + Sync>; type LlmRequestFn = Arc< - dyn Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> + dyn Fn(&str, LlmRequest, Option) -> Result + Send + Sync, >; @@ -364,11 +360,7 @@ impl PluginContext { break_chain: bool, callback: F, ) where - F: Fn( - &str, - LlmRequest, - Option, - ) -> Result<(LlmRequest, Option)> + F: Fn(&str, LlmRequest, Option) -> Result + Send + Sync + 'static, @@ -1407,10 +1399,10 @@ impl WorkerService { .map(|value| decode_json_envelope::(&value)) .transpose()?; let handler = self.llm_request(&request.registration_name)?; - let (request, annotated) = with_thread_scope(&scope, || { + let outcome = with_thread_scope(&scope, || { handler(&payload.model_name, request_value, annotated) })?; - Ok(llm_request_response(request, annotated)?) + Ok(llm_request_response(outcome)?) } RegistrationSurface::LlmExecutionIntercept => { let payload = llm_payload(request.payload)?; @@ -1664,20 +1656,15 @@ fn guardrail_response(reason: Option) -> InvokeResponse { } } -fn llm_request_response( - request: LlmRequest, - annotated: Option, -) -> Result { +fn llm_request_response(outcome: LlmRequestInterceptOutcome) -> Result { Ok(InvokeResponse { result: Some( nemo_relay_worker_proto::v1::invoke_response::Result::LlmRequest( LlmRequestInterceptResult { - request: Some(json_envelope("nemo.relay.LlmRequest@1", &request)?), - annotated_request: annotated - .as_ref() - .map(|value| json_envelope("nemo.relay.AnnotatedLlmRequest@1", value)) - .transpose()?, - has_annotated_request: annotated.is_some(), + outcome: Some(json_envelope( + "nemo.relay.LlmRequestInterceptOutcome@1", + &outcome, + )?), }, ), ), diff --git a/crates/worker/tests/worker_sdk_tests.rs b/crates/worker/tests/worker_sdk_tests.rs index 8a64b130c..f8baa1fa9 100644 --- a/crates/worker/tests/worker_sdk_tests.rs +++ b/crates/worker/tests/worker_sdk_tests.rs @@ -1590,7 +1590,10 @@ impl WorkerPlugin for SurfacePlugin { .and_then(|blocked| blocked.then(|| "blocked-llm".into()))) }); ctx.register_llm_request_intercept("llm-request", 1, false, |_, request, annotated| { - Ok((set_llm_phase(request, "llm_request"), annotated)) + Ok(nemo_relay_worker::LlmRequestInterceptOutcome::new( + set_llm_phase(request, "llm_request"), + annotated, + )) }); ctx.register_llm_execution_intercept( @@ -2280,7 +2283,11 @@ async fn invoke_llm_request( .into_inner(); match response.result.expect("invoke result") { nemo_relay_worker_proto::v1::invoke_response::Result::LlmRequest(result) => { - decode_json_envelope(&result.request.expect("llm request")).expect("decode LLM request") + decode_json_envelope::( + &result.outcome.expect("llm request outcome"), + ) + .expect("decode LLM request outcome") + .request } other => panic!("unexpected invoke result: {other:?}"), } diff --git a/examples/rust-native-plugin/README.md b/examples/rust-native-plugin/README.md index 82fa142a6..ea38314f2 100644 --- a/examples/rust-native-plugin/README.md +++ b/examples/rust-native-plugin/README.md @@ -84,9 +84,16 @@ The example registers the following runtime behavior: - Request and execution intercepts for tools that mutate JSON payloads and call continuations. - LLM sanitize request/response guardrails. -- LLM request, execution, and stream execution intercepts. +- An LLM request intercept that rewrites the request and schedules a mark. Relay + emits that mark after the LLM start event with the LLM scope as its parent. +- LLM execution and stream execution intercepts. - Runtime mark and scope events. - A plugin-owned isolated scope stack for non-correlated visibility. Native plugins are not sandboxed. They run in the Relay process and must not unwind across ABI callbacks. + +Request intercepts do not own an LLM lifecycle because they run before Relay +creates the LLM scope. `register_llm_request_intercept` returns one +`LlmRequestInterceptOutcome`, whose `pending_marks` Relay emits in interceptor +order after the LLM start event and before provider execution. diff --git a/examples/rust-native-plugin/src/lib.rs b/examples/rust-native-plugin/src/lib.rs index f18296f54..4d9130e8c 100644 --- a/examples/rust-native-plugin/src/lib.rs +++ b/examples/rust-native-plugin/src/lib.rs @@ -2,8 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 use nemo_relay_plugin::{ - ConfigDiagnostic, DiagnosticLevel, Event, Json, LlmJsonStream, LlmRequest, NativePlugin, - PluginContext, PluginRuntime, ScopeCategory, ScopeType, + CategoryProfile, ConfigDiagnostic, DiagnosticLevel, Event, EventCategory, Json, LlmJsonStream, + LlmRequest, LlmRequestInterceptOutcome, NativePlugin, PendingMarkSpec, PluginContext, + PluginRuntime, ScopeCategory, ScopeType, }; use serde_json::{Map, json}; @@ -235,9 +236,20 @@ impl NativePlugin for ExampleNativePlugin { ctx.register_llm_request_intercept("example_llm_request", 20, false, { let tag = config.tag.clone(); move |_name, request, annotated| { - Ok(( + Ok(LlmRequestInterceptOutcome::new( tag_llm_request(request, "native_llm_request_intercept", &tag), annotated, + ) + .with_pending_mark( + PendingMarkSpec::builder() + .name("example.native.llm_request_intercept") + .category(EventCategory::custom()) + .category_profile(CategoryProfile { + subtype: Some("example.native.request_rewrite".into()), + ..CategoryProfile::default() + }) + .data(json!({ "tag": &tag })) + .build(), )) } })?; diff --git a/go/nemo_relay/adaptive_plugin_test.go b/go/nemo_relay/adaptive_plugin_test.go index 44d580b0c..7b53a7131 100644 --- a/go/nemo_relay/adaptive_plugin_test.go +++ b/go/nemo_relay/adaptive_plugin_test.go @@ -104,12 +104,13 @@ func registerLifecycleInterceptors(ctx *PluginContext, pluginKind string) error "llm_request", 7, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - out, err := decorateJSONPayload(headers, "x-go-plugin", pluginKind) + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { + out, err := decorateJSONPayload(request.Headers, "x-go-plugin", pluginKind) if err != nil { - return nil, nil, nil, err + return LLMRequestInterceptOutcome{}, err } - return out, content, annotated, nil + request.Headers = out + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { return err @@ -441,8 +442,8 @@ func TestPluginFuncsAndClosedContextBranches(t *testing.T) { return closed.RegisterLlmConditionalExecutionGuardrail("llm_conditional", 1, func(headers, content json.RawMessage) *string { return nil }) }}, {"llm request", func() error { - return closed.RegisterLlmRequestIntercept("llm_request", 1, false, func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return headers, content, annotated, nil + return closed.RegisterLlmRequestIntercept("llm_request", 1, false, func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }) }}, {"tool request", func() error { diff --git a/go/nemo_relay/callbacks.go b/go/nemo_relay/callbacks.go index 68f3d7f71..856e6fa28 100644 --- a/go/nemo_relay/callbacks.go +++ b/go/nemo_relay/callbacks.go @@ -212,16 +212,37 @@ type CodecFunc struct { Encode func(annotatedJSON json.RawMessage, originalHeadersJSON, originalContentJSON json.RawMessage) (json.RawMessage, error) } -// LLMRequestInterceptFunc is a callback for LLM request intercepts with -// the unified annotated-aware signature. It receives the intercept name, -// request headers/content, and optionally the annotated request JSON (nil if -// no Codec resolved). Returns the (possibly modified) headers, content, and -// annotated JSON. +// LLMRequestDTO is the JSON-shaped request used by request intercept outcomes. +type LLMRequestDTO struct { + Headers json.RawMessage `json:"headers"` + Content json.RawMessage `json:"content"` +} + +// PendingMarkSpec describes a mark Relay emits after starting a managed LLM call. +type PendingMarkSpec struct { + Name string `json:"name"` + Category *string `json:"category,omitempty"` + CategoryProfile json.RawMessage `json:"category_profile,omitempty"` + Data json.RawMessage `json:"data,omitempty"` + Metadata json.RawMessage `json:"metadata,omitempty"` +} + +// LLMRequestInterceptOutcome is the canonical result of an LLM request intercept. +type LLMRequestInterceptOutcome struct { + Request LLMRequestDTO `json:"request"` + AnnotatedRequest json.RawMessage `json:"annotated_request"` + PendingMarks []PendingMarkSpec `json:"pending_marks"` +} + +// LLMRequestInterceptFunc is a callback for LLM request intercepts. When +// annotatedJSON is non-nil, request.Content is read-only, request.Headers may +// be changed, and the returned annotation is authoritative for provider body +// content. Without an annotation, the full request is writable. type LLMRequestInterceptFunc func( name string, - headers, content json.RawMessage, + request LLMRequestDTO, annotatedJSON json.RawMessage, -) (newHeaders, newContent, newAnnotatedJSON json.RawMessage, err error) +) (LLMRequestInterceptOutcome, error) func codecDecodePayload(codec *CodecFunc, headers, content json.RawMessage) (json.RawMessage, error) { return codec.Decode(headers, content) @@ -236,8 +257,8 @@ func llmRequestInterceptPayload( name string, headers, content json.RawMessage, annotatedJSON json.RawMessage, -) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return fn(name, headers, content, annotatedJSON) +) (LLMRequestInterceptOutcome, error) { + return fn(name, LLMRequestDTO{Headers: headers, Content: content}, annotatedJSON) } func pluginValidatePayload(plugin Plugin, pluginConfigJSON json.RawMessage) (json.RawMessage, error) { @@ -512,7 +533,7 @@ func goCodecEncodeTrampoline(userData unsafe.Pointer, annotatedJSON *C.char, ori //export goLlmRequestInterceptTrampoline func goLlmRequestInterceptTrampoline( userData unsafe.Pointer, name *C.char, request *C.FfiLLMRequest, - annotatedJSON *C.char, outRequest **C.FfiLLMRequest, outAnnotatedJSON **C.char, + annotatedJSON *C.char, outOutcomeJSON **C.char, ) C.int32_t { fn := lookupClosure(userData).(LLMRequestInterceptFunc) goName := C.GoString(name) @@ -526,20 +547,20 @@ func goLlmRequestInterceptTrampoline( if annotatedJSON != nil { goAnnotated = json.RawMessage(C.GoString(annotatedJSON)) } - newHeaders, newContent, newAnnotated, err := llmRequestInterceptPayload(fn, goName, goHeaders, goContent, goAnnotated) + outcome, err := llmRequestInterceptPayload(fn, goName, goHeaders, goContent, goAnnotated) if err != nil { setLastErrorMessage(err.Error()) return 5 // NemoRelayStatus::Internal } - // Create output FfiLLMRequest - cNewHeaders := C.CString(string(newHeaders)) - cNewContent := C.CString(string(newContent)) - defer C.free(unsafe.Pointer(cNewHeaders)) - defer C.free(unsafe.Pointer(cNewContent)) - *outRequest = C.nemo_relay_llm_request_new(cNewHeaders, cNewContent) - if newAnnotated != nil { - *outAnnotatedJSON = C.CString(string(newAnnotated)) + if outcome.PendingMarks == nil { + outcome.PendingMarks = []PendingMarkSpec{} + } + outcomeJSON, err := jsonMarshal(outcome) + if err != nil { + setLastErrorMessage(err.Error()) + return 5 } + *outOutcomeJSON = C.CString(string(outcomeJSON)) return 0 // NemoRelayStatus::Ok } diff --git a/go/nemo_relay/intercepts/intercepts.go b/go/nemo_relay/intercepts/intercepts.go index 5872b74a6..98030c400 100644 --- a/go/nemo_relay/intercepts/intercepts.go +++ b/go/nemo_relay/intercepts/intercepts.go @@ -206,6 +206,6 @@ func ToolRequestIntercepts(name string, args json.RawMessage) (json.RawMessage, // LlmRequestIntercepts runs the registered LLM request intercept chain and // returns the transformed request. This is a shorthand for // [nemo_relay.LlmRequestIntercepts]. -func LlmRequestIntercepts(name string, request json.RawMessage) (json.RawMessage, error) { +func LlmRequestIntercepts(name string, request json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { return nemo_relay.LlmRequestIntercepts(name, request) } diff --git a/go/nemo_relay/intercepts/intercepts_test.go b/go/nemo_relay/intercepts/intercepts_test.go index 1bb6eb8da..ed16007be 100644 --- a/go/nemo_relay/intercepts/intercepts_test.go +++ b/go/nemo_relay/intercepts/intercepts_test.go @@ -87,12 +87,12 @@ func runGlobalLLMInterceptShorthandChecks(t *testing.T) { t.Helper() if err := intercepts.RegisterLlmRequest("intercepts_llm_req", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request nemo_relay.LLMRequestDTO, annotated json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { var payload map[string]interface{} - _ = json.Unmarshal(content, &payload) + _ = json.Unmarshal(request.Content, &payload) payload["intercepted"] = true - out, _ := json.Marshal(payload) - return headers, out, annotated, nil + request.Content, _ = json.Marshal(payload) + return nemo_relay.LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { t.Fatalf("RegisterLlmRequest failed: %v", err) @@ -109,7 +109,7 @@ func runGlobalLLMInterceptShorthandChecks(t *testing.T) { var llmReq struct { Content map[string]interface{} `json:"content"` } - if err := json.Unmarshal(transformedRequest, &llmReq); err != nil { + if err := json.Unmarshal(transformedRequest.Request.Content, &llmReq.Content); err != nil { t.Fatalf("unmarshal llm request: %v", err) } if llmReq.Content["intercepted"] != true { @@ -198,8 +198,8 @@ func runScopeLocalLLMInterceptShorthandChecks(t *testing.T, scopeUUID string) { t.Helper() if err := intercepts.ScopeRegisterLlmRequest(scopeUUID, "intercepts_scope_llm_req", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return headers, content, annotated, nil + func(name string, request nemo_relay.LLMRequestDTO, annotated json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { + return nemo_relay.LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { t.Fatalf("ScopeRegisterLlmRequest failed: %v", err) diff --git a/go/nemo_relay/llm/llm.go b/go/nemo_relay/llm/llm.go index 662fa6b3e..742e16e99 100644 --- a/go/nemo_relay/llm/llm.go +++ b/go/nemo_relay/llm/llm.go @@ -63,7 +63,7 @@ func StreamExecute(name string, native interface{}, fn nemo_relay.LLMExecutionFu // RequestIntercepts runs the registered LLM request intercept chain on the // given request and returns the transformed request. This is a shorthand for // [nemo_relay.LlmRequestIntercepts]. -func RequestIntercepts(name string, request json.RawMessage) (json.RawMessage, error) { +func RequestIntercepts(name string, request json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { return nemo_relay.LlmRequestIntercepts(name, request) } diff --git a/go/nemo_relay/llm/llm_shorthand_test.go b/go/nemo_relay/llm/llm_shorthand_test.go index 2be60eadc..40e3604bc 100644 --- a/go/nemo_relay/llm/llm_shorthand_test.go +++ b/go/nemo_relay/llm/llm_shorthand_test.go @@ -45,12 +45,12 @@ func assertLLMRequestInterceptShorthand(t *testing.T) { t.Helper() if err := nemo_relay.RegisterLlmRequestIntercept("llm_req_int", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request nemo_relay.LLMRequestDTO, annotated json.RawMessage) (nemo_relay.LLMRequestInterceptOutcome, error) { var payload map[string]interface{} - _ = json.Unmarshal(content, &payload) + _ = json.Unmarshal(request.Content, &payload) payload["intercepted"] = true - out, _ := json.Marshal(payload) - return headers, out, annotated, nil + request.Content, _ = json.Marshal(payload) + return nemo_relay.LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { t.Fatalf("RegisterLlmRequestIntercept failed: %v", err) @@ -67,7 +67,7 @@ func assertLLMRequestInterceptShorthand(t *testing.T) { var intercepted struct { Content map[string]interface{} `json:"content"` } - if err := json.Unmarshal(request, &intercepted); err != nil { + if err := json.Unmarshal(request.Request.Content, &intercepted.Content); err != nil { t.Fatalf("unmarshal request: %v", err) } if intercepted.Content["intercepted"] != true { diff --git a/go/nemo_relay/llm_test.go b/go/nemo_relay/llm_test.go index 7b82ef77c..0ca101ce4 100644 --- a/go/nemo_relay/llm_test.go +++ b/go/nemo_relay/llm_test.go @@ -440,8 +440,8 @@ func TestLlmConditionalBlocksExecution(t *testing.T) { func TestLlmRequestInterceptRegisterDeregister(t *testing.T) { err := RegisterLlmRequestIntercept("go_llm_req", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return headers, content, annotated, nil + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ) if err != nil { @@ -521,14 +521,15 @@ func TestLlmStreamExecutionInterceptCanCallNext(t *testing.T) { func TestLlmRequestInterceptModifies(t *testing.T) { RegisterLlmRequestIntercept("go_llm_req_mod", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { var m map[string]interface{} - json.Unmarshal(content, &m) + json.Unmarshal(request.Content, &m) m["intercepted"] = true - out, _ := json.Marshal(m) - return headers, out, annotated, nil + request.Content, _ = json.Marshal(m) + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ) + t.Cleanup(func() { _ = DeregisterLlmRequestIntercept("go_llm_req_mod") }) request := makeRequest() result, err := LlmCallExecute("int_llm", request, @@ -550,8 +551,6 @@ func TestLlmRequestInterceptModifies(t *testing.T) { if output["saw_intercepted"] != true { t.Fatalf("expected saw_intercepted=true, got %v", output) } - - DeregisterLlmRequestIntercept("go_llm_req_mod") } func TestLlmExecutionInterceptReplaces(t *testing.T) { diff --git a/go/nemo_relay/nemo_relay.go b/go/nemo_relay/nemo_relay.go index 9799a9a76..87ab55291 100644 --- a/go/nemo_relay/nemo_relay.go +++ b/go/nemo_relay/nemo_relay.go @@ -148,7 +148,7 @@ extern int32_t nemo_relay_register_llm_conditional_execution_guardrail(const cha extern int32_t nemo_relay_deregister_llm_conditional_execution_guardrail(const char* name); // LLM intercepts -typedef int32_t (*NemoRelayLlmRequestInterceptCb)(void* user_data, const char* name, const FfiLLMRequest* request, const char* annotated_json, FfiLLMRequest** out_request, char** out_annotated_json); +typedef int32_t (*NemoRelayLlmRequestInterceptCb)(void* user_data, const char* name, const FfiLLMRequest* request, const char* annotated_json, char** out_outcome_json); extern int32_t nemo_relay_register_llm_request_intercept(const char* name, int32_t priority, _Bool break_chain, NemoRelayLlmRequestInterceptCb cb, void* user_data, NemoRelayFreeFn free_fn); extern int32_t nemo_relay_deregister_llm_request_intercept(const char* name); typedef char* (*NemoRelayLlmExecNextFn)(const char* native_json, void* next_ctx); @@ -269,7 +269,7 @@ extern char* goLlmExecInterceptTrampoline(void*, const char*, NemoRelayLlmExecNe extern char* goCodecDecodeTrampoline(void*, const FfiLLMRequest*); extern char* goCodecEncodeTrampoline(void*, const char*, const FfiLLMRequest*); extern int32_t goLlmRequestInterceptTrampoline( - void*, const char*, const FfiLLMRequest*, const char*, FfiLLMRequest**, char**); + void*, const char*, const FfiLLMRequest*, const char*, char**); */ import "C" @@ -2431,7 +2431,7 @@ func ToolConditionalExecution(name string, args json.RawMessage) error { // LlmRequestIntercepts runs the registered LLM request intercept chain on the // given request (serialized as JSON) and returns the transformed request JSON. -func LlmRequestIntercepts(name string, request json.RawMessage) (json.RawMessage, error) { +func LlmRequestIntercepts(name string, request json.RawMessage) (LLMRequestInterceptOutcome, error) { cName := C.CString(name) defer C.free(unsafe.Pointer(cName)) cRequest := C.CString(string(request)) @@ -2440,10 +2440,14 @@ func LlmRequestIntercepts(name string, request json.RawMessage) (json.RawMessage var out *C.char status := C.nemo_relay_llm_request_intercepts(cName, cRequest, &out) if err := checkStatus(status); err != nil { - return nil, err + return LLMRequestInterceptOutcome{}, err } defer C.nemo_relay_string_free(out) - return json.RawMessage(C.GoString(out)), nil + var outcome LLMRequestInterceptOutcome + if err := jsonUnmarshal([]byte(C.GoString(out)), &outcome); err != nil { + return LLMRequestInterceptOutcome{}, err + } + return outcome, nil } // LlmConditionalExecution runs the registered LLM conditional execution diff --git a/go/nemo_relay/plugin.go b/go/nemo_relay/plugin.go index c4a8affcf..af39171b6 100644 --- a/go/nemo_relay/plugin.go +++ b/go/nemo_relay/plugin.go @@ -18,7 +18,7 @@ typedef char* (*NemoRelayToolConditionalFn)(void* user_data, const char* name, c typedef void* (*NemoRelayLlmRequestCb)(void* user_data, const void* request); typedef char* (*NemoRelayLlmResponseFn)(void* user_data, const char* response_json); typedef char* (*NemoRelayLlmConditionalCb)(void* user_data, const void* request); -typedef int32_t (*NemoRelayLlmRequestInterceptCb)(void* user_data, const char* name, const void* request, const char* annotated_json, void** out_request, char** out_annotated_json); +typedef int32_t (*NemoRelayLlmRequestInterceptCb)(void* user_data, const char* name, const void* request, const char* annotated_json, char** out_outcome_json); typedef char* (*NemoRelayLlmExecNextFn)(const char* native_json, void* next_ctx); typedef char* (*NemoRelayLlmExecInterceptCb)(void* user_data, const char* native_json, NemoRelayLlmExecNextFn next_fn, void* next_ctx); typedef char* (*NemoRelayToolExecNextFn)(const char* args_json, void* next_ctx); @@ -56,7 +56,7 @@ extern void* goLlmRequestTrampoline(void*, const void*); extern char* goLlmResponseTrampoline(void*, const char*); extern char* goLlmConditionalTrampoline(void*, const void*); extern char* goLlmExecInterceptTrampoline(void*, const char*, NemoRelayLlmExecNextFn, void*); -extern int32_t goLlmRequestInterceptTrampoline(void*, const char*, const void*, const char*, void**, char**); +extern int32_t goLlmRequestInterceptTrampoline(void*, const char*, const void*, const char*, char**); extern char* goToolExecInterceptTrampoline(void*, const char*, NemoRelayToolExecNextFn, void*); */ import "C" diff --git a/go/nemo_relay/scope_local_test.go b/go/nemo_relay/scope_local_test.go index ca40b5218..9e3e62351 100644 --- a/go/nemo_relay/scope_local_test.go +++ b/go/nemo_relay/scope_local_test.go @@ -1128,9 +1128,9 @@ func assertScopeLocalLLMWrappersDeregister(t *testing.T, scopeUUID string, reque &requestInterceptCalls, func() error { return ScopeRegisterLlmRequestIntercept(scopeUUID, "llm_scope_req_int", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { requestInterceptCalls++ - return headers, content, annotated, nil + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ) }, diff --git a/go/nemo_relay/top_level_coverage_test.go b/go/nemo_relay/top_level_coverage_test.go index 455d945f5..8c4684bae 100644 --- a/go/nemo_relay/top_level_coverage_test.go +++ b/go/nemo_relay/top_level_coverage_test.go @@ -270,15 +270,17 @@ func assertCodecEncodeCoverage(t *testing.T, request *LLMRequest) { func assertLlmRequestInterceptPayloadCoverage(t *testing.T, request *LLMRequest) { t.Helper() - newHeaders, newContent, newAnnotated, err := llmRequestInterceptPayload( - func(name string, headers, content, annotatedJSON json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + outcome, err := llmRequestInterceptPayload( + func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { if name != coverageInterceptName { t.Fatalf("unexpected intercept name: %q", name) } if string(annotatedJSON) != `{"annotated":true}` { t.Fatalf("unexpected annotated payload: %s", annotatedJSON) } - return json.RawMessage(`{"updated":"headers"}`), json.RawMessage(`{"model":"updated"}`), json.RawMessage(`{"annotated":"updated"}`), nil + request.Headers = json.RawMessage(`{"updated":"headers"}`) + request.Content = json.RawMessage(`{"model":"updated"}`) + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: json.RawMessage(`{"annotated":"updated"}`)}, nil }, coverageInterceptName, request.Headers(), @@ -288,22 +290,22 @@ func assertLlmRequestInterceptPayloadCoverage(t *testing.T, request *LLMRequest) if err != nil { t.Fatalf("expected intercept success, got %v", err) } - if string(newHeaders) != `{"updated":"headers"}` { - t.Fatalf("unexpected output headers: %s", newHeaders) + if string(outcome.Request.Headers) != `{"updated":"headers"}` { + t.Fatalf("unexpected output headers: %s", outcome.Request.Headers) } - if string(newContent) != `{"model":"updated"}` { - t.Fatalf("unexpected output content: %s", newContent) + if string(outcome.Request.Content) != `{"model":"updated"}` { + t.Fatalf("unexpected output content: %s", outcome.Request.Content) } - if string(newAnnotated) != `{"annotated":"updated"}` { - t.Fatalf("unexpected output annotated json: %s", newAnnotated) + if string(outcome.AnnotatedRequest) != `{"annotated":"updated"}` { + t.Fatalf("unexpected output annotated json: %s", outcome.AnnotatedRequest) } - _, _, _, err = llmRequestInterceptPayload( - func(name string, headers, content, annotatedJSON json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + _, err = llmRequestInterceptPayload( + func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { if annotatedJSON != nil { t.Fatalf("expected nil annotated JSON, got %s", annotatedJSON) } - return nil, nil, nil, errors.New("forced intercept failure") + return LLMRequestInterceptOutcome{}, errors.New("forced intercept failure") }, coverageInterceptName, request.Headers(), @@ -400,8 +402,8 @@ func assertLlmCodecInterceptCoverage(t *testing.T) { }, } - if err := RegisterLlmRequestIntercept("coverage_llm_codec_success", 1, false, func(name string, headers, content, annotatedJSON json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return headers, content, json.RawMessage(`{"messages":[{"role":"user","content":"updated"}],"model":"updated-model"}`), nil + if err := RegisterLlmRequestIntercept("coverage_llm_codec_success", 1, false, func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: json.RawMessage(`{"messages":[{"role":"user","content":"updated"}],"model":"updated-model"}`)}, nil }); err != nil { t.Fatalf("RegisterLlmRequestIntercept success case failed: %v", err) } @@ -415,13 +417,37 @@ func assertLlmCodecInterceptCoverage(t *testing.T) { t.Fatalf("expected codec-backed intercept success, got %v", err) } - if err := RegisterLlmRequestIntercept("coverage_llm_codec_error", 1, false, func(name string, headers, content, annotatedJSON json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { - return nil, nil, nil, errors.New("forced codec-backed intercept failure") + if err := RegisterLlmRequestIntercept("coverage_llm_codec_raw_content", 1, false, func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { + request.Content = json.RawMessage(`{"model":"raw-model-edit"}`) + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotatedJSON}, nil + }); err != nil { + t.Fatalf("RegisterLlmRequestIntercept raw content case failed: %v", err) + } + t.Cleanup(func() { + if err := DeregisterLlmRequestIntercept("coverage_llm_codec_raw_content"); err != nil { + t.Errorf("failed to deregister raw content intercept: %v", err) + } + }) + providerCalled := false + _, err := LlmCallExecute("coverage_llm_codec_raw_content", map[string]any{ + "headers": map[string]any{}, + "content": map[string]any{"model": coverageModelName}, + }, func(json.RawMessage) (json.RawMessage, error) { + providerCalled = true + return json.RawMessage(`{"content":"unexpected"}`), nil + }, WithLLMCodec(requestCodec)) + assertErrorContains(t, err, "request.content", "codec-backed raw content mutation") + if providerCalled { + t.Fatal("provider should not run after a codec-backed raw content mutation") + } + + if err := RegisterLlmRequestIntercept("coverage_llm_codec_error", 1, false, func(name string, request LLMRequestDTO, annotatedJSON json.RawMessage) (LLMRequestInterceptOutcome, error) { + return LLMRequestInterceptOutcome{}, errors.New("forced codec-backed intercept failure") }); err != nil { t.Fatalf("RegisterLlmRequestIntercept error case failed: %v", err) } defer DeregisterLlmRequestIntercept("coverage_llm_codec_error") - _, err := LlmCallExecute("coverage_llm_codec_error", map[string]any{ + _, err = LlmCallExecute("coverage_llm_codec_error", map[string]any{ "headers": map[string]any{}, "content": map[string]any{"model": coverageModelName}, }, func(json.RawMessage) (json.RawMessage, error) { diff --git a/go/nemo_relay/wrapper_coverage_test.go b/go/nemo_relay/wrapper_coverage_test.go index 70eb9dd26..19812aa29 100644 --- a/go/nemo_relay/wrapper_coverage_test.go +++ b/go/nemo_relay/wrapper_coverage_test.go @@ -112,12 +112,12 @@ func TestStandaloneMiddlewareHelpers(t *testing.T) { } if err := RegisterLlmRequestIntercept("go_standalone_llm_req", 1, false, - func(name string, headers, content, annotated json.RawMessage) (json.RawMessage, json.RawMessage, json.RawMessage, error) { + func(name string, request LLMRequestDTO, annotated json.RawMessage) (LLMRequestInterceptOutcome, error) { var payload map[string]interface{} - _ = json.Unmarshal(content, &payload) + _ = json.Unmarshal(request.Content, &payload) payload["intercepted"] = true - out, _ := json.Marshal(payload) - return headers, out, annotated, nil + request.Content, _ = json.Marshal(payload) + return LLMRequestInterceptOutcome{Request: request, AnnotatedRequest: annotated}, nil }, ); err != nil { t.Fatalf("RegisterLlmRequestIntercept failed: %v", err) @@ -131,7 +131,7 @@ func TestStandaloneMiddlewareHelpers(t *testing.T) { var llmPayload struct { Content map[string]interface{} `json:"content"` } - if err := json.Unmarshal(request, &llmPayload); err != nil { + if err := json.Unmarshal(request.Request.Content, &llmPayload.Content); err != nil { t.Fatalf("unmarshal llm request: %v", err) } if llmPayload.Content["intercepted"] != true { diff --git a/python/nemo_relay/__init__.py b/python/nemo_relay/__init__.py index 44623269b..285c0a9d4 100644 --- a/python/nemo_relay/__init__.py +++ b/python/nemo_relay/__init__.py @@ -43,11 +43,13 @@ def add_header( name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: + ) -> nemo_relay.LLMRequestInterceptOutcome: # The request object is immutable, however we can return a new instance with updated headers. headers = request.headers.copy() headers["Authorization"] = "Bearer test-token" - return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + return nemo_relay.LLMRequestInterceptOutcome( + nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + ) async def tool_impl(args): return {"echo": args["query"]} @@ -96,11 +98,13 @@ async def main(): LLMAttributes, LLMHandle, LLMRequest, + LLMRequestInterceptOutcome, MarkEvent, OpenInferenceConfig, OpenInferenceSubscriber, OpenTelemetryConfig, OpenTelemetrySubscriber, + PendingMarkSpec, ScopeAttributes, ScopeEvent, ScopeHandle, @@ -162,12 +166,11 @@ async def main(): [str, Json, Callable[[Json], Awaitable[Json]]], Json | Awaitable[Json], ] -#: Request intercept callback that rewrites raw and annotated LLM requests -#: together. The return tuple supplies the request and optional annotated view -#: passed to later request intercepts and execution. +#: Request intercept callback that returns the canonical request, annotation, +#: and pending-mark outcome passed to later intercepts and managed execution. LlmRequestIntercept: TypeAlias = Callable[ [str, LLMRequest, AnnotatedLLMRequest | None], - tuple[LLMRequest, AnnotatedLLMRequest | None], + LLMRequestInterceptOutcome, ] #: Execution intercept callback that wraps non-streaming LLM execution. The #: callback receives the logical LLM name, request, and next callable. It may @@ -457,6 +460,8 @@ def worker() -> None: "ToolHandle", "LLMHandle", "LLMRequest", + "LLMRequestInterceptOutcome", + "PendingMarkSpec", "Event", "AnnotatedLLMRequest", "AnnotatedLLMResponse", diff --git a/python/nemo_relay/__init__.pyi b/python/nemo_relay/__init__.pyi index d3c2a3c8c..2c5a91456 100644 --- a/python/nemo_relay/__init__.pyi +++ b/python/nemo_relay/__init__.pyi @@ -69,6 +69,9 @@ from nemo_relay._native import ( from nemo_relay._native import ( LLMRequest as LLMRequest, ) +from nemo_relay._native import ( + LLMRequestInterceptOutcome as LLMRequestInterceptOutcome, +) from nemo_relay._native import ( MarkEvent as MarkEvent, ) @@ -84,6 +87,9 @@ from nemo_relay._native import ( from nemo_relay._native import ( OpenTelemetrySubscriber as OpenTelemetrySubscriber, ) +from nemo_relay._native import ( + PendingMarkSpec as PendingMarkSpec, +) from nemo_relay._native import ( ScopeAttributes as ScopeAttributes, ) @@ -209,7 +215,7 @@ Exceptional flow: """ LlmRequestIntercept: TypeAlias = Callable[ [str, LLMRequest, AnnotatedLLMRequest | None], - tuple[LLMRequest, AnnotatedLLMRequest | None], + LLMRequestInterceptOutcome, ] """Request intercept callback that rewrites raw and annotated LLM requests. @@ -217,7 +223,7 @@ Arguments: The logical LLM name, raw request, and optional annotated request view. Return: - The request and optional annotated view passed to later middleware. + The complete canonical outcome passed to later middleware. """ LlmExecutionIntercept: TypeAlias = Callable[ [str, LLMRequest, Callable[[LLMRequest], Awaitable[Json]]], diff --git a/python/nemo_relay/_native.pyi b/python/nemo_relay/_native.pyi index 8aed064ab..6acc400f5 100644 --- a/python/nemo_relay/_native.pyi +++ b/python/nemo_relay/_native.pyi @@ -42,7 +42,7 @@ _ToolExecutionIntercept: TypeAlias = Callable[ ] _LlmRequestIntercept: TypeAlias = Callable[ [str, "LLMRequest", "AnnotatedLLMRequest | None"], - tuple["LLMRequest", "AnnotatedLLMRequest | None"], + "LLMRequestInterceptOutcome", ] _LlmExecutionIntercept: TypeAlias = Callable[ [str, "LLMRequest", Callable[["LLMRequest"], Awaitable[_Json]]], @@ -382,6 +382,42 @@ class LLMRequest: """Return the request content body as a JSON object.""" ... +class PendingMarkSpec: + """A runtime-owned mark specification returned by request middleware.""" + def __init__( + self, + name: str, + category: Optional[str] = ..., + category_profile: Optional[_Json] = ..., + data: Optional[_Json] = ..., + metadata: Optional[_Json] = ..., + ) -> None: ... + @property + def name(self) -> str: ... + @property + def category(self) -> Optional[str]: ... + @property + def category_profile(self) -> Optional[_Json]: ... + @property + def data(self) -> Optional[_Json]: ... + @property + def metadata(self) -> Optional[_Json]: ... + +class LLMRequestInterceptOutcome: + """Canonical result returned by an LLM request intercept.""" + def __init__( + self, + request: LLMRequest, + annotated_request: Optional[AnnotatedLLMRequest] = ..., + pending_marks: list[PendingMarkSpec] = ..., + ) -> None: ... + @property + def request(self) -> LLMRequest: ... + @property + def annotated_request(self) -> Optional[AnnotatedLLMRequest]: ... + @property + def pending_marks(self) -> list[PendingMarkSpec]: ... + class AnnotatedLLMRequest: """Structured view of an LLM request produced by a codec. @@ -1509,7 +1545,7 @@ def tool_conditional_execution(name: str, args: _Json) -> None: """ ... -def llm_request_intercepts(name: str, request: LLMRequest) -> LLMRequest: +def llm_request_intercepts(name: str, request: LLMRequest) -> LLMRequestInterceptOutcome: """Run the registered LLM request-intercept chain. Args: diff --git a/python/nemo_relay/intercepts.py b/python/nemo_relay/intercepts.py index ac9f340df..5392a6c46 100644 --- a/python/nemo_relay/intercepts.py +++ b/python/nemo_relay/intercepts.py @@ -14,11 +14,13 @@ def add_header( name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: + ) -> nemo_relay.LLMRequestInterceptOutcome: # The request object is immutable, however we can return a new instance with updated headers. headers = request.headers.copy() headers["X-Trace"] = "demo" - return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + return nemo_relay.LLMRequestInterceptOutcome( + nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + ) nemo_relay.intercepts.register_llm_request("trace-header", 10, False, add_header) """ @@ -167,9 +169,9 @@ def register_llm_request(name: str, priority: int, break_chain: bool, fn: LlmReq priority: Execution order for the intercept. Lower values run first. break_chain: Whether to stop applying lower-priority request intercepts after this intercept runs. - fn: Callable invoked as ``fn(name, request, annotated)`` that returns a - tuple of ``(request, annotated)`` for the next intercept or the - provider callback. + fn: Callable invoked as ``fn(name, request, annotated)`` that returns an + ``nemo_relay.LLMRequestInterceptOutcome`` for the next intercept or + the provider callback. Returns: None: This function returns after the intercept is registered. @@ -186,10 +188,12 @@ def register_llm_request(name: str, priority: int, break_chain: bool, fn: LlmReq def add_header( name: str, request: nemo_relay.LLMRequest, annotated: nemo_relay.AnnotatedLLMRequest | None - ) -> tuple[nemo_relay.LLMRequest, nemo_relay.AnnotatedLLMRequest | None]: + ) -> nemo_relay.LLMRequestInterceptOutcome: headers = request.headers.copy() headers["X-Trace"] = "req-123" - return nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + return nemo_relay.LLMRequestInterceptOutcome( + nemo_relay.LLMRequest(headers=headers, content=request.content), annotated + ) nemo_relay.intercepts.register_llm_request( "trace-header", diff --git a/python/nemo_relay/llm.py b/python/nemo_relay/llm.py index a99724593..78657143f 100644 --- a/python/nemo_relay/llm.py +++ b/python/nemo_relay/llm.py @@ -378,7 +378,8 @@ def request_intercepts(name, request): intercept chain. Returns: - LLMRequest: The request produced by the final request intercept. + LLMRequestInterceptOutcome: The complete request, annotation, and + pending-mark outcome produced by the intercept chain. Notes: This runs only the request-intercept chain. It does not execute diff --git a/python/plugin/README.md b/python/plugin/README.md index ee0caeca9..2378cffa1 100644 --- a/python/plugin/README.md +++ b/python/plugin/README.md @@ -36,6 +36,27 @@ Set `load.entrypoint` to `your_module:main` in `relay-plugin.toml`. Relay imports that function and awaits the returned coroutine when it starts the worker process. +LLM request intercepts return one canonical outcome: + +```python +from nemo_relay_plugin import LlmRequestInterceptOutcome, PendingMarkSpec + + +def intercept(model_name, request, annotated): + del model_name + headers = {**request.get("headers", {}), "x-policy": "checked"} + return LlmRequestInterceptOutcome( + request={**request, "headers": headers}, + annotated_request=annotated, + pending_marks=[PendingMarkSpec("acme.policy.checked")], + ) +``` + +When `annotated` is present, it is authoritative for provider-body content: +leave raw `request["content"]` unchanged, edit normalized fields or provider +extensions through the annotation, and use `request["headers"]` for transport +headers. + The SDK owns gRPC serving, JSON envelope conversion, callback dispatch, continuations, host runtime calls, and local scope-stack binding. Its private protobuf bindings are generated from the canonical Relay schema while the diff --git a/python/plugin/src/nemo_relay_plugin/__init__.py b/python/plugin/src/nemo_relay_plugin/__init__.py index 2ed6a2267..676a0de3b 100644 --- a/python/plugin/src/nemo_relay_plugin/__init__.py +++ b/python/plugin/src/nemo_relay_plugin/__init__.py @@ -21,6 +21,8 @@ LlmRequest: A Relay LLM request represented as a JSON object. AnnotatedLlmRequest: An annotated Relay LLM request represented as a JSON object. + PendingMarkSpec: A mark Relay emits under the future managed LLM scope. + LlmRequestInterceptOutcome: Canonical LLM request-intercept result. DiagnosticLevel: Severity of a configuration diagnostic. ConfigDiagnostic: Structured configuration warning or error. ScopeType: Semantic category for a Relay execution scope. @@ -62,10 +64,12 @@ LlmNext, LlmRequest, LlmRequestCallback, + LlmRequestInterceptOutcome, LlmSanitizeRequestCallback, LlmSanitizeResponseCallback, LlmStreamExecutionCallback, LlmStreamNext, + PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, @@ -91,12 +95,14 @@ "LlmNext", "LlmRequest", "LlmRequestCallback", + "LlmRequestInterceptOutcome", "LlmSanitizeRequestCallback", "LlmSanitizeResponseCallback", "LlmStreamNext", "LlmStreamExecutionCallback", "PluginContext", "PluginRuntime", + "PendingMarkSpec", "ScopeType", "SubscriberCallback", "ToolConditionalCallback", diff --git a/python/plugin/src/nemo_relay_plugin/_api.py b/python/plugin/src/nemo_relay_plugin/_api.py index 4feabbc75..06bef440e 100644 --- a/python/plugin/src/nemo_relay_plugin/_api.py +++ b/python/plugin/src/nemo_relay_plugin/_api.py @@ -16,6 +16,8 @@ LlmRequest: A Relay LLM request represented as a JSON object. AnnotatedLlmRequest: An annotated Relay LLM request represented as a JSON object. + PendingMarkSpec: A mark Relay emits under the future managed LLM scope. + LlmRequestInterceptOutcome: Canonical LLM request-intercept result. DiagnosticLevel: Severity of a configuration diagnostic. ConfigDiagnostic: Structured configuration warning or error. ScopeType: Semantic category for a Relay execution scope. @@ -63,7 +65,7 @@ import tempfile import tomllib from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Iterator, Mapping -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from enum import Enum from importlib import metadata from pathlib import Path @@ -88,8 +90,16 @@ EVENT_SCHEMA = "nemo.relay.Event@1" LLM_REQUEST_SCHEMA = "nemo.relay.LlmRequest@1" ANNOTATED_LLM_REQUEST_SCHEMA = "nemo.relay.AnnotatedLlmRequest@1" +LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA = "nemo.relay.LlmRequestInterceptOutcome@1" PLUGIN_DIAGNOSTICS_SCHEMA = "nemo.relay.PluginDiagnostics@1" -_OBJECT_SCHEMAS = frozenset({EVENT_SCHEMA, LLM_REQUEST_SCHEMA, ANNOTATED_LLM_REQUEST_SCHEMA}) +_OBJECT_SCHEMAS = frozenset( + { + EVENT_SCHEMA, + LLM_REQUEST_SCHEMA, + ANNOTATED_LLM_REQUEST_SCHEMA, + LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA, + } +) _UNREGISTERED = object() _SCOPE_CONTEXT: contextvars.ContextVar[_BoundScopeContext | None] = contextvars.ContextVar( "nemo_relay_plugin_scope_context", @@ -171,6 +181,49 @@ def to_json(self) -> dict[str, Any]: return _normalize_diagnostic(asdict(self)) +@dataclass(slots=True) +class PendingMarkSpec: + """Describe a mark Relay emits after starting a managed LLM call.""" + + name: str + category: str | None = None + category_profile: Json | None = None + data: Json | None = None + metadata: Json | None = None + + def to_json(self) -> dict[str, Json]: + """Convert this pending mark to its canonical JSON object.""" + if not isinstance(self.name, str): + raise WorkerSdkError("pending mark name must be a string") + return asdict(self) + + +@dataclass(slots=True) +class LlmRequestInterceptOutcome: + """Canonical result returned by a Python worker LLM request intercept.""" + + request: LlmRequest + annotated_request: AnnotatedLlmRequest | None = None + pending_marks: list[PendingMarkSpec] = field(default_factory=list) + + def to_json(self) -> dict[str, Json]: + """Convert this outcome to the canonical worker-envelope payload.""" + if not isinstance(self.request, dict): + raise WorkerSdkError("LLM request intercept outcome request must be a JSON object") + if self.annotated_request is not None and not isinstance(self.annotated_request, dict): + raise WorkerSdkError("LLM request intercept outcome annotated_request must be a JSON object or None") + marks = [] + for mark in self.pending_marks: + if not isinstance(mark, PendingMarkSpec): + raise WorkerSdkError("LLM request intercept outcome pending_marks must contain PendingMarkSpec values") + marks.append(mark.to_json()) + return { + "request": self.request, + "annotated_request": self.annotated_request, + "pending_marks": marks, + } + + def _normalize_diagnostic(value: Mapping[str, Any]) -> dict[str, Any]: try: level = DiagnosticLevel(value.get("level")).value @@ -325,9 +378,7 @@ def register(self, ctx: PluginContext, config: Json) -> None | Awaitable[None]: LlmConditionalCallback: TypeAlias = Callable[[LlmRequest], str | None | Awaitable[str | None]] LlmRequestCallback: TypeAlias = Callable[ [str, LlmRequest, AnnotatedLlmRequest | None], - LlmRequest - | tuple[LlmRequest, AnnotatedLlmRequest | None] - | Awaitable[LlmRequest | tuple[LlmRequest, AnnotatedLlmRequest | None]], + LlmRequestInterceptOutcome | Awaitable[LlmRequestInterceptOutcome], ] LlmExecutionCallback: TypeAlias = Callable[[str, LlmRequest, "LlmNext"], Json | Awaitable[Json]] LlmStreamExecutionCallback: TypeAlias = Callable[ @@ -593,10 +644,11 @@ def register_llm_request_intercept( Args: name: Component-local registration name. callback: Function receiving ``(model_name, request, - annotated_request)``. Return a replacement :data:`LlmRequest` - or ``(request, annotated_request)`` tuple, directly or through - an awaitable. ``annotated_request`` is ``None`` when the host - did not provide one. + annotated_request)``. Return + :class:`LlmRequestInterceptOutcome`, directly or through an + awaitable. ``annotated_request`` is ``None`` when the host did + not provide one. When present, it is authoritative for provider + body content and the raw request content must remain unchanged. priority: Execution order. Lower values run first. break_chain: Whether Relay skips later, lower-priority request intercepts after this callback runs. @@ -1429,15 +1481,14 @@ async def _invoke_result(self, request: Any) -> Any: annotated, ) ) - if isinstance(result, tuple): - llm_request, annotated = result - else: - llm_request = result + if not isinstance(result, LlmRequestInterceptOutcome): + raise WorkerSdkError("LLM request intercept must return LlmRequestInterceptOutcome") return pb.InvokeResponse( llm_request=pb.LlmRequestInterceptResult( - request=_json_envelope(LLM_REQUEST_SCHEMA, llm_request), - annotated_request=_optional_json_envelope(annotated, ANNOTATED_LLM_REQUEST_SCHEMA), - has_annotated_request=annotated is not None, + outcome=_json_envelope( + LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA, + result.to_json(), + ), ) ) if request.surface == pb.LLM_EXECUTION_INTERCEPT: diff --git a/python/tests/integrations/langchain_tests/test_middleware.py b/python/tests/integrations/langchain_tests/test_middleware.py index 7d2c7813c..fffb369b3 100644 --- a/python/tests/integrations/langchain_tests/test_middleware.py +++ b/python/tests/integrations/langchain_tests/test_middleware.py @@ -323,7 +323,7 @@ def change_request(name: str, request: nemo_relay.LLMRequest, annotated: Any): else message for message in annotated.messages ] - return request, annotated + return nemo_relay.LLMRequestInterceptOutcome(request, annotated) nemo_relay.intercepts.register_llm_request("test_langchain_change_request", 1, False, change_request) try: diff --git a/python/tests/plugin/test_worker_sdk.py b/python/tests/plugin/test_worker_sdk.py index ee474e2c3..b011da9ab 100644 --- a/python/tests/plugin/test_worker_sdk.py +++ b/python/tests/plugin/test_worker_sdk.py @@ -13,7 +13,7 @@ import tempfile from collections.abc import AsyncIterator from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast import pytest @@ -26,6 +26,8 @@ ConfigDiagnostic, DiagnosticLevel, Json, + LlmRequestInterceptOutcome, + PendingMarkSpec, PluginContext, PluginRuntime, ScopeType, @@ -39,6 +41,7 @@ ANNOTATED_LLM_REQUEST_SCHEMA, EVENT_SCHEMA, JSON_SCHEMA, + LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA, LLM_REQUEST_SCHEMA, WORKER_PROTOCOL, _announced_worker_endpoint, @@ -201,9 +204,13 @@ def llm_block(request: Json) -> str | None: del request return "llm blocked" - def llm_request(name: str, request: Json, annotated: Json | None) -> tuple[Json, Json]: + def llm_request(name: str, request: Json, annotated: Json | None) -> LlmRequestInterceptOutcome: del name - return _tag_llm_request(request, "llm_request"), _tag(annotated or {}, "annotated") + return LlmRequestInterceptOutcome( + request=_tag_llm_request(request, "llm_request"), + annotated_request=_tag(annotated, "annotated") if annotated is not None else None, + pending_marks=[PendingMarkSpec("worker.pending", data={"source": "python"})], + ) async def llm_execution(name: str, request: Json, next_call: Any) -> Json: result = await next_call.call(_tag_llm_request(request, f"llm_execute_{name}")) @@ -937,9 +944,32 @@ async def test_unary_invoke_success_paths(service: _WorkerService, host_stub: Re ), AbortContext(), ) - assert _envelope_value(llm_request.llm_request.request)["content"]["llm_request"] - assert _envelope_value(llm_request.llm_request.annotated_request)["tag"] == "annotated" - assert llm_request.llm_request.has_annotated_request + assert llm_request.llm_request.outcome.schema == LLM_REQUEST_INTERCEPT_OUTCOME_SCHEMA + outcome = _envelope_value(llm_request.llm_request.outcome) + assert outcome["request"]["content"]["llm_request"] + assert outcome["annotated_request"]["tag"] == "annotated" + assert outcome["pending_marks"] == [ + { + "name": "worker.pending", + "category": None, + "category_profile": None, + "data": {"source": "python"}, + "metadata": None, + } + ] + + request_only = await service.Invoke( + _invoke_request( + "llm_request", + pb.LLM_REQUEST_INTERCEPT, + llm=_llm_payload(request={"content": {"prompt": "hello"}}), + ), + AbortContext(), + ) + request_only_outcome = _envelope_value(request_only.llm_request.outcome) + assert request_only_outcome["request"]["content"]["llm_request"] + assert request_only_outcome["annotated_request"] is None + assert request_only_outcome["pending_marks"] == outcome["pending_marks"] llm_execution = await _invoke_json_async( service, @@ -1011,8 +1041,8 @@ def register(self, ctx: PluginContext, config: Json) -> None: def invalid_result(name: str, request: Json, annotated: Json | None) -> Any: del name if invalid_part == "request": - return [] - return request, [] + return LlmRequestInterceptOutcome(request=cast(Any, [])) + return LlmRequestInterceptOutcome(request=request, annotated_request=cast(Any, [])) ctx.register_llm_request_intercept("invalid", invalid_result) @@ -1127,9 +1157,9 @@ class RequestOnlyPlugin(WorkerPlugin): def register(self, ctx: PluginContext, config: Json) -> None: del config - def llm_request(name: str, request: Json, annotated: Json | None) -> Json: + def llm_request(name: str, request: Json, annotated: Json | None) -> LlmRequestInterceptOutcome: del name, annotated - return _tag_llm_request(request, "request_only") + return LlmRequestInterceptOutcome(request=_tag_llm_request(request, "request_only")) ctx.register_llm_request_intercept("request_only", llm_request) @@ -1143,8 +1173,10 @@ def llm_request(name: str, request: Json, annotated: Json | None) -> Json: ), AbortContext(), ) - assert _envelope_value(response.llm_request.request)["content"]["request_only"] - assert not response.llm_request.has_annotated_request + outcome = _envelope_value(response.llm_request.outcome) + assert outcome["request"]["content"]["request_only"] + assert outcome["annotated_request"] is None + assert outcome["pending_marks"] == [] async def test_stream_invoke_success_and_failures(service: _WorkerService, host_stub: RecordingHostStub): diff --git a/python/tests/test_adaptive.py b/python/tests/test_adaptive.py index f3fb92547..c7e956e7b 100644 --- a/python/tests/test_adaptive.py +++ b/python/tests/test_adaptive.py @@ -9,7 +9,17 @@ import pytest -from nemo_relay import AnnotatedLLMRequest, JsonObject, LLMRequest, ScopeType, llm, plugin, scope, tools +from nemo_relay import ( + AnnotatedLLMRequest, + JsonObject, + LLMRequest, + LLMRequestInterceptOutcome, + ScopeType, + llm, + plugin, + scope, + tools, +) from nemo_relay import adaptive as adaptive_module from nemo_relay.adaptive import ( ADAPTIVE_PLUGIN_KIND, @@ -213,7 +223,7 @@ async def test_adaptive_runtime_bind_scope_passes_through_without_state(self): with scope.scope("adaptive-runtime-translate", ScopeType.Agent) as handle: runtime.bind_scope(handle) translated = llm.request_intercepts("anthropic", request) - assert translated.content == { + assert translated.request.content == { "messages": [{"role": "user", "content": "Hello"}], "system": "You are helpful.", "model": "claude-sonnet-4-20250514", @@ -287,7 +297,7 @@ def register(self, plugin_config, context): def intercept(_name, request, annotated): headers = dict(request.headers) headers["x-python-plugin"] = f"priority:{priority}" - return LLMRequest(headers, request.content), annotated + return LLMRequestInterceptOutcome(LLMRequest(headers, request.content), annotated) async def llm_exec_intercept(_name, request, next_call): response = await next_call(request) diff --git a/python/tests/test_codecs.py b/python/tests/test_codecs.py index d49c70513..d5ef6810b 100644 --- a/python/tests/test_codecs.py +++ b/python/tests/test_codecs.py @@ -12,10 +12,13 @@ from typing import cast +import pytest + from nemo_relay import ( AnnotatedLLMRequest, JsonObject, LLMRequest, + LLMRequestInterceptOutcome, intercepts, llm, ) @@ -267,7 +270,7 @@ def annotated_intercept(name, request, annotated): *annotated.messages, ] annotated.messages = new_messages - return (request, annotated) + return LLMRequestInterceptOutcome(request, annotated) intercepts.register_llm_request("test-annot-intercept-pipeline", 1, False, annotated_intercept) @@ -286,6 +289,48 @@ def func(request): finally: intercepts.deregister_llm_request("test-annot-intercept-pipeline") + async def test_codec_rejects_raw_content_edits_before_provider(self): + """Codec-aware intercepts must edit the annotation, not the raw body.""" + provider_called = False + + def raw_content_intercept(name, request, annotated): + content = {**request.content, "model": "raw-model-edit"} + return LLMRequestInterceptOutcome(LLMRequest(request.headers, content), annotated) + + def provider(request): + nonlocal provider_called + provider_called = True + return {"unexpected": True} + + intercepts.register_llm_request("test-codec-raw-content", 1, False, raw_content_intercept) + try: + with pytest.raises(RuntimeError, match=r"request\.content"): + await llm.execute("pipeline-llm", make_request(), provider, codec=SimpleCodec()) + assert not provider_called + finally: + intercepts.deregister_llm_request("test-codec-raw-content") + + async def test_codec_rejects_missing_annotation_before_provider(self): + """Codec-aware intercepts must return the decoded annotation.""" + provider_called = False + + def missing_annotation_intercept(name, request, annotated): + assert annotated is not None + return LLMRequestInterceptOutcome(request, None) + + def provider(request): + nonlocal provider_called + provider_called = True + return {"unexpected": True} + + intercepts.register_llm_request("test-codec-missing-annotation", 1, False, missing_annotation_intercept) + try: + with pytest.raises(RuntimeError, match="omitted annotated_request"): + await llm.execute("pipeline-llm", make_request(), provider, codec=SimpleCodec()) + assert not provider_called + finally: + intercepts.deregister_llm_request("test-codec-missing-annotation") + async def test_codec_parameter(self): """codec parameter passes the specified codec instance directly.""" alternate = AlternateCodec() @@ -295,7 +340,7 @@ async def test_codec_parameter(self): def annotated_intercept(name, request, annotated): if annotated is not None: intercept_data["extra"] = annotated.extra - return (request, annotated) + return LLMRequestInterceptOutcome(request, annotated) intercepts.register_llm_request("test-annot-intercept-cn", 1, False, annotated_intercept) @@ -324,7 +369,7 @@ def annotated_intercept(name, request, annotated): assert annotated is not None assert isinstance(annotated, AnnotatedLLMRequest) assert isinstance(request, LLMRequest) - return (request, annotated) + return LLMRequestInterceptOutcome(request, annotated) intercepts.register_llm_request("test-annot-typed", 1, False, annotated_intercept) diff --git a/python/tests/test_integration_codecs.py b/python/tests/test_integration_codecs.py index 4e4facf16..7330d9593 100644 --- a/python/tests/test_integration_codecs.py +++ b/python/tests/test_integration_codecs.py @@ -13,7 +13,15 @@ from typing import cast -from nemo_relay import AnnotatedLLMRequest, JsonObject, LLMRequest, ScopeType, llm, scope +from nemo_relay import ( + AnnotatedLLMRequest, + JsonObject, + LLMRequest, + LLMRequestInterceptOutcome, + ScopeType, + llm, + scope, +) from nemo_relay.codecs import ( LlmCodec, ) @@ -448,7 +456,7 @@ def annotated_intercept(name, request, annotated): if annotated is not None: intercept_data["model"] = annotated.model intercept_data["messages"] = annotated.messages - return (request, annotated) + return LLMRequestInterceptOutcome(request, annotated) intercepts.register_llm_request("delegation-test-intercept", 1, False, annotated_intercept) diff --git a/python/tests/test_llm.py b/python/tests/test_llm.py index 49aa8f070..5e9c738ef 100644 --- a/python/tests/test_llm.py +++ b/python/tests/test_llm.py @@ -11,6 +11,8 @@ LLMAttributes, LLMHandle, LLMRequest, + LLMRequestInterceptOutcome, + PendingMarkSpec, ScopeEvent, ScopeType, guardrails, @@ -285,20 +287,34 @@ def func(request): class TestLLMIntercepts: def test_request_intercept(self): # Request intercepts now operate on LLMRequest - intercepts.register_llm_request("py_llm_req", 1, False, lambda name, request, annotated: (request, annotated)) + intercepts.register_llm_request( + "py_llm_req", + 1, + False, + lambda name, request, annotated: LLMRequestInterceptOutcome(request, annotated), + ) assert intercepts.deregister_llm_request("py_llm_req") def test_request_intercepts_direct(self): + pending_mark = PendingMarkSpec("request.direct", data={"source": "python"}) + def intercept_fn(name, request, annotated): content = request.content content["direct"] = True - return LLMRequest(request.headers, content), annotated + return LLMRequestInterceptOutcome( + LLMRequest(request.headers, content), + annotated, + [pending_mark], + ) intercepts.register_llm_request("py_llm_req_direct", 1, False, intercept_fn) transformed = llm.request_intercepts("direct_llm", make_request()) intercepts.deregister_llm_request("py_llm_req_direct") - assert transformed.content["direct"] is True + assert transformed.request.content["direct"] is True + assert len(transformed.pending_marks) == 1 + assert transformed.pending_marks[0].name == pending_mark.name + assert transformed.pending_marks[0].data == pending_mark.data def test_request_intercept_raises_on_exception(self): intercepts.register_llm_request( @@ -316,7 +332,7 @@ def test_request_intercept_raises_on_exception(self): def test_request_intercept_raises_on_invalid_return(self): intercepts.register_llm_request("py_llm_req_bad_return", 1, False, lambda name, request, annotated: object()) # type: ignore[arg-type] # ty: ignore[invalid-argument-type] try: - with pytest.raises(RuntimeError, match="result\\[0\\] extraction failed"): + with pytest.raises(RuntimeError, match="must return LLMRequestInterceptOutcome"): llm.request_intercepts("bad_return_llm", make_request()) finally: intercepts.deregister_llm_request("py_llm_req_bad_return") @@ -358,7 +374,7 @@ def intercept_fn(name, request, annotated): # Request intercepts now operate on LLMRequest content = request.content content["intercepted"] = True - return LLMRequest(request.headers, content), annotated + return LLMRequestInterceptOutcome(LLMRequest(request.headers, content), annotated) intercepts.register_llm_request("py_llm_req_mod", 1, False, intercept_fn) diff --git a/python/tests/test_scope_local.py b/python/tests/test_scope_local.py index 0b4fcfe8d..7746ed48e 100644 --- a/python/tests/test_scope_local.py +++ b/python/tests/test_scope_local.py @@ -16,6 +16,7 @@ from nemo_relay import ( JsonObject, LLMRequest, + LLMRequestInterceptOutcome, MarkEvent, ScopeEvent, ScopeType, @@ -648,7 +649,10 @@ async def test_scope_local_llm_request_intercept_modifies_request(self): request = LLMRequest({}, {"messages": [], "model": "scope-local"}) def intercept(name, req, annotated): - return LLMRequest(req.headers, {**req.content, "intercepted": True}), annotated + return LLMRequestInterceptOutcome( + LLMRequest(req.headers, {**req.content, "intercepted": True}), + annotated, + ) with scope.scope("sl_llm_request_scope", ScopeType.Agent) as handle: scope_local.register_llm_request(handle, "sl_llm_request", 1, False, intercept)