Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ impl ProtocolVersion {
pub const V_2024_11_05: Self = Self(Cow::Borrowed("2024-11-05"));
// Keep LATEST at 2025-03-26 until full 2025-06-18 compliance and automated testing are in place.
pub const LATEST: Self = Self::V_2025_03_26;

/// All protocol versions known to this SDK.
pub const KNOWN_VERSIONS: &[Self] =
&[Self::V_2024_11_05, Self::V_2025_03_26, Self::V_2025_06_18];

/// Returns the string representation of this protocol version.
pub fn as_str(&self) -> &str {
&self.0
}
}

impl Serialize for ProtocolVersion {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ where
uri: std::sync::Arc<str>,
session_id: std::sync::Arc<str>,
mut auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<(), crate::transport::streamable_http_client::StreamableHttpError<Self::Error>>
{
if auth_token.is_none() {
auth_token = Some(self.get_access_token().await?);
}
self.http_client
.delete_session(uri, session_id, auth_token)
.delete_session(uri, session_id, auth_token, custom_headers)
.await
}

Expand All @@ -33,6 +34,7 @@ where
session_id: std::sync::Arc<str>,
last_event_id: Option<String>,
mut auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<
futures::stream::BoxStream<'static, Result<sse_stream::Sse, sse_stream::Error>>,
crate::transport::streamable_http_client::StreamableHttpError<Self::Error>,
Expand All @@ -41,7 +43,7 @@ where
auth_token = Some(self.get_access_token().await?);
}
self.http_client
.get_stream(uri, session_id, last_event_id, auth_token)
.get_stream(uri, session_id, last_event_id, auth_token, custom_headers)
.await
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,43 @@ impl From<reqwest::Error> for StreamableHttpError<reqwest::Error> {
}
}

/// Reserved headers that must not be overridden by user-supplied custom headers.
/// `MCP-Protocol-Version` is in this list but is allowed through because the worker
/// injects it after initialization.
const RESERVED_HEADERS: &[&str] = &[
"accept",
HEADER_SESSION_ID,
HEADER_MCP_PROTOCOL_VERSION,
HEADER_LAST_EVENT_ID,
];

/// Applies custom headers to a request builder, rejecting reserved headers
/// except `MCP-Protocol-Version` (which the worker injects after init).
fn apply_custom_headers(
mut builder: reqwest::RequestBuilder,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<reqwest::RequestBuilder, StreamableHttpError<reqwest::Error>> {
for (name, value) in custom_headers {
if RESERVED_HEADERS
.iter()
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
{
if name
.as_str()
.eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION)
{
builder = builder.header(name, value);
continue;
}
return Err(StreamableHttpError::ReservedHeaderConflict(
name.to_string(),
));
}
builder = builder.header(name, value);
}
Ok(builder)
}

