diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index d43599b05..bc56e21f0 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -35,7 +35,6 @@ "github.com/superseriousbusiness/gotosocial/internal/middleware" "go.uber.org/automaxprocs/maxprocs" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -45,7 +44,6 @@ "github.com/superseriousbusiness/gotosocial/internal/httpclient" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/processing" @@ -107,19 +105,11 @@ state.Workers.Start() defer state.Workers.Stop() - // Create the client API and federator worker pools - // NOTE: these MUST NOT be used until they are passed to the - // processor and it is started. The reason being that the processor - // sets the Worker process functions and start the underlying pools - // TODO: move these into state.Workers (and maybe reformat worker pools). - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - // build backend handlers mediaManager := media.NewManager(&state) oauthServer := oauth.New(ctx, dbService) typeConverter := typeutils.NewConverter(dbService) - federatingDB := federatingdb.New(dbService, fedWorker, typeConverter) + federatingDB := federatingdb.New(&state, typeConverter) transportController := transport.NewController(dbService, federatingDB, &federation.Clock{}, client) federator := federation.NewFederator(dbService, federatingDB, transportController, typeConverter, mediaManager) @@ -140,11 +130,15 @@ } // create the message processor using the other services we've created so far - processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, storage, dbService, emailSender, clientWorker, fedWorker) + processor := processing.NewProcessor(typeConverter, federator, oauthServer, mediaManager, &state, emailSender) if err := processor.Start(); err != nil { return fmt.Errorf("error creating processor: %s", err) } + // Set state client / federator worker enqueue functions + state.Workers.EnqueueClientAPI = processor.EnqueueClientAPI + state.Workers.EnqueueFederator = processor.EnqueueFederator + /* HTTP router initialization */ diff --git a/cmd/gotosocial/action/testrig/testrig.go b/cmd/gotosocial/action/testrig/testrig.go index 3be7907fe..68bb94ec3 100644 --- a/cmd/gotosocial/action/testrig/testrig.go +++ b/cmd/gotosocial/action/testrig/testrig.go @@ -33,14 +33,13 @@ "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/gotosocial" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/oidc" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/web" "github.com/superseriousbusiness/gotosocial/testrig" @@ -48,37 +47,44 @@ // Start creates and starts a gotosocial testrig server var Start action.GTSAction = func(ctx context.Context) error { + var state state.State + testrig.InitTestConfig() testrig.InitTestLog() - dbService := testrig.NewTestDB() - testrig.StandardDBSetup(dbService, nil) - var storageBackend *storage.Driver - if os.Getenv("GTS_STORAGE_BACKEND") == "s3" { - storageBackend, _ = storage.NewS3Storage() - } else { - storageBackend = testrig.NewInMemoryStorage() - } - testrig.StandardStorageSetup(storageBackend, "./testrig/media") + // Initialize caches + state.Caches.Init() + state.Caches.Start() + defer state.Caches.Stop() - // Create client API and federator worker pools - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + state.DB = testrig.NewTestDB(&state) + testrig.StandardDBSetup(state.DB, nil) + + if os.Getenv("GTS_STORAGE_BACKEND") == "s3" { + state.Storage, _ = storage.NewS3Storage() + } else { + state.Storage = testrig.NewInMemoryStorage() + } + testrig.StandardStorageSetup(state.Storage, "./testrig/media") + + // Initialize workers. + state.Workers.Start() + defer state.Workers.Stop() // build backend handlers - transportController := testrig.NewTestTransportController(testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { + transportController := testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(func(req *http.Request) (*http.Response, error) { r := io.NopCloser(bytes.NewReader([]byte{})) return &http.Response{ StatusCode: 200, Body: r, }, nil - }, ""), dbService, fedWorker) - mediaManager := testrig.NewTestMediaManager(dbService, storageBackend) - federator := testrig.NewTestFederator(dbService, transportController, storageBackend, mediaManager, fedWorker) + }, "")) + mediaManager := testrig.NewTestMediaManager(&state) + federator := testrig.NewTestFederator(&state, transportController, mediaManager) emailSender := testrig.NewEmailSender("./web/template/", nil) - processor := testrig.NewTestProcessor(dbService, storageBackend, federator, emailSender, mediaManager, clientWorker, fedWorker) + processor := testrig.NewTestProcessor(&state, federator, emailSender, mediaManager) if err := processor.Start(); err != nil { return fmt.Errorf("error starting processor: %s", err) } @@ -87,7 +93,7 @@ HTTP router initialization */ - router := testrig.NewTestRouter(dbService) + router := testrig.NewTestRouter(state.DB) // attach global middlewares which are used for every request router.AttachGlobalMiddleware( @@ -112,7 +118,7 @@ } } - routerSession, err := dbService.GetSession(ctx) + routerSession, err := state.DB.GetSession(ctx) if err != nil { return fmt.Errorf("error retrieving router session for session middleware: %w", err) } @@ -123,13 +129,13 @@ } var ( - authModule = api.NewAuth(dbService, processor, idp, routerSession, sessionName) // auth/oauth paths - clientModule = api.NewClient(dbService, processor) // api client endpoints - fileserverModule = api.NewFileserver(processor) // fileserver endpoints - wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints - nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint - activityPubModule = api.NewActivityPub(dbService, processor) // ActivityPub endpoints - webModule = web.New(dbService, processor) // web pages + user profiles + settings panels etc + authModule = api.NewAuth(state.DB, processor, idp, routerSession, sessionName) // auth/oauth paths + clientModule = api.NewClient(state.DB, processor) // api client endpoints + fileserverModule = api.NewFileserver(processor) // fileserver endpoints + wellKnownModule = api.NewWellKnown(processor) // .well-known endpoints + nodeInfoModule = api.NewNodeInfo(processor) // nodeinfo endpoint + activityPubModule = api.NewActivityPub(state.DB, processor) // ActivityPub endpoints + webModule = web.New(state.DB, processor) // web pages + user profiles + settings panels etc ) // these should be routed in order @@ -142,7 +148,7 @@ activityPubModule.RoutePublicKey(router) webModule.Route(router) - gts, err := gotosocial.NewServer(dbService, router, federator, mediaManager) + gts, err := gotosocial.NewServer(state.DB, router, federator, mediaManager) if err != nil { return fmt.Errorf("error creating gotosocial service: %s", err) } @@ -157,8 +163,8 @@ sig := <-sigs log.Infof(ctx, "received signal %s, shutting down", sig) - testrig.StandardDBTeardown(dbService) - testrig.StandardStorageTeardown(storageBackend) + testrig.StandardDBTeardown(state.DB) + testrig.StandardStorageTeardown(state.Storage) // close down all running services in order if err := gts.Stop(ctx); err != nil { diff --git a/internal/api/activitypub/emoji/emojiget_test.go b/internal/api/activitypub/emoji/emojiget_test.go index cd7333955..8f99efdfc 100644 --- a/internal/api/activitypub/emoji/emojiget_test.go +++ b/internal/api/activitypub/emoji/emojiget_test.go @@ -27,15 +27,14 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/emoji" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -50,6 +49,7 @@ type EmojiGetTestSuite struct { emailSender email.Sender processor *processing.Processor storage *storage.Driver + state state.State testEmojis map[string]*gtsmodel.Emoji testAccounts map[string]*gtsmodel.Account @@ -65,19 +65,23 @@ func (suite *EmojiGetTestSuite) SetupSuite() { } func (suite *EmojiGetTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - suite.db = testrig.NewTestDB() - suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + + suite.tc = testrig.NewTestTypeConverter(suite.db) + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.emojiModule = emoji.New(suite.processor) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -90,6 +94,7 @@ func (suite *EmojiGetTestSuite) SetupTest() { func (suite *EmojiGetTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func (suite *EmojiGetTestSuite) TestGetEmoji() { diff --git a/internal/api/activitypub/users/inboxpost_test.go b/internal/api/activitypub/users/inboxpost_test.go index 0ad63abf7..fa23204c9 100644 --- a/internal/api/activitypub/users/inboxpost_test.go +++ b/internal/api/activitypub/users/inboxpost_test.go @@ -34,11 +34,9 @@ "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -86,13 +84,10 @@ func (suite *InboxPostTestSuite) TestPostBlock() { suite.NoError(err) body := bytes.NewReader(bodyJson) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) - federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) + federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) emailSender := testrig.NewEmailSender("../../../../web/template/", nil) - processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) + processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) suite.NoError(processor.Start()) @@ -190,13 +185,10 @@ func (suite *InboxPostTestSuite) TestPostUnblock() { suite.NoError(err) body := bytes.NewReader(bodyJson) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) - federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) + federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) emailSender := testrig.NewEmailSender("../../../../web/template/", nil) - processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) + processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) suite.NoError(processor.Start()) @@ -291,9 +283,6 @@ func (suite *InboxPostTestSuite) TestPostUpdate() { suite.NoError(err) body := bytes.NewReader(bodyJson) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - // use a different version of the mock http client which serves the updated // version of the remote account, as though it had been updated there too; // this is needed so it can be dereferenced + updated properly @@ -301,10 +290,11 @@ func (suite *InboxPostTestSuite) TestPostUpdate() { mockHTTPClient.TestRemotePeople = map[string]vocab.ActivityStreamsPerson{ updatedAccount.URI: asAccount, } - tc := testrig.NewTestTransportController(mockHTTPClient, suite.db, fedWorker) - federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + + tc := testrig.NewTestTransportController(&suite.state, mockHTTPClient) + federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) emailSender := testrig.NewEmailSender("../../../../web/template/", nil) - processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) + processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) suite.NoError(processor.Start()) @@ -430,15 +420,12 @@ func (suite *InboxPostTestSuite) TestPostDelete() { suite.NoError(err) body := bytes.NewReader(bodyJson) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) - federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) + federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) emailSender := testrig.NewEmailSender("../../../../web/template/", nil) - processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) - suite.NoError(processor.Start()) + processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) + suite.NoError(processor.Start()) // setup request recorder := httptest.NewRecorder() diff --git a/internal/api/activitypub/users/outboxget_test.go b/internal/api/activitypub/users/outboxget_test.go index 6e5c4e1e0..8f3306a25 100644 --- a/internal/api/activitypub/users/outboxget_test.go +++ b/internal/api/activitypub/users/outboxget_test.go @@ -32,8 +32,6 @@ "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -104,13 +102,10 @@ func (suite *OutboxGetTestSuite) TestGetOutboxFirstPage() { signedRequest := derefRequests["foss_satan_dereference_zork_outbox_first"] targetAccount := suite.testAccounts["local_account_1"] - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) - federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) + federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) emailSender := testrig.NewEmailSender("../../../../web/template/", nil) - processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) + processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) suite.NoError(processor.Start()) @@ -182,13 +177,10 @@ func (suite *OutboxGetTestSuite) TestGetOutboxNextPage() { signedRequest := derefRequests["foss_satan_dereference_zork_outbox_next"] targetAccount := suite.testAccounts["local_account_1"] - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) - federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) + federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) emailSender := testrig.NewEmailSender("../../../../web/template/", nil) - processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) + processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) suite.NoError(processor.Start()) diff --git a/internal/api/activitypub/users/repliesget_test.go b/internal/api/activitypub/users/repliesget_test.go index 4e985a0a1..92e5cddfa 100644 --- a/internal/api/activitypub/users/repliesget_test.go +++ b/internal/api/activitypub/users/repliesget_test.go @@ -33,8 +33,6 @@ "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -104,13 +102,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesNext() { targetAccount := suite.testAccounts["local_account_1"] targetStatus := suite.testStatuses["local_account_1_status_1"] - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) - federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) + federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) emailSender := testrig.NewEmailSender("../../../../web/template/", nil) - processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) + processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) suite.NoError(processor.Start()) @@ -172,13 +167,10 @@ func (suite *RepliesGetTestSuite) TestGetRepliesLast() { targetAccount := suite.testAccounts["local_account_1"] targetStatus := suite.testStatuses["local_account_1_status_1"] - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - tc := testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker) - federator := testrig.NewTestFederator(suite.db, tc, suite.storage, suite.mediaManager, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")) + federator := testrig.NewTestFederator(&suite.state, tc, suite.mediaManager) emailSender := testrig.NewEmailSender("../../../../web/template/", nil) - processor := testrig.NewTestProcessor(suite.db, suite.storage, federator, emailSender, suite.mediaManager, clientWorker, fedWorker) + processor := testrig.NewTestProcessor(&suite.state, federator, emailSender, suite.mediaManager) userModule := users.New(processor) suite.NoError(processor.Start()) diff --git a/internal/api/activitypub/users/user_test.go b/internal/api/activitypub/users/user_test.go index 0124925b9..d025eada0 100644 --- a/internal/api/activitypub/users/user_test.go +++ b/internal/api/activitypub/users/user_test.go @@ -22,15 +22,14 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/activitypub/users" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -46,6 +45,7 @@ type UserStandardTestSuite struct { emailSender email.Sender processor *processing.Processor storage *storage.Driver + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -75,19 +75,21 @@ func (suite *UserStandardTestSuite) SetupSuite() { } func (suite *UserStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.tc = testrig.NewTestTypeConverter(suite.db) suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.userModule = users.New(suite.processor) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -100,4 +102,5 @@ func (suite *UserStandardTestSuite) SetupTest() { func (suite *UserStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } diff --git a/internal/api/auth/auth_test.go b/internal/api/auth/auth_test.go index a5e518cda..1a15155bd 100644 --- a/internal/api/auth/auth_test.go +++ b/internal/api/auth/auth_test.go @@ -28,17 +28,16 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/auth" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -47,6 +46,7 @@ type AuthStandardTestSuite struct { suite.Suite db db.DB storage *storage.Driver + state state.State mediaManager media.Manager federator federation.Federator processor *processing.Processor @@ -78,18 +78,19 @@ func (suite *AuthStandardTestSuite) SetupSuite() { } func (suite *AuthStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.authModule = auth.New(suite.db, suite.processor, suite.idp) testrig.StandardDBSetup(suite.db, suite.testAccounts) } diff --git a/internal/api/client/accounts/account_test.go b/internal/api/client/accounts/account_test.go index 5a25c12f1..ab3f4cd1f 100644 --- a/internal/api/client/accounts/account_test.go +++ b/internal/api/client/accounts/account_test.go @@ -27,16 +27,15 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -51,6 +50,7 @@ type AccountStandardTestSuite struct { processor *processing.Processor emailSender email.Sender sentEmails map[string]string + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -76,19 +76,22 @@ func (suite *AccountStandardTestSuite) SetupSuite() { } func (suite *AccountStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.accountsModule = accounts.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -99,6 +102,7 @@ func (suite *AccountStandardTestSuite) SetupTest() { func (suite *AccountStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func (suite *AccountStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { diff --git a/internal/api/client/admin/admin_test.go b/internal/api/client/admin/admin_test.go index 4f3f48904..1d19635f0 100644 --- a/internal/api/client/admin/admin_test.go +++ b/internal/api/client/admin/admin_test.go @@ -27,16 +27,15 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/admin" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -51,6 +50,7 @@ type AdminStandardTestSuite struct { processor *processing.Processor emailSender email.Sender sentEmails map[string]string + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -82,19 +82,22 @@ func (suite *AdminStandardTestSuite) SetupSuite() { } func (suite *AdminStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.adminModule = admin.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -103,6 +106,7 @@ func (suite *AdminStandardTestSuite) SetupTest() { func (suite *AdminStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func (suite *AdminStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { diff --git a/internal/api/client/bookmarks/bookmarks_test.go b/internal/api/client/bookmarks/bookmarks_test.go index c39ad49f3..931d504f7 100644 --- a/internal/api/client/bookmarks/bookmarks_test.go +++ b/internal/api/client/bookmarks/bookmarks_test.go @@ -32,16 +32,15 @@ "github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks" "github.com/superseriousbusiness/gotosocial/internal/api/client/statuses" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -57,6 +56,7 @@ type BookmarkTestSuite struct { emailSender email.Sender processor *processing.Processor storage *storage.Driver + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -87,22 +87,25 @@ func (suite *BookmarkTestSuite) SetupSuite() { } func (suite *BookmarkTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() - suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + + suite.tc = testrig.NewTestTypeConverter(suite.db) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.statusModule = statuses.New(suite.processor) suite.bookmarkModule = bookmarks.New(suite.processor) @@ -112,6 +115,7 @@ func (suite *BookmarkTestSuite) SetupTest() { func (suite *BookmarkTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func (suite *BookmarkTestSuite) getBookmarks( diff --git a/internal/api/client/favourites/favourites_test.go b/internal/api/client/favourites/favourites_test.go index 7949aa38c..71c7097cc 100644 --- a/internal/api/client/favourites/favourites_test.go +++ b/internal/api/client/favourites/favourites_test.go @@ -21,14 +21,13 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/favourites" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -44,6 +43,7 @@ type FavouritesStandardTestSuite struct { emailSender email.Sender processor *processing.Processor storage *storage.Driver + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -71,22 +71,25 @@ func (suite *FavouritesStandardTestSuite) SetupSuite() { } func (suite *FavouritesStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() - suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + + suite.tc = testrig.NewTestTypeConverter(suite.db) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.favModule = favourites.New(suite.processor) suite.NoError(suite.processor.Start()) @@ -95,6 +98,7 @@ func (suite *FavouritesStandardTestSuite) SetupTest() { func (suite *FavouritesStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func (suite *FavouritesStandardTestSuite) TestProcessFave() {} diff --git a/internal/api/client/followrequests/followrequest_test.go b/internal/api/client/followrequests/followrequest_test.go index 7a08479ab..294dbc7ed 100644 --- a/internal/api/client/followrequests/followrequest_test.go +++ b/internal/api/client/followrequests/followrequest_test.go @@ -26,16 +26,15 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/followrequests" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -48,6 +47,7 @@ type FollowRequestStandardTestSuite struct { federator federation.Federator processor *processing.Processor emailSender email.Sender + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -73,18 +73,21 @@ func (suite *FollowRequestStandardTestSuite) SetupSuite() { } func (suite *FollowRequestStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.followRequestModule = followrequests.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -95,6 +98,7 @@ func (suite *FollowRequestStandardTestSuite) SetupTest() { func (suite *FollowRequestStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func (suite *FollowRequestStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestMethod string, requestBody []byte, requestPath string, bodyContentType string) *gin.Context { diff --git a/internal/api/client/instance/instance_test.go b/internal/api/client/instance/instance_test.go index ff622febe..6870d2a44 100644 --- a/internal/api/client/instance/instance_test.go +++ b/internal/api/client/instance/instance_test.go @@ -26,16 +26,15 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/instance" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -50,6 +49,7 @@ type InstanceStandardTestSuite struct { processor *processing.Processor emailSender email.Sender sentEmails map[string]string + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -75,19 +75,22 @@ func (suite *InstanceStandardTestSuite) SetupSuite() { } func (suite *InstanceStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.instanceModule = instance.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -96,6 +99,7 @@ func (suite *InstanceStandardTestSuite) SetupTest() { func (suite *InstanceStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func (suite *InstanceStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, method string, path string, body []byte, contentType string, auth bool) *gin.Context { diff --git a/internal/api/client/media/mediacreate_test.go b/internal/api/client/media/mediacreate_test.go index caa40b061..6439895f3 100644 --- a/internal/api/client/media/mediacreate_test.go +++ b/internal/api/client/media/mediacreate_test.go @@ -33,7 +33,6 @@ "github.com/stretchr/testify/suite" mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -41,9 +40,9 @@ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -60,6 +59,7 @@ type MediaCreateTestSuite struct { oauthServer oauth.Server emailSender email.Sender processor *processing.Processor + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -78,21 +78,24 @@ type MediaCreateTestSuite struct { */ func (suite *MediaCreateTestSuite) SetupSuite() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + // setup standard items testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + suite.tc = testrig.NewTestTypeConverter(suite.db) - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.oauthServer = testrig.NewTestOauthServer(suite.db) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) // setup module being tested suite.mediaModule = mediamodule.New(suite.processor) @@ -102,11 +105,15 @@ func (suite *MediaCreateTestSuite) TearDownSuite() { if err := suite.db.Stop(context.Background()); err != nil { log.Panicf(nil, "error closing db connection: %s", err) } + testrig.StopWorkers(&suite.state) } func (suite *MediaCreateTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") + suite.testTokens = testrig.NewTestTokens() suite.testClients = testrig.NewTestClients() suite.testApplications = testrig.NewTestApplications() diff --git a/internal/api/client/media/mediaupdate_test.go b/internal/api/client/media/mediaupdate_test.go index cb96e8aa1..75657e1b5 100644 --- a/internal/api/client/media/mediaupdate_test.go +++ b/internal/api/client/media/mediaupdate_test.go @@ -31,7 +31,6 @@ "github.com/stretchr/testify/suite" mediamodule "github.com/superseriousbusiness/gotosocial/internal/api/client/media" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" @@ -39,9 +38,9 @@ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -58,6 +57,7 @@ type MediaUpdateTestSuite struct { oauthServer oauth.Server emailSender email.Sender processor *processing.Processor + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -76,21 +76,23 @@ type MediaUpdateTestSuite struct { */ func (suite *MediaUpdateTestSuite) SetupSuite() { + testrig.StartWorkers(&suite.state) + // setup standard items testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + suite.tc = testrig.NewTestTypeConverter(suite.db) - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.oauthServer = testrig.NewTestOauthServer(suite.db) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) // setup module being tested suite.mediaModule = mediamodule.New(suite.processor) @@ -100,11 +102,15 @@ func (suite *MediaUpdateTestSuite) TearDownSuite() { if err := suite.db.Stop(context.Background()); err != nil { log.Panicf(nil, "error closing db connection: %s", err) } + testrig.StopWorkers(&suite.state) } func (suite *MediaUpdateTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") + suite.testTokens = testrig.NewTestTokens() suite.testClients = testrig.NewTestClients() suite.testApplications = testrig.NewTestApplications() diff --git a/internal/api/client/reports/reports_test.go b/internal/api/client/reports/reports_test.go index 1c5a532b9..cdab0b77b 100644 --- a/internal/api/client/reports/reports_test.go +++ b/internal/api/client/reports/reports_test.go @@ -21,14 +21,13 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/reports" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -42,6 +41,7 @@ type ReportsStandardTestSuite struct { processor *processing.Processor emailSender email.Sender sentEmails map[string]string + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -67,19 +67,22 @@ func (suite *ReportsStandardTestSuite) SetupSuite() { } func (suite *ReportsStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.reportsModule = reports.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -90,4 +93,5 @@ func (suite *ReportsStandardTestSuite) SetupTest() { func (suite *ReportsStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } diff --git a/internal/api/client/search/search_test.go b/internal/api/client/search/search_test.go index 4580f6f9d..153328cc3 100644 --- a/internal/api/client/search/search_test.go +++ b/internal/api/client/search/search_test.go @@ -26,16 +26,15 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/search" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -50,6 +49,7 @@ type SearchStandardTestSuite struct { processor *processing.Processor emailSender email.Sender sentEmails map[string]string + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -71,19 +71,22 @@ func (suite *SearchStandardTestSuite) SetupSuite() { } func (suite *SearchStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.searchModule = search.New(suite.processor) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -94,6 +97,7 @@ func (suite *SearchStandardTestSuite) SetupTest() { func (suite *SearchStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func (suite *SearchStandardTestSuite) newContext(recorder *httptest.ResponseRecorder, requestPath string) *gin.Context { diff --git a/internal/api/client/statuses/status_test.go b/internal/api/client/statuses/status_test.go index a87fd36f7..93745ffd8 100644 --- a/internal/api/client/statuses/status_test.go +++ b/internal/api/client/statuses/status_test.go @@ -21,14 +21,13 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/statuses" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -44,6 +43,7 @@ type StatusStandardTestSuite struct { emailSender email.Sender processor *processing.Processor storage *storage.Driver + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -71,22 +71,26 @@ func (suite *StatusStandardTestSuite) SetupSuite() { } func (suite *StatusStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() - suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + + suite.tc = testrig.NewTestTypeConverter(suite.db) + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.statusModule = statuses.New(suite.processor) suite.NoError(suite.processor.Start()) @@ -95,4 +99,5 @@ func (suite *StatusStandardTestSuite) SetupTest() { func (suite *StatusStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } diff --git a/internal/api/client/streaming/streaming_test.go b/internal/api/client/streaming/streaming_test.go index 5fb470af8..ac27aad8a 100644 --- a/internal/api/client/streaming/streaming_test.go +++ b/internal/api/client/streaming/streaming_test.go @@ -32,15 +32,14 @@ "github.com/gin-gonic/gin" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/streaming" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -56,6 +55,7 @@ type StreamingTestSuite struct { emailSender email.Sender processor *processing.Processor storage *storage.Driver + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -83,22 +83,25 @@ func (suite *StreamingTestSuite) SetupSuite() { } func (suite *StreamingTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() - suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + + suite.tc = testrig.NewTestTypeConverter(suite.db) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.streamingModule = streaming.New(suite.processor, 1, 4096) suite.NoError(suite.processor.Start()) } @@ -106,6 +109,7 @@ func (suite *StreamingTestSuite) SetupTest() { func (suite *StreamingTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } // Addr is a fake network interface which implements the net.Addr interface diff --git a/internal/api/client/user/user_test.go b/internal/api/client/user/user_test.go index c990abb56..ce117059e 100644 --- a/internal/api/client/user/user_test.go +++ b/internal/api/client/user/user_test.go @@ -21,14 +21,13 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/client/user" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -43,6 +42,7 @@ type UserStandardTestSuite struct { emailSender email.Sender processor *processing.Processor storage *storage.Driver + state state.State testTokens map[string]*gtsmodel.Token testClients map[string]*gtsmodel.Client @@ -56,23 +56,29 @@ type UserStandardTestSuite struct { } func (suite *UserStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) + suite.testTokens = testrig.NewTestTokens() suite.testClients = testrig.NewTestClients() suite.testApplications = testrig.NewTestApplications() suite.testUsers = testrig.NewTestUsers() suite.testAccounts = testrig.NewTestAccounts() - suite.db = testrig.NewTestDB() + + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage + suite.tc = testrig.NewTestTypeConverter(suite.db) - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.sentEmails = make(map[string]string) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", suite.sentEmails) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.userModule = user.New(suite.processor) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media") @@ -83,4 +89,5 @@ func (suite *UserStandardTestSuite) SetupTest() { func (suite *UserStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } diff --git a/internal/api/fileserver/fileserver_test.go b/internal/api/fileserver/fileserver_test.go index 0a6879e70..0e0dd9434 100644 --- a/internal/api/fileserver/fileserver_test.go +++ b/internal/api/fileserver/fileserver_test.go @@ -23,16 +23,15 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/fileserver" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -43,6 +42,7 @@ type FileserverTestSuite struct { suite.Suite db db.DB storage *storage.Driver + state state.State federator federation.Federator tc typeutils.TypeConverter processor *processing.Processor @@ -67,26 +67,32 @@ type FileserverTestSuite struct { */ func (suite *FileserverTestSuite) SetupSuite() { + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.storage = testrig.NewInMemoryStorage() - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) - suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) + suite.state.Storage = suite.storage + + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), suite.mediaManager) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker) suite.tc = testrig.NewTestTypeConverter(suite.db) - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.oauthServer = testrig.NewTestOauthServer(suite.db) + suite.emailSender = testrig.NewEmailSender("../../../web/template/", nil) suite.fileServer = fileserver.New(suite.processor) } func (suite *FileserverTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") suite.testTokens = testrig.NewTestTokens() @@ -101,9 +107,11 @@ func (suite *FileserverTestSuite) TearDownSuite() { if err := suite.db.Stop(context.Background()); err != nil { log.Panicf(nil, "error closing db connection: %s", err) } + testrig.StopWorkers(&suite.state) } func (suite *FileserverTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } diff --git a/internal/api/wellknown/webfinger/webfinger_test.go b/internal/api/wellknown/webfinger/webfinger_test.go index 38228e928..3148279c5 100644 --- a/internal/api/wellknown/webfinger/webfinger_test.go +++ b/internal/api/wellknown/webfinger/webfinger_test.go @@ -26,15 +26,14 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -44,6 +43,7 @@ type WebfingerStandardTestSuite struct { // standard suite interfaces suite.Suite db db.DB + state state.State tc typeutils.TypeConverter mediaManager media.Manager federator federation.Federator @@ -76,19 +76,21 @@ func (suite *WebfingerStandardTestSuite) SetupSuite() { } func (suite *WebfingerStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestLog() testrig.InitTestConfig() - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.tc = testrig.NewTestTypeConverter(suite.db) suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker) + suite.state.Storage = suite.storage + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../../testrig/media")), suite.mediaManager) suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil) - suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, suite.mediaManager, clientWorker, fedWorker) + suite.processor = testrig.NewTestProcessor(&suite.state, suite.federator, suite.emailSender, suite.mediaManager) suite.webfingerModule = webfinger.New(suite.processor) suite.oauthServer = testrig.NewTestOauthServer(suite.db) testrig.StandardDBSetup(suite.db, suite.testAccounts) @@ -100,6 +102,7 @@ func (suite *WebfingerStandardTestSuite) SetupTest() { func (suite *WebfingerStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } func accountDomainAccount() *gtsmodel.Account { diff --git a/internal/api/wellknown/webfinger/webfingerget_test.go b/internal/api/wellknown/webfinger/webfingerget_test.go index 7587dfee1..a345d0602 100644 --- a/internal/api/wellknown/webfinger/webfingerget_test.go +++ b/internal/api/wellknown/webfinger/webfingerget_test.go @@ -30,9 +30,7 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/api/wellknown/webfinger" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -91,9 +89,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByHo config.SetHost("gts.example.org") config.SetAccountDomain("example.org") - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) + suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender) suite.webfingerModule = webfinger.New(suite.processor) targetAccount := accountDomainAccount() @@ -148,9 +144,7 @@ func (suite *WebfingerGetTestSuite) TestFingerUserWithDifferentAccountDomainByAc config.SetHost("gts.example.org") config.SetAccountDomain("example.org") - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(suite.db, suite.storage), suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) + suite.processor = processing.NewProcessor(suite.tc, suite.federator, testrig.NewTestOauthServer(suite.db), testrig.NewTestMediaManager(&suite.state), &suite.state, suite.emailSender) suite.webfingerModule = webfinger.New(suite.processor) targetAccount := accountDomainAccount() diff --git a/internal/concurrency/workers.go b/internal/concurrency/workers.go deleted file mode 100644 index ed99509cf..000000000 --- a/internal/concurrency/workers.go +++ /dev/null @@ -1,141 +0,0 @@ -/* - GoToSocial - Copyright (C) 2021-2023 GoToSocial Authors admin@gotosocial.org - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . -*/ - -package concurrency - -import ( - "context" - "errors" - "fmt" - "path" - "reflect" - "runtime" - - "codeberg.org/gruf/go-kv" - "codeberg.org/gruf/go-runners" - "github.com/superseriousbusiness/gotosocial/internal/log" -) - -// WorkerPool represents a proccessor for MsgType objects, using a worker pool to allocate resources. -type WorkerPool[MsgType any] struct { - workers runners.WorkerPool - process func(context.Context, MsgType) error - nw, nq int - wtype string // contains worker type for logging -} - -// New returns a new WorkerPool[MsgType] with given number of workers and queue ratio, -// where the queue ratio is multiplied by no. workers to get queue size. If args < 1 -// then suitable defaults are determined from the runtime's GOMAXPROCS variable. -func NewWorkerPool[MsgType any](workers int, queueRatio int) *WorkerPool[MsgType] { - var zero MsgType - - if workers < 1 { - // ensure sensible workers - workers = runtime.GOMAXPROCS(0) * 4 - } - if queueRatio < 1 { - // ensure sensible ratio - queueRatio = 100 - } - - // Calculate the short type string for the msg type - msgType := reflect.TypeOf(zero).String() - _, msgType = path.Split(msgType) - - w := &WorkerPool[MsgType]{ - process: nil, - nw: workers, - nq: workers * queueRatio, - wtype: fmt.Sprintf("worker.Worker[%s]", msgType), - } - - // Log new worker creation with worker type prefix - log.Infof(nil, "%s created with workers=%d queue=%d", - w.wtype, - workers, - workers*queueRatio, - ) - - return w -} - -// Start will attempt to start the underlying worker pool, or return error. -func (w *WorkerPool[MsgType]) Start() error { - log.Infof(nil, "%s starting", w.wtype) - - // Check processor was set - if w.process == nil { - return errors.New("nil Worker.process function") - } - - // Attempt to start pool - if !w.workers.Start(w.nw, w.nq) { - return errors.New("failed to start Worker pool") - } - - return nil -} - -// Stop will attempt to stop the underlying worker pool, or return error. -func (w *WorkerPool[MsgType]) Stop() error { - log.Infof(nil, "%s stopping", w.wtype) - - // Attempt to stop pool - if !w.workers.Stop() { - return errors.New("failed to stop Worker pool") - } - - return nil -} - -// SetProcessor will set the Worker's processor function, which is called for each queued message. -func (w *WorkerPool[MsgType]) SetProcessor(fn func(context.Context, MsgType) error) { - if w.process != nil { - log.Panicf(nil, "%s Worker.process is already set", w.wtype) - } - w.process = fn -} - -// Queue will queue provided message to be processed with there's a free worker. -func (w *WorkerPool[MsgType]) Queue(msg MsgType) { - log.Tracef(nil, "%s queueing message: %+v", w.wtype, msg) - - // Create new process function for msg - process := func(ctx context.Context) { - if err := w.process(ctx, msg); err != nil { - log.WithContext(ctx). - WithFields(kv.Fields{ - kv.Field{K: "type", V: w.wtype}, - kv.Field{K: "error", V: err}, - }...).Error("message processing error") - } - } - - // Attempt a fast-enqueue of process - if !w.workers.EnqueueNow(process) { - // No spot acquired, log warning - log.WithFields(kv.Fields{ - kv.Field{K: "type", V: w.wtype}, - kv.Field{K: "queue", V: w.workers.Queue()}, - }...).Warn("full worker queue") - - // Block on enqueuing process func - w.workers.Enqueue(process) - } -} diff --git a/internal/db/bundb/admin_test.go b/internal/db/bundb/admin_test.go index b0da97ef1..ce255d036 100644 --- a/internal/db/bundb/admin_test.go +++ b/internal/db/bundb/admin_test.go @@ -70,8 +70,8 @@ func (suite *AdminTestSuite) TestIsEmailAvailableDomainBlocked() { } func (suite *AdminTestSuite) TestCreateInstanceAccount() { - // reinitialize test DB to clear caches - suite.db = testrig.NewTestDB() + // reinitialize db caches to clear + suite.state.Caches.Init() // we need to take an empty db for this... testrig.StandardDBTeardown(suite.db) // ...with tables created but no data diff --git a/internal/db/bundb/bundb_test.go b/internal/db/bundb/bundb_test.go index e050c2b5d..bad8bfc72 100644 --- a/internal/db/bundb/bundb_test.go +++ b/internal/db/bundb/bundb_test.go @@ -22,13 +22,15 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/testrig" ) type BunDBStandardTestSuite struct { // standard suite interfaces suite.Suite - db db.DB + db db.DB + state state.State // standard suite models testTokens map[string]*gtsmodel.Token @@ -61,9 +63,10 @@ func (suite *BunDBStandardTestSuite) SetupSuite() { } func (suite *BunDBStandardTestSuite) SetupTest() { + suite.state.Caches.Init() testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) testrig.StandardDBSetup(suite.db, suite.testAccounts) } diff --git a/internal/federation/dereferencing/dereferencer_test.go b/internal/federation/dereferencing/dereferencer_test.go index daca8b7de..f5b59b0ed 100644 --- a/internal/federation/dereferencing/dereferencer_test.go +++ b/internal/federation/dereferencing/dereferencer_test.go @@ -21,11 +21,10 @@ import ( "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/activity/streams/vocab" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation/dereferencing" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -34,6 +33,7 @@ type DereferencerStandardTestSuite struct { suite.Suite db db.DB storage *storage.Driver + state state.State testRemoteStatuses map[string]vocab.ActivityStreamsNote testRemotePeople map[string]vocab.ActivityStreamsPerson @@ -58,12 +58,19 @@ func (suite *DereferencerStandardTestSuite) SetupTest() { suite.testRemoteAttachments = testrig.NewTestFediAttachments("../../../testrig/media") suite.testEmojis = testrig.NewTestEmojis() - suite.db = testrig.NewTestDB() + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + + suite.db = testrig.NewTestDB(&suite.state) suite.storage = testrig.NewInMemoryStorage() - suite.dereferencer = dereferencing.NewDereferencer(suite.db, testrig.NewTestTypeConverter(suite.db), testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1)), testrig.NewTestMediaManager(suite.db, suite.storage)) + suite.state.DB = suite.db + suite.state.Storage = suite.storage + media := testrig.NewTestMediaManager(&suite.state) + suite.dereferencer = dereferencing.NewDereferencer(suite.db, testrig.NewTestTypeConverter(suite.db), testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")), media) testrig.StandardDBSetup(suite.db, nil) } func (suite *DereferencerStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) + testrig.StopWorkers(&suite.state) } diff --git a/internal/federation/federatingactor_test.go b/internal/federation/federatingactor_test.go index 0d1d8e37f..f63ecd827 100644 --- a/internal/federation/federatingactor_test.go +++ b/internal/federation/federatingactor_test.go @@ -27,10 +27,8 @@ "time" "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -56,14 +54,12 @@ func (suite *FederatingActorTestSuite) TestSendNoRemoteFollowers() { ) testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), time.Now(), testNote) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - // setup transport controller with a no-op client so we don't make external calls httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, httpClient) // setup module being tested - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity) suite.NoError(err) @@ -105,12 +101,10 @@ func (suite *FederatingActorTestSuite) TestSendRemoteFollower() { ) testActivity := testrig.WrapAPNoteInCreate(testrig.URLMustParse("http://localhost:8080/whatever_some_create"), testrig.URLMustParse(testAccount.URI), testrig.TimeMustParse("2022-06-02T12:22:21+02:00"), testNote) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, httpClient) // setup module being tested - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) activity, err := federator.FederatingActor().Send(ctx, testrig.URLMustParse(testAccount.OutboxURI), testActivity) suite.NoError(err) diff --git a/internal/federation/federatingdb/accept.go b/internal/federation/federatingdb/accept.go index d3e227a10..184d2b09d 100644 --- a/internal/federation/federatingdb/accept.go +++ b/internal/federation/federatingdb/accept.go @@ -65,7 +65,7 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA if uris.IsFollowPath(acceptedObjectIRI) { // ACCEPT FOLLOW gtsFollowRequest := >smodel.FollowRequest{} - if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil { + if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil { return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err) } @@ -73,12 +73,12 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA if gtsFollowRequest.AccountID != receivingAccount.ID { return errors.New("ACCEPT: follow object account and inbox account were not the same") } - follow, err := f.db.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID) + follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID) if err != nil { return err } - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityAccept, GTSModel: follow, @@ -108,12 +108,12 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA if gtsFollow.AccountID != receivingAccount.ID { return errors.New("ACCEPT: follow object account and inbox account were not the same") } - follow, err := f.db.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID) + follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID) if err != nil { return err } - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityAccept, GTSModel: follow, diff --git a/internal/federation/federatingdb/announce.go b/internal/federation/federatingdb/announce.go index f4d145148..552a95ba9 100644 --- a/internal/federation/federatingdb/announce.go +++ b/internal/federation/federatingdb/announce.go @@ -59,7 +59,7 @@ func (f *federatingDB) Announce(ctx context.Context, announce vocab.ActivityStre } // it's a new announce so pass it back to the processor async for dereferencing etc - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ActivityAnnounce, APActivityType: ap.ActivityCreate, GTSModel: boost, diff --git a/internal/federation/federatingdb/announce_test.go b/internal/federation/federatingdb/announce_test.go index 6c0d969f4..d9158f383 100644 --- a/internal/federation/federatingdb/announce_test.go +++ b/internal/federation/federatingdb/announce_test.go @@ -25,6 +25,7 @@ "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" ) type AnnounceTestSuite struct { @@ -74,6 +75,13 @@ func (suite *AnnounceTestSuite) TestAnnounceTwice() { suite.True(ok) suite.Equal(announcingAccount.ID, boost.AccountID) + // Insert the boost-of status into the + // DB cache to emulate processor handling + boost.ID, _ = id.NewULIDFromTime(boost.CreatedAt) + suite.state.Caches.GTS.Status().Store(boost, func() error { + return nil + }) + // only the URI will be set on the boosted status because it still needs to be dereferenced suite.NotEmpty(boost.BoostOf.URI) diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go index bf3e7f75d..ca87131fe 100644 --- a/internal/federation/federatingdb/create.go +++ b/internal/federation/federatingdb/create.go @@ -103,11 +103,11 @@ func (f *federatingDB) activityBlock(ctx context.Context, asType vocab.Type, rec block.ID = id.NewULID() - if err := f.db.PutBlock(ctx, block); err != nil { + if err := f.state.DB.PutBlock(ctx, block); err != nil { return fmt.Errorf("activityBlock: database error inserting block: %s", err) } - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ActivityBlock, APActivityType: ap.ActivityCreate, GTSModel: block, @@ -202,7 +202,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream return nil } // pass the note iri into the processor and have it do the dereferencing instead of doing it here - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, APIri: id.GetIRI(), @@ -226,7 +226,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream } status.ID = statusID - if err := f.db.PutStatus(ctx, status); err != nil { + if err := f.state.DB.PutStatus(ctx, status); err != nil { if errors.Is(err, db.ErrAlreadyExists) { // the status already exists in the database, which means we've already handled everything else, // so we can just return nil here and be done with it. @@ -236,7 +236,7 @@ func (f *federatingDB) createNote(ctx context.Context, note vocab.ActivityStream return fmt.Errorf("createNote: database error inserting status: %s", err) } - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, GTSModel: status, @@ -263,11 +263,11 @@ func (f *federatingDB) activityFollow(ctx context.Context, asType vocab.Type, re followRequest.ID = id.NewULID() - if err := f.db.Put(ctx, followRequest); err != nil { + if err := f.state.DB.Put(ctx, followRequest); err != nil { return fmt.Errorf("activityFollow: database error inserting follow request: %s", err) } - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityCreate, GTSModel: followRequest, @@ -294,11 +294,11 @@ func (f *federatingDB) activityLike(ctx context.Context, asType vocab.Type, rece fave.ID = id.NewULID() - if err := f.db.Put(ctx, fave); err != nil { + if err := f.state.DB.Put(ctx, fave); err != nil { return fmt.Errorf("activityLike: database error inserting fave: %s", err) } - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ActivityLike, APActivityType: ap.ActivityCreate, GTSModel: fave, @@ -325,11 +325,11 @@ func (f *federatingDB) activityFlag(ctx context.Context, asType vocab.Type, rece report.ID = id.NewULID() - if err := f.db.PutReport(ctx, report); err != nil { + if err := f.state.DB.PutReport(ctx, report); err != nil { return fmt.Errorf("activityFlag: database error inserting report: %w", err) } - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ActivityFlag, APActivityType: ap.ActivityCreate, GTSModel: report, diff --git a/internal/federation/federatingdb/db.go b/internal/federation/federatingdb/db.go index 24455a553..af4aceeeb 100644 --- a/internal/federation/federatingdb/db.go +++ b/internal/federation/federatingdb/db.go @@ -24,9 +24,7 @@ "codeberg.org/gruf/go-mutexes" "github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/streams/vocab" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" ) @@ -43,17 +41,15 @@ type DB interface { // It doesn't care what the underlying implementation of the DB interface is, as long as it works. type federatingDB struct { locks mutexes.MutexMap - db db.DB - fedWorker *concurrency.WorkerPool[messages.FromFederator] + state *state.State typeConverter typeutils.TypeConverter } // New returns a DB interface using the given database and config -func New(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator], tc typeutils.TypeConverter) DB { +func New(state *state.State, tc typeutils.TypeConverter) DB { fdb := federatingDB{ locks: mutexes.NewMap(-1, -1), // use defaults - db: db, - fedWorker: fedWorker, + state: state, typeConverter: tc, } return &fdb diff --git a/internal/federation/federatingdb/delete.go b/internal/federation/federatingdb/delete.go index a1890b56b..695f199b4 100644 --- a/internal/federation/federatingdb/delete.go +++ b/internal/federation/federatingdb/delete.go @@ -51,9 +51,9 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error { // in a delete we only get the URI, we can't know if we have a status or a profile or something else, // so we have to try a few different things... - if s, err := f.db.GetStatusByURI(ctx, id.String()); err == nil && requestingAccount.ID == s.AccountID { + if s, err := f.state.DB.GetStatusByURI(ctx, id.String()); err == nil && requestingAccount.ID == s.AccountID { l.Debugf("uri is for STATUS with id: %s", s.ID) - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityDelete, GTSModel: s, @@ -61,9 +61,9 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error { }) } - if a, err := f.db.GetAccountByURI(ctx, id.String()); err == nil && requestingAccount.ID == a.ID { + if a, err := f.state.DB.GetAccountByURI(ctx, id.String()); err == nil && requestingAccount.ID == a.ID { l.Debugf("uri is for ACCOUNT with id %s", a.ID) - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ObjectProfile, APActivityType: ap.ActivityDelete, GTSModel: a, diff --git a/internal/federation/federatingdb/federatingdb_test.go b/internal/federation/federatingdb/federatingdb_test.go index dd5a5f5f9..b0893f246 100644 --- a/internal/federation/federatingdb/federatingdb_test.go +++ b/internal/federation/federatingdb/federatingdb_test.go @@ -23,11 +23,11 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/ap" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -36,9 +36,9 @@ type FederatingDBTestSuite struct { suite.Suite db db.DB tc typeutils.TypeConverter - fedWorker *concurrency.WorkerPool[messages.FromFederator] fromFederator chan messages.FromFederator federatingDB federatingdb.DB + state state.State testTokens map[string]*gtsmodel.Token testClients map[string]*gtsmodel.Client @@ -66,22 +66,33 @@ func (suite *FederatingDBTestSuite) SetupTest() { testrig.InitTestConfig() testrig.InitTestLog() - suite.fedWorker = concurrency.NewWorkerPool[messages.FromFederator](-1, -1) + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + suite.fromFederator = make(chan messages.FromFederator, 10) - suite.fedWorker.SetProcessor(func(ctx context.Context, msg messages.FromFederator) error { + suite.state.Workers.EnqueueFederator = func(ctx context.Context, msg messages.FromFederator) { suite.fromFederator <- msg - return nil - }) - _ = suite.fedWorker.Start() - suite.db = testrig.NewTestDB() + } + + suite.db = testrig.NewTestDB(&suite.state) suite.testActivities = testrig.NewTestActivities(suite.testAccounts) suite.tc = testrig.NewTestTypeConverter(suite.db) - suite.federatingDB = testrig.NewTestFederatingDB(suite.db, suite.fedWorker) + suite.federatingDB = testrig.NewTestFederatingDB(&suite.state) testrig.StandardDBSetup(suite.db, suite.testAccounts) + + suite.state.DB = suite.db } func (suite *FederatingDBTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) + testrig.StopWorkers(&suite.state) + for suite.fromFederator != nil { + select { + case <-suite.fromFederator: + default: + return + } + } } func createTestContext(receivingAccount *gtsmodel.Account, requestingAccount *gtsmodel.Account) context.Context { diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go index c47a2b625..69746c99b 100644 --- a/internal/federation/federatingdb/followers.go +++ b/internal/federation/federatingdb/followers.go @@ -29,7 +29,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow return nil, err } - acctFollowers, err := f.db.GetAccountFollowedBy(ctx, acct.ID, false) + acctFollowers, err := f.state.DB.GetAccountFollowedBy(ctx, acct.ID, false) if err != nil { return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err) } @@ -37,7 +37,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow iris := []*url.URL{} for _, follow := range acctFollowers { if follow.Account == nil { - a, err := f.db.GetAccountByID(ctx, follow.AccountID) + a, err := f.state.DB.GetAccountByID(ctx, follow.AccountID) if err != nil { errWrapped := fmt.Errorf("Followers: db error getting account id %s: %s", follow.AccountID, err) if err == db.ErrNoEntries { diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go index f4f07bb25..9c22c0574 100644 --- a/internal/federation/federatingdb/following.go +++ b/internal/federation/federatingdb/following.go @@ -47,7 +47,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow return nil, err } - acctFollowing, err := f.db.GetAccountFollows(ctx, acct.ID) + acctFollowing, err := f.state.DB.GetAccountFollows(ctx, acct.ID) if err != nil { return nil, fmt.Errorf("Following: db error getting following for account id %s: %s", acct.ID, err) } @@ -55,7 +55,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow iris := []*url.URL{} for _, follow := range acctFollowing { if follow.TargetAccount == nil { - a, err := f.db.GetAccountByID(ctx, follow.TargetAccountID) + a, err := f.state.DB.GetAccountByID(ctx, follow.TargetAccountID) if err != nil { errWrapped := fmt.Errorf("Following: db error getting account id %s: %s", follow.TargetAccountID, err) if err == db.ErrNoEntries { diff --git a/internal/federation/federatingdb/get.go b/internal/federation/federatingdb/get.go index 92a79d70f..1d687f110 100644 --- a/internal/federation/federatingdb/get.go +++ b/internal/federation/federatingdb/get.go @@ -39,13 +39,13 @@ func (f *federatingDB) Get(ctx context.Context, id *url.URL) (value vocab.Type, switch { case uris.IsUserPath(id): - acct, err := f.db.GetAccountByURI(ctx, id.String()) + acct, err := f.state.DB.GetAccountByURI(ctx, id.String()) if err != nil { return nil, err } return f.typeConverter.AccountToAS(ctx, acct) case uris.IsStatusesPath(id): - status, err := f.db.GetStatusByURI(ctx, id.String()) + status, err := f.state.DB.GetStatusByURI(ctx, id.String()) if err != nil { return nil, err } diff --git a/internal/federation/federatingdb/inbox.go b/internal/federation/federatingdb/inbox.go index 5ec735bd4..1a6da4ef0 100644 --- a/internal/federation/federatingdb/inbox.go +++ b/internal/federation/federatingdb/inbox.go @@ -85,12 +85,12 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs return nil, fmt.Errorf("couldn't extract local account username from uri %s: %s", iri, err) } - account, err := f.db.GetAccountByUsernameDomain(c, localAccountUsername, "") + account, err := f.state.DB.GetAccountByUsernameDomain(c, localAccountUsername, "") if err != nil { return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err) } - follows, err := f.db.GetAccountFollowedBy(c, account.ID, false) + follows, err := f.state.DB.GetAccountFollowedBy(c, account.ID, false) if err != nil { return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err) } @@ -98,7 +98,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs for _, follow := range follows { // make sure we retrieved the following account from the db if follow.Account == nil { - followingAccount, err := f.db.GetAccountByID(c, follow.AccountID) + followingAccount, err := f.state.DB.GetAccountByID(c, follow.AccountID) if err != nil { if err == db.ErrNoEntries { continue @@ -126,7 +126,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs } // check if this is just an account IRI... - if account, err := f.db.GetAccountByURI(c, iri.String()); err == nil { + if account, err := f.state.DB.GetAccountByURI(c, iri.String()); err == nil { // deliver to a shared inbox if we have that option var inbox string if config.GetInstanceDeliverToSharedInboxes() && account.SharedInboxURI != nil && *account.SharedInboxURI != "" { diff --git a/internal/federation/federatingdb/owns.go b/internal/federation/federatingdb/owns.go index def0fa518..2c11e8148 100644 --- a/internal/federation/federatingdb/owns.go +++ b/internal/federation/federatingdb/owns.go @@ -54,7 +54,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } - status, err := f.db.GetStatusByURI(ctx, uid) + status, err := f.state.DB.GetStatusByURI(ctx, uid) if err != nil { if err == db.ErrNoEntries { // there are no entries for this status @@ -71,7 +71,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } - if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { + if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -88,7 +88,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } - if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { + if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -105,7 +105,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err) } - if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { + if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -122,7 +122,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err) } - if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { + if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -130,7 +130,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { // an actual error happened return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) } - if err := f.db.GetByID(ctx, likeID, >smodel.StatusFave{}); err != nil { + if err := f.state.DB.GetByID(ctx, likeID, >smodel.StatusFave{}); err != nil { if err == db.ErrNoEntries { // there are no entries return false, nil @@ -147,7 +147,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { if err != nil { return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err) } - if _, err := f.db.GetAccountByUsernameDomain(ctx, username, ""); err != nil { + if _, err := f.state.DB.GetAccountByUsernameDomain(ctx, username, ""); err != nil { if err == db.ErrNoEntries { // there are no entries for this username return false, nil @@ -155,7 +155,7 @@ func (f *federatingDB) Owns(ctx context.Context, id *url.URL) (bool, error) { // an actual error happened return false, fmt.Errorf("database error fetching account with username %s: %s", username, err) } - if err := f.db.GetByID(ctx, blockID, >smodel.Block{}); err != nil { + if err := f.state.DB.GetByID(ctx, blockID, >smodel.Block{}); err != nil { if err == db.ErrNoEntries { // there are no entries return false, nil diff --git a/internal/federation/federatingdb/reject.go b/internal/federation/federatingdb/reject.go index 3c3cd7c75..d443cd6cb 100644 --- a/internal/federation/federatingdb/reject.go +++ b/internal/federation/federatingdb/reject.go @@ -64,7 +64,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR if uris.IsFollowPath(rejectedObjectIRI) { // REJECT FOLLOW gtsFollowRequest := >smodel.FollowRequest{} - if err := f.db.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil { + if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil { return fmt.Errorf("Reject: couldn't get follow request with id %s from the database: %s", rejectedObjectIRI.String(), err) } @@ -73,7 +73,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR return errors.New("Reject: follow object account and inbox account were not the same") } - if _, err := f.db.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil { + if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil { return err } @@ -102,7 +102,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR if gtsFollow.AccountID != receivingAccount.ID { return errors.New("Reject: follow object account and inbox account were not the same") } - if _, err := f.db.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil { + if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil { return err } diff --git a/internal/federation/federatingdb/undo.go b/internal/federation/federatingdb/undo.go index b239aabb4..e33b365fa 100644 --- a/internal/federation/federatingdb/undo.go +++ b/internal/federation/federatingdb/undo.go @@ -81,11 +81,11 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo) return errors.New("UNDO: follow object account and inbox account were not the same") } // delete any existing FOLLOW - if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil { + if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil { return fmt.Errorf("UNDO: db error removing follow: %s", err) } // delete any existing FOLLOW REQUEST - if err := f.db.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil { + if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil { return fmt.Errorf("UNDO: db error removing follow request: %s", err) } l.Debug("follow undone") @@ -114,7 +114,7 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo) return errors.New("UNDO: block object account and inbox account were not the same") } // delete any existing BLOCK - if err := f.db.DeleteBlockByURI(ctx, gtsBlock.URI); err != nil { + if err := f.state.DB.DeleteBlockByURI(ctx, gtsBlock.URI); err != nil { return fmt.Errorf("UNDO: db error removing block: %s", err) } l.Debug("block undone") diff --git a/internal/federation/federatingdb/update.go b/internal/federation/federatingdb/update.go index 570729a31..bed5de4db 100644 --- a/internal/federation/federatingdb/update.go +++ b/internal/federation/federatingdb/update.go @@ -138,7 +138,7 @@ func (f *federatingDB) Update(ctx context.Context, asType vocab.Type) error { // pass to the processor for further updating of eg., avatar/header, emojis // the actual db insert/update will take place a bit later - f.fedWorker.Queue(messages.FromFederator{ + f.state.Workers.EnqueueFederator(ctx, messages.FromFederator{ APObjectType: ap.ObjectProfile, APActivityType: ap.ActivityUpdate, GTSModel: updatedAcct, diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go index 64f32d39c..f63eb6dc9 100644 --- a/internal/federation/federatingdb/util.go +++ b/internal/federation/federatingdb/util.go @@ -95,7 +95,7 @@ func (f *federatingDB) NewID(ctx context.Context, t vocab.Type) (idURL *url.URL, // take the IRI of the first actor we can find (there should only be one) if iter.IsIRI() { // if there's an error here, just use the fallback behavior -- we don't need to return an error here - if actorAccount, err := f.db.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil { + if actorAccount, err := f.state.DB.GetAccountByURI(ctx, iter.GetIRI().String()); err == nil { newID, err := id.NewRandomULID() if err != nil { return nil, err @@ -238,7 +238,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts switch { case uris.IsUserPath(iri): - if acct, err = f.db.GetAccountByURI(ctx, iri.String()); err != nil { + if acct, err = f.state.DB.GetAccountByURI(ctx, iri.String()); err != nil { if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to uri %s", iri.String()) } @@ -246,7 +246,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts } return acct, nil case uris.IsInboxPath(iri): - if err = f.db.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil { + if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil { if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to inbox %s", iri.String()) } @@ -254,7 +254,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts } return acct, nil case uris.IsOutboxPath(iri): - if err = f.db.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil { + if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil { if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to outbox %s", iri.String()) } @@ -262,7 +262,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts } return acct, nil case uris.IsFollowersPath(iri): - if err = f.db.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil { + if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil { if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to followers_uri %s", iri.String()) } @@ -270,7 +270,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts } return acct, nil case uris.IsFollowingPath(iri): - if err = f.db.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil { + if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil { if err == db.ErrNoEntries { return nil, fmt.Errorf("no actor found that corresponds to following_uri %s", iri.String()) } diff --git a/internal/federation/federatingprotocol_test.go b/internal/federation/federatingprotocol_test.go index faa168a71..e66cd78cb 100644 --- a/internal/federation/federatingprotocol_test.go +++ b/internal/federation/federatingprotocol_test.go @@ -28,10 +28,8 @@ "github.com/go-fed/httpsig" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/ap" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -43,12 +41,10 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook1() { // the activity we're gonna use activity := suite.testActivities["dm_for_zork"] - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, httpClient) // setup module being tested - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) // setup request ctx := context.Background() @@ -74,13 +70,11 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook2() { // the activity we're gonna use activity := suite.testActivities["reply_to_turtle_for_zork"] - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, httpClient) // setup module being tested - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) // setup request ctx := context.Background() @@ -107,13 +101,11 @@ func (suite *FederatingProtocolTestSuite) TestPostInboxRequestBodyHook3() { // the activity we're gonna use activity := suite.testActivities["reply_to_turtle_for_turtle"] - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, httpClient) // setup module being tested - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) // setup request ctx := context.Background() @@ -142,13 +134,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostInbox() { sendingAccount := suite.testAccounts["remote_account_1"] inboxAccount := suite.testAccounts["local_account_1"] - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, httpClient) // now setup module being tested, with the mock transport controller - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil) // we need these headers for the request to be validated @@ -187,13 +177,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGone() { activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"] inboxAccount := suite.testAccounts["local_account_1"] - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, httpClient) // now setup module being tested, with the mock transport controller - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil) // we need these headers for the request to be validated @@ -231,13 +219,11 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGoneNoTombstoneYet activity := suite.testActivities["delete_https://somewhere.mysterious/users/rest_in_piss#main-key"] inboxAccount := suite.testAccounts["local_account_1"] - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) + tc := testrig.NewTestTransportController(&suite.state, httpClient) // now setup module being tested, with the mock transport controller - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) request := httptest.NewRequest(http.MethodPost, "http://localhost:8080/users/the_mighty_zork/inbox", nil) // we need these headers for the request to be validated @@ -271,10 +257,9 @@ func (suite *FederatingProtocolTestSuite) TestAuthenticatePostGoneNoTombstoneYet } func (suite *FederatingProtocolTestSuite) TestBlocked1() { - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + tc := testrig.NewTestTransportController(&suite.state, httpClient) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) sendingAccount := suite.testAccounts["remote_account_1"] inboxAccount := suite.testAccounts["local_account_1"] @@ -294,10 +279,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked1() { } func (suite *FederatingProtocolTestSuite) TestBlocked2() { - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + tc := testrig.NewTestTransportController(&suite.state, httpClient) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) sendingAccount := suite.testAccounts["remote_account_1"] inboxAccount := suite.testAccounts["local_account_1"] @@ -328,10 +312,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked2() { } func (suite *FederatingProtocolTestSuite) TestBlocked3() { - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + tc := testrig.NewTestTransportController(&suite.state, httpClient) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) sendingAccount := suite.testAccounts["remote_account_1"] inboxAccount := suite.testAccounts["local_account_1"] @@ -365,10 +348,9 @@ func (suite *FederatingProtocolTestSuite) TestBlocked3() { } func (suite *FederatingProtocolTestSuite) TestBlocked4() { - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) httpClient := testrig.NewMockHTTPClient(nil, "../../testrig/media") - tc := testrig.NewTestTransportController(httpClient, suite.db, fedWorker) - federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(suite.db, fedWorker), tc, suite.tc, testrig.NewTestMediaManager(suite.db, suite.storage)) + tc := testrig.NewTestTransportController(&suite.state, httpClient) + federator := federation.NewFederator(suite.db, testrig.NewTestFederatingDB(&suite.state), tc, suite.tc, testrig.NewTestMediaManager(&suite.state)) sendingAccount := suite.testAccounts["remote_account_1"] inboxAccount := suite.testAccounts["local_account_1"] diff --git a/internal/federation/federator_test.go b/internal/federation/federator_test.go index da6038ace..8a045aa1f 100644 --- a/internal/federation/federator_test.go +++ b/internal/federation/federator_test.go @@ -23,6 +23,7 @@ "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" @@ -32,6 +33,7 @@ type FederatorStandardTestSuite struct { suite.Suite db db.DB storage *storage.Driver + state state.State tc typeutils.TypeConverter testAccounts map[string]*gtsmodel.Account testStatuses map[string]*gtsmodel.Status @@ -42,8 +44,9 @@ type FederatorStandardTestSuite struct { // SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout func (suite *FederatorStandardTestSuite) SetupSuite() { // setup standard items + testrig.StartWorkers(&suite.state) suite.storage = testrig.NewInMemoryStorage() - suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.state.Storage = suite.storage suite.testAccounts = testrig.NewTestAccounts() suite.testStatuses = testrig.NewTestStatuses() suite.testTombstones = testrig.NewTestTombstones() @@ -52,7 +55,10 @@ func (suite *FederatorStandardTestSuite) SetupSuite() { func (suite *FederatorStandardTestSuite) SetupTest() { testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() + suite.state.Caches.Init() + suite.db = testrig.NewTestDB(&suite.state) + suite.tc = testrig.NewTestTypeConverter(suite.db) + suite.state.DB = suite.db suite.testActivities = testrig.NewTestActivities(suite.testAccounts) testrig.StandardDBSetup(suite.db, suite.testAccounts) } diff --git a/internal/media/media_test.go b/internal/media/media_test.go index d9f01c1ff..393126ac7 100644 --- a/internal/media/media_test.go +++ b/internal/media/media_test.go @@ -20,11 +20,10 @@ import ( "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" gtsmodel "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/testrig" @@ -35,6 +34,7 @@ type MediaStandardTestSuite struct { db db.DB storage *storage.Driver + state state.State manager media.Manager transportController transport.Controller testAttachments map[string]*gtsmodel.MediaAttachment @@ -46,21 +46,27 @@ func (suite *MediaStandardTestSuite) SetupSuite() { testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) suite.storage = testrig.NewInMemoryStorage() + suite.state.DB = suite.db + suite.state.Storage = suite.storage } func (suite *MediaStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.StandardStorageSetup(suite.storage, "../../testrig/media") testrig.StandardDBSetup(suite.db, nil) suite.testAttachments = testrig.NewTestAttachments() suite.testAccounts = testrig.NewTestAccounts() suite.testEmojis = testrig.NewTestEmojis() - suite.manager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](0, 0)) + suite.manager = testrig.NewTestMediaManager(&suite.state) + suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../testrig/media")) } func (suite *MediaStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } diff --git a/internal/oauth/clientstore_test.go b/internal/oauth/clientstore_test.go index 92c117bb3..a243383da 100644 --- a/internal/oauth/clientstore_test.go +++ b/internal/oauth/clientstore_test.go @@ -25,6 +25,7 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/testrig" "github.com/superseriousbusiness/oauth2/v4/models" ) @@ -32,6 +33,7 @@ type PgClientStoreTestSuite struct { suite.Suite db db.DB + state state.State testClientID string testClientSecret string testClientDomain string @@ -48,9 +50,11 @@ func (suite *PgClientStoreTestSuite) SetupSuite() { // SetupTest creates a postgres connection and creates the oauth_clients table before each test func (suite *PgClientStoreTestSuite) SetupTest() { + suite.state.Caches.Init() testrig.InitTestLog() testrig.InitTestConfig() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db testrig.StandardDBSetup(suite.db, nil) } diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go index 41315d483..62330c0dc 100644 --- a/internal/processing/account/account.go +++ b/internal/processing/account/account.go @@ -19,13 +19,11 @@ package account import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" @@ -35,35 +33,32 @@ // // It also contains logic for actions towards accounts such as following, blocking, seeing follows, etc. type Processor struct { + state *state.State tc typeutils.TypeConverter mediaManager media.Manager - clientWorker *concurrency.WorkerPool[messages.FromClientAPI] oauthServer oauth.Server filter visibility.Filter formatter text.Formatter - db db.DB federator federation.Federator parseMention gtsmodel.ParseMentionFunc } // New returns a new account processor. func New( - db db.DB, + state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, oauthServer oauth.Server, - clientWorker *concurrency.WorkerPool[messages.FromClientAPI], federator federation.Federator, parseMention gtsmodel.ParseMentionFunc, ) Processor { return Processor{ + state: state, tc: tc, mediaManager: mediaManager, - clientWorker: clientWorker, oauthServer: oauthServer, - filter: visibility.NewFilter(db), - formatter: text.NewFormatter(db), - db: db, + filter: visibility.NewFilter(state.DB), + formatter: text.NewFormatter(state.DB), federator: federator, parseMention: parseMention, } diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go index 2e7cdb994..7a2e5aa8d 100644 --- a/internal/processing/account/account_test.go +++ b/internal/processing/account/account_test.go @@ -22,7 +22,6 @@ "context" "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" @@ -32,6 +31,7 @@ "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing/account" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -44,6 +44,7 @@ type AccountStandardTestSuite struct { db db.DB tc typeutils.TypeConverter storage *storage.Driver + state state.State mediaManager media.Manager oauthServer oauth.Server fromClientAPIChan chan messages.FromClientAPI @@ -76,30 +77,30 @@ func (suite *AccountStandardTestSuite) SetupSuite() { } func (suite *AccountStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestLog() testrig.InitTestConfig() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - clientWorker.SetProcessor(func(_ context.Context, msg messages.FromClientAPI) error { - suite.fromClientAPIChan <- msg - return nil - }) - - _ = fedWorker.Start() - _ = clientWorker.Start() - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.tc = testrig.NewTestTypeConverter(suite.db) suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) + suite.state.Storage = suite.storage + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) suite.oauthServer = testrig.NewTestOauthServer(suite.db) + suite.fromClientAPIChan = make(chan messages.FromClientAPI, 100) - suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker) - suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker) + suite.state.Workers.EnqueueClientAPI = func(ctx context.Context, msg messages.FromClientAPI) { + suite.fromClientAPIChan <- msg + } + + suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")) + suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager) suite.sentEmails = make(map[string]string) suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) - suite.accountProcessor = account.New(suite.db, suite.tc, suite.mediaManager, suite.oauthServer, clientWorker, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator)) + suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator)) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") } @@ -107,4 +108,5 @@ func (suite *AccountStandardTestSuite) SetupTest() { func (suite *AccountStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } diff --git a/internal/processing/account/block.go b/internal/processing/account/block.go index 99effd3a3..edec106b1 100644 --- a/internal/processing/account/block.go +++ b/internal/processing/account/block.go @@ -36,13 +36,13 @@ // BlockCreate handles the creation of a block from requestingAccount to targetAccountID, either remote or local. func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // make sure the target account actually exists in our db - targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) + targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err)) } // if requestingAccount already blocks target account, we don't need to do anything - if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil { + if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, false); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err)) } else if blocked { return p.RelationshipGet(ctx, requestingAccount, targetAccountID) @@ -64,18 +64,18 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel block.URI = uris.GenerateURIForBlock(requestingAccount.Username, newBlockID) // whack it in the database - if err := p.db.PutBlock(ctx, block); err != nil { + if err := p.state.DB.PutBlock(ctx, block); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error creating block in db: %s", err)) } // clear any follows or follow requests from the blocked account to the target account -- this is a simple delete - if err := p.db.DeleteWhere(ctx, []db.Where{ + if err := p.state.DB.DeleteWhere(ctx, []db.Where{ {Key: "account_id", Value: targetAccountID}, {Key: "target_account_id", Value: requestingAccount.ID}, }, >smodel.Follow{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow in db: %s", err)) } - if err := p.db.DeleteWhere(ctx, []db.Where{ + if err := p.state.DB.DeleteWhere(ctx, []db.Where{ {Key: "account_id", Value: targetAccountID}, {Key: "target_account_id", Value: requestingAccount.ID}, }, >smodel.FollowRequest{}); err != nil { @@ -89,12 +89,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel var frChanged bool var frURI string fr := >smodel.FollowRequest{} - if err := p.db.GetWhere(ctx, []db.Where{ + if err := p.state.DB.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, fr); err == nil { frURI = fr.URI - if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { + if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow request from db: %s", err)) } frChanged = true @@ -104,12 +104,12 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel var fChanged bool var fURI string f := >smodel.Follow{} - if err := p.db.GetWhere(ctx, []db.Where{ + if err := p.state.DB.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, f); err == nil { fURI = f.URI - if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { + if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error removing follow from db: %s", err)) } fChanged = true @@ -117,7 +117,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel // follow request status changed so send the UNDO activity to the channel for async processing if frChanged { - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityUndo, GTSModel: >smodel.Follow{ @@ -132,7 +132,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel // follow status changed so send the UNDO activity to the channel for async processing if fChanged { - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityUndo, GTSModel: >smodel.Follow{ @@ -146,7 +146,7 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel } // handle the rest of the block process asynchronously - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityBlock, APActivityType: ap.ActivityCreate, GTSModel: block, @@ -160,23 +160,23 @@ func (p *Processor) BlockCreate(ctx context.Context, requestingAccount *gtsmodel // BlockRemove handles the removal of a block from requestingAccount to targetAccountID, either remote or local. func (p *Processor) BlockRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // make sure the target account actually exists in our db - targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) + targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err)) } // check if a block exists, and remove it if it does - block, err := p.db.GetBlock(ctx, requestingAccount.ID, targetAccountID) + block, err := p.state.DB.GetBlock(ctx, requestingAccount.ID, targetAccountID) if err == nil { // we got a block, remove it block.Account = requestingAccount block.TargetAccount = targetAccount - if err := p.db.DeleteBlockByID(ctx, block.ID); err != nil { + if err := p.state.DB.DeleteBlockByID(ctx, block.ID); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err)) } // send the UNDO activity to the client worker for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityBlock, APActivityType: ap.ActivityUndo, GTSModel: block, diff --git a/internal/processing/account/bookmarks.go b/internal/processing/account/bookmarks.go index 28688c20d..cf53e63bb 100644 --- a/internal/processing/account/bookmarks.go +++ b/internal/processing/account/bookmarks.go @@ -34,7 +34,7 @@ // BookmarksGet returns a pageable response of statuses that are bookmarked by requestingAccount. // Paging for this response is done based on bookmark ID rather than status ID. func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmodel.Account, limit int, maxID string, minID string) (*apimodel.PageableResponse, gtserror.WithCode) { - bookmarks, err := p.db.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID) + bookmarks, err := p.state.DB.GetBookmarks(ctx, requestingAccount.ID, limit, maxID, minID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -47,7 +47,7 @@ func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmode ) for _, bookmark := range bookmarks { - status, err := p.db.GetStatusByID(ctx, bookmark.StatusID) + status, err := p.state.DB.GetStatusByID(ctx, bookmark.StatusID) if err != nil { if errors.Is(err, db.ErrNoEntries) { // We just don't have the status for some reason. diff --git a/internal/processing/account/create.go b/internal/processing/account/create.go index 8b82bc681..9c9cfb57f 100644 --- a/internal/processing/account/create.go +++ b/internal/processing/account/create.go @@ -35,7 +35,7 @@ // Create processes the given form for creating a new account, returning an oauth token for that account if successful. func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInfo, application *gtsmodel.Application, form *apimodel.AccountCreateRequest) (*apimodel.Token, gtserror.WithCode) { - emailAvailable, err := p.db.IsEmailAvailable(ctx, form.Email) + emailAvailable, err := p.state.DB.IsEmailAvailable(ctx, form.Email) if err != nil { return nil, gtserror.NewErrorBadRequest(err) } @@ -43,7 +43,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf return nil, gtserror.NewErrorConflict(fmt.Errorf("email address %s is not available", form.Email)) } - usernameAvailable, err := p.db.IsUsernameAvailable(ctx, form.Username) + usernameAvailable, err := p.state.DB.IsUsernameAvailable(ctx, form.Username) if err != nil { return nil, gtserror.NewErrorBadRequest(err) } @@ -61,7 +61,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf } log.Trace(ctx, "creating new username and account") - user, err := p.db.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false) + user, err := p.state.DB.NewSignup(ctx, form.Username, text.SanitizePlaintext(reason), approvalRequired, form.Email, form.Password, form.IP, form.Locale, application.ID, false, "", false) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error creating new signup in the database: %s", err)) } @@ -73,7 +73,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf } if user.Account == nil { - a, err := p.db.GetAccountByID(ctx, user.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, user.AccountID) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting new account from the database: %s", err)) } @@ -82,7 +82,7 @@ func (p *Processor) Create(ctx context.Context, applicationToken oauth2.TokenInf // there are side effects for creating a new account (sending confirmation emails etc) // so pass a message to the processor so that it can do it asynchronously - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectProfile, APActivityType: ap.ActivityCreate, GTSModel: user.Account, diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go index 58a967337..eea4a621e 100644 --- a/internal/processing/account/delete.go +++ b/internal/processing/account/delete.go @@ -54,22 +54,22 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi if account.Domain == "" { // see if we can get a user for this account var err error - if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil { + if user, err = p.state.DB.GetUserByAccountID(ctx, account.ID); err == nil { // we got one! select all tokens with the user's ID tokens := []*gtsmodel.Token{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil { // we have some tokens to delete for _, t := range tokens { // delete client(s) associated with this token - if err := p.db.DeleteByID(ctx, t.ClientID, >smodel.Client{}); err != nil { + if err := p.state.DB.DeleteByID(ctx, t.ClientID, >smodel.Client{}); err != nil { l.Errorf("error deleting oauth client: %s", err) } // delete application(s) associated with this token - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, >smodel.Application{}); err != nil { l.Errorf("error deleting application: %s", err) } // delete the token itself - if err := p.db.DeleteByID(ctx, t.ID, t); err != nil { + if err := p.state.DB.DeleteByID(ctx, t.ID, t); err != nil { l.Errorf("error deleting oauth token: %s", err) } } @@ -80,12 +80,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi // 2. Delete account's blocks l.Trace("deleting account blocks") // first delete any blocks that this account created - if err := p.db.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil { + if err := p.state.DB.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil { l.Errorf("error deleting blocks created by account: %s", err) } // now delete any blocks that target this account - if err := p.db.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil { + if err := p.state.DB.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil { l.Errorf("error deleting blocks targeting account: %s", err) } @@ -96,12 +96,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi // TODO: federate these if necessary l.Trace("deleting account follow requests") // first delete any follow requests that this account created - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { l.Errorf("error deleting follow requests created by account: %s", err) } // now delete any follow requests that target this account - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.FollowRequest{}); err != nil { l.Errorf("error deleting follow requests targeting account: %s", err) } @@ -109,12 +109,12 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi // TODO: federate these if necessary l.Trace("deleting account follows") // first delete any follows that this account created - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { l.Errorf("error deleting follows created by account: %s", err) } // now delete any follows that target this account - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Follow{}); err != nil { l.Errorf("error deleting follows targeting account: %s", err) } @@ -129,7 +129,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi for { // Fetch next block of account statuses from database - statuses, err := p.db.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false) + statuses, err := p.state.DB.GetAccountStatuses(ctx, account.ID, 20, false, false, maxID, "", false, false) if err != nil { if !errors.Is(err, db.ErrNoEntries) { // an actual error has occurred @@ -149,7 +149,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi l.Tracef("queue client API status delete: %s", status.ID) // pass the status delete through the client api channel for processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityDelete, GTSModel: status, @@ -158,7 +158,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi }) // Look for any boosts of this status in DB - boosts, err := p.db.GetStatusReblogs(ctx, status) + boosts, err := p.state.DB.GetStatusReblogs(ctx, status) if err != nil && !errors.Is(err, db.ErrNoEntries) { l.Errorf("error fetching status reblogs for %q: %v", status.ID, err) continue @@ -167,7 +167,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi for _, boost := range boosts { if boost.Account == nil { // Fetch the relevant account for this status boost - boostAcc, err := p.db.GetAccountByID(ctx, boost.AccountID) + boostAcc, err := p.state.DB.GetAccountByID(ctx, boost.AccountID) if err != nil { l.Errorf("error fetching boosted status account for %q: %v", boost.AccountID, err) continue @@ -180,7 +180,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi l.Tracef("queue client API boost delete: %s", status.ID) // pass the boost delete through the client api channel for processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityAnnounce, APActivityType: ap.ActivityUndo, GTSModel: status, @@ -197,31 +197,31 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi // 10. Delete account's notifications l.Trace("deleting account notifications") // first notifications created by account - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "origin_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { l.Errorf("error deleting notifications created by account: %s", err) } // now notifications targeting account - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "target_account_id", Value: account.ID}}, &[]*gtsmodel.Notification{}); err != nil { l.Errorf("error deleting notifications targeting account: %s", err) } // 11. Delete account's bookmarks l.Trace("deleting account bookmarks") - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { l.Errorf("error deleting bookmarks created by account: %s", err) } // 12. Delete account's faves // TODO: federate these if necessary l.Trace("deleting account faves") - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusFave{}); err != nil { l.Errorf("error deleting faves created by account: %s", err) } // 13. Delete account's mutes l.Trace("deleting account mutes") - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &[]*gtsmodel.StatusMute{}); err != nil { l.Errorf("error deleting status mutes created by account: %s", err) } @@ -234,7 +234,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi // 16. Delete account's user if user != nil { l.Trace("deleting account user") - if err := p.db.DeleteUserByID(ctx, user.ID); err != nil { + if err := p.state.DB.DeleteUserByID(ctx, user.ID); err != nil { return gtserror.NewErrorInternalError(err) } } @@ -261,7 +261,7 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi account.Discoverable = &discoverable account.SuspendedAt = time.Now() account.SuspensionOrigin = origin - err := p.db.UpdateAccount(ctx, account) + err := p.state.DB.UpdateAccount(ctx, account) if err != nil { return gtserror.NewErrorInternalError(err) } @@ -281,7 +281,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account, if form.DeleteOriginID == account.ID { // the account owner themself has requested deletion via the API, get their user from the db - user, err := p.db.GetUserByAccountID(ctx, account.ID) + user, err := p.state.DB.GetUserByAccountID(ctx, account.ID) if err != nil { return gtserror.NewErrorInternalError(err) } @@ -301,7 +301,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account, } else { // the delete has been requested by some other account, grab it; // if we've reached this point we know it has permission already - requestingAccount, err := p.db.GetAccountByID(ctx, form.DeleteOriginID) + requestingAccount, err := p.state.DB.GetAccountByID(ctx, form.DeleteOriginID) if err != nil { return gtserror.NewErrorInternalError(err) } @@ -310,7 +310,7 @@ func (p *Processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account, } // put the delete in the processor queue to handle the rest of it asynchronously - p.clientWorker.Queue(fromClientAPIMessage) + p.state.Workers.EnqueueClientAPI(ctx, fromClientAPIMessage) return nil } diff --git a/internal/processing/account/follow.go b/internal/processing/account/follow.go index d4d479be7..ac65c39f2 100644 --- a/internal/processing/account/follow.go +++ b/internal/processing/account/follow.go @@ -35,14 +35,14 @@ // FollowCreate handles a follow request to an account, either remote or local. func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) { // if there's a block between the accounts we shouldn't create the request ofc - if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil { + if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, form.ID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } // make sure the target account actually exists in our db - targetAcct, err := p.db.GetAccountByID(ctx, form.ID) + targetAcct, err := p.state.DB.GetAccountByID(ctx, form.ID) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err)) @@ -51,7 +51,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode } // check if a follow exists already - if follows, err := p.db.IsFollowing(ctx, requestingAccount, targetAcct); err != nil { + if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount, targetAcct); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err)) } else if follows { // already follows so just return the relationship @@ -59,7 +59,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode } // check if a follow request exists already - if followRequested, err := p.db.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil { + if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount, targetAcct); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err)) } else if followRequested { // already follow requested so just return the relationship @@ -95,13 +95,13 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode } // whack it in the database - if err := p.db.Put(ctx, fr); err != nil { + if err := p.state.DB.Put(ctx, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error creating follow request in db: %s", err)) } // if it's a local account that's not locked we can just straight up accept the follow request if !*targetAcct.Locked && targetAcct.Domain == "" { - if _, err := p.db.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { + if _, err := p.state.DB.AcceptFollowRequest(ctx, requestingAccount.ID, form.ID); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error accepting folow request for local unlocked account: %s", err)) } // return the new relationship @@ -109,7 +109,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode } // otherwise we leave the follow request as it is and we handle the rest of the process asynchronously - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityCreate, GTSModel: fr, @@ -124,7 +124,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode // FollowRemove handles the removal of a follow/follow request to an account, either remote or local. func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) { // if there's a block between the accounts we shouldn't do anything - blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -133,7 +133,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode } // make sure the target account actually exists in our db - targetAcct, err := p.db.GetAccountByID(ctx, targetAccountID) + targetAcct, err := p.state.DB.GetAccountByID(ctx, targetAccountID) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("AccountFollowRemove: account %s not found in the db: %s", targetAccountID, err)) @@ -144,12 +144,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode var frChanged bool var frURI string fr := >smodel.FollowRequest{} - if err := p.db.GetWhere(ctx, []db.Where{ + if err := p.state.DB.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, fr); err == nil { frURI = fr.URI - if err := p.db.DeleteByID(ctx, fr.ID, fr); err != nil { + if err := p.state.DB.DeleteByID(ctx, fr.ID, fr); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow request from db: %s", err)) } frChanged = true @@ -159,12 +159,12 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode var fChanged bool var fURI string f := >smodel.Follow{} - if err := p.db.GetWhere(ctx, []db.Where{ + if err := p.state.DB.GetWhere(ctx, []db.Where{ {Key: "account_id", Value: requestingAccount.ID}, {Key: "target_account_id", Value: targetAccountID}, }, f); err == nil { fURI = f.URI - if err := p.db.DeleteByID(ctx, f.ID, f); err != nil { + if err := p.state.DB.DeleteByID(ctx, f.ID, f); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("AccountFollowRemove: error removing follow from db: %s", err)) } fChanged = true @@ -172,7 +172,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode // follow request status changed so send the UNDO activity to the channel for async processing if frChanged { - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityUndo, GTSModel: >smodel.Follow{ @@ -187,7 +187,7 @@ func (p *Processor) FollowRemove(ctx context.Context, requestingAccount *gtsmode // follow status changed so send the UNDO activity to the channel for async processing if fChanged { - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityUndo, GTSModel: >smodel.Follow{ diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go index 11de1ddac..2c650254f 100644 --- a/internal/processing/account/get.go +++ b/internal/processing/account/get.go @@ -33,7 +33,7 @@ // Get processes the given request for account information. func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, gtserror.WithCode) { - targetAccount, err := p.db.GetAccountByID(ctx, targetAccountID) + targetAccount, err := p.state.DB.GetAccountByID(ctx, targetAccountID) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(errors.New("account not found")) @@ -46,7 +46,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account // GetLocalByUsername processes the given request for account information targeting a local account by username. func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *gtsmodel.Account, username string) (*apimodel.Account, gtserror.WithCode) { - targetAccount, err := p.db.GetAccountByUsernameDomain(ctx, username, "") + targetAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "") if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(errors.New("account not found")) @@ -59,7 +59,7 @@ func (p *Processor) GetLocalByUsername(ctx context.Context, requestingAccount *g // GetCustomCSSForUsername returns custom css for the given local username. func (p *Processor) GetCustomCSSForUsername(ctx context.Context, username string) (string, gtserror.WithCode) { - customCSS, err := p.db.GetAccountCustomCSSByUsername(ctx, username) + customCSS, err := p.state.DB.GetAccountCustomCSSByUsername(ctx, username) if err != nil { if err == db.ErrNoEntries { return "", gtserror.NewErrorNotFound(errors.New("account not found")) @@ -74,7 +74,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco var blocked bool var err error if requestingAccount != nil { - blocked, err = p.db.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true) + blocked, err = p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking account block: %s", err)) } diff --git a/internal/processing/account/relationships.go b/internal/processing/account/relationships.go index cb2789829..f60216f95 100644 --- a/internal/processing/account/relationships.go +++ b/internal/processing/account/relationships.go @@ -31,14 +31,14 @@ // FollowersGet fetches a list of the target account's followers. func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { + if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } accounts := []apimodel.Account{} - follows, err := p.db.GetAccountFollowedBy(ctx, targetAccountID, false) + follows, err := p.state.DB.GetAccountFollowedBy(ctx, targetAccountID, false) if err != nil { if err == db.ErrNoEntries { return accounts, nil @@ -47,7 +47,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode } for _, f := range follows { - blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -56,7 +56,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode } if f.Account == nil { - a, err := p.db.GetAccountByID(ctx, f.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, f.AccountID) if err != nil { if err == db.ErrNoEntries { continue @@ -77,14 +77,14 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode // FollowingGet fetches a list of the accounts that target account is following. func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) { - if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { + if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) } accounts := []apimodel.Account{} - follows, err := p.db.GetAccountFollows(ctx, targetAccountID) + follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID) if err != nil { if err == db.ErrNoEntries { return accounts, nil @@ -93,7 +93,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode } for _, f := range follows { - blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, f.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -102,7 +102,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode } if f.TargetAccount == nil { - a, err := p.db.GetAccountByID(ctx, f.TargetAccountID) + a, err := p.state.DB.GetAccountByID(ctx, f.TargetAccountID) if err != nil { if err == db.ErrNoEntries { continue @@ -127,7 +127,7 @@ func (p *Processor) RelationshipGet(ctx context.Context, requestingAccount *gtsm return nil, gtserror.NewErrorForbidden(errors.New("not authed")) } - gtsR, err := p.db.GetRelationship(ctx, requestingAccount.ID, targetAccountID) + gtsR, err := p.state.DB.GetRelationship(ctx, requestingAccount.ID, targetAccountID) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error getting relationship: %s", err)) } diff --git a/internal/processing/account/rss.go b/internal/processing/account/rss.go index 22065cf8e..61fcc1c51 100644 --- a/internal/processing/account/rss.go +++ b/internal/processing/account/rss.go @@ -34,7 +34,7 @@ // GetRSSFeedForUsername returns RSS feed for the given local username. func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) (func() (string, gtserror.WithCode), time.Time, gtserror.WithCode) { - account, err := p.db.GetAccountByUsernameDomain(ctx, username, "") + account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, "") if err != nil { if err == db.ErrNoEntries { return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account not found")) @@ -46,13 +46,13 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) return nil, time.Time{}, gtserror.NewErrorNotFound(errors.New("GetRSSFeedForUsername: account RSS feed not enabled")) } - lastModified, err := p.db.GetAccountLastPosted(ctx, account.ID, true) + lastModified, err := p.state.DB.GetAccountLastPosted(ctx, account.ID, true) if err != nil { return nil, time.Time{}, gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err)) } return func() (string, gtserror.WithCode) { - statuses, err := p.db.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "") + statuses, err := p.state.DB.GetAccountWebStatuses(ctx, account.ID, rssFeedLength, "") if err != nil && err != db.ErrNoEntries { return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error: %s", err)) } @@ -65,7 +65,7 @@ func (p *Processor) GetRSSFeedForUsername(ctx context.Context, username string) var image *feeds.Image if account.AvatarMediaAttachmentID != "" { if account.AvatarMediaAttachment == nil { - avatar, err := p.db.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID) + avatar, err := p.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID) if err != nil { return "", gtserror.NewErrorInternalError(fmt.Errorf("GetRSSFeedForUsername: db error fetching avatar attachment: %s", err)) } diff --git a/internal/processing/account/statuses.go b/internal/processing/account/statuses.go index 7ff6de2ff..9961dbdbe 100644 --- a/internal/processing/account/statuses.go +++ b/internal/processing/account/statuses.go @@ -33,7 +33,7 @@ // the account given in authed. func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) { if requestingAccount != nil { - if blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { + if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil { return nil, gtserror.NewErrorInternalError(err) } else if blocked { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts")) @@ -46,10 +46,10 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel ) if pinned { // Get *ONLY* pinned statuses. - statuses, err = p.db.GetAccountPinnedStatuses(ctx, targetAccountID) + statuses, err = p.state.DB.GetAccountPinnedStatuses(ctx, targetAccountID) } else { // Get account statuses which *may* include pinned ones. - statuses, err = p.db.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly) + statuses, err = p.state.DB.GetAccountStatuses(ctx, targetAccountID, limit, excludeReplies, excludeReblogs, maxID, minID, mediaOnly, publicOnly) } if err != nil { if err == db.ErrNoEntries { @@ -120,7 +120,7 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel // WebStatusesGet fetches a number of statuses (in descending order) from the given account. It selects only // statuses which are suitable for showing on the public web profile of an account. func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, maxID string) (*apimodel.PageableResponse, gtserror.WithCode) { - acct, err := p.db.GetAccountByID(ctx, targetAccountID) + acct, err := p.state.DB.GetAccountByID(ctx, targetAccountID) if err != nil { if err == db.ErrNoEntries { err := fmt.Errorf("account %s not found in the db, not getting web statuses for it", targetAccountID) @@ -134,7 +134,7 @@ func (p *Processor) WebStatusesGet(ctx context.Context, targetAccountID string, return nil, gtserror.NewErrorNotFound(err) } - statuses, err := p.db.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID) + statuses, err := p.state.DB.GetAccountWebStatuses(ctx, targetAccountID, 10, maxID) if err != nil { if err == db.ErrNoEntries { return util.EmptyPageableResponse(), nil diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go index cffbbb0c5..537857cee 100644 --- a/internal/processing/account/update.go +++ b/internal/processing/account/update.go @@ -165,12 +165,12 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, form account.EnableRSS = form.EnableRSS } - err := p.db.UpdateAccount(ctx, account) + err := p.state.DB.UpdateAccount(ctx, account) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("could not update account %s: %s", account.ID, err)) } - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectProfile, APActivityType: ap.ActivityUpdate, GTSModel: account, diff --git a/internal/processing/admin/account.go b/internal/processing/admin/account.go index d23d1fbfe..ba4c5d4eb 100644 --- a/internal/processing/admin/account.go +++ b/internal/processing/admin/account.go @@ -31,7 +31,7 @@ ) func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account, form *apimodel.AdminAccountActionRequest) gtserror.WithCode { - targetAccount, err := p.db.GetAccountByID(ctx, form.TargetAccountID) + targetAccount, err := p.state.DB.GetAccountByID(ctx, form.TargetAccountID) if err != nil { return gtserror.NewErrorInternalError(err) } @@ -47,7 +47,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account case string(gtsmodel.AdminActionSuspend): adminAction.Type = gtsmodel.AdminActionSuspend // pass the account delete through the client api channel for processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActorPerson, APActivityType: ap.ActivityDelete, OriginAccount: account, @@ -57,7 +57,7 @@ func (p *Processor) AccountAction(ctx context.Context, account *gtsmodel.Account return gtserror.NewErrorBadRequest(fmt.Errorf("admin action type %s is not supported for this endpoint", form.Type)) } - if err := p.db.Put(ctx, adminAction); err != nil { + if err := p.state.DB.Put(ctx, adminAction); err != nil { return gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/admin/admin.go b/internal/processing/admin/admin.go index 54827b8fd..ba09969dc 100644 --- a/internal/processing/admin/admin.go +++ b/internal/processing/admin/admin.go @@ -19,32 +19,25 @@ package admin import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" ) type Processor struct { + state *state.State tc typeutils.TypeConverter mediaManager media.Manager transportController transport.Controller - storage *storage.Driver - clientWorker *concurrency.WorkerPool[messages.FromClientAPI] - db db.DB } // New returns a new admin processor. -func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor { +func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor { return Processor{ + state: state, tc: tc, mediaManager: mediaManager, transportController: transportController, - storage: storage, - clientWorker: clientWorker, - db: db, } } diff --git a/internal/processing/admin/domainblock.go b/internal/processing/admin/domainblock.go index 415ac610f..dd22f72e6 100644 --- a/internal/processing/admin/domainblock.go +++ b/internal/processing/admin/domainblock.go @@ -28,7 +28,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc domain = strings.ToLower(domain) // first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work - block, err := p.db.GetDomainBlock(ctx, domain) + block, err := p.state.DB.GetDomainBlock(ctx, domain) if err != nil { if !errors.Is(err, db.ErrNoEntries) { // something went wrong in the DB @@ -47,7 +47,7 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc } // Insert the new block into the database - if err := p.db.CreateDomainBlock(ctx, newBlock); err != nil { + if err := p.state.DB.CreateDomainBlock(ctx, newBlock); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err)) } @@ -80,7 +80,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account // if we have an instance entry for this domain, update it with the new block ID and clear all fields instance := >smodel.Instance{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil { updatingColumns := []string{ "title", "updated_at", @@ -105,15 +105,15 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account instance.ContactAccountUsername = "" instance.ContactAccountID = "" instance.Version = "" - if err := p.db.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil { + if err := p.state.DB.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil { l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err) } l.Debug("domainBlockProcessSideEffects: instance entry updated") } // if we have an instance account for this instance, delete it - if instanceAccount, err := p.db.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { - if err := p.db.DeleteAccount(ctx, instanceAccount.ID); err != nil { + if instanceAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil { + if err := p.state.DB.DeleteAccount(ctx, instanceAccount.ID); err != nil { l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err) } } @@ -125,7 +125,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account selectAccountsLoop: for { - accounts, err := p.db.GetInstanceAccounts(ctx, block.Domain, maxID, limit) + accounts, err := p.state.DB.GetInstanceAccounts(ctx, block.Domain, maxID, limit) if err != nil { if err == db.ErrNoEntries { // no accounts left for this instance so we're done @@ -141,7 +141,7 @@ func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account l.Debugf("putting delete for account %s in the clientAPI channel", a.Username) // pass the account delete through the client api channel for processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActorPerson, APActivityType: ap.ActivityDelete, GTSModel: block, @@ -195,7 +195,7 @@ func (p *Processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Ac func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) { domainBlocks := []*gtsmodel.DomainBlock{} - if err := p.db.GetAll(ctx, &domainBlocks); err != nil { + if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil { if !errors.Is(err, db.ErrNoEntries) { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) @@ -219,7 +219,7 @@ func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Accou func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) { domainBlock := >smodel.DomainBlock{} - if err := p.db.GetByID(ctx, id, domainBlock); err != nil { + if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil { if !errors.Is(err, db.ErrNoEntries) { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) @@ -240,7 +240,7 @@ func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Accoun func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) { domainBlock := >smodel.DomainBlock{} - if err := p.db.GetByID(ctx, id, domainBlock); err != nil { + if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil { if !errors.Is(err, db.ErrNoEntries) { // something has gone really wrong return nil, gtserror.NewErrorInternalError(err) @@ -256,13 +256,13 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc } // Delete the domain block - if err := p.db.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil { + if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil { return nil, gtserror.NewErrorInternalError(err) } // remove the domain block reference from the instance, if we have an entry for it i := >smodel.Instance{} - if err := p.db.GetWhere(ctx, []db.Where{ + if err := p.state.DB.GetWhere(ctx, []db.Where{ {Key: "domain", Value: domainBlock.Domain}, {Key: "domain_block_id", Value: id}, }, i); err == nil { @@ -270,21 +270,21 @@ func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Acc i.SuspendedAt = time.Time{} i.DomainBlockID = "" i.UpdatedAt = time.Now() - if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { + if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err)) } } // unsuspend all accounts whose suspension origin was this domain block // 1. remove the 'suspended_at' entry from their accounts - if err := p.db.UpdateWhere(ctx, []db.Where{ + if err := p.state.DB.UpdateWhere(ctx, []db.Where{ {Key: "suspension_origin", Value: domainBlock.ID}, }, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err)) } // 2. remove the 'suspension_origin' entry from their accounts - if err := p.db.UpdateWhere(ctx, []db.Where{ + if err := p.state.DB.UpdateWhere(ctx, []db.Where{ {Key: "suspension_origin", Value: domainBlock.ID}, }, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err)) diff --git a/internal/processing/admin/emoji.go b/internal/processing/admin/emoji.go index 391d18525..3eacbf888 100644 --- a/internal/processing/admin/emoji.go +++ b/internal/processing/admin/emoji.go @@ -42,7 +42,7 @@ func (p *Processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account, return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin") } - maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "") + maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, form.Shortcode, "") if maybeExisting != nil { return nil, gtserror.NewErrorConflict(fmt.Errorf("emoji with shortcode %s already exists", form.Shortcode), fmt.Sprintf("emoji with shortcode %s already exists", form.Shortcode)) } @@ -110,7 +110,7 @@ func (p *Processor) EmojisGet( return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin") } - emojis, err := p.db.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit) + emojis, err := p.state.DB.GetEmojis(ctx, domain, includeDisabled, includeEnabled, shortcode, maxShortcodeDomain, minShortcodeDomain, limit) if err != nil && !errors.Is(err, db.ErrNoEntries) { err := fmt.Errorf("EmojisGet: db error: %s", err) return nil, gtserror.NewErrorInternalError(err) @@ -176,7 +176,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use return nil, gtserror.NewErrorUnauthorized(fmt.Errorf("user %s not an admin", user.ID), "user is not an admin") } - emoji, err := p.db.GetEmojiByID(ctx, id) + emoji, err := p.state.DB.GetEmojiByID(ctx, id) if err != nil { if errors.Is(err, db.ErrNoEntries) { err = fmt.Errorf("EmojiGet: no emoji with id %s found in the db", id) @@ -197,7 +197,7 @@ func (p *Processor) EmojiGet(ctx context.Context, account *gtsmodel.Account, use // EmojiDelete deletes one emoji from the database, with the given id. func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.AdminEmoji, gtserror.WithCode) { - emoji, err := p.db.GetEmojiByID(ctx, id) + emoji, err := p.state.DB.GetEmojiByID(ctx, id) if err != nil { if errors.Is(err, db.ErrNoEntries) { err = fmt.Errorf("EmojiDelete: no emoji with id %s found in the db", id) @@ -218,7 +218,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin return nil, gtserror.NewErrorInternalError(err) } - if err := p.db.DeleteEmojiByID(ctx, id); err != nil { + if err := p.state.DB.DeleteEmojiByID(ctx, id); err != nil { err := fmt.Errorf("EmojiDelete: db error: %s", err) return nil, gtserror.NewErrorInternalError(err) } @@ -228,7 +228,7 @@ func (p *Processor) EmojiDelete(ctx context.Context, id string) (*apimodel.Admin // EmojiUpdate updates one emoji with the given id, using the provided form parameters. func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.EmojiUpdateRequest) (*apimodel.AdminEmoji, gtserror.WithCode) { - emoji, err := p.db.GetEmojiByID(ctx, id) + emoji, err := p.state.DB.GetEmojiByID(ctx, id) if err != nil { if errors.Is(err, db.ErrNoEntries) { err = fmt.Errorf("EmojiUpdate: no emoji with id %s found in the db", id) @@ -253,7 +253,7 @@ func (p *Processor) EmojiUpdate(ctx context.Context, id string, form *apimodel.E // EmojiCategoriesGet returns all custom emoji categories that exist on this instance. func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCategory, gtserror.WithCode) { - categories, err := p.db.GetEmojiCategories(ctx) + categories, err := p.state.DB.GetEmojiCategories(ctx) if err != nil { err := fmt.Errorf("EmojiCategoriesGet: db error: %s", err) return nil, gtserror.NewErrorInternalError(err) @@ -277,7 +277,7 @@ func (p *Processor) EmojiCategoriesGet(ctx context.Context) ([]*apimodel.EmojiCa */ func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) (*gtsmodel.EmojiCategory, error) { - category, err := p.db.GetEmojiCategoryByName(ctx, name) + category, err := p.state.DB.GetEmojiCategoryByName(ctx, name) if err == nil { return category, nil } @@ -299,7 +299,7 @@ func (p *Processor) getOrCreateEmojiCategory(ctx context.Context, name string) ( Name: name, } - if err := p.db.PutEmojiCategory(ctx, category); err != nil { + if err := p.state.DB.PutEmojiCategory(ctx, category); err != nil { err = fmt.Errorf("GetOrCreateEmojiCategory: error putting new emoji category in the database: %s", err) return nil, err } @@ -319,7 +319,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji, return nil, gtserror.NewErrorBadRequest(err, err.Error()) } - maybeExisting, err := p.db.GetEmojiByShortcodeDomain(ctx, *shortcode, "") + maybeExisting, err := p.state.DB.GetEmojiByShortcodeDomain(ctx, *shortcode, "") if maybeExisting != nil { err := fmt.Errorf("emojiUpdateCopy: emoji %s could not be copied, emoji with shortcode %s already exists on this instance", emoji.ID, *shortcode) return nil, gtserror.NewErrorConflict(err, err.Error()) @@ -339,7 +339,7 @@ func (p *Processor) emojiUpdateCopy(ctx context.Context, emoji *gtsmodel.Emoji, newEmojiURI := uris.GenerateURIForEmoji(newEmojiID) data := func(ctx context.Context) (reader io.ReadCloser, fileSize int64, err error) { - rc, err := p.storage.GetStream(ctx, emoji.ImagePath) + rc, err := p.state.Storage.GetStream(ctx, emoji.ImagePath) return rc, int64(emoji.ImageFileSize), err } @@ -386,7 +386,7 @@ func (p *Processor) emojiUpdateDisable(ctx context.Context, emoji *gtsmodel.Emoj emojiDisabled := true emoji.Disabled = &emojiDisabled - updatedEmoji, err := p.db.UpdateEmoji(ctx, emoji, "updated_at", "disabled") + updatedEmoji, err := p.state.DB.UpdateEmoji(ctx, emoji, "updated_at", "disabled") if err != nil { err = fmt.Errorf("emojiUpdateDisable: error updating emoji %s: %s", emoji.ID, err) return nil, gtserror.NewErrorInternalError(err) @@ -443,7 +443,7 @@ func (p *Processor) emojiUpdateModify(ctx context.Context, emoji *gtsmodel.Emoji } var err error - updatedEmoji, err = p.db.UpdateEmoji(ctx, emoji, columns...) + updatedEmoji, err = p.state.DB.UpdateEmoji(ctx, emoji, columns...) if err != nil { err = fmt.Errorf("emojiUpdateModify: error updating emoji %s: %s", emoji.ID, err) return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/admin/report.go b/internal/processing/admin/report.go index 3a6028bca..bed97e204 100644 --- a/internal/processing/admin/report.go +++ b/internal/processing/admin/report.go @@ -43,7 +43,7 @@ func (p *Processor) ReportsGet( minID string, limit int, ) (*apimodel.PageableResponse, gtserror.WithCode) { - reports, err := p.db.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit) + reports, err := p.state.DB.GetReports(ctx, resolved, accountID, targetAccountID, maxID, sinceID, minID, limit) if err != nil { if err == db.ErrNoEntries { return util.EmptyPageableResponse(), nil @@ -95,7 +95,7 @@ func (p *Processor) ReportsGet( // ReportGet returns one report, with the given ID. func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.AdminReport, gtserror.WithCode) { - report, err := p.db.GetReportByID(ctx, id) + report, err := p.state.DB.GetReportByID(ctx, id) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(err) @@ -113,7 +113,7 @@ func (p *Processor) ReportGet(ctx context.Context, account *gtsmodel.Account, id // ReportResolve marks a report with the given id as resolved, and stores the provided actionTakenComment (if not null). func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account, id string, actionTakenComment *string) (*apimodel.AdminReport, gtserror.WithCode) { - report, err := p.db.GetReportByID(ctx, id) + report, err := p.state.DB.GetReportByID(ctx, id) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(err) @@ -134,7 +134,7 @@ func (p *Processor) ReportResolve(ctx context.Context, account *gtsmodel.Account columns = append(columns, "action_taken") } - updatedReport, err := p.db.UpdateReport(ctx, report, columns...) + updatedReport, err := p.state.DB.UpdateReport(ctx, report, columns...) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/app.go b/internal/processing/app.go index f2a938b22..e4cda5a43 100644 --- a/internal/processing/app.go +++ b/internal/processing/app.go @@ -62,7 +62,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api } // chuck it in the db - if err := p.db.Put(ctx, app); err != nil { + if err := p.state.DB.Put(ctx, app); err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -76,7 +76,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api } // chuck it in the db - if err := p.db.Put(ctx, oc); err != nil { + if err := p.state.DB.Put(ctx, oc); err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/blocks.go b/internal/processing/blocks.go index 6dd9c3de9..754954f02 100644 --- a/internal/processing/blocks.go +++ b/internal/processing/blocks.go @@ -31,7 +31,7 @@ ) func (p *Processor) BlocksGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, limit int) (*apimodel.BlocksResponse, gtserror.WithCode) { - accounts, nextMaxID, prevMinID, err := p.db.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit) + accounts, nextMaxID, prevMinID, err := p.state.DB.GetAccountBlocks(ctx, authed.Account.ID, maxID, sinceID, limit) if err != nil { if err == db.ErrNoEntries { // there are just no entries diff --git a/internal/processing/fedi/collections.go b/internal/processing/fedi/collections.go index 78a65bebe..627511c3b 100644 --- a/internal/processing/fedi/collections.go +++ b/internal/processing/fedi/collections.go @@ -84,8 +84,8 @@ func (p *Processor) OutboxGet(ctx context.Context, requestedUsername string, pag // scenario 2 -- get the requested page // limit pages to 30 entries per page - publicStatuses, err := p.db.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true) - if err != nil && err != db.ErrNoEntries { + publicStatuses, err := p.state.DB.GetAccountStatuses(ctx, requestedAccount.ID, 30, true, true, maxID, minID, false, true) + if err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, gtserror.NewErrorInternalError(err) } @@ -161,7 +161,7 @@ func (p *Processor) FeaturedCollectionGet(ctx context.Context, requestedUsername return nil, errWithCode } - statuses, err := p.db.GetAccountPinnedStatuses(ctx, requestedAccount.ID) + statuses, err := p.state.DB.GetAccountPinnedStatuses(ctx, requestedAccount.ID) if err != nil { if !errors.Is(err, db.ErrNoEntries) { return nil, gtserror.NewErrorInternalError(err) diff --git a/internal/processing/fedi/common.go b/internal/processing/fedi/common.go index 37c604ded..a2c7f9b37 100644 --- a/internal/processing/fedi/common.go +++ b/internal/processing/fedi/common.go @@ -29,7 +29,7 @@ ) func (p *Processor) authenticate(ctx context.Context, requestedUsername string) (requestedAccount, requestingAccount *gtsmodel.Account, errWithCode gtserror.WithCode) { - requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") + requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "") if err != nil { errWithCode = gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) return @@ -46,7 +46,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string) return } - blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { errWithCode = gtserror.NewErrorInternalError(err) return diff --git a/internal/processing/fedi/emoji.go b/internal/processing/fedi/emoji.go index 0b1dd3440..b2618ca13 100644 --- a/internal/processing/fedi/emoji.go +++ b/internal/processing/fedi/emoji.go @@ -32,7 +32,7 @@ func (p *Processor) EmojiGet(ctx context.Context, requestedEmojiID string) (inte return nil, errWithCode } - requestedEmoji, err := p.db.GetEmojiByID(ctx, requestedEmojiID) + requestedEmoji, err := p.state.DB.GetEmojiByID(ctx, requestedEmojiID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting emoji with id %s: %s", requestedEmojiID, err)) } diff --git a/internal/processing/fedi/fedi.go b/internal/processing/fedi/fedi.go index e72d037f5..c8f78c5a6 100644 --- a/internal/processing/fedi/fedi.go +++ b/internal/processing/fedi/fedi.go @@ -19,25 +19,25 @@ package fedi import ( - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" ) type Processor struct { - db db.DB + state *state.State federator federation.Federator tc typeutils.TypeConverter filter visibility.Filter } // New returns a new fedi processor. -func New(db db.DB, tc typeutils.TypeConverter, federator federation.Federator) Processor { +func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator) Processor { return Processor{ - db: db, + state: state, federator: federator, tc: tc, - filter: visibility.NewFilter(db), + filter: visibility.NewFilter(state.DB), } } diff --git a/internal/processing/fedi/status.go b/internal/processing/fedi/status.go index fbadcb290..60ebb3c84 100644 --- a/internal/processing/fedi/status.go +++ b/internal/processing/fedi/status.go @@ -36,7 +36,7 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req return nil, errWithCode } - status, err := p.db.GetStatusByID(ctx, requestedStatusID) + status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(err) } @@ -74,7 +74,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri return nil, errWithCode } - status, err := p.db.GetStatusByID(ctx, requestedStatusID) + status, err := p.state.DB.GetStatusByID(ctx, requestedStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(err) } @@ -125,7 +125,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri default: // scenario 3 // get immediate children - replies, err := p.db.GetStatusChildren(ctx, status, true, minID) + replies, err := p.state.DB.GetStatusChildren(ctx, status, true, minID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/fedi/user.go b/internal/processing/fedi/user.go index 899d063d1..35e756e57 100644 --- a/internal/processing/fedi/user.go +++ b/internal/processing/fedi/user.go @@ -34,7 +34,7 @@ // before returning a JSON serializable interface to the caller. func (p *Processor) UserGet(ctx context.Context, requestedUsername string, requestURL *url.URL) (interface{}, gtserror.WithCode) { // Get the instance-local account the request is referring to. - requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") + requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "") if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } @@ -63,7 +63,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque return nil, gtserror.NewErrorUnauthorized(err) } - blocked, err := p.db.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/fedi/wellknown.go b/internal/processing/fedi/wellknown.go index 75ed34ec2..6f113ac5d 100644 --- a/internal/processing/fedi/wellknown.go +++ b/internal/processing/fedi/wellknown.go @@ -64,12 +64,12 @@ func (p *Processor) NodeInfoRelGet(ctx context.Context) (*apimodel.WellKnownResp func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserror.WithCode) { host := config.GetHost() - userCount, err := p.db.CountInstanceUsers(ctx, host) + userCount, err := p.state.DB.CountInstanceUsers(ctx, host) if err != nil { return nil, gtserror.NewErrorInternalError(err) } - postCount, err := p.db.CountInstanceStatuses(ctx, host) + postCount, err := p.state.DB.CountInstanceStatuses(ctx, host) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -99,7 +99,7 @@ func (p *Processor) NodeInfoGet(ctx context.Context) (*apimodel.Nodeinfo, gtserr // WebfingerGet handles the GET for a webfinger resource. Most commonly, it will be used for returning account lookups. func (p *Processor) WebfingerGet(ctx context.Context, requestedUsername string) (*apimodel.WellKnownResponse, gtserror.WithCode) { // Get the local account the request is referring to. - requestedAccount, err := p.db.GetAccountByUsernameDomain(ctx, requestedUsername, "") + requestedAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, requestedUsername, "") if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("database error getting account with username %s: %s", requestedUsername, err)) } diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go index 1f1b7f3c2..9bd13cc0b 100644 --- a/internal/processing/followrequest.go +++ b/internal/processing/followrequest.go @@ -30,7 +30,7 @@ ) func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) { - frs, err := p.db.GetAccountFollowRequests(ctx, auth.Account.ID) + frs, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID) if err != nil { if err != db.ErrNoEntries { return nil, gtserror.NewErrorInternalError(err) @@ -40,7 +40,7 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([] accts := []apimodel.Account{} for _, fr := range frs { if fr.Account == nil { - frAcct, err := p.db.GetAccountByID(ctx, fr.AccountID) + frAcct, err := p.state.DB.GetAccountByID(ctx, fr.AccountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -57,13 +57,13 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([] } func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { - follow, err := p.db.AcceptFollowRequest(ctx, accountID, auth.Account.ID) + follow, err := p.state.DB.AcceptFollowRequest(ctx, accountID, auth.Account.ID) if err != nil { return nil, gtserror.NewErrorNotFound(err) } if follow.Account == nil { - followAccount, err := p.db.GetAccountByID(ctx, follow.AccountID) + followAccount, err := p.state.DB.GetAccountByID(ctx, follow.AccountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -71,14 +71,14 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a } if follow.TargetAccount == nil { - followTargetAccount, err := p.db.GetAccountByID(ctx, follow.TargetAccountID) + followTargetAccount, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } follow.TargetAccount = followTargetAccount } - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityAccept, GTSModel: follow, @@ -86,7 +86,7 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a TargetAccount: follow.TargetAccount, }) - gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) + gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -100,13 +100,13 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a } func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) { - followRequest, err := p.db.RejectFollowRequest(ctx, accountID, auth.Account.ID) + followRequest, err := p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID) if err != nil { return nil, gtserror.NewErrorNotFound(err) } if followRequest.Account == nil { - a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -114,14 +114,14 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a } if followRequest.TargetAccount == nil { - a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) + a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } followRequest.TargetAccount = a } - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityFollow, APActivityType: ap.ActivityReject, GTSModel: followRequest, @@ -129,7 +129,7 @@ func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, a TargetAccount: followRequest.TargetAccount, }) - gtsR, err := p.db.GetRelationship(ctx, auth.Account.ID, accountID) + gtsR, err := p.state.DB.GetRelationship(ctx, auth.Account.ID, accountID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/fromclientapi.go b/internal/processing/fromclientapi.go index 701f425f6..209a27105 100644 --- a/internal/processing/fromclientapi.go +++ b/internal/processing/fromclientapi.go @@ -143,7 +143,7 @@ func (p *Processor) processCreateAccountFromClientAPI(ctx context.Context, clien } // get the user this account belongs to - user, err := p.db.GetUserByAccountID(ctx, account.ID) + user, err := p.state.DB.GetUserByAccountID(ctx, account.ID) if err != nil { return err } @@ -293,7 +293,7 @@ func (p *Processor) processUndoAnnounceFromClientAPI(ctx context.Context, client return errors.New("undo was not parseable as *gtsmodel.Status") } - if err := p.db.DeleteStatusByID(ctx, boost.ID); err != nil { + if err := p.state.DB.DeleteStatusByID(ctx, boost.ID); err != nil { return err } @@ -422,7 +422,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status) } if status.Account == nil { - statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) + statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID) if err != nil { return fmt.Errorf("federateStatus: error fetching status author account: %s", err) } @@ -455,7 +455,7 @@ func (p *Processor) federateStatus(ctx context.Context, status *gtsmodel.Status) func (p *Processor) federateStatusDelete(ctx context.Context, status *gtsmodel.Status) error { if status.Account == nil { - statusAccount, err := p.db.GetAccountByID(ctx, status.AccountID) + statusAccount, err := p.state.DB.GetAccountByID(ctx, status.AccountID) if err != nil { return fmt.Errorf("federateStatusDelete: error fetching status author account: %s", err) } @@ -642,7 +642,7 @@ func (p *Processor) federateUnannounce(ctx context.Context, boost *gtsmodel.Stat func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gtsmodel.Follow) error { if follow.Account == nil { - a, err := p.db.GetAccountByID(ctx, follow.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, follow.AccountID) if err != nil { return err } @@ -651,7 +651,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts originAccount := follow.Account if follow.TargetAccount == nil { - a, err := p.db.GetAccountByID(ctx, follow.TargetAccountID) + a, err := p.state.DB.GetAccountByID(ctx, follow.TargetAccountID) if err != nil { return err } @@ -715,7 +715,7 @@ func (p *Processor) federateAcceptFollowRequest(ctx context.Context, follow *gts func (p *Processor) federateRejectFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error { if followRequest.Account == nil { - a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID) if err != nil { return err } @@ -724,7 +724,7 @@ func (p *Processor) federateRejectFollowRequest(ctx context.Context, followReque originAccount := followRequest.Account if followRequest.TargetAccount == nil { - a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) + a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID) if err != nil { return err } @@ -844,7 +844,7 @@ func (p *Processor) federateAccountUpdate(ctx context.Context, updatedAccount *g func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) error { if block.Account == nil { - blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) + blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID) if err != nil { return fmt.Errorf("federateBlock: error getting block account from database: %s", err) } @@ -852,7 +852,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er } if block.TargetAccount == nil { - blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) + blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID) if err != nil { return fmt.Errorf("federateBlock: error getting block target account from database: %s", err) } @@ -880,7 +880,7 @@ func (p *Processor) federateBlock(ctx context.Context, block *gtsmodel.Block) er func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) error { if block.Account == nil { - blockAccount, err := p.db.GetAccountByID(ctx, block.AccountID) + blockAccount, err := p.state.DB.GetAccountByID(ctx, block.AccountID) if err != nil { return fmt.Errorf("federateUnblock: error getting block account from database: %s", err) } @@ -888,7 +888,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) } if block.TargetAccount == nil { - blockTargetAccount, err := p.db.GetAccountByID(ctx, block.TargetAccountID) + blockTargetAccount, err := p.state.DB.GetAccountByID(ctx, block.TargetAccountID) if err != nil { return fmt.Errorf("federateUnblock: error getting block target account from database: %s", err) } @@ -934,7 +934,7 @@ func (p *Processor) federateUnblock(ctx context.Context, block *gtsmodel.Block) func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report) error { if report.TargetAccount == nil { - reportTargetAccount, err := p.db.GetAccountByID(ctx, report.TargetAccountID) + reportTargetAccount, err := p.state.DB.GetAccountByID(ctx, report.TargetAccountID) if err != nil { return fmt.Errorf("federateReport: error getting report target account from database: %w", err) } @@ -942,7 +942,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report) } if len(report.StatusIDs) > 0 && len(report.Statuses) == 0 { - statuses, err := p.db.GetStatuses(ctx, report.StatusIDs) + statuses, err := p.state.DB.GetStatuses(ctx, report.StatusIDs) if err != nil { return fmt.Errorf("federateReport: error getting report statuses from database: %w", err) } @@ -966,7 +966,7 @@ func (p *Processor) federateReport(ctx context.Context, report *gtsmodel.Report) // deliver the flag using the outbox of the // instance account to anonymize the report - instanceAccount, err := p.db.GetInstanceAccount(ctx, "") + instanceAccount, err := p.state.DB.GetInstanceAccount(ctx, "") if err != nil { return fmt.Errorf("federateReport: error getting instance account: %w", err) } diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go index 3e4c62c6c..f9e732732 100644 --- a/internal/processing/fromcommon.go +++ b/internal/processing/fromcommon.go @@ -38,7 +38,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e if status.Mentions == nil { // there are mentions but they're not fully populated on the status yet so do this - menchies, err := p.db.GetMentions(ctx, status.MentionIDs) + menchies, err := p.state.DB.GetMentions(ctx, status.MentionIDs) if err != nil { return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err) } @@ -49,7 +49,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e for _, m := range status.Mentions { // make sure this is a local account, otherwise we don't need to create a notification for it if m.TargetAccount == nil { - a, err := p.db.GetAccountByID(ctx, m.TargetAccountID) + a, err := p.state.DB.GetAccountByID(ctx, m.TargetAccountID) if err != nil { // we don't have the account or there's been an error return fmt.Errorf("notifyStatus: error getting account with id %s from the db: %s", m.TargetAccountID, err) @@ -62,7 +62,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e } // make sure a notif doesn't already exist for this mention - if err := p.db.GetWhere(ctx, []db.Where{ + if err := p.state.DB.GetWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationMention}, {Key: "target_account_id", Value: m.TargetAccountID}, {Key: "origin_account_id", Value: m.OriginAccountID}, @@ -87,7 +87,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e Status: status, } - if err := p.db.Put(ctx, notif); err != nil { + if err := p.state.DB.Put(ctx, notif); err != nil { return fmt.Errorf("notifyStatus: error putting notification in database: %s", err) } @@ -108,7 +108,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsmodel.FollowRequest) error { // make sure we have the target account pinned on the follow request if followRequest.TargetAccount == nil { - a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) + a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID) if err != nil { return err } @@ -129,7 +129,7 @@ func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsm OriginAccountID: followRequest.AccountID, } - if err := p.db.Put(ctx, notif); err != nil { + if err := p.state.DB.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err) } @@ -153,7 +153,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t } // first remove the follow request notification - if err := p.db.DeleteWhere(ctx, []db.Where{ + if err := p.state.DB.DeleteWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationFollowRequest}, {Key: "target_account_id", Value: follow.TargetAccountID}, {Key: "origin_account_id", Value: follow.AccountID}, @@ -170,7 +170,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t OriginAccountID: follow.AccountID, OriginAccount: follow.Account, } - if err := p.db.Put(ctx, notif); err != nil { + if err := p.state.DB.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFollow: error putting notification in database: %s", err) } @@ -194,7 +194,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e } if fave.TargetAccount == nil { - a, err := p.db.GetAccountByID(ctx, fave.TargetAccountID) + a, err := p.state.DB.GetAccountByID(ctx, fave.TargetAccountID) if err != nil { return err } @@ -218,7 +218,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e Status: fave.Status, } - if err := p.db.Put(ctx, notif); err != nil { + if err := p.state.DB.Put(ctx, notif); err != nil { return fmt.Errorf("notifyFave: error putting notification in database: %s", err) } @@ -242,7 +242,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) } if status.BoostOf == nil { - boostedStatus, err := p.db.GetStatusByID(ctx, status.BoostOfID) + boostedStatus, err := p.state.DB.GetStatusByID(ctx, status.BoostOfID) if err != nil { return fmt.Errorf("notifyAnnounce: error getting status with id %s: %s", status.BoostOfID, err) } @@ -250,7 +250,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) } if status.BoostOfAccount == nil { - boostedAcct, err := p.db.GetAccountByID(ctx, status.BoostOfAccountID) + boostedAcct, err := p.state.DB.GetAccountByID(ctx, status.BoostOfAccountID) if err != nil { return fmt.Errorf("notifyAnnounce: error getting account with id %s: %s", status.BoostOfAccountID, err) } @@ -269,7 +269,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) } // make sure a notif doesn't already exist for this announce - err := p.db.GetWhere(ctx, []db.Where{ + err := p.state.DB.GetWhere(ctx, []db.Where{ {Key: "notification_type", Value: gtsmodel.NotificationReblog}, {Key: "target_account_id", Value: status.BoostOfAccountID}, {Key: "origin_account_id", Value: status.AccountID}, @@ -292,7 +292,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) Status: status, } - if err := p.db.Put(ctx, notif); err != nil { + if err := p.state.DB.Put(ctx, notif); err != nil { return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err) } @@ -314,7 +314,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status) func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error { // make sure the author account is pinned onto the status if status.Account == nil { - a, err := p.db.GetAccountByID(ctx, status.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, status.AccountID) if err != nil { return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err) } @@ -322,7 +322,7 @@ func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) } // get local followers of the account that posted the status - follows, err := p.db.GetAccountFollowedBy(ctx, status.AccountID, true) + follows, err := p.state.DB.GetAccountFollowedBy(ctx, status.AccountID, true) if err != nil { return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err) } @@ -374,7 +374,7 @@ func (p *Processor) timelineStatusForAccount(ctx context.Context, status *gtsmod defer wg.Done() // get the timeline owner account - timelineAccount, err := p.db.GetAccountByID(ctx, accountID) + timelineAccount, err := p.state.DB.GetAccountByID(ctx, accountID) if err != nil { errors <- fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %s", accountID, err) return @@ -446,28 +446,28 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta // delete all mention entries generated by this status for _, m := range statusToDelete.MentionIDs { - if err := p.db.DeleteByID(ctx, m, >smodel.Mention{}); err != nil { + if err := p.state.DB.DeleteByID(ctx, m, >smodel.Mention{}); err != nil { return err } } // delete all notification entries generated by this status - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.Notification{}); err != nil { return err } // delete all bookmarks that point to this status - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: statusToDelete.ID}}, &[]*gtsmodel.StatusBookmark{}); err != nil { return err } // delete all boosts for this status + remove them from timelines - if boosts, err := p.db.GetStatusReblogs(ctx, statusToDelete); err == nil { + if boosts, err := p.state.DB.GetStatusReblogs(ctx, statusToDelete); err == nil { for _, b := range boosts { if err := p.deleteStatusFromTimelines(ctx, b); err != nil { return err } - if err := p.db.DeleteStatusByID(ctx, b.ID); err != nil { + if err := p.state.DB.DeleteStatusByID(ctx, b.ID); err != nil { return err } } @@ -479,7 +479,7 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta } // delete the status itself - if err := p.db.DeleteStatusByID(ctx, statusToDelete.ID); err != nil { + if err := p.state.DB.DeleteStatusByID(ctx, statusToDelete.ID); err != nil { return err } diff --git a/internal/processing/fromfederator.go b/internal/processing/fromfederator.go index eea3c529d..afddedf93 100644 --- a/internal/processing/fromfederator.go +++ b/internal/processing/fromfederator.go @@ -139,7 +139,7 @@ func (p *Processor) processCreateStatusFromFederator(ctx context.Context, federa // make sure the account is pinned if status.Account == nil { - a, err := p.db.GetAccountByID(ctx, status.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, status.AccountID) if err != nil { return err } @@ -185,7 +185,7 @@ func (p *Processor) processCreateFaveFromFederator(ctx context.Context, federato // make sure the account is pinned if incomingFave.Account == nil { - a, err := p.db.GetAccountByID(ctx, incomingFave.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, incomingFave.AccountID) if err != nil { return err } @@ -227,7 +227,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context, // make sure the account is pinned if followRequest.Account == nil { - a, err := p.db.GetAccountByID(ctx, followRequest.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, followRequest.AccountID) if err != nil { return err } @@ -254,7 +254,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context, } if followRequest.TargetAccount == nil { - a, err := p.db.GetAccountByID(ctx, followRequest.TargetAccountID) + a, err := p.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID) if err != nil { return err } @@ -267,7 +267,7 @@ func (p *Processor) processCreateFollowRequestFromFederator(ctx context.Context, } // if the target account isn't locked, we should already accept the follow and notify about the new follower instead - follow, err := p.db.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID) + follow, err := p.state.DB.AcceptFollowRequest(ctx, followRequest.AccountID, followRequest.TargetAccountID) if err != nil { return err } @@ -288,7 +288,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede // make sure the account is pinned if incomingAnnounce.Account == nil { - a, err := p.db.GetAccountByID(ctx, incomingAnnounce.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, incomingAnnounce.AccountID) if err != nil { return err } @@ -324,7 +324,7 @@ func (p *Processor) processCreateAnnounceFromFederator(ctx context.Context, fede } incomingAnnounce.ID = incomingAnnounceID - if err := p.db.PutStatus(ctx, incomingAnnounce); err != nil { + if err := p.state.DB.PutStatus(ctx, incomingAnnounce); err != nil { return fmt.Errorf("error adding dereferenced announce to the db: %s", err) } diff --git a/internal/processing/fromfederator_test.go b/internal/processing/fromfederator_test.go index 6913b22af..d8f8ad6e1 100644 --- a/internal/processing/fromfederator_test.go +++ b/internal/processing/fromfederator_test.go @@ -344,7 +344,6 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() { suite.NoError(err) // now they are mufos! - err = suite.processor.ProcessFromFederator(ctx, messages.FromFederator{ APObjectType: ap.ObjectProfile, APActivityType: ap.ActivityDelete, diff --git a/internal/processing/instance.go b/internal/processing/instance.go index c3dc4dcea..3ca807af3 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -35,7 +35,7 @@ func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) { i := >smodel.Instance{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil { return nil, err } return i, nil @@ -73,7 +73,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool, domains := []*apimodel.Domain{} if includeOpen { - instances, err := p.db.GetInstancePeers(ctx, false) + instances, err := p.state.DB.GetInstancePeers(ctx, false) if err != nil && err != db.ErrNoEntries { err = fmt.Errorf("error selecting instance peers: %s", err) return nil, gtserror.NewErrorInternalError(err) @@ -87,7 +87,7 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool, if includeSuspended { domainBlocks := []*gtsmodel.DomainBlock{} - if err := p.db.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries { + if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil && err != db.ErrNoEntries { return nil, gtserror.NewErrorInternalError(err) } @@ -124,12 +124,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe // fetch the instance entry from the db for processing i := >smodel.Instance{} host := config.GetHost() - if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err)) } // fetch the instance account from the db for processing - ia, err := p.db.GetInstanceAccount(ctx, "") + ia, err := p.state.DB.GetInstanceAccount(ctx, "") if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance account %s: %s", host, err)) } @@ -148,12 +148,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe // validate & update site contact account if it's set on the form if form.ContactUsername != nil { // make sure the account with the given username exists in the db - contactAccount, err := p.db.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "") + contactAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, *form.ContactUsername, "") if err != nil { return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername)) } // make sure it has a user associated with it - contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID) + contactUser, err := p.state.DB.GetUserByAccountID(ctx, contactAccount.ID) if err != nil { return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername)) } @@ -233,7 +233,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe } else if form.AvatarDescription != nil && ia.AvatarMediaAttachment != nil { // process just the description for the existing avatar ia.AvatarMediaAttachment.Description = *form.AvatarDescription - if err := p.db.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil { + if err := p.state.DB.UpdateByID(ctx, ia.AvatarMediaAttachment, ia.AvatarMediaAttachmentID, "description"); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance avatar description: %s", err)) } } @@ -252,13 +252,13 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe if updateInstanceAccount { // if either avatar or header is updated, we need // to update the instance account that stores them - if err := p.db.UpdateAccount(ctx, ia); err != nil { + if err := p.state.DB.UpdateAccount(ctx, ia); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance account: %s", err)) } } if len(updatingColumns) != 0 { - if err := p.db.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { + if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err)) } } diff --git a/internal/processing/media/delete.go b/internal/processing/media/delete.go index 6507fcae4..02bd6cd0d 100644 --- a/internal/processing/media/delete.go +++ b/internal/processing/media/delete.go @@ -13,7 +13,7 @@ // Delete deletes the media attachment with the given ID, including all files pertaining to that attachment. func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserror.WithCode { - attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID) if err != nil { if err == db.ErrNoEntries { // attachment already gone @@ -27,20 +27,20 @@ func (p *Processor) Delete(ctx context.Context, mediaAttachmentID string) gtserr // delete the thumbnail from storage if attachment.Thumbnail.Path != "" { - if err := p.storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) { + if err := p.state.Storage.Delete(ctx, attachment.Thumbnail.Path); err != nil && !errors.Is(err, storage.ErrNotFound) { errs = append(errs, fmt.Sprintf("remove thumbnail at path %s: %s", attachment.Thumbnail.Path, err)) } } // delete the file from storage if attachment.File.Path != "" { - if err := p.storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) { + if err := p.state.Storage.Delete(ctx, attachment.File.Path); err != nil && !errors.Is(err, storage.ErrNotFound) { errs = append(errs, fmt.Sprintf("remove file at path %s: %s", attachment.File.Path, err)) } } // delete the attachment - if err := p.db.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) { + if err := p.state.DB.DeleteByID(ctx, mediaAttachmentID, attachment); err != nil && !errors.Is(err, db.ErrNoEntries) { errs = append(errs, fmt.Sprintf("remove attachment: %s", err)) } diff --git a/internal/processing/media/getemoji.go b/internal/processing/media/getemoji.go index 4c0ce9930..fba059f60 100644 --- a/internal/processing/media/getemoji.go +++ b/internal/processing/media/getemoji.go @@ -31,7 +31,7 @@ // GetCustomEmojis returns a list of all useable local custom emojis stored on this instance. // 'useable' in this context means visible and picker, and not disabled. func (p *Processor) GetCustomEmojis(ctx context.Context) ([]*apimodel.Emoji, gtserror.WithCode) { - emojis, err := p.db.GetUseableEmojis(ctx) + emojis, err := p.state.DB.GetUseableEmojis(ctx) if err != nil { if err != db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error retrieving custom emojis: %s", err)) diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go index 2a4ef2097..f9c6c23c2 100644 --- a/internal/processing/media/getfile.go +++ b/internal/processing/media/getfile.go @@ -54,7 +54,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc owningAccountID := form.AccountID // get the account that owns the media and make sure it's not suspended - owningAccount, err := p.db.GetAccountByID(ctx, owningAccountID) + owningAccount, err := p.state.DB.GetAccountByID(ctx, owningAccountID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("account with id %s could not be selected from the db: %s", owningAccountID, err)) } @@ -64,7 +64,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc // make sure the requesting account and the media account don't block each other if requestingAccount != nil { - blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", owningAccountID, requestingAccount.ID, err)) } @@ -117,7 +117,7 @@ func parseSize(s string) (media.Size, error) { func (p *Processor) getAttachmentContent(ctx context.Context, requestingAccount *gtsmodel.Account, wantedMediaID string, owningAccountID string, mediaSize media.Size) (*apimodel.Content, gtserror.WithCode) { // retrieve attachment from the database and do basic checks on it - a, err := p.db.GetAttachmentByID(ctx, wantedMediaID) + a, err := p.state.DB.GetAttachmentByID(ctx, wantedMediaID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("attachment %s could not be taken from the db: %s", wantedMediaID, err)) } @@ -209,7 +209,7 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning // so this is more reliable than using full size url imageStaticURL := uris.GenerateURIForAttachment(owningAccountID, string(media.TypeEmoji), string(media.SizeStatic), fileName, "png") - e, err := p.db.GetEmojiByStaticURL(ctx, imageStaticURL) + e, err := p.state.DB.GetEmojiByStaticURL(ctx, imageStaticURL) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("emoji %s could not be taken from the db: %s", fileName, err)) } @@ -237,12 +237,12 @@ func (p *Processor) getEmojiContent(ctx context.Context, fileName string, owning func (p *Processor) retrieveFromStorage(ctx context.Context, storagePath string, content *apimodel.Content) (*apimodel.Content, gtserror.WithCode) { // If running on S3 storage with proxying disabled then // just fetch a pre-signed URL instead of serving the content. - if url := p.storage.URL(ctx, storagePath); url != nil { + if url := p.state.Storage.URL(ctx, storagePath); url != nil { content.URL = url return content, nil } - reader, err := p.storage.GetStream(ctx, storagePath) + reader, err := p.state.Storage.GetStream(ctx, storagePath) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error retrieving from storage: %s", err)) } diff --git a/internal/processing/media/getmedia.go b/internal/processing/media/getmedia.go index 03d5ba770..dad6ac538 100644 --- a/internal/processing/media/getmedia.go +++ b/internal/processing/media/getmedia.go @@ -30,7 +30,7 @@ ) func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { - attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID) if err != nil { if err == db.ErrNoEntries { // attachment doesn't exist diff --git a/internal/processing/media/media.go b/internal/processing/media/media.go index ca95e276f..51585102a 100644 --- a/internal/processing/media/media.go +++ b/internal/processing/media/media.go @@ -19,28 +19,25 @@ package media import ( - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" ) type Processor struct { + state *state.State tc typeutils.TypeConverter mediaManager media.Manager transportController transport.Controller - storage *storage.Driver - db db.DB } // New returns a new media processor. -func New(db db.DB, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller, storage *storage.Driver) Processor { +func New(state *state.State, tc typeutils.TypeConverter, mediaManager media.Manager, transportController transport.Controller) Processor { return Processor{ + state: state, tc: tc, mediaManager: mediaManager, transportController: transportController, - storage: storage, - db: db, } } diff --git a/internal/processing/media/media_test.go b/internal/processing/media/media_test.go index 1d223a66c..e706dbd7a 100644 --- a/internal/processing/media/media_test.go +++ b/internal/processing/media/media_test.go @@ -20,12 +20,11 @@ import ( "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" mediaprocessing "github.com/superseriousbusiness/gotosocial/internal/processing/media" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -38,6 +37,7 @@ type MediaStandardTestSuite struct { db db.DB tc typeutils.TypeConverter storage *storage.Driver + state state.State mediaManager media.Manager transportController transport.Controller @@ -67,15 +67,19 @@ func (suite *MediaStandardTestSuite) SetupSuite() { } func (suite *MediaStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.tc = testrig.NewTestTypeConverter(suite.db) suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.transportController = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, concurrency.NewWorkerPool[messages.FromFederator](-1, -1)) - suite.mediaProcessor = mediaprocessing.New(suite.db, suite.tc, suite.mediaManager, suite.transportController, suite.storage) + suite.state.Storage = suite.storage + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.transportController = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")) + suite.mediaProcessor = mediaprocessing.New(&suite.state, suite.tc, suite.mediaManager, suite.transportController) testrig.StandardDBSetup(suite.db, nil) testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") } diff --git a/internal/processing/media/unattach.go b/internal/processing/media/unattach.go index 816b5134e..7c6f7dbac 100644 --- a/internal/processing/media/unattach.go +++ b/internal/processing/media/unattach.go @@ -33,7 +33,7 @@ // Unattach unattaches the media attachment with the given ID from any statuses it was attached to, making it available // for reattachment again. func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string) (*apimodel.Attachment, gtserror.WithCode) { - attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(errors.New("attachment doesn't exist in the db")) @@ -49,7 +49,7 @@ func (p *Processor) Unattach(ctx context.Context, account *gtsmodel.Account, med attachment.UpdatedAt = time.Now() attachment.StatusID = "" - if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { + if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("db error updating attachment: %s", err)) } diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go index c03df705b..cf49168f0 100644 --- a/internal/processing/media/update.go +++ b/internal/processing/media/update.go @@ -32,7 +32,7 @@ // Update updates a media attachment with the given id, using the provided form parameters. func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, mediaAttachmentID string, form *apimodel.AttachmentUpdateRequest) (*apimodel.Attachment, gtserror.WithCode) { - attachment, err := p.db.GetAttachmentByID(ctx, mediaAttachmentID) + attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaAttachmentID) if err != nil { if err == db.ErrNoEntries { // attachment doesn't exist @@ -62,7 +62,7 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, media updatingColumns = append(updatingColumns, "focus_x", "focus_y") } - if err := p.db.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { + if err := p.state.DB.UpdateByID(ctx, attachment, attachment.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error updating media: %s", err)) } diff --git a/internal/processing/notification.go b/internal/processing/notification.go index 05d0e82ee..57100e743 100644 --- a/internal/processing/notification.go +++ b/internal/processing/notification.go @@ -29,7 +29,7 @@ ) func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, excludeTypes []string, limit int, maxID string, sinceID string) (*apimodel.PageableResponse, gtserror.WithCode) { - notifs, err := p.db.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID) + notifs, err := p.state.DB.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -72,7 +72,7 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ex } func (p *Processor) NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode { - err := p.db.ClearNotifications(ctx, authed.Account.ID) + err := p.state.DB.ClearNotifications(ctx, authed.Account.ID) if err != nil { return gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/processor.go b/internal/processing/processor.go index 07fcdb8b3..bb75aab76 100644 --- a/internal/processing/processor.go +++ b/internal/processing/processor.go @@ -19,10 +19,11 @@ package processing import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" + "context" + "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" + "github.com/superseriousbusiness/gotosocial/internal/log" mm "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" @@ -34,23 +35,19 @@ "github.com/superseriousbusiness/gotosocial/internal/processing/status" "github.com/superseriousbusiness/gotosocial/internal/processing/stream" "github.com/superseriousbusiness/gotosocial/internal/processing/user" - "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" ) type Processor struct { - clientWorker *concurrency.WorkerPool[messages.FromClientAPI] - fedWorker *concurrency.WorkerPool[messages.FromFederator] - federator federation.Federator tc typeutils.TypeConverter oauthServer oauth.Server mediaManager mm.Manager - storage *storage.Driver statusTimelines timeline.Manager - db db.DB + state *state.State filter visibility.Filter /* @@ -105,76 +102,65 @@ func NewProcessor( federator federation.Federator, oauthServer oauth.Server, mediaManager mm.Manager, - storage *storage.Driver, - db db.DB, + state *state.State, emailSender email.Sender, - clientWorker *concurrency.WorkerPool[messages.FromClientAPI], - fedWorker *concurrency.WorkerPool[messages.FromFederator], ) *Processor { - parseMentionFunc := GetParseMentionFunc(db, federator) + parseMentionFunc := GetParseMentionFunc(state.DB, federator) - filter := visibility.NewFilter(db) + filter := visibility.NewFilter(state.DB) - return &Processor{ - clientWorker: clientWorker, - fedWorker: fedWorker, - - federator: federator, - tc: tc, - oauthServer: oauthServer, - mediaManager: mediaManager, - storage: storage, - statusTimelines: timeline.NewManager(StatusGrabFunction(db), StatusFilterFunction(db, filter), StatusPrepareFunction(db, tc), StatusSkipInsertFunction()), - db: db, - filter: filter, - - // sub processors - account: account.New(db, tc, mediaManager, oauthServer, clientWorker, federator, parseMentionFunc), - admin: admin.New(db, tc, mediaManager, federator.TransportController(), storage, clientWorker), - fedi: fedi.New(db, tc, federator), - media: media.New(db, tc, mediaManager, federator.TransportController(), storage), - report: report.New(db, tc, clientWorker), - status: status.New(db, tc, clientWorker, parseMentionFunc), - stream: stream.New(db, oauthServer), - user: user.New(db, emailSender), + processor := &Processor{ + federator: federator, + tc: tc, + oauthServer: oauthServer, + mediaManager: mediaManager, + statusTimelines: timeline.NewManager( + StatusGrabFunction(state.DB), + StatusFilterFunction(state.DB, filter), + StatusPrepareFunction(state.DB, tc), + StatusSkipInsertFunction(), + ), + state: state, + filter: filter, } + + // sub processors + processor.account = account.New(state, tc, mediaManager, oauthServer, federator, parseMentionFunc) + processor.admin = admin.New(state, tc, mediaManager, federator.TransportController()) + processor.fedi = fedi.New(state, tc, federator) + processor.media = media.New(state, tc, mediaManager, federator.TransportController()) + processor.report = report.New(state, tc) + processor.status = status.New(state, tc, parseMentionFunc) + processor.stream = stream.New(state, oauthServer) + processor.user = user.New(state, emailSender) + + return processor } -// Start starts the Processor, reading from its channels and passing messages back and forth. +func (p *Processor) EnqueueClientAPI(ctx context.Context, msg messages.FromClientAPI) { + log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing client API") + _ = p.state.Workers.ClientAPI.MustEnqueueCtx(ctx, func(ctx context.Context) { + if err := p.ProcessFromClientAPI(ctx, msg); err != nil { + log.Errorf(ctx, "error processing client API message: %v", err) + } + }) +} + +func (p *Processor) EnqueueFederator(ctx context.Context, msg messages.FromFederator) { + log.WithContext(ctx).WithField("msg", msg).Trace("enqueuing federator") + _ = p.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) { + if err := p.ProcessFromFederator(ctx, msg); err != nil { + log.Errorf(ctx, "error processing federator message: %v", err) + } + }) +} + +// Start starts the Processor. func (p *Processor) Start() error { - // Setup and start the client API worker pool - p.clientWorker.SetProcessor(p.ProcessFromClientAPI) - if err := p.clientWorker.Start(); err != nil { - return err - } - - // Setup and start the federator worker pool - p.fedWorker.SetProcessor(p.ProcessFromFederator) - if err := p.fedWorker.Start(); err != nil { - return err - } - - // Start status timelines - if err := p.statusTimelines.Start(); err != nil { - return err - } - - return nil + return p.statusTimelines.Start() } -// Stop stops the processor cleanly, finishing handling any remaining messages before closing down. +// Stop stops the processor cleanly. func (p *Processor) Stop() error { - if err := p.clientWorker.Stop(); err != nil { - return err - } - - if err := p.fedWorker.Stop(); err != nil { - return err - } - - if err := p.statusTimelines.Stop(); err != nil { - return err - } - - return nil + return p.statusTimelines.Stop() } diff --git a/internal/processing/processor_test.go b/internal/processing/processor_test.go index 44857cb47..d8da87bcc 100644 --- a/internal/processing/processor_test.go +++ b/internal/processing/processor_test.go @@ -20,15 +20,14 @@ import ( "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -40,6 +39,7 @@ type ProcessingStandardTestSuite struct { suite.Suite db db.DB storage *storage.Driver + state state.State mediaManager media.Manager typeconverter typeutils.TypeConverter httpClient *testrig.MockHTTPClient @@ -86,25 +86,29 @@ func (suite *ProcessingStandardTestSuite) SetupSuite() { } func (suite *ProcessingStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.testActivities = testrig.NewTestActivities(suite.testAccounts) suite.storage = testrig.NewInMemoryStorage() + suite.state.Storage = suite.storage suite.typeconverter = testrig.NewTestTypeConverter(suite.db) suite.httpClient = testrig.NewMockHTTPClient(nil, "../../testrig/media") - clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - suite.transportController = testrig.NewTestTransportController(suite.httpClient, suite.db, fedWorker) - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, suite.transportController, suite.storage, suite.mediaManager, fedWorker) + suite.transportController = testrig.NewTestTransportController(&suite.state, suite.httpClient) + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager) suite.oauthServer = testrig.NewTestOauthServer(suite.db) suite.emailSender = testrig.NewEmailSender("../../web/template/", nil) - suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, suite.storage, suite.db, suite.emailSender, clientWorker, fedWorker) + suite.processor = processing.NewProcessor(suite.typeconverter, suite.federator, suite.oauthServer, suite.mediaManager, &suite.state, suite.emailSender) + suite.state.Workers.EnqueueClientAPI = suite.processor.EnqueueClientAPI + suite.state.Workers.EnqueueFederator = suite.processor.EnqueueFederator testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../testrig/media") @@ -119,4 +123,5 @@ func (suite *ProcessingStandardTestSuite) TearDownTest() { if err := suite.processor.Stop(); err != nil { panic(err) } + testrig.StopWorkers(&suite.state) } diff --git a/internal/processing/report/create.go b/internal/processing/report/create.go index 726d11666..e0918554e 100644 --- a/internal/processing/report/create.go +++ b/internal/processing/report/create.go @@ -41,7 +41,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form } // validate + fetch target account - targetAccount, err := p.db.GetAccountByID(ctx, form.AccountID) + targetAccount, err := p.state.DB.GetAccountByID(ctx, form.AccountID) if err != nil { if errors.Is(err, db.ErrNoEntries) { err = fmt.Errorf("account with ID %s does not exist", form.AccountID) @@ -52,7 +52,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form } // fetch statuses by IDs given in the report form (noop if no statuses given) - statuses, err := p.db.GetStatuses(ctx, form.StatusIDs) + statuses, err := p.state.DB.GetStatuses(ctx, form.StatusIDs) if err != nil { err = fmt.Errorf("db error fetching report target statuses: %w", err) return nil, gtserror.NewErrorInternalError(err) @@ -79,11 +79,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form Forwarded: &form.Forward, } - if err := p.db.PutReport(ctx, report); err != nil { + if err := p.state.DB.PutReport(ctx, report); err != nil { return nil, gtserror.NewErrorInternalError(err) } - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectProfile, APActivityType: ap.ActivityFlag, GTSModel: report, diff --git a/internal/processing/report/get.go b/internal/processing/report/get.go index af2079b8a..0348c397c 100644 --- a/internal/processing/report/get.go +++ b/internal/processing/report/get.go @@ -32,7 +32,7 @@ // Get returns the user view of a moderation report, with the given id. func (p *Processor) Get(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.Report, gtserror.WithCode) { - report, err := p.db.GetReportByID(ctx, id) + report, err := p.state.DB.GetReportByID(ctx, id) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(err) @@ -64,7 +64,7 @@ func (p *Processor) GetMultiple( minID string, limit int, ) (*apimodel.PageableResponse, gtserror.WithCode) { - reports, err := p.db.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit) + reports, err := p.state.DB.GetReports(ctx, resolved, account.ID, targetAccountID, maxID, sinceID, minID, limit) if err != nil { if err == db.ErrNoEntries { return util.EmptyPageableResponse(), nil diff --git a/internal/processing/report/report.go b/internal/processing/report/report.go index b5f4b301e..bc634af2e 100644 --- a/internal/processing/report/report.go +++ b/internal/processing/report/report.go @@ -19,22 +19,18 @@ package report import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" - "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" ) type Processor struct { - db db.DB - tc typeutils.TypeConverter - clientWorker *concurrency.WorkerPool[messages.FromClientAPI] + state *state.State + tc typeutils.TypeConverter } -func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI]) Processor { +func New(state *state.State, tc typeutils.TypeConverter) Processor { return Processor{ - tc: tc, - db: db, - clientWorker: clientWorker, + state: state, + tc: tc, } } diff --git a/internal/processing/search.go b/internal/processing/search.go index 05a1fe353..c5592fffd 100644 --- a/internal/processing/search.go +++ b/internal/processing/search.go @@ -88,7 +88,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a if username, domain, err := util.ExtractNamestringParts(maybeNamestring); err == nil { l.Trace("search term is a mention, looking it up...") - blocked, err := p.db.IsDomainBlocked(ctx, domain) + blocked, err := p.state.DB.IsDomainBlocked(ctx, domain) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err)) } @@ -120,7 +120,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a if uri, err := url.Parse(query); err == nil { if uri.Scheme == "https" || uri.Scheme == "http" { l.Trace("search term is a uri, looking it up...") - blocked, err := p.db.IsURIBlocked(ctx, uri) + blocked, err := p.state.DB.IsURIBlocked(ctx, uri) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking domain block: %w", err)) } @@ -178,7 +178,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a */ for _, foundAccount := range foundAccounts { // make sure there's no block in either direction between the account and the requester - blocked, err := p.db.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true) + blocked, err := p.state.DB.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true) if err != nil { err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err) return nil, gtserror.NewErrorInternalError(err) @@ -246,14 +246,14 @@ func (p *Processor) searchAccountByURI(ctx context.Context, authed *oauth.Auth, ) // Search the database for existing account with ID URI. - account, err = p.db.GetAccountByURI(ctx, uriStr) + account, err = p.state.DB.GetAccountByURI(ctx, uriStr) if err != nil && !errors.Is(err, db.ErrNoEntries) { return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err) } if account == nil { // Else, search the database for existing by ID URL. - account, err = p.db.GetAccountByURL(ctx, uriStr) + account, err = p.state.DB.GetAccountByURL(ctx, uriStr) if err != nil { if !errors.Is(err, db.ErrNoEntries) { return nil, fmt.Errorf("searchAccountByURI: error checking database for account %s: %w", uriStr, err) @@ -281,7 +281,7 @@ func (p *Processor) searchAccountByUsernameDomain(ctx context.Context, authed *o } // Search the database for existing account with USERNAME@DOMAIN - account, err := p.db.GetAccountByUsernameDomain(ctx, username, domain) + account, err := p.state.DB.GetAccountByUsernameDomain(ctx, username, domain) if err != nil { if !errors.Is(err, db.ErrNoEntries) { return nil, fmt.Errorf("searchAccountByUsernameDomain: error checking database for account %s@%s: %w", username, domain, err) diff --git a/internal/processing/status/bookmark.go b/internal/processing/status/bookmark.go index dde31ea7d..cf3787da2 100644 --- a/internal/processing/status/bookmark.go +++ b/internal/processing/status/bookmark.go @@ -32,7 +32,7 @@ // BookmarkCreate adds a bookmark for the requestingAccount, targeting the given status (no-op if bookmark already exists). func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -50,7 +50,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo // first check if the status is already bookmarked, if so we don't need to do anything newBookmark := true gtsBookmark := >smodel.StatusBookmark{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { // we already have a bookmark for this status newBookmark = false } @@ -67,7 +67,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo Status: targetStatus, } - if err := p.db.Put(ctx, gtsBookmark); err != nil { + if err := p.state.DB.Put(ctx, gtsBookmark); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting bookmark in database: %s", err)) } } @@ -83,7 +83,7 @@ func (p *Processor) BookmarkCreate(ctx context.Context, requestingAccount *gtsmo // BookmarkRemove removes a bookmark for the requesting account, targeting the given status (no-op if bookmark doesn't exist). func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -101,13 +101,13 @@ func (p *Processor) BookmarkRemove(ctx context.Context, requestingAccount *gtsmo // first check if the status is actually bookmarked toUnbookmark := false gtsBookmark := >smodel.StatusBookmark{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err == nil { // we have a bookmark for this status toUnbookmark = true } if toUnbookmark { - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsBookmark); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err)) } } diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go index 4dfe17019..6756d816c 100644 --- a/internal/processing/status/boost.go +++ b/internal/processing/status/boost.go @@ -33,7 +33,7 @@ // BoostCreate processes the boost/reblog of a given status, returning the newly-created boost if all is well. func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -47,7 +47,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel // boost boosts, and it looks absolutely bizarre in the UI if targetStatus.BoostOfID != "" { if targetStatus.BoostOf == nil { - b, err := p.db.GetStatusByID(ctx, targetStatus.BoostOfID) + b, err := p.state.DB.GetStatusByID(ctx, targetStatus.BoostOfID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("couldn't fetch boosted status %s", targetStatus.BoostOfID)) } @@ -74,12 +74,12 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel boostWrapperStatus.BoostOfAccount = targetStatus.Account // put the boost in the database - if err := p.db.PutStatus(ctx, boostWrapperStatus); err != nil { + if err := p.state.DB.PutStatus(ctx, boostWrapperStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityAnnounce, APActivityType: ap.ActivityCreate, GTSModel: boostWrapperStatus, @@ -98,7 +98,7 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel // BoostRemove processes the unboost/unreblog of a given status, returning the status if all is well. func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel.Account, application *gtsmodel.Application, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -128,7 +128,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel Value: requestingAccount.ID, }, } - err = p.db.GetWhere(ctx, where, gtsBoost) + err = p.state.DB.GetWhere(ctx, where, gtsBoost) if err == nil { // we have a boost toUnboost = true @@ -151,7 +151,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel gtsBoost.BoostOf.Account = targetStatus.Account // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityAnnounce, APActivityType: ap.ActivityUndo, GTSModel: gtsBoost, @@ -170,7 +170,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel // StatusBoostedBy returns a slice of accounts that have boosted the given status, filtered according to privacy settings. func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", targetStatusID, err) if !errors.Is(err, db.ErrNoEntries) { @@ -181,7 +181,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm if boostOfID := targetStatus.BoostOfID; boostOfID != "" { // the target status is a boost wrapper, redirect this request to the status it boosts - boostedStatus, err := p.db.GetStatusByID(ctx, boostOfID) + boostedStatus, err := p.state.DB.GetStatusByID(ctx, boostOfID) if err != nil { wrapped := fmt.Errorf("BoostedBy: error fetching status %s: %s", boostOfID, err) if !errors.Is(err, db.ErrNoEntries) { @@ -202,7 +202,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm return nil, gtserror.NewErrorNotFound(err) } - statusReblogs, err := p.db.GetStatusReblogs(ctx, targetStatus) + statusReblogs, err := p.state.DB.GetStatusReblogs(ctx, targetStatus) if err != nil { err = fmt.Errorf("BoostedBy: error seeing who boosted status: %s", err) return nil, gtserror.NewErrorNotFound(err) @@ -211,7 +211,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm // filter account IDs so the user doesn't see accounts they blocked or which blocked them accountIDs := make([]string, 0, len(statusReblogs)) for _, s := range statusReblogs { - blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true) if err != nil { err = fmt.Errorf("BoostedBy: error checking blocks: %s", err) return nil, gtserror.NewErrorNotFound(err) @@ -226,7 +226,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm // fetch accounts + create their API representations apiAccounts := make([]*apimodel.Account, 0, len(accountIDs)) for _, accountID := range accountIDs { - account, err := p.db.GetAccountByID(ctx, accountID) + account, err := p.state.DB.GetAccountByID(ctx, accountID) if err != nil { wrapped := fmt.Errorf("BoostedBy: error fetching account %s: %s", accountID, err) if !errors.Is(err, db.ErrNoEntries) { diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go index f47c850dd..4e5399469 100644 --- a/internal/processing/status/create.go +++ b/internal/processing/status/create.go @@ -61,11 +61,11 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli Text: form.Status, } - if errWithCode := processReplyToID(ctx, p.db, form, account.ID, newStatus); errWithCode != nil { + if errWithCode := processReplyToID(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil { return nil, errWithCode } - if errWithCode := processMediaIDs(ctx, p.db, form, account.ID, newStatus); errWithCode != nil { + if errWithCode := processMediaIDs(ctx, p.state.DB, form, account.ID, newStatus); errWithCode != nil { return nil, errWithCode } @@ -77,17 +77,17 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, appli return nil, gtserror.NewErrorInternalError(err) } - if err := processContent(ctx, p.db, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil { + if err := processContent(ctx, p.state.DB, p.formatter, p.parseMention, form, account.ID, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } // put the new status in the database - if err := p.db.PutStatus(ctx, newStatus); err != nil { + if err := p.state.DB.PutStatus(ctx, newStatus); err != nil { return nil, gtserror.NewErrorInternalError(err) } // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, GTSModel: newStatus, diff --git a/internal/processing/status/delete.go b/internal/processing/status/delete.go index d3a03aad6..0e9510e08 100644 --- a/internal/processing/status/delete.go +++ b/internal/processing/status/delete.go @@ -32,7 +32,7 @@ // Delete processes the delete of a given status, returning the deleted status if the delete goes through. func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -50,7 +50,7 @@ func (p *Processor) Delete(ctx context.Context, requestingAccount *gtsmodel.Acco } // send the status back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityDelete, GTSModel: targetStatus, diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go index 3bcb1835f..3025c720d 100644 --- a/internal/processing/status/fave.go +++ b/internal/processing/status/fave.go @@ -35,7 +35,7 @@ // FaveCreate processes the faving of a given status, returning the updated status if the fave goes through. func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -57,7 +57,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel. // first check if the status is already faved, if so we don't need to do anything newFave := true gtsFave := >smodel.StatusFave{} - if err := p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err == nil { // we already have a fave for this status newFave = false } @@ -77,12 +77,12 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel. URI: uris.GenerateURIForLike(requestingAccount.Username, thisFaveID), } - if err := p.db.Put(ctx, gtsFave); err != nil { + if err := p.state.DB.Put(ctx, gtsFave); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error putting fave in database: %s", err)) } // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityLike, APActivityType: ap.ActivityCreate, GTSModel: gtsFave, @@ -102,7 +102,7 @@ func (p *Processor) FaveCreate(ctx context.Context, requestingAccount *gtsmodel. // FaveRemove processes the unfaving of a given status, returning the updated status if the fave goes through. func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -122,7 +122,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel. var toUnfave bool gtsFave := >smodel.StatusFave{} - err = p.db.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave) + err = p.state.DB.GetWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave) if err == nil { // we have a fave toUnfave = true @@ -138,12 +138,12 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel. if toUnfave { // we had a fave, so take some action to get rid of it - if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil { + if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "status_id", Value: targetStatus.ID}, {Key: "account_id", Value: requestingAccount.ID}}, gtsFave); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error unfaveing status: %s", err)) } // send it back to the processor for async processing - p.clientWorker.Queue(messages.FromClientAPI{ + p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{ APObjectType: ap.ActivityLike, APActivityType: ap.ActivityUndo, GTSModel: gtsFave, @@ -162,7 +162,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel. // FavedBy returns a slice of accounts that have liked the given status, filtered according to privacy settings. func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.Account, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -178,7 +178,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc return nil, gtserror.NewErrorNotFound(errors.New("status is not visible")) } - statusFaves, err := p.db.GetStatusFaves(ctx, targetStatus) + statusFaves, err := p.state.DB.GetStatusFaves(ctx, targetStatus) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing who faved status: %s", err)) } @@ -186,7 +186,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc // filter the list so the user doesn't see accounts they blocked or which blocked them filteredAccounts := []*gtsmodel.Account{} for _, fave := range statusFaves { - blocked, err := p.db.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true) + blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking blocks: %s", err)) } diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index edefeb440..51c384c44 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -31,7 +31,7 @@ // Get gets the given status, taking account of privacy settings and blocks etc. func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -57,7 +57,7 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account // ContextGet returns the context (previous and following posts) from the given status ID. func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Context, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { return nil, gtserror.NewErrorNotFound(fmt.Errorf("error fetching status %s: %s", targetStatusID, err)) } @@ -78,7 +78,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel. Descendants: []apimodel.Status{}, } - parents, err := p.db.GetStatusParents(ctx, targetStatus, false) + parents, err := p.state.DB.GetStatusParents(ctx, targetStatus, false) if err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -96,7 +96,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel. return context.Ancestors[i].ID < context.Ancestors[j].ID }) - children, err := p.db.GetStatusChildren(ctx, targetStatus, false, "") + children, err := p.state.DB.GetStatusChildren(ctx, targetStatus, false, "") if err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/status/pin.go b/internal/processing/status/pin.go index 3e50b0c73..6001a147f 100644 --- a/internal/processing/status/pin.go +++ b/internal/processing/status/pin.go @@ -39,7 +39,7 @@ // - Status is public, unlisted, or followers-only. // - Status is not a boost. func (p *Processor) getPinnableStatus(ctx context.Context, targetStatusID string, requestingAccountID string) (*gtsmodel.Status, gtserror.WithCode) { - targetStatus, err := p.db.GetStatusByID(ctx, targetStatusID) + targetStatus, err := p.state.DB.GetStatusByID(ctx, targetStatusID) if err != nil { err = fmt.Errorf("error fetching status %s: %w", targetStatusID, err) return nil, gtserror.NewErrorNotFound(err) @@ -84,7 +84,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A return nil, gtserror.NewErrorUnprocessableEntity(err, err.Error()) } - pinnedCount, err := p.db.CountAccountPinned(ctx, requestingAccount.ID) + pinnedCount, err := p.state.DB.CountAccountPinned(ctx, requestingAccount.ID) if err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking number of pinned statuses: %w", err)) } @@ -95,7 +95,7 @@ func (p *Processor) PinCreate(ctx context.Context, requestingAccount *gtsmodel.A } targetStatus.PinnedAt = time.Now() - if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { + if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error pinning status: %w", err)) } @@ -126,7 +126,7 @@ func (p *Processor) PinRemove(ctx context.Context, requestingAccount *gtsmodel.A if targetStatus.PinnedAt.IsZero() { targetStatus.PinnedAt = time.Time{} - if err := p.db.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { + if err := p.state.DB.UpdateStatus(ctx, targetStatus, "pinned_at"); err != nil { return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error unpinning status: %w", err)) } } diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go index c91fd85d1..909b06481 100644 --- a/internal/processing/status/status.go +++ b/internal/processing/status/status.go @@ -19,32 +19,28 @@ package status import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/visibility" ) type Processor struct { + state *state.State tc typeutils.TypeConverter - db db.DB filter visibility.Filter formatter text.Formatter - clientWorker *concurrency.WorkerPool[messages.FromClientAPI] parseMention gtsmodel.ParseMentionFunc } // New returns a new status processor. -func New(db db.DB, tc typeutils.TypeConverter, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], parseMention gtsmodel.ParseMentionFunc) Processor { +func New(state *state.State, tc typeutils.TypeConverter, parseMention gtsmodel.ParseMentionFunc) Processor { return Processor{ + state: state, tc: tc, - db: db, - filter: visibility.NewFilter(db), - formatter: text.NewFormatter(db), - clientWorker: clientWorker, + filter: visibility.NewFilter(state.DB), + formatter: text.NewFormatter(state.DB), parseMention: parseMention, } } diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go index 272d2c8ea..1b35b69db 100644 --- a/internal/processing/status/status_test.go +++ b/internal/processing/status/status_test.go @@ -19,17 +19,14 @@ package status_test import ( - "context" - "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" "github.com/superseriousbusiness/gotosocial/internal/processing/status" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/storage" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" @@ -42,9 +39,9 @@ type StatusStandardTestSuite struct { typeConverter typeutils.TypeConverter tc transport.Controller storage *storage.Driver + state state.State mediaManager media.Manager federator federation.Federator - clientWorker *concurrency.WorkerPool[messages.FromClientAPI] // standard suite models testTokens map[string]*gtsmodel.Token @@ -74,21 +71,22 @@ func (suite *StatusStandardTestSuite) SetupSuite() { } func (suite *StatusStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.StartWorkers(&suite.state) + testrig.InitTestConfig() testrig.InitTestLog() - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) suite.typeConverter = testrig.NewTestTypeConverter(suite.db) - suite.clientWorker = concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1) - suite.tc = testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../testrig/media"), suite.db, fedWorker) + suite.state.DB = suite.db + + suite.tc = testrig.NewTestTransportController(&suite.state, testrig.NewMockHTTPClient(nil, "../../../testrig/media")) suite.storage = testrig.NewInMemoryStorage() - suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage) - suite.federator = testrig.NewTestFederator(suite.db, suite.tc, suite.storage, suite.mediaManager, fedWorker) - suite.status = status.New(suite.db, suite.typeConverter, suite.clientWorker, processing.GetParseMentionFunc(suite.db, suite.federator)) - suite.clientWorker.SetProcessor(func(ctx context.Context, msg messages.FromClientAPI) error { return nil }) - suite.NoError(suite.clientWorker.Start()) + suite.state.Storage = suite.storage + suite.mediaManager = testrig.NewTestMediaManager(&suite.state) + suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager) + suite.status = status.New(&suite.state, suite.typeConverter, processing.GetParseMentionFunc(suite.db, suite.federator)) testrig.StandardDBSetup(suite.db, suite.testAccounts) testrig.StandardStorageSetup(suite.storage, "../../../testrig/media") @@ -97,4 +95,5 @@ func (suite *StatusStandardTestSuite) SetupTest() { func (suite *StatusStandardTestSuite) TearDownTest() { testrig.StandardDBTeardown(suite.db) testrig.StandardStorageTeardown(suite.storage) + testrig.StopWorkers(&suite.state) } diff --git a/internal/processing/statustimeline.go b/internal/processing/statustimeline.go index 7c9f36f16..8c8e20316 100644 --- a/internal/processing/statustimeline.go +++ b/internal/processing/statustimeline.go @@ -173,7 +173,7 @@ func (p *Processor) HomeTimelineGet(ctx context.Context, authed *oauth.Auth, max } func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, sinceID string, minID string, limit int, local bool) (*apimodel.PageableResponse, gtserror.WithCode) { - statuses, err := p.db.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local) + statuses, err := p.state.DB.GetPublicTimeline(ctx, maxID, sinceID, minID, limit, local) if err != nil { if err == db.ErrNoEntries { // there are just no entries left @@ -218,7 +218,7 @@ func (p *Processor) PublicTimelineGet(ctx context.Context, authed *oauth.Auth, m } func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, maxID string, minID string, limit int) (*apimodel.PageableResponse, gtserror.WithCode) { - statuses, nextMaxID, prevMinID, err := p.db.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit) + statuses, nextMaxID, prevMinID, err := p.state.DB.GetFavedTimeline(ctx, authed.Account.ID, maxID, minID, limit) if err != nil { if err == db.ErrNoEntries { // there are just no entries left @@ -255,7 +255,7 @@ func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth apiStatuses := []*apimodel.Status{} for _, s := range statuses { targetAccount := >smodel.Account{} - if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil { + if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil { if err == db.ErrNoEntries { log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID) continue @@ -288,7 +288,7 @@ func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth, apiStatuses := []*apimodel.Status{} for _, s := range statuses { targetAccount := >smodel.Account{} - if err := p.db.GetByID(ctx, s.AccountID, targetAccount); err != nil { + if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil { if err == db.ErrNoEntries { log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID) continue diff --git a/internal/processing/stream/authorize.go b/internal/processing/stream/authorize.go index 5f6811db9..a30e6fb33 100644 --- a/internal/processing/stream/authorize.go +++ b/internal/processing/stream/authorize.go @@ -41,7 +41,7 @@ func (p *Processor) Authorize(ctx context.Context, accessToken string) (*gtsmode return nil, gtserror.NewErrorUnauthorized(err) } - user, err := p.db.GetUserByID(ctx, uid) + user, err := p.state.DB.GetUserByID(ctx, uid) if err != nil { if err == db.ErrNoEntries { err := fmt.Errorf("no user found for validated uid %s", uid) @@ -50,7 +50,7 @@ func (p *Processor) Authorize(ctx context.Context, accessToken string) (*gtsmode return nil, gtserror.NewErrorInternalError(err) } - acct, err := p.db.GetAccountByID(ctx, user.AccountID) + acct, err := p.state.DB.GetAccountByID(ctx, user.AccountID) if err != nil { if err == db.ErrNoEntries { err := fmt.Errorf("no account found for validated uid %s", uid) diff --git a/internal/processing/stream/stream.go b/internal/processing/stream/stream.go index 3c38e720a..a10ab2474 100644 --- a/internal/processing/stream/stream.go +++ b/internal/processing/stream/stream.go @@ -22,22 +22,21 @@ "errors" "sync" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/stream" ) type Processor struct { - db db.DB + state *state.State oauthServer oauth.Server - streamMap *sync.Map + streamMap sync.Map } -func New(db db.DB, oauthServer oauth.Server) Processor { +func New(state *state.State, oauthServer oauth.Server) Processor { return Processor{ - db: db, + state: state, oauthServer: oauthServer, - streamMap: &sync.Map{}, } } diff --git a/internal/processing/stream/stream_test.go b/internal/processing/stream/stream_test.go index 907c7e1d0..9e1eb57f2 100644 --- a/internal/processing/stream/stream_test.go +++ b/internal/processing/stream/stream_test.go @@ -24,6 +24,7 @@ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/processing/stream" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -33,19 +34,23 @@ type StreamTestSuite struct { testTokens map[string]*gtsmodel.Token db db.DB oauthServer oauth.Server + state state.State streamProcessor stream.Processor } func (suite *StreamTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.InitTestLog() testrig.InitTestConfig() suite.testAccounts = testrig.NewTestAccounts() suite.testTokens = testrig.NewTestTokens() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db suite.oauthServer = testrig.NewTestOauthServer(suite.db) - suite.streamProcessor = stream.New(suite.db, suite.oauthServer) + suite.streamProcessor = stream.New(&suite.state, suite.oauthServer) testrig.StandardDBSetup(suite.db, suite.testAccounts) } diff --git a/internal/processing/user/email.go b/internal/processing/user/email.go index 349e27f47..c55488954 100644 --- a/internal/processing/user/email.go +++ b/internal/processing/user/email.go @@ -56,7 +56,7 @@ func (p *Processor) EmailSendConfirmation(ctx context.Context, user *gtsmodel.Us // pull our instance entry from the database so we can greet the user nicely in the email instance := >smodel.Instance{} host := config.GetHost() - if err := p.db.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, instance); err != nil { + if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, instance); err != nil { return fmt.Errorf("SendConfirmEmail: error getting instance: %s", err) } @@ -78,7 +78,7 @@ func (p *Processor) EmailSendConfirmation(ctx context.Context, user *gtsmodel.Us user.LastEmailedAt = time.Now() user.UpdatedAt = time.Now() - if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { + if err := p.state.DB.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { return fmt.Errorf("SendConfirmEmail: error updating user entry after email sent: %s", err) } @@ -92,7 +92,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U return nil, gtserror.NewErrorNotFound(errors.New("no token provided")) } - user, err := p.db.GetUserByConfirmationToken(ctx, token) + user, err := p.state.DB.GetUserByConfirmationToken(ctx, token) if err != nil { if err == db.ErrNoEntries { return nil, gtserror.NewErrorNotFound(err) @@ -101,7 +101,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U } if user.Account == nil { - a, err := p.db.GetAccountByID(ctx, user.AccountID) + a, err := p.state.DB.GetAccountByID(ctx, user.AccountID) if err != nil { return nil, gtserror.NewErrorNotFound(err) } @@ -129,7 +129,7 @@ func (p *Processor) EmailConfirm(ctx context.Context, token string) (*gtsmodel.U user.ConfirmationToken = "" user.UpdatedAt = time.Now() - if err := p.db.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { + if err := p.state.DB.UpdateByID(ctx, user, user.ID, updatingColumns...); err != nil { return nil, gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/user/password.go b/internal/processing/user/password.go index 3475e005e..72ef5ffa7 100644 --- a/internal/processing/user/password.go +++ b/internal/processing/user/password.go @@ -44,7 +44,7 @@ func (p *Processor) PasswordChange(ctx context.Context, user *gtsmodel.User, old user.EncryptedPassword = string(newPasswordHash) - if err := p.db.UpdateUser(ctx, user, "encrypted_password"); err != nil { + if err := p.state.DB.UpdateUser(ctx, user, "encrypted_password"); err != nil { return gtserror.NewErrorInternalError(err) } diff --git a/internal/processing/user/user.go b/internal/processing/user/user.go index fce628d0c..4fda4c1f6 100644 --- a/internal/processing/user/user.go +++ b/internal/processing/user/user.go @@ -19,19 +19,19 @@ package user import ( - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" + "github.com/superseriousbusiness/gotosocial/internal/state" ) type Processor struct { + state *state.State emailSender email.Sender - db db.DB } // New returns a new user processor -func New(db db.DB, emailSender email.Sender) Processor { +func New(state *state.State, emailSender email.Sender) Processor { return Processor{ + state: state, emailSender: emailSender, - db: db, } } diff --git a/internal/processing/user/user_test.go b/internal/processing/user/user_test.go index 83ab5892e..7379b568e 100644 --- a/internal/processing/user/user_test.go +++ b/internal/processing/user/user_test.go @@ -24,6 +24,7 @@ "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing/user" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -31,6 +32,7 @@ type UserStandardTestSuite struct { suite.Suite emailSender email.Sender db db.DB + state state.State testUsers map[string]*gtsmodel.User @@ -40,15 +42,19 @@ type UserStandardTestSuite struct { } func (suite *UserStandardTestSuite) SetupTest() { + suite.state.Caches.Init() + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&suite.state) + suite.state.DB = suite.db + suite.sentEmails = make(map[string]string) suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails) suite.testUsers = testrig.NewTestUsers() - suite.user = user.New(suite.db, suite.emailSender) + suite.user = user.New(&suite.state, suite.emailSender) testrig.StandardDBSetup(suite.db, nil) } diff --git a/internal/text/formatter_test.go b/internal/text/formatter_test.go index 32ae74488..304a538fc 100644 --- a/internal/text/formatter_test.go +++ b/internal/text/formatter_test.go @@ -20,12 +20,12 @@ import ( "context" + "github.com/stretchr/testify/suite" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -66,13 +66,15 @@ func (suite *TextStandardTestSuite) SetupSuite() { } func (suite *TextStandardTestSuite) SetupTest() { + var state state.State + state.Caches.Init() + testrig.InitTestLog() testrig.InitTestConfig() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&state) - fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1) - federator := testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../testrig/media"), suite.db, fedWorker), nil, nil, fedWorker) + federator := testrig.NewTestFederator(&state, testrig.NewTestTransportController(&state, testrig.NewMockHTTPClient(nil, "../../testrig/media")), nil) suite.parseMention = processing.GetParseMentionFunc(suite.db, federator) suite.formatter = text.NewFormatter(suite.db) diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go index 9be1fdb90..0c866c7a8 100644 --- a/internal/timeline/get_test.go +++ b/internal/timeline/get_test.go @@ -27,6 +27,7 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" @@ -42,10 +43,13 @@ func (suite *GetTestSuite) SetupSuite() { } func (suite *GetTestSuite) SetupTest() { + var state state.State + state.Caches.Init() + testrig.InitTestLog() testrig.InitTestConfig() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&state) suite.tc = testrig.NewTestTypeConverter(suite.db) suite.filter = visibility.NewFilter(suite.db) diff --git a/internal/timeline/index_test.go b/internal/timeline/index_test.go index 692688aba..9d79f12c2 100644 --- a/internal/timeline/index_test.go +++ b/internal/timeline/index_test.go @@ -26,6 +26,7 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" @@ -41,10 +42,13 @@ func (suite *IndexTestSuite) SetupSuite() { } func (suite *IndexTestSuite) SetupTest() { + var state state.State + state.Caches.Init() + testrig.InitTestLog() testrig.InitTestConfig() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&state) suite.tc = testrig.NewTestTypeConverter(suite.db) suite.filter = visibility.NewFilter(suite.db) diff --git a/internal/timeline/manager_test.go b/internal/timeline/manager_test.go index 03804bf78..e033ffda4 100644 --- a/internal/timeline/manager_test.go +++ b/internal/timeline/manager_test.go @@ -24,6 +24,7 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" @@ -39,10 +40,13 @@ func (suite *ManagerTestSuite) SetupSuite() { } func (suite *ManagerTestSuite) SetupTest() { + var state state.State + state.Caches.Init() + testrig.InitTestLog() testrig.InitTestConfig() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&state) suite.tc = testrig.NewTestTypeConverter(suite.db) suite.filter = visibility.NewFilter(suite.db) diff --git a/internal/timeline/prune_test.go b/internal/timeline/prune_test.go index 9d539e0e0..48bba41dc 100644 --- a/internal/timeline/prune_test.go +++ b/internal/timeline/prune_test.go @@ -26,6 +26,7 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/processing" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/timeline" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" @@ -41,10 +42,13 @@ func (suite *PruneTestSuite) SetupSuite() { } func (suite *PruneTestSuite) SetupTest() { + var state state.State + state.Caches.Init() + testrig.InitTestLog() testrig.InitTestConfig() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&state) suite.tc = testrig.NewTestTypeConverter(suite.db) suite.filter = visibility.NewFilter(suite.db) diff --git a/internal/trans/import_test.go b/internal/trans/import_test.go index 128ac58a3..a53305c79 100644 --- a/internal/trans/import_test.go +++ b/internal/trans/import_test.go @@ -27,6 +27,7 @@ "github.com/google/uuid" "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/trans" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -57,8 +58,11 @@ func (suite *ImportMinimalTestSuite) TestImportMinimalOK() { suite.NotEmpty(b) fmt.Println(string(b)) + var state state.State + state.Caches.Init() + // create a new database with just the tables created, no entries - newDB := testrig.NewTestDB() + newDB := testrig.NewTestDB(&state) importer := trans.NewImporter(newDB) err = importer.Import(ctx, tempFilePath) diff --git a/internal/trans/trans_test.go b/internal/trans/trans_test.go index 9364891a0..2b6bbb57b 100644 --- a/internal/trans/trans_test.go +++ b/internal/trans/trans_test.go @@ -22,6 +22,7 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -32,12 +33,15 @@ type TransTestSuite struct { } func (suite *TransTestSuite) SetupTest() { + var state state.State + state.Caches.Init() + testrig.InitTestConfig() testrig.InitTestLog() suite.testAccounts = testrig.NewTestAccounts() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&state) testrig.StandardDBSetup(suite.db, nil) } diff --git a/internal/typeutils/converter_test.go b/internal/typeutils/converter_test.go index c6f3c2579..bc81a7c6d 100644 --- a/internal/typeutils/converter_test.go +++ b/internal/typeutils/converter_test.go @@ -23,6 +23,7 @@ "github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -481,10 +482,13 @@ type TypeUtilsTestSuite struct { } func (suite *TypeUtilsTestSuite) SetupSuite() { + var state state.State + state.Caches.Init() + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&state) suite.testAccounts = testrig.NewTestAccounts() suite.testStatuses = testrig.NewTestStatuses() suite.testAttachments = testrig.NewTestAttachments() diff --git a/internal/visibility/filter_test.go b/internal/visibility/filter_test.go index bd7a8671e..9697dd72c 100644 --- a/internal/visibility/filter_test.go +++ b/internal/visibility/filter_test.go @@ -22,6 +22,7 @@ "github.com/stretchr/testify/suite" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/visibility" "github.com/superseriousbusiness/gotosocial/testrig" ) @@ -60,10 +61,13 @@ func (suite *FilterStandardTestSuite) SetupSuite() { } func (suite *FilterStandardTestSuite) SetupTest() { + var state state.State + state.Caches.Init() + testrig.InitTestConfig() testrig.InitTestLog() - suite.db = testrig.NewTestDB() + suite.db = testrig.NewTestDB(&state) suite.filter = visibility.NewFilter(suite.db) testrig.StandardDBSetup(suite.db, nil) diff --git a/internal/workers/workers.go b/internal/workers/workers.go index 77b3065ce..b29d115aa 100644 --- a/internal/workers/workers.go +++ b/internal/workers/workers.go @@ -19,20 +19,28 @@ package workers import ( + "context" "log" "runtime" "codeberg.org/gruf/go-runners" "codeberg.org/gruf/go-sched" + "github.com/superseriousbusiness/gotosocial/internal/messages" ) type Workers struct { // Main task scheduler instance. Scheduler sched.Scheduler - // Processor / federator worker pools. - // ClientAPI runners.WorkerPool - // Federator runners.WorkerPool + // ClientAPI / federator worker pools. + ClientAPI runners.WorkerPool + Federator runners.WorkerPool + + // Enqueue functions for clientAPI / federator worker pools, + // these are pointers to Processor{}.Enqueue___() msg functions. + // This prevents dependency cycling as Processor depends on Workers. + EnqueueClientAPI func(context.Context, messages.FromClientAPI) + EnqueueFederator func(context.Context, messages.FromFederator) // Media manager worker pools. Media runners.WorkerPool @@ -50,13 +58,13 @@ func (w *Workers) Start() { return w.Scheduler.Start(nil) }) - // tryUntil("starting client API workerpool", 5, func() bool { - // return w.ClientAPI.Start(4*maxprocs, 400*maxprocs) - // }) + tryUntil("starting client API workerpool", 5, func() bool { + return w.ClientAPI.Start(4*maxprocs, 400*maxprocs) + }) - // tryUntil("starting federator workerpool", 5, func() bool { - // return w.Federator.Start(4*maxprocs, 400*maxprocs) - // }) + tryUntil("starting federator workerpool", 5, func() bool { + return w.Federator.Start(4*maxprocs, 400*maxprocs) + }) tryUntil("starting media workerpool", 5, func() bool { return w.Media.Start(8*maxprocs, 80*maxprocs) @@ -66,8 +74,8 @@ func (w *Workers) Start() { // Stop will stop all of the contained worker pools (and global scheduler). func (w *Workers) Stop() { tryUntil("stopping scheduler", 5, w.Scheduler.Stop) - // tryUntil("stopping client API workerpool", 5, w.ClientAPI.Stop) - // tryUntil("stopping federator workerpool", 5, w.Federator.Stop) + tryUntil("stopping client API workerpool", 5, w.ClientAPI.Stop) + tryUntil("stopping federator workerpool", 5, w.Federator.Stop) tryUntil("stopping media workerpool", 5, w.Media.Stop) } diff --git a/testrig/db.go b/testrig/db.go index 8479347eb..1a29aa8b9 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -71,7 +71,7 @@ // // If the environment variable GTS_DB_PORT is set, it will take that // value as the port instead. -func NewTestDB() db.DB { +func NewTestDB(state *state.State) db.DB { if alternateAddress := os.Getenv("GTS_DB_ADDRESS"); alternateAddress != "" { config.SetDbAddress(alternateAddress) } @@ -88,10 +88,9 @@ func NewTestDB() db.DB { config.SetDbPort(int(port)) } - var state state.State state.Caches.Init() - testDB, err := bundb.NewBunDBService(context.Background(), &state) + testDB, err := bundb.NewBunDBService(context.Background(), state) if err != nil { log.Panic(nil, err) } diff --git a/testrig/federatingdb.go b/testrig/federatingdb.go index 9b1f1961e..27adc4c51 100644 --- a/testrig/federatingdb.go +++ b/testrig/federatingdb.go @@ -19,13 +19,11 @@ package testrig import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" - "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" ) // NewTestFederatingDB returns a federating DB with the underlying db -func NewTestFederatingDB(db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federatingdb.DB { - return federatingdb.New(db, fedWorker, NewTestTypeConverter(db)) +func NewTestFederatingDB(state *state.State) federatingdb.DB { + return federatingdb.New(state, NewTestTypeConverter(state.DB)) } diff --git a/testrig/federator.go b/testrig/federator.go index 605a2c8f3..bc150633e 100644 --- a/testrig/federator.go +++ b/testrig/federator.go @@ -19,16 +19,13 @@ package testrig import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/transport" ) // NewTestFederator returns a federator with the given database and (mock!!) transport controller. -func NewTestFederator(db db.DB, tc transport.Controller, storage *storage.Driver, mediaManager media.Manager, fedWorker *concurrency.WorkerPool[messages.FromFederator]) federation.Federator { - return federation.NewFederator(db, NewTestFederatingDB(db, fedWorker), tc, NewTestTypeConverter(db), mediaManager) +func NewTestFederator(state *state.State, tc transport.Controller, mediaManager media.Manager) federation.Federator { + return federation.NewFederator(state.DB, NewTestFederatingDB(state), tc, NewTestTypeConverter(state.DB), mediaManager) } diff --git a/testrig/mediahandler.go b/testrig/mediahandler.go index a1863218c..b4b992b0b 100644 --- a/testrig/mediahandler.go +++ b/testrig/mediahandler.go @@ -19,17 +19,12 @@ package testrig import ( - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/state" - "github.com/superseriousbusiness/gotosocial/internal/storage" ) // NewTestMediaManager returns a media handler with the default test config, and the given db and storage. -func NewTestMediaManager(db db.DB, storage *storage.Driver) media.Manager { - var state state.State - state.DB = db - state.Storage = storage - state.Workers.Start() - return media.NewManager(&state) +func NewTestMediaManager(state *state.State) media.Manager { + StartWorkers(state) // ensure started + return media.NewManager(state) } diff --git a/testrig/processor.go b/testrig/processor.go index f451d4ad0..856ee523d 100644 --- a/testrig/processor.go +++ b/testrig/processor.go @@ -19,17 +19,17 @@ package testrig import ( - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/media" - "github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/processing" - "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/state" ) // NewTestProcessor returns a Processor suitable for testing purposes -func NewTestProcessor(db db.DB, storage *storage.Driver, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager, clientWorker *concurrency.WorkerPool[messages.FromClientAPI], fedWorker *concurrency.WorkerPool[messages.FromFederator]) *processing.Processor { - return processing.NewProcessor(NewTestTypeConverter(db), federator, NewTestOauthServer(db), mediaManager, storage, db, emailSender, clientWorker, fedWorker) +func NewTestProcessor(state *state.State, federator federation.Federator, emailSender email.Sender, mediaManager media.Manager) *processing.Processor { + p := processing.NewProcessor(NewTestTypeConverter(state.DB), federator, NewTestOauthServer(state.DB), mediaManager, state, emailSender) + state.Workers.EnqueueClientAPI = p.EnqueueClientAPI + state.Workers.EnqueueFederator = p.EnqueueFederator + return p } diff --git a/testrig/transportcontroller.go b/testrig/transportcontroller.go index 7565a741c..9657205f6 100644 --- a/testrig/transportcontroller.go +++ b/testrig/transportcontroller.go @@ -30,12 +30,10 @@ "github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams/vocab" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/concurrency" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" - "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/transport" ) @@ -53,8 +51,8 @@ // Unlike the other test interfaces provided in this package, you'll probably want to call this function // PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular) // basis. -func NewTestTransportController(client pub.HttpClient, db db.DB, fedWorker *concurrency.WorkerPool[messages.FromFederator]) transport.Controller { - return transport.NewController(db, NewTestFederatingDB(db, fedWorker), &federation.Clock{}, client) +func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller { + return transport.NewController(state.DB, NewTestFederatingDB(state), &federation.Clock{}, client) } type MockHTTPClient struct { diff --git a/testrig/util.go b/testrig/util.go index cc392b315..0cda93024 100644 --- a/testrig/util.go +++ b/testrig/util.go @@ -20,13 +20,34 @@ import ( "bytes" + "context" "io" "mime/multipart" "net/url" "os" "time" + + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/state" ) +func StartWorkers(state *state.State) { + state.Workers.EnqueueClientAPI = func(context.Context, messages.FromClientAPI) {} + state.Workers.EnqueueFederator = func(context.Context, messages.FromFederator) {} + + _ = state.Workers.Scheduler.Start(nil) + _ = state.Workers.ClientAPI.Start(1, 10) + _ = state.Workers.Federator.Start(1, 10) + _ = state.Workers.Media.Start(1, 10) +} + +func StopWorkers(state *state.State) { + _ = state.Workers.Scheduler.Stop() + _ = state.Workers.ClientAPI.Stop() + _ = state.Workers.Federator.Stop() + _ = state.Workers.Media.Stop() +} + // CreateMultipartFormData is a handy function for taking a fieldname and a filename, and creating a multipart form bytes buffer // with the file contents set in the given fieldname. The extraFields param can be used to add extra FormFields to the request, as necessary. // The returned bytes.Buffer b can be used like so: