Add streamer config, tests

Signed-off-by: Vartan Benohanian <vartanbeno@gmail.com>
This commit is contained in:
Vartan Benohanian 2020-08-22 14:26:50 -04:00
parent 49fa672619
commit da689412a1
3 changed files with 401 additions and 21 deletions

View File

@ -18,20 +18,48 @@ type StreamService struct {
// - a function that the client can call once to stop the streaming and close the channels
// Because of the 100 post limit imposed by Reddit when fetching posts, some high-traffic
// streams might drop submissions between API requests, such as when streaming r/all.
func (s *StreamService) Posts(subreddit string) (<-chan *Post, <-chan error, func()) {
func (s *StreamService) Posts(subreddit string, opts ...StreamOpt) (<-chan *Post, <-chan error, func()) {
streamConfig := &streamConfig{
Interval: defaultStreamInterval,
DiscardInitial: false,
MaxRequests: 0,
}
for _, opt := range opts {
opt(streamConfig)
}
ticker := time.NewTicker(streamConfig.Interval)
posts := make(chan *Post)
errs := make(chan error)
ticker := time.NewTicker(time.Second * 5)
var once sync.Once
stop := func() {
once.Do(func() {
ticker.Stop()
close(posts)
close(errs)
})
}
// originally used the "before" parameter, but if that post gets deleted, subsequent requests
// would just return empty listings; easier to just keep track of all post ids encountered
ids := set{}
go func() {
defer stop()
var n int
infinite := streamConfig.MaxRequests == 0
for ; ; <-ticker.C {
n++
result, err := s.getPosts(subreddit)
if err != nil {
errs <- err
if !infinite && n >= streamConfig.MaxRequests {
break
}
continue
}
@ -43,34 +71,28 @@ func (s *StreamService) Posts(subreddit string) (<-chan *Post, <-chan error, fun
if ids.Exists(id) {
break
}
ids.Add(id)
if streamConfig.DiscardInitial {
streamConfig.DiscardInitial = false
break
}
posts <- post
}
if !infinite && n >= streamConfig.MaxRequests {
break
}
}
}()
var once sync.Once
return posts, errs, func() {
once.Do(func() {
ticker.Stop()
close(posts)
close(errs)
})
}
return posts, errs, stop
}
func (s *StreamService) getPosts(subreddit string) (*Posts, error) {
opts := &ListOptions{
Limit: 100,
}
result, _, err := s.client.Subreddit.NewPosts(context.Background(), subreddit, opts)
if err != nil {
return nil, err
}
return result, nil
result, _, err := s.client.Subreddit.NewPosts(context.Background(), subreddit, &ListOptions{Limit: 100})
return result, err
}
type set map[string]struct{}

314
reddit/stream_test.go Normal file
View File

@ -0,0 +1,314 @@
package reddit
import (
"fmt"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestStreamService_Posts(t *testing.T) {
setup()
defer teardown()
var counter int
mux.HandleFunc("/r/testsubreddit/new", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
defer func() { counter++ }()
switch counter {
case 0:
fmt.Fprint(w, `{
"kind": "Listing",
"data": {
"children": [
{
"kind": "t3",
"data": {
"name": "t3_post1"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post2"
}
}
]
}
}`)
case 1:
fmt.Fprint(w, `{
"kind": "Listing",
"data": {
"children": [
{
"kind": "t3",
"data": {
"name": "t3_post3"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post1"
}
}
]
}
}`)
case 2:
fmt.Fprint(w, `{
"kind": "Listing",
"data": {
"children": [
{
"kind": "t3",
"data": {
"name": "t3_post4"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post5"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post6"
}
}
]
}
}`)
case 3:
fmt.Fprint(w, `{
"kind": "Listing",
"data": {
"children": [
{
"kind": "t3",
"data": {
"name": "t3_post7"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post8"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post9"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post10"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post11"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post12"
}
}
]
}
}`)
default:
fmt.Fprint(w, `{}`)
}
})
posts, errs, stop := client.Stream.Posts("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4))
defer stop()
expectedPostIDs := []string{"t3_post1", "t3_post2", "t3_post3", "t3_post4", "t3_post5", "t3_post6", "t3_post7", "t3_post8", "t3_post9", "t3_post10", "t3_post11", "t3_post12"}
var i int
loop:
for i != len(expectedPostIDs) {
select {
case post, ok := <-posts:
if !ok {
break loop
}
require.Equal(t, expectedPostIDs[i], post.FullID)
case err, ok := <-errs:
if !ok {
break loop
}
require.NoError(t, err)
}
i++
}
require.Len(t, expectedPostIDs, i)
}
func TestStreamService_Posts_DiscardInitial(t *testing.T) {
setup()
defer teardown()
var counter int
mux.HandleFunc("/r/testsubreddit/new", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method)
defer func() { counter++ }()
switch counter {
case 0:
fmt.Fprint(w, `{
"kind": "Listing",
"data": {
"children": [
{
"kind": "t3",
"data": {
"name": "t3_post1"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post2"
}
}
]
}
}`)
case 1:
fmt.Fprint(w, `{
"kind": "Listing",
"data": {
"children": [
{
"kind": "t3",
"data": {
"name": "t3_post3"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post1"
}
}
]
}
}`)
case 2:
fmt.Fprint(w, `{
"kind": "Listing",
"data": {
"children": [
{
"kind": "t3",
"data": {
"name": "t3_post4"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post5"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post6"
}
}
]
}
}`)
case 3:
fmt.Fprint(w, `{
"kind": "Listing",
"data": {
"children": [
{
"kind": "t3",
"data": {
"name": "t3_post7"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post8"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post9"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post10"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post11"
}
},
{
"kind": "t3",
"data": {
"name": "t3_post12"
}
}
]
}
}`)
default:
fmt.Fprint(w, `{}`)
}
})
posts, errs, stop := client.Stream.Posts("testsubreddit", StreamInterval(time.Millisecond*10), StreamMaxRequests(4), StreamDiscardInitial)
defer stop()
expectedPostIDs := []string{"t3_post3", "t3_post4", "t3_post5", "t3_post6", "t3_post7", "t3_post8", "t3_post9", "t3_post10", "t3_post11", "t3_post12"}
var i int
loop:
for i != len(expectedPostIDs) {
select {
case post, ok := <-posts:
if !ok {
break loop
}
require.Equal(t, expectedPostIDs[i], post.FullID)
case err, ok := <-errs:
if !ok {
break loop
}
require.NoError(t, err)
}
i++
}
require.Len(t, expectedPostIDs, i)
}

44
reddit/streamer.go Normal file
View File

@ -0,0 +1,44 @@
package reddit
import "time"
const defaultStreamInterval = time.Second * 5
type streamConfig struct {
Interval time.Duration
DiscardInitial bool
MaxRequests int
}
// StreamOpt is a configuration option to configure a stream.
type StreamOpt func(*streamConfig)
// StreamInterval sets the frequency at which data will be fetched for the stream.
// If the duration is 0 or less, it will not be set and the default will be used.
func StreamInterval(v time.Duration) StreamOpt {
return func(c *streamConfig) {
if v > 0 {
c.Interval = v
}
}
}
// StreamDiscardInitial will discard data from the first fetch for the stream.
func StreamDiscardInitial(c *streamConfig) {
c.DiscardInitial = true
}
// StreamMaxRequests sets a limit on the number of times data is fetched for a stream.
// If less than or equal to 0, it is assumed to be infinite.
func StreamMaxRequests(v int) StreamOpt {
return func(c *streamConfig) {
if v > 0 {
c.MaxRequests = v
}
}
}
// Streamer streams data to the client.
// type Streamer interface {
// Stream() (<-chan *rootListing, <-chan error, func())
// }