diff --git a/example/config.yaml b/example/config.yaml
index cb21e733b..ea33e4ca4 100644
--- a/example/config.yaml
+++ b/example/config.yaml
@@ -230,71 +230,95 @@ db-sqlite-cache-size: "8MiB"
db-sqlite-busy-timeout: "5m"
cache:
+ # Cache configuration options:
+ #
+ # max-size = maximum cached objects count
+ # ttl = cached object lifetime
+ # sweep-freq = frequency to look for stale cache objects
+ # (zero will disable cache sweeping)
+
+ #############################
+ #### VISIBILITY CACHES ######
+ #############################
+ #
+ # Configure Status and account
+ # visibility cache.
+
+ visibility-max-size: 2000
+ visibility-ttl: "30m"
+ visibility-sweep-freq: "1m"
+
gts:
###########################
#### DATABASE CACHES ######
###########################
#
- # Database cache configuration:
- #
- # Allows configuration of caches used
- # when loading GTS models from the database.
- #
- # max-size = maximum cached objects count
- # ttl = cached object lifetime
- # sweep-freq = frequency to look for stale cache objects
+ # Configure GTS database
+ # model caches.
- account-max-size: 500
- account-ttl: "5m"
- account-sweep-freq: "30s"
+ account-max-size: 2000
+ account-ttl: "30m"
+ account-sweep-freq: "1m"
block-max-size: 100
- block-ttl: "5m"
- block-sweep-freq: "30s"
+ block-ttl: "30m"
+ block-sweep-freq: "1m"
- domain-block-max-size: 1000
+ domain-block-max-size: 2000
domain-block-ttl: "24h"
domain-block-sweep-freq: "1m"
- emoji-max-size: 500
- emoji-ttl: "5m"
- emoji-sweep-freq: "30s"
+ emoji-max-size: 2000
+ emoji-ttl: "30m"
+ emoji-sweep-freq: "1m"
emoji-category-max-size: 100
- emoji-category-ttl: "5m"
- emoji-category-sweep-freq: "30s"
+ emoji-category-ttl: "30m"
+ emoji-category-sweep-freq: "1m"
- media-max-size: 500
- media-ttl: "5m"
- media-sweep-freq: "30s"
+ follow-max-size: 2000
+ follow-ttl: "30m"
+ follow-sweep-freq: "1m"
- mention-max-size: 500
- mention-ttl: "5m"
- mention-sweep-freq: "30s"
+ follow-request-max-size: 2000
+ follow-request-ttl: "30m"
+ follow-request-sweep-freq: "1m"
- notification-max-size: 500
- notification-ttl: "5m"
- notification-sweep-freq: "30s"
+ media-max-size: 1000
+ media-ttl: "30m"
+ media-sweep-freq: "1m"
+
+ mention-max-size: 2000
+ mention-ttl: "30m"
+ mention-sweep-freq: "1m"
+
+ notification-max-size: 1000
+ notification-ttl: "30m"
+ notification-sweep-freq: "1m"
report-max-size: 100
- report-ttl: "5m"
- report-sweep-freq: "30s"
+ report-ttl: "30m"
+ report-sweep-freq: "1m"
- status-max-size: 500
- status-ttl: "5m"
- status-sweep-freq: "30s"
+ status-max-size: 2000
+ status-ttl: "30m"
+ status-sweep-freq: "1m"
- tombstone-max-size: 100
- tombstone-ttl: "5m"
- tombstone-sweep-freq: "30s"
+ status-fave-max-size: 2000
+ status-fave-ttl: "30m"
+ status-fave-sweep-freq: "1m"
- user-max-size: 100
- user-ttl: "5m"
- user-sweep-freq: "30s"
+ tombstone-max-size: 500
+ tombstone-ttl: "30m"
+ tombstone-sweep-freq: "1m"
+
+ user-max-size: 500
+ user-ttl: "30m"
+ user-sweep-freq: "1m"
webfinger-max-size": 250
webfinger-ttl: "24h"
- webfinger-sweep-freq": "15m"
+ webfinger-sweep-freq": "1m"
######################
##### WEB CONFIG #####
diff --git a/internal/api/client/statuses/statusboost_test.go b/internal/api/client/statuses/statusboost_test.go
index 2d7be8c72..aea0e20e0 100644
--- a/internal/api/client/statuses/statusboost_test.go
+++ b/internal/api/client/statuses/statusboost_test.go
@@ -168,7 +168,7 @@ func (suite *StatusBoostTestSuite) TestPostBoostOwnFollowersOnly() {
suite.Equal("really cool gts application", responseStatus.Reblog.Application.Name)
}
-// try to boost a status that's not boostable
+// try to boost a status that's not boostable / visible to us
func (suite *StatusBoostTestSuite) TestPostUnboostable() {
t := suite.testTokens["local_account_1"]
oauthToken := oauth.DBTokenToToken(t)
@@ -197,13 +197,13 @@ func (suite *StatusBoostTestSuite) TestPostUnboostable() {
suite.statusModule.StatusBoostPOSTHandler(ctx)
// check response
- suite.Equal(http.StatusForbidden, recorder.Code) // we 403 unboostable statuses
+ suite.Equal(http.StatusNotFound, recorder.Code) // we 404 unboostable statuses
result := recorder.Result()
defer result.Body.Close()
b, err := ioutil.ReadAll(result.Body)
suite.NoError(err)
- suite.Equal(`{"error":"Forbidden"}`, string(b))
+ suite.Equal(`{"error":"Not Found"}`, string(b))
}
// try to boost a status that's not visible to the user
diff --git a/internal/cache/ap.go b/internal/cache/ap.go
index 204752f54..6498d7991 100644
--- a/internal/cache/ap.go
+++ b/internal/cache/ap.go
@@ -17,27 +17,14 @@
package cache
-type APCaches interface {
- // Init will initialize all the ActivityPub caches in this collection.
- // NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
- Init()
+type APCaches struct{}
- // Start will attempt to start all of the ActivityPub caches, or panic.
- Start()
+// Init will initialize all the ActivityPub caches in this collection.
+// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
+func (c *APCaches) Init() {}
- // Stop will attempt to stop all of the ActivityPub caches, or panic.
- Stop()
-}
+// Start will attempt to start all of the ActivityPub caches, or panic.
+func (c *APCaches) Start() {}
-// NewAP returns a new default implementation of APCaches.
-func NewAP() APCaches {
- return &apCaches{}
-}
-
-type apCaches struct{}
-
-func (c *apCaches) Init() {}
-
-func (c *apCaches) Start() {}
-
-func (c *apCaches) Stop() {}
+// Stop will attempt to stop all of the ActivityPub caches, or panic.
+func (c *APCaches) Stop() {}
diff --git a/internal/cache/cache.go b/internal/cache/cache.go
index 834542a52..913d6eca7 100644
--- a/internal/cache/cache.go
+++ b/internal/cache/cache.go
@@ -17,13 +17,23 @@
package cache
+import (
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+)
+
type Caches struct {
// GTS provides access to the collection of gtsmodel object caches.
+ // (used by the database).
GTS GTSCaches
// AP provides access to the collection of ActivityPub object caches.
+ // (planned to be used by the typeconverter).
AP APCaches
+ // Visibility provides access to the item visibility cache.
+ // (used by the visibility filter).
+ Visibility VisibilityCache
+
// prevent pass-by-value.
_ nocopy
}
@@ -31,29 +41,77 @@ type Caches struct {
// Init will (re)initialize both the GTS and AP cache collections.
// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
func (c *Caches) Init() {
- if c.GTS == nil {
- // use default impl
- c.GTS = NewGTS()
- }
-
- if c.AP == nil {
- // use default impl
- c.AP = NewAP()
- }
-
- // initialize caches
c.GTS.Init()
c.AP.Init()
+ c.Visibility.Init()
+
+ // Setup cache invalidate hooks.
+ // !! READ THE METHOD COMMENT
+ c.setuphooks()
}
// Start will start both the GTS and AP cache collections.
func (c *Caches) Start() {
c.GTS.Start()
c.AP.Start()
+ c.Visibility.Start()
}
// Stop will stop both the GTS and AP cache collections.
func (c *Caches) Stop() {
c.GTS.Stop()
c.AP.Stop()
+ c.Visibility.Stop()
+}
+
+// setuphooks sets necessary cache invalidation hooks between caches,
+// as an invalidation indicates a database UPDATE / DELETE. INSERT is
+// not handled by invalidation hooks and must be invalidated manually.
+func (c *Caches) setuphooks() {
+ c.GTS.Account().SetInvalidateCallback(func(account *gtsmodel.Account) {
+ // Invalidate account ID cached visibility.
+ c.Visibility.Invalidate("ItemID", account.ID)
+ c.Visibility.Invalidate("RequesterID", account.ID)
+ })
+
+ c.GTS.Block().SetInvalidateCallback(func(block *gtsmodel.Block) {
+ // Invalidate block origin account ID cached visibility.
+ c.Visibility.Invalidate("ItemID", block.AccountID)
+ c.Visibility.Invalidate("RequesterID", block.AccountID)
+
+ // Invalidate block target account ID cached visibility.
+ c.Visibility.Invalidate("ItemID", block.TargetAccountID)
+ c.Visibility.Invalidate("RequesterID", block.TargetAccountID)
+ })
+
+ c.GTS.Follow().SetInvalidateCallback(func(follow *gtsmodel.Follow) {
+ // Invalidate follow origin account ID cached visibility.
+ c.Visibility.Invalidate("ItemID", follow.AccountID)
+ c.Visibility.Invalidate("RequesterID", follow.AccountID)
+
+ // Invalidate follow target account ID cached visibility.
+ c.Visibility.Invalidate("ItemID", follow.TargetAccountID)
+ c.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
+ })
+
+ c.GTS.FollowRequest().SetInvalidateCallback(func(followReq *gtsmodel.FollowRequest) {
+ // Invalidate follow request origin account ID cached visibility.
+ c.Visibility.Invalidate("ItemID", followReq.AccountID)
+ c.Visibility.Invalidate("RequesterID", followReq.AccountID)
+
+ // Invalidate follow request target account ID cached visibility.
+ c.Visibility.Invalidate("ItemID", followReq.TargetAccountID)
+ c.Visibility.Invalidate("RequesterID", followReq.TargetAccountID)
+ })
+
+ c.GTS.Status().SetInvalidateCallback(func(status *gtsmodel.Status) {
+ // Invalidate status ID cached visibility.
+ c.Visibility.Invalidate("ItemID", status.ID)
+ })
+
+ c.GTS.User().SetInvalidateCallback(func(user *gtsmodel.User) {
+ // Invalidate local account ID cached visibility.
+ c.Visibility.Invalidate("ItemID", user.AccountID)
+ c.Visibility.Invalidate("RequesterID", user.AccountID)
+ })
}
diff --git a/internal/cache/gts.go b/internal/cache/gts.go
index 72c3211a8..392fc8449 100644
--- a/internal/cache/gts.go
+++ b/internal/cache/gts.go
@@ -25,240 +25,221 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
-type GTSCaches interface {
- // Init will initialize all the gtsmodel caches in this collection.
- // NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
- Init()
-
- // Start will attempt to start all of the gtsmodel caches, or panic.
- Start()
-
- // Stop will attempt to stop all of the gtsmodel caches, or panic.
- Stop()
-
- // Account provides access to the gtsmodel Account database cache.
- Account() *result.Cache[*gtsmodel.Account]
-
- // Block provides access to the gtsmodel Block (account) database cache.
- Block() *result.Cache[*gtsmodel.Block]
-
- // DomainBlock provides access to the domain block database cache.
- DomainBlock() *domain.BlockCache
-
- // Emoji provides access to the gtsmodel Emoji database cache.
- Emoji() *result.Cache[*gtsmodel.Emoji]
-
- // EmojiCategory provides access to the gtsmodel EmojiCategory database cache.
- EmojiCategory() *result.Cache[*gtsmodel.EmojiCategory]
-
- // Mention provides access to the gtsmodel Mention database cache.
- Mention() *result.Cache[*gtsmodel.Mention]
-
- // Media provides access to the gtsmodel Media database cache.
- Media() *result.Cache[*gtsmodel.MediaAttachment]
-
- // Notification provides access to the gtsmodel Notification database cache.
- Notification() *result.Cache[*gtsmodel.Notification]
-
- // Report provides access to the gtsmodel Report database cache.
- Report() *result.Cache[*gtsmodel.Report]
-
- // Status provides access to the gtsmodel Status database cache.
- Status() *result.Cache[*gtsmodel.Status]
-
- // Tombstone provides access to the gtsmodel Tombstone database cache.
- Tombstone() *result.Cache[*gtsmodel.Tombstone]
-
- // User provides access to the gtsmodel User database cache.
- User() *result.Cache[*gtsmodel.User]
-
- // Webfinger
- Webfinger() *ttl.Cache[string, string]
-}
-
-// NewGTS returns a new default implementation of GTSCaches.
-func NewGTS() GTSCaches {
- return >sCaches{}
-}
-
-type gtsCaches struct {
- account *result.Cache[*gtsmodel.Account]
- block *result.Cache[*gtsmodel.Block]
+type GTSCaches struct {
+ account *result.Cache[*gtsmodel.Account]
+ block *result.Cache[*gtsmodel.Block]
+ // TODO: maybe should be moved out of here since it's
+ // not actually doing anything with gtsmodel.DomainBlock.
domainBlock *domain.BlockCache
emoji *result.Cache[*gtsmodel.Emoji]
emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
+ follow *result.Cache[*gtsmodel.Follow]
+ followRequest *result.Cache[*gtsmodel.FollowRequest]
media *result.Cache[*gtsmodel.MediaAttachment]
mention *result.Cache[*gtsmodel.Mention]
notification *result.Cache[*gtsmodel.Notification]
report *result.Cache[*gtsmodel.Report]
status *result.Cache[*gtsmodel.Status]
+ statusFave *result.Cache[*gtsmodel.StatusFave]
tombstone *result.Cache[*gtsmodel.Tombstone]
user *result.Cache[*gtsmodel.User]
- webfinger *ttl.Cache[string, string]
+ // TODO: move out of GTS caches since not using database models.
+ webfinger *ttl.Cache[string, string]
}
-func (c *gtsCaches) Init() {
+// Init will initialize all the gtsmodel caches in this collection.
+// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
+func (c *GTSCaches) Init() {
c.initAccount()
c.initBlock()
c.initDomainBlock()
c.initEmoji()
c.initEmojiCategory()
+ c.initFollow()
+ c.initFollowRequest()
c.initMedia()
c.initMention()
c.initNotification()
c.initReport()
c.initStatus()
+ c.initStatusFave()
c.initTombstone()
c.initUser()
c.initWebfinger()
}
-func (c *gtsCaches) Start() {
- tryUntil("starting gtsmodel.Account cache", 5, func() bool {
- return c.account.Start(config.GetCacheGTSAccountSweepFreq())
+// Start will attempt to start all of the gtsmodel caches, or panic.
+func (c *GTSCaches) Start() {
+ tryStart(c.account, config.GetCacheGTSAccountSweepFreq())
+ tryStart(c.block, config.GetCacheGTSBlockSweepFreq())
+ tryUntil("starting domain block cache", 5, func() bool {
+ if sweep := config.GetCacheGTSDomainBlockSweepFreq(); sweep > 0 {
+ return c.domainBlock.Start(sweep)
+ }
+ return true
})
- tryUntil("starting gtsmodel.Block cache", 5, func() bool {
- return c.block.Start(config.GetCacheGTSBlockSweepFreq())
- })
- tryUntil("starting gtsmodel.DomainBlock cache", 5, func() bool {
- return c.domainBlock.Start(config.GetCacheGTSDomainBlockSweepFreq())
- })
- tryUntil("starting gtsmodel.Emoji cache", 5, func() bool {
- return c.emoji.Start(config.GetCacheGTSEmojiSweepFreq())
- })
- tryUntil("starting gtsmodel.EmojiCategory cache", 5, func() bool {
- return c.emojiCategory.Start(config.GetCacheGTSEmojiCategorySweepFreq())
- })
- tryUntil("starting gtsmodel.MediaAttachment cache", 5, func() bool {
- return c.media.Start(config.GetCacheGTSMediaSweepFreq())
- })
- tryUntil("starting gtsmodel.Mention cache", 5, func() bool {
- return c.mention.Start(config.GetCacheGTSMentionSweepFreq())
- })
- tryUntil("starting gtsmodel.Notification cache", 5, func() bool {
- return c.notification.Start(config.GetCacheGTSNotificationSweepFreq())
- })
- tryUntil("starting gtsmodel.Report cache", 5, func() bool {
- return c.report.Start(config.GetCacheGTSReportSweepFreq())
- })
- tryUntil("starting gtsmodel.Status cache", 5, func() bool {
- return c.status.Start(config.GetCacheGTSStatusSweepFreq())
- })
- tryUntil("starting gtsmodel.Tombstone cache", 5, func() bool {
- return c.tombstone.Start(config.GetCacheGTSTombstoneSweepFreq())
- })
- tryUntil("starting gtsmodel.User cache", 5, func() bool {
- return c.user.Start(config.GetCacheGTSUserSweepFreq())
- })
- tryUntil("starting gtsmodel.Webfinger cache", 5, func() bool {
- return c.webfinger.Start(config.GetCacheGTSWebfingerSweepFreq())
+ tryStart(c.emoji, config.GetCacheGTSEmojiSweepFreq())
+ tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
+ tryStart(c.follow, config.GetCacheGTSFollowSweepFreq())
+ tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
+ tryStart(c.media, config.GetCacheGTSMediaSweepFreq())
+ tryStart(c.mention, config.GetCacheGTSMentionSweepFreq())
+ tryStart(c.notification, config.GetCacheGTSNotificationSweepFreq())
+ tryStart(c.report, config.GetCacheGTSReportSweepFreq())
+ tryStart(c.status, config.GetCacheGTSStatusSweepFreq())
+ tryStart(c.statusFave, config.GetCacheGTSStatusFaveSweepFreq())
+ tryStart(c.tombstone, config.GetCacheGTSTombstoneSweepFreq())
+ tryStart(c.user, config.GetCacheGTSUserSweepFreq())
+ tryUntil("starting *gtsmodel.Webfinger cache", 5, func() bool {
+ if sweep := config.GetCacheGTSWebfingerSweepFreq(); sweep > 0 {
+ return c.webfinger.Start(sweep)
+ }
+ return true
})
}
-func (c *gtsCaches) Stop() {
- tryUntil("stopping gtsmodel.Account cache", 5, c.account.Stop)
- tryUntil("stopping gtsmodel.Block cache", 5, c.block.Stop)
- tryUntil("stopping gtsmodel.DomainBlock cache", 5, c.domainBlock.Stop)
- tryUntil("stopping gtsmodel.Emoji cache", 5, c.emoji.Stop)
- tryUntil("stopping gtsmodel.EmojiCategory cache", 5, c.emojiCategory.Stop)
- tryUntil("stopping gtsmodel.MediaAttachment cache", 5, c.media.Stop)
- tryUntil("stopping gtsmodel.Mention cache", 5, c.mention.Stop)
- tryUntil("stopping gtsmodel.Notification cache", 5, c.notification.Stop)
- tryUntil("stopping gtsmodel.Report cache", 5, c.report.Stop)
- tryUntil("stopping gtsmodel.Status cache", 5, c.status.Stop)
- tryUntil("stopping gtsmodel.Tombstone cache", 5, c.tombstone.Stop)
- tryUntil("stopping gtsmodel.User cache", 5, c.user.Stop)
- tryUntil("stopping gtsmodel.Webfinger cache", 5, c.webfinger.Stop)
+// Stop will attempt to stop all of the gtsmodel caches, or panic.
+func (c *GTSCaches) Stop() {
+ tryStop(c.account, config.GetCacheGTSAccountSweepFreq())
+ tryStop(c.block, config.GetCacheGTSBlockSweepFreq())
+ tryUntil("stopping domain block cache", 5, c.domainBlock.Stop)
+ tryStop(c.emoji, config.GetCacheGTSEmojiSweepFreq())
+ tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
+ tryStop(c.follow, config.GetCacheGTSFollowSweepFreq())
+ tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
+ tryStop(c.media, config.GetCacheGTSMediaSweepFreq())
+ tryStop(c.mention, config.GetCacheGTSNotificationSweepFreq())
+ tryStop(c.notification, config.GetCacheGTSNotificationSweepFreq())
+ tryStop(c.report, config.GetCacheGTSReportSweepFreq())
+ tryStop(c.status, config.GetCacheGTSStatusSweepFreq())
+ tryStop(c.statusFave, config.GetCacheGTSStatusFaveSweepFreq())
+ tryStop(c.tombstone, config.GetCacheGTSTombstoneSweepFreq())
+ tryStop(c.user, config.GetCacheGTSUserSweepFreq())
+ tryUntil("stopping *gtsmodel.Webfinger cache", 5, c.webfinger.Stop)
}
-func (c *gtsCaches) Account() *result.Cache[*gtsmodel.Account] {
+// Account provides access to the gtsmodel Account database cache.
+func (c *GTSCaches) Account() *result.Cache[*gtsmodel.Account] {
return c.account
}
-func (c *gtsCaches) Block() *result.Cache[*gtsmodel.Block] {
+// Block provides access to the gtsmodel Block (account) database cache.
+func (c *GTSCaches) Block() *result.Cache[*gtsmodel.Block] {
return c.block
}
-func (c *gtsCaches) DomainBlock() *domain.BlockCache {
+// DomainBlock provides access to the domain block database cache.
+func (c *GTSCaches) DomainBlock() *domain.BlockCache {
return c.domainBlock
}
-func (c *gtsCaches) Emoji() *result.Cache[*gtsmodel.Emoji] {
+// Emoji provides access to the gtsmodel Emoji database cache.
+func (c *GTSCaches) Emoji() *result.Cache[*gtsmodel.Emoji] {
return c.emoji
}
-func (c *gtsCaches) EmojiCategory() *result.Cache[*gtsmodel.EmojiCategory] {
+// EmojiCategory provides access to the gtsmodel EmojiCategory database cache.
+func (c *GTSCaches) EmojiCategory() *result.Cache[*gtsmodel.EmojiCategory] {
return c.emojiCategory
}
-func (c *gtsCaches) Media() *result.Cache[*gtsmodel.MediaAttachment] {
+// Follow provides access to the gtsmodel Follow database cache.
+func (c *GTSCaches) Follow() *result.Cache[*gtsmodel.Follow] {
+ return c.follow
+}
+
+// FollowRequest provides access to the gtsmodel FollowRequest database cache.
+func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] {
+ return c.followRequest
+}
+
+// Media provides access to the gtsmodel Media database cache.
+func (c *GTSCaches) Media() *result.Cache[*gtsmodel.MediaAttachment] {
return c.media
}
-func (c *gtsCaches) Mention() *result.Cache[*gtsmodel.Mention] {
+// Mention provides access to the gtsmodel Mention database cache.
+func (c *GTSCaches) Mention() *result.Cache[*gtsmodel.Mention] {
return c.mention
}
-func (c *gtsCaches) Notification() *result.Cache[*gtsmodel.Notification] {
+// Notification provides access to the gtsmodel Notification database cache.
+func (c *GTSCaches) Notification() *result.Cache[*gtsmodel.Notification] {
return c.notification
}
-func (c *gtsCaches) Report() *result.Cache[*gtsmodel.Report] {
+// Report provides access to the gtsmodel Report database cache.
+func (c *GTSCaches) Report() *result.Cache[*gtsmodel.Report] {
return c.report
}
-func (c *gtsCaches) Status() *result.Cache[*gtsmodel.Status] {
+// Status provides access to the gtsmodel Status database cache.
+func (c *GTSCaches) Status() *result.Cache[*gtsmodel.Status] {
return c.status
}
-func (c *gtsCaches) Tombstone() *result.Cache[*gtsmodel.Tombstone] {
+// StatusFave provides access to the gtsmodel StatusFave database cache.
+func (c *GTSCaches) StatusFave() *result.Cache[*gtsmodel.StatusFave] {
+ return c.statusFave
+}
+
+// Tombstone provides access to the gtsmodel Tombstone database cache.
+func (c *GTSCaches) Tombstone() *result.Cache[*gtsmodel.Tombstone] {
return c.tombstone
}
-func (c *gtsCaches) User() *result.Cache[*gtsmodel.User] {
+// User provides access to the gtsmodel User database cache.
+func (c *GTSCaches) User() *result.Cache[*gtsmodel.User] {
return c.user
}
-func (c *gtsCaches) Webfinger() *ttl.Cache[string, string] {
+// Webfinger provides access to the webfinger URL cache.
+func (c *GTSCaches) Webfinger() *ttl.Cache[string, string] {
return c.webfinger
}
-func (c *gtsCaches) initAccount() {
+func (c *GTSCaches) initAccount() {
c.account = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
{Name: "URL"},
{Name: "Username.Domain"},
{Name: "PublicKeyURI"},
+ {Name: "InboxURI"},
+ {Name: "OutboxURI"},
+ {Name: "FollowersURI"},
+ {Name: "FollowingURI"},
}, func(a1 *gtsmodel.Account) *gtsmodel.Account {
a2 := new(gtsmodel.Account)
*a2 = *a1
return a2
}, config.GetCacheGTSAccountMaxSize())
c.account.SetTTL(config.GetCacheGTSAccountTTL(), true)
+ c.account.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initBlock() {
+func (c *GTSCaches) initBlock() {
c.block = result.New([]result.Lookup{
{Name: "ID"},
- {Name: "AccountID.TargetAccountID"},
{Name: "URI"},
+ {Name: "AccountID.TargetAccountID"},
}, func(b1 *gtsmodel.Block) *gtsmodel.Block {
b2 := new(gtsmodel.Block)
*b2 = *b1
return b2
}, config.GetCacheGTSBlockMaxSize())
c.block.SetTTL(config.GetCacheGTSBlockTTL(), true)
+ c.block.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initDomainBlock() {
+func (c *GTSCaches) initDomainBlock() {
c.domainBlock = domain.New(
config.GetCacheGTSDomainBlockMaxSize(),
config.GetCacheGTSDomainBlockTTL(),
)
}
-func (c *gtsCaches) initEmoji() {
+func (c *GTSCaches) initEmoji() {
c.emoji = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
@@ -270,9 +251,10 @@ func (c *gtsCaches) initEmoji() {
return e2
}, config.GetCacheGTSEmojiMaxSize())
c.emoji.SetTTL(config.GetCacheGTSEmojiTTL(), true)
+ c.emoji.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initEmojiCategory() {
+func (c *GTSCaches) initEmojiCategory() {
c.emojiCategory = result.New([]result.Lookup{
{Name: "ID"},
{Name: "Name"},
@@ -282,9 +264,36 @@ func (c *gtsCaches) initEmojiCategory() {
return c2
}, config.GetCacheGTSEmojiCategoryMaxSize())
c.emojiCategory.SetTTL(config.GetCacheGTSEmojiCategoryTTL(), true)
+ c.emojiCategory.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initMedia() {
+func (c *GTSCaches) initFollow() {
+ c.follow = result.New([]result.Lookup{
+ {Name: "ID"},
+ {Name: "URI"},
+ {Name: "AccountID.TargetAccountID"},
+ }, func(f1 *gtsmodel.Follow) *gtsmodel.Follow {
+ f2 := new(gtsmodel.Follow)
+ *f2 = *f1
+ return f2
+ }, config.GetCacheGTSFollowMaxSize())
+ c.follow.SetTTL(config.GetCacheGTSFollowTTL(), true)
+}
+
+func (c *GTSCaches) initFollowRequest() {
+ c.followRequest = result.New([]result.Lookup{
+ {Name: "ID"},
+ {Name: "URI"},
+ {Name: "AccountID.TargetAccountID"},
+ }, func(f1 *gtsmodel.FollowRequest) *gtsmodel.FollowRequest {
+ f2 := new(gtsmodel.FollowRequest)
+ *f2 = *f1
+ return f2
+ }, config.GetCacheGTSFollowRequestMaxSize())
+ c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true)
+}
+
+func (c *GTSCaches) initMedia() {
c.media = result.New([]result.Lookup{
{Name: "ID"},
}, func(m1 *gtsmodel.MediaAttachment) *gtsmodel.MediaAttachment {
@@ -293,9 +302,10 @@ func (c *gtsCaches) initMedia() {
return m2
}, config.GetCacheGTSMediaMaxSize())
c.media.SetTTL(config.GetCacheGTSMediaTTL(), true)
+ c.media.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initMention() {
+func (c *GTSCaches) initMention() {
c.mention = result.New([]result.Lookup{
{Name: "ID"},
}, func(m1 *gtsmodel.Mention) *gtsmodel.Mention {
@@ -304,9 +314,10 @@ func (c *gtsCaches) initMention() {
return m2
}, config.GetCacheGTSMentionMaxSize())
c.mention.SetTTL(config.GetCacheGTSMentionTTL(), true)
+ c.mention.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initNotification() {
+func (c *GTSCaches) initNotification() {
c.notification = result.New([]result.Lookup{
{Name: "ID"},
}, func(n1 *gtsmodel.Notification) *gtsmodel.Notification {
@@ -315,9 +326,10 @@ func (c *gtsCaches) initNotification() {
return n2
}, config.GetCacheGTSNotificationMaxSize())
c.notification.SetTTL(config.GetCacheGTSNotificationTTL(), true)
+ c.notification.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initReport() {
+func (c *GTSCaches) initReport() {
c.report = result.New([]result.Lookup{
{Name: "ID"},
}, func(r1 *gtsmodel.Report) *gtsmodel.Report {
@@ -326,9 +338,10 @@ func (c *gtsCaches) initReport() {
return r2
}, config.GetCacheGTSReportMaxSize())
c.report.SetTTL(config.GetCacheGTSReportTTL(), true)
+ c.report.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initStatus() {
+func (c *GTSCaches) initStatus() {
c.status = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
@@ -339,10 +352,24 @@ func (c *gtsCaches) initStatus() {
return s2
}, config.GetCacheGTSStatusMaxSize())
c.status.SetTTL(config.GetCacheGTSStatusTTL(), true)
+ c.status.IgnoreErrors(ignoreErrors)
+}
+
+func (c *GTSCaches) initStatusFave() {
+ c.statusFave = result.New([]result.Lookup{
+ {Name: "ID"},
+ {Name: "AccountID.StatusID"},
+ }, func(f1 *gtsmodel.StatusFave) *gtsmodel.StatusFave {
+ f2 := new(gtsmodel.StatusFave)
+ *f2 = *f1
+ return f2
+ }, config.GetCacheGTSStatusFaveMaxSize())
+ c.status.SetTTL(config.GetCacheGTSStatusFaveTTL(), true)
+ c.status.IgnoreErrors(ignoreErrors)
}
// initTombstone will initialize the gtsmodel.Tombstone cache.
-func (c *gtsCaches) initTombstone() {
+func (c *GTSCaches) initTombstone() {
c.tombstone = result.New([]result.Lookup{
{Name: "ID"},
{Name: "URI"},
@@ -352,9 +379,10 @@ func (c *gtsCaches) initTombstone() {
return t2
}, config.GetCacheGTSTombstoneMaxSize())
c.tombstone.SetTTL(config.GetCacheGTSTombstoneTTL(), true)
+ c.tombstone.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initUser() {
+func (c *GTSCaches) initUser() {
c.user = result.New([]result.Lookup{
{Name: "ID"},
{Name: "AccountID"},
@@ -367,9 +395,10 @@ func (c *gtsCaches) initUser() {
return u2
}, config.GetCacheGTSUserMaxSize())
c.user.SetTTL(config.GetCacheGTSUserTTL(), true)
+ c.user.IgnoreErrors(ignoreErrors)
}
-func (c *gtsCaches) initWebfinger() {
+func (c *GTSCaches) initWebfinger() {
c.webfinger = ttl.New[string, string](
0,
config.GetCacheGTSWebfingerMaxSize(),
diff --git a/internal/cache/util.go b/internal/cache/util.go
index 066e477e9..1ffd72876 100644
--- a/internal/cache/util.go
+++ b/internal/cache/util.go
@@ -17,7 +17,30 @@
package cache
-import "github.com/superseriousbusiness/gotosocial/internal/log"
+import (
+ "context"
+ "errors"
+ "fmt"
+ "time"
+
+ "codeberg.org/gruf/go-cache/v3/result"
+ errorsv2 "codeberg.org/gruf/go-errors/v2"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+)
+
+// SentinelError is returned to indicate a non-permanent error return,
+// i.e. a situation in which we do not want a cache a negative result.
+var SentinelError = errors.New("BUG: error should not be returned") //nolint:revive
+
+// ignoreErrors is an error ignoring function capable of being passed to
+// caches, which specifically catches and ignores our sentinel error type.
+func ignoreErrors(err error) bool {
+ return errorsv2.Is(
+ SentinelError,
+ context.DeadlineExceeded,
+ context.Canceled,
+ )
+}
// nocopy when embedded will signal linter to
// error on pass-by-value of parent struct.
@@ -27,6 +50,26 @@ func (*nocopy) Lock() {}
func (*nocopy) Unlock() {}
+// tryStart will attempt to start the given cache only if sweep duration > 0 (sweeping is enabled).
+func tryStart[ValueType any](cache *result.Cache[ValueType], sweep time.Duration) {
+ if sweep > 0 {
+ var z ValueType
+ msg := fmt.Sprintf("starting %T cache", z)
+ tryUntil(msg, 5, func() bool {
+ return cache.Start(sweep)
+ })
+ }
+}
+
+// tryStop will attempt to stop the given cache only if sweep duration > 0 (sweeping is enabled).
+func tryStop[ValueType any](cache *result.Cache[ValueType], sweep time.Duration) {
+ if sweep > 0 {
+ var z ValueType
+ msg := fmt.Sprintf("stopping %T cache", z)
+ tryUntil(msg, 5, cache.Stop)
+ }
+}
+
// tryUntil will attempt to call 'do' for 'count' attempts, before panicking with 'msg'.
func tryUntil(msg string, count int, do func() bool) {
for i := 0; i < count; i++ {
diff --git a/internal/cache/visibility.go b/internal/cache/visibility.go
new file mode 100644
index 000000000..8706a8015
--- /dev/null
+++ b/internal/cache/visibility.go
@@ -0,0 +1,81 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package cache
+
+import (
+ "codeberg.org/gruf/go-cache/v3/result"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+)
+
+type VisibilityCache struct {
+ *result.Cache[*CachedVisibility]
+}
+
+// Init will initialize the visibility cache in this collection.
+// NOTE: the cache MUST NOT be in use anywhere, this is not thread-safe.
+func (c *VisibilityCache) Init() {
+ c.Cache = result.New([]result.Lookup{
+ {Name: "ItemID"},
+ {Name: "RequesterID"},
+ {Name: "Type.RequesterID.ItemID"},
+ }, func(v1 *CachedVisibility) *CachedVisibility {
+ v2 := new(CachedVisibility)
+ *v2 = *v1
+ return v2
+ }, config.GetCacheVisibilityMaxSize())
+ c.Cache.SetTTL(config.GetCacheVisibilityTTL(), true)
+ c.Cache.IgnoreErrors(ignoreErrors)
+}
+
+// Start will attempt to start the visibility cache, or panic.
+func (c *VisibilityCache) Start() {
+ tryStart(c.Cache, config.GetCacheVisibilitySweepFreq())
+}
+
+// Stop will attempt to stop the visibility cache, or panic.
+func (c *VisibilityCache) Stop() {
+ tryStop(c.Cache, config.GetCacheVisibilitySweepFreq())
+}
+
+// VisibilityType represents a visibility lookup type.
+// We use a byte type here to improve performance in the
+// result cache when generating the key.
+type VisibilityType byte
+
+const (
+ // Possible cache visibility lookup types.
+ VisibilityTypeAccount = VisibilityType('a')
+ VisibilityTypeStatus = VisibilityType('s')
+ VisibilityTypeHome = VisibilityType('h')
+ VisibilityTypePublic = VisibilityType('p')
+)
+
+// CachedVisibility represents a cached visibility lookup value.
+type CachedVisibility struct {
+ // ItemID is the ID of the item in question (status / account).
+ ItemID string
+
+ // RequesterID is the ID of the requesting account for this visibility lookup.
+ RequesterID string
+
+ // Type is the visibility lookup type.
+ Type VisibilityType
+
+ // Value is the actual visibility value.
+ Value bool
+}
diff --git a/internal/config/config.go b/internal/config/config.go
index a1e00ea8d..ab353f32a 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -157,6 +157,10 @@ type Configuration struct {
type CacheConfiguration struct {
GTS GTSCacheConfiguration `name:"gts"`
+
+ VisibilityMaxSize int `name:"visibility-max-size"`
+ VisibilityTTL time.Duration `name:"visibility-ttl"`
+ VisibilitySweepFreq time.Duration `name:"visibility-sweep-freq"`
}
type GTSCacheConfiguration struct {
@@ -180,6 +184,14 @@ type GTSCacheConfiguration struct {
EmojiCategoryTTL time.Duration `name:"emoji-category-ttl"`
EmojiCategorySweepFreq time.Duration `name:"emoji-category-sweep-freq"`
+ FollowMaxSize int `name:"follow-max-size"`
+ FollowTTL time.Duration `name:"follow-ttl"`
+ FollowSweepFreq time.Duration `name:"follow-sweep-freq"`
+
+ FollowRequestMaxSize int `name:"follow-request-max-size"`
+ FollowRequestTTL time.Duration `name:"follow-request-ttl"`
+ FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"`
+
MediaMaxSize int `name:"media-max-size"`
MediaTTL time.Duration `name:"media-ttl"`
MediaSweepFreq time.Duration `name:"media-sweep-freq"`
@@ -200,6 +212,10 @@ type GTSCacheConfiguration struct {
StatusTTL time.Duration `name:"status-ttl"`
StatusSweepFreq time.Duration `name:"status-sweep-freq"`
+ StatusFaveMaxSize int `name:"status-fave-max-size"`
+ StatusFaveTTL time.Duration `name:"status-fave-ttl"`
+ StatusFaveSweepFreq time.Duration `name:"status-fave-sweep-freq"`
+
TombstoneMaxSize int `name:"tombstone-max-size"`
TombstoneTTL time.Duration `name:"tombstone-ttl"`
TombstoneSweepFreq time.Duration `name:"tombstone-sweep-freq"`
diff --git a/internal/config/defaults.go b/internal/config/defaults.go
index 17cc71086..999b81c65 100644
--- a/internal/config/defaults.go
+++ b/internal/config/defaults.go
@@ -51,7 +51,7 @@
DbSqliteJournalMode: "WAL",
DbSqliteSynchronous: "NORMAL",
DbSqliteCacheSize: 8 * bytesize.MiB,
- DbSqliteBusyTimeout: time.Minute * 5,
+ DbSqliteBusyTimeout: time.Minute * 30,
WebTemplateBaseDir: "./web/template/",
WebAssetBaseDir: "./web/assets/",
@@ -119,58 +119,74 @@
Cache: CacheConfiguration{
GTS: GTSCacheConfiguration{
- AccountMaxSize: 500,
- AccountTTL: time.Minute * 5,
- AccountSweepFreq: time.Second * 30,
+ AccountMaxSize: 2000,
+ AccountTTL: time.Minute * 30,
+ AccountSweepFreq: time.Minute,
- BlockMaxSize: 100,
- BlockTTL: time.Minute * 5,
- BlockSweepFreq: time.Second * 30,
+ BlockMaxSize: 1000,
+ BlockTTL: time.Minute * 30,
+ BlockSweepFreq: time.Minute,
- DomainBlockMaxSize: 1000,
+ DomainBlockMaxSize: 2000,
DomainBlockTTL: time.Hour * 24,
DomainBlockSweepFreq: time.Minute,
- EmojiMaxSize: 500,
- EmojiTTL: time.Minute * 5,
- EmojiSweepFreq: time.Second * 30,
+ EmojiMaxSize: 2000,
+ EmojiTTL: time.Minute * 30,
+ EmojiSweepFreq: time.Minute,
EmojiCategoryMaxSize: 100,
- EmojiCategoryTTL: time.Minute * 5,
- EmojiCategorySweepFreq: time.Second * 30,
+ EmojiCategoryTTL: time.Minute * 30,
+ EmojiCategorySweepFreq: time.Minute,
- MediaMaxSize: 500,
- MediaTTL: time.Minute * 5,
- MediaSweepFreq: time.Second * 30,
+ FollowMaxSize: 2000,
+ FollowTTL: time.Minute * 30,
+ FollowSweepFreq: time.Minute,
- MentionMaxSize: 500,
- MentionTTL: time.Minute * 5,
- MentionSweepFreq: time.Second * 30,
+ FollowRequestMaxSize: 2000,
+ FollowRequestTTL: time.Minute * 30,
+ FollowRequestSweepFreq: time.Minute,
- NotificationMaxSize: 500,
- NotificationTTL: time.Minute * 5,
- NotificationSweepFreq: time.Second * 30,
+ MediaMaxSize: 1000,
+ MediaTTL: time.Minute * 30,
+ MediaSweepFreq: time.Minute,
+
+ MentionMaxSize: 2000,
+ MentionTTL: time.Minute * 30,
+ MentionSweepFreq: time.Minute,
+
+ NotificationMaxSize: 1000,
+ NotificationTTL: time.Minute * 30,
+ NotificationSweepFreq: time.Minute,
ReportMaxSize: 100,
- ReportTTL: time.Minute * 5,
- ReportSweepFreq: time.Second * 30,
+ ReportTTL: time.Minute * 30,
+ ReportSweepFreq: time.Minute,
- StatusMaxSize: 500,
- StatusTTL: time.Minute * 5,
- StatusSweepFreq: time.Second * 30,
+ StatusMaxSize: 2000,
+ StatusTTL: time.Minute * 30,
+ StatusSweepFreq: time.Minute,
- TombstoneMaxSize: 100,
- TombstoneTTL: time.Minute * 5,
- TombstoneSweepFreq: time.Second * 30,
+ StatusFaveMaxSize: 2000,
+ StatusFaveTTL: time.Minute * 30,
+ StatusFaveSweepFreq: time.Minute,
- UserMaxSize: 100,
- UserTTL: time.Minute * 5,
- UserSweepFreq: time.Second * 30,
+ TombstoneMaxSize: 500,
+ TombstoneTTL: time.Minute * 30,
+ TombstoneSweepFreq: time.Minute,
+
+ UserMaxSize: 500,
+ UserTTL: time.Minute * 30,
+ UserSweepFreq: time.Minute,
WebfingerMaxSize: 250,
WebfingerTTL: time.Hour * 24,
WebfingerSweepFreq: time.Minute * 15,
},
+
+ VisibilityMaxSize: 2000,
+ VisibilityTTL: time.Minute * 30,
+ VisibilitySweepFreq: time.Minute,
},
AdminMediaPruneDryRun: true,
diff --git a/internal/config/helpers.gen.go b/internal/config/helpers.gen.go
index 6fc195ad0..e35eb0665 100644
--- a/internal/config/helpers.gen.go
+++ b/internal/config/helpers.gen.go
@@ -2501,6 +2501,158 @@ func GetCacheGTSEmojiCategorySweepFreq() time.Duration {
// SetCacheGTSEmojiCategorySweepFreq safely sets the value for global configuration 'Cache.GTS.EmojiCategorySweepFreq' field
func SetCacheGTSEmojiCategorySweepFreq(v time.Duration) { global.SetCacheGTSEmojiCategorySweepFreq(v) }
+// GetCacheGTSFollowMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowMaxSize' field
+func (st *ConfigState) GetCacheGTSFollowMaxSize() (v int) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.FollowMaxSize
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSFollowMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowMaxSize' field
+func (st *ConfigState) SetCacheGTSFollowMaxSize(v int) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowMaxSize = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowMaxSize' field
+func CacheGTSFollowMaxSizeFlag() string { return "cache-gts-follow-max-size" }
+
+// GetCacheGTSFollowMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowMaxSize' field
+func GetCacheGTSFollowMaxSize() int { return global.GetCacheGTSFollowMaxSize() }
+
+// SetCacheGTSFollowMaxSize safely sets the value for global configuration 'Cache.GTS.FollowMaxSize' field
+func SetCacheGTSFollowMaxSize(v int) { global.SetCacheGTSFollowMaxSize(v) }
+
+// GetCacheGTSFollowTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowTTL' field
+func (st *ConfigState) GetCacheGTSFollowTTL() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.FollowTTL
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSFollowTTL safely sets the Configuration value for state's 'Cache.GTS.FollowTTL' field
+func (st *ConfigState) SetCacheGTSFollowTTL(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowTTL = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowTTLFlag returns the flag name for the 'Cache.GTS.FollowTTL' field
+func CacheGTSFollowTTLFlag() string { return "cache-gts-follow-ttl" }
+
+// GetCacheGTSFollowTTL safely fetches the value for global configuration 'Cache.GTS.FollowTTL' field
+func GetCacheGTSFollowTTL() time.Duration { return global.GetCacheGTSFollowTTL() }
+
+// SetCacheGTSFollowTTL safely sets the value for global configuration 'Cache.GTS.FollowTTL' field
+func SetCacheGTSFollowTTL(v time.Duration) { global.SetCacheGTSFollowTTL(v) }
+
+// GetCacheGTSFollowSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowSweepFreq' field
+func (st *ConfigState) GetCacheGTSFollowSweepFreq() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.FollowSweepFreq
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSFollowSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowSweepFreq' field
+func (st *ConfigState) SetCacheGTSFollowSweepFreq(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowSweepFreq = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowSweepFreq' field
+func CacheGTSFollowSweepFreqFlag() string { return "cache-gts-follow-sweep-freq" }
+
+// GetCacheGTSFollowSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowSweepFreq' field
+func GetCacheGTSFollowSweepFreq() time.Duration { return global.GetCacheGTSFollowSweepFreq() }
+
+// SetCacheGTSFollowSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowSweepFreq' field
+func SetCacheGTSFollowSweepFreq(v time.Duration) { global.SetCacheGTSFollowSweepFreq(v) }
+
+// GetCacheGTSFollowRequestMaxSize safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestMaxSize' field
+func (st *ConfigState) GetCacheGTSFollowRequestMaxSize() (v int) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.FollowRequestMaxSize
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSFollowRequestMaxSize safely sets the Configuration value for state's 'Cache.GTS.FollowRequestMaxSize' field
+func (st *ConfigState) SetCacheGTSFollowRequestMaxSize(v int) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowRequestMaxSize = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowRequestMaxSizeFlag returns the flag name for the 'Cache.GTS.FollowRequestMaxSize' field
+func CacheGTSFollowRequestMaxSizeFlag() string { return "cache-gts-follow-request-max-size" }
+
+// GetCacheGTSFollowRequestMaxSize safely fetches the value for global configuration 'Cache.GTS.FollowRequestMaxSize' field
+func GetCacheGTSFollowRequestMaxSize() int { return global.GetCacheGTSFollowRequestMaxSize() }
+
+// SetCacheGTSFollowRequestMaxSize safely sets the value for global configuration 'Cache.GTS.FollowRequestMaxSize' field
+func SetCacheGTSFollowRequestMaxSize(v int) { global.SetCacheGTSFollowRequestMaxSize(v) }
+
+// GetCacheGTSFollowRequestTTL safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestTTL' field
+func (st *ConfigState) GetCacheGTSFollowRequestTTL() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.FollowRequestTTL
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSFollowRequestTTL safely sets the Configuration value for state's 'Cache.GTS.FollowRequestTTL' field
+func (st *ConfigState) SetCacheGTSFollowRequestTTL(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowRequestTTL = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowRequestTTLFlag returns the flag name for the 'Cache.GTS.FollowRequestTTL' field
+func CacheGTSFollowRequestTTLFlag() string { return "cache-gts-follow-request-ttl" }
+
+// GetCacheGTSFollowRequestTTL safely fetches the value for global configuration 'Cache.GTS.FollowRequestTTL' field
+func GetCacheGTSFollowRequestTTL() time.Duration { return global.GetCacheGTSFollowRequestTTL() }
+
+// SetCacheGTSFollowRequestTTL safely sets the value for global configuration 'Cache.GTS.FollowRequestTTL' field
+func SetCacheGTSFollowRequestTTL(v time.Duration) { global.SetCacheGTSFollowRequestTTL(v) }
+
+// GetCacheGTSFollowRequestSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.FollowRequestSweepFreq' field
+func (st *ConfigState) GetCacheGTSFollowRequestSweepFreq() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.FollowRequestSweepFreq
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSFollowRequestSweepFreq safely sets the Configuration value for state's 'Cache.GTS.FollowRequestSweepFreq' field
+func (st *ConfigState) SetCacheGTSFollowRequestSweepFreq(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.FollowRequestSweepFreq = v
+ st.reloadToViper()
+}
+
+// CacheGTSFollowRequestSweepFreqFlag returns the flag name for the 'Cache.GTS.FollowRequestSweepFreq' field
+func CacheGTSFollowRequestSweepFreqFlag() string { return "cache-gts-follow-request-sweep-freq" }
+
+// GetCacheGTSFollowRequestSweepFreq safely fetches the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
+func GetCacheGTSFollowRequestSweepFreq() time.Duration {
+ return global.GetCacheGTSFollowRequestSweepFreq()
+}
+
+// SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
+func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) }
+
// GetCacheGTSMediaMaxSize safely fetches the Configuration value for state's 'Cache.GTS.MediaMaxSize' field
func (st *ConfigState) GetCacheGTSMediaMaxSize() (v int) {
st.mutex.Lock()
@@ -2878,6 +3030,81 @@ func GetCacheGTSStatusSweepFreq() time.Duration { return global.GetCacheGTSStatu
// SetCacheGTSStatusSweepFreq safely sets the value for global configuration 'Cache.GTS.StatusSweepFreq' field
func SetCacheGTSStatusSweepFreq(v time.Duration) { global.SetCacheGTSStatusSweepFreq(v) }
+// GetCacheGTSStatusFaveMaxSize safely fetches the Configuration value for state's 'Cache.GTS.StatusFaveMaxSize' field
+func (st *ConfigState) GetCacheGTSStatusFaveMaxSize() (v int) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.StatusFaveMaxSize
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSStatusFaveMaxSize safely sets the Configuration value for state's 'Cache.GTS.StatusFaveMaxSize' field
+func (st *ConfigState) SetCacheGTSStatusFaveMaxSize(v int) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.StatusFaveMaxSize = v
+ st.reloadToViper()
+}
+
+// CacheGTSStatusFaveMaxSizeFlag returns the flag name for the 'Cache.GTS.StatusFaveMaxSize' field
+func CacheGTSStatusFaveMaxSizeFlag() string { return "cache-gts-status-fave-max-size" }
+
+// GetCacheGTSStatusFaveMaxSize safely fetches the value for global configuration 'Cache.GTS.StatusFaveMaxSize' field
+func GetCacheGTSStatusFaveMaxSize() int { return global.GetCacheGTSStatusFaveMaxSize() }
+
+// SetCacheGTSStatusFaveMaxSize safely sets the value for global configuration 'Cache.GTS.StatusFaveMaxSize' field
+func SetCacheGTSStatusFaveMaxSize(v int) { global.SetCacheGTSStatusFaveMaxSize(v) }
+
+// GetCacheGTSStatusFaveTTL safely fetches the Configuration value for state's 'Cache.GTS.StatusFaveTTL' field
+func (st *ConfigState) GetCacheGTSStatusFaveTTL() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.StatusFaveTTL
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSStatusFaveTTL safely sets the Configuration value for state's 'Cache.GTS.StatusFaveTTL' field
+func (st *ConfigState) SetCacheGTSStatusFaveTTL(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.StatusFaveTTL = v
+ st.reloadToViper()
+}
+
+// CacheGTSStatusFaveTTLFlag returns the flag name for the 'Cache.GTS.StatusFaveTTL' field
+func CacheGTSStatusFaveTTLFlag() string { return "cache-gts-status-fave-ttl" }
+
+// GetCacheGTSStatusFaveTTL safely fetches the value for global configuration 'Cache.GTS.StatusFaveTTL' field
+func GetCacheGTSStatusFaveTTL() time.Duration { return global.GetCacheGTSStatusFaveTTL() }
+
+// SetCacheGTSStatusFaveTTL safely sets the value for global configuration 'Cache.GTS.StatusFaveTTL' field
+func SetCacheGTSStatusFaveTTL(v time.Duration) { global.SetCacheGTSStatusFaveTTL(v) }
+
+// GetCacheGTSStatusFaveSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.StatusFaveSweepFreq' field
+func (st *ConfigState) GetCacheGTSStatusFaveSweepFreq() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.GTS.StatusFaveSweepFreq
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheGTSStatusFaveSweepFreq safely sets the Configuration value for state's 'Cache.GTS.StatusFaveSweepFreq' field
+func (st *ConfigState) SetCacheGTSStatusFaveSweepFreq(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.GTS.StatusFaveSweepFreq = v
+ st.reloadToViper()
+}
+
+// CacheGTSStatusFaveSweepFreqFlag returns the flag name for the 'Cache.GTS.StatusFaveSweepFreq' field
+func CacheGTSStatusFaveSweepFreqFlag() string { return "cache-gts-status-fave-sweep-freq" }
+
+// GetCacheGTSStatusFaveSweepFreq safely fetches the value for global configuration 'Cache.GTS.StatusFaveSweepFreq' field
+func GetCacheGTSStatusFaveSweepFreq() time.Duration { return global.GetCacheGTSStatusFaveSweepFreq() }
+
+// SetCacheGTSStatusFaveSweepFreq safely sets the value for global configuration 'Cache.GTS.StatusFaveSweepFreq' field
+func SetCacheGTSStatusFaveSweepFreq(v time.Duration) { global.SetCacheGTSStatusFaveSweepFreq(v) }
+
// GetCacheGTSTombstoneMaxSize safely fetches the Configuration value for state's 'Cache.GTS.TombstoneMaxSize' field
func (st *ConfigState) GetCacheGTSTombstoneMaxSize() (v int) {
st.mutex.Lock()
@@ -3103,6 +3330,81 @@ func GetCacheGTSWebfingerSweepFreq() time.Duration { return global.GetCacheGTSWe
// SetCacheGTSWebfingerSweepFreq safely sets the value for global configuration 'Cache.GTS.WebfingerSweepFreq' field
func SetCacheGTSWebfingerSweepFreq(v time.Duration) { global.SetCacheGTSWebfingerSweepFreq(v) }
+// GetCacheVisibilityMaxSize safely fetches the Configuration value for state's 'Cache.VisibilityMaxSize' field
+func (st *ConfigState) GetCacheVisibilityMaxSize() (v int) {
+ st.mutex.Lock()
+ v = st.config.Cache.VisibilityMaxSize
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheVisibilityMaxSize safely sets the Configuration value for state's 'Cache.VisibilityMaxSize' field
+func (st *ConfigState) SetCacheVisibilityMaxSize(v int) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.VisibilityMaxSize = v
+ st.reloadToViper()
+}
+
+// CacheVisibilityMaxSizeFlag returns the flag name for the 'Cache.VisibilityMaxSize' field
+func CacheVisibilityMaxSizeFlag() string { return "cache-visibility-max-size" }
+
+// GetCacheVisibilityMaxSize safely fetches the value for global configuration 'Cache.VisibilityMaxSize' field
+func GetCacheVisibilityMaxSize() int { return global.GetCacheVisibilityMaxSize() }
+
+// SetCacheVisibilityMaxSize safely sets the value for global configuration 'Cache.VisibilityMaxSize' field
+func SetCacheVisibilityMaxSize(v int) { global.SetCacheVisibilityMaxSize(v) }
+
+// GetCacheVisibilityTTL safely fetches the Configuration value for state's 'Cache.VisibilityTTL' field
+func (st *ConfigState) GetCacheVisibilityTTL() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.VisibilityTTL
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheVisibilityTTL safely sets the Configuration value for state's 'Cache.VisibilityTTL' field
+func (st *ConfigState) SetCacheVisibilityTTL(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.VisibilityTTL = v
+ st.reloadToViper()
+}
+
+// CacheVisibilityTTLFlag returns the flag name for the 'Cache.VisibilityTTL' field
+func CacheVisibilityTTLFlag() string { return "cache-visibility-ttl" }
+
+// GetCacheVisibilityTTL safely fetches the value for global configuration 'Cache.VisibilityTTL' field
+func GetCacheVisibilityTTL() time.Duration { return global.GetCacheVisibilityTTL() }
+
+// SetCacheVisibilityTTL safely sets the value for global configuration 'Cache.VisibilityTTL' field
+func SetCacheVisibilityTTL(v time.Duration) { global.SetCacheVisibilityTTL(v) }
+
+// GetCacheVisibilitySweepFreq safely fetches the Configuration value for state's 'Cache.VisibilitySweepFreq' field
+func (st *ConfigState) GetCacheVisibilitySweepFreq() (v time.Duration) {
+ st.mutex.Lock()
+ v = st.config.Cache.VisibilitySweepFreq
+ st.mutex.Unlock()
+ return
+}
+
+// SetCacheVisibilitySweepFreq safely sets the Configuration value for state's 'Cache.VisibilitySweepFreq' field
+func (st *ConfigState) SetCacheVisibilitySweepFreq(v time.Duration) {
+ st.mutex.Lock()
+ defer st.mutex.Unlock()
+ st.config.Cache.VisibilitySweepFreq = v
+ st.reloadToViper()
+}
+
+// CacheVisibilitySweepFreqFlag returns the flag name for the 'Cache.VisibilitySweepFreq' field
+func CacheVisibilitySweepFreqFlag() string { return "cache-visibility-sweep-freq" }
+
+// GetCacheVisibilitySweepFreq safely fetches the value for global configuration 'Cache.VisibilitySweepFreq' field
+func GetCacheVisibilitySweepFreq() time.Duration { return global.GetCacheVisibilitySweepFreq() }
+
+// SetCacheVisibilitySweepFreq safely sets the value for global configuration 'Cache.VisibilitySweepFreq' field
+func SetCacheVisibilitySweepFreq(v time.Duration) { global.SetCacheVisibilitySweepFreq(v) }
+
// GetAdminAccountUsername safely fetches the Configuration value for state's 'AdminAccountUsername' field
func (st *ConfigState) GetAdminAccountUsername() (v string) {
st.mutex.Lock()
diff --git a/internal/db/account.go b/internal/db/account.go
index 6ecfea018..4a08918b0 100644
--- a/internal/db/account.go
+++ b/internal/db/account.go
@@ -41,6 +41,21 @@ type Account interface {
// GetAccountByPubkeyID returns one account with the given public key URI (ID), or an error if something goes wrong.
GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmodel.Account, Error)
+ // GetAccountByInboxURI returns one account with the given inbox_uri, or an error if something goes wrong.
+ GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // GetAccountByOutboxURI returns one account with the given outbox_uri, or an error if something goes wrong.
+ GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // GetAccountByFollowingURI returns one account with the given following_uri, or an error if something goes wrong.
+ GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // GetAccountByFollowersURI returns one account with the given followers_uri, or an error if something goes wrong.
+ GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, Error)
+
+ // PopulateAccount ensures that all sub-models of an account are populated (e.g. avatar, header etc).
+ PopulateAccount(ctx context.Context, account *gtsmodel.Account) error
+
// PutAccount puts one account in the database.
PutAccount(ctx context.Context, account *gtsmodel.Account) Error
diff --git a/internal/db/bundb/account.go b/internal/db/bundb/account.go
index df73168e2..ccf7aaa46 100644
--- a/internal/db/bundb/account.go
+++ b/internal/db/bundb/account.go
@@ -20,11 +20,13 @@
import (
"context"
"errors"
+ "fmt"
"strings"
"time"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -37,18 +39,15 @@ type accountDB struct {
state *state.State
}
-func (a *accountDB) newAccountQ(account *gtsmodel.Account) *bun.SelectQuery {
- return a.conn.
- NewSelect().
- Model(account)
-}
-
func (a *accountDB) GetAccountByID(ctx context.Context, id string) (*gtsmodel.Account, db.Error) {
return a.getAccount(
ctx,
"ID",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.id"), id).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.id"), id).
+ Scan(ctx)
},
id,
)
@@ -59,7 +58,10 @@ func (a *accountDB) GetAccountByURI(ctx context.Context, uri string) (*gtsmodel.
ctx,
"URI",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.uri"), uri).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.uri"), uri).
+ Scan(ctx)
},
uri,
)
@@ -70,7 +72,10 @@ func (a *accountDB) GetAccountByURL(ctx context.Context, url string) (*gtsmodel.
ctx,
"URL",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.url"), url).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.url"), url).
+ Scan(ctx)
},
url,
)
@@ -81,7 +86,8 @@ func (a *accountDB) GetAccountByUsernameDomain(ctx context.Context, username str
ctx,
"Username.Domain",
func(account *gtsmodel.Account) error {
- q := a.newAccountQ(account)
+ q := a.conn.NewSelect().
+ Model(account)
if domain != "" {
q = q.
@@ -105,12 +111,71 @@ func (a *accountDB) GetAccountByPubkeyID(ctx context.Context, id string) (*gtsmo
ctx,
"PublicKeyURI",
func(account *gtsmodel.Account) error {
- return a.newAccountQ(account).Where("? = ?", bun.Ident("account.public_key_uri"), id).Scan(ctx)
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.public_key_uri"), id).
+ Scan(ctx)
},
id,
)
}
+func (a *accountDB) GetAccountByInboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "InboxURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.inbox_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByOutboxURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "OutboxURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.outbox_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByFollowersURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "FollowersURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.followers_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (a *accountDB) GetAccountByFollowingURI(ctx context.Context, uri string) (*gtsmodel.Account, db.Error) {
+ return a.getAccount(
+ ctx,
+ "FollowingURI",
+ func(account *gtsmodel.Account) error {
+ return a.conn.NewSelect().
+ Model(account).
+ Where("? = ?", bun.Ident("account.following_uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
func (a *accountDB) GetInstanceAccount(ctx context.Context, domain string) (*gtsmodel.Account, db.Error) {
var username string
@@ -141,33 +206,58 @@ func (a *accountDB) getAccount(ctx context.Context, lookup string, dbQuery func(
return nil, err
}
- if account.AvatarMediaAttachmentID != "" {
- // Set the account's related avatar
- account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.AvatarMediaAttachmentID)
- if err != nil {
- log.Errorf(ctx, "error getting account %s avatar: %v", account.ID, err)
- }
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return account, nil
}
- if account.HeaderMediaAttachmentID != "" {
- // Set the account's related header
- account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(ctx, account.HeaderMediaAttachmentID)
- if err != nil {
- log.Errorf(ctx, "error getting account %s header: %v", account.ID, err)
- }
- }
-
- if len(account.EmojiIDs) > 0 {
- // Set the account's related emojis
- account.Emojis, err = a.state.DB.GetEmojisByIDs(ctx, account.EmojiIDs)
- if err != nil {
- log.Errorf(ctx, "error getting account %s emojis: %v", account.ID, err)
- }
+ // Further populate the account fields where applicable.
+ if err := a.PopulateAccount(ctx, account); err != nil {
+ return nil, err
}
return account, nil
}
+func (a *accountDB) PopulateAccount(ctx context.Context, account *gtsmodel.Account) error {
+ var err error
+
+ if account.AvatarMediaAttachment == nil && account.AvatarMediaAttachmentID != "" {
+ // Account avatar attachment is not set, fetch from database.
+ account.AvatarMediaAttachment, err = a.state.DB.GetAttachmentByID(
+ ctx, // these are already barebones
+ account.AvatarMediaAttachmentID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating account avatar: %w", err)
+ }
+ }
+
+ if account.HeaderMediaAttachment == nil && account.HeaderMediaAttachmentID != "" {
+ // Account header attachment is not set, fetch from database.
+ account.HeaderMediaAttachment, err = a.state.DB.GetAttachmentByID(
+ ctx, // these are already barebones
+ account.HeaderMediaAttachmentID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating account header: %w", err)
+ }
+ }
+
+ if !account.EmojisPopulated() {
+ // Account emojis are out-of-date with IDs, repopulate.
+ account.Emojis, err = a.state.DB.GetEmojisByIDs(
+ ctx, // these are already barebones
+ account.EmojiIDs,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating account emojis: %w", err)
+ }
+ }
+
+ return nil
+}
+
func (a *accountDB) PutAccount(ctx context.Context, account *gtsmodel.Account) db.Error {
return a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
@@ -198,7 +288,7 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
columns = append(columns, "updated_at")
}
- return a.state.Caches.GTS.Account().Store(account, func() error {
+ err := a.state.Caches.GTS.Account().Store(account, func() error {
// It is safe to run this database transaction within cache.Store
// as the cache does not attempt a mutex lock until AFTER hook.
//
@@ -234,6 +324,11 @@ func (a *accountDB) UpdateAccount(ctx context.Context, account *gtsmodel.Account
return err
})
})
+ if err != nil {
+ return err
+ }
+
+ return nil
}
func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
@@ -258,7 +353,9 @@ func (a *accountDB) DeleteAccount(ctx context.Context, id string) db.Error {
return err
}
+ // Invalidate account from database lookups.
a.state.Caches.GTS.Account().Invalidate("ID", id)
+
return nil
}
diff --git a/internal/db/bundb/account_test.go b/internal/db/bundb/account_test.go
index b7e8aaadc..2241ab783 100644
--- a/internal/db/bundb/account_test.go
+++ b/internal/db/bundb/account_test.go
@@ -21,6 +21,8 @@
"context"
"crypto/rand"
"crypto/rsa"
+ "errors"
+ "reflect"
"strings"
"testing"
"time"
@@ -61,44 +63,149 @@ func (suite *AccountTestSuite) TestGetAccountStatusesMediaOnly() {
suite.Len(statuses, 1)
}
-func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
- account, err := suite.db.GetAccountByID(context.Background(), suite.testAccounts["local_account_1"].ID)
- if err != nil {
- suite.FailNow(err.Error())
+func (suite *AccountTestSuite) TestGetAccountBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 account models are equal.
+ isEqual := func(a1, a2 gtsmodel.Account) bool {
+ // Clear populated sub-models.
+ a1.HeaderMediaAttachment = nil
+ a2.HeaderMediaAttachment = nil
+ a1.AvatarMediaAttachment = nil
+ a2.AvatarMediaAttachment = nil
+ a1.Emojis = nil
+ a2.Emojis = nil
+
+ // Clear database-set fields.
+ a1.CreatedAt = time.Time{}
+ a2.CreatedAt = time.Time{}
+ a1.UpdatedAt = time.Time{}
+ a2.UpdatedAt = time.Time{}
+
+ // Manually compare keys.
+ pk1 := a1.PublicKey
+ pv1 := a1.PrivateKey
+ pk2 := a2.PublicKey
+ pv2 := a2.PrivateKey
+ a1.PublicKey = nil
+ a1.PrivateKey = nil
+ a2.PublicKey = nil
+ a2.PrivateKey = nil
+
+ return reflect.DeepEqual(a1, a2) &&
+ ((pk1 == nil && pk2 == nil) || pk1.Equal(pk2)) &&
+ ((pv1 == nil && pv2 == nil) || pv1.Equal(pv2))
}
- suite.NotNil(account)
- suite.NotNil(account.AvatarMediaAttachment)
- suite.NotEmpty(account.AvatarMediaAttachment.URL)
- suite.NotNil(account.HeaderMediaAttachment)
- suite.NotEmpty(account.HeaderMediaAttachment.URL)
-}
-func (suite *AccountTestSuite) TestGetAccountByUsernameDomain() {
- testAccount1 := suite.testAccounts["local_account_1"]
- account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount1.Username, testAccount1.Domain)
- suite.NoError(err)
- suite.NotNil(account1)
+ for _, account := range suite.testAccounts {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.Account, error){
+ "id": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByID(ctx, account.ID)
+ },
- testAccount2 := suite.testAccounts["remote_account_1"]
- account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount2.Username, testAccount2.Domain)
- suite.NoError(err)
- suite.NotNil(account2)
-}
+ "uri": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByURI(ctx, account.URI)
+ },
-func (suite *AccountTestSuite) TestGetAccountByUsernameDomainMixedCase() {
- testAccount := suite.testAccounts["remote_account_2"]
+ "url": func() (*gtsmodel.Account, error) {
+ if account.URL == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByURL(ctx, account.URL)
+ },
- account1, err := suite.db.GetAccountByUsernameDomain(context.Background(), testAccount.Username, testAccount.Domain)
- suite.NoError(err)
- suite.NotNil(account1)
+ "username@domain": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByUsernameDomain(ctx, account.Username, account.Domain)
+ },
- account2, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToUpper(testAccount.Username), testAccount.Domain)
- suite.NoError(err)
- suite.NotNil(account2)
+ "username_upper@domain": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByUsernameDomain(ctx, strings.ToUpper(account.Username), account.Domain)
+ },
- account3, err := suite.db.GetAccountByUsernameDomain(context.Background(), strings.ToLower(testAccount.Username), testAccount.Domain)
- suite.NoError(err)
- suite.NotNil(account3)
+ "username_lower@domain": func() (*gtsmodel.Account, error) {
+ return suite.db.GetAccountByUsernameDomain(ctx, strings.ToLower(account.Username), account.Domain)
+ },
+
+ "public_key_uri": func() (*gtsmodel.Account, error) {
+ if account.PublicKeyURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByPubkeyID(ctx, account.PublicKeyURI)
+ },
+
+ "inbox_uri": func() (*gtsmodel.Account, error) {
+ if account.InboxURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByInboxURI(ctx, account.InboxURI)
+ },
+
+ "outbox_uri": func() (*gtsmodel.Account, error) {
+ if account.OutboxURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByOutboxURI(ctx, account.OutboxURI)
+ },
+
+ "following_uri": func() (*gtsmodel.Account, error) {
+ if account.FollowingURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByFollowingURI(ctx, account.FollowingURI)
+ },
+
+ "followers_uri": func() (*gtsmodel.Account, error) {
+ if account.FollowersURI == "" {
+ return nil, sentinelErr
+ }
+ return suite.db.GetAccountByFollowersURI(ctx, account.FollowersURI)
+ },
+ } {
+
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkAcc, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received account data.
+ if !isEqual(*checkAcc, *account) {
+ t.Errorf("account does not contain expected data: %+v", checkAcc)
+ continue
+ }
+
+ // Check that avatar attachment populated.
+ if account.AvatarMediaAttachmentID != "" &&
+ (checkAcc.AvatarMediaAttachment == nil || checkAcc.AvatarMediaAttachment.ID != account.AvatarMediaAttachmentID) {
+ t.Errorf("account avatar media attachment not correctly populated for: %+v", account)
+ continue
+ }
+
+ // Check that header attachment populated.
+ if account.HeaderMediaAttachmentID != "" &&
+ (checkAcc.HeaderMediaAttachment == nil || checkAcc.HeaderMediaAttachment.ID != account.HeaderMediaAttachmentID) {
+ t.Errorf("account header media attachment not correctly populated for: %+v", account)
+ continue
+ }
+ }
+ }
}
func (suite *AccountTestSuite) TestUpdateAccount() {
diff --git a/internal/db/bundb/basic_test.go b/internal/db/bundb/basic_test.go
index f238f2273..a24deac9e 100644
--- a/internal/db/bundb/basic_test.go
+++ b/internal/db/bundb/basic_test.go
@@ -19,6 +19,8 @@
import (
"context"
+ "crypto/rand"
+ "crypto/rsa"
"testing"
"time"
@@ -40,6 +42,12 @@ func (suite *BasicTestSuite) TestGetAccountByID() {
}
func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
+ key, err := rsa.GenerateKey(rand.Reader, 2048)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // Create an account that only just matches constraints.
testAccount := >smodel.Account{
ID: "01GADR1AH9VCKH8YYCM86XSZ00",
Username: "test",
@@ -49,6 +57,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
OutboxURI: "https://example.org/users/test/outbox",
ActorType: "Person",
PublicKeyURI: "https://example.org/test#main-key",
+ PublicKey: &key.PublicKey,
}
if err := suite.db.Put(context.Background(), testAccount); err != nil {
@@ -99,7 +108,7 @@ func (suite *BasicTestSuite) TestPutAccountWithBunDefaultFields() {
suite.Empty(a.FeaturedCollectionURI)
suite.Equal(testAccount.ActorType, a.ActorType)
suite.Nil(a.PrivateKey)
- suite.Nil(a.PublicKey)
+ suite.EqualValues(key.PublicKey, *a.PublicKey)
suite.Equal(testAccount.PublicKeyURI, a.PublicKeyURI)
suite.Zero(a.SensitizedAt)
suite.Zero(a.SilencedAt)
diff --git a/internal/db/bundb/media.go b/internal/db/bundb/media.go
index 1a9d3be05..d17d64b35 100644
--- a/internal/db/bundb/media.go
+++ b/internal/db/bundb/media.go
@@ -47,6 +47,24 @@ func(attachment *gtsmodel.MediaAttachment) error {
)
}
+func (m *mediaDB) GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error) {
+ attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
+
+ for _, id := range ids {
+ // Attempt fetch from DB
+ attachment, err := m.GetAttachmentByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting attachment %q: %v", id, err)
+ continue
+ }
+
+ // Append attachment
+ attachments = append(attachments, attachment)
+ }
+
+ return attachments, nil
+}
+
func (m *mediaDB) getAttachment(ctx context.Context, lookup string, dbQuery func(*gtsmodel.MediaAttachment) error, keyParts ...any) (*gtsmodel.MediaAttachment, db.Error) {
return m.state.Caches.GTS.Media().Load(lookup, func() (*gtsmodel.MediaAttachment, error) {
var attachment gtsmodel.MediaAttachment
@@ -118,7 +136,7 @@ func (m *mediaDB) GetRemoteOlderThan(ctx context.Context, olderThan time.Time, l
return nil, m.conn.ProcessError(err)
}
- return m.getAttachments(ctx, attachmentIDs)
+ return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountRemoteOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
@@ -163,7 +181,7 @@ func (m *mediaDB) GetAvatarsAndHeaders(ctx context.Context, maxID string, limit
return nil, m.conn.ProcessError(err)
}
- return m.getAttachments(ctx, attachmentIDs)
+ return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time, limit int) ([]*gtsmodel.MediaAttachment, db.Error) {
@@ -189,7 +207,7 @@ func (m *mediaDB) GetLocalUnattachedOlderThan(ctx context.Context, olderThan tim
return nil, m.conn.ProcessError(err)
}
- return m.getAttachments(ctx, attachmentIDs)
+ return m.GetAttachmentsByIDs(ctx, attachmentIDs)
}
func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan time.Time) (int, db.Error) {
@@ -211,21 +229,3 @@ func (m *mediaDB) CountLocalUnattachedOlderThan(ctx context.Context, olderThan t
return count, nil
}
-
-func (m *mediaDB) getAttachments(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, db.Error) {
- attachments := make([]*gtsmodel.MediaAttachment, 0, len(ids))
-
- for _, id := range ids {
- // Attempt fetch from DB
- attachment, err := m.GetAttachmentByID(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting attachment %q: %v", id, err)
- continue
- }
-
- // Append attachment
- attachments = append(attachments, attachment)
- }
-
- return attachments, nil
-}
diff --git a/internal/db/bundb/mention.go b/internal/db/bundb/mention.go
index 3a543f3c2..e64d6dac4 100644
--- a/internal/db/bundb/mention.go
+++ b/internal/db/bundb/mention.go
@@ -19,8 +19,10 @@
import (
"context"
+ "fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -32,20 +34,13 @@ type mentionDB struct {
state *state.State
}
-func (m *mentionDB) newMentionQ(i interface{}) *bun.SelectQuery {
- return m.conn.
- NewSelect().
- Model(i).
- Relation("Status").
- Relation("OriginAccount").
- Relation("TargetAccount")
-}
-
func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mention, db.Error) {
- return m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
+ mention, err := m.state.Caches.GTS.Mention().Load("ID", func() (*gtsmodel.Mention, error) {
var mention gtsmodel.Mention
- q := m.newMentionQ(&mention).
+ q := m.conn.
+ NewSelect().
+ Model(&mention).
Where("? = ?", bun.Ident("mention.id"), id)
if err := q.Scan(ctx); err != nil {
@@ -54,6 +49,38 @@ func (m *mentionDB) GetMention(ctx context.Context, id string) (*gtsmodel.Mentio
return &mention, nil
}, id)
+ if err != nil {
+ return nil, err
+ }
+
+ // Set the mention originating status.
+ mention.Status, err = m.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ mention.StatusID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error populating mention status: %w", err)
+ }
+
+ // Set the mention origin account model.
+ mention.OriginAccount, err = m.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ mention.OriginAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error populating mention origin account: %w", err)
+ }
+
+ // Set the mention target account model.
+ mention.TargetAccount, err = m.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ mention.TargetAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error populating mention target account: %w", err)
+ }
+
+ return mention, nil
}
func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, db.Error) {
@@ -73,3 +100,25 @@ func (m *mentionDB) GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.
return mentions, nil
}
+
+func (m *mentionDB) PutMention(ctx context.Context, mention *gtsmodel.Mention) error {
+ return m.state.Caches.GTS.Mention().Store(mention, func() error {
+ _, err := m.conn.NewInsert().Model(mention).Exec(ctx)
+ return m.conn.ProcessError(err)
+ })
+}
+
+func (m *mentionDB) DeleteMentionByID(ctx context.Context, id string) error {
+ if _, err := m.conn.
+ NewDelete().
+ Table("mentions").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return m.conn.ProcessError(err)
+ }
+
+ // Invalidate mention from the lookup cache.
+ m.state.Caches.GTS.Mention().Invalidate("ID", id)
+
+ return nil
+}
diff --git a/internal/db/bundb/migrations/20230328105630_chore_refactoring.go b/internal/db/bundb/migrations/20230328105630_chore_refactoring.go
new file mode 100644
index 000000000..3bf9d59ef
--- /dev/null
+++ b/internal/db/bundb/migrations/20230328105630_chore_refactoring.go
@@ -0,0 +1,167 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package migrations
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/uptrace/bun"
+)
+
+func init() {
+ up := func(ctx context.Context, db *bun.DB) error {
+ // To update unique constraint on public key, we need to migrate accounts into a new table.
+ // See section 7 here: https://www.sqlite.org/lang_altertable.html
+
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ // Create the new accounts table.
+ if _, err := tx.
+ NewCreateTable().
+ ModelTableExpr("new_accounts").
+ Model(>smodel.Account{}).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // If we don't specify columns explicitly,
+ // Postgres gives the following error when
+ // transferring accounts to new_accounts:
+ //
+ // ERROR: column "fetched_at" is of type timestamp with time zone but expression is of type character varying at character 35
+ // HINT: You will need to rewrite or cast the expression.
+ //
+ // Rather than do funky casting to fix this,
+ // it's simpler to just specify all columns.
+ columns := []string{
+ "id",
+ "created_at",
+ "updated_at",
+ "fetched_at",
+ "username",
+ "domain",
+ "avatar_media_attachment_id",
+ "avatar_remote_url",
+ "header_media_attachment_id",
+ "header_remote_url",
+ "display_name",
+ "emojis",
+ "fields",
+ "note",
+ "note_raw",
+ "memorial",
+ "also_known_as",
+ "moved_to_account_id",
+ "bot",
+ "reason",
+ "locked",
+ "discoverable",
+ "privacy",
+ "sensitive",
+ "language",
+ "status_content_type",
+ "custom_css",
+ "uri",
+ "url",
+ "inbox_uri",
+ "shared_inbox_uri",
+ "outbox_uri",
+ "following_uri",
+ "followers_uri",
+ "featured_collection_uri",
+ "actor_type",
+ "private_key",
+ "public_key",
+ "public_key_uri",
+ "sensitized_at",
+ "silenced_at",
+ "suspended_at",
+ "hide_collections",
+ "suspension_origin",
+ "enable_rss",
+ }
+
+ // Copy all accounts to the new table.
+ if _, err := tx.
+ NewInsert().
+ Table("new_accounts").
+ Table("accounts").
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Drop the old table.
+ if _, err := tx.
+ NewDropTable().
+ Table("accounts").
+ Exec(ctx); err != nil {
+ return err
+ }
+
+ // Rename new table to old table.
+ if _, err := tx.
+ ExecContext(
+ ctx,
+ "ALTER TABLE ? RENAME TO ?",
+ bun.Ident("new_accounts"),
+ bun.Ident("accounts"),
+ ); err != nil {
+ return err
+ }
+
+ // Add all account indexes to the new table.
+ for index, columns := range map[string][]string{
+ // Standard indices.
+ "accounts_id_idx": {"id"},
+ "accounts_suspended_at_idx": {"suspended_at"},
+ "accounts_domain_idx": {"domain"},
+ "accounts_username_domain_idx": {"username", "domain"},
+ // URI indices.
+ "accounts_uri_idx": {"uri"},
+ "accounts_url_idx": {"url"},
+ "accounts_inbox_uri_idx": {"inbox_uri"},
+ "accounts_outbox_uri_idx": {"outbox_uri"},
+ "accounts_followers_uri_idx": {"followers_uri"},
+ "accounts_following_uri_idx": {"following_uri"},
+ "accounts_public_key_uri_idx": {"public_key_uri"},
+ } {
+ if _, err := tx.
+ NewCreateIndex().
+ Table("accounts").
+ Index(index).
+ Column(columns...).
+ Exec(ctx); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ }
+
+ down := func(ctx context.Context, db *bun.DB) error {
+ return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
+ return nil
+ })
+ }
+
+ if err := Migrations.Register(up, down); err != nil {
+ panic(err)
+ }
+}
diff --git a/internal/db/bundb/notification.go b/internal/db/bundb/notification.go
index b1e7f45ff..f32aed092 100644
--- a/internal/db/bundb/notification.go
+++ b/internal/db/bundb/notification.go
@@ -33,7 +33,7 @@ type notificationDB struct {
state *state.State
}
-func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
+func (n *notificationDB) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, db.Error) {
return n.state.Caches.GTS.Notification().Load("ID", func() (*gtsmodel.Notification, error) {
var notif gtsmodel.Notification
@@ -48,7 +48,7 @@ func (n *notificationDB) GetNotification(ctx context.Context, id string) (*gtsmo
}, id)
}
-func (n *notificationDB) GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
+func (n *notificationDB) GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// Ensure reasonable
if limit < 0 {
limit = 0
@@ -92,7 +92,7 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
// reason for this is that for each notif, we can instead get it from our cache if it's cached
for _, id := range notifIDs {
// Attempt fetch from DB
- notif, err := n.GetNotification(ctx, id)
+ notif, err := n.GetNotificationByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting notification %q: %v", id, err)
continue
@@ -105,7 +105,14 @@ func (n *notificationDB) GetNotifications(ctx context.Context, accountID string,
return notifs, nil
}
-func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.Error {
+func (n *notificationDB) PutNotification(ctx context.Context, notif *gtsmodel.Notification) error {
+ return n.state.Caches.GTS.Notification().Store(notif, func() error {
+ _, err := n.conn.NewInsert().Model(notif).Exec(ctx)
+ return n.conn.ProcessError(err)
+ })
+}
+
+func (n *notificationDB) DeleteNotificationByID(ctx context.Context, id string) db.Error {
if _, err := n.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
@@ -118,19 +125,23 @@ func (n *notificationDB) DeleteNotification(ctx context.Context, id string) db.E
return nil
}
-func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
+func (n *notificationDB) DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) db.Error {
if targetAccountID == "" && originAccountID == "" {
return errors.New("DeleteNotifications: one of targetAccountID or originAccountID must be set")
}
// Capture notification IDs in a RETURNING statement.
- ids := []string{}
+ var ids []string
q := n.conn.
NewDelete().
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Returning("?", bun.Ident("id"))
+ if len(types) > 0 {
+ q = q.Where("? IN (?)", bun.Ident("notification.notification_type"), bun.In(types))
+ }
+
if targetAccountID != "" {
q = q.Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID)
}
@@ -153,7 +164,7 @@ func (n *notificationDB) DeleteNotifications(ctx context.Context, targetAccountI
func (n *notificationDB) DeleteNotificationsForStatus(ctx context.Context, statusID string) db.Error {
// Capture notification IDs in a RETURNING statement.
- ids := []string{}
+ var ids []string
q := n.conn.
NewDelete().
diff --git a/internal/db/bundb/notification_test.go b/internal/db/bundb/notification_test.go
index 117fc329c..bdee911b3 100644
--- a/internal/db/bundb/notification_test.go
+++ b/internal/db/bundb/notification_test.go
@@ -85,11 +85,11 @@ type NotificationTestSuite struct {
BunDBStandardTestSuite
}
-func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
+func (suite *NotificationTestSuite) TestGetAccountNotificationsWithSpam() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
before := time.Now()
- notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
+ notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
timeTaken := time.Since(before)
fmt.Printf("\n\n\n withSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
@@ -100,10 +100,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithSpam() {
}
}
-func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
+func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() {
testAccount := suite.testAccounts["local_account_1"]
before := time.Now()
- notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
+ notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
timeTaken := time.Since(before)
fmt.Printf("\n\n\n withoutSpam: got %d notifications in %s\n\n\n", len(notifications), timeTaken)
@@ -117,10 +117,10 @@ func (suite *NotificationTestSuite) TestGetNotificationsWithoutSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
- err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
+ err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
- notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
+ notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
suite.NotNil(notifications)
suite.Empty(notifications)
@@ -129,10 +129,10 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
suite.spamNotifs()
testAccount := suite.testAccounts["local_account_1"]
- err := suite.db.DeleteNotifications(context.Background(), testAccount.ID, "")
+ err := suite.db.DeleteNotifications(context.Background(), nil, testAccount.ID, "")
suite.NoError(err)
- notifications, err := suite.db.GetNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
+ notifications, err := suite.db.GetAccountNotifications(context.Background(), testAccount.ID, []string{}, 20, id.Highest, id.Lowest)
suite.NoError(err)
suite.NotNil(notifications)
suite.Empty(notifications)
@@ -146,7 +146,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAccount() {
testAccount := suite.testAccounts["local_account_2"]
- if err := suite.db.DeleteNotifications(context.Background(), "", testAccount.ID); err != nil {
+ if err := suite.db.DeleteNotifications(context.Background(), nil, "", testAccount.ID); err != nil {
suite.FailNow(err.Error())
}
@@ -166,7 +166,7 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsOriginatingFromAndTar
originAccount := suite.testAccounts["local_account_2"]
targetAccount := suite.testAccounts["admin_account"]
- if err := suite.db.DeleteNotifications(context.Background(), targetAccount.ID, originAccount.ID); err != nil {
+ if err := suite.db.DeleteNotifications(context.Background(), nil, targetAccount.ID, originAccount.ID); err != nil {
suite.FailNow(err.Error())
}
diff --git a/internal/db/bundb/relationship.go b/internal/db/bundb/relationship.go
index 21a29b5dc..82559a213 100644
--- a/internal/db/bundb/relationship.go
+++ b/internal/db/bundb/relationship.go
@@ -23,8 +23,8 @@
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/uptrace/bun"
)
@@ -34,603 +34,212 @@ type relationshipDB struct {
state *state.State
}
-func (r *relationshipDB) IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
- // Look for a block in direction of account1->account2
- block1, err := r.getBlock(ctx, account1, account2)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return false, err
- }
-
- if block1 != nil {
- // account1 blocks account2
- return true, nil
- } else if !eitherDirection {
- // Don't check for mutli-directional
- return false, nil
- }
-
- // Look for a block in direction of account2->account1
- block2, err := r.getBlock(ctx, account2, account1)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return false, err
- }
-
- return (block2 != nil), nil
-}
-
-func (r *relationshipDB) GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
- // Fetch block from database
- block, err := r.getBlock(ctx, account1, account2)
- if err != nil {
- return nil, err
- }
-
- // Set the block originating account
- block.Account, err = r.state.DB.GetAccountByID(ctx, block.AccountID)
- if err != nil {
- return nil, err
- }
-
- // Set the block target account
- block.TargetAccount, err = r.state.DB.GetAccountByID(ctx, block.TargetAccountID)
- if err != nil {
- return nil, err
- }
-
- return block, nil
-}
-
-func (r *relationshipDB) getBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, db.Error) {
- return r.state.Caches.GTS.Block().Load("AccountID.TargetAccountID", func() (*gtsmodel.Block, error) {
- var block gtsmodel.Block
-
- q := r.conn.NewSelect().Model(&block).
- Where("? = ?", bun.Ident("block.account_id"), account1).
- Where("? = ?", bun.Ident("block.target_account_id"), account2)
- if err := q.Scan(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- return &block, nil
- }, account1, account2)
-}
-
-func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) db.Error {
- return r.state.Caches.GTS.Block().Store(block, func() error {
- _, err := r.conn.NewInsert().Model(block).Exec(ctx)
- return r.conn.ProcessError(err)
- })
-}
-
-func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) db.Error {
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Where("? = ?", bun.Ident("block.id"), id).
- Exec(ctx); err != nil {
- return r.conn.ProcessError(err)
- }
-
- // Drop any old value from cache by this ID
- r.state.Caches.GTS.Block().Invalidate("ID", id)
- return nil
-}
-
-func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) db.Error {
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Where("? = ?", bun.Ident("block.uri"), uri).
- Exec(ctx); err != nil {
- return r.conn.ProcessError(err)
- }
-
- // Drop any old value from cache by this URI
- r.state.Caches.GTS.Block().Invalidate("URI", uri)
- return nil
-}
-
-func (r *relationshipDB) DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) db.Error {
- blockIDs := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Column("block.id").
- Where("? = ?", bun.Ident("block.account_id"), originAccountID)
-
- if err := q.Scan(ctx, &blockIDs); err != nil {
- return r.conn.ProcessError(err)
- }
-
- for _, blockID := range blockIDs {
- if err := r.DeleteBlockByID(ctx, blockID); err != nil {
- return err
- }
- }
-
- return nil
-}
-
-func (r *relationshipDB) DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) db.Error {
- blockIDs := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("blocks"), bun.Ident("block")).
- Column("block.id").
- Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID)
-
- if err := q.Scan(ctx, &blockIDs); err != nil {
- return r.conn.ProcessError(err)
- }
-
- for _, blockID := range blockIDs {
- if err := r.DeleteBlockByID(ctx, blockID); err != nil {
- return err
- }
- }
-
- return nil
-}
-
func (r *relationshipDB) GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
- rel := >smodel.Relationship{
- ID: targetAccount,
+ var rel gtsmodel.Relationship
+ rel.ID = targetAccount
+
+ // check if the requesting follows the target
+ follow, err := r.GetFollow(
+ gtscontext.SetBarebones(ctx),
+ requestingAccount,
+ targetAccount,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return nil, fmt.Errorf("GetRelationship: error fetching follow: %w", err)
}
- // check if the requesting account follows the target account
- follow := >smodel.Follow{}
- if err := r.conn.
- NewSelect().
- Model(follow).
- Column("follow.show_reblogs", "follow.notify").
- Where("? = ?", bun.Ident("follow.account_id"), requestingAccount).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount).
- Limit(1).
- Scan(ctx); err != nil {
- if err := r.conn.ProcessError(err); err != db.ErrNoEntries {
- return nil, fmt.Errorf("GetRelationship: error fetching follow: %s", err)
- }
- // no follow exists so these are all false
- rel.Following = false
- rel.ShowingReblogs = false
- rel.Notifying = false
- } else {
+ if follow != nil {
// follow exists so we can fill these fields out...
rel.Following = true
rel.ShowingReblogs = *follow.ShowReblogs
rel.Notifying = *follow.Notify
}
- // check if the target account follows the requesting account
- followedByQ := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Where("? = ?", bun.Ident("follow.account_id"), targetAccount).
- Where("? = ?", bun.Ident("follow.target_account_id"), requestingAccount)
- followedBy, err := r.conn.Exists(ctx, followedByQ)
+ // check if the target follows the requesting
+ rel.FollowedBy, err = r.IsFollowing(ctx,
+ targetAccount,
+ requestingAccount,
+ )
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking followedBy: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking followedBy: %w", err)
}
- rel.FollowedBy = followedBy
- // check if there's a pending following request from requesting account to target account
- requestedQ := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), requestingAccount).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount)
- requested, err := r.conn.Exists(ctx, requestedQ)
+ // check if requesting has follow requested target
+ rel.Requested, err = r.IsFollowRequested(ctx,
+ requestingAccount,
+ targetAccount,
+ )
if err != nil {
- return nil, fmt.Errorf("GetRelationship: error checking requested: %s", err)
+ return nil, fmt.Errorf("GetRelationship: error checking requested: %w", err)
}
- rel.Requested = requested
// check if the requesting account is blocking the target account
- blockA2T, err := r.getBlock(ctx, requestingAccount, targetAccount)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, fmt.Errorf("GetRelationship: error checking blocking: %s", err)
+ rel.Blocking, err = r.IsBlocked(ctx, requestingAccount, targetAccount)
+ if err != nil {
+ return nil, fmt.Errorf("GetRelationship: error checking blocking: %w", err)
}
- rel.Blocking = (blockA2T != nil)
// check if the requesting account is blocked by the target account
- blockT2A, err := r.getBlock(ctx, targetAccount, requestingAccount)
- if err != nil && !errors.Is(err, db.ErrNoEntries) {
- return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %s", err)
- }
- rel.BlockedBy = (blockT2A != nil)
-
- return rel, nil
-}
-
-func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Where("? = ?", bun.Ident("follow.account_id"), sourceAccount.ID).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccount.ID)
-
- return r.conn.Exists(ctx, q)
-}
-
-func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
- if sourceAccount == nil || targetAccount == nil {
- return false, nil
- }
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), sourceAccount.ID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccount.ID)
-
- return r.conn.Exists(ctx, q)
-}
-
-func (r *relationshipDB) IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
- if account1 == nil || account2 == nil {
- return false, nil
- }
-
- // make sure account 1 follows account 2
- f1, err := r.IsFollowing(ctx, account1, account2)
+ rel.BlockedBy, err = r.IsBlocked(ctx, targetAccount, requestingAccount)
if err != nil {
- return false, err
+ return nil, fmt.Errorf("GetRelationship: error checking blockedBy: %w", err)
}
- // make sure account 2 follows account 1
- f2, err := r.IsFollowing(ctx, account2, account1)
- if err != nil {
- return false, err
- }
-
- return f1 && f2, nil
+ return &rel, nil
}
-func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
- // Get original follow request.
- var followRequestID string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Scan(ctx, &followRequestID); err != nil {
+func (r *relationshipDB) GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectFollows(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- followRequest, err := r.getFollowRequest(ctx, followRequestID)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Create a new follow to 'replace'
- // the original follow request with.
- follow := >smodel.Follow{
- ID: followRequest.ID,
- AccountID: originAccountID,
- Account: followRequest.Account,
- TargetAccountID: targetAccountID,
- TargetAccount: followRequest.TargetAccount,
- URI: followRequest.URI,
- }
-
- // If the follow already exists, just
- // replace the URI with the new one.
- if _, err := r.conn.
- NewInsert().
- Model(follow).
- On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Delete original follow request.
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Delete original follow request notification.
- if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
- return nil, err
- }
-
- // return the new follow
- return follow, nil
+ return r.GetFollowsByIDs(ctx, followIDs)
}
-func (r *relationshipDB) RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, db.Error) {
- // Get original follow request.
- var followRequestID string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Scan(ctx, &followRequestID); err != nil {
+func (r *relationshipDB) GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectLocalFollows(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- followRequest, err := r.getFollowRequest(ctx, followRequestID)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Delete original follow request.
- if _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.id"), followRequest.ID).
- Exec(ctx); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- // Delete original follow request notification.
- if err := r.deleteFollowRequestNotif(ctx, originAccountID, targetAccountID); err != nil {
- return nil, err
- }
-
- // Return the now deleted follow request.
- return followRequest, nil
+ return r.GetFollowsByIDs(ctx, followIDs)
}
-func (r *relationshipDB) deleteFollowRequestNotif(ctx context.Context, originAccountID string, targetAccountID string) db.Error {
- var id string
- if err := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
- Column("notification.id").
- Where("? = ?", bun.Ident("notification.origin_account_id"), originAccountID).
- Where("? = ?", bun.Ident("notification.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("notification.notification_type"), gtsmodel.NotificationFollowRequest).
- Limit(1). // There should only be one!
- Scan(ctx, &id); err != nil {
- err = r.conn.ProcessError(err)
- if errors.Is(err, db.ErrNoEntries) {
- // If no entries, the notif didn't
- // exist anyway so nothing to do here.
- return nil
- }
- // Return on real error.
- return err
- }
-
- return r.state.DB.DeleteNotification(ctx, id)
-}
-
-func (r *relationshipDB) getFollow(ctx context.Context, id string) (*gtsmodel.Follow, db.Error) {
- follow := >smodel.Follow{}
-
- err := r.conn.
- NewSelect().
- Model(follow).
- Where("? = ?", bun.Ident("follow.id"), id).
- Scan(ctx)
- if err != nil {
+func (r *relationshipDB) GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectFollowers(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
return nil, r.conn.ProcessError(err)
}
-
- follow.Account, err = r.state.DB.GetAccountByID(ctx, follow.AccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow account %q: %v", follow.AccountID, err)
- }
-
- follow.TargetAccount, err = r.state.DB.GetAccountByID(ctx, follow.TargetAccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow target account %q: %v", follow.TargetAccountID, err)
- }
-
- return follow, nil
+ return r.GetFollowsByIDs(ctx, followIDs)
}
-func (r *relationshipDB) GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, db.Error) {
- accountIDs := []string{}
+func (r *relationshipDB) GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error) {
+ var followIDs []string
+ if err := newSelectLocalFollowers(r.conn, accountID).
+ Scan(ctx, &followIDs); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+ return r.GetFollowsByIDs(ctx, followIDs)
+}
- // Select only the account ID of each follow.
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- ColumnExpr("? AS ?", bun.Ident("follow.account_id"), bun.Ident("account_id")).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
+func (r *relationshipDB) CountAccountFollows(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollows(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
- // Join on accounts table to select only
- // those with NULL domain (local accounts).
- q = q.
- Join("JOIN ? AS ? ON ? = ?",
- bun.Ident("accounts"),
- bun.Ident("account"),
- bun.Ident("follow.account_id"),
- bun.Ident("account.id"),
+func (r *relationshipDB) CountAccountLocalFollows(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectLocalFollows(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
+
+func (r *relationshipDB) CountAccountFollowers(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowers(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
+
+func (r *relationshipDB) CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectLocalFollowers(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
+
+func (r *relationshipDB) GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
+ var followReqIDs []string
+ if err := newSelectFollowRequests(r.conn, accountID).
+ Scan(ctx, &followReqIDs); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+ return r.GetFollowRequestsByIDs(ctx, followReqIDs)
+}
+
+func (r *relationshipDB) GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error) {
+ var followReqIDs []string
+ if err := newSelectFollowRequesting(r.conn, accountID).
+ Scan(ctx, &followReqIDs); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+ return r.GetFollowRequestsByIDs(ctx, followReqIDs)
+}
+
+func (r *relationshipDB) CountAccountFollowRequests(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowRequests(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
+
+func (r *relationshipDB) CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error) {
+ n, err := newSelectFollowRequesting(r.conn, accountID).Count(ctx)
+ return n, r.conn.ProcessError(err)
+}
+
+// newSelectFollowRequests returns a new select query for all rows in the follow_requests table with target_account_id = accountID.
+func newSelectFollowRequests(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ TableExpr("?", bun.Ident("follow_requests")).
+ ColumnExpr("?", bun.Ident("id")).
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
+}
+
+// newSelectFollowRequesting returns a new select query for all rows in the follow_requests table with account_id = accountID.
+func newSelectFollowRequesting(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ TableExpr("?", bun.Ident("follow_requests")).
+ ColumnExpr("?", bun.Ident("id")).
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
+}
+
+// newSelectFollows returns a new select query for all rows in the follows table with account_id = accountID.
+func newSelectFollows(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
+}
+
+// newSelectLocalFollows returns a new select query for all rows in the follows table with
+// account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
+func newSelectLocalFollows(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ? AND ? IN (?)",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ conn.NewSelect().
+ Table("accounts").
+ Column("id").
+ Where("? IS NULL", bun.Ident("domain")),
).
- Where("? IS NULL", bun.Ident("account.domain"))
-
- // We don't *really* need to order these,
- // but it makes it more consistent to do so.
- q = q.Order("account_id DESC")
-
- if err := q.Scan(ctx, &accountIDs); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- return accountIDs, nil
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
-func (r *relationshipDB) GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, db.Error) {
- ids := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id").
- Order("follow.updated_at DESC")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
- }
-
- if err := q.Scan(ctx, &ids); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- follows := make([]*gtsmodel.Follow, 0, len(ids))
- for _, id := range ids {
- follow, err := r.getFollow(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow %q: %v", id, err)
- continue
- }
-
- follows = append(follows, follow)
- }
-
- return follows, nil
+// newSelectFollowers returns a new select query for all rows in the follows table with target_account_id = accountID.
+func newSelectFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ?", bun.Ident("target_account_id"), accountID).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
-func (r *relationshipDB) CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Column("follow.id")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID)
- }
-
- return q.Count(ctx)
-}
-
-func (r *relationshipDB) getFollowRequest(ctx context.Context, id string) (*gtsmodel.FollowRequest, db.Error) {
- followRequest := >smodel.FollowRequest{}
-
- err := r.conn.
- NewSelect().
- Model(followRequest).
- Where("? = ?", bun.Ident("follow_request.id"), id).
- Scan(ctx)
- if err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- followRequest.Account, err = r.state.DB.GetAccountByID(ctx, followRequest.AccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow request account %q: %v", followRequest.AccountID, err)
- }
-
- followRequest.TargetAccount, err = r.state.DB.GetAccountByID(ctx, followRequest.TargetAccountID)
- if err != nil {
- log.Errorf(ctx, "error getting follow request target account %q: %v", followRequest.TargetAccountID, err)
- }
-
- return followRequest, nil
-}
-
-func (r *relationshipDB) GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, db.Error) {
- ids := []string{}
-
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
- }
-
- if err := q.Scan(ctx, &ids); err != nil {
- return nil, r.conn.ProcessError(err)
- }
-
- followRequests := make([]*gtsmodel.FollowRequest, 0, len(ids))
- for _, id := range ids {
- followRequest, err := r.getFollowRequest(ctx, id)
- if err != nil {
- log.Errorf(ctx, "error getting follow request %q: %v", id, err)
- continue
- }
-
- followRequests = append(followRequests, followRequest)
- }
-
- return followRequests, nil
-}
-
-func (r *relationshipDB) CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, db.Error) {
- q := r.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Column("follow_request.id").
- Order("follow_request.updated_at DESC")
-
- if accountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.account_id"), accountID)
- }
-
- if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID)
- }
-
- return q.Count(ctx)
-}
-
-func (r *relationshipDB) Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
- uri := new(string)
-
- _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follows"), bun.Ident("follow")).
- Where("? = ?", bun.Ident("follow.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("follow.account_id"), originAccountID).
- Returning("?", bun.Ident("uri")).Exec(ctx, uri)
-
- // Only return proper errors.
- if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
- return *uri, err
- }
-
- return *uri, nil
-}
-
-func (r *relationshipDB) UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, db.Error) {
- uri := new(string)
-
- _, err := r.conn.
- NewDelete().
- TableExpr("? AS ?", bun.Ident("follow_requests"), bun.Ident("follow_request")).
- Where("? = ?", bun.Ident("follow_request.target_account_id"), targetAccountID).
- Where("? = ?", bun.Ident("follow_request.account_id"), originAccountID).
- Returning("?", bun.Ident("uri")).Exec(ctx, uri)
-
- // Only return proper errors.
- if err = r.conn.ProcessError(err); err != db.ErrNoEntries {
- return *uri, err
- }
-
- return *uri, nil
+// newSelectLocalFollowers returns a new select query for all rows in the follows table with
+// target_account_id = accountID where the corresponding account ID has a NULL domain (i.e. is local).
+func newSelectLocalFollowers(conn *DBConn, accountID string) *bun.SelectQuery {
+ return conn.NewSelect().
+ Table("follows").
+ Column("id").
+ Where("? = ? AND ? IN (?)",
+ bun.Ident("target_account_id"),
+ accountID,
+ bun.Ident("account_id"),
+ conn.NewSelect().
+ Table("accounts").
+ Column("id").
+ Where("? IS NULL", bun.Ident("domain")),
+ ).
+ OrderExpr("? DESC", bun.Ident("updated_at"))
}
diff --git a/internal/db/bundb/relationship_block.go b/internal/db/bundb/relationship_block.go
new file mode 100644
index 000000000..9232ea984
--- /dev/null
+++ b/internal/db/bundb/relationship_block.go
@@ -0,0 +1,218 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/uptrace/bun"
+)
+
+func (r *relationshipDB) IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
+ block, err := r.GetBlock(
+ gtscontext.SetBarebones(ctx),
+ sourceAccountID,
+ targetAccountID,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return false, err
+ }
+ return (block != nil), nil
+}
+
+func (r *relationshipDB) IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error) {
+ // Look for a block in direction of account1->account2
+ b1, err := r.IsBlocked(ctx, accountID1, accountID2)
+ if err != nil || b1 {
+ return true, err
+ }
+
+ // Look for a block in direction of account2->account1
+ b2, err := r.IsBlocked(ctx, accountID2, accountID1)
+ if err != nil || b2 {
+ return true, err
+ }
+
+ return false, nil
+}
+
+func (r *relationshipDB) GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error) {
+ return r.getBlock(
+ ctx,
+ "ID",
+ func(block *gtsmodel.Block) error {
+ return r.conn.NewSelect().Model(block).
+ Where("? = ?", bun.Ident("block.id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (r *relationshipDB) GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error) {
+ return r.getBlock(
+ ctx,
+ "URI",
+ func(block *gtsmodel.Block) error {
+ return r.conn.NewSelect().Model(block).
+ Where("? = ?", bun.Ident("block.uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (r *relationshipDB) GetBlock(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Block, error) {
+ return r.getBlock(
+ ctx,
+ "AccountID.TargetAccountID",
+ func(block *gtsmodel.Block) error {
+ return r.conn.NewSelect().Model(block).
+ Where("? = ?", bun.Ident("block.account_id"), sourceAccountID).
+ Where("? = ?", bun.Ident("block.target_account_id"), targetAccountID).
+ Scan(ctx)
+ },
+ sourceAccountID,
+ targetAccountID,
+ )
+}
+
+func (r *relationshipDB) getBlock(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Block) error, keyParts ...any) (*gtsmodel.Block, error) {
+ // Fetch block from cache with loader callback
+ block, err := r.state.Caches.GTS.Block().Load(lookup, func() (*gtsmodel.Block, error) {
+ var block gtsmodel.Block
+
+ // Not cached! Perform database query
+ if err := dbQuery(&block); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ return &block, nil
+ }, keyParts...)
+ if err != nil {
+ // already processe
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return block, nil
+ }
+
+ // Set the block source account
+ block.Account, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ block.AccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting block source account: %w", err)
+ }
+
+ // Set the block target account
+ block.TargetAccount, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ block.TargetAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting block target account: %w", err)
+ }
+
+ return block, nil
+}
+
+func (r *relationshipDB) PutBlock(ctx context.Context, block *gtsmodel.Block) error {
+ err := r.state.Caches.GTS.Block().Store(block, func() error {
+ _, err := r.conn.NewInsert().Model(block).Exec(ctx)
+ return r.conn.ProcessError(err)
+ })
+ if err != nil {
+ return err
+ }
+
+ // Invalidate block origin account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", block.AccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", block.AccountID)
+
+ // Invalidate block target account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", block.TargetAccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", block.TargetAccountID)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteBlockByID(ctx context.Context, id string) error {
+ block, err := r.GetBlockByID(gtscontext.SetBarebones(ctx), id)
+ if err != nil {
+ return err
+ }
+ return r.deleteBlock(ctx, block)
+}
+
+func (r *relationshipDB) DeleteBlockByURI(ctx context.Context, uri string) error {
+ block, err := r.GetBlockByURI(gtscontext.SetBarebones(ctx), uri)
+ if err != nil {
+ return err
+ }
+ return r.deleteBlock(ctx, block)
+}
+
+func (r *relationshipDB) deleteBlock(ctx context.Context, block *gtsmodel.Block) error {
+ if _, err := r.conn.
+ NewDelete().
+ Table("blocks").
+ Where("? = ?", bun.Ident("id"), block.ID).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate block from cache lookups.
+ r.state.Caches.GTS.Block().Invalidate("ID", block.ID)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteAccountBlocks(ctx context.Context, accountID string) error {
+ var blockIDs []string
+
+ if err := r.conn.NewSelect().
+ Table("blocks").
+ ColumnExpr("?", bun.Ident("id")).
+ WhereOr("? = ? OR ? = ?",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ accountID,
+ ).
+ Scan(ctx, &blockIDs); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ for _, id := range blockIDs {
+ if err := r.DeleteBlockByID(ctx, id); err != nil {
+ log.Errorf(ctx, "error deleting block %q: %v", id, err)
+ }
+ }
+
+ return nil
+}
diff --git a/internal/db/bundb/relationship_follow.go b/internal/db/bundb/relationship_follow.go
new file mode 100644
index 000000000..4a315d116
--- /dev/null
+++ b/internal/db/bundb/relationship_follow.go
@@ -0,0 +1,243 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/uptrace/bun"
+)
+
+func (r *relationshipDB) GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error) {
+ return r.getFollow(
+ ctx,
+ "ID",
+ func(follow *gtsmodel.Follow) error {
+ return r.conn.NewSelect().
+ Model(follow).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (r *relationshipDB) GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error) {
+ return r.getFollow(
+ ctx,
+ "URI",
+ func(follow *gtsmodel.Follow) error {
+ return r.conn.NewSelect().
+ Model(follow).
+ Where("? = ?", bun.Ident("uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (r *relationshipDB) GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
+ return r.getFollow(
+ ctx,
+ "AccountID.TargetAccountID",
+ func(follow *gtsmodel.Follow) error {
+ return r.conn.NewSelect().
+ Model(follow).
+ Where("? = ?", bun.Ident("account_id"), sourceAccountID).
+ Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
+ Scan(ctx)
+ },
+ sourceAccountID,
+ targetAccountID,
+ )
+}
+
+func (r *relationshipDB) GetFollowsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.Follow, error) {
+ // Preallocate slice of expected length.
+ follows := make([]*gtsmodel.Follow, 0, len(ids))
+
+ for _, id := range ids {
+ // Fetch follow model for this ID.
+ follow, err := r.GetFollowByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting follow %q: %v", id, err)
+ continue
+ }
+
+ // Append to return slice.
+ follows = append(follows, follow)
+ }
+
+ return follows, nil
+}
+
+func (r *relationshipDB) IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
+ follow, err := r.GetFollow(
+ gtscontext.SetBarebones(ctx),
+ sourceAccountID,
+ targetAccountID,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return false, err
+ }
+ return (follow != nil), nil
+}
+
+func (r *relationshipDB) IsMutualFollowing(ctx context.Context, accountID1 string, accountID2 string) (bool, db.Error) {
+ // make sure account 1 follows account 2
+ f1, err := r.IsFollowing(ctx,
+ accountID1,
+ accountID2,
+ )
+ if !f1 /* f1 = false when err != nil */ {
+ return false, err
+ }
+
+ // make sure account 2 follows account 1
+ f2, err := r.IsFollowing(ctx,
+ accountID2,
+ accountID1,
+ )
+ if !f2 /* f2 = false when err != nil */ {
+ return false, err
+ }
+
+ return true, nil
+}
+
+func (r *relationshipDB) getFollow(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Follow) error, keyParts ...any) (*gtsmodel.Follow, error) {
+ // Fetch follow from database cache with loader callback
+ follow, err := r.state.Caches.GTS.Follow().Load(lookup, func() (*gtsmodel.Follow, error) {
+ var follow gtsmodel.Follow
+
+ // Not cached! Perform database query
+ if err := dbQuery(&follow); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ return &follow, nil
+ }, keyParts...)
+ if err != nil {
+ // error already processed
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return follow, nil
+ }
+
+ // Set the follow source account
+ follow.Account, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ follow.AccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting follow source account: %w", err)
+ }
+
+ // Set the follow target account
+ follow.TargetAccount, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ follow.TargetAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting follow target account: %w", err)
+ }
+
+ return follow, nil
+}
+
+func (r *relationshipDB) PutFollow(ctx context.Context, follow *gtsmodel.Follow) error {
+ err := r.state.Caches.GTS.Follow().Store(follow, func() error {
+ _, err := r.conn.NewInsert().Model(follow).Exec(ctx)
+ return r.conn.ProcessError(err)
+ })
+ if err != nil {
+ return err
+ }
+
+ // Invalidate follow origin account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
+
+ // Invalidate follow target account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteFollowByID(ctx context.Context, id string) error {
+ if _, err := r.conn.NewDelete().
+ Table("follows").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow from cache lookups.
+ r.state.Caches.GTS.Follow().Invalidate("ID", id)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteFollowByURI(ctx context.Context, uri string) error {
+ if _, err := r.conn.NewDelete().
+ Table("follows").
+ Where("? = ?", bun.Ident("uri"), uri).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow from cache lookups.
+ r.state.Caches.GTS.Follow().Invalidate("URI", uri)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID string) error {
+ var followIDs []string
+
+ if _, err := r.conn.
+ NewDelete().
+ Table("follows").
+ WhereOr("? = ? OR ? = ?",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ accountID,
+ ).
+ Returning("?", bun.Ident("id")).
+ Exec(ctx, &followIDs); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate each returned ID.
+ for _, id := range followIDs {
+ r.state.Caches.GTS.Follow().Invalidate("ID", id)
+ }
+
+ return nil
+}
diff --git a/internal/db/bundb/relationship_follow_req.go b/internal/db/bundb/relationship_follow_req.go
new file mode 100644
index 000000000..11200338d
--- /dev/null
+++ b/internal/db/bundb/relationship_follow_req.go
@@ -0,0 +1,293 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package bundb
+
+import (
+ "context"
+ "errors"
+ "fmt"
+
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/uptrace/bun"
+)
+
+func (r *relationshipDB) GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error) {
+ return r.getFollowRequest(
+ ctx,
+ "ID",
+ func(followReq *gtsmodel.FollowRequest) error {
+ return r.conn.NewSelect().
+ Model(followReq).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (r *relationshipDB) GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error) {
+ return r.getFollowRequest(
+ ctx,
+ "URI",
+ func(followReq *gtsmodel.FollowRequest) error {
+ return r.conn.NewSelect().
+ Model(followReq).
+ Where("? = ?", bun.Ident("uri"), uri).
+ Scan(ctx)
+ },
+ uri,
+ )
+}
+
+func (r *relationshipDB) GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error) {
+ return r.getFollowRequest(
+ ctx,
+ "AccountID.TargetAccountID",
+ func(followReq *gtsmodel.FollowRequest) error {
+ return r.conn.NewSelect().
+ Model(followReq).
+ Where("? = ?", bun.Ident("account_id"), sourceAccountID).
+ Where("? = ?", bun.Ident("target_account_id"), targetAccountID).
+ Scan(ctx)
+ },
+ sourceAccountID,
+ targetAccountID,
+ )
+}
+
+func (r *relationshipDB) GetFollowRequestsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.FollowRequest, error) {
+ // Preallocate slice of expected length.
+ followReqs := make([]*gtsmodel.FollowRequest, 0, len(ids))
+
+ for _, id := range ids {
+ // Fetch follow request model for this ID.
+ followReq, err := r.GetFollowRequestByID(ctx, id)
+ if err != nil {
+ log.Errorf(ctx, "error getting follow request %q: %v", id, err)
+ continue
+ }
+
+ // Append to return slice.
+ followReqs = append(followReqs, followReq)
+ }
+
+ return followReqs, nil
+}
+
+func (r *relationshipDB) IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, db.Error) {
+ followReq, err := r.GetFollowRequest(
+ gtscontext.SetBarebones(ctx),
+ sourceAccountID,
+ targetAccountID,
+ )
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ return false, err
+ }
+ return (followReq != nil), nil
+}
+
+func (r *relationshipDB) getFollowRequest(ctx context.Context, lookup string, dbQuery func(*gtsmodel.FollowRequest) error, keyParts ...any) (*gtsmodel.FollowRequest, error) {
+ // Fetch follow request from database cache with loader callback
+ followReq, err := r.state.Caches.GTS.FollowRequest().Load(lookup, func() (*gtsmodel.FollowRequest, error) {
+ var followReq gtsmodel.FollowRequest
+
+ // Not cached! Perform database query
+ if err := dbQuery(&followReq); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ return &followReq, nil
+ }, keyParts...)
+ if err != nil {
+ // error already processed
+ return nil, err
+ }
+
+ if gtscontext.Barebones(ctx) {
+ // Only a barebones model was requested.
+ return followReq, nil
+ }
+
+ // Set the follow request source account
+ followReq.Account, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ followReq.AccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting follow request source account: %w", err)
+ }
+
+ // Set the follow request target account
+ followReq.TargetAccount, err = r.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ followReq.TargetAccountID,
+ )
+ if err != nil {
+ return nil, fmt.Errorf("error getting follow request target account: %w", err)
+ }
+
+ return followReq, nil
+}
+
+func (r *relationshipDB) PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error {
+ err := r.state.Caches.GTS.FollowRequest().Store(follow, func() error {
+ _, err := r.conn.NewInsert().Model(follow).Exec(ctx)
+ return r.conn.ProcessError(err)
+ })
+ if err != nil {
+ return err
+ }
+
+ // Invalidate follow request origin account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", follow.AccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", follow.AccountID)
+
+ // Invalidate follow request target account ID cached visibility.
+ r.state.Caches.Visibility.Invalidate("ItemID", follow.TargetAccountID)
+ r.state.Caches.Visibility.Invalidate("RequesterID", follow.TargetAccountID)
+
+ return nil
+}
+
+func (r *relationshipDB) AcceptFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
+ // Get original follow request.
+ followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create a new follow to 'replace'
+ // the original follow request with.
+ follow := >smodel.Follow{
+ ID: followReq.ID,
+ AccountID: sourceAccountID,
+ Account: followReq.Account,
+ TargetAccountID: targetAccountID,
+ TargetAccount: followReq.TargetAccount,
+ URI: followReq.URI,
+ }
+
+ // If the follow already exists, just
+ // replace the URI with the new one.
+ if _, err := r.conn.
+ NewInsert().
+ Model(follow).
+ On("CONFLICT (?,?) DO UPDATE set ? = ?", bun.Ident("account_id"), bun.Ident("target_account_id"), bun.Ident("uri"), follow.URI).
+ Exec(ctx); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ // Delete original follow request.
+ if _, err := r.conn.
+ NewDelete().
+ Table("follow_requests").
+ Where("? = ?", bun.Ident("id"), followReq.ID).
+ Exec(ctx); err != nil {
+ return nil, r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow request from cache lookups.
+ r.state.Caches.GTS.FollowRequest().Invalidate("ID", followReq.ID)
+
+ // Delete original follow request notification
+ if err := r.state.DB.DeleteNotifications(ctx, []string{
+ string(gtsmodel.NotificationFollowRequest),
+ }, targetAccountID, sourceAccountID); err != nil {
+ return nil, err
+ }
+
+ return follow, nil
+}
+
+func (r *relationshipDB) RejectFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) db.Error {
+ // Get original follow request.
+ followReq, err := r.GetFollowRequest(ctx, sourceAccountID, targetAccountID)
+ if err != nil {
+ return err
+ }
+
+ // Delete original follow request.
+ if _, err := r.conn.
+ NewDelete().
+ Table("follow_requests").
+ Where("? = ?", bun.Ident("id"), followReq.ID).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Delete original follow request notification
+ return r.state.DB.DeleteNotifications(ctx, []string{
+ string(gtsmodel.NotificationFollowRequest),
+ }, targetAccountID, sourceAccountID)
+}
+
+func (r *relationshipDB) DeleteFollowRequestByID(ctx context.Context, id string) error {
+ if _, err := r.conn.NewDelete().
+ Table("follow_requests").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow request from cache lookups.
+ r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteFollowRequestByURI(ctx context.Context, uri string) error {
+ if _, err := r.conn.NewDelete().
+ Table("follow_requests").
+ Where("? = ?", bun.Ident("uri"), uri).
+ Exec(ctx); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate follow request from cache lookups.
+ r.state.Caches.GTS.FollowRequest().Invalidate("URI", uri)
+
+ return nil
+}
+
+func (r *relationshipDB) DeleteAccountFollowRequests(ctx context.Context, accountID string) error {
+ var followIDs []string
+
+ if _, err := r.conn.
+ NewDelete().
+ Table("follow_requests").
+ WhereOr("? = ? OR ? = ?",
+ bun.Ident("account_id"),
+ accountID,
+ bun.Ident("target_account_id"),
+ accountID,
+ ).
+ Returning("?", bun.Ident("id")).
+ Exec(ctx, &followIDs); err != nil {
+ return r.conn.ProcessError(err)
+ }
+
+ // Invalidate each returned ID.
+ for _, id := range followIDs {
+ r.state.Caches.GTS.FollowRequest().Invalidate("ID", id)
+ }
+
+ return nil
+}
diff --git a/internal/db/bundb/relationship_test.go b/internal/db/bundb/relationship_test.go
index 3d307ecde..00583d175 100644
--- a/internal/db/bundb/relationship_test.go
+++ b/internal/db/bundb/relationship_test.go
@@ -19,17 +19,359 @@
import (
"context"
+ "errors"
+ "reflect"
"testing"
+ "time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/id"
)
type RelationshipTestSuite struct {
BunDBStandardTestSuite
}
+func (suite *RelationshipTestSuite) TestGetBlockBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 block models are equal.
+ isEqual := func(b1, b2 gtsmodel.Block) bool {
+ // Clear populated sub-models.
+ b1.Account = nil
+ b2.Account = nil
+ b1.TargetAccount = nil
+ b2.TargetAccount = nil
+
+ // Clear database-set fields.
+ b1.CreatedAt = time.Time{}
+ b2.CreatedAt = time.Time{}
+ b1.UpdatedAt = time.Time{}
+ b2.UpdatedAt = time.Time{}
+
+ return reflect.DeepEqual(b1, b2)
+ }
+
+ var testBlocks []*gtsmodel.Block
+
+ for _, account1 := range suite.testAccounts {
+ for _, account2 := range suite.testAccounts {
+ if account1.ID == account2.ID {
+ // don't block *yourself* ...
+ continue
+ }
+
+ // Create new account block.
+ block := >smodel.Block{
+ ID: id.NewULID(),
+ URI: "http://127.0.0.1:8080/" + id.NewULID(),
+ AccountID: account1.ID,
+ TargetAccountID: account2.ID,
+ }
+
+ // Attempt to place the block in database (if not already).
+ if err := suite.db.PutBlock(ctx, block); err != nil {
+ if err != db.ErrAlreadyExists {
+ // Unrecoverable database error.
+ t.Fatalf("error creating block: %v", err)
+ }
+
+ // Fetch existing block from database between accounts.
+ block, _ = suite.db.GetBlock(ctx, account1.ID, account2.ID)
+ continue
+ }
+
+ // Append generated block to test cases.
+ testBlocks = append(testBlocks, block)
+ }
+ }
+
+ for _, block := range testBlocks {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.Block, error){
+ "id": func() (*gtsmodel.Block, error) {
+ return suite.db.GetBlockByID(ctx, block.ID)
+ },
+
+ "uri": func() (*gtsmodel.Block, error) {
+ return suite.db.GetBlockByURI(ctx, block.URI)
+ },
+
+ "origin_target": func() (*gtsmodel.Block, error) {
+ return suite.db.GetBlock(ctx, block.AccountID, block.TargetAccountID)
+ },
+ } {
+
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkBlock, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received block data.
+ if !isEqual(*checkBlock, *block) {
+ t.Errorf("block does not contain expected data: %+v", checkBlock)
+ continue
+ }
+
+ // Check that block origin account populated.
+ if checkBlock.Account == nil || checkBlock.Account.ID != block.AccountID {
+ t.Errorf("block origin account not correctly populated for: %+v", checkBlock)
+ continue
+ }
+
+ // Check that block target account populated.
+ if checkBlock.TargetAccount == nil || checkBlock.TargetAccount.ID != block.TargetAccountID {
+ t.Errorf("block target account not correctly populated for: %+v", checkBlock)
+ continue
+ }
+ }
+ }
+}
+
+func (suite *RelationshipTestSuite) TestGetFollowBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 follow models are equal.
+ isEqual := func(f1, f2 gtsmodel.Follow) bool {
+ // Clear populated sub-models.
+ f1.Account = nil
+ f2.Account = nil
+ f1.TargetAccount = nil
+ f2.TargetAccount = nil
+
+ // Clear database-set fields.
+ f1.CreatedAt = time.Time{}
+ f2.CreatedAt = time.Time{}
+ f1.UpdatedAt = time.Time{}
+ f2.UpdatedAt = time.Time{}
+
+ return reflect.DeepEqual(f1, f2)
+ }
+
+ var testFollows []*gtsmodel.Follow
+
+ for _, account1 := range suite.testAccounts {
+ for _, account2 := range suite.testAccounts {
+ if account1.ID == account2.ID {
+ // don't follow *yourself* ...
+ continue
+ }
+
+ // Create new account follow.
+ follow := >smodel.Follow{
+ ID: id.NewULID(),
+ URI: "http://127.0.0.1:8080/" + id.NewULID(),
+ AccountID: account1.ID,
+ TargetAccountID: account2.ID,
+ }
+
+ // Attempt to place the follow in database (if not already).
+ if err := suite.db.PutFollow(ctx, follow); err != nil {
+ if err != db.ErrAlreadyExists {
+ // Unrecoverable database error.
+ t.Fatalf("error creating follow: %v", err)
+ }
+
+ // Fetch existing follow from database between accounts.
+ follow, _ = suite.db.GetFollow(ctx, account1.ID, account2.ID)
+ continue
+ }
+
+ // Append generated follow to test cases.
+ testFollows = append(testFollows, follow)
+ }
+ }
+
+ for _, follow := range testFollows {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.Follow, error){
+ "id": func() (*gtsmodel.Follow, error) {
+ return suite.db.GetFollowByID(ctx, follow.ID)
+ },
+
+ "uri": func() (*gtsmodel.Follow, error) {
+ return suite.db.GetFollowByURI(ctx, follow.URI)
+ },
+
+ "origin_target": func() (*gtsmodel.Follow, error) {
+ return suite.db.GetFollow(ctx, follow.AccountID, follow.TargetAccountID)
+ },
+ } {
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkFollow, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received follow data.
+ if !isEqual(*checkFollow, *follow) {
+ t.Errorf("follow does not contain expected data: %+v", checkFollow)
+ continue
+ }
+
+ // Check that follow origin account populated.
+ if checkFollow.Account == nil || checkFollow.Account.ID != follow.AccountID {
+ t.Errorf("follow origin account not correctly populated for: %+v", checkFollow)
+ continue
+ }
+
+ // Check that follow target account populated.
+ if checkFollow.TargetAccount == nil || checkFollow.TargetAccount.ID != follow.TargetAccountID {
+ t.Errorf("follow target account not correctly populated for: %+v", checkFollow)
+ continue
+ }
+ }
+ }
+}
+
+func (suite *RelationshipTestSuite) TestGetFollowRequestBy() {
+ t := suite.T()
+
+ // Create a new context for this test.
+ ctx, cncl := context.WithCancel(context.Background())
+ defer cncl()
+
+ // Sentinel error to mark avoiding a test case.
+ sentinelErr := errors.New("sentinel")
+
+ // isEqual checks if 2 follow request models are equal.
+ isEqual := func(f1, f2 gtsmodel.FollowRequest) bool {
+ // Clear populated sub-models.
+ f1.Account = nil
+ f2.Account = nil
+ f1.TargetAccount = nil
+ f2.TargetAccount = nil
+
+ // Clear database-set fields.
+ f1.CreatedAt = time.Time{}
+ f2.CreatedAt = time.Time{}
+ f1.UpdatedAt = time.Time{}
+ f2.UpdatedAt = time.Time{}
+
+ return reflect.DeepEqual(f1, f2)
+ }
+
+ var testFollowReqs []*gtsmodel.FollowRequest
+
+ for _, account1 := range suite.testAccounts {
+ for _, account2 := range suite.testAccounts {
+ if account1.ID == account2.ID {
+ // don't follow *yourself* ...
+ continue
+ }
+
+ // Create new account follow request.
+ followReq := >smodel.FollowRequest{
+ ID: id.NewULID(),
+ URI: "http://127.0.0.1:8080/" + id.NewULID(),
+ AccountID: account1.ID,
+ TargetAccountID: account2.ID,
+ }
+
+ // Attempt to place the follow in database (if not already).
+ if err := suite.db.PutFollowRequest(ctx, followReq); err != nil {
+ if err != db.ErrAlreadyExists {
+ // Unrecoverable database error.
+ t.Fatalf("error creating follow request: %v", err)
+ }
+
+ // Fetch existing follow request from database between accounts.
+ followReq, _ = suite.db.GetFollowRequest(ctx, account1.ID, account2.ID)
+ continue
+ }
+
+ // Append generated follow request to test cases.
+ testFollowReqs = append(testFollowReqs, followReq)
+ }
+ }
+
+ for _, followReq := range testFollowReqs {
+ for lookup, dbfunc := range map[string]func() (*gtsmodel.FollowRequest, error){
+ "id": func() (*gtsmodel.FollowRequest, error) {
+ return suite.db.GetFollowRequestByID(ctx, followReq.ID)
+ },
+
+ "uri": func() (*gtsmodel.FollowRequest, error) {
+ return suite.db.GetFollowRequestByURI(ctx, followReq.URI)
+ },
+
+ "origin_target": func() (*gtsmodel.FollowRequest, error) {
+ return suite.db.GetFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID)
+ },
+ } {
+
+ // Clear database caches.
+ suite.state.Caches.Init()
+
+ t.Logf("checking database lookup %q", lookup)
+
+ // Perform database function.
+ checkFollowReq, err := dbfunc()
+ if err != nil {
+ if err == sentinelErr {
+ continue
+ }
+
+ t.Errorf("error encountered for database lookup %q: %v", lookup, err)
+ continue
+ }
+
+ // Check received follow request data.
+ if !isEqual(*checkFollowReq, *followReq) {
+ t.Errorf("follow request does not contain expected data: %+v", checkFollowReq)
+ continue
+ }
+
+ // Check that follow request origin account populated.
+ if checkFollowReq.Account == nil || checkFollowReq.Account.ID != followReq.AccountID {
+ t.Errorf("follow request origin account not correctly populated for: %+v", checkFollowReq)
+ continue
+ }
+
+ // Check that follow request target account populated.
+ if checkFollowReq.TargetAccount == nil || checkFollowReq.TargetAccount.ID != followReq.TargetAccountID {
+ t.Errorf("follow request target account not correctly populated for: %+v", checkFollowReq)
+ continue
+ }
+ }
+ }
+}
+
func (suite *RelationshipTestSuite) TestIsBlocked() {
ctx := context.Background()
@@ -37,11 +379,11 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
account2 := suite.testAccounts["local_account_2"].ID
// no blocks exist between account 1 and account 2
- blocked, err := suite.db.IsBlocked(ctx, account1, account2, false)
+ blocked, err := suite.db.IsBlocked(ctx, account1, account2)
suite.NoError(err)
suite.False(blocked)
- blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
+ blocked, err = suite.db.IsBlocked(ctx, account2, account1)
suite.NoError(err)
suite.False(blocked)
@@ -56,45 +398,24 @@ func (suite *RelationshipTestSuite) TestIsBlocked() {
}
// account 1 now blocks account 2
- blocked, err = suite.db.IsBlocked(ctx, account1, account2, false)
+ blocked, err = suite.db.IsBlocked(ctx, account1, account2)
suite.NoError(err)
suite.True(blocked)
// account 2 doesn't block account 1
- blocked, err = suite.db.IsBlocked(ctx, account2, account1, false)
+ blocked, err = suite.db.IsBlocked(ctx, account2, account1)
suite.NoError(err)
suite.False(blocked)
// a block exists in either direction between the two
- blocked, err = suite.db.IsBlocked(ctx, account1, account2, true)
+ blocked, err = suite.db.IsEitherBlocked(ctx, account1, account2)
suite.NoError(err)
suite.True(blocked)
- blocked, err = suite.db.IsBlocked(ctx, account2, account1, true)
+ blocked, err = suite.db.IsEitherBlocked(ctx, account2, account1)
suite.NoError(err)
suite.True(blocked)
}
-func (suite *RelationshipTestSuite) TestGetBlock() {
- ctx := context.Background()
-
- account1 := suite.testAccounts["local_account_1"].ID
- account2 := suite.testAccounts["local_account_2"].ID
-
- if err := suite.db.PutBlock(ctx, >smodel.Block{
- ID: "01G202BCSXXJZ70BHB5KCAHH8C",
- URI: "http://localhost:8080/some_block_uri_1",
- AccountID: account1,
- TargetAccountID: account2,
- }); err != nil {
- suite.FailNow(err.Error())
- }
-
- block, err := suite.db.GetBlock(ctx, account1, account2)
- suite.NoError(err)
- suite.NotNil(block)
- suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
-}
-
func (suite *RelationshipTestSuite) TestDeleteBlockByID() {
ctx := context.Background()
@@ -157,7 +478,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlockByURI() {
suite.Nil(block)
}
-func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
+func (suite *RelationshipTestSuite) TestDeleteAccountBlocks() {
ctx := context.Background()
// put a block in first
@@ -179,38 +500,7 @@ func (suite *RelationshipTestSuite) TestDeleteBlocksByOriginAccountID() {
suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
// delete the block by originAccountID
- err = suite.db.DeleteBlocksByOriginAccountID(ctx, account1)
- suite.NoError(err)
-
- // block should be gone
- block, err = suite.db.GetBlock(ctx, account1, account2)
- suite.ErrorIs(err, db.ErrNoEntries)
- suite.Nil(block)
-}
-
-func (suite *RelationshipTestSuite) TestDeleteBlocksByTargetAccountID() {
- ctx := context.Background()
-
- // put a block in first
- account1 := suite.testAccounts["local_account_1"].ID
- account2 := suite.testAccounts["local_account_2"].ID
- if err := suite.db.PutBlock(ctx, >smodel.Block{
- ID: "01G202BCSXXJZ70BHB5KCAHH8C",
- URI: "http://localhost:8080/some_block_uri_1",
- AccountID: account1,
- TargetAccountID: account2,
- }); err != nil {
- suite.FailNow(err.Error())
- }
-
- // make sure the block is in the db
- block, err := suite.db.GetBlock(ctx, account1, account2)
- suite.NoError(err)
- suite.NotNil(block)
- suite.Equal("01G202BCSXXJZ70BHB5KCAHH8C", block.ID)
-
- // delete the block by targetAccountID
- err = suite.db.DeleteBlocksByTargetAccountID(ctx, account2)
+ err = suite.db.DeleteAccountBlocks(ctx, account1)
suite.NoError(err)
// block should be gone
@@ -244,7 +534,7 @@ func (suite *RelationshipTestSuite) TestGetRelationship() {
func (suite *RelationshipTestSuite) TestIsFollowingYes() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
- isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
+ isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isFollowing)
}
@@ -252,7 +542,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingYes() {
func (suite *RelationshipTestSuite) TestIsFollowingNo() {
requestingAccount := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
- isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount, targetAccount)
+ isFollowing, err := suite.db.IsFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.False(isFollowing)
}
@@ -260,7 +550,7 @@ func (suite *RelationshipTestSuite) TestIsFollowingNo() {
func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
- isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
+ isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isMutualFollowing)
}
@@ -268,7 +558,7 @@ func (suite *RelationshipTestSuite) TestIsMutualFollowing() {
func (suite *RelationshipTestSuite) TestIsMutualFollowingNo() {
requestingAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["local_account_2"]
- isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount, targetAccount)
+ isMutualFollowing, err := suite.db.IsMutualFollowing(context.Background(), requestingAccount.ID, targetAccount.ID)
suite.NoError(err)
suite.True(isMutualFollowing)
}
@@ -306,7 +596,7 @@ func (suite *RelationshipTestSuite) TestAcceptFollowRequestOK() {
suite.Equal(followRequest.URI, follow.URI)
// Ensure notification is deleted.
- notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
+ notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(notification)
}
@@ -389,7 +679,7 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
TargetAccountID: targetAccount.ID,
}
- if err := suite.db.Put(ctx, followRequest); err != nil {
+ if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
@@ -404,12 +694,11 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestOK() {
suite.FailNow(err.Error())
}
- rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
+ err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.NoError(err)
- suite.NotNil(rejectedFollowRequest)
// Ensure notification is deleted.
- notification, err := suite.db.GetNotification(ctx, followRequestNotification.ID)
+ notification, err := suite.db.GetNotificationByID(ctx, followRequestNotification.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(notification)
}
@@ -419,9 +708,8 @@ func (suite *RelationshipTestSuite) TestRejectFollowRequestNotExisting() {
account := suite.testAccounts["admin_account"]
targetAccount := suite.testAccounts["local_account_2"]
- rejectedFollowRequest, err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
+ err := suite.db.RejectFollowRequest(ctx, account.ID, targetAccount.ID)
suite.ErrorIs(err, db.ErrNoEntries)
- suite.Nil(rejectedFollowRequest)
}
func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
@@ -440,42 +728,49 @@ func (suite *RelationshipTestSuite) TestGetAccountFollowRequests() {
suite.FailNow(err.Error())
}
- followRequests, err := suite.db.GetFollowRequests(ctx, "", targetAccount.ID)
+ followRequests, err := suite.db.GetAccountFollowRequests(ctx, targetAccount.ID)
suite.NoError(err)
suite.Len(followRequests, 1)
}
func (suite *RelationshipTestSuite) TestGetAccountFollows() {
account := suite.testAccounts["local_account_1"]
- follows, err := suite.db.GetFollows(context.Background(), account.ID, "")
+ follows, err := suite.db.GetAccountFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
-func (suite *RelationshipTestSuite) TestCountAccountFollows() {
+func (suite *RelationshipTestSuite) TestCountAccountFollowsLocalOnly() {
account := suite.testAccounts["local_account_1"]
- followsCount, err := suite.db.CountFollows(context.Background(), account.ID, "")
+ followsCount, err := suite.db.CountAccountLocalFollows(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
-func (suite *RelationshipTestSuite) TestGetAccountFollowedBy() {
+func (suite *RelationshipTestSuite) TestCountAccountFollows() {
account := suite.testAccounts["local_account_1"]
- follows, err := suite.db.GetFollows(context.Background(), "", account.ID)
+ followsCount, err := suite.db.CountAccountFollows(context.Background(), account.ID)
+ suite.NoError(err)
+ suite.Equal(2, followsCount)
+}
+
+func (suite *RelationshipTestSuite) TestGetAccountFollowers() {
+ account := suite.testAccounts["local_account_1"]
+ follows, err := suite.db.GetAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Len(follows, 2)
}
-func (suite *RelationshipTestSuite) TestGetLocalFollowersIDs() {
+func (suite *RelationshipTestSuite) TestCountAccountFollowers() {
account := suite.testAccounts["local_account_1"]
- accountIDs, err := suite.db.GetLocalFollowersIDs(context.Background(), account.ID)
+ followsCount, err := suite.db.CountAccountFollowers(context.Background(), account.ID)
suite.NoError(err)
- suite.EqualValues([]string{"01F8MH5NBDF2MV7CTC4Q5128HF", "01F8MH17FWEB39HZJ76B6VXSKF"}, accountIDs)
+ suite.Equal(2, followsCount)
}
-func (suite *RelationshipTestSuite) TestCountAccountFollowedBy() {
+func (suite *RelationshipTestSuite) TestCountAccountFollowersLocalOnly() {
account := suite.testAccounts["local_account_1"]
- followsCount, err := suite.db.CountFollows(context.Background(), "", account.ID)
+ followsCount, err := suite.db.CountAccountLocalFollowers(context.Background(), account.ID)
suite.NoError(err)
suite.Equal(2, followsCount)
}
@@ -484,18 +779,25 @@ func (suite *RelationshipTestSuite) TestUnfollowExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccount := suite.testAccounts["admin_account"]
- uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccount.ID)
+ follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
- suite.Equal("http://localhost:8080/users/the_mighty_zork/follow/01F8PY8RHWRQZV038T4E8T9YK8", uri)
+ suite.NotNil(follow)
+
+ err = suite.db.DeleteFollowByID(context.Background(), follow.ID)
+ suite.NoError(err)
+
+ follow, err = suite.db.GetFollow(context.Background(), originAccount.ID, targetAccount.ID)
+ suite.EqualError(err, db.ErrNoEntries.Error())
+ suite.Nil(follow)
}
func (suite *RelationshipTestSuite) TestUnfollowNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
- uri, err := suite.db.Unfollow(context.Background(), originAccount.ID, targetAccountID)
- suite.NoError(err)
- suite.Empty(uri)
+ follow, err := suite.db.GetFollow(context.Background(), originAccount.ID, targetAccountID)
+ suite.EqualError(err, db.ErrNoEntries.Error())
+ suite.Nil(follow)
}
func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
@@ -510,22 +812,29 @@ func (suite *RelationshipTestSuite) TestUnfollowRequestExisting() {
TargetAccountID: targetAccount.ID,
}
- if err := suite.db.Put(ctx, followRequest); err != nil {
+ if err := suite.db.PutFollowRequest(ctx, followRequest); err != nil {
suite.FailNow(err.Error())
}
- uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
+ followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
suite.NoError(err)
- suite.Equal("http://localhost:8080/weeeeeeeeeeeeeeeee", uri)
+ suite.NotNil(followRequest)
+
+ err = suite.db.DeleteFollowRequestByID(context.Background(), followRequest.ID)
+ suite.NoError(err)
+
+ followRequest, err = suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccount.ID)
+ suite.EqualError(err, db.ErrNoEntries.Error())
+ suite.Nil(followRequest)
}
func (suite *RelationshipTestSuite) TestUnfollowRequestNotExisting() {
originAccount := suite.testAccounts["local_account_1"]
targetAccountID := "01GTVD9N484CZ6AM90PGGNY7GQ"
- uri, err := suite.db.UnfollowRequest(context.Background(), originAccount.ID, targetAccountID)
- suite.NoError(err)
- suite.Empty(uri)
+ followRequest, err := suite.db.GetFollowRequest(context.Background(), originAccount.ID, targetAccountID)
+ suite.EqualError(err, db.ErrNoEntries.Error())
+ suite.Nil(followRequest)
}
func TestRelationshipTestSuite(t *testing.T) {
diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go
index deec9a118..c2b5546f8 100644
--- a/internal/db/bundb/status.go
+++ b/internal/db/bundb/status.go
@@ -26,6 +26,7 @@
"time"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -41,7 +42,6 @@ func (s *statusDB) newStatusQ(status interface{}) *bun.SelectQuery {
return s.conn.
NewSelect().
Model(status).
- Relation("Attachments").
Relation("Tags").
Relation("CreatedWithApplication")
}
@@ -102,81 +102,143 @@ func (s *statusDB) getStatus(ctx context.Context, lookup string, dbQuery func(*g
status, err := s.state.Caches.GTS.Status().Load(lookup, func() (*gtsmodel.Status, error) {
var status gtsmodel.Status
- // Not cached! Perform database query
+ // Not cached! Perform database query.
if err := dbQuery(&status); err != nil {
return nil, s.conn.ProcessError(err)
}
- if status.InReplyToID != "" {
- // Also load in-reply-to status
- status.InReplyTo = new(gtsmodel.Status)
- err := s.conn.NewSelect().Model(status.InReplyTo).
- Where("? = ?", bun.Ident("status.id"), status.InReplyToID).
- Scan(ctx)
- if err != nil {
- return nil, s.conn.ProcessError(err)
- }
- }
-
- if status.BoostOfID != "" {
- // Also load original boosted status
- status.BoostOf = new(gtsmodel.Status)
- err := s.conn.NewSelect().Model(status.BoostOf).
- Where("? = ?", bun.Ident("status.id"), status.BoostOfID).
- Scan(ctx)
- if err != nil {
- return nil, s.conn.ProcessError(err)
- }
- }
-
return &status, nil
}, keyParts...)
if err != nil {
- // error already processed
return nil, err
}
- // Set the status author account
- status.Account, err = s.state.DB.GetAccountByID(ctx, status.AccountID)
- if err != nil {
- return nil, fmt.Errorf("error getting status account: %w", err)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return status, nil
}
- if id := status.BoostOfAccountID; id != "" {
- // Set boost of status' author account
- status.BoostOfAccount, err = s.state.DB.GetAccountByID(ctx, id)
- if err != nil {
- return nil, fmt.Errorf("error getting boosted status account: %w", err)
- }
- }
-
- if id := status.InReplyToAccountID; id != "" {
- // Set in-reply-to status' author account
- status.InReplyToAccount, err = s.state.DB.GetAccountByID(ctx, id)
- if err != nil {
- return nil, fmt.Errorf("error getting in reply to status account: %w", err)
- }
- }
-
- if len(status.EmojiIDs) > 0 {
- // Fetch status emojis
- status.Emojis, err = s.state.DB.GetEmojisByIDs(ctx, status.EmojiIDs)
- if err != nil {
- return nil, fmt.Errorf("error getting status emojis: %w", err)
- }
- }
-
- if len(status.MentionIDs) > 0 {
- // Fetch status mentions
- status.Mentions, err = s.state.DB.GetMentions(ctx, status.MentionIDs)
- if err != nil {
- return nil, fmt.Errorf("error getting status mentions: %w", err)
- }
+ // Further populate the status fields where applicable.
+ if err := s.PopulateStatus(ctx, status); err != nil {
+ return nil, err
}
return status, nil
}
+func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) error {
+ var err error
+
+ if status.Account == nil {
+ // Status author is not set, fetch from database.
+ status.Account, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ status.AccountID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status author: %w", err)
+ }
+ }
+
+ if status.InReplyToID != "" && status.InReplyTo == nil {
+ // Status parent is not set, fetch from database.
+ status.InReplyTo, err = s.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ status.InReplyToID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status parent: %w", err)
+ }
+ }
+
+ if status.InReplyToID != "" {
+ if status.InReplyTo == nil {
+ // Status parent is not set, fetch from database.
+ status.InReplyTo, err = s.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ status.InReplyToID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status parent: %w", err)
+ }
+ }
+
+ if status.InReplyToAccount == nil {
+ // Status parent author is not set, fetch from database.
+ status.InReplyToAccount, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ status.InReplyToAccountID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status parent author: %w", err)
+ }
+ }
+ }
+
+ if status.BoostOfID != "" {
+ if status.BoostOf == nil {
+ // Status boost is not set, fetch from database.
+ status.BoostOf, err = s.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ status.BoostOfID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status boost: %w", err)
+ }
+ }
+
+ if status.BoostOfAccount == nil {
+ // Status boost author is not set, fetch from database.
+ status.BoostOfAccount, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ status.BoostOfAccountID,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status boost author: %w", err)
+ }
+ }
+ }
+
+ if !status.AttachmentsPopulated() {
+ // Status attachments are out-of-date with IDs, repopulate.
+ status.Attachments, err = s.state.DB.GetAttachmentsByIDs(
+ ctx, // these are already barebones
+ status.AttachmentIDs,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status attachments: %w", err)
+ }
+ }
+
+ // TODO: once we don't fetch using relations.
+ // if !status.TagsPopulated() {
+ // }
+
+ if !status.MentionsPopulated() {
+ // Status mentions are out-of-date with IDs, repopulate.
+ status.Mentions, err = s.state.DB.GetMentions(
+ ctx, // leave fully populated for now
+ status.MentionIDs,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status mentions: %w", err)
+ }
+ }
+
+ if !status.EmojisPopulated() {
+ // Status emojis are out-of-date with IDs, repopulate.
+ status.Emojis, err = s.state.DB.GetEmojisByIDs(
+ ctx, // these are already barebones
+ status.EmojiIDs,
+ )
+ if err != nil {
+ return fmt.Errorf("error populating status emojis: %w", err)
+ }
+ }
+
+ return nil
+}
+
func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Error {
err := s.state.Caches.GTS.Status().Store(status, func() error {
// It is safe to run this database transaction within cache.Store
@@ -239,12 +301,16 @@ func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) db.Er
})
})
if err != nil {
- // already processed
return err
}
for _, id := range status.AttachmentIDs {
- // Clear updated media attachment IDs from cache
+ // Invalidate media attachments from cache.
+ //
+ // NOTE: this is needed due to the way in which
+ // we upload status attachments, and only after
+ // update them with a known status ID. This is
+ // not the case for header/avatar attachments.
s.state.Caches.GTS.Media().Invalidate("ID", id)
}
@@ -322,14 +388,19 @@ func (s *statusDB) UpdateStatus(ctx context.Context, status *gtsmodel.Status, co
return err
}
+ // Invalidate status from database lookups.
+ s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
+
for _, id := range status.AttachmentIDs {
- // Clear updated media attachment IDs from cache
+ // Invalidate media attachments from cache.
+ //
+ // NOTE: this is needed due to the way in which
+ // we upload status attachments, and only after
+ // update them with a known status ID. This is
+ // not the case for header/avatar attachments.
s.state.Caches.GTS.Media().Invalidate("ID", id)
}
- // Drop any old status value from cache by this ID
- s.state.Caches.GTS.Status().Invalidate("ID", status.ID)
-
return nil
}
@@ -367,8 +438,12 @@ func (s *statusDB) DeleteStatusByID(ctx context.Context, id string) db.Error {
return err
}
- // Drop any old value from cache by this ID
+ // Invalidate status from database lookups.
s.state.Caches.GTS.Status().Invalidate("ID", id)
+
+ // Invalidate status from all visibility lookups.
+ s.state.Caches.Visibility.Invalidate("ItemID", id)
+
return nil
}
diff --git a/internal/db/bundb/statusfave.go b/internal/db/bundb/statusfave.go
index c42ab249f..0f7e5df74 100644
--- a/internal/db/bundb/statusfave.go
+++ b/internal/db/bundb/statusfave.go
@@ -23,6 +23,7 @@
"fmt"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state"
@@ -34,29 +35,82 @@ type statusFaveDB struct {
state *state.State
}
-func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
- fave := new(gtsmodel.StatusFave)
+func (s *statusFaveDB) GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
+ return s.getStatusFave(
+ ctx,
+ "AccountID.StatusID",
+ func(fave *gtsmodel.StatusFave) error {
+ return s.conn.
+ NewSelect().
+ Model(fave).
+ Where("? = ?", bun.Ident("account_id"), accountID).
+ Where("? = ?", bun.Ident("status_id"), statusID).
+ Scan(ctx)
+ },
+ accountID,
+ statusID,
+ )
+}
- err := s.conn.
- NewSelect().
- Model(fave).
- Where("? = ?", bun.Ident("status_fave.ID"), id).
- Scan(ctx)
+func (s *statusFaveDB) GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, db.Error) {
+ return s.getStatusFave(
+ ctx,
+ "ID",
+ func(fave *gtsmodel.StatusFave) error {
+ return s.conn.
+ NewSelect().
+ Model(fave).
+ Where("? = ?", bun.Ident("id"), id).
+ Scan(ctx)
+ },
+ id,
+ )
+}
+
+func (s *statusFaveDB) getStatusFave(ctx context.Context, lookup string, dbQuery func(*gtsmodel.StatusFave) error, keyParts ...any) (*gtsmodel.StatusFave, error) {
+ // Fetch status fave from database cache with loader callback
+ fave, err := s.state.Caches.GTS.StatusFave().Load(lookup, func() (*gtsmodel.StatusFave, error) {
+ var fave gtsmodel.StatusFave
+
+ // Not cached! Perform database query.
+ if err := dbQuery(&fave); err != nil {
+ return nil, s.conn.ProcessError(err)
+ }
+
+ return &fave, nil
+ }, keyParts...)
if err != nil {
- return nil, s.conn.ProcessError(err)
+ return nil, err
}
- fave.Account, err = s.state.DB.GetAccountByID(ctx, fave.AccountID)
+ if gtscontext.Barebones(ctx) {
+ // no need to fully populate.
+ return fave, nil
+ }
+
+ // Fetch the status fave author account.
+ fave.Account, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ fave.AccountID,
+ )
if err != nil {
return nil, fmt.Errorf("error getting status fave account %q: %w", fave.AccountID, err)
}
- fave.TargetAccount, err = s.state.DB.GetAccountByID(ctx, fave.TargetAccountID)
+ // Fetch the status fave target account.
+ fave.TargetAccount, err = s.state.DB.GetAccountByID(
+ gtscontext.SetBarebones(ctx),
+ fave.TargetAccountID,
+ )
if err != nil {
return nil, fmt.Errorf("error getting status fave target account %q: %w", fave.TargetAccountID, err)
}
- fave.Status, err = s.state.DB.GetStatusByID(ctx, fave.StatusID)
+ // Fetch the status fave target status.
+ fave.Status, err = s.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ fave.StatusID,
+ )
if err != nil {
return nil, fmt.Errorf("error getting status fave status %q: %w", fave.StatusID, err)
}
@@ -64,38 +118,22 @@ func (s *statusFaveDB) GetStatusFave(ctx context.Context, id string) (*gtsmodel.
return fave, nil
}
-func (s *statusFaveDB) GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, db.Error) {
- var id string
-
- err := s.conn.
- NewSelect().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
- Column("status_fave.id").
- Where("? = ?", bun.Ident("status_fave.account_id"), accountID).
- Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
- Scan(ctx, &id)
- if err != nil {
- return nil, s.conn.ProcessError(err)
- }
-
- return s.GetStatusFave(ctx, id)
-}
-
-func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
+func (s *statusFaveDB) GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, db.Error) {
ids := []string{}
if err := s.conn.
NewSelect().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
- Column("status_fave.id").
- Where("? = ?", bun.Ident("status_fave.status_id"), statusID).
+ Table("status_faves").
+ Column("id").
+ Where("? = ?", bun.Ident("status_id"), statusID).
Scan(ctx, &ids); err != nil {
return nil, s.conn.ProcessError(err)
}
faves := make([]*gtsmodel.StatusFave, 0, len(ids))
+
for _, id := range ids {
- fave, err := s.GetStatusFave(ctx, id)
+ fave, err := s.GetStatusFaveByID(ctx, id)
if err != nil {
log.Errorf(ctx, "error getting status fave %q: %v", id, err)
continue
@@ -107,23 +145,27 @@ func (s *statusFaveDB) GetStatusFaves(ctx context.Context, statusID string) ([]*
return faves, nil
}
-func (s *statusFaveDB) PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) db.Error {
- _, err := s.conn.
- NewInsert().
- Model(statusFave).
- Exec(ctx)
-
- return s.conn.ProcessError(err)
+func (s *statusFaveDB) PutStatusFave(ctx context.Context, fave *gtsmodel.StatusFave) db.Error {
+ return s.state.Caches.GTS.StatusFave().Store(fave, func() error {
+ _, err := s.conn.
+ NewInsert().
+ Model(fave).
+ Exec(ctx)
+ return s.conn.ProcessError(err)
+ })
}
-func (s *statusFaveDB) DeleteStatusFave(ctx context.Context, id string) db.Error {
- _, err := s.conn.
+func (s *statusFaveDB) DeleteStatusFaveByID(ctx context.Context, id string) db.Error {
+ if _, err := s.conn.
NewDelete().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
- Where("? = ?", bun.Ident("status_fave.id"), id).
- Exec(ctx)
+ Table("status_faves").
+ Where("? = ?", bun.Ident("id"), id).
+ Exec(ctx); err != nil {
+ return s.conn.ProcessError(err)
+ }
- return s.conn.ProcessError(err)
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ return nil
}
func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID string, originAccountID string) db.Error {
@@ -131,42 +173,52 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
return errors.New("DeleteStatusFaves: one of targetAccountID or originAccountID must be set")
}
- // TODO: Capture fave IDs in a RETURNING
- // statement (when faves have a cache),
- // + use the IDs to invalidate cache entries.
+ // Capture fave IDs in a RETURNING statement.
+ var faveIDs []string
q := s.conn.
NewDelete().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave"))
+ Table("status_faves").
+ Returning("?", bun.Ident("id"))
if targetAccountID != "" {
- q = q.Where("? = ?", bun.Ident("status_fave.target_account_id"), targetAccountID)
+ q = q.Where("? = ?", bun.Ident("target_account_id"), targetAccountID)
}
if originAccountID != "" {
- q = q.Where("? = ?", bun.Ident("status_fave.account_id"), originAccountID)
+ q = q.Where("? = ?", bun.Ident("account_id"), originAccountID)
}
- if _, err := q.Exec(ctx); err != nil {
+ if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
}
+ for _, id := range faveIDs {
+ // Invalidate each of the returned status fave IDs.
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ }
+
return nil
}
func (s *statusFaveDB) DeleteStatusFavesForStatus(ctx context.Context, statusID string) db.Error {
- // TODO: Capture fave IDs in a RETURNING
- // statement (when faves have a cache),
- // + use the IDs to invalidate cache entries.
+ // Capture fave IDs in a RETURNING statement.
+ var faveIDs []string
q := s.conn.
NewDelete().
- TableExpr("? AS ?", bun.Ident("status_faves"), bun.Ident("status_fave")).
- Where("? = ?", bun.Ident("status_fave.status_id"), statusID)
+ Table("status_faves").
+ Where("? = ?", bun.Ident("status_id"), statusID).
+ Returning("?", bun.Ident("id"))
- if _, err := q.Exec(ctx); err != nil {
+ if _, err := q.Exec(ctx, &faveIDs); err != nil {
return s.conn.ProcessError(err)
}
+ for _, id := range faveIDs {
+ // Invalidate each of the returned status fave IDs.
+ s.state.Caches.GTS.StatusFave().Invalidate("ID", id)
+ }
+
return nil
}
diff --git a/internal/db/bundb/statusfave_test.go b/internal/db/bundb/statusfave_test.go
index 98e495bf3..7218390bc 100644
--- a/internal/db/bundb/statusfave_test.go
+++ b/internal/db/bundb/statusfave_test.go
@@ -35,7 +35,7 @@ type StatusFaveTestSuite struct {
func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
testStatus := suite.testStatuses["admin_account_status_1"]
- faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
+ faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -51,7 +51,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaves() {
func (suite *StatusFaveTestSuite) TestGetStatusFavesNone() {
testStatus := suite.testStatuses["admin_account_status_4"]
- faves, err := suite.db.GetStatusFaves(context.Background(), testStatus.ID)
+ faves, err := suite.db.GetStatusFavesForStatus(context.Background(), testStatus.ID)
if err != nil {
suite.FailNow(err.Error())
}
@@ -63,7 +63,7 @@ func (suite *StatusFaveTestSuite) TestGetStatusFaveByAccountID() {
testAccount := suite.testAccounts["local_account_1"]
testStatus := suite.testStatuses["admin_account_status_1"]
- fave, err := suite.db.GetStatusFaveByAccountID(context.Background(), testAccount.ID, testStatus.ID)
+ fave, err := suite.db.GetStatusFave(context.Background(), testAccount.ID, testStatus.ID)
suite.NoError(err)
suite.NotNil(fave)
}
@@ -129,17 +129,17 @@ func (suite *StatusFaveTestSuite) TestDeleteStatusFave() {
testFave := suite.testFaves["local_account_1_admin_account_status_1"]
ctx := context.Background()
- if err := suite.db.DeleteStatusFave(ctx, testFave.ID); err != nil {
+ if err := suite.db.DeleteStatusFaveByID(ctx, testFave.ID); err != nil {
suite.FailNow(err.Error())
}
- fave, err := suite.db.GetStatusFave(ctx, testFave.ID)
+ fave, err := suite.db.GetStatusFaveByID(ctx, testFave.ID)
suite.ErrorIs(err, db.ErrNoEntries)
suite.Nil(fave)
}
func (suite *StatusFaveTestSuite) TestDeleteStatusFaveNonExisting() {
- err := suite.db.DeleteStatusFave(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
+ err := suite.db.DeleteStatusFaveByID(context.Background(), "01GVAV715K6Y2SG9ZKS9ZA8G7G")
suite.NoError(err)
}
diff --git a/internal/db/bundb/timeline.go b/internal/db/bundb/timeline.go
index ea4a87d03..1ab140103 100644
--- a/internal/db/bundb/timeline.go
+++ b/internal/db/bundb/timeline.go
@@ -61,9 +61,12 @@ func (t *timelineDB) GetHomeTimeline(ctx context.Context, accountID string, maxI
Order("status.id DESC")
if maxID == "" {
+ const future = 24 * time.Hour
+
var err error
- // don't return statuses more than five minutes in the future
- maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
+
+ // don't return statuses more than 24hr in the future
+ maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}
@@ -138,15 +141,16 @@ func (t *timelineDB) GetPublicTimeline(ctx context.Context, maxID string, sinceI
TableExpr("? AS ?", bun.Ident("statuses"), bun.Ident("status")).
Column("status.id").
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
- WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_id")).
- WhereGroup(" AND ", whereEmptyOrNull("status.in_reply_to_uri")).
WhereGroup(" AND ", whereEmptyOrNull("status.boost_of_id")).
Order("status.id DESC")
if maxID == "" {
+ const future = 24 * time.Hour
+
var err error
- // don't return statuses more than five minutes in the future
- maxID, err = id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
+
+ // don't return statuses more than 24hr in the future
+ maxID, err = id.NewULIDFromTime(time.Now().Add(future))
if err != nil {
return nil, err
}
diff --git a/internal/db/bundb/timeline_test.go b/internal/db/bundb/timeline_test.go
index 5a447111c..d6632b38c 100644
--- a/internal/db/bundb/timeline_test.go
+++ b/internal/db/bundb/timeline_test.go
@@ -34,15 +34,32 @@ type TimelineTestSuite struct {
}
func (suite *TimelineTestSuite) TestGetPublicTimeline() {
- ctx := context.Background()
+ var count int
+ for _, status := range suite.testStatuses {
+ if status.Visibility == gtsmodel.VisibilityPublic &&
+ status.BoostOfID == "" {
+ count++
+ }
+ }
+
+ ctx := context.Background()
s, err := suite.db.GetPublicTimeline(ctx, "", "", "", 20, false)
suite.NoError(err)
- suite.Len(s, 6)
+ suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
+ var count int
+
+ for _, status := range suite.testStatuses {
+ if status.Visibility == gtsmodel.VisibilityPublic &&
+ status.BoostOfID == "" {
+ count++
+ }
+ }
+
ctx := context.Background()
futureStatus := getFutureStatus()
@@ -53,7 +70,7 @@ func (suite *TimelineTestSuite) TestGetPublicTimelineWithFutureStatus() {
suite.NoError(err)
suite.NotContains(s, futureStatus)
- suite.Len(s, 6)
+ suite.Len(s, count)
}
func (suite *TimelineTestSuite) TestGetHomeTimeline() {
diff --git a/internal/db/media.go b/internal/db/media.go
index d86f9fe84..05609ba52 100644
--- a/internal/db/media.go
+++ b/internal/db/media.go
@@ -29,6 +29,9 @@ type Media interface {
// GetAttachmentByID gets a single attachment by its ID.
GetAttachmentByID(ctx context.Context, id string) (*gtsmodel.MediaAttachment, Error)
+ // GetAttachmentsByIDs fetches a list of media attachments for given IDs.
+ GetAttachmentsByIDs(ctx context.Context, ids []string) ([]*gtsmodel.MediaAttachment, error)
+
// PutAttachment inserts the given attachment into the database.
PutAttachment(ctx context.Context, media *gtsmodel.MediaAttachment) error
diff --git a/internal/db/mention.go b/internal/db/mention.go
index d66394a5d..348f946a2 100644
--- a/internal/db/mention.go
+++ b/internal/db/mention.go
@@ -30,4 +30,10 @@ type Mention interface {
// GetMentions gets multiple mentions.
GetMentions(ctx context.Context, ids []string) ([]*gtsmodel.Mention, Error)
+
+ // PutMention will insert the given mention into the database.
+ PutMention(ctx context.Context, mention *gtsmodel.Mention) error
+
+ // DeleteMentionByID will delete mention with given ID from the database.
+ DeleteMentionByID(ctx context.Context, id string) error
}
diff --git a/internal/db/notification.go b/internal/db/notification.go
index 18e40b4c1..fd3affe90 100644
--- a/internal/db/notification.go
+++ b/internal/db/notification.go
@@ -28,14 +28,17 @@ type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
- GetNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
+ GetAccountNotifications(ctx context.Context, accountID string, excludeTypes []string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
- GetNotification(ctx context.Context, id string) (*gtsmodel.Notification, Error)
+ GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, Error)
- // DeleteNotification deletes one notification according to its id,
+ // PutNotification will insert the given notification into the database.
+ PutNotification(ctx context.Context, notif *gtsmodel.Notification) error
+
+ // DeleteNotificationByID deletes one notification according to its id,
// and removes that notification from the in-memory cache.
- DeleteNotification(ctx context.Context, id string) Error
+ DeleteNotificationByID(ctx context.Context, id string) Error
// DeleteNotifications mass deletes notifications targeting targetAccountID
// and/or originating from originAccountID.
@@ -50,7 +53,7 @@ type Notification interface {
// originate from originAccountID will be deleted.
//
// At least one parameter must not be an empty string.
- DeleteNotifications(ctx context.Context, targetAccountID string, originAccountID string) Error
+ DeleteNotifications(ctx context.Context, types []string, targetAccountID string, originAccountID string) Error
// DeleteNotificationsForStatus deletes all notifications that relate to
// the given statusID. This function is useful when a status has been deleted,
diff --git a/internal/db/relationship.go b/internal/db/relationship.go
index d13a73dea..838647154 100644
--- a/internal/db/relationship.go
+++ b/internal/db/relationship.go
@@ -25,42 +25,86 @@
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
- // IsBlocked checks whether account 1 has a block in place against account2.
- // If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
- IsBlocked(ctx context.Context, account1 string, account2 string, eitherDirection bool) (bool, Error)
+ // IsBlocked checks whether source account has a block in place against target.
+ IsBlocked(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
+
+ // IsEitherBlocked checks whether there is a block in place between either of account1 and account2.
+ IsEitherBlocked(ctx context.Context, accountID1 string, accountID2 string) (bool, error)
+
+ // GetBlockByID fetches block with given ID from the database.
+ GetBlockByID(ctx context.Context, id string) (*gtsmodel.Block, error)
+
+ // GetBlockByURI fetches block with given AP URI from the database.
+ GetBlockByURI(ctx context.Context, uri string) (*gtsmodel.Block, error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
- //
- // Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
- // not if you're just checking for the existence of a block.
- GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, Error)
+ GetBlock(ctx context.Context, account1 string, account2 string) (*gtsmodel.Block, error)
// PutBlock attempts to place the given account block in the database.
- PutBlock(ctx context.Context, block *gtsmodel.Block) Error
+ PutBlock(ctx context.Context, block *gtsmodel.Block) error
// DeleteBlockByID removes block with given ID from the database.
- DeleteBlockByID(ctx context.Context, id string) Error
+ DeleteBlockByID(ctx context.Context, id string) error
// DeleteBlockByURI removes block with given AP URI from the database.
- DeleteBlockByURI(ctx context.Context, uri string) Error
+ DeleteBlockByURI(ctx context.Context, uri string) error
- // DeleteBlocksByOriginAccountID removes any blocks with accountID equal to originAccountID.
- DeleteBlocksByOriginAccountID(ctx context.Context, originAccountID string) Error
-
- // DeleteBlocksByTargetAccountID removes any blocks with given targetAccountID.
- DeleteBlocksByTargetAccountID(ctx context.Context, targetAccountID string) Error
+ // DeleteAccountBlocks will delete all database blocks to / from the given account ID.
+ DeleteAccountBlocks(ctx context.Context, accountID string) error
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(ctx context.Context, requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
- // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
- IsFollowing(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ // GetFollowByID fetches follow with given ID from the database.
+ GetFollowByID(ctx context.Context, id string) (*gtsmodel.Follow, error)
- // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
- IsFollowRequested(ctx context.Context, sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
+ // GetFollowByURI fetches follow with given AP URI from the database.
+ GetFollowByURI(ctx context.Context, uri string) (*gtsmodel.Follow, error)
+
+ // GetFollow retrieves a follow if it exists between source and target accounts.
+ GetFollow(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
+
+ // GetFollowRequestByID fetches follow request with given ID from the database.
+ GetFollowRequestByID(ctx context.Context, id string) (*gtsmodel.FollowRequest, error)
+
+ // GetFollowRequestByURI fetches follow request with given AP URI from the database.
+ GetFollowRequestByURI(ctx context.Context, uri string) (*gtsmodel.FollowRequest, error)
+
+ // GetFollowRequest retrieves a follow request if it exists between source and target accounts.
+ GetFollowRequest(ctx context.Context, sourceAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, error)
+
+ // IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
+ IsFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
- IsMutualFollowing(ctx context.Context, account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
+ IsMutualFollowing(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
+
+ // IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
+ IsFollowRequested(ctx context.Context, sourceAccountID string, targetAccountID string) (bool, Error)
+
+ // PutFollow attempts to place the given account follow in the database.
+ PutFollow(ctx context.Context, follow *gtsmodel.Follow) error
+
+ // PutFollowRequest attempts to place the given account follow request in the database.
+ PutFollowRequest(ctx context.Context, follow *gtsmodel.FollowRequest) error
+
+ // DeleteFollowByID deletes a follow from the database with the given ID.
+ DeleteFollowByID(ctx context.Context, id string) error
+
+ // DeleteFollowByURI deletes a follow from the database with the given URI.
+ DeleteFollowByURI(ctx context.Context, uri string) error
+
+ // DeleteFollowRequestByID deletes a follow request from the database with the given ID.
+ DeleteFollowRequestByID(ctx context.Context, id string) error
+
+ // DeleteFollowRequestByURI deletes a follow request from the database with the given URI.
+ DeleteFollowRequestByURI(ctx context.Context, uri string) error
+
+ // DeleteAccountFollows will delete all database follows to / from the given account ID.
+ DeleteAccountFollows(ctx context.Context, accountID string) error
+
+ // DeleteAccountFollowRequests will delete all database follow requests to / from the given account ID.
+ DeleteAccountFollowRequests(ctx context.Context, accountID string) error
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
@@ -69,65 +113,41 @@ type Relationship interface {
AcceptFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// RejectFollowRequest fetches a follow request from the database, and then deletes it.
- //
- // The deleted follow request will be returned so that further processing can be done on it.
- RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (*gtsmodel.FollowRequest, Error)
+ RejectFollowRequest(ctx context.Context, originAccountID string, targetAccountID string) Error
- // GetFollows returns a slice of follows owned by the given accountID, and/or
- // targeting the given account id.
- //
- // If accountID is set and targetAccountID isn't, then all follows created by
- // accountID will be returned.
- //
- // If targetAccountID is set and accountID isn't, then all follows targeting
- // targetAccountID will be returned.
- //
- // If both accountID and targetAccountID are set, then only 0 or 1 follows will
- // be in the returned slice.
- GetFollows(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.Follow, Error)
+ // GetAccountFollows returns a slice of follows owned by the given accountID.
+ GetAccountFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
- // GetLocalFollowersIDs returns a list of local account IDs which follow the
- // targetAccountID. The returned IDs are not guaranteed to be ordered in any
- // particular way, so take care.
- GetLocalFollowersIDs(ctx context.Context, targetAccountID string) ([]string, Error)
+ // GetAccountLocalFollows returns a slice of follows owned by the given accountID, only including follows from this instance.
+ GetAccountLocalFollows(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
- // CountFollows is like GetFollows, but just counts rather than returning.
- CountFollows(ctx context.Context, accountID string, targetAccountID string) (int, Error)
+ // CountAccountFollows returns the amount of accounts that the given accountID is following.
+ CountAccountFollows(ctx context.Context, accountID string) (int, error)
- // GetFollowRequests returns a slice of follows requests owned by the given
- // accountID, and/or targeting the given account id.
- //
- // If accountID is set and targetAccountID isn't, then all requests created by
- // accountID will be returned.
- //
- // If targetAccountID is set and accountID isn't, then all requests targeting
- // targetAccountID will be returned.
- //
- // If both accountID and targetAccountID are set, then only 0 or 1 requests will
- // be in the returned slice.
- GetFollowRequests(ctx context.Context, accountID string, targetAccountID string) ([]*gtsmodel.FollowRequest, Error)
+ // CountAccountLocalFollows returns the amount of accounts that the given accountID is following, only including follows from this instance.
+ CountAccountLocalFollows(ctx context.Context, accountID string) (int, error)
- // CountFollowRequests is like GetFollowRequests, but just counts rather than returning.
- CountFollowRequests(ctx context.Context, accountID string, targetAccountID string) (int, Error)
+ // GetAccountFollowers fetches follows that target given accountID.
+ GetAccountFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
- // Unfollow removes a follow targeting targetAccountID and originating
- // from originAccountID.
- //
- // If a follow was removed this way, the AP URI of the follow will be
- // returned to the caller, so that further processing can take place
- // if necessary.
- //
- // If no follow was removed this way, the returned string will be empty.
- Unfollow(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
+ // GetAccountLocalFollowers fetches follows that target given accountID, only including follows from this instance.
+ GetAccountLocalFollowers(ctx context.Context, accountID string) ([]*gtsmodel.Follow, error)
- // UnfollowRequest removes a follow request targeting targetAccountID
- // and originating from originAccountID.
- //
- // If a follow request was removed this way, the AP URI of the follow
- // request will be returned to the caller, so that further processing
- // can take place if necessary.
- //
- // If no follow request was removed this way, the returned string will
- // be empty.
- UnfollowRequest(ctx context.Context, originAccountID string, targetAccountID string) (string, Error)
+ // CountAccountFollowers returns the amounts that the given ID is followed by.
+ CountAccountFollowers(ctx context.Context, accountID string) (int, error)
+
+ // CountAccountLocalFollowers returns the amounts that the given ID is followed by, only including follows from this instance.
+ CountAccountLocalFollowers(ctx context.Context, accountID string) (int, error)
+
+ // GetAccountFollowRequests returns all follow requests targeting the given account.
+ GetAccountFollowRequests(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
+
+ // GetAccountFollowRequesting returns all follow requests originating from the given account.
+ GetAccountFollowRequesting(ctx context.Context, accountID string) ([]*gtsmodel.FollowRequest, error)
+
+ // CountAccountFollowRequests returns number of follow requests targeting the given account.
+ CountAccountFollowRequests(ctx context.Context, accountID string) (int, error)
+
+ // CountAccountFollowerRequests returns number of follow requests originating from the given account.
+ CountAccountFollowRequesting(ctx context.Context, accountID string) (int, error)
}
diff --git a/internal/db/status.go b/internal/db/status.go
index 16728983a..fdce19094 100644
--- a/internal/db/status.go
+++ b/internal/db/status.go
@@ -37,6 +37,9 @@ type Status interface {
// GetStatusByURL returns one status from the database, with no rel fields populated, only their linking ID / URIs
GetStatusByURL(ctx context.Context, uri string) (*gtsmodel.Status, Error)
+ // PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc).
+ PopulateStatus(ctx context.Context, status *gtsmodel.Status) error
+
// PutStatus stores one status in the database.
PutStatus(ctx context.Context, status *gtsmodel.Status) Error
diff --git a/internal/db/statusfave.go b/internal/db/statusfave.go
index 2d55592aa..b435da514 100644
--- a/internal/db/statusfave.go
+++ b/internal/db/statusfave.go
@@ -24,22 +24,22 @@
)
type StatusFave interface {
- // GetStatusFave returns one status fave with the given id.
- GetStatusFave(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
-
// GetStatusFaveByAccountID gets one status fave created by the given
// accountID, targeting the given statusID.
- GetStatusFaveByAccountID(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
+ GetStatusFave(ctx context.Context, accountID string, statusID string) (*gtsmodel.StatusFave, Error)
+
+ // GetStatusFave returns one status fave with the given id.
+ GetStatusFaveByID(ctx context.Context, id string) (*gtsmodel.StatusFave, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
- GetStatusFaves(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
+ GetStatusFavesForStatus(ctx context.Context, statusID string) ([]*gtsmodel.StatusFave, Error)
// PutStatusFave inserts the given statusFave into the database.
PutStatusFave(ctx context.Context, statusFave *gtsmodel.StatusFave) Error
// DeleteStatusFave deletes one status fave with the given id.
- DeleteStatusFave(ctx context.Context, id string) Error
+ DeleteStatusFaveByID(ctx context.Context, id string) Error
// DeleteStatusFaves mass deletes status faves targeting targetAccountID
// and/or originating from originAccountID and/or faving statusID.
diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go
index 55a41ff09..8e202d585 100644
--- a/internal/federation/dereferencing/status.go
+++ b/internal/federation/dereferencing/status.go
@@ -383,7 +383,7 @@ func (d *deref) populateStatusMentions(ctx context.Context, status *gtsmodel.Sta
TargetAccountURL: targetAccount.URL,
}
- if err := d.db.Put(ctx, newMention); err != nil {
+ if err := d.db.PutMention(ctx, newMention); err != nil {
return fmt.Errorf("populateStatusMentions: error creating mention: %s", err)
}
diff --git a/internal/federation/federatingdb/accept.go b/internal/federation/federatingdb/accept.go
index 8f0e40694..1c39bc134 100644
--- a/internal/federation/federatingdb/accept.go
+++ b/internal/federation/federatingdb/accept.go
@@ -25,8 +25,6 @@
"codeberg.org/gruf/go-logger/v2/level"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/messages"
"github.com/superseriousbusiness/gotosocial/internal/uris"
@@ -63,16 +61,16 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
acceptedObjectIRI := iter.GetIRI()
if uris.IsFollowPath(acceptedObjectIRI) {
// ACCEPT FOLLOW
- gtsFollowRequest := >smodel.FollowRequest{}
- if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: acceptedObjectIRI.String()}}, gtsFollowRequest); err != nil {
+ followReq, err := f.state.DB.GetFollowRequestByURI(ctx, acceptedObjectIRI.String())
+ if err != nil {
return fmt.Errorf("ACCEPT: couldn't get follow request with id %s from the database: %s", acceptedObjectIRI.String(), err)
}
// make sure the addressee of the original follow is the same as whatever inbox this landed in
- if gtsFollowRequest.AccountID != receivingAccount.ID {
+ if followReq.AccountID != receivingAccount.ID {
return errors.New("ACCEPT: follow object account and inbox account were not the same")
}
- follow, err := f.state.DB.AcceptFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID)
+ follow, err := f.state.DB.AcceptFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID)
if err != nil {
return err
}
diff --git a/internal/federation/federatingdb/create.go b/internal/federation/federatingdb/create.go
index a82cd5cdf..166ea34a9 100644
--- a/internal/federation/federatingdb/create.go
+++ b/internal/federation/federatingdb/create.go
@@ -262,7 +262,7 @@ func (f *federatingDB) activityFollow(ctx context.Context, asType vocab.Type, re
followRequest.ID = id.NewULID()
- if err := f.state.DB.Put(ctx, followRequest); err != nil {
+ if err := f.state.DB.PutFollowRequest(ctx, followRequest); err != nil {
return fmt.Errorf("activityFollow: database error inserting follow request: %s", err)
}
diff --git a/internal/federation/federatingdb/followers.go b/internal/federation/federatingdb/followers.go
index d829a304c..4ca2e2683 100644
--- a/internal/federation/federatingdb/followers.go
+++ b/internal/federation/federatingdb/followers.go
@@ -38,7 +38,7 @@ func (f *federatingDB) Followers(ctx context.Context, actorIRI *url.URL) (follow
return nil, err
}
- follows, err := f.state.DB.GetFollows(ctx, "", acct.ID)
+ follows, err := f.state.DB.GetAccountFollowers(ctx, acct.ID)
if err != nil {
return nil, fmt.Errorf("Followers: db error getting followers for account id %s: %s", acct.ID, err)
}
diff --git a/internal/federation/federatingdb/following.go b/internal/federation/federatingdb/following.go
index 7b0db5660..391a2f810 100644
--- a/internal/federation/federatingdb/following.go
+++ b/internal/federation/federatingdb/following.go
@@ -38,7 +38,7 @@ func (f *federatingDB) Following(ctx context.Context, actorIRI *url.URL) (follow
return nil, err
}
- follows, err := f.state.DB.GetFollows(ctx, acct.ID, "")
+ follows, err := f.state.DB.GetAccountFollows(ctx, acct.ID)
if err != nil {
return nil, fmt.Errorf("Following: db error getting following for account id %s: %w", acct.ID, err)
}
diff --git a/internal/federation/federatingdb/inbox.go b/internal/federation/federatingdb/inbox.go
index 954a192e4..18974ba79 100644
--- a/internal/federation/federatingdb/inbox.go
+++ b/internal/federation/federatingdb/inbox.go
@@ -89,7 +89,7 @@ func (f *federatingDB) InboxesForIRI(c context.Context, iri *url.URL) (inboxIRIs
return nil, fmt.Errorf("couldn't find local account with username %s: %s", localAccountUsername, err)
}
- follows, err := f.state.DB.GetFollows(c, "", account.ID)
+ follows, err := f.state.DB.GetAccountFollowers(c, account.ID)
if err != nil {
return nil, fmt.Errorf("couldn't get followers of local account %s: %s", localAccountUsername, err)
}
diff --git a/internal/federation/federatingdb/reject.go b/internal/federation/federatingdb/reject.go
index e19224283..ceaee83ef 100644
--- a/internal/federation/federatingdb/reject.go
+++ b/internal/federation/federatingdb/reject.go
@@ -25,8 +25,6 @@
"codeberg.org/gruf/go-logger/v2/level"
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/uris"
)
@@ -62,17 +60,17 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
rejectedObjectIRI := iter.GetIRI()
if uris.IsFollowPath(rejectedObjectIRI) {
// REJECT FOLLOW
- gtsFollowRequest := >smodel.FollowRequest{}
- if err := f.state.DB.GetWhere(ctx, []db.Where{{Key: "uri", Value: rejectedObjectIRI.String()}}, gtsFollowRequest); err != nil {
+ followReq, err := f.state.DB.GetFollowRequestByURI(ctx, rejectedObjectIRI.String())
+ if err != nil {
return fmt.Errorf("Reject: couldn't get follow request with id %s from the database: %s", rejectedObjectIRI.String(), err)
}
// make sure the addressee of the original follow is the same as whatever inbox this landed in
- if gtsFollowRequest.AccountID != receivingAccount.ID {
+ if followReq.AccountID != receivingAccount.ID {
return errors.New("Reject: follow object account and inbox account were not the same")
}
- if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollowRequest.AccountID, gtsFollowRequest.TargetAccountID); err != nil {
+ if err := f.state.DB.RejectFollowRequest(ctx, followReq.AccountID, followReq.TargetAccountID); err != nil {
return err
}
@@ -101,7 +99,7 @@ func (f *federatingDB) Reject(ctx context.Context, reject vocab.ActivityStreamsR
if gtsFollow.AccountID != receivingAccount.ID {
return errors.New("Reject: follow object account and inbox account were not the same")
}
- if _, err := f.state.DB.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil {
+ if err := f.state.DB.RejectFollowRequest(ctx, gtsFollow.AccountID, gtsFollow.TargetAccountID); err != nil {
return err
}
diff --git a/internal/federation/federatingdb/undo.go b/internal/federation/federatingdb/undo.go
index 5e75b22f7..517aa9cc6 100644
--- a/internal/federation/federatingdb/undo.go
+++ b/internal/federation/federatingdb/undo.go
@@ -26,7 +26,6 @@
"github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log"
)
@@ -80,11 +79,11 @@ func (f *federatingDB) Undo(ctx context.Context, undo vocab.ActivityStreamsUndo)
return errors.New("UNDO: follow object account and inbox account were not the same")
}
// delete any existing FOLLOW
- if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.Follow{}); err != nil {
+ if err := f.state.DB.DeleteFollowByURI(ctx, gtsFollow.URI); err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("UNDO: db error removing follow: %s", err)
}
// delete any existing FOLLOW REQUEST
- if err := f.state.DB.DeleteWhere(ctx, []db.Where{{Key: "uri", Value: gtsFollow.URI}}, >smodel.FollowRequest{}); err != nil {
+ if err := f.state.DB.DeleteFollowRequestByURI(ctx, gtsFollow.URI); err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("UNDO: db error removing follow request: %s", err)
}
l.Debug("follow undone")
diff --git a/internal/federation/federatingdb/util.go b/internal/federation/federatingdb/util.go
index df446a5f3..8ad209f5e 100644
--- a/internal/federation/federatingdb/util.go
+++ b/internal/federation/federatingdb/util.go
@@ -231,7 +231,7 @@ func (f *federatingDB) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (ac
// getAccountForIRI returns the account that corresponds to or owns the given IRI.
func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gtsmodel.Account, error) {
var (
- acct = >smodel.Account{}
+ acct *gtsmodel.Account
err error
)
@@ -245,7 +245,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsInboxPath(iri):
- if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "inbox_uri", Value: iri.String()}}, acct); err != nil {
+ if acct, err = f.state.DB.GetAccountByInboxURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", iri.String())
}
@@ -253,7 +253,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsOutboxPath(iri):
- if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "outbox_uri", Value: iri.String()}}, acct); err != nil {
+ if acct, err = f.state.DB.GetAccountByOutboxURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", iri.String())
}
@@ -261,7 +261,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsFollowersPath(iri):
- if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "followers_uri", Value: iri.String()}}, acct); err != nil {
+ if acct, err = f.state.DB.GetAccountByFollowersURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to followers_uri %s", iri.String())
}
@@ -269,7 +269,7 @@ func (f *federatingDB) getAccountForIRI(ctx context.Context, iri *url.URL) (*gts
}
return acct, nil
case uris.IsFollowingPath(iri):
- if err = f.state.DB.GetWhere(ctx, []db.Where{{Key: "following_uri", Value: iri.String()}}, acct); err != nil {
+ if acct, err = f.state.DB.GetAccountByFollowingURI(ctx, iri.String()); err != nil {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to following_uri %s", iri.String())
}
diff --git a/internal/federation/federatingprotocol.go b/internal/federation/federatingprotocol.go
index 37d43276b..ed0a216fe 100644
--- a/internal/federation/federatingprotocol.go
+++ b/internal/federation/federatingprotocol.go
@@ -283,7 +283,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, errors.New("requesting account not set on request context, so couldn't determine blocks")
}
// the receiver shouldn't block the sender
- blocked, err = f.db.IsBlocked(ctx, receivingAccount.ID, requestingAccount.ID, false)
+ blocked, err = f.db.IsBlocked(ctx, receivingAccount.ID, requestingAccount.ID)
if err != nil {
return false, fmt.Errorf("error checking user-level blocks: %s", err)
}
@@ -309,7 +309,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
for _, involvedAccountID := range deduped {
// the involved account shouldn't block whoever is making this request
- blocked, err = f.db.IsBlocked(ctx, involvedAccountID, requestingAccount.ID, false)
+ blocked, err = f.db.IsBlocked(ctx, involvedAccountID, requestingAccount.ID)
if err != nil {
return false, fmt.Errorf("error checking user-level otherInvolvedIRI blocks: %s", err)
}
@@ -318,7 +318,7 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
}
// whoever is receiving this request shouldn't block the involved account
- blocked, err = f.db.IsBlocked(ctx, receivingAccount.ID, involvedAccountID, false)
+ blocked, err = f.db.IsBlocked(ctx, receivingAccount.ID, involvedAccountID)
if err != nil {
return false, fmt.Errorf("error checking user-level otherInvolvedIRI blocks: %s", err)
}
diff --git a/internal/gtscontext/context.go b/internal/gtscontext/context.go
new file mode 100644
index 000000000..7d4a44774
--- /dev/null
+++ b/internal/gtscontext/context.go
@@ -0,0 +1,43 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package gtscontext
+
+import "context"
+
+// package private context key type.
+type ctxkey uint
+
+const (
+ // context keys.
+ _ ctxkey = iota
+ barebonesKey
+)
+
+// Barebones returns whether the "barebones" context key has been set. This
+// can be used to indicate to the database, for example, that only a barebones
+// model need be returned, Allowing it to skip populating sub models.
+func Barebones(ctx context.Context) bool {
+ _, ok := ctx.Value(barebonesKey).(struct{})
+ return ok
+}
+
+// SetBarebones sets the "barebones" context flag and returns this wrapped context.
+// See Barebones() for further information on the "barebones" context flag..
+func SetBarebones(ctx context.Context) context.Context {
+ return context.WithValue(ctx, barebonesKey, struct{}{})
+}
diff --git a/internal/gtsmodel/account.go b/internal/gtsmodel/account.go
index 54d8c3d71..42e4f5ea6 100644
--- a/internal/gtsmodel/account.go
+++ b/internal/gtsmodel/account.go
@@ -27,6 +27,7 @@
"time"
"github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
)
// Account represents either a local or a remote fediverse account, gotosocial or otherwise (mastodon, pleroma, etc).
@@ -35,8 +36,8 @@ type Account struct {
CreatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item created.
UpdatedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero,notnull,default:current_timestamp"` // when was item was last updated.
FetchedAt time.Time `validate:"required_with=Domain" bun:"type:timestamptz,nullzero"` // when was item (remote) last fetched.
- Username string `validate:"required" bun:",nullzero,notnull,unique:userdomain"` // Username of the account, should just be a string of [a-zA-Z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``. Username and domain should be unique *with* each other
- Domain string `validate:"omitempty,fqdn" bun:",nullzero,unique:userdomain"` // Domain of the account, will be null if this is a local account, otherwise something like ``example.org``. Should be unique with username.
+ Username string `validate:"required" bun:",nullzero,notnull,unique:usernamedomain"` // Username of the account, should just be a string of [a-zA-Z0-9_]. Can be added to domain to create the full username in the form ``[username]@[domain]`` eg., ``user_96@example.org``. Username and domain should be unique *with* each other
+ Domain string `validate:"omitempty,fqdn" bun:",nullzero,unique:usernamedomain"` // Domain of the account, will be null if this is a local account, otherwise something like ``example.org``. Should be unique with username.
AvatarMediaAttachmentID string `validate:"omitempty,ulid" bun:"type:CHAR(26),nullzero"` // Database ID of the media attachment, if present
AvatarMediaAttachment *MediaAttachment `validate:"-" bun:"rel:belongs-to"` // MediaAttachment corresponding to avatarMediaAttachmentID
AvatarRemoteURL string `validate:"omitempty,url" bun:",nullzero"` // For a non-local account, where can the header be fetched?
@@ -70,8 +71,8 @@ type Account struct {
FollowersURI string `validate:"required_without=Domain,omitempty,url" bun:",nullzero,unique"` // URI for getting the followers list of this account
FeaturedCollectionURI string `validate:"required_without=Domain,omitempty,url" bun:",nullzero,unique"` // URL for getting the featured collection list of this account
ActorType string `validate:"oneof=Application Group Organization Person Service" bun:",nullzero,notnull"` // What type of activitypub actor is this account?
- PrivateKey *rsa.PrivateKey `validate:"required_without=Domain"` // Privatekey for validating activitypub requests, will only be defined for local accounts
- PublicKey *rsa.PublicKey `validate:"required"` // Publickey for encoding activitypub requests, will be defined for both local and remote accounts
+ PrivateKey *rsa.PrivateKey `validate:"required_without=Domain" bun:""` // Privatekey for validating activitypub requests, will only be defined for local accounts
+ PublicKey *rsa.PublicKey `validate:"required" bun:",notnull,unique"` // Publickey for encoding activitypub requests, will be defined for both local and remote accounts
PublicKeyURI string `validate:"required,url" bun:",nullzero,notnull,unique"` // Web-reachable location of this account's public key
SensitizedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero"` // When was this account set to have all its media shown as sensitive?
SilencedAt time.Time `validate:"-" bun:"type:timestamptz,nullzero"` // When was this account silenced (eg., statuses only visible to followers, not public)?
@@ -82,23 +83,44 @@ type Account struct {
}
// IsLocal returns whether account is a local user account.
-func (a Account) IsLocal() bool {
+func (a *Account) IsLocal() bool {
return a.Domain == "" || a.Domain == config.GetHost() || a.Domain == config.GetAccountDomain()
}
// IsRemote returns whether account is a remote user account.
-func (a Account) IsRemote() bool {
+func (a *Account) IsRemote() bool {
return !a.IsLocal()
}
// IsInstance returns whether account is an instance internal actor account.
-func (a Account) IsInstance() bool {
+func (a *Account) IsInstance() bool {
return a.Username == a.Domain ||
a.FollowersURI == "" ||
a.FollowingURI == "" ||
(a.Username == "internal.fetch" && strings.Contains(a.Note, "internal service actor"))
}
+// EmojisPopulated returns whether emojis are populated according to current EmojiIDs.
+func (a *Account) EmojisPopulated() bool {
+ if len(a.EmojiIDs) != len(a.Emojis) {
+ // this is the quickest indicator.
+ return false
+ }
+
+ // Emojis must be in same order.
+ for i, id := range a.EmojiIDs {
+ if a.Emojis[i] == nil {
+ log.Warnf(nil, "nil emoji in slice for account %s", a.URI)
+ continue
+ }
+ if a.Emojis[i].ID != id {
+ return false
+ }
+ }
+
+ return true
+}
+
// AccountToEmoji is an intermediate struct to facilitate the many2many relationship between an account and one or more emojis.
type AccountToEmoji struct {
AccountID string `validate:"ulid,required" bun:"type:CHAR(26),unique:accountemoji,nullzero,notnull"`
diff --git a/internal/gtsmodel/status.go b/internal/gtsmodel/status.go
index eda24e316..d04deecef 100644
--- a/internal/gtsmodel/status.go
+++ b/internal/gtsmodel/status.go
@@ -19,6 +19,8 @@
import (
"time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/log"
)
// Status represents a user-created 'post' or 'status' in the database, either remote or local
@@ -65,27 +67,120 @@ type Status struct {
Likeable *bool `validate:"-" bun:",notnull"` // This status can be liked/faved
}
-/*
- The below functions are added onto the gtsmodel status so that it satisfies
- the Timelineable interface in internal/timeline.
-*/
-
+// GetID implements timeline.Timelineable{}.
func (s *Status) GetID() string {
return s.ID
}
+// GetAccountID implements timeline.Timelineable{}.
func (s *Status) GetAccountID() string {
return s.AccountID
}
+// GetBoostID implements timeline.Timelineable{}.
func (s *Status) GetBoostOfID() string {
return s.BoostOfID
}
+// GetBoostOfAccountID implements timeline.Timelineable{}.
func (s *Status) GetBoostOfAccountID() string {
return s.BoostOfAccountID
}
+// AttachmentsPopulated returns whether media attachments are populated according to current AttachmentIDs.
+func (s *Status) AttachmentsPopulated() bool {
+ if len(s.AttachmentIDs) != len(s.Attachments) {
+ // this is the quickest indicator.
+ return false
+ }
+
+ // Attachments must be in same order.
+ for i, id := range s.AttachmentIDs {
+ if s.Attachments[i] == nil {
+ log.Warnf(nil, "nil attachment in slice for status %s", s.URI)
+ continue
+ }
+ if s.Attachments[i].ID != id {
+ return false
+ }
+ }
+
+ return true
+}
+
+// TagsPopulated returns whether tags are populated according to current TagIDs.
+func (s *Status) TagsPopulated() bool {
+ if len(s.TagIDs) != len(s.Tags) {
+ // this is the quickest indicator.
+ return false
+ }
+
+ // Tags must be in same order.
+ for i, id := range s.TagIDs {
+ if s.Tags[i] == nil {
+ log.Warnf(nil, "nil tag in slice for status %s", s.URI)
+ continue
+ }
+ if s.Tags[i].ID != id {
+ return false
+ }
+ }
+
+ return true
+}
+
+// MentionsPopulated returns whether mentions are populated according to current MentionIDs.
+func (s *Status) MentionsPopulated() bool {
+ if len(s.MentionIDs) != len(s.Mentions) {
+ // this is the quickest indicator.
+ return false
+ }
+
+ // Mentions must be in same order.
+ for i, id := range s.MentionIDs {
+ if s.Mentions[i] == nil {
+ log.Warnf(nil, "nil mention in slice for status %s", s.URI)
+ continue
+ }
+ if s.Mentions[i].ID != id {
+ return false
+ }
+ }
+
+ return true
+}
+
+// EmojisPopulated returns whether emojis are populated according to current EmojiIDs.
+func (s *Status) EmojisPopulated() bool {
+ if len(s.EmojiIDs) != len(s.Emojis) {
+ // this is the quickest indicator.
+ return false
+ }
+
+ // Emojis must be in same order.
+ for i, id := range s.EmojiIDs {
+ if s.Emojis[i] == nil {
+ log.Warnf(nil, "nil emoji in slice for status %s", s.URI)
+ continue
+ }
+ if s.Emojis[i].ID != id {
+ return false
+ }
+ }
+
+ return true
+}
+
+// MentionsAccount returns whether status mentions the given account ID.
+func (s *Status) MentionsAccount(id string) bool {
+ for _, mention := range s.Mentions {
+ if mention.TargetAccountID == id {
+ return true
+ }
+ }
+ return false
+}
+
// StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags.
type StatusToTag struct {
StatusID string `validate:"ulid,required" bun:"type:CHAR(26),unique:statustag,nullzero,notnull"`
diff --git a/internal/processing/account/account.go b/internal/processing/account/account.go
index ef6bc0477..328b682ac 100644
--- a/internal/processing/account/account.go
+++ b/internal/processing/account/account.go
@@ -36,7 +36,7 @@ type Processor struct {
tc typeutils.TypeConverter
mediaManager media.Manager
oauthServer oauth.Server
- filter visibility.Filter
+ filter *visibility.Filter
formatter text.Formatter
federator federation.Federator
parseMention gtsmodel.ParseMentionFunc
@@ -49,6 +49,7 @@ func New(
mediaManager media.Manager,
oauthServer oauth.Server,
federator federation.Federator,
+ filter *visibility.Filter,
parseMention gtsmodel.ParseMentionFunc,
) Processor {
return Processor{
@@ -56,7 +57,7 @@ func New(
tc: tc,
mediaManager: mediaManager,
oauthServer: oauthServer,
- filter: visibility.NewFilter(state.DB),
+ filter: filter,
formatter: text.NewFormatter(state.DB),
federator: federator,
parseMention: parseMention,
diff --git a/internal/processing/account/account_test.go b/internal/processing/account/account_test.go
index 0898e707b..eed6ad7e3 100644
--- a/internal/processing/account/account_test.go
+++ b/internal/processing/account/account_test.go
@@ -34,6 +34,7 @@
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
+ "github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -101,7 +102,9 @@ func (suite *AccountStandardTestSuite) SetupTest() {
suite.federator = testrig.NewTestFederator(&suite.state, suite.transportController, suite.mediaManager)
suite.sentEmails = make(map[string]string)
suite.emailSender = testrig.NewEmailSender("../../../web/template/", suite.sentEmails)
- suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, processing.GetParseMentionFunc(suite.db, suite.federator))
+
+ filter := visibility.NewFilter(&suite.state)
+ suite.accountProcessor = account.New(&suite.state, suite.tc, suite.mediaManager, suite.oauthServer, suite.federator, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, nil)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
}
diff --git a/internal/processing/account/bookmarks.go b/internal/processing/account/bookmarks.go
index 56f0fc9e0..32075f592 100644
--- a/internal/processing/account/bookmarks.go
+++ b/internal/processing/account/bookmarks.go
@@ -56,7 +56,7 @@ func (p *Processor) BookmarksGet(ctx context.Context, requestingAccount *gtsmode
return nil, gtserror.NewErrorInternalError(err) // A real error has occurred.
}
- visible, err := p.filter.StatusVisible(ctx, status, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, requestingAccount, status)
if err != nil {
log.Errorf(ctx, "error checking bookmarked status visibility: %s", err)
continue
diff --git a/internal/processing/account/delete.go b/internal/processing/account/delete.go
index 9c59e1b99..f3dfecc7b 100644
--- a/internal/processing/account/delete.go
+++ b/internal/processing/account/delete.go
@@ -150,25 +150,25 @@ func (p *Processor) deleteUserAndTokensForAccount(ctx context.Context, account *
// - Follow requests created by account.
func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.Account) error {
// Delete follows targeting this account.
- followedBy, err := p.state.DB.GetFollows(ctx, "", account.ID)
+ followedBy, err := p.state.DB.GetAccountFollowers(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("deleteAccountFollows: db error getting follows targeting account %s: %w", account.ID, err)
}
for _, follow := range followedBy {
- if _, err := p.state.DB.Unfollow(ctx, follow.AccountID, account.ID); err != nil {
+ if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil {
return fmt.Errorf("deleteAccountFollows: db error unfollowing account followedBy: %w", err)
}
}
// Delete follow requests targeting this account.
- followRequestedBy, err := p.state.DB.GetFollowRequests(ctx, "", account.ID)
+ followRequestedBy, err := p.state.DB.GetAccountFollowRequests(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("deleteAccountFollows: db error getting follow requests targeting account %s: %w", account.ID, err)
}
for _, followRequest := range followRequestedBy {
- if _, err := p.state.DB.UnfollowRequest(ctx, followRequest.AccountID, account.ID); err != nil {
+ if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil {
return fmt.Errorf("deleteAccountFollows: db error unfollowing account followRequestedBy: %w", err)
}
}
@@ -183,7 +183,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
)
// Delete follows originating from this account.
- following, err := p.state.DB.GetFollows(ctx, account.ID, "")
+ following, err := p.state.DB.GetAccountFollows(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("deleteAccountFollows: db error getting follows owned by account %s: %w", account.ID, err)
}
@@ -191,15 +191,9 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
// For each follow owned by this account, unfollow
// and process side effects (noop if remote account).
for _, follow := range following {
- if uri, err := p.state.DB.Unfollow(ctx, account.ID, follow.TargetAccountID); err != nil {
+ if err := p.state.DB.DeleteFollowByID(ctx, follow.ID); err != nil {
return fmt.Errorf("deleteAccountFollows: db error unfollowing account: %w", err)
- } else if uri == "" {
- // There was no follow after all.
- // Some race condition? Skip.
- log.WithContext(ctx).WithField("follow", follow).Warn("Unfollow did not return uri, likely race condition")
- continue
}
-
if msg := unfollowSideEffects(ctx, account, follow); msg != nil {
// There was a side effect to process.
msgs = append(msgs, *msg)
@@ -207,7 +201,7 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
}
// Delete follow requests originating from this account.
- followRequesting, err := p.state.DB.GetFollowRequests(ctx, account.ID, "")
+ followRequesting, err := p.state.DB.GetAccountFollowRequesting(ctx, account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return fmt.Errorf("deleteAccountFollows: db error getting follow requests owned by account %s: %w", account.ID, err)
}
@@ -215,23 +209,15 @@ func (p *Processor) deleteAccountFollows(ctx context.Context, account *gtsmodel.
// For each follow owned by this account, unfollow
// and process side effects (noop if remote account).
for _, followRequest := range followRequesting {
- uri, err := p.state.DB.UnfollowRequest(ctx, account.ID, followRequest.TargetAccountID)
- if err != nil {
- return fmt.Errorf("deleteAccountFollows: db error unfollowRequesting account: %w", err)
- }
-
- if uri == "" {
- // There was no follow request after all.
- // Some race condition? Skip.
- log.WithContext(ctx).WithField("followRequest", followRequest).Warn("UnfollowRequest did not return uri, likely race condition")
- continue
+ if err := p.state.DB.DeleteFollowRequestByID(ctx, followRequest.ID); err != nil {
+ return fmt.Errorf("deleteAccountFollows: db error unfollowingRequesting account: %w", err)
}
// Dummy out a follow so our side effects func
// has something to work with. This follow will
// never enter the db, it's just for convenience.
follow := >smodel.Follow{
- URI: uri,
+ URI: followRequest.URI,
AccountID: followRequest.AccountID,
Account: followRequest.Account,
TargetAccountID: followRequest.TargetAccountID,
@@ -284,16 +270,9 @@ func (p *Processor) unfollowSideEffectsFunc(deletedAccount *gtsmodel.Account) fu
}
func (p *Processor) deleteAccountBlocks(ctx context.Context, account *gtsmodel.Account) error {
- // Delete blocks created by this account.
- if err := p.state.DB.DeleteBlocksByOriginAccountID(ctx, account.ID); err != nil {
- return fmt.Errorf("deleteAccountBlocks: db error deleting blocks created by account %s: %w", account.ID, err)
+ if err := p.state.DB.DeleteAccountBlocks(ctx, account.ID); err != nil {
+ return fmt.Errorf("deleteAccountBlocks: db error deleting account blocks for %s: %w", account.ID, err)
}
-
- // Delete blocks targeting this account.
- if err := p.state.DB.DeleteBlocksByTargetAccountID(ctx, account.ID); err != nil {
- return fmt.Errorf("deleteAccountBlocks: db error deleting blocks targeting account %s: %w", account.ID, err)
- }
-
return nil
}
@@ -386,13 +365,13 @@ func (p *Processor) deleteAccountStatuses(ctx context.Context, account *gtsmodel
}
func (p *Processor) deleteAccountNotifications(ctx context.Context, account *gtsmodel.Account) error {
- // Delete all notifications targeting given account.
- if err := p.state.DB.DeleteNotifications(ctx, account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) {
+ // Delete all notifications of all types targeting given account.
+ if err := p.state.DB.DeleteNotifications(ctx, nil, account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
- // Delete all notifications originating from given account.
- if err := p.state.DB.DeleteNotifications(ctx, "", account.ID); err != nil && !errors.Is(err, db.ErrNoEntries) {
+ // Delete all notifications of all types originating from given account.
+ if err := p.state.DB.DeleteNotifications(ctx, nil, "", account.ID); err != nil && !errors.Is(err, db.ErrNoEntries) {
return err
}
diff --git a/internal/processing/account/follow.go b/internal/processing/account/follow.go
index 8d053e92a..ab8fecd94 100644
--- a/internal/processing/account/follow.go
+++ b/internal/processing/account/follow.go
@@ -40,7 +40,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
// Check if a follow exists already.
- if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount, targetAccount); err != nil {
+ if follows, err := p.state.DB.IsFollowing(ctx, requestingAccount.ID, targetAccount.ID); err != nil {
err = fmt.Errorf("FollowCreate: db error checking follow: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if follows {
@@ -49,7 +49,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
}
// Check if a follow request exists already.
- if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount, targetAccount); err != nil {
+ if followRequested, err := p.state.DB.IsFollowRequested(ctx, requestingAccount.ID, targetAccount.ID); err != nil {
err = fmt.Errorf("FollowCreate: db error checking follow request: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if followRequested {
@@ -75,7 +75,7 @@ func (p *Processor) FollowCreate(ctx context.Context, requestingAccount *gtsmode
Notify: form.Notify,
}
- if err := p.state.DB.Put(ctx, fr); err != nil {
+ if err := p.state.DB.PutFollowRequest(ctx, fr); err != nil {
err = fmt.Errorf("FollowCreate: error creating follow request in db: %s", err)
return nil, gtserror.NewErrorInternalError(err)
}
@@ -141,7 +141,7 @@ func (p *Processor) getFollowTarget(ctx context.Context, requestingAccountID str
}
// Do nothing if a block exists in either direction between accounts.
- if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccountID, targetAccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, targetAccountID); err != nil {
err = fmt.Errorf("db error checking block between accounts: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@@ -173,12 +173,30 @@ func (p *Processor) getFollowTarget(ctx context.Context, requestingAccountID str
// messages will be returned which should then be processed by a client
// api worker.
func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) ([]messages.FromClientAPI, error) {
- msgs := []messages.FromClientAPI{}
+ var msgs []messages.FromClientAPI
- if fURI, err := p.state.DB.Unfollow(ctx, requestingAccount.ID, targetAccount.ID); err != nil {
- err = fmt.Errorf("unfollow: error deleting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
+ // Get follow from requesting account to target account.
+ follow, err := p.state.DB.GetFollow(ctx, requestingAccount.ID, targetAccount.ID)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ err = fmt.Errorf("unfollow: error getting follow from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
- } else if fURI != "" {
+ }
+
+ if follow != nil {
+ // Delete known follow from database with ID.
+ err = p.state.DB.DeleteFollowByID(ctx, follow.ID)
+ if err != nil {
+ if !errors.Is(err, db.ErrNoEntries) {
+ err = fmt.Errorf("unfollow: error deleting request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
+ return nil, err
+ }
+
+ // If err == db.ErrNoEntries here then it
+ // indicates a race condition with another
+ // unfollow for the same requester->target.
+ return msgs, nil
+ }
+
// Follow status changed, process side effects.
msgs = append(msgs, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
@@ -186,25 +204,43 @@ func (p *Processor) unfollow(ctx context.Context, requestingAccount *gtsmodel.Ac
GTSModel: >smodel.Follow{
AccountID: requestingAccount.ID,
TargetAccountID: targetAccount.ID,
- URI: fURI,
+ URI: follow.URI,
},
OriginAccount: requestingAccount,
TargetAccount: targetAccount,
})
}
- if frURI, err := p.state.DB.UnfollowRequest(ctx, requestingAccount.ID, targetAccount.ID); err != nil {
- err = fmt.Errorf("unfollow: error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
+ // Get follow request from requesting account to target account.
+ followReq, err := p.state.DB.GetFollowRequest(ctx, requestingAccount.ID, targetAccount.ID)
+ if err != nil && !errors.Is(err, db.ErrNoEntries) {
+ err = fmt.Errorf("unfollow: error getting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
return nil, err
- } else if frURI != "" {
- // Follow request status changed, process side effects.
+ }
+
+ if followReq != nil {
+ // Delete known follow request from database with ID.
+ err = p.state.DB.DeleteFollowRequestByID(ctx, followReq.ID)
+ if err != nil {
+ if !errors.Is(err, db.ErrNoEntries) {
+ err = fmt.Errorf("unfollow: error deleting follow request from %s targeting %s: %w", requestingAccount.ID, targetAccount.ID, err)
+ return nil, err
+ }
+
+ // If err == db.ErrNoEntries here then it
+ // indicates a race condition with another
+ // unfollow for the same requester->target.
+ return msgs, nil
+ }
+
+ // Follow status changed, process side effects.
msgs = append(msgs, messages.FromClientAPI{
APObjectType: ap.ActivityFollow,
APActivityType: ap.ActivityUndo,
GTSModel: >smodel.Follow{
AccountID: requestingAccount.ID,
TargetAccountID: targetAccount.ID,
- URI: frURI,
+ URI: followReq.URI,
},
OriginAccount: requestingAccount,
TargetAccount: targetAccount,
diff --git a/internal/processing/account/get.go b/internal/processing/account/get.go
index cfa6d4896..84d00c46b 100644
--- a/internal/processing/account/get.go
+++ b/internal/processing/account/get.go
@@ -73,7 +73,7 @@ func (p *Processor) getFor(ctx context.Context, requestingAccount *gtsmodel.Acco
var blocked bool
var err error
if requestingAccount != nil {
- blocked, err = p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccount.ID, true)
+ blocked, err = p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccount.ID)
if err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error checking account block: %s", err))
}
diff --git a/internal/processing/account/relationships.go b/internal/processing/account/relationships.go
index 4afe36afe..d12d989ef 100644
--- a/internal/processing/account/relationships.go
+++ b/internal/processing/account/relationships.go
@@ -31,7 +31,7 @@
// FollowersGet fetches a list of the target account's followers.
func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil {
err = fmt.Errorf("FollowersGet: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@@ -39,7 +39,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
return nil, gtserror.NewErrorNotFound(err)
}
- follows, err := p.state.DB.GetFollows(ctx, "", targetAccountID)
+ follows, err := p.state.DB.GetAccountFollowers(ctx, targetAccountID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FollowersGet: db error getting followers: %w", err)
@@ -53,7 +53,7 @@ func (p *Processor) FollowersGet(ctx context.Context, requestingAccount *gtsmode
// FollowingGet fetches a list of the accounts that target account is following.
func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
- if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil {
err = fmt.Errorf("FollowingGet: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@@ -61,7 +61,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
return nil, gtserror.NewErrorNotFound(err)
}
- follows, err := p.state.DB.GetFollows(ctx, targetAccountID, "")
+ follows, err := p.state.DB.GetAccountFollows(ctx, targetAccountID)
if err != nil {
if !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("FollowingGet: db error getting followers: %w", err)
@@ -70,7 +70,7 @@ func (p *Processor) FollowingGet(ctx context.Context, requestingAccount *gtsmode
return []apimodel.Account{}, nil
}
- return p.accountsFromFollows(ctx, follows, requestingAccount.ID)
+ return p.targetAccountsFromFollows(ctx, follows, requestingAccount.ID)
}
// RelationshipGet returns a relationship model describing the relationship of the targetAccount to the Authed account.
@@ -101,7 +101,7 @@ func (p *Processor) accountsFromFollows(ctx context.Context, follows []*gtsmodel
continue
}
- if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccountID, follow.AccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.AccountID); err != nil {
err = fmt.Errorf("accountsFromFollows: db error checking block: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@@ -113,8 +113,35 @@ func (p *Processor) accountsFromFollows(ctx context.Context, follows []*gtsmodel
err = fmt.Errorf("accountsFromFollows: error converting account to api account: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
+
+ accounts = append(accounts, *account)
+ }
+ return accounts, nil
+}
+
+func (p *Processor) targetAccountsFromFollows(ctx context.Context, follows []*gtsmodel.Follow, requestingAccountID string) ([]apimodel.Account, gtserror.WithCode) {
+ accounts := make([]apimodel.Account, 0, len(follows))
+ for _, follow := range follows {
+ if follow.TargetAccount == nil {
+ // No account set for some reason; just skip.
+ log.WithContext(ctx).WithField("follow", follow).Warn("follow had no associated target account")
+ continue
+ }
+
+ if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccountID, follow.TargetAccountID); err != nil {
+ err = fmt.Errorf("targetAccountsFromFollows: db error checking block: %w", err)
+ return nil, gtserror.NewErrorInternalError(err)
+ } else if blocked {
+ continue
+ }
+
+ account, err := p.tc.AccountToAPIAccountPublic(ctx, follow.TargetAccount)
+ if err != nil {
+ err = fmt.Errorf("targetAccountsFromFollows: error converting account to api account: %w", err)
+ return nil, gtserror.NewErrorInternalError(err)
+ }
+
accounts = append(accounts, *account)
}
-
return accounts, nil
}
diff --git a/internal/processing/account/statuses.go b/internal/processing/account/statuses.go
index 9ff23ad4b..0b4ee5a2a 100644
--- a/internal/processing/account/statuses.go
+++ b/internal/processing/account/statuses.go
@@ -19,6 +19,7 @@
import (
"context"
+ "errors"
"fmt"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@@ -32,10 +33,11 @@
// the account given in authed.
func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, excludeReblogs bool, maxID string, minID string, pinned bool, mediaOnly bool, publicOnly bool) (*apimodel.PageableResponse, gtserror.WithCode) {
if requestingAccount != nil {
- if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, targetAccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, targetAccountID); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
- return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
+ err := errors.New("block exists between accounts")
+ return nil, gtserror.NewErrorNotFound(err)
}
}
@@ -57,14 +59,10 @@ func (p *Processor) StatusesGet(ctx context.Context, requestingAccount *gtsmodel
return nil, gtserror.NewErrorInternalError(err)
}
- // Filtering + serialization process is the same for
- // either pinned status queries or 'normal' ones.
- filtered := make([]*gtsmodel.Status, 0, len(statuses))
- for _, s := range statuses {
- visible, err := p.filter.StatusVisible(ctx, s, requestingAccount)
- if err == nil && visible {
- filtered = append(filtered, s)
- }
+ // Filtering + serialization process is the same for either pinned status queries or 'normal' ones.
+ filtered, err := p.filter.StatusesVisible(ctx, requestingAccount, statuses)
+ if err != nil {
+ return nil, gtserror.NewErrorInternalError(err)
}
count := len(filtered)
diff --git a/internal/processing/fedi/common.go b/internal/processing/fedi/common.go
index 4a83c2f97..91b3030e1 100644
--- a/internal/processing/fedi/common.go
+++ b/internal/processing/fedi/common.go
@@ -45,7 +45,7 @@ func (p *Processor) authenticate(ctx context.Context, requestedUsername string)
return
}
- blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.state.DB.IsEitherBlocked(ctx, requestedAccount.ID, requestingAccount.ID)
if err != nil {
errWithCode = gtserror.NewErrorInternalError(err)
return
diff --git a/internal/processing/fedi/fedi.go b/internal/processing/fedi/fedi.go
index 935eebd4d..92a23a543 100644
--- a/internal/processing/fedi/fedi.go
+++ b/internal/processing/fedi/fedi.go
@@ -28,15 +28,15 @@ type Processor struct {
state *state.State
federator federation.Federator
tc typeutils.TypeConverter
- filter visibility.Filter
+ filter *visibility.Filter
}
// New returns a new fedi processor.
-func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator) Processor {
+func New(state *state.State, tc typeutils.TypeConverter, federator federation.Federator, filter *visibility.Filter) Processor {
return Processor{
state: state,
federator: federator,
tc: tc,
- filter: visibility.NewFilter(state.DB),
+ filter: filter,
}
}
diff --git a/internal/processing/fedi/status.go b/internal/processing/fedi/status.go
index 2595bef7f..072ff6aaf 100644
--- a/internal/processing/fedi/status.go
+++ b/internal/processing/fedi/status.go
@@ -44,7 +44,7 @@ func (p *Processor) StatusGet(ctx context.Context, requestedUsername string, req
return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s does not belong to account with id %s", status.ID, requestedAccount.ID))
}
- visible, err := p.filter.StatusVisible(ctx, status, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, requestingAccount, status)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -82,7 +82,7 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
return nil, gtserror.NewErrorNotFound(fmt.Errorf("status with id %s does not belong to account with id %s", status.ID, requestedAccount.ID))
}
- visible, err := p.filter.StatusVisible(ctx, status, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, requestedAccount, status)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -143,13 +143,13 @@ func (p *Processor) StatusRepliesGet(ctx context.Context, requestedUsername stri
}
// only show replies that the status owner can see
- visibleToStatusOwner, err := p.filter.StatusVisible(ctx, r, requestedAccount)
+ visibleToStatusOwner, err := p.filter.StatusVisible(ctx, requestedAccount, r)
if err != nil || !visibleToStatusOwner {
continue
}
// only show replies that the requester can see
- visibleToRequester, err := p.filter.StatusVisible(ctx, r, requestingAccount)
+ visibleToRequester, err := p.filter.StatusVisible(ctx, requestingAccount, r)
if err != nil || !visibleToRequester {
continue
}
diff --git a/internal/processing/fedi/user.go b/internal/processing/fedi/user.go
index 62518ad6f..3343ae8bc 100644
--- a/internal/processing/fedi/user.go
+++ b/internal/processing/fedi/user.go
@@ -62,7 +62,7 @@ func (p *Processor) UserGet(ctx context.Context, requestedUsername string, reque
return nil, gtserror.NewErrorUnauthorized(err)
}
- blocked, err := p.state.DB.IsBlocked(ctx, requestedAccount.ID, requestingAccount.ID, true)
+ blocked, err := p.state.DB.IsEitherBlocked(ctx, requestedAccount.ID, requestingAccount.ID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/followrequest.go b/internal/processing/followrequest.go
index 7b8c36050..6587b73bb 100644
--- a/internal/processing/followrequest.go
+++ b/internal/processing/followrequest.go
@@ -31,7 +31,7 @@
)
func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]apimodel.Account, gtserror.WithCode) {
- followRequests, err := p.state.DB.GetFollowRequests(ctx, "", auth.Account.ID)
+ followRequests, err := p.state.DB.GetAccountFollowRequests(ctx, auth.Account.ID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -49,8 +49,10 @@ func (p *Processor) FollowRequestsGet(ctx context.Context, auth *oauth.Auth) ([]
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
+
accts = append(accts, *apiAcct)
}
+
return accts, nil
}
@@ -79,7 +81,12 @@ func (p *Processor) FollowRequestAccept(ctx context.Context, auth *oauth.Auth, a
}
func (p *Processor) FollowRequestReject(ctx context.Context, auth *oauth.Auth, accountID string) (*apimodel.Relationship, gtserror.WithCode) {
- followRequest, err := p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
+ followRequest, err := p.state.DB.GetFollowRequest(ctx, accountID, auth.Account.ID)
+ if err != nil {
+ return nil, gtserror.NewErrorNotFound(err)
+ }
+
+ err = p.state.DB.RejectFollowRequest(ctx, accountID, auth.Account.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(err)
}
diff --git a/internal/processing/fromcommon.go b/internal/processing/fromcommon.go
index 49a05da5d..93d61c533 100644
--- a/internal/processing/fromcommon.go
+++ b/internal/processing/fromcommon.go
@@ -39,11 +39,12 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
if status.Mentions == nil {
// there are mentions but they're not fully populated on the status yet so do this
- menchies, err := p.state.DB.GetMentions(ctx, status.MentionIDs)
+ mentions, err := p.state.DB.GetMentions(ctx, status.MentionIDs)
if err != nil {
return fmt.Errorf("notifyStatus: error getting mentions for status %s from the db: %s", status.ID, err)
}
- status.Mentions = menchies
+
+ status.Mentions = mentions
}
// now we have mentions as full gtsmodel.Mention structs on the status we can continue
@@ -88,7 +89,7 @@ func (p *Processor) notifyStatus(ctx context.Context, status *gtsmodel.Status) e
Status: status,
}
- if err := p.state.DB.Put(ctx, notif); err != nil {
+ if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyStatus: error putting notification in database: %s", err)
}
@@ -130,7 +131,7 @@ func (p *Processor) notifyFollowRequest(ctx context.Context, followRequest *gtsm
OriginAccountID: followRequest.AccountID,
}
- if err := p.state.DB.Put(ctx, notif); err != nil {
+ if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyFollowRequest: error putting notification in database: %s", err)
}
@@ -171,7 +172,7 @@ func (p *Processor) notifyFollow(ctx context.Context, follow *gtsmodel.Follow, t
OriginAccountID: follow.AccountID,
OriginAccount: follow.Account,
}
- if err := p.state.DB.Put(ctx, notif); err != nil {
+ if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyFollow: error putting notification in database: %s", err)
}
@@ -219,7 +220,7 @@ func (p *Processor) notifyFave(ctx context.Context, fave *gtsmodel.StatusFave) e
Status: fave.Status,
}
- if err := p.state.DB.Put(ctx, notif); err != nil {
+ if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyFave: error putting notification in database: %s", err)
}
@@ -293,7 +294,7 @@ func (p *Processor) notifyAnnounce(ctx context.Context, status *gtsmodel.Status)
Status: status,
}
- if err := p.state.DB.Put(ctx, notif); err != nil {
+ if err := p.state.DB.PutNotification(ctx, notif); err != nil {
return fmt.Errorf("notifyAnnounce: error putting notification in database: %s", err)
}
@@ -403,39 +404,39 @@ func (p *Processor) notifyReportClosed(ctx context.Context, report *gtsmodel.Rep
// timelineStatus processes the given new status and inserts it into
// the HOME timelines of accounts that follow the status author.
func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status) error {
- // make sure the author account is pinned onto the status
if status.Account == nil {
- a, err := p.state.DB.GetAccountByID(ctx, status.AccountID)
- if err != nil {
- return fmt.Errorf("timelineStatus: error getting author account with id %s: %s", status.AccountID, err)
+ // ensure status fully populated (including account)
+ if err := p.state.DB.PopulateStatus(ctx, status); err != nil {
+ return fmt.Errorf("timelineStatus: error populating status with id %s: %w", status.ID, err)
}
- status.Account = a
}
- // Get LOCAL followers of the account that posted the status;
- // we know that remote accounts don't have timelines on this
- // instance, so there's no point selecting them too.
- accountIDs, err := p.state.DB.GetLocalFollowersIDs(ctx, status.AccountID)
+ // get local followers of the account that posted the status
+ follows, err := p.state.DB.GetAccountLocalFollowers(ctx, status.AccountID)
if err != nil {
- return fmt.Errorf("timelineStatus: error getting followers for account id %s: %s", status.AccountID, err)
+ return fmt.Errorf("timelineStatus: error getting followers for account id %s: %w", status.AccountID, err)
}
// If the poster is also local, add a fake entry for them
// so they can see their own status in their timeline.
if status.Account.IsLocal() {
- accountIDs = append(accountIDs, status.AccountID)
+ follows = append(follows, >smodel.Follow{
+ AccountID: status.AccountID,
+ Account: status.Account,
+ })
}
- // Timeline the status for each local following account.
- errors := gtserror.MultiError{}
- for _, accountID := range accountIDs {
- if err := p.timelineStatusForAccount(ctx, status, accountID); err != nil {
- errors.Append(err)
+ var errs gtserror.MultiError
+
+ for _, follow := range follows {
+ // Timeline the status for each local following account.
+ if err := p.timelineStatusForAccount(ctx, follow.Account, status); err != nil {
+ errs.Append(err)
}
}
- if len(errors) != 0 {
- return fmt.Errorf("timelineStatus: one or more errors timelining statuses: %w", errors.Combine())
+ if len(errs) != 0 {
+ return fmt.Errorf("timelineStatus: one or more errors timelining statuses: %w", errs.Combine())
}
return nil
@@ -446,34 +447,28 @@ func (p *Processor) timelineStatus(ctx context.Context, status *gtsmodel.Status)
//
// If the status was inserted into the home timeline of the given account,
// it will also be streamed via websockets to the user.
-func (p *Processor) timelineStatusForAccount(ctx context.Context, status *gtsmodel.Status, accountID string) error {
- // get the timeline owner account
- timelineAccount, err := p.state.DB.GetAccountByID(ctx, accountID)
- if err != nil {
- return fmt.Errorf("timelineStatusForAccount: error getting account for timeline with id %s: %w", accountID, err)
- }
-
+func (p *Processor) timelineStatusForAccount(ctx context.Context, account *gtsmodel.Account, status *gtsmodel.Status) error {
// make sure the status is timelineable
- if timelineable, err := p.filter.StatusHometimelineable(ctx, status, timelineAccount); err != nil {
- return fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", accountID, err)
+ if timelineable, err := p.filter.StatusHomeTimelineable(ctx, account, status); err != nil {
+ return fmt.Errorf("timelineStatusForAccount: error getting timelineability for status for timeline with id %s: %w", account.ID, err)
} else if !timelineable {
return nil
}
// stick the status in the timeline for the account and then immediately prepare it so they can see it right away
- if inserted, err := p.statusTimelines.IngestAndPrepare(ctx, status, timelineAccount.ID); err != nil {
+ if inserted, err := p.statusTimelines.IngestAndPrepare(ctx, status, account.ID); err != nil {
return fmt.Errorf("timelineStatusForAccount: error ingesting status %s: %w", status.ID, err)
} else if !inserted {
return nil
}
// the status was inserted so stream it to the user
- apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, timelineAccount)
+ apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, account)
if err != nil {
return fmt.Errorf("timelineStatusForAccount: error converting status %s to frontend representation: %w", status.ID, err)
}
- if err := p.stream.Update(apiStatus, timelineAccount, stream.TimelineHome); err != nil {
+ if err := p.stream.Update(apiStatus, account, stream.TimelineHome); err != nil {
return fmt.Errorf("timelineStatusForAccount: error streaming update for status %s: %w", status.ID, err)
}
@@ -513,8 +508,8 @@ func (p *Processor) wipeStatus(ctx context.Context, statusToDelete *gtsmodel.Sta
}
// delete all mention entries generated by this status
- for _, m := range statusToDelete.MentionIDs {
- if err := p.state.DB.DeleteByID(ctx, m, >smodel.Mention{}); err != nil {
+ for _, id := range statusToDelete.MentionIDs {
+ if err := p.state.DB.DeleteMentionByID(ctx, id); err != nil {
return err
}
}
diff --git a/internal/processing/fromfederator_test.go b/internal/processing/fromfederator_test.go
index d6f4ff555..58d644287 100644
--- a/internal/processing/fromfederator_test.go
+++ b/internal/processing/fromfederator_test.go
@@ -358,10 +358,10 @@ func (suite *FromFederatorTestSuite) TestProcessAccountDelete() {
suite.ErrorIs(err, db.ErrNoEntries)
// the mufos should be gone now too
- satanFollowsZork, err := suite.db.IsFollowing(ctx, deletedAccount, receivingAccount)
+ satanFollowsZork, err := suite.db.IsFollowing(ctx, deletedAccount.ID, receivingAccount.ID)
suite.NoError(err)
suite.False(satanFollowsZork)
- zorkFollowsSatan, err := suite.db.IsFollowing(ctx, receivingAccount, deletedAccount)
+ zorkFollowsSatan, err := suite.db.IsFollowing(ctx, receivingAccount.ID, deletedAccount.ID)
suite.NoError(err)
suite.False(zorkFollowsSatan)
diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go
index 761a352da..293093ac2 100644
--- a/internal/processing/media/getfile.go
+++ b/internal/processing/media/getfile.go
@@ -63,7 +63,7 @@ func (p *Processor) GetFile(ctx context.Context, requestingAccount *gtsmodel.Acc
// make sure the requesting account and the media account don't block each other
if requestingAccount != nil {
- blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, owningAccountID, true)
+ blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, owningAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block status could not be established between accounts %s and %s: %s", owningAccountID, requestingAccount.ID, err))
}
diff --git a/internal/processing/notification.go b/internal/processing/notification.go
index aa81d863a..48c8f92ac 100644
--- a/internal/processing/notification.go
+++ b/internal/processing/notification.go
@@ -30,7 +30,7 @@
)
func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, excludeTypes []string, limit int, maxID string, sinceID string) (*apimodel.PageableResponse, gtserror.WithCode) {
- notifs, err := p.state.DB.GetNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
+ notifs, err := p.state.DB.GetAccountNotifications(ctx, authed.Account.ID, excludeTypes, limit, maxID, sinceID)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@@ -73,8 +73,8 @@ func (p *Processor) NotificationsGet(ctx context.Context, authed *oauth.Auth, ex
}
func (p *Processor) NotificationsClear(ctx context.Context, authed *oauth.Auth) gtserror.WithCode {
- // Delete all notifications that target the authorized account.
- if err := p.state.DB.DeleteNotifications(ctx, authed.Account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) {
+ // Delete all notifications of all types that target the authorized account.
+ if err := p.state.DB.DeleteNotifications(ctx, nil, authed.Account.ID, ""); err != nil && !errors.Is(err, db.ErrNoEntries) {
return gtserror.NewErrorInternalError(err)
}
diff --git a/internal/processing/processor.go b/internal/processing/processor.go
index 3e3854f69..a61a57f88 100644
--- a/internal/processing/processor.go
+++ b/internal/processing/processor.go
@@ -47,8 +47,8 @@ type Processor struct {
mediaManager mm.Manager
statusTimelines timeline.Manager
state *state.State
- filter visibility.Filter
emailSender email.Sender
+ filter *visibility.Filter
/*
SUB-PROCESSORS
@@ -107,7 +107,7 @@ func NewProcessor(
) *Processor {
parseMentionFunc := GetParseMentionFunc(state.DB, federator)
- filter := visibility.NewFilter(state.DB)
+ filter := visibility.NewFilter(state)
processor := &Processor{
federator: federator,
@@ -126,12 +126,12 @@ func NewProcessor(
}
// sub processors
- processor.account = account.New(state, tc, mediaManager, oauthServer, federator, parseMentionFunc)
+ processor.account = account.New(state, tc, mediaManager, oauthServer, federator, filter, parseMentionFunc)
processor.admin = admin.New(state, tc, mediaManager, federator.TransportController(), emailSender)
- processor.fedi = fedi.New(state, tc, federator)
+ processor.fedi = fedi.New(state, tc, federator, filter)
processor.media = media.New(state, tc, mediaManager, federator.TransportController())
processor.report = report.New(state, tc)
- processor.status = status.New(state, tc, parseMentionFunc)
+ processor.status = status.New(state, tc, filter, parseMentionFunc)
processor.stream = stream.New(state, oauthServer)
processor.user = user.New(state, emailSender)
@@ -139,22 +139,24 @@ func NewProcessor(
}
func (p *Processor) EnqueueClientAPI(ctx context.Context, msgs ...messages.FromClientAPI) {
- log.Trace(ctx, "enqueuing client API")
+ log.Trace(ctx, "enqueuing")
_ = p.state.Workers.ClientAPI.MustEnqueueCtx(ctx, func(ctx context.Context) {
for _, msg := range msgs {
+ log.Trace(ctx, "processing: %+v", msg)
if err := p.ProcessFromClientAPI(ctx, msg); err != nil {
- log.WithContext(ctx).WithField("msg", msg).Errorf("error processing client API message: %v", err)
+ log.Errorf(ctx, "error processing client API message: %v", err)
}
}
})
}
func (p *Processor) EnqueueFederator(ctx context.Context, msgs ...messages.FromFederator) {
- log.Trace(ctx, "enqueuing federator")
+ log.Trace(ctx, "enqueuing")
_ = p.state.Workers.Federator.MustEnqueueCtx(ctx, func(ctx context.Context) {
for _, msg := range msgs {
+ log.Trace(ctx, "processing: %+v", msg)
if err := p.ProcessFromFederator(ctx, msg); err != nil {
- log.WithContext(ctx).WithField("msg", msg).Errorf("error processing federator message: %v", err)
+ log.Errorf(ctx, "error processing federator message: %v", err)
}
}
})
diff --git a/internal/processing/search.go b/internal/processing/search.go
index aebb72ecd..9aa89a17b 100644
--- a/internal/processing/search.go
+++ b/internal/processing/search.go
@@ -177,7 +177,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
*/
for _, foundAccount := range foundAccounts {
// make sure there's no block in either direction between the account and the requester
- blocked, err := p.state.DB.IsBlocked(ctx, authed.Account.ID, foundAccount.ID, true)
+ blocked, err := p.state.DB.IsEitherBlocked(ctx, authed.Account.ID, foundAccount.ID)
if err != nil {
err = fmt.Errorf("SearchGet: error checking block between %s and %s: %s", authed.Account.ID, foundAccount.ID, err)
return nil, gtserror.NewErrorInternalError(err)
@@ -199,7 +199,7 @@ func (p *Processor) SearchGet(ctx context.Context, authed *oauth.Auth, search *a
for _, foundStatus := range foundStatuses {
// make sure each found status is visible to the requester
- visible, err := p.filter.StatusVisible(ctx, foundStatus, authed.Account)
+ visible, err := p.filter.StatusVisible(ctx, authed.Account, foundStatus)
if err != nil {
err = fmt.Errorf("SearchGet: error checking visibility of status %s for account %s: %s", foundStatus.ID, authed.Account.ID, err)
return nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/status/boost.go b/internal/processing/status/boost.go
index f5b5a4052..e5d38d9d2 100644
--- a/internal/processing/status/boost.go
+++ b/internal/processing/status/boost.go
@@ -55,12 +55,11 @@ func (p *Processor) BoostCreate(ctx context.Context, requestingAccount *gtsmodel
targetStatus = targetStatus.BoostOf
}
- boostable, err := p.filter.StatusBoostable(ctx, targetStatus, requestingAccount)
+ boostable, err := p.filter.StatusBoostable(ctx, requestingAccount, targetStatus)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is boostable: %s", targetStatus.ID, err))
- }
- if !boostable {
- return nil, gtserror.NewErrorForbidden(errors.New("status is not boostable"))
+ } else if !boostable {
+ return nil, gtserror.NewErrorNotFound(errors.New("status is not boostable"))
}
// it's visible! it's boostable! so let's boost the FUCK out of it
@@ -99,7 +98,7 @@ func (p *Processor) BoostRemove(ctx context.Context, requestingAccount *gtsmodel
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no status owner for status %s", targetStatusID))
}
- visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, requestingAccount, targetStatus)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err))
}
@@ -180,7 +179,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
targetStatus = boostedStatus
}
- visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, requestingAccount, targetStatus)
if err != nil {
err = fmt.Errorf("BoostedBy: error seeing if status %s is visible: %s", targetStatus.ID, err)
return nil, gtserror.NewErrorNotFound(err)
@@ -199,7 +198,7 @@ func (p *Processor) StatusBoostedBy(ctx context.Context, requestingAccount *gtsm
// filter account IDs so the user doesn't see accounts they blocked or which blocked them
accountIDs := make([]string, 0, len(statusReblogs))
for _, s := range statusReblogs {
- blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, s.AccountID, true)
+ blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, s.AccountID)
if err != nil {
err = fmt.Errorf("BoostedBy: error checking blocks: %s", err)
return nil, gtserror.NewErrorNotFound(err)
diff --git a/internal/processing/status/common.go b/internal/processing/status/common.go
index 5b73e9c94..d6478d35a 100644
--- a/internal/processing/status/common.go
+++ b/internal/processing/status/common.go
@@ -43,12 +43,7 @@ func (p *Processor) getVisibleStatus(ctx context.Context, requestingAccount *gts
return nil, gtserror.NewErrorNotFound(err)
}
- if targetStatus.Account == nil {
- err = fmt.Errorf("getVisibleStatus: no status owner for status %s", targetStatusID)
- return nil, gtserror.NewErrorNotFound(err)
- }
-
- visible, err := p.filter.StatusVisible(ctx, targetStatus, requestingAccount)
+ visible, err := p.filter.StatusVisible(ctx, requestingAccount, targetStatus)
if err != nil {
err = fmt.Errorf("getVisibleStatus: error seeing if status %s is visible: %w", targetStatus.ID, err)
return nil, gtserror.NewErrorNotFound(err)
diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go
index 71db8c18e..2d9c3a196 100644
--- a/internal/processing/status/create.go
+++ b/internal/processing/status/create.go
@@ -133,7 +133,7 @@ func processReplyToID(ctx context.Context, dbService db.DB, form *apimodel.Advan
return gtserror.NewErrorInternalError(err)
}
- if blocked, err := dbService.IsBlocked(ctx, thisAccountID, repliedAccount.ID, true); err != nil {
+ if blocked, err := dbService.IsEitherBlocked(ctx, thisAccountID, repliedAccount.ID); err != nil {
err := fmt.Errorf("db error checking block: %s", err)
return gtserror.NewErrorInternalError(err)
} else if blocked {
diff --git a/internal/processing/status/fave.go b/internal/processing/status/fave.go
index da1bae8a1..77d3f67e9 100644
--- a/internal/processing/status/fave.go
+++ b/internal/processing/status/fave.go
@@ -88,7 +88,7 @@ func (p *Processor) FaveRemove(ctx context.Context, requestingAccount *gtsmodel.
}
// We have a fave to remove.
- if err := p.state.DB.DeleteStatusFave(ctx, existingFave.ID); err != nil {
+ if err := p.state.DB.DeleteStatusFaveByID(ctx, existingFave.ID); err != nil {
err = fmt.Errorf("FaveRemove: error removing status fave: %w", err)
return nil, gtserror.NewErrorInternalError(err)
}
@@ -112,7 +112,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc
return nil, errWithCode
}
- statusFaves, err := p.state.DB.GetStatusFaves(ctx, targetStatus.ID)
+ statusFaves, err := p.state.DB.GetStatusFavesForStatus(ctx, targetStatus.ID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("FavedBy: error seeing who faved status: %s", err))
}
@@ -122,7 +122,7 @@ func (p *Processor) FavedBy(ctx context.Context, requestingAccount *gtsmodel.Acc
// and which don't block them.
apiAccounts := make([]*apimodel.Account, 0, len(statusFaves))
for _, fave := range statusFaves {
- if blocked, err := p.state.DB.IsBlocked(ctx, requestingAccount.ID, fave.AccountID, true); err != nil {
+ if blocked, err := p.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, fave.AccountID); err != nil {
err = fmt.Errorf("FavedBy: error checking blocks: %w", err)
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
@@ -157,7 +157,7 @@ func (p *Processor) getFaveTarget(ctx context.Context, requestingAccount *gtsmod
return nil, nil, gtserror.NewErrorForbidden(err, err.Error())
}
- fave, err := p.state.DB.GetStatusFaveByAccountID(ctx, requestingAccount.ID, targetStatusID)
+ fave, err := p.state.DB.GetStatusFave(ctx, requestingAccount.ID, targetStatusID)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = fmt.Errorf("getFaveTarget: error checking existing fave: %w", err)
return nil, nil, gtserror.NewErrorInternalError(err)
diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go
index 251a095de..236f6f126 100644
--- a/internal/processing/status/get.go
+++ b/internal/processing/status/get.go
@@ -54,7 +54,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.
}
for _, status := range parents {
- if v, err := p.filter.StatusVisible(ctx, status, requestingAccount); err == nil && v {
+ if v, err := p.filter.StatusVisible(ctx, requestingAccount, status); err == nil && v {
apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, requestingAccount)
if err == nil {
context.Ancestors = append(context.Ancestors, *apiStatus)
@@ -72,7 +72,7 @@ func (p *Processor) ContextGet(ctx context.Context, requestingAccount *gtsmodel.
}
for _, status := range children {
- if v, err := p.filter.StatusVisible(ctx, status, requestingAccount); err == nil && v {
+ if v, err := p.filter.StatusVisible(ctx, requestingAccount, status); err == nil && v {
apiStatus, err := p.tc.StatusToAPIStatus(ctx, status, requestingAccount)
if err == nil {
context.Descendants = append(context.Descendants, *apiStatus)
diff --git a/internal/processing/status/status.go b/internal/processing/status/status.go
index 8bb9f4cca..2bc1b62ce 100644
--- a/internal/processing/status/status.go
+++ b/internal/processing/status/status.go
@@ -28,17 +28,17 @@
type Processor struct {
state *state.State
tc typeutils.TypeConverter
- filter visibility.Filter
+ filter *visibility.Filter
formatter text.Formatter
parseMention gtsmodel.ParseMentionFunc
}
// New returns a new status processor.
-func New(state *state.State, tc typeutils.TypeConverter, parseMention gtsmodel.ParseMentionFunc) Processor {
+func New(state *state.State, tc typeutils.TypeConverter, filter *visibility.Filter, parseMention gtsmodel.ParseMentionFunc) Processor {
return Processor{
state: state,
tc: tc,
- filter: visibility.NewFilter(state.DB),
+ filter: filter,
formatter: text.NewFormatter(state.DB),
parseMention: parseMention,
}
diff --git a/internal/processing/status/status_test.go b/internal/processing/status/status_test.go
index e7e6d90fe..bef0a6e69 100644
--- a/internal/processing/status/status_test.go
+++ b/internal/processing/status/status_test.go
@@ -29,6 +29,7 @@
"github.com/superseriousbusiness/gotosocial/internal/storage"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
+ "github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
)
@@ -85,7 +86,9 @@ func (suite *StatusStandardTestSuite) SetupTest() {
suite.state.Storage = suite.storage
suite.mediaManager = testrig.NewTestMediaManager(&suite.state)
suite.federator = testrig.NewTestFederator(&suite.state, suite.tc, suite.mediaManager)
- suite.status = status.New(&suite.state, suite.typeConverter, processing.GetParseMentionFunc(suite.db, suite.federator))
+
+ filter := visibility.NewFilter(&suite.state)
+ suite.status = status.New(&suite.state, suite.typeConverter, filter, processing.GetParseMentionFunc(suite.db, suite.federator))
testrig.StandardDBSetup(suite.db, suite.testAccounts)
testrig.StandardStorageSetup(suite.storage, "../../../testrig/media")
diff --git a/internal/processing/statustimeline.go b/internal/processing/statustimeline.go
index 4e7cf8147..4e46b59dc 100644
--- a/internal/processing/statustimeline.go
+++ b/internal/processing/statustimeline.go
@@ -47,9 +47,9 @@ func StatusGrabFunction(database db.DB) timeline.GrabFunction {
return nil, false, fmt.Errorf("statusGrabFunction: error getting statuses from db: %s", err)
}
- items := []timeline.Timelineable{}
- for _, s := range statuses {
- items = append(items, s)
+ items := make([]timeline.Timelineable, len(statuses))
+ for i, s := range statuses {
+ items[i] = s
}
return items, false, nil
@@ -57,7 +57,7 @@ func StatusGrabFunction(database db.DB) timeline.GrabFunction {
}
// StatusFilterFunction returns a function that satisfies the FilterFunction interface in internal/timeline.
-func StatusFilterFunction(database db.DB, filter visibility.Filter) timeline.FilterFunction {
+func StatusFilterFunction(database db.DB, filter *visibility.Filter) timeline.FilterFunction {
return func(ctx context.Context, timelineAccountID string, item timeline.Timelineable) (shouldIndex bool, err error) {
status, ok := item.(*gtsmodel.Status)
if !ok {
@@ -69,7 +69,7 @@ func StatusFilterFunction(database db.DB, filter visibility.Filter) timeline.Fil
return false, fmt.Errorf("statusFilterFunction: error getting account with id %s", timelineAccountID)
}
- timelineable, err := filter.StatusHometimelineable(ctx, status, requestingAccount)
+ timelineable, err := filter.StatusHomeTimelineable(ctx, requestingAccount, status)
if err != nil {
log.Warnf(ctx, "error checking hometimelineability of status %s for account %s: %s", status.ID, timelineAccountID, err)
}
@@ -253,8 +253,7 @@ func (p *Processor) FavedTimelineGet(ctx context.Context, authed *oauth.Auth, ma
func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
apiStatuses := []*apimodel.Status{}
for _, s := range statuses {
- targetAccount := >smodel.Account{}
- if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {
+ if _, err := p.state.DB.GetAccountByID(ctx, s.AccountID); err != nil {
if err == db.ErrNoEntries {
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
@@ -262,7 +261,7 @@ func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth
return nil, gtserror.NewErrorInternalError(fmt.Errorf("filterPublicStatuses: error getting status author: %s", err))
}
- timelineable, err := p.filter.StatusPublictimelineable(ctx, s, authed.Account)
+ timelineable, err := p.filter.StatusPublicTimelineable(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
@@ -286,8 +285,7 @@ func (p *Processor) filterPublicStatuses(ctx context.Context, authed *oauth.Auth
func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth, statuses []*gtsmodel.Status) ([]*apimodel.Status, error) {
apiStatuses := []*apimodel.Status{}
for _, s := range statuses {
- targetAccount := >smodel.Account{}
- if err := p.state.DB.GetByID(ctx, s.AccountID, targetAccount); err != nil {
+ if _, err := p.state.DB.GetAccountByID(ctx, s.AccountID); err != nil {
if err == db.ErrNoEntries {
log.Debugf(ctx, "skipping status %s because account %s can't be found in the db", s.ID, s.AccountID)
continue
@@ -295,7 +293,7 @@ func (p *Processor) filterFavedStatuses(ctx context.Context, authed *oauth.Auth,
return nil, gtserror.NewErrorInternalError(fmt.Errorf("filterPublicStatuses: error getting status author: %s", err))
}
- timelineable, err := p.filter.StatusVisible(ctx, s, authed.Account)
+ timelineable, err := p.filter.StatusVisible(ctx, authed.Account, s)
if err != nil {
log.Debugf(ctx, "skipping status %s because of an error checking status visibility: %s", s.ID, err)
continue
diff --git a/internal/text/goldmark_extension.go b/internal/text/goldmark_extension.go
index 23ac1169d..a12c618dc 100644
--- a/internal/text/goldmark_extension.go
+++ b/internal/text/goldmark_extension.go
@@ -240,7 +240,7 @@ func (r *customRenderer) renderMention(w mdutil.BufWriter, source []byte, node a
n, ok := node.(*mention) // this function is only registered for kindMention
if !ok {
- log.Panic(nil, "type assertion failed")
+ log.Panic(r.ctx, "type assertion failed")
}
text := string(n.Segment.Value(source))
@@ -248,7 +248,7 @@ func (r *customRenderer) renderMention(w mdutil.BufWriter, source []byte, node a
// we don't have much recourse if this fails
if _, err := w.WriteString(html); err != nil {
- log.Errorf(nil, "error writing HTML: %s", err)
+ log.Errorf(r.ctx, "error writing HTML: %s", err)
}
return ast.WalkSkipChildren, nil
}
@@ -260,7 +260,7 @@ func (r *customRenderer) renderHashtag(w mdutil.BufWriter, source []byte, node a
n, ok := node.(*hashtag) // this function is only registered for kindHashtag
if !ok {
- log.Panic(nil, "type assertion failed")
+ log.Panic(r.ctx, "type assertion failed")
}
text := string(n.Segment.Value(source))
@@ -269,7 +269,7 @@ func (r *customRenderer) renderHashtag(w mdutil.BufWriter, source []byte, node a
_, err := w.WriteString(html)
// we don't have much recourse if this fails
if err != nil {
- log.Errorf(nil, "error writing HTML: %s", err)
+ log.Errorf(r.ctx, "error writing HTML: %s", err)
}
return ast.WalkSkipChildren, nil
}
@@ -282,7 +282,7 @@ func (r *customRenderer) renderEmoji(w mdutil.BufWriter, source []byte, node ast
n, ok := node.(*emoji) // this function is only registered for kindEmoji
if !ok {
- log.Panic(nil, "type assertion failed")
+ log.Panic(r.ctx, "type assertion failed")
}
text := string(n.Segment.Value(source))
shortcode := text[1 : len(text)-1]
@@ -307,7 +307,7 @@ func (r *customRenderer) renderEmoji(w mdutil.BufWriter, source []byte, node ast
// we don't have much recourse if this fails
if _, err := w.WriteString(text); err != nil {
- log.Errorf(nil, "error writing HTML: %s", err)
+ log.Errorf(r.ctx, "error writing HTML: %s", err)
}
return ast.WalkSkipChildren, nil
}
diff --git a/internal/text/replace.go b/internal/text/replace.go
index 3c6586e92..e8e02454e 100644
--- a/internal/text/replace.go
+++ b/internal/text/replace.go
@@ -22,6 +22,7 @@
"strings"
"github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/text/unicode/norm"
@@ -37,15 +38,15 @@
// replaceMention takes a string in the form @username@domain.com or @localusername
func (r *customRenderer) replaceMention(text string) string {
- menchie, err := r.parseMention(r.ctx, text, r.accountID, r.statusID)
+ mention, err := r.parseMention(r.ctx, text, r.accountID, r.statusID)
if err != nil {
- log.Errorf(nil, "error parsing mention %s from status: %s", text, err)
+ log.Errorf(r.ctx, "error parsing mention %s from status: %s", text, err)
return text
}
if r.statusID != "" {
- if err := r.f.db.Put(r.ctx, menchie); err != nil {
- log.Errorf(nil, "error putting mention in db: %s", err)
+ if err := r.f.db.PutMention(r.ctx, mention); err != nil {
+ log.Errorf(r.ctx, "error putting mention in db: %s", err)
return text
}
}
@@ -53,27 +54,29 @@ func (r *customRenderer) replaceMention(text string) string {
// only append if it's not been listed yet
listed := false
for _, m := range r.result.Mentions {
- if menchie.ID == m.ID {
+ if mention.ID == m.ID {
listed = true
break
}
}
if !listed {
- r.result.Mentions = append(r.result.Mentions, menchie)
+ r.result.Mentions = append(r.result.Mentions, mention)
}
- // make sure we have an account attached to this mention
- if menchie.TargetAccount == nil {
- a, err := r.f.db.GetAccountByID(r.ctx, menchie.TargetAccountID)
+ if mention.TargetAccount == nil {
+ // Fetch mention target account if not yet populated.
+ mention.TargetAccount, err = r.f.db.GetAccountByID(
+ gtscontext.SetBarebones(r.ctx),
+ mention.TargetAccountID,
+ )
if err != nil {
- log.Errorf(nil, "error getting account with id %s from the db: %s", menchie.TargetAccountID, err)
+ log.Errorf(r.ctx, "error populating mention target account: %v", err)
return text
}
- menchie.TargetAccount = a
}
// The mention's target is our target
- targetAccount := menchie.TargetAccount
+ targetAccount := mention.TargetAccount
var b strings.Builder
@@ -105,7 +108,7 @@ func (r *customRenderer) replaceHashtag(text string) string {
tag, err := r.f.db.TagStringToTag(r.ctx, normalized, r.accountID)
if err != nil {
- log.Errorf(nil, "error generating hashtags from status: %s", err)
+ log.Errorf(r.ctx, "error generating hashtags from status: %s", err)
return text
}
@@ -121,7 +124,7 @@ func (r *customRenderer) replaceHashtag(text string) string {
err = r.f.db.Put(r.ctx, tag)
if err != nil {
if !errors.Is(err, db.ErrAlreadyExists) {
- log.Errorf(nil, "error putting tags in db: %s", err)
+ log.Errorf(r.ctx, "error putting tags in db: %s", err)
return text
}
}
diff --git a/internal/timeline/get_test.go b/internal/timeline/get_test.go
index 7b76178ae..071f33aaf 100644
--- a/internal/timeline/get_test.go
+++ b/internal/timeline/get_test.go
@@ -26,7 +26,6 @@
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -42,15 +41,14 @@ func (suite *GetTestSuite) SetupSuite() {
}
func (suite *GetTestSuite) SetupTest() {
- var state state.State
- state.Caches.Init()
+ suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB(&state)
+ suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.filter = visibility.NewFilter(suite.db)
+ suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)
diff --git a/internal/timeline/index_test.go b/internal/timeline/index_test.go
index 27e47fb2a..d2d1741f6 100644
--- a/internal/timeline/index_test.go
+++ b/internal/timeline/index_test.go
@@ -25,7 +25,6 @@
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -41,15 +40,14 @@ func (suite *IndexTestSuite) SetupSuite() {
}
func (suite *IndexTestSuite) SetupTest() {
- var state state.State
- state.Caches.Init()
+ suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB(&state)
+ suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.filter = visibility.NewFilter(suite.db)
+ suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)
diff --git a/internal/timeline/manager_test.go b/internal/timeline/manager_test.go
index 8614bbc95..8fc4984d1 100644
--- a/internal/timeline/manager_test.go
+++ b/internal/timeline/manager_test.go
@@ -23,7 +23,6 @@
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -39,15 +38,14 @@ func (suite *ManagerTestSuite) SetupSuite() {
}
func (suite *ManagerTestSuite) SetupTest() {
- var state state.State
- state.Caches.Init()
+ suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB(&state)
+ suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.filter = visibility.NewFilter(suite.db)
+ suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)
diff --git a/internal/timeline/prune_test.go b/internal/timeline/prune_test.go
index 36aa411b3..7daf88b83 100644
--- a/internal/timeline/prune_test.go
+++ b/internal/timeline/prune_test.go
@@ -25,7 +25,6 @@
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
"github.com/superseriousbusiness/gotosocial/testrig"
@@ -41,15 +40,14 @@ func (suite *PruneTestSuite) SetupSuite() {
}
func (suite *PruneTestSuite) SetupTest() {
- var state state.State
- state.Caches.Init()
+ suite.state.Caches.Init()
testrig.InitTestLog()
testrig.InitTestConfig()
- suite.db = testrig.NewTestDB(&state)
+ suite.db = testrig.NewTestDB(&suite.state)
suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.filter = visibility.NewFilter(suite.db)
+ suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)
diff --git a/internal/timeline/timeline_test.go b/internal/timeline/timeline_test.go
index 94e2681c8..2207a3418 100644
--- a/internal/timeline/timeline_test.go
+++ b/internal/timeline/timeline_test.go
@@ -21,6 +21,7 @@
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/timeline"
"github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/visibility"
@@ -29,8 +30,9 @@
type TimelineStandardTestSuite struct {
suite.Suite
db db.DB
+ state state.State
tc typeutils.TypeConverter
- filter visibility.Filter
+ filter *visibility.Filter
testAccounts map[string]*gtsmodel.Account
testStatuses map[string]*gtsmodel.Status
diff --git a/internal/typeutils/converter_test.go b/internal/typeutils/converter_test.go
index b040b9854..88c0256c8 100644
--- a/internal/typeutils/converter_test.go
+++ b/internal/typeutils/converter_test.go
@@ -470,6 +470,7 @@
type TypeUtilsTestSuite struct {
suite.Suite
db db.DB
+ state state.State
testAccounts map[string]*gtsmodel.Account
testStatuses map[string]*gtsmodel.Status
testAttachments map[string]*gtsmodel.MediaAttachment
@@ -482,13 +483,12 @@ type TypeUtilsTestSuite struct {
}
func (suite *TypeUtilsTestSuite) SetupSuite() {
- var state state.State
- state.Caches.Init()
+ suite.state.Caches.Init()
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB(&state)
+ suite.db = testrig.NewTestDB(&suite.state)
suite.testAccounts = testrig.NewTestAccounts()
suite.testStatuses = testrig.NewTestStatuses()
suite.testAttachments = testrig.NewTestAttachments()
@@ -500,6 +500,7 @@ func (suite *TypeUtilsTestSuite) SetupSuite() {
}
func (suite *TypeUtilsTestSuite) SetupTest() {
+ suite.state.Caches.Init() // reset
testrig.StandardDBSetup(suite.db, nil)
}
diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go
index bc42226ff..198bed099 100644
--- a/internal/typeutils/internaltofrontend.go
+++ b/internal/typeutils/internaltofrontend.go
@@ -59,7 +59,7 @@ func (c *converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode
// then adding the Source object to it...
// check pending follow requests aimed at this account
- frc, err := c.db.CountFollowRequests(ctx, "", a.ID)
+ frc, err := c.db.CountAccountFollowRequests(ctx, a.ID)
if err != nil {
return nil, fmt.Errorf("error counting follow requests: %s", err)
}
@@ -84,13 +84,13 @@ func (c *converter) AccountToAPIAccountSensitive(ctx context.Context, a *gtsmode
func (c *converter) AccountToAPIAccountPublic(ctx context.Context, a *gtsmodel.Account) (*apimodel.Account, error) {
// count followers
- followersCount, err := c.db.CountFollows(ctx, "", a.ID)
+ followersCount, err := c.db.CountAccountFollowers(ctx, a.ID)
if err != nil {
return nil, fmt.Errorf("error counting followers: %s", err)
}
// count following
- followingCount, err := c.db.CountFollows(ctx, a.ID, "")
+ followingCount, err := c.db.CountAccountFollows(ctx, a.ID)
if err != nil {
return nil, fmt.Errorf("error counting following: %s", err)
}
diff --git a/internal/visibility/account.go b/internal/visibility/account.go
new file mode 100644
index 000000000..ca532f5dd
--- /dev/null
+++ b/internal/visibility/account.go
@@ -0,0 +1,151 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package visibility
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/config"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+)
+
+// AccountVisible will check if given account is visible to requester, accounting for requester with no auth (i.e is nil), suspensions, disabled local users and account blocks.
+func (f *Filter) AccountVisible(ctx context.Context, requester *gtsmodel.Account, account *gtsmodel.Account) (bool, error) {
+ // By default we assume no auth.
+ requesterID := noauth
+
+ if requester != nil {
+ // Use provided account ID.
+ requesterID = requester.ID
+ }
+
+ visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
+ // Visibility not yet cached, perform visibility lookup.
+ visible, err := f.isAccountVisibleTo(ctx, requester, account)
+ if err != nil {
+ return nil, err
+ }
+
+ // Return visibility value.
+ return &cache.CachedVisibility{
+ ItemID: account.ID,
+ RequesterID: requesterID,
+ Type: cache.VisibilityTypeAccount,
+ Value: visible,
+ }, nil
+ }, "account", requesterID, account.ID)
+ if err != nil {
+ return false, err
+ }
+
+ return visibility.Value, nil
+}
+
+// isAccountVisibleTo will check if account is visible to requester. It is the "meat" of the logic to Filter{}.AccountVisible() which is called within cache loader callback.
+func (f *Filter) isAccountVisibleTo(ctx context.Context, requester *gtsmodel.Account, account *gtsmodel.Account) (bool, error) {
+ // Check whether target account is visible to anyone.
+ visible, err := f.isAccountVisible(ctx, account)
+ if err != nil {
+ return false, fmt.Errorf("isAccountVisibleTo: error checking account %s visibility: %w", account.ID, err)
+ }
+
+ if !visible {
+ log.Trace(ctx, "target account is not visible to anyone")
+ return false, nil
+ }
+
+ if requester == nil {
+ // It seems stupid, but when un-authed all accounts are
+ // visible to allow for federation to work correctly.
+ return true, nil
+ }
+
+ // If requester is not visible, they cannot *see* either.
+ visible, err = f.isAccountVisible(ctx, requester)
+ if err != nil {
+ return false, fmt.Errorf("isAccountVisibleTo: error checking account %s visibility: %w", account.ID, err)
+ }
+
+ if !visible {
+ log.Trace(ctx, "requesting account cannot see other accounts")
+ return false, nil
+ }
+
+ // Check whether either blocks the other.
+ blocked, err := f.state.DB.IsEitherBlocked(ctx,
+ requester.ID,
+ account.ID,
+ )
+ if err != nil {
+ return false, fmt.Errorf("isAccountVisibleTo: error checking account blocks: %w", err)
+ }
+
+ if blocked {
+ log.Trace(ctx, "block exists between accounts")
+ return false, nil
+ }
+
+ return true, nil
+}
+
+// isAccountVisible will check if given account should be visible at all, e.g. it may not be if suspended or disabled.
+func (f *Filter) isAccountVisible(ctx context.Context, account *gtsmodel.Account) (bool, error) {
+ if account.IsLocal() {
+ // This is a local account.
+
+ if account.Username == config.GetHost() {
+ // This is the instance actor account.
+ return true, nil
+ }
+
+ // Fetch the local user model for this account.
+ user, err := f.state.DB.GetUserByAccountID(ctx, account.ID)
+ if err != nil {
+ return false, err
+ }
+
+ // Make sure that user is active (i.e. not disabled, not approved etc).
+ if *user.Disabled || !*user.Approved || user.ConfirmedAt.IsZero() {
+ log.Trace(ctx, "local account not active")
+ return false, nil
+ }
+ } else {
+ // This is a remote account.
+
+ // Check whether remote account's domain is blocked.
+ blocked, err := f.state.DB.IsDomainBlocked(ctx, account.Domain)
+ if err != nil {
+ return false, err
+ }
+
+ if blocked {
+ log.Trace(ctx, "remote account domain blocked")
+ return false, nil
+ }
+ }
+
+ if !account.SuspendedAt.IsZero() {
+ log.Trace(ctx, "account suspended")
+ return false, nil
+ }
+
+ return true, nil
+}
diff --git a/internal/visibility/boostable.go b/internal/visibility/boostable.go
new file mode 100644
index 000000000..7c8bda324
--- /dev/null
+++ b/internal/visibility/boostable.go
@@ -0,0 +1,62 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package visibility
+
+import (
+ "context"
+
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+)
+
+// StatusBoostable checks if given status is boostable by requester, checking boolean status visibility to requester and ultimately the AP status visibility setting.
+func (f *Filter) StatusBoostable(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
+ if status.Visibility == gtsmodel.VisibilityDirect {
+ log.Trace(ctx, "direct statuses are not boostable")
+ return false, nil
+ }
+
+ // Check whether status is visible to requesting account.
+ visible, err := f.StatusVisible(ctx, requester, status)
+ if err != nil {
+ return false, err
+ }
+
+ if !visible {
+ log.Trace(ctx, "status not visible to requesting account")
+ return false, nil
+ }
+
+ if requester.ID == status.AccountID {
+ // Status author can always boost non-directs.
+ return true, nil
+ }
+
+ if status.Visibility == gtsmodel.VisibilityFollowersOnly ||
+ status.Visibility == gtsmodel.VisibilityMutualsOnly {
+ log.Trace(ctx, "unauthored %s status not boostable", status.Visibility)
+ return false, nil
+ }
+
+ if !*status.Boostable {
+ log.Trace(ctx, "status marked not boostable")
+ return false, nil
+ }
+
+ return true, nil
+}
diff --git a/internal/visibility/statusboostable_test.go b/internal/visibility/boostable_test.go
similarity index 81%
rename from internal/visibility/statusboostable_test.go
rename to internal/visibility/boostable_test.go
index cdadd82a3..fd29e7305 100644
--- a/internal/visibility/statusboostable_test.go
+++ b/internal/visibility/boostable_test.go
@@ -33,7 +33,7 @@ func (suite *StatusBoostableTestSuite) TestOwnPublicBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@@ -44,7 +44,7 @@ func (suite *StatusBoostableTestSuite) TestOwnUnlockedBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@@ -55,7 +55,7 @@ func (suite *StatusBoostableTestSuite) TestOwnMutualsOnlyNonInteractiveBoostable
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@@ -66,7 +66,7 @@ func (suite *StatusBoostableTestSuite) TestOwnMutualsOnlyBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@@ -77,7 +77,7 @@ func (suite *StatusBoostableTestSuite) TestOwnFollowersOnlyBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@@ -88,7 +88,7 @@ func (suite *StatusBoostableTestSuite) TestOwnDirectNotBoostable() {
testAccount := suite.testAccounts["local_account_2"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(boostable)
@@ -99,7 +99,7 @@ func (suite *StatusBoostableTestSuite) TestOtherPublicBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@@ -110,7 +110,7 @@ func (suite *StatusBoostableTestSuite) TestOtherUnlistedBoostable() {
testAccount := suite.testAccounts["local_account_2"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(boostable)
@@ -121,7 +121,7 @@ func (suite *StatusBoostableTestSuite) TestOtherFollowersOnlyNotBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(boostable)
@@ -132,19 +132,19 @@ func (suite *StatusBoostableTestSuite) TestOtherDirectNotBoostable() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(boostable)
}
-func (suite *StatusBoostableTestSuite) TestRemoteFollowersOnlyNotVisibleError() {
+func (suite *StatusBoostableTestSuite) TestRemoteFollowersOnlyNotVisible() {
testStatus := suite.testStatuses["local_account_1_status_5"]
testAccount := suite.testAccounts["remote_account_1"]
ctx := context.Background()
- boostable, err := suite.filter.StatusBoostable(ctx, testStatus, testAccount)
- suite.Assert().Error(err)
+ boostable, err := suite.filter.StatusBoostable(ctx, testAccount, testStatus)
+ suite.NoError(err)
suite.False(boostable)
}
diff --git a/internal/visibility/filter.go b/internal/visibility/filter.go
index caa622d09..c9f007ccf 100644
--- a/internal/visibility/filter.go
+++ b/internal/visibility/filter.go
@@ -18,46 +18,20 @@
package visibility
import (
- "context"
-
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/state"
)
-// Filter packages up a bunch of logic for checking whether given statuses or accounts are visible to a requester.
-type Filter interface {
- // StatusVisible returns true if targetStatus is visible to requestingAccount, based on the
- // privacy settings of the status, and any blocks/mutes that might exist between the two accounts
- // or account domains, and other relevant accounts mentioned in or replied to by the status.
- StatusVisible(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
+// noauth is a placeholder ID used in cache lookups
+// when there is no authorized account ID to use.
+const noauth = "noauth"
- // StatusesVisible calls StatusVisible for each status in the statuses slice, and returns a slice of only
- // statuses which are visible to the requestingAccount.
- StatusesVisible(ctx context.Context, statuses []*gtsmodel.Status, requestingAccount *gtsmodel.Account) ([]*gtsmodel.Status, error)
-
- // StatusHometimelineable returns true if targetStatus should be in the home timeline of the requesting account.
- //
- // This function will call StatusVisible internally, so it's not necessary to call it beforehand.
- StatusHometimelineable(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
-
- // StatusPublictimelineable returns true if targetStatus should be in the public timeline of the requesting account.
- //
- // This function will call StatusVisible internally, so it's not necessary to call it beforehand.
- StatusPublictimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error)
-
- // StatusBoostable returns true if targetStatus can be boosted by the requesting account.
- //
- // this function will call StatusVisible internally so it's not necessary to call it beforehand.
- StatusBoostable(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error)
-}
-
-type filter struct {
- db db.DB
+// Filter packages up a bunch of logic for checking whether
+// given statuses or accounts are visible to a requester.
+type Filter struct {
+ state *state.State
}
// NewFilter returns a new Filter interface that will use the provided database.
-func NewFilter(db db.DB) Filter {
- return &filter{
- db: db,
- }
+func NewFilter(state *state.State) *Filter {
+ return &Filter{state: state}
}
diff --git a/internal/visibility/filter_test.go b/internal/visibility/filter_test.go
index 500f46239..41f06079a 100644
--- a/internal/visibility/filter_test.go
+++ b/internal/visibility/filter_test.go
@@ -29,7 +29,8 @@
type FilterStandardTestSuite struct {
// standard suite interfaces
suite.Suite
- db db.DB
+ db db.DB
+ state state.State
// standard suite models
testTokens map[string]*gtsmodel.Token
@@ -43,7 +44,7 @@ type FilterStandardTestSuite struct {
testMentions map[string]*gtsmodel.Mention
testFollows map[string]*gtsmodel.Follow
- filter visibility.Filter
+ filter *visibility.Filter
}
func (suite *FilterStandardTestSuite) SetupSuite() {
@@ -60,14 +61,13 @@ func (suite *FilterStandardTestSuite) SetupSuite() {
}
func (suite *FilterStandardTestSuite) SetupTest() {
- var state state.State
- state.Caches.Init()
+ suite.state.Caches.Init()
testrig.InitTestConfig()
testrig.InitTestLog()
- suite.db = testrig.NewTestDB(&state)
- suite.filter = visibility.NewFilter(suite.db)
+ suite.db = testrig.NewTestDB(&suite.state)
+ suite.filter = visibility.NewFilter(&suite.state)
testrig.StandardDBSetup(suite.db, nil)
}
diff --git a/internal/visibility/home_timeline.go b/internal/visibility/home_timeline.go
new file mode 100644
index 000000000..3f0f1f16a
--- /dev/null
+++ b/internal/visibility/home_timeline.go
@@ -0,0 +1,165 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package visibility
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+)
+
+// StatusHomeTimelineable checks if given status should be included on owner's home timeline. Primarily relying on status visibility to owner and the AP visibility setting, but also taking into account thread replies etc.
+func (f *Filter) StatusHomeTimelineable(ctx context.Context, owner *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
+ // By default we assume no auth.
+ requesterID := noauth
+
+ if owner != nil {
+ // Use provided account ID.
+ requesterID = owner.ID
+ }
+
+ visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
+ // Visibility not yet cached, perform timeline visibility lookup.
+ visible, err := f.isStatusHomeTimelineable(ctx, owner, status)
+ if err != nil {
+ return nil, err
+ }
+
+ // Return visibility value.
+ return &cache.CachedVisibility{
+ ItemID: status.ID,
+ RequesterID: requesterID,
+ Type: cache.VisibilityTypeHome,
+ Value: visible,
+ }, nil
+ }, "home", requesterID, status.ID)
+ if err != nil {
+ if err == cache.SentinelError {
+ // Filter-out our temporary
+ // race-condition error.
+ return false, nil
+ }
+
+ return false, err
+ }
+
+ return visibility.Value, nil
+}
+
+func (f *Filter) isStatusHomeTimelineable(ctx context.Context, owner *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
+ if status.CreatedAt.After(time.Now().Add(24 * time.Hour)) {
+ // Statuses made over 1 day in the future we don't show...
+ log.Warnf(ctx, "status >24hrs in the future: %+v", status)
+ return false, nil
+ }
+
+ // Check whether status is visible to timeline owner.
+ visible, err := f.StatusVisible(ctx, owner, status)
+ if err != nil {
+ return false, err
+ }
+
+ if !visible {
+ log.Trace(ctx, "status not visible to timeline owner")
+ return false, nil
+ }
+
+ if status.AccountID == owner.ID {
+ // Author can always see their status.
+ return true, nil
+ }
+
+ if status.MentionsAccount(owner.ID) {
+ // Can always see when you are mentioned.
+ return true, nil
+ }
+
+ var (
+ parent *gtsmodel.Status
+ included bool
+ oneAuthor bool
+ )
+
+ for parent = status; parent.InReplyToURI != ""; {
+ // Fetch next parent to lookup.
+ parentID := parent.InReplyToID
+ if parentID == "" {
+ log.Warnf(ctx, "status not yet deref'd: %s", parent.InReplyToURI)
+ return false, cache.SentinelError
+ }
+
+ // Get the next parent in the chain from DB.
+ parent, err = f.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ parentID,
+ )
+ if err != nil {
+ return false, fmt.Errorf("isStatusHomeTimelineable: error getting status parent %s: %w", parentID, err)
+ }
+
+ if (parent.AccountID == owner.ID) ||
+ parent.MentionsAccount(owner.ID) {
+ // Owner is in / mentioned in
+ // this status thread.
+ included = true
+ break
+ }
+
+ if oneAuthor {
+ // Check if this is a single-author status thread.
+ oneAuthor = (parent.AccountID == status.AccountID)
+ }
+ }
+
+ if parent != status && !included && !oneAuthor {
+ log.Trace(ctx, "ignoring visible reply to conversation thread excluding owner")
+ return false, nil
+ }
+
+ // At this point status is either a top-level status, a reply in a single
+ // author thread (e.g. "this is my weird-ass take and here is why 1/10 🧵"),
+ // or a thread mentioning / including timeline owner.
+
+ if status.Visibility == gtsmodel.VisibilityFollowersOnly ||
+ status.Visibility == gtsmodel.VisibilityMutualsOnly {
+ // Followers/mutuals only post that already passed the status
+ // visibility check, (i.e. we follow / mutuals with author).
+ return true, nil
+ }
+
+ // Ensure owner follows author of public/unlocked status.
+ follow, err := f.state.DB.IsFollowing(ctx,
+ owner.ID,
+ status.AccountID,
+ )
+ if err != nil {
+ return false, fmt.Errorf("isStatusHomeTimelineable: error checking follow %s->%s: %w", owner.ID, status.AccountID, err)
+ }
+
+ if !follow {
+ log.Trace(ctx, "ignoring visible status from unfollowed author")
+ return false, nil
+ }
+
+ return true, nil
+}
diff --git a/internal/visibility/statushometimelineable_test.go b/internal/visibility/home_timeline_test.go
similarity index 83%
rename from internal/visibility/statushometimelineable_test.go
rename to internal/visibility/home_timeline_test.go
index 8f7e51362..b81f8fe4c 100644
--- a/internal/visibility/statushometimelineable_test.go
+++ b/internal/visibility/home_timeline_test.go
@@ -25,86 +25,77 @@
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/testrig"
)
-type StatusStatusHometimelineableTestSuite struct {
+type StatusStatusHomeTimelineableTestSuite struct {
FilterStandardTestSuite
}
-func (suite *StatusStatusHometimelineableTestSuite) TestOwnStatusHometimelineable() {
+func (suite *StatusStatusHomeTimelineableTestSuite) TestOwnStatusHomeTimelineable() {
testStatus := suite.testStatuses["local_account_1_status_1"]
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
+ timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(timelineable)
}
-func (suite *StatusStatusHometimelineableTestSuite) TestFollowingStatusHometimelineable() {
+func (suite *StatusStatusHomeTimelineableTestSuite) TestFollowingStatusHomeTimelineable() {
testStatus := suite.testStatuses["local_account_2_status_1"]
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
+ timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(timelineable)
}
-func (suite *StatusStatusHometimelineableTestSuite) TestNotFollowingStatusHometimelineable() {
+func (suite *StatusStatusHomeTimelineableTestSuite) TestNotFollowingStatusHomeTimelineable() {
testStatus := suite.testStatuses["remote_account_1_status_1"]
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
+ timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(timelineable)
}
-func (suite *StatusStatusHometimelineableTestSuite) TestStatusTooNewNotTimelineable() {
+func (suite *StatusStatusHomeTimelineableTestSuite) TestStatusTooNewNotTimelineable() {
testStatus := >smodel.Status{}
*testStatus = *suite.testStatuses["local_account_1_status_1"]
- var err error
- testStatus.ID, err = id.NewULIDFromTime(time.Now().Add(10 * time.Minute))
- if err != nil {
- suite.FailNow(err.Error())
- }
+ testStatus.CreatedAt = time.Now().Add(25 * time.Hour)
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
+ timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(timelineable)
}
-func (suite *StatusStatusHometimelineableTestSuite) TestStatusNotTooNewTimelineable() {
+func (suite *StatusStatusHomeTimelineableTestSuite) TestStatusNotTooNewTimelineable() {
testStatus := >smodel.Status{}
*testStatus = *suite.testStatuses["local_account_1_status_1"]
- var err error
- testStatus.ID, err = id.NewULIDFromTime(time.Now().Add(4 * time.Minute))
- if err != nil {
- suite.FailNow(err.Error())
- }
+ testStatus.CreatedAt = time.Now().Add(23 * time.Hour)
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- timelineable, err := suite.filter.StatusHometimelineable(ctx, testStatus, testAccount)
+ timelineable, err := suite.filter.StatusHomeTimelineable(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(timelineable)
}
-func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly() {
+func (suite *StatusStatusHomeTimelineableTestSuite) TestChainReplyFollowersOnly() {
ctx := context.Background()
// This scenario makes sure that we don't timeline a status which is a followers-only
@@ -112,9 +103,8 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly(
// timeline owner account doesn't follow.
//
// In other words, remote_account_1 posts a followers-only status, which local_account_1 replies to;
- // THEN, local_account_1 replies to their own reply. We don't want this last status to appear
- // in the timeline of local_account_2, even though they follow local_account_1, because they
- // *don't* follow remote_account_1.
+ // THEN, local_account_1 replies to their own reply. None of these statuses should appear to
+ // local_account_2 since they don't follow the original parent.
//
// See: https://github.com/superseriousbusiness/gotosocial/issues/501
@@ -152,7 +142,7 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly(
suite.FailNow(err.Error())
}
// this status should not be hometimelineable for local_account_2
- originalStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, originalStatus, timelineOwnerAccount)
+ originalStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, originalStatus)
suite.NoError(err)
suite.False(originalStatusTimelineable)
@@ -185,8 +175,8 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly(
if err := suite.db.PutStatus(ctx, firstReplyStatus); err != nil {
suite.FailNow(err.Error())
}
- // this status should not be hometimelineable for local_account_2
- firstReplyStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, firstReplyStatus, timelineOwnerAccount)
+ // this status should be hometimelineable for local_account_2
+ firstReplyStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, firstReplyStatus)
suite.NoError(err)
suite.False(firstReplyStatusTimelineable)
@@ -221,12 +211,12 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyFollowersOnly(
}
// this status should ALSO not be hometimelineable for local_account_2
- secondReplyStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, secondReplyStatus, timelineOwnerAccount)
+ secondReplyStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, secondReplyStatus)
suite.NoError(err)
suite.False(secondReplyStatusTimelineable)
}
-func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyPublicAndUnlocked() {
+func (suite *StatusStatusHomeTimelineableTestSuite) TestChainReplyPublicAndUnlocked() {
ctx := context.Background()
// This scenario is exactly the same as the above test, but for a mix of unlocked + public posts
@@ -265,7 +255,7 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyPublicAndUnloc
suite.FailNow(err.Error())
}
// this status should not be hometimelineable for local_account_2
- originalStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, originalStatus, timelineOwnerAccount)
+ originalStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, originalStatus)
suite.NoError(err)
suite.False(originalStatusTimelineable)
@@ -299,7 +289,7 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyPublicAndUnloc
suite.FailNow(err.Error())
}
// this status should not be hometimelineable for local_account_2
- firstReplyStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, firstReplyStatus, timelineOwnerAccount)
+ firstReplyStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, firstReplyStatus)
suite.NoError(err)
suite.False(firstReplyStatusTimelineable)
@@ -334,11 +324,11 @@ func (suite *StatusStatusHometimelineableTestSuite) TestChainReplyPublicAndUnloc
}
// this status should ALSO not be hometimelineable for local_account_2
- secondReplyStatusTimelineable, err := suite.filter.StatusHometimelineable(ctx, secondReplyStatus, timelineOwnerAccount)
+ secondReplyStatusTimelineable, err := suite.filter.StatusHomeTimelineable(ctx, timelineOwnerAccount, secondReplyStatus)
suite.NoError(err)
suite.False(secondReplyStatusTimelineable)
}
-func TestStatusHometimelineableTestSuite(t *testing.T) {
- suite.Run(t, new(StatusStatusHometimelineableTestSuite))
+func TestStatusHomeTimelineableTestSuite(t *testing.T) {
+ suite.Run(t, new(StatusStatusHomeTimelineableTestSuite))
}
diff --git a/internal/visibility/public_timeline.go b/internal/visibility/public_timeline.go
new file mode 100644
index 000000000..13ac07831
--- /dev/null
+++ b/internal/visibility/public_timeline.go
@@ -0,0 +1,121 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package visibility
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+)
+
+// StatusHomeTimelineable checks if given status should be included on requester's public timeline. Primarily relying on status visibility to requester and the AP visibility setting, and ignoring conversation threads.
+func (f *Filter) StatusPublicTimelineable(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
+ // By default we assume no auth.
+ requesterID := noauth
+
+ if requester != nil {
+ // Use provided account ID.
+ requesterID = requester.ID
+ }
+
+ visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
+ // Visibility not yet cached, perform timeline visibility lookup.
+ visible, err := f.isStatusPublicTimelineable(ctx, requester, status)
+ if err != nil {
+ return nil, err
+ }
+
+ // Return visibility value.
+ return &cache.CachedVisibility{
+ ItemID: status.ID,
+ RequesterID: requesterID,
+ Type: cache.VisibilityTypePublic,
+ Value: visible,
+ }, nil
+ }, "public", requesterID, status.ID)
+ if err != nil {
+ if err == cache.SentinelError {
+ // Filter-out our temporary
+ // race-condition error.
+ return false, nil
+ }
+
+ return false, err
+ }
+
+ return visibility.Value, nil
+}
+
+func (f *Filter) isStatusPublicTimelineable(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
+ if status.CreatedAt.After(time.Now().Add(24 * time.Hour)) {
+ // Statuses made over 1 day in the future we don't show...
+ log.Warnf(ctx, "status >24hrs in the future: %+v", status)
+ return false, nil
+ }
+
+ // Don't show boosts on timeline.
+ if status.BoostOfID != "" {
+ return false, nil
+ }
+
+ // Check whether status is visible to requesting account.
+ visible, err := f.StatusVisible(ctx, requester, status)
+ if err != nil {
+ return false, err
+ }
+
+ if !visible {
+ log.Trace(ctx, "status not visible to timeline requester")
+ return false, nil
+ }
+
+ for parent := status; parent.InReplyToURI != ""; {
+ // Fetch next parent to lookup.
+ parentID := parent.InReplyToID
+ if parentID == "" {
+ log.Warnf(ctx, "status not yet deref'd: %s", parent.InReplyToURI)
+ return false, cache.SentinelError
+ }
+
+ // Get the next parent in the chain from DB.
+ parent, err = f.state.DB.GetStatusByID(
+ gtscontext.SetBarebones(ctx),
+ parentID,
+ )
+ if err != nil {
+ return false, fmt.Errorf("isStatusHomeTimelineable: error getting status parent %s: %w", parentID, err)
+ }
+
+ if parent.AccountID != status.AccountID {
+ // This is not a single author reply-chain-thread,
+ // instead is an actualy conversation. Don't timeline.
+ log.Trace(ctx, "ignoring multi-author reply-chain")
+ return false, nil
+ }
+ }
+
+ // This is either a visible status in a
+ // single-author thread, or a visible top
+ // level status. Show on public timeline.
+ return true, nil
+}
diff --git a/internal/visibility/relevantaccounts.go b/internal/visibility/relevantaccounts.go
deleted file mode 100644
index 2389b8544..000000000
--- a/internal/visibility/relevantaccounts.go
+++ /dev/null
@@ -1,230 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see .
-
-package visibility
-
-import (
- "context"
- "errors"
- "fmt"
-
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
-)
-
-// relevantAccounts denotes accounts that are replied to, boosted by, or mentioned in a status.
-type relevantAccounts struct {
- // Who wrote the status
- Account *gtsmodel.Account
- // Who is the status replying to
- InReplyToAccount *gtsmodel.Account
- // Which accounts are mentioned (tagged) in the status
- MentionedAccounts []*gtsmodel.Account
- // Who authed the boosted status
- BoostedAccount *gtsmodel.Account
- // If the boosted status replies to another account, who does it reply to?
- BoostedInReplyToAccount *gtsmodel.Account
- // Who is mentioned (tagged) in the boosted status
- BoostedMentionedAccounts []*gtsmodel.Account
-}
-
-func (f *filter) relevantAccounts(ctx context.Context, status *gtsmodel.Status, getBoosted bool) (*relevantAccounts, error) {
- relAccts := &relevantAccounts{
- MentionedAccounts: []*gtsmodel.Account{},
- BoostedMentionedAccounts: []*gtsmodel.Account{},
- }
-
- /*
- Here's what we need to try and extract from the status:
-
- // 1. Who wrote the status
- Account *gtsmodel.Account
-
- // 2. Who is the status replying to
- InReplyToAccount *gtsmodel.Account
-
- // 3. Which accounts are mentioned (tagged) in the status
- MentionedAccounts []*gtsmodel.Account
-
- if getBoosted:
- // 4. Who wrote the boosted status
- BoostedAccount *gtsmodel.Account
-
- // 5. If the boosted status replies to another account, who does it reply to?
- BoostedInReplyToAccount *gtsmodel.Account
-
- // 6. Who is mentioned (tagged) in the boosted status
- BoostedMentionedAccounts []*gtsmodel.Account
- */
-
- // 1. Account.
- // Account might be set on the status already
- if status.Account != nil {
- // it was set
- relAccts.Account = status.Account
- } else {
- // it wasn't set, so get it from the db
- account, err := f.db.GetAccountByID(ctx, status.AccountID)
- if err != nil {
- return nil, fmt.Errorf("relevantAccounts: error getting account with id %s: %s", status.AccountID, err)
- }
- // set it on the status in case we need it further along
- status.Account = account
- // set it on relevant accounts
- relAccts.Account = account
- }
-
- // 2. InReplyToAccount
- // only get this if InReplyToAccountID is set
- if status.InReplyToAccountID != "" {
- // InReplyToAccount might be set on the status already
- if status.InReplyToAccount != nil {
- // it was set
- relAccts.InReplyToAccount = status.InReplyToAccount
- } else {
- // it wasn't set, so get it from the db
- inReplyToAccount, err := f.db.GetAccountByID(ctx, status.InReplyToAccountID)
- if err != nil {
- return nil, fmt.Errorf("relevantAccounts: error getting inReplyToAccount with id %s: %s", status.InReplyToAccountID, err)
- }
- // set it on the status in case we need it further along
- status.InReplyToAccount = inReplyToAccount
- // set it on relevant accounts
- relAccts.InReplyToAccount = inReplyToAccount
- }
- }
-
- // 3. MentionedAccounts
- // First check if status.Mentions is populated with all mentions that correspond to status.MentionIDs
- for _, mID := range status.MentionIDs {
- if mID == "" {
- continue
- }
- if !idIn(mID, status.Mentions) {
- // mention with ID isn't in status.Mentions
- mention, err := f.db.GetMention(ctx, mID)
- if err != nil {
- return nil, fmt.Errorf("relevantAccounts: error getting mention with id %s: %s", mID, err)
- }
- if mention == nil {
- return nil, fmt.Errorf("relevantAccounts: mention with id %s was nil", mID)
- }
- status.Mentions = append(status.Mentions, mention)
- }
- }
- // now filter mentions to make sure we only have mentions with a corresponding ID
- nm := []*gtsmodel.Mention{}
- for _, m := range status.Mentions {
- if m == nil {
- continue
- }
- if mentionIn(m, status.MentionIDs) {
- nm = append(nm, m)
- relAccts.MentionedAccounts = append(relAccts.MentionedAccounts, m.TargetAccount)
- }
- }
- status.Mentions = nm
-
- if len(status.Mentions) != len(status.MentionIDs) {
- return nil, errors.New("relevantAccounts: mentions length did not correspond with mentionIDs length")
- }
-
- // if getBoosted is set, we should check the same properties on the boosted account as well
- if getBoosted {
- // 4, 5, 6. Boosted status items
- // get the boosted status if it's not set on the status already
- if status.BoostOfID != "" && status.BoostOf == nil {
- boostedStatus, err := f.db.GetStatusByID(ctx, status.BoostOfID)
- if err != nil {
- return nil, fmt.Errorf("relevantAccounts: error getting boosted status with id %s: %s", status.BoostOfID, err)
- }
- status.BoostOf = boostedStatus
- }
-
- if status.BoostOf != nil {
- // return relevant accounts for the boosted status
- boostedRelAccts, err := f.relevantAccounts(ctx, status.BoostOf, false) // false because we don't want to recurse
- if err != nil {
- return nil, fmt.Errorf("relevantAccounts: error getting relevant accounts of boosted status %s: %s", status.BoostOf.ID, err)
- }
- relAccts.BoostedAccount = boostedRelAccts.Account
- relAccts.BoostedInReplyToAccount = boostedRelAccts.InReplyToAccount
- relAccts.BoostedMentionedAccounts = boostedRelAccts.MentionedAccounts
- }
- }
-
- return relAccts, nil
-}
-
-// domainBlockedRelevant checks through all relevant accounts attached to a status
-// to make sure none of them are domain blocked by this instance.
-func (f *filter) domainBlockedRelevant(ctx context.Context, r *relevantAccounts) (bool, error) {
- domains := []string{}
-
- if r.Account != nil {
- domains = append(domains, r.Account.Domain)
- }
-
- if r.InReplyToAccount != nil {
- domains = append(domains, r.InReplyToAccount.Domain)
- }
-
- for _, a := range r.MentionedAccounts {
- if a != nil {
- domains = append(domains, a.Domain)
- }
- }
-
- if r.BoostedAccount != nil {
- domains = append(domains, r.BoostedAccount.Domain)
- }
-
- if r.BoostedInReplyToAccount != nil {
- domains = append(domains, r.BoostedInReplyToAccount.Domain)
- }
-
- for _, a := range r.BoostedMentionedAccounts {
- if a != nil {
- domains = append(domains, a.Domain)
- }
- }
-
- return f.db.AreDomainsBlocked(ctx, domains)
-}
-
-func idIn(id string, mentions []*gtsmodel.Mention) bool {
- for _, m := range mentions {
- if m == nil {
- continue
- }
- if m.ID == id {
- return true
- }
- }
- return false
-}
-
-func mentionIn(mention *gtsmodel.Mention, ids []string) bool {
- if mention == nil {
- return false
- }
- for _, i := range ids {
- if mention.ID == i {
- return true
- }
- }
- return false
-}
diff --git a/internal/visibility/status.go b/internal/visibility/status.go
new file mode 100644
index 000000000..dc8261624
--- /dev/null
+++ b/internal/visibility/status.go
@@ -0,0 +1,217 @@
+// GoToSocial
+// Copyright (C) GoToSocial Authors admin@gotosocial.org
+// SPDX-License-Identifier: AGPL-3.0-or-later
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package visibility
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/superseriousbusiness/gotosocial/internal/cache"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+)
+
+// StatusesVisible calls StatusVisible for each status in the statuses slice, and returns a slice of only statuses which are visible to the requester.
+func (f *Filter) StatusesVisible(ctx context.Context, requester *gtsmodel.Account, statuses []*gtsmodel.Status) ([]*gtsmodel.Status, error) {
+ // Preallocate slice of maximum possible length.
+ filtered := make([]*gtsmodel.Status, 0, len(statuses))
+
+ for _, status := range statuses {
+ // Check whether status is visible to requester.
+ visible, err := f.StatusVisible(ctx, requester, status)
+ if err != nil {
+ return nil, err
+ }
+
+ if visible {
+ // Add filtered status to ret slice.
+ filtered = append(filtered, status)
+ }
+ }
+
+ return filtered, nil
+}
+
+// StatusVisible will check if given status is visible to requester, accounting for requester with no auth (i.e is nil), suspensions, disabled local users, account blocks and status privacy.
+func (f *Filter) StatusVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
+ // By default we assume no auth.
+ requesterID := noauth
+
+ if requester != nil {
+ // Use provided account ID.
+ requesterID = requester.ID
+ }
+
+ visibility, err := f.state.Caches.Visibility.Load("Type.RequesterID.ItemID", func() (*cache.CachedVisibility, error) {
+ // Visibility not yet cached, perform visibility lookup.
+ visible, err := f.isStatusVisible(ctx, requester, status)
+ if err != nil {
+ return nil, err
+ }
+
+ // Return visibility value.
+ return &cache.CachedVisibility{
+ ItemID: status.ID,
+ RequesterID: requesterID,
+ Type: cache.VisibilityTypeStatus,
+ Value: visible,
+ }, nil
+ }, "status", requesterID, status.ID)
+ if err != nil {
+ return false, err
+ }
+
+ return visibility.Value, nil
+}
+
+// isStatusVisible will check if status is visible to requester. It is the "meat" of the logic to Filter{}.StatusVisible() which is called within cache loader callback.
+func (f *Filter) isStatusVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
+ // Ensure that status is fully populated for further processing.
+ if err := f.state.DB.PopulateStatus(ctx, status); err != nil {
+ return false, err
+ }
+
+ // Check whether status accounts are visible to the requester.
+ visible, err := f.areStatusAccountsVisible(ctx, requester, status)
+ if err != nil {
+ return false, fmt.Errorf("isStatusVisible: error checking status %s account visibility: %w", status.ID, err)
+ } else if !visible {
+ return false, nil
+ }
+
+ if status.Visibility == gtsmodel.VisibilityPublic {
+ // This status will be visible to all.
+ return true, nil
+ }
+
+ if requester == nil {
+ // This request is WITHOUT auth, and status is NOT public.
+ log.Trace(ctx, "unauthorized request to non-public status")
+ return false, nil
+ }
+
+ if status.Visibility == gtsmodel.VisibilityUnlocked {
+ // This status is visible to all auth'd accounts.
+ return true, nil
+ }
+
+ if requester.ID == status.AccountID {
+ // Author can always see their own status.
+ return true, nil
+ }
+
+ if status.MentionsAccount(requester.ID) {
+ // Status mentions the requesting account.
+ return true, nil
+ }
+
+ if status.BoostOf != nil {
+ if !status.BoostOf.MentionsPopulated() {
+ // Boosted status needs its mentions populating, fetch these from database.
+ status.BoostOf.Mentions, err = f.state.DB.GetMentions(ctx, status.BoostOf.MentionIDs)
+ if err != nil {
+ return false, fmt.Errorf("isStatusVisible: error populating boosted status %s mentions: %w", status.BoostOfID, err)
+ }
+ }
+
+ if status.BoostOf.MentionsAccount(requester.ID) {
+ // Boosted status mentions the requesting account.
+ return true, nil
+ }
+ }
+
+ switch status.Visibility {
+ case gtsmodel.VisibilityFollowersOnly:
+ // Check requester follows status author.
+ follows, err := f.state.DB.IsFollowing(ctx,
+ requester.ID,
+ status.AccountID,
+ )
+ if err != nil {
+ return false, fmt.Errorf("isStatusVisible: error checking follow %s->%s: %w", requester.ID, status.AccountID, err)
+ }
+
+ if !follows {
+ log.Trace(ctx, "follow-only status not visible to requester")
+ return false, nil
+ }
+
+ return true, nil
+
+ case gtsmodel.VisibilityMutualsOnly:
+ // Check mutual following between requester and author.
+ mutuals, err := f.state.DB.IsMutualFollowing(ctx,
+ requester.ID,
+ status.AccountID,
+ )
+ if err != nil {
+ return false, fmt.Errorf("isStatusVisible: error checking mutual follow %s<->%s: %w", requester.ID, status.AccountID, err)
+ }
+
+ if !mutuals {
+ log.Trace(ctx, "mutual-only status not visible to requester")
+ return false, nil
+ }
+
+ return true, nil
+
+ case gtsmodel.VisibilityDirect:
+ log.Trace(ctx, "direct status not visible to requester")
+ return false, nil
+
+ default:
+ log.Warnf(ctx, "unexpected status visibility %s for %s", status.Visibility, status.URI)
+ return false, nil
+ }
+}
+
+// areStatusAccountsVisible calls Filter{}.AccountVisible() on status author and the status boost-of (if set) author, returning visibility of status (and boost-of) to requester.
+func (f *Filter) areStatusAccountsVisible(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status) (bool, error) {
+ // Check whether status author's account is visible to requester.
+ visible, err := f.AccountVisible(ctx, requester, status.Account)
+ if err != nil {
+ return false, err
+ }
+
+ if !visible {
+ log.Trace(ctx, "status author not visible to requester")
+ return false, nil
+ }
+
+ if status.BoostOfID != "" {
+ // This is a boosted status.
+
+ if status.AccountID == status.BoostOfAccountID {
+ // Some clout-chaser boosted their own status, tch.
+ return true, nil
+ }
+
+ // Check whether boosted status author's account is visible to requester.
+ visible, err := f.AccountVisible(ctx, requester, status.BoostOfAccount)
+ if err != nil {
+ return false, err
+ }
+
+ if !visible {
+ log.Trace(ctx, "boosted status author not visible to requester")
+ return false, nil
+ }
+ }
+
+ return true, nil
+}
diff --git a/internal/visibility/statusvisible_test.go b/internal/visibility/status_test.go
similarity index 61%
rename from internal/visibility/statusvisible_test.go
rename to internal/visibility/status_test.go
index bd799d7ca..ad6bc66df 100644
--- a/internal/visibility/statusvisible_test.go
+++ b/internal/visibility/status_test.go
@@ -34,7 +34,7 @@ func (suite *StatusVisibleTestSuite) TestOwnStatusVisible() {
testAccount := suite.testAccounts["local_account_1"]
ctx := context.Background()
- visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
+ visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(visible)
@@ -48,7 +48,7 @@ func (suite *StatusVisibleTestSuite) TestOwnDMVisible() {
suite.NoError(err)
testAccount := suite.testAccounts["local_account_2"]
- visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
+ visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(visible)
@@ -62,7 +62,7 @@ func (suite *StatusVisibleTestSuite) TestDMVisibleToTarget() {
suite.NoError(err)
testAccount := suite.testAccounts["local_account_1"]
- visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
+ visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.True(visible)
@@ -76,7 +76,7 @@ func (suite *StatusVisibleTestSuite) TestDMNotVisibleIfNotMentioned() {
suite.NoError(err)
testAccount := suite.testAccounts["admin_account"]
- visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
+ visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(visible)
@@ -92,7 +92,7 @@ func (suite *StatusVisibleTestSuite) TestStatusNotVisibleIfNotMutuals() {
suite.NoError(err)
testAccount := suite.testAccounts["local_account_2"]
- visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
+ visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(visible)
@@ -108,12 +108,54 @@ func (suite *StatusVisibleTestSuite) TestStatusNotVisibleIfNotFollowing() {
suite.NoError(err)
testAccount := suite.testAccounts["admin_account"]
- visible, err := suite.filter.StatusVisible(ctx, testStatus, testAccount)
+ visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
suite.NoError(err)
suite.False(visible)
}
+func (suite *StatusVisibleTestSuite) TestStatusNotVisibleIfNotMutualsCached() {
+ ctx := context.Background()
+ testStatusID := suite.testStatuses["local_account_1_status_4"].ID
+ testStatus, err := suite.db.GetStatusByID(ctx, testStatusID)
+ suite.NoError(err)
+ testAccount := suite.testAccounts["local_account_2"]
+
+ // Perform a status visibility check while mutuals, this shsould be true.
+ visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
+ suite.NoError(err)
+ suite.True(visible)
+
+ err = suite.db.DeleteFollowByID(ctx, suite.testFollows["local_account_2_local_account_1"].ID)
+ suite.NoError(err)
+
+ // Perform a status visibility check after unfollow, this should be false.
+ visible, err = suite.filter.StatusVisible(ctx, testAccount, testStatus)
+ suite.NoError(err)
+ suite.False(visible)
+}
+
+func (suite *StatusVisibleTestSuite) TestStatusNotVisibleIfNotFollowingCached() {
+ ctx := context.Background()
+ testStatusID := suite.testStatuses["local_account_1_status_5"].ID
+ testStatus, err := suite.db.GetStatusByID(ctx, testStatusID)
+ suite.NoError(err)
+ testAccount := suite.testAccounts["admin_account"]
+
+ // Perform a status visibility check while following, this shsould be true.
+ visible, err := suite.filter.StatusVisible(ctx, testAccount, testStatus)
+ suite.NoError(err)
+ suite.True(visible)
+
+ err = suite.db.DeleteFollowByID(ctx, suite.testFollows["admin_account_local_account_1"].ID)
+ suite.NoError(err)
+
+ // Perform a status visibility check after unfollow, this should be false.
+ visible, err = suite.filter.StatusVisible(ctx, testAccount, testStatus)
+ suite.NoError(err)
+ suite.False(visible)
+}
+
func TestStatusVisibleTestSuite(t *testing.T) {
suite.Run(t, new(StatusVisibleTestSuite))
}
diff --git a/internal/visibility/statusboostable.go b/internal/visibility/statusboostable.go
deleted file mode 100644
index e008008c2..000000000
--- a/internal/visibility/statusboostable.go
+++ /dev/null
@@ -1,60 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see .
-
-package visibility
-
-import (
- "context"
- "errors"
- "fmt"
-
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
-)
-
-func (f *filter) StatusBoostable(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) {
- // if the status isn't visible, it certainly isn't boostable
- visible, err := f.StatusVisible(ctx, targetStatus, requestingAccount)
- if err != nil {
- return false, fmt.Errorf("error seeing if status %s is visible: %s", targetStatus.ID, err)
- }
- if !visible {
- return false, errors.New("status is not visible")
- }
-
- // direct messages are never boostable, even if they're visible
- if targetStatus.Visibility == gtsmodel.VisibilityDirect {
- log.Trace(ctx, "status is not boostable because it is a DM")
- return false, nil
- }
-
- // the original account should always be able to boost its own non-DM statuses
- if requestingAccount.ID == targetStatus.Account.ID {
- log.Trace(ctx, "status is boostable because author is booster")
- return true, nil
- }
-
- // if status is followers-only and not the author's, it is not boostable
- if targetStatus.Visibility == gtsmodel.VisibilityFollowersOnly {
- log.Trace(ctx, "status not boostable because it is followers-only")
- return false, nil
- }
-
- // otherwise, status is as boostable as it says it is
- log.Trace(ctx, "defaulting to status.boostable value")
- return *targetStatus.Boostable, nil
-}
diff --git a/internal/visibility/statushometimelineable.go b/internal/visibility/statushometimelineable.go
deleted file mode 100644
index b5d5b836e..000000000
--- a/internal/visibility/statushometimelineable.go
+++ /dev/null
@@ -1,126 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see .
-
-package visibility
-
-import (
- "context"
- "fmt"
- "time"
-
- "codeberg.org/gruf/go-kv"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
- "github.com/superseriousbusiness/gotosocial/internal/log"
-)
-
-func (f *filter) StatusHometimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) {
- l := log.WithContext(ctx).
- WithFields(kv.Fields{{"statusID", targetStatus.ID}}...)
-
- // don't timeline statuses more than 5 min in the future
- maxID, err := id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
- if err != nil {
- return false, err
- }
-
- if targetStatus.ID > maxID {
- l.Debug("status not hometimelineable because it's from more than 5 minutes in the future")
- return false, nil
- }
-
- // status owner should always be able to see their own status in their timeline so we can return early if this is the case
- if targetStatus.AccountID == timelineOwnerAccount.ID {
- return true, nil
- }
-
- v, err := f.StatusVisible(ctx, targetStatus, timelineOwnerAccount)
- if err != nil {
- return false, fmt.Errorf("StatusHometimelineable: error checking visibility of status with id %s: %s", targetStatus.ID, err)
- }
-
- if !v {
- l.Debug("status is not hometimelineable because it's not visible to the requester")
- return false, nil
- }
-
- for _, m := range targetStatus.Mentions {
- if m.TargetAccountID == timelineOwnerAccount.ID {
- // if we're mentioned we should be able to see the post
- return true, nil
- }
- }
-
- // check we follow the originator of the status
- if targetStatus.Account == nil {
- tsa, err := f.db.GetAccountByID(ctx, targetStatus.AccountID)
- if err != nil {
- return false, fmt.Errorf("StatusHometimelineable: error getting status author account with id %s: %s", targetStatus.AccountID, err)
- }
- targetStatus.Account = tsa
- }
- following, err := f.db.IsFollowing(ctx, timelineOwnerAccount, targetStatus.Account)
- if err != nil {
- return false, fmt.Errorf("StatusHometimelineable: error checking if %s follows %s: %s", timelineOwnerAccount.ID, targetStatus.AccountID, err)
- }
- if !following {
- return false, nil
- }
-
- // Don't timeline a status whose parent hasn't been dereferenced yet or can't be dereferenced.
- // If we have the reply to URI but don't have an ID for the replied-to account or the replied-to status in our database, we haven't dereferenced it yet.
- if targetStatus.InReplyToURI != "" && (targetStatus.InReplyToID == "" || targetStatus.InReplyToAccountID == "") {
- return false, nil
- }
-
- // if a status replies to an ID we know in the database, we need to check that parent status too
- if targetStatus.InReplyToID != "" {
- // pin the reply to status on to this status if it hasn't been done already
- if targetStatus.InReplyTo == nil {
- rs, err := f.db.GetStatusByID(ctx, targetStatus.InReplyToID)
- if err != nil {
- return false, fmt.Errorf("StatusHometimelineable: error getting replied to status with id %s: %s", targetStatus.InReplyToID, err)
- }
- targetStatus.InReplyTo = rs
- }
-
- // pin the reply to account on to this status if it hasn't been done already
- if targetStatus.InReplyToAccount == nil {
- ra, err := f.db.GetAccountByID(ctx, targetStatus.InReplyToAccountID)
- if err != nil {
- return false, fmt.Errorf("StatusHometimelineable: error getting replied to account with id %s: %s", targetStatus.InReplyToAccountID, err)
- }
- targetStatus.InReplyToAccount = ra
- }
-
- // if it's a reply to the timelineOwnerAccount, we don't need to check if the timelineOwnerAccount follows itself, just return true, they can see it
- if targetStatus.InReplyToAccountID == timelineOwnerAccount.ID {
- return true, nil
- }
-
- // make sure the parent status is also home timelineable, otherwise we shouldn't timeline this one either
- parentStatusTimelineable, err := f.StatusHometimelineable(ctx, targetStatus.InReplyTo, timelineOwnerAccount)
- if err != nil {
- return false, fmt.Errorf("StatusHometimelineable: error checking timelineability of parent status %s of status %s: %s", targetStatus.InReplyToID, targetStatus.ID, err)
- }
- if !parentStatusTimelineable {
- return false, nil
- }
- }
-
- return true, nil
-}
diff --git a/internal/visibility/statuspublictimelineable.go b/internal/visibility/statuspublictimelineable.go
deleted file mode 100644
index ae2e95f06..000000000
--- a/internal/visibility/statuspublictimelineable.go
+++ /dev/null
@@ -1,72 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see .
-
-package visibility
-
-import (
- "context"
- "fmt"
- "time"
-
- "codeberg.org/gruf/go-kv"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/id"
- "github.com/superseriousbusiness/gotosocial/internal/log"
-)
-
-func (f *filter) StatusPublictimelineable(ctx context.Context, targetStatus *gtsmodel.Status, timelineOwnerAccount *gtsmodel.Account) (bool, error) {
- l := log.WithContext(ctx).
- WithFields(kv.Fields{{"statusID", targetStatus.ID}}...)
-
- // don't timeline statuses more than 5 min in the future
- maxID, err := id.NewULIDFromTime(time.Now().Add(5 * time.Minute))
- if err != nil {
- return false, err
- }
-
- if targetStatus.ID > maxID {
- l.Debug("status not hometimelineable because it's from more than 5 minutes in the future")
- return false, nil
- }
-
- // Don't timeline boosted statuses
- if targetStatus.BoostOfID != "" {
- return false, nil
- }
-
- // Don't timeline a reply
- if targetStatus.InReplyToURI != "" || targetStatus.InReplyToID != "" || targetStatus.InReplyToAccountID != "" {
- return false, nil
- }
-
- // status owner should always be able to see their own status in their timeline so we can return early if this is the case
- if timelineOwnerAccount != nil && targetStatus.AccountID == timelineOwnerAccount.ID {
- return true, nil
- }
-
- v, err := f.StatusVisible(ctx, targetStatus, timelineOwnerAccount)
- if err != nil {
- return false, fmt.Errorf("StatusPublictimelineable: error checking visibility of status with id %s: %s", targetStatus.ID, err)
- }
-
- if !v {
- l.Debug("status is not publicTimelineable because it's not visible to the requester")
- return false, nil
- }
-
- return true, nil
-}
diff --git a/internal/visibility/statusvisible.go b/internal/visibility/statusvisible.go
deleted file mode 100644
index 91a1f6221..000000000
--- a/internal/visibility/statusvisible.go
+++ /dev/null
@@ -1,252 +0,0 @@
-// GoToSocial
-// Copyright (C) GoToSocial Authors admin@gotosocial.org
-// SPDX-License-Identifier: AGPL-3.0-or-later
-//
-// This program is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Affero General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// This program is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Affero General Public License for more details.
-//
-// You should have received a copy of the GNU Affero General Public License
-// along with this program. If not, see .
-
-package visibility
-
-import (
- "context"
- "fmt"
-
- "codeberg.org/gruf/go-kv"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
-)
-
-func (f *filter) StatusVisible(ctx context.Context, targetStatus *gtsmodel.Status, requestingAccount *gtsmodel.Account) (bool, error) {
- l := log.WithContext(ctx).
- WithFields(kv.Fields{{"statusID", targetStatus.ID}}...)
-
- // Fetch any relevant accounts for the target status
- const getBoosted = true
- relevantAccounts, err := f.relevantAccounts(ctx, targetStatus, getBoosted)
- if err != nil {
- l.Debugf("error pulling relevant accounts for status %s: %s", targetStatus.ID, err)
- return false, fmt.Errorf("StatusVisible: error pulling relevant accounts for status %s: %s", targetStatus.ID, err)
- }
-
- // Check we have determined a target account
- targetAccount := relevantAccounts.Account
- if targetAccount == nil {
- l.Trace("target account is not set")
- return false, nil
- }
-
- // Check for domain blocks among relevant accounts
- domainBlocked, err := f.domainBlockedRelevant(ctx, relevantAccounts)
- if err != nil {
- l.Debugf("error checking domain block: %s", err)
- return false, fmt.Errorf("error checking domain block: %s", err)
- } else if domainBlocked {
- return false, nil
- }
-
- // if target account is suspended then don't show the status
- if !targetAccount.SuspendedAt.IsZero() {
- l.Trace("target account suspended at is not zero")
- return false, nil
- }
-
- // if the target user doesn't exist (anymore) then the status also shouldn't be visible
- // note: we only do this for local users
- if targetAccount.Domain == "" {
- targetUser, err := f.db.GetUserByAccountID(ctx, targetAccount.ID)
- if err != nil {
- l.Debug("target user could not be selected")
- if err == db.ErrNoEntries {
- return false, nil
- }
- return false, fmt.Errorf("StatusVisible: db error selecting user for local target account %s: %s", targetAccount.ID, err)
- }
-
- // if target user is disabled, not yet approved, or not confirmed then don't show the status
- // (although in the latter two cases it's unlikely they posted a status yet anyway, but you never know!)
- if *targetUser.Disabled || !*targetUser.Approved || targetUser.ConfirmedAt.IsZero() {
- l.Trace("target user is disabled, not approved, or not confirmed")
- return false, nil
- }
- }
-
- // If requesting account is nil, that means whoever requested the status didn't auth, or their auth failed.
- // In this case, we can still serve the status if it's public, otherwise we definitely shouldn't.
- if requestingAccount == nil {
- if targetStatus.Visibility == gtsmodel.VisibilityPublic {
- return true, nil
- }
- l.Trace("requesting account is nil but the target status isn't public")
- return false, nil
- }
-
- // if the requesting user doesn't exist (anymore) then the status also shouldn't be visible
- // note: we only do this for local users
- if requestingAccount.Domain == "" {
- requestingUser, err := f.db.GetUserByAccountID(ctx, requestingAccount.ID)
- if err != nil {
- // if the requesting account is local but doesn't have a corresponding user in the db this is a problem
- l.Debug("requesting user could not be selected")
- if err == db.ErrNoEntries {
- return false, nil
- }
- return false, fmt.Errorf("StatusVisible: db error selecting user for local requesting account %s: %s", requestingAccount.ID, err)
- }
- // okay, user exists, so make sure it has full privileges/is confirmed/approved
- if *requestingUser.Disabled || !*requestingUser.Approved || requestingUser.ConfirmedAt.IsZero() {
- l.Trace("requesting account is local but corresponding user is either disabled, not approved, or not confirmed")
- return false, nil
- }
- }
-
- // if requesting account is suspended then don't show the status -- although they probably shouldn't have gotten
- // this far (ie., been authed) in the first place: this is just for safety.
- if !requestingAccount.SuspendedAt.IsZero() {
- l.Trace("requesting account is suspended")
- return false, nil
- }
-
- // if the target status belongs to the requesting account, they should always be able to view it at this point
- if targetStatus.AccountID == requestingAccount.ID {
- return true, nil
- }
-
- // At this point we have a populated targetAccount, targetStatus, and requestingAccount, so we can check for blocks and whathaveyou
- // First check if a block exists directly between the target account (which authored the status) and the requesting account.
- if blocked, err := f.db.IsBlocked(ctx, targetAccount.ID, requestingAccount.ID, true); err != nil {
- l.Debugf("something went wrong figuring out if the accounts have a block: %s", err)
- return false, err
- } else if blocked {
- // don't allow the status to be viewed if a block exists in *either* direction between these two accounts, no creepy stalking please
- l.Trace("a block exists between requesting account and target account")
- return false, nil
- }
-
- // If not in reply to the requesting account, check if inReplyToAccount is blocked
- if relevantAccounts.InReplyToAccount != nil && relevantAccounts.InReplyToAccount.ID != requestingAccount.ID {
- if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.InReplyToAccount.ID, requestingAccount.ID, true); err != nil {
- return false, err
- } else if blocked {
- l.Trace("a block exists between requesting account and reply to account")
- return false, nil
- }
- }
-
- // status boosts accounts id
- if relevantAccounts.BoostedAccount != nil {
- if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.BoostedAccount.ID, requestingAccount.ID, true); err != nil {
- return false, err
- } else if blocked {
- l.Trace("a block exists between requesting account and boosted account")
- return false, nil
- }
- }
-
- // status boosts a reply to account id
- if relevantAccounts.BoostedInReplyToAccount != nil {
- if blocked, err := f.db.IsBlocked(ctx, relevantAccounts.BoostedInReplyToAccount.ID, requestingAccount.ID, true); err != nil {
- return false, err
- } else if blocked {
- l.Trace("a block exists between requesting account and boosted reply to account")
- return false, nil
- }
- }
-
- // boost mentions accounts
- for _, a := range relevantAccounts.BoostedMentionedAccounts {
- if a == nil {
- continue
- }
- if blocked, err := f.db.IsBlocked(ctx, a.ID, requestingAccount.ID, true); err != nil {
- return false, err
- } else if blocked {
- l.Trace("a block exists between requesting account and a boosted mentioned account")
- return false, nil
- }
- }
-
- // Iterate mentions to check for blocks or requester mentions
- isMentioned, blockAmongMentions := false, false
- for _, a := range relevantAccounts.MentionedAccounts {
- if a == nil {
- continue
- }
-
- if blocked, err := f.db.IsBlocked(ctx, a.ID, requestingAccount.ID, true); err != nil {
- return false, err
- } else if blocked {
- blockAmongMentions = true
- break
- }
-
- if a.ID == requestingAccount.ID {
- isMentioned = true
- }
- }
-
- if blockAmongMentions {
- l.Trace("a block exists between requesting account and a mentioned account")
- return false, nil
- } else if isMentioned {
- // Requester mentioned, should always be visible
- return true, nil
- }
-
- // at this point we know neither account blocks the other, or another account mentioned or otherwise referred to in the status
- // that means it's now just a matter of checking the visibility settings of the status itself
- switch targetStatus.Visibility {
- case gtsmodel.VisibilityPublic, gtsmodel.VisibilityUnlocked:
- // no problem here
- case gtsmodel.VisibilityFollowersOnly:
- // Followers-only post, check for a one-way follow to target
- follows, err := f.db.IsFollowing(ctx, requestingAccount, targetAccount)
- if err != nil {
- return false, err
- }
- if !follows {
- l.Trace("requested status is followers only but requesting account is not a follower")
- return false, nil
- }
- case gtsmodel.VisibilityMutualsOnly:
- // Mutuals-only post, check for a mutual follow
- mutuals, err := f.db.IsMutualFollowing(ctx, requestingAccount, targetAccount)
- if err != nil {
- return false, err
- }
- if !mutuals {
- l.Trace("requested status is mutuals only but accounts aren't mufos")
- return false, nil
- }
- case gtsmodel.VisibilityDirect:
- l.Trace("requesting account requests a direct status it's not mentioned in")
- return false, nil // it's not mentioned -_-
- }
-
- // If we reached here, all is okay
- return true, nil
-}
-
-func (f *filter) StatusesVisible(ctx context.Context, statuses []*gtsmodel.Status, requestingAccount *gtsmodel.Account) ([]*gtsmodel.Status, error) {
- filtered := []*gtsmodel.Status{}
- for _, s := range statuses {
- visible, err := f.StatusVisible(ctx, s, requestingAccount)
- if err != nil {
- return nil, err
- }
- if visible {
- filtered = append(filtered, s)
- }
- }
- return filtered, nil
-}
diff --git a/test/envparsing.sh b/test/envparsing.sh
index 361866881..2a9be5155 100755
--- a/test/envparsing.sh
+++ b/test/envparsing.sh
@@ -20,43 +20,55 @@ EXPECT=$(cat <<"EOF"
"account-max-size": 99,
"account-sweep-freq": 1000000000,
"account-ttl": 10800000000000,
- "block-max-size": 100,
- "block-sweep-freq": 30000000000,
- "block-ttl": 300000000000,
- "domain-block-max-size": 1000,
+ "block-max-size": 1000,
+ "block-sweep-freq": 60000000000,
+ "block-ttl": 1800000000000,
+ "domain-block-max-size": 2000,
"domain-block-sweep-freq": 60000000000,
"domain-block-ttl": 86400000000000,
"emoji-category-max-size": 100,
- "emoji-category-sweep-freq": 30000000000,
- "emoji-category-ttl": 300000000000,
- "emoji-max-size": 500,
- "emoji-sweep-freq": 30000000000,
- "emoji-ttl": 300000000000,
- "media-max-size": 500,
- "media-sweep-freq": 30000000000,
- "media-ttl": 300000000000,
- "mention-max-size": 500,
- "mention-sweep-freq": 30000000000,
- "mention-ttl": 300000000000,
- "notification-max-size": 500,
- "notification-sweep-freq": 30000000000,
- "notification-ttl": 300000000000,
+ "emoji-category-sweep-freq": 60000000000,
+ "emoji-category-ttl": 1800000000000,
+ "emoji-max-size": 2000,
+ "emoji-sweep-freq": 60000000000,
+ "emoji-ttl": 1800000000000,
+ "follow-max-size": 2000,
+ "follow-request-max-size": 2000,
+ "follow-request-sweep-freq": 60000000000,
+ "follow-request-ttl": 1800000000000,
+ "follow-sweep-freq": 60000000000,
+ "follow-ttl": 1800000000000,
+ "media-max-size": 1000,
+ "media-sweep-freq": 60000000000,
+ "media-ttl": 1800000000000,
+ "mention-max-size": 2000,
+ "mention-sweep-freq": 60000000000,
+ "mention-ttl": 1800000000000,
+ "notification-max-size": 1000,
+ "notification-sweep-freq": 60000000000,
+ "notification-ttl": 1800000000000,
"report-max-size": 100,
- "report-sweep-freq": 30000000000,
- "report-ttl": 300000000000,
- "status-max-size": 500,
- "status-sweep-freq": 30000000000,
- "status-ttl": 300000000000,
- "tombstone-max-size": 100,
- "tombstone-sweep-freq": 30000000000,
- "tombstone-ttl": 300000000000,
- "user-max-size": 100,
- "user-sweep-freq": 30000000000,
- "user-ttl": 300000000000,
+ "report-sweep-freq": 60000000000,
+ "report-ttl": 1800000000000,
+ "status-fave-max-size": 2000,
+ "status-fave-sweep-freq": 60000000000,
+ "status-fave-ttl": 1800000000000,
+ "status-max-size": 2000,
+ "status-sweep-freq": 60000000000,
+ "status-ttl": 1800000000000,
+ "tombstone-max-size": 500,
+ "tombstone-sweep-freq": 60000000000,
+ "tombstone-ttl": 1800000000000,
+ "user-max-size": 500,
+ "user-sweep-freq": 60000000000,
+ "user-ttl": 1800000000000,
"webfinger-max-size": 250,
"webfinger-sweep-freq": 900000000000,
"webfinger-ttl": 86400000000000
- }
+ },
+ "visibility-max-size": 2000,
+ "visibility-sweep-freq": 60000000000,
+ "visibility-ttl": 1800000000000
},
"config-path": "internal/config/testdata/test.yaml",
"db-address": ":memory:",