package runtime import ( "context" "errors" "fmt" "net/http" "net/textproto" "regexp" "strings" "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" ) // UnescapingMode defines the behavior of ServeMux when unescaping path parameters. type UnescapingMode int const ( // UnescapingModeLegacy is the default V2 behavior, which escapes the entire // path string before doing any routing. UnescapingModeLegacy UnescapingMode = iota // UnescapingModeAllExceptReserved unescapes all path parameters except RFC 6570 // reserved characters. UnescapingModeAllExceptReserved // UnescapingModeAllExceptSlash unescapes URL path parameters except path // separators, which will be left as "%2F". UnescapingModeAllExceptSlash // UnescapingModeAllCharacters unescapes all URL path parameters. UnescapingModeAllCharacters // UnescapingModeDefault is the default escaping type. // TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's // reference implementation UnescapingModeDefault = UnescapingModeLegacy ) var encodedPathSplitter = regexp.MustCompile("(/|%2F)") // A HandlerFunc handles a specific pair of path pattern and HTTP method. type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) // A Middleware handler wraps another HandlerFunc to do some pre- and/or post-processing of the request. This is used as an alternative to gRPC interceptors when using the direct-to-implementation // registration methods. It is generally recommended to use gRPC client or server interceptors instead // where possible. type Middleware func(HandlerFunc) HandlerFunc // ServeMux is a request multiplexer for grpc-gateway. // It matches http requests to patterns and invokes the corresponding handler. type ServeMux struct { // handlers maps HTTP method to a list of handlers. handlers map[string][]handler middlewares []Middleware forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error forwardResponseRewriter ForwardResponseRewriter marshalers marshalerRegistry incomingHeaderMatcher HeaderMatcherFunc outgoingHeaderMatcher HeaderMatcherFunc outgoingTrailerMatcher HeaderMatcherFunc metadataAnnotators []func(context.Context, *http.Request) metadata.MD errorHandler ErrorHandlerFunc streamErrorHandler StreamErrorHandlerFunc routingErrorHandler RoutingErrorHandlerFunc disablePathLengthFallback bool unescapingMode UnescapingMode } // ServeMuxOption is an option that can be given to a ServeMux on construction. type ServeMuxOption func(*ServeMux) // ForwardResponseRewriter is the signature of a function that is capable of rewriting messages // before they are forwarded in a unary, stream, or error response. type ForwardResponseRewriter func(ctx context.Context, response proto.Message) (any, error) // WithForwardResponseRewriter returns a ServeMuxOption that allows for implementers to insert logic // that can rewrite the final response before it is forwarded. // // The response rewriter function is called during unary message forwarding, stream message // forwarding and when errors are being forwarded. // // NOTE: Using this option will likely make what is generated by `protoc-gen-openapiv2` incorrect. // Since this option involves making runtime changes to the response shape or type. func WithForwardResponseRewriter(fwdResponseRewriter ForwardResponseRewriter) ServeMuxOption { return func(sm *ServeMux) { sm.forwardResponseRewriter = fwdResponseRewriter } } // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption. // // forwardResponseOption is an option that will be called on the relevant context.Context, // http.ResponseWriter, and proto.Message before every forwarded response. // // The message may be nil in the case where just a header is being sent. func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption { return func(serveMux *ServeMux) { serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption) } } // WithUnescapingMode sets the escaping type. See the definitions of UnescapingMode // for more information. func WithUnescapingMode(mode UnescapingMode) ServeMuxOption { return func(serveMux *ServeMux) { serveMux.unescapingMode = mode } } // WithMiddlewares sets server middleware for all handlers. This is useful as an alternative to gRPC // interceptors when using the direct-to-implementation registration methods and cannot rely // on gRPC interceptors. It's recommended to use gRPC interceptors instead if possible. func WithMiddlewares(middlewares ...Middleware) ServeMuxOption { return func(serveMux *ServeMux) { serveMux.middlewares = append(serveMux.middlewares, middlewares...) } } // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters. // Configuring this will mean the generated OpenAPI output is no longer correct, and it should be // done with careful consideration. func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption { return func(serveMux *ServeMux) { currentQueryParser = queryParameterParser } } // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context. type HeaderMatcherFunc func(string) (string, bool) // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header // keys (as specified by the IANA, e.g: Accept, Cookie, Host) to the gRPC metadata with the grpcgateway- prefix. If you want to know which headers are considered permanent, you can view the isPermanentHTTPHeader function. // HTTP headers that start with 'Grpc-Metadata-' are mapped to gRPC metadata after removing the prefix 'Grpc-Metadata-'. // Other headers are not added to the gRPC metadata. func DefaultHeaderMatcher(key string) (string, bool) { switch key = textproto.CanonicalMIMEHeaderKey(key); { case isPermanentHTTPHeader(key): return MetadataPrefix + key, true case strings.HasPrefix(key, MetadataHeaderPrefix): return key[len(MetadataHeaderPrefix):], true } return "", false } func defaultOutgoingHeaderMatcher(key string) (string, bool) { return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true } func defaultOutgoingTrailerMatcher(key string) (string, bool) { return fmt.Sprintf("%s%s", MetadataTrailerPrefix, key), true } // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway. // // This matcher will be called with each header in http.Request. If matcher returns true, that header will be // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return the modified header. func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption { for _, header := range fn.matchedMalformedHeaders() { grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header) } return func(mux *ServeMux) { mux.incomingHeaderMatcher = fn } } // matchedMalformedHeaders returns the malformed headers that would be forwarded to gRPC server. func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string { if fn == nil { return nil } headers := make([]string, 0) for header := range malformedHTTPHeaders { out, accept := fn(header) if accept && isMalformedHTTPHeader(out) { headers = append(headers, out) } } return headers } // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway. // // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be // passed to http response returned from gateway. To transform the header before passing to response, // matcher should return the modified header. func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption { return func(mux *ServeMux) { mux.outgoingHeaderMatcher = fn } } // WithOutgoingTrailerMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway. // // This matcher will be called with each header in response trailer metadata. If matcher returns true, that header will be // passed to http response returned from gateway. To transform the header before passing to response, // matcher should return the modified header. func WithOutgoingTrailerMatcher(fn HeaderMatcherFunc) ServeMuxOption { return func(mux *ServeMux) { mux.outgoingTrailerMatcher = fn } } // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context. // // This can be used by services that need to read from http.Request and modify gRPC context. A common use case // is reading token from cookie and adding it in gRPC context. func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption { return func(serveMux *ServeMux) { serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator) } } // WithErrorHandler returns a ServeMuxOption for configuring a custom error handler. // // This can be used to configure a custom error response. func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption { return func(serveMux *ServeMux) { serveMux.errorHandler = fn } } // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream // error handler, which allows for customizing the error trailer for server-streaming // calls. // // For stream errors that occur before any response has been written, the mux's // ErrorHandler will be invoked. However, once data has been written, the errors must // be handled differently: they must be included in the response body. The response body's // final message will include the error details returned by the stream error handler. func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption { return func(serveMux *ServeMux) { serveMux.streamErrorHandler = fn } } // WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors. // // Method called for errors which can happen before gRPC route selected or executed. // The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption { return func(serveMux *ServeMux) { serveMux.routingErrorHandler = fn } } // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback. func WithDisablePathLengthFallback() ServeMuxOption { return func(serveMux *ServeMux) { serveMux.disablePathLengthFallback = true } } // WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath. // When called the handler will forward the request to the upstream grpc service health check (defined in the // gRPC Health Checking Protocol). // // See here https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/health_check/ for more information on how // to setup the protocol in the grpc server. // // If you define a service as query parameter, this will also be forwarded as service in the HealthCheckRequest. func WithHealthEndpointAt(healthCheckClient grpc_health_v1.HealthClient, endpointPath string) ServeMuxOption { return func(s *ServeMux) { // error can be ignored since pattern is definitely valid _ = s.HandlePath( http.MethodGet, endpointPath, func(w http.ResponseWriter, r *http.Request, _ map[string]string, ) { _, outboundMarshaler := MarshalerForRequest(s, r) resp, err := healthCheckClient.Check(r.Context(), &grpc_health_v1.HealthCheckRequest{ Service: r.URL.Query().Get("service"), }) if err != nil { s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err) return } w.Header().Set("Content-Type", "application/json") if resp.GetStatus() != grpc_health_v1.HealthCheckResponse_SERVING { switch resp.GetStatus() { case grpc_health_v1.HealthCheckResponse_NOT_SERVING, grpc_health_v1.HealthCheckResponse_UNKNOWN: err = status.Error(codes.Unavailable, resp.String()) case grpc_health_v1.HealthCheckResponse_SERVICE_UNKNOWN: err = status.Error(codes.NotFound, resp.String()) } s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err) return } _ = outboundMarshaler.NewEncoder(w).Encode(resp) }) } } // WithHealthzEndpoint returns a ServeMuxOption that will add a /healthz endpoint to the created ServeMux. // // See WithHealthEndpointAt for the general implementation. func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMuxOption { return WithHealthEndpointAt(healthCheckClient, "/healthz") } // NewServeMux returns a new ServeMux whose internal mapping is empty. func NewServeMux(opts ...ServeMuxOption) *ServeMux { serveMux := &ServeMux{ handlers: make(map[string][]handler), forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0), forwardResponseRewriter: func(ctx context.Context, response proto.Message) (any, error) { return response, nil }, marshalers: makeMarshalerMIMERegistry(), errorHandler: DefaultHTTPErrorHandler, streamErrorHandler: DefaultStreamErrorHandler, routingErrorHandler: DefaultRoutingErrorHandler, unescapingMode: UnescapingModeDefault, } for _, opt := range opts { opt(serveMux) } if serveMux.incomingHeaderMatcher == nil { serveMux.incomingHeaderMatcher = DefaultHeaderMatcher } if serveMux.outgoingHeaderMatcher == nil { serveMux.outgoingHeaderMatcher = defaultOutgoingHeaderMatcher } if serveMux.outgoingTrailerMatcher == nil { serveMux.outgoingTrailerMatcher = defaultOutgoingTrailerMatcher } return serveMux } // Handle associates "h" to the pair of HTTP method and path pattern. func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) { if len(s.middlewares) > 0 { h = chainMiddlewares(s.middlewares)(h) } s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...) } // HandlePath allows users to configure custom path handlers. // refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/ func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error { compiler, err := httprule.Parse(pathPattern) if err != nil { return fmt.Errorf("parsing path pattern: %w", err) } tp := compiler.Compile() pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb) if err != nil { return fmt.Errorf("creating new pattern: %w", err) } s.Handle(meth, pattern, h) return nil } // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.URL.Path. func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() path := r.URL.Path if !strings.HasPrefix(path, "/") { _, outboundMarshaler := MarshalerForRequest(s, r) s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest) return } // TODO(v3): remove UnescapingModeLegacy if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" { path = r.URL.RawPath } if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) { if err := r.ParseForm(); err != nil { _, outboundMarshaler := MarshalerForRequest(s, r) sterr := status.Error(codes.InvalidArgument, err.Error()) s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } r.Method = strings.ToUpper(override) } var pathComponents []string // since in UnescapeModeLegacy, the URL will already have been fully unescaped, if we also split on "%2F" // in this escaping mode we would be double unescaping but in UnescapingModeAllCharacters, we still do as the // path is the RawPath (i.e. unescaped). That does mean that the behavior of this function will change its default // behavior when the UnescapingModeDefault gets changed from UnescapingModeLegacy to UnescapingModeAllExceptReserved if s.unescapingMode == UnescapingModeAllCharacters { pathComponents = encodedPathSplitter.Split(path[1:], -1) } else { pathComponents = strings.Split(path[1:], "/") } lastPathComponent := pathComponents[len(pathComponents)-1] for _, h := range s.handlers[r.Method] { // If the pattern has a verb, explicitly look for a suffix in the last // component that matches a colon plus the verb. This allows us to // handle some cases that otherwise can't be correctly handled by the // former LastIndex case, such as when the verb literal itself contains // a colon. This should work for all cases that have run through the // parser because we know what verb we're looking for, however, there // are still some cases that the parser itself cannot disambiguate. See // the comment there if interested. var verb string patVerb := h.pat.Verb() idx := -1 if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) { idx = len(lastPathComponent) - len(patVerb) - 1 } if idx == 0 { _, outboundMarshaler := MarshalerForRequest(s, r) s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound) return } comps := make([]string, len(pathComponents)) copy(comps, pathComponents) if idx > 0 { comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:] } pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode) if err != nil { var mse MalformedSequenceError if ok := errors.As(err, &mse); ok { _, outboundMarshaler := MarshalerForRequest(s, r) s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{ HTTPStatus: http.StatusBadRequest, Err: mse, }) } continue } s.handleHandler(h, w, r, pathParams) return } // if no handler has found for the request, lookup for other methods // to handle POST -> GET fallback if the request is subject to path // length fallback. // Note we are not eagerly checking the request here as we want to return the // right HTTP status code, and we need to process the fallback candidates in // order to do that. for m, handlers := range s.handlers { if m == r.Method { continue } for _, h := range handlers { var verb string patVerb := h.pat.Verb() idx := -1 if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) { idx = len(lastPathComponent) - len(patVerb) - 1 } comps := make([]string, len(pathComponents)) copy(comps, pathComponents) if idx > 0 { comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:] } pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode) if err != nil { var mse MalformedSequenceError if ok := errors.As(err, &mse); ok { _, outboundMarshaler := MarshalerForRequest(s, r) s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{ HTTPStatus: http.StatusBadRequest, Err: mse, }) } continue } // X-HTTP-Method-Override is optional. Always allow fallback to POST. // Also, only consider POST -> GET fallbacks, and avoid falling back to // potentially dangerous operations like DELETE. if s.isPathLengthFallback(r) && m == http.MethodGet { if err := r.ParseForm(); err != nil { _, outboundMarshaler := MarshalerForRequest(s, r) sterr := status.Error(codes.InvalidArgument, err.Error()) s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr) return } s.handleHandler(h, w, r, pathParams) return } _, outboundMarshaler := MarshalerForRequest(s, r) s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed) return } } _, outboundMarshaler := MarshalerForRequest(s, r) s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound) } // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux. func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error { return s.forwardResponseOptions } func (s *ServeMux) isPathLengthFallback(r *http.Request) bool { return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" } type handler struct { pat Pattern h HandlerFunc } func (s *ServeMux) handleHandler(h handler, w http.ResponseWriter, r *http.Request, pathParams map[string]string) { h.h(w, r.WithContext(withHTTPPattern(r.Context(), h.pat)), pathParams) } func chainMiddlewares(mws []Middleware) Middleware { return func(next HandlerFunc) HandlerFunc { for i := len(mws); i > 0; i-- { next = mws[i-1](next) } return next } }