Return errors on nil credentials, move code around

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-08-27 00:29:53 -04:00
parent c0f0cf8be6
commit c4faa00b94
6 changed files with 110 additions and 104 deletions

View File

@ -51,16 +51,8 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
}
func oauthTransport(client *Client) http.RoundTripper {
// We need to set a custom user agent, because using the one set by default by the
// stdlib gives us 429 Too Many Request responses from the Reddit API.
userAgentTransport := &userAgentTransport{
userAgent: client.UserAgent(),
Base: client.client.Transport,
}
httpClient := &http.Client{Transport: userAgentTransport}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
httpClient := &http.Client{Transport: client.client.Transport}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
config := &oauth2.Config{
ClientID: client.ID,
@ -80,6 +72,6 @@ func oauthTransport(client *Client) http.RoundTripper {
return &oauth2.Transport{
Source: tokenSource,
Base: userAgentTransport,
Base: client.client.Transport,
}
}

View File

@ -4,7 +4,6 @@ import (
"errors"
"net/http"
"net/url"
"os"
)
// Opt is a configuration option to initialize a client.
@ -27,21 +26,21 @@ func WithHTTPClient(httpClient *http.Client) Opt {
// GO_REDDIT_CLIENT_SECRET to set the client's secret.
// GO_REDDIT_CLIENT_USERNAME to set the client's username.
// GO_REDDIT_CLIENT_PASSWORD to set the client's password.
func FromEnv(c *Client) error {
if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" {
c.ID = v
}
if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" {
c.Secret = v
}
if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" {
c.Username = v
}
if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" {
c.Password = v
}
return nil
}
// func FromEnv(c *Client) error {
// if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" {
// c.ID = v
// }
// if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" {
// c.Secret = v
// }
// if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" {
// c.Username = v
// }
// if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" {
// c.Password = v
// }
// return nil
// }
// WithBaseURL sets the base URL for the client to make requests to.
func WithBaseURL(u string) Opt {

View File

@ -3,65 +3,64 @@ package reddit
import (
"net/http"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
)
func TestWithHTTPClient(t *testing.T) {
_, err := NewClient(nil, WithHTTPClient(nil))
_, err := NewClient(&Credentials{}, WithHTTPClient(nil))
require.EqualError(t, err, "httpClient: cannot be nil")
_, err = NewClient(nil, WithHTTPClient(&http.Client{}))
_, err = NewClient(&Credentials{}, WithHTTPClient(&http.Client{}))
require.NoError(t, err)
}
func TestFromEnv(t *testing.T) {
os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
// func TestFromEnv(t *testing.T) {
// os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
// defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1")
defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET")
// os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1")
// defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET")
os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1")
defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME")
// os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1")
// defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME")
os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1")
defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD")
// os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1")
// defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD")
c, err := NewClient(nil, FromEnv)
require.NoError(t, err)
// c, err := NewClient(&Credentials{}, FromEnv)
// require.NoError(t, err)
type values struct {
id, secret, username, password string
}
// type values struct {
// id, secret, username, password string
// }
expect := values{"id1", "secret1", "username1", "password1"}
actual := values{c.ID, c.Secret, c.Username, c.Password}
require.Equal(t, expect, actual)
}
// expect := values{"id1", "secret1", "username1", "password1"}
// actual := values{c.ID, c.Secret, c.Username, c.Password}
// require.Equal(t, expect, actual)
// }
func TestWithBaseURL(t *testing.T) {
c, err := NewClient(nil, WithBaseURL(":"))
c, err := NewClient(&Credentials{}, WithBaseURL(":"))
urlErr, ok := err.(*url.Error)
require.True(t, ok)
require.Equal(t, "parse", urlErr.Op)
baseURL := "http://localhost:8080"
c, err = NewClient(nil, WithBaseURL(baseURL))
c, err = NewClient(&Credentials{}, WithBaseURL(baseURL))
require.NoError(t, err)
require.Equal(t, baseURL, c.BaseURL.String())
}
func TestWithTokenURL(t *testing.T) {
c, err := NewClient(nil, WithTokenURL(":"))
c, err := NewClient(&Credentials{}, WithTokenURL(":"))
urlErr, ok := err.(*url.Error)
require.True(t, ok)
require.Equal(t, "parse", urlErr.Op)
tokenURL := "http://localhost:8080/api/v1/access_token"
c, err = NewClient(nil, WithTokenURL(tokenURL))
c, err = NewClient(&Credentials{}, WithTokenURL(tokenURL))
require.NoError(t, err)
require.Equal(t, tokenURL, c.TokenURL.String())
}

View File

@ -0,0 +1,42 @@
package reddit
import "net/http"
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map,
// since we'll only be modify the headers.
// Per the specification of http.RoundTripper, we should not directly modify a request.
func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
// Sets the User-Agent header for requests.
type userAgentTransport struct {
userAgent string
Base http.RoundTripper
}
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
req2 := cloneRequest(req)
req2.Header.Set(headerUserAgent, t.userAgent)
return req2
}
func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := t.setUserAgent(req)
return t.base().RoundTrip(req2)
}
func (t *userAgentTransport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
@ -31,45 +32,6 @@ const (
headerUserAgent = "User-Agent"
)
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map,
// since we'll only be modify the headers.
// Per the specification of http.RoundTripper, we should not directly modify a request.
func cloneRequest(r *http.Request) *http.Request {
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header, len(r.Header))
for k, s := range r.Header {
r2.Header[k] = append([]string(nil), s...)
}
return r2
}
// Sets the User-Agent header for requests.
type userAgentTransport struct {
userAgent string
Base http.RoundTripper
}
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
req2 := cloneRequest(req)
req2.Header.Set(headerUserAgent, t.userAgent)
return req2
}
func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req2 := t.setUserAgent(req)
return t.base().RoundTrip(req2)
}
func (t *userAgentTransport) base() http.RoundTripper {
if t.Base != nil {
return t.Base
}
return http.DefaultTransport
}
// RequestCompletionCallback defines the type of the request callback function.
type RequestCompletionCallback func(*http.Request, *http.Response)
@ -152,8 +114,11 @@ func newClient() *Client {
// NewClient returns a new Reddit API client.
func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
client := newClient()
if creds == nil {
return nil, errors.New("must provide credentials to initialize *reddit.Client")
}
client := newClient()
for _, opt := range opts {
if err := opt(client); err != nil {
return nil, err
@ -164,6 +129,14 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
client.client = &http.Client{}
}
// We need to set a custom user agent, because using the one set by the
// stdlib gives us 429 Too Many Requests responses from the Reddit API.
userAgentTransport := &userAgentTransport{
userAgent: client.UserAgent(),
Base: client.client.Transport,
}
client.client.Transport = userAgentTransport
// todo...
// Some endpoints (notably the ones to get random subreddits/posts) redirect to a
// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
@ -183,15 +156,13 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
}
}
if creds != nil {
client.ID = creds.ID
client.Secret = creds.Secret
client.Username = creds.Username
client.Password = creds.Password
client.ID = creds.ID
client.Secret = creds.Secret
client.Username = creds.Username
client.Password = creds.Password
oauthTransport := oauthTransport(client)
client.client.Transport = oauthTransport
}
oauthTransport := oauthTransport(client)
client.client.Transport = oauthTransport
return client, nil
}
@ -204,7 +175,7 @@ func (c *Client) UserAgent() string {
return c.userAgent
}
// NewRequest creates an API request.
// NewRequest creates an API request with a JSON body.
// The path is the relative URL which will be resolves to the BaseURL of the Client.
// It should always be specified without a preceding slash.
func (c *Client) NewRequest(method string, path string, body interface{}) (*http.Request, error) {
@ -253,7 +224,7 @@ func (c *Client) NewRequestWithForm(method string, path string, form url.Values)
return req, nil
}
// Response is a PlayNetwork response. This wraps the standard http.Response returned from PlayNetwork.
// Response is a Reddit response. This wraps the standard http.Response returned from Reddit.
type Response struct {
*http.Response
}

View File

@ -92,17 +92,20 @@ func testClientDefaults(t *testing.T, c *Client) {
}
func TestNewClient(t *testing.T) {
c, err := NewClient(nil)
c, err := NewClient(&Credentials{})
require.NoError(t, err)
testClientDefaults(t, c)
}
func TestNewClient_Error(t *testing.T) {
_, err := NewClient(nil)
require.EqualError(t, err, "must provide credentials to initialize *reddit.Client")
errorOpt := func(c *Client) error {
return errors.New("foo")
}
_, err := NewClient(nil, errorOpt)
_, err = NewClient(&Credentials{}, errorOpt)
require.EqualError(t, err, "foo")
}