start working on multiple tag names for tag timeline

This commit is contained in:
tsmethurst 2023-10-27 15:24:21 +02:00
parent 0b978f2c56
commit 10eed3c80d
6 changed files with 51 additions and 27 deletions

View file

@ -125,10 +125,17 @@ func (m *Module) TagTimelineGETHandler(c *gin.Context) {
return return
} }
// Append any additional tags
// passed as `any[]` parameter.
tagNames := append(
[]string{tagName},
c.QueryArray(apiutil.TagAnyKey)...,
)
resp, errWithCode := m.processor.Timeline().TagTimelineGet( resp, errWithCode := m.processor.Timeline().TagTimelineGet(
c.Request.Context(), c.Request.Context(),
authed.Account, authed.Account,
tagName, tagNames,
c.Query(apiutil.MaxIDKey), c.Query(apiutil.MaxIDKey),
c.Query(apiutil.SinceIDKey), c.Query(apiutil.SinceIDKey),
c.Query(apiutil.MinIDKey), c.Query(apiutil.MinIDKey),

View file

@ -54,6 +54,7 @@
/* Tag keys */ /* Tag keys */
TagNameKey = "tag_name" TagNameKey = "tag_name"
TagAnyKey = "any[]"
/* Web endpoint keys */ /* Web endpoint keys */

View file

@ -463,7 +463,7 @@ func (t *timelineDB) GetListTimeline(
func (t *timelineDB) GetTagTimeline( func (t *timelineDB) GetTagTimeline(
ctx context.Context, ctx context.Context,
tagID string, tagIDs []string,
maxID string, maxID string,
sinceID string, sinceID string,
minID string, minID string,
@ -492,8 +492,8 @@ func (t *timelineDB) GetTagTimeline(
). ).
// Public only. // Public only.
Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic). Where("? = ?", bun.Ident("status.visibility"), gtsmodel.VisibilityPublic).
// This tag only. // Provided tag IDs only.
Where("? = ?", bun.Ident("status_to_tag.tag_id"), tagID) Where("? IN (?)", bun.Ident("status_to_tag.tag_id"), bun.In(tagIDs))
if maxID == "" || maxID >= id.Highest { if maxID == "" || maxID >= id.Highest {
const future = 24 * time.Hour const future = 24 * time.Hour

View file

@ -311,7 +311,7 @@ func (suite *TimelineTestSuite) TestGetTagTimelineNoParams() {
tag = suite.testTags["welcome"] tag = suite.testTags["welcome"]
) )
s, err := suite.db.GetTagTimeline(ctx, tag.ID, "", "", "", 1) s, err := suite.db.GetTagTimeline(ctx, []string{tag.ID}, "", "", "", 1)
if err != nil { if err != nil {
suite.FailNow(err.Error()) suite.FailNow(err.Error())
} }

View file

@ -49,7 +49,7 @@ type Timeline interface {
// Statuses should be returned in descending order of when they were created (newest first). // Statuses should be returned in descending order of when they were created (newest first).
GetListTimeline(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error) GetListTimeline(ctx context.Context, listID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error)
// GetTagTimeline returns a slice of public-visibility statuses that use the given tagID. // GetTagTimeline returns a slice of public-visibility statuses that use the given tagIDs.
// Statuses should be returned in descending order of when they were created (newest first). // Statuses should be returned in descending order of when they were created (newest first).
GetTagTimeline(ctx context.Context, tagID string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error) GetTagTimeline(ctx context.Context, tagIDs []string, maxID string, sinceID string, minID string, limit int) ([]*gtsmodel.Status, error)
} }

View file

@ -32,18 +32,20 @@
) )
// TagTimelineGet gets a pageable timeline for the given // TagTimelineGet gets a pageable timeline for the given
// tagName and given paging parameters. It will ensure // tagNames and given paging parameters. It will ensure
// that each status in the timeline is actually visible // that each status in the timeline is actually visible
// to requestingAcct before returning it. // to requestingAcct before returning it.
func (p *Processor) TagTimelineGet( func (p *Processor) TagTimelineGet(
ctx context.Context, ctx context.Context,
requestingAcct *gtsmodel.Account, requestingAcct *gtsmodel.Account,
tagName string, tagNames []string,
maxID string, maxID string,
sinceID string, sinceID string,
minID string, minID string,
limit int, limit int,
) (*apimodel.PageableResponse, gtserror.WithCode) { ) (*apimodel.PageableResponse, gtserror.WithCode) {
tagIDs := make([]string, 0, len(tagNames))
for _, tagName := range tagNames {
tag, errWithCode := p.getTag(ctx, tagName) tag, errWithCode := p.getTag(ctx, tagName)
if errWithCode != nil { if errWithCode != nil {
return nil, errWithCode return nil, errWithCode
@ -55,7 +57,10 @@ func (p *Processor) TagTimelineGet(
return nil, gtserror.NewErrorNotFound(err, err.Error()) return nil, gtserror.NewErrorNotFound(err, err.Error())
} }
statuses, err := p.state.DB.GetTagTimeline(ctx, tag.ID, maxID, sinceID, minID, limit) tagIDs = append(tagIDs, tag.ID)
}
statuses, err := p.state.DB.GetTagTimeline(ctx, tagIDs, maxID, sinceID, minID, limit)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
err = gtserror.Newf("db error getting statuses: %w", err) err = gtserror.Newf("db error getting statuses: %w", err)
return nil, gtserror.NewErrorInternalError(err) return nil, gtserror.NewErrorInternalError(err)
@ -66,8 +71,7 @@ func (p *Processor) TagTimelineGet(
requestingAcct, requestingAcct,
statuses, statuses,
limit, limit,
// Use API URL for tag. tagNames,
"/api/v1/timelines/tag/"+tagName,
) )
} }
@ -95,7 +99,7 @@ func (p *Processor) packageTagResponse(
requestingAcct *gtsmodel.Account, requestingAcct *gtsmodel.Account,
statuses []*gtsmodel.Status, statuses []*gtsmodel.Status,
limit int, limit int,
requestPath string, tagNames []string,
) (*apimodel.PageableResponse, gtserror.WithCode) { ) (*apimodel.PageableResponse, gtserror.WithCode) {
count := len(statuses) count := len(statuses)
if count == 0 { if count == 0 {
@ -131,11 +135,23 @@ func (p *Processor) packageTagResponse(
items = append(items, apiStatus) items = append(items, apiStatus)
} }
// Use first / "primary" tag for API endpoint.
path := "/api/v1/timelines/tag/" + tagNames[0]
// Add any additional tags.
var extraQueryParams []string
if len(tagNames) > 1 {
for _, tagName := range tagNames[1:] {
extraQueryParams = append(extraQueryParams, "any[]="+tagName)
}
}
return util.PackagePageableResponse(util.PageableResponseParams{ return util.PackagePageableResponse(util.PageableResponseParams{
Items: items, Items: items,
Path: requestPath, Path: path,
NextMaxIDValue: nextMaxIDValue, NextMaxIDValue: nextMaxIDValue,
PrevMinIDValue: prevMinIDValue, PrevMinIDValue: prevMinIDValue,
Limit: limit, Limit: limit,
ExtraQueryParams: extraQueryParams,
}) })
} }