[performance] add user cache and database (#879)

* go fmt

* add + use user cache and database

* fix import

* update tests

* remove unused relation
This commit is contained in:
tobi 2022-10-03 10:46:11 +02:00 committed by GitHub
parent f7af7c061c
commit 56f53a2a6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 490 additions and 70 deletions

View file

@ -26,9 +26,7 @@
"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action" "github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/db/bundb" "github.com/superseriousbusiness/gotosocial/internal/db/bundb"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/validate" "github.com/superseriousbusiness/gotosocial/internal/validate"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -92,8 +90,8 @@
return err return err
} }
u := &gtsmodel.User{} u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { if err != nil {
return err return err
} }
@ -130,8 +128,8 @@
return err return err
} }
u := &gtsmodel.User{} u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { if err != nil {
return err return err
} }
@ -139,7 +137,7 @@
admin := true admin := true
u.Admin = &admin u.Admin = &admin
u.UpdatedAt = time.Now() u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err return err
} }
@ -166,8 +164,8 @@
return err return err
} }
u := &gtsmodel.User{} u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { if err != nil {
return err return err
} }
@ -175,7 +173,7 @@
admin := false admin := false
u.Admin = &admin u.Admin = &admin
u.UpdatedAt = time.Now() u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err return err
} }
@ -202,8 +200,8 @@
return err return err
} }
u := &gtsmodel.User{} u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { if err != nil {
return err return err
} }
@ -211,7 +209,7 @@
disabled := true disabled := true
u.Disabled = &disabled u.Disabled = &disabled
u.UpdatedAt = time.Now() u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err return err
} }
@ -252,8 +250,8 @@
return err return err
} }
u := &gtsmodel.User{} u, err := dbConn.GetUserByAccountID(ctx, a.ID)
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil { if err != nil {
return err return err
} }
@ -265,7 +263,7 @@
updatingColumns := []string{"encrypted_password", "updated_at"} updatingColumns := []string{"encrypted_password", "updated_at"}
u.EncryptedPassword = string(pw) u.EncryptedPassword = string(pw)
u.UpdatedAt = time.Now() u.UpdatedAt = time.Now()
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil { if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
return err return err
} }

View file

@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
return return
} }
user := &gtsmodel.User{} user, err := m.db.GetUserByID(c.Request.Context(), userID)
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { if err != nil {
m.clearSession(s) m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID) safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode
@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
return return
} }
user := &gtsmodel.User{} user, err := m.db.GetUserByID(c.Request.Context(), userID)
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil { if err != nil {
m.clearSession(s) m.clearSession(s)
safe := fmt.Sprintf("user with id %s could not be retrieved", userID) safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
var errWithCode gtserror.WithCode var errWithCode gtserror.WithCode

View file

@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
doTest := func(testCase authorizeHandlerTestCase) { doTest := func(testCase authorizeHandlerTestCase) {
ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "") ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "")
user := suite.testUsers["unconfirmed_account"] user := &gtsmodel.User{}
account := suite.testAccounts["unconfirmed_account"] account := &gtsmodel.Account{}
*user = *suite.testUsers["unconfirmed_account"]
*account = *suite.testAccounts["unconfirmed_account"]
testSession := sessions.Default(ctx) testSession := sessions.Default(ctx)
testSession.Set(sessionUserID, user.ID) testSession.Set(sessionUserID, user.ID)
@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt) testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)
updatingColumns = append(updatingColumns, "updated_at") updatingColumns = append(updatingColumns, "updated_at")
user.UpdatedAt = time.Now() _, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)
err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...)
suite.NoError(err) suite.NoError(err)
_, err = suite.db.UpdateAccount(context.Background(), account) _, err = suite.db.UpdateAccount(context.Background(), account)
suite.NoError(err) suite.NoError(err)

View file

@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i
// see if we already have a user for this email address // see if we already have a user for this email address
// if so, we don't need to continue + create one // if so, we don't need to continue + create one
user := &gtsmodel.User{} user, err := m.db.GetUserByEmailAddress(ctx, claims.Email)
err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
if err == nil { if err == nil {
return user, nil return user, nil
} }

View file

