diff --git a/pkg/api/api_test.go b/pkg/api/api_test.go index 601bbfc7..854a10c5 100644 --- a/pkg/api/api_test.go +++ b/pkg/api/api_test.go @@ -14,8 +14,8 @@ import ( "github.com/stretchr/testify/require" "github.com/xmtp/example-notification-server-go/mocks" "github.com/xmtp/example-notification-server-go/pkg/interfaces" - "github.com/xmtp/example-notification-server-go/pkg/logging" "github.com/xmtp/example-notification-server-go/pkg/options" + "github.com/xmtp/example-notification-server-go/pkg/testutils" proto "github.com/xmtp/example-notification-server-go/pkg/proto/notifications/v1" protoconnect "github.com/xmtp/example-notification-server-go/pkg/proto/notifications/v1/notificationsv1connect" ) @@ -23,7 +23,6 @@ import ( const INSTALLATION_ID = "install1" type testContext struct { - cleanup func() client protoconnect.NotificationsClient ctx context.Context httpClient *http.Client @@ -33,7 +32,8 @@ type testContext struct { } func setupTest(t *testing.T) testContext { - ctx := context.Background() + t.Helper() + ctx := t.Context() listener, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) port := listener.Addr().(*net.TCPAddr).Port @@ -44,18 +44,17 @@ func setupTest(t *testing.T) testContext { DisableKeepAlives: true, }, } - apiServer := NewApiServer(logging.CreateLogger("console", "info"), options.ApiOptions{Port: port}, installationsMock, subscriptionsMock) + apiServer := NewApiServer(testutils.TestLogger(t), options.ApiOptions{Port: port}, installationsMock, subscriptionsMock) require.NoError(t, apiServer.SetListener(listener)) apiServer.Start() time.Sleep(50 * time.Millisecond) - cleanup := func() { + t.Cleanup(func() { httpClient.CloseIdleConnections() apiServer.Stop() - } + }) return testContext{ - cleanup: cleanup, client: protoconnect.NewNotificationsClient(httpClient, fmt.Sprintf("http://127.0.0.1:%d", port)), ctx: ctx, httpClient: httpClient, @@ -67,7 +66,7 @@ func setupTest(t *testing.T) testContext { func Test_SetListenerAfterStartReturnsError(t *testing.T) { apiServer := NewApiServer( - logging.CreateLogger("console", "info"), + testutils.TestLogger(t), options.ApiOptions{Port: 18081}, mocks.NewInstallations(t), mocks.NewSubscriptions(t), @@ -87,7 +86,6 @@ func Test_SetListenerAfterStartReturnsError(t *testing.T) { func Test_RegisterInstallation(t *testing.T) { ctx := setupTest(t) - defer ctx.cleanup() deviceToken := "foo" validUntil := time.Now() @@ -118,7 +116,6 @@ func Test_RegisterInstallation(t *testing.T) { func Test_RegisterInstallationError(t *testing.T) { ctx := setupTest(t) - defer ctx.cleanup() ctx.installationsMock.On( "Register", @@ -142,7 +139,6 @@ func Test_RegisterInstallationError(t *testing.T) { func Test_DeleteInstallation(t *testing.T) { ctx := setupTest(t) - defer ctx.cleanup() ctx.installationsMock.On("Delete", mock.Anything, mock.Anything). Return(nil) @@ -165,7 +161,6 @@ func Test_DeleteInstallation(t *testing.T) { func Test_Subscribe(t *testing.T) { ctx := setupTest(t) - defer ctx.cleanup() topics := []string{"topic1"} ctx.subscriptionsMock.On( @@ -195,7 +190,6 @@ func Test_Subscribe(t *testing.T) { func Test_SubscribeError(t *testing.T) { ctx := setupTest(t) - defer ctx.cleanup() ctx.subscriptionsMock.On( "Subscribe", @@ -218,7 +212,6 @@ func Test_SubscribeError(t *testing.T) { func Test_Unsubscribe(t *testing.T) { ctx := setupTest(t) - defer ctx.cleanup() topics := []string{"topic1"} ctx.subscriptionsMock.On( diff --git a/pkg/db/db_test.go b/pkg/db/db_test.go index 817c3dc8..7f3ad2b0 100644 --- a/pkg/db/db_test.go +++ b/pkg/db/db_test.go @@ -1,14 +1,13 @@ package db_test import ( - "context" "database/sql" "testing" "time" "github.com/stretchr/testify/require" database "github.com/xmtp/example-notification-server-go/pkg/db" - testdb "github.com/xmtp/example-notification-server-go/test" + testdb "github.com/xmtp/example-notification-server-go/pkg/testutils" ) func TestMigrateFreshDatabase(t *testing.T) { @@ -73,7 +72,7 @@ func TestMigrateExistingLegacySchema(t *testing.T) { require.NoError(t, database.Migrate(t.Context(), db)) var version int - err := db.QueryRowContext(context.Background(), `SELECT version FROM schema_migrations`).Scan(&version) + err := db.QueryRowContext(t.Context(), `SELECT version FROM schema_migrations`).Scan(&version) require.NoError(t, err) latest, latestErr := database.LatestMigrationVersion() @@ -113,7 +112,7 @@ func assertRelationExists(t *testing.T, db *sql.DB, name string) { t.Helper() var exists bool - err := db.QueryRowContext(context.Background(), `SELECT to_regclass($1) IS NOT NULL`, "public."+name).Scan(&exists) + err := db.QueryRowContext(t.Context(), `SELECT to_regclass($1) IS NOT NULL`, "public."+name).Scan(&exists) require.NoError(t, err) require.True(t, exists, name) } diff --git a/pkg/delivery/http_test.go b/pkg/delivery/http_test.go index e5d94f73..8ac8080c 100644 --- a/pkg/delivery/http_test.go +++ b/pkg/delivery/http_test.go @@ -8,7 +8,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xmtp/example-notification-server-go/pkg/interfaces" "github.com/xmtp/example-notification-server-go/pkg/options" @@ -48,9 +47,9 @@ func TestHttpDelivery_SendSuccess(t *testing.T) { server, d := testServerAndDelivery(t, countingHandler(&requestCount, http.StatusOK), 3, 10) defer server.Close() - err := d.Send(context.Background(), newTestRequest()) + err := d.Send(t.Context(), newTestRequest()) require.NoError(t, err) - assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount)) + require.Equal(t, int32(1), atomic.LoadInt32(&requestCount)) } func TestHttpDelivery_RetryOnFailureThenSuccess(t *testing.T) { @@ -65,9 +64,9 @@ func TestHttpDelivery_RetryOnFailureThenSuccess(t *testing.T) { }, 3, 10) defer server.Close() - err := d.Send(context.Background(), newTestRequest()) + err := d.Send(t.Context(), newTestRequest()) require.NoError(t, err) - assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) + require.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) } func TestHttpDelivery_ExhaustsAttempts(t *testing.T) { @@ -76,10 +75,10 @@ func TestHttpDelivery_ExhaustsAttempts(t *testing.T) { server, d := testServerAndDelivery(t, countingHandler(&requestCount, http.StatusInternalServerError), maxAttempts, 10) defer server.Close() - err := d.Send(context.Background(), newTestRequest()) + err := d.Send(t.Context(), newTestRequest()) require.Error(t, err) - assert.Equal(t, "HTTP request failed", err.Error()) - assert.Equal(t, int32(maxAttempts), atomic.LoadInt32(&requestCount)) + require.Equal(t, "HTTP request failed", err.Error()) + require.Equal(t, int32(maxAttempts), atomic.LoadInt32(&requestCount)) } func TestHttpDelivery_ContextCancellation(t *testing.T) { @@ -87,7 +86,7 @@ func TestHttpDelivery_ContextCancellation(t *testing.T) { server, d := testServerAndDelivery(t, countingHandler(&requestCount, http.StatusInternalServerError), 5, 500) defer server.Close() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(t.Context()) done := make(chan error, 1) go func() { @@ -100,9 +99,9 @@ func TestHttpDelivery_ContextCancellation(t *testing.T) { err := <-done require.Error(t, err) - assert.Equal(t, context.Canceled, err) + require.Equal(t, context.Canceled, err) // Should have made only 1 request before context was cancelled during backoff - assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount)) + require.Equal(t, int32(1), atomic.LoadInt32(&requestCount)) } func TestHttpDelivery_DefaultConfig(t *testing.T) { @@ -112,8 +111,8 @@ func TestHttpDelivery_DefaultConfig(t *testing.T) { InitialRetryDelayMs: 250, }) - assert.Equal(t, 1, d.maxAttempts) - assert.Equal(t, 250*time.Millisecond, d.initialRetryDelay) + require.Equal(t, 1, d.maxAttempts) + require.Equal(t, 250*time.Millisecond, d.initialRetryDelay) } func TestHttpDelivery_ExponentialBackoff(t *testing.T) { @@ -124,7 +123,7 @@ func TestHttpDelivery_ExponentialBackoff(t *testing.T) { }, 4, 50) defer server.Close() - _ = d.Send(context.Background(), newTestRequest()) + _ = d.Send(t.Context(), newTestRequest()) // Should have 4 requests total (maxAttempts=4) require.Len(t, timestamps, 4) @@ -135,7 +134,7 @@ func TestHttpDelivery_ExponentialBackoff(t *testing.T) { gap := timestamps[i].Sub(timestamps[i-1]) expectedDelay := time.Duration(50*(1<