Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@ import io.kotest.matchers.string.shouldContain
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult
import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.Prompt
import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument
import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.Role
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
Expand All @@ -18,7 +24,9 @@ import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.util.concurrent.CopyOnWriteArrayList
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -697,4 +705,69 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
exception.message shouldBe expectedMessage
}
}

@Test
fun testListPromptsPagination() = runBlocking(Dispatchers.IO) {
val receivedCursors = CopyOnWriteArrayList<String?>()
val page1 = listOf(Prompt(name = "p-1", description = "desc"), Prompt(name = "p-2", description = "desc"))
val page2 = listOf(Prompt(name = "p-3", description = "desc"), Prompt(name = "p-4", description = "desc"))
val page3 = listOf(Prompt(name = "p-5", description = "desc"))

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { request, _ ->
receivedCursors += request.cursor
when (request.cursor) {
null -> ListPromptsResult(prompts = page1, nextCursor = "cursor-2")
"cursor-2" -> ListPromptsResult(prompts = page2, nextCursor = "cursor-3")
"cursor-3" -> ListPromptsResult(prompts = page3, nextCursor = null)
else -> error("Unexpected cursor: ${request.cursor}")
}
}
Comment on lines +717 to +725
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not entirely clear to me what exactly is being verified in such tests, since the test effectively reproduces the same logic that it is supposed to validate

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — the original test was just re-running the same drop/take logic the client was supposed to be exercising. Refactored to a hardcoded when(cursor) handler that returns three pre-built pages with explicit cursor-2 / cursor-3 tokens, and the test now records every cursor the client sends and asserts the full sequence ([null, cursor-2, cursor-3]) plus the exact accumulated list. The server side is no longer the logic under test.

}

val collected = mutableListOf<Prompt>()
var cursor: String? = null
do {
val request = if (cursor == null) {
ListPromptsRequest()
} else {
ListPromptsRequest(PaginatedRequestParams(cursor = cursor))
}
val response = client.listPrompts(request)
collected += response.prompts
cursor = response.nextCursor
} while (cursor != null)

assertEquals(
listOf(null, "cursor-2", "cursor-3"),
receivedCursors.toList(),
"Client must forward each nextCursor into the next request",
)
assertEquals(
listOf("p-1", "p-2", "p-3", "p-4", "p-5"),
collected.map { it.name },
"Client must accumulate pages in order without duplicates or reordering",
)
}

@Test
fun testListPromptsInvalidCursor() = runBlocking(Dispatchers.IO) {
val invalidCursor = "not-a-valid-cursor"

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { _, _ ->
throw McpException(
code = RPCError.ErrorCode.INVALID_PARAMS,
message = "Invalid cursor: $invalidCursor",
)
}
}

val exception = assertFailsWith<McpException> {
client.listPrompts(ListPromptsRequest(PaginatedRequestParams(cursor = invalidCursor)))
}

assertEquals(RPCError.ErrorCode.INVALID_PARAMS, exception.code)
assertEquals("Invalid cursor: $invalidCursor", exception.message)
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin

import io.modelcontextprotocol.kotlin.sdk.types.BlobResourceContents
import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult
import io.modelcontextprotocol.kotlin.sdk.types.Resource
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.SubscribeRequest
import io.modelcontextprotocol.kotlin.sdk.types.SubscribeRequestParams
Expand All @@ -18,8 +23,10 @@ import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -309,4 +316,76 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() {
assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty")
}
}

@Test
fun testListResourcesPagination() = runBlocking(Dispatchers.IO) {
val receivedCursors = CopyOnWriteArrayList<String?>()
val page1 = listOf(
Resource(uri = "test://r-1", name = "r-1"),
Resource(uri = "test://r-2", name = "r-2"),
Resource(uri = "test://r-3", name = "r-3"),
)
val page2 = listOf(
Resource(uri = "test://r-4", name = "r-4"),
Resource(uri = "test://r-5", name = "r-5"),
)
val page3 = listOf(Resource(uri = "test://r-6", name = "r-6"))

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { request, _ ->
receivedCursors += request.cursor
when (request.cursor) {
null -> ListResourcesResult(resources = page1, nextCursor = "cursor-2")
"cursor-2" -> ListResourcesResult(resources = page2, nextCursor = "cursor-3")
"cursor-3" -> ListResourcesResult(resources = page3, nextCursor = null)
else -> error("Unexpected cursor: ${request.cursor}")
}
}
}

