From dfc7656579349bda98d3097c473efbb6000e233b Mon Sep 17 00:00:00 2001 From: kim <89579420+NyaaaWhatsUpDoc@users.noreply.github.com> Date: Tue, 9 Jan 2024 13:12:43 +0000 Subject: [PATCH] [bugfix] fix higher-level explicit domain rules causing issues with lower-level domain blocking (#2513) * fix the sort direction of domain cache child nodes ... * add more domain cache test cases * add specific test for this bug to database domain test suite (thanks for writing this @tsmethurst!) * remove unused field (this was a previous attempt at a fix) * remove debugging println statements :innocent: --- internal/cache/domain/domain.go | 47 ++++++++++++++++----- internal/cache/domain/domain_test.go | 24 +++++++---- internal/db/bundb/domain_test.go | 62 ++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 19 deletions(-) diff --git a/internal/cache/domain/domain.go b/internal/cache/domain/domain.go index 1b836ed28..274a244f7 100644 --- a/internal/cache/domain/domain.go +++ b/internal/cache/domain/domain.go @@ -19,10 +19,9 @@ import ( "fmt" + "slices" "strings" "sync/atomic" - - "golang.org/x/exp/slices" ) // Cache provides a means of caching domains in memory to reduce @@ -57,6 +56,24 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err return false, fmt.Errorf("error reloading cache: %w", err) } + // Ensure the domains being inserted into the cache + // are sorted by number of domain parts. i.e. those + // with less parts are inserted last, else this can + // allow domains to fall through the matching code! + slices.SortFunc(domains, func(a, b string) int { + const k = +1 + an := strings.Count(a, ".") + bn := strings.Count(b, ".") + switch { + case an < bn: + return +k + case an > bn: + return -k + default: + return 0 + } + }) + // Allocate new radix trie // node to store matches. ptr = new(root) @@ -94,13 +111,13 @@ type root struct{ // Add will add the given domain to the radix trie. func (r *root) Add(domain string) { - r.root.add(strings.Split(domain, ".")) + r.root.Add(strings.Split(domain, ".")) } // Match will return whether the given domain matches // an existing stored domain in this radix trie. func (r *root) Match(domain string) bool { - return r.root.match(strings.Split(domain, ".")) + return r.root.Match(strings.Split(domain, ".")) } // Sort will sort the entire radix trie ensuring that @@ -114,7 +131,7 @@ func (r *root) Sort() { // String returns a string representation of node (and its descendants). func (r *root) String() string { buf := new(strings.Builder) - r.root.writestr(buf, "") + r.root.WriteStr(buf, "") return buf.String() } @@ -123,7 +140,7 @@ type node struct { child []*node } -func (n *node) add(parts []string) { +func (n *node) Add(parts []string) { if len(parts) == 0 { panic("invalid domain") } @@ -165,7 +182,7 @@ func (n *node) add(parts []string) { } } -func (n *node) match(parts []string) bool { +func (n *node) Match(parts []string) bool { for len(parts) > 0 { // Pop next domain part. i := len(parts) - 1 @@ -226,8 +243,16 @@ func (n *node) getChild(part string) *node { func (n *node) sort() { // Sort this node's slice of child nodes. - slices.SortFunc(n.child, func(i, j *node) bool { - return i.part < j.part + slices.SortFunc(n.child, func(i, j *node) int { + const k = -1 + switch { + case i.part < j.part: + return +k + case i.part > j.part: + return -k + default: + return 0 + } }) // Sort each child node's children. @@ -236,7 +261,7 @@ func (n *node) sort() { } } -func (n *node) writestr(buf *strings.Builder, prefix string) { +func (n *node) WriteStr(buf *strings.Builder, prefix string) { if prefix != "" { // Suffix joining '.' prefix += "." @@ -251,6 +276,6 @@ func (n *node) writestr(buf *strings.Builder, prefix string) { // Iterate through node children. for _, child := range n.child { - child.writestr(buf, prefix) + child.WriteStr(buf, prefix) } } diff --git a/internal/cache/domain/domain_test.go b/internal/cache/domain/domain_test.go index 9e091e1d0..974425b7c 100644 --- a/internal/cache/domain/domain_test.go +++ b/internal/cache/domain/domain_test.go @@ -28,9 +28,13 @@ func TestCache(t *testing.T) { c := new(domain.Cache) cachedDomains := []string{ - "google.com", - "google.co.uk", - "pleroma.bad.host", + "google.com", // + "mail.google.com", // should be ignored since covered above + "dev.mail.google.com", // same again + "google.co.uk", // + "mail.google.co.uk", // + "pleroma.bad.host", // + "pleroma.still.a.bad.host", // } loader := func() ([]string, error) { @@ -38,22 +42,25 @@ func TestCache(t *testing.T) { return cachedDomains, nil } - // Check a list of known cached domains. + // Check a list of known matching domains. for _, domain := range []string{ "google.com", "mail.google.com", + "dev.mail.google.com", "google.co.uk", "mail.google.co.uk", "pleroma.bad.host", "dev.pleroma.bad.host", + "pleroma.still.a.bad.host", + "dev.pleroma.still.a.bad.host", } { t.Logf("checking domain matches: %s", domain) if b, _ := c.Matches(domain, loader); !b { - t.Errorf("domain should be matched: %s", domain) + t.Fatalf("domain should be matched: %s", domain) } } - // Check a list of known uncached domains. + // Check a list of known unmatched domains. for _, domain := range []string{ "askjeeves.com", "ask-kim.co.uk", @@ -61,10 +68,11 @@ func TestCache(t *testing.T) { "mail.google.ie", "gts.bad.host", "mastodon.bad.host", + "akkoma.still.a.bad.host", } { t.Logf("checking domain isn't matched: %s", domain) if b, _ := c.Matches(domain, loader); b { - t.Errorf("domain should not be matched: %s", domain) + t.Fatalf("domain should not be matched: %s", domain) } } @@ -80,6 +88,6 @@ func TestCache(t *testing.T) { t.Log("load: returning known error") return nil, knownErr }); !errors.Is(err, knownErr) { - t.Errorf("matches did not return expected error: %v", err) + t.Fatalf("matches did not return expected error: %v", err) } } diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index ff687cf59..8164259e8 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -19,6 +19,7 @@ import ( "context" + "slices" "testing" "time" @@ -212,6 +213,67 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() { suite.True(blocked) } +func (suite *DomainTestSuite) TestIsOtherDomainBlockedWildcardAndExplicit() { + ctx := context.Background() + + blocks := []*gtsmodel.DomainBlock{ + { + ID: "01G204214Y9TNJEBX39C7G88SW", + Domain: "bad.apples", + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + CreatedByAccount: suite.testAccounts["admin_account"], + }, + { + ID: "01HKPSVQ864FQ2JJ01CDGPHHMJ", + Domain: "some.bad.apples", + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + CreatedByAccount: suite.testAccounts["admin_account"], + }, + } + + for _, block := range blocks { + if err := suite.db.CreateDomainBlock(ctx, block); err != nil { + suite.FailNow(err.Error()) + } + } + + // Ensure each block created + // above is now present in the db. + dbBlocks, err := suite.db.GetDomainBlocks(ctx) + if err != nil { + suite.FailNow(err.Error()) + } + + for _, block := range blocks { + if !slices.ContainsFunc( + dbBlocks, + func(dbBlock *gtsmodel.DomainBlock) bool { + return block.Domain == dbBlock.Domain + }, + ) { + suite.FailNow("", "stored blocks did not contain %s", block.Domain) + } + } + + // All domains and subdomains + // should now be blocked, even + // ones without an explicit block. + for _, domain := range []string{ + "bad.apples", + "some.bad.apples", + "other.bad.apples", + } { + blocked, err := suite.db.IsDomainBlocked(ctx, domain) + if err != nil { + suite.FailNow(err.Error()) + } + + if !blocked { + suite.Fail("", "domain %s should be blocked", domain) + } + } +} + func TestDomainTestSuite(t *testing.T) { suite.Run(t, new(DomainTestSuite)) }