diff --git a/internal/api/util/response.go b/internal/api/util/response.go index e22bac545..150d2ac2e 100644 --- a/internal/api/util/response.go +++ b/internal/api/util/response.go @@ -42,6 +42,12 @@ StatusInternalServerErrorJSON = mustJSON(map[string]string{ "status": http.StatusText(http.StatusInternalServerError), }) + ErrorCapacityExceeded = mustJSON(map[string]string{ + "error": "server capacity exceeded!", + }) + ErrorRateLimitReached = mustJSON(map[string]string{ + "error": "rate limit reached!", + }) EmptyJSONObject = mustJSON("{}") EmptyJSONArray = mustJSON("[]") diff --git a/internal/middleware/ratelimit.go b/internal/middleware/ratelimit.go index a59a3e608..57055fe70 100644 --- a/internal/middleware/ratelimit.go +++ b/internal/middleware/ratelimit.go @@ -29,6 +29,8 @@ "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/ulule/limiter/v3" "github.com/ulule/limiter/v3/drivers/store/memory" + + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" ) const rateLimitPeriod = 5 * time.Minute @@ -141,10 +143,12 @@ func RateLimit(limit int, exceptions []string) gin.HandlerFunc { if context.Reached { // Return JSON error message for // consistency with other endpoints. - c.AbortWithStatusJSON( + apiutil.Data(c, http.StatusTooManyRequests, - gin.H{"error": "rate limit reached"}, + apiutil.AppJSON, + apiutil.ErrorRateLimitReached, ) + c.Abort() return } diff --git a/internal/middleware/throttling.go b/internal/middleware/throttling.go index 589671547..33f46f175 100644 --- a/internal/middleware/throttling.go +++ b/internal/middleware/throttling.go @@ -29,9 +29,12 @@ "net/http" "runtime" "strconv" + "sync/atomic" "time" "github.com/gin-gonic/gin" + + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" ) // token represents a request that is being processed. @@ -80,55 +83,61 @@ func Throttle(cpuMultiplier int, retryAfter time.Duration) gin.HandlerFunc { } var ( - limit = runtime.GOMAXPROCS(0) * cpuMultiplier - backlogLimit = limit * cpuMultiplier - backlogChannelSize = limit + backlogLimit - tokens = make(chan token, limit) - backlogTokens = make(chan token, backlogChannelSize) - retryAfterStr = strconv.FormatUint(uint64(retryAfter/time.Second), 10) + limit = runtime.GOMAXPROCS(0) * cpuMultiplier + queueLimit = limit * cpuMultiplier + tokens = make(chan token, limit) + requestCount = atomic.Int64{} + retryAfterStr = strconv.FormatUint(uint64(retryAfter/time.Second), 10) ) - // prefill token channels + // prefill token channel for i := 0; i < limit; i++ { tokens <- token{} } - for i := 0; i < backlogChannelSize; i++ { - backlogTokens <- token{} - } return func(c *gin.Context) { - // inside this select, the caller tries to get a backlog token - select { - case <-c.Request.Context().Done(): - // request context has been canceled already + // Always decrement request counter. + defer func() { requestCount.Add(-1) }() + + // Increment request count. + n := requestCount.Add(1) + + // Check whether the request + // count is over queue limit. + if n > int64(queueLimit) { + c.Header("Retry-After", retryAfterStr) + apiutil.Data(c, + http.StatusTooManyRequests, + apiutil.AppJSON, + apiutil.ErrorCapacityExceeded, + ) + c.Abort() return - case btok := <-backlogTokens: + } + + // Sit and wait in the + // queue for free token. + select { + + case <-c.Request.Context().Done(): + // request context has + // been canceled already. + return + + case tok := <-tokens: + // caller has successfully + // received a token, allowing + // request to be processed. + defer func() { - // when we're finished, return the backlog token to the bucket - backlogTokens <- btok + // when we're finished, return + // this token to the bucket. + tokens <- tok }() - // inside *this* select, the caller has a backlog token, - // and they're waiting for their turn to be processed - select { - case <-c.Request.Context().Done(): - // the request context has been canceled already - return - case tok := <-tokens: - // the caller gets a token, so their request can now be processed - defer func() { - // whatever happens to the request, put the - // token back in the bucket when we're finished - tokens <- tok - }() - c.Next() // <- finally process the caller's request - } - - default: - // we don't have space in the backlog queue - c.Header("Retry-After", retryAfterStr) - c.JSON(http.StatusTooManyRequests, gin.H{"error": "server capacity exceeded"}) - c.Abort() + // Process + // request! + c.Next() } } } diff --git a/internal/middleware/throttling_test.go b/internal/middleware/throttling_test.go new file mode 100644 index 000000000..2a716ec53 --- /dev/null +++ b/internal/middleware/throttling_test.go @@ -0,0 +1,149 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +/* + The code in this file is adapted from MIT-licensed code in github.com/go-chi/chi. Thanks chi (thi)! + + See: https://github.com/go-chi/chi/blob/e6baba61759b26ddf7b14d1e02d1da81a4d76c08/middleware/throttle.go + + And: https://github.com/sponsors/pkieltyka +*/ + +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "runtime" + "strconv" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/middleware" +) + +func TestThrottlingMiddleware(t *testing.T) { + testThrottlingMiddleware(t, 2, time.Second*10) + testThrottlingMiddleware(t, 4, time.Second*15) + testThrottlingMiddleware(t, 8, time.Second*30) +} + +func testThrottlingMiddleware(t *testing.T, cpuMulti int, retryAfter time.Duration) { + // Calculate expected request limit + queue. + limit := runtime.GOMAXPROCS(0) * cpuMulti + queueLimit := limit * cpuMulti + + // Calculate expected retry-after header string. + retryAfterStr := strconv.FormatUint(uint64(retryAfter/time.Second), 10) + + // Gin test http engine + // (used for ctx init). + e := gin.New() + + // Add middleware to the gin engine handler stack. + middleware := middleware.Throttle(cpuMulti, retryAfter) + e.Use(middleware) + + // Set the blocking gin handler. + handler := blockingHandler() + e.Handle("GET", "/", handler) + + var cncls []func() + + for i := 0; i < queueLimit+limit; i++ { + // Prepare a gin test context. + r := httptest.NewRequest("GET", "/", nil) + rw := httptest.NewRecorder() + + // Wrap request with new cancel context. + ctx, cncl := context.WithCancel(r.Context()) + r = r.WithContext(ctx) + + // Pass req through + // engine handler. + go e.ServeHTTP(rw, r) + time.Sleep(time.Millisecond) + + // Get http result. + res := rw.Result() + + if i < queueLimit { + + // Check status == 200 (default, i.e not set). + if res.StatusCode != http.StatusOK { + t.Fatalf("status code was set (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i) + } + + // Add cancel to func slice. + cncls = append(cncls, cncl) + + } else { + + // Check the returned status code is expected. + if res.StatusCode != http.StatusTooManyRequests { + t.Fatalf("did not return status 429 (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i) + } + + // Check the returned retry-after header is set. + if res.Header.Get("Retry-After") != retryAfterStr { + t.Fatalf("did not return retry-after %s with queueLimit=%d and request=%d", retryAfterStr, queueLimit, i) + } + + // Cancel on return. + defer cncl() + + } + } + + // Cancel all blocked reqs. + for _, cncl := range cncls { + cncl() + } + time.Sleep(time.Second) + + // Check a bunchh more requests + // can now make it through after + // previous requests were released! + for i := 0; i < limit; i++ { + + // Prepare a gin test context. + r := httptest.NewRequest("GET", "/", nil) + rw := httptest.NewRecorder() + + // Pass req through + // engine handler. + go e.ServeHTTP(rw, r) + time.Sleep(time.Millisecond) + + // Get http result. + res := rw.Result() + + // Check status == 200 (default, i.e not set). + if res.StatusCode != http.StatusOK { + t.Fatalf("status code was set (%d) with queueLimit=%d and request=%d", res.StatusCode, queueLimit, i) + } + } +} + +func blockingHandler() gin.HandlerFunc { + return func(ctx *gin.Context) { + <-ctx.Done() + ctx.Status(201) // specifically not 200 + } +}