Compare commits

...

25 commits

Author SHA1 Message Date
tobi f09a8027b5
Merge 4be1f780a1 into 2c3f1f4ddb 2024-10-08 09:52:53 -07:00
kim 2c3f1f4ddb
[chore] update go-sqlite3 to v0.19.0 (#3406) 2024-10-08 11:15:09 +02:00
tobi 1e421cb912
[feature] Distribute + ingest Accepts to followers (#3404) 2024-10-08 08:51:13 +00:00
dependabot[bot] 99f535f99b
[chore]: Bump golang.org/x/image from 0.20.0 to 0.21.0 (#3399) 2024-10-07 12:25:52 +00:00
dependabot[bot] 33bd97a535
[chore]: Bump golang.org/x/net from 0.29.0 to 0.30.0 (#3402) 2024-10-07 12:02:26 +00:00
kim bd1866ad8a
update go-ffmpreg to v0.3.1 (pulls in latest wazero too) (#3398) 2024-10-06 20:53:03 +00:00
tobi 02470db5f6
[chore/themes] Tweak colors on new themes (#3397) 2024-10-06 13:05:13 +02:00
tobi c023bd30f3
[bugfix] Only allow boosting post from non-interaction-policy-aware instance if public or unlisted (#3396) 2024-10-05 19:15:02 +02:00
tobi 18e2f69e85
[bugfix] Return 501 (not implemented) if user tries to schedule post (#3395) 2024-10-05 19:14:53 +02:00
tobi f0376635ad
[chore] Change order of error checking after PostInbox (#3394)
Check for malformed errors embedded inside error *first*, then check for gtserror.WithCode.
2024-10-05 17:08:42 +02:00
tobi 5c055afa08
[feature/frontend] Add Moonlight hunt theme (#3393)
* [feature/frontend] Add Moonlight Hunt theme

* make almost see through a bit less see through

* update
2024-10-05 15:12:40 +02:00
tobi c33b1e89c1
[bugfix] Update select of pending interaction requests to account for potential nil URI (#3392) 2024-10-05 12:27:53 +02:00
tobi 36abd568b1
[docs] Make protocol config option really explicit (#3391) 2024-10-05 12:09:58 +02:00
tobi 37a3d224a7
[bugfix] Account for nil reply when serializing int req (#3389) 2024-10-05 11:36:01 +02:00
tobi d3d6e3f920
[bugfix] Don't try to add nil filtered statuses to context (#3388) 2024-10-04 19:23:18 +02:00
tobi 8bd8c6fb45
[bugfix] Include own account in conversation when no other accounts involved (#3387) 2024-10-04 19:22:52 +02:00
kim f550f596fa
[performance] remove the pragma optimize analysis limit on connection close (#3386) 2024-10-04 19:05:42 +02:00
cui fliter 23b6d2cc64
fix: fix slice init length (#3382) 2024-10-03 17:22:26 +00:00
tobi 4be1f780a1 goreleaser deprecation notices 2024-09-02 15:14:27 +02:00
tobi 8db3d6b700 allow overflow in imaging 2024-09-02 15:08:07 +02:00
tobi 666b8bc4f2 Merge branch 'main' into go_123 2024-09-02 14:38:00 +02:00
tobi 7c6c74243b bump go version in go.mod 2024-09-01 17:44:54 +02:00
tobi 75d3fca08c sign 2024-09-01 17:42:45 +02:00
tobi bd4c4d79fe undo silly change 2024-09-01 17:37:17 +02:00
tobi c1543c029b [chore] Bump tooling versions, bump go -> v1.23.0 2024-09-01 17:35:31 +02:00
136 changed files with 2673 additions and 1074 deletions

View file

@ -12,7 +12,7 @@ steps:
# We use golangci-lint for linting. # We use golangci-lint for linting.
# See: https://golangci-lint.run/ # See: https://golangci-lint.run/
- name: lint - name: lint
image: golangci/golangci-lint:v1.57.2 image: golangci/golangci-lint:v1.60.3
volumes: volumes:
- name: go-build-cache - name: go-build-cache
path: /root/.cache/go-build path: /root/.cache/go-build
@ -28,7 +28,7 @@ steps:
- pull_request - pull_request
- name: test - name: test
image: golang:1.22-alpine image: golang:1.23.0-alpine
volumes: volumes:
- name: go-build-cache - name: go-build-cache
path: /root/.cache/go-build path: /root/.cache/go-build
@ -94,7 +94,7 @@ steps:
- pull_request - pull_request
- name: snapshot - name: snapshot
image: superseriousbusiness/gotosocial-drone-build:0.6.0 # https://github.com/superseriousbusiness/gotosocial-drone-build image: superseriousbusiness/gotosocial-drone-build:0.7.0 # https://github.com/superseriousbusiness/gotosocial-drone-build
volumes: volumes:
- name: go-build-cache - name: go-build-cache
path: /root/.cache/go-build path: /root/.cache/go-build
@ -135,7 +135,7 @@ steps:
- main - main
- name: release - name: release
image: superseriousbusiness/gotosocial-drone-build:0.6.0 # https://github.com/superseriousbusiness/gotosocial-drone-build image: superseriousbusiness/gotosocial-drone-build:0.7.0 # https://github.com/superseriousbusiness/gotosocial-drone-build
volumes: volumes:
- name: go-build-cache - name: go-build-cache
path: /root/.cache/go-build path: /root/.cache/go-build
@ -194,7 +194,7 @@ clone:
steps: steps:
- name: mirror - name: mirror
image: superseriousbusiness/gotosocial-drone-build:0.6.0 image: superseriousbusiness/gotosocial-drone-build:0.7.0
environment: environment:
ORIGIN_REPO: https://github.com/superseriousbusiness/gotosocial ORIGIN_REPO: https://github.com/superseriousbusiness/gotosocial
TARGET_REPO: https://codeberg.org/superseriousbusiness/gotosocial TARGET_REPO: https://codeberg.org/superseriousbusiness/gotosocial
@ -207,6 +207,6 @@ steps:
--- ---
kind: signature kind: signature
hmac: f4008d87e4e5b67251eb89f255c1224e6ab5818828cab24fc319b8f829176058 hmac: 9810bf692fb1029c13b0a1e2f556e2306d16f7d3eec9ca6163a0499c147280c1
... ...

View file

@ -1,4 +1,5 @@
# https://goreleaser.com # Version 2 of GoReleaser: https://goreleaser.com/errors/version/
version: 2
project_name: gotosocial project_name: gotosocial
before: before:
# https://goreleaser.com/customization/hooks/ # https://goreleaser.com/customization/hooks/
@ -185,7 +186,7 @@ checksum:
name_template: 'checksums.txt' name_template: 'checksums.txt'
snapshot: snapshot:
# https://goreleaser.com/customization/snapshots/ # https://goreleaser.com/customization/snapshots/
name_template: "{{ incpatch .Version }}-SNAPSHOT" version_template: "{{ incpatch .Version }}-SNAPSHOT"
source: source:
# https://goreleaser.com/customization/source/ # https://goreleaser.com/customization/source/
enabled: true enabled: true

View file

@ -2,7 +2,7 @@
# Dockerfile reference: https://docs.docker.com/engine/reference/builder/ # Dockerfile reference: https://docs.docker.com/engine/reference/builder/
# stage 1: generate up-to-date swagger.yaml to put in the final container # stage 1: generate up-to-date swagger.yaml to put in the final container
FROM --platform=${BUILDPLATFORM} golang:1.22-alpine AS swagger FROM --platform=${BUILDPLATFORM} golang:1.23.0-alpine AS swagger
RUN \ RUN \
### Installs goswagger for building swagger definitions inside this container ### Installs goswagger for building swagger definitions inside this container
@ -28,7 +28,7 @@ RUN yarn --cwd ./web/source install && \
rm -rf ./web/source rm -rf ./web/source
# stage 3: build the executor container # stage 3: build the executor container
FROM --platform=${TARGETPLATFORM} alpine:3.19.1 as executor FROM --platform=${TARGETPLATFORM} alpine:3.20.2 as executor
# switch to non-root user:group for GtS # switch to non-root user:group for GtS
USER 1000:1000 USER 1000:1000

View file

@ -177,6 +177,10 @@ It's also easy for admins to [add their own custom themes](https://docs.gotosoci
<img src="https://raw.githubusercontent.com/superseriousbusiness/gotosocial/main/docs/assets/theme-midnight-trip.png"/> <img src="https://raw.githubusercontent.com/superseriousbusiness/gotosocial/main/docs/assets/theme-midnight-trip.png"/>
<figcaption>Midnight trip</figcaption> <figcaption>Midnight trip</figcaption>
</figure> </figure>
<figure>
<img src="https://raw.githubusercontent.com/superseriousbusiness/gotosocial/main/docs/assets/theme-moonlight-hunt.png"/>
<figcaption>Moonlight hunt</figcaption>
</figure>
<hr/> <hr/>
<figure> <figure>
<img src="https://raw.githubusercontent.com/superseriousbusiness/gotosocial/main/docs/assets/theme-rainforest.png"/> <img src="https://raw.githubusercontent.com/superseriousbusiness/gotosocial/main/docs/assets/theme-rainforest.png"/>

View file

@ -950,7 +950,12 @@ definitions:
with "direct message" visibility. with "direct message" visibility.
properties: properties:
accounts: accounts:
description: Participants in the conversation. description: |-
Participants in the conversation.
If this is a conversation between no accounts (ie., a self-directed DM),
this will include only the requesting account itself. Otherwise, it will
include every other account in the conversation *except* the requester.
items: items:
$ref: '#/definitions/account' $ref: '#/definitions/account'
type: array type: array
@ -8942,7 +8947,7 @@ paths:
Providing this parameter will cause ScheduledStatus to be returned instead of Status. Providing this parameter will cause ScheduledStatus to be returned instead of Status.
Must be at least 5 minutes in the future. Must be at least 5 minutes in the future.
This feature isn't implemented yet. This feature isn't implemented yet; attemping to set it will return 501 Not Implemented.
in: formData in: formData
name: scheduled_at name: scheduled_at
type: string type: string
@ -9003,6 +9008,8 @@ paths:
description: not acceptable description: not acceptable
"500": "500":
description: internal server error description: internal server error
"501":
description: scheduled_at was set, but this feature is not yet implemented
security: security:
- OAuth2 Bearer: - OAuth2 Bearer:
- write:statuses - write:statuses

Binary file not shown.

After

Width:  |  Height:  |  Size: 682 KiB

View file

@ -80,10 +80,18 @@ host: "localhost"
# Default: "" # Default: ""
account-domain: "" account-domain: ""
# String. Protocol to use for the server. Only change to http for local testing! # String. Protocol over which the server is reachable from the outside world.
# This should be the protocol part of the URI that your server is actually reachable on. So even if you're #
# running GoToSocial behind a reverse proxy that handles SSL certificates for you, instead of using built-in # ONLY CHANGE THIS TO HTTP FOR LOCAL TESTING! IN 99.99% OF CASES YOU SHOULD NOT CHANGE THIS!
# letsencrypt, it should still be https. #
# This should be the protocol part of the URI that your server is actually reachable on.
# So even if you're running GoToSocial behind a reverse proxy that handles SSL certificates
# for you, instead of using built-in letsencrypt, it should still be https, not http.
#
# Again, ONLY CHANGE THIS TO HTTP FOR LOCAL TESTING! If you set this to `http`, start your instance,
# and then later change it to `https`, you will have already broken URI generation for any created
# users on the instance. You should only touch this setting if you 100% know what you're doing.
#
# Options: ["http","https"] # Options: ["http","https"]
# Default: "https" # Default: "https"
protocol: "https" protocol: "https"

View file

@ -569,6 +569,7 @@ For example, the following json object `Reject`s the attempt of `@someone@somewh
```json ```json
{ {
"@context": "https://www.w3.org/ns/activitystreams",
"actor": "https://example.org/users/post_author", "actor": "https://example.org/users/post_author",
"to": "https://somewhere.else.example.org/users/someone", "to": "https://somewhere.else.example.org/users/someone",
"id": "https://example.org/users/post_author/activities/reject/01J0K2YXP9QCT5BE1JWQSAM3B6", "id": "https://example.org/users/post_author/activities/reject/01J0K2YXP9QCT5BE1JWQSAM3B6",
@ -591,7 +592,12 @@ For example, the following json object `Accept`s the attempt of `@someone@somewh
```json ```json
{ {
"@context": "https://www.w3.org/ns/activitystreams",
"actor": "https://example.org/users/post_author", "actor": "https://example.org/users/post_author",
"cc": [
"https://www.w3.org/ns/activitystreams#Public",
"https://example.org/users/post_author/followers"
],
"to": "https://somewhere.else.example.org/users/someone", "to": "https://somewhere.else.example.org/users/someone",
"id": "https://example.org/users/post_author/activities/reject/01J0K2YXP9QCT5BE1JWQSAM3B6", "id": "https://example.org/users/post_author/activities/reject/01J0K2YXP9QCT5BE1JWQSAM3B6",
"object": "https://somewhere.else.example.org/users/someone/statuses/01J17XY2VXGMNNPH1XR7BG2524", "object": "https://somewhere.else.example.org/users/someone/statuses/01J17XY2VXGMNNPH1XR7BG2524",
@ -601,6 +607,9 @@ For example, the following json object `Accept`s the attempt of `@someone@somewh
If this happens, `@someone@somewhere.else.example.org` (and their instance) should consider the interaction as having been approved / accepted. The instance can then feel free to distribute the interaction `Activity` to all of the recipients targed by `to`, `cc`, etc, with the additional property `approvedBy` ([see below](#approvedby)). If this happens, `@someone@somewhere.else.example.org` (and their instance) should consider the interaction as having been approved / accepted. The instance can then feel free to distribute the interaction `Activity` to all of the recipients targed by `to`, `cc`, etc, with the additional property `approvedBy` ([see below](#approvedby)).
!!! Note
In the above example, actor `https://example.org/users/post_author` addresses the `Accept` activity not just to the interacting actor `https://somewhere.else.example.org/users/someone`, but to their followers collection as well (and, implicitly, to the public). This allows followers of `https://example.org/users/post_author` on other servers to also mark the interaction as accepted, and to show the interaction alongside the interacted-with post.
### Validating presence in a Followers / Following collection ### Validating presence in a Followers / Following collection
If an `Actor` interacting with an `Object` (via `Like`, `inReplyTo`, or `Announce`) is permitted to do that interaction based on their presence in a `Followers` or `Following` collection in the `always` field of an interaction policy, then their server should *still* wait for an `Accept` to be received from the server of the target account, before distributing the interaction more widely with the `approvedBy` property set to the URI of the `Accept`. If an `Actor` interacting with an `Object` (via `Like`, `inReplyTo`, or `Announce`) is permitted to do that interaction based on their presence in a `Followers` or `Following` collection in the `always` field of an interaction policy, then their server should *still* wait for an `Accept` to be received from the server of the target account, before distributing the interaction more widely with the `approvedBy` property set to the URI of the `Accept`.

View file

@ -88,10 +88,18 @@ host: "localhost"
# Default: "" # Default: ""
account-domain: "" account-domain: ""
# String. Protocol to use for the server. Only change to http for local testing! # String. Protocol over which the server is reachable from the outside world.
# This should be the protocol part of the URI that your server is actually reachable on. So even if you're #
# running GoToSocial behind a reverse proxy that handles SSL certificates for you, instead of using built-in # ONLY CHANGE THIS TO HTTP FOR LOCAL TESTING! IN 99.99% OF CASES YOU SHOULD NOT CHANGE THIS!
# letsencrypt, it should still be https. #
# This should be the protocol part of the URI that your server is actually reachable on.
# So even if you're running GoToSocial behind a reverse proxy that handles SSL certificates
# for you, instead of using built-in letsencrypt, it should still be https, not http.
#
# Again, ONLY CHANGE THIS TO HTTP FOR LOCAL TESTING! If you set this to `http`, start your instance,
# and then later change it to `https`, you will have already broken URI generation for any created
# users on the instance. You should only touch this setting if you 100% know what you're doing.
#
# Options: ["http","https"] # Options: ["http","https"]
# Default: "https" # Default: "https"
protocol: "https" protocol: "https"

18
go.mod
View file

@ -1,6 +1,6 @@
module github.com/superseriousbusiness/gotosocial module github.com/superseriousbusiness/gotosocial
go 1.22.2 go 1.23
replace modernc.org/sqlite => gitlab.com/NyaaaWhatsUpDoc/sqlite v1.33.1-concurrency-workaround replace modernc.org/sqlite => gitlab.com/NyaaaWhatsUpDoc/sqlite v1.33.1-concurrency-workaround
@ -12,7 +12,7 @@ require (
codeberg.org/gruf/go-debug v1.3.0 codeberg.org/gruf/go-debug v1.3.0
codeberg.org/gruf/go-errors/v2 v2.3.2 codeberg.org/gruf/go-errors/v2 v2.3.2
codeberg.org/gruf/go-fastcopy v1.1.3 codeberg.org/gruf/go-fastcopy v1.1.3
codeberg.org/gruf/go-ffmpreg v0.2.6 codeberg.org/gruf/go-ffmpreg v0.3.1
codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf
codeberg.org/gruf/go-kv v1.6.5 codeberg.org/gruf/go-kv v1.6.5
codeberg.org/gruf/go-list v0.0.0-20240425093752-494db03d641f codeberg.org/gruf/go-list v0.0.0-20240425093752-494db03d641f
@ -44,7 +44,7 @@ require (
github.com/miekg/dns v1.1.62 github.com/miekg/dns v1.1.62
github.com/minio/minio-go/v7 v7.0.77 github.com/minio/minio-go/v7 v7.0.77
github.com/mitchellh/mapstructure v1.5.0 github.com/mitchellh/mapstructure v1.5.0
github.com/ncruces/go-sqlite3 v0.18.4 github.com/ncruces/go-sqlite3 v0.19.0
github.com/oklog/ulid v1.3.1 github.com/oklog/ulid v1.3.1
github.com/prometheus/client_golang v1.20.4 github.com/prometheus/client_golang v1.20.4
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1
@ -55,7 +55,7 @@ require (
github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB.0.20230227143000-f4900831d6c8 github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB.0.20230227143000-f4900831d6c8
github.com/tdewolff/minify/v2 v2.20.37 github.com/tdewolff/minify/v2 v2.20.37
github.com/technologize/otel-go-contrib v1.1.1 github.com/technologize/otel-go-contrib v1.1.1
github.com/tetratelabs/wazero v1.8.0 github.com/tetratelabs/wazero v1.8.1
github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80
github.com/ulule/limiter/v3 v3.11.2 github.com/ulule/limiter/v3 v3.11.2
github.com/uptrace/bun v1.2.1 github.com/uptrace/bun v1.2.1
@ -73,11 +73,11 @@ require (
go.opentelemetry.io/otel/sdk/metric v1.29.0 go.opentelemetry.io/otel/sdk/metric v1.29.0
go.opentelemetry.io/otel/trace v1.29.0 go.opentelemetry.io/otel/trace v1.29.0
go.uber.org/automaxprocs v1.6.0 go.uber.org/automaxprocs v1.6.0
golang.org/x/crypto v0.27.0 golang.org/x/crypto v0.28.0
golang.org/x/image v0.20.0 golang.org/x/image v0.21.0
golang.org/x/net v0.29.0 golang.org/x/net v0.30.0
golang.org/x/oauth2 v0.23.0 golang.org/x/oauth2 v0.23.0
golang.org/x/text v0.18.0 golang.org/x/text v0.19.0
gopkg.in/mcuadros/go-syslog.v2 v2.3.0 gopkg.in/mcuadros/go-syslog.v2 v2.3.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
modernc.org/sqlite v0.0.0-00010101000000-000000000000 modernc.org/sqlite v0.0.0-00010101000000-000000000000
@ -213,7 +213,7 @@ require (
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect
golang.org/x/mod v0.18.0 // indirect golang.org/x/mod v0.18.0 // indirect
golang.org/x/sync v0.8.0 // indirect golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect golang.org/x/sys v0.26.0 // indirect
golang.org/x/tools v0.22.0 // indirect golang.org/x/tools v0.22.0 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect

36
go.sum
View file

@ -46,8 +46,8 @@ codeberg.org/gruf/go-fastcopy v1.1.3 h1:Jo9VTQjI6KYimlw25PPc7YLA3Xm+XMQhaHwKnM7x
codeberg.org/gruf/go-fastcopy v1.1.3/go.mod h1:GDDYR0Cnb3U/AIfGM3983V/L+GN+vuwVMvrmVABo21s= codeberg.org/gruf/go-fastcopy v1.1.3/go.mod h1:GDDYR0Cnb3U/AIfGM3983V/L+GN+vuwVMvrmVABo21s=
codeberg.org/gruf/go-fastpath/v2 v2.0.0 h1:iAS9GZahFhyWEH0KLhFEJR+txx1ZhMXxYzu2q5Qo9c0= codeberg.org/gruf/go-fastpath/v2 v2.0.0 h1:iAS9GZahFhyWEH0KLhFEJR+txx1ZhMXxYzu2q5Qo9c0=
codeberg.org/gruf/go-fastpath/v2 v2.0.0/go.mod h1:3pPqu5nZjpbRrOqvLyAK7puS1OfEtQvjd6342Cwz56Q= codeberg.org/gruf/go-fastpath/v2 v2.0.0/go.mod h1:3pPqu5nZjpbRrOqvLyAK7puS1OfEtQvjd6342Cwz56Q=
codeberg.org/gruf/go-ffmpreg v0.2.6 h1:OHlTOF+62/b+VeM3Svg7praweU/NECRIsuhilZLFaO4= codeberg.org/gruf/go-ffmpreg v0.3.1 h1:5qE6sHQbLCbQ4RO7ZL4OKZBN4ViAYfDm9ExT8N0ZE7s=
codeberg.org/gruf/go-ffmpreg v0.2.6/go.mod h1:sViRI0BYK2B8PJw4BrOg7DquPD71mZjDfffRAFcDtvk= codeberg.org/gruf/go-ffmpreg v0.3.1/go.mod h1:Ar5nbt3tB2Wr0uoaqV3wDBNwAx+H+AB/mV7Kw7NlZTI=
codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf h1:84s/ii8N6lYlskZjHH+DG6jyia8w2mXMZlRwFn8Gs3A= codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf h1:84s/ii8N6lYlskZjHH+DG6jyia8w2mXMZlRwFn8Gs3A=
codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf/go.mod h1:zZAICsp5rY7+hxnws2V0ePrWxE0Z2Z/KXcN3p/RQCfk= codeberg.org/gruf/go-iotools v0.0.0-20240710125620-934ae9c654cf/go.mod h1:zZAICsp5rY7+hxnws2V0ePrWxE0Z2Z/KXcN3p/RQCfk=
codeberg.org/gruf/go-kv v1.6.5 h1:ttPf0NA8F79pDqBttSudPTVCZmGncumeNIxmeM9ztz0= codeberg.org/gruf/go-kv v1.6.5 h1:ttPf0NA8F79pDqBttSudPTVCZmGncumeNIxmeM9ztz0=
@ -434,8 +434,8 @@ github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-sqlite3 v0.18.4 h1:Je8o3y33MDwPYY/Cacas8yCsuoUzpNY/AgoSlN2ekyE= github.com/ncruces/go-sqlite3 v0.19.0 h1:yebbD/cP8Gf+7nKoUin2ATjnqJK2VvyS30d3xsjRp5k=
github.com/ncruces/go-sqlite3 v0.18.4/go.mod h1:4HLag13gq1k10s4dfGBhMfRVsssJRT9/5hYqVM9RUYo= github.com/ncruces/go-sqlite3 v0.19.0/go.mod h1:yL4ZNWGsr1/8pcLfpPW1RT1WFdvyeHonrgIwwi4rvkg=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
@ -548,8 +548,8 @@ github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739 h1:IkjBCtQOOjIn03
github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8= github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8=
github.com/technologize/otel-go-contrib v1.1.1 h1:wZH9aSPNWZWIkEh3vfaKfMb15AJ80jJ1aVj/4GZdqIw= github.com/technologize/otel-go-contrib v1.1.1 h1:wZH9aSPNWZWIkEh3vfaKfMb15AJ80jJ1aVj/4GZdqIw=
github.com/technologize/otel-go-contrib v1.1.1/go.mod h1:dCN/wj2WyUO8aFZFdIN+6tfJHImjTML/8r2YVYAy3So= github.com/technologize/otel-go-contrib v1.1.1/go.mod h1:dCN/wj2WyUO8aFZFdIN+6tfJHImjTML/8r2YVYAy3So=
github.com/tetratelabs/wazero v1.8.0 h1:iEKu0d4c2Pd+QSRieYbnQC9yiFlMS9D+Jr0LsRmcF4g= github.com/tetratelabs/wazero v1.8.1 h1:NrcgVbWfkWvVc4UtT4LRLDf91PsOzDzefMdwhLfA550=
github.com/tetratelabs/wazero v1.8.0/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= github.com/tetratelabs/wazero v1.8.1/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs=
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 h1:G6Z6HvJuPjG6XfNGi/feOATzeJrfgTNJY+rGrHbA04E= github.com/tidwall/btree v0.0.0-20191029221954-400434d76274 h1:G6Z6HvJuPjG6XfNGi/feOATzeJrfgTNJY+rGrHbA04E=
github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8= github.com/tidwall/btree v0.0.0-20191029221954-400434d76274/go.mod h1:huei1BkDWJ3/sLXmO+bsCNELL+Bp2Kks9OLyQFkzvA8=
github.com/tidwall/buntdb v1.1.2 h1:noCrqQXL9EKMtcdwJcmuVKSEjqu1ua99RHHgbLTEHRo= github.com/tidwall/buntdb v1.1.2 h1:noCrqQXL9EKMtcdwJcmuVKSEjqu1ua99RHHgbLTEHRo=
@ -668,8 +668,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4=
golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@ -684,8 +684,8 @@ golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUF
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.20.0 h1:7cVCUjQwfL18gyBJOmYvptfSHS8Fb3YUDtfLIZ7Nbpw= golang.org/x/image v0.21.0 h1:c5qV36ajHpdj4Qi0GnE0jUc/yuo33OLFaa0d+crTD5s=
golang.org/x/image v0.20.0/go.mod h1:0a88To4CYVBAHp5FXJm8o7QbUl37Vd85ply1vyD8auM= golang.org/x/image v0.21.0/go.mod h1:vUbsLavqK/W303ZroQQVKQ+Af3Yl6Uz1Ppu5J/cLz78=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
@ -739,8 +739,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
@ -800,13 +800,13 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -814,8 +814,8 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

View file

@ -77,6 +77,10 @@
// See https://www.w3.org/TR/activitystreams-vocabulary/#microsyntaxes // See https://www.w3.org/TR/activitystreams-vocabulary/#microsyntaxes
// and https://www.w3.org/TR/activitystreams-vocabulary/#dfn-tag // and https://www.w3.org/TR/activitystreams-vocabulary/#dfn-tag
TagHashtag = "Hashtag" TagHashtag = "Hashtag"
// Not in the AS spec, just used internally to indicate
// that we don't *yet* know what type of Object something is.
ObjectUnknown = "Unknown"
) )
// isActivity returns whether AS type name is of an Activity (NOT IntransitiveActivity). // isActivity returns whether AS type name is of an Activity (NOT IntransitiveActivity).

View file

@ -145,8 +145,8 @@ func validateCreateEmoji(form *apimodel.EmojiCreateRequest) error {
return errors.New("no emoji given") return errors.New("no emoji given")
} }
maxSize := config.GetMediaEmojiLocalMaxSize() maxSize := int64(config.GetMediaEmojiLocalMaxSize()) // #nosec G115 -- Already validated.
if form.Image.Size > int64(maxSize) { if form.Image.Size > maxSize {
return fmt.Errorf("emoji image too large: image is %dKB but size limit for custom emojis is %dKB", form.Image.Size/1024, maxSize/1024) return fmt.Errorf("emoji image too large: image is %dKB but size limit for custom emojis is %dKB", form.Image.Size/1024, maxSize/1024)
} }

View file

@ -208,8 +208,8 @@ func validateUpdateEmoji(form *apimodel.EmojiUpdateRequest) error {
} }
if hasImage { if hasImage {
maxSize := config.GetMediaEmojiLocalMaxSize() maxSize := int64(config.GetMediaEmojiLocalMaxSize()) // #nosec G115 -- Already validated.
if form.Image.Size > int64(maxSize) { if form.Image.Size > maxSize {
return fmt.Errorf("emoji image too large: image is %dKB but size limit for custom emojis is %dKB", form.Image.Size/1024, maxSize/1024) return fmt.Errorf("emoji image too large: image is %dKB but size limit for custom emojis is %dKB", form.Image.Size/1024, maxSize/1024)
} }
} }

View file

@ -181,7 +181,7 @@
// Providing this parameter will cause ScheduledStatus to be returned instead of Status. // Providing this parameter will cause ScheduledStatus to be returned instead of Status.
// Must be at least 5 minutes in the future. // Must be at least 5 minutes in the future.
// //
// This feature isn't implemented yet. // This feature isn't implemented yet; attemping to set it will return 501 Not Implemented.
// type: string // type: string
// in: formData // in: formData
// - // -
@ -254,6 +254,8 @@
// description: not acceptable // description: not acceptable
// '500': // '500':
// description: internal server error // description: internal server error
// '501':
// description: scheduled_at was set, but this feature is not yet implemented
func (m *Module) StatusCreatePOSTHandler(c *gin.Context) { func (m *Module) StatusCreatePOSTHandler(c *gin.Context) {
authed, err := oauth.Authed(c, true, true, true, true) authed, err := oauth.Authed(c, true, true, true, true)
if err != nil { if err != nil {
@ -286,8 +288,8 @@ func (m *Module) StatusCreatePOSTHandler(c *gin.Context) {
// } // }
// form.Status += "\n\nsent from " + user + "'s iphone\n" // form.Status += "\n\nsent from " + user + "'s iphone\n"
if err := validateNormalizeCreateStatus(form); err != nil { if errWithCode := validateStatusCreateForm(form); errWithCode != nil {
apiutil.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGetV1) apiutil.ErrorHandler(c, errWithCode, m.processor.InstanceGetV1)
return return
} }
@ -374,46 +376,61 @@ func parseStatusCreateForm(c *gin.Context) (*apimodel.StatusCreateRequest, error
return form, nil return form, nil
} }
// validateNormalizeCreateStatus checks the form // validateStatusCreateForm checks the form for disallowed
// for disallowed combinations of attachments and // combinations of attachments, overlength inputs, etc.
// overlength inputs.
// //
// Side effect: normalizes the post's language tag. // Side effect: normalizes the post's language tag.
func validateNormalizeCreateStatus(form *apimodel.StatusCreateRequest) error { func validateStatusCreateForm(form *apimodel.StatusCreateRequest) gtserror.WithCode {
hasStatus := form.Status != "" var (
hasMedia := len(form.MediaIDs) != 0 chars = len([]rune(form.Status)) + len([]rune(form.SpoilerText))
hasPoll := form.Poll != nil maxChars = config.GetStatusesMaxChars()
mediaFiles = len(form.MediaIDs)
maxMediaFiles = config.GetStatusesMediaMaxFiles()
hasMedia = mediaFiles != 0
hasPoll = form.Poll != nil
)
if !hasStatus && !hasMedia && !hasPoll { if chars == 0 && !hasMedia && !hasPoll {
return errors.New("no status, media, or poll provided") // Status must contain *some* kind of content.
const text = "no status content, content warning, media, or poll provided"
return gtserror.NewErrorBadRequest(errors.New(text), text)
} }
if hasMedia && hasPoll { if chars > maxChars {
return errors.New("can't post media + poll in same status") text := fmt.Sprintf(
"status too long, %d characters provided (including content warning) but limit is %d",
chars, maxChars,
)
return gtserror.NewErrorBadRequest(errors.New(text), text)
} }
maxChars := config.GetStatusesMaxChars() if mediaFiles > maxMediaFiles {
if length := len([]rune(form.Status)) + len([]rune(form.SpoilerText)); length > maxChars { text := fmt.Sprintf(
return fmt.Errorf("status too long, %d characters provided (including spoiler/content warning) but limit is %d", length, maxChars) "too many media files attached to status, %d attached but limit is %d",
} mediaFiles, maxMediaFiles,
)
maxMediaFiles := config.GetStatusesMediaMaxFiles() return gtserror.NewErrorBadRequest(errors.New(text), text)
if len(form.MediaIDs) > maxMediaFiles {
return fmt.Errorf("too many media files attached to status, %d attached but limit is %d", len(form.MediaIDs), maxMediaFiles)
} }
if form.Poll != nil { if form.Poll != nil {
if err := validateNormalizeCreatePoll(form); err != nil { if errWithCode := validateStatusPoll(form); errWithCode != nil {
return err return errWithCode
} }
} }
if form.ScheduledAt != "" {
const text = "scheduled_at is not yet implemented"
return gtserror.NewErrorNotImplemented(errors.New(text), text)
}
// Validate + normalize
// language tag if provided.
if form.Language != "" { if form.Language != "" {
language, err := validate.Language(form.Language) lang, err := validate.Language(form.Language)
if err != nil { if err != nil {
return err return gtserror.NewErrorBadRequest(err, err.Error())
} }
form.Language = language form.Language = lang
} }
// Check if the deprecated "federated" field was // Check if the deprecated "federated" field was
@ -425,9 +442,36 @@ func validateNormalizeCreateStatus(form *apimodel.StatusCreateRequest) error {
return nil return nil
} }
func validateNormalizeCreatePoll(form *apimodel.StatusCreateRequest) error { func validateStatusPoll(form *apimodel.StatusCreateRequest) gtserror.WithCode {
maxPollOptions := config.GetStatusesPollMaxOptions() var (
maxPollChars := config.GetStatusesPollOptionMaxChars() maxPollOptions = config.GetStatusesPollMaxOptions()
pollOptions = len(form.Poll.Options)
maxPollOptionChars = config.GetStatusesPollOptionMaxChars()
)
if pollOptions == 0 {
const text = "poll with no options"
return gtserror.NewErrorBadRequest(errors.New(text), text)
}
if pollOptions > maxPollOptions {
text := fmt.Sprintf(
"too many poll options provided, %d provided but limit is %d",
pollOptions, maxPollOptions,
)
return gtserror.NewErrorBadRequest(errors.New(text), text)
}
for _, option := range form.Poll.Options {
optionChars := len([]rune(option))
if optionChars > maxPollOptionChars {
text := fmt.Sprintf(
"poll option too long, %d characters provided but limit is %d",
optionChars, maxPollOptionChars,
)
return gtserror.NewErrorBadRequest(errors.New(text), text)
}
}
// Normalize poll expiry if necessary. // Normalize poll expiry if necessary.
// If we parsed this as JSON, expires_in // If we parsed this as JSON, expires_in
@ -440,27 +484,15 @@ func validateNormalizeCreatePoll(form *apimodel.StatusCreateRequest) error {
case string: case string:
expiresIn, err := strconv.Atoi(e) expiresIn, err := strconv.Atoi(e)
if err != nil { if err != nil {
return fmt.Errorf("could not parse expires_in value %s as integer: %w", e, err) text := fmt.Sprintf("could not parse expires_in value %s as integer: %v", e, err)
return gtserror.NewErrorBadRequest(errors.New(text), text)
} }
form.Poll.ExpiresIn = expiresIn form.Poll.ExpiresIn = expiresIn
default: default:
return fmt.Errorf("could not parse expires_in type %T as integer", ei) text := fmt.Sprintf("could not parse expires_in type %T as integer", ei)
} return gtserror.NewErrorBadRequest(errors.New(text), text)
}
if len(form.Poll.Options) == 0 {
return errors.New("poll with no options")
}
if len(form.Poll.Options) > maxPollOptions {
return fmt.Errorf("too many poll options provided, %d provided but limit is %d", len(form.Poll.Options), maxPollOptions)
}
for _, p := range form.Poll.Options {
if length := len([]rune(p)); length > maxPollChars {
return fmt.Errorf("poll option too long, %d characters provided but limit is %d", length, maxPollChars)
} }
} }

View file

@ -365,6 +365,25 @@ func (suite *StatusCreateTestSuite) TestPostNewStatusMessedUpIntPolicy() {
}`, out) }`, out)
} }
func (suite *StatusCreateTestSuite) TestPostNewScheduledStatus() {
out, recorder := suite.postStatus(map[string][]string{
"status": {"this is a brand new status! #helloworld"},
"spoiler_text": {"hello hello"},
"sensitive": {"true"},
"visibility": {string(apimodel.VisibilityMutualsOnly)},
"scheduled_at": {"2080-10-04T15:32:02.018Z"},
}, "")
// We should have 501 from
// our call to the function.
suite.Equal(http.StatusNotImplemented, recorder.Code)
// We should have a helpful error message.
suite.Equal(`{
"error": "Not Implemented: scheduled_at is not yet implemented"
}`, out)
}
func (suite *StatusCreateTestSuite) TestPostNewStatusMarkdown() { func (suite *StatusCreateTestSuite) TestPostNewStatusMarkdown() {
out, recorder := suite.postStatus(map[string][]string{ out, recorder := suite.postStatus(map[string][]string{
"status": {statusMarkdown}, "status": {statusMarkdown},

View file

@ -160,7 +160,7 @@ type MediaDimensions struct {
Duration float32 `json:"duration,omitempty"` Duration float32 `json:"duration,omitempty"`
// Bitrate of the media in bits per second. // Bitrate of the media in bits per second.
// example: 1000000 // example: 1000000
Bitrate int `json:"bitrate,omitempty"` Bitrate uint64 `json:"bitrate,omitempty"`
// Size of the media, in the format `[width]x[height]`. // Size of the media, in the format `[width]x[height]`.
// Not set for audio. // Not set for audio.
// example: 1920x1080 // example: 1920x1080

View file

@ -27,6 +27,10 @@ type Conversation struct {
// Is the conversation currently marked as unread? // Is the conversation currently marked as unread?
Unread bool `json:"unread"` Unread bool `json:"unread"`
// Participants in the conversation. // Participants in the conversation.
//
// If this is a conversation between no accounts (ie., a self-directed DM),
// this will include only the requesting account itself. Otherwise, it will
// include every other account in the conversation *except* the requester.
Accounts []Account `json:"accounts"` Accounts []Account `json:"accounts"`
// The last status in the conversation. May be `null`. // The last status in the conversation. May be `null`.
LastStatus *Status `json:"last_status"` LastStatus *Status `json:"last_status"`

View file

@ -220,7 +220,7 @@ func (n *node) getChild(part string) *node {
for i < j { for i < j {
// avoid overflow when computing h // avoid overflow when computing h
h := int(uint(i+j) >> 1) h := int(uint(i+j) >> 1) // #nosec G115
// i ≤ h < j // i ≤ h < j
if n.child[h].part < part { if n.child[h].part < part {

View file

@ -25,6 +25,7 @@
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
"math"
"net/url" "net/url"
"os" "os"
"runtime" "runtime"
@ -407,13 +408,12 @@ func maxOpenConns() int {
// deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options // deriveBunDBPGOptions takes an application config and returns either a ready-to-use set of options
// with sensible defaults, or an error if it's not satisfied by the provided config. // with sensible defaults, or an error if it's not satisfied by the provided config.
func deriveBunDBPGOptions() (*pgx.ConnConfig, error) { func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
url := config.GetDbPostgresConnectionString() // If database URL is defined, ignore
// other DB-related configuration fields.
// if database URL is defined, ignore other DB related configuration fields if url := config.GetDbPostgresConnectionString(); url != "" {
if url != "" { return pgx.ParseConfig(url)
cfg, err := pgx.ParseConfig(url)
return cfg, err
} }
// these are all optional, the db adapter figures out defaults // these are all optional, the db adapter figures out defaults
address := config.GetDbAddress() address := config.GetDbAddress()
@ -477,7 +477,10 @@ func deriveBunDBPGOptions() (*pgx.ConnConfig, error) {
cfg.Host = address cfg.Host = address
} }
if port := config.GetDbPort(); port > 0 { if port := config.GetDbPort(); port > 0 {
cfg.Port = uint16(port) if port > math.MaxUint16 {
return nil, errors.New("invalid port, must be in range 1-65535")
}
cfg.Port = uint16(port) // #nosec G115 -- Just validated above.
} }
if u := config.GetDbUser(); u != "" { if u := config.GetDbUser(); u != "" {
cfg.User = u cfg.User = u

View file

@ -302,9 +302,9 @@ func (i *interactionDB) GetInteractionsRequestsForAcct(
bun.Ident("interaction_request"), bun.Ident("interaction_request"),
). ).
// Select only interaction requests that // Select only interaction requests that
// are neither accepted or rejected yet, // are neither accepted or rejected yet.
// ie., without an Accept or Reject URI. Where("? IS NULL", bun.Ident("accepted_at")).
Where("? IS NULL", bun.Ident("uri")) Where("? IS NULL", bun.Ident("rejected_at"))
// Select interactions targeting status. // Select interactions targeting status.
if statusID != "" { if statusID != "" {

View file

@ -0,0 +1,57 @@
// GoToSocial
// Copyright (C) GoToSocial Authors admin@gotosocial.org
// SPDX-License-Identifier: AGPL-3.0-or-later
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package migrations
import (
"context"
"github.com/uptrace/bun"
)
func init() {
up := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
for idx, col := range map[string]string{
"interaction_requests_accepted_at_idx": "accepted_at",
"interaction_requests_rejected_at_idx": "rejected_at",
} {
if _, err := tx.
NewCreateIndex().
Table("interaction_requests").
Index(idx).
Column(col).
IfNotExists().
Exec(ctx); err != nil {
return err
}
}
return nil
})
}
down := func(ctx context.Context, db *bun.DB) error {
return db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error {
return nil
})
}
if err := Migrations.Register(up, down); err != nil {
panic(err)
}
}

View file

@ -112,7 +112,7 @@ func (c *sqliteConn) Close() (err error) {
raw := c.connIface.(sqlite3driver.Conn).Raw() raw := c.connIface.(sqlite3driver.Conn).Raw()
// see: https://www.sqlite.org/pragma.html#pragma_optimize // see: https://www.sqlite.org/pragma.html#pragma_optimize
const onClose = "PRAGMA analysis_limit=1000; PRAGMA optimize;" const onClose = "PRAGMA optimize;"
_ = raw.Exec(onClose) _ = raw.Exec(onClose)
// Finally, close. // Finally, close.

View file

@ -97,11 +97,11 @@ func() (*media.ProcessingEmoji, error) {
} }
// Get maximum supported remote emoji size. // Get maximum supported remote emoji size.
maxsz := config.GetMediaEmojiRemoteMaxSize() maxsz := int64(config.GetMediaEmojiRemoteMaxSize()) // #nosec G115 -- Already validated.
// Prepare data function to dereference remote emoji media. // Prepare data function to dereference remote emoji media.
data := func(context.Context) (io.ReadCloser, error) { data := func(context.Context) (io.ReadCloser, error) {
return tsport.DereferenceMedia(ctx, url, int64(maxsz)) return tsport.DereferenceMedia(ctx, url, maxsz)
} }
// Create new emoji with prepared info. // Create new emoji with prepared info.
@ -189,11 +189,11 @@ func() (*media.ProcessingEmoji, error) {
} }
// Get maximum supported remote emoji size. // Get maximum supported remote emoji size.
maxsz := config.GetMediaEmojiRemoteMaxSize() maxsz := int64(config.GetMediaEmojiRemoteMaxSize()) // #nosec G115 -- Already validated.
// Prepare data function to dereference remote emoji media. // Prepare data function to dereference remote emoji media.
data := func(context.Context) (io.ReadCloser, error) { data := func(context.Context) (io.ReadCloser, error) {
return tsport.DereferenceMedia(ctx, url, int64(maxsz)) return tsport.DereferenceMedia(ctx, url, maxsz)
} }
// Update emoji with prepared info. // Update emoji with prepared info.
@ -255,11 +255,11 @@ func() (*media.ProcessingEmoji, error) {
} }
// Get maximum supported remote emoji size. // Get maximum supported remote emoji size.
maxsz := config.GetMediaEmojiRemoteMaxSize() maxsz := int64(config.GetMediaEmojiRemoteMaxSize()) // #nosec G115 -- Already validated.
// Prepare data function to dereference remote emoji media. // Prepare data function to dereference remote emoji media.
data := func(context.Context) (io.ReadCloser, error) { data := func(context.Context) (io.ReadCloser, error) {
return tsport.DereferenceMedia(ctx, url, int64(maxsz)) return tsport.DereferenceMedia(ctx, url, maxsz)
} }
// Recache emoji with prepared info. // Recache emoji with prepared info.

View file

@ -77,14 +77,14 @@ func() (*media.ProcessingMedia, error) {
} }
// Get maximum supported remote media size. // Get maximum supported remote media size.
maxsz := config.GetMediaRemoteMaxSize() maxsz := int64(config.GetMediaRemoteMaxSize()) // #nosec G115 -- Already validated.
// Create media with prepared info. // Create media with prepared info.
return d.mediaManager.CreateMedia( return d.mediaManager.CreateMedia(
ctx, ctx,
accountID, accountID,
func(ctx context.Context) (io.ReadCloser, error) { func(ctx context.Context) (io.ReadCloser, error) {
return tsport.DereferenceMedia(ctx, url, int64(maxsz)) return tsport.DereferenceMedia(ctx, url, maxsz)
}, },
info, info,
) )
@ -168,14 +168,14 @@ func() (*media.ProcessingMedia, error) {
} }
// Get maximum supported remote media size. // Get maximum supported remote media size.
maxsz := config.GetMediaRemoteMaxSize() maxsz := int64(config.GetMediaRemoteMaxSize()) // #nosec G115 -- Already validated.
// Recache media with prepared info, // Recache media with prepared info,
// this will also update media in db. // this will also update media in db.
return d.mediaManager.CacheMedia( return d.mediaManager.CacheMedia(
attach, attach,
func(ctx context.Context) (io.ReadCloser, error) { func(ctx context.Context) (io.ReadCloser, error) {
return tsport.DereferenceMedia(ctx, url, int64(maxsz)) return tsport.DereferenceMedia(ctx, url, maxsz)
}, },
), nil ), nil
}, },

View file

@ -527,8 +527,9 @@ func (d *Dereferencer) enrichStatus(
// serve statuses with the `approved_by` field, but we // serve statuses with the `approved_by` field, but we
// might have marked a status as pre-approved on our side // might have marked a status as pre-approved on our side
// based on the author's inclusion in a followers/following // based on the author's inclusion in a followers/following
// collection. By carrying over previously-set values we // collection, or by providing pre-approval URI on the bare
// can avoid marking such statuses as "pending" again. // status passed to RefreshStatus. By carrying over previously
// set values we can avoid marking such statuses as "pending".
// //
// If a remote has in the meantime retracted its approval, // If a remote has in the meantime retracted its approval,
// the next call to 'isPermittedStatus' will catch that. // the next call to 'isPermittedStatus' will catch that.

View file

@ -113,33 +113,17 @@ func (d *Dereferencer) isPermittedStatus(
func (d *Dereferencer) isPermittedReply( func (d *Dereferencer) isPermittedReply(
ctx context.Context, ctx context.Context,
requestUser string, requestUser string,
status *gtsmodel.Status, reply *gtsmodel.Status,
) (bool, error) { ) (bool, error) {
var ( var (
statusURI = status.URI // Definitely set. replyURI = reply.URI // Definitely set.
inReplyToURI = status.InReplyToURI // Definitely set. inReplyToURI = reply.InReplyToURI // Definitely set.
inReplyTo = status.InReplyTo // Might not yet be set. inReplyTo = reply.InReplyTo // Might not be set.
acceptIRI = reply.ApprovedByURI // Might not be set.
) )
// Check if status with this URI has previously been rejected. // Check if we have a stored interaction request for parent status.
req, err := d.state.DB.GetInteractionRequestByInteractionURI( parentReq, err := d.state.DB.GetInteractionRequestByInteractionURI(
gtscontext.SetBarebones(ctx),
statusURI,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err := gtserror.Newf("db error getting interaction request: %w", err)
return false, err
}
if req != nil && req.IsRejected() {
// This status has been
// rejected reviously, so
// it's not permitted now.
return false, nil
}
// Check if replied-to status has previously been rejected.
req, err = d.state.DB.GetInteractionRequestByInteractionURI(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
inReplyToURI, inReplyToURI,
) )
@ -148,71 +132,78 @@ func (d *Dereferencer) isPermittedReply(
return false, err return false, err
} }
if req != nil && req.IsRejected() { // Check if we have a stored interaction request for this reply.
// This status's parent was rejected, so thisReq, err := d.state.DB.GetInteractionRequestByInteractionURI(
// implicitly this reply should be rejected too. gtscontext.SetBarebones(ctx),
// replyURI,
// We know already that we haven't inserted )
// a rejected interaction request for this if err != nil && !errors.Is(err, db.ErrNoEntries) {
// status yet so do it before returning. err := gtserror.Newf("db error getting interaction request: %w", err)
id := id.NewULID() return false, err
}
// To ensure the Reject chain stays coherent, parentRejected := (parentReq != nil && parentReq.IsRejected())
// borrow fields from the up-thread rejection. thisRejected := (thisReq != nil && thisReq.IsRejected())
// This collapses the chain beyond the first
// rejected reply and allows us to avoid derefing
// further replies we already know we don't want.
statusID := req.StatusID
targetAccountID := req.TargetAccountID
// As nobody is actually Rejecting the reply if parentRejected {
// directly, but it's an implicit Reject coming // If this status's parent was rejected,
// from our internal logic, don't bother setting // implicitly this reply should be too;
// a URI (it's not a required field anyway). // there's nothing more to check here.
uri := "" return false, d.unpermittedByParent(
ctx,
rejection := &gtsmodel.InteractionRequest{ reply,
ID: id, thisReq,
StatusID: statusID, parentReq,
TargetAccountID: targetAccountID, )
InteractingAccountID: status.AccountID, }
InteractionURI: statusURI,
InteractionType: gtsmodel.InteractionReply,
URI: uri,
RejectedAt: time.Now(),
}
err := d.state.DB.PutInteractionRequest(ctx, rejection)
if err != nil && !errors.Is(err, db.ErrAlreadyExists) {
return false, gtserror.Newf("db error putting pre-rejected interaction request: %w", err)
}
// Parent wasn't rejected. Check if this
// reply itself was rejected previously.
//
// If it was, and it doesn't now claim to
// be approved, then we should just reject it
// again, as nothing's changed since last time.
if thisRejected && acceptIRI == "" {
// Nothing changed,
// still rejected.
return false, nil return false, nil
} }
// This reply wasn't rejected previously, or
// it was rejected previously and now claims
// to be approved. Continue permission checks.
if inReplyTo == nil { if inReplyTo == nil {
// We didn't have the replied-to status in // If we didn't have the replied-to status
// our database (yet) so we can't know if // in our database (yet), we can't check
// this reply is permitted or not. For now // right now if this reply is permitted.
// just return true; worst-case, the status //
// sticks around on the instance for a couple // For now, just return permitted if reply
// hours until we try to dereference it again // was not explicitly rejected before; worst-
// and realize it should be forbidden. // case, the reply stays on the instance for
return true, nil // a couple hours until we try to deref it
// again and realize it should be forbidden.
return !thisRejected, nil
} }
// We have the replied-to status; ensure it's fully populated.
if err := d.state.DB.PopulateStatus(ctx, inReplyTo); err != nil {
return false, gtserror.Newf("error populating status %s: %w", reply.ID, err)
}
// Make sure replied-to status is not
// a boost wrapper, and make sure it's
// actually visible to the requester.
if inReplyTo.BoostOfID != "" { if inReplyTo.BoostOfID != "" {
// We do not permit replies to // We do not permit replies
// boost wrapper statuses. (this // to boost wrapper statuses.
// shouldn't be able to happen).
log.Info(ctx, "rejecting reply to boost wrapper status") log.Info(ctx, "rejecting reply to boost wrapper status")
return false, nil return false, nil
} }
// Check visibility of local
// inReplyTo to replying account.
if inReplyTo.IsLocal() { if inReplyTo.IsLocal() {
visible, err := d.visFilter.StatusVisible(ctx, visible, err := d.visFilter.StatusVisible(ctx,
status.Account, reply.Account,
inReplyTo, inReplyTo,
) )
if err != nil { if err != nil {
@ -227,9 +218,26 @@ func (d *Dereferencer) isPermittedReply(
} }
} }
// Check interaction policy of inReplyTo. // If this reply claims to be approved,
// validate this by dereferencing the
// Accept and checking the return value.
// No further checks are required.
if acceptIRI != "" {
return d.isPermittedByAcceptIRI(
ctx,
requestUser,
reply,
inReplyTo,
thisReq,
acceptIRI,
)
}
// Status doesn't claim to be approved.
// Check interaction policy of inReplyTo
// to see if it doesn't require approval.
replyable, err := d.intFilter.StatusReplyable(ctx, replyable, err := d.intFilter.StatusReplyable(ctx,
status.Account, reply.Account,
inReplyTo, inReplyTo,
) )
if err != nil { if err != nil {
@ -238,93 +246,250 @@ func (d *Dereferencer) isPermittedReply(
} }
if replyable.Forbidden() { if replyable.Forbidden() {
// Reply is not permitted. // Reply is not permitted according to policy.
// //
// Insert a pre-rejected interaction request // Either insert a pre-rejected interaction
// into the db and return. This ensures that // req into the db, or update the existing
// replies to this now-rejected status aren't // one, and return. This ensures that replies
// inadvertently permitted. // to this rejected reply also aren't permitted.
id := id.NewULID() return false, d.rejectedByPolicy(
rejection := &gtsmodel.InteractionRequest{ ctx,
ID: id, reply,
StatusID: inReplyTo.ID, inReplyTo,
TargetAccountID: inReplyTo.AccountID, thisReq,
InteractingAccountID: status.AccountID, )
InteractionURI: statusURI,
InteractionType: gtsmodel.InteractionReply,
URI: uris.GenerateURIForReject(inReplyTo.Account.Username, id),
RejectedAt: time.Now(),
}
err := d.state.DB.PutInteractionRequest(ctx, rejection)
if err != nil && !errors.Is(err, db.ErrAlreadyExists) {
return false, gtserror.Newf("db error putting pre-rejected interaction request: %w", err)
}
return false, nil
} }
if replyable.Permitted() && // Reply is permitted according to the interaction
!replyable.MatchedOnCollection() { // policy set on the replied-to status (if any).
// Replier is permitted to do this
// interaction, and didn't match on if !replyable.MatchedOnCollection() {
// a collection so we don't need to // If we didn't match on a collection,
// do further checking. // then we don't require an acceptIRI,
// and we don't need to send an Accept;
// just permit the reply full stop.
return true, nil return true, nil
} }
// Replier is permitted to do this // Reply is permitted, but match was made based
// interaction pending approval, or // on inclusion in a followers/following collection.
// permitted but matched on a collection.
// //
// Check if we can dereference // If the status is ours, mark it as PreApproved
// an Accept that grants approval. // so the processor knows to create and send out
// an Accept for it immediately.
if status.ApprovedByURI == "" { if inReplyTo.IsLocal() {
// Status doesn't claim to be approved. reply.PendingApproval = util.Ptr(true)
// reply.PreApproved = true
// For replies to local statuses that's return true, nil
// fine, we can put it in the DB pending
// approval, and continue processing it.
//
// If permission was granted based on a match
// with a followers or following collection,
// we can mark it as PreApproved so the processor
// sends an accept out for it immediately.
//
// For replies to remote statuses, though
// we should be polite and just drop it.
if inReplyTo.IsLocal() {
status.PendingApproval = util.Ptr(true)
status.PreApproved = replyable.MatchedOnCollection()
return true, nil
}
return false, nil
} }
// Status claims to be approved, check // For replies to remote statuses, which matched
// this by dereferencing the Accept and // on a followers/following collection, but did not
// inspecting the return value. // include an acceptIRI, we should just drop it.
if err := d.validateApprovedBy( // It's possible we'll get an Accept for it later
// and we can check everything again.
return false, nil
}
// unpermittedByParent marks the given reply as rejected
// based on the fact that its parent was rejected.
//
// This will create a rejected interaction request for
// the status in the db, if one didn't exist already,
// or update an existing interaction request instead.
func (d *Dereferencer) unpermittedByParent(
ctx context.Context,
reply *gtsmodel.Status,
thisReq *gtsmodel.InteractionRequest,
parentReq *gtsmodel.InteractionRequest,
) error {
if thisReq != nil && thisReq.IsRejected() {
// This interaction request is
// already marked as rejected,
// there's nothing more to do.
return nil
}
if thisReq != nil {
// Before we return, ensure interaction
// request is marked as rejected.
thisReq.RejectedAt = time.Now()
thisReq.AcceptedAt = time.Time{}
err := d.state.DB.UpdateInteractionRequest(
ctx,
thisReq,
"rejected_at",
"accepted_at",
)
if err != nil {
return gtserror.Newf("db error updating interaction request: %w", err)
}
return nil
}
// We haven't stored a rejected interaction
// request for this status yet, do it now.
rejectID := id.NewULID()
// To ensure the Reject chain stays coherent,
// borrow fields from the up-thread rejection.
// This collapses the chain beyond the first
// rejected reply and allows us to avoid derefing
// further replies we already know we don't want.
inReplyToID := parentReq.StatusID
targetAccountID := parentReq.TargetAccountID
// As nobody is actually Rejecting the reply
// directly, but it's an implicit Reject coming
// from our internal logic, don't bother setting
// a URI (it's not a required field anyway).
uri := ""
rejection := &gtsmodel.InteractionRequest{
ID: rejectID,
StatusID: inReplyToID,
TargetAccountID: targetAccountID,
InteractingAccountID: reply.AccountID,
InteractionURI: reply.URI,
InteractionType: gtsmodel.InteractionReply,
URI: uri,
RejectedAt: time.Now(),
}
err := d.state.DB.PutInteractionRequest(ctx, rejection)
if err != nil && !errors.Is(err, db.ErrAlreadyExists) {
return gtserror.Newf("db error putting pre-rejected interaction request: %w", err)
}
return nil
}
// isPermittedByAcceptIRI checks whether the given acceptIRI
// permits the given reply to the given inReplyTo status.
// If yes, then thisReq will be updated to reflect the
// acceptance, if it's not nil.
func (d *Dereferencer) isPermittedByAcceptIRI(
ctx context.Context,
requestUser string,
reply *gtsmodel.Status,
inReplyTo *gtsmodel.Status,
thisReq *gtsmodel.InteractionRequest,
acceptIRI string,
) (bool, error) {
permitted, err := d.isValidAccept(
ctx, ctx,
requestUser, requestUser,
status.ApprovedByURI, acceptIRI,
statusURI, reply.URI,
inReplyTo.AccountURI, inReplyTo.AccountURI,
); err != nil { )
if err != nil {
// Error dereferencing means we couldn't // Error dereferencing means we couldn't
// get the Accept right now or it wasn't // get the Accept right now or it wasn't
// valid, so we shouldn't store this status. // valid, so we shouldn't store this status.
log.Errorf(ctx, "undereferencable ApprovedByURI: %v", err) err := gtserror.Newf("undereferencable ApprovedByURI: %w", err)
return false, err
}
if !permitted {
// It's a no from
// us, squirt.
return false, nil return false, nil
} }
// Status has been approved. // Reply is permitted by this Accept.
status.PendingApproval = util.Ptr(false) // If it was previously rejected or
// pending approval, clear that now.
reply.PendingApproval = util.Ptr(false)
if thisReq != nil {
thisReq.URI = acceptIRI
thisReq.AcceptedAt = time.Now()
thisReq.RejectedAt = time.Time{}
err := d.state.DB.UpdateInteractionRequest(
ctx,
thisReq,
"uri",
"accepted_at",
"rejected_at",
)
if err != nil {
return false, gtserror.Newf("db error updating interaction request: %w", err)
}
}
// All good!
return true, nil return true, nil
} }
func (d *Dereferencer) rejectedByPolicy(
ctx context.Context,
reply *gtsmodel.Status,
inReplyTo *gtsmodel.Status,
thisReq *gtsmodel.InteractionRequest,
) error {
var (
rejectID string
rejectURI string
)
if thisReq != nil {
// Reuse existing ID.
rejectID = thisReq.ID
} else {
// Generate new ID.
rejectID = id.NewULID()
}
if inReplyTo.IsLocal() {
// If this a reply to one of our statuses
// we should generate a URI for the Reject,
// else just use an implicit (empty) URI.
rejectURI = uris.GenerateURIForReject(
inReplyTo.Account.Username,
rejectID,
)
}
if thisReq != nil {
// Before we return, ensure interaction
// request is marked as rejected.
thisReq.RejectedAt = time.Now()
thisReq.AcceptedAt = time.Time{}
thisReq.URI = rejectURI
err := d.state.DB.UpdateInteractionRequest(
ctx,
thisReq,
"rejected_at",
"accepted_at",
"uri",
)
if err != nil {
return gtserror.Newf("db error updating interaction request: %w", err)
}
return nil
}
// We haven't stored a rejected interaction
// request for this status yet, do it now.
rejection := &gtsmodel.InteractionRequest{
ID: rejectID,
StatusID: inReplyTo.ID,
TargetAccountID: inReplyTo.AccountID,
InteractingAccountID: reply.AccountID,
InteractionURI: reply.URI,
InteractionType: gtsmodel.InteractionReply,
URI: rejectURI,
RejectedAt: time.Now(),
}
err := d.state.DB.PutInteractionRequest(ctx, rejection)
if err != nil && !errors.Is(err, db.ErrAlreadyExists) {
return gtserror.Newf("db error putting pre-rejected interaction request: %w", err)
}
return nil
}
func (d *Dereferencer) isPermittedBoost( func (d *Dereferencer) isPermittedBoost(
ctx context.Context, ctx context.Context,
requestUser string, requestUser string,
@ -418,18 +583,22 @@ func (d *Dereferencer) isPermittedBoost(
// Boost claims to be approved, check // Boost claims to be approved, check
// this by dereferencing the Accept and // this by dereferencing the Accept and
// inspecting the return value. // inspecting the return value.
if err := d.validateApprovedBy( permitted, err := d.isValidAccept(
ctx, ctx,
requestUser, requestUser,
status.ApprovedByURI, status.ApprovedByURI,
status.URI, status.URI,
boostOf.AccountURI, boostOf.AccountURI,
); err != nil { )
if err != nil {
// Error dereferencing means we couldn't // Error dereferencing means we couldn't
// get the Accept right now or it wasn't // get the Accept right now or it wasn't
// valid, so we shouldn't store this status. // valid, so we shouldn't store this status.
log.Errorf(ctx, "undereferencable ApprovedByURI: %v", err) err := gtserror.Newf("undereferencable ApprovedByURI: %w", err)
return false, err
}
if !permitted {
return false, nil return false, nil
} }
@ -438,43 +607,59 @@ func (d *Dereferencer) isPermittedBoost(
return true, nil return true, nil
} }
// validateApprovedBy dereferences the activitystreams Accept at // isValidAccept dereferences the activitystreams Accept at the
// the specified IRI, and checks the Accept for validity against // specified IRI, and checks the Accept for validity against the
// the provided expectedObject and expectedActor. // provided expectedObject and expectedActor.
// //
// Will return either nil if everything looked OK, or an error if // Will return either (true, nil) if everything looked OK, an error
// something went wrong during deref, or if the dereffed Accept // if something went wrong internally during deref, or (false, nil)
// did not meet expectations. // if the dereferenced Accept did not meet expectations.
func (d *Dereferencer) validateApprovedBy( func (d *Dereferencer) isValidAccept(
ctx context.Context, ctx context.Context,
requestUser string, requestUser string,
approvedByURIStr string, // Eg., "https://example.org/users/someone/accepts/01J2736AWWJ3411CPR833F6D03" acceptIRIStr string, // Eg., "https://example.org/users/someone/accepts/01J2736AWWJ3411CPR833F6D03"
expectObjectURIStr string, // Eg., "https://some.instance.example.org/users/someone_else/statuses/01J27414TWV9F7DC39FN8ABB5R" expectObjectURIStr string, // Eg., "https://some.instance.example.org/users/someone_else/statuses/01J27414TWV9F7DC39FN8ABB5R"
expectActorURIStr string, // Eg., "https://example.org/users/someone" expectActorURIStr string, // Eg., "https://example.org/users/someone"
) error { ) (bool, error) {
approvedByURI, err := url.Parse(approvedByURIStr) l := log.
WithContext(ctx).
WithField("acceptIRI", acceptIRIStr)
acceptIRI, err := url.Parse(acceptIRIStr)
if err != nil { if err != nil {
err := gtserror.Newf("error parsing approvedByURI: %w", err) // Real returnable error.
return err err := gtserror.Newf("error parsing acceptIRI: %w", err)
return false, err
} }
// Don't make calls to the remote if it's blocked. // Don't make calls to the Accept IRI
if blocked, err := d.state.DB.IsDomainBlocked(ctx, approvedByURI.Host); blocked || err != nil { // if it's blocked, just return false.
err := gtserror.Newf("domain %s is blocked", approvedByURI.Host) blocked, err := d.state.DB.IsDomainBlocked(ctx, acceptIRI.Host)
return err if err != nil {
// Real returnable error.
err := gtserror.Newf("error checking domain block: %w", err)
return false, err
}
if blocked {
l.Info("Accept host is blocked")
return false, nil
} }
tsport, err := d.transportController.NewTransportForUsername(ctx, requestUser) tsport, err := d.transportController.NewTransportForUsername(ctx, requestUser)
if err != nil { if err != nil {
// Real returnable error.
err := gtserror.Newf("error creating transport: %w", err) err := gtserror.Newf("error creating transport: %w", err)
return err return false, err
} }
// Make the call to resolve into an Acceptable. // Make the call to resolve into an Acceptable.
rsp, err := tsport.Dereference(ctx, approvedByURI) // Log any error encountered here but don't
// return it as it's not *our* error.
rsp, err := tsport.Dereference(ctx, acceptIRI)
if err != nil { if err != nil {
err := gtserror.Newf("error dereferencing %s: %w", approvedByURIStr, err) l.Errorf("error dereferencing Accept: %v", err)
return err return false, nil
} }
acceptable, err := ap.ResolveAcceptable(ctx, rsp.Body) acceptable, err := ap.ResolveAcceptable(ctx, rsp.Body)
@ -483,66 +668,71 @@ func (d *Dereferencer) validateApprovedBy(
_ = rsp.Body.Close() _ = rsp.Body.Close()
if err != nil { if err != nil {
err := gtserror.Newf("error resolving Accept %s: %w", approvedByURIStr, err) l.Errorf("error resolving to Accept: %v", err)
return err return false, err
} }
// Extract the URI/ID of the Accept. // Extract the URI/ID of the Accept.
acceptURI := ap.GetJSONLDId(acceptable) acceptID := ap.GetJSONLDId(acceptable)
acceptURIStr := acceptURI.String() acceptIDStr := acceptID.String()
// Check whether input URI and final returned URI // Check whether input URI and final returned URI
// have changed (i.e. we followed some redirects). // have changed (i.e. we followed some redirects).
rspURL := rsp.Request.URL rspURL := rsp.Request.URL
rspURLStr := rspURL.String() rspURLStr := rspURL.String()
switch { if rspURLStr != acceptIRIStr {
case rspURLStr == approvedByURIStr: // If rspURLStr != acceptIRIStr, make sure final
// response URL is at least on the same host as
// what we expected (ie., we weren't redirected
// across domains), and make sure it's the same
// as the ID of the Accept we were returned.
switch {
case rspURL.Host != acceptIRI.Host:
l.Errorf(
"final deref host %s did not match acceptIRI host",
rspURL.Host,
)
return false, nil
// i.e. from here, rspURLStr != approvedByURIStr. case acceptIDStr != rspURLStr:
// l.Errorf(
// Make sure it's at least on the same host as "final deref uri %s did not match returned Accept ID %s",
// what we expected (ie., we weren't redirected rspURLStr, acceptIDStr,
// across domains), and make sure it's the same )
// as the ID of the Accept we were returned. return false, nil
case rspURL.Host != approvedByURI.Host: }
return gtserror.Newf(
"final dereference host %s did not match approvedByURI host %s",
rspURL.Host, approvedByURI.Host,
)
case acceptURIStr != rspURLStr:
return gtserror.Newf(
"final dereference uri %s did not match returned Accept ID/URI %s",
rspURLStr, acceptURIStr,
)
} }
// Response is superficially OK,
// check in more detail now.
// Extract the actor IRI and string from Accept. // Extract the actor IRI and string from Accept.
actorIRIs := ap.GetActorIRIs(acceptable) actorIRIs := ap.GetActorIRIs(acceptable)
actorIRI, actorIRIStr := extractIRI(actorIRIs) actorIRI, actorIRIStr := extractIRI(actorIRIs)
switch { switch {
case actorIRIStr == "": case actorIRIStr == "":
err := gtserror.New("missing Accept actor IRI") l.Error("Accept missing actor IRI")
return gtserror.SetMalformed(err) return false, nil
// Ensure the Accept Actor is who we expect // Ensure the Accept Actor is on
// it to be, and not someone else trying to // the instance hosting the Accept.
// do an Accept for an interaction with a case actorIRI.Host != acceptID.Host:
// statusable they don't own. l.Errorf(
case actorIRI.Host != acceptURI.Host: "actor %s not on the same host as Accept",
return gtserror.Newf( actorIRIStr,
"Accept Actor %s was not the same host as Accept %s",
actorIRIStr, acceptURIStr,
) )
return false, nil
// Ensure the Accept Actor is who we expect // Ensure the Accept Actor is who we expect
// it to be, and not someone else trying to // it to be, and not someone else trying to
// do an Accept for an interaction with a // do an Accept for an interaction with a
// statusable they don't own. // statusable they don't own.
case actorIRIStr != expectActorURIStr: case actorIRIStr != expectActorURIStr:
return gtserror.Newf( l.Errorf(
"Accept Actor %s was not the same as expected actor %s", "actor %s was not the same as expected actor %s",
actorIRIStr, expectActorURIStr, actorIRIStr, expectActorURIStr,
) )
return false, nil
} }
// Extract the object IRI string from Accept. // Extract the object IRI string from Accept.
@ -550,20 +740,22 @@ func (d *Dereferencer) validateApprovedBy(
_, objectIRIStr := extractIRI(objectIRIs) _, objectIRIStr := extractIRI(objectIRIs)
switch { switch {
case objectIRIStr == "": case objectIRIStr == "":
err := gtserror.New("missing Accept object IRI") l.Error("missing Accept object IRI")
return gtserror.SetMalformed(err) return false, nil
// Ensure the Accept Object is what we expect // Ensure the Accept Object is what we expect
// it to be, ie., it's Accepting the interaction // it to be, ie., it's Accepting the interaction
// we need it to Accept, and not something else. // we need it to Accept, and not something else.
case objectIRIStr != expectObjectURIStr: case objectIRIStr != expectObjectURIStr:
return gtserror.Newf( l.Errorf(
"resolved Accept Object uri %s was not the same as expected object %s", "resolved Accept object IRI %s was not the same as expected object %s",
objectIRIStr, expectObjectURIStr, objectIRIStr, expectObjectURIStr,
) )
return false, nil
} }
return nil // Everything looks OK.
return true, nil
} }
// extractIRI is shorthand to extract the first IRI // extractIRI is shorthand to extract the first IRI

View file

@ -170,12 +170,6 @@ func (f *federatingActor) PostInboxScheme(ctx context.Context, w http.ResponseWr
// //
// Post the activity to the Actor's inbox and trigger side effects. // Post the activity to the Actor's inbox and trigger side effects.
if err := f.sideEffectActor.PostInbox(ctx, inboxID, activity); err != nil { if err := f.sideEffectActor.PostInbox(ctx, inboxID, activity); err != nil {
// Check if a function in the federatingDB
// has returned an explicit errWithCode for us.
if errWithCode, ok := err.(gtserror.WithCode); ok {
return false, errWithCode
}
// Check if it's a bad request because the // Check if it's a bad request because the
// object or target props weren't populated, // object or target props weren't populated,
// or we failed parsing activity details. // or we failed parsing activity details.
@ -193,6 +187,12 @@ func (f *federatingActor) PostInboxScheme(ctx context.Context, w http.ResponseWr
return false, gtserror.NewErrorBadRequest(errors.New(text), text) return false, gtserror.NewErrorBadRequest(errors.New(text), text)
} }
// Check if a function in the federatingDB
// has returned an explicit errWithCode for us.
if errWithCode, ok := err.(gtserror.WithCode); ok {
return false, errWithCode
}
// Default: there's been some real error. // Default: there's been some real error.
err := gtserror.Newf("error calling sideEffectActor.PostInbox: %w", err) err := gtserror.Newf("error calling sideEffectActor.PostInbox: %w", err)
return false, gtserror.NewErrorInternalError(err) return false, gtserror.NewErrorInternalError(err)

View file

@ -24,6 +24,7 @@
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -68,6 +69,20 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return gtserror.NewErrorBadRequest(errors.New(text), text) return gtserror.NewErrorBadRequest(errors.New(text), text)
} }
// Ensure requester is the same as the
// Actor of the Accept; you can't Accept
// something on someone else's behalf.
actorURI, err := ap.ExtractActorURI(accept)
if err != nil {
const text = "Accept had empty or invalid actor property"
return gtserror.NewErrorBadRequest(errors.New(text), text)
}
if requestingAcct.URI != actorURI.String() {
const text = "Accept actor and requesting account were not the same"
return gtserror.NewErrorBadRequest(errors.New(text), text)
}
// Iterate all provided objects in the activity, // Iterate all provided objects in the activity,
// handling the ones we know how to handle. // handling the ones we know how to handle.
for _, object := range ap.ExtractObjects(accept) { for _, object := range ap.ExtractObjects(accept) {
@ -108,18 +123,6 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return err return err
} }
// ACCEPT STATUS (reply/boost)
case uris.IsStatusesPath(objIRI):
if err := f.acceptStatusIRI(
ctx,
activityID.String(),
objIRI.String(),
receivingAcct,
requestingAcct,
); err != nil {
return err
}
// ACCEPT LIKE // ACCEPT LIKE
case uris.IsLikePath(objIRI): case uris.IsLikePath(objIRI):
if err := f.acceptLikeIRI( if err := f.acceptLikeIRI(
@ -132,9 +135,20 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return err return err
} }
// UNHANDLED // ACCEPT OTHER (reply? boost?)
//
// Don't check on IsStatusesPath
// as this may be a remote status.
default: default:
log.Debugf(ctx, "unhandled iri type: %s", objIRI) if err := f.acceptOtherIRI(
ctx,
activityID,
objIRI,
receivingAcct,
requestingAcct,
); err != nil {
return err
}
} }
} }
} }
@ -276,39 +290,91 @@ func (f *federatingDB) acceptFollowIRI(
return nil return nil
} }
func (f *federatingDB) acceptStatusIRI( func (f *federatingDB) acceptOtherIRI(
ctx context.Context, ctx context.Context,
activityID string, activityID *url.URL,
objectIRI string, objectIRI *url.URL,
receivingAcct *gtsmodel.Account, receivingAcct *gtsmodel.Account,
requestingAcct *gtsmodel.Account, requestingAcct *gtsmodel.Account,
) error { ) error {
// Lock on this potential status // See if we can get a status from the db.
// URI as we may be updating it. status, err := f.state.DB.GetStatusByURI(ctx, objectIRI.String())
unlock := f.state.FedLocks.Lock(objectIRI)
defer unlock()
// Get the status from the db.
status, err := f.state.DB.GetStatusByURI(ctx, objectIRI)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
err := gtserror.Newf("db error getting status: %w", err) err := gtserror.Newf("db error getting status: %w", err)
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
if status == nil { if status != nil {
// We didn't have a status with // We had a status stored with this
// this URI, so nothing to do. // objectIRI, proceed to accept it.
// Just return. return f.acceptStoredStatus(
ctx,
activityID,
status,
receivingAcct,
requestingAcct,
)
}
if objectIRI.Host == config.GetHost() ||
objectIRI.Host == config.GetAccountDomain() {
// Claims to be Accepting something of ours,
// but we don't have a status stored for this
// URI, so most likely it's been deleted in
// the meantime, just bail.
return nil return nil
} }
if !status.IsLocal() { // This must be an Accept of a remote Activity
// We don't process Accepts of statuses // or Object. Ensure relevance of this message
// that weren't created on our instance. // by checking that receiver follows requester.
// Just return. following, err := f.state.DB.IsFollowing(
ctx,
receivingAcct.ID,
requestingAcct.ID,
)
if err != nil {
err := gtserror.Newf("db error checking following: %w", err)
return gtserror.NewErrorInternalError(err)
}
if !following {
// If we don't follow this person, and
// they're not Accepting something we know
// about, then we don't give a good goddamn.
return nil return nil
} }
// This may be a reply, or it may be a boost,
// we can't know yet without dereferencing it,
// but let the processor worry about that.
apObjectType := ap.ObjectUnknown
// Pass to the processor and let them handle side effects.
f.state.Workers.Federator.Queue.Push(&messages.FromFediAPI{
APObjectType: apObjectType,
APActivityType: ap.ActivityAccept,
APIRI: activityID,
APObject: objectIRI,
Receiving: receivingAcct,
Requesting: requestingAcct,
})
return nil
}
func (f *federatingDB) acceptStoredStatus(
ctx context.Context,
activityID *url.URL,
status *gtsmodel.Status,
receivingAcct *gtsmodel.Account,
requestingAcct *gtsmodel.Account,
) error {
// Lock on this status URI
// as we may be updating it.
unlock := f.state.FedLocks.Lock(status.URI)
defer unlock()
pendingApproval := util.PtrOrValue(status.PendingApproval, false) pendingApproval := util.PtrOrValue(status.PendingApproval, false)
if !pendingApproval { if !pendingApproval {
// Status doesn't need approval or it's // Status doesn't need approval or it's
@ -317,14 +383,6 @@ func (f *federatingDB) acceptStatusIRI(
return nil return nil
} }
// Make sure the creator of the original status
// is the same as the inbox processing the Accept;
// this also ensures the status is local.
if status.AccountID != receivingAcct.ID {
const text = "status author account and inbox account were not the same"
return gtserror.NewErrorUnprocessableEntity(errors.New(text), text)
}
// Make sure the target of the interaction (reply/boost) // Make sure the target of the interaction (reply/boost)
// is the same as the account doing the Accept. // is the same as the account doing the Accept.
if status.BoostOfAccountID != requestingAcct.ID && if status.BoostOfAccountID != requestingAcct.ID &&
@ -335,7 +393,7 @@ func (f *federatingDB) acceptStatusIRI(
// Mark the status as approved by this Accept URI. // Mark the status as approved by this Accept URI.
status.PendingApproval = util.Ptr(false) status.PendingApproval = util.Ptr(false)
status.ApprovedByURI = activityID status.ApprovedByURI = activityID.String()
if err := f.state.DB.UpdateStatus( if err := f.state.DB.UpdateStatus(
ctx, ctx,
status, status,

View file

@ -306,7 +306,7 @@ func (f *Filter) StatusBoostable(
status.InteractionPolicy.CanAnnounce, status.InteractionPolicy.CanAnnounce,
) )
// If status is local and has no policy set, // If status has no policy set but it's local,
// check against the default policy for this // check against the default policy for this
// visibility, as we're interaction-policy aware. // visibility, as we're interaction-policy aware.
case *status.Local: case *status.Local:
@ -318,13 +318,21 @@ func (f *Filter) StatusBoostable(
policy.CanAnnounce, policy.CanAnnounce,
) )
// Otherwise, assume the status is from an // Status is from an instance that does not use
// instance that does not use / does not care // or does not care about interaction policies.
// about interaction policies, and just return OK. // We can boost it if it's unlisted or public.
default: case status.Visibility == gtsmodel.VisibilityPublic ||
status.Visibility == gtsmodel.VisibilityUnlocked:
return &gtsmodel.PolicyCheckResult{ return &gtsmodel.PolicyCheckResult{
Permission: gtsmodel.PolicyPermissionPermitted, Permission: gtsmodel.PolicyPermissionPermitted,
}, nil }, nil
// Not permitted by any of the
// above checks, so it's forbidden.
default:
return &gtsmodel.PolicyCheckResult{
Permission: gtsmodel.PolicyPermissionForbidden,
}, nil
} }
} }

View file

@ -191,6 +191,19 @@ func NewErrorGone(original error, helpText ...string) WithCode {
} }
} }
// NewErrorNotImplemented returns an ErrorWithCode 501 with the given original error and optional help text.
func NewErrorNotImplemented(original error, helpText ...string) WithCode {
safe := http.StatusText(http.StatusNotImplemented)
if helpText != nil {
safe = safe + ": " + strings.Join(helpText, ": ")
}
return withCode{
original: original,
safe: errors.New(safe),
code: http.StatusNotImplemented,
}
}
// NewErrorClientClosedRequest returns an ErrorWithCode 499 with the given original error. // NewErrorClientClosedRequest returns an ErrorWithCode 499 with the given original error.
// This error type should only be used when an http caller has already hung up their request. // This error type should only be used when an http caller has already hung up their request.
// See: https://en.wikipedia.org/wiki/List_of_HTTP_status_codes#nginx // See: https://en.wikipedia.org/wiki/List_of_HTTP_status_codes#nginx

View file

@ -69,25 +69,29 @@ type InteractionRequest struct {
Like *StatusFave `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionLike. Like *StatusFave `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionLike.
Reply *Status `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionReply. Reply *Status `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionReply.
Announce *Status `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionAnnounce. Announce *Status `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionAnnounce.
URI string `bun:",nullzero,unique"` // ActivityPub URI of the Accept (if accepted) or Reject (if rejected). Null/empty if currently neither accepted not rejected.
AcceptedAt time.Time `bun:"type:timestamptz,nullzero"` // If interaction request was accepted, time at which this occurred. AcceptedAt time.Time `bun:"type:timestamptz,nullzero"` // If interaction request was accepted, time at which this occurred.
RejectedAt time.Time `bun:"type:timestamptz,nullzero"` // If interaction request was rejected, time at which this occurred. RejectedAt time.Time `bun:"type:timestamptz,nullzero"` // If interaction request was rejected, time at which this occurred.
// ActivityPub URI of the Accept (if accepted) or Reject (if rejected).
// Field may be empty if currently neither accepted not rejected, or if
// acceptance/rejection was implicit (ie., not resulting from an Activity).
URI string `bun:",nullzero,unique"`
} }
// IsHandled returns true if interaction // IsHandled returns true if interaction
// request has been neither accepted or rejected. // request has been neither accepted or rejected.
func (ir *InteractionRequest) IsPending() bool { func (ir *InteractionRequest) IsPending() bool {
return ir.URI == "" && ir.AcceptedAt.IsZero() && ir.RejectedAt.IsZero() return !ir.IsAccepted() && !ir.IsRejected()
} }
// IsAccepted returns true if this // IsAccepted returns true if this
// interaction request has been accepted. // interaction request has been accepted.
func (ir *InteractionRequest) IsAccepted() bool { func (ir *InteractionRequest) IsAccepted() bool {
return ir.URI != "" && !ir.AcceptedAt.IsZero() return !ir.AcceptedAt.IsZero()
} }
// IsRejected returns true if this // IsRejected returns true if this
// interaction request has been rejected. // interaction request has been rejected.
func (ir *InteractionRequest) IsRejected() bool { func (ir *InteractionRequest) IsRejected() bool {
return ir.URI != "" && !ir.RejectedAt.IsZero() return !ir.RejectedAt.IsZero()
} }

View file

@ -340,14 +340,14 @@ func (c *Client) do(r *Request) (rsp *http.Response, retry bool, err error) {
if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { if u, _ := strconv.ParseUint(after, 10, 32); u != 0 {
// An integer no. of backoff seconds was provided. // An integer no. of backoff seconds was provided.
r.backoff = time.Duration(u) * time.Second r.backoff = time.Duration(u) * time.Second // #nosec G115 -- We clamp backoff below.
} else if at, _ := http.ParseTime(after); !at.Before(now) { } else if at, _ := http.ParseTime(after); !at.Before(now) {
// An HTTP formatted future date-time was provided. // An HTTP formatted future date-time was provided.
r.backoff = at.Sub(now) r.backoff = at.Sub(now)
} }
// Don't let their provided backoff exceed our max. // Don't let their provided backoff exceed our max.
if max := baseBackoff * time.Duration(c.retries); // if max := baseBackoff * time.Duration(c.retries); // #nosec G115 -- We control c.retries.
r.backoff > max { r.backoff > max {
r.backoff = max r.backoff = max
} }

View file

@ -21,6 +21,7 @@
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"math"
"os" "os"
"path" "path"
"strconv" "strconv"
@ -556,10 +557,18 @@ func (res *ffprobeResult) Process() (*result, error) {
if p := strings.SplitN(str, "/", 2); len(p) == 2 { if p := strings.SplitN(str, "/", 2); len(p) == 2 {
n, _ := strconv.ParseUint(p[0], 10, 32) n, _ := strconv.ParseUint(p[0], 10, 32)
d, _ := strconv.ParseUint(p[1], 10, 32) d, _ := strconv.ParseUint(p[1], 10, 32)
num, den = uint32(n), uint32(d)
if n > math.MaxUint32 || d > math.MaxUint32 {
return nil, gtserror.Newf("overflowed numerator or denominator")
}
num, den = uint32(n), uint32(d) // #nosec G115 -- Just checked.
} else { } else {
n, _ := strconv.ParseUint(p[0], 10, 32) n, _ := strconv.ParseUint(p[0], 10, 32)
num = uint32(n)
if n > math.MaxUint32 {
return nil, gtserror.Newf("overflowed numerator")
}
num = uint32(n) // #nosec G115 -- Just checked.
} }
// Set final divised framerate. // Set final divised framerate.

View file

@ -20,6 +20,7 @@
import ( import (
"context" "context"
"codeberg.org/gruf/go-ffmpreg/wasm"
"github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero"
) )
@ -65,6 +66,6 @@ func (r *runner) Run(ctx context.Context, cmod wazero.CompiledModule, args Args)
// Release slot back to pool on end. // Release slot back to pool on end.
defer func() { r.pool <- struct{}{} }() defer func() { r.pool <- struct{}{} }()
// Pass to main module runner. // Pass to main module runner function.
return run(ctx, cmod, args) return wasm.Run(ctx, runtime, cmod, args)
} }

View file

@ -19,20 +19,18 @@
import ( import (
"context" "context"
"io"
"os" "os"
ffmpeglib "codeberg.org/gruf/go-ffmpreg/embed/ffmpeg" ffmpeglib "codeberg.org/gruf/go-ffmpreg/embed/ffmpeg"
ffprobelib "codeberg.org/gruf/go-ffmpreg/embed/ffprobe" ffprobelib "codeberg.org/gruf/go-ffmpreg/embed/ffprobe"
"codeberg.org/gruf/go-ffmpreg/wasm"
"github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
"github.com/tetratelabs/wazero/sys"
) )
// Use all core features required by ffmpeg / ffprobe // Use all core features required by ffmpeg / ffprobe
// (these should be the same but we OR just in case). // (these should be the same but we OR just in case).
const corefeatures = ffprobelib.CoreFeatures | const corefeatures = wasm.CoreFeatures
ffmpeglib.CoreFeatures
var ( var (
// shared WASM runtime instance. // shared WASM runtime instance.
@ -47,65 +45,7 @@
// configuration options to run an instance // configuration options to run an instance
// of a compiled WebAssembly module that is // of a compiled WebAssembly module that is
// run in a typical CLI manner. // run in a typical CLI manner.
type Args struct { type Args = wasm.Args
// Optional further module configuration function.
// (e.g. to mount filesystem dir, set env vars, etc).
Config func(wazero.ModuleConfig) wazero.ModuleConfig
// Standard FDs.
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
// CLI args.
Args []string
}
// run will run the given compiled
// WebAssembly module using given args,
// using the global wazero runtime.
func run(
ctx context.Context,
cmod wazero.CompiledModule,
args Args,
) (
uint32, // exit code
error,
) {
// Prefix module name as argv0 to args.
cargs := make([]string, len(args.Args)+1)
copy(cargs[1:], args.Args)
cargs[0] = cmod.Name()
// Create base module config.
modcfg := wazero.NewModuleConfig()
modcfg = modcfg.WithArgs(cargs...)
modcfg = modcfg.WithStdin(args.Stdin)
modcfg = modcfg.WithStdout(args.Stdout)
modcfg = modcfg.WithStderr(args.Stderr)
if args.Config != nil {
// Pass through config fn.
modcfg = args.Config(modcfg)
}
// Instantiate the module from precompiled wasm module data.
mod, err := runtime.InstantiateModule(ctx, cmod, modcfg)
if mod != nil {
// Ensure closed.
_ = mod.Close(ctx)
}
// Try extract exit code.
switch err := err.(type) {
case *sys.ExitError:
return err.ExitCode(), nil
default:
return 0, err
}
}
// compileFfmpeg ensures the ffmpeg WebAssembly has been // compileFfmpeg ensures the ffmpeg WebAssembly has been
// pre-compiled into memory. If already compiled is a no-op. // pre-compiled into memory. If already compiled is a no-op.

View file

@ -399,9 +399,9 @@ func (s *scanner) scan(x1, y1, x2, y2 int, dst []uint8) {
g16 := uint16(s[1]) g16 := uint16(s[1])
b16 := uint16(s[2]) b16 := uint16(s[2])
a16 := uint16(a) a16 := uint16(a)
d[0] = uint8(r16 * 0xff / a16) d[0] = uint8(r16 * 0xff / a16) // #nosec G115 -- Overflow desired.
d[1] = uint8(g16 * 0xff / a16) d[1] = uint8(g16 * 0xff / a16) // #nosec G115 -- Overflow desired.
d[2] = uint8(b16 * 0xff / a16) d[2] = uint8(b16 * 0xff / a16) // #nosec G115 -- Overflow desired.
d[3] = a d[3] = a
} }
j += 4 j += 4
@ -431,9 +431,9 @@ func (s *scanner) scan(x1, y1, x2, y2 int, dst []uint8) {
g32 := uint32(s[2])<<8 | uint32(s[3]) g32 := uint32(s[2])<<8 | uint32(s[3])
b32 := uint32(s[4])<<8 | uint32(s[5]) b32 := uint32(s[4])<<8 | uint32(s[5])
a32 := uint32(s[6])<<8 | uint32(s[7]) a32 := uint32(s[6])<<8 | uint32(s[7])
d[0] = uint8((r32 * 0xffff / a32) >> 8) d[0] = uint8((r32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
d[1] = uint8((g32 * 0xffff / a32) >> 8) d[1] = uint8((g32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8((b32 * 0xffff / a32) >> 8) d[2] = uint8((b32 * 0xffff / a32) >> 8) // #nosec G115 -- Overflow desired.
} }
d[3] = a d[3] = a
j += 4 j += 4
@ -530,9 +530,9 @@ func (s *scanner) scan(x1, y1, x2, y2 int, dst []uint8) {
} }
d := dst[j : j+4 : j+4] d := dst[j : j+4 : j+4]
d[0] = uint8(r) d[0] = uint8(r) // #nosec G115 -- Overflow desired.
d[1] = uint8(g) d[1] = uint8(g) // #nosec G115 -- Overflow desired.
d[2] = uint8(b) d[2] = uint8(b) // #nosec G115 -- Overflow desired.
d[3] = 0xff d[3] = 0xff
iy++ iy++
@ -569,9 +569,9 @@ func (s *scanner) scan(x1, y1, x2, y2 int, dst []uint8) {
d := dst[j : j+4 : j+4] d := dst[j : j+4 : j+4]
switch a16 { switch a16 {
case 0xffff: case 0xffff:
d[0] = uint8(r16 >> 8) d[0] = uint8(r16 >> 8) // #nosec G115 -- Overflow desired.
d[1] = uint8(g16 >> 8) d[1] = uint8(g16 >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8(b16 >> 8) d[2] = uint8(b16 >> 8) // #nosec G115 -- Overflow desired.
d[3] = 0xff d[3] = 0xff
case 0: case 0:
d[0] = 0 d[0] = 0
@ -579,10 +579,10 @@ func (s *scanner) scan(x1, y1, x2, y2 int, dst []uint8) {
d[2] = 0 d[2] = 0
d[3] = 0 d[3] = 0
default: default:
d[0] = uint8(((r16 * 0xffff) / a16) >> 8) d[0] = uint8(((r16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[1] = uint8(((g16 * 0xffff) / a16) >> 8) d[1] = uint8(((g16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[2] = uint8(((b16 * 0xffff) / a16) >> 8) d[2] = uint8(((b16 * 0xffff) / a16) >> 8) // #nosec G115 -- Overflow desired.
d[3] = uint8(a16 >> 8) d[3] = uint8(a16 >> 8) // #nosec G115 -- Overflow desired.
} }
j += 4 j += 4
} }
@ -617,7 +617,7 @@ func clampFloat(x float64) uint8 {
return 255 return 255
} }
if v > 0 { if v > 0 {
return uint8(v) return uint8(v) // #nosec G115 -- Just checked.
} }
return 0 return 0
} }

View file

@ -49,9 +49,6 @@ func (m *Manager) RefetchEmojis(ctx context.Context, domain string, dereferenceM
refetchIDs []string refetchIDs []string
) )
// Get max supported remote emoji media size.
maxsz := config.GetMediaEmojiRemoteMaxSize()
// page through emojis 20 at a time, looking for those with missing images // page through emojis 20 at a time, looking for those with missing images
for { for {
// Fetch next block of emojis from database // Fetch next block of emojis from database
@ -111,8 +108,10 @@ func (m *Manager) RefetchEmojis(ctx context.Context, domain string, dereferenceM
continue continue
} }
// Get max supported remote emoji media size.
maxsz := int64(config.GetMediaEmojiRemoteMaxSize()) // #nosec G115 -- Already validated.
dataFunc := func(ctx context.Context) (reader io.ReadCloser, err error) { dataFunc := func(ctx context.Context) (reader io.ReadCloser, err error) {
return dereferenceMedia(ctx, emojiImageIRI, int64(maxsz)) return dereferenceMedia(ctx, emojiImageIRI, maxsz)
} }
processingEmoji, err := m.UpdateEmoji(ctx, emoji, dataFunc, AdditionalEmojiInfo{ processingEmoji, err := m.UpdateEmoji(ctx, emoji, dataFunc, AdditionalEmojiInfo{

View file

@ -462,11 +462,11 @@ func (p *Processor) UpdateAvatar(
gtserror.WithCode, gtserror.WithCode,
) { ) {
// Get maximum supported local media size. // Get maximum supported local media size.
maxsz := config.GetMediaLocalMaxSize() maxsz := int64(config.GetMediaLocalMaxSize()) // #nosec G115 -- Already validated.
// Ensure media within size bounds. // Ensure media within size bounds.
if avatar.Size > int64(maxsz) { if avatar.Size > maxsz {
text := fmt.Sprintf("media exceeds configured max size: %s", maxsz) text := fmt.Sprintf("media exceeds configured max size: %d", maxsz)
return nil, gtserror.NewErrorBadRequest(errors.New(text), text) return nil, gtserror.NewErrorBadRequest(errors.New(text), text)
} }
@ -478,7 +478,7 @@ func (p *Processor) UpdateAvatar(
} }
// Wrap the multipart file reader to ensure is limited to max. // Wrap the multipart file reader to ensure is limited to max.
rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, int64(maxsz)) rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, maxsz)
// Write to instance storage. // Write to instance storage.
return p.c.StoreLocalMedia(ctx, return p.c.StoreLocalMedia(ctx,
@ -507,11 +507,11 @@ func (p *Processor) UpdateHeader(
gtserror.WithCode, gtserror.WithCode,
) { ) {
// Get maximum supported local media size. // Get maximum supported local media size.
maxsz := config.GetMediaLocalMaxSize() maxsz := int64(config.GetMediaLocalMaxSize()) // #nosec G115 -- Already validated.
// Ensure media within size bounds. // Ensure media within size bounds.
if header.Size > int64(maxsz) { if header.Size > maxsz {
text := fmt.Sprintf("media exceeds configured max size: %s", maxsz) text := fmt.Sprintf("media exceeds configured max size: %d", maxsz)
return nil, gtserror.NewErrorBadRequest(errors.New(text), text) return nil, gtserror.NewErrorBadRequest(errors.New(text), text)
} }
@ -523,7 +523,7 @@ func (p *Processor) UpdateHeader(
} }
// Wrap the multipart file reader to ensure is limited to max. // Wrap the multipart file reader to ensure is limited to max.
rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, int64(maxsz)) rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, maxsz)
// Write to instance storage. // Write to instance storage.
return p.c.StoreLocalMedia(ctx, return p.c.StoreLocalMedia(ctx,

View file

@ -45,11 +45,11 @@ func (p *Processor) EmojiCreate(
) (*apimodel.Emoji, gtserror.WithCode) { ) (*apimodel.Emoji, gtserror.WithCode) {
// Get maximum supported local emoji size. // Get maximum supported local emoji size.
maxsz := config.GetMediaEmojiLocalMaxSize() maxsz := int64(config.GetMediaEmojiLocalMaxSize()) // #nosec G115 -- Already validated.
// Ensure media within size bounds. // Ensure media within size bounds.
if form.Image.Size > int64(maxsz) { if form.Image.Size > maxsz {
text := fmt.Sprintf("emoji exceeds configured max size: %s", maxsz) text := fmt.Sprintf("emoji exceeds configured max size: %d", maxsz)
return nil, gtserror.NewErrorBadRequest(errors.New(text), text) return nil, gtserror.NewErrorBadRequest(errors.New(text), text)
} }
@ -61,7 +61,7 @@ func (p *Processor) EmojiCreate(
} }
// Wrap the multipart file reader to ensure is limited to max. // Wrap the multipart file reader to ensure is limited to max.
rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, int64(maxsz)) rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, maxsz)
data := func(context.Context) (io.ReadCloser, error) { data := func(context.Context) (io.ReadCloser, error) {
return rc, nil return rc, nil
} }
@ -441,11 +441,11 @@ func (p *Processor) emojiUpdateModify(
// We can do both at the same time :) // We can do both at the same time :)
// Get maximum supported local emoji size. // Get maximum supported local emoji size.
maxsz := config.GetMediaEmojiLocalMaxSize() maxsz := int64(config.GetMediaEmojiLocalMaxSize()) // #nosec G115 -- Already validated.
// Ensure media within size bounds. // Ensure media within size bounds.
if image.Size > int64(maxsz) { if image.Size > maxsz {
text := fmt.Sprintf("emoji exceeds configured max size: %s", maxsz) text := fmt.Sprintf("emoji exceeds configured max size: %d", maxsz)
return nil, gtserror.NewErrorBadRequest(errors.New(text), text) return nil, gtserror.NewErrorBadRequest(errors.New(text), text)
} }
@ -457,7 +457,7 @@ func (p *Processor) emojiUpdateModify(
} }
// Wrap the multipart file reader to ensure is limited to max. // Wrap the multipart file reader to ensure is limited to max.
rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, int64(maxsz)) rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, maxsz)
data := func(context.Context) (io.ReadCloser, error) { data := func(context.Context) (io.ReadCloser, error) {
return rc, nil return rc, nil
} }

View file

@ -247,6 +247,12 @@ func (p *Processor) GetVisibleAPIStatuses(
continue continue
} }
if apiStatus == nil {
// Status was
// filtered out.
continue
}
// Append converted status to return slice. // Append converted status to return slice.
apiStatuses = append(apiStatuses, *apiStatus) apiStatuses = append(apiStatuses, *apiStatus)
} }

View file

@ -35,11 +35,11 @@
func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, gtserror.WithCode) { func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form *apimodel.AttachmentRequest) (*apimodel.Attachment, gtserror.WithCode) {
// Get maximum supported local media size. // Get maximum supported local media size.
maxsz := config.GetMediaLocalMaxSize() maxsz := int64(config.GetMediaLocalMaxSize()) // #nosec G115 -- Already validated.
// Ensure media within size bounds. // Ensure media within size bounds.
if form.File.Size > int64(maxsz) { if form.File.Size > maxsz {
text := fmt.Sprintf("media exceeds configured max size: %s", maxsz) text := fmt.Sprintf("media exceeds configured max size: %d", maxsz)
return nil, gtserror.NewErrorBadRequest(errors.New(text), text) return nil, gtserror.NewErrorBadRequest(errors.New(text), text)
} }
@ -58,7 +58,7 @@ func (p *Processor) Create(ctx context.Context, account *gtsmodel.Account, form
} }
// Wrap the multipart file reader to ensure is limited to max. // Wrap the multipart file reader to ensure is limited to max.
rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, int64(maxsz)) rc, _, _ := iotools.UpdateReadCloserLimit(mpfile, maxsz)
// Create local media and write to instance storage. // Create local media and write to instance storage.
attachment, errWithCode := p.c.StoreLocalMedia(ctx, attachment, errWithCode := p.c.StoreLocalMedia(ctx,

View file

@ -20,6 +20,7 @@
import ( import (
"context" "context"
"errors" "errors"
"net/url"
"time" "time"
"codeberg.org/gruf/go-kv" "codeberg.org/gruf/go-kv"
@ -144,6 +145,10 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg *messages.FromF
// ACCEPT (pending) ANNOUNCE // ACCEPT (pending) ANNOUNCE
case ap.ActivityAnnounce: case ap.ActivityAnnounce:
return p.fediAPI.AcceptAnnounce(ctx, fMsg) return p.fediAPI.AcceptAnnounce(ctx, fMsg)
// ACCEPT (remote) REPLY or ANNOUNCE
case ap.ObjectUnknown:
return p.fediAPI.AcceptRemoteStatus(ctx, fMsg)
} }
// REJECT SOMETHING // REJECT SOMETHING
@ -823,6 +828,60 @@ func (p *fediAPI) AcceptReply(ctx context.Context, fMsg *messages.FromFediAPI) e
return nil return nil
} }
func (p *fediAPI) AcceptRemoteStatus(ctx context.Context, fMsg *messages.FromFediAPI) error {
// See if we can accept a remote
// status we don't have stored yet.
objectIRI, ok := fMsg.APObject.(*url.URL)
if !ok {
return gtserror.Newf("%T not parseable as *url.URL", fMsg.APObject)
}
acceptIRI := fMsg.APIRI
if acceptIRI == nil {
return gtserror.New("acceptIRI was nil")
}
// Assume we're accepting a status; create a
// barebones status for dereferencing purposes.
bareStatus := &gtsmodel.Status{
URI: objectIRI.String(),
ApprovedByURI: acceptIRI.String(),
}
// Call RefreshStatus() to process the provided
// barebones status and insert it into the database,
// if indeed it's actually a status URI we can fetch.
//
// This will also check whether the given AcceptIRI
// actually grants permission for this status.
status, _, err := p.federate.RefreshStatus(ctx,
fMsg.Receiving.Username,
bareStatus,
nil, nil,
)
if err != nil {
return gtserror.Newf("error processing accepted status %s: %w", bareStatus.URI, err)
}
// No error means it was indeed a remote status, and the
// given acceptIRI permitted it. Timeline and notify it.
if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil {
log.Errorf(ctx, "error timelining and notifying status: %v", err)
}
// Interaction counts changed on the interacted status;
// uncache the prepared version from all timelines.
if status.InReplyToID != "" {
p.surface.invalidateStatusFromTimelines(ctx, status.InReplyToID)
}
if status.BoostOfID != "" {
p.surface.invalidateStatusFromTimelines(ctx, status.BoostOfID)
}
return nil
}
func (p *fediAPI) AcceptAnnounce(ctx context.Context, fMsg *messages.FromFediAPI) error { func (p *fediAPI) AcceptAnnounce(ctx context.Context, fMsg *messages.FromFediAPI) error {
boost, ok := fMsg.GTSModel.(*gtsmodel.Status) boost, ok := fMsg.GTSModel.(*gtsmodel.Status)
if !ok { if !ok {

View file

@ -1988,6 +1988,16 @@ func (c *Converter) InteractionReqToASAccept(
return nil, gtserror.Newf("invalid interacting account uri: %w", err) return nil, gtserror.Newf("invalid interacting account uri: %w", err)
} }
publicIRI, err := url.Parse(pub.PublicActivityPubIRI)
if err != nil {
return nil, gtserror.Newf("invalid public uri: %w", err)
}
followersIRI, err := url.Parse(req.TargetAccount.FollowersURI)
if err != nil {
return nil, gtserror.Newf("invalid followers uri: %w", err)
}
// Set id to the URI of // Set id to the URI of
// interaction request. // interaction request.
ap.SetJSONLDId(accept, acceptID) ap.SetJSONLDId(accept, acceptID)
@ -2003,6 +2013,9 @@ func (c *Converter) InteractionReqToASAccept(
// of interaction URI. // of interaction URI.
ap.AppendTo(accept, toIRI) ap.AppendTo(accept, toIRI)
// Cc to the actor's followers, and to Public.
ap.AppendCc(accept, publicIRI, followersIRI)
return accept, nil return accept, nil
} }
@ -2034,6 +2047,16 @@ func (c *Converter) InteractionReqToASReject(
return nil, gtserror.Newf("invalid interacting account uri: %w", err) return nil, gtserror.Newf("invalid interacting account uri: %w", err)
} }
publicIRI, err := url.Parse(pub.PublicActivityPubIRI)
if err != nil {
return nil, gtserror.Newf("invalid public uri: %w", err)
}
followersIRI, err := url.Parse(req.TargetAccount.FollowersURI)
if err != nil {
return nil, gtserror.Newf("invalid followers uri: %w", err)
}
// Set id to the URI of // Set id to the URI of
// interaction request. // interaction request.
ap.SetJSONLDId(reject, rejectID) ap.SetJSONLDId(reject, rejectID)
@ -2049,5 +2072,8 @@ func (c *Converter) InteractionReqToASReject(
// of interaction URI. // of interaction URI.
ap.AppendTo(reject, toIRI) ap.AppendTo(reject, toIRI)
// Cc to the actor's followers, and to Public.
ap.AppendCc(reject, publicIRI, followersIRI)
return reject, nil return reject, nil
} }

View file

@ -1181,6 +1181,10 @@ func (suite *InternalToASTestSuite) TestInteractionReqToASAccept() {
suite.Equal(`{ suite.Equal(`{
"@context": "https://www.w3.org/ns/activitystreams", "@context": "https://www.w3.org/ns/activitystreams",
"actor": "http://localhost:8080/users/the_mighty_zork", "actor": "http://localhost:8080/users/the_mighty_zork",
"cc": [
"https://www.w3.org/ns/activitystreams#Public",
"http://localhost:8080/users/the_mighty_zork/followers"
],
"id": "http://localhost:8080/users/the_mighty_zork/accepts/01J1AKMZ8JE5NW0ZSFTRC1JJNE", "id": "http://localhost:8080/users/the_mighty_zork/accepts/01J1AKMZ8JE5NW0ZSFTRC1JJNE",
"object": "https://fossbros-anonymous.io/users/foss_satan/statuses/01J1AKRRHQ6MDDQHV0TP716T2K", "object": "https://fossbros-anonymous.io/users/foss_satan/statuses/01J1AKRRHQ6MDDQHV0TP716T2K",
"to": "http://fossbros-anonymous.io/users/foss_satan", "to": "http://fossbros-anonymous.io/users/foss_satan",

View file

@ -647,7 +647,7 @@ func (c *Converter) AttachmentToAPIAttachment(ctx context.Context, media *gtsmod
Size: toAPISize(media.FileMeta.Original.Width, media.FileMeta.Original.Height), Size: toAPISize(media.FileMeta.Original.Width, media.FileMeta.Original.Height),
FrameRate: toAPIFrameRate(media.FileMeta.Original.Framerate), FrameRate: toAPIFrameRate(media.FileMeta.Original.Framerate),
Duration: util.PtrOrZero(media.FileMeta.Original.Duration), Duration: util.PtrOrZero(media.FileMeta.Original.Duration),
Bitrate: int(util.PtrOrZero(media.FileMeta.Original.Bitrate)), Bitrate: util.PtrOrZero(media.FileMeta.Original.Bitrate),
} }
// Copy over local file URL. // Copy over local file URL.
@ -1551,9 +1551,9 @@ func (c *Converter) InstanceToAPIV1Instance(ctx context.Context, i *gtsmodel.Ins
instance.Configuration.Statuses.CharactersReservedPerURL = instanceStatusesCharactersReservedPerURL instance.Configuration.Statuses.CharactersReservedPerURL = instanceStatusesCharactersReservedPerURL
instance.Configuration.Statuses.SupportedMimeTypes = instanceStatusesSupportedMimeTypes instance.Configuration.Statuses.SupportedMimeTypes = instanceStatusesSupportedMimeTypes
instance.Configuration.MediaAttachments.SupportedMimeTypes = media.SupportedMIMETypes instance.Configuration.MediaAttachments.SupportedMimeTypes = media.SupportedMIMETypes
instance.Configuration.MediaAttachments.ImageSizeLimit = int(config.GetMediaRemoteMaxSize()) instance.Configuration.MediaAttachments.ImageSizeLimit = int(config.GetMediaRemoteMaxSize()) // #nosec G115 -- Already validated.
instance.Configuration.MediaAttachments.ImageMatrixLimit = instanceMediaAttachmentsImageMatrixLimit instance.Configuration.MediaAttachments.ImageMatrixLimit = instanceMediaAttachmentsImageMatrixLimit
instance.Configuration.MediaAttachments.VideoSizeLimit = int(config.GetMediaRemoteMaxSize()) instance.Configuration.MediaAttachments.VideoSizeLimit = int(config.GetMediaRemoteMaxSize()) // #nosec G115 -- Already validated.
instance.Configuration.MediaAttachments.VideoFrameRateLimit = instanceMediaAttachmentsVideoFrameRateLimit instance.Configuration.MediaAttachments.VideoFrameRateLimit = instanceMediaAttachmentsVideoFrameRateLimit
instance.Configuration.MediaAttachments.VideoMatrixLimit = instanceMediaAttachmentsVideoMatrixLimit instance.Configuration.MediaAttachments.VideoMatrixLimit = instanceMediaAttachmentsVideoMatrixLimit
instance.Configuration.Polls.MaxOptions = config.GetStatusesPollMaxOptions() instance.Configuration.Polls.MaxOptions = config.GetStatusesPollMaxOptions()
@ -1563,7 +1563,7 @@ func (c *Converter) InstanceToAPIV1Instance(ctx context.Context, i *gtsmodel.Ins
instance.Configuration.Accounts.AllowCustomCSS = config.GetAccountsAllowCustomCSS() instance.Configuration.Accounts.AllowCustomCSS = config.GetAccountsAllowCustomCSS()
instance.Configuration.Accounts.MaxFeaturedTags = instanceAccountsMaxFeaturedTags instance.Configuration.Accounts.MaxFeaturedTags = instanceAccountsMaxFeaturedTags
instance.Configuration.Accounts.MaxProfileFields = instanceAccountsMaxProfileFields instance.Configuration.Accounts.MaxProfileFields = instanceAccountsMaxProfileFields
instance.Configuration.Emojis.EmojiSizeLimit = int(config.GetMediaEmojiLocalMaxSize()) instance.Configuration.Emojis.EmojiSizeLimit = int(config.GetMediaEmojiLocalMaxSize()) // #nosec G115 -- Already validated.
instance.Configuration.OIDCEnabled = config.GetOIDCEnabled() instance.Configuration.OIDCEnabled = config.GetOIDCEnabled()
// URLs // URLs
@ -1695,9 +1695,9 @@ func (c *Converter) InstanceToAPIV2Instance(ctx context.Context, i *gtsmodel.Ins
instance.Configuration.Statuses.CharactersReservedPerURL = instanceStatusesCharactersReservedPerURL instance.Configuration.Statuses.CharactersReservedPerURL = instanceStatusesCharactersReservedPerURL
instance.Configuration.Statuses.SupportedMimeTypes = instanceStatusesSupportedMimeTypes instance.Configuration.Statuses.SupportedMimeTypes = instanceStatusesSupportedMimeTypes
instance.Configuration.MediaAttachments.SupportedMimeTypes = media.SupportedMIMETypes instance.Configuration.MediaAttachments.SupportedMimeTypes = media.SupportedMIMETypes
instance.Configuration.MediaAttachments.ImageSizeLimit = int(config.GetMediaRemoteMaxSize()) instance.Configuration.MediaAttachments.ImageSizeLimit = int(config.GetMediaRemoteMaxSize()) // #nosec G115 -- Already validated.
instance.Configuration.MediaAttachments.ImageMatrixLimit = instanceMediaAttachmentsImageMatrixLimit instance.Configuration.MediaAttachments.ImageMatrixLimit = instanceMediaAttachmentsImageMatrixLimit
instance.Configuration.MediaAttachments.VideoSizeLimit = int(config.GetMediaRemoteMaxSize()) instance.Configuration.MediaAttachments.VideoSizeLimit = int(config.GetMediaRemoteMaxSize()) // #nosec G115 -- Already validated.
instance.Configuration.MediaAttachments.VideoFrameRateLimit = instanceMediaAttachmentsVideoFrameRateLimit instance.Configuration.MediaAttachments.VideoFrameRateLimit = instanceMediaAttachmentsVideoFrameRateLimit
instance.Configuration.MediaAttachments.VideoMatrixLimit = instanceMediaAttachmentsVideoMatrixLimit instance.Configuration.MediaAttachments.VideoMatrixLimit = instanceMediaAttachmentsVideoMatrixLimit
instance.Configuration.Polls.MaxOptions = config.GetStatusesPollMaxOptions() instance.Configuration.Polls.MaxOptions = config.GetStatusesPollMaxOptions()
@ -1707,7 +1707,7 @@ func (c *Converter) InstanceToAPIV2Instance(ctx context.Context, i *gtsmodel.Ins
instance.Configuration.Accounts.AllowCustomCSS = config.GetAccountsAllowCustomCSS() instance.Configuration.Accounts.AllowCustomCSS = config.GetAccountsAllowCustomCSS()
instance.Configuration.Accounts.MaxFeaturedTags = instanceAccountsMaxFeaturedTags instance.Configuration.Accounts.MaxFeaturedTags = instanceAccountsMaxFeaturedTags
instance.Configuration.Accounts.MaxProfileFields = instanceAccountsMaxProfileFields instance.Configuration.Accounts.MaxProfileFields = instanceAccountsMaxProfileFields
instance.Configuration.Emojis.EmojiSizeLimit = int(config.GetMediaEmojiLocalMaxSize()) instance.Configuration.Emojis.EmojiSizeLimit = int(config.GetMediaEmojiLocalMaxSize()) // #nosec G115 -- Already validated.
instance.Configuration.OIDCEnabled = config.GetOIDCEnabled() instance.Configuration.OIDCEnabled = config.GetOIDCEnabled()
// registrations // registrations
@ -1832,46 +1832,23 @@ func (c *Converter) NotificationToAPINotification(
func (c *Converter) ConversationToAPIConversation( func (c *Converter) ConversationToAPIConversation(
ctx context.Context, ctx context.Context,
conversation *gtsmodel.Conversation, conversation *gtsmodel.Conversation,
requestingAccount *gtsmodel.Account, requester *gtsmodel.Account,
filters []*gtsmodel.Filter, filters []*gtsmodel.Filter,
mutes *usermute.CompiledUserMuteList, mutes *usermute.CompiledUserMuteList,
) (*apimodel.Conversation, error) { ) (*apimodel.Conversation, error) {
apiConversation := &apimodel.Conversation{ apiConversation := &apimodel.Conversation{
ID: conversation.ID, ID: conversation.ID,
Unread: !*conversation.Read, Unread: !*conversation.Read,
Accounts: []apimodel.Account{},
}
for _, account := range conversation.OtherAccounts {
var apiAccount *apimodel.Account
blocked, err := c.state.DB.IsEitherBlocked(ctx, requestingAccount.ID, account.ID)
if err != nil {
return nil, gtserror.Newf(
"DB error checking blocks between accounts %s and %s: %w",
requestingAccount.ID,
account.ID,
err,
)
}
if blocked || account.IsSuspended() {
apiAccount, err = c.AccountToAPIAccountBlocked(ctx, account)
} else {
apiAccount, err = c.AccountToAPIAccountPublic(ctx, account)
}
if err != nil {
return nil, gtserror.Newf(
"error converting account %s to API representation: %w",
account.ID,
err,
)
}
apiConversation.Accounts = append(apiConversation.Accounts, *apiAccount)
} }
// Populate most recent status in convo;
// can be nil if this status is filtered.
if conversation.LastStatus != nil { if conversation.LastStatus != nil {
var err error var err error
apiConversation.LastStatus, err = c.StatusToAPIStatus( apiConversation.LastStatus, err = c.StatusToAPIStatus(
ctx, ctx,
conversation.LastStatus, conversation.LastStatus,
requestingAccount, requester,
statusfilter.FilterContextNotifications, statusfilter.FilterContextNotifications,
filters, filters,
mutes, mutes,
@ -1885,6 +1862,60 @@ func (c *Converter) ConversationToAPIConversation(
} }
} }
// If no other accounts are involved in this convo,
// just include the requesting account and return.
//
// See: https://github.com/superseriousbusiness/gotosocial/issues/3385#issuecomment-2394033477
otherAcctsLen := len(conversation.OtherAccounts)
if otherAcctsLen == 0 {
apiAcct, err := c.AccountToAPIAccountPublic(ctx, requester)
if err != nil {
err := gtserror.Newf(
"error converting account %s to API representation: %w",
requester.ID, err,
)
return nil, err
}
apiConversation.Accounts = []apimodel.Account{*apiAcct}
return apiConversation, nil
}
// Other accounts are involved in the
// convo. Convert each to API model.
apiConversation.Accounts = make([]apimodel.Account, otherAcctsLen)
for i, account := range conversation.OtherAccounts {
blocked, err := c.state.DB.IsEitherBlocked(ctx,
requester.ID, account.ID,
)
if err != nil {
err := gtserror.Newf(
"db error checking blocks between accounts %s and %s: %w",
requester.ID, account.ID, err,
)
return nil, err
}
// API account model varies depending
// on status of conversation participant.
var apiAcct *apimodel.Account
if blocked || account.IsSuspended() {
apiAcct, err = c.AccountToAPIAccountBlocked(ctx, account)
} else {
apiAcct, err = c.AccountToAPIAccountPublic(ctx, account)
}
if err != nil {
err := gtserror.Newf(
"error converting account %s to API representation: %w",
account.ID, err,
)
return nil, err
}
apiConversation.Accounts[i] = *apiAcct
}
return apiConversation, nil return apiConversation, nil
} }
@ -2680,7 +2711,7 @@ func (c *Converter) InteractionReqToAPIInteractionReq(
} }
var reply *apimodel.Status var reply *apimodel.Status
if req.InteractionType == gtsmodel.InteractionReply { if req.InteractionType == gtsmodel.InteractionReply && req.Reply != nil {
reply, err = c.statusToAPIStatus( reply, err = c.statusToAPIStatus(
ctx, ctx,
req.Reply, req.Reply,

View file

@ -3358,6 +3358,321 @@ func (suite *InternalToFrontendTestSuite) TestIntReqToAPI() {
}`, string(b)) }`, string(b))
} }
func (suite *InternalToFrontendTestSuite) TestConversationToAPISelfConvo() {
var (
ctx = context.Background()
requester = suite.testAccounts["local_account_1"]
lastStatus = suite.testStatuses["local_account_1_status_1"]
filters []*gtsmodel.Filter = nil
mutes *usermute.CompiledUserMuteList = nil
)
convo := &gtsmodel.Conversation{
ID: "01J9C6K86PKZ5GY5WXV94DGH6R",
CreatedAt: testrig.TimeMustParse("2022-06-10T15:22:08Z"),
UpdatedAt: testrig.TimeMustParse("2022-06-10T15:22:08Z"),
AccountID: requester.ID,
Account: requester,
OtherAccounts: nil,
LastStatus: lastStatus,
Read: util.Ptr(true),
}
apiConvo, err := suite.typeconverter.ConversationToAPIConversation(
ctx,
convo,
requester,
filters,
mutes,
)
if err != nil {
suite.FailNow(err.Error())
}
b, err := json.MarshalIndent(apiConvo, "", " ")
if err != nil {
suite.FailNow(err.Error())
}
// No other accounts involved, so we should only
// have our own account in the "accounts" field.
suite.Equal(`{
"id": "01J9C6K86PKZ5GY5WXV94DGH6R",
"unread": false,
"accounts": [
{
"id": "01F8MH1H7YV1Z7D2C8K2730QBF",
"username": "the_mighty_zork",
"acct": "the_mighty_zork",
"display_name": "original zork (he/they)",
"locked": false,
"discoverable": true,
"bot": false,
"created_at": "2022-05-20T11:09:18.000Z",
"note": "\u003cp\u003ehey yo this is my profile!\u003c/p\u003e",
"url": "http://localhost:8080/@the_mighty_zork",
"avatar": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/original/01F8MH58A357CV5K7R7TJMSH6S.jpg",
"avatar_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.webp",
"avatar_description": "a green goblin looking nasty",
"header": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/original/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg",
"header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.webp",
"header_description": "A very old-school screenshot of the original team fortress mod for quake",
"followers_count": 2,
"following_count": 2,
"statuses_count": 8,
"last_status_at": "2024-01-10T09:24:00.000Z",
"emojis": [],
"fields": [],
"enable_rss": true
}
],
"last_status": {
"id": "01F8MHAMCHF6Y650WCRSCP4WMY",
"created_at": "2021-10-20T10:40:37.000Z",
"in_reply_to_id": null,
"in_reply_to_account_id": null,
"sensitive": true,
"spoiler_text": "introduction post",
"visibility": "public",
"language": "en",
"uri": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY",
"url": "http://localhost:8080/@the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY",
"replies_count": 2,
"reblogs_count": 1,
"favourites_count": 1,
"favourited": false,
"reblogged": false,
"muted": false,
"bookmarked": false,
"pinned": false,
"content": "hello everyone!",
"reblog": null,
"application": {
"name": "really cool gts application",
"website": "https://reallycool.app"
},
"account": {
"id": "01F8MH1H7YV1Z7D2C8K2730QBF",
"username": "the_mighty_zork",
"acct": "the_mighty_zork",
"display_name": "original zork (he/they)",
"locked": false,
"discoverable": true,
"bot": false,
"created_at": "2022-05-20T11:09:18.000Z",
"note": "\u003cp\u003ehey yo this is my profile!\u003c/p\u003e",
"url": "http://localhost:8080/@the_mighty_zork",
"avatar": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/original/01F8MH58A357CV5K7R7TJMSH6S.jpg",
"avatar_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.webp",
"avatar_description": "a green goblin looking nasty",
"header": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/original/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg",
"header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.webp",
"header_description": "A very old-school screenshot of the original team fortress mod for quake",
"followers_count": 2,
"following_count": 2,
"statuses_count": 8,
"last_status_at": "2024-01-10T09:24:00.000Z",
"emojis": [],
"fields": [],
"enable_rss": true
},
"media_attachments": [],
"mentions": [],
"tags": [],
"emojis": [],
"card": null,
"poll": null,
"text": "hello everyone!",
"interaction_policy": {
"can_favourite": {
"always": [
"public",
"me"
],
"with_approval": []
},
"can_reply": {
"always": [
"public",
"me"
],
"with_approval": []
},
"can_reblog": {
"always": [
"public",
"me"
],
"with_approval": []
}
}
}
}`, string(b))
}
func (suite *InternalToFrontendTestSuite) TestConversationToAPI() {
var (
ctx = context.Background()
requester = suite.testAccounts["local_account_1"]
lastStatus = suite.testStatuses["local_account_1_status_1"]
filters []*gtsmodel.Filter = nil
mutes *usermute.CompiledUserMuteList = nil
)
convo := &gtsmodel.Conversation{
ID: "01J9C6K86PKZ5GY5WXV94DGH6R",
CreatedAt: testrig.TimeMustParse("2022-06-10T15:22:08Z"),
UpdatedAt: testrig.TimeMustParse("2022-06-10T15:22:08Z"),
AccountID: requester.ID,
Account: requester,
OtherAccounts: []*gtsmodel.Account{
suite.testAccounts["local_account_2"],
},
LastStatus: lastStatus,
Read: util.Ptr(false),
}
apiConvo, err := suite.typeconverter.ConversationToAPIConversation(
ctx,
convo,
requester,
filters,
mutes,
)
if err != nil {
suite.FailNow(err.Error())
}
b, err := json.MarshalIndent(apiConvo, "", " ")
if err != nil {
suite.FailNow(err.Error())
}
// One other account is involved, so they
// should in the "accounts" field and not us.
suite.Equal(`{
"id": "01J9C6K86PKZ5GY5WXV94DGH6R",
"unread": true,
"accounts": [
{
"id": "01F8MH5NBDF2MV7CTC4Q5128HF",
"username": "1happyturtle",
"acct": "1happyturtle",
"display_name": "happy little turtle :3",
"locked": true,
"discoverable": false,
"bot": false,
"created_at": "2022-06-04T13:12:00.000Z",
"note": "\u003cp\u003ei post about things that concern me\u003c/p\u003e",
"url": "http://localhost:8080/@1happyturtle",
"avatar": "",
"avatar_static": "",
"header": "http://localhost:8080/assets/default_header.webp",
"header_static": "http://localhost:8080/assets/default_header.webp",
"followers_count": 1,
"following_count": 1,
"statuses_count": 8,
"last_status_at": "2021-07-28T08:40:37.000Z",
"emojis": [],
"fields": [
{
"name": "should you follow me?",
"value": "maybe!",
"verified_at": null
},
{
"name": "age",
"value": "120",
"verified_at": null
}
],
"hide_collections": true
}
],
"last_status": {
"id": "01F8MHAMCHF6Y650WCRSCP4WMY",
"created_at": "2021-10-20T10:40:37.000Z",
"in_reply_to_id": null,
"in_reply_to_account_id": null,
"sensitive": true,
"spoiler_text": "introduction post",
"visibility": "public",
"language": "en",
"uri": "http://localhost:8080/users/the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY",
"url": "http://localhost:8080/@the_mighty_zork/statuses/01F8MHAMCHF6Y650WCRSCP4WMY",
"replies_count": 2,
"reblogs_count": 1,
"favourites_count": 1,
"favourited": false,
"reblogged": false,
"muted": false,
"bookmarked": false,
"pinned": false,
"content": "hello everyone!",
"reblog": null,
"application": {
"name": "really cool gts application",
"website": "https://reallycool.app"
},
"account": {
"id": "01F8MH1H7YV1Z7D2C8K2730QBF",
"username": "the_mighty_zork",
"acct": "the_mighty_zork",
"display_name": "original zork (he/they)",
"locked": false,
"discoverable": true,
"bot": false,
"created_at": "2022-05-20T11:09:18.000Z",
"note": "\u003cp\u003ehey yo this is my profile!\u003c/p\u003e",
"url": "http://localhost:8080/@the_mighty_zork",
"avatar": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/original/01F8MH58A357CV5K7R7TJMSH6S.jpg",
"avatar_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/avatar/small/01F8MH58A357CV5K7R7TJMSH6S.webp",
"avatar_description": "a green goblin looking nasty",
"header": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/original/01PFPMWK2FF0D9WMHEJHR07C3Q.jpg",
"header_static": "http://localhost:8080/fileserver/01F8MH1H7YV1Z7D2C8K2730QBF/header/small/01PFPMWK2FF0D9WMHEJHR07C3Q.webp",
"header_description": "A very old-school screenshot of the original team fortress mod for quake",
"followers_count": 2,
"following_count": 2,
"statuses_count": 8,
"last_status_at": "2024-01-10T09:24:00.000Z",
"emojis": [],
"fields": [],
"enable_rss": true
},
"media_attachments": [],
"mentions": [],
"tags": [],
"emojis": [],
"card": null,
"poll": null,
"text": "hello everyone!",
"interaction_policy": {
"can_favourite": {
"always": [
"public",
"me"
],
"with_approval": []
},
"can_reply": {
"always": [
"public",
"me"
],
"with_approval": []
},
"can_reblog": {
"always": [
"public",
"me"
],
"with_approval": []
}
}
}
}`, string(b))
}
func TestInternalToFrontendTestSuite(t *testing.T) { func TestInternalToFrontendTestSuite(t *testing.T) {
suite.Run(t, new(InternalToFrontendTestSuite)) suite.Run(t, new(InternalToFrontendTestSuite))
} }

View file

@ -618,7 +618,7 @@ func NewTestAccounts() map[string]*gtsmodel.Account {
} }
if diff := len(accountsSorted) - len(preserializedKeys); diff > 0 { if diff := len(accountsSorted) - len(preserializedKeys); diff > 0 {
keyStrings := make([]string, diff) keyStrings := make([]string, 0, diff)
for i := 0; i < diff; i++ { for i := 0; i < diff; i++ {
priv, _ := rsa.GenerateKey(rand.Reader, 2048) priv, _ := rsa.GenerateKey(rand.Reader, 2048)
key, _ := x509.MarshalPKCS8PrivateKey(priv) key, _ := x509.MarshalPKCS8PrivateKey(priv)

Binary file not shown.

View file

@ -3,8 +3,6 @@
import ( import (
_ "embed" _ "embed"
"os" "os"
"github.com/tetratelabs/wazero/api"
) )
func init() { func init() {
@ -23,14 +21,5 @@ func init() {
} }
} }
// CoreFeatures is the WebAssembly Core specification
// features this embedded binary was compiled with.
const CoreFeatures = api.CoreFeatureSIMD |
api.CoreFeatureBulkMemoryOperations |
api.CoreFeatureNonTrappingFloatToIntConversion |
api.CoreFeatureMutableGlobal |
api.CoreFeatureReferenceTypes |
api.CoreFeatureSignExtensionOps
//go:embed ffmpeg.wasm //go:embed ffmpeg.wasm
var B []byte var B []byte

Binary file not shown.

View file

@ -3,8 +3,6 @@
import ( import (
_ "embed" _ "embed"
"os" "os"
"github.com/tetratelabs/wazero/api"
) )
func init() { func init() {
@ -23,14 +21,5 @@ func init() {
} }
} }
// CoreFeatures is the WebAssembly Core specification
// features this embedded binary was compiled with.
const CoreFeatures = api.CoreFeatureSIMD |
api.CoreFeatureBulkMemoryOperations |
api.CoreFeatureNonTrappingFloatToIntConversion |
api.CoreFeatureMutableGlobal |
api.CoreFeatureReferenceTypes |
api.CoreFeatureSignExtensionOps
//go:embed ffprobe.wasm //go:embed ffprobe.wasm
var B []byte var B []byte

89
vendor/codeberg.org/gruf/go-ffmpreg/wasm/instance.go generated vendored Normal file
View file

@ -0,0 +1,89 @@
package wasm
import (
"context"
"io"
"unsafe"
"github.com/tetratelabs/wazero"
"github.com/tetratelabs/wazero/api"
"github.com/tetratelabs/wazero/sys"
)
// CoreFeatures are the WebAssembly Core specification
// features our embedded binaries are compiled with.
const CoreFeatures = api.CoreFeatureSIMD |
api.CoreFeatureBulkMemoryOperations |
api.CoreFeatureNonTrappingFloatToIntConversion |
api.CoreFeatureMutableGlobal |
api.CoreFeatureReferenceTypes |
api.CoreFeatureSignExtensionOps
// Args encompasses a common set of
// configuration options often passed to
// wazero.Runtime on module instantiation.
type Args struct {
// Optional further module configuration function.
// (e.g. to mount filesystem dir, set env vars, etc).
Config func(wazero.ModuleConfig) wazero.ModuleConfig
// Standard FDs.
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
// CLI args.
Args []string
}
// Run will run given compiled WebAssembly module
// within the given runtime, with given arguments.
// Returns the exit code, or error.
func Run(
ctx context.Context,
runtime wazero.Runtime,
module wazero.CompiledModule,
args Args,
) (rc uint32, err error) {
// Prefix arguments with module name.
cargs := make([]string, len(args.Args)+1)
cargs[0] = module.Name()
copy(cargs[1:], args.Args)
// Prepare new module configuration.
modcfg := wazero.NewModuleConfig()
modcfg = modcfg.WithArgs(cargs...)
modcfg = modcfg.WithStdin(args.Stdin)
modcfg = modcfg.WithStdout(args.Stdout)
modcfg = modcfg.WithStderr(args.Stderr)
if args.Config != nil {
// Pass through config fn.
modcfg = args.Config(modcfg)
}
// Instantiate the module from precompiled wasm module data.
mod, err := runtime.InstantiateModule(ctx, module, modcfg)
if !isNil(mod) {
// Ensure closed.
_ = mod.Close(ctx)
}
// Try extract exit code.
switch err := err.(type) {
case *sys.ExitError:
return err.ExitCode(), nil
default:
return 0, err
}
}
// isNil will safely check if 'v' is nil without
// dealing with weird Go interface nil bullshit.
func isNil(i interface{}) bool {
type eface struct{ Type, Data unsafe.Pointer }
return (*(*eface)(unsafe.Pointer(&i))).Data == nil
}

View file

@ -31,7 +31,6 @@ type Blob struct {
// //
// https://sqlite.org/c3ref/blob_open.html // https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) { func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.mark()() defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen) blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db) dbPtr := c.arena.string(db)
@ -43,6 +42,7 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
flags = 1 flags = 1
} }
c.checkInterrupt(c.handle)
r := c.call("sqlite3_blob_open", uint64(c.handle), r := c.call("sqlite3_blob_open", uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr), uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr)) uint64(row), flags, uint64(blobPtr))

View file

@ -284,7 +284,10 @@ func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pa
// //
// https://sqlite.org/c3ref/autovacuum_pages.html // https://sqlite.org/c3ref/autovacuum_pages.html
func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error { func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error {
funcPtr := util.AddHandle(c.ctx, cb) var funcPtr uint32
if cb != nil {
funcPtr = util.AddHandle(c.ctx, cb)
}
r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr)) r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr))
return c.error(r) return c.error(r)
} }

View file

@ -24,7 +24,7 @@ type Conn struct {
pending *Stmt pending *Stmt
stmts []*Stmt stmts []*Stmt
timer *time.Timer timer *time.Timer
busy func(int) bool busy func(context.Context, int) bool
log func(xErrorCode, string) log func(xErrorCode, string)
collation func(*Conn, string) collation func(*Conn, string)
wal func(*Conn, string, int) error wal func(*Conn, string, int) error
@ -38,14 +38,20 @@ type Conn struct {
handle uint32 handle uint32
} }
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW]. // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
func Open(filename string) (*Conn, error) { func Open(filename string) (*Conn, error) {
return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW) return newConn(context.Background(), filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
}
// OpenContext is like [Open] but includes a context,
// which is used to interrupt the process of opening the connectiton.
func OpenContext(ctx context.Context, filename string) (*Conn, error) {
return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
} }
// OpenFlags opens an SQLite database file as specified by the filename argument. // OpenFlags opens an SQLite database file as specified by the filename argument.
// //
// If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used. // If none of the required flags are used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma": // If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
// //
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)") // sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
@ -55,25 +61,33 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 { if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE flags |= OPEN_READWRITE | OPEN_CREATE
} }
return newConn(filename, flags) return newConn(context.Background(), filename, flags)
} }
type connKey struct{} type connKey struct{}
func newConn(filename string, flags OpenFlag) (conn *Conn, err error) { func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ error) {
sqlite, err := instantiateSQLite() err := ctx.Err()
if err != nil {
return nil, err
}
c := &Conn{interrupt: ctx}
c.sqlite, err = instantiateSQLite()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
if conn == nil { if res == nil {
sqlite.close() c.Close()
c.sqlite.close()
} else {
c.interrupt = context.Background()
} }
}() }()
c := &Conn{sqlite: sqlite}
c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c) c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.arena = c.newArena(1024)
c.handle, err = c.openDB(filename, flags) c.handle, err = c.openDB(filename, flags)
if err == nil { if err == nil {
err = initExtensions(c) err = initExtensions(c)
@ -98,6 +112,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
return 0, err return 0, err
} }
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") { if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok { if _, after, ok := strings.Cut(filename, "?"); ok {
@ -109,6 +124,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
} }
} }
if pragmas.Len() != 0 { if pragmas.Len() != 0 {
c.checkInterrupt(handle)
pragmaPtr := c.arena.string(pragmas.String()) pragmaPtr := c.arena.string(pragmas.String())
r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0) r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil { if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
@ -118,7 +134,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
} }
} }
} }
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
return handle, nil return handle, nil
} }
@ -160,10 +175,10 @@ func (c *Conn) Close() error {
// //
// https://sqlite.org/c3ref/exec.html // https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error { func (c *Conn) Exec(sql string) error {
c.checkInterrupt()
defer c.arena.mark()() defer c.arena.mark()()
sqlPtr := c.arena.string(sql) sqlPtr := c.arena.string(sql)
c.checkInterrupt(c.handle)
r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0) r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r, sql) return c.error(r, sql)
} }
@ -301,8 +316,7 @@ func (c *Conn) ReleaseMemory() error {
return c.error(r) return c.error(r)
} }
// GetInterrupt gets the context set with [Conn.SetInterrupt], // GetInterrupt gets the context set with [Conn.SetInterrupt].
// or nil if none was set.
func (c *Conn) GetInterrupt() context.Context { func (c *Conn) GetInterrupt() context.Context {
return c.interrupt return c.interrupt
} }
@ -322,9 +336,11 @@ func (c *Conn) GetInterrupt() context.Context {
// //
// https://sqlite.org/c3ref/interrupt.html // https://sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is it the same context? old = c.interrupt
if ctx == c.interrupt { c.interrupt = ctx
return ctx
if ctx == old || ctx.Done() == old.Done() {
return old
} }
// A busy SQL statement prevents SQLite from ignoring an interrupt // A busy SQL statement prevents SQLite from ignoring an interrupt
@ -333,32 +349,29 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
defer c.arena.mark()() defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen) stmtPtr := c.arena.new(ptrlen)
loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, 0, uint64(stmtPtr), 0) c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64,
uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0)
c.pending = &Stmt{c: c} c.pending = &Stmt{c: c}
c.pending.handle = util.ReadUint32(c.mod, stmtPtr) c.pending.handle = util.ReadUint32(c.mod, stmtPtr)
} }
old = c.interrupt if old.Done() != nil && ctx.Err() == nil {
c.interrupt = ctx
if old != nil && old.Done() != nil && (ctx == nil || ctx.Err() == nil) {
c.pending.Reset() c.pending.Reset()
} }
if ctx != nil && ctx.Done() != nil { if ctx.Done() != nil {
c.pending.Step() c.pending.Step()
} }
return old return old
} }
func (c *Conn) checkInterrupt() { func (c *Conn) checkInterrupt(handle uint32) {
if c.interrupt != nil && c.interrupt.Err() != nil { if c.interrupt.Err() != nil {
c.call("sqlite3_interrupt", uint64(c.handle)) c.call("sqlite3_interrupt", uint64(handle))
} }
} }
func progressCallback(ctx context.Context, mod api.Module, pDB uint32) (interrupt uint32) { func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() != nil {
c.interrupt != nil && c.interrupt.Err() != nil {
interrupt = 1 interrupt = 1
} }
return interrupt return interrupt
@ -373,9 +386,8 @@ func (c *Conn) BusyTimeout(timeout time.Duration) error {
return c.error(r) return c.error(r)
} }
func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmout int32) (retry uint32) { func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil {
(c.interrupt == nil || c.interrupt.Err() == nil) {
const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64" const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64"
const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4" const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4"
const ndelay = int32(len(delays) - 1) const ndelay = int32(len(delays) - 1)
@ -391,7 +403,7 @@ func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmo
if delay = min(delay, tmout-prior); delay > 0 { if delay = min(delay, tmout-prior); delay > 0 {
delay := time.Duration(delay) * time.Millisecond delay := time.Duration(delay) * time.Millisecond
if c.interrupt == nil || c.interrupt.Done() == nil { if c.interrupt.Done() == nil {
time.Sleep(delay) time.Sleep(delay)
return 1 return 1
} }
@ -414,7 +426,7 @@ func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmo
// BusyHandler registers a callback to handle [BUSY] errors. // BusyHandler registers a callback to handle [BUSY] errors.
// //
// https://sqlite.org/c3ref/busy_handler.html // https://sqlite.org/c3ref/busy_handler.html
func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error { func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error {
var enable uint64 var enable uint64
if cb != nil { if cb != nil {
enable = 1 enable = 1
@ -428,9 +440,12 @@ func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error {
} }
func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) { func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil && if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil {
(c.interrupt == nil || c.interrupt.Err() == nil) { interrupt := c.interrupt
if c.busy(int(count)) { if interrupt == nil {
interrupt = context.Background()
}
if interrupt.Err() == nil && c.busy(interrupt, int(count)) {
retry = 1 retry = 1
} }
} }

View file

@ -1,4 +1,4 @@
//go:build (go1.23 || goexperiment.rangefunc) && !vet //go:build go1.23
package sqlite3 package sqlite3

View file

@ -1,4 +1,4 @@
//go:build !(go1.23 || goexperiment.rangefunc) || vet //go:build !go1.23
package sqlite3 package sqlite3

View file

@ -40,14 +40,14 @@
// When using a custom time struct, you'll have to implement // When using a custom time struct, you'll have to implement
// [database/sql/driver.Valuer] and [database/sql.Scanner]. // [database/sql/driver.Valuer] and [database/sql.Scanner].
// //
// The Value method should ideally serialise to a time [format] supported by SQLite. // The Value method should ideally encode to a time [format] supported by SQLite.
// This ensures SQL date and time functions work as they should, // This ensures SQL date and time functions work as they should,
// and that your schema works with other SQLite tools. // and that your schema works with other SQLite tools.
// [sqlite3.TimeFormat.Encode] may help. // [sqlite3.TimeFormat.Encode] may help.
// //
// The Scan method needs to take into account that the value it receives can be of differing types. // The Scan method needs to take into account that the value it receives can be of differing types.
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules. // It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
// Or it can be a: string, int64, float64, []byte, nil, // Or it can be a: string, int64, float64, []byte, or nil,
// depending on the column type and what whoever wrote the value. // depending on the column type and what whoever wrote the value.
// [sqlite3.TimeFormat.Decode] may help. // [sqlite3.TimeFormat.Decode] may help.
// //
@ -202,19 +202,19 @@ func (n *connector) Driver() driver.Driver {
return n.driver return n.driver
} }
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
c := &conn{ c := &conn{
txLock: n.txLock, txLock: n.txLock,
tmRead: n.tmRead, tmRead: n.tmRead,
tmWrite: n.tmWrite, tmWrite: n.tmWrite,
} }
c.Conn, err = sqlite3.Open(n.name) c.Conn, err = sqlite3.OpenContext(ctx, n.name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
if err != nil { if res == nil {
c.Close() c.Close()
} }
}() }()
@ -239,6 +239,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer s.Close()
if s.Step() && s.ColumnBool(0) { if s.Step() && s.ColumnBool(0) {
c.readOnly = '1' c.readOnly = '1'
} else { } else {
@ -466,6 +467,7 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
defer s.Stmt.Conn().SetInterrupt(old) defer s.Stmt.Conn().SetInterrupt(old)
err = s.Stmt.Exec() err = s.Stmt.Exec()
s.Stmt.ClearBindings()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -488,7 +490,7 @@ func (s *stmt) setupBindings(args []driver.NamedValue) (err error) {
if arg.Name == "" { if arg.Name == "" {
ids = append(ids, arg.Ordinal) ids = append(ids, arg.Ordinal)
} else { } else {
for _, prefix := range []string{":", "@", "$"} { for _, prefix := range [...]string{":", "@", "$"} {
if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 { if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id) ids = append(ids, id)
} }
@ -522,9 +524,9 @@ func (s *stmt) setupBindings(args []driver.NamedValue) (err error) {
default: default:
panic(util.AssertErr()) panic(util.AssertErr())
} }
} if err != nil {
if err != nil { return err
return err }
} }
} }
return nil return nil
@ -595,10 +597,11 @@ func (r *rows) Close() error {
func (r *rows) Columns() []string { func (r *rows) Columns() []string {
if r.names == nil { if r.names == nil {
count := r.Stmt.ColumnCount() count := r.Stmt.ColumnCount()
r.names = make([]string, count) names := make([]string, count)
for i := range r.names { for i := range names {
r.names[i] = r.Stmt.ColumnName(i) names[i] = r.Stmt.ColumnName(i)
} }
r.names = names
} }
return r.names return r.names
} }
@ -606,26 +609,29 @@ func (r *rows) Columns() []string {
func (r *rows) loadTypes() { func (r *rows) loadTypes() {
if r.nulls == nil { if r.nulls == nil {
count := r.Stmt.ColumnCount() count := r.Stmt.ColumnCount()
r.nulls = make([]bool, count) nulls := make([]bool, count)
r.types = make([]string, count) types := make([]string, count)
for i := range r.nulls { for i := range nulls {
if col := r.Stmt.ColumnOriginName(i); col != "" { if col := r.Stmt.ColumnOriginName(i); col != "" {
r.types[i], _, r.nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata( types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i), r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i), r.Stmt.ColumnTableName(i),
col) col)
} }
} }
r.nulls = nulls
r.types = types
} }
} }
func (r *rows) declType(index int) string { func (r *rows) declType(index int) string {
if r.types == nil { if r.types == nil {
count := r.Stmt.ColumnCount() count := r.Stmt.ColumnCount()
r.types = make([]string, count) types := make([]string, count)
for i := range r.types { for i := range types {
r.types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i)) types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
} }
r.types = types
} }
return r.types[index] return r.types[index]
} }
@ -665,27 +671,23 @@ func (r *rows) Next(dest []driver.Value) error {
for i := range dest { for i := range dest {
if t, ok := r.decodeTime(i, dest[i]); ok { if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t dest[i] = t
continue
}
if s, ok := dest[i].(string); ok {
t, ok := maybeTime(s)
if ok {
dest[i] = t
}
} }
} }
return err return err
} }
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) { func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
switch r.tmRead { switch v := v.(type) {
case sqlite3.TimeFormatDefault, time.RFC3339Nano: case int64, float64:
// handled by maybeTime
return
}
switch v.(type) {
case int64, float64, string:
// could be a time value // could be a time value
case string:
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
break
}
t, ok := maybeTime(v)
if ok {
return t, true
}
default: default:
return return
} }

