Rate limit errors, prevent requests if above limit

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-09-28 22:13:50 -04:00
parent 28b59f02c7
commit 6f79da6e35
3 changed files with 170 additions and 2 deletions

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"time"
) )
// APIError is an error coming from Reddit. // APIError is an error coming from Reddit.
@ -72,4 +73,33 @@ func (r *ErrorResponse) Error() string {
) )
} }
// todo: rate limit errors // RateLimitError occurs when the client is sending too many requests to Reddit in a given time frame.
type RateLimitError struct {
// Rate specifies the last known rate limit for the client
Rate Rate
// HTTP response that caused this error
Response *http.Response
// Error message
Message string
}
func (e *RateLimitError) Error() string {
return fmt.Sprintf(
"%s %s: %d %s %s",
e.Response.Request.Method, e.Response.Request.URL, e.Response.StatusCode, e.Message, e.formateRateReset(),
)
}
func (e *RateLimitError) formateRateReset() string {
d := time.Until(e.Rate.Reset).Round(time.Second)
isNegative := d < 0
if isNegative {
d *= -1
}
if isNegative {
return fmt.Sprintf("[rate limit was reset %s ago]", d)
}
return fmt.Sprintf("[rate limit will reset in %s]", d)
}

View file

@ -10,7 +10,10 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
"strconv"
"strings" "strings"
"sync"
"time"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -30,6 +33,10 @@ const (
headerContentType = "Content-Type" headerContentType = "Content-Type"
headerAccept = "Accept" headerAccept = "Accept"
headerUserAgent = "User-Agent" headerUserAgent = "User-Agent"
headerRateLimitRemaining = "x-ratelimit-remaining"
headerRateLimitUsed = "x-ratelimit-used"
headerRateLimitReset = "x-ratelimit-reset"
) )
// DefaultClient is a readonly client with limited access to the Reddit API. // DefaultClient is a readonly client with limited access to the Reddit API.
@ -56,6 +63,9 @@ type Client struct {
userAgent string userAgent string
rateMu sync.Mutex
rate Rate
ID string ID string
Secret string Secret string
Username string Username string
@ -286,11 +296,15 @@ type Response struct {
// Pagination anchor indicating there are more results before this id. // Pagination anchor indicating there are more results before this id.
// todo: not sure yet if responses ever contain this // todo: not sure yet if responses ever contain this
Before string Before string
// Rate limit information.
Rate Rate
} }
// newResponse creates a new Response for the provided http.Response. // newResponse creates a new Response for the provided http.Response.
func newResponse(r *http.Response) *Response { func newResponse(r *http.Response) *Response {
response := Response{Response: r} response := Response{Response: r}
response.Rate = parseRate(r)
return &response return &response
} }
@ -299,10 +313,35 @@ func (r *Response) populateAnchors(a anchor) {
r.Before = a.Before() r.Before = a.Before()
} }
// parseRate parses the rate related headers.
func parseRate(r *http.Response) Rate {
var rate Rate
if remaining := r.Header.Get(headerRateLimitRemaining); remaining != "" {
v, _ := strconv.ParseFloat(remaining, 64)
rate.Remaining = int(v)
}
if used := r.Header.Get(headerRateLimitUsed); used != "" {
rate.Used, _ = strconv.Atoi(used)
}
if reset := r.Header.Get(headerRateLimitReset); reset != "" {
if v, _ := strconv.ParseInt(reset, 10, 64); v != 0 {
rate.Reset = time.Now().Truncate(time.Second).Add(time.Second * time.Duration(v))
}
}
return rate
}
// Do sends an API request and returns the API response. The API response is JSON decoded and stored in the value // Do sends an API request and returns the API response. The API response is JSON decoded and stored in the value
// pointed to by v, or returned as an error if an API error has occurred. If v implements the io.Writer interface, // pointed to by v, or returned as an error if an API error has occurred. If v implements the io.Writer interface,
// the raw response will be written to v, without attempting to decode it. // the raw response will be written to v, without attempting to decode it.
func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Response, error) { func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Response, error) {
if err := c.checkRateLimitBeforeDo(req); err != nil {
return &Response{
Response: err.Response,
Rate: err.Rate,
}, err
}
resp, err := DoRequestWithClient(ctx, c.client, req) resp, err := DoRequestWithClient(ctx, c.client, req)
if err != nil { if err != nil {
return nil, err return nil, err
@ -315,6 +354,10 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
response := newResponse(resp) response := newResponse(resp)
c.rateMu.Lock()
c.rate = response.Rate
c.rateMu.Unlock()
err = CheckResponse(resp) err = CheckResponse(resp)
if err != nil { if err != nil {
return response, err return response, err
@ -341,6 +384,30 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
return response, nil return response, nil
} }
func (c *Client) checkRateLimitBeforeDo(req *http.Request) *RateLimitError {
c.rateMu.Lock()
rate := c.rate
c.rateMu.Unlock()
if !rate.Reset.IsZero() && rate.Remaining == 0 && time.Now().Before(rate.Reset) {
// Create a fake 429 response.
resp := &http.Response{
Status: http.StatusText(http.StatusTooManyRequests),
StatusCode: http.StatusTooManyRequests,
Request: req,
Header: make(http.Header),
Body: ioutil.NopCloser(strings.NewReader("")),
}
return &RateLimitError{
Rate: rate,
Response: resp,
Message: fmt.Sprintf("API rate limit still exceeded until %s, not making remote request.", rate.Reset),
}
}
return nil
}
// id returns the client's Reddit ID. // id returns the client's Reddit ID.
func (c *Client) id(ctx context.Context) (string, *Response, error) { func (c *Client) id(ctx context.Context) (string, *Response, error) {
if c.redditID != "" { if c.redditID != "" {
@ -371,6 +438,15 @@ func DoRequestWithClient(ctx context.Context, client *http.Client, req *http.Req
// A response is considered an error if it has a status code outside the 200 range. // A response is considered an error if it has a status code outside the 200 range.
// Reddit also sometimes sends errors with 200 codes; we check for those too. // Reddit also sometimes sends errors with 200 codes; we check for those too.
func CheckResponse(r *http.Response) error { func CheckResponse(r *http.Response) error {
if r.Header.Get(headerRateLimitRemaining) == "0" {
err := &RateLimitError{
Rate: parseRate(r),
Response: r,
}
err.Message = fmt.Sprintf("API rate limit has been exceeded until %s.", err.Rate.Reset)
return err
}
jsonErrorResponse := &JSONErrorResponse{Response: r} jsonErrorResponse := &JSONErrorResponse{Response: r}
data, err := ioutil.ReadAll(r.Body) data, err := ioutil.ReadAll(r.Body)
@ -400,7 +476,17 @@ func CheckResponse(r *http.Response) error {
return errorResponse return errorResponse
} }
// A lot of Reddit's responses returns a "thing": { "kind": "...", "data": {...} } // Rate represents the rate limit for the client.
type Rate struct {
// The number of remaining requests the client can make in the current 10-minute window.
Remaining int `json:"remaining"`
// The number of requests the client has made in the current 10-minute window.
Used int `json:"used"`
// The time at which the current rate limit will reset.
Reset time.Time `json:"reset"`
}
// A lot of Reddit's responses return a "thing": { "kind": "...", "data": {...} }
// So this is just a nice convenient method to have. // So this is just a nice convenient method to have.
func (c *Client) getThing(ctx context.Context, path string, opts interface{}) (*thing, *Response, error) { func (c *Client) getThing(ctx context.Context, path string, opts interface{}) (*thing, *Response, error) {
path, err := addOptions(path, opts) path, err := addOptions(path, opts)

View file

@ -10,6 +10,7 @@ import (
"os" "os"
"reflect" "reflect"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -221,3 +222,54 @@ func TestClient_ErrorResponse(t *testing.T) {
require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 403 error message`, client.BaseURL)) require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 403 error message`, client.BaseURL))
require.Equal(t, http.StatusForbidden, resp.StatusCode) require.Equal(t, http.StatusForbidden, resp.StatusCode)
} }
func TestClient_Do_RateLimitError(t *testing.T) {
client, mux, teardown := setup()
defer teardown()
var counter int
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
defer func() { counter++ }()
switch counter {
case 0:
w.Header().Set(headerRateLimitRemaining, "500")
w.Header().Set(headerRateLimitUsed, "100")
w.Header().Set(headerRateLimitReset, "120")
case 1:
w.Header().Set(headerRateLimitRemaining, "0")
w.Header().Set(headerRateLimitUsed, "600")
w.Header().Set(headerRateLimitReset, "240")
}
})
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
require.NoError(t, err)
client.rate.Remaining = 0
client.rate.Reset = time.Now().Add(time.Minute)
resp, err := client.Do(ctx, req, nil)
require.Equal(t, 0, counter)
require.IsType(t, &RateLimitError{}, err)
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
client.rate = Rate{}
resp, err = client.Do(ctx, req, nil)
require.Equal(t, 1, counter)
require.Nil(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 500, resp.Rate.Remaining)
require.Equal(t, 100, resp.Rate.Used)
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*2), resp.Rate.Reset)
resp, err = client.Do(ctx, req, nil)
require.Equal(t, 2, counter)
require.IsType(t, &RateLimitError{}, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 0, resp.Rate.Remaining)
require.Equal(t, 600, resp.Rate.Used)
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*4), resp.Rate.Reset)
}