Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions crates/adaptive/src/intercepts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,26 +128,35 @@ pub(crate) fn create_tool_execution_intercept_with_mode(
let name = name.to_string();
Box::pin(async move {
let Some(cohort_key) = resolve_warm_first_cohort_key(&name, &mode, &cache) else {
return next(args).await;
return next(args).await.map(Into::into);
};

match resolve_warm_first_role(&registry, cohort_key.clone()).await {
WarmFirstRole::Primer(gate) => {
let result = next(args).await;
gate.release();
cleanup_cohort_gate(&registry, &cohort_key, &gate).await;
result
result.map(Into::into)
}
WarmFirstRole::Follower(gate) => {
let _ = tokio::time::timeout(
Duration::from_millis(WARM_FIRST_MAX_WAIT_MS),
gate.wait_for_release(),
)
.await;
next(args).await
next(args).await.map(Into::into)
}
}
}) as Pin<Box<dyn Future<Output = FlowResult<Json>> + Send>>
})
as Pin<
Box<
dyn Future<
Output = FlowResult<
nemo_relay::api::tool::ToolExecutionInterceptOutcome,
>,
> + Send,
>,
>
})
}

Expand Down
10 changes: 5 additions & 5 deletions crates/adaptive/tests/unit/intercepts_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async fn test_tool_intercept_calls_next() {

let result = intercept("test", json!({"input": 1}), next).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!({"result": "ok"}));
assert_eq!(result.unwrap(), json!({"result": "ok"}).into());
}

#[tokio::test]
Expand All @@ -125,7 +125,7 @@ async fn test_tool_intercept_with_populated_cache() {
// Should not panic and should return next's result
let result = intercept("test", json!({"tool_input": "data"}), next).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), json!({"from_next": true}));
assert_eq!(result.unwrap(), json!({"from_next": true}).into());
}

#[tokio::test]
Expand All @@ -147,7 +147,7 @@ async fn test_tool_intercept_passes_args_to_next() {
let input = json!({"tool_arg": "value", "count": 42});
let result = intercept("test", input.clone(), next).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), input);
assert_eq!(result.unwrap(), input.into());
}

#[test]
Expand Down Expand Up @@ -325,8 +325,8 @@ async fn test_schedule_mode_intercept_waits_for_primer_before_running_follower()
tokio::task::yield_now().await;
let follower = tokio::spawn(intercept("search", json!({"call": 2}), next.clone()));

assert_eq!(primer.await.unwrap().unwrap(), json!({"call": 1}));
assert_eq!(follower.await.unwrap().unwrap(), json!({"call": 2}));
assert_eq!(primer.await.unwrap().unwrap(), json!({"call": 1}).into());
assert_eq!(follower.await.unwrap().unwrap(), json!({"call": 2}).into());
assert_eq!(next_order.load(Ordering::SeqCst), 2);