View file

@ -9,6 +9,7 @@ The following optional features are compiled in:
- [JSON](https://sqlite.org/json1.html) - [JSON](https://sqlite.org/json1.html)
- [R*Tree](https://sqlite.org/rtree.html) - [R*Tree](https://sqlite.org/rtree.html)
- [GeoPoly](https://sqlite.org/geopoly.html) - [GeoPoly](https://sqlite.org/geopoly.html)
- [Spellfix1](https://sqlite.org/spellfix1.html)
- [soundex](https://sqlite.org/lang_corefunc.html#soundex) - [soundex](https://sqlite.org/lang_corefunc.html#soundex)
- [stat4](https://sqlite.org/compile.html#enable_stat4) - [stat4](https://sqlite.org/compile.html#enable_stat4)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c) - [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)

View file

@ -14,7 +14,7 @@ trap 'rm -f sqlite3.tmp' EXIT
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \ -o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \ -I"$ROOT/sqlite3" \
-mexec-model=reactor \ -mexec-model=reactor \
-matomics -msimd128 -mmutable-globals \ -matomics -msimd128 -mmutable-globals -mmultivalue \
-mbulk-memory -mreference-types \ -mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \ -mnontrapping-fptoint -msign-ext \
-fno-stack-protector -fno-stack-clash-protection \ -fno-stack-protector -fno-stack-clash-protection \

View file

@ -51,6 +51,7 @@ sqlite3_create_collation_go
sqlite3_create_function_go sqlite3_create_function_go
sqlite3_create_module_go sqlite3_create_module_go
sqlite3_create_window_function_go sqlite3_create_window_function_go
sqlite3_data_count
sqlite3_database_file_object sqlite3_database_file_object
sqlite3_db_cacheflush sqlite3_db_cacheflush
sqlite3_db_config sqlite3_db_config

Binary file not shown.

View file

@ -33,16 +33,23 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
// one or more unknown collating sequences. // one or more unknown collating sequences.
func (c Conn) AnyCollationNeeded() error { func (c Conn) AnyCollationNeeded() error {
r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0) r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
return c.error(r) if err := c.error(r); err != nil {
return err
}
c.collation = nil
return nil
} }
// CreateCollation defines a new collating sequence. // CreateCollation defines a new collating sequence.
// //
// https://sqlite.org/c3ref/create_collation.html // https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
var funcPtr uint32
defer c.arena.mark()() defer c.arena.mark()()
namePtr := c.arena.string(name) namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn) if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
r := c.call("sqlite3_create_collation_go", r := c.call("sqlite3_create_collation_go",
uint64(c.handle), uint64(namePtr), uint64(funcPtr)) uint64(c.handle), uint64(namePtr), uint64(funcPtr))
return c.error(r) return c.error(r)
@ -52,9 +59,12 @@ funcPtr := util.AddHandle(c.ctx, fn)
// //
// https://sqlite.org/c3ref/create_function.html // https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error { func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
var funcPtr uint32
defer c.arena.mark()() defer c.arena.mark()()
namePtr := c.arena.string(name) namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn) if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
r := c.call("sqlite3_create_function_go", r := c.call("sqlite3_create_function_go",
uint64(c.handle), uint64(namePtr), uint64(nArg), uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr)) uint64(flag), uint64(funcPtr))
@ -71,10 +81,13 @@ funcPtr := util.AddHandle(c.ctx, fn)
// //
// https://sqlite.org/c3ref/create_function.html // https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
var funcPtr uint32
defer c.arena.mark()() defer c.arena.mark()()
call := "sqlite3_create_aggregate_function_go"
namePtr := c.arena.string(name) namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn) if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
call := "sqlite3_create_aggregate_function_go"
if _, ok := fn().(WindowFunction); ok { if _, ok := fn().(WindowFunction); ok {
call = "sqlite3_create_window_function_go" call = "sqlite3_create_window_function_go"
} }
@ -184,11 +197,12 @@ func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32)
// We need to create the aggregate. // We need to create the aggregate.
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
handle := util.AddHandle(db.ctx, fn)
if pAgg != 0 { if pAgg != 0 {
handle := util.AddHandle(db.ctx, fn)
util.WriteUint32(db.mod, pAgg, handle) util.WriteUint32(db.mod, pAgg, handle)
return fn, handle
} }
return fn, handle return fn, 0
} }
func callbackArgs(db *Conn, arg []Value, pArg uint32) { func callbackArgs(db *Conn, arg []Value, pArg uint32) {

View file

@ -1,10 +1,13 @@
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=

View file

@ -47,7 +47,7 @@ func (m *mmappedMemory) Reallocate(size uint64) []byte {
// Commit additional memory up to new bytes. // Commit additional memory up to new bytes.
err := unix.Mprotect(m.buf[com:new], unix.PROT_READ|unix.PROT_WRITE) err := unix.Mprotect(m.buf[com:new], unix.PROT_READ|unix.PROT_WRITE)
if err != nil { if err != nil {
panic(err) return nil
} }
// Update committed memory. // Update committed memory.

View file

@ -56,7 +56,7 @@ func (m *virtualMemory) Reallocate(size uint64) []byte {
// Commit additional memory up to new bytes. // Commit additional memory up to new bytes.
_, err := windows.VirtualAlloc(m.addr, uintptr(new), windows.MEM_COMMIT, windows.PAGE_READWRITE) _, err := windows.VirtualAlloc(m.addr, uintptr(new), windows.MEM_COMMIT, windows.PAGE_READWRITE)
if err != nil { if err != nil {
panic(err) return nil
} }
// Update committed memory. // Update committed memory.

View file

@ -26,6 +26,7 @@ func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(con
type funcVII[T0, T1 i32] func(context.Context, api.Module, T0, T1) type funcVII[T0, T1 i32] func(context.Context, api.Module, T0, T1)
func (fn funcVII[T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVII[T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1])) fn(ctx, mod, T0(stack[0]), T1(stack[1]))
} }
@ -39,6 +40,7 @@ func ExportFuncVII[T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn fun
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[2] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])) fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
} }
@ -52,6 +54,7 @@ func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, f
type funcVIIII[T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) type funcVIIII[T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3)
func (fn funcVIIII[T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVIIII[T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])) fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))
} }
@ -65,6 +68,7 @@ func ExportFuncVIIII[T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name stri
type funcVIIIII[T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) type funcVIIIII[T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4)
func (fn funcVIIIII[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVIIIII[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])) fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))
} }
@ -78,6 +82,7 @@ func ExportFuncVIIIII[T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name
type funcVIIIIJ[T0, T1, T2, T3 i32, T4 i64] func(context.Context, api.Module, T0, T1, T2, T3, T4) type funcVIIIIJ[T0, T1, T2, T3 i32, T4 i64] func(context.Context, api.Module, T0, T1, T2, T3, T4)
func (fn funcVIIIIJ[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVIIIIJ[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])) fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))
} }
@ -104,6 +109,7 @@ func ExportFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func
type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR
func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
} }
@ -117,6 +123,7 @@ func ExportFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn
type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR
func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[2] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
} }
@ -130,6 +137,7 @@ func ExportFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name strin
type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
} }
@ -143,6 +151,7 @@ func ExportFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name
type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR
func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
} }
@ -156,6 +165,7 @@ func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder,
type funcIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR type funcIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR
func (fn funcIIIIIII[TR, T0, T1, T2, T3, T4, T5]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIIIII[TR, T0, T1, T2, T3, T4, T5]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[5] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]), T5(stack[5]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]), T5(stack[5])))
} }
@ -169,6 +179,7 @@ func ExportFuncIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32](mod wazero.HostModuleBuil
type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
} }
@ -182,6 +193,7 @@ func ExportFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, n
type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR
func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
} }

