Store Web Push subscriptions in DB

This commit is contained in:
Vyr Cossont 2024-11-30 12:24:13 -08:00
parent 5a2e8341a2
commit 17fd35661d
25 changed files with 546 additions and 69 deletions

View file

@ -115,6 +115,8 @@ func (c *Caches) Init() {
c.initUserMute() c.initUserMute()
c.initUserMuteIDs() c.initUserMuteIDs()
c.initWebfinger() c.initWebfinger()
c.initWebPushSubscription()
c.initWebPushSubscriptionIDs()
c.initVisibility() c.initVisibility()
c.initStatusesFilterableFields() c.initStatusesFilterableFields()
} }

53
internal/cache/db.go vendored
View file

@ -252,6 +252,15 @@ type DBCaches struct {
// UserMuteIDs provides access to the user mute IDs database cache. // UserMuteIDs provides access to the user mute IDs database cache.
UserMuteIDs SliceCache[string] UserMuteIDs SliceCache[string]
// VAPIDKeyPair caches the server's VAPID key pair.
VAPIDKeyPair atomic.Pointer[gtsmodel.VAPIDKeyPair]
// WebPushSubscription provides access to the gtsmodel WebPushSubscription database cache.
WebPushSubscription StructCache[*gtsmodel.WebPushSubscription]
// WebPushSubscriptionIDs provides access to the Web Push subscription IDs database cache.
WebPushSubscriptionIDs SliceCache[string]
} }
// NOTE: // NOTE:
@ -1509,9 +1518,10 @@ func (c *Caches) initToken() {
{Fields: "Refresh"}, {Fields: "Refresh"},
{Fields: "ClientID", Multiple: true}, {Fields: "ClientID", Multiple: true},
}, },
MaxSize: cap, MaxSize: cap,
IgnoreErr: ignoreErrors, IgnoreErr: ignoreErrors,
Copy: copyF, Copy: copyF,
Invalidate: c.OnInvalidateToken,
}) })
} }
@ -1621,3 +1631,40 @@ func (c *Caches) initUserMuteIDs() {
c.DB.UserMuteIDs.Init(0, cap) c.DB.UserMuteIDs.Init(0, cap)
} }
func (c *Caches) initWebPushSubscription() {
cap := calculateResultCacheMax(
sizeofWebPushSubscription(), // model in-mem size.
config.GetCacheWebPushSubscriptionMemRatio(),
)
log.Infof(nil, "cache size = %d", cap)
copyF := func(s1 *gtsmodel.WebPushSubscription) *gtsmodel.WebPushSubscription {
s2 := new(gtsmodel.WebPushSubscription)
*s2 = *s1
return s2
}
c.DB.WebPushSubscription.Init(structr.CacheConfig[*gtsmodel.WebPushSubscription]{
Indices: []structr.IndexConfig{
{Fields: "ID"},
{Fields: "TokenID"},
{Fields: "AccountID", Multiple: true},
},
MaxSize: cap,
IgnoreErr: ignoreErrors,
Invalidate: c.OnInvalidateWebPushSubscription,
Copy: copyF,
})
}
func (c *Caches) initWebPushSubscriptionIDs() {
cap := calculateSliceCacheMax(
config.GetCacheWebPushSubscriptionIDsMemRatio(),
)
log.Infof(nil, "cache size = %d", cap)
c.DB.WebPushSubscriptionIDs.Init(0, cap)
}

View file

