package pgproto3 import ( "encoding/binary" "errors" "fmt" "io" ) // Frontend acts as a client for the PostgreSQL wire protocol version 3. type Frontend struct { cr ChunkReader w io.Writer // Backend message flyweights authenticationOk AuthenticationOk authenticationCleartextPassword AuthenticationCleartextPassword authenticationMD5Password AuthenticationMD5Password authenticationSASL AuthenticationSASL authenticationSASLContinue AuthenticationSASLContinue authenticationSASLFinal AuthenticationSASLFinal backendKeyData BackendKeyData bindComplete BindComplete closeComplete CloseComplete commandComplete CommandComplete copyBothResponse CopyBothResponse copyData CopyData copyInResponse CopyInResponse copyOutResponse CopyOutResponse copyDone CopyDone dataRow DataRow emptyQueryResponse EmptyQueryResponse errorResponse ErrorResponse functionCallResponse FunctionCallResponse noData NoData noticeResponse NoticeResponse notificationResponse NotificationResponse parameterDescription ParameterDescription parameterStatus ParameterStatus parseComplete ParseComplete readyForQuery ReadyForQuery rowDescription RowDescription portalSuspended PortalSuspended bodyLen int msgType byte partialMsg bool authType uint32 } // NewFrontend creates a new Frontend. func NewFrontend(cr ChunkReader, w io.Writer) *Frontend { return &Frontend{cr: cr, w: w} } // Send sends a message to the backend. func (f *Frontend) Send(msg FrontendMessage) error { _, err := f.w.Write(msg.Encode(nil)) return err } func translateEOFtoErrUnexpectedEOF(err error) error { if err == io.EOF { return io.ErrUnexpectedEOF } return err } // Receive receives a message from the backend. The returned message is only valid until the next call to Receive. func (f *Frontend) Receive() (BackendMessage, error) { if !f.partialMsg { header, err := f.cr.Next(5) if err != nil { return nil, translateEOFtoErrUnexpectedEOF(err) } f.msgType = header[0] f.bodyLen = int(binary.BigEndian.Uint32(header[1:])) - 4 f.partialMsg = true } msgBody, err := f.cr.Next(f.bodyLen) if err != nil { return nil, translateEOFtoErrUnexpectedEOF(err) } f.partialMsg = false var msg BackendMessage switch f.msgType { case '1': msg = &f.parseComplete case '2': msg = &f.bindComplete case '3': msg = &f.closeComplete case 'A': msg = &f.notificationResponse case 'c': msg = &f.copyDone case 'C': msg = &f.commandComplete case 'd': msg = &f.copyData case 'D': msg = &f.dataRow case 'E': msg = &f.errorResponse case 'G': msg = &f.copyInResponse case 'H': msg = &f.copyOutResponse case 'I': msg = &f.emptyQueryResponse case 'K': msg = &f.backendKeyData case 'n': msg = &f.noData case 'N': msg = &f.noticeResponse case 'R': var err error msg, err = f.findAuthenticationMessageType(msgBody) if err != nil { return nil, err } case 's': msg = &f.portalSuspended case 'S': msg = &f.parameterStatus case 't': msg = &f.parameterDescription case 'T': msg = &f.rowDescription case 'V': msg = &f.functionCallResponse case 'W': msg = &f.copyBothResponse case 'Z': msg = &f.readyForQuery default: return nil, fmt.Errorf("unknown message type: %c", f.msgType) } err = msg.Decode(msgBody) return msg, err } // Authentication message type constants. // See src/include/libpq/pqcomm.h for all // constants. const ( AuthTypeOk = 0 AuthTypeCleartextPassword = 3 AuthTypeMD5Password = 5 AuthTypeSCMCreds = 6 AuthTypeGSS = 7 AuthTypeGSSCont = 8 AuthTypeSSPI = 9 AuthTypeSASL = 10 AuthTypeSASLContinue = 11 AuthTypeSASLFinal = 12 ) func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, error) { if len(src) < 4 { return nil, errors.New("authentication message too short") } f.authType = binary.BigEndian.Uint32(src[:4]) switch f.authType { case AuthTypeOk: return &f.authenticationOk, nil case AuthTypeCleartextPassword: return &f.authenticationCleartextPassword, nil case AuthTypeMD5Password: return &f.authenticationMD5Password, nil case AuthTypeSCMCreds: return nil, errors.New("AuthTypeSCMCreds is unimplemented") case AuthTypeGSS: return nil, errors.New("AuthTypeGSS is unimplemented") case AuthTypeGSSCont: return nil, errors.New("AuthTypeGSSCont is unimplemented") case AuthTypeSSPI: return nil, errors.New("AuthTypeSSPI is unimplemented") case AuthTypeSASL: return &f.authenticationSASL, nil case AuthTypeSASLContinue: return &f.authenticationSASLContinue, nil case AuthTypeSASLFinal: return &f.authenticationSASLFinal, nil default: return nil, fmt.Errorf("unknown authentication type: %d", f.authType) } } // GetAuthType returns the authType used in the current state of the frontend. // See SetAuthType for more information. func (f *Frontend) GetAuthType() uint32 { return f.authType }