From c4faa00b9411e9db61dcaab5a136d2782fbf7858 Mon Sep 17 00:00:00 2001 From: Vartan Benohanian Date: Thu, 27 Aug 2020 00:29:53 -0400 Subject: [PATCH] Return errors on nil credentials, move code around Signed-off-by: Vartan Benohanian --- reddit/reddit-oauth.go | 14 ++----- reddit/reddit-options.go | 31 ++++++++------- reddit/reddit-options_test.go | 49 ++++++++++++------------ reddit/reddit-user-agent.go | 42 +++++++++++++++++++++ reddit/reddit.go | 71 +++++++++++------------------------ reddit/reddit_test.go | 7 +++- 6 files changed, 110 insertions(+), 104 deletions(-) create mode 100644 reddit/reddit-user-agent.go diff --git a/reddit/reddit-oauth.go b/reddit/reddit-oauth.go index b18efdd..13030d6 100644 --- a/reddit/reddit-oauth.go +++ b/reddit/reddit-oauth.go @@ -51,16 +51,8 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) { } func oauthTransport(client *Client) http.RoundTripper { - // We need to set a custom user agent, because using the one set by default by the - // stdlib gives us 429 Too Many Request responses from the Reddit API. - userAgentTransport := &userAgentTransport{ - userAgent: client.UserAgent(), - Base: client.client.Transport, - } - - httpClient := &http.Client{Transport: userAgentTransport} - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + httpClient := &http.Client{Transport: client.client.Transport} + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) config := &oauth2.Config{ ClientID: client.ID, @@ -80,6 +72,6 @@ func oauthTransport(client *Client) http.RoundTripper { return &oauth2.Transport{ Source: tokenSource, - Base: userAgentTransport, + Base: client.client.Transport, } } diff --git a/reddit/reddit-options.go b/reddit/reddit-options.go index c1b434d..4986131 100644 --- a/reddit/reddit-options.go +++ b/reddit/reddit-options.go @@ -4,7 +4,6 @@ import ( "errors" "net/http" "net/url" - "os" ) // Opt is a configuration option to initialize a client. @@ -27,21 +26,21 @@ func WithHTTPClient(httpClient *http.Client) Opt { // GO_REDDIT_CLIENT_SECRET to set the client's secret. // GO_REDDIT_CLIENT_USERNAME to set the client's username. // GO_REDDIT_CLIENT_PASSWORD to set the client's password. -func FromEnv(c *Client) error { - if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" { - c.ID = v - } - if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" { - c.Secret = v - } - if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" { - c.Username = v - } - if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" { - c.Password = v - } - return nil -} +// func FromEnv(c *Client) error { +// if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" { +// c.ID = v +// } +// if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" { +// c.Secret = v +// } +// if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" { +// c.Username = v +// } +// if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" { +// c.Password = v +// } +// return nil +// } // WithBaseURL sets the base URL for the client to make requests to. func WithBaseURL(u string) Opt { diff --git a/reddit/reddit-options_test.go b/reddit/reddit-options_test.go index 1853855..15ca5bf 100644 --- a/reddit/reddit-options_test.go +++ b/reddit/reddit-options_test.go @@ -3,65 +3,64 @@ package reddit import ( "net/http" "net/url" - "os" "testing" "github.com/stretchr/testify/require" ) func TestWithHTTPClient(t *testing.T) { - _, err := NewClient(nil, WithHTTPClient(nil)) + _, err := NewClient(&Credentials{}, WithHTTPClient(nil)) require.EqualError(t, err, "httpClient: cannot be nil") - _, err = NewClient(nil, WithHTTPClient(&http.Client{})) + _, err = NewClient(&Credentials{}, WithHTTPClient(&http.Client{})) require.NoError(t, err) } -func TestFromEnv(t *testing.T) { - os.Setenv("GO_REDDIT_CLIENT_ID", "id1") - defer os.Unsetenv("GO_REDDIT_CLIENT_ID") +// func TestFromEnv(t *testing.T) { +// os.Setenv("GO_REDDIT_CLIENT_ID", "id1") +// defer os.Unsetenv("GO_REDDIT_CLIENT_ID") - os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1") - defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET") +// os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1") +// defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET") - os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1") - defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME") +// os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1") +// defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME") - os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1") - defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD") +// os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1") +// defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD") - c, err := NewClient(nil, FromEnv) - require.NoError(t, err) +// c, err := NewClient(&Credentials{}, FromEnv) +// require.NoError(t, err) - type values struct { - id, secret, username, password string - } +// type values struct { +// id, secret, username, password string +// } - expect := values{"id1", "secret1", "username1", "password1"} - actual := values{c.ID, c.Secret, c.Username, c.Password} - require.Equal(t, expect, actual) -} +// expect := values{"id1", "secret1", "username1", "password1"} +// actual := values{c.ID, c.Secret, c.Username, c.Password} +// require.Equal(t, expect, actual) +// } func TestWithBaseURL(t *testing.T) { - c, err := NewClient(nil, WithBaseURL(":")) + c, err := NewClient(&Credentials{}, WithBaseURL(":")) urlErr, ok := err.(*url.Error) require.True(t, ok) require.Equal(t, "parse", urlErr.Op) baseURL := "http://localhost:8080" - c, err = NewClient(nil, WithBaseURL(baseURL)) + c, err = NewClient(&Credentials{}, WithBaseURL(baseURL)) require.NoError(t, err) require.Equal(t, baseURL, c.BaseURL.String()) } func TestWithTokenURL(t *testing.T) { - c, err := NewClient(nil, WithTokenURL(":")) + c, err := NewClient(&Credentials{}, WithTokenURL(":")) urlErr, ok := err.(*url.Error) require.True(t, ok) require.Equal(t, "parse", urlErr.Op) tokenURL := "http://localhost:8080/api/v1/access_token" - c, err = NewClient(nil, WithTokenURL(tokenURL)) + c, err = NewClient(&Credentials{}, WithTokenURL(tokenURL)) require.NoError(t, err) require.Equal(t, tokenURL, c.TokenURL.String()) } diff --git a/reddit/reddit-user-agent.go b/reddit/reddit-user-agent.go new file mode 100644 index 0000000..5dea93e --- /dev/null +++ b/reddit/reddit-user-agent.go @@ -0,0 +1,42 @@ +package reddit + +import "net/http" + +// cloneRequest returns a clone of the provided *http.Request. +// The clone is a shallow copy of the struct and its Header map, +// since we'll only be modify the headers. +// Per the specification of http.RoundTripper, we should not directly modify a request. +func cloneRequest(r *http.Request) *http.Request { + r2 := new(http.Request) + *r2 = *r + // deep copy of the Header + r2.Header = make(http.Header, len(r.Header)) + for k, s := range r.Header { + r2.Header[k] = append([]string(nil), s...) + } + return r2 +} + +// Sets the User-Agent header for requests. +type userAgentTransport struct { + userAgent string + Base http.RoundTripper +} + +func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request { + req2 := cloneRequest(req) + req2.Header.Set(headerUserAgent, t.userAgent) + return req2 +} + +func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := t.setUserAgent(req) + return t.base().RoundTrip(req2) +} + +func (t *userAgentTransport) base() http.RoundTripper { + if t.Base != nil { + return t.Base + } + return http.DefaultTransport +} diff --git a/reddit/reddit.go b/reddit/reddit.go index 9ea8397..37eb8d7 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "io/ioutil" @@ -31,45 +32,6 @@ const ( headerUserAgent = "User-Agent" ) -// cloneRequest returns a clone of the provided *http.Request. -// The clone is a shallow copy of the struct and its Header map, -// since we'll only be modify the headers. -// Per the specification of http.RoundTripper, we should not directly modify a request. -func cloneRequest(r *http.Request) *http.Request { - r2 := new(http.Request) - *r2 = *r - // deep copy of the Header - r2.Header = make(http.Header, len(r.Header)) - for k, s := range r.Header { - r2.Header[k] = append([]string(nil), s...) - } - return r2 -} - -// Sets the User-Agent header for requests. -type userAgentTransport struct { - userAgent string - Base http.RoundTripper -} - -func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request { - req2 := cloneRequest(req) - req2.Header.Set(headerUserAgent, t.userAgent) - return req2 -} - -func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { - req2 := t.setUserAgent(req) - return t.base().RoundTrip(req2) -} - -func (t *userAgentTransport) base() http.RoundTripper { - if t.Base != nil { - return t.Base - } - return http.DefaultTransport -} - // RequestCompletionCallback defines the type of the request callback function. type RequestCompletionCallback func(*http.Request, *http.Response) @@ -152,8 +114,11 @@ func newClient() *Client { // NewClient returns a new Reddit API client. func NewClient(creds *Credentials, opts ...Opt) (*Client, error) { - client := newClient() + if creds == nil { + return nil, errors.New("must provide credentials to initialize *reddit.Client") + } + client := newClient() for _, opt := range opts { if err := opt(client); err != nil { return nil, err @@ -164,6 +129,14 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) { client.client = &http.Client{} } + // We need to set a custom user agent, because using the one set by the + // stdlib gives us 429 Too Many Requests responses from the Reddit API. + userAgentTransport := &userAgentTransport{ + userAgent: client.UserAgent(), + Base: client.client.Transport, + } + client.client.Transport = userAgentTransport + // todo... // Some endpoints (notably the ones to get random subreddits/posts) redirect to a // reddit.com url, which returns a 403 Forbidden for some reason, unless the url's @@ -183,15 +156,13 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) { } } - if creds != nil { - client.ID = creds.ID - client.Secret = creds.Secret - client.Username = creds.Username - client.Password = creds.Password + client.ID = creds.ID + client.Secret = creds.Secret + client.Username = creds.Username + client.Password = creds.Password - oauthTransport := oauthTransport(client) - client.client.Transport = oauthTransport - } + oauthTransport := oauthTransport(client) + client.client.Transport = oauthTransport return client, nil } @@ -204,7 +175,7 @@ func (c *Client) UserAgent() string { return c.userAgent } -// NewRequest creates an API request. +// NewRequest creates an API request with a JSON body. // The path is the relative URL which will be resolves to the BaseURL of the Client. // It should always be specified without a preceding slash. func (c *Client) NewRequest(method string, path string, body interface{}) (*http.Request, error) { @@ -253,7 +224,7 @@ func (c *Client) NewRequestWithForm(method string, path string, form url.Values) return req, nil } -// Response is a PlayNetwork response. This wraps the standard http.Response returned from PlayNetwork. +// Response is a Reddit response. This wraps the standard http.Response returned from Reddit. type Response struct { *http.Response } diff --git a/reddit/reddit_test.go b/reddit/reddit_test.go index b7fdb9d..7820605 100644 --- a/reddit/reddit_test.go +++ b/reddit/reddit_test.go @@ -92,17 +92,20 @@ func testClientDefaults(t *testing.T, c *Client) { } func TestNewClient(t *testing.T) { - c, err := NewClient(nil) + c, err := NewClient(&Credentials{}) require.NoError(t, err) testClientDefaults(t, c) } func TestNewClient_Error(t *testing.T) { + _, err := NewClient(nil) + require.EqualError(t, err, "must provide credentials to initialize *reddit.Client") + errorOpt := func(c *Client) error { return errors.New("foo") } - _, err := NewClient(nil, errorOpt) + _, err = NewClient(&Credentials{}, errorOpt) require.EqualError(t, err, "foo") }