@ -278,6 +278,11 @@ func (c *Caches) OnInvalidateStatusFave(fave *gtsmodel.StatusFave) {
c.DB.StatusFaveIDs.Invalidate(fave.StatusID) c.DB.StatusFaveIDs.Invalidate(fave.StatusID)
} }
func (c *Caches) OnInvalidateToken(token *gtsmodel.Token) {
// Invalidate token's push subscription.
c.DB.WebPushSubscription.Invalidate("ID", token.ID)
}
func (c *Caches) OnInvalidateUser(user *gtsmodel.User) { func (c *Caches) OnInvalidateUser(user *gtsmodel.User) {
// Invalidate local account ID cached visibility. // Invalidate local account ID cached visibility.
c.Visibility.Invalidate("ItemID", user.AccountID) c.Visibility.Invalidate("ItemID", user.AccountID)
@ -291,3 +296,8 @@ func (c *Caches) OnInvalidateUserMute(mute *gtsmodel.UserMute) {
// Invalidate source account's user mute lists. // Invalidate source account's user mute lists.
c.DB.UserMuteIDs.Invalidate(mute.AccountID) c.DB.UserMuteIDs.Invalidate(mute.AccountID)
} }
func (c *Caches) OnInvalidateWebPushSubscription(subscription *gtsmodel.WebPushSubscription) {
// Invalidate source account's Web Push subscription list.
c.DB.WebPushSubscriptionIDs.Invalidate(subscription.AccountID)
}

View file

@ -66,6 +66,14 @@
// be a serialized string of almost any type, so we pick a // be a serialized string of almost any type, so we pick a
// nice serialized key size on the upper end of normal. // nice serialized key size on the upper end of normal.
sizeofResultKey = 2 * sizeofIDStr sizeofResultKey = 2 * sizeofIDStr
// exampleWebPushAuth is a Base64-encoded 16-byte random auth secret.
// This secret is consumed as Base64 by webpush-go.
exampleWebPushAuth = "ZVxqlt5fzVgmSz2aqiA2XQ=="
// exampleWebPushP256dh is a Base64-encoded DH P-256 public key.
// This secret is consumed as Base64 by webpush-go.
exampleWebPushP256dh = "OrpejO16gV97uBXew/T0I7YoUv/CX8fz0z4g8RrQ+edXJqQPjX3XVSo2P0HhcCpCOR1+Dzj5LFcK9jYNqX7SBg=="
) )
var ( var (
@ -558,7 +566,7 @@ func sizeofMove() uintptr {
func sizeofNotification() uintptr { func sizeofNotification() uintptr {
return uintptr(size.Of(&gtsmodel.Notification{ return uintptr(size.Of(&gtsmodel.Notification{
ID: exampleID, ID: exampleID,
NotificationType: gtsmodel.NotificationFave, NotificationType: gtsmodel.NotificationFavourite,
CreatedAt: exampleTime, CreatedAt: exampleTime,
TargetAccountID: exampleID, TargetAccountID: exampleID,
OriginAccountID: exampleID, OriginAccountID: exampleID,
@ -786,3 +794,11 @@ func sizeofUserMute() uintptr {
Notifications: util.Ptr(false), Notifications: util.Ptr(false),
})) }))
} }
func sizeofWebPushSubscription() uintptr {
return uintptr(size.Of(&gtsmodel.WebPushSubscription{
TokenID: exampleID,
Auth: exampleWebPushAuth,
P256dh: exampleWebPushP256dh,
}))
}

View file

@ -248,6 +248,8 @@ type CacheConfiguration struct {
UserMuteMemRatio float64 `name:"user-mute-mem-ratio"` UserMuteMemRatio float64 `name:"user-mute-mem-ratio"`
UserMuteIDsMemRatio float64 `name:"user-mute-ids-mem-ratio"` UserMuteIDsMemRatio float64 `name:"user-mute-ids-mem-ratio"`
WebfingerMemRatio float64 `name:"webfinger-mem-ratio"` WebfingerMemRatio float64 `name:"webfinger-mem-ratio"`
WebPushSubscriptionMemRatio float64 `name:"web-push-subscription-mem-ratio"`
WebPushSubscriptionIDsMemRatio float64 `name:"web-push-subscription-ids-mem-ratio"`
VisibilityMemRatio float64 `name:"visibility-mem-ratio"` VisibilityMemRatio float64 `name:"visibility-mem-ratio"`
} }

View file

@ -209,6 +209,8 @@
UserMuteMemRatio: 2, UserMuteMemRatio: 2,
UserMuteIDsMemRatio: 3, UserMuteIDsMemRatio: 3,
WebfingerMemRatio: 0.1, WebfingerMemRatio: 0.1,
WebPushSubscriptionMemRatio: 1,
WebPushSubscriptionIDsMemRatio: 1,
VisibilityMemRatio: 2, VisibilityMemRatio: 2,
}, },

View file

@ -4162,6 +4162,64 @@ func GetCacheWebfingerMemRatio() float64 { return global.GetCacheWebfingerMemRat
// SetCacheWebfingerMemRatio safely sets the value for global configuration 'Cache.WebfingerMemRatio' field // SetCacheWebfingerMemRatio safely sets the value for global configuration 'Cache.WebfingerMemRatio' field
func SetCacheWebfingerMemRatio(v float64) { global.SetCacheWebfingerMemRatio(v) } func SetCacheWebfingerMemRatio(v float64) { global.SetCacheWebfingerMemRatio(v) }
// GetCacheWebPushSubscriptionMemRatio safely fetches the Configuration value for state's 'Cache.WebPushSubscriptionMemRatio' field
func (st *ConfigState) GetCacheWebPushSubscriptionMemRatio() (v float64) {
st.mutex.RLock()
v = st.config.Cache.WebPushSubscriptionMemRatio
st.mutex.RUnlock()
return
}
// SetCacheWebPushSubscriptionMemRatio safely sets the Configuration value for state's 'Cache.WebPushSubscriptionMemRatio' field
func (st *ConfigState) SetCacheWebPushSubscriptionMemRatio(v float64) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.WebPushSubscriptionMemRatio = v
st.reloadToViper()
}
// CacheWebPushSubscriptionMemRatioFlag returns the flag name for the 'Cache.WebPushSubscriptionMemRatio' field
func CacheWebPushSubscriptionMemRatioFlag() string { return "cache-web-push-subscription-mem-ratio" }
// GetCacheWebPushSubscriptionMemRatio safely fetches the value for global configuration 'Cache.WebPushSubscriptionMemRatio' field
func GetCacheWebPushSubscriptionMemRatio() float64 {
return global.GetCacheWebPushSubscriptionMemRatio()
}
// SetCacheWebPushSubscriptionMemRatio safely sets the value for global configuration 'Cache.WebPushSubscriptionMemRatio' field
func SetCacheWebPushSubscriptionMemRatio(v float64) { global.SetCacheWebPushSubscriptionMemRatio(v) }
// GetCacheWebPushSubscriptionIDsMemRatio safely fetches the Configuration value for state's 'Cache.WebPushSubscriptionIDsMemRatio' field
func (st *ConfigState) GetCacheWebPushSubscriptionIDsMemRatio() (v float64) {
st.mutex.RLock()
v = st.config.Cache.WebPushSubscriptionIDsMemRatio
st.mutex.RUnlock()
return
}
// SetCacheWebPushSubscriptionIDsMemRatio safely sets the Configuration value for state's 'Cache.WebPushSubscriptionIDsMemRatio' field
func (st *ConfigState) SetCacheWebPushSubscriptionIDsMemRatio(v float64) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.WebPushSubscriptionIDsMemRatio = v
st.reloadToViper()
}
// CacheWebPushSubscriptionIDsMemRatioFlag returns the flag name for the 'Cache.WebPushSubscriptionIDsMemRatio' field
func CacheWebPushSubscriptionIDsMemRatioFlag() string {
return "cache-web-push-subscription-ids-mem-ratio"
}
// GetCacheWebPushSubscriptionIDsMemRatio safely fetches the value for global configuration 'Cache.WebPushSubscriptionIDsMemRatio' field
func GetCacheWebPushSubscriptionIDsMemRatio() float64 {
return global.GetCacheWebPushSubscriptionIDsMemRatio()
}
// SetCacheWebPushSubscriptionIDsMemRatio safely sets the value for global configuration 'Cache.WebPushSubscriptionIDsMemRatio' field
func SetCacheWebPushSubscriptionIDsMemRatio(v float64) {
global.SetCacheWebPushSubscriptionIDsMemRatio(v)
}
// GetCacheVisibilityMemRatio safely fetches the Configuration value for state's 'Cache.VisibilityMemRatio' field // GetCacheVisibilityMemRatio safely fetches the Configuration value for state's 'Cache.VisibilityMemRatio' field
func (st *ConfigState) GetCacheVisibilityMemRatio() (v float64) { func (st *ConfigState) GetCacheVisibilityMemRatio() (v float64) {
st.mutex.RLock() st.mutex.RLock()

View file

@ -68,14 +68,6 @@ type Admin interface {
// the number of pending sign-ups sitting in the backlog. // the number of pending sign-ups sitting in the backlog.
CountUnhandledSignups(ctx context.Context) (int, error) CountUnhandledSignups(ctx context.Context) (int, error)
// GetVAPIDKeyPair retrieves the existing VAPID key pair, if there is one.
// If there isn't, it returns nil.
GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error)
// PutVAPIDKeyPair stores a VAPID key pair.
// This should be called at most once, during server startup.
PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error
/* /*
ACTION FUNCS ACTION FUNCS
*/ */

View file

@ -48,6 +48,9 @@ type Application interface {
// GetAllTokens ... // GetAllTokens ...
GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error) GetAllTokens(ctx context.Context) ([]*gtsmodel.Token, error)
// GetTokenByID ...
GetTokenByID(ctx context.Context, id string) (*gtsmodel.Token, error)
// GetTokenByCode ... // GetTokenByCode ...
GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error)

View file

@ -48,9 +48,6 @@
type adminDB struct { type adminDB struct {
db *bun.DB db *bun.DB
state *state.State state *state.State
// Since the VAPID key pair is very small and never written to concurrently, we can cache it here.
vapidKeyPair *gtsmodel.VAPIDKeyPair
} }
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, error) { func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, error) {
@ -445,39 +442,6 @@ func (a *adminDB) CountUnhandledSignups(ctx context.Context) (int, error) {
Count(ctx) Count(ctx)
} }
func (a *adminDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) {
// Look for cached keys.
if a.vapidKeyPair != nil {
return a.vapidKeyPair, nil
}
// Look for previously generated keys in the database.
if err := a.db.NewSelect().
Model(a.vapidKeyPair).
Limit(1).
Scan(ctx); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.Newf("DB error getting VAPID key pair: %w", err)
}
return a.vapidKeyPair, nil
}
func (a *adminDB) PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error {
// Store the keys in the database.
if _, err := a.db.NewInsert().
Model(a.vapidKeyPair).
Exec(ctx); // nocollapse
err != nil {
return gtserror.Newf("DB error putting VAPID key pair: %w", err)
}
// Cache the keys.
a.vapidKeyPair = vapidKeyPair
return nil
}
/* /*
ACTION FUNCS ACTION FUNCS
*/ */