@ -28,9 +28,7 @@
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/superseriousbusiness/gotosocial/internal/api" "github.com/superseriousbusiness/gotosocial/internal/api"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st
return incorrectPassword(err) return incorrectPassword(err)
} }
user := &gtsmodel.User{} user, err := m.db.GetUserByEmailAddress(ctx, email)
if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil { if err != nil {
err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err) err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
return incorrectPassword(err) return incorrectPassword(err)
} }

View file

@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) {
log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope()) log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope())
// fetch user for this token // fetch user for this token
user := &gtsmodel.User{} user, err := m.db.GetUserByID(ctx, userID)
if err := m.db.GetByID(ctx, userID, user); err != nil { if err != nil {
if err != db.ErrNoEntries { if err != db.ErrNoEntries {
log.Errorf("database error looking for user with id %s: %s", userID, err) log.Errorf("database error looking for user with id %s: %s", userID, err)
return return
@ -80,6 +80,7 @@ func (m *Module) TokenCheck(c *gin.Context) {
c.Set(oauth.SessionAuthorizedUser, user) c.Set(oauth.SessionAuthorizedUser, user)
// fetch account for this token // fetch account for this token
if user.Account == nil {
acct, err := m.db.GetAccountByID(ctx, user.AccountID) acct, err := m.db.GetAccountByID(ctx, user.AccountID)
if err != nil { if err != nil {
if err != db.ErrNoEntries { if err != db.ErrNoEntries {
@ -89,13 +90,15 @@ func (m *Module) TokenCheck(c *gin.Context) {
log.Warnf("no account found for userID %s", userID) log.Warnf("no account found for userID %s", userID)
return return
} }
user.Account = acct
}
if !acct.SuspendedAt.IsZero() { if !user.Account.SuspendedAt.IsZero() {
log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID) log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID)
return return
} }
c.Set(oauth.SessionAuthorizedAccount, acct) c.Set(oauth.SessionAuthorizedAccount, user.Account)
} }
// check for application token // check for application token

141
internal/cache/user.go vendored Normal file
View file

@ -0,0 +1,141 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import (
"time"
"codeberg.org/gruf/go-cache/v2"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// UserCache is a cache wrapper to provide lookups for gtsmodel.User
type UserCache struct {
cache cache.LookupCache[string, string, *gtsmodel.User]
}
// NewUserCache returns a new instantiated UserCache object
func NewUserCache() *UserCache {
c := &UserCache{}
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
lm.RegisterLookup("accountid")
lm.RegisterLookup("email")
lm.RegisterLookup("unconfirmedemail")
lm.RegisterLookup("confirmationtoken")
},
AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
lm.Set("accountid", user.AccountID, user.ID)
if email := user.Email; email != "" {
lm.Set("email", email, user.ID)
}
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
lm.Set("unconfirmedemail", unconfirmedEmail, user.ID)
}
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
lm.Set("confirmationtoken", confirmationToken, user.ID)
}
},
DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
lm.Delete("accountid", user.AccountID)
if email := user.Email; email != "" {
lm.Delete("email", email)
}
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
lm.Delete("unconfirmedemail", unconfirmedEmail)
}
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
lm.Delete("confirmationtoken", confirmationToken)
}
},
})
c.cache.SetTTL(time.Minute*5, false)
c.cache.Start(time.Second * 10)
return c
}
// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety
func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) {
return c.cache.Get(id)
}
// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety
func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) {
return c.cache.GetBy("accountid", accountID)
}
// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety
func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) {
return c.cache.GetBy("email", email)
}
// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety
func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) {
return c.cache.GetBy("confirmationtoken", token)
}
// Put places a user in the cache, ensuring that the object place is a copy for thread-safety
func (c *UserCache) Put(user *gtsmodel.User) {
if user == nil || user.ID == "" {
panic("invalid user")
}
c.cache.Set(user.ID, copyUser(user))
}
// Invalidate invalidates one user from the cache using the ID of the user as key.
func (c *UserCache) Invalidate(userID string) {
c.cache.Invalidate(userID)
}
func copyUser(user *gtsmodel.User) *gtsmodel.User {
return &gtsmodel.User{
ID: user.ID,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
Email: user.Email,
AccountID: user.AccountID,
Account: nil,
EncryptedPassword: user.EncryptedPassword,
SignUpIP: user.SignUpIP,
CurrentSignInAt: user.CurrentSignInAt,
CurrentSignInIP: user.CurrentSignInIP,
LastSignInAt: user.LastSignInAt,
LastSignInIP: user.LastSignInIP,
SignInCount: user.SignInCount,
InviteID: user.InviteID,
ChosenLanguages: user.ChosenLanguages,
FilteredLanguages: user.FilteredLanguages,
Locale: user.Locale,
CreatedByApplicationID: user.CreatedByApplicationID,
CreatedByApplication: nil,
LastEmailedAt: user.LastEmailedAt,
ConfirmationToken: user.ConfirmationToken,
ConfirmationSentAt: user.ConfirmationSentAt,
ConfirmedAt: user.ConfirmedAt,
UnconfirmedEmail: user.UnconfirmedEmail,
Moderator: copyBoolPtr(user.Moderator),
Admin: copyBoolPtr(user.Admin),
Disabled: copyBoolPtr(user.Disabled),
Approved: copyBoolPtr(user.Approved),
ResetPasswordToken: user.ResetPasswordToken,
ResetPasswordSentAt: user.ResetPasswordSentAt,
}
}

