Use oauth2 package's PasswordCredentialsToken method

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-08-06 17:41:20 -04:00
parent e49f4bf738
commit a9c581df34
4 changed files with 65 additions and 62 deletions

View file

@ -48,13 +48,13 @@ func cloneRequest(r *http.Request) *http.Request {
// Sets the User-Agent header for requests.
type userAgentTransport struct {
ua string
Base http.RoundTripper
userAgent string
Base http.RoundTripper
}
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
req2 := cloneRequest(req)
req2.Header.Set(headerUserAgent, t.ua)
req2.Header.Set(headerUserAgent, t.userAgent)
return req2
}
@ -163,25 +163,20 @@ func NewClient(httpClient *http.Client, opts ...Opt) (c *Client, err error) {
}
}
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
userAgentTransport := &userAgentTransport{
ua: c.userAgent,
Base: c.client.Transport,
}
oauth2Config := oauth2Config{
id: c.ID,
secret: c.Secret,
username: c.Username,
password: c.Password,
tokenURL: c.TokenURL.String(),
userAgentTransport: userAgentTransport,
}
c.client.Transport = oauth2Transport(oauth2Config)
oauthTransport := oauthTransport(c)
c.client.Transport = oauthTransport
return
}
// UserAgent returns the client's user agent.
func (c *Client) UserAgent() string {
if c.userAgent == "" {
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
}
return c.userAgent
}
// NewRequest creates an API request.
// The path is the relative URL which will be resolves to the BaseURL of the Client.
// It should always be specified without a preceding slash.

View file

@ -36,63 +36,50 @@ package reddit
import (
"context"
"net/http"
"net/url"
"golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
)
var endpoint = oauth2.Endpoint{
TokenURL: "https://www.reddit.com/api/v1/access_token",
AuthStyle: oauth2.AuthStyleInHeader,
type oauthTokenSource struct {
ctx context.Context
config *oauth2.Config
username, password string
}
type oauth2Config struct {
id string
secret string
username string
password string
tokenURL string
func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password)
}
func oauthTransport(client *Client) http.RoundTripper {
// We need to set a custom user agent, because using the one set by default by the
// stdlib gives us 429 Too Many Request responses from the Reddit API.
userAgentTransport *userAgentTransport
}
func oauth2Transport(c oauth2Config) *oauth2.Transport {
// todo: use oauth2.Config.PasswordCredentialsToken
params := url.Values{
"grant_type": {"password"},
"username": {c.username},
"password": {c.password},
userAgentTransport := &userAgentTransport{
userAgent: client.UserAgent(),
Base: client.client.Transport,
}
cfg := clientcredentials.Config{
ClientID: c.id,
ClientSecret: c.secret,
TokenURL: c.tokenURL,
AuthStyle: oauth2.AuthStyleInHeader,
EndpointParams: params,
httpClient := &http.Client{Transport: userAgentTransport}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
config := &oauth2.Config{
ClientID: client.ID,
ClientSecret: client.Secret,
Endpoint: oauth2.Endpoint{
TokenURL: client.TokenURL.String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
}
httpClient := &http.Client{Transport: c.userAgentTransport}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{
ctx: ctx,
config: config,
username: client.Username,
password: client.Password,
})
src := cfg.TokenSource(ctx)
tr := &oauth2.Transport{
Source: src,
Base: c.userAgentTransport,
}
return tr
}
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
func WithCredentials(id, secret, username, password string) Opt {
return func(c *Client) error {
c.ID = id
c.Secret = secret
c.Username = username
c.Password = password
return nil
return &oauth2.Transport{
Source: tokenSource,
Base: userAgentTransport,
}
}

View file

@ -35,6 +35,17 @@ func FromEnv(c *Client) error {
return nil
}
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
func WithCredentials(id, secret, username, password string) Opt {
return func(c *Client) error {
c.ID = id
c.Secret = secret
c.Username = username
c.Password = password
return nil
}
}
// WithBaseURL sets the base URL for the client to make requests to.
func WithBaseURL(u string) Opt {
return func(c *Client) error {

View file

@ -32,6 +32,16 @@ func TestFromEnv(t *testing.T) {
assert.Equal(t, expect, actual)
}
func TestWithCredentials(t *testing.T) {
withCredentials := WithCredentials("id1", "secret1", "username1", "password1")
c, err := NewClient(nil, withCredentials)
assert.NoError(t, err)
assert.Equal(t, "id1", c.ID)
assert.Equal(t, "secret1", c.Secret)
assert.Equal(t, "username1", c.Username)
assert.Equal(t, "password1", c.Password)
}
func TestWithBaseURL(t *testing.T) {
baseURL := "http://localhost:8080"
c, err := NewClient(nil, WithBaseURL(baseURL))