View file

@ -174,6 +174,16 @@ func(uncached []string) ([]*gtsmodel.Token, error) {
return tokens, nil return tokens, nil
} }
func (a *applicationDB) GetTokenByID(ctx context.Context, code string) (*gtsmodel.Token, error) {
return a.getTokenBy(
"ID",
func(t *gtsmodel.Token) error {
return a.db.NewSelect().Model(t).Where("? = ?", bun.Ident("id"), code).Scan(ctx)
},
code,
)
}
func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) { func (a *applicationDB) GetTokenByCode(ctx context.Context, code string) (*gtsmodel.Token, error) {
return a.getTokenBy( return a.getTokenBy(
"Code", "Code",

View file

@ -87,6 +87,7 @@ type DBService struct {
db.Timeline db.Timeline
db.User db.User
db.Tombstone db.Tombstone
db.WebPush
db.WorkerTask db.WorkerTask
db *bun.DB db *bun.DB
} }
@ -296,6 +297,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db, db: db,
state: state, state: state,
}, },
WebPush: &webPushDB{
db: db,
state: state,
},
WorkerTask: &workerTaskDB{ WorkerTask: &workerTaskDB{
db: db, db: db,
}, },

View file

@ -66,7 +66,7 @@ func (suite *NotificationTestSuite) spamNotifs() {
notif := &gtsmodel.Notification{ notif := &gtsmodel.Notification{
ID: notifID, ID: notifID,
NotificationType: gtsmodel.NotificationFave, NotificationType: gtsmodel.NotificationFavourite,
CreatedAt: time.Now(), CreatedAt: time.Now(),
TargetAccountID: targetAccountID, TargetAccountID: targetAccountID,
OriginAccountID: originAccountID, OriginAccountID: originAccountID,

View file

@ -0,0 +1,203 @@
package bundb
import (
"context"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices"
"github.com/uptrace/bun"
)
type webPushDB struct {
db *bun.DB
state *state.State
}
func (w *webPushDB) GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error) {
// Look for cached keys.
vapidKeyPair := w.state.Caches.DB.VAPIDKeyPair.Load()
if vapidKeyPair != nil {
return vapidKeyPair, nil
}
// Look for previously generated keys in the database.
if err := w.db.NewSelect().
Model(vapidKeyPair).
Limit(1).
Scan(ctx); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
// Cache the keys.
w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair)
return vapidKeyPair, nil
}
func (w *webPushDB) PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error {
// Store the keys in the database.
if _, err := w.db.NewInsert().
Model(vapidKeyPair).
Exec(ctx); // nocollapse
err != nil {
return err
}
// Cache the keys.
w.state.Caches.DB.VAPIDKeyPair.Store(vapidKeyPair)
return nil
}
func (w *webPushDB) GetWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) (*gtsmodel.WebPushSubscription, error) {
return w.state.Caches.DB.WebPushSubscription.LoadOne(
"TokenID",
func() (*gtsmodel.WebPushSubscription, error) {
var subscription gtsmodel.WebPushSubscription
err := w.db.
NewSelect().
Model(&subscription).
Where("? = ?", bun.Ident("token_id"), tokenID).
Scan(ctx)
return &subscription, err
},
tokenID,
)
}
func (w *webPushDB) PutWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription) error {
return w.state.Caches.DB.WebPushSubscription.Store(subscription, func() error {
_, err := w.db.NewInsert().
Model(subscription).
Exec(ctx)
return err
})
}
func (w *webPushDB) UpdateWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription, columns ...string) error {
// If we're updating by column, ensure "updated_at" is included.
if len(columns) > 0 {
columns = append(columns, "updated_at")
}
// Update database.
if _, err := w.db.
NewUpdate().
Model(subscription).
Column(columns...).
Where("? = ?", bun.Ident("id"), subscription.ID).
Exec(ctx); // nocollapse
err != nil {
return err
}
// Update cache.
w.state.Caches.DB.WebPushSubscription.Put(subscription)
return nil
}
func (w *webPushDB) DeleteWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) error {
// Deleted partial model for cache invalidation.
var deleted gtsmodel.WebPushSubscription
// Delete subscription, returning subset of columns used by invalidation hook.
if _, err := w.db.NewDelete().
Model(&deleted).
Where("? = ?", bun.Ident("token_id"), tokenID).
Returning("?", bun.Ident("account_id")).
Exec(ctx); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Invalidate cached subscription by token ID.
w.state.Caches.DB.WebPushSubscription.Invalidate("TokenID", tokenID)
// Call invalidate hook directly.
w.state.Caches.OnInvalidateWebPushSubscription(&deleted)
return nil
}
func (w *webPushDB) GetWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) ([]*gtsmodel.WebPushSubscription, error) {
// Fetch IDs of all subscriptions created by this account.
subscriptionIDs, err := loadPagedIDs(&w.state.Caches.DB.WebPushSubscriptionIDs, accountID, nil, func() ([]string, error) {
// Subscription IDs not in cache. Perform DB query.
var subscriptionIDs []string
if _, err := w.db.
NewSelect().
Model((*gtsmodel.WebPushSubscription)(nil)).
Column("id").
Where("? = ?", bun.Ident("account_id"), accountID).
Order("id DESC").
Exec(ctx, &subscriptionIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
return subscriptionIDs, nil
})
if len(subscriptionIDs) == 0 {
return nil, nil
}
// Get each subscription by ID from the cache or DB.
subscriptions, err := w.state.Caches.DB.WebPushSubscription.LoadIDs("ID",
subscriptionIDs,
func(uncached []string) ([]*gtsmodel.WebPushSubscription, error) {
subscriptions := make([]*gtsmodel.WebPushSubscription, 0, len(uncached))
if err := w.db.
NewSelect().
Model(&subscriptions).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); // nocollapse
err != nil {
return nil, err
}
return subscriptions, nil
},
)
if err != nil {
return nil, err
}
// Put the subscription structs in the same order as the filter IDs.
xslices.OrderBy(
subscriptions,
subscriptionIDs,
func(subscription *gtsmodel.WebPushSubscription) string {
return subscription.ID
},
)
return subscriptions, nil
}
func (w *webPushDB) DeleteWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) error {
// Deleted partial models for cache invalidation.
var deleted []*gtsmodel.WebPushSubscription
// Delete subscriptions, returning subset of columns.
if _, err := w.db.NewDelete().
Model(&deleted).
Where("? = ?", bun.Ident("account_id"), accountID).
Returning("?", bun.Ident("account_id")).
Exec(ctx); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Invalidate cached subscriptions by account ID.
w.state.Caches.DB.WebPushSubscription.Invalidate("AccountID", accountID)
// Call invalidate hooks directly in case those entries weren't cached.
for _, subscription := range deleted {
w.state.Caches.OnInvalidateWebPushSubscription(subscription)
}
return nil
}

