[performance] remove last of relational queries to instead rely on caches (#2091)

This commit is contained in:
kim 2023-08-10 15:08:41 +01:00 committed by GitHub
parent 9770d54237
commit 91cbcd589e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 507 additions and 107 deletions

View file

@ -75,8 +75,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
return return
} }
app := &gtsmodel.Application{} app, err := m.db.GetApplicationByClientID(c.Request.Context(), clientID)
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { if err != nil {
m.clearSession(s) m.clearSession(s)
safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode

View file

@ -107,8 +107,8 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
return return
} }
app := &gtsmodel.Application{} app, err := m.db.GetApplicationByClientID(c.Request.Context(), sessionClientID)
if err := m.db.GetWhere(c.Request.Context(), []db.Where{{Key: sessionClientID, Value: clientID}}, app); err != nil { if err != nil {
m.clearSession(s) m.clearSession(s)
safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID) safe := fmt.Sprintf("application for %s %s could not be retrieved", sessionClientID, clientID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode

27
internal/cache/gts.go vendored
View file

@ -32,6 +32,7 @@
type GTSCaches struct { type GTSCaches struct {
account *result.Cache[*gtsmodel.Account] account *result.Cache[*gtsmodel.Account]
accountNote *result.Cache[*gtsmodel.AccountNote] accountNote *result.Cache[*gtsmodel.AccountNote]
application *result.Cache[*gtsmodel.Application]
block *result.Cache[*gtsmodel.Block] block *result.Cache[*gtsmodel.Block]
blockIDs *SliceCache[string] blockIDs *SliceCache[string]
boostOfIDs *SliceCache[string] boostOfIDs *SliceCache[string]
@ -67,6 +68,7 @@ type GTSCaches struct {
func (c *GTSCaches) Init() { func (c *GTSCaches) Init() {
c.initAccount() c.initAccount()
c.initAccountNote() c.initAccountNote()
c.initApplication()
c.initBlock() c.initBlock()
c.initBlockIDs() c.initBlockIDs()
c.initBoostOfIDs() c.initBoostOfIDs()
@ -117,6 +119,11 @@ func (c *GTSCaches) AccountNote() *result.Cache[*gtsmodel.AccountNote] {
return c.accountNote return c.accountNote
} }
// Application provides access to the gtsmodel Application database cache.
func (c *GTSCaches) Application() *result.Cache[*gtsmodel.Application] {
return c.application
}
// Block provides access to the gtsmodel Block (account) database cache. // Block provides access to the gtsmodel Block (account) database cache.
func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] { func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] {
return c.block return c.block
@ -303,6 +310,26 @@ func (c *GTSCaches) initAccountNote() {
c.accountNote.IgnoreErrors(ignoreErrors) c.accountNote.IgnoreErrors(ignoreErrors)
} }
func (c *GTSCaches) initApplication() {
// Calculate maximum cache size.
cap := calculateResultCacheMax(
sizeofApplication(), // model in-mem size.
config.GetCacheApplicationMemRatio(),
)
log.Infof(nil, "Application cache size = %d", cap)
c.application = result.New([]result.Lookup{
{Name: "ID"},
{Name: "ClientID"},
}, func(a1 *gtsmodel.Application) *gtsmodel.Application {
a2 := new(gtsmodel.Application)
*a2 = *a1
return a2
}, cap)
c.application.IgnoreErrors(ignoreErrors)
}
func (c *GTSCaches) initBlock() { func (c *GTSCaches) initBlock() {
// Calculate maximum cache size. // Calculate maximum cache size.
cap := calculateResultCacheMax( cap := calculateResultCacheMax(

View file

@ -155,6 +155,7 @@ func totalOfRatios() float64 {
return 0 + return 0 +
config.GetCacheAccountMemRatio() + config.GetCacheAccountMemRatio() +
config.GetCacheAccountNoteMemRatio() + config.GetCacheAccountNoteMemRatio() +
config.GetCacheApplicationMemRatio() +
config.GetCacheBlockMemRatio() + config.GetCacheBlockMemRatio() +
config.GetCacheBlockIDsMemRatio() + config.GetCacheBlockIDsMemRatio() +
config.GetCacheBoostOfIDsMemRatio() + config.GetCacheBoostOfIDsMemRatio() +
@ -217,7 +218,7 @@ func sizeofAccount() uintptr {
SilencedAt: time.Now(), SilencedAt: time.Now(),
SuspendedAt: time.Now(), SuspendedAt: time.Now(),
HideCollections: func() *bool { ok := true; return &ok }(), HideCollections: func() *bool { ok := true; return &ok }(),
SuspensionOrigin: "", SuspensionOrigin: exampleID,
EnableRSS: func() *bool { ok := true; return &ok }(), EnableRSS: func() *bool { ok := true; return &ok }(),
})) }))
} }
@ -231,6 +232,20 @@ func sizeofAccountNote() uintptr {
})) }))
} }
func sizeofApplication() uintptr {
return uintptr(size.Of(&gtsmodel.Application{
ID: exampleID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Name: exampleUsername,
Website: exampleURI,
RedirectURI: exampleURI,
ClientID: exampleID,
ClientSecret: exampleID,
Scopes: exampleTextSmall,
}))
}
func sizeofBlock() uintptr { func sizeofBlock() uintptr {
return uintptr(size.Of(&gtsmodel.Block{ return uintptr(size.Of(&gtsmodel.Block{
ID: exampleID, ID: exampleID,
@ -500,5 +515,31 @@ func sizeofVisibility() uintptr {
} }
func sizeofUser() uintptr { func sizeofUser() uintptr {
return uintptr(size.Of(&gtsmodel.User{})) return uintptr(size.Of(&gtsmodel.User{
ID: exampleID,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
Email: exampleURI,
AccountID: exampleID,
EncryptedPassword: exampleTextSmall,
CurrentSignInAt: time.Now(),
LastSignInAt: time.Now(),
InviteID: exampleID,
ChosenLanguages: []string{"en", "fr", "jp"},
FilteredLanguages: []string{"en", "fr", "jp"},
Locale: "en",
CreatedByApplicationID: exampleID,
LastEmailedAt: time.Now(),
ConfirmationToken: exampleTextSmall,
ConfirmationSentAt: time.Now(),
ConfirmedAt: time.Now(),
UnconfirmedEmail: exampleURI,
Moderator: func() *bool { ok := true; return &ok }(),
Admin: func() *bool { ok := true; return &ok }(),
Disabled: func() *bool { ok := true; return &ok }(),
Approved: func() *bool { ok := true; return &ok }(),
ResetPasswordToken: exampleTextSmall,
ResetPasswordSentAt: time.Now(),
ExternalID: exampleID,
}))
} }

View file

@ -178,6 +178,7 @@ type CacheConfiguration struct {
MemoryTarget bytesize.Size `name:"memory-target"` MemoryTarget bytesize.Size `name:"memory-target"`
AccountMemRatio float64 `name:"account-mem-ratio"` AccountMemRatio float64 `name:"account-mem-ratio"`
AccountNoteMemRatio float64 `name:"account-note-mem-ratio"` AccountNoteMemRatio float64 `name:"account-note-mem-ratio"`
ApplicationMemRatio float64 `name:"application-mem-ratio"`
BlockMemRatio float64 `name:"block-mem-ratio"` BlockMemRatio float64 `name:"block-mem-ratio"`
BlockIDsMemRatio float64 `name:"block-mem-ratio"` BlockIDsMemRatio float64 `name:"block-mem-ratio"`
BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"` BoostOfIDsMemRatio float64 `name:"boost-of-ids-mem-ratio"`

View file

@ -147,6 +147,7 @@
// be able to make some more sense :D // be able to make some more sense :D
AccountMemRatio: 18, AccountMemRatio: 18,
AccountNoteMemRatio: 0.1, AccountNoteMemRatio: 0.1,
ApplicationMemRatio: 0.1,
BlockMemRatio: 3, BlockMemRatio: 3,
BlockIDsMemRatio: 3, BlockIDsMemRatio: 3,
BoostOfIDsMemRatio: 3, BoostOfIDsMemRatio: 3,
@ -170,7 +171,7 @@
StatusFaveIDsMemRatio: 3, StatusFaveIDsMemRatio: 3,
TagMemRatio: 3, TagMemRatio: 3,
TombstoneMemRatio: 2, TombstoneMemRatio: 2,
UserMemRatio: 0.1, UserMemRatio: 0.25,
WebfingerMemRatio: 0.1, WebfingerMemRatio: 0.1,
VisibilityMemRatio: 2, VisibilityMemRatio: 2,
}, },

View file

@ -2499,6 +2499,31 @@ func GetCacheAccountNoteMemRatio() float64 { return global.GetCacheAccountNoteMe
// SetCacheAccountNoteMemRatio safely sets the value for global configuration 'Cache.AccountNoteMemRatio' field // SetCacheAccountNoteMemRatio safely sets the value for global configuration 'Cache.AccountNoteMemRatio' field
func SetCacheAccountNoteMemRatio(v float64) { global.SetCacheAccountNoteMemRatio(v) } func SetCacheAccountNoteMemRatio(v float64) { global.SetCacheAccountNoteMemRatio(v) }
// GetCacheApplicationMemRatio safely fetches the Configuration value for state's 'Cache.ApplicationMemRatio' field
func (st *ConfigState) GetCacheApplicationMemRatio() (v float64) {
st.mutex.RLock()
v = st.config.Cache.ApplicationMemRatio
st.mutex.RUnlock()
return
}
// SetCacheApplicationMemRatio safely sets the Configuration value for state's 'Cache.ApplicationMemRatio' field
func (st *ConfigState) SetCacheApplicationMemRatio(v float64) {
st.mutex.Lock()
defer st.mutex.Unlock()
st.config.Cache.ApplicationMemRatio = v
st.reloadToViper()
}
// CacheApplicationMemRatioFlag returns the flag name for the 'Cache.ApplicationMemRatio' field
func CacheApplicationMemRatioFlag() string { return "cache-application-mem-ratio" }
// GetCacheApplicationMemRatio safely fetches the value for global configuration 'Cache.ApplicationMemRatio' field
func GetCacheApplicationMemRatio() float64 { return global.GetCacheApplicationMemRatio() }
// SetCacheApplicationMemRatio safely sets the value for global configuration 'Cache.ApplicationMemRatio' field
func SetCacheApplicationMemRatio(v float64) { global.SetCacheApplicationMemRatio(v) }
// GetCacheBlockMemRatio safely fetches the Configuration value for state's 'Cache.BlockMemRatio' field // GetCacheBlockMemRatio safely fetches the Configuration value for state's 'Cache.BlockMemRatio' field
func (st *ConfigState) GetCacheBlockMemRatio() (v float64) { func (st *ConfigState) GetCacheBlockMemRatio() (v float64) {
st.mutex.RLock() st.mutex.RLock()

View file

@ -0,0 +1,38 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type Application interface {
// GetApplicationByID fetches the application from the database with corresponding ID value.
GetApplicationByID(ctx context.Context, id string) (*gtsmodel.Application, error)
// GetApplicationByClientID fetches the application from the database with corresponding client_id value.
GetApplicationByClientID(ctx context.Context, clientID string) (*gtsmodel.Application, error)
// PutApplication places the new application in the database, erroring on non-unique ID or client_id.
PutApplication(ctx context.Context, app *gtsmodel.Application) error
// DeleteApplicationByClientID deletes the application with corresponding client_id value from the database.
DeleteApplicationByClientID(ctx context.Context, clientID string) error
}

View file

@ -0,0 +1,97 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
type applicationDB struct {
db *WrappedDB
state *state.State
}
func (a *applicationDB) GetApplicationByID(ctx context.Context, id string) (*gtsmodel.Application, error) {
return a.getApplication(
ctx,
"ID",
func(app *gtsmodel.Application) error {
return a.db.NewSelect().Model(app).Where("? = ?", bun.Ident("id"), id).Scan(ctx)
},
id,
)
}
func (a *applicationDB) GetApplicationByClientID(ctx context.Context, clientID string) (*gtsmodel.Application, error) {
return a.getApplication(
ctx,
"ClientID",
func(app *gtsmodel.Application) error {
return a.db.NewSelect().Model(app).Where("? = ?", bun.Ident("client_id"), clientID).Scan(ctx)
},
clientID,
)
}
func (a *applicationDB) getApplication(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Application) error, keyParts ...any) (*gtsmodel.Application, error) {
return a.state.Caches.GTS.Application().Load(lookup, func() (*gtsmodel.Application, error) {
var app gtsmodel.Application
// Not cached! Perform database query.
if err := dbQuery(&app); err != nil {
return nil, a.db.ProcessError(err)
}
return &app, nil
}, keyParts...)
}
func (a *applicationDB) PutApplication(ctx context.Context, app *gtsmodel.Application) error {
return a.state.Caches.GTS.Application().Store(app, func() error {
_, err := a.db.NewInsert().Model(app).Exec(ctx)
return a.db.ProcessError(err)
})
}
func (a *applicationDB) DeleteApplicationByClientID(ctx context.Context, clientID string) error {
// Attempt to delete application.
if _, err := a.db.NewDelete().
Table("applications").
Where("? = ?", bun.Ident("client_id"), clientID).
Exec(ctx); err != nil {
return a.db.ProcessError(err)
}
// NOTE about further side effects:
//
// We don't need to handle updating any statuses or users
// (both of which may contain refs to applications), as
// DeleteApplication__() is only ever called during an
// account deletion, which handles deletion of the user
// and all their statuses already.
//
// Clear application from the cache.
a.state.Caches.GTS.Application().Invalidate("ClientID", clientID)
return nil
}

View file

@ -0,0 +1,128 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb_test
import (
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type ApplicationTestSuite struct {
BunDBStandardTestSuite
}
func (suite *ApplicationTestSuite) TestGetApplicationBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
// Sentinel error to mark avoiding a test case.
sentinelErr := errors.New("sentinel")
// isEqual checks if 2 application models are equal.
isEqual := func(a1, a2 gtsmodel.Application) bool {
// Clear database-set fields.
a1.CreatedAt = time.Time{}
a2.CreatedAt = time.Time{}
a1.UpdatedAt = time.Time{}
a2.UpdatedAt = time.Time{}
return reflect.DeepEqual(a1, a2)
}
for _, app := range suite.testApplications {
for lookup, dbfunc := range map[string]func() (*gtsmodel.Application, error){
"id": func() (*gtsmodel.Application, error) {
return suite.db.GetApplicationByID(ctx, app.ID)
},
"client_id": func() (*gtsmodel.Application, error) {
return suite.db.GetApplicationByClientID(ctx, app.ClientID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
checkApp, err := dbfunc()
if err != nil {
if err == sentinelErr {
continue
}
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Check received application data.
if !isEqual(*checkApp, *app) {
t.Errorf("application does not contain expected data: %+v", checkApp)
continue
}
}
}
}
func (suite *ApplicationTestSuite) TestDeleteApplicationBy() {
t := suite.T()
// Create a new context for this test.
ctx, cncl := context.WithCancel(context.Background())
defer cncl()
for _, app := range suite.testApplications {
for lookup, dbfunc := range map[string]func() error{
"client_id": func() error {
return suite.db.DeleteApplicationByClientID(ctx, app.ClientID)
},
} {
// Clear database caches.
suite.state.Caches.Init()
t.Logf("checking database lookup %q", lookup)
// Perform database function.
err := dbfunc()
if err != nil {
t.Errorf("error encountered for database lookup %q: %v", lookup, err)
continue
}
// Ensure this application has been deleted and cache cleared.
if _, err := suite.db.GetApplicationByID(ctx, app.ID); err != db.ErrNoEntries {
t.Errorf("application does not appear to have been deleted %q: %v", lookup, err)
continue
}
}
}
}
func TestApplicationTestSuite(t *testing.T) {
suite.Run(t, new(ApplicationTestSuite))
}

View file

@ -60,6 +60,7 @@
type DBService struct { type DBService struct {
db.Account db.Account
db.Admin db.Admin
db.Application
db.Basic db.Basic
db.Domain db.Domain
db.Emoji db.Emoji
@ -168,6 +169,10 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
db: db, db: db,
state: state, state: state,
}, },
Application: &applicationDB{
db: db,
state: state,
},
Basic: &basicDB{ Basic: &basicDB{
db: db, db: db,
}, },

View file

@ -37,19 +37,12 @@ type statusDB struct {
state *state.State state *state.State
} }
func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.db.
NewSelect().
Model(status).
Relation("CreatedWithApplication")
}
func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) { func (s *statusDB) GetStatusByID(ctx context.Context, id string) (*gtsmodel.Status, error) {
return s.getStatus( return s.getStatus(
ctx, ctx,
"ID", "ID",
func(status *gtsmodel.Status) error { func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx) return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.id"), id).Scan(ctx)
}, },
id, id,
) )
@ -78,7 +71,7 @@ func (s *statusDB) GetStatusByURI(ctx context.Context, uri string) (*gtsmodel.St
ctx, ctx,
"URI", "URI",
func(status *gtsmodel.Status) error { func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx) return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.uri"), uri).Scan(ctx)
}, },
uri, uri,
) )
@ -89,7 +82,7 @@ func (s *statusDB) GetStatusByURL(ctx context.Context, url string) (*gtsmodel.St
ctx, ctx,
"URL", "URL",
func(status *gtsmodel.Status) error { func(status *gtsmodel.Status) error {
return s.newStatusQ(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx) return s.db.NewSelect().Model(status).Where("? = ?", bun.Ident("status.url"), url).Scan(ctx)
}, },
url, url,
) )
@ -100,7 +93,7 @@ func (s *statusDB) GetStatusBoost(ctx context.Context, boostOfID string, byAccou
ctx, ctx,
"BoostOfID.AccountID", "BoostOfID.AccountID",
func(status *gtsmodel.Status) error { func(status *gtsmodel.Status) error {
return s.newStatusQ(status). return s.db.NewSelect().Model(status).
Where("status.boost_of_id = ?", boostOfID). Where("status.boost_of_id = ?", boostOfID).
Where("status.account_id = ?", byAccountID). Where("status.account_id = ?", byAccountID).
@ -264,6 +257,17 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status)
} }
} }
if status.CreatedWithApplicationID != "" && status.CreatedWithApplication == nil {
// Populate the status' expected CreatedWithApplication (not always set).
status.CreatedWithApplication, err = s.state.DB.GetApplicationByID(
ctx, // these are already barebones
status.CreatedWithApplicationID,
)
if err != nil {
errs.Appendf("error populating status application: %w", err)
}
}
return errs.Combine() return errs.Combine()
} }

