Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions pkg/api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,15 @@ import (
"github.com/stretchr/testify/require"
"github.com/xmtp/example-notification-server-go/mocks"
"github.com/xmtp/example-notification-server-go/pkg/interfaces"
"github.com/xmtp/example-notification-server-go/pkg/logging"
"github.com/xmtp/example-notification-server-go/pkg/options"
"github.com/xmtp/example-notification-server-go/pkg/testutils"
proto "github.com/xmtp/example-notification-server-go/pkg/proto/notifications/v1"
protoconnect "github.com/xmtp/example-notification-server-go/pkg/proto/notifications/v1/notificationsv1connect"
)

const INSTALLATION_ID = "install1"

type testContext struct {
cleanup func()
client protoconnect.NotificationsClient
ctx context.Context
httpClient *http.Client
Expand All @@ -33,7 +32,8 @@ type testContext struct {
}

func setupTest(t *testing.T) testContext {
ctx := context.Background()
t.Helper()
ctx := t.Context()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
port := listener.Addr().(*net.TCPAddr).Port
Expand All @@ -44,18 +44,17 @@ func setupTest(t *testing.T) testContext {
DisableKeepAlives: true,
},
}
apiServer := NewApiServer(logging.CreateLogger("console", "info"), options.ApiOptions{Port: port}, installationsMock, subscriptionsMock)
apiServer := NewApiServer(testutils.TestLogger(t), options.ApiOptions{Port: port}, installationsMock, subscriptionsMock)
require.NoError(t, apiServer.SetListener(listener))
apiServer.Start()
time.Sleep(50 * time.Millisecond)

cleanup := func() {
t.Cleanup(func() {
httpClient.CloseIdleConnections()
apiServer.Stop()
}
})

return testContext{
cleanup: cleanup,
client: protoconnect.NewNotificationsClient(httpClient, fmt.Sprintf("http://127.0.0.1:%d", port)),
ctx: ctx,
httpClient: httpClient,
Expand All @@ -67,7 +66,7 @@ func setupTest(t *testing.T) testContext {

func Test_SetListenerAfterStartReturnsError(t *testing.T) {
apiServer := NewApiServer(
logging.CreateLogger("console", "info"),
testutils.TestLogger(t),
options.ApiOptions{Port: 18081},
mocks.NewInstallations(t),
mocks.NewSubscriptions(t),
Expand All @@ -87,7 +86,6 @@ func Test_SetListenerAfterStartReturnsError(t *testing.T) {

func Test_RegisterInstallation(t *testing.T) {
ctx := setupTest(t)
defer ctx.cleanup()

deviceToken := "foo"
validUntil := time.Now()
Expand Down Expand Up @@ -118,7 +116,6 @@ func Test_RegisterInstallation(t *testing.T) {

func Test_RegisterInstallationError(t *testing.T) {
ctx := setupTest(t)
defer ctx.cleanup()

ctx.installationsMock.On(
"Register",
Expand All @@ -142,7 +139,6 @@ func Test_RegisterInstallationError(t *testing.T) {

func Test_DeleteInstallation(t *testing.T) {
ctx := setupTest(t)
defer ctx.cleanup()

ctx.installationsMock.On("Delete", mock.Anything, mock.Anything).
Return(nil)
Expand All @@ -165,7 +161,6 @@ func Test_DeleteInstallation(t *testing.T) {

func Test_Subscribe(t *testing.T) {
ctx := setupTest(t)
defer ctx.cleanup()
topics := []string{"topic1"}

ctx.subscriptionsMock.On(
Expand Down Expand Up @@ -195,7 +190,6 @@ func Test_Subscribe(t *testing.T) {

func Test_SubscribeError(t *testing.T) {
ctx := setupTest(t)
defer ctx.cleanup()

ctx.subscriptionsMock.On(
"Subscribe",
Expand All @@ -218,7 +212,6 @@ func Test_SubscribeError(t *testing.T) {

func Test_Unsubscribe(t *testing.T) {
ctx := setupTest(t)
defer ctx.cleanup()
topics := []string{"topic1"}

ctx.subscriptionsMock.On(
Expand Down
7 changes: 3 additions & 4 deletions pkg/db/db_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
package db_test

import (
"context"
"database/sql"
"testing"
"time"

"github.com/stretchr/testify/require"
database "github.com/xmtp/example-notification-server-go/pkg/db"
testdb "github.com/xmtp/example-notification-server-go/test"
testdb "github.com/xmtp/example-notification-server-go/pkg/testutils"
)

func TestMigrateFreshDatabase(t *testing.T) {
Expand Down Expand Up @@ -73,7 +72,7 @@ func TestMigrateExistingLegacySchema(t *testing.T) {
require.NoError(t, database.Migrate(t.Context(), db))

var version int
err := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations`).Scan(&version)
err := db.QueryRowContext(t.Context(), `SELECT version FROM schema_migrations`).Scan(&version)
require.NoError(t, err)

latest, latestErr := database.LatestMigrationVersion()
Expand Down Expand Up @@ -113,7 +112,7 @@ func assertRelationExists(t *testing.T, db *sql.DB, name string) {
t.Helper()

var exists bool
err := db.QueryRowContext(context.Background(), `SELECT to_regclass($1) IS NOT NULL`, "public."+name).Scan(&exists)
err := db.QueryRowContext(t.Context(), `SELECT to_regclass($1) IS NOT NULL`, "public."+name).Scan(&exists)
require.NoError(t, err)
require.True(t, exists, name)
}
41 changes: 20 additions & 21 deletions pkg/delivery/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xmtp/example-notification-server-go/pkg/interfaces"
"github.com/xmtp/example-notification-server-go/pkg/options"
Expand Down Expand Up @@ -48,9 +47,9 @@ func TestHttpDelivery_SendSuccess(t *testing.T) {
server, d := testServerAndDelivery(t, countingHandler(&requestCount, http.StatusOK), 3, 10)
defer server.Close()

err := d.Send(context.Background(), newTestRequest())
err := d.Send(t.Context(), newTestRequest())
require.NoError(t, err)
assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount))
require.Equal(t, int32(1), atomic.LoadInt32(&requestCount))
}

func TestHttpDelivery_RetryOnFailureThenSuccess(t *testing.T) {
Expand All @@ -65,9 +64,9 @@ func TestHttpDelivery_RetryOnFailureThenSuccess(t *testing.T) {
}, 3, 10)
defer server.Close()

err := d.Send(context.Background(), newTestRequest())
err := d.Send(t.Context(), newTestRequest())
require.NoError(t, err)
assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount))
require.Equal(t, int32(2), atomic.LoadInt32(&requestCount))
}

func TestHttpDelivery_ExhaustsAttempts(t *testing.T) {
Expand All @@ -76,18 +75,18 @@ func TestHttpDelivery_ExhaustsAttempts(t *testing.T) {
server, d := testServerAndDelivery(t, countingHandler(&requestCount, http.StatusInternalServerError), maxAttempts, 10)
defer server.Close()

err := d.Send(context.Background(), newTestRequest())
err := d.Send(t.Context(), newTestRequest())
require.Error(t, err)
assert.Equal(t, "HTTP request failed", err.Error())
assert.Equal(t, int32(maxAttempts), atomic.LoadInt32(&requestCount))
require.Equal(t, "HTTP request failed", err.Error())
require.Equal(t, int32(maxAttempts), atomic.LoadInt32(&requestCount))
}

func TestHttpDelivery_ContextCancellation(t *testing.T) {
var requestCount int32
server, d := testServerAndDelivery(t, countingHandler(&requestCount, http.StatusInternalServerError), 5, 500)
defer server.Close()

ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(t.Context())

done := make(chan error, 1)
go func() {
Expand All @@ -100,9 +99,9 @@ func TestHttpDelivery_ContextCancellation(t *testing.T) {

err := <-done
require.Error(t, err)
assert.Equal(t, context.Canceled, err)
require.Equal(t, context.Canceled, err)
// Should have made only 1 request before context was cancelled during backoff
assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount))
require.Equal(t, int32(1), atomic.LoadInt32(&requestCount))
}

func TestHttpDelivery_DefaultConfig(t *testing.T) {
Expand All @@ -112,8 +111,8 @@ func TestHttpDelivery_DefaultConfig(t *testing.T) {
InitialRetryDelayMs: 250,
})

assert.Equal(t, 1, d.maxAttempts)
assert.Equal(t, 250*time.Millisecond, d.initialRetryDelay)
require.Equal(t, 1, d.maxAttempts)
require.Equal(t, 250*time.Millisecond, d.initialRetryDelay)
}

func TestHttpDelivery_ExponentialBackoff(t *testing.T) {
Expand All @@ -124,7 +123,7 @@ func TestHttpDelivery_ExponentialBackoff(t *testing.T) {
}, 4, 50)
defer server.Close()

_ = d.Send(context.Background(), newTestRequest())
_ = d.Send(t.Context(), newTestRequest())

// Should have 4 requests total (maxAttempts=4)
require.Len(t, timestamps, 4)
Expand All @@ -135,7 +134,7 @@ func TestHttpDelivery_ExponentialBackoff(t *testing.T) {
gap := timestamps[i].Sub(timestamps[i-1])
expectedDelay := time.Duration(50*(1<<uint(i-1))) * time.Millisecond
// Allow 30ms tolerance for test timing
assert.InDelta(t, expectedDelay.Milliseconds(), gap.Milliseconds(), 30,
require.InDelta(t, expectedDelay.Milliseconds(), gap.Milliseconds(), 30,
"gap between request %d and %d should be ~%v, got %v", i-1, i, expectedDelay, gap)
}
}
Expand All @@ -145,10 +144,10 @@ func TestHttpDelivery_SingleAttempt(t *testing.T) {
server, d := testServerAndDelivery(t, countingHandler(&requestCount, http.StatusInternalServerError), 1, 10)
defer server.Close()

err := d.Send(context.Background(), newTestRequest())
err := d.Send(t.Context(), newTestRequest())
require.Error(t, err)
// With maxAttempts=1, only one attempt is made (no retries)
assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount))
require.Equal(t, int32(1), atomic.LoadInt32(&requestCount))
}

func TestHttpDelivery_MaxAttemptsClampsToMinimumOne(t *testing.T) {
Expand All @@ -159,12 +158,12 @@ func TestHttpDelivery_MaxAttemptsClampsToMinimumOne(t *testing.T) {
})

// Value of 0 should be clamped to 1
assert.Equal(t, 1, d.maxAttempts)
require.Equal(t, 1, d.maxAttempts)
}

func TestHttpDelivery_CanDeliver(t *testing.T) {
d := NewHttpDelivery(zaptest.NewLogger(t), options.HttpDeliveryOptions{})
assert.True(t, d.CanDeliver(newTestRequest()))
require.True(t, d.CanDeliver(newTestRequest()))
}

func TestHttpDelivery_AuthHeader(t *testing.T) {
Expand All @@ -182,7 +181,7 @@ func TestHttpDelivery_AuthHeader(t *testing.T) {
InitialRetryDelayMs: 10,
})

err := d.Send(context.Background(), newTestRequest())
err := d.Send(t.Context(), newTestRequest())
require.NoError(t, err)
assert.Equal(t, "Bearer test-token", receivedAuth)
require.Equal(t, "Bearer test-token", receivedAuth)
}
Loading
Loading