diff --git a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt index 54d5c8c44..dc126f681 100644 --- a/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt +++ b/conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt @@ -176,11 +176,11 @@ fun Server.registerConformanceTools() { CreateMessageRequest( CreateMessageRequestParams( maxTokens = 10000, - messages = listOf(SamplingMessage(Role.User, TextContent(prompt))), + messages = listOf(SamplingMessage(Role.User, listOf(TextContent(prompt)))), ), ), ) - CallToolResult(listOf(TextContent(result.content.toString()))) + CallToolResult(listOf(TextContent(result.content.joinToString("\n") { it.toString() }))) } // 9. Elicitation diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index d0acda59d..e5f66383a 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -106,7 +106,7 @@ class ClientTest { ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject, + sampling = ClientCapabilities.sampling, ), ), ) @@ -152,7 +152,7 @@ class ClientTest { ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject, + sampling = ClientCapabilities.sampling, ), ), ) @@ -339,7 +339,7 @@ class ClientTest { val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( - capabilities = ClientCapabilities(sampling = EmptyJsonObject), + capabilities = ClientCapabilities(sampling = ClientCapabilities.sampling), ), ) @@ -633,18 +633,16 @@ class ClientTest { ), options = ClientOptions( capabilities = ClientCapabilities( - sampling = EmptyJsonObject, + sampling = ClientCapabilities.sampling, ), ), ) client.setRequestHandler(Method.Defined.SamplingCreateMessage) { _, _ -> CreateMessageResult( - model = "test-model", role = Role.Assistant, - content = TextContent( - text = "Test response", - ), + content = TextContent(text = "Test response"), + model = "test-model", ) } @@ -670,7 +668,7 @@ class ClientTest { val client = Client( clientInfo = Implementation(name = "test client", version = "1.0"), options = ClientOptions( - capabilities = ClientCapabilities(sampling = EmptyJsonObject), + capabilities = ClientCapabilities(sampling = ClientCapabilities.sampling), ), ) diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt new file mode 100644 index 000000000..f161d8ec0 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt @@ -0,0 +1,251 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions +import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageResult +import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.IncludeContext +import io.modelcontextprotocol.kotlin.sdk.types.Method +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.SamplingMessage +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.StopReason +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.Tool +import io.modelcontextprotocol.kotlin.sdk.types.ToolChoice +import io.modelcontextprotocol.kotlin.sdk.types.ToolResultContent +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import org.junit.jupiter.api.assertDoesNotThrow +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class SamplingTest { + + private val dummyTool = Tool( + name = "t", + inputSchema = ToolSchema(properties = buildJsonObject { }, required = emptyList()), + ) + + private val weatherTool = Tool( + name = "get_weather", + description = "Return the current temperature in Celsius.", + inputSchema = ToolSchema( + properties = buildJsonObject { + put("location", buildJsonObject { put("type", JsonPrimitive("string")) }) + }, + required = listOf("location"), + ), + ) + + private val minimalMessages = listOf(SamplingMessage(Role.User, TextContent("hi"))) + + /** + * Builds a connected [Server]+[Client] pair using [InMemoryTransport]. + * + * @param clientCapabilities the capabilities the client advertises during initialize + * @param samplingHandler the handler the client uses to respond to sampling requests + * @return Pair of (server, sessionId) ready for [Server.createMessage] calls. + */ + private fun buildPair( + clientCapabilities: ClientCapabilities = ClientCapabilities( + sampling = ClientCapabilities.Sampling(), + ), + samplingHandler: (CreateMessageRequest) -> CreateMessageResult = { _ -> + CreateMessageResult(role = Role.Assistant, content = TextContent("ok"), model = "m") + }, + ): Pair { + val server = Server( + serverInfo = Implementation(name = "srv", version = "1.0"), + options = ServerOptions(capabilities = ServerCapabilities()), + ) + + val client = Client( + clientInfo = Implementation(name = "cli", version = "1.0"), + options = ClientOptions(capabilities = clientCapabilities), + ) + + client.setRequestHandler(Method.Defined.SamplingCreateMessage) { req, _ -> + samplingHandler(req) + } + + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + var sessionId: String? = null + runBlocking { + val sessionDeferred = CompletableDeferred() + launch { client.connect(clientTransport) } + launch { + val session = server.createSession(serverTransport) + sessionDeferred.complete(session.sessionId) + } + sessionId = sessionDeferred.await() + } + + return Pair(server, checkNotNull(sessionId)) + } + + // ============================================================================ + // Server.createMessage — capability enforcement (SEP-1577) + // ============================================================================ + + @Test + fun `tools field rejected when client has no sampling tools capability`() { + val (server, sessionId) = buildPair() + assertFailsWith { + runBlocking { + server.createMessage( + sessionId = sessionId, + params = CreateMessageRequest( + params = CreateMessageRequestParams( + maxTokens = 100, + messages = minimalMessages, + tools = listOf(dummyTool), + ), + ), + ) + } + } + } + + @Test + fun `toolChoice field rejected when client has no sampling tools capability`() { + val (server, sessionId) = buildPair() + assertFailsWith { + runBlocking { + server.createMessage( + sessionId = sessionId, + params = CreateMessageRequest( + params = CreateMessageRequestParams( + maxTokens = 100, + messages = minimalMessages, + toolChoice = ToolChoice(), + ), + ), + ) + } + } + } + + @Test + fun `includeContext with no sampling context capability succeeds with a warning`() { + val (server, sessionId) = buildPair() + assertDoesNotThrow { + runBlocking { + server.createMessage( + sessionId = sessionId, + params = CreateMessageRequest( + params = CreateMessageRequestParams( + maxTokens = 100, + messages = minimalMessages, + includeContext = IncludeContext.ThisServer, + ), + ), + ) + } + } + } + + // ============================================================================ + // End-to-end tool-loop integration + // ============================================================================ + + @Test + fun `server sends tools, client returns tool_use then final text`() { + var turn = 0 + val (server, sessionId) = buildPair( + clientCapabilities = ClientCapabilities( + sampling = ClientCapabilities.Sampling(tools = EmptyJsonObject), + ), + samplingHandler = { _ -> + turn++ + if (turn == 1) { + CreateMessageResult( + role = Role.Assistant, + content = listOf( + TextContent("Let me check."), + ToolUseContent( + id = "call_1", + name = "get_weather", + input = buildJsonObject { put("location", JsonPrimitive("London")) }, + ), + ), + model = "test", + stopReason = StopReason.ToolUse, + ) + } else { + CreateMessageResult( + role = Role.Assistant, + content = TextContent("The temperature in London is 20°C."), + model = "test", + stopReason = StopReason.EndTurn, + ) + } + }, + ) + + runBlocking { + val messages = mutableListOf( + SamplingMessage(Role.User, TextContent("What is the weather in London?")), + ) + + // — Turn 1: server sends tools, expects tool_use stop reason + val first = server.createMessage( + sessionId = sessionId, + params = CreateMessageRequest( + params = CreateMessageRequestParams( + maxTokens = 256, + messages = messages.toList(), + tools = listOf(weatherTool), + toolChoice = ToolChoice(mode = ToolChoice.Mode.Auto), + ), + ), + ) + assertEquals(StopReason.ToolUse, first.stopReason) + assertEquals(2, first.content.size) + + // — Append assistant turn and inject tool result + messages.add(SamplingMessage(Role.Assistant, first.content)) + val toolUse = first.content.filterIsInstance().single() + messages.add( + SamplingMessage( + Role.User, + ToolResultContent( + toolUseId = toolUse.id, + content = listOf(TextContent("""{"tempC":20}""")), + ), + ), + ) + + // — Turn 2: server sends updated history, expects final text + val second = server.createMessage( + sessionId = sessionId, + params = CreateMessageRequest( + params = CreateMessageRequestParams( + maxTokens = 256, + messages = messages.toList(), + tools = listOf(weatherTool), + toolChoice = ToolChoice(mode = ToolChoice.Mode.Auto), + ), + ), + ) + assertEquals(StopReason.EndTurn, second.stopReason) + assertEquals(1, second.content.size) + val text = second.content.single() as TextContent + assertEquals("The temperature in London is 20°C.", text.text) + + server.close() + } + } +} diff --git a/kotlin-sdk-client/api/kotlin-sdk-client.api b/kotlin-sdk-client/api/kotlin-sdk-client.api index da0e752a7..8b6fb79b4 100644 --- a/kotlin-sdk-client/api/kotlin-sdk-client.api +++ b/kotlin-sdk-client/api/kotlin-sdk-client.api @@ -41,6 +41,7 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp public static synthetic fun subscribeResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/types/SubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public final fun unsubscribeResource (Lio/modelcontextprotocol/kotlin/sdk/types/UnsubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static synthetic fun unsubscribeResource$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/types/UnsubscribeRequest;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; + protected fun wrapRequestHandler (Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function3;)Lkotlin/jvm/functions/Function3; } public final class io/modelcontextprotocol/kotlin/sdk/client/ClientKt { diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt index 34b9a5bf3..49b20d7cb 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/Client.kt @@ -4,6 +4,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi import io.modelcontextprotocol.kotlin.sdk.shared.Protocol import io.modelcontextprotocol.kotlin.sdk.shared.ProtocolOptions +import io.modelcontextprotocol.kotlin.sdk.shared.RequestHandlerExtra import io.modelcontextprotocol.kotlin.sdk.shared.RequestOptions import io.modelcontextprotocol.kotlin.sdk.shared.Transport import io.modelcontextprotocol.kotlin.sdk.types.BooleanSchema @@ -13,6 +14,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities import io.modelcontextprotocol.kotlin.sdk.types.CompleteRequest import io.modelcontextprotocol.kotlin.sdk.types.CompleteResult +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest import io.modelcontextprotocol.kotlin.sdk.types.DoubleSchema import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequest import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestFormParams @@ -45,7 +47,9 @@ import io.modelcontextprotocol.kotlin.sdk.types.PingRequest import io.modelcontextprotocol.kotlin.sdk.types.PrimitiveSchemaDefinition import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult +import io.modelcontextprotocol.kotlin.sdk.types.Request import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta +import io.modelcontextprotocol.kotlin.sdk.types.RequestResult import io.modelcontextprotocol.kotlin.sdk.types.Root import io.modelcontextprotocol.kotlin.sdk.types.RootsListChangedNotification import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS @@ -323,6 +327,26 @@ public open class Client(private val clientInfo: Implementation, options: Client } } + /** + * Wraps incoming-request handlers with SEP-1577 client-side enforcement. + * + * For `sampling/createMessage`: if the incoming request carries `tools` or + * `toolChoice` but this client did not advertise [ClientCapabilities.Sampling.tools], + * the wrapper throws an [McpException] with JSON-RPC error code `InvalidParams` + * before the user-supplied handler runs. Matches the TypeScript SDK wrapper in + * `Client.setRequestHandler`. + */ + override fun wrapRequestHandler( + method: Method, + block: suspend (T, RequestHandlerExtra) -> RequestResult?, + ): suspend (T, RequestHandlerExtra) -> RequestResult? { + if (method != Method.Defined.SamplingCreateMessage) return block + return { request, extra -> + (request as? CreateMessageRequest)?.let { validateSamplingToolsCapability(it, capabilities) } + block(request, extra) + } + } + /** * Sends a ping request to the server to check connectivity. * diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SamplingValidation.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SamplingValidation.kt new file mode 100644 index 000000000..1290cb0fa --- /dev/null +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SamplingValidation.kt @@ -0,0 +1,24 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.RPCError + +/** + * Validates SEP-1577 client-side enforcement: an incoming `sampling/createMessage` + * request that carries `tools` or `toolChoice` requires the client to have advertised + * the `sampling.tools` sub-capability. When the capability is missing, throws an + * [McpException] with JSON-RPC error code `InvalidParams`, matching TypeScript SDK + * client enforcement. + */ +internal fun validateSamplingToolsCapability(request: CreateMessageRequest, capabilities: ClientCapabilities) { + val params = request.params + if (params.tools == null && params.toolChoice == null) return + if (capabilities.sampling?.tools != null) return + val field = if (params.tools != null) "tools" else "toolChoice" + throw McpException( + code = RPCError.ErrorCode.INVALID_PARAMS, + message = "Client does not support sampling with tools but request contains $field parameter", + ) +} diff --git a/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientSamplingValidationTest.kt b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientSamplingValidationTest.kt new file mode 100644 index 000000000..7775a08f3 --- /dev/null +++ b/kotlin-sdk-client/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientSamplingValidationTest.kt @@ -0,0 +1,98 @@ +package io.modelcontextprotocol.kotlin.sdk.client + +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequest +import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject +import io.modelcontextprotocol.kotlin.sdk.types.McpException +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.SamplingMessage +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.Tool +import io.modelcontextprotocol.kotlin.sdk.types.ToolChoice +import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.serialization.json.buildJsonObject +import org.junit.jupiter.api.assertDoesNotThrow +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class ClientSamplingValidationTest { + + private val dummyTool = Tool( + name = "t", + inputSchema = ToolSchema(properties = buildJsonObject { }, required = emptyList()), + ) + + private val minimalMessages = listOf(SamplingMessage(Role.User, TextContent("hi"))) + + private val noToolsCaps = ClientCapabilities(sampling = ClientCapabilities.sampling) + private val withToolsCaps = ClientCapabilities( + sampling = ClientCapabilities.Sampling(tools = EmptyJsonObject), + ) + + @Test + fun `request without tools or toolChoice is always accepted`() { + val request = CreateMessageRequest( + CreateMessageRequestParams(maxTokens = 10, messages = minimalMessages), + ) + assertDoesNotThrow { validateSamplingToolsCapability(request, noToolsCaps) } + } + + @Test + fun `tools without sampling tools capability throws InvalidParams`() { + val request = CreateMessageRequest( + CreateMessageRequestParams( + maxTokens = 10, + messages = minimalMessages, + tools = listOf(dummyTool), + ), + ) + val exception = assertFailsWith { + validateSamplingToolsCapability(request, noToolsCaps) + } + assertEquals(RPCError.ErrorCode.INVALID_PARAMS, exception.code) + check(exception.message?.contains("tools") == true) + } + + @Test + fun `toolChoice without sampling tools capability throws InvalidParams`() { + val request = CreateMessageRequest( + CreateMessageRequestParams( + maxTokens = 10, + messages = minimalMessages, + toolChoice = ToolChoice(mode = ToolChoice.Mode.Required), + ), + ) + val exception = assertFailsWith { + validateSamplingToolsCapability(request, noToolsCaps) + } + assertEquals(RPCError.ErrorCode.INVALID_PARAMS, exception.code) + check(exception.message?.contains("toolChoice") == true) + } + + @Test + fun `tools with sampling tools capability is accepted`() { + val request = CreateMessageRequest( + CreateMessageRequestParams( + maxTokens = 10, + messages = minimalMessages, + tools = listOf(dummyTool), + ), + ) + assertDoesNotThrow { validateSamplingToolsCapability(request, withToolsCaps) } + } + + @Test + fun `toolChoice with sampling tools capability is accepted`() { + val request = CreateMessageRequest( + CreateMessageRequestParams( + maxTokens = 10, + messages = minimalMessages, + toolChoice = ToolChoice(mode = ToolChoice.Mode.Auto), + ), + ) + assertDoesNotThrow { validateSamplingToolsCapability(request, withToolsCaps) } + } +} diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index b6fe94228..c843aea06 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -79,6 +79,7 @@ public abstract class io/modelcontextprotocol/kotlin/sdk/shared/Protocol { public final fun setFallbackRequestHandler (Lkotlin/jvm/functions/Function3;)V public final fun setNotificationHandler (Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function1;)V public final fun setRequestHandlerInternal (Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function3;)V + protected fun wrapRequestHandler (Lio/modelcontextprotocol/kotlin/sdk/types/Method;Lkotlin/jvm/functions/Function3;)Lkotlin/jvm/functions/Function3; } public final class io/modelcontextprotocol/kotlin/sdk/shared/ProtocolKt { @@ -606,21 +607,23 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/CancelledNotificatio public final class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Companion; public fun ()V + public fun (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Ljava/util/Map;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Ljava/util/Map;)V public synthetic fun (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Ljava/util/Map;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public final fun component1 ()Lkotlinx/serialization/json/JsonObject; + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling; public final fun component2 ()Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots; public final fun component3 ()Lkotlinx/serialization/json/JsonObject; public final fun component4 ()Lkotlinx/serialization/json/JsonObject; public final fun component5 ()Ljava/util/Map; - public final fun copy (Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Ljava/util/Map;)Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities; - public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities;Lkotlinx/serialization/json/JsonObject;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Ljava/util/Map;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Ljava/util/Map;)Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Ljava/util/Map;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities; public fun equals (Ljava/lang/Object;)Z public final fun getElicitation ()Lkotlinx/serialization/json/JsonObject; public final fun getExperimental ()Lkotlinx/serialization/json/JsonObject; public final fun getExtensions ()Ljava/util/Map; public final fun getRoots ()Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Roots; - public final fun getSampling ()Lkotlinx/serialization/json/JsonObject; + public final fun getSampling ()Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -638,7 +641,7 @@ public final synthetic class io/modelcontextprotocol/kotlin/sdk/types/ClientCapa public final class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Companion { public final fun getElicitation ()Lkotlinx/serialization/json/JsonObject; - public final fun getSampling ()Lkotlinx/serialization/json/JsonObject; + public final fun getSampling ()Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling; public final fun serializer ()Lkotlinx/serialization/KSerializer; } @@ -671,6 +674,37 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$R public final fun serializer ()Lkotlinx/serialization/KSerializer; } +public final class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling { + public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling$Companion; + public fun ()V + public fun (Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;)V + public synthetic fun (Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Lkotlinx/serialization/json/JsonObject; + public final fun component2 ()Lkotlinx/serialization/json/JsonObject; + public final fun copy (Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling; + public fun equals (Ljava/lang/Object;)Z + public final fun getContext ()Lkotlinx/serialization/json/JsonObject; + public final fun getTools ()Lkotlinx/serialization/json/JsonObject; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final synthetic class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling$$serializer : kotlinx/serialization/internal/GeneratedSerializer { + public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling$$serializer; + public final fun childSerializers ()[Lkotlinx/serialization/KSerializer; + public final fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling; + public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object; + public final fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor; + public final fun serialize (Lkotlinx/serialization/encoding/Encoder;Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling;)V + public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V + public fun typeParametersSerializers ()[Lkotlinx/serialization/KSerializer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling$Companion { + public final fun serializer ()Lkotlinx/serialization/KSerializer; +} + public final class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilitiesBuilder { public fun ()V public final fun build ()Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities; @@ -681,8 +715,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilitiesBu public final fun extensions (Ljava/util/Map;)V public final fun roots (Ljava/lang/Boolean;)V public static synthetic fun roots$default (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilitiesBuilder;Ljava/lang/Boolean;ILjava/lang/Object;)V - public final fun sampling (Lkotlin/jvm/functions/Function1;)V - public final fun sampling (Lkotlinx/serialization/json/JsonObject;)V + public final fun sampling (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Sampling;)V } public abstract interface class io/modelcontextprotocol/kotlin/sdk/types/ClientNotification : io/modelcontextprotocol/kotlin/sdk/types/Notification { @@ -934,6 +967,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/ContentTypes : java/ public static final field IMAGE Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; public static final field RESOURCE_LINK Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; public static final field TEXT Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; + public static final field TOOL_RESULT Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; + public static final field TOOL_USE Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; public static fun getEntries ()Lkotlin/enums/EnumEntries; public final fun getValue ()Ljava/lang/String; public static fun valueOf (Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; @@ -1006,9 +1041,11 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequest public final class io/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequestParams : io/modelcontextprotocol/kotlin/sdk/types/RequestParams { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequestParams$Companion; - public synthetic fun (ILjava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext;Ljava/lang/Double;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public synthetic fun (ILjava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext;Ljava/lang/Double;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (ILjava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext;Ljava/lang/Double;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (ILjava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext;Ljava/lang/Double;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice;Lkotlinx/serialization/json/JsonObject;Lkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()I + public final fun component10 ()Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice; + public final fun component11-VI-3G7E ()Lkotlinx/serialization/json/JsonObject; public final fun component2 ()Ljava/util/List; public final fun component3 ()Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences; public final fun component4 ()Ljava/lang/String; @@ -1016,9 +1053,9 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequest public final fun component6 ()Ljava/lang/Double; public final fun component7 ()Ljava/util/List; public final fun component8 ()Lkotlinx/serialization/json/JsonObject; - public final fun component9-VI-3G7E ()Lkotlinx/serialization/json/JsonObject; - public final fun copy-qtvSJzw (ILjava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext;Ljava/lang/Double;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequestParams; - public static synthetic fun copy-qtvSJzw$default (Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequestParams;ILjava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext;Ljava/lang/Double;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequestParams; + public final fun component9 ()Ljava/util/List; + public final fun copy-APmi5RE (ILjava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext;Ljava/lang/Double;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequestParams; + public static synthetic fun copy-APmi5RE$default (Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequestParams;ILjava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ModelPreferences;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext;Ljava/lang/Double;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequestParams; public fun equals (Ljava/lang/Object;)Z public final fun getIncludeContext ()Lio/modelcontextprotocol/kotlin/sdk/types/IncludeContext; public final fun getMaxTokens ()I @@ -1029,6 +1066,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequest public final fun getStopSequences ()Ljava/util/List; public final fun getSystemPrompt ()Ljava/lang/String; public final fun getTemperature ()Ljava/lang/Double; + public final fun getToolChoice ()Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice; + public final fun getTools ()Ljava/util/List; public fun hashCode ()I public fun toString ()Ljava/lang/String; } @@ -1050,17 +1089,19 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/CreateMessageRequest public final class io/modelcontextprotocol/kotlin/sdk/types/CreateMessageResult : io/modelcontextprotocol/kotlin/sdk/types/ClientResult { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageResult$Companion; - public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V - public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/types/Role; - public final fun component2 ()Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent; + public final fun component2 ()Ljava/util/List; public final fun component3 ()Ljava/lang/String; public final fun component4-6olV-UY ()Ljava/lang/String; public final fun component5 ()Lkotlinx/serialization/json/JsonObject; - public final fun copy-ctVX1tw (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageResult; - public static synthetic fun copy-ctVX1tw$default (Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageResult;Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageResult; + public final fun copy-ctVX1tw (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageResult; + public static synthetic fun copy-ctVX1tw$default (Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageResult;Lio/modelcontextprotocol/kotlin/sdk/types/Role;Ljava/util/List;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/CreateMessageResult; public fun equals (Ljava/lang/Object;)Z - public final fun getContent ()Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent; + public final fun getContent ()Ljava/util/List; public fun getMeta ()Lkotlinx/serialization/json/JsonObject; public final fun getModel ()Ljava/lang/String; public final fun getRole ()Lio/modelcontextprotocol/kotlin/sdk/types/Role; @@ -2937,7 +2978,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/McpException : java/ public final fun getData ()Lkotlinx/serialization/json/JsonElement; } -public abstract interface class io/modelcontextprotocol/kotlin/sdk/types/MediaContent : io/modelcontextprotocol/kotlin/sdk/types/ContentBlock { +public abstract interface class io/modelcontextprotocol/kotlin/sdk/types/MediaContent : io/modelcontextprotocol/kotlin/sdk/types/ContentBlock, io/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent$Companion; } @@ -4199,13 +4240,18 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/Roots_dslKt { public final class io/modelcontextprotocol/kotlin/sdk/types/SamplingMessage { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessage$Companion; - public fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent;)V + public fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent;Lkotlinx/serialization/json/JsonObject;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/types/Role; - public final fun component2 ()Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent; - public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent;)Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessage; - public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessage;Lio/modelcontextprotocol/kotlin/sdk/types/Role;Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessage; + public final fun component2 ()Ljava/util/List; + public final fun component3 ()Lkotlinx/serialization/json/JsonObject; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/types/Role;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessage; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessage;Lio/modelcontextprotocol/kotlin/sdk/types/Role;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessage; public fun equals (Ljava/lang/Object;)Z - public final fun getContent ()Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent; + public final fun getContent ()Ljava/util/List; + public final fun getMeta ()Lkotlinx/serialization/json/JsonObject; public final fun getRole ()Lio/modelcontextprotocol/kotlin/sdk/types/Role; public fun hashCode ()I public fun toString ()Ljava/lang/String; @@ -4233,6 +4279,15 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/SamplingMessageBuild public final fun user (Lio/modelcontextprotocol/kotlin/sdk/types/MediaContent;)V } +public abstract interface class io/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent : io/modelcontextprotocol/kotlin/sdk/types/WithMeta { + public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent$Companion; + public abstract fun getType ()Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; +} + +public final class io/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent$Companion { + public final fun serializer ()Lkotlinx/serialization/KSerializer; +} + public final class io/modelcontextprotocol/kotlin/sdk/types/Sampling_dslKt { public static final fun assistant (Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessageBuilder;Lkotlin/jvm/functions/Function0;)V public static final fun assistantAudio (Lio/modelcontextprotocol/kotlin/sdk/types/SamplingMessageBuilder;Lkotlin/jvm/functions/Function1;)V @@ -4508,6 +4563,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/StopReason$Companion public final fun getEndTurn--MeDoko ()Ljava/lang/String; public final fun getMaxTokens--MeDoko ()Ljava/lang/String; public final fun getStopSequence--MeDoko ()Ljava/lang/String; + public final fun getToolUse--MeDoko ()Ljava/lang/String; public final fun serializer ()Lkotlinx/serialization/KSerializer; } @@ -5077,6 +5133,49 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/ToolAnnotations$Comp public final fun serializer ()Lkotlinx/serialization/KSerializer; } +public final class io/modelcontextprotocol/kotlin/sdk/types/ToolChoice { + public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Companion; + public fun ()V + public fun (Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode;)V + public synthetic fun (Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode; + public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice;Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice; + public fun equals (Ljava/lang/Object;)Z + public final fun getMode ()Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final synthetic class io/modelcontextprotocol/kotlin/sdk/types/ToolChoice$$serializer : kotlinx/serialization/internal/GeneratedSerializer { + public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$$serializer; + public final fun childSerializers ()[Lkotlinx/serialization/KSerializer; + public final fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice; + public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object; + public final fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor; + public final fun serialize (Lkotlinx/serialization/encoding/Encoder;Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice;)V + public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V + public fun typeParametersSerializers ()[Lkotlinx/serialization/KSerializer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Companion { + public final fun serializer ()Lkotlinx/serialization/KSerializer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode : java/lang/Enum { + public static final field Auto Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode; + public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode$Companion; + public static final field None Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode; + public static final field Required Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode; + public static fun getEntries ()Lkotlin/enums/EnumEntries; + public static fun valueOf (Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode; + public static fun values ()[Lio/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode; +} + +public final class io/modelcontextprotocol/kotlin/sdk/types/ToolChoice$Mode$Companion { + public final fun serializer ()Lkotlinx/serialization/KSerializer; +} + public final class io/modelcontextprotocol/kotlin/sdk/types/ToolExecution { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ToolExecution$Companion; public fun ()V @@ -5137,6 +5236,43 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/ToolListChangedNotif public final fun serializer ()Lkotlinx/serialization/KSerializer; } +public final class io/modelcontextprotocol/kotlin/sdk/types/ToolResultContent : io/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent { + public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ToolResultContent$Companion; + public fun (Ljava/lang/String;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Boolean;Lkotlinx/serialization/json/JsonObject;)V + public synthetic fun (Ljava/lang/String;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Boolean;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()Ljava/util/List; + public final fun component3 ()Lkotlinx/serialization/json/JsonObject; + public final fun component4 ()Ljava/lang/Boolean; + public final fun component5 ()Lkotlinx/serialization/json/JsonObject; + public final fun copy (Ljava/lang/String;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Boolean;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolResultContent; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/ToolResultContent;Ljava/lang/String;Ljava/util/List;Lkotlinx/serialization/json/JsonObject;Ljava/lang/Boolean;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolResultContent; + public fun equals (Ljava/lang/Object;)Z + public final fun getContent ()Ljava/util/List; + public fun getMeta ()Lkotlinx/serialization/json/JsonObject; + public final fun getStructuredContent ()Lkotlinx/serialization/json/JsonObject; + public final fun getToolUseId ()Ljava/lang/String; + public fun getType ()Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; + public fun hashCode ()I + public final fun isError ()Ljava/lang/Boolean; + public fun toString ()Ljava/lang/String; +} + +public final synthetic class io/modelcontextprotocol/kotlin/sdk/types/ToolResultContent$$serializer : kotlinx/serialization/internal/GeneratedSerializer { + public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/types/ToolResultContent$$serializer; + public final fun childSerializers ()[Lkotlinx/serialization/KSerializer; + public final fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolResultContent; + public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object; + public final fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor; + public final fun serialize (Lkotlinx/serialization/encoding/Encoder;Lio/modelcontextprotocol/kotlin/sdk/types/ToolResultContent;)V + public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V + public fun typeParametersSerializers ()[Lkotlinx/serialization/KSerializer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/types/ToolResultContent$Companion { + public final fun serializer ()Lkotlinx/serialization/KSerializer; +} + public final class io/modelcontextprotocol/kotlin/sdk/types/ToolSchema { public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ToolSchema$Companion; public fun ()V @@ -5173,6 +5309,41 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/ToolSchema$Companion public final fun serializer ()Lkotlinx/serialization/KSerializer; } +public final class io/modelcontextprotocol/kotlin/sdk/types/ToolUseContent : io/modelcontextprotocol/kotlin/sdk/types/SamplingMessageContent { + public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ToolUseContent$Companion; + public fun (Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;)V + public synthetic fun (Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun component1 ()Ljava/lang/String; + public final fun component2 ()Ljava/lang/String; + public final fun component3 ()Lkotlinx/serialization/json/JsonObject; + public final fun component4 ()Lkotlinx/serialization/json/JsonObject; + public final fun copy (Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolUseContent; + public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/ToolUseContent;Ljava/lang/String;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolUseContent; + public fun equals (Ljava/lang/Object;)Z + public final fun getId ()Ljava/lang/String; + public final fun getInput ()Lkotlinx/serialization/json/JsonObject; + public fun getMeta ()Lkotlinx/serialization/json/JsonObject; + public final fun getName ()Ljava/lang/String; + public fun getType ()Lio/modelcontextprotocol/kotlin/sdk/types/ContentTypes; + public fun hashCode ()I + public fun toString ()Ljava/lang/String; +} + +public final synthetic class io/modelcontextprotocol/kotlin/sdk/types/ToolUseContent$$serializer : kotlinx/serialization/internal/GeneratedSerializer { + public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/types/ToolUseContent$$serializer; + public final fun childSerializers ()[Lkotlinx/serialization/KSerializer; + public final fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Lio/modelcontextprotocol/kotlin/sdk/types/ToolUseContent; + public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object; + public final fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor; + public final fun serialize (Lkotlinx/serialization/encoding/Encoder;Lio/modelcontextprotocol/kotlin/sdk/types/ToolUseContent;)V + public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V + public fun typeParametersSerializers ()[Lkotlinx/serialization/KSerializer; +} + +public final class io/modelcontextprotocol/kotlin/sdk/types/ToolUseContent$Companion { + public final fun serializer ()Lkotlinx/serialization/KSerializer; +} + public final class io/modelcontextprotocol/kotlin/sdk/types/ToolsKt { public static final fun error (Lio/modelcontextprotocol/kotlin/sdk/types/CallToolResult$Companion;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;)Lio/modelcontextprotocol/kotlin/sdk/types/CallToolResult; public static synthetic fun error$default (Lio/modelcontextprotocol/kotlin/sdk/types/CallToolResult$Companion;Ljava/lang/String;Lkotlinx/serialization/json/JsonObject;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/CallToolResult; diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index bc3cf7673..75cebd3dd 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -556,16 +556,31 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio block: suspend (T, RequestHandlerExtra) -> RequestResult?, ) { assertRequestHandlerCapability(method) + val wrapped = wrapRequestHandler(method, block) _requestHandlers.update { current -> current.put(method.value) { jSONRPCRequest, extraHandler -> val request = jSONRPCRequest.fromJSON() - val response = block(request as T, extraHandler) + val response = wrapped(request as T, extraHandler) response } } } + /** + * Subclass hook to wrap an incoming-request handler before it is registered. + * + * Called once by [setRequestHandler] during registration. Subclasses may return a + * new function that, when invoked, performs additional checks (capability gates, + * schema validation, etc.) before delegating to [block]. The default implementation + * is the identity. + */ + @Suppress("UNUSED_PARAMETER") + protected open fun wrapRequestHandler( + method: Method, + block: suspend (T, RequestHandlerExtra) -> RequestResult?, + ): suspend (T, RequestHandlerExtra) -> RequestResult? = block + /** * Removes the request handler for the given method. */ diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/capabilities.dsl.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/capabilities.dsl.kt index d936381d9..75e923b2d 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/capabilities.dsl.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/capabilities.dsl.kt @@ -21,7 +21,7 @@ import kotlinx.serialization.json.buildJsonObject * val request = buildInitializeRequest { * protocolVersion = "1.0" * capabilities { - * sampling(ClientCapabilities.sampling) + * sampling(ClientCapabilities.Sampling()) * roots(listChanged = true) * experimental { * put("customFeature", JsonPrimitive(true)) @@ -36,46 +36,32 @@ import kotlinx.serialization.json.buildJsonObject */ @McpDsl public class ClientCapabilitiesBuilder @PublishedApi internal constructor() { - private var sampling: JsonObject? = null + private var sampling: ClientCapabilities.Sampling? = null private var roots: ClientCapabilities.Roots? = null private var elicitation: JsonObject? = null private var extensions: Map? = null private var experimental: JsonObject? = null /** - * Indicates that the client supports sampling from an LLM. + * Sampling capability configuration. See [ClientCapabilities.Sampling]. * - * Use [ClientCapabilities.sampling] for default empty configuration. + * Pass `ClientCapabilities.Sampling()` to enable base sampling with no sub-capabilities. + * Construct `ClientCapabilities.Sampling(tools = EmptyJsonObject, context = EmptyJsonObject)` + * directly to enable SEP-1577 sub-capabilities. * * Example: * ```kotlin * capabilities { - * sampling(ClientCapabilities.sampling) + * sampling(ClientCapabilities.Sampling(tools = EmptyJsonObject)) * } * ``` * * @param value The sampling capability configuration */ - public fun sampling(value: JsonObject) { + public fun sampling(value: ClientCapabilities.Sampling) { this.sampling = value } - /** - * Indicates that the client supports sampling from an LLM with custom configuration. - * - * Example: - * ```kotlin - * capabilities { - * sampling { - * put("temperature", JsonPrimitive(0.7)) - * } - * } - * ``` - * - * @param block Lambda for building the sampling configuration - */ - public fun sampling(block: JsonObjectBuilder.() -> Unit): Unit = sampling(buildJsonObject(block)) - /** * Indicates that the client supports listing roots. * diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/capabilities.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/capabilities.kt index 14248fae7..e13a75036 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/capabilities.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/capabilities.kt @@ -41,6 +41,7 @@ public data class Implementation( * supports that capability. * * @property sampling Present if the client supports sampling from an LLM. + * Use [ClientCapabilities.Sampling] to configure SEP-1577 sub-capabilities (tools, context). * @property roots Present if the client supports listing roots. * @property elicitation Present if the client supports elicitation from the server. * @property experimental Experimental, non-standard capabilities that the client supports. @@ -51,7 +52,7 @@ public data class Implementation( */ @Serializable public data class ClientCapabilities( - public val sampling: JsonObject? = null, + public val sampling: Sampling? = null, public val roots: Roots? = null, public val elicitation: JsonObject? = null, public val experimental: JsonObject? = null, @@ -59,11 +60,51 @@ public data class ClientCapabilities( ) { /** - * @property sampling convenience value to enable the sampling capability + * Source-compatibility constructor retaining the pre-SEP-1577 `sampling: JsonObject?` + * shape. Any non-null [sampling] is converted to an empty [Sampling] (sub-capabilities + * cannot be recovered from the old opaque `JsonObject`). + */ + @Deprecated( + "ClientCapabilities.sampling is now typed. Pass a ClientCapabilities.Sampling? " + + "instead of JsonObject?.", + ReplaceWith( + "ClientCapabilities(sampling?.let { ClientCapabilities.Sampling() }, " + + "roots, elicitation, experimental, extensions)", + ), + ) + public constructor( + sampling: JsonObject?, + roots: Roots? = null, + elicitation: JsonObject? = null, + experimental: JsonObject? = null, + extensions: Map? = null, + ) : this( + sampling = sampling?.let { Sampling() }, + roots = roots, + elicitation = elicitation, + experimental = experimental, + extensions = extensions, + ) + + /** + * sub-capabilities for sampling. + * + * @property context Present if the client supports context inclusion via + * [CreateMessageRequestParams.includeContext] with values other than [IncludeContext.None]. + * Servers SHOULD avoid non-`none` values when this field is absent. + * @property tools Present if the client supports tool use via + * [CreateMessageRequestParams.tools] / [CreateMessageRequestParams.toolChoice]. + * Servers MUST NOT send `tools`/`toolChoice` when this field is absent. + */ + @Serializable + public data class Sampling(public val context: JsonObject? = null, public val tools: JsonObject? = null) + + /** + * @property sampling convenience value to enable the base sampling capability with no sub-capabilities * @property elicitation convenience value to enable the elicitation capability */ public companion object { - public val sampling: JsonObject = EmptyJsonObject + public val sampling: Sampling = Sampling() public val elicitation: JsonObject = EmptyJsonObject } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/content.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/content.kt index a1c9e4fec..d30c22e40 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/content.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/content.kt @@ -29,6 +29,12 @@ public enum class ContentTypes(public val value: String) { @SerialName("resource") EMBEDDED_RESOURCE("resource"), + + @SerialName("tool_use") + TOOL_USE("tool_use"), + + @SerialName("tool_result") + TOOL_RESULT("tool_result"), } /** @@ -43,9 +49,27 @@ public sealed interface ContentBlock : WithMeta { /** * Content block that carries media data such as text, images, or audio. + * + * Every [MediaContent] is also a valid [SamplingMessageContent]; the sampling content + * hierarchy additionally admits [ToolUseContent] and [ToolResultContent]. */ @Serializable(with = MediaContentPolymorphicSerializer::class) -public sealed interface MediaContent : ContentBlock +public sealed interface MediaContent : + ContentBlock, + SamplingMessageContent + +/** + * Content block that can appear inside a [SamplingMessage] or [CreateMessageResult]. + * + * Implemented by [TextContent], [ImageContent], [AudioContent], [ToolUseContent], + * and [ToolResultContent]. + * + * @property type discriminator identifying the content block subtype + */ +@Serializable(with = SamplingMessageContentPolymorphicSerializer::class) +public sealed interface SamplingMessageContent : WithMeta { + public val type: ContentTypes +} /** * Text provided to or from an LLM. @@ -174,3 +198,50 @@ public data class EmbeddedResource( @EncodeDefault public override val type: ContentTypes = ContentTypes.EMBEDDED_RESOURCE } + +/** + * A request from the assistant to invoke a tool during sampling. + * + * @property id Unique identifier for this tool use; matches a subsequent + * [ToolResultContent.toolUseId] that reports the result. + * @property name The tool name (must match a tool declared in the sampling request's tools list). + * @property input The arguments to pass to the tool, conforming to the tool's input schema. + * @property meta property/parameter is reserved by MCP to allow clients and servers + * to attach additional metadata to their interactions. + */ +@Serializable +public data class ToolUseContent( + val id: String, + val name: String, + val input: JsonObject, + @SerialName("_meta") + override val meta: JsonObject? = null, +) : SamplingMessageContent { + @EncodeDefault + public override val type: ContentTypes = ContentTypes.TOOL_USE +} + +/** + * The result of a tool call previously requested via [ToolUseContent], supplied back + * to the assistant on the next sampling turn. + * + * @property toolUseId The id of the [ToolUseContent] this result corresponds to. + * @property content The unstructured result, following the same shape as [CallToolResult.content]. + * @property structuredContent Optional structured result; if the tool declared an output schema, + * this SHOULD conform to it. + * @property isError Whether the tool call ended in error. Defaults to absent (treated as false). + * @property meta property/parameter is reserved by MCP to allow clients and servers + * to attach additional metadata to their interactions. + */ +@Serializable +public data class ToolResultContent( + val toolUseId: String, + val content: List = emptyList(), + val structuredContent: JsonObject? = null, + val isError: Boolean? = null, + @SerialName("_meta") + override val meta: JsonObject? = null, +) : SamplingMessageContent { + @EncodeDefault + public override val type: ContentTypes = ContentTypes.TOOL_RESULT +} diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/initialize.dsl.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/initialize.dsl.kt index 2cb58b77b..f3e1575bc 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/initialize.dsl.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/initialize.dsl.kt @@ -21,7 +21,7 @@ import kotlin.contracts.contract * val request = buildInitializeRequest { * protocolVersion = "2024-11-05" * capabilities { - * sampling(ClientCapabilities.sampling) + * sampling(ClientCapabilities.Sampling()) * roots(listChanged = true) * } * info("MyClient", "1.0.0") @@ -33,7 +33,7 @@ import kotlin.contracts.contract * val request = buildInitializeRequest { * protocolVersion = "2024-11-05" * capabilities { - * sampling(ClientCapabilities.sampling) + * sampling(ClientCapabilities.Sampling()) * experimental { * put("feature", JsonPrimitive(true)) * } @@ -94,7 +94,7 @@ public class InitializeRequestBuilder @PublishedApi internal constructor() : Req * Example: * ```kotlin * capabilities(ClientCapabilities( - * sampling = ClientCapabilities.sampling, + * sampling = ClientCapabilities.Sampling(), * roots = ClientCapabilities.Roots(listChanged = true) * )) * ``` @@ -113,7 +113,7 @@ public class InitializeRequestBuilder @PublishedApi internal constructor() : Req * Example: * ```kotlin * capabilities { - * sampling(ClientCapabilities.sampling) + * sampling(ClientCapabilities.Sampling()) * roots(listChanged = true) * elicitation(ClientCapabilities.elicitation) * } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.dsl.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.dsl.kt index 1bc00e281..4752ff50b 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.dsl.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.dsl.kt @@ -112,8 +112,8 @@ public class CreateMessageRequestBuilder @PublishedApi internal constructor() : * Example: * ```kotlin * messages(listOf( - * SamplingMessage(Role.User, TextContent("Hello")), - * SamplingMessage(Role.Assistant, TextContent("Hi!")) + * SamplingMessage(Role.User, listOf(TextContent("Hello"))), + * SamplingMessage(Role.Assistant, listOf(TextContent("Hi!"))) * )) * ``` */ @@ -246,7 +246,7 @@ public class SamplingMessageBuilder @PublishedApi internal constructor() { * ``` */ public fun user(content: MediaContent) { - messages.add(SamplingMessage(Role.User, content)) + messages.add(SamplingMessage(Role.User, listOf(content))) } /** @@ -258,7 +258,7 @@ public class SamplingMessageBuilder @PublishedApi internal constructor() { * ``` */ public fun assistant(content: MediaContent) { - messages.add(SamplingMessage(Role.Assistant, content)) + messages.add(SamplingMessage(Role.Assistant, listOf(content))) } @PublishedApi diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt index 97fc0c06c..211fdc798 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt @@ -75,11 +75,33 @@ public data class ModelPreferences( * * Used in sampling requests to provide conversation context and history to the LLM. * + * Under SEP-1577, [content] is a list of [SamplingMessageContent] blocks rather than a + * single block. On the wire, a list of size 1 is serialised as a single object (wire-compatible + * with pre-SEP single-block content); a list of size 2+ is serialised as a JSON array. + * * @property role The role of the message sender (user, assistant, or system). - * @property content The content of the message. Can be text, image, or audio content. + * @property content The content blocks of the message; MUST contain at least one block. + * @property meta Optional metadata reserved by MCP for clients and servers. */ @Serializable -public data class SamplingMessage(val role: Role, val content: MediaContent) +public data class SamplingMessage( + val role: Role, + @Serializable(with = SamplingContentSerializer::class) + val content: List, + @SerialName("_meta") + val meta: JsonObject? = null, +) { + /** + * Convenience constructor for a single-block message. Wraps [content] in a + * singleton list so call sites can write `SamplingMessage(Role.User, TextContent("hi"))` + * without the explicit `listOf(...)`. + */ + public constructor( + role: Role, + content: SamplingMessageContent, + meta: JsonObject? = null, + ) : this(role, listOf(content), meta) +} // ============================================================================ // sampling/createMessage @@ -184,6 +206,12 @@ public data class CreateMessageRequest(override val params: CreateMessageRequest * @property stopSequences Optional list of sequences that will stop generation if encountered. * @property metadata Optional metadata to pass through to the LLM provider. * The format of this metadata is provider-specific. + * @property tools Optional list of tools the model may use during generation. + * The client MUST return an error if this field is present but the client did not advertise + * [ClientCapabilities.Sampling.tools]. + * @property toolChoice Optional policy controlling how the model uses the provided [tools]. + * The client MUST return an error if this field is present but the client did not advertise + * [ClientCapabilities.Sampling.tools]. * @property meta Optional metadata for this request. May include a progressToken for * out-of-band progress notifications. */ @@ -197,6 +225,8 @@ public data class CreateMessageRequestParams( val temperature: Double? = null, val stopSequences: List? = null, val metadata: JsonObject? = null, + val tools: List? = null, + val toolChoice: ToolChoice? = null, @SerialName("_meta") override val meta: RequestMeta? = null, ) : RequestParams @@ -226,22 +256,37 @@ public enum class IncludeContext { * to inspect the response (human in the loop) and decide whether to allow the server to see it. * * @property role The role of the message sender. Typically [Role.Assistant] for LLM-generated responses. - * @property content The generated content. Can be text, image, or audio content. - * @property model The name of the model that generated the message (e.g., "claude-3-opus-20240229", - * "gpt-4-turbo-preview"). This helps the server understand which model was used. + * @property content The generated content blocks; at least one block is required. + * @property model The name of the model that generated the message (e.g., "claude-3-opus-20240229"). * @property stopReason The reason why sampling stopped, if known. - * Common values: "end_turn", "max_tokens", "stop_sequence", "content_filter" + * Common values: [StopReason.EndTurn], [StopReason.StopSequence], [StopReason.MaxTokens], + * [StopReason.ToolUse]. * @property meta Optional metadata for this response. */ @Serializable public data class CreateMessageResult( val role: Role, - val content: MediaContent, + @Serializable(with = SamplingContentSerializer::class) + val content: List, val model: String, val stopReason: StopReason? = null, @SerialName("_meta") override val meta: JsonObject? = null, -) : ClientResult +) : ClientResult { + /** + * Convenience constructor for a single-block response. Wraps [content] in a + * singleton list so call sites can write + * `CreateMessageResult(Role.Assistant, TextContent("ok"), "model-name")` + * without the explicit `listOf(...)`. + */ + public constructor( + role: Role, + content: SamplingMessageContent, + model: String, + stopReason: StopReason? = null, + meta: JsonObject? = null, + ) : this(role, listOf(content), model, stopReason, meta) +} /** * The reason why the LLM stopped generating tokens. @@ -255,10 +300,38 @@ public value class StopReason(public val value: String) { * @property EndTurn generation ended naturally * @property StopSequence a stop sequence was encountered * @property MaxTokens the maximum token limit was reached + * @property ToolUse the assistant issued a tool call (see SEP-1577) */ public companion object { public val EndTurn: StopReason = StopReason("endTurn") public val StopSequence: StopReason = StopReason("stopSequence") public val MaxTokens: StopReason = StopReason("maxTokens") + public val ToolUse: StopReason = StopReason("toolUse") + } +} + +/** + * Controls tool-selection behaviour for [CreateMessageRequest] under SEP-1577. + * + * @property mode Tool-use policy: + * - [Mode.Auto]: the model decides whether to use tools (default when [mode] is absent) + * - [Mode.Required]: the model MUST use at least one tool before completing + * - [Mode.None]: the model MUST NOT use any tool + */ +@Serializable +public data class ToolChoice(public val mode: Mode? = null) { + /** + * Allowed values of [ToolChoice.mode] per SEP-1577. + */ + @Serializable + public enum class Mode { + @SerialName("auto") + Auto, + + @SerialName("required") + Required, + + @SerialName("none") + None, } } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt index 3e259575b..4a1abc61b 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt @@ -4,13 +4,18 @@ import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.serialization.DeserializationStrategy import kotlinx.serialization.KSerializer import kotlinx.serialization.SerializationException +import kotlinx.serialization.builtins.ListSerializer import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonContentPolymorphicSerializer +import kotlinx.serialization.json.JsonDecoder import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonEncoder +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.contentOrNull import kotlinx.serialization.json.jsonObject @@ -121,6 +126,101 @@ internal object MediaContentPolymorphicSerializer : } } +/** + * Polymorphic serializer for [SamplingMessageContent] types. + * Determines the concrete type based on the "type" field in JSON. + */ +internal object SamplingMessageContentPolymorphicSerializer : + JsonContentPolymorphicSerializer(SamplingMessageContent::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy = + when (element.getType()) { + ContentTypes.TEXT.value -> TextContent.serializer() + + ContentTypes.IMAGE.value -> ImageContent.serializer() + + ContentTypes.AUDIO.value -> AudioContent.serializer() + + ContentTypes.TOOL_USE.value -> ToolUseContent.serializer() + + ContentTypes.TOOL_RESULT.value -> ToolResultContent.serializer() + + else -> throw SerializationException( + "Unknown sampling message content type: ${element.getTypeOrNull()}", + ) + } +} + +/** + * Wire-format serializer for `List` honouring SEP-1577's + * single-object-or-array content shape. + * + * **Why this exists.** SEP-1577 widens the `content` field of [SamplingMessage] and + * [CreateMessageResult] from a single content block to either a single block or an + * array of blocks. Pre-SEP peers only understand a single object on the wire; post-SEP + * peers may send arrays when a turn carries multiple blocks (e.g. an assistant reply + * containing `[TextContent, ToolUseContent]` during a tool-loop step). + * + * The Kotlin API uses `List` uniformly to avoid branching in + * caller code. This serializer bridges the type difference: + * + * - **Read:** accepts both shapes; a single JSON object becomes `listOf(block)`, an + * array becomes the decoded list. + * - **Write:** size-1 list → single object (wire-compatible with pre-SEP peers); + * size≥2 → array; empty list throws (sampling content has no valid empty meaning + * per spec). + */ +internal object SamplingContentSerializer : KSerializer> { + + private val listSerializer = ListSerializer(SamplingMessageContentPolymorphicSerializer) + + override val descriptor: SerialDescriptor = listSerializer.descriptor + + override fun serialize(encoder: Encoder, value: List) { + check(value.isNotEmpty()) { "content must contain at least one block" } + val jsonEncoder = encoder as? JsonEncoder + ?: throw SerializationException("SamplingContentSerializer requires a Json encoder") + if (value.size == 1) { + jsonEncoder.encodeJsonElement( + jsonEncoder.json.encodeToJsonElement( + SamplingMessageContentPolymorphicSerializer, + value[0], + ), + ) + } else { + jsonEncoder.encodeJsonElement( + jsonEncoder.json.encodeToJsonElement(listSerializer, value), + ) + } + } + + override fun deserialize(decoder: Decoder): List { + val jsonDecoder = decoder as? JsonDecoder + ?: throw SerializationException("SamplingContentSerializer requires a Json decoder") + return decodeFromElement(jsonDecoder, jsonDecoder.decodeJsonElement()) + } + + private fun decodeFromElement(jsonDecoder: JsonDecoder, element: JsonElement): List = + when (element) { + is JsonArray -> { + if (element.isEmpty()) { + throw SerializationException("content must contain at least one block") + } + jsonDecoder.json.decodeFromJsonElement(listSerializer, element) + } + + is JsonObject -> listOf( + jsonDecoder.json.decodeFromJsonElement( + SamplingMessageContentPolymorphicSerializer, + element, + ), + ) + + else -> throw SerializationException( + "content must be a JSON object or array of objects, got $element", + ) + } +} + // ============================================================================ // Resource Serializers // ============================================================================ diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/CapabilitiesTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/CapabilitiesTest.kt index 171c249dc..8b2e4f849 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/CapabilitiesTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/CapabilitiesTest.kt @@ -1,6 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.types import io.kotest.assertions.json.shouldEqualJson +import io.kotest.matchers.shouldBe import io.modelcontextprotocol.kotlin.test.utils.verifyDeserialization import io.modelcontextprotocol.kotlin.test.utils.verifySerialization import io.modelcontextprotocol.kotlin.test.utils.verifySerializationRoundTrip @@ -8,6 +9,7 @@ import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertNotNull import kotlin.test.assertNull class CapabilitiesTest { @@ -119,7 +121,7 @@ class CapabilitiesTest { @Test fun `should serialize ClientCapabilities with sampling`() { val capabilities = ClientCapabilities( - sampling = ClientCapabilities.sampling, + sampling = ClientCapabilities.Sampling(), ) verifySerialization( capabilities, @@ -266,7 +268,7 @@ class CapabilitiesTest { "io.modelcontextprotocol/ui" to EmptyJsonObject, ) val capabilities = ClientCapabilities( - sampling = ClientCapabilities.sampling, + sampling = ClientCapabilities.Sampling(), roots = ClientCapabilities.Roots(listChanged = true), elicitation = ClientCapabilities.elicitation, experimental = experimental, @@ -309,7 +311,7 @@ class CapabilitiesTest { val capabilities = verifyDeserialization(McpJson, json) - assertEquals(EmptyJsonObject, capabilities.sampling) + assertEquals(ClientCapabilities.Sampling(), capabilities.sampling) assertEquals(true, capabilities.roots?.listChanged) assertEquals(EmptyJsonObject, capabilities.elicitation) } @@ -330,7 +332,7 @@ class CapabilitiesTest { @Test fun `should serialize and deserialize ClientCapabilities round trip`() { val original = ClientCapabilities( - sampling = ClientCapabilities.sampling, + sampling = ClientCapabilities.Sampling(), roots = ClientCapabilities.Roots(listChanged = false), elicitation = ClientCapabilities.elicitation, extensions = mapOf( @@ -734,10 +736,11 @@ class CapabilitiesTest { } """.trimIndent() - val capabilities = verifyDeserialization(McpJson, json) + // Sampling is now a typed struct; unknown fields are ignored on deserialization + val capabilities = McpJson.decodeFromString(json) - // Should not fail - additionalProperties are allowed - assertEquals("customValue", capabilities.sampling?.get("customProperty")?.toString()?.trim('"')) + // Should not fail - additionalProperties are allowed (unknown fields are ignored by Sampling) + assertNotNull(capabilities.sampling) } @Test @@ -755,4 +758,50 @@ class CapabilitiesTest { // Should not fail - additionalProperties are allowed assertEquals("debug", capabilities.logging?.get("level")?.toString()?.trim('"')) } + + // ============================================================================ + // ClientCapabilities.Sampling sub-capabilities (SEP-1577) + // ============================================================================ + + @Test + fun `empty sampling serialises as empty object`() { + val caps = ClientCapabilities(sampling = ClientCapabilities.Sampling()) + val json = McpJson.encodeToString(ClientCapabilities.serializer(), caps) + check("\"sampling\":{}" in json) { json } + } + + @Test + fun `sampling with tools sub-capability serialises tools`() { + val caps = ClientCapabilities(sampling = ClientCapabilities.Sampling(tools = EmptyJsonObject)) + val json = McpJson.encodeToString(ClientCapabilities.serializer(), caps) + check("\"tools\":{}" in json) { json } + } + + @Test + fun `sampling with context sub-capability serialises context`() { + val caps = ClientCapabilities(sampling = ClientCapabilities.Sampling(context = EmptyJsonObject)) + val json = McpJson.encodeToString(ClientCapabilities.serializer(), caps) + check("\"context\":{}" in json) { json } + } + + @Test + fun `sampling with combined tools and context round-trips`() { + val original = ClientCapabilities( + sampling = ClientCapabilities.Sampling( + context = EmptyJsonObject, + tools = EmptyJsonObject, + ), + ) + val json = McpJson.encodeToString(ClientCapabilities.serializer(), original) + val decoded = McpJson.decodeFromString(ClientCapabilities.serializer(), json) + decoded shouldBe original + decoded.sampling?.tools shouldBe EmptyJsonObject + decoded.sampling?.context shouldBe EmptyJsonObject + } + + @Test + fun `absent sampling means unsupported`() { + val caps = ClientCapabilities() + caps.sampling shouldBe null + } } diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/ContentTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/ContentTest.kt index f110326fb..c6dee5dca 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/ContentTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/ContentTest.kt @@ -1,17 +1,26 @@ package io.modelcontextprotocol.kotlin.sdk.types +import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeInstanceOf import io.modelcontextprotocol.kotlin.test.utils.verifyDeserialization import io.modelcontextprotocol.kotlin.test.utils.verifySerialization +import kotlinx.serialization.SerializationException +import kotlinx.serialization.builtins.ListSerializer +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.jsonPrimitive import kotlinx.serialization.json.put import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith import kotlin.test.assertIs import kotlin.test.assertNull class ContentTest { + private val samplingContentSerializer = ListSerializer(SamplingMessageContent.serializer()) + @Test fun `should serialize TextContent with minimal fields`() { val content = TextContent(text = "Hello, MCP!") @@ -308,4 +317,133 @@ class ContentTest { assertIs(content[1]) assertIs(content[2]) } + + // ============================================================================ + // SamplingMessageContent hierarchy (SEP-1577) + // ============================================================================ + + @Test + fun `TextContent implements SamplingMessageContent`() { + val c: SamplingMessageContent = TextContent("hi") + c.shouldBeInstanceOf() + } + + @Test + fun `ImageContent implements SamplingMessageContent`() { + val c: SamplingMessageContent = ImageContent(data = "AA==", mimeType = "image/png") + c.shouldBeInstanceOf() + } + + @Test + fun `AudioContent implements SamplingMessageContent`() { + val c: SamplingMessageContent = AudioContent(data = "AA==", mimeType = "audio/wav") + c.shouldBeInstanceOf() + } + + // ============================================================================ + // ToolUseContent + // ============================================================================ + + @Test + fun `ToolUseContent round-trips through McpJson`() { + val input = buildJsonObject { put("location", JsonPrimitive("London")) } + val original = ToolUseContent(id = "call_1", name = "get_weather", input = input) + val encoded = McpJson.encodeToString(ToolUseContent.serializer(), original) + val decoded = McpJson.decodeFromString(ToolUseContent.serializer(), encoded) + decoded shouldBe original + } + + @Test + fun `ToolUseContent serialises with type discriminator tool_use`() { + val c = ToolUseContent(id = "x", name = "n", input = JsonObject(emptyMap())) + val json = McpJson.encodeToString(ToolUseContent.serializer(), c) + check("\"type\":\"tool_use\"" in json) { "missing discriminator in: $json" } + } + + @Test + fun `ToolUseContent preserves meta when present`() { + val meta = buildJsonObject { put("cacheControl", JsonPrimitive("ephemeral")) } + val c = ToolUseContent(id = "x", name = "n", input = JsonObject(emptyMap()), meta = meta) + val json = McpJson.encodeToString(ToolUseContent.serializer(), c) + val decoded = McpJson.decodeFromString(ToolUseContent.serializer(), json) + decoded.meta shouldBe meta + } + + // ============================================================================ + // ToolResultContent + // ============================================================================ + + @Test + fun `ToolResultContent round-trips with text content block`() { + val original = ToolResultContent( + toolUseId = "call_1", + content = listOf(TextContent("20°C sunny")), + ) + val encoded = McpJson.encodeToString(ToolResultContent.serializer(), original) + val decoded = McpJson.decodeFromString(ToolResultContent.serializer(), encoded) + decoded shouldBe original + } + + @Test + fun `ToolResultContent round-trips with mixed content including resource_link`() { + val original = ToolResultContent( + toolUseId = "call_2", + content = listOf( + TextContent("see file"), + ResourceLink(name = "log", uri = "file:///tmp/a.log"), + ), + ) + val encoded = McpJson.encodeToString(ToolResultContent.serializer(), original) + val decoded = McpJson.decodeFromString(ToolResultContent.serializer(), encoded) + decoded shouldBe original + } + + @Test + fun `ToolResultContent serialises with discriminator and preserves structuredContent and isError`() { + val original = ToolResultContent( + toolUseId = "call_3", + content = emptyList(), + structuredContent = buildJsonObject { put("temp", JsonPrimitive(20)) }, + isError = true, + ) + val json = McpJson.encodeToString(ToolResultContent.serializer(), original) + check("\"type\":\"tool_result\"" in json) { "missing discriminator in: $json" } + val decoded = McpJson.decodeFromString(ToolResultContent.serializer(), json) + decoded shouldBe original + } + + // ============================================================================ + // SamplingMessageContentPolymorphicSerializer + // ============================================================================ + + @Test + fun `SamplingMessageContent polymorphic serializer decodes text`() { + val json = """[{"type":"text","text":"hi"}]""" + val list = McpJson.decodeFromString(samplingContentSerializer, json) + list[0].shouldBeInstanceOf().text shouldBe "hi" + } + + @Test + fun `SamplingMessageContent polymorphic serializer decodes tool_use`() { + val json = """[{"type":"tool_use","id":"a","name":"n","input":{}}]""" + val list = McpJson.decodeFromString(samplingContentSerializer, json) + val use = list[0].shouldBeInstanceOf() + use.id shouldBe "a" + use.name shouldBe "n" + } + + @Test + fun `SamplingMessageContent polymorphic serializer decodes tool_result`() { + val json = """[{"type":"tool_result","toolUseId":"a","content":[]}]""" + val list = McpJson.decodeFromString(samplingContentSerializer, json) + list[0].shouldBeInstanceOf().toolUseId shouldBe "a" + } + + @Test + fun `SamplingMessageContent polymorphic serializer throws on unknown discriminator`() { + val json = """[{"type":"bogus","x":1}]""" + assertFailsWith { + McpJson.decodeFromString(samplingContentSerializer, json) + } + } } diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/InitializeTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/InitializeTest.kt index 3da2a8bd9..ed97c18fd 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/InitializeTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/InitializeTest.kt @@ -21,7 +21,7 @@ class InitializeTest { InitializeRequestParams( protocolVersion = "2024-11-05", capabilities = ClientCapabilities( - sampling = ClientCapabilities.sampling, + sampling = ClientCapabilities.Sampling(), roots = ClientCapabilities.Roots(listChanged = true), elicitation = ClientCapabilities.elicitation, experimental = buildJsonObject { diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt index 696f32216..0432abac1 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt @@ -1,7 +1,13 @@ package io.modelcontextprotocol.kotlin.sdk.types +import io.kotest.matchers.shouldBe +import io.kotest.matchers.types.shouldBeInstanceOf import io.modelcontextprotocol.kotlin.test.utils.verifyDeserialization import io.modelcontextprotocol.kotlin.test.utils.verifySerialization +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.double import kotlinx.serialization.json.jsonPrimitive @@ -14,6 +20,21 @@ import kotlin.test.assertNotNull class SamplingTest { + @Serializable + private data class Holder( + @Serializable(with = SamplingContentSerializer::class) + val content: List, + ) + + private val dummyTool = Tool( + name = "get_weather", + description = "returns weather", + inputSchema = ToolSchema( + properties = buildJsonObject { }, + required = emptyList(), + ), + ) + @Test fun `should serialize ModelHint`() { val hint = ModelHint(name = "claude-3-5-sonnet") @@ -66,7 +87,7 @@ class SamplingTest { fun `should serialize SamplingMessage`() { val message = SamplingMessage( role = Role.User, - content = TextContent(text = "Summarize the latest release."), + content = listOf(TextContent(text = "Summarize the latest release.")), ) verifySerialization( @@ -93,11 +114,11 @@ class SamplingTest { messages = listOf( SamplingMessage( role = Role.User, - content = TextContent(text = "You are a helpful assistant."), + content = listOf(TextContent(text = "You are a helpful assistant.")), ), SamplingMessage( role = Role.User, - content = TextContent(text = "Provide a short summary."), + content = listOf(TextContent(text = "Provide a short summary.")), ), ), modelPreferences = ModelPreferences( @@ -202,7 +223,7 @@ class SamplingTest { val message = params.messages.first() assertEquals(Role.User, message.role) - val content = assertIs(message.content) + val content = assertIs(message.content.single()) assertEquals("Draft a project update.", content.text) val preferences = assertNotNull(params.modelPreferences) @@ -259,7 +280,7 @@ class SamplingTest { val result = verifyDeserialization(McpJson, json) assertEquals(Role.Assistant, result.role) - val text = assertIs(result.content) + val text = assertIs(result.content.single()) assertEquals("Summary complete.", text.text) assertEquals("gpt-4o", result.model) assertEquals(StopReason.StopSequence, result.stopReason) @@ -267,4 +288,211 @@ class SamplingTest { assertNotNull(meta) assertEquals(1200.5, meta["latencyMs"]?.jsonPrimitive?.double) } + + // ============================================================================ + // SamplingMessage shape (SEP-1577) + // ============================================================================ + + @Test + fun `SamplingMessage content is a list of SamplingMessageContent`() { + val m = SamplingMessage( + role = Role.User, + content = listOf(TextContent("hi")), + ) + m.content.size shouldBe 1 + (m.content[0] as TextContent).text shouldBe "hi" + } + + @Test + fun `SamplingMessage single-element content serialises as single object`() { + val m = SamplingMessage(role = Role.User, content = listOf(TextContent("hi"))) + val json = McpJson.encodeToString(SamplingMessage.serializer(), m) + check("""{"role":"user","content":""" in json) { "expected role/content prefix, got $json" } + check("\"type\":\"text\"" in json && "\"text\":\"hi\"" in json) { + "expected single-object content with text discriminator, got $json" + } + check("\"content\":[" !in json) { "expected single-object wire, but array form found: $json" } + } + + @Test + fun `SamplingMessage multi-element content serialises as array`() { + val m = SamplingMessage( + role = Role.Assistant, + content = listOf( + TextContent("Let me use a tool"), + ToolUseContent(id = "c1", name = "get_weather", input = JsonObject(emptyMap())), + ), + ) + val json = McpJson.encodeToString(SamplingMessage.serializer(), m) + check("\"content\":[" in json) { "expected array wire form, got $json" } + } + + @Test + fun `SamplingMessage _meta round-trips`() { + val meta = buildJsonObject { put("k", JsonPrimitive("v")) } + val m = SamplingMessage(role = Role.User, content = listOf(TextContent("hi")), meta = meta) + val json = McpJson.encodeToString(SamplingMessage.serializer(), m) + val decoded = McpJson.decodeFromString(SamplingMessage.serializer(), json) + decoded.meta shouldBe meta + } + + // ============================================================================ + // CreateMessageRequestParams: tools / toolChoice + StopReason.ToolUse + // ============================================================================ + + @Test + fun `CreateMessageRequestParams tools and toolChoice default to null`() { + val params = CreateMessageRequestParams(maxTokens = 100, messages = emptyList()) + params.tools shouldBe null + params.toolChoice shouldBe null + } + + @Test + fun `CreateMessageRequestParams round-trips tools and toolChoice`() { + val original = CreateMessageRequestParams( + maxTokens = 100, + messages = emptyList(), + tools = listOf(dummyTool), + toolChoice = ToolChoice(mode = ToolChoice.Mode.Required), + ) + val encoded = McpJson.encodeToString(CreateMessageRequestParams.serializer(), original) + val decoded = McpJson.decodeFromString(CreateMessageRequestParams.serializer(), encoded) + decoded.tools?.single()?.name shouldBe "get_weather" + decoded.toolChoice shouldBe ToolChoice(mode = ToolChoice.Mode.Required) + } + + @Test + fun `StopReason ToolUse serialises as toolUse`() { + StopReason.ToolUse.value shouldBe "toolUse" + } + + // ============================================================================ + // CreateMessageResult shape (SEP-1577) + // ============================================================================ + + @Test + fun `CreateMessageResult single-block content serialises as single object`() { + val r = CreateMessageResult( + role = Role.Assistant, + content = TextContent("42"), + model = "test-model", + stopReason = StopReason.EndTurn, + ) + val json = McpJson.encodeToString(CreateMessageResult.serializer(), r) + check("\"content\":{" in json) { "expected single-object content wire form, got $json" } + check("\"content\":[" !in json) { "expected NOT to use array wire form for size-1, got $json" } + } + + @Test + fun `CreateMessageResult multi-block content with ToolUse stopReason round-trips`() { + val r = CreateMessageResult( + role = Role.Assistant, + content = listOf( + TextContent("Let me use a tool"), + ToolUseContent(id = "c1", name = "get_weather", input = JsonObject(emptyMap())), + ), + model = "test-model", + stopReason = StopReason.ToolUse, + ) + val json = McpJson.encodeToString(CreateMessageResult.serializer(), r) + val decoded = McpJson.decodeFromString(CreateMessageResult.serializer(), json) + decoded shouldBe r + decoded.stopReason shouldBe StopReason.ToolUse + decoded.content.size shouldBe 2 + } + + @Test + fun `CreateMessageResult pre-SEP single-object wire decodes correctly`() { + val json = """{"role":"assistant","content":{"type":"text","text":"hi"},"model":"m"}""" + val decoded = McpJson.decodeFromString(CreateMessageResult.serializer(), json) + decoded.content.size shouldBe 1 + (decoded.content[0] as TextContent).text shouldBe "hi" + } + + // ============================================================================ + // SamplingContentSerializer (single-or-array wire heuristic) + // ============================================================================ + + @Test + fun `SamplingContentSerializer decodes single object into list of one`() { + val json = """{"content":{"type":"text","text":"hi"}}""" + val h = McpJson.decodeFromString(Holder.serializer(), json) + h.content.size shouldBe 1 + h.content[0].shouldBeInstanceOf().text shouldBe "hi" + } + + @Test + fun `SamplingContentSerializer decodes array into list`() { + val json = """{"content":[{"type":"text","text":"a"},{"type":"text","text":"b"}]}""" + val h = McpJson.decodeFromString(Holder.serializer(), json) + h.content.size shouldBe 2 + } + + @Test + fun `SamplingContentSerializer encodes list of size one as a single object`() { + val h = Holder(content = listOf(TextContent("hi"))) + val json = McpJson.encodeToString(Holder.serializer(), h) + json shouldBe """{"content":{"text":"hi","type":"text"}}""" + } + + @Test + fun `SamplingContentSerializer encodes list of size two as an array`() { + val h = Holder(content = listOf(TextContent("a"), TextContent("b"))) + val json = McpJson.encodeToString(Holder.serializer(), h) + json shouldBe """{"content":[{"text":"a","type":"text"},{"text":"b","type":"text"}]}""" + } + + @Test + fun `SamplingContentSerializer encoding an empty list throws`() { + val h = Holder(content = emptyList()) + assertFailsWith { + McpJson.encodeToString(Holder.serializer(), h) + } + } + + @Test + fun `SamplingContentSerializer decoding an empty array throws`() { + val json = """{"content":[]}""" + assertFailsWith { + McpJson.decodeFromString(Holder.serializer(), json) + } + } + + // ============================================================================ + // ToolChoice + // ============================================================================ + + @Test + fun `ToolChoice round-trips auto mode`() { + val original = ToolChoice(mode = ToolChoice.Mode.Auto) + val json = McpJson.encodeToString(ToolChoice.serializer(), original) + json shouldBe """{"mode":"auto"}""" + McpJson.decodeFromString(ToolChoice.serializer(), json) shouldBe original + } + + @Test + fun `ToolChoice round-trips required mode`() { + val original = ToolChoice(mode = ToolChoice.Mode.Required) + McpJson.decodeFromString( + ToolChoice.serializer(), + McpJson.encodeToString(ToolChoice.serializer(), original), + ) shouldBe original + } + + @Test + fun `ToolChoice round-trips none mode`() { + val original = ToolChoice(mode = ToolChoice.Mode.None) + McpJson.decodeFromString( + ToolChoice.serializer(), + McpJson.encodeToString(ToolChoice.serializer(), original), + ) shouldBe original + } + + @Test + fun `ToolChoice absent mode serialises as empty object and deserialises to null mode`() { + val original = ToolChoice() + val json = McpJson.encodeToString(ToolChoice.serializer(), original) + json shouldBe """{}""" + McpJson.decodeFromString(ToolChoice.serializer(), json) shouldBe ToolChoice(mode = null) + } } diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/CapabilitiesDslTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/CapabilitiesDslTest.kt index 53dd5c79f..ee2fe434c 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/CapabilitiesDslTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/CapabilitiesDslTest.kt @@ -4,10 +4,10 @@ import io.kotest.matchers.nulls.shouldBeNull import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject import io.modelcontextprotocol.kotlin.sdk.types.buildInitializeRequest import kotlinx.serialization.json.buildJsonObject -import kotlinx.serialization.json.double import kotlinx.serialization.json.jsonPrimitive import kotlinx.serialization.json.put import kotlin.test.Test @@ -42,11 +42,7 @@ class CapabilitiesDslTest { val request = buildInitializeRequest { protocolVersion = "2024-11-05" capabilities { - sampling { - put("temperature", 0.7) - put("maxTokens", 1000) - put("topP", 0.95) - } + sampling(ClientCapabilities.Sampling(tools = EmptyJsonObject)) roots(listChanged = true) elicitation { put("mode", "interactive") @@ -72,9 +68,7 @@ class CapabilitiesDslTest { request.params.capabilities.shouldNotBeNull { sampling shouldNotBeNull { - get("temperature")?.jsonPrimitive?.double shouldBe 0.7 - get("maxTokens")?.jsonPrimitive?.content shouldBe "1000" - get("topP")?.jsonPrimitive?.double shouldBe 0.95 + tools shouldBe EmptyJsonObject } roots shouldNotBeNull { listChanged shouldBe true @@ -174,8 +168,8 @@ class CapabilitiesDslTest { val request = buildInitializeRequest { protocolVersion = "2024-11-05" capabilities { - sampling { put("temperature", 0.5) } - sampling { put("temperature", 0.9) } // Should overwrite + sampling(ClientCapabilities.Sampling()) + sampling(ClientCapabilities.Sampling(tools = EmptyJsonObject)) // Should overwrite experimental { put("feature", "v1") } experimental { put("feature", "v2") } // Should overwrite } @@ -184,7 +178,7 @@ class CapabilitiesDslTest { request.params.capabilities.shouldNotBeNull { sampling shouldNotBeNull { - get("temperature")?.jsonPrimitive?.double shouldBe 0.9 + tools shouldBe EmptyJsonObject } experimental shouldNotBeNull { get("feature")?.jsonPrimitive?.content shouldBe "v2" diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/ContentDslTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/ContentDslTest.kt index 180a30829..4801d02e1 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/ContentDslTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/ContentDslTest.kt @@ -42,7 +42,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as TextContent).shouldNotBeNull { + (request.params.messages[0].content.single() as TextContent).shouldNotBeNull { text shouldBe "Hello, world!" annotations.shouldBeNull() meta.shouldBeNull() @@ -71,7 +71,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as TextContent).shouldNotBeNull { + (request.params.messages[0].content.single() as TextContent).shouldNotBeNull { text shouldBe "Analyze the quarterly sales report for Q4 2025" annotations shouldNotBeNull { audience shouldBe listOf(Role.User, Role.Assistant) @@ -98,7 +98,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as TextContent).text shouldBe "Hello 🌍! Ça va? Москва 北京 مرحبا" + (request.params.messages[0].content.single() as TextContent).text shouldBe "Hello 🌍! Ça va? Москва 北京 مرحبا" } @Test @@ -112,7 +112,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as TextContent).text shouldBe "" + (request.params.messages[0].content.single() as TextContent).text shouldBe "" } @Test @@ -144,7 +144,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as ImageContent).shouldNotBeNull { + (request.params.messages[0].content.single() as ImageContent).shouldNotBeNull { data shouldBe "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" mimeType shouldBe "image/png" @@ -177,7 +177,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as ImageContent).shouldNotBeNull { + (request.params.messages[0].content.single() as ImageContent).shouldNotBeNull { mimeType shouldBe "image/jpeg" annotations shouldNotBeNull { audience shouldBe listOf(Role.Assistant) @@ -208,7 +208,7 @@ class ContentDslTest { } } } - (request.params.messages[0].content as ImageContent).mimeType shouldBe mime + (request.params.messages[0].content.single() as ImageContent).mimeType shouldBe mime } } @@ -256,7 +256,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as AudioContent).shouldNotBeNull { + (request.params.messages[0].content.single() as AudioContent).shouldNotBeNull { data shouldBe "UklGRiQAAABXQVZFZm10IBAAAAABAAEAQB8AAEAfAAABAAgAAABmYWN0BAAAAAAAAABkYXRhAAAAAA==" mimeType shouldBe "audio/wav" annotations.shouldBeNull() @@ -288,7 +288,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as AudioContent).shouldNotBeNull { + (request.params.messages[0].content.single() as AudioContent).shouldNotBeNull { mimeType shouldBe "audio/mpeg" annotations shouldNotBeNull { audience shouldBe listOf(Role.User) @@ -349,7 +349,7 @@ class ContentDslTest { } } } - (requestMin.params.messages[0].content as TextContent).annotations?.priority shouldBe 0.0 + (requestMin.params.messages[0].content.single() as TextContent).annotations?.priority shouldBe 0.0 // Priority = 1.0 (maximum) val requestMax = buildCreateMessageRequest { @@ -361,7 +361,7 @@ class ContentDslTest { } } } - (requestMax.params.messages[0].content as TextContent).annotations?.priority shouldBe 1.0 + (requestMax.params.messages[0].content.single() as TextContent).annotations?.priority shouldBe 1.0 } @Test @@ -376,7 +376,7 @@ class ContentDslTest { } } - (request.params.messages[0].content as TextContent).annotations shouldNotBeNull { + (request.params.messages[0].content.single() as TextContent).annotations shouldNotBeNull { priority shouldBe 0.5 audience shouldBe listOf(Role.User) lastModified.shouldBeNull() diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/InitializeDslTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/InitializeDslTest.kt index aa53cce9d..17e5ae61d 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/InitializeDslTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/InitializeDslTest.kt @@ -9,7 +9,6 @@ import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject import io.modelcontextprotocol.kotlin.sdk.types.Implementation import io.modelcontextprotocol.kotlin.sdk.types.buildInitializeRequest import kotlinx.serialization.json.buildJsonObject -import kotlinx.serialization.json.int import kotlinx.serialization.json.jsonPrimitive import kotlinx.serialization.json.put import kotlin.test.Test @@ -21,9 +20,7 @@ class InitializeDslTest { val request = buildInitializeRequest { protocolVersion = "2024-11-05" capabilities { - sampling { - put("maxTokens", 100) - } + sampling(ClientCapabilities.Sampling()) roots(listChanged = true) elicitation { put("mode", "interactive") @@ -47,7 +44,7 @@ class InitializeDslTest { request.params.protocolVersion shouldBe "2024-11-05" request.params.capabilities.shouldNotBeNull { - sampling?.get("maxTokens")?.jsonPrimitive?.int shouldBe 100 + sampling shouldNotBeNull { } roots?.listChanged shouldBe true elicitation?.get("mode")?.jsonPrimitive?.content shouldBe "interactive" experimental?.get("custom")?.jsonPrimitive?.content shouldBe "true" @@ -79,8 +76,8 @@ class InitializeDslTest { } @Test - fun `capabilities DSL should support direct JsonObject values`() { - val samplingObj = buildJsonObject { put("key", "value") } + fun `capabilities DSL should support direct typed values`() { + val samplingValue = ClientCapabilities.Sampling(tools = EmptyJsonObject) val elicitationObj = buildJsonObject { put("key", "value") } val experimentalObj = buildJsonObject { put("key", "value") } val extensionsMap = mapOf( @@ -90,7 +87,7 @@ class InitializeDslTest { val request = buildInitializeRequest { protocolVersion = "1.0" capabilities { - sampling(samplingObj) + sampling(samplingValue) elicitation(elicitationObj) experimental(experimentalObj) extensions(extensionsMap) @@ -99,7 +96,7 @@ class InitializeDslTest { } request.params.capabilities.shouldNotBeNull { - sampling shouldBe samplingObj + sampling shouldBe samplingValue elicitation shouldBe elicitationObj experimental shouldBe experimentalObj extensions shouldBe extensionsMap diff --git a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/SamplingDslTest.kt b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/SamplingDslTest.kt index c53a4008c..10be8b094 100644 --- a/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/SamplingDslTest.kt +++ b/kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/dsl/SamplingDslTest.kt @@ -71,7 +71,7 @@ class SamplingDslTest { request.params.stopSequences shouldBe listOf("STOP") request.params.messages shouldHaveSize 5 - (request.params.messages[2].content as TextContent).shouldNotBeNull { + (request.params.messages[2].content.single() as TextContent).shouldNotBeNull { text shouldBe "Text with annotations" annotations shouldNotBeNull { audience shouldBe listOf(Role.User) @@ -81,13 +81,13 @@ class SamplingDslTest { meta?.get("key")?.jsonPrimitive?.content shouldBe "value" } - (request.params.messages[3].content as ImageContent).shouldNotBeNull { + (request.params.messages[3].content.single() as ImageContent).shouldNotBeNull { data shouldBe "base64image" mimeType shouldBe "image/png" annotations?.priority shouldBe 0.5 } - (request.params.messages[4].content as AudioContent).shouldNotBeNull { + (request.params.messages[4].content.single() as AudioContent).shouldNotBeNull { data shouldBe "base64audio" mimeType shouldBe "audio/wav" } @@ -102,7 +102,7 @@ class SamplingDslTest { @Test fun `buildCreateMessageRequest should support direct assignments`() { - val messages = listOf(SamplingMessage(Role.User, TextContent("Hello"))) + val messages = listOf(SamplingMessage(Role.User, listOf(TextContent("Hello")))) val preferences = ModelPreferences(costPriority = 0.1) val request = buildCreateMessageRequest { maxTokens = 100 @@ -126,8 +126,8 @@ class SamplingDslTest { assistant(content) } } - request.params.messages[0].content shouldBe content - request.params.messages[1].content shouldBe content + request.params.messages[0].content.single() shouldBe content + request.params.messages[1].content.single() shouldBe content } @Test diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnection.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnection.kt index 23dbe89a6..e64b0da81 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnection.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ClientConnection.kt @@ -11,6 +11,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestURLParams import io.modelcontextprotocol.kotlin.sdk.types.ElicitResult import io.modelcontextprotocol.kotlin.sdk.types.ElicitationCompleteNotification import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult +import io.modelcontextprotocol.kotlin.sdk.types.IncludeContext import io.modelcontextprotocol.kotlin.sdk.types.ListRootsRequest import io.modelcontextprotocol.kotlin.sdk.types.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel @@ -200,13 +201,34 @@ internal class ClientConnectionImpl(private val session: ServerSession) : Client } override suspend fun createMessage(request: CreateMessageRequest, options: RequestOptions?): CreateMessageResult { - with(request.params) { - logger.debug { - "Creating message with ${messages.size} messages, maxTokens=$maxTokens, " + - "temperature=$temperature, " + - "systemPrompt=${if (systemPrompt != null) "present" else "absent"}" + val caps = session.clientCapabilities + val params = request.params + + if (params.tools != null || params.toolChoice != null) { + requireNotNull(caps?.sampling?.tools) { + "Client did not advertise sampling.tools capability; cannot send " + + "tools/toolChoice in sampling/createMessage request." + } + } + + if (params.includeContext != null && params.includeContext != IncludeContext.None) { + if (caps?.sampling?.context == null) { + logger.warn { + "Client did not advertise sampling.context capability but server requested " + + "includeContext=${params.includeContext}. This is soft-deprecated and may be " + + "rejected by future spec versions." + } } } + + validateSamplingMessages(params.messages) + + logger.debug { + "Creating message with ${params.messages.size} messages, maxTokens=${params.maxTokens}, " + + "temperature=${params.temperature}, " + + "systemPrompt=${if (params.systemPrompt != null) "present" else "absent"}, " + + "tools=${params.tools?.size ?: 0}, toolChoice=${params.toolChoice?.mode}" + } logger.trace { "Full createMessage params: $request" } return request(request, options) } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingValidation.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingValidation.kt new file mode 100644 index 000000000..ef408080a --- /dev/null +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingValidation.kt @@ -0,0 +1,51 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.types.SamplingMessage +import io.modelcontextprotocol.kotlin.sdk.types.ToolResultContent +import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent + +/** + * Validates tool_use / tool_result rules on the last two messages of [messages]. + * + * Only the boundary between the previous message and the final message matters, because + * earlier turns were already validated when they were appended; the sole freshly-built + * portion of a sampling request is its tail. + * + * Rules enforced: + * + * 1. If the last message contains any `tool_result` block, it MUST contain only + * `tool_result` blocks (no mixing with text/image/audio/tool_use). + * 2. If the last message contains any `tool_result`, the previous message MUST contain + * matching `tool_use` blocks. + * 3. If the previous message contains `tool_use` blocks, the last message's + * `tool_result` ids MUST form exactly the same set. + * + * On the first violation throws [IllegalArgumentException]. No-op when there are fewer + * than two messages or no tool_use / tool_result blocks are involved. + */ +internal fun validateSamplingMessages(messages: List) { + if (messages.isEmpty()) return + + val last = messages.last().content + val hasToolResult = last.any { it is ToolResultContent } + + val previous = messages.getOrNull(messages.size - 2)?.content.orEmpty() + val hasPreviousToolUse = previous.any { it is ToolUseContent } + + if (hasToolResult) { + require(last.all { it is ToolResultContent }) { + "The last message must contain only tool_result content if any is present" + } + require(hasPreviousToolUse) { + "tool_result blocks are not matching any tool_use from the previous message" + } + } + + if (hasPreviousToolUse) { + val toolUseIds = previous.filterIsInstance().map { it.id }.toSet() + val toolResultIds = last.filterIsInstance().map { it.toolUseId }.toSet() + require(toolUseIds == toolResultIds) { + "ids of tool_result blocks and tool_use blocks from previous message do not match" + } + } +} diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt new file mode 100644 index 000000000..2b5446bfa --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt @@ -0,0 +1,86 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.types.Role +import io.modelcontextprotocol.kotlin.sdk.types.SamplingMessage +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import io.modelcontextprotocol.kotlin.sdk.types.ToolResultContent +import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent +import kotlinx.serialization.json.JsonObject +import org.junit.jupiter.api.assertDoesNotThrow +import kotlin.test.Test +import kotlin.test.assertFailsWith + +/** + * Unit tests for [validateSamplingMessages]: pure-function checks of the SEP-1577 + * last-two-messages tool_use / tool_result rules. + * + * End-to-end and capability-enforcement tests for `Server.createMessage` live in the + * `integration-test` module under the same package. + */ +class SamplingTest { + + private fun toolUse(id: String) = ToolUseContent(id = id, name = "t", input = JsonObject(emptyMap())) + private fun toolResult(id: String) = ToolResultContent(toolUseId = id, content = emptyList()) + + @Test + fun `validate empty message list is valid`() { + assertDoesNotThrow { validateSamplingMessages(emptyList()) } + } + + @Test + fun `validate text-only conversation is valid`() { + assertDoesNotThrow { + validateSamplingMessages( + listOf( + SamplingMessage(Role.User, TextContent("hi")), + SamplingMessage(Role.Assistant, TextContent("hello")), + ), + ) + } + } + + @Test + fun `validate matched tool_use and tool_result at boundary is valid`() { + assertDoesNotThrow { + validateSamplingMessages( + listOf( + SamplingMessage(Role.Assistant, listOf(TextContent("using tool"), toolUse("c1"))), + SamplingMessage(Role.User, toolResult("c1")), + ), + ) + } + } + + @Test + fun `validate orphan tool_result with no previous message fails`() { + assertFailsWith { + validateSamplingMessages( + listOf(SamplingMessage(Role.User, toolResult("missing"))), + ) + } + } + + @Test + fun `validate tool_result mixed with text in last message fails`() { + assertFailsWith { + validateSamplingMessages( + listOf( + SamplingMessage(Role.Assistant, toolUse("c1")), + SamplingMessage(Role.User, listOf(toolResult("c1"), TextContent("extra"))), + ), + ) + } + } + + @Test + fun `validate tool_result ids must match tool_use ids in previous message`() { + assertFailsWith { + validateSamplingMessages( + listOf( + SamplingMessage(Role.Assistant, toolUse("c1")), + SamplingMessage(Role.User, toolResult("wrong_id")), + ), + ) + } + } +}