[bugfix] notification types missing from link header (#3571)

* ensure notification types get included in link header query for notifications

* fix type query keys
This commit is contained in:
kim 2024-11-25 15:33:21 +00:00 committed by GitHub
parent c454b1b488
commit a444adee97
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 87 additions and 103 deletions

View file

@ -18,14 +18,13 @@
package notifications package notifications
import ( import (
"fmt"
"net/http" "net/http"
"strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util" apiutil "github.com/superseriousbusiness/gotosocial/internal/api/util"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/paging"
) )
// NotificationsGETHandler swagger:operation GET /api/v1/notifications notifications // NotificationsGETHandler swagger:operation GET /api/v1/notifications notifications
@ -152,18 +151,6 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) {
return return
} }
limit := 20
limitString := c.Query(LimitKey)
if limitString != "" {
i, err := strconv.ParseInt(limitString, 10, 32)
if err != nil {
err := fmt.Errorf("error parsing %s: %s", LimitKey, err)
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1)
return
}
limit = int(i)
}
types, errWithCode := apiutil.ParseNotificationTypes(c.QueryArray(TypesKey)) types, errWithCode := apiutil.ParseNotificationTypes(c.QueryArray(TypesKey))
if errWithCode != nil { if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
@ -176,13 +163,20 @@ func (m *Module) NotificationsGETHandler(c *gin.Context) {
return return
} }
page, errWithCode := paging.ParseIDPage(c,
1, // min limit
80, // max limit
20, // no limit
)
if errWithCode != nil {
apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return
}
resp, errWithCode := m.processor.Timeline().NotificationsGet( resp, errWithCode := m.processor.Timeline().NotificationsGet(
c.Request.Context(), c.Request.Context(),
authed, authed,
c.Query(MaxIDKey), page,
c.Query(SinceIDKey),
c.Query(MinIDKey),
limit,
types, types,
exclTypes, exclTypes,
) )

View file

