Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ fun Server.registerConformanceTools() {
CreateMessageRequest(
CreateMessageRequestParams(
maxTokens = 10000,
messages = listOf(SamplingMessage(Role.User, TextContent(prompt))),
messages = listOf(SamplingMessage(Role.User, listOf(TextContent(prompt)))),
),
),
)
CallToolResult(listOf(TextContent(result.content.toString())))
CallToolResult(listOf(TextContent(result.content.joinToString("\n") { it.toString() })))
Copy link

Copilot AI Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result.content.joinToString("\n") { it.toString() } will produce Kotlin data-class toString() representations (e.g. TextContent(text=..., type=...)) rather than the model’s text output. If this conformance tool is intended to return the sampled text, consider extracting TextContent.text (and deciding how to handle non-text blocks/tool_use blocks) instead of relying on toString().

Suggested change
CallToolResult(listOf(TextContent(result.content.joinToString("\n") { it.toString() })))
val sampledText = result.content.joinToString("\n") {
when (it) {
is TextContent -> it.text
is ImageContent -> "[image content]"
is AudioContent -> "[audio content]"
is EmbeddedResource -> "[embedded resource]"
}
}
CallToolResult(listOf(TextContent(sampledText)))

Copilot uses AI. Check for mistakes.
}

// 9. Elicitation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class ClientTest {
),
options = ClientOptions(
capabilities = ClientCapabilities(
sampling = EmptyJsonObject,
sampling = ClientCapabilities.sampling,
),
),
)
Expand Down Expand Up @@ -152,7 +152,7 @@ class ClientTest {
),
options = ClientOptions(
capabilities = ClientCapabilities(
sampling = EmptyJsonObject,
sampling = ClientCapabilities.sampling,
),
),
)
Expand Down Expand Up @@ -339,7 +339,7 @@ class ClientTest {
val client = Client(
clientInfo = Implementation(name = "test client", version = "1.0"),
options = ClientOptions(
capabilities = ClientCapabilities(sampling = EmptyJsonObject),
capabilities = ClientCapabilities(sampling = ClientCapabilities.sampling),
),
)

Expand Down Expand Up @@ -633,18 +633,16 @@ class ClientTest {
),
options = ClientOptions(
capabilities = ClientCapabilities(
sampling = EmptyJsonObject,
sampling = ClientCapabilities.sampling,
),
),
)

client.setRequestHandler<CreateMessageRequest>(Method.Defined.SamplingCreateMessage) { _, _ ->
CreateMessageResult(
model = "test-model",
role = Role.Assistant,
content = TextContent(
text = "Test response",
),
content = TextContent(text = "Test response"),
model = "test-model",
)
}

