diff --git a/docs/api/swagger.yaml b/docs/api/swagger.yaml index e76e4e6cf..235309ba9 100644 --- a/docs/api/swagger.yaml +++ b/docs/api/swagger.yaml @@ -7168,6 +7168,34 @@ paths: summary: View instance rule with the given id. tags: - admin + /api/v1/announcements: + get: + description: 'THIS ENDPOINT IS CURRENTLY NOT FULLY IMPLEMENTED: it will always return an empty array.' + operationId: announcementsGet + produces: + - application/json + responses: + "200": + description: "" + schema: + items: + type: object + maxItems: 0 + type: array + "400": + description: bad request + "401": + description: unauthorized + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - read:announcements + summary: Get an array of currently active announcements. + tags: + - announcements /api/v1/apps: post: consumes: @@ -9945,6 +9973,112 @@ paths: summary: Create a new status using the given form field parameters. tags: - statuses + put: + consumes: + - application/json + - application/x-www-form-urlencoded + description: The parameters can also be given in the body of the request, as JSON, if the content-type is set to 'application/json'. + operationId: statusEdit + parameters: + - description: |- + Text content of the status. + If media_ids is provided, this becomes optional. + Attaching a poll is optional while status is provided. + in: formData + name: status + type: string + x-go-name: Status + - description: |- + Array of Attachment ids to be attached as media. + If provided, status becomes optional, and poll cannot be used. + + If the status is being submitted as a form, the key is 'media_ids[]', + but if it's json or xml, the key is 'media_ids'. + in: formData + items: + type: string + name: media_ids + type: array + x-go-name: MediaIDs + - description: |- + Array of possible poll answers. + If provided, media_ids cannot be used, and poll[expires_in] must be provided. + in: formData + items: + type: string + name: poll[options][] + type: array + x-go-name: PollOptions + - description: |- + Duration the poll should be open, in seconds. + If provided, media_ids cannot be used, and poll[options] must be provided. + format: int64 + in: formData + name: poll[expires_in] + type: integer + x-go-name: PollExpiresIn + - default: false + description: Allow multiple choices on this poll. + in: formData + name: poll[multiple] + type: boolean + x-go-name: PollMultiple + - default: true + description: Hide vote counts until the poll ends. + in: formData + name: poll[hide_totals] + type: boolean + x-go-name: PollHideTotals + - description: Status and attached media should be marked as sensitive. + in: formData + name: sensitive + type: boolean + x-go-name: Sensitive + - description: |- + Text to be shown as a warning or subject before the actual content. + Statuses are generally collapsed behind this field. + in: formData + name: spoiler_text + type: string + x-go-name: SpoilerText + - description: ISO 639 language code for this status. + in: formData + name: language + type: string + x-go-name: Language + - description: Content type to use when parsing this status. + enum: + - text/plain + - text/markdown + in: formData + name: content_type + type: string + x-go-name: ContentType + produces: + - application/json + responses: + "200": + description: The latest status revision. + schema: + $ref: '#/definitions/status' + "400": + description: bad request + "401": + description: unauthorized + "403": + description: forbidden + "404": + description: not found + "406": + description: not acceptable + "500": + description: internal server error + security: + - OAuth2 Bearer: + - write:statuses + summary: Edit an existing status using the given form field parameters. + tags: + - statuses /api/v1/statuses/{id}: delete: description: |- diff --git a/go.mod b/go.mod index a6a030eb6..c8537f915 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.62 github.com/minio/minio-go/v7 v7.0.81 github.com/mitchellh/mapstructure v1.5.0 - github.com/ncruces/go-sqlite3 v0.21.2 + github.com/ncruces/go-sqlite3 v0.21.3 github.com/oklog/ulid v1.3.1 github.com/prometheus/client_golang v1.20.5 github.com/spf13/cobra v1.8.1 diff --git a/go.sum b/go.sum index 09c4f9565..3dca6a3ae 100644 --- a/go.sum +++ b/go.sum @@ -434,8 +434,8 @@ github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs= github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= -github.com/ncruces/go-sqlite3 v0.21.2 h1:X7Ao4BwtS9h308lFtZA/stkvrzEHvAdp8g4Gko7Ehjs= -github.com/ncruces/go-sqlite3 v0.21.2/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA= +github.com/ncruces/go-sqlite3 v0.21.3 h1:hHkfNQLcbnxPJZhC/RGw9SwP3bfkv/Y0xUHWsr1CdMQ= +github.com/ncruces/go-sqlite3 v0.21.3/go.mod h1:zxMOaSG5kFYVFK4xQa0pdwIszqxqJ0W0BxBgwdrNjuA= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= diff --git a/internal/api/client.go b/internal/api/client.go index 77a63eb89..60daddf87 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -23,6 +23,7 @@ "github.com/gin-gonic/gin" "github.com/superseriousbusiness/gotosocial/internal/api/client/accounts" "github.com/superseriousbusiness/gotosocial/internal/api/client/admin" + "github.com/superseriousbusiness/gotosocial/internal/api/client/announcements" "github.com/superseriousbusiness/gotosocial/internal/api/client/apps" "github.com/superseriousbusiness/gotosocial/internal/api/client/blocks" "github.com/superseriousbusiness/gotosocial/internal/api/client/bookmarks" @@ -66,6 +67,7 @@ type Client struct { accounts *accounts.Module // api/v1/accounts, api/v1/profile admin *admin.Module // api/v1/admin + announcements *announcements.Module // api/v1/announcements apps *apps.Module // api/v1/apps blocks *blocks.Module // api/v1/blocks bookmarks *bookmarks.Module // api/v1/bookmarks @@ -117,6 +119,7 @@ func (c *Client) Route(r *router.Router, m ...gin.HandlerFunc) { h := apiGroup.Handle c.accounts.Route(h) c.admin.Route(h) + c.announcements.Route(h) c.apps.Route(h) c.blocks.Route(h) c.bookmarks.Route(h) @@ -156,6 +159,7 @@ func NewClient(state *state.State, p *processing.Processor) *Client { accounts: accounts.New(p), admin: admin.New(state, p), + announcements: announcements.New(p), apps: apps.New(p), blocks: blocks.New(p), bookmarks: bookmarks.New(p), diff --git a/internal/api/client/announcements/announcements.go b/internal/api/client/announcements/announcements.go new file mode 100644 index 000000000..611a1c53e --- /dev/null +++ b/internal/api/client/announcements/announcements.go @@ -0,0 +1,42 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package announcements + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/superseriousbusiness/gotosocial/internal/processing" +) + +// BasePath is the base path for this api module, excluding the api prefix +const BasePath = "/v1/announcements" + +type Module struct { + processor *processing.Processor +} + +func New(processor *processing.Processor) *Module { + return &Module{ + processor: processor, + } +} + +func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { + attachHandler(http.MethodGet, BasePath, m.AnnouncementsGETHandler) +} diff --git a/internal/api/client/announcements/announcementsget.go b/internal/api/client/announcements/announcementsget.go new file mode 100644 index 000000000..04bd5f285 --- /dev/null +++ b/internal/api/client/announcements/announcementsget.go @@ -0,0 +1,74 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package announcements + +import ( + "net/http" + + "github.com/gin-gonic/gin" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" +) + +// AnnouncementsGETHandler swagger:operation GET /api/v1/announcements announcementsGet +// +// Get an array of currently active announcements. +// +// THIS ENDPOINT IS CURRENTLY NOT FULLY IMPLEMENTED: it will always return an empty array. +// +// --- +// tags: +// - announcements +// +// produces: +// - application/json +// +// security: +// - OAuth2 Bearer: +// - read:announcements +// +// responses: +// '200': +// schema: +// type: array +// items: +// type: object +// maxItems: 0 +// '400': +// description: bad request +// '401': +// description: unauthorized +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) AnnouncementsGETHandler(c *gin.Context) { + _, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, apiutil.EmptyJSONArray) +} diff --git a/internal/api/client/statuses/status.go b/internal/api/client/statuses/status.go index 33af9c456..88b34cbf5 100644 --- a/internal/api/client/statuses/status.go +++ b/internal/api/client/statuses/status.go @@ -83,9 +83,10 @@ func New(processor *processing.Processor) *Module { } func (m *Module) Route(attachHandler func(method string, path string, f ...gin.HandlerFunc) gin.IRoutes) { - // create / get / delete status + // create / get / edit / delete status attachHandler(http.MethodPost, BasePath, m.StatusCreatePOSTHandler) attachHandler(http.MethodGet, BasePathWithID, m.StatusGETHandler) + attachHandler(http.MethodPut, BasePathWithID, m.StatusEditPUTHandler) attachHandler(http.MethodDelete, BasePathWithID, m.StatusDELETEHandler) // fave stuff diff --git a/internal/api/client/statuses/statuscreate.go b/internal/api/client/statuses/statuscreate.go index 8198d5358..c83cdbad7 100644 --- a/internal/api/client/statuses/statuscreate.go +++ b/internal/api/client/statuses/statuscreate.go @@ -27,11 +27,9 @@ "github.com/go-playground/form/v4" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" - "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/util" - "github.com/superseriousbusiness/gotosocial/internal/validate" ) // StatusCreatePOSTHandler swagger:operation POST /api/v1/statuses statusCreate @@ -272,9 +270,9 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) { return } - form, err := parseStatusCreateForm(c) - if err != nil { - apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) + form, errWithCode := parseStatusCreateForm(c) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) return } @@ -287,11 +285,6 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) { // } // form.Status += "\n\nsent from " + user + "'s iphone\n" - if errWithCode := validateStatusCreateForm(form); errWithCode != nil { - apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) - return - } - apiStatus, errWithCode := m.processor.Status().Create( c.Request.Context(), authed.Account, @@ -303,7 +296,7 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) { return } - c.JSON(http.StatusOK, apiStatus) + apiutil.JSON(c, http.StatusOK, apiStatus) } // intPolicyFormBinding satisfies gin's binding.Binding interface. @@ -328,108 +321,69 @@ func (intPolicyFormBinding) Bind(req *http.Request, obj any) error { return decoder.Decode(obj, req.Form) } -func parseStatusCreateForm(c *gin.Context) (*apimodel.StatusCreateRequest, error) { +func parseStatusCreateForm(c *gin.Context) (*apimodel.StatusCreateRequest, gtserror.WithCode) { form := new(apimodel.StatusCreateRequest) switch ct := c.ContentType(); ct { case binding.MIMEJSON: // Just bind with default json binding. if err := c.ShouldBindWith(form, binding.JSON); err != nil { - return nil, err + return nil, gtserror.NewErrorBadRequest( + err, + err.Error(), + ) } case binding.MIMEPOSTForm: // Bind with default form binding first. if err := c.ShouldBindWith(form, binding.FormPost); err != nil { - return nil, err + return nil, gtserror.NewErrorBadRequest( + err, + err.Error(), + ) } // Now do custom binding. intReqForm := new(apimodel.StatusInteractionPolicyForm) if err := c.ShouldBindWith(intReqForm, intPolicyFormBinding{}); err != nil { - return nil, err + return nil, gtserror.NewErrorBadRequest( + err, + err.Error(), + ) } + form.InteractionPolicy = intReqForm.InteractionPolicy case binding.MIMEMultipartPOSTForm: // Bind with default form binding first. if err := c.ShouldBindWith(form, binding.FormMultipart); err != nil { - return nil, err + return nil, gtserror.NewErrorBadRequest( + err, + err.Error(), + ) } // Now do custom binding. intReqForm := new(apimodel.StatusInteractionPolicyForm) if err := c.ShouldBindWith(intReqForm, intPolicyFormBinding{}); err != nil { - return nil, err + return nil, gtserror.NewErrorBadRequest( + err, + err.Error(), + ) } + form.InteractionPolicy = intReqForm.InteractionPolicy default: - err := fmt.Errorf( - "content-type %s not supported for this endpoint; supported content-types are %s, %s, %s", - ct, binding.MIMEJSON, binding.MIMEPOSTForm, binding.MIMEMultipartPOSTForm, - ) - return nil, err - } - - return form, nil -} - -// validateStatusCreateForm checks the form for disallowed -// combinations of attachments, overlength inputs, etc. -// -// Side effect: normalizes the post's language tag. -func validateStatusCreateForm(form *apimodel.StatusCreateRequest) gtserror.WithCode { - var ( - chars = len([]rune(form.Status)) + len([]rune(form.SpoilerText)) - maxChars = config.GetStatusesMaxChars() - mediaFiles = len(form.MediaIDs) - maxMediaFiles = config.GetStatusesMediaMaxFiles() - hasMedia = mediaFiles != 0 - hasPoll = form.Poll != nil - ) - - if chars == 0 && !hasMedia && !hasPoll { - // Status must contain *some* kind of content. - const text = "no status content, content warning, media, or poll provided" - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - if chars > maxChars { - text := fmt.Sprintf( - "status too long, %d characters provided (including content warning) but limit is %d", - chars, maxChars, - ) - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - if mediaFiles > maxMediaFiles { - text := fmt.Sprintf( - "too many media files attached to status, %d attached but limit is %d", - mediaFiles, maxMediaFiles, - ) - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - if form.Poll != nil { - if errWithCode := validateStatusPoll(form); errWithCode != nil { - return errWithCode - } + text := fmt.Sprintf("content-type %s not supported for this endpoint; supported content-types are %s, %s, %s", + ct, binding.MIMEJSON, binding.MIMEPOSTForm, binding.MIMEMultipartPOSTForm) + return nil, gtserror.NewErrorNotAcceptable(errors.New(text), text) } + // Check not scheduled status. if form.ScheduledAt != "" { const text = "scheduled_at is not yet implemented" - return gtserror.NewErrorNotImplemented(errors.New(text), text) - } - - // Validate + normalize - // language tag if provided. - if form.Language != "" { - lang, err := validate.Language(form.Language) - if err != nil { - return gtserror.NewErrorBadRequest(err, err.Error()) - } - form.Language = lang + return nil, gtserror.NewErrorNotImplemented(errors.New(text), text) } // Check if the deprecated "federated" field was @@ -438,42 +392,9 @@ func validateStatusCreateForm(form *apimodel.StatusCreateRequest) gtserror.WithC form.LocalOnly = util.Ptr(!*form.Federated) // nolint:staticcheck } - return nil -} + // Normalize poll expiry time if a poll was given. + if form.Poll != nil && form.Poll.ExpiresInI != nil { -func validateStatusPoll(form *apimodel.StatusCreateRequest) gtserror.WithCode { - var ( - maxPollOptions = config.GetStatusesPollMaxOptions() - pollOptions = len(form.Poll.Options) - maxPollOptionChars = config.GetStatusesPollOptionMaxChars() - ) - - if pollOptions == 0 { - const text = "poll with no options" - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - if pollOptions > maxPollOptions { - text := fmt.Sprintf( - "too many poll options provided, %d provided but limit is %d", - pollOptions, maxPollOptions, - ) - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - for _, option := range form.Poll.Options { - optionChars := len([]rune(option)) - if optionChars > maxPollOptionChars { - text := fmt.Sprintf( - "poll option too long, %d characters provided but limit is %d", - optionChars, maxPollOptionChars, - ) - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - } - - // Normalize poll expiry if necessary. - if form.Poll.ExpiresInI != nil { // If we parsed this as JSON, expires_in // may be either a float64 or a string. expiresIn, err := apiutil.ParseDuration( @@ -481,13 +402,10 @@ func validateStatusPoll(form *apimodel.StatusCreateRequest) gtserror.WithCode { "expires_in", ) if err != nil { - return gtserror.NewErrorBadRequest(err, err.Error()) - } - - if expiresIn != nil { - form.Poll.ExpiresIn = *expiresIn + return nil, gtserror.NewErrorBadRequest(err, err.Error()) } + form.Poll.ExpiresIn = util.PtrOrZero(expiresIn) } - return nil + return form, nil } diff --git a/internal/api/client/statuses/statusdelete.go b/internal/api/client/statuses/statusdelete.go index 7ee240dff..fa62d6893 100644 --- a/internal/api/client/statuses/statusdelete.go +++ b/internal/api/client/statuses/statusdelete.go @@ -95,5 +95,5 @@ func (m *Module) StatusDELETEHandler(c *gin.Context) { return } - c.JSON(http.StatusOK, apiStatus) + apiutil.JSON(c, http.StatusOK, apiStatus) } diff --git a/internal/api/client/statuses/statusedit.go b/internal/api/client/statuses/statusedit.go new file mode 100644 index 000000000..dfd7d651e --- /dev/null +++ b/internal/api/client/statuses/statusedit.go @@ -0,0 +1,249 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package statuses + +import ( + "errors" + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/gin-gonic/gin/binding" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/oauth" + "github.com/superseriousbusiness/gotosocial/internal/util" +) + +// StatusEditPUTHandler swagger:operation PUT /api/v1/statuses statusEdit +// +// Edit an existing status using the given form field parameters. +// +// The parameters can also be given in the body of the request, as JSON, if the content-type is set to 'application/json'. +// +// --- +// tags: +// - statuses +// +// consumes: +// - application/json +// - application/x-www-form-urlencoded +// +// parameters: +// - +// name: status +// x-go-name: Status +// description: |- +// Text content of the status. +// If media_ids is provided, this becomes optional. +// Attaching a poll is optional while status is provided. +// type: string +// in: formData +// - +// name: media_ids +// x-go-name: MediaIDs +// description: |- +// Array of Attachment ids to be attached as media. +// If provided, status becomes optional, and poll cannot be used. +// +// If the status is being submitted as a form, the key is 'media_ids[]', +// but if it's json or xml, the key is 'media_ids'. +// type: array +// items: +// type: string +// in: formData +// - +// name: poll[options][] +// x-go-name: PollOptions +// description: |- +// Array of possible poll answers. +// If provided, media_ids cannot be used, and poll[expires_in] must be provided. +// type: array +// items: +// type: string +// in: formData +// - +// name: poll[expires_in] +// x-go-name: PollExpiresIn +// description: |- +// Duration the poll should be open, in seconds. +// If provided, media_ids cannot be used, and poll[options] must be provided. +// type: integer +// format: int64 +// in: formData +// - +// name: poll[multiple] +// x-go-name: PollMultiple +// description: Allow multiple choices on this poll. +// type: boolean +// default: false +// in: formData +// - +// name: poll[hide_totals] +// x-go-name: PollHideTotals +// description: Hide vote counts until the poll ends. +// type: boolean +// default: true +// in: formData +// - +// name: sensitive +// x-go-name: Sensitive +// description: Status and attached media should be marked as sensitive. +// type: boolean +// in: formData +// - +// name: spoiler_text +// x-go-name: SpoilerText +// description: |- +// Text to be shown as a warning or subject before the actual content. +// Statuses are generally collapsed behind this field. +// type: string +// in: formData +// - +// name: language +// x-go-name: Language +// description: ISO 639 language code for this status. +// type: string +// in: formData +// - +// name: content_type +// x-go-name: ContentType +// description: Content type to use when parsing this status. +// type: string +// enum: +// - text/plain +// - text/markdown +// in: formData +// +// produces: +// - application/json +// +// security: +// - OAuth2 Bearer: +// - write:statuses +// +// responses: +// '200': +// description: "The latest status revision." +// schema: +// "$ref": "#/definitions/status" +// '400': +// description: bad request +// '401': +// description: unauthorized +// '403': +// description: forbidden +// '404': +// description: not found +// '406': +// description: not acceptable +// '500': +// description: internal server error +func (m *Module) StatusEditPUTHandler(c *gin.Context) { + authed, err := oauth.Authed(c, true, true, true, true) + if err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGetV1) + return + } + + if authed.Account.IsMoving() { + apiutil.ForbiddenAfterMove(c) + return + } + + if _, err := apiutil.NegotiateAccept(c, apiutil.JSONAcceptHeaders...); err != nil { + apiutil.ErrorHandler(c, gtserror.NewErrorNotAcceptable(err, err.Error()), m.processor.InstanceGetV1) + return + } + + form, errWithCode := parseStatusEditForm(c) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiStatus, errWithCode := m.processor.Status().Edit( + c.Request.Context(), + authed.Account, + c.Param(IDKey), + form, + ) + if errWithCode != nil { + apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) + return + } + + apiutil.JSON(c, http.StatusOK, apiStatus) +} + +func parseStatusEditForm(c *gin.Context) (*apimodel.StatusEditRequest, gtserror.WithCode) { + form := new(apimodel.StatusEditRequest) + + switch ct := c.ContentType(); ct { + case binding.MIMEJSON: + // Just bind with default json binding. + if err := c.ShouldBindWith(form, binding.JSON); err != nil { + return nil, gtserror.NewErrorBadRequest( + err, + err.Error(), + ) + } + + case binding.MIMEPOSTForm: + // Bind with default form binding first. + if err := c.ShouldBindWith(form, binding.FormPost); err != nil { + return nil, gtserror.NewErrorBadRequest( + err, + err.Error(), + ) + } + + case binding.MIMEMultipartPOSTForm: + // Bind with default form binding first. + if err := c.ShouldBindWith(form, binding.FormMultipart); err != nil { + return nil, gtserror.NewErrorBadRequest( + err, + err.Error(), + ) + } + + default: + text := fmt.Sprintf("content-type %s not supported for this endpoint; supported content-types are %s, %s, %s", + ct, binding.MIMEJSON, binding.MIMEPOSTForm, binding.MIMEMultipartPOSTForm) + return nil, gtserror.NewErrorNotAcceptable(errors.New(text), text) + } + + // Normalize poll expiry time if a poll was given. + if form.Poll != nil && form.Poll.ExpiresInI != nil { + + // If we parsed this as JSON, expires_in + // may be either a float64 or a string. + expiresIn, err := apiutil.ParseDuration( + form.Poll.ExpiresInI, + "expires_in", + ) + if err != nil { + return nil, gtserror.NewErrorBadRequest(err, err.Error()) + } + form.Poll.ExpiresIn = util.PtrOrZero(expiresIn) + } + + return form, nil + +} diff --git a/internal/api/client/statuses/statusedit_test.go b/internal/api/client/statuses/statusedit_test.go new file mode 100644 index 000000000..43b283d6d --- /dev/null +++ b/internal/api/client/statuses/statusedit_test.go @@ -0,0 +1,32 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package statuses_test + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type StatusEditTestSuite struct { + StatusStandardTestSuite +} + +func TestStatusEditTestSuite(t *testing.T) { + suite.Run(t, new(StatusEditTestSuite)) +} diff --git a/internal/api/client/statuses/statussource_test.go b/internal/api/client/statuses/statussource_test.go index 28b1e6852..797a462ed 100644 --- a/internal/api/client/statuses/statussource_test.go +++ b/internal/api/client/statuses/statussource_test.go @@ -91,7 +91,7 @@ func (suite *StatusSourceTestSuite) TestGetSource() { suite.Equal(`{ "id": "01F8MHAMCHF6Y650WCRSCP4WMY", - "text": "**STATUS EDITS ARE NOT CURRENTLY SUPPORTED IN GOTOSOCIAL (coming in 2024)**\nYou can review the original text of your status below, but you will not be able to submit this edit.\n\n---\n\nhello everyone!", + "text": "hello everyone!", "spoiler_text": "introduction post" }`, dst.String()) } diff --git a/internal/api/model/attachment.go b/internal/api/model/attachment.go index f037a09aa..1d910343c 100644 --- a/internal/api/model/attachment.go +++ b/internal/api/model/attachment.go @@ -23,12 +23,15 @@ // // swagger: ignore type AttachmentRequest struct { + // Media file. File *multipart.FileHeader `form:"file" binding:"required"` + // Description of the media file. Optional. // This will be used as alt-text for users of screenreaders etc. // example: This is an image of some kittens, they are very cute and fluffy. Description string `form:"description"` + // Focus of the media file. Optional. // If present, it should be in the form of two comma-separated floats between -1 and 1. // example: -0.5,0.565 @@ -39,16 +42,38 @@ type AttachmentRequest struct { // // swagger:ignore type AttachmentUpdateRequest struct { + // Description of the media file. // This will be used as alt-text for users of screenreaders etc. // allowEmptyValue: true Description *string `form:"description" json:"description" xml:"description"` + // Focus of the media file. // If present, it should be in the form of two comma-separated floats between -1 and 1. // allowEmptyValue: true Focus *string `form:"focus" json:"focus" xml:"focus"` } +// AttachmentAttributesRequest models an edit request for attachment attributes. +// +// swagger:ignore +type AttachmentAttributesRequest struct { + + // The ID of the attachment. + // example: 01FC31DZT1AYWDZ8XTCRWRBYRK + ID string `form:"id" json:"id"` + + // Description of the media file. + // This will be used as alt-text for users of screenreaders etc. + // allowEmptyValue: true + Description string `form:"description" json:"description"` + + // Focus of the media file. + // If present, it should be in the form of two comma-separated floats between -1 and 1. + // allowEmptyValue: true + Focus string `form:"focus" json:"focus"` +} + // Attachment models a media attachment. // // swagger:model attachment diff --git a/internal/api/model/status.go b/internal/api/model/status.go index 724134b77..ea9fbaa35 100644 --- a/internal/api/model/status.go +++ b/internal/api/model/status.go @@ -197,36 +197,50 @@ type StatusReblogged struct { // // swagger:ignore type StatusCreateRequest struct { + // Text content of the status. // If media_ids is provided, this becomes optional. // Attaching a poll is optional while status is provided. Status string `form:"status" json:"status"` + // Array of Attachment ids to be attached as media. // If provided, status becomes optional, and poll cannot be used. MediaIDs []string `form:"media_ids[]" json:"media_ids"` + // Poll to include with this status. Poll *PollRequest `form:"poll" json:"poll"` + // ID of the status being replied to, if status is a reply. InReplyToID string `form:"in_reply_to_id" json:"in_reply_to_id"` + // Status and attached media should be marked as sensitive. Sensitive bool `form:"sensitive" json:"sensitive"` + // Text to be shown as a warning or subject before the actual content. // Statuses are generally collapsed behind this field. SpoilerText string `form:"spoiler_text" json:"spoiler_text"` + // Visibility of the posted status. Visibility Visibility `form:"visibility" json:"visibility"` - // Set to "true" if this status should not be federated, ie. it should be a "local only" status. + + // Set to "true" if this status should not be + // federated,ie. it should be a "local only" status. LocalOnly *bool `form:"local_only" json:"local_only"` + // Deprecated: Only used if LocalOnly is not set. Federated *bool `form:"federated" json:"federated"` + // ISO 8601 Datetime at which to schedule a status. // Providing this parameter will cause ScheduledStatus to be returned instead of Status. // Must be at least 5 minutes in the future. ScheduledAt string `form:"scheduled_at" json:"scheduled_at"` + // ISO 639 language code for this status. Language string `form:"language" json:"language"` + // Content type to use when parsing this status. ContentType StatusContentType `form:"content_type" json:"content_type"` + // Interaction policy to use for this status. InteractionPolicy *InteractionPolicy `form:"-" json:"interaction_policy"` } @@ -236,6 +250,7 @@ type StatusCreateRequest struct { // // swagger:ignore type StatusInteractionPolicyForm struct { + // Interaction policy to use for this status. InteractionPolicy *InteractionPolicy `form:"interaction_policy" json:"-"` } @@ -250,13 +265,18 @@ type StatusInteractionPolicyForm struct { // VisibilityNone is visible to nobody. This is only used for the visibility of web statuses. VisibilityNone Visibility = "none" // VisibilityPublic is visible to everyone, and will be available via the web even for nonauthenticated users. + VisibilityPublic Visibility = "public" + // VisibilityUnlisted is visible to everyone, but only on home timelines, lists, etc. VisibilityUnlisted Visibility = "unlisted" + // VisibilityPrivate is visible only to followers of the account that posted the status. VisibilityPrivate Visibility = "private" + // VisibilityMutualsOnly is visible only to mutual followers of the account that posted the status. VisibilityMutualsOnly Visibility = "mutuals_only" + // VisibilityDirect is visible only to accounts tagged in the status. It is equivalent to a direct message. VisibilityDirect Visibility = "direct" ) @@ -268,7 +288,8 @@ type StatusInteractionPolicyForm struct { // swagger:type string type StatusContentType string -// Content type to use when parsing submitted status into an html-formatted status +// Content type to use when parsing submitted +// status into an html-formatted status. const ( StatusContentTypePlain StatusContentType = "text/plain" StatusContentTypeMarkdown StatusContentType = "text/markdown" @@ -280,11 +301,14 @@ type StatusInteractionPolicyForm struct { // // swagger:model statusSource type StatusSource struct { + // ID of the status. // example: 01FBVD42CQ3ZEEVMW180SBX03B ID string `json:"id"` + // Plain-text source of a status. Text string `json:"text"` + // Plain-text version of spoiler text. SpoilerText string `json:"spoiler_text"` } @@ -294,27 +318,69 @@ type StatusSource struct { // // swagger:model statusEdit type StatusEdit struct { + // The content of this status at this revision. // Should be HTML, but might also be plaintext in some cases. // example:

