diff --git a/internal/api/client/fileserver/fileserver_test.go b/internal/api/client/fileserver/fileserver_test.go
new file mode 100644
index 000000000..f1fab5672
--- /dev/null
+++ b/internal/api/client/fileserver/fileserver_test.go
@@ -0,0 +1,109 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package fileserver_test
+
+import (
+ "context"
+
+ "github.com/stretchr/testify/suite"
+ "github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
+ "github.com/superseriousbusiness/gotosocial/internal/concurrency"
+ "github.com/superseriousbusiness/gotosocial/internal/db"
+ "github.com/superseriousbusiness/gotosocial/internal/email"
+ "github.com/superseriousbusiness/gotosocial/internal/federation"
+ "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
+ "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/media"
+ "github.com/superseriousbusiness/gotosocial/internal/messages"
+ "github.com/superseriousbusiness/gotosocial/internal/oauth"
+ "github.com/superseriousbusiness/gotosocial/internal/processing"
+ "github.com/superseriousbusiness/gotosocial/internal/storage"
+ "github.com/superseriousbusiness/gotosocial/internal/typeutils"
+ "github.com/superseriousbusiness/gotosocial/testrig"
+)
+
+type FileserverTestSuite struct {
+ // standard suite interfaces
+ suite.Suite
+ db db.DB
+ storage *storage.Driver
+ federator federation.Federator
+ tc typeutils.TypeConverter
+ processor processing.Processor
+ mediaManager media.Manager
+ oauthServer oauth.Server
+ emailSender email.Sender
+
+ // standard suite models
+ testTokens map[string]*gtsmodel.Token
+ testClients map[string]*gtsmodel.Client
+ testApplications map[string]*gtsmodel.Application
+ testUsers map[string]*gtsmodel.User
+ testAccounts map[string]*gtsmodel.Account
+ testAttachments map[string]*gtsmodel.MediaAttachment
+
+ // item being tested
+ fileServer *fileserver.FileServer
+}
+
+/*
+ TEST INFRASTRUCTURE
+*/
+
+func (suite *FileserverTestSuite) SetupSuite() {
+ testrig.InitTestConfig()
+ testrig.InitTestLog()
+
+ fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
+ clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
+
+ suite.db = testrig.NewTestDB()
+ suite.storage = testrig.NewInMemoryStorage()
+ suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
+ suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
+
+ suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker)
+ suite.tc = testrig.NewTestTypeConverter(suite.db)
+ suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
+ suite.oauthServer = testrig.NewTestOauthServer(suite.db)
+
+ suite.fileServer = fileserver.New(suite.processor).(*fileserver.FileServer)
+}
+
+func (suite *FileserverTestSuite) SetupTest() {
+ testrig.StandardDBSetup(suite.db, nil)
+ testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
+ suite.testTokens = testrig.NewTestTokens()
+ suite.testClients = testrig.NewTestClients()
+ suite.testApplications = testrig.NewTestApplications()
+ suite.testUsers = testrig.NewTestUsers()
+ suite.testAccounts = testrig.NewTestAccounts()
+ suite.testAttachments = testrig.NewTestAttachments()
+}
+
+func (suite *FileserverTestSuite) TearDownSuite() {
+ if err := suite.db.Stop(context.Background()); err != nil {
+ log.Panicf("error closing db connection: %s", err)
+ }
+}
+
+func (suite *FileserverTestSuite) TearDownTest() {
+ testrig.StandardDBTeardown(suite.db)
+ testrig.StandardStorageTeardown(suite.storage)
+}
diff --git a/internal/api/client/fileserver/servefile.go b/internal/api/client/fileserver/servefile.go
index e4eca770f..d2328a5fc 100644
--- a/internal/api/client/fileserver/servefile.go
+++ b/internal/api/client/fileserver/servefile.go
@@ -19,7 +19,9 @@
package fileserver
import (
+ "bytes"
"fmt"
+ "io"
"net/http"
"strconv"
@@ -120,5 +122,14 @@ func (m *FileServer) ServeFile(c *gin.Context) {
return
}
- c.DataFromReader(http.StatusOK, content.ContentLength, format, content.Content, nil)
+ // try to slurp the first few bytes to make sure we have something
+ b := bytes.NewBuffer(make([]byte, 0, 64))
+ if _, err := io.CopyN(b, content.Content, 64); err != nil {
+ err = fmt.Errorf("ServeFile: error reading from content: %w", err)
+ api.ErrorHandler(c, gtserror.NewErrorNotFound(err, err.Error()), m.processor.InstanceGet)
+ return
+ }
+
+ // we're good, return the slurped bytes + the rest of the content
+ c.DataFromReader(http.StatusOK, content.ContentLength, format, io.MultiReader(b, content.Content), nil)
}
diff --git a/internal/api/client/fileserver/servefile_test.go b/internal/api/client/fileserver/servefile_test.go
index a6c46e23f..1ca0c60d6 100644
--- a/internal/api/client/fileserver/servefile_test.go
+++ b/internal/api/client/fileserver/servefile_test.go
@@ -20,196 +20,251 @@
import (
"context"
- "fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
- "github.com/gin-gonic/gin"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/api/client/fileserver"
- "github.com/superseriousbusiness/gotosocial/internal/concurrency"
- "github.com/superseriousbusiness/gotosocial/internal/db"
- "github.com/superseriousbusiness/gotosocial/internal/email"
- "github.com/superseriousbusiness/gotosocial/internal/federation"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/media"
- "github.com/superseriousbusiness/gotosocial/internal/messages"
- "github.com/superseriousbusiness/gotosocial/internal/oauth"
- "github.com/superseriousbusiness/gotosocial/internal/processing"
- "github.com/superseriousbusiness/gotosocial/internal/storage"
- "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type ServeFileTestSuite struct {
- // standard suite interfaces
- suite.Suite
- db db.DB
- storage *storage.Driver
- federator federation.Federator
- tc typeutils.TypeConverter
- processor processing.Processor
- mediaManager media.Manager
- oauthServer oauth.Server
- emailSender email.Sender
-
- // standard suite models
- testTokens map[string]*gtsmodel.Token
- testClients map[string]*gtsmodel.Client
- testApplications map[string]*gtsmodel.Application
- testUsers map[string]*gtsmodel.User
- testAccounts map[string]*gtsmodel.Account
- testAttachments map[string]*gtsmodel.MediaAttachment
-
- // item being tested
- fileServer *fileserver.FileServer
+ FileserverTestSuite
}
-/*
- TEST INFRASTRUCTURE
-*/
-
-func (suite *ServeFileTestSuite) SetupSuite() {
- // setup standard items
- testrig.InitTestConfig()
- testrig.InitTestLog()
-
- fedWorker := concurrency.NewWorkerPool[messages.FromFederator](-1, -1)
- clientWorker := concurrency.NewWorkerPool[messages.FromClientAPI](-1, -1)
-
- suite.db = testrig.NewTestDB()
- suite.storage = testrig.NewInMemoryStorage()
- suite.federator = testrig.NewTestFederator(suite.db, testrig.NewTestTransportController(testrig.NewMockHTTPClient(nil, "../../../../testrig/media"), suite.db, fedWorker), suite.storage, suite.mediaManager, fedWorker)
- suite.emailSender = testrig.NewEmailSender("../../../../web/template/", nil)
-
- suite.processor = testrig.NewTestProcessor(suite.db, suite.storage, suite.federator, suite.emailSender, testrig.NewTestMediaManager(suite.db, suite.storage), clientWorker, fedWorker)
- suite.tc = testrig.NewTestTypeConverter(suite.db)
- suite.mediaManager = testrig.NewTestMediaManager(suite.db, suite.storage)
- suite.oauthServer = testrig.NewTestOauthServer(suite.db)
-
- // setup module being tested
- suite.fileServer = fileserver.New(suite.processor).(*fileserver.FileServer)
-}
-
-func (suite *ServeFileTestSuite) TearDownSuite() {
- if err := suite.db.Stop(context.Background()); err != nil {
- log.Panicf("error closing db connection: %s", err)
- }
-}
-
-func (suite *ServeFileTestSuite) SetupTest() {
- testrig.StandardDBSetup(suite.db, nil)
- testrig.StandardStorageSetup(suite.storage, "../../../../testrig/media")
- suite.testTokens = testrig.NewTestTokens()
- suite.testClients = testrig.NewTestClients()
- suite.testApplications = testrig.NewTestApplications()
- suite.testUsers = testrig.NewTestUsers()
- suite.testAccounts = testrig.NewTestAccounts()
- suite.testAttachments = testrig.NewTestAttachments()
-}
-
-func (suite *ServeFileTestSuite) TearDownTest() {
- testrig.StandardDBTeardown(suite.db)
- testrig.StandardStorageTeardown(suite.storage)
-}
-
-/*
- ACTUAL TESTS
-*/
-
-func (suite *ServeFileTestSuite) TestServeOriginalFileSuccessful() {
- targetAttachment, ok := suite.testAttachments["admin_account_status_1_attachment_1"]
- suite.True(ok)
- suite.NotNil(targetAttachment)
-
+// GetFile is just a convenience function to save repetition in this test suite.
+// It takes the required params to serve a file, calls the handler, and returns
+// the http status code, the response headers, and the parsed body bytes.
+func (suite *ServeFileTestSuite) GetFile(
+ accountID string,
+ mediaType media.Type,
+ mediaSize media.Size,
+ filename string,
+) (code int, headers http.Header, body []byte) {
recorder := httptest.NewRecorder()
- ctx, _ := testrig.CreateGinTestContext(recorder, nil)
- ctx.Request = httptest.NewRequest(http.MethodGet, targetAttachment.URL, nil)
- ctx.Request.Header.Set("accept", "*/*")
- // normally the router would populate these params from the path values,
- // but because we're calling the ServeFile function directly, we need to set them manually.
- ctx.Params = gin.Params{
- gin.Param{
- Key: fileserver.AccountIDKey,
- Value: targetAttachment.AccountID,
- },
- gin.Param{
- Key: fileserver.MediaTypeKey,
- Value: string(media.TypeAttachment),
- },
- gin.Param{
- Key: fileserver.MediaSizeKey,
- Value: string(media.SizeOriginal),
- },
- gin.Param{
- Key: fileserver.FileNameKey,
- Value: fmt.Sprintf("%s.jpeg", targetAttachment.ID),
- },
+ ctx, _ := testrig.CreateGinTestContext(recorder, nil)
+ ctx.Request = httptest.NewRequest(http.MethodGet, "http://localhost:8080/whatever", nil)
+ ctx.Request.Header.Set("accept", "*/*")
+ ctx.AddParam(fileserver.AccountIDKey, accountID)
+ ctx.AddParam(fileserver.MediaTypeKey, string(mediaType))
+ ctx.AddParam(fileserver.MediaSizeKey, string(mediaSize))
+ ctx.AddParam(fileserver.FileNameKey, filename)
+
+ suite.fileServer.ServeFile(ctx)
+ code = recorder.Code
+ headers = recorder.Result().Header
+
+ var err error
+ body, err = ioutil.ReadAll(recorder.Body)
+ if err != nil {
+ suite.FailNow(err.Error())
}
- // call the function we're testing and check status code
- suite.fileServer.ServeFile(ctx)
- suite.EqualValues(http.StatusOK, recorder.Code)
- suite.EqualValues("image/jpeg", recorder.Header().Get("content-type"))
-
- b, err := ioutil.ReadAll(recorder.Body)
- suite.NoError(err)
- suite.NotNil(b)
-
- fileInStorage, err := suite.storage.Get(ctx, targetAttachment.File.Path)
- suite.NoError(err)
- suite.NotNil(fileInStorage)
- suite.Equal(b, fileInStorage)
+ return
}
-func (suite *ServeFileTestSuite) TestServeSmallFileSuccessful() {
- targetAttachment, ok := suite.testAttachments["admin_account_status_1_attachment_1"]
- suite.True(ok)
- suite.NotNil(targetAttachment)
+// UncacheAttachment is a convenience function that uncaches the targetAttachment by
+// removing its associated files from storage, and updating the database.
+func (suite *ServeFileTestSuite) UncacheAttachment(targetAttachment *gtsmodel.MediaAttachment) {
+ ctx := context.Background()
- recorder := httptest.NewRecorder()
- ctx, _ := testrig.CreateGinTestContext(recorder, nil)
- ctx.Request = httptest.NewRequest(http.MethodGet, targetAttachment.Thumbnail.URL, nil)
- ctx.Request.Header.Set("accept", "*/*")
+ cached := false
+ targetAttachment.Cached = &cached
- // normally the router would populate these params from the path values,
- // but because we're calling the ServeFile function directly, we need to set them manually.
- ctx.Params = gin.Params{
- gin.Param{
- Key: fileserver.AccountIDKey,
- Value: targetAttachment.AccountID,
- },
- gin.Param{
- Key: fileserver.MediaTypeKey,
- Value: string(media.TypeAttachment),
- },
- gin.Param{
- Key: fileserver.MediaSizeKey,
- Value: string(media.SizeSmall),
- },
- gin.Param{
- Key: fileserver.FileNameKey,
- Value: fmt.Sprintf("%s.jpeg", targetAttachment.ID),
- },
+ if err := suite.db.UpdateByID(ctx, targetAttachment, targetAttachment.ID, "cached"); err != nil {
+ suite.FailNow(err.Error())
+ }
+ if err := suite.storage.Delete(ctx, targetAttachment.File.Path); err != nil {
+ suite.FailNow(err.Error())
+ }
+ if err := suite.storage.Delete(ctx, targetAttachment.Thumbnail.Path); err != nil {
+ suite.FailNow(err.Error())
+ }
+}
+
+func (suite *ServeFileTestSuite) TestServeOriginalLocalFileOK() {
+ targetAttachment := >smodel.MediaAttachment{}
+ *targetAttachment = *suite.testAttachments["admin_account_status_1_attachment_1"]
+ fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path)
+ if err != nil {
+ suite.FailNow(err.Error())
}
- // call the function we're testing and check status code
- suite.fileServer.ServeFile(ctx)
- suite.EqualValues(http.StatusOK, recorder.Code)
- suite.EqualValues("image/jpeg", recorder.Header().Get("content-type"))
+ code, headers, body := suite.GetFile(
+ targetAttachment.AccountID,
+ media.TypeAttachment,
+ media.SizeOriginal,
+ targetAttachment.ID+".jpeg",
+ )
- b, err := ioutil.ReadAll(recorder.Body)
- suite.NoError(err)
- suite.NotNil(b)
+ suite.Equal(http.StatusOK, code)
+ suite.Equal("image/jpeg", headers.Get("content-type"))
+ suite.Equal(fileInStorage, body)
+}
- fileInStorage, err := suite.storage.Get(ctx, targetAttachment.Thumbnail.Path)
- suite.NoError(err)
- suite.NotNil(fileInStorage)
- suite.Equal(b, fileInStorage)
+func (suite *ServeFileTestSuite) TestServeSmallLocalFileOK() {
+ targetAttachment := >smodel.MediaAttachment{}
+ *targetAttachment = *suite.testAttachments["admin_account_status_1_attachment_1"]
+ fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ code, headers, body := suite.GetFile(
+ targetAttachment.AccountID,
+ media.TypeAttachment,
+ media.SizeSmall,
+ targetAttachment.ID+".jpeg",
+ )
+
+ suite.Equal(http.StatusOK, code)
+ suite.Equal("image/jpeg", headers.Get("content-type"))
+ suite.Equal(fileInStorage, body)
+}
+
+func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileOK() {
+ targetAttachment := >smodel.MediaAttachment{}
+ *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"]
+ fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ code, headers, body := suite.GetFile(
+ targetAttachment.AccountID,
+ media.TypeAttachment,
+ media.SizeOriginal,
+ targetAttachment.ID+".jpeg",
+ )
+
+ suite.Equal(http.StatusOK, code)
+ suite.Equal("image/jpeg", headers.Get("content-type"))
+ suite.Equal(fileInStorage, body)
+}
+
+func (suite *ServeFileTestSuite) TestServeSmallRemoteFileOK() {
+ targetAttachment := >smodel.MediaAttachment{}
+ *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"]
+ fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ code, headers, body := suite.GetFile(
+ targetAttachment.AccountID,
+ media.TypeAttachment,
+ media.SizeSmall,
+ targetAttachment.ID+".jpeg",
+ )
+
+ suite.Equal(http.StatusOK, code)
+ suite.Equal("image/jpeg", headers.Get("content-type"))
+ suite.Equal(fileInStorage, body)
+}
+
+func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileRecache() {
+ targetAttachment := >smodel.MediaAttachment{}
+ *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"]
+ fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.File.Path)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // uncache the attachment so we'll have to refetch it from the 'remote' instance
+ suite.UncacheAttachment(targetAttachment)
+
+ code, headers, body := suite.GetFile(
+ targetAttachment.AccountID,
+ media.TypeAttachment,
+ media.SizeOriginal,
+ targetAttachment.ID+".jpeg",
+ )
+
+ suite.Equal(http.StatusOK, code)
+ suite.Equal("image/jpeg", headers.Get("content-type"))
+ suite.Equal(fileInStorage, body)
+}
+
+func (suite *ServeFileTestSuite) TestServeSmallRemoteFileRecache() {
+ targetAttachment := >smodel.MediaAttachment{}
+ *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"]
+ fileInStorage, err := suite.storage.Get(context.Background(), targetAttachment.Thumbnail.Path)
+ if err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ // uncache the attachment so we'll have to refetch it from the 'remote' instance
+ suite.UncacheAttachment(targetAttachment)
+
+ code, headers, body := suite.GetFile(
+ targetAttachment.AccountID,
+ media.TypeAttachment,
+ media.SizeSmall,
+ targetAttachment.ID+".jpeg",
+ )
+
+ suite.Equal(http.StatusOK, code)
+ suite.Equal("image/jpeg", headers.Get("content-type"))
+ suite.Equal(fileInStorage, body)
+}
+
+func (suite *ServeFileTestSuite) TestServeOriginalRemoteFileRecacheNotFound() {
+ targetAttachment := >smodel.MediaAttachment{}
+ *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"]
+
+ // uncache the attachment *and* set the remote URL to something that will return a 404
+ suite.UncacheAttachment(targetAttachment)
+ targetAttachment.RemoteURL = "http://nothing.at.this.url/weeeeeeeee"
+ if err := suite.db.UpdateByID(context.Background(), targetAttachment, targetAttachment.ID, "remote_url"); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ code, _, _ := suite.GetFile(
+ targetAttachment.AccountID,
+ media.TypeAttachment,
+ media.SizeOriginal,
+ targetAttachment.ID+".jpeg",
+ )
+
+ suite.Equal(http.StatusNotFound, code)
+}
+
+func (suite *ServeFileTestSuite) TestServeSmallRemoteFileRecacheNotFound() {
+ targetAttachment := >smodel.MediaAttachment{}
+ *targetAttachment = *suite.testAttachments["remote_account_1_status_1_attachment_1"]
+
+ // uncache the attachment *and* set the remote URL to something that will return a 404
+ suite.UncacheAttachment(targetAttachment)
+ targetAttachment.RemoteURL = "http://nothing.at.this.url/weeeeeeeee"
+ if err := suite.db.UpdateByID(context.Background(), targetAttachment, targetAttachment.ID, "remote_url"); err != nil {
+ suite.FailNow(err.Error())
+ }
+
+ code, _, _ := suite.GetFile(
+ targetAttachment.AccountID,
+ media.TypeAttachment,
+ media.SizeSmall,
+ targetAttachment.ID+".jpeg",
+ )
+
+ suite.Equal(http.StatusNotFound, code)
+}
+
+// Callers trying to get some random-ass file that doesn't exist should just get a 404
+func (suite *ServeFileTestSuite) TestServeFileNotFound() {
+ code, _, _ := suite.GetFile(
+ "01GMMY4G9B0QEG0PQK5Q5JGJWZ",
+ media.TypeAttachment,
+ media.SizeOriginal,
+ "01GMMY68Y7E5DJ3CA3Y9SS8524.jpeg",
+ )
+
+ suite.Equal(http.StatusNotFound, code)
}
func TestServeFileTestSuite(t *testing.T) {
diff --git a/internal/iotools/io.go b/internal/iotools/io.go
new file mode 100644
index 000000000..d16a4ce9c
--- /dev/null
+++ b/internal/iotools/io.go
@@ -0,0 +1,121 @@
+/*
+ GoToSocial
+ Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+*/
+
+package iotools
+
+import (
+ "io"
+)
+
+// ReadFnCloser takes an io.Reader and wraps it to use the provided function to implement io.Closer.
+func ReadFnCloser(r io.Reader, close func() error) io.ReadCloser {
+ return &readFnCloser{
+ Reader: r,
+ close: close,
+ }
+}
+
+type readFnCloser struct {
+ io.Reader
+ close func() error
+}
+
+func (r *readFnCloser) Close() error {
+ return r.close()
+}
+
+// WriteFnCloser takes an io.Writer and wraps it to use the provided function to implement io.Closer.
+func WriteFnCloser(w io.Writer, close func() error) io.WriteCloser {
+ return &writeFnCloser{
+ Writer: w,
+ close: close,
+ }
+}
+
+type writeFnCloser struct {
+ io.Writer
+ close func() error
+}
+
+func (r *writeFnCloser) Close() error {
+ return r.close()
+}
+
+// SilentReader wraps an io.Reader to silence any
+// error output during reads. Instead they are stored
+// and accessible (not concurrency safe!) via .Error().
+type SilentReader struct {
+ io.Reader
+ err error
+}
+
+// SilenceReader wraps an io.Reader within SilentReader{}.
+func SilenceReader(r io.Reader) *SilentReader {
+ return &SilentReader{Reader: r}
+}
+
+func (r *SilentReader) Read(b []byte) (int, error) {
+ n, err := r.Reader.Read(b)
+ if err != nil {
+ // Store error for now
+ if r.err == nil {
+ r.err = err
+ }
+
+ // Pretend we're happy
+ // to continue reading.
+ n = len(b)
+ }
+ return n, nil
+}
+
+func (r *SilentReader) Error() error {
+ return r.err
+}
+
+// SilentWriter wraps an io.Writer to silence any
+// error output during writes. Instead they are stored
+// and accessible (not concurrency safe!) via .Error().
+type SilentWriter struct {
+ io.Writer
+ err error
+}
+
+// SilenceWriter wraps an io.Writer within SilentWriter{}.
+func SilenceWriter(w io.Writer) *SilentWriter {
+ return &SilentWriter{Writer: w}
+}
+
+func (w *SilentWriter) Write(b []byte) (int, error) {
+ n, err := w.Writer.Write(b)
+ if err != nil {
+ // Store error for now
+ if w.err == nil {
+ w.err = err
+ }
+
+ // Pretend we're happy
+ // to continue writing.
+ n = len(b)
+ }
+ return n, nil
+}
+
+func (w *SilentWriter) Error() error {
+ return w.err
+}
diff --git a/internal/media/manager_test.go b/internal/media/manager_test.go
index a8912bde0..f9361a831 100644
--- a/internal/media/manager_test.go
+++ b/internal/media/manager_test.go
@@ -440,7 +440,7 @@ func (suite *ManagerTestSuite) TestSlothVineProcessBlocking() {
processedThumbnailBytes, err := suite.storage.Get(ctx, attachment.Thumbnail.Path)
suite.NoError(err)
suite.NotEmpty(processedThumbnailBytes)
-
+
processedThumbnailBytesExpected, err := os.ReadFile("./test/test-mp4-thumbnail.jpg")
suite.NoError(err)
suite.NotEmpty(processedThumbnailBytesExpected)
diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go
index ddc14479a..eba3fdb7e 100644
--- a/internal/processing/media/getfile.go
+++ b/internal/processing/media/getfile.go
@@ -19,7 +19,6 @@
package media
import (
- "bufio"
"context"
"fmt"
"io"
@@ -29,7 +28,7 @@
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
- "github.com/superseriousbusiness/gotosocial/internal/log"
+ "github.com/superseriousbusiness/gotosocial/internal/iotools"
"github.com/superseriousbusiness/gotosocial/internal/media"
"github.com/superseriousbusiness/gotosocial/internal/transport"
"github.com/superseriousbusiness/gotosocial/internal/uris"
@@ -135,7 +134,6 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount
}
var data media.DataFunc
- var postDataCallback media.PostDataCallbackFunc
if mediaSize == media.SizeSmall {
// if it's the thumbnail that's requested then the user will have to wait a bit while we process the
@@ -155,7 +153,7 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount
//
// this looks a bit like this:
//
- // http fetch buffered pipe
+ // http fetch pipe
// remote server ------------> data function ----------------> api caller
// |
// | tee
@@ -163,54 +161,58 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount
// ▼
// instance storage
- // Buffer each end of the pipe, so that if the caller drops the connection during the flow, the tee
- // reader can continue without having to worry about tee-ing into a closed or blocked pipe.
+ // This pipe will connect the caller to the in-process media retrieval...
pipeReader, pipeWriter := io.Pipe()
- bufferedWriter := bufio.NewWriterSize(pipeWriter, int(attachmentContent.ContentLength))
- bufferedReader := bufio.NewReaderSize(pipeReader, int(attachmentContent.ContentLength))
- // the caller will read from the buffered reader, so it doesn't matter if they drop out without reading everything
- attachmentContent.Content = io.NopCloser(bufferedReader)
+ // Wrap the output pipe to silence any errors during the actual media
+ // streaming process. We catch the error later but they must be silenced
+ // during stream to prevent interruptions to storage of the actual media.
+ silencedWriter := iotools.SilenceWriter(pipeWriter)
+ // Pass the reader side of the pipe to the caller to slurp from.
+ attachmentContent.Content = pipeReader
+
+ // Create a data function which injects the writer end of the pipe
+ // into the data retrieval process. If something goes wrong while
+ // doing the data retrieval, we hang up the underlying pipeReader
+ // to indicate to the caller that no data is available. It's up to
+ // the caller of this processor function to handle that gracefully.
data = func(innerCtx context.Context) (io.ReadCloser, int64, error) {
t, err := p.transportController.NewTransportForUsername(innerCtx, requestingUsername)
if err != nil {
+ // propagate the transport error to read end of pipe.
+ _ = pipeWriter.CloseWithError(fmt.Errorf("error getting transport for user: %w", err))
return nil, 0, err
}
readCloser, fileSize, err := t.DereferenceMedia(transport.WithFastfail(innerCtx), remoteMediaIRI)
if err != nil {
+ // propagate the dereference error to read end of pipe.
+ _ = pipeWriter.CloseWithError(fmt.Errorf("error dereferencing media: %w", err))
return nil, 0, err
}
- // Make a TeeReader so that everything read from the readCloser by the media manager will be written into the bufferedWriter.
- // We wrap this in a teeReadCloser which implements io.ReadCloser, so that whoever uses the teeReader can close the readCloser
- // when they're done with it.
- trc := teeReadCloser{
- teeReader: io.TeeReader(readCloser, bufferedWriter),
- close: readCloser.Close,
- }
+ // Make a TeeReader so that everything read from the readCloser,
+ // aka the remote instance, will also be written into the pipe.
+ teeReader := io.TeeReader(readCloser, silencedWriter)
- return trc, fileSize, nil
- }
+ // Wrap teereader to implement original readcloser's close,
+ // and also ensuring that we close the pipe from write end.
+ return iotools.ReadFnCloser(teeReader, func() error {
+ defer func() {
+ // We use the error (if any) encountered by the
+ // silenced writer to close connection to make sure it
+ // gets propagated to the attachment.Content reader.
+ _ = pipeWriter.CloseWithError(silencedWriter.Error())
+ }()
- // close the pipewriter after data has been piped into it, so the reader on the other side doesn't block;
- // we don't need to close the reader here because that's the caller's responsibility
- postDataCallback = func(innerCtx context.Context) error {
- // close the underlying pipe writer when we're done with it
- defer func() {
- if err := pipeWriter.Close(); err != nil {
- log.Errorf("getAttachmentContent: error closing pipeWriter: %s", err)
- }
- }()
-
- // and flush the buffered writer into the buffer of the reader
- return bufferedWriter.Flush()
+ return readCloser.Close()
+ }), fileSize, nil
}
}
// put the media recached in the queue
- processingMedia, err := p.mediaManager.RecacheMedia(ctx, data, postDataCallback, wantedMediaID)
+ processingMedia, err := p.mediaManager.RecacheMedia(ctx, data, nil, wantedMediaID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("error recaching media: %s", err))
}
diff --git a/internal/processing/media/getfile_test.go b/internal/processing/media/getfile_test.go
index ba7269535..7b9786914 100644
--- a/internal/processing/media/getfile_test.go
+++ b/internal/processing/media/getfile_test.go
@@ -19,6 +19,7 @@
package media_test
import (
+ "bytes"
"context"
"io"
"path"
@@ -143,9 +144,13 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncachedInterrupted() {
suite.NotNil(content)
// only read the first kilobyte and then stop
- b := make([]byte, 1024)
- _, err = content.Content.Read(b)
- suite.NoError(err)
+ b := make([]byte, 0, 1024)
+ if !testrig.WaitFor(func() bool {
+ read, err := io.CopyN(bytes.NewBuffer(b), content.Content, 1024)
+ return err == nil && read == 1024
+ }) {
+ suite.FailNow("timed out trying to read first 1024 bytes")
+ }
// close the reader
suite.NoError(content.Content.Close())
diff --git a/internal/processing/media/util.go b/internal/processing/media/util.go
index 9739e70b7..37dc87979 100644
--- a/internal/processing/media/util.go
+++ b/internal/processing/media/util.go
@@ -20,7 +20,6 @@
import (
"fmt"
- "io"
"strconv"
"strings"
)
@@ -62,16 +61,3 @@ func parseFocus(focus string) (focusx, focusy float32, err error) {
focusy = float32(fy)
return
}
-
-type teeReadCloser struct {
- teeReader io.Reader
- close func() error
-}
-
-func (t teeReadCloser) Read(p []byte) (n int, err error) {
- return t.teeReader.Read(p)
-}
-
-func (t teeReadCloser) Close() error {
- return t.close()
-}