Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 87 additions & 20 deletions contract/hooks.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package contract

import "net/http"
import (
"net/http"
"slices"
"sync"
)

// hooksKey is the unexported type used as the context key
// for storing and retrieving [Hooks] from a request context.
Expand All @@ -26,28 +30,91 @@ type BeforeWriteHeaderHook = func(w http.ResponseWriter, status int)
// slice that is about to be sent.
type BeforeWriteHook = func(w http.ResponseWriter, content []byte)

// Hooks defines the contract for registering and retrieving
// lifecycle callbacks during HTTP request processing. Middleware
// and handlers use these hooks to observe response events.
type Hooks interface {
// AfterResponse registers one or more callbacks to be invoked
// after the HTTP response has been fully written.
AfterResponse(callbacks ...AfterResponseHook)
// Hooks provides lifecycle hook registration for the HTTP
// request/response cycle. Middleware and handlers can attach
// callbacks that fire before headers are written, before the
// body is written, and after the response completes.
//
// All methods are safe for concurrent use.
type Hooks struct {
mutex sync.Mutex
afterResponseHooks []AfterResponseHook
beforeWriteHeaderHooks []BeforeWriteHeaderHook
beforeWriteHooks []BeforeWriteHook
}

// NewHooks creates a [Hooks] instance with empty callback slices
// ready to accept registrations via the Before* and After* methods.
func NewHooks() *Hooks {
return &Hooks{
beforeWriteHeaderHooks: []BeforeWriteHeaderHook{},
beforeWriteHooks: []BeforeWriteHook{},
afterResponseHooks: []AfterResponseHook{},
}
}

// AfterResponse registers one or more callbacks to be invoked
// after the HTTP response has been fully written.
func (hooks *Hooks) AfterResponse(callbacks ...AfterResponseHook) {
hooks.mutex.Lock()
defer hooks.mutex.Unlock()

hooks.afterResponseHooks = append(hooks.afterResponseHooks, callbacks...)
}

// AfterResponseFuncs returns a reversed clone of the registered
// AfterResponse callbacks. The reversal ensures that the most
// recently registered callback executes first (LIFO order).
func (hooks *Hooks) AfterResponseFuncs() []AfterResponseHook {
hooks.mutex.Lock()
defer hooks.mutex.Unlock()

clone := slices.Clone(hooks.afterResponseHooks)
slices.Reverse(clone)

return clone
}

// BeforeWrite registers one or more callbacks to be invoked
// just before response body bytes are written.
func (hooks *Hooks) BeforeWrite(callbacks ...BeforeWriteHook) {
hooks.mutex.Lock()
defer hooks.mutex.Unlock()

hooks.beforeWriteHooks = append(hooks.beforeWriteHooks, callbacks...)
}

// BeforeWriteFuncs returns a reversed clone of the registered
// BeforeWrite callbacks. The reversal ensures that the most
// recently registered callback executes first (LIFO order).
func (hooks *Hooks) BeforeWriteFuncs() []BeforeWriteHook {
hooks.mutex.Lock()
defer hooks.mutex.Unlock()

// AfterResponseFuncs returns all registered after-response callbacks.
AfterResponseFuncs() []AfterResponseHook
clone := slices.Clone(hooks.beforeWriteHooks)
slices.Reverse(clone)

// BeforeWrite registers one or more callbacks to be invoked
// just before response body bytes are written.
BeforeWrite(callbacks ...BeforeWriteHook)
return clone
}

// BeforeWriteHeader registers one or more callbacks to be invoked
// just before the response status code is written.
func (hooks *Hooks) BeforeWriteHeader(callbacks ...BeforeWriteHeaderHook) {
hooks.mutex.Lock()
defer hooks.mutex.Unlock()

hooks.beforeWriteHeaderHooks = append(hooks.beforeWriteHeaderHooks, callbacks...)
}

// BeforeWriteFuncs returns all registered before-write callbacks.
BeforeWriteFuncs() []BeforeWriteHook
// BeforeWriteHeaderFuncs returns a reversed clone of the registered
// BeforeWriteHeader callbacks. The reversal ensures that the most
// recently registered callback executes first (LIFO order).
func (hooks *Hooks) BeforeWriteHeaderFuncs() []BeforeWriteHeaderHook {
hooks.mutex.Lock()
defer hooks.mutex.Unlock()

// BeforeWriteHeader registers one or more callbacks to be invoked
// just before the response status code is written.
BeforeWriteHeader(callbacks ...BeforeWriteHeaderHook)
clone := slices.Clone(hooks.beforeWriteHeaderHooks)
slices.Reverse(clone)

// BeforeWriteHeaderFuncs returns all registered before-write-header callbacks.
BeforeWriteHeaderFuncs() []BeforeWriteHeaderHook
return clone
}
143 changes: 143 additions & 0 deletions contract/hooks_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package contract_test

import (
"net/http"
"net/http/httptest"
"sync"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -20,3 +23,143 @@ func TestHooksKeyIsDistinctType(t *testing.T) {

require.NotEqual(t, other, contract.HooksKey)
}

func TestNewHooksReturnsNonNil(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

require.NotNil(t, hooks)
}

func TestHooksAfterResponseRegisters(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

var called bool
hooks.AfterResponse(func(err error) { called = true })

fns := hooks.AfterResponseFuncs()

require.Len(t, fns, 1)

fns[0](nil)

require.True(t, called)
}

func TestHooksAfterResponseFuncsReturnsLIFO(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

var order []int
hooks.AfterResponse(func(err error) { order = append(order, 1) })
hooks.AfterResponse(func(err error) { order = append(order, 2) })

for _, fn := range hooks.AfterResponseFuncs() {
fn(nil)
}

require.Equal(t, []int{2, 1}, order)
}

func TestHooksBeforeWriteRegisters(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

var called bool
hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { called = true })

fns := hooks.BeforeWriteFuncs()

require.Len(t, fns, 1)

fns[0](httptest.NewRecorder(), nil)

require.True(t, called)
}

func TestHooksBeforeWriteFuncsReturnsLIFO(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

var order []int
hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { order = append(order, 1) })
hooks.BeforeWrite(func(w http.ResponseWriter, content []byte) { order = append(order, 2) })

for _, fn := range hooks.BeforeWriteFuncs() {
fn(httptest.NewRecorder(), nil)
}

require.Equal(t, []int{2, 1}, order)
}

