Database updates (#144)

* start moving some database stuff around

* continue moving db stuff around

* more fiddling

* more updates

* and some more

* and yet more

* i broke SOMETHING but what, it's a mystery

* tidy up

* vendor ttlcache

* use ttlcache

* fix up some tests

* rename some stuff

* little reminder

* some more updates
This commit is contained in:
tobi 2021-08-20 12:26:56 +02:00 committed by GitHub
parent ce190d867c
commit 4920229a3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
164 changed files with 4850 additions and 2617 deletions

View file

@ -137,6 +137,7 @@ The following libraries and frameworks are used by GoToSocial, with gratitude
* [mvdan/xurls](https://github.com/mvdan/xurls); URL parsing regular expressions. [BSD-3-Clause License](https://spdx.org/licenses/BSD-3-Clause.html).
* [nfnt/resize](https://github.com/nfnt/resize); convenient image resizing. [ISC License](https://spdx.org/licenses/ISC.html).
* [oklog/ulid](https://github.com/oklog/ulid); sequential, database-friendly ID generation. [Apache-2.0 License](https://spdx.org/licenses/Apache-2.0.html).
* [ReneKroon/ttlcache](https://github.com/ReneKroon/ttlcache); in-memory caching. [MIT License](https://spdx.org/licenses/MIT.html).
* [russross/blackfriday](https://github.com/russross/blackfriday); markdown parsing for statuses. [Simplified BSD License](https://spdx.org/licenses/BSD-2-Clause.html).
* [sirupsen/logrus](https://github.com/sirupsen/logrus); logging. [MIT License](https://spdx.org/licenses/MIT.html).
* [stretchr/testify](https://github.com/stretchr/testify); test framework. [MIT License](https://spdx.org/licenses/MIT.html).

View file

@ -1679,7 +1679,7 @@ info:
name: AGPL3
url: https://www.gnu.org/licenses/agpl-3.0.en.html
title: GoToSocial
version: 0.1.0-SNAPSHOT-dereference_remote_replies
version: 0.1.0-SNAPSHOT
paths:
/api/v1/accounts:
post:
@ -3404,6 +3404,8 @@ paths:
description: ""
schema:
$ref: '#/definitions/swaggerStatusRepliesCollection'
"400":
description: bad request
"401":
description: unauthorized
"403":

1
go.mod
View file

@ -3,6 +3,7 @@ module github.com/superseriousbusiness/gotosocial
go 1.16
require (
github.com/ReneKroon/ttlcache v1.7.0
github.com/buckket/go-blurhash v1.1.0
github.com/coreos/go-oidc/v3 v3.0.0
github.com/cpuguy83/go-md2man/v2 v2.0.1 // indirect

4
go.sum
View file

@ -33,6 +33,8 @@ cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/ReneKroon/ttlcache v1.7.0 h1:8BkjFfrzVFXyrqnMtezAaJ6AHPSsVV10m6w28N/Fgkk=
github.com/ReneKroon/ttlcache v1.7.0/go.mod h1:8BGGzdumrIjWxdRx8zpK6L3oGMWvIXdvB2GD1cfvd+I=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4=
@ -425,6 +427,8 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opentelemetry.io/otel v0.13.0/go.mod h1:dlSNewoRYikTkotEnxdmuBHgzT+k/idJSfDv/FxEnOY=
go.uber.org/goleak v0.10.0 h1:G3eWbSNIskeRqtsN/1uI5B+eP73y3JUuBsv9AZjehb4=
go.uber.org/goleak v0.10.0/go.mod h1:VCZuO8V8mFPlL0F5J5GK1rtHV3DrFcQ1R8ryq7FK0aI=
golang.org/x/crypto v0.0.0-20180527072434-ab813273cd59/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20180910181607-0e37d006457b/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=

View file

@ -581,7 +581,7 @@ func ExtractMention(i Mentionable) (*gtsmodel.Mention, error) {
if hrefProp == nil || !hrefProp.IsIRI() {
return nil, errors.New("no href prop")
}
mention.MentionedAccountURI = hrefProp.GetIRI().String()
mention.TargetAccountURI = hrefProp.GetIRI().String()
return mention, nil
}

View file

@ -116,7 +116,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return user, nil
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
}
@ -128,7 +128,7 @@ func (m *Module) parseUserFromClaims(claims *oidc.Claims, ip net.IP, appID strin
return nil, fmt.Errorf("user with email address %s is unconfirmed", claims.Email)
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// we have an actual error in the database
return nil, fmt.Errorf("error checking database for email %s: %s", claims.Email, err)
}

View file

@ -6,8 +6,6 @@
"github.com/gin-gonic/gin"
"github.com/go-fed/httpsig"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
@ -33,13 +31,13 @@ func (m *Module) SignatureCheck(c *gin.Context) {
// we managed to parse the url!
// if the domain is blocked we want to bail as early as possible
blockedDomain, err := m.blockedDomain(requestingPublicKeyID.Host)
blocked, err := m.db.IsURIBlocked(requestingPublicKeyID)
if err != nil {
l.Errorf("could not tell if domain %s was blocked or not: %s", requestingPublicKeyID.Host, err)
c.AbortWithStatus(http.StatusInternalServerError)
return
}
if blockedDomain {
if blocked {
l.Infof("domain %s is blocked", requestingPublicKeyID.Host)
c.AbortWithStatus(http.StatusForbidden)
return
@ -50,20 +48,3 @@ func (m *Module) SignatureCheck(c *gin.Context) {
}
}
}
func (m *Module) blockedDomain(host string) (bool, error) {
b := &gtsmodel.DomainBlock{}
err := m.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b)
if err == nil {
// block exists
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
// there are no entries so there's no block
return false, nil
}
// there's an actual error
return false, err
}

View file

@ -18,8 +18,28 @@
package cache
import (
"time"
"github.com/ReneKroon/ttlcache"
)
// Cache defines an in-memory cache that is safe to be wiped when the application is restarted
type Cache interface {
Store(k string, v interface{}) error
Fetch(k string) (interface{}, error)
}
type cache struct {
c *ttlcache.Cache
}
// New returns a new in-memory cache.
func New() Cache {
c := ttlcache.NewCache()
c.SetTTL(30 * time.Second)
cache := &cache{
c: c,
}
return cache
}

27
internal/cache/error.go vendored Normal file
View file

@ -0,0 +1,27 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
import "errors"
// Error models an error returned by the in-memory cache.
type Error error
// ErrNotFound means that a value for the requested key was not found in the cache.
var ErrNotFound = errors.New("value not found in cache")

View file

@ -16,18 +16,13 @@
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
package cache
import (
"strings"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) Put(i interface{}) error {
_, err := ps.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists{}
func (c *cache) Fetch(k string) (interface{}, error) {
i, stored := c.c.Get(k)
if !stored {
return nil, ErrNotFound
}
return err
return i, nil
}

24
internal/cache/store.go vendored Normal file
View file

@ -0,0 +1,24 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package cache
func (c *cache) Store(k string, v interface{}) error {
c.c.Set(k, v)
return nil
}

View file

@ -88,8 +88,8 @@
return err
}
a := &gtsmodel.Account{}
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil {
a, err := dbConn.GetLocalAccountByUsername(username)
if err != nil {
return err
}
@ -123,8 +123,8 @@
return err
}
a := &gtsmodel.Account{}
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil {
a, err := dbConn.GetLocalAccountByUsername(username)
if err != nil {
return err
}
@ -155,8 +155,8 @@
return err
}
a := &gtsmodel.Account{}
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil {
a, err := dbConn.GetLocalAccountByUsername(username)
if err != nil {
return err
}
@ -187,8 +187,8 @@
return err
}
a := &gtsmodel.Account{}
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil {
a, err := dbConn.GetLocalAccountByUsername(username)
if err != nil {
return err
}
@ -233,8 +233,8 @@
return err
}
a := &gtsmodel.Account{}
if err := dbConn.GetLocalAccountByUsername(username, a); err != nil {
a, err := dbConn.GetLocalAccountByUsername(username)
if err != nil {
return err
}

View file

@ -62,6 +62,8 @@
&gtsmodel.MediaAttachment{},
&gtsmodel.Mention{},
&gtsmodel.Status{},
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
&gtsmodel.StatusFave{},
&gtsmodel.StatusBookmark{},
&gtsmodel.StatusMute{},

66
internal/db/account.go Normal file
View file

@ -0,0 +1,66 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"time"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Account contains functions related to account getting/setting/creation.
type Account interface {
// GetAccountByID returns one account with the given ID, or an error if something goes wrong.
GetAccountByID(id string) (*gtsmodel.Account, Error)
// GetAccountByURI returns one account with the given URI, or an error if something goes wrong.
GetAccountByURI(uri string) (*gtsmodel.Account, Error)
// GetAccountByURL returns one account with the given URL, or an error if something goes wrong.
GetAccountByURL(uri string) (*gtsmodel.Account, Error)
// GetLocalAccountByUsername returns an account on this instance by its username.
GetLocalAccountByUsername(username string) (*gtsmodel.Account, Error)
// GetAccountFaves fetches faves/likes created by the target accountID.
GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, Error)
// GetAccountStatusesCount is a shortcut for the common action of counting statuses produced by accountID.
CountAccountStatuses(accountID string) (int, Error)
// GetAccountStatuses is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, Error)
GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, Error)
// GetAccountLastPosted simply gets the timestamp of the most recent post by the account.
//
// The returned time will be zero if account has never posted anything.
GetAccountLastPosted(accountID string) (time.Time, Error)
// SetAccountHeaderOrAvatar sets the header or avatar for the given accountID to the given media attachment.
SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) Error
// GetInstanceAccount returns the instance account for the given domain.
// If domain is empty, this instance account will be returned.
GetInstanceAccount(domain string) (*gtsmodel.Account, Error)
}

53
internal/db/admin.go Normal file
View file

@ -0,0 +1,53 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import (
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
// Admin contains functions related to instance administration (new signups etc).
type Admin interface {
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(username string) Error
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(email string) Error
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, Error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount() Error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance() Error
}

87
internal/db/basic.go Normal file
View file

@ -0,0 +1,87 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "context"
// Basic wraps basic database functionality.
type Basic interface {
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
CreateTable(i interface{}) Error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(i interface{}) Error
// RegisterTable registers a table for use in many2many relations.
// For implementations that don't use tables, or many2many relations, this can just return nil.
RegisterTable(i interface{}) Error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) Error
// IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) Error
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) Error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(where []Where, i interface{}) Error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) Error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) Error
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Upsert(i interface{}, conflictColumn string) Error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) Error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) Error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(where []Where, key string, value interface{}, i interface{}) Error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) Error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(where []Where, i interface{}) Error
}

View file

