Skip to content

Commit

Permalink
Ability to specify response HTTP status code for Throttle middleware (#…
Browse files Browse the repository at this point in the history
…571)

* Fix typo in doc comment for Throttle middleware

* Add ability to specify response HTTP status code for Throttle middleware
  • Loading branch information
vasayxtx authored Sep 18, 2024
1 parent 2c4d128 commit 1f927a8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 10 deletions.
20 changes: 14 additions & 6 deletions middleware/throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ type ThrottleOpts struct {
Limit int
BacklogLimit int
BacklogTimeout time.Duration
StatusCode int
}

// Throttle is a middleware that limits number of currently processed requests
// at a time across all users. Note: Throttle is not a rate-limiter per user,
// instead it just puts a ceiling on the number of currently in-flight requests
// instead it just puts a ceiling on the number of current in-flight requests
// being processed from the point from where the Throttle middleware is mounted.
func Throttle(limit int) func(http.Handler) http.Handler {
return ThrottleWithOpts(ThrottleOpts{Limit: limit, BacklogTimeout: defaultBacklogTimeout})
Expand All @@ -49,10 +50,16 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
panic("chi/middleware: Throttle expects backlogLimit to be positive")
}

statusCode := opts.StatusCode
if statusCode == 0 {
statusCode = http.StatusTooManyRequests
}

t := throttler{
tokens: make(chan token, opts.Limit),
backlogTokens: make(chan token, opts.Limit+opts.BacklogLimit),
backlogTimeout: opts.BacklogTimeout,
statusCode: statusCode,
retryAfterFn: opts.RetryAfterFn,
}

Expand All @@ -72,7 +79,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {

case <-ctx.Done():
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
http.Error(w, errContextCanceled, t.statusCode)
return

case btok := <-t.backlogTokens:
Expand All @@ -85,12 +92,12 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {
select {
case <-timer.C:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errTimedOut, http.StatusTooManyRequests)
http.Error(w, errTimedOut, t.statusCode)
return
case <-ctx.Done():
timer.Stop()
t.setRetryAfterHeaderIfNeeded(w, true)
http.Error(w, errContextCanceled, http.StatusTooManyRequests)
http.Error(w, errContextCanceled, t.statusCode)
return
case tok := <-t.tokens:
defer func() {
Expand All @@ -103,7 +110,7 @@ func ThrottleWithOpts(opts ThrottleOpts) func(http.Handler) http.Handler {

default:
t.setRetryAfterHeaderIfNeeded(w, false)
http.Error(w, errCapacityExceeded, http.StatusTooManyRequests)
http.Error(w, errCapacityExceeded, t.statusCode)
return
}
}
Expand All @@ -119,8 +126,9 @@ type token struct{}
type throttler struct {
tokens chan token
backlogTokens chan token
retryAfterFn func(ctxDone bool) time.Duration
backlogTimeout time.Duration
statusCode int
retryAfterFn func(ctxDone bool) time.Duration
}

// setRetryAfterHeaderIfNeeded sets Retry-After HTTP header if corresponding retryAfterFn option of throttler is initialized.
Expand Down
55 changes: 51 additions & 4 deletions middleware/throttle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
res, err := client.Get(server.URL)
assertNoError(t, err)
assertEqual(t, http.StatusOK, res.StatusCode)

}(i)
}

Expand All @@ -136,7 +135,6 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errTimedOut, strings.TrimSpace(string(buf)))

}(i)
}

Expand Down Expand Up @@ -175,7 +173,6 @@ func TestThrottleMaximum(t *testing.T) {
buf, err := ioutil.ReadAll(res.Body)
assertNoError(t, err)
assertEqual(t, testContent, buf)

}(i)
}

Expand All @@ -196,7 +193,6 @@ func TestThrottleMaximum(t *testing.T) {
assertNoError(t, err)
assertEqual(t, http.StatusTooManyRequests, res.StatusCode)
assertEqual(t, errCapacityExceeded, strings.TrimSpace(string(buf)))

}(i)
}

Expand Down Expand Up @@ -252,3 +248,54 @@ func TestThrottleMaximum(t *testing.T) {
wg.Wait()
}*/

func TestThrottleCustomStatusCode(t *testing.T) {
const timeout = time.Second * 3

wait := make(chan struct{})

r := chi.NewRouter()
r.Use(ThrottleWithOpts(ThrottleOpts{Limit: 1, StatusCode: http.StatusServiceUnavailable}))
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
select {
case <-wait:
case <-time.After(timeout):
}
w.WriteHeader(http.StatusOK)
})
server := httptest.NewServer(r)
defer server.Close()

const totalRequestCount = 5

codes := make(chan int, totalRequestCount)
errs := make(chan error, totalRequestCount)
client := &http.Client{Timeout: timeout}
for i := 0; i < totalRequestCount; i++ {
go func() {
resp, err := client.Get(server.URL)
if err != nil {
errs <- err
return
}
codes <- resp.StatusCode
}()
}

waitResponse := func(wantCode int) {
select {
case err := <-errs:
t.Fatal(err)
case code := <-codes:
assertEqual(t, wantCode, code)
case <-time.After(timeout):
t.Fatalf("waiting %d code, timeout exceeded", wantCode)
}
}

for i := 0; i < totalRequestCount-1; i++ {
waitResponse(http.StatusServiceUnavailable)
}
close(wait) // Allow the last request to proceed.
waitResponse(http.StatusOK)
}

0 comments on commit 1f927a8

Please sign in to comment.