diff --git a/reddit_oauth.go b/reddit_oauth.go index ce22113..a2783cd 100644 --- a/reddit_oauth.go +++ b/reddit_oauth.go @@ -55,7 +55,7 @@ type oauth2Config struct { tokenURL string // We need to set a custom user agent, because using the one set by default by the - // stdlib gives us 429 Too Many Request responses from the Reddit API + // stdlib gives us 429 Too Many Request responses from the Reddit API. userAgentTransport *userAgentTransport } @@ -74,8 +74,6 @@ func oauth2Transport(c oauth2Config) *oauth2.Transport { EndpointParams: params, } - // We need to set a custom user agent, because using the one set by default by the - // stdlib gives us 429 Too Many Request responses from the Reddit API httpClient := &http.Client{Transport: c.userAgentTransport} ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) @@ -87,8 +85,8 @@ func oauth2Transport(c oauth2Config) *oauth2.Transport { return tr } -// WithOAuth2 sets the necessary values for the client to authenticate via OAuth2 -func WithOAuth2(id, secret, username, password string) Opt { +// WithCredentials sets the necessary values for the client to authenticate via OAuth2. +func WithCredentials(id, secret, username, password string) Opt { return func(c *Client) error { c.ID = id c.Secret = secret diff --git a/reddit_options.go b/reddit_options.go index 0d56d5f..ad27ee2 100644 --- a/reddit_options.go +++ b/reddit_options.go @@ -5,7 +5,7 @@ import ( "os" ) -// Opt is a configuration option to initialize a client +// Opt is a configuration option to initialize a client. type Opt func(*Client) error // FromEnv configures the client with values from environment variables. @@ -35,7 +35,7 @@ func FromEnv(c *Client) error { return nil } -// WithBaseURL sets the base URL for the client to make requests to +// WithBaseURL sets the base URL for the client to make requests to. func WithBaseURL(u string) Opt { return func(c *Client) error { url, err := url.Parse(u) @@ -48,7 +48,7 @@ func WithBaseURL(u string) Opt { } } -// WithTokenURL sets the url used to get access tokens +// WithTokenURL sets the url used to get access tokens. func WithTokenURL(u string) Opt { return func(c *Client) error { url, err := url.Parse(u) diff --git a/reddit_test.go b/reddit_test.go index 7f90b37..6857e9c 100644 --- a/reddit_test.go +++ b/reddit_test.go @@ -39,7 +39,7 @@ func setup() { }) client, _ = NewClient(nil, - WithOAuth2("id1", "secret1", "user1", "password1"), + WithCredentials("id1", "secret1", "user1", "password1"), WithBaseURL(server.URL), WithTokenURL(server.URL+"/api/v1/access_token"), ) diff --git a/subreddit.go b/subreddit.go index 9a9718e..ea48cc0 100644 --- a/subreddit.go +++ b/subreddit.go @@ -40,10 +40,10 @@ type Moderator struct { Permissions []string `json:"mod_permissions"` } -func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { +func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) { path := sort - if len(subreddits) > 0 { - path = fmt.Sprintf("r/%s/%s", strings.Join(subreddits, "+"), sort) + if subreddit != "" { + path = fmt.Sprintf("r/%s/%s", subreddit, sort) } form := newSearchOptions(opts...) @@ -63,36 +63,41 @@ func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddits return root.getPosts(), resp, nil } -// Hot returns the hottest posts from the specified subreddits. +// Hot returns the hottest posts from the specified subreddit. +// To search through multiple, separate the names with a plus (+), e.g. "golang+test". // If none are defined, it returns the ones from your subscribed subreddits. // Note: when looking for hot posts in a subreddit, it will include the stickied // posts (if any) PLUS posts from the limit parameter (25 by default). -func (s *SubredditService) Hot(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { - return s.getPosts(ctx, "hot", subreddits, opts...) +func (s *SubredditService) Hot(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) { + return s.getPosts(ctx, "hot", subreddit, opts...) } -// New returns the newest posts from the specified subreddits. +// New returns the newest posts from the specified subreddit. +// To search through multiple, separate the names with a plus (+), e.g. "golang+test". // If none are defined, it returns the ones from your subscribed subreddits. -func (s *SubredditService) New(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { - return s.getPosts(ctx, "new", subreddits, opts...) +func (s *SubredditService) New(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) { + return s.getPosts(ctx, "new", subreddit, opts...) } -// Rising returns the rising posts from the specified subreddits. +// Rising returns the rising posts from the specified subreddit. +// To search through multiple, separate the names with a plus (+), e.g. "golang+test". // If none are defined, it returns the ones from your subscribed subreddits. -func (s *SubredditService) Rising(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { - return s.getPosts(ctx, "rising", subreddits, opts...) +func (s *SubredditService) Rising(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) { + return s.getPosts(ctx, "rising", subreddit, opts...) } -// Controversial returns the most controversial posts from the specified subreddits. +// Controversial returns the most controversial posts from the specified subreddit. +// To search through multiple, separate the names with a plus (+), e.g. "golang+test". // If none are defined, it returns the ones from your subscribed subreddits. -func (s *SubredditService) Controversial(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { - return s.getPosts(ctx, "controversial", subreddits, opts...) +func (s *SubredditService) Controversial(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) { + return s.getPosts(ctx, "controversial", subreddit, opts...) } -// Top returns the top posts from the specified subreddits. +// Top returns the top posts from the specified subreddit. +// To search through multiple, separate the names with a plus (+), e.g. "golang+test". // If none are defined, it returns the ones from your subscribed subreddits. -func (s *SubredditService) Top(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { - return s.getPosts(ctx, "top", subreddits, opts...) +func (s *SubredditService) Top(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) { + return s.getPosts(ctx, "top", subreddit, opts...) } // Get gets a subreddit by name. @@ -243,19 +248,23 @@ func (s *SubredditService) SearchNames(ctx context.Context, query string) ([]str return root.Names, resp, nil } -// SearchPosts searches for posts. -// If the list of subreddits provided is empty, the search is run against r/all. -func (s *SubredditService) SearchPosts(ctx context.Context, query string, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { - if len(subreddits) > 0 { +// SearchPosts searches for posts in the specified subreddit. +// To search through multiple, separate the names with a plus (+), e.g. "golang+test". +// If no subreddit is provided, the search is run against r/all. +func (s *SubredditService) SearchPosts(ctx context.Context, query string, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) { + if subreddit == "" { + subreddit = "all" + } + + notAll := !strings.EqualFold(subreddit, "all") + if notAll { opts = append(opts, setRestrict) - } else { - subreddits = append(subreddits, "all") } opts = append(opts, setQuery(query)) form := newSearchOptions(opts...) - path := fmt.Sprintf("r/%s/search", strings.Join(subreddits, "+")) + path := fmt.Sprintf("r/%s/search", subreddit) path = addQuery(path, form) req, err := s.client.NewRequest(http.MethodGet, path, nil) diff --git a/subreddit_test.go b/subreddit_test.go index 5d89c13..5d9284a 100644 --- a/subreddit_test.go +++ b/subreddit_test.go @@ -231,7 +231,7 @@ func TestSubredditService_Hot(t *testing.T) { fmt.Fprint(w, blob) }) - posts, _, err := client.Subreddit.Hot(ctx, []string{"test"}) + posts, _, err := client.Subreddit.Hot(ctx, "test") assert.NoError(t, err) assert.Equal(t, expectedPosts, posts) } @@ -248,7 +248,7 @@ func TestSubredditService_New(t *testing.T) { fmt.Fprint(w, blob) }) - posts, _, err := client.Subreddit.New(ctx, []string{"test"}) + posts, _, err := client.Subreddit.New(ctx, "test") assert.NoError(t, err) assert.Equal(t, expectedPosts, posts) } @@ -265,7 +265,7 @@ func TestSubredditService_Rising(t *testing.T) { fmt.Fprint(w, blob) }) - posts, _, err := client.Subreddit.Rising(ctx, []string{"test"}) + posts, _, err := client.Subreddit.Rising(ctx, "test") assert.NoError(t, err) assert.Equal(t, expectedPosts, posts) } @@ -282,7 +282,7 @@ func TestSubredditService_Controversial(t *testing.T) { fmt.Fprint(w, blob) }) - posts, _, err := client.Subreddit.Controversial(ctx, []string{"test"}) + posts, _, err := client.Subreddit.Controversial(ctx, "test") assert.NoError(t, err) assert.Equal(t, expectedPosts, posts) } @@ -299,7 +299,7 @@ func TestSubredditService_Top(t *testing.T) { fmt.Fprint(w, blob) }) - posts, _, err := client.Subreddit.Top(ctx, []string{"test"}) + posts, _, err := client.Subreddit.Top(ctx, "test") assert.NoError(t, err) assert.Equal(t, expectedPosts, posts) } @@ -638,7 +638,7 @@ func TestSubredditService_SearchPosts(t *testing.T) { fmt.Fprint(w, blob) }) - posts, _, err := client.Subreddit.SearchPosts(ctx, "test", nil, SetAfter("t3_testpost")) + posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "", SetAfter("t3_testpost")) assert.NoError(t, err) assert.Equal(t, expectedSearchPosts, posts) } @@ -665,7 +665,7 @@ func TestSubredditService_SearchPosts_InSubreddit(t *testing.T) { fmt.Fprint(w, blob) }) - posts, _, err := client.Subreddit.SearchPosts(ctx, "test", []string{"test"}, SetAfter("t3_testpost")) + posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "test", SetAfter("t3_testpost")) assert.NoError(t, err) assert.Equal(t, expectedSearchPosts, posts) } @@ -692,7 +692,7 @@ func TestSubredditService_SearchPosts_InSubreddits(t *testing.T) { fmt.Fprint(w, blob) }) - posts, _, err := client.Subreddit.SearchPosts(ctx, "test", []string{"test", "golang", "nba"}, SetAfter("t3_testpost")) + posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "test+golang+nba", SetAfter("t3_testpost")) assert.NoError(t, err) assert.Equal(t, expectedSearchPosts, posts) }