[bugfix] fix higher-level explicit domain rules causing issues with lower-level domain blocking (#2513)

* fix the sort direction of domain cache child nodes ...

* add more domain cache test cases

* add specific test for this bug to database domain test suite (thanks for writing this @tsmethurst!)

* remove unused field (this was a previous attempt at a fix)

* remove debugging println statements 😇
This commit is contained in:
kim 2024-01-09 13:12:43 +00:00 committed by tobi
parent d5c305dc6e
commit ccecf5a7e4
3 changed files with 114 additions and 19 deletions

View file

@ -19,11 +19,10 @@
import ( import (
"fmt" "fmt"
"slices"
"strings" "strings"
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
"golang.org/x/exp/slices"
) )
// Cache provides a means of caching domains in memory to reduce // Cache provides a means of caching domains in memory to reduce
@ -58,6 +57,24 @@ func (c *Cache) Matches(domain string, load func() ([]string, error)) (bool, err
return false, fmt.Errorf("error reloading cache: %w", err) return false, fmt.Errorf("error reloading cache: %w", err)
} }
// Ensure the domains being inserted into the cache
// are sorted by number of domain parts. i.e. those
// with less parts are inserted last, else this can
// allow domains to fall through the matching code!
slices.SortFunc(domains, func(a, b string) int {
const k = +1
an := strings.Count(a, ".")
bn := strings.Count(b, ".")
switch {
case an < bn:
return +k
case an > bn:
return -k
default:
return 0
}
})
// Allocate new radix trie // Allocate new radix trie
// node to store matches. // node to store matches.
root := new(root) root := new(root)
@ -98,13 +115,13 @@ type root struct{
// Add will add the given domain to the radix trie. // Add will add the given domain to the radix trie.
func (r *root) Add(domain string) { func (r *root) Add(domain string) {
r.root.add(strings.Split(domain, ".")) r.root.Add(strings.Split(domain, "."))
} }
// Match will return whether the given domain matches // Match will return whether the given domain matches
// an existing stored domain in this radix trie. // an existing stored domain in this radix trie.
func (r *root) Match(domain string) bool { func (r *root) Match(domain string) bool {
return r.root.match(strings.Split(domain, ".")) return r.root.Match(strings.Split(domain, "."))
} }
// Sort will sort the entire radix trie ensuring that // Sort will sort the entire radix trie ensuring that
@ -118,7 +135,7 @@ func (r *root) Sort() {
// String returns a string representation of node (and its descendants). // String returns a string representation of node (and its descendants).
func (r *root) String() string { func (r *root) String() string {
buf := new(strings.Builder) buf := new(strings.Builder)
r.root.writestr(buf, "") r.root.WriteStr(buf, "")
return buf.String() return buf.String()
} }
@ -127,7 +144,7 @@ type node struct {
child []*node child []*node
} }
func (n *node) add(parts []string) { func (n *node) Add(parts []string) {
if len(parts) == 0 { if len(parts) == 0 {
panic("invalid domain") panic("invalid domain")
} }
@ -169,7 +186,7 @@ func (n *node) add(parts []string) {
} }
} }
func (n *node) match(parts []string) bool { func (n *node) Match(parts []string) bool {
for len(parts) > 0 { for len(parts) > 0 {
// Pop next domain part. // Pop next domain part.
i := len(parts) - 1 i := len(parts) - 1
@ -230,8 +247,16 @@ func (n *node) getChild(part string) *node {
func (n *node) sort() { func (n *node) sort() {
// Sort this node's slice of child nodes. // Sort this node's slice of child nodes.
slices.SortFunc(n.child, func(i, j *node) bool { slices.SortFunc(n.child, func(i, j *node) int {
return i.part < j.part const k = -1
switch {
case i.part < j.part:
return +k
case i.part > j.part:
return -k
default:
return 0
}
}) })
// Sort each child node's children. // Sort each child node's children.
@ -240,7 +265,7 @@ func (n *node) sort() {
} }
} }
func (n *node) writestr(buf *strings.Builder, prefix string) { func (n *node) WriteStr(buf *strings.Builder, prefix string) {
if prefix != "" { if prefix != "" {
// Suffix joining '.' // Suffix joining '.'
prefix += "." prefix += "."
@ -255,6 +280,6 @@ func (n *node) writestr(buf *strings.Builder, prefix string) {
// Iterate through node children. // Iterate through node children.
for _, child := range n.child { for _, child := range n.child {
child.writestr(buf, prefix) child.WriteStr(buf, prefix)
} }
} }

View file

@ -28,9 +28,13 @@ func TestCache(t *testing.T) {
c := new(domain.Cache) c := new(domain.Cache)
cachedDomains := []string{ cachedDomains := []string{
"google.com", "google.com", //
"google.co.uk", "mail.google.com", // should be ignored since covered above
"pleroma.bad.host", "dev.mail.google.com", // same again
"google.co.uk", //
"mail.google.co.uk", //
"pleroma.bad.host", //
"pleroma.still.a.bad.host", //
} }
loader := func() ([]string, error) { loader := func() ([]string, error) {
@ -38,22 +42,25 @@ func TestCache(t *testing.T) {
return cachedDomains, nil return cachedDomains, nil
} }
// Check a list of known cached domains. // Check a list of known matching domains.
for _, domain := range []string{ for _, domain := range []string{
"google.com", "google.com",
"mail.google.com", "mail.google.com",
"dev.mail.google.com",
"google.co.uk", "google.co.uk",
"mail.google.co.uk", "mail.google.co.uk",
"pleroma.bad.host", "pleroma.bad.host",
"dev.pleroma.bad.host", "dev.pleroma.bad.host",
"pleroma.still.a.bad.host",
"dev.pleroma.still.a.bad.host",
} { } {
t.Logf("checking domain matches: %s", domain) t.Logf("checking domain matches: %s", domain)
if b, _ := c.Matches(domain, loader); !b { if b, _ := c.Matches(domain, loader); !b {
t.Errorf("domain should be matched: %s", domain) t.Fatalf("domain should be matched: %s", domain)
} }
} }
// Check a list of known uncached domains. // Check a list of known unmatched domains.
for _, domain := range []string{ for _, domain := range []string{
"askjeeves.com", "askjeeves.com",
"ask-kim.co.uk", "ask-kim.co.uk",
@ -61,10 +68,11 @@ func TestCache(t *testing.T) {
"mail.google.ie", "mail.google.ie",
"gts.bad.host", "gts.bad.host",
"mastodon.bad.host", "mastodon.bad.host",
"akkoma.still.a.bad.host",
} { } {
t.Logf("checking domain isn't matched: %s", domain) t.Logf("checking domain isn't matched: %s", domain)
if b, _ := c.Matches(domain, loader); b { if b, _ := c.Matches(domain, loader); b {
t.Errorf("domain should not be matched: %s", domain) t.Fatalf("domain should not be matched: %s", domain)
} }
} }
@ -80,6 +88,6 @@ func TestCache(t *testing.T) {
t.Log("load: returning known error") t.Log("load: returning known error")
return nil, knownErr return nil, knownErr
}); !errors.Is(err, knownErr) { }); !errors.Is(err, knownErr) {
t.Errorf("matches did not return expected error: %v", err) t.Fatalf("matches did not return expected error: %v", err)
} }
} }

View file

@ -19,6 +19,7 @@
import ( import (
"context" "context"
"slices"
"testing" "testing"
"time" "time"
@ -212,6 +213,67 @@ func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() {
suite.True(blocked) suite.True(blocked)
} }
func (suite *DomainTestSuite) TestIsOtherDomainBlockedWildcardAndExplicit() {
ctx := context.Background()
blocks := []*gtsmodel.DomainBlock{
{
ID: "01G204214Y9TNJEBX39C7G88SW",
Domain: "bad.apples",
CreatedByAccountID: suite.testAccounts["admin_account"].ID,
CreatedByAccount: suite.testAccounts["admin_account"],
},
{
ID: "01HKPSVQ864FQ2JJ01CDGPHHMJ",
Domain: "some.bad.apples",
CreatedByAccountID: suite.testAccounts["admin_account"].ID,
CreatedByAccount: suite.testAccounts["admin_account"],
},
}
for _, block := range blocks {
if err := suite.db.CreateDomainBlock(ctx, block); err != nil {
suite.FailNow(err.Error())
}
}
// Ensure each block created
// above is now present in the db.
dbBlocks, err := suite.db.GetDomainBlocks(ctx)
if err != nil {
suite.FailNow(err.Error())
}
for _, block := range blocks {
if !slices.ContainsFunc(
dbBlocks,
func(dbBlock *gtsmodel.DomainBlock) bool {
return block.Domain == dbBlock.Domain
},
) {
suite.FailNow("", "stored blocks did not contain %s", block.Domain)
}
}
// All domains and subdomains
// should now be blocked, even
// ones without an explicit block.
for _, domain := range []string{
"bad.apples",
"some.bad.apples",
"other.bad.apples",
} {
blocked, err := suite.db.IsDomainBlocked(ctx, domain)
if err != nil {
suite.FailNow(err.Error())
}
if !blocked {
suite.Fail("", "domain %s should be blocked", domain)
}
}
}
func TestDomainTestSuite(t *testing.T) { func TestDomainTestSuite(t *testing.T) {
suite.Run(t, new(DomainTestSuite)) suite.Run(t, new(DomainTestSuite))
} }