impl StreamableHttpClient for reqwest::Client {
type Error = reqwest::Error;

Expand All @@ -31,6 +68,7 @@ impl StreamableHttpClient for reqwest::Client {
session_id: Arc<str>,
last_event_id: Option<String>,
auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<BoxStream<'static, Result<Sse, SseError>>, StreamableHttpError<Self::Error>> {
let mut request_builder = self
.get(uri.as_ref())
Expand All @@ -42,6 +80,7 @@ impl StreamableHttpClient for reqwest::Client {
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
request_builder = apply_custom_headers(request_builder, custom_headers)?;
let response = request_builder.send().await?;
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
return Err(StreamableHttpError::ServerDoesNotSupportSse);
Expand Down Expand Up @@ -70,15 +109,15 @@ impl StreamableHttpClient for reqwest::Client {
uri: Arc<str>,
session: Arc<str>,
auth_token: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> Result<(), StreamableHttpError<Self::Error>> {
let mut request_builder = self.delete(uri.as_ref());
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
let response = request_builder
.header(HEADER_SESSION_ID, session.as_ref())
.send()
.await?;
request_builder = request_builder.header(HEADER_SESSION_ID, session.as_ref());
request_builder = apply_custom_headers(request_builder, custom_headers)?;
let response = request_builder.send().await?;

// if method no allowed
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
Expand All @@ -104,25 +143,7 @@ impl StreamableHttpClient for reqwest::Client {
request = request.bearer_auth(auth_header);
}

// Apply custom headers
let reserved_headers = [
ACCEPT.as_str(),
HEADER_SESSION_ID,
HEADER_MCP_PROTOCOL_VERSION,
HEADER_LAST_EVENT_ID,
];
for (name, value) in custom_headers {
if reserved_headers
.iter()
.any(|&r| name.as_str().eq_ignore_ascii_case(r))
{
return Err(StreamableHttpError::ReservedHeaderConflict(
name.to_string(),
));
}

request = request.header(name, value);
}
request = apply_custom_headers(request, custom_headers)?;
if let Some(session_id) = session_id {
request = request.header(HEADER_SESSION_ID, session_id.as_ref());
}
Expand Down
86 changes: 68 additions & 18 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use tracing::debug;
use super::common::client_side_sse::{ExponentialBackoff, SseRetryPolicy, SseStreamReconnect};
use crate::{
RoleClient,
model::{ClientJsonRpcMessage, ServerJsonRpcMessage},
model::{ClientJsonRpcMessage, ServerJsonRpcMessage, ServerResult},
transport::{
common::client_side_sse::SseAutoReconnectStream,
worker::{Worker, WorkerQuitReason, WorkerSendRequest, WorkerTransport},
Expand Down Expand Up @@ -184,13 +184,15 @@ pub trait StreamableHttpClient: Clone + Send + 'static {
uri: Arc<str>,
session_id: Arc<str>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> impl Future<Output = Result<(), StreamableHttpError<Self::Error>>> + Send + '_;
fn get_stream(
&self,
uri: Arc<str>,
session_id: Arc<str>,
last_event_id: Option<String>,
auth_header: Option<String>,
custom_headers: HashMap<HeaderName, HeaderValue>,
) -> impl Future<
Output = Result<
BoxStream<'static, Result<Sse, SseError>>,
Expand All @@ -210,6 +212,7 @@ struct StreamableHttpClientReconnect<C> {
pub session_id: Arc<str>,
pub uri: Arc<str>,
pub auth_header: Option<String>,
pub custom_headers: HashMap<HeaderName, HeaderValue>,
}

impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconnect<C> {
Expand All @@ -220,15 +223,25 @@ impl<C: StreamableHttpClient> SseStreamReconnect for StreamableHttpClientReconne
let uri = self.uri.clone();
let session_id = self.session_id.clone();
let auth_header = self.auth_header.clone();
let custom_headers = self.custom_headers.clone();
let last_event_id = last_event_id.map(|s| s.to_owned());
Box::pin(async move {
client
.get_stream(uri, session_id, last_event_id, auth_header)
.get_stream(uri, session_id, last_event_id, auth_header, custom_headers)
.await
})
}
}

/// Info retained for cleaning up the session when the worker exits.
struct SessionCleanupInfo<C> {
client: C,
uri: Arc<str>,
session_id: Arc<str>,
auth_header: Option<String>,
protocol_headers: HashMap<HeaderName, HeaderValue>,
}

#[derive(Debug, Clone, Default)]
pub struct StreamableHttpClientWorker<C: StreamableHttpClient> {
pub client: C,
Expand Down Expand Up @@ -357,14 +370,29 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
}
None
};
// Extract the negotiated protocol version from the init response
// and build a custom headers map that includes MCP-Protocol-Version
// for all subsequent HTTP requests (per MCP 2025-06-18 spec).
let protocol_headers = {
let mut headers = config.custom_headers.clone();
if let ServerJsonRpcMessage::Response(response) = &message {
if let ServerResult::InitializeResult(init_result) = &response.result {
if let Ok(hv) = HeaderValue::from_str(init_result.protocol_version.as_str()) {
// HeaderName::from_static requires lowercase
headers.insert(HeaderName::from_static("mcp-protocol-version"), hv);
}
}
}
headers
};

// Store session info for cleanup when run() exits (not spawned, so cleanup completes before close() returns)
let session_cleanup_info = session_id.as_ref().map(|sid| {
(
self.client.clone(),
config.uri.clone(),
sid.clone(),
config.auth_header.clone(),
)
let session_cleanup_info = session_id.as_ref().map(|sid| SessionCleanupInfo {
client: self.client.clone(),
uri: config.uri.clone(),
session_id: sid.clone(),
auth_header: config.auth_header.clone(),
protocol_headers: protocol_headers.clone(),
});

context.send_to_handler(message).await?;
Expand All @@ -376,7 +404,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
initialized_notification.message,
session_id.clone(),
config.auth_header.clone(),
config.custom_headers.clone(),
protocol_headers.clone(),
)
.await
.map_err(WorkerQuitReason::fatal_context(
Expand Down Expand Up @@ -404,10 +432,17 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
let transport_task_ct = transport_task_ct.clone();
let config_uri = config.uri.clone();
let config_auth_header = config.auth_header.clone();
let spawn_headers = protocol_headers.clone();

streams.spawn(async move {
match client
.get_stream(uri.clone(), session_id.clone(), None, auth_header.clone())
.get_stream(
uri.clone(),
session_id.clone(),
None,
auth_header.clone(),
spawn_headers.clone(),
)
.await
{
Ok(stream) => {
Expand All @@ -418,6 +453,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
session_id: session_id.clone(),
uri: config_uri,
auth_header: config_auth_header,
custom_headers: spawn_headers,
},
retry_config,
);
Expand Down Expand Up @@ -482,7 +518,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
message,
session_id.clone(),
config.auth_header.clone(),
config.custom_headers.clone(),
protocol_headers.clone(),
)
.await;
let send_result = match response {
Expand All @@ -504,6 +540,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
session_id: session_id.clone(),
uri: config.uri.clone(),
auth_header: config.auth_header.clone(),
custom_headers: protocol_headers.clone(),
},
self.config.retry_config.clone(),
);
Expand Down Expand Up @@ -550,32 +587,41 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {

// Cleanup session before returning (ensures close() waits for session deletion)
// Use a timeout to prevent indefinite hangs if the server is unresponsive
if let Some((client, url, session_id, auth_header)) = session_cleanup_info {
if let Some(cleanup) = session_cleanup_info {
const SESSION_CLEANUP_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
let cleanup_session_id = cleanup.session_id.clone();
match tokio::time::timeout(
SESSION_CLEANUP_TIMEOUT,
client.delete_session(url, session_id.clone(), auth_header),
cleanup.client.delete_session(
cleanup.uri,
cleanup.session_id,
cleanup.auth_header,
cleanup.protocol_headers,
),
)
.await
{
Ok(Ok(_)) => {
tracing::info!(session_id = session_id.as_ref(), "delete session success")
tracing::info!(
session_id = cleanup_session_id.as_ref(),
"delete session success"
)
}
Ok(Err(StreamableHttpError::ServerDoesNotSupportDeleteSession)) => {
tracing::info!(
session_id = session_id.as_ref(),
session_id = cleanup_session_id.as_ref(),
"server doesn't support delete session"
)
}
Ok(Err(e)) => {
tracing::error!(
session_id = session_id.as_ref(),
session_id = cleanup_session_id.as_ref(),
"fail to delete session: {e}"
);
}
Err(_elapsed) => {
tracing::warn!(
session_id = session_id.as_ref(),
session_id = cleanup_session_id.as_ref(),
"session cleanup timed out after {:?}",
SESSION_CLEANUP_TIMEOUT
);
Expand Down Expand Up @@ -652,6 +698,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
/// _uri: Arc<str>,
/// _session_id: Arc<str>,
/// _auth_header: Option<String>,
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
/// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
/// todo!()
/// }
Expand All @@ -662,6 +709,7 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
/// _session_id: Arc<str>,
/// _last_event_id: Option<String>,
/// _auth_header: Option<String>,
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
/// ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
/// todo!()
/// }
Expand Down Expand Up @@ -737,6 +785,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
/// _uri: Arc<str>,
/// _session_id: Arc<str>,
/// _auth_header: Option<String>,
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
/// ) -> Result<(), rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
/// todo!()
/// }
Expand All @@ -747,6 +796,7 @@ impl<C: StreamableHttpClient> StreamableHttpClientTransport<C> {
/// _session_id: Arc<str>,
/// _last_event_id: Option<String>,
/// _auth_header: Option<String>,
/// _custom_headers: HashMap<HeaderName, HeaderValue>,
/// ) -> Result<BoxStream<'static, Result<Sse, SseError>>, rmcp::transport::streamable_http_client::StreamableHttpError<Self::Error>> {
/// todo!()
/// }
Expand Down
Loading