diff --git a/internal/db/bundb/domain.go b/internal/db/bundb/domain.go index 0d67837d7..5d262c676 100644 --- a/internal/db/bundb/domain.go +++ b/internal/db/bundb/domain.go @@ -28,6 +28,7 @@ "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "golang.org/x/net/idna" ) type domainDB struct { @@ -35,15 +36,28 @@ type domainDB struct { cache *cache.DomainBlockCache } +// normalizeDomain converts the given domain to lowercase +// then to punycode (for international domain names). +// +// Returns the resulting domain or an error if the +// punycode conversion fails. +func normalizeDomain(domain string) (out string, err error) { + out = strings.ToLower(domain) + out, err = idna.ToASCII(out) + return out, err +} + func (d *domainDB) CreateDomainBlock(ctx context.Context, block gtsmodel.DomainBlock) db.Error { - // Normalize to lowercase - block.Domain = strings.ToLower(block.Domain) + domain, err := normalizeDomain(block.Domain) + if err != nil { + return err + } + block.Domain = domain // Attempt to insert new domain block - _, err := d.conn.NewInsert(). + if _, err := d.conn.NewInsert(). Model(&block). - Exec(ctx, &block) - if err != nil { + Exec(ctx, &block); err != nil { return d.conn.ProcessError(err) } @@ -54,8 +68,11 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block gtsmodel.DomainB } func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, db.Error) { - // Normalize to lowercase - domain = strings.ToLower(domain) + var err error + domain, err = normalizeDomain(domain) + if err != nil { + return nil, err + } // Check for easy case, domain referencing *us* if domain == "" || domain == config.GetAccountDomain() { @@ -100,15 +117,17 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel } func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error { - // Normalize to lowercase - domain = strings.ToLower(domain) + var err error + domain, err = normalizeDomain(domain) + if err != nil { + return err + } // Attempt to delete domain block - _, err := d.conn.NewDelete(). + if _, err := d.conn.NewDelete(). Model((*gtsmodel.DomainBlock)(nil)). Where("domain = ?", domain). - Exec(ctx) - if err != nil { + Exec(ctx); err != nil { return d.conn.ProcessError(err) } diff --git a/internal/db/bundb/domain_test.go b/internal/db/bundb/domain_test.go index b326236ad..48c4a7798 100644 --- a/internal/db/bundb/domain_test.go +++ b/internal/db/bundb/domain_test.go @@ -59,6 +59,78 @@ func (suite *DomainTestSuite) TestIsDomainBlocked() { suite.True(blocked) } +func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII() { + ctx := context.Background() + + now := time.Now() + + domainBlock := >smodel.DomainBlock{ + ID: "01G204214Y9TNJEBX39C7G88SW", + Domain: "xn--80aaa1bbb1h.com", + CreatedAt: now, + UpdatedAt: now, + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + CreatedByAccount: suite.testAccounts["admin_account"], + } + + // no domain block exists for the given domain yet + blocked, err := suite.db.IsDomainBlocked(ctx, "какашка.com") + suite.NoError(err) + suite.False(blocked) + + blocked, err = suite.db.IsDomainBlocked(ctx, "xn--80aaa1bbb1h.com") + suite.NoError(err) + suite.False(blocked) + + err = suite.db.CreateDomainBlock(ctx, *domainBlock) + suite.NoError(err) + + // domain block now exists + blocked, err = suite.db.IsDomainBlocked(ctx, "какашка.com") + suite.NoError(err) + suite.True(blocked) + + blocked, err = suite.db.IsDomainBlocked(ctx, "xn--80aaa1bbb1h.com") + suite.NoError(err) + suite.True(blocked) +} + +func (suite *DomainTestSuite) TestIsDomainBlockedNonASCII2() { + ctx := context.Background() + + now := time.Now() + + domainBlock := >smodel.DomainBlock{ + ID: "01G204214Y9TNJEBX39C7G88SW", + Domain: "какашка.com", + CreatedAt: now, + UpdatedAt: now, + CreatedByAccountID: suite.testAccounts["admin_account"].ID, + CreatedByAccount: suite.testAccounts["admin_account"], + } + + // no domain block exists for the given domain yet + blocked, err := suite.db.IsDomainBlocked(ctx, "какашка.com") + suite.NoError(err) + suite.False(blocked) + + blocked, err = suite.db.IsDomainBlocked(ctx, "xn--80aaa1bbb1h.com") + suite.NoError(err) + suite.False(blocked) + + err = suite.db.CreateDomainBlock(ctx, *domainBlock) + suite.NoError(err) + + // domain block now exists + blocked, err = suite.db.IsDomainBlocked(ctx, "какашка.com") + suite.NoError(err) + suite.True(blocked) + + blocked, err = suite.db.IsDomainBlocked(ctx, "xn--80aaa1bbb1h.com") + suite.NoError(err) + suite.True(blocked) +} + func TestDomainTestSuite(t *testing.T) { suite.Run(t, new(DomainTestSuite)) }