reset_scope_stack();
Expand Down
16 changes: 12 additions & 4 deletions crates/adaptive/tests/unit/runtime_features_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,22 @@ fn assert_llm_stream_execution_intercept_absent(name: &str) {

fn assert_tool_execution_intercept_registered(name: &str) {
assert_already_registered(
register_tool_execution_intercept(name, i32::MAX, Arc::new(|_name, args, next| next(args))),
register_tool_execution_intercept(
name,
i32::MAX,
Arc::new(|_name, args, next| Box::pin(async move { next(args).await.map(Into::into) })),
),
name,
);
}

fn assert_tool_execution_intercept_absent(name: &str) {
register_tool_execution_intercept(name, i32::MAX, Arc::new(|_name, args, next| next(args)))
.unwrap();
register_tool_execution_intercept(
name,
i32::MAX,
Arc::new(|_name, args, next| Box::pin(async move { next(args).await.map(Into::into) })),
)
.unwrap();
deregister_tool_execution_intercept(name).unwrap();
}

Expand Down Expand Up @@ -753,7 +761,7 @@ async fn registration_context_registers_all_supported_callback_types() {
ctx.register_tool_execution_intercept(
"adaptive_test_tool",
8,
Arc::new(|_name, args, _next| Box::pin(async move { Ok(args) })),
Arc::new(|_name, args, _next| Box::pin(async move { Ok(args.into()) })),
)
.unwrap();

Expand Down
7 changes: 5 additions & 2 deletions crates/core/src/api/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,9 @@ global_intercept_registry_api!(
);
global_execution_registry_api!(
/// Register a global tool execution intercept.
/// Execution intercepts can wrap or replace the tool callback.
/// Execution intercepts can wrap or replace the tool callback. Each
/// callback returns a canonical tool execution outcome, while its
/// continuation resolves to the raw downstream result JSON.
register_tool_execution_intercept,
/// Deregister a global tool execution intercept.
deregister_tool_execution_intercept,
Expand Down Expand Up @@ -616,7 +618,8 @@ scope_intercept_registry_api!(
scope_execution_registry_api!(
/// Register a scope-local tool execution intercept.
/// Execution intercepts can wrap or replace the tool callback inside the
/// owning scope.
/// owning scope. Each callback returns a canonical tool execution outcome,
/// while its continuation resolves to the raw downstream result JSON.
scope_register_tool_execution_intercept,
/// Deregister a scope-local tool execution intercept.
scope_deregister_tool_execution_intercept,
Expand Down
21 changes: 18 additions & 3 deletions crates/core/src/api/runtime/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tokio_stream::Stream;

use crate::api::event::Event;
use crate::api::llm::{LlmRequest, LlmRequestInterceptOutcome};
use crate::api::tool::ToolExecutionInterceptOutcome;
use crate::codec::request::AnnotatedLlmRequest;
use crate::error::Result;
use crate::json::Json;
Expand Down Expand Up @@ -81,7 +82,9 @@ pub type ToolInterceptFn = Arc<dyn Fn(&str, Json) -> Result<Json> + Send + Sync>
/// chain.
///
/// # Returns
/// A future resolving to the tool result JSON.
/// A future resolving to the downstream tool result JSON. Pending marks from
/// downstream intercepts are retained by the runtime and are not exposed
/// through this continuation.
///
/// # Errors
/// The future resolves to an error when the remaining execution chain fails.
Expand All @@ -98,13 +101,25 @@ pub type ToolExecutionNextFn =
/// - Third argument: Continuation for the remaining execution chain.
///
/// # Returns
/// A future resolving to the tool result JSON.
/// A future resolving to the canonical tool execution outcome, containing the
/// tool result and any pending lifecycle marks produced by this intercept.
///
/// # Errors
/// The future resolves to an error when the intercept or remaining execution
/// chain fails.
pub type ToolExecutionFn = Arc<
dyn Fn(&str, Json, ToolExecutionNextFn) -> Pin<Box<dyn Future<Output = Result<Json>> + Send>>
dyn Fn(
&str,
Json,
ToolExecutionNextFn,
) -> Pin<Box<dyn Future<Output = Result<ToolExecutionInterceptOutcome>> + Send>>
+ Send
+ Sync,
>;

/// Internal continuation carrying both a tool result and accumulated marks.
pub(crate) type ToolExecutionOutcomeNextFn = Arc<
dyn Fn(Json) -> Pin<Box<dyn Future<Output = Result<ToolExecutionInterceptOutcome>> + Send>>
+ Send
+ Sync,
>;
Expand Down
65 changes: 57 additions & 8 deletions crates/core/src/api/runtime/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};

use crate::api::event::{
BaseEvent, CategoryProfile, Event, EventCategory, MarkEvent, ScopeCategory, ScopeEvent,
Expand All @@ -23,12 +24,14 @@ use crate::api::runtime::callbacks::{
EventSubscriberFn, LlmConditionalFn, LlmExecutionFn, LlmExecutionNextFn, LlmRequestInterceptFn,
LlmSanitizeRequestFn, LlmSanitizeResponseFn, LlmStreamExecutionFn, LlmStreamExecutionNextFn,
LlmStreamExecutionRegistryRefs, ToolConditionalFn, ToolExecutionFn, ToolExecutionNextFn,
ToolInterceptFn, ToolSanitizeFn,
ToolExecutionOutcomeNextFn, ToolInterceptFn, ToolSanitizeFn,
};
use crate::api::runtime::subscriber_dispatcher;
use crate::api::scope::{CreateScopeHandleParams, EndScopeHandleParams, ScopeHandle, ScopeType};
use crate::api::tool::ToolHandle;
use crate::api::tool::{CreateToolHandleParams, EndToolHandleParams};
use crate::api::tool::{
CreateToolHandleParams, EndToolHandleParams, ToolExecutionInterceptOutcome,
};
use crate::codec::request::AnnotatedLlmRequest;
use crate::codec::response::AnnotatedLlmResponse;
use crate::context::registries::{
Expand Down Expand Up @@ -821,22 +824,68 @@ impl NemoRelayContextState {
/// from the active scope stack.
///
/// # Returns
/// A composed [`ToolExecutionNextFn`] that wraps `default_fn` in every
/// matching execution intercept.
/// A composed [`ToolExecutionOutcomeNextFn`] that wraps `default_fn` in
/// every matching execution intercept.
pub(crate) fn tool_build_execution_chain(
&self,
name: &str,
default_fn: ToolExecutionNextFn,
scope_locals: &[&SortedRegistry<ExecutionIntercept<ToolExecutionFn>>],
) -> ToolExecutionNextFn {
) -> ToolExecutionOutcomeNextFn {
let matching =
merge_execution_intercept_callables(&self.tool_execution_intercepts, scope_locals);
let mut next = default_fn;
let mut next: ToolExecutionOutcomeNextFn = Arc::new(move |args| {
let default_fn = default_fn.clone();
Box::pin(async move {
default_fn(args)
.await
.map(ToolExecutionInterceptOutcome::new)
})
});
let name = name.to_string();
for (callable, _) in matching.into_iter().rev() {
let current_next = next.clone();
let current_name = name.clone();
next = Arc::new(move |args| callable(&current_name, args, current_next.clone()));
next = Arc::new(move |args| {
let callable = callable.clone();
let current_name = current_name.clone();
let next_sequence = Arc::new(AtomicUsize::new(0));
let downstream_marks = Arc::new(Mutex::new(Vec::new()));
let raw_next: ToolExecutionNextFn = {
let current_next = current_next.clone();
let next_sequence = next_sequence.clone();
let downstream_marks = downstream_marks.clone();
Arc::new(move |args| {
let sequence = next_sequence.fetch_add(1, Ordering::Relaxed);
let current_next = current_next.clone();
let downstream_marks = downstream_marks.clone();
Box::pin(async move {
let outcome = current_next(args).await?;
downstream_marks
.lock()
.expect("tool pending mark accumulator lock poisoned")
.push((sequence, outcome.pending_marks));
Ok(outcome.result)
})
})
};
Box::pin(async move {
let mut outcome = callable(&current_name, args, raw_next).await?;
let mut downstream_batches = std::mem::take(
&mut *downstream_marks
.lock()
.expect("tool pending mark accumulator lock poisoned"),
);
downstream_batches.sort_by_key(|(sequence, _)| *sequence);
let mut marks = downstream_batches
.into_iter()
.flat_map(|(_, marks)| marks)
.collect::<Vec<_>>();
marks.append(&mut outcome.pending_marks);
outcome.pending_marks = marks;
Ok(outcome)
})
});
}
next
}
Expand Down
Loading
Loading