View file

@ -24,6 +24,7 @@
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@ -35,107 +36,125 @@ type userDB struct {
} }
func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) { func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("ID", func() (*gtsmodel.User, error) { return u.getUser(
var user gtsmodel.User ctx,
"ID",
q := u.db. func(user *gtsmodel.User) error {
NewSelect(). return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("id"), id).Scan(ctx)
Model(&user). },
Relation("Account"). id,
Where("? = ?", bun.Ident("user.id"), id) )
if err := q.Scan(ctx); err != nil {
return nil, u.db.ProcessError(err)
} }
return &user, nil func (u *userDB) GetUsersByIDs(ctx context.Context, ids []string) ([]*gtsmodel.User, error) {
}, id) var (
users = make([]*gtsmodel.User, 0, len(ids))
// Collect errors instead of
// returning early on any.
errs gtserror.MultiError
)
for _, id := range ids {
// Attempt to fetch user from DB.
user, err := u.GetUserByID(ctx, id)
if err != nil {
errs.Appendf("error getting user %s: %w", id, err)
continue
}
// Append user to return slice.
users = append(users, user)
}
return users, errs.Combine()
} }
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) { func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("AccountID", func() (*gtsmodel.User, error) { return u.getUser(
var user gtsmodel.User ctx,
"AccountID",
q := u.db. func(user *gtsmodel.User) error {
NewSelect(). return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("account_id"), accountID).Scan(ctx)
Model(&user). },
Relation("Account"). accountID,
Where("? = ?", bun.Ident("user.account_id"), accountID) )
if err := q.Scan(ctx); err != nil {
return nil, u.db.ProcessError(err)
} }
return &user, nil func (u *userDB) GetUserByEmailAddress(ctx context.Context, email string) (*gtsmodel.User, error) {
}, accountID) return u.getUser(
} ctx,
"Email",
func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, error) { func(user *gtsmodel.User) error {
return u.state.Caches.GTS.User().Load("Email", func() (*gtsmodel.User, error) { return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("email"), email).Scan(ctx)
var user gtsmodel.User },
email,
q := u.db. )
NewSelect().
Model(&user).
Relation("Account").
Where("? = ?", bun.Ident("user.email"), emailAddress)
if err := q.Scan(ctx); err != nil {
return nil, u.db.ProcessError(err)
}
return &user, nil
}, emailAddress)
} }
func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) { func (u *userDB) GetUserByExternalID(ctx context.Context, id string) (*gtsmodel.User, error) {
return u.state.Caches.GTS.User().Load("ExternalID", func() (*gtsmodel.User, error) { return u.getUser(
ctx,
"ExternalID",
func(user *gtsmodel.User) error {
return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("external_id"), id).Scan(ctx)
},
id,
)
}
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, token string) (*gtsmodel.User, error) {
return u.getUser(
ctx,
"ConfirmationToken",
func(user *gtsmodel.User) error {
return u.db.NewSelect().Model(user).Where("? = ?", bun.Ident("confirmation_token"), token).Scan(ctx)
},
token,
)
}
func (u *userDB) getUser(ctx context.Context, lookup string, dbQuery func(*gtsmodel.User) error, keyParts ...any) (*gtsmodel.User, error) {
// Fetch user from database cache with loader callback.
user, err := u.state.Caches.GTS.User().Load(lookup, func() (*gtsmodel.User, error) {
var user gtsmodel.User var user gtsmodel.User
q := u.db. // Not cached! perform database query.
NewSelect(). if err := dbQuery(&user); err != nil {
Model(&user).
Relation("Account").
Where("? = ?", bun.Ident("user.external_id"), id)
if err := q.Scan(ctx); err != nil {
return nil, u.db.ProcessError(err) return nil, u.db.ProcessError(err)
} }
return &user, nil return &user, nil
}, id) }, keyParts...)
if err != nil {
return nil, err
} }
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, error) { // Fetch the related account model for this user.
return u.state.Caches.GTS.User().Load("ConfirmationToken", func() (*gtsmodel.User, error) { user.Account, err = u.state.DB.GetAccountByID(
var user gtsmodel.User gtscontext.SetBarebones(ctx),
user.AccountID,
q := u.db. )
NewSelect(). if err != nil {
Model(&user). return nil, gtserror.Newf("error populating user account: %w", err)
Relation("Account").
Where("? = ?", bun.Ident("user.confirmation_token"), confirmationToken)
if err := q.Scan(ctx); err != nil {
return nil, u.db.ProcessError(err)
} }
return &user, nil return user, nil
}, confirmationToken)
} }
func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) { func (u *userDB) GetAllUsers(ctx context.Context) ([]*gtsmodel.User, error) {
var users []*gtsmodel.User var userIDs []string
q := u.db.
NewSelect().
Model(&users).
Relation("Account")
if err := q.Scan(ctx); err != nil { // Scan all user IDs into slice.
if err := u.db.NewSelect().
Table("users").
Column("id").
Scan(ctx, &userIDs); err != nil {
return nil, u.db.ProcessError(err) return nil, u.db.ProcessError(err)
} }
return users, nil // Transform user IDs into user slice.
return u.GetUsersByIDs(ctx, userIDs)
} }
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error { func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) error {

View file

@ -26,6 +26,7 @@
type DB interface { type DB interface {
Account Account
Admin Admin
Application
Basic Basic
Domain Domain
Emoji Emoji

View file

@ -22,7 +22,6 @@
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/oauth2/v4" "github.com/superseriousbusiness/oauth2/v4"
@ -125,8 +124,8 @@ func TokenCheck(dbConn db.DB, validateBearerToken func(r *http.Request) (oauth2.
log.Tracef(ctx, "authenticated client %s with bearer token, scope is %s", clientID, ti.GetScope()) log.Tracef(ctx, "authenticated client %s with bearer token, scope is %s", clientID, ti.GetScope())
// fetch app for this token // fetch app for this token
app := &gtsmodel.Application{} app, err := dbConn.GetApplicationByClientID(ctx, clientID)
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "client_id", Value: clientID}}, app); err != nil { if err != nil {
if err != db.ErrNoEntries { if err != db.ErrNoEntries {
log.Errorf(ctx, "database error looking for application with clientID %s: %s", clientID, err) log.Errorf(ctx, "database error looking for application with clientID %s: %s", clientID, err)
return return
@ -134,6 +133,7 @@ func TokenCheck(dbConn db.DB, validateBearerToken func(r *http.Request) (oauth2.
log.Warnf(ctx, "no app found for client %s", clientID) log.Warnf(ctx, "no app found for client %s", clientID)
return return
} }
c.Set(oauth.SessionAuthorizedApplication, app) c.Set(oauth.SessionAuthorizedApplication, app)
} }
} }

View file

@ -46,12 +46,6 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
}...) }...)
l.Trace("beginning account delete process") l.Trace("beginning account delete process")
if account.IsLocal() {
if err := p.deleteUserAndTokensForAccount(ctx, account); err != nil {
return gtserror.NewErrorInternalError(err)
}
}
if err := p.deleteAccountFollows(ctx, account); err != nil { if err := p.deleteAccountFollows(ctx, account); err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
@ -72,6 +66,14 @@ func (p *Processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
if account.IsLocal() {
// we tokens, applications and clients for account as one of the last
// stages during deletion, as other database models rely on these.
if err := p.deleteUserAndTokensForAccount(ctx, account); err != nil {
return gtserror.NewErrorInternalError(err)
}
}
// To prevent the account being created again, // To prevent the account being created again,
// stubbify it and update it in the db. // stubbify it and update it in the db.
// The account will not be deleted, but it // The account will not be deleted, but it
@ -129,7 +131,7 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *
} }
// Delete any OAuth applications associated with this token. // Delete any OAuth applications associated with this token.
if err := p.state.DB.DeleteWhere(ctx, []db.Where{{Key: "client_id", Value: t.ClientID}}, &[]*gtsmodel.Application{}); err != nil { if err := p.state.DB.DeleteApplicationByClientID(ctx, t.ClientID); err != nil {
return gtserror.Newf("db error deleting application: %w", err) return gtserror.Newf("db error deleting application: %w", err)
} }
@ -305,7 +307,17 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel
statusLoop: statusLoop:
for { for {
// Page through account's statuses. // Page through account's statuses.
statuses, err = p.state.DB.GetAccountStatuses(ctx, account.ID, deleteSelectLimit, false, false, maxID, "", false, false) statuses, err = p.state.DB.GetAccountStatuses(
ctx,
account.ID,
deleteSelectLimit,
false,
false,
maxID,
"",
false,
false,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
// Make sure we don't have a real error. // Make sure we don't have a real error.
return err return err

View file

@ -61,7 +61,7 @@ func (p *Processor) AppCreate(ctx context.Context, authed *oauth.Auth, form *api
} }
// chuck it in the db // chuck it in the db
if err := p.state.DB.Put(ctx, app); err != nil { if err := p.state.DB.PutApplication(ctx, app); err != nil {
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
} }

View file

@ -699,8 +699,8 @@ func (c *converter) StatusToAPIStatus(ctx context.Context, s *gtsmodel.Status, r
} }
if appID := s.CreatedWithApplicationID; appID != "" { if appID := s.CreatedWithApplicationID; appID != "" {
app := &gtsmodel.Application{} app, err := c.db.GetApplicationByID(ctx, appID)
if err := c.db.GetByID(ctx, appID, app); err != nil { if err != nil {
return nil, fmt.Errorf("error getting application %s: %w", appID, err) return nil, fmt.Errorf("error getting application %s: %w", appID, err)
} }

View file

@ -20,6 +20,7 @@ EXPECT=$(cat << "EOF"
"cache": { "cache": {
"account-mem-ratio": 18, "account-mem-ratio": 18,
"account-note-mem-ratio": 0.1, "account-note-mem-ratio": 0.1,
"application-mem-ratio": 0.1,
"block-mem-ratio": 3, "block-mem-ratio": 3,
"boost-of-ids-mem-ratio": 3, "boost-of-ids-mem-ratio": 3,
"emoji-category-mem-ratio": 0.1, "emoji-category-mem-ratio": 0.1,
@ -43,7 +44,7 @@ EXPECT=$(cat << "EOF"
"status-mem-ratio": 18, "status-mem-ratio": 18,
"tag-mem-ratio": 3, "tag-mem-ratio": 3,
"tombstone-mem-ratio": 2, "tombstone-mem-ratio": 2,
"user-mem-ratio": 0.1, "user-mem-ratio": 0.25,
"visibility-mem-ratio": 2, "visibility-mem-ratio": 2,
"webfinger-mem-ratio": 0.1 "webfinger-mem-ratio": 0.1
}, },