Use string instead of slice for subreddits, edit comments
Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
parent
bab7ff8e14
commit
e74360e08d
5 changed files with 49 additions and 42 deletions
|
@ -55,7 +55,7 @@ type oauth2Config struct {
|
|||
tokenURL string
|
||||
|
||||
// We need to set a custom user agent, because using the one set by default by the
|
||||
// stdlib gives us 429 Too Many Request responses from the Reddit API
|
||||
// stdlib gives us 429 Too Many Request responses from the Reddit API.
|
||||
userAgentTransport *userAgentTransport
|
||||
}
|
||||
|
||||
|
@ -74,8 +74,6 @@ func oauth2Transport(c oauth2Config) *oauth2.Transport {
|
|||
EndpointParams: params,
|
||||
}
|
||||
|
||||
// We need to set a custom user agent, because using the one set by default by the
|
||||
// stdlib gives us 429 Too Many Request responses from the Reddit API
|
||||
httpClient := &http.Client{Transport: c.userAgentTransport}
|
||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||
|
||||
|
@ -87,8 +85,8 @@ func oauth2Transport(c oauth2Config) *oauth2.Transport {
|
|||
return tr
|
||||
}
|
||||
|
||||
// WithOAuth2 sets the necessary values for the client to authenticate via OAuth2
|
||||
func WithOAuth2(id, secret, username, password string) Opt {
|
||||
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
|
||||
func WithCredentials(id, secret, username, password string) Opt {
|
||||
return func(c *Client) error {
|
||||
c.ID = id
|
||||
c.Secret = secret
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"os"
|
||||
)
|
||||
|
||||
// Opt is a configuration option to initialize a client
|
||||
// Opt is a configuration option to initialize a client.
|
||||
type Opt func(*Client) error
|
||||
|
||||
// FromEnv configures the client with values from environment variables.
|
||||
|
@ -35,7 +35,7 @@ func FromEnv(c *Client) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WithBaseURL sets the base URL for the client to make requests to
|
||||
// WithBaseURL sets the base URL for the client to make requests to.
|
||||
func WithBaseURL(u string) Opt {
|
||||
return func(c *Client) error {
|
||||
url, err := url.Parse(u)
|
||||
|
@ -48,7 +48,7 @@ func WithBaseURL(u string) Opt {
|
|||
}
|
||||
}
|
||||
|
||||
// WithTokenURL sets the url used to get access tokens
|
||||
// WithTokenURL sets the url used to get access tokens.
|
||||
func WithTokenURL(u string) Opt {
|
||||
return func(c *Client) error {
|
||||
url, err := url.Parse(u)
|
||||
|
|
|
@ -39,7 +39,7 @@ func setup() {
|
|||
})
|
||||
|
||||
client, _ = NewClient(nil,
|
||||
WithOAuth2("id1", "secret1", "user1", "password1"),
|
||||
WithCredentials("id1", "secret1", "user1", "password1"),
|
||||
WithBaseURL(server.URL),
|
||||
WithTokenURL(server.URL+"/api/v1/access_token"),
|
||||
)
|
||||
|
|
59
subreddit.go
59
subreddit.go
|
@ -40,10 +40,10 @@ type Moderator struct {
|
|||
Permissions []string `json:"mod_permissions"`
|
||||
}
|
||||
|
||||
func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
path := sort
|
||||
if len(subreddits) > 0 {
|
||||
path = fmt.Sprintf("r/%s/%s", strings.Join(subreddits, "+"), sort)
|
||||
if subreddit != "" {
|
||||
path = fmt.Sprintf("r/%s/%s", subreddit, sort)
|
||||
}
|
||||
|
||||
form := newSearchOptions(opts...)
|
||||
|
@ -63,36 +63,41 @@ func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddits
|
|||
return root.getPosts(), resp, nil
|
||||
}
|
||||
|
||||
// Hot returns the hottest posts from the specified subreddits.
|
||||
// Hot returns the hottest posts from the specified subreddit.
|
||||
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||
// Note: when looking for hot posts in a subreddit, it will include the stickied
|
||||
// posts (if any) PLUS posts from the limit parameter (25 by default).
|
||||
func (s *SubredditService) Hot(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "hot", subreddits, opts...)
|
||||
func (s *SubredditService) Hot(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "hot", subreddit, opts...)
|
||||
}
|
||||
|
||||
// New returns the newest posts from the specified subreddits.
|
||||
// New returns the newest posts from the specified subreddit.
|
||||
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||
func (s *SubredditService) New(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "new", subreddits, opts...)
|
||||
func (s *SubredditService) New(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "new", subreddit, opts...)
|
||||
}
|
||||
|
||||
// Rising returns the rising posts from the specified subreddits.
|
||||
// Rising returns the rising posts from the specified subreddit.
|
||||
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||
func (s *SubredditService) Rising(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "rising", subreddits, opts...)
|
||||
func (s *SubredditService) Rising(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "rising", subreddit, opts...)
|
||||
}
|
||||
|
||||
// Controversial returns the most controversial posts from the specified subreddits.
|
||||
// Controversial returns the most controversial posts from the specified subreddit.
|
||||
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||
func (s *SubredditService) Controversial(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "controversial", subreddits, opts...)
|
||||
func (s *SubredditService) Controversial(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "controversial", subreddit, opts...)
|
||||
}
|
||||
|
||||
// Top returns the top posts from the specified subreddits.
|
||||
// Top returns the top posts from the specified subreddit.
|
||||
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||
func (s *SubredditService) Top(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "top", subreddits, opts...)
|
||||
func (s *SubredditService) Top(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
return s.getPosts(ctx, "top", subreddit, opts...)
|
||||
}
|
||||
|
||||
// Get gets a subreddit by name.
|
||||
|
@ -243,19 +248,23 @@ func (s *SubredditService) SearchNames(ctx context.Context, query string) ([]str
|
|||
return root.Names, resp, nil
|
||||
}
|
||||
|
||||
// SearchPosts searches for posts.
|
||||
// If the list of subreddits provided is empty, the search is run against r/all.
|
||||
func (s *SubredditService) SearchPosts(ctx context.Context, query string, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
if len(subreddits) > 0 {
|
||||
// SearchPosts searches for posts in the specified subreddit.
|
||||
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||
// If no subreddit is provided, the search is run against r/all.
|
||||
func (s *SubredditService) SearchPosts(ctx context.Context, query string, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||
if subreddit == "" {
|
||||
subreddit = "all"
|
||||
}
|
||||
|
||||
notAll := !strings.EqualFold(subreddit, "all")
|
||||
if notAll {
|
||||
opts = append(opts, setRestrict)
|
||||
} else {
|
||||
subreddits = append(subreddits, "all")
|
||||
}
|
||||
|
||||
opts = append(opts, setQuery(query))
|
||||
form := newSearchOptions(opts...)
|
||||
|
||||
path := fmt.Sprintf("r/%s/search", strings.Join(subreddits, "+"))
|
||||
path := fmt.Sprintf("r/%s/search", subreddit)
|
||||
path = addQuery(path, form)
|
||||
|
||||
req, err := s.client.NewRequest(http.MethodGet, path, nil)
|
||||
|
|
|
@ -231,7 +231,7 @@ func TestSubredditService_Hot(t *testing.T) {
|
|||
fmt.Fprint(w, blob)
|
||||
})
|
||||
|
||||
posts, _, err := client.Subreddit.Hot(ctx, []string{"test"})
|
||||
posts, _, err := client.Subreddit.Hot(ctx, "test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedPosts, posts)
|
||||
}
|
||||
|
@ -248,7 +248,7 @@ func TestSubredditService_New(t *testing.T) {
|
|||
fmt.Fprint(w, blob)
|
||||
})
|
||||
|
||||
posts, _, err := client.Subreddit.New(ctx, []string{"test"})
|
||||
posts, _, err := client.Subreddit.New(ctx, "test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedPosts, posts)
|
||||
}
|
||||
|
@ -265,7 +265,7 @@ func TestSubredditService_Rising(t *testing.T) {
|
|||
fmt.Fprint(w, blob)
|
||||
})
|
||||
|
||||
posts, _, err := client.Subreddit.Rising(ctx, []string{"test"})
|
||||
posts, _, err := client.Subreddit.Rising(ctx, "test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedPosts, posts)
|
||||
}
|
||||
|
@ -282,7 +282,7 @@ func TestSubredditService_Controversial(t *testing.T) {
|
|||
fmt.Fprint(w, blob)
|
||||
})
|
||||
|
||||
posts, _, err := client.Subreddit.Controversial(ctx, []string{"test"})
|
||||
posts, _, err := client.Subreddit.Controversial(ctx, "test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedPosts, posts)
|
||||
}
|
||||
|
@ -299,7 +299,7 @@ func TestSubredditService_Top(t *testing.T) {
|
|||
fmt.Fprint(w, blob)
|
||||
})
|
||||
|
||||
posts, _, err := client.Subreddit.Top(ctx, []string{"test"})
|
||||
posts, _, err := client.Subreddit.Top(ctx, "test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedPosts, posts)
|
||||
}
|
||||
|
@ -638,7 +638,7 @@ func TestSubredditService_SearchPosts(t *testing.T) {
|
|||
fmt.Fprint(w, blob)
|
||||
})
|
||||
|
||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", nil, SetAfter("t3_testpost"))
|
||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "", SetAfter("t3_testpost"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSearchPosts, posts)
|
||||
}
|
||||
|
@ -665,7 +665,7 @@ func TestSubredditService_SearchPosts_InSubreddit(t *testing.T) {
|
|||
fmt.Fprint(w, blob)
|
||||
})
|
||||
|
||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", []string{"test"}, SetAfter("t3_testpost"))
|
||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "test", SetAfter("t3_testpost"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSearchPosts, posts)
|
||||
}
|
||||
|
@ -692,7 +692,7 @@ func TestSubredditService_SearchPosts_InSubreddits(t *testing.T) {
|
|||
fmt.Fprint(w, blob)
|
||||
})
|
||||
|
||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", []string{"test", "golang", "nba"}, SetAfter("t3_testpost"))
|
||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "test+golang+nba", SetAfter("t3_testpost"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expectedSearchPosts, posts)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue