diff --git a/cmd/gotosocial/action/server/server.go b/cmd/gotosocial/action/server/server.go index 100897a41..ddf783cda 100644 --- a/cmd/gotosocial/action/server/server.go +++ b/cmd/gotosocial/action/server/server.go @@ -30,44 +30,41 @@ "time" "github.com/KimMachineGun/automemlimit/memlimit" - webpushgo "github.com/SherClockHolmes/webpush-go" "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/internal/api" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/cleaner" - "github.com/superseriousbusiness/gotosocial/internal/filter/interaction" - "github.com/superseriousbusiness/gotosocial/internal/filter/spam" - "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" - "github.com/superseriousbusiness/gotosocial/internal/gtserror" - "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/media/ffmpeg" - "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/metrics" - "github.com/superseriousbusiness/gotosocial/internal/middleware" - tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" - "github.com/superseriousbusiness/gotosocial/internal/timeline" - "github.com/superseriousbusiness/gotosocial/internal/tracing" - "github.com/superseriousbusiness/gotosocial/internal/webpush" - "go.uber.org/automaxprocs/maxprocs" - "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/email" "github.com/superseriousbusiness/gotosocial/internal/federation" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" + "github.com/superseriousbusiness/gotosocial/internal/filter/interaction" + "github.com/superseriousbusiness/gotosocial/internal/filter/spam" + "github.com/superseriousbusiness/gotosocial/internal/filter/visibility" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/httpclient" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" + "github.com/superseriousbusiness/gotosocial/internal/media/ffmpeg" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/metrics" + "github.com/superseriousbusiness/gotosocial/internal/middleware" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oidc" "github.com/superseriousbusiness/gotosocial/internal/processing" + tlprocessor "github.com/superseriousbusiness/gotosocial/internal/processing/timeline" "github.com/superseriousbusiness/gotosocial/internal/router" "github.com/superseriousbusiness/gotosocial/internal/state" gtsstorage "github.com/superseriousbusiness/gotosocial/internal/storage" + "github.com/superseriousbusiness/gotosocial/internal/timeline" + "github.com/superseriousbusiness/gotosocial/internal/tracing" "github.com/superseriousbusiness/gotosocial/internal/transport" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/web" + "github.com/superseriousbusiness/gotosocial/internal/webpush" + "go.uber.org/automaxprocs/maxprocs" ) // Start creates and starts a gotosocial server @@ -246,19 +243,8 @@ } // Get or create a VAPID key pair. - vapidKeyPair, err := dbService.GetVAPIDKeyPair(ctx) - if err != nil { - return gtserror.Newf("error getting VAPID key pair: %w", err) - } - if vapidKeyPair == nil { - // Generate and store a new key pair. - vapidKeyPair = >smodel.VAPIDKeyPair{} - if vapidKeyPair.Private, vapidKeyPair.Public, err = webpushgo.GenerateVAPIDKeys(); err != nil { - return gtserror.Newf("error generating VAPID key pair: %w", err) - } - if err := dbService.PutVAPIDKeyPair(ctx, vapidKeyPair); err != nil { - return gtserror.Newf("error putting VAPID key pair: %w", err) - } + if _, err := dbService.GetVAPIDKeyPair(ctx); err != nil { + return gtserror.Newf("error getting or creating VAPID key pair: %w", err) } // Create a Web Push notification sender. diff --git a/internal/db/bundb/migrations/20241124012636_add_web_push_subscriptions.go b/internal/db/bundb/migrations/20241124012636_add_web_push_subscriptions.go index a9487fc24..269470ee9 100644 --- a/internal/db/bundb/migrations/20241124012636_add_web_push_subscriptions.go +++ b/internal/db/bundb/migrations/20241124012636_add_web_push_subscriptions.go @@ -41,7 +41,7 @@ func init() { } { if _, err := tx. NewCreateIndex(). - Table("web_push_subscriptions"). + Model(>smodel.WebPushSubscription{}). Index(index). Column(columns...). IfNotExists(). diff --git a/internal/db/bundb/webpush.go b/internal/db/bundb/webpush.go index 1472e7d7b..c61209573 100644 --- a/internal/db/bundb/webpush.go +++ b/internal/db/bundb/webpush.go @@ -21,6 +21,7 @@ "context" "errors" + webpushgo "github.com/SherClockHolmes/webpush-go" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -35,6 +36,44 @@ type webPushDB struct { } func (w *webPushDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) { + var err error + + vapidKeyPair, err := w.getVAPIDKeyPair(ctx) + if err != nil { + return nil, err + } + if vapidKeyPair != nil { + return vapidKeyPair, nil + } + + // If there aren't any, generate new ones. + vapidKeyPair = >smodel.VAPIDKeyPair{} + if vapidKeyPair.Private, vapidKeyPair.Public, err = webpushgo.GenerateVAPIDKeys(); err != nil { + return nil, gtserror.Newf("error generating VAPID key pair: %w", err) + } + + // Store the keys in the database. + if _, err = w.db.NewInsert(). + Model(vapidKeyPair). + Exec(ctx); // nocollapse + err != nil { + if errors.Is(err, db.ErrAlreadyExists) { + // Multiple concurrent attempts to generate new keys, and this one didn't win. + // Get the results of the one that did. + return w.getVAPIDKeyPair(ctx) + } + return nil, err + } + + // Cache the keys. + w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair) + + return vapidKeyPair, nil +} + +// getVAPIDKeyPair gets an existing VAPID key pair from cache or DB. +// If there is no existing VAPID key pair, it returns nil, with no error. +func (w *webPushDB) getVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) { // Look for cached keys. vapidKeyPair := w.state.Caches.DB.VAPIDKeyPair.Load() if vapidKeyPair != nil { @@ -54,23 +93,20 @@ func (w *webPushDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair return nil, err } - // Cache the keys. - w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair) - return vapidKeyPair, nil } -func (w *webPushDB) PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error { - // Store the keys in the database. - if _, err := w.db.NewInsert(). - Model(vapidKeyPair). +func (w *webPushDB) DeleteVAPIDKeyPair(ctx context.Context) error { + // Delete any existing keys. + if _, err := w.db.NewTruncateTable(). + Model((*gtsmodel.VAPIDKeyPair)(nil)). Exec(ctx); // nocollapse err != nil { return err } - // Cache the keys. - w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair) + // Clear the key cache. + w.state.Caches.DB.VAPIDKeyPair.Store(nil) return nil } diff --git a/internal/db/bundb/webpush_test.go b/internal/db/bundb/webpush_test.go new file mode 100644 index 000000000..8ca83955a --- /dev/null +++ b/internal/db/bundb/webpush_test.go @@ -0,0 +1,81 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package bundb_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/suite" +) + +type WebPushTestSuite struct { + BunDBStandardTestSuite +} + +// Get the text fixture VAPID key pair. +func (suite *WebPushTestSuite) TestGetVAPIDKeyPair() { + ctx := context.Background() + + vapidKeyPair, err := suite.db.GetVAPIDKeyPair(ctx) + suite.NoError(err) + if !suite.NotNil(vapidKeyPair) { + suite.FailNow("Got a nil VAPID key pair, can't continue") + } + suite.NotEmpty(vapidKeyPair.Private) + suite.NotEmpty(vapidKeyPair.Public) + + // Get it again. It should be the same one. + vapidKeyPair2, err := suite.db.GetVAPIDKeyPair(ctx) + suite.NoError(err) + if suite.NotNil(vapidKeyPair2) { + suite.Equal(vapidKeyPair.Private, vapidKeyPair2.Private) + suite.Equal(vapidKeyPair.Public, vapidKeyPair2.Public) + } +} + +// Generate a VAPID key pair when there isn't one. +func (suite *WebPushTestSuite) TestGenerateVAPIDKeyPair() { + ctx := context.Background() + + // Delete the text fixture VAPID key pair. + if err := suite.db.DeleteVAPIDKeyPair(ctx); !suite.NoError(err) { + suite.FailNow("Test setup failed: DB error deleting fixture VAPID key pair: %v", err) + } + + // Get a new one. + vapidKeyPair, err := suite.db.GetVAPIDKeyPair(ctx) + suite.NoError(err) + if !suite.NotNil(vapidKeyPair) { + suite.FailNow("Got a nil VAPID key pair, can't continue") + } + suite.NotEmpty(vapidKeyPair.Private) + suite.NotEmpty(vapidKeyPair.Public) + + // Get it again. It should be the same one. + vapidKeyPair2, err := suite.db.GetVAPIDKeyPair(ctx) + suite.NoError(err) + if suite.NotNil(vapidKeyPair2) { + suite.Equal(vapidKeyPair.Private, vapidKeyPair2.Private) + suite.Equal(vapidKeyPair.Public, vapidKeyPair2.Public) + } +} + +func TestWebPushTestSuite(t *testing.T) { + suite.Run(t, new(WebPushTestSuite)) +} diff --git a/internal/db/webpush.go b/internal/db/webpush.go index 05c76e0d5..22bf449de 100644 --- a/internal/db/webpush.go +++ b/internal/db/webpush.go @@ -26,12 +26,11 @@ // WebPush contains functions related to Web Push notifications. type WebPush interface { // GetVAPIDKeyPair retrieves the server's existing VAPID key pair, if there is one. - // If there isn't, it returns nil. + // If there isn't one, it generates a new one, stores it, and returns that. GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) - // PutVAPIDKeyPair stores the server's VAPID key pair. - // This should be called at most once, during server startup. - PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error + // DeleteVAPIDKeyPair deletes the server's VAPID key pair. + DeleteVAPIDKeyPair(ctx context.Context) error // GetWebPushSubscriptionByTokenID retrieves an access token's Web Push subscription. // There may not be one, in which case an error will be returned. diff --git a/testrig/db.go b/testrig/db.go index dd19c3648..d33a63f12 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -20,7 +20,6 @@ import ( "context" - webpushgo "github.com/SherClockHolmes/webpush-go" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -377,12 +376,8 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { log.Panic(ctx, err) } - vapidKeyPair := >smodel.VAPIDKeyPair{} - var err error - if vapidKeyPair.Private, vapidKeyPair.Public, err = webpushgo.GenerateVAPIDKeys(); err != nil { - log.Panic(nil, err) - } - if err = db.PutVAPIDKeyPair(ctx, vapidKeyPair); err != nil { + // Generates and stores a VAPID key pair as a side effect. + if _, err := db.GetVAPIDKeyPair(ctx); err != nil { log.Panic(nil, err) }