View file

@ -57,5 +57,6 @@ type DB interface {
Timeline Timeline
User User
Tombstone Tombstone
WebPush
WorkerTask WorkerTask
} }

53
internal/db/webpush.go Normal file
View file

@ -0,0 +1,53 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// WebPush contains functions related to Web Push notifications.
type WebPush interface {
// GetVAPIDKeyPair retrieves the server's existing VAPID key pair, if there is one.
// If there isn't, it returns nil.
GetVAPIDKeyPair(ctx context.Context) (*gtsmodel.VAPIDKeyPair, error)
// PutVAPIDKeyPair stores the server's VAPID key pair.
// This should be called at most once, during server startup.
PutVAPIDKeyPair(ctx context.Context, vapidKeyPair *gtsmodel.VAPIDKeyPair) error
// GetWebPushSubscriptionByTokenID retrieves an access token's Web Push subscription, if there is one.
GetWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) (*gtsmodel.WebPushSubscription, error)
// PutWebPushSubscription creates an access token's Web Push subscription.
PutWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription) error
// UpdateWebPushSubscription updates an access token's Web Push subscription.
UpdateWebPushSubscription(ctx context.Context, subscription *gtsmodel.WebPushSubscription, columns ...string) error
// DeleteWebPushSubscriptionByTokenID deletes an access token's Web Push subscription, if there is one.
DeleteWebPushSubscriptionByTokenID(ctx context.Context, tokenID string) error
// GetWebPushSubscriptionsByAccountID retrieves an account's list of Web Push subscriptions.
GetWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) ([]*gtsmodel.WebPushSubscription, error)
// DeleteWebPushSubscriptionsByAccountID deletes an account's list of Web Push subscriptions.
DeleteWebPushSubscriptionsByAccountID(ctx context.Context, accountID string) error
}

