Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions dotnet/test/Harness/TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,29 @@ async void CheckExistingMessages()

return null;
}

public static async Task<T> GetNextEventOfTypeAsync<T>(
CopilotSession session,
TimeSpan? timeout = null) where T : SessionEvent
{
var tcs = new TaskCompletionSource<T>();
using var cts = new CancellationTokenSource(timeout ?? TimeSpan.FromSeconds(60));

using var subscription = session.On(evt =>
{
if (evt is T matched)
{
tcs.TrySetResult(matched);
}
else if (evt is SessionErrorEvent error)
{
tcs.TrySetException(new Exception(error.Data.Message ?? "session error"));
}
});

cts.Token.Register(() => tcs.TrySetException(
new TimeoutException($"Timeout waiting for event of type '{typeof(T).Name}'")));

return await tcs.Task;
}
}
23 changes: 16 additions & 7 deletions dotnet/test/SessionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,23 +201,32 @@ public async Task Should_Abort_A_Session()
{
var session = await Client.CreateSessionAsync();

// Set up wait for tool execution to start BEFORE sending
var toolStartTask = TestHelper.GetNextEventOfTypeAsync<ToolExecutionStartEvent>(session);
var sessionIdleTask = TestHelper.GetNextEventOfTypeAsync<SessionIdleEvent>(session);

// Send a message that will take some time to process
await session.SendAsync(new MessageOptions { Prompt = "What is 1+1?" });
await session.SendAsync(new MessageOptions
{
Prompt = "run the shell command 'sleep 100' (note this works on both bash and PowerShell)"
});

// Wait for tool execution to start
await toolStartTask;

// Abort the session immediately
// Abort the session
await session.AbortAsync();
await sessionIdleTask;

// The session should still be alive and usable after abort
var messages = await session.GetMessagesAsync();
Assert.NotEmpty(messages);

// TODO: We should do something to verify it really did abort (e.g., is there an abort event we can see,
// or can we check that the session became idle without receiving an assistant message?). Right now
// I'm not seeing any evidence that it actually does abort.
// Verify an abort event exists in messages
Assert.Contains(messages, m => m is AbortEvent);

// We should be able to send another message
await session.SendAsync(new MessageOptions { Prompt = "What is 2+2?" });
var answer = await TestHelper.GetFinalAssistantMessageAsync(session);
var answer = await session.SendAndWaitAsync(new MessageOptions { Prompt = "What is 2+2?" });
Assert.NotNull(answer);
Assert.Contains("4", answer!.Data.Content ?? string.Empty);
}
Expand Down
4 changes: 1 addition & 3 deletions go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ import (
"strings"
"sync"
"time"

"github.com/github/copilot-sdk/go/generated"
)

// Client manages the connection to the Copilot CLI server and provides session management.
Expand Down Expand Up @@ -923,7 +921,7 @@ func (c *Client) setupNotificationHandler() {
return
}

event, err := generated.UnmarshalSessionEvent(eventJSON)
event, err := UnmarshalSessionEvent(eventJSON)
if err != nil {
return
}
Expand Down
64 changes: 55 additions & 9 deletions go/e2e/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,18 +472,57 @@ func TestSession(t *testing.T) {
t.Fatalf("Failed to create session: %v", err)
}

// Send a message that will take some time to process
_, err = session.Send(copilot.MessageOptions{Prompt: "What is 1+1?"})
// Set up event listeners BEFORE sending to avoid race conditions
toolStartCh := make(chan *copilot.SessionEvent, 1)
toolStartErrCh := make(chan error, 1)
go func() {
evt, err := testharness.GetNextEventOfType(session, copilot.ToolExecutionStart, 60*time.Second)
if err != nil {
toolStartErrCh <- err
} else {
toolStartCh <- evt
}
}()

