Use string instead of slice for subreddits, edit comments
Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
parent
bab7ff8e14
commit
e74360e08d
5 changed files with 49 additions and 42 deletions
|
@ -55,7 +55,7 @@ type oauth2Config struct {
|
||||||
tokenURL string
|
tokenURL string
|
||||||
|
|
||||||
// We need to set a custom user agent, because using the one set by default by the
|
// We need to set a custom user agent, because using the one set by default by the
|
||||||
// stdlib gives us 429 Too Many Request responses from the Reddit API
|
// stdlib gives us 429 Too Many Request responses from the Reddit API.
|
||||||
userAgentTransport *userAgentTransport
|
userAgentTransport *userAgentTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,8 +74,6 @@ func oauth2Transport(c oauth2Config) *oauth2.Transport {
|
||||||
EndpointParams: params,
|
EndpointParams: params,
|
||||||
}
|
}
|
||||||
|
|
||||||
// We need to set a custom user agent, because using the one set by default by the
|
|
||||||
// stdlib gives us 429 Too Many Request responses from the Reddit API
|
|
||||||
httpClient := &http.Client{Transport: c.userAgentTransport}
|
httpClient := &http.Client{Transport: c.userAgentTransport}
|
||||||
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, httpClient)
|
||||||
|
|
||||||
|
@ -87,8 +85,8 @@ func oauth2Transport(c oauth2Config) *oauth2.Transport {
|
||||||
return tr
|
return tr
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithOAuth2 sets the necessary values for the client to authenticate via OAuth2
|
// WithCredentials sets the necessary values for the client to authenticate via OAuth2.
|
||||||
func WithOAuth2(id, secret, username, password string) Opt {
|
func WithCredentials(id, secret, username, password string) Opt {
|
||||||
return func(c *Client) error {
|
return func(c *Client) error {
|
||||||
c.ID = id
|
c.ID = id
|
||||||
c.Secret = secret
|
c.Secret = secret
|
||||||
|
|
|
@ -5,7 +5,7 @@ import (
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Opt is a configuration option to initialize a client
|
// Opt is a configuration option to initialize a client.
|
||||||
type Opt func(*Client) error
|
type Opt func(*Client) error
|
||||||
|
|
||||||
// FromEnv configures the client with values from environment variables.
|
// FromEnv configures the client with values from environment variables.
|
||||||
|
@ -35,7 +35,7 @@ func FromEnv(c *Client) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithBaseURL sets the base URL for the client to make requests to
|
// WithBaseURL sets the base URL for the client to make requests to.
|
||||||
func WithBaseURL(u string) Opt {
|
func WithBaseURL(u string) Opt {
|
||||||
return func(c *Client) error {
|
return func(c *Client) error {
|
||||||
url, err := url.Parse(u)
|
url, err := url.Parse(u)
|
||||||
|
@ -48,7 +48,7 @@ func WithBaseURL(u string) Opt {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithTokenURL sets the url used to get access tokens
|
// WithTokenURL sets the url used to get access tokens.
|
||||||
func WithTokenURL(u string) Opt {
|
func WithTokenURL(u string) Opt {
|
||||||
return func(c *Client) error {
|
return func(c *Client) error {
|
||||||
url, err := url.Parse(u)
|
url, err := url.Parse(u)
|
||||||
|
|
|
@ -39,7 +39,7 @@ func setup() {
|
||||||
})
|
})
|
||||||
|
|
||||||
client, _ = NewClient(nil,
|
client, _ = NewClient(nil,
|
||||||
WithOAuth2("id1", "secret1", "user1", "password1"),
|
WithCredentials("id1", "secret1", "user1", "password1"),
|
||||||
WithBaseURL(server.URL),
|
WithBaseURL(server.URL),
|
||||||
WithTokenURL(server.URL+"/api/v1/access_token"),
|
WithTokenURL(server.URL+"/api/v1/access_token"),
|
||||||
)
|
)
|
||||||
|
|
59
subreddit.go
59
subreddit.go
|
@ -40,10 +40,10 @@ type Moderator struct {
|
||||||
Permissions []string `json:"mod_permissions"`
|
Permissions []string `json:"mod_permissions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||||
path := sort
|
path := sort
|
||||||
if len(subreddits) > 0 {
|
if subreddit != "" {
|
||||||
path = fmt.Sprintf("r/%s/%s", strings.Join(subreddits, "+"), sort)
|
path = fmt.Sprintf("r/%s/%s", subreddit, sort)
|
||||||
}
|
}
|
||||||
|
|
||||||
form := newSearchOptions(opts...)
|
form := newSearchOptions(opts...)
|
||||||
|
@ -63,36 +63,41 @@ func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddits
|
||||||
return root.getPosts(), resp, nil
|
return root.getPosts(), resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hot returns the hottest posts from the specified subreddits.
|
// Hot returns the hottest posts from the specified subreddit.
|
||||||
|
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||||
// Note: when looking for hot posts in a subreddit, it will include the stickied
|
// Note: when looking for hot posts in a subreddit, it will include the stickied
|
||||||
// posts (if any) PLUS posts from the limit parameter (25 by default).
|
// posts (if any) PLUS posts from the limit parameter (25 by default).
|
||||||
func (s *SubredditService) Hot(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
func (s *SubredditService) Hot(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||||
return s.getPosts(ctx, "hot", subreddits, opts...)
|
return s.getPosts(ctx, "hot", subreddit, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns the newest posts from the specified subreddits.
|
// New returns the newest posts from the specified subreddit.
|
||||||
|
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||||
func (s *SubredditService) New(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
func (s *SubredditService) New(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||||
return s.getPosts(ctx, "new", subreddits, opts...)
|
return s.getPosts(ctx, "new", subreddit, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rising returns the rising posts from the specified subreddits.
|
// Rising returns the rising posts from the specified subreddit.
|
||||||
|
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||||
func (s *SubredditService) Rising(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
func (s *SubredditService) Rising(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||||
return s.getPosts(ctx, "rising", subreddits, opts...)
|
return s.getPosts(ctx, "rising", subreddit, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Controversial returns the most controversial posts from the specified subreddits.
|
// Controversial returns the most controversial posts from the specified subreddit.
|
||||||
|
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||||
func (s *SubredditService) Controversial(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
func (s *SubredditService) Controversial(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||||
return s.getPosts(ctx, "controversial", subreddits, opts...)
|
return s.getPosts(ctx, "controversial", subreddit, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Top returns the top posts from the specified subreddits.
|
// Top returns the top posts from the specified subreddit.
|
||||||
|
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||||
// If none are defined, it returns the ones from your subscribed subreddits.
|
// If none are defined, it returns the ones from your subscribed subreddits.
|
||||||
func (s *SubredditService) Top(ctx context.Context, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
func (s *SubredditService) Top(ctx context.Context, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||||
return s.getPosts(ctx, "top", subreddits, opts...)
|
return s.getPosts(ctx, "top", subreddit, opts...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get gets a subreddit by name.
|
// Get gets a subreddit by name.
|
||||||
|
@ -243,19 +248,23 @@ func (s *SubredditService) SearchNames(ctx context.Context, query string) ([]str
|
||||||
return root.Names, resp, nil
|
return root.Names, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SearchPosts searches for posts.
|
// SearchPosts searches for posts in the specified subreddit.
|
||||||
// If the list of subreddits provided is empty, the search is run against r/all.
|
// To search through multiple, separate the names with a plus (+), e.g. "golang+test".
|
||||||
func (s *SubredditService) SearchPosts(ctx context.Context, query string, subreddits []string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
// If no subreddit is provided, the search is run against r/all.
|
||||||
if len(subreddits) > 0 {
|
func (s *SubredditService) SearchPosts(ctx context.Context, query string, subreddit string, opts ...SearchOptionSetter) (*Posts, *Response, error) {
|
||||||
|
if subreddit == "" {
|
||||||
|
subreddit = "all"
|
||||||
|
}
|
||||||
|
|
||||||
|
notAll := !strings.EqualFold(subreddit, "all")
|
||||||
|
if notAll {
|
||||||
opts = append(opts, setRestrict)
|
opts = append(opts, setRestrict)
|
||||||
} else {
|
|
||||||
subreddits = append(subreddits, "all")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
opts = append(opts, setQuery(query))
|
opts = append(opts, setQuery(query))
|
||||||
form := newSearchOptions(opts...)
|
form := newSearchOptions(opts...)
|
||||||
|
|
||||||
path := fmt.Sprintf("r/%s/search", strings.Join(subreddits, "+"))
|
path := fmt.Sprintf("r/%s/search", subreddit)
|
||||||
path = addQuery(path, form)
|
path = addQuery(path, form)
|
||||||
|
|
||||||
req, err := s.client.NewRequest(http.MethodGet, path, nil)
|
req, err := s.client.NewRequest(http.MethodGet, path, nil)
|
||||||
|
|
|
@ -231,7 +231,7 @@ func TestSubredditService_Hot(t *testing.T) {
|
||||||
fmt.Fprint(w, blob)
|
fmt.Fprint(w, blob)
|
||||||
})
|
})
|
||||||
|
|
||||||
posts, _, err := client.Subreddit.Hot(ctx, []string{"test"})
|
posts, _, err := client.Subreddit.Hot(ctx, "test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedPosts, posts)
|
assert.Equal(t, expectedPosts, posts)
|
||||||
}
|
}
|
||||||
|
@ -248,7 +248,7 @@ func TestSubredditService_New(t *testing.T) {
|
||||||
fmt.Fprint(w, blob)
|
fmt.Fprint(w, blob)
|
||||||
})
|
})
|
||||||
|
|
||||||
posts, _, err := client.Subreddit.New(ctx, []string{"test"})
|
posts, _, err := client.Subreddit.New(ctx, "test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedPosts, posts)
|
assert.Equal(t, expectedPosts, posts)
|
||||||
}
|
}
|
||||||
|
@ -265,7 +265,7 @@ func TestSubredditService_Rising(t *testing.T) {
|
||||||
fmt.Fprint(w, blob)
|
fmt.Fprint(w, blob)
|
||||||
})
|
})
|
||||||
|
|
||||||
posts, _, err := client.Subreddit.Rising(ctx, []string{"test"})
|
posts, _, err := client.Subreddit.Rising(ctx, "test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedPosts, posts)
|
assert.Equal(t, expectedPosts, posts)
|
||||||
}
|
}
|
||||||
|
@ -282,7 +282,7 @@ func TestSubredditService_Controversial(t *testing.T) {
|
||||||
fmt.Fprint(w, blob)
|
fmt.Fprint(w, blob)
|
||||||
})
|
})
|
||||||
|
|
||||||
posts, _, err := client.Subreddit.Controversial(ctx, []string{"test"})
|
posts, _, err := client.Subreddit.Controversial(ctx, "test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedPosts, posts)
|
assert.Equal(t, expectedPosts, posts)
|
||||||
}
|
}
|
||||||
|
@ -299,7 +299,7 @@ func TestSubredditService_Top(t *testing.T) {
|
||||||
fmt.Fprint(w, blob)
|
fmt.Fprint(w, blob)
|
||||||
})
|
})
|
||||||
|
|
||||||
posts, _, err := client.Subreddit.Top(ctx, []string{"test"})
|
posts, _, err := client.Subreddit.Top(ctx, "test")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedPosts, posts)
|
assert.Equal(t, expectedPosts, posts)
|
||||||
}
|
}
|
||||||
|
@ -638,7 +638,7 @@ func TestSubredditService_SearchPosts(t *testing.T) {
|
||||||
fmt.Fprint(w, blob)
|
fmt.Fprint(w, blob)
|
||||||
})
|
})
|
||||||
|
|
||||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", nil, SetAfter("t3_testpost"))
|
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "", SetAfter("t3_testpost"))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedSearchPosts, posts)
|
assert.Equal(t, expectedSearchPosts, posts)
|
||||||
}
|
}
|
||||||
|
@ -665,7 +665,7 @@ func TestSubredditService_SearchPosts_InSubreddit(t *testing.T) {
|
||||||
fmt.Fprint(w, blob)
|
fmt.Fprint(w, blob)
|
||||||
})
|
})
|
||||||
|
|
||||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", []string{"test"}, SetAfter("t3_testpost"))
|
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "test", SetAfter("t3_testpost"))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedSearchPosts, posts)
|
assert.Equal(t, expectedSearchPosts, posts)
|
||||||
}
|
}
|
||||||
|
@ -692,7 +692,7 @@ func TestSubredditService_SearchPosts_InSubreddits(t *testing.T) {
|
||||||
fmt.Fprint(w, blob)
|
fmt.Fprint(w, blob)
|
||||||
})
|
})
|
||||||
|
|
||||||
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", []string{"test", "golang", "nba"}, SetAfter("t3_testpost"))
|
posts, _, err := client.Subreddit.SearchPosts(ctx, "test", "test+golang+nba", SetAfter("t3_testpost"))
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, expectedSearchPosts, posts)
|
assert.Equal(t, expectedSearchPosts, posts)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue