diff --git a/reddit/reddit-oauth.go b/reddit/reddit-oauth.go index 13030d6..89e22bc 100644 --- a/reddit/reddit-oauth.go +++ b/reddit/reddit-oauth.go @@ -38,6 +38,7 @@ import ( "net/http" "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" ) type oauthTokenSource struct { @@ -50,28 +51,48 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) { return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password) } +type oauthTwoLeggedTokenSource struct { + ctx context.Context + config *clientcredentials.Config +} + +func (s *oauthTwoLeggedTokenSource) Token() (*oauth2.Token, error) { + return s.config.Token(s.ctx) +} + func oauthTransport(client *Client) http.RoundTripper { httpClient := &http.Client{Transport: client.client.Transport} ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) - config := &oauth2.Config{ - ClientID: client.ID, - ClientSecret: client.Secret, - Endpoint: oauth2.Endpoint{ - TokenURL: client.TokenURL.String(), - AuthStyle: oauth2.AuthStyleInHeader, - }, + var tokenSource oauth2.TokenSource + if client.applicationOnlyOAuth { + tokenSource = &oauthTwoLeggedTokenSource{ + ctx: ctx, + config: &clientcredentials.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + TokenURL: client.TokenURL.String(), + AuthStyle: oauth2.AuthStyleInHeader, + }, + } + } else { + tokenSource = &oauthTokenSource{ + ctx: ctx, + config: &oauth2.Config{ + ClientID: client.ID, + ClientSecret: client.Secret, + Endpoint: oauth2.Endpoint{ + TokenURL: client.TokenURL.String(), + AuthStyle: oauth2.AuthStyleInHeader, + }, + }, + username: client.Username, + password: client.Password, + } } - tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{ - ctx: ctx, - config: config, - username: client.Username, - password: client.Password, - }) - return &oauth2.Transport{ - Source: tokenSource, + Source: oauth2.ReuseTokenSource(nil, tokenSource), Base: client.client.Transport, } } diff --git a/reddit/reddit-options.go b/reddit/reddit-options.go index 25d397b..04cfab1 100644 --- a/reddit/reddit-options.go +++ b/reddit/reddit-options.go @@ -55,6 +55,16 @@ func WithTokenURL(u string) Opt { } } +// WithApplicationOnlyOAuth sets authentication flow to "Application Only OAuth". +// Only ID and Secret are required to be set in client. Username and Password are ignored. +// The flow is described here: https://github.com/reddit-archive/reddit/wiki/OAuth2#application-only-oauth +func WithApplicationOnlyOAuth(o bool) Opt { + return func(c *Client) error { + c.applicationOnlyOAuth = o + return nil + } +} + // FromEnv configures the client with values from environment variables. // Supported environment variables: // GO_REDDIT_CLIENT_ID to set the client's id. diff --git a/reddit/reddit-options_test.go b/reddit/reddit-options_test.go index 3db5ce8..cd0e499 100644 --- a/reddit/reddit-options_test.go +++ b/reddit/reddit-options_test.go @@ -1,13 +1,17 @@ package reddit import ( + "bytes" "fmt" + "io" + "io/ioutil" "net/http" "net/url" "os" "testing" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" ) func TestWithHTTPClient(t *testing.T) { @@ -52,6 +56,28 @@ func TestWithTokenURL(t *testing.T) { require.Equal(t, tokenURL, c.TokenURL.String()) } +type RequestInterceptor struct { + interceptedBody string +} + +func (t *RequestInterceptor) RoundTrip(r *http.Request) (*http.Response, error) { + requestBody, _ := ioutil.ReadAll(r.Body) + t.interceptedBody = string(requestBody) + var body bytes.Buffer + body.WriteString(`{"access_token": "foobar", "expires_in": 3600, "scope": "*", "token_type": "bearer"}`) + return &http.Response{Status: "200 OK", StatusCode: 200, Body: io.NopCloser(&body)}, nil +} + +func TestWithApplicationOnlyOAuth(t *testing.T) { + requestInterceptor := &RequestInterceptor{} + c, err := NewClient(Credentials{ID: "id", Secret: "secret"}, WithApplicationOnlyOAuth(true), WithHTTPClient(&http.Client{Transport: requestInterceptor})) + require.NoError(t, err) + token, err := c.client.Transport.(*oauth2.Transport).Source.Token() + require.NoError(t, err) + require.Equal(t, token.AccessToken, "foobar") + require.Equal(t, "grant_type=client_credentials", requestInterceptor.interceptedBody) +} + func TestFromEnv(t *testing.T) { os.Setenv("GO_REDDIT_CLIENT_ID", "id1") defer os.Unsetenv("GO_REDDIT_CLIENT_ID") diff --git a/reddit/reddit.go b/reddit/reddit.go index bafda0c..c292a9e 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -16,7 +16,6 @@ import ( "time" "github.com/google/go-querystring/query" - "golang.org/x/oauth2" ) const ( @@ -96,7 +95,7 @@ type Client struct { Widget *WidgetService Wiki *WikiService - oauth2Transport *oauth2.Transport + applicationOnlyOAuth bool onRequestCompleted RequestCompletionCallback }