View file

@ -35,17 +35,22 @@ func DelHandle(ctx context.Context, id uint32) error {
s := ctx.Value(moduleKey{}).(*moduleState) s := ctx.Value(moduleKey{}).(*moduleState)
a := s.handles[^id] a := s.handles[^id]
s.handles[^id] = nil s.handles[^id] = nil
s.holes++ if l := uint32(len(s.handles)); l == ^id {
s.handles = s.handles[:l-1]
} else {
s.holes++
}
if c, ok := a.(io.Closer); ok { if c, ok := a.(io.Closer); ok {
return c.Close() return c.Close()
} }
return nil return nil
} }
func AddHandle(ctx context.Context, a any) (id uint32) { func AddHandle(ctx context.Context, a any) uint32 {
if a == nil { if a == nil {
panic(NilErr) panic(NilErr)
} }
s := ctx.Value(moduleKey{}).(*moduleState) s := ctx.Value(moduleKey{}).(*moduleState)
// Find an empty slot. // Find an empty slot.

View file

@ -3,6 +3,7 @@
import ( import (
"bytes" "bytes"
"math" "math"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -13,6 +14,9 @@
// Quote escapes and quotes a value // Quote escapes and quotes a value
// making it safe to embed in SQL text. // making it safe to embed in SQL text.
// Strings with embedded NUL characters are truncated.
//
// https://sqlite.org/lang_corefunc.html#quote
func Quote(value any) string { func Quote(value any) string {
switch v := value.(type) { switch v := value.(type) {
case nil: case nil:
@ -42,8 +46,8 @@ func Quote(value any) string {
return "'" + v.Format(time.RFC3339Nano) + "'" return "'" + v.Format(time.RFC3339Nano) + "'"
case string: case string:
if strings.IndexByte(v, 0) >= 0 { if i := strings.IndexByte(v, 0); i >= 0 {
break v = v[:i]
} }
buf := make([]byte, 2+len(v)+strings.Count(v, "'")) buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
@ -57,13 +61,13 @@ func Quote(value any) string {
buf[i] = b buf[i] = b
i += 1 i += 1
} }
buf[i] = '\'' buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
case []byte: case []byte:
buf := make([]byte, 3+2*len(v)) buf := make([]byte, 3+2*len(v))
buf[0] = 'x'
buf[1] = '\'' buf[1] = '\''
buf[0] = 'x'
i := 2 i := 2
for _, b := range v { for _, b := range v {
const hex = "0123456789ABCDEF" const hex = "0123456789ABCDEF"
@ -71,26 +75,50 @@ func Quote(value any) string {
buf[i+1] = hex[b%16] buf[i+1] = hex[b%16]
i += 2 i += 2
} }
buf[i] = '\'' buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
case ZeroBlob: case ZeroBlob:
if v > ZeroBlob(1e9-3)/2 {
break
}
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v))) buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
buf[0] = 'x'
buf[1] = '\'' buf[1] = '\''
buf[0] = 'x'
buf[len(buf)-1] = '\'' buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
} }
v := reflect.ValueOf(value)
k := v.Kind()
if k == reflect.Interface || k == reflect.Pointer {
if v.IsNil() {
return "NULL"
}
v = v.Elem()
k = v.Kind()
}
switch {
case v.CanInt():
return strconv.FormatInt(v.Int(), 10)
case v.CanUint():
return strconv.FormatUint(v.Uint(), 10)
case v.CanFloat():
return Quote(v.Float())
case k == reflect.Bool:
return Quote(v.Bool())
case k == reflect.String:
return Quote(v.String())
case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
v.Type().Elem().Kind() == reflect.Uint8:
return Quote(v.Bytes())
}
panic(util.ValueErr) panic(util.ValueErr)
} }
// QuoteIdentifier escapes and quotes an identifier // QuoteIdentifier escapes and quotes an identifier
// making it safe to embed in SQL text. // making it safe to embed in SQL text.
// Strings with embedded NUL characters panic.
func QuoteIdentifier(id string) string { func QuoteIdentifier(id string) string {
if strings.IndexByte(id, 0) >= 0 { if strings.IndexByte(id, 0) >= 0 {
panic(util.ValueErr) panic(util.ValueErr)
@ -107,6 +135,6 @@ func QuoteIdentifier(id string) string {
buf[i] = b buf[i] = b
i += 1 i += 1
} }
buf[i] = '"' buf[len(buf)-1] = '"'
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
} }

View file

@ -131,7 +131,7 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH) err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH)
} }
if sql != nil { if len(sql) != 0 {
if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 { if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:] err.sql = sql[0][r:]
} }
@ -301,7 +301,7 @@ func (a *arena) string(s string) uint32 {
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress_handler", progressCallback) util.ExportFuncII(env, "go_progress_handler", progressCallback)
util.ExportFuncIIII(env, "go_busy_timeout", timeoutCallback) util.ExportFuncIII(env, "go_busy_timeout", timeoutCallback)
util.ExportFuncIII(env, "go_busy_handler", busyCallback) util.ExportFuncIII(env, "go_busy_handler", busyCallback)
util.ExportFuncII(env, "go_commit_hook", commitCallback) util.ExportFuncII(env, "go_commit_hook", commitCallback)
util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)

