mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2024-12-05 10:02:47 +00:00
357 lines
11 KiB
Go
357 lines
11 KiB
Go
|
// Copyright 2010 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package proto
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
|
||
|
"google.golang.org/protobuf/encoding/protowire"
|
||
|
"google.golang.org/protobuf/proto"
|
||
|
"google.golang.org/protobuf/reflect/protoreflect"
|
||
|
"google.golang.org/protobuf/reflect/protoregistry"
|
||
|
"google.golang.org/protobuf/runtime/protoiface"
|
||
|
"google.golang.org/protobuf/runtime/protoimpl"
|
||
|
)
|
||
|
|
||
|
type (
|
||
|
// ExtensionDesc represents an extension descriptor and
|
||
|
// is used to interact with an extension field in a message.
|
||
|
//
|
||
|
// Variables of this type are generated in code by protoc-gen-go.
|
||
|
ExtensionDesc = protoimpl.ExtensionInfo
|
||
|
|
||
|
// ExtensionRange represents a range of message extensions.
|
||
|
// Used in code generated by protoc-gen-go.
|
||
|
ExtensionRange = protoiface.ExtensionRangeV1
|
||
|
|
||
|
// Deprecated: Do not use; this is an internal type.
|
||
|
Extension = protoimpl.ExtensionFieldV1
|
||
|
|
||
|
// Deprecated: Do not use; this is an internal type.
|
||
|
XXX_InternalExtensions = protoimpl.ExtensionFields
|
||
|
)
|
||
|
|
||
|
// ErrMissingExtension reports whether the extension was not present.
|
||
|
var ErrMissingExtension = errors.New("proto: missing extension")
|
||
|
|
||
|
var errNotExtendable = errors.New("proto: not an extendable proto.Message")
|
||
|
|
||
|
// HasExtension reports whether the extension field is present in m
|
||
|
// either as an explicitly populated field or as an unknown field.
|
||
|
func HasExtension(m Message, xt *ExtensionDesc) (has bool) {
|
||
|
mr := MessageReflect(m)
|
||
|
if mr == nil || !mr.IsValid() {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
// Check whether any populated known field matches the field number.
|
||
|
xtd := xt.TypeDescriptor()
|
||
|
if isValidExtension(mr.Descriptor(), xtd) {
|
||
|
has = mr.Has(xtd)
|
||
|
} else {
|
||
|
mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
|
||
|
has = int32(fd.Number()) == xt.Field
|
||
|
return !has
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// Check whether any unknown field matches the field number.
|
||
|
for b := mr.GetUnknown(); !has && len(b) > 0; {
|
||
|
num, _, n := protowire.ConsumeField(b)
|
||
|
has = int32(num) == xt.Field
|
||
|
b = b[n:]
|
||
|
}
|
||
|
return has
|
||
|
}
|
||
|
|
||
|
// ClearExtension removes the extension field from m
|
||
|
// either as an explicitly populated field or as an unknown field.
|
||
|
func ClearExtension(m Message, xt *ExtensionDesc) {
|
||
|
mr := MessageReflect(m)
|
||
|
if mr == nil || !mr.IsValid() {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
xtd := xt.TypeDescriptor()
|
||
|
if isValidExtension(mr.Descriptor(), xtd) {
|
||
|
mr.Clear(xtd)
|
||
|
} else {
|
||
|
mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
|
||
|
if int32(fd.Number()) == xt.Field {
|
||
|
mr.Clear(fd)
|
||
|
return false
|
||
|
}
|
||
|
return true
|
||
|
})
|
||
|
}
|
||
|
clearUnknown(mr, fieldNum(xt.Field))
|
||
|
}
|
||
|
|
||
|
// ClearAllExtensions clears all extensions from m.
|
||
|
// This includes populated fields and unknown fields in the extension range.
|
||
|
func ClearAllExtensions(m Message) {
|
||
|
mr := MessageReflect(m)
|
||
|
if mr == nil || !mr.IsValid() {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
mr.Range(func(fd protoreflect.FieldDescriptor, _ protoreflect.Value) bool {
|
||
|
if fd.IsExtension() {
|
||
|
mr.Clear(fd)
|
||
|
}
|
||
|
return true
|
||
|
})
|
||
|
clearUnknown(mr, mr.Descriptor().ExtensionRanges())
|
||
|
}
|
||
|
|
||
|
// GetExtension retrieves a proto2 extended field from m.
|
||
|
//
|
||
|
// If the descriptor is type complete (i.e., ExtensionDesc.ExtensionType is non-nil),
|
||
|
// then GetExtension parses the encoded field and returns a Go value of the specified type.
|
||
|
// If the field is not present, then the default value is returned (if one is specified),
|
||
|
// otherwise ErrMissingExtension is reported.
|
||
|
//
|
||
|
// If the descriptor is type incomplete (i.e., ExtensionDesc.ExtensionType is nil),
|
||
|
// then GetExtension returns the raw encoded bytes for the extension field.
|
||
|
func GetExtension(m Message, xt *ExtensionDesc) (interface{}, error) {
|
||
|
mr := MessageReflect(m)
|
||
|
if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
|
||
|
return nil, errNotExtendable
|
||
|
}
|
||
|
|
||
|
// Retrieve the unknown fields for this extension field.
|
||
|
var bo protoreflect.RawFields
|
||
|
for bi := mr.GetUnknown(); len(bi) > 0; {
|
||
|
num, _, n := protowire.ConsumeField(bi)
|
||
|
if int32(num) == xt.Field {
|
||
|
bo = append(bo, bi[:n]...)
|
||
|
}
|
||
|
bi = bi[n:]
|
||
|
}
|
||
|
|
||
|
// For type incomplete descriptors, only retrieve the unknown fields.
|
||
|
if xt.ExtensionType == nil {
|
||
|
return []byte(bo), nil
|
||
|
}
|
||
|
|
||
|
// If the extension field only exists as unknown fields, unmarshal it.
|
||
|
// This is rarely done since proto.Unmarshal eagerly unmarshals extensions.
|
||
|
xtd := xt.TypeDescriptor()
|
||
|
if !isValidExtension(mr.Descriptor(), xtd) {
|
||
|
return nil, fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
|
||
|
}
|
||
|
if !mr.Has(xtd) && len(bo) > 0 {
|
||
|
m2 := mr.New()
|
||
|
if err := (proto.UnmarshalOptions{
|
||
|
Resolver: extensionResolver{xt},
|
||
|
}.Unmarshal(bo, m2.Interface())); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if m2.Has(xtd) {
|
||
|
mr.Set(xtd, m2.Get(xtd))
|
||
|
clearUnknown(mr, fieldNum(xt.Field))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Check whether the message has the extension field set or a default.
|
||
|
var pv protoreflect.Value
|
||
|
switch {
|
||
|
case mr.Has(xtd):
|
||
|
pv = mr.Get(xtd)
|
||
|
case xtd.HasDefault():
|
||
|
pv = xtd.Default()
|
||
|
default:
|
||
|
return nil, ErrMissingExtension
|
||
|
}
|
||
|
|
||
|
v := xt.InterfaceOf(pv)
|
||
|
rv := reflect.ValueOf(v)
|
||
|
if isScalarKind(rv.Kind()) {
|
||
|
rv2 := reflect.New(rv.Type())
|
||
|
rv2.Elem().Set(rv)
|
||
|
v = rv2.Interface()
|
||
|
}
|
||
|
return v, nil
|
||
|
}
|
||
|
|
||
|
// extensionResolver is a custom extension resolver that stores a single
|
||
|
// extension type that takes precedence over the global registry.
|
||
|
type extensionResolver struct{ xt protoreflect.ExtensionType }
|
||
|
|
||
|
func (r extensionResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
|
||
|
if xtd := r.xt.TypeDescriptor(); xtd.FullName() == field {
|
||
|
return r.xt, nil
|
||
|
}
|
||
|
return protoregistry.GlobalTypes.FindExtensionByName(field)
|
||
|
}
|
||
|
|
||
|
func (r extensionResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
|
||
|
if xtd := r.xt.TypeDescriptor(); xtd.ContainingMessage().FullName() == message && xtd.Number() == field {
|
||
|
return r.xt, nil
|
||
|
}
|
||
|
return protoregistry.GlobalTypes.FindExtensionByNumber(message, field)
|
||
|
}
|
||
|
|
||
|
// GetExtensions returns a list of the extensions values present in m,
|
||
|
// corresponding with the provided list of extension descriptors, xts.
|
||
|
// If an extension is missing in m, the corresponding value is nil.
|
||
|
func GetExtensions(m Message, xts []*ExtensionDesc) ([]interface{}, error) {
|
||
|
mr := MessageReflect(m)
|
||
|
if mr == nil || !mr.IsValid() {
|
||
|
return nil, errNotExtendable
|
||
|
}
|
||
|
|
||
|
vs := make([]interface{}, len(xts))
|
||
|
for i, xt := range xts {
|
||
|
v, err := GetExtension(m, xt)
|
||
|
if err != nil {
|
||
|
if err == ErrMissingExtension {
|
||
|
continue
|
||
|
}
|
||
|
return vs, err
|
||
|
}
|
||
|
vs[i] = v
|
||
|
}
|
||
|
return vs, nil
|
||
|
}
|
||
|
|
||
|
// SetExtension sets an extension field in m to the provided value.
|
||
|
func SetExtension(m Message, xt *ExtensionDesc, v interface{}) error {
|
||
|
mr := MessageReflect(m)
|
||
|
if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
|
||
|
return errNotExtendable
|
||
|
}
|
||
|
|
||
|
rv := reflect.ValueOf(v)
|
||
|
if reflect.TypeOf(v) != reflect.TypeOf(xt.ExtensionType) {
|
||
|
return fmt.Errorf("proto: bad extension value type. got: %T, want: %T", v, xt.ExtensionType)
|
||
|
}
|
||
|
if rv.Kind() == reflect.Ptr {
|
||
|
if rv.IsNil() {
|
||
|
return fmt.Errorf("proto: SetExtension called with nil value of type %T", v)
|
||
|
}
|
||
|
if isScalarKind(rv.Elem().Kind()) {
|
||
|
v = rv.Elem().Interface()
|
||
|
}
|
||
|
}
|
||
|
|
||
|
xtd := xt.TypeDescriptor()
|
||
|
if !isValidExtension(mr.Descriptor(), xtd) {
|
||
|
return fmt.Errorf("proto: bad extended type; %T does not extend %T", xt.ExtendedType, m)
|
||
|
}
|
||
|
mr.Set(xtd, xt.ValueOf(v))
|
||
|
clearUnknown(mr, fieldNum(xt.Field))
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// SetRawExtension inserts b into the unknown fields of m.
|
||
|
//
|
||
|
// Deprecated: Use Message.ProtoReflect.SetUnknown instead.
|
||
|
func SetRawExtension(m Message, fnum int32, b []byte) {
|
||
|
mr := MessageReflect(m)
|
||
|
if mr == nil || !mr.IsValid() {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Verify that the raw field is valid.
|
||
|
for b0 := b; len(b0) > 0; {
|
||
|
num, _, n := protowire.ConsumeField(b0)
|
||
|
if int32(num) != fnum {
|
||
|
panic(fmt.Sprintf("mismatching field number: got %d, want %d", num, fnum))
|
||
|
}
|
||
|
b0 = b0[n:]
|
||
|
}
|
||
|
|
||
|
ClearExtension(m, &ExtensionDesc{Field: fnum})
|
||
|
mr.SetUnknown(append(mr.GetUnknown(), b...))
|
||
|
}
|
||
|
|
||
|
// ExtensionDescs returns a list of extension descriptors found in m,
|
||
|
// containing descriptors for both populated extension fields in m and
|
||
|
// also unknown fields of m that are in the extension range.
|
||
|
// For the later case, an type incomplete descriptor is provided where only
|
||
|
// the ExtensionDesc.Field field is populated.
|
||
|
// The order of the extension descriptors is undefined.
|
||
|
func ExtensionDescs(m Message) ([]*ExtensionDesc, error) {
|
||
|
mr := MessageReflect(m)
|
||
|
if mr == nil || !mr.IsValid() || mr.Descriptor().ExtensionRanges().Len() == 0 {
|
||
|
return nil, errNotExtendable
|
||
|
}
|
||
|
|
||
|
// Collect a set of known extension descriptors.
|
||
|
extDescs := make(map[protoreflect.FieldNumber]*ExtensionDesc)
|
||
|
mr.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
|
||
|
if fd.IsExtension() {
|
||
|
xt := fd.(protoreflect.ExtensionTypeDescriptor)
|
||
|
if xd, ok := xt.Type().(*ExtensionDesc); ok {
|
||
|
extDescs[fd.Number()] = xd
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
})
|
||
|
|
||
|
// Collect a set of unknown extension descriptors.
|
||
|
extRanges := mr.Descriptor().ExtensionRanges()
|
||
|
for b := mr.GetUnknown(); len(b) > 0; {
|
||
|
num, _, n := protowire.ConsumeField(b)
|
||
|
if extRanges.Has(num) && extDescs[num] == nil {
|
||
|
extDescs[num] = nil
|
||
|
}
|
||
|
b = b[n:]
|
||
|
}
|
||
|
|
||
|
// Transpose the set of descriptors into a list.
|
||
|
var xts []*ExtensionDesc
|
||
|
for num, xt := range extDescs {
|
||
|
if xt == nil {
|
||
|
xt = &ExtensionDesc{Field: int32(num)}
|
||
|
}
|
||
|
xts = append(xts, xt)
|
||
|
}
|
||
|
return xts, nil
|
||
|
}
|
||
|
|
||
|
// isValidExtension reports whether xtd is a valid extension descriptor for md.
|
||
|
func isValidExtension(md protoreflect.MessageDescriptor, xtd protoreflect.ExtensionTypeDescriptor) bool {
|
||
|
return xtd.ContainingMessage() == md && md.ExtensionRanges().Has(xtd.Number())
|
||
|
}
|
||
|
|
||
|
// isScalarKind reports whether k is a protobuf scalar kind (except bytes).
|
||
|
// This function exists for historical reasons since the representation of
|
||
|
// scalars differs between v1 and v2, where v1 uses *T and v2 uses T.
|
||
|
func isScalarKind(k reflect.Kind) bool {
|
||
|
switch k {
|
||
|
case reflect.Bool, reflect.Int32, reflect.Int64, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.String:
|
||
|
return true
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// clearUnknown removes unknown fields from m where remover.Has reports true.
|
||
|
func clearUnknown(m protoreflect.Message, remover interface {
|
||
|
Has(protoreflect.FieldNumber) bool
|
||
|
}) {
|
||
|
var bo protoreflect.RawFields
|
||
|
for bi := m.GetUnknown(); len(bi) > 0; {
|
||
|
num, _, n := protowire.ConsumeField(bi)
|
||
|
if !remover.Has(num) {
|
||
|
bo = append(bo, bi[:n]...)
|
||
|
}
|
||
|
bi = bi[n:]
|
||
|
}
|
||
|
if bi := m.GetUnknown(); len(bi) != len(bo) {
|
||
|
m.SetUnknown(bo)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type fieldNum protoreflect.FieldNumber
|
||
|
|
||
|
func (n1 fieldNum) Has(n2 protoreflect.FieldNumber) bool {
|
||
|
return protoreflect.FieldNumber(n1) == n2
|
||
|
}
|