@ -19,9 +19,6 @@
package db
import (
"context"
"net"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
@ -30,257 +27,19 @@
DBTypePostgres string = "POSTGRES"
)
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
// Note that in all of the functions below, the passed interface should be a pointer or a slice, which will then be populated
// by whatever is returned from the database.
// DB provides methods for interacting with an underlying database or other storage mechanism.
type DB interface {
/*
BASIC DB FUNCTIONALITY
*/
// CreateTable creates a table for the given interface.
// For implementations that don't use tables, this can just return nil.
CreateTable(i interface{}) error
// DropTable drops the table for the given interface.
// For implementations that don't use tables, this can just return nil.
DropTable(i interface{}) error
// Stop should stop and close the database connection cleanly, returning an error if this is not possible.
// If the database implementation doesn't need to be stopped, this can just return nil.
Stop(ctx context.Context) error
// IsHealthy should return nil if the database connection is healthy, or an error if not.
IsHealthy(ctx context.Context) error
// GetByID gets one entry by its id. In a database like postgres, this might be the 'id' field of the entry,
// for other implementations (for example, in-memory) it might just be the key of a map.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetByID(id string, i interface{}) error
// GetWhere gets one entry where key = value. This is similar to GetByID but allows the caller to specify the
// name of the key to select from.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetWhere(where []Where, i interface{}) error
// GetAll will try to get all entries of type i.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
// In case of no entries, a 'no entries' error will be returned
GetAll(i interface{}) error
// Put simply stores i. It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Put(i interface{}) error
// Upsert stores or updates i based on the given conflict column, as in https://www.postgresqltutorial.com/postgresql-upsert/
// It is up to the implementation to figure out how to store it, and using what key.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
Upsert(i interface{}, conflictColumn string) error
// UpdateByID updates i with id id.
// The given interface i will be set to the result of the query, whatever it is. Use a pointer or a slice.
UpdateByID(id string, i interface{}) error
// UpdateOneByID updates interface i with database the given database id. It will update one field of key key and value value.
UpdateOneByID(id string, key string, value interface{}, i interface{}) error
// UpdateWhere updates column key of interface i with the given value, where the given parameters apply.
UpdateWhere(where []Where, key string, value interface{}, i interface{}) error
// DeleteByID removes i with id id.
// If i didn't exist anyway, then no error should be returned.
DeleteByID(id string, i interface{}) error
// DeleteWhere deletes i where key = value
// If i didn't exist anyway, then no error should be returned.
DeleteWhere(where []Where, i interface{}) error
/*
HANDY SHORTCUTS
*/
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error)
// CreateInstanceAccount creates an account in the database with the same username as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance user will have a username of 'example.org'.
// This is needed for things like serving files that belong to the instance and not an individual user/account.
CreateInstanceAccount() error
// CreateInstanceInstance creates an instance in the database with the same domain as the instance host value.
// Ie., if the instance is hosted at 'example.org' the instance will have a domain of 'example.org'.
// This is needed for things like serving instance information through /api/v1/instance
CreateInstanceInstance() error
// GetAccountByUserID is a shortcut for the common action of fetching an account corresponding to a user ID.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetAccountByUserID(userID string, account *gtsmodel.Account) error
// GetLocalAccountByUsername is a shortcut for the common action of fetching an account ON THIS INSTANCE
// according to its username, which should be unique.
// The given account pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetLocalAccountByUsername(username string, account *gtsmodel.Account) error
// GetFollowRequestsForAccountID is a shortcut for the common action of fetching a list of follow requests targeting the given account ID.
// The given slice 'followRequests' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error
// GetFollowingByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is following.
// The given slice 'following' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error
// GetFollowersByAccountID is a shortcut for the common action of fetching a list of accounts that accountID is followed by.
// The given slice 'followers' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
//
// If localOnly is set to true, then only followers from *this instance* will be returned.
GetFollowersByAccountID(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error
// GetFavesByAccountID is a shortcut for the common action of fetching a list of faves made by the given accountID.
// The given slice 'faves' will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetFavesByAccountID(accountID string, faves *[]gtsmodel.StatusFave) error
// CountStatusesByAccountID is a shortcut for the common action of counting statuses produced by accountID.
CountStatusesByAccountID(accountID string) (int, error)
// GetStatusesForAccount is a shortcut for getting the most recent statuses. accountID is optional, if not provided
// then all statuses will be returned. If limit is set to 0, the size of the returned slice will not be limited. This can
// be very memory intensive so you probably shouldn't do this!
// In case of no entries, a 'no entries' error will be returned
GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error)
GetBlocksForAccount(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error)
// GetLastStatusForAccountID simply gets the most recent status by the given account.
// The given slice 'status' pointer will be set to the result of the query, whatever it is.
// In case of no entries, a 'no entries' error will be returned
GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error
// IsUsernameAvailable checks whether a given username is available on our domain.
// Returns an error if the username is already taken, or something went wrong in the db.
IsUsernameAvailable(username string) error
// IsEmailAvailable checks whether a given email address for a new account is available to be used on our domain.
// Return an error if:
// A) the email is already associated with an account
// B) we block signups from this email domain
// C) something went wrong in the db
IsEmailAvailable(email string) error
// NewSignup creates a new user in the database with the given parameters.
// By the time this function is called, it should be assumed that all the parameters have passed validation!
NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error)
// SetHeaderOrAvatarForAccountID sets the header or avatar for the given accountID to the given media attachment.
SetHeaderOrAvatarForAccountID(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error
// GetHeaderAvatarForAccountID gets the current avatar for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the avatar, if it exists.
GetAvatarForAccountID(avatar *gtsmodel.MediaAttachment, accountID string) error
// GetHeaderForAccountID gets the current header for the given account ID.
// The passed mediaAttachment pointer will be populated with the value of the header, if it exists.
GetHeaderForAccountID(header *gtsmodel.MediaAttachment, accountID string) error
// Blocked checks whether a block exists in eiher direction between two accounts.
// That is, it returns true if account1 blocks account2, OR if account2 blocks account1.
Blocked(account1 string, account2 string) (bool, error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error)
// Follows returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
// FollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error)
// Mutuals returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error)
// GetReplyCountForStatus returns the amount of replies recorded for a status, or an error if something goes wrong
GetReplyCountForStatus(status *gtsmodel.Status) (int, error)
// GetReblogCountForStatus returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
GetReblogCountForStatus(status *gtsmodel.Status) (int, error)
// GetFaveCountForStatus returns the amount of faves/likes recorded for a status, or an error if something goes wrong
GetFaveCountForStatus(status *gtsmodel.Status) (int, error)
// StatusParents get the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error)
// StatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error)
// StatusFavedBy checks if a given status has been faved by a given account ID
StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error)
// StatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error)
// StatusMutedBy checks if a given status has been muted by a given account ID
StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error)
// StatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error)
// WhoFavedStatus returns a slice of accounts who faved the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
// WhoBoostedStatus returns a slice of accounts who boosted the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error)
// GetHomeTimelineForAccount returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
// GetPublicTimelineForAccount fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error)
// GetFavedTimelineForAccount fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Note that unlike the other GetTimeline functions, the returned statuses will be arranged by their FAVE id, not the STATUS id.
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error)
// GetNotificationsForAccount returns a list of notifications that pertain to the given accountID.
GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error)
// GetUserCountForInstance returns the number of known accounts registered with the given domain.
GetUserCountForInstance(domain string) (int, error)
// GetStatusCountForInstance returns the number of known statuses posted from the given domain.
GetStatusCountForInstance(domain string) (int, error)
// GetDomainCountForInstance returns the number of known instances known that the given domain federates with.
GetDomainCountForInstance(domain string) (int, error)
// GetAccountsForInstance returns a slice of accounts from the given instance, arranged by ID.
GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error)
Account
Admin
Basic
Domain
Instance
Media
Mention
Notification
Relationship
Status
Timeline
/*
USEFUL CONVERSION FUNCTIONS

36
internal/db/domain.go Normal file
View file

@ -0,0 +1,36 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "net/url"
// Domain contains DB functions related to domains and domain blocks.
type Domain interface {
// IsDomainBlocked checks if an instance-level domain block exists for the given domain string (eg., `example.org`).
IsDomainBlocked(domain string) (bool, Error)
// AreDomainsBlocked checks if an instance-level domain block exists for any of the given domains strings, and returns true if even one is found.
AreDomainsBlocked(domains []string) (bool, Error)
// IsURIBlocked checks if an instance-level domain block exists for the `host` in the given URI (eg., `https://example.org/users/whatever`).
IsURIBlocked(uri *url.URL) (bool, Error)
// AreURIsBlocked checks if an instance-level domain block exists for any `host` in the given URI slice, and returns true if even one is found.
AreURIsBlocked(uris []*url.URL) (bool, Error)
}

View file

@ -18,16 +18,18 @@
package db
// ErrNoEntries is to be returned from the DB interface when no entries are found for a given query.
type ErrNoEntries struct{}
import "fmt"
func (e ErrNoEntries) Error() string {
return "no entries"
}
// Error denotes a database error.
type Error error
// ErrAlreadyExists is to be returned from the DB interface when an entry already exists for a given query or its constraints.
type ErrAlreadyExists struct{}
func (e ErrAlreadyExists) Error() string {
return "already exists"
}
var (
// ErrNoEntries is returned when a caller expected an entry for a query, but none was found.
ErrNoEntries Error = fmt.Errorf("no entries")
// ErrMultipleEntries is returned when a caller expected ONE entry for a query, but multiples were found.
ErrMultipleEntries Error = fmt.Errorf("multiple entries")
// ErrAlreadyExists is returned when a caller tries to insert a database entry that already exists in the db.
ErrAlreadyExists Error = fmt.Errorf("already exists")
// ErrUnknown denotes an unknown database error.
ErrUnknown Error = fmt.Errorf("unknown error")
)

36
internal/db/instance.go Normal file
View file

@ -0,0 +1,36 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Instance contains functions for instance-level actions (counting instance users etc.).
type Instance interface {
// CountInstanceUsers returns the number of known accounts registered with the given domain.
CountInstanceUsers(domain string) (int, Error)
// CountInstanceStatuses returns the number of known statuses posted from the given domain.
CountInstanceStatuses(domain string) (int, Error)
// CountInstanceDomains returns the number of known instances known that the given domain federates with.
CountInstanceDomains(domain string) (int, Error)
// GetInstanceAccounts returns a slice of accounts from the given instance, arranged by ID.
GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, Error)
}

27
internal/db/media.go Normal file
View file

@ -0,0 +1,27 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Media contains functions related to creating/getting/removing media attachments.
type Media interface {
// GetAttachmentByID gets a single attachment by its ID
GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, Error)
}

30
internal/db/mention.go Normal file
View file

@ -0,0 +1,30 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Mention contains functions for getting/creating mentions in the database.
type Mention interface {
// GetMention gets a single mention by ID
GetMention(id string) (*gtsmodel.Mention, Error)
// GetMentions gets multiple mentions.
GetMentions(ids []string) ([]*gtsmodel.Mention, Error)
}

View file

@ -0,0 +1,31 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Notification contains functions for creating and getting notifications.
type Notification interface {
// GetNotifications returns a slice of notifications that pertain to the given accountID.
//
// Returned notifications will be ordered ID descending (ie., highest/newest to lowest/oldest).
GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, Error)
// GetNotification returns one notification according to its id.
GetNotification(id string) (*gtsmodel.Notification, Error)
}

256
internal/db/pg/account.go Normal file
View file

@ -0,0 +1,256 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"time"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type accountDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *accountDB) newAccountQ(account *gtsmodel.Account) *orm.Query {
return a.conn.Model(account).
Relation("AvatarMediaAttachment").
Relation("HeaderMediaAttachment")
}
func (a *accountDB) GetAccountByID(id string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.id = ?", id)
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetAccountByURI(uri string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.uri = ?", uri)
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetAccountByURL(uri string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("account.url = ?", uri)
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetInstanceAccount(domain string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account)
if domain == "" {
q = q.
Where("account.username = ?", domain).
Where("account.domain = ?", domain)
} else {
q = q.
Where("account.username = ?", domain).
Where("? IS NULL", pg.Ident("domain"))
}
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetAccountLastPosted(accountID string) (time.Time, db.Error) {
status := &gtsmodel.Status{}
q := a.conn.Model(status).
Order("id DESC").
Limit(1).
Where("account_id = ?", accountID).
Column("created_at")
err := processErrorResponse(q.Select())
return status.CreatedAt, err
}
func (a *accountDB) SetAccountHeaderOrAvatar(mediaAttachment *gtsmodel.MediaAttachment, accountID string) db.Error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
var headerOrAVI string
if mediaAttachment.Avatar {
headerOrAVI = "avatar"
} else if mediaAttachment.Header {
headerOrAVI = "header"
} else {
return errors.New("given media attachment was neither a header nor an avatar")
}
// TODO: there are probably more side effects here that need to be handled
if _, err := a.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
return err
}
if _, err := a.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
return err
}
return nil
}
func (a *accountDB) GetLocalAccountByUsername(username string) (*gtsmodel.Account, db.Error) {
account := &gtsmodel.Account{}
q := a.newAccountQ(account).
Where("username = ?", username).
Where("? IS NULL", pg.Ident("domain"))
err := processErrorResponse(q.Select())
return account, err
}
func (a *accountDB) GetAccountFaves(accountID string) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
if err := a.conn.Model(&faves).
Where("account_id = ?", accountID).
Select(); err != nil {
if err == pg.ErrNoRows {
return faves, nil
}
return nil, err
}
return faves, nil
}
func (a *accountDB) CountAccountStatuses(accountID string) (int, db.Error) {
return a.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
}
func (a *accountDB) GetAccountStatuses(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, db.Error) {
a.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := a.conn.Model(&statuses).Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
if limit != 0 {
q = q.Limit(limit)
}
if excludeReplies {
q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
}
if pinnedOnly {
q = q.Where("pinned = ?", true)
}
if mediaOnly {
q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
})
}
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries
}
a.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
func (a *accountDB) GetAccountBlocks(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, db.Error) {
blocks := []*gtsmodel.Block{}
fq := a.conn.Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
}
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
}

View file

@ -0,0 +1,70 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"testing"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type AccountTestSuite struct {
PGStandardTestSuite
}
func (suite *AccountTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *AccountTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *AccountTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *AccountTestSuite) TestGetAccountByIDWithExtras() {
account, err := suite.db.GetAccountByID(suite.testAccounts["local_account_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(account)
suite.NotNil(account.AvatarMediaAttachment)
suite.NotEmpty(account.AvatarMediaAttachment.URL)
suite.NotNil(account.HeaderMediaAttachment)
suite.NotEmpty(account.HeaderMediaAttachment.URL)
}
func TestAccountTestSuite(t *testing.T) {
suite.Run(t, new(AccountTestSuite))
}

235
internal/db/pg/admin.go Normal file
View file

@ -0,0 +1,235 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"crypto/rand"
"crypto/rsa"
"fmt"
"net"
"net/mail"
"strings"
"time"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
)
type adminDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (a *adminDB) IsUsernameAvailable(username string) db.Error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := a.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (a *adminDB) IsEmailAvailable(email string) db.Error {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := a.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
// check if this email is associated with a user already
if err := a.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (a *adminDB) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, db.Error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
// if something went wrong while creating a user, we might already have an account, so check here first...
acct := &gtsmodel.Account{}
err = a.conn.Model(acct).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
if err != nil {
// there's been an actual error
if err != pg.ErrNoRows {
return nil, fmt.Errorf("db error checking existence of account: %s", err)
}
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
newAccountID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
acct = &gtsmodel.Account{
ID: newAccountID,
Username: username,
DisplayName: username,
Reason: reason,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
if _, err = a.conn.Model(acct).Insert(); err != nil {
return nil, err
}
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
newUserID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
u := &gtsmodel.User{
ID: newUserID,
AccountID: acct.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP.To4(),
Locale: locale,
UnconfirmedEmail: email,
CreatedByApplicationID: appID,
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
}
if emailVerified {
u.ConfirmedAt = time.Now()
u.Email = email
}
if admin {
u.Admin = true
u.Moderator = true
}
if _, err = a.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
func (a *adminDB) CreateInstanceAccount() db.Error {
username := a.config.Host
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
a.log.Errorf("error creating new rsa key: %s", err)
return err
}
aID, err := id.NewRandomULID()
if err != nil {
return err
}
newAccountURIs := util.GenerateURIsForAccount(username, a.config.Protocol, a.config.Host)
acct := &gtsmodel.Account{
ID: aID,
Username: a.config.Host,
DisplayName: username,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
inserted, err := a.conn.Model(acct).Where("username = ?", username).SelectOrInsert()
if err != nil {
return err
}
if inserted {
a.log.Infof("created instance account %s with id %s", username, acct.ID)
} else {
a.log.Infof("instance account %s already exists with id %s", username, acct.ID)
}
return nil
}
func (a *adminDB) CreateInstanceInstance() db.Error {
iID, err := id.NewRandomULID()
if err != nil {
return err
}
i := &gtsmodel.Instance{
ID: iID,
Domain: a.config.Host,
Title: a.config.Host,
URI: fmt.Sprintf("%s://%s", a.config.Protocol, a.config.Host),
}
inserted, err := a.conn.Model(i).Where("domain = ?", a.config.Host).SelectOrInsert()
if err != nil {
return err
}
if inserted {
a.log.Infof("created instance instance %s with id %s", a.config.Host, i.ID)
} else {
a.log.Infof("instance instance %s already exists with id %s", a.config.Host, i.ID)
}
return nil
}

205
internal/db/pg/basic.go Normal file
View file

@ -0,0 +1,205 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"errors"
"fmt"
"strings"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
type basicDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (b *basicDB) Put(i interface{}) db.Error {
_, err := b.conn.Model(i).Insert(i)
if err != nil && strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
func (b *basicDB) GetByID(id string, i interface{}) db.Error {
if err := b.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) GetWhere(where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) GetAll(i interface{}) db.Error {
if err := b.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) DeleteByID(id string, i interface{}) db.Error {
if _, err := b.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
return err
}
}
return nil
}
func (b *basicDB) DeleteWhere(where []db.Where, i interface{}) db.Error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := b.conn.Model(i)
for _, w := range where {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
if _, err := q.Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
return err
}
}
return nil
}
func (b *basicDB) Upsert(i interface{}, conflictColumn string) db.Error {
if _, err := b.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateByID(id string, i interface{}) db.Error {
if _, err := b.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries
}
return err
}
return nil
}
func (b *basicDB) UpdateOneByID(id string, key string, value interface{}, i interface{}) db.Error {
_, err := b.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (b *basicDB) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) db.Error {
q := b.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}
func (b *basicDB) CreateTable(i interface{}) db.Error {
return b.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (b *basicDB) DropTable(i interface{}) db.Error {
return b.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (b *basicDB) RegisterTable(i interface{}) db.Error {
orm.RegisterTable(i)
return nil
}
func (b *basicDB) IsHealthy(ctx context.Context) db.Error {
return b.conn.Ping(ctx)
}
func (b *basicDB) Stop(ctx context.Context) db.Error {
b.log.Info("closing db connection")
if err := b.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
b.cancel()
return err
}
return nil
}

View file

@ -1,67 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetBlocksForAccount(accountID string, maxID string, sinceID string, limit int) ([]*gtsmodel.Account, string, string, error) {
blocks := []*gtsmodel.Block{}
fq := ps.conn.Model(&blocks).
Where("block.account_id = ?", accountID).
Relation("TargetAccount").
Order("block.id DESC")
if maxID != "" {
fq = fq.Where("block.id < ?", maxID)
}
if sinceID != "" {
fq = fq.Where("block.id > ?", sinceID)
}
if limit > 0 {
fq = fq.Limit(limit)
}
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
}
return nil, "", "", err
}
if len(blocks) == 0 {
return nil, "", "", db.ErrNoEntries{}
}
accounts := []*gtsmodel.Account{}
for _, b := range blocks {
accounts = append(accounts, b.TargetAccount)
}
nextMaxID := blocks[len(blocks)-1].ID
prevMinID := blocks[0].ID
return accounts, nextMaxID, prevMinID, nil
}

83
internal/db/pg/domain.go Normal file
View file

@ -0,0 +1,83 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"net/url"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
type domainDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (d *domainDB) IsDomainBlocked(domain string) (bool, db.Error) {
if domain == "" {
return false, nil
}
blocked, err := d.conn.
Model(&gtsmodel.DomainBlock{}).
Where("LOWER(domain) = LOWER(?)", domain).
Exists()
err = processErrorResponse(err)
return blocked, err
}
func (d *domainDB) AreDomainsBlocked(domains []string) (bool, db.Error) {
// filter out any doubles
uniqueDomains := util.UniqueStrings(domains)
for _, domain := range uniqueDomains {
if blocked, err := d.IsDomainBlocked(domain); err != nil {
return false, err
} else if blocked {
return blocked, nil
}
}
// no blocks found
return false, nil
}
func (d *domainDB) IsURIBlocked(uri *url.URL) (bool, db.Error) {
domain := uri.Hostname()
return d.IsDomainBlocked(domain)
}
func (d *domainDB) AreURIsBlocked(uris []*url.URL) (bool, db.Error) {
domains := []string{}
for _, uri := range uris {
domains = append(domains, uri.Hostname())
}
return d.AreDomainsBlocked(domains)
}

View file

@ -1,75 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"errors"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) GetByID(id string, i interface{}) error {
if err := ps.conn.Model(i).Where("id = ?", id).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetWhere(where []db.Where, i interface{}) error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := ps.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAll(i interface{}) error {
if err := ps.conn.Model(i).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}

View file

@ -19,15 +19,26 @@
package pg
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Account{})
type instanceDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
if domain == ps.config.Host {
func (i *instanceDB) CountInstanceUsers(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Account{})
if domain == i.config.Host {
// if the domain is *this* domain, just count where the domain field is null
q = q.Where("? IS NULL", pg.Ident("domain"))
} else {
@ -40,10 +51,10 @@ func (ps *postgresService) GetUserCountForInstance(domain string) (int, error) {
return q.Count()
}
func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Status{})
func (i *instanceDB) CountInstanceStatuses(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Status{})
if domain == ps.config.Host {
if domain == i.config.Host {
// if the domain is *this* domain, just count where local is true
q = q.Where("local = ?", true)
} else {
@ -55,10 +66,10 @@ func (ps *postgresService) GetStatusCountForInstance(domain string) (int, error)
return q.Count()
}
func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error) {
q := ps.conn.Model(&[]*gtsmodel.Instance{})
func (i *instanceDB) CountInstanceDomains(domain string) (int, db.Error) {
q := i.conn.Model(&[]*gtsmodel.Instance{})
if domain == ps.config.Host {
if domain == i.config.Host {
// if the domain is *this* domain, just count other instances it knows about
// exclude domains that are blocked
q = q.Where("domain != ?", domain).Where("? IS NULL", pg.Ident("suspended_at"))
@ -70,12 +81,12 @@ func (ps *postgresService) GetDomainCountForInstance(domain string) (int, error)
return q.Count()
}
func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, limit int) ([]*gtsmodel.Account, error) {
ps.log.Debug("GetAccountsForInstance")
func (i *instanceDB) GetInstanceAccounts(domain string, maxID string, limit int) ([]*gtsmodel.Account, db.Error) {
i.log.Debug("GetAccountsForInstance")
accounts := []*gtsmodel.Account{}
q := ps.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
q := i.conn.Model(&accounts).Where("domain = ?", domain).Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
@ -88,13 +99,13 @@ func (ps *postgresService) GetAccountsForInstance(domain string, maxID string, l
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(accounts) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return accounts, nil

View file

@ -19,39 +19,35 @@
package pg
import (
"errors"
"context"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
return err
}
}
return nil
type mediaDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (ps *postgresService) DeleteWhere(where []db.Where, i interface{}) error {
if len(where) == 0 {
return errors.New("no queries provided")
}
q := ps.conn.Model(i)
for _, w := range where {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
if _, err := q.Delete(); err != nil {
// if there are no rows *anyway* then that's fine
// just return err if there's an actual error
if err != pg.ErrNoRows {
return err
}
}
return nil
func (m *mediaDB) newMediaQ(i interface{}) *orm.Query {
return m.conn.Model(i).
Relation("Account")
}
func (m *mediaDB) GetAttachmentByID(id string) (*gtsmodel.MediaAttachment, db.Error) {
attachment := &gtsmodel.MediaAttachment{}
q := m.newMediaQ(attachment).
Where("media_attachment.id = ?", id)
err := processErrorResponse(q.Select())
return attachment, err
}

108
internal/db/pg/mention.go Normal file
View file

@ -0,0 +1,108 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type mentionDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (m *mentionDB) cacheMention(id string, mention *gtsmodel.Mention) {
if m.cache == nil {
m.cache = cache.New()
}
if err := m.cache.Store(id, mention); err != nil {
m.log.Panicf("mentionDB: error storing in cache: %s", err)
}
}
func (m *mentionDB) mentionCached(id string) (*gtsmodel.Mention, bool) {
if m.cache == nil {
m.cache = cache.New()
return nil, false
}
mI, err := m.cache.Fetch(id)
if err != nil || mI == nil {
return nil, false
}
mention, ok := mI.(*gtsmodel.Mention)
if !ok {
m.log.Panicf("mentionDB: cached interface with key %s was not a mention", id)
}
return mention, true
}
func (m *mentionDB) newMentionQ(i interface{}) *orm.Query {
return m.conn.Model(i).
Relation("Status").
Relation("OriginAccount").
Relation("TargetAccount")
}
func (m *mentionDB) GetMention(id string) (*gtsmodel.Mention, db.Error) {
if mention, cached := m.mentionCached(id); cached {
return mention, nil
}
mention := &gtsmodel.Mention{}
q := m.newMentionQ(mention).
Where("mention.id = ?", id)
err := processErrorResponse(q.Select())
if err == nil && mention != nil {
m.cacheMention(id, mention)
}
return mention, err
}
func (m *mentionDB) GetMentions(ids []string) ([]*gtsmodel.Mention, db.Error) {
mentions := []*gtsmodel.Mention{}
for _, i := range ids {
mention, err := m.GetMention(i)
if err != nil {
return nil, processErrorResponse(err)
}
mentions = append(mentions, mention)
}
return mentions, nil
}

View file

@ -0,0 +1,135 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type notificationDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (n *notificationDB) cacheNotification(id string, notification *gtsmodel.Notification) {
if n.cache == nil {
n.cache = cache.New()
}
if err := n.cache.Store(id, notification); err != nil {
n.log.Panicf("notificationDB: error storing in cache: %s", err)
}
}
func (n *notificationDB) notificationCached(id string) (*gtsmodel.Notification, bool) {
if n.cache == nil {
n.cache = cache.New()
return nil, false
}
nI, err := n.cache.Fetch(id)
if err != nil || nI == nil {
return nil, false
}
notification, ok := nI.(*gtsmodel.Notification)
if !ok {
n.log.Panicf("notificationDB: cached interface with key %s was not a notification", id)
}
return notification, true
}
func (n *notificationDB) newNotificationQ(i interface{}) *orm.Query {
return n.conn.Model(i).
Relation("OriginAccount").
Relation("TargetAccount").
Relation("Status")
}
func (n *notificationDB) GetNotification(id string) (*gtsmodel.Notification, db.Error) {
if notification, cached := n.notificationCached(id); cached {
return notification, nil
}
notification := &gtsmodel.Notification{}
q := n.newNotificationQ(notification).
Where("notification.id = ?", id)
err := processErrorResponse(q.Select())
if err == nil && notification != nil {
n.cacheNotification(id, notification)
}
return notification, err
}
func (n *notificationDB) GetNotifications(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, db.Error) {
// begin by selecting just the IDs
notifIDs := []*gtsmodel.Notification{}
q := n.conn.
Model(&notifIDs).
Column("id").
Where("target_account_id = ?", accountID).
Order("id DESC")
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
q = q.Where("id > ?", sinceID)
}
if limit != 0 {
q = q.Limit(limit)
}
err := processErrorResponse(q.Select())
if err != nil {
return nil, err
}
// now we have the IDs, select the notifs one by one
// reason for this is that for each notif, we can instead get it from our cache if it's cached
notifications := []*gtsmodel.Notification{}
for _, notifID := range notifIDs {
notif, err := n.GetNotification(notifID.ID)
errP := processErrorResponse(err)
if errP != nil {
return nil, errP
}
notifications = append(notifications, notif)
}
return notifications, nil
}

View file

@ -20,15 +20,11 @@
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"net"
"net/mail"
"os"
"strings"
"time"
@ -41,12 +37,26 @@
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/util"
"golang.org/x/crypto/bcrypt"
)
var registerTables []interface{} = []interface{}{
&gtsmodel.StatusToEmoji{},
&gtsmodel.StatusToTag{},
}
// postgresService satisfies the DB interface
type postgresService struct {
db.Account
db.Admin
db.Basic
db.Domain
db.Instance
db.Media
db.Mention
db.Notification
db.Relationship
db.Status
db.Timeline
config *config.Config
conn *pg.DB
log *logrus.Logger
@ -56,6 +66,11 @@ type postgresService struct {
// NewPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
// Under the hood, it uses https://github.com/go-pg/pg to create and maintain a database connection.
func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logger) (db.DB, error) {
for _, t := range registerTables {
// https://pg.uptrace.dev/orm/many-to-many-relation/
orm.RegisterTable(t)
}
opts, err := derivePGOptions(c)
if err != nil {
return nil, fmt.Errorf("could not create postgres service: %s", err)
@ -91,6 +106,72 @@ func NewPostgresService(ctx context.Context, c *config.Config, log *logrus.Logge
log.Infof("connected to postgres version: %s", version)
ps := &postgresService{
Account: &accountDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Admin: &adminDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Basic: &basicDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Domain: &domainDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Instance: &instanceDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Media: &mediaDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Mention: &mentionDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Notification: &notificationDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Relationship: &relationshipDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Status: &statusDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
Timeline: &timelineDB{
config: c,
conn: conn,
log: log,
cancel: cancel,
},
config: c,
conn: conn,
log: log,
@ -199,724 +280,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
return options, nil
}
/*
BASIC DB FUNCTIONALITY
*/
func (ps *postgresService) CreateTable(i interface{}) error {
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
}
func (ps *postgresService) DropTable(i interface{}) error {
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
IfExists: true,
})
}
func (ps *postgresService) Stop(ctx context.Context) error {
ps.log.Info("closing db connection")
if err := ps.conn.Close(); err != nil {
// only cancel if there's a problem closing the db
ps.cancel()
return err
}
return nil
}
func (ps *postgresService) IsHealthy(ctx context.Context) error {
return ps.conn.Ping(ctx)
}
func (ps *postgresService) CreateSchema(ctx context.Context) error {
models := []interface{}{
(*gtsmodel.Account)(nil),
(*gtsmodel.Status)(nil),
(*gtsmodel.User)(nil),
}
ps.log.Info("creating db schema")
for _, model := range models {
err := ps.conn.Model(model).CreateTable(&orm.CreateTableOptions{
IfNotExists: true,
})
if err != nil {
return err
}
}
ps.log.Info("db schema created")
return nil
}
/*
HANDY SHORTCUTS
*/
func (ps *postgresService) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := ps.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
return nil, db.ErrNoEntries{}
}
return nil, err
}
// create a new follow to 'replace' the request with
follow := &gtsmodel.Follow{
ID: fr.ID,
AccountID: originAccountID,
TargetAccountID: targetAccountID,
URI: fr.URI,
}
// if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := ps.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
return nil, err
}
// now remove the follow request
if _, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
return nil, err
}
return follow, nil
}
func (ps *postgresService) CreateInstanceAccount() error {
username := ps.config.Host
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
return err
}
aID, err := id.NewRandomULID()
if err != nil {
return err
}
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
a := &gtsmodel.Account{
ID: aID,
Username: ps.config.Host,
DisplayName: username,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
inserted, err := ps.conn.Model(a).Where("username = ?", username).SelectOrInsert()
if err != nil {
return err
}
if inserted {
ps.log.Infof("created instance account %s with id %s", username, a.ID)
} else {
ps.log.Infof("instance account %s already exists with id %s", username, a.ID)
}
return nil
}
func (ps *postgresService) CreateInstanceInstance() error {
iID, err := id.NewRandomULID()
if err != nil {
return err
}
i := &gtsmodel.Instance{
ID: iID,
Domain: ps.config.Host,
Title: ps.config.Host,
URI: fmt.Sprintf("%s://%s", ps.config.Protocol, ps.config.Host),
}
inserted, err := ps.conn.Model(i).Where("domain = ?", ps.config.Host).SelectOrInsert()
if err != nil {
return err
}
if inserted {
ps.log.Infof("created instance instance %s with id %s", ps.config.Host, i.ID)
} else {
ps.log.Infof("instance instance %s already exists with id %s", ps.config.Host, i.ID)
}
return nil
}
func (ps *postgresService) GetAccountByUserID(userID string, account *gtsmodel.Account) error {
user := &gtsmodel.User{
ID: userID,
}
if err := ps.conn.Model(user).Where("id = ?", userID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if err := ps.conn.Model(account).Where("id = ?", user.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetLocalAccountByUsername(username string, account *gtsmodel.Account) error {
if err := ps.conn.Model(account).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetFollowRequestsForAccountID(accountID string, followRequests *[]gtsmodel.FollowRequest) error {
if err := ps.conn.Model(followRequests).Where("target_account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
return err
}
return nil
}
func (ps *postgresService) GetFollowingByAccountID(accountID string, following *[]gtsmodel.Follow) error {
if err := ps.conn.Model(following).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
return err
}
return nil
}
func (ps *postgresService) GetFollowersByAccountID(accountID string, followers *[]gtsmodel.Follow, localOnly bool) error {
q := ps.conn.Model(followers)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
whereGroup := func(q *pg.Query) (*pg.Query, error) {
q = q.
WhereOr("? IS NULL", pg.Ident("a.domain")).
WhereOr("a.domain = ?", "")
return q, nil
}
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Where("follow.target_account_id = ?", accountID).
WhereGroup(whereGroup)
} else {
q = q.Where("target_account_id = ?", accountID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
return err
}
return nil
}
func (ps *postgresService) GetFavesByAccountID(accountID string, faves *[]gtsmodel.StatusFave) error {
if err := ps.conn.Model(faves).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return nil
}
return err
}
return nil
}
func (ps *postgresService) CountStatusesByAccountID(accountID string) (int, error) {
count, err := ps.conn.Model(&gtsmodel.Status{}).Where("account_id = ?", accountID).Count()
if err != nil {
if err == pg.ErrNoRows {
return 0, nil
}
return 0, err
}
return count, nil
}
func (ps *postgresService) GetStatusesForAccount(accountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]*gtsmodel.Status, error) {
ps.log.Debugf("getting statuses for account %s", accountID)
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses).Order("id DESC")
if accountID != "" {
q = q.Where("account_id = ?", accountID)
}
if limit != 0 {
q = q.Limit(limit)
}
if excludeReplies {
q = q.Where("? IS NULL", pg.Ident("in_reply_to_id"))
}
if pinnedOnly {
q = q.Where("pinned = ?", true)
}
if mediaOnly {
q = q.WhereGroup(func(q *pg.Query) (*pg.Query, error) {
return q.Where("? IS NOT NULL", pg.Ident("attachments")).Where("attachments != '{}'"), nil
})
}
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
}
ps.log.Debugf("returning statuses for account %s", accountID)
return statuses, nil
}
func (ps *postgresService) GetLastStatusForAccountID(accountID string, status *gtsmodel.Status) error {
if err := ps.conn.Model(status).Order("created_at DESC").Limit(1).Where("account_id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) IsUsernameAvailable(username string) error {
// if no error we fail because it means we found something
// if error but it's not pg.ErrNoRows then we fail
// if err is pg.ErrNoRows we're good, we found nothing so continue
if err := ps.conn.Model(&gtsmodel.Account{}).Where("username = ?", username).Where("domain = ?", nil).Select(); err == nil {
return fmt.Errorf("username %s already in use", username)
} else if err != pg.ErrNoRows {
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (ps *postgresService) IsEmailAvailable(email string) error {
// parse the domain from the email
m, err := mail.ParseAddress(email)
if err != nil {
return fmt.Errorf("error parsing email address %s: %s", email, err)
}
domain := strings.Split(m.Address, "@")[1] // domain will always be the second part after @
// check if the email domain is blocked
if err := ps.conn.Model(&gtsmodel.EmailDomainBlock{}).Where("domain = ?", domain).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email domain %s is blocked", domain)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
// check if this email is associated with a user already
if err := ps.conn.Model(&gtsmodel.User{}).Where("email = ?", email).WhereOr("unconfirmed_email = ?", email).Select(); err == nil {
// fail because we found something
return fmt.Errorf("email %s already in use", email)
} else if err != pg.ErrNoRows {
// fail because we got an unexpected error
return fmt.Errorf("db error: %s", err)
}
return nil
}
func (ps *postgresService) NewSignup(username string, reason string, requireApproval bool, email string, password string, signUpIP net.IP, locale string, appID string, emailVerified bool, admin bool) (*gtsmodel.User, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
ps.log.Errorf("error creating new rsa key: %s", err)
return nil, err
}
// if something went wrong while creating a user, we might already have an account, so check here first...
a := &gtsmodel.Account{}
err = ps.conn.Model(a).Where("username = ?", username).Where("? IS NULL", pg.Ident("domain")).Select()
if err != nil {
// there's been an actual error
if err != pg.ErrNoRows {
return nil, fmt.Errorf("db error checking existence of account: %s", err)
}
// we just don't have an account yet create one
newAccountURIs := util.GenerateURIsForAccount(username, ps.config.Protocol, ps.config.Host)
newAccountID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
a = &gtsmodel.Account{
ID: newAccountID,
Username: username,
DisplayName: username,
Reason: reason,
URL: newAccountURIs.UserURL,
PrivateKey: key,
PublicKey: &key.PublicKey,
PublicKeyURI: newAccountURIs.PublicKeyURI,
ActorType: gtsmodel.ActivityStreamsPerson,
URI: newAccountURIs.UserURI,
InboxURI: newAccountURIs.InboxURI,
OutboxURI: newAccountURIs.OutboxURI,
FollowersURI: newAccountURIs.FollowersURI,
FollowingURI: newAccountURIs.FollowingURI,
FeaturedCollectionURI: newAccountURIs.CollectionURI,
}
if _, err = ps.conn.Model(a).Insert(); err != nil {
return nil, err
}
}
pw, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("error hashing password: %s", err)
}
newUserID, err := id.NewRandomULID()
if err != nil {
return nil, err
}
u := &gtsmodel.User{
ID: newUserID,
AccountID: a.ID,
EncryptedPassword: string(pw),
SignUpIP: signUpIP.To4(),
Locale: locale,
UnconfirmedEmail: email,
CreatedByApplicationID: appID,
Approved: !requireApproval, // if we don't require moderator approval, just pre-approve the user
}
if emailVerified {
u.ConfirmedAt = time.Now()
u.Email = email
}
if admin {
u.Admin = true
u.Moderator = true
}
if _, err = ps.conn.Model(u).Insert(); err != nil {
return nil, err
}
return u, nil
}
func (ps *postgresService) SetHeaderOrAvatarForAccountID(mediaAttachment *gtsmodel.MediaAttachment, accountID string) error {
if mediaAttachment.Avatar && mediaAttachment.Header {
return errors.New("one media attachment cannot be both header and avatar")
}
var headerOrAVI string
if mediaAttachment.Avatar {
headerOrAVI = "avatar"
} else if mediaAttachment.Header {
headerOrAVI = "header"
} else {
return errors.New("given media attachment was neither a header nor an avatar")
}
// TODO: there are probably more side effects here that need to be handled
if _, err := ps.conn.Model(mediaAttachment).OnConflict("(id) DO UPDATE").Insert(); err != nil {
return err
}
if _, err := ps.conn.Model(&gtsmodel.Account{}).Set(fmt.Sprintf("%s_media_attachment_id = ?", headerOrAVI), mediaAttachment.ID).Where("id = ?", accountID).Update(); err != nil {
return err
}
return nil
}
func (ps *postgresService) GetHeaderForAccountID(header *gtsmodel.MediaAttachment, accountID string) error {
acct := &gtsmodel.Account{}
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if acct.HeaderMediaAttachmentID == "" {
return db.ErrNoEntries{}
}
if err := ps.conn.Model(header).Where("id = ?", acct.HeaderMediaAttachmentID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) GetAvatarForAccountID(avatar *gtsmodel.MediaAttachment, accountID string) error {
acct := &gtsmodel.Account{}
if err := ps.conn.Model(acct).Where("id = ?", accountID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
if acct.AvatarMediaAttachmentID == "" {
return db.ErrNoEntries{}
}
if err := ps.conn.Model(avatar).Where("id = ?", acct.AvatarMediaAttachmentID).Select(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) Blocked(account1 string, account2 string) (bool, error) {
// TODO: check domain blocks as well
var blocked bool
if err := ps.conn.Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).Where("target_account_id = ?", account2).
WhereOr("target_account_id = ?", account1).Where("account_id = ?", account2).
Select(); err != nil {
if err == pg.ErrNoRows {
blocked = false
return blocked, nil
}
return blocked, err
}
blocked = true
return blocked, nil
}
func (ps *postgresService) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, error) {
r := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := ps.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err != pg.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
// no follow exists so these are all false
r.Following = false
r.ShowingReblogs = false
r.Notifying = false
} else {
// follow exists so we can fill these fields out...
r.Following = true
r.ShowingReblogs = follow.ShowReblogs
r.Notifying = follow.Notify
}
// check if the target account follows the requesting account
followedBy, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
r.FollowedBy = followedBy
// check if the requesting account blocks the target account
blocking, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
r.Blocking = blocking
// check if the target account blocks the requesting account
blockedBy, err := ps.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
r.BlockedBy = blockedBy
// check if there's a pending following request from requesting account to target account
requested, err := ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
r.Requested = requested
return r, nil
}
func (ps *postgresService) Follows(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (ps *postgresService) FollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
return ps.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", sourceAccount.ID).Where("target_account_id = ?", targetAccount.ID).Exists()
}
func (ps *postgresService) Mutuals(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account1.ID).Where("target_account_id = ?", account2.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
}
return false, err
}
// make sure account 2 follows account 1
f2, err := ps.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", account2.ID).Where("target_account_id = ?", account1.ID).Exists()
if err != nil {
if err == pg.ErrNoRows {
return false, nil
}
return false, err
}
return f1 && f2, nil
}
func (ps *postgresService) GetReplyCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
func (ps *postgresService) GetReblogCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
func (ps *postgresService) GetFaveCountForStatus(status *gtsmodel.Status) (int, error) {
return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
func (ps *postgresService) StatusFavedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusMutedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) StatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, error) {
return ps.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (ps *postgresService) WhoFavedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
accounts := []*gtsmodel.Account{}
faves := []*gtsmodel.StatusFave{}
if err := ps.conn.Model(&faves).Where("status_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has faved this status, so that's fine
}
return nil, err // an actual error has occurred
}
for _, f := range faves {
acc := &gtsmodel.Account{}
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}
return nil, err // an actual error has occurred
}
accounts = append(accounts, acc)
}
return accounts, nil
}
func (ps *postgresService) WhoBoostedStatus(status *gtsmodel.Status) ([]*gtsmodel.Account, error) {
accounts := []*gtsmodel.Account{}
boosts := []*gtsmodel.Status{}
if err := ps.conn.Model(&boosts).Where("boost_of_id = ?", status.ID).Select(); err != nil {
if err == pg.ErrNoRows {
return accounts, nil // no rows just means nobody has boosted this status, so that's fine
}
return nil, err // an actual error has occurred
}
for _, f := range boosts {
acc := &gtsmodel.Account{}
if err := ps.conn.Model(acc).Where("id = ?", f.AccountID).Select(); err != nil {
if err == pg.ErrNoRows {
continue // the account doesn't exist for some reason??? but this isn't the place to worry about that so just skip it
}
return nil, err // an actual error has occurred
}
accounts = append(accounts, acc)
}
return accounts, nil
}
func (ps *postgresService) GetNotificationsForAccount(accountID string, limit int, maxID string, sinceID string) ([]*gtsmodel.Notification, error) {
notifications := []*gtsmodel.Notification{}
q := ps.conn.Model(&notifications).Where("target_account_id = ?", accountID)
if maxID != "" {
q = q.Where("id < ?", maxID)
}
if sinceID != "" {
q = q.Where("id > ?", sinceID)
}
if limit != 0 {
q = q.Limit(limit)
}
q = q.Order("created_at DESC")
if err := q.Select(); err != nil {
if err != pg.ErrNoRows {
return nil, err
}
}
return notifications, nil
}
/*
CONVERSION FUNCTIONS
*/
@ -988,14 +351,14 @@ func (ps *postgresService) MentionStringsToMentions(targetAccounts []string, ori
// id, createdAt and updatedAt will be populated by the db, so we have everything we need!
menchies = append(menchies, &gtsmodel.Mention{
StatusID: statusID,
OriginAccountID: ogAccount.ID,
OriginAccountURI: ogAccount.URI,
TargetAccountID: mentionedAccount.ID,
NameString: a,
MentionedAccountURI: mentionedAccount.URI,
MentionedAccountURL: mentionedAccount.URL,
GTSAccount: mentionedAccount,
StatusID: statusID,
OriginAccountID: ogAccount.ID,
OriginAccountURI: ogAccount.URI,
TargetAccountID: mentionedAccount.ID,
NameString: a,
TargetAccountURI: mentionedAccount.URI,
TargetAccountURL: mentionedAccount.URL,
OriginAccount: mentionedAccount,
})
}
return menchies, nil