Expand All @@ -670,7 +668,7 @@ class ClientTest {
val client = Client(
clientInfo = Implementation(name = "test client", version = "1.0"),
options = ClientOptions(
capabilities = ClientCapabilities(sampling = EmptyJsonObject),
capabilities = ClientCapabilities(sampling = ClientCapabilities.sampling),
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
package io.modelcontextprotocol.kotlin.sdk.server

import io.modelcontextprotocol.kotlin.sdk.client.Client
import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions
import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest
import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageResult
import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
import io.modelcontextprotocol.kotlin.sdk.types.IncludeContext
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.Role
import io.modelcontextprotocol.kotlin.sdk.types.SamplingMessage
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.StopReason
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
import io.modelcontextprotocol.kotlin.sdk.types.Tool
import io.modelcontextprotocol.kotlin.sdk.types.ToolChoice
import io.modelcontextprotocol.kotlin.sdk.types.ToolResultContent
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.json.JsonPrimitive
import kotlinx.serialization.json.buildJsonObject
import org.junit.jupiter.api.assertDoesNotThrow
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith

class SamplingTest {

private val dummyTool = Tool(
name = "t",
inputSchema = ToolSchema(properties = buildJsonObject { }, required = emptyList()),
)

private val weatherTool = Tool(
name = "get_weather",
description = "Return the current temperature in Celsius.",
inputSchema = ToolSchema(
properties = buildJsonObject {
put("location", buildJsonObject { put("type", JsonPrimitive("string")) })
},
required = listOf("location"),
),
)

private val minimalMessages = listOf(SamplingMessage(Role.User, TextContent("hi")))

/**
* Builds a connected [Server]+[Client] pair using [InMemoryTransport].
*
* @param clientCapabilities the capabilities the client advertises during initialize
* @param samplingHandler the handler the client uses to respond to sampling requests
* @return Pair of (server, sessionId) ready for [Server.createMessage] calls.
*/
private fun buildPair(
clientCapabilities: ClientCapabilities = ClientCapabilities(
sampling = ClientCapabilities.Sampling(),
),
samplingHandler: (CreateMessageRequest) -> CreateMessageResult = { _ ->
CreateMessageResult(role = Role.Assistant, content = TextContent("ok"), model = "m")
},
): Pair<Server, String> {
val server = Server(
serverInfo = Implementation(name = "srv", version = "1.0"),
options = ServerOptions(capabilities = ServerCapabilities()),
)

val client = Client(
clientInfo = Implementation(name = "cli", version = "1.0"),
options = ClientOptions(capabilities = clientCapabilities),
)

client.setRequestHandler<CreateMessageRequest>(Method.Defined.SamplingCreateMessage) { req, _ ->
samplingHandler(req)
}

val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()

var sessionId: String? = null
runBlocking {
val sessionDeferred = CompletableDeferred<String>()
launch { client.connect(clientTransport) }
launch {
val session = server.createSession(serverTransport)
sessionDeferred.complete(session.sessionId)
}
sessionId = sessionDeferred.await()
}

return Pair(server, checkNotNull(sessionId))
}

// ============================================================================
// Server.createMessage — capability enforcement (SEP-1577)
// ============================================================================

@Test
fun `tools field rejected when client has no sampling tools capability`() {
val (server, sessionId) = buildPair()
assertFailsWith<IllegalArgumentException> {
runBlocking {
server.createMessage(
sessionId = sessionId,
params = CreateMessageRequest(
params = CreateMessageRequestParams(
maxTokens = 100,
messages = minimalMessages,
tools = listOf(dummyTool),
),
),
)
}
}
}

@Test
fun `toolChoice field rejected when client has no sampling tools capability`() {
val (server, sessionId) = buildPair()
assertFailsWith<IllegalArgumentException> {
runBlocking {
server.createMessage(
sessionId = sessionId,
params = CreateMessageRequest(
params = CreateMessageRequestParams(
maxTokens = 100,
messages = minimalMessages,
toolChoice = ToolChoice(),
),
),
)
}
}
}

@Test
fun `includeContext with no sampling context capability succeeds with a warning`() {
val (server, sessionId) = buildPair()
assertDoesNotThrow {
runBlocking {
server.createMessage(
sessionId = sessionId,
params = CreateMessageRequest(
params = CreateMessageRequestParams(
maxTokens = 100,
messages = minimalMessages,
includeContext = IncludeContext.ThisServer,
),
),
)
}
}
}

// ============================================================================
// End-to-end tool-loop integration
// ============================================================================

@Test
fun `server sends tools, client returns tool_use then final text`() {
var turn = 0
val (server, sessionId) = buildPair(
clientCapabilities = ClientCapabilities(
sampling = ClientCapabilities.Sampling(tools = EmptyJsonObject),
),
samplingHandler = { _ ->
turn++
if (turn == 1) {
CreateMessageResult(
role = Role.Assistant,
content = listOf(
TextContent("Let me check."),
ToolUseContent(
id = "call_1",
name = "get_weather",
input = buildJsonObject { put("location", JsonPrimitive("London")) },
),
),
model = "test",
stopReason = StopReason.ToolUse,
)
} else {
CreateMessageResult(
role = Role.Assistant,
content = TextContent("The temperature in London is 20°C."),
model = "test",
stopReason = StopReason.EndTurn,
)
}
},
)

runBlocking {
val messages = mutableListOf(
SamplingMessage(Role.User, TextContent("What is the weather in London?")),
)

// — Turn 1: server sends tools, expects tool_use stop reason
val first = server.createMessage(
sessionId = sessionId,
params = CreateMessageRequest(
params = CreateMessageRequestParams(
maxTokens = 256,
messages = messages.toList(),
tools = listOf(weatherTool),
toolChoice = ToolChoice(mode = ToolChoice.Mode.Auto),
),
),
)
assertEquals(StopReason.ToolUse, first.stopReason)
assertEquals(2, first.content.size)

// — Append assistant turn and inject tool result
messages.add(SamplingMessage(Role.Assistant, first.content))
val toolUse = first.content.filterIsInstance<ToolUseContent>().single()
messages.add(
SamplingMessage(
Role.User,
ToolResultContent(
toolUseId = toolUse.id,
content = listOf(TextContent("""{"tempC":20}""")),
),
),
)

// — Turn 2: server sends updated history, expects final text
val second = server.createMessage(
sessionId = sessionId,
params = CreateMessageRequest(
params = CreateMessageRequestParams(
maxTokens = 256,
messages = messages.toList(),
tools = listOf(weatherTool),
toolChoice = ToolChoice(mode = ToolChoice.Mode.Auto),
),
),
)
assertEquals(StopReason.EndTurn, second.stopReason)
assertEquals(1, second.content.size)
val text = second.content.single() as TextContent
assertEquals("The temperature in London is 20°C.", text.text)

server.close()
}
}
}
1 change: 1 addition & 0 deletions kotlin-sdk-client/api/kotlin-sdk-client.api
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp
public static synthetic fun subscribeResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/types/SubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public final fun unsubscribeResource (Lio/modelcontextprotocol/kotlin/sdk/types/UnsubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun unsubscribeResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/types/UnsubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
protected fun wrapRequestHandler (Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function3;)Lkotlin/jvm/functions/Function3;
}

public final class io/modelcontextprotocol/kotlin/sdk/client/ClientKt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging
import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi
import io.modelcontextprotocol.kotlin.sdk.shared.Protocol
import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions
import io.modelcontextprotocol.kotlin.sdk.shared.RequestHandlerExtra
import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions
import io.modelcontextprotocol.kotlin.sdk.shared.Transport
import io.modelcontextprotocol.kotlin.sdk.types.BooleanSchema
Expand All @@ -13,6 +14,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.CompleteRequest
import io.modelcontextprotocol.kotlin.sdk.types.CompleteResult
import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest
import io.modelcontextprotocol.kotlin.sdk.types.DoubleSchema
import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequest
import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestFormParams
Expand Down Expand Up @@ -45,7 +47,9 @@ import io.modelcontextprotocol.kotlin.sdk.types.PingRequest
import io.modelcontextprotocol.kotlin.sdk.types.PrimitiveSchemaDefinition
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult
import io.modelcontextprotocol.kotlin.sdk.types.Request
import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta
import io.modelcontextprotocol.kotlin.sdk.types.RequestResult
import io.modelcontextprotocol.kotlin.sdk.types.Root
import io.modelcontextprotocol.kotlin.sdk.types.RootsListChangedNotification
import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS
Expand Down Expand Up @@ -323,6 +327,26 @@ public open class Client(private val clientInfo: Implementation, options: Client
}
}

/**
* Wraps incoming-request handlers with SEP-1577 client-side enforcement.
*
* For `sampling/createMessage`: if the incoming request carries `tools` or
* `toolChoice` but this client did not advertise [ClientCapabilities.Sampling.tools],
* the wrapper throws an [McpException] with JSON-RPC error code `InvalidParams`
* before the user-supplied handler runs. Matches the TypeScript SDK wrapper in
* `Client.setRequestHandler`.
*/
override fun <T : Request> wrapRequestHandler(
method: Method,
block: suspend (T, RequestHandlerExtra) -> RequestResult?,
): suspend (T, RequestHandlerExtra) -> RequestResult? {
if (method != Method.Defined.SamplingCreateMessage) return block
return { request, extra ->
(request as? CreateMessageRequest)?.let { validateSamplingToolsCapability(it, capabilities) }
block(request, extra)
}
}

/**
* Sends a ping request to the server to check connectivity.
*
Expand Down
Loading
Loading