mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-11-22 11:46:40 +00:00
Oauth/token (#7)
* add host and protocol options * some fiddling * tidying up and comments * tick off /oauth/token * tidying a bit * tidying * go mod tidy * allow attaching middleware to server * add middleware * more user friendly * add comments * comments * store account + app * tidying * lots of restructuring * lint + tidy
This commit is contained in:
parent
4194f8d88f
commit
aa9ce272dc
|
@ -6,7 +6,7 @@
|
||||||
* [ ] /api/v1/apps/verify_credentials GET (Verify an application works)
|
* [ ] /api/v1/apps/verify_credentials GET (Verify an application works)
|
||||||
* [x] /oauth/authorize GET (Show authorize page to user)
|
* [x] /oauth/authorize GET (Show authorize page to user)
|
||||||
* [x] /oauth/authorize POST (Get an oauth access code for an app/user)
|
* [x] /oauth/authorize POST (Get an oauth access code for an app/user)
|
||||||
* [ ] /oauth/token POST (Obtain a user-level access token)
|
* [x] /oauth/token POST (Obtain a user-level access token)
|
||||||
* [ ] /oauth/revoke POST (Revoke a user-level access token)
|
* [ ] /oauth/revoke POST (Revoke a user-level access token)
|
||||||
* [x] /auth/sign_in GET (Show form for user signin)
|
* [x] /auth/sign_in GET (Show form for user signin)
|
||||||
* [x] /auth/sign_in POST (Validate username and password and sign user in)
|
* [x] /auth/sign_in POST (Validate username and password and sign user in)
|
||||||
|
|
|
@ -58,6 +58,18 @@ func main() {
|
||||||
Value: "",
|
Value: "",
|
||||||
EnvVars: []string{envNames.ConfigPath},
|
EnvVars: []string{envNames.ConfigPath},
|
||||||
},
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: flagNames.Host,
|
||||||
|
Usage: "Hostname to use for the server (eg., example.org, gotosocial.whatever.com)",
|
||||||
|
Value: "localhost",
|
||||||
|
EnvVars: []string{envNames.Host},
|
||||||
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: flagNames.Protocol,
|
||||||
|
Usage: "Protocol to use for the REST api of the server (only use http for debugging and tests!)",
|
||||||
|
Value: "https",
|
||||||
|
EnvVars: []string{envNames.Protocol},
|
||||||
|
},
|
||||||
|
|
||||||
// DATABASE FLAGS
|
// DATABASE FLAGS
|
||||||
&cli.StringFlag{
|
&cli.StringFlag{
|
||||||
|
|
|
@ -28,6 +28,17 @@ logLevel: "info"
|
||||||
# Default: "gotosocial"
|
# Default: "gotosocial"
|
||||||
applicationName: "gotosocial"
|
applicationName: "gotosocial"
|
||||||
|
|
||||||
|
# String. Hostname/domain to use for the server. Defaults to localhost for local testing,
|
||||||
|
# but you should *definitely* change this when running for real, or your server won't work at all.
|
||||||
|
# Examples: ["example.org","some.server.com"]
|
||||||
|
# Default: "localhost"
|
||||||
|
host: "localhost"
|
||||||
|
|
||||||
|
# String. Protocol to use for the server. Only change to http for local testing!
|
||||||
|
# Options: ["http","https"]
|
||||||
|
# Default: "https"
|
||||||
|
protocol: "https"
|
||||||
|
|
||||||
# Config pertaining to the Gotosocial database connection
|
# Config pertaining to the Gotosocial database connection
|
||||||
db:
|
db:
|
||||||
# String. Database type.
|
# String. Database type.
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -10,7 +10,7 @@ require (
|
||||||
github.com/go-pg/pg/v10 v10.8.0
|
github.com/go-pg/pg/v10 v10.8.0
|
||||||
github.com/golang/mock v1.4.4 // indirect
|
github.com/golang/mock v1.4.4 // indirect
|
||||||
github.com/google/uuid v1.2.0
|
github.com/google/uuid v1.2.0
|
||||||
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3
|
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88
|
||||||
github.com/onsi/ginkgo v1.15.0 // indirect
|
github.com/onsi/ginkgo v1.15.0 // indirect
|
||||||
github.com/onsi/gomega v1.10.5 // indirect
|
github.com/onsi/gomega v1.10.5 // indirect
|
||||||
github.com/sirupsen/logrus v1.8.0
|
github.com/sirupsen/logrus v1.8.0
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -103,8 +103,8 @@ github.com/gorilla/sessions v1.1.3 h1:uXoZdcdA5XdXF3QzuSlheVRUvjl+1rKY7zBXL68L9R
|
||||||
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
|
github.com/gorilla/sessions v1.1.3/go.mod h1:8KCfur6+4Mqcc6S0FEfKuN15Vl5MgXW92AE8ovaJD0w=
|
||||||
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
|
||||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3 h1:CKRz5d7mRum+UMR88Ue33tCYcej14WjUsB59C02DDqY=
|
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88 h1:YJ//HmHOYJ4srm/LA6VPNjNisneMbY6TTM1xttV/ZQU=
|
||||||
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210318133800-45d321d259b3/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8=
|
github.com/gotosocial/oauth2/v4 v4.2.1-0.20210316171520-7b12112bbb88/go.mod h1:zl5kwHf/atRUrY5yOyDnk49Us1Ygs0BzdW4jKAgoiP8=
|
||||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
|
github.com/imkira/go-interpol v1.1.0 h1:KIiKr0VSG2CUW1hl1jpiyuzuJeKUUpC8iM1AIE7N1Vk=
|
||||||
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
|
github.com/imkira/go-interpol v1.1.0/go.mod h1:z0h2/2T3XF8kyEPpRgJ3kmNv+C43p+I/CoI+jC3w2iA=
|
||||||
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*
|
|
||||||
GoToSocial
|
|
||||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU Affero General Public License as published by
|
|
||||||
the Free Software Foundation, either version 3 of the License, or
|
|
||||||
(at your option) any later version.
|
|
||||||
|
|
||||||
This program is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU Affero General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU Affero General Public License
|
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package api
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-contrib/sessions/memstore"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Server interface {
|
|
||||||
AttachHandler(method string, path string, handler gin.HandlerFunc)
|
|
||||||
// AttachMiddleware(handler gin.HandlerFunc)
|
|
||||||
GetAPIGroup() *gin.RouterGroup
|
|
||||||
Start()
|
|
||||||
Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
type AddsRoutes interface {
|
|
||||||
AddRoutes(s Server) error
|
|
||||||
}
|
|
||||||
|
|
||||||
type server struct {
|
|
||||||
APIGroup *gin.RouterGroup
|
|
||||||
logger *logrus.Logger
|
|
||||||
engine *gin.Engine
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) GetAPIGroup() *gin.RouterGroup {
|
|
||||||
return s.APIGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) Start() {
|
|
||||||
// todo: start gracefully
|
|
||||||
if err := s.engine.Run(); err != nil {
|
|
||||||
s.logger.Panicf("server error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) Stop() {
|
|
||||||
// todo: shut down gracefully
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *server) AttachHandler(method string, path string, handler gin.HandlerFunc) {
|
|
||||||
if method == "ANY" {
|
|
||||||
s.engine.Any(path, handler)
|
|
||||||
} else {
|
|
||||||
s.engine.Handle(method, path, handler)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(config *config.Config, logger *logrus.Logger) Server {
|
|
||||||
engine := gin.New()
|
|
||||||
store := memstore.NewStore([]byte("authentication-key"), []byte("encryption-keyencryption-key----"))
|
|
||||||
engine.Use(sessions.Sessions("gotosocial-session", store))
|
|
||||||
cwd, _ := os.Getwd()
|
|
||||||
tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir))
|
|
||||||
logger.Debugf("loading templates from %s", tmPath)
|
|
||||||
engine.LoadHTMLGlob(tmPath)
|
|
||||||
return &server{
|
|
||||||
APIGroup: engine.Group("/api").Group("/v1"),
|
|
||||||
logger: logger,
|
|
||||||
engine: engine,
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -29,6 +29,8 @@
|
||||||
type Config struct {
|
type Config struct {
|
||||||
LogLevel string `yaml:"logLevel"`
|
LogLevel string `yaml:"logLevel"`
|
||||||
ApplicationName string `yaml:"applicationName"`
|
ApplicationName string `yaml:"applicationName"`
|
||||||
|
Host string `yaml:"host"`
|
||||||
|
Protocol string `yaml:"protocol"`
|
||||||
DBConfig *DBConfig `yaml:"db"`
|
DBConfig *DBConfig `yaml:"db"`
|
||||||
TemplateConfig *TemplateConfig `yaml:"template"`
|
TemplateConfig *TemplateConfig `yaml:"template"`
|
||||||
}
|
}
|
||||||
|
@ -97,6 +99,14 @@ func (c *Config) ParseCLIFlags(f KeyedFlags) {
|
||||||
c.ApplicationName = f.String(fn.ApplicationName)
|
c.ApplicationName = f.String(fn.ApplicationName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.Host == "" || f.IsSet(fn.Host) {
|
||||||
|
c.Host = f.String(fn.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Protocol == "" || f.IsSet(fn.Protocol) {
|
||||||
|
c.Protocol = f.String(fn.Protocol)
|
||||||
|
}
|
||||||
|
|
||||||
// db flags
|
// db flags
|
||||||
if c.DBConfig.Type == "" || f.IsSet(fn.DbType) {
|
if c.DBConfig.Type == "" || f.IsSet(fn.DbType) {
|
||||||
c.DBConfig.Type = f.String(fn.DbType)
|
c.DBConfig.Type = f.String(fn.DbType)
|
||||||
|
@ -142,6 +152,8 @@ type Flags struct {
|
||||||
LogLevel string
|
LogLevel string
|
||||||
ApplicationName string
|
ApplicationName string
|
||||||
ConfigPath string
|
ConfigPath string
|
||||||
|
Host string
|
||||||
|
Protocol string
|
||||||
DbType string
|
DbType string
|
||||||
DbAddress string
|
DbAddress string
|
||||||
DbPort string
|
DbPort string
|
||||||
|
@ -158,6 +170,8 @@ func GetFlagNames() Flags {
|
||||||
LogLevel: "log-level",
|
LogLevel: "log-level",
|
||||||
ApplicationName: "application-name",
|
ApplicationName: "application-name",
|
||||||
ConfigPath: "config-path",
|
ConfigPath: "config-path",
|
||||||
|
Host: "host",
|
||||||
|
Protocol: "protocol",
|
||||||
DbType: "db-type",
|
DbType: "db-type",
|
||||||
DbAddress: "db-address",
|
DbAddress: "db-address",
|
||||||
DbPort: "db-port",
|
DbPort: "db-port",
|
||||||
|
@ -175,6 +189,8 @@ func GetEnvNames() Flags {
|
||||||
LogLevel: "GTS_LOG_LEVEL",
|
LogLevel: "GTS_LOG_LEVEL",
|
||||||
ApplicationName: "GTS_APPLICATION_NAME",
|
ApplicationName: "GTS_APPLICATION_NAME",
|
||||||
ConfigPath: "GTS_CONFIG_PATH",
|
ConfigPath: "GTS_CONFIG_PATH",
|
||||||
|
Host: "GTS_HOST",
|
||||||
|
Protocol: "GTS_PROTOCOL",
|
||||||
DbType: "GTS_DB_TYPE",
|
DbType: "GTS_DB_TYPE",
|
||||||
DbAddress: "GTS_DB_ADDRESS",
|
DbAddress: "GTS_DB_ADDRESS",
|
||||||
DbPort: "GTS_DB_PORT",
|
DbPort: "GTS_DB_PORT",
|
||||||
|
|
|
@ -28,9 +28,10 @@
|
||||||
|
|
||||||
// Initialize will initialize the database given in the config for use with GoToSocial
|
// Initialize will initialize the database given in the config for use with GoToSocial
|
||||||
var Initialize action.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
|
var Initialize action.GTSAction = func(ctx context.Context, c *config.Config, log *logrus.Logger) error {
|
||||||
db, err := New(ctx, c, log)
|
// db, err := New(ctx, c, log)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
return db.CreateSchema(ctx)
|
return nil
|
||||||
|
// return db.CreateSchema(ctx)
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,30 +30,47 @@
|
||||||
|
|
||||||
const dbTypePostgres string = "POSTGRES"
|
const dbTypePostgres string = "POSTGRES"
|
||||||
|
|
||||||
// DB provides methods for interacting with an underlying database (for now, just postgres).
|
// DB provides methods for interacting with an underlying database or other storage mechanism (for now, just postgres).
|
||||||
// The function mapping lines up with the DB interface described in go-fed.
|
|
||||||
// See here: https://github.com/go-fed/activity/blob/master/pub/database.go
|
|
||||||
type DB interface {
|
type DB interface {
|
||||||
/*
|
// Federation returns an interface that's compatible with go-fed, for performing federation storage/retrieval functions.
|
||||||
GO-FED DATABASE FUNCTIONS
|
// See: https://pkg.go.dev/github.com/go-fed/activity@v1.0.0/pub?utm_source=gopls#Database
|
||||||
*/
|
Federation() pub.Database
|
||||||
pub.Database
|
|
||||||
|
|
||||||
/*
|
// CreateTable creates a table for the given interface
|
||||||
ANY ADDITIONAL DESIRED FUNCTIONS
|
CreateTable(i interface{}) error
|
||||||
*/
|
|
||||||
|
|
||||||
// CreateSchema should populate the database with the required tables
|
// DropTable drops the table for the given interface
|
||||||
CreateSchema(context.Context) error
|
DropTable(i interface{}) error
|
||||||
|
|
||||||
// Stop should stop and close the database connection cleanly, returning an error if this is not possible
|
// Stop should stop and close the database connection cleanly, returning an error if this is not possible
|
||||||
Stop(context.Context) error
|
Stop(ctx context.Context) error
|
||||||
|
|
||||||
// IsHealthy should return nil if the database connection is healthy, or an error if not
|
// IsHealthy should return nil if the database connection is healthy, or an error if not
|
||||||
IsHealthy(context.Context) error
|
IsHealthy(ctx context.Context) error
|
||||||
|
|
||||||
|
// GetByID gets one entry by its id.
|
||||||
|
GetByID(id string, i interface{}) error
|
||||||
|
|
||||||
|
// GetWhere gets one entry where key = value
|
||||||
|
GetWhere(key string, value interface{}, i interface{}) error
|
||||||
|
|
||||||
|
// GetAll gets all entries of interface type i
|
||||||
|
GetAll(i interface{}) error
|
||||||
|
|
||||||
|
// Put stores i
|
||||||
|
Put(i interface{}) error
|
||||||
|
|
||||||
|
// Update by id updates i with id id
|
||||||
|
UpdateByID(id string, i interface{}) error
|
||||||
|
|
||||||
|
// Delete by id removes i with id id
|
||||||
|
DeleteByID(id string, i interface{}) error
|
||||||
|
|
||||||
|
// Delete where deletes i where key = value
|
||||||
|
DeleteWhere(key string, value interface{}, i interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns a new database service that satisfies the Service interface and, by extension,
|
// New returns a new database service that satisfies the DB interface and, by extension,
|
||||||
// the go-fed database interface described here: https://github.com/go-fed/activity/blob/master/pub/database.go
|
// the go-fed database interface described here: https://github.com/go-fed/activity/blob/master/pub/database.go
|
||||||
func New(ctx context.Context, c *config.Config, log *logrus.Logger) (DB, error) {
|
func New(ctx context.Context, c *config.Config, log *logrus.Logger) (DB, error) {
|
||||||
switch strings.ToUpper(c.DBConfig.Type) {
|
switch strings.ToUpper(c.DBConfig.Type) {
|
||||||
|
|
137
internal/db/pg-fed.go
Normal file
137
internal/db/pg-fed.go
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-fed/activity/pub"
|
||||||
|
"github.com/go-fed/activity/streams"
|
||||||
|
"github.com/go-fed/activity/streams/vocab"
|
||||||
|
"github.com/go-pg/pg/v10"
|
||||||
|
)
|
||||||
|
|
||||||
|
type postgresFederation struct {
|
||||||
|
locks *sync.Map
|
||||||
|
conn *pg.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPostgresFederation(conn *pg.DB) pub.Database {
|
||||||
|
return &postgresFederation{
|
||||||
|
locks: new(sync.Map),
|
||||||
|
conn: conn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS
|
||||||
|
*/
|
||||||
|
func (pf *postgresFederation) Lock(ctx context.Context, id *url.URL) error {
|
||||||
|
// Before any other Database methods are called, the relevant `id`
|
||||||
|
// entries are locked to allow for fine-grained concurrency.
|
||||||
|
|
||||||
|
// Strategy: create a new lock, if stored, continue. Otherwise, lock the
|
||||||
|
// existing mutex.
|
||||||
|
mu := &sync.Mutex{}
|
||||||
|
mu.Lock() // Optimistically lock if we do store it.
|
||||||
|
i, loaded := pf.locks.LoadOrStore(id.String(), mu)
|
||||||
|
if loaded {
|
||||||
|
mu = i.(*sync.Mutex)
|
||||||
|
mu.Lock()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Unlock(ctx context.Context, id *url.URL) error {
|
||||||
|
// Once Go-Fed is done calling Database methods, the relevant `id`
|
||||||
|
// entries are unlocked.
|
||||||
|
|
||||||
|
i, ok := pf.locks.Load(id.String())
|
||||||
|
if !ok {
|
||||||
|
return errors.New("missing an id in unlock")
|
||||||
|
}
|
||||||
|
mu := i.(*sync.Mutex)
|
||||||
|
mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Owns(ctx context.Context, id *url.URL) (owns bool, err error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Exists(ctx context.Context, id *url.URL) (exists bool, err error) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Create(ctx context.Context, asType vocab.Type) error {
|
||||||
|
t, err := streams.NewTypeResolver()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := t.Resolve(ctx, asType); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
asType.GetTypeName()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Update(ctx context.Context, asType vocab.Type) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Delete(ctx context.Context, id *url.URL) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pf *postgresFederation) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
|
@ -22,30 +22,26 @@
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-fed/activity/streams"
|
"github.com/go-fed/activity/pub"
|
||||||
"github.com/go-fed/activity/streams/vocab"
|
|
||||||
"github.com/go-pg/pg/extra/pgdebug"
|
"github.com/go-pg/pg/extra/pgdebug"
|
||||||
"github.com/go-pg/pg/v10"
|
"github.com/go-pg/pg/v10"
|
||||||
"github.com/go-pg/pg/v10/orm"
|
"github.com/go-pg/pg/v10/orm"
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
||||||
"github.com/gotosocial/oauth2/v4"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// postgresService satisfies the DB interface
|
||||||
type postgresService struct {
|
type postgresService struct {
|
||||||
config *config.DBConfig
|
config *config.DBConfig
|
||||||
conn *pg.DB
|
conn *pg.DB
|
||||||
log *logrus.Entry
|
log *logrus.Entry
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
locks *sync.Map
|
federationDB pub.Database
|
||||||
tokenStore oauth2.TokenStore
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
|
// newPostgresService returns a postgresService derived from the provided config, which implements the go-fed DB interface.
|
||||||
|
@ -102,36 +98,20 @@ func newPostgresService(ctx context.Context, c *config.Config, log *logrus.Entry
|
||||||
return nil, errors.New("db connection timeout")
|
return nil, errors.New("db connection timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
// acc := model.StubAccount()
|
|
||||||
// if _, err := conn.Model(acc).Returning("id").Insert(); err != nil {
|
|
||||||
// cancel()
|
|
||||||
// return nil, fmt.Errorf("db insert error: %s", err)
|
|
||||||
// }
|
|
||||||
// log.Infof("created account with id %s", acc.ID)
|
|
||||||
|
|
||||||
// note := &model.Note{
|
|
||||||
// Visibility: &model.Visibility{
|
|
||||||
// Local: true,
|
|
||||||
// },
|
|
||||||
// CreatedAt: time.Now(),
|
|
||||||
// UpdatedAt: time.Now(),
|
|
||||||
// }
|
|
||||||
// if _, err := conn.WithContext(ctx).Model(note).Returning("id").Insert(); err != nil {
|
|
||||||
// cancel()
|
|
||||||
// return nil, fmt.Errorf("db insert error: %s", err)
|
|
||||||
// }
|
|
||||||
// log.Infof("created note with id %s", note.ID)
|
|
||||||
|
|
||||||
// we can confidently return this useable postgres service now
|
// we can confidently return this useable postgres service now
|
||||||
return &postgresService{
|
return &postgresService{
|
||||||
config: c.DBConfig,
|
config: c.DBConfig,
|
||||||
conn: conn,
|
conn: conn,
|
||||||
log: log,
|
log: log,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
locks: &sync.Map{},
|
federationDB: newPostgresFederation(conn),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) Federation() pub.Database {
|
||||||
|
return ps.federationDB
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
HANDY STUFF
|
HANDY STUFF
|
||||||
*/
|
*/
|
||||||
|
@ -187,118 +167,6 @@ func derivePGOptions(c *config.Config) (*pg.Options, error) {
|
||||||
return options, nil
|
return options, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
GO-FED DB INTERFACE-IMPLEMENTING FUNCTIONS
|
|
||||||
*/
|
|
||||||
func (ps *postgresService) Lock(ctx context.Context, id *url.URL) error {
|
|
||||||
// Before any other Database methods are called, the relevant `id`
|
|
||||||
// entries are locked to allow for fine-grained concurrency.
|
|
||||||
|
|
||||||
// Strategy: create a new lock, if stored, continue. Otherwise, lock the
|
|
||||||
// existing mutex.
|
|
||||||
mu := &sync.Mutex{}
|
|
||||||
mu.Lock() // Optimistically lock if we do store it.
|
|
||||||
i, loaded := ps.locks.LoadOrStore(id.String(), mu)
|
|
||||||
if loaded {
|
|
||||||
mu = i.(*sync.Mutex)
|
|
||||||
mu.Lock()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Unlock(ctx context.Context, id *url.URL) error {
|
|
||||||
// Once Go-Fed is done calling Database methods, the relevant `id`
|
|
||||||
// entries are unlocked.
|
|
||||||
|
|
||||||
i, ok := ps.locks.Load(id.String())
|
|
||||||
if !ok {
|
|
||||||
return errors.New("missing an id in unlock")
|
|
||||||
}
|
|
||||||
mu := i.(*sync.Mutex)
|
|
||||||
mu.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) InboxContains(ctx context.Context, inbox *url.URL, id *url.URL) (bool, error) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) GetInbox(ctx context.Context, inboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) SetInbox(ctx context.Context, inbox vocab.ActivityStreamsOrderedCollectionPage) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Owns(ctx context.Context, id *url.URL) (owns bool, err error) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) ActorForOutbox(ctx context.Context, outboxIRI *url.URL) (actorIRI *url.URL, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) ActorForInbox(ctx context.Context, inboxIRI *url.URL) (actorIRI *url.URL, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) OutboxForInbox(ctx context.Context, inboxIRI *url.URL) (outboxIRI *url.URL, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Exists(ctx context.Context, id *url.URL) (exists bool, err error) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Get(ctx context.Context, id *url.URL) (value vocab.Type, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Create(ctx context.Context, asType vocab.Type) error {
|
|
||||||
t, err := streams.NewTypeResolver()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := t.Resolve(ctx, asType); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
asType.GetTypeName()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Update(ctx context.Context, asType vocab.Type) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Delete(ctx context.Context, id *url.URL) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) GetOutbox(ctx context.Context, outboxIRI *url.URL) (inbox vocab.ActivityStreamsOrderedCollectionPage, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) SetOutbox(ctx context.Context, outbox vocab.ActivityStreamsOrderedCollectionPage) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) NewID(ctx context.Context, t vocab.Type) (id *url.URL, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Followers(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Following(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ps *postgresService) Liked(ctx context.Context, actorIRI *url.URL) (followers vocab.ActivityStreamsCollection, err error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
EXTRA FUNCTIONS
|
EXTRA FUNCTIONS
|
||||||
*/
|
*/
|
||||||
|
@ -338,6 +206,46 @@ func (ps *postgresService) IsHealthy(ctx context.Context) error {
|
||||||
return ps.conn.Ping(ctx)
|
return ps.conn.Ping(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ps *postgresService) TokenStore() oauth2.TokenStore {
|
func (ps *postgresService) CreateTable(i interface{}) error {
|
||||||
return ps.tokenStore
|
return ps.conn.Model(i).CreateTable(&orm.CreateTableOptions{
|
||||||
|
IfNotExists: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) DropTable(i interface{}) error {
|
||||||
|
return ps.conn.Model(i).DropTable(&orm.DropTableOptions{
|
||||||
|
IfExists: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) GetByID(id string, i interface{}) error {
|
||||||
|
return ps.conn.Model(i).Where("id = ?", id).Select()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) GetWhere(key string, value interface{}, i interface{}) error {
|
||||||
|
return ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Select()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) GetAll(i interface{}) error {
|
||||||
|
return ps.conn.Model(i).Select()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) Put(i interface{}) error {
|
||||||
|
_, err := ps.conn.Model(i).Insert(i)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) UpdateByID(id string, i interface{}) error {
|
||||||
|
_, err := ps.conn.Model(i).OnConflict("(id) DO UPDATE").Insert()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) DeleteByID(id string, i interface{}) error {
|
||||||
|
_, err := ps.conn.Model(i).Where("id = ?", id).Delete()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ps *postgresService) DeleteWhere(key string, value interface{}, i interface{}) error {
|
||||||
|
_, err := ps.conn.Model(i).Where(fmt.Sprintf("%s = ?", key), value).Delete()
|
||||||
|
return err
|
||||||
}
|
}
|
|
@ -16,5 +16,5 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// package email provides a service for interacting with an SMTP server
|
// Package email provides a service for interacting with an SMTP server
|
||||||
package email
|
package email
|
||||||
|
|
|
@ -30,11 +30,13 @@
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// New returns a go-fed compatible federating actor
|
||||||
func New(db db.DB) pub.FederatingActor {
|
func New(db db.DB) pub.FederatingActor {
|
||||||
fa := &API{}
|
fa := &API{}
|
||||||
return pub.NewFederatingActor(fa, fa, db, fa)
|
return pub.NewFederatingActor(fa, fa, db.Federation(), fa)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// API implements several go-fed interfaces in one convenient location
|
||||||
type API struct {
|
type API struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,9 +38,9 @@
|
||||||
return fmt.Errorf("error creating dbservice: %s", err)
|
return fmt.Errorf("error creating dbservice: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := dbService.CreateSchema(ctx); err != nil {
|
// if err := dbService.CreateSchema(ctx); err != nil {
|
||||||
return fmt.Errorf("error creating dbschema: %s", err)
|
// return fmt.Errorf("error creating dbschema: %s", err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
// catch shutdown signals from the operating system
|
// catch shutdown signals from the operating system
|
||||||
sigs := make(chan os.Signal, 1)
|
sigs := make(chan os.Signal, 1)
|
||||||
|
|
|
@ -22,10 +22,10 @@
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/go-fed/activity/pub"
|
"github.com/go-fed/activity/pub"
|
||||||
"github.com/gotosocial/gotosocial/internal/api"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/cache"
|
"github.com/gotosocial/gotosocial/internal/cache"
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/gotosocial/gotosocial/internal/db"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Gotosocial interface {
|
type Gotosocial interface {
|
||||||
|
@ -33,11 +33,11 @@ type Gotosocial interface {
|
||||||
Stop(context.Context) error
|
Stop(context.Context) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) {
|
func New(db db.DB, cache cache.Cache, apiRouter router.Router, federationAPI pub.FederatingActor, config *config.Config) (Gotosocial, error) {
|
||||||
return &gotosocial{
|
return &gotosocial{
|
||||||
db: db,
|
db: db,
|
||||||
cache: cache,
|
cache: cache,
|
||||||
clientAPI: clientAPI,
|
apiRouter: apiRouter,
|
||||||
federationAPI: federationAPI,
|
federationAPI: federationAPI,
|
||||||
config: config,
|
config: config,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -46,7 +46,7 @@ func New(db db.DB, cache cache.Cache, clientAPI api.Server, federationAPI pub.Fe
|
||||||
type gotosocial struct {
|
type gotosocial struct {
|
||||||
db db.DB
|
db db.DB
|
||||||
cache cache.Cache
|
cache cache.Cache
|
||||||
clientAPI api.Server
|
apiRouter router.Router
|
||||||
federationAPI pub.FederatingActor
|
federationAPI pub.FederatingActor
|
||||||
config *config.Config
|
config *config.Config
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
|
// Package gtsmodel contains types used *internally* by GoToSocial and added/removed/selected from the database.
|
||||||
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
|
// These types should never be serialized and/or sent out via public APIs, as they contain sensitive information.
|
||||||
// The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/
|
// The annotation used on these structs is for handling them via the go-pg ORM. See here: https://pg.uptrace.dev/models/
|
||||||
package gtsmodel
|
package gtsmodel
|
||||||
|
|
|
@ -18,13 +18,38 @@
|
||||||
|
|
||||||
package gtsmodel
|
package gtsmodel
|
||||||
|
|
||||||
|
import "github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
|
|
||||||
|
// Application represents an application that can perform actions on behalf of a user.
|
||||||
|
// It is used to authorize tokens etc, and is associated with an oauth client id in the database.
|
||||||
type Application struct {
|
type Application struct {
|
||||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
// id of this application in the db
|
||||||
Name string
|
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
||||||
Website string
|
// name of the application given when it was created (eg., 'tusky')
|
||||||
RedirectURI string `json:"redirect_uri"`
|
Name string
|
||||||
ClientID string `json:"client_id"`
|
// website for the application given when it was created (eg., 'https://tusky.app')
|
||||||
ClientSecret string `json:"client_secret"`
|
Website string
|
||||||
Scopes string `json:"scopes"`
|
// redirect uri requested by the application for oauth2 flow
|
||||||
VapidKey string `json:"vapid_key"`
|
RedirectURI string
|
||||||
|
// id of the associated oauth client entity in the db
|
||||||
|
ClientID string
|
||||||
|
// secret of the associated oauth client entity in the db
|
||||||
|
ClientSecret string
|
||||||
|
// scopes requested when this app was created
|
||||||
|
Scopes string
|
||||||
|
// a vapid key generated for this app when it was created
|
||||||
|
VapidKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToMastotype returns this application as a mastodon api type, ready for serialization
|
||||||
|
func (a *Application) ToMastotype() *mastotypes.Application {
|
||||||
|
return &mastotypes.Application{
|
||||||
|
ID: a.ID,
|
||||||
|
Name: a.Name,
|
||||||
|
Website: a.Website,
|
||||||
|
RedirectURI: a.RedirectURI,
|
||||||
|
ClientID: a.ClientID,
|
||||||
|
ClientSecret: a.ClientSecret,
|
||||||
|
VapidKey: a.VapidKey,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,25 +20,44 @@
|
||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
// Status represents a user-created 'post' or 'status' in the database, either remote or local
|
||||||
type Status struct {
|
type Status struct {
|
||||||
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
// id of the status in the database
|
||||||
URI string `pg:",unique"`
|
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
||||||
URL string `pg:",unique"`
|
// uri at which this status is reachable
|
||||||
Content string
|
URI string `pg:",unique"`
|
||||||
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
// web url for viewing this status
|
||||||
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
URL string `pg:",unique"`
|
||||||
Local bool
|
// the html-formatted content of this status
|
||||||
AccountID string
|
Content string
|
||||||
InReplyToID string
|
// when was this status created?
|
||||||
BoostOfID string
|
CreatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||||
|
// when was this status updated?
|
||||||
|
UpdatedAt time.Time `pg:"type:timestamp,notnull,default:now()"`
|
||||||
|
// is this status from a local account?
|
||||||
|
Local bool
|
||||||
|
// which account posted this status?
|
||||||
|
AccountID string
|
||||||
|
// id of the status this status is a reply to
|
||||||
|
InReplyToID string
|
||||||
|
// id of the status this status is a boost of
|
||||||
|
BoostOfID string
|
||||||
|
// cw string for this status
|
||||||
ContentWarning string
|
ContentWarning string
|
||||||
Visibility *Visibility
|
// visibility entry for this status
|
||||||
|
Visibility *Visibility
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Visibility represents the visibility granularity of a status. It is a combination of flags.
|
||||||
type Visibility struct {
|
type Visibility struct {
|
||||||
Direct bool
|
// Is this status viewable as a direct message?
|
||||||
|
Direct bool
|
||||||
|
// Is this status viewable to followers?
|
||||||
Followers bool
|
Followers bool
|
||||||
Local bool
|
// Is this status viewable on the local timeline?
|
||||||
Unlisted bool
|
Local bool
|
||||||
Public bool
|
// Is this status boostable but not shown on public timelines?
|
||||||
|
Unlisted bool
|
||||||
|
// Is this status shown on public and federated timelines?
|
||||||
|
Public bool
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,4 +16,22 @@
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package api
|
package account
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gotosocial/gotosocial/internal/module"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
type accountModule struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new account module
|
||||||
|
func New() module.ClientAPIModule {
|
||||||
|
return &accountModule{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route attaches all routes from this module to the given router
|
||||||
|
func (m *accountModule) Route(r router.Router) error {
|
||||||
|
return nil
|
||||||
|
}
|
29
internal/module/module.go
Normal file
29
internal/module/module.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package module is basically a wrapper for a lot of modules (in subdirectories) that satisfy the ClientAPIModule interface.
|
||||||
|
package module
|
||||||
|
|
||||||
|
import "github.com/gotosocial/gotosocial/internal/router"
|
||||||
|
|
||||||
|
// ClientAPIModule represents a chunk of code (usually contained in a single package) that adds a set
|
||||||
|
// of functionalities and side effects to a router, by mapping routes and handlers onto it--in other words, a REST API ;)
|
||||||
|
// A ClientAPIMpdule corresponds roughly to one main path of the gotosocial REST api, for example /api/v1/accounts/ or /oauth/
|
||||||
|
type ClientAPIModule interface {
|
||||||
|
Route(s router.Router) error
|
||||||
|
}
|
|
@ -1,3 +1,5 @@
|
||||||
# oauth
|
# oauth
|
||||||
|
|
||||||
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) server functionality to the GoToSocial APIs.
|
This package provides uses the [GoToSocial oauth2](https://github.com/gotosocial/oauth2) module (forked from [go-oauth2](https://github.com/go-oauth2/oauth2)) to provide [oauth2](https://www.oauth.com/) functionality to the GoToSocial client API.
|
||||||
|
|
||||||
|
It also provides a handler/middleware for attaching to the Gin engine for validating authenticated users.
|
|
@ -22,55 +22,47 @@
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/go-pg/pg/v10"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/oauth2/v4"
|
"github.com/gotosocial/oauth2/v4"
|
||||||
"github.com/gotosocial/oauth2/v4/models"
|
"github.com/gotosocial/oauth2/v4/models"
|
||||||
)
|
)
|
||||||
|
|
||||||
type pgClientStore struct {
|
type clientStore struct {
|
||||||
conn *pg.DB
|
db db.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPGClientStore(conn *pg.DB) oauth2.ClientStore {
|
func newClientStore(db db.DB) oauth2.ClientStore {
|
||||||
pts := &pgClientStore{
|
pts := &clientStore{
|
||||||
conn: conn,
|
db: db,
|
||||||
}
|
}
|
||||||
return pts
|
return pts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pcs *pgClientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
|
func (cs *clientStore) GetByID(ctx context.Context, clientID string) (oauth2.ClientInfo, error) {
|
||||||
poc := &oauthClient{
|
poc := &oauthClient{
|
||||||
ID: clientID,
|
ID: clientID,
|
||||||
}
|
}
|
||||||
if err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Select(); err != nil {
|
if err := cs.db.GetByID(clientID, poc); err != nil {
|
||||||
return nil, fmt.Errorf("error in clientstore getbyid searching for client %s: %s", clientID, err)
|
return nil, fmt.Errorf("database error: %s", err)
|
||||||
}
|
}
|
||||||
return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
|
return models.New(poc.ID, poc.Secret, poc.Domain, poc.UserID), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pcs *pgClientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
|
func (cs *clientStore) Set(ctx context.Context, id string, cli oauth2.ClientInfo) error {
|
||||||
poc := &oauthClient{
|
poc := &oauthClient{
|
||||||
ID: cli.GetID(),
|
ID: cli.GetID(),
|
||||||
Secret: cli.GetSecret(),
|
Secret: cli.GetSecret(),
|
||||||
Domain: cli.GetDomain(),
|
Domain: cli.GetDomain(),
|
||||||
UserID: cli.GetUserID(),
|
UserID: cli.GetUserID(),
|
||||||
}
|
}
|
||||||
_, err := pcs.conn.WithContext(ctx).Model(poc).OnConflict("(id) DO UPDATE").Insert()
|
return cs.db.UpdateByID(id, poc)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error in clientstore set: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pcs *pgClientStore) Delete(ctx context.Context, id string) error {
|
func (cs *clientStore) Delete(ctx context.Context, id string) error {
|
||||||
poc := &oauthClient{
|
poc := &oauthClient{
|
||||||
ID: id,
|
ID: id,
|
||||||
}
|
}
|
||||||
_, err := pcs.conn.WithContext(ctx).Model(poc).Where("id = ?", poc.ID).Delete()
|
return cs.db.DeleteByID(id, poc)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error in clientstore delete: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauthClient struct {
|
type oauthClient struct {
|
|
@ -1,11 +1,28 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
package oauth
|
package oauth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/go-pg/pg/v10"
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
"github.com/go-pg/pg/v10/orm"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/oauth2/v4/models"
|
"github.com/gotosocial/oauth2/v4/models"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
@ -13,7 +30,7 @@
|
||||||
|
|
||||||
type PgClientStoreTestSuite struct {
|
type PgClientStoreTestSuite struct {
|
||||||
suite.Suite
|
suite.Suite
|
||||||
conn *pg.DB
|
db db.DB
|
||||||
testClientID string
|
testClientID string
|
||||||
testClientSecret string
|
testClientSecret string
|
||||||
testClientDomain string
|
testClientDomain string
|
||||||
|
@ -32,31 +49,55 @@ func (suite *PgClientStoreTestSuite) SetupSuite() {
|
||||||
|
|
||||||
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
||||||
func (suite *PgClientStoreTestSuite) SetupTest() {
|
func (suite *PgClientStoreTestSuite) SetupTest() {
|
||||||
suite.conn = pg.Connect(&pg.Options{})
|
log := logrus.New()
|
||||||
if err := suite.conn.Ping(context.Background()); err != nil {
|
log.SetLevel(logrus.TraceLevel)
|
||||||
logrus.Panicf("db connection error: %s", err)
|
c := config.Empty()
|
||||||
|
c.DBConfig = &config.DBConfig{
|
||||||
|
Type: "postgres",
|
||||||
|
Address: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "postgres",
|
||||||
|
Database: "postgres",
|
||||||
|
ApplicationName: "gotosocial",
|
||||||
}
|
}
|
||||||
if err := suite.conn.Model(&oauthClient{}).CreateTable(&orm.CreateTableOptions{
|
db, err := db.New(context.Background(), c, log)
|
||||||
IfNotExists: true,
|
if err != nil {
|
||||||
}); err != nil {
|
logrus.Panicf("error creating database connection: %s", err)
|
||||||
logrus.Panicf("db connection error: %s", err)
|
}
|
||||||
|
|
||||||
|
suite.db = db
|
||||||
|
|
||||||
|
models := []interface{}{
|
||||||
|
&oauthClient{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range models {
|
||||||
|
if err := suite.db.CreateTable(m); err != nil {
|
||||||
|
logrus.Panicf("db connection error: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||||
func (suite *PgClientStoreTestSuite) TearDownTest() {
|
func (suite *PgClientStoreTestSuite) TearDownTest() {
|
||||||
if err := suite.conn.Model(&oauthClient{}).DropTable(&orm.DropTableOptions{}); err != nil {
|
models := []interface{}{
|
||||||
logrus.Panicf("drop table error: %s", err)
|
&oauthClient{},
|
||||||
}
|
}
|
||||||
if err := suite.conn.Close(); err != nil {
|
for _, m := range models {
|
||||||
|
if err := suite.db.DropTable(m); err != nil {
|
||||||
|
logrus.Panicf("error dropping table: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := suite.db.Stop(context.Background()); err != nil {
|
||||||
logrus.Panicf("error closing db connection: %s", err)
|
logrus.Panicf("error closing db connection: %s", err)
|
||||||
}
|
}
|
||||||
suite.conn = nil
|
suite.db = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
|
func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
|
||||||
// set a new client in the store
|
// set a new client in the store
|
||||||
cs := NewPGClientStore(suite.conn)
|
cs := newClientStore(suite.db)
|
||||||
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
|
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
|
||||||
suite.FailNow(err.Error())
|
suite.FailNow(err.Error())
|
||||||
}
|
}
|
||||||
|
@ -74,7 +115,7 @@ func (suite *PgClientStoreTestSuite) TestClientStoreSetAndGet() {
|
||||||
|
|
||||||
func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
|
func (suite *PgClientStoreTestSuite) TestClientSetAndDelete() {
|
||||||
// set a new client in the store
|
// set a new client in the store
|
||||||
cs := NewPGClientStore(suite.conn)
|
cs := newClientStore(suite.db)
|
||||||
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
|
if err := cs.Set(context.Background(), suite.testClientID, models.New(suite.testClientID, suite.testClientSecret, suite.testClientDomain, suite.testClientUserID)); err != nil {
|
||||||
suite.FailNow(err.Error())
|
suite.FailNow(err.Error())
|
||||||
}
|
}
|
510
internal/module/oauth/oauth.go
Normal file
510
internal/module/oauth/oauth.go
Normal file
|
@ -0,0 +1,510 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Package oauth is a module that provides oauth functionality to a router.
|
||||||
|
// It adds the following paths:
|
||||||
|
// /api/v1/apps
|
||||||
|
// /auth/sign_in
|
||||||
|
// /oauth/token
|
||||||
|
// /oauth/authorize
|
||||||
|
// It also includes the oauthTokenMiddleware, which can be attached to a router to authenticate every request by Bearer token.
|
||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/module"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
|
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
||||||
|
"github.com/gotosocial/oauth2/v4"
|
||||||
|
"github.com/gotosocial/oauth2/v4/errors"
|
||||||
|
"github.com/gotosocial/oauth2/v4/manage"
|
||||||
|
"github.com/gotosocial/oauth2/v4/server"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
appsPath = "/api/v1/apps"
|
||||||
|
authSignInPath = "/auth/sign_in"
|
||||||
|
oauthTokenPath = "/oauth/token"
|
||||||
|
oauthAuthorizePath = "/oauth/authorize"
|
||||||
|
)
|
||||||
|
|
||||||
|
// oauthModule is an oauth2 oauthModule that satisfies the ClientAPIModule interface
|
||||||
|
type oauthModule struct {
|
||||||
|
oauthManager *manage.Manager
|
||||||
|
oauthServer *server.Server
|
||||||
|
db db.DB
|
||||||
|
log *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
type login struct {
|
||||||
|
Email string `form:"username"`
|
||||||
|
Password string `form:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new oauth module
|
||||||
|
func New(ts oauth2.TokenStore, cs oauth2.ClientStore, db db.DB, log *logrus.Logger) module.ClientAPIModule {
|
||||||
|
manager := manage.NewDefaultManager()
|
||||||
|
manager.MapTokenStorage(ts)
|
||||||
|
manager.MapClientStorage(cs)
|
||||||
|
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
||||||
|
sc := &server.Config{
|
||||||
|
TokenType: "Bearer",
|
||||||
|
// Must follow the spec.
|
||||||
|
AllowGetAccessRequest: false,
|
||||||
|
// Support only the non-implicit flow.
|
||||||
|
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
||||||
|
// Allow:
|
||||||
|
// - Authorization Code (for first & third parties)
|
||||||
|
AllowedGrantTypes: []oauth2.GrantType{
|
||||||
|
oauth2.AuthorizationCode,
|
||||||
|
},
|
||||||
|
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{oauth2.CodeChallengePlain},
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := server.NewServer(sc, manager)
|
||||||
|
srv.SetInternalErrorHandler(func(err error) *errors.Response {
|
||||||
|
log.Errorf("internal oauth error: %s", err)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
srv.SetResponseErrorHandler(func(re *errors.Response) {
|
||||||
|
log.Errorf("internal response error: %s", re.Error)
|
||||||
|
})
|
||||||
|
|
||||||
|
m := &oauthModule{
|
||||||
|
oauthManager: manager,
|
||||||
|
oauthServer: srv,
|
||||||
|
db: db,
|
||||||
|
log: log,
|
||||||
|
}
|
||||||
|
|
||||||
|
m.oauthServer.SetUserAuthorizationHandler(m.userAuthorizationHandler)
|
||||||
|
m.oauthServer.SetClientInfoHandler(server.ClientFormHandler)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route satisfies the RESTAPIModule interface
|
||||||
|
func (m *oauthModule) Route(s router.Router) error {
|
||||||
|
s.AttachHandler(http.MethodPost, appsPath, m.appsPOSTHandler)
|
||||||
|
|
||||||
|
s.AttachHandler(http.MethodGet, authSignInPath, m.signInGETHandler)
|
||||||
|
s.AttachHandler(http.MethodPost, authSignInPath, m.signInPOSTHandler)
|
||||||
|
|
||||||
|
s.AttachHandler(http.MethodPost, oauthTokenPath, m.tokenPOSTHandler)
|
||||||
|
|
||||||
|
s.AttachHandler(http.MethodGet, oauthAuthorizePath, m.authorizeGETHandler)
|
||||||
|
s.AttachHandler(http.MethodPost, oauthAuthorizePath, m.authorizePOSTHandler)
|
||||||
|
|
||||||
|
s.AttachMiddleware(m.oauthTokenMiddleware)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
MAIN HANDLERS -- serve these through a server/router
|
||||||
|
*/
|
||||||
|
|
||||||
|
// appsPOSTHandler should be served at https://example.org/api/v1/apps
|
||||||
|
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
|
||||||
|
func (m *oauthModule) appsPOSTHandler(c *gin.Context) {
|
||||||
|
l := m.log.WithField("func", "AppsPOSTHandler")
|
||||||
|
l.Trace("entering AppsPOSTHandler")
|
||||||
|
|
||||||
|
form := &mastotypes.ApplicationPOSTRequest{}
|
||||||
|
if err := c.ShouldBind(form); err != nil {
|
||||||
|
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// permitted length for most fields
|
||||||
|
permittedLength := 64
|
||||||
|
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
|
||||||
|
permittedRedirect := 256
|
||||||
|
|
||||||
|
// check lengths of fields before proceeding so the user can't spam huge entries into the database
|
||||||
|
if len(form.ClientName) > permittedLength {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(form.Website) > permittedLength {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(form.RedirectURIs) > permittedRedirect {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(form.Scopes) > permittedLength {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default 'read' for scopes if it's not set, this follows the default of the mastodon api https://docs.joinmastodon.org/methods/apps/
|
||||||
|
var scopes string
|
||||||
|
if form.Scopes == "" {
|
||||||
|
scopes = "read"
|
||||||
|
} else {
|
||||||
|
scopes = form.Scopes
|
||||||
|
}
|
||||||
|
|
||||||
|
// generate new IDs for this application and its associated client
|
||||||
|
clientID := uuid.NewString()
|
||||||
|
clientSecret := uuid.NewString()
|
||||||
|
vapidKey := uuid.NewString()
|
||||||
|
|
||||||
|
// generate the application to put in the database
|
||||||
|
app := >smodel.Application{
|
||||||
|
Name: form.ClientName,
|
||||||
|
Website: form.Website,
|
||||||
|
RedirectURI: form.RedirectURIs,
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
Scopes: scopes,
|
||||||
|
VapidKey: vapidKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
// chuck it in the db
|
||||||
|
if err := m.db.Put(app); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// now we need to model an oauth client from the application that the oauth library can use
|
||||||
|
oc := &oauthClient{
|
||||||
|
ID: clientID,
|
||||||
|
Secret: clientSecret,
|
||||||
|
Domain: form.RedirectURIs,
|
||||||
|
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
|
||||||
|
}
|
||||||
|
|
||||||
|
// chuck it in the db
|
||||||
|
if err := m.db.Put(oc); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
|
||||||
|
c.JSON(http.StatusOK, app.ToMastotype())
|
||||||
|
}
|
||||||
|
|
||||||
|
// signInGETHandler should be served at https://example.org/auth/sign_in.
|
||||||
|
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||||
|
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
|
||||||
|
func (m *oauthModule) signInGETHandler(c *gin.Context) {
|
||||||
|
m.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
|
||||||
|
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// signInPOSTHandler should be served at https://example.org/auth/sign_in.
|
||||||
|
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
||||||
|
// The handler will then redirect to the auth handler served at /auth
|
||||||
|
func (m *oauthModule) signInPOSTHandler(c *gin.Context) {
|
||||||
|
l := m.log.WithField("func", "SignInPOSTHandler")
|
||||||
|
s := sessions.Default(c)
|
||||||
|
form := &login{}
|
||||||
|
if err := c.ShouldBind(form); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
l.Tracef("parsed form: %+v", form)
|
||||||
|
|
||||||
|
userid, err := m.validatePassword(form.Email, form.Password)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusForbidden, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Set("userid", userid)
|
||||||
|
if err := s.Save(); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.Trace("redirecting to auth page")
|
||||||
|
c.Redirect(http.StatusFound, oauthAuthorizePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenPOSTHandler should be served as a POST at https://example.org/oauth/token
|
||||||
|
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
|
||||||
|
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
|
||||||
|
func (m *oauthModule) tokenPOSTHandler(c *gin.Context) {
|
||||||
|
l := m.log.WithField("func", "TokenPOSTHandler")
|
||||||
|
l.Trace("entered TokenPOSTHandler")
|
||||||
|
if err := m.oauthServer.HandleTokenRequest(c.Writer, c.Request); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// authorizeGETHandler should be served as GET at https://example.org/oauth/authorize
|
||||||
|
// The idea here is to present an oauth authorize page to the user, with a button
|
||||||
|
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||||
|
func (m *oauthModule) authorizeGETHandler(c *gin.Context) {
|
||||||
|
l := m.log.WithField("func", "AuthorizeGETHandler")
|
||||||
|
s := sessions.Default(c)
|
||||||
|
|
||||||
|
// UserID will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
|
||||||
|
// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
|
||||||
|
userID, ok := s.Get("userid").(string)
|
||||||
|
if !ok || userID == "" {
|
||||||
|
l.Trace("userid was empty, parsing form then redirecting to sign in page")
|
||||||
|
if err := parseAuthForm(c, l); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
} else {
|
||||||
|
c.Redirect(http.StatusFound, authSignInPath)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// We can use the client_id on the session to retrieve info about the app associated with the client_id
|
||||||
|
clientID, ok := s.Get("client_id").(string)
|
||||||
|
if !ok || clientID == "" {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "no client_id found in session"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
app := >smodel.Application{
|
||||||
|
ClientID: clientID,
|
||||||
|
}
|
||||||
|
if err := m.db.GetWhere("client_id", app.ClientID, app); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("no application found for client id %s", clientID)})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// we can also use the userid of the user to fetch their username from the db to greet them nicely <3
|
||||||
|
user := >smodel.User{
|
||||||
|
ID: userID,
|
||||||
|
}
|
||||||
|
if err := m.db.GetByID(user.ID, user); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
acct := >smodel.Account{
|
||||||
|
ID: user.AccountID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.db.GetByID(acct.ID, acct); err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally we should also get the redirect and scope of this particular request, as stored in the session.
|
||||||
|
redirect, ok := s.Get("redirect_uri").(string)
|
||||||
|
if !ok || redirect == "" {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "no redirect_uri found in session"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scope, ok := s.Get("scope").(string)
|
||||||
|
if !ok || scope == "" {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "no scope found in session"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// the authorize template will display a form to the user where they can get some information
|
||||||
|
// about the app that's trying to authorize, and the scope of the request.
|
||||||
|
// They can then approve it if it looks OK to them, which will POST to the AuthorizePOSTHandler
|
||||||
|
l.Trace("serving authorize html")
|
||||||
|
c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
|
||||||
|
"appname": app.Name,
|
||||||
|
"appwebsite": app.Website,
|
||||||
|
"redirect": redirect,
|
||||||
|
"scope": scope,
|
||||||
|
"user": acct.Username,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// authorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
|
||||||
|
// At this point we assume that the user has A) logged in and B) accepted that the app should act for them,
|
||||||
|
// so we should proceed with the authentication flow and generate an oauth token for them if we can.
|
||||||
|
// See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
||||||
|
func (m *oauthModule) authorizePOSTHandler(c *gin.Context) {
|
||||||
|
l := m.log.WithField("func", "AuthorizePOSTHandler")
|
||||||
|
s := sessions.Default(c)
|
||||||
|
|
||||||
|
// At this point we know the user has said 'yes' to allowing the application and oauth client
|
||||||
|
// work for them, so we can set the
|
||||||
|
|
||||||
|
// We need to retrieve the original form submitted to the authorizeGEThandler, and
|
||||||
|
// recreate it on the request so that it can be used further by the oauth2 library.
|
||||||
|
// So first fetch all the values from the session.
|
||||||
|
forceLogin, ok := s.Get("force_login").(string)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
responseType, ok := s.Get("response_type").(string)
|
||||||
|
if !ok || responseType == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientID, ok := s.Get("client_id").(string)
|
||||||
|
if !ok || clientID == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirectURI, ok := s.Get("redirect_uri").(string)
|
||||||
|
if !ok || redirectURI == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scope, ok := s.Get("scope").(string)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID, ok := s.Get("userid").(string)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing userid"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// we're done with the session so we can clear it now
|
||||||
|
s.Clear()
|
||||||
|
|
||||||
|
// now set the values on the request
|
||||||
|
values := url.Values{}
|
||||||
|
values.Set("force_login", forceLogin)
|
||||||
|
values.Set("response_type", responseType)
|
||||||
|
values.Set("client_id", clientID)
|
||||||
|
values.Set("redirect_uri", redirectURI)
|
||||||
|
values.Set("scope", scope)
|
||||||
|
values.Set("userid", userID)
|
||||||
|
c.Request.Form = values
|
||||||
|
l.Tracef("values on request set to %+v", c.Request.Form)
|
||||||
|
|
||||||
|
// and proceed with authorization using the oauth2 library
|
||||||
|
if err := m.oauthServer.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
MIDDLEWARE
|
||||||
|
*/
|
||||||
|
|
||||||
|
// oauthTokenMiddleware
|
||||||
|
func (m *oauthModule) oauthTokenMiddleware(c *gin.Context) {
|
||||||
|
l := m.log.WithField("func", "ValidatePassword")
|
||||||
|
l.Trace("entering OauthTokenMiddleware")
|
||||||
|
if ti, err := m.oauthServer.ValidationBearerToken(c.Request); err == nil {
|
||||||
|
l.Tracef("authenticated user %s with bearer token, scope is %s", ti.GetUserID(), ti.GetScope())
|
||||||
|
c.Set("authenticated_user", ti.GetUserID())
|
||||||
|
|
||||||
|
} else {
|
||||||
|
l.Trace("continuing with unauthenticated request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server or used inside handler funcs
|
||||||
|
*/
|
||||||
|
|
||||||
|
// validatePassword takes an email address and a password.
|
||||||
|
// The goal is to authenticate the password against the one for that email
|
||||||
|
// address stored in the database. If OK, we return the userid (a uuid) for that user,
|
||||||
|
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
|
||||||
|
func (m *oauthModule) validatePassword(email string, password string) (userid string, err error) {
|
||||||
|
l := m.log.WithField("func", "ValidatePassword")
|
||||||
|
|
||||||
|
// make sure an email/password was provided and bail if not
|
||||||
|
if email == "" || password == "" {
|
||||||
|
l.Debug("email or password was not provided")
|
||||||
|
return incorrectPassword()
|
||||||
|
}
|
||||||
|
|
||||||
|
// first we select the user from the database based on email address, bail if no user found for that email
|
||||||
|
gtsUser := >smodel.User{}
|
||||||
|
|
||||||
|
if err := m.db.GetWhere("email", email, gtsUser); err != nil {
|
||||||
|
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
|
||||||
|
return incorrectPassword()
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure a password is actually set and bail if not
|
||||||
|
if gtsUser.EncryptedPassword == "" {
|
||||||
|
l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email)
|
||||||
|
return incorrectPassword()
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare the provided password with the encrypted one from the db, bail if they don't match
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil {
|
||||||
|
l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err)
|
||||||
|
return incorrectPassword()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we've made it this far the email/password is correct, so we can just return the id of the user.
|
||||||
|
userid = gtsUser.ID
|
||||||
|
l.Tracef("returning (%s, %s)", userid, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// incorrectPassword is just a little helper function to use in the ValidatePassword function
|
||||||
|
func incorrectPassword() (string, error) {
|
||||||
|
return "", errors.New("password/email combination was incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
// userAuthorizationHandler gets the user's ID from the 'userid' field of the request form,
|
||||||
|
// or redirects to the /auth/sign_in page, if this key is not present.
|
||||||
|
func (m *oauthModule) userAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
|
||||||
|
l := m.log.WithField("func", "UserAuthorizationHandler")
|
||||||
|
userID = r.FormValue("userid")
|
||||||
|
if userID == "" {
|
||||||
|
return "", errors.New("userid was empty, redirecting to sign in page")
|
||||||
|
}
|
||||||
|
l.Tracef("returning userID %s", userID)
|
||||||
|
return userID, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAuthForm parses the OAuthAuthorize form in the gin context, and stores
|
||||||
|
// the values in the form into the session.
|
||||||
|
func parseAuthForm(c *gin.Context, l *logrus.Entry) error {
|
||||||
|
s := sessions.Default(c)
|
||||||
|
|
||||||
|
// first make sure they've filled out the authorize form with the required values
|
||||||
|
form := &mastotypes.OAuthAuthorize{}
|
||||||
|
if err := c.ShouldBind(form); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
l.Tracef("parsed form: %+v", form)
|
||||||
|
|
||||||
|
// these fields are *required* so check 'em
|
||||||
|
if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
|
||||||
|
return errors.New("missing one of: response_type, client_id or redirect_uri")
|
||||||
|
}
|
||||||
|
|
||||||
|
// set default scope to read
|
||||||
|
if form.Scope == "" {
|
||||||
|
form.Scope = "read"
|
||||||
|
}
|
||||||
|
|
||||||
|
// save these values from the form so we can use them elsewhere in the session
|
||||||
|
s.Set("force_login", form.ForceLogin)
|
||||||
|
s.Set("response_type", form.ResponseType)
|
||||||
|
s.Set("client_id", form.ClientID)
|
||||||
|
s.Set("redirect_uri", form.RedirectURI)
|
||||||
|
s.Set("scope", form.Scope)
|
||||||
|
return s.Save()
|
||||||
|
}
|
191
internal/module/oauth/oauth_test.go
Normal file
191
internal/module/oauth/oauth_test.go
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package oauth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/router"
|
||||||
|
"github.com/gotosocial/oauth2/v4"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/suite"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OauthTestSuite struct {
|
||||||
|
suite.Suite
|
||||||
|
tokenStore oauth2.TokenStore
|
||||||
|
clientStore oauth2.ClientStore
|
||||||
|
db db.DB
|
||||||
|
testAccount *gtsmodel.Account
|
||||||
|
testApplication *gtsmodel.Application
|
||||||
|
testUser *gtsmodel.User
|
||||||
|
testClient *oauthClient
|
||||||
|
config *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
||||||
|
func (suite *OauthTestSuite) SetupSuite() {
|
||||||
|
c := config.Empty()
|
||||||
|
// we're running on localhost without https so set the protocol to http
|
||||||
|
c.Protocol = "http"
|
||||||
|
// just for testing
|
||||||
|
c.Host = "localhost:8080"
|
||||||
|
// because go tests are run within the test package directory, we need to fiddle with the templateconfig
|
||||||
|
// basedir in a way that we wouldn't normally have to do when running the binary, in order to make
|
||||||
|
// the templates actually load
|
||||||
|
c.TemplateConfig.BaseDir = "../../../web/template/"
|
||||||
|
c.DBConfig = &config.DBConfig{
|
||||||
|
Type: "postgres",
|
||||||
|
Address: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
User: "postgres",
|
||||||
|
Password: "postgres",
|
||||||
|
Database: "postgres",
|
||||||
|
ApplicationName: "gotosocial",
|
||||||
|
}
|
||||||
|
suite.config = c
|
||||||
|
|
||||||
|
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panicf("error encrypting user pass: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
acctID := uuid.NewString()
|
||||||
|
|
||||||
|
suite.testAccount = >smodel.Account{
|
||||||
|
ID: acctID,
|
||||||
|
Username: "test_user",
|
||||||
|
}
|
||||||
|
suite.testUser = >smodel.User{
|
||||||
|
EncryptedPassword: string(encryptedPassword),
|
||||||
|
Email: "user@example.org",
|
||||||
|
AccountID: acctID,
|
||||||
|
}
|
||||||
|
suite.testClient = &oauthClient{
|
||||||
|
ID: "a-known-client-id",
|
||||||
|
Secret: "some-secret",
|
||||||
|
Domain: fmt.Sprintf("%s://%s", c.Protocol, c.Host),
|
||||||
|
}
|
||||||
|
suite.testApplication = >smodel.Application{
|
||||||
|
Name: "a test application",
|
||||||
|
Website: "https://some-application-website.com",
|
||||||
|
RedirectURI: "http://localhost:8080",
|
||||||
|
ClientID: "a-known-client-id",
|
||||||
|
ClientSecret: "some-secret",
|
||||||
|
Scopes: "read",
|
||||||
|
VapidKey: uuid.NewString(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
||||||
|
func (suite *OauthTestSuite) SetupTest() {
|
||||||
|
|
||||||
|
log := logrus.New()
|
||||||
|
log.SetLevel(logrus.TraceLevel)
|
||||||
|
db, err := db.New(context.Background(), suite.config, log)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Panicf("error creating database connection: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
suite.db = db
|
||||||
|
|
||||||
|
models := []interface{}{
|
||||||
|
&oauthClient{},
|
||||||
|
&oauthToken{},
|
||||||
|
>smodel.User{},
|
||||||
|
>smodel.Account{},
|
||||||
|
>smodel.Application{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range models {
|
||||||
|
if err := suite.db.CreateTable(m); err != nil {
|
||||||
|
logrus.Panicf("db connection error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
suite.tokenStore = newTokenStore(context.Background(), suite.db, logrus.New())
|
||||||
|
suite.clientStore = newClientStore(suite.db)
|
||||||
|
|
||||||
|
if err := suite.db.Put(suite.testAccount); err != nil {
|
||||||
|
logrus.Panicf("could not insert test account into db: %s", err)
|
||||||
|
}
|
||||||
|
if err := suite.db.Put(suite.testUser); err != nil {
|
||||||
|
logrus.Panicf("could not insert test user into db: %s", err)
|
||||||
|
}
|
||||||
|
if err := suite.db.Put(suite.testClient); err != nil {
|
||||||
|
logrus.Panicf("could not insert test client into db: %s", err)
|
||||||
|
}
|
||||||
|
if err := suite.db.Put(suite.testApplication); err != nil {
|
||||||
|
logrus.Panicf("could not insert test application into db: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
||||||
|
func (suite *OauthTestSuite) TearDownTest() {
|
||||||
|
models := []interface{}{
|
||||||
|
&oauthClient{},
|
||||||
|
&oauthToken{},
|
||||||
|
>smodel.User{},
|
||||||
|
>smodel.Account{},
|
||||||
|
>smodel.Application{},
|
||||||
|
}
|
||||||
|
for _, m := range models {
|
||||||
|
if err := suite.db.DropTable(m); err != nil {
|
||||||
|
logrus.Panicf("error dropping table: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := suite.db.Stop(context.Background()); err != nil {
|
||||||
|
logrus.Panicf("error closing db connection: %s", err)
|
||||||
|
}
|
||||||
|
suite.db = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (suite *OauthTestSuite) TestAPIInitialize() {
|
||||||
|
log := logrus.New()
|
||||||
|
log.SetLevel(logrus.TraceLevel)
|
||||||
|
|
||||||
|
r, err := router.New(suite.config, log)
|
||||||
|
if err != nil {
|
||||||
|
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
api := New(suite.tokenStore, suite.clientStore, suite.db, log)
|
||||||
|
if err := api.Route(r); err != nil {
|
||||||
|
suite.FailNow(fmt.Sprintf("error mapping routes onto router: %s", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
go r.Start()
|
||||||
|
time.Sleep(60 * time.Second)
|
||||||
|
// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=http://localhost:8080&scope=read
|
||||||
|
// curl -v -F client_id=a-known-client-id -F client_secret=some-secret -F redirect_uri=http://localhost:8080 -F code=[ INSERT CODE HERE ] -F grant_type=authorization_code localhost:8080/oauth/token
|
||||||
|
// curl -v -H "Authorization: Bearer [INSERT TOKEN HERE]" http://localhost:8080
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOauthTestSuite(t *testing.T) {
|
||||||
|
suite.Run(t, new(OauthTestSuite))
|
||||||
|
}
|
|
@ -24,31 +24,31 @@
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-pg/pg/v10"
|
"github.com/gotosocial/gotosocial/internal/db"
|
||||||
"github.com/gotosocial/oauth2/v4"
|
"github.com/gotosocial/oauth2/v4"
|
||||||
"github.com/gotosocial/oauth2/v4/models"
|
"github.com/gotosocial/oauth2/v4/models"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// pgTokenStore is an implementation of oauth2.TokenStore, which uses Postgres as a storage backend.
|
// tokenStore is an implementation of oauth2.TokenStore, which uses our db interface as a storage backend.
|
||||||
type pgTokenStore struct {
|
type tokenStore struct {
|
||||||
oauth2.TokenStore
|
oauth2.TokenStore
|
||||||
conn *pg.DB
|
db db.DB
|
||||||
log *logrus.Logger
|
log *logrus.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPGTokenStore returns a token store, using postgres, that satisfies the oauth2.TokenStore interface.
|
// newTokenStore returns a token store that satisfies the oauth2.TokenStore interface.
|
||||||
//
|
//
|
||||||
// In order to allow tokens to 'expire' (not really a thing in Postgres world), it will also set off a
|
// In order to allow tokens to 'expire', it will also set off a goroutine that iterates through
|
||||||
// goroutine that iterates through the tokens in the DB once per minute and deletes any that have expired.
|
// the tokens in the DB once per minute and deletes any that have expired.
|
||||||
func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth2.TokenStore {
|
func newTokenStore(ctx context.Context, db db.DB, log *logrus.Logger) oauth2.TokenStore {
|
||||||
pts := &pgTokenStore{
|
pts := &tokenStore{
|
||||||
conn: conn,
|
db: db,
|
||||||
log: log,
|
log: log,
|
||||||
}
|
}
|
||||||
|
|
||||||
// set the token store to clean out expired tokens once per minute, or return if we're done
|
// set the token store to clean out expired tokens once per minute, or return if we're done
|
||||||
go func(ctx context.Context, pts *pgTokenStore, log *logrus.Logger) {
|
go func(ctx context.Context, pts *tokenStore, log *logrus.Logger) {
|
||||||
cleanloop:
|
cleanloop:
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
@ -67,22 +67,22 @@ func NewPGTokenStore(ctx context.Context, conn *pg.DB, log *logrus.Logger) oauth
|
||||||
}
|
}
|
||||||
|
|
||||||
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
|
// sweep clears out old tokens that have expired; it should be run on a loop about once per minute or so.
|
||||||
func (pts *pgTokenStore) sweep() error {
|
func (pts *tokenStore) sweep() error {
|
||||||
// select *all* tokens from the db
|
// select *all* tokens from the db
|
||||||
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
|
// todo: if this becomes expensive (ie., there are fucking LOADS of tokens) then figure out a better way.
|
||||||
var tokens []oauthToken
|
tokens := new([]*oauthToken)
|
||||||
if err := pts.conn.Model(&tokens).Select(); err != nil {
|
if err := pts.db.GetAll(tokens); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// iterate through and remove expired tokens
|
// iterate through and remove expired tokens
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
for _, pgt := range tokens {
|
for _, pgt := range *tokens {
|
||||||
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
|
// The zero value of a time.Time is 00:00 january 1 1970, which will always be before now. So:
|
||||||
// we only want to check if a token expired before now if the expiry time is *not zero*;
|
// we only want to check if a token expired before now if the expiry time is *not zero*;
|
||||||
// ie., if it's been explicity set.
|
// ie., if it's been explicity set.
|
||||||
if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) {
|
if !pgt.CodeExpiresAt.IsZero() && pgt.CodeExpiresAt.Before(now) || !pgt.RefreshExpiresAt.IsZero() && pgt.RefreshExpiresAt.Before(now) || !pgt.AccessExpiresAt.IsZero() && pgt.AccessExpiresAt.Before(now) {
|
||||||
if _, err := pts.conn.Model(&pgt).Delete(); err != nil {
|
if err := pts.db.DeleteByID(pgt.ID, &pgt); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -93,68 +93,61 @@ func (pts *pgTokenStore) sweep() error {
|
||||||
|
|
||||||
// Create creates and store the new token information.
|
// Create creates and store the new token information.
|
||||||
// For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34
|
// For the original implementation, see https://github.com/gotosocial/oauth2/blob/master/store/token.go#L34
|
||||||
func (pts *pgTokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
|
func (pts *tokenStore) Create(ctx context.Context, info oauth2.TokenInfo) error {
|
||||||
t, ok := info.(*models.Token)
|
t, ok := info.(*models.Token)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errors.New("info param was not a models.Token")
|
return errors.New("info param was not a models.Token")
|
||||||
}
|
}
|
||||||
_, err := pts.conn.WithContext(ctx).Model(oauthTokenToPGToken(t)).Insert()
|
if err := pts.db.Put(oauthTokenToPGToken(t)); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error in tokenstore create: %s", err)
|
return fmt.Errorf("error in tokenstore create: %s", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveByCode deletes a token from the DB based on the Code field
|
// RemoveByCode deletes a token from the DB based on the Code field
|
||||||
func (pts *pgTokenStore) RemoveByCode(ctx context.Context, code string) error {
|
func (pts *tokenStore) RemoveByCode(ctx context.Context, code string) error {
|
||||||
_, err := pts.conn.Model(&oauthToken{}).Where("code = ?", code).Delete()
|
return pts.db.DeleteWhere("code", code, &oauthToken{})
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error in tokenstore removebycode: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveByAccess deletes a token from the DB based on the Access field
|
// RemoveByAccess deletes a token from the DB based on the Access field
|
||||||
func (pts *pgTokenStore) RemoveByAccess(ctx context.Context, access string) error {
|
func (pts *tokenStore) RemoveByAccess(ctx context.Context, access string) error {
|
||||||
_, err := pts.conn.Model(&oauthToken{}).Where("access = ?", access).Delete()
|
return pts.db.DeleteWhere("access", access, &oauthToken{})
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error in tokenstore removebyaccess: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveByRefresh deletes a token from the DB based on the Refresh field
|
// RemoveByRefresh deletes a token from the DB based on the Refresh field
|
||||||
func (pts *pgTokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
|
func (pts *tokenStore) RemoveByRefresh(ctx context.Context, refresh string) error {
|
||||||
_, err := pts.conn.Model(&oauthToken{}).Where("refresh = ?", refresh).Delete()
|
return pts.db.DeleteWhere("refresh", refresh, &oauthToken{})
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error in tokenstore removebyrefresh: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByCode selects a token from the DB based on the Code field
|
// GetByCode selects a token from the DB based on the Code field
|
||||||
func (pts *pgTokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
func (pts *tokenStore) GetByCode(ctx context.Context, code string) (oauth2.TokenInfo, error) {
|
||||||
pgt := &oauthToken{}
|
pgt := &oauthToken{
|
||||||
if err := pts.conn.Model(pgt).Where("code = ?", code).Select(); err != nil {
|
Code: code,
|
||||||
return nil, fmt.Errorf("error in tokenstore getbycode: %s", err)
|
}
|
||||||
|
if err := pts.db.GetWhere("code", code, pgt); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return pgTokenToOauthToken(pgt), nil
|
return pgTokenToOauthToken(pgt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByAccess selects a token from the DB based on the Access field
|
// GetByAccess selects a token from the DB based on the Access field
|
||||||
func (pts *pgTokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
func (pts *tokenStore) GetByAccess(ctx context.Context, access string) (oauth2.TokenInfo, error) {
|
||||||
pgt := &oauthToken{}
|
pgt := &oauthToken{
|
||||||
if err := pts.conn.Model(pgt).Where("access = ?", access).Select(); err != nil {
|
Access: access,
|
||||||
return nil, fmt.Errorf("error in tokenstore getbyaccess: %s", err)
|
}
|
||||||
|
if err := pts.db.GetWhere("access", access, pgt); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return pgTokenToOauthToken(pgt), nil
|
return pgTokenToOauthToken(pgt), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByRefresh selects a token from the DB based on the Refresh field
|
// GetByRefresh selects a token from the DB based on the Refresh field
|
||||||
func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
func (pts *tokenStore) GetByRefresh(ctx context.Context, refresh string) (oauth2.TokenInfo, error) {
|
||||||
pgt := &oauthToken{}
|
pgt := &oauthToken{
|
||||||
if err := pts.conn.Model(pgt).Where("refresh = ?", refresh).Select(); err != nil {
|
Refresh: refresh,
|
||||||
return nil, fmt.Errorf("error in tokenstore getbyrefresh: %s", err)
|
}
|
||||||
|
if err := pts.db.GetWhere("refresh", refresh, pgt); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return pgTokenToOauthToken(pgt), nil
|
return pgTokenToOauthToken(pgt), nil
|
||||||
}
|
}
|
||||||
|
@ -174,6 +167,7 @@ func (pts *pgTokenStore) GetByRefresh(ctx context.Context, refresh string) (oaut
|
||||||
// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
|
// As such, manual translation is always required between oauthToken and the gotosocial *model.Token. The helper functions oauthTokenToPGToken
|
||||||
// and pgTokenToOauthToken can be used for that.
|
// and pgTokenToOauthToken can be used for that.
|
||||||
type oauthToken struct {
|
type oauthToken struct {
|
||||||
|
ID string `pg:"type:uuid,default:gen_random_uuid(),pk,notnull"`
|
||||||
ClientID string
|
ClientID string
|
||||||
UserID string
|
UserID string
|
||||||
RedirectURI string
|
RedirectURI string
|
|
@ -1,446 +0,0 @@
|
||||||
/*
|
|
||||||
GoToSocial
|
|
||||||
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
|
||||||
|
|
||||||
This program is free software: you can redistribute it and/or modify
|
|
||||||
it under the terms of the GNU Affero General Public License as published by
|
|
||||||
the Free Software Foundation, either version 3 of the License, or
|
|
||||||
(at your option) any later version.
|
|
||||||
|
|
||||||
This program is distributed in the hope that it will be useful,
|
|
||||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
GNU Affero General Public License for more details.
|
|
||||||
|
|
||||||
You should have received a copy of the GNU Affero General Public License
|
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package oauth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/go-pg/pg/v10"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/api"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
|
||||||
"github.com/gotosocial/gotosocial/pkg/mastotypes"
|
|
||||||
"github.com/gotosocial/oauth2/v4"
|
|
||||||
"github.com/gotosocial/oauth2/v4/errors"
|
|
||||||
"github.com/gotosocial/oauth2/v4/manage"
|
|
||||||
"github.com/gotosocial/oauth2/v4/server"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type API struct {
|
|
||||||
manager *manage.Manager
|
|
||||||
server *server.Server
|
|
||||||
conn *pg.DB
|
|
||||||
log *logrus.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
type login struct {
|
|
||||||
Email string `form:"username"`
|
|
||||||
Password string `form:"password"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type code struct {
|
|
||||||
Code string `form:"code"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(ts oauth2.TokenStore, cs oauth2.ClientStore, conn *pg.DB, log *logrus.Logger) *API {
|
|
||||||
manager := manage.NewDefaultManager()
|
|
||||||
manager.MapTokenStorage(ts)
|
|
||||||
manager.MapClientStorage(cs)
|
|
||||||
manager.SetAuthorizeCodeTokenCfg(manage.DefaultAuthorizeCodeTokenCfg)
|
|
||||||
sc := &server.Config{
|
|
||||||
TokenType: "Bearer",
|
|
||||||
// Must follow the spec.
|
|
||||||
AllowGetAccessRequest: false,
|
|
||||||
// Support only the non-implicit flow.
|
|
||||||
AllowedResponseTypes: []oauth2.ResponseType{oauth2.Code},
|
|
||||||
// Allow:
|
|
||||||
// - Authorization Code (for first & third parties)
|
|
||||||
// - Refreshing Tokens
|
|
||||||
//
|
|
||||||
// Deny:
|
|
||||||
// - Resource owner secrets (password grant)
|
|
||||||
// - Client secrets
|
|
||||||
AllowedGrantTypes: []oauth2.GrantType{
|
|
||||||
oauth2.AuthorizationCode,
|
|
||||||
oauth2.Refreshing,
|
|
||||||
},
|
|
||||||
AllowedCodeChallengeMethods: []oauth2.CodeChallengeMethod{
|
|
||||||
oauth2.CodeChallengePlain,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
srv := server.NewServer(sc, manager)
|
|
||||||
srv.SetInternalErrorHandler(func(err error) *errors.Response {
|
|
||||||
log.Errorf("internal oauth error: %s", err)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
srv.SetResponseErrorHandler(func(re *errors.Response) {
|
|
||||||
log.Errorf("internal response error: %s", re.Error)
|
|
||||||
})
|
|
||||||
|
|
||||||
api := &API{
|
|
||||||
manager: manager,
|
|
||||||
server: srv,
|
|
||||||
conn: conn,
|
|
||||||
log: log,
|
|
||||||
}
|
|
||||||
|
|
||||||
api.server.SetUserAuthorizationHandler(api.UserAuthorizationHandler)
|
|
||||||
api.server.SetClientInfoHandler(server.ClientFormHandler)
|
|
||||||
return api
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *API) AddRoutes(s api.Server) error {
|
|
||||||
s.AttachHandler(http.MethodPost, "/api/v1/apps", a.AppsPOSTHandler)
|
|
||||||
|
|
||||||
s.AttachHandler(http.MethodGet, "/auth/sign_in", a.SignInGETHandler)
|
|
||||||
s.AttachHandler(http.MethodPost, "/auth/sign_in", a.SignInPOSTHandler)
|
|
||||||
|
|
||||||
s.AttachHandler(http.MethodPost, "/oauth/token", a.TokenPOSTHandler)
|
|
||||||
|
|
||||||
s.AttachHandler(http.MethodGet, "/oauth/authorize", a.AuthorizeGETHandler)
|
|
||||||
s.AttachHandler(http.MethodPost, "/oauth/authorize", a.AuthorizePOSTHandler)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func incorrectPassword() (string, error) {
|
|
||||||
return "", errors.New("password/email combination was incorrect")
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
MAIN HANDLERS -- serve these through a server/router
|
|
||||||
*/
|
|
||||||
|
|
||||||
// AppsPOSTHandler should be served at https://example.org/api/v1/apps
|
|
||||||
// It is equivalent to: https://docs.joinmastodon.org/methods/apps/
|
|
||||||
func (a *API) AppsPOSTHandler(c *gin.Context) {
|
|
||||||
l := a.log.WithField("func", "AppsPOSTHandler")
|
|
||||||
l.Trace("entering AppsPOSTHandler")
|
|
||||||
|
|
||||||
form := &mastotypes.ApplicationPOSTRequest{}
|
|
||||||
if err := c.ShouldBind(form); err != nil {
|
|
||||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// permitted length for most fields
|
|
||||||
permittedLength := 64
|
|
||||||
// redirect can be a bit bigger because we probably need to encode data in the redirect uri
|
|
||||||
permittedRedirect := 256
|
|
||||||
|
|
||||||
// check lengths of fields before proceeding so the user can't spam huge entries into the database
|
|
||||||
if len(form.ClientName) > permittedLength {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("client_name must be less than %d bytes", permittedLength)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(form.Website) > permittedLength {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("website must be less than %d bytes", permittedLength)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(form.RedirectURIs) > permittedRedirect {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("redirect_uris must be less than %d bytes", permittedRedirect)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(form.Scopes) > permittedLength {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("scopes must be less than %d bytes", permittedLength)})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// set default 'read' for scopes if it's not set
|
|
||||||
var scopes string
|
|
||||||
if form.Scopes == "" {
|
|
||||||
scopes = "read"
|
|
||||||
} else {
|
|
||||||
scopes = form.Scopes
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate new IDs for this application and its associated client
|
|
||||||
clientID := uuid.NewString()
|
|
||||||
clientSecret := uuid.NewString()
|
|
||||||
vapidKey := uuid.NewString()
|
|
||||||
|
|
||||||
// generate the application to put in the database
|
|
||||||
app := >smodel.Application{
|
|
||||||
Name: form.ClientName,
|
|
||||||
Website: form.Website,
|
|
||||||
RedirectURI: form.RedirectURIs,
|
|
||||||
ClientID: clientID,
|
|
||||||
ClientSecret: clientSecret,
|
|
||||||
Scopes: scopes,
|
|
||||||
VapidKey: vapidKey,
|
|
||||||
}
|
|
||||||
|
|
||||||
// chuck it in the db
|
|
||||||
if _, err := a.conn.Model(app).Insert(); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// now we need to model an oauth client from the application that the oauth library can use
|
|
||||||
oc := &oauthClient{
|
|
||||||
ID: clientID,
|
|
||||||
Secret: clientSecret,
|
|
||||||
Domain: form.RedirectURIs,
|
|
||||||
UserID: "", // This client isn't yet associated with a specific user, it's just an app client right now
|
|
||||||
}
|
|
||||||
|
|
||||||
// chuck it in the db
|
|
||||||
if _, err := a.conn.Model(oc).Insert(); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// done, return the new app information per the spec here: https://docs.joinmastodon.org/methods/apps/
|
|
||||||
c.JSON(http.StatusOK, app)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignInGETHandler should be served at https://example.org/auth/sign_in.
|
|
||||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
|
||||||
// The form will then POST to the sign in page, which will be handled by SignInPOSTHandler
|
|
||||||
func (a *API) SignInGETHandler(c *gin.Context) {
|
|
||||||
a.log.WithField("func", "SignInGETHandler").Trace("serving sign in html")
|
|
||||||
c.HTML(http.StatusOK, "sign-in.tmpl", gin.H{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignInPOSTHandler should be served at https://example.org/auth/sign_in.
|
|
||||||
// The idea is to present a sign in page to the user, where they can enter their username and password.
|
|
||||||
// The handler will then redirect to the auth handler served at /auth
|
|
||||||
func (a *API) SignInPOSTHandler(c *gin.Context) {
|
|
||||||
l := a.log.WithField("func", "SignInPOSTHandler")
|
|
||||||
s := sessions.Default(c)
|
|
||||||
form := &login{}
|
|
||||||
if err := c.ShouldBind(form); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
l.Tracef("parsed form: %+v", form)
|
|
||||||
|
|
||||||
userid, err := a.ValidatePassword(form.Email, form.Password)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusForbidden, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.Set("username", userid)
|
|
||||||
if err := s.Save(); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
l.Trace("redirecting to auth page")
|
|
||||||
c.Redirect(http.StatusFound, "/oauth/authorize")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenPOSTHandler should be served as a POST at https://example.org/oauth/token
|
|
||||||
// The idea here is to serve an oauth access token to a user, which can be used for authorizing against non-public APIs.
|
|
||||||
// See https://docs.joinmastodon.org/methods/apps/oauth/#obtain-a-token
|
|
||||||
func (a *API) TokenPOSTHandler(c *gin.Context) {
|
|
||||||
l := a.log.WithField("func", "TokenHandler")
|
|
||||||
l.Trace("entered token handler, will now go to server.HandleTokenRequest")
|
|
||||||
if err := a.server.HandleTokenRequest(c.Writer, c.Request); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthorizeGETHandler should be served as GET at https://example.org/oauth/authorize
|
|
||||||
// The idea here is to present an oauth authorize page to the user, with a button
|
|
||||||
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
|
||||||
func (a *API) AuthorizeGETHandler(c *gin.Context) {
|
|
||||||
l := a.log.WithField("func", "AuthorizeGETHandler")
|
|
||||||
s := sessions.Default(c)
|
|
||||||
|
|
||||||
// Username will be set in the session by AuthorizePOSTHandler if the caller has already gone through the authentication flow
|
|
||||||
// If it's not set, then we don't know yet who the user is, so we need to redirect them to the sign in page.
|
|
||||||
v := s.Get("username")
|
|
||||||
if username, ok := v.(string); !ok || username == "" {
|
|
||||||
l.Trace("username was empty, parsing form then redirecting to sign in page")
|
|
||||||
|
|
||||||
// first make sure they've filled out the authorize form with the required values
|
|
||||||
form := &mastotypes.OAuthAuthorize{}
|
|
||||||
if err := c.ShouldBind(form); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
l.Tracef("parsed form: %+v", form)
|
|
||||||
|
|
||||||
// these fields are *required* so check 'em
|
|
||||||
if form.ResponseType == "" || form.ClientID == "" || form.RedirectURI == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing one of: response_type, client_id or redirect_uri"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// save these values from the form so we can use them elsewhere in the session
|
|
||||||
s.Set("force_login", form.ForceLogin)
|
|
||||||
s.Set("response_type", form.ResponseType)
|
|
||||||
s.Set("client_id", form.ClientID)
|
|
||||||
s.Set("redirect_uri", form.RedirectURI)
|
|
||||||
s.Set("scope", form.Scope)
|
|
||||||
if err := s.Save(); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// send them to the sign in page so we can tell who they are
|
|
||||||
c.Redirect(http.StatusFound, "/auth/sign_in")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if we have a code already. If we do, it means the user used urn:ietf:wg:oauth:2.0:oob as their redirect URI
|
|
||||||
// and were sent here, which means they just want the code displayed so they can use it out of band.
|
|
||||||
code := &code{}
|
|
||||||
if err := c.Bind(code); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// the authorize template will either:
|
|
||||||
// 1. Display the code to the user if they're already authorized and were redirected here because they selected urn:ietf:wg:oauth:2.0:oob.
|
|
||||||
// 2. Display a form where they can get some information about the app that's trying to authorize, and approve it, which will then go to AuthorizePOSTHandler
|
|
||||||
l.Trace("serving authorize html")
|
|
||||||
c.HTML(http.StatusOK, "authorize.tmpl", gin.H{
|
|
||||||
"code": code.Code,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// AuthorizePOSTHandler should be served as POST at https://example.org/oauth/authorize
|
|
||||||
// The idea here is to present an oauth authorize page to the user, with a button
|
|
||||||
// that they have to click to accept. See here: https://docs.joinmastodon.org/methods/apps/oauth/#authorize-a-user
|
|
||||||
func (a *API) AuthorizePOSTHandler(c *gin.Context) {
|
|
||||||
l := a.log.WithField("func", "AuthorizePOSTHandler")
|
|
||||||
s := sessions.Default(c)
|
|
||||||
|
|
||||||
v := s.Get("username")
|
|
||||||
if username, ok := v.(string); !ok || username == "" {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "you are not signed in"})
|
|
||||||
}
|
|
||||||
|
|
||||||
values := url.Values{}
|
|
||||||
|
|
||||||
if v, ok := s.Get("force_login").(string); !ok {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing force_login"})
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
values.Add("force_login", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := s.Get("response_type").(string); !ok {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing response_type"})
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
values.Add("response_type", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := s.Get("client_id").(string); !ok {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing client_id"})
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
values.Add("client_id", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := s.Get("redirect_uri").(string); !ok {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing redirect_uri"})
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
// todo: explain this little hack
|
|
||||||
if v == "urn:ietf:wg:oauth:2.0:oob" {
|
|
||||||
v = "http://localhost:8080/oauth/authorize"
|
|
||||||
}
|
|
||||||
values.Add("redirect_uri", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := s.Get("scope").(string); !ok {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing scope"})
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
values.Add("scope", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := s.Get("username").(string); !ok {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "session missing username"})
|
|
||||||
return
|
|
||||||
} else {
|
|
||||||
values.Add("username", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Request.Form = values
|
|
||||||
l.Tracef("values on request set to %+v", c.Request.Form)
|
|
||||||
|
|
||||||
if err := s.Save(); err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := a.server.HandleAuthorizeRequest(c.Writer, c.Request); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
SUB-HANDLERS -- don't serve these directly, they should be attached to the oauth2 server
|
|
||||||
*/
|
|
||||||
|
|
||||||
// PasswordAuthorizationHandler takes a username (in this case, we use an email address)
|
|
||||||
// and a password. The goal is to authenticate the password against the one for that email
|
|
||||||
// address stored in the database. If OK, we return the userid (a uuid) for that user,
|
|
||||||
// so that it can be used in further Oauth flows to generate a token/retreieve an oauth client from the db.
|
|
||||||
func (a *API) ValidatePassword(email string, password string) (userid string, err error) {
|
|
||||||
l := a.log.WithField("func", "PasswordAuthorizationHandler")
|
|
||||||
|
|
||||||
// make sure an email/password was provided and bail if not
|
|
||||||
if email == "" || password == "" {
|
|
||||||
l.Debug("email or password was not provided")
|
|
||||||
return incorrectPassword()
|
|
||||||
}
|
|
||||||
|
|
||||||
// first we select the user from the database based on email address, bail if no user found for that email
|
|
||||||
gtsUser := >smodel.User{}
|
|
||||||
if err := a.conn.Model(gtsUser).Where("email = ?", email).Select(); err != nil {
|
|
||||||
l.Debugf("user %s was not retrievable from db during oauth authorization attempt: %s", email, err)
|
|
||||||
return incorrectPassword()
|
|
||||||
}
|
|
||||||
|
|
||||||
// make sure a password is actually set and bail if not
|
|
||||||
if gtsUser.EncryptedPassword == "" {
|
|
||||||
l.Warnf("encrypted password for user %s was empty for some reason", gtsUser.Email)
|
|
||||||
return incorrectPassword()
|
|
||||||
}
|
|
||||||
|
|
||||||
// compare the provided password with the encrypted one from the db, bail if they don't match
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(gtsUser.EncryptedPassword), []byte(password)); err != nil {
|
|
||||||
l.Debugf("password hash didn't match for user %s during login attempt: %s", gtsUser.Email, err)
|
|
||||||
return incorrectPassword()
|
|
||||||
}
|
|
||||||
|
|
||||||
// If we've made it this far the email/password is correct, so we can just return the id of the user.
|
|
||||||
userid = gtsUser.ID
|
|
||||||
l.Tracef("returning (%s, %s)", userid, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserAuthorizationHandler gets the user's ID from the 'username' field of the request form,
|
|
||||||
// or redirects to the /auth/sign_in page, if this key is not present.
|
|
||||||
func (a *API) UserAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userID string, err error) {
|
|
||||||
l := a.log.WithField("func", "UserAuthorizationHandler")
|
|
||||||
userID = r.FormValue("username")
|
|
||||||
if userID == "" {
|
|
||||||
l.Trace("username was empty, redirecting to sign in page")
|
|
||||||
http.Redirect(w, r, "/auth/sign_in", http.StatusFound)
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
l.Tracef("returning (%s, %s)", userID, err)
|
|
||||||
return userID, err
|
|
||||||
}
|
|
|
@ -1,133 +0,0 @@
|
||||||
package oauth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-pg/pg/v10"
|
|
||||||
"github.com/go-pg/pg/v10/orm"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/api"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/config"
|
|
||||||
"github.com/gotosocial/gotosocial/internal/gtsmodel"
|
|
||||||
"github.com/gotosocial/oauth2/v4"
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/suite"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type OauthTestSuite struct {
|
|
||||||
suite.Suite
|
|
||||||
tokenStore oauth2.TokenStore
|
|
||||||
clientStore oauth2.ClientStore
|
|
||||||
conn *pg.DB
|
|
||||||
testAccount *gtsmodel.Account
|
|
||||||
testUser *gtsmodel.User
|
|
||||||
testClient *oauthClient
|
|
||||||
config *config.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
const ()
|
|
||||||
|
|
||||||
// SetupSuite sets some variables on the suite that we can use as consts (more or less) throughout
|
|
||||||
func (suite *OauthTestSuite) SetupSuite() {
|
|
||||||
encryptedPassword, err := bcrypt.GenerateFromPassword([]byte("test-password"), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
logrus.Panicf("error encrypting user pass: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
suite.testAccount = >smodel.Account{}
|
|
||||||
suite.testUser = >smodel.User{
|
|
||||||
EncryptedPassword: string(encryptedPassword),
|
|
||||||
Email: "user@localhost",
|
|
||||||
AccountID: "some-account-id-it-doesn't-matter-really-since-this-user-doesn't-actually-have-an-account!",
|
|
||||||
}
|
|
||||||
suite.testClient = &oauthClient{
|
|
||||||
ID: "a-known-client-id",
|
|
||||||
Secret: "some-secret",
|
|
||||||
Domain: "http://localhost:8080",
|
|
||||||
}
|
|
||||||
|
|
||||||
// because go tests are run within the test package directory, we need to fiddle with the templateconfig
|
|
||||||
// basedir in a way that we wouldn't normally have to do when running the binary, in order to make
|
|
||||||
// the templates actually load
|
|
||||||
c := config.Empty()
|
|
||||||
c.TemplateConfig.BaseDir = "../../web/template/"
|
|
||||||
suite.config = c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetupTest creates a postgres connection and creates the oauth_clients table before each test
|
|
||||||
func (suite *OauthTestSuite) SetupTest() {
|
|
||||||
suite.conn = pg.Connect(&pg.Options{})
|
|
||||||
if err := suite.conn.Ping(context.Background()); err != nil {
|
|
||||||
logrus.Panicf("db connection error: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
models := []interface{}{
|
|
||||||
&oauthClient{},
|
|
||||||
&oauthToken{},
|
|
||||||
>smodel.User{},
|
|
||||||
>smodel.Account{},
|
|
||||||
>smodel.Application{},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, m := range models {
|
|
||||||
if err := suite.conn.Model(m).CreateTable(&orm.CreateTableOptions{
|
|
||||||
IfNotExists: true,
|
|
||||||
}); err != nil {
|
|
||||||
logrus.Panicf("db connection error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
suite.tokenStore = NewPGTokenStore(context.Background(), suite.conn, logrus.New())
|
|
||||||
suite.clientStore = NewPGClientStore(suite.conn)
|
|
||||||
|
|
||||||
if _, err := suite.conn.Model(suite.testUser).Insert(); err != nil {
|
|
||||||
logrus.Panicf("could not insert test user into db: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := suite.conn.Model(suite.testClient).Insert(); err != nil {
|
|
||||||
logrus.Panicf("could not insert test client into db: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// TearDownTest drops the oauth_clients table and closes the pg connection after each test
|
|
||||||
func (suite *OauthTestSuite) TearDownTest() {
|
|
||||||
models := []interface{}{
|
|
||||||
&oauthClient{},
|
|
||||||
&oauthToken{},
|
|
||||||
>smodel.User{},
|
|
||||||
>smodel.Account{},
|
|
||||||
>smodel.Application{},
|
|
||||||
}
|
|
||||||
for _, m := range models {
|
|
||||||
if err := suite.conn.Model(m).DropTable(&orm.DropTableOptions{}); err != nil {
|
|
||||||
logrus.Panicf("drop table error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := suite.conn.Close(); err != nil {
|
|
||||||
logrus.Panicf("error closing db connection: %s", err)
|
|
||||||
}
|
|
||||||
suite.conn = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (suite *OauthTestSuite) TestAPIInitialize() {
|
|
||||||
log := logrus.New()
|
|
||||||
log.SetLevel(logrus.TraceLevel)
|
|
||||||
|
|
||||||
r := api.New(suite.config, log)
|
|
||||||
api := New(suite.tokenStore, suite.clientStore, suite.conn, log)
|
|
||||||
if err := api.AddRoutes(r); err != nil {
|
|
||||||
suite.FailNow(fmt.Sprintf("error initializing api: %s", err))
|
|
||||||
}
|
|
||||||
go r.Start()
|
|
||||||
time.Sleep(30 * time.Second)
|
|
||||||
// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=https://example.org
|
|
||||||
// http://localhost:8080/oauth/authorize?client_id=a-known-client-id&response_type=code&redirect_uri=urn:ietf:wg:oauth:2.0:oob
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOauthTestSuite(t *testing.T) {
|
|
||||||
suite.Run(t, new(OauthTestSuite))
|
|
||||||
}
|
|
120
internal/router/router.go
Normal file
120
internal/router/router.go
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-contrib/sessions/memstore"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gotosocial/gotosocial/internal/config"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Router provides the REST interface for gotosocial, using gin.
|
||||||
|
type Router interface {
|
||||||
|
// Attach a gin handler to the router with the given method and path
|
||||||
|
AttachHandler(method string, path string, handler gin.HandlerFunc)
|
||||||
|
// Attach a gin middleware to the router that will be used globally
|
||||||
|
AttachMiddleware(handler gin.HandlerFunc)
|
||||||
|
// Start the router
|
||||||
|
Start()
|
||||||
|
// Stop the router
|
||||||
|
Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// router fulfils the Router interface using gin and logrus
|
||||||
|
type router struct {
|
||||||
|
logger *logrus.Logger
|
||||||
|
engine *gin.Engine
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the router nicely
|
||||||
|
func (s *router) Start() {
|
||||||
|
// todo: start gracefully
|
||||||
|
if err := s.engine.Run(); err != nil {
|
||||||
|
s.logger.Panicf("server error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop shuts down the router nicely
|
||||||
|
func (s *router) Stop() {
|
||||||
|
// todo: shut down gracefully
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttachHandler attaches the given gin.HandlerFunc to the router with the specified method and path.
|
||||||
|
// If the path is set to ANY, then the handlerfunc will be used for ALL methods at its given path.
|
||||||
|
func (s *router) AttachHandler(method string, path string, handler gin.HandlerFunc) {
|
||||||
|
if method == "ANY" {
|
||||||
|
s.engine.Any(path, handler)
|
||||||
|
} else {
|
||||||
|
s.engine.Handle(method, path, handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AttachMiddleware attaches a gin middleware to the router that will be used globally
|
||||||
|
func (s *router) AttachMiddleware(middleware gin.HandlerFunc) {
|
||||||
|
s.engine.Use(middleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new Router with the specified configuration, using the given logrus logger.
|
||||||
|
func New(config *config.Config, logger *logrus.Logger) (Router, error) {
|
||||||
|
engine := gin.New()
|
||||||
|
|
||||||
|
// create a new session store middleware
|
||||||
|
store, err := sessionStore()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error creating session store: %s", err)
|
||||||
|
}
|
||||||
|
engine.Use(sessions.Sessions("gotosocial-session", store))
|
||||||
|
|
||||||
|
// load html templates for use by the router
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error getting current working directory: %s", err)
|
||||||
|
}
|
||||||
|
tmPath := filepath.Join(cwd, fmt.Sprintf("%s*", config.TemplateConfig.BaseDir))
|
||||||
|
logger.Debugf("loading templates from %s", tmPath)
|
||||||
|
engine.LoadHTMLGlob(tmPath)
|
||||||
|
|
||||||
|
return &router{
|
||||||
|
logger: logger,
|
||||||
|
engine: engine,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionStore returns a new session store with a random auth and encryption key.
|
||||||
|
// This means that cookies using the store will be reset if gotosocial is restarted!
|
||||||
|
func sessionStore() (memstore.Store, error) {
|
||||||
|
auth := make([]byte, 32)
|
||||||
|
crypt := make([]byte, 32)
|
||||||
|
|
||||||
|
if _, err := rand.Read(auth); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := rand.Read(crypt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return memstore.NewStore(auth, crypt), nil
|
||||||
|
}
|
|
@ -2,7 +2,7 @@
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<title>Auth</title>
|
<title>GoToSocial Authorization</title>
|
||||||
<link
|
<link
|
||||||
rel="stylesheet"
|
rel="stylesheet"
|
||||||
href="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css"
|
href="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css"
|
||||||
|
@ -11,13 +11,13 @@
|
||||||
<script src="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
|
<script src="//maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
{{if len .code | eq 0 }}
|
|
||||||
<body>
|
<body>
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<div class="jumbotron">
|
<div class="jumbotron">
|
||||||
<form action="/oauth/authorize" method="POST">
|
<form action="/oauth/authorize" method="POST">
|
||||||
<h1>Authorize</h1>
|
<h1>Hi {{.user}}!</h1>
|
||||||
<p>The client would like to perform actions on your behalf.</p>
|
<p>Application <b>{{.appname}}</b> {{if len .appwebsite | eq 0 | not}}({{.appwebsite}}) {{end}}would like to perform actions on your behalf, with scope <em>{{.scope}}</em>.</p>
|
||||||
|
<p>The application will redirect to {{.redirect}} to continue.</p>
|
||||||
<p>
|
<p>
|
||||||
<button
|
<button
|
||||||
type="submit"
|
type="submit"
|
||||||
|
@ -31,14 +31,4 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
{{else}}
|
|
||||||
<body>
|
|
||||||
<div class="container">
|
|
||||||
<div class="jumbotron">
|
|
||||||
{{.code}}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
</html>
|
</html>
|
||||||
|
|
Loading…
Reference in a new issue