snoobert/reddit/reddit_test.go

276 lines
6.5 KiB
Go
Raw Normal View History

package reddit
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/require"
)
var ctx = context.Background()
func setup() (*Client, *http.ServeMux, func()) {
mux := http.NewServeMux()
server := httptest.NewServer(mux)
mux.HandleFunc("/api/v1/access_token", func(w http.ResponseWriter, r *http.Request) {
response := `{
"access_token": "token1",
"token_type": "bearer",
"expires_in": 3600,
"scope": "*"
}`
w.Header().Add(headerContentType, mediaTypeJSON)
fmt.Fprint(w, response)
})
client, _ := NewClient(
WithCredentials("id1", "secret1", "user1", "password1"),
WithBaseURL(server.URL),
WithTokenURL(server.URL+"/api/v1/access_token"),
)
return client, mux, server.Close
}
func readFileContents(path string) (string, error) {
file, err := os.Open(path)
if err != nil {
return "", err
}
defer file.Close()
bytes, err := ioutil.ReadAll(file)
if err != nil {
return "", err
}
return string(bytes), err
}
func testClientServices(t *testing.T, c *Client) {
services := []string{
"Account",
"Collection",
"Comment",
"Emoji",
"Flair",
"Gold",
"Listings",
"LiveThread",
"Message",
"Moderation",
"Multi",
"Post",
"Stream",
"Subreddit",
"User",
"Widget",
"Wiki",
}
cp := reflect.ValueOf(c)
cv := reflect.Indirect(cp)
for _, s := range services {
require.Falsef(t, cv.FieldByName(s).IsNil(), "c.%s should not be nil", s)
}
}
func testClientDefaultUserAgent(t *testing.T, c *Client) {
expectedUserAgent := fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion)
require.Equal(t, expectedUserAgent, c.UserAgent())
}
func testClientDefaults(t *testing.T, c *Client) {
testClientDefaultUserAgent(t, c)
testClientServices(t, c)
}
func TestNewClient(t *testing.T) {
c, err := NewClient()
require.NoError(t, err)
testClientDefaults(t, c)
}
func TestNewClient_Error(t *testing.T) {
_, err := NewClient()
require.NoError(t, err)
errorOpt := func(c *Client) error {
return errors.New("foo")
}
_, err = NewClient(errorOpt)
require.EqualError(t, err, "foo")
}
func TestNewReadonlyClient(t *testing.T) {
c, err := NewReadonlyClient()
require.NoError(t, err)
require.Equal(t, c.BaseURL.String(), defaultBaseURLReadonly)
}
func TestNewReadonlyClient_Error(t *testing.T) {
_, err := NewReadonlyClient()
require.NoError(t, err)
errorOpt := func(c *Client) error {
return errors.New("foo")
}
_, err = NewReadonlyClient(errorOpt)
require.EqualError(t, err, "foo")
}
func TestDefaultClient(t *testing.T) {
require.NotNil(t, DefaultClient)
}
func TestClient_Readonly_NewRequest(t *testing.T) {
c, err := NewReadonlyClient()
require.NoError(t, err)
req, err := c.NewRequest(http.MethodGet, "r/golang", nil)
require.NoError(t, err)
require.Equal(t, defaultBaseURLReadonly+"/r/golang.json", req.URL.String())
}
func TestClient_OnRequestComplemented(t *testing.T) {
client, mux, teardown := setup()
defer teardown()
var i int
cb := func(*http.Request, *http.Response) {
i++
}
client.OnRequestCompleted(cb)
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
})
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
require.NoError(t, err)
_, _ = client.Do(ctx, req, nil)
require.Equal(t, 1, i)
_, _ = client.Do(ctx, req, nil)
_, _ = client.Do(ctx, req, nil)
_, _ = client.Do(ctx, req, nil)
_, _ = client.Do(ctx, req, nil)
require.Equal(t, 5, i)
_, _ = client.Do(ctx, req, nil)
require.Equal(t, 6, i)
}
func TestClient_JSONErrorResponse(t *testing.T) {
client, mux, teardown := setup()
defer teardown()
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
fmt.Fprint(w, `{
"json": {
"errors": [
[
"TEST_ERROR",
"this is a test error",
"test field"
]
]
}
}`)
})
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
require.NoError(t, err)
resp, err := client.Do(ctx, req, nil)
require.IsType(t, &JSONErrorResponse{}, err)
require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 200 field "test field" caused TEST_ERROR: this is a test error`, client.BaseURL))
require.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestClient_ErrorResponse(t *testing.T) {
client, mux, teardown := setup()
defer teardown()
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
w.WriteHeader(http.StatusForbidden)
fmt.Fprint(w, `{
"message": "error message"
}`)
})
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
require.NoError(t, err)
resp, err := client.Do(ctx, req, nil)
require.IsType(t, &ErrorResponse{}, err)
require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 403 error message`, client.BaseURL))
require.Equal(t, http.StatusForbidden, resp.StatusCode)
}
func TestClient_Do_RateLimitError(t *testing.T) {
client, mux, teardown := setup()
defer teardown()
var counter int
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
defer func() { counter++ }()
switch counter {
case 0:
w.Header().Set(headerRateLimitRemaining, "500")
w.Header().Set(headerRateLimitUsed, "100")
w.Header().Set(headerRateLimitReset, "120")
case 1:
w.Header().Set(headerRateLimitRemaining, "0")
w.Header().Set(headerRateLimitUsed, "600")
w.Header().Set(headerRateLimitReset, "240")
}
})
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
require.NoError(t, err)
client.rate.Remaining = 0
client.rate.Reset = time.Now().Add(time.Minute)
resp, err := client.Do(ctx, req, nil)
require.Equal(t, 0, counter)
require.IsType(t, &RateLimitError{}, err)
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
client.rate = Rate{}
resp, err = client.Do(ctx, req, nil)
require.Equal(t, 1, counter)
require.Nil(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 500, resp.Rate.Remaining)
require.Equal(t, 100, resp.Rate.Used)
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*2), resp.Rate.Reset)
resp, err = client.Do(ctx, req, nil)
require.Equal(t, 2, counter)
require.IsType(t, &RateLimitError{}, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, 0, resp.Rate.Remaining)
require.Equal(t, 600, resp.Rate.Used)
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*4), resp.Rate.Reset)
}