Compare commits

..

1 commit

Author SHA1 Message Date
Victor Dyotte 90e153ada0
Merge 40c33ccc49 into 99f535f99b 2024-10-07 17:55:26 +02:00
40 changed files with 431 additions and 891 deletions

View file

@ -569,7 +569,6 @@ For example, the following json object `Reject`s the attempt of `@someone@somewh
```json ```json
{ {
"@context": "https://www.w3.org/ns/activitystreams",
"actor": "https://example.org/users/post_author", "actor": "https://example.org/users/post_author",
"to": "https://somewhere.else.example.org/users/someone", "to": "https://somewhere.else.example.org/users/someone",
"id": "https://example.org/users/post_author/activities/reject/01J0K2YXP9QCT5BE1JWQSAM3B6", "id": "https://example.org/users/post_author/activities/reject/01J0K2YXP9QCT5BE1JWQSAM3B6",
@ -592,12 +591,7 @@ For example, the following json object `Accept`s the attempt of `@someone@somewh
```json ```json
{ {
"@context": "https://www.w3.org/ns/activitystreams",
"actor": "https://example.org/users/post_author", "actor": "https://example.org/users/post_author",
"cc": [
"https://www.w3.org/ns/activitystreams#Public",
"https://example.org/users/post_author/followers"
],
"to": "https://somewhere.else.example.org/users/someone", "to": "https://somewhere.else.example.org/users/someone",
"id": "https://example.org/users/post_author/activities/reject/01J0K2YXP9QCT5BE1JWQSAM3B6", "id": "https://example.org/users/post_author/activities/reject/01J0K2YXP9QCT5BE1JWQSAM3B6",
"object": "https://somewhere.else.example.org/users/someone/statuses/01J17XY2VXGMNNPH1XR7BG2524", "object": "https://somewhere.else.example.org/users/someone/statuses/01J17XY2VXGMNNPH1XR7BG2524",
@ -607,9 +601,6 @@ For example, the following json object `Accept`s the attempt of `@someone@somewh
If this happens, `@someone@somewhere.else.example.org` (and their instance) should consider the interaction as having been approved / accepted. The instance can then feel free to distribute the interaction `Activity` to all of the recipients targed by `to`, `cc`, etc, with the additional property `approvedBy` ([see below](#approvedby)). If this happens, `@someone@somewhere.else.example.org` (and their instance) should consider the interaction as having been approved / accepted. The instance can then feel free to distribute the interaction `Activity` to all of the recipients targed by `to`, `cc`, etc, with the additional property `approvedBy` ([see below](#approvedby)).
!!! Note
In the above example, actor `https://example.org/users/post_author` addresses the `Accept` activity not just to the interacting actor `https://somewhere.else.example.org/users/someone`, but to their followers collection as well (and, implicitly, to the public). This allows followers of `https://example.org/users/post_author` on other servers to also mark the interaction as accepted, and to show the interaction alongside the interacted-with post.
### Validating presence in a Followers / Following collection ### Validating presence in a Followers / Following collection
If an `Actor` interacting with an `Object` (via `Like`, `inReplyTo`, or `Announce`) is permitted to do that interaction based on their presence in a `Followers` or `Following` collection in the `always` field of an interaction policy, then their server should *still* wait for an `Accept` to be received from the server of the target account, before distributing the interaction more widely with the `approvedBy` property set to the URI of the `Accept`. If an `Actor` interacting with an `Object` (via `Like`, `inReplyTo`, or `Announce`) is permitted to do that interaction based on their presence in a `Followers` or `Following` collection in the `always` field of an interaction policy, then their server should *still* wait for an `Accept` to be received from the server of the target account, before distributing the interaction more widely with the `approvedBy` property set to the URI of the `Accept`.

2
go.mod
View file

@ -44,7 +44,7 @@ require (
github.com/miekg/dns v1.1.62 github.com/miekg/dns v1.1.62
github.com/minio/minio-go/v7 v7.0.77 github.com/minio/minio-go/v7 v7.0.77
github.com/mitchellh/mapstructure v1.5.0 github.com/mitchellh/mapstructure v1.5.0
github.com/ncruces/go-sqlite3 v0.19.0 github.com/ncruces/go-sqlite3 v0.18.4
github.com/oklog/ulid v1.3.1 github.com/oklog/ulid v1.3.1
github.com/prometheus/client_golang v1.20.4 github.com/prometheus/client_golang v1.20.4
github.com/spf13/cobra v1.8.1 github.com/spf13/cobra v1.8.1

4
go.sum
View file

@ -434,8 +434,8 @@ github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA=
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/ncruces/go-sqlite3 v0.19.0 h1:yebbD/cP8Gf+7nKoUin2ATjnqJK2VvyS30d3xsjRp5k= github.com/ncruces/go-sqlite3 v0.18.4 h1:Je8o3y33MDwPYY/Cacas8yCsuoUzpNY/AgoSlN2ekyE=
github.com/ncruces/go-sqlite3 v0.19.0/go.mod h1:yL4ZNWGsr1/8pcLfpPW1RT1WFdvyeHonrgIwwi4rvkg= github.com/ncruces/go-sqlite3 v0.18.4/go.mod h1:4HLag13gq1k10s4dfGBhMfRVsssJRT9/5hYqVM9RUYo=
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=

View file

@ -77,10 +77,6 @@
// See https://www.w3.org/TR/activitystreams-vocabulary/#microsyntaxes // See https://www.w3.org/TR/activitystreams-vocabulary/#microsyntaxes
// and https://www.w3.org/TR/activitystreams-vocabulary/#dfn-tag // and https://www.w3.org/TR/activitystreams-vocabulary/#dfn-tag
TagHashtag = "Hashtag" TagHashtag = "Hashtag"
// Not in the AS spec, just used internally to indicate
// that we don't *yet* know what type of Object something is.
ObjectUnknown = "Unknown"
) )
// isActivity returns whether AS type name is of an Activity (NOT IntransitiveActivity). // isActivity returns whether AS type name is of an Activity (NOT IntransitiveActivity).

View file

@ -527,9 +527,8 @@ func (d *Dereferencer) enrichStatus(
// serve statuses with the `approved_by` field, but we // serve statuses with the `approved_by` field, but we
// might have marked a status as pre-approved on our side // might have marked a status as pre-approved on our side
// based on the author's inclusion in a followers/following // based on the author's inclusion in a followers/following
// collection, or by providing pre-approval URI on the bare // collection. By carrying over previously-set values we
// status passed to RefreshStatus. By carrying over previously // can avoid marking such statuses as "pending" again.
// set values we can avoid marking such statuses as "pending".
// //
// If a remote has in the meantime retracted its approval, // If a remote has in the meantime retracted its approval,
// the next call to 'isPermittedStatus' will catch that. // the next call to 'isPermittedStatus' will catch that.

View file

@ -113,17 +113,33 @@ func (d *Dereferencer) isPermittedStatus(
func (d *Dereferencer) isPermittedReply( func (d *Dereferencer) isPermittedReply(
ctx context.Context, ctx context.Context,
requestUser string, requestUser string,
reply *gtsmodel.Status, status *gtsmodel.Status,
) (bool, error) { ) (bool, error) {
var ( var (
replyURI = reply.URI // Definitely set. statusURI = status.URI // Definitely set.
inReplyToURI = reply.InReplyToURI // Definitely set. inReplyToURI = status.InReplyToURI // Definitely set.
inReplyTo = reply.InReplyTo // Might not be set. inReplyTo = status.InReplyTo // Might not yet be set.
acceptIRI = reply.ApprovedByURI // Might not be set.
) )
// Check if we have a stored interaction request for parent status. // Check if status with this URI has previously been rejected.
parentReq, err := d.state.DB.GetInteractionRequestByInteractionURI( req, err := d.state.DB.GetInteractionRequestByInteractionURI(
gtscontext.SetBarebones(ctx),
statusURI,
)
if err != nil && !errors.Is(err, db.ErrNoEntries) {
err := gtserror.Newf("db error getting interaction request: %w", err)
return false, err
}
if req != nil && req.IsRejected() {
// This status has been
// rejected reviously, so
// it's not permitted now.
return false, nil
}
// Check if replied-to status has previously been rejected.
req, err = d.state.DB.GetInteractionRequestByInteractionURI(
gtscontext.SetBarebones(ctx), gtscontext.SetBarebones(ctx),
inReplyToURI, inReplyToURI,
) )
@ -132,78 +148,71 @@ func (d *Dereferencer) isPermittedReply(
return false, err return false, err
} }
// Check if we have a stored interaction request for this reply. if req != nil && req.IsRejected() {
thisReq, err := d.state.DB.GetInteractionRequestByInteractionURI( // This status's parent was rejected, so
gtscontext.SetBarebones(ctx), // implicitly this reply should be rejected too.
replyURI, //
) // We know already that we haven't inserted
if err != nil && !errors.Is(err, db.ErrNoEntries) { // a rejected interaction request for this
err := gtserror.Newf("db error getting interaction request: %w", err) // status yet so do it before returning.
return false, err id := id.NewULID()
}
parentRejected := (parentReq != nil && parentReq.IsRejected()) // To ensure the Reject chain stays coherent,
thisRejected := (thisReq != nil && thisReq.IsRejected()) // borrow fields from the up-thread rejection.
// This collapses the chain beyond the first
// rejected reply and allows us to avoid derefing
// further replies we already know we don't want.
statusID := req.StatusID
targetAccountID := req.TargetAccountID
if parentRejected { // As nobody is actually Rejecting the reply
// If this status's parent was rejected, // directly, but it's an implicit Reject coming
// implicitly this reply should be too; // from our internal logic, don't bother setting
// there's nothing more to check here. // a URI (it's not a required field anyway).
return false, d.unpermittedByParent( uri := ""
ctx,
reply, rejection := &gtsmodel.InteractionRequest{
thisReq, ID: id,
parentReq, StatusID: statusID,
) TargetAccountID: targetAccountID,
} InteractingAccountID: status.AccountID,
InteractionURI: statusURI,
InteractionType: gtsmodel.InteractionReply,
URI: uri,
RejectedAt: time.Now(),
}
err := d.state.DB.PutInteractionRequest(ctx, rejection)
if err != nil && !errors.Is(err, db.ErrAlreadyExists) {
return false, gtserror.Newf("db error putting pre-rejected interaction request: %w", err)
}
// Parent wasn't rejected. Check if this
// reply itself was rejected previously.
//
// If it was, and it doesn't now claim to
// be approved, then we should just reject it
// again, as nothing's changed since last time.
if thisRejected && acceptIRI == "" {
// Nothing changed,
// still rejected.
return false, nil return false, nil
} }
// This reply wasn't rejected previously, or
// it was rejected previously and now claims
// to be approved. Continue permission checks.
if inReplyTo == nil { if inReplyTo == nil {
// If we didn't have the replied-to status // We didn't have the replied-to status in
// in our database (yet), we can't check // our database (yet) so we can't know if
// right now if this reply is permitted. // this reply is permitted or not. For now
// // just return true; worst-case, the status
// For now, just return permitted if reply // sticks around on the instance for a couple
// was not explicitly rejected before; worst- // hours until we try to dereference it again
// case, the reply stays on the instance for // and realize it should be forbidden.
// a couple hours until we try to deref it return true, nil
// again and realize it should be forbidden.
return !thisRejected, nil
} }
// We have the replied-to status; ensure it's fully populated.
if err := d.state.DB.PopulateStatus(ctx, inReplyTo); err != nil {
return false, gtserror.Newf("error populating status %s: %w", reply.ID, err)
}
// Make sure replied-to status is not
// a boost wrapper, and make sure it's
// actually visible to the requester.
if inReplyTo.BoostOfID != "" { if inReplyTo.BoostOfID != "" {
// We do not permit replies // We do not permit replies to
// to boost wrapper statuses. // boost wrapper statuses. (this
// shouldn't be able to happen).
log.Info(ctx, "rejecting reply to boost wrapper status") log.Info(ctx, "rejecting reply to boost wrapper status")
return false, nil return false, nil
} }
// Check visibility of local
// inReplyTo to replying account.
if inReplyTo.IsLocal() { if inReplyTo.IsLocal() {
visible, err := d.visFilter.StatusVisible(ctx, visible, err := d.visFilter.StatusVisible(ctx,
reply.Account, status.Account,
inReplyTo, inReplyTo,
) )
if err != nil { if err != nil {
@ -218,26 +227,9 @@ func (d *Dereferencer) isPermittedReply(
} }
} }
// If this reply claims to be approved, // Check interaction policy of inReplyTo.
// validate this by dereferencing the
// Accept and checking the return value.
// No further checks are required.
if acceptIRI != "" {
return d.isPermittedByAcceptIRI(
ctx,
requestUser,
reply,
inReplyTo,
thisReq,
acceptIRI,
)
}
// Status doesn't claim to be approved.
// Check interaction policy of inReplyTo
// to see if it doesn't require approval.
replyable, err := d.intFilter.StatusReplyable(ctx, replyable, err := d.intFilter.StatusReplyable(ctx,
reply.Account, status.Account,
inReplyTo, inReplyTo,
) )
if err != nil { if err != nil {
@ -246,250 +238,93 @@ func (d *Dereferencer) isPermittedReply(
} }
if replyable.Forbidden() { if replyable.Forbidden() {
// Reply is not permitted according to policy. // Reply is not permitted.
// //
// Either insert a pre-rejected interaction // Insert a pre-rejected interaction request
// req into the db, or update the existing // into the db and return. This ensures that
// one, and return. This ensures that replies // replies to this now-rejected status aren't
// to this rejected reply also aren't permitted. // inadvertently permitted.
return false, d.rejectedByPolicy( id := id.NewULID()
ctx, rejection := &gtsmodel.InteractionRequest{
reply, ID: id,
inReplyTo, StatusID: inReplyTo.ID,
thisReq, TargetAccountID: inReplyTo.AccountID,
) InteractingAccountID: status.AccountID,
} InteractionURI: statusURI,
InteractionType: gtsmodel.InteractionReply,
// Reply is permitted according to the interaction URI: uris.GenerateURIForReject(inReplyTo.Account.Username, id),
// policy set on the replied-to status (if any). RejectedAt: time.Now(),
}
if !replyable.MatchedOnCollection() { err := d.state.DB.PutInteractionRequest(ctx, rejection)
// If we didn't match on a collection, if err != nil && !errors.Is(err, db.ErrAlreadyExists) {
// then we don't require an acceptIRI, return false, gtserror.Newf("db error putting pre-rejected interaction request: %w", err)
// and we don't need to send an Accept;
// just permit the reply full stop.
return true, nil
}
// Reply is permitted, but match was made based
// on inclusion in a followers/following collection.
//
// If the status is ours, mark it as PreApproved
// so the processor knows to create and send out
// an Accept for it immediately.
if inReplyTo.IsLocal() {
reply.PendingApproval = util.Ptr(true)
reply.PreApproved = true
return true, nil
}
// For replies to remote statuses, which matched
// on a followers/following collection, but did not
// include an acceptIRI, we should just drop it.
// It's possible we'll get an Accept for it later
// and we can check everything again.
return false, nil
}
// unpermittedByParent marks the given reply as rejected
// based on the fact that its parent was rejected.
//
// This will create a rejected interaction request for
// the status in the db, if one didn't exist already,
// or update an existing interaction request instead.
func (d *Dereferencer) unpermittedByParent(
ctx context.Context,
reply *gtsmodel.Status,
thisReq *gtsmodel.InteractionRequest,
parentReq *gtsmodel.InteractionRequest,
) error {
if thisReq != nil && thisReq.IsRejected() {
// This interaction request is
// already marked as rejected,
// there's nothing more to do.
return nil
}
if thisReq != nil {
// Before we return, ensure interaction
// request is marked as rejected.
thisReq.RejectedAt = time.Now()
thisReq.AcceptedAt = time.Time{}
err := d.state.DB.UpdateInteractionRequest(
ctx,
thisReq,
"rejected_at",
"accepted_at",
)
if err != nil {
return gtserror.Newf("db error updating interaction request: %w", err)
} }
return nil
}
// We haven't stored a rejected interaction
// request for this status yet, do it now.
rejectID := id.NewULID()
// To ensure the Reject chain stays coherent,
// borrow fields from the up-thread rejection.
// This collapses the chain beyond the first
// rejected reply and allows us to avoid derefing
// further replies we already know we don't want.
inReplyToID := parentReq.StatusID
targetAccountID := parentReq.TargetAccountID
// As nobody is actually Rejecting the reply
// directly, but it's an implicit Reject coming
// from our internal logic, don't bother setting
// a URI (it's not a required field anyway).
uri := ""
rejection := &gtsmodel.InteractionRequest{
ID: rejectID,
StatusID: inReplyToID,
TargetAccountID: targetAccountID,
InteractingAccountID: reply.AccountID,
InteractionURI: reply.URI,
InteractionType: gtsmodel.InteractionReply,
URI: uri,
RejectedAt: time.Now(),
}
err := d.state.DB.PutInteractionRequest(ctx, rejection)
if err != nil && !errors.Is(err, db.ErrAlreadyExists) {
return gtserror.Newf("db error putting pre-rejected interaction request: %w", err)
}
return nil
}
// isPermittedByAcceptIRI checks whether the given acceptIRI
// permits the given reply to the given inReplyTo status.
// If yes, then thisReq will be updated to reflect the
// acceptance, if it's not nil.
func (d *Dereferencer) isPermittedByAcceptIRI(
ctx context.Context,
requestUser string,
reply *gtsmodel.Status,
inReplyTo *gtsmodel.Status,
thisReq *gtsmodel.InteractionRequest,
acceptIRI string,
) (bool, error) {
permitted, err := d.isValidAccept(
ctx,
requestUser,
acceptIRI,
reply.URI,
inReplyTo.AccountURI,
)
if err != nil {
// Error dereferencing means we couldn't
// get the Accept right now or it wasn't
// valid, so we shouldn't store this status.
err := gtserror.Newf("undereferencable ApprovedByURI: %w", err)
return false, err
}
if !permitted {
// It's a no from
// us, squirt.
return false, nil return false, nil
} }
// Reply is permitted by this Accept. if replyable.Permitted() &&
// If it was previously rejected or !replyable.MatchedOnCollection() {
// pending approval, clear that now. // Replier is permitted to do this
reply.PendingApproval = util.Ptr(false) // interaction, and didn't match on
if thisReq != nil { // a collection so we don't need to
thisReq.URI = acceptIRI // do further checking.
thisReq.AcceptedAt = time.Now() return true, nil
thisReq.RejectedAt = time.Time{}
err := d.state.DB.UpdateInteractionRequest(
ctx,
thisReq,
"uri",
"accepted_at",
"rejected_at",
)
if err != nil {
return false, gtserror.Newf("db error updating interaction request: %w", err)
}
} }
// All good! // Replier is permitted to do this
// interaction pending approval, or
// permitted but matched on a collection.
//
// Check if we can dereference
// an Accept that grants approval.
if status.ApprovedByURI == "" {
// Status doesn't claim to be approved.
//
// For replies to local statuses that's
// fine, we can put it in the DB pending
// approval, and continue processing it.
//
// If permission was granted based on a match
// with a followers or following collection,
// we can mark it as PreApproved so the processor
// sends an accept out for it immediately.
//
// For replies to remote statuses, though
// we should be polite and just drop it.
if inReplyTo.IsLocal() {
status.PendingApproval = util.Ptr(true)
status.PreApproved = replyable.MatchedOnCollection()
return true, nil
}
return false, nil
}
// Status claims to be approved, check
// this by dereferencing the Accept and
// inspecting the return value.
if err := d.validateApprovedBy(
ctx,
requestUser,
status.ApprovedByURI,
statusURI,
inReplyTo.AccountURI,
); err != nil {
// Error dereferencing means we couldn't
// get the Accept right now or it wasn't
// valid, so we shouldn't store this status.
log.Errorf(ctx, "undereferencable ApprovedByURI: %v", err)
return false, nil
}
// Status has been approved.
status.PendingApproval = util.Ptr(false)
return true, nil return true, nil
} }
func (d *Dereferencer) rejectedByPolicy(
ctx context.Context,
reply *gtsmodel.Status,
inReplyTo *gtsmodel.Status,
thisReq *gtsmodel.InteractionRequest,
) error {
var (
rejectID string
rejectURI string
)
if thisReq != nil {
// Reuse existing ID.
rejectID = thisReq.ID
} else {
// Generate new ID.
rejectID = id.NewULID()
}
if inReplyTo.IsLocal() {
// If this a reply to one of our statuses
// we should generate a URI for the Reject,
// else just use an implicit (empty) URI.
rejectURI = uris.GenerateURIForReject(
inReplyTo.Account.Username,
rejectID,
)
}
if thisReq != nil {
// Before we return, ensure interaction
// request is marked as rejected.
thisReq.RejectedAt = time.Now()
thisReq.AcceptedAt = time.Time{}
thisReq.URI = rejectURI
err := d.state.DB.UpdateInteractionRequest(
ctx,
thisReq,
"rejected_at",
"accepted_at",
"uri",
)
if err != nil {
return gtserror.Newf("db error updating interaction request: %w", err)
}
return nil
}
// We haven't stored a rejected interaction
// request for this status yet, do it now.
rejection := &gtsmodel.InteractionRequest{
ID: rejectID,
StatusID: inReplyTo.ID,
TargetAccountID: inReplyTo.AccountID,
InteractingAccountID: reply.AccountID,
InteractionURI: reply.URI,
InteractionType: gtsmodel.InteractionReply,
URI: rejectURI,
RejectedAt: time.Now(),
}
err := d.state.DB.PutInteractionRequest(ctx, rejection)
if err != nil && !errors.Is(err, db.ErrAlreadyExists) {
return gtserror.Newf("db error putting pre-rejected interaction request: %w", err)
}
return nil
}
func (d *Dereferencer) isPermittedBoost( func (d *Dereferencer) isPermittedBoost(
ctx context.Context, ctx context.Context,
requestUser string, requestUser string,
@ -583,22 +418,18 @@ func (d *Dereferencer) isPermittedBoost(
// Boost claims to be approved, check // Boost claims to be approved, check
// this by dereferencing the Accept and // this by dereferencing the Accept and
// inspecting the return value. // inspecting the return value.
permitted, err := d.isValidAccept( if err := d.validateApprovedBy(
ctx, ctx,
requestUser, requestUser,
status.ApprovedByURI, status.ApprovedByURI,
status.URI, status.URI,
boostOf.AccountURI, boostOf.AccountURI,
) ); err != nil {
if err != nil {
// Error dereferencing means we couldn't // Error dereferencing means we couldn't
// get the Accept right now or it wasn't // get the Accept right now or it wasn't
// valid, so we shouldn't store this status. // valid, so we shouldn't store this status.
err := gtserror.Newf("undereferencable ApprovedByURI: %w", err) log.Errorf(ctx, "undereferencable ApprovedByURI: %v", err)
return false, err
}
if !permitted {
return false, nil return false, nil
} }
@ -607,59 +438,43 @@ func (d *Dereferencer) isPermittedBoost(
return true, nil return true, nil
} }
// isValidAccept dereferences the activitystreams Accept at the // validateApprovedBy dereferences the activitystreams Accept at
// specified IRI, and checks the Accept for validity against the // the specified IRI, and checks the Accept for validity against
// provided expectedObject and expectedActor. // the provided expectedObject and expectedActor.
// //
// Will return either (true, nil) if everything looked OK, an error // Will return either nil if everything looked OK, or an error if
// if something went wrong internally during deref, or (false, nil) // something went wrong during deref, or if the dereffed Accept
// if the dereferenced Accept did not meet expectations. // did not meet expectations.
func (d *Dereferencer) isValidAccept( func (d *Dereferencer) validateApprovedBy(
ctx context.Context, ctx context.Context,
requestUser string, requestUser string,
acceptIRIStr string, // Eg., "https://example.org/users/someone/accepts/01J2736AWWJ3411CPR833F6D03" approvedByURIStr string, // Eg., "https://example.org/users/someone/accepts/01J2736AWWJ3411CPR833F6D03"
expectObjectURIStr string, // Eg., "https://some.instance.example.org/users/someone_else/statuses/01J27414TWV9F7DC39FN8ABB5R" expectObjectURIStr string, // Eg., "https://some.instance.example.org/users/someone_else/statuses/01J27414TWV9F7DC39FN8ABB5R"
expectActorURIStr string, // Eg., "https://example.org/users/someone" expectActorURIStr string, // Eg., "https://example.org/users/someone"
) (bool, error) { ) error {
l := log. approvedByURI, err := url.Parse(approvedByURIStr)
WithContext(ctx).
WithField("acceptIRI", acceptIRIStr)
acceptIRI, err := url.Parse(acceptIRIStr)
if err != nil { if err != nil {
// Real returnable error. err := gtserror.Newf("error parsing approvedByURI: %w", err)
err := gtserror.Newf("error parsing acceptIRI: %w", err) return err
return false, err
} }
// Don't make calls to the Accept IRI // Don't make calls to the remote if it's blocked.
// if it's blocked, just return false. if blocked, err := d.state.DB.IsDomainBlocked(ctx, approvedByURI.Host); blocked || err != nil {
blocked, err := d.state.DB.IsDomainBlocked(ctx, acceptIRI.Host) err := gtserror.Newf("domain %s is blocked", approvedByURI.Host)
if err != nil { return err
// Real returnable error.
err := gtserror.Newf("error checking domain block: %w", err)
return false, err
}
if blocked {
l.Info("Accept host is blocked")
return false, nil
} }
tsport, err := d.transportController.NewTransportForUsername(ctx, requestUser) tsport, err := d.transportController.NewTransportForUsername(ctx, requestUser)
if err != nil { if err != nil {
// Real returnable error.
err := gtserror.Newf("error creating transport: %w", err) err := gtserror.Newf("error creating transport: %w", err)
return false, err return err
} }
// Make the call to resolve into an Acceptable. // Make the call to resolve into an Acceptable.
// Log any error encountered here but don't rsp, err := tsport.Dereference(ctx, approvedByURI)
// return it as it's not *our* error.
rsp, err := tsport.Dereference(ctx, acceptIRI)
if err != nil { if err != nil {
l.Errorf("error dereferencing Accept: %v", err) err := gtserror.Newf("error dereferencing %s: %w", approvedByURIStr, err)
return false, nil return err
} }
acceptable, err := ap.ResolveAcceptable(ctx, rsp.Body) acceptable, err := ap.ResolveAcceptable(ctx, rsp.Body)
@ -668,71 +483,66 @@ func (d *Dereferencer) isValidAccept(
_ = rsp.Body.Close() _ = rsp.Body.Close()
if err != nil { if err != nil {
l.Errorf("error resolving to Accept: %v", err) err := gtserror.Newf("error resolving Accept %s: %w", approvedByURIStr, err)
return false, err return err
} }
// Extract the URI/ID of the Accept. // Extract the URI/ID of the Accept.
acceptID := ap.GetJSONLDId(acceptable) acceptURI := ap.GetJSONLDId(acceptable)
acceptIDStr := acceptID.String() acceptURIStr := acceptURI.String()
// Check whether input URI and final returned URI // Check whether input URI and final returned URI
// have changed (i.e. we followed some redirects). // have changed (i.e. we followed some redirects).
rspURL := rsp.Request.URL rspURL := rsp.Request.URL
rspURLStr := rspURL.String() rspURLStr := rspURL.String()
if rspURLStr != acceptIRIStr { switch {
// If rspURLStr != acceptIRIStr, make sure final case rspURLStr == approvedByURIStr:
// response URL is at least on the same host as
// what we expected (ie., we weren't redirected
// across domains), and make sure it's the same
// as the ID of the Accept we were returned.
switch {
case rspURL.Host != acceptIRI.Host:
l.Errorf(
"final deref host %s did not match acceptIRI host",
rspURL.Host,
)
return false, nil
case acceptIDStr != rspURLStr: // i.e. from here, rspURLStr != approvedByURIStr.
l.Errorf( //
"final deref uri %s did not match returned Accept ID %s", // Make sure it's at least on the same host as
rspURLStr, acceptIDStr, // what we expected (ie., we weren't redirected
) // across domains), and make sure it's the same
return false, nil // as the ID of the Accept we were returned.
} case rspURL.Host != approvedByURI.Host:
return gtserror.Newf(
"final dereference host %s did not match approvedByURI host %s",
rspURL.Host, approvedByURI.Host,
)
case acceptURIStr != rspURLStr:
return gtserror.Newf(
"final dereference uri %s did not match returned Accept ID/URI %s",
rspURLStr, acceptURIStr,
)
} }
// Response is superficially OK,
// check in more detail now.
// Extract the actor IRI and string from Accept. // Extract the actor IRI and string from Accept.
actorIRIs := ap.GetActorIRIs(acceptable) actorIRIs := ap.GetActorIRIs(acceptable)
actorIRI, actorIRIStr := extractIRI(actorIRIs) actorIRI, actorIRIStr := extractIRI(actorIRIs)
switch { switch {
case actorIRIStr == "": case actorIRIStr == "":
l.Error("Accept missing actor IRI") err := gtserror.New("missing Accept actor IRI")
return false, nil return gtserror.SetMalformed(err)
// Ensure the Accept Actor is on // Ensure the Accept Actor is who we expect
// the instance hosting the Accept. // it to be, and not someone else trying to
case actorIRI.Host != acceptID.Host: // do an Accept for an interaction with a
l.Errorf( // statusable they don't own.
"actor %s not on the same host as Accept", case actorIRI.Host != acceptURI.Host:
actorIRIStr, return gtserror.Newf(
"Accept Actor %s was not the same host as Accept %s",
actorIRIStr, acceptURIStr,
) )
return false, nil
// Ensure the Accept Actor is who we expect // Ensure the Accept Actor is who we expect
// it to be, and not someone else trying to // it to be, and not someone else trying to
// do an Accept for an interaction with a // do an Accept for an interaction with a
// statusable they don't own. // statusable they don't own.
case actorIRIStr != expectActorURIStr: case actorIRIStr != expectActorURIStr:
l.Errorf( return gtserror.Newf(
"actor %s was not the same as expected actor %s", "Accept Actor %s was not the same as expected actor %s",
actorIRIStr, expectActorURIStr, actorIRIStr, expectActorURIStr,
) )
return false, nil
} }
// Extract the object IRI string from Accept. // Extract the object IRI string from Accept.
@ -740,22 +550,20 @@ func (d *Dereferencer) isValidAccept(
_, objectIRIStr := extractIRI(objectIRIs) _, objectIRIStr := extractIRI(objectIRIs)
switch { switch {
case objectIRIStr == "": case objectIRIStr == "":
l.Error("missing Accept object IRI") err := gtserror.New("missing Accept object IRI")
return false, nil return gtserror.SetMalformed(err)
// Ensure the Accept Object is what we expect // Ensure the Accept Object is what we expect
// it to be, ie., it's Accepting the interaction // it to be, ie., it's Accepting the interaction
// we need it to Accept, and not something else. // we need it to Accept, and not something else.
case objectIRIStr != expectObjectURIStr: case objectIRIStr != expectObjectURIStr:
l.Errorf( return gtserror.Newf(
"resolved Accept object IRI %s was not the same as expected object %s", "resolved Accept Object uri %s was not the same as expected object %s",
objectIRIStr, expectObjectURIStr, objectIRIStr, expectObjectURIStr,
) )
return false, nil
} }
// Everything looks OK. return nil
return true, nil
} }
// extractIRI is shorthand to extract the first IRI // extractIRI is shorthand to extract the first IRI

View file

@ -24,7 +24,6 @@
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
"github.com/superseriousbusiness/gotosocial/internal/ap" "github.com/superseriousbusiness/gotosocial/internal/ap"
"github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/gtserror" "github.com/superseriousbusiness/gotosocial/internal/gtserror"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -69,20 +68,6 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return gtserror.NewErrorBadRequest(errors.New(text), text) return gtserror.NewErrorBadRequest(errors.New(text), text)
} }
// Ensure requester is the same as the
// Actor of the Accept; you can't Accept
// something on someone else's behalf.
actorURI, err := ap.ExtractActorURI(accept)
if err != nil {
const text = "Accept had empty or invalid actor property"
return gtserror.NewErrorBadRequest(errors.New(text), text)
}
if requestingAcct.URI != actorURI.String() {
const text = "Accept actor and requesting account were not the same"
return gtserror.NewErrorBadRequest(errors.New(text), text)
}
// Iterate all provided objects in the activity, // Iterate all provided objects in the activity,
// handling the ones we know how to handle. // handling the ones we know how to handle.
for _, object := range ap.ExtractObjects(accept) { for _, object := range ap.ExtractObjects(accept) {
@ -123,6 +108,18 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return err return err
} }
// ACCEPT STATUS (reply/boost)
case uris.IsStatusesPath(objIRI):
if err := f.acceptStatusIRI(
ctx,
activityID.String(),
objIRI.String(),
receivingAcct,
requestingAcct,
); err != nil {
return err
}
// ACCEPT LIKE // ACCEPT LIKE
case uris.IsLikePath(objIRI): case uris.IsLikePath(objIRI):
if err := f.acceptLikeIRI( if err := f.acceptLikeIRI(
@ -135,20 +132,9 @@ func (f *federatingDB) Accept(ctx context.Context, accept vocab.ActivityStreamsA
return err return err
} }
// ACCEPT OTHER (reply? boost?) // UNHANDLED
//
// Don't check on IsStatusesPath
// as this may be a remote status.
default: default:
if err := f.acceptOtherIRI( log.Debugf(ctx, "unhandled iri type: %s", objIRI)
ctx,
activityID,
objIRI,
receivingAcct,
requestingAcct,
); err != nil {
return err
}
} }
} }
} }
@ -290,91 +276,39 @@ func (f *federatingDB) acceptFollowIRI(
return nil return nil
} }
func (f *federatingDB) acceptOtherIRI( func (f *federatingDB) acceptStatusIRI(
ctx context.Context, ctx context.Context,
activityID *url.URL, activityID string,
objectIRI *url.URL, objectIRI string,
receivingAcct *gtsmodel.Account, receivingAcct *gtsmodel.Account,
requestingAcct *gtsmodel.Account, requestingAcct *gtsmodel.Account,
) error { ) error {
// See if we can get a status from the db. // Lock on this potential status
status, err := f.state.DB.GetStatusByURI(ctx, objectIRI.String()) // URI as we may be updating it.
unlock := f.state.FedLocks.Lock(objectIRI)
defer unlock()
// Get the status from the db.
status, err := f.state.DB.GetStatusByURI(ctx, objectIRI)
if err != nil && !errors.Is(err, db.ErrNoEntries) { if err != nil && !errors.Is(err, db.ErrNoEntries) {
err := gtserror.Newf("db error getting status: %w", err) err := gtserror.Newf("db error getting status: %w", err)
return gtserror.NewErrorInternalError(err) return gtserror.NewErrorInternalError(err)
} }
if status != nil { if status == nil {
// We had a status stored with this // We didn't have a status with
// objectIRI, proceed to accept it. // this URI, so nothing to do.
return f.acceptStoredStatus( // Just return.
ctx,
activityID,
status,
receivingAcct,
requestingAcct,
)
}
if objectIRI.Host == config.GetHost() ||
objectIRI.Host == config.GetAccountDomain() {
// Claims to be Accepting something of ours,
// but we don't have a status stored for this
// URI, so most likely it's been deleted in
// the meantime, just bail.
return nil return nil
} }
// This must be an Accept of a remote Activity if !status.IsLocal() {
// or Object. Ensure relevance of this message // We don't process Accepts of statuses
// by checking that receiver follows requester. // that weren't created on our instance.
following, err := f.state.DB.IsFollowing( // Just return.
ctx,
receivingAcct.ID,
requestingAcct.ID,
)
if err != nil {
err := gtserror.Newf("db error checking following: %w", err)
return gtserror.NewErrorInternalError(err)
}
if !following {
// If we don't follow this person, and
// they're not Accepting something we know
// about, then we don't give a good goddamn.
return nil return nil
} }
// This may be a reply, or it may be a boost,
// we can't know yet without dereferencing it,
// but let the processor worry about that.
apObjectType := ap.ObjectUnknown
// Pass to the processor and let them handle side effects.
f.state.Workers.Federator.Queue.Push(&messages.FromFediAPI{
APObjectType: apObjectType,
APActivityType: ap.ActivityAccept,
APIRI: activityID,
APObject: objectIRI,
Receiving: receivingAcct,
Requesting: requestingAcct,
})
return nil
}
func (f *federatingDB) acceptStoredStatus(
ctx context.Context,
activityID *url.URL,
status *gtsmodel.Status,
receivingAcct *gtsmodel.Account,
requestingAcct *gtsmodel.Account,
) error {
// Lock on this status URI
// as we may be updating it.
unlock := f.state.FedLocks.Lock(status.URI)
defer unlock()
pendingApproval := util.PtrOrValue(status.PendingApproval, false) pendingApproval := util.PtrOrValue(status.PendingApproval, false)
if !pendingApproval { if !pendingApproval {
// Status doesn't need approval or it's // Status doesn't need approval or it's
@ -383,6 +317,14 @@ func (f *federatingDB) acceptStoredStatus(
return nil return nil
} }
// Make sure the creator of the original status
// is the same as the inbox processing the Accept;
// this also ensures the status is local.
if status.AccountID != receivingAcct.ID {
const text = "status author account and inbox account were not the same"
return gtserror.NewErrorUnprocessableEntity(errors.New(text), text)
}
// Make sure the target of the interaction (reply/boost) // Make sure the target of the interaction (reply/boost)
// is the same as the account doing the Accept. // is the same as the account doing the Accept.
if status.BoostOfAccountID != requestingAcct.ID && if status.BoostOfAccountID != requestingAcct.ID &&
@ -393,7 +335,7 @@ func (f *federatingDB) acceptStoredStatus(
// Mark the status as approved by this Accept URI. // Mark the status as approved by this Accept URI.
status.PendingApproval = util.Ptr(false) status.PendingApproval = util.Ptr(false)
status.ApprovedByURI = activityID.String() status.ApprovedByURI = activityID
if err := f.state.DB.UpdateStatus( if err := f.state.DB.UpdateStatus(
ctx, ctx,
status, status,

View file

@ -69,29 +69,25 @@ type InteractionRequest struct {
Like *StatusFave `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionLike. Like *StatusFave `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionLike.
Reply *Status `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionReply. Reply *Status `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionReply.
Announce *Status `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionAnnounce. Announce *Status `bun:"-"` // Not stored in DB. Only set if InteractionType = InteractionAnnounce.
URI string `bun:",nullzero,unique"` // ActivityPub URI of the Accept (if accepted) or Reject (if rejected). Null/empty if currently neither accepted not rejected.
AcceptedAt time.Time `bun:"type:timestamptz,nullzero"` // If interaction request was accepted, time at which this occurred. AcceptedAt time.Time `bun:"type:timestamptz,nullzero"` // If interaction request was accepted, time at which this occurred.
RejectedAt time.Time `bun:"type:timestamptz,nullzero"` // If interaction request was rejected, time at which this occurred. RejectedAt time.Time `bun:"type:timestamptz,nullzero"` // If interaction request was rejected, time at which this occurred.
// ActivityPub URI of the Accept (if accepted) or Reject (if rejected).
// Field may be empty if currently neither accepted not rejected, or if
// acceptance/rejection was implicit (ie., not resulting from an Activity).
URI string `bun:",nullzero,unique"`
} }
// IsHandled returns true if interaction // IsHandled returns true if interaction
// request has been neither accepted or rejected. // request has been neither accepted or rejected.
func (ir *InteractionRequest) IsPending() bool { func (ir *InteractionRequest) IsPending() bool {
return !ir.IsAccepted() && !ir.IsRejected() return ir.URI == "" && ir.AcceptedAt.IsZero() && ir.RejectedAt.IsZero()
} }
// IsAccepted returns true if this // IsAccepted returns true if this
// interaction request has been accepted. // interaction request has been accepted.
func (ir *InteractionRequest) IsAccepted() bool { func (ir *InteractionRequest) IsAccepted() bool {
return !ir.AcceptedAt.IsZero() return ir.URI != "" && !ir.AcceptedAt.IsZero()
} }
// IsRejected returns true if this // IsRejected returns true if this
// interaction request has been rejected. // interaction request has been rejected.
func (ir *InteractionRequest) IsRejected() bool { func (ir *InteractionRequest) IsRejected() bool {
return !ir.RejectedAt.IsZero() return ir.URI != "" && !ir.RejectedAt.IsZero()
} }

View file

@ -20,7 +20,6 @@
import ( import (
"context" "context"
"errors" "errors"
"net/url"
"time" "time"
"codeberg.org/gruf/go-kv" "codeberg.org/gruf/go-kv"
@ -145,10 +144,6 @@ func (p *Processor) ProcessFromFediAPI(ctx context.Context, fMsg *messages.FromF
// ACCEPT (pending) ANNOUNCE // ACCEPT (pending) ANNOUNCE
case ap.ActivityAnnounce: case ap.ActivityAnnounce:
return p.fediAPI.AcceptAnnounce(ctx, fMsg) return p.fediAPI.AcceptAnnounce(ctx, fMsg)
// ACCEPT (remote) REPLY or ANNOUNCE
case ap.ObjectUnknown:
return p.fediAPI.AcceptRemoteStatus(ctx, fMsg)
} }
// REJECT SOMETHING // REJECT SOMETHING
@ -828,60 +823,6 @@ func (p *fediAPI) AcceptReply(ctx context.Context, fMsg *messages.FromFediAPI) e
return nil return nil
} }
func (p *fediAPI) AcceptRemoteStatus(ctx context.Context, fMsg *messages.FromFediAPI) error {
// See if we can accept a remote
// status we don't have stored yet.
objectIRI, ok := fMsg.APObject.(*url.URL)
if !ok {
return gtserror.Newf("%T not parseable as *url.URL", fMsg.APObject)
}
acceptIRI := fMsg.APIRI
if acceptIRI == nil {
return gtserror.New("acceptIRI was nil")
}
// Assume we're accepting a status; create a
// barebones status for dereferencing purposes.
bareStatus := &gtsmodel.Status{
URI: objectIRI.String(),
ApprovedByURI: acceptIRI.String(),
}
// Call RefreshStatus() to process the provided
// barebones status and insert it into the database,
// if indeed it's actually a status URI we can fetch.
//
// This will also check whether the given AcceptIRI
// actually grants permission for this status.
status, _, err := p.federate.RefreshStatus(ctx,
fMsg.Receiving.Username,
bareStatus,
nil, nil,
)
if err != nil {
return gtserror.Newf("error processing accepted status %s: %w", bareStatus.URI, err)
}
// No error means it was indeed a remote status, and the
// given acceptIRI permitted it. Timeline and notify it.
if err := p.surface.timelineAndNotifyStatus(ctx, status); err != nil {
log.Errorf(ctx, "error timelining and notifying status: %v", err)
}
// Interaction counts changed on the interacted status;
// uncache the prepared version from all timelines.
if status.InReplyToID != "" {
p.surface.invalidateStatusFromTimelines(ctx, status.InReplyToID)
}
if status.BoostOfID != "" {
p.surface.invalidateStatusFromTimelines(ctx, status.BoostOfID)
}
return nil
}
func (p *fediAPI) AcceptAnnounce(ctx context.Context, fMsg *messages.FromFediAPI) error { func (p *fediAPI) AcceptAnnounce(ctx context.Context, fMsg *messages.FromFediAPI) error {
boost, ok := fMsg.GTSModel.(*gtsmodel.Status) boost, ok := fMsg.GTSModel.(*gtsmodel.Status)
if !ok { if !ok {

View file

@ -1988,16 +1988,6 @@ func (c *Converter) InteractionReqToASAccept(
return nil, gtserror.Newf("invalid interacting account uri: %w", err) return nil, gtserror.Newf("invalid interacting account uri: %w", err)
} }
publicIRI, err := url.Parse(pub.PublicActivityPubIRI)
if err != nil {
return nil, gtserror.Newf("invalid public uri: %w", err)
}
followersIRI, err := url.Parse(req.TargetAccount.FollowersURI)
if err != nil {
return nil, gtserror.Newf("invalid followers uri: %w", err)
}
// Set id to the URI of // Set id to the URI of
// interaction request. // interaction request.
ap.SetJSONLDId(accept, acceptID) ap.SetJSONLDId(accept, acceptID)
@ -2013,9 +2003,6 @@ func (c *Converter) InteractionReqToASAccept(
// of interaction URI. // of interaction URI.
ap.AppendTo(accept, toIRI) ap.AppendTo(accept, toIRI)
// Cc to the actor's followers, and to Public.
ap.AppendCc(accept, publicIRI, followersIRI)
return accept, nil return accept, nil
} }
@ -2047,16 +2034,6 @@ func (c *Converter) InteractionReqToASReject(
return nil, gtserror.Newf("invalid interacting account uri: %w", err) return nil, gtserror.Newf("invalid interacting account uri: %w", err)
} }
publicIRI, err := url.Parse(pub.PublicActivityPubIRI)
if err != nil {
return nil, gtserror.Newf("invalid public uri: %w", err)
}
followersIRI, err := url.Parse(req.TargetAccount.FollowersURI)
if err != nil {
return nil, gtserror.Newf("invalid followers uri: %w", err)
}
// Set id to the URI of // Set id to the URI of
// interaction request. // interaction request.
ap.SetJSONLDId(reject, rejectID) ap.SetJSONLDId(reject, rejectID)
@ -2072,8 +2049,5 @@ func (c *Converter) InteractionReqToASReject(
// of interaction URI. // of interaction URI.
ap.AppendTo(reject, toIRI) ap.AppendTo(reject, toIRI)
// Cc to the actor's followers, and to Public.
ap.AppendCc(reject, publicIRI, followersIRI)
return reject, nil return reject, nil
} }

View file

@ -1181,10 +1181,6 @@ func (suite *InternalToASTestSuite) TestInteractionReqToASAccept() {
suite.Equal(`{ suite.Equal(`{
"@context": "https://www.w3.org/ns/activitystreams", "@context": "https://www.w3.org/ns/activitystreams",
"actor": "http://localhost:8080/users/the_mighty_zork", "actor": "http://localhost:8080/users/the_mighty_zork",
"cc": [
"https://www.w3.org/ns/activitystreams#Public",
"http://localhost:8080/users/the_mighty_zork/followers"
],
"id": "http://localhost:8080/users/the_mighty_zork/accepts/01J1AKMZ8JE5NW0ZSFTRC1JJNE", "id": "http://localhost:8080/users/the_mighty_zork/accepts/01J1AKMZ8JE5NW0ZSFTRC1JJNE",
"object": "https://fossbros-anonymous.io/users/foss_satan/statuses/01J1AKRRHQ6MDDQHV0TP716T2K", "object": "https://fossbros-anonymous.io/users/foss_satan/statuses/01J1AKRRHQ6MDDQHV0TP716T2K",
"to": "http://fossbros-anonymous.io/users/foss_satan", "to": "http://fossbros-anonymous.io/users/foss_satan",

View file

@ -31,6 +31,7 @@ type Blob struct {
// //
// https://sqlite.org/c3ref/blob_open.html // https://sqlite.org/c3ref/blob_open.html
func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) { func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) {
c.checkInterrupt()
defer c.arena.mark()() defer c.arena.mark()()
blobPtr := c.arena.new(ptrlen) blobPtr := c.arena.new(ptrlen)
dbPtr := c.arena.string(db) dbPtr := c.arena.string(db)
@ -42,7 +43,6 @@ func (c *Conn) OpenBlob(db, table, column string, row int64, write bool) (*Blob,
flags = 1 flags = 1
} }
c.checkInterrupt(c.handle)
r := c.call("sqlite3_blob_open", uint64(c.handle), r := c.call("sqlite3_blob_open", uint64(c.handle),
uint64(dbPtr), uint64(tablePtr), uint64(columnPtr), uint64(dbPtr), uint64(tablePtr), uint64(columnPtr),
uint64(row), flags, uint64(blobPtr)) uint64(row), flags, uint64(blobPtr))

View file

@ -284,10 +284,7 @@ func walCallback(ctx context.Context, mod api.Module, _, pDB, zSchema uint32, pa
// //
// https://sqlite.org/c3ref/autovacuum_pages.html // https://sqlite.org/c3ref/autovacuum_pages.html
func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error { func (c *Conn) AutoVacuumPages(cb func(schema string, dbPages, freePages, bytesPerPage uint) uint) error {
var funcPtr uint32 funcPtr := util.AddHandle(c.ctx, cb)
if cb != nil {
funcPtr = util.AddHandle(c.ctx, cb)
}
r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr)) r := c.call("sqlite3_autovacuum_pages_go", uint64(c.handle), uint64(funcPtr))
return c.error(r) return c.error(r)
} }

View file

@ -24,7 +24,7 @@ type Conn struct {
pending *Stmt pending *Stmt
stmts []*Stmt stmts []*Stmt
timer *time.Timer timer *time.Timer
busy func(context.Context, int) bool busy func(int) bool
log func(xErrorCode, string) log func(xErrorCode, string)
collation func(*Conn, string) collation func(*Conn, string)
wal func(*Conn, string, int) error wal func(*Conn, string, int) error
@ -38,20 +38,14 @@ type Conn struct {
handle uint32 handle uint32
} }
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE], [OPEN_URI] and [OPEN_NOFOLLOW].
func Open(filename string) (*Conn, error) { func Open(filename string) (*Conn, error) {
return newConn(context.Background(), filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI) return newConn(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI|OPEN_NOFOLLOW)
}
// OpenContext is like [Open] but includes a context,
// which is used to interrupt the process of opening the connectiton.
func OpenContext(ctx context.Context, filename string) (*Conn, error) {
return newConn(ctx, filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
} }
// OpenFlags opens an SQLite database file as specified by the filename argument. // OpenFlags opens an SQLite database file as specified by the filename argument.
// //
// If none of the required flags are used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used. // If none of the required flags is used, a combination of [OPEN_READWRITE] and [OPEN_CREATE] is used.
// If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma": // If a URI filename is used, PRAGMA statements to execute can be specified using "_pragma":
// //
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)") // sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)")
@ -61,33 +55,25 @@ func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 { if flags&(OPEN_READONLY|OPEN_READWRITE|OPEN_CREATE) == 0 {
flags |= OPEN_READWRITE | OPEN_CREATE flags |= OPEN_READWRITE | OPEN_CREATE
} }
return newConn(context.Background(), filename, flags) return newConn(filename, flags)
} }
type connKey struct{} type connKey struct{}
func newConn(ctx context.Context, filename string, flags OpenFlag) (res *Conn, _ error) { func newConn(filename string, flags OpenFlag) (conn *Conn, err error) {
err := ctx.Err() sqlite, err := instantiateSQLite()
if err != nil {
return nil, err
}
c := &Conn{interrupt: ctx}
c.sqlite, err = instantiateSQLite()
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
if res == nil { if conn == nil {
c.Close() sqlite.close()
c.sqlite.close()
} else {
c.interrupt = context.Background()
} }
}() }()
c.ctx = context.WithValue(c.ctx, connKey{}, c) c := &Conn{sqlite: sqlite}
c.arena = c.newArena(1024) c.arena = c.newArena(1024)
c.ctx = context.WithValue(c.ctx, connKey{}, c)
c.handle, err = c.openDB(filename, flags) c.handle, err = c.openDB(filename, flags)
if err == nil { if err == nil {
err = initExtensions(c) err = initExtensions(c)
@ -112,7 +98,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
return 0, err return 0, err
} }
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") { if flags|OPEN_URI != 0 && strings.HasPrefix(filename, "file:") {
var pragmas strings.Builder var pragmas strings.Builder
if _, after, ok := strings.Cut(filename, "?"); ok { if _, after, ok := strings.Cut(filename, "?"); ok {
@ -124,7 +109,6 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
} }
} }
if pragmas.Len() != 0 { if pragmas.Len() != 0 {
c.checkInterrupt(handle)
pragmaPtr := c.arena.string(pragmas.String()) pragmaPtr := c.arena.string(pragmas.String())
r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0) r := c.call("sqlite3_exec", uint64(handle), uint64(pragmaPtr), 0, 0, 0)
if err := c.sqlite.error(r, handle, pragmas.String()); err != nil { if err := c.sqlite.error(r, handle, pragmas.String()); err != nil {
@ -134,6 +118,7 @@ func (c *Conn) openDB(filename string, flags OpenFlag) (uint32, error) {
} }
} }
} }
c.call("sqlite3_progress_handler_go", uint64(handle), 100)
return handle, nil return handle, nil
} }
@ -175,10 +160,10 @@ func (c *Conn) Close() error {
// //
// https://sqlite.org/c3ref/exec.html // https://sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error { func (c *Conn) Exec(sql string) error {
c.checkInterrupt()
defer c.arena.mark()() defer c.arena.mark()()
sqlPtr := c.arena.string(sql) sqlPtr := c.arena.string(sql)
c.checkInterrupt(c.handle)
r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0) r := c.call("sqlite3_exec", uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
return c.error(r, sql) return c.error(r, sql)
} }
@ -316,7 +301,8 @@ func (c *Conn) ReleaseMemory() error {
return c.error(r) return c.error(r)
} }
// GetInterrupt gets the context set with [Conn.SetInterrupt]. // GetInterrupt gets the context set with [Conn.SetInterrupt],
// or nil if none was set.
func (c *Conn) GetInterrupt() context.Context { func (c *Conn) GetInterrupt() context.Context {
return c.interrupt return c.interrupt
} }
@ -336,11 +322,9 @@ func (c *Conn) GetInterrupt() context.Context {
// //
// https://sqlite.org/c3ref/interrupt.html // https://sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
old = c.interrupt // Is it the same context?
c.interrupt = ctx if ctx == c.interrupt {
return ctx
if ctx == old || ctx.Done() == old.Done() {
return old
} }
// A busy SQL statement prevents SQLite from ignoring an interrupt // A busy SQL statement prevents SQLite from ignoring an interrupt
@ -349,29 +333,32 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
defer c.arena.mark()() defer c.arena.mark()()
stmtPtr := c.arena.new(ptrlen) stmtPtr := c.arena.new(ptrlen)
loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`) loopPtr := c.arena.string(`WITH RECURSIVE c(x) AS (VALUES(0) UNION ALL SELECT x FROM c) SELECT x FROM c`)
c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, c.call("sqlite3_prepare_v3", uint64(c.handle), uint64(loopPtr), math.MaxUint64, 0, uint64(stmtPtr), 0)
uint64(PREPARE_PERSISTENT), uint64(stmtPtr), 0)
c.pending = &Stmt{c: c} c.pending = &Stmt{c: c}
c.pending.handle = util.ReadUint32(c.mod, stmtPtr) c.pending.handle = util.ReadUint32(c.mod, stmtPtr)
} }
if old.Done() != nil && ctx.Err() == nil { old = c.interrupt
c.interrupt = ctx
if old != nil && old.Done() != nil && (ctx == nil || ctx.Err() == nil) {
c.pending.Reset() c.pending.Reset()
} }
if ctx.Done() != nil { if ctx != nil && ctx.Done() != nil {
c.pending.Step() c.pending.Step()
} }
return old return old
} }
func (c *Conn) checkInterrupt(handle uint32) { func (c *Conn) checkInterrupt() {
if c.interrupt.Err() != nil { if c.interrupt != nil && c.interrupt.Err() != nil {
c.call("sqlite3_interrupt", uint64(handle)) c.call("sqlite3_interrupt", uint64(c.handle))
} }
} }
func progressCallback(ctx context.Context, mod api.Module, _ uint32) (interrupt uint32) { func progressCallback(ctx context.Context, mod api.Module, pDB uint32) (interrupt uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() != nil { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB &&
c.interrupt != nil && c.interrupt.Err() != nil {
interrupt = 1 interrupt = 1
} }
return interrupt return interrupt
@ -386,8 +373,9 @@ func (c *Conn) BusyTimeout(timeout time.Duration) error {
return c.error(r) return c.error(r)
} }
func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (retry uint32) { func timeoutCallback(ctx context.Context, mod api.Module, pDB uint32, count, tmout int32) (retry uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.interrupt.Err() == nil { if c, ok := ctx.Value(connKey{}).(*Conn); ok &&
(c.interrupt == nil || c.interrupt.Err() == nil) {
const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64" const delays = "\x01\x02\x05\x0a\x0f\x14\x19\x19\x19\x32\x32\x64"
const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4" const totals = "\x00\x01\x03\x08\x12\x21\x35\x4e\x67\x80\xb2\xe4"
const ndelay = int32(len(delays) - 1) const ndelay = int32(len(delays) - 1)
@ -403,7 +391,7 @@ func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (r
if delay = min(delay, tmout-prior); delay > 0 { if delay = min(delay, tmout-prior); delay > 0 {
delay := time.Duration(delay) * time.Millisecond delay := time.Duration(delay) * time.Millisecond
if c.interrupt.Done() == nil { if c.interrupt == nil || c.interrupt.Done() == nil {
time.Sleep(delay) time.Sleep(delay)
return 1 return 1
} }
@ -426,7 +414,7 @@ func timeoutCallback(ctx context.Context, mod api.Module, count, tmout int32) (r
// BusyHandler registers a callback to handle [BUSY] errors. // BusyHandler registers a callback to handle [BUSY] errors.
// //
// https://sqlite.org/c3ref/busy_handler.html // https://sqlite.org/c3ref/busy_handler.html
func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool)) error { func (c *Conn) BusyHandler(cb func(count int) (retry bool)) error {
var enable uint64 var enable uint64
if cb != nil { if cb != nil {
enable = 1 enable = 1
@ -440,12 +428,9 @@ func (c *Conn) BusyHandler(cb func(ctx context.Context, count int) (retry bool))
} }
func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) { func busyCallback(ctx context.Context, mod api.Module, pDB uint32, count int32) (retry uint32) {
if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil { if c, ok := ctx.Value(connKey{}).(*Conn); ok && c.handle == pDB && c.busy != nil &&
interrupt := c.interrupt (c.interrupt == nil || c.interrupt.Err() == nil) {
if interrupt == nil { if c.busy(int(count)) {
interrupt = context.Background()
}
if interrupt.Err() == nil && c.busy(interrupt, int(count)) {
retry = 1 retry = 1
} }
} }

View file

@ -1,4 +1,4 @@
//go:build go1.23 //go:build (go1.23 || goexperiment.rangefunc) && !vet
package sqlite3 package sqlite3

View file

@ -1,4 +1,4 @@
//go:build !go1.23 //go:build !(go1.23 || goexperiment.rangefunc) || vet
package sqlite3 package sqlite3

View file

@ -40,14 +40,14 @@
// When using a custom time struct, you'll have to implement // When using a custom time struct, you'll have to implement
// [database/sql/driver.Valuer] and [database/sql.Scanner]. // [database/sql/driver.Valuer] and [database/sql.Scanner].
// //
// The Value method should ideally encode to a time [format] supported by SQLite. // The Value method should ideally serialise to a time [format] supported by SQLite.
// This ensures SQL date and time functions work as they should, // This ensures SQL date and time functions work as they should,
// and that your schema works with other SQLite tools. // and that your schema works with other SQLite tools.
// [sqlite3.TimeFormat.Encode] may help. // [sqlite3.TimeFormat.Encode] may help.
// //
// The Scan method needs to take into account that the value it receives can be of differing types. // The Scan method needs to take into account that the value it receives can be of differing types.
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules. // It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
// Or it can be a: string, int64, float64, []byte, or nil, // Or it can be a: string, int64, float64, []byte, nil,
// depending on the column type and what whoever wrote the value. // depending on the column type and what whoever wrote the value.
// [sqlite3.TimeFormat.Decode] may help. // [sqlite3.TimeFormat.Decode] may help.
// //
@ -202,19 +202,19 @@ func (n *connector) Driver() driver.Driver {
return n.driver return n.driver
} }
func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) { func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
c := &conn{ c := &conn{
txLock: n.txLock, txLock: n.txLock,
tmRead: n.tmRead, tmRead: n.tmRead,
tmWrite: n.tmWrite, tmWrite: n.tmWrite,
} }
c.Conn, err = sqlite3.OpenContext(ctx, n.name) c.Conn, err = sqlite3.Open(n.name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() { defer func() {
if res == nil { if err != nil {
c.Close() c.Close()
} }
}() }()
@ -239,7 +239,6 @@ func (n *connector) Connect(ctx context.Context) (res driver.Conn, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer s.Close()
if s.Step() && s.ColumnBool(0) { if s.Step() && s.ColumnBool(0) {
c.readOnly = '1' c.readOnly = '1'
} else { } else {
@ -467,7 +466,6 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
defer s.Stmt.Conn().SetInterrupt(old) defer s.Stmt.Conn().SetInterrupt(old)
err = s.Stmt.Exec() err = s.Stmt.Exec()
s.Stmt.ClearBindings()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -490,7 +488,7 @@ func (s *stmt) setupBindings(args []driver.NamedValue) (err error) {
if arg.Name == "" { if arg.Name == "" {
ids = append(ids, arg.Ordinal) ids = append(ids, arg.Ordinal)
} else { } else {
for _, prefix := range [...]string{":", "@", "$"} { for _, prefix := range []string{":", "@", "$"} {
if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 { if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id) ids = append(ids, id)
} }
@ -524,9 +522,9 @@ func (s *stmt) setupBindings(args []driver.NamedValue) (err error) {
default: default:
panic(util.AssertErr()) panic(util.AssertErr())
} }
if err != nil { }
return err if err != nil {
} return err
} }
} }
return nil return nil
@ -597,11 +595,10 @@ func (r *rows) Close() error {
func (r *rows) Columns() []string { func (r *rows) Columns() []string {
if r.names == nil { if r.names == nil {
count := r.Stmt.ColumnCount() count := r.Stmt.ColumnCount()
names := make([]string, count) r.names = make([]string, count)
for i := range names { for i := range r.names {
names[i] = r.Stmt.ColumnName(i) r.names[i] = r.Stmt.ColumnName(i)
} }
r.names = names
} }
return r.names return r.names
} }
@ -609,29 +606,26 @@ func (r *rows) Columns() []string {
func (r *rows) loadTypes() { func (r *rows) loadTypes() {
if r.nulls == nil { if r.nulls == nil {
count := r.Stmt.ColumnCount() count := r.Stmt.ColumnCount()
nulls := make([]bool, count) r.nulls = make([]bool, count)
types := make([]string, count) r.types = make([]string, count)
for i := range nulls { for i := range r.nulls {
if col := r.Stmt.ColumnOriginName(i); col != "" { if col := r.Stmt.ColumnOriginName(i); col != "" {
types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata( r.types[i], _, r.nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i), r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i), r.Stmt.ColumnTableName(i),
col) col)
} }
} }
r.nulls = nulls
r.types = types
} }
} }
func (r *rows) declType(index int) string { func (r *rows) declType(index int) string {
if r.types == nil { if r.types == nil {
count := r.Stmt.ColumnCount() count := r.Stmt.ColumnCount()
types := make([]string, count) r.types = make([]string, count)
for i := range types { for i := range r.types {
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i)) r.types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
} }
r.types = types
} }
return r.types[index] return r.types[index]
} }
@ -671,23 +665,27 @@ func (r *rows) Next(dest []driver.Value) error {
for i := range dest { for i := range dest {
if t, ok := r.decodeTime(i, dest[i]); ok { if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t dest[i] = t
continue
}
if s, ok := dest[i].(string); ok {
t, ok := maybeTime(s)
if ok {
dest[i] = t
}
} }
} }
return err return err
} }
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) { func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
switch v := v.(type) { switch r.tmRead {
case int64, float64: case sqlite3.TimeFormatDefault, time.RFC3339Nano:
// handled by maybeTime
return
}
switch v.(type) {
case int64, float64, string:
// could be a time value // could be a time value
case string:
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
break
}
t, ok := maybeTime(v)
if ok {
return t, true
}
default: default:
return return
} }

View file

@ -9,7 +9,6 @@ The following optional features are compiled in:
- [JSON](https://sqlite.org/json1.html) - [JSON](https://sqlite.org/json1.html)
- [R*Tree](https://sqlite.org/rtree.html) - [R*Tree](https://sqlite.org/rtree.html)
- [GeoPoly](https://sqlite.org/geopoly.html) - [GeoPoly](https://sqlite.org/geopoly.html)
- [Spellfix1](https://sqlite.org/spellfix1.html)
- [soundex](https://sqlite.org/lang_corefunc.html#soundex) - [soundex](https://sqlite.org/lang_corefunc.html#soundex)
- [stat4](https://sqlite.org/compile.html#enable_stat4) - [stat4](https://sqlite.org/compile.html#enable_stat4)
- [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c) - [base64](https://github.com/sqlite/sqlite/blob/master/ext/misc/base64.c)

View file

@ -14,7 +14,7 @@ trap 'rm -f sqlite3.tmp' EXIT
-o sqlite3.wasm "$ROOT/sqlite3/main.c" \ -o sqlite3.wasm "$ROOT/sqlite3/main.c" \
-I"$ROOT/sqlite3" \ -I"$ROOT/sqlite3" \
-mexec-model=reactor \ -mexec-model=reactor \
-matomics -msimd128 -mmutable-globals -mmultivalue \ -matomics -msimd128 -mmutable-globals \
-mbulk-memory -mreference-types \ -mbulk-memory -mreference-types \
-mnontrapping-fptoint -msign-ext \ -mnontrapping-fptoint -msign-ext \
-fno-stack-protector -fno-stack-clash-protection \ -fno-stack-protector -fno-stack-clash-protection \

View file

@ -51,7 +51,6 @@ sqlite3_create_collation_go
sqlite3_create_function_go sqlite3_create_function_go
sqlite3_create_module_go sqlite3_create_module_go
sqlite3_create_window_function_go sqlite3_create_window_function_go
sqlite3_data_count
sqlite3_database_file_object sqlite3_database_file_object
sqlite3_db_cacheflush sqlite3_db_cacheflush
sqlite3_db_config sqlite3_db_config

Binary file not shown.

View file

@ -33,23 +33,16 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error {
// one or more unknown collating sequences. // one or more unknown collating sequences.
func (c Conn) AnyCollationNeeded() error { func (c Conn) AnyCollationNeeded() error {
r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0) r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0)
if err := c.error(r); err != nil { return c.error(r)
return err
}
c.collation = nil
return nil
} }
// CreateCollation defines a new collating sequence. // CreateCollation defines a new collating sequence.
// //
// https://sqlite.org/c3ref/create_collation.html // https://sqlite.org/c3ref/create_collation.html
func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error { func (c *Conn) CreateCollation(name string, fn func(a, b []byte) int) error {
var funcPtr uint32
defer c.arena.mark()() defer c.arena.mark()()
namePtr := c.arena.string(name) namePtr := c.arena.string(name)
if fn != nil { funcPtr := util.AddHandle(c.ctx, fn)
funcPtr = util.AddHandle(c.ctx, fn)
}
r := c.call("sqlite3_create_collation_go", r := c.call("sqlite3_create_collation_go",
uint64(c.handle), uint64(namePtr), uint64(funcPtr)) uint64(c.handle), uint64(namePtr), uint64(funcPtr))
return c.error(r) return c.error(r)
@ -59,12 +52,9 @@ funcPtr = util.AddHandle(c.ctx, fn)
// //
// https://sqlite.org/c3ref/create_function.html // https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error { func (c *Conn) CreateFunction(name string, nArg int, flag FunctionFlag, fn ScalarFunction) error {
var funcPtr uint32
defer c.arena.mark()() defer c.arena.mark()()
namePtr := c.arena.string(name) namePtr := c.arena.string(name)
if fn != nil { funcPtr := util.AddHandle(c.ctx, fn)
funcPtr = util.AddHandle(c.ctx, fn)
}
r := c.call("sqlite3_create_function_go", r := c.call("sqlite3_create_function_go",
uint64(c.handle), uint64(namePtr), uint64(nArg), uint64(c.handle), uint64(namePtr), uint64(nArg),
uint64(flag), uint64(funcPtr)) uint64(flag), uint64(funcPtr))
@ -81,13 +71,10 @@ funcPtr = util.AddHandle(c.ctx, fn)
// //
// https://sqlite.org/c3ref/create_function.html // https://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error { func (c *Conn) CreateWindowFunction(name string, nArg int, flag FunctionFlag, fn func() AggregateFunction) error {
var funcPtr uint32
defer c.arena.mark()() defer c.arena.mark()()
namePtr := c.arena.string(name)
if fn != nil {
funcPtr = util.AddHandle(c.ctx, fn)
}
call := "sqlite3_create_aggregate_function_go" call := "sqlite3_create_aggregate_function_go"
namePtr := c.arena.string(name)
funcPtr := util.AddHandle(c.ctx, fn)
if _, ok := fn().(WindowFunction); ok { if _, ok := fn().(WindowFunction); ok {
call = "sqlite3_create_window_function_go" call = "sqlite3_create_window_function_go"
} }
@ -197,12 +184,11 @@ func callbackAggregate(db *Conn, pAgg, pApp uint32) (AggregateFunction, uint32)
// We need to create the aggregate. // We need to create the aggregate.
fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)() fn := util.GetHandle(db.ctx, pApp).(func() AggregateFunction)()
handle := util.AddHandle(db.ctx, fn)
if pAgg != 0 { if pAgg != 0 {
handle := util.AddHandle(db.ctx, fn)
util.WriteUint32(db.mod, pAgg, handle) util.WriteUint32(db.mod, pAgg, handle)
return fn, handle
} }
return fn, 0 return fn, handle
} }
func callbackArgs(db *Conn, arg []Value, pArg uint32) { func callbackArgs(db *Conn, arg []Value, pArg uint32) {

View file

@ -1,13 +1,10 @@
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0=
golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4=
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8=
golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=

View file

@ -47,7 +47,7 @@ func (m *mmappedMemory) Reallocate(size uint64) []byte {
// Commit additional memory up to new bytes. // Commit additional memory up to new bytes.
err := unix.Mprotect(m.buf[com:new], unix.PROT_READ|unix.PROT_WRITE) err := unix.Mprotect(m.buf[com:new], unix.PROT_READ|unix.PROT_WRITE)
if err != nil { if err != nil {
return nil panic(err)
} }
// Update committed memory. // Update committed memory.

View file

@ -56,7 +56,7 @@ func (m *virtualMemory) Reallocate(size uint64) []byte {
// Commit additional memory up to new bytes. // Commit additional memory up to new bytes.
_, err := windows.VirtualAlloc(m.addr, uintptr(new), windows.MEM_COMMIT, windows.PAGE_READWRITE) _, err := windows.VirtualAlloc(m.addr, uintptr(new), windows.MEM_COMMIT, windows.PAGE_READWRITE)
if err != nil { if err != nil {
return nil panic(err)
} }
// Update committed memory. // Update committed memory.

View file

@ -26,7 +26,6 @@ func ExportFuncVI[T0 i32](mod wazero.HostModuleBuilder, name string, fn func(con
type funcVII[T0, T1 i32] func(context.Context, api.Module, T0, T1) type funcVII[T0, T1 i32] func(context.Context, api.Module, T0, T1)
func (fn funcVII[T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVII[T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1])) fn(ctx, mod, T0(stack[0]), T1(stack[1]))
} }
@ -40,7 +39,6 @@ func ExportFuncVII[T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn fun
type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) type funcVIII[T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2)
func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVIII[T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[2] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])) fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))
} }
@ -54,7 +52,6 @@ func ExportFuncVIII[T0, T1, T2 i32](mod wazero.HostModuleBuilder, name string, f
type funcVIIII[T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) type funcVIIII[T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3)
func (fn funcVIIII[T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVIIII[T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])) fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))
} }
@ -68,7 +65,6 @@ func ExportFuncVIIII[T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name stri
type funcVIIIII[T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) type funcVIIIII[T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4)
func (fn funcVIIIII[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVIIIII[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])) fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))
} }
@ -82,7 +78,6 @@ func ExportFuncVIIIII[T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder, name
type funcVIIIIJ[T0, T1, T2, T3 i32, T4 i64] func(context.Context, api.Module, T0, T1, T2, T3, T4) type funcVIIIIJ[T0, T1, T2, T3 i32, T4 i64] func(context.Context, api.Module, T0, T1, T2, T3, T4)
func (fn funcVIIIIJ[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcVIIIIJ[T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])) fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))
} }
@ -109,7 +104,6 @@ func ExportFuncII[TR, T0 i32](mod wazero.HostModuleBuilder, name string, fn func
type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR type funcIII[TR, T0, T1 i32] func(context.Context, api.Module, T0, T1) TR
func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIII[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
} }
@ -123,7 +117,6 @@ func ExportFuncIII[TR, T0, T1 i32](mod wazero.HostModuleBuilder, name string, fn
type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR type funcIIII[TR, T0, T1, T2 i32] func(context.Context, api.Module, T0, T1, T2) TR
func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIII[TR, T0, T1, T2]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[2] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2])))
} }
@ -137,7 +130,6 @@ func ExportFuncIIII[TR, T0, T1, T2 i32](mod wazero.HostModuleBuilder, name strin
type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR type funcIIIII[TR, T0, T1, T2, T3 i32] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIII[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
} }
@ -151,7 +143,6 @@ func ExportFuncIIIII[TR, T0, T1, T2, T3 i32](mod wazero.HostModuleBuilder, name
type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR type funcIIIIII[TR, T0, T1, T2, T3, T4 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4) TR
func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIIII[TR, T0, T1, T2, T3, T4]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[4] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4])))
} }
@ -165,7 +156,6 @@ func ExportFuncIIIIII[TR, T0, T1, T2, T3, T4 i32](mod wazero.HostModuleBuilder,
type funcIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR type funcIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32] func(context.Context, api.Module, T0, T1, T2, T3, T4, T5) TR
func (fn funcIIIIIII[TR, T0, T1, T2, T3, T4, T5]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIIIII[TR, T0, T1, T2, T3, T4, T5]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[5] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]), T5(stack[5]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]), T4(stack[4]), T5(stack[5])))
} }
@ -179,7 +169,6 @@ func ExportFuncIIIIIII[TR, T0, T1, T2, T3, T4, T5 i32](mod wazero.HostModuleBuil
type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR type funcIIIIJ[TR, T0, T1, T2 i32, T3 i64] func(context.Context, api.Module, T0, T1, T2, T3) TR
func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIIIJ[TR, T0, T1, T2, T3]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[3] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]), T2(stack[2]), T3(stack[3])))
} }
@ -193,7 +182,6 @@ func ExportFuncIIIIJ[TR, T0, T1, T2 i32, T3 i64](mod wazero.HostModuleBuilder, n
type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR type funcIIJ[TR, T0 i32, T1 i64] func(context.Context, api.Module, T0, T1) TR
func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) { func (fn funcIIJ[TR, T0, T1]) Call(ctx context.Context, mod api.Module, stack []uint64) {
_ = stack[1] // prevent bounds check on every slice access
stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1]))) stack[0] = uint64(fn(ctx, mod, T0(stack[0]), T1(stack[1])))
} }

