Last active
May 23, 2024 22:36
-
-
Save maxsei/fef76d3306056fbdfea66fcd7fc30c49 to your computer and use it in GitHub Desktop.
streaming for EnTeRpRiSe
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package stream | |
import ( | |
"context" | |
"errors" | |
"sync" | |
) | |
type CancelableMessage[T any] struct { | |
ctx context.Context | |
cancel context.CancelCauseFunc | |
data T | |
} | |
func New[T any](ctx context.Context) *Stream[T] { | |
res := Stream[T]{ | |
next: make(chan CancelableMessage[T]), | |
subscribe: make(chan CancelableMessage[chan T]), | |
unsubscribe: make(chan CancelableMessage[(<-chan T)]), | |
} | |
res.ctx, res.cancel = context.WithCancel(ctx) | |
return &res | |
} | |
type Stream[T any] struct { | |
subscribers []chan T | |
ctx context.Context | |
cancel context.CancelFunc | |
state T | |
// Channels. | |
next chan CancelableMessage[T] | |
deref chan CancelableMessage[(<-chan T)] | |
subscribe chan CancelableMessage[chan T] | |
unsubscribe chan CancelableMessage[(<-chan T)] | |
} | |
func (s *Stream[T]) Start() { | |
processingEvents: | |
for { | |
select { | |
case <-s.ctx.Done(): | |
for _, sub := range s.subscribers { | |
close(sub) | |
} | |
s.subscribers = s.subscribers[:0] | |
return | |
case subscriber := <-s.unsubscribe: | |
i := s.findSubscriberIndex(subscriber.data) | |
if i == -1 { | |
subscriber.cancel(errors.New("subscriber not found")) | |
continue processingEvents | |
} | |
close(s.subscribers[i]) | |
s.subscribers = append(s.subscribers[:i], s.subscribers[i+1:]...) | |
case subscriber := <-s.subscribe: | |
// i := s.findSubscriberIndex(subscriber.data) | |
// if i != -1 { | |
// subscriber.cancel(errors.New("subscriber already exists")) | |
// return | |
// } | |
s.subscribers = append(s.subscribers, subscriber.data) | |
case message := <-s.next: | |
var wg sync.WaitGroup | |
wg.Add(len(s.subscribers)) | |
for i := range s.subscribers { | |
go func(i int) { | |
defer wg.Done() | |
select { | |
case s.subscribers[i] <- message.data: | |
case <-message.ctx.Done(): | |
// TODO: deal with slow consumers here... <16-05-24, Max Schulte> // | |
// TODO: I feel like consumers must have their own context too | |
// so that we can "continue" to other consumers instead of just | |
// returning when the message context has run out. We can have both. | |
// Perhaps the unsubscribe method can be part of a consumer object | |
// as well as context expiration for explicit and implicit | |
// unsubscriptions. <17-05-24, Max Schulte> // | |
return | |
} | |
}(i) | |
} | |
wg.Wait() | |
case subscriber := <-s.deref: | |
i := s.findSubscriberIndex(subscriber.data) | |
if i == -1 { | |
subscriber.cancel(errors.New("subscriber not found")) | |
continue processingEvents | |
} | |
select { | |
case s.subscribers[i] <- s.state: | |
case <-subscriber.ctx.Done(): | |
// TODO: deal with slow consumers here... <16-05-24, Max Schulte> // | |
continue processingEvents | |
} | |
} | |
} | |
} | |
func (s *Stream[T]) Ctx() context.Context { return s.ctx } | |
func (s *Stream[T]) Close() { s.cancel() } | |
func (s *Stream[T]) findSubscriberIndex(subscriber <-chan T) int { | |
for i := range s.subscribers { | |
if s.subscribers[i] == subscriber { | |
return i | |
} | |
} | |
return -1 | |
} | |
func sendMsg[T any](parent context.Context, message T, ch chan CancelableMessage[T]) error { | |
ctx, cancel := context.WithCancelCause(parent) | |
// NB: Calling cancel is delegated to the Start() method for error | |
// handling/control flow. | |
select { | |
case <-ctx.Done(): | |
return ctx.Err() | |
case ch <- CancelableMessage[T]{ctx: ctx, cancel: cancel, data: message}: | |
} | |
return nil | |
} | |
func (s *Stream[T]) Next(ctx context.Context, message T) error { | |
return sendMsg(ctx, message, s.next) | |
} | |
func (s *Stream[T]) Deref(ctx context.Context, subscriber <-chan T) error { | |
return sendMsg(ctx, subscriber, s.deref) | |
} | |
func (s *Stream[T]) Subscribe(ctx context.Context) (<-chan T, error) { | |
subscriber := make(chan T) | |
if err := sendMsg(ctx, subscriber, s.subscribe); err != nil { | |
return nil, err | |
} | |
return subscriber, nil | |
} | |
func (s *Stream[T]) Unsubscribe(ctx context.Context, subscriber <-chan T) error { | |
return sendMsg(ctx, subscriber, s.unsubscribe) | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package stream | |
import ( | |
"context" | |
"sync" | |
"sync/atomic" | |
"testing" | |
) | |
func TestSingleProducerSingleConsumer(t *testing.T) { | |
// Setup stream. | |
s := New[int](context.Background()) | |
go s.Start() | |
// Consumer | |
consumer, err := s.Subscribe(context.Background()) | |
if err != nil { | |
t.Error(err) | |
} | |
// Producer | |
go func() { | |
for i := 0; i < 32; i++ { | |
if err := s.Next(context.Background(), i); err != nil { | |
t.Error(err) | |
} | |
} | |
s.Close() | |
}() | |
// Listen to consumer. | |
for message := range consumer { | |
t.Log(message) | |
} | |
} | |
func TestSingleProducerMultipeConsumer(t *testing.T) { | |
// Setup stream. | |
s := New[int](context.Background()) | |
go s.Start() | |
const ExpectedMessageCount = 32 | |
// Setup consumers. | |
consumers := make([]<-chan int, 16) | |
for i := range consumers { | |
var err error | |
consumers[i], err = s.Subscribe(context.Background()) | |
if err != nil { | |
t.Error(err) | |
} | |
} | |
var consumerReportsCount int64 | |
// Listen to consumers and make sure they get all the messages. | |
var wg sync.WaitGroup | |
wg.Add(len(consumers)) | |
for id := range consumers { | |
go func(id int) { | |
consumer := consumers[id] | |
var actualMessageCount int | |
for range consumer { | |
actualMessageCount += 1 | |
} | |
if actualMessageCount != ExpectedMessageCount { | |
t.Errorf("consumer %03d: expected %d got %d", id, ExpectedMessageCount, actualMessageCount) | |
} | |
atomic.AddInt64(&consumerReportsCount, 1) | |
wg.Done() | |
}(id) | |
} | |
// Produce all values wait for consumers to receive them and close the stream. | |
for i := range make([]struct{}, ExpectedMessageCount) { | |
if err := s.Next(context.Background(), i); err != nil { | |
t.Error(err) | |
} | |
} | |
s.Close() | |
wg.Wait() | |
// Ensure that all consumers have reported their counts. | |
if int(consumerReportsCount) != len(consumers) { | |
t.Fatal("did not get a report from all consumers") | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment