Compare commits

..

1 commit

Author SHA1 Message Date
Daenney 749fc5d116
Merge 9acf9ddc75 into 45e1609377 2024-11-07 12:14:46 +00:00
145 changed files with 1613 additions and 6056 deletions

View file

@ -1,7 +1,5 @@
# Storage # Storage
When configuring an object storage backend, the `storage-s3-endpoint` **must not** include the bucket name. That's what `s3-bucket-name` is for. Using subfolders in a bucket isn't currently supported.
## Settings ## Settings
```yaml ```yaml

53
go.mod
View file

@ -2,24 +2,8 @@ module github.com/superseriousbusiness/gotosocial
go 1.22.2 go 1.22.2
// Replace modernc/sqlite with our version that fixes the concurrency INTERRUPT issue
replace modernc.org/sqlite => gitlab.com/NyaaaWhatsUpDoc/sqlite v1.33.1-concurrency-workaround replace modernc.org/sqlite => gitlab.com/NyaaaWhatsUpDoc/sqlite v1.33.1-concurrency-workaround
// Below pin otel libraries to v1.29.0 until we can figure out issues
replace go.opentelemetry.io/otel => go.opentelemetry.io/otel v1.29.0
replace go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc => go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0
replace go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp => go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.29.0
replace go.opentelemetry.io/otel/metric => go.opentelemetry.io/otel/metric v1.29.0
replace go.opentelemetry.io/otel/sdk => go.opentelemetry.io/otel/sdk v1.29.0
replace go.opentelemetry.io/otel/sdk/metric => go.opentelemetry.io/otel/sdk/metric v1.29.0
replace go.opentelemetry.io/otel/trace => go.opentelemetry.io/otel/trace v1.29.0
require ( require (
codeberg.org/gruf/go-bytes v1.0.2 codeberg.org/gruf/go-bytes v1.0.2
codeberg.org/gruf/go-bytesize v1.0.3 codeberg.org/gruf/go-bytesize v1.0.3
@ -73,26 +57,26 @@ require (
github.com/tetratelabs/wazero v1.8.1 github.com/tetratelabs/wazero v1.8.1
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
github.com/ulule/limiter/v3 v3.11.2 github.com/ulule/limiter/v3 v3.11.2
github.com/uptrace/bun v1.2.5 github.com/uptrace/bun v1.2.1
github.com/uptrace/bun/dialect/pgdialect v1.2.5 github.com/uptrace/bun/dialect/pgdialect v1.2.1
github.com/uptrace/bun/dialect/sqlitedialect v1.2.5 github.com/uptrace/bun/dialect/sqlitedialect v1.2.1
github.com/uptrace/bun/extra/bunotel v1.2.5 github.com/uptrace/bun/extra/bunotel v1.2.1
github.com/wagslane/go-password-validator v0.3.0 github.com/wagslane/go-password-validator v0.3.0
github.com/yuin/goldmark v1.7.8 github.com/yuin/goldmark v1.7.8
go.opentelemetry.io/otel v1.32.0 go.opentelemetry.io/otel v1.29.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.32.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.32.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.29.0
go.opentelemetry.io/otel/exporters/prometheus v0.51.0 go.opentelemetry.io/otel/exporters/prometheus v0.51.0
go.opentelemetry.io/otel/metric v1.32.0 go.opentelemetry.io/otel/metric v1.29.0
go.opentelemetry.io/otel/sdk v1.32.0 go.opentelemetry.io/otel/sdk v1.29.0
go.opentelemetry.io/otel/sdk/metric v1.32.0 go.opentelemetry.io/otel/sdk/metric v1.29.0
go.opentelemetry.io/otel/trace v1.32.0 go.opentelemetry.io/otel/trace v1.29.0
go.uber.org/automaxprocs v1.6.0 go.uber.org/automaxprocs v1.6.0
golang.org/x/crypto v0.29.0 golang.org/x/crypto v0.28.0
golang.org/x/image v0.22.0 golang.org/x/image v0.21.0
golang.org/x/net v0.31.0 golang.org/x/net v0.30.0
golang.org/x/oauth2 v0.23.0 golang.org/x/oauth2 v0.23.0
golang.org/x/text v0.20.0 golang.org/x/text v0.19.0
gopkg.in/mcuadros/go-syslog.v2 v2.3.0 gopkg.in/mcuadros/go-syslog.v2 v2.3.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v0.0.0-00010101000000-000000000000 modernc.org/sqlite v0.0.0-00010101000000-000000000000
@ -197,7 +181,6 @@ require (
github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.59.1 // indirect github.com/prometheus/common v0.59.1 // indirect
github.com/prometheus/procfs v0.15.1 // indirect github.com/prometheus/procfs v0.15.1 // indirect
github.com/puzpuzpuz/xsync/v3 v3.4.0 // indirect
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect
@ -218,7 +201,7 @@ require (
github.com/toqueteos/webbrowser v1.2.0 // indirect github.com/toqueteos/webbrowser v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2 // indirect github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.4 // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
go.mongodb.org/mongo-driver v1.14.0 // indirect go.mongodb.org/mongo-driver v1.14.0 // indirect
@ -228,8 +211,8 @@ require (
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect
golang.org/x/mod v0.18.0 // indirect golang.org/x/mod v0.18.0 // indirect
golang.org/x/sync v0.9.0 // indirect golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.27.0 // indirect golang.org/x/sys v0.26.0 // indirect
golang.org/x/tools v0.22.0 // indirect golang.org/x/tools v0.22.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect

50
go.sum generated
View file

@ -469,8 +469,6 @@ github.com/prometheus/common v0.59.1 h1:LXb1quJHWm1P6wq/U824uxYi4Sg0oGvNeUm1z5dJ
github.com/prometheus/common v0.59.1/go.mod h1:GpWM7dewqmVYcd7SmRaiWVe9SSqjf0UrwnYnpEZNuT0= github.com/prometheus/common v0.59.1/go.mod h1:GpWM7dewqmVYcd7SmRaiWVe9SSqjf0UrwnYnpEZNuT0=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4=
github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc= github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b h1:aUNXCGgukb4gtY99imuIeoh8Vr0GSwAlYxPAhqZrpFc=
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg= github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b/go.mod h1:wTPjTepVu7uJBYgZ0SdWHQlIas582j6cn2jgk4DDdlg=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
@ -580,16 +578,16 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/ulule/limiter/v3 v3.11.2 h1:P4yOrxoEMJbOTfRJR2OzjL90oflzYPPmWg+dvwN2tHA= github.com/ulule/limiter/v3 v3.11.2 h1:P4yOrxoEMJbOTfRJR2OzjL90oflzYPPmWg+dvwN2tHA=
github.com/ulule/limiter/v3 v3.11.2/go.mod h1:QG5GnFOCV+k7lrL5Y8kgEeeflPH3+Cviqlqa8SVSQxI= github.com/ulule/limiter/v3 v3.11.2/go.mod h1:QG5GnFOCV+k7lrL5Y8kgEeeflPH3+Cviqlqa8SVSQxI=
github.com/uptrace/bun v1.2.5 h1:gSprL5xiBCp+tzcZHgENzJpXnmQwRM/A6s4HnBF85mc= github.com/uptrace/bun v1.2.1 h1:2ENAcfeCfaY5+2e7z5pXrzFKy3vS8VXvkCag6N2Yzfk=
github.com/uptrace/bun v1.2.5/go.mod h1:vkQMS4NNs4VNZv92y53uBSHXRqYyJp4bGhMHgaNCQpY= github.com/uptrace/bun v1.2.1/go.mod h1:cNg+pWBUMmJ8rHnETgf65CEvn3aIKErrwOD6IA8e+Ec=
github.com/uptrace/bun/dialect/pgdialect v1.2.5 h1:dWLUxpjTdglzfBks2x+U2WIi+nRVjuh7Z3DLYVFswJk= github.com/uptrace/bun/dialect/pgdialect v1.2.1 h1:ceP99r03u+s8ylaDE/RzgcajwGiC76Jz3nS2ZgyPQ4M=
github.com/uptrace/bun/dialect/pgdialect v1.2.5/go.mod h1:stwnlE8/6x8cuQ2aXcZqwDK/d+6jxgO3iQewflJT6C4= github.com/uptrace/bun/dialect/pgdialect v1.2.1/go.mod h1:mv6B12cisvSc6bwKm9q9wcrr26awkZK8QXM+nso9n2U=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.5 h1:liDvMaIWrN8DrHcxVbviOde/VDss9uhcqpcTSL3eJjc= github.com/uptrace/bun/dialect/sqlitedialect v1.2.1 h1:IprvkIKUjEjvt4VKpcmLpbMIucjrsmUPJOSlg19+a0Q=
github.com/uptrace/bun/dialect/sqlitedialect v1.2.5/go.mod h1:Mw6IDL/jNUL5ozcREAezOJSZ9Jm4LJlfoaXxBEfNBlM= github.com/uptrace/bun/dialect/sqlitedialect v1.2.1/go.mod h1:mMQf4NUpgY8bnOanxGmxNiHCdALOggS4cZ3v63a9D/o=
github.com/uptrace/bun/extra/bunotel v1.2.5 h1:kkuuTbrG9d5leYZuSBKhq2gtq346lIrxf98Mig2y128= github.com/uptrace/bun/extra/bunotel v1.2.1 h1:5oTy3Jh7Q1bhCd5vnPszBmJgYouw+PuuZ8iSCm+uNCQ=
github.com/uptrace/bun/extra/bunotel v1.2.5/go.mod h1:rCHLszRZwppWE9cGDodO2FCI1qCrLwDjONp38KD3bA8= github.com/uptrace/bun/extra/bunotel v1.2.1/go.mod h1:SWW3HyjiXPYM36q0QSpdtTP8v21nWHnTCxu4lYkpO90=
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2 h1:ZjUj9BLYf9PEqBn8W/OapxhPjVRdC6CsXTdULHsyk5c= github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.4 h1:x3omFAG2XkvWFg1hvXRinY2ExAL1Aacl7W9ZlYjo6gc=
github.com/uptrace/opentelemetry-go-extra/otelsql v0.3.2/go.mod h1:O8bHQfyinKwTXKkiKNGmLQS7vRsqRxIQTFZpYpHK3IQ= github.com/uptrace/opentelemetry-go-extra/otelsql v0.2.4/go.mod h1:qMKJr5fTnY0p7hqCQMNrAk62bCARWR5rAbTrGUFRuh4=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.14.0/go.mod h1:ol1PCaL0dX20wC0htZ7sYCsvCYmrouYra0zHzaclZhE= github.com/valyala/fasthttp v1.14.0/go.mod h1:ol1PCaL0dX20wC0htZ7sYCsvCYmrouYra0zHzaclZhE=
@ -668,8 +666,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
golang.org/x/crypto v0.29.0 h1:L5SG1JTTXupVV3n6sUqMTeWbjAyfPwoda2DLX8J8FrQ= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.29.0/go.mod h1:+F4F4N5hv6v38hfeYwTdx20oUvLLc+QfrE9Ax9HtgRg= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -684,8 +682,8 @@ golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUF
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.22.0 h1:UtK5yLUzilVrkjMAZAZ34DXGpASN8i8pj8g+O+yd10g= golang.org/x/image v0.21.0 h1:c5qV36ajHpdj4Qi0GnE0jUc/yuo33OLFaa0d+crTD5s=
golang.org/x/image v0.22.0/go.mod h1:9hPFhljd4zZ1GNSIZJ49sqbp45GKK9t6w+iXvGqZUz4= golang.org/x/image v0.21.0/go.mod h1:vUbsLavqK/W303ZroQQVKQ+Af3Yl6Uz1Ppu5J/cLz78=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@ -739,8 +737,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -758,8 +756,8 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.9.0 h1:fEo0HyrW1GIgZdpbhCRO0PkJajUS5H9IFUztCgEo2jQ= golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
golang.org/x/sync v0.9.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -800,13 +798,13 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/term v0.26.0 h1:WEQa6V3Gja/BhNxg540hBip/kkaYtRg3cxg4oXSw4AU= golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
golang.org/x/term v0.26.0/go.mod h1:Si5m1o57C5nBNQo5z1iq+XDijt21BDBDp2bK0QI8e3E= golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -814,8 +812,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View file

@ -36,7 +36,6 @@
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect"
) )
@ -87,7 +86,7 @@ func(uncached []string) ([]*gtsmodel.Account, error) {
// Reorder the statuses by their // Reorder the statuses by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(a *gtsmodel.Account) string { return a.ID } getID := func(a *gtsmodel.Account) string { return a.ID }
xslices.OrderBy(accounts, ids, getID) util.OrderBy(accounts, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -22,7 +22,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -169,7 +169,7 @@ func(uncached []string) ([]*gtsmodel.Token, error) {
// Reoroder the tokens by their // Reoroder the tokens by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(t *gtsmodel.Token) string { return t.ID } getID := func(t *gtsmodel.Token) string { return t.ID }
xslices.OrderBy(tokens, tokenIDs, getID) util.OrderBy(tokens, tokenIDs, getID)
return tokens, nil return tokens, nil
} }

View file

@ -31,7 +31,7 @@
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect"
) )
@ -209,7 +209,7 @@ func(accountID string, uncached []string) ([]*gtsmodel.Conversation, error) {
// Reorder the conversations by their last status IDs to ensure correct order. // Reorder the conversations by their last status IDs to ensure correct order.
getID := func(b *gtsmodel.Conversation) string { return b.ID } getID := func(b *gtsmodel.Conversation) string { return b.ID }
xslices.OrderBy(conversations, conversationLastStatusIDs, getID) util.OrderBy(conversations, conversationLastStatusIDs, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.
@ -558,7 +558,7 @@ func (c *conversationDB) DeleteStatusFromConversations(ctx context.Context, stat
// Invalidate cache entries. // Invalidate cache entries.
updatedConversationIDs = append(updatedConversationIDs, deletedConversationIDs...) updatedConversationIDs = append(updatedConversationIDs, deletedConversationIDs...)
updatedConversationIDs = xslices.Deduplicate(updatedConversationIDs) updatedConversationIDs = util.Deduplicate(updatedConversationIDs)
c.state.Caches.DB.Conversation.InvalidateIDs("ID", updatedConversationIDs) c.state.Caches.DB.Conversation.InvalidateIDs("ID", updatedConversationIDs)
return nil return nil

View file

@ -31,7 +31,7 @@
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect"
) )
@ -597,7 +597,7 @@ func(uncached []string) ([]*gtsmodel.Emoji, error) {
// Reorder the emojis by their // Reorder the emojis by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(e *gtsmodel.Emoji) string { return e.ID } getID := func(e *gtsmodel.Emoji) string { return e.ID }
xslices.OrderBy(emojis, ids, getID) util.OrderBy(emojis, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.
@ -661,7 +661,7 @@ func(uncached []string) ([]*gtsmodel.EmojiCategory, error) {
// Reorder the categories by their // Reorder the categories by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(c *gtsmodel.EmojiCategory) string { return c.ID } getID := func(c *gtsmodel.EmojiCategory) string { return c.ID }
xslices.OrderBy(categories, ids, getID) util.OrderBy(categories, ids, getID)
return categories, nil return categories, nil
} }

View file

@ -27,7 +27,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -99,7 +99,7 @@ func(uncached []string) ([]*gtsmodel.Filter, error) {
} }
// Put the filter structs in the same order as the filter IDs. // Put the filter structs in the same order as the filter IDs.
xslices.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID }) util.OrderBy(filters, filterIDs, func(filter *gtsmodel.Filter) string { return filter.ID })
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
return filters, nil return filters, nil

View file

@ -26,7 +26,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -140,7 +140,7 @@ func(uncached []string) ([]*gtsmodel.FilterKeyword, error) {
} }
// Put the filter keyword structs in the same order as the filter keyword IDs. // Put the filter keyword structs in the same order as the filter keyword IDs.
xslices.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string { util.OrderBy(filterKeywords, filterKeywordIDs, func(filterKeyword *gtsmodel.FilterKeyword) string {
return filterKeyword.ID return filterKeyword.ID
}) })

View file

@ -25,7 +25,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtscontext" "github.com/superseriousbusiness/gotosocial/internal/gtscontext"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -116,7 +116,7 @@ func(uncached []string) ([]*gtsmodel.FilterStatus, error) {
} }
// Put the filter status structs in the same order as the filter status IDs. // Put the filter status structs in the same order as the filter status IDs.
xslices.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string { util.OrderBy(filterStatuses, filterStatusIDs, func(filterStatus *gtsmodel.FilterStatus) string {
return filterStatus.ID return filterStatus.ID
}) })

View file

@ -29,7 +29,7 @@
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -113,7 +113,7 @@ func(uncached []string) ([]*gtsmodel.InteractionRequest, error) {
// Reorder the requests by their // Reorder the requests by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(r *gtsmodel.InteractionRequest) string { return r.ID } getID := func(r *gtsmodel.InteractionRequest) string { return r.ID }
xslices.OrderBy(requests, ids, getID) util.OrderBy(requests, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -31,7 +31,7 @@
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -333,7 +333,7 @@ func(uncached []string) ([]*gtsmodel.List, error) {
// Reorder the lists by their // Reorder the lists by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(l *gtsmodel.List) string { return l.ID } getID := func(l *gtsmodel.List) string { return l.ID }
xslices.OrderBy(lists, ids, getID) util.OrderBy(lists, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.
@ -387,12 +387,12 @@ func (l *listDB) PutListEntries(ctx context.Context, entries []*gtsmodel.ListEnt
} }
// Collect unique list IDs from the provided list entries. // Collect unique list IDs from the provided list entries.
listIDs := xslices.Collate(entries, func(e *gtsmodel.ListEntry) string { listIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string {
return e.ListID return e.ListID
}) })
// Collect unique follow IDs from the provided list entries. // Collect unique follow IDs from the provided list entries.
followIDs := xslices.Collate(entries, func(e *gtsmodel.ListEntry) string { followIDs := util.Collate(entries, func(e *gtsmodel.ListEntry) string {
return e.FollowID return e.FollowID
}) })
@ -441,7 +441,7 @@ func (l *listDB) DeleteAllListEntriesByFollows(ctx context.Context, followIDs ..
} }
// Deduplicate IDs before invalidate. // Deduplicate IDs before invalidate.
listIDs = xslices.Deduplicate(listIDs) listIDs = util.Deduplicate(listIDs)
// Invalidate all related list entry caches. // Invalidate all related list entry caches.
l.invalidateEntryCaches(ctx, listIDs, followIDs) l.invalidateEntryCaches(ctx, listIDs, followIDs)

View file

@ -28,7 +28,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -78,7 +78,7 @@ func(uncached []string) ([]*gtsmodel.MediaAttachment, error) {
// Reorder the media by their // Reorder the media by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(m *gtsmodel.MediaAttachment) string { return m.ID } getID := func(m *gtsmodel.MediaAttachment) string { return m.ID }
xslices.OrderBy(media, ids, getID) util.OrderBy(media, ids, getID)
return media, nil return media, nil
} }

View file