View file

@ -48,13 +48,14 @@ type Notification struct {
NotificationFollowRequest NotificationType = 2 // NotificationFollowRequest -- someone requested to follow you NotificationFollowRequest NotificationType = 2 // NotificationFollowRequest -- someone requested to follow you
NotificationMention NotificationType = 3 // NotificationMention -- someone mentioned you in their status NotificationMention NotificationType = 3 // NotificationMention -- someone mentioned you in their status
NotificationReblog NotificationType = 4 // NotificationReblog -- someone boosted one of your statuses NotificationReblog NotificationType = 4 // NotificationReblog -- someone boosted one of your statuses
NotificationFave NotificationType = 5 // NotificationFave -- someone faved/liked one of your statuses NotificationFavourite NotificationType = 5 // NotificationFavourite -- someone faved/liked one of your statuses
NotificationPoll NotificationType = 6 // NotificationPoll -- a poll you voted in or created has ended NotificationPoll NotificationType = 6 // NotificationPoll -- a poll you voted in or created has ended
NotificationStatus NotificationType = 7 // NotificationStatus -- someone you enabled notifications for has posted a status. NotificationStatus NotificationType = 7 // NotificationStatus -- someone you enabled notifications for has posted a status.
NotificationSignup NotificationType = 8 // NotificationSignup -- someone has submitted a new account sign-up to the instance. NotificationAdminSignup NotificationType = 8 // NotificationAdminSignup -- someone has submitted a new account sign-up to the instance.
NotificationPendingFave NotificationType = 9 // Someone has faved a status of yours, which requires approval by you. NotificationPendingFave NotificationType = 9 // NotificationPendingFave -- Someone has faved a status of yours, which requires approval by you.
NotificationPendingReply NotificationType = 10 // Someone has replied to a status of yours, which requires approval by you. NotificationPendingReply NotificationType = 10 // NotificationPendingReply -- Someone has replied to a status of yours, which requires approval by you.
NotificationPendingReblog NotificationType = 11 // Someone has boosted a status of yours, which requires approval by you. NotificationPendingReblog NotificationType = 11 // NotificationPendingReblog -- Someone has boosted a status of yours, which requires approval by you.
NotificationAdminReport NotificationType = 12 // NotificationAdminReport -- someone has submitted a new report to the instance.
) )
// String returns a stringified, frontend API compatible form of NotificationType. // String returns a stringified, frontend API compatible form of NotificationType.
@ -68,13 +69,13 @@ func (t NotificationType) String() string {
return "mention" return "mention"
case NotificationReblog: case NotificationReblog:
return "reblog" return "reblog"
case NotificationFave: case NotificationFavourite:
return "favourite" return "favourite"
case NotificationPoll: case NotificationPoll:
return "poll" return "poll"
case NotificationStatus: case NotificationStatus:
return "status" return "status"
case NotificationSignup: case NotificationAdminSignup:
return "admin.sign_up" return "admin.sign_up"
case NotificationPendingFave: case NotificationPendingFave:
return "pending.favourite" return "pending.favourite"
@ -82,6 +83,8 @@ func (t NotificationType) String() string {
return "pending.reply" return "pending.reply"
case NotificationPendingReblog: case NotificationPendingReblog:
return "pending.reblog" return "pending.reblog"
case NotificationAdminReport:
return "admin.report"
default: default:
panic("invalid notification type") panic("invalid notification type")
} }
@ -99,19 +102,21 @@ func ParseNotificationType(in string) NotificationType {
case "reblog": case "reblog":
return NotificationReblog return NotificationReblog
case "favourite": case "favourite":
return NotificationFave return NotificationFavourite
case "poll": case "poll":
return NotificationPoll return NotificationPoll
case "status": case "status":
return NotificationStatus return NotificationStatus
case "admin.sign_up": case "admin.sign_up":
return NotificationSignup return NotificationAdminSignup
case "pending.favourite": case "pending.favourite":
return NotificationPendingFave return NotificationPendingFave
case "pending.reply": case "pending.reply":
return NotificationPendingReply return NotificationPendingReply
case "pending.reblog": case "pending.reblog":
return NotificationPendingReblog return NotificationPendingReblog
case "admin.report":
return NotificationAdminReport
default: default:
return NotificationUnknown return NotificationUnknown
} }