View file

@ -30,6 +30,7 @@
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -41,6 +42,7 @@
type adminDB struct { type adminDB struct {
conn *DBConn conn *DBConn
userCache *cache.UserCache
} }
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) { func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
@ -175,6 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
Exec(ctx); err != nil { Exec(ctx); err != nil {
return nil, a.conn.ProcessError(err) return nil, a.conn.ProcessError(err)
} }
a.userCache.Put(u)
return u, nil return u, nil
} }

View file

@ -87,6 +87,7 @@ type DBService struct {
db.Session db.Session
db.Status db.Status
db.Timeline db.Timeline
db.User
conn *DBConn conn *DBConn
} }
@ -181,13 +182,15 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
notifCache.SetTTL(time.Minute*5, false) notifCache.SetTTL(time.Minute*5, false)
notifCache.Start(time.Second * 10) notifCache.Start(time.Second * 10)
// Prepare domain block cache // Prepare other caches
blockCache := cache.NewDomainBlockCache() blockCache := cache.NewDomainBlockCache()
userCache := cache.NewUserCache()
ps := &DBService{ ps := &DBService{
Account: accounts, Account: accounts,
Admin: &adminDB{ Admin: &adminDB{
conn: conn, conn: conn,
userCache: userCache,
}, },
Basic: &basicDB{ Basic: &basicDB{
conn: conn, conn: conn,
@ -219,6 +222,10 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
}, },
Status: status, Status: status,
Timeline: timeline, Timeline: timeline,
User: &userDB{
conn: conn,
cache: userCache,
},
conn: conn, conn: conn,
} }

151
internal/db/bundb/user.go Normal file
View file

@ -0,0 +1,151 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb
import (
"context"
"time"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/uptrace/bun"
)
type userDB struct {
conn *DBConn
cache *cache.UserCache
}
func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery {
return u.conn.
NewSelect().
Model(user).
Relation("Account")
}
func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) {
// Attempt to fetch cached user
user, cached := cacheGet()
if !cached {
user = &gtsmodel.User{}
// Not cached! Perform database query
err := dbQuery(user)
if err != nil {
return nil, u.conn.ProcessError(err)
}
// Place in the cache
u.cache.Put(user)
}
return user, nil
}
func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) {
return u.getUser(
ctx,
func() (*gtsmodel.User, bool) {
return u.cache.GetByID(id)
},
func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx)
},
)
}
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) {
return u.getUser(
ctx,
func() (*gtsmodel.User, bool) {
return u.cache.GetByAccountID(accountID)
},
func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx)
},
)
}
func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) {
return u.getUser(
ctx,
func() (*gtsmodel.User, bool) {
return u.cache.GetByEmail(emailAddress)
},
func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx)
},
)
}
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) {
return u.getUser(
ctx,
func() (*gtsmodel.User, bool) {
return u.cache.GetByConfirmationToken(confirmationToken)
},
func(user *gtsmodel.User) error {
return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx)
},
)
}
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) {
if _, err := u.conn.
NewInsert().
Model(user).
Exec(ctx); err != nil {
return nil, u.conn.ProcessError(err)
}
u.cache.Put(user)
return user, nil
}
func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) {
// Update the user's last-updated
user.UpdatedAt = time.Now()
if _, err := u.conn.
NewUpdate().
Model(user).
WherePK().
Column(columns...).
Exec(ctx); err != nil {
return nil, u.conn.ProcessError(err)
}
u.cache.Invalidate(user.ID)
return user, nil
}
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
if _, err := u.conn.
NewDelete().
Model(&gtsmodel.User{ID: userID}).
WherePK().
Exec(ctx); err != nil {
return u.conn.ProcessError(err)
}
u.cache.Invalidate(userID)
return nil
}

