diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index 705e1b118..4b4c78726 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -37,7 +37,7 @@
)
type accountDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -334,7 +334,7 @@ func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) e
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
- return a.db.RunInTx(ctx, func(tx Tx) error {
+ return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// create links between this account and any emojis it uses
for _, i := range account.EmojiIDs {
if _, err := tx.NewInsert().Model(>smodel.AccountToEmoji{
@@ -363,7 +363,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
- return a.db.RunInTx(ctx, func(tx Tx) error {
+ return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// create links between this account and any emojis it uses
// first clear out any old emoji links
if _, err := tx.
@@ -411,7 +411,7 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) error {
return err
}
- return a.db.RunInTx(ctx, func(tx Tx) error {
+ return a.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// clear out any emoji links
if _, err := tx.
NewDelete().
diff --git a/internal/db/bundb/admin.go b/internal/db/bundb/admin.go
index e189c508e..70ae68026 100644
--- a/internal/db/bundb/admin.go
+++ b/internal/db/bundb/admin.go
@@ -45,7 +45,7 @@
const rsaKeyBits = 2048
type adminDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -56,7 +56,7 @@ func (a *adminDB) IsUsernameAvailable(ctx context.Context, username string) (boo
Column("account.id").
Where("? = ?", bun.Ident("account.username"), username).
Where("? IS NULL", bun.Ident("account.domain"))
- return a.db.NotExists(ctx, q)
+ return notExists(ctx, q)
}
func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, error) {
@@ -73,7 +73,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, err
TableExpr("? AS ?", bun.Ident("email_domain_blocks"), bun.Ident("email_domain_block")).
Column("email_domain_block.id").
Where("? = ?", bun.Ident("email_domain_block.domain"), domain)
- emailDomainBlocked, err := a.db.Exists(ctx, emailDomainBlockedQ)
+ emailDomainBlocked, err := exists(ctx, emailDomainBlockedQ)
if err != nil {
return false, err
}
@@ -88,7 +88,7 @@ func (a *adminDB) IsEmailAvailable(ctx context.Context, email string) (bool, err
Column("user.id").
Where("? = ?", bun.Ident("user.email"), email).
WhereOr("? = ?", bun.Ident("user.unconfirmed_email"), email)
- return a.db.NotExists(ctx, q)
+ return notExists(ctx, q)
}
func (a *adminDB) NewSignup(ctx context.Context, newSignup gtsmodel.NewSignup) (*gtsmodel.User, error) {
@@ -229,7 +229,7 @@ func (a *adminDB) CreateInstanceAccount(ctx context.Context) error {
Where("? = ?", bun.Ident("account.username"), username).
Where("? IS NULL", bun.Ident("account.domain"))
- exists, err := a.db.Exists(ctx, q)
+ exists, err := exists(ctx, q)
if err != nil {
return err
}
@@ -287,7 +287,7 @@ func (a *adminDB) CreateInstanceInstance(ctx context.Context) error {
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
Where("? = ?", bun.Ident("instance.domain"), host)
- exists, err := a.db.Exists(ctx, q)
+ exists, err := exists(ctx, q)
if err != nil {
return err
}
diff --git a/internal/db/bundb/application.go b/internal/db/bundb/application.go
index 2e17a0e94..f02632793 100644
--- a/internal/db/bundb/application.go
+++ b/internal/db/bundb/application.go
@@ -26,7 +26,7 @@
)
type applicationDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/basic.go b/internal/db/bundb/basic.go
index 488f59ad5..7b523f309 100644
--- a/internal/db/bundb/basic.go
+++ b/internal/db/bundb/basic.go
@@ -27,7 +27,7 @@
)
type basicDB struct {
- db *DB
+ db *bun.DB
}
func (b *basicDB) Put(ctx context.Context, i interface{}) error {
diff --git a/internal/db/bundb/bundb.go b/internal/db/bundb/bundb.go
index 048474782..4ecbec7b9 100644
--- a/internal/db/bundb/bundb.go
+++ b/internal/db/bundb/bundb.go
@@ -52,13 +52,6 @@
"modernc.org/sqlite"
)
-var registerTables = []interface{}{
- >smodel.AccountToEmoji{},
- >smodel.StatusToEmoji{},
- >smodel.StatusToTag{},
- >smodel.ThreadToStatus{},
-}
-
// DBService satisfies the DB interface
type DBService struct {
db.Account
@@ -88,12 +81,12 @@ type DBService struct {
db.Timeline
db.User
db.Tombstone
- db *DB
+ db *bun.DB
}
// GetDB returns the underlying database connection pool.
// Should only be used in testing + exceptional circumstance.
-func (dbService *DBService) DB() *DB {
+func (dbService *DBService) DB() *bun.DB {
return dbService.db
}
@@ -129,18 +122,18 @@ func doMigration(ctx context.Context, db *bun.DB) error {
// NewBunDBService returns a bunDB derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/uptrace/bun to create and maintain a database connection.
func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
- var db *DB
+ var db *bun.DB
var err error
t := strings.ToLower(config.GetDbType())
switch t {
case "postgres":
- db, err = pgConn(ctx)
+ db, err = pgConn(ctx, state)
if err != nil {
return nil, err
}
case "sqlite":
- db, err = sqliteConn(ctx)
+ db, err = sqliteConn(ctx, state)
if err != nil {
return nil, err
}
@@ -159,14 +152,19 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
// table registration is needed for many-to-many, see:
// https://bun.uptrace.dev/orm/many-to-many-relation/
- for _, t := range registerTables {
+ for _, t := range []interface{}{
+ >smodel.AccountToEmoji{},
+ >smodel.StatusToEmoji{},
+ >smodel.StatusToTag{},
+ >smodel.ThreadToStatus{},
+ } {
db.RegisterModel(t)
}
// perform any pending database migrations: this includes
// the very first 'migration' on startup which just creates
// necessary tables
- if err := doMigration(ctx, db.bun); err != nil {
+ if err := doMigration(ctx, db); err != nil {
return nil, fmt.Errorf("db migration error: %s", err)
}
@@ -284,13 +282,18 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
return ps, nil
}
-func pgConn(ctx context.Context) (*DB, error) {
+func pgConn(ctx context.Context, state *state.State) (*bun.DB, error) {
opts, err := deriveBunDBPGOptions() //nolint:contextcheck
if err != nil {
- return nil, fmt.Errorf("could not create bundb postgres options: %s", err)
+ return nil, fmt.Errorf("could not create bundb postgres options: %w", err)
}
- sqldb := stdlib.OpenDB(*opts)
+ cfg := stdlib.RegisterConnConfig(opts)
+
+ sqldb, err := sql.Open("pgx-gts", cfg)
+ if err != nil {
+ return nil, fmt.Errorf("could not open postgres db: %w", err)
+ }
// Tune db connections for postgres, see:
// - https://bun.uptrace.dev/guide/running-bun-in-production.html#database-sql
@@ -299,18 +302,18 @@ func pgConn(ctx context.Context) (*DB, error) {
sqldb.SetMaxIdleConns(2) // assume default 2; if max idle is less than max open, it will be automatically adjusted
sqldb.SetConnMaxLifetime(5 * time.Minute) // fine to kill old connections
- db := WrapDB(bun.NewDB(sqldb, pgdialect.New()))
+ db := bun.NewDB(sqldb, pgdialect.New())
// ping to check the db is there and listening
if err := db.PingContext(ctx); err != nil {
- return nil, fmt.Errorf("postgres ping: %s", err)
+ return nil, fmt.Errorf("postgres ping: %w", err)
}
log.Info(ctx, "connected to POSTGRES database")
return db, nil
}
-func sqliteConn(ctx context.Context) (*DB, error) {
+func sqliteConn(ctx context.Context, state *state.State) (*bun.DB, error) {
// validate db address has actually been set
address := config.GetDbAddress()
if address == "" {
@@ -321,7 +324,7 @@ func sqliteConn(ctx context.Context) (*DB, error) {
address = buildSQLiteAddress(address)
// Open new DB instance
- sqldb, err := sql.Open("sqlite", address)
+ sqldb, err := sql.Open("sqlite-gts", address)
if err != nil {
if errWithCode, ok := err.(*sqlite.Error); ok {
err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
@@ -336,15 +339,14 @@ func sqliteConn(ctx context.Context) (*DB, error) {
sqldb.SetMaxIdleConns(1) // only keep max 1 idle connection around
sqldb.SetConnMaxLifetime(0) // don't kill connections due to age
- // Wrap Bun database conn in our own wrapper
- db := WrapDB(bun.NewDB(sqldb, sqlitedialect.New()))
+ db := bun.NewDB(sqldb, sqlitedialect.New())
// ping to check the db is there and listening
if err := db.PingContext(ctx); err != nil {
if errWithCode, ok := err.(*sqlite.Error); ok {
err = errors.New(sqlite.ErrorCodeString[errWithCode.Code()])
}
- return nil, fmt.Errorf("sqlite ping: %s", err)
+ return nil, fmt.Errorf("sqlite ping: %w", err)
}
log.Infof(ctx, "connected to SQLITE database with address %s", address)
@@ -418,7 +420,7 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
// parse the PEM block into the certificate
caCert, err := x509.ParseCertificate(caPem.Bytes)
if err != nil {
- return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %s", certPath, err)
+ return nil, fmt.Errorf("could not parse cert at %s into x509 certificate: %w", certPath, err)
}
// we're happy, add it to the existing pool and then use this pool in our tls config
diff --git a/internal/db/bundb/db.go b/internal/db/bundb/db.go
deleted file mode 100644
index 2b19ba0c4..000000000
--- a/internal/db/bundb/db.go
+++ /dev/null
@@ -1,578 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see .
-
-package bundb
-
-import (
- "context"
- "database/sql"
- "time"
- "unsafe"
-
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtserror"
- "github.com/uptrace/bun"
- "github.com/uptrace/bun/dialect"
- "github.com/uptrace/bun/schema"
-)
-
-// DB wraps a bun database instance
-// to provide common per-dialect SQL error
-// conversions to common types, and retries
-// on returned busy (SQLite only).
-type DB struct {
- // our own wrapped db type
- // with retry backoff support.
- // kept separate to the *bun.DB
- // type to be passed into query
- // builders as bun.IConn iface
- // (this prevents double firing
- // bun query hooks).
- //
- // also holds per-dialect
- // error hook function.
- raw rawdb
-
- // bun DB interface we use
- // for dialects, and improved
- // struct marshal/unmarshaling.
- bun *bun.DB
-}
-
-// WrapDB wraps a bun database instance in our database type.
-func WrapDB(db *bun.DB) *DB {
- var errProc func(error) error
- switch name := db.Dialect().Name(); name {
- case dialect.PG:
- errProc = processPostgresError
- case dialect.SQLite:
- errProc = processSQLiteError
- default:
- panic("unknown dialect name: " + name.String())
- }
- return &DB{
- raw: rawdb{
- errHook: errProc,
- db: db.DB,
- },
- bun: db,
- }
-}
-
-// Dialect is a direct call-through to bun.DB.Dialect().
-func (db *DB) Dialect() schema.Dialect { return db.bun.Dialect() }
-
-// AddQueryHook is a direct call-through to bun.DB.AddQueryHook().
-func (db *DB) AddQueryHook(hook bun.QueryHook) { db.bun.AddQueryHook(hook) }
-
-// RegisterModels is a direct call-through to bun.DB.RegisterModels().
-func (db *DB) RegisterModel(models ...any) { db.bun.RegisterModel(models...) }
-
-// PingContext is a direct call-through to bun.DB.PingContext().
-func (db *DB) PingContext(ctx context.Context) error { return db.bun.PingContext(ctx) }
-
-// Close is a direct call-through to bun.DB.Close().
-func (db *DB) Close() error { return db.bun.Close() }
-
-// ExecContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing.
-func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
- bundb := db.bun // use underlying *bun.DB interface for their query formatting
- err = retryOnBusy(ctx, func() error {
- result, err = bundb.ExecContext(ctx, query, args...)
- err = db.raw.errHook(err)
- return err
- })
- return
-}
-
-// QueryContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing.
-func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
- bundb := db.bun // use underlying *bun.DB interface for their query formatting
- err = retryOnBusy(ctx, func() error {
- rows, err = bundb.QueryContext(ctx, query, args...)
- err = db.raw.errHook(err)
- return err
- })
- return
-}
-
-// QueryRowContext wraps bun.DB.ExecContext() with retry-busy timeout and our own error processing.
-func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) {
- bundb := db.bun // use underlying *bun.DB interface for their query formatting
- _ = retryOnBusy(ctx, func() error {
- row = bundb.QueryRowContext(ctx, query, args...)
- if err := db.raw.errHook(row.Err()); err != nil {
- updateRowError(row, err) // set new error
- }
- return row.Err()
- })
- return
-}
-
-// BeginTx wraps bun.DB.BeginTx() with retry-busy timeout and our own error processing.
-func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (tx Tx, err error) {
- var buntx bun.Tx // captured bun.Tx
- bundb := db.bun // use *bun.DB interface to return bun.Tx type
-
- err = retryOnBusy(ctx, func() error {
- buntx, err = bundb.BeginTx(ctx, opts)
- err = db.raw.errHook(err)
- return err
- })
-
- if err == nil {
- // Wrap bun.Tx in our type.
- tx = wrapTx(db, &buntx)
- }
-
- return
-}
-
-// RunInTx is functionally the same as bun.DB.RunInTx() but with retry-busy timeouts.
-func (db *DB) RunInTx(ctx context.Context, fn func(Tx) error) error {
- // Attempt to start new transaction.
- tx, err := db.BeginTx(ctx, nil)
- if err != nil {
- return err
- }
-
- var done bool
-
- defer func() {
- if !done {
- // Rollback tx.
- _ = tx.Rollback()
- }
- }()
-
- // Perform supplied transaction
- if err := fn(tx); err != nil {
- return err
- }
-
- // Commit tx.
- err = tx.Commit()
- done = true
- return err
-}
-
-func (db *DB) NewValues(model interface{}) *bun.ValuesQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewValuesQuery(db.bun, model).Conn(&db.raw)
-}
-
-func (db *DB) NewMerge() *bun.MergeQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewMergeQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewSelect() *bun.SelectQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewSelectQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewInsert() *bun.InsertQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewInsertQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewUpdate() *bun.UpdateQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewUpdateQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewDelete() *bun.DeleteQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewDeleteQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewRaw(query string, args ...interface{}) *bun.RawQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewRawQuery(db.bun, query, args...).Conn(&db.raw)
-}
-
-func (db *DB) NewCreateTable() *bun.CreateTableQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewCreateTableQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewDropTable() *bun.DropTableQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewDropTableQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewCreateIndex() *bun.CreateIndexQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewCreateIndexQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewDropIndex() *bun.DropIndexQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewDropIndexQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewTruncateTable() *bun.TruncateTableQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewTruncateTableQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewAddColumn() *bun.AddColumnQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewAddColumnQuery(db.bun).Conn(&db.raw)
-}
-
-func (db *DB) NewDropColumn() *bun.DropColumnQuery {
- // note: passing in rawdb as conn iface so no double query-hook
- // firing when passed through the bun.DB.Query___() functions.
- return bun.NewDropColumnQuery(db.bun).Conn(&db.raw)
-}
-
-// Exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors.
-func (db *DB) Exists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
- exists, err := query.Exists(ctx)
- switch err {
- case nil:
- return exists, nil
- case sql.ErrNoRows:
- return false, nil
- default:
- return false, err
- }
-}
-
-// NotExists checks the results of a SelectQuery for the non-existence of the data in question, masking ErrNoEntries errors.
-func (db *DB) NotExists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
- exists, err := db.Exists(ctx, query)
- return !exists, err
-}
-
-type rawdb struct {
- // dialect specific error
- // processing function hook.
- errHook func(error) error
-
- // embedded raw
- // db interface
- db *sql.DB
-}
-
-// ExecContext wraps sql.DB.ExecContext() with retry-busy timeout and our own error processing.
-func (db *rawdb) ExecContext(ctx context.Context, query string, args ...any) (result sql.Result, err error) {
- err = retryOnBusy(ctx, func() error {
- result, err = db.db.ExecContext(ctx, query, args...)
- err = db.errHook(err)
- return err
- })
- return
-}
-
-// QueryContext wraps sql.DB.QueryContext() with retry-busy timeout and our own error processing.
-func (db *rawdb) QueryContext(ctx context.Context, query string, args ...any) (rows *sql.Rows, err error) {
- err = retryOnBusy(ctx, func() error {
- rows, err = db.db.QueryContext(ctx, query, args...)
- err = db.errHook(err)
- return err
- })
- return
-}
-
-// QueryRowContext wraps sql.DB.QueryRowContext() with retry-busy timeout and our own error processing.
-func (db *rawdb) QueryRowContext(ctx context.Context, query string, args ...any) (row *sql.Row) {
- _ = retryOnBusy(ctx, func() error {
- row = db.db.QueryRowContext(ctx, query, args...)
- err := db.errHook(row.Err())
- return err
- })
- return
-}
-
-// Tx wraps a bun transaction instance
-// to provide common per-dialect SQL error
-// conversions to common types, and retries
-// on busy commit/rollback (SQLite only).
-type Tx struct {
- // our own wrapped Tx type
- // kept separate to the *bun.Tx
- // type to be passed into query
- // builders as bun.IConn iface
- // (this prevents double firing
- // bun query hooks).
- //
- // also holds per-dialect
- // error hook function.
- raw rawtx
-
- // bun Tx interface we use
- // for dialects, and improved
- // struct marshal/unmarshaling.
- bun *bun.Tx
-}
-
-// wrapTx wraps a given bun.Tx in our own wrapping Tx type.
-func wrapTx(db *DB, tx *bun.Tx) Tx {
- return Tx{
- raw: rawtx{
- errHook: db.raw.errHook,
- tx: tx.Tx,
- },
- bun: tx,
- }
-}
-
-// ExecContext wraps bun.Tx.ExecContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction).
-func (tx Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
- buntx := tx.bun // use underlying *bun.Tx interface for their query formatting
- res, err := buntx.ExecContext(ctx, query, args...)
- err = tx.raw.errHook(err)
- return res, err
-}
-
-// QueryContext wraps bun.Tx.QueryContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction).
-func (tx Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
- buntx := tx.bun // use underlying *bun.Tx interface for their query formatting
- rows, err := buntx.QueryContext(ctx, query, args...)
- err = tx.raw.errHook(err)
- return rows, err
-}
-
-// QueryRowContext wraps bun.Tx.QueryRowContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction).
-func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
- buntx := tx.bun // use underlying *bun.Tx interface for their query formatting
- row := buntx.QueryRowContext(ctx, query, args...)
- if err := tx.raw.errHook(row.Err()); err != nil {
- updateRowError(row, err) // set new error
- }
- return row
-}
-
-// Commit wraps bun.Tx.Commit() with retry-busy timeout and our own error processing.
-func (tx Tx) Commit() (err error) {
- buntx := tx.bun // use *bun.Tx interface
- err = retryOnBusy(context.TODO(), func() error {
- err = buntx.Commit()
- err = tx.raw.errHook(err)
- return err
- })
- return
-}
-
-// Rollback wraps bun.Tx.Rollback() with retry-busy timeout and our own error processing.
-func (tx Tx) Rollback() (err error) {
- buntx := tx.bun // use *bun.Tx interface
- err = retryOnBusy(context.TODO(), func() error {
- err = buntx.Rollback()
- err = tx.raw.errHook(err)
- return err
- })
- return
-}
-
-// Dialect is a direct call-through to bun.DB.Dialect().
-func (tx Tx) Dialect() schema.Dialect {
- return tx.bun.Dialect()
-}
-
-func (tx Tx) NewValues(model interface{}) *bun.ValuesQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewValues(model).Conn(&tx.raw)
-}
-
-func (tx Tx) NewMerge() *bun.MergeQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewMerge().Conn(&tx.raw)
-}
-
-func (tx Tx) NewSelect() *bun.SelectQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewSelect().Conn(&tx.raw)
-}
-
-func (tx Tx) NewInsert() *bun.InsertQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewInsert().Conn(&tx.raw)
-}
-
-func (tx Tx) NewUpdate() *bun.UpdateQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewUpdate().Conn(&tx.raw)
-}
-
-func (tx Tx) NewDelete() *bun.DeleteQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewDelete().Conn(&tx.raw)
-}
-
-func (tx Tx) NewRaw(query string, args ...interface{}) *bun.RawQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewRaw(query, args...).Conn(&tx.raw)
-}
-
-func (tx Tx) NewCreateTable() *bun.CreateTableQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewCreateTable().Conn(&tx.raw)
-}
-
-func (tx Tx) NewDropTable() *bun.DropTableQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewDropTable().Conn(&tx.raw)
-}
-
-func (tx Tx) NewCreateIndex() *bun.CreateIndexQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewCreateIndex().Conn(&tx.raw)
-}
-
-func (tx Tx) NewDropIndex() *bun.DropIndexQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewDropIndex().Conn(&tx.raw)
-}
-
-func (tx Tx) NewTruncateTable() *bun.TruncateTableQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewTruncateTable().Conn(&tx.raw)
-}
-
-func (tx Tx) NewAddColumn() *bun.AddColumnQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewAddColumn().Conn(&tx.raw)
-}
-
-func (tx Tx) NewDropColumn() *bun.DropColumnQuery {
- // note: passing in rawtx as conn iface so no double query-hook
- // firing when passed through the bun.Tx.Query___() functions.
- return tx.bun.NewDropColumn().Conn(&tx.raw)
-}
-
-type rawtx struct {
- // dialect specific error
- // processing function hook.
- errHook func(error) error
-
- // embedded raw
- // tx interface
- tx *sql.Tx
-}
-
-// ExecContext wraps sql.Tx.ExecContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction).
-func (tx *rawtx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
- res, err := tx.tx.ExecContext(ctx, query, args...)
- err = tx.errHook(err)
- return res, err
-}
-
-// QueryContext wraps sql.Tx.QueryContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction).
-func (tx *rawtx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
- rows, err := tx.tx.QueryContext(ctx, query, args...)
- err = tx.errHook(err)
- return rows, err
-}
-
-// QueryRowContext wraps sql.Tx.QueryRowContext() with our own error processing, WITHOUT retry-busy timeouts (as will be mid-transaction).
-func (tx *rawtx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
- row := tx.tx.QueryRowContext(ctx, query, args...)
- if err := tx.errHook(row.Err()); err != nil {
- updateRowError(row, err) // set new error
- }
- return row
-}
-
-// updateRowError updates an sql.Row's internal error field using the unsafe package.
-func updateRowError(sqlrow *sql.Row, err error) {
- type row struct {
- err error
- rows *sql.Rows
- }
-
- // compile-time check to ensure sql.Row not changed.
- if unsafe.Sizeof(row{}) != unsafe.Sizeof(sql.Row{}) {
- panic("sql.Row has changed definition")
- }
-
- // this code is awful and i must be shamed for this.
- (*row)(unsafe.Pointer(sqlrow)).err = err
-}
-
-// retryOnBusy will retry given function on returned 'errBusy'.
-func retryOnBusy(ctx context.Context, fn func() error) error {
- var backoff time.Duration
-
- for i := 0; ; i++ {
- // Perform func.
- err := fn()
-
- if err != errBusy {
- // May be nil, or may be
- // some other error, either
- // way return here.
- return err
- }
-
- // backoff according to a multiplier of 2ms * 2^2n,
- // up to a maximum possible backoff time of 5 minutes.
- //
- // this works out as the following:
- // 4ms
- // 16ms
- // 64ms
- // 256ms
- // 1.024s
- // 4.096s
- // 16.384s
- // 1m5.536s
- // 4m22.144s
- backoff = 2 * time.Millisecond * (1 << (2*i + 1))
- if backoff >= 5*time.Minute {
- break
- }
-
- select {
- // Context cancelled.
- case <-ctx.Done():
-
- // Backoff for some time.
- case <-time.After(backoff):
- }
- }
-
- return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff)
-}
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index 2398e52c2..1254d79c8 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -31,7 +31,7 @@
)
type domainDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/drivers.go b/internal/db/bundb/drivers.go
new file mode 100644
index 000000000..14d84e6fa
--- /dev/null
+++ b/internal/db/bundb/drivers.go
@@ -0,0 +1,267 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package bundb
+
+import (
+ "context"
+ "database/sql"
+ "database/sql/driver"
+ "time"
+ _ "unsafe" // linkname shenanigans
+
+ pgx "github.com/jackc/pgx/v5/stdlib"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtserror"
+ "modernc.org/sqlite"
+)
+
+var (
+ // global SQL driver instances.
+ postgresDriver = pgx.GetDefaultDriver()
+ sqliteDriver = getSQLiteDriver()
+)
+
+func init() {
+ sql.Register("pgx-gts", &PostgreSQLDriver{})
+ sql.Register("sqlite-gts", &SQLiteDriver{})
+}
+
+//go:linkname getSQLiteDriver modernc.org/sqlite.newDriver
+func getSQLiteDriver() *sqlite.Driver
+
+// PostgreSQLDriver is our own wrapper around the
+// pgx/stdlib.Driver{} type in order to wrap further
+// SQL driver types with our own err processing.
+type PostgreSQLDriver struct{}
+
+func (d *PostgreSQLDriver) Open(name string) (driver.Conn, error) {
+ c, err := postgresDriver.Open(name)
+ if err != nil {
+ return nil, err
+ }
+ return &PostgreSQLConn{conn: c.(conn)}, nil
+}
+
+type PostgreSQLConn struct{ conn }
+
+func (c *PostgreSQLConn) Begin() (driver.Tx, error) {
+ return c.BeginTx(context.Background(), driver.TxOptions{})
+}
+
+func (c *PostgreSQLConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
+ tx, err := c.conn.BeginTx(ctx, opts)
+ err = processPostgresError(err)
+ return tx, err
+}
+
+func (c *PostgreSQLConn) Prepare(query string) (driver.Stmt, error) {
+ return c.PrepareContext(context.Background(), query)
+}
+
+func (c *PostgreSQLConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+ stmt, err := c.conn.PrepareContext(ctx, query)
+ err = processPostgresError(err)
+ return stmt, err
+}
+
+func (c *PostgreSQLConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) {
+ return c.ExecContext(context.Background(), query, args)
+}
+
+func (c *PostgreSQLConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
+ result, err := c.conn.ExecContext(ctx, query, args)
+ err = processPostgresError(err)
+ return result, err
+}
+
+func (c *PostgreSQLConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) {
+ return c.QueryContext(context.Background(), query, args)
+}
+
+func (c *PostgreSQLConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
+ rows, err := c.conn.QueryContext(ctx, query, args)
+ err = processPostgresError(err)
+ return rows, err
+}
+
+func (c *PostgreSQLConn) Close() error {
+ return c.conn.Close()
+}
+
+type PostgreSQLTx struct{ driver.Tx }
+
+func (tx *PostgreSQLTx) Commit() error {
+ err := tx.Tx.Commit()
+ return processPostgresError(err)
+}
+
+func (tx *PostgreSQLTx) Rollback() error {
+ err := tx.Tx.Rollback()
+ return processPostgresError(err)
+}
+
+// SQLiteDriver is our own wrapper around the
+// sqlite.Driver{} type in order to wrap further
+// SQL driver types with our own functionality,
+// e.g. hooks, retries and err processing.
+type SQLiteDriver struct{}
+
+func (d *SQLiteDriver) Open(name string) (driver.Conn, error) {
+ c, err := sqliteDriver.Open(name)
+ if err != nil {
+ return nil, err
+ }
+ return &SQLiteConn{conn: c.(conn)}, nil
+}
+
+type SQLiteConn struct{ conn }
+
+func (c *SQLiteConn) Begin() (driver.Tx, error) {
+ return c.BeginTx(context.Background(), driver.TxOptions{})
+}
+
+func (c *SQLiteConn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
+ err = retryOnBusy(ctx, func() error {
+ tx, err = c.conn.BeginTx(ctx, opts)
+ err = processSQLiteError(err)
+ return err
+ })
+ return &SQLiteTx{Context: ctx, Tx: tx}, nil
+}
+
+func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
+ return c.PrepareContext(context.Background(), query)
+}
+
+func (c *SQLiteConn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
+ err = retryOnBusy(ctx, func() error {
+ stmt, err = c.conn.PrepareContext(ctx, query)
+ err = processSQLiteError(err)
+ return err
+ })
+ return
+}
+
+func (c *SQLiteConn) Exec(query string, args []driver.NamedValue) (driver.Result, error) {
+ return c.ExecContext(context.Background(), query, args)
+}
+
+func (c *SQLiteConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
+ err = retryOnBusy(ctx, func() error {
+ result, err = c.conn.ExecContext(ctx, query, args)
+ err = processSQLiteError(err)
+ return err
+ })
+ return
+}
+
+func (c *SQLiteConn) Query(query string, args []driver.NamedValue) (driver.Rows, error) {
+ return c.QueryContext(context.Background(), query, args)
+}
+
+func (c *SQLiteConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
+ err = retryOnBusy(ctx, func() error {
+ rows, err = c.conn.QueryContext(ctx, query, args)
+ err = processSQLiteError(err)
+ return err
+ })
+ return
+}
+
+func (c *SQLiteConn) Close() error {
+ // see: https://www.sqlite.org/pragma.html#pragma_optimize
+ const onClose = "PRAGMA analysis_limit=1000; PRAGMA optimize;"
+ _, _ = c.conn.ExecContext(context.Background(), onClose, nil)
+ return c.conn.Close()
+}
+
+type SQLiteTx struct {
+ context.Context
+ driver.Tx
+}
+
+func (tx *SQLiteTx) Commit() (err error) {
+ err = retryOnBusy(tx.Context, func() error {
+ err = tx.Tx.Commit()
+ err = processSQLiteError(err)
+ return err
+ })
+ return
+}
+
+func (tx *SQLiteTx) Rollback() (err error) {
+ err = retryOnBusy(tx.Context, func() error {
+ err = tx.Tx.Rollback()
+ err = processSQLiteError(err)
+ return err
+ })
+ return
+}
+
+type conn interface {
+ driver.Conn
+ driver.ConnPrepareContext
+ driver.ExecerContext
+ driver.QueryerContext
+ driver.ConnBeginTx
+}
+
+// retryOnBusy will retry given function on returned 'errBusy'.
+func retryOnBusy(ctx context.Context, fn func() error) error {
+ var backoff time.Duration
+
+ for i := 0; ; i++ {
+ // Perform func.
+ err := fn()
+
+ if err != errBusy {
+ // May be nil, or may be
+ // some other error, either
+ // way return here.
+ return err
+ }
+
+ // backoff according to a multiplier of 2ms * 2^2n,
+ // up to a maximum possible backoff time of 5 minutes.
+ //
+ // this works out as the following:
+ // 4ms
+ // 16ms
+ // 64ms
+ // 256ms
+ // 1.024s
+ // 4.096s
+ // 16.384s
+ // 1m5.536s
+ // 4m22.144s
+ backoff = 2 * time.Millisecond * (1 << (2*i + 1))
+ if backoff >= 5*time.Minute {
+ break
+ }
+
+ select {
+ // Context cancelled.
+ case <-ctx.Done():
+
+ // Backoff for some time.
+ case <-time.After(backoff):
+ }
+ }
+
+ return gtserror.Newf("%w (waited > %s)", db.ErrBusyTimeout, backoff)
+}
diff --git a/internal/db/bundb/emoji.go b/internal/db/bundb/emoji.go
index 608cb6417..69d33eede 100644
--- a/internal/db/bundb/emoji.go
+++ b/internal/db/bundb/emoji.go
@@ -38,7 +38,7 @@
)
type emojiDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -109,7 +109,7 @@ func (e *emojiDB) DeleteEmojiByID(ctx context.Context, id string) error {
return err
}
- return e.db.RunInTx(ctx, func(tx Tx) error {
+ return e.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete relational links between this emoji
// and any statuses using it, returning the
// status IDs so we can later update them.
diff --git a/internal/db/bundb/headerfilter.go b/internal/db/bundb/headerfilter.go
index 087b65c82..b02d9249e 100644
--- a/internal/db/bundb/headerfilter.go
+++ b/internal/db/bundb/headerfilter.go
@@ -29,7 +29,7 @@
)
type headerFilterDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/instance.go b/internal/db/bundb/instance.go
index d506e0a31..5f96f9a26 100644
--- a/internal/db/bundb/instance.go
+++ b/internal/db/bundb/instance.go
@@ -34,7 +34,7 @@
)
type instanceDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/list.go b/internal/db/bundb/list.go
index 5f95d3c24..fb97c8fe7 100644
--- a/internal/db/bundb/list.go
+++ b/internal/db/bundb/list.go
@@ -35,7 +35,7 @@
)
type listDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -198,7 +198,7 @@ func (l *listDB) DeleteListByID(ctx context.Context, id string) error {
}
}()
- return l.db.RunInTx(ctx, func(tx Tx) error {
+ return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete all entries attached to list.
if _, err := tx.NewDelete().
Table("list_entries").
@@ -515,7 +515,7 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt
}()
// Finally, insert each list entry into the database.
- return l.db.RunInTx(ctx, func(tx Tx) error {
+ return l.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
for _, entry := range entries {
entry := entry // rescope
if err := l.state.Caches.GTS.ListEntry.Store(entry, func() error {
diff --git a/internal/db/bundb/marker.go b/internal/db/bundb/marker.go
index b1dedb4f1..0ae50f269 100644
--- a/internal/db/bundb/marker.go
+++ b/internal/db/bundb/marker.go
@@ -30,7 +30,7 @@
)
type markerDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -85,7 +85,7 @@ func (m *markerDB) UpdateMarker(ctx context.Context, marker *gtsmodel.Marker) er
// Optimistic concurrency control: start a transaction, try to update a row with a previously retrieved version.
// If the update in the transaction fails to actually change anything, another update happened concurrently, and
// this update should be retried by the caller, which in this case involves sending HTTP 409 to the API client.
- return m.db.RunInTx(ctx, func(tx Tx) error {
+ return m.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
result, err := tx.NewUpdate().
Model(marker).
WherePK().
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
index ced38a588..99ef30d22 100644
--- a/internal/db/bundb/media.go
+++ b/internal/db/bundb/media.go
@@ -34,7 +34,7 @@
)
type mediaDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -151,7 +151,7 @@ func (m *mediaDB) DeleteAttachment(ctx context.Context, id string) error {
defer m.state.Caches.GTS.Media.Invalidate("ID", id)
// Delete media attachment in new transaction.
- err = m.db.RunInTx(ctx, func(tx Tx) error {
+ err = m.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
if media.AccountID != "" {
var account gtsmodel.Account
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
index b069423bb..156469544 100644
--- a/internal/db/bundb/mention.go
+++ b/internal/db/bundb/mention.go
@@ -33,7 +33,7 @@
)
type mentionDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go
index ed34222fb..3f3d5fbd6 100644
--- a/internal/db/bundb/notification.go
+++ b/internal/db/bundb/notification.go
@@ -34,7 +34,7 @@
)
type notificationDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/poll.go b/internal/db/bundb/poll.go
index 0dfb15621..37a1f26ab 100644
--- a/internal/db/bundb/poll.go
+++ b/internal/db/bundb/poll.go
@@ -34,7 +34,7 @@
)
type pollDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -154,7 +154,7 @@ func (p *pollDB) UpdatePoll(ctx context.Context, poll *gtsmodel.Poll, cols ...st
poll.CheckVotes()
return p.state.Caches.GTS.Poll.Store(poll, func() error {
- return p.db.RunInTx(ctx, func(tx Tx) error {
+ return p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Update the status' "updated_at" field.
if _, err := tx.NewUpdate().
Table("statuses").
@@ -362,7 +362,7 @@ func (p *pollDB) PopulatePollVote(ctx context.Context, vote *gtsmodel.PollVote)
func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error {
return p.state.Caches.GTS.PollVote.Store(vote, func() error {
- return p.db.RunInTx(ctx, func(tx Tx) error {
+ return p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Try insert vote into database.
if _, err := tx.NewInsert().
Model(vote).
@@ -398,7 +398,7 @@ func (p *pollDB) PutPollVote(ctx context.Context, vote *gtsmodel.PollVote) error
}
func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
- err := p.db.RunInTx(ctx, func(tx Tx) error {
+ err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Delete all votes in poll.
res, err := tx.NewDelete().
Table("poll_votes").
@@ -469,7 +469,7 @@ func (p *pollDB) DeletePollVotes(ctx context.Context, pollID string) error {
}
func (p *pollDB) DeletePollVoteBy(ctx context.Context, pollID string, accountID string) error {
- err := p.db.RunInTx(ctx, func(tx Tx) error {
+ err := p.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// Slice should only ever be of length
// 0 or 1; it's a slice of slices only
// because we can't LIMIT deletes to 1.
@@ -569,7 +569,7 @@ func (p *pollDB) DeletePollVotesByAccountID(ctx context.Context, accountID strin
}
// newSelectPollVotes returns a new select query for all rows in the poll_votes table with poll_id = pollID.
-func newSelectPollVotes(db *DB, pollID string) *bun.SelectQuery {
+func newSelectPollVotes(db *bun.DB, pollID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("poll_votes")).
ColumnExpr("?", bun.Ident("id")).
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 4c50862a1..71ae37545 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -31,7 +31,7 @@
)
type relationshipDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -299,7 +299,7 @@ func (r *relationshipDB) getAccountBlockIDs(ctx context.Context, accountID strin
}
// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
-func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery {
+func newSelectFollowRequests(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
@@ -308,7 +308,7 @@ func newSelectFollowRequests(db *DB, accountID string) *bun.SelectQuery {
}
// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
-func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery {
+func newSelectFollowRequesting(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("follow_requests")).
ColumnExpr("?", bun.Ident("id")).
@@ -317,7 +317,7 @@ func newSelectFollowRequesting(db *DB, accountID string) *bun.SelectQuery {
}
// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
-func newSelectFollows(db *DB, accountID string) *bun.SelectQuery {
+func newSelectFollows(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
Table("follows").
Column("id").
@@ -327,7 +327,7 @@ func newSelectFollows(db *DB, accountID string) *bun.SelectQuery {
// newSelectLocalFollows returns a new select query for all rows in the follows table with
// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
-func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery {
+func newSelectLocalFollows(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
Table("follows").
Column("id").
@@ -344,7 +344,7 @@ func newSelectLocalFollows(db *DB, accountID string) *bun.SelectQuery {
}
// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
-func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery {
+func newSelectFollowers(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
Table("follows").
Column("id").
@@ -354,7 +354,7 @@ func newSelectFollowers(db *DB, accountID string) *bun.SelectQuery {
// newSelectLocalFollowers returns a new select query for all rows in the follows table with
// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
-func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery {
+func newSelectLocalFollowers(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
Table("follows").
Column("id").
@@ -371,7 +371,7 @@ func newSelectLocalFollowers(db *DB, accountID string) *bun.SelectQuery {
}
// newSelectBlocks returns a new select query for all rows in the blocks table with account_id = accountID.
-func newSelectBlocks(db *DB, accountID string) *bun.SelectQuery {
+func newSelectBlocks(db *bun.DB, accountID string) *bun.SelectQuery {
return db.NewSelect().
TableExpr("?", bun.Ident("blocks")).
ColumnExpr("?", bun.Ident("id")).
diff --git a/internal/db/bundb/report.go b/internal/db/bundb/report.go
index 5b0ae17f3..486bf09f0 100644
--- a/internal/db/bundb/report.go
+++ b/internal/db/bundb/report.go
@@ -32,7 +32,7 @@
)
type reportDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/rule.go b/internal/db/bundb/rule.go
index ebfa89d15..e36053c38 100644
--- a/internal/db/bundb/rule.go
+++ b/internal/db/bundb/rule.go
@@ -32,7 +32,7 @@
)
type ruleDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/search.go b/internal/db/bundb/search.go
index f9c2df1f8..f8ae529f7 100644
--- a/internal/db/bundb/search.go
+++ b/internal/db/bundb/search.go
@@ -57,7 +57,7 @@
// This isn't ideal, of course, but at least we could cover the most common use case of
// a caller paging down through results.
type searchDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/session.go b/internal/db/bundb/session.go
index 9310a6463..2177a57ae 100644
--- a/internal/db/bundb/session.go
+++ b/internal/db/bundb/session.go
@@ -24,10 +24,11 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
+ "github.com/uptrace/bun"
)
type sessionDB struct {
- db *DB
+ db *bun.DB
}
func (s *sessionDB) GetSession(ctx context.Context) (*gtsmodel.RouterSession, error) {
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index 07a09050a..6d1788b5d 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -34,7 +34,7 @@
)
type statusDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
@@ -330,7 +330,7 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
- return s.db.RunInTx(ctx, func(tx Tx) error {
+ return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.
@@ -414,7 +414,7 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
- return s.db.RunInTx(ctx, func(tx Tx) error {
+ return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.
@@ -509,7 +509,7 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) error {
// On return ensure status invalidated from cache.
defer s.state.Caches.GTS.Status.Invalidate("ID", id)
- return s.db.RunInTx(ctx, func(tx Tx) error {
+ return s.db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
// delete links between this status and any emojis it uses
if _, err := tx.
NewDelete().
@@ -697,6 +697,5 @@ func (s *statusDB) IsStatusBookmarkedBy(ctx context.Context, status *gtsmodel.St
TableExpr("? AS ?", bun.Ident("status_bookmarks"), bun.Ident("status_bookmark")).
Where("? = ?", bun.Ident("status_bookmark.status_id"), status.ID).
Where("? = ?", bun.Ident("status_bookmark.account_id"), accountID)
-
- return s.db.Exists(ctx, q)
+ return exists(ctx, q)
}
diff --git a/internal/db/bundb/statusbookmark.go b/internal/db/bundb/statusbookmark.go
index 742c13966..73fced9c3 100644
--- a/internal/db/bundb/statusbookmark.go
+++ b/internal/db/bundb/statusbookmark.go
@@ -29,7 +29,7 @@
)
type statusBookmarkDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go
index e0f018b68..d04578076 100644
--- a/internal/db/bundb/statusfave.go
+++ b/internal/db/bundb/statusfave.go
@@ -35,7 +35,7 @@
)
type statusFaveDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/tag.go b/internal/db/bundb/tag.go
index 66ee8cb3a..e6297d2ab 100644
--- a/internal/db/bundb/tag.go
+++ b/internal/db/bundb/tag.go
@@ -28,7 +28,7 @@
)
type tagDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/thread.go b/internal/db/bundb/thread.go
index 34c5f783a..a75515062 100644
--- a/internal/db/bundb/thread.go
+++ b/internal/db/bundb/thread.go
@@ -28,7 +28,7 @@
)
type threadDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
index f2ba2a9d1..e6c7e482d 100644
--- a/internal/db/bundb/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -34,7 +34,7 @@
)
type timelineDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/tombstone.go b/internal/db/bundb/tombstone.go
index c0e439720..64169213e 100644
--- a/internal/db/bundb/tombstone.go
+++ b/internal/db/bundb/tombstone.go
@@ -27,7 +27,7 @@
)
type tombstoneDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/user.go b/internal/db/bundb/user.go
index a6fa142f2..2854c0caa 100644
--- a/internal/db/bundb/user.go
+++ b/internal/db/bundb/user.go
@@ -31,7 +31,7 @@
)
type userDB struct {
- db *DB
+ db *bun.DB
state *state.State
}
diff --git a/internal/db/bundb/util.go b/internal/db/bundb/util.go
index cee20bbe1..e2dd392dc 100644
--- a/internal/db/bundb/util.go
+++ b/internal/db/bundb/util.go
@@ -18,6 +18,8 @@
package bundb
import (
+ "context"
+ "database/sql"
"slices"
"strings"
@@ -113,6 +115,25 @@ func whereStartsLike(
)
}
+// exists checks the results of a SelectQuery for the existence of the data in question, masking ErrNoEntries errors.
+func exists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
+ exists, err := query.Exists(ctx)
+ switch err {
+ case nil:
+ return exists, nil
+ case sql.ErrNoRows:
+ return false, nil
+ default:
+ return false, err
+ }
+}
+
+// notExists checks the results of a SelectQuery for the non-existence of the data in question, masking ErrNoEntries errors.
+func notExists(ctx context.Context, query *bun.SelectQuery) (bool, error) {
+ exists, err := exists(ctx, query)
+ return !exists, err
+}
+
// loadPagedIDs loads a page of IDs from given SliceCache by `key`, resorting to `loadDESC` if required. Uses `page` to sort + page resulting IDs.
// NOTE: IDs returned from `cache` / `loadDESC` MUST be in descending order, otherwise paging will not work correctly / return things out of order.
func loadPagedIDs(cache *cache.SliceCache[string], key string, page *paging.Page, loadDESC func() ([]string, error)) ([]string, error) {