val collected = mutableListOf<Resource>()
var cursor: String? = null
do {
val request = if (cursor == null) {
ListResourcesRequest()
} else {
ListResourcesRequest(PaginatedRequestParams(cursor = cursor))
}
val response = client.listResources(request)
collected += response.resources
cursor = response.nextCursor
} while (cursor != null)

assertEquals(
listOf(null, "cursor-2", "cursor-3"),
receivedCursors.toList(),
"Client must forward each nextCursor into the next request",
)
assertEquals(
listOf("test://r-1", "test://r-2", "test://r-3", "test://r-4", "test://r-5", "test://r-6"),
collected.map { it.uri },
"Client must accumulate pages in order without duplicates or reordering",
)
}

@Test
fun testListResourcesInvalidCursor() = runBlocking(Dispatchers.IO) {
val invalidCursor = "not-a-valid-cursor"

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { _, _ ->
throw McpException(
code = RPCError.ErrorCode.INVALID_PARAMS,
message = "Invalid cursor: $invalidCursor",
)
}
}

val exception = assertFailsWith<McpException> {
client.listResources(ListResourcesRequest(PaginatedRequestParams(cursor = invalidCursor)))
}

assertEquals(RPCError.ErrorCode.INVALID_PARAMS, exception.code)
assertEquals("Invalid cursor: $invalidCursor", exception.message)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.types.ContentBlock
import io.modelcontextprotocol.kotlin.sdk.types.ImageContent
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
import io.modelcontextprotocol.kotlin.sdk.types.Tool
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.delay
Expand All @@ -24,7 +31,9 @@ import org.junit.jupiter.api.Test
import java.text.DecimalFormat
import java.text.DecimalFormatSymbols
import java.util.Locale
import java.util.concurrent.CopyOnWriteArrayList
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -791,4 +800,75 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
"Error message should indicate the tool was not found",
)
}

@Test
fun testListToolsPagination() = runBlocking(Dispatchers.IO) {
val receivedCursors = CopyOnWriteArrayList<String?>()
val page1 = listOf(
Tool(name = "t-1", inputSchema = ToolSchema()),
Tool(name = "t-2", inputSchema = ToolSchema()),
)
val page2 = listOf(
Tool(name = "t-3", inputSchema = ToolSchema()),
Tool(name = "t-4", inputSchema = ToolSchema()),
)
val page3 = listOf(Tool(name = "t-5", inputSchema = ToolSchema()))

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
receivedCursors += request.cursor
when (request.cursor) {
null -> ListToolsResult(tools = page1, nextCursor = "cursor-2")
"cursor-2" -> ListToolsResult(tools = page2, nextCursor = "cursor-3")
"cursor-3" -> ListToolsResult(tools = page3, nextCursor = null)
else -> error("Unexpected cursor: ${request.cursor}")
}
}
}

val collected = mutableListOf<Tool>()
var cursor: String? = null
do {
val request = if (cursor == null) {
ListToolsRequest()
} else {
ListToolsRequest(PaginatedRequestParams(cursor = cursor))
}
val response = client.listTools(request)
collected += response.tools
cursor = response.nextCursor
} while (cursor != null)

assertEquals(
listOf(null, "cursor-2", "cursor-3"),
receivedCursors.toList(),
"Client must forward each nextCursor into the next request",
)
assertEquals(
listOf("t-1", "t-2", "t-3", "t-4", "t-5"),
collected.map { it.name },
"Client must accumulate pages in order without duplicates or reordering",
)
}

@Test
fun testListToolsInvalidCursor() = runBlocking(Dispatchers.IO) {
val invalidCursor = "not-a-valid-cursor"

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { _, _ ->
throw McpException(
code = RPCError.ErrorCode.INVALID_PARAMS,
message = "Invalid cursor: $invalidCursor",
)
}
}

val exception = assertFailsWith<McpException> {
client.listTools(ListToolsRequest(PaginatedRequestParams(cursor = invalidCursor)))
}

assertEquals(RPCError.ErrorCode.INVALID_PARAMS, exception.code)
assertEquals("Invalid cursor: $invalidCursor", exception.message)
}
}
Loading