2020-07-11 13:49:07 -04:00
|
|
|
package reddit
|
2020-04-29 15:59:18 -04:00
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"io/ioutil"
|
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
|
|
|
"os"
|
|
|
|
"reflect"
|
|
|
|
"testing"
|
2020-09-28 22:13:50 -04:00
|
|
|
"time"
|
2020-05-29 19:50:52 -04:00
|
|
|
|
2020-08-11 16:21:07 -04:00
|
|
|
"github.com/stretchr/testify/require"
|
2020-04-29 15:59:18 -04:00
|
|
|
)
|
|
|
|
|
2020-08-23 22:34:59 -04:00
|
|
|
var ctx = context.Background()
|
2020-04-29 15:59:18 -04:00
|
|
|
|
2020-11-01 22:15:20 -05:00
|
|
|
func setup(t testing.TB) (*Client, *http.ServeMux) {
|
2020-08-23 22:34:59 -04:00
|
|
|
mux := http.NewServeMux()
|
2020-11-01 22:15:20 -05:00
|
|
|
|
2020-08-23 22:34:59 -04:00
|
|
|
server := httptest.NewServer(mux)
|
2020-11-01 22:15:20 -05:00
|
|
|
t.Cleanup(server.Close)
|
2020-04-29 15:59:18 -04:00
|
|
|
|
|
|
|
mux.HandleFunc("/api/v1/access_token", func(w http.ResponseWriter, r *http.Request) {
|
2020-08-23 22:34:59 -04:00
|
|
|
response := `{
|
2020-04-29 15:59:18 -04:00
|
|
|
"access_token": "token1",
|
|
|
|
"token_type": "bearer",
|
|
|
|
"expires_in": 3600,
|
|
|
|
"scope": "*"
|
2020-08-23 22:34:59 -04:00
|
|
|
}`
|
2020-04-29 15:59:18 -04:00
|
|
|
w.Header().Add(headerContentType, mediaTypeJSON)
|
|
|
|
fmt.Fprint(w, response)
|
|
|
|
})
|
|
|
|
|
2020-08-26 23:13:34 -04:00
|
|
|
client, _ := NewClient(
|
2021-01-24 22:51:27 -05:00
|
|
|
Credentials{"id1", "secret1", "user1", "password1"},
|
2020-04-29 15:59:18 -04:00
|
|
|
WithBaseURL(server.URL),
|
|
|
|
WithTokenURL(server.URL+"/api/v1/access_token"),
|
|
|
|
)
|
|
|
|
|
2020-11-01 22:15:20 -05:00
|
|
|
return client, mux
|
2020-04-29 15:59:18 -04:00
|
|
|
}
|
|
|
|
|
2020-07-21 21:56:32 -04:00
|
|
|
func readFileContents(path string) (string, error) {
|
|
|
|
file, err := os.Open(path)
|
2020-04-29 15:59:18 -04:00
|
|
|
if err != nil {
|
2020-07-21 21:56:32 -04:00
|
|
|
return "", err
|
2020-04-29 15:59:18 -04:00
|
|
|
}
|
|
|
|
defer file.Close()
|
|
|
|
|
|
|
|
bytes, err := ioutil.ReadAll(file)
|
|
|
|
if err != nil {
|
2020-07-21 21:56:32 -04:00
|
|
|
return "", err
|
2020-04-29 15:59:18 -04:00
|
|
|
}
|
|
|
|
|
2020-07-21 21:56:32 -04:00
|
|
|
return string(bytes), err
|
2020-04-29 15:59:18 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
func testClientServices(t *testing.T, c *Client) {
|
|
|
|
services := []string{
|
2020-06-27 23:53:59 -04:00
|
|
|
"Account",
|
2020-08-06 22:37:08 -04:00
|
|
|
"Collection",
|
2020-04-29 15:59:18 -04:00
|
|
|
"Comment",
|
2020-08-08 01:27:53 -04:00
|
|
|
"Emoji",
|
2020-04-29 15:59:18 -04:00
|
|
|
"Flair",
|
2020-08-20 17:04:44 -04:00
|
|
|
"Gold",
|
2020-04-29 15:59:18 -04:00
|
|
|
"Listings",
|
2020-09-15 21:41:12 -04:00
|
|
|
"LiveThread",
|
2020-08-16 22:22:13 -04:00
|
|
|
"Message",
|
2020-07-13 23:05:24 -04:00
|
|
|
"Moderation",
|
|
|
|
"Multi",
|
2020-06-22 21:52:34 -04:00
|
|
|
"Post",
|
2020-08-22 00:32:59 -04:00
|
|
|
"Stream",
|
2020-04-29 15:59:18 -04:00
|
|
|
"Subreddit",
|
|
|
|
"User",
|
2020-09-23 23:29:04 -04:00
|
|
|
"Widget",
|
2020-09-03 23:25:16 -04:00
|
|
|
"Wiki",
|
2020-04-29 15:59:18 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
cp := reflect.ValueOf(c)
|
|
|
|
cv := reflect.Indirect(cp)
|
|
|
|
|
|
|
|
for _, s := range services {
|
2020-08-11 16:21:07 -04:00
|
|
|
require.Falsef(t, cv.FieldByName(s).IsNil(), "c.%s should not be nil", s)
|
2020-04-29 15:59:18 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func testClientDefaultUserAgent(t *testing.T, c *Client) {
|
2020-08-29 02:48:22 -04:00
|
|
|
expectedUserAgent := fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion)
|
2020-08-23 22:25:29 -04:00
|
|
|
require.Equal(t, expectedUserAgent, c.UserAgent())
|
2020-04-29 15:59:18 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
func testClientDefaults(t *testing.T, c *Client) {
|
|
|
|
testClientDefaultUserAgent(t, c)
|
|
|
|
testClientServices(t, c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestNewClient(t *testing.T) {
|
2021-01-24 22:51:27 -05:00
|
|
|
c, err := NewClient(Credentials{})
|
2020-08-11 16:21:07 -04:00
|
|
|
require.NoError(t, err)
|
2020-04-29 15:59:18 -04:00
|
|
|
testClientDefaults(t, c)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestNewClient_Error(t *testing.T) {
|
2021-01-24 22:51:27 -05:00
|
|
|
_, err := NewClient(Credentials{})
|
2020-08-29 02:48:22 -04:00
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
errorOpt := func(c *Client) error {
|
|
|
|
return errors.New("foo")
|
|
|
|
}
|
|
|
|
|
2021-01-24 22:51:27 -05:00
|
|
|
_, err = NewClient(Credentials{}, errorOpt)
|
2020-08-29 02:48:22 -04:00
|
|
|
require.EqualError(t, err, "foo")
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestNewReadonlyClient(t *testing.T) {
|
|
|
|
c, err := NewReadonlyClient()
|
|
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, c.BaseURL.String(), defaultBaseURLReadonly)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestNewReadonlyClient_Error(t *testing.T) {
|
|
|
|
_, err := NewReadonlyClient()
|
|
|
|
require.NoError(t, err)
|
2020-08-27 00:29:53 -04:00
|
|
|
|
2020-04-29 15:59:18 -04:00
|
|
|
errorOpt := func(c *Client) error {
|
|
|
|
return errors.New("foo")
|
|
|
|
}
|
|
|
|
|
2020-08-29 02:48:22 -04:00
|
|
|
_, err = NewReadonlyClient(errorOpt)
|
2020-08-11 16:21:07 -04:00
|
|
|
require.EqualError(t, err, "foo")
|
2020-04-29 15:59:18 -04:00
|
|
|
}
|
2020-08-11 19:10:23 -04:00
|
|
|
|
2020-08-29 02:48:22 -04:00
|
|
|
func TestDefaultClient(t *testing.T) {
|
2020-11-01 22:25:13 -05:00
|
|
|
require.NotNil(t, DefaultClient())
|
2020-08-29 02:48:22 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
func TestClient_Readonly_NewRequest(t *testing.T) {
|
|
|
|
c, err := NewReadonlyClient()
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
req, err := c.NewRequest(http.MethodGet, "r/golang", nil)
|
|
|
|
require.NoError(t, err)
|
|
|
|
require.Equal(t, defaultBaseURLReadonly+"/r/golang.json", req.URL.String())
|
|
|
|
}
|
|
|
|
|
2020-08-11 19:10:23 -04:00
|
|
|
func TestClient_OnRequestComplemented(t *testing.T) {
|
2020-11-01 22:15:20 -05:00
|
|
|
client, mux := setup(t)
|
2020-08-11 19:10:23 -04:00
|
|
|
|
|
|
|
var i int
|
|
|
|
cb := func(*http.Request, *http.Response) {
|
|
|
|
i++
|
|
|
|
}
|
|
|
|
client.OnRequestCompleted(cb)
|
|
|
|
|
|
|
|
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
require.Equal(t, http.MethodGet, r.Method)
|
|
|
|
})
|
|
|
|
|
|
|
|
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
_, _ = client.Do(ctx, req, nil)
|
|
|
|
require.Equal(t, 1, i)
|
|
|
|
|
|
|
|
_, _ = client.Do(ctx, req, nil)
|
|
|
|
_, _ = client.Do(ctx, req, nil)
|
|
|
|
_, _ = client.Do(ctx, req, nil)
|
|
|
|
_, _ = client.Do(ctx, req, nil)
|
|
|
|
require.Equal(t, 5, i)
|
|
|
|
|
|
|
|
_, _ = client.Do(ctx, req, nil)
|
|
|
|
require.Equal(t, 6, i)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestClient_JSONErrorResponse(t *testing.T) {
|
2020-11-01 22:15:20 -05:00
|
|
|
client, mux := setup(t)
|
2020-08-11 19:10:23 -04:00
|
|
|
|
|
|
|
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
require.Equal(t, http.MethodGet, r.Method)
|
|
|
|
fmt.Fprint(w, `{
|
|
|
|
"json": {
|
|
|
|
"errors": [
|
|
|
|
[
|
|
|
|
"TEST_ERROR",
|
|
|
|
"this is a test error",
|
|
|
|
"test field"
|
|
|
|
]
|
|
|
|
]
|
|
|
|
}
|
|
|
|
}`)
|
|
|
|
})
|
|
|
|
|
|
|
|
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
resp, err := client.Do(ctx, req, nil)
|
|
|
|
require.IsType(t, &JSONErrorResponse{}, err)
|
2020-08-23 22:34:59 -04:00
|
|
|
require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 200 field "test field" caused TEST_ERROR: this is a test error`, client.BaseURL))
|
2020-08-11 19:10:23 -04:00
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestClient_ErrorResponse(t *testing.T) {
|
2020-11-01 22:15:20 -05:00
|
|
|
client, mux := setup(t)
|
2020-08-11 19:10:23 -04:00
|
|
|
|
|
|
|
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
require.Equal(t, http.MethodGet, r.Method)
|
|
|
|
w.WriteHeader(http.StatusForbidden)
|
|
|
|
fmt.Fprint(w, `{
|
|
|
|
"message": "error message"
|
|
|
|
}`)
|
|
|
|
})
|
|
|
|
|
|
|
|
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
resp, err := client.Do(ctx, req, nil)
|
|
|
|
require.IsType(t, &ErrorResponse{}, err)
|
2020-08-23 22:34:59 -04:00
|
|
|
require.EqualError(t, err, fmt.Sprintf(`GET %s/api/v1/test: 403 error message`, client.BaseURL))
|
2020-08-11 19:10:23 -04:00
|
|
|
require.Equal(t, http.StatusForbidden, resp.StatusCode)
|
|
|
|
}
|
2020-09-28 22:13:50 -04:00
|
|
|
|
|
|
|
func TestClient_Do_RateLimitError(t *testing.T) {
|
2020-11-01 22:15:20 -05:00
|
|
|
client, mux := setup(t)
|
2020-09-28 22:13:50 -04:00
|
|
|
|
|
|
|
var counter int
|
|
|
|
mux.HandleFunc("/api/v1/test", func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
require.Equal(t, http.MethodGet, r.Method)
|
|
|
|
defer func() { counter++ }()
|
|
|
|
|
|
|
|
switch counter {
|
|
|
|
case 0:
|
|
|
|
w.Header().Set(headerRateLimitRemaining, "500")
|
|
|
|
w.Header().Set(headerRateLimitUsed, "100")
|
|
|
|
w.Header().Set(headerRateLimitReset, "120")
|
|
|
|
case 1:
|
|
|
|
w.Header().Set(headerRateLimitRemaining, "0")
|
|
|
|
w.Header().Set(headerRateLimitUsed, "600")
|
|
|
|
w.Header().Set(headerRateLimitReset, "240")
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
req, err := client.NewRequest(http.MethodGet, "api/v1/test", nil)
|
|
|
|
require.NoError(t, err)
|
|
|
|
|
|
|
|
client.rate.Remaining = 0
|
|
|
|
client.rate.Reset = time.Now().Add(time.Minute)
|
|
|
|
|
|
|
|
resp, err := client.Do(ctx, req, nil)
|
|
|
|
require.Equal(t, 0, counter)
|
|
|
|
require.IsType(t, &RateLimitError{}, err)
|
|
|
|
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
|
|
|
|
|
|
|
|
client.rate = Rate{}
|
|
|
|
|
|
|
|
resp, err = client.Do(ctx, req, nil)
|
|
|
|
require.Equal(t, 1, counter)
|
|
|
|
require.Nil(t, err)
|
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
require.Equal(t, 500, resp.Rate.Remaining)
|
|
|
|
require.Equal(t, 100, resp.Rate.Used)
|
|
|
|
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*2), resp.Rate.Reset)
|
|
|
|
|
|
|
|
resp, err = client.Do(ctx, req, nil)
|
|
|
|
require.Equal(t, 2, counter)
|
|
|
|
require.IsType(t, &RateLimitError{}, err)
|
|
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
require.Equal(t, 0, resp.Rate.Remaining)
|
|
|
|
require.Equal(t, 600, resp.Rate.Used)
|
|
|
|
require.Equal(t, time.Now().Truncate(time.Second).Add(time.Minute*4), resp.Rate.Reset)
|
|
|
|
}
|