2024-11-08 13:51:23 +00:00
|
|
|
package pgdialect
|
|
|
|
|
|
|
|
import (
|
|
|
|
"database/sql"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/uptrace/bun/internal"
|
|
|
|
"github.com/uptrace/bun/schema"
|
|
|
|
)
|
|
|
|
|
|
|
|
type MultiRange[T any] []Range[T]
|
|
|
|
|
|
|
|
type Range[T any] struct {
|
|
|
|
Lower, Upper T
|
|
|
|
LowerBound, UpperBound RangeBound
|
|
|
|
}
|
|
|
|
|
|
|
|
type RangeBound byte
|
|
|
|
|
|
|
|
const (
|
|
|
|
RangeBoundInclusiveLeft RangeBound = '['
|
|
|
|
RangeBoundInclusiveRight RangeBound = ']'
|
|
|
|
RangeBoundExclusiveLeft RangeBound = '('
|
|
|
|
RangeBoundExclusiveRight RangeBound = ')'
|
|
|
|
)
|
|
|
|
|
|
|
|
func NewRange[T any](lower, upper T) Range[T] {
|
|
|
|
return Range[T]{
|
|
|
|
Lower: lower,
|
|
|
|
Upper: upper,
|
|
|
|
LowerBound: RangeBoundInclusiveLeft,
|
|
|
|
UpperBound: RangeBoundExclusiveRight,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
var _ sql.Scanner = (*Range[any])(nil)
|
|
|
|
|
|
|
|
func (r *Range[T]) Scan(anySrc any) (err error) {
|
2025-01-14 14:23:28 +00:00
|
|
|
src, ok := anySrc.([]byte)
|
|
|
|
if !ok {
|
|
|
|
return fmt.Errorf("pgdialect: Range can't scan %T", anySrc)
|
|
|
|
}
|
2024-11-08 13:51:23 +00:00
|
|
|
|
|
|
|
if len(src) == 0 {
|
|
|
|
return io.ErrUnexpectedEOF
|
|
|
|
}
|
|
|
|
r.LowerBound = RangeBound(src[0])
|
|
|
|
src = src[1:]
|
|
|
|
|
|
|
|
src, err = scanElem(&r.Lower, src)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(src) == 0 {
|
|
|
|
return io.ErrUnexpectedEOF
|
|
|
|
}
|
|
|
|
if ch := src[0]; ch != ',' {
|
|
|
|
return fmt.Errorf("got %q, wanted %q", ch, ',')
|
|
|
|
}
|
|
|
|
src = src[1:]
|
|
|
|
|
|
|
|
src, err = scanElem(&r.Upper, src)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(src) == 0 {
|
|
|
|
return io.ErrUnexpectedEOF
|
|
|
|
}
|
|
|
|
r.UpperBound = RangeBound(src[0])
|
|
|
|
src = src[1:]
|
|
|
|
|
|
|
|
if len(src) > 0 {
|
|
|
|
return fmt.Errorf("unread data: %q", src)
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
var _ schema.QueryAppender = (*Range[any])(nil)
|
|
|
|
|
|
|
|
func (r *Range[T]) AppendQuery(fmt schema.Formatter, buf []byte) ([]byte, error) {
|
|
|
|
buf = append(buf, byte(r.LowerBound))
|
|
|
|
buf = appendElem(buf, r.Lower)
|
|
|
|
buf = append(buf, ',')
|
|
|
|
buf = appendElem(buf, r.Upper)
|
|
|
|
buf = append(buf, byte(r.UpperBound))
|
|
|
|
return buf, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func scanElem(ptr any, src []byte) ([]byte, error) {
|
|
|
|
switch ptr := ptr.(type) {
|
|
|
|
case *time.Time:
|
|
|
|
src, str, err := readStringLiteral(src)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
tm, err := internal.ParseTime(internal.String(str))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
*ptr = tm
|
|
|
|
|
|
|
|
return src, nil
|
2025-01-14 14:23:28 +00:00
|
|
|
|
|
|
|
case sql.Scanner:
|
|
|
|
src, str, err := readStringLiteral(src)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
if err := ptr.Scan(str); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
return src, nil
|
|
|
|
|
2024-11-08 13:51:23 +00:00
|
|
|
default:
|
|
|
|
panic(fmt.Errorf("unsupported range type: %T", ptr))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func readStringLiteral(src []byte) ([]byte, []byte, error) {
|
|
|
|
p := newParser(src)
|
|
|
|
|
|
|
|
if err := p.Skip('"'); err != nil {
|
|
|
|
return nil, nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
str, err := p.ReadSubstring('"')
|
|
|
|
if err != nil {
|
|
|
|
return nil, nil, err
|
|
|
|
}
|
|
|
|
|
|
|
|
src = p.Remaining()
|
|
|
|
return src, str, nil
|
|
|
|
}
|