From d142255dbdb12f225c90a86c2f8eb9fba9e5fd52 Mon Sep 17 00:00:00 2001 From: Vartan Benohanian Date: Tue, 21 Jul 2020 23:59:53 -0400 Subject: [PATCH] Add more tests for SubredditService Signed-off-by: Vartan Benohanian --- subreddit.go | 108 +++++++++--------- subreddit_test.go | 158 ++++++++++++++++++++++++++- testdata/subreddit/search-names.json | 11 ++ testdata/subreddit/search.json | 74 +++++++++++++ 4 files changed, 290 insertions(+), 61 deletions(-) create mode 100644 testdata/subreddit/search-names.json create mode 100644 testdata/subreddit/search.json diff --git a/subreddit.go b/subreddit.go index e9a937c..1b22637 100644 --- a/subreddit.go +++ b/subreddit.go @@ -22,21 +22,21 @@ type rootSubreddit struct { Data *Subreddit `json:"data,omitempty"` } -type rootSubredditNames struct { - Names []string `json:"names,omitempty"` +type rootSubredditInfoList struct { + Subreddits []*SubredditInfo `json:"subreddits,omitempty"` } -type rootSubredditShorts struct { - Subreddits []SubredditShort `json:"subreddits,omitempty"` -} - -// SubredditShort represents minimal information about a subreddit -type SubredditShort struct { +// SubredditInfo represents minimal information about a subreddit. +type SubredditInfo struct { Name string `json:"name,omitempty"` Subscribers int `json:"subscriber_count"` ActiveUsers int `json:"active_user_count"` } +type rootSubredditNames struct { + Names []string `json:"names,omitempty"` +} + type rootModeratorList struct { Kind string `json:"kind,omitempty"` Data struct { @@ -61,13 +61,13 @@ func (s *SubredditService) GetPosts() *PostFinder { return f.Sort(SortHot).FromAll() } -// GetByName gets a subreddit by name -func (s *SubredditService) GetByName(ctx context.Context, subreddit string) (*Subreddit, *Response, error) { - if subreddit == "" { - return nil, nil, errors.New("empty subreddit name provided") +// Get gets a subreddit by name. +func (s *SubredditService) Get(ctx context.Context, name string) (*Subreddit, *Response, error) { + if name == "" { + return nil, nil, errors.New("name: must not be empty") } - path := fmt.Sprintf("r/%s/about", subreddit) + path := fmt.Sprintf("r/%s/about", name) req, err := s.client.NewRequest(http.MethodGet, path, nil) if err != nil { return nil, nil, err @@ -127,6 +127,17 @@ func (s *SubredditService) GetSticky2(ctx context.Context, name string) (*Post, return s.getSticky(ctx, name, 2) } +func (s *SubredditService) handleSubscription(ctx context.Context, form url.Values) (*Response, error) { + path := "api/subscribe" + + req, err := s.client.NewRequestWithForm(http.MethodPost, path, form) + if err != nil { + return nil, err + } + + return s.client.Do(ctx, req, nil) +} + // Subscribe subscribes to subreddits based on their names. func (s *SubredditService) Subscribe(ctx context.Context, subreddits ...string) (*Response, error) { form := url.Values{} @@ -159,8 +170,30 @@ func (s *SubredditService) UnsubscribeByID(ctx context.Context, ids ...string) ( return s.handleSubscription(ctx, form) } -// SearchSubredditNames searches for subreddits with names beginning with the query provided. -func (s *SubredditService) SearchSubredditNames(ctx context.Context, query string) ([]string, *Response, error) { +// Search searches for subreddits with names beginning with the query provided. +// They hold a very minimal amount of info. +func (s *SubredditService) Search(ctx context.Context, query string) ([]*SubredditInfo, *Response, error) { + path := "api/search_subreddits" + + form := url.Values{} + form.Set("query", query) + + req, err := s.client.NewRequestWithForm(http.MethodPost, path, form) + if err != nil { + return nil, nil, err + } + + root := new(rootSubredditInfoList) + resp, err := s.client.Do(ctx, req, root) + if err != nil { + return nil, resp, err + } + + return root.Subreddits, resp, nil +} + +// SearchNames searches for subreddits with names beginning with the query provided. +func (s *SubredditService) SearchNames(ctx context.Context, query string) ([]string, *Response, error) { path := fmt.Sprintf("api/search_reddit_names?query=%s", query) req, err := s.client.NewRequest(http.MethodGet, path, nil) @@ -177,41 +210,6 @@ func (s *SubredditService) SearchSubredditNames(ctx context.Context, query strin return root.Names, resp, nil } -// SearchSubredditInfo searches for subreddits with names beginning with the query provided. -// They hold a bit more info that just the name, but still not much. -func (s *SubredditService) SearchSubredditInfo(ctx context.Context, query string) ([]SubredditShort, *Response, error) { - path := fmt.Sprintf("api/search_subreddits?query=%s", query) - - req, err := s.client.NewRequest(http.MethodPost, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(rootSubredditShorts) - resp, err := s.client.Do(ctx, req, root) - if err != nil { - return nil, resp, err - } - - return root.Subreddits, resp, nil -} - -func (s *SubredditService) handleSubscription(ctx context.Context, form url.Values) (*Response, error) { - path := "api/subscribe" - - req, err := s.client.NewRequestWithForm(http.MethodPost, path, form) - if err != nil { - return nil, err - } - - resp, err := s.client.Do(ctx, req, nil) - if err != nil { - return resp, err - } - - return resp, nil -} - func (s *SubredditService) getSubreddits(ctx context.Context, path string, opts *ListOptions) (*Subreddits, *Response, error) { path, err := addOptions(path, opts) if err != nil { @@ -229,15 +227,7 @@ func (s *SubredditService) getSubreddits(ctx context.Context, path string, opts return nil, resp, err } - l := new(Subreddits) - - if root.Data != nil { - l.Subreddits = root.Data.Things.Subreddits - l.After = root.Data.After - l.Before = root.Data.Before - } - - return l, resp, nil + return root.getSubreddits(), resp, nil } // getSticky returns one of the 2 stickied posts of the subreddit (if they exist). diff --git a/subreddit_test.go b/subreddit_test.go index ab20712..96ff990 100644 --- a/subreddit_test.go +++ b/subreddit_test.go @@ -3,6 +3,7 @@ package reddit import ( "fmt" "net/http" + "net/url" "testing" "time" @@ -113,6 +114,26 @@ var expectedSticky = &postAndComments{ }, } +var expectSubredditInfos = []*SubredditInfo{ + {Name: "golang", Subscribers: 119_722, ActiveUsers: 531}, + {Name: "golang_infosec", Subscribers: 1_776, ActiveUsers: 0}, + {Name: "GolangJobOfferings", Subscribers: 863, ActiveUsers: 1}, + {Name: "golang2", Subscribers: 626, ActiveUsers: 0}, + {Name: "GolangUnofficial", Subscribers: 239, ActiveUsers: 4}, + {Name: "golanguage", Subscribers: 247, ActiveUsers: 4}, + {Name: "golang_jobs", Subscribers: 16, ActiveUsers: 4}, +} + +var expectSubredditNames = []string{ + "golang", + "golang_infosec", + "GolangJobOfferings", + "golanguage", + "golang2", + "GolangUnofficial", + "golang_jobs", +} + var expectedModerators = []Moderator{ {ID: "t2_test1", Name: "testuser1", Permissions: []string{"all"}}, {ID: "t2_test2", Name: "testuser2", Permissions: []string{"all"}}, @@ -132,7 +153,7 @@ var expectedRandomSubreddit = &Subreddit{ Subscribers: 52357, } -func TestSubredditService_GetByName(t *testing.T) { +func TestSubredditService_Get(t *testing.T) { setup() defer teardown() @@ -144,7 +165,10 @@ func TestSubredditService_GetByName(t *testing.T) { fmt.Fprint(w, blob) }) - subreddit, _, err := client.Subreddit.GetByName(ctx, "golang") + _, _, err = client.Subreddit.Get(ctx, "") + assert.EqualError(t, err, "name: must not be empty") + + subreddit, _, err := client.Subreddit.Get(ctx, "golang") assert.NoError(t, err) assert.Equal(t, expectedSubreddit, subreddit) } @@ -293,6 +317,136 @@ func TestSubredditService_GetModerated(t *testing.T) { // // fmt.Println(string(b)) // } +func TestSubredditService_Subscribe(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/subscribe", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("action", "sub") + form.Set("sr_name", "test,golang,nba") + + err := r.ParseForm() + assert.NoError(t, err) + assert.Equal(t, form, r.Form) + }) + + _, err := client.Subreddit.Subscribe(ctx, "test", "golang", "nba") + assert.NoError(t, err) +} + +func TestSubredditService_SubscribeByID(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/subscribe", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("action", "sub") + form.Set("sr", "t5_test1,t5_test2,t5_test3") + + err := r.ParseForm() + assert.NoError(t, err) + assert.Equal(t, form, r.Form) + }) + + _, err := client.Subreddit.SubscribeByID(ctx, "t5_test1", "t5_test2", "t5_test3") + assert.NoError(t, err) +} + +func TestSubredditService_Unsubscribe(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/subscribe", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("action", "unsub") + form.Set("sr_name", "test,golang,nba") + + err := r.ParseForm() + assert.NoError(t, err) + assert.Equal(t, form, r.Form) + }) + + _, err := client.Subreddit.Unsubscribe(ctx, "test", "golang", "nba") + assert.NoError(t, err) +} + +func TestSubredditService_UnsubscribeByID(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/subscribe", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("action", "unsub") + form.Set("sr", "t5_test1,t5_test2,t5_test3") + + err := r.ParseForm() + assert.NoError(t, err) + assert.Equal(t, form, r.Form) + }) + + _, err := client.Subreddit.UnsubscribeByID(ctx, "t5_test1", "t5_test2", "t5_test3") + assert.NoError(t, err) +} + +func TestSubredditService_Search(t *testing.T) { + setup() + defer teardown() + + blob, err := readFileContents("testdata/subreddit/search.json") + assert.NoError(t, err) + + mux.HandleFunc("/api/search_subreddits", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("query", "golang") + + err := r.ParseForm() + assert.NoError(t, err) + assert.Equal(t, form, r.PostForm) + + fmt.Fprint(w, blob) + }) + + subreddits, _, err := client.Subreddit.Search(ctx, "golang") + assert.NoError(t, err) + assert.Equal(t, expectSubredditInfos, subreddits) +} + +func TestSubredditService_SearchNames(t *testing.T) { + setup() + defer teardown() + + blob, err := readFileContents("testdata/subreddit/search-names.json") + assert.NoError(t, err) + + mux.HandleFunc("/api/search_reddit_names", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodGet, r.Method) + + form := url.Values{} + form.Set("query", "golang") + + err := r.ParseForm() + assert.NoError(t, err) + assert.Equal(t, form, r.Form) + + fmt.Fprint(w, blob) + }) + + names, _, err := client.Subreddit.SearchNames(ctx, "golang") + assert.NoError(t, err) + assert.Equal(t, expectSubredditNames, names) +} + func TestSubredditService_Moderators(t *testing.T) { setup() defer teardown() diff --git a/testdata/subreddit/search-names.json b/testdata/subreddit/search-names.json new file mode 100644 index 0000000..99791d2 --- /dev/null +++ b/testdata/subreddit/search-names.json @@ -0,0 +1,11 @@ +{ + "names": [ + "golang", + "golang_infosec", + "GolangJobOfferings", + "golanguage", + "golang2", + "GolangUnofficial", + "golang_jobs" + ] +} diff --git a/testdata/subreddit/search.json b/testdata/subreddit/search.json new file mode 100644 index 0000000..707add0 --- /dev/null +++ b/testdata/subreddit/search.json @@ -0,0 +1,74 @@ +{ + "subreddits": [ + { + "active_user_count": 531, + "icon_img": "", + "key_color": "", + "name": "golang", + "subscriber_count": 119722, + "is_chat_post_feature_enabled": false, + "allow_chat_post_creation": false, + "allow_images": true + }, + { + "active_user_count": 0, + "icon_img": "", + "key_color": "", + "name": "golang_infosec", + "subscriber_count": 1776, + "is_chat_post_feature_enabled": true, + "allow_chat_post_creation": false, + "allow_images": true + }, + { + "active_user_count": 1, + "icon_img": "", + "key_color": "", + "name": "GolangJobOfferings", + "subscriber_count": 863, + "is_chat_post_feature_enabled": true, + "allow_chat_post_creation": false, + "allow_images": true + }, + { + "active_user_count": 0, + "icon_img": "", + "key_color": "#24a0ed", + "name": "golang2", + "subscriber_count": 626, + "is_chat_post_feature_enabled": false, + "allow_chat_post_creation": false, + "allow_images": true + }, + { + "active_user_count": 4, + "icon_img": "", + "key_color": "", + "name": "GolangUnofficial", + "subscriber_count": 239, + "is_chat_post_feature_enabled": false, + "allow_chat_post_creation": false, + "allow_images": true + }, + { + "active_user_count": 4, + "icon_img": "", + "key_color": "", + "name": "golanguage", + "subscriber_count": 247, + "is_chat_post_feature_enabled": false, + "allow_chat_post_creation": false, + "allow_images": true + }, + { + "active_user_count": 4, + "icon_img": "", + "key_color": "", + "name": "golang_jobs", + "subscriber_count": 16, + "is_chat_post_feature_enabled": false, + "allow_chat_post_creation": false, + "allow_images": true + } + ] +}