Rate limit errors, prevent requests if above limit
Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
parent
28b59f02c7
commit
6f79da6e35
3 changed files with 170 additions and 2 deletions
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// APIError is an error coming from Reddit.
|
||||
|
@ -72,4 +73,33 @@ func (r *ErrorResponse) Error() string {
|
|||
)
|
||||
}
|
||||
|
||||
// todo: rate limit errors
|
||||
// RateLimitError occurs when the client is sending too many requests to Reddit in a given time frame.
|
||||
type RateLimitError struct {
|
||||
// Rate specifies the last known rate limit for the client
|
||||
Rate Rate
|
||||
// HTTP response that caused this error
|
||||
Response *http.Response
|
||||
// Error message
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *RateLimitError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"%s %s: %d %s %s",
|
||||
e.Response.Request.Method, e.Response.Request.URL, e.Response.StatusCode, e.Message, e.formateRateReset(),
|
||||
)
|
||||
}
|
||||
|
||||
func (e *RateLimitError) formateRateReset() string {
|
||||
d := time.Until(e.Rate.Reset).Round(time.Second)
|
||||
|
||||
isNegative := d < 0
|
||||
if isNegative {
|
||||
d *= -1
|
||||
}
|
||||
|
||||
if isNegative {
|
||||
return fmt.Sprintf("[rate limit was reset %s ago]", d)
|
||||
}
|
||||
return fmt.Sprintf("[rate limit will reset in %s]", d)
|
||||
}
|
||||
|
|
|
@ -10,7 +10,10 @@ import (
|
|||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/go-querystring/query"
|
||||
"golang.org/x/oauth2"
|
||||
|
@ -30,6 +33,10 @@ const (
|
|||
headerContentType = "Content-Type"
|
||||
headerAccept = "Accept"
|
||||
headerUserAgent = "User-Agent"
|
||||
|
||||
headerRateLimitRemaining = "x-ratelimit-remaining"
|
||||
headerRateLimitUsed = "x-ratelimit-used"
|
||||
headerRateLimitReset = "x-ratelimit-reset"
|
||||
)
|
||||
|
||||
// DefaultClient is a readonly client with limited access to the Reddit API.
|
||||
|
@ -56,6 +63,9 @@ type Client struct {
|
|||
|
||||
userAgent string
|
||||
|
||||
rateMu sync.Mutex
|
||||
rate Rate
|
||||
|
||||
ID string
|
||||
Secret string
|
||||
Username string
|
||||
|
@ -286,11 +296,15 @@ type Response struct {
|
|||
// Pagination anchor indicating there are more results before this id.
|
||||
// todo: not sure yet if responses ever contain this
|
||||
Before string
|
||||
|
||||
// Rate limit information.
|
||||
Rate Rate
|
||||
}
|
||||
|
||||
// newResponse creates a new Response for the provided http.Response.
|
||||
func newResponse(r *http.Response) *Response {
|
||||
response := Response{Response: r}
|
||||
response.Rate = parseRate(r)
|
||||
return &response
|
||||
}
|
||||
|
||||
|
@ -299,10 +313,35 @@ func (r *Response) populateAnchors(a anchor) {
|
|||
r.Before = a.Before()
|
||||
}
|
||||
|
||||
// parseRate parses the rate related headers.
|
||||
func parseRate(r *http.Response) Rate {
|
||||
var rate Rate
|
||||
if remaining := r.Header.Get(headerRateLimitRemaining); remaining != "" {
|
||||
v, _ := strconv.ParseFloat(remaining, 64)
|
||||
rate.Remaining = int(v)
|
||||
}
|
||||
if used := r.Header.Get(headerRateLimitUsed); used != "" {
|
||||
rate.Used, _ = strconv.Atoi(used)
|
||||
}
|
||||
if reset := r.Header.Get(headerRateLimitReset); reset != "" {
|
||||
if v, _ := strconv.ParseInt(reset, 10, 64); v != 0 {
|
||||
rate.Reset = time.Now().Truncate(time.Second).Add(time.Second * time.Duration(v))
|
||||
}
|
||||
}
|
||||
return rate
|
||||
}
|
||||
|
||||
// Do sends an API request and returns the API response. The API response is JSON decoded and stored in the value
|
||||
// pointed to by v, or returned as an error if an API error has occurred. If v implements the io.Writer interface,
|
||||
// the raw response will be written to v, without attempting to decode it.
|
||||
func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Response, error) {
|
||||
if err := c.checkRateLimitBeforeDo(req); err != nil {
|
||||
return &Response{
|
||||
Response: err.Response,
|
||||
Rate: err.Rate,
|
||||
}, err
|
||||
}
|
||||
|
||||
resp, err := DoRequestWithClient(ctx, c.client, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -315,6 +354,10 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
|
|||
|
||||
response := newResponse(resp)
|
||||
|
||||
c.rateMu.Lock()
|
||||
c.rate = response.Rate
|
||||
c.rateMu.Unlock()
|
||||
|
||||
err = CheckResponse(resp)
|
||||
if err != nil {
|
||||
return response, err
|
||||
|
@ -341,6 +384,30 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
|
|||
return response, nil
|
||||
}
|
||||
|
||||
func (c *Client) checkRateLimitBeforeDo(req *http.Request) *RateLimitError {
|
||||
c.rateMu.Lock()
|
||||
rate := c.rate
|
||||
c.rateMu.Unlock()
|
||||
|
||||
if !rate.Reset.IsZero() && rate.Remaining == 0 && time.Now().Before(rate.Reset) {
|
||||
// Create a fake 429 response.
|
||||
resp := &http.Response{
|
||||
Status: http.StatusText(http.StatusTooManyRequests),
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Request: req,
|
||||
Header: make(http.Header),
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
return &RateLimitError{
|
||||
Rate: rate,
|
||||
Response: resp,
|
||||
Message: fmt.Sprintf("API rate limit still exceeded until %s, not making remote request.", rate.Reset),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// id returns the client's Reddit ID.
|
||||
func (c *Client) id(ctx context.Context) (string, *Response, error) {
|
||||
if c.redditID != "" {
|
||||
|
@ -371,6 +438,15 @@ func DoRequestWithClient(ctx context.Context, client *http.Client, req *http.Req
|
|||
// A response is considered an error if it has a status code outside the 200 range.
|
||||
// Reddit also sometimes sends errors with 200 codes; we check for those too.
|
||||
func CheckResponse(r *http.Response) error {
|
||||
if r.Header.Get(headerRateLimitRemaining) == "0" {
|
||||
err := &RateLimitError{
|
||||
Rate: parseRate(r),
|
||||
Response: r,
|
||||
}
|
||||
err.Message = fmt.Sprintf("API rate limit has been exceeded until %s.", err.Rate.Reset)
|
||||
return err
|
||||
}
|
||||
|
||||
jsonErrorResponse := &JSONErrorResponse{Response: r}
|
||||
|
||||
data, err := ioutil.ReadAll(r.Body)
|
||||
|
@ -400,7 +476,17 @@ func CheckResponse(r *http.Response) error {
|
|||
return errorResponse
|
||||
}
|
||||
|
||||
// A lot of Reddit's responses returns a "thing": { "kind": "...", "data": {...} }
|
||||
// Rate represents the rate limit for the client.
|
||||
type Rate struct {
|
||||
// The number of remaining requests the client can make in the current 10-minute window.
|
||||
Remaining int `json:"remaining"`
|
||||
// The number of requests the client has made in the current 10-minute window.
|
||||
Used int `json:"used"`
|
||||
// The time at which the current rate limit will reset.
|
||||
Reset time.Time `json:"reset"`
|
||||
}
|
||||
|
||||
// A lot of Reddit's responses return a "thing": { "kind": "...", "data": {...} }
|
||||
// So this is just a nice convenient method to have.
|
||||
func (c *Client) getThing(ctx context.Context, path string, opts interface{}) (*thing, *Response, error) {
|
||||
path, err := addOptions(path, opts)
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -221,3 +222,54 @@ func TestClient_ErrorResponse(t *testing.T) {
|
|||
require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 403 error message`, client.BaseURL))
|
||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestClient_Do_RateLimitError(t *testing.T) {
|
||||
client, mux, teardown := setup()
|
||||
defer teardown()
|
||||
|
||||
var counter int
|
||||
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodGet, r.Method)
|
||||
defer func() { counter++ }()
|
||||
|
||||
switch counter {
|
||||
case 0:
|
||||
w.Header().Set(headerRateLimitRemaining, "500")
|
||||
w.Header().Set(headerRateLimitUsed, "100")
|
||||
w.Header().Set(headerRateLimitReset, "120")
|
||||
case 1:
|
||||
w.Header().Set(headerRateLimitRemaining, "0")
|
||||
w.Header().Set(headerRateLimitUsed, "600")
|
||||
w.Header().Set(headerRateLimitReset, "240")
|
||||
}
|
||||
})
|
||||
|
||||
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
client.rate.Remaining = 0
|
||||
client.rate.Reset = time.Now().Add(time.Minute)
|
||||
|
||||
resp, err := client.Do(ctx, req, nil)
|
||||
require.Equal(t, 0, counter)
|
||||
require.IsType(t, &RateLimitError{}, err)
|
||||
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
|
||||
|
||||
client.rate = Rate{}
|
||||
|
||||
resp, err = client.Do(ctx, req, nil)
|
||||
require.Equal(t, 1, counter)
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 500, resp.Rate.Remaining)
|
||||
require.Equal(t, 100, resp.Rate.Used)
|
||||
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*2), resp.Rate.Reset)
|
||||
|
||||
resp, err = client.Do(ctx, req, nil)
|
||||
require.Equal(t, 2, counter)
|
||||
require.IsType(t, &RateLimitError{}, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, 0, resp.Rate.Remaining)
|
||||
require.Equal(t, 600, resp.Rate.Used)
|
||||
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*4), resp.Rate.Reset)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue