Use oauth2 package's PasswordCredentialsToken method

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-08-06 17:41:20 -04:00
parent e49f4bf738
commit a9c581df34
4 changed files with 65 additions and 62 deletions

View file

@ -48,13 +48,13 @@ func cloneRequest(r *http.Request) *http.Request {
// Sets the User-Agent header for requests. // Sets the User-Agent header for requests.
type userAgentTransport struct { type userAgentTransport struct {
ua string userAgent string
Base http.RoundTripper Base http.RoundTripper
} }
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request { func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
req2 := cloneRequest(req) req2 := cloneRequest(req)
req2.Header.Set(headerUserAgent, t.ua) req2.Header.Set(headerUserAgent, t.userAgent)
return req2 return req2
} }
@ -163,25 +163,20 @@ func NewClient(httpClient *http.Client, opts ...Opt) (c *Client, err error) {
} }
} }
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username) oauthTransport := oauthTransport(c)
userAgentTransport := &userAgentTransport{ c.client.Transport = oauthTransport
ua: c.userAgent,
Base: c.client.Transport,
}
oauth2Config := oauth2Config{
id: c.ID,
secret: c.Secret,
username: c.Username,
password: c.Password,
tokenURL: c.TokenURL.String(),
userAgentTransport: userAgentTransport,
}
c.client.Transport = oauth2Transport(oauth2Config)
return return
} }
// UserAgent returns the client's user agent.
func (c *Client) UserAgent() string {
if c.userAgent == "" {
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
}
return c.userAgent
}
// NewRequest creates an API request. // NewRequest creates an API request.
// The path is the relative URL which will be resolves to the BaseURL of the Client. // The path is the relative URL which will be resolves to the BaseURL of the Client.
// It should always be specified without a preceding slash. // It should always be specified without a preceding slash.

View file

@ -36,63 +36,50 @@ package reddit
import ( import (
"context" "context"
"net/http" "net/http"
"net/url"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
) )
var endpoint = oauth2.Endpoint{ type oauthTokenSource struct {
TokenURL: "https://www.reddit.com/api/v1/access_token", ctx context.Context
AuthStyle: oauth2.AuthStyleInHeader, config *oauth2.Config
username, password string
} }
type oauth2Config struct { func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
id string return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password)
secret string }
username string
password string
tokenURL string
func oauthTransport(client *Client) http.RoundTripper {
// We need to set a custom user agent, because using the one set by default by the // We need to set a custom user agent, because using the one set by default by the
// stdlib gives us 429 Too Many Request responses from the Reddit API. // stdlib gives us 429 Too Many Request responses from the Reddit API.
userAgentTransport *userAgentTransport userAgentTransport := &userAgentTransport{
} userAgent: client.UserAgent(),
Base: client.client.Transport,
func oauth2Transport(c oauth2Config) *oauth2.Transport {
// todo: use oauth2.Config.PasswordCredentialsToken
params := url.Values{
"grant_type": {"password"},
"username": {c.username},
"password": {c.password},
} }
cfg := clientcredentials.Config{ httpClient := &http.Client{Transport: userAgentTransport}
ClientID: c.id, ctx := context.Background()
ClientSecret: c.secret, ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
TokenURL: c.tokenURL,
config := &oauth2.Config{
ClientID: client.ID,
ClientSecret: client.Secret,
Endpoint: oauth2.Endpoint{
TokenURL: client.TokenURL.String(),
AuthStyle: oauth2.AuthStyleInHeader, AuthStyle: oauth2.AuthStyleInHeader,
EndpointParams: params, },
} }
httpClient := &http.Client{Transport: c.userAgentTransport} tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) ctx: ctx,
config: config,
username: client.Username,
password: client.Password,
})
src := cfg.TokenSource(ctx) return &oauth2.Transport{
tr := &oauth2.Transport{ Source: tokenSource,
Source: src, Base: userAgentTransport,
Base: c.userAgentTransport,
}
return tr
}
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
func WithCredentials(id, secret, username, password string) Opt {
return func(c *Client) error {
c.ID = id
c.Secret = secret
c.Username = username
c.Password = password
return nil
} }
} }

View file

@ -35,6 +35,17 @@ func FromEnv(c *Client) error {
return nil return nil
} }
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
func WithCredentials(id, secret, username, password string) Opt {
return func(c *Client) error {
c.ID = id
c.Secret = secret
c.Username = username
c.Password = password
return nil
}
}
// WithBaseURL sets the base URL for the client to make requests to. // WithBaseURL sets the base URL for the client to make requests to.
func WithBaseURL(u string) Opt { func WithBaseURL(u string) Opt {
return func(c *Client) error { return func(c *Client) error {

View file

@ -32,6 +32,16 @@ func TestFromEnv(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
} }
func TestWithCredentials(t *testing.T) {
withCredentials := WithCredentials("id1", "secret1", "username1", "password1")
c, err := NewClient(nil, withCredentials)
assert.NoError(t, err)
assert.Equal(t, "id1", c.ID)
assert.Equal(t, "secret1", c.Secret)
assert.Equal(t, "username1", c.Username)
assert.Equal(t, "password1", c.Password)
}
func TestWithBaseURL(t *testing.T) { func TestWithBaseURL(t *testing.T) {
baseURL := "http://localhost:8080" baseURL := "http://localhost:8080"
c, err := NewClient(nil, WithBaseURL(baseURL)) c, err := NewClient(nil, WithBaseURL(baseURL))