From ba5eea5ed455e48971e720945b47d06d6f9e7da1 Mon Sep 17 00:00:00 2001 From: Vartan Benohanian Date: Sun, 23 Aug 2020 22:25:29 -0400 Subject: [PATCH] Create Credentials struct Signed-off-by: Vartan Benohanian --- examples/client-on-request-completed/main.go | 9 ++- examples/get-subreddit/main.go | 9 ++- examples/get-top-posts/main.go | 9 ++- examples/stream-posts/main.go | 11 +++- examples/submit-post/main.go | 9 ++- reddit/reddit-options.go | 11 ---- reddit/reddit-options_test.go | 16 +---- reddit/reddit.go | 66 ++++++++++++-------- reddit/reddit_test.go | 8 +-- 9 files changed, 84 insertions(+), 64 deletions(-) diff --git a/examples/client-on-request-completed/main.go b/examples/client-on-request-completed/main.go index bba2610..6f8dc50 100644 --- a/examples/client-on-request-completed/main.go +++ b/examples/client-on-request-completed/main.go @@ -18,9 +18,14 @@ func main() { } func run() (err error) { - withCredentials := reddit.WithCredentials("id", "secret", "username", "password") + credentials := &reddit.Credentials{ + ID: "id", + Secret: "secret", + Username: "username", + Password: "password", + } - client, err := reddit.NewClient(nil, withCredentials) + client, err := reddit.NewClient(nil, credentials) if err != nil { return } diff --git a/examples/get-subreddit/main.go b/examples/get-subreddit/main.go index 60a11b5..5413a44 100644 --- a/examples/get-subreddit/main.go +++ b/examples/get-subreddit/main.go @@ -17,9 +17,14 @@ func main() { } func run() (err error) { - withCredentials := reddit.WithCredentials("id", "secret", "username", "password") + credentials := &reddit.Credentials{ + ID: "id", + Secret: "secret", + Username: "username", + Password: "password", + } - client, err := reddit.NewClient(nil, withCredentials) + client, err := reddit.NewClient(nil, credentials) if err != nil { return } diff --git a/examples/get-top-posts/main.go b/examples/get-top-posts/main.go index 23ee0c0..7e8cb24 100644 --- a/examples/get-top-posts/main.go +++ b/examples/get-top-posts/main.go @@ -17,9 +17,14 @@ func main() { } func run() (err error) { - withCredentials := reddit.WithCredentials("id", "secret", "username", "password") + credentials := &reddit.Credentials{ + ID: "id", + Secret: "secret", + Username: "username", + Password: "password", + } - client, err := reddit.NewClient(nil, withCredentials) + client, err := reddit.NewClient(nil, credentials) if err != nil { return } diff --git a/examples/stream-posts/main.go b/examples/stream-posts/main.go index bd1c98f..1664c07 100644 --- a/examples/stream-posts/main.go +++ b/examples/stream-posts/main.go @@ -19,9 +19,14 @@ func main() { } func run() (err error) { - withCredentials := reddit.WithCredentials("id", "secret", "username", "password") + credentials := &reddit.Credentials{ + ID: "id", + Secret: "secret", + Username: "username", + Password: "password", + } - client, err := reddit.NewClient(nil, withCredentials) + client, err := reddit.NewClient(nil, credentials) if err != nil { return } @@ -41,7 +46,7 @@ func run() (err error) { if !ok { return } - fmt.Fprintf(os.Stderr, "Error! %v", err) + fmt.Fprintf(os.Stderr, "Error! %v\n", err) } } }() diff --git a/examples/submit-post/main.go b/examples/submit-post/main.go index bfb3a8d..0723c09 100644 --- a/examples/submit-post/main.go +++ b/examples/submit-post/main.go @@ -17,9 +17,14 @@ func main() { } func run() (err error) { - withCredentials := reddit.WithCredentials("id", "secret", "username", "password") + credentials := &reddit.Credentials{ + ID: "id", + Secret: "secret", + Username: "username", + Password: "password", + } - client, err := reddit.NewClient(nil, withCredentials) + client, err := reddit.NewClient(nil, credentials) if err != nil { return } diff --git a/reddit/reddit-options.go b/reddit/reddit-options.go index 555584d..5db75f7 100644 --- a/reddit/reddit-options.go +++ b/reddit/reddit-options.go @@ -35,17 +35,6 @@ func FromEnv(c *Client) error { return nil } -// WithCredentials sets the necessary values for the client to authenticate via OAuth2. -func WithCredentials(id, secret, username, password string) Opt { - return func(c *Client) error { - c.ID = id - c.Secret = secret - c.Username = username - c.Password = password - return nil - } -} - // WithBaseURL sets the base URL for the client to make requests to. func WithBaseURL(u string) Opt { return func(c *Client) error { diff --git a/reddit/reddit-options_test.go b/reddit/reddit-options_test.go index 2d24df8..7abb40b 100644 --- a/reddit/reddit-options_test.go +++ b/reddit/reddit-options_test.go @@ -20,7 +20,7 @@ func TestFromEnv(t *testing.T) { os.Setenv("GO_REDDIT_CLIENT_PASSWORD", "password1") defer os.Unsetenv("GO_REDDIT_CLIENT_PASSWORD") - c, err := NewClient(nil, FromEnv) + c, err := NewClient(nil, nil, FromEnv) require.NoError(t, err) type values struct { @@ -32,26 +32,16 @@ func TestFromEnv(t *testing.T) { require.Equal(t, expect, actual) } -func TestWithCredentials(t *testing.T) { - withCredentials := WithCredentials("id1", "secret1", "username1", "password1") - c, err := NewClient(nil, withCredentials) - require.NoError(t, err) - require.Equal(t, "id1", c.ID) - require.Equal(t, "secret1", c.Secret) - require.Equal(t, "username1", c.Username) - require.Equal(t, "password1", c.Password) -} - func TestWithBaseURL(t *testing.T) { baseURL := "http://localhost:8080" - c, err := NewClient(nil, WithBaseURL(baseURL)) + c, err := NewClient(nil, nil, WithBaseURL(baseURL)) require.NoError(t, err) require.Equal(t, baseURL, c.BaseURL.String()) } func TestWithTokenURL(t *testing.T) { tokenURL := "http://localhost:8080/api/v1/access_token" - c, err := NewClient(nil, WithTokenURL(tokenURL)) + c, err := NewClient(nil, nil, WithTokenURL(tokenURL)) require.NoError(t, err) require.Equal(t, tokenURL, c.TokenURL.String()) } diff --git a/reddit/reddit.go b/reddit/reddit.go index 25773ae..ee3a1be 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -73,6 +73,14 @@ func (t *userAgentTransport) base() http.RoundTripper { // RequestCompletionCallback defines the type of the request callback function. type RequestCompletionCallback func(*http.Request, *http.Response) +// Credentials used to authenticate to make requests to the Reddit API. +type Credentials struct { + ID string + Secret string + Username string + Password string +} + // Client manages communication with the Reddit API. type Client struct { // HTTP client used to communicate with the Reddit API. @@ -141,42 +149,50 @@ func newClient(httpClient *http.Client) *Client { baseURL, _ := url.Parse(defaultBaseURL) tokenURL, _ := url.Parse(defaultTokenURL) - c := &Client{client: httpClient, BaseURL: baseURL, TokenURL: tokenURL} + client := &Client{client: httpClient, BaseURL: baseURL, TokenURL: tokenURL} - c.Account = &AccountService{client: c} - c.Collection = &CollectionService{client: c} - c.Emoji = &EmojiService{client: c} - c.Flair = &FlairService{client: c} - c.Gold = &GoldService{client: c} - c.Listings = &ListingsService{client: c} - c.Message = &MessageService{client: c} - c.Moderation = &ModerationService{client: c} - c.Multi = &MultiService{client: c} - c.Stream = &StreamService{client: c} - c.Subreddit = &SubredditService{client: c} - c.User = &UserService{client: c} + client.Account = &AccountService{client: client} + client.Collection = &CollectionService{client: client} + client.Emoji = &EmojiService{client: client} + client.Flair = &FlairService{client: client} + client.Gold = &GoldService{client: client} + client.Listings = &ListingsService{client: client} + client.Message = &MessageService{client: client} + client.Moderation = &ModerationService{client: client} + client.Multi = &MultiService{client: client} + client.Stream = &StreamService{client: client} + client.Subreddit = &SubredditService{client: client} + client.User = &UserService{client: client} - postAndCommentService := &postAndCommentService{client: c} - c.Comment = &CommentService{client: c, postAndCommentService: postAndCommentService} - c.Post = &PostService{client: c, postAndCommentService: postAndCommentService} + postAndCommentService := &postAndCommentService{client: client} + client.Comment = &CommentService{client: client, postAndCommentService: postAndCommentService} + client.Post = &PostService{client: client, postAndCommentService: postAndCommentService} - return c + return client } -// NewClient returns a client that can make requests to the Reddit API. -func NewClient(httpClient *http.Client, opts ...Opt) (c *Client, err error) { - c = newClient(httpClient) +// NewClient returns a new Reddit API client. If a nil httpClient is provided, +// a new http.Client will be used. +func NewClient(httpClient *http.Client, creds *Credentials, opts ...Opt) (*Client, error) { + client := newClient(httpClient) for _, opt := range opts { - if err = opt(c); err != nil { - return + if err := opt(client); err != nil { + return nil, err } } - oauthTransport := oauthTransport(c) - c.client.Transport = oauthTransport + if creds != nil { + client.ID = creds.ID + client.Secret = creds.Secret + client.Username = creds.Username + client.Password = creds.Password - return + oauthTransport := oauthTransport(client) + client.client.Transport = oauthTransport + } + + return client, nil } // UserAgent returns the client's user agent. diff --git a/reddit/reddit_test.go b/reddit/reddit_test.go index 442a1a6..068bbea 100644 --- a/reddit/reddit_test.go +++ b/reddit/reddit_test.go @@ -39,7 +39,7 @@ func setup() { }) client, _ = NewClient(nil, - WithCredentials("id1", "secret1", "user1", "password1"), + &Credentials{"id1", "secret1", "user1", "password1"}, WithBaseURL(server.URL), WithTokenURL(server.URL+"/api/v1/access_token"), ) @@ -92,7 +92,7 @@ func testClientServices(t *testing.T, c *Client) { func testClientDefaultUserAgent(t *testing.T, c *Client) { expectedUserAgent := fmt.Sprintf("golang:%s:v%s (by /u/)", libraryName, libraryVersion) - require.Equal(t, expectedUserAgent, c.userAgent) + require.Equal(t, expectedUserAgent, c.UserAgent()) } func testClientDefaults(t *testing.T, c *Client) { @@ -101,7 +101,7 @@ func testClientDefaults(t *testing.T, c *Client) { } func TestNewClient(t *testing.T) { - c, err := NewClient(nil) + c, err := NewClient(nil, nil) require.NoError(t, err) testClientDefaults(t, c) } @@ -111,7 +111,7 @@ func TestNewClient_Error(t *testing.T) { return errors.New("foo") } - _, err := NewClient(nil, errorOpt) + _, err := NewClient(nil, nil, errorOpt) require.EqualError(t, err, "foo") }