View file

@ -0,0 +1,73 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package bundb_test
import (
"context"
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type UserTestSuite struct {
BunDBStandardTestSuite
}
func (suite *UserTestSuite) TestGetUser() {
user, err := suite.db.GetUserByID(context.Background(), suite.testUsers["local_account_1"].ID)
suite.NoError(err)
suite.NotNil(user)
}
func (suite *UserTestSuite) TestGetUserByEmailAddress() {
user, err := suite.db.GetUserByEmailAddress(context.Background(), suite.testUsers["local_account_1"].Email)
suite.NoError(err)
suite.NotNil(user)
}
func (suite *UserTestSuite) TestGetUserByAccountID() {
user, err := suite.db.GetUserByAccountID(context.Background(), suite.testAccounts["local_account_1"].ID)
suite.NoError(err)
suite.NotNil(user)
}
func (suite *UserTestSuite) TestUpdateUserSelectedColumns() {
testUser := suite.testUsers["local_account_1"]
user := &gtsmodel.User{
ID: testUser.ID,
Email: "whatever",
Locale: "es",
}
user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale")
suite.NoError(err)
suite.NotNil(user)
dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID)
suite.NoError(err)
suite.NotNil(dbUser)
suite.Equal("whatever", dbUser.Email)
suite.Equal("es", dbUser.Locale)
suite.Equal(testUser.AccountID, dbUser.AccountID)
}
func TestUserTestSuite(t *testing.T) {
suite.Run(t, new(UserTestSuite))
}

View file

