From 17fd35661d6ef8a6399a2594d3231fe40ca3a916 Mon Sep 17 00:00:00 2001 From: Vyr Cossont Date: Sat, 30 Nov 2024 12:24:13 -0800 Subject: [PATCH] Store Web Push subscriptions in DB --- internal/cache/cache.go | 2 + internal/cache/db.go | 53 ++++- internal/cache/invalidate.go | 10 + internal/cache/size.go | 18 +- internal/config/config.go | 2 + internal/config/defaults.go | 2 + internal/config/helpers.gen.go | 58 +++++ internal/db/admin.go | 8 - internal/db/application.go | 3 + internal/db/bundb/admin.go | 36 ---- internal/db/bundb/application.go | 10 + internal/db/bundb/bundb.go | 5 + internal/db/bundb/notification_test.go | 2 +- internal/db/bundb/webpush.go | 203 ++++++++++++++++++ internal/db/db.go | 1 + internal/db/webpush.go | 53 +++++ internal/gtsmodel/notification.go | 23 +- internal/gtsmodel/vapidkeypair.go | 6 +- internal/gtsmodel/webpushsubscription.go | 67 ++++++ internal/processing/timeline/notification.go | 2 +- .../processing/workers/fromfediapi_test.go | 4 +- internal/processing/workers/surfacenotify.go | 4 +- test/envparsing.sh | 2 + testrig/db.go | 9 + testrig/testmodels.go | 32 ++- 25 files changed, 546 insertions(+), 69 deletions(-) create mode 100644 internal/db/bundb/webpush.go create mode 100644 internal/db/webpush.go create mode 100644 internal/gtsmodel/webpushsubscription.go diff --git a/internal/cache/cache.go b/internal/cache/cache.go index a4f9f2044..152ae33d7 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -115,6 +115,8 @@ func (c *Caches) Init() { c.initUserMute() c.initUserMuteIDs() c.initWebfinger() + c.initWebPushSubscription() + c.initWebPushSubscriptionIDs() c.initVisibility() c.initStatusesFilterableFields() } diff --git a/internal/cache/db.go b/internal/cache/db.go index aac11236a..c264d5567 100644 --- a/internal/cache/db.go +++ b/internal/cache/db.go @@ -252,6 +252,15 @@ type DBCaches struct { // UserMuteIDs provides access to the user mute IDs database cache. UserMuteIDs SliceCache[string] + + // VAPIDKeyPair caches the server's VAPID key pair. + VAPIDKeyPair atomic.Pointer[gtsmodel.VAPIDKeyPair] + + // WebPushSubscription provides access to the gtsmodel WebPushSubscription database cache. + WebPushSubscription StructCache[*gtsmodel.WebPushSubscription] + + // WebPushSubscriptionIDs provides access to the Web Push subscription IDs database cache. + WebPushSubscriptionIDs SliceCache[string] } // NOTE: @@ -1509,9 +1518,10 @@ func (c *Caches) initToken() { {Fields: "Refresh"}, {Fields: "ClientID", Multiple: true}, }, - MaxSize: cap, - IgnoreErr: ignoreErrors, - Copy: copyF, + MaxSize: cap, + IgnoreErr: ignoreErrors, + Copy: copyF, + Invalidate: c.OnInvalidateToken, }) } @@ -1621,3 +1631,40 @@ func (c *Caches) initUserMuteIDs() { c.DB.UserMuteIDs.Init(0, cap) } + +func (c *Caches) initWebPushSubscription() { + cap := calculateResultCacheMax( + sizeofWebPushSubscription(), // model in-mem size. + config.GetCacheWebPushSubscriptionMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + copyF := func(s1 *gtsmodel.WebPushSubscription) *gtsmodel.WebPushSubscription { + s2 := new(gtsmodel.WebPushSubscription) + *s2 = *s1 + return s2 + } + + c.DB.WebPushSubscription.Init(structr.CacheConfig[*gtsmodel.WebPushSubscription]{ + Indices: []structr.IndexConfig{ + {Fields: "ID"}, + {Fields: "TokenID"}, + {Fields: "AccountID", Multiple: true}, + }, + MaxSize: cap, + IgnoreErr: ignoreErrors, + Invalidate: c.OnInvalidateWebPushSubscription, + Copy: copyF, + }) +} + +func (c *Caches) initWebPushSubscriptionIDs() { + cap := calculateSliceCacheMax( + config.GetCacheWebPushSubscriptionIDsMemRatio(), + ) + + log.Infof(nil, "cache size = %d", cap) + + c.DB.WebPushSubscriptionIDs.Init(0, cap) +} diff --git a/internal/cache/invalidate.go b/internal/cache/invalidate.go index 9b42e88f6..be3eaa735 100644 --- a/internal/cache/invalidate.go +++ b/internal/cache/invalidate.go @@ -278,6 +278,11 @@ func (c *Caches) OnInvalidateStatusFave(fave *gtsmodel.StatusFave) { c.DB.StatusFaveIDs.Invalidate(fave.StatusID) } +func (c *Caches) OnInvalidateToken(token *gtsmodel.Token) { + // Invalidate token's push subscription. + c.DB.WebPushSubscription.Invalidate("ID", token.ID) +} + func (c *Caches) OnInvalidateUser(user *gtsmodel.User) { // Invalidate local account ID cached visibility. c.Visibility.Invalidate("ItemID", user.AccountID) @@ -291,3 +296,8 @@ func (c *Caches) OnInvalidateUserMute(mute *gtsmodel.UserMute) { // Invalidate source account's user mute lists. c.DB.UserMuteIDs.Invalidate(mute.AccountID) } + +func (c *Caches) OnInvalidateWebPushSubscription(subscription *gtsmodel.WebPushSubscription) { + // Invalidate source account's Web Push subscription list. + c.DB.WebPushSubscriptionIDs.Invalidate(subscription.AccountID) +} diff --git a/internal/cache/size.go b/internal/cache/size.go index 26f4096ed..abed1e3b6 100644 --- a/internal/cache/size.go +++ b/internal/cache/size.go @@ -66,6 +66,14 @@ // be a serialized string of almost any type, so we pick a // nice serialized key size on the upper end of normal. sizeofResultKey = 2 * sizeofIDStr + + // exampleWebPushAuth is a Base64-encoded 16-byte random auth secret. + // This secret is consumed as Base64 by webpush-go. + exampleWebPushAuth = "ZVxqlt5fzVgmSz2aqiA2XQ==" + + // exampleWebPushP256dh is a Base64-encoded DH P-256 public key. + // This secret is consumed as Base64 by webpush-go. + exampleWebPushP256dh = "OrpejO16gV97uBXew/T0I7YoUv/CX8fz0z4g8RrQ+edXJqQPjX3XVSo2P0HhcCpCOR1+Dzj5LFcK9jYNqX7SBg==" ) var ( @@ -558,7 +566,7 @@ func sizeofMove() uintptr { func sizeofNotification() uintptr { return uintptr(size.Of(>smodel.Notification{ ID: exampleID, - NotificationType: gtsmodel.NotificationFave, + NotificationType: gtsmodel.NotificationFavourite, CreatedAt: exampleTime, TargetAccountID: exampleID, OriginAccountID: exampleID, @@ -786,3 +794,11 @@ func sizeofUserMute() uintptr { Notifications: util.Ptr(false), })) } + +func sizeofWebPushSubscription() uintptr { + return uintptr(size.Of(>smodel.WebPushSubscription{ + TokenID: exampleID, + Auth: exampleWebPushAuth, + P256dh: exampleWebPushP256dh, + })) +} diff --git a/internal/config/config.go b/internal/config/config.go index 2e3ad8ec1..d9491740e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -248,6 +248,8 @@ type CacheConfiguration struct { UserMuteMemRatio float64 `name:"user-mute-mem-ratio"` UserMuteIDsMemRatio float64 `name:"user-mute-ids-mem-ratio"` WebfingerMemRatio float64 `name:"webfinger-mem-ratio"` + WebPushSubscriptionMemRatio float64 `name:"web-push-subscription-mem-ratio"` + WebPushSubscriptionIDsMemRatio float64 `name:"web-push-subscription-ids-mem-ratio"` VisibilityMemRatio float64 `name:"visibility-mem-ratio"` } diff --git a/internal/config/defaults.go b/internal/config/defaults.go index 9b45002d0..0b28b9025 100644 --- a/internal/config/defaults.go +++ b/internal/config/defaults.go @@ -209,6 +209,8 @@ UserMuteMemRatio: 2, UserMuteIDsMemRatio: 3, WebfingerMemRatio: 0.1, + WebPushSubscriptionMemRatio: 1, + WebPushSubscriptionIDsMemRatio: 1, VisibilityMemRatio: 2, }, diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go index a35622f8e..2c554d87a 100644 --- a/internal/config/helpers.gen.go +++ b/internal/config/helpers.gen.go @@ -4162,6 +4162,64 @@ func GetCacheWebfingerMemRatio() float64 { return global.GetCacheWebfingerMemRat // SetCacheWebfingerMemRatio safely sets the value for global configuration 'Cache.WebfingerMemRatio' field func SetCacheWebfingerMemRatio(v float64) { global.SetCacheWebfingerMemRatio(v) } +// GetCacheWebPushSubscriptionMemRatio safely fetches the Configuration value for state's 'Cache.WebPushSubscriptionMemRatio' field +func (st *ConfigState) GetCacheWebPushSubscriptionMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.WebPushSubscriptionMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheWebPushSubscriptionMemRatio safely sets the Configuration value for state's 'Cache.WebPushSubscriptionMemRatio' field +func (st *ConfigState) SetCacheWebPushSubscriptionMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.WebPushSubscriptionMemRatio = v + st.reloadToViper() +} + +// CacheWebPushSubscriptionMemRatioFlag returns the flag name for the 'Cache.WebPushSubscriptionMemRatio' field +func CacheWebPushSubscriptionMemRatioFlag() string { return "cache-web-push-subscription-mem-ratio" } + +// GetCacheWebPushSubscriptionMemRatio safely fetches the value for global configuration 'Cache.WebPushSubscriptionMemRatio' field +func GetCacheWebPushSubscriptionMemRatio() float64 { + return global.GetCacheWebPushSubscriptionMemRatio() +} + +// SetCacheWebPushSubscriptionMemRatio safely sets the value for global configuration 'Cache.WebPushSubscriptionMemRatio' field +func SetCacheWebPushSubscriptionMemRatio(v float64) { global.SetCacheWebPushSubscriptionMemRatio(v) } + +// GetCacheWebPushSubscriptionIDsMemRatio safely fetches the Configuration value for state's 'Cache.WebPushSubscriptionIDsMemRatio' field +func (st *ConfigState) GetCacheWebPushSubscriptionIDsMemRatio() (v float64) { + st.mutex.RLock() + v = st.config.Cache.WebPushSubscriptionIDsMemRatio + st.mutex.RUnlock() + return +} + +// SetCacheWebPushSubscriptionIDsMemRatio safely sets the Configuration value for state's 'Cache.WebPushSubscriptionIDsMemRatio' field +func (st *ConfigState) SetCacheWebPushSubscriptionIDsMemRatio(v float64) { + st.mutex.Lock() + defer st.mutex.Unlock() + st.config.Cache.WebPushSubscriptionIDsMemRatio = v + st.reloadToViper() +} + +// CacheWebPushSubscriptionIDsMemRatioFlag returns the flag name for the 'Cache.WebPushSubscriptionIDsMemRatio' field +func CacheWebPushSubscriptionIDsMemRatioFlag() string { + return "cache-web-push-subscription-ids-mem-ratio" +} + +// GetCacheWebPushSubscriptionIDsMemRatio safely fetches the value for global configuration 'Cache.WebPushSubscriptionIDsMemRatio' field +func GetCacheWebPushSubscriptionIDsMemRatio() float64 { + return global.GetCacheWebPushSubscriptionIDsMemRatio() +} + +// SetCacheWebPushSubscriptionIDsMemRatio safely sets the value for global configuration 'Cache.WebPushSubscriptionIDsMemRatio' field +func SetCacheWebPushSubscriptionIDsMemRatio(v float64) { + global.SetCacheWebPushSubscriptionIDsMemRatio(v) +} + // GetCacheVisibilityMemRatio safely fetches the Configuration value for state's 'Cache.VisibilityMemRatio' field func (st *ConfigState) GetCacheVisibilityMemRatio() (v float64) { st.mutex.RLock() diff --git a/internal/db/admin.go b/internal/db/admin.go index 77fbbe613..1f24c7932 100644 --- a/internal/db/admin.go +++ b/internal/db/admin.go @@ -68,14 +68,6 @@ type Admin interface { // the number of pending sign-ups sitting in the backlog. CountUnhandledSignups(ctx context.Context) (int, error) - // GetVAPIDKeyPair retrieves the existing VAPID key pair, if there is one. - // If there isn't, it returns nil. - GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) - - // PutVAPIDKeyPair stores a VAPID key pair. - // This should be called at most once, during server startup. - PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error - /* ACTION FUNCS */ diff --git a/internal/db/application.go b/internal/db/application.go index b71e593c2..5a4068431 100644 --- a/internal/db/application.go +++ b/internal/db/application.go @@ -48,6 +48,9 @@ type Application interface { // GetAllTokens ... GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) + // GetTokenByID ... + GetTokenByID(ctx context.Context, id string) (*gtsmodel.Token, error) + // GetTokenByCode ... GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go index 266b351f5..ff398fca5 100644 --- a/internal/db/bundb/admin.go +++ b/internal/db/bundb/admin.go @@ -48,9 +48,6 @@ type adminDB struct { db *bun.DB state *state.State - - // Since the VAPID key pair is very small and never written to concurrently, we can cache it here. - vapidKeyPair *gtsmodel.VAPIDKeyPair } func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, error) { @@ -445,39 +442,6 @@ func (a *adminDB) CountUnhandledSignups(ctx context.Context) (int, error) { Count(ctx) } -func (a *adminDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) { - // Look for cached keys. - if a.vapidKeyPair != nil { - return a.vapidKeyPair, nil - } - - // Look for previously generated keys in the database. - if err := a.db.NewSelect(). - Model(a.vapidKeyPair). - Limit(1). - Scan(ctx); // nocollapse - err != nil && !errors.Is(err, db.ErrNoEntries) { - return nil, gtserror.Newf("DB error getting VAPID key pair: %w", err) - } - - return a.vapidKeyPair, nil -} - -func (a *adminDB) PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error { - // Store the keys in the database. - if _, err := a.db.NewInsert(). - Model(a.vapidKeyPair). - Exec(ctx); // nocollapse - err != nil { - return gtserror.Newf("DB error putting VAPID key pair: %w", err) - } - - // Cache the keys. - a.vapidKeyPair = vapidKeyPair - - return nil -} - /* ACTION FUNCS */ diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go index cbba499b0..92fc5ea2b 100644 --- a/internal/db/bundb/application.go +++ b/internal/db/bundb/application.go @@ -174,6 +174,16 @@ func(uncached []string) ([]*gtsmodel.Token, error) { return tokens, nil } +func (a *applicationDB) GetTokenByID(ctx context.Context, code string) (*gtsmodel.Token, error) { + return a.getTokenBy( + "ID", + func(t *gtsmodel.Token) error { + return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("id"), code).Scan(ctx) + }, + code, + ) +} + func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) { return a.getTokenBy( "Code", diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go index 70132fe58..c307e0356 100644 --- a/internal/db/bundb/bundb.go +++ b/internal/db/bundb/bundb.go @@ -87,6 +87,7 @@ type DBService struct { db.Timeline db.User db.Tombstone + db.WebPush db.WorkerTask db *bun.DB } @@ -296,6 +297,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) { db: db, state: state, }, + WebPush: &webPushDB{ + db: db, + state: state, + }, WorkerTask: &workerTaskDB{ db: db, }, diff --git a/internal/db/bundb/notification_test.go b/internal/db/bundb/notification_test.go index 8e2fb8031..8cc778071 100644 --- a/internal/db/bundb/notification_test.go +++ b/internal/db/bundb/notification_test.go @@ -66,7 +66,7 @@ func (suite *NotificationTestSuite) spamNotifs() { notif := >smodel.Notification{ ID: notifID, - NotificationType: gtsmodel.NotificationFave, + NotificationType: gtsmodel.NotificationFavourite, CreatedAt: time.Now(), TargetAccountID: targetAccountID, OriginAccountID: originAccountID, diff --git a/internal/db/bundb/webpush.go b/internal/db/bundb/webpush.go new file mode 100644 index 000000000..bb2ee2ba2 --- /dev/null +++ b/internal/db/bundb/webpush.go @@ -0,0 +1,203 @@ +package bundb + +import ( + "context" + "errors" + + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/state" + "github.com/superseriousbusiness/gotosocial/internal/util/xslices" + "github.com/uptrace/bun" +) + +type webPushDB struct { + db *bun.DB + state *state.State +} + +func (w *webPushDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) { + // Look for cached keys. + vapidKeyPair := w.state.Caches.DB.VAPIDKeyPair.Load() + if vapidKeyPair != nil { + return vapidKeyPair, nil + } + + // Look for previously generated keys in the database. + if err := w.db.NewSelect(). + Model(vapidKeyPair). + Limit(1). + Scan(ctx); // nocollapse + err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, err + } + + // Cache the keys. + w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair) + + return vapidKeyPair, nil +} + +func (w *webPushDB) PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error { + // Store the keys in the database. + if _, err := w.db.NewInsert(). + Model(vapidKeyPair). + Exec(ctx); // nocollapse + err != nil { + return err + } + + // Cache the keys. + w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair) + + return nil +} + +func (w *webPushDB) GetWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) (*gtsmodel.WebPushSubscription, error) { + return w.state.Caches.DB.WebPushSubscription.LoadOne( + "TokenID", + func() (*gtsmodel.WebPushSubscription, error) { + var subscription gtsmodel.WebPushSubscription + err := w.db. + NewSelect(). + Model(&subscription). + Where("? = ?", bun.Ident("token_id"), tokenID). + Scan(ctx) + return &subscription, err + }, + tokenID, + ) +} + +func (w *webPushDB) PutWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription) error { + return w.state.Caches.DB.WebPushSubscription.Store(subscription, func() error { + _, err := w.db.NewInsert(). + Model(subscription). + Exec(ctx) + return err + }) +} + +func (w *webPushDB) UpdateWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription, columns ...string) error { + // If we're updating by column, ensure "updated_at" is included. + if len(columns) > 0 { + columns = append(columns, "updated_at") + } + + // Update database. + if _, err := w.db. + NewUpdate(). + Model(subscription). + Column(columns...). + Where("? = ?", bun.Ident("id"), subscription.ID). + Exec(ctx); // nocollapse + err != nil { + return err + } + + // Update cache. + w.state.Caches.DB.WebPushSubscription.Put(subscription) + + return nil +} + +func (w *webPushDB) DeleteWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) error { + // Deleted partial model for cache invalidation. + var deleted gtsmodel.WebPushSubscription + + // Delete subscription, returning subset of columns used by invalidation hook. + if _, err := w.db.NewDelete(). + Model(&deleted). + Where("? = ?", bun.Ident("token_id"), tokenID). + Returning("?", bun.Ident("account_id")). + Exec(ctx); // nocollapse + err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } + + // Invalidate cached subscription by token ID. + w.state.Caches.DB.WebPushSubscription.Invalidate("TokenID", tokenID) + + // Call invalidate hook directly. + w.state.Caches.OnInvalidateWebPushSubscription(&deleted) + + return nil +} + +func (w *webPushDB) GetWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) ([]*gtsmodel.WebPushSubscription, error) { + // Fetch IDs of all subscriptions created by this account. + subscriptionIDs, err := loadPagedIDs(&w.state.Caches.DB.WebPushSubscriptionIDs, accountID, nil, func() ([]string, error) { + // Subscription IDs not in cache. Perform DB query. + var subscriptionIDs []string + if _, err := w.db. + NewSelect(). + Model((*gtsmodel.WebPushSubscription)(nil)). + Column("id"). + Where("? = ?", bun.Ident("account_id"), accountID). + Order("id DESC"). + Exec(ctx, &subscriptionIDs); // nocollapse + err != nil && !errors.Is(err, db.ErrNoEntries) { + return nil, err + } + return subscriptionIDs, nil + }) + if len(subscriptionIDs) == 0 { + return nil, nil + } + + // Get each subscription by ID from the cache or DB. + subscriptions, err := w.state.Caches.DB.WebPushSubscription.LoadIDs("ID", + subscriptionIDs, + func(uncached []string) ([]*gtsmodel.WebPushSubscription, error) { + subscriptions := make([]*gtsmodel.WebPushSubscription, 0, len(uncached)) + if err := w.db. + NewSelect(). + Model(&subscriptions). + Where("? IN (?)", bun.Ident("id"), bun.In(uncached)). + Scan(ctx); // nocollapse + err != nil { + return nil, err + } + return subscriptions, nil + }, + ) + if err != nil { + return nil, err + } + + // Put the subscription structs in the same order as the filter IDs. + xslices.OrderBy( + subscriptions, + subscriptionIDs, + func(subscription *gtsmodel.WebPushSubscription) string { + return subscription.ID + }, + ) + + return subscriptions, nil +} + +func (w *webPushDB) DeleteWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) error { + // Deleted partial models for cache invalidation. + var deleted []*gtsmodel.WebPushSubscription + + // Delete subscriptions, returning subset of columns. + if _, err := w.db.NewDelete(). + Model(&deleted). + Where("? = ?", bun.Ident("account_id"), accountID). + Returning("?", bun.Ident("account_id")). + Exec(ctx); // nocollapse + err != nil && !errors.Is(err, db.ErrNoEntries) { + return err + } + + // Invalidate cached subscriptions by account ID. + w.state.Caches.DB.WebPushSubscription.Invalidate("AccountID", accountID) + + // Call invalidate hooks directly in case those entries weren't cached. + for _, subscription := range deleted { + w.state.Caches.OnInvalidateWebPushSubscription(subscription) + } + + return nil +} diff --git a/internal/db/db.go b/internal/db/db.go index c42985912..b7e2b29bd 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -57,5 +57,6 @@ type DB interface { Timeline User Tombstone + WebPush WorkerTask } diff --git a/internal/db/webpush.go b/internal/db/webpush.go new file mode 100644 index 000000000..6752657d7 --- /dev/null +++ b/internal/db/webpush.go @@ -0,0 +1,53 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package db + +import ( + "context" + + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" +) + +// WebPush contains functions related to Web Push notifications. +type WebPush interface { + // GetVAPIDKeyPair retrieves the server's existing VAPID key pair, if there is one. + // If there isn't, it returns nil. + GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) + + // PutVAPIDKeyPair stores the server's VAPID key pair. + // This should be called at most once, during server startup. + PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error + + // GetWebPushSubscriptionByTokenID retrieves an access token's Web Push subscription, if there is one. + GetWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) (*gtsmodel.WebPushSubscription, error) + + // PutWebPushSubscription creates an access token's Web Push subscription. + PutWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription) error + + // UpdateWebPushSubscription updates an access token's Web Push subscription. + UpdateWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription, columns ...string) error + + // DeleteWebPushSubscriptionByTokenID deletes an access token's Web Push subscription, if there is one. + DeleteWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) error + + // GetWebPushSubscriptionsByAccountID retrieves an account's list of Web Push subscriptions. + GetWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) ([]*gtsmodel.WebPushSubscription, error) + + // DeleteWebPushSubscriptionsByAccountID deletes an account's list of Web Push subscriptions. + DeleteWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) error +} diff --git a/internal/gtsmodel/notification.go b/internal/gtsmodel/notification.go index 1ef805081..bdaa3f563 100644 --- a/internal/gtsmodel/notification.go +++ b/internal/gtsmodel/notification.go @@ -48,13 +48,14 @@ type Notification struct { NotificationFollowRequest NotificationType = 2 // NotificationFollowRequest -- someone requested to follow you NotificationMention NotificationType = 3 // NotificationMention -- someone mentioned you in their status NotificationReblog NotificationType = 4 // NotificationReblog -- someone boosted one of your statuses - NotificationFave NotificationType = 5 // NotificationFave -- someone faved/liked one of your statuses + NotificationFavourite NotificationType = 5 // NotificationFavourite -- someone faved/liked one of your statuses NotificationPoll NotificationType = 6 // NotificationPoll -- a poll you voted in or created has ended NotificationStatus NotificationType = 7 // NotificationStatus -- someone you enabled notifications for has posted a status. - NotificationSignup NotificationType = 8 // NotificationSignup -- someone has submitted a new account sign-up to the instance. - NotificationPendingFave NotificationType = 9 // Someone has faved a status of yours, which requires approval by you. - NotificationPendingReply NotificationType = 10 // Someone has replied to a status of yours, which requires approval by you. - NotificationPendingReblog NotificationType = 11 // Someone has boosted a status of yours, which requires approval by you. + NotificationAdminSignup NotificationType = 8 // NotificationAdminSignup -- someone has submitted a new account sign-up to the instance. + NotificationPendingFave NotificationType = 9 // NotificationPendingFave -- Someone has faved a status of yours, which requires approval by you. + NotificationPendingReply NotificationType = 10 // NotificationPendingReply -- Someone has replied to a status of yours, which requires approval by you. + NotificationPendingReblog NotificationType = 11 // NotificationPendingReblog -- Someone has boosted a status of yours, which requires approval by you. + NotificationAdminReport NotificationType = 12 // NotificationAdminReport -- someone has submitted a new report to the instance. ) // String returns a stringified, frontend API compatible form of NotificationType. @@ -68,13 +69,13 @@ func (t NotificationType) String() string { return "mention" case NotificationReblog: return "reblog" - case NotificationFave: + case NotificationFavourite: return "favourite" case NotificationPoll: return "poll" case NotificationStatus: return "status" - case NotificationSignup: + case NotificationAdminSignup: return "admin.sign_up" case NotificationPendingFave: return "pending.favourite" @@ -82,6 +83,8 @@ func (t NotificationType) String() string { return "pending.reply" case NotificationPendingReblog: return "pending.reblog" + case NotificationAdminReport: + return "admin.report" default: panic("invalid notification type") } @@ -99,19 +102,21 @@ func ParseNotificationType(in string) NotificationType { case "reblog": return NotificationReblog case "favourite": - return NotificationFave + return NotificationFavourite case "poll": return NotificationPoll case "status": return NotificationStatus case "admin.sign_up": - return NotificationSignup + return NotificationAdminSignup case "pending.favourite": return NotificationPendingFave case "pending.reply": return NotificationPendingReply case "pending.reblog": return NotificationPendingReblog + case "admin.report": + return NotificationAdminReport default: return NotificationUnknown } diff --git a/internal/gtsmodel/vapidkeypair.go b/internal/gtsmodel/vapidkeypair.go index 85883df45..56b7edda8 100644 --- a/internal/gtsmodel/vapidkeypair.go +++ b/internal/gtsmodel/vapidkeypair.go @@ -22,7 +22,7 @@ // // See: https://datatracker.ietf.org/doc/html/rfc8292 type VAPIDKeyPair struct { - ID int `bun:"pk,notnull"` - Public string `bun:"notnull,nullzero"` - Private string `bun:"notnull,nullzero"` + ID int `bun:",pk,notnull"` + Public string `bun:",notnull,nullzero"` + Private string `bun:",notnull,nullzero"` } diff --git a/internal/gtsmodel/webpushsubscription.go b/internal/gtsmodel/webpushsubscription.go new file mode 100644 index 000000000..b14fb1caf --- /dev/null +++ b/internal/gtsmodel/webpushsubscription.go @@ -0,0 +1,67 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package gtsmodel + +import ( + "time" +) + +// WebPushSubscription represents an access token's Web Push subscription. +// There can be at most one per access token. +type WebPushSubscription struct { + // ID of this subscription in the database. + ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"` + + // CreatedAt is the time this subscription was created. + CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` + + // UpdatedAt is the time this subscription was last updated. + UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` + + // AccountID of the local account that created this subscription. + AccountID string `bun:"type:CHAR(26),notnull,nullzero"` + + // TokenID is the ID of the associated access token. + // There can be at most one subscription for any given access token, + TokenID string `bun:"type:CHAR(26),nullzero,notnull,unique"` + + // Endpoint is the URL receiving Web Push notifications for this subscription. + Endpoint string `bun:",nullzero,notnull"` + + // Auth is a Base64-encoded authentication secret. + Auth string `bun:",nullzero,notnull"` + + // P256dh is a Base64-encoded Diffie-Hellman public key on the P-256 elliptic curve. + P256dh string `bun:",nullzero,notnull"` + + // NotifyFollow and friends control which notifications are delivered to a given subscription. + // Corresponds to NotificationType and model.PushSubscriptionAlerts. + NotifyFollow *bool `bun:",nullzero,notnull,default:false"` + NotifyFollowRequest *bool `bun:",nullzero,notnull,default:false"` + NotifyFavourite *bool `bun:",nullzero,notnull,default:false"` + NotifyMention *bool `bun:",nullzero,notnull,default:false"` + NotifyReblog *bool `bun:",nullzero,notnull,default:false"` + NotifyPoll *bool `bun:",nullzero,notnull,default:false"` + NotifyStatus *bool `bun:",nullzero,notnull,default:false"` + NotifyUpdate *bool `bun:",nullzero,notnull,default:false"` + NotifyAdminSignup *bool `bun:",nullzero,notnull,default:false"` + NotifyAdminReport *bool `bun:",nullzero,notnull,default:false"` + NotifyPendingFave *bool `bun:",nullzero,notnull,default:false"` + NotifyPendingReply *bool `bun:",nullzero,notnull,default:false"` + NotifyPendingReblog *bool `bun:",nullzero,notnull,default:false"` +} diff --git a/internal/processing/timeline/notification.go b/internal/processing/timeline/notification.go index a242c7b74..09636e7eb 100644 --- a/internal/processing/timeline/notification.go +++ b/internal/processing/timeline/notification.go @@ -184,7 +184,7 @@ func (p *Processor) notifVisible( // If this is a new local account sign-up, // skip normal visibility checking because // origin account won't be confirmed yet. - if n.NotificationType == gtsmodel.NotificationSignup { + if n.NotificationType == gtsmodel.NotificationAdminSignup { return true, nil } diff --git a/internal/processing/workers/fromfediapi_test.go b/internal/processing/workers/fromfediapi_test.go index d7d7454e7..ca6fe38f9 100644 --- a/internal/processing/workers/fromfediapi_test.go +++ b/internal/processing/workers/fromfediapi_test.go @@ -241,7 +241,7 @@ func (suite *FromFediAPITestSuite) TestProcessFave() { notif := >smodel.Notification{} err = testStructs.State.DB.GetWhere(context.Background(), where, notif) suite.NoError(err) - suite.Equal(gtsmodel.NotificationFave, notif.NotificationType) + suite.Equal(gtsmodel.NotificationFavourite, notif.NotificationType) suite.Equal(fave.TargetAccountID, notif.TargetAccountID) suite.Equal(fave.AccountID, notif.OriginAccountID) suite.Equal(fave.StatusID, notif.StatusID) @@ -314,7 +314,7 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount( notif := >smodel.Notification{} err = testStructs.State.DB.GetWhere(context.Background(), where, notif) suite.NoError(err) - suite.Equal(gtsmodel.NotificationFave, notif.NotificationType) + suite.Equal(gtsmodel.NotificationFavourite, notif.NotificationType) suite.Equal(fave.TargetAccountID, notif.TargetAccountID) suite.Equal(fave.AccountID, notif.OriginAccountID) suite.Equal(fave.StatusID, notif.StatusID) diff --git a/internal/processing/workers/surfacenotify.go b/internal/processing/workers/surfacenotify.go index 1520d2ec0..7773e80d3 100644 --- a/internal/processing/workers/surfacenotify.go +++ b/internal/processing/workers/surfacenotify.go @@ -250,7 +250,7 @@ func (s *Surface) notifyFave( // notify status author // of fave by account. if err := s.Notify(ctx, - gtsmodel.NotificationFave, + gtsmodel.NotificationFavourite, fave.TargetAccount, fave.Account, fave.StatusID, @@ -521,7 +521,7 @@ func (s *Surface) notifySignup(ctx context.Context, newUser *gtsmodel.User) erro var errs gtserror.MultiError for _, mod := range modAccounts { if err := s.Notify(ctx, - gtsmodel.NotificationSignup, + gtsmodel.NotificationAdminSignup, mod, newUser.Account, "", diff --git a/test/envparsing.sh b/test/envparsing.sh index 927c5f98b..e5e69a710 100755 --- a/test/envparsing.sh +++ b/test/envparsing.sh @@ -75,6 +75,8 @@ EXPECT=$(cat << "EOF" "user-mute-ids-mem-ratio": 3, "user-mute-mem-ratio": 2, "visibility-mem-ratio": 2, + "web-push-subscription-ids-mem-ratio": 1, + "web-push-subscription-mem-ratio": 1, "webfinger-mem-ratio": 0.1 }, "config-path": "internal/config/testdata/test.yaml", diff --git a/testrig/db.go b/testrig/db.go index 5e423431c..c107b9b05 100644 --- a/testrig/db.go +++ b/testrig/db.go @@ -19,6 +19,7 @@ import ( "context" + webpushgo "github.com/SherClockHolmes/webpush-go" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db/bundb" @@ -60,6 +61,8 @@ >smodel.ThreadToStatus{}, >smodel.User{}, >smodel.UserMute{}, + >smodel.VAPIDKeyPair{}, + >smodel.WebPushSubscription{}, >smodel.Emoji{}, >smodel.Instance{}, >smodel.Notification{}, @@ -347,6 +350,12 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) { } } + for _, v := range NewTestWebPushSubscriptions() { + if err := db.Put(ctx, v); err != nil { + log.Panic(nil, err) + } + } + for _, v := range NewTestInteractionRequests() { if err := db.Put(ctx, v); err != nil { log.Panic(nil, err) diff --git a/testrig/testmodels.go b/testrig/testmodels.go index ae69b9e81..c9c0c7be5 100644 --- a/testrig/testmodels.go +++ b/testrig/testmodels.go @@ -2475,7 +2475,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification { return map[string]*gtsmodel.Notification{ "local_account_1_like": { ID: "01F8Q0ANPTWW10DAKTX7BRPBJP", - NotificationType: gtsmodel.NotificationFave, + NotificationType: gtsmodel.NotificationFavourite, CreatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"), TargetAccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", OriginAccountID: "01F8MH17FWEB39HZJ76B6VXSKF", @@ -2484,7 +2484,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification { }, "local_account_2_like": { ID: "01GTS6PRPXJYZBPFFQ56PP0XR8", - NotificationType: gtsmodel.NotificationFave, + NotificationType: gtsmodel.NotificationFavourite, CreatedAt: TimeMustParse("2022-01-13T12:45:01+02:00"), TargetAccountID: "01F8MH17FWEB39HZJ76B6VXSKF", OriginAccountID: "01F8MH5NBDF2MV7CTC4Q5128HF", @@ -2493,7 +2493,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification { }, "new_signup": { ID: "01HTM9TETMB3YQCBKZ7KD4KV02", - NotificationType: gtsmodel.NotificationSignup, + NotificationType: gtsmodel.NotificationAdminSignup, CreatedAt: TimeMustParse("2022-06-04T13:12:00Z"), TargetAccountID: "01F8MH17FWEB39HZJ76B6VXSKF", OriginAccountID: "01F8MH0BBE4FHXPH513MBVFHB0", @@ -3476,6 +3476,32 @@ func NewTestUserMutes() map[string]*gtsmodel.UserMute { return map[string]*gtsmodel.UserMute{} } +func NewTestWebPushSubscriptions() map[string]*gtsmodel.WebPushSubscription { + return map[string]*gtsmodel.WebPushSubscription{ + "local_account_1_token_1": { + ID: "01G65Z755AFWAKHE12NY0CQ9FH", + AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", + TokenID: "01F8MGTQW4DKTDF8SW5CT9HYGA", + Endpoint: "https://example.test/push", + Auth: "cgna/fzrYLDQyPf5hD7IsA==", + P256dh: "BMYVItYVOX+AHBdtA62Q0i6c+F7MV2Gia3aoDr8mvHkuPBNIOuTLDfmFcnBqoZcQk6BtLcIONbxhHpy2R+mYIUY=", + NotifyFollow: util.Ptr(true), + NotifyFollowRequest: util.Ptr(true), + NotifyFavourite: util.Ptr(true), + NotifyMention: util.Ptr(true), + NotifyReblog: util.Ptr(true), + NotifyPoll: util.Ptr(true), + NotifyStatus: util.Ptr(true), + NotifyUpdate: util.Ptr(true), + NotifyAdminSignup: util.Ptr(true), + NotifyAdminReport: util.Ptr(true), + NotifyPendingFave: util.Ptr(true), + NotifyPendingReply: util.Ptr(true), + NotifyPendingReblog: util.Ptr(true), + }, + } +} + func NewTestInteractionRequests() map[string]*gtsmodel.InteractionRequest { return map[string]*gtsmodel.InteractionRequest{ "admin_account_reply_turtle": {