From e01896137a5b6b8dc4e5aaba70a56521df0cf79a Mon Sep 17 00:00:00 2001 From: Vartan Benohanian Date: Fri, 18 Sep 2020 11:27:42 -0400 Subject: [PATCH] Create getThing method to remove duplicate code Signed-off-by: Vartan Benohanian --- reddit/account.go | 24 ++++--------------- reddit/live-thread.go | 26 +++++---------------- reddit/multi.go | 54 ++++++++++++++++++------------------------- reddit/reddit.go | 20 ++++++++++++---- reddit/subreddit.go | 39 ++++++------------------------- reddit/user.go | 23 ++++-------------- reddit/wiki.go | 39 +++++-------------------------- 7 files changed, 65 insertions(+), 160 deletions(-) diff --git a/reddit/account.go b/reddit/account.go index 58449cb..a4ed04d 100644 --- a/reddit/account.go +++ b/reddit/account.go @@ -236,19 +236,11 @@ func (s *AccountService) Info(ctx context.Context) (*User, *Response, error) { // Karma returns a breakdown of your karma per subreddit. func (s *AccountService) Karma(ctx context.Context) ([]*SubredditKarma, *Response, error) { path := "api/v1/me/karma" - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - karma, _ := root.Karma() + karma, _ := t.Karma() return karma, resp, nil } @@ -291,19 +283,11 @@ func (s *AccountService) UpdateSettings(ctx context.Context, settings *Settings) // Trophies returns a list of your trophies. func (s *AccountService) Trophies(ctx context.Context) ([]*Trophy, *Response, error) { path := "api/v1/me/trophies" - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - trophies, _ := root.TrophyList() + trophies, _ := t.TrophyList() return trophies, resp, nil } diff --git a/reddit/live-thread.go b/reddit/live-thread.go index b32d13b..bac46d4 100644 --- a/reddit/live-thread.go +++ b/reddit/live-thread.go @@ -221,40 +221,26 @@ func (p *LiveThreadPermissions) String() (s string) { // This returns an empty 204 response if no thread is currently featured. func (s *LiveThreadService) Now(ctx context.Context) (*LiveThread, *Response, error) { path := "api/live/happening_now" - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { if err == io.EOF && resp != nil && resp.StatusCode == http.StatusNoContent { return nil, resp, nil } return nil, resp, err } - - t, _ := root.LiveThread() - return t, resp, nil + liveThread, _ := t.LiveThread() + return liveThread, resp, nil } // Get information about a live thread. func (s *LiveThreadService) Get(ctx context.Context, id string) (*LiveThread, *Response, error) { path := fmt.Sprintf("live/%s/about", id) - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - t, _ := root.LiveThread() - return t, resp, nil + liveThread, _ := t.LiveThread() + return liveThread, resp, nil } // GetMultiple gets information about multiple live threads. diff --git a/reddit/multi.go b/reddit/multi.go index 4ae15b3..185d434 100644 --- a/reddit/multi.go +++ b/reddit/multi.go @@ -100,14 +100,6 @@ type MultiCreateOrUpdateRequest struct { Visibility string `json:"visibility,omitempty"` } -// Form parameterizes the fields and returns the form. -func (r *MultiCreateOrUpdateRequest) Form() url.Values { - byteValue, _ := json.Marshal(r) - form := url.Values{} - form.Set("model", string(byteValue)) - return form -} - type rootMultiDescription struct { Body string `json:"body_md"` } @@ -115,18 +107,11 @@ type rootMultiDescription struct { // Get the multireddit from its url path. func (s *MultiService) Get(ctx context.Context, multiPath string) (*Multi, *Response, error) { path := fmt.Sprintf("api/multi/%s", multiPath) - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - multi, _ := root.Multi() + multi, _ := t.Multi() return multi, resp, nil } @@ -198,8 +183,16 @@ func (s *MultiService) Create(ctx context.Context, createRequest *MultiCreateOrU return nil, nil, errors.New("createRequest: cannot be nil") } + byteValue, err := json.Marshal(createRequest) + if err != nil { + return nil, nil, err + } + + form := url.Values{} + form.Set("model", string(byteValue)) + path := "api/multi" - req, err := s.client.NewRequest(http.MethodPost, path, createRequest.Form()) + req, err := s.client.NewRequest(http.MethodPost, path, form) if err != nil { return nil, nil, err } @@ -221,8 +214,16 @@ func (s *MultiService) Update(ctx context.Context, multiPath string, updateReque return nil, nil, errors.New("updateRequest: cannot be nil") } + byteValue, err := json.Marshal(updateRequest) + if err != nil { + return nil, nil, err + } + + form := url.Values{} + form.Set("model", string(byteValue)) + path := fmt.Sprintf("api/multi/%s", multiPath) - req, err := s.client.NewRequest(http.MethodPut, path, updateRequest.Form()) + req, err := s.client.NewRequest(http.MethodPut, path, form) if err != nil { return nil, nil, err } @@ -250,29 +251,20 @@ func (s *MultiService) Delete(ctx context.Context, multiPath string) (*Response, // Description gets a multireddit's description. func (s *MultiService) Description(ctx context.Context, multiPath string) (string, *Response, error) { path := fmt.Sprintf("api/multi/%s/description", multiPath) - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return "", nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return "", resp, err } - - multiDescription, _ := root.MultiDescription() + multiDescription, _ := t.MultiDescription() return multiDescription, resp, nil } // UpdateDescription updates a multireddit's description. func (s *MultiService) UpdateDescription(ctx context.Context, multiPath string, description string) (string, *Response, error) { - path := fmt.Sprintf("api/multi/%s/description", multiPath) - form := url.Values{} form.Set("model", fmt.Sprintf(`{"body_md":"%s"}`, description)) + path := fmt.Sprintf("api/multi/%s/description", multiPath) req, err := s.client.NewRequest(http.MethodPut, path, form) if err != nil { return "", nil, err diff --git a/reddit/reddit.go b/reddit/reddit.go index 63f98b1..a103686 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -398,7 +398,9 @@ func CheckResponse(r *http.Response) error { return errorResponse } -func (c *Client) getListing(ctx context.Context, path string, opts interface{}) (*listing, *Response, error) { +// A lot of Reddit's responses returns a "thing": { "kind": "...", "data": {...} } +// So this is just a nice convenient method to have. +func (c *Client) getThing(ctx context.Context, path string, opts interface{}) (*thing, *Response, error) { path, err := addOptions(path, opts) if err != nil { return nil, nil, err @@ -409,14 +411,22 @@ func (c *Client) getListing(ctx context.Context, path string, opts interface{}) return nil, nil, err } - root := new(thing) - resp, err := c.Do(ctx, req, root) + t := new(thing) + resp, err := c.Do(ctx, req, t) if err != nil { return nil, resp, err } - listing, _ := root.Listing() - return listing, resp, nil + return t, resp, nil +} + +func (c *Client) getListing(ctx context.Context, path string, opts interface{}) (*listing, *Response, error) { + t, resp, err := c.getThing(ctx, path, opts) + if err != nil { + return nil, resp, err + } + l, _ := t.Listing() + return l, resp, nil } // ListOptions specifies the optional parameters to various API calls that return a listing. diff --git a/reddit/subreddit.go b/reddit/subreddit.go index 490b8e8..0aeb363 100644 --- a/reddit/subreddit.go +++ b/reddit/subreddit.go @@ -225,18 +225,12 @@ func (s *SubredditService) Get(ctx context.Context, name string) (*Subreddit, *R } path := fmt.Sprintf("r/%s/about", name) - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - sr, _ := root.Subreddit() + sr, _ := t.Subreddit() return sr, resp, nil } @@ -410,24 +404,13 @@ func (s *SubredditService) SearchPosts(ctx context.Context, query string, subred RestrictSubreddits bool `url:"restrict_sr,omitempty"` }{query, notAll} - path, err = addOptions(path, params) - if err != nil { - return nil, nil, err - } - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, params) if err != nil { return nil, resp, err } - listing, _ := root.Listing() - return listing.Posts(), resp, nil + l, _ := t.Listing() + return l.Posts(), resp, nil } func (s *SubredditService) getSubreddits(ctx context.Context, path string, opts *ListSubredditOptions) ([]*Subreddit, *Response, error) { @@ -789,19 +772,11 @@ func (s *SubredditService) Traffic(ctx context.Context, subreddit string) ([]*Su // StyleSheet returns the subreddit's style sheet, as well as some information about images. func (s *SubredditService) StyleSheet(ctx context.Context, subreddit string) (*SubredditStyleSheet, *Response, error) { path := fmt.Sprintf("r/%s/about/stylesheet", subreddit) - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - styleSheet, _ := root.StyleSheet() + styleSheet, _ := t.StyleSheet() return styleSheet, resp, nil } diff --git a/reddit/user.go b/reddit/user.go index e6d3164..f1839f5 100644 --- a/reddit/user.go +++ b/reddit/user.go @@ -61,18 +61,11 @@ type Trophy struct { // Get returns information about the user. func (s *UserService) Get(ctx context.Context, username string) (*User, *Response, error) { path := fmt.Sprintf("user/%s/about", username) - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - user, _ := root.User() + user, _ := t.User() return user, resp, nil } @@ -380,19 +373,11 @@ func (s *UserService) Trophies(ctx context.Context) ([]*Trophy, *Response, error // TrophiesOf returns a list of the specified user's trophies. func (s *UserService) TrophiesOf(ctx context.Context, username string) ([]*Trophy, *Response, error) { path := fmt.Sprintf("api/v1/user/%s/trophies", username) - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - trophies, _ := root.TrophyList() + trophies, _ := t.TrophyList() return trophies, resp, nil } diff --git a/reddit/wiki.go b/reddit/wiki.go index 114a8b1..ecbae23 100644 --- a/reddit/wiki.go +++ b/reddit/wiki.go @@ -192,23 +192,12 @@ func (s *WikiService) PageRevision(ctx context.Context, subreddit, page, revisio RevisionID string `url:"v,omitempty"` }{revisionID} - path, err := addOptions(path, params) - if err != nil { - return nil, nil, err - } - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, params) if err != nil { return nil, resp, err } - wikiPage, _ := root.WikiPage() + wikiPage, _ := t.WikiPage() return wikiPage, resp, nil } @@ -216,19 +205,11 @@ func (s *WikiService) PageRevision(ctx context.Context, subreddit, page, revisio // Returns 403 Forbidden if the wiki is disabled. func (s *WikiService) Pages(ctx context.Context, subreddit string) ([]string, *Response, error) { path := fmt.Sprintf("r/%s/wiki/pages", subreddit) - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - wikiPages, _ := root.WikiPages() + wikiPages, _ := t.WikiPages() return wikiPages, resp, nil } @@ -271,19 +252,11 @@ func (s *WikiService) Revert(ctx context.Context, subreddit, page, revisionID st // Settings gets the subreddit's wiki page's settings. func (s *WikiService) Settings(ctx context.Context, subreddit, page string) (*WikiPageSettings, *Response, error) { path := fmt.Sprintf("r/%s/wiki/settings/%s", subreddit, page) - - req, err := s.client.NewRequest(http.MethodGet, path, nil) - if err != nil { - return nil, nil, err - } - - root := new(thing) - resp, err := s.client.Do(ctx, req, root) + t, resp, err := s.client.getThing(ctx, path, nil) if err != nil { return nil, resp, err } - - settings, _ := root.WikiPageSettings() + settings, _ := t.WikiPageSettings() return settings, resp, nil }