diff --git a/dotnet/test/Harness/TestHelper.cs b/dotnet/test/Harness/TestHelper.cs index af7ebe9a..6dd919bc 100644 --- a/dotnet/test/Harness/TestHelper.cs +++ b/dotnet/test/Harness/TestHelper.cs @@ -73,4 +73,29 @@ async void CheckExistingMessages() return null; } + + public static async Task GetNextEventOfTypeAsync( + CopilotSession session, + TimeSpan? timeout = null) where T : SessionEvent + { + var tcs = new TaskCompletionSource(); + using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(60)); + + using var subscription = session.On(evt => + { + if (evt is T matched) + { + tcs.TrySetResult(matched); + } + else if (evt is SessionErrorEvent error) + { + tcs.TrySetException(new Exception(error.Data.Message ?? "session error")); + } + }); + + cts.Token.Register(() => tcs.TrySetException( + new TimeoutException($"Timeout waiting for event of type '{typeof(T).Name}'"))); + + return await tcs.Task; + } } diff --git a/dotnet/test/SessionTests.cs b/dotnet/test/SessionTests.cs index 2e1119f5..a8be3741 100644 --- a/dotnet/test/SessionTests.cs +++ b/dotnet/test/SessionTests.cs @@ -201,23 +201,32 @@ public async Task Should_Abort_A_Session() { var session = await Client.CreateSessionAsync(); + // Set up wait for tool execution to start BEFORE sending + var toolStartTask = TestHelper.GetNextEventOfTypeAsync(session); + var sessionIdleTask = TestHelper.GetNextEventOfTypeAsync(session); + // Send a message that will take some time to process - await session.SendAsync(new MessageOptions { Prompt = "What is 1+1?" }); + await session.SendAsync(new MessageOptions + { + Prompt = "run the shell command 'sleep 100' (note this works on both bash and PowerShell)" + }); + + // Wait for tool execution to start + await toolStartTask; - // Abort the session immediately + // Abort the session await session.AbortAsync(); + await sessionIdleTask; // The session should still be alive and usable after abort var messages = await session.GetMessagesAsync(); Assert.NotEmpty(messages); - // TODO: We should do something to verify it really did abort (e.g., is there an abort event we can see, - // or can we check that the session became idle without receiving an assistant message?). Right now - // I'm not seeing any evidence that it actually does abort. + // Verify an abort event exists in messages + Assert.Contains(messages, m => m is AbortEvent); // We should be able to send another message - await session.SendAsync(new MessageOptions { Prompt = "What is 2+2?" }); - var answer = await TestHelper.GetFinalAssistantMessageAsync(session); + var answer = await session.SendAndWaitAsync(new MessageOptions { Prompt = "What is 2+2?" }); Assert.NotNull(answer); Assert.Contains("4", answer!.Data.Content ?? string.Empty); } diff --git a/go/client.go b/go/client.go index ca06335d..267b84b9 100644 --- a/go/client.go +++ b/go/client.go @@ -39,8 +39,6 @@ import ( "strings" "sync" "time" - - "github.com/github/copilot-sdk/go/generated" ) // Client manages the connection to the Copilot CLI server and provides session management. @@ -923,7 +921,7 @@ func (c *Client) setupNotificationHandler() { return } - event, err := generated.UnmarshalSessionEvent(eventJSON) + event, err := UnmarshalSessionEvent(eventJSON) if err != nil { return } diff --git a/go/e2e/session_test.go b/go/e2e/session_test.go index 3de45eb5..adcbbc66 100644 --- a/go/e2e/session_test.go +++ b/go/e2e/session_test.go @@ -472,18 +472,57 @@ func TestSession(t *testing.T) { t.Fatalf("Failed to create session: %v", err) } - // Send a message that will take some time to process - _, err = session.Send(copilot.MessageOptions{Prompt: "What is 1+1?"}) + // Set up event listeners BEFORE sending to avoid race conditions + toolStartCh := make(chan *copilot.SessionEvent, 1) + toolStartErrCh := make(chan error, 1) + go func() { + evt, err := testharness.GetNextEventOfType(session, copilot.ToolExecutionStart, 60*time.Second) + if err != nil { + toolStartErrCh <- err + } else { + toolStartCh <- evt + } + }() + + sessionIdleCh := make(chan *copilot.SessionEvent, 1) + sessionIdleErrCh := make(chan error, 1) + go func() { + evt, err := testharness.GetNextEventOfType(session, copilot.SessionIdle, 60*time.Second) + if err != nil { + sessionIdleErrCh <- err + } else { + sessionIdleCh <- evt + } + }() + + // Send a message that triggers a long-running shell command + _, err = session.Send(copilot.MessageOptions{Prompt: "run the shell command 'sleep 100' (note this works on both bash and PowerShell)"}) if err != nil { t.Fatalf("Failed to send message: %v", err) } - // Abort the session immediately + // Wait for tool.execution_start + select { + case <-toolStartCh: + // Tool execution has started + case err := <-toolStartErrCh: + t.Fatalf("Failed waiting for tool.execution_start: %v", err) + } + + // Abort the session err = session.Abort() if err != nil { t.Fatalf("Failed to abort session: %v", err) } + // Wait for session.idle after abort + select { + case <-sessionIdleCh: + // Session is idle + case err := <-sessionIdleErrCh: + t.Fatalf("Failed waiting for session.idle after abort: %v", err) + } + // The session should still be alive and usable after abort messages, err := session.GetMessages() if err != nil { @@ -493,15 +532,22 @@ func TestSession(t *testing.T) { t.Error("Expected messages to exist after abort") } - // We should be able to send another message - _, err = session.Send(copilot.MessageOptions{Prompt: "What is 2+2?"}) - if err != nil { - t.Fatalf("Failed to send message after abort: %v", err) + // Verify messages contain an abort event + hasAbortEvent := false + for _, msg := range messages { + if msg.Type == copilot.Abort { + hasAbortEvent = true + break + } + } + if !hasAbortEvent { + t.Error("Expected messages to contain an 'abort' event") } - answer, err := testharness.GetFinalAssistantMessage(session, 60*time.Second) + // We should be able to send another message + answer, err := session.SendAndWait(copilot.MessageOptions{Prompt: "What is 2+2?"}, 60*time.Second) if err != nil { - t.Fatalf("Failed to get assistant message after abort: %v", err) + t.Fatalf("Failed to send message after abort: %v", err) } if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "4") { diff --git a/go/e2e/testharness/helper.go b/go/e2e/testharness/helper.go index 2edaf61a..b75dd6e2 100644 --- a/go/e2e/testharness/helper.go +++ b/go/e2e/testharness/helper.go @@ -54,6 +54,41 @@ func GetFinalAssistantMessage(session *copilot.Session, timeout time.Duration) ( } } +// GetNextEventOfType waits for and returns the next event of the specified type from a session. +func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEventType, timeout time.Duration) (*copilot.SessionEvent, error) { + result := make(chan *copilot.SessionEvent, 1) + errCh := make(chan error, 1) + + unsubscribe := session.On(func(event copilot.SessionEvent) { + switch event.Type { + case eventType: + select { + case result <- &event: + default: + } + case copilot.SessionError: + msg := "session error" + if event.Data.Message != nil { + msg = *event.Data.Message + } + select { + case errCh <- errors.New(msg): + default: + } + } + }) + defer unsubscribe() + + select { + case evt := <-result: + return evt, nil + case err := <-errCh: + return nil, err + case <-time.After(timeout): + return nil, errors.New("timeout waiting for event: " + string(eventType)) + } +} + func getExistingFinalResponse(session *copilot.Session) (*copilot.SessionEvent, error) { messages, err := session.GetMessages() if err != nil { diff --git a/go/generated/session_events.go b/go/generated_session_events.go similarity index 99% rename from go/generated/session_events.go rename to go/generated_session_events.go index 6445846b..80bd1dc1 100644 --- a/go/generated/session_events.go +++ b/go/generated_session_events.go @@ -2,7 +2,7 @@ // // Generated from: @github/copilot/session-events.schema.json // Generated by: scripts/generate-session-types.ts -// Generated at: 2026-01-20T04:18:06.667Z +// Generated at: 2026-01-20T12:53:00.653Z // // To update these types: // 1. Update the schema in copilot-agent-runtime @@ -14,7 +14,7 @@ // sessionEvent, err := UnmarshalSessionEvent(bytes) // bytes, err = sessionEvent.Marshal() -package generated +package copilot import "bytes" import "errors" diff --git a/go/session.go b/go/session.go index 36685b5b..ddafb96e 100644 --- a/go/session.go +++ b/go/session.go @@ -6,8 +6,6 @@ import ( "fmt" "sync" "time" - - "github.com/github/copilot-sdk/go/generated" ) type sessionHandler struct { @@ -159,17 +157,17 @@ func (s *Session) SendAndWait(options MessageOptions, timeout time.Duration) (*S unsubscribe := s.On(func(event SessionEvent) { switch event.Type { - case generated.AssistantMessage: + case AssistantMessage: mu.Lock() eventCopy := event lastAssistantMessage = &eventCopy mu.Unlock() - case generated.SessionIdle: + case SessionIdle: select { case idleCh <- struct{}{}: default: } - case generated.SessionError: + case SessionError: errMsg := "session error" if event.Data.Message != nil { errMsg = *event.Data.Message @@ -387,7 +385,7 @@ func (s *Session) GetMessages() ([]SessionEvent, error) { continue } - event, err := generated.UnmarshalSessionEvent(eventJSON) + event, err := UnmarshalSessionEvent(eventJSON) if err != nil { continue } diff --git a/go/session_test.go b/go/session_test.go new file mode 100644 index 00000000..40874a65 --- /dev/null +++ b/go/session_test.go @@ -0,0 +1,121 @@ +package copilot + +import ( + "sync" + "testing" +) + +func TestSession_On(t *testing.T) { + t.Run("multiple handlers all receive events", func(t *testing.T) { + session := &Session{ + handlers: make([]sessionHandler, 0), + } + + var received1, received2, received3 bool + session.On(func(event SessionEvent) { received1 = true }) + session.On(func(event SessionEvent) { received2 = true }) + session.On(func(event SessionEvent) { received3 = true }) + + session.dispatchEvent(SessionEvent{Type: "test"}) + + if !received1 || !received2 || !received3 { + t.Errorf("Expected all handlers to receive event, got received1=%v, received2=%v, received3=%v", + received1, received2, received3) + } + }) + + t.Run("unsubscribing one handler does not affect others", func(t *testing.T) { + session := &Session{ + handlers: make([]sessionHandler, 0), + } + + var count1, count2, count3 int + session.On(func(event SessionEvent) { count1++ }) + unsub2 := session.On(func(event SessionEvent) { count2++ }) + session.On(func(event SessionEvent) { count3++ }) + + // First event - all handlers receive it + session.dispatchEvent(SessionEvent{Type: "test"}) + + // Unsubscribe handler 2 + unsub2() + + // Second event - only handlers 1 and 3 should receive it + session.dispatchEvent(SessionEvent{Type: "test"}) + + if count1 != 2 { + t.Errorf("Expected handler 1 to receive 2 events, got %d", count1) + } + if count2 != 1 { + t.Errorf("Expected handler 2 to receive 1 event (before unsubscribe), got %d", count2) + } + if count3 != 2 { + t.Errorf("Expected handler 3 to receive 2 events, got %d", count3) + } + }) + + t.Run("calling unsubscribe multiple times is safe", func(t *testing.T) { + session := &Session{ + handlers: make([]sessionHandler, 0), + } + + var count int + unsub := session.On(func(event SessionEvent) { count++ }) + + session.dispatchEvent(SessionEvent{Type: "test"}) + + // Call unsubscribe multiple times - should not panic + unsub() + unsub() + unsub() + + session.dispatchEvent(SessionEvent{Type: "test"}) + + if count != 1 { + t.Errorf("Expected handler to receive 1 event, got %d", count) + } + }) + + t.Run("handlers are called in registration order", func(t *testing.T) { + session := &Session{ + handlers: make([]sessionHandler, 0), + } + + var order []int + session.On(func(event SessionEvent) { order = append(order, 1) }) + session.On(func(event SessionEvent) { order = append(order, 2) }) + session.On(func(event SessionEvent) { order = append(order, 3) }) + + session.dispatchEvent(SessionEvent{Type: "test"}) + + if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 { + t.Errorf("Expected handlers to be called in order [1,2,3], got %v", order) + } + }) + + t.Run("concurrent subscribe and unsubscribe is safe", func(t *testing.T) { + session := &Session{ + handlers: make([]sessionHandler, 0), + } + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + unsub := session.On(func(event SessionEvent) {}) + unsub() + }() + } + wg.Wait() + + // Should not panic and handlers should be empty + session.handlerMutex.RLock() + count := len(session.handlers) + session.handlerMutex.RUnlock() + + if count != 0 { + t.Errorf("Expected 0 handlers after all unsubscribes, got %d", count) + } + }) +} diff --git a/go/types.go b/go/types.go index d4883206..72f4959b 100644 --- a/go/types.go +++ b/go/types.go @@ -1,11 +1,5 @@ package copilot -import ( - "github.com/github/copilot-sdk/go/generated" -) - -type SessionEvent = generated.SessionEvent - // ConnectionState represents the client connection state type ConnectionState string @@ -258,13 +252,6 @@ type MessageOptions struct { Mode string } -// Attachment represents a file or directory attachment -type Attachment struct { - Type string `json:"type"` // "file" or "directory" - Path string `json:"path"` - DisplayName string `json:"displayName,omitempty"` -} - // SessionEventHandler is a callback for session events type SessionEventHandler func(event SessionEvent) diff --git a/nodejs/scripts/generate-session-types.ts b/nodejs/scripts/generate-session-types.ts index faeb24f7..961d6bae 100644 --- a/nodejs/scripts/generate-session-types.ts +++ b/nodejs/scripts/generate-session-types.ts @@ -272,7 +272,7 @@ async function generateGoTypes(schemaPath: string) { inputData, lang: "go", rendererOptions: { - package: "generated", + package: "copilot", }, }); @@ -289,7 +289,7 @@ async function generateGoTypes(schemaPath: string) { `; - const outputPath = path.join(__dirname, "../../go/generated/session_events.go"); + const outputPath = path.join(__dirname, "../../go/generated_session_events.go"); await fs.mkdir(path.dirname(outputPath), { recursive: true }); await fs.writeFile(outputPath, banner + generatedCode, "utf-8"); diff --git a/nodejs/test/e2e/harness/sdkTestHelper.ts b/nodejs/test/e2e/harness/sdkTestHelper.ts index 03414a7f..4e8ff203 100644 --- a/nodejs/test/e2e/harness/sdkTestHelper.ts +++ b/nodejs/test/e2e/harness/sdkTestHelper.ts @@ -2,8 +2,7 @@ * Copyright (c) Microsoft Corporation. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import { AssistantMessageEvent } from "@github/copilot/sdk"; -import { CopilotSession } from "../../../src"; +import { AssistantMessageEvent, CopilotSession, SessionEvent } from "../../../src"; export async function getFinalAssistantMessage( session: CopilotSession @@ -54,13 +53,19 @@ function getExistingFinalResponse( } function getFutureFinalResponse(session: CopilotSession): Promise { - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { let finalAssistantMessage: AssistantMessageEvent | undefined; session.on((event) => { if (event.type === "assistant.message") { finalAssistantMessage = event; } else if (event.type === "session.idle") { - resolve(finalAssistantMessage); + if (!finalAssistantMessage) { + reject( + new Error("Received session.idle without a preceding assistant.message") + ); + } else { + resolve(finalAssistantMessage); + } } else if (event.type === "session.error") { const error = new Error(event.data.message); error.stack = event.data.stack; @@ -106,3 +111,20 @@ export function formatError(error: unknown): string { return String(error); } } + +export function getNextEventOfType( + session: CopilotSession, + eventType: SessionEvent["type"] +): Promise { + return new Promise((resolve, reject) => { + const unsubscribe = session.on((event) => { + if (event.type === eventType) { + unsubscribe(); + resolve(event); + } else if (event.type === "session.error") { + unsubscribe(); + reject(new Error(`${event.data.message}\n${event.data.stack}`)); + } + }); + }); +} diff --git a/nodejs/test/e2e/session.test.ts b/nodejs/test/e2e/session.test.ts index 6779b004..45b1bd42 100644 --- a/nodejs/test/e2e/session.test.ts +++ b/nodejs/test/e2e/session.test.ts @@ -2,7 +2,7 @@ import { describe, expect, it, onTestFinished } from "vitest"; import { ParsedHttpExchange } from "../../../test/harness/replayingCapiProxy.js"; import { CopilotClient } from "../../src/index.js"; import { CLI_PATH, createSdkTestContext } from "./harness/sdkTestContext.js"; -import { getFinalAssistantMessage } from "./harness/sdkTestHelper.js"; +import { getFinalAssistantMessage, getNextEventOfType } from "./harness/sdkTestHelper.js"; describe("Sessions", async () => { const { copilotClient: client, openAiEndpoint, homeDir } = await createSdkTestContext(); @@ -230,15 +230,23 @@ describe("Sessions", async () => { it("should abort a session", async () => { const session = await client.createSession(); - // Send a message that will take some time to process - await session.sendAndWait({ prompt: "What is 1+1?" }); + // Set up event listeners BEFORE sending to avoid race conditions + const nextToolCallStart = getNextEventOfType(session, "tool.execution_start"); + const nextSessionIdle = getNextEventOfType(session, "session.idle"); + + await session.send({ + prompt: "run the shell command 'sleep 100' (note this works on both bash and PowerShell)", + }); - // Abort the session immediately + // Abort once we see a tool execution start + await nextToolCallStart; await session.abort(); + await nextSessionIdle; // The session should still be alive and usable after abort const messages = await session.getMessages(); expect(messages.length).toBeGreaterThan(0); + expect(messages.some((m) => m.type === "abort")).toBe(true); // We should be able to send another message const answer = await session.sendAndWait({ prompt: "What is 2+2?" }); diff --git a/python/e2e/test_session.py b/python/e2e/test_session.py index e54465e1..ad2704fe 100644 --- a/python/e2e/test_session.py +++ b/python/e2e/test_session.py @@ -5,7 +5,7 @@ from copilot import CopilotClient from copilot.types import Tool -from .testharness import E2ETestContext, get_final_assistant_message +from .testharness import E2ETestContext, get_final_assistant_message, get_next_event_of_type pytestmark = pytest.mark.asyncio(loop_scope="module") @@ -256,21 +256,42 @@ async def test_should_resume_session_with_custom_provider(self, ctx: E2ETestCont assert session2.session_id == session_id async def test_should_abort_a_session(self, ctx: E2ETestContext): + import asyncio + session = await ctx.client.create_session() - # Send a message that will take some time to process - await session.send({"prompt": "What is 1+1?"}) + # Set up event listeners BEFORE sending to avoid race conditions + wait_for_tool_start = asyncio.create_task( + get_next_event_of_type(session, "tool.execution_start", timeout=60.0) + ) + wait_for_session_idle = asyncio.create_task( + get_next_event_of_type(session, "session.idle", timeout=30.0) + ) - # Abort the session immediately + # Send a message that will trigger a long-running shell command + await session.send( + {"prompt": "run the shell command 'sleep 100' (works on bash and PowerShell)"} + ) + + # Wait for the tool to start executing + _ = await wait_for_tool_start + + # Abort the session while the tool is running await session.abort() + # Wait for session to become idle after abort + _ = await wait_for_session_idle + # The session should still be alive and usable after abort messages = await session.get_messages() assert len(messages) > 0 + # Verify an abort event exists in messages + abort_events = [m for m in messages if m.type.value == "abort"] + assert len(abort_events) > 0, "Expected an abort event in messages" + # We should be able to send another message - await session.send({"prompt": "What is 2+2?"}) - answer = await get_final_assistant_message(session) + answer = await session.send_and_wait({"prompt": "What is 2+2?"}) assert "4" in answer.data.content async def test_should_receive_streaming_delta_events_when_streaming_is_enabled( diff --git a/python/e2e/testharness/__init__.py b/python/e2e/testharness/__init__.py index 2a711fc4..58a36028 100644 --- a/python/e2e/testharness/__init__.py +++ b/python/e2e/testharness/__init__.py @@ -1,7 +1,13 @@ """Test harness for E2E tests.""" from .context import CLI_PATH, E2ETestContext -from .helper import get_final_assistant_message +from .helper import get_final_assistant_message, get_next_event_of_type from .proxy import CapiProxy -__all__ = ["CLI_PATH", "E2ETestContext", "CapiProxy", "get_final_assistant_message"] +__all__ = [ + "CLI_PATH", + "E2ETestContext", + "CapiProxy", + "get_final_assistant_message", + "get_next_event_of_type", +] diff --git a/python/e2e/testharness/helper.py b/python/e2e/testharness/helper.py index 2111846d..85f1427f 100644 --- a/python/e2e/testharness/helper.py +++ b/python/e2e/testharness/helper.py @@ -125,3 +125,39 @@ def read_file(work_dir: str, filename: str) -> str: filepath = os.path.join(work_dir, filename) with open(filepath) as f: return f.read() + + +async def get_next_event_of_type(session: CopilotSession, event_type: str, timeout: float = 30.0): + """ + Wait for and return the next event of a specific type from a session. + + Args: + session: The session to wait on + event_type: The event type to wait for (e.g., "tool.execution_start", "session.idle") + timeout: Maximum time to wait in seconds + + Returns: + The matching event + + Raises: + TimeoutError: If no matching event arrives within timeout + RuntimeError: If a session error occurs + """ + result_future: asyncio.Future = asyncio.get_event_loop().create_future() + + def on_event(event): + if result_future.done(): + return + + if event.type.value == event_type: + result_future.set_result(event) + elif event.type.value == "session.error": + msg = event.data.message if event.data.message else "session error" + result_future.set_exception(RuntimeError(msg)) + + unsubscribe = session.on(on_event) + + try: + return await asyncio.wait_for(result_future, timeout=timeout) + finally: + unsubscribe() diff --git a/test/snapshots/session/should_abort_a_session.yaml b/test/snapshots/session/should_abort_a_session.yaml index de6c928f..70685dd6 100644 --- a/test/snapshots/session/should_abort_a_session.yaml +++ b/test/snapshots/session/should_abort_a_session.yaml @@ -5,11 +5,45 @@ conversations: - role: system content: ${system} - role: user - content: What is 1+1? + content: run the shell command 'sleep 100' (note this works on both bash and PowerShell) + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Running sleep command"}' + - role: assistant + tool_calls: + - id: toolcall_1 + type: function + function: + name: ${shell} + arguments: '{"command":"sleep 100","description":"Run sleep command for 100 seconds","initial_wait":105,"mode":"sync"}' + - messages: + - role: system + content: ${system} + - role: user + content: run the shell command 'sleep 100' (note this works on both bash and PowerShell) + - role: assistant + tool_calls: + - id: toolcall_0 + type: function + function: + name: report_intent + arguments: '{"intent":"Running sleep command"}' + - id: toolcall_1 + type: function + function: + name: ${shell} + arguments: '{"command":"sleep 100","description":"Run sleep command for 100 seconds","initial_wait":105,"mode":"sync"}' + - role: tool + tool_call_id: toolcall_0 + content: Intent logged + - role: tool + tool_call_id: toolcall_1 + content: The execution of this tool, or a previous tool was interrupted. - role: user content: What is 2+2? - role: assistant - content: |- - 1+1 = 2 - - 2+2 = 4 + content: 2+2 equals 4.