View file

@ -22,7 +22,7 @@
// //
// See: https://datatracker.ietf.org/doc/html/rfc8292 // See: https://datatracker.ietf.org/doc/html/rfc8292
type VAPIDKeyPair struct { type VAPIDKeyPair struct {
ID int `bun:"pk,notnull"` ID int `bun:",pk,notnull"`
Public string `bun:"notnull,nullzero"` Public string `bun:",notnull,nullzero"`
Private string `bun:"notnull,nullzero"` Private string `bun:",notnull,nullzero"`
} }

View file

@ -0,0 +1,67 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package gtsmodel
import (
"time"
)
// WebPushSubscription represents an access token's Web Push subscription.
// There can be at most one per access token.
type WebPushSubscription struct {
// ID of this subscription in the database.
ID string `bun:"type:CHAR(26),pk,nullzero,notnull,unique"`
// CreatedAt is the time this subscription was created.
CreatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"`
// UpdatedAt is the time this subscription was last updated.
UpdatedAt time.Time `bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"`
// AccountID of the local account that created this subscription.
AccountID string `bun:"type:CHAR(26),notnull,nullzero"`
// TokenID is the ID of the associated access token.
// There can be at most one subscription for any given access token,
TokenID string `bun:"type:CHAR(26),nullzero,notnull,unique"`
// Endpoint is the URL receiving Web Push notifications for this subscription.
Endpoint string `bun:",nullzero,notnull"`
// Auth is a Base64-encoded authentication secret.
Auth string `bun:",nullzero,notnull"`
// P256dh is a Base64-encoded Diffie-Hellman public key on the P-256 elliptic curve.
P256dh string `bun:",nullzero,notnull"`
// NotifyFollow and friends control which notifications are delivered to a given subscription.
// Corresponds to NotificationType and model.PushSubscriptionAlerts.
NotifyFollow *bool `bun:",nullzero,notnull,default:false"`
NotifyFollowRequest *bool `bun:",nullzero,notnull,default:false"`
NotifyFavourite *bool `bun:",nullzero,notnull,default:false"`
NotifyMention *bool `bun:",nullzero,notnull,default:false"`
NotifyReblog *bool `bun:",nullzero,notnull,default:false"`
NotifyPoll *bool `bun:",nullzero,notnull,default:false"`
NotifyStatus *bool `bun:",nullzero,notnull,default:false"`
NotifyUpdate *bool `bun:",nullzero,notnull,default:false"`
NotifyAdminSignup *bool `bun:",nullzero,notnull,default:false"`
NotifyAdminReport *bool `bun:",nullzero,notnull,default:false"`
NotifyPendingFave *bool `bun:",nullzero,notnull,default:false"`
NotifyPendingReply *bool `bun:",nullzero,notnull,default:false"`
NotifyPendingReblog *bool `bun:",nullzero,notnull,default:false"`
}

