Rate limit errors, prevent requests if above limit
Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
parent
28b59f02c7
commit
6f79da6e35
3 changed files with 170 additions and 2 deletions
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// APIError is an error coming from Reddit.
|
// APIError is an error coming from Reddit.
|
||||||
|
@ -72,4 +73,33 @@ func (r *ErrorResponse) Error() string {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: rate limit errors
|
// RateLimitError occurs when the client is sending too many requests to Reddit in a given time frame.
|
||||||
|
type RateLimitError struct {
|
||||||
|
// Rate specifies the last known rate limit for the client
|
||||||
|
Rate Rate
|
||||||
|
// HTTP response that caused this error
|
||||||
|
Response *http.Response
|
||||||
|
// Error message
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RateLimitError) Error() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"%s %s: %d %s %s",
|
||||||
|
e.Response.Request.Method, e.Response.Request.URL, e.Response.StatusCode, e.Message, e.formateRateReset(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *RateLimitError) formateRateReset() string {
|
||||||
|
d := time.Until(e.Rate.Reset).Round(time.Second)
|
||||||
|
|
||||||
|
isNegative := d < 0
|
||||||
|
if isNegative {
|
||||||
|
d *= -1
|
||||||
|
}
|
||||||
|
|
||||||
|
if isNegative {
|
||||||
|
return fmt.Sprintf("[rate limit was reset %s ago]", d)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("[rate limit will reset in %s]", d)
|
||||||
|
}
|
||||||
|
|
|
@ -10,7 +10,10 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
@ -30,6 +33,10 @@ const (
|
||||||
headerContentType = "Content-Type"
|
headerContentType = "Content-Type"
|
||||||
headerAccept = "Accept"
|
headerAccept = "Accept"
|
||||||
headerUserAgent = "User-Agent"
|
headerUserAgent = "User-Agent"
|
||||||
|
|
||||||
|
headerRateLimitRemaining = "x-ratelimit-remaining"
|
||||||
|
headerRateLimitUsed = "x-ratelimit-used"
|
||||||
|
headerRateLimitReset = "x-ratelimit-reset"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultClient is a readonly client with limited access to the Reddit API.
|
// DefaultClient is a readonly client with limited access to the Reddit API.
|
||||||
|
@ -56,6 +63,9 @@ type Client struct {
|
||||||
|
|
||||||
userAgent string
|
userAgent string
|
||||||
|
|
||||||
|
rateMu sync.Mutex
|
||||||
|
rate Rate
|
||||||
|
|
||||||
ID string
|
ID string
|
||||||
Secret string
|
Secret string
|
||||||
Username string
|
Username string
|
||||||
|
@ -286,11 +296,15 @@ type Response struct {
|
||||||
// Pagination anchor indicating there are more results before this id.
|
// Pagination anchor indicating there are more results before this id.
|
||||||
// todo: not sure yet if responses ever contain this
|
// todo: not sure yet if responses ever contain this
|
||||||
Before string
|
Before string
|
||||||
|
|
||||||
|
// Rate limit information.
|
||||||
|
Rate Rate
|
||||||
}
|
}
|
||||||
|
|
||||||
// newResponse creates a new Response for the provided http.Response.
|
// newResponse creates a new Response for the provided http.Response.
|
||||||
func newResponse(r *http.Response) *Response {
|
func newResponse(r *http.Response) *Response {
|
||||||
response := Response{Response: r}
|
response := Response{Response: r}
|
||||||
|
response.Rate = parseRate(r)
|
||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,10 +313,35 @@ func (r *Response) populateAnchors(a anchor) {
|
||||||
r.Before = a.Before()
|
r.Before = a.Before()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseRate parses the rate related headers.
|
||||||
|
func parseRate(r *http.Response) Rate {
|
||||||
|
var rate Rate
|
||||||
|
if remaining := r.Header.Get(headerRateLimitRemaining); remaining != "" {
|
||||||
|
v, _ := strconv.ParseFloat(remaining, 64)
|
||||||
|
rate.Remaining = int(v)
|
||||||
|
}
|
||||||
|
if used := r.Header.Get(headerRateLimitUsed); used != "" {
|
||||||
|
rate.Used, _ = strconv.Atoi(used)
|
||||||
|
}
|
||||||
|
if reset := r.Header.Get(headerRateLimitReset); reset != "" {
|
||||||
|
if v, _ := strconv.ParseInt(reset, 10, 64); v != 0 {
|
||||||
|
rate.Reset = time.Now().Truncate(time.Second).Add(time.Second * time.Duration(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rate
|
||||||
|
}
|
||||||
|
|
||||||
// Do sends an API request and returns the API response. The API response is JSON decoded and stored in the value
|
// Do sends an API request and returns the API response. The API response is JSON decoded and stored in the value
|
||||||
// pointed to by v, or returned as an error if an API error has occurred. If v implements the io.Writer interface,
|
// pointed to by v, or returned as an error if an API error has occurred. If v implements the io.Writer interface,
|
||||||
// the raw response will be written to v, without attempting to decode it.
|
// the raw response will be written to v, without attempting to decode it.
|
||||||
func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Response, error) {
|
func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Response, error) {
|
||||||
|
if err := c.checkRateLimitBeforeDo(req); err != nil {
|
||||||
|
return &Response{
|
||||||
|
Response: err.Response,
|
||||||
|
Rate: err.Rate,
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := DoRequestWithClient(ctx, c.client, req)
|
resp, err := DoRequestWithClient(ctx, c.client, req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -315,6 +354,10 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
|
||||||
|
|
||||||
response := newResponse(resp)
|
response := newResponse(resp)
|
||||||
|
|
||||||
|
c.rateMu.Lock()
|
||||||
|
c.rate = response.Rate
|
||||||
|
c.rateMu.Unlock()
|
||||||
|
|
||||||
err = CheckResponse(resp)
|
err = CheckResponse(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return response, err
|
return response, err
|
||||||
|
@ -341,6 +384,30 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) checkRateLimitBeforeDo(req *http.Request) *RateLimitError {
|
||||||
|
c.rateMu.Lock()
|
||||||
|
rate := c.rate
|
||||||
|
c.rateMu.Unlock()
|
||||||
|
|
||||||
|
if !rate.Reset.IsZero() && rate.Remaining == 0 && time.Now().Before(rate.Reset) {
|
||||||
|
// Create a fake 429 response.
|
||||||
|
resp := &http.Response{
|
||||||
|
Status: http.StatusText(http.StatusTooManyRequests),
|
||||||
|
StatusCode: http.StatusTooManyRequests,
|
||||||
|
Request: req,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||||
|
}
|
||||||
|
return &RateLimitError{
|
||||||
|
Rate: rate,
|
||||||
|
Response: resp,
|
||||||
|
Message: fmt.Sprintf("API rate limit still exceeded until %s, not making remote request.", rate.Reset),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// id returns the client's Reddit ID.
|
// id returns the client's Reddit ID.
|
||||||
func (c *Client) id(ctx context.Context) (string, *Response, error) {
|
func (c *Client) id(ctx context.Context) (string, *Response, error) {
|
||||||
if c.redditID != "" {
|
if c.redditID != "" {
|
||||||
|
@ -371,6 +438,15 @@ func DoRequestWithClient(ctx context.Context, client *http.Client, req *http.Req
|
||||||
// A response is considered an error if it has a status code outside the 200 range.
|
// A response is considered an error if it has a status code outside the 200 range.
|
||||||
// Reddit also sometimes sends errors with 200 codes; we check for those too.
|
// Reddit also sometimes sends errors with 200 codes; we check for those too.
|
||||||
func CheckResponse(r *http.Response) error {
|
func CheckResponse(r *http.Response) error {
|
||||||
|
if r.Header.Get(headerRateLimitRemaining) == "0" {
|
||||||
|
err := &RateLimitError{
|
||||||
|
Rate: parseRate(r),
|
||||||
|
Response: r,
|
||||||
|
}
|
||||||
|
err.Message = fmt.Sprintf("API rate limit has been exceeded until %s.", err.Rate.Reset)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
jsonErrorResponse := &JSONErrorResponse{Response: r}
|
jsonErrorResponse := &JSONErrorResponse{Response: r}
|
||||||
|
|
||||||
data, err := ioutil.ReadAll(r.Body)
|
data, err := ioutil.ReadAll(r.Body)
|
||||||
|
@ -400,7 +476,17 @@ func CheckResponse(r *http.Response) error {
|
||||||
return errorResponse
|
return errorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
// A lot of Reddit's responses returns a "thing": { "kind": "...", "data": {...} }
|
// Rate represents the rate limit for the client.
|
||||||
|
type Rate struct {
|
||||||
|
// The number of remaining requests the client can make in the current 10-minute window.
|
||||||
|
Remaining int `json:"remaining"`
|
||||||
|
// The number of requests the client has made in the current 10-minute window.
|
||||||
|
Used int `json:"used"`
|
||||||
|
// The time at which the current rate limit will reset.
|
||||||
|
Reset time.Time `json:"reset"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// A lot of Reddit's responses return a "thing": { "kind": "...", "data": {...} }
|
||||||
// So this is just a nice convenient method to have.
|
// So this is just a nice convenient method to have.
|
||||||
func (c *Client) getThing(ctx context.Context, path string, opts interface{}) (*thing, *Response, error) {
|
func (c *Client) getThing(ctx context.Context, path string, opts interface{}) (*thing, *Response, error) {
|
||||||
path, err := addOptions(path, opts)
|
path, err := addOptions(path, opts)
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
@ -221,3 +222,54 @@ func TestClient_ErrorResponse(t *testing.T) {
|
||||||
require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 403 error message`, client.BaseURL))
|
require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 403 error message`, client.BaseURL))
|
||||||
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestClient_Do_RateLimitError(t *testing.T) {
|
||||||
|
client, mux, teardown := setup()
|
||||||
|
defer teardown()
|
||||||
|
|
||||||
|
var counter int
|
||||||
|
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, http.MethodGet, r.Method)
|
||||||
|
defer func() { counter++ }()
|
||||||
|
|
||||||
|
switch counter {
|
||||||
|
case 0:
|
||||||
|
w.Header().Set(headerRateLimitRemaining, "500")
|
||||||
|
w.Header().Set(headerRateLimitUsed, "100")
|
||||||
|
w.Header().Set(headerRateLimitReset, "120")
|
||||||
|
case 1:
|
||||||
|
w.Header().Set(headerRateLimitRemaining, "0")
|
||||||
|
w.Header().Set(headerRateLimitUsed, "600")
|
||||||
|
w.Header().Set(headerRateLimitReset, "240")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
client.rate.Remaining = 0
|
||||||
|
client.rate.Reset = time.Now().Add(time.Minute)
|
||||||
|
|
||||||
|
resp, err := client.Do(ctx, req, nil)
|
||||||
|
require.Equal(t, 0, counter)
|
||||||
|
require.IsType(t, &RateLimitError{}, err)
|
||||||
|
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
|
||||||
|
|
||||||
|
client.rate = Rate{}
|
||||||
|
|
||||||
|
resp, err = client.Do(ctx, req, nil)
|
||||||
|
require.Equal(t, 1, counter)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
require.Equal(t, 500, resp.Rate.Remaining)
|
||||||
|
require.Equal(t, 100, resp.Rate.Used)
|
||||||
|
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*2), resp.Rate.Reset)
|
||||||
|
|
||||||
|
resp, err = client.Do(ctx, req, nil)
|
||||||
|
require.Equal(t, 2, counter)
|
||||||
|
require.IsType(t, &RateLimitError{}, err)
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
require.Equal(t, 0, resp.Rate.Remaining)
|
||||||
|
require.Equal(t, 600, resp.Rate.Used)
|
||||||
|
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*4), resp.Rate.Reset)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue