diff --git a/reddit/stream.go b/reddit/stream.go index 634feb1..0f75f0e 100644 --- a/reddit/stream.go +++ b/reddit/stream.go @@ -18,20 +18,48 @@ type StreamService struct { // - a function that the client can call once to stop the streaming and close the channels // Because of the 100 post limit imposed by Reddit when fetching posts, some high-traffic // streams might drop submissions between API requests, such as when streaming r/all. -func (s *StreamService) Posts(subreddit string) (<-chan *Post, <-chan error, func()) { +func (s *StreamService) Posts(subreddit string, opts ...StreamOpt) (<-chan *Post, <-chan error, func()) { + streamConfig := &streamConfig{ + Interval: defaultStreamInterval, + DiscardInitial: false, + MaxRequests: 0, + } + for _, opt := range opts { + opt(streamConfig) + } + + ticker := time.NewTicker(streamConfig.Interval) posts := make(chan *Post) errs := make(chan error) - ticker := time.NewTicker(time.Second * 5) + + var once sync.Once + stop := func() { + once.Do(func() { + ticker.Stop() + close(posts) + close(errs) + }) + } // originally used the "before" parameter, but if that post gets deleted, subsequent requests // would just return empty listings; easier to just keep track of all post ids encountered ids := set{} go func() { + defer stop() + + var n int + infinite := streamConfig.MaxRequests == 0 + for ; ; <-ticker.C { + n++ + result, err := s.getPosts(subreddit) if err != nil { errs <- err + if !infinite && n >= streamConfig.MaxRequests { + break + } continue } @@ -43,34 +71,28 @@ func (s *StreamService) Posts(subreddit string) (<-chan *Post, <-chan error, fun if ids.Exists(id) { break } - ids.Add(id) + + if streamConfig.DiscardInitial { + streamConfig.DiscardInitial = false + break + } + posts <- post } + + if !infinite && n >= streamConfig.MaxRequests { + break + } } }() - var once sync.Once - return posts, errs, func() { - once.Do(func() { - ticker.Stop() - close(posts) - close(errs) - }) - } + return posts, errs, stop } func (s *StreamService) getPosts(subreddit string) (*Posts, error) { - opts := &ListOptions{ - Limit: 100, - } - - result, _, err := s.client.Subreddit.NewPosts(context.Background(), subreddit, opts) - if err != nil { - return nil, err - } - - return result, nil + result, _, err := s.client.Subreddit.NewPosts(context.Background(), subreddit, &ListOptions{Limit: 100}) + return result, err } type set map[string]struct{} diff --git a/reddit/stream_test.go b/reddit/stream_test.go new file mode 100644 index 0000000..35c7c1d --- /dev/null +++ b/reddit/stream_test.go @@ -0,0 +1,314 @@ +package reddit + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestStreamService_Posts(t *testing.T) { + setup() + defer teardown() + + var counter int + + mux.HandleFunc("/r/testsubreddit/new", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + defer func() { counter++ }() + + switch counter { + case 0: + fmt.Fprint(w, `{ + "kind": "Listing", + "data": { + "children": [ + { + "kind": "t3", + "data": { + "name": "t3_post1" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post2" + } + } + ] + } + }`) + case 1: + fmt.Fprint(w, `{ + "kind": "Listing", + "data": { + "children": [ + { + "kind": "t3", + "data": { + "name": "t3_post3" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post1" + } + } + ] + } + }`) + case 2: + fmt.Fprint(w, `{ + "kind": "Listing", + "data": { + "children": [ + { + "kind": "t3", + "data": { + "name": "t3_post4" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post5" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post6" + } + } + ] + } + }`) + case 3: + fmt.Fprint(w, `{ + "kind": "Listing", + "data": { + "children": [ + { + "kind": "t3", + "data": { + "name": "t3_post7" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post8" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post9" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post10" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post11" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post12" + } + } + ] + } + }`) + default: + fmt.Fprint(w, `{}`) + } + }) + + posts, errs, stop := client.Stream.Posts("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4)) + defer stop() + + expectedPostIDs := []string{"t3_post1", "t3_post2", "t3_post3", "t3_post4", "t3_post5", "t3_post6", "t3_post7", "t3_post8", "t3_post9", "t3_post10", "t3_post11", "t3_post12"} + var i int + +loop: + for i != len(expectedPostIDs) { + select { + case post, ok := <-posts: + if !ok { + break loop + } + require.Equal(t, expectedPostIDs[i], post.FullID) + case err, ok := <-errs: + if !ok { + break loop + } + require.NoError(t, err) + } + i++ + } + + require.Len(t, expectedPostIDs, i) +} + +func TestStreamService_Posts_DiscardInitial(t *testing.T) { + setup() + defer teardown() + + var counter int + + mux.HandleFunc("/r/testsubreddit/new", func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + defer func() { counter++ }() + + switch counter { + case 0: + fmt.Fprint(w, `{ + "kind": "Listing", + "data": { + "children": [ + { + "kind": "t3", + "data": { + "name": "t3_post1" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post2" + } + } + ] + } + }`) + case 1: + fmt.Fprint(w, `{ + "kind": "Listing", + "data": { + "children": [ + { + "kind": "t3", + "data": { + "name": "t3_post3" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post1" + } + } + ] + } + }`) + case 2: + fmt.Fprint(w, `{ + "kind": "Listing", + "data": { + "children": [ + { + "kind": "t3", + "data": { + "name": "t3_post4" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post5" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post6" + } + } + ] + } + }`) + case 3: + fmt.Fprint(w, `{ + "kind": "Listing", + "data": { + "children": [ + { + "kind": "t3", + "data": { + "name": "t3_post7" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post8" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post9" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post10" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post11" + } + }, + { + "kind": "t3", + "data": { + "name": "t3_post12" + } + } + ] + } + }`) + default: + fmt.Fprint(w, `{}`) + } + }) + + posts, errs, stop := client.Stream.Posts("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4), StreamDiscardInitial) + defer stop() + + expectedPostIDs := []string{"t3_post3", "t3_post4", "t3_post5", "t3_post6", "t3_post7", "t3_post8", "t3_post9", "t3_post10", "t3_post11", "t3_post12"} + var i int + +loop: + for i != len(expectedPostIDs) { + select { + case post, ok := <-posts: + if !ok { + break loop + } + require.Equal(t, expectedPostIDs[i], post.FullID) + case err, ok := <-errs: + if !ok { + break loop + } + require.NoError(t, err) + } + i++ + } + + require.Len(t, expectedPostIDs, i) +} diff --git a/reddit/streamer.go b/reddit/streamer.go new file mode 100644 index 0000000..2216ab8 --- /dev/null +++ b/reddit/streamer.go @@ -0,0 +1,44 @@ +package reddit + +import "time" + +const defaultStreamInterval = time.Second * 5 + +type streamConfig struct { + Interval time.Duration + DiscardInitial bool + MaxRequests int +} + +// StreamOpt is a configuration option to configure a stream. +type StreamOpt func(*streamConfig) + +// StreamInterval sets the frequency at which data will be fetched for the stream. +// If the duration is 0 or less, it will not be set and the default will be used. +func StreamInterval(v time.Duration) StreamOpt { + return func(c *streamConfig) { + if v > 0 { + c.Interval = v + } + } +} + +// StreamDiscardInitial will discard data from the first fetch for the stream. +func StreamDiscardInitial(c *streamConfig) { + c.DiscardInitial = true +} + +// StreamMaxRequests sets a limit on the number of times data is fetched for a stream. +// If less than or equal to 0, it is assumed to be infinite. +func StreamMaxRequests(v int) StreamOpt { + return func(c *streamConfig) { + if v > 0 { + c.MaxRequests = v + } + } +} + +// Streamer streams data to the client. +// type Streamer interface { +// Stream() (<-chan *rootListing, <-chan error, func()) +// }