mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-12-05 10:02:47 +00:00
181 lines
6.1 KiB
Go
181 lines
6.1 KiB
Go
|
package runtime
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"io"
|
||
|
"net/http"
|
||
|
|
||
|
"google.golang.org/grpc/codes"
|
||
|
"google.golang.org/grpc/grpclog"
|
||
|
"google.golang.org/grpc/status"
|
||
|
)
|
||
|
|
||
|
// ErrorHandlerFunc is the signature used to configure error handling.
|
||
|
type ErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, error)
|
||
|
|
||
|
// StreamErrorHandlerFunc is the signature used to configure stream error handling.
|
||
|
type StreamErrorHandlerFunc func(context.Context, error) *status.Status
|
||
|
|
||
|
// RoutingErrorHandlerFunc is the signature used to configure error handling for routing errors.
|
||
|
type RoutingErrorHandlerFunc func(context.Context, *ServeMux, Marshaler, http.ResponseWriter, *http.Request, int)
|
||
|
|
||
|
// HTTPStatusError is the error to use when needing to provide a different HTTP status code for an error
|
||
|
// passed to the DefaultRoutingErrorHandler.
|
||
|
type HTTPStatusError struct {
|
||
|
HTTPStatus int
|
||
|
Err error
|
||
|
}
|
||
|
|
||
|
func (e *HTTPStatusError) Error() string {
|
||
|
return e.Err.Error()
|
||
|
}
|
||
|
|
||
|
// HTTPStatusFromCode converts a gRPC error code into the corresponding HTTP response status.
|
||
|
// See: https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto
|
||
|
func HTTPStatusFromCode(code codes.Code) int {
|
||
|
switch code {
|
||
|
case codes.OK:
|
||
|
return http.StatusOK
|
||
|
case codes.Canceled:
|
||
|
return http.StatusRequestTimeout
|
||
|
case codes.Unknown:
|
||
|
return http.StatusInternalServerError
|
||
|
case codes.InvalidArgument:
|
||
|
return http.StatusBadRequest
|
||
|
case codes.DeadlineExceeded:
|
||
|
return http.StatusGatewayTimeout
|
||
|
case codes.NotFound:
|
||
|
return http.StatusNotFound
|
||
|
case codes.AlreadyExists:
|
||
|
return http.StatusConflict
|
||
|
case codes.PermissionDenied:
|
||
|
return http.StatusForbidden
|
||
|
case codes.Unauthenticated:
|
||
|
return http.StatusUnauthorized
|
||
|
case codes.ResourceExhausted:
|
||
|
return http.StatusTooManyRequests
|
||
|
case codes.FailedPrecondition:
|
||
|
// Note, this deliberately doesn't translate to the similarly named '412 Precondition Failed' HTTP response status.
|
||
|
return http.StatusBadRequest
|
||
|
case codes.Aborted:
|
||
|
return http.StatusConflict
|
||
|
case codes.OutOfRange:
|
||
|
return http.StatusBadRequest
|
||
|
case codes.Unimplemented:
|
||
|
return http.StatusNotImplemented
|
||
|
case codes.Internal:
|
||
|
return http.StatusInternalServerError
|
||
|
case codes.Unavailable:
|
||
|
return http.StatusServiceUnavailable
|
||
|
case codes.DataLoss:
|
||
|
return http.StatusInternalServerError
|
||
|
}
|
||
|
|
||
|
grpclog.Infof("Unknown gRPC error code: %v", code)
|
||
|
return http.StatusInternalServerError
|
||
|
}
|
||
|
|
||
|
// HTTPError uses the mux-configured error handler.
|
||
|
func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
|
||
|
mux.errorHandler(ctx, mux, marshaler, w, r, err)
|
||
|
}
|
||
|
|
||
|
// DefaultHTTPErrorHandler is the default error handler.
|
||
|
// If "err" is a gRPC Status, the function replies with the status code mapped by HTTPStatusFromCode.
|
||
|
// If "err" is a HTTPStatusError, the function replies with the status code provide by that struct. This is
|
||
|
// intended to allow passing through of specific statuses via the function set via WithRoutingErrorHandler
|
||
|
// for the ServeMux constructor to handle edge cases which the standard mappings in HTTPStatusFromCode
|
||
|
// are insufficient for.
|
||
|
// If otherwise, it replies with http.StatusInternalServerError.
|
||
|
//
|
||
|
// The response body written by this function is a Status message marshaled by the Marshaler.
|
||
|
func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
|
||
|
// return Internal when Marshal failed
|
||
|
const fallback = `{"code": 13, "message": "failed to marshal error message"}`
|
||
|
|
||
|
var customStatus *HTTPStatusError
|
||
|
if errors.As(err, &customStatus) {
|
||
|
err = customStatus.Err
|
||
|
}
|
||
|
|
||
|
s := status.Convert(err)
|
||
|
pb := s.Proto()
|
||
|
|
||
|
w.Header().Del("Trailer")
|
||
|
w.Header().Del("Transfer-Encoding")
|
||
|
|
||
|
contentType := marshaler.ContentType(pb)
|
||
|
w.Header().Set("Content-Type", contentType)
|
||
|
|
||
|
if s.Code() == codes.Unauthenticated {
|
||
|
w.Header().Set("WWW-Authenticate", s.Message())
|
||
|
}
|
||
|
|
||
|
buf, merr := marshaler.Marshal(pb)
|
||
|
if merr != nil {
|
||
|
grpclog.Infof("Failed to marshal error message %q: %v", s, merr)
|
||
|
w.WriteHeader(http.StatusInternalServerError)
|
||
|
if _, err := io.WriteString(w, fallback); err != nil {
|
||
|
grpclog.Infof("Failed to write response: %v", err)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
md, ok := ServerMetadataFromContext(ctx)
|
||
|
if !ok {
|
||
|
grpclog.Infof("Failed to extract ServerMetadata from context")
|
||
|
}
|
||
|
|
||
|
handleForwardResponseServerMetadata(w, mux, md)
|
||
|
|
||
|
// RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
|
||
|
// Unless the request includes a TE header field indicating "trailers"
|
||
|
// is acceptable, as described in Section 4.3, a server SHOULD NOT
|
||
|
// generate trailer fields that it believes are necessary for the user
|
||
|
// agent to receive.
|
||
|
doForwardTrailers := requestAcceptsTrailers(r)
|
||
|
|
||
|
if doForwardTrailers {
|
||
|
handleForwardResponseTrailerHeader(w, md)
|
||
|
w.Header().Set("Transfer-Encoding", "chunked")
|
||
|
}
|
||
|
|
||
|
st := HTTPStatusFromCode(s.Code())
|
||
|
if customStatus != nil {
|
||
|
st = customStatus.HTTPStatus
|
||
|
}
|
||
|
|
||
|
w.WriteHeader(st)
|
||
|
if _, err := w.Write(buf); err != nil {
|
||
|
grpclog.Infof("Failed to write response: %v", err)
|
||
|
}
|
||
|
|
||
|
if doForwardTrailers {
|
||
|
handleForwardResponseTrailer(w, md)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func DefaultStreamErrorHandler(_ context.Context, err error) *status.Status {
|
||
|
return status.Convert(err)
|
||
|
}
|
||
|
|
||
|
// DefaultRoutingErrorHandler is our default handler for routing errors.
|
||
|
// By default http error codes mapped on the following error codes:
|
||
|
// NotFound -> grpc.NotFound
|
||
|
// StatusBadRequest -> grpc.InvalidArgument
|
||
|
// MethodNotAllowed -> grpc.Unimplemented
|
||
|
// Other -> grpc.Internal, method is not expecting to be called for anything else
|
||
|
func DefaultRoutingErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, httpStatus int) {
|
||
|
sterr := status.Error(codes.Internal, "Unexpected routing error")
|
||
|
switch httpStatus {
|
||
|
case http.StatusBadRequest:
|
||
|
sterr = status.Error(codes.InvalidArgument, http.StatusText(httpStatus))
|
||
|
case http.StatusMethodNotAllowed:
|
||
|
sterr = status.Error(codes.Unimplemented, http.StatusText(httpStatus))
|
||
|
case http.StatusNotFound:
|
||
|
sterr = status.Error(codes.NotFound, http.StatusText(httpStatus))
|
||
|
}
|
||
|
mux.errorHandler(ctx, mux, marshaler, w, r, sterr)
|
||
|
}
|