Hey this is a status!

Content string `json:"content"` + // Subject, summary, or content warning for the status at this revision. // example: warning nsfw SpoilerText string `json:"spoiler_text"` + // Status marked sensitive at this revision. // example: false Sensitive bool `json:"sensitive"` + // The date when this revision was created (ISO 8601 Datetime). // example: 2021-07-30T09:20:25+00:00 CreatedAt string `json:"created_at"` + // The account that authored this status. Account *Account `json:"account"` + // The poll attached to the status at this revision. // Note that edits changing the poll options will be collapsed together into one edit, since this action resets the poll. // nullable: true Poll *Poll `json:"poll"` + // Media that is attached to this status. MediaAttachments []*Attachment `json:"media_attachments"` + // Custom emoji to be used when rendering status content. Emojis []Emoji `json:"emojis"` } + +// StatusEditRequest models status edit parameters. +// +// swagger:ignore +type StatusEditRequest struct { + + // Text content of the status. + // If media_ids is provided, this becomes optional. + // Attaching a poll is optional while status is provided. + Status string `form:"status" json:"status"` + + // Text to be shown as a warning or subject before the actual content. + // Statuses are generally collapsed behind this field. + SpoilerText string `form:"spoiler_text" json:"spoiler_text"` + + // Content type to use when parsing this status. + ContentType StatusContentType `form:"content_type" json:"content_type"` + + // Status and attached media should be marked as sensitive. + Sensitive bool `form:"sensitive" json:"sensitive"` + + // ISO 639 language code for this status. + Language string `form:"language" json:"language"` + + // Array of Attachment ids to be attached as media. + // If provided, status becomes optional, and poll cannot be used. + MediaIDs []string `form:"media_ids[]" json:"media_ids"` + + // Array of Attachment attributes to be updated in attached media. + MediaAttributes []AttachmentAttributesRequest `form:"media_attributes[]" json:"media_attributes"` + + // Poll to include with this status. + Poll *PollRequest `form:"poll" json:"poll"` +} diff --git a/internal/api/util/parseform.go b/internal/api/util/parseform.go index 3eab065f2..8bb10012c 100644 --- a/internal/api/util/parseform.go +++ b/internal/api/util/parseform.go @@ -18,13 +18,55 @@ package util import ( + "errors" "fmt" "strconv" + "strings" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/util" ) +// ParseFocus parses a media attachment focus parameters from incoming API string. +func ParseFocus(focus string) (focusx, focusy float32, errWithCode gtserror.WithCode) { + if focus == "" { + return + } + spl := strings.Split(focus, ",") + if len(spl) != 2 { + const text = "missing comma separator" + errWithCode = gtserror.NewErrorBadRequest( + errors.New(text), + text, + ) + return + } + xStr := spl[0] + yStr := spl[1] + fx, err := strconv.ParseFloat(xStr, 32) + if err != nil || fx > 1 || fx < -1 { + text := fmt.Sprintf("invalid x focus: %s", xStr) + errWithCode = gtserror.NewErrorBadRequest( + errors.New(text), + text, + ) + return + } + fy, err := strconv.ParseFloat(yStr, 32) + if err != nil || fy > 1 || fy < -1 { + text := fmt.Sprintf("invalid y focus: %s", xStr) + errWithCode = gtserror.NewErrorBadRequest( + errors.New(text), + text, + ) + return + } + focusx = float32(fx) + focusy = float32(fy) + return +} + // ParseDuration parses the given raw interface belonging // the given fieldName as an integer duration. func ParseDuration(rawI any, fieldName string) (*int, error) { diff --git a/internal/db/bundb/status.go b/internal/db/bundb/status.go index fa31f3459..fea5594dd 100644 --- a/internal/db/bundb/status.go +++ b/internal/db/bundb/status.go @@ -297,17 +297,6 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) } } - if !status.EditsPopulated() { - // Status edits are out-of-date with IDs, repopulate. - status.Edits, err = s.state.DB.GetStatusEditsByIDs( - gtscontext.SetBarebones(ctx), - status.EditIDs, - ) - if err != nil { - errs.Appendf("error populating status edits: %w", err) - } - } - if status.CreatedWithApplicationID != "" && status.CreatedWithApplication == nil { // Populate the status' expected CreatedWithApplication (not always set). status.CreatedWithApplication, err = s.state.DB.GetApplicationByID( @@ -322,6 +311,23 @@ func (s *statusDB) PopulateStatus(ctx context.Context, status *gtsmodel.Status) return errs.Combine() } +func (s *statusDB) PopulateStatusEdits(ctx context.Context, status *gtsmodel.Status) error { + var err error + + if !status.EditsPopulated() { + // Status edits are out-of-date with IDs, repopulate. + status.Edits, err = s.state.DB.GetStatusEditsByIDs( + gtscontext.SetBarebones(ctx), + status.EditIDs, + ) + if err != nil { + return gtserror.Newf("error populating status edits: %w", err) + } + } + + return nil +} + func (s *statusDB) PutStatus(ctx context.Context, status *gtsmodel.Status) error { return s.state.Caches.DB.Status.Store(status, func() error { // It is safe to run this database transaction within cache.Store diff --git a/internal/db/status.go b/internal/db/status.go index ade900728..6bf9653c8 100644 --- a/internal/db/status.go +++ b/internal/db/status.go @@ -41,8 +41,12 @@ type Status interface { GetStatusBoost(ctx context.Context, boostOfID string, byAccountID string) (*gtsmodel.Status, error) // PopulateStatus ensures that all sub-models of a status are populated (e.g. mentions, attachments, etc). + // Except for edits, to fetch these please call PopulateStatusEdits() . PopulateStatus(ctx context.Context, status *gtsmodel.Status) error + // PopulateStatusEdits ensures that status' edits are fully popualted. + PopulateStatusEdits(ctx context.Context, status *gtsmodel.Status) error + // PutStatus stores one status in the database. PutStatus(ctx context.Context, status *gtsmodel.Status) error diff --git a/internal/federation/dereferencing/dereferencer.go b/internal/federation/dereferencing/dereferencer.go index 3bff0d1a2..5e7b2b9c0 100644 --- a/internal/federation/dereferencing/dereferencer.go +++ b/internal/federation/dereferencing/dereferencer.go @@ -66,7 +66,7 @@ // causing loads of dereferencing calls. Fresh = util.Ptr(FreshnessWindow(5 * time.Minute)) - // 10 seconds. + // 5 seconds. // // Freshest is useful when you want an // immediately up to date model of something @@ -74,7 +74,7 @@ // // Be careful using this one; it can cause // lots of unnecessary traffic if used unwisely. - Freshest = util.Ptr(FreshnessWindow(10 * time.Second)) + Freshest = util.Ptr(FreshnessWindow(5 * time.Second)) ) // Dereferencer wraps logic and functionality for doing dereferencing diff --git a/internal/federation/dereferencing/status.go b/internal/federation/dereferencing/status.go index d19669891..0a75a4802 100644 --- a/internal/federation/dereferencing/status.go +++ b/internal/federation/dereferencing/status.go @@ -35,6 +35,7 @@ "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/superseriousbusiness/gotosocial/internal/util/xslices" ) // statusFresh returns true if the given status is still @@ -1000,12 +1001,21 @@ func (d *Dereferencer) fetchStatusEmojis( // Set latest emojis. status.Emojis = emojis - // Iterate over and set changed emoji IDs. + // Extract IDs from latest slice of emojis. status.EmojiIDs = make([]string, len(emojis)) for i, emoji := range emojis { status.EmojiIDs[i] = emoji.ID } + // Combine both old and new emojis, as statuses.emojis + // keeps track of emojis for both old and current edits. + status.EmojiIDs = append(status.EmojiIDs, existing.EmojiIDs...) + status.Emojis = append(status.Emojis, existing.Emojis...) + status.EmojiIDs = xslices.Deduplicate(status.EmojiIDs) + status.Emojis = xslices.DeduplicateFunc(status.Emojis, + func(e *gtsmodel.Emoji) string { return e.ID }, + ) + return true, nil } @@ -1118,10 +1128,10 @@ func (d *Dereferencer) handleStatusEdit( var edited bool // Preallocate max slice length. - cols = make([]string, 0, 13) + cols = make([]string, 1, 13) // Always update `fetched_at`. - cols = append(cols, "fetched_at") + cols[0] = "fetched_at" // Check for edited status content. if existing.Content != status.Content { @@ -1187,6 +1197,13 @@ func (d *Dereferencer) handleStatusEdit( // Attached emojis changed. cols = append(cols, "emojis") // i.e. EmojiIDs + // We specifically store both *new* AND *old* edit + // revision emojis in the statuses.emojis column. + emojiByID := func(e *gtsmodel.Emoji) string { return e.ID } + status.Emojis = append(status.Emojis, existing.Emojis...) + status.Emojis = xslices.DeduplicateFunc(status.Emojis, emojiByID) + status.EmojiIDs = xslices.Gather(status.EmojiIDs[:0], status.Emojis, emojiByID) + // Emojis changed doesn't necessarily // indicate an edit, it may just not have // been previously populated properly. @@ -1230,7 +1247,8 @@ func (d *Dereferencer) handleStatusEdit( // Poll only set if existing contained them. edit.PollOptions = existing.Poll.Options - if !*existing.Poll.HideCounts || pollChanged { + if pollChanged || !*existing.Poll.HideCounts || + !existing.Poll.ClosedAt.IsZero() { // If the counts are allowed to be // shown, or poll has changed, then // include poll vote counts in edit. diff --git a/internal/processing/admin/rule.go b/internal/processing/admin/rule.go index d1ee63cc8..8134c21cd 100644 --- a/internal/processing/admin/rule.go +++ b/internal/processing/admin/rule.go @@ -27,6 +27,7 @@ "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/util" ) @@ -42,7 +43,7 @@ func (p *Processor) RulesGet( apiRules := make([]*apimodel.AdminInstanceRule, len(rules)) for i := range rules { - apiRules[i] = p.converter.InstanceRuleToAdminAPIRule(&rules[i]) + apiRules[i] = typeutils.InstanceRuleToAdminAPIRule(&rules[i]) } return apiRules, nil @@ -58,7 +59,7 @@ func (p *Processor) RuleGet(ctx context.Context, id string) (*apimodel.AdminInst return nil, gtserror.NewErrorInternalError(err) } - return p.converter.InstanceRuleToAdminAPIRule(rule), nil + return typeutils.InstanceRuleToAdminAPIRule(rule), nil } // RuleCreate adds a new rule to the instance. @@ -77,7 +78,7 @@ func (p *Processor) RuleCreate(ctx context.Context, form *apimodel.InstanceRuleC return nil, gtserror.NewErrorInternalError(err) } - return p.converter.InstanceRuleToAdminAPIRule(rule), nil + return typeutils.InstanceRuleToAdminAPIRule(rule), nil } // RuleUpdate updates text for an existing rule. @@ -99,7 +100,7 @@ func (p *Processor) RuleUpdate(ctx context.Context, id string, form *apimodel.In return nil, gtserror.NewErrorInternalError(err) } - return p.converter.InstanceRuleToAdminAPIRule(updatedRule), nil + return typeutils.InstanceRuleToAdminAPIRule(updatedRule), nil } // RuleDelete deletes an existing rule. @@ -120,5 +121,5 @@ func (p *Processor) RuleDelete(ctx context.Context, id string) (*apimodel.AdminI return nil, gtserror.NewErrorInternalError(err) } - return p.converter.InstanceRuleToAdminAPIRule(deletedRule), nil + return typeutils.InstanceRuleToAdminAPIRule(deletedRule), nil } diff --git a/internal/processing/common/status.go b/internal/processing/common/status.go index da5cf1290..01f2ab72d 100644 --- a/internal/processing/common/status.go +++ b/internal/processing/common/status.go @@ -31,6 +31,40 @@ "github.com/superseriousbusiness/gotosocial/internal/log" ) +// GetOwnStatus fetches the given status with ID, +// and ensures that it belongs to given requester. +func (p *Processor) GetOwnStatus( + ctx context.Context, + requester *gtsmodel.Account, + targetID string, +) ( + *gtsmodel.Status, + gtserror.WithCode, +) { + target, err := p.state.DB.GetStatusByID(ctx, targetID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err := gtserror.Newf("error getting from db: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + switch { + case target == nil: + const text = "target status not found" + return nil, gtserror.NewErrorNotFound( + errors.New(text), + text, + ) + + case target.AccountID != requester.ID: + return nil, gtserror.NewErrorNotFound( + errors.New("status does not belong to requester"), + "target status not found", + ) + } + + return target, nil +} + // GetTargetStatusBy fetches the target status with db load // function, given the authorized (or, nil) requester's // account. This returns an approprate gtserror.WithCode diff --git a/internal/processing/instance.go b/internal/processing/instance.go index fab71b1de..2f4c40416 100644 --- a/internal/processing/instance.go +++ b/internal/processing/instance.go @@ -29,6 +29,7 @@ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/text" + "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/validate" ) @@ -133,7 +134,7 @@ func (p *Processor) InstanceGetRules(ctx context.Context) ([]apimodel.InstanceRu return nil, gtserror.NewErrorInternalError(fmt.Errorf("db error fetching instance: %s", err)) } - return p.converter.InstanceRulesToAPIRules(i.Rules), nil + return typeutils.InstanceRulesToAPIRules(i.Rules), nil } func (p *Processor) InstancePatch(ctx context.Context, form *apimodel.InstanceSettingsUpdateRequest) (*apimodel.InstanceV1, gtserror.WithCode) { diff --git a/internal/processing/media/create.go b/internal/processing/media/create.go index ca1f1c3c6..5ea630618 100644 --- a/internal/processing/media/create.go +++ b/internal/processing/media/create.go @@ -25,6 +25,7 @@ "codeberg.org/gruf/go-iotools" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -45,10 +46,21 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form } // Parse focus details from API form input. - focusX, focusY, err := parseFocus(form.Focus) - if err != nil { - text := fmt.Sprintf("could not parse focus value %s: %s", form.Focus, err) - return nil, gtserror.NewErrorBadRequest(errors.New(text), text) + focusX, focusY, errWithCode := apiutil.ParseFocus(form.Focus) + if errWithCode != nil { + return nil, errWithCode + } + + // If description provided, + // process and validate it. + // + // This may not yet be set as it + // is often set on status post. + if form.Description != "" { + form.Description, errWithCode = processDescription(form.Description) + if errWithCode != nil { + return nil, errWithCode + } } // Open multipart file reader. @@ -58,7 +70,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form return nil, gtserror.NewErrorInternalError(err) } - // Wrap the multipart file reader to ensure is limited to max. + // Wrap multipart file reader to ensure is limited to max size. rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, maxszInt64) // Create local media and write to instance storage. diff --git a/internal/processing/media/update.go b/internal/processing/media/update.go index d3a9cfe61..c8592395f 100644 --- a/internal/processing/media/update.go +++ b/internal/processing/media/update.go @@ -23,6 +23,8 @@ "fmt" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" @@ -47,17 +49,27 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, media var updatingColumns []string if form.Description != nil { - attachment.Description = text.SanitizeToPlaintext(*form.Description) + // Sanitize and validate incoming description. + description, errWithCode := processDescription( + *form.Description, + ) + if errWithCode != nil { + return nil, errWithCode + } + + attachment.Description = description updatingColumns = append(updatingColumns, "description") } if form.Focus != nil { - focusx, focusy, err := parseFocus(*form.Focus) - if err != nil { - return nil, gtserror.NewErrorBadRequest(err) + // Parse focus details from API form input. + focusX, focusY, errWithCode := apiutil.ParseFocus(*form.Focus) + if errWithCode != nil { + return nil, errWithCode } - attachment.FileMeta.Focus.X = focusx - attachment.FileMeta.Focus.Y = focusy + + attachment.FileMeta.Focus.X = focusX + attachment.FileMeta.Focus.Y = focusY updatingColumns = append(updatingColumns, "focus_x", "focus_y") } @@ -72,3 +84,21 @@ func (p *Processor) Update(ctx context.Context, account *gtsmodel.Account, media return &a, nil } + +// processDescription will sanitize and valid description against server configuration. +func processDescription(description string) (string, gtserror.WithCode) { + description = text.SanitizeToPlaintext(description) + chars := len([]rune(description)) + + if min := config.GetMediaDescriptionMinChars(); chars < min { + text := fmt.Sprintf("media description less than min chars (%d)", min) + return "", gtserror.NewErrorBadRequest(errors.New(text), text) + } + + if max := config.GetMediaDescriptionMaxChars(); chars > max { + text := fmt.Sprintf("media description exceeds max chars (%d)", max) + return "", gtserror.NewErrorBadRequest(errors.New(text), text) + } + + return description, nil +} diff --git a/internal/processing/media/util.go b/internal/processing/media/util.go deleted file mode 100644 index 0ca2697fd..000000000 --- a/internal/processing/media/util.go +++ /dev/null @@ -1,62 +0,0 @@ -// GoToSocial -// Copyright (C) GoToSocial Authors admin@gotosocial.org -// SPDX-License-Identifier: AGPL-3.0-or-later -// -// This program is free software: you can redistribute it and/or modify -// it under the terms of the GNU Affero General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Affero General Public License for more details. -// -// You should have received a copy of the GNU Affero General Public License -// along with this program. If not, see . - -package media - -import ( - "fmt" - "strconv" - "strings" -) - -func parseFocus(focus string) (focusx, focusy float32, err error) { - if focus == "" { - return - } - spl := strings.Split(focus, ",") - if len(spl) != 2 { - err = fmt.Errorf("improperly formatted focus %s", focus) - return - } - xStr := spl[0] - yStr := spl[1] - if xStr == "" || yStr == "" { - err = fmt.Errorf("improperly formatted focus %s", focus) - return - } - fx, err := strconv.ParseFloat(xStr, 32) - if err != nil { - err = fmt.Errorf("improperly formatted focus %s: %s", focus, err) - return - } - if fx > 1 || fx < -1 { - err = fmt.Errorf("improperly formatted focus %s", focus) - return - } - focusx = float32(fx) - fy, err := strconv.ParseFloat(yStr, 32) - if err != nil { - err = fmt.Errorf("improperly formatted focus %s: %s", focus, err) - return - } - if fy > 1 || fy < -1 { - err = fmt.Errorf("improperly formatted focus %s", focus) - return - } - focusy = float32(fy) - return -} diff --git a/internal/processing/status/common.go b/internal/processing/status/common.go new file mode 100644 index 000000000..3f2b7b6cb --- /dev/null +++ b/internal/processing/status/common.go @@ -0,0 +1,351 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package status + +import ( + "context" + "errors" + "fmt" + "time" + + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/config" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/text" + "github.com/superseriousbusiness/gotosocial/internal/util/xslices" + "github.com/superseriousbusiness/gotosocial/internal/validate" +) + +// validateStatusContent will validate the common +// content fields across status write endpoints against +// current server configuration (e.g. max char counts). +func validateStatusContent( + status string, + spoiler string, + mediaIDs []string, + poll *apimodel.PollRequest, +) gtserror.WithCode { + totalChars := len([]rune(status)) + + len([]rune(spoiler)) + + if totalChars == 0 && len(mediaIDs) == 0 && poll == nil { + const text = "status contains no text, media or poll" + return gtserror.NewErrorBadRequest(errors.New(text), text) + } + + if max := config.GetStatusesMaxChars(); totalChars > max { + text := fmt.Sprintf("text with spoiler exceed max chars (%d)", max) + return gtserror.NewErrorBadRequest(errors.New(text), text) + } + + if max := config.GetStatusesMediaMaxFiles(); len(mediaIDs) > max { + text := fmt.Sprintf("media files exceed max count (%d)", max) + return gtserror.NewErrorBadRequest(errors.New(text), text) + } + + if poll != nil { + switch max := config.GetStatusesPollMaxOptions(); { + case len(poll.Options) == 0: + const text = "poll cannot have no options" + return gtserror.NewErrorBadRequest(errors.New(text), text) + + case len(poll.Options) > max: + text := fmt.Sprintf("poll options exceed max count (%d)", max) + return gtserror.NewErrorBadRequest(errors.New(text), text) + } + + max := config.GetStatusesPollOptionMaxChars() + for i, option := range poll.Options { + switch l := len([]rune(option)); { + case l == 0: + const text = "poll option cannot be empty" + return gtserror.NewErrorBadRequest(errors.New(text), text) + + case l > max: + text := fmt.Sprintf("poll option %d exceed max chars (%d)", i, max) + return gtserror.NewErrorBadRequest(errors.New(text), text) + } + } + } + + return nil +} + +// statusContent encompasses the set of common processed +// status content fields from status write operations for +// an easily returnable type, without needing to allocate +// an entire gtsmodel.Status{} model. +type statusContent struct { + Content string + ContentWarning string + PollOptions []string + Language string + MentionIDs []string + Mentions []*gtsmodel.Mention + EmojiIDs []string + Emojis []*gtsmodel.Emoji + TagIDs []string + Tags []*gtsmodel.Tag +} + +func (p *Processor) processContent( + ctx context.Context, + author *gtsmodel.Account, + statusID string, + contentType string, + content string, + contentWarning string, + language string, + poll *apimodel.PollRequest, +) ( + *statusContent, + gtserror.WithCode, +) { + if language == "" { + // Ensure we have a status language. + language = author.Settings.Language + if language == "" { + const text = "account default language unset" + return nil, gtserror.NewErrorInternalError( + errors.New(text), + ) + } + } + + var err error + + // Validate + normalize determined language. + language, err = validate.Language(language) + if err != nil { + text := fmt.Sprintf("invalid language tag: %v", err) + return nil, gtserror.NewErrorBadRequest( + errors.New(text), + text, + ) + } + + // format is the currently set text formatting + // function, according to the provided content-type. + var format text.FormatFunc + + if contentType == "" { + // If content type wasn't specified, use + // the author's preferred content-type. + contentType = author.Settings.StatusContentType + } + + switch contentType { + + // Format status according to text/plain. + case "", string(apimodel.StatusContentTypePlain): + format = p.formatter.FromPlain + + // Format status according to text/markdown. + case string(apimodel.StatusContentTypeMarkdown): + format = p.formatter.FromMarkdown + + // Unknown. + default: + const text = "invalid status format" + return nil, gtserror.NewErrorBadRequest( + errors.New(text), + text, + ) + } + + // Allocate a structure to hold the + // majority of formatted content without + // needing to alloc a whole gtsmodel.Status{}. + var status statusContent + status.Language = language + + // formatInput is a shorthand function to format the given input string with the + // currently set 'formatFunc', passing in all required args and returning result. + formatInput := func(formatFunc text.FormatFunc, input string) *text.FormatResult { + return formatFunc(ctx, p.parseMention, author.ID, statusID, input) + } + + // Sanitize input status text and format. + contentRes := formatInput(format, content) + + // Gather results of formatted. + status.Content = contentRes.HTML + status.Mentions = contentRes.Mentions + status.Emojis = contentRes.Emojis + status.Tags = contentRes.Tags + + // From here-on-out just use emoji-only + // plain-text formatting as the FormatFunc. + format = p.formatter.FromPlainEmojiOnly + + // Sanitize content warning and format. + warning := text.SanitizeToPlaintext(contentWarning) + warningRes := formatInput(format, warning) + + // Gather results of the formatted. + status.ContentWarning = warningRes.HTML + status.Emojis = append(status.Emojis, warningRes.Emojis...) + + if poll != nil { + // Pre-allocate slice of poll options of expected length. + status.PollOptions = make([]string, len(poll.Options)) + for i, option := range poll.Options { + + // Sanitize each poll option and format. + option = text.SanitizeToPlaintext(option) + optionRes := formatInput(format, option) + + // Gather results of the formatted. + status.PollOptions[i] = optionRes.HTML + status.Emojis = append(status.Emojis, optionRes.Emojis...) + } + + // Also update options on the form. + poll.Options = status.PollOptions + } + + // We may have received multiple copies of the same emoji, deduplicate these first. + status.Emojis = xslices.DeduplicateFunc(status.Emojis, func(e *gtsmodel.Emoji) string { + return e.ID + }) + + // Gather up the IDs of mentions from parsed content. + status.MentionIDs = xslices.Gather(nil, status.Mentions, + func(m *gtsmodel.Mention) string { + return m.ID + }, + ) + + // Gather up the IDs of tags from parsed content. + status.TagIDs = xslices.Gather(nil, status.Tags, + func(t *gtsmodel.Tag) string { + return t.ID + }, + ) + + // Gather up the IDs of emojis in updated content. + status.EmojiIDs = xslices.Gather(nil, status.Emojis, + func(e *gtsmodel.Emoji) string { + return e.ID + }, + ) + + return &status, nil +} + +func (p *Processor) processMedia( + ctx context.Context, + authorID string, + statusID string, + mediaIDs []string, +) ( + []*gtsmodel.MediaAttachment, + gtserror.WithCode, +) { + // No media provided! + if len(mediaIDs) == 0 { + return nil, nil + } + + // Get configured min/max supported descr chars. + minChars := config.GetMediaDescriptionMinChars() + maxChars := config.GetMediaDescriptionMaxChars() + + // Pre-allocate slice of media attachments of expected length. + attachments := make([]*gtsmodel.MediaAttachment, len(mediaIDs)) + for i, id := range mediaIDs { + + // Look for media attachment by ID in database. + media, err := p.state.DB.GetAttachmentByID(ctx, id) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + err := gtserror.Newf("error getting media from db: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // Check media exists and is owned by author + // (this masks finding out media ownership info). + if media == nil || media.AccountID != authorID { + text := fmt.Sprintf("media not found: %s", id) + return nil, gtserror.NewErrorBadRequest(errors.New(text), text) + } + + // Check media isn't already attached to another status. + if (media.StatusID != "" && media.StatusID != statusID) || + (media.ScheduledStatusID != "" && media.ScheduledStatusID != statusID) { + text := fmt.Sprintf("media already attached to status: %s", id) + return nil, gtserror.NewErrorBadRequest(errors.New(text), text) + } + + // Check media description chars within range, + // this needs to be done here as lots of clients + // only update media description on status post. + switch chars := len([]rune(media.Description)); { + case chars < minChars: + text := fmt.Sprintf("media description less than min chars (%d)", minChars) + return nil, gtserror.NewErrorBadRequest(errors.New(text), text) + + case chars > maxChars: + text := fmt.Sprintf("media description exceeds max chars (%d)", maxChars) + return nil, gtserror.NewErrorBadRequest(errors.New(text), text) + } + + // Set media at index. + attachments[i] = media + } + + return attachments, nil +} + +func (p *Processor) processPoll( + ctx context.Context, + statusID string, + form *apimodel.PollRequest, + now time.Time, // used for expiry time +) ( + *gtsmodel.Poll, + gtserror.WithCode, +) { + var expiresAt time.Time + + // Set an expiry time if one given. + if in := form.ExpiresIn; in > 0 { + expiresIn := time.Duration(in) + expiresAt = now.Add(expiresIn * time.Second) + } + + // Create new poll model. + poll := >smodel.Poll{ + ID: id.NewULIDFromTime(now), + Multiple: &form.Multiple, + HideCounts: &form.HideTotals, + Options: form.Options, + StatusID: statusID, + ExpiresAt: expiresAt, + } + + // Insert the newly created poll model in the database. + if err := p.state.DB.PutPoll(ctx, poll); err != nil { + err := gtserror.Newf("error inserting poll in db: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + return poll, nil +} diff --git a/internal/processing/status/create.go b/internal/processing/status/create.go index 340cf9ff3..af9831b9c 100644 --- a/internal/processing/status/create.go +++ b/internal/processing/status/create.go @@ -19,29 +19,22 @@ import ( "context" - "errors" - "fmt" "time" "github.com/superseriousbusiness/gotosocial/internal/ap" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" - "github.com/superseriousbusiness/gotosocial/internal/config" - "github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/messages" - "github.com/superseriousbusiness/gotosocial/internal/text" "github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/uris" "github.com/superseriousbusiness/gotosocial/internal/util" - "github.com/superseriousbusiness/gotosocial/internal/util/xslices" ) // Create processes the given form to create a new status, returning the api model representation of that status if it's OK. -// -// Precondition: the form's fields should have already been validated and normalized by the caller. +// Note this also handles validation of incoming form field data. func (p *Processor) Create( ctx context.Context, requester *gtsmodel.Account, @@ -51,7 +44,17 @@ func (p *Processor) Create( *apimodel.Status, gtserror.WithCode, ) { - // Ensure account populated; we'll need settings. + // Validate incoming form status content. + if errWithCode := validateStatusContent( + form.Status, + form.SpoilerText, + form.MediaIDs, + form.Poll, + ); errWithCode != nil { + return nil, errWithCode + } + + // Ensure account populated; we'll need their settings. if err := p.state.DB.PopulateAccount(ctx, requester); err != nil { log.Errorf(ctx, "error(s) populating account, will continue: %s", err) } @@ -59,6 +62,30 @@ func (p *Processor) Create( // Generate new ID for status. statusID := id.NewULID() + // Process incoming status content fields. + content, errWithCode := p.processContent(ctx, + requester, + statusID, + string(form.ContentType), + form.Status, + form.SpoilerText, + form.Language, + form.Poll, + ) + if errWithCode != nil { + return nil, errWithCode + } + + // Process incoming status attachments. + media, errWithCode := p.processMedia(ctx, + requester.ID, + statusID, + form.MediaIDs, + ) + if errWithCode != nil { + return nil, errWithCode + } + // Generate necessary URIs for username, to build status URIs. accountURIs := uris.GenerateURIsForAccount(requester.Username) @@ -78,16 +105,36 @@ func (p *Processor) Create( ActivityStreamsType: ap.ObjectNote, Sensitive: &form.Sensitive, CreatedWithApplicationID: application.ID, - Text: form.Status, + + // Set validated language. + Language: content.Language, + + // Set formatted status content. + Content: content.Content, + ContentWarning: content.ContentWarning, + Text: form.Status, // raw + + // Set gathered mentions. + MentionIDs: content.MentionIDs, + Mentions: content.Mentions, + + // Set gathered emojis. + EmojiIDs: content.EmojiIDs, + Emojis: content.Emojis, + + // Set gathered tags. + TagIDs: content.TagIDs, + Tags: content.Tags, + + // Set gathered media. + AttachmentIDs: form.MediaIDs, + Attachments: media, // Assume not pending approval; this may // change when permissivity is checked. PendingApproval: util.Ptr(false), } - // Process any attached poll. - p.processPoll(status, form.Poll) - // Check + attach in-reply-to status. if errWithCode := p.processInReplyTo(ctx, requester, @@ -101,10 +148,6 @@ func (p *Processor) Create( return nil, errWithCode } - if errWithCode := p.processMediaIDs(ctx, form, requester.ID, status); errWithCode != nil { - return nil, errWithCode - } - if err := p.processVisibility(ctx, form, requester.Settings.Privacy, status); err != nil { return nil, gtserror.NewErrorInternalError(err) } @@ -115,36 +158,49 @@ func (p *Processor) Create( return nil, errWithCode } - if err := processLanguage(form, requester.Settings.Language, status); err != nil { - return nil, gtserror.NewErrorInternalError(err) + if status.ContentWarning != "" && len(status.AttachmentIDs) > 0 { + // If a content-warning is set, and + // the status contains media, always + // set the status sensitive flag. + status.Sensitive = util.Ptr(true) } - if err := p.processContent(ctx, p.parseMention, form, status); err != nil { - return nil, gtserror.NewErrorInternalError(err) - } - - if status.Poll != nil { - // Try to insert the new status poll in the database. - if err := p.state.DB.PutPoll(ctx, status.Poll); err != nil { - err := gtserror.Newf("error inserting poll in db: %w", err) - return nil, gtserror.NewErrorInternalError(err) + if form.Poll != nil { + // Process poll, inserting into database. + poll, errWithCode := p.processPoll(ctx, + statusID, + form.Poll, + now, + ) + if errWithCode != nil { + return nil, errWithCode } + + // Set poll and its ID + // on status before insert. + status.PollID = poll.ID + status.Poll = poll + poll.Status = status + + // Update the status' ActivityPub type to Question. + status.ActivityStreamsType = ap.ActivityQuestion } - // Insert this new status in the database. + // Insert this newly prepared status into the database. if err := p.state.DB.PutStatus(ctx, status); err != nil { + err := gtserror.Newf("error inserting status in db: %w", err) return nil, gtserror.NewErrorInternalError(err) } if status.Poll != nil && !status.Poll.ExpiresAt.IsZero() { - // Now that the status is inserted, and side effects queued, - // attempt to schedule an expiry handler for the status poll. + // Now that the status is inserted, attempt to + // schedule an expiry handler for the status poll. if err := p.polls.ScheduleExpiry(ctx, status.Poll); err != nil { log.Errorf(ctx, "error scheduling poll expiry: %v", err) } } - // send it back to the client API worker for async side-effects. + // Send it to the client API worker for async side-effects. p.state.Workers.Client.Queue.Push(&messages.FromClientAPI{ APObjectType: ap.ObjectNote, APActivityType: ap.ActivityCreate, @@ -172,43 +228,6 @@ func (p *Processor) Create( return p.c.GetAPIStatus(ctx, requester, status) } -func (p *Processor) processPoll(status *gtsmodel.Status, poll *apimodel.PollRequest) { - if poll == nil { - // No poll set. - // Nothing to do. - return - } - - var expiresAt time.Time - - // Now will have been set - // as the status creation. - now := status.CreatedAt - - // Update the status AS type to "Question". - status.ActivityStreamsType = ap.ActivityQuestion - - // Set an expiry time if one given. - if in := poll.ExpiresIn; in > 0 { - expiresIn := time.Duration(in) - expiresAt = now.Add(expiresIn * time.Second) - } - - // Create new poll for status. - status.Poll = >smodel.Poll{ - ID: id.NewULID(), - Multiple: &poll.Multiple, - HideCounts: &poll.HideTotals, - Options: poll.Options, - StatusID: status.ID, - Status: status, - ExpiresAt: expiresAt, - } - - // Set poll ID on the status. - status.PollID = status.Poll.ID -} - func (p *Processor) processInReplyTo(ctx context.Context, requester *gtsmodel.Account, status *gtsmodel.Status, inReplyToID string) gtserror.WithCode { if inReplyToID == "" { // Not a reply. @@ -332,53 +351,6 @@ func (p *Processor) processThreadID(ctx context.Context, status *gtsmodel.Status return nil } -func (p *Processor) processMediaIDs(ctx context.Context, form *apimodel.StatusCreateRequest, thisAccountID string, status *gtsmodel.Status) gtserror.WithCode { - if form.MediaIDs == nil { - return nil - } - - // Get minimum allowed char descriptions. - minChars := config.GetMediaDescriptionMinChars() - - attachments := []*gtsmodel.MediaAttachment{} - attachmentIDs := []string{} - - for _, mediaID := range form.MediaIDs { - attachment, err := p.state.DB.GetAttachmentByID(ctx, mediaID) - if err != nil && !errors.Is(err, db.ErrNoEntries) { - err := gtserror.Newf("error fetching media from db: %w", err) - return gtserror.NewErrorInternalError(err) - } - - if attachment == nil { - text := fmt.Sprintf("media %s not found", mediaID) - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - if attachment.AccountID != thisAccountID { - text := fmt.Sprintf("media %s does not belong to account", mediaID) - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - if attachment.StatusID != "" || attachment.ScheduledStatusID != "" { - text := fmt.Sprintf("media %s already attached to status", mediaID) - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - if length := len([]rune(attachment.Description)); length < minChars { - text := fmt.Sprintf("media %s description too short, at least %d required", mediaID, minChars) - return gtserror.NewErrorBadRequest(errors.New(text), text) - } - - attachments = append(attachments, attachment) - attachmentIDs = append(attachmentIDs, attachment.ID) - } - - status.Attachments = attachments - status.AttachmentIDs = attachmentIDs - return nil -} - func (p *Processor) processVisibility( ctx context.Context, form *apimodel.StatusCreateRequest, @@ -474,99 +446,3 @@ func processInteractionPolicy( // setting it explicitly to save space. return nil } - -func processLanguage(form *apimodel.StatusCreateRequest, accountDefaultLanguage string, status *gtsmodel.Status) error { - if form.Language != "" { - status.Language = form.Language - } else { - status.Language = accountDefaultLanguage - } - if status.Language == "" { - return errors.New("no language given either in status create form or account default") - } - return nil -} - -func (p *Processor) processContent(ctx context.Context, parseMention gtsmodel.ParseMentionFunc, form *apimodel.StatusCreateRequest, status *gtsmodel.Status) error { - if form.ContentType == "" { - // If content type wasn't specified, use the author's preferred content-type. - contentType := apimodel.StatusContentType(status.Account.Settings.StatusContentType) - form.ContentType = contentType - } - - // format is the currently set text formatting - // function, according to the provided content-type. - var format text.FormatFunc - - // formatInput is a shorthand function to format the given input string with the - // currently set 'formatFunc', passing in all required args and returning result. - formatInput := func(formatFunc text.FormatFunc, input string) *text.FormatResult { - return formatFunc(ctx, parseMention, status.AccountID, status.ID, input) - } - - switch form.ContentType { - // None given / set, - // use default (plain). - case "": - fallthrough - - // Format status according to text/plain. - case apimodel.StatusContentTypePlain: - format = p.formatter.FromPlain - - // Format status according to text/markdown. - case apimodel.StatusContentTypeMarkdown: - format = p.formatter.FromMarkdown - - // Unknown. - default: - return fmt.Errorf("invalid status format: %q", form.ContentType) - } - - // Sanitize status text and format. - contentRes := formatInput(format, form.Status) - - // Collect formatted results. - status.Content = contentRes.HTML - status.Mentions = append(status.Mentions, contentRes.Mentions...) - status.Emojis = append(status.Emojis, contentRes.Emojis...) - status.Tags = append(status.Tags, contentRes.Tags...) - - // From here-on-out just use emoji-only - // plain-text formatting as the FormatFunc. - format = p.formatter.FromPlainEmojiOnly - - // Sanitize content warning and format. - spoiler := text.SanitizeToPlaintext(form.SpoilerText) - warningRes := formatInput(format, spoiler) - - // Collect formatted results. - status.ContentWarning = warningRes.HTML - status.Emojis = append(status.Emojis, warningRes.Emojis...) - - if status.Poll != nil { - for i := range status.Poll.Options { - // Sanitize each option title name and format. - option := text.SanitizeToPlaintext(status.Poll.Options[i]) - optionRes := formatInput(format, option) - - // Collect each formatted result. - status.Poll.Options[i] = optionRes.HTML - status.Emojis = append(status.Emojis, optionRes.Emojis...) - } - } - - // Gather all the database IDs from each of the gathered status mentions, tags, and emojis. - status.MentionIDs = xslices.Gather(nil, status.Mentions, func(mention *gtsmodel.Mention) string { return mention.ID }) - status.TagIDs = xslices.Gather(nil, status.Tags, func(tag *gtsmodel.Tag) string { return tag.ID }) - status.EmojiIDs = xslices.Gather(nil, status.Emojis, func(emoji *gtsmodel.Emoji) string { return emoji.ID }) - - if status.ContentWarning != "" && len(status.AttachmentIDs) > 0 { - // If a content-warning is set, and - // the status contains media, always - // set the status sensitive flag. - status.Sensitive = util.Ptr(true) - } - - return nil -} diff --git a/internal/processing/status/create_test.go b/internal/processing/status/create_test.go index 84168880e..d0a5c7f92 100644 --- a/internal/processing/status/create_test.go +++ b/internal/processing/status/create_test.go @@ -170,7 +170,7 @@ func (suite *StatusCreateTestSuite) TestProcessMediaDescriptionTooShort() { } apiStatus, err := suite.status.Create(ctx, creatingAccount, creatingApplication, statusCreateForm) - suite.EqualError(err, "media 01F8MH8RMYQ6MSNY3JM2XT1CQ5 description too short, at least 100 required") + suite.EqualError(err, "media description less than min chars (100)") suite.Nil(apiStatus) } diff --git a/internal/processing/status/edit.go b/internal/processing/status/edit.go new file mode 100644 index 000000000..d16092a57 --- /dev/null +++ b/internal/processing/status/edit.go @@ -0,0 +1,555 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package status + +import ( + "context" + "errors" + "fmt" + "slices" + "time" + + "github.com/superseriousbusiness/gotosocial/internal/ap" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" + "github.com/superseriousbusiness/gotosocial/internal/db" + "github.com/superseriousbusiness/gotosocial/internal/gtserror" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/id" + "github.com/superseriousbusiness/gotosocial/internal/log" + "github.com/superseriousbusiness/gotosocial/internal/messages" + "github.com/superseriousbusiness/gotosocial/internal/util/xslices" +) + +// Edit ... +func (p *Processor) Edit( + ctx context.Context, + requester *gtsmodel.Account, + statusID string, + form *apimodel.StatusEditRequest, +) ( + *apimodel.Status, + gtserror.WithCode, +) { + // Fetch status and ensure it's owned by requesting account. + status, errWithCode := p.c.GetOwnStatus(ctx, requester, statusID) + if errWithCode != nil { + return nil, errWithCode + } + + // Ensure this isn't a boost. + if status.BoostOfID != "" { + return nil, gtserror.NewErrorNotFound( + errors.New("status is a boost wrapper"), + "target status not found", + ) + } + + // Ensure account populated; we'll need their settings. + if err := p.state.DB.PopulateAccount(ctx, requester); err != nil { + log.Errorf(ctx, "error(s) populating account, will continue: %s", err) + } + + // We need the status populated including all historical edits. + if err := p.state.DB.PopulateStatusEdits(ctx, status); err != nil { + err := gtserror.Newf("error getting status edits from db: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // Time of edit. + now := time.Now() + + // Validate incoming form edit content. + if errWithCode := validateStatusContent( + form.Status, + form.SpoilerText, + form.MediaIDs, + form.Poll, + ); errWithCode != nil { + return nil, errWithCode + } + + // Process incoming status edit content fields. + content, errWithCode := p.processContent(ctx, + requester, + statusID, + string(form.ContentType), + form.Status, + form.SpoilerText, + form.Language, + form.Poll, + ) + if errWithCode != nil { + return nil, errWithCode + } + + // Process new status attachments to use. + media, errWithCode := p.processMedia(ctx, + requester.ID, + statusID, + form.MediaIDs, + ) + if errWithCode != nil { + return nil, errWithCode + } + + // Process incoming edits of any attached media. + mediaEdited, errWithCode := p.processMediaEdits(ctx, + media, + form.MediaAttributes, + ) + if errWithCode != nil { + return nil, errWithCode + } + + // Process incoming edits of any attached status poll. + poll, pollEdited, errWithCode := p.processPollEdit(ctx, + statusID, + status.Poll, + form.Poll, + now, + ) + if errWithCode != nil { + return nil, errWithCode + } + + // Check if new status poll was set. + pollChanged := (poll != status.Poll) + + // Determine whether there were any changes possibly + // causing a change to embedded mentions, tags, emojis. + contentChanged := (status.Content != content.Content) + warningChanged := (status.ContentWarning != content.ContentWarning) + languageChanged := (status.Language != content.Language) + anyContentChanged := contentChanged || warningChanged || + pollEdited // encapsulates pollChanged too + + // Check if status media attachments have changed. + mediaChanged := !slices.Equal(status.AttachmentIDs, + form.MediaIDs, + ) + + // Track status columns we + // need to update in database. + cols := make([]string, 2, 13) + cols[0] = "updated_at" + cols[1] = "edits" + + if contentChanged { + // Update status text. + // + // Note we don't update these + // status fields right away so + // we can save current version. + cols = append(cols, "content") + cols = append(cols, "text") + } + + if warningChanged { + // Update status content warning. + // + // Note we don't update these + // status fields right away so + // we can save current version. + cols = append(cols, "content_warning") + } + + if languageChanged { + // Update status language pref. + // + // Note we don't update these + // status fields right away so + // we can save current version. + cols = append(cols, "language") + } + + if *status.Sensitive != form.Sensitive { + // Update status sensitivity pref. + // + // Note we don't update these + // status fields right away so + // we can save current version. + cols = append(cols, "sensitive") + } + + if mediaChanged { + // Updated status media attachments. + // + // Note we don't update these + // status fields right away so + // we can save current version. + cols = append(cols, "attachments") + } + + if pollChanged { + // Updated attached status poll. + // + // Note we don't update these + // status fields right away so + // we can save current version. + cols = append(cols, "poll_id") + + if status.Poll == nil || poll == nil { + // Went from with-poll to without-poll + // or vice-versa. This changes AP type. + cols = append(cols, "activity_streams_type") + } + } + + if anyContentChanged { + if !slices.Equal(status.MentionIDs, content.MentionIDs) { + // Update attached status mentions. + cols = append(cols, "mentions") + status.MentionIDs = content.MentionIDs + status.Mentions = content.Mentions + } + + if !slices.Equal(status.TagIDs, content.TagIDs) { + // Updated attached status tags. + cols = append(cols, "tags") + status.TagIDs = content.TagIDs + status.Tags = content.Tags + } + + if !slices.Equal(status.EmojiIDs, content.EmojiIDs) { + // We specifically store both *new* AND *old* edit + // revision emojis in the statuses.emojis column. + emojiByID := func(e *gtsmodel.Emoji) string { return e.ID } + status.Emojis = append(status.Emojis, content.Emojis...) + status.Emojis = xslices.DeduplicateFunc(status.Emojis, emojiByID) + status.EmojiIDs = xslices.Gather(status.EmojiIDs[:0], status.Emojis, emojiByID) + + // Update attached status emojis. + cols = append(cols, "emojis") + } + } + + // If no status columns were updated, no media and + // no poll were edited, there's nothing to do! + if len(cols) == 2 && !mediaEdited && !pollEdited { + const text = "status was not changed" + return nil, gtserror.NewErrorUnprocessableEntity( + errors.New(text), + text, + ) + } + + // Create an edit to store a + // historical snapshot of status. + var edit gtsmodel.StatusEdit + edit.ID = id.NewULIDFromTime(now) + edit.Content = status.Content + edit.ContentWarning = status.ContentWarning + edit.Text = status.Text + edit.Language = status.Language + edit.Sensitive = status.Sensitive + edit.StatusID = status.ID + edit.CreatedAt = status.UpdatedAt + + // Copy existing media and descriptions. + edit.AttachmentIDs = status.AttachmentIDs + if l := len(status.Attachments); l > 0 { + edit.AttachmentDescriptions = make([]string, l) + for i, attach := range status.Attachments { + edit.AttachmentDescriptions[i] = attach.Description + } + } + + if status.Poll != nil { + // Poll only set if existed previously. + edit.PollOptions = status.Poll.Options + + if pollChanged || !*status.Poll.HideCounts || + !status.Poll.ClosedAt.IsZero() { + // If the counts are allowed to be + // shown, or poll has changed, then + // include poll vote counts in edit. + edit.PollVotes = status.Poll.Votes + } + } + + // Insert this new edit of existing status into database. + if err := p.state.DB.PutStatusEdit(ctx, &edit); err != nil { + err := gtserror.Newf("error putting edit in database: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + // Add edit to list of edits on the status. + status.EditIDs = append(status.EditIDs, edit.ID) + status.Edits = append(status.Edits, &edit) + + // Now historical status data is stored, + // update the other necessary status fields. + status.Content = content.Content + status.ContentWarning = content.ContentWarning + status.Text = form.Status + status.Language = content.Language + status.Sensitive = &form.Sensitive + status.AttachmentIDs = form.MediaIDs + status.Attachments = media + status.UpdatedAt = now + + if poll != nil { + // Set relevent fields for latest with poll. + status.ActivityStreamsType = ap.ActivityQuestion + status.PollID = poll.ID + status.Poll = poll + } else { + // Set relevant fields for latest without poll. + status.ActivityStreamsType = ap.ObjectNote + status.PollID = "" + status.Poll = nil + } + + // Finally update the existing status model in the database. + if err := p.state.DB.UpdateStatus(ctx, status, cols...); err != nil { + err := gtserror.Newf("error updating status in db: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + if pollChanged && status.Poll != nil && !status.Poll.ExpiresAt.IsZero() { + // Now the status is updated, attempt to schedule + // an expiry handler for the changed status poll. + if err := p.polls.ScheduleExpiry(ctx, status.Poll); err != nil { + log.Errorf(ctx, "error scheduling poll expiry: %v", err) + } + } + + // Send it to the client API worker for async side-effects. + p.state.Workers.Client.Queue.Push(&messages.FromClientAPI{ + APObjectType: ap.ObjectNote, + APActivityType: ap.ActivityUpdate, + GTSModel: status, + Origin: requester, + }) + + // Return an API model of the updated status. + return p.c.GetAPIStatus(ctx, requester, status) +} + +// HistoryGet gets edit history for the target status, taking account of privacy settings and blocks etc. +func (p *Processor) HistoryGet(ctx context.Context, requester *gtsmodel.Account, targetStatusID string) ([]*apimodel.StatusEdit, gtserror.WithCode) { + target, errWithCode := p.c.GetVisibleTargetStatus(ctx, + requester, + targetStatusID, + nil, // default freshness + ) + if errWithCode != nil { + return nil, errWithCode + } + + if err := p.state.DB.PopulateStatusEdits(ctx, target); err != nil { + err := gtserror.Newf("error getting status edits from db: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + edits, err := p.converter.StatusToAPIEdits(ctx, target) + if err != nil { + err := gtserror.Newf("error converting status edits: %w", err) + return nil, gtserror.NewErrorInternalError(err) + } + + return edits, nil +} + +func (p *Processor) processMediaEdits( + ctx context.Context, + attachs []*gtsmodel.MediaAttachment, + attrs []apimodel.AttachmentAttributesRequest, +) ( + bool, + gtserror.WithCode, +) { + var edited bool + + for _, attr := range attrs { + // Search the media attachments slice for index of media with attr.ID. + i := slices.IndexFunc(attachs, func(m *gtsmodel.MediaAttachment) bool { + return m.ID == attr.ID + }) + if i == -1 { + text := fmt.Sprintf("media not found: %s", attr.ID) + return false, gtserror.NewErrorBadRequest(errors.New(text), text) + } + + // Get attach at index. + attach := attachs[i] + + // Track which columns need + // updating in database query. + cols := make([]string, 0, 2) + + // Check for description change. + if attr.Description != attach.Description { + attach.Description = attr.Description + cols = append(cols, "description") + } + + if attr.Focus != "" { + // Parse provided media focus parameters from string. + fx, fy, errWithCode := apiutil.ParseFocus(attr.Focus) + if errWithCode != nil { + return false, errWithCode + } + + // Check for change in focus coords. + if attach.FileMeta.Focus.X != fx || + attach.FileMeta.Focus.Y != fy { + attach.FileMeta.Focus.X = fx + attach.FileMeta.Focus.Y = fy + cols = append(cols, "focus_x", "focus_y") + } + } + + if len(cols) > 0 { + // Media attachment was changed, update this in database. + err := p.state.DB.UpdateAttachment(ctx, attach, cols...) + if err != nil { + err := gtserror.Newf("error updating attachment in db: %w", err) + return false, gtserror.NewErrorInternalError(err) + } + + // Set edited. + edited = true + } + } + + return edited, nil +} + +func (p *Processor) processPollEdit( + ctx context.Context, + statusID string, + original *gtsmodel.Poll, + form *apimodel.PollRequest, + now time.Time, // used for expiry time +) ( + *gtsmodel.Poll, + bool, + gtserror.WithCode, +) { + if form == nil { + if original != nil { + // No poll was given but there's an existing poll, + // this indicates the original needs to be deleted. + if err := p.deletePoll(ctx, original); err != nil { + return nil, true, gtserror.NewErrorInternalError(err) + } + + // Existing was deleted. + return nil, true, nil + } + + // No change in poll. + return nil, false, nil + } + + switch { + // No existing poll. + case original == nil: + + // Any change that effects voting, i.e. options, allow multiple + // or re-opening a closed poll requires deleting the existing poll. + case !slices.Equal(form.Options, original.Options) || + (form.Multiple != *original.Multiple) || + (!original.ClosedAt.IsZero() && form.ExpiresIn != 0): + if err := p.deletePoll(ctx, original); err != nil { + return nil, true, gtserror.NewErrorInternalError(err) + } + + // Any other changes only require a model + // update, and at-most a new expiry handler. + default: + var cols []string + + // Check if the hide counts field changed. + if form.HideTotals != *original.HideCounts { + cols = append(cols, "hide_counts") + original.HideCounts = &form.HideTotals + } + + var expiresAt time.Time + + // Determine expiry time if given. + if in := form.ExpiresIn; in > 0 { + expiresIn := time.Duration(in) + expiresAt = now.Add(expiresIn * time.Second) + } + + // Check for expiry time. + if !expiresAt.IsZero() { + + if !original.ExpiresAt.IsZero() { + // Existing had expiry, cancel scheduled handler. + _ = p.state.Workers.Scheduler.Cancel(original.ID) + } + + // Since expiry is given as a duration + // we always treat > 0 as a change as + // we can't know otherwise unfortunately. + cols = append(cols, "expires_at") + original.ExpiresAt = expiresAt + } + + if len(cols) == 0 { + // Were no changes to poll. + return original, false, nil + } + + // Update the original poll model in the database with these columns. + if err := p.state.DB.UpdatePoll(ctx, original, cols...); err != nil { + err := gtserror.Newf("error updating poll.expires_at in db: %w", err) + return nil, true, gtserror.NewErrorInternalError(err) + } + + if !expiresAt.IsZero() { + // Updated poll has an expiry, schedule a new expiry handler. + if err := p.polls.ScheduleExpiry(ctx, original); err != nil { + log.Errorf(ctx, "error scheduling poll expiry: %v", err) + } + } + + // Existing poll was updated. + return original, true, nil + } + + // If we reached here then an entirely + // new status poll needs to be created. + poll, errWithCode := p.processPoll(ctx, + statusID, + form, + now, + ) + return poll, true, errWithCode +} + +func (p *Processor) deletePoll(ctx context.Context, poll *gtsmodel.Poll) error { + if !poll.ExpiresAt.IsZero() && !poll.ClosedAt.IsZero() { + // Poll has an expiry and has not yet closed, + // cancel any expiry handler before deletion. + _ = p.state.Workers.Scheduler.Cancel(poll.ID) + } + + // Delete the given poll from the database. + err := p.state.DB.DeletePollByID(ctx, poll.ID) + if err != nil && !errors.Is(err, db.ErrNoEntries) { + return gtserror.Newf("error deleting poll from db: %w", err) + } + + return nil +} diff --git a/internal/processing/status/edit_test.go b/internal/processing/status/edit_test.go new file mode 100644 index 000000000..393c3efc2 --- /dev/null +++ b/internal/processing/status/edit_test.go @@ -0,0 +1,544 @@ +// GoToSocial +// Copyright (C) GoToSocial Authors admin@gotosocial.org +// SPDX-License-Identifier: AGPL-3.0-or-later +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package status_test + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/suite" + apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" + "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/util" + "github.com/superseriousbusiness/gotosocial/internal/util/xslices" +) + +type StatusEditTestSuite struct { + StatusStandardTestSuite +} + +func (suite *StatusEditTestSuite) TestSimpleEdit() { + // Create cancellable context to use for test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Get a local account to use as test requester. + requester := suite.testAccounts["local_account_1"] + requester, _ = suite.state.DB.GetAccountByID(ctx, requester.ID) + + // Get requester's existing status to perform an edit on. + status := suite.testStatuses["local_account_1_status_9"] + status, _ = suite.state.DB.GetStatusByID(ctx, status.ID) + + // Prepare a simple status edit. + form := &apimodel.StatusEditRequest{ + Status: "

