mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-01-21 22:00:21 +00:00
[performance] add user cache and database (#879)
* go fmt * add + use user cache and database * fix import * update tests * remove unused relation
This commit is contained in:
parent
f7af7c061c
commit
56f53a2a6f
|
@ -26,9 +26,7 @@
|
|||
|
||||
"github.com/superseriousbusiness/gotosocial/cmd/gotosocial/action"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db/bundb"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/validate"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
@ -92,8 +90,8 @@
|
|||
return err
|
||||
}
|
||||
|
||||
u := >smodel.User{}
|
||||
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
|
||||
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -130,8 +128,8 @@
|
|||
return err
|
||||
}
|
||||
|
||||
u := >smodel.User{}
|
||||
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
|
||||
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -139,7 +137,7 @@
|
|||
admin := true
|
||||
u.Admin = &admin
|
||||
u.UpdatedAt = time.Now()
|
||||
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
|
||||
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -166,8 +164,8 @@
|
|||
return err
|
||||
}
|
||||
|
||||
u := >smodel.User{}
|
||||
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
|
||||
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -175,7 +173,7 @@
|
|||
admin := false
|
||||
u.Admin = &admin
|
||||
u.UpdatedAt = time.Now()
|
||||
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
|
||||
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -202,8 +200,8 @@
|
|||
return err
|
||||
}
|
||||
|
||||
u := >smodel.User{}
|
||||
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
|
||||
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -211,7 +209,7 @@
|
|||
disabled := true
|
||||
u.Disabled = &disabled
|
||||
u.UpdatedAt = time.Now()
|
||||
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
|
||||
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -252,8 +250,8 @@
|
|||
return err
|
||||
}
|
||||
|
||||
u := >smodel.User{}
|
||||
if err := dbConn.GetWhere(ctx, []db.Where{{Key: "account_id", Value: a.ID}}, u); err != nil {
|
||||
u, err := dbConn.GetUserByAccountID(ctx, a.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -265,7 +263,7 @@
|
|||
updatingColumns := []string{"encrypted_password", "updated_at"}
|
||||
u.EncryptedPassword = string(pw)
|
||||
u.UpdatedAt = time.Now()
|
||||
if err := dbConn.UpdateByPrimaryKey(ctx, u, updatingColumns...); err != nil {
|
||||
if _, err := dbConn.UpdateUser(ctx, u, updatingColumns...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -94,8 +94,8 @@ func (m *Module) AuthorizeGETHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
user := >smodel.User{}
|
||||
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
|
||||
user, err := m.db.GetUserByID(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
m.clearSession(s)
|
||||
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
|
||||
var errWithCode gtserror.WithCode
|
||||
|
@ -213,8 +213,8 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
user := >smodel.User{}
|
||||
if err := m.db.GetByID(c.Request.Context(), userID, user); err != nil {
|
||||
user, err := m.db.GetUserByID(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
m.clearSession(s)
|
||||
safe := fmt.Sprintf("user with id %s could not be retrieved", userID)
|
||||
var errWithCode gtserror.WithCode
|
||||
|
|
|
@ -76,8 +76,11 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
|
|||
doTest := func(testCase authorizeHandlerTestCase) {
|
||||
ctx, recorder := suite.newContext(http.MethodGet, auth.OauthAuthorizePath, nil, "")
|
||||
|
||||
user := suite.testUsers["unconfirmed_account"]
|
||||
account := suite.testAccounts["unconfirmed_account"]
|
||||
user := >smodel.User{}
|
||||
account := >smodel.Account{}
|
||||
|
||||
*user = *suite.testUsers["unconfirmed_account"]
|
||||
*account = *suite.testAccounts["unconfirmed_account"]
|
||||
|
||||
testSession := sessions.Default(ctx)
|
||||
testSession.Set(sessionUserID, user.ID)
|
||||
|
@ -91,8 +94,7 @@ func (suite *AuthAuthorizeTestSuite) TestAccountAuthorizeHandler() {
|
|||
testCase.description = fmt.Sprintf("%s, %t, %s", user.Email, *user.Disabled, account.SuspendedAt)
|
||||
|
||||
updatingColumns = append(updatingColumns, "updated_at")
|
||||
user.UpdatedAt = time.Now()
|
||||
err := suite.db.UpdateByPrimaryKey(context.Background(), user, updatingColumns...)
|
||||
_, err := suite.db.UpdateUser(context.Background(), user, updatingColumns...)
|
||||
suite.NoError(err)
|
||||
_, err = suite.db.UpdateAccount(context.Background(), account)
|
||||
suite.NoError(err)
|
||||
|
|
|
@ -134,8 +134,7 @@ func (m *Module) parseUserFromClaims(ctx context.Context, claims *oidc.Claims, i
|
|||
|
||||
// see if we already have a user for this email address
|
||||
// if so, we don't need to continue + create one
|
||||
user := >smodel.User{}
|
||||
err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: claims.Email}}, user)
|
||||
user, err := m.db.GetUserByEmailAddress(ctx, claims.Email)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
|
|
|
@ -28,9 +28,7 @@
|
|||
"github.com/gin-gonic/gin"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/api"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
|
@ -119,8 +117,8 @@ func (m *Module) ValidatePassword(ctx context.Context, email string, password st
|
|||
return incorrectPassword(err)
|
||||
}
|
||||
|
||||
user := >smodel.User{}
|
||||
if err := m.db.GetWhere(ctx, []db.Where{{Key: "email", Value: email}}, user); err != nil {
|
||||
user, err := m.db.GetUserByEmailAddress(ctx, email)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
|
||||
return incorrectPassword(err)
|
||||
}
|
||||
|
|
|
@ -52,8 +52,8 @@ func (m *Module) TokenCheck(c *gin.Context) {
|
|||
log.Tracef("authenticated user %s with bearer token, scope is %s", userID, ti.GetScope())
|
||||
|
||||
// fetch user for this token
|
||||
user := >smodel.User{}
|
||||
if err := m.db.GetByID(ctx, userID, user); err != nil {
|
||||
user, err := m.db.GetUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
if err != db.ErrNoEntries {
|
||||
log.Errorf("database error looking for user with id %s: %s", userID, err)
|
||||
return
|
||||
|
@ -80,22 +80,25 @@ func (m *Module) TokenCheck(c *gin.Context) {
|
|||
c.Set(oauth.SessionAuthorizedUser, user)
|
||||
|
||||
// fetch account for this token
|
||||
acct, err := m.db.GetAccountByID(ctx, user.AccountID)
|
||||
if err != nil {
|
||||
if err != db.ErrNoEntries {
|
||||
log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
|
||||
if user.Account == nil {
|
||||
acct, err := m.db.GetAccountByID(ctx, user.AccountID)
|
||||
if err != nil {
|
||||
if err != db.ErrNoEntries {
|
||||
log.Errorf("database error looking for account with id %s: %s", user.AccountID, err)
|
||||
return
|
||||
}
|
||||
log.Warnf("no account found for userID %s", userID)
|
||||
return
|
||||
}
|
||||
log.Warnf("no account found for userID %s", userID)
|
||||
return
|
||||
user.Account = acct
|
||||
}
|
||||
|
||||
if !acct.SuspendedAt.IsZero() {
|
||||
if !user.Account.SuspendedAt.IsZero() {
|
||||
log.Warnf("authenticated user %s's account (accountId=%s) has been suspended", userID, user.AccountID)
|
||||
return
|
||||
}
|
||||
|
||||
c.Set(oauth.SessionAuthorizedAccount, acct)
|
||||
c.Set(oauth.SessionAuthorizedAccount, user.Account)
|
||||
}
|
||||
|
||||
// check for application token
|
||||
|
|
141
internal/cache/user.go
vendored
Normal file
141
internal/cache/user.go
vendored
Normal file
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"codeberg.org/gruf/go-cache/v2"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
// UserCache is a cache wrapper to provide lookups for gtsmodel.User
|
||||
type UserCache struct {
|
||||
cache cache.LookupCache[string, string, *gtsmodel.User]
|
||||
}
|
||||
|
||||
// NewUserCache returns a new instantiated UserCache object
|
||||
func NewUserCache() *UserCache {
|
||||
c := &UserCache{}
|
||||
c.cache = cache.NewLookup(cache.LookupCfg[string, string, *gtsmodel.User]{
|
||||
RegisterLookups: func(lm *cache.LookupMap[string, string]) {
|
||||
lm.RegisterLookup("accountid")
|
||||
lm.RegisterLookup("email")
|
||||
lm.RegisterLookup("unconfirmedemail")
|
||||
lm.RegisterLookup("confirmationtoken")
|
||||
},
|
||||
|
||||
AddLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
|
||||
lm.Set("accountid", user.AccountID, user.ID)
|
||||
if email := user.Email; email != "" {
|
||||
lm.Set("email", email, user.ID)
|
||||
}
|
||||
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
|
||||
lm.Set("unconfirmedemail", unconfirmedEmail, user.ID)
|
||||
}
|
||||
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
|
||||
lm.Set("confirmationtoken", confirmationToken, user.ID)
|
||||
}
|
||||
},
|
||||
|
||||
DeleteLookups: func(lm *cache.LookupMap[string, string], user *gtsmodel.User) {
|
||||
lm.Delete("accountid", user.AccountID)
|
||||
if email := user.Email; email != "" {
|
||||
lm.Delete("email", email)
|
||||
}
|
||||
if unconfirmedEmail := user.UnconfirmedEmail; unconfirmedEmail != "" {
|
||||
lm.Delete("unconfirmedemail", unconfirmedEmail)
|
||||
}
|
||||
if confirmationToken := user.ConfirmationToken; confirmationToken != "" {
|
||||
lm.Delete("confirmationtoken", confirmationToken)
|
||||
}
|
||||
},
|
||||
})
|
||||
c.cache.SetTTL(time.Minute*5, false)
|
||||
c.cache.Start(time.Second * 10)
|
||||
return c
|
||||
}
|
||||
|
||||
// GetByID attempts to fetch a user from the cache by its ID, you will receive a copy for thread-safety
|
||||
func (c *UserCache) GetByID(id string) (*gtsmodel.User, bool) {
|
||||
return c.cache.Get(id)
|
||||
}
|
||||
|
||||
// GetByAccountID attempts to fetch a user from the cache by its account ID, you will receive a copy for thread-safety
|
||||
func (c *UserCache) GetByAccountID(accountID string) (*gtsmodel.User, bool) {
|
||||
return c.cache.GetBy("accountid", accountID)
|
||||
}
|
||||
|
||||
// GetByEmail attempts to fetch a user from the cache by its email address, you will receive a copy for thread-safety
|
||||
func (c *UserCache) GetByEmail(email string) (*gtsmodel.User, bool) {
|
||||
return c.cache.GetBy("email", email)
|
||||
}
|
||||
|
||||
// GetByUnconfirmedEmail attempts to fetch a user from the cache by its confirmation token, you will receive a copy for thread-safety
|
||||
func (c *UserCache) GetByConfirmationToken(token string) (*gtsmodel.User, bool) {
|
||||
return c.cache.GetBy("confirmationtoken", token)
|
||||
}
|
||||
|
||||
// Put places a user in the cache, ensuring that the object place is a copy for thread-safety
|
||||
func (c *UserCache) Put(user *gtsmodel.User) {
|
||||
if user == nil || user.ID == "" {
|
||||
panic("invalid user")
|
||||
}
|
||||
c.cache.Set(user.ID, copyUser(user))
|
||||
}
|
||||
|
||||
// Invalidate invalidates one user from the cache using the ID of the user as key.
|
||||
func (c *UserCache) Invalidate(userID string) {
|
||||
c.cache.Invalidate(userID)
|
||||
}
|
||||
|
||||
func copyUser(user *gtsmodel.User) *gtsmodel.User {
|
||||
return >smodel.User{
|
||||
ID: user.ID,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
Email: user.Email,
|
||||
AccountID: user.AccountID,
|
||||
Account: nil,
|
||||
EncryptedPassword: user.EncryptedPassword,
|
||||
SignUpIP: user.SignUpIP,
|
||||
CurrentSignInAt: user.CurrentSignInAt,
|
||||
CurrentSignInIP: user.CurrentSignInIP,
|
||||
LastSignInAt: user.LastSignInAt,
|
||||
LastSignInIP: user.LastSignInIP,
|
||||
SignInCount: user.SignInCount,
|
||||
InviteID: user.InviteID,
|
||||
ChosenLanguages: user.ChosenLanguages,
|
||||
FilteredLanguages: user.FilteredLanguages,
|
||||
Locale: user.Locale,
|
||||
CreatedByApplicationID: user.CreatedByApplicationID,
|
||||
CreatedByApplication: nil,
|
||||
LastEmailedAt: user.LastEmailedAt,
|
||||
ConfirmationToken: user.ConfirmationToken,
|
||||
ConfirmationSentAt: user.ConfirmationSentAt,
|
||||
ConfirmedAt: user.ConfirmedAt,
|
||||
UnconfirmedEmail: user.UnconfirmedEmail,
|
||||
Moderator: copyBoolPtr(user.Moderator),
|
||||
Admin: copyBoolPtr(user.Admin),
|
||||
Disabled: copyBoolPtr(user.Disabled),
|
||||
Approved: copyBoolPtr(user.Approved),
|
||||
ResetPasswordToken: user.ResetPasswordToken,
|
||||
ResetPasswordSentAt: user.ResetPasswordSentAt,
|
||||
}
|
||||
}
|
|
@ -30,6 +30,7 @@
|
|||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
|
@ -40,7 +41,8 @@
|
|||
)
|
||||
|
||||
type adminDB struct {
|
||||
conn *DBConn
|
||||
conn *DBConn
|
||||
userCache *cache.UserCache
|
||||
}
|
||||
|
||||
func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (bool, db.Error) {
|
||||
|
@ -175,6 +177,7 @@ func (a *adminDB) NewSignup(ctx context.Context, username string, reason string,
|
|||
Exec(ctx); err != nil {
|
||||
return nil, a.conn.ProcessError(err)
|
||||
}
|
||||
a.userCache.Put(u)
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
|
|
@ -87,6 +87,7 @@ type DBService struct {
|
|||
db.Session
|
||||
db.Status
|
||||
db.Timeline
|
||||
db.User
|
||||
conn *DBConn
|
||||
}
|
||||
|
||||
|
@ -181,13 +182,15 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
|
|||
notifCache.SetTTL(time.Minute*5, false)
|
||||
notifCache.Start(time.Second * 10)
|
||||
|
||||
// Prepare domain block cache
|
||||
// Prepare other caches
|
||||
blockCache := cache.NewDomainBlockCache()
|
||||
userCache := cache.NewUserCache()
|
||||
|
||||
ps := &DBService{
|
||||
Account: accounts,
|
||||
Admin: &adminDB{
|
||||
conn: conn,
|
||||
conn: conn,
|
||||
userCache: userCache,
|
||||
},
|
||||
Basic: &basicDB{
|
||||
conn: conn,
|
||||
|
@ -219,7 +222,11 @@ func NewBunDBService(ctx context.Context) (db.DB, error) {
|
|||
},
|
||||
Status: status,
|
||||
Timeline: timeline,
|
||||
conn: conn,
|
||||
User: &userDB{
|
||||
conn: conn,
|
||||
cache: userCache,
|
||||
},
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
// we can confidently return this useable service now
|
||||
|
|
151
internal/db/bundb/user.go
Normal file
151
internal/db/bundb/user.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package bundb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/cache"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type userDB struct {
|
||||
conn *DBConn
|
||||
cache *cache.UserCache
|
||||
}
|
||||
|
||||
func (u *userDB) newUserQ(user *gtsmodel.User) *bun.SelectQuery {
|
||||
return u.conn.
|
||||
NewSelect().
|
||||
Model(user).
|
||||
Relation("Account")
|
||||
}
|
||||
|
||||
func (u *userDB) getUser(ctx context.Context, cacheGet func() (*gtsmodel.User, bool), dbQuery func(*gtsmodel.User) error) (*gtsmodel.User, db.Error) {
|
||||
// Attempt to fetch cached user
|
||||
user, cached := cacheGet()
|
||||
|
||||
if !cached {
|
||||
user = >smodel.User{}
|
||||
|
||||
// Not cached! Perform database query
|
||||
err := dbQuery(user)
|
||||
if err != nil {
|
||||
return nil, u.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
// Place in the cache
|
||||
u.cache.Put(user)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (u *userDB) GetUserByID(ctx context.Context, id string) (*gtsmodel.User, db.Error) {
|
||||
return u.getUser(
|
||||
ctx,
|
||||
func() (*gtsmodel.User, bool) {
|
||||
return u.cache.GetByID(id)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.id = ?", id).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (u *userDB) GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, db.Error) {
|
||||
return u.getUser(
|
||||
ctx,
|
||||
func() (*gtsmodel.User, bool) {
|
||||
return u.cache.GetByAccountID(accountID)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.account_id = ?", accountID).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (u *userDB) GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, db.Error) {
|
||||
return u.getUser(
|
||||
ctx,
|
||||
func() (*gtsmodel.User, bool) {
|
||||
return u.cache.GetByEmail(emailAddress)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.email = ?", emailAddress).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (u *userDB) GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, db.Error) {
|
||||
return u.getUser(
|
||||
ctx,
|
||||
func() (*gtsmodel.User, bool) {
|
||||
return u.cache.GetByConfirmationToken(confirmationToken)
|
||||
},
|
||||
func(user *gtsmodel.User) error {
|
||||
return u.newUserQ(user).Where("user.confirmation_token = ?", confirmationToken).Scan(ctx)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (u *userDB) PutUser(ctx context.Context, user *gtsmodel.User) (*gtsmodel.User, db.Error) {
|
||||
if _, err := u.conn.
|
||||
NewInsert().
|
||||
Model(user).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, u.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
u.cache.Put(user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (u *userDB) UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, db.Error) {
|
||||
// Update the user's last-updated
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
if _, err := u.conn.
|
||||
NewUpdate().
|
||||
Model(user).
|
||||
WherePK().
|
||||
Column(columns...).
|
||||
Exec(ctx); err != nil {
|
||||
return nil, u.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
u.cache.Invalidate(user.ID)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (u *userDB) DeleteUserByID(ctx context.Context, userID string) db.Error {
|
||||
if _, err := u.conn.
|
||||
NewDelete().
|
||||
Model(>smodel.User{ID: userID}).
|
||||
WherePK().
|
||||
Exec(ctx); err != nil {
|
||||
return u.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
u.cache.Invalidate(userID)
|
||||
return nil
|
||||
}
|
73
internal/db/bundb/user_test.go
Normal file
73
internal/db/bundb/user_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package bundb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
type UserTestSuite struct {
|
||||
BunDBStandardTestSuite
|
||||
}
|
||||
|
||||
func (suite *UserTestSuite) TestGetUser() {
|
||||
user, err := suite.db.GetUserByID(context.Background(), suite.testUsers["local_account_1"].ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(user)
|
||||
}
|
||||
|
||||
func (suite *UserTestSuite) TestGetUserByEmailAddress() {
|
||||
user, err := suite.db.GetUserByEmailAddress(context.Background(), suite.testUsers["local_account_1"].Email)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(user)
|
||||
}
|
||||
|
||||
func (suite *UserTestSuite) TestGetUserByAccountID() {
|
||||
user, err := suite.db.GetUserByAccountID(context.Background(), suite.testAccounts["local_account_1"].ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(user)
|
||||
}
|
||||
|
||||
func (suite *UserTestSuite) TestUpdateUserSelectedColumns() {
|
||||
testUser := suite.testUsers["local_account_1"]
|
||||
user := >smodel.User{
|
||||
ID: testUser.ID,
|
||||
Email: "whatever",
|
||||
Locale: "es",
|
||||
}
|
||||
|
||||
user, err := suite.db.UpdateUser(context.Background(), user, "email", "locale")
|
||||
suite.NoError(err)
|
||||
suite.NotNil(user)
|
||||
|
||||
dbUser, err := suite.db.GetUserByID(context.Background(), testUser.ID)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(dbUser)
|
||||
suite.Equal("whatever", dbUser.Email)
|
||||
suite.Equal("es", dbUser.Locale)
|
||||
suite.Equal(testUser.AccountID, dbUser.AccountID)
|
||||
}
|
||||
|
||||
func TestUserTestSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserTestSuite))
|
||||
}
|
|
@ -44,6 +44,7 @@ type DB interface {
|
|||
Session
|
||||
Status
|
||||
Timeline
|
||||
User
|
||||
|
||||
/*
|
||||
USEFUL CONVERSION FUNCTIONS
|
||||
|
|
42
internal/db/user.go
Normal file
42
internal/db/user.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
GoToSocial
|
||||
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
)
|
||||
|
||||
// User contains functions related to user getting/setting/creation.
|
||||
type User interface {
|
||||
// GetUserByID returns one user with the given ID, or an error if something goes wrong.
|
||||
GetUserByID(ctx context.Context, id string) (*gtsmodel.User, Error)
|
||||
// GetUserByAccountID returns one user by its account ID, or an error if something goes wrong.
|
||||
GetUserByAccountID(ctx context.Context, accountID string) (*gtsmodel.User, Error)
|
||||
// GetUserByID returns one user with the given email address, or an error if something goes wrong.
|
||||
GetUserByEmailAddress(ctx context.Context, emailAddress string) (*gtsmodel.User, Error)
|
||||
// GetUserByConfirmationToken returns one user by its confirmation token, or an error if something goes wrong.
|
||||
GetUserByConfirmationToken(ctx context.Context, confirmationToken string) (*gtsmodel.User, Error)
|
||||
// UpdateUser updates one user by its primary key. If columns is set, only given columns
|
||||
// will be updated. If not set, all columns will be updated.
|
||||
UpdateUser(ctx context.Context, user *gtsmodel.User, columns ...string) (*gtsmodel.User, Error)
|
||||
// DeleteUserByID deletes one user by its ID.
|
||||
DeleteUserByID(ctx context.Context, userID string) Error
|
||||
}
|
|
@ -70,13 +70,14 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
|||
|
||||
// 1. Delete account's application(s), clients, and oauth tokens
|
||||
// we only need to do this step for local account since remote ones won't have any tokens or applications on our server
|
||||
var user *gtsmodel.User
|
||||
if account.Domain == "" {
|
||||
// see if we can get a user for this account
|
||||
u := >smodel.User{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, u); err == nil {
|
||||
var err error
|
||||
if user, err = p.db.GetUserByAccountID(ctx, account.ID); err == nil {
|
||||
// we got one! select all tokens with the user's ID
|
||||
tokens := []*gtsmodel.Token{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: u.ID}}, &tokens); err == nil {
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "user_id", Value: user.ID}}, &tokens); err == nil {
|
||||
// we have some tokens to delete
|
||||
for _, t := range tokens {
|
||||
// delete client(s) associated with this token
|
||||
|
@ -240,9 +241,11 @@ func (p *processor) Delete(ctx context.Context, account *gtsmodel.Account, origi
|
|||
// TODO
|
||||
|
||||
// 16. Delete account's user
|
||||
l.Debug("deleting account user")
|
||||
if err := p.db.DeleteWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, >smodel.User{}); err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
if user != nil {
|
||||
l.Debug("deleting account user")
|
||||
if err := p.db.DeleteUserByID(ctx, user.ID); err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// 17. Delete account's timeline
|
||||
|
@ -288,8 +291,8 @@ func (p *processor) DeleteLocal(ctx context.Context, account *gtsmodel.Account,
|
|||
|
||||
if form.DeleteOriginID == account.ID {
|
||||
// the account owner themself has requested deletion via the API, get their user from the db
|
||||
user := >smodel.User{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil {
|
||||
user, err := p.db.GetUserByAccountID(ctx, account.ID)
|
||||
if err != nil {
|
||||
return gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,6 @@
|
|||
"github.com/superseriousbusiness/activity/pub"
|
||||
"github.com/superseriousbusiness/activity/streams"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/ap"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/messages"
|
||||
|
@ -138,8 +137,8 @@ func (p *processor) processCreateAccountFromClientAPI(ctx context.Context, clien
|
|||
}
|
||||
|
||||
// get the user this account belongs to
|
||||
user := >smodel.User{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: account.ID}}, user); err != nil {
|
||||
user, err := p.db.GetUserByAccountID(ctx, account.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -370,7 +370,7 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {
|
|||
// no statuses from foss satan should be left in the database
|
||||
if !testrig.WaitFor(func() bool {
|
||||
s, err := suite.db.GetAccountStatuses(ctx, deletedAccount.ID, 0, false, false, "", "", false, false, false)
|
||||
return s == nil && err == db.ErrNoEntries
|
||||
return s == nil && err == db.ErrNoEntries
|
||||
}) {
|
||||
suite.FailNow("timeout waiting for statuses to be deleted")
|
||||
}
|
||||
|
|
|
@ -142,8 +142,8 @@ func (p *processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("account with username %s not retrievable", *form.ContactUsername))
|
||||
}
|
||||
// make sure it has a user associated with it
|
||||
contactUser := >smodel.User{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: contactAccount.ID}}, contactUser); err != nil {
|
||||
contactUser, err := p.db.GetUserByAccountID(ctx, contactAccount.ID)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("user for account with username %s not retrievable", *form.ContactUsername))
|
||||
}
|
||||
// suspended accounts cannot be contact accounts
|
||||
|
|
|
@ -40,8 +40,8 @@ func (p *processor) AuthorizeStreamingRequest(ctx context.Context, accessToken s
|
|||
return nil, gtserror.NewErrorUnauthorized(err)
|
||||
}
|
||||
|
||||
user := >smodel.User{}
|
||||
if err := p.db.GetByID(ctx, uid, user); err != nil {
|
||||
user, err := p.db.GetUserByID(ctx, uid)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
err := fmt.Errorf("no user found for validated uid %s", uid)
|
||||
return nil, gtserror.NewErrorUnauthorized(err)
|
||||
|
|
|
@ -89,8 +89,8 @@ func (p *processor) ConfirmEmail(ctx context.Context, token string) (*gtsmodel.U
|
|||
return nil, gtserror.NewErrorNotFound(errors.New("no token provided"))
|
||||
}
|
||||
|
||||
user := >smodel.User{}
|
||||
if err := p.db.GetWhere(ctx, []db.Where{{Key: "confirmation_token", Value: token}}, user); err != nil {
|
||||
user, err := p.db.GetUserByConfirmationToken(ctx, token)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
return nil, gtserror.NewErrorNotFound(err)
|
||||
}
|
||||
|
|
|
@ -46,9 +46,9 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontend() {
|
|||
func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct() {
|
||||
testAccount := suite.testAccounts["local_account_1"] // take zork for this test
|
||||
testEmoji := suite.testEmojis["rainbow"]
|
||||
|
||||
|
||||
testAccount.Emojis = []*gtsmodel.Emoji{testEmoji}
|
||||
|
||||
|
||||
apiAccount, err := suite.typeconverter.AccountToAPIAccountPublic(context.Background(), testAccount)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(apiAccount)
|
||||
|
@ -61,9 +61,9 @@ func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiStruct()
|
|||
func (suite *InternalToFrontendTestSuite) TestAccountToFrontendWithEmojiIDs() {
|
||||
testAccount := suite.testAccounts["local_account_1"] // take zork for this test
|
||||
testEmoji := suite.testEmojis["rainbow"]
|
||||
|
||||
|
||||
testAccount.EmojiIDs = []string{testEmoji.ID}
|
||||
|
||||
|
||||
apiAccount, err := suite.typeconverter.AccountToAPIAccountPublic(context.Background(), testAccount)
|
||||
suite.NoError(err)
|
||||
suite.NotNil(apiAccount)
|
||||
|
|
|
@ -68,8 +68,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu
|
|||
// if the target user doesn't exist (anymore) then the status also shouldn't be visible
|
||||
// note: we only do this for local users
|
||||
if targetAccount.Domain == "" {
|
||||
targetUser := >smodel.User{}
|
||||
if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: targetAccount.ID}}, targetUser); err != nil {
|
||||
targetUser, err := f.db.GetUserByAccountID(ctx, targetAccount.ID)
|
||||
if err != nil {
|
||||
l.Debug("target user could not be selected")
|
||||
if err == db.ErrNoEntries {
|
||||
return false, nil
|
||||
|
@ -98,8 +98,8 @@ func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Statu
|
|||
// if the requesting user doesn't exist (anymore) then the status also shouldn't be visible
|
||||
// note: we only do this for local users
|
||||
if requestingAccount.Domain == "" {
|
||||
requestingUser := >smodel.User{}
|
||||
if err := f.db.GetWhere(ctx, []db.Where{{Key: "account_id", Value: requestingAccount.ID}}, requestingUser); err != nil {
|
||||
requestingUser, err := f.db.GetUserByAccountID(ctx, requestingAccount.ID)
|
||||
if err != nil {
|
||||
// if the requesting account is local but doesn't have a corresponding user in the db this is a problem
|
||||
l.Debug("requesting user could not be selected")
|
||||
if err == db.ErrNoEntries {
|
||||
|
|
Loading…
Reference in a new issue