This commit is contained in:
Daniel Ponte 2022-03-12 10:04:11 -05:00
commit 3e50b43650
4 changed files with 73 additions and 17 deletions

View file

@ -38,6 +38,7 @@ import (
"net/http" "net/http"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/clientcredentials"
) )
type oauthTokenSource struct { type oauthTokenSource struct {
@ -50,28 +51,48 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password) return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password)
} }
type oauthTwoLeggedTokenSource struct {
ctx context.Context
config *clientcredentials.Config
}
func (s *oauthTwoLeggedTokenSource) Token() (*oauth2.Token, error) {
return s.config.Token(s.ctx)
}
func oauthTransport(client *Client) http.RoundTripper { func oauthTransport(client *Client) http.RoundTripper {
httpClient := &http.Client{Transport: client.client.Transport} httpClient := &http.Client{Transport: client.client.Transport}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
config := &oauth2.Config{ var tokenSource oauth2.TokenSource
ClientID: client.ID, if client.applicationOnlyOAuth {
ClientSecret: client.Secret, tokenSource = &oauthTwoLeggedTokenSource{
Endpoint: oauth2.Endpoint{ ctx: ctx,
TokenURL: client.TokenURL.String(), config: &clientcredentials.Config{
AuthStyle: oauth2.AuthStyleInHeader, ClientID: client.ID,
}, ClientSecret: client.Secret,
TokenURL: client.TokenURL.String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
}
} else {
tokenSource = &oauthTokenSource{
ctx: ctx,
config: &oauth2.Config{
ClientID: client.ID,
ClientSecret: client.Secret,
Endpoint: oauth2.Endpoint{
TokenURL: client.TokenURL.String(),
AuthStyle: oauth2.AuthStyleInHeader,
},
},
username: client.Username,
password: client.Password,
}
} }
tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{
ctx: ctx,
config: config,
username: client.Username,
password: client.Password,
})
return &oauth2.Transport{ return &oauth2.Transport{
Source: tokenSource, Source: oauth2.ReuseTokenSource(nil, tokenSource),
Base: client.client.Transport, Base: client.client.Transport,
} }
} }

View file

@ -55,6 +55,16 @@ func WithTokenURL(u string) Opt {
} }
} }
// WithApplicationOnlyOAuth sets authentication flow to "Application Only OAuth".
// Only ID and Secret are required to be set in client. Username and Password are ignored.
// The flow is described here: https://github.com/reddit-archive/reddit/wiki/OAuth2#application-only-oauth
func WithApplicationOnlyOAuth(o bool) Opt {
return func(c *Client) error {
c.applicationOnlyOAuth = o
return nil
}
}
// FromEnv configures the client with values from environment variables. // FromEnv configures the client with values from environment variables.
// Supported environment variables: // Supported environment variables:
// GO_REDDIT_CLIENT_ID to set the client's id. // GO_REDDIT_CLIENT_ID to set the client's id.

View file

@ -1,13 +1,17 @@
package reddit package reddit
import ( import (
"bytes"
"fmt" "fmt"
"io"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2"
) )
func TestWithHTTPClient(t *testing.T) { func TestWithHTTPClient(t *testing.T) {
@ -52,6 +56,28 @@ func TestWithTokenURL(t *testing.T) {
require.Equal(t, tokenURL, c.TokenURL.String()) require.Equal(t, tokenURL, c.TokenURL.String())
} }
type RequestInterceptor struct {
interceptedBody string
}
func (t *RequestInterceptor) RoundTrip(r *http.Request) (*http.Response, error) {
requestBody, _ := ioutil.ReadAll(r.Body)
t.interceptedBody = string(requestBody)
var body bytes.Buffer
body.WriteString(`{"access_token": "foobar", "expires_in": 3600, "scope": "*", "token_type": "bearer"}`)
return &http.Response{Status: "200 OK", StatusCode: 200, Body: io.NopCloser(&body)}, nil
}
func TestWithApplicationOnlyOAuth(t *testing.T) {
requestInterceptor := &RequestInterceptor{}
c, err := NewClient(Credentials{ID: "id", Secret: "secret"}, WithApplicationOnlyOAuth(true), WithHTTPClient(&http.Client{Transport: requestInterceptor}))
require.NoError(t, err)
token, err := c.client.Transport.(*oauth2.Transport).Source.Token()
require.NoError(t, err)
require.Equal(t, token.AccessToken, "foobar")
require.Equal(t, "grant_type=client_credentials", requestInterceptor.interceptedBody)
}
func TestFromEnv(t *testing.T) { func TestFromEnv(t *testing.T) {
os.Setenv("GO_REDDIT_CLIENT_ID", "id1") os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
defer os.Unsetenv("GO_REDDIT_CLIENT_ID") defer os.Unsetenv("GO_REDDIT_CLIENT_ID")

View file

@ -16,7 +16,6 @@ import (
"time" "time"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"golang.org/x/oauth2"
) )
const ( const (
@ -96,7 +95,7 @@ type Client struct {
Widget *WidgetService Widget *WidgetService
Wiki *WikiService Wiki *WikiService
oauth2Transport *oauth2.Transport applicationOnlyOAuth bool
onRequestCompleted RequestCompletionCallback onRequestCompleted RequestCompletionCallback
} }