Set HTTP client via option. Update readme, Makefile, go.sum

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-08-26 23:13:34 -04:00
parent e1ce8a7a14
commit c0f0cf8be6
12 changed files with 90 additions and 58 deletions

View File

@ -12,7 +12,7 @@ define log
endef
.PHONY: all
all: lint fmt vet test test-coverage build
all: lint fmt vet test test-coverage
.PHONY: usage
usage:

View File

@ -37,19 +37,30 @@ func main() {
Username: "username",
Password: "password",
}
client, _ := reddit.NewClient(nil, credentials)
client, _ := reddit.NewClient(credentials)
}
```
The first argument (the one set to `nil`) is of type `*http.Client`. It will be used to make the requests. If nil, it will be set to `&http.Client{}`.
You can pass in a number of options to `NewClient` that further configure the client. For example, to use a custom HTTP client:
```go
httpClient := &http.Client{Timeout: time.Second * 30}
client, _ := reddit.NewClient(credentials, reddit.WithHTTPClient(httpClient))
```
If this option is not used, it will be set to `&http.Client{}` by default. More options are available in the [reddit/reddit-options.go](reddit/reddit-options.go) file.
## Examples
<details>
<summary>Configure the client from environment variables.</summary>
<summary>Submit a comment.</summary>
```go
client, _ := reddit.NewClient(nil, reddit.FromEnv)
comment, _, err := client.Comment.Submit(context.Background(), "t3_postid", "comment body")
if err != nil {
return err
}
fmt.Printf("Comment permalink: %s\n", comment.Permalink)
```
</details>
@ -59,7 +70,6 @@ client, _ := reddit.NewClient(nil, reddit.FromEnv)
```go
_, err := client.Post.Upvote(context.Background(), "t3_postid")
if err != nil {
fmt.Printf("Something bad happened: %v\n", err)
return err
}
```
@ -76,7 +86,6 @@ result, _, err := client.Subreddit.TopPosts(context.Background(), "golang", &red
Time: "all",
})
if err != nil {
fmt.Printf("Something bad happened: %v\n", err)
return err
}
fmt.Printf("Received %d posts.\n", len(result.Posts))
@ -87,7 +96,7 @@ More examples are available in the [examples](examples) folder.
## Design
The package design and structure are heavily inspired from [Google's GitHub API client](https://github.com/google/go-github) and [DigitalOcean's API client](https://github.com/digitalocean/godo).
The package design is heavily inspired from [Google's GitHub API client](https://github.com/google/go-github) and [DigitalOcean's API client](https://github.com/digitalocean/godo).
## License

View File

@ -25,7 +25,7 @@ func run() (err error) {
Password: "password",
}
client, err := reddit.NewClient(nil, credentials)
client, err := reddit.NewClient(credentials)
if err != nil {
return
}

View File

@ -24,7 +24,7 @@ func run() (err error) {
Password: "password",
}
client, err := reddit.NewClient(nil, credentials)
client, err := reddit.NewClient(credentials)
if err != nil {
return
}

View File

@ -24,7 +24,7 @@ func run() (err error) {
Password: "password",
}
client, err := reddit.NewClient(nil, credentials)
client, err := reddit.NewClient(credentials)
if err != nil {
return
}

View File

@ -26,7 +26,7 @@ func run() (err error) {
Password: "password",
}
client, err := reddit.NewClient(nil, credentials)
client, err := reddit.NewClient(credentials)
if err != nil {
return
}

View File

@ -24,7 +24,7 @@ func run() (err error) {
Password: "password",
}
client, err := reddit.NewClient(nil, credentials)
client, err := reddit.NewClient(credentials)
if err != nil {
return
}

6
go.sum
View File

@ -10,21 +10,17 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20200425230154-ff2c4b7c35a0 h1:Jcxah/M+oLZ/R4/z5RzfPzGbPXnVDPkEDtf2JnuxN+U=
golang.org/x/net v0.0.0-20200425230154-ff2c4b7c35a0/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d h1:TzXSXBo42m9gQenoE3b9BGiEpg5IG2JkU5FkPIawgtw=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@ -1,6 +1,8 @@
package reddit
import (
"errors"
"net/http"
"net/url"
"os"
)
@ -8,8 +10,18 @@ import (
// Opt is a configuration option to initialize a client.
type Opt func(*Client) error
// WithHTTPClient sets the HTTP client which will be used to make requests.
func WithHTTPClient(httpClient *http.Client) Opt {
return func(c *Client) error {
if httpClient == nil {
return errors.New("httpClient: cannot be nil")
}
c.client = httpClient
return nil
}
}
// FromEnv configures the client with values from environment variables.
//
// Supported environment variables:
// GO_REDDIT_CLIENT_ID to set the client's id.
// GO_REDDIT_CLIENT_SECRET to set the client's secret.
@ -19,19 +31,15 @@ func FromEnv(c *Client) error {
if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" {
c.ID = v
}
if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" {
c.Secret = v
}
if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" {
c.Username = v
}
if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" {
c.Password = v
}
return nil
}
@ -42,7 +50,6 @@ func WithBaseURL(u string) Opt {
if err != nil {
return err
}
c.BaseURL = url
return nil
}
@ -55,7 +62,6 @@ func WithTokenURL(u string) Opt {
if err != nil {
return err
}
c.TokenURL = url
return nil
}

View File

@ -1,12 +1,22 @@
package reddit
import (
"net/http"
"net/url"
"os"
"testing"
"github.com/stretchr/testify/require"
)
func TestWithHTTPClient(t *testing.T) {
_, err := NewClient(nil, WithHTTPClient(nil))
require.EqualError(t, err, "httpClient: cannot be nil")
_, err = NewClient(nil, WithHTTPClient(&http.Client{}))
require.NoError(t, err)
}
func TestFromEnv(t *testing.T) {
os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
@ -20,7 +30,7 @@ func TestFromEnv(t *testing.T) {
os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1")
defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD")
c, err := NewClient(nil, nil, FromEnv)
c, err := NewClient(nil, FromEnv)
require.NoError(t, err)
type values struct {
@ -33,15 +43,25 @@ func TestFromEnv(t *testing.T) {
}
func TestWithBaseURL(t *testing.T) {
c, err := NewClient(nil, WithBaseURL(":"))
urlErr, ok := err.(*url.Error)
require.True(t, ok)
require.Equal(t, "parse", urlErr.Op)
baseURL := "http://localhost:8080"
c, err := NewClient(nil, nil, WithBaseURL(baseURL))
c, err = NewClient(nil, WithBaseURL(baseURL))
require.NoError(t, err)
require.Equal(t, baseURL, c.BaseURL.String())
}
func TestWithTokenURL(t *testing.T) {
c, err := NewClient(nil, WithTokenURL(":"))
urlErr, ok := err.(*url.Error)
require.True(t, ok)
require.Equal(t, "parse", urlErr.Op)
tokenURL := "http://localhost:8080/api/v1/access_token"
c, err := NewClient(nil, nil, WithTokenURL(tokenURL))
c, err = NewClient(nil, WithTokenURL(tokenURL))
require.NoError(t, err)
require.Equal(t, tokenURL, c.TokenURL.String())
}

View File

@ -124,32 +124,11 @@ func (c *Client) OnRequestCompleted(rc RequestCompletionCallback) {
c.onRequestCompleted = rc
}
func newClient(httpClient *http.Client) *Client {
if httpClient == nil {
httpClient = &http.Client{}
}
// todo...
// Some endpoints (notably the ones to get random subreddits/posts) redirect to a
// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
// host is changed to oauth.reddit.com
httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
redirectURL := req.URL.String()
redirectURL = strings.Replace(redirectURL, "https://www.reddit.com", defaultBaseURL, 1)
reqURL, err := url.Parse(redirectURL)
if err != nil {
return err
}
req.URL = reqURL
return nil
}
func newClient() *Client {
baseURL, _ := url.Parse(defaultBaseURL)
tokenURL, _ := url.Parse(defaultTokenURL)
client := &Client{client: httpClient, BaseURL: baseURL, TokenURL: tokenURL}
client := &Client{BaseURL: baseURL, TokenURL: tokenURL}
client.Account = &AccountService{client: client}
client.Collection = &CollectionService{client: client}
@ -171,10 +150,9 @@ func newClient(httpClient *http.Client) *Client {
return client
}
// NewClient returns a new Reddit API client. If a nil httpClient is provided,
// a new http.Client will be used.
func NewClient(httpClient *http.Client, creds *Credentials, opts ...Opt) (*Client, error) {
client := newClient(httpClient)
// NewClient returns a new Reddit API client.
func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
client := newClient()
for _, opt := range opts {
if err := opt(client); err != nil {
@ -182,6 +160,29 @@ func NewClient(httpClient *http.Client, creds *Credentials, opts ...Opt) (*Clien
}
}
if client.client == nil {
client.client = &http.Client{}
}
// todo...
// Some endpoints (notably the ones to get random subreddits/posts) redirect to a
// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
// host is changed to oauth.reddit.com
if client.client.CheckRedirect == nil {
client.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
redirectURL := req.URL.String()
redirectURL = strings.Replace(redirectURL, "https://www.reddit.com", defaultBaseURL, 1)
reqURL, err := url.Parse(redirectURL)
if err != nil {
return err
}
req.URL = reqURL
return nil
}
}
if creds != nil {
client.ID = creds.ID
client.Secret = creds.Secret

View File

@ -31,7 +31,7 @@ func setup() (*Client, *http.ServeMux, func()) {
fmt.Fprint(w, response)
})
client, _ := NewClient(nil,
client, _ := NewClient(
&Credentials{"id1", "secret1", "user1", "password1"},
WithBaseURL(server.URL),
WithTokenURL(server.URL+"/api/v1/access_token"),
@ -92,7 +92,7 @@ func testClientDefaults(t *testing.T, c *Client) {
}
func TestNewClient(t *testing.T) {
c, err := NewClient(nil, nil)
c, err := NewClient(nil)
require.NoError(t, err)
testClientDefaults(t, c)
}
@ -102,7 +102,7 @@ func TestNewClient_Error(t *testing.T) {
return errors.New("foo")
}
_, err := NewClient(nil, nil, errorOpt)
_, err := NewClient(nil, errorOpt)
require.EqualError(t, err, "foo")
}