this is some edited status text!

", + SpoilerText: "shhhhh", + Sensitive: true, + Language: "fr", // hoh hoh hoh + MediaIDs: nil, + MediaAttributes: nil, + Poll: nil, + } + + // Pass the prepared form to the status processor to perform the edit. + apiStatus, errWithCode := suite.status.Edit(ctx, requester, status.ID, form) + suite.NotNil(apiStatus) + suite.NoError(errWithCode) + + // Check response against input form data. + suite.Equal(form.Status, apiStatus.Text) + suite.Equal(form.SpoilerText, apiStatus.SpoilerText) + suite.Equal(form.Sensitive, apiStatus.Sensitive) + suite.Equal(form.Language, *apiStatus.Language) + suite.NotEqual(util.FormatISO8601(status.UpdatedAt), *apiStatus.EditedAt) + + // Fetched the latest version of edited status from the database. + latestStatus, err := suite.state.DB.GetStatusByID(ctx, status.ID) + suite.NoError(err) + + // Check latest status against input form data. + suite.Equal(form.Status, latestStatus.Text) + suite.Equal(form.SpoilerText, latestStatus.ContentWarning) + suite.Equal(form.Sensitive, *latestStatus.Sensitive) + suite.Equal(form.Language, latestStatus.Language) + suite.Equal(len(status.EditIDs)+1, len(latestStatus.EditIDs)) + suite.NotEqual(status.UpdatedAt, latestStatus.UpdatedAt) + + // Populate all historical edits for this status. + err = suite.state.DB.PopulateStatusEdits(ctx, latestStatus) + suite.NoError(err) + + // Check previous status edit matches original status content. + previousEdit := latestStatus.Edits[len(latestStatus.Edits)-1] + suite.Equal(status.Content, previousEdit.Content) + suite.Equal(status.Text, previousEdit.Text) + suite.Equal(status.ContentWarning, previousEdit.ContentWarning) + suite.Equal(*status.Sensitive, *previousEdit.Sensitive) + suite.Equal(status.Language, previousEdit.Language) + suite.Equal(status.UpdatedAt, previousEdit.CreatedAt) +} + +func (suite *StatusEditTestSuite) TestEditAddPoll() { + // Create cancellable context to use for test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Get a local account to use as test requester. + requester := suite.testAccounts["local_account_1"] + requester, _ = suite.state.DB.GetAccountByID(ctx, requester.ID) + + // Get requester's existing status to perform an edit on. + status := suite.testStatuses["local_account_1_status_9"] + status, _ = suite.state.DB.GetStatusByID(ctx, status.ID) + + // Prepare edit adding a status poll. + form := &apimodel.StatusEditRequest{ + Status: "