View file

@ -35,22 +35,17 @@ func DelHandle(ctx context.Context, id uint32) error {
s := ctx.Value(moduleKey{}).(*moduleState) s := ctx.Value(moduleKey{}).(*moduleState)
a := s.handles[^id] a := s.handles[^id]
s.handles[^id] = nil s.handles[^id] = nil
if l := uint32(len(s.handles)); l == ^id { s.holes++
s.handles = s.handles[:l-1]
} else {
s.holes++
}
if c, ok := a.(io.Closer); ok { if c, ok := a.(io.Closer); ok {
return c.Close() return c.Close()
} }
return nil return nil
} }
func AddHandle(ctx context.Context, a any) uint32 { func AddHandle(ctx context.Context, a any) (id uint32) {
if a == nil { if a == nil {
panic(NilErr) panic(NilErr)
} }
s := ctx.Value(moduleKey{}).(*moduleState) s := ctx.Value(moduleKey{}).(*moduleState)
// Find an empty slot. // Find an empty slot.

View file

@ -3,7 +3,6 @@
import ( import (
"bytes" "bytes"
"math" "math"
"reflect"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -14,9 +13,6 @@
// Quote escapes and quotes a value // Quote escapes and quotes a value
// making it safe to embed in SQL text. // making it safe to embed in SQL text.
// Strings with embedded NUL characters are truncated.
//
// https://sqlite.org/lang_corefunc.html#quote
func Quote(value any) string { func Quote(value any) string {
switch v := value.(type) { switch v := value.(type) {
case nil: case nil:
@ -46,8 +42,8 @@ func Quote(value any) string {
return "'" + v.Format(time.RFC3339Nano) + "'" return "'" + v.Format(time.RFC3339Nano) + "'"
case string: case string:
if i := strings.IndexByte(v, 0); i >= 0 { if strings.IndexByte(v, 0) >= 0 {
v = v[:i] break
} }
buf := make([]byte, 2+len(v)+strings.Count(v, "'")) buf := make([]byte, 2+len(v)+strings.Count(v, "'"))
@ -61,13 +57,13 @@ func Quote(value any) string {
buf[i] = b buf[i] = b
i += 1 i += 1
} }
buf[len(buf)-1] = '\'' buf[i] = '\''
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
case []byte: case []byte:
buf := make([]byte, 3+2*len(v)) buf := make([]byte, 3+2*len(v))
buf[1] = '\''
buf[0] = 'x' buf[0] = 'x'
buf[1] = '\''
i := 2 i := 2
for _, b := range v { for _, b := range v {
const hex = "0123456789ABCDEF" const hex = "0123456789ABCDEF"
@ -75,50 +71,26 @@ func Quote(value any) string {
buf[i+1] = hex[b%16] buf[i+1] = hex[b%16]
i += 2 i += 2
} }
buf[len(buf)-1] = '\'' buf[i] = '\''
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
case ZeroBlob: case ZeroBlob:
if v > ZeroBlob(1e9-3)/2 {
break
}
buf := bytes.Repeat([]byte("0"), int(3+2*int64(v))) buf := bytes.Repeat([]byte("0"), int(3+2*int64(v)))
buf[1] = '\''
buf[0] = 'x' buf[0] = 'x'
buf[1] = '\''
buf[len(buf)-1] = '\'' buf[len(buf)-1] = '\''
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
} }
v := reflect.ValueOf(value)
k := v.Kind()
if k == reflect.Interface || k == reflect.Pointer {
if v.IsNil() {
return "NULL"
}
v = v.Elem()
k = v.Kind()
}
switch {
case v.CanInt():
return strconv.FormatInt(v.Int(), 10)
case v.CanUint():
return strconv.FormatUint(v.Uint(), 10)
case v.CanFloat():
return Quote(v.Float())
case k == reflect.Bool:
return Quote(v.Bool())
case k == reflect.String:
return Quote(v.String())
case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
v.Type().Elem().Kind() == reflect.Uint8:
return Quote(v.Bytes())
}
panic(util.ValueErr) panic(util.ValueErr)
} }
// QuoteIdentifier escapes and quotes an identifier // QuoteIdentifier escapes and quotes an identifier
// making it safe to embed in SQL text. // making it safe to embed in SQL text.
// Strings with embedded NUL characters panic.
func QuoteIdentifier(id string) string { func QuoteIdentifier(id string) string {
if strings.IndexByte(id, 0) >= 0 { if strings.IndexByte(id, 0) >= 0 {
panic(util.ValueErr) panic(util.ValueErr)
@ -135,6 +107,6 @@ func QuoteIdentifier(id string) string {
buf[i] = b buf[i] = b
i += 1 i += 1
} }
buf[len(buf)-1] = '"' buf[i] = '"'
return unsafe.String(&buf[0], len(buf)) return unsafe.String(&buf[0], len(buf))
} }

View file

@ -131,7 +131,7 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH) err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_LENGTH)
} }
if len(sql) != 0 { if sql != nil {
if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 { if r := sqlt.call("sqlite3_error_offset", uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:] err.sql = sql[0][r:]
} }
@ -301,7 +301,7 @@ func (a *arena) string(s string) uint32 {
func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { func exportCallbacks(env wazero.HostModuleBuilder) wazero.HostModuleBuilder {
util.ExportFuncII(env, "go_progress_handler", progressCallback) util.ExportFuncII(env, "go_progress_handler", progressCallback)
util.ExportFuncIII(env, "go_busy_timeout", timeoutCallback) util.ExportFuncIIII(env, "go_busy_timeout", timeoutCallback)
util.ExportFuncIII(env, "go_busy_handler", busyCallback) util.ExportFuncIII(env, "go_busy_handler", busyCallback)
util.ExportFuncII(env, "go_commit_hook", commitCallback) util.ExportFuncII(env, "go_commit_hook", commitCallback)
util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback) util.ExportFuncVI(env, "go_rollback_hook", rollbackCallback)

View file

@ -30,13 +30,12 @@ func (s *Stmt) Close() error {
} }
r := s.c.call("sqlite3_finalize", uint64(s.handle)) r := s.c.call("sqlite3_finalize", uint64(s.handle))
stmts := s.c.stmts for i := range s.c.stmts {
for i := range stmts { if s == s.c.stmts[i] {
if s == stmts[i] { l := len(s.c.stmts) - 1
l := len(stmts) - 1 s.c.stmts[i] = s.c.stmts[l]
stmts[i] = stmts[l] s.c.stmts[l] = nil
stmts[l] = nil s.c.stmts = s.c.stmts[:l]
s.c.stmts = stmts[:l]
break break
} }
} }
@ -106,7 +105,7 @@ func (s *Stmt) Busy() bool {
// //
// https://sqlite.org/c3ref/step.html // https://sqlite.org/c3ref/step.html
func (s *Stmt) Step() bool { func (s *Stmt) Step() bool {
s.c.checkInterrupt(s.c.handle) s.c.checkInterrupt()
r := s.c.call("sqlite3_step", uint64(s.handle)) r := s.c.call("sqlite3_step", uint64(s.handle))
switch r { switch r {
case _ROW: case _ROW:
@ -377,15 +376,6 @@ func (s *Stmt) BindValue(param int, value Value) error {
return s.c.error(r) return s.c.error(r)
} }
// DataCount resets the number of columns in a result set.
//
// https://sqlite.org/c3ref/data_count.html
func (s *Stmt) DataCount() int {
r := s.c.call("sqlite3_data_count",
uint64(s.handle))
return int(int32(r))
}
// ColumnCount returns the number of columns in a result set. // ColumnCount returns the number of columns in a result set.
// //
// https://sqlite.org/c3ref/column_count.html // https://sqlite.org/c3ref/column_count.html
@ -640,7 +630,7 @@ func (s *Stmt) Columns(dest []any) error {
defer s.c.arena.mark()() defer s.c.arena.mark()()
count := uint64(len(dest)) count := uint64(len(dest))
typePtr := s.c.arena.new(count) typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8) dataPtr := s.c.arena.new(8 * count)
r := s.c.call("sqlite3_columns_go", r := s.c.call("sqlite3_columns_go",
uint64(s.handle), count, uint64(typePtr), uint64(dataPtr)) uint64(s.handle), count, uint64(typePtr), uint64(dataPtr))
@ -649,31 +639,26 @@ func (s *Stmt) Columns(dest []any) error {
} }
types := util.View(s.c.mod, typePtr, count) types := util.View(s.c.mod, typePtr, count)
// Avoid bounds checks on types below.
if len(types) != len(dest) {
panic(util.AssertErr())
}
for i := range dest { for i := range dest {
switch types[i] { switch types[i] {
case byte(INTEGER): case byte(INTEGER):
dest[i] = int64(util.ReadUint64(s.c.mod, dataPtr)) dest[i] = int64(util.ReadUint64(s.c.mod, dataPtr+8*uint32(i)))
continue
case byte(FLOAT): case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, dataPtr) dest[i] = util.ReadFloat64(s.c.mod, dataPtr+8*uint32(i))
continue
case byte(NULL): case byte(NULL):
dest[i] = nil dest[i] = nil
default: continue
ptr := util.ReadUint32(s.c.mod, dataPtr+0) }
len := util.ReadUint32(s.c.mod, dataPtr+4) ptr := util.ReadUint32(s.c.mod, dataPtr+8*uint32(i)+0)
buf := util.View(s.c.mod, ptr, uint64(len)) len := util.ReadUint32(s.c.mod, dataPtr+8*uint32(i)+4)
if types[i] == byte(TEXT) { buf := util.View(s.c.mod, ptr, uint64(len))
dest[i] = string(buf) if types[i] == byte(TEXT) {
} else { dest[i] = string(buf)
dest[i] = buf } else {
} dest[i] = buf
} }
dataPtr += 8
} }
return nil return nil
} }

View file

@ -138,9 +138,6 @@ func (f TimeFormat) Encode(t time.Time) any {
// //
// https://sqlite.org/lang_datefunc.html // https://sqlite.org/lang_datefunc.html
func (f TimeFormat) Decode(v any) (time.Time, error) { func (f TimeFormat) Decode(v any) (time.Time, error) {
if t, ok := v.(time.Time); ok {
return t, nil
}
switch f { switch f {
// Numeric formats. // Numeric formats.
case TimeFormatJulianDay: case TimeFormatJulianDay:

View file

@ -3,6 +3,7 @@
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"math/rand" "math/rand"
"runtime" "runtime"
"strconv" "strconv"
@ -135,21 +136,23 @@ type Savepoint struct {
// //
// https://sqlite.org/lang_savepoint.html // https://sqlite.org/lang_savepoint.html
func (c *Conn) Savepoint() Savepoint { func (c *Conn) Savepoint() Savepoint {
name := callerName() // Names can be reused; this makes catching bugs more likely.
if name == "" { name := saveptName() + "_" + strconv.Itoa(int(rand.Int31()))
name = "sqlite3.Savepoint"
}
// Names can be reused, but this makes catching bugs more likely.
name = QuoteIdentifier(name + "_" + strconv.Itoa(int(rand.Int31())))
err := c.txnExecInterrupted("SAVEPOINT " + name) err := c.txnExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil { if err != nil {
panic(err) panic(err)
} }
return Savepoint{c: c, name: name} return Savepoint{c: c, name: name}
} }
func callerName() (name string) { func saveptName() (name string) {
defer func() {
if name == "" {
name = "sqlite3.Savepoint"
}
}()
var pc [8]uintptr var pc [8]uintptr
n := runtime.Callers(3, pc[:]) n := runtime.Callers(3, pc[:])
if n <= 0 { if n <= 0 {
@ -186,7 +189,7 @@ func (s Savepoint) Release(errp *error) {
if s.c.GetAutocommit() { // There is nothing to commit. if s.c.GetAutocommit() { // There is nothing to commit.
return return
} }
*errp = s.c.Exec("RELEASE " + s.name) *errp = s.c.Exec(fmt.Sprintf("RELEASE %q;", s.name))
if *errp == nil { if *errp == nil {
return return
} }
@ -198,8 +201,10 @@ func (s Savepoint) Release(errp *error) {
return return
} }
// ROLLBACK and RELEASE even if interrupted. // ROLLBACK and RELEASE even if interrupted.
err := s.c.txnExecInterrupted("ROLLBACK TO " + err := s.c.txnExecInterrupted(fmt.Sprintf(`
s.name + "; RELEASE " + s.name) ROLLBACK TO %[1]q;
RELEASE %[1]q;
`, s.name))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -212,7 +217,7 @@ func (s Savepoint) Release(errp *error) {
// https://sqlite.org/lang_transaction.html // https://sqlite.org/lang_transaction.html
func (s Savepoint) Rollback() error { func (s Savepoint) Rollback() error {
// ROLLBACK even if interrupted. // ROLLBACK even if interrupted.
return s.c.txnExecInterrupted("ROLLBACK TO " + s.name) return s.c.txnExecInterrupted(fmt.Sprintf("ROLLBACK TO %q;", s.name))
} }
func (c *Conn) txnExecInterrupted(sql string) error { func (c *Conn) txnExecInterrupted(sql string) error {

View file

@ -15,7 +15,7 @@ func OpenFile(name string, flag int, perm fs.FileMode) (*os.File, error) {
if name == "" { if name == "" {
return nil, &os.PathError{Op: "open", Path: name, Err: ENOENT} return nil, &os.PathError{Op: "open", Path: name, Err: ENOENT}
} }
r, e := syscallOpen(name, flag|O_CLOEXEC, uint32(perm.Perm())) r, e := syscallOpen(name, flag, uint32(perm.Perm()))
if e != nil { if e != nil {
return nil, &os.PathError{Op: "open", Path: name, Err: e} return nil, &os.PathError{Op: "open", Path: name, Err: e}
} }

View file

@ -19,18 +19,17 @@ func (vfsOS) FullPathname(path string) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return path, testSymlinks(filepath.Dir(path)) fi, err := os.Lstat(path)
}
func testSymlinks(path string) error {
p, err := filepath.EvalSymlinks(path)
if err != nil { if err != nil {
return err if errors.Is(err, fs.ErrNotExist) {
return path, nil
}
return "", err
} }
if p != path { if fi.Mode()&fs.ModeSymlink != 0 {
return _OK_SYMLINK err = _OK_SYMLINK
} }
return nil return path, err
} }
func (vfsOS) Delete(path string, syncDir bool) error { func (vfsOS) Delete(path string, syncDir bool) error {
@ -75,7 +74,7 @@ func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
} }
func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error) { func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error) {
oflags := _O_NOFOLLOW var oflags int
if flags&OPEN_EXCLUSIVE != 0 { if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL oflags |= os.O_EXCL
} }

View file

@ -43,8 +43,7 @@ func Create(name string, data []byte) {
} }
// Convert data from WAL/2 to rollback journal. // Convert data from WAL/2 to rollback journal.
if len(data) >= 20 && (false || if len(data) >= 20 && (data[18] == 2 && data[19] == 2 ||
data[18] == 2 && data[19] == 2 ||
data[18] == 3 && data[19] == 3) { data[18] == 3 && data[19] == 3) {
data[18] = 1 data[18] = 1
data[19] = 1 data[19] = 1

View file

@ -7,8 +7,6 @@
"os" "os"
) )
const _O_NOFOLLOW = 0
func osAccess(path string, flags AccessFlag) error { func osAccess(path string, flags AccessFlag) error {
fi, err := os.Stat(path) fi, err := os.Stat(path)
if err != nil { if err != nil {
@ -36,12 +34,3 @@ func osAccess(path string, flags AccessFlag) error {
} }
return nil return nil
} }
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}

View file

@ -0,0 +1,14 @@
//go:build !unix || sqlite3_nosys
package vfs
import "os"
func osSetMode(file *os.File, modeof string) error {
fi, err := os.Stat(modeof)
if err != nil {
return err
}
file.Chmod(fi.Mode())
return nil
}

View file

@ -9,8 +9,6 @@
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
const _O_NOFOLLOW = unix.O_NOFOLLOW
func osAccess(path string, flags AccessFlag) error { func osAccess(path string, flags AccessFlag) error {
var access uint32 // unix.F_OK var access uint32 // unix.F_OK
switch flags { switch flags {

View file

@ -57,12 +57,9 @@ func CreateModule[T VTab](db *Conn, name string, create, connect VTabConstructor
flags |= VTAB_SHADOWTABS flags |= VTAB_SHADOWTABS
} }
var modulePtr uint32
defer db.arena.mark()() defer db.arena.mark()()
namePtr := db.arena.string(name) namePtr := db.arena.string(name)
if connect != nil { modulePtr := util.AddHandle(db.ctx, module[T]{create, connect})
modulePtr = util.AddHandle(db.ctx, module[T]{create, connect})
}
r := db.call("sqlite3_create_module_go", uint64(db.handle), r := db.call("sqlite3_create_module_go", uint64(db.handle),
uint64(namePtr), uint64(flags), uint64(modulePtr)) uint64(namePtr), uint64(flags), uint64(modulePtr))
return db.error(r) return db.error(r)
@ -355,9 +352,8 @@ func (idx *IndexInfo) load() {
idx.OrderBy = make([]IndexOrderBy, util.ReadUint32(mod, ptr+8)) idx.OrderBy = make([]IndexOrderBy, util.ReadUint32(mod, ptr+8))
constraintPtr := util.ReadUint32(mod, ptr+4) constraintPtr := util.ReadUint32(mod, ptr+4)
constraint := idx.Constraint
for i := range idx.Constraint { for i := range idx.Constraint {
constraint[i] = IndexConstraint{ idx.Constraint[i] = IndexConstraint{
Column: int(int32(util.ReadUint32(mod, constraintPtr+0))), Column: int(int32(util.ReadUint32(mod, constraintPtr+0))),
Op: IndexConstraintOp(util.ReadUint8(mod, constraintPtr+4)), Op: IndexConstraintOp(util.ReadUint8(mod, constraintPtr+4)),
Usable: util.ReadUint8(mod, constraintPtr+5) != 0, Usable: util.ReadUint8(mod, constraintPtr+5) != 0,
@ -366,9 +362,8 @@ func (idx *IndexInfo) load() {
} }
orderByPtr := util.ReadUint32(mod, ptr+12) orderByPtr := util.ReadUint32(mod, ptr+12)
orderBy := idx.OrderBy for i := range idx.OrderBy {
for i := range orderBy { idx.OrderBy[i] = IndexOrderBy{
orderBy[i] = IndexOrderBy{
Column: int(int32(util.ReadUint32(mod, orderByPtr+0))), Column: int(int32(util.ReadUint32(mod, orderByPtr+0))),
Desc: util.ReadUint8(mod, orderByPtr+4) != 0, Desc: util.ReadUint8(mod, orderByPtr+4) != 0,
} }

2
vendor/modules.txt vendored
View file

@ -518,7 +518,7 @@ github.com/modern-go/reflect2
# github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 # github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822
## explicit ## explicit
github.com/munnerz/goautoneg github.com/munnerz/goautoneg
# github.com/ncruces/go-sqlite3 v0.19.0 # github.com/ncruces/go-sqlite3 v0.18.4
## explicit; go 1.21 ## explicit; go 1.21
github.com/ncruces/go-sqlite3 github.com/ncruces/go-sqlite3
github.com/ncruces/go-sqlite3/driver github.com/ncruces/go-sqlite3/driver