47
internal/db/pg/pg_test.go Normal file
View file

@ -0,0 +1,47 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/oauth"
)
type PGStandardTestSuite struct {
// standard suite interfaces
suite.Suite
config *config.Config
db db.DB
log *logrus.Logger
// standard suite models
testTokens map[string]*oauth.Token
testClients map[string]*oauth.Client
testApplications map[string]*gtsmodel.Application
testUsers map[string]*gtsmodel.User
testAccounts map[string]*gtsmodel.Account
testAttachments map[string]*gtsmodel.MediaAttachment
testStatuses map[string]*gtsmodel.Status
testTags map[string]*gtsmodel.Tag
testMentions map[string]*gtsmodel.Mention
}

View file

@ -0,0 +1,276 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"context"
"fmt"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type relationshipDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (r *relationshipDB) newBlockQ(block *gtsmodel.Block) *orm.Query {
return r.conn.Model(block).
Relation("Account").
Relation("TargetAccount")
}
func (r *relationshipDB) newFollowQ(follow interface{}) *orm.Query {
return r.conn.Model(follow).
Relation("Account").
Relation("TargetAccount")
}
func (r *relationshipDB) IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, db.Error) {
q := r.conn.
Model(&gtsmodel.Block{}).
Where("account_id = ?", account1).
Where("target_account_id = ?", account2)
if eitherDirection {
q = q.
WhereOr("target_account_id = ?", account1).
Where("account_id = ?", account2)
}
return q.Exists()
}
func (r *relationshipDB) GetBlock(account1 string, account2 string) (*gtsmodel.Block, db.Error) {
block := &gtsmodel.Block{}
q := r.newBlockQ(block).
Where("block.account_id = ?", account1).
Where("block.target_account_id = ?", account2)
err := processErrorResponse(q.Select())
return block, err
}
func (r *relationshipDB) GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, db.Error) {
rel := &gtsmodel.Relationship{
ID: targetAccount,
}
// check if the requesting account follows the target account
follow := &gtsmodel.Follow{}
if err := r.conn.Model(follow).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Select(); err != nil {
if err != pg.ErrNoRows {
// a proper error
return nil, fmt.Errorf("getrelationship: error checking follow existence: %s", err)
}
// no follow exists so these are all false
rel.Following = false
rel.ShowingReblogs = false
rel.Notifying = false
} else {
// follow exists so we can fill these fields out...
rel.Following = true
rel.ShowingReblogs = follow.ShowReblogs
rel.Notifying = follow.Notify
}
// check if the target account follows the requesting account
followedBy, err := r.conn.Model(&gtsmodel.Follow{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking followed_by existence: %s", err)
}
rel.FollowedBy = followedBy
// check if the requesting account blocks the target account
blocking, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocking existence: %s", err)
}
rel.Blocking = blocking
// check if the target account blocks the requesting account
blockedBy, err := r.conn.Model(&gtsmodel.Block{}).Where("account_id = ?", targetAccount).Where("target_account_id = ?", requestingAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.BlockedBy = blockedBy
// check if there's a pending following request from requesting account to target account
requested, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", requestingAccount).Where("target_account_id = ?", targetAccount).Exists()
if err != nil {
return nil, fmt.Errorf("getrelationship: error checking blocked existence: %s", err)
}
rel.Requested = requested
return rel, nil
}
func (r *relationshipDB) IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
Model(&gtsmodel.Follow{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return q.Exists()
}
func (r *relationshipDB) IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, db.Error) {
if sourceAccount == nil || targetAccount == nil {
return false, nil
}
q := r.conn.
Model(&gtsmodel.FollowRequest{}).
Where("account_id = ?", sourceAccount.ID).
Where("target_account_id = ?", targetAccount.ID)
return q.Exists()
}
func (r *relationshipDB) IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, db.Error) {
if account1 == nil || account2 == nil {
return false, nil
}
// make sure account 1 follows account 2
f1, err := r.IsFollowing(account1, account2)
if err != nil {
return false, processErrorResponse(err)
}
// make sure account 2 follows account 1
f2, err := r.IsFollowing(account2, account1)
if err != nil {
return false, processErrorResponse(err)
}
return f1 && f2, nil
}
func (r *relationshipDB) AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, db.Error) {
// make sure the original follow request exists
fr := &gtsmodel.FollowRequest{}
if err := r.conn.Model(fr).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Select(); err != nil {
if err == pg.ErrMultiRows {
return nil, db.ErrNoEntries
}
return nil, err
}
// create a new follow to 'replace' the request with
follow := &gtsmodel.Follow{
ID: fr.ID,
AccountID: originAccountID,
TargetAccountID: targetAccountID,
URI: fr.URI,
}
// if the follow already exists, just update the URI -- we don't need to do anything else
if _, err := r.conn.Model(follow).OnConflict("ON CONSTRAINT follows_account_id_target_account_id_key DO UPDATE set uri = ?", follow.URI).Insert(); err != nil {
return nil, err
}
// now remove the follow request
if _, err := r.conn.Model(&gtsmodel.FollowRequest{}).Where("account_id = ?", originAccountID).Where("target_account_id = ?", targetAccountID).Delete(); err != nil {
return nil, err
}
return follow, nil
}
func (r *relationshipDB) GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, db.Error) {
followRequests := []*gtsmodel.FollowRequest{}
q := r.newFollowQ(&followRequests).
Where("target_account_id = ?", accountID)
err := processErrorResponse(q.Select())
return followRequests, err
}
func (r *relationshipDB) GetAccountFollows(accountID string) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.newFollowQ(&follows).
Where("account_id = ?", accountID)
err := processErrorResponse(q.Select())
return follows, err
}
func (r *relationshipDB) CountAccountFollows(accountID string, localOnly bool) (int, db.Error) {
return r.conn.
Model(&[]*gtsmodel.Follow{}).
Where("account_id = ?", accountID).
Count()
}
func (r *relationshipDB) GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, db.Error) {
follows := []*gtsmodel.Follow{}
q := r.conn.Model(&follows)
if localOnly {
// for local accounts let's get where domain is null OR where domain is an empty string, just to be safe
whereGroup := func(q *pg.Query) (*pg.Query, error) {
q = q.
WhereOr("? IS NULL", pg.Ident("a.domain")).
WhereOr("a.domain = ?", "")
return q, nil
}
q = q.ColumnExpr("follow.*").
Join("JOIN accounts AS a ON follow.account_id = TEXT(a.id)").
Where("follow.target_account_id = ?", accountID).
WhereGroup(whereGroup)
} else {
q = q.Where("target_account_id = ?", accountID)
}
if err := q.Select(); err != nil {
if err == pg.ErrNoRows {
return follows, nil
}
return nil, err
}
return follows, nil
}
func (r *relationshipDB) CountAccountFollowedBy(accountID string, localOnly bool) (int, db.Error) {
return r.conn.
Model(&[]*gtsmodel.Follow{}).
Where("target_account_id = ?", accountID).
Count()
}

