From 83240db9f66d4f45f35b74fe3c7b60a1d62e59ec Mon Sep 17 00:00:00 2001
From: Andrew Dunai
Date: Fri, 30 Jul 2021 23:33:58 +0300
Subject: [PATCH] Add support for Application Only OAuth (aka "two-legged
oauth")
Signed-off-by: Andrew Dunai
---
reddit/reddit-oauth.go | 51 ++++++++++++++++++++++++-----------
reddit/reddit-options.go | 10 +++++++
reddit/reddit-options_test.go | 26 ++++++++++++++++++
reddit/reddit.go | 3 +--
4 files changed, 73 insertions(+), 17 deletions(-)
diff --git a/reddit/reddit-oauth.go b/reddit/reddit-oauth.go
index 13030d6..89e22bc 100644
--- a/reddit/reddit-oauth.go
+++ b/reddit/reddit-oauth.go
@@ -38,6 +38,7 @@ import (
"net/http"
"golang.org/x/oauth2"
+ "golang.org/x/oauth2/clientcredentials"
)
type oauthTokenSource struct {
@@ -50,28 +51,48 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password)
}
+type oauthTwoLeggedTokenSource struct {
+ ctx context.Context
+ config *clientcredentials.Config
+}
+
+func (s *oauthTwoLeggedTokenSource) Token() (*oauth2.Token, error) {
+ return s.config.Token(s.ctx)
+}
+
func oauthTransport(client *Client) http.RoundTripper {
httpClient := &http.Client{Transport: client.client.Transport}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
- config := &oauth2.Config{
- ClientID: client.ID,
- ClientSecret: client.Secret,
- Endpoint: oauth2.Endpoint{
- TokenURL: client.TokenURL.String(),
- AuthStyle: oauth2.AuthStyleInHeader,
- },
+ var tokenSource oauth2.TokenSource
+ if client.applicationOnlyOAuth {
+ tokenSource = &oauthTwoLeggedTokenSource{
+ ctx: ctx,
+ config: &clientcredentials.Config{
+ ClientID: client.ID,
+ ClientSecret: client.Secret,
+ TokenURL: client.TokenURL.String(),
+ AuthStyle: oauth2.AuthStyleInHeader,
+ },
+ }
+ } else {
+ tokenSource = &oauthTokenSource{
+ ctx: ctx,
+ config: &oauth2.Config{
+ ClientID: client.ID,
+ ClientSecret: client.Secret,
+ Endpoint: oauth2.Endpoint{
+ TokenURL: client.TokenURL.String(),
+ AuthStyle: oauth2.AuthStyleInHeader,
+ },
+ },
+ username: client.Username,
+ password: client.Password,
+ }
}
- tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{
- ctx: ctx,
- config: config,
- username: client.Username,
- password: client.Password,
- })
-
return &oauth2.Transport{
- Source: tokenSource,
+ Source: oauth2.ReuseTokenSource(nil, tokenSource),
Base: client.client.Transport,
}
}
diff --git a/reddit/reddit-options.go b/reddit/reddit-options.go
index 25d397b..04cfab1 100644
--- a/reddit/reddit-options.go
+++ b/reddit/reddit-options.go
@@ -55,6 +55,16 @@ func WithTokenURL(u string) Opt {
}
}
+// WithApplicationOnlyOAuth sets authentication flow to "Application Only OAuth".
+// Only ID and Secret are required to be set in client. Username and Password are ignored.
+// The flow is described here: https://github.com/reddit-archive/reddit/wiki/OAuth2#application-only-oauth
+func WithApplicationOnlyOAuth(o bool) Opt {
+ return func(c *Client) error {
+ c.applicationOnlyOAuth = o
+ return nil
+ }
+}
+
// FromEnv configures the client with values from environment variables.
// Supported environment variables:
// GO_REDDIT_CLIENT_ID to set the client's id.
diff --git a/reddit/reddit-options_test.go b/reddit/reddit-options_test.go
index 3db5ce8..cd0e499 100644
--- a/reddit/reddit-options_test.go
+++ b/reddit/reddit-options_test.go
@@ -1,13 +1,17 @@
package reddit
import (
+ "bytes"
"fmt"
+ "io"
+ "io/ioutil"
"net/http"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
+ "golang.org/x/oauth2"
)
func TestWithHTTPClient(t *testing.T) {
@@ -52,6 +56,28 @@ func TestWithTokenURL(t *testing.T) {
require.Equal(t, tokenURL, c.TokenURL.String())
}
+type RequestInterceptor struct {
+ interceptedBody string
+}
+
+func (t *RequestInterceptor) RoundTrip(r *http.Request) (*http.Response, error) {
+ requestBody, _ := ioutil.ReadAll(r.Body)
+ t.interceptedBody = string(requestBody)
+ var body bytes.Buffer
+ body.WriteString(`{"access_token": "foobar", "expires_in": 3600, "scope": "*", "token_type": "bearer"}`)
+ return &http.Response{Status: "200 OK", StatusCode: 200, Body: io.NopCloser(&body)}, nil
+}
+
+func TestWithApplicationOnlyOAuth(t *testing.T) {
+ requestInterceptor := &RequestInterceptor{}
+ c, err := NewClient(Credentials{ID: "id", Secret: "secret"}, WithApplicationOnlyOAuth(true), WithHTTPClient(&http.Client{Transport: requestInterceptor}))
+ require.NoError(t, err)
+ token, err := c.client.Transport.(*oauth2.Transport).Source.Token()
+ require.NoError(t, err)
+ require.Equal(t, token.AccessToken, "foobar")
+ require.Equal(t, "grant_type=client_credentials", requestInterceptor.interceptedBody)
+}
+
func TestFromEnv(t *testing.T) {
os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
diff --git a/reddit/reddit.go b/reddit/reddit.go
index bafda0c..c292a9e 100644
--- a/reddit/reddit.go
+++ b/reddit/reddit.go
@@ -16,7 +16,6 @@ import (
"time"
"github.com/google/go-querystring/query"
- "golang.org/x/oauth2"
)
const (
@@ -96,7 +95,7 @@ type Client struct {
Widget *WidgetService
Wiki *WikiService
- oauth2Transport *oauth2.Transport
+ applicationOnlyOAuth bool
onRequestCompleted RequestCompletionCallback
}