mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-12-22 18:22:11 +00:00
[chore/bugfix] Domain block tidying up, Implement first pass of 207 Multi-Status
(#1886)
* [chore/refactor] update domain block processing * expose domain block import errors a lil better * move/remove unused query keys
This commit is contained in:
parent
d9c69f6ce0
commit
e70bf8a6c8
|
@ -42,8 +42,6 @@
|
|||
EmailPath = BasePath + "/email"
|
||||
EmailTestPath = EmailPath + "/test"
|
||||
|
||||
ExportQueryKey = "export"
|
||||
ImportQueryKey = "import"
|
||||
IDKey = "id"
|
||||
FilterQueryKey = "filter"
|
||||
MaxShortcodeDomainKey = "max_shortcode_domain"
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
|
||||
|
@ -140,48 +139,78 @@ func (m *Module) DomainBlocksPOSTHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
imp := false
|
||||
importString := c.Query(ImportQueryKey)
|
||||
if importString != "" {
|
||||
i, err := strconv.ParseBool(importString)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error parsing %s: %s", ImportQueryKey, err)
|
||||
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
imp = i
|
||||
importing, errWithCode := apiutil.ParseDomainBlockImport(c.Query(apiutil.DomainBlockImportKey), false)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
form := &apimodel.DomainBlockCreateRequest{}
|
||||
form := new(apimodel.DomainBlockCreateRequest)
|
||||
if err := c.ShouldBind(form); err != nil {
|
||||
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if err := validateCreateDomainBlock(form, imp); err != nil {
|
||||
err := fmt.Errorf("error validating form: %s", err)
|
||||
if err := validateCreateDomainBlock(form, importing); err != nil {
|
||||
err := fmt.Errorf("error validating form: %w", err)
|
||||
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
if imp {
|
||||
// we're importing multiple blocks
|
||||
domainBlocks, errWithCode := m.processor.Admin().DomainBlocksImport(c.Request.Context(), authed.Account, form.Domains)
|
||||
if !importing {
|
||||
// Single domain block creation.
|
||||
domainBlock, errWithCode := m.processor.Admin().DomainBlockCreate(
|
||||
c.Request.Context(),
|
||||
authed.Account,
|
||||
form.Domain,
|
||||
form.Obfuscate,
|
||||
form.PublicComment,
|
||||
form.PrivateComment,
|
||||
"", // No sub ID for single block creation.
|
||||
)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, domainBlocks)
|
||||
|
||||
c.JSON(http.StatusOK, domainBlock)
|
||||
return
|
||||
}
|
||||
|
||||
// we're just creating one block
|
||||
domainBlock, errWithCode := m.processor.Admin().DomainBlockCreate(c.Request.Context(), authed.Account, form.Domain, form.Obfuscate, form.PublicComment, form.PrivateComment, "")
|
||||
// We're importing multiple domain blocks,
|
||||
// so we're looking at a multi-status response.
|
||||
multiStatus, errWithCode := m.processor.Admin().DomainBlocksImport(
|
||||
c.Request.Context(),
|
||||
authed.Account,
|
||||
form.Domains, // Pass the file through.
|
||||
)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, domainBlock)
|
||||
|
||||
// TODO: Return 207 and multiStatus data nicely
|
||||
// when supported by the admin panel.
|
||||
|
||||
if multiStatus.Metadata.Failure != 0 {
|
||||
failures := make(map[string]any, multiStatus.Metadata.Failure)
|
||||
for _, entry := range multiStatus.Data {
|
||||
// nolint:forcetypeassert
|
||||
failures[entry.Resource.(string)] = entry.Message
|
||||
}
|
||||
|
||||
err := fmt.Errorf("one or more errors importing domain blocks: %+v", failures)
|
||||
apiutil.ErrorHandler(c, gtserror.NewErrorUnprocessableEntity(err, err.Error()), m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
// Success, return slice of domain blocks.
|
||||
domainBlocks := make([]any, 0, multiStatus.Metadata.Success)
|
||||
for _, entry := range multiStatus.Data {
|
||||
domainBlocks = append(domainBlocks, entry.Resource)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, domainBlocks)
|
||||
}
|
||||
|
||||
func validateCreateDomainBlock(form *apimodel.DomainBlockCreateRequest, imp bool) error {
|
||||
|
|
|
@ -18,10 +18,8 @@
|
|||
package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||
|
@ -87,26 +85,19 @@ func (m *Module) DomainBlockGETHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
domainBlockID := c.Param(IDKey)
|
||||
if domainBlockID == "" {
|
||||
err := errors.New("no domain block id specified")
|
||||
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
|
||||
domainBlockID, errWithCode := apiutil.ParseID(c.Param(apiutil.IDKey))
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
export := false
|
||||
exportString := c.Query(ExportQueryKey)
|
||||
if exportString != "" {
|
||||
i, err := strconv.ParseBool(exportString)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err)
|
||||
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
export = i
|
||||
export, errWithCode := apiutil.ParseDomainBlockExport(c.Query(apiutil.DomainBlockExportKey), false)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
domainBlock, errWithCode := m.processor.Admin().DomainBlockGet(c.Request.Context(), authed.Account, domainBlockID, export)
|
||||
domainBlock, errWithCode := m.processor.Admin().DomainBlockGet(c.Request.Context(), domainBlockID, export)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
|
||||
|
@ -92,16 +91,10 @@ func (m *Module) DomainBlocksGETHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
export := false
|
||||
exportString := c.Query(ExportQueryKey)
|
||||
if exportString != "" {
|
||||
i, err := strconv.ParseBool(exportString)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("error parsing %s: %s", ExportQueryKey, err)
|
||||
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
export = i
|
||||
export, errWithCode := apiutil.ParseDomainBlockExport(c.Query(apiutil.DomainBlockExportKey), false)
|
||||
if errWithCode != nil {
|
||||
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
|
||||
return
|
||||
}
|
||||
|
||||
domainBlocks, errWithCode := m.processor.Admin().DomainBlocksGet(c.Request.Context(), authed.Account, export)
|
||||
|
|
90
internal/api/model/multistatus.go
Normal file
90
internal/api/model/multistatus.go
Normal file
|
@ -0,0 +1,90 @@
|
|||
// GoToSocial
|
||||
// Copyright (C) GoToSocial Authors admin@gotosocial.org
|
||||
// SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
//
|
||||
// This program is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU Affero General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// This program is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU Affero General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU Affero General Public License
|
||||
// along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
package model
|
||||
|
||||
// MultiStatus models a multistatus HTTP response body.
|
||||
// This model should be transmitted along with http code
|
||||
// 207 MULTI-STATUS to indicate a mixture of responses.
|
||||
// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/207
|
||||
//
|
||||
// swagger:model multiStatus
|
||||
type MultiStatus struct {
|
||||
Data []MultiStatusEntry `json:"data"`
|
||||
Metadata MultiStatusMetadata `json:"metadata"`
|
||||
}
|
||||
|
||||
// MultiStatusEntry models one entry in multistatus data.
|
||||
// It can model either a success or a failure. The type
|
||||
// and value of `Resource` is left to the discretion of
|
||||
// the caller, but at minimum it should be expected to be
|
||||
// JSON-serializable.
|
||||
//
|
||||
// swagger:model multiStatusEntry
|
||||
type MultiStatusEntry struct {
|
||||
// The resource/result for this entry.
|
||||
// Value may be any type, check the docs
|
||||
// per endpoint to see which to expect.
|
||||
Resource any `json:"resource"`
|
||||
// Message/error message for this entry.
|
||||
Message string `json:"message"`
|
||||
// HTTP status code of this entry.
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// MultiStatusMetadata models an at-a-glance summary of
|
||||
// the data contained in the MultiStatus.
|
||||
//
|
||||
// swagger:model multiStatusMetadata
|
||||
type MultiStatusMetadata struct {
|
||||
// Success count + failure count.
|
||||
Total int `json:"total"`
|
||||
// Count of successful results (2xx).
|
||||
Success int `json:"success"`
|
||||
// Count of unsuccessful results (!2xx).
|
||||
Failure int `json:"failure"`
|
||||
}
|
||||
|
||||
// NewMultiStatus returns a new MultiStatus API model with
|
||||
// the provided entries, which will be iterated through to
|
||||
// look for 2xx and non 2xx status codes, in order to count
|
||||
// successes and failures.
|
||||
func NewMultiStatus(entries []MultiStatusEntry) *MultiStatus {
|
||||
var (
|
||||
successCount int
|
||||
failureCount int
|
||||
total = len(entries)
|
||||
)
|
||||
|
||||
for _, e := range entries {
|
||||
// Outside 2xx range = failure.
|
||||
if e.Status > 299 || e.Status < 200 {
|
||||
failureCount++
|
||||
} else {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
return &MultiStatus{
|
||||
Data: entries,
|
||||
Metadata: MultiStatusMetadata{
|
||||
Total: total,
|
||||
Success: successCount,
|
||||
Failure: failureCount,
|
||||
},
|
||||
}
|
||||
}
|
|
@ -27,6 +27,7 @@
|
|||
const (
|
||||
/* Common keys */
|
||||
|
||||
IDKey = "id"
|
||||
LimitKey = "limit"
|
||||
LocalKey = "local"
|
||||
MaxIDKey = "max_id"
|
||||
|
@ -41,6 +42,11 @@
|
|||
SearchQueryKey = "q"
|
||||
SearchResolveKey = "resolve"
|
||||
SearchTypeKey = "type"
|
||||
|
||||
/* Domain block keys */
|
||||
|
||||
DomainBlockExportKey = "export"
|
||||
DomainBlockImportKey = "import"
|
||||
)
|
||||
|
||||
// parseError returns gtserror.WithCode set to 400 Bad Request, to indicate
|
||||
|
@ -50,6 +56,8 @@ func parseError(key string, value, defaultValue any, err error) gtserror.WithCod
|
|||
return gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
|
||||
// requiredError returns gtserror.WithCode set to 400 Bad Request, to indicate
|
||||
// to the caller a required key value was not provided, or was empty.
|
||||
func requiredError(key string) gtserror.WithCode {
|
||||
err := fmt.Errorf("required key %s was not set or had empty value", key)
|
||||
return gtserror.NewErrorBadRequest(err, err.Error())
|
||||
|
@ -60,111 +68,51 @@ func requiredError(key string) gtserror.WithCode {
|
|||
*/
|
||||
|
||||
func ParseLimit(value string, defaultValue int, max, min int) (int, gtserror.WithCode) {
|
||||
key := LimitKey
|
||||
|
||||
if value == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
i, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return defaultValue, parseError(key, value, defaultValue, err)
|
||||
}
|
||||
|
||||
if i > max {
|
||||
i = max
|
||||
} else if i < min {
|
||||
i = min
|
||||
}
|
||||
|
||||
return i, nil
|
||||
return parseInt(value, defaultValue, max, min, LimitKey)
|
||||
}
|
||||
|
||||
func ParseLocal(value string, defaultValue bool) (bool, gtserror.WithCode) {
|
||||
key := LimitKey
|
||||
|
||||
if value == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
i, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return defaultValue, parseError(key, value, defaultValue, err)
|
||||
}
|
||||
|
||||
return i, nil
|
||||
return parseBool(value, defaultValue, LocalKey)
|
||||
}
|
||||
|
||||
func ParseSearchExcludeUnreviewed(value string, defaultValue bool) (bool, gtserror.WithCode) {
|
||||
key := SearchExcludeUnreviewedKey
|
||||
|
||||
if value == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
i, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return defaultValue, parseError(key, value, defaultValue, err)
|
||||
}
|
||||
|
||||
return i, nil
|
||||
return parseBool(value, defaultValue, SearchExcludeUnreviewedKey)
|
||||
}
|
||||
|
||||
func ParseSearchFollowing(value string, defaultValue bool) (bool, gtserror.WithCode) {
|
||||
key := SearchFollowingKey
|
||||
|
||||
if value == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
i, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return defaultValue, parseError(key, value, defaultValue, err)
|
||||
}
|
||||
|
||||
return i, nil
|
||||
return parseBool(value, defaultValue, SearchFollowingKey)
|
||||
}
|
||||
|
||||
func ParseSearchOffset(value string, defaultValue int, max, min int) (int, gtserror.WithCode) {
|
||||
key := SearchOffsetKey
|
||||
|
||||
if value == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
i, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return defaultValue, parseError(key, value, defaultValue, err)
|
||||
}
|
||||
|
||||
if i > max {
|
||||
i = max
|
||||
} else if i < min {
|
||||
i = min
|
||||
}
|
||||
|
||||
return i, nil
|
||||
return parseInt(value, defaultValue, max, min, SearchOffsetKey)
|
||||
}
|
||||
|
||||
func ParseSearchResolve(value string, defaultValue bool) (bool, gtserror.WithCode) {
|
||||
key := SearchResolveKey
|
||||
return parseBool(value, defaultValue, SearchResolveKey)
|
||||
}
|
||||
|
||||
if value == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
func ParseDomainBlockExport(value string, defaultValue bool) (bool, gtserror.WithCode) {
|
||||
return parseBool(value, defaultValue, DomainBlockExportKey)
|
||||
}
|
||||
|
||||
i, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return defaultValue, parseError(key, value, defaultValue, err)
|
||||
}
|
||||
|
||||
return i, nil
|
||||
func ParseDomainBlockImport(value string, defaultValue bool) (bool, gtserror.WithCode) {
|
||||
return parseBool(value, defaultValue, DomainBlockImportKey)
|
||||
}
|
||||
|
||||
/*
|
||||
Parse functions for *REQUIRED* parameters.
|
||||
*/
|
||||
|
||||
func ParseID(value string) (string, gtserror.WithCode) {
|
||||
key := IDKey
|
||||
|
||||
if value == "" {
|
||||
return "", requiredError(key)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func ParseSearchLookup(value string) (string, gtserror.WithCode) {
|
||||
key := SearchLookupKey
|
||||
|
||||
|
@ -184,3 +132,39 @@ func ParseSearchQuery(value string) (string, gtserror.WithCode) {
|
|||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
/*
|
||||
Internal functions
|
||||
*/
|
||||
|
||||
func parseBool(value string, defaultValue bool, key string) (bool, gtserror.WithCode) {
|
||||
if value == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
i, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
return defaultValue, parseError(key, value, defaultValue, err)
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func parseInt(value string, defaultValue int, max int, min int, key string) (int, gtserror.WithCode) {
|
||||
if value == "" {
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
i, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return defaultValue, parseError(key, value, defaultValue, err)
|
||||
}
|
||||
|
||||
if i > max {
|
||||
i = max
|
||||
} else if i < min {
|
||||
i = min
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
|
22
internal/cache/gts.go
vendored
22
internal/cache/gts.go
vendored
|
@ -35,6 +35,7 @@ type GTSCaches struct {
|
|||
emojiCategory *result.Cache[*gtsmodel.EmojiCategory]
|
||||
follow *result.Cache[*gtsmodel.Follow]
|
||||
followRequest *result.Cache[*gtsmodel.FollowRequest]
|
||||
instance *result.Cache[*gtsmodel.Instance]
|
||||
list *result.Cache[*gtsmodel.List]
|
||||
listEntry *result.Cache[*gtsmodel.ListEntry]
|
||||
media *result.Cache[*gtsmodel.MediaAttachment]
|
||||
|
@ -59,6 +60,7 @@ func (c *GTSCaches) Init() {
|
|||
c.initEmojiCategory()
|
||||
c.initFollow()
|
||||
c.initFollowRequest()
|
||||
c.initInstance()
|
||||
c.initList()
|
||||
c.initListEntry()
|
||||
c.initMedia()
|
||||
|
@ -80,6 +82,7 @@ func (c *GTSCaches) Start() {
|
|||
tryStart(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
|
||||
tryStart(c.follow, config.GetCacheGTSFollowSweepFreq())
|
||||
tryStart(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
|
||||
tryStart(c.instance, config.GetCacheGTSInstanceSweepFreq())
|
||||
tryStart(c.list, config.GetCacheGTSListSweepFreq())
|
||||
tryStart(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
|
||||
tryStart(c.media, config.GetCacheGTSMediaSweepFreq())
|
||||
|
@ -106,6 +109,7 @@ func (c *GTSCaches) Stop() {
|
|||
tryStop(c.emojiCategory, config.GetCacheGTSEmojiCategorySweepFreq())
|
||||
tryStop(c.follow, config.GetCacheGTSFollowSweepFreq())
|
||||
tryStop(c.followRequest, config.GetCacheGTSFollowRequestSweepFreq())
|
||||
tryStop(c.instance, config.GetCacheGTSInstanceSweepFreq())
|
||||
tryStop(c.list, config.GetCacheGTSListSweepFreq())
|
||||
tryStop(c.listEntry, config.GetCacheGTSListEntrySweepFreq())
|
||||
tryStop(c.media, config.GetCacheGTSMediaSweepFreq())
|
||||
|
@ -154,6 +158,11 @@ func (c *GTSCaches) FollowRequest() *result.Cache[*gtsmodel.FollowRequest] {
|
|||
return c.followRequest
|
||||
}
|
||||
|
||||
// Instance provides access to the gtsmodel Instance database cache.
|
||||
func (c *GTSCaches) Instance() *result.Cache[*gtsmodel.Instance] {
|
||||
return c.instance
|
||||
}
|
||||
|
||||
// List provides access to the gtsmodel List database cache.
|
||||
func (c *GTSCaches) List() *result.Cache[*gtsmodel.List] {
|
||||
return c.list
|
||||
|
@ -301,6 +310,19 @@ func (c *GTSCaches) initFollowRequest() {
|
|||
c.followRequest.SetTTL(config.GetCacheGTSFollowRequestTTL(), true)
|
||||
}
|
||||
|
||||
func (c *GTSCaches) initInstance() {
|
||||
c.instance = result.New([]result.Lookup{
|
||||
{Name: "ID"},
|
||||
{Name: "Domain"},
|
||||
}, func(i1 *gtsmodel.Instance) *gtsmodel.Instance {
|
||||
i2 := new(gtsmodel.Instance)
|
||||
*i2 = *i1
|
||||
return i1
|
||||
}, config.GetCacheGTSInstanceMaxSize())
|
||||
c.instance.SetTTL(config.GetCacheGTSInstanceTTL(), true)
|
||||
c.emojiCategory.IgnoreErrors(ignoreErrors)
|
||||
}
|
||||
|
||||
func (c *GTSCaches) initList() {
|
||||
c.list = result.New([]result.Lookup{
|
||||
{Name: "ID"},
|
||||
|
|
|
@ -200,6 +200,10 @@ type GTSCacheConfiguration struct {
|
|||
FollowRequestTTL time.Duration `name:"follow-request-ttl"`
|
||||
FollowRequestSweepFreq time.Duration `name:"follow-request-sweep-freq"`
|
||||
|
||||
InstanceMaxSize int `name:"instance-max-size"`
|
||||
InstanceTTL time.Duration `name:"instance-ttl"`
|
||||
InstanceSweepFreq time.Duration `name:"instance-sweep-freq"`
|
||||
|
||||
ListMaxSize int `name:"list-max-size"`
|
||||
ListTTL time.Duration `name:"list-ttl"`
|
||||
ListSweepFreq time.Duration `name:"list-sweep-freq"`
|
||||
|
|
|
@ -154,6 +154,10 @@
|
|||
FollowRequestTTL: time.Minute * 30,
|
||||
FollowRequestSweepFreq: time.Minute,
|
||||
|
||||
InstanceMaxSize: 2000,
|
||||
InstanceTTL: time.Minute * 30,
|
||||
InstanceSweepFreq: time.Minute,
|
||||
|
||||
ListMaxSize: 2000,
|
||||
ListTTL: time.Minute * 30,
|
||||
ListSweepFreq: time.Minute,
|
||||
|
|
|
@ -2828,6 +2828,81 @@ func GetCacheGTSFollowRequestSweepFreq() time.Duration {
|
|||
// SetCacheGTSFollowRequestSweepFreq safely sets the value for global configuration 'Cache.GTS.FollowRequestSweepFreq' field
|
||||
func SetCacheGTSFollowRequestSweepFreq(v time.Duration) { global.SetCacheGTSFollowRequestSweepFreq(v) }
|
||||
|
||||
// GetCacheGTSInstanceMaxSize safely fetches the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field
|
||||
func (st *ConfigState) GetCacheGTSInstanceMaxSize() (v int) {
|
||||
st.mutex.Lock()
|
||||
v = st.config.Cache.GTS.InstanceMaxSize
|
||||
st.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// SetCacheGTSInstanceMaxSize safely sets the Configuration value for state's 'Cache.GTS.InstanceMaxSize' field
|
||||
func (st *ConfigState) SetCacheGTSInstanceMaxSize(v int) {
|
||||
st.mutex.Lock()
|
||||
defer st.mutex.Unlock()
|
||||
st.config.Cache.GTS.InstanceMaxSize = v
|
||||
st.reloadToViper()
|
||||
}
|
||||
|
||||
// CacheGTSInstanceMaxSizeFlag returns the flag name for the 'Cache.GTS.InstanceMaxSize' field
|
||||
func CacheGTSInstanceMaxSizeFlag() string { return "cache-gts-instance-max-size" }
|
||||
|
||||
// GetCacheGTSInstanceMaxSize safely fetches the value for global configuration 'Cache.GTS.InstanceMaxSize' field
|
||||
func GetCacheGTSInstanceMaxSize() int { return global.GetCacheGTSInstanceMaxSize() }
|
||||
|
||||
// SetCacheGTSInstanceMaxSize safely sets the value for global configuration 'Cache.GTS.InstanceMaxSize' field
|
||||
func SetCacheGTSInstanceMaxSize(v int) { global.SetCacheGTSInstanceMaxSize(v) }
|
||||
|
||||
// GetCacheGTSInstanceTTL safely fetches the Configuration value for state's 'Cache.GTS.InstanceTTL' field
|
||||
func (st *ConfigState) GetCacheGTSInstanceTTL() (v time.Duration) {
|
||||
st.mutex.Lock()
|
||||
v = st.config.Cache.GTS.InstanceTTL
|
||||
st.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// SetCacheGTSInstanceTTL safely sets the Configuration value for state's 'Cache.GTS.InstanceTTL' field
|
||||
func (st *ConfigState) SetCacheGTSInstanceTTL(v time.Duration) {
|
||||
st.mutex.Lock()
|
||||
defer st.mutex.Unlock()
|
||||
st.config.Cache.GTS.InstanceTTL = v
|
||||
st.reloadToViper()
|
||||
}
|
||||
|
||||
// CacheGTSInstanceTTLFlag returns the flag name for the 'Cache.GTS.InstanceTTL' field
|
||||
func CacheGTSInstanceTTLFlag() string { return "cache-gts-instance-ttl" }
|
||||
|
||||
// GetCacheGTSInstanceTTL safely fetches the value for global configuration 'Cache.GTS.InstanceTTL' field
|
||||
func GetCacheGTSInstanceTTL() time.Duration { return global.GetCacheGTSInstanceTTL() }
|
||||
|
||||
// SetCacheGTSInstanceTTL safely sets the value for global configuration 'Cache.GTS.InstanceTTL' field
|
||||
func SetCacheGTSInstanceTTL(v time.Duration) { global.SetCacheGTSInstanceTTL(v) }
|
||||
|
||||
// GetCacheGTSInstanceSweepFreq safely fetches the Configuration value for state's 'Cache.GTS.InstanceSweepFreq' field
|
||||
func (st *ConfigState) GetCacheGTSInstanceSweepFreq() (v time.Duration) {
|
||||
st.mutex.Lock()
|
||||
v = st.config.Cache.GTS.InstanceSweepFreq
|
||||
st.mutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// SetCacheGTSInstanceSweepFreq safely sets the Configuration value for state's 'Cache.GTS.InstanceSweepFreq' field
|
||||
func (st *ConfigState) SetCacheGTSInstanceSweepFreq(v time.Duration) {
|
||||
st.mutex.Lock()
|
||||
defer st.mutex.Unlock()
|
||||
st.config.Cache.GTS.InstanceSweepFreq = v
|
||||
st.reloadToViper()
|
||||
}
|
||||
|
||||
// CacheGTSInstanceSweepFreqFlag returns the flag name for the 'Cache.GTS.InstanceSweepFreq' field
|
||||
func CacheGTSInstanceSweepFreqFlag() string { return "cache-gts-instance-sweep-freq" }
|
||||
|
||||
// GetCacheGTSInstanceSweepFreq safely fetches the value for global configuration 'Cache.GTS.InstanceSweepFreq' field
|
||||
func GetCacheGTSInstanceSweepFreq() time.Duration { return global.GetCacheGTSInstanceSweepFreq() }
|
||||
|
||||
// SetCacheGTSInstanceSweepFreq safely sets the value for global configuration 'Cache.GTS.InstanceSweepFreq' field
|
||||
func SetCacheGTSInstanceSweepFreq(v time.Duration) { global.SetCacheGTSInstanceSweepFreq(v) }
|
||||
|
||||
// GetCacheGTSListMaxSize safely fetches the Configuration value for state's 'Cache.GTS.ListMaxSize' field
|
||||
func (st *ConfigState) GetCacheGTSListMaxSize() (v int) {
|
||||
st.mutex.Lock()
|
||||
|
|
|
@ -179,7 +179,8 @@ func NewBunDBService(ctx context.Context, state *state.State) (db.DB, error) {
|
|||
state: state,
|
||||
},
|
||||
Instance: &instanceDB{
|
||||
conn: conn,
|
||||
conn: conn,
|
||||
state: state,
|
||||
},
|
||||
List: &listDB{
|
||||
conn: conn,
|
||||
|
|
|
@ -42,7 +42,7 @@ func (d *domainDB) CreateDomainBlock(ctx context.Context, block *gtsmodel.Domain
|
|||
return err
|
||||
}
|
||||
|
||||
// Attempt to store domain in DB
|
||||
// Attempt to store domain block in DB
|
||||
if _, err := d.conn.NewInsert().
|
||||
Model(block).
|
||||
Exec(ctx); err != nil {
|
||||
|
@ -82,6 +82,33 @@ func (d *domainDB) GetDomainBlock(ctx context.Context, domain string) (*gtsmodel
|
|||
return &block, nil
|
||||
}
|
||||
|
||||
func (d *domainDB) GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error) {
|
||||
blocks := []*gtsmodel.DomainBlock{}
|
||||
|
||||
if err := d.conn.
|
||||
NewSelect().
|
||||
Model(&blocks).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, d.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return blocks, nil
|
||||
}
|
||||
|
||||
func (d *domainDB) GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, db.Error) {
|
||||
var block gtsmodel.DomainBlock
|
||||
|
||||
q := d.conn.
|
||||
NewSelect().
|
||||
Model(&block).
|
||||
Where("? = ?", bun.Ident("domain_block.id"), id)
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
return nil, d.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return &block, nil
|
||||
}
|
||||
|
||||
func (d *domainDB) DeleteDomainBlock(ctx context.Context, domain string) db.Error {
|
||||
// Normalize the domain as punycode
|
||||
domain, err := util.Punify(domain)
|
||||
|
|
|
@ -19,15 +19,23 @@
|
|||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/superseriousbusiness/gotosocial/internal/config"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/id"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/log"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/state"
|
||||
"github.com/superseriousbusiness/gotosocial/internal/util"
|
||||
"github.com/uptrace/bun"
|
||||
)
|
||||
|
||||
type instanceDB struct {
|
||||
conn *DBConn
|
||||
conn *DBConn
|
||||
state *state.State
|
||||
}
|
||||
|
||||
func (i *instanceDB) CountInstanceUsers(ctx context.Context, domain string) (int, db.Error) {
|
||||
|
@ -99,62 +107,236 @@ func (i *instanceDB) CountInstanceDomains(ctx context.Context, domain string) (i
|
|||
}
|
||||
|
||||
func (i *instanceDB) GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, db.Error) {
|
||||
instance := >smodel.Instance{}
|
||||
// Normalize the domain as punycode
|
||||
var err error
|
||||
domain, err = util.Punify(domain)
|
||||
if err != nil {
|
||||
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
if err := i.conn.
|
||||
NewSelect().
|
||||
Model(instance).
|
||||
Where("? = ?", bun.Ident("instance.domain"), domain).
|
||||
Scan(ctx); err != nil {
|
||||
return nil, i.conn.ProcessError(err)
|
||||
return i.getInstance(
|
||||
ctx,
|
||||
"Domain",
|
||||
func(instance *gtsmodel.Instance) error {
|
||||
return i.conn.NewSelect().
|
||||
Model(instance).
|
||||
Where("? = ?", bun.Ident("instance.domain"), domain).
|
||||
Scan(ctx)
|
||||
},
|
||||
domain,
|
||||
)
|
||||
}
|
||||
|
||||
func (i *instanceDB) GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error) {
|
||||
return i.getInstance(
|
||||
ctx,
|
||||
"ID",
|
||||
func(instance *gtsmodel.Instance) error {
|
||||
return i.conn.NewSelect().
|
||||
Model(instance).
|
||||
Where("? = ?", bun.Ident("instance.id"), id).
|
||||
Scan(ctx)
|
||||
},
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
func (i *instanceDB) getInstance(ctx context.Context, lookup string, dbQuery func(*gtsmodel.Instance) error, keyParts ...any) (*gtsmodel.Instance, db.Error) {
|
||||
// Fetch instance from database cache with loader callback
|
||||
instance, err := i.state.Caches.GTS.Instance().Load(lookup, func() (*gtsmodel.Instance, error) {
|
||||
var instance gtsmodel.Instance
|
||||
|
||||
// Not cached! Perform database query.
|
||||
if err := dbQuery(&instance); err != nil {
|
||||
return nil, i.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
return &instance, nil
|
||||
}, keyParts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if gtscontext.Barebones(ctx) {
|
||||
// no need to fully populate.
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
// Further populate the instance fields where applicable.
|
||||
if err := i.populateInstance(ctx, instance); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (i *instanceDB) populateInstance(ctx context.Context, instance *gtsmodel.Instance) error {
|
||||
var (
|
||||
err error
|
||||
errs = make(gtserror.MultiError, 0, 2)
|
||||
)
|
||||
|
||||
if instance.DomainBlockID != "" && instance.DomainBlock == nil {
|
||||
// Instance domain block is not set, fetch from database.
|
||||
instance.DomainBlock, err = i.state.DB.GetDomainBlock(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
instance.Domain,
|
||||
)
|
||||
if err != nil {
|
||||
errs.Append(gtserror.Newf("error populating instance domain block: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if instance.ContactAccountID != "" && instance.ContactAccount == nil {
|
||||
// Instance domain block is not set, fetch from database.
|
||||
instance.ContactAccount, err = i.state.DB.GetAccountByID(
|
||||
gtscontext.SetBarebones(ctx),
|
||||
instance.ContactAccountID,
|
||||
)
|
||||
if err != nil {
|
||||
errs.Append(gtserror.Newf("error populating instance contact account: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return errs.Combine()
|
||||
}
|
||||
|
||||
func (i *instanceDB) PutInstance(ctx context.Context, instance *gtsmodel.Instance) error {
|
||||
// Normalize the domain as punycode
|
||||
var err error
|
||||
instance.Domain, err = util.Punify(instance.Domain)
|
||||
if err != nil {
|
||||
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
|
||||
}
|
||||
|
||||
return i.state.Caches.GTS.Instance().Store(instance, func() error {
|
||||
_, err := i.conn.NewInsert().Model(instance).Exec(ctx)
|
||||
return i.conn.ProcessError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (i *instanceDB) UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error {
|
||||
// Normalize the domain as punycode
|
||||
var err error
|
||||
instance.Domain, err = util.Punify(instance.Domain)
|
||||
if err != nil {
|
||||
return gtserror.Newf("error punifying domain %s: %w", instance.Domain, err)
|
||||
}
|
||||
|
||||
// Update the instance's last-updated
|
||||
instance.UpdatedAt = time.Now()
|
||||
if len(columns) != 0 {
|
||||
columns = append(columns, "updated_at")
|
||||
}
|
||||
|
||||
return i.state.Caches.GTS.Instance().Store(instance, func() error {
|
||||
_, err := i.conn.
|
||||
NewUpdate().
|
||||
Model(instance).
|
||||
Where("? = ?", bun.Ident("instance.id"), instance.ID).
|
||||
Column(columns...).
|
||||
Exec(ctx)
|
||||
return i.conn.ProcessError(err)
|
||||
})
|
||||
}
|
||||
|
||||
func (i *instanceDB) GetInstancePeers(ctx context.Context, includeSuspended bool) ([]*gtsmodel.Instance, db.Error) {
|
||||
instances := []*gtsmodel.Instance{}
|
||||
instanceIDs := []string{}
|
||||
|
||||
q := i.conn.
|
||||
NewSelect().
|
||||
Model(&instances).
|
||||
TableExpr("? AS ?", bun.Ident("instances"), bun.Ident("instance")).
|
||||
// Select just the IDs of each instance.
|
||||
Column("instance.id").
|
||||
// Exclude our own instance.
|
||||
Where("? != ?", bun.Ident("instance.domain"), config.GetHost())
|
||||
|
||||
if !includeSuspended {
|
||||
q = q.Where("? IS NULL", bun.Ident("instance.suspended_at"))
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
if err := q.Scan(ctx, &instanceIDs); err != nil {
|
||||
return nil, i.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if len(instanceIDs) == 0 {
|
||||
return make([]*gtsmodel.Instance, 0), nil
|
||||
}
|
||||
|
||||
instances := make([]*gtsmodel.Instance, 0, len(instanceIDs))
|
||||
|
||||
for _, id := range instanceIDs {
|
||||
// Select each instance by its ID.
|
||||
instance, err := i.GetInstanceByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting instance %q: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to return slice.
|
||||
instances = append(instances, instance)
|
||||
}
|
||||
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
func (i *instanceDB) GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
|
||||
accounts := []*gtsmodel.Account{}
|
||||
// Ensure reasonable
|
||||
if limit < 0 {
|
||||
limit = 0
|
||||
}
|
||||
|
||||
q := i.conn.NewSelect().
|
||||
Model(&accounts).
|
||||
// Normalize the domain as punycode.
|
||||
var err error
|
||||
domain, err = util.Punify(domain)
|
||||
if err != nil {
|
||||
return nil, gtserror.Newf("error punifying domain %s: %w", domain, err)
|
||||
}
|
||||
|
||||
// Make educated guess for slice size
|
||||
accountIDs := make([]string, 0, limit)
|
||||
|
||||
q := i.conn.
|
||||
NewSelect().
|
||||
TableExpr("? AS ?", bun.Ident("accounts"), bun.Ident("account")).
|
||||
// Select just the account ID.
|
||||
Column("account.id").
|
||||
// Select accounts belonging to given domain.
|
||||
Where("? = ?", bun.Ident("account.domain"), domain).
|
||||
Order("account.id DESC")
|
||||
|
||||
if maxID != "" {
|
||||
q = q.Where("? < ?", bun.Ident("account.id"), maxID)
|
||||
if maxID == "" {
|
||||
maxID = id.Highest
|
||||
}
|
||||
q = q.Where("? < ?", bun.Ident("account.id"), maxID)
|
||||
|
||||
if limit > 0 {
|
||||
q = q.Limit(limit)
|
||||
}
|
||||
|
||||
if err := q.Scan(ctx); err != nil {
|
||||
if err := q.Scan(ctx, &accountIDs); err != nil {
|
||||
return nil, i.conn.ProcessError(err)
|
||||
}
|
||||
|
||||
if len(accounts) == 0 {
|
||||
// Catch case of no accounts early.
|
||||
count := len(accountIDs)
|
||||
if count == 0 {
|
||||
return nil, db.ErrNoEntries
|
||||
}
|
||||
|
||||
// Select each account by its ID.
|
||||
accounts := make([]*gtsmodel.Account, 0, count)
|
||||
for _, id := range accountIDs {
|
||||
account, err := i.state.DB.GetAccountByID(ctx, id)
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "error getting account %q: %v", id, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Append to return slice.
|
||||
accounts = append(accounts, account)
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -26,13 +26,19 @@
|
|||
|
||||
// Domain contains DB functions related to domains and domain blocks.
|
||||
type Domain interface {
|
||||
// CreateDomainBlock ...
|
||||
// CreateDomainBlock puts the given instance-level domain block into the database.
|
||||
CreateDomainBlock(ctx context.Context, block *gtsmodel.DomainBlock) Error
|
||||
|
||||
// GetDomainBlock ...
|
||||
// GetDomainBlock returns one instance-level domain block with the given domain, if it exists.
|
||||
GetDomainBlock(ctx context.Context, domain string) (*gtsmodel.DomainBlock, Error)
|
||||
|
||||
// DeleteDomainBlock ...
|
||||
// GetDomainBlockByID returns one instance-level domain block with the given id, if it exists.
|
||||
GetDomainBlockByID(ctx context.Context, id string) (*gtsmodel.DomainBlock, Error)
|
||||
|
||||
// GetDomainBlocks returns all instance-level domain blocks currently enforced by this instance.
|
||||
GetDomainBlocks(ctx context.Context) ([]*gtsmodel.DomainBlock, error)
|
||||
|
||||
// DeleteDomainBlock deletes an instance-level domain block with the given domain, if it exists.
|
||||
DeleteDomainBlock(ctx context.Context, domain string) Error
|
||||
|
||||
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
|
||||
|
|
|
@ -37,6 +37,15 @@ type Instance interface {
|
|||
// GetInstance returns the instance entry for the given domain, if it exists.
|
||||
GetInstance(ctx context.Context, domain string) (*gtsmodel.Instance, Error)
|
||||
|
||||
// GetInstanceByID returns the instance entry corresponding to the given id, if it exists.
|
||||
GetInstanceByID(ctx context.Context, id string) (*gtsmodel.Instance, error)
|
||||
|
||||
// PutInstance inserts the given instance into the database.
|
||||
PutInstance(ctx context.Context, instance *gtsmodel.Instance) error
|
||||
|
||||
// UpdateInstance updates the given instance entry.
|
||||
UpdateInstance(ctx context.Context, instance *gtsmodel.Instance, columns ...string) error
|
||||
|
||||
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
|
||||
GetInstanceAccounts(ctx context.Context, domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
|
||||
|
||||
|
|
|
@ -256,7 +256,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
|
|||
return nil, false, err
|
||||
}
|
||||
|
||||
if err := f.db.Put(ctx, instance); err != nil && !errors.Is(err, db.ErrAlreadyExists) {
|
||||
if err := f.db.PutInstance(ctx, instance); err != nil && !errors.Is(err, db.ErrAlreadyExists) {
|
||||
err = gtserror.Newf("error inserting instance entry for %s: %w", pubKeyOwner.Host, err)
|
||||
return nil, false, err
|
||||
}
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"strings"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"codeberg.org/gruf/go-kv"
|
||||
|
@ -40,20 +40,30 @@
|
|||
"github.com/superseriousbusiness/gotosocial/internal/text"
|
||||
)
|
||||
|
||||
func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Account, domain string, obfuscate bool, publicComment string, privateComment string, subscriptionID string) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
// domain blocks will always be lowercase
|
||||
domain = strings.ToLower(domain)
|
||||
// DomainBlockCreate creates an instance-level block against the given domain,
|
||||
// and then processes side effects of that block (deleting accounts, media, etc).
|
||||
//
|
||||
// If a domain block already exists for the domain, side effects will be retried.
|
||||
func (p *Processor) DomainBlockCreate(
|
||||
ctx context.Context,
|
||||
account *gtsmodel.Account,
|
||||
domain string,
|
||||
obfuscate bool,
|
||||
publicComment string,
|
||||
privateComment string,
|
||||
subscriptionID string,
|
||||
) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
// Check if a block already exists for this domain.
|
||||
domainBlock, err := p.state.DB.GetDomainBlock(ctx, domain)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
// Something went wrong in the DB.
|
||||
err = gtserror.Newf("db error getting domain block %s: %w", domain, err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// first check if we already have a block -- if err == nil we already had a block so we can skip a whole lot of work
|
||||
block, err := p.state.DB.GetDomainBlock(ctx, domain)
|
||||
if err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// something went wrong in the DB
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error checking for existence of domain block %s: %s", domain, err))
|
||||
}
|
||||
|
||||
// there's no block for this domain yet so create one
|
||||
newBlock := >smodel.DomainBlock{
|
||||
if domainBlock == nil {
|
||||
// No block exists yet, create it.
|
||||
domainBlock = >smodel.DomainBlock{
|
||||
ID: id.NewULID(),
|
||||
Domain: domain,
|
||||
CreatedByAccountID: account.ID,
|
||||
|
@ -63,249 +73,408 @@ func (p *Processor) DomainBlockCreate(ctx context.Context, account *gtsmodel.Acc
|
|||
SubscriptionID: subscriptionID,
|
||||
}
|
||||
|
||||
// Insert the new block into the database
|
||||
if err := p.state.DB.CreateDomainBlock(ctx, newBlock); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error putting new domain block %s: %s", domain, err))
|
||||
// Insert the new block into the database.
|
||||
if err := p.state.DB.CreateDomainBlock(ctx, domainBlock); err != nil {
|
||||
err = gtserror.Newf("db error putting domain block %s: %s", domain, err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// Set the newly created block
|
||||
block = newBlock
|
||||
|
||||
// Process the side effects of the domain block asynchronously since it might take a while
|
||||
go func() {
|
||||
p.initiateDomainBlockSideEffects(context.Background(), account, block)
|
||||
}()
|
||||
}
|
||||
|
||||
// Convert our gts model domain block into an API model
|
||||
apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, block, false)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting domain block to frontend/api representation %s: %s", domain, err))
|
||||
}
|
||||
// Process the side effects of the domain block
|
||||
// asynchronously since it might take a while.
|
||||
p.state.Workers.ClientAPI.Enqueue(func(ctx context.Context) {
|
||||
p.domainBlockSideEffects(ctx, account, domainBlock)
|
||||
})
|
||||
|
||||
return apiDomainBlock, nil
|
||||
return p.apiDomainBlock(ctx, domainBlock)
|
||||
}
|
||||
|
||||
// initiateDomainBlockSideEffects should be called asynchronously, to process the side effects of a domain block:
|
||||
// DomainBlocksImport handles the import of multiple domain blocks,
|
||||
// by calling the DomainBlockCreate function for each domain in the
|
||||
// provided file. Will return a slice of processed domain blocks.
|
||||
//
|
||||
// 1. Strip most info away from the instance entry for the domain.
|
||||
// 2. Delete the instance account for that instance if it exists.
|
||||
// 3. Select all accounts from this instance and pass them through the delete functionality of the processor.
|
||||
func (p *Processor) initiateDomainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) {
|
||||
l := log.WithContext(ctx).WithFields(kv.Fields{{"domain", block.Domain}}...)
|
||||
l.Debug("processing domain block side effects")
|
||||
|
||||
// if we have an instance entry for this domain, update it with the new block ID and clear all fields
|
||||
instance := >smodel.Instance{}
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: block.Domain}}, instance); err == nil {
|
||||
updatingColumns := []string{
|
||||
"title",
|
||||
"updated_at",
|
||||
"suspended_at",
|
||||
"domain_block_id",
|
||||
"short_description",
|
||||
"description",
|
||||
"terms",
|
||||
"contact_email",
|
||||
"contact_account_username",
|
||||
"contact_account_id",
|
||||
"version",
|
||||
}
|
||||
instance.Title = ""
|
||||
instance.UpdatedAt = time.Now()
|
||||
instance.SuspendedAt = time.Now()
|
||||
instance.DomainBlockID = block.ID
|
||||
instance.ShortDescription = ""
|
||||
instance.Description = ""
|
||||
instance.Terms = ""
|
||||
instance.ContactEmail = ""
|
||||
instance.ContactAccountUsername = ""
|
||||
instance.ContactAccountID = ""
|
||||
instance.Version = ""
|
||||
if err := p.state.DB.UpdateByID(ctx, instance, instance.ID, updatingColumns...); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error updating instance: %s", err)
|
||||
}
|
||||
l.Debug("domainBlockProcessSideEffects: instance entry updated")
|
||||
}
|
||||
|
||||
// if we have an instance account for this instance, delete it
|
||||
if instanceAccount, err := p.state.DB.GetAccountByUsernameDomain(ctx, block.Domain, block.Domain); err == nil {
|
||||
if err := p.state.DB.DeleteAccount(ctx, instanceAccount.ID); err != nil {
|
||||
l.Errorf("domainBlockProcessSideEffects: db error deleting instance account: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// delete accounts through the normal account deletion system (which should also delete media + posts + remove posts from timelines)
|
||||
|
||||
limit := 20 // just select 20 accounts at a time so we don't nuke our DB/mem with one huge query
|
||||
var maxID string // this is initially an empty string so we'll start at the top of accounts list (sorted by ID)
|
||||
|
||||
selectAccountsLoop:
|
||||
for {
|
||||
accounts, err := p.state.DB.GetInstanceAccounts(ctx, block.Domain, maxID, limit)
|
||||
if err != nil {
|
||||
if err == db.ErrNoEntries {
|
||||
// no accounts left for this instance so we're done
|
||||
l.Infof("domainBlockProcessSideEffects: done iterating through accounts for domain %s", block.Domain)
|
||||
break selectAccountsLoop
|
||||
}
|
||||
// an actual error has occurred
|
||||
l.Errorf("domainBlockProcessSideEffects: db error selecting accounts for domain %s: %s", block.Domain, err)
|
||||
break selectAccountsLoop
|
||||
}
|
||||
|
||||
for i, a := range accounts {
|
||||
l.Debugf("putting delete for account %s in the clientAPI channel", a.Username)
|
||||
|
||||
// pass the account delete through the client api channel for processing
|
||||
p.state.Workers.EnqueueClientAPI(ctx, messages.FromClientAPI{
|
||||
APObjectType: ap.ActorPerson,
|
||||
APActivityType: ap.ActivityDelete,
|
||||
GTSModel: block,
|
||||
OriginAccount: account,
|
||||
TargetAccount: a,
|
||||
})
|
||||
|
||||
// if this is the last account in the slice, set the maxID appropriately for the next query
|
||||
if i == len(accounts)-1 {
|
||||
maxID = a.ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DomainBlocksImport handles the import of a bunch of domain blocks at once, by calling the DomainBlockCreate function for each domain in the provided file.
|
||||
func (p *Processor) DomainBlocksImport(ctx context.Context, account *gtsmodel.Account, domains *multipart.FileHeader) ([]*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
f, err := domains.Open()
|
||||
// In the case of total failure, a gtserror.WithCode will be returned
|
||||
// so that the caller can respond appropriately. In the case of
|
||||
// partial or total success, a MultiStatus model will be returned,
|
||||
// which contains information about success/failure count, so that
|
||||
// the caller can retry any failures as they wish.
|
||||
func (p *Processor) DomainBlocksImport(
|
||||
ctx context.Context,
|
||||
account *gtsmodel.Account,
|
||||
domainsF *multipart.FileHeader,
|
||||
) (*apimodel.MultiStatus, gtserror.WithCode) {
|
||||
// Open the provided file.
|
||||
file, err := domainsF.Open()
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: error opening attachment: %s", err))
|
||||
err = gtserror.Newf("error opening attachment: %w", err)
|
||||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Copy the file contents into a buffer.
|
||||
buf := new(bytes.Buffer)
|
||||
size, err := io.Copy(buf, f)
|
||||
size, err := io.Copy(buf, file)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: error reading attachment: %s", err))
|
||||
err = gtserror.Newf("error reading attachment: %w", err)
|
||||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
|
||||
// Ensure we actually read something.
|
||||
if size == 0 {
|
||||
return nil, gtserror.NewErrorBadRequest(errors.New("DomainBlocksImport: could not read provided attachment: size 0 bytes"))
|
||||
err = gtserror.New("error reading attachment: size 0 bytes")
|
||||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
|
||||
d := []apimodel.DomainBlock{}
|
||||
if err := json.Unmarshal(buf.Bytes(), &d); err != nil {
|
||||
return nil, gtserror.NewErrorBadRequest(fmt.Errorf("DomainBlocksImport: could not read provided attachment: %s", err))
|
||||
// Parse bytes as slice of domain blocks.
|
||||
domainBlocks := make([]*apimodel.DomainBlock, 0)
|
||||
if err := json.Unmarshal(buf.Bytes(), &domainBlocks); err != nil {
|
||||
err = gtserror.Newf("error parsing attachment as domain blocks: %w", err)
|
||||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
|
||||
blocks := []*apimodel.DomainBlock{}
|
||||
for _, d := range d {
|
||||
block, err := p.DomainBlockCreate(ctx, account, d.Domain.Domain, false, d.PublicComment, "", "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
count := len(domainBlocks)
|
||||
if count == 0 {
|
||||
err = gtserror.New("error importing domain blocks: 0 entries provided")
|
||||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
|
||||
// Try to process each domain block, differentiating
|
||||
// between successes and errors so that the caller can
|
||||
// try failed imports again if desired.
|
||||
multiStatusEntries := make([]apimodel.MultiStatusEntry, 0, count)
|
||||
|
||||
for _, domainBlock := range domainBlocks {
|
||||
var (
|
||||
domain = domainBlock.Domain.Domain
|
||||
obfuscate = domainBlock.Obfuscate
|
||||
publicComment = domainBlock.PublicComment
|
||||
privateComment = domainBlock.PrivateComment
|
||||
subscriptionID = "" // No sub ID for imports.
|
||||
errWithCode gtserror.WithCode
|
||||
)
|
||||
|
||||
domainBlock, errWithCode = p.DomainBlockCreate(
|
||||
ctx,
|
||||
account,
|
||||
domain,
|
||||
obfuscate,
|
||||
publicComment,
|
||||
privateComment,
|
||||
subscriptionID,
|
||||
)
|
||||
|
||||
var entry *apimodel.MultiStatusEntry
|
||||
|
||||
if errWithCode != nil {
|
||||
entry = &apimodel.MultiStatusEntry{
|
||||
// Use the failed domain entry as the resource value.
|
||||
Resource: domain,
|
||||
Message: errWithCode.Safe(),
|
||||
Status: errWithCode.Code(),
|
||||
}
|
||||
} else {
|
||||
entry = &apimodel.MultiStatusEntry{
|
||||
// Use successfully created API model domain block as the resource value.
|
||||
Resource: domainBlock,
|
||||
Message: http.StatusText(http.StatusOK),
|
||||
Status: http.StatusOK,
|
||||
}
|
||||
}
|
||||
|
||||
blocks = append(blocks, block)
|
||||
multiStatusEntries = append(multiStatusEntries, *entry)
|
||||
}
|
||||
|
||||
return blocks, nil
|
||||
return apimodel.NewMultiStatus(multiStatusEntries), nil
|
||||
}
|
||||
|
||||
// DomainBlocksGet returns all existing domain blocks.
|
||||
// If export is true, the format will be suitable for writing out to an export.
|
||||
// DomainBlocksGet returns all existing domain blocks. If export is
|
||||
// true, the format will be suitable for writing out to an export.
|
||||
func (p *Processor) DomainBlocksGet(ctx context.Context, account *gtsmodel.Account, export bool) ([]*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
domainBlocks := []*gtsmodel.DomainBlock{}
|
||||
|
||||
if err := p.state.DB.GetAll(ctx, &domainBlocks); err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
domainBlocks, err := p.state.DB.GetDomainBlocks(ctx)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
err = gtserror.Newf("db error getting domain blocks: %w", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
apiDomainBlocks := []*apimodel.DomainBlock{}
|
||||
for _, b := range domainBlocks {
|
||||
apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, b, export)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
apiDomainBlocks := make([]*apimodel.DomainBlock, 0, len(domainBlocks))
|
||||
for _, domainBlock := range domainBlocks {
|
||||
apiDomainBlock, errWithCode := p.apiDomainBlock(ctx, domainBlock)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
apiDomainBlocks = append(apiDomainBlocks, apiDomainBlock)
|
||||
}
|
||||
|
||||
return apiDomainBlocks, nil
|
||||
}
|
||||
|
||||
// DomainBlockGet returns one domain block with the given id.
|
||||
// If export is true, the format will be suitable for writing out to an export.
|
||||
func (p *Processor) DomainBlockGet(ctx context.Context, account *gtsmodel.Account, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
domainBlock := >smodel.DomainBlock{}
|
||||
|
||||
if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
// there are no entries for this ID
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry for ID %s", id))
|
||||
}
|
||||
|
||||
apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, domainBlock, export)
|
||||
// DomainBlockGet returns one domain block with the given id. If export
|
||||
// is true, the format will be suitable for writing out to an export.
|
||||
func (p *Processor) DomainBlockGet(ctx context.Context, id string, export bool) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
domainBlock, err := p.state.DB.GetDomainBlockByID(ctx, id)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNoEntries) {
|
||||
err = fmt.Errorf("no domain block exists with id %s", id)
|
||||
return nil, gtserror.NewErrorNotFound(err, err.Error())
|
||||
}
|
||||
|
||||
// Something went wrong in the DB.
|
||||
err = gtserror.Newf("db error getting domain block %s: %w", id, err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
return p.apiDomainBlock(ctx, domainBlock)
|
||||
}
|
||||
|
||||
// DomainBlockDelete removes one domain block with the given ID,
|
||||
// and processes side effects of removing the block asynchronously.
|
||||
func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
domainBlock, err := p.state.DB.GetDomainBlockByID(ctx, id)
|
||||
if err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// Real error.
|
||||
err = gtserror.Newf("db error getting domain block: %w", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// There are just no entries for this ID.
|
||||
err = fmt.Errorf("no domain block entry exists with ID %s", id)
|
||||
return nil, gtserror.NewErrorNotFound(err, err.Error())
|
||||
}
|
||||
|
||||
// Prepare the domain block to return, *before* the deletion goes through.
|
||||
apiDomainBlock, errWithCode := p.apiDomainBlock(ctx, domainBlock)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// Copy value of the domain block.
|
||||
domainBlockC := new(gtsmodel.DomainBlock)
|
||||
*domainBlockC = *domainBlock
|
||||
|
||||
// Delete the original domain block.
|
||||
if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
|
||||
err = gtserror.Newf("db error deleting domain block: %w", err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// Process the side effects of the domain unblock
|
||||
// asynchronously since it might take a while.
|
||||
p.state.Workers.ClientAPI.Enqueue(func(ctx context.Context) {
|
||||
p.domainUnblockSideEffects(ctx, domainBlockC) // Use the copy.
|
||||
})
|
||||
|
||||
return apiDomainBlock, nil
|
||||
}
|
||||
|
||||
// DomainBlockDelete removes one domain block with the given ID.
|
||||
func (p *Processor) DomainBlockDelete(ctx context.Context, account *gtsmodel.Account, id string) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
domainBlock := >smodel.DomainBlock{}
|
||||
// stubbifyInstance renders the given instance as a stub,
|
||||
// removing most information from it and marking it as
|
||||
// suspended.
|
||||
//
|
||||
// For caller's convenience, this function returns the db
|
||||
// names of all columns that are updated by it.
|
||||
func stubbifyInstance(instance *gtsmodel.Instance, domainBlockID string) []string {
|
||||
instance.Title = ""
|
||||
instance.SuspendedAt = time.Now()
|
||||
instance.DomainBlockID = domainBlockID
|
||||
instance.ShortDescription = ""
|
||||
instance.Description = ""
|
||||
instance.Terms = ""
|
||||
instance.ContactEmail = ""
|
||||
instance.ContactAccountUsername = ""
|
||||
instance.ContactAccountID = ""
|
||||
instance.Version = ""
|
||||
|
||||
if err := p.state.DB.GetByID(ctx, id, domainBlock); err != nil {
|
||||
if !errors.Is(err, db.ErrNoEntries) {
|
||||
// something has gone really wrong
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
// there are no entries for this ID
|
||||
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry for ID %s", id))
|
||||
return []string{
|
||||
"title",
|
||||
"suspended_at",
|
||||
"domain_block_id",
|
||||
"short_description",
|
||||
"description",
|
||||
"terms",
|
||||
"contact_email",
|
||||
"contact_account_username",
|
||||
"contact_account_id",
|
||||
"version",
|
||||
}
|
||||
}
|
||||
|
||||
// prepare the domain block to return
|
||||
// apiDomainBlock is a cheeky shortcut function for returning the API
|
||||
// version of the given domainBlock, or an appropriate error if
|
||||
// something goes wrong.
|
||||
func (p *Processor) apiDomainBlock(ctx context.Context, domainBlock *gtsmodel.DomainBlock) (*apimodel.DomainBlock, gtserror.WithCode) {
|
||||
apiDomainBlock, err := p.tc.DomainBlockToAPIDomainBlock(ctx, domainBlock, false)
|
||||
if err != nil {
|
||||
err = gtserror.Newf("error converting domain block for %s to api model : %w", domainBlock.Domain, err)
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// Delete the domain block
|
||||
if err := p.state.DB.DeleteDomainBlock(ctx, domainBlock.Domain); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(err)
|
||||
}
|
||||
|
||||
// remove the domain block reference from the instance, if we have an entry for it
|
||||
i := >smodel.Instance{}
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{
|
||||
{Key: "domain", Value: domainBlock.Domain},
|
||||
{Key: "domain_block_id", Value: id},
|
||||
}, i); err == nil {
|
||||
updatingColumns := []string{"suspended_at", "domain_block_id", "updated_at"}
|
||||
i.SuspendedAt = time.Time{}
|
||||
i.DomainBlockID = ""
|
||||
i.UpdatedAt = time.Now()
|
||||
if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("couldn't update database entry for instance %s: %s", domainBlock.Domain, err))
|
||||
}
|
||||
}
|
||||
|
||||
// unsuspend all accounts whose suspension origin was this domain block
|
||||
// 1. remove the 'suspended_at' entry from their accounts
|
||||
if err := p.state.DB.UpdateWhere(ctx, []db.Where{
|
||||
{Key: "suspension_origin", Value: domainBlock.ID},
|
||||
}, "suspended_at", nil, &[]*gtsmodel.Account{}); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspended_at from accounts: %s", err))
|
||||
}
|
||||
|
||||
// 2. remove the 'suspension_origin' entry from their accounts
|
||||
if err := p.state.DB.UpdateWhere(ctx, []db.Where{
|
||||
{Key: "suspension_origin", Value: domainBlock.ID},
|
||||
}, "suspension_origin", nil, &[]*gtsmodel.Account{}); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("database error removing suspension_origin from accounts: %s", err))
|
||||
}
|
||||
|
||||
return apiDomainBlock, nil
|
||||
}
|
||||
|
||||
// rangeAccounts iterates through all accounts originating from the
|
||||
// given domain, and calls the provided range function on each account.
|
||||
// If an error is returned from the range function, the loop will stop
|
||||
// and return the error.
|
||||
func (p *Processor) rangeAccounts(
|
||||
ctx context.Context,
|
||||
domain string,
|
||||
rangeF func(*gtsmodel.Account) error,
|
||||
) error {
|
||||
var (
|
||||
limit = 50 // Limit selection to avoid spiking mem/cpu.
|
||||
maxID string // Start with empty string to select from top.
|
||||
)
|
||||
|
||||
for {
|
||||
// Get (next) page of accounts.
|
||||
accounts, err := p.state.DB.GetInstanceAccounts(ctx, domain, maxID, limit)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
// Real db error.
|
||||
return gtserror.Newf("db error getting instance accounts: %w", err)
|
||||
}
|
||||
|
||||
if len(accounts) == 0 {
|
||||
// No accounts left, we're done.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set next max ID for paging down.
|
||||
maxID = accounts[len(accounts)-1].ID
|
||||
|
||||
// Call provided range function.
|
||||
for _, account := range accounts {
|
||||
if err := rangeF(account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// domainBlockSideEffects processes the side effects of a domain block:
|
||||
//
|
||||
// 1. Strip most info away from the instance entry for the domain.
|
||||
// 2. Pass each account from the domain to the processor for deletion.
|
||||
//
|
||||
// It should be called asynchronously, since it can take a while when
|
||||
// there are many accounts present on the given domain.
|
||||
func (p *Processor) domainBlockSideEffects(ctx context.Context, account *gtsmodel.Account, block *gtsmodel.DomainBlock) {
|
||||
l := log.
|
||||
WithContext(ctx).
|
||||
WithFields(kv.Fields{
|
||||
{"domain", block.Domain},
|
||||
}...)
|
||||
l.Debug("processing domain block side effects")
|
||||
|
||||
// If we have an instance entry for this domain,
|
||||
// update it with the new block ID and clear all fields
|
||||
instance, err := p.state.DB.GetInstance(ctx, block.Domain)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
l.Errorf("db error getting instance %s: %q", block.Domain, err)
|
||||
}
|
||||
|
||||
if instance != nil {
|
||||
// We had an entry for this domain.
|
||||
columns := stubbifyInstance(instance, block.ID)
|
||||
if err := p.state.DB.UpdateInstance(ctx, instance, columns...); err != nil {
|
||||
l.Errorf("db error updating instance: %s", err)
|
||||
} else {
|
||||
l.Debug("instance entry updated")
|
||||
}
|
||||
}
|
||||
|
||||
// For each account that belongs to this domain, create
|
||||
// an account delete message to process via the client API
|
||||
// worker pool, to remove that account's posts, media, etc.
|
||||
msgs := []messages.FromClientAPI{}
|
||||
if err := p.rangeAccounts(ctx, block.Domain, func(account *gtsmodel.Account) error {
|
||||
msgs = append(msgs, messages.FromClientAPI{
|
||||
APObjectType: ap.ActorPerson,
|
||||
APActivityType: ap.ActivityDelete,
|
||||
GTSModel: block,
|
||||
OriginAccount: account,
|
||||
TargetAccount: account,
|
||||
})
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
l.Errorf("error while ranging through accounts: %q", err)
|
||||
}
|
||||
|
||||
// Batch process all accreted messages.
|
||||
p.state.Workers.EnqueueClientAPI(ctx, msgs...)
|
||||
}
|
||||
|
||||
// domainUnblockSideEffects processes the side effects of undoing a
|
||||
// domain block:
|
||||
//
|
||||
// 1. Mark instance entry as no longer suspended.
|
||||
// 2. Mark each account from the domain as no longer suspended, if the
|
||||
// suspension origin corresponds to the ID of the provided domain block.
|
||||
//
|
||||
// It should be called asynchronously, since it can take a while when
|
||||
// there are many accounts present on the given domain.
|
||||
func (p *Processor) domainUnblockSideEffects(ctx context.Context, block *gtsmodel.DomainBlock) {
|
||||
l := log.
|
||||
WithContext(ctx).
|
||||
WithFields(kv.Fields{
|
||||
{"domain", block.Domain},
|
||||
}...)
|
||||
l.Debug("processing domain unblock side effects")
|
||||
|
||||
// Update instance entry for this domain, if we have it.
|
||||
instance, err := p.state.DB.GetInstance(ctx, block.Domain)
|
||||
if err != nil && !errors.Is(err, db.ErrNoEntries) {
|
||||
l.Errorf("db error getting instance %s: %q", block.Domain, err)
|
||||
}
|
||||
|
||||
if instance != nil {
|
||||
// We had an entry, update it to signal
|
||||
// that it's no longer suspended.
|
||||
instance.SuspendedAt = time.Time{}
|
||||
instance.DomainBlockID = ""
|
||||
if err := p.state.DB.UpdateInstance(
|
||||
ctx,
|
||||
instance,
|
||||
"suspended_at",
|
||||
"domain_block_id",
|
||||
); err != nil {
|
||||
l.Errorf("db error updating instance: %s", err)
|
||||
} else {
|
||||
l.Debug("instance entry updated")
|
||||
}
|
||||
}
|
||||
|
||||
// Unsuspend all accounts whose suspension origin was this domain block.
|
||||
if err := p.rangeAccounts(ctx, block.Domain, func(account *gtsmodel.Account) error {
|
||||
if account.SuspensionOrigin == "" || account.SuspendedAt.IsZero() {
|
||||
// Account wasn't suspended, nothing to do.
|
||||
return nil
|
||||
}
|
||||
|
||||
if account.SuspensionOrigin != block.ID {
|
||||
// Account was suspended, but not by
|
||||
// this domain block, leave it alone.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Account was suspended by this domain
|
||||
// block, mark it as unsuspended.
|
||||
account.SuspendedAt = time.Time{}
|
||||
account.SuspensionOrigin = ""
|
||||
|
||||
if err := p.state.DB.UpdateAccount(
|
||||
ctx,
|
||||
account,
|
||||
"suspended_at",
|
||||
"suspension_origin",
|
||||
); err != nil {
|
||||
return gtserror.Newf("db error updating account %s: %w", account.Username, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
l.Errorf("error while ranging through accounts: %q", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -34,11 +34,12 @@
|
|||
)
|
||||
|
||||
func (p *Processor) getThisInstance(ctx context.Context) (*gtsmodel.Instance, error) {
|
||||
i := >smodel.Instance{}
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: config.GetHost()}}, i); err != nil {
|
||||
instance, err := p.state.DB.GetInstance(ctx, config.GetHost())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return i, nil
|
||||
|
||||
return instance, nil
|
||||
}
|
||||
|
||||
func (p *Processor) InstanceGetV1(ctx context.Context) (*apimodel.InstanceV1, gtserror.WithCode) {
|
||||
|
@ -137,9 +138,10 @@ func (p *Processor) InstancePeersGet(ctx context.Context, includeSuspended bool,
|
|||
|
||||
func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.InstanceV1, gtserror.WithCode) {
|
||||
// fetch the instance entry from the db for processing
|
||||
i := >smodel.Instance{}
|
||||
host := config.GetHost()
|
||||
if err := p.state.DB.GetWhere(ctx, []db.Where{{Key: "domain", Value: host}}, i); err != nil {
|
||||
|
||||
instance, err := p.state.DB.GetInstance(ctx, host)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance %s: %s", host, err))
|
||||
}
|
||||
|
||||
|
@ -157,7 +159,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
return nil, gtserror.NewErrorBadRequest(err, fmt.Sprintf("site title invalid: %s", err))
|
||||
}
|
||||
updatingColumns = append(updatingColumns, "title")
|
||||
i.Title = text.SanitizePlaintext(*form.Title) // don't allow html in site title
|
||||
instance.Title = text.SanitizePlaintext(*form.Title) // don't allow html in site title
|
||||
}
|
||||
|
||||
// validate & update site contact account if it's set on the form
|
||||
|
@ -192,7 +194,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
updatingColumns = append(updatingColumns, "contact_account_id")
|
||||
i.ContactAccountID = contactAccount.ID
|
||||
instance.ContactAccountID = contactAccount.ID
|
||||
}
|
||||
|
||||
// validate & update site contact email if it's set on the form
|
||||
|
@ -204,7 +206,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
}
|
||||
}
|
||||
updatingColumns = append(updatingColumns, "contact_email")
|
||||
i.ContactEmail = contactEmail
|
||||
instance.ContactEmail = contactEmail
|
||||
}
|
||||
|
||||
// validate & update site short description if it's set on the form
|
||||
|
@ -213,7 +215,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
updatingColumns = append(updatingColumns, "short_description")
|
||||
i.ShortDescription = text.SanitizeHTML(*form.ShortDescription) // html is OK in site description, but we should sanitize it
|
||||
instance.ShortDescription = text.SanitizeHTML(*form.ShortDescription) // html is OK in site description, but we should sanitize it
|
||||
}
|
||||
|
||||
// validate & update site description if it's set on the form
|
||||
|
@ -222,7 +224,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
updatingColumns = append(updatingColumns, "description")
|
||||
i.Description = text.SanitizeHTML(*form.Description) // html is OK in site description, but we should sanitize it
|
||||
instance.Description = text.SanitizeHTML(*form.Description) // html is OK in site description, but we should sanitize it
|
||||
}
|
||||
|
||||
// validate & update site terms if it's set on the form
|
||||
|
@ -231,7 +233,7 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
return nil, gtserror.NewErrorBadRequest(err, err.Error())
|
||||
}
|
||||
updatingColumns = append(updatingColumns, "terms")
|
||||
i.Terms = text.SanitizeHTML(*form.Terms) // html is OK in site terms, but we should sanitize it
|
||||
instance.Terms = text.SanitizeHTML(*form.Terms) // html is OK in site terms, but we should sanitize it
|
||||
}
|
||||
|
||||
var updateInstanceAccount bool
|
||||
|
@ -273,12 +275,12 @@ func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSe
|
|||
}
|
||||
|
||||
if len(updatingColumns) != 0 {
|
||||
if err := p.state.DB.UpdateByID(ctx, i, i.ID, updatingColumns...); err != nil {
|
||||
if err := p.state.DB.UpdateInstance(ctx, instance, updatingColumns...); err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error updating instance %s: %s", host, err))
|
||||
}
|
||||
}
|
||||
|
||||
ai, err := p.tc.InstanceToAPIV1Instance(ctx, i)
|
||||
ai, err := p.tc.InstanceToAPIV1Instance(ctx, instance)
|
||||
if err != nil {
|
||||
return nil, gtserror.NewErrorInternalError(fmt.Errorf("error converting instance to api representation: %s", err))
|
||||
}
|
||||
|
|
|
@ -40,6 +40,9 @@ EXPECT=$(cat <<"EOF"
|
|||
"follow-request-ttl": 1800000000000,
|
||||
"follow-sweep-freq": 60000000000,
|
||||
"follow-ttl": 1800000000000,
|
||||
"instance-max-size": 2000,
|
||||
"instance-sweep-freq": 60000000000,
|
||||
"instance-ttl": 1800000000000,
|
||||
"list-entry-max-size": 2000,
|
||||
"list-entry-sweep-freq": 60000000000,
|
||||
"list-entry-ttl": 1800000000000,
|
||||
|
|
Loading…
Reference in a new issue