View file

@ -30,12 +30,13 @@ func (s *Stmt) Close() error {
} }
r := s.c.call("sqlite3_finalize", uint64(s.handle)) r := s.c.call("sqlite3_finalize", uint64(s.handle))
for i := range s.c.stmts { stmts := s.c.stmts
if s == s.c.stmts[i] { for i := range stmts {
l := len(s.c.stmts) - 1 if s == stmts[i] {
s.c.stmts[i] = s.c.stmts[l] l := len(stmts) - 1
s.c.stmts[l] = nil stmts[i] = stmts[l]
s.c.stmts = s.c.stmts[:l] stmts[l] = nil
s.c.stmts = stmts[:l]
break break
} }
} }
@ -105,7 +106,7 @@ func (s *Stmt) Busy() bool {
// //
// https://sqlite.org/c3ref/step.html // https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool { func (s *Stmt) Step() bool {
s.c.checkInterrupt() s.c.checkInterrupt(s.c.handle)
r := s.c.call("sqlite3_step", uint64(s.handle)) r := s.c.call("sqlite3_step", uint64(s.handle))
switch r { switch r {
case _ROW: case _ROW:
@ -376,6 +377,15 @@ func (s *Stmt) BindValue(param int, value Value) error {
return s.c.error(r) return s.c.error(r)
} }
// DataCount resets the number of columns in a result set.
//
// https://sqlite.org/c3ref/data_count.html
func (s *Stmt) DataCount() int {
r := s.c.call("sqlite3_data_count",
uint64(s.handle))
return int(int32(r))
}
// ColumnCount returns the number of columns in a result set. // ColumnCount returns the number of columns in a result set.
// //
// https://sqlite.org/c3ref/column_count.html // https://sqlite.org/c3ref/column_count.html
@ -630,7 +640,7 @@ func (s *Stmt) Columns(dest []any) error {
defer s.c.arena.mark()() defer s.c.arena.mark()()
count := uint64(len(dest)) count := uint64(len(dest))
typePtr := s.c.arena.new(count) typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(8 * count) dataPtr := s.c.arena.new(count * 8)
r := s.c.call("sqlite3_columns_go", r := s.c.call("sqlite3_columns_go",
uint64(s.handle), count, uint64(typePtr), uint64(dataPtr)) uint64(s.handle), count, uint64(typePtr), uint64(dataPtr))
@ -639,26 +649,31 @@ func (s *Stmt) Columns(dest []any) error {
} }
types := util.View(s.c.mod, typePtr, count) types := util.View(s.c.mod, typePtr, count)
// Avoid bounds checks on types below.
if len(types) != len(dest) {
panic(util.AssertErr())
}
for i := range dest { for i := range dest {
switch types[i] { switch types[i] {
case byte(INTEGER): case byte(INTEGER):
dest[i] = int64(util.ReadUint64(s.c.mod, dataPtr+8*uint32(i))) dest[i] = int64(util.ReadUint64(s.c.mod, dataPtr))
continue
case byte(FLOAT): case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, dataPtr+8*uint32(i)) dest[i] = util.ReadFloat64(s.c.mod, dataPtr)
continue
case byte(NULL): case byte(NULL):
dest[i] = nil dest[i] = nil
continue default:
} ptr := util.ReadUint32(s.c.mod, dataPtr+0)
ptr := util.ReadUint32(s.c.mod, dataPtr+8*uint32(i)+0) len := util.ReadUint32(s.c.mod, dataPtr+4)
len := util.ReadUint32(s.c.mod, dataPtr+8*uint32(i)+4) buf := util.View(s.c.mod, ptr, uint64(len))
buf := util.View(s.c.mod, ptr, uint64(len)) if types[i] == byte(TEXT) {
if types[i] == byte(TEXT) { dest[i] = string(buf)
dest[i] = string(buf) } else {
} else { dest[i] = buf
dest[i] = buf }
} }
dataPtr += 8
} }
return nil return nil
} }

