From 291e18099050ff9e19b8ee25c2ffad68d9baafef Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:07:49 +0000 Subject: [PATCH] [bugfix] fix possible mutex lockup during streaming code (#2633) * rewrite Stream{} to use much less mutex locking, update related code * use new context for the stream context * ensure stream gets closed on return of writeTo / readFrom WSConn() * ensure stream write timeout gets cancelled * remove embedded context type from Stream{}, reformat log messages for consistency * use c.Request.Context() for context passed into Stream().Open() * only return 1 boolean, fix tests to expect multiple stream types in messages * changes to ping logic * further improved ping logic * don't export unused function types, update message sending to only include relevant stream type * ensure stream gets closed :facepalm: * update to error log on failed json marshal (instead of panic) * inverse websocket read error checking to _ignore_ expected close errors --- internal/api/client/streaming/stream.go | 243 +++++------ internal/processing/stream/delete.go | 34 +- internal/processing/stream/notification.go | 21 +- .../processing/stream/notification_test.go | 7 +- internal/processing/stream/open.go | 97 +---- internal/processing/stream/statusupdate.go | 21 +- .../processing/stream/statusupdate_test.go | 7 +- internal/processing/stream/stream.go | 46 +-- internal/processing/stream/update.go | 18 +- .../processing/workers/fromclientapi_test.go | 27 +- .../processing/workers/fromfediapi_test.go | 53 +-- internal/processing/workers/surfacenotify.go | 5 +- .../processing/workers/surfacetimeline.go | 18 +- internal/stream/stream.go | 389 +++++++++++++++--- 14 files changed, 535 insertions(+), 451 deletions(-) diff --git a/internal/api/client/streaming/stream.go b/internal/api/client/streaming/stream.go index 266b64976..8df4e9e76 100644 --- a/internal/api/client/streaming/stream.go +++ b/internal/api/client/streaming/stream.go @@ -22,10 +22,10 @@ "slices" "time" - "codeberg.org/gruf/go-kv" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/oauth" streampkg "github.com/superseriousbusiness/gotosocial/internal/stream" @@ -202,7 +202,7 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // functions pass messages into a channel, which we can // then read from and put into a websockets connection. stream, errWithCode := m.processor.Stream().Open( - c.Request.Context(), + c.Request.Context(), // this ctx is only used for logging account, streamType, ) @@ -213,10 +213,8 @@ func (m *Module) StreamGETHandler(c *gin.Context) { l := log. WithContext(c.Request.Context()). - WithFields(kv.Fields{ - {"username", account.Username}, - {"streamID", stream.ID}, - }...) + WithField("streamID", id.NewULID()). + WithField("username", account.Username) // Upgrade the incoming HTTP request. This hijacks the // underlying connection and reuses it for the websocket @@ -227,18 +225,16 @@ func (m *Module) StreamGETHandler(c *gin.Context) { wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) if err != nil { l.Errorf("error upgrading websocket connection: %v", err) - close(stream.Hangup) + stream.Close() return } - l.Info("opened websocket connection") - // We perform the main websocket rw loops in a separate // goroutine in order to let the upgrade handler return. // This prevents the upgrade handler from holding open any // throttle / rate-limit request tokens which could become // problematic on instances with multiple users. - go m.handleWSConn(account.Username, wsConn, stream) + go m.handleWSConn(&l, wsConn, stream) } // handleWSConn handles a two-way websocket streaming connection. @@ -246,48 +242,39 @@ func (m *Module) StreamGETHandler(c *gin.Context) { // into the connection. If any errors are encountered while reading // or writing (including expected errors like clients leaving), the // connection will be closed. -func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) { - // Create new context for the lifetime of this connection. - ctx, cancel := context.WithCancel(context.Background()) +func (m *Module) handleWSConn(l *log.Entry, wsConn *websocket.Conn, stream *streampkg.Stream) { + l.Info("opened websocket connection") - l := log. - WithContext(ctx). - WithFields(kv.Fields{ - {"username", username}, - {"streamID", stream.ID}, - }...) + // Create new async context with cancel. + ctx, cncl := context.WithCancel(context.Background()) - // Create ticker to send keepalive pings - pinger := time.NewTicker(m.dTicker) - - // Read messages coming from the Websocket client connection into the server. go func() { - defer cancel() - m.readFromWSConn(ctx, username, wsConn, stream) + defer cncl() + + // Read messages from websocket to server. + m.readFromWSConn(ctx, wsConn, stream, l) }() - // Write messages coming from the processor into the Websocket client connection. go func() { - defer cancel() - m.writeToWSConn(ctx, username, wsConn, stream, pinger) + defer cncl() + + // Write messages from processor in websocket conn. + m.writeToWSConn(ctx, wsConn, stream, m.dTicker, l) }() - // Wait for either the read or write functions to close, to indicate - // that the client has left, or something else has gone wrong. + // Wait for ctx + // to be closed. <-ctx.Done() + // Close stream + // straightaway. + stream.Close() + // Tidy up underlying websocket connection. if err := wsConn.Close(); err != nil { l.Errorf("error closing websocket connection: %v", err) } - // Close processor channel so the processor knows - // not to send any more messages to this stream. - close(stream.Hangup) - - // Stop ping ticker (tiny resource saving). - pinger.Stop() - l.Info("closed websocket connection") } @@ -299,89 +286,64 @@ func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *s // if the given context is canceled. func (m *Module) readFromWSConn( ctx context.Context, - username string, wsConn *websocket.Conn, stream *streampkg.Stream, + l *log.Entry, ) { - l := log. - WithContext(ctx). - WithFields(kv.Fields{ - {"username", username}, - {"streamID", stream.ID}, - }...) -readLoop: for { - select { - case <-ctx.Done(): - // Connection closed. - break readLoop + var msg struct { + Type string `json:"type"` + Stream string `json:"stream"` + List string `json:"list,omitempty"` + } + // Read JSON objects from the client and act on them. + if err := wsConn.ReadJSON(&msg); err != nil { + // Only log an error if something weird happened. + // See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7 + if !websocket.IsCloseError(err, []int{ + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseNoStatusReceived, + }...) { + l.Errorf("error during websocket read: %v", err) + } + + // The connection is gone; no + // further streaming possible. + break + } + + // Messages *from* the WS connection are infrequent + // and usually interesting, so log this at info. + l.Infof("received websocket message: %+v", msg) + + // Ignore if the updateStreamType is unknown (or missing), + // so a bad client can't cause extra memory allocations + if !slices.Contains(streampkg.AllStatusTimelines, msg.Stream) { + l.Warnf("unknown 'stream' field: %v", msg) + continue + } + + if msg.List != "" { + // If a list is given, add this to + // the stream name as this is how we + // we track stream types internally. + msg.Stream += ":" + msg.List + } + + switch msg.Type { + case "subscribe": + stream.Subscribe(msg.Stream) + case "unsubscribe": + stream.Unsubscribe(msg.Stream) default: - // Read JSON objects from the client and act on them. - var msg map[string]string - if err := wsConn.ReadJSON(&msg); err != nil { - // Only log an error if something weird happened. - // See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7 - if websocket.IsUnexpectedCloseError(err, []int{ - websocket.CloseNormalClosure, - websocket.CloseGoingAway, - websocket.CloseNoStatusReceived, - }...) { - l.Errorf("error reading from websocket: %v", err) - } - - // The connection is gone; no - // further streaming possible. - break readLoop - } - - // Messages *from* the WS connection are infrequent - // and usually interesting, so log this at info. - l.Infof("received message from websocket: %v", msg) - - // If the message contains 'stream' and 'type' fields, we can - // update the set of timelines that are subscribed for events. - updateType, ok := msg["type"] - if !ok { - l.Warn("'type' field not provided") - continue - } - - updateStream, ok := msg["stream"] - if !ok { - l.Warn("'stream' field not provided") - continue - } - - // Ignore if the updateStreamType is unknown (or missing), - // so a bad client can't cause extra memory allocations - if !slices.Contains(streampkg.AllStatusTimelines, updateStream) { - l.Warnf("unknown 'stream' field: %v", msg) - continue - } - - updateList, ok := msg["list"] - if ok { - updateStream += ":" + updateList - } - - switch updateType { - case "subscribe": - stream.Lock() - stream.StreamTypes[updateStream] = true - stream.Unlock() - case "unsubscribe": - stream.Lock() - delete(stream.StreamTypes, updateStream) - stream.Unlock() - default: - l.Warnf("invalid 'type' field: %v", msg) - } + l.Warnf("invalid 'type' field: %v", msg) } } - l.Debug("finished reading from websocket connection") + l.Debug("finished websocket read") } // writeToWSConn receives messages coming from the processor via the @@ -393,46 +355,47 @@ func (m *Module) readFromWSConn( // if the given context is canceled. func (m *Module) writeToWSConn( ctx context.Context, - username string, wsConn *websocket.Conn, stream *streampkg.Stream, - pinger *time.Ticker, + ping time.Duration, + l *log.Entry, ) { - l := log. - WithContext(ctx). - WithFields(kv.Fields{ - {"username", username}, - {"streamID", stream.ID}, - }...) - -writeLoop: for { - select { - case <-ctx.Done(): - // Connection closed. - break writeLoop + // Wrap context with timeout to send a ping. + pingctx, cncl := context.WithTimeout(ctx, ping) - case msg := <-stream.Messages: - // Received a new message from the processor. - l.Tracef("writing message to websocket: %+v", msg) - if err := wsConn.WriteJSON(msg); err != nil { - l.Debugf("error writing json to websocket: %v", err) - break writeLoop - } + // Block on receipt of msg. + msg, ok := stream.Recv(pingctx) - // Reset pinger on successful send, since - // we know the connection is still there. - pinger.Reset(m.dTicker) + // Check if cancel because ping. + pinged := (pingctx.Err() != nil) + cncl() - case <-pinger.C: - // Time to send a keep-alive "ping". - l.Trace("writing ping control message to websocket") + switch { + case !ok && pinged: + // The ping context timed out! + l.Trace("writing websocket ping") + + // Wrapped context time-out, send a keep-alive "ping". if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil { - l.Debugf("error writing ping to websocket: %v", err) - break writeLoop + l.Debugf("error writing websocket ping: %v", err) + break } + + case !ok: + // Stream was + // closed. + return + } + + l.Trace("writing websocket message: %+v", msg) + + // Received a new message from the processor. + if err := wsConn.WriteJSON(msg); err != nil { + l.Debugf("error writing websocket message: %v", err) + break } } - l.Debug("finished writing to websocket connection") + l.Debug("finished websocket write") } diff --git a/internal/processing/stream/delete.go b/internal/processing/stream/delete.go index d7745eef8..1c61b98d3 100644 --- a/internal/processing/stream/delete.go +++ b/internal/processing/stream/delete.go @@ -18,38 +18,16 @@ package stream import ( - "fmt" - "strings" + "context" "github.com/superseriousbusiness/gotosocial/internal/stream" ) // Delete streams the delete of the given statusID to *ALL* open streams. -func (p *Processor) Delete(statusID string) error { - errs := []string{} - - // get all account IDs with open streams - accountIDs := []string{} - p.streamMap.Range(func(k interface{}, _ interface{}) bool { - key, ok := k.(string) - if !ok { - panic("streamMap key was not a string (account id)") - } - - accountIDs = append(accountIDs, key) - return true +func (p *Processor) Delete(ctx context.Context, statusID string) { + p.streams.PostAll(ctx, stream.Message{ + Payload: statusID, + Event: stream.EventTypeDelete, + Stream: stream.AllStatusTimelines, }) - - // stream the delete to every account - for _, accountID := range accountIDs { - if err := p.toAccount(statusID, stream.EventTypeDelete, stream.AllStatusTimelines, accountID); err != nil { - errs = append(errs, err.Error()) - } - } - - if len(errs) != 0 { - return fmt.Errorf("one or more errors streaming status delete: %s", strings.Join(errs, ";")) - } - - return nil } diff --git a/internal/processing/stream/notification.go b/internal/processing/stream/notification.go index 63d7c5d11..a16da11e6 100644 --- a/internal/processing/stream/notification.go +++ b/internal/processing/stream/notification.go @@ -18,20 +18,29 @@ package stream import ( + "context" "encoding/json" - "fmt" + "codeberg.org/gruf/go-byteutil" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/stream" ) // Notify streams the given notification to any open, appropriate streams belonging to the given account. -func (p *Processor) Notify(n *apimodel.Notification, account *gtsmodel.Account) error { - bytes, err := json.Marshal(n) +func (p *Processor) Notify(ctx context.Context, account *gtsmodel.Account, notif *apimodel.Notification) { + b, err := json.Marshal(notif) if err != nil { - return fmt.Errorf("error marshalling notification to json: %s", err) + log.Errorf(ctx, "error marshaling json: %v", err) + return } - - return p.toAccount(string(bytes), stream.EventTypeNotification, []string{stream.TimelineNotifications, stream.TimelineHome}, account.ID) + p.streams.Post(ctx, account.ID, stream.Message{ + Payload: byteutil.B2S(b), + Event: stream.EventTypeNotification, + Stream: []string{ + stream.TimelineNotifications, + stream.TimelineHome, + }, + }) } diff --git a/internal/processing/stream/notification_test.go b/internal/processing/stream/notification_test.go index 2138f0025..e12f23abe 100644 --- a/internal/processing/stream/notification_test.go +++ b/internal/processing/stream/notification_test.go @@ -49,10 +49,11 @@ func (suite *NotificationTestSuite) TestStreamNotification() { Account: followAccountAPIModel, } - err = suite.streamProcessor.Notify(notification, account) - suite.NoError(err) + suite.streamProcessor.Notify(context.Background(), account, notification) + + msg, ok := openStream.Recv(context.Background()) + suite.True(ok) - msg := <-openStream.Messages dst := new(bytes.Buffer) err = json.Indent(dst, []byte(msg.Payload), "", " ") suite.NoError(err) diff --git a/internal/processing/stream/open.go b/internal/processing/stream/open.go index 1c041309f..2f2bbd4a3 100644 --- a/internal/processing/stream/open.go +++ b/internal/processing/stream/open.go @@ -19,13 +19,10 @@ import ( "context" - "errors" - "fmt" "codeberg.org/gruf/go-kv" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/stream" ) @@ -37,97 +34,5 @@ func (p *Processor) Open(ctx context.Context, account *gtsmodel.Account, streamT {"streamType", streamType}, }...) l.Debug("received open stream request") - - var ( - streamID string - err error - ) - - // Each stream needs a unique ID so we know to close it. - streamID, err = id.NewRandomULID() - if err != nil { - return nil, gtserror.NewErrorInternalError(fmt.Errorf("error generating stream id: %w", err)) - } - - // Each stream can be subscibed to multiple types. - // Record them in a set, and include the initial one - // if it was given to us. - streamTypes := map[string]any{} - if streamType != "" { - streamTypes[streamType] = true - } - - newStream := &stream.Stream{ - ID: streamID, - StreamTypes: streamTypes, - Messages: make(chan *stream.Message, 100), - Hangup: make(chan interface{}, 1), - Connected: true, - } - go p.waitToCloseStream(account, newStream) - - v, ok := p.streamMap.Load(account.ID) - if ok { - // There is an entry in the streamMap - // for this account. Parse it out. - streamsForAccount, ok := v.(*stream.StreamsForAccount) - if !ok { - return nil, gtserror.NewErrorInternalError(errors.New("stream map error")) - } - - // Append new stream to existing entry. - streamsForAccount.Lock() - streamsForAccount.Streams = append(streamsForAccount.Streams, newStream) - streamsForAccount.Unlock() - } else { - // There is no entry in the streamMap for - // this account yet. Create one and store it. - p.streamMap.Store(account.ID, &stream.StreamsForAccount{ - Streams: []*stream.Stream{ - newStream, - }, - }) - } - - return newStream, nil -} - -// waitToCloseStream waits until the hangup channel is closed for the given stream. -// It then iterates through the map of streams stored by the processor, removes the stream from it, -// and then closes the messages channel of the stream to indicate that the channel should no longer be read from. -func (p *Processor) waitToCloseStream(account *gtsmodel.Account, thisStream *stream.Stream) { - <-thisStream.Hangup // wait for a hangup message - - // lock the stream to prevent more messages being put in it while we work - thisStream.Lock() - defer thisStream.Unlock() - - // indicate the stream is no longer connected - thisStream.Connected = false - - // load and parse the entry for this account from the stream map - v, ok := p.streamMap.Load(account.ID) - if !ok || v == nil { - return - } - streamsForAccount, ok := v.(*stream.StreamsForAccount) - if !ok { - return - } - - // lock the streams for account while we remove this stream from its slice - streamsForAccount.Lock() - defer streamsForAccount.Unlock() - - // put everything into modified streams *except* the stream we're removing - modifiedStreams := []*stream.Stream{} - for _, s := range streamsForAccount.Streams { - if s.ID != thisStream.ID { - modifiedStreams = append(modifiedStreams, s) - } - } - streamsForAccount.Streams = modifiedStreams - - // finally close the messages channel so no more messages can be read from it - close(thisStream.Messages) + return p.streams.Open(account.ID, streamType), nil } diff --git a/internal/processing/stream/statusupdate.go b/internal/processing/stream/statusupdate.go index fd8e388ce..bd4658873 100644 --- a/internal/processing/stream/statusupdate.go +++ b/internal/processing/stream/statusupdate.go @@ -18,21 +18,26 @@ package stream import ( + "context" "encoding/json" - "fmt" + "codeberg.org/gruf/go-byteutil" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/stream" ) -// StatusUpdate streams the given edited status to any open, appropriate -// streams belonging to the given account. -func (p *Processor) StatusUpdate(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { - bytes, err := json.Marshal(s) +// StatusUpdate streams the given edited status to any open, appropriate streams belonging to the given account. +func (p *Processor) StatusUpdate(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) { + b, err := json.Marshal(status) if err != nil { - return fmt.Errorf("error marshalling status to json: %s", err) + log.Errorf(ctx, "error marshaling json: %v", err) + return } - - return p.toAccount(string(bytes), stream.EventTypeStatusUpdate, streamTypes, account.ID) + p.streams.Post(ctx, account.ID, stream.Message{ + Payload: byteutil.B2S(b), + Event: stream.EventTypeStatusUpdate, + Stream: []string{streamType}, + }) } diff --git a/internal/processing/stream/statusupdate_test.go b/internal/processing/stream/statusupdate_test.go index 7b987b412..8814c966f 100644 --- a/internal/processing/stream/statusupdate_test.go +++ b/internal/processing/stream/statusupdate_test.go @@ -42,10 +42,11 @@ func (suite *StatusUpdateTestSuite) TestStreamNotification() { apiStatus, err := typeutils.NewConverter(&suite.state).StatusToAPIStatus(context.Background(), editedStatus, account) suite.NoError(err) - err = suite.streamProcessor.StatusUpdate(apiStatus, account, []string{stream.TimelineHome}) - suite.NoError(err) + suite.streamProcessor.StatusUpdate(context.Background(), account, apiStatus, stream.TimelineHome) + + msg, ok := openStream.Recv(context.Background()) + suite.True(ok) - msg := <-openStream.Messages dst := new(bytes.Buffer) err = json.Indent(dst, []byte(msg.Payload), "", " ") suite.NoError(err) diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go index a5b3b9386..0b7285b58 100644 --- a/internal/processing/stream/stream.go +++ b/internal/processing/stream/stream.go @@ -18,8 +18,6 @@ package stream import ( - "sync" - "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/stream" @@ -28,53 +26,13 @@ type Processor struct { state *state.State oauthServer oauth.Server - streamMap *sync.Map + streams stream.Streams } func New(state *state.State, oauthServer oauth.Server) Processor { return Processor{ state: state, oauthServer: oauthServer, - streamMap: &sync.Map{}, + streams: stream.Streams{}, } } - -// toAccount streams the given payload with the given event type to any streams currently open for the given account ID. -func (p *Processor) toAccount(payload string, event string, streamTypes []string, accountID string) error { - // Load all streams open for this account. - v, ok := p.streamMap.Load(accountID) - if !ok { - return nil // No entry = nothing to stream. - } - streamsForAccount := v.(*stream.StreamsForAccount) - - streamsForAccount.Lock() - defer streamsForAccount.Unlock() - - for _, s := range streamsForAccount.Streams { - s.Lock() - defer s.Unlock() - - if !s.Connected { - continue - } - - typeLoop: - for _, streamType := range streamTypes { - if _, found := s.StreamTypes[streamType]; found { - s.Messages <- &stream.Message{ - Stream: []string{streamType}, - Event: string(event), - Payload: payload, - } - - // Break out to the outer loop, - // to avoid sending duplicates of - // the same event to the same stream. - break typeLoop - } - } - } - - return nil -} diff --git a/internal/processing/stream/update.go b/internal/processing/stream/update.go index ee70bda11..a84763d51 100644 --- a/internal/processing/stream/update.go +++ b/internal/processing/stream/update.go @@ -18,20 +18,26 @@ package stream import ( + "context" "encoding/json" - "fmt" + "codeberg.org/gruf/go-byteutil" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/stream" ) // Update streams the given update to any open, appropriate streams belonging to the given account. -func (p *Processor) Update(s *apimodel.Status, account *gtsmodel.Account, streamTypes []string) error { - bytes, err := json.Marshal(s) +func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, status *apimodel.Status, streamType string) { + b, err := json.Marshal(status) if err != nil { - return fmt.Errorf("error marshalling status to json: %s", err) + log.Errorf(ctx, "error marshaling json: %v", err) + return } - - return p.toAccount(string(bytes), stream.EventTypeUpdate, streamTypes, account.ID) + p.streams.Post(ctx, account.ID, stream.Message{ + Payload: byteutil.B2S(b), + Event: stream.EventTypeUpdate, + Stream: []string{streamType}, + }) } diff --git a/internal/processing/workers/fromclientapi_test.go b/internal/processing/workers/fromclientapi_test.go index 05526f437..3d3630b11 100644 --- a/internal/processing/workers/fromclientapi_test.go +++ b/internal/processing/workers/fromclientapi_test.go @@ -116,23 +116,20 @@ func (suite *FromClientAPITestSuite) checkStreamed( expectPayload string, expectEventType string, ) { - var msg *stream.Message -streamLoop: - for { - select { - case msg = <-str.Messages: - break streamLoop // Got it. - case <-time.After(5 * time.Second): - break streamLoop // Didn't get it. - } + + // Set a 5s timeout on context. + ctx := context.Background() + ctx, cncl := context.WithTimeout(ctx, time.Second*5) + defer cncl() + + msg, ok := str.Recv(ctx) + + if expectMessage && !ok { + suite.FailNow("expected a message but message was not received") } - if expectMessage && msg == nil { - suite.FailNow("expected a message but message was nil") - } - - if !expectMessage && msg != nil { - suite.FailNow("expected no message but message was not nil") + if !expectMessage && ok { + suite.FailNow("expected no message but message was received") } if expectPayload != "" && msg.Payload != expectPayload { diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go index 799eaf2dc..446355628 100644 --- a/internal/processing/workers/fromfediapi_test.go +++ b/internal/processing/workers/fromfediapi_test.go @@ -130,14 +130,9 @@ func (suite *FromFediAPITestSuite) TestProcessReplyMention() { suite.Equal(replyingStatus.ID, notif.StatusID) suite.False(*notif.Read) - // the notification should be streamed - var msg *stream.Message - select { - case msg = <-wssStream.Messages: - // fine - case <-time.After(5 * time.Second): - suite.FailNow("no message from wssStream") - } + ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + msg, ok := wssStream.Recv(ctx) + suite.True(ok) suite.Equal(stream.EventTypeNotification, msg.Event) suite.NotEmpty(msg.Payload) @@ -203,14 +198,10 @@ func (suite *FromFediAPITestSuite) TestProcessFave() { suite.Equal(fave.StatusID, notif.StatusID) suite.False(*notif.Read) - // 2. a notification should be streamed - var msg *stream.Message - select { - case msg = <-wssStream.Messages: - // fine - case <-time.After(5 * time.Second): - suite.FailNow("no message from wssStream") - } + ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + msg, ok := wssStream.Recv(ctx) + suite.True(ok) + suite.Equal(stream.EventTypeNotification, msg.Event) suite.NotEmpty(msg.Payload) suite.EqualValues([]string{stream.TimelineNotifications}, msg.Stream) @@ -277,7 +268,9 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount( suite.False(*notif.Read) // 2. no notification should be streamed to the account that received the fave message, because they weren't the target - suite.Empty(wssStream.Messages) + ctx, _ := context.WithTimeout(context.Background(), time.Second*5) + _, ok := wssStream.Recv(ctx) + suite.False(ok) } func (suite *FromFediAPITestSuite) TestProcessAccountDelete() { @@ -405,14 +398,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() { }) suite.NoError(err) - // a notification should be streamed - var msg *stream.Message - select { - case msg = <-wssStream.Messages: - // fine - case <-time.After(5 * time.Second): - suite.FailNow("no message from wssStream") - } + ctx, _ = context.WithTimeout(ctx, time.Second*5) + msg, ok := wssStream.Recv(context.Background()) + suite.True(ok) + suite.Equal(stream.EventTypeNotification, msg.Event) suite.NotEmpty(msg.Payload) suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) @@ -423,7 +412,7 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestLocked() { suite.Equal(originAccount.ID, notif.Account.ID) // no messages should have been sent out, since we didn't need to federate an accept - suite.Empty(suite.httpClient.SentMessages) + suite.Empty(&suite.httpClient.SentMessages) } func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() { @@ -503,14 +492,10 @@ func (suite *FromFediAPITestSuite) TestProcessFollowRequestUnlocked() { suite.Equal(originAccount.URI, accept.To) suite.Equal("Accept", accept.Type) - // a notification should be streamed - var msg *stream.Message - select { - case msg = <-wssStream.Messages: - // fine - case <-time.After(5 * time.Second): - suite.FailNow("no message from wssStream") - } + ctx, _ = context.WithTimeout(ctx, time.Second*5) + msg, ok := wssStream.Recv(context.Background()) + suite.True(ok) + suite.Equal(stream.EventTypeNotification, msg.Event) suite.NotEmpty(msg.Payload) suite.EqualValues([]string{stream.TimelineHome}, msg.Stream) diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go index 39798f45e..a8c36248c 100644 --- a/internal/processing/workers/surfacenotify.go +++ b/internal/processing/workers/surfacenotify.go @@ -394,10 +394,7 @@ func (s *surface) notify( if err != nil { return gtserror.Newf("error converting notification to api representation: %w", err) } - - if err := s.stream.Notify(apiNotif, targetAccount); err != nil { - return gtserror.Newf("error streaming notification to account: %w", err) - } + s.stream.Notify(ctx, targetAccount, apiNotif) return nil } diff --git a/internal/processing/workers/surfacetimeline.go b/internal/processing/workers/surfacetimeline.go index e63b8a7c0..14634f846 100644 --- a/internal/processing/workers/surfacetimeline.go +++ b/internal/processing/workers/surfacetimeline.go @@ -348,11 +348,7 @@ func (s *surface) timelineStatus( err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) return true, err } - - if err := s.stream.Update(apiStatus, account, []string{streamType}); err != nil { - err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err) - return true, err - } + s.stream.Update(ctx, account, apiStatus, streamType) return true, nil } @@ -363,12 +359,11 @@ func (s *surface) deleteStatusFromTimelines(ctx context.Context, statusID string if err := s.state.Timelines.Home.WipeItemFromAllTimelines(ctx, statusID); err != nil { return err } - if err := s.state.Timelines.List.WipeItemFromAllTimelines(ctx, statusID); err != nil { return err } - - return s.stream.Delete(statusID) + s.stream.Delete(ctx, statusID) + return nil } // invalidateStatusFromTimelines does cache invalidation on the given status by @@ -555,11 +550,6 @@ func (s *surface) timelineStreamStatusUpdate( err = gtserror.Newf("error converting status %s to frontend representation: %w", status.ID, err) return err } - - if err := s.stream.StatusUpdate(apiStatus, account, []string{streamType}); err != nil { - err = gtserror.Newf("error streaming update for status %s: %w", status.ID, err) - return err - } - + s.stream.StatusUpdate(ctx, account, apiStatus, streamType) return nil } diff --git a/internal/stream/stream.go b/internal/stream/stream.go index da5647433..ec22464f5 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -17,36 +17,65 @@ package stream -import "sync" - -const ( - // EventTypeNotification -- a user should be shown a notification - EventTypeNotification string = "notification" - // EventTypeUpdate -- a user should be shown an update in their timeline - EventTypeUpdate string = "update" - // EventTypeDelete -- something should be deleted from a user - EventTypeDelete string = "delete" - // EventTypeStatusUpdate -- something in the user's timeline has been edited - // (yes this is a confusing name, blame Mastodon) - EventTypeStatusUpdate string = "status.update" +import ( + "context" + "maps" + "slices" + "sync" + "sync/atomic" ) const ( - // TimelineLocal -- public statuses from the LOCAL timeline. - TimelineLocal string = "public:local" - // TimelinePublic -- public statuses, including federated ones. - TimelinePublic string = "public" - // TimelineHome -- statuses for a user's Home timeline. - TimelineHome string = "user" - // TimelineNotifications -- notification events. - TimelineNotifications string = "user:notification" - // TimelineDirect -- statuses sent to a user directly. - TimelineDirect string = "direct" - // TimelineList -- statuses for a user's list timeline. - TimelineList string = "list" + // EventTypeNotification -- a user + // should be shown a notification. + EventTypeNotification = "notification" + + // EventTypeUpdate -- a user should + // be shown an update in their timeline. + EventTypeUpdate = "update" + + // EventTypeDelete -- something + // should be deleted from a user. + EventTypeDelete = "delete" + + // EventTypeStatusUpdate -- something in the + // user's timeline has been edited (yes this + // is a confusing name, blame Mastodon ...). + EventTypeStatusUpdate = "status.update" ) -// AllStatusTimelines contains all Timelines that a status could conceivably be delivered to -- useful for doing deletes. +const ( + // TimelineLocal: + // All public posts originating from this + // server. Analogous to the local timeline. + TimelineLocal = "public:local" + + // TimelinePublic: + // All public posts known to the server. + // Analogous to the federated timeline. + TimelinePublic = "public" + + // TimelineHome: + // Events related to the current user, such + // as home feed updates and notifications. + TimelineHome = "user" + + // TimelineNotifications: + // Notifications for the current user. + TimelineNotifications = "user:notification" + + // TimelineDirect: + // Updates to direct conversations. + TimelineDirect = "direct" + + // TimelineList: + // Updates to a specific list. + TimelineList = "list" +) + +// AllStatusTimelines contains all Timelines +// that a status could conceivably be delivered +// to, useful for sending out status deletes. var AllStatusTimelines = []string{ TimelineLocal, TimelinePublic, @@ -55,38 +84,298 @@ TimelineList, } -// StreamsForAccount is a wrapper for the multiple streams that one account can have running at the same time. -// TODO: put a limit on this -type StreamsForAccount struct { - // The currently held streams for this account - Streams []*Stream - // Mutex to lock/unlock when modifying the slice of streams. - sync.Mutex +type Streams struct { + streams map[string][]*Stream + mutex sync.Mutex } -// Stream represents one open stream for a client. +// Open will open open a new Stream for given account ID and stream types, the given context will be passed to Stream. +func (s *Streams) Open(accountID string, streamTypes ...string) *Stream { + if len(streamTypes) == 0 { + panic("no stream types given") + } + + // Prep new Stream. + str := new(Stream) + str.done = make(chan struct{}) + str.msgCh = make(chan Message, 50) // TODO: make configurable + for _, streamType := range streamTypes { + str.Subscribe(streamType) + } + + // TODO: add configurable + // max streams per account. + + // Acquire lock. + s.mutex.Lock() + + if s.streams == nil { + // Main stream-map needs allocating. + s.streams = make(map[string][]*Stream) + } + + // Add new stream for account. + strs := s.streams[accountID] + strs = append(strs, str) + s.streams[accountID] = strs + + // Register close callback + // to remove stream from our + // internal map for this account. + str.close = func() { + s.mutex.Lock() + strs := s.streams[accountID] + strs = slices.DeleteFunc(strs, func(s *Stream) bool { + return s == str // remove 'str' ptr + }) + s.streams[accountID] = strs + s.mutex.Unlock() + } + + // Done with lock. + s.mutex.Unlock() + + return str +} + +// Post will post the given message to all streams of given account ID matching type. +func (s *Streams) Post(ctx context.Context, accountID string, msg Message) bool { + var deferred []func() bool + + // Acquire lock. + s.mutex.Lock() + + // Iterate all streams stored for account. + for _, str := range s.streams[accountID] { + + // Check whether stream supports any of our message targets. + if stype := str.getStreamType(msg.Stream...); stype != "" { + + // Rescope var + // to prevent + // ptr reuse. + stream := str + + // Use a message copy to *only* + // include the supported stream. + msgCopy := Message{ + Stream: []string{stype}, + Event: msg.Event, + Payload: msg.Payload, + } + + // Send message to supported stream + // DEFERRED (i.e. OUTSIDE OF MAIN MUTEX). + // This prevents deadlocks between each + // msg channel and main Streams{} mutex. + deferred = append(deferred, func() bool { + return stream.send(ctx, msgCopy) + }) + } + } + + // Done with lock. + s.mutex.Unlock() + + var ok bool + + // Execute deferred outside lock. + for _, deferfn := range deferred { + v := deferfn() + ok = ok && v + } + + return ok +} + +// PostAll will post the given message to all streams with matching types. +func (s *Streams) PostAll(ctx context.Context, msg Message) bool { + var deferred []func() bool + + // Acquire lock. + s.mutex.Lock() + + // Iterate ALL stored streams. + for _, strs := range s.streams { + for _, str := range strs { + + // Check whether stream supports any of our message targets. + if stype := str.getStreamType(msg.Stream...); stype != "" { + + // Rescope var + // to prevent + // ptr reuse. + stream := str + + // Use a message copy to *only* + // include the supported stream. + msgCopy := Message{ + Stream: []string{stype}, + Event: msg.Event, + Payload: msg.Payload, + } + + // Send message to supported stream + // DEFERRED (i.e. OUTSIDE OF MAIN MUTEX). + // This prevents deadlocks between each + // msg channel and main Streams{} mutex. + deferred = append(deferred, func() bool { + return stream.send(ctx, msgCopy) + }) + } + } + } + + // Done with lock. + s.mutex.Unlock() + + var ok bool + + // Execute deferred outside lock. + for _, deferfn := range deferred { + v := deferfn() + ok = ok && v + } + + return ok +} + +// Stream represents one +// open stream for a client. type Stream struct { - // ID of this stream, generated during creation. - ID string - // A set of types subscribed to by this stream: user/public/etc. - // It's a map to ensure no duplicates; the value is ignored. - StreamTypes map[string]any - // Channel of messages for the client to read from - Messages chan *Message - // Channel to close when the client drops away - Hangup chan interface{} - // Only put messages in the stream when Connected - Connected bool - // Mutex to lock/unlock when inserting messages, hanging up, changing the connected state etc. - sync.Mutex + + // atomically updated ptr to a read-only copy + // of supported stream types in a hashmap. this + // gets updated via CAS operations in .cas(). + types atomic.Pointer[map[string]struct{}] + + // protects stream close. + done chan struct{} + + // inbound msg ch. + msgCh chan Message + + // close hook to remove + // stream from Streams{}. + close func() } -// Message represents one streamed message. +// Subscribe will add given type to given types this stream supports. +func (s *Stream) Subscribe(streamType string) { + s.cas(func(m map[string]struct{}) bool { + if _, ok := m[streamType]; ok { + return false + } + m[streamType] = struct{}{} + return true + }) +} + +// Unsubscribe will remove given type (if found) from types this stream supports. +func (s *Stream) Unsubscribe(streamType string) { + s.cas(func(m map[string]struct{}) bool { + if _, ok := m[streamType]; !ok { + return false + } + delete(m, streamType) + return true + }) +} + +// getStreamType returns the first stream type in given list that stream supports. +func (s *Stream) getStreamType(streamTypes ...string) string { + if ptr := s.types.Load(); ptr != nil { + for _, streamType := range streamTypes { + if _, ok := (*ptr)[streamType]; ok { + return streamType + } + } + } + return "" +} + +// send will block on posting a new Message{}, returning early with +// a false value if provided context is canceled, or stream closed. +func (s *Stream) send(ctx context.Context, msg Message) bool { + select { + case <-s.done: + return false + case <-ctx.Done(): + return false + case s.msgCh <- msg: + return true + } +} + +// Recv will block on receiving Message{}, returning early with a +// false value if provided context is canceled, or stream closed. +func (s *Stream) Recv(ctx context.Context) (Message, bool) { + select { + case <-s.done: + return Message{}, false + case <-ctx.Done(): + return Message{}, false + case msg := <-s.msgCh: + return msg, true + } +} + +// Close will close the underlying context, finally +// removing it from the parent Streams per-account-map. +func (s *Stream) Close() { + select { + case <-s.done: + default: + close(s.done) + s.close() + } +} + +// cas will perform a Compare And Swap operation on s.types using modifier func. +func (s *Stream) cas(fn func(map[string]struct{}) bool) { + if fn == nil { + panic("nil function") + } + for { + var m map[string]struct{} + + // Get current value. + ptr := s.types.Load() + + if ptr == nil { + // Allocate new types map. + m = make(map[string]struct{}) + } else { + // Clone r-only map. + m = maps.Clone(*ptr) + } + + // Apply + // changes. + if !fn(m) { + return + } + + // Attempt to Compare And Swap ptr. + if s.types.CompareAndSwap(ptr, &m) { + return + } + } +} + +// Message represents +// one streamed message. type Message struct { - // All the stream types this message should be delivered to. + + // All the stream types this + // message should be delivered to. Stream []string `json:"stream"` - // The event type of the message (update/delete/notification etc) + + // The event type of the message + // (update/delete/notification etc) Event string `json:"event"` - // The actual payload of the message. In case of an update or notification, this will be a JSON string. + + // The actual payload of the message. In case of an + // update or notification, this will be a JSON string. Payload string `json:"payload"` }