mirror of
https://github.com/superseriousbusiness/gotosocial.git
synced 2025-01-23 14:50:21 +00:00
365 lines
6.2 KiB
Go
365 lines
6.2 KiB
Go
|
// Copyright 2024 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
// Helper code for parsing a protocol buffer
|
||
|
|
||
|
package protolazy
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
|
||
|
"google.golang.org/protobuf/encoding/protowire"
|
||
|
)
|
||
|
|
||
|
// BufferReader is a structure encapsulating a protobuf and a current position
|
||
|
type BufferReader struct {
|
||
|
Buf []byte
|
||
|
Pos int
|
||
|
}
|
||
|
|
||
|
// NewBufferReader creates a new BufferRead from a protobuf
|
||
|
func NewBufferReader(buf []byte) BufferReader {
|
||
|
return BufferReader{Buf: buf, Pos: 0}
|
||
|
}
|
||
|
|
||
|
var errOutOfBounds = errors.New("protobuf decoding: out of bounds")
|
||
|
var errOverflow = errors.New("proto: integer overflow")
|
||
|
|
||
|
func (b *BufferReader) DecodeVarintSlow() (x uint64, err error) {
|
||
|
i := b.Pos
|
||
|
l := len(b.Buf)
|
||
|
|
||
|
for shift := uint(0); shift < 64; shift += 7 {
|
||
|
if i >= l {
|
||
|
err = io.ErrUnexpectedEOF
|
||
|
return
|
||
|
}
|
||
|
v := b.Buf[i]
|
||
|
i++
|
||
|
x |= (uint64(v) & 0x7F) << shift
|
||
|
if v < 0x80 {
|
||
|
b.Pos = i
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// The number is too large to represent in a 64-bit value.
|
||
|
err = errOverflow
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// decodeVarint decodes a varint at the current position
|
||
|
func (b *BufferReader) DecodeVarint() (x uint64, err error) {
|
||
|
i := b.Pos
|
||
|
buf := b.Buf
|
||
|
|
||
|
if i >= len(buf) {
|
||
|
return 0, io.ErrUnexpectedEOF
|
||
|
} else if buf[i] < 0x80 {
|
||
|
b.Pos++
|
||
|
return uint64(buf[i]), nil
|
||
|
} else if len(buf)-i < 10 {
|
||
|
return b.DecodeVarintSlow()
|
||
|
}
|
||
|
|
||
|
var v uint64
|
||
|
// we already checked the first byte
|
||
|
x = uint64(buf[i]) & 127
|
||
|
i++
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 7
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 14
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 21
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 28
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 35
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 42
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 49
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 56
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint64(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 63
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
return 0, errOverflow
|
||
|
|
||
|
done:
|
||
|
b.Pos = i
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// decodeVarint32 decodes a varint32 at the current position
|
||
|
func (b *BufferReader) DecodeVarint32() (x uint32, err error) {
|
||
|
i := b.Pos
|
||
|
buf := b.Buf
|
||
|
|
||
|
if i >= len(buf) {
|
||
|
return 0, io.ErrUnexpectedEOF
|
||
|
} else if buf[i] < 0x80 {
|
||
|
b.Pos++
|
||
|
return uint32(buf[i]), nil
|
||
|
} else if len(buf)-i < 5 {
|
||
|
v, err := b.DecodeVarintSlow()
|
||
|
return uint32(v), err
|
||
|
}
|
||
|
|
||
|
var v uint32
|
||
|
// we already checked the first byte
|
||
|
x = uint32(buf[i]) & 127
|
||
|
i++
|
||
|
|
||
|
v = uint32(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 7
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint32(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 14
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint32(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 21
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
v = uint32(buf[i])
|
||
|
i++
|
||
|
x |= (v & 127) << 28
|
||
|
if v < 128 {
|
||
|
goto done
|
||
|
}
|
||
|
|
||
|
return 0, errOverflow
|
||
|
|
||
|
done:
|
||
|
b.Pos = i
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// skipValue skips a value in the protobuf, based on the specified tag
|
||
|
func (b *BufferReader) SkipValue(tag uint32) (err error) {
|
||
|
wireType := tag & 0x7
|
||
|
switch protowire.Type(wireType) {
|
||
|
case protowire.VarintType:
|
||
|
err = b.SkipVarint()
|
||
|
case protowire.Fixed64Type:
|
||
|
err = b.SkipFixed64()
|
||
|
case protowire.BytesType:
|
||
|
var n uint32
|
||
|
n, err = b.DecodeVarint32()
|
||
|
if err == nil {
|
||
|
err = b.Skip(int(n))
|
||
|
}
|
||
|
case protowire.StartGroupType:
|
||
|
err = b.SkipGroup(tag)
|
||
|
case protowire.Fixed32Type:
|
||
|
err = b.SkipFixed32()
|
||
|
default:
|
||
|
err = fmt.Errorf("Unexpected wire type (%d)", wireType)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// skipGroup skips a group with the specified tag. It executes efficiently using a tag stack
|
||
|
func (b *BufferReader) SkipGroup(tag uint32) (err error) {
|
||
|
tagStack := make([]uint32, 0, 16)
|
||
|
tagStack = append(tagStack, tag)
|
||
|
var n uint32
|
||
|
for len(tagStack) > 0 {
|
||
|
tag, err = b.DecodeVarint32()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
switch protowire.Type(tag & 0x7) {
|
||
|
case protowire.VarintType:
|
||
|
err = b.SkipVarint()
|
||
|
case protowire.Fixed64Type:
|
||
|
err = b.Skip(8)
|
||
|
case protowire.BytesType:
|
||
|
n, err = b.DecodeVarint32()
|
||
|
if err == nil {
|
||
|
err = b.Skip(int(n))
|
||
|
}
|
||
|
case protowire.StartGroupType:
|
||
|
tagStack = append(tagStack, tag)
|
||
|
case protowire.Fixed32Type:
|
||
|
err = b.SkipFixed32()
|
||
|
case protowire.EndGroupType:
|
||
|
if protoFieldNumber(tagStack[len(tagStack)-1]) == protoFieldNumber(tag) {
|
||
|
tagStack = tagStack[:len(tagStack)-1]
|
||
|
} else {
|
||
|
err = fmt.Errorf("end group tag %d does not match begin group tag %d at pos %d",
|
||
|
protoFieldNumber(tag), protoFieldNumber(tagStack[len(tagStack)-1]), b.Pos)
|
||
|
}
|
||
|
}
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// skipVarint effiently skips a varint
|
||
|
func (b *BufferReader) SkipVarint() (err error) {
|
||
|
i := b.Pos
|
||
|
|
||
|
if len(b.Buf)-i < 10 {
|
||
|
// Use DecodeVarintSlow() to check for buffer overflow, but ignore result
|
||
|
if _, err := b.DecodeVarintSlow(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
i++
|
||
|
|
||
|
if b.Buf[i] < 0x80 {
|
||
|
goto out
|
||
|
}
|
||
|
return errOverflow
|
||
|
|
||
|
out:
|
||
|
b.Pos = i + 1
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// skip skips the specified number of bytes
|
||
|
func (b *BufferReader) Skip(n int) (err error) {
|
||
|
if len(b.Buf) < b.Pos+n {
|
||
|
return io.ErrUnexpectedEOF
|
||
|
}
|
||
|
b.Pos += n
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// skipFixed64 skips a fixed64
|
||
|
func (b *BufferReader) SkipFixed64() (err error) {
|
||
|
return b.Skip(8)
|
||
|
}
|
||
|
|
||
|
// skipFixed32 skips a fixed32
|
||
|
func (b *BufferReader) SkipFixed32() (err error) {
|
||
|
return b.Skip(4)
|
||
|
}
|
||
|
|
||
|
// skipBytes skips a set of bytes
|
||
|
func (b *BufferReader) SkipBytes() (err error) {
|
||
|
n, err := b.DecodeVarint32()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return b.Skip(int(n))
|
||
|
}
|
||
|
|
||
|
// Done returns whether we are at the end of the protobuf
|
||
|
func (b *BufferReader) Done() bool {
|
||
|
return b.Pos == len(b.Buf)
|
||
|
}
|
||
|
|
||
|
// Remaining returns how many bytes remain
|
||
|
func (b *BufferReader) Remaining() int {
|
||
|
return len(b.Buf) - b.Pos
|
||
|
}
|