func TestHooksBeforeWriteHeaderRegisters(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

var called bool
hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { called = true })

fns := hooks.BeforeWriteHeaderFuncs()

require.Len(t, fns, 1)

fns[0](httptest.NewRecorder(), 200)

require.True(t, called)
}

func TestHooksBeforeWriteHeaderFuncsReturnsLIFO(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

var order []int
hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { order = append(order, 1) })
hooks.BeforeWriteHeader(func(w http.ResponseWriter, status int) { order = append(order, 2) })

for _, fn := range hooks.BeforeWriteHeaderFuncs() {
fn(httptest.NewRecorder(), 200)
}

require.Equal(t, []int{2, 1}, order)
}

func TestHooksEmptyFuncsReturnsEmptySlice(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

require.Empty(t, hooks.AfterResponseFuncs())
require.Empty(t, hooks.BeforeWriteFuncs())
require.Empty(t, hooks.BeforeWriteHeaderFuncs())
}

func TestHooksConcurrentAccess(t *testing.T) {
t.Parallel()

hooks := contract.NewHooks()

var wg sync.WaitGroup

for range 100 {
wg.Add(1)

go func() {
defer wg.Done()

hooks.AfterResponse(func(err error) {})
hooks.AfterResponseFuncs()
}()
}

wg.Wait()

require.Len(t, hooks.AfterResponseFuncs(), 100)
}
Loading