package server import ( "context" "encoding/json" "fmt" "net/http" "net/url" "strings" "time" "github.com/superseriousbusiness/oauth2/v4" "github.com/superseriousbusiness/oauth2/v4/errors" ) // NewDefaultServer create a default authorization server func NewDefaultServer(manager oauth2.Manager) *Server { return NewServer(NewConfig(), manager) } // NewServer create authorization server func NewServer(cfg *Config, manager oauth2.Manager) *Server { srv := &Server{ Config: cfg, Manager: manager, } // default handler srv.ClientInfoHandler = ClientBasicHandler srv.UserAuthorizationHandler = func(w http.ResponseWriter, r *http.Request) (string, error) { return "", errors.ErrAccessDenied } srv.PasswordAuthorizationHandler = func(username, password string) (string, error) { return "", errors.ErrAccessDenied } return srv } // Server Provide authorization server type Server struct { Config *Config Manager oauth2.Manager ClientInfoHandler ClientInfoHandler ClientAuthorizedHandler ClientAuthorizedHandler ClientScopeHandler ClientScopeHandler UserAuthorizationHandler UserAuthorizationHandler PasswordAuthorizationHandler PasswordAuthorizationHandler RefreshingValidationHandler RefreshingValidationHandler RefreshingScopeHandler RefreshingScopeHandler ResponseErrorHandler ResponseErrorHandler InternalErrorHandler InternalErrorHandler ExtensionFieldsHandler ExtensionFieldsHandler AccessTokenExpHandler AccessTokenExpHandler AuthorizeScopeHandler AuthorizeScopeHandler } func (s *Server) redirectError(w http.ResponseWriter, req *AuthorizeRequest, err error) error { if req == nil { return err } data, _, _ := s.GetErrorData(err) return s.redirect(w, req, data) } func (s *Server) redirect(w http.ResponseWriter, req *AuthorizeRequest, data map[string]interface{}) error { uri, err := s.GetRedirectURI(req, data) if err != nil { return err } w.Header().Set("Location", uri) w.WriteHeader(302) return nil } func (s *Server) tokenError(w http.ResponseWriter, err error) error { data, statusCode, header := s.GetErrorData(err) return s.token(w, data, header, statusCode) } func (s *Server) token(w http.ResponseWriter, data map[string]interface{}, header http.Header, statusCode ...int) error { w.Header().Set("Content-Type", "application/json;charset=UTF-8") w.Header().Set("Cache-Control", "no-store") w.Header().Set("Pragma", "no-cache") for key := range header { w.Header().Set(key, header.Get(key)) } status := http.StatusOK if len(statusCode) > 0 && statusCode[0] > 0 { status = statusCode[0] } w.WriteHeader(status) return json.NewEncoder(w).Encode(data) } // GetRedirectURI get redirect uri func (s *Server) GetRedirectURI(req *AuthorizeRequest, data map[string]interface{}) (string, error) { u, err := url.Parse(req.RedirectURI) if err != nil { return "", err } q := u.Query() if req.State != "" { q.Set("state", req.State) } for k, v := range data { q.Set(k, fmt.Sprint(v)) } switch req.ResponseType { case oauth2.Code: u.RawQuery = q.Encode() case oauth2.Token: u.RawQuery = "" fragment, err := url.QueryUnescape(q.Encode()) if err != nil { return "", err } u.Fragment = fragment } return u.String(), nil } // CheckResponseType check allows response type func (s *Server) CheckResponseType(rt oauth2.ResponseType) bool { for _, art := range s.Config.AllowedResponseTypes { if art == rt { return true } } return false } // CheckCodeChallengeMethod checks for allowed code challenge method func (s *Server) CheckCodeChallengeMethod(ccm oauth2.CodeChallengeMethod) bool { for _, c := range s.Config.AllowedCodeChallengeMethods { if c == ccm { return true } } return false } // ValidationAuthorizeRequest the authorization request validation func (s *Server) ValidationAuthorizeRequest(r *http.Request) (*AuthorizeRequest, error) { redirectURI := r.FormValue("redirect_uri") clientID := r.FormValue("client_id") if !(r.Method == "GET" || r.Method == "POST") || clientID == "" { fmt.Println(r.Method, clientID, r) return nil, errors.ErrInvalidRequest } resType := oauth2.ResponseType(r.FormValue("response_type")) if resType.String() == "" { return nil, errors.ErrUnsupportedResponseType } else if allowed := s.CheckResponseType(resType); !allowed { return nil, errors.ErrUnauthorizedClient } cc := r.FormValue("code_challenge") if cc == "" && s.Config.ForcePKCE { return nil, errors.ErrCodeChallengeRquired } if cc != "" && (len(cc) < 43 || len(cc) > 128) { return nil, errors.ErrInvalidCodeChallengeLen } ccm := oauth2.CodeChallengeMethod(r.FormValue("code_challenge_method")) // set default if ccm == "" { ccm = oauth2.CodeChallengePlain } if ccm.String() != "" && !s.CheckCodeChallengeMethod(ccm) { return nil, errors.ErrUnsupportedCodeChallengeMethod } req := &AuthorizeRequest{ RedirectURI: redirectURI, ResponseType: resType, ClientID: clientID, State: r.FormValue("state"), Scope: r.FormValue("scope"), Request: r, CodeChallenge: cc, CodeChallengeMethod: ccm, } return req, nil } // GetAuthorizeToken get authorization token(code) func (s *Server) GetAuthorizeToken(ctx context.Context, req *AuthorizeRequest) (oauth2.TokenInfo, error) { // check the client allows the grant type if fn := s.ClientAuthorizedHandler; fn != nil { gt := oauth2.AuthorizationCode if req.ResponseType == oauth2.Token { gt = oauth2.Implicit } allowed, err := fn(req.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } // check the client allows the authorized scope if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(req.ClientID, req.Scope) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } tgr := &oauth2.TokenGenerateRequest{ ClientID: req.ClientID, UserID: req.UserID, RedirectURI: req.RedirectURI, Scope: req.Scope, AccessTokenExp: req.AccessTokenExp, Request: req.Request, CodeChallenge: req.CodeChallenge, CodeChallengeMethod: req.CodeChallengeMethod, } return s.Manager.GenerateAuthToken(ctx, req.ResponseType, tgr) } // GetAuthorizeData get authorization response data func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) map[string]interface{} { if rt == oauth2.Code { return map[string]interface{}{ "code": ti.GetCode(), } } return s.GetTokenData(ti) } // HandleAuthorizeRequest the authorization request handling func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() req, err := s.ValidationAuthorizeRequest(r) if err != nil { return s.redirectError(w, req, err) } // user authorization userID, err := s.UserAuthorizationHandler(w, r) if err != nil { return s.redirectError(w, req, err) } else if userID == "" { return nil } req.UserID = userID // specify the scope of authorization if fn := s.AuthorizeScopeHandler; fn != nil { scope, err := fn(w, r) if err != nil { return err } else if scope != "" { req.Scope = scope } } // specify the expiration time of access token if fn := s.AccessTokenExpHandler; fn != nil { exp, err := fn(w, r) if err != nil { return err } req.AccessTokenExp = exp } ti, err := s.GetAuthorizeToken(ctx, req) if err != nil { return s.redirectError(w, req, err) } // If the redirect URI is empty, the default domain provided by the client is used. if req.RedirectURI == "" { client, err := s.Manager.GetClient(ctx, req.ClientID) if err != nil { return err } req.RedirectURI = client.GetDomain() } return s.redirect(w, req, s.GetAuthorizeData(req.ResponseType, ti)) } // ValidationTokenRequest the token request validation func (s *Server) ValidationTokenRequest(r *http.Request) (oauth2.GrantType, *oauth2.TokenGenerateRequest, error) { if v := r.Method; !(v == "POST" || (s.Config.AllowGetAccessRequest && v == "GET")) { return "", nil, errors.ErrInvalidRequest } gt := oauth2.GrantType(r.FormValue("grant_type")) if gt.String() == "" { return "", nil, errors.ErrUnsupportedGrantType } codeVer := r.FormValue("code_verifier") if s.Config.ForcePKCE && codeVer == "" { return "", nil, errors.ErrInvalidRequest } clientID, clientSecret, err := s.ClientInfoHandler(r) if err != nil { return "", nil, err } tgr := &oauth2.TokenGenerateRequest{ ClientID: clientID, ClientSecret: clientSecret, Request: r, } switch gt { case oauth2.AuthorizationCode: tgr.RedirectURI = r.FormValue("redirect_uri") tgr.Code = r.FormValue("code") if tgr.RedirectURI == "" || tgr.Code == "" { return "", nil, errors.ErrInvalidRequest } tgr.CodeVerifier = codeVer case oauth2.PasswordCredentials: tgr.Scope = r.FormValue("scope") username, password := r.FormValue("username"), r.FormValue("password") if username == "" || password == "" { return "", nil, errors.ErrInvalidRequest } userID, err := s.PasswordAuthorizationHandler(username, password) if err != nil { return "", nil, err } else if userID == "" { return "", nil, errors.ErrInvalidGrant } tgr.UserID = userID case oauth2.ClientCredentials: tgr.Scope = r.FormValue("scope") case oauth2.Refreshing: tgr.Refresh = r.FormValue("refresh_token") tgr.Scope = r.FormValue("scope") if tgr.Refresh == "" { return "", nil, errors.ErrInvalidRequest } } return gt, tgr, nil } // CheckGrantType check allows grant type func (s *Server) CheckGrantType(gt oauth2.GrantType) bool { for _, agt := range s.Config.AllowedGrantTypes { if agt == gt { return true } } return false } // GetAccessToken access token func (s *Server) GetAccessToken(ctx context.Context, gt oauth2.GrantType, tgr *oauth2.TokenGenerateRequest) (oauth2.TokenInfo, error) { if allowed := s.CheckGrantType(gt); !allowed { return nil, errors.ErrUnauthorizedClient } if fn := s.ClientAuthorizedHandler; fn != nil { allowed, err := fn(tgr.ClientID, gt) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrUnauthorizedClient } } switch gt { case oauth2.AuthorizationCode: ti, err := s.Manager.GenerateAccessToken(ctx, gt, tgr) if err != nil { switch err { case errors.ErrInvalidAuthorizeCode, errors.ErrInvalidCodeChallenge, errors.ErrMissingCodeChallenge, errors.ErrMissingCodeChallenge: return nil, errors.ErrInvalidGrant case errors.ErrInvalidClient: return nil, errors.ErrInvalidClient default: return nil, err } } return ti, nil case oauth2.PasswordCredentials, oauth2.ClientCredentials: if fn := s.ClientScopeHandler; fn != nil { allowed, err := fn(tgr.ClientID, tgr.Scope) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } return s.Manager.GenerateAccessToken(ctx, gt, tgr) case oauth2.Refreshing: // check scope if scope, scopeFn := tgr.Scope, s.RefreshingScopeHandler; scope != "" && scopeFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } allowed, err := scopeFn(scope, rti.GetScope()) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } if validationFn := s.RefreshingValidationHandler; validationFn != nil { rti, err := s.Manager.LoadRefreshToken(ctx, tgr.Refresh) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } allowed, err := validationFn(rti) if err != nil { return nil, err } else if !allowed { return nil, errors.ErrInvalidScope } } ti, err := s.Manager.RefreshAccessToken(ctx, tgr) if err != nil { if err == errors.ErrInvalidRefreshToken || err == errors.ErrExpiredRefreshToken { return nil, errors.ErrInvalidGrant } return nil, err } return ti, nil } return nil, errors.ErrUnsupportedGrantType } // GetTokenData token data func (s *Server) GetTokenData(ti oauth2.TokenInfo) map[string]interface{} { data := map[string]interface{}{ "access_token": ti.GetAccess(), "token_type": s.Config.TokenType, "expires_in": int64(ti.GetAccessExpiresIn() / time.Second), } if scope := ti.GetScope(); scope != "" { data["scope"] = scope } if refresh := ti.GetRefresh(); refresh != "" { data["refresh_token"] = refresh } if fn := s.ExtensionFieldsHandler; fn != nil { ext := fn(ti) for k, v := range ext { if _, ok := data[k]; ok { continue } data[k] = v } } return data } // HandleTokenRequest token request handling func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() gt, tgr, err := s.ValidationTokenRequest(r) if err != nil { return s.tokenError(w, err) } ti, err := s.GetAccessToken(ctx, gt, tgr) if err != nil { return s.tokenError(w, err) } return s.token(w, s.GetTokenData(ti), nil) } // GetErrorData get error response data func (s *Server) GetErrorData(err error) (map[string]interface{}, int, http.Header) { var re errors.Response if v, ok := errors.Descriptions[err]; ok { re.Error = err re.Description = v re.StatusCode = errors.StatusCodes[err] } else { if fn := s.InternalErrorHandler; fn != nil { if v := fn(err); v != nil { re = *v } } if re.Error == nil { re.Error = errors.ErrServerError re.Description = errors.Descriptions[errors.ErrServerError] re.StatusCode = errors.StatusCodes[errors.ErrServerError] } } if fn := s.ResponseErrorHandler; fn != nil { fn(&re) } data := make(map[string]interface{}) if err := re.Error; err != nil { data["error"] = err.Error() } if v := re.ErrorCode; v != 0 { data["error_code"] = v } if v := re.Description; v != "" { data["error_description"] = v } if v := re.URI; v != "" { data["error_uri"] = v } statusCode := http.StatusInternalServerError if v := re.StatusCode; v > 0 { statusCode = v } return data, statusCode, re.Header } // BearerAuth parse bearer token func (s *Server) BearerAuth(r *http.Request) (string, bool) { auth := r.Header.Get("Authorization") prefix := "Bearer " token := "" if auth != "" && strings.HasPrefix(auth, prefix) { token = auth[len(prefix):] } else { token = r.FormValue("access_token") } return token, token != "" } // ValidationBearerToken validation the bearer tokens // https://tools.ietf.org/html/rfc6750 func (s *Server) ValidationBearerToken(r *http.Request) (oauth2.TokenInfo, error) { ctx := r.Context() accessToken, ok := s.BearerAuth(r) if !ok { return nil, errors.ErrInvalidAccessToken } return s.Manager.LoadAccessToken(ctx, accessToken) }