View file

@ -138,6 +138,9 @@ func (f TimeFormat) Encode(t time.Time) any {
// //
// https://sqlite.org/lang_datefunc.html // https://sqlite.org/lang_datefunc.html
func (f TimeFormat) Decode(v any) (time.Time, error) { func (f TimeFormat) Decode(v any) (time.Time, error) {
if t, ok := v.(time.Time); ok {
return t, nil
}
switch f { switch f {
// Numeric formats. // Numeric formats.
case TimeFormatJulianDay: case TimeFormatJulianDay:

View file

@ -3,7 +3,6 @@
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"math/rand" "math/rand"
"runtime" "runtime"
"strconv" "strconv"
@ -136,23 +135,21 @@ type Savepoint struct {
// //
// https://sqlite.org/lang_savepoint.html // https://sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint { func (c *Conn) Savepoint() Savepoint {
// Names can be reused; this makes catching bugs more likely. name := callerName()
name := saveptName() + "_" + strconv.Itoa(int(rand.Int31())) if name == "" {
name = "sqlite3.Savepoint"
}
// Names can be reused, but this makes catching bugs more likely.
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
err := c.txnExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name)) err := c.txnExecInterrupted("SAVEPOINT " + name)
if err != nil { if err != nil {
panic(err) panic(err)
} }
return Savepoint{c: c, name: name} return Savepoint{c: c, name: name}
} }
func saveptName() (name string) { func callerName() (name string) {
defer func() {
if name == "" {
name = "sqlite3.Savepoint"
}
}()
var pc [8]uintptr var pc [8]uintptr
n := runtime.Callers(3, pc[:]) n := runtime.Callers(3, pc[:])
if n <= 0 { if n <= 0 {
@ -189,7 +186,7 @@ func (s Savepoint) Release(errp *error) {
if s.c.GetAutocommit() { // There is nothing to commit. if s.c.GetAutocommit() { // There is nothing to commit.
return return
} }
*errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name)) *errp = s.c.Exec("RELEASE " + s.name)
if *errp == nil { if *errp == nil {
return return
} }
@ -201,10 +198,8 @@ func (s Savepoint) Release(errp *error) {
return return
} }
// ROLLBACK and RELEASE even if interrupted. // ROLLBACK and RELEASE even if interrupted.
err := s.c.txnExecInterrupted(fmt.Sprintf(` err := s.c.txnExecInterrupted("ROLLBACK TO " +
ROLLBACK TO %[1]q; s.name + "; RELEASE " + s.name)
RELEASE %[1]q;
`, s.name))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -217,7 +212,7 @@ func (s Savepoint) Release(errp *error) {
// https://sqlite.org/lang_transaction.html // https://sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error { func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted. // ROLLBACK even if interrupted.
return s.c.txnExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name)) return s.c.txnExecInterrupted("ROLLBACK TO " + s.name)
} }
func (c *Conn) txnExecInterrupted(sql string) error { func (c *Conn) txnExecInterrupted(sql string) error {

View file

@ -15,7 +15,7 @@ func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" { if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: ENOENT} return nil, &os.PathError{Op: "open", Path: name, Err: ENOENT}
} }
r, e := syscallOpen(name, flag, uint32(perm.Perm())) r, e := syscallOpen(name, flag|O_CLOEXEC, uint32(perm.Perm()))
if e != nil { if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e} return nil, &os.PathError{Op: "open", Path: name, Err: e}
} }