this is some edited status text!

", + SpoilerText: "", + Sensitive: true, + Language: "fr", // hoh hoh hoh + MediaIDs: nil, + MediaAttributes: nil, + Poll: &apimodel.PollRequest{ + Options: []string{"yes", "no", "spiderman"}, + ExpiresIn: int(time.Minute), + Multiple: true, + HideTotals: false, + }, + } + + // Pass the prepared form to the status processor to perform the edit. + apiStatus, errWithCode := suite.status.Edit(ctx, requester, status.ID, form) + suite.NotNil(apiStatus) + suite.NoError(errWithCode) + + // Check response against input form data. + suite.Equal(form.Status, apiStatus.Text) + suite.Equal(form.SpoilerText, apiStatus.SpoilerText) + suite.Equal(form.Sensitive, apiStatus.Sensitive) + suite.Equal(form.Language, *apiStatus.Language) + suite.NotEqual(util.FormatISO8601(status.UpdatedAt), *apiStatus.EditedAt) + suite.NotNil(apiStatus.Poll) + suite.Equal(form.Poll.Options, xslices.Gather(nil, apiStatus.Poll.Options, func(opt apimodel.PollOption) string { + return opt.Title + })) + + // Fetched the latest version of edited status from the database. + latestStatus, err := suite.state.DB.GetStatusByID(ctx, status.ID) + suite.NoError(err) + + // Check latest status against input form data. + suite.Equal(form.Status, latestStatus.Text) + suite.Equal(form.SpoilerText, latestStatus.ContentWarning) + suite.Equal(form.Sensitive, *latestStatus.Sensitive) + suite.Equal(form.Language, latestStatus.Language) + suite.Equal(len(status.EditIDs)+1, len(latestStatus.EditIDs)) + suite.NotEqual(status.UpdatedAt, latestStatus.UpdatedAt) + suite.NotNil(latestStatus.Poll) + suite.Equal(form.Poll.Options, latestStatus.Poll.Options) + + // Ensure that a poll expiry handler was scheduled on status edit. + expiryWorker := suite.state.Workers.Scheduler.Cancel(latestStatus.PollID) + suite.Equal(form.Poll.ExpiresIn > 0, expiryWorker) + + // Populate all historical edits for this status. + err = suite.state.DB.PopulateStatusEdits(ctx, latestStatus) + suite.NoError(err) + + // Check previous status edit matches original status content. + previousEdit := latestStatus.Edits[len(latestStatus.Edits)-1] + suite.Equal(status.Content, previousEdit.Content) + suite.Equal(status.Text, previousEdit.Text) + suite.Equal(status.ContentWarning, previousEdit.ContentWarning) + suite.Equal(*status.Sensitive, *previousEdit.Sensitive) + suite.Equal(status.Language, previousEdit.Language) + suite.Equal(status.UpdatedAt, previousEdit.CreatedAt) + suite.Equal(status.Poll != nil, len(previousEdit.PollOptions) > 0) +} + +func (suite *StatusEditTestSuite) TestEditAddPollNoExpiry() { + // Create cancellable context to use for test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Get a local account to use as test requester. + requester := suite.testAccounts["local_account_1"] + requester, _ = suite.state.DB.GetAccountByID(ctx, requester.ID) + + // Get requester's existing status to perform an edit on. + status := suite.testStatuses["local_account_1_status_9"] + status, _ = suite.state.DB.GetStatusByID(ctx, status.ID) + + // Prepare edit adding an endless poll. + form := &apimodel.StatusEditRequest{ + Status: "

