[chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors (#1932)

* [chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors

* tweak

* tidy up, use control message
This commit is contained in:
tobi 2023-07-04 12:55:10 +02:00 committed by GitHub
parent ba0bc06b8c
commit 3d16962173
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 237 additions and 133 deletions

View file

@ -149,60 +149,78 @@
// '400': // '400':
// description: bad request // description: bad request
func (m *Module) StreamGETHandler(c *gin.Context) { func (m *Module) StreamGETHandler(c *gin.Context) {
var (
account *gtsmodel.Account
errWithCode gtserror.WithCode
)
// First we check for a query param provided access token // Try query param access token.
token := c.Query(AccessTokenQueryKey) token := c.Query(AccessTokenQueryKey)
if token == "" { if token == "" {
// Else we check the HTTP header provided token // Try fallback HTTP header provided token.
token = c.GetHeader(AccessTokenHeader) token = c.GetHeader(AccessTokenHeader)
} }
var account *gtsmodel.Account
if token != "" { if token != "" {
// Check the explicit token // Token was provided, use it to authorize stream.
var errWithCode gtserror.WithCode
account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token) account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
} else { } else {
// If no explicit token was provided, try regular oauth // No explicit token was provided:
auth, errStr := oauth.Authed(c, true, true, true, true) // try regular oauth as a last resort.
if errStr != nil { account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) {
err := gtserror.NewErrorUnauthorized(errStr, errStr.Error()) authed, err := oauth.Authed(c, true, true, true, true)
apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1) if err != nil {
return return nil, gtserror.NewErrorUnauthorized(err, err.Error())
} }
account = auth.Account
return authed.Account, nil
}()
} }
// Get the initial stream type, if there is one. if errWithCode != nil {
// By appending other query params to the streamType, apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
// we can allow for streaming for specific list IDs return
// or hashtags. }
// Get the initial requested stream type, if there is one.
streamType := c.Query(StreamQueryKey) streamType := c.Query(StreamQueryKey)
// By appending other query params to the streamType, we
// can allow streaming for specific list IDs or hashtags.
// The streamType in this case will end up looking like
// `hashtag:example` or `list:01H3YF48G8B7KTPQFS8D2QBVG8`.
if list := c.Query(StreamListKey); list != "" { if list := c.Query(StreamListKey); list != "" {
streamType += ":" + list streamType += ":" + list
} else if tag := c.Query(StreamTagKey); tag != "" { } else if tag := c.Query(StreamTagKey); tag != "" {
streamType += ":" + tag streamType += ":" + tag
} }
stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType) // Open a stream with the processor; this lets processor
// functions pass messages into a channel, which we can
// then read from and put into a websockets connection.
stream, errWithCode := m.processor.Stream().Open(
c.Request.Context(),
account,
streamType,
)
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
l := log.WithContext(c.Request.Context()). l := log.
WithContext(c.Request.Context()).
WithFields(kv.Fields{ WithFields(kv.Fields{
{"account", account.Username}, {"username", account.Username},
{"streamID", stream.ID}, {"streamID", stream.ID},
{"streamType", streamType},
}...) }...)
// Upgrade the incoming HTTP request, which hijacks the underlying // Upgrade the incoming HTTP request. This hijacks the
// connection and reuses it for the websocket (non-http) protocol. // underlying connection and reuses it for the websocket
// (non-http) protocol.
//
// If the upgrade fails, then Upgrade replies to the client
// with an HTTP error response.
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil) wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
l.Errorf("error upgrading websocket connection: %v", err) l.Errorf("error upgrading websocket connection: %v", err)
@ -210,125 +228,208 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
return return
} }
l.Info("opened websocket connection")
// We perform the main websocket rw loops in a separate
// goroutine in order to let the upgrade handler return.
// This prevents the upgrade handler from holding open any
// throttle / rate-limit request tokens which could become
// problematic on instances with multiple users.
go m.handleWSConn(account.Username, wsConn, stream)
}
// handleWSConn handles a two-way websocket streaming connection.
// It will both read messages from the connection, and push messages
// into the connection. If any errors are encountered while reading
// or writing (including expected errors like clients leaving), the
// connection will be closed.
func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) {
// Create new context for the lifetime of this connection.
ctx, cancel := context.WithCancel(context.Background())
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
// Create ticker to send keepalive pings
pinger := time.NewTicker(m.dTicker)
// Read messages coming from the Websocket client connection into the server.
go func() { go func() {
// We perform the main websocket send loop in a separate defer cancel()
// goroutine in order to let the upgrade handler return. m.readFromWSConn(ctx, username, wsConn, stream)
// This prevents the upgrade handler from holding open any }()
// throttle / rate-limit request tokens which could become
// problematic on instances with multiple users.
l.Info("opened websocket connection")
defer l.Info("closed websocket connection")
// Create new context for lifetime of the connection // Write messages coming from the processor into the Websocket client connection.
ctx, cncl := context.WithCancel(context.Background()) go func() {
defer cancel()
m.writeToWSConn(ctx, username, wsConn, stream, pinger)
}()
// Create ticker to send alive pings // Wait for either the read or write functions to close, to indicate
pinger := time.NewTicker(m.dTicker) // that the client has left, or something else has gone wrong.
<-ctx.Done()
defer func() { // Tidy up underlying websocket connection.
// Signal done if err := wsConn.Close(); err != nil {
cncl() l.Errorf("error closing websocket connection: %v", err)
}
// Close websocket conn // Close processor channel so the processor knows
_ = wsConn.Close() // not to send any more messages to this stream.
close(stream.Hangup)
// Close processor stream // Stop ping ticker (tiny resource saving).
close(stream.Hangup) pinger.Stop()
// Stop ping ticker l.Info("closed websocket connection")
pinger.Stop() }
}()
go func() { // readFromWSConn reads control messages coming in from the given
// Signal done // websockets connection, and modifies the subscription StreamTypes
defer cncl() // of the given stream accordingly after acquiring a lock on it.
//
// This is a blocking function; will return only on read error or
// if the given context is canceled.
func (m *Module) readFromWSConn(
ctx context.Context,
username string,
wsConn *websocket.Conn,
stream *streampkg.Stream,
) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
for { readLoop:
// We have to listen for received websocket messages in for {
// order to trigger the underlying wsConn.PingHandler(). select {
// case <-ctx.Done():
// Read JSON objects from the client and act on them // Connection closed.
var msg map[string]string break readLoop
err := wsConn.ReadJSON(&msg)
if err != nil {
if ctx.Err() == nil {
// Only log error if the connection was not closed
// by us. Uncanceled context indicates this is the case.
l.Errorf("error reading from websocket: %v", err)
}
return
}
l.Tracef("received message from websocket: %v", msg)
// If the message contains 'stream' and 'type' fields, we can default:
// update the set of timelines that are subscribed for events. // Read JSON objects from the client and act on them.
updateType, ok := msg["type"] var msg map[string]string
if !ok { if err := wsConn.ReadJSON(&msg); err != nil {
l.Warn("'type' field not provided") // Only log an error if something weird happened.
continue // See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
if websocket.IsUnexpectedCloseError(err, []int{
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived,
}...) {
l.Errorf("error reading from websocket: %v", err)
} }
updateStream, ok := msg["stream"] // The connection is gone; no
if !ok { // further streaming possible.
l.Warn("'stream' field not provided") break readLoop
continue
}
// Ignore if the updateStreamType is unknown (or missing),
// so a bad client can't cause extra memory allocations
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
l.Warnf("unknown 'stream' field: %v", msg)
continue
}
updateList, ok := msg["list"]
if ok {
updateStream += ":" + updateList
}
switch updateType {
case "subscribe":
stream.Lock()
stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
l.Warnf("invalid 'type' field: %v", msg)
}
} }
}()
for { // Messages *from* the WS connection are infrequent
select { // and usually interesting, so log this at info.
// Connection closed l.Infof("received message from websocket: %v", msg)
case <-ctx.Done():
return
// Received next stream message // If the message contains 'stream' and 'type' fields, we can
case msg := <-stream.Messages: // update the set of timelines that are subscribed for events.
l.Tracef("sending message to websocket: %+v", msg) updateType, ok := msg["type"]
if err := wsConn.WriteJSON(msg); err != nil { if !ok {
l.Debugf("error writing json to websocket: %v", err) l.Warn("'type' field not provided")
return continue
} }
// Reset on each successful send. updateStream, ok := msg["stream"]
pinger.Reset(m.dTicker) if !ok {
l.Warn("'stream' field not provided")
continue
}
// Send keep-alive "ping" // Ignore if the updateStreamType is unknown (or missing),
case <-pinger.C: // so a bad client can't cause extra memory allocations
l.Trace("pinging websocket ...") if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
if err := wsConn.WriteMessage( l.Warnf("unknown 'stream' field: %v", msg)
websocket.PingMessage, continue
[]byte{}, }
); err != nil {
l.Debugf("error writing ping to websocket: %v", err) updateList, ok := msg["list"]
return if ok {
} updateStream += ":" + updateList
}
switch updateType {
case "subscribe":
stream.Lock()
stream.StreamTypes[updateStream] = true
stream.Unlock()
case "unsubscribe":
stream.Lock()
delete(stream.StreamTypes, updateStream)
stream.Unlock()
default:
l.Warnf("invalid 'type' field: %v", msg)
} }
} }
}() }
l.Debug("finished reading from websocket connection")
}
// writeToWSConn receives messages coming from the processor via the
// given stream, and writes them into the given websockets connection.
// This function also handles sending ping messages into the websockets
// connection to keep it alive when no other activity occurs.
//
// This is a blocking function; will return only on write error or
// if the given context is canceled.
func (m *Module) writeToWSConn(
ctx context.Context,
username string,
wsConn *websocket.Conn,
stream *streampkg.Stream,
pinger *time.Ticker,
) {
l := log.
WithContext(ctx).
WithFields(kv.Fields{
{"username", username},
{"streamID", stream.ID},
}...)
writeLoop:
for {
select {
case <-ctx.Done():
// Connection closed.
break writeLoop
case msg := <-stream.Messages:
// Received a new message from the processor.
l.Tracef("writing message to websocket: %+v", msg)
if err := wsConn.WriteJSON(msg); err != nil {
l.Debugf("error writing json to websocket: %v", err)
break writeLoop
}
// Reset pinger on successful send, since
// we know the connection is still there.
pinger.Reset(m.dTicker)
case <-pinger.C:
// Time to send a keep-alive "ping".
l.Trace("writing ping control message to websocket")
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
l.Debugf("error writing ping to websocket: %v", err)
break writeLoop
}
}
}
l.Debug("finished writing to websocket connection")
} }

View file

@ -42,15 +42,18 @@ type Module struct {
} }
func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module { func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module {
// We expect CORS requests for websockets,
// (via eg., semaphore.social) so be lenient.
// TODO: make this customizable?
checkOrigin := func(r *http.Request) bool { return true }
return &Module{ return &Module{
processor: processor, processor: processor,
dTicker: dTicker, dTicker: dTicker,
wsUpgrade: websocket.Upgrader{ wsUpgrade: websocket.Upgrader{
ReadBufferSize: wsBuf, // we don't expect reads ReadBufferSize: wsBuf,
WriteBufferSize: wsBuf, WriteBufferSize: wsBuf,
CheckOrigin: checkOrigin,
// we expect cors requests (via eg., semaphore.social) so be lenient
CheckOrigin: func(r *http.Request) bool { return true },
}, },
} }
} }