diff --git a/reddit/reddit.go b/reddit/reddit.go index c44e241..25773ae 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -102,6 +102,7 @@ type Client struct { Moderation *ModerationService Multi *MultiService Post *PostService + Stream *StreamService Subreddit *SubredditService User *UserService @@ -151,6 +152,7 @@ func newClient(httpClient *http.Client) *Client { c.Message = &MessageService{client: c} c.Moderation = &ModerationService{client: c} c.Multi = &MultiService{client: c} + c.Stream = &StreamService{client: c} c.Subreddit = &SubredditService{client: c} c.User = &UserService{client: c} diff --git a/reddit/reddit_test.go b/reddit/reddit_test.go index 30812c8..442a1a6 100644 --- a/reddit/reddit_test.go +++ b/reddit/reddit_test.go @@ -77,6 +77,7 @@ func testClientServices(t *testing.T, c *Client) { "Moderation", "Multi", "Post", + "Stream", "Subreddit", "User", } diff --git a/reddit/stream.go b/reddit/stream.go new file mode 100644 index 0000000..634feb1 --- /dev/null +++ b/reddit/stream.go @@ -0,0 +1,93 @@ +package reddit + +import ( + "context" + "sync" + "time" +) + +// StreamService allows streaming new content from Reddit as it appears. +type StreamService struct { + client *Client +} + +// Posts streams posts from the specified subreddit. +// It returns 2 channels and a function: +// - a channel into which new posts will be sent +// - a channel into which any errors will be sent +// - a function that the client can call once to stop the streaming and close the channels +// Because of the 100 post limit imposed by Reddit when fetching posts, some high-traffic +// streams might drop submissions between API requests, such as when streaming r/all. +func (s *StreamService) Posts(subreddit string) (<-chan *Post, <-chan error, func()) { + posts := make(chan *Post) + errs := make(chan error) + ticker := time.NewTicker(time.Second * 5) + + // originally used the "before" parameter, but if that post gets deleted, subsequent requests + // would just return empty listings; easier to just keep track of all post ids encountered + ids := set{} + + go func() { + for ; ; <-ticker.C { + result, err := s.getPosts(subreddit) + if err != nil { + errs <- err + continue + } + + for _, post := range result.Posts { + id := post.FullID + + // if this post id is already part of the set, it means that it and the ones + // after it in the list have already been streamed, so break out of the loop + if ids.Exists(id) { + break + } + + ids.Add(id) + posts <- post + } + } + }() + + var once sync.Once + return posts, errs, func() { + once.Do(func() { + ticker.Stop() + close(posts) + close(errs) + }) + } +} + +func (s *StreamService) getPosts(subreddit string) (*Posts, error) { + opts := &ListOptions{ + Limit: 100, + } + + result, _, err := s.client.Subreddit.NewPosts(context.Background(), subreddit, opts) + if err != nil { + return nil, err + } + + return result, nil +} + +type set map[string]struct{} + +func (s set) Add(v string) { + s[v] = struct{}{} +} + +func (s set) Delete(v string) { + delete(s, v) +} + +func (s set) Len() int { + return len(s) +} + +func (s set) Exists(v string) bool { + _, ok := s[v] + return ok +} diff --git a/reddit/subreddit.go b/reddit/subreddit.go index c6d47f7..70f9e79 100644 --- a/reddit/subreddit.go +++ b/reddit/subreddit.go @@ -63,7 +63,8 @@ type Bans struct { Before string `json:"before"` } -func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddit string, opts *ListPostOptions) (*Posts, *Response, error) { +// todo: interface{}, seriously? +func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddit string, opts interface{}) (*Posts, *Response, error) { path := sort if subreddit != "" { path = fmt.Sprintf("r/%s/%s", subreddit, sort) @@ -95,7 +96,7 @@ func (s *SubredditService) getPosts(ctx context.Context, sort string, subreddit // To search through all and filter out subreddits, provide "all-name1-name2". // Note: when looking for hot posts in a subreddit, it will include the stickied // posts (if any) PLUS posts from the limit parameter (25 by default). -func (s *SubredditService) HotPosts(ctx context.Context, subreddit string, opts *ListPostOptions) (*Posts, *Response, error) { +func (s *SubredditService) HotPosts(ctx context.Context, subreddit string, opts *ListOptions) (*Posts, *Response, error) { return s.getPosts(ctx, "hot", subreddit, opts) } @@ -104,7 +105,7 @@ func (s *SubredditService) HotPosts(ctx context.Context, subreddit string, opts // If none are defined, it returns the ones from your subscribed subreddits. // To search through all, just specify "all". // To search through all and filter out subreddits, provide "all-name1-name2". -func (s *SubredditService) NewPosts(ctx context.Context, subreddit string, opts *ListPostOptions) (*Posts, *Response, error) { +func (s *SubredditService) NewPosts(ctx context.Context, subreddit string, opts *ListOptions) (*Posts, *Response, error) { return s.getPosts(ctx, "new", subreddit, opts) } @@ -113,7 +114,7 @@ func (s *SubredditService) NewPosts(ctx context.Context, subreddit string, opts // If none are defined, it returns the ones from your subscribed subreddits. // To search through all, just specify "all". // To search through all and filter out subreddits, provide "all-name1-name2". -func (s *SubredditService) RisingPosts(ctx context.Context, subreddit string, opts *ListPostOptions) (*Posts, *Response, error) { +func (s *SubredditService) RisingPosts(ctx context.Context, subreddit string, opts *ListOptions) (*Posts, *Response, error) { return s.getPosts(ctx, "rising", subreddit, opts) }