diff --git a/reddit/errors.go b/reddit/errors.go index 549a3c2..3b20418 100644 --- a/reddit/errors.go +++ b/reddit/errors.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "strings" + "time" ) // APIError is an error coming from Reddit. @@ -72,4 +73,33 @@ func (r *ErrorResponse) Error() string { ) } -// todo: rate limit errors +// RateLimitError occurs when the client is sending too many requests to Reddit in a given time frame. +type RateLimitError struct { + // Rate specifies the last known rate limit for the client + Rate Rate + // HTTP response that caused this error + Response *http.Response + // Error message + Message string +} + +func (e *RateLimitError) Error() string { + return fmt.Sprintf( + "%s %s: %d %s %s", + e.Response.Request.Method, e.Response.Request.URL, e.Response.StatusCode, e.Message, e.formateRateReset(), + ) +} + +func (e *RateLimitError) formateRateReset() string { + d := time.Until(e.Rate.Reset).Round(time.Second) + + isNegative := d < 0 + if isNegative { + d *= -1 + } + + if isNegative { + return fmt.Sprintf("[rate limit was reset %s ago]", d) + } + return fmt.Sprintf("[rate limit will reset in %s]", d) +} diff --git a/reddit/reddit.go b/reddit/reddit.go index 3183630..939f0f6 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -10,7 +10,10 @@ import ( "net/http" "net/url" "reflect" + "strconv" "strings" + "sync" + "time" "github.com/google/go-querystring/query" "golang.org/x/oauth2" @@ -30,6 +33,10 @@ const ( headerContentType = "Content-Type" headerAccept = "Accept" headerUserAgent = "User-Agent" + + headerRateLimitRemaining = "x-ratelimit-remaining" + headerRateLimitUsed = "x-ratelimit-used" + headerRateLimitReset = "x-ratelimit-reset" ) // DefaultClient is a readonly client with limited access to the Reddit API. @@ -56,6 +63,9 @@ type Client struct { userAgent string + rateMu sync.Mutex + rate Rate + ID string Secret string Username string @@ -286,11 +296,15 @@ type Response struct { // Pagination anchor indicating there are more results before this id. // todo: not sure yet if responses ever contain this Before string + + // Rate limit information. + Rate Rate } // newResponse creates a new Response for the provided http.Response. func newResponse(r *http.Response) *Response { response := Response{Response: r} + response.Rate = parseRate(r) return &response } @@ -299,10 +313,35 @@ func (r *Response) populateAnchors(a anchor) { r.Before = a.Before() } +// parseRate parses the rate related headers. +func parseRate(r *http.Response) Rate { + var rate Rate + if remaining := r.Header.Get(headerRateLimitRemaining); remaining != "" { + v, _ := strconv.ParseFloat(remaining, 64) + rate.Remaining = int(v) + } + if used := r.Header.Get(headerRateLimitUsed); used != "" { + rate.Used, _ = strconv.Atoi(used) + } + if reset := r.Header.Get(headerRateLimitReset); reset != "" { + if v, _ := strconv.ParseInt(reset, 10, 64); v != 0 { + rate.Reset = time.Now().Truncate(time.Second).Add(time.Second * time.Duration(v)) + } + } + return rate +} + // Do sends an API request and returns the API response. The API response is JSON decoded and stored in the value // pointed to by v, or returned as an error if an API error has occurred. If v implements the io.Writer interface, // the raw response will be written to v, without attempting to decode it. func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Response, error) { + if err := c.checkRateLimitBeforeDo(req); err != nil { + return &Response{ + Response: err.Response, + Rate: err.Rate, + }, err + } + resp, err := DoRequestWithClient(ctx, c.client, req) if err != nil { return nil, err @@ -315,6 +354,10 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res response := newResponse(resp) + c.rateMu.Lock() + c.rate = response.Rate + c.rateMu.Unlock() + err = CheckResponse(resp) if err != nil { return response, err @@ -341,6 +384,30 @@ func (c *Client) Do(ctx context.Context, req *http.Request, v interface{}) (*Res return response, nil } +func (c *Client) checkRateLimitBeforeDo(req *http.Request) *RateLimitError { + c.rateMu.Lock() + rate := c.rate + c.rateMu.Unlock() + + if !rate.Reset.IsZero() && rate.Remaining == 0 && time.Now().Before(rate.Reset) { + // Create a fake 429 response. + resp := &http.Response{ + Status: http.StatusText(http.StatusTooManyRequests), + StatusCode: http.StatusTooManyRequests, + Request: req, + Header: make(http.Header), + Body: ioutil.NopCloser(strings.NewReader("")), + } + return &RateLimitError{ + Rate: rate, + Response: resp, + Message: fmt.Sprintf("API rate limit still exceeded until %s, not making remote request.", rate.Reset), + } + } + + return nil +} + // id returns the client's Reddit ID. func (c *Client) id(ctx context.Context) (string, *Response, error) { if c.redditID != "" { @@ -371,6 +438,15 @@ func DoRequestWithClient(ctx context.Context, client *http.Client, req *http.Req // A response is considered an error if it has a status code outside the 200 range. // Reddit also sometimes sends errors with 200 codes; we check for those too. func CheckResponse(r *http.Response) error { + if r.Header.Get(headerRateLimitRemaining) == "0" { + err := &RateLimitError{ + Rate: parseRate(r), + Response: r, + } + err.Message = fmt.Sprintf("API rate limit has been exceeded until %s.", err.Rate.Reset) + return err + } + jsonErrorResponse := &JSONErrorResponse{Response: r} data, err := ioutil.ReadAll(r.Body) @@ -400,7 +476,17 @@ func CheckResponse(r *http.Response) error { return errorResponse } -// A lot of Reddit's responses returns a "thing": { "kind": "...", "data": {...} } +// Rate represents the rate limit for the client. +type Rate struct { + // The number of remaining requests the client can make in the current 10-minute window. + Remaining int `json:"remaining"` + // The number of requests the client has made in the current 10-minute window. + Used int `json:"used"` + // The time at which the current rate limit will reset. + Reset time.Time `json:"reset"` +} + +// A lot of Reddit's responses return a "thing": { "kind": "...", "data": {...} } // So this is just a nice convenient method to have. func (c *Client) getThing(ctx context.Context, path string, opts interface{}) (*thing, *Response, error) { path, err := addOptions(path, opts) diff --git a/reddit/reddit_test.go b/reddit/reddit_test.go index 89c762f..d252aa6 100644 --- a/reddit/reddit_test.go +++ b/reddit/reddit_test.go @@ -10,6 +10,7 @@ import ( "os" "reflect" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -221,3 +222,54 @@ func TestClient_ErrorResponse(t *testing.T) { require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 403 error message`, client.BaseURL)) require.Equal(t, http.StatusForbidden, resp.StatusCode) } + +func TestClient_Do_RateLimitError(t *testing.T) { + client, mux, teardown := setup() + defer teardown() + + var counter int + mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + defer func() { counter++ }() + + switch counter { + case 0: + w.Header().Set(headerRateLimitRemaining, "500") + w.Header().Set(headerRateLimitUsed, "100") + w.Header().Set(headerRateLimitReset, "120") + case 1: + w.Header().Set(headerRateLimitRemaining, "0") + w.Header().Set(headerRateLimitUsed, "600") + w.Header().Set(headerRateLimitReset, "240") + } + }) + + req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil) + require.NoError(t, err) + + client.rate.Remaining = 0 + client.rate.Reset = time.Now().Add(time.Minute) + + resp, err := client.Do(ctx, req, nil) + require.Equal(t, 0, counter) + require.IsType(t, &RateLimitError{}, err) + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + + client.rate = Rate{} + + resp, err = client.Do(ctx, req, nil) + require.Equal(t, 1, counter) + require.Nil(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, 500, resp.Rate.Remaining) + require.Equal(t, 100, resp.Rate.Used) + require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*2), resp.Rate.Reset) + + resp, err = client.Do(ctx, req, nil) + require.Equal(t, 2, counter) + require.IsType(t, &RateLimitError{}, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, 0, resp.Rate.Remaining) + require.Equal(t, 600, resp.Rate.Used) + require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*4), resp.Rate.Reset) +}