Create getThing method to remove duplicate code

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-09-18 11:27:42 -04:00
parent 1eb75d163c
commit e01896137a
7 changed files with 65 additions and 160 deletions

View file

@ -236,19 +236,11 @@ func (s *AccountService) Info(ctx context.Context) (*User, *Response, error) {
// Karma returns a breakdown of your karma per subreddit.
func (s *AccountService) Karma(ctx context.Context) ([]*SubredditKarma, *Response, error) {
path := "api/v1/me/karma"
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
karma, _ := root.Karma()
karma, _ := t.Karma()
return karma, resp, nil
}
@ -291,19 +283,11 @@ func (s *AccountService) UpdateSettings(ctx context.Context, settings *Settings)
// Trophies returns a list of your trophies.
func (s *AccountService) Trophies(ctx context.Context) ([]*Trophy, *Response, error) {
path := "api/v1/me/trophies"
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
trophies, _ := root.TrophyList()
trophies, _ := t.TrophyList()
return trophies, resp, nil
}

View file

@ -221,40 +221,26 @@ func (p *LiveThreadPermissions) String() (s string) {
// This returns an empty 204 response if no thread is currently featured.
func (s *LiveThreadService) Now(ctx context.Context) (*LiveThread, *Response, error) {
path := "api/live/happening_now"
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
if err == io.EOF && resp != nil && resp.StatusCode == http.StatusNoContent {
return nil, resp, nil
}
return nil, resp, err
}
t, _ := root.LiveThread()
return t, resp, nil
liveThread, _ := t.LiveThread()
return liveThread, resp, nil
}
// Get information about a live thread.
func (s *LiveThreadService) Get(ctx context.Context, id string) (*LiveThread, *Response, error) {
path := fmt.Sprintf("live/%s/about", id)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
t, _ := root.LiveThread()
return t, resp, nil
liveThread, _ := t.LiveThread()
return liveThread, resp, nil
}
// GetMultiple gets information about multiple live threads.

View file

@ -100,14 +100,6 @@ type MultiCreateOrUpdateRequest struct {
Visibility string `json:"visibility,omitempty"`
}
// Form parameterizes the fields and returns the form.
func (r *MultiCreateOrUpdateRequest) Form() url.Values {
byteValue, _ := json.Marshal(r)
form := url.Values{}
form.Set("model", string(byteValue))
return form
}
type rootMultiDescription struct {
Body string `json:"body_md"`
}
@ -115,18 +107,11 @@ type rootMultiDescription struct {
// Get the multireddit from its url path.
func (s *MultiService) Get(ctx context.Context, multiPath string) (*Multi, *Response, error) {
path := fmt.Sprintf("api/multi/%s", multiPath)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
multi, _ := root.Multi()
multi, _ := t.Multi()
return multi, resp, nil
}
@ -198,8 +183,16 @@ func (s *MultiService) Create(ctx context.Context, createRequest *MultiCreateOrU
return nil, nil, errors.New("createRequest: cannot be nil")
}
byteValue, err := json.Marshal(createRequest)
if err != nil {
return nil, nil, err
}
form := url.Values{}
form.Set("model", string(byteValue))
path := "api/multi"
req, err := s.client.NewRequest(http.MethodPost, path, createRequest.Form())
req, err := s.client.NewRequest(http.MethodPost, path, form)
if err != nil {
return nil, nil, err
}
@ -221,8 +214,16 @@ func (s *MultiService) Update(ctx context.Context, multiPath string, updateReque
return nil, nil, errors.New("updateRequest: cannot be nil")
}
byteValue, err := json.Marshal(updateRequest)
if err != nil {
return nil, nil, err
}
form := url.Values{}
form.Set("model", string(byteValue))
path := fmt.Sprintf("api/multi/%s", multiPath)
req, err := s.client.NewRequest(http.MethodPut, path, updateRequest.Form())
req, err := s.client.NewRequest(http.MethodPut, path, form)
if err != nil {
return nil, nil, err
}
@ -250,29 +251,20 @@ func (s *MultiService) Delete(ctx context.Context, multiPath string) (*Response,
// Description gets a multireddit's description.
func (s *MultiService) Description(ctx context.Context, multiPath string) (string, *Response, error) {
path := fmt.Sprintf("api/multi/%s/description", multiPath)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return "", nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return "", resp, err
}
multiDescription, _ := root.MultiDescription()
multiDescription, _ := t.MultiDescription()
return multiDescription, resp, nil
}
// UpdateDescription updates a multireddit's description.
func (s *MultiService) UpdateDescription(ctx context.Context, multiPath string, description string) (string, *Response, error) {
path := fmt.Sprintf("api/multi/%s/description", multiPath)
form := url.Values{}
form.Set("model", fmt.Sprintf(`{"body_md":"%s"}`, description))
path := fmt.Sprintf("api/multi/%s/description", multiPath)
req, err := s.client.NewRequest(http.MethodPut, path, form)
if err != nil {
return "", nil, err

View file

@ -398,7 +398,9 @@ func CheckResponse(r *http.Response) error {
return errorResponse
}
func (c *Client) getListing(ctx context.Context, path string, opts interface{}) (*listing, *Response, error) {
// A lot of Reddit's responses returns a "thing": { "kind": "...", "data": {...} }
// So this is just a nice convenient method to have.
func (c *Client) getThing(ctx context.Context, path string, opts interface{}) (*thing, *Response, error) {
path, err := addOptions(path, opts)
if err != nil {
return nil, nil, err
@ -409,14 +411,22 @@ func (c *Client) getListing(ctx context.Context, path string, opts interface{})
return nil, nil, err
}
root := new(thing)
resp, err := c.Do(ctx, req, root)
t := new(thing)
resp, err := c.Do(ctx, req, t)
if err != nil {
return nil, resp, err
}
listing, _ := root.Listing()
return listing, resp, nil
return t, resp, nil
}
func (c *Client) getListing(ctx context.Context, path string, opts interface{}) (*listing, *Response, error) {
t, resp, err := c.getThing(ctx, path, opts)
if err != nil {
return nil, resp, err
}
l, _ := t.Listing()
return l, resp, nil
}
// ListOptions specifies the optional parameters to various API calls that return a listing.

View file

@ -225,18 +225,12 @@ func (s *SubredditService) Get(ctx context.Context, name string) (*Subreddit, *R
}
path := fmt.Sprintf("r/%s/about", name)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
sr, _ := root.Subreddit()
sr, _ := t.Subreddit()
return sr, resp, nil
}
@ -410,24 +404,13 @@ func (s *SubredditService) SearchPosts(ctx context.Context, query string, subred
RestrictSubreddits bool `url:"restrict_sr,omitempty"`
}{query, notAll}
path, err = addOptions(path, params)
if err != nil {
return nil, nil, err
}
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, params)
if err != nil {
return nil, resp, err
}
listing, _ := root.Listing()
return listing.Posts(), resp, nil
l, _ := t.Listing()
return l.Posts(), resp, nil
}
func (s *SubredditService) getSubreddits(ctx context.Context, path string, opts *ListSubredditOptions) ([]*Subreddit, *Response, error) {
@ -789,19 +772,11 @@ func (s *SubredditService) Traffic(ctx context.Context, subreddit string) ([]*Su
// StyleSheet returns the subreddit's style sheet, as well as some information about images.
func (s *SubredditService) StyleSheet(ctx context.Context, subreddit string) (*SubredditStyleSheet, *Response, error) {
path := fmt.Sprintf("r/%s/about/stylesheet", subreddit)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
styleSheet, _ := root.StyleSheet()
styleSheet, _ := t.StyleSheet()
return styleSheet, resp, nil
}

View file

@ -61,18 +61,11 @@ type Trophy struct {
// Get returns information about the user.
func (s *UserService) Get(ctx context.Context, username string) (*User, *Response, error) {
path := fmt.Sprintf("user/%s/about", username)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
user, _ := root.User()
user, _ := t.User()
return user, resp, nil
}
@ -380,19 +373,11 @@ func (s *UserService) Trophies(ctx context.Context) ([]*Trophy, *Response, error
// TrophiesOf returns a list of the specified user's trophies.
func (s *UserService) TrophiesOf(ctx context.Context, username string) ([]*Trophy, *Response, error) {
path := fmt.Sprintf("api/v1/user/%s/trophies", username)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
trophies, _ := root.TrophyList()
trophies, _ := t.TrophyList()
return trophies, resp, nil
}

View file

@ -192,23 +192,12 @@ func (s *WikiService) PageRevision(ctx context.Context, subreddit, page, revisio
RevisionID string `url:"v,omitempty"`
}{revisionID}
path, err := addOptions(path, params)
if err != nil {
return nil, nil, err
}
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, params)
if err != nil {
return nil, resp, err
}
wikiPage, _ := root.WikiPage()
wikiPage, _ := t.WikiPage()
return wikiPage, resp, nil
}
@ -216,19 +205,11 @@ func (s *WikiService) PageRevision(ctx context.Context, subreddit, page, revisio
// Returns 403 Forbidden if the wiki is disabled.
func (s *WikiService) Pages(ctx context.Context, subreddit string) ([]string, *Response, error) {
path := fmt.Sprintf("r/%s/wiki/pages", subreddit)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
wikiPages, _ := root.WikiPages()
wikiPages, _ := t.WikiPages()
return wikiPages, resp, nil
}
@ -271,19 +252,11 @@ func (s *WikiService) Revert(ctx context.Context, subreddit, page, revisionID st
// Settings gets the subreddit's wiki page's settings.
func (s *WikiService) Settings(ctx context.Context, subreddit, page string) (*WikiPageSettings, *Response, error) {
path := fmt.Sprintf("r/%s/wiki/settings/%s", subreddit, page)
req, err := s.client.NewRequest(http.MethodGet, path, nil)
if err != nil {
return nil, nil, err
}
root := new(thing)
resp, err := s.client.Do(ctx, req, root)
t, resp, err := s.client.getThing(ctx, path, nil)
if err != nil {
return nil, resp, err
}
settings, _ := root.WikiPageSettings()
settings, _ := t.WikiPageSettings()
return settings, resp, nil
}