this is some edited status text!

", + SpoilerText: "", + Sensitive: true, + Language: "fr", // hoh hoh hoh + MediaIDs: nil, + MediaAttributes: nil, + Poll: &apimodel.PollRequest{ + Options: []string{"yes", "no", "spiderman"}, + ExpiresIn: 0, + Multiple: true, + HideTotals: false, + }, + } + + // Pass the prepared form to the status processor to perform the edit. + apiStatus, errWithCode := suite.status.Edit(ctx, requester, status.ID, form) + suite.NotNil(apiStatus) + suite.NoError(errWithCode) + + // Check response against input form data. + suite.Equal(form.Status, apiStatus.Text) + suite.Equal(form.SpoilerText, apiStatus.SpoilerText) + suite.Equal(form.Sensitive, apiStatus.Sensitive) + suite.Equal(form.Language, *apiStatus.Language) + suite.NotEqual(util.FormatISO8601(status.UpdatedAt), *apiStatus.EditedAt) + suite.NotNil(apiStatus.Poll) + suite.Equal(form.Poll.Options, xslices.Gather(nil, apiStatus.Poll.Options, func(opt apimodel.PollOption) string { + return opt.Title + })) + + // Fetched the latest version of edited status from the database. + latestStatus, err := suite.state.DB.GetStatusByID(ctx, status.ID) + suite.NoError(err) + + // Check latest status against input form data. + suite.Equal(form.Status, latestStatus.Text) + suite.Equal(form.SpoilerText, latestStatus.ContentWarning) + suite.Equal(form.Sensitive, *latestStatus.Sensitive) + suite.Equal(form.Language, latestStatus.Language) + suite.Equal(len(status.EditIDs)+1, len(latestStatus.EditIDs)) + suite.NotEqual(status.UpdatedAt, latestStatus.UpdatedAt) + suite.NotNil(latestStatus.Poll) + suite.Equal(form.Poll.Options, latestStatus.Poll.Options) + + // Ensure that a poll expiry handler was *not* scheduled on status edit. + expiryWorker := suite.state.Workers.Scheduler.Cancel(latestStatus.PollID) + suite.Equal(form.Poll.ExpiresIn > 0, expiryWorker) + + // Populate all historical edits for this status. + err = suite.state.DB.PopulateStatusEdits(ctx, latestStatus) + suite.NoError(err) + + // Check previous status edit matches original status content. + previousEdit := latestStatus.Edits[len(latestStatus.Edits)-1] + suite.Equal(status.Content, previousEdit.Content) + suite.Equal(status.Text, previousEdit.Text) + suite.Equal(status.ContentWarning, previousEdit.ContentWarning) + suite.Equal(*status.Sensitive, *previousEdit.Sensitive) + suite.Equal(status.Language, previousEdit.Language) + suite.Equal(status.UpdatedAt, previousEdit.CreatedAt) + suite.Equal(status.Poll != nil, len(previousEdit.PollOptions) > 0) +} + +func (suite *StatusEditTestSuite) TestEditMediaDescription() { + // Create cancellable context to use for test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Get a local account to use as test requester. + requester := suite.testAccounts["local_account_1"] + requester, _ = suite.state.DB.GetAccountByID(ctx, requester.ID) + + // Get requester's existing status to perform an edit on. + status := suite.testStatuses["local_account_1_status_4"] + status, _ = suite.state.DB.GetStatusByID(ctx, status.ID) + + // Prepare edit changing media description. + form := &apimodel.StatusEditRequest{ + Status: "

