mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-01-09 08:00:13 +00:00
[chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors (#1932)
* [chore/bugfix] Break Websockets logic into smaller read/write functions, don't log expected errors * tweak * tidy up, use control message
This commit is contained in:
parent
ba0bc06b8c
commit
3d16962173
|
@ -149,60 +149,78 @@
|
||||||
// '400':
|
// '400':
|
||||||
// description: bad request
|
// description: bad request
|
||||||
func (m *Module) StreamGETHandler(c *gin.Context) {
|
func (m *Module) StreamGETHandler(c *gin.Context) {
|
||||||
|
var (
|
||||||
|
account *gtsmodel.Account
|
||||||
|
errWithCode gtserror.WithCode
|
||||||
|
)
|
||||||
|
|
||||||
// First we check for a query param provided access token
|
// Try query param access token.
|
||||||
token := c.Query(AccessTokenQueryKey)
|
token := c.Query(AccessTokenQueryKey)
|
||||||
if token == "" {
|
if token == "" {
|
||||||
// Else we check the HTTP header provided token
|
// Try fallback HTTP header provided token.
|
||||||
token = c.GetHeader(AccessTokenHeader)
|
token = c.GetHeader(AccessTokenHeader)
|
||||||
}
|
}
|
||||||
|
|
||||||
var account *gtsmodel.Account
|
|
||||||
if token != "" {
|
if token != "" {
|
||||||
// Check the explicit token
|
// Token was provided, use it to authorize stream.
|
||||||
var errWithCode gtserror.WithCode
|
|
||||||
account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token)
|
account, errWithCode = m.processor.Stream().Authorize(c.Request.Context(), token)
|
||||||
if errWithCode != nil {
|
|
||||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// If no explicit token was provided, try regular oauth
|
// No explicit token was provided:
|
||||||
auth, errStr := oauth.Authed(c, true, true, true, true)
|
// try regular oauth as a last resort.
|
||||||
if errStr != nil {
|
account, errWithCode = func() (*gtsmodel.Account, gtserror.WithCode) {
|
||||||
err := gtserror.NewErrorUnauthorized(errStr, errStr.Error())
|
authed, err := oauth.Authed(c, true, true, true, true)
|
||||||
apiutil.ErrorHandler(c, err, m.processor.InstanceGetV1)
|
if err != nil {
|
||||||
return
|
return nil, gtserror.NewErrorUnauthorized(err, err.Error())
|
||||||
}
|
}
|
||||||
account = auth.Account
|
|
||||||
|
return authed.Account, nil
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the initial stream type, if there is one.
|
if errWithCode != nil {
|
||||||
// By appending other query params to the streamType,
|
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||||
// we can allow for streaming for specific list IDs
|
return
|
||||||
// or hashtags.
|
}
|
||||||
|
|
||||||
|
// Get the initial requested stream type, if there is one.
|
||||||
streamType := c.Query(StreamQueryKey)
|
streamType := c.Query(StreamQueryKey)
|
||||||
|
|
||||||
|
// By appending other query params to the streamType, we
|
||||||
|
// can allow streaming for specific list IDs or hashtags.
|
||||||
|
// The streamType in this case will end up looking like
|
||||||
|
// `hashtag:example` or `list:01H3YF48G8B7KTPQFS8D2QBVG8`.
|
||||||
if list := c.Query(StreamListKey); list != "" {
|
if list := c.Query(StreamListKey); list != "" {
|
||||||
streamType += ":" + list
|
streamType += ":" + list
|
||||||
} else if tag := c.Query(StreamTagKey); tag != "" {
|
} else if tag := c.Query(StreamTagKey); tag != "" {
|
||||||
streamType += ":" + tag
|
streamType += ":" + tag
|
||||||
}
|
}
|
||||||
|
|
||||||
stream, errWithCode := m.processor.Stream().Open(c.Request.Context(), account, streamType)
|
// Open a stream with the processor; this lets processor
|
||||||
|
// functions pass messages into a channel, which we can
|
||||||
|
// then read from and put into a websockets connection.
|
||||||
|
stream, errWithCode := m.processor.Stream().Open(
|
||||||
|
c.Request.Context(),
|
||||||
|
account,
|
||||||
|
streamType,
|
||||||
|
)
|
||||||
if errWithCode != nil {
|
if errWithCode != nil {
|
||||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
l := log.WithContext(c.Request.Context()).
|
l := log.
|
||||||
|
WithContext(c.Request.Context()).
|
||||||
WithFields(kv.Fields{
|
WithFields(kv.Fields{
|
||||||
{"account", account.Username},
|
{"username", account.Username},
|
||||||
{"streamID", stream.ID},
|
{"streamID", stream.ID},
|
||||||
{"streamType", streamType},
|
|
||||||
}...)
|
}...)
|
||||||
|
|
||||||
// Upgrade the incoming HTTP request, which hijacks the underlying
|
// Upgrade the incoming HTTP request. This hijacks the
|
||||||
// connection and reuses it for the websocket (non-http) protocol.
|
// underlying connection and reuses it for the websocket
|
||||||
|
// (non-http) protocol.
|
||||||
|
//
|
||||||
|
// If the upgrade fails, then Upgrade replies to the client
|
||||||
|
// with an HTTP error response.
|
||||||
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
|
wsConn, err := m.wsUpgrade.Upgrade(c.Writer, c.Request, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
l.Errorf("error upgrading websocket connection: %v", err)
|
l.Errorf("error upgrading websocket connection: %v", err)
|
||||||
|
@ -210,125 +228,208 @@ func (m *Module) StreamGETHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
l.Info("opened websocket connection")
|
||||||
|
|
||||||
|
// We perform the main websocket rw loops in a separate
|
||||||
|
// goroutine in order to let the upgrade handler return.
|
||||||
|
// This prevents the upgrade handler from holding open any
|
||||||
|
// throttle / rate-limit request tokens which could become
|
||||||
|
// problematic on instances with multiple users.
|
||||||
|
go m.handleWSConn(account.Username, wsConn, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleWSConn handles a two-way websocket streaming connection.
|
||||||
|
// It will both read messages from the connection, and push messages
|
||||||
|
// into the connection. If any errors are encountered while reading
|
||||||
|
// or writing (including expected errors like clients leaving), the
|
||||||
|
// connection will be closed.
|
||||||
|
func (m *Module) handleWSConn(username string, wsConn *websocket.Conn, stream *streampkg.Stream) {
|
||||||
|
// Create new context for the lifetime of this connection.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
l := log.
|
||||||
|
WithContext(ctx).
|
||||||
|
WithFields(kv.Fields{
|
||||||
|
{"username", username},
|
||||||
|
{"streamID", stream.ID},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
// Create ticker to send keepalive pings
|
||||||
|
pinger := time.NewTicker(m.dTicker)
|
||||||
|
|
||||||
|
// Read messages coming from the Websocket client connection into the server.
|
||||||
go func() {
|
go func() {
|
||||||
// We perform the main websocket send loop in a separate
|
defer cancel()
|
||||||
// goroutine in order to let the upgrade handler return.
|
m.readFromWSConn(ctx, username, wsConn, stream)
|
||||||
// This prevents the upgrade handler from holding open any
|
}()
|
||||||
// throttle / rate-limit request tokens which could become
|
|
||||||
// problematic on instances with multiple users.
|
|
||||||
l.Info("opened websocket connection")
|
|
||||||
defer l.Info("closed websocket connection")
|
|
||||||
|
|
||||||
// Create new context for lifetime of the connection
|
// Write messages coming from the processor into the Websocket client connection.
|
||||||
ctx, cncl := context.WithCancel(context.Background())
|
go func() {
|
||||||
|
defer cancel()
|
||||||
|
m.writeToWSConn(ctx, username, wsConn, stream, pinger)
|
||||||
|
}()
|
||||||
|
|
||||||
// Create ticker to send alive pings
|
// Wait for either the read or write functions to close, to indicate
|
||||||
pinger := time.NewTicker(m.dTicker)
|
// that the client has left, or something else has gone wrong.
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
defer func() {
|
// Tidy up underlying websocket connection.
|
||||||
// Signal done
|
if err := wsConn.Close(); err != nil {
|
||||||
cncl()
|
l.Errorf("error closing websocket connection: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Close websocket conn
|
// Close processor channel so the processor knows
|
||||||
_ = wsConn.Close()
|
// not to send any more messages to this stream.
|
||||||
|
close(stream.Hangup)
|
||||||
|
|
||||||
// Close processor stream
|
// Stop ping ticker (tiny resource saving).
|
||||||
close(stream.Hangup)
|
pinger.Stop()
|
||||||
|
|
||||||
// Stop ping ticker
|
l.Info("closed websocket connection")
|
||||||
pinger.Stop()
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
// readFromWSConn reads control messages coming in from the given
|
||||||
// Signal done
|
// websockets connection, and modifies the subscription StreamTypes
|
||||||
defer cncl()
|
// of the given stream accordingly after acquiring a lock on it.
|
||||||
|
//
|
||||||
|
// This is a blocking function; will return only on read error or
|
||||||
|
// if the given context is canceled.
|
||||||
|
func (m *Module) readFromWSConn(
|
||||||
|
ctx context.Context,
|
||||||
|
username string,
|
||||||
|
wsConn *websocket.Conn,
|
||||||
|
stream *streampkg.Stream,
|
||||||
|
) {
|
||||||
|
l := log.
|
||||||
|
WithContext(ctx).
|
||||||
|
WithFields(kv.Fields{
|
||||||
|
{"username", username},
|
||||||
|
{"streamID", stream.ID},
|
||||||
|
}...)
|
||||||
|
|
||||||
for {
|
readLoop:
|
||||||
// We have to listen for received websocket messages in
|
for {
|
||||||
// order to trigger the underlying wsConn.PingHandler().
|
select {
|
||||||
//
|
case <-ctx.Done():
|
||||||
// Read JSON objects from the client and act on them
|
// Connection closed.
|
||||||
var msg map[string]string
|
break readLoop
|
||||||
err := wsConn.ReadJSON(&msg)
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() == nil {
|
|
||||||
// Only log error if the connection was not closed
|
|
||||||
// by us. Uncanceled context indicates this is the case.
|
|
||||||
l.Errorf("error reading from websocket: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
l.Tracef("received message from websocket: %v", msg)
|
|
||||||
|
|
||||||
// If the message contains 'stream' and 'type' fields, we can
|
default:
|
||||||
// update the set of timelines that are subscribed for events.
|
// Read JSON objects from the client and act on them.
|
||||||
updateType, ok := msg["type"]
|
var msg map[string]string
|
||||||
if !ok {
|
if err := wsConn.ReadJSON(&msg); err != nil {
|
||||||
l.Warn("'type' field not provided")
|
// Only log an error if something weird happened.
|
||||||
continue
|
// See: https://www.rfc-editor.org/rfc/rfc6455.html#section-11.7
|
||||||
|
if websocket.IsUnexpectedCloseError(err, []int{
|
||||||
|
websocket.CloseNormalClosure,
|
||||||
|
websocket.CloseGoingAway,
|
||||||
|
websocket.CloseNoStatusReceived,
|
||||||
|
}...) {
|
||||||
|
l.Errorf("error reading from websocket: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
updateStream, ok := msg["stream"]
|
// The connection is gone; no
|
||||||
if !ok {
|
// further streaming possible.
|
||||||
l.Warn("'stream' field not provided")
|
break readLoop
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore if the updateStreamType is unknown (or missing),
|
|
||||||
// so a bad client can't cause extra memory allocations
|
|
||||||
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
|
|
||||||
l.Warnf("unknown 'stream' field: %v", msg)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
updateList, ok := msg["list"]
|
|
||||||
if ok {
|
|
||||||
updateStream += ":" + updateList
|
|
||||||
}
|
|
||||||
|
|
||||||
switch updateType {
|
|
||||||
case "subscribe":
|
|
||||||
stream.Lock()
|
|
||||||
stream.StreamTypes[updateStream] = true
|
|
||||||
stream.Unlock()
|
|
||||||
case "unsubscribe":
|
|
||||||
stream.Lock()
|
|
||||||
delete(stream.StreamTypes, updateStream)
|
|
||||||
stream.Unlock()
|
|
||||||
default:
|
|
||||||
l.Warnf("invalid 'type' field: %v", msg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
// Messages *from* the WS connection are infrequent
|
||||||
select {
|
// and usually interesting, so log this at info.
|
||||||
// Connection closed
|
l.Infof("received message from websocket: %v", msg)
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
|
|
||||||
// Received next stream message
|
// If the message contains 'stream' and 'type' fields, we can
|
||||||
case msg := <-stream.Messages:
|
// update the set of timelines that are subscribed for events.
|
||||||
l.Tracef("sending message to websocket: %+v", msg)
|
updateType, ok := msg["type"]
|
||||||
if err := wsConn.WriteJSON(msg); err != nil {
|
if !ok {
|
||||||
l.Debugf("error writing json to websocket: %v", err)
|
l.Warn("'type' field not provided")
|
||||||
return
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset on each successful send.
|
updateStream, ok := msg["stream"]
|
||||||
pinger.Reset(m.dTicker)
|
if !ok {
|
||||||
|
l.Warn("'stream' field not provided")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Send keep-alive "ping"
|
// Ignore if the updateStreamType is unknown (or missing),
|
||||||
case <-pinger.C:
|
// so a bad client can't cause extra memory allocations
|
||||||
l.Trace("pinging websocket ...")
|
if !slices.Contains(streampkg.AllStatusTimelines, updateStream) {
|
||||||
if err := wsConn.WriteMessage(
|
l.Warnf("unknown 'stream' field: %v", msg)
|
||||||
websocket.PingMessage,
|
continue
|
||||||
[]byte{},
|
}
|
||||||
); err != nil {
|
|
||||||
l.Debugf("error writing ping to websocket: %v", err)
|
updateList, ok := msg["list"]
|
||||||
return
|
if ok {
|
||||||
}
|
updateStream += ":" + updateList
|
||||||
|
}
|
||||||
|
|
||||||
|
switch updateType {
|
||||||
|
case "subscribe":
|
||||||
|
stream.Lock()
|
||||||
|
stream.StreamTypes[updateStream] = true
|
||||||
|
stream.Unlock()
|
||||||
|
case "unsubscribe":
|
||||||
|
stream.Lock()
|
||||||
|
delete(stream.StreamTypes, updateStream)
|
||||||
|
stream.Unlock()
|
||||||
|
default:
|
||||||
|
l.Warnf("invalid 'type' field: %v", msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}
|
||||||
|
|
||||||
|
l.Debug("finished reading from websocket connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeToWSConn receives messages coming from the processor via the
|
||||||
|
// given stream, and writes them into the given websockets connection.
|
||||||
|
// This function also handles sending ping messages into the websockets
|
||||||
|
// connection to keep it alive when no other activity occurs.
|
||||||
|
//
|
||||||
|
// This is a blocking function; will return only on write error or
|
||||||
|
// if the given context is canceled.
|
||||||
|
func (m *Module) writeToWSConn(
|
||||||
|
ctx context.Context,
|
||||||
|
username string,
|
||||||
|
wsConn *websocket.Conn,
|
||||||
|
stream *streampkg.Stream,
|
||||||
|
pinger *time.Ticker,
|
||||||
|
) {
|
||||||
|
l := log.
|
||||||
|
WithContext(ctx).
|
||||||
|
WithFields(kv.Fields{
|
||||||
|
{"username", username},
|
||||||
|
{"streamID", stream.ID},
|
||||||
|
}...)
|
||||||
|
|
||||||
|
writeLoop:
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Connection closed.
|
||||||
|
break writeLoop
|
||||||
|
|
||||||
|
case msg := <-stream.Messages:
|
||||||
|
// Received a new message from the processor.
|
||||||
|
l.Tracef("writing message to websocket: %+v", msg)
|
||||||
|
if err := wsConn.WriteJSON(msg); err != nil {
|
||||||
|
l.Debugf("error writing json to websocket: %v", err)
|
||||||
|
break writeLoop
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset pinger on successful send, since
|
||||||
|
// we know the connection is still there.
|
||||||
|
pinger.Reset(m.dTicker)
|
||||||
|
|
||||||
|
case <-pinger.C:
|
||||||
|
// Time to send a keep-alive "ping".
|
||||||
|
l.Trace("writing ping control message to websocket")
|
||||||
|
if err := wsConn.WriteControl(websocket.PingMessage, nil, time.Time{}); err != nil {
|
||||||
|
l.Debugf("error writing ping to websocket: %v", err)
|
||||||
|
break writeLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Debug("finished writing to websocket connection")
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,15 +42,18 @@ type Module struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module {
|
func New(processor *processing.Processor, dTicker time.Duration, wsBuf int) *Module {
|
||||||
|
// We expect CORS requests for websockets,
|
||||||
|
// (via eg., semaphore.social) so be lenient.
|
||||||
|
// TODO: make this customizable?
|
||||||
|
checkOrigin := func(r *http.Request) bool { return true }
|
||||||
|
|
||||||
return &Module{
|
return &Module{
|
||||||
processor: processor,
|
processor: processor,
|
||||||
dTicker: dTicker,
|
dTicker: dTicker,
|
||||||
wsUpgrade: websocket.Upgrader{
|
wsUpgrade: websocket.Upgrader{
|
||||||
ReadBufferSize: wsBuf, // we don't expect reads
|
ReadBufferSize: wsBuf,
|
||||||
WriteBufferSize: wsBuf,
|
WriteBufferSize: wsBuf,
|
||||||
|
CheckOrigin: checkOrigin,
|
||||||
// we expect cors requests (via eg., semaphore.social) so be lenient
|
|
||||||
CheckOrigin: func(r *http.Request) bool { return true },
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue