diff --git a/README.md b/README.md
index b22317f..25f313b 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
[![Actions Status](https://github.com/vartanbeno/go-reddit/workflows/tests/badge.svg)](https://github.com/vartanbeno/go-reddit/actions)
[![Go Report Card](https://goreportcard.com/badge/github.com/vartanbeno/go-reddit)](https://goreportcard.com/report/github.com/vartanbeno/go-reddit)
-Featured in [issue 327 of Golang Weekly](https://golangweekly.com/issues/327) 🎉
+**Featured in [issue 327 of Golang Weekly](https://golangweekly.com/issues/327) 🎉**
go-reddit is a Go client library for accessing the Reddit API.
@@ -33,27 +33,36 @@ package main
import "github.com/vartanbeno/go-reddit/reddit"
func main() {
- credentials := &reddit.Credentials{
- ID: "id",
- Secret: "secret",
- Username: "username",
- Password: "password",
- }
- client, _ := reddit.NewClient(credentials)
+ withCredentials := reddit.WithCredentials("id", "secret", "username", "password")
+ client, _ := reddit.NewClient(withCredentials)
}
```
-You can pass in a number of options to `NewClient` that further configure the client. For example, to use a custom HTTP client:
+You can pass in a number of options to `NewClient` to further configure the client (see [reddit/reddit-options.go](reddit/reddit-options.go)). For example, to use a custom HTTP client:
```go
httpClient := &http.Client{Timeout: time.Second * 30}
-client, _ := reddit.NewClient(credentials, reddit.WithHTTPClient(httpClient))
+client, _ := reddit.NewClient(withCredentials, reddit.WithHTTPClient(httpClient))
```
-If this option is not used, it will be set to `&http.Client{}` by default. More options are available in the [reddit/reddit-options.go](reddit/reddit-options.go) file.
+### Read-Only Mode
+
+The global `DefaultClient` variable is a valid, read-only client with limited access to the Reddit API, much like a logged out user. You can initialize your own via `NewReadonlyClient`:
+
+```go
+client, _ := reddit.NewReadonlyClient()
+```
## Examples
+
+ Configure the client from environment variables.
+
+```go
+client, _ := reddit.NewClient(reddit.FromEnv)
+```
+
+
Submit a comment.
diff --git a/examples/client-on-request-completed/main.go b/examples/client-on-request-completed/main.go
index 19601f2..c57c058 100644
--- a/examples/client-on-request-completed/main.go
+++ b/examples/client-on-request-completed/main.go
@@ -18,14 +18,7 @@ func main() {
}
func run() (err error) {
- credentials := &reddit.Credentials{
- ID: "id",
- Secret: "secret",
- Username: "username",
- Password: "password",
- }
-
- client, err := reddit.NewClient(credentials)
+ client, err := reddit.NewReadonlyClient()
if err != nil {
return
}
@@ -35,13 +28,7 @@ func run() (err error) {
client.Subreddit.Search(ctx, "programming", nil)
client.Subreddit.SearchNames(ctx, "monitor")
client.Subreddit.SearchPosts(ctx, "react", "webdev", nil)
- client.User.Posts(ctx, &reddit.ListUserOverviewOptions{
- ListOptions: reddit.ListOptions{
- Limit: 50,
- },
- Sort: "top",
- Time: "month",
- })
+ client.Subreddit.HotPosts(ctx, "golang", &reddit.ListOptions{Limit: 5})
return
}
diff --git a/examples/get-subreddit/main.go b/examples/get-subreddit/main.go
index cf12205..7d4ad6b 100644
--- a/examples/get-subreddit/main.go
+++ b/examples/get-subreddit/main.go
@@ -17,24 +17,11 @@ func main() {
}
func run() (err error) {
- credentials := &reddit.Credentials{
- ID: "id",
- Secret: "secret",
- Username: "username",
- Password: "password",
- }
-
- client, err := reddit.NewClient(credentials)
- if err != nil {
- return
- }
-
- sr, _, err := client.Subreddit.Get(ctx, "golang")
+ sr, _, err := reddit.DefaultClient.Subreddit.Get(ctx, "golang")
if err != nil {
return
}
fmt.Printf("%s was created on %s and has %d subscribers.\n", sr.NamePrefixed, sr.Created.Local(), sr.Subscribers)
-
return
}
diff --git a/examples/get-top-posts/main.go b/examples/get-top-posts/main.go
index 4b36962..bc876b7 100644
--- a/examples/get-top-posts/main.go
+++ b/examples/get-top-posts/main.go
@@ -17,22 +17,10 @@ func main() {
}
func run() (err error) {
- credentials := &reddit.Credentials{
- ID: "id",
- Secret: "secret",
- Username: "username",
- Password: "password",
- }
-
- client, err := reddit.NewClient(credentials)
- if err != nil {
- return
- }
-
// Let's get the top 200 posts of r/golang.
// Reddit returns a maximum of 100 posts at a time,
// so we'll need to separate this into 2 requests.
- result, _, err := client.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{
+ result, _, err := reddit.DefaultClient.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{
ListOptions: reddit.ListOptions{
Limit: 100,
},
@@ -48,7 +36,7 @@ func run() (err error) {
// The SetAfter option sets the id of an item that Reddit
// will use as an anchor point for the returned listing.
- result, _, err = client.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{
+ result, _, err = reddit.DefaultClient.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{
ListOptions: reddit.ListOptions{
Limit: 100,
After: result.After,
diff --git a/examples/stream-posts/main.go b/examples/stream-posts/main.go
index 670a909..9532861 100644
--- a/examples/stream-posts/main.go
+++ b/examples/stream-posts/main.go
@@ -3,8 +3,9 @@ package main
import (
"context"
"fmt"
- "log"
"os"
+ "os/signal"
+ "syscall"
"time"
"github.com/vartanbeno/go-reddit/reddit"
@@ -13,44 +14,35 @@ import (
var ctx = context.Background()
func main() {
- if err := run(); err != nil {
- log.Fatal(err)
- }
-}
+ sig := make(chan os.Signal, 1)
+ defer close(sig)
+ signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
-func run() (err error) {
- credentials := &reddit.Credentials{
- ID: "id",
- Secret: "secret",
- Username: "username",
- Password: "password",
- }
-
- client, err := reddit.NewClient(credentials)
- if err != nil {
- return
- }
-
- posts, errs, stop := client.Stream.Posts("AskReddit", reddit.StreamInterval(time.Second*3), reddit.StreamDiscardInitial)
+ posts, errs, stop := reddit.DefaultClient.Stream.Posts("AskReddit", reddit.StreamInterval(time.Second*3), reddit.StreamDiscardInitial)
defer stop()
- go func() {
- for {
- select {
- case post, ok := <-posts:
- if !ok {
- return
- }
- fmt.Printf("Received post: %s\n", post.Title)
- case err, ok := <-errs:
- if !ok {
- return
- }
- fmt.Fprintf(os.Stderr, "Error! %v\n", err)
- }
- }
- }()
+ timer := time.NewTimer(time.Minute)
+ defer timer.Stop()
- <-time.After(time.Minute)
- return
+ for {
+ select {
+ case post, ok := <-posts:
+ if !ok {
+ return
+ }
+ fmt.Printf("Received post: %s\n", post.Title)
+ case err, ok := <-errs:
+ if !ok {
+ return
+ }
+ fmt.Fprintf(os.Stderr, "Error! %v\n", err)
+ case rcvSig, ok := <-sig:
+ if !ok {
+ return
+ }
+ fmt.Printf("Stopping due to %s signal\n", rcvSig)
+ case <-timer.C:
+ return
+ }
+ }
}
diff --git a/examples/submit-post/main.go b/examples/submit-post/main.go
index 3fe6911..7f03daf 100644
--- a/examples/submit-post/main.go
+++ b/examples/submit-post/main.go
@@ -17,14 +17,9 @@ func main() {
}
func run() (err error) {
- credentials := &reddit.Credentials{
- ID: "id",
- Secret: "secret",
- Username: "username",
- Password: "password",
- }
+ withCredentials := reddit.WithCredentials("id", "secret", "username", "password")
- client, err := reddit.NewClient(credentials)
+ client, err := reddit.NewClient(withCredentials)
if err != nil {
return
}
@@ -51,6 +46,5 @@ func run() (err error) {
}
fmt.Printf("The link post is available at: %s\n", post.URL)
-
return
}
diff --git a/reddit/reddit-options.go b/reddit/reddit-options.go
index 7ace2c3..dee2c4d 100644
--- a/reddit/reddit-options.go
+++ b/reddit/reddit-options.go
@@ -4,11 +4,23 @@ import (
"errors"
"net/http"
"net/url"
+ "os"
)
// Opt is a configuration option to initialize a client.
type Opt func(*Client) error
+// WithCredentials sets the credentials used to authenticate with the Reddit API.
+func WithCredentials(id, secret, username, password string) Opt {
+ return func(c *Client) error {
+ c.ID = id
+ c.Secret = secret
+ c.Username = username
+ c.Password = password
+ return nil
+ }
+}
+
// WithHTTPClient sets the HTTP client which will be used to make requests.
func WithHTTPClient(httpClient *http.Client) Opt {
return func(c *Client) error {
@@ -20,6 +32,16 @@ func WithHTTPClient(httpClient *http.Client) Opt {
}
}
+// WithUserAgent sets the User-Agent header for requests made with the client.
+// Reddit recommends the following format for the user agent:
+// :: (by /u/)
+func WithUserAgent(ua string) Opt {
+ return func(c *Client) error {
+ c.userAgent = ua
+ return nil
+ }
+}
+
// WithBaseURL sets the base URL for the client to make requests to.
func WithBaseURL(u string) Opt {
return func(c *Client) error {
@@ -43,3 +65,25 @@ func WithTokenURL(u string) Opt {
return nil
}
}
+
+// FromEnv configures the client with values from environment variables.
+// Supported environment variables:
+// GO_REDDIT_CLIENT_ID to set the client's id.
+// GO_REDDIT_CLIENT_SECRET to set the client's secret.
+// GO_REDDIT_CLIENT_USERNAME to set the client's username.
+// GO_REDDIT_CLIENT_PASSWORD to set the client's password.
+func FromEnv(c *Client) error {
+ if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" {
+ c.ID = v
+ }
+ if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" {
+ c.Secret = v
+ }
+ if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" {
+ c.Username = v
+ }
+ if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" {
+ c.Password = v
+ }
+ return nil
+}
diff --git a/reddit/reddit-options_test.go b/reddit/reddit-options_test.go
index 31b0248..5a3b3d8 100644
--- a/reddit/reddit-options_test.go
+++ b/reddit/reddit-options_test.go
@@ -1,41 +1,83 @@
package reddit
import (
+ "fmt"
"net/http"
"net/url"
+ "os"
"testing"
"github.com/stretchr/testify/require"
)
+func TestWithCredentials(t *testing.T) {
+ c, err := NewClient(WithCredentials("id1", "secret1", "username1", "password1"))
+ require.NoError(t, err)
+ require.Equal(t, "id1", c.ID)
+ require.Equal(t, "secret1", c.Secret)
+ require.Equal(t, "username1", c.Username)
+ require.Equal(t, "password1", c.Password)
+}
+
func TestWithHTTPClient(t *testing.T) {
- _, err := NewClient(&Credentials{}, WithHTTPClient(nil))
+ _, err := NewClient(WithHTTPClient(nil))
require.EqualError(t, err, "httpClient: cannot be nil")
- _, err = NewClient(&Credentials{}, WithHTTPClient(&http.Client{}))
+ _, err = NewClient(WithHTTPClient(&http.Client{}))
require.NoError(t, err)
}
+func TestWithUserAgent(t *testing.T) {
+ c, err := NewClient(WithUserAgent("test"))
+ require.NoError(t, err)
+ require.Equal(t, "test", c.UserAgent())
+
+ c, err = NewClient(WithUserAgent(""))
+ require.NoError(t, err)
+ require.Equal(t, fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion), c.UserAgent())
+}
+
func TestWithBaseURL(t *testing.T) {
- c, err := NewClient(&Credentials{}, WithBaseURL(":"))
+ c, err := NewClient(WithBaseURL(":"))
urlErr, ok := err.(*url.Error)
require.True(t, ok)
require.Equal(t, "parse", urlErr.Op)
baseURL := "http://localhost:8080"
- c, err = NewClient(&Credentials{}, WithBaseURL(baseURL))
+ c, err = NewClient(WithBaseURL(baseURL))
require.NoError(t, err)
require.Equal(t, baseURL, c.BaseURL.String())
}
func TestWithTokenURL(t *testing.T) {
- c, err := NewClient(&Credentials{}, WithTokenURL(":"))
+ c, err := NewClient(WithTokenURL(":"))
urlErr, ok := err.(*url.Error)
require.True(t, ok)
require.Equal(t, "parse", urlErr.Op)
tokenURL := "http://localhost:8080/api/v1/access_token"
- c, err = NewClient(&Credentials{}, WithTokenURL(tokenURL))
+ c, err = NewClient(WithTokenURL(tokenURL))
require.NoError(t, err)
require.Equal(t, tokenURL, c.TokenURL.String())
}
+
+func TestFromEnv(t *testing.T) {
+ os.Setenv("GO_REDDIT_CLIENT_ID", "id1")
+ defer os.Unsetenv("GO_REDDIT_CLIENT_ID")
+
+ os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1")
+ defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET")
+
+ os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1")
+ defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME")
+
+ os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1")
+ defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD")
+
+ c, err := NewClient(FromEnv)
+ require.NoError(t, err)
+ require.Equal(t, "id1", c.ID)
+ require.Equal(t, "secret1", c.Secret)
+ require.Equal(t, "username1", c.Username)
+ require.Equal(t, "password1", c.Password)
+}
diff --git a/reddit/reddit-user-agent.go b/reddit/reddit-user-agent.go
index 5dea93e..0e69b67 100644
--- a/reddit/reddit-user-agent.go
+++ b/reddit/reddit-user-agent.go
@@ -18,6 +18,8 @@ func cloneRequest(r *http.Request) *http.Request {
}
// Sets the User-Agent header for requests.
+// We need to set a custom user agent because using the one set by the
+// stdlib gives us 429 Too Many Requests responses from the Reddit API.
type userAgentTransport struct {
userAgent string
Base http.RoundTripper
diff --git a/reddit/reddit.go b/reddit/reddit.go
index 044eb63..616c9a6 100644
--- a/reddit/reddit.go
+++ b/reddit/reddit.go
@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/json"
- "errors"
"fmt"
"io"
"io/ioutil"
@@ -21,8 +20,9 @@ const (
libraryName = "github.com/vartanbeno/go-reddit"
libraryVersion = "1.0.0"
- defaultBaseURL = "https://oauth.reddit.com"
- defaultTokenURL = "https://www.reddit.com/api/v1/access_token"
+ defaultBaseURL = "https://oauth.reddit.com"
+ defaultBaseURLReadonly = "https://reddit.com"
+ defaultTokenURL = "https://www.reddit.com/api/v1/access_token"
mediaTypeJSON = "application/json"
mediaTypeForm = "application/x-www-form-urlencoded"
@@ -32,6 +32,9 @@ const (
headerUserAgent = "User-Agent"
)
+// DefaultClient is a readonly client with limited access to the Reddit API.
+var DefaultClient, _ = NewReadonlyClient()
+
// RequestCompletionCallback defines the type of the request callback function.
type RequestCompletionCallback func(*http.Request, *http.Response)
@@ -113,16 +116,9 @@ func newClient() *Client {
}
// NewClient returns a new Reddit API client.
-func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
- if creds == nil {
- return nil, errors.New("must provide credentials to initialize *reddit.Client")
- }
-
+// Use an Opt to configure the client credentials, such as WithCredentials or FromEnv.
+func NewClient(opts ...Opt) (*Client, error) {
client := newClient()
- client.ID = creds.ID
- client.Secret = creds.Secret
- client.Username = creds.Username
- client.Password = creds.Password
for _, opt := range opts {
if err := opt(client); err != nil {
@@ -134,31 +130,14 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
client.client = &http.Client{}
}
- // We need to set a custom user agent, because using the one set by the
- // stdlib gives us 429 Too Many Requests responses from the Reddit API.
userAgentTransport := &userAgentTransport{
userAgent: client.UserAgent(),
Base: client.client.Transport,
}
client.client.Transport = userAgentTransport
- // todo...
- // Some endpoints (notably the ones to get random subreddits/posts) redirect to a
- // reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
- // host is changed to oauth.reddit.com
if client.client.CheckRedirect == nil {
- client.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
- redirectURL := req.URL.String()
- redirectURL = strings.Replace(redirectURL, "https://www.reddit.com", defaultBaseURL, 1)
-
- reqURL, err := url.Parse(redirectURL)
- if err != nil {
- return err
- }
- req.URL = reqURL
-
- return nil
- }
+ client.client.CheckRedirect = client.redirect
}
oauthTransport := oauthTransport(client)
@@ -167,10 +146,71 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) {
return client, nil
}
+// NewReadonlyClient returns a new read-only Reddit API client.
+// The client will have limited access to the Reddit API.
+// Options that modify credentials (such as WithCredentials or FromEnv) won't have any effect on this client.
+func NewReadonlyClient(opts ...Opt) (*Client, error) {
+ client := newClient()
+ client.BaseURL, _ = url.Parse(defaultBaseURLReadonly)
+
+ for _, opt := range opts {
+ if err := opt(client); err != nil {
+ return nil, err
+ }
+ }
+
+ if client.client == nil {
+ client.client = &http.Client{}
+ }
+
+ userAgentTransport := &userAgentTransport{
+ userAgent: client.UserAgent(),
+ Base: client.client.Transport,
+ }
+ client.client.Transport = userAgentTransport
+
+ return client, nil
+}
+
+// todo...
+// Some endpoints (notably the ones to get random subreddits/posts) redirect to a
+// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's
+// host is changed to oauth.reddit.com
+func (c *Client) redirect(req *http.Request, via []*http.Request) error {
+ redirectURL := req.URL.String()
+ redirectURL = strings.Replace(redirectURL, "https://www.reddit.com", defaultBaseURL, 1)
+
+ reqURL, err := url.Parse(redirectURL)
+ if err != nil {
+ return err
+ }
+ req.URL = reqURL
+
+ return nil
+}
+
+// The readonly Reddit url needs .json at the end of its path to return responses in JSON instead of HTML.
+func (c *Client) appendJSONExtensionToRequestPath(req *http.Request) {
+ readonlyURL, err := url.Parse(defaultBaseURLReadonly)
+ if err != nil {
+ return
+ }
+
+ if req.URL.Host != readonlyURL.Host {
+ return
+ }
+
+ req.URL.Path += ".json"
+}
+
// UserAgent returns the client's user agent.
func (c *Client) UserAgent() string {
if c.userAgent == "" {
- c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username)
+ userAgent := fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion)
+ if c.Username != "" {
+ userAgent += fmt.Sprintf(" (by /u/%s)", c.Username)
+ }
+ c.userAgent = userAgent
}
return c.userAgent
}
@@ -198,6 +238,7 @@ func (c *Client) NewRequest(method string, path string, body interface{}) (*http
return nil, err
}
+ c.appendJSONExtensionToRequestPath(req)
req.Header.Add(headerContentType, mediaTypeJSON)
req.Header.Add(headerAccept, mediaTypeJSON)
@@ -218,6 +259,7 @@ func (c *Client) NewRequestWithForm(method string, path string, form url.Values)
return nil, err
}
+ c.appendJSONExtensionToRequestPath(req)
req.Header.Add(headerContentType, mediaTypeForm)
req.Header.Add(headerAccept, mediaTypeJSON)
diff --git a/reddit/reddit_test.go b/reddit/reddit_test.go
index 7820605..746df69 100644
--- a/reddit/reddit_test.go
+++ b/reddit/reddit_test.go
@@ -32,7 +32,7 @@ func setup() (*Client, *http.ServeMux, func()) {
})
client, _ := NewClient(
- &Credentials{"id1", "secret1", "user1", "password1"},
+ WithCredentials("id1", "secret1", "user1", "password1"),
WithBaseURL(server.URL),
WithTokenURL(server.URL+"/api/v1/access_token"),
)
@@ -82,7 +82,7 @@ func testClientServices(t *testing.T, c *Client) {
}
func testClientDefaultUserAgent(t *testing.T, c *Client) {
- expectedUserAgent := fmt.Sprintf("golang:%s:v%s (by /u/)", libraryName, libraryVersion)
+ expectedUserAgent := fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion)
require.Equal(t, expectedUserAgent, c.UserAgent())
}
@@ -92,23 +92,54 @@ func testClientDefaults(t *testing.T, c *Client) {
}
func TestNewClient(t *testing.T) {
- c, err := NewClient(&Credentials{})
+ c, err := NewClient()
require.NoError(t, err)
testClientDefaults(t, c)
}
func TestNewClient_Error(t *testing.T) {
- _, err := NewClient(nil)
- require.EqualError(t, err, "must provide credentials to initialize *reddit.Client")
+ _, err := NewClient()
+ require.NoError(t, err)
errorOpt := func(c *Client) error {
return errors.New("foo")
}
- _, err = NewClient(&Credentials{}, errorOpt)
+ _, err = NewClient(errorOpt)
require.EqualError(t, err, "foo")
}
+func TestNewReadonlyClient(t *testing.T) {
+ c, err := NewReadonlyClient()
+ require.NoError(t, err)
+ require.Equal(t, c.BaseURL.String(), defaultBaseURLReadonly)
+}
+
+func TestNewReadonlyClient_Error(t *testing.T) {
+ _, err := NewReadonlyClient()
+ require.NoError(t, err)
+
+ errorOpt := func(c *Client) error {
+ return errors.New("foo")
+ }
+
+ _, err = NewReadonlyClient(errorOpt)
+ require.EqualError(t, err, "foo")
+}
+
+func TestDefaultClient(t *testing.T) {
+ require.NotNil(t, DefaultClient)
+}
+
+func TestClient_Readonly_NewRequest(t *testing.T) {
+ c, err := NewReadonlyClient()
+ require.NoError(t, err)
+
+ req, err := c.NewRequest(http.MethodGet, "r/golang", nil)
+ require.NoError(t, err)
+ require.Equal(t, defaultBaseURLReadonly+"/r/golang.json", req.URL.String())
+}
+
func TestClient_OnRequestComplemented(t *testing.T) {
client, mux, teardown := setup()
defer teardown()