Use oauth2 package's PasswordCredentialsToken method
Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
parent
e49f4bf738
commit
a9c581df34
4 changed files with 65 additions and 62 deletions
29
reddit.go
29
reddit.go
|
@ -48,13 +48,13 @@ func cloneRequest(r *http.Request) *http.Request {
|
||||||
|
|
||||||
// Sets the User-Agent header for requests.
|
// Sets the User-Agent header for requests.
|
||||||
type userAgentTransport struct {
|
type userAgentTransport struct {
|
||||||
ua string
|
userAgent string
|
||||||
Base http.RoundTripper
|
Base http.RoundTripper
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
|
func (t *userAgentTransport) setUserAgent(req *http.Request) *http.Request {
|
||||||
req2 := cloneRequest(req)
|
req2 := cloneRequest(req)
|
||||||
req2.Header.Set(headerUserAgent, t.ua)
|
req2.Header.Set(headerUserAgent, t.userAgent)
|
||||||
return req2
|
return req2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,25 +163,20 @@ func NewClient(httpClient *http.Client, opts ...Opt) (c *Client, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
|
oauthTransport := oauthTransport(c)
|
||||||
userAgentTransport := &userAgentTransport{
|
c.client.Transport = oauthTransport
|
||||||
ua: c.userAgent,
|
|
||||||
Base: c.client.Transport,
|
|
||||||
}
|
|
||||||
|
|
||||||
oauth2Config := oauth2Config{
|
|
||||||
id: c.ID,
|
|
||||||
secret: c.Secret,
|
|
||||||
username: c.Username,
|
|
||||||
password: c.Password,
|
|
||||||
tokenURL: c.TokenURL.String(),
|
|
||||||
userAgentTransport: userAgentTransport,
|
|
||||||
}
|
|
||||||
c.client.Transport = oauth2Transport(oauth2Config)
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserAgent returns the client's user agent.
|
||||||
|
func (c *Client) UserAgent() string {
|
||||||
|
if c.userAgent == "" {
|
||||||
|
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
|
||||||
|
}
|
||||||
|
return c.userAgent
|
||||||
|
}
|
||||||
|
|
||||||
// NewRequest creates an API request.
|
// NewRequest creates an API request.
|
||||||
// The path is the relative URL which will be resolves to the BaseURL of the Client.
|
// The path is the relative URL which will be resolves to the BaseURL of the Client.
|
||||||
// It should always be specified without a preceding slash.
|
// It should always be specified without a preceding slash.
|
||||||
|
|
|
@ -36,63 +36,50 @@ package reddit
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/clientcredentials"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var endpoint = oauth2.Endpoint{
|
type oauthTokenSource struct {
|
||||||
TokenURL: "https://www.reddit.com/api/v1/access_token",
|
ctx context.Context
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
config *oauth2.Config
|
||||||
|
username, password string
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauth2Config struct {
|
func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
|
||||||
id string
|
return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password)
|
||||||
secret string
|
}
|
||||||
username string
|
|
||||||
password string
|
|
||||||
tokenURL string
|
|
||||||
|
|
||||||
|
func oauthTransport(client *Client) http.RoundTripper {
|
||||||
// We need to set a custom user agent, because using the one set by default by the
|
// We need to set a custom user agent, because using the one set by default by the
|
||||||
// stdlib gives us 429 Too Many Request responses from the Reddit API.
|
// stdlib gives us 429 Too Many Request responses from the Reddit API.
|
||||||
userAgentTransport *userAgentTransport
|
userAgentTransport := &userAgentTransport{
|
||||||
}
|
userAgent: client.UserAgent(),
|
||||||
|
Base: client.client.Transport,
|
||||||
func oauth2Transport(c oauth2Config) *oauth2.Transport {
|
|
||||||
// todo: use oauth2.Config.PasswordCredentialsToken
|
|
||||||
params := url.Values{
|
|
||||||
"grant_type": {"password"},
|
|
||||||
"username": {c.username},
|
|
||||||
"password": {c.password},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := clientcredentials.Config{
|
httpClient := &http.Client{Transport: userAgentTransport}
|
||||||
ClientID: c.id,
|
ctx := context.Background()
|
||||||
ClientSecret: c.secret,
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||||
TokenURL: c.tokenURL,
|
|
||||||
|
config := &oauth2.Config{
|
||||||
|
ClientID: client.ID,
|
||||||
|
ClientSecret: client.Secret,
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
TokenURL: client.TokenURL.String(),
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
EndpointParams: params,
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient := &http.Client{Transport: c.userAgentTransport}
|
tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{
|
||||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
ctx: ctx,
|
||||||
|
config: config,
|
||||||
|
username: client.Username,
|
||||||
|
password: client.Password,
|
||||||
|
})
|
||||||
|
|
||||||
src := cfg.TokenSource(ctx)
|
return &oauth2.Transport{
|
||||||
tr := &oauth2.Transport{
|
Source: tokenSource,
|
||||||
Source: src,
|
Base: userAgentTransport,
|
||||||
Base: c.userAgentTransport,
|
|
||||||
}
|
|
||||||
return tr
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
|
|
||||||
func WithCredentials(id, secret, username, password string) Opt {
|
|
||||||
return func(c *Client) error {
|
|
||||||
c.ID = id
|
|
||||||
c.Secret = secret
|
|
||||||
c.Username = username
|
|
||||||
c.Password = password
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,6 +35,17 @@ func FromEnv(c *Client) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
|
||||||
|
func WithCredentials(id, secret, username, password string) Opt {
|
||||||
|
return func(c *Client) error {
|
||||||
|
c.ID = id
|
||||||
|
c.Secret = secret
|
||||||
|
c.Username = username
|
||||||
|
c.Password = password
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithBaseURL sets the base URL for the client to make requests to.
|
// WithBaseURL sets the base URL for the client to make requests to.
|
||||||
func WithBaseURL(u string) Opt {
|
func WithBaseURL(u string) Opt {
|
||||||
return func(c *Client) error {
|
return func(c *Client) error {
|
||||||
|
|
|
@ -32,6 +32,16 @@ func TestFromEnv(t *testing.T) {
|
||||||
assert.Equal(t, expect, actual)
|
assert.Equal(t, expect, actual)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithCredentials(t *testing.T) {
|
||||||
|
withCredentials := WithCredentials("id1", "secret1", "username1", "password1")
|
||||||
|
c, err := NewClient(nil, withCredentials)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "id1", c.ID)
|
||||||
|
assert.Equal(t, "secret1", c.Secret)
|
||||||
|
assert.Equal(t, "username1", c.Username)
|
||||||
|
assert.Equal(t, "password1", c.Password)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithBaseURL(t *testing.T) {
|
func TestWithBaseURL(t *testing.T) {
|
||||||
baseURL := "http://localhost:8080"
|
baseURL := "http://localhost:8080"
|
||||||
c, err := NewClient(nil, WithBaseURL(baseURL))
|
c, err := NewClient(nil, WithBaseURL(baseURL))
|
||||||
|
|
Loading…
Reference in a new issue