Return errors on nil credentials, move code around

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-08-27 00:29:53 -04:00
parent c0f0cf8be6
commit c4faa00b94
6 changed files with 110 additions and 104 deletions

View File

@ -51,16 +51,8 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
} }
func oauthTransport(client *Client) http.RoundTripper { func oauthTransport(client *Client) http.RoundTripper {
// We need to set a custom user agent, because using the one set by default by the httpClient := &http.Client{Transport: client.client.Transport}
// stdlib gives us 429 Too Many Request responses from the Reddit API. ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
userAgentTransport := &userAgentTransport{
userAgent: client.UserAgent(),
Base: client.client.Transport,
}
httpClient := &http.Client{Transport: userAgentTransport}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
config := &oauth2.Config{ config := &oauth2.Config{
ClientID: client.ID, ClientID: client.ID,
@ -80,6 +72,6 @@ func oauthTransport(client *Client) http.RoundTripper {
return &oauth2.Transport{ return &oauth2.Transport{
Source: tokenSource, Source: tokenSource,
Base: userAgentTransport, Base: client.client.Transport,
} }
} }

View File

@ -4,7 +4,6 @@ import (
"errors" "errors"
"net/http" "net/http"
"net/url" "net/url"
"os"
) )
// Opt is a configuration option to initialize a client. // Opt is a configuration option to initialize a client.
@ -27,21 +26,21 @@ func WithHTTPClient(httpClient *http.Client) Opt {
// GO_REDDIT_CLIENT_SECRET to set the client's secret. // GO_REDDIT_CLIENT_SECRET to set the client's secret.
// GO_REDDIT_CLIENT_USERNAME to set the client's username. // GO_REDDIT_CLIENT_USERNAME to set the client's username.
// GO_REDDIT_CLIENT_PASSWORD to set the client's password. // GO_REDDIT_CLIENT_PASSWORD to set the client's password.
func FromEnv(c *Client) error { // func FromEnv(c *Client) error {
if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" { // if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" {
c.ID = v // c.ID = v
} // }
if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" { // if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" {
c.Secret = v // c.Secret = v
} // }
if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" { // if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" {
c.Username = v // c.Username = v
} // }
if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" { // if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" {
c.Password = v // c.Password = v
} // }
return nil // return nil
} // }
// WithBaseURL sets the base URL for the client to make requests to. // WithBaseURL sets the base URL for the client to make requests to.
func WithBaseURL(u string) Opt { func WithBaseURL(u string) Opt {

View File

@ -3,65 +3,64 @@ package reddit
import ( import (
"net/http" "net/http"
"net/url" "net/url"
"os"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestWithHTTPClient(t *testing.T) { func TestWithHTTPClient(t *testing.T) {
_, err := NewClient(nil, WithHTTPClient(nil)) _, err := NewClient(&Credentials{}, WithHTTPClient(nil))
require.EqualError(t, err, "httpClient: cannot be nil") require.EqualError(t, err, "httpClient: cannot be nil")
_, err = NewClient(nil, WithHTTPClient(&http.Client{})) _, err = NewClient(&Credentials{}, WithHTTPClient(&http.Client{}))
require.NoError(t, err) require.NoError(t, err)
} }
func TestFromEnv(t *testing.T) { // func TestFromEnv(t *testing.T) {
os.Setenv("GO_REDDIT_CLIENT_ID", "id1") // os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
defer os.Unsetenv("GO_REDDIT_CLIENT_ID") // defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1") // os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1")
defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET") // defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET")
os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1") // os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1")
defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME") // defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME")
os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1") // os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1")
defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD") // defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD")
c, err := NewClient(nil, FromEnv) // c, err := NewClient(&Credentials{}, FromEnv)
require.NoError(t, err) // require.NoError(t, err)
type values struct { // type values struct {
id, secret, username, password string // id, secret, username, password string
} // }
expect := values{"id1", "secret1", "username1", "password1"} // expect := values{"id1", "secret1", "username1", "password1"}
actual := values{c.ID, c.Secret, c.Username, c.Password} // actual := values{c.ID, c.Secret, c.Username, c.Password}
require.Equal(t, expect, actual) // require.Equal(t, expect, actual)
} // }
func TestWithBaseURL(t *testing.T) { func TestWithBaseURL(t *testing.T) {
c, err := NewClient(nil, WithBaseURL(":")) c, err := NewClient(&Credentials{}, WithBaseURL(":"))
urlErr, ok := err.(*url.Error) urlErr, ok := err.(*url.Error)
require.True(t, ok) require.True(t, ok)
require.Equal(t, "parse", urlErr.Op) require.Equal(t, "parse", urlErr.Op)
baseURL := "http://localhost:8080" baseURL := "http://localhost:8080"
c, err = NewClient(nil, WithBaseURL(baseURL)) c, err = NewClient(&Credentials{}, WithBaseURL(baseURL))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, baseURL, c.BaseURL.String()) require.Equal(t, baseURL, c.BaseURL.String())
} }
func TestWithTokenURL(t *testing.T) { func TestWithTokenURL(t *testing.T) {
c, err := NewClient(nil, WithTokenURL(":")) c, err := NewClient(&Credentials{}, WithTokenURL(":"))
urlErr, ok := err.(*url.Error) urlErr, ok := err.(*url.Error)
require.True(t, ok) require.True(t, ok)
require.Equal(t, "parse", urlErr.Op) require.Equal(t, "parse", urlErr.Op)
tokenURL := "http://localhost:8080/api/v1/access_token" tokenURL := "http://localhost:8080/api/v1/access_token"
c, err = NewClient(nil, WithTokenURL(tokenURL)) c, err = NewClient(&Credentials{}, WithTokenURL(tokenURL))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, tokenURL, c.TokenURL.String()) require.Equal(t, tokenURL, c.TokenURL.String())
} }