318
internal/db/pg/status.go Normal file
View file

@ -0,0 +1,318 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"container/list"
"context"
"errors"
"time"
"github.com/go-pg/pg/v10"
"github.com/go-pg/pg/v10/orm"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/cache"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
type statusDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
cache cache.Cache
}
func (s *statusDB) cacheStatus(id string, status *gtsmodel.Status) {
if s.cache == nil {
s.cache = cache.New()
}
if err := s.cache.Store(id, status); err != nil {
s.log.Panicf("statusDB: error storing in cache: %s", err)
}
}
func (s *statusDB) statusCached(id string) (*gtsmodel.Status, bool) {
if s.cache == nil {
s.cache = cache.New()
return nil, false
}
sI, err := s.cache.Fetch(id)
if err != nil || sI == nil {
return nil, false
}
status, ok := sI.(*gtsmodel.Status)
if !ok {
s.log.Panicf("statusDB: cached interface with key %s was not a status", id)
}
return status, true
}
func (s *statusDB) newStatusQ(status interface{}) *orm.Query {
return s.conn.Model(status).
Relation("Attachments").
Relation("Tags").
Relation("Mentions").
Relation("Emojis").
Relation("Account").
Relation("InReplyTo").
Relation("InReplyToAccount").
Relation("BoostOf").
Relation("BoostOfAccount").
Relation("CreatedWithApplication")
}
func (s *statusDB) newFaveQ(faves interface{}) *orm.Query {
return s.conn.Model(faves).
Relation("Account").
Relation("TargetAccount").
Relation("Status")
}
func (s *statusDB) GetStatusByID(id string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(id); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("status.id = ?", id)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(id, status)
}
return status, err
}
func (s *statusDB) GetStatusByURI(uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.uri) = LOWER(?)", uri)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(uri, status)
}
return status, err
}
func (s *statusDB) GetStatusByURL(uri string) (*gtsmodel.Status, db.Error) {
if status, cached := s.statusCached(uri); cached {
return status, nil
}
status := &gtsmodel.Status{}
q := s.newStatusQ(status).
Where("LOWER(status.url) = LOWER(?)", uri)
err := processErrorResponse(q.Select())
if err == nil && status != nil {
s.cacheStatus(uri, status)
}
return status, err
}
func (s *statusDB) PutStatus(status *gtsmodel.Status) db.Error {
transaction := func(tx *pg.Tx) error {
// create links between this status and any emojis it uses
for _, i := range status.EmojiIDs {
if _, err := tx.Model(&gtsmodel.StatusToEmoji{
StatusID: status.ID,
EmojiID: i,
}).Insert(); err != nil {
return err
}
}
// create links between this status and any tags it uses
for _, i := range status.TagIDs {
if _, err := tx.Model(&gtsmodel.StatusToTag{
StatusID: status.ID,
TagID: i,
}).Insert(); err != nil {
return err
}
}
// change the status ID of the media attachments to the new status
for _, a := range status.Attachments {
a.StatusID = status.ID
a.UpdatedAt = time.Now()
if _, err := s.conn.Model(a).
Where("id = ?", a.ID).
Update(); err != nil {
return err
}
}
_, err := tx.Model(status).Insert()
return err
}
return processErrorResponse(s.conn.RunInTransaction(context.Background(), transaction))
}
func (s *statusDB) GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, db.Error) {
parents := []*gtsmodel.Status{}
s.statusParent(status, &parents, onlyDirect)
return parents, nil
}
func (s *statusDB) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus, err := s.GetStatusByID(status.InReplyToID)
if err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
s.statusParent(parentStatus, foundStatuses, false)
}
func (s *statusDB) GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, db.Error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
s.statusChildren(status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
// only append children, not the overall parent status
if entry.ID != status.ID {
children = append(children, entry)
}
}
return children, nil
}
func (s *statusDB) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := s.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if err := q.Select(); err != nil {
return
}
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
}
}
// only do one loop if we only want direct children
if onlyDirect {
return
}
s.statusChildren(child, foundStatuses, false, minID)
}
}
func (s *statusDB) CountStatusReplies(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("in_reply_to_id = ?", status.ID).Count()
}
func (s *statusDB) CountStatusReblogs(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Count()
}
func (s *statusDB) CountStatusFaves(status *gtsmodel.Status) (int, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Count()
}
func (s *statusDB) IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusFave{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.Status{}).Where("boost_of_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusMute{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, db.Error) {
return s.conn.Model(&gtsmodel.StatusBookmark{}).Where("status_id = ?", status.ID).Where("account_id = ?", accountID).Exists()
}
func (s *statusDB) GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, db.Error) {
faves := []*gtsmodel.StatusFave{}
q := s.newFaveQ(&faves).
Where("status_id = ?", status.ID)
err := processErrorResponse(q.Select())
return faves, err
}
func (s *statusDB) GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, db.Error) {
reblogs := []*gtsmodel.Status{}
q := s.newStatusQ(&reblogs).
Where("boost_of_id = ?", status.ID)
err := processErrorResponse(q.Select())
return reblogs, err
}

