diff --git a/README.md b/README.md index b22317f..25f313b 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Actions Status](https://github.com/vartanbeno/go-reddit/workflows/tests/badge.svg)](https://github.com/vartanbeno/go-reddit/actions) [![Go Report Card](https://goreportcard.com/badge/github.com/vartanbeno/go-reddit)](https://goreportcard.com/report/github.com/vartanbeno/go-reddit) -Featured in [issue 327 of Golang Weekly](https://golangweekly.com/issues/327) 🎉 +**Featured in [issue 327 of Golang Weekly](https://golangweekly.com/issues/327) 🎉** go-reddit is a Go client library for accessing the Reddit API. @@ -33,27 +33,36 @@ package main import "github.com/vartanbeno/go-reddit/reddit" func main() { - credentials := &reddit.Credentials{ - ID: "id", - Secret: "secret", - Username: "username", - Password: "password", - } - client, _ := reddit.NewClient(credentials) + withCredentials := reddit.WithCredentials("id", "secret", "username", "password") + client, _ := reddit.NewClient(withCredentials) } ``` -You can pass in a number of options to `NewClient` that further configure the client. For example, to use a custom HTTP client: +You can pass in a number of options to `NewClient` to further configure the client (see [reddit/reddit-options.go](reddit/reddit-options.go)). For example, to use a custom HTTP client: ```go httpClient := &http.Client{Timeout: time.Second * 30} -client, _ := reddit.NewClient(credentials, reddit.WithHTTPClient(httpClient)) +client, _ := reddit.NewClient(withCredentials, reddit.WithHTTPClient(httpClient)) ``` -If this option is not used, it will be set to `&http.Client{}` by default. More options are available in the [reddit/reddit-options.go](reddit/reddit-options.go) file. +### Read-Only Mode + +The global `DefaultClient` variable is a valid, read-only client with limited access to the Reddit API, much like a logged out user. You can initialize your own via `NewReadonlyClient`: + +```go +client, _ := reddit.NewReadonlyClient() +``` ## Examples +
+ Configure the client from environment variables. + +```go +client, _ := reddit.NewClient(reddit.FromEnv) +``` +
+
Submit a comment. diff --git a/examples/client-on-request-completed/main.go b/examples/client-on-request-completed/main.go index 19601f2..c57c058 100644 --- a/examples/client-on-request-completed/main.go +++ b/examples/client-on-request-completed/main.go @@ -18,14 +18,7 @@ func main() { } func run() (err error) { - credentials := &reddit.Credentials{ - ID: "id", - Secret: "secret", - Username: "username", - Password: "password", - } - - client, err := reddit.NewClient(credentials) + client, err := reddit.NewReadonlyClient() if err != nil { return } @@ -35,13 +28,7 @@ func run() (err error) { client.Subreddit.Search(ctx, "programming", nil) client.Subreddit.SearchNames(ctx, "monitor") client.Subreddit.SearchPosts(ctx, "react", "webdev", nil) - client.User.Posts(ctx, &reddit.ListUserOverviewOptions{ - ListOptions: reddit.ListOptions{ - Limit: 50, - }, - Sort: "top", - Time: "month", - }) + client.Subreddit.HotPosts(ctx, "golang", &reddit.ListOptions{Limit: 5}) return } diff --git a/examples/get-subreddit/main.go b/examples/get-subreddit/main.go index cf12205..7d4ad6b 100644 --- a/examples/get-subreddit/main.go +++ b/examples/get-subreddit/main.go @@ -17,24 +17,11 @@ func main() { } func run() (err error) { - credentials := &reddit.Credentials{ - ID: "id", - Secret: "secret", - Username: "username", - Password: "password", - } - - client, err := reddit.NewClient(credentials) - if err != nil { - return - } - - sr, _, err := client.Subreddit.Get(ctx, "golang") + sr, _, err := reddit.DefaultClient.Subreddit.Get(ctx, "golang") if err != nil { return } fmt.Printf("%s was created on %s and has %d subscribers.\n", sr.NamePrefixed, sr.Created.Local(), sr.Subscribers) - return } diff --git a/examples/get-top-posts/main.go b/examples/get-top-posts/main.go index 4b36962..bc876b7 100644 --- a/examples/get-top-posts/main.go +++ b/examples/get-top-posts/main.go @@ -17,22 +17,10 @@ func main() { } func run() (err error) { - credentials := &reddit.Credentials{ - ID: "id", - Secret: "secret", - Username: "username", - Password: "password", - } - - client, err := reddit.NewClient(credentials) - if err != nil { - return - } - // Let's get the top 200 posts of r/golang. // Reddit returns a maximum of 100 posts at a time, // so we'll need to separate this into 2 requests. - result, _, err := client.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{ + result, _, err := reddit.DefaultClient.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{ ListOptions: reddit.ListOptions{ Limit: 100, }, @@ -48,7 +36,7 @@ func run() (err error) { // The SetAfter option sets the id of an item that Reddit // will use as an anchor point for the returned listing. - result, _, err = client.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{ + result, _, err = reddit.DefaultClient.Subreddit.TopPosts(ctx, "golang", &reddit.ListPostOptions{ ListOptions: reddit.ListOptions{ Limit: 100, After: result.After, diff --git a/examples/stream-posts/main.go b/examples/stream-posts/main.go index 670a909..9532861 100644 --- a/examples/stream-posts/main.go +++ b/examples/stream-posts/main.go @@ -3,8 +3,9 @@ package main import ( "context" "fmt" - "log" "os" + "os/signal" + "syscall" "time" "github.com/vartanbeno/go-reddit/reddit" @@ -13,44 +14,35 @@ import ( var ctx = context.Background() func main() { - if err := run(); err != nil { - log.Fatal(err) - } -} + sig := make(chan os.Signal, 1) + defer close(sig) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) -func run() (err error) { - credentials := &reddit.Credentials{ - ID: "id", - Secret: "secret", - Username: "username", - Password: "password", - } - - client, err := reddit.NewClient(credentials) - if err != nil { - return - } - - posts, errs, stop := client.Stream.Posts("AskReddit", reddit.StreamInterval(time.Second*3), reddit.StreamDiscardInitial) + posts, errs, stop := reddit.DefaultClient.Stream.Posts("AskReddit", reddit.StreamInterval(time.Second*3), reddit.StreamDiscardInitial) defer stop() - go func() { - for { - select { - case post, ok := <-posts: - if !ok { - return - } - fmt.Printf("Received post: %s\n", post.Title) - case err, ok := <-errs: - if !ok { - return - } - fmt.Fprintf(os.Stderr, "Error! %v\n", err) - } - } - }() + timer := time.NewTimer(time.Minute) + defer timer.Stop() - <-time.After(time.Minute) - return + for { + select { + case post, ok := <-posts: + if !ok { + return + } + fmt.Printf("Received post: %s\n", post.Title) + case err, ok := <-errs: + if !ok { + return + } + fmt.Fprintf(os.Stderr, "Error! %v\n", err) + case rcvSig, ok := <-sig: + if !ok { + return + } + fmt.Printf("Stopping due to %s signal\n", rcvSig) + case <-timer.C: + return + } + } } diff --git a/examples/submit-post/main.go b/examples/submit-post/main.go index 3fe6911..7f03daf 100644 --- a/examples/submit-post/main.go +++ b/examples/submit-post/main.go @@ -17,14 +17,9 @@ func main() { } func run() (err error) { - credentials := &reddit.Credentials{ - ID: "id", - Secret: "secret", - Username: "username", - Password: "password", - } + withCredentials := reddit.WithCredentials("id", "secret", "username", "password") - client, err := reddit.NewClient(credentials) + client, err := reddit.NewClient(withCredentials) if err != nil { return } @@ -51,6 +46,5 @@ func run() (err error) { } fmt.Printf("The link post is available at: %s\n", post.URL) - return } diff --git a/reddit/reddit-options.go b/reddit/reddit-options.go index 7ace2c3..dee2c4d 100644 --- a/reddit/reddit-options.go +++ b/reddit/reddit-options.go @@ -4,11 +4,23 @@ import ( "errors" "net/http" "net/url" + "os" ) // Opt is a configuration option to initialize a client. type Opt func(*Client) error +// WithCredentials sets the credentials used to authenticate with the Reddit API. +func WithCredentials(id, secret, username, password string) Opt { + return func(c *Client) error { + c.ID = id + c.Secret = secret + c.Username = username + c.Password = password + return nil + } +} + // WithHTTPClient sets the HTTP client which will be used to make requests. func WithHTTPClient(httpClient *http.Client) Opt { return func(c *Client) error { @@ -20,6 +32,16 @@ func WithHTTPClient(httpClient *http.Client) Opt { } } +// WithUserAgent sets the User-Agent header for requests made with the client. +// Reddit recommends the following format for the user agent: +// :: (by /u/) +func WithUserAgent(ua string) Opt { + return func(c *Client) error { + c.userAgent = ua + return nil + } +} + // WithBaseURL sets the base URL for the client to make requests to. func WithBaseURL(u string) Opt { return func(c *Client) error { @@ -43,3 +65,25 @@ func WithTokenURL(u string) Opt { return nil } } + +// FromEnv configures the client with values from environment variables. +// Supported environment variables: +// GO_REDDIT_CLIENT_ID to set the client's id. +// GO_REDDIT_CLIENT_SECRET to set the client's secret. +// GO_REDDIT_CLIENT_USERNAME to set the client's username. +// GO_REDDIT_CLIENT_PASSWORD to set the client's password. +func FromEnv(c *Client) error { + if v := os.Getenv("GO_REDDIT_CLIENT_ID"); v != "" { + c.ID = v + } + if v := os.Getenv("GO_REDDIT_CLIENT_SECRET"); v != "" { + c.Secret = v + } + if v := os.Getenv("GO_REDDIT_CLIENT_USERNAME"); v != "" { + c.Username = v + } + if v := os.Getenv("GO_REDDIT_CLIENT_PASSWORD"); v != "" { + c.Password = v + } + return nil +} diff --git a/reddit/reddit-options_test.go b/reddit/reddit-options_test.go index 31b0248..5a3b3d8 100644 --- a/reddit/reddit-options_test.go +++ b/reddit/reddit-options_test.go @@ -1,41 +1,83 @@ package reddit import ( + "fmt" "net/http" "net/url" + "os" "testing" "github.com/stretchr/testify/require" ) +func TestWithCredentials(t *testing.T) { + c, err := NewClient(WithCredentials("id1", "secret1", "username1", "password1")) + require.NoError(t, err) + require.Equal(t, "id1", c.ID) + require.Equal(t, "secret1", c.Secret) + require.Equal(t, "username1", c.Username) + require.Equal(t, "password1", c.Password) +} + func TestWithHTTPClient(t *testing.T) { - _, err := NewClient(&Credentials{}, WithHTTPClient(nil)) + _, err := NewClient(WithHTTPClient(nil)) require.EqualError(t, err, "httpClient: cannot be nil") - _, err = NewClient(&Credentials{}, WithHTTPClient(&http.Client{})) + _, err = NewClient(WithHTTPClient(&http.Client{})) require.NoError(t, err) } +func TestWithUserAgent(t *testing.T) { + c, err := NewClient(WithUserAgent("test")) + require.NoError(t, err) + require.Equal(t, "test", c.UserAgent()) + + c, err = NewClient(WithUserAgent("")) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion), c.UserAgent()) +} + func TestWithBaseURL(t *testing.T) { - c, err := NewClient(&Credentials{}, WithBaseURL(":")) + c, err := NewClient(WithBaseURL(":")) urlErr, ok := err.(*url.Error) require.True(t, ok) require.Equal(t, "parse", urlErr.Op) baseURL := "http://localhost:8080" - c, err = NewClient(&Credentials{}, WithBaseURL(baseURL)) + c, err = NewClient(WithBaseURL(baseURL)) require.NoError(t, err) require.Equal(t, baseURL, c.BaseURL.String()) } func TestWithTokenURL(t *testing.T) { - c, err := NewClient(&Credentials{}, WithTokenURL(":")) + c, err := NewClient(WithTokenURL(":")) urlErr, ok := err.(*url.Error) require.True(t, ok) require.Equal(t, "parse", urlErr.Op) tokenURL := "http://localhost:8080/api/v1/access_token" - c, err = NewClient(&Credentials{}, WithTokenURL(tokenURL)) + c, err = NewClient(WithTokenURL(tokenURL)) require.NoError(t, err) require.Equal(t, tokenURL, c.TokenURL.String()) } + +func TestFromEnv(t *testing.T) { + os.Setenv("GO_REDDIT_CLIENT_ID", "id1") + defer os.Unsetenv("GO_REDDIT_CLIENT_ID") + + os.Setenv("GO_REDDIT_CLIENT_SECRET", "secret1") + defer os.Unsetenv("GO_REDDIT_CLIENT_SECRET") + + os.Setenv("GO_REDDIT_CLIENT_USERNAME", "username1") + defer os.Unsetenv("GO_REDDIT_CLIENT_USERNAME") + + os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1") + defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD") + + c, err := NewClient(FromEnv) + require.NoError(t, err) + require.Equal(t, "id1", c.ID) + require.Equal(t, "secret1", c.Secret) + require.Equal(t, "username1", c.Username) + require.Equal(t, "password1", c.Password) +} diff --git a/reddit/reddit-user-agent.go b/reddit/reddit-user-agent.go index 5dea93e..0e69b67 100644 --- a/reddit/reddit-user-agent.go +++ b/reddit/reddit-user-agent.go @@ -18,6 +18,8 @@ func cloneRequest(r *http.Request) *http.Request { } // Sets the User-Agent header for requests. +// We need to set a custom user agent because using the one set by the +// stdlib gives us 429 Too Many Requests responses from the Reddit API. type userAgentTransport struct { userAgent string Base http.RoundTripper diff --git a/reddit/reddit.go b/reddit/reddit.go index 044eb63..616c9a6 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "io/ioutil" @@ -21,8 +20,9 @@ const ( libraryName = "github.com/vartanbeno/go-reddit" libraryVersion = "1.0.0" - defaultBaseURL = "https://oauth.reddit.com" - defaultTokenURL = "https://www.reddit.com/api/v1/access_token" + defaultBaseURL = "https://oauth.reddit.com" + defaultBaseURLReadonly = "https://reddit.com" + defaultTokenURL = "https://www.reddit.com/api/v1/access_token" mediaTypeJSON = "application/json" mediaTypeForm = "application/x-www-form-urlencoded" @@ -32,6 +32,9 @@ const ( headerUserAgent = "User-Agent" ) +// DefaultClient is a readonly client with limited access to the Reddit API. +var DefaultClient, _ = NewReadonlyClient() + // RequestCompletionCallback defines the type of the request callback function. type RequestCompletionCallback func(*http.Request, *http.Response) @@ -113,16 +116,9 @@ func newClient() *Client { } // NewClient returns a new Reddit API client. -func NewClient(creds *Credentials, opts ...Opt) (*Client, error) { - if creds == nil { - return nil, errors.New("must provide credentials to initialize *reddit.Client") - } - +// Use an Opt to configure the client credentials, such as WithCredentials or FromEnv. +func NewClient(opts ...Opt) (*Client, error) { client := newClient() - client.ID = creds.ID - client.Secret = creds.Secret - client.Username = creds.Username - client.Password = creds.Password for _, opt := range opts { if err := opt(client); err != nil { @@ -134,31 +130,14 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) { client.client = &http.Client{} } - // We need to set a custom user agent, because using the one set by the - // stdlib gives us 429 Too Many Requests responses from the Reddit API. userAgentTransport := &userAgentTransport{ userAgent: client.UserAgent(), Base: client.client.Transport, } client.client.Transport = userAgentTransport - // todo... - // Some endpoints (notably the ones to get random subreddits/posts) redirect to a - // reddit.com url, which returns a 403 Forbidden for some reason, unless the url's - // host is changed to oauth.reddit.com if client.client.CheckRedirect == nil { - client.client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - redirectURL := req.URL.String() - redirectURL = strings.Replace(redirectURL, "https://www.reddit.com", defaultBaseURL, 1) - - reqURL, err := url.Parse(redirectURL) - if err != nil { - return err - } - req.URL = reqURL - - return nil - } + client.client.CheckRedirect = client.redirect } oauthTransport := oauthTransport(client) @@ -167,10 +146,71 @@ func NewClient(creds *Credentials, opts ...Opt) (*Client, error) { return client, nil } +// NewReadonlyClient returns a new read-only Reddit API client. +// The client will have limited access to the Reddit API. +// Options that modify credentials (such as WithCredentials or FromEnv) won't have any effect on this client. +func NewReadonlyClient(opts ...Opt) (*Client, error) { + client := newClient() + client.BaseURL, _ = url.Parse(defaultBaseURLReadonly) + + for _, opt := range opts { + if err := opt(client); err != nil { + return nil, err + } + } + + if client.client == nil { + client.client = &http.Client{} + } + + userAgentTransport := &userAgentTransport{ + userAgent: client.UserAgent(), + Base: client.client.Transport, + } + client.client.Transport = userAgentTransport + + return client, nil +} + +// todo... +// Some endpoints (notably the ones to get random subreddits/posts) redirect to a +// reddit.com url, which returns a 403 Forbidden for some reason, unless the url's +// host is changed to oauth.reddit.com +func (c *Client) redirect(req *http.Request, via []*http.Request) error { + redirectURL := req.URL.String() + redirectURL = strings.Replace(redirectURL, "https://www.reddit.com", defaultBaseURL, 1) + + reqURL, err := url.Parse(redirectURL) + if err != nil { + return err + } + req.URL = reqURL + + return nil +} + +// The readonly Reddit url needs .json at the end of its path to return responses in JSON instead of HTML. +func (c *Client) appendJSONExtensionToRequestPath(req *http.Request) { + readonlyURL, err := url.Parse(defaultBaseURLReadonly) + if err != nil { + return + } + + if req.URL.Host != readonlyURL.Host { + return + } + + req.URL.Path += ".json" +} + // UserAgent returns the client's user agent. func (c *Client) UserAgent() string { if c.userAgent == "" { - c.userAgent = fmt.Sprintf("golang:%s:v%s (by /u/%s)", libraryName, libraryVersion, c.Username) + userAgent := fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion) + if c.Username != "" { + userAgent += fmt.Sprintf(" (by /u/%s)", c.Username) + } + c.userAgent = userAgent } return c.userAgent } @@ -198,6 +238,7 @@ func (c *Client) NewRequest(method string, path string, body interface{}) (*http return nil, err } + c.appendJSONExtensionToRequestPath(req) req.Header.Add(headerContentType, mediaTypeJSON) req.Header.Add(headerAccept, mediaTypeJSON) @@ -218,6 +259,7 @@ func (c *Client) NewRequestWithForm(method string, path string, form url.Values) return nil, err } + c.appendJSONExtensionToRequestPath(req) req.Header.Add(headerContentType, mediaTypeForm) req.Header.Add(headerAccept, mediaTypeJSON) diff --git a/reddit/reddit_test.go b/reddit/reddit_test.go index 7820605..746df69 100644 --- a/reddit/reddit_test.go +++ b/reddit/reddit_test.go @@ -32,7 +32,7 @@ func setup() (*Client, *http.ServeMux, func()) { }) client, _ := NewClient( - &Credentials{"id1", "secret1", "user1", "password1"}, + WithCredentials("id1", "secret1", "user1", "password1"), WithBaseURL(server.URL), WithTokenURL(server.URL+"/api/v1/access_token"), ) @@ -82,7 +82,7 @@ func testClientServices(t *testing.T, c *Client) { } func testClientDefaultUserAgent(t *testing.T, c *Client) { - expectedUserAgent := fmt.Sprintf("golang:%s:v%s (by /u/)", libraryName, libraryVersion) + expectedUserAgent := fmt.Sprintf("golang:%s:v%s", libraryName, libraryVersion) require.Equal(t, expectedUserAgent, c.UserAgent()) } @@ -92,23 +92,54 @@ func testClientDefaults(t *testing.T, c *Client) { } func TestNewClient(t *testing.T) { - c, err := NewClient(&Credentials{}) + c, err := NewClient() require.NoError(t, err) testClientDefaults(t, c) } func TestNewClient_Error(t *testing.T) { - _, err := NewClient(nil) - require.EqualError(t, err, "must provide credentials to initialize *reddit.Client") + _, err := NewClient() + require.NoError(t, err) errorOpt := func(c *Client) error { return errors.New("foo") } - _, err = NewClient(&Credentials{}, errorOpt) + _, err = NewClient(errorOpt) require.EqualError(t, err, "foo") } +func TestNewReadonlyClient(t *testing.T) { + c, err := NewReadonlyClient() + require.NoError(t, err) + require.Equal(t, c.BaseURL.String(), defaultBaseURLReadonly) +} + +func TestNewReadonlyClient_Error(t *testing.T) { + _, err := NewReadonlyClient() + require.NoError(t, err) + + errorOpt := func(c *Client) error { + return errors.New("foo") + } + + _, err = NewReadonlyClient(errorOpt) + require.EqualError(t, err, "foo") +} + +func TestDefaultClient(t *testing.T) { + require.NotNil(t, DefaultClient) +} + +func TestClient_Readonly_NewRequest(t *testing.T) { + c, err := NewReadonlyClient() + require.NoError(t, err) + + req, err := c.NewRequest(http.MethodGet, "r/golang", nil) + require.NoError(t, err) + require.Equal(t, defaultBaseURLReadonly+"/r/golang.json", req.URL.String()) +} + func TestClient_OnRequestComplemented(t *testing.T) { client, mux, teardown := setup() defer teardown()