View File

@ -0,0 +1,42 @@
package reddit
import "net/http"
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map,
// since we'll only be modify the headers.
// Per the specification of http.RoundTripper, we should not directly modify a request.
func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
// Sets the User-Agent header for requests.
type userAgentTransport struct {
userAgent string
Base http.RoundTripper
}
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
req2 := cloneRequest(req)
req2.Header.Set(headerUserAgent, t.userAgent)
return req2
}
func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := t.setUserAgent(req)
return t.base().RoundTrip(req2)
}
func (t *userAgentTransport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}

View File

@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -31,45 +32,6 @@ const (
headerUserAgent = "User-Agent" headerUserAgent = "User-Agent"
) )
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map,
// since we'll only be modify the headers.
// Per the specification of http.RoundTripper, we should not directly modify a request.
func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
// Sets the User-Agent header for requests.
type userAgentTransport struct {
userAgent string
Base http.RoundTripper
}
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
req2 := cloneRequest(req)
req2.Header.Set(headerUserAgent, t.userAgent)
return req2
}
func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := t.setUserAgent(req)
return t.base().RoundTrip(req2)
}
func (t *userAgentTransport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
// RequestCompletionCallback defines the type of the request callback function. // RequestCompletionCallback defines the type of the request callback function.
type RequestCompletionCallback func(*http.Request, *http.Response) type RequestCompletionCallback func(*http.Request, *http.Response)
@ -152,8 +114,11 @@ func newClient() *Client {
// NewClient returns a new Reddit API client. // NewClient returns a new Reddit API client.
func NewClient(creds *Credentials, opts ...Opt) (*Client, error) { func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
client := newClient() if creds == nil {
return nil, errors.New("must provide credentials to initialize *reddit.Client")
}
client := newClient()
for _, opt := range opts { for _, opt := range opts {
if err := opt(client); err != nil { if err := opt(client); err != nil {
return nil, err return nil, err
@ -164,6 +129,14 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
client.client = &http.Client{} client.client = &http.Client{}
} }
// We need to set a custom user agent, because using the one set by the
// stdlib gives us 429 Too Many Requests responses from the Reddit API.
userAgentTransport := &userAgentTransport{
userAgent: client.UserAgent(),
Base: client.client.Transport,
}
client.client.Transport = userAgentTransport
// todo... // todo...
// Some endpoints (notably the ones to get random subreddits/posts) redirect to a // Some endpoints (notably the ones to get random subreddits/posts) redirect to a
// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's // reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
@ -183,15 +156,13 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
} }
} }
if creds != nil { client.ID = creds.ID
client.ID = creds.ID client.Secret = creds.Secret
client.Secret = creds.Secret client.Username = creds.Username
client.Username = creds.Username client.Password = creds.Password
client.Password = creds.Password
oauthTransport := oauthTransport(client) oauthTransport := oauthTransport(client)
client.client.Transport = oauthTransport client.client.Transport = oauthTransport
}
return client, nil return client, nil
} }
@ -204,7 +175,7 @@ func (c *Client) UserAgent() string {
return c.userAgent return c.userAgent
} }
// NewRequest creates an API request. // NewRequest creates an API request with a JSON body.
// The path is the relative URL which will be resolves to the BaseURL of the Client. // The path is the relative URL which will be resolves to the BaseURL of the Client.
// It should always be specified without a preceding slash. // It should always be specified without a preceding slash.
func (c *Client) NewRequest(method string, path string, body interface{}) (*http.Request, error) { func (c *Client) NewRequest(method string, path string, body interface{}) (*http.Request, error) {
@ -253,7 +224,7 @@ func (c *Client) NewRequestWithForm(method string, path string, form url.Values)
return req, nil return req, nil
} }
// Response is a PlayNetwork response. This wraps the standard http.Response returned from PlayNetwork. // Response is a Reddit response. This wraps the standard http.Response returned from Reddit.
type Response struct { type Response struct {
*http.Response *http.Response
} }

View File

@ -92,17 +92,20 @@ func testClientDefaults(t *testing.T, c *Client) {
} }
func TestNewClient(t *testing.T) { func TestNewClient(t *testing.T) {
c, err := NewClient(nil) c, err := NewClient(&Credentials{})
require.NoError(t, err) require.NoError(t, err)
testClientDefaults(t, c) testClientDefaults(t, c)
} }
func TestNewClient_Error(t *testing.T) { func TestNewClient_Error(t *testing.T) {
_, err := NewClient(nil)
require.EqualError(t, err, "must provide credentials to initialize *reddit.Client")
errorOpt := func(c *Client) error { errorOpt := func(c *Client) error {
return errors.New("foo") return errors.New("foo")
} }
_, err := NewClient(nil, errorOpt) _, err = NewClient(&Credentials{}, errorOpt)
require.EqualError(t, err, "foo") require.EqualError(t, err, "foo")
} }