View file

@ -0,0 +1,134 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg_test
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/suite"
"github.com/superseriousbusiness/gotosocial/testrig"
)
type StatusTestSuite struct {
PGStandardTestSuite
}
func (suite *StatusTestSuite) SetupSuite() {
suite.testTokens = testrig.NewTestTokens()
suite.testClients = testrig.NewTestClients()
suite.testApplications = testrig.NewTestApplications()
suite.testUsers = testrig.NewTestUsers()
suite.testAccounts = testrig.NewTestAccounts()
suite.testAttachments = testrig.NewTestAttachments()
suite.testStatuses = testrig.NewTestStatuses()
suite.testTags = testrig.NewTestTags()
suite.testMentions = testrig.NewTestMentions()
}
func (suite *StatusTestSuite) SetupTest() {
suite.config = testrig.NewTestConfig()
suite.db = testrig.NewTestDB()
suite.log = testrig.NewTestLog()
testrig.StandardDBSetup(suite.db, suite.testAccounts)
}
func (suite *StatusTestSuite) TearDownTest() {
testrig.StandardDBTeardown(suite.db)
}
func (suite *StatusTestSuite) TestGetStatusByID() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_1_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusByURI() {
status, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.Nil(status.BoostOf)
suite.Nil(status.BoostOfAccount)
suite.Nil(status.InReplyTo)
suite.Nil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusWithExtras() {
status, err := suite.db.GetStatusByID(suite.testStatuses["admin_account_status_1"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.NotEmpty(status.Tags)
suite.NotEmpty(status.Attachments)
suite.NotEmpty(status.Emojis)
}
func (suite *StatusTestSuite) TestGetStatusWithMention() {
status, err := suite.db.GetStatusByID(suite.testStatuses["local_account_2_status_5"].ID)
if err != nil {
suite.FailNow(err.Error())
}
suite.NotNil(status)
suite.NotNil(status.Account)
suite.NotNil(status.CreatedWithApplication)
suite.NotEmpty(status.Mentions)
suite.NotEmpty(status.MentionIDs)
suite.NotNil(status.InReplyTo)
suite.NotNil(status.InReplyToAccount)
}
func (suite *StatusTestSuite) TestGetStatusTwice() {
before1 := time.Now()
_, err := suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after1 := time.Now()
duration1 := after1.Sub(before1)
fmt.Println(duration1.Nanoseconds())
before2 := time.Now()
_, err = suite.db.GetStatusByURI(suite.testStatuses["local_account_1_status_1"].URI)
suite.NoError(err)
after2 := time.Now()
duration2 := after2.Sub(before2)
fmt.Println(duration2.Nanoseconds())
// second retrieval should be several orders faster since it will be cached now
suite.Less(duration2, duration1)
}
func TestStatusTestSuite(t *testing.T) {
suite.Run(t, new(StatusTestSuite))
}

View file

@ -1,104 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"container/list"
"errors"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) StatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, error) {
parents := []*gtsmodel.Status{}
ps.statusParent(status, &parents, onlyDirect)
return parents, nil
}
func (ps *postgresService) statusParent(status *gtsmodel.Status, foundStatuses *[]*gtsmodel.Status, onlyDirect bool) {
if status.InReplyToID == "" {
return
}
parentStatus := &gtsmodel.Status{}
if err := ps.conn.Model(parentStatus).Where("id = ?", status.InReplyToID).Select(); err == nil {
*foundStatuses = append(*foundStatuses, parentStatus)
}
if onlyDirect {
return
}
ps.statusParent(parentStatus, foundStatuses, false)
}
func (ps *postgresService) StatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, error) {
foundStatuses := &list.List{}
foundStatuses.PushFront(status)
ps.statusChildren(status, foundStatuses, onlyDirect, minID)
children := []*gtsmodel.Status{}
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
// only append children, not the overall parent status
if entry.ID != status.ID {
children = append(children, entry)
}
}
return children, nil
}
func (ps *postgresService) statusChildren(status *gtsmodel.Status, foundStatuses *list.List, onlyDirect bool, minID string) {
immediateChildren := []*gtsmodel.Status{}
q := ps.conn.Model(&immediateChildren).Where("in_reply_to_id = ?", status.ID)
if minID != "" {
q = q.Where("status.id > ?", minID)
}
if err := q.Select(); err != nil {
return
}
for _, child := range immediateChildren {
insertLoop:
for e := foundStatuses.Front(); e != nil; e = e.Next() {
entry, ok := e.Value.(*gtsmodel.Status)
if !ok {
panic(errors.New("entry in foundStatuses was not a *gtsmodel.Status"))
}
if child.InReplyToAccountID != "" && entry.ID == child.InReplyToID {
foundStatuses.InsertAfter(child, e)
break insertLoop
}
}
// only do one loop if we only want direct children
if onlyDirect {
return
}
ps.statusChildren(child, foundStatuses, false, minID)
}
}

View file