View file

@ -19,17 +19,18 @@ func (vfsOS) FullPathname(path string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
fi, err := os.Lstat(path) return path, testSymlinks(filepath.Dir(path))
}
func testSymlinks(path string) error {
p, err := filepath.EvalSymlinks(path)
if err != nil { if err != nil {
if errors.Is(err, fs.ErrNotExist) { return err
return path, nil
}
return "", err
} }
if fi.Mode()&fs.ModeSymlink != 0 { if p != path {
err = _OK_SYMLINK return _OK_SYMLINK
} }
return path, err return nil
} }
func (vfsOS) Delete(path string, syncDir bool) error { func (vfsOS) Delete(path string, syncDir bool) error {
@ -74,7 +75,7 @@ func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
} }
func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error) { func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error) {
var oflags int oflags := _O_NOFOLLOW
if flags&OPEN_EXCLUSIVE != 0 { if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL oflags |= os.O_EXCL
} }

View file

@ -43,7 +43,8 @@ func Create(name string, data []byte) {
} }
// Convert data from WAL/2 to rollback journal. // Convert data from WAL/2 to rollback journal.
if len(data) >= 20 && (data[18] == 2 && data[19] == 2 || if len(data) >= 20 && (false ||
data[18] == 2 && data[19] == 2 ||
data[18] == 3 && data[19] == 3) { data[18] == 3 && data[19] == 3) {
data[18] = 1 data[18] = 1
data[19] = 1 data[19] = 1

View file

@ -7,6 +7,8 @@
"os" "os"
) )
const _O_NOFOLLOW = 0
func osAccess(path string, flags AccessFlag) error { func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path) fi, err := os.Stat(path)
if err != nil { if err != nil {
@ -34,3 +36,12 @@ func osAccess(path string, flags AccessFlag) error {
} }
return nil return nil
} }
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}

View file

@ -1,14 +0,0 @@
//go:build !unix || sqlite3_nosys
package vfs
import "os"
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}

View file

