diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go new file mode 100644 index 000000000..4697f05a6 --- /dev/null +++ b/internal/cache/domain/domain.go @@ -0,0 +1,170 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package domain + +import ( + "fmt" + "time" + + "codeberg.org/gruf/go-cache/v3/ttl" + "github.com/miekg/dns" +) + +// BlockCache provides a means of caching domain blocks in memory to reduce load +// on an underlying storage mechanism, e.g. a database. +// +// It consists of a TTL primary cache that stores calculated domain string to block results, +// that on cache miss is filled by calculating block status by iterating over a list of all of +// the domain blocks stored in memory. This reduces CPU usage required by not need needing to +// iterate through a possible 100-1000s long block list, while saving memory by having a primary +// cache of limited size that evicts stale entries. The raw list of all domain blocks should in +// most cases be negligible when it comes to memory usage. +// +// The in-memory block list is kept up-to-date by means of a passed loader function during every +// call to .IsBlocked(). In the case of a nil internal block list, the loader function is called to +// hydrate the cache with the latest list of domain blocks. The .Clear() function can be used to invalidate +// the cache, e.g. when a domain block is added / deleted from the database. It will drop the current +// list of domain blocks and clear all entries from the primary cache. +type BlockCache struct { + pcache *ttl.Cache[string, bool] // primary cache of domains -> block results + blocks []block // raw list of all domain blocks, nil => not loaded. +} + +// New returns a new initialized BlockCache instance with given primary cache capacity and TTL. +func New(pcap int, pttl time.Duration) *BlockCache { + c := new(BlockCache) + c.pcache = new(ttl.Cache[string, bool]) + c.pcache.Init(0, pcap, pttl) + return c +} + +// Start will start the cache background eviction routine with given sweep frequency. If already running or a freq <= 0 provided, this is a no-op. This will block until the eviction routine has started. +func (b *BlockCache) Start(pfreq time.Duration) bool { + return b.pcache.Start(pfreq) +} + +// Stop will stop cache background eviction routine. If not running this is a no-op. This will block until the eviction routine has stopped. +func (b *BlockCache) Stop() bool { + return b.pcache.Stop() +} + +// IsBlocked checks whether domain is blocked. If the cache is not currently loaded, then the provided load function is used to hydrate it. +// NOTE: be VERY careful using any kind of locking mechanism within the load function, as this itself is ran within the cache mutex lock. +func (b *BlockCache) IsBlocked(domain string, load func() ([]string, error)) (bool, error) { + var blocked bool + + // Acquire cache lock + b.pcache.Lock() + defer b.pcache.Unlock() + + // Check primary cache for result + entry, ok := b.pcache.Cache.Get(domain) + if ok { + return entry.Value, nil + } + + if b.blocks == nil { + // Cache is not hydrated + // + // Load domains from callback + domains, err := load() + if err != nil { + return false, fmt.Errorf("error reloading cache: %w", err) + } + + // Drop all domain blocks and recreate + b.blocks = make([]block, len(domains)) + + for i, domain := range domains { + // Store pre-split labels for each domain block + b.blocks[i].labels = dns.SplitDomainName(domain) + } + } + + // Split domain into it separate labels + labels := dns.SplitDomainName(domain) + + // Compare this to our stored blocks + for _, block := range b.blocks { + if block.Blocks(labels) { + blocked = true + break + } + } + + // Store block result in primary cache + b.pcache.Cache.Set(domain, &ttl.Entry[string, bool]{ + Key: domain, + Value: blocked, + Expiry: time.Now().Add(b.pcache.TTL), + }) + + return blocked, nil +} + +// Clear will drop the currently loaded domain list, and clear the primary cache. +// This will trigger a reload on next call to .IsBlocked(). +func (b *BlockCache) Clear() { + // Drop all blocks. + b.pcache.Lock() + b.blocks = nil + b.pcache.Unlock() + + // Clear needs to be done _outside_ of + // lock, as also acquires a mutex lock. + b.pcache.Clear() +} + +// block represents a domain block, and stores the +// deconstructed labels of a singular domain block. +// e.g. []string{"gts", "superseriousbusiness", "org"}. +type block struct { + labels []string +} + +// Blocks checks whether the separated domain labels of an +// incoming domain matches the stored (receiving struct) block. +func (b block) Blocks(labels []string) bool { + // Calculate length difference + d := len(labels) - len(b.labels) + if d < 0 { + return false + } + + // Iterate backwards through domain block's + // labels, omparing against the incoming domain's. + // + // So for the following input: + // labels = []string{"mail", "google", "com"} + // b.labels = []string{"google", "com"} + // + // These would be matched in reverse order along + // the entirety of the block object's labels: + // "com" => match + // "google" => match + // + // And so would reach the end and return true. + for i := len(b.labels) - 1; i >= 0; i-- { + if b.labels[i] != labels[i+d] { + return false + } + } + + return true +} diff --git a/internal/cache/domain/domain_test.go b/internal/cache/domain/domain_test.go new file mode 100644 index 000000000..416ce5012 --- /dev/null +++ b/internal/cache/domain/domain_test.go @@ -0,0 +1,85 @@ +/* + GoToSocial + Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . +*/ + +package domain_test + +import ( + "errors" + "testing" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/cache/domain" +) + +func TestBlockCache(t *testing.T) { + c := domain.New(100, time.Second) + + blocks := []string{ + "google.com", + "google.co.uk", + "pleroma.bad.host", + } + + loader := func() ([]string, error) { + t.Log("load: returning blocked domains") + return blocks, nil + } + + // Check a list of known blocked domains. + for _, domain := range []string{ + "google.com", + "mail.google.com", + "google.co.uk", + "mail.google.co.uk", + "pleroma.bad.host", + "dev.pleroma.bad.host", + } { + t.Logf("checking domain is blocked: %s", domain) + if b, _ := c.IsBlocked(domain, loader); !b { + t.Errorf("domain should be blocked: %s", domain) + } + } + + // Check a list of known unblocked domains. + for _, domain := range []string{ + "askjeeves.com", + "ask-kim.co.uk", + "google.ie", + "mail.google.ie", + "gts.bad.host", + "mastodon.bad.host", + } { + t.Logf("checking domain isn't blocked: %s", domain) + if b, _ := c.IsBlocked(domain, loader); b { + t.Errorf("domain should not be blocked: %s", domain) + } + } + + // Clear the cache + c.Clear() + + knownErr := errors.New("known error") + + // Check that reload is actually performed and returns our error + if _, err := c.IsBlocked("", func() ([]string, error) { + t.Log("load: returning known error") + return nil, knownErr + }); !errors.Is(err, knownErr) { + t.Errorf("is blocked did not return expected error: %v", err) + } +} diff --git a/internal/cache/gts.go b/internal/cache/gts.go index 6083b8693..3fa25ddef 100644 --- a/internal/cache/gts.go +++ b/internal/cache/gts.go @@ -20,6 +20,7 @@ import ( "codeberg.org/gruf/go-cache/v3/result" + "github.com/superseriousbusiness/gotosocial/internal/cache/domain" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" ) @@ -41,8 +42,8 @@ type GTSCaches interface { // Block provides access to the gtsmodel Block (account) database cache. Block() *result.Cache[*gtsmodel.Block] - // DomainBlock provides access to the gtsmodel DomainBlock database cache. - DomainBlock() *result.Cache[*gtsmodel.DomainBlock] + // DomainBlock provides access to the domain block database cache. + DomainBlock() *domain.BlockCache // Emoji provides access to the gtsmodel Emoji database cache. Emoji() *result.Cache[*gtsmodel.Emoji] @@ -74,7 +75,7 @@ func NewGTS() GTSCaches { type gtsCaches struct { account *result.Cache[*gtsmodel.Account] block *result.Cache[*gtsmodel.Block] - domainBlock *result.Cache[*gtsmodel.DomainBlock] + domainBlock *domain.BlockCache emoji *result.Cache[*gtsmodel.Emoji] emojiCategory *result.Cache[*gtsmodel.EmojiCategory] mention *result.Cache[*gtsmodel.Mention] @@ -151,7 +152,7 @@ func (c *gtsCaches) Block() *result.Cache[*gtsmodel.Block] { return c.block } -func (c *gtsCaches) DomainBlock() *result.Cache[*gtsmodel.DomainBlock] { +func (c *gtsCaches) DomainBlock() *domain.BlockCache { return c.domainBlock } @@ -212,14 +213,10 @@ func (c *gtsCaches) initBlock() { } func (c *gtsCaches) initDomainBlock() { - c.domainBlock = result.NewSized([]result.Lookup{ - {Name: "Domain"}, - }, func(d1 *gtsmodel.DomainBlock) *gtsmodel.DomainBlock { - d2 := new(gtsmodel.DomainBlock) - *d2 = *d1 - return d2 - }, config.GetCacheGTSDomainBlockMaxSize()) - c.domainBlock.SetTTL(config.GetCacheGTSDomainBlockTTL(), true) + c.domainBlock = domain.New( + config.GetCacheGTSDomainBlockMaxSize(), + config.GetCacheGTSDomainBlockTTL(), + ) } func (c *gtsCaches) initEmoji() { diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index a5d9f61e2..5407f9656 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -50,46 +50,52 @@ func normalizeDomain(domain string) (out string, err error) { func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) db.Error { var err error + // Normalize the domain as punycode block.Domain, err = normalizeDomain(block.Domain) if err != nil { return err } - return d.state.Caches.GTS.DomainBlock().Store(block, func() error { - _, err := d.conn.NewInsert(). - Model(block). - Exec(ctx) + // Attempt to store domain in DB + if _, err := d.conn.NewInsert(). + Model(block). + Exec(ctx); err != nil { return d.conn.ProcessError(err) - }) + } + + // Clear the domain block cache (for later reload) + d.state.Caches.GTS.DomainBlock().Clear() + + return nil } func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { var err error + // Normalize the domain as punycode domain, err = normalizeDomain(domain) if err != nil { return nil, err } - return d.state.Caches.GTS.DomainBlock().Load("Domain", func() (*gtsmodel.DomainBlock, error) { - // Check for easy case, domain referencing *us* - if domain == "" || domain == config.GetAccountDomain() { - return nil, db.ErrNoEntries - } + // Check for easy case, domain referencing *us* + if domain == "" || domain == config.GetAccountDomain() || + domain == config.GetHost() { + return nil, db.ErrNoEntries + } - var block gtsmodel.DomainBlock + var block gtsmodel.DomainBlock - q := d.conn. - NewSelect(). - Model(&block). - Where("? = ?", bun.Ident("domain_block.domain"), domain). - Limit(1) - if err := q.Scan(ctx); err != nil { - return nil, d.conn.ProcessError(err) - } + // Look for block matching domain in DB + q := d.conn. + NewSelect(). + Model(&block). + Where("? = ?", bun.Ident("domain_block.domain"), domain) + if err := q.Scan(ctx); err != nil { + return nil, d.conn.ProcessError(err) + } - return &block, nil - }, domain) + return &block, nil } func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { @@ -108,18 +114,39 @@ func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Erro return d.conn.ProcessError(err) } - // Clear domain from cache - d.state.Caches.GTS.DomainBlock().Invalidate(domain) + // Clear the domain block cache (for later reload) + d.state.Caches.GTS.DomainBlock().Clear() return nil } func (d *domainDB) IsDomainBlocked(ctx context.Context, domain string) (bool, db.Error) { - block, err := d.GetDomainBlock(ctx, domain) - if err == nil || err == db.ErrNoEntries { - return (block != nil), nil + // Normalize the domain as punycode + domain, err := normalizeDomain(domain) + if err != nil { + return false, err } - return false, err + + // Check for easy case, domain referencing *us* + if domain == "" || domain == config.GetAccountDomain() || + domain == config.GetHost() { + return false, nil + } + + // Check the cache for a domain block (hydrating the cache with callback if necessary) + return d.state.Caches.GTS.DomainBlock().IsBlocked(domain, func() ([]string, error) { + var domains []string + + // Scan list of all blocked domains from DB + q := d.conn.NewSelect(). + Table("domain_blocks"). + Column("domain") + if err := q.Scan(ctx, &domains); err != nil { + return nil, d.conn.ProcessError(err) + } + + return domains, nil + }) } func (d *domainDB) AreDomainsBlocked(ctx context.Context, domains []string) (bool, db.Error) { diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index 41a73ff80..8091e6585 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -56,6 +56,38 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() { suite.WithinDuration(time.Now(), domainBlock.CreatedAt, 10*time.Second) } +func (suite *DomainTestSuite) TestIsDomainBlockedWildcard() { + ctx := context.Background() + + domainBlock := >smodel.DomainBlock{ + ID: "01G204214Y9TNJEBX39C7G88SW", + Domain: "bad.apples", + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + CreatedByAccount: suite.testAccounts["admin_account"], + } + + // no domain block exists for the given domain yet + blocked, err := suite.db.IsDomainBlocked(ctx, domainBlock.Domain) + suite.NoError(err) + suite.False(blocked) + + err = suite.db.CreateDomainBlock(ctx, domainBlock) + suite.NoError(err) + + // Start with the base block domain + domain := domainBlock.Domain + + for _, part := range []string{"extra", "domain", "parts"} { + // Prepend the next domain part + domain = part + "." + domain + + // Check that domain block is wildcarded for this subdomain + blocked, err = suite.db.IsDomainBlocked(ctx, domainBlock.Domain) + suite.NoError(err) + suite.True(blocked) + } +} + func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() { ctx := context.Background()