// Copyright 2015 go-swagger maintainers // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package middleware import ( "encoding" "encoding/base64" "fmt" "io" "net/http" "reflect" "strconv" "github.com/go-openapi/errors" "github.com/go-openapi/spec" "github.com/go-openapi/strfmt" "github.com/go-openapi/swag" "github.com/go-openapi/validate" "github.com/go-openapi/runtime" ) const defaultMaxMemory = 32 << 20 const ( typeString = "string" typeArray = "array" ) var textUnmarshalType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() func newUntypedParamBinder(param spec.Parameter, spec *spec.Swagger, formats strfmt.Registry) *untypedParamBinder { binder := new(untypedParamBinder) binder.Name = param.Name binder.parameter = ¶m binder.formats = formats if param.In != "body" { binder.validator = validate.NewParamValidator(¶m, formats) } else { binder.validator = validate.NewSchemaValidator(param.Schema, spec, param.Name, formats) } return binder } type untypedParamBinder struct { parameter *spec.Parameter formats strfmt.Registry Name string validator validate.EntityValidator } func (p *untypedParamBinder) Type() reflect.Type { return p.typeForSchema(p.parameter.Type, p.parameter.Format, p.parameter.Items) } func (p *untypedParamBinder) typeForSchema(tpe, format string, items *spec.Items) reflect.Type { switch tpe { case "boolean": return reflect.TypeOf(true) case typeString: if tt, ok := p.formats.GetType(format); ok { return tt } return reflect.TypeOf("") case "integer": switch format { case "int8": return reflect.TypeOf(int8(0)) case "int16": return reflect.TypeOf(int16(0)) case "int32": return reflect.TypeOf(int32(0)) case "int64": return reflect.TypeOf(int64(0)) default: return reflect.TypeOf(int64(0)) } case "number": switch format { case "float": return reflect.TypeOf(float32(0)) case "double": return reflect.TypeOf(float64(0)) } case typeArray: if items == nil { return nil } itemsType := p.typeForSchema(items.Type, items.Format, items.Items) if itemsType == nil { return nil } return reflect.MakeSlice(reflect.SliceOf(itemsType), 0, 0).Type() case "file": return reflect.TypeOf(&runtime.File{}).Elem() case "object": return reflect.TypeOf(map[string]interface{}{}) } return nil } func (p *untypedParamBinder) allowsMulti() bool { return p.parameter.In == "query" || p.parameter.In == "formData" } func (p *untypedParamBinder) readValue(values runtime.Gettable, target reflect.Value) ([]string, bool, bool, error) { name, in, cf, tpe := p.parameter.Name, p.parameter.In, p.parameter.CollectionFormat, p.parameter.Type if tpe == typeArray { if cf == "multi" { if !p.allowsMulti() { return nil, false, false, errors.InvalidCollectionFormat(name, in, cf) } vv, hasKey, _ := values.GetOK(name) return vv, false, hasKey, nil } v, hk, hv := values.GetOK(name) if !hv { return nil, false, hk, nil } d, c, e := p.readFormattedSliceFieldValue(v[len(v)-1], target) return d, c, hk, e } vv, hk, _ := values.GetOK(name) return vv, false, hk, nil } func (p *untypedParamBinder) Bind(request *http.Request, routeParams RouteParams, consumer runtime.Consumer, target reflect.Value) error { // fmt.Println("binding", p.name, "as", p.Type()) switch p.parameter.In { case "query": data, custom, hasKey, err := p.readValue(runtime.Values(request.URL.Query()), target) if err != nil { return err } if custom { return nil } return p.bindValue(data, hasKey, target) case "header": data, custom, hasKey, err := p.readValue(runtime.Values(request.Header), target) if err != nil { return err } if custom { return nil } return p.bindValue(data, hasKey, target) case "path": data, custom, hasKey, err := p.readValue(routeParams, target) if err != nil { return err } if custom { return nil } return p.bindValue(data, hasKey, target) case "formData": var err error var mt string mt, _, e := runtime.ContentType(request.Header) if e != nil { // because of the interface conversion go thinks the error is not nil // so we first check for nil and then set the err var if it's not nil err = e } if err != nil { return errors.InvalidContentType("", []string{"multipart/form-data", "application/x-www-form-urlencoded"}) } if mt != "multipart/form-data" && mt != "application/x-www-form-urlencoded" { return errors.InvalidContentType(mt, []string{"multipart/form-data", "application/x-www-form-urlencoded"}) } if mt == "multipart/form-data" { if err = request.ParseMultipartForm(defaultMaxMemory); err != nil { return errors.NewParseError(p.Name, p.parameter.In, "", err) } } if err = request.ParseForm(); err != nil { return errors.NewParseError(p.Name, p.parameter.In, "", err) } if p.parameter.Type == "file" { file, header, ffErr := request.FormFile(p.parameter.Name) if ffErr != nil { if p.parameter.Required { return errors.NewParseError(p.Name, p.parameter.In, "", ffErr) } return nil } target.Set(reflect.ValueOf(runtime.File{Data: file, Header: header})) return nil } if request.MultipartForm != nil { data, custom, hasKey, rvErr := p.readValue(runtime.Values(request.MultipartForm.Value), target) if rvErr != nil { return rvErr } if custom { return nil } return p.bindValue(data, hasKey, target) } data, custom, hasKey, err := p.readValue(runtime.Values(request.PostForm), target) if err != nil { return err } if custom { return nil } return p.bindValue(data, hasKey, target) case "body": newValue := reflect.New(target.Type()) if !runtime.HasBody(request) { if p.parameter.Default != nil { target.Set(reflect.ValueOf(p.parameter.Default)) } return nil } if err := consumer.Consume(request.Body, newValue.Interface()); err != nil { if err == io.EOF && p.parameter.Default != nil { target.Set(reflect.ValueOf(p.parameter.Default)) return nil } tpe := p.parameter.Type if p.parameter.Format != "" { tpe = p.parameter.Format } return errors.InvalidType(p.Name, p.parameter.In, tpe, nil) } target.Set(reflect.Indirect(newValue)) return nil default: return errors.New(500, fmt.Sprintf("invalid parameter location %q", p.parameter.In)) } } func (p *untypedParamBinder) bindValue(data []string, hasKey bool, target reflect.Value) error { if p.parameter.Type == typeArray { return p.setSliceFieldValue(target, p.parameter.Default, data, hasKey) } var d string if len(data) > 0 { d = data[len(data)-1] } return p.setFieldValue(target, p.parameter.Default, d, hasKey) } func (p *untypedParamBinder) setFieldValue(target reflect.Value, defaultValue interface{}, data string, hasKey bool) error { //nolint:gocyclo tpe := p.parameter.Type if p.parameter.Format != "" { tpe = p.parameter.Format } if (!hasKey || (!p.parameter.AllowEmptyValue && data == "")) && p.parameter.Required && p.parameter.Default == nil { return errors.Required(p.Name, p.parameter.In, data) } ok, err := p.tryUnmarshaler(target, defaultValue, data) if err != nil { return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } if ok { return nil } defVal := reflect.Zero(target.Type()) if defaultValue != nil { defVal = reflect.ValueOf(defaultValue) } if tpe == "byte" { if data == "" { if target.CanSet() { target.SetBytes(defVal.Bytes()) } return nil } b, err := base64.StdEncoding.DecodeString(data) if err != nil { b, err = base64.URLEncoding.DecodeString(data) if err != nil { return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } } if target.CanSet() { target.SetBytes(b) } return nil } switch target.Kind() { //nolint:exhaustive // we want to check only types that map from a swagger parameter case reflect.Bool: if data == "" { if target.CanSet() { target.SetBool(defVal.Bool()) } return nil } b, err := swag.ConvertBool(data) if err != nil { return err } if target.CanSet() { target.SetBool(b) } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if data == "" { if target.CanSet() { rd := defVal.Convert(reflect.TypeOf(int64(0))) target.SetInt(rd.Int()) } return nil } i, err := strconv.ParseInt(data, 10, 64) if err != nil { return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } if target.OverflowInt(i) { return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } if target.CanSet() { target.SetInt(i) } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: if data == "" { if target.CanSet() { rd := defVal.Convert(reflect.TypeOf(uint64(0))) target.SetUint(rd.Uint()) } return nil } u, err := strconv.ParseUint(data, 10, 64) if err != nil { return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } if target.OverflowUint(u) { return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } if target.CanSet() { target.SetUint(u) } case reflect.Float32, reflect.Float64: if data == "" { if target.CanSet() { rd := defVal.Convert(reflect.TypeOf(float64(0))) target.SetFloat(rd.Float()) } return nil } f, err := strconv.ParseFloat(data, 64) if err != nil { return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } if target.OverflowFloat(f) { return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } if target.CanSet() { target.SetFloat(f) } case reflect.String: value := data if value == "" { value = defVal.String() } // validate string if target.CanSet() { target.SetString(value) } case reflect.Ptr: if data == "" && defVal.Kind() == reflect.Ptr { if target.CanSet() { target.Set(defVal) } return nil } newVal := reflect.New(target.Type().Elem()) if err := p.setFieldValue(reflect.Indirect(newVal), defVal, data, hasKey); err != nil { return err } if target.CanSet() { target.Set(newVal) } default: return errors.InvalidType(p.Name, p.parameter.In, tpe, data) } return nil } func (p *untypedParamBinder) tryUnmarshaler(target reflect.Value, defaultValue interface{}, data string) (bool, error) { if !target.CanSet() { return false, nil } // When a type implements encoding.TextUnmarshaler we'll use that instead of reflecting some more if reflect.PtrTo(target.Type()).Implements(textUnmarshalType) { if defaultValue != nil && len(data) == 0 { target.Set(reflect.ValueOf(defaultValue)) return true, nil } value := reflect.New(target.Type()) if err := value.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(data)); err != nil { return true, err } target.Set(reflect.Indirect(value)) return true, nil } return false, nil } func (p *untypedParamBinder) readFormattedSliceFieldValue(data string, target reflect.Value) ([]string, bool, error) { ok, err := p.tryUnmarshaler(target, p.parameter.Default, data) if err != nil { return nil, true, err } if ok { return nil, true, nil } return swag.SplitByFormat(data, p.parameter.CollectionFormat), false, nil } func (p *untypedParamBinder) setSliceFieldValue(target reflect.Value, defaultValue interface{}, data []string, hasKey bool) error { sz := len(data) if (!hasKey || (!p.parameter.AllowEmptyValue && (sz == 0 || (sz == 1 && data[0] == "")))) && p.parameter.Required && defaultValue == nil { return errors.Required(p.Name, p.parameter.In, data) } defVal := reflect.Zero(target.Type()) if defaultValue != nil { defVal = reflect.ValueOf(defaultValue) } if !target.CanSet() { return nil } if sz == 0 { target.Set(defVal) return nil } value := reflect.MakeSlice(reflect.SliceOf(target.Type().Elem()), sz, sz) for i := 0; i < sz; i++ { if err := p.setFieldValue(value.Index(i), nil, data[i], hasKey); err != nil { return err } } target.Set(value) return nil }