@ -44,6 +44,7 @@ type DB interface {
Session Session
Status Status
Timeline Timeline
User
/* /*
USEFUL CONVERSION FUNCTIONS USEFUL CONVERSION FUNCTIONS

42
internal/db/user.go Normal file
View file

@ -0,0 +1,42 @@
/*
GoToSocial
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// User contains functions related to user getting/setting/creation.
type User interface {
// GetUserByID returns one user with the given ID, or an error if something goes wrong.
GetUserByID(ctx context.Context, id string) (*gtsmodel.User, Error)
// GetUserByAccountID returns one user by its account ID, or an error if something goes wrong.
GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, Error)
// GetUserByID returns one user with the given email address, or an error if something goes wrong.
GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error)
// GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong.
GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error)
// UpdateUser updates one user by its primary key. If columns is set, only given columns
// will be updated. If not set, all columns will be updated.
UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error)
// DeleteUserByID deletes one user by its ID.
DeleteUserByID(ctx context.Context, userID string) Error
}

View file

@ -70,13 +70,14 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// 1. Delete account's application(s), clients, and oauth tokens // 1. Delete account's application(s), clients, and oauth tokens
// we only need to do this step for local account since remote ones won't have any tokens or applications on our server // we only need to do this step for local account since remote ones won't have any tokens or applications on our server
var user *gtsmodel.User
if account.Domain == "" { if account.Domain == "" {
// see if we can get a user for this account // see if we can get a user for this account
u := &gtsmodel.User{} var err error
if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil { if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil {
// we got one! select all tokens with the user's ID // we got one! select all tokens with the user's ID
tokens := []*gtsmodel.Token{} tokens := []*gtsmodel.Token{}
if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil { if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {
// we have some tokens to delete // we have some tokens to delete
for _, t := range tokens { for _, t := range tokens {
// delete client(s) associated with this token // delete client(s) associated with this token
@ -240,10 +241,12 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
// TODO // TODO
// 16. Delete account's user // 16. Delete account's user
if user != nil {
l.Debug("deleting account user") l.Debug("deleting account user")
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, &gtsmodel.User{}); err != nil { if err := p.db.DeleteUserByID(ctx, user.ID); err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
}
// 17. Delete account's timeline // 17. Delete account's timeline
// TODO // TODO
@ -288,8 +291,8 @@ func (p *processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
if form.DeleteOriginID == account.ID { if form.DeleteOriginID == account.ID {
// the account owner themself has requested deletion via the API, get their user from the db // the account owner themself has requested deletion via the API, get their user from the db
user := &gtsmodel.User{} user, err := p.db.GetUserByAccountID(ctx, account.ID)
if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil { if err != nil {
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }

View file

@ -29,7 +29,6 @@
"github.com/superseriousbusiness/activity/pub" "github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages" "github.com/superseriousbusiness/gotosocial/internal/messages"
@ -138,8 +137,8 @@ func (p *processor) processCreateAccountFromClientAPI(ctx context.Context, clien
} }
// get the user this account belongs to // get the user this account belongs to
user := &gtsmodel.User{} user, err := p.db.GetUserByAccountID(ctx, account.ID)
if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil { if err != nil {
return err return err
} }

View file

@ -142,8 +142,8 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername)) return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))
} }
// make sure it has a user associated with it // make sure it has a user associated with it
contactUser := &gtsmodel.User{} contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID)
if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil { if err != nil {
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername)) return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))
} }
// suspended accounts cannot be contact accounts // suspended accounts cannot be contact accounts

View file

@ -40,8 +40,8 @@ func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken s
return nil, gtserror.NewErrorUnauthorized(err) return nil, gtserror.NewErrorUnauthorized(err)
} }
user := &gtsmodel.User{} user, err := p.db.GetUserByID(ctx, uid)
if err := p.db.GetByID(ctx, uid, user); err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
err := fmt.Errorf("no user found for validated uid %s", uid) err := fmt.Errorf("no user found for validated uid %s", uid)
return nil, gtserror.NewErrorUnauthorized(err) return nil, gtserror.NewErrorUnauthorized(err)

View file

@ -89,8 +89,8 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U
return nil, gtserror.NewErrorNotFound(errors.New("no token provided")) return nil, gtserror.NewErrorNotFound(errors.New("no token provided"))
} }
user := &gtsmodel.User{} user, err := p.db.GetUserByConfirmationToken(ctx, token)
if err := p.db.GetWhere(ctx, []db.Where{{Key: "confirmation_token", Value: token}}, user); err != nil { if err != nil {
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(err) return nil, gtserror.NewErrorNotFound(err)
} }

View file

@ -68,8 +68,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu
// if the target user doesn't exist (anymore) then the status also shouldn't be visible // if the target user doesn't exist (anymore) then the status also shouldn't be visible
// note: we only do this for local users // note: we only do this for local users
if targetAccount.Domain == "" { if targetAccount.Domain == "" {
targetUser := &gtsmodel.User{} targetUser, err := f.db.GetUserByAccountID(ctx, targetAccount.ID)
if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil { if err != nil {
l.Debug("target user could not be selected") l.Debug("target user could not be selected")
if err == db.ErrNoEntries { if err == db.ErrNoEntries {
return false, nil return false, nil
@ -98,8 +98,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu
// if the requesting user doesn't exist (anymore) then the status also shouldn't be visible // if the requesting user doesn't exist (anymore) then the status also shouldn't be visible
// note: we only do this for local users // note: we only do this for local users
if requestingAccount.Domain == "" { if requestingAccount.Domain == "" {
requestingUser := &gtsmodel.User{} requestingUser, err := f.db.GetUserByAccountID(ctx, requestingAccount.ID)
if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil { if err != nil {
// if the requesting account is local but doesn't have a corresponding user in the db this is a problem // if the requesting account is local but doesn't have a corresponding user in the db this is a problem
l.Debug("requesting user could not be selected") l.Debug("requesting user could not be selected")
if err == db.ErrNoEntries { if err == db.ErrNoEntries {