diff --git a/rust/README.md b/rust/README.md index 9bf1fddd9..a903e906f 100644 --- a/rust/README.md +++ b/rust/README.md @@ -406,6 +406,37 @@ The closure receives the full [`ToolInvocation`](crate::types::ToolInvocation) a Reach for the `ToolHandler` trait directly when you need shared state across multiple methods or want a named type that shows up by name in stack traces. +### Tool Handler Cancellation + +Every `ToolInvocation` carries an optional `cancellation_token: Option` that fires when the in-flight handler should stop early. The SDK populates it on dispatch; it's `None` only for invocations you construct yourself (e.g. in tests). Two sources can cancel it: + +- **`session.abort().await?`** — cancels all currently in-flight handlers and also sends the `session.abort` RPC to stop the agentic loop. +- **`session.cancel_tool_call(tool_call_id)`** — cancels only the named handler without affecting others or the agentic loop. Returns `true` if an in-flight handler with that ID was found; `false` otherwise. + +Handlers that don't need cancellation can ignore the token. Handlers that do long-running work can cooperate: + +```rust,ignore +use github_copilot_sdk::tool::ToolHandler; +use github_copilot_sdk::types::ToolInvocation; +use github_copilot_sdk::{Error, ErrorKind, ToolResult}; + +struct LongRunningTool; + +impl ToolHandler for LongRunningTool { + async fn call(&self, inv: ToolInvocation) -> Result { + let Some(token) = inv.cancellation_token.clone() else { + return do_expensive_work().await; + }; + tokio::select! { + _ = token.cancelled() => { + Err(Error::with_message(ErrorKind::Cancelled, "tool call cancelled")) + } + result = do_expensive_work() => result, + } + } +} +``` + ### Permission Policies Set a permission policy directly on `SessionConfig` with the chainable builders. They install a synthesized `PermissionHandler` so only permission requests are intercepted; every other event flows through unchanged. diff --git a/rust/src/session.rs b/rust/src/session.rs index fed6705da..47a505916 100644 --- a/rust/src/session.rs +++ b/rust/src/session.rs @@ -158,6 +158,15 @@ pub struct Session { /// via [`Session::cancellation_token`] to bind their own work to /// the session lifetime. shutdown: CancellationToken, + /// Cancellation tokens for all currently in-flight tool handlers, keyed + /// by `tool_call_id`. + /// + /// Each dispatched [`ToolInvocation`](crate::types::ToolInvocation) + /// receives a child token registered here. [`Session::abort`] cancels + /// every token in the map; [`Session::cancel_tool_call`] cancels exactly + /// one. The event-loop task removes the entry once the handler future + /// resolves. Shared with the event loop via `Arc>`. + in_flight_tool_calls: Arc>>, /// Only populated while a `send_and_wait` call is in flight. /// /// Sync `parking_lot::Mutex` because the lock is never held across an @@ -500,12 +509,29 @@ impl Session { /// Abort the current agent turn. /// + /// Cancels the agentic loop and propagates cancellation to all in-flight + /// tool handlers via the [`CancellationToken`] on each + /// [`ToolInvocation`](crate::types::ToolInvocation). Handlers can check + /// [`is_cancelled()`](CancellationToken::is_cancelled) or `select!` on + /// [`cancelled()`](CancellationToken::cancelled) to stop early. + /// + /// To cancel a single handler without aborting the agentic loop, use + /// [`cancel_tool_call`](Self::cancel_tool_call) instead. + /// /// # Cancel safety /// /// **Cancel-safe.** Single `session.abort` RPC; the underlying /// [`Client::call`](crate::Client::call) is cancel-safe via the /// writer-actor. pub async fn abort(&self) -> Result<(), Error> { + // Cancel all in-flight handlers before sending the RPC so they can + // begin cleanup while the network round-trip is in flight. + { + let guard = self.in_flight_tool_calls.lock(); + for token in guard.values() { + token.cancel(); + } + } self.client .call( "session.abort", @@ -515,6 +541,25 @@ impl Session { Ok(()) } + /// Cancel a single in-flight tool handler by its `tool_call_id`. + /// + /// Fires only the cancellation token for the named handler and removes it + /// from the in-flight registry, leaving all other handlers and the + /// agentic loop untouched. Use [`abort`](Self::abort) to cancel the full + /// turn. + /// + /// Returns `true` if a handler with that ID was found and cancelled, + /// `false` if no matching in-flight handler exists. + pub fn cancel_tool_call(&self, tool_call_id: &str) -> bool { + let mut guard = self.in_flight_tool_calls.lock(); + if let Some(token) = guard.remove(tool_call_id) { + token.cancel(); + true + } else { + false + } + } + /// Switch to a different model. /// /// Pass `None` for `opts` if no extra configuration is needed. @@ -916,6 +961,7 @@ impl Client { let idle_waiter = Arc::new(ParkingLotMutex::new(None)); let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new())); let shutdown = CancellationToken::new(); + let in_flight_tool_calls = Arc::new(ParkingLotMutex::new(HashMap::new())); let (event_tx, _) = tokio::sync::broadcast::channel(512); // For cloud sessions (use_server_generated_id), defer session @@ -1017,6 +1063,7 @@ impl Client { open_canvases.clone(), event_tx.clone(), shutdown.clone(), + in_flight_tool_calls.clone(), ); tracing::debug!( elapsed_ms = setup_start.elapsed().as_millis(), @@ -1041,6 +1088,7 @@ impl Client { client: self.clone(), event_loop: ParkingLotMutex::new(Some(event_loop)), shutdown, + in_flight_tool_calls, idle_waiter, capabilities, open_canvases, @@ -1173,6 +1221,7 @@ impl Client { let idle_waiter = Arc::new(ParkingLotMutex::new(None)); let open_canvases = Arc::new(parking_lot::RwLock::new(Vec::new())); let shutdown = CancellationToken::new(); + let in_flight_tool_calls = Arc::new(ParkingLotMutex::new(HashMap::new())); let (event_tx, _) = tokio::sync::broadcast::channel(512); let event_loop = spawn_event_loop( session_id.clone(), @@ -1189,6 +1238,7 @@ impl Client { open_canvases.clone(), event_tx.clone(), shutdown.clone(), + in_flight_tool_calls.clone(), ); let mut registration = PendingSessionRegistration::new(self.clone(), session_id.clone(), shutdown.clone()); @@ -1284,6 +1334,7 @@ impl Client { client: self.clone(), event_loop: ParkingLotMutex::new(Some(event_loop)), shutdown, + in_flight_tool_calls, idle_waiter, capabilities, open_canvases, @@ -1397,6 +1448,7 @@ fn spawn_event_loop( open_canvases: Arc>>, event_tx: tokio::sync::broadcast::Sender, shutdown: CancellationToken, + in_flight_tool_calls: Arc>>, ) -> JoinHandle<()> { let crate::router::SessionChannels { mut notifications, @@ -1421,7 +1473,7 @@ fn spawn_event_loop( _ = shutdown.cancelled() => break, Some(notification) = notifications.recv() => { handle_notification( - &session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &open_canvases, &event_tx, + &session_id, &client, &handlers, &command_handlers, notification, &idle_waiter, &capabilities, &open_canvases, &event_tx, &shutdown, &in_flight_tool_calls, ).await; } Some(request) = requests.recv() => { @@ -1494,6 +1546,8 @@ async fn handle_notification( capabilities: &Arc>, open_canvases: &Arc>>, event_tx: &tokio::sync::broadcast::Sender, + shutdown: &CancellationToken, + in_flight_tool_calls: &Arc>>, ) { let dispatch_start = Instant::now(); let event = notification.event.clone(); @@ -1741,6 +1795,8 @@ async fn handle_notification( session_id = %sid, request_id = %request_id ); + let shutdown = shutdown.clone(); + let in_flight_tool_calls = in_flight_tool_calls.clone(); tokio::spawn( async move { // `tool_name.is_empty()` would have produced a `None` @@ -1770,6 +1826,10 @@ async fn handle_notification( } let tool_call_id = data.tool_call_id.clone(); let tool_name = data.tool_name.clone(); + let cancellation_token = shutdown.child_token(); + in_flight_tool_calls + .lock() + .insert(tool_call_id.clone(), cancellation_token.clone()); let invocation = ToolInvocation { session_id: sid.clone(), tool_call_id: data.tool_call_id, @@ -1777,6 +1837,7 @@ async fn handle_notification( arguments: data .arguments .unwrap_or(Value::Object(serde_json::Map::new())), + cancellation_token: Some(cancellation_token), traceparent: data.traceparent, tracestate: data.tracestate, }; @@ -1785,6 +1846,9 @@ async fn handle_notification( Ok(r) => r, Err(e) => tool_failure_result(e.to_string()), }; + // Remove the entry whether the handler succeeded, failed, + // or was cancelled — the token is no longer needed. + in_flight_tool_calls.lock().remove(&tool_call_id); tracing::debug!( elapsed_ms = handler_start.elapsed().as_millis(), session_id = %sid, @@ -2320,7 +2384,12 @@ fn inject_transform_sections_resume( #[cfg(test)] mod tests { + use std::collections::HashMap; + use std::sync::Arc; + + use parking_lot::Mutex as ParkingLotMutex; use serde_json::json; + use tokio_util::sync::CancellationToken; use super::notification_permission_payload; use crate::handler::PermissionResult; @@ -2349,4 +2418,103 @@ mod tests { Some(json!({ "kind": "user-not-available" })) ); } + + // Simulate the in-flight map mechanics used by Session without needing a + // real CLI connection. + fn make_map() -> Arc>> { + Arc::new(ParkingLotMutex::new(HashMap::new())) + } + + fn cancel_tool_call( + map: &Arc>>, + tool_call_id: &str, + ) -> bool { + let mut guard = map.lock(); + if let Some(token) = guard.remove(tool_call_id) { + token.cancel(); + true + } else { + false + } + } + + #[test] + fn cancel_tool_call_cancels_only_the_targeted_handler() { + let map = make_map(); + let token_a = CancellationToken::new(); + let token_b = CancellationToken::new(); + map.lock().insert("tc_a".to_string(), token_a.clone()); + map.lock().insert("tc_b".to_string(), token_b.clone()); + + // Cancelling A leaves B untouched. + assert!(cancel_tool_call(&map, "tc_a")); + assert!(token_a.is_cancelled()); + assert!(!token_b.is_cancelled()); + + // The entry is removed from the map. + assert_eq!(map.lock().len(), 1); + assert!(!map.lock().contains_key("tc_a")); + assert!(map.lock().contains_key("tc_b")); + } + + #[test] + fn cancel_tool_call_returns_false_for_unknown_id() { + let map = make_map(); + assert!(!cancel_tool_call(&map, "nonexistent")); + } + + #[test] + fn abort_cancels_all_in_flight_tokens() { + let map = make_map(); + let token_a = CancellationToken::new(); + let token_b = CancellationToken::new(); + map.lock().insert("tc_a".to_string(), token_a.clone()); + map.lock().insert("tc_b".to_string(), token_b.clone()); + + // Simulate abort(): cancel all tokens in the map. + { + let guard = map.lock(); + for token in guard.values() { + token.cancel(); + } + } + + assert!(token_a.is_cancelled()); + assert!(token_b.is_cancelled()); + } + + /// Verify the end-to-end contract: a handler that selects on its + /// `cancellation_token.cancelled()` unblocks when the map entry is + /// cancelled (as `abort()` would do). This exercises the same path the + /// real dispatch code uses — insert token into map, pass child to handler, + /// cancel map entry — without requiring a live CLI connection. + #[tokio::test] + async fn abort_unblocks_handler_awaiting_cancellation() { + let map = make_map(); + let shutdown = CancellationToken::new(); + + // Simulate dispatch: create a child token, register it, hand it to + // the "handler" task. + let token = shutdown.child_token(); + map.lock().insert("tc_x".to_string(), token.clone()); + + let handler = tokio::spawn(async move { + // Handler blocks until its token fires. + token.cancelled().await; + }); + + // Simulate abort(): cancel every token in the map. + { + let guard = map.lock(); + for t in guard.values() { + t.cancel(); + } + } + + // The handler task must complete promptly once cancelled. + tokio::time::timeout(std::time::Duration::from_secs(1), handler) + .await + .expect("handler should complete within timeout after abort") + .expect("handler task should not panic"); + } } diff --git a/rust/src/tool.rs b/rust/src/tool.rs index 189bc6f21..f49b0d0a6 100644 --- a/rust/src/tool.rs +++ b/rust/src/tool.rs @@ -566,8 +566,7 @@ mod tests { tool_call_id: "tc1".to_string(), tool_name: "echo".to_string(), arguments: serde_json::json!({"msg": "hello"}), - traceparent: None, - tracestate: None, + ..Default::default() }; let result = tool.call(inv).await.unwrap(); @@ -606,8 +605,7 @@ mod tests { tool_call_id: "tc1".to_string(), tool_name: "weather".to_string(), arguments: serde_json::json!({"city": "Seattle"}), - traceparent: None, - tracestate: None, + ..Default::default() }; match handler.call(inv).await.unwrap() { ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"), @@ -688,8 +686,7 @@ mod tests { tool_call_id: "tc1".to_string(), tool_name: "get_weather".to_string(), arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}), - traceparent: None, - tracestate: None, + ..Default::default() }; let result = tool.call(inv).await.unwrap(); @@ -707,8 +704,7 @@ mod tests { tool_call_id: "tc1".to_string(), tool_name: "get_weather".to_string(), arguments: serde_json::json!({"wrong_field": 42}), - traceparent: None, - tracestate: None, + ..Default::default() }; let err = tool.call(inv).await.unwrap_err(); @@ -728,8 +724,7 @@ mod tests { tool_call_id: "tc1".to_string(), tool_name: "get_weather".to_string(), arguments: serde_json::json!({"city": "Portland"}), - traceparent: None, - tracestate: None, + ..Default::default() }) .await .expect("ToolHandler::call should succeed for matching args"); @@ -739,4 +734,55 @@ mod tests { } } } + + #[tokio::test] + async fn tool_invocation_cancellation_token_fires_on_cancel() { + let token = tokio_util::sync::CancellationToken::new(); + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "echo".to_string(), + arguments: serde_json::json!({}), + cancellation_token: Some(token.clone()), + ..Default::default() + }; + + let inv_token = inv.cancellation_token.expect("token was set"); + assert!(!inv_token.is_cancelled()); + token.cancel(); + assert!(inv_token.is_cancelled()); + } + + #[tokio::test] + async fn tool_invocation_child_token_cancelled_when_parent_fires() { + let parent = tokio_util::sync::CancellationToken::new(); + let child = parent.child_token(); + + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "echo".to_string(), + arguments: serde_json::json!({}), + cancellation_token: Some(child), + ..Default::default() + }; + + let inv_token = inv.cancellation_token.expect("token was set"); + assert!(!inv_token.is_cancelled()); + parent.cancel(); + assert!(inv_token.is_cancelled()); + } + + #[test] + fn tool_invocation_cancellation_token_defaults_to_none() { + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "echo".to_string(), + arguments: serde_json::json!({}), + ..Default::default() + }; + + assert!(inv.cancellation_token.is_none()); + } } diff --git a/rust/src/types.rs b/rust/src/types.rs index c0643ec66..f807f7765 100644 --- a/rust/src/types.rs +++ b/rust/src/types.rs @@ -11,6 +11,7 @@ use std::time::Duration; use serde::{Deserialize, Serialize}; use serde_json::Value; +use tokio_util::sync::CancellationToken; use crate::canvas::{CanvasDeclaration, CanvasHandler}; use crate::generated::api_types::OpenCanvasInstance; @@ -3934,6 +3935,21 @@ pub struct ToolInvocation { pub tool_name: String, /// Tool arguments as JSON. pub arguments: Value, + /// Cancellation signal for this tool invocation. + /// + /// Populated by the SDK when dispatching a handler. Fires when + /// [`Session::abort`](crate::Session::abort) or + /// [`Session::cancel_tool_call`](crate::Session::cancel_tool_call) is + /// called while this handler is in flight. Handlers can check + /// [`is_cancelled()`](CancellationToken::is_cancelled) or `select!` on + /// [`cancelled()`](CancellationToken::cancelled) to cooperatively stop + /// work early. Handlers that don't need cancellation can ignore this field. + /// + /// `None` for invocations constructed outside SDK dispatch (e.g. in + /// tests), so handlers should treat a missing token as "no cancellation + /// signal available." + #[serde(skip)] + pub cancellation_token: Option, /// W3C Trace Context `traceparent` header propagated from the CLI's /// `execute_tool` span. Pass through to OpenTelemetry-aware code so /// child spans created inside the handler are parented to the CLI