this is some edited status text!

", + SpoilerText: "this status is now missing media", + Sensitive: true, + Language: "en", + MediaIDs: status.AttachmentIDs, + MediaAttributes: []apimodel.AttachmentAttributesRequest{ + {ID: status.AttachmentIDs[0], Description: "hello world!"}, + {ID: status.AttachmentIDs[1], Description: "media attachment numero two"}, + }, + } + + // Pass the prepared form to the status processor to perform the edit. + apiStatus, errWithCode := suite.status.Edit(ctx, requester, status.ID, form) + suite.NoError(errWithCode) + + // Check response against input form data. + suite.Equal(form.Status, apiStatus.Text) + suite.Equal(form.SpoilerText, apiStatus.SpoilerText) + suite.Equal(form.Sensitive, apiStatus.Sensitive) + suite.Equal(form.Language, *apiStatus.Language) + suite.NotEqual(util.FormatISO8601(status.UpdatedAt), *apiStatus.EditedAt) + suite.Equal(form.MediaIDs, xslices.Gather(nil, apiStatus.MediaAttachments, func(media *apimodel.Attachment) string { + return media.ID + })) + suite.Equal( + xslices.Gather(nil, form.MediaAttributes, func(attr apimodel.AttachmentAttributesRequest) string { + return attr.Description + }), + xslices.Gather(nil, apiStatus.MediaAttachments, func(media *apimodel.Attachment) string { + return *media.Description + }), + ) + + // Fetched the latest version of edited status from the database. + latestStatus, err := suite.state.DB.GetStatusByID(ctx, status.ID) + suite.NoError(err) + + // Check latest status against input form data. + suite.Equal(form.Status, latestStatus.Text) + suite.Equal(form.SpoilerText, latestStatus.ContentWarning) + suite.Equal(form.Sensitive, *latestStatus.Sensitive) + suite.Equal(form.Language, latestStatus.Language) + suite.Equal(len(status.EditIDs)+1, len(latestStatus.EditIDs)) + suite.NotEqual(status.UpdatedAt, latestStatus.UpdatedAt) + suite.Equal(form.MediaIDs, latestStatus.AttachmentIDs) + suite.Equal( + xslices.Gather(nil, form.MediaAttributes, func(attr apimodel.AttachmentAttributesRequest) string { + return attr.Description + }), + xslices.Gather(nil, latestStatus.Attachments, func(media *gtsmodel.MediaAttachment) string { + return media.Description + }), + ) + + // Populate all historical edits for this status. + err = suite.state.DB.PopulateStatusEdits(ctx, latestStatus) + suite.NoError(err) + + // Further populate edits to get attachments. + for _, edit := range latestStatus.Edits { + err = suite.state.DB.PopulateStatusEdit(ctx, edit) + suite.NoError(err) + } + + // Check previous status edit matches original status content. + previousEdit := latestStatus.Edits[len(latestStatus.Edits)-1] + suite.Equal(status.Content, previousEdit.Content) + suite.Equal(status.Text, previousEdit.Text) + suite.Equal(status.ContentWarning, previousEdit.ContentWarning) + suite.Equal(*status.Sensitive, *previousEdit.Sensitive) + suite.Equal(status.Language, previousEdit.Language) + suite.Equal(status.UpdatedAt, previousEdit.CreatedAt) + suite.Equal(status.AttachmentIDs, previousEdit.AttachmentIDs) + suite.Equal( + xslices.Gather(nil, status.Attachments, func(media *gtsmodel.MediaAttachment) string { + return media.Description + }), + previousEdit.AttachmentDescriptions, + ) +} + +func (suite *StatusEditTestSuite) TestEditAddMedia() { + // Create cancellable context to use for test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Get a local account to use as test requester. + requester := suite.testAccounts["local_account_1"] + requester, _ = suite.state.DB.GetAccountByID(ctx, requester.ID) + + // Get some of requester's existing media, and unattach from existing status. + media1 := suite.testAttachments["local_account_1_status_4_attachment_1"] + media2 := suite.testAttachments["local_account_1_status_4_attachment_2"] + media1.StatusID, media2.StatusID = "", "" + suite.NoError(suite.state.DB.UpdateAttachment(ctx, media1, "status_id")) + suite.NoError(suite.state.DB.UpdateAttachment(ctx, media2, "status_id")) + media1, _ = suite.state.DB.GetAttachmentByID(ctx, media1.ID) + media2, _ = suite.state.DB.GetAttachmentByID(ctx, media2.ID) + + // Get requester's existing status to perform an edit on. + status := suite.testStatuses["local_account_1_status_9"] + status, _ = suite.state.DB.GetStatusByID(ctx, status.ID) + + // Prepare edit addding status media. + form := &apimodel.StatusEditRequest{ + Status: "

this is some edited status text!

", + SpoilerText: "this status now has media", + Sensitive: true, + Language: "en", + MediaIDs: []string{media1.ID, media2.ID}, + MediaAttributes: nil, + } + + // Pass the prepared form to the status processor to perform the edit. + apiStatus, errWithCode := suite.status.Edit(ctx, requester, status.ID, form) + suite.NotNil(apiStatus) + suite.NoError(errWithCode) + + // Check response against input form data. + suite.Equal(form.Status, apiStatus.Text) + suite.Equal(form.SpoilerText, apiStatus.SpoilerText) + suite.Equal(form.Sensitive, apiStatus.Sensitive) + suite.Equal(form.Language, *apiStatus.Language) + suite.NotEqual(util.FormatISO8601(status.UpdatedAt), *apiStatus.EditedAt) + suite.Equal(form.MediaIDs, xslices.Gather(nil, apiStatus.MediaAttachments, func(media *apimodel.Attachment) string { + return media.ID + })) + + // Fetched the latest version of edited status from the database. + latestStatus, err := suite.state.DB.GetStatusByID(ctx, status.ID) + suite.NoError(err) + + // Check latest status against input form data. + suite.Equal(form.Status, latestStatus.Text) + suite.Equal(form.SpoilerText, latestStatus.ContentWarning) + suite.Equal(form.Sensitive, *latestStatus.Sensitive) + suite.Equal(form.Language, latestStatus.Language) + suite.Equal(len(status.EditIDs)+1, len(latestStatus.EditIDs)) + suite.NotEqual(status.UpdatedAt, latestStatus.UpdatedAt) + suite.Equal(form.MediaIDs, latestStatus.AttachmentIDs) + + // Populate all historical edits for this status. + err = suite.state.DB.PopulateStatusEdits(ctx, latestStatus) + suite.NoError(err) + + // Check previous status edit matches original status content. + previousEdit := latestStatus.Edits[len(latestStatus.Edits)-1] + suite.Equal(status.Content, previousEdit.Content) + suite.Equal(status.Text, previousEdit.Text) + suite.Equal(status.ContentWarning, previousEdit.ContentWarning) + suite.Equal(*status.Sensitive, *previousEdit.Sensitive) + suite.Equal(status.Language, previousEdit.Language) + suite.Equal(status.UpdatedAt, previousEdit.CreatedAt) + suite.Equal(status.AttachmentIDs, previousEdit.AttachmentIDs) +} + +func (suite *StatusEditTestSuite) TestEditRemoveMedia() { + // Create cancellable context to use for test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Get a local account to use as test requester. + requester := suite.testAccounts["local_account_1"] + requester, _ = suite.state.DB.GetAccountByID(ctx, requester.ID) + + // Get requester's existing status to perform an edit on. + status := suite.testStatuses["local_account_1_status_4"] + status, _ = suite.state.DB.GetStatusByID(ctx, status.ID) + + // Prepare edit removing status media. + form := &apimodel.StatusEditRequest{ + Status: "

this is some edited status text!

