diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go
new file mode 100644
index 000000000..4697f05a6
--- /dev/null
+++ b/internal/cache/domain/domain.go
@@ -0,0 +1,170 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package domain
+
+import (
+ "fmt"
+ "time"
+
+ "codeberg.org/gruf/go-cache/v3/ttl"
+ "github.com/miekg/dns"
+)
+
+// BlockCache provides a means of caching domain blocks in memory to reduce load
+// on an underlying storage mechanism, e.g. a database.
+//
+// It consists of a TTL primary cache that stores calculated domain string to block results,
+// that on cache miss is filled by calculating block status by iterating over a list of all of
+// the domain blocks stored in memory. This reduces CPU usage required by not need needing to
+// iterate through a possible 100-1000s long block list, while saving memory by having a primary
+// cache of limited size that evicts stale entries. The raw list of all domain blocks should in
+// most cases be negligible when it comes to memory usage.
+//
+// The in-memory block list is kept up-to-date by means of a passed loader function during every
+// call to .IsBlocked(). In the case of a nil internal block list, the loader function is called to
+// hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to invalidate
+// the cache, e.g. when a domain block is added / deleted from the database. It will drop the current
+// list of domain blocks and clear all entries from the primary cache.
+type BlockCache struct {
+ pcache *ttl.Cache[string, bool] // primary cache of domains -> block results
+ blocks []block // raw list of all domain blocks, nil => not loaded.
+}
+
+// New returns a new initialized BlockCache instance with given primary cache capacity and TTL.
+func New(pcap int, pttl time.Duration) *BlockCache {
+ c := new(BlockCache)
+ c.pcache = new(ttl.Cache[string, bool])
+ c.pcache.Init(0, pcap, pttl)
+ return c
+}
+
+// Start will start the cache background eviction routine with given sweep frequency. If already running or a freq <= 0 provided, this is a no-op. This will block until the eviction routine has started.
+func (b *BlockCache) Start(pfreq time.Duration) bool {
+ return b.pcache.Start(pfreq)
+}
+
+// Stop will stop cache background eviction routine. If not running this is a no-op. This will block until the eviction routine has stopped.
+func (b *BlockCache) Stop() bool {
+ return b.pcache.Stop()
+}
+
+// IsBlocked checks whether domain is blocked. If the cache is not currently loaded, then the provided load function is used to hydrate it.
+// NOTE: be VERY careful using any kind of locking mechanism within the load function, as this itself is ran within the cache mutex lock.
+func (b *BlockCache) IsBlocked(domain string, load func() ([]string, error)) (bool, error) {
+ var blocked bool
+
+ // Acquire cache lock
+ b.pcache.Lock()
+ defer b.pcache.Unlock()
+
+ // Check primary cache for result
+ entry, ok := b.pcache.Cache.Get(domain)
+ if ok {
+ return entry.Value, nil
+ }
+
+ if b.blocks == nil {
+ // Cache is not hydrated
+ //
+ // Load domains from callback
+ domains, err := load()
+ if err != nil {
+ return false, fmt.Errorf("error reloading cache: %w", err)
+ }
+
+ // Drop all domain blocks and recreate
+ b.blocks = make([]block, len(domains))
+
+ for i, domain := range domains {
+ // Store pre-split labels for each domain block
+ b.blocks[i].labels = dns.SplitDomainName(domain)
+ }
+ }
+
+ // Split domain into it separate labels
+ labels := dns.SplitDomainName(domain)
+
+ // Compare this to our stored blocks
+ for _, block := range b.blocks {
+ if block.Blocks(labels) {
+ blocked = true
+ break
+ }
+ }
+
+ // Store block result in primary cache
+ b.pcache.Cache.Set(domain, &ttl.Entry[string, bool]{
+ Key: domain,
+ Value: blocked,
+ Expiry: time.Now().Add(b.pcache.TTL),
+ })
+
+ return blocked, nil
+}
+
+// Clear will drop the currently loaded domain list, and clear the primary cache.
+// This will trigger a reload on next call to .IsBlocked().
+func (b *BlockCache) Clear() {
+ // Drop all blocks.
+ b.pcache.Lock()
+ b.blocks = nil
+ b.pcache.Unlock()
+
+ // Clear needs to be done _outside_ of
+ // lock, as also acquires a mutex lock.
+ b.pcache.Clear()
+}
+
+// block represents a domain block, and stores the
+// deconstructed labels of a singular domain block.
+// e.g. []string{"gts", "superseriousbusiness", "org"}.
+type block struct {
+ labels []string
+}
+
+// Blocks checks whether the separated domain labels of an
+// incoming domain matches the stored (receiving struct) block.
+func (b block) Blocks(labels []string) bool {
+ // Calculate length difference
+ d := len(labels) - len(b.labels)
+ if d < 0 {
+ return false
+ }
+
+ // Iterate backwards through domain block's
+ // labels, omparing against the incoming domain's.
+ //
+ // So for the following input:
+ // labels = []string{"mail", "google", "com"}
+ // b.labels = []string{"google", "com"}
+ //
+ // These would be matched in reverse order along
+ // the entirety of the block object's labels:
+ // "com" => match
+ // "google" => match
+ //
+ // And so would reach the end and return true.
+ for i := len(b.labels) - 1; i >= 0; i-- {
+ if b.labels[i] != labels[i+d] {
+ return false
+ }
+ }
+
+ return true
+}
diff --git a/internal/cache/domain/domain_test.go b/internal/cache/domain/domain_test.go
new file mode 100644
index 000000000..416ce5012
--- /dev/null
+++ b/internal/cache/domain/domain_test.go
@@ -0,0 +1,85 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package domain_test
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/superseriousbusiness/gotosocial/internal/cache/domain"
+)
+
+func TestBlockCache(t *testing.T) {
+ c := domain.New(100, time.Second)
+
+ blocks := []string{
+ "google.com",
+ "google.co.uk",
+ "pleroma.bad.host",
+ }
+
+ loader := func() ([]string, error) {
+ t.Log("load: returning blocked domains")
+ return blocks, nil
+ }
+
+ // Check a list of known blocked domains.
+ for _, domain := range []string{
+ "google.com",
+ "mail.google.com",
+ "google.co.uk",
+ "mail.google.co.uk",
+ "pleroma.bad.host",
+ "dev.pleroma.bad.host",
+ } {
+ t.Logf("checking domain is blocked: %s", domain)
+ if b, _ := c.IsBlocked(domain, loader); !b {
+ t.Errorf("domain should be blocked: %s", domain)
+ }
+ }
+
+ // Check a list of known unblocked domains.
+ for _, domain := range []string{
+ "askjeeves.com",
+ "ask-kim.co.uk",
+ "google.ie",
+ "mail.google.ie",
+ "gts.bad.host",
+ "mastodon.bad.host",
+ } {
+ t.Logf("checking domain isn't blocked: %s", domain)
+ if b, _ := c.IsBlocked(domain, loader); b {
+ t.Errorf("domain should not be blocked: %s", domain)
+ }
+ }
+
+ // Clear the cache
+ c.Clear()
+
+ knownErr := errors.New("known error")
+
+ // Check that reload is actually performed and returns our error
+ if _, err := c.IsBlocked("", func() ([]string, error) {
+ t.Log("load: returning known error")
+ return nil, knownErr
+ }); !errors.Is(err, knownErr) {
+ t.Errorf("is blocked did not return expected error: %v", err)
+ }
+}
diff --git a/internal/cache/gts.go b/internal/cache/gts.go
index 6083b8693..3fa25ddef 100644
--- a/internal/cache/gts.go
+++ b/internal/cache/gts.go
@@ -20,6 +20,7 @@
import (
"codeberg.org/gruf/go-cache/v3/result"
+ "github.com/superseriousbusiness/gotosocial/internal/cache/domain"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@@ -41,8 +42,8 @@ type GTSCaches interface {
// Block provides access to the gtsmodel Block (account) database cache.
Block() *result.Cache[*gtsmodel.Block]
- // DomainBlock provides access to the gtsmodel DomainBlock database cache.
- DomainBlock() *result.Cache[*gtsmodel.DomainBlock]
+ // DomainBlock provides access to the domain block database cache.
+ DomainBlock() *domain.BlockCache
// Emoji provides access to the gtsmodel Emoji database cache.
Emoji() *result.Cache[*gtsmodel.Emoji]
@@ -74,7 +75,7 @@ func NewGTS() GTSCaches {
type gtsCaches struct {
account *result.Cache[*gtsmodel.Account]
block *result.Cache[*gtsmodel.Block]
- domainBlock *result.Cache[*gtsmodel.DomainBlock]
+ domainBlock *domain.BlockCache
emoji *result.Cache[*gtsmodel.Emoji]
emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
mention *result.Cache[*gtsmodel.Mention]
@@ -151,7 +152,7 @@ func (c *gtsCaches) Block() *result.Cache[*gtsmodel.Block] {
return c.block
}
-func (c *gtsCaches) DomainBlock() *result.Cache[*gtsmodel.DomainBlock] {
+func (c *gtsCaches) DomainBlock() *domain.BlockCache {
return c.domainBlock
}
@@ -212,14 +213,10 @@ func (c *gtsCaches) initBlock() {
}
func (c *gtsCaches) initDomainBlock() {
- c.domainBlock = result.NewSized([]result.Lookup{
- {Name: "Domain"},
- }, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock {
- d2 := new(gtsmodel.DomainBlock)
- *d2 = *d1
- return d2
- }, config.GetCacheGTSDomainBlockMaxSize())
- c.domainBlock.SetTTL(config.GetCacheGTSDomainBlockTTL(), true)
+ c.domainBlock = domain.New(
+ config.GetCacheGTSDomainBlockMaxSize(),
+ config.GetCacheGTSDomainBlockTTL(),
+ )
}
func (c *gtsCaches) initEmoji() {
diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go
index a5d9f61e2..5407f9656 100644
--- a/internal/db/bundb/domain.go
+++ b/internal/db/bundb/domain.go
@@ -50,46 +50,52 @@ func normalizeDomain(domain string) (out string, err error) {
func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error {
var err error
+ // Normalize the domain as punycode
block.Domain, err = normalizeDomain(block.Domain)
if err != nil {
return err
}
- return d.state.Caches.GTS.DomainBlock().Store(block, func() error {
- _, err := d.conn.NewInsert().
- Model(block).
- Exec(ctx)
+ // Attempt to store domain in DB
+ if _, err := d.conn.NewInsert().
+ Model(block).
+ Exec(ctx); err != nil {
return d.conn.ProcessError(err)
- })
+ }
+
+ // Clear the domain block cache (for later reload)
+ d.state.Caches.GTS.DomainBlock().Clear()
+
+ return nil
}
func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) {
var err error
+ // Normalize the domain as punycode
domain, err = normalizeDomain(domain)
if err != nil {
return nil, err
}
- return d.state.Caches.GTS.DomainBlock().Load("Domain", func() (*gtsmodel.DomainBlock, error) {
- // Check for easy case, domain referencing *us*
- if domain == "" || domain == config.GetAccountDomain() {
- return nil, db.ErrNoEntries
- }
+ // Check for easy case, domain referencing *us*
+ if domain == "" || domain == config.GetAccountDomain() ||
+ domain == config.GetHost() {
+ return nil, db.ErrNoEntries
+ }
- var block gtsmodel.DomainBlock
+ var block gtsmodel.DomainBlock
- q := d.conn.
- NewSelect().
- Model(&block).
- Where("? = ?", bun.Ident("domain_block.domain"), domain).
- Limit(1)
- if err := q.Scan(ctx); err != nil {
- return nil, d.conn.ProcessError(err)
- }
+ // Look for block matching domain in DB
+ q := d.conn.
+ NewSelect().
+ Model(&block).
+ Where("? = ?", bun.Ident("domain_block.domain"), domain)
+ if err := q.Scan(ctx); err != nil {
+ return nil, d.conn.ProcessError(err)
+ }
- return &block, nil
- }, domain)
+ return &block, nil
}
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {
@@ -108,18 +114,39 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro
return d.conn.ProcessError(err)
}
- // Clear domain from cache
- d.state.Caches.GTS.DomainBlock().Invalidate(domain)
+ // Clear the domain block cache (for later reload)
+ d.state.Caches.GTS.DomainBlock().Clear()
return nil
}
func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) {
- block, err := d.GetDomainBlock(ctx, domain)
- if err == nil || err == db.ErrNoEntries {
- return (block != nil), nil
+ // Normalize the domain as punycode
+ domain, err := normalizeDomain(domain)
+ if err != nil {
+ return false, err
}
- return false, err
+
+ // Check for easy case, domain referencing *us*
+ if domain == "" || domain == config.GetAccountDomain() ||
+ domain == config.GetHost() {
+ return false, nil
+ }
+
+ // Check the cache for a domain block (hydrating the cache with callback if necessary)
+ return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) {
+ var domains []string
+
+ // Scan list of all blocked domains from DB
+ q := d.conn.NewSelect().
+ Table("domain_blocks").
+ Column("domain")
+ if err := q.Scan(ctx, &domains); err != nil {
+ return nil, d.conn.ProcessError(err)
+ }
+
+ return domains, nil
+ })
}
func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) {
diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go
index 41a73ff80..8091e6585 100644
--- a/internal/db/bundb/domain_test.go
+++ b/internal/db/bundb/domain_test.go
@@ -56,6 +56,38 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() {
suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second)
}
+func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() {
+ ctx := context.Background()
+
+ domainBlock := >smodel.DomainBlock{
+ ID: "01G204214Y9TNJEBX39C7G88SW",
+ Domain: "bad.apples",
+ CreatedByAccountID: suite.testAccounts["admin_account"].ID,
+ CreatedByAccount: suite.testAccounts["admin_account"],
+ }
+
+ // no domain block exists for the given domain yet
+ blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain)
+ suite.NoError(err)
+ suite.False(blocked)
+
+ err = suite.db.CreateDomainBlock(ctx, domainBlock)
+ suite.NoError(err)
+
+ // Start with the base block domain
+ domain := domainBlock.Domain
+
+ for _, part := range []string{"extra", "domain", "parts"} {
+ // Prepend the next domain part
+ domain = part + "." + domain
+
+ // Check that domain block is wildcarded for this subdomain
+ blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain)
+ suite.NoError(err)
+ suite.True(blocked)
+ }
+}
+
func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() {
ctx := context.Background()