Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion crates/adaptive/src/acg_component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,10 +595,25 @@ pub(crate) fn create_acg_llm_request_intercept(
plugin: Arc<dyn ProviderPlugin>,
) -> LlmRequestInterceptFn {
Arc::new(move |_name: &str, request: LlmRequest, annotated| {
let input_content = request.content.clone();
let translated =
translate_request(&request, &agent_id, &provider, plugin.as_ref(), &hot_cache)
.unwrap_or(request);
Ok((translated, annotated))
if annotated.is_some() && translated.content != input_content {
let translated_annotated = build_semantic_request_view(&translated)
.map_err(|error| nemo_relay::error::FlowError::Internal(error.to_string()))?
.annotated_request;
return Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new(
LlmRequest {
headers: translated.headers,
content: input_content,
},
Some(translated_annotated),
));
Comment thread
bbednarski9 marked this conversation as resolved.
}
Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new(
translated, annotated,
))
})
}

Expand Down
36 changes: 21 additions & 15 deletions crates/adaptive/src/adaptive_hints_intercept.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,24 @@ fn manual_agent_hints(manual: u32, effective_agent_id: &str, scope_depth: usize)
}
}

fn inject_agent_hints(request: &mut LlmRequest, hints: &AgentHints) {
fn inject_agent_hints(
request: &mut LlmRequest,
annotated: &mut Option<AnnotatedLlmRequest>,
hints: &AgentHints,
) {
let Ok(serialized_hints) = serde_json::to_value(hints) else {
return;
};

if let Some(body) = request.content.as_object_mut() {
if !body.contains_key("nvext") {
body.insert(
"nvext".to_string(),
serde_json::Value::Object(serde_json::Map::new()),
);
}
if let Some(nvext) = body
.get_mut("nvext")
.and_then(|value| value.as_object_mut())
{
let body = annotated
.as_mut()
.map(|annotated| &mut annotated.extra)
.or_else(|| request.content.as_object_mut());
if let Some(body) = body {
let nvext = body
.entry("nvext".to_string())
.or_insert_with(|| serde_json::Value::Object(serde_json::Map::new()));
if let Some(nvext) = nvext.as_object_mut() {
nvext.insert("agent_hints".to_string(), serialized_hints.clone());
}
}
Expand Down Expand Up @@ -172,7 +174,9 @@ impl AdaptiveHintsIntercept {
pub fn into_request_fn(self) -> LlmRequestInterceptFn {
let this = Arc::new(self);
Arc::new(
move |_name: &str, mut request: LlmRequest, annotated: Option<AnnotatedLlmRequest>| {
move |_name: &str,
mut request: LlmRequest,
mut annotated: Option<AnnotatedLlmRequest>| {
let scope_path = extract_scope_path();
let manual_ls = read_manual_latency_sensitivity();
let scope_depth = scope_path.len();
Expand All @@ -189,10 +193,12 @@ impl AdaptiveHintsIntercept {
);

if let Some(hints) = final_hints {
inject_agent_hints(&mut request, &hints);
inject_agent_hints(&mut request, &mut annotated, &hints);
}

Ok((request, annotated))
Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new(
request, annotated,
))
},
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ async fn test_adaptive_plugin_registers_and_passes_calls_through() {
},
)
.unwrap();
assert_eq!(request.content["messages"], json!([]));
assert_eq!(request.request.content["messages"], json!([]));

let llm_func: LlmExecutionNextFn =
Arc::new(|_req: LlmRequest| Box::pin(async { Ok(json!({"response": "ok"})) }));
Expand Down Expand Up @@ -721,7 +721,9 @@ impl Plugin for HeaderPlugin {
false,
Arc::new(|_name, mut request, annotated| {
request.headers.insert("x-plugin".into(), json!("set"));
Ok((request, annotated))
Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new(
request, annotated,
))
}),
)?;
ctx.register_tool_request_intercept(
Expand Down Expand Up @@ -806,7 +808,7 @@ async fn test_top_level_plugin_registers_request_and_execution_intercepts() {
},
)
.unwrap();
assert_eq!(request.headers.get("x-plugin"), Some(&json!("set")));
assert_eq!(request.request.headers.get("x-plugin"), Some(&json!("set")));

let tool_func: ToolExecutionNextFn = Arc::new(|args| Box::pin(async move { Ok(args) }));
let tool_result = tool_call_execute(
Expand Down
4 changes: 3 additions & 1 deletion crates/adaptive/tests/unit/acg_component_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1087,12 +1087,14 @@ fn acg_component_request_intercept_passes_original_request_and_annotation_when_t
plugin,
);

let (translated, returned_annotated) = intercept(
let outcome = intercept(
"anthropic",
invalid_request.clone(),
Some(annotated.clone()),
)
.expect("request intercept should pass through");
let translated = outcome.request;
let returned_annotated = outcome.annotated_request;

assert_eq!(translated.content, invalid_request.content);
assert_eq!(
Expand Down
23 changes: 17 additions & 6 deletions crates/adaptive/tests/unit/adaptive_hints_intercept_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override()
stream: None,
extra: serde_json::Map::new(),
};
let (request, returned_annotated) = req_fn(
let outcome = req_fn(
"model",
LlmRequest {
headers: serde_json::Map::new(),
Expand All @@ -203,8 +203,12 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override()
Some(annotated.clone()),
)
.unwrap();
let request = outcome.request;
let returned_annotated = outcome.annotated_request;

let body_hints = &request.content["nvext"]["agent_hints"];
assert_eq!(request.content, serde_json::json!({}));
let returned_annotated = returned_annotated.expect("annotation should be preserved");
let body_hints = &returned_annotated.extra["nvext"]["agent_hints"];
assert_eq!(body_hints["osl"], serde_json::json!(150));
assert_eq!(body_hints["iat"], serde_json::json!(200));
assert_eq!(body_hints["latency_sensitivity"], serde_json::json!(5.0));
Expand All @@ -215,7 +219,11 @@ fn test_adaptive_hints_intercept_injects_prediction_hints_and_manual_override()
request.headers.get(AGENT_HINTS_HEADER_KEY).unwrap(),
body_hints
);
assert_eq!(returned_annotated, Some(annotated));
let mut expected_annotated = annotated;
expected_annotated
.extra
.insert("nvext".into(), returned_annotated.extra["nvext"].clone());
assert_eq!(returned_annotated, expected_annotated);

pop_scope(
nemo_relay::api::scope::PopScopeParams::builder()
Expand Down Expand Up @@ -256,7 +264,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() {
}));
let req_fn =
AdaptiveHintsIntercept::new(hot_cache, "fallback-agent".to_string()).into_request_fn();
let (request, annotated) = req_fn(
let outcome = req_fn(
"model",
LlmRequest {
headers: serde_json::Map::new(),
Expand All @@ -265,6 +273,8 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() {
None,
)
.unwrap();
let request = outcome.request;
let annotated = outcome.annotated_request;
assert_eq!(
request.headers.get(AGENT_HINTS_HEADER_KEY),
Some(&serde_json::to_value(&defaults).unwrap())
Expand Down Expand Up @@ -293,7 +303,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() {
});
let poisoned_req_fn =
AdaptiveHintsIntercept::new(poisoned_cache, "fallback-agent".to_string()).into_request_fn();
let (poisoned_request, _) = poisoned_req_fn(
let poisoned_outcome = poisoned_req_fn(
"model",
LlmRequest {
headers: serde_json::Map::new(),
Expand All @@ -302,6 +312,7 @@ fn test_adaptive_hints_intercept_uses_defaults_and_ignores_poisoned_cache() {
None,
)
.unwrap();
let poisoned_request = poisoned_outcome.request;
assert!(
poisoned_request
.headers
Expand Down Expand Up @@ -344,7 +355,7 @@ fn test_apply_manual_latency_override_and_inject_agent_hints_cover_manual_paths(
headers: serde_json::Map::new(),
content: serde_json::json!("scalar"),
};
inject_agent_hints(&mut non_object_request, &manual_only);
inject_agent_hints(&mut non_object_request, &mut None, &manual_only);
assert_eq!(
non_object_request.headers.get(AGENT_HINTS_HEADER_KEY),
Some(&serde_json::to_value(&manual_only).unwrap())
Expand Down
2 changes: 1 addition & 1 deletion crates/adaptive/tests/unit/plugin_component_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ async fn adaptive_plugin_registers_runtime_and_rolls_back_registration() {
},
)
.unwrap();
assert!(request.headers.is_empty());
assert!(request.request.headers.is_empty());

let mut registrations = ctx.into_registrations();
assert_eq!(registrations.len(), 1);
Expand Down
20 changes: 16 additions & 4 deletions crates/adaptive/tests/unit/runtime_features_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ fn assert_llm_request_intercept_registered(name: &str) {
name,
i32::MAX,
false,
Arc::new(|_name, request, annotated| Ok((request, annotated))),
Arc::new(|_name, request, annotated| {
Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new(
request, annotated,
))
}),
),
name,
);
Expand All @@ -148,7 +152,11 @@ fn assert_llm_request_intercept_absent(name: &str) {
name,
i32::MAX,
false,
Arc::new(|_name, request, annotated| Ok((request, annotated))),
Arc::new(|_name, request, annotated| {
Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new(
request, annotated,
))
}),
)
.unwrap();
deregister_llm_request_intercept(name).unwrap();
Expand Down Expand Up @@ -549,7 +557,7 @@ async fn adaptive_hints_feature_registers_request_intercept() {
},
)
.unwrap();
assert!(request.headers.contains_key(AGENT_HINTS_HEADER_KEY));
assert!(request.request.headers.contains_key(AGENT_HINTS_HEADER_KEY));

let mut registrations = ctx.finish();
rollback_registrations(&mut registrations);
Expand Down Expand Up @@ -712,7 +720,11 @@ async fn registration_context_registers_all_supported_callback_types() {
"adaptive_test_request",
5,
false,
Arc::new(|_name, request, annotated| Ok((request, annotated))),
Arc::new(|_name, request, annotated| {
Ok(nemo_relay::api::llm::LlmRequestInterceptOutcome::new(
request, annotated,
))
}),
)
.unwrap();
ctx.register_llm_execution_intercept(
Expand Down
2 changes: 1 addition & 1 deletion crates/adaptive/tests/unit/runtime_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ async fn adaptive_runtime_bind_scope_requires_registration_and_passes_through_wi
let translated = llm_request_intercepts("anthropic", request.clone())
.expect("request intercept chain should pass through when no hot-cache state exists");

assert_eq!(translated.content, request.content);
assert_eq!(translated.request.content, request.content);
pop_scope(PopScopeParams::builder().handle_uuid(&scope.uuid).build())
.expect("scope pop should succeed");
}
Loading
Loading