From a9c581df3406419c4ec18f1d8f2231a06450c40b Mon Sep 17 00:00:00 2001 From: Vartan Benohanian Date: Thu, 6 Aug 2020 17:41:20 -0400 Subject: [PATCH] Use oauth2 package's PasswordCredentialsToken method Signed-off-by: Vartan Benohanian --- reddit.go | 31 ++++++++--------- reddit_oauth.go | 75 +++++++++++++++++------------------------- reddit_options.go | 11 +++++++ reddit_options_test.go | 10 ++++++ 4 files changed, 65 insertions(+), 62 deletions(-) diff --git a/reddit.go b/reddit.go index f8ff8f7..27f3a30 100644 --- a/reddit.go +++ b/reddit.go @@ -48,13 +48,13 @@ func cloneRequest(r *http.Request) *http.Request { // Sets the User-Agent header for requests. type userAgentTransport struct { - ua string - Base http.RoundTripper + userAgent string + Base http.RoundTripper } func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request { req2 := cloneRequest(req) - req2.Header.Set(headerUserAgent, t.ua) + req2.Header.Set(headerUserAgent, t.userAgent) return req2 } @@ -163,25 +163,20 @@ func NewClient(httpClient *http.Client, opts ...Opt) (c *Client, err error) { } } - c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username) - userAgentTransport := &userAgentTransport{ - ua: c.userAgent, - Base: c.client.Transport, - } - - oauth2Config := oauth2Config{ - id: c.ID, - secret: c.Secret, - username: c.Username, - password: c.Password, - tokenURL: c.TokenURL.String(), - userAgentTransport: userAgentTransport, - } - c.client.Transport = oauth2Transport(oauth2Config) + oauthTransport := oauthTransport(c) + c.client.Transport = oauthTransport return } +// UserAgent returns the client's user agent. +func (c *Client) UserAgent() string { + if c.userAgent == "" { + c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username) + } + return c.userAgent +} + // NewRequest creates an API request. // The path is the relative URL which will be resolves to the BaseURL of the Client. // It should always be specified without a preceding slash. diff --git a/reddit_oauth.go b/reddit_oauth.go index ef440fb..b18efdd 100644 --- a/reddit_oauth.go +++ b/reddit_oauth.go @@ -36,63 +36,50 @@ package reddit import ( "context" "net/http" - "net/url" "golang.org/x/oauth2" - "golang.org/x/oauth2/clientcredentials" ) -var endpoint = oauth2.Endpoint{ - TokenURL: "https://www.reddit.com/api/v1/access_token", - AuthStyle: oauth2.AuthStyleInHeader, +type oauthTokenSource struct { + ctx context.Context + config *oauth2.Config + username, password string } -type oauth2Config struct { - id string - secret string - username string - password string - tokenURL string +func (s *oauthTokenSource) Token() (*oauth2.Token, error) { + return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password) +} +func oauthTransport(client *Client) http.RoundTripper { // We need to set a custom user agent, because using the one set by default by the // stdlib gives us 429 Too Many Request responses from the Reddit API. - userAgentTransport *userAgentTransport -} - -func oauth2Transport(c oauth2Config) *oauth2.Transport { - // todo: use oauth2.Config.PasswordCredentialsToken - params := url.Values{ - "grant_type": {"password"}, - "username": {c.username}, - "password": {c.password}, + userAgentTransport := &userAgentTransport{ + userAgent: client.UserAgent(), + Base: client.client.Transport, } - cfg := clientcredentials.Config{ - ClientID: c.id, - ClientSecret: c.secret, - TokenURL: c.tokenURL, - AuthStyle: oauth2.AuthStyleInHeader, - EndpointParams: params, + httpClient := &http.Client{Transport: userAgentTransport} + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + config := &oauth2.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + Endpoint: oauth2.Endpoint{ + TokenURL: client.TokenURL.String(), + AuthStyle: oauth2.AuthStyleInHeader, + }, } - httpClient := &http.Client{Transport: c.userAgentTransport} - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) + tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{ + ctx: ctx, + config: config, + username: client.Username, + password: client.Password, + }) - src := cfg.TokenSource(ctx) - tr := &oauth2.Transport{ - Source: src, - Base: c.userAgentTransport, - } - return tr -} - -// WithCredentials sets the necessary values for the client to authenticate via OAuth2. -func WithCredentials(id, secret, username, password string) Opt { - return func(c *Client) error { - c.ID = id - c.Secret = secret - c.Username = username - c.Password = password - return nil + return &oauth2.Transport{ + Source: tokenSource, + Base: userAgentTransport, } } diff --git a/reddit_options.go b/reddit_options.go index ad27ee2..ae6221b 100644 --- a/reddit_options.go +++ b/reddit_options.go @@ -35,6 +35,17 @@ func FromEnv(c *Client) error { return nil } +// WithCredentials sets the necessary values for the client to authenticate via OAuth2. +func WithCredentials(id, secret, username, password string) Opt { + return func(c *Client) error { + c.ID = id + c.Secret = secret + c.Username = username + c.Password = password + return nil + } +} + // WithBaseURL sets the base URL for the client to make requests to. func WithBaseURL(u string) Opt { return func(c *Client) error { diff --git a/reddit_options_test.go b/reddit_options_test.go index f3fc11d..7b2adb9 100644 --- a/reddit_options_test.go +++ b/reddit_options_test.go @@ -32,6 +32,16 @@ func TestFromEnv(t *testing.T) { assert.Equal(t, expect, actual) } +func TestWithCredentials(t *testing.T) { + withCredentials := WithCredentials("id1", "secret1", "username1", "password1") + c, err := NewClient(nil, withCredentials) + assert.NoError(t, err) + assert.Equal(t, "id1", c.ID) + assert.Equal(t, "secret1", c.Secret) + assert.Equal(t, "username1", c.Username) + assert.Equal(t, "password1", c.Password) +} + func TestWithBaseURL(t *testing.T) { baseURL := "http://localhost:8080" c, err := NewClient(nil, WithBaseURL(baseURL))