View file

@ -184,7 +184,7 @@ func (p *Processor) notifVisible(
// If this is a new local account sign-up, // If this is a new local account sign-up,
// skip normal visibility checking because // skip normal visibility checking because
// origin account won't be confirmed yet. // origin account won't be confirmed yet.
if n.NotificationType == gtsmodel.NotificationSignup { if n.NotificationType == gtsmodel.NotificationAdminSignup {
return true, nil return true, nil
} }

View file

@ -241,7 +241,7 @@ func (suite *FromFediAPITestSuite) TestProcessFave() {
notif := &gtsmodel.Notification{} notif := &gtsmodel.Notification{}
err = testStructs.State.DB.GetWhere(context.Background(), where, notif) err = testStructs.State.DB.GetWhere(context.Background(), where, notif)
suite.NoError(err) suite.NoError(err)
suite.Equal(gtsmodel.NotificationFave, notif.NotificationType) suite.Equal(gtsmodel.NotificationFavourite, notif.NotificationType)
suite.Equal(fave.TargetAccountID, notif.TargetAccountID) suite.Equal(fave.TargetAccountID, notif.TargetAccountID)
suite.Equal(fave.AccountID, notif.OriginAccountID) suite.Equal(fave.AccountID, notif.OriginAccountID)
suite.Equal(fave.StatusID, notif.StatusID) suite.Equal(fave.StatusID, notif.StatusID)
@ -314,7 +314,7 @@ func (suite *FromFediAPITestSuite) TestProcessFaveWithDifferentReceivingAccount(
notif := &gtsmodel.Notification{} notif := &gtsmodel.Notification{}
err = testStructs.State.DB.GetWhere(context.Background(), where, notif) err = testStructs.State.DB.GetWhere(context.Background(), where, notif)
suite.NoError(err) suite.NoError(err)
suite.Equal(gtsmodel.NotificationFave, notif.NotificationType) suite.Equal(gtsmodel.NotificationFavourite, notif.NotificationType)
suite.Equal(fave.TargetAccountID, notif.TargetAccountID) suite.Equal(fave.TargetAccountID, notif.TargetAccountID)
suite.Equal(fave.AccountID, notif.OriginAccountID) suite.Equal(fave.AccountID, notif.OriginAccountID)
suite.Equal(fave.StatusID, notif.StatusID) suite.Equal(fave.StatusID, notif.StatusID)

View file

@ -250,7 +250,7 @@ func (s *Surface) notifyFave(
// notify status author // notify status author
// of fave by account. // of fave by account.
if err := s.Notify(ctx, if err := s.Notify(ctx,
gtsmodel.NotificationFave, gtsmodel.NotificationFavourite,
fave.TargetAccount, fave.TargetAccount,
fave.Account, fave.Account,
fave.StatusID, fave.StatusID,
@ -521,7 +521,7 @@ func (s *Surface) notifySignup(ctx context.Context, newUser *gtsmodel.User) erro
var errs gtserror.MultiError var errs gtserror.MultiError
for _, mod := range modAccounts { for _, mod := range modAccounts {
if err := s.Notify(ctx, if err := s.Notify(ctx,
gtsmodel.NotificationSignup, gtsmodel.NotificationAdminSignup,
mod, mod,
newUser.Account, newUser.Account,
"", "",

View file

@ -75,6 +75,8 @@ EXPECT=$(cat << "EOF"
"user-mute-ids-mem-ratio": 3, "user-mute-ids-mem-ratio": 3,
"user-mute-mem-ratio": 2, "user-mute-mem-ratio": 2,
"visibility-mem-ratio": 2, "visibility-mem-ratio": 2,
"web-push-subscription-ids-mem-ratio": 1,
"web-push-subscription-mem-ratio": 1,
"webfinger-mem-ratio": 0.1 "webfinger-mem-ratio": 0.1
}, },
"config-path": "internal/config/testdata/test.yaml", "config-path": "internal/config/testdata/test.yaml",

View file

@ -19,6 +19,7 @@
import ( import (
"context" "context"
webpushgo "github.com/SherClockHolmes/webpush-go" webpushgo "github.com/SherClockHolmes/webpush-go"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
@ -60,6 +61,8 @@
&gtsmodel.ThreadToStatus{}, &gtsmodel.ThreadToStatus{},
&gtsmodel.User{}, &gtsmodel.User{},
&gtsmodel.UserMute{}, &gtsmodel.UserMute{},
&gtsmodel.VAPIDKeyPair{},
&gtsmodel.WebPushSubscription{},
&gtsmodel.Emoji{}, &gtsmodel.Emoji{},
&gtsmodel.Instance{}, &gtsmodel.Instance{},
&gtsmodel.Notification{}, &gtsmodel.Notification{},
@ -347,6 +350,12 @@ func StandardDBSetup(db db.DB, accounts map[string]*gtsmodel.Account) {
} }
} }
for _, v := range NewTestWebPushSubscriptions() {
if err := db.Put(ctx, v); err != nil {
log.Panic(nil, err)
}
}
for _, v := range NewTestInteractionRequests() { for _, v := range NewTestInteractionRequests() {
if err := db.Put(ctx, v); err != nil { if err := db.Put(ctx, v); err != nil {
log.Panic(nil, err) log.Panic(nil, err)

View file

@ -2475,7 +2475,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification {
return map[string]*gtsmodel.Notification{ return map[string]*gtsmodel.Notification{
"local_account_1_like": { "local_account_1_like": {
ID: "01F8Q0ANPTWW10DAKTX7BRPBJP", ID: "01F8Q0ANPTWW10DAKTX7BRPBJP",
NotificationType: gtsmodel.NotificationFave, NotificationType: gtsmodel.NotificationFavourite,
CreatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"), CreatedAt: TimeMustParse("2022-05-14T13:21:09+02:00"),
TargetAccountID: "01F8MH1H7YV1Z7D2C8K2730QBF", TargetAccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
OriginAccountID: "01F8MH17FWEB39HZJ76B6VXSKF", OriginAccountID: "01F8MH17FWEB39HZJ76B6VXSKF",
@ -2484,7 +2484,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification {
}, },
"local_account_2_like": { "local_account_2_like": {
ID: "01GTS6PRPXJYZBPFFQ56PP0XR8", ID: "01GTS6PRPXJYZBPFFQ56PP0XR8",
NotificationType: gtsmodel.NotificationFave, NotificationType: gtsmodel.NotificationFavourite,
CreatedAt: TimeMustParse("2022-01-13T12:45:01+02:00"), CreatedAt: TimeMustParse("2022-01-13T12:45:01+02:00"),
TargetAccountID: "01F8MH17FWEB39HZJ76B6VXSKF", TargetAccountID: "01F8MH17FWEB39HZJ76B6VXSKF",
OriginAccountID: "01F8MH5NBDF2MV7CTC4Q5128HF", OriginAccountID: "01F8MH5NBDF2MV7CTC4Q5128HF",
@ -2493,7 +2493,7 @@ func NewTestNotifications() map[string]*gtsmodel.Notification {
}, },
"new_signup": { "new_signup": {
ID: "01HTM9TETMB3YQCBKZ7KD4KV02", ID: "01HTM9TETMB3YQCBKZ7KD4KV02",
NotificationType: gtsmodel.NotificationSignup, NotificationType: gtsmodel.NotificationAdminSignup,
CreatedAt: TimeMustParse("2022-06-04T13:12:00Z"), CreatedAt: TimeMustParse("2022-06-04T13:12:00Z"),
TargetAccountID: "01F8MH17FWEB39HZJ76B6VXSKF", TargetAccountID: "01F8MH17FWEB39HZJ76B6VXSKF",
OriginAccountID: "01F8MH0BBE4FHXPH513MBVFHB0", OriginAccountID: "01F8MH0BBE4FHXPH513MBVFHB0",
@ -3476,6 +3476,32 @@ func NewTestUserMutes() map[string]*gtsmodel.UserMute {
return map[string]*gtsmodel.UserMute{} return map[string]*gtsmodel.UserMute{}
} }
func NewTestWebPushSubscriptions() map[string]*gtsmodel.WebPushSubscription {
return map[string]*gtsmodel.WebPushSubscription{
"local_account_1_token_1": {
ID: "01G65Z755AFWAKHE12NY0CQ9FH",
AccountID: "01F8MH1H7YV1Z7D2C8K2730QBF",
TokenID: "01F8MGTQW4DKTDF8SW5CT9HYGA",
Endpoint: "https://example.test/push",
Auth: "cgna/fzrYLDQyPf5hD7IsA==",
P256dh: "BMYVItYVOX+AHBdtA62Q0i6c+F7MV2Gia3aoDr8mvHkuPBNIOuTLDfmFcnBqoZcQk6BtLcIONbxhHpy2R+mYIUY=",
NotifyFollow: util.Ptr(true),
NotifyFollowRequest: util.Ptr(true),
NotifyFavourite: util.Ptr(true),
NotifyMention: util.Ptr(true),
NotifyReblog: util.Ptr(true),
NotifyPoll: util.Ptr(true),
NotifyStatus: util.Ptr(true),
NotifyUpdate: util.Ptr(true),
NotifyAdminSignup: util.Ptr(true),
NotifyAdminReport: util.Ptr(true),
NotifyPendingFave: util.Ptr(true),
NotifyPendingReply: util.Ptr(true),
NotifyPendingReblog: util.Ptr(true),
},
}
}
func NewTestInteractionRequests() map[string]*gtsmodel.InteractionRequest { func NewTestInteractionRequests() map[string]*gtsmodel.InteractionRequest {
return map[string]*gtsmodel.InteractionRequest{ return map[string]*gtsmodel.InteractionRequest{
"admin_account_reply_turtle": { "admin_account_reply_turtle": {