diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index a93f98d1..9c6ebbe4 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -155,6 +155,15 @@ impl ProtocolVersion { pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05")); // Keep LATEST at 2025-03-26 until full 2025-06-18 compliance and automated testing are in place. pub const LATEST: Self = Self::V_2025_03_26; + + /// All protocol versions known to this SDK. + pub const KNOWN_VERSIONS: &[Self] = + &[Self::V_2024_11_05, Self::V_2025_03_26, Self::V_2025_06_18]; + + /// Returns the string representation of this protocol version. + pub fn as_str(&self) -> &str { + &self.0 + } } impl Serialize for ProtocolVersion { diff --git a/crates/rmcp/src/transport/common/auth/streamable_http_client.rs b/crates/rmcp/src/transport/common/auth/streamable_http_client.rs index 35e3ed5a..47f08f13 100644 --- a/crates/rmcp/src/transport/common/auth/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/auth/streamable_http_client.rs @@ -17,13 +17,14 @@ where uri: std::sync::Arc, session_id: std::sync::Arc, mut auth_token: Option, + custom_headers: HashMap, ) -> Result<(), crate::transport::streamable_http_client::StreamableHttpError> { if auth_token.is_none() { auth_token = Some(self.get_access_token().await?); } self.http_client - .delete_session(uri, session_id, auth_token) + .delete_session(uri, session_id, auth_token, custom_headers) .await } @@ -33,6 +34,7 @@ where session_id: std::sync::Arc, last_event_id: Option, mut auth_token: Option, + custom_headers: HashMap, ) -> Result< futures::stream::BoxStream<'static, Result>, crate::transport::streamable_http_client::StreamableHttpError, @@ -41,7 +43,7 @@ where auth_token = Some(self.get_access_token().await?); } self.http_client - .get_stream(uri, session_id, last_event_id, auth_token) + .get_stream(uri, session_id, last_event_id, auth_token, custom_headers) .await } diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index b4cdafd1..5a39b4a4 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -22,6 +22,43 @@ impl From for StreamableHttpError { } } +/// Reserved headers that must not be overridden by user-supplied custom headers. +/// `MCP-Protocol-Version` is in this list but is allowed through because the worker +/// injects it after initialization. +const RESERVED_HEADERS: &[&str] = &[ + "accept", + HEADER_SESSION_ID, + HEADER_MCP_PROTOCOL_VERSION, + HEADER_LAST_EVENT_ID, +]; + +/// Applies custom headers to a request builder, rejecting reserved headers +/// except `MCP-Protocol-Version` (which the worker injects after init). +fn apply_custom_headers( + mut builder: reqwest::RequestBuilder, + custom_headers: HashMap, +) -> Result> { + for (name, value) in custom_headers { + if RESERVED_HEADERS + .iter() + .any(|&r| name.as_str().eq_ignore_ascii_case(r)) + { + if name + .as_str() + .eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION) + { + builder = builder.header(name, value); + continue; + } + return Err(StreamableHttpError::ReservedHeaderConflict( + name.to_string(), + )); + } + builder = builder.header(name, value); + } + Ok(builder) +} + impl StreamableHttpClient for reqwest::Client { type Error = reqwest::Error; @@ -31,6 +68,7 @@ impl StreamableHttpClient for reqwest::Client { session_id: Arc, last_event_id: Option, auth_token: Option, + custom_headers: HashMap, ) -> Result>, StreamableHttpError> { let mut request_builder = self .get(uri.as_ref()) @@ -42,6 +80,7 @@ impl StreamableHttpClient for reqwest::Client { if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } + request_builder = apply_custom_headers(request_builder, custom_headers)?; let response = request_builder.send().await?; if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { return Err(StreamableHttpError::ServerDoesNotSupportSse); @@ -70,15 +109,15 @@ impl StreamableHttpClient for reqwest::Client { uri: Arc, session: Arc, auth_token: Option, + custom_headers: HashMap, ) -> Result<(), StreamableHttpError> { let mut request_builder = self.delete(uri.as_ref()); if let Some(auth_header) = auth_token { request_builder = request_builder.bearer_auth(auth_header); } - let response = request_builder - .header(HEADER_SESSION_ID, session.as_ref()) - .send() - .await?; + request_builder = request_builder.header(HEADER_SESSION_ID, session.as_ref()); + request_builder = apply_custom_headers(request_builder, custom_headers)?; + let response = request_builder.send().await?; // if method no allowed if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { @@ -104,25 +143,7 @@ impl StreamableHttpClient for reqwest::Client { request = request.bearer_auth(auth_header); } - // Apply custom headers - let reserved_headers = [ - ACCEPT.as_str(), - HEADER_SESSION_ID, - HEADER_MCP_PROTOCOL_VERSION, - HEADER_LAST_EVENT_ID, - ]; - for (name, value) in custom_headers { - if reserved_headers - .iter() - .any(|&r| name.as_str().eq_ignore_ascii_case(r)) - { - return Err(StreamableHttpError::ReservedHeaderConflict( - name.to_string(), - )); - } - - request = request.header(name, value); - } + request = apply_custom_headers(request, custom_headers)?; if let Some(session_id) = session_id { request = request.header(HEADER_SESSION_ID, session_id.as_ref()); } diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index d45613bb..1c388e50 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -11,7 +11,7 @@ use tracing::debug; use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect}; use crate::{ RoleClient, - model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult}, transport::{ common::client_side_sse::SseAutoReconnectStream, worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport}, @@ -184,6 +184,7 @@ pub trait StreamableHttpClient: Clone + Send + 'static { uri: Arc, session_id: Arc, auth_header: Option, + custom_headers: HashMap, ) -> impl Future>> + Send + '_; fn get_stream( &self, @@ -191,6 +192,7 @@ pub trait StreamableHttpClient: Clone + Send + 'static { session_id: Arc, last_event_id: Option, auth_header: Option, + custom_headers: HashMap, ) -> impl Future< Output = Result< BoxStream<'static, Result>, @@ -210,6 +212,7 @@ struct StreamableHttpClientReconnect { pub session_id: Arc, pub uri: Arc, pub auth_header: Option, + pub custom_headers: HashMap, } impl SseStreamReconnect for StreamableHttpClientReconnect { @@ -220,15 +223,25 @@ impl SseStreamReconnect for StreamableHttpClientReconne let uri = self.uri.clone(); let session_id = self.session_id.clone(); let auth_header = self.auth_header.clone(); + let custom_headers = self.custom_headers.clone(); let last_event_id = last_event_id.map(|s| s.to_owned()); Box::pin(async move { client - .get_stream(uri, session_id, last_event_id, auth_header) + .get_stream(uri, session_id, last_event_id, auth_header, custom_headers) .await }) } } +/// Info retained for cleaning up the session when the worker exits. +struct SessionCleanupInfo { + client: C, + uri: Arc, + session_id: Arc, + auth_header: Option, + protocol_headers: HashMap, +} + #[derive(Debug, Clone, Default)] pub struct StreamableHttpClientWorker { pub client: C, @@ -357,14 +370,29 @@ impl Worker for StreamableHttpClientWorker { } None }; + // Extract the negotiated protocol version from the init response + // and build a custom headers map that includes MCP-Protocol-Version + // for all subsequent HTTP requests (per MCP 2025-06-18 spec). + let protocol_headers = { + let mut headers = config.custom_headers.clone(); + if let ServerJsonRpcMessage::Response(response) = &message { + if let ServerResult::InitializeResult(init_result) = &response.result { + if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) { + // HeaderName::from_static requires lowercase + headers.insert(HeaderName::from_static("mcp-protocol-version"), hv); + } + } + } + headers + }; + // Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns) - let session_cleanup_info = session_id.as_ref().map(|sid| { - ( - self.client.clone(), - config.uri.clone(), - sid.clone(), - config.auth_header.clone(), - ) + let session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo { + client: self.client.clone(), + uri: config.uri.clone(), + session_id: sid.clone(), + auth_header: config.auth_header.clone(), + protocol_headers: protocol_headers.clone(), }); context.send_to_handler(message).await?; @@ -376,7 +404,7 @@ impl Worker for StreamableHttpClientWorker { initialized_notification.message, session_id.clone(), config.auth_header.clone(), - config.custom_headers.clone(), + protocol_headers.clone(), ) .await .map_err(WorkerQuitReason::fatal_context( @@ -404,10 +432,17 @@ impl Worker for StreamableHttpClientWorker { let transport_task_ct = transport_task_ct.clone(); let config_uri = config.uri.clone(); let config_auth_header = config.auth_header.clone(); + let spawn_headers = protocol_headers.clone(); streams.spawn(async move { match client - .get_stream(uri.clone(), session_id.clone(), None, auth_header.clone()) + .get_stream( + uri.clone(), + session_id.clone(), + None, + auth_header.clone(), + spawn_headers.clone(), + ) .await { Ok(stream) => { @@ -418,6 +453,7 @@ impl Worker for StreamableHttpClientWorker { session_id: session_id.clone(), uri: config_uri, auth_header: config_auth_header, + custom_headers: spawn_headers, }, retry_config, ); @@ -482,7 +518,7 @@ impl Worker for StreamableHttpClientWorker { message, session_id.clone(), config.auth_header.clone(), - config.custom_headers.clone(), + protocol_headers.clone(), ) .await; let send_result = match response { @@ -504,6 +540,7 @@ impl Worker for StreamableHttpClientWorker { session_id: session_id.clone(), uri: config.uri.clone(), auth_header: config.auth_header.clone(), + custom_headers: protocol_headers.clone(), }, self.config.retry_config.clone(), ); @@ -550,32 +587,41 @@ impl Worker for StreamableHttpClientWorker { // Cleanup session before returning (ensures close() waits for session deletion) // Use a timeout to prevent indefinite hangs if the server is unresponsive - if let Some((client, url, session_id, auth_header)) = session_cleanup_info { + if let Some(cleanup) = session_cleanup_info { const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5); + let cleanup_session_id = cleanup.session_id.clone(); match tokio::time::timeout( SESSION_CLEANUP_TIMEOUT, - client.delete_session(url, session_id.clone(), auth_header), + cleanup.client.delete_session( + cleanup.uri, + cleanup.session_id, + cleanup.auth_header, + cleanup.protocol_headers, + ), ) .await { Ok(Ok(_)) => { - tracing::info!(session_id = session_id.as_ref(), "delete session success") + tracing::info!( + session_id = cleanup_session_id.as_ref(), + "delete session success" + ) } Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => { tracing::info!( - session_id = session_id.as_ref(), + session_id = cleanup_session_id.as_ref(), "server doesn't support delete session" ) } Ok(Err(e)) => { tracing::error!( - session_id = session_id.as_ref(), + session_id = cleanup_session_id.as_ref(), "fail to delete session: {e}" ); } Err(_elapsed) => { tracing::warn!( - session_id = session_id.as_ref(), + session_id = cleanup_session_id.as_ref(), "session cleanup timed out after {:?}", SESSION_CLEANUP_TIMEOUT ); @@ -652,6 +698,7 @@ impl Worker for StreamableHttpClientWorker { /// _uri: Arc, /// _session_id: Arc, /// _auth_header: Option, +/// _custom_headers: HashMap, /// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError> { /// todo!() /// } @@ -662,6 +709,7 @@ impl Worker for StreamableHttpClientWorker { /// _session_id: Arc, /// _last_event_id: Option, /// _auth_header: Option, +/// _custom_headers: HashMap, /// ) -> Result>, rmcp::transport::streamable_http_client::StreamableHttpError> { /// todo!() /// } @@ -737,6 +785,7 @@ impl StreamableHttpClientTransport { /// _uri: Arc, /// _session_id: Arc, /// _auth_header: Option, + /// _custom_headers: HashMap, /// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError> { /// todo!() /// } @@ -747,6 +796,7 @@ impl StreamableHttpClientTransport { /// _session_id: Arc, /// _last_event_id: Option, /// _auth_header: Option, + /// _custom_headers: HashMap, /// ) -> Result>, rmcp::transport::streamable_http_client::StreamableHttpError> { /// todo!() /// } diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 37d4a008..4dffb4ea 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -11,14 +11,15 @@ use tokio_util::sync::CancellationToken; use super::session::SessionManager; use crate::{ RoleServer, - model::{ClientJsonRpcMessage, ClientRequest, GetExtensions}, + model::{ClientJsonRpcMessage, ClientRequest, GetExtensions, ProtocolVersion}, serve_server, service::serve_directly, transport::{ OneshotTransport, TransportAdapterIdentity, common::{ http_header::{ - EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION, + HEADER_SESSION_ID, JSON_MIME_TYPE, }, server_side_http::{ BoxResponse, ServerSseMessage, accepted_response, expect_json, @@ -55,6 +56,46 @@ impl Default for StreamableHttpServerConfig { } } +#[expect( + clippy::result_large_err, + reason = "BoxResponse is intentionally large; matches other handlers in this file" +)] +/// Validates the `MCP-Protocol-Version` header on incoming HTTP requests. +/// +/// Per the MCP 2025-06-18 spec: +/// - If the header is present but contains an unsupported version, return 400 Bad Request. +/// - If the header is absent, assume `2025-03-26` for backwards compatibility (no error). +fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), BoxResponse> { + if let Some(value) = headers.get(HEADER_MCP_PROTOCOL_VERSION) { + let version_str = value.to_str().map_err(|_| { + Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body( + Full::new(Bytes::from( + "Bad Request: Invalid MCP-Protocol-Version header encoding", + )) + .boxed(), + ) + .expect("valid response") + })?; + let is_known = ProtocolVersion::KNOWN_VERSIONS + .iter() + .any(|v| v.as_str() == version_str); + if !is_known { + return Err(Response::builder() + .status(http::StatusCode::BAD_REQUEST) + .body( + Full::new(Bytes::from(format!( + "Bad Request: Unsupported MCP-Protocol-Version: {version_str}" + ))) + .boxed(), + ) + .expect("valid response")); + } + } + Ok(()) +} + /// # Streamable Http Server /// /// ## Extract information from raw http request @@ -207,6 +248,8 @@ where .body(Full::new(Bytes::from("Unauthorized: Session not found")).boxed()) .expect("valid response")); } + // Validate MCP-Protocol-Version header (per 2025-06-18 spec) + validate_protocol_version_header(request.headers())?; // check if last event id is provided let last_event_id = request .headers() @@ -320,6 +363,9 @@ where .expect("valid response")); } + // Validate MCP-Protocol-Version header (per 2025-06-18 spec) + validate_protocol_version_header(&part.headers)?; + // inject request part to extensions match &mut message { ClientJsonRpcMessage::Request(req) => { @@ -455,6 +501,14 @@ where Ok(response) } } else { + // Stateless mode: validate MCP-Protocol-Version on non-init requests + let is_init = matches!( + &message, + ClientJsonRpcMessage::Request(req) if matches!(req.request, ClientRequest::InitializeRequest(_)) + ); + if !is_init { + validate_protocol_version_header(&part.headers)?; + } let service = self .get_service() .map_err(internal_error_response("get service"))?; @@ -511,6 +565,8 @@ where .body(Full::new(Bytes::from("Unauthorized: Session ID is required")).boxed()) .expect("valid response")); }; + // Validate MCP-Protocol-Version header (per 2025-06-18 spec) + validate_protocol_version_header(request.headers())?; // close session self.session_manager .close_session(&session_id) diff --git a/crates/rmcp/tests/test_custom_headers.rs b/crates/rmcp/tests/test_custom_headers.rs index c9307109..82537a80 100644 --- a/crates/rmcp/tests/test_custom_headers.rs +++ b/crates/rmcp/tests/test_custom_headers.rs @@ -190,22 +190,23 @@ async fn test_post_message_rejects_mcp_session_id() { } } -/// Unit test: post_message should reject reserved header "mcp-protocol-version" +/// Unit test: post_message should allow the mcp-protocol-version header through +/// (it is injected by the worker after initialization, not a user-settable custom header) #[tokio::test] #[cfg(feature = "transport-streamable-http-client-reqwest")] -async fn test_post_message_rejects_mcp_protocol_version() { +async fn test_post_message_allows_mcp_protocol_version() { use std::sync::Arc; use rmcp::{ model::{ClientJsonRpcMessage, ClientRequest, PingRequest, RequestId}, - transport::streamable_http_client::{StreamableHttpClient, StreamableHttpError}, + transport::streamable_http_client::StreamableHttpClient, }; let client = reqwest::Client::new(); let mut custom_headers = HashMap::new(); custom_headers.insert( HeaderName::from_static("mcp-protocol-version"), - HeaderValue::from_static("1.0"), + HeaderValue::from_static("2025-03-26"), ); let message = ClientJsonRpcMessage::request( @@ -223,19 +224,20 @@ async fn test_post_message_rejects_mcp_protocol_version() { ) .await; + // The header should be allowed through (not rejected as reserved). + // The error should be a connection error (no server at localhost:9999), + // not a ReservedHeaderConflict. + assert!(result.is_err(), "Should fail due to connection error"); assert!( - result.is_err(), - "Should reject 'mcp-protocol-version' header" + !matches!( + &result, + Err(rmcp::transport::streamable_http_client::StreamableHttpError::ReservedHeaderConflict( + _ + )) + ), + "MCP-Protocol-Version should not be rejected as reserved, got: {:?}", + result ); - match result { - Err(StreamableHttpError::ReservedHeaderConflict(header_name)) => { - assert_eq!( - header_name, "mcp-protocol-version", - "Error should indicate 'mcp-protocol-version' header" - ); - } - other => panic!("Expected ReservedHeaderConflict error, got: {:?}", other), - } } /// Unit test: post_message should reject reserved header "last-event-id" @@ -529,3 +531,344 @@ async fn test_mcp_custom_headers_sent_to_server() -> anyhow::Result<()> { Ok(()) } + +/// Integration test: Verify that MCP-Protocol-Version header is sent on post-init requests +#[tokio::test] +#[cfg(all( + feature = "transport-streamable-http-client", + feature = "transport-streamable-http-client-reqwest" +))] +async fn test_mcp_protocol_version_header_sent_after_init() -> anyhow::Result<()> { + use std::{net::SocketAddr, sync::Arc}; + + use axum::{ + Router, body::Bytes, extract::State, http::StatusCode, response::IntoResponse, + routing::post, + }; + use rmcp::{ + ServiceExt, + transport::{ + StreamableHttpClientTransport, + streamable_http_client::StreamableHttpClientTransportConfig, + }, + }; + use serde_json::json; + use tokio::sync::Mutex; + + type CapturedRequests = Vec<(String, Option)>; + + #[derive(Clone)] + struct ServerState { + /// Captures the MCP-Protocol-Version header value for each request method + protocol_version_by_method: Arc>, + initialized_called: Arc, + } + + async fn mcp_handler( + State(state): State, + headers: http::HeaderMap, + body: Bytes, + ) -> impl IntoResponse { + let protocol_version = headers + .get("mcp-protocol-version") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + if let Ok(json_body) = serde_json::from_slice::(&body) { + let method = json_body + .get("method") + .and_then(|m| m.as_str()) + .unwrap_or("unknown") + .to_string(); + + state + .protocol_version_by_method + .lock() + .await + .push((method.clone(), protocol_version)); + + if method == "initialize" { + let response = json!({ + "jsonrpc": "2.0", + "id": json_body.get("id"), + "result": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "serverInfo": { + "name": "test-server", + "version": "1.0.0" + } + } + }); + return ( + StatusCode::OK, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "test-session-456", + ), + ], + response.to_string(), + ); + } else if method == "notifications/initialized" { + state.initialized_called.notify_one(); + return ( + StatusCode::ACCEPTED, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "test-session-456", + ), + ], + String::new(), + ); + } + } + + let response = json!({ + "jsonrpc": "2.0", + "id": 1, + "result": {} + }); + ( + StatusCode::OK, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "test-session-456", + ), + ], + response.to_string(), + ) + } + + let state = ServerState { + protocol_version_by_method: Arc::new(Mutex::new(Vec::new())), + initialized_called: Arc::new(tokio::sync::Notify::new()), + }; + + let app = Router::new() + .route("/mcp", post(mcp_handler)) + .with_state(state.clone()); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = tokio::net::TcpListener::bind(addr).await?; + let port = listener.local_addr()?.port(); + + let server_handle = tokio::spawn(async move { axum::serve(listener, app).await }); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let config = + StreamableHttpClientTransportConfig::with_uri(format!("http://127.0.0.1:{}/mcp", port)); + + let transport = StreamableHttpClientTransport::from_config(config); + let client = ().serve(transport).await.expect("Failed to start client"); + + tokio::time::timeout( + std::time::Duration::from_secs(5), + state.initialized_called.notified(), + ) + .await + .expect("Initialized notification should be received"); + + // Give time for the initialized notification to be fully processed + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let captured = state.protocol_version_by_method.lock().await; + + // The initialize request should NOT have MCP-Protocol-Version + // (the version isn't known yet) + let init_entry = captured + .iter() + .find(|(m, _)| m == "initialize") + .expect("Should have captured initialize request"); + assert_eq!( + init_entry.1, None, + "Initialize request should not have MCP-Protocol-Version header" + ); + + // The initialized notification should HAVE MCP-Protocol-Version + let initialized_entry = captured + .iter() + .find(|(m, _)| m == "notifications/initialized") + .expect("Should have captured initialized notification"); + assert_eq!( + initialized_entry.1, + Some("2025-03-26".to_string()), + "Initialized notification should include MCP-Protocol-Version: 2025-03-26" + ); + + drop(client); + server_handle.abort(); + + Ok(()) +} + +/// Integration test: Verify server rejects unsupported MCP-Protocol-Version with 400 +#[tokio::test] +#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))] +async fn test_server_rejects_unsupported_protocol_version() { + use std::sync::Arc; + + use bytes::Bytes; + use http::{Method, Request, header::CONTENT_TYPE}; + use http_body_util::Full; + use rmcp::{ + handler::server::ServerHandler, + model::{ServerCapabilities, ServerInfo}, + transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, + }; + use serde_json::json; + + #[derive(Clone)] + struct TestHandler; + + impl ServerHandler for TestHandler { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder().build(), + ..Default::default() + } + } + } + + let session_manager = Arc::new(LocalSessionManager::default()); + let service = StreamableHttpService::new( + || Ok(TestHandler), + session_manager, + StreamableHttpServerConfig::default(), + ); + + // First, send an initialize request to create a session + let init_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0" + } + } + }); + + let init_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(init_request).await; + assert_eq!(response.status(), http::StatusCode::OK); + + // Extract session id from response + let session_id = response + .headers() + .get("mcp-session-id") + .expect("Should have session id") + .to_str() + .unwrap() + .to_string(); + + // Send initialized notification to complete handshake + let initialized_body = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + let initialized_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("mcp-session-id", &session_id) + .header("mcp-protocol-version", "2025-03-26") + .body(Full::new(Bytes::from(initialized_body.to_string()))) + .unwrap(); + + let response = service.handle(initialized_request).await; + assert_eq!(response.status(), http::StatusCode::ACCEPTED); + + // Test 1: Valid protocol version should succeed + let valid_body = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + let valid_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("mcp-session-id", &session_id) + .header("mcp-protocol-version", "2025-03-26") + .body(Full::new(Bytes::from(valid_body.to_string()))) + .unwrap(); + + let response = service.handle(valid_request).await; + assert_eq!( + response.status(), + http::StatusCode::ACCEPTED, + "Valid MCP-Protocol-Version should be accepted" + ); + + // Test 2: Unsupported protocol version should return 400 + let invalid_body = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + let invalid_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("mcp-session-id", &session_id) + .header("mcp-protocol-version", "9999-01-01") + .body(Full::new(Bytes::from(invalid_body.to_string()))) + .unwrap(); + + let response = service.handle(invalid_request).await; + assert_eq!( + response.status(), + http::StatusCode::BAD_REQUEST, + "Unsupported MCP-Protocol-Version should return 400" + ); + + // Test 3: Missing protocol version should succeed (backwards compat) + let no_version_body = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + let no_version_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("mcp-session-id", &session_id) + .body(Full::new(Bytes::from(no_version_body.to_string()))) + .unwrap(); + + let response = service.handle(no_version_request).await; + assert_eq!( + response.status(), + http::StatusCode::ACCEPTED, + "Missing MCP-Protocol-Version should be accepted (backwards compat)" + ); +} + +/// Unit test: ProtocolVersion::as_str and KNOWN_VERSIONS +#[test] +fn test_protocol_version_utilities() { + use rmcp::model::ProtocolVersion; + + assert_eq!(ProtocolVersion::V_2025_06_18.as_str(), "2025-06-18"); + assert_eq!(ProtocolVersion::V_2025_03_26.as_str(), "2025-03-26"); + assert_eq!(ProtocolVersion::V_2024_11_05.as_str(), "2024-11-05"); + + assert_eq!(ProtocolVersion::KNOWN_VERSIONS.len(), 3); + assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2024_11_05)); + assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_03_26)); + assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_06_18)); +}