diff --git a/private-message.go b/private-message.go new file mode 100644 index 0000000..e16dc5b --- /dev/null +++ b/private-message.go @@ -0,0 +1,123 @@ +package reddit + +import ( + "context" + "errors" + "net/http" + "net/url" + "strings" +) + +// PrivateMessageService handles communication with the private message +// related methods of the Reddit API. +// +// Reddit API docs: https://www.reddit.com/dev/api/#section_messages +type PrivateMessageService struct { + client *Client +} + +// ReadAll marks all messages/comments as read. It queues up the task on Reddit's end. +// A successful response returns 202 to acknowledge acceptance of the request. +// This endpoint is heavily rate limited. +func (s *PrivateMessageService) ReadAll(ctx context.Context) (*Response, error) { + path := "api/read_all_messages" + + req, err := s.client.NewRequest(http.MethodPost, path, nil) + if err != nil { + return nil, err + } + + return s.client.Do(ctx, req, nil) +} + +// Read marks a message/comment as read via its full ID. +func (s *PrivateMessageService) Read(ctx context.Context, ids ...string) (*Response, error) { + if len(ids) == 0 { + return nil, errors.New("must provide at least 1 id") + } + + path := "api/read_message" + + form := url.Values{} + form.Set("id", strings.Join(ids, ",")) + + req, err := s.client.NewRequestWithForm(http.MethodPost, path, form) + if err != nil { + return nil, err + } + + return s.client.Do(ctx, req, nil) +} + +// Unread marks a message/comment as unread via its full ID. +func (s *PrivateMessageService) Unread(ctx context.Context, ids ...string) (*Response, error) { + if len(ids) == 0 { + return nil, errors.New("must provide at least 1 id") + } + + path := "api/unread_message" + + form := url.Values{} + form.Set("id", strings.Join(ids, ",")) + + req, err := s.client.NewRequestWithForm(http.MethodPost, path, form) + if err != nil { + return nil, err + } + + return s.client.Do(ctx, req, nil) +} + +// Block blocks the author of a thing via the thing's full ID. +// The thing can be a post, comment or message. +func (s *PrivateMessageService) Block(ctx context.Context, id string) (*Response, error) { + path := "api/block" + + form := url.Values{} + form.Set("id", id) + + req, err := s.client.NewRequestWithForm(http.MethodPost, path, form) + if err != nil { + return nil, err + } + + return s.client.Do(ctx, req, nil) +} + +// Collapse collapses messages. +func (s *PrivateMessageService) Collapse(ctx context.Context, ids ...string) (*Response, error) { + if len(ids) == 0 { + return nil, errors.New("must provide at least 1 id") + } + + path := "api/collapse_message" + + form := url.Values{} + form.Set("id", strings.Join(ids, ",")) + + req, err := s.client.NewRequestWithForm(http.MethodPost, path, form) + if err != nil { + return nil, err + } + + return s.client.Do(ctx, req, nil) +} + +// Uncollapse uncollapses messages. +func (s *PrivateMessageService) Uncollapse(ctx context.Context, ids ...string) (*Response, error) { + if len(ids) == 0 { + return nil, errors.New("must provide at least 1 id") + } + + path := "api/uncollapse_message" + + form := url.Values{} + form.Set("id", strings.Join(ids, ",")) + + req, err := s.client.NewRequestWithForm(http.MethodPost, path, form) + if err != nil { + return nil, err + } + + return s.client.Do(ctx, req, nil) +} diff --git a/private-message_test.go b/private-message_test.go new file mode 100644 index 0000000..9ec7b6a --- /dev/null +++ b/private-message_test.go @@ -0,0 +1,130 @@ +package reddit + +import ( + "net/http" + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrivateMessageService_ReadAll(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/read_all_messages", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + w.WriteHeader(http.StatusAccepted) + }) + + res, err := client.PrivateMessage.ReadAll(ctx) + require.NoError(t, err) + require.Equal(t, http.StatusAccepted, res.StatusCode) +} + +func TestPrivateMessageService_Read(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/read_message", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("id", "test1,test2,test3") + + err := r.ParseForm() + require.NoError(t, err) + require.Equal(t, form, r.Form) + }) + + _, err := client.PrivateMessage.Read(ctx) + require.EqualError(t, err, "must provide at least 1 id") + + _, err = client.PrivateMessage.Read(ctx, "test1", "test2", "test3") + require.NoError(t, err) +} + +func TestPrivateMessageService_Unread(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/unread_message", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("id", "test1,test2,test3") + + err := r.ParseForm() + require.NoError(t, err) + require.Equal(t, form, r.Form) + }) + + _, err := client.PrivateMessage.Unread(ctx) + require.EqualError(t, err, "must provide at least 1 id") + + _, err = client.PrivateMessage.Unread(ctx, "test1", "test2", "test3") + require.NoError(t, err) +} + +func TestPrivateMessageService_Block(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/block", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("id", "test") + + err := r.ParseForm() + require.NoError(t, err) + require.Equal(t, form, r.Form) + }) + + _, err := client.PrivateMessage.Block(ctx, "test") + require.NoError(t, err) +} + +func TestPrivateMessageService_Collapse(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/collapse_message", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("id", "test1,test2,test3") + + err := r.ParseForm() + require.NoError(t, err) + require.Equal(t, form, r.Form) + }) + + _, err := client.PrivateMessage.Collapse(ctx) + require.EqualError(t, err, "must provide at least 1 id") + + _, err = client.PrivateMessage.Collapse(ctx, "test1", "test2", "test3") + require.NoError(t, err) +} + +func TestPrivateMessageService_Uncollapse(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/api/uncollapse_message", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + + form := url.Values{} + form.Set("id", "test1,test2,test3") + + err := r.ParseForm() + require.NoError(t, err) + require.Equal(t, form, r.Form) + }) + + _, err := client.PrivateMessage.Uncollapse(ctx) + require.EqualError(t, err, "must provide at least 1 id") + + _, err = client.PrivateMessage.Uncollapse(ctx, "test1", "test2", "test3") + require.NoError(t, err) +} diff --git a/reddit.go b/reddit.go index 035858b..05e8ed6 100644 --- a/reddit.go +++ b/reddit.go @@ -91,17 +91,18 @@ type Client struct { // This is the client's user ID in Reddit's database. redditID string - Account *AccountService - Collection *CollectionService - Comment *CommentService - Emoji *EmojiService - Flair *FlairService - Listings *ListingsService - Moderation *ModerationService - Multi *MultiService - Post *PostService - Subreddit *SubredditService - User *UserService + Account *AccountService + Collection *CollectionService + Comment *CommentService + Emoji *EmojiService + Flair *FlairService + Listings *ListingsService + Moderation *ModerationService + Multi *MultiService + Post *PostService + PrivateMessage *PrivateMessageService + Subreddit *SubredditService + User *UserService oauth2Transport *oauth2.Transport @@ -147,6 +148,7 @@ func newClient(httpClient *http.Client) *Client { c.Listings = &ListingsService{client: c} c.Moderation = &ModerationService{client: c} c.Multi = &MultiService{client: c} + c.PrivateMessage = &PrivateMessageService{client: c} c.Subreddit = &SubredditService{client: c} c.User = &UserService{client: c} diff --git a/reddit_test.go b/reddit_test.go index 3c2f200..bf0b7cd 100644 --- a/reddit_test.go +++ b/reddit_test.go @@ -75,6 +75,7 @@ func testClientServices(t *testing.T, c *Client) { "Moderation", "Multi", "Post", + "PrivateMessage", "Subreddit", "User", }