Use oauth2 package's PasswordCredentialsToken method
Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
parent
e49f4bf738
commit
a9c581df34
4 changed files with 65 additions and 62 deletions
31
reddit.go
31
reddit.go
|
@ -48,13 +48,13 @@ func cloneRequest(r *http.Request) *http.Request {
|
|||
|
||||
// Sets the User-Agent header for requests.
|
||||
type userAgentTransport struct {
|
||||
ua string
|
||||
Base http.RoundTripper
|
||||
userAgent string
|
||||
Base http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
|
||||
req2 := cloneRequest(req)
|
||||
req2.Header.Set(headerUserAgent, t.ua)
|
||||
req2.Header.Set(headerUserAgent, t.userAgent)
|
||||
return req2
|
||||
}
|
||||
|
||||
|
@ -163,25 +163,20 @@ func NewClient(httpClient *http.Client, opts ...Opt) (c *Client, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
|
||||
userAgentTransport := &userAgentTransport{
|
||||
ua: c.userAgent,
|
||||
Base: c.client.Transport,
|
||||
}
|
||||
|
||||
oauth2Config := oauth2Config{
|
||||
id: c.ID,
|
||||
secret: c.Secret,
|
||||
username: c.Username,
|
||||
password: c.Password,
|
||||
tokenURL: c.TokenURL.String(),
|
||||
userAgentTransport: userAgentTransport,
|
||||
}
|
||||
c.client.Transport = oauth2Transport(oauth2Config)
|
||||
oauthTransport := oauthTransport(c)
|
||||
c.client.Transport = oauthTransport
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// UserAgent returns the client's user agent.
|
||||
func (c *Client) UserAgent() string {
|
||||
if c.userAgent == "" {
|
||||
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
|
||||
}
|
||||
return c.userAgent
|
||||
}
|
||||
|
||||
// NewRequest creates an API request.
|
||||
// The path is the relative URL which will be resolves to the BaseURL of the Client.
|
||||
// It should always be specified without a preceding slash.
|
||||
|
|
|
@ -36,63 +36,50 @@ package reddit
|
|||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
)
|
||||
|
||||
var endpoint = oauth2.Endpoint{
|
||||
TokenURL: "https://www.reddit.com/api/v1/access_token",
|
||||
AuthStyle: oauth2.AuthStyleInHeader,
|
||||
type oauthTokenSource struct {
|
||||
ctx context.Context
|
||||
config *oauth2.Config
|
||||
username, password string
|
||||
}
|
||||
|
||||
type oauth2Config struct {
|
||||
id string
|
||||
secret string
|
||||
username string
|
||||
password string
|
||||
tokenURL string
|
||||
func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
|
||||
return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password)
|
||||
}
|
||||
|
||||
func oauthTransport(client *Client) http.RoundTripper {
|
||||
// We need to set a custom user agent, because using the one set by default by the
|
||||
// stdlib gives us 429 Too Many Request responses from the Reddit API.
|
||||
userAgentTransport *userAgentTransport
|
||||
}
|
||||
|
||||
func oauth2Transport(c oauth2Config) *oauth2.Transport {
|
||||
// todo: use oauth2.Config.PasswordCredentialsToken
|
||||
params := url.Values{
|
||||
"grant_type": {"password"},
|
||||
"username": {c.username},
|
||||
"password": {c.password},
|
||||
userAgentTransport := &userAgentTransport{
|
||||
userAgent: client.UserAgent(),
|
||||
Base: client.client.Transport,
|
||||
}
|
||||
|
||||
cfg := clientcredentials.Config{
|
||||
ClientID: c.id,
|
||||
ClientSecret: c.secret,
|
||||
TokenURL: c.tokenURL,
|
||||
AuthStyle: oauth2.AuthStyleInHeader,
|
||||
EndpointParams: params,
|
||||
httpClient := &http.Client{Transport: userAgentTransport}
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
|
||||
config := &oauth2.Config{
|
||||
ClientID: client.ID,
|
||||
ClientSecret: client.Secret,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
TokenURL: client.TokenURL.String(),
|
||||
AuthStyle: oauth2.AuthStyleInHeader,
|
||||
},
|
||||
}
|
||||
|
||||
httpClient := &http.Client{Transport: c.userAgentTransport}
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||
tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
username: client.Username,
|
||||
password: client.Password,
|
||||
})
|
||||
|
||||
src := cfg.TokenSource(ctx)
|
||||
tr := &oauth2.Transport{
|
||||
Source: src,
|
||||
Base: c.userAgentTransport,
|
||||
}
|
||||
return tr
|
||||
}
|
||||
|
||||
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
|
||||
func WithCredentials(id, secret, username, password string) Opt {
|
||||
return func(c *Client) error {
|
||||
c.ID = id
|
||||
c.Secret = secret
|
||||
c.Username = username
|
||||
c.Password = password
|
||||
return nil
|
||||
return &oauth2.Transport{
|
||||
Source: tokenSource,
|
||||
Base: userAgentTransport,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,6 +35,17 @@ func FromEnv(c *Client) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
|
||||
func WithCredentials(id, secret, username, password string) Opt {
|
||||
return func(c *Client) error {
|
||||
c.ID = id
|
||||
c.Secret = secret
|
||||
c.Username = username
|
||||
c.Password = password
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithBaseURL sets the base URL for the client to make requests to.
|
||||
func WithBaseURL(u string) Opt {
|
||||
return func(c *Client) error {
|
||||
|
|
|
@ -32,6 +32,16 @@ func TestFromEnv(t *testing.T) {
|
|||
assert.Equal(t, expect, actual)
|
||||
}
|
||||
|
||||
func TestWithCredentials(t *testing.T) {
|
||||
withCredentials := WithCredentials("id1", "secret1", "username1", "password1")
|
||||
c, err := NewClient(nil, withCredentials)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "id1", c.ID)
|
||||
assert.Equal(t, "secret1", c.Secret)
|
||||
assert.Equal(t, "username1", c.Username)
|
||||
assert.Equal(t, "password1", c.Password)
|
||||
}
|
||||
|
||||
func TestWithBaseURL(t *testing.T) {
|
||||
baseURL := "http://localhost:8080"
|
||||
c, err := NewClient(nil, WithBaseURL(baseURL))
|
||||
|
|
Loading…
Reference in a new issue