@ -19,16 +19,26 @@
package pg
import (
"context"
"sort"
"github.com/go-pg/pg/v10"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
type timelineDB struct {
config *config.Config
conn *pg.DB
log *logrus.Logger
cancel context.CancelFunc
}
func (t *timelineDB) GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses)
q := t.conn.Model(&statuses)
q = q.ColumnExpr("status.*").
// Find out who accountID follows.
@ -74,22 +84,22 @@ func (ps *postgresService) GetHomeTimelineForAccount(accountID string, maxID str
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return statuses, nil
}
func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, error) {
func (t *timelineDB) GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, db.Error) {
statuses := []*gtsmodel.Status{}
q := ps.conn.Model(&statuses).
q := t.conn.Model(&statuses).
Where("visibility = ?", gtsmodel.VisibilityPublic).
Where("? IS NULL", pg.Ident("in_reply_to_id")).
Where("? IS NULL", pg.Ident("in_reply_to_uri")).
@ -119,13 +129,13 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
err := q.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return nil, err
}
if len(statuses) == 0 {
return nil, db.ErrNoEntries{}
return nil, db.ErrNoEntries
}
return statuses, nil
@ -133,11 +143,11 @@ func (ps *postgresService) GetPublicTimelineForAccount(accountID string, maxID s
// TODO optimize this query and the logic here, because it's slow as balls -- it takes like a literal second to return with a limit of 20!
// It might be worth serving it through a timeline instead of raw DB queries, like we do for Home feeds.
func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, error) {
func (t *timelineDB) GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, db.Error) {
faves := []*gtsmodel.StatusFave{}
fq := ps.conn.Model(&faves).
fq := t.conn.Model(&faves).
Where("account_id = ?", accountID).
Order("id DESC")
@ -156,13 +166,13 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
err := fq.Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(faves) == 0 {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
// map[statusID]faveID -- we need this to sort statuses by fave ID rather than their own ID
@ -175,16 +185,16 @@ func (ps *postgresService) GetFavedTimelineForAccount(accountID string, maxID st
}
statuses := []*gtsmodel.Status{}
err = ps.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
err = t.conn.Model(&statuses).Where("id IN (?)", pg.In(in)).Select()
if err != nil {
if err == pg.ErrNoRows {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
return nil, "", "", err
}
if len(statuses) == 0 {
return nil, "", "", db.ErrNoEntries{}
return nil, "", "", db.ErrNoEntries
}
// arrange statuses by fave ID

View file

@ -1,73 +0,0 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package pg
import (
"fmt"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
func (ps *postgresService) Upsert(i interface{}, conflictColumn string) error {
if _, err := ps.conn.Model(i).OnConflict(fmt.Sprintf("(%s) DO UPDATE", conflictColumn)).Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
if _, err := ps.conn.Model(i).Where("id = ?", id).OnConflict("(id) DO UPDATE").Insert(); err != nil {
if err == pg.ErrNoRows {
return db.ErrNoEntries{}
}
return err
}
return nil
}
func (ps *postgresService) UpdateOneByID(id string, key string, value interface{}, i interface{}) error {
_, err := ps.conn.Model(i).Set("? = ?", pg.Safe(key), value).Where("id = ?", id).Update()
return err
}
func (ps *postgresService) UpdateWhere(where []db.Where, key string, value interface{}, i interface{}) error {
q := ps.conn.Model(i)
for _, w := range where {
if w.Value == nil {
q = q.Where("? IS NULL", pg.Ident(w.Key))
} else {
if w.CaseInsensitive {
q = q.Where("LOWER(?) = LOWER(?)", pg.Safe(w.Key), w.Value)
} else {
q = q.Where("? = ?", pg.Safe(w.Key), w.Value)
}
}
}
q = q.Set("? = ?", pg.Safe(key), value)
_, err := q.Update()
return err
}

25
internal/db/pg/util.go Normal file
View file

@ -0,0 +1,25 @@
package pg
import (
"strings"
"github.com/go-pg/pg/v10"
"github.com/superseriousbusiness/gotosocial/internal/db"
)
// processErrorResponse parses the given error and returns an appropriate DBError.
func processErrorResponse(err error) db.Error {
switch err {
case nil:
return nil
case pg.ErrNoRows:
return db.ErrNoEntries
case pg.ErrMultiRows:
return db.ErrMultipleEntries
default:
if strings.Contains(err.Error(), "duplicate key value violates unique constraint") {
return db.ErrAlreadyExists
}
return err
}
}

View file

@ -0,0 +1,71 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Relationship contains functions for getting or modifying the relationship between two accounts.
type Relationship interface {
// IsBlocked checks whether account 1 has a block in place against block2.
// If eitherDirection is true, then the function returns true if account1 blocks account2, OR if account2 blocks account1.
IsBlocked(account1 string, account2 string, eitherDirection bool) (bool, Error)
// GetBlock returns the block from account1 targeting account2, if it exists, or an error if it doesn't.
//
// Because this is slower than Blocked, only use it if you need the actual Block struct for some reason,
// not if you're just checking for the existence of a block.
GetBlock(account1 string, account2 string) (*gtsmodel.Block, Error)
// GetRelationship retrieves the relationship of the targetAccount to the requestingAccount.
GetRelationship(requestingAccount string, targetAccount string) (*gtsmodel.Relationship, Error)
// IsFollowing returns true if sourceAccount follows target account, or an error if something goes wrong while finding out.
IsFollowing(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsFollowRequested returns true if sourceAccount has requested to follow target account, or an error if something goes wrong while finding out.
IsFollowRequested(sourceAccount *gtsmodel.Account, targetAccount *gtsmodel.Account) (bool, Error)
// IsMutualFollowing returns true if account1 and account2 both follow each other, or an error if something goes wrong while finding out.
IsMutualFollowing(account1 *gtsmodel.Account, account2 *gtsmodel.Account) (bool, Error)
// AcceptFollowRequest moves a follow request in the database from the follow_requests table to the follows table.
// In other words, it should create the follow, and delete the existing follow request.
//
// It will return the newly created follow for further processing.
AcceptFollowRequest(originAccountID string, targetAccountID string) (*gtsmodel.Follow, Error)
// GetAccountFollowRequests returns all follow requests targeting the given account.
GetAccountFollowRequests(accountID string) ([]*gtsmodel.FollowRequest, Error)
// GetAccountFollows returns a slice of follows owned by the given accountID.
GetAccountFollows(accountID string) ([]*gtsmodel.Follow, Error)
// CountAccountFollows returns the amount of accounts that the given accountID is following.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
CountAccountFollows(accountID string, localOnly bool) (int, Error)
// GetAccountFollowedBy fetches follows that target given accountID.
//
// If localOnly is set to true, then only follows from *this instance* will be returned.
GetAccountFollowedBy(accountID string, localOnly bool) ([]*gtsmodel.Follow, Error)
// CountAccountFollowedBy returns the amounts that the given ID is followed by.
CountAccountFollowedBy(accountID string, localOnly bool) (int, Error)
}

75
internal/db/status.go Normal file
View file

@ -0,0 +1,75 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Status contains functions for getting statuses, creating statuses, and checking various other fields on statuses.
type Status interface {
// GetStatusByID returns one status from the database, with all rel fields populated (if possible).
GetStatusByID(id string) (*gtsmodel.Status, Error)
// GetStatusByURI returns one status from the database, with all rel fields populated (if possible).
GetStatusByURI(uri string) (*gtsmodel.Status, Error)
// GetStatusByURL returns one status from the database, with all rel fields populated (if possible).
GetStatusByURL(uri string) (*gtsmodel.Status, Error)
// PutStatus stores one status in the database.
PutStatus(status *gtsmodel.Status) Error
// CountStatusReplies returns the amount of replies recorded for a status, or an error if something goes wrong
CountStatusReplies(status *gtsmodel.Status) (int, Error)
// CountStatusReblogs returns the amount of reblogs/boosts recorded for a status, or an error if something goes wrong
CountStatusReblogs(status *gtsmodel.Status) (int, Error)
// CountStatusFaves returns the amount of faves/likes recorded for a status, or an error if something goes wrong
CountStatusFaves(status *gtsmodel.Status) (int, Error)
// GetStatusParents gets the parent statuses of a given status.
//
// If onlyDirect is true, only the immediate parent will be returned.
GetStatusParents(status *gtsmodel.Status, onlyDirect bool) ([]*gtsmodel.Status, Error)
// GetStatusChildren gets the child statuses of a given status.
//
// If onlyDirect is true, only the immediate children will be returned.
GetStatusChildren(status *gtsmodel.Status, onlyDirect bool, minID string) ([]*gtsmodel.Status, Error)
// IsStatusFavedBy checks if a given status has been faved by a given account ID
IsStatusFavedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusRebloggedBy checks if a given status has been reblogged/boosted by a given account ID
IsStatusRebloggedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusMutedBy checks if a given status has been muted by a given account ID
IsStatusMutedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// IsStatusBookmarkedBy checks if a given status has been bookmarked by a given account ID
IsStatusBookmarkedBy(status *gtsmodel.Status, accountID string) (bool, Error)
// GetStatusFaves returns a slice of faves/likes of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusFaves(status *gtsmodel.Status) ([]*gtsmodel.StatusFave, Error)
// GetStatusReblogs returns a slice of statuses that are a boost/reblog of the given status.
// This slice will be unfiltered, not taking account of blocks and whatnot, so filter it before serving it back to a user.
GetStatusReblogs(status *gtsmodel.Status) ([]*gtsmodel.Status, Error)
}

44
internal/db/timeline.go Normal file
View file

@ -0,0 +1,44 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
import "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
// Timeline contains functionality for retrieving home/public/faved etc timelines for an account.
type Timeline interface {
// GetHomeTimeline returns a slice of statuses from accounts that are followed by the given account id.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetHomeTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetPublicTimeline fetches the account's PUBLIC timeline -- ie., posts and replies that are public.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Statuses should be returned in descending order of when they were created (newest first).
GetPublicTimeline(accountID string, maxID string, sinceID string, minID string, limit int, local bool) ([]*gtsmodel.Status, Error)
// GetFavedTimeline fetches the account's FAVED timeline -- ie., posts and replies that the requesting account has faved.
// It will use the given filters and try to return as many statuses as possible up to the limit.
//
// Note that unlike the other GetTimeline functions, the returned statuses will be arranged by their FAVE id, not the STATUS id.
// In other words, they'll be returned in descending order of when they were faved by the requesting user, not when they were created.
//
// Also note the extra return values, which correspond to the nextMaxID and prevMinID for building Link headers.
GetFavedTimeline(accountID string, maxID string, minID string, limit int) ([]*gtsmodel.Status, string, string, Error)
}

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federation
import (

View file

@ -29,7 +29,6 @@
"github.com/go-fed/activity/streams/vocab"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/transport"
@ -65,8 +64,8 @@ func (d *deref) GetRemoteAccount(username string, remoteAccountID *url.URL, refr
new := true
// check if we already have the account in our db
maybeAccount := &gtsmodel.Account{}
if err := d.db.GetWhere([]db.Where{{Key: "uri", Value: remoteAccountID.String()}}, maybeAccount); err == nil {
maybeAccount, err := d.db.GetAccountByURI(remoteAccountID.String())
if err == nil {
// we've seen this account before so it's not new
new = false
if !refresh {

View file

@ -27,14 +27,14 @@
)
func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsername string) error {
if announce.GTSBoostedStatus == nil || announce.GTSBoostedStatus.URI == "" {
if announce.BoostOf == nil || announce.BoostOf.URI == "" {
// we can't do anything unfortunately
return errors.New("DereferenceAnnounce: no URI to dereference")
}
boostedStatusURI, err := url.Parse(announce.GTSBoostedStatus.URI)
boostedStatusURI, err := url.Parse(announce.BoostOf.URI)
if err != nil {
return fmt.Errorf("DereferenceAnnounce: couldn't parse boosted status URI %s: %s", announce.GTSBoostedStatus.URI, err)
return fmt.Errorf("DereferenceAnnounce: couldn't parse boosted status URI %s: %s", announce.BoostOf.URI, err)
}
if blocked, err := d.blockedDomain(boostedStatusURI.Host); blocked || err != nil {
return fmt.Errorf("DereferenceAnnounce: domain %s is blocked", boostedStatusURI.Host)
@ -47,7 +47,7 @@ func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsernam
boostedStatus, _, _, err := d.GetRemoteStatus(requestingUsername, boostedStatusURI, false)
if err != nil {
return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.GTSBoostedStatus.URI, err)
return fmt.Errorf("DereferenceAnnounce: error dereferencing remote status with id %s: %s", announce.BoostOf.URI, err)
}
announce.Content = boostedStatus.Content
@ -60,6 +60,6 @@ func (d *deref) DereferenceAnnounce(announce *gtsmodel.Status, requestingUsernam
announce.BoostOfAccountID = boostedStatus.AccountID
announce.Visibility = boostedStatus.Visibility
announce.VisibilityAdvanced = boostedStatus.VisibilityAdvanced
announce.GTSBoostedStatus = boostedStatus
announce.BoostOf = boostedStatus
return nil
}

View file

@ -31,7 +31,7 @@ func (d *deref) blockedDomain(host string) (bool, error) {
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries so there's no block
return false, nil
}

View file

@ -66,8 +66,8 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
new := true
// check if we already have the status in our db
maybeStatus := &gtsmodel.Status{}
if err := d.db.GetWhere([]db.Where{{Key: "uri", Value: remoteStatusID.String()}}, maybeStatus); err == nil {
maybeStatus, err := d.db.GetStatusByURI(remoteStatusID.String())
if err == nil {
// we've seen this status before so it's not new
new = false
@ -109,7 +109,7 @@ func (d *deref) GetRemoteStatus(username string, remoteStatusID *url.URL, refres
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error populating status fields: %s", err)
}
if err := d.db.Put(gtsStatus); err != nil {
if err := d.db.PutStatus(gtsStatus); err != nil {
return nil, statusable, new, fmt.Errorf("GetRemoteStatus: error putting new status: %s", err)
}
} else {
@ -276,7 +276,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
// * the remote URL (a.RemoteURL)
// This should be enough to pass along to the media processor.
attachmentIDs := []string{}
for _, a := range status.GTSMediaAttachments {
for _, a := range status.Attachments {
l.Tracef("dereferencing attachment: %+v", a)
// it might have been processed elsewhere so check first if it's already in the database or not
@ -288,7 +288,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
attachmentIDs = append(attachmentIDs, maybeAttachment.ID)
continue
}
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// we have a real error
return fmt.Errorf("error checking db for existence of attachment with remote url %s: %s", a.RemoteURL, err)
}
@ -307,7 +307,7 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
}
attachmentIDs = append(attachmentIDs, deferencedAttachment.ID)
}
status.Attachments = attachmentIDs
status.AttachmentIDs = attachmentIDs
// 2. Hashtags
@ -317,53 +317,84 @@ func (d *deref) populateStatusFields(status *gtsmodel.Status, requestingUsername
// At this point, mentions should have the namestring and mentionedAccountURI set on them.
//
// We should dereference any accounts mentioned here which we don't have in our db yet, by their URI.
mentions := []string{}
for _, m := range status.GTSMentions {
mentionIDs := []string{}
for _, m := range status.Mentions {
if m.ID != "" {
continue
// we've already populated this mention, since it has an ID
l.Debug("mention already populated")
continue
}
if m.TargetAccountURI == "" {
// can't do anything with this mention
l.Debug("target URI not set on mention")
continue
}
targetAccountURI, err := url.Parse(m.TargetAccountURI)
if err != nil {
l.Debugf("error parsing mentioned account uri %s: %s", m.TargetAccountURI, err)
continue
}
var targetAccount *gtsmodel.Account
if a, err := d.db.GetAccountByURL(targetAccountURI.String()); err == nil {
targetAccount = a
} else if a, _, err := d.GetRemoteAccount(requestingUsername, targetAccountURI, false); err == nil {
targetAccount = a
} else {
// we can't find the target account so bail
l.Debug("can't retrieve account targeted by mention")
continue
}
mID, err := id.NewRandomULID()
if err != nil {
return err
}
m.ID = mID
uri, err := url.Parse(m.MentionedAccountURI)
if err != nil {
l.Debugf("error parsing mentioned account uri %s: %s", m.MentionedAccountURI, err)
continue
m = &gtsmodel.Mention{
ID: mID,
StatusID: status.ID,
Status: m.Status,
CreatedAt: status.CreatedAt,
UpdatedAt: status.UpdatedAt,
OriginAccountID: status.Account.ID,
OriginAccountURI: status.AccountURI,
OriginAccount: status.Account,
TargetAccountID: targetAccount.ID,
TargetAccount: targetAccount,
NameString: m.NameString,
TargetAccountURI: targetAccount.URI,
TargetAccountURL: targetAccount.URL,
}
m.StatusID = status.ID
m.OriginAccountID = status.GTSAuthorAccount.ID
m.OriginAccountURI = status.GTSAuthorAccount.URI
targetAccount, _, err := d.GetRemoteAccount(requestingUsername, uri, false)
if err != nil {
continue
}
// by this point, we know the targetAccount exists in our database with an ID :)
m.TargetAccountID = targetAccount.ID
if err := d.db.Put(m); err != nil {
return fmt.Errorf("error creating mention: %s", err)
}
mentions = append(mentions, m.ID)
mentionIDs = append(mentionIDs, m.ID)
}
status.Mentions = mentions
status.MentionIDs = mentionIDs
// status has replyToURI but we don't have an ID yet for the status it replies to
if status.InReplyToURI != "" && status.InReplyToID == "" {
replyToStatus := &gtsmodel.Status{}
if err := d.db.GetWhere([]db.Where{{Key: "uri", Value: status.InReplyToURI}}, replyToStatus); err == nil {
statusURI, err := url.Parse(status.InReplyToURI)
if err != nil {
return err
}
if replyToStatus, err := d.db.GetStatusByURI(status.InReplyToURI); err == nil {
// we have the status
status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
status.InReplyToAccountID = replyToStatus.AccountID
status.InReplyToAccount = replyToStatus.Account
} else if replyToStatus, _, _, err := d.GetRemoteStatus(requestingUsername, statusURI, false); err == nil {
// we got the status
status.InReplyToID = replyToStatus.ID
status.InReplyTo = replyToStatus
status.InReplyToAccountID = replyToStatus.AccountID
status.InReplyToAccount = replyToStatus.Account
}
}
return nil
}

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (

View file

@ -112,8 +112,8 @@ func (f *federatingDB) Create(ctx context.Context, asType vocab.Type) error {
}
status.ID = statusID
if err := f.db.Put(status); err != nil {
if _, ok := err.(db.ErrAlreadyExists); ok {
if err := f.db.PutStatus(status); err != nil {
if err == db.ErrAlreadyExists {
// the status already exists in the database, which means we've already handled everything else,
// so we can just return nil here and be done with it.
return nil

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (
@ -6,7 +24,6 @@
"net/url"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
@ -52,10 +69,8 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
// in a delete we only get the URI, we can't know if we have a status or a profile or something else,
// so we have to try a few different things...
where := []db.Where{{Key: "uri", Value: id.String()}}
s := &gtsmodel.Status{}
if err := f.db.GetWhere(where, s); err == nil {
s, err := f.db.GetStatusByURI(id.String())
if err == nil {
// it's a status
l.Debugf("uri is for status with id: %s", s.ID)
if err := f.db.DeleteByID(s.ID, &gtsmodel.Status{}); err != nil {
@ -69,8 +84,8 @@ func (f *federatingDB) Delete(ctx context.Context, id *url.URL) error {
}
}
a := &gtsmodel.Account{}
if err := f.db.GetWhere(where, a); err == nil {
a, err := f.db.GetAccountByURI(id.String())
if err == nil {
// it's an account
l.Debugf("uri is for an account with id: %s", s.ID)
if err := f.db.DeleteByID(a.ID, &gtsmodel.Account{}); err != nil {

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (

View file

@ -31,7 +31,8 @@ func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (follower
acct := &gtsmodel.Account{}
if util.IsUserPath(actorIRI) {
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: actorIRI.String()}}, acct); err != nil {
acct, err = f.db.GetAccountByURI(actorIRI.String())
if err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting account with uri %s: %s", actorIRI.String(), err)
}
} else if util.IsFollowersPath(actorIRI) {
@ -42,8 +43,8 @@ func (f *federatingDB) Followers(c context.Context, actorIRI *url.URL) (follower
return nil, fmt.Errorf("FOLLOWERS: could not parse actor IRI %s as users or followers path", actorIRI.String())
}
acctFollowers := []gtsmodel.Follow{}
if err := f.db.GetFollowersByAccountID(acct.ID, &acctFollowers, false); err != nil {
acctFollowers, err := f.db.GetAccountFollowedBy(acct.ID, false)
if err != nil {
return nil, fmt.Errorf("FOLLOWERS: db error getting followers for account id %s: %s", acct.ID, err)
}

View file

@ -8,7 +8,6 @@
"github.com/go-fed/activity/streams"
"github.com/go-fed/activity/streams/vocab"
"github.com/sirupsen/logrus"
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util"
)
@ -28,21 +27,37 @@ func (f *federatingDB) Following(c context.Context, actorIRI *url.URL) (followin
)
l.Debugf("entering FOLLOWING function with actorIRI %s", actorIRI.String())
acct := &gtsmodel.Account{}
var acct *gtsmodel.Account
if util.IsUserPath(actorIRI) {
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: actorIRI.String()}}, acct); err != nil {
username, err := util.ParseUserPath(actorIRI)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: error parsing user path: %s", err)
}
a, err := f.db.GetLocalAccountByUsername(username)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting account with uri %s: %s", actorIRI.String(), err)
}
acct = a
} else if util.IsFollowingPath(actorIRI) {
if err := f.db.GetWhere([]db.Where{{Key: "following_uri", Value: actorIRI.String()}}, acct); err != nil {
username, err := util.ParseFollowingPath(actorIRI)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: error parsing following path: %s", err)
}
a, err := f.db.GetLocalAccountByUsername(username)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting account with following uri %s: %s", actorIRI.String(), err)
}
acct = a
} else {
return nil, fmt.Errorf("FOLLOWING: could not parse actor IRI %s as users or following path", actorIRI.String())
}
acctFollowing := []gtsmodel.Follow{}
if err := f.db.GetFollowingByAccountID(acct.ID, &acctFollowing); err != nil {
acctFollowing, err := f.db.GetAccountFollows(acct.ID)
if err != nil {
return nil, fmt.Errorf("FOLLOWING: db error getting following for account id %s: %s", acct.ID, err)
}

View file

@ -43,8 +43,8 @@ func (f *federatingDB) Get(c context.Context, id *url.URL) (value vocab.Type, er
l.Debug("entering GET function")
if util.IsUserPath(id) {
acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: id.String()}}, acct); err != nil {
acct, err := f.db.GetAccountByURI(id.String())
if err != nil {
return nil, err
}
l.Debug("is user path! returning account")

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (
@ -62,7 +80,7 @@ func (f *federatingDB) OutboxForInbox(c context.Context, inboxIRI *url.URL) (out
}
acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())

View file

@ -54,16 +54,16 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uid}}, &gtsmodel.Status{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
status, err := f.db.GetStatusByURI(uid)
if err != nil {
if err == db.ErrNoEntries {
// there are no entries for this status
return false, nil
}
// an actual error happened
return false, fmt.Errorf("database error fetching status with id %s: %s", uid, err)
}
l.Debugf("we own url %s", id.String())
return true, nil
return status.Local, nil
}
if util.IsUserPath(id) {
@ -71,8 +71,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -88,8 +88,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -105,8 +105,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing statuses path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -122,8 +122,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing like path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -131,7 +131,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
if err := f.db.GetByID(likeID, &gtsmodel.StatusFave{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries
return false, nil
}
@ -147,8 +147,8 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
if err != nil {
return false, fmt.Errorf("error parsing block path for url %s: %s", id.String(), err)
}
if err := f.db.GetLocalAccountByUsername(username, &gtsmodel.Account{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if _, err := f.db.GetLocalAccountByUsername(username); err != nil {
if err == db.ErrNoEntries {
// there are no entries for this username
return false, nil
}
@ -156,7 +156,7 @@ func (f *federatingDB) Owns(c context.Context, id *url.URL) (bool, error) {
return false, fmt.Errorf("database error fetching account with username %s: %s", username, err)
}
if err := f.db.GetByID(blockID, &gtsmodel.Block{}); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// there are no entries
return false, nil
}

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federatingdb
import (

View file

@ -97,8 +97,8 @@ func (f *federatingDB) NewID(c context.Context, t vocab.Type) (idURL *url.URL, e
for iter := actorProp.Begin(); iter != actorProp.End(); iter = iter.Next() {
// take the IRI of the first actor we can find (there should only be one)
if iter.IsIRI() {
actorAccount := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: iter.GetIRI().String()}}, actorAccount); err == nil { // if there's an error here, just use the fallback behavior -- we don't need to return an error here
// if there's an error here, just use the fallback behavior -- we don't need to return an error here
if actorAccount, err := f.db.GetAccountByURI(iter.GetIRI().String()); err == nil {
newID, err := id.NewRandomULID()
if err != nil {
return nil, err
@ -213,7 +213,7 @@ func (f *federatingDB) ActorForOutbox(c context.Context, outboxIRI *url.URL) (ac
}
acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "outbox_uri", Value: outboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to outbox %s", outboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with outbox %s", outboxIRI.String())
@ -238,7 +238,7 @@ func (f *federatingDB) ActorForInbox(c context.Context, inboxIRI *url.URL) (acto
}
acct := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "inbox_uri", Value: inboxIRI.String()}}, acct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, fmt.Errorf("no actor found that corresponds to inbox %s", inboxIRI.String())
}
return nil, fmt.Errorf("db error searching for actor with inbox %s", inboxIRI.String())

View file

@ -113,8 +113,8 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
return nil, false, errors.New("username was empty")
}
requestedAccount := &gtsmodel.Account{}
if err := f.db.GetLocalAccountByUsername(username, requestedAccount); err != nil {
requestedAccount, err := f.db.GetLocalAccountByUsername(username)
if err != nil {
return nil, false, fmt.Errorf("could not fetch requested account with username %s: %s", username, err)
}
@ -132,7 +132,7 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// authentication has passed, so add an instance entry for this instance if it hasn't been done already
i := &gtsmodel.Instance{}
if err := f.db.GetWhere([]db.Where{{Key: "domain", Value: publicKeyOwnerURI.Host, CaseInsensitive: true}}, i); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// there's been an actual error
return ctx, false, fmt.Errorf("error getting requesting account with public key id %s: %s", publicKeyOwnerURI.String(), err)
}
@ -176,8 +176,6 @@ func (f *federator) AuthenticatePostInbox(ctx context.Context, w http.ResponseWr
// Finally, if the authentication and authorization succeeds, then
// blocked must be false and error nil. The request will continue
// to be processed.
//
// TODO: implement domain block checking here as well
func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, error) {
l := f.log.WithFields(logrus.Fields{
"func": "Blocked",
@ -191,19 +189,18 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, errors.New("requested account not set on request context, so couldn't determine blocks")
}
for _, uri := range actorIRIs {
blockedDomain, err := f.blockedDomain(uri.Host)
if err != nil {
return false, fmt.Errorf("error checking domain block: %s", err)
}
if blockedDomain {
return true, nil
}
blocked, err := f.db.AreURIsBlocked(actorIRIs)
if err != nil {
return false, fmt.Errorf("error checking domain blocks: %s", err)
}
if blocked {
return blocked, nil
}
requestingAccount := &gtsmodel.Account{}
if err := f.db.GetWhere([]db.Where{{Key: "uri", Value: uri.String()}}, requestingAccount); err != nil {
_, ok := err.(db.ErrNoEntries)
if ok {
for _, uri := range actorIRIs {
requestingAccount, err := f.db.GetAccountByURI(uri.String())
if err != nil {
if err == db.ErrNoEntries {
// we don't have an entry for this account so it's not blocked
// TODO: allow a different default to be set for this behavior
continue
@ -211,12 +208,11 @@ func (f *federator) Blocked(ctx context.Context, actorIRIs []*url.URL) (bool, er
return false, fmt.Errorf("error getting account with uri %s: %s", uri.String(), err)
}
// check if requested account blocks requesting account
if err := f.db.GetWhere([]db.Where{
{Key: "account_id", Value: requestedAccount.ID},
{Key: "target_account_id", Value: requestingAccount.ID},
}, &gtsmodel.Block{}); err == nil {
// a block exists
blocked, err = f.db.IsBlocked(requestedAccount.ID, requestingAccount.ID, true)
if err != nil {
return false, fmt.Errorf("error checking account block: %s", err)
}
if blocked {
return true, nil
}
}

View file

@ -30,7 +30,7 @@
)
func (f *federator) FingerRemoteAccount(requestingUsername string, targetUsername string, targetDomain string) (*url.URL, error) {
if blocked, err := f.blockedDomain(targetDomain); blocked || err != nil {
if blocked, err := f.db.IsDomainBlocked(targetDomain); blocked || err != nil {
return nil, fmt.Errorf("FingerRemoteAccount: domain %s is blocked", targetDomain)
}

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federation
import "net/url"

View file

@ -1,3 +1,21 @@
/*
GoToSocial
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package federation
import (

View file

@ -1,23 +0,0 @@
package federation
import (
"github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
)
func (f *federator) blockedDomain(host string) (bool, error) {
b := &gtsmodel.DomainBlock{}
err := f.db.GetWhere([]db.Where{{Key: "domain", Value: host, CaseInsensitive: true}}, b)
if err == nil {
// block exists
return true, nil
}
if _, ok := err.(db.ErrNoEntries); ok {
// there are no entries so there's no block
return false, nil
}
// there's an actual error
return false, err
}

View file

@ -45,11 +45,13 @@ type Account struct {
*/
// ID of the avatar as a media attachment
AvatarMediaAttachmentID string `pg:"type:CHAR(26)"`
AvatarMediaAttachmentID string `pg:"type:CHAR(26)"`
AvatarMediaAttachment *MediaAttachment `pg:"rel:has-one"`
// For a non-local account, where can the header be fetched?
AvatarRemoteURL string
// ID of the header as a media attachment
HeaderMediaAttachmentID string `pg:"type:CHAR(26)"`
HeaderMediaAttachmentID string `pg:"type:CHAR(26)"`
HeaderMediaAttachment *MediaAttachment `pg:"rel:has-one"`
// For a non-local account, where can the header be fetched?
HeaderRemoteURL string
// DisplayName for this account. Can be empty, then just the Username will be used for display purposes.

View file

@ -31,7 +31,8 @@ type DomainBlock struct {
// When was this block updated
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Account ID of the creator of this block
CreatedByAccountID string `pg:"type:CHAR(26),notnull"`
CreatedByAccountID string `pg:"type:CHAR(26),notnull"`
CreatedByAccount *Account `pg:"rel:belongs-to"`
// Private comment on this block, viewable to admins
PrivateComment string
// Public comment on this block, viewable (optionally) by everyone

View file

@ -31,5 +31,6 @@ type EmailDomainBlock struct {
// When was this block updated
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Account ID of the creator of this block
CreatedByAccountID string `pg:"type:CHAR(26),notnull"`
CreatedByAccountID string `pg:"type:CHAR(26),notnull"`
CreatedByAccount *Account `pg:"rel:belongs-to"`
}

View file

@ -73,5 +73,6 @@ type Emoji struct {
// Is this emoji visible in the admin emoji picker?
VisibleInPicker bool `pg:",notnull,default:true"`
// In which emoji category is this emoji visible?
CategoryID string `pg:"type:CHAR(26)"`
CategoryID string `pg:"type:CHAR(26)"`
Status *Status `pg:"rel:belongs-to"`
}

View file

@ -29,9 +29,11 @@ type Follow struct {
// When was this follow last updated?
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Who does this follow belong to?
AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
Account *Account `pg:"rel:belongs-to"`
// Who does AccountID follow?
TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// Does this follow also want to see reblogs and not just posts?
ShowReblogs bool `pg:"default:true"`
// What is the activitypub URI of this follow?

View file

@ -29,9 +29,11 @@ type FollowRequest struct {
// When was this follow request last updated?
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Who does this follow request originate from?
AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
AccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
Account Account `pg:"rel:has-one"`
// Who is the target of this follow request?
TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
TargetAccountID string `pg:"type:CHAR(26),unique:srctarget,notnull"`
TargetAccount Account `pg:"rel:has-one"`
// Does this follow also want to see reblogs and not just posts?
ShowReblogs bool `pg:"default:true"`
// What is the activitypub URI of this follow request?

View file

@ -19,7 +19,8 @@ type Instance struct {
// When was this instance suspended, if at all?
SuspendedAt time.Time
// ID of any existing domain block for this instance in the database
DomainBlockID string `pg:"type:CHAR(26)"`
DomainBlockID string `pg:"type:CHAR(26)"`
DomainBlock *DomainBlock `pg:"rel:has-one"`
// Short description of this instance
ShortDescription string
// Longer description of this instance
@ -31,7 +32,8 @@ type Instance struct {
// Username of the contact account for this instance
ContactAccountUsername string
// Contact account ID in the database for this instance
ContactAccountID string `pg:"type:CHAR(26)"`
ContactAccountID string `pg:"type:CHAR(26)"`
ContactAccount *Account `pg:"rel:has-one"`
// Reputation score of this instance
Reputation int64 `pg:",notnull,default:0"`
// Version of the software used on this instance

View file

@ -42,7 +42,8 @@ type MediaAttachment struct {
// Metadata about the file
FileMeta FileMeta
// To which account does this attachment belong
AccountID string `pg:"type:CHAR(26),notnull"`
AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:belongs-to"`
// Description of the attachment (for screenreaders)
Description string
// To which scheduled status does this attachment belong

View file

@ -25,17 +25,20 @@ type Mention struct {
// ID of this mention in the database
ID string `pg:"type:CHAR(26),pk,notnull,unique"`
// ID of the status this mention originates from
StatusID string `pg:"type:CHAR(26),notnull"`
StatusID string `pg:"type:CHAR(26),notnull"`
Status *Status `pg:"rel:belongs-to"`
// When was this mention created?
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// When was this mention last updated?
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// What's the internal account ID of the originator of the mention?
OriginAccountID string `pg:"type:CHAR(26),notnull"`
OriginAccountID string `pg:"type:CHAR(26),notnull"`
OriginAccount *Account `pg:"rel:has-one"`
// What's the AP URI of the originator of the mention?
OriginAccountURI string `pg:",notnull"`
// What's the internal account ID of the mention target?
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// Prevent this mention from generating a notification?
Silent bool
@ -52,14 +55,14 @@ type Mention struct {
//
// This will not be put in the database, it's just for convenience.
NameString string `pg:"-"`
// MentionedAccountURI is the AP ID (uri) of the user mentioned.
// TargetAccountURI is the AP ID (uri) of the user mentioned.
//
// This will not be put in the database, it's just for convenience.
MentionedAccountURI string `pg:"-"`
// MentionedAccountURL is the web url of the user mentioned.
TargetAccountURI string `pg:"-"`
// TargetAccountURL is the web url of the user mentioned.
//
// This will not be put in the database, it's just for convenience.
MentionedAccountURL string `pg:"-"`
TargetAccountURL string `pg:"-"`
// A pointer to the gtsmodel account of the mentioned account.
GTSAccount *Account `pg:"-"`
}

View file

@ -29,24 +29,16 @@ type Notification struct {
// Creation time of this notification
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// Which account does this notification target (ie., who will receive the notification?)
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// Which account performed the action that created this notification?
OriginAccountID string `pg:"type:CHAR(26),notnull"`
OriginAccountID string `pg:"type:CHAR(26),notnull"`
OriginAccount *Account `pg:"rel:has-one"`
// If the notification pertains to a status, what is the database ID of that status?
StatusID string `pg:"type:CHAR(26)"`
StatusID string `pg:"type:CHAR(26)"`
Status *Status `pg:"rel:has-one"`
// Has this notification been read already?
Read bool
/*
NON-DATABASE fields
*/
// gts model of the target account, won't be put in the database, it's just for convenience when passing the notification around.
GTSTargetAccount *Account `pg:"-"`
// gts model of the origin account, won't be put in the database, it's just for convenience when passing the notification around.
GTSOriginAccount *Account `pg:"-"`
// gts model of the relevant status, won't be put in the database, it's just for convenience when passing the notification around.
GTSStatus *Status `pg:"-"`
}
// NotificationType describes the reason/type of this notification.

View file

@ -33,13 +33,17 @@ type Status struct {
// the html-formatted content of this status
Content string
// Database IDs of any media attachments associated with this status
Attachments []string `pg:",array"`
AttachmentIDs []string `pg:"attachments,array"`
Attachments []*MediaAttachment `pg:"attached_media,rel:has-many"`
// Database IDs of any tags used in this status
Tags []string `pg:",array"`
TagIDs []string `pg:"tags,array"`
Tags []*Tag `pg:"attached_tags,many2many:status_to_tags"` // https://pg.uptrace.dev/orm/many-to-many-relation/
// Database IDs of any mentions in this status
Mentions []string `pg:",array"`
MentionIDs []string `pg:"mentions,array"`
Mentions []*Mention `pg:"attached_mentions,rel:has-many"`
// Database IDs of any emojis used in this status
Emojis []string `pg:",array"`
EmojiIDs []string `pg:"emojis,array"`
Emojis []*Emoji `pg:"attached_emojis,many2many:status_to_emojis"` // https://pg.uptrace.dev/orm/many-to-many-relation/
// when was this status created?
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// when was this status updated?
@ -47,19 +51,24 @@ type Status struct {
// is this status from a local account?
Local bool
// which account posted this status?
AccountID string `pg:"type:CHAR(26),notnull"`
AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:has-one"`
// AP uri of the owner of this status
AccountURI string
// id of the status this status is a reply to
InReplyToID string `pg:"type:CHAR(26)"`
InReplyToID string `pg:"type:CHAR(26)"`
InReplyTo *Status `pg:"rel:has-one"`
// AP uri of the status this status is a reply to
InReplyToURI string
// id of the account that this status replies to
InReplyToAccountID string `pg:"type:CHAR(26)"`
InReplyToAccountID string `pg:"type:CHAR(26)"`
InReplyToAccount *Account `pg:"rel:has-one"`
// id of the status this status is a boost of
BoostOfID string `pg:"type:CHAR(26)"`
BoostOfID string `pg:"type:CHAR(26)"`
BoostOf *Status `pg:"rel:has-one"`
// id of the account that owns the boosted status
BoostOfAccountID string `pg:"type:CHAR(26)"`
BoostOfAccountID string `pg:"type:CHAR(26)"`
BoostOfAccount *Account `pg:"rel:has-one"`
// cw string for this status
ContentWarning string
// visibility entry for this status
@ -69,7 +78,8 @@ type Status struct {
// what language is this status written in?
Language string
// Which application was used to create this status?
CreatedWithApplicationID string `pg:"type:CHAR(26)"`
CreatedWithApplicationID string `pg:"type:CHAR(26)"`
CreatedWithApplication *Application `pg:"rel:has-one"`
// advanced visibility for this status
VisibilityAdvanced *VisibilityAdvanced
// What is the activitystreams type of this status? See: https://www.w3.org/TR/activitystreams-vocabulary/#object-types
@ -79,32 +89,18 @@ type Status struct {
Text string
// Has this status been pinned by its owner?
Pinned bool
}
/*
INTERNAL MODEL NON-DATABASE FIELDS
// StatusToTag is an intermediate struct to facilitate the many2many relationship between a status and one or more tags.
type StatusToTag struct {
StatusID string `pg:"unique:statustag"`
TagID string `pg:"unique:statustag"`
}
These are for convenience while passing the status around internally,
but these fields should *never* be put in the db.
*/
// Account that created this status
GTSAuthorAccount *Account `pg:"-"`
// Mentions created in this status
GTSMentions []*Mention `pg:"-"`
// Hashtags used in this status
GTSTags []*Tag `pg:"-"`
// Emojis used in this status
GTSEmojis []*Emoji `pg:"-"`
// MediaAttachments used in this status
GTSMediaAttachments []*MediaAttachment `pg:"-"`
// Status being replied to
GTSReplyToStatus *Status `pg:"-"`
// Account being replied to
GTSReplyToAccount *Account `pg:"-"`
// Status being boosted
GTSBoostedStatus *Status `pg:"-"`
// Account of the boosted status
GTSBoostedAccount *Account `pg:"-"`
// StatusToEmoji is an intermediate struct to facilitate the many2many relationship between a status and one or more emojis.
type StatusToEmoji struct {
StatusID string `pg:"unique:statusemoji"`
EmojiID string `pg:"unique:statusemoji"`
}
// Visibility represents the visibility granularity of a status.

View file

@ -27,9 +27,11 @@ type StatusBookmark struct {
// when was this bookmark created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// id of the account that created ('did') the bookmarking
AccountID string `pg:"type:CHAR(26),notnull"`
AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:belongs-to"`
// id the account owning the bookmarked status
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// database id of the status that has been bookmarked
StatusID string `pg:"type:CHAR(26),notnull"`
}

View file

@ -27,18 +27,14 @@ type StatusFave struct {
// when was this fave created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// id of the account that created ('did') the fave
AccountID string `pg:"type:CHAR(26),notnull"`
AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:has-one"`
// id the account owning the faved status
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// database id of the status that has been 'faved'
StatusID string `pg:"type:CHAR(26),notnull"`
StatusID string `pg:"type:CHAR(26),notnull"`
Status *Status `pg:"rel:has-one"`
// ActivityPub URI of this fave
URI string `pg:",notnull"`
// GTSStatus is the status being interacted with. It won't be put or retrieved from the db, it's just for conveniently passing a pointer around.
GTSStatus *Status `pg:"-"`
// GTSTargetAccount is the account being interacted with. It won't be put or retrieved from the db, it's just for conveniently passing a pointer around.
GTSTargetAccount *Account `pg:"-"`
// GTSFavingAccount is the account doing the faving. It won't be put or retrieved from the db, it's just for conveniently passing a pointer around.
GTSFavingAccount *Account `pg:"-"`
}

View file

@ -27,9 +27,12 @@ type StatusMute struct {
// when was this mute created
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
// id of the account that created ('did') the mute
AccountID string `pg:"type:CHAR(26),notnull"`
AccountID string `pg:"type:CHAR(26),notnull"`
Account *Account `pg:"rel:belongs-to"`
// id the account owning the muted status (can be the same as accountID)
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccountID string `pg:"type:CHAR(26),notnull"`
TargetAccount *Account `pg:"rel:has-one"`
// database id of the status that has been muted
StatusID string `pg:"type:CHAR(26),notnull"`
StatusID string `pg:"type:CHAR(26),notnull"`
Status *Status `pg:"rel:has-one"`
}

View file

@ -27,7 +27,7 @@ type Tag struct {
// Href of this tag, eg https://example.org/tags/somehashtag
URL string
// name of this tag -- the tag without the hash part
Name string `pg:",unique,pk,notnull"`
Name string `pg:",unique,notnull"`
// Which account ID is the first one we saw using this tag?
FirstSeenFromAccountID string `pg:"type:CHAR(26)"`
// when was this tag created

View file

@ -35,7 +35,8 @@ type User struct {
// confirmed email address for this user, this should be unique -- only one email address registered per instance, multiple users per email are not supported
Email string `pg:"default:null,unique"`
// The id of the local gtsmodel.Account entry for this user, if it exists (unconfirmed users don't have an account yet)
AccountID string `pg:"type:CHAR(26),unique"`
AccountID string `pg:"type:CHAR(26),unique"`
Account *Account `pg:"rel:has-one"`
// The encrypted password of this user, generated using https://pkg.go.dev/golang.org/x/crypto/bcrypt#GenerateFromPassword. A salt is included so we're safe against 🌈 tables
EncryptedPassword string `pg:",notnull"`
@ -68,7 +69,8 @@ type User struct {
// In what timezone/locale is this user located?
Locale string
// Which application id created this user? See gtsmodel.Application
CreatedByApplicationID string `pg:"type:CHAR(26)"`
CreatedByApplicationID string `pg:"type:CHAR(26)"`
CreatedByApplication *Application `pg:"rel:has-one"`
// When did we last contact this user
LastEmailedAt time.Time `pg:"type:timestamp"`

View file

@ -142,7 +142,7 @@ func (mh *mediaHandler) ProcessHeaderOrAvatar(attachment []byte, accountID strin
}
// set it in the database
if err := mh.db.SetHeaderOrAvatarForAccountID(ma, accountID); err != nil {
if err := mh.db.SetAccountHeaderOrAvatar(ma, accountID); err != nil {
return nil, fmt.Errorf("error putting %s in database: %s", mediaType, err)
}
@ -231,8 +231,8 @@ func (mh *mediaHandler) ProcessLocalEmoji(emojiBytes []byte, shortcode string) (
// since emoji aren't 'owned' by an account, but we still want to use the same pattern for serving them through the filserver,
// (ie., fileserver/ACCOUNT_ID/etc etc) we need to fetch the INSTANCE ACCOUNT from the database. That is, the account that's created
// with the same username as the instance hostname, which doesn't belong to any particular user.
instanceAccount := &gtsmodel.Account{}
if err := mh.db.GetLocalAccountByUsername(mh.config.Host, instanceAccount); err != nil {
instanceAccount, err := mh.db.GetInstanceAccount("")
if err != nil {
return nil, fmt.Errorf("error fetching instance account: %s", err)
}

View file

@ -27,11 +27,11 @@
)
type clientStore struct {
db db.DB
db db.Basic
}
// NewClientStore returns an implementation of the oauth2 ClientStore interface, using the given db as a storage backend.
func NewClientStore(db db.DB) oauth2.ClientStore {
func NewClientStore(db db.Basic) oauth2.ClientStore {
pts := &clientStore{
db: db,
}

View file

@ -99,7 +99,7 @@ func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
// try to get the deleted client; we should get an error
deletedClient, err := cs.GetByID(context.Background(), suite.testClientID)
suite.Assert().Nil(deletedClient)
suite.Assert().EqualValues(db.ErrNoEntries{}, err)
suite.Assert().EqualValues(db.ErrNoEntries, err)
}
func TestPgClientStoreTestSuite(t *testing.T) {

View file

@ -66,7 +66,7 @@ type s struct {
}
// New returns a new oauth server that implements the Server interface
func New(database db.DB, log *logrus.Logger) Server {
func New(database db.Basic, log *logrus.Logger) Server {
ts := newTokenStore(context.Background(), database, log)
cs := NewClientStore(database)

View file

@ -34,7 +34,7 @@
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
type tokenStore struct {
oauth2.TokenStore
db db.DB
db db.Basic
log *logrus.Logger
}
@ -42,7 +42,7 @@ type tokenStore struct {
//
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
// the tokens in the DB once per minute and deletes any that have expired.
func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore {
func newTokenStore(ctx context.Context, db db.Basic, log *logrus.Logger) oauth2.TokenStore {
pts := &tokenStore{
db: db,
log: log,

View file

@ -31,24 +31,20 @@
func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: account %s not found in the db: %s", targetAccountID, err))
}
targetAccount, err := p.db.GetAccountByID(targetAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
}
// if requestingAccount already blocks target account, we don't need to do anything
block := &gtsmodel.Block{}
if err := p.db.GetWhere([]db.Where{
{Key: "account_id", Value: requestingAccount.ID},
{Key: "target_account_id", Value: targetAccountID},
}, block); err == nil {
// block already exists, just return relationship
if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, false); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockCreate: error checking existence of block: %s", err))
} else if blocked {
return p.RelationshipGet(requestingAccount, targetAccountID)
}
// make the block
block := &gtsmodel.Block{}
newBlockID, err := id.NewULID()
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
@ -57,7 +53,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
block.AccountID = requestingAccount.ID
block.Account = requestingAccount
block.TargetAccountID = targetAccountID
block.TargetAccount = targetAcct
block.TargetAccount = targetAccount
block.URI = util.GenerateURIForBlock(requestingAccount.Username, p.config.Protocol, p.config.Host, newBlockID)
// whack it in the database
@ -123,7 +119,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
URI: frURI,
},
OriginAccount: requestingAccount,
TargetAccount: targetAcct,
TargetAccount: targetAccount,
}
}
@ -138,7 +134,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
URI: fURI,
},
OriginAccount: requestingAccount,
TargetAccount: targetAcct,
TargetAccount: targetAccount,
}
}
@ -148,7 +144,7 @@ func (p *processor) BlockCreate(requestingAccount *gtsmodel.Account, targetAccou
APActivityType: gtsmodel.ActivityStreamsCreate,
GTSModel: block,
OriginAccount: requestingAccount,
TargetAccount: targetAcct,
TargetAccount: targetAccount,
}
return p.RelationshipGet(requestingAccount, targetAccountID)

View file

@ -31,38 +31,33 @@
func (p *processor) FollowCreate(requestingAccount *gtsmodel.Account, form *apimodel.AccountFollowRequest) (*apimodel.Relationship, gtserror.WithCode) {
// if there's a block between the accounts we shouldn't create the request ofc
blocked, err := p.db.Blocked(requestingAccount.ID, form.ID)
if err != nil {
if blocked, err := p.db.IsBlocked(requestingAccount.ID, form.ID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: block exists between accounts"))
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
// make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{}
if err := p.db.GetByID(form.ID, targetAcct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
targetAcct, err := p.db.GetAccountByID(form.ID)
if err != nil {
if err == db.ErrNoEntries {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("accountfollowcreate: account %s not found in the db: %s", form.ID, err))
}
return nil, gtserror.NewErrorInternalError(err)
}
// check if a follow exists already
follows, err := p.db.Follows(requestingAccount, targetAcct)
if err != nil {
if follows, err := p.db.IsFollowing(requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow in db: %s", err))
}
if follows {
} else if follows {
// already follows so just return the relationship
return p.RelationshipGet(requestingAccount, form.ID)
}
// check if a follow exists already
followRequested, err := p.db.FollowRequested(requestingAccount, targetAcct)
if err != nil {
// check if a follow request exists already
if followRequested, err := p.db.IsFollowRequested(requestingAccount, targetAcct); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("accountfollowcreate: error checking follow request in db: %s", err))
}
if followRequested {
} else if followRequested {
// already follow requested so just return the relationship
return p.RelationshipGet(requestingAccount, form.ID)
}

View file

@ -133,9 +133,9 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
var maxID string
selectStatusesLoop:
for {
statuses, err := p.db.GetStatusesForAccount(account.ID, 20, false, maxID, false, false)
statuses, err := p.db.GetAccountStatuses(account.ID, 20, false, maxID, false, false)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
// no statuses left for this instance so we're done
l.Infof("Delete: done iterating through statuses for account %s", account.Username)
break selectStatusesLoop
@ -147,7 +147,7 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
for i, s := range statuses {
// pass the status delete through the client api channel for processing
s.GTSAuthorAccount = account
s.Account = account
l.Debug("putting status in the client api channel")
p.fromClientAPI <- gtsmodel.FromClientAPI{
APObjectType: gtsmodel.ActivityStreamsNote,
@ -158,7 +158,7 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
}
if err := p.db.DeleteByID(s.ID, s); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error status %s for account %s: %s", s.ID, account.Username, err)
break selectStatusesLoop
@ -168,7 +168,7 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
// if there are any boosts of this status, delete them as well
boosts := []*gtsmodel.Status{}
if err := p.db.GetWhere([]db.Where{{Key: "boost_of_id", Value: s.ID}}, &boosts); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// an actual error has occurred
l.Errorf("Delete: db error selecting boosts of status %s for account %s: %s", s.ID, account.Username, err)
break selectStatusesLoop
@ -190,7 +190,7 @@ func (p *processor) Delete(account *gtsmodel.Account, origin string) error {
}
if err := p.db.DeleteByID(b.ID, b); err != nil {
if _, ok := err.(db.ErrNoEntries); !ok {
if err != db.ErrNoEntries {
// actual error has occurred
l.Errorf("Delete: db error deleting boost with id %s: %s", b.ID, err)
break selectStatusesLoop

View file

@ -30,7 +30,7 @@
func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Account, error) {
targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return nil, errors.New("account not found")
}
return nil, fmt.Errorf("db error: %s", err)
@ -39,7 +39,7 @@ func (p *processor) Get(requestingAccount *gtsmodel.Account, targetAccountID str
var blocked bool
var err error
if requestingAccount != nil {
blocked, err = p.db.Blocked(requestingAccount.ID, targetAccountID)
blocked, err = p.db.IsBlocked(requestingAccount.ID, targetAccountID, true)
if err != nil {
return nil, fmt.Errorf("error checking account block: %s", err)
}

View file

@ -28,26 +28,23 @@
)
func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
blocked, err := p.db.Blocked(requestingAccount.ID, targetAccountID)
if err != nil {
if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
if blocked {
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
followers := []gtsmodel.Follow{}
accounts := []apimodel.Account{}
if err := p.db.GetFollowersByAccountID(targetAccountID, &followers, false); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
follows, err := p.db.GetAccountFollowedBy(targetAccountID, false)
if err != nil {
if err == db.ErrNoEntries {
return accounts, nil
}
return nil, gtserror.NewErrorInternalError(err)
}
for _, f := range followers {
blocked, err := p.db.Blocked(requestingAccount.ID, f.AccountID)
for _, f := range follows {
blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@ -55,15 +52,18 @@ func (p *processor) FollowersGet(requestingAccount *gtsmodel.Account, targetAcco
continue
}
a := &gtsmodel.Account{}
if err := p.db.GetByID(f.AccountID, a); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
continue
if f.Account == nil {
a, err := p.db.GetAccountByID(f.AccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
}
return nil, gtserror.NewErrorInternalError(err)
}
return nil, gtserror.NewErrorInternalError(err)
f.Account = a
}
account, err := p.tc.AccountToMastoPublic(a)
account, err := p.tc.AccountToMastoPublic(f.Account)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}

View file

@ -28,26 +28,23 @@
)
func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAccountID string) ([]apimodel.Account, gtserror.WithCode) {
blocked, err := p.db.Blocked(requestingAccount.ID, targetAccountID)
if err != nil {
if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
if blocked {
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
following := []gtsmodel.Follow{}
accounts := []apimodel.Account{}
if err := p.db.GetFollowingByAccountID(targetAccountID, &following); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
follows, err := p.db.GetAccountFollows(targetAccountID)
if err != nil {
if err == db.ErrNoEntries {
return accounts, nil
}
return nil, gtserror.NewErrorInternalError(err)
}
for _, f := range following {
blocked, err := p.db.Blocked(requestingAccount.ID, f.AccountID)
for _, f := range follows {
blocked, err := p.db.IsBlocked(requestingAccount.ID, f.AccountID, true)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}
@ -55,15 +52,18 @@ func (p *processor) FollowingGet(requestingAccount *gtsmodel.Account, targetAcco
continue
}
a := &gtsmodel.Account{}
if err := p.db.GetByID(f.TargetAccountID, a); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
continue
if f.TargetAccount == nil {
a, err := p.db.GetAccountByID(f.TargetAccountID)
if err != nil {
if err == db.ErrNoEntries {
continue
}
return nil, gtserror.NewErrorInternalError(err)
}
return nil, gtserror.NewErrorInternalError(err)
f.TargetAccount = a
}
account, err := p.tc.AccountToMastoPublic(a)
account, err := p.tc.AccountToMastoPublic(f.TargetAccount)
if err != nil {
return nil, gtserror.NewErrorInternalError(err)
}

View file

@ -28,18 +28,17 @@
)
func (p *processor) StatusesGet(requestingAccount *gtsmodel.Account, targetAccountID string, limit int, excludeReplies bool, maxID string, pinnedOnly bool, mediaOnly bool) ([]apimodel.Status, gtserror.WithCode) {
targetAccount := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAccount); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("no entry found for account id %s", targetAccountID))
}
if blocked, err := p.db.IsBlocked(requestingAccount.ID, targetAccountID, true); err != nil {
return nil, gtserror.NewErrorInternalError(err)
} else if blocked {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("block exists between accounts"))
}
apiStatuses := []apimodel.Status{}
statuses, err := p.db.GetStatusesForAccount(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
statuses, err := p.db.GetAccountStatuses(targetAccountID, limit, excludeReplies, maxID, pinnedOnly, mediaOnly)
if err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
if err == db.ErrNoEntries {
return apiStatuses, nil
}
return nil, gtserror.NewErrorInternalError(err)

View file

@ -29,11 +29,9 @@
func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccountID string) (*apimodel.Relationship, gtserror.WithCode) {
// make sure the target account actually exists in our db
targetAcct := &gtsmodel.Account{}
if err := p.db.GetByID(targetAccountID, targetAcct); err != nil {
if _, ok := err.(db.ErrNoEntries); ok {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockRemove: account %s not found in the db: %s", targetAccountID, err))
}
targetAccount, err := p.db.GetAccountByID(targetAccountID)
if err != nil {
return nil, gtserror.NewErrorNotFound(fmt.Errorf("BlockCreate: error getting account %s from the db: %s", targetAccountID, err))
}
// check if a block exists, and remove it if it does (storing the URI for later)
@ -44,7 +42,7 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
{Key: "target_account_id", Value: targetAccountID},
}, block); err == nil {
block.Account = requestingAccount
block.TargetAccount = targetAcct
block.TargetAccount = targetAccount
if err := p.db.DeleteByID(block.ID, &gtsmodel.Block{}); err != nil {
return nil, gtserror.NewErrorInternalError(fmt.Errorf("BlockRemove: error removing block from db: %s", err))
}
@ -58,7 +56,7 @@ func (p *processor) BlockRemove(requestingAccount *gtsmodel.Account, targetAccou
APActivityType: gtsmodel.ActivityStreamsUndo,
GTSModel: block,
OriginAccount: requestingAccount,
TargetAccount: targetAcct,
TargetAccount: targetAccount,
}
}

Some files were not shown because too many files have changed in this diff Show more