@ -28,7 +28,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -91,7 +91,7 @@ func(uncached []string) ([]*gtsmodel.Mention, error) {
// Reorder the mentions by their // Reorder the mentions by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(m *gtsmodel.Mention) string { return m.ID } getID := func(m *gtsmodel.Mention) string { return m.ID }
xslices.OrderBy(mentions, ids, getID) util.OrderBy(mentions, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -29,7 +29,7 @@
"github.com/superseriousbusiness/gotosocial/internal/id" "github.com/superseriousbusiness/gotosocial/internal/id"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -130,7 +130,7 @@ func(uncached []string) ([]*gtsmodel.Notification, error) {
// Reorder the notifs by their // Reorder the notifs by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(n *gtsmodel.Notification) string { return n.ID } getID := func(n *gtsmodel.Notification) string { return n.ID }
xslices.OrderBy(notifs, ids, getID) util.OrderBy(notifs, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -29,7 +29,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -315,7 +315,7 @@ func(uncached []string) ([]*gtsmodel.PollVote, error) {
// Reorder the poll votes by their // Reorder the poll votes by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(v *gtsmodel.PollVote) string { return v.ID } getID := func(v *gtsmodel.PollVote) string { return v.ID }
xslices.OrderBy(votes, voteIDs, getID) util.OrderBy(votes, voteIDs, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -27,7 +27,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -127,7 +127,7 @@ func(uncached []string) ([]*gtsmodel.Block, error) {
// Reorder the blocks by their // Reorder the blocks by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(b *gtsmodel.Block) string { return b.ID } getID := func(b *gtsmodel.Block) string { return b.ID }
xslices.OrderBy(blocks, ids, getID) util.OrderBy(blocks, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -28,7 +28,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -103,7 +103,7 @@ func(uncached []string) ([]*gtsmodel.Follow, error) {
// Reorder the follows by their // Reorder the follows by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(f *gtsmodel.Follow) string { return f.ID } getID := func(f *gtsmodel.Follow) string { return f.ID }
xslices.OrderBy(follows, ids, getID) util.OrderBy(follows, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.
@ -376,7 +376,7 @@ func (r *relationshipDB) DeleteAccountFollows(ctx context.Context, accountID str
} }
// Gather the follow IDs that were deleted for removing related list entries. // Gather the follow IDs that were deleted for removing related list entries.
followIDs := xslices.Gather(nil, deleted, func(follow *gtsmodel.Follow) string { followIDs := util.Gather(nil, deleted, func(follow *gtsmodel.Follow) string {
return follow.ID return follow.ID
}) })

View file

@ -28,7 +28,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -103,7 +103,7 @@ func(uncached []string) ([]*gtsmodel.FollowRequest, error) {
// Reorder the requests by their // Reorder the requests by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(f *gtsmodel.FollowRequest) string { return f.ID } getID := func(f *gtsmodel.FollowRequest) string { return f.ID }
xslices.OrderBy(follows, ids, getID) util.OrderBy(follows, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -28,7 +28,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
"github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect"
) )
@ -109,7 +109,7 @@ func(uncached []string) ([]*gtsmodel.UserMute, error) {
// Reorder the mutes by their // Reorder the mutes by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(b *gtsmodel.UserMute) string { return b.ID } getID := func(b *gtsmodel.UserMute) string { return b.ID }
xslices.OrderBy(mutes, ids, getID) util.OrderBy(mutes, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -29,7 +29,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -76,7 +76,7 @@ func(uncached []string) ([]*gtsmodel.Status, error) {
// Reorder the statuses by their // Reorder the statuses by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(s *gtsmodel.Status) string { return s.ID } getID := func(s *gtsmodel.Status) string { return s.ID }
xslices.OrderBy(statuses, ids, getID) util.OrderBy(statuses, ids, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.

View file

@ -28,7 +28,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -95,7 +95,7 @@ func(uncached []string) ([]*gtsmodel.StatusBookmark, error) {
// Reorder the bookmarks by their // Reorder the bookmarks by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(b *gtsmodel.StatusBookmark) string { return b.ID } getID := func(b *gtsmodel.StatusBookmark) string { return b.ID }
xslices.OrderBy(bookmarks, ids, getID) util.OrderBy(bookmarks, ids, getID)
// Populate all loaded bookmarks, removing those we fail // Populate all loaded bookmarks, removing those we fail
// to populate (removes needing so many later nil checks). // to populate (removes needing so many later nil checks).

View file

@ -31,7 +31,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -155,7 +155,7 @@ func(uncached []string) ([]*gtsmodel.StatusFave, error) {
// Reorder the statuses by their // Reorder the statuses by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(f *gtsmodel.StatusFave) string { return f.ID } getID := func(f *gtsmodel.StatusFave) string { return f.ID }
xslices.OrderBy(faves, faveIDs, getID) util.OrderBy(faves, faveIDs, getID)
if gtscontext.Barebones(ctx) { if gtscontext.Barebones(ctx) {
// no need to fully populate. // no need to fully populate.
@ -339,7 +339,7 @@ func (s *statusFaveDB) DeleteStatusFaves(ctx context.Context, targetAccountID st
} }
// Deduplicate determined status IDs. // Deduplicate determined status IDs.
statusIDs = xslices.Deduplicate(statusIDs) statusIDs = util.Deduplicate(statusIDs)
// Invalidate any cached status faves for this status ID. // Invalidate any cached status faves for this status ID.
s.state.Caches.DB.StatusFave.InvalidateIDs("ID", statusIDs) s.state.Caches.DB.StatusFave.InvalidateIDs("ID", statusIDs)

View file

@ -28,7 +28,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/paging" "github.com/superseriousbusiness/gotosocial/internal/paging"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/uptrace/bun" "github.com/uptrace/bun"
) )
@ -102,7 +102,7 @@ func(uncached []string) ([]*gtsmodel.Tag, error) {
// Reorder the tags by their // Reorder the tags by their
// IDs to ensure in correct order. // IDs to ensure in correct order.
getID := func(t *gtsmodel.Tag) string { return t.ID } getID := func(t *gtsmodel.Tag) string { return t.ID }
xslices.OrderBy(tags, ids, getID) util.OrderBy(tags, ids, getID)
return tags, nil return tags, nil
} }
@ -301,5 +301,5 @@ func (t *tagDB) GetAccountIDsFollowingTagIDs(ctx context.Context, tagIDs []strin
// Accounts might be following multiple tags in list, // Accounts might be following multiple tags in list,
// but we only want to return each account once. // but we only want to return each account once.
return xslices.Deduplicate(accountIDs), nil return util.Deduplicate(accountIDs), nil
} }

View file

@ -35,7 +35,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/uris" "github.com/superseriousbusiness/gotosocial/internal/uris"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
type errOtherIRIBlocked struct { type errOtherIRIBlocked struct {
@ -162,7 +162,7 @@ func (f *Federator) PostInboxRequestBodyHook(ctx context.Context, r *http.Reques
// OtherIRIs will likely contain some // OtherIRIs will likely contain some
// duplicate entries now, so remove them. // duplicate entries now, so remove them.
otherIRIs = xslices.DeduplicateFunc(otherIRIs, otherIRIs = util.DeduplicateFunc(otherIRIs,
(*url.URL).String, // serialized URL is 'key()' (*url.URL).String, // serialized URL is 'key()'
) )

View file

@ -22,7 +22,7 @@
"strings" "strings"
"time" "time"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
// Conversation represents direct messages between the owner account and a set of other accounts. // Conversation represents direct messages between the owner account and a set of other accounts.
@ -62,7 +62,7 @@ type Conversation struct {
// ConversationOtherAccountsKey creates an OtherAccountsKey from a list of OtherAccountIDs. // ConversationOtherAccountsKey creates an OtherAccountsKey from a list of OtherAccountIDs.
func ConversationOtherAccountsKey(otherAccountIDs []string) string { func ConversationOtherAccountsKey(otherAccountIDs []string) string {
otherAccountIDs = xslices.Deduplicate(otherAccountIDs) otherAccountIDs = util.Deduplicate(otherAccountIDs)
slices.Sort(otherAccountIDs) slices.Sort(otherAccountIDs)
return strings.Join(otherAccountIDs, ",") return strings.Join(otherAccountIDs, ",")
} }

View file

@ -22,11 +22,11 @@
"fmt" "fmt"
"log/syslog" "log/syslog"
"os" "os"
"slices"
"strings" "strings"
"time" "time"
"codeberg.org/gruf/go-kv" "codeberg.org/gruf/go-kv"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices"
) )
var ( var (
@ -412,10 +412,7 @@ func logf(ctx context.Context, depth int, lvl LEVEL, fields []kv.Field, s string
buf.B = append(buf.B, lvlstrs[lvl]...) buf.B = append(buf.B, lvlstrs[lvl]...)
buf.B = append(buf.B, ' ') buf.B = append(buf.B, ' ')
if ctx != nil && len(ctxhooks) > 0 { if ctx != nil {
// Ensure fields have space for hooks (+1 for below).
fields = xslices.GrowJust(fields, len(ctxhooks)+1)
// Pass context through hooks. // Pass context through hooks.
for _, hook := range ctxhooks { for _, hook := range ctxhooks {
fields = hook(ctx, fields) fields = hook(ctx, fields)
@ -423,8 +420,9 @@ func logf(ctx context.Context, depth int, lvl LEVEL, fields []kv.Field, s string
} }
if s != "" { if s != "" {
// Append message (if given) as final log field. // Append message to log fields.
fields = xslices.AppendJust(fields, kv.Field{ fields = slices.Grow(fields, 1)
fields = append(fields, kv.Field{
K: "msg", V: fmt.Sprintf(s, a...), K: "msg", V: fmt.Sprintf(s, a...),
}) })
} }

View file

@ -27,7 +27,7 @@
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
func (p *Processor) Alias( func (p *Processor) Alias(
@ -137,8 +137,8 @@ type uri struct {
// Dedupe URIs + accounts, in case someone // Dedupe URIs + accounts, in case someone
// provided both an account URL and an // provided both an account URL and an
// account URI above, for the same account. // account URI above, for the same account.
account.AlsoKnownAsURIs = xslices.Deduplicate(account.AlsoKnownAsURIs) account.AlsoKnownAsURIs = util.Deduplicate(account.AlsoKnownAsURIs)
account.AlsoKnownAs = xslices.DeduplicateFunc( account.AlsoKnownAs = util.DeduplicateFunc(
account.AlsoKnownAs, account.AlsoKnownAs,
func(a *gtsmodel.Account) string { func(a *gtsmodel.Account) string {
return a.URI return a.URI

View file

@ -36,7 +36,6 @@
"github.com/superseriousbusiness/gotosocial/internal/typeutils" "github.com/superseriousbusiness/gotosocial/internal/typeutils"
"github.com/superseriousbusiness/gotosocial/internal/uris" "github.com/superseriousbusiness/gotosocial/internal/uris"
"github.com/superseriousbusiness/gotosocial/internal/util" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices"
) )
// Create processes the given form to create a new status, returning the api model representation of that status if it's OK. // Create processes the given form to create a new status, returning the api model representation of that status if it's OK.
@ -537,9 +536,9 @@ func (p *Processor) processContent(ctx context.Context, parseMention gtsmodel.Pa
} }
// Gather all the database IDs from each of the gathered status mentions, tags, and emojis. // Gather all the database IDs from each of the gathered status mentions, tags, and emojis.
status.MentionIDs = xslices.Gather(nil, status.Mentions, func(mention *gtsmodel.Mention) string { return mention.ID }) status.MentionIDs = util.Gather(nil, status.Mentions, func(mention *gtsmodel.Mention) string { return mention.ID })
status.TagIDs = xslices.Gather(nil, status.Tags, func(tag *gtsmodel.Tag) string { return tag.ID }) status.TagIDs = util.Gather(nil, status.Tags, func(tag *gtsmodel.Tag) string { return tag.ID })
status.EmojiIDs = xslices.Gather(nil, status.Emojis, func(emoji *gtsmodel.Emoji) string { return emoji.ID }) status.EmojiIDs = util.Gather(nil, status.Emojis, func(emoji *gtsmodel.Emoji) string { return emoji.ID })
if status.ContentWarning != "" && len(status.AttachmentIDs) > 0 { if status.ContentWarning != "" && len(status.AttachmentIDs) > 0 {
// If a content-warning is set, and // If a content-warning is set, and

View file

@ -36,7 +36,7 @@
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
"github.com/superseriousbusiness/gotosocial/internal/log" "github.com/superseriousbusiness/gotosocial/internal/log"
"github.com/superseriousbusiness/gotosocial/internal/uris" "github.com/superseriousbusiness/gotosocial/internal/uris"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices" "github.com/superseriousbusiness/gotosocial/internal/util"
) )
// AccountToAS converts a gts model account into an activity streams person, suitable for federation // AccountToAS converts a gts model account into an activity streams person, suitable for federation
@ -1819,7 +1819,7 @@ func populateValuesForProp[T ap.WithIRI](
// Deduplicate the iri strings to // Deduplicate the iri strings to
// make sure we're not parsing + adding // make sure we're not parsing + adding
// the same string multiple times. // the same string multiple times.
iriStrs = xslices.Deduplicate(iriStrs) iriStrs = util.Deduplicate(iriStrs)
// Append them to the property. // Append them to the property.
for _, iriStr := range iriStrs { for _, iriStr := range iriStrs {

View file

@ -15,53 +15,12 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>. // along with this program. If not, see <http://www.gnu.org/licenses/>.
package xslices package util
import ( import (
"slices" "slices"
) )
// GrowJust increases slice capacity to guarantee
// extra room 'size', where in the case that it does
// need to allocate more it ONLY allocates 'size' extra.
// This is different to typical slices.Grow behaviour,
// which simply guarantees extra through append() which
// may allocate more than necessary extra size.
func GrowJust[T any](in []T, size int) []T {
if cap(in)-len(in) < size {
// Reallocate enough for in + size.
in2 := make([]T, len(in), len(in)+size)
_ = copy(in2, in)
in = in2
}
return in
}
// AppendJust appends extra elements to slice,
// ONLY allocating at most len(extra) elements. This
// is different to the typical append behaviour which
// will append extra, in a manner to reduce the need
// for new allocations on every call to append.
func AppendJust[T any](in []T, extra ...T) []T {
l := len(in)
if cap(in)-l < len(extra) {
// Reallocate enough for + extra.
in2 := make([]T, l+len(extra))
_ = copy(in2, in)
in = in2
} else {
// Reslice for + extra.
in = in[:l+len(extra)]
}
// Copy extra into slice.
_ = copy(in[l:], extra)
return in
}
// Deduplicate deduplicates entries in the given slice. // Deduplicate deduplicates entries in the given slice.
func Deduplicate[T comparable](in []T) []T { func Deduplicate[T comparable](in []T) []T {
var ( var (

View file

@ -15,90 +15,22 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>. // along with this program. If not, see <http://www.gnu.org/licenses/>.
package xslices_test package util_test
import ( import (
"math/rand"
"net/url" "net/url"
"slices" "slices"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/superseriousbusiness/gotosocial/internal/util"
"github.com/superseriousbusiness/gotosocial/internal/util/xslices"
) )
func TestGrowJust(t *testing.T) { var (
for _, l := range []int{0, 2, 4, 8, 16, 32, 64} { testURLSlice = []*url.URL{}
for _, x := range []int{0, 2, 4, 8, 16, 32, 64} { )
s := make([]int, l, l+x)
for _, g := range []int{0, 2, 4, 8, 16, 32, 64} {
s2 := xslices.GrowJust(s, g)
// Slice length should not be different.
assert.Equal(t, len(s), len(s2))
switch {
// If slice already has capacity for
// 'g' then it should not be changed.
case cap(s) >= len(s)+g:
assert.Equal(t, cap(s), cap(s2))
// Else, returned slice should only
// have capacity for original length
// plus extra elements, NOTHING MORE.
default:
assert.Equal(t, cap(s2), len(s)+g)
}
}
}
}
}
func TestAppendJust(t *testing.T) {
for _, l := range []int{0, 2, 4, 8, 16, 32, 64} {
for _, x := range []int{0, 2, 4, 8, 16, 32, 64} {
s := make([]int, l, l+x)
// Randomize slice.
for i := range s {
s[i] = rand.Int()
}
for _, a := range []int{0, 2, 4, 8, 16, 32, 64} {
toAppend := make([]int, a)
// Randomize appended vals.
for i := range toAppend {
toAppend[i] = rand.Int()
}
s2 := xslices.AppendJust(s, toAppend...)
// Slice length should be as expected.
assert.Equal(t, len(s)+a, len(s2))
// Slice contents should be as expected.
assert.Equal(t, append(s, toAppend...), s2)
switch {
// If slice already has capacity for
// 'toAppend' then it should not change.
case cap(s) >= len(s)+a:
assert.Equal(t, cap(s), cap(s2))
// Else, returned slice should only
// have capacity for original length
// plus extra elements, NOTHING MORE.
default:
assert.Equal(t, len(s)+a, cap(s2))
}
}
}
}
}
func TestGather(t *testing.T) { func TestGather(t *testing.T) {
out := xslices.Gather(nil, []*url.URL{ out := util.Gather(nil, []*url.URL{
{Scheme: "https", Host: "google.com", Path: "/some-search"}, {Scheme: "https", Host: "google.com", Path: "/some-search"},
{Scheme: "http", Host: "example.com", Path: "/robots.txt"}, {Scheme: "http", Host: "example.com", Path: "/robots.txt"},
}, (*url.URL).String) }, (*url.URL).String)
@ -109,7 +41,7 @@ func TestGather(t *testing.T) {
t.Fatal("unexpected gather output") t.Fatal("unexpected gather output")
} }
out = xslices.Gather([]string{ out = util.Gather([]string{
"starting input string", "starting input string",
"another starting input", "another starting input",
}, []*url.URL{ }, []*url.URL{
@ -127,7 +59,7 @@ func TestGather(t *testing.T) {
} }
func TestGatherIf(t *testing.T) { func TestGatherIf(t *testing.T) {
out := xslices.GatherIf(nil, []string{ out := util.GatherIf(nil, []string{
"hello world", "hello world",
"not hello world", "not hello world",
"hello world", "hello world",
@ -141,7 +73,7 @@ func TestGatherIf(t *testing.T) {
t.Fatal("unexpected gatherif output") t.Fatal("unexpected gatherif output")
} }
out = xslices.GatherIf([]string{ out = util.GatherIf([]string{
"starting input string", "starting input string",
"another starting input", "another starting input",
}, []string{ }, []string{

View file

@ -1,15 +0,0 @@
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/

View file

@ -1,133 +0,0 @@
# xsync benchmarks
If you're interested in `MapOf` comparison with some of the popular concurrent hash maps written in Go, check [this](https://github.com/cornelk/hashmap/pull/70) and [this](https://github.com/alphadose/haxmap/pull/22) PRs.
The below results were obtained for xsync v2.3.1 on a c6g.metal EC2 instance (64 CPU, 128GB RAM) running Linux and Go 1.19.3. I'd like to thank [@felixge](https://github.com/felixge) who kindly ran the benchmarks.
The following commands were used to run the benchmarks:
```bash
$ go test -run='^$' -cpu=1,2,4,8,16,32,64 -bench . -count=30 -timeout=0 | tee bench.txt
$ benchstat bench.txt | tee benchstat.txt
```
The below sections contain some of the results. Refer to [this gist](https://gist.github.com/puzpuzpuz/e62e38e06feadecfdc823c0f941ece0b) for the complete output.
Please note that `MapOf` got a number of optimizations since v2.3.1, so the current result is likely to be different.
### Counter vs. atomic int64
```
name time/op
Counter 27.3ns ± 1%
Counter-2 27.2ns ±11%
Counter-4 15.3ns ± 8%
Counter-8 7.43ns ± 7%
Counter-16 3.70ns ±10%
Counter-32 1.77ns ± 3%
Counter-64 0.96ns ±10%
AtomicInt64 7.60ns ± 0%
AtomicInt64-2 12.6ns ±13%
AtomicInt64-4 13.5ns ±14%
AtomicInt64-8 12.7ns ± 9%
AtomicInt64-16 12.8ns ± 8%
AtomicInt64-32 13.0ns ± 6%
AtomicInt64-64 12.9ns ± 7%
```
Here `time/op` stands for average time spent on operation. If you divide `10^9` by the result in nanoseconds per operation, you'd get the throughput in operations per second. Thus, the ideal theoretical scalability of a concurrent data structure implies that the reported `time/op` decreases proportionally with the increased number of CPU cores. On the contrary, if the measured time per operation increases when run on more cores, it means performance degradation.
### MapOf vs. sync.Map
1,000 `[int, int]` entries with a warm-up, 100% Loads:
```
IntegerMapOf_WarmUp/reads=100% 24.0ns ± 0%
IntegerMapOf_WarmUp/reads=100%-2 12.0ns ± 0%
IntegerMapOf_WarmUp/reads=100%-4 6.02ns ± 0%
IntegerMapOf_WarmUp/reads=100%-8 3.01ns ± 0%
IntegerMapOf_WarmUp/reads=100%-16 1.50ns ± 0%
IntegerMapOf_WarmUp/reads=100%-32 0.75ns ± 0%
IntegerMapOf_WarmUp/reads=100%-64 0.38ns ± 0%
IntegerMapStandard_WarmUp/reads=100% 55.3ns ± 0%
IntegerMapStandard_WarmUp/reads=100%-2 27.6ns ± 0%
IntegerMapStandard_WarmUp/reads=100%-4 16.1ns ± 3%
IntegerMapStandard_WarmUp/reads=100%-8 8.35ns ± 7%
IntegerMapStandard_WarmUp/reads=100%-16 4.24ns ± 7%
IntegerMapStandard_WarmUp/reads=100%-32 2.18ns ± 6%
IntegerMapStandard_WarmUp/reads=100%-64 1.11ns ± 3%
```
1,000 `[int, int]` entries with a warm-up, 99% Loads, 0.5% Stores, 0.5% Deletes:
```
IntegerMapOf_WarmUp/reads=99% 31.0ns ± 0%
IntegerMapOf_WarmUp/reads=99%-2 16.4ns ± 1%
IntegerMapOf_WarmUp/reads=99%-4 8.42ns ± 0%
IntegerMapOf_WarmUp/reads=99%-8 4.41ns ± 0%
IntegerMapOf_WarmUp/reads=99%-16 2.38ns ± 2%
IntegerMapOf_WarmUp/reads=99%-32 1.37ns ± 4%
IntegerMapOf_WarmUp/reads=99%-64 0.85ns ± 2%
IntegerMapStandard_WarmUp/reads=99% 121ns ± 1%
IntegerMapStandard_WarmUp/reads=99%-2 109ns ± 3%
IntegerMapStandard_WarmUp/reads=99%-4 115ns ± 4%
IntegerMapStandard_WarmUp/reads=99%-8 114ns ± 2%
IntegerMapStandard_WarmUp/reads=99%-16 105ns ± 2%
IntegerMapStandard_WarmUp/reads=99%-32 97.0ns ± 3%
IntegerMapStandard_WarmUp/reads=99%-64 98.0ns ± 2%
```
1,000 `[int, int]` entries with a warm-up, 75% Loads, 12.5% Stores, 12.5% Deletes:
```
IntegerMapOf_WarmUp/reads=75%-reads 46.2ns ± 1%
IntegerMapOf_WarmUp/reads=75%-reads-2 36.7ns ± 2%
IntegerMapOf_WarmUp/reads=75%-reads-4 22.0ns ± 1%
IntegerMapOf_WarmUp/reads=75%-reads-8 12.8ns ± 2%
IntegerMapOf_WarmUp/reads=75%-reads-16 7.69ns ± 1%
IntegerMapOf_WarmUp/reads=75%-reads-32 5.16ns ± 1%
IntegerMapOf_WarmUp/reads=75%-reads-64 4.91ns ± 1%
IntegerMapStandard_WarmUp/reads=75%-reads 156ns ± 0%
IntegerMapStandard_WarmUp/reads=75%-reads-2 177ns ± 1%
IntegerMapStandard_WarmUp/reads=75%-reads-4 197ns ± 1%
IntegerMapStandard_WarmUp/reads=75%-reads-8 221ns ± 2%
IntegerMapStandard_WarmUp/reads=75%-reads-16 242ns ± 1%
IntegerMapStandard_WarmUp/reads=75%-reads-32 258ns ± 1%
IntegerMapStandard_WarmUp/reads=75%-reads-64 264ns ± 1%
```
### MPMCQueue vs. Go channels
Concurrent producers and consumers (1:1), queue/channel size 1,000, some work done by both producers and consumers:
```
QueueProdConsWork100 252ns ± 0%
QueueProdConsWork100-2 206ns ± 5%
QueueProdConsWork100-4 136ns ±12%
QueueProdConsWork100-8 110ns ± 6%
QueueProdConsWork100-16 108ns ± 2%
QueueProdConsWork100-32 102ns ± 2%
QueueProdConsWork100-64 101ns ± 0%
ChanProdConsWork100 283ns ± 0%
ChanProdConsWork100-2 406ns ±21%
ChanProdConsWork100-4 549ns ± 7%
ChanProdConsWork100-8 754ns ± 7%
ChanProdConsWork100-16 828ns ± 7%
ChanProdConsWork100-32 810ns ± 8%
ChanProdConsWork100-64 832ns ± 4%
```
### RBMutex vs. sync.RWMutex
The writer locks on each 100,000 iteration with some work in the critical section for both readers and the writer:
```
RBMutexWorkWrite100000 146ns ± 0%
RBMutexWorkWrite100000-2 73.3ns ± 0%
RBMutexWorkWrite100000-4 36.7ns ± 0%
RBMutexWorkWrite100000-8 18.6ns ± 0%
RBMutexWorkWrite100000-16 9.83ns ± 3%
RBMutexWorkWrite100000-32 5.53ns ± 0%
RBMutexWorkWrite100000-64 4.04ns ± 3%
RWMutexWorkWrite100000 121ns ± 0%
RWMutexWorkWrite100000-2 128ns ± 1%
RWMutexWorkWrite100000-4 124ns ± 2%
RWMutexWorkWrite100000-8 101ns ± 1%
RWMutexWorkWrite100000-16 92.9ns ± 1%
RWMutexWorkWrite100000-32 89.9ns ± 1%
RWMutexWorkWrite100000-64 88.4ns ± 1%
```

View file

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View file

@ -1,166 +0,0 @@
[![GoDoc reference](https://img.shields.io/badge/godoc-reference-blue.svg)](https://pkg.go.dev/github.com/puzpuzpuz/xsync/v3)
[![GoReport](https://goreportcard.com/badge/github.com/puzpuzpuz/xsync/v3)](https://goreportcard.com/report/github.com/puzpuzpuz/xsync/v3)
[![codecov](https://codecov.io/gh/puzpuzpuz/xsync/branch/main/graph/badge.svg)](https://codecov.io/gh/puzpuzpuz/xsync)
# xsync
Concurrent data structures for Go. Aims to provide more scalable alternatives for some of the data structures from the standard `sync` package, but not only.
Covered with tests following the approach described [here](https://puzpuzpuz.dev/testing-concurrent-code-for-fun-and-profit).
## Benchmarks
Benchmark results may be found [here](BENCHMARKS.md). I'd like to thank [@felixge](https://github.com/felixge) who kindly ran the benchmarks on a beefy multicore machine.
Also, a non-scientific, unfair benchmark comparing Java's [j.u.c.ConcurrentHashMap](https://docs.oracle.com/en/java/javase/17/docs/api/java.base/java/util/concurrent/ConcurrentHashMap.html) and `xsync.MapOf` is available [here](https://puzpuzpuz.dev/concurrent-map-in-go-vs-java-yet-another-meaningless-benchmark).
## Usage
The latest xsync major version is v3, so `/v3` suffix should be used when importing the library:
```go
import (
"github.com/puzpuzpuz/xsync/v3"
)
```
*Note for pre-v3 users*: v1 and v2 support is discontinued, so please upgrade to v3. While the API has some breaking changes, the migration should be trivial.
### Counter
A `Counter` is a striped `int64` counter inspired by the `j.u.c.a.LongAdder` class from the Java standard library.
```go
c := xsync.NewCounter()
// increment and decrement the counter
c.Inc()
c.Dec()
// read the current value
v := c.Value()
```
Works better in comparison with a single atomically updated `int64` counter in high contention scenarios.
### Map
A `Map` is like a concurrent hash table-based map. It follows the interface of `sync.Map` with a number of valuable extensions like `Compute` or `Size`.
```go
m := xsync.NewMap()
m.Store("foo", "bar")
v, ok := m.Load("foo")
s := m.Size()
```
`Map` uses a modified version of Cache-Line Hash Table (CLHT) data structure: https://github.com/LPD-EPFL/CLHT
CLHT is built around the idea of organizing the hash table in cache-line-sized buckets, so that on all modern CPUs update operations complete with minimal cache-line transfer. Also, `Get` operations are obstruction-free and involve no writes to shared memory, hence no mutexes or any other sort of locks. Due to this design, in all considered scenarios `Map` outperforms `sync.Map`.
One important difference with `sync.Map` is that only string keys are supported. That's because Golang standard library does not expose the built-in hash functions for `interface{}` values.
`MapOf[K, V]` is an implementation with parametrized key and value types. While it's still a CLHT-inspired hash map, `MapOf`'s design is quite different from `Map`. As a result, less GC pressure and fewer atomic operations on reads.
```go
m := xsync.NewMapOf[string, string]()
m.Store("foo", "bar")
v, ok := m.Load("foo")
```
Apart from CLHT, `MapOf` borrows ideas from Java's `j.u.c.ConcurrentHashMap` (immutable K/V pair structs instead of atomic snapshots) and C++'s `absl::flat_hash_map` (meta memory and SWAR-based lookups). It also has more dense memory layout when compared with `Map`. Long story short, `MapOf` should be preferred over `Map` when possible.
An important difference with `Map` is that `MapOf` supports arbitrary `comparable` key types:
```go
type Point struct {
x int32
y int32
}
m := NewMapOf[Point, int]()
m.Store(Point{42, 42}, 42)
v, ok := m.Load(point{42, 42})
```
Both maps use the built-in Golang's hash function which has DDOS protection. This means that each map instance gets its own seed number and the hash function uses that seed for hash code calculation. However, for smaller keys this hash function has some overhead. So, if you don't need DDOS protection, you may provide a custom hash function when creating a `MapOf`. For instance, Murmur3 finalizer does a decent job when it comes to integers:
```go
m := NewMapOfWithHasher[int, int](func(i int, _ uint64) uint64 {
h := uint64(i)
h = (h ^ (h >> 33)) * 0xff51afd7ed558ccd
h = (h ^ (h >> 33)) * 0xc4ceb9fe1a85ec53
return h ^ (h >> 33)
})
```
When benchmarking concurrent maps, make sure to configure all of the competitors with the same hash function or, at least, take hash function performance into the consideration.
### MPMCQueue
A `MPMCQueue` is a bounded multi-producer multi-consumer concurrent queue.
```go
q := xsync.NewMPMCQueue(1024)
// producer inserts an item into the queue
q.Enqueue("foo")
// optimistic insertion attempt; doesn't block
inserted := q.TryEnqueue("bar")
// consumer obtains an item from the queue
item := q.Dequeue() // interface{} pointing to a string
// optimistic obtain attempt; doesn't block
item, ok := q.TryDequeue()
```
`MPMCQueueOf[I]` is an implementation with parametrized item type. It is available for Go 1.19 or later.
```go
q := xsync.NewMPMCQueueOf[string](1024)
q.Enqueue("foo")
item := q.Dequeue() // string
```
The queue is based on the algorithm from the [MPMCQueue](https://github.com/rigtorp/MPMCQueue) C++ library which in its turn references D.Vyukov's [MPMC queue](https://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue). According to the following [classification](https://www.1024cores.net/home/lock-free-algorithms/queues), the queue is array-based, fails on overflow, provides causal FIFO, has blocking producers and consumers.
The idea of the algorithm is to allow parallelism for concurrent producers and consumers by introducing the notion of tickets, i.e. values of two counters, one per producers/consumers. An atomic increment of one of those counters is the only noticeable contention point in queue operations. The rest of the operation avoids contention on writes thanks to the turn-based read/write access for each of the queue items.
In essence, `MPMCQueue` is a specialized queue for scenarios where there are multiple concurrent producers and consumers of a single queue running on a large multicore machine.
To get the optimal performance, you may want to set the queue size to be large enough, say, an order of magnitude greater than the number of producers/consumers, to allow producers and consumers to progress with their queue operations in parallel most of the time.
### RBMutex
A `RBMutex` is a reader-biased reader/writer mutual exclusion lock. The lock can be held by many readers or a single writer.
```go
mu := xsync.NewRBMutex()
// reader lock calls return a token
t := mu.RLock()
// the token must be later used to unlock the mutex
mu.RUnlock(t)
// writer locks are the same as in sync.RWMutex
mu.Lock()
mu.Unlock()
```
`RBMutex` is based on a modified version of BRAVO (Biased Locking for Reader-Writer Locks) algorithm: https://arxiv.org/pdf/1810.01553.pdf
The idea of the algorithm is to build on top of an existing reader-writer mutex and introduce a fast path for readers. On the fast path, reader lock attempts are sharded over an internal array based on the reader identity (a token in the case of Golang). This means that readers do not contend over a single atomic counter like it's done in, say, `sync.RWMutex` allowing for better scalability in terms of cores.
Hence, by the design `RBMutex` is a specialized mutex for scenarios, such as caches, where the vast majority of locks are acquired by readers and write lock acquire attempts are infrequent. In such scenarios, `RBMutex` should perform better than the `sync.RWMutex` on large multicore machines.
`RBMutex` extends `sync.RWMutex` internally and uses it as the "reader bias disabled" fallback, so the same semantics apply. The only noticeable difference is in the reader tokens returned from the `RLock`/`RUnlock` methods.
Apart from blocking methods, `RBMutex` also has methods for optimistic locking:
```go
mu := xsync.NewRBMutex()
if locked, t := mu.TryRLock(); locked {
// critical reader section...
mu.RUnlock(t)
}
if mu.TryLock() {
// critical writer section...
mu.Unlock()
}
```
## License
Licensed under MIT.

View file

@ -1,99 +0,0 @@
package xsync
import (
"sync"
"sync/atomic"
)
// pool for P tokens
var ptokenPool sync.Pool
// a P token is used to point at the current OS thread (P)
// on which the goroutine is run; exact identity of the thread,
// as well as P migration tolerance, is not important since
// it's used to as a best effort mechanism for assigning
// concurrent operations (goroutines) to different stripes of
// the counter
type ptoken struct {
idx uint32
//lint:ignore U1000 prevents false sharing
pad [cacheLineSize - 4]byte
}
// A Counter is a striped int64 counter.
//
// Should be preferred over a single atomically updated int64
// counter in high contention scenarios.
//
// A Counter must not be copied after first use.
type Counter struct {
stripes []cstripe
mask uint32
}
type cstripe struct {
c int64
//lint:ignore U1000 prevents false sharing
pad [cacheLineSize - 8]byte
}
// NewCounter creates a new Counter instance.
func NewCounter() *Counter {
nstripes := nextPowOf2(parallelism())
c := Counter{
stripes: make([]cstripe, nstripes),
mask: nstripes - 1,
}
return &c
}
// Inc increments the counter by 1.
func (c *Counter) Inc() {
c.Add(1)
}
// Dec decrements the counter by 1.
func (c *Counter) Dec() {
c.Add(-1)
}
// Add adds the delta to the counter.
func (c *Counter) Add(delta int64) {
t, ok := ptokenPool.Get().(*ptoken)
if !ok {
t = new(ptoken)
t.idx = runtime_fastrand()
}
for {
stripe := &c.stripes[t.idx&c.mask]
cnt := atomic.LoadInt64(&stripe.c)
if atomic.CompareAndSwapInt64(&stripe.c, cnt, cnt+delta) {
break
}
// Give a try with another randomly selected stripe.
t.idx = runtime_fastrand()
}
ptokenPool.Put(t)
}
// Value returns the current counter value.
// The returned value may not include all of the latest operations in
// presence of concurrent modifications of the counter.
func (c *Counter) Value() int64 {
v := int64(0)
for i := 0; i < len(c.stripes); i++ {
stripe := &c.stripes[i]
v += atomic.LoadInt64(&stripe.c)
}
return v
}
// Reset resets the counter to zero.
// This method should only be used when it is known that there are
// no concurrent modifications of the counter.
func (c *Counter) Reset() {
for i := 0; i < len(c.stripes); i++ {
stripe := &c.stripes[i]
atomic.StoreInt64(&stripe.c, 0)
}
}

View file

@ -1,873 +0,0 @@
package xsync
import (
"fmt"
"math"
"runtime"
"strings"
"sync"
"sync/atomic"
"unsafe"
)
type mapResizeHint int
const (
mapGrowHint mapResizeHint = 0
mapShrinkHint mapResizeHint = 1
mapClearHint mapResizeHint = 2
)
const (
// number of Map entries per bucket; 3 entries lead to size of 64B
// (one cache line) on 64-bit machines
entriesPerMapBucket = 3
// threshold fraction of table occupation to start a table shrinking
// when deleting the last entry in a bucket chain
mapShrinkFraction = 128
// map load factor to trigger a table resize during insertion;
// a map holds up to mapLoadFactor*entriesPerMapBucket*mapTableLen
// key-value pairs (this is a soft limit)
mapLoadFactor = 0.75
// minimal table size, i.e. number of buckets; thus, minimal map
// capacity can be calculated as entriesPerMapBucket*defaultMinMapTableLen
defaultMinMapTableLen = 32
// minimum counter stripes to use
minMapCounterLen = 8
// maximum counter stripes to use; stands for around 4KB of memory
maxMapCounterLen = 32
)
var (
topHashMask = uint64((1<<20)-1) << 44
topHashEntryMasks = [3]uint64{
topHashMask,
topHashMask >> 20,
topHashMask >> 40,
}
)
// Map is like a Go map[string]interface{} but is safe for concurrent
// use by multiple goroutines without additional locking or
// coordination. It follows the interface of sync.Map with
// a number of valuable extensions like Compute or Size.
//
// A Map must not be copied after first use.
//
// Map uses a modified version of Cache-Line Hash Table (CLHT)
// data structure: https://github.com/LPD-EPFL/CLHT
//
// CLHT is built around idea to organize the hash table in
// cache-line-sized buckets, so that on all modern CPUs update
// operations complete with at most one cache-line transfer.
// Also, Get operations involve no write to memory, as well as no
// mutexes or any other sort of locks. Due to this design, in all
// considered scenarios Map outperforms sync.Map.
//
// One important difference with sync.Map is that only string keys
// are supported. That's because Golang standard library does not
// expose the built-in hash functions for interface{} values.
type Map struct {
totalGrowths int64
totalShrinks int64
resizing int64 // resize in progress flag; updated atomically
resizeMu sync.Mutex // only used along with resizeCond
resizeCond sync.Cond // used to wake up resize waiters (concurrent modifications)
table unsafe.Pointer // *mapTable
minTableLen int
growOnly bool
}
type mapTable struct {
buckets []bucketPadded
// striped counter for number of table entries;
// used to determine if a table shrinking is needed
// occupies min(buckets_memory/1024, 64KB) of memory
size []counterStripe
seed uint64
}
type counterStripe struct {
c int64
//lint:ignore U1000 prevents false sharing
pad [cacheLineSize - 8]byte
}
type bucketPadded struct {
//lint:ignore U1000 ensure each bucket takes two cache lines on both 32 and 64-bit archs
pad [cacheLineSize - unsafe.Sizeof(bucket{})]byte
bucket
}
type bucket struct {
next unsafe.Pointer // *bucketPadded
keys [entriesPerMapBucket]unsafe.Pointer
values [entriesPerMapBucket]unsafe.Pointer
// topHashMutex is a 2-in-1 value.
//
// It contains packed top 20 bits (20 MSBs) of hash codes for keys
// stored in the bucket:
// | key 0's top hash | key 1's top hash | key 2's top hash | bitmap for keys | mutex |
// | 20 bits | 20 bits | 20 bits | 3 bits | 1 bit |
//
// The least significant bit is used for the mutex (TTAS spinlock).
topHashMutex uint64
}
type rangeEntry struct {
key unsafe.Pointer
value unsafe.Pointer
}
// MapConfig defines configurable Map/MapOf options.
type MapConfig struct {
sizeHint int
growOnly bool
}
// WithPresize configures new Map/MapOf instance with capacity enough
// to hold sizeHint entries. The capacity is treated as the minimal
// capacity meaning that the underlying hash table will never shrink
// to a smaller capacity. If sizeHint is zero or negative, the value
// is ignored.
func WithPresize(sizeHint int) func(*MapConfig) {
return func(c *MapConfig) {
c.sizeHint = sizeHint
}
}
// WithGrowOnly configures new Map/MapOf instance to be grow-only.
// This means that the underlying hash table grows in capacity when
// new keys are added, but does not shrink when keys are deleted.
// The only exception to this rule is the Clear method which
// shrinks the hash table back to the initial capacity.
func WithGrowOnly() func(*MapConfig) {
return func(c *MapConfig) {
c.growOnly = true
}
}
// NewMap creates a new Map instance configured with the given
// options.
func NewMap(options ...func(*MapConfig)) *Map {
c := &MapConfig{
sizeHint: defaultMinMapTableLen * entriesPerMapBucket,
}
for _, o := range options {
o(c)
}
m := &Map{}
m.resizeCond = *sync.NewCond(&m.resizeMu)
var table *mapTable
if c.sizeHint <= defaultMinMapTableLen*entriesPerMapBucket {
table = newMapTable(defaultMinMapTableLen)
} else {
tableLen := nextPowOf2(uint32((float64(c.sizeHint) / entriesPerMapBucket) / mapLoadFactor))
table = newMapTable(int(tableLen))
}
m.minTableLen = len(table.buckets)
m.growOnly = c.growOnly
atomic.StorePointer(&m.table, unsafe.Pointer(table))
return m
}
// NewMapPresized creates a new Map instance with capacity enough to hold
// sizeHint entries. The capacity is treated as the minimal capacity
// meaning that the underlying hash table will never shrink to
// a smaller capacity. If sizeHint is zero or negative, the value
// is ignored.
//
// Deprecated: use NewMap in combination with WithPresize.
func NewMapPresized(sizeHint int) *Map {
return NewMap(WithPresize(sizeHint))
}
func newMapTable(minTableLen int) *mapTable {
buckets := make([]bucketPadded, minTableLen)
counterLen := minTableLen >> 10
if counterLen < minMapCounterLen {
counterLen = minMapCounterLen
} else if counterLen > maxMapCounterLen {
counterLen = maxMapCounterLen
}
counter := make([]counterStripe, counterLen)
t := &mapTable{
buckets: buckets,
size: counter,
seed: makeSeed(),
}
return t
}
// Load returns the value stored in the map for a key, or nil if no
// value is present.
// The ok result indicates whether value was found in the map.
func (m *Map) Load(key string) (value interface{}, ok bool) {
table := (*mapTable)(atomic.LoadPointer(&m.table))
hash := hashString(key, table.seed)
bidx := uint64(len(table.buckets)-1) & hash
b := &table.buckets[bidx]
for {
topHashes := atomic.LoadUint64(&b.topHashMutex)
for i := 0; i < entriesPerMapBucket; i++ {
if !topHashMatch(hash, topHashes, i) {
continue
}
atomic_snapshot:
// Start atomic snapshot.
vp := atomic.LoadPointer(&b.values[i])
kp := atomic.LoadPointer(&b.keys[i])
if kp != nil && vp != nil {
if key == derefKey(kp) {
if uintptr(vp) == uintptr(atomic.LoadPointer(&b.values[i])) {
// Atomic snapshot succeeded.
return derefValue(vp), true
}
// Concurrent update/remove. Go for another spin.
goto atomic_snapshot
}
}
}
bptr := atomic.LoadPointer(&b.next)
if bptr == nil {
return
}
b = (*bucketPadded)(bptr)
}
}
// Store sets the value for a key.
func (m *Map) Store(key string, value interface{}) {
m.doCompute(
key,
func(interface{}, bool) (interface{}, bool) {
return value, false
},
false,
false,
)
}
// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map) LoadOrStore(key string, value interface{}) (actual interface{}, loaded bool) {
return m.doCompute(
key,
func(interface{}, bool) (interface{}, bool) {
return value, false
},
true,
false,
)
}
// LoadAndStore returns the existing value for the key if present,
// while setting the new value for the key.
// It stores the new value and returns the existing one, if present.
// The loaded result is true if the existing value was loaded,
// false otherwise.
func (m *Map) LoadAndStore(key string, value interface{}) (actual interface{}, loaded bool) {
return m.doCompute(
key,
func(interface{}, bool) (interface{}, bool) {
return value, false
},
false,
false,
)
}
// LoadOrCompute returns the existing value for the key if present.
// Otherwise, it computes the value using the provided function and
// returns the computed value. The loaded result is true if the value
// was loaded, false if stored.
//
// This call locks a hash table bucket while the compute function
// is executed. It means that modifications on other entries in
// the bucket will be blocked until the valueFn executes. Consider
// this when the function includes long-running operations.
func (m *Map) LoadOrCompute(key string, valueFn func() interface{}) (actual interface{}, loaded bool) {
return m.doCompute(
key,
func(interface{}, bool) (interface{}, bool) {
return valueFn(), false
},
true,
false,
)
}
// Compute either sets the computed new value for the key or deletes
// the value for the key. When the delete result of the valueFn function
// is set to true, the value will be deleted, if it exists. When delete
// is set to false, the value is updated to the newValue.
// The ok result indicates whether value was computed and stored, thus, is
// present in the map. The actual result contains the new value in cases where
// the value was computed and stored. See the example for a few use cases.
//
// This call locks a hash table bucket while the compute function
// is executed. It means that modifications on other entries in
// the bucket will be blocked until the valueFn executes. Consider
// this when the function includes long-running operations.
func (m *Map) Compute(
key string,
valueFn func(oldValue interface{}, loaded bool) (newValue interface{}, delete bool),
) (actual interface{}, ok bool) {
return m.doCompute(key, valueFn, false, true)
}
// LoadAndDelete deletes the value for a key, returning the previous
// value if any. The loaded result reports whether the key was
// present.
func (m *Map) LoadAndDelete(key string) (value interface{}, loaded bool) {
return m.doCompute(
key,
func(value interface{}, loaded bool) (interface{}, bool) {
return value, true
},
false,
false,
)
}
// Delete deletes the value for a key.
func (m *Map) Delete(key string) {
m.doCompute(
key,
func(value interface{}, loaded bool) (interface{}, bool) {
return value, true
},
false,
false,
)
}
func (m *Map) doCompute(
key string,
valueFn func(oldValue interface{}, loaded bool) (interface{}, bool),
loadIfExists, computeOnly bool,
) (interface{}, bool) {
// Read-only path.
if loadIfExists {
if v, ok := m.Load(key); ok {
return v, !computeOnly
}
}
// Write path.
for {
compute_attempt:
var (
emptyb *bucketPadded
emptyidx int
hintNonEmpty int
)
table := (*mapTable)(atomic.LoadPointer(&m.table))
tableLen := len(table.buckets)
hash := hashString(key, table.seed)
bidx := uint64(len(table.buckets)-1) & hash
rootb := &table.buckets[bidx]
lockBucket(&rootb.topHashMutex)
// The following two checks must go in reverse to what's
// in the resize method.
if m.resizeInProgress() {
// Resize is in progress. Wait, then go for another attempt.
unlockBucket(&rootb.topHashMutex)
m.waitForResize()
goto compute_attempt
}
if m.newerTableExists(table) {
// Someone resized the table. Go for another attempt.
unlockBucket(&rootb.topHashMutex)
goto compute_attempt
}
b := rootb
for {
topHashes := atomic.LoadUint64(&b.topHashMutex)
for i := 0; i < entriesPerMapBucket; i++ {
if b.keys[i] == nil {
if emptyb == nil {
emptyb = b
emptyidx = i
}
continue
}
if !topHashMatch(hash, topHashes, i) {
hintNonEmpty++
continue
}
if key == derefKey(b.keys[i]) {
vp := b.values[i]
if loadIfExists {
unlockBucket(&rootb.topHashMutex)
return derefValue(vp), !computeOnly
}
// In-place update/delete.
// We get a copy of the value via an interface{} on each call,
// thus the live value pointers are unique. Otherwise atomic
// snapshot won't be correct in case of multiple Store calls
// using the same value.
oldValue := derefValue(vp)
newValue, del := valueFn(oldValue, true)
if del {
// Deletion.
// First we update the value, then the key.
// This is important for atomic snapshot states.
atomic.StoreUint64(&b.topHashMutex, eraseTopHash(topHashes, i))
atomic.StorePointer(&b.values[i], nil)
atomic.StorePointer(&b.keys[i], nil)
leftEmpty := false
if hintNonEmpty == 0 {
leftEmpty = isEmptyBucket(b)
}
unlockBucket(&rootb.topHashMutex)
table.addSize(bidx, -1)
// Might need to shrink the table.
if leftEmpty {
m.resize(table, mapShrinkHint)
}
return oldValue, !computeOnly
}
nvp := unsafe.Pointer(&newValue)
if assertionsEnabled && vp == nvp {
panic("non-unique value pointer")
}
atomic.StorePointer(&b.values[i], nvp)
unlockBucket(&rootb.topHashMutex)
if computeOnly {
// Compute expects the new value to be returned.
return newValue, true
}
// LoadAndStore expects the old value to be returned.
return oldValue, true
}
hintNonEmpty++
}
if b.next == nil {
if emptyb != nil {
// Insertion into an existing bucket.
var zeroedV interface{}
newValue, del := valueFn(zeroedV, false)
if del {
unlockBucket(&rootb.topHashMutex)
return zeroedV, false
}
// First we update the value, then the key.
// This is important for atomic snapshot states.
topHashes = atomic.LoadUint64(&emptyb.topHashMutex)
atomic.StoreUint64(&emptyb.topHashMutex, storeTopHash(hash, topHashes, emptyidx))
atomic.StorePointer(&emptyb.values[emptyidx], unsafe.Pointer(&newValue))
atomic.StorePointer(&emptyb.keys[emptyidx], unsafe.Pointer(&key))
unlockBucket(&rootb.topHashMutex)
table.addSize(bidx, 1)
return newValue, computeOnly
}
growThreshold := float64(tableLen) * entriesPerMapBucket * mapLoadFactor
if table.sumSize() > int64(growThreshold) {
// Need to grow the table. Then go for another attempt.
unlockBucket(&rootb.topHashMutex)
m.resize(table, mapGrowHint)
goto compute_attempt
}
// Insertion into a new bucket.
var zeroedV interface{}
newValue, del := valueFn(zeroedV, false)
if del {
unlockBucket(&rootb.topHashMutex)
return newValue, false
}
// Create and append a bucket.
newb := new(bucketPadded)
newb.keys[0] = unsafe.Pointer(&key)
newb.values[0] = unsafe.Pointer(&newValue)
newb.topHashMutex = storeTopHash(hash, newb.topHashMutex, 0)
atomic.StorePointer(&b.next, unsafe.Pointer(newb))
unlockBucket(&rootb.topHashMutex)
table.addSize(bidx, 1)
return newValue, computeOnly
}
b = (*bucketPadded)(b.next)
}
}
}
func (m *Map) newerTableExists(table *mapTable) bool {
curTablePtr := atomic.LoadPointer(&m.table)
return uintptr(curTablePtr) != uintptr(unsafe.Pointer(table))
}
func (m *Map) resizeInProgress() bool {
return atomic.LoadInt64(&m.resizing) == 1
}
func (m *Map) waitForResize() {
m.resizeMu.Lock()
for m.resizeInProgress() {
m.resizeCond.Wait()
}
m.resizeMu.Unlock()
}
func (m *Map) resize(knownTable *mapTable, hint mapResizeHint) {
knownTableLen := len(knownTable.buckets)
// Fast path for shrink attempts.
if hint == mapShrinkHint {
if m.growOnly ||
m.minTableLen == knownTableLen ||
knownTable.sumSize() > int64((knownTableLen*entriesPerMapBucket)/mapShrinkFraction) {
return
}
}
// Slow path.
if !atomic.CompareAndSwapInt64(&m.resizing, 0, 1) {
// Someone else started resize. Wait for it to finish.
m.waitForResize()
return
}
var newTable *mapTable
table := (*mapTable)(atomic.LoadPointer(&m.table))
tableLen := len(table.buckets)
switch hint {
case mapGrowHint:
// Grow the table with factor of 2.
atomic.AddInt64(&m.totalGrowths, 1)
newTable = newMapTable(tableLen << 1)
case mapShrinkHint:
shrinkThreshold := int64((tableLen * entriesPerMapBucket) / mapShrinkFraction)
if tableLen > m.minTableLen && table.sumSize() <= shrinkThreshold {
// Shrink the table with factor of 2.
atomic.AddInt64(&m.totalShrinks, 1)
newTable = newMapTable(tableLen >> 1)
} else {
// No need to shrink. Wake up all waiters and give up.
m.resizeMu.Lock()
atomic.StoreInt64(&m.resizing, 0)
m.resizeCond.Broadcast()
m.resizeMu.Unlock()
return
}
case mapClearHint:
newTable = newMapTable(m.minTableLen)
default:
panic(fmt.Sprintf("unexpected resize hint: %d", hint))
}
// Copy the data only if we're not clearing the map.
if hint != mapClearHint {
for i := 0; i < tableLen; i++ {
copied := copyBucket(&table.buckets[i], newTable)
newTable.addSizePlain(uint64(i), copied)
}
}
// Publish the new table and wake up all waiters.
atomic.StorePointer(&m.table, unsafe.Pointer(newTable))
m.resizeMu.Lock()
atomic.StoreInt64(&m.resizing, 0)
m.resizeCond.Broadcast()
m.resizeMu.Unlock()
}
func copyBucket(b *bucketPadded, destTable *mapTable) (copied int) {
rootb := b
lockBucket(&rootb.topHashMutex)
for {
for i := 0; i < entriesPerMapBucket; i++ {
if b.keys[i] != nil {
k := derefKey(b.keys[i])
hash := hashString(k, destTable.seed)
bidx := uint64(len(destTable.buckets)-1) & hash
destb := &destTable.buckets[bidx]
appendToBucket(hash, b.keys[i], b.values[i], destb)
copied++
}
}
if b.next == nil {
unlockBucket(&rootb.topHashMutex)
return
}
b = (*bucketPadded)(b.next)
}
}
func appendToBucket(hash uint64, keyPtr, valPtr unsafe.Pointer, b *bucketPadded) {
for {
for i := 0; i < entriesPerMapBucket; i++ {
if b.keys[i] == nil {
b.keys[i] = keyPtr
b.values[i] = valPtr
b.topHashMutex = storeTopHash(hash, b.topHashMutex, i)
return
}
}
if b.next == nil {
newb := new(bucketPadded)
newb.keys[0] = keyPtr
newb.values[0] = valPtr
newb.topHashMutex = storeTopHash(hash, newb.topHashMutex, 0)
b.next = unsafe.Pointer(newb)
return
}
b = (*bucketPadded)(b.next)
}
}
func isEmptyBucket(rootb *bucketPadded) bool {
b := rootb
for {
for i := 0; i < entriesPerMapBucket; i++ {
if b.keys[i] != nil {
return false
}
}
if b.next == nil {
return true
}
b = (*bucketPadded)(b.next)
}
}
// Range calls f sequentially for each key and value present in the
// map. If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot
// of the Map's contents: no key will be visited more than once, but
// if the value for any key is stored or deleted concurrently, Range
// may reflect any mapping for that key from any point during the
// Range call.
//
// It is safe to modify the map while iterating it, including entry
// creation, modification and deletion. However, the concurrent
// modification rule apply, i.e. the changes may be not reflected
// in the subsequently iterated entries.
func (m *Map) Range(f func(key string, value interface{}) bool) {
var zeroEntry rangeEntry
// Pre-allocate array big enough to fit entries for most hash tables.
bentries := make([]rangeEntry, 0, 16*entriesPerMapBucket)
tablep := atomic.LoadPointer(&m.table)
table := *(*mapTable)(tablep)
for i := range table.buckets {
rootb := &table.buckets[i]
b := rootb
// Prevent concurrent modifications and copy all entries into
// the intermediate slice.
lockBucket(&rootb.topHashMutex)
for {
for i := 0; i < entriesPerMapBucket; i++ {
if b.keys[i] != nil {
bentries = append(bentries, rangeEntry{
key: b.keys[i],
value: b.values[i],
})
}
}
if b.next == nil {
unlockBucket(&rootb.topHashMutex)
break
}
b = (*bucketPadded)(b.next)
}
// Call the function for all copied entries.
for j := range bentries {
k := derefKey(bentries[j].key)
v := derefValue(bentries[j].value)
if !f(k, v) {
return
}
// Remove the reference to avoid preventing the copied
// entries from being GCed until this method finishes.
bentries[j] = zeroEntry
}
bentries = bentries[:0]
}
}
// Clear deletes all keys and values currently stored in the map.
func (m *Map) Clear() {
table := (*mapTable)(atomic.LoadPointer(&m.table))
m.resize(table, mapClearHint)
}
// Size returns current size of the map.
func (m *Map) Size() int {
table := (*mapTable)(atomic.LoadPointer(&m.table))
return int(table.sumSize())
}
func derefKey(keyPtr unsafe.Pointer) string {
return *(*string)(keyPtr)
}
func derefValue(valuePtr unsafe.Pointer) interface{} {
return *(*interface{})(valuePtr)
}
func lockBucket(mu *uint64) {
for {
var v uint64
for {
v = atomic.LoadUint64(mu)
if v&1 != 1 {
break
}
runtime.Gosched()
}
if atomic.CompareAndSwapUint64(mu, v, v|1) {
return
}
runtime.Gosched()
}
}
func unlockBucket(mu *uint64) {
v := atomic.LoadUint64(mu)
atomic.StoreUint64(mu, v&^1)
}
func topHashMatch(hash, topHashes uint64, idx int) bool {
if topHashes&(1<<(idx+1)) == 0 {
// Entry is not present.
return false
}
hash = hash & topHashMask
topHashes = (topHashes & topHashEntryMasks[idx]) << (20 * idx)
return hash == topHashes
}
func storeTopHash(hash, topHashes uint64, idx int) uint64 {
// Zero out top hash at idx.
topHashes = topHashes &^ topHashEntryMasks[idx]
// Chop top 20 MSBs of the given hash and position them at idx.
hash = (hash & topHashMask) >> (20 * idx)
// Store the MSBs.
topHashes = topHashes | hash
// Mark the entry as present.
return topHashes | (1 << (idx + 1))
}
func eraseTopHash(topHashes uint64, idx int) uint64 {
return topHashes &^ (1 << (idx + 1))
}
func (table *mapTable) addSize(bucketIdx uint64, delta int) {
cidx := uint64(len(table.size)-1) & bucketIdx
atomic.AddInt64(&table.size[cidx].c, int64(delta))
}
func (table *mapTable) addSizePlain(bucketIdx uint64, delta int) {
cidx := uint64(len(table.size)-1) & bucketIdx
table.size[cidx].c += int64(delta)
}
func (table *mapTable) sumSize() int64 {
sum := int64(0)
for i := range table.size {
sum += atomic.LoadInt64(&table.size[i].c)
}
return sum
}
// MapStats is Map/MapOf statistics.
//
// Warning: map statistics are intented to be used for diagnostic
// purposes, not for production code. This means that breaking changes
// may be introduced into this struct even between minor releases.
type MapStats struct {
// RootBuckets is the number of root buckets in the hash table.
// Each bucket holds a few entries.
RootBuckets int
// TotalBuckets is the total number of buckets in the hash table,
// including root and their chained buckets. Each bucket holds
// a few entries.
TotalBuckets int
// EmptyBuckets is the number of buckets that hold no entries.
EmptyBuckets int
// Capacity is the Map/MapOf capacity, i.e. the total number of
// entries that all buckets can physically hold. This number
// does not consider the load factor.
Capacity int
// Size is the exact number of entries stored in the map.
Size int
// Counter is the number of entries stored in the map according
// to the internal atomic counter. In case of concurrent map
// modifications this number may be different from Size.
Counter int
// CounterLen is the number of internal atomic counter stripes.
// This number may grow with the map capacity to improve
// multithreaded scalability.
CounterLen int
// MinEntries is the minimum number of entries per a chain of
// buckets, i.e. a root bucket and its chained buckets.
MinEntries int
// MinEntries is the maximum number of entries per a chain of
// buckets, i.e. a root bucket and its chained buckets.
MaxEntries int
// TotalGrowths is the number of times the hash table grew.
TotalGrowths int64
// TotalGrowths is the number of times the hash table shrinked.
TotalShrinks int64
}
// ToString returns string representation of map stats.
func (s *MapStats) ToString() string {
var sb strings.Builder
sb.WriteString("MapStats{\n")
sb.WriteString(fmt.Sprintf("RootBuckets: %d\n", s.RootBuckets))
sb.WriteString(fmt.Sprintf("TotalBuckets: %d\n", s.TotalBuckets))
sb.WriteString(fmt.Sprintf("EmptyBuckets: %d\n", s.EmptyBuckets))
sb.WriteString(fmt.Sprintf("Capacity: %d\n", s.Capacity))
sb.WriteString(fmt.Sprintf("Size: %d\n", s.Size))
sb.WriteString(fmt.Sprintf("Counter: %d\n", s.Counter))
sb.WriteString(fmt.Sprintf("CounterLen: %d\n", s.CounterLen))
sb.WriteString(fmt.Sprintf("MinEntries: %d\n", s.MinEntries))
sb.WriteString(fmt.Sprintf("MaxEntries: %d\n", s.MaxEntries))
sb.WriteString(fmt.Sprintf("TotalGrowths: %d\n", s.TotalGrowths))
sb.WriteString(fmt.Sprintf("TotalShrinks: %d\n", s.TotalShrinks))
sb.WriteString("}\n")
return sb.String()
}
// Stats returns statistics for the Map. Just like other map
// methods, this one is thread-safe. Yet it's an O(N) operation,
// so it should be used only for diagnostics or debugging purposes.
func (m *Map) Stats() MapStats {
stats := MapStats{
TotalGrowths: atomic.LoadInt64(&m.totalGrowths),
TotalShrinks: atomic.LoadInt64(&m.totalShrinks),
MinEntries: math.MaxInt32,
}
table := (*mapTable)(atomic.LoadPointer(&m.table))
stats.RootBuckets = len(table.buckets)
stats.Counter = int(table.sumSize())
stats.CounterLen = len(table.size)
for i := range table.buckets {
nentries := 0
b := &table.buckets[i]
stats.TotalBuckets++
for {
nentriesLocal := 0
stats.Capacity += entriesPerMapBucket
for i := 0; i < entriesPerMapBucket; i++ {
if atomic.LoadPointer(&b.keys[i]) != nil {
stats.Size++
nentriesLocal++
}
}
nentries += nentriesLocal
if nentriesLocal == 0 {
stats.EmptyBuckets++
}
if b.next == nil {
break
}
b = (*bucketPadded)(atomic.LoadPointer(&b.next))
stats.TotalBuckets++
}
if nentries < stats.MinEntries {
stats.MinEntries = nentries
}
if nentries > stats.MaxEntries {
stats.MaxEntries = nentries
}
}
return stats
}

View file

@ -1,694 +0,0 @@
package xsync
import (
"fmt"
"math"
"sync"
"sync/atomic"
"unsafe"
)
const (
// number of MapOf entries per bucket; 5 entries lead to size of 64B
// (one cache line) on 64-bit machines
entriesPerMapOfBucket = 5
defaultMeta uint64 = 0x8080808080808080
metaMask uint64 = 0xffffffffff
defaultMetaMasked uint64 = defaultMeta & metaMask
emptyMetaSlot uint8 = 0x80
)
// MapOf is like a Go map[K]V but is safe for concurrent
// use by multiple goroutines without additional locking or
// coordination. It follows the interface of sync.Map with
// a number of valuable extensions like Compute or Size.
//
// A MapOf must not be copied after first use.
//
// MapOf uses a modified version of Cache-Line Hash Table (CLHT)
// data structure: https://github.com/LPD-EPFL/CLHT
//
// CLHT is built around idea to organize the hash table in
// cache-line-sized buckets, so that on all modern CPUs update
// operations complete with at most one cache-line transfer.
// Also, Get operations involve no write to memory, as well as no
// mutexes or any other sort of locks. Due to this design, in all
// considered scenarios MapOf outperforms sync.Map.
//
// MapOf also borrows ideas from Java's j.u.c.ConcurrentHashMap
// (immutable K/V pair structs instead of atomic snapshots)
// and C++'s absl::flat_hash_map (meta memory and SWAR-based
// lookups).
type MapOf[K comparable, V any] struct {
totalGrowths int64
totalShrinks int64
resizing int64 // resize in progress flag; updated atomically
resizeMu sync.Mutex // only used along with resizeCond
resizeCond sync.Cond // used to wake up resize waiters (concurrent modifications)
table unsafe.Pointer // *mapOfTable
hasher func(K, uint64) uint64
minTableLen int
growOnly bool
}
type mapOfTable[K comparable, V any] struct {
buckets []bucketOfPadded
// striped counter for number of table entries;
// used to determine if a table shrinking is needed
// occupies min(buckets_memory/1024, 64KB) of memory
size []counterStripe
seed uint64
}
// bucketOfPadded is a CL-sized map bucket holding up to
// entriesPerMapOfBucket entries.
type bucketOfPadded struct {
//lint:ignore U1000 ensure each bucket takes two cache lines on both 32 and 64-bit archs
pad [cacheLineSize - unsafe.Sizeof(bucketOf{})]byte
bucketOf
}
type bucketOf struct {
meta uint64
entries [entriesPerMapOfBucket]unsafe.Pointer // *entryOf
next unsafe.Pointer // *bucketOfPadded
mu sync.Mutex
}
// entryOf is an immutable map entry.
type entryOf[K comparable, V any] struct {
key K
value V
}
// NewMapOf creates a new MapOf instance configured with the given
// options.
func NewMapOf[K comparable, V any](options ...func(*MapConfig)) *MapOf[K, V] {
return NewMapOfWithHasher[K, V](defaultHasher[K](), options...)
}
// NewMapOfWithHasher creates a new MapOf instance configured with
// the given hasher and options. The hash function is used instead
// of the built-in hash function configured when a map is created
// with the NewMapOf function.
func NewMapOfWithHasher[K comparable, V any](
hasher func(K, uint64) uint64,
options ...func(*MapConfig),
) *MapOf[K, V] {
c := &MapConfig{
sizeHint: defaultMinMapTableLen * entriesPerMapOfBucket,
}
for _, o := range options {
o(c)
}
m := &MapOf[K, V]{}
m.resizeCond = *sync.NewCond(&m.resizeMu)
m.hasher = hasher
var table *mapOfTable[K, V]
if c.sizeHint <= defaultMinMapTableLen*entriesPerMapOfBucket {
table = newMapOfTable[K, V](defaultMinMapTableLen)
} else {
tableLen := nextPowOf2(uint32((float64(c.sizeHint) / entriesPerMapOfBucket) / mapLoadFactor))
table = newMapOfTable[K, V](int(tableLen))
}
m.minTableLen = len(table.buckets)
m.growOnly = c.growOnly
atomic.StorePointer(&m.table, unsafe.Pointer(table))
return m
}
// NewMapOfPresized creates a new MapOf instance with capacity enough
// to hold sizeHint entries. The capacity is treated as the minimal capacity
// meaning that the underlying hash table will never shrink to
// a smaller capacity. If sizeHint is zero or negative, the value
// is ignored.
//
// Deprecated: use NewMapOf in combination with WithPresize.
func NewMapOfPresized[K comparable, V any](sizeHint int) *MapOf[K, V] {
return NewMapOf[K, V](WithPresize(sizeHint))
}
func newMapOfTable[K comparable, V any](minTableLen int) *mapOfTable[K, V] {
buckets := make([]bucketOfPadded, minTableLen)
for i := range buckets {
buckets[i].meta = defaultMeta
}
counterLen := minTableLen >> 10
if counterLen < minMapCounterLen {
counterLen = minMapCounterLen
} else if counterLen > maxMapCounterLen {
counterLen = maxMapCounterLen
}
counter := make([]counterStripe, counterLen)
t := &mapOfTable[K, V]{
buckets: buckets,
size: counter,
seed: makeSeed(),
}
return t
}
// Load returns the value stored in the map for a key, or zero value
// of type V if no value is present.
// The ok result indicates whether value was found in the map.
func (m *MapOf[K, V]) Load(key K) (value V, ok bool) {
table := (*mapOfTable[K, V])(atomic.LoadPointer(&m.table))
hash := m.hasher(key, table.seed)
h1 := h1(hash)
h2w := broadcast(h2(hash))
bidx := uint64(len(table.buckets)-1) & h1
b := &table.buckets[bidx]
for {
metaw := atomic.LoadUint64(&b.meta)
markedw := markZeroBytes(metaw^h2w) & metaMask
for markedw != 0 {
idx := firstMarkedByteIndex(markedw)
eptr := atomic.LoadPointer(&b.entries[idx])
if eptr != nil {
e := (*entryOf[K, V])(eptr)
if e.key == key {
return e.value, true
}
}
markedw &= markedw - 1
}
bptr := atomic.LoadPointer(&b.next)
if bptr == nil {
return
}
b = (*bucketOfPadded)(bptr)
}
}
// Store sets the value for a key.
func (m *MapOf[K, V]) Store(key K, value V) {
m.doCompute(
key,
func(V, bool) (V, bool) {
return value, false
},
false,
false,
)
}
// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *MapOf[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
return m.doCompute(
key,
func(V, bool) (V, bool) {
return value, false
},
true,
false,
)
}
// LoadAndStore returns the existing value for the key if present,
// while setting the new value for the key.
// It stores the new value and returns the existing one, if present.
// The loaded result is true if the existing value was loaded,
// false otherwise.
func (m *MapOf[K, V]) LoadAndStore(key K, value V) (actual V, loaded bool) {
return m.doCompute(
key,
func(V, bool) (V, bool) {
return value, false
},
false,
false,
)
}
// LoadOrCompute returns the existing value for the key if present.
// Otherwise, it computes the value using the provided function and
// returns the computed value. The loaded result is true if the value
// was loaded, false if stored.
//
// This call locks a hash table bucket while the compute function
// is executed. It means that modifications on other entries in
// the bucket will be blocked until the valueFn executes. Consider
// this when the function includes long-running operations.
func (m *MapOf[K, V]) LoadOrCompute(key K, valueFn func() V) (actual V, loaded bool) {
return m.doCompute(
key,
func(V, bool) (V, bool) {
return valueFn(), false
},
true,
false,
)
}
// Compute either sets the computed new value for the key or deletes
// the value for the key. When the delete result of the valueFn function
// is set to true, the value will be deleted, if it exists. When delete
// is set to false, the value is updated to the newValue.
// The ok result indicates whether value was computed and stored, thus, is
// present in the map. The actual result contains the new value in cases where
// the value was computed and stored. See the example for a few use cases.
//
// This call locks a hash table bucket while the compute function
// is executed. It means that modifications on other entries in
// the bucket will be blocked until the valueFn executes. Consider
// this when the function includes long-running operations.
func (m *MapOf[K, V]) Compute(
key K,
valueFn func(oldValue V, loaded bool) (newValue V, delete bool),
) (actual V, ok bool) {
return m.doCompute(key, valueFn, false, true)
}
// LoadAndDelete deletes the value for a key, returning the previous
// value if any. The loaded result reports whether the key was
// present.
func (m *MapOf[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
return m.doCompute(
key,
func(value V, loaded bool) (V, bool) {
return value, true
},
false,
false,
)
}
// Delete deletes the value for a key.
func (m *MapOf[K, V]) Delete(key K) {
m.doCompute(
key,
func(value V, loaded bool) (V, bool) {
return value, true
},
false,
false,
)
}
func (m *MapOf[K, V]) doCompute(
key K,
valueFn func(oldValue V, loaded bool) (V, bool),
loadIfExists, computeOnly bool,
) (V, bool) {
// Read-only path.
if loadIfExists {
if v, ok := m.Load(key); ok {
return v, !computeOnly
}
}
// Write path.
for {
compute_attempt:
var (
emptyb *bucketOfPadded
emptyidx int
)
table := (*mapOfTable[K, V])(atomic.LoadPointer(&m.table))
tableLen := len(table.buckets)
hash := m.hasher(key, table.seed)
h1 := h1(hash)
h2 := h2(hash)
h2w := broadcast(h2)
bidx := uint64(len(table.buckets)-1) & h1
rootb := &table.buckets[bidx]
rootb.mu.Lock()
// The following two checks must go in reverse to what's
// in the resize method.
if m.resizeInProgress() {
// Resize is in progress. Wait, then go for another attempt.
rootb.mu.Unlock()
m.waitForResize()
goto compute_attempt
}
if m.newerTableExists(table) {
// Someone resized the table. Go for another attempt.
rootb.mu.Unlock()
goto compute_attempt
}
b := rootb
for {
metaw := b.meta
markedw := markZeroBytes(metaw^h2w) & metaMask
for markedw != 0 {
idx := firstMarkedByteIndex(markedw)
eptr := b.entries[idx]
if eptr != nil {
e := (*entryOf[K, V])(eptr)
if e.key == key {
if loadIfExists {
rootb.mu.Unlock()
return e.value, !computeOnly
}
// In-place update/delete.
// We get a copy of the value via an interface{} on each call,
// thus the live value pointers are unique. Otherwise atomic
// snapshot won't be correct in case of multiple Store calls
// using the same value.
oldv := e.value
newv, del := valueFn(oldv, true)
if del {
// Deletion.
// First we update the hash, then the entry.
newmetaw := setByte(metaw, emptyMetaSlot, idx)
atomic.StoreUint64(&b.meta, newmetaw)
atomic.StorePointer(&b.entries[idx], nil)
rootb.mu.Unlock()
table.addSize(bidx, -1)
// Might need to shrink the table if we left bucket empty.
if newmetaw == defaultMeta {
m.resize(table, mapShrinkHint)
}
return oldv, !computeOnly
}
newe := new(entryOf[K, V])
newe.key = key
newe.value = newv
atomic.StorePointer(&b.entries[idx], unsafe.Pointer(newe))
rootb.mu.Unlock()
if computeOnly {
// Compute expects the new value to be returned.
return newv, true
}
// LoadAndStore expects the old value to be returned.
return oldv, true
}
}
markedw &= markedw - 1
}
if emptyb == nil {
// Search for empty entries (up to 5 per bucket).
emptyw := metaw & defaultMetaMasked
if emptyw != 0 {
idx := firstMarkedByteIndex(emptyw)
emptyb = b
emptyidx = idx
}
}
if b.next == nil {
if emptyb != nil {
// Insertion into an existing bucket.
var zeroedV V
newValue, del := valueFn(zeroedV, false)
if del {
rootb.mu.Unlock()
return zeroedV, false
}
newe := new(entryOf[K, V])
newe.key = key
newe.value = newValue
// First we update meta, then the entry.
atomic.StoreUint64(&emptyb.meta, setByte(emptyb.meta, h2, emptyidx))
atomic.StorePointer(&emptyb.entries[emptyidx], unsafe.Pointer(newe))
rootb.mu.Unlock()
table.addSize(bidx, 1)
return newValue, computeOnly
}
growThreshold := float64(tableLen) * entriesPerMapOfBucket * mapLoadFactor
if table.sumSize() > int64(growThreshold) {
// Need to grow the table. Then go for another attempt.
rootb.mu.Unlock()
m.resize(table, mapGrowHint)
goto compute_attempt
}
// Insertion into a new bucket.
var zeroedV V
newValue, del := valueFn(zeroedV, false)
if del {
rootb.mu.Unlock()
return newValue, false
}
// Create and append a bucket.
newb := new(bucketOfPadded)
newb.meta = setByte(defaultMeta, h2, 0)
newe := new(entryOf[K, V])
newe.key = key
newe.value = newValue
newb.entries[0] = unsafe.Pointer(newe)
atomic.StorePointer(&b.next, unsafe.Pointer(newb))
rootb.mu.Unlock()
table.addSize(bidx, 1)
return newValue, computeOnly
}
b = (*bucketOfPadded)(b.next)
}
}
}
func (m *MapOf[K, V]) newerTableExists(table *mapOfTable[K, V]) bool {
curTablePtr := atomic.LoadPointer(&m.table)
return uintptr(curTablePtr) != uintptr(unsafe.Pointer(table))
}
func (m *MapOf[K, V]) resizeInProgress() bool {
return atomic.LoadInt64(&m.resizing) == 1
}
func (m *MapOf[K, V]) waitForResize() {
m.resizeMu.Lock()
for m.resizeInProgress() {
m.resizeCond.Wait()
}
m.resizeMu.Unlock()
}
func (m *MapOf[K, V]) resize(knownTable *mapOfTable[K, V], hint mapResizeHint) {
knownTableLen := len(knownTable.buckets)
// Fast path for shrink attempts.
if hint == mapShrinkHint {
if m.growOnly ||
m.minTableLen == knownTableLen ||
knownTable.sumSize() > int64((knownTableLen*entriesPerMapOfBucket)/mapShrinkFraction) {
return
}
}
// Slow path.
if !atomic.CompareAndSwapInt64(&m.resizing, 0, 1) {
// Someone else started resize. Wait for it to finish.
m.waitForResize()
return
}
var newTable *mapOfTable[K, V]
table := (*mapOfTable[K, V])(atomic.LoadPointer(&m.table))
tableLen := len(table.buckets)
switch hint {
case mapGrowHint:
// Grow the table with factor of 2.
atomic.AddInt64(&m.totalGrowths, 1)
newTable = newMapOfTable[K, V](tableLen << 1)
case mapShrinkHint:
shrinkThreshold := int64((tableLen * entriesPerMapOfBucket) / mapShrinkFraction)
if tableLen > m.minTableLen && table.sumSize() <= shrinkThreshold {
// Shrink the table with factor of 2.
atomic.AddInt64(&m.totalShrinks, 1)
newTable = newMapOfTable[K, V](tableLen >> 1)
} else {
// No need to shrink. Wake up all waiters and give up.
m.resizeMu.Lock()
atomic.StoreInt64(&m.resizing, 0)
m.resizeCond.Broadcast()
m.resizeMu.Unlock()
return
}
case mapClearHint:
newTable = newMapOfTable[K, V](m.minTableLen)
default:
panic(fmt.Sprintf("unexpected resize hint: %d", hint))
}
// Copy the data only if we're not clearing the map.
if hint != mapClearHint {
for i := 0; i < tableLen; i++ {
copied := copyBucketOf(&table.buckets[i], newTable, m.hasher)
newTable.addSizePlain(uint64(i), copied)
}
}
// Publish the new table and wake up all waiters.
atomic.StorePointer(&m.table, unsafe.Pointer(newTable))
m.resizeMu.Lock()
atomic.StoreInt64(&m.resizing, 0)
m.resizeCond.Broadcast()
m.resizeMu.Unlock()
}
func copyBucketOf[K comparable, V any](
b *bucketOfPadded,
destTable *mapOfTable[K, V],
hasher func(K, uint64) uint64,
) (copied int) {
rootb := b
rootb.mu.Lock()
for {
for i := 0; i < entriesPerMapOfBucket; i++ {
if b.entries[i] != nil {
e := (*entryOf[K, V])(b.entries[i])
hash := hasher(e.key, destTable.seed)
bidx := uint64(len(destTable.buckets)-1) & h1(hash)
destb := &destTable.buckets[bidx]
appendToBucketOf(h2(hash), b.entries[i], destb)
copied++
}
}
if b.next == nil {
rootb.mu.Unlock()
return
}
b = (*bucketOfPadded)(b.next)
}
}
// Range calls f sequentially for each key and value present in the
// map. If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot
// of the Map's contents: no key will be visited more than once, but
// if the value for any key is stored or deleted concurrently, Range
// may reflect any mapping for that key from any point during the
// Range call.
//
// It is safe to modify the map while iterating it, including entry
// creation, modification and deletion. However, the concurrent
// modification rule apply, i.e. the changes may be not reflected
// in the subsequently iterated entries.
func (m *MapOf[K, V]) Range(f func(key K, value V) bool) {
var zeroPtr unsafe.Pointer
// Pre-allocate array big enough to fit entries for most hash tables.
bentries := make([]unsafe.Pointer, 0, 16*entriesPerMapOfBucket)
tablep := atomic.LoadPointer(&m.table)
table := *(*mapOfTable[K, V])(tablep)
for i := range table.buckets {
rootb := &table.buckets[i]
b := rootb
// Prevent concurrent modifications and copy all entries into
// the intermediate slice.
rootb.mu.Lock()
for {
for i := 0; i < entriesPerMapOfBucket; i++ {
if b.entries[i] != nil {
bentries = append(bentries, b.entries[i])
}
}
if b.next == nil {
rootb.mu.Unlock()
break
}
b = (*bucketOfPadded)(b.next)
}
// Call the function for all copied entries.
for j := range bentries {
entry := (*entryOf[K, V])(bentries[j])
if !f(entry.key, entry.value) {
return
}
// Remove the reference to avoid preventing the copied
// entries from being GCed until this method finishes.
bentries[j] = zeroPtr
}
bentries = bentries[:0]
}
}
// Clear deletes all keys and values currently stored in the map.
func (m *MapOf[K, V]) Clear() {
table := (*mapOfTable[K, V])(atomic.LoadPointer(&m.table))
m.resize(table, mapClearHint)
}
// Size returns current size of the map.
func (m *MapOf[K, V]) Size() int {
table := (*mapOfTable[K, V])(atomic.LoadPointer(&m.table))
return int(table.sumSize())
}
func appendToBucketOf(h2 uint8, entryPtr unsafe.Pointer, b *bucketOfPadded) {
for {
for i := 0; i < entriesPerMapOfBucket; i++ {
if b.entries[i] == nil {
b.meta = setByte(b.meta, h2, i)
b.entries[i] = entryPtr
return
}
}
if b.next == nil {
newb := new(bucketOfPadded)
newb.meta = setByte(defaultMeta, h2, 0)
newb.entries[0] = entryPtr
b.next = unsafe.Pointer(newb)
return
}
b = (*bucketOfPadded)(b.next)
}
}
func (table *mapOfTable[K, V]) addSize(bucketIdx uint64, delta int) {
cidx := uint64(len(table.size)-1) & bucketIdx
atomic.AddInt64(&table.size[cidx].c, int64(delta))
}
func (table *mapOfTable[K, V]) addSizePlain(bucketIdx uint64, delta int) {
cidx := uint64(len(table.size)-1) & bucketIdx
table.size[cidx].c += int64(delta)
}
func (table *mapOfTable[K, V]) sumSize() int64 {
sum := int64(0)
for i := range table.size {
sum += atomic.LoadInt64(&table.size[i].c)
}
return sum
}
func h1(h uint64) uint64 {
return h >> 7
}
func h2(h uint64) uint8 {
return uint8(h & 0x7f)
}
// Stats returns statistics for the MapOf. Just like other map
// methods, this one is thread-safe. Yet it's an O(N) operation,
// so it should be used only for diagnostics or debugging purposes.
func (m *MapOf[K, V]) Stats() MapStats {
stats := MapStats{
TotalGrowths: atomic.LoadInt64(&m.totalGrowths),
TotalShrinks: atomic.LoadInt64(&m.totalShrinks),
MinEntries: math.MaxInt32,
}
table := (*mapOfTable[K, V])(atomic.LoadPointer(&m.table))
stats.RootBuckets = len(table.buckets)
stats.Counter = int(table.sumSize())
stats.CounterLen = len(table.size)
for i := range table.buckets {
nentries := 0
b := &table.buckets[i]
stats.TotalBuckets++
for {
nentriesLocal := 0
stats.Capacity += entriesPerMapOfBucket
for i := 0; i < entriesPerMapOfBucket; i++ {
if atomic.LoadPointer(&b.entries[i]) != nil {
stats.Size++
nentriesLocal++
}
}
nentries += nentriesLocal
if nentriesLocal == 0 {
stats.EmptyBuckets++
}
if b.next == nil {
break
}
b = (*bucketOfPadded)(atomic.LoadPointer(&b.next))
stats.TotalBuckets++
}
if nentries < stats.MinEntries {
stats.MinEntries = nentries
}
if nentries > stats.MaxEntries {
stats.MaxEntries = nentries
}
}
return stats
}

View file

@ -1,137 +0,0 @@
package xsync
import (
"runtime"
"sync/atomic"
"unsafe"
)
// A MPMCQueue is a bounded multi-producer multi-consumer concurrent
// queue.
//
// MPMCQueue instances must be created with NewMPMCQueue function.
// A MPMCQueue must not be copied after first use.
//
// Based on the data structure from the following C++ library:
// https://github.com/rigtorp/MPMCQueue
type MPMCQueue struct {
cap uint64
head uint64
//lint:ignore U1000 prevents false sharing
hpad [cacheLineSize - 8]byte
tail uint64
//lint:ignore U1000 prevents false sharing
tpad [cacheLineSize - 8]byte
slots []slotPadded
}
type slotPadded struct {
slot
//lint:ignore U1000 prevents false sharing
pad [cacheLineSize - unsafe.Sizeof(slot{})]byte
}
type slot struct {
turn uint64
item interface{}
}
// NewMPMCQueue creates a new MPMCQueue instance with the given
// capacity.
func NewMPMCQueue(capacity int) *MPMCQueue {
if capacity < 1 {
panic("capacity must be positive number")
}
return &MPMCQueue{
cap: uint64(capacity),
slots: make([]slotPadded, capacity),
}
}
// Enqueue inserts the given item into the queue.
// Blocks, if the queue is full.
func (q *MPMCQueue) Enqueue(item interface{}) {
head := atomic.AddUint64(&q.head, 1) - 1
slot := &q.slots[q.idx(head)]
turn := q.turn(head) * 2
for atomic.LoadUint64(&slot.turn) != turn {
runtime.Gosched()
}
slot.item = item
atomic.StoreUint64(&slot.turn, turn+1)
}
// Dequeue retrieves and removes the item from the head of the queue.
// Blocks, if the queue is empty.
func (q *MPMCQueue) Dequeue() interface{} {
tail := atomic.AddUint64(&q.tail, 1) - 1
slot := &q.slots[q.idx(tail)]
turn := q.turn(tail)*2 + 1
for atomic.LoadUint64(&slot.turn) != turn {
runtime.Gosched()
}
item := slot.item
slot.item = nil
atomic.StoreUint64(&slot.turn, turn+1)
return item
}
// TryEnqueue inserts the given item into the queue. Does not block
// and returns immediately. The result indicates that the queue isn't
// full and the item was inserted.
func (q *MPMCQueue) TryEnqueue(item interface{}) bool {
head := atomic.LoadUint64(&q.head)
for {
slot := &q.slots[q.idx(head)]
turn := q.turn(head) * 2
if atomic.LoadUint64(&slot.turn) == turn {
if atomic.CompareAndSwapUint64(&q.head, head, head+1) {
slot.item = item
atomic.StoreUint64(&slot.turn, turn+1)
return true
}
} else {
prevHead := head
head = atomic.LoadUint64(&q.head)
if head == prevHead {
return false
}
}
runtime.Gosched()
}
}
// TryDequeue retrieves and removes the item from the head of the
// queue. Does not block and returns immediately. The ok result
// indicates that the queue isn't empty and an item was retrieved.
func (q *MPMCQueue) TryDequeue() (item interface{}, ok bool) {
tail := atomic.LoadUint64(&q.tail)
for {
slot := &q.slots[q.idx(tail)]
turn := q.turn(tail)*2 + 1
if atomic.LoadUint64(&slot.turn) == turn {
if atomic.CompareAndSwapUint64(&q.tail, tail, tail+1) {
item = slot.item
ok = true
slot.item = nil
atomic.StoreUint64(&slot.turn, turn+1)
return
}
} else {
prevTail := tail
tail = atomic.LoadUint64(&q.tail)
if tail == prevTail {
return
}
}
runtime.Gosched()
}
}
func (q *MPMCQueue) idx(i uint64) uint64 {
return i % q.cap
}
func (q *MPMCQueue) turn(i uint64) uint64 {
return i / q.cap
}

View file

@ -1,150 +0,0 @@
//go:build go1.19
// +build go1.19
package xsync
import (
"runtime"
"sync/atomic"
"unsafe"
)
// A MPMCQueueOf is a bounded multi-producer multi-consumer concurrent
// queue. It's a generic version of MPMCQueue.
//
// MPMCQueue instances must be created with NewMPMCQueueOf function.
// A MPMCQueueOf must not be copied after first use.
//
// Based on the data structure from the following C++ library:
// https://github.com/rigtorp/MPMCQueue
type MPMCQueueOf[I any] struct {
cap uint64
head uint64
//lint:ignore U1000 prevents false sharing
hpad [cacheLineSize - 8]byte
tail uint64
//lint:ignore U1000 prevents false sharing
tpad [cacheLineSize - 8]byte
slots []slotOfPadded[I]
}
type slotOfPadded[I any] struct {
slotOf[I]
// Unfortunately, proper padding like the below one:
//
// pad [cacheLineSize - (unsafe.Sizeof(slotOf[I]{}) % cacheLineSize)]byte
//
// won't compile, so here we add a best-effort padding for items up to
// 56 bytes size.
//lint:ignore U1000 prevents false sharing
pad [cacheLineSize - unsafe.Sizeof(atomic.Uint64{})]byte
}
type slotOf[I any] struct {
// atomic.Uint64 is used here to get proper 8 byte alignment on
// 32-bit archs.
turn atomic.Uint64
item I
}
// NewMPMCQueueOf creates a new MPMCQueueOf instance with the given
// capacity.
func NewMPMCQueueOf[I any](capacity int) *MPMCQueueOf[I] {
if capacity < 1 {
panic("capacity must be positive number")
}
return &MPMCQueueOf[I]{
cap: uint64(capacity),
slots: make([]slotOfPadded[I], capacity),
}
}
// Enqueue inserts the given item into the queue.
// Blocks, if the queue is full.
func (q *MPMCQueueOf[I]) Enqueue(item I) {
head := atomic.AddUint64(&q.head, 1) - 1
slot := &q.slots[q.idx(head)]
turn := q.turn(head) * 2
for slot.turn.Load() != turn {
runtime.Gosched()
}
slot.item = item
slot.turn.Store(turn + 1)
}
// Dequeue retrieves and removes the item from the head of the queue.
// Blocks, if the queue is empty.
func (q *MPMCQueueOf[I]) Dequeue() I {
var zeroedI I
tail := atomic.AddUint64(&q.tail, 1) - 1
slot := &q.slots[q.idx(tail)]
turn := q.turn(tail)*2 + 1
for slot.turn.Load() != turn {
runtime.Gosched()
}
item := slot.item
slot.item = zeroedI
slot.turn.Store(turn + 1)
return item
}
// TryEnqueue inserts the given item into the queue. Does not block
// and returns immediately. The result indicates that the queue isn't
// full and the item was inserted.
func (q *MPMCQueueOf[I]) TryEnqueue(item I) bool {
head := atomic.LoadUint64(&q.head)
for {
slot := &q.slots[q.idx(head)]
turn := q.turn(head) * 2
if slot.turn.Load() == turn {
if atomic.CompareAndSwapUint64(&q.head, head, head+1) {
slot.item = item
slot.turn.Store(turn + 1)
return true
}
} else {
prevHead := head
head = atomic.LoadUint64(&q.head)
if head == prevHead {
return false
}
}
runtime.Gosched()
}
}
// TryDequeue retrieves and removes the item from the head of the
// queue. Does not block and returns immediately. The ok result
// indicates that the queue isn't empty and an item was retrieved.
func (q *MPMCQueueOf[I]) TryDequeue() (item I, ok bool) {
tail := atomic.LoadUint64(&q.tail)
for {
slot := &q.slots[q.idx(tail)]
turn := q.turn(tail)*2 + 1
if slot.turn.Load() == turn {
if atomic.CompareAndSwapUint64(&q.tail, tail, tail+1) {
var zeroedI I
item = slot.item
ok = true
slot.item = zeroedI
slot.turn.Store(turn + 1)
return
}
} else {
prevTail := tail
tail = atomic.LoadUint64(&q.tail)
if tail == prevTail {
return
}
}
runtime.Gosched()
}
}
func (q *MPMCQueueOf[I]) idx(i uint64) uint64 {
return i % q.cap
}
func (q *MPMCQueueOf[I]) turn(i uint64) uint64 {
return i / q.cap
}

View file

@ -1,188 +0,0 @@
package xsync
import (
"runtime"
"sync"
"sync/atomic"
"time"
)
// slow-down guard
const nslowdown = 7
// pool for reader tokens
var rtokenPool sync.Pool
// RToken is a reader lock token.
type RToken struct {
slot uint32
//lint:ignore U1000 prevents false sharing
pad [cacheLineSize - 4]byte
}
// A RBMutex is a reader biased reader/writer mutual exclusion lock.
// The lock can be held by an many readers or a single writer.
// The zero value for a RBMutex is an unlocked mutex.
//
// A RBMutex must not be copied after first use.
//
// RBMutex is based on a modified version of BRAVO
// (Biased Locking for Reader-Writer Locks) algorithm:
// https://arxiv.org/pdf/1810.01553.pdf
//
// RBMutex is a specialized mutex for scenarios, such as caches,
// where the vast majority of locks are acquired by readers and write
// lock acquire attempts are infrequent. In such scenarios, RBMutex
// performs better than sync.RWMutex on large multicore machines.
//
// RBMutex extends sync.RWMutex internally and uses it as the "reader
// bias disabled" fallback, so the same semantics apply. The only
// noticeable difference is in reader tokens returned from the
// RLock/RUnlock methods.
type RBMutex struct {
rslots []rslot
rmask uint32
rbias int32
inhibitUntil time.Time
rw sync.RWMutex
}
type rslot struct {
mu int32
//lint:ignore U1000 prevents false sharing
pad [cacheLineSize - 4]byte
}
// NewRBMutex creates a new RBMutex instance.
func NewRBMutex() *RBMutex {
nslots := nextPowOf2(parallelism())
mu := RBMutex{
rslots: make([]rslot, nslots),
rmask: nslots - 1,
rbias: 1,
}
return &mu
}
// TryRLock tries to lock m for reading without blocking.
// When TryRLock succeeds, it returns true and a reader token.
// In case of a failure, a false is returned.
func (mu *RBMutex) TryRLock() (bool, *RToken) {
if t := mu.fastRlock(); t != nil {
return true, t
}
// Optimistic slow path.
if mu.rw.TryRLock() {
if atomic.LoadInt32(&mu.rbias) == 0 && time.Now().After(mu.inhibitUntil) {
atomic.StoreInt32(&mu.rbias, 1)
}
return true, nil
}
return false, nil
}
// RLock locks m for reading and returns a reader token. The
// token must be used in the later RUnlock call.
//
// Should not be used for recursive read locking; a blocked Lock
// call excludes new readers from acquiring the lock.
func (mu *RBMutex) RLock() *RToken {
if t := mu.fastRlock(); t != nil {
return t
}
// Slow path.
mu.rw.RLock()
if atomic.LoadInt32(&mu.rbias) == 0 && time.Now().After(mu.inhibitUntil) {
atomic.StoreInt32(&mu.rbias, 1)
}
return nil
}
func (mu *RBMutex) fastRlock() *RToken {
if atomic.LoadInt32(&mu.rbias) == 1 {
t, ok := rtokenPool.Get().(*RToken)
if !ok {
t = new(RToken)
t.slot = runtime_fastrand()
}
// Try all available slots to distribute reader threads to slots.
for i := 0; i < len(mu.rslots); i++ {
slot := t.slot + uint32(i)
rslot := &mu.rslots[slot&mu.rmask]
rslotmu := atomic.LoadInt32(&rslot.mu)
if atomic.CompareAndSwapInt32(&rslot.mu, rslotmu, rslotmu+1) {
if atomic.LoadInt32(&mu.rbias) == 1 {
// Hot path succeeded.
t.slot = slot
return t
}
// The mutex is no longer reader biased. Roll back.
atomic.AddInt32(&rslot.mu, -1)
rtokenPool.Put(t)
return nil
}
// Contention detected. Give a try with the next slot.
}
}
return nil
}
// RUnlock undoes a single RLock call. A reader token obtained from
// the RLock call must be provided. RUnlock does not affect other
// simultaneous readers. A panic is raised if m is not locked for
// reading on entry to RUnlock.
func (mu *RBMutex) RUnlock(t *RToken) {
if t == nil {
mu.rw.RUnlock()
return
}
if atomic.AddInt32(&mu.rslots[t.slot&mu.rmask].mu, -1) < 0 {
panic("invalid reader state detected")
}
rtokenPool.Put(t)
}
// TryLock tries to lock m for writing without blocking.
func (mu *RBMutex) TryLock() bool {
if mu.rw.TryLock() {
if atomic.LoadInt32(&mu.rbias) == 1 {
atomic.StoreInt32(&mu.rbias, 0)
for i := 0; i < len(mu.rslots); i++ {
if atomic.LoadInt32(&mu.rslots[i].mu) > 0 {
// There is a reader. Roll back.
atomic.StoreInt32(&mu.rbias, 1)
mu.rw.Unlock()
return false
}
}
}
return true
}
return false
}
// Lock locks m for writing. If the lock is already locked for
// reading or writing, Lock blocks until the lock is available.
func (mu *RBMutex) Lock() {
mu.rw.Lock()
if atomic.LoadInt32(&mu.rbias) == 1 {
atomic.StoreInt32(&mu.rbias, 0)
start := time.Now()
for i := 0; i < len(mu.rslots); i++ {
for atomic.LoadInt32(&mu.rslots[i].mu) > 0 {
runtime.Gosched()
}
}
mu.inhibitUntil = time.Now().Add(time.Since(start) * nslowdown)
}
}
// Unlock unlocks m for writing. A panic is raised if m is not locked
// for writing on entry to Unlock.
//
// As with RWMutex, a locked RBMutex is not associated with a
// particular goroutine. One goroutine may RLock (Lock) a RBMutex and
// then arrange for another goroutine to RUnlock (Unlock) it.
func (mu *RBMutex) Unlock() {
mu.rw.Unlock()
}

View file

@ -1,66 +0,0 @@
package xsync
import (
"math/bits"
"runtime"
_ "unsafe"
)
// test-only assert()-like flag
var assertionsEnabled = false
const (
// cacheLineSize is used in paddings to prevent false sharing;
// 64B are used instead of 128B as a compromise between
// memory footprint and performance; 128B usage may give ~30%
// improvement on NUMA machines.
cacheLineSize = 64
)
// nextPowOf2 computes the next highest power of 2 of 32-bit v.
// Source: https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
func nextPowOf2(v uint32) uint32 {
if v == 0 {
return 1
}
v--
v |= v >> 1
v |= v >> 2
v |= v >> 4
v |= v >> 8
v |= v >> 16
v++
return v
}
func parallelism() uint32 {
maxProcs := uint32(runtime.GOMAXPROCS(0))
numCores := uint32(runtime.NumCPU())
if maxProcs < numCores {
return maxProcs
}
return numCores
}
//go:noescape
//go:linkname runtime_fastrand runtime.fastrand
func runtime_fastrand() uint32
func broadcast(b uint8) uint64 {
return 0x101010101010101 * uint64(b)
}
func firstMarkedByteIndex(w uint64) int {
return bits.TrailingZeros64(w) >> 3
}
// SWAR byte search: may produce false positives, e.g. for 0x0100,
// so make sure to double-check bytes found by this function.
func markZeroBytes(w uint64) uint64 {
return ((w - 0x0101010101010101) & (^w) & 0x8080808080808080)
}
func setByte(w uint64, b uint8, idx int) uint64 {
shift := idx << 3
return (w &^ (0xff << shift)) | (uint64(b) << shift)
}

View file

@ -1,77 +0,0 @@
package xsync
import (
"reflect"
"unsafe"
)
// makeSeed creates a random seed.
func makeSeed() uint64 {
var s1 uint32
for {
s1 = runtime_fastrand()
// We use seed 0 to indicate an uninitialized seed/hash,
// so keep trying until we get a non-zero seed.
if s1 != 0 {
break
}
}
s2 := runtime_fastrand()
return uint64(s1)<<32 | uint64(s2)
}
// hashString calculates a hash of s with the given seed.
func hashString(s string, seed uint64) uint64 {
if s == "" {
return seed
}
strh := (*reflect.StringHeader)(unsafe.Pointer(&s))
return uint64(runtime_memhash(unsafe.Pointer(strh.Data), uintptr(seed), uintptr(strh.Len)))
}
//go:noescape
//go:linkname runtime_memhash runtime.memhash
func runtime_memhash(p unsafe.Pointer, h, s uintptr) uintptr
// defaultHasher creates a fast hash function for the given comparable type.
// The only limitation is that the type should not contain interfaces inside
// based on runtime.typehash.
func defaultHasher[T comparable]() func(T, uint64) uint64 {
var zero T
if reflect.TypeOf(&zero).Elem().Kind() == reflect.Interface {
return func(value T, seed uint64) uint64 {
iValue := any(value)
i := (*iface)(unsafe.Pointer(&iValue))
return runtime_typehash64(i.typ, i.word, seed)
}
} else {
var iZero any = zero
i := (*iface)(unsafe.Pointer(&iZero))
return func(value T, seed uint64) uint64 {
return runtime_typehash64(i.typ, unsafe.Pointer(&value), seed)
}
}
}
// how interface is represented in memory
type iface struct {
typ uintptr
word unsafe.Pointer
}
// same as runtime_typehash, but always returns a uint64
// see: maphash.rthash function for details
func runtime_typehash64(t uintptr, p unsafe.Pointer, seed uint64) uint64 {
if unsafe.Sizeof(uintptr(0)) == 8 {
return uint64(runtime_typehash(t, p, uintptr(seed)))
}
lo := runtime_typehash(t, p, uintptr(seed))
hi := runtime_typehash(t, p, uintptr(seed>>32))
return uint64(hi)<<32 | uint64(lo)
}
//go:noescape
//go:linkname runtime_typehash runtime.typehash
func runtime_typehash(t uintptr, p unsafe.Pointer, h uintptr) uintptr

View file

@ -1,70 +1,3 @@
## [1.2.5](https://github.com/uptrace/bun/compare/v1.2.3...v1.2.5) (2024-10-26)
### Bug Fixes
* allow Limit() without Order() with MSSQL ([#1009](https://github.com/uptrace/bun/issues/1009)) ([1a46ddc](https://github.com/uptrace/bun/commit/1a46ddc0d3ca0bdc60ca8be5ad1886799d14c8b0))
* copy bytes in mapModel.Scan ([#1030](https://github.com/uptrace/bun/issues/1030)) ([#1032](https://github.com/uptrace/bun/issues/1032)) ([39fda4e](https://github.com/uptrace/bun/commit/39fda4e3d341e59e4955f751cb354a939e57c1b1))
* fix issue with has-many join and pointer fields ([#950](https://github.com/uptrace/bun/issues/950)) ([#983](https://github.com/uptrace/bun/issues/983)) ([cbc5177](https://github.com/uptrace/bun/commit/cbc517792ba6cdcef1828f3699d3d4dfe3c5e0eb))
* restore explicit column: name override ([#984](https://github.com/uptrace/bun/issues/984)) ([169f258](https://github.com/uptrace/bun/commit/169f258a9460cad451f3025d2ef8df1bbd42a003))
* return column option back ([#1036](https://github.com/uptrace/bun/issues/1036)) ([a3ccbea](https://github.com/uptrace/bun/commit/a3ccbeab39151d3eed6cb245fe15cfb5d71ba557))
* sql.NullString mistaken as custom struct ([#1019](https://github.com/uptrace/bun/issues/1019)) ([87c77b8](https://github.com/uptrace/bun/commit/87c77b8911f2035b0ee8ea96356a2c7600b5b94d))
* typos ([#1026](https://github.com/uptrace/bun/issues/1026)) ([760de7d](https://github.com/uptrace/bun/commit/760de7d0fad15dc761475670a4dde056aef9210d))
### Features
* add transaction isolation level support to pgdriver ([#1034](https://github.com/uptrace/bun/issues/1034)) ([3ef44ce](https://github.com/uptrace/bun/commit/3ef44ce1cdd969a21b76d6c803119cf12c375cb0))
### Performance Improvements
* refactor SelectQuery.ScanAndCount to optimize performance when there is no limit and offset ([#1035](https://github.com/uptrace/bun/issues/1035)) ([8638613](https://github.com/uptrace/bun/commit/86386135897485bbada6c50ec9a2743626111433))
## [1.2.4](https://github.com/uptrace/bun/compare/v1.2.3...v1.2.4) (2024-10-26)
### Bug Fixes
* allow Limit() without Order() with MSSQL ([#1009](https://github.com/uptrace/bun/issues/1009)) ([1a46ddc](https://github.com/uptrace/bun/commit/1a46ddc0d3ca0bdc60ca8be5ad1886799d14c8b0))
* copy bytes in mapModel.Scan ([#1030](https://github.com/uptrace/bun/issues/1030)) ([#1032](https://github.com/uptrace/bun/issues/1032)) ([39fda4e](https://github.com/uptrace/bun/commit/39fda4e3d341e59e4955f751cb354a939e57c1b1))
* return column option back ([#1036](https://github.com/uptrace/bun/issues/1036)) ([a3ccbea](https://github.com/uptrace/bun/commit/a3ccbeab39151d3eed6cb245fe15cfb5d71ba557))
* sql.NullString mistaken as custom struct ([#1019](https://github.com/uptrace/bun/issues/1019)) ([87c77b8](https://github.com/uptrace/bun/commit/87c77b8911f2035b0ee8ea96356a2c7600b5b94d))
* typos ([#1026](https://github.com/uptrace/bun/issues/1026)) ([760de7d](https://github.com/uptrace/bun/commit/760de7d0fad15dc761475670a4dde056aef9210d))
### Features
* add transaction isolation level support to pgdriver ([#1034](https://github.com/uptrace/bun/issues/1034)) ([3ef44ce](https://github.com/uptrace/bun/commit/3ef44ce1cdd969a21b76d6c803119cf12c375cb0))
### Performance Improvements
* refactor SelectQuery.ScanAndCount to optimize performance when there is no limit and offset ([#1035](https://github.com/uptrace/bun/issues/1035)) ([8638613](https://github.com/uptrace/bun/commit/86386135897485bbada6c50ec9a2743626111433))
## [1.2.3](https://github.com/uptrace/bun/compare/v1.2.2...v1.2.3) (2024-08-31)
## [1.2.2](https://github.com/uptrace/bun/compare/v1.2.1...v1.2.2) (2024-08-29)
### Bug Fixes
* gracefully handle empty hstore in pgdialect ([#1010](https://github.com/uptrace/bun/issues/1010)) ([2f73d8a](https://github.com/uptrace/bun/commit/2f73d8a8e16c8718ebfc956036d9c9a01a0888bc))
* number each unit test ([#974](https://github.com/uptrace/bun/issues/974)) ([b005dc2](https://github.com/uptrace/bun/commit/b005dc2a9034715c6f59dcfc8e76aa3b85df38ab))
### Features
* add ModelTableExpr to TruncateTableQuery ([#969](https://github.com/uptrace/bun/issues/969)) ([7bc330f](https://github.com/uptrace/bun/commit/7bc330f152cf0d9dc30956478e2731ea5816f012))
## [1.2.1](https://github.com/uptrace/bun/compare/v1.2.0...v1.2.1) (2024-04-02) ## [1.2.1](https://github.com/uptrace/bun/compare/v1.2.0...v1.2.1) (2024-04-02)
@ -81,7 +14,7 @@
### Features ### Features
* Allow overriding of Warn and Deprecated loggers ([#952](https://github.com/uptrace/bun/issues/952)) ([0e9d737](https://github.com/uptrace/bun/commit/0e9d737e4ca2deb86930237ee32a39cf3f7e8157)) * Allow overiding of Warn and Deprecated loggers ([#952](https://github.com/uptrace/bun/issues/952)) ([0e9d737](https://github.com/uptrace/bun/commit/0e9d737e4ca2deb86930237ee32a39cf3f7e8157))
* enable SNI ([#953](https://github.com/uptrace/bun/issues/953)) ([4071ffb](https://github.com/uptrace/bun/commit/4071ffb5bcb1b233cda239c92504d8139dcf1d2f)) * enable SNI ([#953](https://github.com/uptrace/bun/issues/953)) ([4071ffb](https://github.com/uptrace/bun/commit/4071ffb5bcb1b233cda239c92504d8139dcf1d2f))
* **idb:** add NewMerge method to IDB ([#966](https://github.com/uptrace/bun/issues/966)) ([664e2f1](https://github.com/uptrace/bun/commit/664e2f154f1153d2a80cd062a5074f1692edaee7)) * **idb:** add NewMerge method to IDB ([#966](https://github.com/uptrace/bun/issues/966)) ([664e2f1](https://github.com/uptrace/bun/commit/664e2f154f1153d2a80cd062a5074f1692edaee7))
@ -167,7 +100,7 @@
### Bug Fixes ### Bug Fixes
* add support for inserting values with Unicode encoding for mssql dialect ([e98c6c0](https://github.com/uptrace/bun/commit/e98c6c0f033b553bea3bbc783aa56c2eaa17718f)) * add support for inserting values with unicode encoding for mssql dialect ([e98c6c0](https://github.com/uptrace/bun/commit/e98c6c0f033b553bea3bbc783aa56c2eaa17718f))
* fix relation tag ([a3eedff](https://github.com/uptrace/bun/commit/a3eedff49700490d4998dcdcdc04f554d8f17166)) * fix relation tag ([a3eedff](https://github.com/uptrace/bun/commit/a3eedff49700490d4998dcdcdc04f554d8f17166))
@ -203,7 +136,7 @@
### Bug Fixes ### Bug Fixes
* adding dialect override for append-bool ([#695](https://github.com/uptrace/bun/issues/695)) ([338f2f0](https://github.com/uptrace/bun/commit/338f2f04105ad89e64530db86aeb387e2ad4789e)) * addng dialect override for append-bool ([#695](https://github.com/uptrace/bun/issues/695)) ([338f2f0](https://github.com/uptrace/bun/commit/338f2f04105ad89e64530db86aeb387e2ad4789e))
* don't call hooks twice for whereExists ([9057857](https://github.com/uptrace/bun/commit/90578578e717f248e4b6eb114c5b495fd8d4ed41)) * don't call hooks twice for whereExists ([9057857](https://github.com/uptrace/bun/commit/90578578e717f248e4b6eb114c5b495fd8d4ed41))
* don't lock migrations when running Migrate and Rollback ([69a7354](https://github.com/uptrace/bun/commit/69a7354d987ff2ed5338c9ef5f4ce320724299ab)) * don't lock migrations when running Migrate and Rollback ([69a7354](https://github.com/uptrace/bun/commit/69a7354d987ff2ed5338c9ef5f4ce320724299ab))
* **query:** make WhereDeleted compatible with ForceDelete ([299c3fd](https://github.com/uptrace/bun/commit/299c3fd57866aaecd127a8f219c95332898475db)), closes [#673](https://github.com/uptrace/bun/issues/673) * **query:** make WhereDeleted compatible with ForceDelete ([299c3fd](https://github.com/uptrace/bun/commit/299c3fd57866aaecd127a8f219c95332898475db)), closes [#673](https://github.com/uptrace/bun/issues/673)
@ -371,7 +304,7 @@ recommended to upgrade to v1.0.24 before upgrading to v1.1.x.
- append slice values - append slice values
([4a65129](https://github.com/uptrace/bun/commit/4a651294fb0f1e73079553024810c3ead9777311)) ([4a65129](https://github.com/uptrace/bun/commit/4a651294fb0f1e73079553024810c3ead9777311))
- check for nils when appending driver.Value - check for nils when appeding driver.Value
([7bb1640](https://github.com/uptrace/bun/commit/7bb1640a00fceca1e1075fe6544b9a4842ab2b26)) ([7bb1640](https://github.com/uptrace/bun/commit/7bb1640a00fceca1e1075fe6544b9a4842ab2b26))
- cleanup soft deletes for mssql - cleanup soft deletes for mssql
([e72e2c5](https://github.com/uptrace/bun/commit/e72e2c5d0a85f3d26c3fa22c7284c2de1dcfda8e)) ([e72e2c5](https://github.com/uptrace/bun/commit/e72e2c5d0a85f3d26c3fa22c7284c2de1dcfda8e))
@ -390,7 +323,7 @@ recommended to upgrade to v1.0.24 before upgrading to v1.1.x.
### Deprecated ### Deprecated
In the coming v1.1.x release, Bun will stop automatically adding `,pk,autoincrement` options on In the comming v1.1.x release, Bun will stop automatically adding `,pk,autoincrement` options on
`ID int64/int32` fields. This version (v1.0.23) only prints a warning when it encounters such `ID int64/int32` fields. This version (v1.0.23) only prints a warning when it encounters such
fields, but the code will continue working as before. fields, but the code will continue working as before.
@ -508,7 +441,7 @@ In v1.1.x, such options as `,nopk` and `,allowzero` will not be necessary and wi
([693f1e1](https://github.com/uptrace/bun/commit/693f1e135999fc31cf83b99a2530a695b20f4e1b)) ([693f1e1](https://github.com/uptrace/bun/commit/693f1e135999fc31cf83b99a2530a695b20f4e1b))
- add model embedding via embed:prefix\_ - add model embedding via embed:prefix\_
([9a2cedc](https://github.com/uptrace/bun/commit/9a2cedc8b08fa8585d4bfced338bd0a40d736b1d)) ([9a2cedc](https://github.com/uptrace/bun/commit/9a2cedc8b08fa8585d4bfced338bd0a40d736b1d))
- change the default log output to stderr - change the default logoutput to stderr
([4bf5773](https://github.com/uptrace/bun/commit/4bf577382f19c64457cbf0d64490401450954654)), ([4bf5773](https://github.com/uptrace/bun/commit/4bf577382f19c64457cbf0d64490401450954654)),
closes [#349](https://github.com/uptrace/bun/issues/349) closes [#349](https://github.com/uptrace/bun/issues/349)

View file

@ -15,7 +15,7 @@ go_mod_tidy:
echo "go mod tidy in $${dir}"; \ echo "go mod tidy in $${dir}"; \
(cd "$${dir}" && \ (cd "$${dir}" && \
go get -u ./... && \ go get -u ./... && \
go mod tidy); \ go mod tidy -go=1.21); \
done done
fmt: fmt:

View file

@ -1,4 +1,4 @@
# SQL-first Golang ORM for PostgreSQL, MySQL, MSSQL, SQLite and Oracle # SQL-first Golang ORM for PostgreSQL, MySQL, MSSQL, and SQLite
[![build workflow](https://github.com/uptrace/bun/actions/workflows/build.yml/badge.svg)](https://github.com/uptrace/bun/actions) [![build workflow](https://github.com/uptrace/bun/actions/workflows/build.yml/badge.svg)](https://github.com/uptrace/bun/actions)
[![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun)](https://pkg.go.dev/github.com/uptrace/bun) [![PkgGoDev](https://pkg.go.dev/badge/github.com/uptrace/bun)](https://pkg.go.dev/github.com/uptrace/bun)
@ -19,7 +19,6 @@
[MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql) (including MariaDB), [MySQL](https://bun.uptrace.dev/guide/drivers.html#mysql) (including MariaDB),
[MSSQL](https://bun.uptrace.dev/guide/drivers.html#mssql), [MSSQL](https://bun.uptrace.dev/guide/drivers.html#mssql),
[SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite). [SQLite](https://bun.uptrace.dev/guide/drivers.html#sqlite).
[Oracle](https://bun.uptrace.dev/guide/drivers.html#oracle).
- [ORM-like](/example/basic/) experience using good old SQL. Bun supports structs, map, scalars, and - [ORM-like](/example/basic/) experience using good old SQL. Bun supports structs, map, scalars, and
slices of map/structs/scalars. slices of map/structs/scalars.
- [Bulk inserts](https://bun.uptrace.dev/guide/query-insert.html). - [Bulk inserts](https://bun.uptrace.dev/guide/query-insert.html).

View file

@ -22,10 +22,6 @@
AfterScanRowHook = schema.AfterScanRowHook AfterScanRowHook = schema.AfterScanRowHook
) )
func SafeQuery(query string, args ...interface{}) schema.QueryWithArgs {
return schema.SafeQuery(query, args)
}
type BeforeSelectHook interface { type BeforeSelectHook interface {
BeforeSelect(ctx context.Context, query *SelectQuery) error BeforeSelect(ctx context.Context, query *SelectQuery) error
} }
@ -74,7 +70,7 @@ type AfterDropTableHook interface {
AfterDropTable(ctx context.Context, query *DropTableQuery) error AfterDropTable(ctx context.Context, query *DropTableQuery) error
} }
// SetLogger overwrites default Bun logger. // SetLogger overwriters default Bun logger.
func SetLogger(logger internal.Logging) { func SetLogger(logger internal.Logging) {
internal.SetLogger(logger) internal.SetLogger(logger)
} }

View file

@ -12,8 +12,6 @@ func (n Name) String() string {
return "mysql" return "mysql"
case MSSQL: case MSSQL:
return "mssql" return "mssql"
case Oracle:
return "oracle"
default: default:
return "invalid" return "invalid"
} }
@ -25,5 +23,4 @@ func (n Name) String() string {
SQLite SQLite
MySQL MySQL
MSSQL MSSQL
Oracle
) )

View file

@ -2,9 +2,12 @@
import ( import (
"database/sql/driver" "database/sql/driver"
"encoding/hex"
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"time" "time"
"unicode/utf8"
"github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/schema" "github.com/uptrace/bun/schema"
@ -29,10 +32,316 @@
sliceTimeType = reflect.TypeOf([]time.Time(nil)) sliceTimeType = reflect.TypeOf([]time.Time(nil))
) )
func appendTime(buf []byte, tm time.Time) []byte { func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
return tm.UTC().AppendFormat(buf, "2006-01-02 15:04:05.999999-07:00") switch v := v.(type) {
case int64:
return strconv.AppendInt(b, v, 10)
case float64:
return dialect.AppendFloat64(b, v)
case bool:
return dialect.AppendBool(b, v)
case []byte:
return arrayAppendBytes(b, v)
case string:
return arrayAppendString(b, v)
case time.Time:
return fmter.Dialect().AppendTime(b, v)
default:
err := fmt.Errorf("pgdialect: can't append %T", v)
return dialect.AppendError(b, err)
}
} }
func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendString(b, v.String())
}
func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendBytes(b, v.Bytes())
}
func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
iface, err := v.Interface().(driver.Valuer).Value()
if err != nil {
return dialect.AppendError(b, err)
}
return arrayAppend(fmter, b, iface)
}
//------------------------------------------------------------------------------
func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc {
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if fn := d.arrayAppender(typ.Elem()); fn != nil {
return schema.PtrAppender(fn)
}
case reflect.Slice, reflect.Array:
// ok:
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return appendStringSliceValue
case intType:
return appendIntSliceValue
case int64Type:
return appendInt64SliceValue
case float64Type:
return appendFloat64SliceValue
case timeType:
return appendTimeSliceValue
}
}
appendElem := d.arrayElemAppender(elemType)
if appendElem == nil {
panic(fmt.Errorf("pgdialect: %s is not supported", typ))
}
return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
kind := v.Kind()
switch kind {
case reflect.Ptr, reflect.Slice:
if v.IsNil() {
return dialect.AppendNull(b)
}
}
if kind == reflect.Ptr {
v = v.Elem()
}
b = append(b, '\'')
b = append(b, '{')
ln := v.Len()
for i := 0; i < ln; i++ {
elem := v.Index(i)
b = appendElem(fmter, b, elem)
b = append(b, ',')
}
if v.Len() > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
}
func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
switch typ.Kind() {
case reflect.String:
return arrayAppendStringValue
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return arrayAppendBytesValue
}
}
return schema.Appender(d, typ)
}
func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ss := v.Convert(sliceStringType).Interface().([]string)
return appendStringSlice(b, ss)
}
func appendStringSlice(b []byte, ss []string) []byte {
if ss == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, s := range ss {
b = arrayAppendString(b, s)
b = append(b, ',')
}
if len(ss) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceIntType).Interface().([]int)
return appendIntSlice(b, ints)
}
func appendIntSlice(b []byte, ints []int) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, int64(n), 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceInt64Type).Interface().([]int64)
return appendInt64Slice(b, ints)
}
func appendInt64Slice(b []byte, ints []int64) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, n, 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
floats := v.Convert(sliceFloat64Type).Interface().([]float64)
return appendFloat64Slice(b, floats)
}
func appendFloat64Slice(b []byte, floats []float64) []byte {
if floats == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range floats {
b = dialect.AppendFloat64(b, n)
b = append(b, ',')
}
if len(floats) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
//------------------------------------------------------------------------------
func arrayAppendBytes(b []byte, bs []byte) []byte {
if bs == nil {
return dialect.AppendNull(b)
}
b = append(b, `"\\x`...)
s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)
b = append(b, '"')
return b
}
func arrayAppendString(b []byte, s string) []byte {
b = append(b, '"')
for _, r := range s {
switch r {
case 0:
// ignore
case '\'':
b = append(b, "''"...)
case '"':
b = append(b, '\\', '"')
case '\\':
b = append(b, '\\', '\\')
default:
if r < utf8.RuneSelf {
b = append(b, byte(r))
break
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
}
b = append(b, '"')
return b
}
func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ts := v.Convert(sliceTimeType).Interface().([]time.Time)
return appendTimeSlice(fmter, b, ts)
}
func appendTimeSlice(fmter schema.Formatter, b []byte, ts []time.Time) []byte {
if ts == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, t := range ts {
b = append(b, '"')
b = t.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00")
b = append(b, '"')
b = append(b, ',')
}
if len(ts) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
//------------------------------------------------------------------------------
var mapStringStringType = reflect.TypeOf(map[string]string(nil)) var mapStringStringType = reflect.TypeOf(map[string]string(nil))
func (d *Dialect) hstoreAppender(typ reflect.Type) schema.AppenderFunc { func (d *Dialect) hstoreAppender(typ reflect.Type) schema.AppenderFunc {

View file

@ -2,16 +2,9 @@
import ( import (
"database/sql" "database/sql"
"database/sql/driver"
"encoding/hex"
"fmt" "fmt"
"reflect" "reflect"
"strconv"
"time"
"unicode/utf8"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema" "github.com/uptrace/bun/schema"
) )
@ -27,7 +20,7 @@ type ArrayValue struct {
// //
// For struct fields you can use array tag: // For struct fields you can use array tag:
// //
// Emails []string `bun:",array"` // Emails []string `bun:",array"`
func Array(vi interface{}) *ArrayValue { func Array(vi interface{}) *ArrayValue {
v := reflect.ValueOf(vi) v := reflect.ValueOf(vi)
if !v.IsValid() { if !v.IsValid() {
@ -70,576 +63,3 @@ func (a *ArrayValue) Value() interface{} {
} }
return nil return nil
} }
//------------------------------------------------------------------------------
func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc {
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if fn := d.arrayAppender(typ.Elem()); fn != nil {
return schema.PtrAppender(fn)
}
case reflect.Slice, reflect.Array:
// continue below
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return appendStringSliceValue
case intType:
return appendIntSliceValue
case int64Type:
return appendInt64SliceValue
case float64Type:
return appendFloat64SliceValue
case timeType:
return appendTimeSliceValue
}
}
appendElem := d.arrayElemAppender(elemType)
if appendElem == nil {
panic(fmt.Errorf("pgdialect: %s is not supported", typ))
}
return func(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
kind := v.Kind()
switch kind {
case reflect.Ptr, reflect.Slice:
if v.IsNil() {
return dialect.AppendNull(b)
}
}
if kind == reflect.Ptr {
v = v.Elem()
}
b = append(b, "'{"...)
ln := v.Len()
for i := 0; i < ln; i++ {
elem := v.Index(i)
if i > 0 {
b = append(b, ',')
}
b = appendElem(fmter, b, elem)
}
b = append(b, "}'"...)
return b
}
}
func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
switch typ.Kind() {
case reflect.String:
return arrayAppendStringValue
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return arrayAppendBytesValue
}
}
return schema.Appender(d, typ)
}
func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
switch v := v.(type) {
case int64:
return strconv.AppendInt(b, v, 10)
case float64:
return dialect.AppendFloat64(b, v)
case bool:
return dialect.AppendBool(b, v)
case []byte:
return arrayAppendBytes(b, v)
case string:
return arrayAppendString(b, v)
case time.Time:
return fmter.Dialect().AppendTime(b, v)
default:
err := fmt.Errorf("pgdialect: can't append %T", v)
return dialect.AppendError(b, err)
}
}
func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendString(b, v.String())
}
func arrayAppendBytesValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendBytes(b, v.Bytes())
}
func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
iface, err := v.Interface().(driver.Valuer).Value()
if err != nil {
return dialect.AppendError(b, err)
}
return arrayAppend(fmter, b, iface)
}
func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ss := v.Convert(sliceStringType).Interface().([]string)
return appendStringSlice(b, ss)
}
func appendStringSlice(b []byte, ss []string) []byte {
if ss == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, s := range ss {
b = arrayAppendString(b, s)
b = append(b, ',')
}
if len(ss) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendIntSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceIntType).Interface().([]int)
return appendIntSlice(b, ints)
}
func appendIntSlice(b []byte, ints []int) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, int64(n), 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendInt64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ints := v.Convert(sliceInt64Type).Interface().([]int64)
return appendInt64Slice(b, ints)
}
func appendInt64Slice(b []byte, ints []int64) []byte {
if ints == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range ints {
b = strconv.AppendInt(b, n, 10)
b = append(b, ',')
}
if len(ints) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendFloat64SliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
floats := v.Convert(sliceFloat64Type).Interface().([]float64)
return appendFloat64Slice(b, floats)
}
func appendFloat64Slice(b []byte, floats []float64) []byte {
if floats == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, n := range floats {
b = dialect.AppendFloat64(b, n)
b = append(b, ',')
}
if len(floats) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ts := v.Convert(sliceTimeType).Interface().([]time.Time)
return appendTimeSlice(fmter, b, ts)
}
func appendTimeSlice(fmter schema.Formatter, b []byte, ts []time.Time) []byte {
if ts == nil {
return dialect.AppendNull(b)
}
b = append(b, '\'')
b = append(b, '{')
for _, t := range ts {
b = append(b, '"')
b = appendTime(b, t)
b = append(b, '"')
b = append(b, ',')
}
if len(ts) > 0 {
b[len(b)-1] = '}' // Replace trailing comma.
} else {
b = append(b, '}')
}
b = append(b, '\'')
return b
}
//------------------------------------------------------------------------------
func arrayScanner(typ reflect.Type) schema.ScannerFunc {
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if fn := arrayScanner(typ.Elem()); fn != nil {
return schema.PtrScanner(fn)
}
case reflect.Slice, reflect.Array:
// ok:
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return scanStringSliceValue
case intType:
return scanIntSliceValue
case int64Type:
return scanInt64SliceValue
case float64Type:
return scanFloat64SliceValue
}
}
scanElem := schema.Scanner(elemType)
return func(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
kind := dest.Kind()
if src == nil {
if kind != reflect.Slice || !dest.IsNil() {
dest.Set(reflect.Zero(dest.Type()))
}
return nil
}
if kind == reflect.Slice {
if dest.IsNil() {
dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
} else if dest.Len() > 0 {
dest.Set(dest.Slice(0, 0))
}
}
b, err := toBytes(src)
if err != nil {
return err
}
p := newArrayParser(b)
nextValue := internal.MakeSliceNextElemFunc(dest)
for p.Next() {
elem := p.Elem()
elemValue := nextValue()
if err := scanElem(elemValue, elem); err != nil {
return fmt.Errorf("scanElem failed: %w", err)
}
}
return p.Err()
}
}
func scanStringSliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeStringSlice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeStringSlice(src interface{}) ([]string, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]string, 0)
p := newArrayParser(b)
for p.Next() {
elem := p.Elem()
slice = append(slice, string(elem))
}
if err := p.Err(); err != nil {
return nil, err
}
return slice, nil
}
func scanIntSliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeIntSlice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeIntSlice(src interface{}) ([]int, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]int, 0)
p := newArrayParser(b)
for p.Next() {
elem := p.Elem()
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.Atoi(bytesToString(elem))
if err != nil {
return nil, err
}
slice = append(slice, n)
}
if err := p.Err(); err != nil {
return nil, err
}
return slice, nil
}
func scanInt64SliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeInt64Slice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeInt64Slice(src interface{}) ([]int64, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]int64, 0)
p := newArrayParser(b)
for p.Next() {
elem := p.Elem()
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.ParseInt(bytesToString(elem), 10, 64)
if err != nil {
return nil, err
}
slice = append(slice, n)
}
if err := p.Err(); err != nil {
return nil, err
}
return slice, nil
}
func scanFloat64SliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := scanFloat64Slice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func scanFloat64Slice(src interface{}) ([]float64, error) {
if src == -1 {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]float64, 0)
p := newArrayParser(b)
for p.Next() {
elem := p.Elem()
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.ParseFloat(bytesToString(elem), 64)
if err != nil {
return nil, err
}
slice = append(slice, n)
}
if err := p.Err(); err != nil {
return nil, err
}
return slice, nil
}
func toBytes(src interface{}) ([]byte, error) {
switch src := src.(type) {
case string:
return stringToBytes(src), nil
case []byte:
return src, nil
default:
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
}
}
//------------------------------------------------------------------------------
func arrayAppendBytes(b []byte, bs []byte) []byte {
if bs == nil {
return dialect.AppendNull(b)
}
b = append(b, `"\\x`...)
s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)
b = append(b, '"')
return b
}
func arrayAppendString(b []byte, s string) []byte {
b = append(b, '"')
for _, r := range s {
switch r {
case 0:
// ignore
case '\'':
b = append(b, "''"...)
case '"':
b = append(b, '\\', '"')
case '\\':
b = append(b, '\\', '\\')
default:
if r < utf8.RuneSelf {
b = append(b, byte(r))
break
}
l := len(b)
if cap(b)-l < utf8.UTFMax {
b = append(b, make([]byte, utf8.UTFMax)...)
}
n := utf8.EncodeRune(b[l:l+utf8.UTFMax], r)
b = b[:l+n]
}
}
b = append(b, '"')
return b
}

View file

@ -2,92 +2,132 @@
import ( import (
"bytes" "bytes"
"encoding/hex"
"fmt" "fmt"
"io" "io"
) )
type arrayParser struct { type arrayParser struct {
p pgparser *streamParser
err error
elem []byte
err error
} }
func newArrayParser(b []byte) *arrayParser { func newArrayParser(b []byte) *arrayParser {
p := new(arrayParser) p := &arrayParser{
streamParser: newStreamParser(b, 1),
if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' { }
p.err = fmt.Errorf("pgdialect: can't parse array: %q", b) if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
return p p.err = fmt.Errorf("bun: can't parse array: %q", b)
} }
p.p.Reset(b[1 : len(b)-1])
return p return p
} }
func (p *arrayParser) Next() bool { func (p *arrayParser) NextElem() ([]byte, error) {
if p.err != nil { if p.err != nil {
return false return nil, p.err
}
p.err = p.readNext()
return p.err == nil
}
func (p *arrayParser) Err() error {
if p.err != io.EOF {
return p.err
}
return nil
}
func (p *arrayParser) Elem() []byte {
return p.elem
}
func (p *arrayParser) readNext() error {
ch := p.p.Read()
if ch == 0 {
return io.EOF
} }
switch ch { c, err := p.readByte()
if err != nil {
return nil, err
}
switch c {
case '}': case '}':
return io.EOF return nil, io.EOF
case '"': case '"':
b, err := p.p.ReadSubstring(ch) b, err := p.readSubstring()
if err != nil { if err != nil {
return err return nil, err
} }
if p.p.Peek() == ',' { if p.peek() == ',' {
p.p.Advance() p.skipNext()
} }
p.elem = b return b, nil
return nil
case '[', '(':
rng, err := p.p.ReadRange(ch)
if err != nil {
return err
}
if p.p.Peek() == ',' {
p.p.Advance()
}
p.elem = rng
return nil
default: default:
lit := p.p.ReadLiteral(ch) b := p.readSimple()
if bytes.Equal(lit, []byte("NULL")) { if bytes.Equal(b, []byte("NULL")) {
lit = nil b = nil
} }
if p.p.Peek() == ',' { if p.peek() == ',' {
p.p.Advance() p.skipNext()
} }
p.elem = lit return b, nil
return nil
} }
} }
func (p *arrayParser) readSimple() []byte {
p.unreadByte()
if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 {
b := p.b[p.i : p.i+i]
p.i += i
return b
}
b := p.b[p.i : len(p.b)-1]
p.i = len(p.b) - 1
return b
}
func (p *arrayParser) readSubstring() ([]byte, error) {
c, err := p.readByte()
if err != nil {
return nil, err
}
p.buf = p.buf[:0]
for {
if c == '"' {
break
}
next, err := p.readByte()
if err != nil {
return nil, err
}
if c == '\\' {
switch next {
case '\\', '"':
p.buf = append(p.buf, next)
c, err = p.readByte()
if err != nil {
return nil, err
}
default:
p.buf = append(p.buf, '\\')
c = next
}
continue
}
if c == '\'' && next == '\'' {
p.buf = append(p.buf, next)
c, err = p.readByte()
if err != nil {
return nil, err
}
continue
}
p.buf = append(p.buf, c)
c = next
}
if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 {
data := p.buf[2:]
buf := make([]byte, hex.DecodedLen(len(data)))
n, err := hex.Decode(buf, data)
if err != nil {
return nil, err
}
return buf[:n], nil
}
return p.buf, nil
}

View file

@ -1 +1,302 @@
package pgdialect package pgdialect
import (
"fmt"
"io"
"reflect"
"strconv"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)
func arrayScanner(typ reflect.Type) schema.ScannerFunc {
kind := typ.Kind()
switch kind {
case reflect.Ptr:
if fn := arrayScanner(typ.Elem()); fn != nil {
return schema.PtrScanner(fn)
}
case reflect.Slice, reflect.Array:
// ok:
default:
return nil
}
elemType := typ.Elem()
if kind == reflect.Slice {
switch elemType {
case stringType:
return scanStringSliceValue
case intType:
return scanIntSliceValue
case int64Type:
return scanInt64SliceValue
case float64Type:
return scanFloat64SliceValue
}
}
scanElem := schema.Scanner(elemType)
return func(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
kind := dest.Kind()
if src == nil {
if kind != reflect.Slice || !dest.IsNil() {
dest.Set(reflect.Zero(dest.Type()))
}
return nil
}
if kind == reflect.Slice {
if dest.IsNil() {
dest.Set(reflect.MakeSlice(dest.Type(), 0, 0))
} else if dest.Len() > 0 {
dest.Set(dest.Slice(0, 0))
}
}
b, err := toBytes(src)
if err != nil {
return err
}
p := newArrayParser(b)
nextValue := internal.MakeSliceNextElemFunc(dest)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return err
}
elemValue := nextValue()
if err := scanElem(elemValue, elem); err != nil {
return err
}
}
return nil
}
}
func scanStringSliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeStringSlice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeStringSlice(src interface{}) ([]string, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]string, 0)
p := newArrayParser(b)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
slice = append(slice, string(elem))
}
return slice, nil
}
func scanIntSliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeIntSlice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeIntSlice(src interface{}) ([]int, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]int, 0)
p := newArrayParser(b)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.Atoi(bytesToString(elem))
if err != nil {
return nil, err
}
slice = append(slice, n)
}
return slice, nil
}
func scanInt64SliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := decodeInt64Slice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func decodeInt64Slice(src interface{}) ([]int64, error) {
if src == nil {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]int64, 0)
p := newArrayParser(b)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.ParseInt(bytesToString(elem), 10, 64)
if err != nil {
return nil, err
}
slice = append(slice, n)
}
return slice, nil
}
func scanFloat64SliceValue(dest reflect.Value, src interface{}) error {
dest = reflect.Indirect(dest)
if !dest.CanSet() {
return fmt.Errorf("bun: Scan(non-settable %s)", dest.Type())
}
slice, err := scanFloat64Slice(src)
if err != nil {
return err
}
dest.Set(reflect.ValueOf(slice))
return nil
}
func scanFloat64Slice(src interface{}) ([]float64, error) {
if src == -1 {
return nil, nil
}
b, err := toBytes(src)
if err != nil {
return nil, err
}
slice := make([]float64, 0)
p := newArrayParser(b)
for {
elem, err := p.NextElem()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if elem == nil {
slice = append(slice, 0)
continue
}
n, err := strconv.ParseFloat(bytesToString(elem), 64)
if err != nil {
return nil, err
}
slice = append(slice, n)
}
return slice, nil
}
func toBytes(src interface{}) ([]byte, error) {
switch src := src.(type) {
case string:
return stringToBytes(src), nil
case []byte:
return src, nil
default:
return nil, fmt.Errorf("bun: got %T, wanted []byte or string", src)
}
}

View file

@ -89,17 +89,9 @@ func (d *Dialect) onField(field *schema.Field) {
if field.Tag.HasOption("array") || strings.HasSuffix(field.UserSQLType, "[]") { if field.Tag.HasOption("array") || strings.HasSuffix(field.UserSQLType, "[]") {
field.Append = d.arrayAppender(field.StructField.Type) field.Append = d.arrayAppender(field.StructField.Type)
field.Scan = arrayScanner(field.StructField.Type) field.Scan = arrayScanner(field.StructField.Type)
return
} }
if field.Tag.HasOption("multirange") { if field.DiscoveredSQLType == sqltype.HSTORE {
field.Append = d.arrayAppender(field.StructField.Type)
field.Scan = arrayScanner(field.StructField.Type)
return
}
switch field.DiscoveredSQLType {
case sqltype.HSTORE:
field.Append = d.hstoreAppender(field.StructField.Type) field.Append = d.hstoreAppender(field.StructField.Type)
field.Scan = hstoreScanner(field.StructField.Type) field.Scan = hstoreScanner(field.StructField.Type)
} }

View file

@ -3,98 +3,140 @@
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
) )
type hstoreParser struct { type hstoreParser struct {
p pgparser *streamParser
err error
key string
value string
err error
} }
func newHStoreParser(b []byte) *hstoreParser { func newHStoreParser(b []byte) *hstoreParser {
p := new(hstoreParser) p := &hstoreParser{
if len(b) != 0 && (len(b) < 6 || b[0] != '"') { streamParser: newStreamParser(b, 0),
p.err = fmt.Errorf("pgdialect: can't parse hstore: %q", b) }
return p if len(b) < 6 || b[0] != '"' {
p.err = fmt.Errorf("bun: can't parse hstore: %q", b)
} }
p.p.Reset(b)
return p return p
} }
func (p *hstoreParser) Next() bool { func (p *hstoreParser) NextKey() (string, error) {
if p.err != nil { if p.err != nil {
return false return "", p.err
}
p.err = p.readNext()
return p.err == nil
}
func (p *hstoreParser) Err() error {
if p.err != io.EOF {
return p.err
}
return nil
}
func (p *hstoreParser) Key() string {
return p.key
}
func (p *hstoreParser) Value() string {
return p.value
}
func (p *hstoreParser) readNext() error {
if !p.p.Valid() {
return io.EOF
} }
if err := p.p.Skip('"'); err != nil { err := p.skipByte('"')
return err
}
key, err := p.p.ReadUnescapedSubstring('"')
if err != nil { if err != nil {
return err return "", err
}
p.key = string(key)
if err := p.p.SkipPrefix([]byte("=>")); err != nil {
return err
} }
ch, err := p.p.ReadByte() key, err := p.readSubstring()
if err != nil { if err != nil {
return err return "", err
} }
switch ch { const separator = "=>"
case '"':
value, err := p.p.ReadUnescapedSubstring(ch) for i := range separator {
err = p.skipByte(separator[i])
if err != nil { if err != nil {
return err return "", err
} }
p.skipComma() }
p.value = string(value)
return nil return string(key), nil
}
func (p *hstoreParser) NextValue() (string, error) {
if p.err != nil {
return "", p.err
}
c, err := p.readByte()
if err != nil {
return "", err
}
switch c {
case '"':
value, err := p.readSubstring()
if err != nil {
return "", err
}
if p.peek() == ',' {
p.skipNext()
}
if p.peek() == ' ' {
p.skipNext()
}
return string(value), nil
default: default:
value := p.p.ReadLiteral(ch) value := p.readSimple()
if bytes.Equal(value, []byte("NULL")) { if bytes.Equal(value, []byte("NULL")) {
p.value = "" value = nil
} }
p.skipComma()
return nil if p.peek() == ',' {
p.skipNext()
}
return string(value), nil
} }
} }
func (p *hstoreParser) skipComma() { func (p *hstoreParser) readSimple() []byte {
if p.p.Peek() == ',' { p.unreadByte()
p.p.Advance()
} if i := bytes.IndexByte(p.b[p.i:], ','); i >= 0 {
if p.p.Peek() == ' ' { b := p.b[p.i : p.i+i]
p.p.Advance() p.i += i
return b
} }
b := p.b[p.i:len(p.b)]
p.i = len(p.b)
return b
}
func (p *hstoreParser) readSubstring() ([]byte, error) {
c, err := p.readByte()
if err != nil {
return nil, err
}
p.buf = p.buf[:0]
for {
if c == '"' {
break
}
next, err := p.readByte()
if err != nil {
return nil, err
}
if c == '\\' {
switch next {
case '\\', '"':
p.buf = append(p.buf, next)
c, err = p.readByte()
if err != nil {
return nil, err
}
default:
p.buf = append(p.buf, '\\')
c = next
}
continue
}
p.buf = append(p.buf, c)
c = next
}
return p.buf, nil
} }

View file

@ -2,6 +2,7 @@
import ( import (
"fmt" "fmt"
"io"
"reflect" "reflect"
"github.com/uptrace/bun/schema" "github.com/uptrace/bun/schema"
@ -57,11 +58,25 @@ func decodeMapStringString(src interface{}) (map[string]string, error) {
m := make(map[string]string) m := make(map[string]string)
p := newHStoreParser(b) p := newHStoreParser(b)
for p.Next() { for {
m[p.Key()] = p.Value() key, err := p.NextKey()
} if err != nil {
if err := p.Err(); err != nil { if err == io.EOF {
return nil, err break
}
return nil, err
}
value, err := p.NextValue()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
m[key] = value
} }
return m, nil return m, nil
} }

View file

@ -1,240 +0,0 @@
package pgdialect
import (
"bytes"
"database/sql"
"encoding/hex"
"fmt"
"io"
"time"
"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/internal/parser"
"github.com/uptrace/bun/schema"
)
type MultiRange[T any] []Range[T]
type Range[T any] struct {
Lower, Upper T
LowerBound, UpperBound RangeBound
}
type RangeBound byte
const (
RangeBoundInclusiveLeft RangeBound = '['
RangeBoundInclusiveRight RangeBound = ']'
RangeBoundExclusiveLeft RangeBound = '('
RangeBoundExclusiveRight RangeBound = ')'
)
func NewRange[T any](lower, upper T) Range[T] {
return Range[T]{
Lower: lower,
Upper: upper,
LowerBound: RangeBoundInclusiveLeft,
UpperBound: RangeBoundExclusiveRight,
}
}
var _ sql.Scanner = (*Range[any])(nil)
func (r *Range[T]) Scan(anySrc any) (err error) {
src := anySrc.([]byte)
if len(src) == 0 {
return io.ErrUnexpectedEOF
}
r.LowerBound = RangeBound(src[0])
src = src[1:]
src, err = scanElem(&r.Lower, src)
if err != nil {
return err
}
if len(src) == 0 {
return io.ErrUnexpectedEOF
}
if ch := src[0]; ch != ',' {
return fmt.Errorf("got %q, wanted %q", ch, ',')
}
src = src[1:]
src, err = scanElem(&r.Upper, src)
if err != nil {
return err
}
if len(src) == 0 {
return io.ErrUnexpectedEOF
}
r.UpperBound = RangeBound(src[0])
src = src[1:]
if len(src) > 0 {
return fmt.Errorf("unread data: %q", src)
}
return nil
}
var _ schema.QueryAppender = (*Range[any])(nil)
func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) {
buf = append(buf, byte(r.LowerBound))
buf = appendElem(buf, r.Lower)
buf = append(buf, ',')
buf = appendElem(buf, r.Upper)
buf = append(buf, byte(r.UpperBound))
return buf, nil
}
func appendElem(buf []byte, val any) []byte {
switch val := val.(type) {
case time.Time:
buf = append(buf, '"')
buf = appendTime(buf, val)
buf = append(buf, '"')
return buf
default:
panic(fmt.Errorf("unsupported range type: %T", val))
}
}
func scanElem(ptr any, src []byte) ([]byte, error) {
switch ptr := ptr.(type) {
case *time.Time:
src, str, err := readStringLiteral(src)
if err != nil {
return nil, err
}
tm, err := internal.ParseTime(internal.String(str))
if err != nil {
return nil, err
}
*ptr = tm
return src, nil
default:
panic(fmt.Errorf("unsupported range type: %T", ptr))
}
}
func readStringLiteral(src []byte) ([]byte, []byte, error) {
p := newParser(src)
if err := p.Skip('"'); err != nil {
return nil, nil, err
}
str, err := p.ReadSubstring('"')
if err != nil {
return nil, nil, err
}
src = p.Remaining()
return src, str, nil
}
//------------------------------------------------------------------------------
type pgparser struct {
parser.Parser
buf []byte
}
func newParser(b []byte) *pgparser {
p := new(pgparser)
p.Reset(b)
return p
}
func (p *pgparser) ReadLiteral(ch byte) []byte {
p.Unread()
lit, _ := p.ReadSep(',')
return lit
}
func (p *pgparser) ReadUnescapedSubstring(ch byte) ([]byte, error) {
return p.readSubstring(ch, false)
}
func (p *pgparser) ReadSubstring(ch byte) ([]byte, error) {
return p.readSubstring(ch, true)
}
func (p *pgparser) readSubstring(ch byte, escaped bool) ([]byte, error) {
ch, err := p.ReadByte()
if err != nil {
return nil, err
}
p.buf = p.buf[:0]
for {
if ch == '"' {
break
}
next, err := p.ReadByte()
if err != nil {
return nil, err
}
if ch == '\\' {
switch next {
case '\\', '"':
p.buf = append(p.buf, next)
ch, err = p.ReadByte()
if err != nil {
return nil, err
}
default:
p.buf = append(p.buf, '\\')
ch = next
}
continue
}
if escaped && ch == '\'' && next == '\'' {
p.buf = append(p.buf, next)
ch, err = p.ReadByte()
if err != nil {
return nil, err
}
continue
}
p.buf = append(p.buf, ch)
ch = next
}
if bytes.HasPrefix(p.buf, []byte("\\x")) && len(p.buf)%2 == 0 {
data := p.buf[2:]
buf := make([]byte, hex.DecodedLen(len(data)))
n, err := hex.Decode(buf, data)
if err != nil {
return nil, err
}
return buf[:n], nil
}
return p.buf, nil
}
func (p *pgparser) ReadRange(ch byte) ([]byte, error) {
p.buf = p.buf[:0]
p.buf = append(p.buf, ch)
for p.Valid() {
ch = p.Read()
p.buf = append(p.buf, ch)
if ch == ']' || ch == ')' {
break
}
}
return p.buf, nil
}

View file

@ -1,7 +1,6 @@
package pgdialect package pgdialect
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"net" "net"
"reflect" "reflect"
@ -28,6 +27,14 @@
pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer pgTypeSerial = "SERIAL" // 4 byte autoincrementing integer
pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer pgTypeBigSerial = "BIGSERIAL" // 8 byte autoincrementing integer
// Character Types
pgTypeChar = "CHAR" // fixed length string (blank padded)
pgTypeText = "TEXT" // variable length string without limit
// JSON Types
pgTypeJSON = "JSON" // text representation of json data
pgTypeJSONB = "JSONB" // binary representation of json data
// Binary Data Types // Binary Data Types
pgTypeBytea = "BYTEA" // binary string pgTypeBytea = "BYTEA" // binary string
) )
@ -36,7 +43,6 @@
ipType = reflect.TypeOf((*net.IP)(nil)).Elem() ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
nullStringType = reflect.TypeOf((*sql.NullString)(nil)).Elem()
) )
func (d *Dialect) DefaultVarcharLen() int { func (d *Dialect) DefaultVarcharLen() int {
@ -72,14 +78,12 @@ func fieldSQLType(field *schema.Field) string {
func sqlType(typ reflect.Type) string { func sqlType(typ reflect.Type) string {
switch typ { switch typ {
case nullStringType: // typ.Kind() == reflect.Struct, test for exact match
return sqltype.VarChar
case ipType: case ipType:
return pgTypeInet return pgTypeInet
case ipNetType: case ipNetType:
return pgTypeCidr return pgTypeCidr
case jsonRawMessageType: case jsonRawMessageType:
return sqltype.JSONB return pgTypeJSONB
} }
sqlType := schema.DiscoverSQLType(typ) sqlType := schema.DiscoverSQLType(typ)
@ -89,16 +93,16 @@ func sqlType(typ reflect.Type) string {
} }
switch typ.Kind() { switch typ.Kind() {
case reflect.Map, reflect.Struct: // except typ == nullStringType, see above case reflect.Map, reflect.Struct:
if sqlType == sqltype.VarChar { if sqlType == sqltype.VarChar {
return sqltype.JSONB return pgTypeJSONB
} }
return sqlType return sqlType
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 { if typ.Elem().Kind() == reflect.Uint8 {
return pgTypeBytea return pgTypeBytea
} }
return sqltype.JSONB return pgTypeJSONB
} }
return sqlType return sqlType

View file

@ -0,0 +1,60 @@
package pgdialect
import (
"fmt"
"io"
)
type streamParser struct {
b []byte
i int
buf []byte
}
func newStreamParser(b []byte, start int) *streamParser {
return &streamParser{
b: b,
i: start,
}
}
func (p *streamParser) valid() bool {
return p.i < len(p.b)
}
func (p *streamParser) skipByte(skip byte) error {
c, err := p.readByte()
if err != nil {
return err
}
if c == skip {
return nil
}
p.unreadByte()
return fmt.Errorf("got %q, wanted %q", c, skip)
}
func (p *streamParser) readByte() (byte, error) {
if p.valid() {
c := p.b[p.i]
p.i++
return c, nil
}
return 0, io.EOF
}
func (p *streamParser) unreadByte() {
p.i--
}
func (p *streamParser) peek() byte {
if p.valid() {
return p.b[p.i]
}
return 0
}
func (p *streamParser) skipNext() {
p.i++
}

View file

@ -2,5 +2,5 @@
// Version is the current release version. // Version is the current release version.
func Version() string { func Version() string {
return "1.2.5" return "1.2.1"
} }

View file

@ -2,5 +2,5 @@
// Version is the current release version. // Version is the current release version.
func Version() string { func Version() string {
return "1.2.5" return "1.2.1"
} }

View file

@ -1,3 +1,3 @@
# OpenTelemetry instrumentation for Bun # OpenTelemetry instrumentation for Bun
See [example](../../example/opentelemetry) for details. See [example](../example/opentelemetry) for details.

View file

@ -1,4 +1,3 @@
//go:build !appengine
// +build !appengine // +build !appengine
package bunotel package bunotel
@ -6,15 +5,14 @@
import "unsafe" import "unsafe"
func bytesToString(b []byte) string { func bytesToString(b []byte) string {
if len(b) == 0 { return *(*string)(unsafe.Pointer(&b))
return ""
}
return unsafe.String(&b[0], len(b))
} }
func stringToBytes(s string) []byte { func stringToBytes(s string) []byte {
if s == "" { return *(*[]byte)(unsafe.Pointer(
return []byte{} &struct {
} string
return unsafe.Slice(unsafe.StringData(s), len(s)) Cap int
}{s, len(s)},
))
} }

View file

@ -2,8 +2,6 @@
import ( import (
"bytes" "bytes"
"fmt"
"io"
"strconv" "strconv"
"github.com/uptrace/bun/internal" "github.com/uptrace/bun/internal"
@ -24,43 +22,23 @@ func NewString(s string) *Parser {
return New(internal.Bytes(s)) return New(internal.Bytes(s))
} }
func (p *Parser) Reset(b []byte) {
p.b = b
p.i = 0
}
func (p *Parser) Valid() bool { func (p *Parser) Valid() bool {
return p.i < len(p.b) return p.i < len(p.b)
} }
func (p *Parser) Remaining() []byte { func (p *Parser) Bytes() []byte {
return p.b[p.i:] return p.b[p.i:]
} }
func (p *Parser) ReadByte() (byte, error) {
if p.Valid() {
ch := p.b[p.i]
p.Advance()
return ch, nil
}
return 0, io.ErrUnexpectedEOF
}
func (p *Parser) Read() byte { func (p *Parser) Read() byte {
if p.Valid() { if p.Valid() {
ch := p.b[p.i] c := p.b[p.i]
p.Advance() p.Advance()
return ch return c
} }
return 0 return 0
} }
func (p *Parser) Unread() {
if p.i > 0 {
p.i--
}
}
func (p *Parser) Peek() byte { func (p *Parser) Peek() byte {
if p.Valid() { if p.Valid() {
return p.b[p.i] return p.b[p.i]
@ -72,25 +50,19 @@ func (p *Parser) Advance() {
p.i++ p.i++
} }
func (p *Parser) Skip(skip byte) error { func (p *Parser) Skip(skip byte) bool {
ch := p.Peek() if p.Peek() == skip {
if ch == skip {
p.Advance() p.Advance()
return nil return true
} }
return fmt.Errorf("got %q, wanted %q", ch, skip) return false
} }
func (p *Parser) SkipPrefix(skip []byte) error { func (p *Parser) SkipBytes(skip []byte) bool {
if !bytes.HasPrefix(p.b[p.i:], skip) { if len(skip) > len(p.b[p.i:]) {
return fmt.Errorf("got %q, wanted prefix %q", p.b, skip) return false
} }
p.i += len(skip) if !bytes.Equal(p.b[p.i:p.i+len(skip)], skip) {
return nil
}
func (p *Parser) CutPrefix(skip []byte) bool {
if !bytes.HasPrefix(p.b[p.i:], skip) {
return false return false
} }
p.i += len(skip) p.i += len(skip)

View file

@ -1,4 +1,3 @@
//go:build !appengine
// +build !appengine // +build !appengine
package internal package internal
@ -7,16 +6,15 @@
// String converts byte slice to string. // String converts byte slice to string.
func String(b []byte) string { func String(b []byte) string {
if len(b) == 0 { return *(*string)(unsafe.Pointer(&b))
return ""
}
return unsafe.String(&b[0], len(b))
} }
// Bytes converts string to byte slice. // Bytes converts string to byte slice.
func Bytes(s string) []byte { func Bytes(s string) []byte {
if s == "" { return *(*[]byte)(unsafe.Pointer(
return []byte{} &struct {
} string
return unsafe.Slice(unsafe.StringData(s), len(s)) Cap int
}{s, len(s)},
))
} }

View file

@ -96,6 +96,10 @@ func (m *Migrations) Discover(fsys fs.FS) error {
} }
migration := m.getOrCreateMigration(name) migration := m.getOrCreateMigration(name)
if err != nil {
return err
}
migration.Comment = comment migration.Comment = comment
migrationFunc := NewSQLMigrationFunc(fsys, path) migrationFunc := NewSQLMigrationFunc(fsys, path)

View file

@ -362,10 +362,7 @@ func (m *Migrator) MarkUnapplied(ctx context.Context, migration *Migration) erro
} }
func (m *Migrator) TruncateTable(ctx context.Context) error { func (m *Migrator) TruncateTable(ctx context.Context) error {
_, err := m.db.NewTruncateTable(). _, err := m.db.NewTruncateTable().TableExpr(m.table).Exec(ctx)
Model((*Migration)(nil)).
ModelTableExpr(m.table).
Exec(ctx)
return err return err
} }

View file

@ -1,7 +1,6 @@
package bun package bun
import ( import (
"bytes"
"context" "context"
"database/sql" "database/sql"
"reflect" "reflect"
@ -83,8 +82,6 @@ func (m *mapModel) Scan(src interface{}) error {
return m.scanRaw(src) return m.scanRaw(src)
case reflect.Slice: case reflect.Slice:
if scanType.Elem().Kind() == reflect.Uint8 { if scanType.Elem().Kind() == reflect.Uint8 {
// Reference types such as []byte are only valid until the next call to Scan.
src := bytes.Clone(src.([]byte))
return m.scanRaw(src) return m.scanRaw(src)
} }
} }

View file

@ -24,7 +24,7 @@ type hasManyModel struct {
func newHasManyModel(j *relationJoin) *hasManyModel { func newHasManyModel(j *relationJoin) *hasManyModel {
baseTable := j.BaseModel.Table() baseTable := j.BaseModel.Table()
joinModel := j.JoinModel.(*sliceTableModel) joinModel := j.JoinModel.(*sliceTableModel)
baseValues := baseValues(joinModel, j.Relation.BasePKs) baseValues := baseValues(joinModel, j.Relation.BaseFields)
if len(baseValues) == 0 { if len(baseValues) == 0 {
return nil return nil
} }
@ -92,9 +92,9 @@ func (m *hasManyModel) Scan(src interface{}) error {
return err return err
} }
for _, f := range m.rel.JoinPKs { for _, f := range m.rel.JoinFields {
if f.Name == field.Name { if f.Name == field.Name {
m.structKey = append(m.structKey, indirectFieldValue(field.Value(m.strct))) m.structKey = append(m.structKey, field.Value(m.strct).Interface())
break break
} }
} }
@ -103,7 +103,6 @@ func (m *hasManyModel) Scan(src interface{}) error {
} }
func (m *hasManyModel) parkStruct() error { func (m *hasManyModel) parkStruct() error {
baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)]
if !ok { if !ok {
return fmt.Errorf( return fmt.Errorf(
@ -144,19 +143,7 @@ func baseValues(model TableModel, fields []*schema.Field) map[internal.MapKey][]
func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} { func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} {
for _, f := range fields { for _, f := range fields {
key = append(key, indirectFieldValue(f.Value(strct))) key = append(key, f.Value(strct).Interface())
} }
return key return key
} }
// indirectFieldValue return the field value dereferencing the pointer if necessary.
// The value is then used as a map key.
func indirectFieldValue(field reflect.Value) interface{} {
if field.Kind() != reflect.Ptr {
return field.Interface()
}
if field.IsNil() {
return nil
}
return field.Elem().Interface()
}

View file

@ -24,7 +24,7 @@ type m2mModel struct {
func newM2MModel(j *relationJoin) *m2mModel { func newM2MModel(j *relationJoin) *m2mModel {
baseTable := j.BaseModel.Table() baseTable := j.BaseModel.Table()
joinModel := j.JoinModel.(*sliceTableModel) joinModel := j.JoinModel.(*sliceTableModel)
baseValues := baseValues(joinModel, j.Relation.BasePKs) baseValues := baseValues(joinModel, baseTable.PKs)
if len(baseValues) == 0 { if len(baseValues) == 0 {
return nil return nil
} }
@ -83,21 +83,27 @@ func (m *m2mModel) Scan(src interface{}) error {
column := m.columns[m.scanIndex] column := m.columns[m.scanIndex]
m.scanIndex++ m.scanIndex++
// Base pks must come first. field, ok := m.table.FieldMap[column]
if m.scanIndex <= len(m.rel.M2MBasePKs) { if !ok {
return m.scanM2MColumn(column, src) return m.scanM2MColumn(column, src)
} }
if field, ok := m.table.FieldMap[column]; ok { if err := field.ScanValue(m.strct, src); err != nil {
return field.ScanValue(m.strct, src) return err
} }
_, err := m.scanColumn(column, src) for _, fk := range m.rel.M2MBaseFields {
return err if fk.Name == field.Name {
m.structKey = append(m.structKey, field.Value(m.strct).Interface())
break
}
}
return nil
} }
func (m *m2mModel) scanM2MColumn(column string, src interface{}) error { func (m *m2mModel) scanM2MColumn(column string, src interface{}) error {
for _, field := range m.rel.M2MBasePKs { for _, field := range m.rel.M2MBaseFields {
if field.Name == column { if field.Name == column {
dest := reflect.New(field.IndirectType).Elem() dest := reflect.New(field.IndirectType).Elem()
if err := field.Scan(dest, src); err != nil { if err := field.Scan(dest, src); err != nil {

View file

@ -242,7 +242,7 @@ func (m *structTableModel) ScanRows(ctx context.Context, rows *sql.Rows) (int, e
n++ n++
// And discard the rest. This is especially important for SQLite3, which can return // And discard the rest. This is especially important for SQLite3, which can return
// a row like it was inserted successfully and then return an actual error for the next row. // a row like it was inserted sucessfully and then return an actual error for the next row.
// See issues/100. // See issues/100.
for rows.Next() { for rows.Next() {
n++ n++

View file

@ -1,6 +1,6 @@
{ {
"name": "gobun", "name": "gobun",
"version": "1.2.5", "version": "1.2.1",
"main": "index.js", "main": "index.js",
"repository": "git@github.com:uptrace/bun.git", "repository": "git@github.com:uptrace/bun.git",
"author": "Vladimir Mihailenco <vladimir.webdev@gmail.com>", "author": "Vladimir Mihailenco <vladimir.webdev@gmail.com>",

View file

@ -8,7 +8,6 @@
"fmt" "fmt"
"time" "time"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal" "github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema" "github.com/uptrace/bun/schema"
@ -419,11 +418,7 @@ func (q *baseQuery) _appendTables(
} else { } else {
b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects)) b = fmter.AppendQuery(b, string(q.table.SQLNameForSelects))
if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects { if withAlias && q.table.SQLAlias != q.table.SQLNameForSelects {
if q.db.dialect.Name() == dialect.Oracle { b = append(b, " AS "...)
b = append(b, ' ')
} else {
b = append(b, " AS "...)
}
b = append(b, q.table.SQLAlias...) b = append(b, q.table.SQLAlias...)
} }
} }

View file

@ -538,11 +538,6 @@ func (q *SelectQuery) appendQuery(
if count && !cteCount { if count && !cteCount {
b = append(b, "count(*)"...) b = append(b, "count(*)"...)
} else { } else {
// MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953
if q.limit > 0 && len(q.order) == 0 && fmter.Dialect().Name() == dialect.MSSQL {
b = append(b, "0 AS _temp_sort, "...)
}
b, err = q.appendColumns(fmter, b) b, err = q.appendColumns(fmter, b)
if err != nil { if err != nil {
return nil, err return nil, err
@ -569,8 +564,8 @@ func (q *SelectQuery) appendQuery(
return nil, err return nil, err
} }
for _, join := range q.joins { for _, j := range q.joins {
b, err = join.AppendQuery(fmter, b) b, err = j.AppendQuery(fmter, b)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -798,12 +793,6 @@ func (q *SelectQuery) appendOrder(fmter schema.Formatter, b []byte) (_ []byte, e
return b, nil return b, nil
} }
// MSSQL: allows Limit() without Order() as per https://stackoverflow.com/a/36156953
if q.limit > 0 && fmter.Dialect().Name() == dialect.MSSQL {
return append(b, " ORDER BY _temp_sort"...), nil
}
return b, nil return b, nil
} }
@ -867,57 +856,52 @@ func (q *SelectQuery) Exec(ctx context.Context, dest ...interface{}) (res sql.Re
} }
func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error { func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error {
_, err := q.scanResult(ctx, dest...)
return err
}
func (q *SelectQuery) scanResult(ctx context.Context, dest ...interface{}) (sql.Result, error) {
if q.err != nil { if q.err != nil {
return nil, q.err return q.err
} }
model, err := q.getModel(dest) model, err := q.getModel(dest)
if err != nil { if err != nil {
return nil, err return err
} }
if q.table != nil { if q.table != nil {
if err := q.beforeSelectHook(ctx); err != nil { if err := q.beforeSelectHook(ctx); err != nil {
return nil, err return err
} }
} }
if err := q.beforeAppendModel(ctx, q); err != nil { if err := q.beforeAppendModel(ctx, q); err != nil {
return nil, err return err
} }
queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes()) queryBytes, err := q.AppendQuery(q.db.fmter, q.db.makeQueryBytes())
if err != nil { if err != nil {
return nil, err return err
} }
query := internal.String(queryBytes) query := internal.String(queryBytes)
res, err := q.scan(ctx, q, query, model, true) res, err := q.scan(ctx, q, query, model, true)
if err != nil { if err != nil {
return nil, err return err
} }
if n, _ := res.RowsAffected(); n > 0 { if n, _ := res.RowsAffected(); n > 0 {
if tableModel, ok := model.(TableModel); ok { if tableModel, ok := model.(TableModel); ok {
if err := q.selectJoins(ctx, tableModel.getJoins()); err != nil { if err := q.selectJoins(ctx, tableModel.getJoins()); err != nil {
return nil, err return err
} }
} }
} }
if q.table != nil { if q.table != nil {
if err := q.afterSelectHook(ctx); err != nil { if err := q.afterSelectHook(ctx); err != nil {
return nil, err return err
} }
} }
return res, nil return nil
} }
func (q *SelectQuery) beforeSelectHook(ctx context.Context) error { func (q *SelectQuery) beforeSelectHook(ctx context.Context) error {
@ -962,16 +946,6 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) {
} }
func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) { func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (int, error) {
if q.offset == 0 && q.limit == 0 {
// If there is no limit and offset, we can use a single query to get the count and scan
if res, err := q.scanResult(ctx, dest...); err != nil {
return 0, err
} else if n, err := res.RowsAffected(); err != nil {
return 0, err
} else {
return int(n), nil
}
}
if _, ok := q.conn.(*DB); ok { if _, ok := q.conn.(*DB); ok {
return q.scanAndCountConc(ctx, dest...) return q.scanAndCountConc(ctx, dest...)
} }

View file

@ -9,7 +9,6 @@
"strconv" "strconv"
"strings" "strings"
"github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/internal" "github.com/uptrace/bun/internal"
@ -166,7 +165,7 @@ func (q *CreateTableQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []by
b = append(b, field.SQLName...) b = append(b, field.SQLName...)
b = append(b, " "...) b = append(b, " "...)
b = q.appendSQLType(b, field) b = q.appendSQLType(b, field)
if field.NotNull && q.db.dialect.Name() != dialect.Oracle { if field.NotNull {
b = append(b, " NOT NULL"...) b = append(b, " NOT NULL"...)
} }
@ -247,11 +246,7 @@ func (q *CreateTableQuery) appendSQLType(b []byte, field *schema.Field) []byte {
return append(b, field.CreateTableSQLType...) return append(b, field.CreateTableSQLType...)
} }
if q.db.dialect.Name() == dialect.Oracle { b = append(b, sqltype.VarChar...)
b = append(b, "VARCHAR2"...)
} else {
b = append(b, sqltype.VarChar...)
}
b = append(b, "("...) b = append(b, "("...)
b = strconv.AppendInt(b, int64(q.varchar), 10) b = strconv.AppendInt(b, int64(q.varchar), 10)
b = append(b, ")"...) b = append(b, ")"...)
@ -302,9 +297,9 @@ func (q *CreateTableQuery) appendFKConstraintsRel(fmter schema.Formatter, b []by
b, err = q.appendFK(fmter, b, schema.QueryWithArgs{ b, err = q.appendFK(fmter, b, schema.QueryWithArgs{
Query: "(?) REFERENCES ? (?) ? ?", Query: "(?) REFERENCES ? (?) ? ?",
Args: []interface{}{ Args: []interface{}{
Safe(appendColumns(nil, "", rel.BasePKs)), Safe(appendColumns(nil, "", rel.BaseFields)),
rel.JoinTable.SQLName, rel.JoinTable.SQLName,
Safe(appendColumns(nil, "", rel.JoinPKs)), Safe(appendColumns(nil, "", rel.JoinFields)),
Safe(rel.OnUpdate), Safe(rel.OnUpdate),
Safe(rel.OnDelete), Safe(rel.OnDelete),
}, },

View file

@ -57,11 +57,6 @@ func (q *TruncateTableQuery) TableExpr(query string, args ...interface{}) *Trunc
return q return q
} }
func (q *TruncateTableQuery) ModelTableExpr(query string, args ...interface{}) *TruncateTableQuery {
q.modelTableName = schema.SafeQuery(query, args)
return q
}
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------
func (q *TruncateTableQuery) ContinueIdentity() *TruncateTableQuery { func (q *TruncateTableQuery) ContinueIdentity() *TruncateTableQuery {

View file

@ -70,11 +70,11 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {
} }
func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery { func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery {
if len(j.Relation.JoinPKs) > 1 { if len(j.Relation.JoinFields) > 1 {
where = append(where, '(') where = append(where, '(')
} }
where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinPKs) where = appendColumns(where, j.JoinModel.Table().SQLAlias, j.Relation.JoinFields)
if len(j.Relation.JoinPKs) > 1 { if len(j.Relation.JoinFields) > 1 {
where = append(where, ')') where = append(where, ')')
} }
where = append(where, " IN ("...) where = append(where, " IN ("...)
@ -83,7 +83,7 @@ func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *Selec
where, where,
j.JoinModel.rootValue(), j.JoinModel.rootValue(),
j.JoinModel.parentIndex(), j.JoinModel.parentIndex(),
j.Relation.BasePKs, j.Relation.BaseFields,
) )
where = append(where, ")"...) where = append(where, ")"...)
q = q.Where(internal.String(where)) q = q.Where(internal.String(where))
@ -104,8 +104,8 @@ func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery
where, where,
j.JoinModel.rootValue(), j.JoinModel.rootValue(),
j.JoinModel.parentIndex(), j.JoinModel.parentIndex(),
j.Relation.BasePKs, j.Relation.BaseFields,
j.Relation.JoinPKs, j.Relation.JoinFields,
j.JoinModel.Table().SQLAlias, j.JoinModel.Table().SQLAlias,
) )
@ -175,10 +175,10 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
q = q.Model(m2mModel) q = q.Model(m2mModel)
index := j.JoinModel.parentIndex() index := j.JoinModel.parentIndex()
baseTable := j.BaseModel.Table()
if j.Relation.M2MTable != nil { if j.Relation.M2MTable != nil {
// We only need base pks to park joined models to the base model. fields := append(j.Relation.M2MBaseFields, j.Relation.M2MJoinFields...)
fields := j.Relation.M2MBasePKs
b := make([]byte, 0, len(fields)) b := make([]byte, 0, len(fields))
b = appendColumns(b, j.Relation.M2MTable.SQLAlias, fields) b = appendColumns(b, j.Relation.M2MTable.SQLAlias, fields)
@ -193,7 +193,7 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
join = append(join, " AS "...) join = append(join, " AS "...)
join = append(join, j.Relation.M2MTable.SQLAlias...) join = append(join, j.Relation.M2MTable.SQLAlias...)
join = append(join, " ON ("...) join = append(join, " ON ("...)
for i, col := range j.Relation.M2MBasePKs { for i, col := range j.Relation.M2MBaseFields {
if i > 0 { if i > 0 {
join = append(join, ", "...) join = append(join, ", "...)
} }
@ -202,13 +202,13 @@ func (j *relationJoin) m2mQuery(q *SelectQuery) *SelectQuery {
join = append(join, col.SQLName...) join = append(join, col.SQLName...)
} }
join = append(join, ") IN ("...) join = append(join, ") IN ("...)
join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, j.Relation.BasePKs) join = appendChildValues(fmter, join, j.BaseModel.rootValue(), index, baseTable.PKs)
join = append(join, ")"...) join = append(join, ")"...)
q = q.Join(internal.String(join)) q = q.Join(internal.String(join))
joinTable := j.JoinModel.Table() joinTable := j.JoinModel.Table()
for i, m2mJoinField := range j.Relation.M2MJoinPKs { for i, m2mJoinField := range j.Relation.M2MJoinFields {
joinField := j.Relation.JoinPKs[i] joinField := j.Relation.JoinFields[i]
q = q.Where("?.? = ?.?", q = q.Where("?.? = ?.?",
joinTable.SQLAlias, joinField.SQLName, joinTable.SQLAlias, joinField.SQLName,
j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName) j.Relation.M2MTable.SQLAlias, m2mJoinField.SQLName)
@ -310,13 +310,13 @@ func (j *relationJoin) appendHasOneJoin(
b = append(b, " ON "...) b = append(b, " ON "...)
b = append(b, '(') b = append(b, '(')
for i, baseField := range j.Relation.BasePKs { for i, baseField := range j.Relation.BaseFields {
if i > 0 { if i > 0 {
b = append(b, " AND "...) b = append(b, " AND "...)
} }
b = j.appendAlias(fmter, b) b = j.appendAlias(fmter, b)
b = append(b, '.') b = append(b, '.')
b = append(b, j.Relation.JoinPKs[i].SQLName...) b = append(b, j.Relation.JoinFields[i].SQLName...)
b = append(b, " = "...) b = append(b, " = "...)
b = j.appendBaseAlias(fmter, b) b = j.appendBaseAlias(fmter, b)
b = append(b, '.') b = append(b, '.')
@ -367,13 +367,13 @@ func appendChildValues(
} }
// appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID // appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID
// but instead uses old style ((k1=v1) AND (k2=v2)) OR (...) conditions. // but instead use a old style ((k1=v1) AND (k2=v2)) OR (...) of conditions.
func appendMultiValues( func appendMultiValues(
fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe, fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe,
) []byte { ) []byte {
// This is based on a mix of appendChildValues and query_base.appendColumns // This is based on a mix of appendChildValues and query_base.appendColumns
// These should never mismatch in length but nice to know if it does // These should never missmatch in length but nice to know if it does
if len(joinFields) != len(baseFields) { if len(joinFields) != len(baseFields) {
panic("not reached") panic("not reached")
} }

View file

@ -7,9 +7,9 @@
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/puzpuzpuz/xsync/v3"
"github.com/uptrace/bun/dialect" "github.com/uptrace/bun/dialect"
"github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/dialect/sqltype"
"github.com/uptrace/bun/extra/bunjson" "github.com/uptrace/bun/extra/bunjson"
@ -51,7 +51,7 @@
reflect.UnsafePointer: nil, reflect.UnsafePointer: nil,
} }
var appenderCache = xsync.NewMapOf[reflect.Type, AppenderFunc]() var appenderMap sync.Map
func FieldAppender(dialect Dialect, field *Field) AppenderFunc { func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
if field.Tag.HasOption("msgpack") { if field.Tag.HasOption("msgpack") {
@ -67,7 +67,7 @@ func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
} }
if fieldType.Kind() != reflect.Ptr { if fieldType.Kind() != reflect.Ptr {
if reflect.PointerTo(fieldType).Implements(driverValuerType) { if reflect.PtrTo(fieldType).Implements(driverValuerType) {
return addrAppender(appendDriverValue) return addrAppender(appendDriverValue)
} }
} }
@ -79,14 +79,14 @@ func FieldAppender(dialect Dialect, field *Field) AppenderFunc {
} }
func Appender(dialect Dialect, typ reflect.Type) AppenderFunc { func Appender(dialect Dialect, typ reflect.Type) AppenderFunc {
if v, ok := appenderCache.Load(typ); ok { if v, ok := appenderMap.Load(typ); ok {
return v return v.(AppenderFunc)
} }
fn := appender(dialect, typ) fn := appender(dialect, typ)
if v, ok := appenderCache.LoadOrStore(typ, fn); ok { if v, ok := appenderMap.LoadOrStore(typ, fn); ok {
return v return v.(AppenderFunc)
} }
return fn return fn
} }
@ -99,10 +99,10 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
return appendTimeValue return appendTimeValue
case timePtrType: case timePtrType:
return PtrAppender(appendTimeValue) return PtrAppender(appendTimeValue)
case ipType:
return appendIPValue
case ipNetType: case ipNetType:
return appendIPNetValue return appendIPNetValue
case ipType, netipPrefixType, netipAddrType:
return appendStringer
case jsonRawMessageType: case jsonRawMessageType:
return appendJSONRawMessageValue return appendJSONRawMessageValue
} }
@ -123,7 +123,7 @@ func appender(dialect Dialect, typ reflect.Type) AppenderFunc {
} }
if kind != reflect.Ptr { if kind != reflect.Ptr {
ptr := reflect.PointerTo(typ) ptr := reflect.PtrTo(typ)
if ptr.Implements(queryAppenderType) { if ptr.Implements(queryAppenderType) {
return addrAppender(appendQueryAppenderValue) return addrAppender(appendQueryAppenderValue)
} }
@ -247,15 +247,16 @@ func appendTimeValue(fmter Formatter, b []byte, v reflect.Value) []byte {
return fmter.Dialect().AppendTime(b, tm) return fmter.Dialect().AppendTime(b, tm)
} }
func appendIPValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ip := v.Interface().(net.IP)
return fmter.Dialect().AppendString(b, ip.String())
}
func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte { func appendIPNetValue(fmter Formatter, b []byte, v reflect.Value) []byte {
ipnet := v.Interface().(net.IPNet) ipnet := v.Interface().(net.IPNet)
return fmter.Dialect().AppendString(b, ipnet.String()) return fmter.Dialect().AppendString(b, ipnet.String())
} }
func appendStringer(fmter Formatter, b []byte, v reflect.Value) []byte {
return fmter.Dialect().AppendString(b, v.Interface().(fmt.Stringer).String())
}
func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte { func appendJSONRawMessageValue(fmter Formatter, b []byte, v reflect.Value) []byte {
bytes := v.Bytes() bytes := v.Bytes()
if bytes == nil { if bytes == nil {

View file

@ -118,7 +118,7 @@ func (BaseDialect) AppendJSON(b, jsonb []byte) []byte {
case '\000': case '\000':
continue continue
case '\\': case '\\':
if p.CutPrefix([]byte("u0000")) { if p.SkipBytes([]byte("u0000")) {
b = append(b, `\\u0000`...) b = append(b, `\\u0000`...)
} else { } else {
b = append(b, '\\') b = append(b, '\\')

View file

@ -4,7 +4,6 @@
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"net" "net"
"net/netip"
"reflect" "reflect"
"time" "time"
) )
@ -15,8 +14,6 @@
timeType = timePtrType.Elem() timeType = timePtrType.Elem()
ipType = reflect.TypeOf((*net.IP)(nil)).Elem() ipType = reflect.TypeOf((*net.IP)(nil)).Elem()
ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem() ipNetType = reflect.TypeOf((*net.IPNet)(nil)).Elem()
netipPrefixType = reflect.TypeOf((*netip.Prefix)(nil)).Elem()
netipAddrType = reflect.TypeOf((*netip.Addr)(nil)).Elem()
jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem() jsonRawMessageType = reflect.TypeOf((*json.RawMessage)(nil)).Elem()
driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()

View file

@ -13,25 +13,21 @@
) )
type Relation struct { type Relation struct {
// Base and Join can be explained with this query: Type int
// Field *Field
// SELECT * FROM base_table JOIN join_table JoinTable *Table
BaseFields []*Field
Type int JoinFields []*Field
Field *Field OnUpdate string
JoinTable *Table OnDelete string
BasePKs []*Field Condition []string
JoinPKs []*Field
OnUpdate string
OnDelete string
Condition []string
PolymorphicField *Field PolymorphicField *Field
PolymorphicValue string PolymorphicValue string
M2MTable *Table M2MTable *Table
M2MBasePKs []*Field M2MBaseFields []*Field
M2MJoinPKs []*Field M2MJoinFields []*Field
} }
// References returns true if the table to which the Relation belongs needs to declare a foreign key constraint to create the relation. // References returns true if the table to which the Relation belongs needs to declare a foreign key constraint to create the relation.

View file

@ -8,9 +8,9 @@
"reflect" "reflect"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/puzpuzpuz/xsync/v3"
"github.com/vmihailenco/msgpack/v5" "github.com/vmihailenco/msgpack/v5"
"github.com/uptrace/bun/dialect/sqltype" "github.com/uptrace/bun/dialect/sqltype"
@ -53,7 +53,7 @@ func init() {
} }
} }
var scannerCache = xsync.NewMapOf[reflect.Type, ScannerFunc]() var scannerMap sync.Map
func FieldScanner(dialect Dialect, field *Field) ScannerFunc { func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
if field.Tag.HasOption("msgpack") { if field.Tag.HasOption("msgpack") {
@ -72,14 +72,14 @@ func FieldScanner(dialect Dialect, field *Field) ScannerFunc {
} }
func Scanner(typ reflect.Type) ScannerFunc { func Scanner(typ reflect.Type) ScannerFunc {
if v, ok := scannerCache.Load(typ); ok { if v, ok := scannerMap.Load(typ); ok {
return v return v.(ScannerFunc)
} }
fn := scanner(typ) fn := scanner(typ)
if v, ok := scannerCache.LoadOrStore(typ, fn); ok { if v, ok := scannerMap.LoadOrStore(typ, fn); ok {
return v return v.(ScannerFunc)
} }
return fn return fn
} }
@ -111,7 +111,7 @@ func scanner(typ reflect.Type) ScannerFunc {
} }
if kind != reflect.Ptr { if kind != reflect.Ptr {
ptr := reflect.PointerTo(typ) ptr := reflect.PtrTo(typ)
if ptr.Implements(scannerType) { if ptr.Implements(scannerType) {
return addrScanner(scanScanner) return addrScanner(scanScanner)
} }

View file

@ -74,7 +74,16 @@ type structField struct {
Table *Table Table *Table
} }
func (table *Table) init(dialect Dialect, typ reflect.Type, canAddr bool) { func newTable(
dialect Dialect, typ reflect.Type, seen map[reflect.Type]*Table, canAddr bool,
) *Table {
if table, ok := seen[typ]; ok {
return table
}
table := new(Table)
seen[typ] = table
table.dialect = dialect table.dialect = dialect
table.Type = typ table.Type = typ
table.ZeroValue = reflect.New(table.Type).Elem() table.ZeroValue = reflect.New(table.Type).Elem()
@ -88,7 +97,7 @@ func (table *Table) init(dialect Dialect, typ reflect.Type, canAddr bool) {
table.Fields = make([]*Field, 0, typ.NumField()) table.Fields = make([]*Field, 0, typ.NumField())
table.FieldMap = make(map[string]*Field, typ.NumField()) table.FieldMap = make(map[string]*Field, typ.NumField())
table.processFields(typ, canAddr) table.processFields(typ, seen, canAddr)
hooks := []struct { hooks := []struct {
typ reflect.Type typ reflect.Type
@ -100,15 +109,28 @@ func (table *Table) init(dialect Dialect, typ reflect.Type, canAddr bool) {
{afterScanRowHookType, afterScanRowHookFlag}, {afterScanRowHookType, afterScanRowHookFlag},
} }
typ = reflect.PointerTo(table.Type) typ = reflect.PtrTo(table.Type)
for _, hook := range hooks { for _, hook := range hooks {
if typ.Implements(hook.typ) { if typ.Implements(hook.typ) {
table.flags = table.flags.Set(hook.flag) table.flags = table.flags.Set(hook.flag)
} }
} }
return table
} }
func (t *Table) processFields(typ reflect.Type, canAddr bool) { func (t *Table) init() {
for _, field := range t.relFields {
t.processRelation(field)
}
t.relFields = nil
}
func (t *Table) processFields(
typ reflect.Type,
seen map[reflect.Type]*Table,
canAddr bool,
) {
type embeddedField struct { type embeddedField struct {
prefix string prefix string
index []int index []int
@ -150,7 +172,7 @@ type embeddedField struct {
continue continue
} }
subtable := t.dialect.Tables().InProgress(sfType) subtable := newTable(t.dialect, sfType, seen, canAddr)
for _, subfield := range subtable.allFields { for _, subfield := range subtable.allFields {
embedded = append(embedded, embeddedField{ embedded = append(embedded, embeddedField{
@ -184,7 +206,7 @@ type embeddedField struct {
t.TypeName, sf.Name, fieldType.Kind())) t.TypeName, sf.Name, fieldType.Kind()))
} }
subtable := t.dialect.Tables().InProgress(fieldType) subtable := newTable(t.dialect, fieldType, seen, canAddr)
for _, subfield := range subtable.allFields { for _, subfield := range subtable.allFields {
embedded = append(embedded, embeddedField{ embedded = append(embedded, embeddedField{
prefix: prefix, prefix: prefix,
@ -207,7 +229,7 @@ type embeddedField struct {
} }
t.StructMap[field.Name] = &structField{ t.StructMap[field.Name] = &structField{
Index: field.Index, Index: field.Index,
Table: t.dialect.Tables().InProgress(field.IndirectType), Table: newTable(t.dialect, field.IndirectType, seen, canAddr),
} }
} }
} }
@ -401,10 +423,6 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field {
sqlName = tag.Name sqlName = tag.Name
} }
if s, ok := tag.Option("column"); ok {
sqlName = s
}
for name := range tag.Options { for name := range tag.Options {
if !isKnownFieldOption(name) { if !isKnownFieldOption(name) {
internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, sf.Name, name) internal.Warn.Printf("%s.%s has unknown tag option: %q", t.TypeName, sf.Name, name)
@ -472,13 +490,6 @@ func (t *Table) newField(sf reflect.StructField, tag tagparser.Tag) *Field {
//--------------------------------------------------------------------------------------- //---------------------------------------------------------------------------------------
func (t *Table) initRelations() {
for _, field := range t.relFields {
t.processRelation(field)
}
t.relFields = nil
}
func (t *Table) processRelation(field *Field) { func (t *Table) processRelation(field *Field) {
if rel, ok := field.Tag.Option("rel"); ok { if rel, ok := field.Tag.Option("rel"); ok {
t.initRelation(field, rel) t.initRelation(field, rel)
@ -566,7 +577,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
joinColumn := joinColumns[i] joinColumn := joinColumns[i]
if f := t.FieldMap[baseColumn]; f != nil { if f := t.FieldMap[baseColumn]; f != nil {
rel.BasePKs = append(rel.BasePKs, f) rel.BaseFields = append(rel.BaseFields, f)
} else { } else {
panic(fmt.Errorf( panic(fmt.Errorf(
"bun: %s belongs-to %s: %s must have column %s", "bun: %s belongs-to %s: %s must have column %s",
@ -575,7 +586,7 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
} }
if f := joinTable.FieldMap[joinColumn]; f != nil { if f := joinTable.FieldMap[joinColumn]; f != nil {
rel.JoinPKs = append(rel.JoinPKs, f) rel.JoinFields = append(rel.JoinFields, f)
} else { } else {
panic(fmt.Errorf( panic(fmt.Errorf(
"bun: %s belongs-to %s: %s must have column %s", "bun: %s belongs-to %s: %s must have column %s",
@ -586,17 +597,17 @@ func (t *Table) belongsToRelation(field *Field) *Relation {
return rel return rel
} }
rel.JoinPKs = joinTable.PKs rel.JoinFields = joinTable.PKs
fkPrefix := internal.Underscore(field.GoName) + "_" fkPrefix := internal.Underscore(field.GoName) + "_"
for _, joinPK := range joinTable.PKs { for _, joinPK := range joinTable.PKs {
fkName := fkPrefix + joinPK.Name fkName := fkPrefix + joinPK.Name
if fk := t.FieldMap[fkName]; fk != nil { if fk := t.FieldMap[fkName]; fk != nil {
rel.BasePKs = append(rel.BasePKs, fk) rel.BaseFields = append(rel.BaseFields, fk)
continue continue
} }
if fk := t.FieldMap[joinPK.Name]; fk != nil { if fk := t.FieldMap[joinPK.Name]; fk != nil {
rel.BasePKs = append(rel.BasePKs, fk) rel.BaseFields = append(rel.BaseFields, fk)
continue continue
} }
@ -629,7 +640,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation {
baseColumns, joinColumns := parseRelationJoin(join) baseColumns, joinColumns := parseRelationJoin(join)
for i, baseColumn := range baseColumns { for i, baseColumn := range baseColumns {
if f := t.FieldMap[baseColumn]; f != nil { if f := t.FieldMap[baseColumn]; f != nil {
rel.BasePKs = append(rel.BasePKs, f) rel.BaseFields = append(rel.BaseFields, f)
} else { } else {
panic(fmt.Errorf( panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s", "bun: %s has-one %s: %s must have column %s",
@ -639,7 +650,7 @@ func (t *Table) hasOneRelation(field *Field) *Relation {
joinColumn := joinColumns[i] joinColumn := joinColumns[i]
if f := joinTable.FieldMap[joinColumn]; f != nil { if f := joinTable.FieldMap[joinColumn]; f != nil {
rel.JoinPKs = append(rel.JoinPKs, f) rel.JoinFields = append(rel.JoinFields, f)
} else { } else {
panic(fmt.Errorf( panic(fmt.Errorf(
"bun: %s has-one %s: %s must have column %s", "bun: %s has-one %s: %s must have column %s",
@ -650,17 +661,17 @@ func (t *Table) hasOneRelation(field *Field) *Relation {
return rel return rel
} }
rel.BasePKs = t.PKs rel.BaseFields = t.PKs
fkPrefix := internal.Underscore(t.ModelName) + "_" fkPrefix := internal.Underscore(t.ModelName) + "_"
for _, pk := range t.PKs { for _, pk := range t.PKs {
fkName := fkPrefix + pk.Name fkName := fkPrefix + pk.Name
if f := joinTable.FieldMap[fkName]; f != nil { if f := joinTable.FieldMap[fkName]; f != nil {
rel.JoinPKs = append(rel.JoinPKs, f) rel.JoinFields = append(rel.JoinFields, f)
continue continue
} }
if f := joinTable.FieldMap[pk.Name]; f != nil { if f := joinTable.FieldMap[pk.Name]; f != nil {
rel.JoinPKs = append(rel.JoinPKs, f) rel.JoinFields = append(rel.JoinFields, f)
continue continue
} }
@ -709,7 +720,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
} }
if f := t.FieldMap[baseColumn]; f != nil { if f := t.FieldMap[baseColumn]; f != nil {
rel.BasePKs = append(rel.BasePKs, f) rel.BaseFields = append(rel.BaseFields, f)
} else { } else {
panic(fmt.Errorf( panic(fmt.Errorf(
"bun: %s has-many %s: %s must have column %s", "bun: %s has-many %s: %s must have column %s",
@ -718,7 +729,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
} }
if f := joinTable.FieldMap[joinColumn]; f != nil { if f := joinTable.FieldMap[joinColumn]; f != nil {
rel.JoinPKs = append(rel.JoinPKs, f) rel.JoinFields = append(rel.JoinFields, f)
} else { } else {
panic(fmt.Errorf( panic(fmt.Errorf(
"bun: %s has-many %s: %s must have column %s", "bun: %s has-many %s: %s must have column %s",
@ -727,7 +738,7 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
} }
} }
} else { } else {
rel.BasePKs = t.PKs rel.BaseFields = t.PKs
fkPrefix := internal.Underscore(t.ModelName) + "_" fkPrefix := internal.Underscore(t.ModelName) + "_"
if isPolymorphic { if isPolymorphic {
polymorphicColumn = fkPrefix + "type" polymorphicColumn = fkPrefix + "type"
@ -736,12 +747,12 @@ func (t *Table) hasManyRelation(field *Field) *Relation {
for _, pk := range t.PKs { for _, pk := range t.PKs {
joinColumn := fkPrefix + pk.Name joinColumn := fkPrefix + pk.Name
if fk := joinTable.FieldMap[joinColumn]; fk != nil { if fk := joinTable.FieldMap[joinColumn]; fk != nil {
rel.JoinPKs = append(rel.JoinPKs, fk) rel.JoinFields = append(rel.JoinFields, fk)
continue continue
} }
if fk := joinTable.FieldMap[pk.Name]; fk != nil { if fk := joinTable.FieldMap[pk.Name]; fk != nil {
rel.JoinPKs = append(rel.JoinPKs, fk) rel.JoinFields = append(rel.JoinFields, fk)
continue continue
} }
@ -841,12 +852,12 @@ func (t *Table) m2mRelation(field *Field) *Relation {
} }
leftRel := m2mTable.belongsToRelation(leftField) leftRel := m2mTable.belongsToRelation(leftField)
rel.BasePKs = leftRel.JoinPKs rel.BaseFields = leftRel.JoinFields
rel.M2MBasePKs = leftRel.BasePKs rel.M2MBaseFields = leftRel.BaseFields
rightRel := m2mTable.belongsToRelation(rightField) rightRel := m2mTable.belongsToRelation(rightField)
rel.JoinPKs = rightRel.JoinPKs rel.JoinFields = rightRel.JoinFields
rel.M2MJoinPKs = rightRel.BasePKs rel.M2MJoinFields = rightRel.BaseFields
return rel return rel
} }
@ -907,7 +918,6 @@ func isKnownFieldOption(name string) bool {
"array", "array",
"hstore", "hstore",
"composite", "composite",
"multirange",
"json_use_number", "json_use_number",
"msgpack", "msgpack",
"notnull", "notnull",

View file

@ -4,24 +4,22 @@
"fmt" "fmt"
"reflect" "reflect"
"sync" "sync"
"github.com/puzpuzpuz/xsync/v3"
) )
type Tables struct { type Tables struct {
dialect Dialect dialect Dialect
tables sync.Map
mu sync.Mutex mu sync.RWMutex
tables *xsync.MapOf[reflect.Type, *Table] seen map[reflect.Type]*Table
inProgress map[reflect.Type]*tableInProgress
inProgress map[reflect.Type]*Table
} }
func NewTables(dialect Dialect) *Tables { func NewTables(dialect Dialect) *Tables {
return &Tables{ return &Tables{
dialect: dialect, dialect: dialect,
tables: xsync.NewMapOf[reflect.Type, *Table](), seen: make(map[reflect.Type]*Table),
inProgress: make(map[reflect.Type]*Table), inProgress: make(map[reflect.Type]*tableInProgress),
} }
} }
@ -32,26 +30,58 @@ func (t *Tables) Register(models ...interface{}) {
} }
func (t *Tables) Get(typ reflect.Type) *Table { func (t *Tables) Get(typ reflect.Type) *Table {
return t.table(typ, false)
}
func (t *Tables) InProgress(typ reflect.Type) *Table {
return t.table(typ, true)
}
func (t *Tables) table(typ reflect.Type, allowInProgress bool) *Table {
typ = indirectType(typ) typ = indirectType(typ)
if typ.Kind() != reflect.Struct { if typ.Kind() != reflect.Struct {
panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct)) panic(fmt.Errorf("got %s, wanted %s", typ.Kind(), reflect.Struct))
} }
if v, ok := t.tables.Load(typ); ok { if v, ok := t.tables.Load(typ); ok {
return v return v.(*Table)
} }
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock()
if v, ok := t.tables.Load(typ); ok { if v, ok := t.tables.Load(typ); ok {
return v t.mu.Unlock()
return v.(*Table)
} }
table := t.InProgress(typ) var table *Table
table.initRelations()
inProgress := t.inProgress[typ]
if inProgress == nil {
table = newTable(t.dialect, typ, t.seen, false)
inProgress = newTableInProgress(table)
t.inProgress[typ] = inProgress
} else {
table = inProgress.table
}
t.mu.Unlock()
if allowInProgress {
return table
}
if !inProgress.init() {
return table
}
t.mu.Lock()
delete(t.inProgress, typ)
t.tables.Store(typ, table)
t.mu.Unlock()
t.dialect.OnTable(table) t.dialect.OnTable(table)
for _, field := range table.FieldMap { for _, field := range table.FieldMap {
if field.UserSQLType == "" { if field.UserSQLType == "" {
field.UserSQLType = field.DiscoveredSQLType field.UserSQLType = field.DiscoveredSQLType
@ -61,27 +91,15 @@ func (t *Tables) Get(typ reflect.Type) *Table {
} }
} }
t.tables.Store(typ, table)
return table
}
func (t *Tables) InProgress(typ reflect.Type) *Table {
if table, ok := t.inProgress[typ]; ok {
return table
}
table := new(Table)
t.inProgress[typ] = table
table.init(t.dialect, typ, false)
return table return table
} }
func (t *Tables) ByModel(name string) *Table { func (t *Tables) ByModel(name string) *Table {
var found *Table var found *Table
t.tables.Range(func(typ reflect.Type, table *Table) bool { t.tables.Range(func(key, value interface{}) bool {
if table.TypeName == name { t := value.(*Table)
found = table if t.TypeName == name {
found = t
return false return false
} }
return true return true
@ -91,12 +109,34 @@ func (t *Tables) ByModel(name string) *Table {
func (t *Tables) ByName(name string) *Table { func (t *Tables) ByName(name string) *Table {
var found *Table var found *Table
t.tables.Range(func(typ reflect.Type, table *Table) bool { t.tables.Range(func(key, value interface{}) bool {
if table.Name == name { t := value.(*Table)
found = table if t.Name == name {
found = t
return false return false
} }
return true return true
}) })
return found return found
} }
type tableInProgress struct {
table *Table
initOnce sync.Once
}
func newTableInProgress(table *Table) *tableInProgress {
return &tableInProgress{
table: table,
}
}
func (inp *tableInProgress) init() bool {
var inited bool
inp.initOnce.Do(func() {
inp.table.init()
inited = true
})
return inited
}

View file

@ -60,7 +60,7 @@ func zeroChecker(typ reflect.Type) IsZeroerFunc {
kind := typ.Kind() kind := typ.Kind()
if kind != reflect.Ptr { if kind != reflect.Ptr {
ptr := reflect.PointerTo(typ) ptr := reflect.PtrTo(typ)
if ptr.Implements(isZeroerType) { if ptr.Implements(isZeroerType) {
return addrChecker(isZeroInterface) return addrChecker(isZeroInterface)
} }

View file

@ -2,5 +2,5 @@
// Version is the current release version. // Version is the current release version.
func Version() string { func Version() string {
return "1.2.5" return "1.2.1"
} }

View file

@ -2,9 +2,8 @@
# database/sql instrumentation for OpenTelemetry Go # database/sql instrumentation for OpenTelemetry Go
[OpenTelemetry database/sql](https://uptrace.dev/get/instrument/opentelemetry-database-sql.html) [database/sql OpenTelemetry instrumentation](https://uptrace.dev/getinstrument/opentelemetry-database-sql.html)
instrumentation records database queries (including `Tx` and `Stmt` queries) and reports `DBStats` records database queries (including `Tx` and `Stmt` queries) and reports `DBStats` metrics.
metrics.
## Installation ## Installation

View file

@ -2,5 +2,5 @@
// Version is the current release version. // Version is the current release version.
func Version() string { func Version() string {
return "0.3.2" return "0.2.4"
} }

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (!arm64 && !s390x && !ppc64 && !ppc64le) || !gc || purego //go:build (!arm64 && !s390x && !ppc64le) || !gc || purego
package chacha20 package chacha20

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build gc && !purego && (ppc64 || ppc64le) //go:build gc && !purego
package chacha20 package chacha20

View file

@ -19,7 +19,7 @@
// The differences in this and the original implementation are // The differences in this and the original implementation are
// due to the calling conventions and initialization of constants. // due to the calling conventions and initialization of constants.
//go:build gc && !purego && (ppc64 || ppc64le) //go:build gc && !purego
#include "textflag.h" #include "textflag.h"
@ -36,68 +36,32 @@
// for VPERMXOR // for VPERMXOR
#define MASK R18 #define MASK R18
DATA consts<>+0x00(SB)/4, $0x61707865 DATA consts<>+0x00(SB)/8, $0x3320646e61707865
DATA consts<>+0x04(SB)/4, $0x3320646e DATA consts<>+0x08(SB)/8, $0x6b20657479622d32
DATA consts<>+0x08(SB)/4, $0x79622d32 DATA consts<>+0x10(SB)/8, $0x0000000000000001
DATA consts<>+0x0c(SB)/4, $0x6b206574 DATA consts<>+0x18(SB)/8, $0x0000000000000000
DATA consts<>+0x10(SB)/4, $0x00000001 DATA consts<>+0x20(SB)/8, $0x0000000000000004
DATA consts<>+0x14(SB)/4, $0x00000000 DATA consts<>+0x28(SB)/8, $0x0000000000000000
DATA consts<>+0x18(SB)/4, $0x00000000 DATA consts<>+0x30(SB)/8, $0x0a0b08090e0f0c0d
DATA consts<>+0x1c(SB)/4, $0x00000000 DATA consts<>+0x38(SB)/8, $0x0203000106070405
DATA consts<>+0x20(SB)/4, $0x00000004 DATA consts<>+0x40(SB)/8, $0x090a0b080d0e0f0c
DATA consts<>+0x24(SB)/4, $0x00000000 DATA consts<>+0x48(SB)/8, $0x0102030005060704
DATA consts<>+0x28(SB)/4, $0x00000000 DATA consts<>+0x50(SB)/8, $0x6170786561707865
DATA consts<>+0x2c(SB)/4, $0x00000000 DATA consts<>+0x58(SB)/8, $0x6170786561707865
DATA consts<>+0x30(SB)/4, $0x0e0f0c0d DATA consts<>+0x60(SB)/8, $0x3320646e3320646e
DATA consts<>+0x34(SB)/4, $0x0a0b0809 DATA consts<>+0x68(SB)/8, $0x3320646e3320646e
DATA consts<>+0x38(SB)/4, $0x06070405 DATA consts<>+0x70(SB)/8, $0x79622d3279622d32
DATA consts<>+0x3c(SB)/4, $0x02030001 DATA consts<>+0x78(SB)/8, $0x79622d3279622d32
DATA consts<>+0x40(SB)/4, $0x0d0e0f0c DATA consts<>+0x80(SB)/8, $0x6b2065746b206574
DATA consts<>+0x44(SB)/4, $0x090a0b08 DATA consts<>+0x88(SB)/8, $0x6b2065746b206574
DATA consts<>+0x48(SB)/4, $0x05060704 DATA consts<>+0x90(SB)/8, $0x0000000100000000
DATA consts<>+0x4c(SB)/4, $0x01020300 DATA consts<>+0x98(SB)/8, $0x0000000300000002
DATA consts<>+0x50(SB)/4, $0x61707865 DATA consts<>+0xa0(SB)/8, $0x5566774411223300
DATA consts<>+0x54(SB)/4, $0x61707865 DATA consts<>+0xa8(SB)/8, $0xddeeffcc99aabb88
DATA consts<>+0x58(SB)/4, $0x61707865 DATA consts<>+0xb0(SB)/8, $0x6677445522330011
DATA consts<>+0x5c(SB)/4, $0x61707865 DATA consts<>+0xb8(SB)/8, $0xeeffccddaabb8899
DATA consts<>+0x60(SB)/4, $0x3320646e
DATA consts<>+0x64(SB)/4, $0x3320646e
DATA consts<>+0x68(SB)/4, $0x3320646e
DATA consts<>+0x6c(SB)/4, $0x3320646e
DATA consts<>+0x70(SB)/4, $0x79622d32
DATA consts<>+0x74(SB)/4, $0x79622d32
DATA consts<>+0x78(SB)/4, $0x79622d32
DATA consts<>+0x7c(SB)/4, $0x79622d32
DATA consts<>+0x80(SB)/4, $0x6b206574
DATA consts<>+0x84(SB)/4, $0x6b206574
DATA consts<>+0x88(SB)/4, $0x6b206574
DATA consts<>+0x8c(SB)/4, $0x6b206574
DATA consts<>+0x90(SB)/4, $0x00000000
DATA consts<>+0x94(SB)/4, $0x00000001
DATA consts<>+0x98(SB)/4, $0x00000002
DATA consts<>+0x9c(SB)/4, $0x00000003
DATA consts<>+0xa0(SB)/4, $0x11223300
DATA consts<>+0xa4(SB)/4, $0x55667744
DATA consts<>+0xa8(SB)/4, $0x99aabb88
DATA consts<>+0xac(SB)/4, $0xddeeffcc
DATA consts<>+0xb0(SB)/4, $0x22330011
DATA consts<>+0xb4(SB)/4, $0x66774455
DATA consts<>+0xb8(SB)/4, $0xaabb8899
DATA consts<>+0xbc(SB)/4, $0xeeffccdd
GLOBL consts<>(SB), RODATA, $0xc0 GLOBL consts<>(SB), RODATA, $0xc0
#ifdef GOARCH_ppc64
#define BE_XXBRW_INIT() \
LVSL (R0)(R0), V24 \
VSPLTISB $3, V25 \
VXOR V24, V25, V24 \
#define BE_XXBRW(vr) VPERM vr, vr, V24, vr
#else
#define BE_XXBRW_INIT()
#define BE_XXBRW(vr)
#endif
//func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32) //func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32)
TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40 TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
MOVD out+0(FP), OUT MOVD out+0(FP), OUT
@ -130,8 +94,6 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
// Clear V27 // Clear V27
VXOR V27, V27, V27 VXOR V27, V27, V27
BE_XXBRW_INIT()
// V28 // V28
LXVW4X (CONSTBASE)(R11), VS60 LXVW4X (CONSTBASE)(R11), VS60
@ -337,11 +299,6 @@ loop_vsx:
VADDUWM V8, V18, V8 VADDUWM V8, V18, V8
VADDUWM V12, V19, V12 VADDUWM V12, V19, V12
BE_XXBRW(V0)
BE_XXBRW(V4)
BE_XXBRW(V8)
BE_XXBRW(V12)
CMPU LEN, $64 CMPU LEN, $64
BLT tail_vsx BLT tail_vsx
@ -370,11 +327,6 @@ loop_vsx:
VADDUWM V9, V18, V8 VADDUWM V9, V18, V8
VADDUWM V13, V19, V12 VADDUWM V13, V19, V12
BE_XXBRW(V0)
BE_XXBRW(V4)
BE_XXBRW(V8)
BE_XXBRW(V12)
CMPU LEN, $64 CMPU LEN, $64
BLT tail_vsx BLT tail_vsx
@ -382,8 +334,8 @@ loop_vsx:
LXVW4X (INP)(R8), VS60 LXVW4X (INP)(R8), VS60
LXVW4X (INP)(R9), VS61 LXVW4X (INP)(R9), VS61
LXVW4X (INP)(R10), VS62 LXVW4X (INP)(R10), VS62
VXOR V27, V0, V27
VXOR V27, V0, V27
VXOR V28, V4, V28 VXOR V28, V4, V28
VXOR V29, V8, V29 VXOR V29, V8, V29
VXOR V30, V12, V30 VXOR V30, V12, V30
@ -402,11 +354,6 @@ loop_vsx:
VADDUWM V10, V18, V8 VADDUWM V10, V18, V8
VADDUWM V14, V19, V12 VADDUWM V14, V19, V12
BE_XXBRW(V0)
BE_XXBRW(V4)
BE_XXBRW(V8)
BE_XXBRW(V12)
CMPU LEN, $64 CMPU LEN, $64
BLT tail_vsx BLT tail_vsx
@ -434,11 +381,6 @@ loop_vsx:
VADDUWM V11, V18, V8 VADDUWM V11, V18, V8
VADDUWM V15, V19, V12 VADDUWM V15, V19, V12
BE_XXBRW(V0)
BE_XXBRW(V4)
BE_XXBRW(V8)
BE_XXBRW(V12)
CMPU LEN, $64 CMPU LEN, $64
BLT tail_vsx BLT tail_vsx
@ -466,9 +408,9 @@ loop_vsx:
done_vsx: done_vsx:
// Increment counter by number of 64 byte blocks // Increment counter by number of 64 byte blocks
MOVWZ (CNT), R14 MOVD (CNT), R14
ADD BLOCKS, R14 ADD BLOCKS, R14
MOVWZ R14, (CNT) MOVD R14, (CNT)
RET RET
tail_vsx: tail_vsx:

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (!amd64 && !ppc64le && !ppc64 && !s390x) || !gc || purego //go:build (!amd64 && !ppc64le && !s390x) || !gc || purego
package poly1305 package poly1305

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build gc && !purego && (ppc64 || ppc64le) //go:build gc && !purego
package poly1305 package poly1305

View file

@ -2,25 +2,15 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build gc && !purego && (ppc64 || ppc64le) //go:build gc && !purego
#include "textflag.h" #include "textflag.h"
// This was ported from the amd64 implementation. // This was ported from the amd64 implementation.
#ifdef GOARCH_ppc64le
#define LE_MOVD MOVD
#define LE_MOVWZ MOVWZ
#define LE_MOVHZ MOVHZ
#else
#define LE_MOVD MOVDBR
#define LE_MOVWZ MOVWBR
#define LE_MOVHZ MOVHBR
#endif
#define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \ #define POLY1305_ADD(msg, h0, h1, h2, t0, t1, t2) \
LE_MOVD (msg)( R0), t0; \ MOVD (msg), t0; \
LE_MOVD (msg)(R24), t1; \ MOVD 8(msg), t1; \
MOVD $1, t2; \ MOVD $1, t2; \
ADDC t0, h0, h0; \ ADDC t0, h0, h0; \
ADDE t1, h1, h1; \ ADDE t1, h1, h1; \
@ -60,6 +50,10 @@
ADDE t3, h1, h1; \ ADDE t3, h1, h1; \
ADDZE h2 ADDZE h2
DATA ·poly1305Mask<>+0x00(SB)/8, $0x0FFFFFFC0FFFFFFF
DATA ·poly1305Mask<>+0x08(SB)/8, $0x0FFFFFFC0FFFFFFC
GLOBL ·poly1305Mask<>(SB), RODATA, $16
// func update(state *[7]uint64, msg []byte) // func update(state *[7]uint64, msg []byte)
TEXT ·update(SB), $0-32 TEXT ·update(SB), $0-32
MOVD state+0(FP), R3 MOVD state+0(FP), R3
@ -72,8 +66,6 @@ TEXT ·update(SB), $0-32
MOVD 24(R3), R11 // r0 MOVD 24(R3), R11 // r0
MOVD 32(R3), R12 // r1 MOVD 32(R3), R12 // r1
MOVD $8, R24
CMP R5, $16 CMP R5, $16
BLT bytes_between_0_and_15 BLT bytes_between_0_and_15
@ -102,7 +94,7 @@ flush_buffer:
// Greater than 8 -- load the rightmost remaining bytes in msg // Greater than 8 -- load the rightmost remaining bytes in msg
// and put into R17 (h1) // and put into R17 (h1)
LE_MOVD (R4)(R21), R17 MOVD (R4)(R21), R17
MOVD $16, R22 MOVD $16, R22
// Find the offset to those bytes // Find the offset to those bytes
@ -126,7 +118,7 @@ just1:
BLT less8 BLT less8
// Exactly 8 // Exactly 8
LE_MOVD (R4), R16 MOVD (R4), R16
CMP R17, $0 CMP R17, $0
@ -141,7 +133,7 @@ less8:
MOVD $0, R22 // shift count MOVD $0, R22 // shift count
CMP R5, $4 CMP R5, $4
BLT less4 BLT less4
LE_MOVWZ (R4), R16 MOVWZ (R4), R16
ADD $4, R4 ADD $4, R4
ADD $-4, R5 ADD $-4, R5
MOVD $32, R22 MOVD $32, R22
@ -149,7 +141,7 @@ less8:
less4: less4:
CMP R5, $2 CMP R5, $2
BLT less2 BLT less2
LE_MOVHZ (R4), R21 MOVHZ (R4), R21
SLD R22, R21, R21 SLD R22, R21, R21
OR R16, R21, R16 OR R16, R21, R16
ADD $16, R22 ADD $16, R22

View file

@ -5,10 +5,6 @@
// Package sha3 implements the SHA-3 fixed-output-length hash functions and // Package sha3 implements the SHA-3 fixed-output-length hash functions and
// the SHAKE variable-output-length hash functions defined by FIPS-202. // the SHAKE variable-output-length hash functions defined by FIPS-202.
// //
// All types in this package also implement [encoding.BinaryMarshaler],
// [encoding.BinaryAppender] and [encoding.BinaryUnmarshaler] to marshal and
// unmarshal the internal state of the hash.
//
// Both types of hash function use the "sponge" construction and the Keccak // Both types of hash function use the "sponge" construction and the Keccak
// permutation. For a detailed specification see http://keccak.noekeon.org/ // permutation. For a detailed specification see http://keccak.noekeon.org/
// //

View file

@ -48,52 +48,33 @@ func init() {
crypto.RegisterHash(crypto.SHA3_512, New512) crypto.RegisterHash(crypto.SHA3_512, New512)
} }
const (
dsbyteSHA3 = 0b00000110
dsbyteKeccak = 0b00000001
dsbyteShake = 0b00011111
dsbyteCShake = 0b00000100
// rateK[c] is the rate in bytes for Keccak[c] where c is the capacity in
// bits. Given the sponge size is 1600 bits, the rate is 1600 - c bits.
rateK256 = (1600 - 256) / 8
rateK448 = (1600 - 448) / 8
rateK512 = (1600 - 512) / 8
rateK768 = (1600 - 768) / 8
rateK1024 = (1600 - 1024) / 8
)
func new224Generic() *state { func new224Generic() *state {
return &state{rate: rateK448, outputLen: 28, dsbyte: dsbyteSHA3} return &state{rate: 144, outputLen: 28, dsbyte: 0x06}
} }
func new256Generic() *state { func new256Generic() *state {
return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteSHA3} return &state{rate: 136, outputLen: 32, dsbyte: 0x06}
} }
func new384Generic() *state { func new384Generic() *state {
return &state{rate: rateK768, outputLen: 48, dsbyte: dsbyteSHA3} return &state{rate: 104, outputLen: 48, dsbyte: 0x06}
} }
func new512Generic() *state { func new512Generic() *state {
return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteSHA3} return &state{rate: 72, outputLen: 64, dsbyte: 0x06}
} }
// NewLegacyKeccak256 creates a new Keccak-256 hash. // NewLegacyKeccak256 creates a new Keccak-256 hash.
// //
// Only use this function if you require compatibility with an existing cryptosystem // Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New256 instead. // that uses non-standard padding. All other users should use New256 instead.
func NewLegacyKeccak256() hash.Hash { func NewLegacyKeccak256() hash.Hash { return &state{rate: 136, outputLen: 32, dsbyte: 0x01} }
return &state{rate: rateK512, outputLen: 32, dsbyte: dsbyteKeccak}
}
// NewLegacyKeccak512 creates a new Keccak-512 hash. // NewLegacyKeccak512 creates a new Keccak-512 hash.
// //
// Only use this function if you require compatibility with an existing cryptosystem // Only use this function if you require compatibility with an existing cryptosystem
// that uses non-standard padding. All other users should use New512 instead. // that uses non-standard padding. All other users should use New512 instead.
func NewLegacyKeccak512() hash.Hash { func NewLegacyKeccak512() hash.Hash { return &state{rate: 72, outputLen: 64, dsbyte: 0x01} }
return &state{rate: rateK1024, outputLen: 64, dsbyte: dsbyteKeccak}
}
// Sum224 returns the SHA3-224 digest of the data. // Sum224 returns the SHA3-224 digest of the data.
func Sum224(data []byte) (digest [28]byte) { func Sum224(data []byte) (digest [28]byte) {

View file

@ -4,15 +4,6 @@
package sha3 package sha3
import (
"crypto/subtle"
"encoding/binary"
"errors"
"unsafe"
"golang.org/x/sys/cpu"
)
// spongeDirection indicates the direction bytes are flowing through the sponge. // spongeDirection indicates the direction bytes are flowing through the sponge.
type spongeDirection int type spongeDirection int
@ -23,13 +14,16 @@
spongeSqueezing spongeSqueezing
) )
type state struct { const (
a [1600 / 8]byte // main state of the hash // maxRate is the maximum size of the internal buffer. SHAKE-256
// currently needs the largest buffer.
maxRate = 168
)
// a[n:rate] is the buffer. If absorbing, it's the remaining space to XOR type state struct {
// into before running the permutation. If squeezing, it's the remaining // Generic sponge components.
// output to produce before running the permutation. a [25]uint64 // main state of the hash
n, rate int rate int // the number of bytes of state to use
// dsbyte contains the "domain separation" bits and the first bit of // dsbyte contains the "domain separation" bits and the first bit of
// the padding. Sections 6.1 and 6.2 of [1] separate the outputs of the // the padding. Sections 6.1 and 6.2 of [1] separate the outputs of the
@ -45,6 +39,10 @@ type state struct {
// Extendable-Output Functions (May 2014)" // Extendable-Output Functions (May 2014)"
dsbyte byte dsbyte byte
i, n int // storage[i:n] is the buffer, i is only used while squeezing
storage [maxRate]byte
// Specific to SHA-3 and SHAKE.
outputLen int // the default output size in bytes outputLen int // the default output size in bytes
state spongeDirection // whether the sponge is absorbing or squeezing state spongeDirection // whether the sponge is absorbing or squeezing
} }
@ -63,7 +61,7 @@ func (d *state) Reset() {
d.a[i] = 0 d.a[i] = 0
} }
d.state = spongeAbsorbing d.state = spongeAbsorbing
d.n = 0 d.i, d.n = 0, 0
} }
func (d *state) clone() *state { func (d *state) clone() *state {
@ -71,25 +69,22 @@ func (d *state) clone() *state {
return &ret return &ret
} }
// permute applies the KeccakF-1600 permutation. // permute applies the KeccakF-1600 permutation. It handles
// any input-output buffering.
func (d *state) permute() { func (d *state) permute() {
var a *[25]uint64 switch d.state {
if cpu.IsBigEndian { case spongeAbsorbing:
a = new([25]uint64) // If we're absorbing, we need to xor the input into the state
for i := range a { // before applying the permutation.
a[i] = binary.LittleEndian.Uint64(d.a[i*8:]) xorIn(d, d.storage[:d.rate])
} d.n = 0
} else { keccakF1600(&d.a)
a = (*[25]uint64)(unsafe.Pointer(&d.a)) case spongeSqueezing:
} // If we're squeezing, we need to apply the permutation before
// copying more output.
keccakF1600(a) keccakF1600(&d.a)
d.n = 0 d.i = 0
copyOut(d, d.storage[:d.rate])
if cpu.IsBigEndian {
for i := range a {
binary.LittleEndian.PutUint64(d.a[i*8:], a[i])
}
} }
} }
@ -97,36 +92,53 @@ func (d *state) permute() {
// the multi-bitrate 10..1 padding rule, and permutes the state. // the multi-bitrate 10..1 padding rule, and permutes the state.
func (d *state) padAndPermute() { func (d *state) padAndPermute() {
// Pad with this instance's domain-separator bits. We know that there's // Pad with this instance's domain-separator bits. We know that there's
// at least one byte of space in the sponge because, if it were full, // at least one byte of space in d.buf because, if it were full,
// permute would have been called to empty it. dsbyte also contains the // permute would have been called to empty it. dsbyte also contains the
// first one bit for the padding. See the comment in the state struct. // first one bit for the padding. See the comment in the state struct.
d.a[d.n] ^= d.dsbyte d.storage[d.n] = d.dsbyte
d.n++
for d.n < d.rate {
d.storage[d.n] = 0
d.n++
}
// This adds the final one bit for the padding. Because of the way that // This adds the final one bit for the padding. Because of the way that
// bits are numbered from the LSB upwards, the final bit is the MSB of // bits are numbered from the LSB upwards, the final bit is the MSB of
// the last byte. // the last byte.
d.a[d.rate-1] ^= 0x80 d.storage[d.rate-1] ^= 0x80
// Apply the permutation // Apply the permutation
d.permute() d.permute()
d.state = spongeSqueezing d.state = spongeSqueezing
d.n = d.rate
copyOut(d, d.storage[:d.rate])
} }
// Write absorbs more data into the hash's state. It panics if any // Write absorbs more data into the hash's state. It panics if any
// output has already been read. // output has already been read.
func (d *state) Write(p []byte) (n int, err error) { func (d *state) Write(p []byte) (written int, err error) {
if d.state != spongeAbsorbing { if d.state != spongeAbsorbing {
panic("sha3: Write after Read") panic("sha3: Write after Read")
} }
written = len(p)
n = len(p)
for len(p) > 0 { for len(p) > 0 {
x := subtle.XORBytes(d.a[d.n:d.rate], d.a[d.n:d.rate], p) if d.n == 0 && len(p) >= d.rate {
d.n += x // The fast path; absorb a full "rate" bytes of input and apply the permutation.
p = p[x:] xorIn(d, p[:d.rate])
p = p[d.rate:]
keccakF1600(&d.a)
} else {
// The slow path; buffer the input until we can fill the sponge, and then xor it in.
todo := d.rate - d.n
if todo > len(p) {
todo = len(p)
}
d.n += copy(d.storage[d.n:], p[:todo])
p = p[todo:]
// If the sponge is full, apply the permutation. // If the sponge is full, apply the permutation.
if d.n == d.rate { if d.n == d.rate {
d.permute() d.permute()
}
} }
} }
@ -144,14 +156,14 @@ func (d *state) Read(out []byte) (n int, err error) {
// Now, do the squeezing. // Now, do the squeezing.
for len(out) > 0 { for len(out) > 0 {
n := copy(out, d.storage[d.i:d.n])
d.i += n
out = out[n:]
// Apply the permutation if we've squeezed the sponge dry. // Apply the permutation if we've squeezed the sponge dry.
if d.n == d.rate { if d.i == d.rate {
d.permute() d.permute()
} }
x := copy(out, d.a[d.n:d.rate])
d.n += x
out = out[x:]
} }
return return
@ -171,74 +183,3 @@ func (d *state) Sum(in []byte) []byte {
dup.Read(hash) dup.Read(hash)
return append(in, hash...) return append(in, hash...)
} }
const (
magicSHA3 = "sha\x08"
magicShake = "sha\x09"
magicCShake = "sha\x0a"
magicKeccak = "sha\x0b"
// magic || rate || main state || n || sponge direction
marshaledSize = len(magicSHA3) + 1 + 200 + 1 + 1
)
func (d *state) MarshalBinary() ([]byte, error) {
return d.AppendBinary(make([]byte, 0, marshaledSize))
}
func (d *state) AppendBinary(b []byte) ([]byte, error) {
switch d.dsbyte {
case dsbyteSHA3:
b = append(b, magicSHA3...)
case dsbyteShake:
b = append(b, magicShake...)
case dsbyteCShake:
b = append(b, magicCShake...)
case dsbyteKeccak:
b = append(b, magicKeccak...)
default:
panic("unknown dsbyte")
}
// rate is at most 168, and n is at most rate.
b = append(b, byte(d.rate))
b = append(b, d.a[:]...)
b = append(b, byte(d.n), byte(d.state))
return b, nil
}
func (d *state) UnmarshalBinary(b []byte) error {
if len(b) != marshaledSize {
return errors.New("sha3: invalid hash state")
}
magic := string(b[:len(magicSHA3)])
b = b[len(magicSHA3):]
switch {
case magic == magicSHA3 && d.dsbyte == dsbyteSHA3:
case magic == magicShake && d.dsbyte == dsbyteShake:
case magic == magicCShake && d.dsbyte == dsbyteCShake:
case magic == magicKeccak && d.dsbyte == dsbyteKeccak:
default:
return errors.New("sha3: invalid hash state identifier")
}
rate := int(b[0])
b = b[1:]
if rate != d.rate {
return errors.New("sha3: invalid hash state function")
}
copy(d.a[:], b)
b = b[len(d.a):]
n, state := int(b[0]), spongeDirection(b[1])
if n > d.rate {
return errors.New("sha3: invalid hash state")
}
d.n = n
if state != spongeAbsorbing && state != spongeSqueezing {
return errors.New("sha3: invalid hash state")
}
d.state = state
return nil
}

View file

@ -16,12 +16,9 @@
// [2] https://doi.org/10.6028/NIST.SP.800-185 // [2] https://doi.org/10.6028/NIST.SP.800-185
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors"
"hash" "hash"
"io" "io"
"math/bits"
) )
// ShakeHash defines the interface to hash functions that support // ShakeHash defines the interface to hash functions that support
@ -53,33 +50,41 @@ type cshakeState struct {
initBlock []byte initBlock []byte
} }
func bytepad(data []byte, rate int) []byte { // Consts for configuring initial SHA-3 state
out := make([]byte, 0, 9+len(data)+rate-1) const (
out = append(out, leftEncode(uint64(rate))...) dsbyteShake = 0x1f
out = append(out, data...) dsbyteCShake = 0x04
if padlen := rate - len(out)%rate; padlen < rate { rate128 = 168
out = append(out, make([]byte, padlen)...) rate256 = 136
} )
return out
func bytepad(input []byte, w int) []byte {
// leftEncode always returns max 9 bytes
buf := make([]byte, 0, 9+len(input)+w)
buf = append(buf, leftEncode(uint64(w))...)
buf = append(buf, input...)
padlen := w - (len(buf) % w)
return append(buf, make([]byte, padlen)...)
} }
func leftEncode(x uint64) []byte { func leftEncode(value uint64) []byte {
// Let n be the smallest positive integer for which 2^(8n) > x. var b [9]byte
n := (bits.Len64(x) + 7) / 8 binary.BigEndian.PutUint64(b[1:], value)
if n == 0 { // Trim all but last leading zero bytes
n = 1 i := byte(1)
for i < 8 && b[i] == 0 {
i++
} }
// Return n || x with n as a byte and x an n bytes in big-endian order. // Prepend number of encoded bytes
b := make([]byte, 9) b[i-1] = 9 - i
binary.BigEndian.PutUint64(b[1:], x) return b[i-1:]
b = b[9-n-1:]
b[0] = byte(n)
return b
} }
func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash { func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}} c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}}
c.initBlock = make([]byte, 0, 9+len(N)+9+len(S)) // leftEncode returns max 9 bytes
// leftEncode returns max 9 bytes
c.initBlock = make([]byte, 0, 9*2+len(N)+len(S))
c.initBlock = append(c.initBlock, leftEncode(uint64(len(N))*8)...) c.initBlock = append(c.initBlock, leftEncode(uint64(len(N))*8)...)
c.initBlock = append(c.initBlock, N...) c.initBlock = append(c.initBlock, N...)
c.initBlock = append(c.initBlock, leftEncode(uint64(len(S))*8)...) c.initBlock = append(c.initBlock, leftEncode(uint64(len(S))*8)...)
@ -106,30 +111,6 @@ func (c *state) Clone() ShakeHash {
return c.clone() return c.clone()
} }
func (c *cshakeState) MarshalBinary() ([]byte, error) {
return c.AppendBinary(make([]byte, 0, marshaledSize+len(c.initBlock)))
}
func (c *cshakeState) AppendBinary(b []byte) ([]byte, error) {
b, err := c.state.AppendBinary(b)
if err != nil {
return nil, err
}
b = append(b, c.initBlock...)
return b, nil
}
func (c *cshakeState) UnmarshalBinary(b []byte) error {
if len(b) <= marshaledSize {
return errors.New("sha3: invalid hash state")
}
if err := c.state.UnmarshalBinary(b[:marshaledSize]); err != nil {
return err
}
c.initBlock = bytes.Clone(b[marshaledSize:])
return nil
}
// NewShake128 creates a new SHAKE128 variable-output-length ShakeHash. // NewShake128 creates a new SHAKE128 variable-output-length ShakeHash.
// Its generic security strength is 128 bits against all attacks if at // Its generic security strength is 128 bits against all attacks if at
// least 32 bytes of its output are used. // least 32 bytes of its output are used.
@ -145,11 +126,11 @@ func NewShake256() ShakeHash {
} }
func newShake128Generic() *state { func newShake128Generic() *state {
return &state{rate: rateK256, outputLen: 32, dsbyte: dsbyteShake} return &state{rate: rate128, outputLen: 32, dsbyte: dsbyteShake}
} }
func newShake256Generic() *state { func newShake256Generic() *state {
return &state{rate: rateK512, outputLen: 64, dsbyte: dsbyteShake} return &state{rate: rate256, outputLen: 64, dsbyte: dsbyteShake}
} }
// NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash, // NewCShake128 creates a new instance of cSHAKE128 variable-output-length ShakeHash,
@ -162,7 +143,7 @@ func NewCShake128(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 { if len(N) == 0 && len(S) == 0 {
return NewShake128() return NewShake128()
} }
return newCShake(N, S, rateK256, 32, dsbyteCShake) return newCShake(N, S, rate128, 32, dsbyteCShake)
} }
// NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash, // NewCShake256 creates a new instance of cSHAKE256 variable-output-length ShakeHash,
@ -175,7 +156,7 @@ func NewCShake256(N, S []byte) ShakeHash {
if len(N) == 0 && len(S) == 0 { if len(N) == 0 && len(S) == 0 {
return NewShake256() return NewShake256()
} }
return newCShake(N, S, rateK512, 64, dsbyteCShake) return newCShake(N, S, rate256, 64, dsbyteCShake)
} }
// ShakeSum128 writes an arbitrary-length digest of data into hash. // ShakeSum128 writes an arbitrary-length digest of data into hash.

40
vendor/golang.org/x/crypto/sha3/xor.go generated vendored Normal file
View file

@ -0,0 +1,40 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package sha3
import (
"crypto/subtle"
"encoding/binary"
"unsafe"
"golang.org/x/sys/cpu"
)
// xorIn xors the bytes in buf into the state.
func xorIn(d *state, buf []byte) {
if cpu.IsBigEndian {
for i := 0; len(buf) >= 8; i++ {
a := binary.LittleEndian.Uint64(buf)
d.a[i] ^= a
buf = buf[8:]
}
} else {
ab := (*[25 * 64 / 8]byte)(unsafe.Pointer(&d.a))
subtle.XORBytes(ab[:], ab[:], buf)
}
}
// copyOut copies uint64s to a byte buffer.
func copyOut(d *state, b []byte) {
if cpu.IsBigEndian {
for i := 0; len(b) >= 8; i++ {
binary.LittleEndian.PutUint64(b, d.a[i])
b = b[8:]
}
} else {
ab := (*[25 * 64 / 8]byte)(unsafe.Pointer(&d.a))
copy(b, ab[:])
}
}

View file

@ -555,7 +555,6 @@ type initiateMsg struct {
} }
gotMsgExtInfo := false gotMsgExtInfo := false
gotUserAuthInfoRequest := false
for { for {
packet, err := c.readPacket() packet, err := c.readPacket()
if err != nil { if err != nil {
@ -586,9 +585,6 @@ type initiateMsg struct {
if msg.PartialSuccess { if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil return authPartialSuccess, msg.Methods, nil
} }
if !gotUserAuthInfoRequest {
return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
return authFailure, msg.Methods, nil return authFailure, msg.Methods, nil
case msgUserAuthSuccess: case msgUserAuthSuccess:
return authSuccess, nil, nil return authSuccess, nil, nil
@ -600,7 +596,6 @@ type initiateMsg struct {
if err := Unmarshal(packet, &msg); err != nil { if err := Unmarshal(packet, &msg); err != nil {
return authFailure, nil, err return authFailure, nil, err
} }
gotUserAuthInfoRequest = true
// Manually unpack the prompt/echo pairs. // Manually unpack the prompt/echo pairs.
rest := msg.Prompts rest := msg.Prompts

Some files were not shown because too many files have changed in this diff Show more