@ -26,8 +26,8 @@
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util/xslices"
"github.com/uptrace/bun" "github.com/uptrace/bun"
@ -192,22 +192,19 @@ func (n *notificationDB) PopulateNotification(ctx context.Context, notif *gtsmod
func (n *notificationDB) GetAccountNotifications( func (n *notificationDB) GetAccountNotifications(
ctx context.Context, ctx context.Context,
accountID string, accountID string,
maxID string, page *paging.Page,
sinceID string,
minID string,
limit int,
types []gtsmodel.NotificationType, types []gtsmodel.NotificationType,
excludeTypes []gtsmodel.NotificationType, excludeTypes []gtsmodel.NotificationType,
) ([]*gtsmodel.Notification, error) { ) ([]*gtsmodel.Notification, error) {
// Ensure reasonable
if limit < 0 {
limit = 0
}
// Make educated guess for slice size
var ( var (
notifIDs = make([]string, 0, limit) // Get paging params.
frontToBack = true minID = page.GetMin()
maxID = page.GetMax()
limit = page.GetLimit()
order = page.GetOrder()
// Make educated guess for slice size
notifIDs = make([]string, 0, limit)
) )
q := n.db. q := n.db.
@ -215,23 +212,14 @@ func (n *notificationDB) GetAccountNotifications(
TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")). TableExpr("? AS ?", bun.Ident("notifications"), bun.Ident("notification")).
Column("notification.id") Column("notification.id")
if maxID == "" { if maxID != "" {
maxID = id.Highest // Return only notifs LOWER (ie., older) than maxID.
} q = q.Where("? < ?", bun.Ident("notification.id"), maxID)
// Return only notifs LOWER (ie., older) than maxID.
q = q.Where("? < ?", bun.Ident("notification.id"), maxID)
if sinceID != "" {
// Return only notifs HIGHER (ie., newer) than sinceID.
q = q.Where("? > ?", bun.Ident("notification.id"), sinceID)
} }
if minID != "" { if minID != "" {
// Return only notifs HIGHER (ie., newer) than minID. // Return only notifs HIGHER (ie., newer) than minID.
q = q.Where("? > ?", bun.Ident("notification.id"), minID) q = q.Where("? > ?", bun.Ident("notification.id"), minID)
frontToBack = false // page up
} }
if len(types) > 0 { if len(types) > 0 {
@ -251,12 +239,12 @@ func (n *notificationDB) GetAccountNotifications(
q = q.Limit(limit) q = q.Limit(limit)
} }
if frontToBack { if order == paging.OrderAscending {
// Page down.
q = q.Order("notification.id DESC")
} else {
// Page up. // Page up.
q = q.Order("notification.id ASC") q = q.Order("notification.id ASC")
} else {
// Page down.
q = q.Order("notification.id DESC")
} }
if err := q.Scan(ctx, &notifIDs); err != nil { if err := q.Scan(ctx, &notifIDs); err != nil {
@ -269,11 +257,8 @@ func (n *notificationDB) GetAccountNotifications(
// If we're paging up, we still want notifications // If we're paging up, we still want notifications
// to be sorted by ID desc, so reverse ids slice. // to be sorted by ID desc, so reverse ids slice.
// https://zchee.github.io/golang-wiki/SliceTricks/#reversing if order == paging.OrderAscending {
if !frontToBack { slices.Reverse(notifIDs)
for l, r := 0, len(notifIDs)-1; l < r; l, r = l+1, r-1 {
notifIDs[l], notifIDs[r] = notifIDs[r], notifIDs[l]
}
} }
// Fetch notification models by their IDs. // Fetch notification models by their IDs.

View file

@ -28,6 +28,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
@ -92,10 +93,11 @@ func (suite *NotificationTestSuite) TestGetAccountNotificationsWithSpam() {
notifications, err := suite.db.GetAccountNotifications( notifications, err := suite.db.GetAccountNotifications(
gtscontext.SetBarebones(context.Background()), gtscontext.SetBarebones(context.Background()),
testAccount.ID, testAccount.ID,
id.Highest, &paging.Page{
id.Lowest, Min: paging.EitherMinID("", id.Lowest),
"", Max: paging.MaxID(id.Highest),
20, Limit: 20,
},
nil, nil,
nil, nil,
) )
@ -115,10 +117,11 @@ func (suite *NotificationTestSuite) TestGetAccountNotificationsWithoutSpam() {
notifications, err := suite.db.GetAccountNotifications( notifications, err := suite.db.GetAccountNotifications(
gtscontext.SetBarebones(context.Background()), gtscontext.SetBarebones(context.Background()),
testAccount.ID, testAccount.ID,
id.Highest, &paging.Page{
id.Lowest, Min: paging.EitherMinID("", id.Lowest),
"", Max: paging.MaxID(id.Highest),
20, Limit: 20,
},
nil, nil,
nil, nil,
) )
@ -140,10 +143,11 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
notifications, err := suite.db.GetAccountNotifications( notifications, err := suite.db.GetAccountNotifications(
gtscontext.SetBarebones(context.Background()), gtscontext.SetBarebones(context.Background()),
testAccount.ID, testAccount.ID,
id.Highest, &paging.Page{
id.Lowest, Min: paging.EitherMinID("", id.Lowest),
"", Max: paging.MaxID(id.Highest),
20, Limit: 20,
},
nil, nil,
nil, nil,
) )
@ -161,10 +165,11 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithSpam() {
notifications, err = suite.db.GetAccountNotifications( notifications, err = suite.db.GetAccountNotifications(
gtscontext.SetBarebones(context.Background()), gtscontext.SetBarebones(context.Background()),
testAccount.ID, testAccount.ID,
id.Highest, &paging.Page{
id.Lowest, Min: paging.EitherMinID("", id.Lowest),
"", Max: paging.MaxID(id.Highest),
20, Limit: 20,
},
nil, nil,
nil, nil,
) )
@ -183,10 +188,11 @@ func (suite *NotificationTestSuite) TestDeleteNotificationsWithTwoAccounts() {
notifications, err := suite.db.GetAccountNotifications( notifications, err := suite.db.GetAccountNotifications(
gtscontext.SetBarebones(context.Background()), gtscontext.SetBarebones(context.Background()),
testAccount.ID, testAccount.ID,
id.Highest, &paging.Page{
id.Lowest, Min: paging.EitherMinID("", id.Lowest),
"", Max: paging.MaxID(id.Highest),
20, Limit: 20,
},
nil, nil,
nil, nil,
) )

View file

@ -21,6 +21,7 @@
"context" "context"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging"
) )
// Notification contains functions for creating and getting notifications. // Notification contains functions for creating and getting notifications.
@ -29,7 +30,7 @@ type Notification interface {
// //
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest). // Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
// If types is empty, *all* notification types will be included. // If types is empty, *all* notification types will be included.
GetAccountNotifications(ctx context.Context, accountID string, maxID string, sinceID string, minID string, limit int, types []gtsmodel.NotificationType, excludeTypes []gtsmodel.NotificationType) ([]*gtsmodel.Notification, error) GetAccountNotifications(ctx context.Context, accountID string, page *paging.Page, types []gtsmodel.NotificationType, excludeTypes []gtsmodel.NotificationType) ([]*gtsmodel.Notification, error)
// GetNotificationByID returns one notification according to its id. // GetNotificationByID returns one notification according to its id.
GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error) GetNotificationByID(ctx context.Context, id string) (*gtsmodel.Notification, error)

View file

@ -21,6 +21,7 @@
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net/url"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
@ -31,26 +32,21 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/oauth" "github.com/superseriousbusiness/gotosocial/internal/oauth"
"github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
func (p *Processor) NotificationsGet( func (p *Processor) NotificationsGet(
ctx context.Context, ctx context.Context,
authed *oauth.Auth, authed *oauth.Auth,
maxID string, page *paging.Page,
sinceID string,
minID string,
limit int,
types []gtsmodel.NotificationType, types []gtsmodel.NotificationType,
excludeTypes []gtsmodel.NotificationType, excludeTypes []gtsmodel.NotificationType,
) (*apimodel.PageableResponse, gtserror.WithCode) { ) (*apimodel.PageableResponse, gtserror.WithCode) {
notifs, err := p.state.DB.GetAccountNotifications( notifs, err := p.state.DB.GetAccountNotifications(
ctx, ctx,
authed.Account.ID, authed.Account.ID,
maxID, page,
sinceID,
minID,
limit,
types, types,
excludeTypes, excludeTypes,
) )
@ -78,22 +74,15 @@ func (p *Processor) NotificationsGet(
compiledMutes := usermute.NewCompiledUserMuteList(mutes) compiledMutes := usermute.NewCompiledUserMuteList(mutes)
var ( var (
items = make([]interface{}, 0, count) items = make([]interface{}, 0, count)
nextMaxIDValue string
prevMinIDValue string // Get the lowest and highest
// ID values, used for paging.
lo = notifs[count-1].ID
hi = notifs[0].ID
) )
for i, n := range notifs { for _, n := range notifs {
// Set next + prev values before filtering and API
// converting, so caller can still page properly.
if i == count-1 {
nextMaxIDValue = n.ID
}
if i == 0 {
prevMinIDValue = n.ID
}
visible, err := p.notifVisible(ctx, n, authed.Account) visible, err := p.notifVisible(ctx, n, authed.Account)
if err != nil { if err != nil {
log.Debugf(ctx, "skipping notification %s because of an error checking notification visibility: %v", n.ID, err) log.Debugf(ctx, "skipping notification %s because of an error checking notification visibility: %v", n.ID, err)
@ -115,13 +104,22 @@ func (p *Processor) NotificationsGet(
items = append(items, item) items = append(items, item)
} }
return util.PackagePageableResponse(util.PageableResponseParams{ // Build type query string.
Items: items, query := make(url.Values)
Path: "api/v1/notifications", for _, typ := range types {
NextMaxIDValue: nextMaxIDValue, query.Add("types[]", typ.String())
PrevMinIDValue: prevMinIDValue, }
Limit: limit, for _, typ := range excludeTypes {
}) query.Add("exclude_types[]", typ.String())
}
return paging.PackageResponse(paging.ResponseParams{
Items: items,
Path: "/api/v1/notifications",
Next: page.Next(lo, hi),
Prev: page.Prev(lo, hi),
Query: query,
}), nil
} }
func (p *Processor) NotificationGet(ctx context.Context, account *gtsmodel.Account, targetNotifID string) (*apimodel.Notification, gtserror.WithCode) { func (p *Processor) NotificationGet(ctx context.Context, account *gtsmodel.Account, targetNotifID string) (*apimodel.Notification, gtserror.WithCode) {

View file

@ -89,7 +89,7 @@ func (suite *SurfaceNotifyTestSuite) TestSpamNotifs() {
notifs, err := testStructs.State.DB.GetAccountNotifications( notifs, err := testStructs.State.DB.GetAccountNotifications(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
targetAccount.ID, targetAccount.ID,
"", "", "", 0, nil, nil, nil, nil, nil,
) )
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())