Return errors on nil credentials, move code around
Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
parent
c0f0cf8be6
commit
c4faa00b94
6 changed files with 110 additions and 104 deletions
|
@ -51,16 +51,8 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
|
|||
}
|
||||
|
||||
func oauthTransport(client *Client) http.RoundTripper {
|
||||
// We need to set a custom user agent, because using the one set by default by the
|
||||
// stdlib gives us 429 Too Many Request responses from the Reddit API.
|
||||
userAgentTransport := &userAgentTransport{
|
||||
userAgent: client.UserAgent(),
|
||||
Base: client.client.Transport,
|
||||
}
|
||||
|
||||
httpClient := &http.Client{Transport: userAgentTransport}
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
httpClient := &http.Client{Transport: client.client.Transport}
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||
|
||||
config := &oauth2.Config{
|
||||
ClientID: client.ID,
|
||||
|
@ -80,6 +72,6 @@ func oauthTransport(client *Client) http.RoundTripper {
|
|||
|
||||
return &oauth2.Transport{
|
||||
Source: tokenSource,
|
||||
Base: userAgentTransport,
|
||||
Base: client.client.Transport,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Opt is a configuration option to initialize a client.
|
||||
|
@ -27,21 +26,21 @@ func WithHTTPClient(httpClient *http.Client) Opt {
|
|||
// GO_REDDIT_CLIENT_SECRET to set the client's secret.
|
||||
// GO_REDDIT_CLIENT_USERNAME to set the client's username.
|
||||
// GO_REDDIT_CLIENT_PASSWORD to set the client's password.
|
||||
func FromEnv(c *Client) error {
|
||||
if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" {
|
||||
c.ID = v
|
||||
}
|
||||
if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" {
|
||||
c.Secret = v
|
||||
}
|
||||
if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" {
|
||||
c.Username = v
|
||||
}
|
||||
if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" {
|
||||
c.Password = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// func FromEnv(c *Client) error {
|
||||
// if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" {
|
||||
// c.ID = v
|
||||
// }
|
||||
// if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" {
|
||||
// c.Secret = v
|
||||
// }
|
||||
// if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" {
|
||||
// c.Username = v
|
||||
// }
|
||||
// if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" {
|
||||
// c.Password = v
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// WithBaseURL sets the base URL for the client to make requests to.
|
||||
func WithBaseURL(u string) Opt {
|
||||
|
|
|
@ -3,65 +3,64 @@ package reddit
|
|||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWithHTTPClient(t *testing.T) {
|
||||
_, err := NewClient(nil, WithHTTPClient(nil))
|
||||
_, err := NewClient(&Credentials{}, WithHTTPClient(nil))
|
||||
require.EqualError(t, err, "httpClient: cannot be nil")
|
||||
|
||||
_, err = NewClient(nil, WithHTTPClient(&http.Client{}))
|
||||
_, err = NewClient(&Credentials{}, WithHTTPClient(&http.Client{}))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestFromEnv(t *testing.T) {
|
||||
os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
|
||||
defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
|
||||
// func TestFromEnv(t *testing.T) {
|
||||
// os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
|
||||
// defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
|
||||
|
||||
os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1")
|
||||
defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET")
|
||||
// os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1")
|
||||
// defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET")
|
||||
|
||||
os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1")
|
||||
defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME")
|
||||
// os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1")
|
||||
// defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME")
|
||||
|
||||
os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1")
|
||||
defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD")
|
||||
// os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1")
|
||||
// defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD")
|
||||
|
||||
c, err := NewClient(nil, FromEnv)
|
||||
require.NoError(t, err)
|
||||
// c, err := NewClient(&Credentials{}, FromEnv)
|
||||
// require.NoError(t, err)
|
||||
|
||||
type values struct {
|
||||
id, secret, username, password string
|
||||
}
|
||||
// type values struct {
|
||||
// id, secret, username, password string
|
||||
// }
|
||||
|
||||
expect := values{"id1", "secret1", "username1", "password1"}
|
||||
actual := values{c.ID, c.Secret, c.Username, c.Password}
|
||||
require.Equal(t, expect, actual)
|
||||
}
|
||||
// expect := values{"id1", "secret1", "username1", "password1"}
|
||||
// actual := values{c.ID, c.Secret, c.Username, c.Password}
|
||||
// require.Equal(t, expect, actual)
|
||||
// }
|
||||
|
||||
func TestWithBaseURL(t *testing.T) {
|
||||
c, err := NewClient(nil, WithBaseURL(":"))
|
||||
c, err := NewClient(&Credentials{}, WithBaseURL(":"))
|
||||
urlErr, ok := err.(*url.Error)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "parse", urlErr.Op)
|
||||
|
||||
baseURL := "http://localhost:8080"
|
||||
c, err = NewClient(nil, WithBaseURL(baseURL))
|
||||
c, err = NewClient(&Credentials{}, WithBaseURL(baseURL))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, baseURL, c.BaseURL.String())
|
||||
}
|
||||
|
||||
func TestWithTokenURL(t *testing.T) {
|
||||
c, err := NewClient(nil, WithTokenURL(":"))
|
||||
c, err := NewClient(&Credentials{}, WithTokenURL(":"))
|
||||
urlErr, ok := err.(*url.Error)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "parse", urlErr.Op)
|
||||
|
||||
tokenURL := "http://localhost:8080/api/v1/access_token"
|
||||
c, err = NewClient(nil, WithTokenURL(tokenURL))
|
||||
c, err = NewClient(&Credentials{}, WithTokenURL(tokenURL))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tokenURL, c.TokenURL.String())
|
||||
}
|
||||
|
|
42
reddit/reddit-user-agent.go
Normal file
42
reddit/reddit-user-agent.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package reddit
|
||||
|
||||
import "net/http"
|
||||
|
||||
// cloneRequest returns a clone of the provided *http.Request.
|
||||
// The clone is a shallow copy of the struct and its Header map,
|
||||
// since we'll only be modify the headers.
|
||||
// Per the specification of http.RoundTripper, we should not directly modify a request.
|
||||
func cloneRequest(r *http.Request) *http.Request {
|
||||
r2 := new(http.Request)
|
||||
*r2 = *r
|
||||
// deep copy of the Header
|
||||
r2.Header = make(http.Header, len(r.Header))
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = append([]string(nil), s...)
|
||||
}
|
||||
return r2
|
||||
}
|
||||
|
||||
// Sets the User-Agent header for requests.
|
||||
type userAgentTransport struct {
|
||||
userAgent string
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
|
||||
req2 := cloneRequest(req)
|
||||
req2.Header.Set(headerUserAgent, t.userAgent)
|
||||
return req2
|
||||
}
|
||||
|
||||
func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req2 := t.setUserAgent(req)
|
||||
return t.base().RoundTrip(req2)
|
||||
}
|
||||
|
||||
func (t *userAgentTransport) base() http.RoundTripper {
|
||||
if t.Base != nil {
|
||||
return t.Base
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
@ -31,45 +32,6 @@ const (
|
|||
headerUserAgent = "User-Agent"
|
||||
)
|
||||
|
||||
// cloneRequest returns a clone of the provided *http.Request.
|
||||
// The clone is a shallow copy of the struct and its Header map,
|
||||
// since we'll only be modify the headers.
|
||||
// Per the specification of http.RoundTripper, we should not directly modify a request.
|
||||
func cloneRequest(r *http.Request) *http.Request {
|
||||
r2 := new(http.Request)
|
||||
*r2 = *r
|
||||
// deep copy of the Header
|
||||
r2.Header = make(http.Header, len(r.Header))
|
||||
for k, s := range r.Header {
|
||||
r2.Header[k] = append([]string(nil), s...)
|
||||
}
|
||||
return r2
|
||||
}
|
||||
|
||||
// Sets the User-Agent header for requests.
|
||||
type userAgentTransport struct {
|
||||
userAgent string
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
|
||||
req2 := cloneRequest(req)
|
||||
req2.Header.Set(headerUserAgent, t.userAgent)
|
||||
return req2
|
||||
}
|
||||
|
||||
func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req2 := t.setUserAgent(req)
|
||||
return t.base().RoundTrip(req2)
|
||||
}
|
||||
|
||||
func (t *userAgentTransport) base() http.RoundTripper {
|
||||
if t.Base != nil {
|
||||
return t.Base
|
||||
}
|
||||
return http.DefaultTransport
|
||||
}
|
||||
|
||||
// RequestCompletionCallback defines the type of the request callback function.
|
||||
type RequestCompletionCallback func(*http.Request, *http.Response)
|
||||
|
||||
|
@ -152,8 +114,11 @@ func newClient() *Client {
|
|||
|
||||
// NewClient returns a new Reddit API client.
|
||||
func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
|
||||
client := newClient()
|
||||
if creds == nil {
|
||||
return nil, errors.New("must provide credentials to initialize *reddit.Client")
|
||||
}
|
||||
|
||||
client := newClient()
|
||||
for _, opt := range opts {
|
||||
if err := opt(client); err != nil {
|
||||
return nil, err
|
||||
|
@ -164,6 +129,14 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
|
|||
client.client = &http.Client{}
|
||||
}
|
||||
|
||||
// We need to set a custom user agent, because using the one set by the
|
||||
// stdlib gives us 429 Too Many Requests responses from the Reddit API.
|
||||
userAgentTransport := &userAgentTransport{
|
||||
userAgent: client.UserAgent(),
|
||||
Base: client.client.Transport,
|
||||
}
|
||||
client.client.Transport = userAgentTransport
|
||||
|
||||
// todo...
|
||||
// Some endpoints (notably the ones to get random subreddits/posts) redirect to a
|
||||
// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
|
||||
|
@ -183,15 +156,13 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if creds != nil {
|
||||
client.ID = creds.ID
|
||||
client.Secret = creds.Secret
|
||||
client.Username = creds.Username
|
||||
client.Password = creds.Password
|
||||
client.ID = creds.ID
|
||||
client.Secret = creds.Secret
|
||||
client.Username = creds.Username
|
||||
client.Password = creds.Password
|
||||
|
||||
oauthTransport := oauthTransport(client)
|
||||
client.client.Transport = oauthTransport
|
||||
}
|
||||
oauthTransport := oauthTransport(client)
|
||||
client.client.Transport = oauthTransport
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
@ -204,7 +175,7 @@ func (c *Client) UserAgent() string {
|
|||
return c.userAgent
|
||||
}
|
||||
|
||||
// NewRequest creates an API request.
|
||||
// NewRequest creates an API request with a JSON body.
|
||||
// The path is the relative URL which will be resolves to the BaseURL of the Client.
|
||||
// It should always be specified without a preceding slash.
|
||||
func (c *Client) NewRequest(method string, path string, body interface{}) (*http.Request, error) {
|
||||
|
@ -253,7 +224,7 @@ func (c *Client) NewRequestWithForm(method string, path string, form url.Values)
|
|||
return req, nil
|
||||
}
|
||||
|
||||
// Response is a PlayNetwork response. This wraps the standard http.Response returned from PlayNetwork.
|
||||
// Response is a Reddit response. This wraps the standard http.Response returned from Reddit.
|
||||
type Response struct {
|
||||
*http.Response
|
||||
}
|
||||
|
|
|
@ -92,17 +92,20 @@ func testClientDefaults(t *testing.T, c *Client) {
|
|||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
c, err := NewClient(nil)
|
||||
c, err := NewClient(&Credentials{})
|
||||
require.NoError(t, err)
|
||||
testClientDefaults(t, c)
|
||||
}
|
||||
|
||||
func TestNewClient_Error(t *testing.T) {
|
||||
_, err := NewClient(nil)
|
||||
require.EqualError(t, err, "must provide credentials to initialize *reddit.Client")
|
||||
|
||||
errorOpt := func(c *Client) error {
|
||||
return errors.New("foo")
|
||||
}
|
||||
|
||||
_, err := NewClient(nil, errorOpt)
|
||||
_, err = NewClient(&Credentials{}, errorOpt)
|
||||
require.EqualError(t, err, "foo")
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue