Use string instead of slice for subreddits, edit comments

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-08-02 13:42:53 -04:00
parent bab7ff8e14
commit e74360e08d
5 changed files with 49 additions and 42 deletions

View File

@ -55,7 +55,7 @@ type oauth2Config struct {
tokenURL string tokenURL string
// We need to set a custom user agent, because using the one set by default by the // We need to set a custom user agent, because using the one set by default by the
// stdlib gives us 429 Too Many Request responses from the Reddit API // stdlib gives us 429 Too Many Request responses from the Reddit API.
userAgentTransport *userAgentTransport userAgentTransport *userAgentTransport
} }
@ -74,8 +74,6 @@ func oauth2Transport(c oauth2Config) *oauth2.Transport {
EndpointParams: params, EndpointParams: params,
} }
// We need to set a custom user agent, because using the one set by default by the
// stdlib gives us 429 Too Many Request responses from the Reddit API
httpClient := &http.Client{Transport: c.userAgentTransport} httpClient := &http.Client{Transport: c.userAgentTransport}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient) ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
@ -87,8 +85,8 @@ func oauth2Transport(c oauth2Config) *oauth2.Transport {
return tr return tr
} }
// WithOAuth2 sets the necessary values for the client to authenticate via OAuth2 // WithCredentials sets the necessary values for the client to authenticate via OAuth2.
func WithOAuth2(id, secret, username, password string) Opt { func WithCredentials(id, secret, username, password string) Opt {
return func(c *Client) error { return func(c *Client) error {
c.ID = id c.ID = id
c.Secret = secret c.Secret = secret

View File

@ -5,7 +5,7 @@ import (
"os" "os"
) )
// Opt is a configuration option to initialize a client // Opt is a configuration option to initialize a client.
type Opt func(*Client) error type Opt func(*Client) error
// FromEnv configures the client with values from environment variables. // FromEnv configures the client with values from environment variables.
@ -35,7 +35,7 @@ func FromEnv(c *Client) error {
return nil return nil
} }
// WithBaseURL sets the base URL for the client to make requests to // WithBaseURL sets the base URL for the client to make requests to.
func WithBaseURL(u string) Opt { func WithBaseURL(u string) Opt {
return func(c *Client) error { return func(c *Client) error {
url, err := url.Parse(u) url, err := url.Parse(u)
@ -48,7 +48,7 @@ func WithBaseURL(u string) Opt {
} }
} }
// WithTokenURL sets the url used to get access tokens // WithTokenURL sets the url used to get access tokens.
func WithTokenURL(u string) Opt { func WithTokenURL(u string) Opt {
return func(c *Client) error { return func(c *Client) error {
url, err := url.Parse(u) url, err := url.Parse(u)

View File

@ -39,7 +39,7 @@ func setup() {
}) })
client, _ = NewClient(nil, client, _ = NewClient(nil,
WithOAuth2("id1", "secret1", "user1", "password1"), WithCredentials("id1", "secret1", "user1", "password1"),
WithBaseURL(server.URL), WithBaseURL(server.URL),
WithTokenURL(server.URL+"/api/v1/access_token"), WithTokenURL(server.URL+"/api/v1/access_token"),
) )

View File

@ -40,10 +40,10 @@ type Moderator struct {
Permissions []string `json:"mod_permissions"` Permissions []string `json:"mod_permissions"`
} }
func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
path := sort path := sort
if len(subreddits) > 0 { if subreddit != "" {
path = fmt.Sprintf("r/%s/%s", strings.Join(subreddits, "+"), sort) path = fmt.Sprintf("r/%s/%s", subreddit, sort)
} }
form := newSearchOptions(opts...) form := newSearchOptions(opts...)
@ -63,36 +63,41 @@ func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddits
return root.getPosts(), resp, nil return root.getPosts(), resp, nil
} }
// Hot returns the hottest posts from the specified subreddits. // Hot returns the hottest posts from the specified subreddit.
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
// If none are defined, it returns the ones from your subscribed subreddits. // If none are defined, it returns the ones from your subscribed subreddits.
// Note: when looking for hot posts in a subreddit, it will include the stickied // Note: when looking for hot posts in a subreddit, it will include the stickied
// posts (if any) PLUS posts from the limit parameter (25 by default). // posts (if any) PLUS posts from the limit parameter (25 by default).
func (s *SubredditService) Hot(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { func (s *SubredditService) Hot(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
return s.getPosts(ctx, "hot", subreddits, opts...) return s.getPosts(ctx, "hot", subreddit, opts...)
} }
// New returns the newest posts from the specified subreddits. // New returns the newest posts from the specified subreddit.
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
// If none are defined, it returns the ones from your subscribed subreddits. // If none are defined, it returns the ones from your subscribed subreddits.
func (s *SubredditService) New(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { func (s *SubredditService) New(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
return s.getPosts(ctx, "new", subreddits, opts...) return s.getPosts(ctx, "new", subreddit, opts...)
} }
// Rising returns the rising posts from the specified subreddits. // Rising returns the rising posts from the specified subreddit.
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
// If none are defined, it returns the ones from your subscribed subreddits. // If none are defined, it returns the ones from your subscribed subreddits.
func (s *SubredditService) Rising(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { func (s *SubredditService) Rising(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
return s.getPosts(ctx, "rising", subreddits, opts...) return s.getPosts(ctx, "rising", subreddit, opts...)
} }
// Controversial returns the most controversial posts from the specified subreddits. // Controversial returns the most controversial posts from the specified subreddit.
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
// If none are defined, it returns the ones from your subscribed subreddits. // If none are defined, it returns the ones from your subscribed subreddits.
func (s *SubredditService) Controversial(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { func (s *SubredditService) Controversial(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
return s.getPosts(ctx, "controversial", subreddits, opts...) return s.getPosts(ctx, "controversial", subreddit, opts...)
} }
// Top returns the top posts from the specified subreddits. // Top returns the top posts from the specified subreddit.
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
// If none are defined, it returns the ones from your subscribed subreddits. // If none are defined, it returns the ones from your subscribed subreddits.
func (s *SubredditService) Top(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { func (s *SubredditService) Top(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
return s.getPosts(ctx, "top", subreddits, opts...) return s.getPosts(ctx, "top", subreddit, opts...)
} }
// Get gets a subreddit by name. // Get gets a subreddit by name.
@ -243,19 +248,23 @@ func (s *SubredditService) SearchNames(ctx context.Context, query string) ([]str
return root.Names, resp, nil return root.Names, resp, nil
} }
// SearchPosts searches for posts. // SearchPosts searches for posts in the specified subreddit.
// If the list of subreddits provided is empty, the search is run against r/all. // To search through multiple, separate the names with a plus (+), e.g. "golang+test".
func (s *SubredditService) SearchPosts(ctx context.Context, query string, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) { // If no subreddit is provided, the search is run against r/all.
if len(subreddits) > 0 { func (s *SubredditService) SearchPosts(ctx context.Context, query string, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
if subreddit == "" {
subreddit = "all"
}
notAll := !strings.EqualFold(subreddit, "all")
if notAll {
opts = append(opts, setRestrict) opts = append(opts, setRestrict)
} else {
subreddits = append(subreddits, "all")
} }
opts = append(opts, setQuery(query)) opts = append(opts, setQuery(query))
form := newSearchOptions(opts...) form := newSearchOptions(opts...)
path := fmt.Sprintf("r/%s/search", strings.Join(subreddits, "+")) path := fmt.Sprintf("r/%s/search", subreddit)
path = addQuery(path, form) path = addQuery(path, form)
req, err := s.client.NewRequest(http.MethodGet, path, nil) req, err := s.client.NewRequest(http.MethodGet, path, nil)

View File

@ -231,7 +231,7 @@ func TestSubredditService_Hot(t *testing.T) {
fmt.Fprint(w, blob) fmt.Fprint(w, blob)
}) })
posts, _, err := client.Subreddit.Hot(ctx, []string{"test"}) posts, _, err := client.Subreddit.Hot(ctx, "test")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedPosts, posts) assert.Equal(t, expectedPosts, posts)
} }
@ -248,7 +248,7 @@ func TestSubredditService_New(t *testing.T) {
fmt.Fprint(w, blob) fmt.Fprint(w, blob)
}) })
posts, _, err := client.Subreddit.New(ctx, []string{"test"}) posts, _, err := client.Subreddit.New(ctx, "test")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedPosts, posts) assert.Equal(t, expectedPosts, posts)
} }
@ -265,7 +265,7 @@ func TestSubredditService_Rising(t *testing.T) {
fmt.Fprint(w, blob) fmt.Fprint(w, blob)
}) })
posts, _, err := client.Subreddit.Rising(ctx, []string{"test"}) posts, _, err := client.Subreddit.Rising(ctx, "test")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedPosts, posts) assert.Equal(t, expectedPosts, posts)
} }
@ -282,7 +282,7 @@ func TestSubredditService_Controversial(t *testing.T) {
fmt.Fprint(w, blob) fmt.Fprint(w, blob)
}) })
posts, _, err := client.Subreddit.Controversial(ctx, []string{"test"}) posts, _, err := client.Subreddit.Controversial(ctx, "test")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedPosts, posts) assert.Equal(t, expectedPosts, posts)
} }
@ -299,7 +299,7 @@ func TestSubredditService_Top(t *testing.T) {
fmt.Fprint(w, blob) fmt.Fprint(w, blob)
}) })
posts, _, err := client.Subreddit.Top(ctx, []string{"test"}) posts, _, err := client.Subreddit.Top(ctx, "test")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedPosts, posts) assert.Equal(t, expectedPosts, posts)
} }
@ -638,7 +638,7 @@ func TestSubredditService_SearchPosts(t *testing.T) {
fmt.Fprint(w, blob) fmt.Fprint(w, blob)
}) })
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", nil, SetAfter("t3_testpost")) posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "", SetAfter("t3_testpost"))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedSearchPosts, posts) assert.Equal(t, expectedSearchPosts, posts)
} }
@ -665,7 +665,7 @@ func TestSubredditService_SearchPosts_InSubreddit(t *testing.T) {
fmt.Fprint(w, blob) fmt.Fprint(w, blob)
}) })
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", []string{"test"}, SetAfter("t3_testpost")) posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "test", SetAfter("t3_testpost"))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedSearchPosts, posts) assert.Equal(t, expectedSearchPosts, posts)
} }
@ -692,7 +692,7 @@ func TestSubredditService_SearchPosts_InSubreddits(t *testing.T) {
fmt.Fprint(w, blob) fmt.Fprint(w, blob)
}) })
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", []string{"test", "golang", "nba"}, SetAfter("t3_testpost")) posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "test+golang+nba", SetAfter("t3_testpost"))
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expectedSearchPosts, posts) assert.Equal(t, expectedSearchPosts, posts)
} }