gotosocial/internal/db/bundb/relationship_mute.go

319 lines
8.6 KiB
Go
Raw Normal View History

// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package bundb
import (
"context"
"errors"
"slices"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect"
)
func (r *relationshipDB) IsMuted(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, error) {
mute, err := r.GetMute(
gtscontext.SetBarebones(ctx),
sourceAccountID,
targetAccountID,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return false, err
}
return mute != nil, nil
}
func (r *relationshipDB) GetMuteByID(ctx context.Context, id string) (*gtsmodel.UserMute, error) {
return r.getMute(
ctx,
"ID",
func(mute *gtsmodel.UserMute) error {
return r.db.NewSelect().Model(mute).
Where("? = ?", bun.Ident("id"), id).
Scan(ctx)
},
id,
)
}
func (r *relationshipDB) GetMute(
ctx context.Context,
sourceAccountID string,
targetAccountID string,
) (*gtsmodel.UserMute, error) {
return r.getMute(
ctx,
"AccountID,TargetAccountID",
func(mute *gtsmodel.UserMute) error {
return r.db.NewSelect().Model(mute).
Where("? = ?", bun.Ident("account_id"), sourceAccountID).
Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
Scan(ctx)
},
sourceAccountID,
targetAccountID,
)
}
func (r *relationshipDB) CountAccountMutes(ctx context.Context, accountID string) (int, error) {
muteIDs, err := r.getAccountMuteIDs(ctx, accountID, nil)
return len(muteIDs), err
}
func (r *relationshipDB) getMutesByIDs(ctx context.Context, ids []string) ([]*gtsmodel.UserMute, error) {
// Load all mutes IDs via cache loader callbacks.
mutes, err := r.state.Caches.DB.UserMute.LoadIDs("ID",
ids,
func(uncached []string) ([]*gtsmodel.UserMute, error) {
// Avoid querying
// if none uncached.
count := len(uncached)
if count == 0 {
return nil, nil
}
// Preallocate expected length of uncached mutes.
mutes := make([]*gtsmodel.UserMute, 0, count)
// Perform database query scanning
// the remaining (uncached) IDs.
if err := r.db.NewSelect().
Model(&mutes).
Where("? IN (?)", bun.Ident("id"), bun.In(uncached)).
Scan(ctx); err != nil {
return nil, err
}
return mutes, nil
},
)
if err != nil {
return nil, err
}
// Reorder the mutes by their
// IDs to ensure in correct order.
getID := func(b *gtsmodel.UserMute) string { return b.ID }
util.OrderBy(mutes, ids, getID)
if gtscontext.Barebones(ctx) {
// no need to fully populate.
return mutes, nil
}
// Populate all loaded mutes, removing those we fail to
// populate (removes needing so many nil checks everywhere).
mutes = slices.DeleteFunc(mutes, func(mute *gtsmodel.UserMute) bool {
if err := r.populateMute(ctx, mute); err != nil {
log.Errorf(ctx, "error populating mute %s: %v", mute.ID, err)
return true
}
return false
})
return mutes, nil
}
func (r *relationshipDB) getMute(
ctx context.Context,
lookup string,
dbQuery func(*gtsmodel.UserMute) error,
keyParts ...any,
) (*gtsmodel.UserMute, error) {
// Fetch mute from cache with loader callback
mute, err := r.state.Caches.DB.UserMute.LoadOne(lookup, func() (*gtsmodel.UserMute, error) {
var mute gtsmodel.UserMute
// Not cached! Perform database query
if err := dbQuery(&mute); err != nil {
return nil, err
}
return &mute, nil
}, keyParts...)
if err != nil {
// already processe
return nil, err
}
if gtscontext.Barebones(ctx) {
// Only a barebones model was requested.
return mute, nil
}
if err := r.populateMute(ctx, mute); err != nil {
return nil, err
}
return mute, nil
}
func (r *relationshipDB) populateMute(ctx context.Context, mute *gtsmodel.UserMute) error {
var (
errs gtserror.MultiError
err error
)
if mute.Account == nil {
// Mute origin account is not set, fetch from database.
mute.Account, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mute.AccountID,
)
if err != nil {
errs.Appendf("error populating mute account: %w", err)
}
}
if mute.TargetAccount == nil {
// Mute target account is not set, fetch from database.
mute.TargetAccount, err = r.state.DB.GetAccountByID(
gtscontext.SetBarebones(ctx),
mute.TargetAccountID,
)
if err != nil {
errs.Appendf("error populating mute target account: %w", err)
}
}
return errs.Combine()
}
func (r *relationshipDB) PutMute(ctx context.Context, mute *gtsmodel.UserMute) error {
return r.state.Caches.DB.UserMute.Store(mute, func() error {
_, err := NewUpsert(r.db).Model(mute).Constraint("id").Exec(ctx)
return err
})
}
func (r *relationshipDB) DeleteMuteByID(ctx context.Context, id string) error {
// Load mute into cache before attempting a delete,
// as we need it cached in order to trigger the invalidate
// callback. This in turn invalidates others.
_, err := r.GetMuteByID(gtscontext.SetBarebones(ctx), id)
if err != nil {
if errors.Is(err, db.ErrNoEntries) {
// not an issue.
err = nil
}
return err
}
// Drop this now-cached mute on return after delete.
defer r.state.Caches.DB.UserMute.Invalidate("ID", id)
// Finally delete mute from DB.
_, err = r.db.NewDelete().
Table("user_mutes").
Where("? = ?", bun.Ident("id"), id).
Exec(ctx)
return err
}
func (r *relationshipDB) DeleteAccountMutes(ctx context.Context, accountID string) error {
var muteIDs []string
// Get full list of IDs.
if err := r.db.NewSelect().
Column("id").
Table("user_mutes").
WhereOr("? = ? OR ? = ?",
bun.Ident("account_id"),
accountID,
bun.Ident("target_account_id"),
accountID,
).
Scan(ctx, &muteIDs); err != nil {
return err
}
defer func() {
// Invalidate all account's incoming / outoing mutes on return.
r.state.Caches.DB.UserMute.Invalidate("AccountID", accountID)
r.state.Caches.DB.UserMute.Invalidate("TargetAccountID", accountID)
}()
// Load all mutes into cache, this *really* isn't great
// but it is the only way we can ensure we invalidate all
// related caches correctly (e.g. visibility).
_, err := r.GetAccountMutes(ctx, accountID, nil)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
// Finally delete all from DB.
_, err = r.db.NewDelete().
Table("user_mutes").
Where("? IN (?)", bun.Ident("id"), bun.In(muteIDs)).
Exec(ctx)
return err
}
func (r *relationshipDB) GetAccountMutes(
ctx context.Context,
accountID string,
page *paging.Page,
) ([]*gtsmodel.UserMute, error) {
muteIDs, err := r.getAccountMuteIDs(ctx, accountID, page)
if err != nil {
return nil, err
}
return r.getMutesByIDs(ctx, muteIDs)
}
func (r *relationshipDB) getAccountMuteIDs(ctx context.Context, accountID string, page *paging.Page) ([]string, error) {
return loadPagedIDs(&r.state.Caches.DB.UserMuteIDs, accountID, page, func() ([]string, error) {
var muteIDs []string
// Mute IDs not in cache. Perform DB query.
if _, err := r.db.
NewSelect().
TableExpr("?", bun.Ident("user_mutes")).
ColumnExpr("?", bun.Ident("id")).
Where("? = ?", bun.Ident("account_id"), accountID).
WhereGroup(" AND ", func(q *bun.SelectQuery) *bun.SelectQuery {
var notYetExpiredSQL string
switch r.db.Dialect().Name() {
case dialect.SQLite:
notYetExpiredSQL = "? > DATE('now')"
case dialect.PG:
notYetExpiredSQL = "? > NOW()"
default:
log.Panicf(nil, "db conn %s was neither pg nor sqlite", r.db)
}
return q.
Where("? IS NULL", bun.Ident("expires_at")).
WhereOr(notYetExpiredSQL, bun.Ident("expires_at"))
}).
OrderExpr("? DESC", bun.Ident("id")).
Exec(ctx, &muteIDs); // nocollapse
err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, err
}
return muteIDs, nil
})
}