Add global DefaultClient (readonly), WithUserAgent option
Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
parent
aac3b48d6e
commit
37e712b334
11 changed files with 260 additions and 142 deletions
31
README.md
31
README.md
|
@ -3,7 +3,7 @@
|
||||||
[![Actions Status](https://github.com/vartanbeno/go-reddit/workflows/tests/badge.svg)](https://github.com/vartanbeno/go-reddit/actions)
|
[![Actions Status](https://github.com/vartanbeno/go-reddit/workflows/tests/badge.svg)](https://github.com/vartanbeno/go-reddit/actions)
|
||||||
[![Go Report Card](https://goreportcard.com/badge/github.com/vartanbeno/go-reddit)](https://goreportcard.com/report/github.com/vartanbeno/go-reddit)
|
[![Go Report Card](https://goreportcard.com/badge/github.com/vartanbeno/go-reddit)](https://goreportcard.com/report/github.com/vartanbeno/go-reddit)
|
||||||
|
|
||||||
Featured in [issue 327 of Golang Weekly](https://golangweekly.com/issues/327) 🎉
|
**Featured in [issue 327 of Golang Weekly](https://golangweekly.com/issues/327) 🎉**
|
||||||
|
|
||||||
go-reddit is a Go client library for accessing the Reddit API.
|
go-reddit is a Go client library for accessing the Reddit API.
|
||||||
|
|
||||||
|
@ -33,27 +33,36 @@ package main
|
||||||
import "github.com/vartanbeno/go-reddit/reddit"
|
import "github.com/vartanbeno/go-reddit/reddit"
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
credentials := &reddit.Credentials{
|
withCredentials := reddit.WithCredentials("id", "secret", "username", "password")
|
||||||
ID: "id",
|
client, _ := reddit.NewClient(withCredentials)
|
||||||
Secret: "secret",
|
|
||||||
Username: "username",
|
|
||||||
Password: "password",
|
|
||||||
}
|
|
||||||
client, _ := reddit.NewClient(credentials)
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
You can pass in a number of options to `NewClient` that further configure the client. For example, to use a custom HTTP client:
|
You can pass in a number of options to `NewClient` to further configure the client (see [reddit/reddit-options.go](reddit/reddit-options.go)). For example, to use a custom HTTP client:
|
||||||
|
|
||||||
```go
|
```go
|
||||||
httpClient := &http.Client{Timeout: time.Second * 30}
|
httpClient := &http.Client{Timeout: time.Second * 30}
|
||||||
client, _ := reddit.NewClient(credentials, reddit.WithHTTPClient(httpClient))
|
client, _ := reddit.NewClient(withCredentials, reddit.WithHTTPClient(httpClient))
|
||||||
```
|
```
|
||||||
|
|
||||||
If this option is not used, it will be set to `&http.Client{}` by default. More options are available in the [reddit/reddit-options.go](reddit/reddit-options.go) file.
|
### Read-Only Mode
|
||||||
|
|
||||||
|
The global `DefaultClient` variable is a valid, read-only client with limited access to the Reddit API, much like a logged out user. You can initialize your own via `NewReadonlyClient`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
client, _ := reddit.NewReadonlyClient()
|
||||||
|
```
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Configure the client from environment variables.</summary>
|
||||||
|
|
||||||
|
```go
|
||||||
|
client, _ := reddit.NewClient(reddit.FromEnv)
|
||||||
|
```
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Submit a comment.</summary>
|
<summary>Submit a comment.</summary>
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,7 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func run() (err error) {
|
func run() (err error) {
|
||||||
credentials := &reddit.Credentials{
|
client, err := reddit.NewReadonlyClient()
|
||||||
ID: "id",
|
|
||||||
Secret: "secret",
|
|
||||||
Username: "username",
|
|
||||||
Password: "password",
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := reddit.NewClient(credentials)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -35,13 +28,7 @@ func run() (err error) {
|
||||||
client.Subreddit.Search(ctx, "programming", nil)
|
client.Subreddit.Search(ctx, "programming", nil)
|
||||||
client.Subreddit.SearchNames(ctx, "monitor")
|
client.Subreddit.SearchNames(ctx, "monitor")
|
||||||
client.Subreddit.SearchPosts(ctx, "react", "webdev", nil)
|
client.Subreddit.SearchPosts(ctx, "react", "webdev", nil)
|
||||||
client.User.Posts(ctx, &reddit.ListUserOverviewOptions{
|
client.Subreddit.HotPosts(ctx, "golang", &reddit.ListOptions{Limit: 5})
|
||||||
ListOptions: reddit.ListOptions{
|
|
||||||
Limit: 50,
|
|
||||||
},
|
|
||||||
Sort: "top",
|
|
||||||
Time: "month",
|
|
||||||
})
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,24 +17,11 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func run() (err error) {
|
func run() (err error) {
|
||||||
credentials := &reddit.Credentials{
|
sr, _, err := reddit.DefaultClient.Subreddit.Get(ctx, "golang")
|
||||||
ID: "id",
|
|
||||||
Secret: "secret",
|
|
||||||
Username: "username",
|
|
||||||
Password: "password",
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := reddit.NewClient(credentials)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sr, _, err := client.Subreddit.Get(ctx, "golang")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("%s was created on %s and has %d subscribers.\n", sr.NamePrefixed, sr.Created.Local(), sr.Subscribers)
|
fmt.Printf("%s was created on %s and has %d subscribers.\n", sr.NamePrefixed, sr.Created.Local(), sr.Subscribers)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,22 +17,10 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func run() (err error) {
|
func run() (err error) {
|
||||||
credentials := &reddit.Credentials{
|
|
||||||
ID: "id",
|
|
||||||
Secret: "secret",
|
|
||||||
Username: "username",
|
|
||||||
Password: "password",
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := reddit.NewClient(credentials)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Let's get the top 200 posts of r/golang.
|
// Let's get the top 200 posts of r/golang.
|
||||||
// Reddit returns a maximum of 100 posts at a time,
|
// Reddit returns a maximum of 100 posts at a time,
|
||||||
// so we'll need to separate this into 2 requests.
|
// so we'll need to separate this into 2 requests.
|
||||||
result, _, err := client.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{
|
result, _, err := reddit.DefaultClient.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{
|
||||||
ListOptions: reddit.ListOptions{
|
ListOptions: reddit.ListOptions{
|
||||||
Limit: 100,
|
Limit: 100,
|
||||||
},
|
},
|
||||||
|
@ -48,7 +36,7 @@ func run() (err error) {
|
||||||
|
|
||||||
// The SetAfter option sets the id of an item that Reddit
|
// The SetAfter option sets the id of an item that Reddit
|
||||||
// will use as an anchor point for the returned listing.
|
// will use as an anchor point for the returned listing.
|
||||||
result, _, err = client.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{
|
result, _, err = reddit.DefaultClient.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{
|
||||||
ListOptions: reddit.ListOptions{
|
ListOptions: reddit.ListOptions{
|
||||||
Limit: 100,
|
Limit: 100,
|
||||||
After: result.After,
|
After: result.After,
|
||||||
|
|
|
@ -3,8 +3,9 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/vartanbeno/go-reddit/reddit"
|
"github.com/vartanbeno/go-reddit/reddit"
|
||||||
|
@ -13,28 +14,16 @@ import (
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if err := run(); err != nil {
|
sig := make(chan os.Signal, 1)
|
||||||
log.Fatal(err)
|
defer close(sig)
|
||||||
}
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||||
}
|
|
||||||
|
|
||||||
func run() (err error) {
|
posts, errs, stop := reddit.DefaultClient.Stream.Posts("AskReddit", reddit.StreamInterval(time.Second*3), reddit.StreamDiscardInitial)
|
||||||
credentials := &reddit.Credentials{
|
|
||||||
ID: "id",
|
|
||||||
Secret: "secret",
|
|
||||||
Username: "username",
|
|
||||||
Password: "password",
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := reddit.NewClient(credentials)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
posts, errs, stop := client.Stream.Posts("AskReddit", reddit.StreamInterval(time.Second*3), reddit.StreamDiscardInitial)
|
|
||||||
defer stop()
|
defer stop()
|
||||||
|
|
||||||
go func() {
|
timer := time.NewTimer(time.Minute)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case post, ok := <-posts:
|
case post, ok := <-posts:
|
||||||
|
@ -47,10 +36,13 @@ func run() (err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "Error! %v\n", err)
|
fmt.Fprintf(os.Stderr, "Error! %v\n", err)
|
||||||
}
|
case rcvSig, ok := <-sig:
|
||||||
}
|
if !ok {
|
||||||
}()
|
|
||||||
|
|
||||||
<-time.After(time.Minute)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
fmt.Printf("Stopping due to %s signal\n", rcvSig)
|
||||||
|
case <-timer.C:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -17,14 +17,9 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func run() (err error) {
|
func run() (err error) {
|
||||||
credentials := &reddit.Credentials{
|
withCredentials := reddit.WithCredentials("id", "secret", "username", "password")
|
||||||
ID: "id",
|
|
||||||
Secret: "secret",
|
|
||||||
Username: "username",
|
|
||||||
Password: "password",
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := reddit.NewClient(credentials)
|
client, err := reddit.NewClient(withCredentials)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -51,6 +46,5 @@ func run() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("The link post is available at: %s\n", post.URL)
|
fmt.Printf("The link post is available at: %s\n", post.URL)
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,11 +4,23 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Opt is a configuration option to initialize a client.
|
// Opt is a configuration option to initialize a client.
|
||||||
type Opt func(*Client) error
|
type Opt func(*Client) error
|
||||||
|
|
||||||
|
// WithCredentials sets the credentials used to authenticate with the Reddit API.
|
||||||
|
func WithCredentials(id, secret, username, password string) Opt {
|
||||||
|
return func(c *Client) error {
|
||||||
|
c.ID = id
|
||||||
|
c.Secret = secret
|
||||||
|
c.Username = username
|
||||||
|
c.Password = password
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithHTTPClient sets the HTTP client which will be used to make requests.
|
// WithHTTPClient sets the HTTP client which will be used to make requests.
|
||||||
func WithHTTPClient(httpClient *http.Client) Opt {
|
func WithHTTPClient(httpClient *http.Client) Opt {
|
||||||
return func(c *Client) error {
|
return func(c *Client) error {
|
||||||
|
@ -20,6 +32,16 @@ func WithHTTPClient(httpClient *http.Client) Opt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithUserAgent sets the User-Agent header for requests made with the client.
|
||||||
|
// Reddit recommends the following format for the user agent:
|
||||||
|
// <platform>:<app ID>:<version string> (by /u/<reddit username>)
|
||||||
|
func WithUserAgent(ua string) Opt {
|
||||||
|
return func(c *Client) error {
|
||||||
|
c.userAgent = ua
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithBaseURL sets the base URL for the client to make requests to.
|
// WithBaseURL sets the base URL for the client to make requests to.
|
||||||
func WithBaseURL(u string) Opt {
|
func WithBaseURL(u string) Opt {
|
||||||
return func(c *Client) error {
|
return func(c *Client) error {
|
||||||
|
@ -43,3 +65,25 @@ func WithTokenURL(u string) Opt {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FromEnv configures the client with values from environment variables.
|
||||||
|
// Supported environment variables:
|
||||||
|
// GO_REDDIT_CLIENT_ID to set the client's id.
|
||||||
|
// GO_REDDIT_CLIENT_SECRET to set the client's secret.
|
||||||
|
// GO_REDDIT_CLIENT_USERNAME to set the client's username.
|
||||||
|
// GO_REDDIT_CLIENT_PASSWORD to set the client's password.
|
||||||
|
func FromEnv(c *Client) error {
|
||||||
|
if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" {
|
||||||
|
c.ID = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" {
|
||||||
|
c.Secret = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" {
|
||||||
|
c.Username = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" {
|
||||||
|
c.Password = v
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,41 +1,83 @@
|
||||||
package reddit
|
package reddit
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestWithCredentials(t *testing.T) {
|
||||||
|
c, err := NewClient(WithCredentials("id1", "secret1", "username1", "password1"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "id1", c.ID)
|
||||||
|
require.Equal(t, "secret1", c.Secret)
|
||||||
|
require.Equal(t, "username1", c.Username)
|
||||||
|
require.Equal(t, "password1", c.Password)
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithHTTPClient(t *testing.T) {
|
func TestWithHTTPClient(t *testing.T) {
|
||||||
_, err := NewClient(&Credentials{}, WithHTTPClient(nil))
|
_, err := NewClient(WithHTTPClient(nil))
|
||||||
require.EqualError(t, err, "httpClient: cannot be nil")
|
require.EqualError(t, err, "httpClient: cannot be nil")
|
||||||
|
|
||||||
_, err = NewClient(&Credentials{}, WithHTTPClient(&http.Client{}))
|
_, err = NewClient(WithHTTPClient(&http.Client{}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWithUserAgent(t *testing.T) {
|
||||||
|
c, err := NewClient(WithUserAgent("test"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "test", c.UserAgent())
|
||||||
|
|
||||||
|
c, err = NewClient(WithUserAgent(""))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion), c.UserAgent())
|
||||||
|
}
|
||||||
|
|
||||||
func TestWithBaseURL(t *testing.T) {
|
func TestWithBaseURL(t *testing.T) {
|
||||||
c, err := NewClient(&Credentials{}, WithBaseURL(":"))
|
c, err := NewClient(WithBaseURL(":"))
|
||||||
urlErr, ok := err.(*url.Error)
|
urlErr, ok := err.(*url.Error)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, "parse", urlErr.Op)
|
require.Equal(t, "parse", urlErr.Op)
|
||||||
|
|
||||||
baseURL := "http://localhost:8080"
|
baseURL := "http://localhost:8080"
|
||||||
c, err = NewClient(&Credentials{}, WithBaseURL(baseURL))
|
c, err = NewClient(WithBaseURL(baseURL))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, baseURL, c.BaseURL.String())
|
require.Equal(t, baseURL, c.BaseURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWithTokenURL(t *testing.T) {
|
func TestWithTokenURL(t *testing.T) {
|
||||||
c, err := NewClient(&Credentials{}, WithTokenURL(":"))
|
c, err := NewClient(WithTokenURL(":"))
|
||||||
urlErr, ok := err.(*url.Error)
|
urlErr, ok := err.(*url.Error)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, "parse", urlErr.Op)
|
require.Equal(t, "parse", urlErr.Op)
|
||||||
|
|
||||||
tokenURL := "http://localhost:8080/api/v1/access_token"
|
tokenURL := "http://localhost:8080/api/v1/access_token"
|
||||||
c, err = NewClient(&Credentials{}, WithTokenURL(tokenURL))
|
c, err = NewClient(WithTokenURL(tokenURL))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tokenURL, c.TokenURL.String())
|
require.Equal(t, tokenURL, c.TokenURL.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFromEnv(t *testing.T) {
|
||||||
|
os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
|
||||||
|
defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
|
||||||
|
|
||||||
|
os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1")
|
||||||
|
defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET")
|
||||||
|
|
||||||
|
os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1")
|
||||||
|
defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME")
|
||||||
|
|
||||||
|
os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1")
|
||||||
|
defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD")
|
||||||
|
|
||||||
|
c, err := NewClient(FromEnv)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "id1", c.ID)
|
||||||
|
require.Equal(t, "secret1", c.Secret)
|
||||||
|
require.Equal(t, "username1", c.Username)
|
||||||
|
require.Equal(t, "password1", c.Password)
|
||||||
|
}
|
||||||
|
|
|
@ -18,6 +18,8 @@ func cloneRequest(r *http.Request) *http.Request {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets the User-Agent header for requests.
|
// Sets the User-Agent header for requests.
|
||||||
|
// We need to set a custom user agent because using the one set by the
|
||||||
|
// stdlib gives us 429 Too Many Requests responses from the Reddit API.
|
||||||
type userAgentTransport struct {
|
type userAgentTransport struct {
|
||||||
userAgent string
|
userAgent string
|
||||||
Base http.RoundTripper
|
Base http.RoundTripper
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
@ -22,6 +21,7 @@ const (
|
||||||
libraryVersion = "1.0.0"
|
libraryVersion = "1.0.0"
|
||||||
|
|
||||||
defaultBaseURL = "https://oauth.reddit.com"
|
defaultBaseURL = "https://oauth.reddit.com"
|
||||||
|
defaultBaseURLReadonly = "https://reddit.com"
|
||||||
defaultTokenURL = "https://www.reddit.com/api/v1/access_token"
|
defaultTokenURL = "https://www.reddit.com/api/v1/access_token"
|
||||||
|
|
||||||
mediaTypeJSON = "application/json"
|
mediaTypeJSON = "application/json"
|
||||||
|
@ -32,6 +32,9 @@ const (
|
||||||
headerUserAgent = "User-Agent"
|
headerUserAgent = "User-Agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DefaultClient is a readonly client with limited access to the Reddit API.
|
||||||
|
var DefaultClient, _ = NewReadonlyClient()
|
||||||
|
|
||||||
// RequestCompletionCallback defines the type of the request callback function.
|
// RequestCompletionCallback defines the type of the request callback function.
|
||||||
type RequestCompletionCallback func(*http.Request, *http.Response)
|
type RequestCompletionCallback func(*http.Request, *http.Response)
|
||||||
|
|
||||||
|
@ -113,16 +116,9 @@ func newClient() *Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient returns a new Reddit API client.
|
// NewClient returns a new Reddit API client.
|
||||||
func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
|
// Use an Opt to configure the client credentials, such as WithCredentials or FromEnv.
|
||||||
if creds == nil {
|
func NewClient(opts ...Opt) (*Client, error) {
|
||||||
return nil, errors.New("must provide credentials to initialize *reddit.Client")
|
|
||||||
}
|
|
||||||
|
|
||||||
client := newClient()
|
client := newClient()
|
||||||
client.ID = creds.ID
|
|
||||||
client.Secret = creds.Secret
|
|
||||||
client.Username = creds.Username
|
|
||||||
client.Password = creds.Password
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if err := opt(client); err != nil {
|
if err := opt(client); err != nil {
|
||||||
|
@ -134,20 +130,53 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
|
||||||
client.client = &http.Client{}
|
client.client = &http.Client{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to set a custom user agent, because using the one set by the
|
|
||||||
// stdlib gives us 429 Too Many Requests responses from the Reddit API.
|
|
||||||
userAgentTransport := &userAgentTransport{
|
userAgentTransport := &userAgentTransport{
|
||||||
userAgent: client.UserAgent(),
|
userAgent: client.UserAgent(),
|
||||||
Base: client.client.Transport,
|
Base: client.client.Transport,
|
||||||
}
|
}
|
||||||
client.client.Transport = userAgentTransport
|
client.client.Transport = userAgentTransport
|
||||||
|
|
||||||
|
if client.client.CheckRedirect == nil {
|
||||||
|
client.client.CheckRedirect = client.redirect
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthTransport := oauthTransport(client)
|
||||||
|
client.client.Transport = oauthTransport
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewReadonlyClient returns a new read-only Reddit API client.
|
||||||
|
// The client will have limited access to the Reddit API.
|
||||||
|
// Options that modify credentials (such as WithCredentials or FromEnv) won't have any effect on this client.
|
||||||
|
func NewReadonlyClient(opts ...Opt) (*Client, error) {
|
||||||
|
client := newClient()
|
||||||
|
client.BaseURL, _ = url.Parse(defaultBaseURLReadonly)
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
if err := opt(client); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if client.client == nil {
|
||||||
|
client.client = &http.Client{}
|
||||||
|
}
|
||||||
|
|
||||||
|
userAgentTransport := &userAgentTransport{
|
||||||
|
userAgent: client.UserAgent(),
|
||||||
|
Base: client.client.Transport,
|
||||||
|
}
|
||||||
|
client.client.Transport = userAgentTransport
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
// todo...
|
// todo...
|
||||||
// Some endpoints (notably the ones to get random subreddits/posts) redirect to a
|
// Some endpoints (notably the ones to get random subreddits/posts) redirect to a
|
||||||
// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
|
// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
|
||||||
// host is changed to oauth.reddit.com
|
// host is changed to oauth.reddit.com
|
||||||
if client.client.CheckRedirect == nil {
|
func (c *Client) redirect(req *http.Request, via []*http.Request) error {
|
||||||
client.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
|
||||||
redirectURL := req.URL.String()
|
redirectURL := req.URL.String()
|
||||||
redirectURL = strings.Replace(redirectURL, "https://www.reddit.com", defaultBaseURL, 1)
|
redirectURL = strings.Replace(redirectURL, "https://www.reddit.com", defaultBaseURL, 1)
|
||||||
|
|
||||||
|
@ -159,18 +188,29 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The readonly Reddit url needs .json at the end of its path to return responses in JSON instead of HTML.
|
||||||
|
func (c *Client) appendJSONExtensionToRequestPath(req *http.Request) {
|
||||||
|
readonlyURL, err := url.Parse(defaultBaseURLReadonly)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
oauthTransport := oauthTransport(client)
|
if req.URL.Host != readonlyURL.Host {
|
||||||
client.client.Transport = oauthTransport
|
return
|
||||||
|
}
|
||||||
|
|
||||||
return client, nil
|
req.URL.Path += ".json"
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserAgent returns the client's user agent.
|
// UserAgent returns the client's user agent.
|
||||||
func (c *Client) UserAgent() string {
|
func (c *Client) UserAgent() string {
|
||||||
if c.userAgent == "" {
|
if c.userAgent == "" {
|
||||||
c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
|
userAgent := fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion)
|
||||||
|
if c.Username != "" {
|
||||||
|
userAgent += fmt.Sprintf(" (by /u/%s)", c.Username)
|
||||||
|
}
|
||||||
|
c.userAgent = userAgent
|
||||||
}
|
}
|
||||||
return c.userAgent
|
return c.userAgent
|
||||||
}
|
}
|
||||||
|
@ -198,6 +238,7 @@ func (c *Client) NewRequest(method string, path string, body interface{}) (*http
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.appendJSONExtensionToRequestPath(req)
|
||||||
req.Header.Add(headerContentType, mediaTypeJSON)
|
req.Header.Add(headerContentType, mediaTypeJSON)
|
||||||
req.Header.Add(headerAccept, mediaTypeJSON)
|
req.Header.Add(headerAccept, mediaTypeJSON)
|
||||||
|
|
||||||
|
@ -218,6 +259,7 @@ func (c *Client) NewRequestWithForm(method string, path string, form url.Values)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.appendJSONExtensionToRequestPath(req)
|
||||||
req.Header.Add(headerContentType, mediaTypeForm)
|
req.Header.Add(headerContentType, mediaTypeForm)
|
||||||
req.Header.Add(headerAccept, mediaTypeJSON)
|
req.Header.Add(headerAccept, mediaTypeJSON)
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ func setup() (*Client, *http.ServeMux, func()) {
|
||||||
})
|
})
|
||||||
|
|
||||||
client, _ := NewClient(
|
client, _ := NewClient(
|
||||||
&Credentials{"id1", "secret1", "user1", "password1"},
|
WithCredentials("id1", "secret1", "user1", "password1"),
|
||||||
WithBaseURL(server.URL),
|
WithBaseURL(server.URL),
|
||||||
WithTokenURL(server.URL+"/api/v1/access_token"),
|
WithTokenURL(server.URL+"/api/v1/access_token"),
|
||||||
)
|
)
|
||||||
|
@ -82,7 +82,7 @@ func testClientServices(t *testing.T, c *Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func testClientDefaultUserAgent(t *testing.T, c *Client) {
|
func testClientDefaultUserAgent(t *testing.T, c *Client) {
|
||||||
expectedUserAgent := fmt.Sprintf("golang:%s:v%s (by /u/)", libraryName, libraryVersion)
|
expectedUserAgent := fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion)
|
||||||
require.Equal(t, expectedUserAgent, c.UserAgent())
|
require.Equal(t, expectedUserAgent, c.UserAgent())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,23 +92,54 @@ func testClientDefaults(t *testing.T, c *Client) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewClient(t *testing.T) {
|
func TestNewClient(t *testing.T) {
|
||||||
c, err := NewClient(&Credentials{})
|
c, err := NewClient()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testClientDefaults(t, c)
|
testClientDefaults(t, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewClient_Error(t *testing.T) {
|
func TestNewClient_Error(t *testing.T) {
|
||||||
_, err := NewClient(nil)
|
_, err := NewClient()
|
||||||
require.EqualError(t, err, "must provide credentials to initialize *reddit.Client")
|
require.NoError(t, err)
|
||||||
|
|
||||||
errorOpt := func(c *Client) error {
|
errorOpt := func(c *Client) error {
|
||||||
return errors.New("foo")
|
return errors.New("foo")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = NewClient(&Credentials{}, errorOpt)
|
_, err = NewClient(errorOpt)
|
||||||
require.EqualError(t, err, "foo")
|
require.EqualError(t, err, "foo")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewReadonlyClient(t *testing.T) {
|
||||||
|
c, err := NewReadonlyClient()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, c.BaseURL.String(), defaultBaseURLReadonly)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewReadonlyClient_Error(t *testing.T) {
|
||||||
|
_, err := NewReadonlyClient()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
errorOpt := func(c *Client) error {
|
||||||
|
return errors.New("foo")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = NewReadonlyClient(errorOpt)
|
||||||
|
require.EqualError(t, err, "foo")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultClient(t *testing.T) {
|
||||||
|
require.NotNil(t, DefaultClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClient_Readonly_NewRequest(t *testing.T) {
|
||||||
|
c, err := NewReadonlyClient()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req, err := c.NewRequest(http.MethodGet, "r/golang", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, defaultBaseURLReadonly+"/r/golang.json", req.URL.String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestClient_OnRequestComplemented(t *testing.T) {
|
func TestClient_OnRequestComplemented(t *testing.T) {
|
||||||
client, mux, teardown := setup()
|
client, mux, teardown := setup()
|
||||||
defer teardown()
|
defer teardown()
|
||||||
|
|
Loading…
Reference in a new issue