diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go
index 42cbf318b..68b039d0c 100644
--- a/cmd/gotosocial/action/server/server.go
+++ b/cmd/gotosocial/action/server/server.go
@@ -87,9 +87,9 @@
// defer function for safe shutdown
// depending on what services were
// managed to be started.
-
- state = new(state.State)
- route *router.Router
+ state = new(state.State)
+ route *router.Router
+ process *processing.Processor
)
defer func() {
@@ -125,6 +125,23 @@
}
}
+ if process != nil {
+ const timeout = time.Minute
+
+ // Use a new timeout context to ensure
+ // persisting queued tasks does not fail!
+ // The main ctx is very likely canceled.
+ ctx := context.WithoutCancel(ctx)
+ ctx, cncl := context.WithTimeout(ctx, timeout)
+ defer cncl()
+
+ // Now that all the "moving" components have been stopped,
+ // persist any remaining queued worker tasks to the database.
+ if err := process.Admin().PersistWorkerQueues(ctx); err != nil {
+ log.Errorf(ctx, "error persisting worker queues: %v", err)
+ }
+ }
+
if state.DB != nil {
// Lastly, if database service was started,
// ensure it gets closed now all else stopped.
@@ -270,7 +287,7 @@ func(context.Context, time.Time) {
// Create the processor using all the
// other services we've created so far.
- processor := processing.NewProcessor(
+ process = processing.NewProcessor(
cleaner,
typeConverter,
federator,
@@ -286,14 +303,14 @@ func(context.Context, time.Time) {
state.Workers.Client.Init(messages.ClientMsgIndices())
state.Workers.Federator.Init(messages.FederatorMsgIndices())
state.Workers.Delivery.Init(client)
- state.Workers.Client.Process = processor.Workers().ProcessFromClientAPI
- state.Workers.Federator.Process = processor.Workers().ProcessFromFediAPI
+ state.Workers.Client.Process = process.Workers().ProcessFromClientAPI
+ state.Workers.Federator.Process = process.Workers().ProcessFromFediAPI
// Now start workers!
state.Workers.Start()
// Schedule notif tasks for all existing poll expiries.
- if err := processor.Polls().ScheduleAll(ctx); err != nil {
+ if err := process.Polls().ScheduleAll(ctx); err != nil {
return fmt.Errorf("error scheduling poll expiries: %w", err)
}
@@ -303,7 +320,7 @@ func(context.Context, time.Time) {
}
// Run advanced migrations.
- if err := processor.AdvancedMigrations().Migrate(ctx); err != nil {
+ if err := process.AdvancedMigrations().Migrate(ctx); err != nil {
return err
}
@@ -370,7 +387,7 @@ func(context.Context, time.Time) {
// attach global no route / 404 handler to the router
route.AttachNoRouteHandler(func(c *gin.Context) {
- apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), processor.InstanceGetV1)
+ apiutil.ErrorHandler(c, gtserror.NewErrorNotFound(errors.New(http.StatusText(http.StatusNotFound))), process.InstanceGetV1)
})
// build router modules
@@ -393,15 +410,15 @@ func(context.Context, time.Time) {
}
var (
- authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths
- clientModule = api.NewClient(state, processor) // api client endpoints
- metricsModule = api.NewMetrics() // Metrics endpoints
- healthModule = api.NewHealth(dbService.Ready) // Health check endpoints
- fileserverModule = api.NewFileserver(processor) // fileserver endpoints
- wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints
- nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint
- activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints
- webModule = web.New(dbService, processor) // web pages + user profiles + settings panels etc
+ authModule = api.NewAuth(dbService, process, idp, routerSession, sessionName) // auth/oauth paths
+ clientModule = api.NewClient(state, process) // api client endpoints
+ metricsModule = api.NewMetrics() // Metrics endpoints
+ healthModule = api.NewHealth(dbService.Ready) // Health check endpoints
+ fileserverModule = api.NewFileserver(process) // fileserver endpoints
+ wellKnownModule = api.NewWellKnown(process) // .well-known endpoints
+ nodeInfoModule = api.NewNodeInfo(process) // nodeinfo endpoint
+ activityPubModule = api.NewActivityPub(dbService, process) // ActivityPub endpoints
+ webModule = web.New(dbService, process) // web pages + user profiles + settings panels etc
)
// create required middleware
@@ -416,10 +433,11 @@ func(context.Context, time.Time) {
// throttling
cpuMultiplier := config.GetAdvancedThrottlingMultiplier()
retryAfter := config.GetAdvancedThrottlingRetryAfter()
- clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api
- s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // server-to-server (AP)
- fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis
- pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately
+ clThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // client api
+ s2sThrottle := middleware.Throttle(cpuMultiplier, retryAfter)
+ // server-to-server (AP)
+ fsThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // fileserver / web templates / emojis
+ pkThrottle := middleware.Throttle(cpuMultiplier, retryAfter) // throttle public key endpoint separately
gzip := middleware.Gzip() // applied to all except fileserver
@@ -442,6 +460,11 @@ func(context.Context, time.Time) {
return fmt.Errorf("error starting router: %w", err)
}
+ // Fill worker queues from persisted task data in database.
+ if err := process.Admin().FillWorkerQueues(ctx); err != nil {
+ return fmt.Errorf("error filling worker queues: %w", err)
+ }
+
// catch shutdown signals from the operating system
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 070d4eb91..d5071d141 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -84,6 +84,7 @@ type DBService struct {
db.Timeline
db.User
db.Tombstone
+ db.WorkerTask
db *bun.DB
}
@@ -302,6 +303,9 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db,
state: state,
},
+ WorkerTask: &workerTaskDB{
+ db: db,
+ },
db: db,
}
diff --git a/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go b/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go
new file mode 100644
index 000000000..3b0ebcfd8
--- /dev/null
+++ b/internal/db/bundb/migrations/20240617134210_add_worker_tasks_table.go
@@ -0,0 +1,51 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package migrations
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // WorkerTask table.
+ if _, err := tx.
+ NewCreateTable().
+ Model(>smodel.WorkerTask{}).
+ IfNotExists().
+ Exec(ctx); err != nil {
+ return err
+ }
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/workertask.go b/internal/db/bundb/workertask.go
new file mode 100644
index 000000000..eec51530d
--- /dev/null
+++ b/internal/db/bundb/workertask.go
@@ -0,0 +1,58 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package bundb
+
+import (
+ "context"
+ "errors"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+type workerTaskDB struct{ db *bun.DB }
+
+func (w *workerTaskDB) GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error) {
+ var tasks []*gtsmodel.WorkerTask
+ if err := w.db.NewSelect().
+ Model(&tasks).
+ OrderExpr("? ASC", bun.Ident("created_at")).
+ Scan(ctx); err != nil {
+ return nil, err
+ }
+ return tasks, nil
+}
+
+func (w *workerTaskDB) PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error {
+ var errs []error
+ for _, task := range tasks {
+ _, err := w.db.NewInsert().Model(task).Exec(ctx)
+ if err != nil {
+ errs = append(errs, err)
+ }
+ }
+ return errors.Join(errs...)
+}
+
+func (w *workerTaskDB) DeleteWorkerTaskByID(ctx context.Context, id uint) error {
+ _, err := w.db.NewDelete().
+ Table("worker_tasks").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx)
+ return err
+}
diff --git a/internal/db/db.go b/internal/db/db.go
index 4b2152732..cd621871a 100644
--- a/internal/db/db.go
+++ b/internal/db/db.go
@@ -56,4 +56,5 @@ type DB interface {
Timeline
User
Tombstone
+ WorkerTask
}
diff --git a/internal/db/workertask.go b/internal/db/workertask.go
new file mode 100644
index 000000000..0276f231a
--- /dev/null
+++ b/internal/db/workertask.go
@@ -0,0 +1,35 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package db
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
+type WorkerTask interface {
+ // GetWorkerTasks fetches all persisted worker tasks from the database.
+ GetWorkerTasks(ctx context.Context) ([]*gtsmodel.WorkerTask, error)
+
+ // PutWorkerTasks persists the given worker tasks to the database.
+ PutWorkerTasks(ctx context.Context, tasks []*gtsmodel.WorkerTask) error
+
+ // DeleteWorkerTask deletes worker task with given ID from database.
+ DeleteWorkerTaskByID(ctx context.Context, id uint) error
+}
diff --git a/internal/gtsmodel/workertask.go b/internal/gtsmodel/workertask.go
index cc8433199..758fc4cd7 100644
--- a/internal/gtsmodel/workertask.go
+++ b/internal/gtsmodel/workertask.go
@@ -34,8 +34,8 @@
// queued tasks from being lost. It is simply a
// means to store a blob of serialized task data.
type WorkerTask struct {
- ID uint `bun:""`
- WorkerType uint8 `bun:""`
- TaskData []byte `bun:""`
- CreatedAt time.Time `bun:""`
+ ID uint `bun:",pk,autoincrement"`
+ WorkerType WorkerType `bun:",notnull"`
+ TaskData []byte `bun:",nullzero,notnull"`
+ CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"`
}
diff --git a/internal/httpclient/client.go b/internal/httpclient/client.go
index b78dbc2d9..30ef0b04d 100644
--- a/internal/httpclient/client.go
+++ b/internal/httpclient/client.go
@@ -197,7 +197,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {
// If the fast-fail flag was set, just
// attempt a single iteration instead of
// following the below retry-backoff loop.
- rsp, _, err = c.DoOnce(&req)
+ rsp, _, err = c.DoOnce(req)
if err != nil {
return nil, fmt.Errorf("%w (fast fail)", err)
}
@@ -208,7 +208,7 @@ func (c *Client) Do(r *http.Request) (rsp *http.Response, err error) {
var retry bool
// Perform the http request.
- rsp, retry, err = c.DoOnce(&req)
+ rsp, retry, err = c.DoOnce(req)
if err == nil {
return rsp, nil
}
diff --git a/internal/httpclient/request.go b/internal/httpclient/request.go
index e5a7f44d3..dfe51b160 100644
--- a/internal/httpclient/request.go
+++ b/internal/httpclient/request.go
@@ -47,8 +47,8 @@ type Request struct {
// WrapRequest wraps an existing http.Request within
// our own httpclient.Request with retry / backoff tracking.
-func WrapRequest(r *http.Request) Request {
- var rr Request
+func WrapRequest(r *http.Request) *Request {
+ rr := new(Request)
rr.Request = r
entry := log.WithContext(r.Context())
entry = entry.WithField("method", r.Method)
diff --git a/internal/messages/messages.go b/internal/messages/messages.go
index 7779633ba..d652c0c5c 100644
--- a/internal/messages/messages.go
+++ b/internal/messages/messages.go
@@ -352,7 +352,7 @@ func resolveAPObject(data map[string]interface{}) (interface{}, error) {
// we then need to wrangle back into the original type. So we also store the type name
// and use this to determine the appropriate Go structure type to unmarshal into to.
func resolveGTSModel(typ string, data []byte) (interface{}, error) {
- if typ == "" && data == nil {
+ if typ == "" {
// No data given.
return nil, nil
}
diff --git a/internal/processing/admin/workertask.go b/internal/processing/admin/workertask.go
new file mode 100644
index 000000000..6d7cc7b7a
--- /dev/null
+++ b/internal/processing/admin/workertask.go
@@ -0,0 +1,426 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package admin
+
+import (
+ "context"
+ "fmt"
+ "slices"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/transport"
+ "github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
+)
+
+// NOTE:
+// Having these functions in the processor, which is
+// usually the intermediary that performs *processing*
+// between the HTTP route handlers and the underlying
+// database / storage layers is a little odd, so this
+// may be subject to change!
+//
+// For now at least, this is a useful place that has
+// access to the underlying database, workers and
+// causes no dependency cycles with this use case!
+
+// FillWorkerQueues recovers all serialized worker tasks from the database
+// (if any!), and pushes them to each of their relevant worker queues.
+func (p *Processor) FillWorkerQueues(ctx context.Context) error {
+ log.Info(ctx, "rehydrate!")
+
+ // Get all persisted worker tasks from db.
+ //
+ // (database returns these as ASCENDING, i.e.
+ // returned in the order they were inserted).
+ tasks, err := p.state.DB.GetWorkerTasks(ctx)
+ if err != nil {
+ return gtserror.Newf("error fetching worker tasks from db: %w", err)
+ }
+
+ var (
+ // Counts of each task type
+ // successfully recovered.
+ delivery int
+ federator int
+ client int
+
+ // Failed recoveries.
+ errors int
+ )
+
+loop:
+
+ // Handle each persisted task, removing
+ // all those we can't handle. Leaving us
+ // with a slice of tasks we can safely
+ // delete from being persisted in the DB.
+ for i := 0; i < len(tasks); {
+ var err error
+
+ // Task at index.
+ task := tasks[i]
+
+ // Appropriate task count
+ // pointer to increment.
+ var counter *int
+
+ // Attempt to recovery persisted
+ // task depending on worker type.
+ switch task.WorkerType {
+ case gtsmodel.DeliveryWorker:
+ err = p.pushDelivery(ctx, task)
+ counter = &delivery
+ case gtsmodel.FederatorWorker:
+ err = p.pushFederator(ctx, task)
+ counter = &federator
+ case gtsmodel.ClientWorker:
+ err = p.pushClient(ctx, task)
+ counter = &client
+ default:
+ err = fmt.Errorf("invalid worker type %d", task.WorkerType)
+ }
+
+ if err != nil {
+ log.Errorf(ctx, "error pushing task %d: %v", task.ID, err)
+
+ // Drop error'd task from slice.
+ tasks = slices.Delete(tasks, i, i+1)
+
+ // Incr errors.
+ errors++
+ continue loop
+ }
+
+ // Increment slice
+ // index & counter.
+ (*counter)++
+ i++
+ }
+
+ // Tasks that worker successfully pushed
+ // to their appropriate workers, we can
+ // safely now remove from the database.
+ for _, task := range tasks {
+ if err := p.state.DB.DeleteWorkerTaskByID(ctx, task.ID); err != nil {
+ log.Errorf(ctx, "error deleting task from db: %v", err)
+ }
+ }
+
+ // Log recovered tasks.
+ log.WithContext(ctx).
+ WithField("delivery", delivery).
+ WithField("federator", federator).
+ WithField("client", client).
+ WithField("errors", errors).
+ Info("recovered queued tasks")
+
+ return nil
+}
+
+// PersistWorkerQueues pops all queued worker tasks (that are themselves persistable, i.e. not
+// dereference tasks which are just function ptrs), serializes and persists them to the database.
+func (p *Processor) PersistWorkerQueues(ctx context.Context) error {
+ log.Info(ctx, "dehydrate!")
+
+ var (
+ // Counts of each task type
+ // successfully persisted.
+ delivery int
+ federator int
+ client int
+
+ // Failed persists.
+ errors int
+
+ // Serialized tasks to persist.
+ tasks []*gtsmodel.WorkerTask
+ )
+
+ for {
+ // Pop all queued deliveries.
+ task, err := p.popDelivery()
+ if err != nil {
+ log.Errorf(ctx, "error popping delivery: %v", err)
+ errors++ // incr error count.
+ continue
+ }
+
+ if task == nil {
+ // No more queue
+ // tasks to pop!
+ break
+ }
+
+ // Append serialized task.
+ tasks = append(tasks, task)
+ delivery++ // incr count
+ }
+
+ for {
+ // Pop queued federator msgs.
+ task, err := p.popFederator()
+ if err != nil {
+ log.Errorf(ctx, "error popping federator message: %v", err)
+ errors++ // incr count
+ continue
+ }
+
+ if task == nil {
+ // No more queue
+ // tasks to pop!
+ break
+ }
+
+ // Append serialized task.
+ tasks = append(tasks, task)
+ federator++ // incr count
+ }
+
+ for {
+ // Pop queued client msgs.
+ task, err := p.popClient()
+ if err != nil {
+ log.Errorf(ctx, "error popping client message: %v", err)
+ continue
+ }
+
+ if task == nil {
+ // No more queue
+ // tasks to pop!
+ break
+ }
+
+ // Append serialized task.
+ tasks = append(tasks, task)
+ client++ // incr count
+ }
+
+ // Persist all serialized queued worker tasks to database.
+ if err := p.state.DB.PutWorkerTasks(ctx, tasks); err != nil {
+ return gtserror.Newf("error putting tasks in db: %w", err)
+ }
+
+ // Log recovered tasks.
+ log.WithContext(ctx).
+ WithField("delivery", delivery).
+ WithField("federator", federator).
+ WithField("client", client).
+ WithField("errors", errors).
+ Info("persisted queued tasks")
+
+ return nil
+}
+
+// pushDelivery parses a valid delivery.Delivery{} from serialized task data and pushes to queue.
+func (p *Processor) pushDelivery(ctx context.Context, task *gtsmodel.WorkerTask) error {
+ dlv := new(delivery.Delivery)
+
+ // Deserialize the raw worker task data into delivery.
+ if err := dlv.Deserialize(task.TaskData); err != nil {
+ return gtserror.Newf("error deserializing delivery: %w", err)
+ }
+
+ var tsport transport.Transport
+
+ if uri := dlv.ActorID; uri != "" {
+ // Fetch the actor account by provided URI from db.
+ account, err := p.state.DB.GetAccountByURI(ctx, uri)
+ if err != nil {
+ return gtserror.Newf("error getting actor account %s from db: %w", uri, err)
+ }
+
+ // Fetch a transport for request signing for actor's account username.
+ tsport, err = p.transport.NewTransportForUsername(ctx, account.Username)
+ if err != nil {
+ return gtserror.Newf("error getting transport for actor %s: %w", uri, err)
+ }
+ } else {
+ var err error
+
+ // No actor was given, will be signed by instance account.
+ tsport, err = p.transport.NewTransportForUsername(ctx, "")
+ if err != nil {
+ return gtserror.Newf("error getting instance account transport: %w", err)
+ }
+ }
+
+ // Using transport, add actor signature to delivery.
+ if err := tsport.SignDelivery(dlv); err != nil {
+ return gtserror.Newf("error signing delivery: %w", err)
+ }
+
+ // Push deserialized task to delivery queue.
+ p.state.Workers.Delivery.Queue.Push(dlv)
+
+ return nil
+}
+
+// popDelivery pops delivery.Delivery{} from queue and serializes as valid task data.
+func (p *Processor) popDelivery() (*gtsmodel.WorkerTask, error) {
+
+ // Pop waiting delivery from the delivery worker.
+ delivery, ok := p.state.Workers.Delivery.Queue.Pop()
+ if !ok {
+ return nil, nil
+ }
+
+ // Serialize the delivery task data.
+ data, err := delivery.Serialize()
+ if err != nil {
+ return nil, gtserror.Newf("error serializing delivery: %w", err)
+ }
+
+ return >smodel.WorkerTask{
+ // ID is autoincrement
+ WorkerType: gtsmodel.DeliveryWorker,
+ TaskData: data,
+ CreatedAt: time.Now(),
+ }, nil
+}
+
+// pushClient parses a valid messages.FromFediAPI{} from serialized task data and pushes to queue.
+func (p *Processor) pushFederator(ctx context.Context, task *gtsmodel.WorkerTask) error {
+ var msg messages.FromFediAPI
+
+ // Deserialize the raw worker task data into message.
+ if err := msg.Deserialize(task.TaskData); err != nil {
+ return gtserror.Newf("error deserializing federator message: %w", err)
+ }
+
+ if rcv := msg.Receiving; rcv != nil {
+ // Only a placeholder receiving account will be populated,
+ // fetch the actual model from database by persisted ID.
+ account, err := p.state.DB.GetAccountByID(ctx, rcv.ID)
+ if err != nil {
+ return gtserror.Newf("error fetching receiving account %s from db: %w", rcv.ID, err)
+ }
+
+ // Set the now populated
+ // receiving account model.
+ msg.Receiving = account
+ }
+
+ if req := msg.Requesting; req != nil {
+ // Only a placeholder requesting account will be populated,
+ // fetch the actual model from database by persisted ID.
+ account, err := p.state.DB.GetAccountByID(ctx, req.ID)
+ if err != nil {
+ return gtserror.Newf("error fetching requesting account %s from db: %w", req.ID, err)
+ }
+
+ // Set the now populated
+ // requesting account model.
+ msg.Requesting = account
+ }
+
+ // Push populated task to the federator queue.
+ p.state.Workers.Federator.Queue.Push(&msg)
+
+ return nil
+}
+
+// popFederator pops messages.FromFediAPI{} from queue and serializes as valid task data.
+func (p *Processor) popFederator() (*gtsmodel.WorkerTask, error) {
+
+ // Pop waiting message from the federator worker.
+ msg, ok := p.state.Workers.Federator.Queue.Pop()
+ if !ok {
+ return nil, nil
+ }
+
+ // Serialize message task data.
+ data, err := msg.Serialize()
+ if err != nil {
+ return nil, gtserror.Newf("error serializing federator message: %w", err)
+ }
+
+ return >smodel.WorkerTask{
+ // ID is autoincrement
+ WorkerType: gtsmodel.FederatorWorker,
+ TaskData: data,
+ CreatedAt: time.Now(),
+ }, nil
+}
+
+// pushClient parses a valid messages.FromClientAPI{} from serialized task data and pushes to queue.
+func (p *Processor) pushClient(ctx context.Context, task *gtsmodel.WorkerTask) error {
+ var msg messages.FromClientAPI
+
+ // Deserialize the raw worker task data into message.
+ if err := msg.Deserialize(task.TaskData); err != nil {
+ return gtserror.Newf("error deserializing client message: %w", err)
+ }
+
+ if org := msg.Origin; org != nil {
+ // Only a placeholder origin account will be populated,
+ // fetch the actual model from database by persisted ID.
+ account, err := p.state.DB.GetAccountByID(ctx, org.ID)
+ if err != nil {
+ return gtserror.Newf("error fetching origin account %s from db: %w", org.ID, err)
+ }
+
+ // Set the now populated
+ // origin account model.
+ msg.Origin = account
+ }
+
+ if trg := msg.Target; trg != nil {
+ // Only a placeholder target account will be populated,
+ // fetch the actual model from database by persisted ID.
+ account, err := p.state.DB.GetAccountByID(ctx, trg.ID)
+ if err != nil {
+ return gtserror.Newf("error fetching target account %s from db: %w", trg.ID, err)
+ }
+
+ // Set the now populated
+ // target account model.
+ msg.Target = account
+ }
+
+ // Push populated task to the federator queue.
+ p.state.Workers.Client.Queue.Push(&msg)
+
+ return nil
+}
+
+// popClient pops messages.FromClientAPI{} from queue and serializes as valid task data.
+func (p *Processor) popClient() (*gtsmodel.WorkerTask, error) {
+
+ // Pop waiting message from the client worker.
+ msg, ok := p.state.Workers.Client.Queue.Pop()
+ if !ok {
+ return nil, nil
+ }
+
+ // Serialize message task data.
+ data, err := msg.Serialize()
+ if err != nil {
+ return nil, gtserror.Newf("error serializing client message: %w", err)
+ }
+
+ return >smodel.WorkerTask{
+ // ID is autoincrement
+ WorkerType: gtsmodel.ClientWorker,
+ TaskData: data,
+ CreatedAt: time.Now(),
+ }, nil
+}
diff --git a/internal/processing/admin/workertask_test.go b/internal/processing/admin/workertask_test.go
new file mode 100644
index 000000000..bf326bafd
--- /dev/null
+++ b/internal/processing/admin/workertask_test.go
@@ -0,0 +1,421 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package admin_test
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/ap"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/httpclient"
+ "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+var (
+ // TODO: move these test values into
+ // the testrig test models area. They'll
+ // need to be as both WorkerTask and as
+ // the raw types themselves.
+
+ testDeliveries = []*delivery.Delivery{
+ {
+ ObjectID: "https://google.com/users/bigboy/follow/1",
+ TargetID: "https://askjeeves.com/users/smallboy",
+ Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Host": {"https://askjeeves.com"}}),
+ },
+ {
+ Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), http.Header{"Host": {"https://google.com"}}),
+ },
+ }
+
+ testFederatorMsgs = []*messages.FromFediAPI{
+ {
+ APObjectType: ap.ObjectNote,
+ APActivityType: ap.ActivityCreate,
+ TargetURI: "https://gotosocial.org",
+ Requesting: >smodel.Account{ID: "654321"},
+ Receiving: >smodel.Account{ID: "123456"},
+ },
+ {
+ APObjectType: ap.ObjectProfile,
+ APActivityType: ap.ActivityUpdate,
+ TargetURI: "https://uk-queen-is-dead.org",
+ Requesting: >smodel.Account{ID: "123456"},
+ Receiving: >smodel.Account{ID: "654321"},
+ },
+ }
+
+ testClientMsgs = []*messages.FromClientAPI{
+ {
+ APObjectType: ap.ObjectNote,
+ APActivityType: ap.ActivityCreate,
+ TargetURI: "https://gotosocial.org",
+ Origin: >smodel.Account{ID: "654321"},
+ Target: >smodel.Account{ID: "123456"},
+ },
+ {
+ APObjectType: ap.ObjectProfile,
+ APActivityType: ap.ActivityUpdate,
+ TargetURI: "https://uk-queen-is-dead.org",
+ Origin: >smodel.Account{ID: "123456"},
+ Target: >smodel.Account{ID: "654321"},
+ },
+ }
+)
+
+type WorkerTaskTestSuite struct {
+ AdminStandardTestSuite
+}
+
+func (suite *WorkerTaskTestSuite) TestFillWorkerQueues() {
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ var tasks []*gtsmodel.WorkerTask
+
+ for _, dlv := range testDeliveries {
+ // Serialize all test deliveries.
+ data, err := dlv.Serialize()
+ if err != nil {
+ panic(err)
+ }
+
+ // Append each serialized delivery to tasks.
+ tasks = append(tasks, >smodel.WorkerTask{
+ WorkerType: gtsmodel.DeliveryWorker,
+ TaskData: data,
+ })
+ }
+
+ for _, msg := range testFederatorMsgs {
+ // Serialize all test messages.
+ data, err := msg.Serialize()
+ if err != nil {
+ panic(err)
+ }
+
+ if msg.Receiving != nil {
+ // Quick hack to bypass database errors for non-existing
+ // accounts, instead we just insert this into cache ;).
+ suite.state.Caches.DB.Account.Put(msg.Receiving)
+ suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{
+ AccountID: msg.Receiving.ID,
+ })
+ }
+
+ if msg.Requesting != nil {
+ // Quick hack to bypass database errors for non-existing
+ // accounts, instead we just insert this into cache ;).
+ suite.state.Caches.DB.Account.Put(msg.Requesting)
+ suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{
+ AccountID: msg.Requesting.ID,
+ })
+ }
+
+ // Append each serialized message to tasks.
+ tasks = append(tasks, >smodel.WorkerTask{
+ WorkerType: gtsmodel.FederatorWorker,
+ TaskData: data,
+ })
+ }
+
+ for _, msg := range testClientMsgs {
+ // Serialize all test messages.
+ data, err := msg.Serialize()
+ if err != nil {
+ panic(err)
+ }
+
+ if msg.Origin != nil {
+ // Quick hack to bypass database errors for non-existing
+ // accounts, instead we just insert this into cache ;).
+ suite.state.Caches.DB.Account.Put(msg.Origin)
+ suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{
+ AccountID: msg.Origin.ID,
+ })
+ }
+
+ if msg.Target != nil {
+ // Quick hack to bypass database errors for non-existing
+ // accounts, instead we just insert this into cache ;).
+ suite.state.Caches.DB.Account.Put(msg.Target)
+ suite.state.Caches.DB.AccountSettings.Put(>smodel.AccountSettings{
+ AccountID: msg.Target.ID,
+ })
+ }
+
+ // Append each serialized message to tasks.
+ tasks = append(tasks, >smodel.WorkerTask{
+ WorkerType: gtsmodel.ClientWorker,
+ TaskData: data,
+ })
+ }
+
+ // Persist all test worker tasks to the database.
+ err := suite.state.DB.PutWorkerTasks(ctx, tasks)
+ suite.NoError(err)
+
+ // Fill the worker queues from persisted task data.
+ err = suite.adminProcessor.FillWorkerQueues(ctx)
+ suite.NoError(err)
+
+ var (
+ // Recovered
+ // task counts.
+ ndelivery int
+ nfederator int
+ nclient int
+ )
+
+ // Fetch current gotosocial instance account, for later checks.
+ instanceAcc, err := suite.state.DB.GetInstanceAccount(ctx, "")
+ suite.NoError(err)
+
+ for {
+ // Pop all queued delivery tasks from worker queue.
+ dlv, ok := suite.state.Workers.Delivery.Queue.Pop()
+ if !ok {
+ break
+ }
+
+ // Incr count.
+ ndelivery++
+
+ // Check that we have this message in slice.
+ err = containsSerializable(testDeliveries, dlv)
+ suite.NoError(err)
+
+ // Check that delivery request context has instance account pubkey.
+ pubKeyID := gtscontext.OutgoingPublicKeyID(dlv.Request.Context())
+ suite.Equal(instanceAcc.PublicKeyURI, pubKeyID)
+ signfn := gtscontext.HTTPClientSignFunc(dlv.Request.Context())
+ suite.NotNil(signfn)
+ }
+
+ for {
+ // Pop all queued federator messages from worker queue.
+ msg, ok := suite.state.Workers.Federator.Queue.Pop()
+ if !ok {
+ break
+ }
+
+ // Incr count.
+ nfederator++
+
+ // Check that we have this message in slice.
+ err = containsSerializable(testFederatorMsgs, msg)
+ suite.NoError(err)
+ }
+
+ for {
+ // Pop all queued client messages from worker queue.
+ msg, ok := suite.state.Workers.Client.Queue.Pop()
+ if !ok {
+ break
+ }
+
+ // Incr count.
+ nclient++
+
+ // Check that we have this message in slice.
+ err = containsSerializable(testClientMsgs, msg)
+ suite.NoError(err)
+ }
+
+ // Ensure recovered task counts as expected.
+ suite.Equal(len(testDeliveries), ndelivery)
+ suite.Equal(len(testFederatorMsgs), nfederator)
+ suite.Equal(len(testClientMsgs), nclient)
+}
+
+func (suite *WorkerTaskTestSuite) TestPersistWorkerQueues() {
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Push all test worker tasks to their respective queues.
+ suite.state.Workers.Delivery.Queue.Push(testDeliveries...)
+ suite.state.Workers.Federator.Queue.Push(testFederatorMsgs...)
+ suite.state.Workers.Client.Queue.Push(testClientMsgs...)
+
+ // Persist the worker queued tasks to database.
+ err := suite.adminProcessor.PersistWorkerQueues(ctx)
+ suite.NoError(err)
+
+ // Fetch all the persisted tasks from database.
+ tasks, err := suite.state.DB.GetWorkerTasks(ctx)
+ suite.NoError(err)
+
+ var (
+ // Persisted
+ // task counts.
+ ndelivery int
+ nfederator int
+ nclient int
+ )
+
+ // Check persisted task data.
+ for _, task := range tasks {
+ switch task.WorkerType {
+ case gtsmodel.DeliveryWorker:
+ var dlv delivery.Delivery
+
+ // Incr count.
+ ndelivery++
+
+ // Deserialize the persisted task data.
+ err := dlv.Deserialize(task.TaskData)
+ suite.NoError(err)
+
+ // Check that we have this delivery in slice.
+ err = containsSerializable(testDeliveries, &dlv)
+ suite.NoError(err)
+
+ case gtsmodel.FederatorWorker:
+ var msg messages.FromFediAPI
+
+ // Incr count.
+ nfederator++
+
+ // Deserialize the persisted task data.
+ err := msg.Deserialize(task.TaskData)
+ suite.NoError(err)
+
+ // Check that we have this message in slice.
+ err = containsSerializable(testFederatorMsgs, &msg)
+ suite.NoError(err)
+
+ case gtsmodel.ClientWorker:
+ var msg messages.FromClientAPI
+
+ // Incr count.
+ nclient++
+
+ // Deserialize the persisted task data.
+ err := msg.Deserialize(task.TaskData)
+ suite.NoError(err)
+
+ // Check that we have this message in slice.
+ err = containsSerializable(testClientMsgs, &msg)
+ suite.NoError(err)
+
+ default:
+ suite.T().Errorf("unexpected worker type: %d", task.WorkerType)
+ }
+ }
+
+ // Ensure persisted task counts as expected.
+ suite.Equal(len(testDeliveries), ndelivery)
+ suite.Equal(len(testFederatorMsgs), nfederator)
+ suite.Equal(len(testClientMsgs), nclient)
+}
+
+func (suite *WorkerTaskTestSuite) SetupTest() {
+ suite.AdminStandardTestSuite.SetupTest()
+ // we don't want workers running
+ testrig.StopWorkers(&suite.state)
+}
+
+func TestWorkerTaskTestSuite(t *testing.T) {
+ suite.Run(t, new(WorkerTaskTestSuite))
+}
+
+// containsSerializeable returns whether slice of serializables contains given serializable entry.
+func containsSerializable[T interface{ Serialize() ([]byte, error) }](expect []T, have T) error {
+ // Serialize wanted value.
+ bh, err := have.Serialize()
+ if err != nil {
+ panic(err)
+ }
+
+ var strings []string
+
+ for _, t := range expect {
+ // Serialize expected value.
+ be, err := t.Serialize()
+ if err != nil {
+ panic(err)
+ }
+
+ // Alloc as string.
+ se := string(be)
+
+ if se == string(bh) {
+ // We have this entry!
+ return nil
+ }
+
+ // Add to serialized strings.
+ strings = append(strings, se)
+ }
+
+ return fmt.Errorf("could not find %s in %s", string(bh), strings)
+}
+
+// urlStr simply returns u.String() or "" if nil.
+func urlStr(u *url.URL) string {
+ if u == nil {
+ return ""
+ }
+ return u.String()
+}
+
+// accountID simply returns account.ID or "" if nil.
+func accountID(account *gtsmodel.Account) string {
+ if account == nil {
+ return ""
+ }
+ return account.ID
+}
+
+// toRequest creates httpclient.Request from HTTP method, URL and body data.
+func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {
+ var rbody io.Reader
+ if body != nil {
+ rbody = bytes.NewReader(body)
+ }
+ req, err := http.NewRequest(method, url, rbody)
+ if err != nil {
+ panic(err)
+ }
+ for key, values := range hdr {
+ for _, value := range values {
+ req.Header.Add(key, value)
+ }
+ }
+ return httpclient.WrapRequest(req)
+}
+
+// toJSON marshals input type as JSON data.
+func toJSON(a any) []byte {
+ b, err := json.Marshal(a)
+ if err != nil {
+ panic(err)
+ }
+ return b
+}
diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go
index 30435b86f..36ad6f015 100644
--- a/internal/transport/deliver.go
+++ b/internal/transport/deliver.go
@@ -21,6 +21,7 @@
"bytes"
"context"
"encoding/json"
+ "io"
"net/http"
"net/url"
@@ -169,6 +170,38 @@ func (t *transport) prepare(
}, nil
}
+func (t *transport) SignDelivery(dlv *delivery.Delivery) error {
+ if dlv.Request.GetBody == nil {
+ return gtserror.New("delivery request body not rewindable")
+ }
+
+ // Get a new copy of the request body.
+ body, err := dlv.Request.GetBody()
+ if err != nil {
+ return gtserror.Newf("error getting request body: %w", err)
+ }
+
+ // Read body data into memory.
+ data, err := io.ReadAll(body)
+ if err != nil {
+ return gtserror.Newf("error reading request body: %w", err)
+ }
+
+ // Get signing function for POST data.
+ // (note that delivery is ALWAYS POST).
+ sign := t.signPOST(data)
+
+ // Extract delivery context.
+ ctx := dlv.Request.Context()
+
+ // Update delivery request context with signing details.
+ ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID)
+ ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign)
+ dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
+
+ return nil
+}
+
// getObjectID extracts an object ID from 'serialized' ActivityPub object map.
func getObjectID(obj map[string]interface{}) string {
switch t := obj["object"].(type) {
diff --git a/internal/transport/delivery/delivery.go b/internal/transport/delivery/delivery.go
index 1e3ebb054..e11eea83c 100644
--- a/internal/transport/delivery/delivery.go
+++ b/internal/transport/delivery/delivery.go
@@ -33,10 +33,6 @@
// be indexed (and so, dropped from queue)
// by any of these possible ID IRIs.
type Delivery struct {
- // PubKeyID is the signing public key
- // ID of the actor performing request.
- PubKeyID string
-
// ActorID contains the ActivityPub
// actor ID IRI (if any) of the activity
// being sent out by this request.
@@ -55,7 +51,7 @@ type Delivery struct {
// Request is the prepared (+ wrapped)
// httpclient.Client{} request that
// constitutes this ActivtyPub delivery.
- Request httpclient.Request
+ Request *httpclient.Request
// internal fields.
next time.Time
@@ -66,7 +62,6 @@ type Delivery struct {
// a json serialize / deserialize
// able shape that minimizes data.
type delivery struct {
- PubKeyID string `json:"pub_key_id,omitempty"`
ActorID string `json:"actor_id,omitempty"`
ObjectID string `json:"object_id,omitempty"`
TargetID string `json:"target_id,omitempty"`
@@ -101,7 +96,6 @@ func (dlv *Delivery) Serialize() ([]byte, error) {
// Marshal as internal JSON type.
return json.Marshal(delivery{
- PubKeyID: dlv.PubKeyID,
ActorID: dlv.ActorID,
ObjectID: dlv.ObjectID,
TargetID: dlv.TargetID,
@@ -125,7 +119,6 @@ func (dlv *Delivery) Deserialize(data []byte) error {
}
// Copy over simplest fields.
- dlv.PubKeyID = idlv.PubKeyID
dlv.ActorID = idlv.ActorID
dlv.ObjectID = idlv.ObjectID
dlv.TargetID = idlv.TargetID
@@ -143,6 +136,13 @@ func (dlv *Delivery) Deserialize(data []byte) error {
return err
}
+ // Copy over any stored header values.
+ for key, values := range idlv.Header {
+ for _, value := range values {
+ r.Header.Add(key, value)
+ }
+ }
+
// Wrap request in httpclient type.
dlv.Request = httpclient.WrapRequest(r)
diff --git a/internal/transport/delivery/delivery_test.go b/internal/transport/delivery/delivery_test.go
index e9eaf8fd1..81f32d5f8 100644
--- a/internal/transport/delivery/delivery_test.go
+++ b/internal/transport/delivery/delivery_test.go
@@ -35,32 +35,30 @@
}{
{
msg: delivery.Delivery{
- PubKeyID: "https://google.com/users/bigboy#pubkey",
ActorID: "https://google.com/users/bigboy",
ObjectID: "https://google.com/users/bigboy/follow/1",
TargetID: "https://askjeeves.com/users/smallboy",
- Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!")),
+ Request: toRequest("POST", "https://askjeeves.com/users/smallboy/inbox", []byte("data!"), http.Header{"Hello": {"world1", "world2"}}),
},
data: toJSON(map[string]any{
- "pub_key_id": "https://google.com/users/bigboy#pubkey",
- "actor_id": "https://google.com/users/bigboy",
- "object_id": "https://google.com/users/bigboy/follow/1",
- "target_id": "https://askjeeves.com/users/smallboy",
- "method": "POST",
- "url": "https://askjeeves.com/users/smallboy/inbox",
- "body": []byte("data!"),
- // "header": map[string][]string{},
+ "actor_id": "https://google.com/users/bigboy",
+ "object_id": "https://google.com/users/bigboy/follow/1",
+ "target_id": "https://askjeeves.com/users/smallboy",
+ "method": "POST",
+ "url": "https://askjeeves.com/users/smallboy/inbox",
+ "body": []byte("data!"),
+ "header": map[string][]string{"Hello": {"world1", "world2"}},
}),
},
{
msg: delivery.Delivery{
- Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin")),
+ Request: toRequest("GET", "https://google.com", []byte("uwu im just a wittle seawch engwin"), nil),
},
data: toJSON(map[string]any{
"method": "GET",
"url": "https://google.com",
"body": []byte("uwu im just a wittle seawch engwin"),
- // "header": map[string][]string{},
+ // "header": map[string][]string{},
}),
},
}
@@ -89,18 +87,18 @@ func TestDeserializeDelivery(t *testing.T) {
}
// Check that delivery fields are as expected.
- assert.Equal(t, test.msg.PubKeyID, msg.PubKeyID)
assert.Equal(t, test.msg.ActorID, msg.ActorID)
assert.Equal(t, test.msg.ObjectID, msg.ObjectID)
assert.Equal(t, test.msg.TargetID, msg.TargetID)
assert.Equal(t, test.msg.Request.Method, msg.Request.Method)
assert.Equal(t, test.msg.Request.URL, msg.Request.URL)
assert.Equal(t, readBody(test.msg.Request.Body), readBody(msg.Request.Body))
+ assert.Equal(t, test.msg.Request.Header, msg.Request.Header)
}
}
// toRequest creates httpclient.Request from HTTP method, URL and body data.
-func toRequest(method string, url string, body []byte) httpclient.Request {
+func toRequest(method string, url string, body []byte, hdr http.Header) *httpclient.Request {
var rbody io.Reader
if body != nil {
rbody = bytes.NewReader(body)
@@ -109,6 +107,11 @@ func toRequest(method string, url string, body []byte) httpclient.Request {
if err != nil {
panic(err)
}
+ for key, values := range hdr {
+ for _, value := range values {
+ req.Header.Add(key, value)
+ }
+ }
return httpclient.WrapRequest(req)
}
diff --git a/internal/transport/delivery/worker.go b/internal/transport/delivery/worker.go
index ef31e94a6..d6d253769 100644
--- a/internal/transport/delivery/worker.go
+++ b/internal/transport/delivery/worker.go
@@ -19,6 +19,7 @@
import (
"context"
+ "errors"
"slices"
"time"
@@ -160,6 +161,13 @@ func (w *Worker) process(ctx context.Context) bool {
loop:
for {
+ // Before trying to get
+ // next delivery, check
+ // context still valid.
+ if ctx.Err() != nil {
+ return true
+ }
+
// Get next delivery.
dlv, ok := w.next(ctx)
if !ok {
@@ -195,16 +203,30 @@ func (w *Worker) process(ctx context.Context) bool {
// Attempt delivery of AP request.
rsp, retry, err := w.Client.DoOnce(
- &dlv.Request,
+ dlv.Request,
)
- if err == nil {
+ switch {
+ case err == nil:
// Ensure body closed.
_ = rsp.Body.Close()
continue loop
- }
- if !retry {
+ case errors.Is(err, context.Canceled) &&
+ ctx.Err() != nil:
+ // In the case of our own context
+ // being cancelled, push delivery
+ // back onto queue for persisting.
+ //
+ // Note we specifically check against
+ // context.Canceled here as it will
+ // be faster than the mutex lock of
+ // ctx.Err(), so gives an initial
+ // faster check in the if-clause.
+ w.Queue.Push(dlv)
+ continue loop
+
+ case !retry:
// Drop deliveries when no
// retry requested, or they
// reached max (either).
@@ -222,42 +244,36 @@ func (w *Worker) process(ctx context.Context) bool {
// next gets the next available delivery, blocking until available if necessary.
func (w *Worker) next(ctx context.Context) (*Delivery, bool) {
-loop:
- for {
- // Try pop next queued.
- dlv, ok := w.Queue.Pop()
+ // Try a fast-pop of queued
+ // delivery before anything.
+ dlv, ok := w.Queue.Pop()
- if !ok {
- // Check the backlog.
- if len(w.backlog) > 0 {
+ if !ok {
+ // Check the backlog.
+ if len(w.backlog) > 0 {
- // Sort by 'next' time.
- sortDeliveries(w.backlog)
+ // Sort by 'next' time.
+ sortDeliveries(w.backlog)
- // Pop next delivery.
- dlv := w.popBacklog()
+ // Pop next delivery.
+ dlv := w.popBacklog()
- return dlv, true
- }
-
- select {
- // Backlog is empty, we MUST
- // block until next enqueued.
- case <-w.Queue.Wait():
- continue loop
-
- // Worker was stopped.
- case <-ctx.Done():
- return nil, false
- }
+ return dlv, true
}
- // Replace request context for worker state canceling.
- ctx := gtscontext.WithValues(ctx, dlv.Request.Context())
- dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
-
- return dlv, true
+ // Block on next delivery push
+ // OR worker context canceled.
+ dlv, ok = w.Queue.PopCtx(ctx)
+ if !ok {
+ return nil, false
+ }
}
+
+ // Replace request context for worker state canceling.
+ ctx = gtscontext.WithValues(ctx, dlv.Request.Context())
+ dlv.Request.Request = dlv.Request.Request.WithContext(ctx)
+
+ return dlv, true
}
// popBacklog pops next available from the backlog.
diff --git a/internal/transport/transport.go b/internal/transport/transport.go
index 2971ca603..7f7e985fc 100644
--- a/internal/transport/transport.go
+++ b/internal/transport/transport.go
@@ -30,6 +30,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
+ "github.com/superseriousbusiness/gotosocial/internal/transport/delivery"
"github.com/superseriousbusiness/httpsig"
)
@@ -50,6 +51,10 @@ type Transport interface {
// transport client, retrying on certain preset errors.
POST(*http.Request, []byte) (*http.Response, error)
+ // SignDelivery adds HTTP request signing client "middleware"
+ // to the request context within given delivery.Delivery{}.
+ SignDelivery(*delivery.Delivery) error
+
// Deliver sends an ActivityStreams object.
Deliver(ctx context.Context, obj map[string]interface{}, to *url.URL) error
diff --git a/internal/workers/worker_msg.go b/internal/workers/worker_msg.go
index 92180651a..c7dc568d7 100644
--- a/internal/workers/worker_msg.go
+++ b/internal/workers/worker_msg.go
@@ -19,6 +19,7 @@
import (
"context"
+ "errors"
"codeberg.org/gruf/go-runners"
"codeberg.org/gruf/go-structr"
@@ -147,9 +148,25 @@ func (w *MsgWorker[T]) process(ctx context.Context) {
return
}
- // Attempt to process popped message type.
- if err := w.Process(ctx, msg); err != nil {
+ // Attempt to process message.
+ err := w.Process(ctx, msg)
+ if err != nil {
log.Errorf(ctx, "%p: error processing: %v", w, err)
+
+ if errors.Is(err, context.Canceled) &&
+ ctx.Err() != nil {
+ // In the case of our own context
+ // being cancelled, push message
+ // back onto queue for persisting.
+ //
+ // Note we specifically check against
+ // context.Canceled here as it will
+ // be faster than the mutex lock of
+ // ctx.Err(), so gives an initial
+ // faster check in the if-clause.
+ w.Queue.Push(msg)
+ break
+ }
}
}
}
diff --git a/internal/workers/workers.go b/internal/workers/workers.go
index 4d2b146b6..377a9d899 100644
--- a/internal/workers/workers.go
+++ b/internal/workers/workers.go
@@ -55,7 +55,8 @@ type Workers struct {
// StartScheduler starts the job scheduler.
func (w *Workers) StartScheduler() {
- _ = w.Scheduler.Start() // false = already running
+ _ = w.Scheduler.Start()
+ // false = already running
log.Info(nil, "started scheduler")
}
@@ -82,9 +83,12 @@ func (w *Workers) Start() {
log.Infof(nil, "started %d dereference workers", n)
}
-// Stop will stop all of the contained worker pools (and global scheduler).
+// Stop will stop all of the contained
+// worker pools (and global scheduler).
func (w *Workers) Stop() {
- _ = w.Scheduler.Stop() // false = not running
+ _ = w.Scheduler.Stop()
+ // false = not running
+ log.Info(nil, "stopped scheduler")
w.Delivery.Stop()
log.Info(nil, "stopped delivery workers")
diff --git a/testrig/db.go b/testrig/db.go
index 67a7e2439..e6b40c846 100644
--- a/testrig/db.go
+++ b/testrig/db.go
@@ -29,6 +29,8 @@
var testModels = []interface{}{
>smodel.Account{},
+ >smodel.AccountNote{},
+ >smodel.AccountSettings{},
>smodel.AccountToEmoji{},
>smodel.Application{},
>smodel.Block{},
@@ -67,8 +69,7 @@
>smodel.Tombstone{},
>smodel.Report{},
>smodel.Rule{},
- >smodel.AccountNote{},
- >smodel.AccountSettings{},
+ >smodel.WorkerTask{},
}
// NewTestDB returns a new initialized, empty database for testing.