From 1dfa7fe0d51b75792db7b0c28ffad7d1f650834d Mon Sep 17 00:00:00 2001 From: tobi <31960611+tsmethurst@users.noreply.github.com> Date: Thu, 3 Nov 2022 15:03:12 +0100 Subject: [PATCH] [bugfix] Wrap media in read closer (#941) * use readcloser for content.Content * call media postdata function no matter what * return a readcloser from data func * tidy of logic of readertostore * fix whoopsie --- internal/api/client/fileserver/servefile.go | 9 ++-- internal/api/model/content.go | 2 +- internal/federation/dereferencing/account.go | 4 +- internal/federation/dereferencing/emoji.go | 2 +- internal/federation/dereferencing/media.go | 2 +- internal/media/manager_test.go | 54 ++++++++++---------- internal/media/processingemoji.go | 14 +++-- internal/media/processingmedia.go | 40 +++++++++------ internal/media/pruneremote_test.go | 4 +- internal/media/types.go | 2 +- internal/processing/account/update.go | 4 +- internal/processing/admin/createemoji.go | 2 +- internal/processing/media/create.go | 2 +- internal/processing/media/getfile.go | 25 +++++---- internal/processing/media/getfile_test.go | 14 ++--- 15 files changed, 88 insertions(+), 92 deletions(-) diff --git a/internal/api/client/fileserver/servefile.go b/internal/api/client/fileserver/servefile.go index 236a2d8ac..e4eca770f 100644 --- a/internal/api/client/fileserver/servefile.go +++ b/internal/api/client/fileserver/servefile.go @@ -20,7 +20,6 @@ import ( "fmt" - "io" "net/http" "strconv" @@ -86,12 +85,10 @@ func (m *FileServer) ServeFile(c *gin.Context) { } defer func() { - // if the content is a ReadCloser (ie., it's streamed from storage), close it when we're done + // close content when we're done if content.Content != nil { - if closer, ok := content.Content.(io.ReadCloser); ok { - if err := closer.Close(); err != nil { - log.Errorf("ServeFile: error closing readcloser: %s", err) - } + if err := content.Content.Close(); err != nil { + log.Errorf("ServeFile: error closing readcloser: %s", err) } } }() diff --git a/internal/api/model/content.go b/internal/api/model/content.go index aa02a99c3..ecce07356 100644 --- a/internal/api/model/content.go +++ b/internal/api/model/content.go @@ -30,7 +30,7 @@ type Content struct { // ContentLength in bytes ContentLength int64 // Actual content - Content io.Reader + Content io.ReadCloser // Resource URL to forward to if the file can be fetched from the storage directly (e.g signed S3 URL) URL *url.URL } diff --git a/internal/federation/dereferencing/account.go b/internal/federation/dereferencing/account.go index dfe7693fb..21900d47b 100644 --- a/internal/federation/dereferencing/account.go +++ b/internal/federation/dereferencing/account.go @@ -496,7 +496,7 @@ func (d *deref) fetchRemoteAccountMedia(ctx context.Context, targetAccount *gtsm } } - data := func(innerCtx context.Context) (io.Reader, int64, error) { + data := func(innerCtx context.Context) (io.ReadCloser, int64, error) { return t.DereferenceMedia(innerCtx, avatarIRI) } @@ -562,7 +562,7 @@ func (d *deref) fetchRemoteAccountMedia(ctx context.Context, targetAccount *gtsm } } - data := func(innerCtx context.Context) (io.Reader, int64, error) { + data := func(innerCtx context.Context) (io.ReadCloser, int64, error) { return t.DereferenceMedia(innerCtx, headerIRI) } diff --git a/internal/federation/dereferencing/emoji.go b/internal/federation/dereferencing/emoji.go index 3cdb1d52d..2d32c8803 100644 --- a/internal/federation/dereferencing/emoji.go +++ b/internal/federation/dereferencing/emoji.go @@ -42,7 +42,7 @@ func (d *deref) GetRemoteEmoji(ctx context.Context, requestingUsername string, r return nil, fmt.Errorf("GetRemoteEmoji: error parsing url: %s", err) } - dataFunc := func(innerCtx context.Context) (io.Reader, int64, error) { + dataFunc := func(innerCtx context.Context) (io.ReadCloser, int64, error) { return t.DereferenceMedia(innerCtx, derefURI) } diff --git a/internal/federation/dereferencing/media.go b/internal/federation/dereferencing/media.go index 1b99eaa96..afa184a98 100644 --- a/internal/federation/dereferencing/media.go +++ b/internal/federation/dereferencing/media.go @@ -42,7 +42,7 @@ func (d *deref) GetRemoteMedia(ctx context.Context, requestingUsername string, a return nil, fmt.Errorf("GetRemoteMedia: error parsing url: %s", err) } - dataFunc := func(innerCtx context.Context) (io.Reader, int64, error) { + dataFunc := func(innerCtx context.Context) (io.ReadCloser, int64, error) { return t.DereferenceMedia(innerCtx, derefURI) } diff --git a/internal/media/manager_test.go b/internal/media/manager_test.go index e00cdd98d..b50235054 100644 --- a/internal/media/manager_test.go +++ b/internal/media/manager_test.go @@ -43,13 +43,13 @@ type ManagerTestSuite struct { func (suite *ManagerTestSuite) TestEmojiProcessBlocking() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/rainbow-original.png") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } emojiID := "01GDQ9G782X42BAMFASKP64343" @@ -114,12 +114,12 @@ func (suite *ManagerTestSuite) TestEmojiProcessBlockingRefresh() { oldEmojiImagePath := emojiToUpdate.ImagePath oldEmojiImageStaticPath := emojiToUpdate.ImageStaticPath - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { b, err := os.ReadFile("./test/gts_pixellated-original.png") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } emojiID := emojiToUpdate.ID @@ -197,13 +197,13 @@ func (suite *ManagerTestSuite) TestEmojiProcessBlockingRefresh() { func (suite *ManagerTestSuite) TestEmojiProcessBlockingTooLarge() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/big-panda.gif") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } emojiID := "01GDQ9G782X42BAMFASKP64343" @@ -221,13 +221,13 @@ func (suite *ManagerTestSuite) TestEmojiProcessBlockingTooLarge() { func (suite *ManagerTestSuite) TestEmojiProcessBlockingTooLargeNoSizeGiven() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/big-panda.gif") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } emojiID := "01GDQ9G782X42BAMFASKP64343" @@ -245,13 +245,13 @@ func (suite *ManagerTestSuite) TestEmojiProcessBlockingTooLargeNoSizeGiven() { func (suite *ManagerTestSuite) TestEmojiProcessBlockingNoFileSizeGiven() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/rainbow-original.png") if err != nil { panic(err) } - return bytes.NewBuffer(b), -1, nil + return io.NopCloser(bytes.NewBuffer(b)), -1, nil } emojiID := "01GDQ9G782X42BAMFASKP64343" @@ -307,13 +307,13 @@ func (suite *ManagerTestSuite) TestEmojiProcessBlockingNoFileSizeGiven() { func (suite *ManagerTestSuite) TestSimpleJpegProcessBlocking() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/test-jpeg.jpg") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } accountID := "01FS1X72SK9ZPW0J1QQ68BD264" @@ -379,14 +379,14 @@ func (suite *ManagerTestSuite) TestSimpleJpegProcessBlocking() { func (suite *ManagerTestSuite) TestSimpleJpegProcessBlockingNoContentLengthGiven() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/test-jpeg.jpg") if err != nil { panic(err) } // give length as -1 to indicate unknown - return bytes.NewBuffer(b), -1, nil + return io.NopCloser(bytes.NewBuffer(b)), -1, nil } accountID := "01FS1X72SK9ZPW0J1QQ68BD264" @@ -452,7 +452,7 @@ func (suite *ManagerTestSuite) TestSimpleJpegProcessBlockingNoContentLengthGiven func (suite *ManagerTestSuite) TestSimpleJpegProcessBlockingReadCloser() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // open test image as a file f, err := os.Open("./test/test-jpeg.jpg") if err != nil { @@ -525,13 +525,13 @@ func (suite *ManagerTestSuite) TestSimpleJpegProcessBlockingReadCloser() { func (suite *ManagerTestSuite) TestPngNoAlphaChannelProcessBlocking() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/test-png-noalphachannel.png") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } accountID := "01FS1X72SK9ZPW0J1QQ68BD264" @@ -597,13 +597,13 @@ func (suite *ManagerTestSuite) TestPngNoAlphaChannelProcessBlocking() { func (suite *ManagerTestSuite) TestPngAlphaChannelProcessBlocking() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/test-png-alphachannel.png") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } accountID := "01FS1X72SK9ZPW0J1QQ68BD264" @@ -669,13 +669,13 @@ func (suite *ManagerTestSuite) TestPngAlphaChannelProcessBlocking() { func (suite *ManagerTestSuite) TestSimpleJpegProcessBlockingWithCallback() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/test-jpeg.jpg") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } // test the callback function by setting a simple boolean @@ -752,13 +752,13 @@ func (suite *ManagerTestSuite) TestSimpleJpegProcessBlockingWithCallback() { func (suite *ManagerTestSuite) TestSimpleJpegProcessAsync() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/test-jpeg.jpg") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } accountID := "01FS1X72SK9ZPW0J1QQ68BD264" @@ -837,9 +837,9 @@ func (suite *ManagerTestSuite) TestSimpleJpegQueueSpamming() { panic(err) } - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image - return bytes.NewReader(b), int64(len(b)), nil + return io.NopCloser(bytes.NewReader(b)), int64(len(b)), nil } accountID := "01FS1X72SK9ZPW0J1QQ68BD264" @@ -913,13 +913,13 @@ func (suite *ManagerTestSuite) TestSimpleJpegQueueSpamming() { func (suite *ManagerTestSuite) TestSimpleJpegProcessBlockingWithDiskStorage() { ctx := context.Background() - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("./test/test-jpeg.jpg") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } accountID := "01FS1X72SK9ZPW0J1QQ68BD264" diff --git a/internal/media/processingemoji.go b/internal/media/processingemoji.go index e1c6f2efb..79bc23998 100644 --- a/internal/media/processingemoji.go +++ b/internal/media/processingemoji.go @@ -193,24 +193,22 @@ func (p *ProcessingEmoji) store(ctx context.Context) error { return nil } - // execute the data function to get the reader out of it - reader, fileSize, err := p.data(ctx) + // execute the data function to get the readcloser out of it + rc, fileSize, err := p.data(ctx) if err != nil { return fmt.Errorf("store: error executing data function: %s", err) } // defer closing the reader when we're done with it defer func() { - if rc, ok := reader.(io.ReadCloser); ok { - if err := rc.Close(); err != nil { - log.Errorf("store: error closing readcloser: %s", err) - } + if err := rc.Close(); err != nil { + log.Errorf("store: error closing readcloser: %s", err) } }() // extract no more than 261 bytes from the beginning of the file -- this is the header firstBytes := make([]byte, maxFileHeaderBytes) - if _, err := reader.Read(firstBytes); err != nil { + if _, err := rc.Read(firstBytes); err != nil { return fmt.Errorf("store: error reading initial %d bytes: %s", maxFileHeaderBytes, err) } @@ -242,7 +240,7 @@ func (p *ProcessingEmoji) store(ctx context.Context) error { p.emoji.ImageContentType = contentType // concatenate the first bytes with the existing bytes still in the reader (thanks Mara) - readerToStore := io.MultiReader(bytes.NewBuffer(firstBytes), reader) + readerToStore := io.MultiReader(bytes.NewBuffer(firstBytes), rc) var maxEmojiSize int64 if p.emoji.Domain == "" { diff --git a/internal/media/processingmedia.go b/internal/media/processingmedia.go index 1f5a58b9f..573df1d0e 100644 --- a/internal/media/processingmedia.go +++ b/internal/media/processingmedia.go @@ -263,24 +263,31 @@ func (p *ProcessingMedia) store(ctx context.Context) error { return nil } - // execute the data function to get the reader out of it - reader, fileSize, err := p.data(ctx) + // execute the data function to get the readcloser out of it + rc, fileSize, err := p.data(ctx) if err != nil { return fmt.Errorf("store: error executing data function: %s", err) } // defer closing the reader when we're done with it defer func() { - if rc, ok := reader.(io.ReadCloser); ok { - if err := rc.Close(); err != nil { - log.Errorf("store: error closing readcloser: %s", err) + if err := rc.Close(); err != nil { + log.Errorf("store: error closing readcloser: %s", err) + } + }() + + // execute the postData function no matter what happens + defer func() { + if p.postData != nil { + if err := p.postData(ctx); err != nil { + log.Errorf("store: error executing postData: %s", err) } } }() // extract no more than 261 bytes from the beginning of the file -- this is the header firstBytes := make([]byte, maxFileHeaderBytes) - if _, err := reader.Read(firstBytes); err != nil { + if _, err := rc.Read(firstBytes); err != nil { return fmt.Errorf("store: error reading initial %d bytes: %s", maxFileHeaderBytes, err) } @@ -303,29 +310,36 @@ func (p *ProcessingMedia) store(ctx context.Context) error { extension := split[1] // something like 'jpeg' // concatenate the cleaned up first bytes with the existing bytes still in the reader (thanks Mara) - readerToStore := io.MultiReader(bytes.NewBuffer(firstBytes), reader) + multiReader := io.MultiReader(bytes.NewBuffer(firstBytes), rc) // use the extension to derive the attachment type // and, while we're in here, clean up exif data from // the image if we already know the fileSize + var readerToStore io.Reader switch extension { case mimeGif: p.attachment.Type = gtsmodel.FileTypeImage + // nothing to terminate, we can just store the multireader + readerToStore = multiReader case mimeJpeg, mimePng: p.attachment.Type = gtsmodel.FileTypeImage if fileSize > 0 { - var err error - readerToStore, err = terminator.Terminate(readerToStore, int(fileSize), extension) + terminated, err := terminator.Terminate(multiReader, int(fileSize), extension) if err != nil { return fmt.Errorf("store: exif error: %s", err) } defer func() { - if rc, ok := readerToStore.(io.ReadCloser); ok { - if err := rc.Close(); err != nil { + if closer, ok := terminated.(io.Closer); ok { + if err := closer.Close(); err != nil { log.Errorf("store: error closing terminator reader: %s", err) } } }() + // store the exif-terminated version of what was in the multireader + readerToStore = terminated + } else { + // can't terminate if we don't know the file size, so just store the multiReader + readerToStore = multiReader } default: return fmt.Errorf("store: couldn't process %s", extension) @@ -347,10 +361,6 @@ func (p *ProcessingMedia) store(ctx context.Context) error { p.attachment.File.FileSize = int(fileSize) p.read = true - if p.postData != nil { - return p.postData(ctx) - } - log.Tracef("store: finished storing initial data for attachment %s", p.attachment.URL) return nil } diff --git a/internal/media/pruneremote_test.go b/internal/media/pruneremote_test.go index e71d3310b..d3a01b7be 100644 --- a/internal/media/pruneremote_test.go +++ b/internal/media/pruneremote_test.go @@ -74,13 +74,13 @@ func (suite *PruneRemoteTestSuite) TestPruneAndRecache() { suite.ErrorIs(err, storage.ErrNotFound) // now recache the image.... - data := func(_ context.Context) (io.Reader, int64, error) { + data := func(_ context.Context) (io.ReadCloser, int64, error) { // load bytes from a test image b, err := os.ReadFile("../../testrig/media/thoughtsofdog-original.jpeg") if err != nil { panic(err) } - return bytes.NewBuffer(b), int64(len(b)), nil + return io.NopCloser(bytes.NewBuffer(b)), int64(len(b)), nil } processingRecache, err := suite.manager.RecacheMedia(ctx, data, nil, testAttachment.ID) suite.NoError(err) diff --git a/internal/media/types.go b/internal/media/types.go index 3238916b8..763a8137f 100644 --- a/internal/media/types.go +++ b/internal/media/types.go @@ -118,7 +118,7 @@ type AdditionalEmojiInfo struct { } // DataFunc represents a function used to retrieve the raw bytes of a piece of media. -type DataFunc func(ctx context.Context) (reader io.Reader, fileSize int64, err error) +type DataFunc func(ctx context.Context) (reader io.ReadCloser, fileSize int64, err error) // PostDataCallbackFunc represents a function executed after the DataFunc has been executed, // and the returned reader has been read. It can be used to clean up any remaining resources. diff --git a/internal/processing/account/update.go b/internal/processing/account/update.go index f39361c06..2ef3bfe25 100644 --- a/internal/processing/account/update.go +++ b/internal/processing/account/update.go @@ -192,7 +192,7 @@ func (p *processor) UpdateAvatar(ctx context.Context, avatar *multipart.FileHead return nil, fmt.Errorf("UpdateAvatar: avatar with size %d exceeded max image size of %d bytes", avatar.Size, maxImageSize) } - dataFunc := func(innerCtx context.Context) (io.Reader, int64, error) { + dataFunc := func(innerCtx context.Context) (io.ReadCloser, int64, error) { f, err := avatar.Open() return f, avatar.Size, err } @@ -219,7 +219,7 @@ func (p *processor) UpdateHeader(ctx context.Context, header *multipart.FileHead return nil, fmt.Errorf("UpdateHeader: header with size %d exceeded max image size of %d bytes", header.Size, maxImageSize) } - dataFunc := func(innerCtx context.Context) (io.Reader, int64, error) { + dataFunc := func(innerCtx context.Context) (io.ReadCloser, int64, error) { f, err := header.Open() return f, header.Size, err } diff --git a/internal/processing/admin/createemoji.go b/internal/processing/admin/createemoji.go index 93ae17496..a315e144e 100644 --- a/internal/processing/admin/createemoji.go +++ b/internal/processing/admin/createemoji.go @@ -52,7 +52,7 @@ func (p *processor) EmojiCreate(ctx context.Context, account *gtsmodel.Account, emojiURI := uris.GenerateURIForEmoji(emojiID) - data := func(innerCtx context.Context) (io.Reader, int64, error) { + data := func(innerCtx context.Context) (io.ReadCloser, int64, error) { f, err := form.Image.Open() return f, form.Image.Size, err } diff --git a/internal/processing/media/create.go b/internal/processing/media/create.go index eb0c251e9..451a77391 100644 --- a/internal/processing/media/create.go +++ b/internal/processing/media/create.go @@ -30,7 +30,7 @@ ) func (p *processor) Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, gtserror.WithCode) { - data := func(innerCtx context.Context) (io.Reader, int64, error) { + data := func(innerCtx context.Context) (io.ReadCloser, int64, error) { f, err := form.File.Open() return f, form.File.Size, err } diff --git a/internal/processing/media/getfile.go b/internal/processing/media/getfile.go index 693b8685b..522d47435 100644 --- a/internal/processing/media/getfile.go +++ b/internal/processing/media/getfile.go @@ -29,6 +29,7 @@ apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" "github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel" + "github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/media" "github.com/superseriousbusiness/gotosocial/internal/uris" ) @@ -139,7 +140,7 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount // if it's the thumbnail that's requested then the user will have to wait a bit while we process the // large version and derive a thumbnail from it, so use the normal recaching procedure: fetch the media, // process it, then return the thumbnail data - data = func(innerCtx context.Context) (io.Reader, int64, error) { + data = func(innerCtx context.Context) (io.ReadCloser, int64, error) { transport, err := p.transportController.NewTransportForUsername(innerCtx, requestingUsername) if err != nil { return nil, 0, err @@ -168,9 +169,9 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount bufferedReader := bufio.NewReaderSize(pipeReader, int(attachmentContent.ContentLength)) // the caller will read from the buffered reader, so it doesn't matter if they drop out without reading everything - attachmentContent.Content = bufferedReader + attachmentContent.Content = io.NopCloser(bufferedReader) - data = func(innerCtx context.Context) (io.Reader, int64, error) { + data = func(innerCtx context.Context) (io.ReadCloser, int64, error) { transport, err := p.transportController.NewTransportForUsername(innerCtx, requestingUsername) if err != nil { return nil, 0, err @@ -195,17 +196,15 @@ func (p *processor) getAttachmentContent(ctx context.Context, requestingAccount // close the pipewriter after data has been piped into it, so the reader on the other side doesn't block; // we don't need to close the reader here because that's the caller's responsibility postDataCallback = func(innerCtx context.Context) error { - // flush the buffered writer into the buffer of the reader... - if err := bufferedWriter.Flush(); err != nil { - return err - } + // close the underlying pipe writer when we're done with it + defer func() { + if err := pipeWriter.Close(); err != nil { + log.Errorf("getAttachmentContent: error closing pipeWriter: %s", err) + } + }() - // and close the underlying pipe writer - if err := pipeWriter.Close(); err != nil { - return err - } - - return nil + // and flush the buffered writer into the buffer of the reader + return bufferedWriter.Flush() } } diff --git a/internal/processing/media/getfile_test.go b/internal/processing/media/getfile_test.go index 6e5271607..ba7269535 100644 --- a/internal/processing/media/getfile_test.go +++ b/internal/processing/media/getfile_test.go @@ -91,10 +91,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncached() { suite.NotNil(content) b, err := io.ReadAll(content.Content) suite.NoError(err) - - if closer, ok := content.Content.(io.Closer); ok { - suite.NoError(closer.Close()) - } + suite.NoError(content.Content.Close()) suite.Equal(suite.testRemoteAttachments[testAttachment.RemoteURL].Data, b) suite.Equal(suite.testRemoteAttachments[testAttachment.RemoteURL].ContentType, content.ContentType) @@ -151,9 +148,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileUncachedInterrupted() { suite.NoError(err) // close the reader - if closer, ok := content.Content.(io.Closer); ok { - suite.NoError(closer.Close()) - } + suite.NoError(content.Content.Close()) // the attachment should still be updated in the database even though the caller hung up if !testrig.WaitFor(func() bool { @@ -201,10 +196,7 @@ func (suite *GetFileTestSuite) TestGetRemoteFileThumbnailUncached() { suite.NotNil(content) b, err := io.ReadAll(content.Content) suite.NoError(err) - - if closer, ok := content.Content.(io.Closer); ok { - suite.NoError(closer.Close()) - } + suite.NoError(content.Content.Close()) suite.Equal(thumbnailBytes, b) suite.Equal("image/jpeg", content.ContentType)