@ -9,6 +9,8 @@
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const _O_NOFOLLOW = unix.O_NOFOLLOW
func osAccess(path string, flags AccessFlag) error { func osAccess(path string, flags AccessFlag) error {
var access uint32 // unix.F_OK var access uint32 // unix.F_OK
switch flags { switch flags {

View file

@ -57,9 +57,12 @@ func CreateModule[T VTab](db *Conn, name string, create, connect VTabConstructor
flags |= VTAB_SHADOWTABS flags |= VTAB_SHADOWTABS
} }
var modulePtr uint32
defer db.arena.mark()() defer db.arena.mark()()
namePtr := db.arena.string(name) namePtr := db.arena.string(name)
modulePtr := util.AddHandle(db.ctx, module[T]{create, connect}) if connect != nil {
modulePtr = util.AddHandle(db.ctx, module[T]{create, connect})
}
r := db.call("sqlite3_create_module_go", uint64(db.handle), r := db.call("sqlite3_create_module_go", uint64(db.handle),
uint64(namePtr), uint64(flags), uint64(modulePtr)) uint64(namePtr), uint64(flags), uint64(modulePtr))
return db.error(r) return db.error(r)
@ -352,8 +355,9 @@ func (idx *IndexInfo) load() {
idx.OrderBy = make([]IndexOrderBy, util.ReadUint32(mod, ptr+8)) idx.OrderBy = make([]IndexOrderBy, util.ReadUint32(mod, ptr+8))
constraintPtr := util.ReadUint32(mod, ptr+4) constraintPtr := util.ReadUint32(mod, ptr+4)
constraint := idx.Constraint
for i := range idx.Constraint { for i := range idx.Constraint {
idx.Constraint[i] = IndexConstraint{ constraint[i] = IndexConstraint{
Column: int(int32(util.ReadUint32(mod, constraintPtr+0))), Column: int(int32(util.ReadUint32(mod, constraintPtr+0))),
Op: IndexConstraintOp(util.ReadUint8(mod, constraintPtr+4)), Op: IndexConstraintOp(util.ReadUint8(mod, constraintPtr+4)),
Usable: util.ReadUint8(mod, constraintPtr+5) != 0, Usable: util.ReadUint8(mod, constraintPtr+5) != 0,
@ -362,8 +366,9 @@ func (idx *IndexInfo) load() {
} }
orderByPtr := util.ReadUint32(mod, ptr+12) orderByPtr := util.ReadUint32(mod, ptr+12)
for i := range idx.OrderBy { orderBy := idx.OrderBy
idx.OrderBy[i] = IndexOrderBy{ for i := range orderBy {
orderBy[i] = IndexOrderBy{
Column: int(int32(util.ReadUint32(mod, orderByPtr+0))), Column: int(int32(util.ReadUint32(mod, orderByPtr+0))),
Desc: util.ReadUint8(mod, orderByPtr+4) != 0, Desc: util.ReadUint8(mod, orderByPtr+4) != 0,
} }

View file

@ -35,6 +35,8 @@ type LinearMemory interface {
// Notes: // Notes:
// - To back a shared memory, Reallocate can't change the address of the // - To back a shared memory, Reallocate can't change the address of the
// backing []byte (only its length/capacity may change). // backing []byte (only its length/capacity may change).
// - Reallocate may return nil if fails to grow the LinearMemory. This
// condition may or may not be handled gracefully by the Wasm module.
Reallocate(size uint64) []byte Reallocate(size uint64) []byte
// Free the backing memory buffer. // Free the backing memory buffer.
Free() Free()

View file

@ -1,6 +1,9 @@
package descriptor package descriptor
import "math/bits" import (
"math/bits"
"slices"
)
// Table is a data structure mapping 32 bit descriptor to items. // Table is a data structure mapping 32 bit descriptor to items.
// //
@ -37,23 +40,13 @@ func (t *Table[Key, Item]) Len() (n int) {
return n return n
} }
// grow ensures that t has enough room for n items, potentially reallocating the // grow grows the table by n * 64 items.
// internal buffers if their capacity was too small to hold this many items.
func (t *Table[Key, Item]) grow(n int) { func (t *Table[Key, Item]) grow(n int) {
// Round up to a multiple of 64 since this is the smallest increment due to total := len(t.masks) + n
// using 64 bits masks. t.masks = slices.Grow(t.masks, n)[:total]
n = (n*64 + 63) / 64
if n > len(t.masks) { total = len(t.items) + n*64
masks := make([]uint64, n) t.items = slices.Grow(t.items, n*64)[:total]
copy(masks, t.masks)
items := make([]Item, n*64)
copy(items, t.items)
t.masks = masks
t.items = items
}
} }
// Insert inserts the given item to the table, returning the key that it is // Insert inserts the given item to the table, returning the key that it is
@ -78,13 +71,9 @@ func (t *Table[Key, Item]) Insert(item Item) (key Key, ok bool) {
} }
} }
// No free slot found, grow the table and retry.
offset = len(t.masks) offset = len(t.masks)
n := 2 * len(t.masks) t.grow(1)
if n == 0 {
n = 1
}
t.grow(n)
goto insert goto insert
} }
@ -109,10 +98,10 @@ func (t *Table[Key, Item]) InsertAt(item Item, key Key) bool {
if key < 0 { if key < 0 {
return false return false
} }
if diff := int(key) - t.Len(); diff > 0 { index := uint(key) / 64
if diff := int(index) - len(t.masks) + 1; diff > 0 {
t.grow(diff) t.grow(diff)
} }
index := uint(key) / 64
shift := uint(key) % 64 shift := uint(key) % 64
t.masks[index] |= 1 << shift t.masks[index] |= 1 << shift
t.items[key] = item t.items[key] = item
@ -124,7 +113,8 @@ func (t *Table[Key, Item]) Delete(key Key) {
if key < 0 { // invalid key if key < 0 { // invalid key
return return
} }
if index, shift := key/64, key%64; int(index) < len(t.masks) { if index := uint(key) / 64; int(index) < len(t.masks) {
shift := uint(key) % 64
mask := t.masks[index] mask := t.masks[index]
if (mask & (1 << shift)) != 0 { if (mask & (1 << shift)) != 0 {
var zero Item var zero Item

View file

@ -487,7 +487,7 @@ func (e *engine) setLabelAddress(op *uint64, label label, labelAddressResolution
} }
// ResolveImportedFunction implements wasm.ModuleEngine. // ResolveImportedFunction implements wasm.ModuleEngine.
func (e *moduleEngine) ResolveImportedFunction(index, indexInImportedModule wasm.Index, importedModuleEngine wasm.ModuleEngine) { func (e *moduleEngine) ResolveImportedFunction(index, descFunc, indexInImportedModule wasm.Index, importedModuleEngine wasm.ModuleEngine) {
imported := importedModuleEngine.(*moduleEngine) imported := importedModuleEngine.(*moduleEngine)
e.functions[index] = imported.functions[indexInImportedModule] e.functions[index] = imported.functions[indexInImportedModule]
} }

View file

@ -237,7 +237,7 @@ func (m *moduleEngine) putLocalMemory() {
} }
// ResolveImportedFunction implements wasm.ModuleEngine. // ResolveImportedFunction implements wasm.ModuleEngine.
func (m *moduleEngine) ResolveImportedFunction(index, indexInImportedModule wasm.Index, importedModuleEngine wasm.ModuleEngine) { func (m *moduleEngine) ResolveImportedFunction(index, descFunc, indexInImportedModule wasm.Index, importedModuleEngine wasm.ModuleEngine) {
executableOffset, moduleCtxOffset, typeIDOffset := m.parent.offsets.ImportedFunctionOffset(index) executableOffset, moduleCtxOffset, typeIDOffset := m.parent.offsets.ImportedFunctionOffset(index)
importedME := importedModuleEngine.(*moduleEngine) importedME := importedModuleEngine.(*moduleEngine)
@ -245,12 +245,12 @@ func (m *moduleEngine) ResolveImportedFunction(index, indexInImportedModule wasm
indexInImportedModule -= wasm.Index(len(importedME.importedFunctions)) indexInImportedModule -= wasm.Index(len(importedME.importedFunctions))
} else { } else {
imported := &importedME.importedFunctions[indexInImportedModule] imported := &importedME.importedFunctions[indexInImportedModule]
m.ResolveImportedFunction(index, imported.indexInModule, imported.me) m.ResolveImportedFunction(index, descFunc, imported.indexInModule, imported.me)
return // Recursively resolve the imported function. return // Recursively resolve the imported function.
} }
offset := importedME.parent.functionOffsets[indexInImportedModule] offset := importedME.parent.functionOffsets[indexInImportedModule]
typeID := getTypeIDOf(indexInImportedModule, importedME.module) typeID := m.module.TypeIDs[descFunc]
executable := &importedME.parent.executable[offset] executable := &importedME.parent.executable[offset]
// Write functionInstance. // Write functionInstance.
binary.LittleEndian.PutUint64(m.opaque[executableOffset:], uint64(uintptr(unsafe.Pointer(executable)))) binary.LittleEndian.PutUint64(m.opaque[executableOffset:], uint64(uintptr(unsafe.Pointer(executable))))
@ -261,28 +261,6 @@ func (m *moduleEngine) ResolveImportedFunction(index, indexInImportedModule wasm
m.importedFunctions[index] = importedFunction{me: importedME, indexInModule: indexInImportedModule} m.importedFunctions[index] = importedFunction{me: importedME, indexInModule: indexInImportedModule}
} }
func getTypeIDOf(funcIndex wasm.Index, m *wasm.ModuleInstance) wasm.FunctionTypeID {
source := m.Source
var typeIndex wasm.Index
if funcIndex >= source.ImportFunctionCount {
funcIndex -= source.ImportFunctionCount
typeIndex = source.FunctionSection[funcIndex]
} else {
var cnt wasm.Index
for i := range source.ImportSection {
if source.ImportSection[i].Type == wasm.ExternTypeFunc {
if cnt == funcIndex {
typeIndex = source.ImportSection[i].DescFunc
break
}
cnt++
}
}
}
return m.TypeIDs[typeIndex]
}
// ResolveImportedMemory implements wasm.ModuleEngine. // ResolveImportedMemory implements wasm.ModuleEngine.
func (m *moduleEngine) ResolveImportedMemory(importedModuleEngine wasm.ModuleEngine) { func (m *moduleEngine) ResolveImportedMemory(importedModuleEngine wasm.ModuleEngine) {
importedME := importedModuleEngine.(*moduleEngine) importedME := importedModuleEngine.(*moduleEngine)

View file

@ -5,6 +5,7 @@
import ( import (
"io/fs" "io/fs"
"os" "os"
"path"
experimentalsys "github.com/tetratelabs/wazero/experimental/sys" experimentalsys "github.com/tetratelabs/wazero/experimental/sys"
) )
@ -34,6 +35,11 @@ func (d *dirFS) Chmod(path string, perm fs.FileMode) experimentalsys.Errno {
// Symlink implements the same method as documented on sys.FS // Symlink implements the same method as documented on sys.FS
func (d *dirFS) Symlink(oldName, link string) experimentalsys.Errno { func (d *dirFS) Symlink(oldName, link string) experimentalsys.Errno {
// Creating a symlink with an absolute path string fails with a "not permitted" error.
// https://github.com/WebAssembly/wasi-filesystem/blob/v0.2.0/path-resolution.md#symlinks
if path.IsAbs(oldName) {
return experimentalsys.EPERM
}
// Note: do not resolve `oldName` relative to this dirFS. The link result is always resolved // Note: do not resolve `oldName` relative to this dirFS. The link result is always resolved
// when dereference the `link` on its usage (e.g. readlink, read, etc). // when dereference the `link` on its usage (e.g. readlink, read, etc).
// https://github.com/bytecodealliance/cap-std/blob/v1.0.4/cap-std/src/fs/dir.rs#L404-L409 // https://github.com/bytecodealliance/cap-std/blob/v1.0.4/cap-std/src/fs/dir.rs#L404-L409

View file

@ -269,7 +269,7 @@ func (f *fsFile) Readdir(n int) (dirents []experimentalsys.Dirent, errno experim
if f.reopenDir { // re-open the directory if needed. if f.reopenDir { // re-open the directory if needed.
f.reopenDir = false f.reopenDir = false
if errno = adjustReaddirErr(f, f.closed, f.reopen()); errno != 0 { if errno = adjustReaddirErr(f, f.closed, f.rewindDir()); errno != 0 {
return return
} }
} }
@ -418,19 +418,25 @@ func seek(s io.Seeker, offset int64, whence int) (int64, experimentalsys.Errno)
return newOffset, experimentalsys.UnwrapOSError(err) return newOffset, experimentalsys.UnwrapOSError(err)
} }
// reopenFile allows re-opening a file for reasons such as applying flags or func (f *fsFile) rewindDir() experimentalsys.Errno {
// directory iteration. // Reopen the directory to rewind it.
type reopenFile func() experimentalsys.Errno file, err := f.fs.Open(f.name)
if err != nil {
// compile-time check to ensure fsFile.reopen implements reopenFile. return experimentalsys.UnwrapOSError(err)
var _ reopenFile = (*fsFile)(nil).reopen }
fi, err := file.Stat()
// reopen implements the same method as documented on reopenFile. if err != nil {
func (f *fsFile) reopen() experimentalsys.Errno { return experimentalsys.UnwrapOSError(err)
_ = f.close() }
var err error // Can't check if it's still the same file,
f.file, err = f.fs.Open(f.name) // but is it still a directory, at least?
return experimentalsys.UnwrapOSError(err) if !fi.IsDir() {
return experimentalsys.ENOTDIR
}
// Only update f on success.
_ = f.file.Close()
f.file = file
return 0
} }
// readdirFile allows masking the `Readdir` function on os.File. // readdirFile allows masking the `Readdir` function on os.File.

View file

@ -83,21 +83,12 @@ func (f *osFile) SetAppend(enable bool) (errno experimentalsys.Errno) {
f.flag &= ^experimentalsys.O_APPEND f.flag &= ^experimentalsys.O_APPEND
} }
// Clear any create or trunc flag, as we are re-opening, not re-creating. // appendMode cannot be changed later, so we have to re-open the file
f.flag &= ^(experimentalsys.O_CREAT | experimentalsys.O_TRUNC) // https://github.com/golang/go/blob/go1.23/src/os/file_unix.go#L60
// appendMode (bool) cannot be changed later, so we have to re-open the
// file. https://github.com/golang/go/blob/go1.20/src/os/file_unix.go#L60
return fileError(f, f.closed, f.reopen()) return fileError(f, f.closed, f.reopen())
} }
// compile-time check to ensure osFile.reopen implements reopenFile.
var _ reopenFile = (*osFile)(nil).reopen
func (f *osFile) reopen() (errno experimentalsys.Errno) { func (f *osFile) reopen() (errno experimentalsys.Errno) {
// Clear any create flag, as we are re-opening, not re-creating.
f.flag &= ^experimentalsys.O_CREAT
var ( var (
isDir bool isDir bool
offset int64 offset int64
@ -116,22 +107,47 @@ func (f *osFile) reopen() (errno experimentalsys.Errno) {
} }
} }
_ = f.close() // Clear any create or trunc flag, as we are re-opening, not re-creating.
f.file, errno = OpenFile(f.path, f.flag, f.perm) flag := f.flag &^ (experimentalsys.O_CREAT | experimentalsys.O_TRUNC)
file, errno := OpenFile(f.path, flag, f.perm)
if errno != 0 {
return errno
}
errno = f.checkSameFile(file)
if errno != 0 { if errno != 0 {
return errno return errno
} }
if !isDir { if !isDir {
_, err = f.file.Seek(offset, io.SeekStart) _, err = file.Seek(offset, io.SeekStart)
if err != nil { if err != nil {
_ = file.Close()
return experimentalsys.UnwrapOSError(err) return experimentalsys.UnwrapOSError(err)
} }
} }
// Only update f on success.
_ = f.file.Close()
f.file = file
f.fd = file.Fd()
return 0 return 0
} }
func (f *osFile) checkSameFile(osf *os.File) experimentalsys.Errno {
fi1, err := f.file.Stat()
if err != nil {
return experimentalsys.UnwrapOSError(err)
}
fi2, err := osf.Stat()
if err != nil {
return experimentalsys.UnwrapOSError(err)
}
if os.SameFile(fi1, fi2) {
return 0
}
return experimentalsys.ENOENT
}
// IsNonblock implements the same method as documented on fsapi.File // IsNonblock implements the same method as documented on fsapi.File
func (f *osFile) IsNonblock() bool { func (f *osFile) IsNonblock() bool {
return isNonblock(f) return isNonblock(f)

View file

@ -44,9 +44,10 @@ type ModuleEngine interface {
// ResolveImportedFunction is used to add imported functions needed to make this ModuleEngine fully functional. // ResolveImportedFunction is used to add imported functions needed to make this ModuleEngine fully functional.
// - `index` is the function Index of this imported function. // - `index` is the function Index of this imported function.
// - `descFunc` is the type Index in Module.TypeSection of this imported function. It corresponds to Import.DescFunc.
// - `indexInImportedModule` is the function Index of the imported function in the imported module. // - `indexInImportedModule` is the function Index of the imported function in the imported module.
// - `importedModuleEngine` is the ModuleEngine for the imported ModuleInstance. // - `importedModuleEngine` is the ModuleEngine for the imported ModuleInstance.
ResolveImportedFunction(index, indexInImportedModule Index, importedModuleEngine ModuleEngine) ResolveImportedFunction(index, descFunc, indexInImportedModule Index, importedModuleEngine ModuleEngine)
// ResolveImportedMemory is called when this module imports a memory from another module. // ResolveImportedMemory is called when this module imports a memory from another module.
ResolveImportedMemory(importedModuleEngine ModuleEngine) ResolveImportedMemory(importedModuleEngine ModuleEngine)

View file

@ -77,6 +77,7 @@ func NewMemoryInstance(memSec *Memory, allocator experimental.MemoryAllocator, m
if allocator != nil { if allocator != nil {
expBuffer = allocator.Allocate(capBytes, maxBytes) expBuffer = allocator.Allocate(capBytes, maxBytes)
buffer = expBuffer.Reallocate(minBytes) buffer = expBuffer.Reallocate(minBytes)
_ = buffer[:minBytes] // Bounds check that the minimum was allocated.
} else if memSec.IsShared { } else if memSec.IsShared {
// Shared memory needs a fixed buffer, so allocate with the maximum size. // Shared memory needs a fixed buffer, so allocate with the maximum size.
// //
@ -238,12 +239,15 @@ func (m *MemoryInstance) Grow(delta uint32) (result uint32, ok bool) {
return currentPages, true return currentPages, true
} }
// If exceeds the max of memory size, we push -1 according to the spec.
newPages := currentPages + delta newPages := currentPages + delta
if newPages > m.Max || int32(delta) < 0 { if newPages > m.Max || int32(delta) < 0 {
return 0, false return 0, false
} else if m.expBuffer != nil { } else if m.expBuffer != nil {
buffer := m.expBuffer.Reallocate(MemoryPagesToBytesNum(newPages)) buffer := m.expBuffer.Reallocate(MemoryPagesToBytesNum(newPages))
if buffer == nil {
// Allocator failed to grow.
return 0, false
}
if m.Shared { if m.Shared {
if unsafe.SliceData(buffer) != unsafe.SliceData(m.Buffer) { if unsafe.SliceData(buffer) != unsafe.SliceData(m.Buffer) {
panic("shared memory cannot move, this is a bug in the memory allocator") panic("shared memory cannot move, this is a bug in the memory allocator")

View file

@ -446,7 +446,7 @@ func (m *ModuleInstance) resolveImports(ctx context.Context, module *Module) (er
return return
} }
m.Engine.ResolveImportedFunction(i.IndexPerType, imported.Index, importedModule.Engine) m.Engine.ResolveImportedFunction(i.IndexPerType, i.DescFunc, imported.Index, importedModule.Engine)
case ExternTypeTable: case ExternTypeTable:
expected := i.DescTable expected := i.DescTable
importedTable := importedModule.Tables[imported.Index] importedTable := importedModule.Tables[imported.Index]

View file

@ -29,9 +29,7 @@
// # Notes // # Notes
// //
// - This is used for WebAssembly ABI emulating the POSIX `stat` system call. // - This is used for WebAssembly ABI emulating the POSIX `stat` system call.
// Fields included are required for WebAssembly ABI including wasip1 // See https://pubs.opengroup.org/onlinepubs/9699919799/functions/stat.html
// (a.k.a. wasix) and wasi-filesystem (a.k.a. wasip2). See
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/stat.html
// - Fields here are required for WebAssembly ABI including wasip1 // - Fields here are required for WebAssembly ABI including wasip1
// (a.k.a. wasix) and wasi-filesystem (a.k.a. wasip2). // (a.k.a. wasix) and wasi-filesystem (a.k.a. wasip2).
// - This isn't the same as syscall.Stat_t because wazero supports Windows, // - This isn't the same as syscall.Stat_t because wazero supports Windows,

View file

@ -85,9 +85,9 @@ func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
// leftEncode returns max 9 bytes // leftEncode returns max 9 bytes
c.initBlock = make([]byte, 0, 9*2+len(N)+len(S)) c.initBlock = make([]byte, 0, 9*2+len(N)+len(S))
c.initBlock = append(c.initBlock, leftEncode(uint64(len(N)*8))...) c.initBlock = append(c.initBlock, leftEncode(uint64(len(N))*8)...)
c.initBlock = append(c.initBlock, N...) c.initBlock = append(c.initBlock, N...)
c.initBlock = append(c.initBlock, leftEncode(uint64(len(S)*8))...) c.initBlock = append(c.initBlock, leftEncode(uint64(len(S))*8)...)
c.initBlock = append(c.initBlock, S...) c.initBlock = append(c.initBlock, S...)
c.Write(bytepad(c.initBlock, c.rate)) c.Write(bytepad(c.initBlock, c.rate))
return &c return &c

View file

@ -510,8 +510,8 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
if err := s.transport.writePacket(Marshal(discMsg)); err != nil { if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
return nil, err return nil, err
} }
authErrs = append(authErrs, discMsg)
return nil, discMsg return nil, &ServerAuthError{Errors: authErrs}
} }
var userAuthReq userAuthRequestMsg var userAuthReq userAuthRequestMsg

122
vendor/golang.org/x/net/http2/config.go generated vendored Normal file
View file

@ -0,0 +1,122 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http2
import (
"math"
"net/http"
"time"
)
// http2Config is a package-internal version of net/http.HTTP2Config.
//
// http.HTTP2Config was added in Go 1.24.
// When running with a version of net/http that includes HTTP2Config,
// we merge the configuration with the fields in Transport or Server
// to produce an http2Config.
//
// Zero valued fields in http2Config are interpreted as in the
// net/http.HTTPConfig documentation.
//
// Precedence order for reconciling configurations is:
//
// - Use the net/http.{Server,Transport}.HTTP2Config value, when non-zero.
// - Otherwise use the http2.{Server.Transport} value.
// - If the resulting value is zero or out of range, use a default.
type http2Config struct {
MaxConcurrentStreams uint32
MaxDecoderHeaderTableSize uint32
MaxEncoderHeaderTableSize uint32
MaxReadFrameSize uint32
MaxUploadBufferPerConnection int32
MaxUploadBufferPerStream int32
SendPingTimeout time.Duration
PingTimeout time.Duration
WriteByteTimeout time.Duration
PermitProhibitedCipherSuites bool
CountError func(errType string)
}
// configFromServer merges configuration settings from
// net/http.Server.HTTP2Config and http2.Server.
func configFromServer(h1 *http.Server, h2 *Server) http2Config {
conf := http2Config{
MaxConcurrentStreams: h2.MaxConcurrentStreams,
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
MaxUploadBufferPerConnection: h2.MaxUploadBufferPerConnection,
MaxUploadBufferPerStream: h2.MaxUploadBufferPerStream,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
PermitProhibitedCipherSuites: h2.PermitProhibitedCipherSuites,
CountError: h2.CountError,
}
fillNetHTTPServerConfig(&conf, h1)
setConfigDefaults(&conf, true)
return conf
}
// configFromServer merges configuration settings from h2 and h2.t1.HTTP2
// (the net/http Transport).
func configFromTransport(h2 *Transport) http2Config {
conf := http2Config{
MaxEncoderHeaderTableSize: h2.MaxEncoderHeaderTableSize,
MaxDecoderHeaderTableSize: h2.MaxDecoderHeaderTableSize,
MaxReadFrameSize: h2.MaxReadFrameSize,
SendPingTimeout: h2.ReadIdleTimeout,
PingTimeout: h2.PingTimeout,
WriteByteTimeout: h2.WriteByteTimeout,
}
// Unlike most config fields, where out-of-range values revert to the default,
// Transport.MaxReadFrameSize clips.
if conf.MaxReadFrameSize < minMaxFrameSize {
conf.MaxReadFrameSize = minMaxFrameSize
} else if conf.MaxReadFrameSize > maxFrameSize {
conf.MaxReadFrameSize = maxFrameSize
}
if h2.t1 != nil {
fillNetHTTPTransportConfig(&conf, h2.t1)
}
setConfigDefaults(&conf, false)
return conf
}
func setDefault[T ~int | ~int32 | ~uint32 | ~int64](v *T, minval, maxval, defval T) {
if *v < minval || *v > maxval {
*v = defval
}
}
func setConfigDefaults(conf *http2Config, server bool) {
setDefault(&conf.MaxConcurrentStreams, 1, math.MaxUint32, defaultMaxStreams)
setDefault(&conf.MaxEncoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
setDefault(&conf.MaxDecoderHeaderTableSize, 1, math.MaxUint32, initialHeaderTableSize)
if server {
setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, 1<<20)
} else {
setDefault(&conf.MaxUploadBufferPerConnection, initialWindowSize, math.MaxInt32, transportDefaultConnFlow)
}
if server {
setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, 1<<20)
} else {
setDefault(&conf.MaxUploadBufferPerStream, 1, math.MaxInt32, transportDefaultStreamFlow)
}
setDefault(&conf.MaxReadFrameSize, minMaxFrameSize, maxFrameSize, defaultMaxReadFrameSize)
setDefault(&conf.PingTimeout, 1, math.MaxInt64, 15*time.Second)
}
// adjustHTTP1MaxHeaderSize converts a limit in bytes on the size of an HTTP/1 header
// to an HTTP/2 MAX_HEADER_LIST_SIZE value.
func adjustHTTP1MaxHeaderSize(n int64) int64 {
// http2's count is in a slightly different unit and includes 32 bytes per pair.
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return n + typicalHeaders*perFieldOverhead
}

61
vendor/golang.org/x/net/http2/config_go124.go generated vendored Normal file
View file

@ -0,0 +1,61 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.24
package http2
import "net/http"
// fillNetHTTPServerConfig sets fields in conf from srv.HTTP2.
func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {
fillNetHTTPConfig(conf, srv.HTTP2)
}
// fillNetHTTPServerConfig sets fields in conf from tr.HTTP2.
func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {
fillNetHTTPConfig(conf, tr.HTTP2)
}
func fillNetHTTPConfig(conf *http2Config, h2 *http.HTTP2Config) {
if h2 == nil {
return
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxEncoderHeaderTableSize != 0 {
conf.MaxEncoderHeaderTableSize = uint32(h2.MaxEncoderHeaderTableSize)
}
if h2.MaxDecoderHeaderTableSize != 0 {
conf.MaxDecoderHeaderTableSize = uint32(h2.MaxDecoderHeaderTableSize)
}
if h2.MaxConcurrentStreams != 0 {
conf.MaxConcurrentStreams = uint32(h2.MaxConcurrentStreams)
}
if h2.MaxReadFrameSize != 0 {
conf.MaxReadFrameSize = uint32(h2.MaxReadFrameSize)
}
if h2.MaxReceiveBufferPerConnection != 0 {
conf.MaxUploadBufferPerConnection = int32(h2.MaxReceiveBufferPerConnection)
}
if h2.MaxReceiveBufferPerStream != 0 {
conf.MaxUploadBufferPerStream = int32(h2.MaxReceiveBufferPerStream)
}
if h2.SendPingTimeout != 0 {
conf.SendPingTimeout = h2.SendPingTimeout
}
if h2.PingTimeout != 0 {
conf.PingTimeout = h2.PingTimeout
}
if h2.WriteByteTimeout != 0 {
conf.WriteByteTimeout = h2.WriteByteTimeout
}
if h2.PermitProhibitedCipherSuites {
conf.PermitProhibitedCipherSuites = true
}
if h2.CountError != nil {
conf.CountError = h2.CountError
}
}

16
vendor/golang.org/x/net/http2/config_pre_go124.go generated vendored Normal file
View file

@ -0,0 +1,16 @@
// Copyright 2024 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.24
package http2
import "net/http"
// Pre-Go 1.24 fallback.
// The Server.HTTP2 and Transport.HTTP2 config fields were added in Go 1.24.
func fillNetHTTPServerConfig(conf *http2Config, srv *http.Server) {}
func fillNetHTTPTransportConfig(conf *http2Config, tr *http.Transport) {}

View file

@ -19,8 +19,9 @@
"bufio" "bufio"
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"io" "net"
"net/http" "net/http"
"os" "os"
"sort" "sort"
@ -237,13 +238,19 @@ func (cw closeWaiter) Wait() {
// Its buffered writer is lazily allocated as needed, to minimize // Its buffered writer is lazily allocated as needed, to minimize
// idle memory usage with many connections. // idle memory usage with many connections.
type bufferedWriter struct { type bufferedWriter struct {
_ incomparable _ incomparable
w io.Writer // immutable group synctestGroupInterface // immutable
bw *bufio.Writer // non-nil when data is buffered conn net.Conn // immutable
bw *bufio.Writer // non-nil when data is buffered
byteTimeout time.Duration // immutable, WriteByteTimeout
} }
func newBufferedWriter(w io.Writer) *bufferedWriter { func newBufferedWriter(group synctestGroupInterface, conn net.Conn, timeout time.Duration) *bufferedWriter {
return &bufferedWriter{w: w} return &bufferedWriter{
group: group,
conn: conn,
byteTimeout: timeout,
}
} }
// bufWriterPoolBufferSize is the size of bufio.Writer's // bufWriterPoolBufferSize is the size of bufio.Writer's
@ -270,7 +277,7 @@ func (w *bufferedWriter) Available() int {
func (w *bufferedWriter) Write(p []byte) (n int, err error) { func (w *bufferedWriter) Write(p []byte) (n int, err error) {
if w.bw == nil { if w.bw == nil {
bw := bufWriterPool.Get().(*bufio.Writer) bw := bufWriterPool.Get().(*bufio.Writer)
bw.Reset(w.w) bw.Reset((*bufferedWriterTimeoutWriter)(w))
w.bw = bw w.bw = bw
} }
return w.bw.Write(p) return w.bw.Write(p)
@ -288,6 +295,38 @@ func (w *bufferedWriter) Flush() error {
return err return err
} }
type bufferedWriterTimeoutWriter bufferedWriter
func (w *bufferedWriterTimeoutWriter) Write(p []byte) (n int, err error) {
return writeWithByteTimeout(w.group, w.conn, w.byteTimeout, p)
}
// writeWithByteTimeout writes to conn.
// If more than timeout passes without any bytes being written to the connection,
// the write fails.
func writeWithByteTimeout(group synctestGroupInterface, conn net.Conn, timeout time.Duration, p []byte) (n int, err error) {
if timeout <= 0 {
return conn.Write(p)
}
for {
var now time.Time
if group == nil {
now = time.Now()
} else {
now = group.Now()
}
conn.SetWriteDeadline(now.Add(timeout))
nn, err := conn.Write(p[n:])
n += nn
if n == len(p) || nn == 0 || !errors.Is(err, os.ErrDeadlineExceeded) {
// Either we finished the write, made no progress, or hit the deadline.
// Whichever it is, we're done now.
conn.SetWriteDeadline(time.Time{})
return n, err
}
}
}
func mustUint31(v int32) uint32 { func mustUint31(v int32) uint32 {
if v < 0 || v > 2147483647 { if v < 0 || v > 2147483647 {
panic("out of range") panic("out of range")

View file

@ -29,6 +29,7 @@
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"crypto/rand"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -52,10 +53,14 @@
) )
const ( const (
prefaceTimeout = 10 * time.Second prefaceTimeout = 10 * time.Second
firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
handlerChunkWriteSize = 4 << 10 handlerChunkWriteSize = 4 << 10
defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to? defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
maxQueuedControlFrames = 10000 maxQueuedControlFrames = 10000
) )
@ -127,6 +132,22 @@ type Server struct {
// If zero or negative, there is no timeout. // If zero or negative, there is no timeout.
IdleTimeout time.Duration IdleTimeout time.Duration
// ReadIdleTimeout is the timeout after which a health check using a ping
// frame will be carried out if no frame is received on the connection.
// If zero, no health check is performed.
ReadIdleTimeout time.Duration
// PingTimeout is the timeout after which the connection will be closed
// if a response to a ping is not received.
// If zero, a default of 15 seconds is used.
PingTimeout time.Duration
// WriteByteTimeout is the timeout after which a connection will be
// closed if no data can be written to it. The timeout begins when data is
// available to write, and is extended whenever any bytes are written.
// If zero or negative, there is no timeout.
WriteByteTimeout time.Duration
// MaxUploadBufferPerConnection is the size of the initial flow // MaxUploadBufferPerConnection is the size of the initial flow
// control window for each connections. The HTTP/2 spec does not // control window for each connections. The HTTP/2 spec does not
// allow this to be smaller than 65535 or larger than 2^32-1. // allow this to be smaller than 65535 or larger than 2^32-1.
@ -189,57 +210,6 @@ func (s *Server) afterFunc(d time.Duration, f func()) timer {
return timeTimer{time.AfterFunc(d, f)} return timeTimer{time.AfterFunc(d, f)}
} }
func (s *Server) initialConnRecvWindowSize() int32 {
if s.MaxUploadBufferPerConnection >= initialWindowSize {
return s.MaxUploadBufferPerConnection
}
return 1 << 20
}
func (s *Server) initialStreamRecvWindowSize() int32 {
if s.MaxUploadBufferPerStream > 0 {
return s.MaxUploadBufferPerStream
}
return 1 << 20
}
func (s *Server) maxReadFrameSize() uint32 {
if v := s.MaxReadFrameSize; v >= minMaxFrameSize && v <= maxFrameSize {
return v
}
return defaultMaxReadFrameSize
}
func (s *Server) maxConcurrentStreams() uint32 {
if v := s.MaxConcurrentStreams; v > 0 {
return v
}
return defaultMaxStreams
}
func (s *Server) maxDecoderHeaderTableSize() uint32 {
if v := s.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (s *Server) maxEncoderHeaderTableSize() uint32 {
if v := s.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
// maxQueuedControlFrames is the maximum number of control frames like
// SETTINGS, PING and RST_STREAM that will be queued for writing before
// the connection is closed to prevent memory exhaustion attacks.
func (s *Server) maxQueuedControlFrames() int {
// TODO: if anybody asks, add a Server field, and remember to define the
// behavior of negative values.
return maxQueuedControlFrames
}
type serverInternalState struct { type serverInternalState struct {
mu sync.Mutex mu sync.Mutex
activeConns map[*serverConn]struct{} activeConns map[*serverConn]struct{}
@ -440,13 +410,15 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
baseCtx, cancel := serverConnBaseContext(c, opts) baseCtx, cancel := serverConnBaseContext(c, opts)
defer cancel() defer cancel()
http1srv := opts.baseConfig()
conf := configFromServer(http1srv, s)
sc := &serverConn{ sc := &serverConn{
srv: s, srv: s,
hs: opts.baseConfig(), hs: http1srv,
conn: c, conn: c,
baseCtx: baseCtx, baseCtx: baseCtx,
remoteAddrStr: c.RemoteAddr().String(), remoteAddrStr: c.RemoteAddr().String(),
bw: newBufferedWriter(c), bw: newBufferedWriter(s.group, c, conf.WriteByteTimeout),
handler: opts.handler(), handler: opts.handler(),
streams: make(map[uint32]*stream), streams: make(map[uint32]*stream),
readFrameCh: make(chan readFrameResult), readFrameCh: make(chan readFrameResult),
@ -456,9 +428,12 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way bodyReadCh: make(chan bodyReadMsg), // buffering doesn't matter either way
doneServing: make(chan struct{}), doneServing: make(chan struct{}),
clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value" clientMaxStreams: math.MaxUint32, // Section 6.5.2: "Initially, there is no limit to this value"
advMaxStreams: s.maxConcurrentStreams(), advMaxStreams: conf.MaxConcurrentStreams,
initialStreamSendWindowSize: initialWindowSize, initialStreamSendWindowSize: initialWindowSize,
initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
maxFrameSize: initialMaxFrameSize, maxFrameSize: initialMaxFrameSize,
pingTimeout: conf.PingTimeout,
countErrorFunc: conf.CountError,
serveG: newGoroutineLock(), serveG: newGoroutineLock(),
pushEnabled: true, pushEnabled: true,
sawClientPreface: opts.SawClientPreface, sawClientPreface: opts.SawClientPreface,
@ -491,15 +466,15 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
sc.flow.add(initialWindowSize) sc.flow.add(initialWindowSize)
sc.inflow.init(initialWindowSize) sc.inflow.init(initialWindowSize)
sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
sc.hpackEncoder.SetMaxDynamicTableSizeLimit(s.maxEncoderHeaderTableSize()) sc.hpackEncoder.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
fr := NewFramer(sc.bw, c) fr := NewFramer(sc.bw, c)
if s.CountError != nil { if conf.CountError != nil {
fr.countError = s.CountError fr.countError = conf.CountError
} }
fr.ReadMetaHeaders = hpack.NewDecoder(s.maxDecoderHeaderTableSize(), nil) fr.ReadMetaHeaders = hpack.NewDecoder(conf.MaxDecoderHeaderTableSize, nil)
fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.MaxHeaderListSize = sc.maxHeaderListSize()
fr.SetMaxReadFrameSize(s.maxReadFrameSize()) fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
sc.framer = fr sc.framer = fr
if tc, ok := c.(connectionStater); ok { if tc, ok := c.(connectionStater); ok {
@ -532,7 +507,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
// So for now, do nothing here again. // So for now, do nothing here again.
} }
if !s.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) { if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
// "Endpoints MAY choose to generate a connection error // "Endpoints MAY choose to generate a connection error
// (Section 5.4.1) of type INADEQUATE_SECURITY if one of // (Section 5.4.1) of type INADEQUATE_SECURITY if one of
// the prohibited cipher suites are negotiated." // the prohibited cipher suites are negotiated."
@ -569,7 +544,7 @@ func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverCon
opts.UpgradeRequest = nil opts.UpgradeRequest = nil
} }
sc.serve() sc.serve(conf)
} }
func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) { func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
@ -609,6 +584,7 @@ type serverConn struct {
tlsState *tls.ConnectionState // shared by all handlers, like net/http tlsState *tls.ConnectionState // shared by all handlers, like net/http
remoteAddrStr string remoteAddrStr string
writeSched WriteScheduler writeSched WriteScheduler
countErrorFunc func(errType string)
// Everything following is owned by the serve loop; use serveG.check(): // Everything following is owned by the serve loop; use serveG.check():
serveG goroutineLock // used to verify funcs are on serve() serveG goroutineLock // used to verify funcs are on serve()
@ -628,6 +604,7 @@ type serverConn struct {
streams map[uint32]*stream streams map[uint32]*stream
unstartedHandlers []unstartedHandler unstartedHandlers []unstartedHandler
initialStreamSendWindowSize int32 initialStreamSendWindowSize int32
initialStreamRecvWindowSize int32
maxFrameSize int32 maxFrameSize int32
peerMaxHeaderListSize uint32 // zero means unknown (default) peerMaxHeaderListSize uint32 // zero means unknown (default)
canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
@ -638,9 +615,14 @@ type serverConn struct {
inGoAway bool // we've started to or sent GOAWAY inGoAway bool // we've started to or sent GOAWAY
inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop
needToSendGoAway bool // we need to schedule a GOAWAY frame write needToSendGoAway bool // we need to schedule a GOAWAY frame write
pingSent bool
sentPingData [8]byte
goAwayCode ErrCode goAwayCode ErrCode
shutdownTimer timer // nil until used shutdownTimer timer // nil until used
idleTimer timer // nil if unused idleTimer timer // nil if unused
readIdleTimeout time.Duration
pingTimeout time.Duration
readIdleTimer timer // nil if unused
// Owned by the writeFrameAsync goroutine: // Owned by the writeFrameAsync goroutine:
headerWriteBuf bytes.Buffer headerWriteBuf bytes.Buffer
@ -655,11 +637,7 @@ func (sc *serverConn) maxHeaderListSize() uint32 {
if n <= 0 { if n <= 0 {
n = http.DefaultMaxHeaderBytes n = http.DefaultMaxHeaderBytes
} }
// http2's count is in a slightly different unit and includes 32 bytes per pair. return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
// So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
const perFieldOverhead = 32 // per http2 spec
const typicalHeaders = 10 // conservative
return uint32(n + typicalHeaders*perFieldOverhead)
} }
func (sc *serverConn) curOpenStreams() uint32 { func (sc *serverConn) curOpenStreams() uint32 {
@ -923,7 +901,7 @@ func (sc *serverConn) notePanic() {
} }
} }
func (sc *serverConn) serve() { func (sc *serverConn) serve(conf http2Config) {
sc.serveG.check() sc.serveG.check()
defer sc.notePanic() defer sc.notePanic()
defer sc.conn.Close() defer sc.conn.Close()
@ -937,18 +915,18 @@ func (sc *serverConn) serve() {
sc.writeFrame(FrameWriteRequest{ sc.writeFrame(FrameWriteRequest{
write: writeSettings{ write: writeSettings{
{SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {SettingMaxFrameSize, conf.MaxReadFrameSize},
{SettingMaxConcurrentStreams, sc.advMaxStreams}, {SettingMaxConcurrentStreams, sc.advMaxStreams},
{SettingMaxHeaderListSize, sc.maxHeaderListSize()}, {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
{SettingHeaderTableSize, sc.srv.maxDecoderHeaderTableSize()}, {SettingHeaderTableSize, conf.MaxDecoderHeaderTableSize},
{SettingInitialWindowSize, uint32(sc.srv.initialStreamRecvWindowSize())}, {SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
}, },
}) })
sc.unackedSettings++ sc.unackedSettings++
// Each connection starts with initialWindowSize inflow tokens. // Each connection starts with initialWindowSize inflow tokens.
// If a higher value is configured, we add more tokens. // If a higher value is configured, we add more tokens.
if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 { if diff := conf.MaxUploadBufferPerConnection - initialWindowSize; diff > 0 {
sc.sendWindowUpdate(nil, int(diff)) sc.sendWindowUpdate(nil, int(diff))
} }
@ -968,11 +946,18 @@ func (sc *serverConn) serve() {
defer sc.idleTimer.Stop() defer sc.idleTimer.Stop()
} }
if conf.SendPingTimeout > 0 {
sc.readIdleTimeout = conf.SendPingTimeout
sc.readIdleTimer = sc.srv.afterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
defer sc.readIdleTimer.Stop()
}
go sc.readFrames() // closed by defer sc.conn.Close above go sc.readFrames() // closed by defer sc.conn.Close above
settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer) settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer)
defer settingsTimer.Stop() defer settingsTimer.Stop()
lastFrameTime := sc.srv.now()
loopNum := 0 loopNum := 0
for { for {
loopNum++ loopNum++
@ -986,6 +971,7 @@ func (sc *serverConn) serve() {
case res := <-sc.wroteFrameCh: case res := <-sc.wroteFrameCh:
sc.wroteFrame(res) sc.wroteFrame(res)
case res := <-sc.readFrameCh: case res := <-sc.readFrameCh:
lastFrameTime = sc.srv.now()
// Process any written frames before reading new frames from the client since a // Process any written frames before reading new frames from the client since a
// written frame could have triggered a new stream to be started. // written frame could have triggered a new stream to be started.
if sc.writingFrameAsync { if sc.writingFrameAsync {
@ -1017,6 +1003,8 @@ func (sc *serverConn) serve() {
case idleTimerMsg: case idleTimerMsg:
sc.vlogf("connection is idle") sc.vlogf("connection is idle")
sc.goAway(ErrCodeNo) sc.goAway(ErrCodeNo)
case readIdleTimerMsg:
sc.handlePingTimer(lastFrameTime)
case shutdownTimerMsg: case shutdownTimerMsg:
sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
return return
@ -1039,7 +1027,7 @@ func (sc *serverConn) serve() {
// If the peer is causing us to generate a lot of control frames, // If the peer is causing us to generate a lot of control frames,
// but not reading them from us, assume they are trying to make us // but not reading them from us, assume they are trying to make us
// run out of memory. // run out of memory.
if sc.queuedControlFrames > sc.srv.maxQueuedControlFrames() { if sc.queuedControlFrames > maxQueuedControlFrames {
sc.vlogf("http2: too many control frames in send queue, closing connection") sc.vlogf("http2: too many control frames in send queue, closing connection")
return return
} }
@ -1055,12 +1043,39 @@ func (sc *serverConn) serve() {
} }
} }
func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
if sc.pingSent {
sc.vlogf("timeout waiting for PING response")
sc.conn.Close()
return
}
pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
now := sc.srv.now()
if pingAt.After(now) {
// We received frames since arming the ping timer.
// Reset it for the next possible timeout.
sc.readIdleTimer.Reset(pingAt.Sub(now))
return
}
sc.pingSent = true
// Ignore crypto/rand.Read errors: It generally can't fail, and worse case if it does
// is we send a PING frame containing 0s.
_, _ = rand.Read(sc.sentPingData[:])
sc.writeFrame(FrameWriteRequest{
write: &writePing{data: sc.sentPingData},
})
sc.readIdleTimer.Reset(sc.pingTimeout)
}
type serverMessage int type serverMessage int
// Message values sent to serveMsgCh. // Message values sent to serveMsgCh.
var ( var (
settingsTimerMsg = new(serverMessage) settingsTimerMsg = new(serverMessage)
idleTimerMsg = new(serverMessage) idleTimerMsg = new(serverMessage)
readIdleTimerMsg = new(serverMessage)
shutdownTimerMsg = new(serverMessage) shutdownTimerMsg = new(serverMessage)
gracefulShutdownMsg = new(serverMessage) gracefulShutdownMsg = new(serverMessage)
handlerDoneMsg = new(serverMessage) handlerDoneMsg = new(serverMessage)
@ -1068,6 +1083,7 @@ func (sc *serverConn) serve() {
func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) } func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) } func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) } func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
func (sc *serverConn) sendServeMsg(msg interface{}) { func (sc *serverConn) sendServeMsg(msg interface{}) {
@ -1320,6 +1336,10 @@ func (sc *serverConn) wroteFrame(res frameWriteResult) {
sc.writingFrame = false sc.writingFrame = false
sc.writingFrameAsync = false sc.writingFrameAsync = false
if res.err != nil {
sc.conn.Close()
}
wr := res.wr wr := res.wr
if writeEndsStream(wr.write) { if writeEndsStream(wr.write) {
@ -1594,6 +1614,11 @@ func (sc *serverConn) processFrame(f Frame) error {
func (sc *serverConn) processPing(f *PingFrame) error { func (sc *serverConn) processPing(f *PingFrame) error {
sc.serveG.check() sc.serveG.check()
if f.IsAck() { if f.IsAck() {
if sc.pingSent && sc.sentPingData == f.Data {
// This is a response to a PING we sent.
sc.pingSent = false
sc.readIdleTimer.Reset(sc.readIdleTimeout)
}
// 6.7 PING: " An endpoint MUST NOT respond to PING frames // 6.7 PING: " An endpoint MUST NOT respond to PING frames
// containing this flag." // containing this flag."
return nil return nil
@ -2160,7 +2185,7 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream
st.cw.Init() st.cw.Init()
st.flow.conn = &sc.flow // link to conn-level counter st.flow.conn = &sc.flow // link to conn-level counter
st.flow.add(sc.initialStreamSendWindowSize) st.flow.add(sc.initialStreamSendWindowSize)
st.inflow.init(sc.srv.initialStreamRecvWindowSize()) st.inflow.init(sc.initialStreamRecvWindowSize)
if sc.hs.WriteTimeout > 0 { if sc.hs.WriteTimeout > 0 {
st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout) st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
} }
@ -3301,7 +3326,7 @@ func (sc *serverConn) countError(name string, err error) error {
if sc == nil || sc.srv == nil { if sc == nil || sc.srv == nil {
return err return err
} }
f := sc.srv.CountError f := sc.countErrorFunc
if f == nil { if f == nil {
return err return err
} }

View file

@ -25,7 +25,6 @@
"net/http" "net/http"
"net/http/httptrace" "net/http/httptrace"
"net/textproto" "net/textproto"
"os"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@ -227,40 +226,26 @@ func (t *Transport) contextWithTimeout(ctx context.Context, d time.Duration) (co
} }
func (t *Transport) maxHeaderListSize() uint32 { func (t *Transport) maxHeaderListSize() uint32 {
if t.MaxHeaderListSize == 0 { n := int64(t.MaxHeaderListSize)
if t.t1 != nil && t.t1.MaxResponseHeaderBytes != 0 {
n = t.t1.MaxResponseHeaderBytes
if n > 0 {
n = adjustHTTP1MaxHeaderSize(n)
}
}
if n <= 0 {
return 10 << 20 return 10 << 20
} }
if t.MaxHeaderListSize == 0xffffffff { if n >= 0xffffffff {
return 0 return 0
} }
return t.MaxHeaderListSize return uint32(n)
}
func (t *Transport) maxFrameReadSize() uint32 {
if t.MaxReadFrameSize == 0 {
return 0 // use the default provided by the peer
}
if t.MaxReadFrameSize < minMaxFrameSize {
return minMaxFrameSize
}
if t.MaxReadFrameSize > maxFrameSize {
return maxFrameSize
}
return t.MaxReadFrameSize
} }
func (t *Transport) disableCompression() bool { func (t *Transport) disableCompression() bool {
return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression) return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
} }
func (t *Transport) pingTimeout() time.Duration {
if t.PingTimeout == 0 {
return 15 * time.Second
}
return t.PingTimeout
}
// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2. // ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
// It returns an error if t1 has already been HTTP/2-enabled. // It returns an error if t1 has already been HTTP/2-enabled.
// //
@ -370,11 +355,14 @@ type ClientConn struct {
lastActive time.Time lastActive time.Time
lastIdle time.Time // time last idle lastIdle time.Time // time last idle
// Settings from peer: (also guarded by wmu) // Settings from peer: (also guarded by wmu)
maxFrameSize uint32 maxFrameSize uint32
maxConcurrentStreams uint32 maxConcurrentStreams uint32
peerMaxHeaderListSize uint64 peerMaxHeaderListSize uint64
peerMaxHeaderTableSize uint32 peerMaxHeaderTableSize uint32
initialWindowSize uint32 initialWindowSize uint32
initialStreamRecvWindowSize int32
readIdleTimeout time.Duration
pingTimeout time.Duration
// reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests. // reqHeaderMu is a 1-element semaphore channel controlling access to sending new requests.
// Write to reqHeaderMu to lock it, read from it to unlock. // Write to reqHeaderMu to lock it, read from it to unlock.
@ -499,6 +487,7 @@ func (cs *clientStream) closeReqBodyLocked() {
} }
type stickyErrWriter struct { type stickyErrWriter struct {
group synctestGroupInterface
conn net.Conn conn net.Conn
timeout time.Duration timeout time.Duration
err *error err *error
@ -508,22 +497,9 @@ func (sew stickyErrWriter) Write(p []byte) (n int, err error) {
if *sew.err != nil { if *sew.err != nil {
return 0, *sew.err return 0, *sew.err
} }
for { n, err = writeWithByteTimeout(sew.group, sew.conn, sew.timeout, p)
if sew.timeout != 0 { *sew.err = err
sew.conn.SetWriteDeadline(time.Now().Add(sew.timeout)) return n, err
}
nn, err := sew.conn.Write(p[n:])
n += nn
if n < len(p) && nn > 0 && errors.Is(err, os.ErrDeadlineExceeded) {
// Keep extending the deadline so long as we're making progress.
continue
}
if sew.timeout != 0 {
sew.conn.SetWriteDeadline(time.Time{})
}
*sew.err = err
return n, err
}
} }
// noCachedConnError is the concrete type of ErrNoCachedConn, which // noCachedConnError is the concrete type of ErrNoCachedConn, which
@ -758,44 +734,36 @@ func (t *Transport) expectContinueTimeout() time.Duration {
return t.t1.ExpectContinueTimeout return t.t1.ExpectContinueTimeout
} }
func (t *Transport) maxDecoderHeaderTableSize() uint32 {
if v := t.MaxDecoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (t *Transport) maxEncoderHeaderTableSize() uint32 {
if v := t.MaxEncoderHeaderTableSize; v > 0 {
return v
}
return initialHeaderTableSize
}
func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) { func (t *Transport) NewClientConn(c net.Conn) (*ClientConn, error) {
return t.newClientConn(c, t.disableKeepAlives()) return t.newClientConn(c, t.disableKeepAlives())
} }
func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) { func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, error) {
conf := configFromTransport(t)
cc := &ClientConn{ cc := &ClientConn{
t: t, t: t,
tconn: c, tconn: c,
readerDone: make(chan struct{}), readerDone: make(chan struct{}),
nextStreamID: 1, nextStreamID: 1,
maxFrameSize: 16 << 10, // spec default maxFrameSize: 16 << 10, // spec default
initialWindowSize: 65535, // spec default initialWindowSize: 65535, // spec default
maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings. initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead. maxConcurrentStreams: initialMaxConcurrentStreams, // "infinite", per spec. Use a smaller value until we have received server settings.
streams: make(map[uint32]*clientStream), peerMaxHeaderListSize: 0xffffffffffffffff, // "infinite", per spec. Use 2^64-1 instead.
singleUse: singleUse, streams: make(map[uint32]*clientStream),
wantSettingsAck: true, singleUse: singleUse,
pings: make(map[[8]byte]chan struct{}), wantSettingsAck: true,
reqHeaderMu: make(chan struct{}, 1), readIdleTimeout: conf.SendPingTimeout,
pingTimeout: conf.PingTimeout,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
} }
var group synctestGroupInterface
if t.transportTestHooks != nil { if t.transportTestHooks != nil {
t.markNewGoroutine() t.markNewGoroutine()
t.transportTestHooks.newclientconn(cc) t.transportTestHooks.newclientconn(cc)
c = cc.tconn c = cc.tconn
group = t.group
} }
if VerboseLogs { if VerboseLogs {
t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr())
@ -807,24 +775,23 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
// TODO: adjust this writer size to account for frame size + // TODO: adjust this writer size to account for frame size +
// MTU + crypto/tls record padding. // MTU + crypto/tls record padding.
cc.bw = bufio.NewWriter(stickyErrWriter{ cc.bw = bufio.NewWriter(stickyErrWriter{
group: group,
conn: c, conn: c,
timeout: t.WriteByteTimeout, timeout: conf.WriteByteTimeout,
err: &cc.werr, err: &cc.werr,
}) })
cc.br = bufio.NewReader(c) cc.br = bufio.NewReader(c)
cc.fr = NewFramer(cc.bw, cc.br) cc.fr = NewFramer(cc.bw, cc.br)
if t.maxFrameReadSize() != 0 { cc.fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
cc.fr.SetMaxReadFrameSize(t.maxFrameReadSize())
}
if t.CountError != nil { if t.CountError != nil {
cc.fr.countError = t.CountError cc.fr.countError = t.CountError
} }
maxHeaderTableSize := t.maxDecoderHeaderTableSize() maxHeaderTableSize := conf.MaxDecoderHeaderTableSize
cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil) cc.fr.ReadMetaHeaders = hpack.NewDecoder(maxHeaderTableSize, nil)
cc.fr.MaxHeaderListSize = t.maxHeaderListSize() cc.fr.MaxHeaderListSize = t.maxHeaderListSize()
cc.henc = hpack.NewEncoder(&cc.hbuf) cc.henc = hpack.NewEncoder(&cc.hbuf)
cc.henc.SetMaxDynamicTableSizeLimit(t.maxEncoderHeaderTableSize()) cc.henc.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
cc.peerMaxHeaderTableSize = initialHeaderTableSize cc.peerMaxHeaderTableSize = initialHeaderTableSize
if cs, ok := c.(connectionStater); ok { if cs, ok := c.(connectionStater); ok {
@ -834,11 +801,9 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
initialSettings := []Setting{ initialSettings := []Setting{
{ID: SettingEnablePush, Val: 0}, {ID: SettingEnablePush, Val: 0},
{ID: SettingInitialWindowSize, Val: transportDefaultStreamFlow}, {ID: SettingInitialWindowSize, Val: uint32(cc.initialStreamRecvWindowSize)},
}
if max := t.maxFrameReadSize(); max != 0 {
initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: max})
} }
initialSettings = append(initialSettings, Setting{ID: SettingMaxFrameSize, Val: conf.MaxReadFrameSize})
if max := t.maxHeaderListSize(); max != 0 { if max := t.maxHeaderListSize(); max != 0 {
initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max}) initialSettings = append(initialSettings, Setting{ID: SettingMaxHeaderListSize, Val: max})
} }
@ -848,8 +813,8 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
cc.bw.Write(clientPreface) cc.bw.Write(clientPreface)
cc.fr.WriteSettings(initialSettings...) cc.fr.WriteSettings(initialSettings...)
cc.fr.WriteWindowUpdate(0, transportDefaultConnFlow) cc.fr.WriteWindowUpdate(0, uint32(conf.MaxUploadBufferPerConnection))
cc.inflow.init(transportDefaultConnFlow + initialWindowSize) cc.inflow.init(conf.MaxUploadBufferPerConnection + initialWindowSize)
cc.bw.Flush() cc.bw.Flush()
if cc.werr != nil { if cc.werr != nil {
cc.Close() cc.Close()
@ -867,7 +832,7 @@ func (t *Transport) newClientConn(c net.Conn, singleUse bool) (*ClientConn, erro
} }
func (cc *ClientConn) healthCheck() { func (cc *ClientConn) healthCheck() {
pingTimeout := cc.t.pingTimeout() pingTimeout := cc.pingTimeout
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will // We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received. // trigger the healthCheck again if there is no frame received.
ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout) ctx, cancel := cc.t.contextWithTimeout(context.Background(), pingTimeout)
@ -2199,7 +2164,7 @@ type resAndError struct {
func (cc *ClientConn) addStreamLocked(cs *clientStream) { func (cc *ClientConn) addStreamLocked(cs *clientStream) {
cs.flow.add(int32(cc.initialWindowSize)) cs.flow.add(int32(cc.initialWindowSize))
cs.flow.setConnFlow(&cc.flow) cs.flow.setConnFlow(&cc.flow)
cs.inflow.init(transportDefaultStreamFlow) cs.inflow.init(cc.initialStreamRecvWindowSize)
cs.ID = cc.nextStreamID cs.ID = cc.nextStreamID
cc.nextStreamID += 2 cc.nextStreamID += 2
cc.streams[cs.ID] = cs cc.streams[cs.ID] = cs
@ -2345,7 +2310,7 @@ func (cc *ClientConn) countReadFrameError(err error) {
func (rl *clientConnReadLoop) run() error { func (rl *clientConnReadLoop) run() error {
cc := rl.cc cc := rl.cc
gotSettings := false gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout readIdleTimeout := cc.readIdleTimeout
var t timer var t timer
if readIdleTimeout != 0 { if readIdleTimeout != 0 {
t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck) t = cc.t.afterFunc(readIdleTimeout, cc.healthCheck)

Some files were not shown because too many files have changed in this diff Show more