sessionIdleCh := make(chan *copilot.SessionEvent, 1)
sessionIdleErrCh := make(chan error, 1)
go func() {
evt, err := testharness.GetNextEventOfType(session, copilot.SessionIdle, 60*time.Second)
if err != nil {
sessionIdleErrCh <- err
} else {
sessionIdleCh <- evt
}
}()

// Send a message that triggers a long-running shell command
_, err = session.Send(copilot.MessageOptions{Prompt: "run the shell command 'sleep 100' (note this works on both bash and PowerShell)"})
if err != nil {
t.Fatalf("Failed to send message: %v", err)
}

// Abort the session immediately
// Wait for tool.execution_start
select {
case <-toolStartCh:
// Tool execution has started
case err := <-toolStartErrCh:
t.Fatalf("Failed waiting for tool.execution_start: %v", err)
}

// Abort the session
err = session.Abort()
if err != nil {
t.Fatalf("Failed to abort session: %v", err)
}

// Wait for session.idle after abort
select {
case <-sessionIdleCh:
// Session is idle
case err := <-sessionIdleErrCh:
t.Fatalf("Failed waiting for session.idle after abort: %v", err)
}

// The session should still be alive and usable after abort
messages, err := session.GetMessages()
if err != nil {
Expand All @@ -493,15 +532,22 @@ func TestSession(t *testing.T) {
t.Error("Expected messages to exist after abort")
}

// We should be able to send another message
_, err = session.Send(copilot.MessageOptions{Prompt: "What is 2+2?"})
if err != nil {
t.Fatalf("Failed to send message after abort: %v", err)
// Verify messages contain an abort event
hasAbortEvent := false
for _, msg := range messages {
if msg.Type == copilot.Abort {
hasAbortEvent = true
break
}
}
if !hasAbortEvent {
t.Error("Expected messages to contain an 'abort' event")
}

answer, err := testharness.GetFinalAssistantMessage(session, 60*time.Second)
// We should be able to send another message
answer, err := session.SendAndWait(copilot.MessageOptions{Prompt: "What is 2+2?"}, 60*time.Second)
if err != nil {
t.Fatalf("Failed to get assistant message after abort: %v", err)
t.Fatalf("Failed to send message after abort: %v", err)
}

if answer.Data.Content == nil || !strings.Contains(*answer.Data.Content, "4") {
Expand Down
35 changes: 35 additions & 0 deletions go/e2e/testharness/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,41 @@ func GetFinalAssistantMessage(session *copilot.Session, timeout time.Duration) (
}
}

// GetNextEventOfType waits for and returns the next event of the specified type from a session.
func GetNextEventOfType(session *copilot.Session, eventType copilot.SessionEventType, timeout time.Duration) (*copilot.SessionEvent, error) {
result := make(chan *copilot.SessionEvent, 1)
errCh := make(chan error, 1)

unsubscribe := session.On(func(event copilot.SessionEvent) {
switch event.Type {
case eventType:
select {
case result <- &event:
default:
}
case copilot.SessionError:
msg := "session error"
if event.Data.Message != nil {
msg = *event.Data.Message
}
select {
case errCh <- errors.New(msg):
default:
}
}
})
defer unsubscribe()

select {
case evt := <-result:
return evt, nil
case err := <-errCh:
return nil, err
case <-time.After(timeout):
return nil, errors.New("timeout waiting for event: " + string(eventType))
}
}

func getExistingFinalResponse(session *copilot.Session) (*copilot.SessionEvent, error) {
messages, err := session.GetMessages()
if err != nil {
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 4 additions & 6 deletions go/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"fmt"
"sync"
"time"

"github.com/github/copilot-sdk/go/generated"
)

type sessionHandler struct {
Expand Down Expand Up @@ -159,17 +157,17 @@ func (s *Session) SendAndWait(options MessageOptions, timeout time.Duration) (*S

unsubscribe := s.On(func(event SessionEvent) {
switch event.Type {
case generated.AssistantMessage:
case AssistantMessage:
mu.Lock()
eventCopy := event
lastAssistantMessage = &eventCopy
mu.Unlock()
case generated.SessionIdle:
case SessionIdle:
select {
case idleCh <- struct{}{}:
default:
}
case generated.SessionError:
case SessionError:
errMsg := "session error"
if event.Data.Message != nil {
errMsg = *event.Data.Message
Expand Down Expand Up @@ -387,7 +385,7 @@ func (s *Session) GetMessages() ([]SessionEvent, error) {
continue
}

event, err := generated.UnmarshalSessionEvent(eventJSON)
event, err := UnmarshalSessionEvent(eventJSON)
if err != nil {
continue
}
Expand Down
121 changes: 121 additions & 0 deletions go/session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package copilot

import (
"sync"
"testing"
)

func TestSession_On(t *testing.T) {
t.Run("multiple handlers all receive events", func(t *testing.T) {
session := &Session{
handlers: make([]sessionHandler, 0),
}

var received1, received2, received3 bool
session.On(func(event SessionEvent) { received1 = true })
session.On(func(event SessionEvent) { received2 = true })
session.On(func(event SessionEvent) { received3 = true })

session.dispatchEvent(SessionEvent{Type: "test"})

if !received1 || !received2 || !received3 {
t.Errorf("Expected all handlers to receive event, got received1=%v, received2=%v, received3=%v",
received1, received2, received3)
}
})

t.Run("unsubscribing one handler does not affect others", func(t *testing.T) {
session := &Session{
handlers: make([]sessionHandler, 0),
}

var count1, count2, count3 int
session.On(func(event SessionEvent) { count1++ })
unsub2 := session.On(func(event SessionEvent) { count2++ })
session.On(func(event SessionEvent) { count3++ })

// First event - all handlers receive it
session.dispatchEvent(SessionEvent{Type: "test"})

// Unsubscribe handler 2
unsub2()

// Second event - only handlers 1 and 3 should receive it
session.dispatchEvent(SessionEvent{Type: "test"})

if count1 != 2 {
t.Errorf("Expected handler 1 to receive 2 events, got %d", count1)
}
if count2 != 1 {
t.Errorf("Expected handler 2 to receive 1 event (before unsubscribe), got %d", count2)
}
if count3 != 2 {
t.Errorf("Expected handler 3 to receive 2 events, got %d", count3)
}
})

t.Run("calling unsubscribe multiple times is safe", func(t *testing.T) {
session := &Session{
handlers: make([]sessionHandler, 0),
}

var count int
unsub := session.On(func(event SessionEvent) { count++ })

session.dispatchEvent(SessionEvent{Type: "test"})

// Call unsubscribe multiple times - should not panic
unsub()
unsub()
unsub()

session.dispatchEvent(SessionEvent{Type: "test"})

if count != 1 {
t.Errorf("Expected handler to receive 1 event, got %d", count)
}
})

t.Run("handlers are called in registration order", func(t *testing.T) {
session := &Session{
handlers: make([]sessionHandler, 0),
}

var order []int
session.On(func(event SessionEvent) { order = append(order, 1) })
session.On(func(event SessionEvent) { order = append(order, 2) })
session.On(func(event SessionEvent) { order = append(order, 3) })

session.dispatchEvent(SessionEvent{Type: "test"})

if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 {
t.Errorf("Expected handlers to be called in order [1,2,3], got %v", order)
}
})

t.Run("concurrent subscribe and unsubscribe is safe", func(t *testing.T) {
session := &Session{
handlers: make([]sessionHandler, 0),
}

var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
unsub := session.On(func(event SessionEvent) {})
unsub()
}()
}
wg.Wait()

// Should not panic and handlers should be empty
session.handlerMutex.RLock()
count := len(session.handlers)
session.handlerMutex.RUnlock()

if count != 0 {
t.Errorf("Expected 0 handlers after all unsubscribes, got %d", count)
}
})
}
Loading
Loading