", + SpoilerText: "this status is now missing media", + Sensitive: true, + Language: "en", + MediaIDs: nil, + MediaAttributes: nil, + } + + // Pass the prepared form to the status processor to perform the edit. + apiStatus, errWithCode := suite.status.Edit(ctx, requester, status.ID, form) + suite.NotNil(apiStatus) + suite.NoError(errWithCode) + + // Check response against input form data. + suite.Equal(form.Status, apiStatus.Text) + suite.Equal(form.SpoilerText, apiStatus.SpoilerText) + suite.Equal(form.Sensitive, apiStatus.Sensitive) + suite.Equal(form.Language, *apiStatus.Language) + suite.NotEqual(util.FormatISO8601(status.UpdatedAt), *apiStatus.EditedAt) + suite.Equal(form.MediaIDs, xslices.Gather(nil, apiStatus.MediaAttachments, func(media *apimodel.Attachment) string { + return media.ID + })) + + // Fetched the latest version of edited status from the database. + latestStatus, err := suite.state.DB.GetStatusByID(ctx, status.ID) + suite.NoError(err) + + // Check latest status against input form data. + suite.Equal(form.Status, latestStatus.Text) + suite.Equal(form.SpoilerText, latestStatus.ContentWarning) + suite.Equal(form.Sensitive, *latestStatus.Sensitive) + suite.Equal(form.Language, latestStatus.Language) + suite.Equal(len(status.EditIDs)+1, len(latestStatus.EditIDs)) + suite.NotEqual(status.UpdatedAt, latestStatus.UpdatedAt) + suite.Equal(form.MediaIDs, latestStatus.AttachmentIDs) + + // Populate all historical edits for this status. + err = suite.state.DB.PopulateStatusEdits(ctx, latestStatus) + suite.NoError(err) + + // Check previous status edit matches original status content. + previousEdit := latestStatus.Edits[len(latestStatus.Edits)-1] + suite.Equal(status.Content, previousEdit.Content) + suite.Equal(status.Text, previousEdit.Text) + suite.Equal(status.ContentWarning, previousEdit.ContentWarning) + suite.Equal(*status.Sensitive, *previousEdit.Sensitive) + suite.Equal(status.Language, previousEdit.Language) + suite.Equal(status.UpdatedAt, previousEdit.CreatedAt) + suite.Equal(status.AttachmentIDs, previousEdit.AttachmentIDs) +} + +func (suite *StatusEditTestSuite) TestEditOthersStatus1() { + // Create cancellable context to use for test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Get a local account to use as test requester. + requester := suite.testAccounts["local_account_1"] + requester, _ = suite.state.DB.GetAccountByID(ctx, requester.ID) + + // Get remote accounts's status to attempt an edit on. + status := suite.testStatuses["remote_account_1_status_1"] + status, _ = suite.state.DB.GetStatusByID(ctx, status.ID) + + // Prepare an empty request form, this + // should be all we need to trigger it. + form := &apimodel.StatusEditRequest{} + + // Attempt to edit other remote account's status, this should return an error. + apiStatus, errWithCode := suite.status.Edit(ctx, requester, status.ID, form) + suite.Nil(apiStatus) + suite.Equal(http.StatusNotFound, errWithCode.Code()) + suite.Equal("status does not belong to requester", errWithCode.Error()) + suite.Equal("Not Found: target status not found", errWithCode.Safe()) +} + +func (suite *StatusEditTestSuite) TestEditOthersStatus2() { + // Create cancellable context to use for test. + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + // Get a local account to use as test requester. + requester := suite.testAccounts["local_account_1"] + requester, _ = suite.state.DB.GetAccountByID(ctx, requester.ID) + + // Get other local accounts's status to attempt edit on. + status := suite.testStatuses["local_account_2_status_1"] + status, _ = suite.state.DB.GetStatusByID(ctx, status.ID) + + // Prepare an empty request form, this + // should be all we need to trigger it. + form := &apimodel.StatusEditRequest{} + + // Attempt to edit other local account's status, this should return an error. + apiStatus, errWithCode := suite.status.Edit(ctx, requester, status.ID, form) + suite.Nil(apiStatus) + suite.Equal(http.StatusNotFound, errWithCode.Code()) + suite.Equal("status does not belong to requester", errWithCode.Error()) + suite.Equal("Not Found: target status not found", errWithCode.Safe()) +} + +func TestStatusEditTestSuite(t *testing.T) { + suite.Run(t, new(StatusEditTestSuite)) +} diff --git a/internal/processing/status/get.go b/internal/processing/status/get.go index 470b93a8f..812f01683 100644 --- a/internal/processing/status/get.go +++ b/internal/processing/status/get.go @@ -19,47 +19,16 @@ import ( "context" + "errors" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" - "github.com/superseriousbusiness/gotosocial/internal/util" ) -// HistoryGet gets edit history for the target status, taking account of privacy settings and blocks etc. -// TODO: currently this just returns the latest version of the status. -func (p *Processor) HistoryGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) ([]*apimodel.StatusEdit, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, - requestingAccount, - targetStatusID, - nil, // default freshness - ) - if errWithCode != nil { - return nil, errWithCode - } - - apiStatus, errWithCode := p.c.GetAPIStatus(ctx, requestingAccount, targetStatus) - if errWithCode != nil { - return nil, errWithCode - } - - return []*apimodel.StatusEdit{ - { - Content: apiStatus.Content, - SpoilerText: apiStatus.SpoilerText, - Sensitive: apiStatus.Sensitive, - CreatedAt: util.FormatISO8601(targetStatus.UpdatedAt), - Account: apiStatus.Account, - Poll: apiStatus.Poll, - MediaAttachments: apiStatus.MediaAttachments, - Emojis: apiStatus.Emojis, - }, - }, nil -} - // Get gets the given status, taking account of privacy settings and blocks etc. func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.Status, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, + target, errWithCode := p.c.GetVisibleTargetStatus(ctx, requestingAccount, targetStatusID, nil, // default freshness @@ -67,44 +36,25 @@ func (p *Processor) Get(ctx context.Context, requestingAccount *gtsmodel.Account if errWithCode != nil { return nil, errWithCode } - return p.c.GetAPIStatus(ctx, requestingAccount, targetStatus) + return p.c.GetAPIStatus(ctx, requestingAccount, target) } // SourceGet returns the *apimodel.StatusSource version of the targetStatusID. // Status must belong to the requester, and must not be a boost. -func (p *Processor) SourceGet(ctx context.Context, requestingAccount *gtsmodel.Account, targetStatusID string) (*apimodel.StatusSource, gtserror.WithCode) { - targetStatus, errWithCode := p.c.GetVisibleTargetStatus(ctx, - requestingAccount, - targetStatusID, - nil, // default freshness - ) +func (p *Processor) SourceGet(ctx context.Context, requester *gtsmodel.Account, statusID string) (*apimodel.StatusSource, gtserror.WithCode) { + status, errWithCode := p.c.GetOwnStatus(ctx, requester, statusID) if errWithCode != nil { return nil, errWithCode } - - // Redirect to wrapped status if boost. - targetStatus, errWithCode = p.c.UnwrapIfBoost( - ctx, - requestingAccount, - targetStatus, - ) - if errWithCode != nil { - return nil, errWithCode - } - - if targetStatus.AccountID != requestingAccount.ID { - err := gtserror.Newf( - "status %s does not belong to account %s", - targetStatusID, requestingAccount.ID, + if status.BoostOfID != "" { + return nil, gtserror.NewErrorNotFound( + errors.New("status is a boost wrapper"), + "target status not found", ) - return nil, gtserror.NewErrorNotFound(err) } - - statusSource, err := p.converter.StatusToAPIStatusSource(ctx, targetStatus) - if err != nil { - err = gtserror.Newf("error converting status: %w", err) - return nil, gtserror.NewErrorInternalError(err) - } - - return statusSource, nil + return &apimodel.StatusSource{ + ID: status.ID, + Text: status.Text, + SpoilerText: status.ContentWarning, + }, nil } diff --git a/internal/processing/workers/fromfediapi.go b/internal/processing/workers/fromfediapi.go index 0d6ec1836..096e285f6 100644 --- a/internal/processing/workers/fromfediapi.go +++ b/internal/processing/workers/fromfediapi.go @@ -762,7 +762,7 @@ func (p *fediAPI) UpdateAccount(ctx context.Context, fMsg *messages.FromFediAPI) account, apubAcc, - // Force refresh within 10s window. + // Force refresh within 5s window. // // Missing account updates could be // detrimental to federation if they @@ -917,17 +917,25 @@ func (p *fediAPI) UpdateStatus(ctx context.Context, fMsg *messages.FromFediAPI) return gtserror.Newf("cannot cast %T -> *gtsmodel.Status", fMsg.GTSModel) } + var freshness *dereferencing.FreshnessWindow + // Cast the updated ActivityPub statusable object . apStatus, _ := fMsg.APObject.(ap.Statusable) + if apStatus != nil { + // If an AP object was provided, we + // allow very fast refreshes that likely + // indicate a status edit after post. + freshness = dereferencing.Freshest + } + // Fetch up-to-date attach status attachments, etc. status, _, err := p.federate.RefreshStatus( ctx, fMsg.Receiving.Username, existing, apStatus, - // Force refresh within 5min window. - dereferencing.Fresh, + freshness, ) if err != nil { log.Errorf(ctx, "error refreshing status: %v", err) diff --git a/internal/typeutils/internaltofrontend.go b/internal/typeutils/internaltofrontend.go index 40671d884..59618a573 100644 --- a/internal/typeutils/internaltofrontend.go +++ b/internal/typeutils/internaltofrontend.go @@ -1217,21 +1217,6 @@ func (c *Converter) StatusToWebStatus( return webStatus, nil } -// StatusToAPIStatusSource returns the *apimodel.StatusSource of the given status. -// Callers should check beforehand whether a requester has permission to view the -// source of the status, and ensure they're passing only a local status into this function. -func (c *Converter) StatusToAPIStatusSource(ctx context.Context, s *gtsmodel.Status) (*apimodel.StatusSource, error) { - // TODO: remove this when edit support is added. - text := "**STATUS EDITS ARE NOT CURRENTLY SUPPORTED IN GOTOSOCIAL (coming in 2024)**\n" + - "You can review the original text of your status below, but you will not be able to submit this edit.\n\n---\n\n" + s.Text - - return &apimodel.StatusSource{ - ID: s.ID, - Text: text, - SpoilerText: s.ContentWarning, - }, nil -} - // statusToFrontend is a package internal function for // parsing a status into its initial frontend representation. // @@ -1473,6 +1458,149 @@ func (c *Converter) baseStatusToFrontend( return apiStatus, nil } +// StatusToAPIEdits converts a status and its historical edits (if any) to a slice of API model status edits. +func (c *Converter) StatusToAPIEdits(ctx context.Context, status *gtsmodel.Status) ([]*apimodel.StatusEdit, error) { + var media map[string]*gtsmodel.MediaAttachment + + // Gather attachments of status AND edits. + attachmentIDs := status.AllAttachmentIDs() + if len(attachmentIDs) > 0 { + + // Fetch all of the gathered status attachments from the database. + attachments, err := c.state.DB.GetAttachmentsByIDs(ctx, attachmentIDs) + if err != nil { + return nil, gtserror.Newf("error getting attachments from db: %w", err) + } + + // Generate a lookup map in 'media' of status attachments by their IDs. + media = util.KeyBy(attachments, func(m *gtsmodel.MediaAttachment) string { + return m.ID + }) + } + + // Convert the status author account to API model. + apiAccount, err := c.AccountToAPIAccountPublic(ctx, + status.Account, + ) + if err != nil { + return nil, gtserror.Newf("error converting account: %w", err) + } + + // Convert status emojis to their API models, + // this includes all status emojis both current + // and historic, so it gets passed to each edit. + apiEmojis, err := c.convertEmojisToAPIEmojis(ctx, + nil, + status.EmojiIDs, + ) + if err != nil { + return nil, gtserror.Newf("error converting emojis: %w", err) + } + + var votes []int + var options []string + + if status.Poll != nil { + // Extract status poll options. + options = status.Poll.Options + + // Show votes only if closed / allowed. + if !status.Poll.ClosedAt.IsZero() || + !*status.Poll.HideCounts { + votes = status.Poll.Votes + } + } + + // Append status itself to final slot in the edits + // so we can add its revision using the below loop. + edits := append(status.Edits, >smodel.StatusEdit{ //nolint:gocritic + Content: status.Content, + ContentWarning: status.ContentWarning, + Sensitive: status.Sensitive, + PollOptions: options, + PollVotes: votes, + AttachmentIDs: status.AttachmentIDs, + AttachmentDescriptions: nil, // no change from current + CreatedAt: status.UpdatedAt, + }) + + // Iterate through status edits, starting at newest. + apiEdits := make([]*apimodel.StatusEdit, 0, len(edits)) + for i := len(edits) - 1; i >= 0; i-- { + edit := edits[i] + + // Iterate through edit attachment IDs, getting model from 'media' lookup. + apiAttachments := make([]*apimodel.Attachment, 0, len(edit.AttachmentIDs)) + for _, id := range edit.AttachmentIDs { + attachment, ok := media[id] + if !ok { + continue + } + + // Convert each media attachment to frontend API model. + apiAttachment, err := c.AttachmentToAPIAttachment(ctx, + attachment, + ) + if err != nil { + log.Error(ctx, "error converting attachment: %v", err) + continue + } + + // Append converted media attachment to return slice. + apiAttachments = append(apiAttachments, &apiAttachment) + } + + // If media descriptions are set, update API model descriptions. + if len(edit.AttachmentIDs) == len(edit.AttachmentDescriptions) { + var j int + for i, id := range edit.AttachmentIDs { + descr := edit.AttachmentDescriptions[i] + for ; j < len(apiAttachments); j++ { + if apiAttachments[j].ID == id { + apiAttachments[j].Description = &descr + break + } + } + } + } + + // Attach status poll if set. + var apiPoll *apimodel.Poll + if len(edit.PollOptions) > 0 { + apiPoll = new(apimodel.Poll) + + // Iterate through poll options and attach to API poll model. + apiPoll.Options = make([]apimodel.PollOption, len(edit.PollOptions)) + for i, option := range edit.PollOptions { + apiPoll.Options[i] = apimodel.PollOption{ + Title: option, + } + } + + // If poll votes are attached, set vote counts. + if len(edit.PollVotes) == len(apiPoll.Options) { + for i, votes := range edit.PollVotes { + apiPoll.Options[i].VotesCount = &votes + } + } + } + + // Append this status edit to the return slice. + apiEdits = append(apiEdits, &apimodel.StatusEdit{ + CreatedAt: util.FormatISO8601(edit.CreatedAt), + Content: edit.Content, + SpoilerText: edit.ContentWarning, + Sensitive: util.PtrOrZero(edit.Sensitive), + Account: apiAccount, + Poll: apiPoll, + MediaAttachments: apiAttachments, + Emojis: apiEmojis, // same models used for whole status + all edits + }) + } + + return apiEdits, nil +} + // VisToAPIVis converts a gts visibility into its api equivalent func (c *Converter) VisToAPIVis(ctx context.Context, m gtsmodel.Visibility) apimodel.Visibility { switch m { @@ -1489,7 +1617,7 @@ func (c *Converter) VisToAPIVis(ctx context.Context, m gtsmodel.Visibility) apim } // InstanceRuleToAdminAPIRule converts a local instance rule into its api equivalent for serving at /api/v1/admin/instance/rules/:id -func (c *Converter) InstanceRuleToAPIRule(r gtsmodel.Rule) apimodel.InstanceRule { +func InstanceRuleToAPIRule(r gtsmodel.Rule) apimodel.InstanceRule { return apimodel.InstanceRule{ ID: r.ID, Text: r.Text, @@ -1497,18 +1625,16 @@ func (c *Converter) InstanceRuleToAPIRule(r gtsmodel.Rule) apimodel.InstanceRule } // InstanceRulesToAPIRules converts all local instance rules into their api equivalent for serving at /api/v1/instance/rules -func (c *Converter) InstanceRulesToAPIRules(r []gtsmodel.Rule) []apimodel.InstanceRule { +func InstanceRulesToAPIRules(r []gtsmodel.Rule) []apimodel.InstanceRule { rules := make([]apimodel.InstanceRule, len(r)) - for i, v := range r { - rules[i] = c.InstanceRuleToAPIRule(v) + rules[i] = InstanceRuleToAPIRule(v) } - return rules } // InstanceRuleToAdminAPIRule converts a local instance rule into its api equivalent for serving at /api/v1/admin/instance/rules/:id -func (c *Converter) InstanceRuleToAdminAPIRule(r *gtsmodel.Rule) *apimodel.AdminInstanceRule { +func InstanceRuleToAdminAPIRule(r *gtsmodel.Rule) *apimodel.AdminInstanceRule { return &apimodel.AdminInstanceRule{ ID: r.ID, CreatedAt: util.FormatISO8601(r.CreatedAt), @@ -1541,7 +1667,7 @@ func (c *Converter) InstanceToAPIV1Instance(ctx context.Context, i *gtsmodel.Ins ApprovalRequired: true, // approval always required InvitesEnabled: false, // todo: not supported yet MaxTootChars: uint(config.GetStatusesMaxChars()), // #nosec G115 -- Already validated. - Rules: c.InstanceRulesToAPIRules(i.Rules), + Rules: InstanceRulesToAPIRules(i.Rules), Terms: i.Terms, TermsRaw: i.TermsText, } @@ -1675,7 +1801,7 @@ func (c *Converter) InstanceToAPIV2Instance(ctx context.Context, i *gtsmodel.Ins CustomCSS: i.CustomCSS, Usage: apimodel.InstanceV2Usage{}, // todo: not implemented Languages: config.GetInstanceLanguages().TagStrs(), - Rules: c.InstanceRulesToAPIRules(i.Rules), + Rules: InstanceRulesToAPIRules(i.Rules), Terms: i.Terms, TermsText: i.TermsText, } diff --git a/internal/typeutils/internaltofrontend_test.go b/internal/typeutils/internaltofrontend_test.go index 0ec9ea05f..39a9bd9d4 100644 --- a/internal/typeutils/internaltofrontend_test.go +++ b/internal/typeutils/internaltofrontend_test.go @@ -3737,6 +3737,136 @@ func (suite *InternalToFrontendTestSuite) TestConversationToAPI() { }`, string(b)) } +func (suite *InternalToFrontendTestSuite) TestStatusToAPIEdits() { + ctx, cncl := context.WithCancel(context.Background()) + defer cncl() + + statusID := suite.testStatuses["local_account_1_status_9"].ID + + status, err := suite.state.DB.GetStatusByID(ctx, statusID) + suite.NoError(err) + + err = suite.state.DB.PopulateStatusEdits(ctx, status) + suite.NoError(err) + + apiEdits, err := suite.typeconverter.StatusToAPIEdits(ctx, status) + suite.NoError(err) + + b, err := json.MarshalIndent(apiEdits, "", " ") + suite.NoError(err) + + suite.Equal(`[ + { + "content": "\u003cp\u003ethis is the latest revision of the status, with a content-warning\u003c/p\u003e", + "spoiler_text": "edited status", + "sensitive": false, + "created_at": "2024-11-01T09:02:00.000Z", + "account": { + "id": "01F8MH1H7YV1Z7D2C8K2730QBF", + "username": "the_mighty_zork", + "acct": "the_mighty_zork", + "display_name": "original zork (he/they)", + "locked": false, + "discoverable": true, + "bot": false, + "created_at": "2022-05-20T11:09:18.000Z", + "note": "\u003cp\u003ehey yo this is my profile!\u003c/p\u003e", + "url": "http://localhost:8080/@the_mighty_zork", + "avatar": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/original/01F8MH58A357CV5K7R7TJMSH6S.jpg", + "avatar_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.webp", + "avatar_description": "a green goblin looking nasty", + "avatar_media_id": "01F8MH58A357CV5K7R7TJMSH6S", + "header": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/original/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", + "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.webp", + "header_description": "A very old-school screenshot of the original team fortress mod for quake", + "header_media_id": "01PFPMWK2FF0D9WMHEJHR07C3Q", + "followers_count": 2, + "following_count": 2, + "statuses_count": 9, + "last_status_at": "2024-11-01", + "emojis": [], + "fields": [], + "enable_rss": true + }, + "poll": null, + "media_attachments": [], + "emojis": [] + }, + { + "content": "\u003cp\u003ethis is the first status edit! now with content-warning\u003c/p\u003e", + "spoiler_text": "edited status", + "sensitive": false, + "created_at": "2024-11-01T09:01:00.000Z", + "account": { + "id": "01F8MH1H7YV1Z7D2C8K2730QBF", + "username": "the_mighty_zork", + "acct": "the_mighty_zork", + "display_name": "original zork (he/they)", + "locked": false, + "discoverable": true, + "bot": false, + "created_at": "2022-05-20T11:09:18.000Z", + "note": "\u003cp\u003ehey yo this is my profile!\u003c/p\u003e", + "url": "http://localhost:8080/@the_mighty_zork", + "avatar": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/original/01F8MH58A357CV5K7R7TJMSH6S.jpg", + "avatar_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.webp", + "avatar_description": "a green goblin looking nasty", + "avatar_media_id": "01F8MH58A357CV5K7R7TJMSH6S", + "header": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/original/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", + "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.webp", + "header_description": "A very old-school screenshot of the original team fortress mod for quake", + "header_media_id": "01PFPMWK2FF0D9WMHEJHR07C3Q", + "followers_count": 2, + "following_count": 2, + "statuses_count": 9, + "last_status_at": "2024-11-01", + "emojis": [], + "fields": [], + "enable_rss": true + }, + "poll": null, + "media_attachments": [], + "emojis": [] + }, + { + "content": "\u003cp\u003ethis is the original status\u003c/p\u003e", + "spoiler_text": "", + "sensitive": false, + "created_at": "2024-11-01T09:00:00.000Z", + "account": { + "id": "01F8MH1H7YV1Z7D2C8K2730QBF", + "username": "the_mighty_zork", + "acct": "the_mighty_zork", + "display_name": "original zork (he/they)", + "locked": false, + "discoverable": true, + "bot": false, + "created_at": "2022-05-20T11:09:18.000Z", + "note": "\u003cp\u003ehey yo this is my profile!\u003c/p\u003e", + "url": "http://localhost:8080/@the_mighty_zork", + "avatar": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/original/01F8MH58A357CV5K7R7TJMSH6S.jpg", + "avatar_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.webp", + "avatar_description": "a green goblin looking nasty", + "avatar_media_id": "01F8MH58A357CV5K7R7TJMSH6S", + "header": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/original/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg", + "header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.webp", + "header_description": "A very old-school screenshot of the original team fortress mod for quake", + "header_media_id": "01PFPMWK2FF0D9WMHEJHR07C3Q", + "followers_count": 2, + "following_count": 2, + "statuses_count": 9, + "last_status_at": "2024-11-01", + "emojis": [], + "fields": [], + "enable_rss": true + }, + "poll": null, + "media_attachments": [], + "emojis": [] + } +]`, string(b)) +} + func TestInternalToFrontendTestSuite(t *testing.T) { suite.Run(t, new(InternalToFrontendTestSuite)) } diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go index b9bb03bb7..742f308af 100644 --- a/vendor/github.com/ncruces/go-sqlite3/driver/driver.go +++ b/vendor/github.com/ncruces/go-sqlite3/driver/driver.go @@ -274,6 +274,7 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { // if err != nil { // log.Fatal(err) // } +// defer conn.Close() // // err = conn.Raw(func(driverConn any) error { // conn := driverConn.(driver.Conn) diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/util.go b/vendor/github.com/ncruces/go-sqlite3/driver/util.go index 033841157..987585576 100644 --- a/vendor/github.com/ncruces/go-sqlite3/driver/util.go +++ b/vendor/github.com/ncruces/go-sqlite3/driver/util.go @@ -12,3 +12,63 @@ func namedValues(args []driver.Value) []driver.NamedValue { } return named } + +func notWhitespace(sql string) bool { + const ( + code = iota + slash + minus + ccomment + sqlcomment + endcomment + ) + + state := code + for _, b := range ([]byte)(sql) { + if b == 0 { + break + } + + switch state { + case code: + switch b { + case '/': + state = slash + case '-': + state = minus + case ' ', ';', '\t', '\n', '\v', '\f', '\r': + continue + default: + return true + } + case slash: + if b != '*' { + return true + } + state = ccomment + case minus: + if b != '-' { + return true + } + state = sqlcomment + case ccomment: + if b == '*' { + state = endcomment + } + case sqlcomment: + if b == '\n' { + state = code + } + case endcomment: + switch b { + case '/': + state = code + case '*': + state = endcomment + default: + state = ccomment + } + } + } + return state == slash || state == minus +} diff --git a/vendor/github.com/ncruces/go-sqlite3/driver/whitespace.go b/vendor/github.com/ncruces/go-sqlite3/driver/whitespace.go deleted file mode 100644 index 8f45706f5..000000000 --- a/vendor/github.com/ncruces/go-sqlite3/driver/whitespace.go +++ /dev/null @@ -1,61 +0,0 @@ -package driver - -func notWhitespace(sql string) bool { - const ( - code = iota - slash - minus - ccomment - sqlcomment - endcomment - ) - - state := code - for _, b := range ([]byte)(sql) { - if b == 0 { - break - } - - switch state { - case code: - switch b { - case '/': - state = slash - case '-': - state = minus - case ' ', ';', '\t', '\n', '\v', '\f', '\r': - continue - default: - return true - } - case slash: - if b != '*' { - return true - } - state = ccomment - case minus: - if b != '-' { - return true - } - state = sqlcomment - case ccomment: - if b == '*' { - state = endcomment - } - case sqlcomment: - if b == '\n' { - state = code - } - case endcomment: - switch b { - case '/': - state = code - case '*': - state = endcomment - default: - state = ccomment - } - } - } - return state == slash || state == minus -} diff --git a/vendor/github.com/ncruces/go-sqlite3/internal/util/mmap_unix.go b/vendor/github.com/ncruces/go-sqlite3/internal/util/mmap_unix.go index 5d5ca3823..4ff056666 100644 --- a/vendor/github.com/ncruces/go-sqlite3/internal/util/mmap_unix.go +++ b/vendor/github.com/ncruces/go-sqlite3/internal/util/mmap_unix.go @@ -39,13 +39,13 @@ func (s *mmapState) new(ctx context.Context, mod api.Module, size int32) *Mapped // Save the newly allocated region. ptr := uint32(stack[0]) buf := View(mod, ptr, uint64(size)) - addr := unsafe.Pointer(&buf[0]) - s.regions = append(s.regions, &MappedRegion{ + res := &MappedRegion{ Ptr: ptr, - addr: addr, size: size, - }) - return s.regions[len(s.regions)-1] + addr: unsafe.Pointer(&buf[0]), + } + s.regions = append(s.regions, res) + return res } type MappedRegion struct { diff --git a/vendor/github.com/ncruces/go-sqlite3/sqlite.go b/vendor/github.com/ncruces/go-sqlite3/sqlite.go index 2afe9971c..18a2c2a73 100644 --- a/vendor/github.com/ncruces/go-sqlite3/sqlite.go +++ b/vendor/github.com/ncruces/go-sqlite3/sqlite.go @@ -265,10 +265,11 @@ func (a *arena) mark() (reset func()) { ptrs := len(a.ptrs) next := a.next return func() { - for _, ptr := range a.ptrs[ptrs:] { + rest := a.ptrs[ptrs:] + for _, ptr := range a.ptrs[:ptrs] { a.sqlt.free(ptr) } - a.ptrs = a.ptrs[:ptrs] + a.ptrs = rest a.next = next } } diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/lock.go b/vendor/github.com/ncruces/go-sqlite3/vfs/lock.go index 8828662d4..b28d83230 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/lock.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/lock.go @@ -20,12 +20,10 @@ ) func (f *vfsFile) Lock(lock LockLevel) error { - // Argument check. SQLite never explicitly requests a pending lock. - if lock != LOCK_SHARED && lock != LOCK_RESERVED && lock != LOCK_EXCLUSIVE { - panic(util.AssertErr()) - } - switch { + case lock != LOCK_SHARED && lock != LOCK_RESERVED && lock != LOCK_EXCLUSIVE: + // Argument check. SQLite never explicitly requests a pending lock. + panic(util.AssertErr()) case f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE: // Connection state check. panic(util.AssertErr()) @@ -87,13 +85,12 @@ func (f *vfsFile) Lock(lock LockLevel) error { } func (f *vfsFile) Unlock(lock LockLevel) error { - // Argument check. - if lock != LOCK_NONE && lock != LOCK_SHARED { + switch { + case lock != LOCK_NONE && lock != LOCK_SHARED: + // Argument check. panic(util.AssertErr()) - } - - // Connection state check. - if f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE { + case f.lock < LOCK_NONE || f.lock > LOCK_EXCLUSIVE: + // Connection state check. panic(util.AssertErr()) } diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_bsd.go b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_bsd.go index 10d6dbf61..5f4f5d108 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_bsd.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_bsd.go @@ -22,7 +22,7 @@ type vfsShmParent struct { refs int // +checklocks:vfsShmListMtx - lock [_SHM_NLOCK]int16 // +checklocks:Mutex + lock [_SHM_NLOCK]int8 // +checklocks:Mutex sync.Mutex } @@ -184,10 +184,22 @@ func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) _ErrorCode { return rc } - // Obtain/release the appropriate file lock. + // Obtain/release the appropriate file locks. switch { case flags&_SHM_UNLOCK != 0: - return osUnlock(s.File, _SHM_BASE+int64(offset), int64(n)) + begin, end := offset, offset+n + for i := begin; i < end; i++ { + if s.vfsShmParent.lock[i] != 0 { + if i > begin { + rc |= osUnlock(s.File, _SHM_BASE+int64(begin), int64(i-begin)) + } + begin = i + 1 + } + } + if end > begin { + rc |= osUnlock(s.File, _SHM_BASE+int64(begin), int64(end-begin)) + } + return rc case flags&_SHM_SHARED != 0: rc = osReadLock(s.File, _SHM_BASE+int64(offset), int64(n)) case flags&_SHM_EXCLUSIVE != 0: @@ -196,7 +208,7 @@ func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) _ErrorCode { panic(util.AssertErr()) } - // Release the local lock. + // Release the local lock we had acquired. if rc != _OK { s.shmMemLock(offset, n, flags^(_SHM_UNLOCK|_SHM_LOCK)) } diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_dotlk.go b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_dotlk.go index 17fefe562..842bea8f5 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_dotlk.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_dotlk.go @@ -18,7 +18,7 @@ type vfsShmParent struct { shared [][_WALINDEX_PGSZ]byte refs int // +checklocks:vfsShmListMtx - lock [_SHM_NLOCK]int16 // +checklocks:Mutex + lock [_SHM_NLOCK]int8 // +checklocks:Mutex sync.Mutex } diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_memlk.go b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_memlk.go index 404019642..5c8071ebe 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_memlk.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_memlk.go @@ -10,9 +10,6 @@ func (s *vfsShm) shmMemLock(offset, n int32, flags _ShmFlag) _ErrorCode { case flags&_SHM_UNLOCK != 0: for i := offset; i < offset+n; i++ { if s.lock[i] { - if s.vfsShmParent.lock[i] == 0 { - panic(util.AssertErr()) - } if s.vfsShmParent.lock[i] <= 0 { s.vfsShmParent.lock[i] = 0 } else { @@ -23,20 +20,21 @@ func (s *vfsShm) shmMemLock(offset, n int32, flags _ShmFlag) _ErrorCode { } case flags&_SHM_SHARED != 0: for i := offset; i < offset+n; i++ { - if s.lock[i] { - panic(util.AssertErr()) - } - if s.vfsShmParent.lock[i]+1 <= 0 { + if !s.lock[i] && + s.vfsShmParent.lock[i]+1 <= 0 { return _BUSY } } for i := offset; i < offset+n; i++ { - s.vfsShmParent.lock[i]++ - s.lock[i] = true + if !s.lock[i] { + s.vfsShmParent.lock[i]++ + s.lock[i] = true + } } case flags&_SHM_EXCLUSIVE != 0: for i := offset; i < offset+n; i++ { if s.lock[i] { + // SQLite never requests an exclusive lock that it already holds. panic(util.AssertErr()) } if s.vfsShmParent.lock[i] != 0 { diff --git a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_ofd.go b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_ofd.go index d335a85fc..dd3611193 100644 --- a/vendor/github.com/ncruces/go-sqlite3/vfs/shm_ofd.go +++ b/vendor/github.com/ncruces/go-sqlite3/vfs/shm_ofd.go @@ -110,7 +110,12 @@ func (s *vfsShm) shmMap(ctx context.Context, mod api.Module, id, size int32, ext func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) _ErrorCode { // Argument check. - if n <= 0 || offset < 0 || offset+n > _SHM_NLOCK { + switch { + case n <= 0: + panic(util.AssertErr()) + case offset < 0 || offset+n > _SHM_NLOCK: + panic(util.AssertErr()) + case n != 1 && flags&_SHM_EXCLUSIVE == 0: panic(util.AssertErr()) } switch flags { @@ -123,9 +128,6 @@ func (s *vfsShm) shmLock(offset, n int32, flags _ShmFlag) _ErrorCode { default: panic(util.AssertErr()) } - if n != 1 && flags&_SHM_EXCLUSIVE == 0 { - panic(util.AssertErr()) - } var timeout time.Duration if s.blocking { diff --git a/vendor/modules.txt b/vendor/modules.txt index e0aef8e18..5e0734f84 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -520,7 +520,7 @@ github.com/modern-go/reflect2 # github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 ## explicit github.com/munnerz/goautoneg -# github.com/ncruces/go-sqlite3 v0.21.2 +# github.com/ncruces/go-sqlite3 v0.21.3 ## explicit; go 1.21 github.com/ncruces/go-sqlite3 github.com/ncruces/go-sqlite3/driver