Merge branch 'master' of https://github.com/and3rson/go-reddit
This commit is contained in:
commit
3e50b43650
4 changed files with 73 additions and 17 deletions
|
@ -38,6 +38,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/clientcredentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
type oauthTokenSource struct {
|
type oauthTokenSource struct {
|
||||||
|
@ -50,28 +51,48 @@ func (s *oauthTokenSource) Token() (*oauth2.Token, error) {
|
||||||
return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password)
|
return s.config.PasswordCredentialsToken(s.ctx, s.username, s.password)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type oauthTwoLeggedTokenSource struct {
|
||||||
|
ctx context.Context
|
||||||
|
config *clientcredentials.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *oauthTwoLeggedTokenSource) Token() (*oauth2.Token, error) {
|
||||||
|
return s.config.Token(s.ctx)
|
||||||
|
}
|
||||||
|
|
||||||
func oauthTransport(client *Client) http.RoundTripper {
|
func oauthTransport(client *Client) http.RoundTripper {
|
||||||
httpClient := &http.Client{Transport: client.client.Transport}
|
httpClient := &http.Client{Transport: client.client.Transport}
|
||||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||||
|
|
||||||
config := &oauth2.Config{
|
var tokenSource oauth2.TokenSource
|
||||||
ClientID: client.ID,
|
if client.applicationOnlyOAuth {
|
||||||
ClientSecret: client.Secret,
|
tokenSource = &oauthTwoLeggedTokenSource{
|
||||||
Endpoint: oauth2.Endpoint{
|
ctx: ctx,
|
||||||
TokenURL: client.TokenURL.String(),
|
config: &clientcredentials.Config{
|
||||||
AuthStyle: oauth2.AuthStyleInHeader,
|
ClientID: client.ID,
|
||||||
},
|
ClientSecret: client.Secret,
|
||||||
|
TokenURL: client.TokenURL.String(),
|
||||||
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tokenSource = &oauthTokenSource{
|
||||||
|
ctx: ctx,
|
||||||
|
config: &oauth2.Config{
|
||||||
|
ClientID: client.ID,
|
||||||
|
ClientSecret: client.Secret,
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
TokenURL: client.TokenURL.String(),
|
||||||
|
AuthStyle: oauth2.AuthStyleInHeader,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
username: client.Username,
|
||||||
|
password: client.Password,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenSource := oauth2.ReuseTokenSource(nil, &oauthTokenSource{
|
|
||||||
ctx: ctx,
|
|
||||||
config: config,
|
|
||||||
username: client.Username,
|
|
||||||
password: client.Password,
|
|
||||||
})
|
|
||||||
|
|
||||||
return &oauth2.Transport{
|
return &oauth2.Transport{
|
||||||
Source: tokenSource,
|
Source: oauth2.ReuseTokenSource(nil, tokenSource),
|
||||||
Base: client.client.Transport,
|
Base: client.client.Transport,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,6 +55,16 @@ func WithTokenURL(u string) Opt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithApplicationOnlyOAuth sets authentication flow to "Application Only OAuth".
|
||||||
|
// Only ID and Secret are required to be set in client. Username and Password are ignored.
|
||||||
|
// The flow is described here: https://github.com/reddit-archive/reddit/wiki/OAuth2#application-only-oauth
|
||||||
|
func WithApplicationOnlyOAuth(o bool) Opt {
|
||||||
|
return func(c *Client) error {
|
||||||
|
c.applicationOnlyOAuth = o
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FromEnv configures the client with values from environment variables.
|
// FromEnv configures the client with values from environment variables.
|
||||||
// Supported environment variables:
|
// Supported environment variables:
|
||||||
// GO_REDDIT_CLIENT_ID to set the client's id.
|
// GO_REDDIT_CLIENT_ID to set the client's id.
|
||||||
|
|
|
@ -1,13 +1,17 @@
|
||||||
package reddit
|
package reddit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestWithHTTPClient(t *testing.T) {
|
func TestWithHTTPClient(t *testing.T) {
|
||||||
|
@ -52,6 +56,28 @@ func TestWithTokenURL(t *testing.T) {
|
||||||
require.Equal(t, tokenURL, c.TokenURL.String())
|
require.Equal(t, tokenURL, c.TokenURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type RequestInterceptor struct {
|
||||||
|
interceptedBody string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *RequestInterceptor) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
requestBody, _ := ioutil.ReadAll(r.Body)
|
||||||
|
t.interceptedBody = string(requestBody)
|
||||||
|
var body bytes.Buffer
|
||||||
|
body.WriteString(`{"access_token": "foobar", "expires_in": 3600, "scope": "*", "token_type": "bearer"}`)
|
||||||
|
return &http.Response{Status: "200 OK", StatusCode: 200, Body: io.NopCloser(&body)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWithApplicationOnlyOAuth(t *testing.T) {
|
||||||
|
requestInterceptor := &RequestInterceptor{}
|
||||||
|
c, err := NewClient(Credentials{ID: "id", Secret: "secret"}, WithApplicationOnlyOAuth(true), WithHTTPClient(&http.Client{Transport: requestInterceptor}))
|
||||||
|
require.NoError(t, err)
|
||||||
|
token, err := c.client.Transport.(*oauth2.Transport).Source.Token()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, token.AccessToken, "foobar")
|
||||||
|
require.Equal(t, "grant_type=client_credentials", requestInterceptor.interceptedBody)
|
||||||
|
}
|
||||||
|
|
||||||
func TestFromEnv(t *testing.T) {
|
func TestFromEnv(t *testing.T) {
|
||||||
os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
|
os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
|
||||||
defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
|
defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
|
||||||
|
|
|
@ -16,7 +16,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
"golang.org/x/oauth2"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -96,7 +95,7 @@ type Client struct {
|
||||||
Widget *WidgetService
|
Widget *WidgetService
|
||||||
Wiki *WikiService
|
Wiki *WikiService
|
||||||
|
|
||||||
oauth2Transport *oauth2.Transport
|
applicationOnlyOAuth bool
|
||||||
|
|
||||||
onRequestCompleted RequestCompletionCallback
|
onRequestCompleted RequestCompletionCallback
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue