[chore]: Bump github.com/gin-gonic/gin from 1.8.1 to 1.8.2 (#1286)

Bumps [github.com/gin-gonic/gin](https://github.com/gin-gonic/gin) from 1.8.1 to 1.8.2.
- [Release notes](https://github.com/gin-gonic/gin/releases)
- [Changelog](https://github.com/gin-gonic/gin/blob/master/CHANGELOG.md)
- [Commits](https://github.com/gin-gonic/gin/compare/v1.8.1...v1.8.2)

---
updated-dependencies:
- dependency-name: github.com/gin-gonic/gin
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
dependabot[bot] 2022-12-27 08:29:42 +00:00 committed by GitHub
parent abd594b71f
commit b966d3b157
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
45 changed files with 1196 additions and 11185 deletions

8
go.mod
View file

@ -21,7 +21,7 @@ require (
github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/cors v1.4.0
github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/gzip v0.0.6
github.com/gin-contrib/sessions v0.0.5 github.com/gin-contrib/sessions v0.0.5
github.com/gin-gonic/gin v1.8.1 github.com/gin-gonic/gin v1.8.2
github.com/go-fed/httpsig v1.1.0 github.com/go-fed/httpsig v1.1.0
github.com/go-playground/validator/v10 v10.11.1 github.com/go-playground/validator/v10 v10.11.1
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
@ -52,7 +52,7 @@ require (
golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90 golang.org/x/crypto v0.0.0-20220829220503-c86fa9a7ed90
golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d golang.org/x/exp v0.0.0-20220613132600-b0d781184e0d
golang.org/x/image v0.2.0 golang.org/x/image v0.2.0
golang.org/x/net v0.0.0-20221014081412-f15817d10f9b golang.org/x/net v0.4.0
golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783 golang.org/x/oauth2 v0.0.0-20221014153046-6fdb5e3db783
golang.org/x/text v0.5.0 golang.org/x/text v0.5.0
gopkg.in/mcuadros/go-syslog.v2 v2.3.0 gopkg.in/mcuadros/go-syslog.v2 v2.3.0
@ -118,7 +118,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml v1.9.5 // indirect
github.com/pelletier/go-toml/v2 v2.0.5 // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
@ -137,7 +137,7 @@ require (
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220907135952-02c991387e35 // indirect golang.org/x/mod v0.6.0-dev.0.20220907135952-02c991387e35 // indirect
golang.org/x/sys v0.2.0 // indirect golang.org/x/sys v0.3.0 // indirect
golang.org/x/tools v0.1.12 // indirect golang.org/x/tools v0.1.12 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.28.1 // indirect google.golang.org/protobuf v1.28.1 // indirect

17
go.sum
View file

@ -208,8 +208,9 @@ github.com/gin-contrib/sessions v0.0.5 h1:CATtfHmLMQrMNpJRgzjWXD7worTh7g7ritsQfm
github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY= github.com/gin-contrib/sessions v0.0.5/go.mod h1:vYAuaUPqie3WUSsft6HUlCjlwwoJQs97miaG2+7neKY=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8=
github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk=
github.com/gin-gonic/gin v1.8.2 h1:UzKToD9/PoFj/V4rvlKqTRKnQYyz8Sc1MJlv4JHPtvY=
github.com/gin-gonic/gin v1.8.2/go.mod h1:qw5AYuDrzRTnhvusDsrov+fDIxp9Dleuu12h8nfB398=
github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q= github.com/go-errors/errors v1.0.1/go.mod h1:f4zRHt4oKfwPJE5k8C9vpYG+aDHdBFUsgrm6/TyX73Q=
github.com/go-errors/errors v1.0.2/go.mod h1:psDX2osz5VnTOnFWbDeWwS7yejl+uV3FEWEp4lssFEs= github.com/go-errors/errors v1.0.2/go.mod h1:psDX2osz5VnTOnFWbDeWwS7yejl+uV3FEWEp4lssFEs=
github.com/go-errors/errors v1.1.1/go.mod h1:psDX2osz5VnTOnFWbDeWwS7yejl+uV3FEWEp4lssFEs= github.com/go-errors/errors v1.1.1/go.mod h1:psDX2osz5VnTOnFWbDeWwS7yejl+uV3FEWEp4lssFEs=
@ -498,8 +499,8 @@ github.com/orcaman/writerseeker v0.0.0-20200621085525-1d3f536ff85e/go.mod h1:nBd
github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8=
github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c=
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo= github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg= github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU=
github.com/pelletier/go-toml/v2 v2.0.5/go.mod h1:OMHamSCAODeSsVrwwvcJOaoN0LIUIaFVNZzmWyNfXas= github.com/pelletier/go-toml/v2 v2.0.6/go.mod h1:eumQOmlWiOPt5WriQQqoM5y18pDHwha2N+QD+EUNTek=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@ -785,8 +786,8 @@ golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20221014081412-f15817d10f9b h1:tvrvnPFcdzp294diPnrdZZZ8XUt2Tyj7svb7X52iDuU= golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU=
golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -903,12 +904,12 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.3.0 h1:qoo4akIqOcDME5bhc/NgxUdovd6BSS2uMsVjB56q1xI=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View file

@ -1,5 +1,16 @@
# Gin ChangeLog # Gin ChangeLog
## Gin v1.8.2
### Bugs
* fix(route): redirectSlash bug ([#3227]((https://github.com/gin-gonic/gin/pull/3227)))
* fix(engine): missing route params for CreateTestContext ([#2778]((https://github.com/gin-gonic/gin/pull/2778))) ([#2803]((https://github.com/gin-gonic/gin/pull/2803)))
### Security
* Fix the GO-2022-1144 vulnerability ([#3432]((https://github.com/gin-gonic/gin/pull/3432)))
## Gin v1.8.1 ## Gin v1.8.1
### ENHANCEMENTS ### ENHANCEMENTS

View file

@ -203,7 +203,7 @@ func New() *Engine {
} }
engine.RouterGroup.engine = engine engine.RouterGroup.engine = engine
engine.pool.New = func() any { engine.pool.New = func() any {
return engine.allocateContext() return engine.allocateContext(engine.maxParams)
} }
return engine return engine
} }
@ -225,8 +225,8 @@ func (engine *Engine) Handler() http.Handler {
return h2c.NewHandler(engine, h2s) return h2c.NewHandler(engine, h2s)
} }
func (engine *Engine) allocateContext() *Context { func (engine *Engine) allocateContext(maxParams uint16) *Context {
v := make(Params, 0, engine.maxParams) v := make(Params, 0, maxParams)
skippedNodes := make([]skippedNode, 0, engine.maxSections) skippedNodes := make([]skippedNode, 0, engine.maxSections)
return &Context{engine: engine, params: &v, skippedNodes: &skippedNodes} return &Context{engine: engine, params: &v, skippedNodes: &skippedNodes}
} }

View file

@ -9,7 +9,15 @@
// CreateTestContext returns a fresh engine and context for testing purposes // CreateTestContext returns a fresh engine and context for testing purposes
func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) { func CreateTestContext(w http.ResponseWriter) (c *Context, r *Engine) {
r = New() r = New()
c = r.allocateContext() c = r.allocateContext(0)
c.reset()
c.writermem.reset(w)
return
}
// CreateTestContextOnly returns a fresh context base on the engine for testing purposes
func CreateTestContextOnly(w http.ResponseWriter, r *Engine) (c *Context) {
c = r.allocateContext(r.maxParams)
c.reset() c.reset()
c.writermem.reset(w) c.writermem.reset(w)
return return

View file

@ -107,7 +107,8 @@ func countSections(path string) uint16 {
type nodeType uint8 type nodeType uint8
const ( const (
root nodeType = iota + 1 static nodeType = iota
root
param param
catchAll catchAll
) )
@ -173,6 +174,7 @@ func (n *node) addRoute(path string, handlers HandlersChain) {
child := node{ child := node{
path: n.path[i:], path: n.path[i:],
wildChild: n.wildChild, wildChild: n.wildChild,
nType: static,
indices: n.indices, indices: n.indices,
children: n.children, children: n.children,
handlers: n.handlers, handlers: n.handlers,
@ -604,6 +606,11 @@ func (n *node) getValue(path string, params *Params, skippedNodes *[]skippedNode
return return
} }
if path == "/" && n.nType == static {
value.tsr = true
return
}
// No handle found. Check if a handle for this path + a // No handle found. Check if a handle for this path + a
// trailing slash exists for trailing slash recommendation // trailing slash exists for trailing slash recommendation
for i, c := range []byte(n.indices) { for i, c := range []byte(n.indices) {

View file

@ -5,4 +5,4 @@
package gin package gin
// Version is the current gin framework's version. // Version is the current gin framework's version.
const Version = "v1.8.1" const Version = "v1.8.2"

View file

@ -140,6 +140,17 @@ fmt.Println(string(b))
[marshal]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Marshal [marshal]: https://pkg.go.dev/github.com/pelletier/go-toml/v2#Marshal
## Unstable API
This API does not yet follow the backward compatibility guarantees of this
library. They provide early access to features that may have rough edges or an
API subject to change.
### Parser
Parser is the unstable API that allows iterative parsing of a TOML document at
the AST level. See https://pkg.go.dev/github.com/pelletier/go-toml/v2/unstable.
## Benchmarks ## Benchmarks
Execution time speedup compared to other Go TOML libraries: Execution time speedup compared to other Go TOML libraries:

View file

@ -5,6 +5,8 @@
"math" "math"
"strconv" "strconv"
"time" "time"
"github.com/pelletier/go-toml/v2/unstable"
) )
func parseInteger(b []byte) (int64, error) { func parseInteger(b []byte) (int64, error) {
@ -32,7 +34,7 @@ func parseLocalDate(b []byte) (LocalDate, error) {
var date LocalDate var date LocalDate
if len(b) != 10 || b[4] != '-' || b[7] != '-' { if len(b) != 10 || b[4] != '-' || b[7] != '-' {
return date, newDecodeError(b, "dates are expected to have the format YYYY-MM-DD") return date, unstable.NewParserError(b, "dates are expected to have the format YYYY-MM-DD")
} }
var err error var err error
@ -53,7 +55,7 @@ func parseLocalDate(b []byte) (LocalDate, error) {
} }
if !isValidDate(date.Year, date.Month, date.Day) { if !isValidDate(date.Year, date.Month, date.Day) {
return LocalDate{}, newDecodeError(b, "impossible date") return LocalDate{}, unstable.NewParserError(b, "impossible date")
} }
return date, nil return date, nil
@ -64,7 +66,7 @@ func parseDecimalDigits(b []byte) (int, error) {
for i, c := range b { for i, c := range b {
if c < '0' || c > '9' { if c < '0' || c > '9' {
return 0, newDecodeError(b[i:i+1], "expected digit (0-9)") return 0, unstable.NewParserError(b[i:i+1], "expected digit (0-9)")
} }
v *= 10 v *= 10
v += int(c - '0') v += int(c - '0')
@ -97,7 +99,7 @@ func parseDateTime(b []byte) (time.Time, error) {
} else { } else {
const dateTimeByteLen = 6 const dateTimeByteLen = 6
if len(b) != dateTimeByteLen { if len(b) != dateTimeByteLen {
return time.Time{}, newDecodeError(b, "invalid date-time timezone") return time.Time{}, unstable.NewParserError(b, "invalid date-time timezone")
} }
var direction int var direction int
switch b[0] { switch b[0] {
@ -106,11 +108,11 @@ func parseDateTime(b []byte) (time.Time, error) {
case '+': case '+':
direction = +1 direction = +1
default: default:
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset character") return time.Time{}, unstable.NewParserError(b[:1], "invalid timezone offset character")
} }
if b[3] != ':' { if b[3] != ':' {
return time.Time{}, newDecodeError(b[3:4], "expected a : separator") return time.Time{}, unstable.NewParserError(b[3:4], "expected a : separator")
} }
hours, err := parseDecimalDigits(b[1:3]) hours, err := parseDecimalDigits(b[1:3])
@ -118,7 +120,7 @@ func parseDateTime(b []byte) (time.Time, error) {
return time.Time{}, err return time.Time{}, err
} }
if hours > 23 { if hours > 23 {
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset hours") return time.Time{}, unstable.NewParserError(b[:1], "invalid timezone offset hours")
} }
minutes, err := parseDecimalDigits(b[4:6]) minutes, err := parseDecimalDigits(b[4:6])
@ -126,7 +128,7 @@ func parseDateTime(b []byte) (time.Time, error) {
return time.Time{}, err return time.Time{}, err
} }
if minutes > 59 { if minutes > 59 {
return time.Time{}, newDecodeError(b[:1], "invalid timezone offset minutes") return time.Time{}, unstable.NewParserError(b[:1], "invalid timezone offset minutes")
} }
seconds := direction * (hours*3600 + minutes*60) seconds := direction * (hours*3600 + minutes*60)
@ -139,7 +141,7 @@ func parseDateTime(b []byte) (time.Time, error) {
} }
if len(b) > 0 { if len(b) > 0 {
return time.Time{}, newDecodeError(b, "extra bytes at the end of the timezone") return time.Time{}, unstable.NewParserError(b, "extra bytes at the end of the timezone")
} }
t := time.Date( t := time.Date(
@ -160,7 +162,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
const localDateTimeByteMinLen = 11 const localDateTimeByteMinLen = 11
if len(b) < localDateTimeByteMinLen { if len(b) < localDateTimeByteMinLen {
return dt, nil, newDecodeError(b, "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]") return dt, nil, unstable.NewParserError(b, "local datetimes are expected to have the format YYYY-MM-DDTHH:MM:SS[.NNNNNNNNN]")
} }
date, err := parseLocalDate(b[:10]) date, err := parseLocalDate(b[:10])
@ -171,7 +173,7 @@ func parseLocalDateTime(b []byte) (LocalDateTime, []byte, error) {
sep := b[10] sep := b[10]
if sep != 'T' && sep != ' ' && sep != 't' { if sep != 'T' && sep != ' ' && sep != 't' {
return dt, nil, newDecodeError(b[10:11], "datetime separator is expected to be T or a space") return dt, nil, unstable.NewParserError(b[10:11], "datetime separator is expected to be T or a space")
} }
t, rest, err := parseLocalTime(b[11:]) t, rest, err := parseLocalTime(b[11:])
@ -195,7 +197,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
// check if b matches to have expected format HH:MM:SS[.NNNNNN] // check if b matches to have expected format HH:MM:SS[.NNNNNN]
const localTimeByteLen = 8 const localTimeByteLen = 8
if len(b) < localTimeByteLen { if len(b) < localTimeByteLen {
return t, nil, newDecodeError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]") return t, nil, unstable.NewParserError(b, "times are expected to have the format HH:MM:SS[.NNNNNN]")
} }
var err error var err error
@ -206,10 +208,10 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
} }
if t.Hour > 23 { if t.Hour > 23 {
return t, nil, newDecodeError(b[0:2], "hour cannot be greater 23") return t, nil, unstable.NewParserError(b[0:2], "hour cannot be greater 23")
} }
if b[2] != ':' { if b[2] != ':' {
return t, nil, newDecodeError(b[2:3], "expecting colon between hours and minutes") return t, nil, unstable.NewParserError(b[2:3], "expecting colon between hours and minutes")
} }
t.Minute, err = parseDecimalDigits(b[3:5]) t.Minute, err = parseDecimalDigits(b[3:5])
@ -217,10 +219,10 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
return t, nil, err return t, nil, err
} }
if t.Minute > 59 { if t.Minute > 59 {
return t, nil, newDecodeError(b[3:5], "minutes cannot be greater 59") return t, nil, unstable.NewParserError(b[3:5], "minutes cannot be greater 59")
} }
if b[5] != ':' { if b[5] != ':' {
return t, nil, newDecodeError(b[5:6], "expecting colon between minutes and seconds") return t, nil, unstable.NewParserError(b[5:6], "expecting colon between minutes and seconds")
} }
t.Second, err = parseDecimalDigits(b[6:8]) t.Second, err = parseDecimalDigits(b[6:8])
@ -229,7 +231,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
} }
if t.Second > 60 { if t.Second > 60 {
return t, nil, newDecodeError(b[6:8], "seconds cannot be greater 60") return t, nil, unstable.NewParserError(b[6:8], "seconds cannot be greater 60")
} }
b = b[8:] b = b[8:]
@ -242,7 +244,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
for i, c := range b[1:] { for i, c := range b[1:] {
if !isDigit(c) { if !isDigit(c) {
if i == 0 { if i == 0 {
return t, nil, newDecodeError(b[0:1], "need at least one digit after fraction point") return t, nil, unstable.NewParserError(b[0:1], "need at least one digit after fraction point")
} }
break break
} }
@ -266,7 +268,7 @@ func parseLocalTime(b []byte) (LocalTime, []byte, error) {
} }
if precision == 0 { if precision == 0 {
return t, nil, newDecodeError(b[:1], "nanoseconds need at least one digit") return t, nil, unstable.NewParserError(b[:1], "nanoseconds need at least one digit")
} }
t.Nanosecond = frac * nspow[precision] t.Nanosecond = frac * nspow[precision]
@ -289,24 +291,24 @@ func parseFloat(b []byte) (float64, error) {
} }
if cleaned[0] == '.' { if cleaned[0] == '.' {
return 0, newDecodeError(b, "float cannot start with a dot") return 0, unstable.NewParserError(b, "float cannot start with a dot")
} }
if cleaned[len(cleaned)-1] == '.' { if cleaned[len(cleaned)-1] == '.' {
return 0, newDecodeError(b, "float cannot end with a dot") return 0, unstable.NewParserError(b, "float cannot end with a dot")
} }
dotAlreadySeen := false dotAlreadySeen := false
for i, c := range cleaned { for i, c := range cleaned {
if c == '.' { if c == '.' {
if dotAlreadySeen { if dotAlreadySeen {
return 0, newDecodeError(b[i:i+1], "float can have at most one decimal point") return 0, unstable.NewParserError(b[i:i+1], "float can have at most one decimal point")
} }
if !isDigit(cleaned[i-1]) { if !isDigit(cleaned[i-1]) {
return 0, newDecodeError(b[i-1:i+1], "float decimal point must be preceded by a digit") return 0, unstable.NewParserError(b[i-1:i+1], "float decimal point must be preceded by a digit")
} }
if !isDigit(cleaned[i+1]) { if !isDigit(cleaned[i+1]) {
return 0, newDecodeError(b[i:i+2], "float decimal point must be followed by a digit") return 0, unstable.NewParserError(b[i:i+2], "float decimal point must be followed by a digit")
} }
dotAlreadySeen = true dotAlreadySeen = true
} }
@ -317,12 +319,12 @@ func parseFloat(b []byte) (float64, error) {
start = 1 start = 1
} }
if cleaned[start] == '0' && isDigit(cleaned[start+1]) { if cleaned[start] == '0' && isDigit(cleaned[start+1]) {
return 0, newDecodeError(b, "float integer part cannot have leading zeroes") return 0, unstable.NewParserError(b, "float integer part cannot have leading zeroes")
} }
f, err := strconv.ParseFloat(string(cleaned), 64) f, err := strconv.ParseFloat(string(cleaned), 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "unable to parse float: %w", err) return 0, unstable.NewParserError(b, "unable to parse float: %w", err)
} }
return f, nil return f, nil
@ -336,7 +338,7 @@ func parseIntHex(b []byte) (int64, error) {
i, err := strconv.ParseInt(string(cleaned), 16, 64) i, err := strconv.ParseInt(string(cleaned), 16, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse hexadecimal number: %w", err) return 0, unstable.NewParserError(b, "couldn't parse hexadecimal number: %w", err)
} }
return i, nil return i, nil
@ -350,7 +352,7 @@ func parseIntOct(b []byte) (int64, error) {
i, err := strconv.ParseInt(string(cleaned), 8, 64) i, err := strconv.ParseInt(string(cleaned), 8, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse octal number: %w", err) return 0, unstable.NewParserError(b, "couldn't parse octal number: %w", err)
} }
return i, nil return i, nil
@ -364,7 +366,7 @@ func parseIntBin(b []byte) (int64, error) {
i, err := strconv.ParseInt(string(cleaned), 2, 64) i, err := strconv.ParseInt(string(cleaned), 2, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse binary number: %w", err) return 0, unstable.NewParserError(b, "couldn't parse binary number: %w", err)
} }
return i, nil return i, nil
@ -387,12 +389,12 @@ func parseIntDec(b []byte) (int64, error) {
} }
if len(cleaned) > startIdx+1 && cleaned[startIdx] == '0' { if len(cleaned) > startIdx+1 && cleaned[startIdx] == '0' {
return 0, newDecodeError(b, "leading zero not allowed on decimal number") return 0, unstable.NewParserError(b, "leading zero not allowed on decimal number")
} }
i, err := strconv.ParseInt(string(cleaned), 10, 64) i, err := strconv.ParseInt(string(cleaned), 10, 64)
if err != nil { if err != nil {
return 0, newDecodeError(b, "couldn't parse decimal number: %w", err) return 0, unstable.NewParserError(b, "couldn't parse decimal number: %w", err)
} }
return i, nil return i, nil
@ -409,11 +411,11 @@ func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
} }
if b[start] == '_' { if b[start] == '_' {
return nil, newDecodeError(b[start:start+1], "number cannot start with underscore") return nil, unstable.NewParserError(b[start:start+1], "number cannot start with underscore")
} }
if b[len(b)-1] == '_' { if b[len(b)-1] == '_' {
return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore") return nil, unstable.NewParserError(b[len(b)-1:], "number cannot end with underscore")
} }
// fast path // fast path
@ -435,7 +437,7 @@ func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
c := b[i] c := b[i]
if c == '_' { if c == '_' {
if !before { if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores") return nil, unstable.NewParserError(b[i-1:i+1], "number must have at least one digit between underscores")
} }
before = false before = false
} else { } else {
@ -449,11 +451,11 @@ func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) { func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
if b[0] == '_' { if b[0] == '_' {
return nil, newDecodeError(b[0:1], "number cannot start with underscore") return nil, unstable.NewParserError(b[0:1], "number cannot start with underscore")
} }
if b[len(b)-1] == '_' { if b[len(b)-1] == '_' {
return nil, newDecodeError(b[len(b)-1:], "number cannot end with underscore") return nil, unstable.NewParserError(b[len(b)-1:], "number cannot end with underscore")
} }
// fast path // fast path
@ -476,10 +478,10 @@ func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
switch c { switch c {
case '_': case '_':
if !before { if !before {
return nil, newDecodeError(b[i-1:i+1], "number must have at least one digit between underscores") return nil, unstable.NewParserError(b[i-1:i+1], "number must have at least one digit between underscores")
} }
if i < len(b)-1 && (b[i+1] == 'e' || b[i+1] == 'E') { if i < len(b)-1 && (b[i+1] == 'e' || b[i+1] == 'E') {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore before exponent") return nil, unstable.NewParserError(b[i+1:i+2], "cannot have underscore before exponent")
} }
before = false before = false
case '+', '-': case '+', '-':
@ -488,15 +490,15 @@ func checkAndRemoveUnderscoresFloats(b []byte) ([]byte, error) {
before = false before = false
case 'e', 'E': case 'e', 'E':
if i < len(b)-1 && b[i+1] == '_' { if i < len(b)-1 && b[i+1] == '_' {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore after exponent") return nil, unstable.NewParserError(b[i+1:i+2], "cannot have underscore after exponent")
} }
cleaned = append(cleaned, c) cleaned = append(cleaned, c)
case '.': case '.':
if i < len(b)-1 && b[i+1] == '_' { if i < len(b)-1 && b[i+1] == '_' {
return nil, newDecodeError(b[i+1:i+2], "cannot have underscore after decimal point") return nil, unstable.NewParserError(b[i+1:i+2], "cannot have underscore after decimal point")
} }
if i > 0 && b[i-1] == '_' { if i > 0 && b[i-1] == '_' {
return nil, newDecodeError(b[i-1:i], "cannot have underscore before decimal point") return nil, unstable.NewParserError(b[i-1:i], "cannot have underscore before decimal point")
} }
cleaned = append(cleaned, c) cleaned = append(cleaned, c)
default: default:
@ -542,3 +544,7 @@ func daysIn(m int, year int) int {
func isLeap(year int) bool { func isLeap(year int) bool {
return year%4 == 0 && (year%100 != 0 || year%400 == 0) return year%4 == 0 && (year%100 != 0 || year%400 == 0)
} }
func isDigit(r byte) bool {
return r >= '0' && r <= '9'
}

View file

@ -6,6 +6,7 @@
"strings" "strings"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/unstable"
) )
// DecodeError represents an error encountered during the parsing or decoding // DecodeError represents an error encountered during the parsing or decoding
@ -55,25 +56,6 @@ func (s *StrictMissingError) String() string {
type Key []string type Key []string
// internal version of DecodeError that is used as the base to create a
// DecodeError with full context.
type decodeError struct {
highlight []byte
message string
key Key // optional
}
func (de *decodeError) Error() string {
return de.message
}
func newDecodeError(highlight []byte, format string, args ...interface{}) error {
return &decodeError{
highlight: highlight,
message: fmt.Errorf(format, args...).Error(),
}
}
// Error returns the error message contained in the DecodeError. // Error returns the error message contained in the DecodeError.
func (e *DecodeError) Error() string { func (e *DecodeError) Error() string {
return "toml: " + e.message return "toml: " + e.message
@ -105,12 +87,12 @@ func (e *DecodeError) Key() Key {
// highlight can be freely deallocated. // highlight can be freely deallocated.
// //
//nolint:funlen //nolint:funlen
func wrapDecodeError(document []byte, de *decodeError) *DecodeError { func wrapDecodeError(document []byte, de *unstable.ParserError) *DecodeError {
offset := danger.SubsliceOffset(document, de.highlight) offset := danger.SubsliceOffset(document, de.Highlight)
errMessage := de.Error() errMessage := de.Error()
errLine, errColumn := positionAtEnd(document[:offset]) errLine, errColumn := positionAtEnd(document[:offset])
before, after := linesOfContext(document, de.highlight, offset, 3) before, after := linesOfContext(document, de.Highlight, offset, 3)
var buf strings.Builder var buf strings.Builder
@ -140,7 +122,7 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
buf.Write(before[0]) buf.Write(before[0])
} }
buf.Write(de.highlight) buf.Write(de.Highlight)
if len(after) > 0 { if len(after) > 0 {
buf.Write(after[0]) buf.Write(after[0])
@ -158,7 +140,7 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
buf.WriteString(strings.Repeat(" ", len(before[0]))) buf.WriteString(strings.Repeat(" ", len(before[0])))
} }
buf.WriteString(strings.Repeat("~", len(de.highlight))) buf.WriteString(strings.Repeat("~", len(de.Highlight)))
if len(errMessage) > 0 { if len(errMessage) > 0 {
buf.WriteString(" ") buf.WriteString(" ")
@ -183,7 +165,7 @@ func wrapDecodeError(document []byte, de *decodeError) *DecodeError {
message: errMessage, message: errMessage,
line: errLine, line: errLine,
column: errColumn, column: errColumn,
key: de.key, key: de.Key,
human: buf.String(), human: buf.String(),
} }
} }

View file

@ -1,51 +0,0 @@
package ast
type Reference int
const InvalidReference Reference = -1
func (r Reference) Valid() bool {
return r != InvalidReference
}
type Builder struct {
tree Root
lastIdx int
}
func (b *Builder) Tree() *Root {
return &b.tree
}
func (b *Builder) NodeAt(ref Reference) *Node {
return b.tree.at(ref)
}
func (b *Builder) Reset() {
b.tree.nodes = b.tree.nodes[:0]
b.lastIdx = 0
}
func (b *Builder) Push(n Node) Reference {
b.lastIdx = len(b.tree.nodes)
b.tree.nodes = append(b.tree.nodes, n)
return Reference(b.lastIdx)
}
func (b *Builder) PushAndChain(n Node) Reference {
newIdx := len(b.tree.nodes)
b.tree.nodes = append(b.tree.nodes, n)
if b.lastIdx >= 0 {
b.tree.nodes[b.lastIdx].next = newIdx - b.lastIdx
}
b.lastIdx = newIdx
return Reference(b.lastIdx)
}
func (b *Builder) AttachChild(parent Reference, child Reference) {
b.tree.nodes[parent].child = int(child) - int(parent)
}
func (b *Builder) Chain(from Reference, to Reference) {
b.tree.nodes[from].next = int(to) - int(from)
}

View file

@ -0,0 +1,42 @@
package characters
var invalidAsciiTable = [256]bool{
0x00: true,
0x01: true,
0x02: true,
0x03: true,
0x04: true,
0x05: true,
0x06: true,
0x07: true,
0x08: true,
// 0x09 TAB
// 0x0A LF
0x0B: true,
0x0C: true,
// 0x0D CR
0x0E: true,
0x0F: true,
0x10: true,
0x11: true,
0x12: true,
0x13: true,
0x14: true,
0x15: true,
0x16: true,
0x17: true,
0x18: true,
0x19: true,
0x1A: true,
0x1B: true,
0x1C: true,
0x1D: true,
0x1E: true,
0x1F: true,
// 0x20 - 0x7E Printable ASCII characters
0x7F: true,
}
func InvalidAscii(b byte) bool {
return invalidAsciiTable[b]
}

View file

@ -1,4 +1,4 @@
package toml package characters
import ( import (
"unicode/utf8" "unicode/utf8"
@ -32,7 +32,7 @@ func (u utf8Err) Zero() bool {
// 0x9 => tab, ok // 0x9 => tab, ok
// 0xA - 0x1F => invalid // 0xA - 0x1F => invalid
// 0x7F => invalid // 0x7F => invalid
func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) { func Utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
// Fast path. Check for and skip 8 bytes of ASCII characters per iteration. // Fast path. Check for and skip 8 bytes of ASCII characters per iteration.
offset := 0 offset := 0
for len(p) >= 8 { for len(p) >= 8 {
@ -48,7 +48,7 @@ func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
} }
for i, b := range p[:8] { for i, b := range p[:8] {
if invalidAscii(b) { if InvalidAscii(b) {
err.Index = offset + i err.Index = offset + i
err.Size = 1 err.Size = 1
return return
@ -62,7 +62,7 @@ func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
for i := 0; i < n; { for i := 0; i < n; {
pi := p[i] pi := p[i]
if pi < utf8.RuneSelf { if pi < utf8.RuneSelf {
if invalidAscii(pi) { if InvalidAscii(pi) {
err.Index = offset + i err.Index = offset + i
err.Size = 1 err.Size = 1
return return
@ -106,11 +106,11 @@ func utf8TomlValidAlreadyEscaped(p []byte) (err utf8Err) {
} }
// Return the size of the next rune if valid, 0 otherwise. // Return the size of the next rune if valid, 0 otherwise.
func utf8ValidNext(p []byte) int { func Utf8ValidNext(p []byte) int {
c := p[0] c := p[0]
if c < utf8.RuneSelf { if c < utf8.RuneSelf {
if invalidAscii(c) { if InvalidAscii(c) {
return 0 return 0
} }
return 1 return 1
@ -140,47 +140,6 @@ func utf8ValidNext(p []byte) int {
return size return size
} }
var invalidAsciiTable = [256]bool{
0x00: true,
0x01: true,
0x02: true,
0x03: true,
0x04: true,
0x05: true,
0x06: true,
0x07: true,
0x08: true,
// 0x09 TAB
// 0x0A LF
0x0B: true,
0x0C: true,
// 0x0D CR
0x0E: true,
0x0F: true,
0x10: true,
0x11: true,
0x12: true,
0x13: true,
0x14: true,
0x15: true,
0x16: true,
0x17: true,
0x18: true,
0x19: true,
0x1A: true,
0x1B: true,
0x1C: true,
0x1D: true,
0x1E: true,
0x1F: true,
// 0x20 - 0x7E Printable ASCII characters
0x7F: true,
}
func invalidAscii(b byte) bool {
return invalidAsciiTable[b]
}
// acceptRange gives the range of valid values for the second byte in a UTF-8 // acceptRange gives the range of valid values for the second byte in a UTF-8
// sequence. // sequence.
type acceptRange struct { type acceptRange struct {

View file

@ -1,8 +1,6 @@
package tracker package tracker
import ( import "github.com/pelletier/go-toml/v2/unstable"
"github.com/pelletier/go-toml/v2/internal/ast"
)
// KeyTracker is a tracker that keeps track of the current Key as the AST is // KeyTracker is a tracker that keeps track of the current Key as the AST is
// walked. // walked.
@ -11,19 +9,19 @@ type KeyTracker struct {
} }
// UpdateTable sets the state of the tracker with the AST table node. // UpdateTable sets the state of the tracker with the AST table node.
func (t *KeyTracker) UpdateTable(node *ast.Node) { func (t *KeyTracker) UpdateTable(node *unstable.Node) {
t.reset() t.reset()
t.Push(node) t.Push(node)
} }
// UpdateArrayTable sets the state of the tracker with the AST array table node. // UpdateArrayTable sets the state of the tracker with the AST array table node.
func (t *KeyTracker) UpdateArrayTable(node *ast.Node) { func (t *KeyTracker) UpdateArrayTable(node *unstable.Node) {
t.reset() t.reset()
t.Push(node) t.Push(node)
} }
// Push the given key on the stack. // Push the given key on the stack.
func (t *KeyTracker) Push(node *ast.Node) { func (t *KeyTracker) Push(node *unstable.Node) {
it := node.Key() it := node.Key()
for it.Next() { for it.Next() {
t.k = append(t.k, string(it.Node().Data)) t.k = append(t.k, string(it.Node().Data))
@ -31,7 +29,7 @@ func (t *KeyTracker) Push(node *ast.Node) {
} }
// Pop key from stack. // Pop key from stack.
func (t *KeyTracker) Pop(node *ast.Node) { func (t *KeyTracker) Pop(node *unstable.Node) {
it := node.Key() it := node.Key()
for it.Next() { for it.Next() {
t.k = t.k[:len(t.k)-1] t.k = t.k[:len(t.k)-1]

View file

@ -5,7 +5,7 @@
"fmt" "fmt"
"sync" "sync"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/unstable"
) )
type keyKind uint8 type keyKind uint8
@ -150,23 +150,23 @@ func (s *SeenTracker) setExplicitFlag(parentIdx int) {
// CheckExpression takes a top-level node and checks that it does not contain // CheckExpression takes a top-level node and checks that it does not contain
// keys that have been seen in previous calls, and validates that types are // keys that have been seen in previous calls, and validates that types are
// consistent. // consistent.
func (s *SeenTracker) CheckExpression(node *ast.Node) error { func (s *SeenTracker) CheckExpression(node *unstable.Node) error {
if s.entries == nil { if s.entries == nil {
s.reset() s.reset()
} }
switch node.Kind { switch node.Kind {
case ast.KeyValue: case unstable.KeyValue:
return s.checkKeyValue(node) return s.checkKeyValue(node)
case ast.Table: case unstable.Table:
return s.checkTable(node) return s.checkTable(node)
case ast.ArrayTable: case unstable.ArrayTable:
return s.checkArrayTable(node) return s.checkArrayTable(node)
default: default:
panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind)) panic(fmt.Errorf("this should not be a top level node type: %s", node.Kind))
} }
} }
func (s *SeenTracker) checkTable(node *ast.Node) error { func (s *SeenTracker) checkTable(node *unstable.Node) error {
if s.currentIdx >= 0 { if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx) s.setExplicitFlag(s.currentIdx)
} }
@ -219,7 +219,7 @@ func (s *SeenTracker) checkTable(node *ast.Node) error {
return nil return nil
} }
func (s *SeenTracker) checkArrayTable(node *ast.Node) error { func (s *SeenTracker) checkArrayTable(node *unstable.Node) error {
if s.currentIdx >= 0 { if s.currentIdx >= 0 {
s.setExplicitFlag(s.currentIdx) s.setExplicitFlag(s.currentIdx)
} }
@ -267,7 +267,7 @@ func (s *SeenTracker) checkArrayTable(node *ast.Node) error {
return nil return nil
} }
func (s *SeenTracker) checkKeyValue(node *ast.Node) error { func (s *SeenTracker) checkKeyValue(node *unstable.Node) error {
parentIdx := s.currentIdx parentIdx := s.currentIdx
it := node.Key() it := node.Key()
@ -297,26 +297,26 @@ func (s *SeenTracker) checkKeyValue(node *ast.Node) error {
value := node.Value() value := node.Value()
switch value.Kind { switch value.Kind {
case ast.InlineTable: case unstable.InlineTable:
return s.checkInlineTable(value) return s.checkInlineTable(value)
case ast.Array: case unstable.Array:
return s.checkArray(value) return s.checkArray(value)
} }
return nil return nil
} }
func (s *SeenTracker) checkArray(node *ast.Node) error { func (s *SeenTracker) checkArray(node *unstable.Node) error {
it := node.Children() it := node.Children()
for it.Next() { for it.Next() {
n := it.Node() n := it.Node()
switch n.Kind { switch n.Kind {
case ast.InlineTable: case unstable.InlineTable:
err := s.checkInlineTable(n) err := s.checkInlineTable(n)
if err != nil { if err != nil {
return err return err
} }
case ast.Array: case unstable.Array:
err := s.checkArray(n) err := s.checkArray(n)
if err != nil { if err != nil {
return err return err
@ -326,7 +326,7 @@ func (s *SeenTracker) checkArray(node *ast.Node) error {
return nil return nil
} }
func (s *SeenTracker) checkInlineTable(node *ast.Node) error { func (s *SeenTracker) checkInlineTable(node *unstable.Node) error {
if pool.New == nil { if pool.New == nil {
pool.New = func() interface{} { pool.New = func() interface{} {
return &SeenTracker{} return &SeenTracker{}

View file

@ -4,6 +4,8 @@
"fmt" "fmt"
"strings" "strings"
"time" "time"
"github.com/pelletier/go-toml/v2/unstable"
) )
// LocalDate represents a calendar day in no specific timezone. // LocalDate represents a calendar day in no specific timezone.
@ -75,7 +77,7 @@ func (d LocalTime) MarshalText() ([]byte, error) {
func (d *LocalTime) UnmarshalText(b []byte) error { func (d *LocalTime) UnmarshalText(b []byte) error {
res, left, err := parseLocalTime(b) res, left, err := parseLocalTime(b)
if err == nil && len(left) != 0 { if err == nil && len(left) != 0 {
err = newDecodeError(left, "extra characters") err = unstable.NewParserError(left, "extra characters")
} }
if err != nil { if err != nil {
return err return err
@ -109,7 +111,7 @@ func (d LocalDateTime) MarshalText() ([]byte, error) {
func (d *LocalDateTime) UnmarshalText(data []byte) error { func (d *LocalDateTime) UnmarshalText(data []byte) error {
res, left, err := parseLocalDateTime(data) res, left, err := parseLocalDateTime(data)
if err == nil && len(left) != 0 { if err == nil && len(left) != 0 {
err = newDecodeError(left, "extra characters") err = unstable.NewParserError(left, "extra characters")
} }
if err != nil { if err != nil {
return err return err

View file

@ -12,6 +12,8 @@
"strings" "strings"
"time" "time"
"unicode" "unicode"
"github.com/pelletier/go-toml/v2/internal/characters"
) )
// Marshal serializes a Go value as a TOML document. // Marshal serializes a Go value as a TOML document.
@ -437,7 +439,7 @@ func (enc *Encoder) encodeString(b []byte, v string, options valueOptions) []byt
func needsQuoting(v string) bool { func needsQuoting(v string) bool {
// TODO: vectorize // TODO: vectorize
for _, b := range []byte(v) { for _, b := range []byte(v) {
if b == '\'' || b == '\r' || b == '\n' || invalidAscii(b) { if b == '\'' || b == '\r' || b == '\n' || characters.InvalidAscii(b) {
return true return true
} }
} }

View file

@ -1,9 +1,9 @@
package toml package toml
import ( import (
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/unstable"
) )
type strict struct { type strict struct {
@ -12,10 +12,10 @@ type strict struct {
// Tracks the current key being processed. // Tracks the current key being processed.
key tracker.KeyTracker key tracker.KeyTracker
missing []decodeError missing []unstable.ParserError
} }
func (s *strict) EnterTable(node *ast.Node) { func (s *strict) EnterTable(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
@ -23,7 +23,7 @@ func (s *strict) EnterTable(node *ast.Node) {
s.key.UpdateTable(node) s.key.UpdateTable(node)
} }
func (s *strict) EnterArrayTable(node *ast.Node) { func (s *strict) EnterArrayTable(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
@ -31,7 +31,7 @@ func (s *strict) EnterArrayTable(node *ast.Node) {
s.key.UpdateArrayTable(node) s.key.UpdateArrayTable(node)
} }
func (s *strict) EnterKeyValue(node *ast.Node) { func (s *strict) EnterKeyValue(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
@ -39,7 +39,7 @@ func (s *strict) EnterKeyValue(node *ast.Node) {
s.key.Push(node) s.key.Push(node)
} }
func (s *strict) ExitKeyValue(node *ast.Node) { func (s *strict) ExitKeyValue(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
@ -47,27 +47,27 @@ func (s *strict) ExitKeyValue(node *ast.Node) {
s.key.Pop(node) s.key.Pop(node)
} }
func (s *strict) MissingTable(node *ast.Node) { func (s *strict) MissingTable(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
s.missing = append(s.missing, decodeError{ s.missing = append(s.missing, unstable.ParserError{
highlight: keyLocation(node), Highlight: keyLocation(node),
message: "missing table", Message: "missing table",
key: s.key.Key(), Key: s.key.Key(),
}) })
} }
func (s *strict) MissingField(node *ast.Node) { func (s *strict) MissingField(node *unstable.Node) {
if !s.Enabled { if !s.Enabled {
return return
} }
s.missing = append(s.missing, decodeError{ s.missing = append(s.missing, unstable.ParserError{
highlight: keyLocation(node), Highlight: keyLocation(node),
message: "missing field", Message: "missing field",
key: s.key.Key(), Key: s.key.Key(),
}) })
} }
@ -88,7 +88,7 @@ func (s *strict) Error(doc []byte) error {
return err return err
} }
func keyLocation(node *ast.Node) []byte { func keyLocation(node *unstable.Node) []byte {
k := node.Key() k := node.Key()
hasOne := k.Next() hasOne := k.Next()

View file

@ -6,9 +6,9 @@
"time" "time"
) )
var timeType = reflect.TypeOf(time.Time{}) var timeType = reflect.TypeOf((*time.Time)(nil)).Elem()
var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem() var textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem() var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{}) var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}(nil))
var sliceInterfaceType = reflect.TypeOf([]interface{}{}) var sliceInterfaceType = reflect.TypeOf([]interface{}(nil))
var stringType = reflect.TypeOf("") var stringType = reflect.TypeOf("")

View file

@ -12,16 +12,16 @@
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/pelletier/go-toml/v2/internal/ast"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
"github.com/pelletier/go-toml/v2/internal/tracker" "github.com/pelletier/go-toml/v2/internal/tracker"
"github.com/pelletier/go-toml/v2/unstable"
) )
// Unmarshal deserializes a TOML document into a Go value. // Unmarshal deserializes a TOML document into a Go value.
// //
// It is a shortcut for Decoder.Decode() with the default options. // It is a shortcut for Decoder.Decode() with the default options.
func Unmarshal(data []byte, v interface{}) error { func Unmarshal(data []byte, v interface{}) error {
p := parser{} p := unstable.Parser{}
p.Reset(data) p.Reset(data)
d := decoder{p: &p} d := decoder{p: &p}
@ -101,7 +101,7 @@ func (d *Decoder) Decode(v interface{}) error {
return fmt.Errorf("toml: %w", err) return fmt.Errorf("toml: %w", err)
} }
p := parser{} p := unstable.Parser{}
p.Reset(b) p.Reset(b)
dec := decoder{ dec := decoder{
p: &p, p: &p,
@ -115,7 +115,7 @@ func (d *Decoder) Decode(v interface{}) error {
type decoder struct { type decoder struct {
// Which parser instance in use for this decoding session. // Which parser instance in use for this decoding session.
p *parser p *unstable.Parser
// Flag indicating that the current expression is stashed. // Flag indicating that the current expression is stashed.
// If set to true, calling nextExpr will not actually pull a new expression // If set to true, calling nextExpr will not actually pull a new expression
@ -157,7 +157,7 @@ func (d *decoder) typeMismatchError(toml string, target reflect.Type) error {
return fmt.Errorf("toml: cannot decode TOML %s into a Go value of type %s", toml, target) return fmt.Errorf("toml: cannot decode TOML %s into a Go value of type %s", toml, target)
} }
func (d *decoder) expr() *ast.Node { func (d *decoder) expr() *unstable.Node {
return d.p.Expression() return d.p.Expression()
} }
@ -208,12 +208,12 @@ func (d *decoder) FromParser(v interface{}) error {
err := d.fromParser(r) err := d.fromParser(r)
if err == nil { if err == nil {
return d.strict.Error(d.p.data) return d.strict.Error(d.p.Data())
} }
var e *decodeError var e *unstable.ParserError
if errors.As(err, &e) { if errors.As(err, &e) {
return wrapDecodeError(d.p.data, e) return wrapDecodeError(d.p.Data(), e)
} }
return err return err
@ -234,16 +234,16 @@ func (d *decoder) fromParser(root reflect.Value) error {
Rules for the unmarshal code: Rules for the unmarshal code:
- The stack is used to keep track of which values need to be set where. - The stack is used to keep track of which values need to be set where.
- handle* functions <=> switch on a given ast.Kind. - handle* functions <=> switch on a given unstable.Kind.
- unmarshalX* functions need to unmarshal a node of kind X. - unmarshalX* functions need to unmarshal a node of kind X.
- An "object" is either a struct or a map. - An "object" is either a struct or a map.
*/ */
func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error { func (d *decoder) handleRootExpression(expr *unstable.Node, v reflect.Value) error {
var x reflect.Value var x reflect.Value
var err error var err error
if !(d.skipUntilTable && expr.Kind == ast.KeyValue) { if !(d.skipUntilTable && expr.Kind == unstable.KeyValue) {
err = d.seen.CheckExpression(expr) err = d.seen.CheckExpression(expr)
if err != nil { if err != nil {
return err return err
@ -251,16 +251,16 @@ func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error {
} }
switch expr.Kind { switch expr.Kind {
case ast.KeyValue: case unstable.KeyValue:
if d.skipUntilTable { if d.skipUntilTable {
return nil return nil
} }
x, err = d.handleKeyValue(expr, v) x, err = d.handleKeyValue(expr, v)
case ast.Table: case unstable.Table:
d.skipUntilTable = false d.skipUntilTable = false
d.strict.EnterTable(expr) d.strict.EnterTable(expr)
x, err = d.handleTable(expr.Key(), v) x, err = d.handleTable(expr.Key(), v)
case ast.ArrayTable: case unstable.ArrayTable:
d.skipUntilTable = false d.skipUntilTable = false
d.strict.EnterArrayTable(expr) d.strict.EnterArrayTable(expr)
x, err = d.handleArrayTable(expr.Key(), v) x, err = d.handleArrayTable(expr.Key(), v)
@ -269,7 +269,7 @@ func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error {
} }
if d.skipUntilTable { if d.skipUntilTable {
if expr.Kind == ast.Table || expr.Kind == ast.ArrayTable { if expr.Kind == unstable.Table || expr.Kind == unstable.ArrayTable {
d.strict.MissingTable(expr) d.strict.MissingTable(expr)
} }
} else if err == nil && x.IsValid() { } else if err == nil && x.IsValid() {
@ -279,14 +279,14 @@ func (d *decoder) handleRootExpression(expr *ast.Node, v reflect.Value) error {
return err return err
} }
func (d *decoder) handleArrayTable(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleArrayTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
if key.Next() { if key.Next() {
return d.handleArrayTablePart(key, v) return d.handleArrayTablePart(key, v)
} }
return d.handleKeyValues(v) return d.handleKeyValues(v)
} }
func (d *decoder) handleArrayTableCollectionLast(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleArrayTableCollectionLast(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
switch v.Kind() { switch v.Kind() {
case reflect.Interface: case reflect.Interface:
elem := v.Elem() elem := v.Elem()
@ -339,13 +339,13 @@ func (d *decoder) handleArrayTableCollectionLast(key ast.Iterator, v reflect.Val
case reflect.Array: case reflect.Array:
idx := d.arrayIndex(true, v) idx := d.arrayIndex(true, v)
if idx >= v.Len() { if idx >= v.Len() {
return v, fmt.Errorf("toml: cannot decode array table into %s at position %d", v.Type(), idx) return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx)
} }
elem := v.Index(idx) elem := v.Index(idx)
_, err := d.handleArrayTable(key, elem) _, err := d.handleArrayTable(key, elem)
return v, err return v, err
default: default:
return reflect.Value{}, fmt.Errorf("toml: cannot decode array table into a %s", v.Type()) return reflect.Value{}, d.typeMismatchError("array table", v.Type())
} }
} }
@ -353,7 +353,7 @@ func (d *decoder) handleArrayTableCollectionLast(key ast.Iterator, v reflect.Val
// evaluated like a normal key, but if it returns a collection, it also needs to // evaluated like a normal key, but if it returns a collection, it also needs to
// point to the last element of the collection. Unless it is the last part of // point to the last element of the collection. Unless it is the last part of
// the key, then it needs to create a new element at the end. // the key, then it needs to create a new element at the end.
func (d *decoder) handleArrayTableCollection(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleArrayTableCollection(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
if key.IsLast() { if key.IsLast() {
return d.handleArrayTableCollectionLast(key, v) return d.handleArrayTableCollectionLast(key, v)
} }
@ -390,7 +390,7 @@ func (d *decoder) handleArrayTableCollection(key ast.Iterator, v reflect.Value)
case reflect.Array: case reflect.Array:
idx := d.arrayIndex(false, v) idx := d.arrayIndex(false, v)
if idx >= v.Len() { if idx >= v.Len() {
return v, fmt.Errorf("toml: cannot decode array table into %s at position %d", v.Type(), idx) return v, fmt.Errorf("%s at position %d", d.typeMismatchError("array table", v.Type()), idx)
} }
elem := v.Index(idx) elem := v.Index(idx)
_, err := d.handleArrayTable(key, elem) _, err := d.handleArrayTable(key, elem)
@ -400,7 +400,7 @@ func (d *decoder) handleArrayTableCollection(key ast.Iterator, v reflect.Value)
return d.handleArrayTable(key, v) return d.handleArrayTable(key, v)
} }
func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) { func (d *decoder) handleKeyPart(key unstable.Iterator, v reflect.Value, nextFn handlerFn, makeFn valueMakerFn) (reflect.Value, error) {
var rv reflect.Value var rv reflect.Value
// First, dispatch over v to make sure it is a valid object. // First, dispatch over v to make sure it is a valid object.
@ -518,7 +518,7 @@ func (d *decoder) handleKeyPart(key ast.Iterator, v reflect.Value, nextFn handle
// HandleArrayTablePart navigates the Go structure v using the key v. It is // HandleArrayTablePart navigates the Go structure v using the key v. It is
// only used for the prefix (non-last) parts of an array-table. When // only used for the prefix (non-last) parts of an array-table. When
// encountering a collection, it should go to the last element. // encountering a collection, it should go to the last element.
func (d *decoder) handleArrayTablePart(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleArrayTablePart(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
var makeFn valueMakerFn var makeFn valueMakerFn
if key.IsLast() { if key.IsLast() {
makeFn = makeSliceInterface makeFn = makeSliceInterface
@ -530,10 +530,10 @@ func (d *decoder) handleArrayTablePart(key ast.Iterator, v reflect.Value) (refle
// HandleTable returns a reference when it has checked the next expression but // HandleTable returns a reference when it has checked the next expression but
// cannot handle it. // cannot handle it.
func (d *decoder) handleTable(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleTable(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
if v.Len() == 0 { if v.Len() == 0 {
return reflect.Value{}, newDecodeError(key.Node().Data, "cannot store a table in a slice") return reflect.Value{}, unstable.NewParserError(key.Node().Data, "cannot store a table in a slice")
} }
elem := v.Index(v.Len() - 1) elem := v.Index(v.Len() - 1)
x, err := d.handleTable(key, elem) x, err := d.handleTable(key, elem)
@ -560,7 +560,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
var rv reflect.Value var rv reflect.Value
for d.nextExpr() { for d.nextExpr() {
expr := d.expr() expr := d.expr()
if expr.Kind != ast.KeyValue { if expr.Kind != unstable.KeyValue {
// Stash the expression so that fromParser can just loop and use // Stash the expression so that fromParser can just loop and use
// the right handler. // the right handler.
// We could just recurse ourselves here, but at least this gives a // We could just recurse ourselves here, but at least this gives a
@ -587,7 +587,7 @@ func (d *decoder) handleKeyValues(v reflect.Value) (reflect.Value, error) {
} }
type ( type (
handlerFn func(key ast.Iterator, v reflect.Value) (reflect.Value, error) handlerFn func(key unstable.Iterator, v reflect.Value) (reflect.Value, error)
valueMakerFn func() reflect.Value valueMakerFn func() reflect.Value
) )
@ -599,11 +599,11 @@ func makeSliceInterface() reflect.Value {
return reflect.MakeSlice(sliceInterfaceType, 0, 16) return reflect.MakeSlice(sliceInterfaceType, 0, 16)
} }
func (d *decoder) handleTablePart(key ast.Iterator, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleTablePart(key unstable.Iterator, v reflect.Value) (reflect.Value, error) {
return d.handleKeyPart(key, v, d.handleTable, makeMapStringInterface) return d.handleKeyPart(key, v, d.handleTable, makeMapStringInterface)
} }
func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, error) { func (d *decoder) tryTextUnmarshaler(node *unstable.Node, v reflect.Value) (bool, error) {
// Special case for time, because we allow to unmarshal to it from // Special case for time, because we allow to unmarshal to it from
// different kind of AST nodes. // different kind of AST nodes.
if v.Type() == timeType { if v.Type() == timeType {
@ -613,7 +613,7 @@ func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, err
if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) { if v.CanAddr() && v.Addr().Type().Implements(textUnmarshalerType) {
err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data) err := v.Addr().Interface().(encoding.TextUnmarshaler).UnmarshalText(node.Data)
if err != nil { if err != nil {
return false, newDecodeError(d.p.Raw(node.Raw), "%w", err) return false, unstable.NewParserError(d.p.Raw(node.Raw), "%w", err)
} }
return true, nil return true, nil
@ -622,7 +622,7 @@ func (d *decoder) tryTextUnmarshaler(node *ast.Node, v reflect.Value) (bool, err
return false, nil return false, nil
} }
func (d *decoder) handleValue(value *ast.Node, v reflect.Value) error { func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
for v.Kind() == reflect.Ptr { for v.Kind() == reflect.Ptr {
v = initAndDereferencePointer(v) v = initAndDereferencePointer(v)
} }
@ -633,32 +633,32 @@ func (d *decoder) handleValue(value *ast.Node, v reflect.Value) error {
} }
switch value.Kind { switch value.Kind {
case ast.String: case unstable.String:
return d.unmarshalString(value, v) return d.unmarshalString(value, v)
case ast.Integer: case unstable.Integer:
return d.unmarshalInteger(value, v) return d.unmarshalInteger(value, v)
case ast.Float: case unstable.Float:
return d.unmarshalFloat(value, v) return d.unmarshalFloat(value, v)
case ast.Bool: case unstable.Bool:
return d.unmarshalBool(value, v) return d.unmarshalBool(value, v)
case ast.DateTime: case unstable.DateTime:
return d.unmarshalDateTime(value, v) return d.unmarshalDateTime(value, v)
case ast.LocalDate: case unstable.LocalDate:
return d.unmarshalLocalDate(value, v) return d.unmarshalLocalDate(value, v)
case ast.LocalTime: case unstable.LocalTime:
return d.unmarshalLocalTime(value, v) return d.unmarshalLocalTime(value, v)
case ast.LocalDateTime: case unstable.LocalDateTime:
return d.unmarshalLocalDateTime(value, v) return d.unmarshalLocalDateTime(value, v)
case ast.InlineTable: case unstable.InlineTable:
return d.unmarshalInlineTable(value, v) return d.unmarshalInlineTable(value, v)
case ast.Array: case unstable.Array:
return d.unmarshalArray(value, v) return d.unmarshalArray(value, v)
default: default:
panic(fmt.Errorf("handleValue not implemented for %s", value.Kind)) panic(fmt.Errorf("handleValue not implemented for %s", value.Kind))
} }
} }
func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalArray(array *unstable.Node, v reflect.Value) error {
switch v.Kind() { switch v.Kind() {
case reflect.Slice: case reflect.Slice:
if v.IsNil() { if v.IsNil() {
@ -729,7 +729,7 @@ func (d *decoder) unmarshalArray(array *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalInlineTable(itable *unstable.Node, v reflect.Value) error {
// Make sure v is an initialized object. // Make sure v is an initialized object.
switch v.Kind() { switch v.Kind() {
case reflect.Map: case reflect.Map:
@ -746,7 +746,7 @@ func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error
} }
return d.unmarshalInlineTable(itable, elem) return d.unmarshalInlineTable(itable, elem)
default: default:
return newDecodeError(itable.Data, "cannot store inline table in Go type %s", v.Kind()) return unstable.NewParserError(itable.Data, "cannot store inline table in Go type %s", v.Kind())
} }
it := itable.Children() it := itable.Children()
@ -765,7 +765,7 @@ func (d *decoder) unmarshalInlineTable(itable *ast.Node, v reflect.Value) error
return nil return nil
} }
func (d *decoder) unmarshalDateTime(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalDateTime(value *unstable.Node, v reflect.Value) error {
dt, err := parseDateTime(value.Data) dt, err := parseDateTime(value.Data)
if err != nil { if err != nil {
return err return err
@ -775,7 +775,7 @@ func (d *decoder) unmarshalDateTime(value *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalLocalDate(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalLocalDate(value *unstable.Node, v reflect.Value) error {
ld, err := parseLocalDate(value.Data) ld, err := parseLocalDate(value.Data)
if err != nil { if err != nil {
return err return err
@ -792,28 +792,28 @@ func (d *decoder) unmarshalLocalDate(value *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalLocalTime(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalLocalTime(value *unstable.Node, v reflect.Value) error {
lt, rest, err := parseLocalTime(value.Data) lt, rest, err := parseLocalTime(value.Data)
if err != nil { if err != nil {
return err return err
} }
if len(rest) > 0 { if len(rest) > 0 {
return newDecodeError(rest, "extra characters at the end of a local time") return unstable.NewParserError(rest, "extra characters at the end of a local time")
} }
v.Set(reflect.ValueOf(lt)) v.Set(reflect.ValueOf(lt))
return nil return nil
} }
func (d *decoder) unmarshalLocalDateTime(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalLocalDateTime(value *unstable.Node, v reflect.Value) error {
ldt, rest, err := parseLocalDateTime(value.Data) ldt, rest, err := parseLocalDateTime(value.Data)
if err != nil { if err != nil {
return err return err
} }
if len(rest) > 0 { if len(rest) > 0 {
return newDecodeError(rest, "extra characters at the end of a local date time") return unstable.NewParserError(rest, "extra characters at the end of a local date time")
} }
if v.Type() == timeType { if v.Type() == timeType {
@ -828,7 +828,7 @@ func (d *decoder) unmarshalLocalDateTime(value *ast.Node, v reflect.Value) error
return nil return nil
} }
func (d *decoder) unmarshalBool(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalBool(value *unstable.Node, v reflect.Value) error {
b := value.Data[0] == 't' b := value.Data[0] == 't'
switch v.Kind() { switch v.Kind() {
@ -837,13 +837,13 @@ func (d *decoder) unmarshalBool(value *ast.Node, v reflect.Value) error {
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(b)) v.Set(reflect.ValueOf(b))
default: default:
return newDecodeError(value.Data, "cannot assign boolean to a %t", b) return unstable.NewParserError(value.Data, "cannot assign boolean to a %t", b)
} }
return nil return nil
} }
func (d *decoder) unmarshalFloat(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalFloat(value *unstable.Node, v reflect.Value) error {
f, err := parseFloat(value.Data) f, err := parseFloat(value.Data)
if err != nil { if err != nil {
return err return err
@ -854,13 +854,13 @@ func (d *decoder) unmarshalFloat(value *ast.Node, v reflect.Value) error {
v.SetFloat(f) v.SetFloat(f)
case reflect.Float32: case reflect.Float32:
if f > math.MaxFloat32 { if f > math.MaxFloat32 {
return newDecodeError(value.Data, "number %f does not fit in a float32", f) return unstable.NewParserError(value.Data, "number %f does not fit in a float32", f)
} }
v.SetFloat(f) v.SetFloat(f)
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(f)) v.Set(reflect.ValueOf(f))
default: default:
return newDecodeError(value.Data, "float cannot be assigned to %s", v.Kind()) return unstable.NewParserError(value.Data, "float cannot be assigned to %s", v.Kind())
} }
return nil return nil
@ -886,7 +886,7 @@ func init() {
} }
} }
func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalInteger(value *unstable.Node, v reflect.Value) error {
i, err := parseInteger(value.Data) i, err := parseInteger(value.Data)
if err != nil { if err != nil {
return err return err
@ -967,20 +967,20 @@ func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error {
return nil return nil
} }
func (d *decoder) unmarshalString(value *ast.Node, v reflect.Value) error { func (d *decoder) unmarshalString(value *unstable.Node, v reflect.Value) error {
switch v.Kind() { switch v.Kind() {
case reflect.String: case reflect.String:
v.SetString(string(value.Data)) v.SetString(string(value.Data))
case reflect.Interface: case reflect.Interface:
v.Set(reflect.ValueOf(string(value.Data))) v.Set(reflect.ValueOf(string(value.Data)))
default: default:
return newDecodeError(d.p.Raw(value.Raw), "cannot store TOML string into a Go %s", v.Kind()) return unstable.NewParserError(d.p.Raw(value.Raw), "cannot store TOML string into a Go %s", v.Kind())
} }
return nil return nil
} }
func (d *decoder) handleKeyValue(expr *ast.Node, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValue(expr *unstable.Node, v reflect.Value) (reflect.Value, error) {
d.strict.EnterKeyValue(expr) d.strict.EnterKeyValue(expr)
v, err := d.handleKeyValueInner(expr.Key(), expr.Value(), v) v, err := d.handleKeyValueInner(expr.Key(), expr.Value(), v)
@ -994,7 +994,7 @@ func (d *decoder) handleKeyValue(expr *ast.Node, v reflect.Value) (reflect.Value
return v, err return v, err
} }
func (d *decoder) handleKeyValueInner(key ast.Iterator, value *ast.Node, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValueInner(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
if key.Next() { if key.Next() {
// Still scoping the key // Still scoping the key
return d.handleKeyValuePart(key, value, v) return d.handleKeyValuePart(key, value, v)
@ -1004,7 +1004,7 @@ func (d *decoder) handleKeyValueInner(key ast.Iterator, value *ast.Node, v refle
return reflect.Value{}, d.handleValue(value, v) return reflect.Value{}, d.handleValue(value, v)
} }
func (d *decoder) handleKeyValuePart(key ast.Iterator, value *ast.Node, v reflect.Value) (reflect.Value, error) { func (d *decoder) handleKeyValuePart(key unstable.Iterator, value *unstable.Node, v reflect.Value) (reflect.Value, error) {
// contains the replacement for v // contains the replacement for v
var rv reflect.Value var rv reflect.Value

View file

@ -1,4 +1,4 @@
package ast package unstable
import ( import (
"fmt" "fmt"
@ -7,13 +7,16 @@
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
) )
// Iterator starts uninitialized, you need to call Next() first. // Iterator over a sequence of nodes.
//
// Starts uninitialized, you need to call Next() first.
// //
// For example: // For example:
// //
// it := n.Children() // it := n.Children()
// for it.Next() { // for it.Next() {
// it.Node() // n := it.Node()
// // do something with n
// } // }
type Iterator struct { type Iterator struct {
started bool started bool
@ -32,42 +35,31 @@ func (c *Iterator) Next() bool {
} }
// IsLast returns true if the current node of the iterator is the last // IsLast returns true if the current node of the iterator is the last
// one. Subsequent call to Next() will return false. // one. Subsequent calls to Next() will return false.
func (c *Iterator) IsLast() bool { func (c *Iterator) IsLast() bool {
return c.node.next == 0 return c.node.next == 0
} }
// Node returns a copy of the node pointed at by the iterator. // Node returns a pointer to the node pointed at by the iterator.
func (c *Iterator) Node() *Node { func (c *Iterator) Node() *Node {
return c.node return c.node
} }
// Root contains a full AST. // Node in a TOML expression AST.
// //
// It is immutable once constructed with Builder. // Depending on Kind, its sequence of children should be interpreted
type Root struct { // differently.
nodes []Node //
} // - Array have one child per element in the array.
// - InlineTable have one child per key-value in the table (each of kind
// Iterator over the top level nodes. // InlineTable).
func (r *Root) Iterator() Iterator { // - KeyValue have at least two children. The first one is the value. The rest
it := Iterator{} // make a potentially dotted key.
if len(r.nodes) > 0 { // - Table and ArrayTable's children represent a dotted key (same as
it.node = &r.nodes[0] // KeyValue, but without the first node being the value).
} //
return it // When relevant, Raw describes the range of bytes this node is refering to in
} // the input document. Use Parser.Raw() to retrieve the actual bytes.
func (r *Root) at(idx Reference) *Node {
return &r.nodes[idx]
}
// Arrays have one child per element in the array. InlineTables have
// one child per key-value pair in the table. KeyValues have at least
// two children. The first one is the value. The rest make a
// potentially dotted key. Table and Array table have one child per
// element of the key they represent (same as KeyValue, but without
// the last node being the value).
type Node struct { type Node struct {
Kind Kind Kind Kind
Raw Range // Raw bytes from the input. Raw Range // Raw bytes from the input.
@ -80,13 +72,13 @@ type Node struct {
child int // 0 if no child child int // 0 if no child
} }
// Range of bytes in the document.
type Range struct { type Range struct {
Offset uint32 Offset uint32
Length uint32 Length uint32
} }
// Next returns a copy of the next node, or an invalid Node if there // Next returns a pointer to the next node, or nil if there is no next node.
// is no next node.
func (n *Node) Next() *Node { func (n *Node) Next() *Node {
if n.next == 0 { if n.next == 0 {
return nil return nil
@ -96,9 +88,9 @@ func (n *Node) Next() *Node {
return (*Node)(danger.Stride(ptr, size, n.next)) return (*Node)(danger.Stride(ptr, size, n.next))
} }
// Child returns a copy of the first child node of this node. Other // Child returns a pointer to the first child node of this node. Other children
// children can be accessed calling Next on the first child. Returns // can be accessed calling Next on the first child. Returns an nil if this Node
// an invalid Node if there is none. // has no child.
func (n *Node) Child() *Node { func (n *Node) Child() *Node {
if n.child == 0 { if n.child == 0 {
return nil return nil
@ -113,9 +105,9 @@ func (n *Node) Valid() bool {
return n != nil return n != nil
} }
// Key returns the child nodes making the Key on a supported // Key returns the children nodes making the Key on a supported node. Panics
// node. Panics otherwise. They are guaranteed to be all be of the // otherwise. They are guaranteed to be all be of the Kind Key. A simple key
// Kind Key. A simple key would return just one element. // would return just one element.
func (n *Node) Key() Iterator { func (n *Node) Key() Iterator {
switch n.Kind { switch n.Kind {
case KeyValue: case KeyValue:

View file

@ -0,0 +1,71 @@
package unstable
// root contains a full AST.
//
// It is immutable once constructed with Builder.
type root struct {
nodes []Node
}
// Iterator over the top level nodes.
func (r *root) Iterator() Iterator {
it := Iterator{}
if len(r.nodes) > 0 {
it.node = &r.nodes[0]
}
return it
}
func (r *root) at(idx reference) *Node {
return &r.nodes[idx]
}
type reference int
const invalidReference reference = -1
func (r reference) Valid() bool {
return r != invalidReference
}
type builder struct {
tree root
lastIdx int
}
func (b *builder) Tree() *root {
return &b.tree
}
func (b *builder) NodeAt(ref reference) *Node {
return b.tree.at(ref)
}
func (b *builder) Reset() {
b.tree.nodes = b.tree.nodes[:0]
b.lastIdx = 0
}
func (b *builder) Push(n Node) reference {
b.lastIdx = len(b.tree.nodes)
b.tree.nodes = append(b.tree.nodes, n)
return reference(b.lastIdx)
}
func (b *builder) PushAndChain(n Node) reference {
newIdx := len(b.tree.nodes)
b.tree.nodes = append(b.tree.nodes, n)
if b.lastIdx >= 0 {
b.tree.nodes[b.lastIdx].next = newIdx - b.lastIdx
}
b.lastIdx = newIdx
return reference(b.lastIdx)
}
func (b *builder) AttachChild(parent reference, child reference) {
b.tree.nodes[parent].child = int(child) - int(parent)
}
func (b *builder) Chain(from reference, to reference) {
b.tree.nodes[from].next = int(to) - int(from)
}

View file

@ -0,0 +1,3 @@
// Package unstable provides APIs that do not meet the backward compatibility
// guarantees yet.
package unstable

View file

@ -1,25 +1,26 @@
package ast package unstable
import "fmt" import "fmt"
// Kind represents the type of TOML structure contained in a given Node.
type Kind int type Kind int
const ( const (
// meta // Meta
Invalid Kind = iota Invalid Kind = iota
Comment Comment
Key Key
// top level structures // Top level structures
Table Table
ArrayTable ArrayTable
KeyValue KeyValue
// containers values // Containers values
Array Array
InlineTable InlineTable
// values // Values
String String
Bool Bool
Float Float
@ -30,6 +31,7 @@
DateTime DateTime
) )
// String implementation of fmt.Stringer.
func (k Kind) String() string { func (k Kind) String() string {
switch k { switch k {
case Invalid: case Invalid:

View file

@ -1,50 +1,108 @@
package toml package unstable
import ( import (
"bytes" "bytes"
"fmt"
"unicode" "unicode"
"github.com/pelletier/go-toml/v2/internal/ast" "github.com/pelletier/go-toml/v2/internal/characters"
"github.com/pelletier/go-toml/v2/internal/danger" "github.com/pelletier/go-toml/v2/internal/danger"
) )
type parser struct { // ParserError describes an error relative to the content of the document.
builder ast.Builder //
ref ast.Reference // It cannot outlive the instance of Parser it refers to, and may cause panics
// if the parser is reset.
type ParserError struct {
Highlight []byte
Message string
Key []string // optional
}
// Error is the implementation of the error interface.
func (e *ParserError) Error() string {
return e.Message
}
// NewParserError is a convenience function to create a ParserError
//
// Warning: Highlight needs to be a subslice of Parser.data, so only slices
// returned by Parser.Raw are valid candidates.
func NewParserError(highlight []byte, format string, args ...interface{}) error {
return &ParserError{
Highlight: highlight,
Message: fmt.Errorf(format, args...).Error(),
}
}
// Parser scans over a TOML-encoded document and generates an iterative AST.
//
// To prime the Parser, first reset it with the contents of a TOML document.
// Then, process all top-level expressions sequentially. See Example.
//
// Don't forget to check Error() after you're done parsing.
//
// Each top-level expression needs to be fully processed before calling
// NextExpression() again. Otherwise, calls to various Node methods may panic if
// the parser has moved on the next expression.
//
// For performance reasons, go-toml doesn't make a copy of the input bytes to
// the parser. Make sure to copy all the bytes you need to outlive the slice
// given to the parser.
//
// The parser doesn't provide nodes for comments yet, nor for whitespace.
type Parser struct {
data []byte data []byte
builder builder
ref reference
left []byte left []byte
err error err error
first bool first bool
} }
func (p *parser) Range(b []byte) ast.Range { // Data returns the slice provided to the last call to Reset.
return ast.Range{ func (p *Parser) Data() []byte {
return p.data
}
// Range returns a range description that corresponds to a given slice of the
// input. If the argument is not a subslice of the parser input, this function
// panics.
func (p *Parser) Range(b []byte) Range {
return Range{
Offset: uint32(danger.SubsliceOffset(p.data, b)), Offset: uint32(danger.SubsliceOffset(p.data, b)),
Length: uint32(len(b)), Length: uint32(len(b)),
} }
} }
func (p *parser) Raw(raw ast.Range) []byte { // Raw returns the slice corresponding to the bytes in the given range.
func (p *Parser) Raw(raw Range) []byte {
return p.data[raw.Offset : raw.Offset+raw.Length] return p.data[raw.Offset : raw.Offset+raw.Length]
} }
func (p *parser) Reset(b []byte) { // Reset brings the parser to its initial state for a given input. It wipes an
// reuses internal storage to reduce allocation.
func (p *Parser) Reset(b []byte) {
p.builder.Reset() p.builder.Reset()
p.ref = ast.InvalidReference p.ref = invalidReference
p.data = b p.data = b
p.left = b p.left = b
p.err = nil p.err = nil
p.first = true p.first = true
} }
//nolint:cyclop // NextExpression parses the next top-level expression. If an expression was
func (p *parser) NextExpression() bool { // successfully parsed, it returns true. If the parser is at the end of the
// document or an error occurred, it returns false.
//
// Retrieve the parsed expression with Expression().
func (p *Parser) NextExpression() bool {
if len(p.left) == 0 || p.err != nil { if len(p.left) == 0 || p.err != nil {
return false return false
} }
p.builder.Reset() p.builder.Reset()
p.ref = ast.InvalidReference p.ref = invalidReference
for { for {
if len(p.left) == 0 || p.err != nil { if len(p.left) == 0 || p.err != nil {
@ -73,15 +131,18 @@ func (p *parser) NextExpression() bool {
} }
} }
func (p *parser) Expression() *ast.Node { // Expression returns a pointer to the node representing the last successfully
// parsed expresion.
func (p *Parser) Expression() *Node {
return p.builder.NodeAt(p.ref) return p.builder.NodeAt(p.ref)
} }
func (p *parser) Error() error { // Error returns any error that has occured during parsing.
func (p *Parser) Error() error {
return p.err return p.err
} }
func (p *parser) parseNewline(b []byte) ([]byte, error) { func (p *Parser) parseNewline(b []byte) ([]byte, error) {
if b[0] == '\n' { if b[0] == '\n' {
return b[1:], nil return b[1:], nil
} }
@ -91,14 +152,14 @@ func (p *parser) parseNewline(b []byte) ([]byte, error) {
return rest, err return rest, err
} }
return nil, newDecodeError(b[0:1], "expected newline but got %#U", b[0]) return nil, NewParserError(b[0:1], "expected newline but got %#U", b[0])
} }
func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseExpression(b []byte) (reference, []byte, error) {
// expression = ws [ comment ] // expression = ws [ comment ]
// expression =/ ws keyval ws [ comment ] // expression =/ ws keyval ws [ comment ]
// expression =/ ws table ws [ comment ] // expression =/ ws table ws [ comment ]
ref := ast.InvalidReference ref := invalidReference
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
@ -136,7 +197,7 @@ func (p *parser) parseExpression(b []byte) (ast.Reference, []byte, error) {
return ref, b, nil return ref, b, nil
} }
func (p *parser) parseTable(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseTable(b []byte) (reference, []byte, error) {
// table = std-table / array-table // table = std-table / array-table
if len(b) > 1 && b[1] == '[' { if len(b) > 1 && b[1] == '[' {
return p.parseArrayTable(b) return p.parseArrayTable(b)
@ -145,12 +206,12 @@ func (p *parser) parseTable(b []byte) (ast.Reference, []byte, error) {
return p.parseStdTable(b) return p.parseStdTable(b)
} }
func (p *parser) parseArrayTable(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseArrayTable(b []byte) (reference, []byte, error) {
// array-table = array-table-open key array-table-close // array-table = array-table-open key array-table-close
// array-table-open = %x5B.5B ws ; [[ Double left square bracket // array-table-open = %x5B.5B ws ; [[ Double left square bracket
// array-table-close = ws %x5D.5D ; ]] Double right square bracket // array-table-close = ws %x5D.5D ; ]] Double right square bracket
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(Node{
Kind: ast.ArrayTable, Kind: ArrayTable,
}) })
b = b[2:] b = b[2:]
@ -174,12 +235,12 @@ func (p *parser) parseArrayTable(b []byte) (ast.Reference, []byte, error) {
return ref, b, err return ref, b, err
} }
func (p *parser) parseStdTable(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseStdTable(b []byte) (reference, []byte, error) {
// std-table = std-table-open key std-table-close // std-table = std-table-open key std-table-close
// std-table-open = %x5B ws ; [ Left square bracket // std-table-open = %x5B ws ; [ Left square bracket
// std-table-close = ws %x5D ; ] Right square bracket // std-table-close = ws %x5D ; ] Right square bracket
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(Node{
Kind: ast.Table, Kind: Table,
}) })
b = b[1:] b = b[1:]
@ -199,15 +260,15 @@ func (p *parser) parseStdTable(b []byte) (ast.Reference, []byte, error) {
return ref, b, err return ref, b, err
} }
func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseKeyval(b []byte) (reference, []byte, error) {
// keyval = key keyval-sep val // keyval = key keyval-sep val
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(Node{
Kind: ast.KeyValue, Kind: KeyValue,
}) })
key, b, err := p.parseKey(b) key, b, err := p.parseKey(b)
if err != nil { if err != nil {
return ast.InvalidReference, nil, err return invalidReference, nil, err
} }
// keyval-sep = ws %x3D ws ; = // keyval-sep = ws %x3D ws ; =
@ -215,12 +276,12 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) == 0 { if len(b) == 0 {
return ast.InvalidReference, nil, newDecodeError(b, "expected = after a key, but the document ends there") return invalidReference, nil, NewParserError(b, "expected = after a key, but the document ends there")
} }
b, err = expect('=', b) b, err = expect('=', b)
if err != nil { if err != nil {
return ast.InvalidReference, nil, err return invalidReference, nil, err
} }
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
@ -237,12 +298,12 @@ func (p *parser) parseKeyval(b []byte) (ast.Reference, []byte, error) {
} }
//nolint:cyclop,funlen //nolint:cyclop,funlen
func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseVal(b []byte) (reference, []byte, error) {
// val = string / boolean / array / inline-table / date-time / float / integer // val = string / boolean / array / inline-table / date-time / float / integer
ref := ast.InvalidReference ref := invalidReference
if len(b) == 0 { if len(b) == 0 {
return ref, nil, newDecodeError(b, "expected value, not eof") return ref, nil, NewParserError(b, "expected value, not eof")
} }
var err error var err error
@ -259,8 +320,8 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
} }
if err == nil { if err == nil {
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(Node{
Kind: ast.String, Kind: String,
Raw: p.Range(raw), Raw: p.Range(raw),
Data: v, Data: v,
}) })
@ -277,8 +338,8 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
} }
if err == nil { if err == nil {
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(Node{
Kind: ast.String, Kind: String,
Raw: p.Range(raw), Raw: p.Range(raw),
Data: v, Data: v,
}) })
@ -287,22 +348,22 @@ func (p *parser) parseVal(b []byte) (ast.Reference, []byte, error) {
return ref, b, err return ref, b, err
case 't': case 't':
if !scanFollowsTrue(b) { if !scanFollowsTrue(b) {
return ref, nil, newDecodeError(atmost(b, 4), "expected 'true'") return ref, nil, NewParserError(atmost(b, 4), "expected 'true'")
} }
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(Node{
Kind: ast.Bool, Kind: Bool,
Data: b[:4], Data: b[:4],
}) })
return ref, b[4:], nil return ref, b[4:], nil
case 'f': case 'f':
if !scanFollowsFalse(b) { if !scanFollowsFalse(b) {
return ref, nil, newDecodeError(atmost(b, 5), "expected 'false'") return ref, nil, NewParserError(atmost(b, 5), "expected 'false'")
} }
ref = p.builder.Push(ast.Node{ ref = p.builder.Push(Node{
Kind: ast.Bool, Kind: Bool,
Data: b[:5], Data: b[:5],
}) })
@ -324,7 +385,7 @@ func atmost(b []byte, n int) []byte {
return b[:n] return b[:n]
} }
func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, []byte, error) { func (p *Parser) parseLiteralString(b []byte) ([]byte, []byte, []byte, error) {
v, rest, err := scanLiteralString(b) v, rest, err := scanLiteralString(b)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -333,19 +394,19 @@ func (p *parser) parseLiteralString(b []byte) ([]byte, []byte, []byte, error) {
return v, v[1 : len(v)-1], rest, nil return v, v[1 : len(v)-1], rest, nil
} }
func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseInlineTable(b []byte) (reference, []byte, error) {
// inline-table = inline-table-open [ inline-table-keyvals ] inline-table-close // inline-table = inline-table-open [ inline-table-keyvals ] inline-table-close
// inline-table-open = %x7B ws ; { // inline-table-open = %x7B ws ; {
// inline-table-close = ws %x7D ; } // inline-table-close = ws %x7D ; }
// inline-table-sep = ws %x2C ws ; , Comma // inline-table-sep = ws %x2C ws ; , Comma
// inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ] // inline-table-keyvals = keyval [ inline-table-sep inline-table-keyvals ]
parent := p.builder.Push(ast.Node{ parent := p.builder.Push(Node{
Kind: ast.InlineTable, Kind: InlineTable,
}) })
first := true first := true
var child ast.Reference var child reference
b = b[1:] b = b[1:]
@ -356,7 +417,7 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
if len(b) == 0 { if len(b) == 0 {
return parent, nil, newDecodeError(previousB[:1], "inline table is incomplete") return parent, nil, NewParserError(previousB[:1], "inline table is incomplete")
} }
if b[0] == '}' { if b[0] == '}' {
@ -371,7 +432,7 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
} }
var kv ast.Reference var kv reference
kv, b, err = p.parseKeyval(b) kv, b, err = p.parseKeyval(b)
if err != nil { if err != nil {
@ -394,7 +455,7 @@ func (p *parser) parseInlineTable(b []byte) (ast.Reference, []byte, error) {
} }
//nolint:funlen,cyclop //nolint:funlen,cyclop
func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseValArray(b []byte) (reference, []byte, error) {
// array = array-open [ array-values ] ws-comment-newline array-close // array = array-open [ array-values ] ws-comment-newline array-close
// array-open = %x5B ; [ // array-open = %x5B ; [
// array-close = %x5D ; ] // array-close = %x5D ; ]
@ -405,13 +466,13 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
arrayStart := b arrayStart := b
b = b[1:] b = b[1:]
parent := p.builder.Push(ast.Node{ parent := p.builder.Push(Node{
Kind: ast.Array, Kind: Array,
}) })
first := true first := true
var lastChild ast.Reference var lastChild reference
var err error var err error
for len(b) > 0 { for len(b) > 0 {
@ -421,7 +482,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
} }
if len(b) == 0 { if len(b) == 0 {
return parent, nil, newDecodeError(arrayStart[:1], "array is incomplete") return parent, nil, NewParserError(arrayStart[:1], "array is incomplete")
} }
if b[0] == ']' { if b[0] == ']' {
@ -430,7 +491,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
if b[0] == ',' { if b[0] == ',' {
if first { if first {
return parent, nil, newDecodeError(b[0:1], "array cannot start with comma") return parent, nil, NewParserError(b[0:1], "array cannot start with comma")
} }
b = b[1:] b = b[1:]
@ -439,7 +500,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
return parent, nil, err return parent, nil, err
} }
} else if !first { } else if !first {
return parent, nil, newDecodeError(b[0:1], "array elements must be separated by commas") return parent, nil, NewParserError(b[0:1], "array elements must be separated by commas")
} }
// TOML allows trailing commas in arrays. // TOML allows trailing commas in arrays.
@ -447,7 +508,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
break break
} }
var valueRef ast.Reference var valueRef reference
valueRef, b, err = p.parseVal(b) valueRef, b, err = p.parseVal(b)
if err != nil { if err != nil {
return parent, nil, err return parent, nil, err
@ -472,7 +533,7 @@ func (p *parser) parseValArray(b []byte) (ast.Reference, []byte, error) {
return parent, rest, err return parent, rest, err
} }
func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) { func (p *Parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error) {
for len(b) > 0 { for len(b) > 0 {
var err error var err error
b = p.parseWhitespace(b) b = p.parseWhitespace(b)
@ -501,7 +562,7 @@ func (p *parser) parseOptionalWhitespaceCommentNewline(b []byte) ([]byte, error)
return b, nil return b, nil
} }
func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, []byte, error) { func (p *Parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, []byte, error) {
token, rest, err := scanMultilineLiteralString(b) token, rest, err := scanMultilineLiteralString(b)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
@ -520,7 +581,7 @@ func (p *parser) parseMultilineLiteralString(b []byte) ([]byte, []byte, []byte,
} }
//nolint:funlen,gocognit,cyclop //nolint:funlen,gocognit,cyclop
func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, error) { func (p *Parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, error) {
// ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body // ml-basic-string = ml-basic-string-delim [ newline ] ml-basic-body
// ml-basic-string-delim // ml-basic-string-delim
// ml-basic-string-delim = 3quotation-mark // ml-basic-string-delim = 3quotation-mark
@ -551,11 +612,11 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
if !escaped { if !escaped {
str := token[startIdx:endIdx] str := token[startIdx:endIdx]
verr := utf8TomlValidAlreadyEscaped(str) verr := characters.Utf8TomlValidAlreadyEscaped(str)
if verr.Zero() { if verr.Zero() {
return token, str, rest, nil return token, str, rest, nil
} }
return nil, nil, nil, newDecodeError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8") return nil, nil, nil, NewParserError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8")
} }
var builder bytes.Buffer var builder bytes.Buffer
@ -635,13 +696,13 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
builder.WriteRune(x) builder.WriteRune(x)
i += 8 i += 8
default: default:
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) return nil, nil, nil, NewParserError(token[i:i+1], "invalid escaped character %#U", c)
} }
i++ i++
} else { } else {
size := utf8ValidNext(token[i:]) size := characters.Utf8ValidNext(token[i:])
if size == 0 { if size == 0 {
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid character %#U", c) return nil, nil, nil, NewParserError(token[i:i+1], "invalid character %#U", c)
} }
builder.Write(token[i : i+size]) builder.Write(token[i : i+size])
i += size i += size
@ -651,7 +712,7 @@ func (p *parser) parseMultilineBasicString(b []byte) ([]byte, []byte, []byte, er
return token, builder.Bytes(), rest, nil return token, builder.Bytes(), rest, nil
} }
func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseKey(b []byte) (reference, []byte, error) {
// key = simple-key / dotted-key // key = simple-key / dotted-key
// simple-key = quoted-key / unquoted-key // simple-key = quoted-key / unquoted-key
// //
@ -662,11 +723,11 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
// dot-sep = ws %x2E ws ; . Period // dot-sep = ws %x2E ws ; . Period
raw, key, b, err := p.parseSimpleKey(b) raw, key, b, err := p.parseSimpleKey(b)
if err != nil { if err != nil {
return ast.InvalidReference, nil, err return invalidReference, nil, err
} }
ref := p.builder.Push(ast.Node{ ref := p.builder.Push(Node{
Kind: ast.Key, Kind: Key,
Raw: p.Range(raw), Raw: p.Range(raw),
Data: key, Data: key,
}) })
@ -681,8 +742,8 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
return ref, nil, err return ref, nil, err
} }
p.builder.PushAndChain(ast.Node{ p.builder.PushAndChain(Node{
Kind: ast.Key, Kind: Key,
Raw: p.Range(raw), Raw: p.Range(raw),
Data: key, Data: key,
}) })
@ -694,9 +755,9 @@ func (p *parser) parseKey(b []byte) (ast.Reference, []byte, error) {
return ref, b, nil return ref, b, nil
} }
func (p *parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) { func (p *Parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) {
if len(b) == 0 { if len(b) == 0 {
return nil, nil, nil, newDecodeError(b, "expected key but found none") return nil, nil, nil, NewParserError(b, "expected key but found none")
} }
// simple-key = quoted-key / unquoted-key // simple-key = quoted-key / unquoted-key
@ -711,12 +772,12 @@ func (p *parser) parseSimpleKey(b []byte) (raw, key, rest []byte, err error) {
key, rest = scanUnquotedKey(b) key, rest = scanUnquotedKey(b)
return key, key, rest, nil return key, key, rest, nil
default: default:
return nil, nil, nil, newDecodeError(b[0:1], "invalid character at start of key: %c", b[0]) return nil, nil, nil, NewParserError(b[0:1], "invalid character at start of key: %c", b[0])
} }
} }
//nolint:funlen,cyclop //nolint:funlen,cyclop
func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) { func (p *Parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
// basic-string = quotation-mark *basic-char quotation-mark // basic-string = quotation-mark *basic-char quotation-mark
// quotation-mark = %x22 ; " // quotation-mark = %x22 ; "
// basic-char = basic-unescaped / escaped // basic-char = basic-unescaped / escaped
@ -744,11 +805,11 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
// validate the string and return a direct reference to the buffer. // validate the string and return a direct reference to the buffer.
if !escaped { if !escaped {
str := token[startIdx:endIdx] str := token[startIdx:endIdx]
verr := utf8TomlValidAlreadyEscaped(str) verr := characters.Utf8TomlValidAlreadyEscaped(str)
if verr.Zero() { if verr.Zero() {
return token, str, rest, nil return token, str, rest, nil
} }
return nil, nil, nil, newDecodeError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8") return nil, nil, nil, NewParserError(str[verr.Index:verr.Index+verr.Size], "invalid UTF-8")
} }
i := startIdx i := startIdx
@ -795,13 +856,13 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
builder.WriteRune(x) builder.WriteRune(x)
i += 8 i += 8
default: default:
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid escaped character %#U", c) return nil, nil, nil, NewParserError(token[i:i+1], "invalid escaped character %#U", c)
} }
i++ i++
} else { } else {
size := utf8ValidNext(token[i:]) size := characters.Utf8ValidNext(token[i:])
if size == 0 { if size == 0 {
return nil, nil, nil, newDecodeError(token[i:i+1], "invalid character %#U", c) return nil, nil, nil, NewParserError(token[i:i+1], "invalid character %#U", c)
} }
builder.Write(token[i : i+size]) builder.Write(token[i : i+size])
i += size i += size
@ -813,7 +874,7 @@ func (p *parser) parseBasicString(b []byte) ([]byte, []byte, []byte, error) {
func hexToRune(b []byte, length int) (rune, error) { func hexToRune(b []byte, length int) (rune, error) {
if len(b) < length { if len(b) < length {
return -1, newDecodeError(b, "unicode point needs %d character, not %d", length, len(b)) return -1, NewParserError(b, "unicode point needs %d character, not %d", length, len(b))
} }
b = b[:length] b = b[:length]
@ -828,19 +889,19 @@ func hexToRune(b []byte, length int) (rune, error) {
case 'A' <= c && c <= 'F': case 'A' <= c && c <= 'F':
d = uint32(c - 'A' + 10) d = uint32(c - 'A' + 10)
default: default:
return -1, newDecodeError(b[i:i+1], "non-hex character") return -1, NewParserError(b[i:i+1], "non-hex character")
} }
r = r*16 + d r = r*16 + d
} }
if r > unicode.MaxRune || 0xD800 <= r && r < 0xE000 { if r > unicode.MaxRune || 0xD800 <= r && r < 0xE000 {
return -1, newDecodeError(b, "escape sequence is invalid Unicode code point") return -1, NewParserError(b, "escape sequence is invalid Unicode code point")
} }
return rune(r), nil return rune(r), nil
} }
func (p *parser) parseWhitespace(b []byte) []byte { func (p *Parser) parseWhitespace(b []byte) []byte {
// ws = *wschar // ws = *wschar
// wschar = %x20 ; Space // wschar = %x20 ; Space
// wschar =/ %x09 ; Horizontal tab // wschar =/ %x09 ; Horizontal tab
@ -850,24 +911,24 @@ func (p *parser) parseWhitespace(b []byte) []byte {
} }
//nolint:cyclop //nolint:cyclop
func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, error) { func (p *Parser) parseIntOrFloatOrDateTime(b []byte) (reference, []byte, error) {
switch b[0] { switch b[0] {
case 'i': case 'i':
if !scanFollowsInf(b) { if !scanFollowsInf(b) {
return ast.InvalidReference, nil, newDecodeError(atmost(b, 3), "expected 'inf'") return invalidReference, nil, NewParserError(atmost(b, 3), "expected 'inf'")
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Float, Kind: Float,
Data: b[:3], Data: b[:3],
}), b[3:], nil }), b[3:], nil
case 'n': case 'n':
if !scanFollowsNan(b) { if !scanFollowsNan(b) {
return ast.InvalidReference, nil, newDecodeError(atmost(b, 3), "expected 'nan'") return invalidReference, nil, NewParserError(atmost(b, 3), "expected 'nan'")
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Float, Kind: Float,
Data: b[:3], Data: b[:3],
}), b[3:], nil }), b[3:], nil
case '+', '-': case '+', '-':
@ -898,7 +959,7 @@ func (p *parser) parseIntOrFloatOrDateTime(b []byte) (ast.Reference, []byte, err
return p.scanIntOrFloat(b) return p.scanIntOrFloat(b)
} }
func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) { func (p *Parser) scanDateTime(b []byte) (reference, []byte, error) {
// scans for contiguous characters in [0-9T:Z.+-], and up to one space if // scans for contiguous characters in [0-9T:Z.+-], and up to one space if
// followed by a digit. // followed by a digit.
hasDate := false hasDate := false
@ -941,30 +1002,30 @@ func (p *parser) scanDateTime(b []byte) (ast.Reference, []byte, error) {
} }
} }
var kind ast.Kind var kind Kind
if hasTime { if hasTime {
if hasDate { if hasDate {
if hasTz { if hasTz {
kind = ast.DateTime kind = DateTime
} else { } else {
kind = ast.LocalDateTime kind = LocalDateTime
} }
} else { } else {
kind = ast.LocalTime kind = LocalTime
} }
} else { } else {
kind = ast.LocalDate kind = LocalDate
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: kind, Kind: kind,
Data: b[:i], Data: b[:i],
}), b[i:], nil }), b[i:], nil
} }
//nolint:funlen,gocognit,cyclop //nolint:funlen,gocognit,cyclop
func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) { func (p *Parser) scanIntOrFloat(b []byte) (reference, []byte, error) {
i := 0 i := 0
if len(b) > 2 && b[0] == '0' && b[1] != '.' && b[1] != 'e' && b[1] != 'E' { if len(b) > 2 && b[0] == '0' && b[1] != '.' && b[1] != 'e' && b[1] != 'E' {
@ -990,8 +1051,8 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
} }
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Integer, Kind: Integer,
Data: b[:i], Data: b[:i],
}), b[i:], nil }), b[i:], nil
} }
@ -1013,40 +1074,40 @@ func (p *parser) scanIntOrFloat(b []byte) (ast.Reference, []byte, error) {
if c == 'i' { if c == 'i' {
if scanFollowsInf(b[i:]) { if scanFollowsInf(b[i:]) {
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Float, Kind: Float,
Data: b[:i+3], Data: b[:i+3],
}), b[i+3:], nil }), b[i+3:], nil
} }
return ast.InvalidReference, nil, newDecodeError(b[i:i+1], "unexpected character 'i' while scanning for a number") return invalidReference, nil, NewParserError(b[i:i+1], "unexpected character 'i' while scanning for a number")
} }
if c == 'n' { if c == 'n' {
if scanFollowsNan(b[i:]) { if scanFollowsNan(b[i:]) {
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: ast.Float, Kind: Float,
Data: b[:i+3], Data: b[:i+3],
}), b[i+3:], nil }), b[i+3:], nil
} }
return ast.InvalidReference, nil, newDecodeError(b[i:i+1], "unexpected character 'n' while scanning for a number") return invalidReference, nil, NewParserError(b[i:i+1], "unexpected character 'n' while scanning for a number")
} }
break break
} }
if i == 0 { if i == 0 {
return ast.InvalidReference, b, newDecodeError(b, "incomplete number") return invalidReference, b, NewParserError(b, "incomplete number")
} }
kind := ast.Integer kind := Integer
if isFloat { if isFloat {
kind = ast.Float kind = Float
} }
return p.builder.Push(ast.Node{ return p.builder.Push(Node{
Kind: kind, Kind: kind,
Data: b[:i], Data: b[:i],
}), b[i:], nil }), b[i:], nil
@ -1075,11 +1136,11 @@ func isValidBinaryRune(r byte) bool {
func expect(x byte, b []byte) ([]byte, error) { func expect(x byte, b []byte) ([]byte, error) {
if len(b) == 0 { if len(b) == 0 {
return nil, newDecodeError(b, "expected character %c but the document ended here", x) return nil, NewParserError(b, "expected character %c but the document ended here", x)
} }
if b[0] != x { if b[0] != x {
return nil, newDecodeError(b[0:1], "expected character %c", x) return nil, NewParserError(b[0:1], "expected character %c", x)
} }
return b[1:], nil return b[1:], nil

View file

@ -1,4 +1,6 @@
package toml package unstable
import "github.com/pelletier/go-toml/v2/internal/characters"
func scanFollows(b []byte, pattern string) bool { func scanFollows(b []byte, pattern string) bool {
n := len(pattern) n := len(pattern)
@ -54,16 +56,16 @@ func scanLiteralString(b []byte) ([]byte, []byte, error) {
case '\'': case '\'':
return b[:i+1], b[i+1:], nil return b[:i+1], b[i+1:], nil
case '\n', '\r': case '\n', '\r':
return nil, nil, newDecodeError(b[i:i+1], "literal strings cannot have new lines") return nil, nil, NewParserError(b[i:i+1], "literal strings cannot have new lines")
} }
size := utf8ValidNext(b[i:]) size := characters.Utf8ValidNext(b[i:])
if size == 0 { if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character") return nil, nil, NewParserError(b[i:i+1], "invalid character")
} }
i += size i += size
} }
return nil, nil, newDecodeError(b[len(b):], "unterminated literal string") return nil, nil, NewParserError(b[len(b):], "unterminated literal string")
} }
func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) { func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
@ -98,39 +100,39 @@ func scanMultilineLiteralString(b []byte) ([]byte, []byte, error) {
i++ i++
if i < len(b) && b[i] == '\'' { if i < len(b) && b[i] == '\'' {
return nil, nil, newDecodeError(b[i-3:i+1], "''' not allowed in multiline literal string") return nil, nil, NewParserError(b[i-3:i+1], "''' not allowed in multiline literal string")
} }
return b[:i], b[i:], nil return b[:i], b[i:], nil
} }
case '\r': case '\r':
if len(b) < i+2 { if len(b) < i+2 {
return nil, nil, newDecodeError(b[len(b):], `need a \n after \r`) return nil, nil, NewParserError(b[len(b):], `need a \n after \r`)
} }
if b[i+1] != '\n' { if b[i+1] != '\n' {
return nil, nil, newDecodeError(b[i:i+2], `need a \n after \r`) return nil, nil, NewParserError(b[i:i+2], `need a \n after \r`)
} }
i += 2 // skip the \n i += 2 // skip the \n
continue continue
} }
size := utf8ValidNext(b[i:]) size := characters.Utf8ValidNext(b[i:])
if size == 0 { if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character") return nil, nil, NewParserError(b[i:i+1], "invalid character")
} }
i += size i += size
} }
return nil, nil, newDecodeError(b[len(b):], `multiline literal string not terminated by '''`) return nil, nil, NewParserError(b[len(b):], `multiline literal string not terminated by '''`)
} }
func scanWindowsNewline(b []byte) ([]byte, []byte, error) { func scanWindowsNewline(b []byte) ([]byte, []byte, error) {
const lenCRLF = 2 const lenCRLF = 2
if len(b) < lenCRLF { if len(b) < lenCRLF {
return nil, nil, newDecodeError(b, "windows new line expected") return nil, nil, NewParserError(b, "windows new line expected")
} }
if b[1] != '\n' { if b[1] != '\n' {
return nil, nil, newDecodeError(b, `windows new line should be \r\n`) return nil, nil, NewParserError(b, `windows new line should be \r\n`)
} }
return b[:lenCRLF], b[lenCRLF:], nil return b[:lenCRLF], b[lenCRLF:], nil
@ -165,11 +167,11 @@ func scanComment(b []byte) ([]byte, []byte, error) {
if i+1 < len(b) && b[i+1] == '\n' { if i+1 < len(b) && b[i+1] == '\n' {
return b[:i+1], b[i+1:], nil return b[:i+1], b[i+1:], nil
} }
return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment") return nil, nil, NewParserError(b[i:i+1], "invalid character in comment")
} }
size := utf8ValidNext(b[i:]) size := characters.Utf8ValidNext(b[i:])
if size == 0 { if size == 0 {
return nil, nil, newDecodeError(b[i:i+1], "invalid character in comment") return nil, nil, NewParserError(b[i:i+1], "invalid character in comment")
} }
i += size i += size
@ -192,17 +194,17 @@ func scanBasicString(b []byte) ([]byte, bool, []byte, error) {
case '"': case '"':
return b[:i+1], escaped, b[i+1:], nil return b[:i+1], escaped, b[i+1:], nil
case '\n', '\r': case '\n', '\r':
return nil, escaped, nil, newDecodeError(b[i:i+1], "basic strings cannot have new lines") return nil, escaped, nil, NewParserError(b[i:i+1], "basic strings cannot have new lines")
case '\\': case '\\':
if len(b) < i+2 { if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[i:i+1], "need a character after \\") return nil, escaped, nil, NewParserError(b[i:i+1], "need a character after \\")
} }
escaped = true escaped = true
i++ // skip the next character i++ // skip the next character
} }
} }
return nil, escaped, nil, newDecodeError(b[len(b):], `basic string not terminated by "`) return nil, escaped, nil, NewParserError(b[len(b):], `basic string not terminated by "`)
} }
func scanMultilineBasicString(b []byte) ([]byte, bool, []byte, error) { func scanMultilineBasicString(b []byte) ([]byte, bool, []byte, error) {
@ -243,27 +245,27 @@ func scanMultilineBasicString(b []byte) ([]byte, bool, []byte, error) {
i++ i++
if i < len(b) && b[i] == '"' { if i < len(b) && b[i] == '"' {
return nil, escaped, nil, newDecodeError(b[i-3:i+1], `""" not allowed in multiline basic string`) return nil, escaped, nil, NewParserError(b[i-3:i+1], `""" not allowed in multiline basic string`)
} }
return b[:i], escaped, b[i:], nil return b[:i], escaped, b[i:], nil
} }
case '\\': case '\\':
if len(b) < i+2 { if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[len(b):], "need a character after \\") return nil, escaped, nil, NewParserError(b[len(b):], "need a character after \\")
} }
escaped = true escaped = true
i++ // skip the next character i++ // skip the next character
case '\r': case '\r':
if len(b) < i+2 { if len(b) < i+2 {
return nil, escaped, nil, newDecodeError(b[len(b):], `need a \n after \r`) return nil, escaped, nil, NewParserError(b[len(b):], `need a \n after \r`)
} }
if b[i+1] != '\n' { if b[i+1] != '\n' {
return nil, escaped, nil, newDecodeError(b[i:i+2], `need a \n after \r`) return nil, escaped, nil, NewParserError(b[i:i+2], `need a \n after \r`)
} }
i++ // skip the \n i++ // skip the \n
} }
} }
return nil, escaped, nil, newDecodeError(b[len(b):], `multiline basic string not terminated by """`) return nil, escaped, nil, NewParserError(b[len(b):], `multiline basic string not terminated by """`)
} }

View file

@ -605,7 +605,10 @@ func (z *Tokenizer) readComment() {
z.data.end = z.data.start z.data.end = z.data.start
} }
}() }()
for dashCount := 2; ; {
var dashCount int
beginning := true
for {
c := z.readByte() c := z.readByte()
if z.err != nil { if z.err != nil {
// Ignore up to two dashes at EOF. // Ignore up to two dashes at EOF.
@ -620,7 +623,7 @@ func (z *Tokenizer) readComment() {
dashCount++ dashCount++
continue continue
case '>': case '>':
if dashCount >= 2 { if dashCount >= 2 || beginning {
z.data.end = z.raw.end - len("-->") z.data.end = z.raw.end - len("-->")
return return
} }
@ -638,6 +641,7 @@ func (z *Tokenizer) readComment() {
} }
} }
dashCount = 0 dashCount = 0
beginning = false
} }
} }

View file

@ -109,6 +109,7 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if http2VerboseLogs { if http2VerboseLogs {
log.Printf("h2c: error h2c upgrade: %v", err) log.Printf("h2c: error h2c upgrade: %v", err)
} }
w.WriteHeader(http.StatusInternalServerError)
return return
} }
defer conn.Close() defer conn.Close()
@ -167,7 +168,10 @@ func h2cUpgrade(w http.ResponseWriter, r *http.Request) (_ net.Conn, settings []
return nil, nil, errors.New("h2c: connection does not support Hijack") return nil, nil, errors.New("h2c: connection does not support Hijack")
} }
body, _ := io.ReadAll(r.Body) body, err := io.ReadAll(r.Body)
if err != nil {
return nil, nil, err
}
r.Body = io.NopCloser(bytes.NewBuffer(body)) r.Body = io.NopCloser(bytes.NewBuffer(body))
conn, rw, err := hijacker.Hijack() conn, rw, err := hijacker.Hijack()

View file

@ -27,7 +27,14 @@ func buildCommonHeaderMaps() {
"accept-language", "accept-language",
"accept-ranges", "accept-ranges",
"age", "age",
"access-control-allow-credentials",
"access-control-allow-headers",
"access-control-allow-methods",
"access-control-allow-origin", "access-control-allow-origin",
"access-control-expose-headers",
"access-control-max-age",
"access-control-request-headers",
"access-control-request-method",
"allow", "allow",
"authorization", "authorization",
"cache-control", "cache-control",
@ -53,6 +60,7 @@ func buildCommonHeaderMaps() {
"link", "link",
"location", "location",
"max-forwards", "max-forwards",
"origin",
"proxy-authenticate", "proxy-authenticate",
"proxy-authorization", "proxy-authorization",
"range", "range",
@ -68,6 +76,8 @@ func buildCommonHeaderMaps() {
"vary", "vary",
"via", "via",
"www-authenticate", "www-authenticate",
"x-forwarded-for",
"x-forwarded-proto",
} }
commonLowerHeader = make(map[string]string, len(common)) commonLowerHeader = make(map[string]string, len(common))
commonCanonHeader = make(map[string]string, len(common)) commonCanonHeader = make(map[string]string, len(common))
@ -85,3 +95,11 @@ func lowerHeader(v string) (lower string, ascii bool) {
} }
return asciiToLower(v) return asciiToLower(v)
} }
func canonicalHeader(v string) string {
buildCommonHeaderMapsOnce()
if s, ok := commonCanonHeader[v]; ok {
return s
}
return http.CanonicalHeaderKey(v)
}

View file

@ -116,6 +116,11 @@ func (e *Encoder) SetMaxDynamicTableSize(v uint32) {
e.dynTab.setMaxSize(v) e.dynTab.setMaxSize(v)
} }
// MaxDynamicTableSize returns the current dynamic header table size.
func (e *Encoder) MaxDynamicTableSize() (v uint32) {
return e.dynTab.maxSize
}
// SetMaxDynamicTableSizeLimit changes the maximum value that can be // SetMaxDynamicTableSizeLimit changes the maximum value that can be
// specified in SetMaxDynamicTableSize to v. By default, it is set to // specified in SetMaxDynamicTableSize to v. By default, it is set to
// 4096, which is the same size of the default dynamic header table // 4096, which is the same size of the default dynamic header table

188
vendor/golang.org/x/net/http2/hpack/static_table.go generated vendored Normal file
View file

@ -0,0 +1,188 @@
// go generate gen.go
// Code generated by the command above; DO NOT EDIT.
package hpack
var staticTable = &headerFieldTable{
evictCount: 0,
byName: map[string]uint64{
":authority": 1,
":method": 3,
":path": 5,
":scheme": 7,
":status": 14,
"accept-charset": 15,
"accept-encoding": 16,
"accept-language": 17,
"accept-ranges": 18,
"accept": 19,
"access-control-allow-origin": 20,
"age": 21,
"allow": 22,
"authorization": 23,
"cache-control": 24,
"content-disposition": 25,
"content-encoding": 26,
"content-language": 27,
"content-length": 28,
"content-location": 29,
"content-range": 30,
"content-type": 31,
"cookie": 32,
"date": 33,
"etag": 34,
"expect": 35,
"expires": 36,
"from": 37,
"host": 38,
"if-match": 39,
"if-modified-since": 40,
"if-none-match": 41,
"if-range": 42,
"if-unmodified-since": 43,
"last-modified": 44,
"link": 45,
"location": 46,
"max-forwards": 47,
"proxy-authenticate": 48,
"proxy-authorization": 49,
"range": 50,
"referer": 51,
"refresh": 52,
"retry-after": 53,
"server": 54,
"set-cookie": 55,
"strict-transport-security": 56,
"transfer-encoding": 57,
"user-agent": 58,
"vary": 59,
"via": 60,
"www-authenticate": 61,
},
byNameValue: map[pairNameValue]uint64{
{name: ":authority", value: ""}: 1,
{name: ":method", value: "GET"}: 2,
{name: ":method", value: "POST"}: 3,
{name: ":path", value: "/"}: 4,
{name: ":path", value: "/index.html"}: 5,
{name: ":scheme", value: "http"}: 6,
{name: ":scheme", value: "https"}: 7,
{name: ":status", value: "200"}: 8,
{name: ":status", value: "204"}: 9,
{name: ":status", value: "206"}: 10,
{name: ":status", value: "304"}: 11,
{name: ":status", value: "400"}: 12,
{name: ":status", value: "404"}: 13,
{name: ":status", value: "500"}: 14,
{name: "accept-charset", value: ""}: 15,
{name: "accept-encoding", value: "gzip, deflate"}: 16,
{name: "accept-language", value: ""}: 17,
{name: "accept-ranges", value: ""}: 18,
{name: "accept", value: ""}: 19,
{name: "access-control-allow-origin", value: ""}: 20,
{name: "age", value: ""}: 21,
{name: "allow", value: ""}: 22,
{name: "authorization", value: ""}: 23,
{name: "cache-control", value: ""}: 24,
{name: "content-disposition", value: ""}: 25,
{name: "content-encoding", value: ""}: 26,
{name: "content-language", value: ""}: 27,
{name: "content-length", value: ""}: 28,
{name: "content-location", value: ""}: 29,
{name: "content-range", value: ""}: 30,
{name: "content-type", value: ""}: 31,
{name: "cookie", value: ""}: 32,
{name: "date", value: ""}: 33,
{name: "etag", value: ""}: 34,
{name: "expect", value: ""}: 35,
{name: "expires", value: ""}: 36,
{name: "from", value: ""}: 37,
{name: "host", value: ""}: 38,
{name: "if-match", value: ""}: 39,
{name: "if-modified-since", value: ""}: 40,
{name: "if-none-match", value: ""}: 41,
{name: "if-range", value: ""}: 42,
{name: "if-unmodified-since", value: ""}: 43,
{name: "last-modified", value: ""}: 44,
{name: "link", value: ""}: 45,
{name: "location", value: ""}: 46,
{name: "max-forwards", value: ""}: 47,
{name: "proxy-authenticate", value: ""}: 48,
{name: "proxy-authorization", value: ""}: 49,
{name: "range", value: ""}: 50,
{name: "referer", value: ""}: 51,
{name: "refresh", value: ""}: 52,
{name: "retry-after", value: ""}: 53,
{name: "server", value: ""}: 54,
{name: "set-cookie", value: ""}: 55,
{name: "strict-transport-security", value: ""}: 56,
{name: "transfer-encoding", value: ""}: 57,
{name: "user-agent", value: ""}: 58,
{name: "vary", value: ""}: 59,
{name: "via", value: ""}: 60,
{name: "www-authenticate", value: ""}: 61,
},
ents: []HeaderField{
{Name: ":authority", Value: "", Sensitive: false},
{Name: ":method", Value: "GET", Sensitive: false},
{Name: ":method", Value: "POST", Sensitive: false},
{Name: ":path", Value: "/", Sensitive: false},
{Name: ":path", Value: "/index.html", Sensitive: false},
{Name: ":scheme", Value: "http", Sensitive: false},
{Name: ":scheme", Value: "https", Sensitive: false},
{Name: ":status", Value: "200", Sensitive: false},
{Name: ":status", Value: "204", Sensitive: false},
{Name: ":status", Value: "206", Sensitive: false},
{Name: ":status", Value: "304", Sensitive: false},
{Name: ":status", Value: "400", Sensitive: false},
{Name: ":status", Value: "404", Sensitive: false},
{Name: ":status", Value: "500", Sensitive: false},
{Name: "accept-charset", Value: "", Sensitive: false},
{Name: "accept-encoding", Value: "gzip, deflate", Sensitive: false},
{Name: "accept-language", Value: "", Sensitive: false},
{Name: "accept-ranges", Value: "", Sensitive: false},
{Name: "accept", Value: "", Sensitive: false},
{Name: "access-control-allow-origin", Value: "", Sensitive: false},
{Name: "age", Value: "", Sensitive: false},
{Name: "allow", Value: "", Sensitive: false},
{Name: "authorization", Value: "", Sensitive: false},
{Name: "cache-control", Value: "", Sensitive: false},
{Name: "content-disposition", Value: "", Sensitive: false},
{Name: "content-encoding", Value: "", Sensitive: false},
{Name: "content-language", Value: "", Sensitive: false},
{Name: "content-length", Value: "", Sensitive: false},
{Name: "content-location", Value: "", Sensitive: false},
{Name: "content-range", Value: "", Sensitive: false},
{Name: "content-type", Value: "", Sensitive: false},
{Name: "cookie", Value: "", Sensitive: false},
{Name: "date", Value: "", Sensitive: false},
{Name: "etag", Value: "", Sensitive: false},
{Name: "expect", Value: "", Sensitive: false},
{Name: "expires", Value: "", Sensitive: false},
{Name: "from", Value: "", Sensitive: false},
{Name: "host", Value: "", Sensitive: false},
{Name: "if-match", Value: "", Sensitive: false},
{Name: "if-modified-since", Value: "", Sensitive: false},
{Name: "if-none-match", Value: "", Sensitive: false},
{Name: "if-range", Value: "", Sensitive: false},
{Name: "if-unmodified-since", Value: "", Sensitive: false},
{Name: "last-modified", Value: "", Sensitive: false},
{Name: "link", Value: "", Sensitive: false},
{Name: "location", Value: "", Sensitive: false},
{Name: "max-forwards", Value: "", Sensitive: false},
{Name: "proxy-authenticate", Value: "", Sensitive: false},
{Name: "proxy-authorization", Value: "", Sensitive: false},
{Name: "range", Value: "", Sensitive: false},
{Name: "referer", Value: "", Sensitive: false},
{Name: "refresh", Value: "", Sensitive: false},
{Name: "retry-after", Value: "", Sensitive: false},
{Name: "server", Value: "", Sensitive: false},
{Name: "set-cookie", Value: "", Sensitive: false},
{Name: "strict-transport-security", Value: "", Sensitive: false},
{Name: "transfer-encoding", Value: "", Sensitive: false},
{Name: "user-agent", Value: "", Sensitive: false},
{Name: "vary", Value: "", Sensitive: false},
{Name: "via", Value: "", Sensitive: false},
{Name: "www-authenticate", Value: "", Sensitive: false},
},
}

View file

@ -96,8 +96,7 @@ func (t *headerFieldTable) evictOldest(n int) {
// meaning t.ents is reversed for dynamic tables. Hence, when t is a dynamic // meaning t.ents is reversed for dynamic tables. Hence, when t is a dynamic
// table, the return value i actually refers to the entry t.ents[t.len()-i]. // table, the return value i actually refers to the entry t.ents[t.len()-i].
// //
// All tables are assumed to be a dynamic tables except for the global // All tables are assumed to be a dynamic tables except for the global staticTable.
// staticTable pointer.
// //
// See Section 2.3.3. // See Section 2.3.3.
func (t *headerFieldTable) search(f HeaderField) (i uint64, nameValueMatch bool) { func (t *headerFieldTable) search(f HeaderField) (i uint64, nameValueMatch bool) {
@ -125,81 +124,6 @@ func (t *headerFieldTable) idToIndex(id uint64) uint64 {
return k + 1 return k + 1
} }
// http://tools.ietf.org/html/draft-ietf-httpbis-header-compression-07#appendix-B
var staticTable = newStaticTable()
var staticTableEntries = [...]HeaderField{
{Name: ":authority"},
{Name: ":method", Value: "GET"},
{Name: ":method", Value: "POST"},
{Name: ":path", Value: "/"},
{Name: ":path", Value: "/index.html"},
{Name: ":scheme", Value: "http"},
{Name: ":scheme", Value: "https"},
{Name: ":status", Value: "200"},
{Name: ":status", Value: "204"},
{Name: ":status", Value: "206"},
{Name: ":status", Value: "304"},
{Name: ":status", Value: "400"},
{Name: ":status", Value: "404"},
{Name: ":status", Value: "500"},
{Name: "accept-charset"},
{Name: "accept-encoding", Value: "gzip, deflate"},
{Name: "accept-language"},
{Name: "accept-ranges"},
{Name: "accept"},
{Name: "access-control-allow-origin"},
{Name: "age"},
{Name: "allow"},
{Name: "authorization"},
{Name: "cache-control"},
{Name: "content-disposition"},
{Name: "content-encoding"},
{Name: "content-language"},
{Name: "content-length"},
{Name: "content-location"},
{Name: "content-range"},
{Name: "content-type"},
{Name: "cookie"},
{Name: "date"},
{Name: "etag"},
{Name: "expect"},
{Name: "expires"},
{Name: "from"},
{Name: "host"},
{Name: "if-match"},
{Name: "if-modified-since"},
{Name: "if-none-match"},
{Name: "if-range"},
{Name: "if-unmodified-since"},
{Name: "last-modified"},
{Name: "link"},
{Name: "location"},
{Name: "max-forwards"},
{Name: "proxy-authenticate"},
{Name: "proxy-authorization"},
{Name: "range"},
{Name: "referer"},
{Name: "refresh"},
{Name: "retry-after"},
{Name: "server"},
{Name: "set-cookie"},
{Name: "strict-transport-security"},
{Name: "transfer-encoding"},
{Name: "user-agent"},
{Name: "vary"},
{Name: "via"},
{Name: "www-authenticate"},
}
func newStaticTable() *headerFieldTable {
t := &headerFieldTable{}
t.init()
for _, e := range staticTableEntries[:] {
t.addEntry(e)
}
return t
}
var huffmanCodes = [256]uint32{ var huffmanCodes = [256]uint32{
0x1ff8, 0x1ff8,
0x7fffd8, 0x7fffd8,

View file

@ -98,6 +98,19 @@ type Server struct {
// the HTTP/2 spec's recommendations. // the HTTP/2 spec's recommendations.
MaxConcurrentStreams uint32 MaxConcurrentStreams uint32
// MaxDecoderHeaderTableSize optionally specifies the http2
// SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It
// informs the remote endpoint of the maximum size of the header compression
// table used to decode header blocks, in octets. If zero, the default value
// of 4096 is used.
MaxDecoderHeaderTableSize uint32
// MaxEncoderHeaderTableSize optionally specifies an upper limit for the
// header compression table used for encoding request headers. Received
// SETTINGS_HEADER_TABLE_SIZE settings are capped at this limit. If zero,
// the default value of 4096 is used.
MaxEncoderHeaderTableSize uint32
// MaxReadFrameSize optionally specifies the largest frame // MaxReadFrameSize optionally specifies the largest frame
// this server is willing to read. A valid value is between // this server is willing to read. A valid value is between
// 16k and 16M, inclusive. If zero or otherwise invalid, a // 16k and 16M, inclusive. If zero or otherwise invalid, a
@ -170,6 +183,20 @@ func (s *Server) maxConcurrentStreams() uint32 {
return defaultMaxStreams return defaultMaxStreams
} }
func (s *Server) maxDecoderHeaderTableSize() uint32 {
if v := s.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (s *Server) maxEncoderHeaderTableSize() uint32 {
if v := s.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
// maxQueuedControlFrames is the maximum number of control frames like // maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before // SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks. // the connection is closed to prevent memory exhaustion attacks.
@ -394,7 +421,6 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
advMaxStreams: s.maxConcurrentStreams(), advMaxStreams: s.maxConcurrentStreams(),
initialStreamSendWindowSize: initialWindowSize, initialStreamSendWindowSize: initialWindowSize,
maxFrameSize: initialMaxFrameSize, maxFrameSize: initialMaxFrameSize,
headerTableSize: initialHeaderTableSize,
serveG: newGoroutineLock(), serveG: newGoroutineLock(),
pushEnabled: true, pushEnabled: true,
sawClientPreface: opts.SawClientPreface, sawClientPreface: opts.SawClientPreface,
@ -424,12 +450,13 @@ func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
sc.flow.add(initialWindowSize) sc.flow.add(initialWindowSize)
sc.inflow.add(initialWindowSize) sc.inflow.add(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize())
fr := NewFramer(sc.bw, c) fr := NewFramer(sc.bw, c)
if s.CountError != nil { if s.CountError != nil {
fr.countError = s.CountError fr.countError = s.CountError
} }
fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize()) fr.SetMaxReadFrameSize(s.maxReadFrameSize())
sc.framer = fr sc.framer = fr
@ -559,9 +586,9 @@ type serverConn struct {
streams map[uint32]*stream streams map[uint32]*stream
initialStreamSendWindowSize int32 initialStreamSendWindowSize int32
maxFrameSize int32 maxFrameSize int32
headerTableSize uint32
peerMaxHeaderListSize uint32 // zero means unknown (default) peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
canonHeaderKeysSize int // canonHeader keys size in bytes
writingFrame bool // started writing a frame (on serve goroutine or separate) writingFrame bool // started writing a frame (on serve goroutine or separate)
writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh
needsFrameFlush bool // last frame write wasn't a flush needsFrameFlush bool // last frame write wasn't a flush
@ -622,7 +649,9 @@ type stream struct {
resetQueued bool // RST_STREAM queued for write; set by sc.resetStream resetQueued bool // RST_STREAM queued for write; set by sc.resetStream
gotTrailerHeader bool // HEADER frame for trailers was seen gotTrailerHeader bool // HEADER frame for trailers was seen
wroteHeaders bool // whether we wrote headers (not status 100) wroteHeaders bool // whether we wrote headers (not status 100)
readDeadline *time.Timer // nil if unused
writeDeadline *time.Timer // nil if unused writeDeadline *time.Timer // nil if unused
closeErr error // set before cw is closed
trailer http.Header // accumulated trailers trailer http.Header // accumulated trailers
reqTrailer http.Header // handler's Request.Trailer reqTrailer http.Header // handler's Request.Trailer
@ -738,6 +767,13 @@ func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
} }
} }
// maxCachedCanonicalHeadersKeysSize is an arbitrarily-chosen limit on the size
// of the entries in the canonHeader cache.
// This should be larger than the size of unique, uncommon header keys likely to
// be sent by the peer, while not so high as to permit unreasonable memory usage
// if the peer sends an unbounded number of unique header keys.
const maxCachedCanonicalHeadersKeysSize = 2048
func (sc *serverConn) canonicalHeader(v string) string { func (sc *serverConn) canonicalHeader(v string) string {
sc.serveG.check() sc.serveG.check()
buildCommonHeaderMapsOnce() buildCommonHeaderMapsOnce()
@ -753,14 +789,10 @@ func (sc *serverConn) canonicalHeader(v string) string {
sc.canonHeader = make(map[string]string) sc.canonHeader = make(map[string]string)
} }
cv = http.CanonicalHeaderKey(v) cv = http.CanonicalHeaderKey(v)
// maxCachedCanonicalHeaders is an arbitrarily-chosen limit on the number of size := 100 + len(v)*2 // 100 bytes of map overhead + key + value
// entries in the canonHeader cache. This should be larger than the number if sc.canonHeaderKeysSize+size <= maxCachedCanonicalHeadersKeysSize {
// of unique, uncommon header keys likely to be sent by the peer, while not
// so high as to permit unreasonable memory usage if the peer sends an unbounded
// number of unique header keys.
const maxCachedCanonicalHeaders = 32
if len(sc.canonHeader) < maxCachedCanonicalHeaders {
sc.canonHeader[v] = cv sc.canonHeader[v] = cv
sc.canonHeaderKeysSize += size
} }
return cv return cv
} }
@ -862,6 +894,7 @@ func (sc *serverConn) serve() {
{SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
{SettingMaxConcurrentStreams, sc.advMaxStreams}, {SettingMaxConcurrentStreams, sc.advMaxStreams},
{SettingMaxHeaderListSize, sc.maxHeaderListSize()}, {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
{SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()},
{SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, {SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())},
}, },
}) })
@ -869,7 +902,9 @@ func (sc *serverConn) serve() {
// Each connection starts with initialWindowSize inflow tokens. // Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens. // If a higher value is configured, we add more tokens.
sc.sendWindowUpdate(nil) if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff))
}
if err := sc.readPreface(); err != nil { if err := sc.readPreface(); err != nil {
sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err) sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
@ -946,6 +981,8 @@ func (sc *serverConn) serve() {
} }
case *startPushRequest: case *startPushRequest:
sc.startPush(v) sc.startPush(v)
case func(*serverConn):
v(sc)
default: default:
panic(fmt.Sprintf("unexpected type %T", v)) panic(fmt.Sprintf("unexpected type %T", v))
} }
@ -1459,6 +1496,21 @@ func (sc *serverConn) processFrame(f Frame) error {
sc.sawFirstSettings = true sc.sawFirstSettings = true
} }
// Discard frames for streams initiated after the identified last
// stream sent in a GOAWAY, or all frames after sending an error.
// We still need to return connection-level flow control for DATA frames.
// RFC 9113 Section 6.8.
if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
if f, ok := f.(*DataFrame); ok {
if sc.inflow.available() < int32(f.Length) {
return sc.countError("data_flow", streamError(f.Header().StreamID, ErrCodeFlowControl))
}
sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
}
return nil
}
switch f := f.(type) { switch f := f.(type) {
case *SettingsFrame: case *SettingsFrame:
return sc.processSettings(f) return sc.processSettings(f)
@ -1501,9 +1553,6 @@ func (sc *serverConn) processPing(f *PingFrame) error {
// PROTOCOL_ERROR." // PROTOCOL_ERROR."
return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol)) return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol))
} }
if sc.inGoAway && sc.goAwayCode != ErrCodeNo {
return nil
}
sc.writeFrame(FrameWriteRequest{write: writePingAck{f}}) sc.writeFrame(FrameWriteRequest{write: writePingAck{f}})
return nil return nil
} }
@ -1565,6 +1614,9 @@ func (sc *serverConn) closeStream(st *stream, err error) {
panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
} }
st.state = stateClosed st.state = stateClosed
if st.readDeadline != nil {
st.readDeadline.Stop()
}
if st.writeDeadline != nil { if st.writeDeadline != nil {
st.writeDeadline.Stop() st.writeDeadline.Stop()
} }
@ -1586,10 +1638,18 @@ func (sc *serverConn) closeStream(st *stream, err error) {
if p := st.body; p != nil { if p := st.body; p != nil {
// Return any buffered unread bytes worth of conn-level flow control. // Return any buffered unread bytes worth of conn-level flow control.
// See golang.org/issue/16481 // See golang.org/issue/16481
sc.sendWindowUpdate(nil) sc.sendWindowUpdate(nil, p.Len())
p.CloseWithError(err) p.CloseWithError(err)
} }
if e, ok := err.(StreamError); ok {
if e.Cause != nil {
err = e.Cause
} else {
err = errStreamClosed
}
}
st.closeErr = err
st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc st.cw.Close() // signals Handler's CloseNotifier, unblocks writes, etc
sc.writeSched.CloseStream(st.id) sc.writeSched.CloseStream(st.id)
} }
@ -1632,7 +1692,6 @@ func (sc *serverConn) processSetting(s Setting) error {
} }
switch s.ID { switch s.ID {
case SettingHeaderTableSize: case SettingHeaderTableSize:
sc.headerTableSize = s.Val
sc.hpackEncoder.SetMaxDynamicTableSize(s.Val) sc.hpackEncoder.SetMaxDynamicTableSize(s.Val)
case SettingEnablePush: case SettingEnablePush:
sc.pushEnabled = s.Val != 0 sc.pushEnabled = s.Val != 0
@ -1686,16 +1745,6 @@ func (sc *serverConn) processSettingInitialWindowSize(val uint32) error {
func (sc *serverConn) processData(f *DataFrame) error { func (sc *serverConn) processData(f *DataFrame) error {
sc.serveG.check() sc.serveG.check()
id := f.Header().StreamID id := f.Header().StreamID
if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || id > sc.maxClientStreamID) {
// Discard all DATA frames if the GOAWAY is due to an
// error, or:
//
// Section 6.8: After sending a GOAWAY frame, the sender
// can discard frames for streams initiated by the
// receiver with identifiers higher than the identified
// last stream.
return nil
}
data := f.Data() data := f.Data()
state, st := sc.state(id) state, st := sc.state(id)
@ -1734,7 +1783,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
// sendWindowUpdate, which also schedules sending the // sendWindowUpdate, which also schedules sending the
// frames. // frames.
sc.inflow.take(int32(f.Length)) sc.inflow.take(int32(f.Length))
sc.sendWindowUpdate(nil) // conn-level sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
if st != nil && st.resetQueued { if st != nil && st.resetQueued {
// Already have a stream error in flight. Don't send another. // Already have a stream error in flight. Don't send another.
@ -1752,7 +1801,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
return sc.countError("data_flow", streamError(id, ErrCodeFlowControl)) return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
} }
sc.inflow.take(int32(f.Length)) sc.inflow.take(int32(f.Length))
sc.sendWindowUpdate(nil) // conn-level sc.sendWindowUpdate(nil, int(f.Length)) // conn-level
st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes)) st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
// RFC 7540, sec 8.1.2.6: A request or response is also malformed if the // RFC 7540, sec 8.1.2.6: A request or response is also malformed if the
@ -1770,7 +1819,7 @@ func (sc *serverConn) processData(f *DataFrame) error {
if len(data) > 0 { if len(data) > 0 {
wrote, err := st.body.Write(data) wrote, err := st.body.Write(data)
if err != nil { if err != nil {
sc.sendWindowUpdate32(nil, int32(f.Length)-int32(wrote)) sc.sendWindowUpdate(nil, int(f.Length)-wrote)
return sc.countError("body_write_err", streamError(id, ErrCodeStreamClosed)) return sc.countError("body_write_err", streamError(id, ErrCodeStreamClosed))
} }
if wrote != len(data) { if wrote != len(data) {
@ -1838,19 +1887,27 @@ func (st *stream) copyTrailersToHandlerRequest() {
} }
} }
// onReadTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's ReadTimeout has fired.
func (st *stream) onReadTimeout() {
// Wrap the ErrDeadlineExceeded to avoid callers depending on us
// returning the bare error.
st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
}
// onWriteTimeout is run on its own goroutine (from time.AfterFunc) // onWriteTimeout is run on its own goroutine (from time.AfterFunc)
// when the stream's WriteTimeout has fired. // when the stream's WriteTimeout has fired.
func (st *stream) onWriteTimeout() { func (st *stream) onWriteTimeout() {
st.sc.writeFrameFromHandler(FrameWriteRequest{write: streamError(st.id, ErrCodeInternal)}) st.sc.writeFrameFromHandler(FrameWriteRequest{write: StreamError{
StreamID: st.id,
Code: ErrCodeInternal,
Cause: os.ErrDeadlineExceeded,
}})
} }
func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error { func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
sc.serveG.check() sc.serveG.check()
id := f.StreamID id := f.StreamID
if sc.inGoAway {
// Ignore.
return nil
}
// http://tools.ietf.org/html/rfc7540#section-5.1.1 // http://tools.ietf.org/html/rfc7540#section-5.1.1
// Streams initiated by a client MUST use odd-numbered stream // Streams initiated by a client MUST use odd-numbered stream
// identifiers. [...] An endpoint that receives an unexpected // identifiers. [...] An endpoint that receives an unexpected
@ -1953,6 +2010,9 @@ func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
// (in Go 1.8), though. That's a more sane option anyway. // (in Go 1.8), though. That's a more sane option anyway.
if sc.hs.ReadTimeout != 0 { if sc.hs.ReadTimeout != 0 {
sc.conn.SetReadDeadline(time.Time{}) sc.conn.SetReadDeadline(time.Time{})
if st.body != nil {
st.readDeadline = time.AfterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
}
} }
go sc.runHandler(rw, req, handler) go sc.runHandler(rw, req, handler)
@ -2021,9 +2081,6 @@ func (sc *serverConn) checkPriority(streamID uint32, p PriorityParam) error {
} }
func (sc *serverConn) processPriority(f *PriorityFrame) error { func (sc *serverConn) processPriority(f *PriorityFrame) error {
if sc.inGoAway {
return nil
}
if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil { if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
return err return err
} }
@ -2322,39 +2379,24 @@ func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int, err error) {
func (sc *serverConn) noteBodyRead(st *stream, n int) { func (sc *serverConn) noteBodyRead(st *stream, n int) {
sc.serveG.check() sc.serveG.check()
sc.sendWindowUpdate(nil) // conn-level sc.sendWindowUpdate(nil, n) // conn-level
if st.state != stateHalfClosedRemote && st.state != stateClosed { if st.state != stateHalfClosedRemote && st.state != stateClosed {
// Don't send this WINDOW_UPDATE if the stream is closed // Don't send this WINDOW_UPDATE if the stream is closed
// remotely. // remotely.
sc.sendWindowUpdate(st) sc.sendWindowUpdate(st, n)
} }
} }
// st may be nil for conn-level // st may be nil for conn-level
func (sc *serverConn) sendWindowUpdate(st *stream) { func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
sc.serveG.check() sc.serveG.check()
var n int32
if st == nil {
if avail, windowSize := sc.inflow.available(), sc.srv.initialConnRecvWindowSize(); avail > windowSize/2 {
return
} else {
n = windowSize - avail
}
} else {
if avail, windowSize := st.inflow.available(), sc.srv.initialStreamRecvWindowSize(); avail > windowSize/2 {
return
} else {
n = windowSize - avail
}
}
// "The legal range for the increment to the flow control // "The legal range for the increment to the flow control
// window is 1 to 2^31-1 (2,147,483,647) octets." // window is 1 to 2^31-1 (2,147,483,647) octets."
// A Go Read call on 64-bit machines could in theory read // A Go Read call on 64-bit machines could in theory read
// a larger Read than this. Very unlikely, but we handle it here // a larger Read than this. Very unlikely, but we handle it here
// rather than elsewhere for now. // rather than elsewhere for now.
const maxUint31 = 1<<31 - 1 const maxUint31 = 1<<31 - 1
for n >= maxUint31 { for n > maxUint31 {
sc.sendWindowUpdate32(st, maxUint31) sc.sendWindowUpdate32(st, maxUint31)
n -= maxUint31 n -= maxUint31
} }
@ -2474,7 +2516,15 @@ type responseWriterState struct {
type chunkWriter struct{ rws *responseWriterState } type chunkWriter struct{ rws *responseWriterState }
func (cw chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) } func (cw chunkWriter) Write(p []byte) (n int, err error) {
n, err = cw.rws.writeChunk(p)
if err == errStreamClosed {
// If writing failed because the stream has been closed,
// return the reason it was closed.
err = cw.rws.stream.closeErr
}
return n, err
}
func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 } func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 }
@ -2668,23 +2718,85 @@ func (rws *responseWriterState) promoteUndeclaredTrailers() {
} }
} }
func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onReadTimeout()
return nil
}
w.rws.conn.sendServeMsg(func(sc *serverConn) {
if st.readDeadline != nil {
if !st.readDeadline.Stop() {
// Deadline already exceeded, or stream has been closed.
return
}
}
if deadline.IsZero() {
st.readDeadline = nil
} else if st.readDeadline == nil {
st.readDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onReadTimeout)
} else {
st.readDeadline.Reset(deadline.Sub(time.Now()))
}
})
return nil
}
func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
st := w.rws.stream
if !deadline.IsZero() && deadline.Before(time.Now()) {
// If we're setting a deadline in the past, reset the stream immediately
// so writes after SetWriteDeadline returns will fail.
st.onWriteTimeout()
return nil
}
w.rws.conn.sendServeMsg(func(sc *serverConn) {
if st.writeDeadline != nil {
if !st.writeDeadline.Stop() {
// Deadline already exceeded, or stream has been closed.
return
}
}
if deadline.IsZero() {
st.writeDeadline = nil
} else if st.writeDeadline == nil {
st.writeDeadline = time.AfterFunc(deadline.Sub(time.Now()), st.onWriteTimeout)
} else {
st.writeDeadline.Reset(deadline.Sub(time.Now()))
}
})
return nil
}
func (w *responseWriter) Flush() { func (w *responseWriter) Flush() {
w.FlushError()
}
func (w *responseWriter) FlushError() error {
rws := w.rws rws := w.rws
if rws == nil { if rws == nil {
panic("Header called after Handler finished") panic("Header called after Handler finished")
} }
var err error
if rws.bw.Buffered() > 0 { if rws.bw.Buffered() > 0 {
if err := rws.bw.Flush(); err != nil { err = rws.bw.Flush()
// Ignore the error. The frame writer already knows.
return
}
} else { } else {
// The bufio.Writer won't call chunkWriter.Write // The bufio.Writer won't call chunkWriter.Write
// (writeChunk with zero bytes, so we have to do it // (writeChunk with zero bytes, so we have to do it
// ourselves to force the HTTP response header and/or // ourselves to force the HTTP response header and/or
// final DATA frame (with END_STREAM) to be sent. // final DATA frame (with END_STREAM) to be sent.
rws.writeChunk(nil) _, err = chunkWriter{rws}.Write(nil)
if err == nil {
select {
case <-rws.stream.cw:
err = rws.stream.closeErr
default:
}
}
} }
return err
} }
func (w *responseWriter) CloseNotify() <-chan bool { func (w *responseWriter) CloseNotify() <-chan bool {

View file

@ -16,6 +16,7 @@
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"log" "log"
"math" "math"
mathrand "math/rand" mathrand "math/rand"
@ -117,6 +118,28 @@ type Transport struct {
// to mean no limit. // to mean no limit.
MaxHeaderListSize uint32 MaxHeaderListSize uint32
// MaxReadFrameSize is the http2 SETTINGS_MAX_FRAME_SIZE to send in the
// initial settings frame. It is the size in bytes of the largest frame
// payload that the sender is willing to receive. If 0, no setting is
// sent, and the value is provided by the peer, which should be 16384
// according to the spec:
// https://datatracker.ietf.org/doc/html/rfc7540#section-6.5.2.
// Values are bounded in the range 16k to 16M.
MaxReadFrameSize uint32
// MaxDecoderHeaderTableSize optionally specifies the http2
// SETTINGS_HEADER_TABLE_SIZE to send in the initial settings frame. It
// informs the remote endpoint of the maximum size of the header compression
// table used to decode header blocks, in octets. If zero, the default value
// of 4096 is used.
MaxDecoderHeaderTableSize uint32
// MaxEncoderHeaderTableSize optionally specifies an upper limit for the
// header compression table used for encoding request headers. Received
// SETTINGS_HEADER_TABLE_SIZE settings are capped at this limit. If zero,
// the default value of 4096 is used.
MaxEncoderHeaderTableSize uint32
// StrictMaxConcurrentStreams controls whether the server's // StrictMaxConcurrentStreams controls whether the server's
// SETTINGS_MAX_CONCURRENT_STREAMS should be respected // SETTINGS_MAX_CONCURRENT_STREAMS should be respected
// globally. If false, new TCP connections are created to the // globally. If false, new TCP connections are created to the
@ -170,6 +193,19 @@ func (t *Transport) maxHeaderListSize() uint32 {
return t.MaxHeaderListSize return t.MaxHeaderListSize
} }
func (t *Transport) maxFrameReadSize() uint32 {
if t.MaxReadFrameSize == 0 {
return 0 // use the default provided by the peer
}
if t.MaxReadFrameSize < minMaxFrameSize {
return minMaxFrameSize
}
if t.MaxReadFrameSize > maxFrameSize {
return maxFrameSize
}
return t.MaxReadFrameSize
}
func (t *Transport) disableCompression() bool { func (t *Transport) disableCompression() bool {
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
} }
@ -292,10 +328,11 @@ type ClientConn struct {
lastActive time.Time lastActive time.Time
lastIdle time.Time // time last idle lastIdle time.Time // time last idle
// Settings from peer: (also guarded by wmu) // Settings from peer: (also guarded by wmu)
maxFrameSize uint32 maxFrameSize uint32
maxConcurrentStreams uint32 maxConcurrentStreams uint32
peerMaxHeaderListSize uint64 peerMaxHeaderListSize uint64
initialWindowSize uint32 peerMaxHeaderTableSize uint32
initialWindowSize uint32
// reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
// Write to reqHeaderMu to lock it, read from it to unlock. // Write to reqHeaderMu to lock it, read from it to unlock.
@ -501,6 +538,15 @@ func authorityAddr(scheme string, authority string) (addr string) {
return net.JoinHostPort(host, port) return net.JoinHostPort(host, port)
} }
var retryBackoffHook func(time.Duration) *time.Timer
func backoffNewTimer(d time.Duration) *time.Timer {
if retryBackoffHook != nil {
return retryBackoffHook(d)
}
return time.NewTimer(d)
}
// RoundTripOpt is like RoundTrip, but takes options. // RoundTripOpt is like RoundTrip, but takes options.
func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) { func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
@ -526,11 +572,14 @@ func (t *Transport) RoundTripOpt(req *http.Request, opt RoundTripOpt) (*http.Res
} }
backoff := float64(uint(1) << (uint(retry) - 1)) backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64()) backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
timer := backoffNewTimer(d)
select { select {
case <-time.After(time.Second * time.Duration(backoff)): case <-timer.C:
t.vlogf("RoundTrip retrying after failure: %v", err) t.vlogf("RoundTrip retrying after failure: %v", err)
continue continue
case <-req.Context().Done(): case <-req.Context().Done():
timer.Stop()
err = req.Context().Err() err = req.Context().Err()
} }
} }
@ -668,6 +717,20 @@ func (t *Transport) expectContinueTimeout() time.Duration {
return t.t1.ExpectContinueTimeout return t.t1.ExpectContinueTimeout
} }
func (t *Transport) maxDecoderHeaderTableSize() uint32 {
if v := t.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (t *Transport) maxEncoderHeaderTableSize() uint32 {
if v := t.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives()) return t.newClientConn(c, t.disableKeepAlives())
} }
@ -708,15 +771,19 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
}) })
cc.br = bufio.NewReader(c) cc.br = bufio.NewReader(c)
cc.fr = NewFramer(cc.bw, cc.br) cc.fr = NewFramer(cc.bw, cc.br)
if t.maxFrameReadSize() != 0 {
cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize())
}
if t.CountError != nil { if t.CountError != nil {
cc.fr.countError = t.CountError cc.fr.countError = t.CountError
} }
cc.fr.ReadMetaHeaders = hpack.NewDecoder(initialHeaderTableSize, nil) maxHeaderTableSize := t.maxDecoderHeaderTableSize()
cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil)
cc.fr.MaxHeaderListSize = t.maxHeaderListSize() cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
// TODO: SetMaxDynamicTableSize, SetMaxDynamicTableSizeLimit on
// henc in response to SETTINGS frames?
cc.henc = hpack.NewEncoder(&cc.hbuf) cc.henc = hpack.NewEncoder(&cc.hbuf)
cc.henc.SetMaxDynamicTableSizeLimit(t.maxEncoderHeaderTableSize())
cc.peerMaxHeaderTableSize = initialHeaderTableSize
if t.AllowHTTP { if t.AllowHTTP {
cc.nextStreamID = 3 cc.nextStreamID = 3
@ -731,9 +798,15 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
{ID: SettingEnablePush, Val: 0}, {ID: SettingEnablePush, Val: 0},
{ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow}, {ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow},
} }
if max := t.maxFrameReadSize(); max != 0 {
initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: max})
}
if max := t.maxHeaderListSize(); max != 0 { if max := t.maxHeaderListSize(); max != 0 {
initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max}) initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max})
} }
if maxHeaderTableSize != initialHeaderTableSize {
initialSettings = append(initialSettings, Setting{ID: SettingHeaderTableSize, Val: maxHeaderTableSize})
}
cc.bw.Write(clientPreface) cc.bw.Write(clientPreface)
cc.fr.WriteSettings(initialSettings...) cc.fr.WriteSettings(initialSettings...)
@ -1075,7 +1148,7 @@ func (cc *ClientConn) closeForLostPing() {
func commaSeparatedTrailers(req *http.Request) (string, error) { func commaSeparatedTrailers(req *http.Request) (string, error) {
keys := make([]string, 0, len(req.Trailer)) keys := make([]string, 0, len(req.Trailer))
for k := range req.Trailer { for k := range req.Trailer {
k = http.CanonicalHeaderKey(k) k = canonicalHeader(k)
switch k { switch k {
case "Transfer-Encoding", "Trailer", "Content-Length": case "Transfer-Encoding", "Trailer", "Content-Length":
return "", fmt.Errorf("invalid Trailer key %q", k) return "", fmt.Errorf("invalid Trailer key %q", k)
@ -1612,7 +1685,7 @@ func (cs *clientStream) writeRequestBody(req *http.Request) (err error) {
var sawEOF bool var sawEOF bool
for !sawEOF { for !sawEOF {
n, err := body.Read(buf[:len(buf)]) n, err := body.Read(buf)
if hasContentLen { if hasContentLen {
remainLen -= int64(n) remainLen -= int64(n)
if remainLen == 0 && err == nil { if remainLen == 0 && err == nil {
@ -1915,7 +1988,7 @@ func (cc *ClientConn) encodeHeaders(req *http.Request, addGzipHeader bool, trail
// Header list size is ok. Write the headers. // Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) { enumerateHeaders(func(name, value string) {
name, ascii := asciiToLower(name) name, ascii := lowerHeader(name)
if !ascii { if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x). // field names have to be ASCII characters (just as in HTTP/1.x).
@ -1968,7 +2041,7 @@ func (cc *ClientConn) encodeTrailers(trailer http.Header) ([]byte, error) {
} }
for k, vv := range trailer { for k, vv := range trailer {
lowKey, ascii := asciiToLower(k) lowKey, ascii := lowerHeader(k)
if !ascii { if !ascii {
// Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header // Skip writing invalid headers. Per RFC 7540, Section 8.1.2, header
// field names have to be ASCII characters (just as in HTTP/1.x). // field names have to be ASCII characters (just as in HTTP/1.x).
@ -2301,7 +2374,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
Status: status + " " + http.StatusText(statusCode), Status: status + " " + http.StatusText(statusCode),
} }
for _, hf := range regularFields { for _, hf := range regularFields {
key := http.CanonicalHeaderKey(hf.Name) key := canonicalHeader(hf.Name)
if key == "Trailer" { if key == "Trailer" {
t := res.Trailer t := res.Trailer
if t == nil { if t == nil {
@ -2309,7 +2382,7 @@ func (rl *clientConnReadLoop) handleResponse(cs *clientStream, f *MetaHeadersFra
res.Trailer = t res.Trailer = t
} }
foreachHeaderElement(hf.Value, func(v string) { foreachHeaderElement(hf.Value, func(v string) {
t[http.CanonicalHeaderKey(v)] = nil t[canonicalHeader(v)] = nil
}) })
} else { } else {
vv := header[key] vv := header[key]
@ -2414,7 +2487,7 @@ func (rl *clientConnReadLoop) processTrailers(cs *clientStream, f *MetaHeadersFr
trailer := make(http.Header) trailer := make(http.Header)
for _, hf := range f.RegularFields() { for _, hf := range f.RegularFields() {
key := http.CanonicalHeaderKey(hf.Name) key := canonicalHeader(hf.Name)
trailer[key] = append(trailer[key], hf.Value) trailer[key] = append(trailer[key], hf.Value)
} }
cs.trailer = trailer cs.trailer = trailer
@ -2760,8 +2833,10 @@ func (rl *clientConnReadLoop) processSettingsNoWrite(f *SettingsFrame) error {
cc.cond.Broadcast() cc.cond.Broadcast()
cc.initialWindowSize = s.Val cc.initialWindowSize = s.Val
case SettingHeaderTableSize:
cc.henc.SetMaxDynamicTableSize(s.Val)
cc.peerMaxHeaderTableSize = s.Val
default: default:
// TODO(bradfitz): handle more settings? SETTINGS_HEADER_TABLE_SIZE probably.
cc.vlogf("Unhandled Setting: %v", s) cc.vlogf("Unhandled Setting: %v", s)
} }
return nil return nil
@ -2985,7 +3060,11 @@ func (gz *gzipReader) Read(p []byte) (n int, err error) {
} }
func (gz *gzipReader) Close() error { func (gz *gzipReader) Close() error {
return gz.body.Close() if err := gz.body.Close(); err != nil {
return err
}
gz.zerr = fs.ErrClosed
return nil
} }
type errorReader struct{ err error } type errorReader struct{ err error }

View file

@ -0,0 +1,30 @@
// Code generated by cmd/cgo -godefs; DO NOT EDIT.
// cgo -godefs defs_openbsd.go
package socket
type iovec struct {
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Iov *iovec
Iovlen uint32
Control *byte
Controllen uint32
Flags int32
}
type cmsghdr struct {
Len uint32
Level int32
Type int32
}
const (
sizeofIovec = 0x10
sizeofMsghdr = 0x30
)

View file

@ -0,0 +1,30 @@
// Code generated by cmd/cgo -godefs; DO NOT EDIT.
// cgo -godefs defs_openbsd.go
package socket
type iovec struct {
Base *byte
Len uint64
}
type msghdr struct {
Name *byte
Namelen uint32
Iov *iovec
Iovlen uint32
Control *byte
Controllen uint32
Flags int32
}
type cmsghdr struct {
Len uint32
Level int32
Type int32
}
const (
sizeofIovec = 0x10
sizeofMsghdr = 0x30
)

BIN
vendor/golang.org/x/net/publicsuffix/data/children generated vendored Normal file

Binary file not shown.

BIN
vendor/golang.org/x/net/publicsuffix/data/nodes generated vendored Normal file

Binary file not shown.

1
vendor/golang.org/x/net/publicsuffix/data/text generated vendored Normal file

File diff suppressed because one or more lines are too long

View file

@ -101,10 +101,10 @@ func PublicSuffix(domain string) (publicSuffix string, icann bool) {
break break
} }
u := uint32(nodeValue(f) >> (nodesBitsTextOffset + nodesBitsTextLength)) u := uint32(nodes.get(f) >> (nodesBitsTextOffset + nodesBitsTextLength))
icannNode = u&(1<<nodesBitsICANN-1) != 0 icannNode = u&(1<<nodesBitsICANN-1) != 0
u >>= nodesBitsICANN u >>= nodesBitsICANN
u = children[u&(1<<nodesBitsChildren-1)] u = children.get(u & (1<<nodesBitsChildren - 1))
lo = u & (1<<childrenBitsLo - 1) lo = u & (1<<childrenBitsLo - 1)
u >>= childrenBitsLo u >>= childrenBitsLo
hi = u & (1<<childrenBitsHi - 1) hi = u & (1<<childrenBitsHi - 1)
@ -154,18 +154,9 @@ func find(label string, lo, hi uint32) uint32 {
return notFound return notFound
} }
func nodeValue(i uint32) uint64 {
off := uint64(i * (nodesBits / 8))
return uint64(nodes[off])<<32 |
uint64(nodes[off+1])<<24 |
uint64(nodes[off+2])<<16 |
uint64(nodes[off+3])<<8 |
uint64(nodes[off+4])
}
// nodeLabel returns the label for the i'th node. // nodeLabel returns the label for the i'th node.
func nodeLabel(i uint32) string { func nodeLabel(i uint32) string {
x := nodeValue(i) x := nodes.get(i)
length := x & (1<<nodesBitsTextLength - 1) length := x & (1<<nodesBitsTextLength - 1)
x >>= nodesBitsTextLength x >>= nodesBitsTextLength
offset := x & (1<<nodesBitsTextOffset - 1) offset := x & (1<<nodesBitsTextOffset - 1)
@ -189,3 +180,24 @@ func EffectiveTLDPlusOne(domain string) (string, error) {
} }
return domain[1+strings.LastIndex(domain[:i], "."):], nil return domain[1+strings.LastIndex(domain[:i], "."):], nil
} }
type uint32String string
func (u uint32String) get(i uint32) uint32 {
off := i * 4
return (uint32(u[off])<<24 |
uint32(u[off+1])<<16 |
uint32(u[off+2])<<8 |
uint32(u[off+3]))
}
type uint40String string
func (u uint40String) get(i uint32) uint64 {
off := uint64(i * (nodesBits / 8))
return uint64(u[off])<<32 |
uint64(u[off+1])<<24 |
uint64(u[off+2])<<16 |
uint64(u[off+3])<<8 |
uint64(u[off+4])
}

File diff suppressed because it is too large Load diff

View file

@ -7,9 +7,11 @@
package execabs package execabs
import "strings" import (
"errors"
"os/exec"
)
func isGo119ErrDot(err error) bool { func isGo119ErrDot(err error) bool {
// TODO: return errors.Is(err, exec.ErrDot) return errors.Is(err, exec.ErrDot)
return strings.Contains(err.Error(), "current directory")
} }

View file

@ -367,6 +367,7 @@ func NewCallbackCDecl(fn interface{}) uintptr {
//sys IsWindowUnicode(hwnd HWND) (isUnicode bool) = user32.IsWindowUnicode //sys IsWindowUnicode(hwnd HWND) (isUnicode bool) = user32.IsWindowUnicode
//sys IsWindowVisible(hwnd HWND) (isVisible bool) = user32.IsWindowVisible //sys IsWindowVisible(hwnd HWND) (isVisible bool) = user32.IsWindowVisible
//sys GetGUIThreadInfo(thread uint32, info *GUIThreadInfo) (err error) = user32.GetGUIThreadInfo //sys GetGUIThreadInfo(thread uint32, info *GUIThreadInfo) (err error) = user32.GetGUIThreadInfo
//sys GetLargePageMinimum() (size uintptr)
// Volume Management Functions // Volume Management Functions
//sys DefineDosDevice(flags uint32, deviceName *uint16, targetPath *uint16) (err error) = DefineDosDeviceW //sys DefineDosDevice(flags uint32, deviceName *uint16, targetPath *uint16) (err error) = DefineDosDeviceW

View file

@ -252,6 +252,7 @@ func errnoErr(e syscall.Errno) error {
procGetFileType = modkernel32.NewProc("GetFileType") procGetFileType = modkernel32.NewProc("GetFileType")
procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW") procGetFinalPathNameByHandleW = modkernel32.NewProc("GetFinalPathNameByHandleW")
procGetFullPathNameW = modkernel32.NewProc("GetFullPathNameW") procGetFullPathNameW = modkernel32.NewProc("GetFullPathNameW")
procGetLargePageMinimum = modkernel32.NewProc("GetLargePageMinimum")
procGetLastError = modkernel32.NewProc("GetLastError") procGetLastError = modkernel32.NewProc("GetLastError")
procGetLogicalDriveStringsW = modkernel32.NewProc("GetLogicalDriveStringsW") procGetLogicalDriveStringsW = modkernel32.NewProc("GetLogicalDriveStringsW")
procGetLogicalDrives = modkernel32.NewProc("GetLogicalDrives") procGetLogicalDrives = modkernel32.NewProc("GetLogicalDrives")
@ -2180,6 +2181,12 @@ func GetFullPathName(path *uint16, buflen uint32, buf *uint16, fname **uint16) (
return return
} }
func GetLargePageMinimum() (size uintptr) {
r0, _, _ := syscall.Syscall(procGetLargePageMinimum.Addr(), 0, 0, 0, 0)
size = uintptr(r0)
return
}
func GetLastError() (lasterr error) { func GetLastError() (lasterr error) {
r0, _, _ := syscall.Syscall(procGetLastError.Addr(), 0, 0, 0, 0) r0, _, _ := syscall.Syscall(procGetLastError.Addr(), 0, 0, 0, 0)
if r0 != 0 { if r0 != 0 {

11
vendor/modules.txt vendored
View file

@ -134,7 +134,7 @@ github.com/gin-contrib/sessions/memstore
# github.com/gin-contrib/sse v0.1.0 # github.com/gin-contrib/sse v0.1.0
## explicit; go 1.12 ## explicit; go 1.12
github.com/gin-contrib/sse github.com/gin-contrib/sse
# github.com/gin-gonic/gin v1.8.1 # github.com/gin-gonic/gin v1.8.2
## explicit; go 1.18 ## explicit; go 1.18
github.com/gin-gonic/gin github.com/gin-gonic/gin
github.com/gin-gonic/gin/binding github.com/gin-gonic/gin/binding
@ -323,12 +323,13 @@ github.com/oklog/ulid
# github.com/pelletier/go-toml v1.9.5 # github.com/pelletier/go-toml v1.9.5
## explicit; go 1.12 ## explicit; go 1.12
github.com/pelletier/go-toml github.com/pelletier/go-toml
# github.com/pelletier/go-toml/v2 v2.0.5 # github.com/pelletier/go-toml/v2 v2.0.6
## explicit; go 1.16 ## explicit; go 1.16
github.com/pelletier/go-toml/v2 github.com/pelletier/go-toml/v2
github.com/pelletier/go-toml/v2/internal/ast github.com/pelletier/go-toml/v2/internal/characters
github.com/pelletier/go-toml/v2/internal/danger github.com/pelletier/go-toml/v2/internal/danger
github.com/pelletier/go-toml/v2/internal/tracker github.com/pelletier/go-toml/v2/internal/tracker
github.com/pelletier/go-toml/v2/unstable
# github.com/pkg/errors v0.9.1 # github.com/pkg/errors v0.9.1
## explicit ## explicit
github.com/pkg/errors github.com/pkg/errors
@ -686,7 +687,7 @@ golang.org/x/image/webp
# golang.org/x/mod v0.6.0-dev.0.20220907135952-02c991387e35 # golang.org/x/mod v0.6.0-dev.0.20220907135952-02c991387e35
## explicit; go 1.17 ## explicit; go 1.17
golang.org/x/mod/semver golang.org/x/mod/semver
# golang.org/x/net v0.0.0-20221014081412-f15817d10f9b # golang.org/x/net v0.4.0
## explicit; go 1.17 ## explicit; go 1.17
golang.org/x/net/bpf golang.org/x/net/bpf
golang.org/x/net/context golang.org/x/net/context
@ -707,7 +708,7 @@ golang.org/x/net/publicsuffix
## explicit; go 1.17 ## explicit; go 1.17
golang.org/x/oauth2 golang.org/x/oauth2
golang.org/x/oauth2/internal golang.org/x/oauth2/internal
# golang.org/x/sys v0.2.0 # golang.org/x/sys v0.3.0
## explicit; go 1.17 ## explicit; go 1.17
golang.org/x/sys/cpu golang.org/x/sys/cpu
golang.org/x/sys/execabs golang.org/x/sys/execabs