package pq import ( "bufio" "context" "crypto/md5" "crypto/sha256" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" "net" "os" "os/user" "path" "path/filepath" "strconv" "strings" "time" "unicode" "github.com/lib/pq/oid" "github.com/lib/pq/scram" ) // Common error types var ( ErrNotSupported = errors.New("pq: Unsupported command") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") errUnexpectedReady = errors.New("unexpected ReadyForQuery") errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) // Driver is the Postgres database driver. type Driver struct{} // Open opens a new connection to the database. name is a connection string. // Most users should only use it through database/sql package from the standard // library. func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } func init() { sql.Register("postgres", &Driver{}) } type parameterStatus struct { // server version in the same format as server_version_num, or 0 if // unavailable serverVersion int // the current location based on the TimeZone value of the session, if // available currentLocation *time.Location } type transactionStatus byte const ( txnStatusIdle transactionStatus = 'I' txnStatusIdleInTransaction transactionStatus = 'T' txnStatusInFailedTransaction transactionStatus = 'E' ) func (s transactionStatus) String() string { switch s { case txnStatusIdle: return "idle" case txnStatusIdleInTransaction: return "idle in transaction" case txnStatusInFailedTransaction: return "in a failed transaction" default: errorf("unknown transactionStatus %d", s) } panic("not reached") } // Dialer is the dialer interface. It can be used to obtain more control over // how pq creates network connections. type Dialer interface { Dial(network, address string) (net.Conn, error) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) } // DialerContext is the context-aware dialer interface. type DialerContext interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } type defaultDialer struct { d net.Dialer } func (d defaultDialer) Dial(network, address string) (net.Conn, error) { return d.d.Dial(network, address) } func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return d.DialContext(ctx, network, address) } func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { return d.d.DialContext(ctx, network, address) } type conn struct { c net.Conn buf *bufio.Reader namei int scratch [512]byte txnStatus transactionStatus txnFinish func() // Save connection arguments to use during CancelRequest. dialer Dialer opts values // Cancellation key data for use with CancelRequest messages. processID int secretKey int parameterStatus parameterStatus saveMessageType byte saveMessageBuffer []byte // If true, this connection is bad and all public-facing functions should // return ErrBadConn. bad bool // If set, this connection should never use the binary format when // receiving query results from prepared statements. Only provided for // debugging. disablePreparedBinaryResult bool // Whether to always send []byte parameters over as binary. Enables single // round-trip mode for non-prepared Query calls. binaryParameters bool // If true this connection is in the middle of a COPY inCopy bool // If not nil, notices will be synchronously sent here noticeHandler func(*Error) // If not nil, notifications will be synchronously sent here notificationHandler func(*Notification) // GSSAPI context gss GSS } // Handle driver-side settings in parsed connection string. func (cn *conn) handleDriverSettings(o values) (err error) { boolSetting := func(key string, val *bool) error { if value, ok := o[key]; ok { if value == "yes" { *val = true } else if value == "no" { *val = false } else { return fmt.Errorf("unrecognized value %q for %s", value, key) } } return nil } err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) if err != nil { return err } return boolSetting("binary_parameters", &cn.binaryParameters) } func (cn *conn) handlePgpass(o values) { // if a password was supplied, do not process .pgpass if _, ok := o["password"]; ok { return } filename := os.Getenv("PGPASSFILE") if filename == "" { // XXX this code doesn't work on Windows where the default filename is // XXX %APPDATA%\postgresql\pgpass.conf // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 userHome := os.Getenv("HOME") if userHome == "" { user, err := user.Current() if err != nil { return } userHome = user.HomeDir } filename = filepath.Join(userHome, ".pgpass") } fileinfo, err := os.Stat(filename) if err != nil { return } mode := fileinfo.Mode() if mode&(0x77) != 0 { // XXX should warn about incorrect .pgpass permissions as psql does return } file, err := os.Open(filename) if err != nil { return } defer file.Close() scanner := bufio.NewScanner(io.Reader(file)) hostname := o["host"] ntw, _ := network(o) port := o["port"] db := o["dbname"] username := o["user"] // From: https://github.com/tg/pgpass/blob/master/reader.go getFields := func(s string) []string { fs := make([]string, 0, 5) f := make([]rune, 0, len(s)) var esc bool for _, c := range s { switch { case esc: f = append(f, c) esc = false case c == '\\': esc = true case c == ':': fs = append(fs, string(f)) f = f[:0] default: f = append(f, c) } } return append(fs, string(f)) } for scanner.Scan() { line := scanner.Text() if len(line) == 0 || line[0] == '#' { continue } split := getFields(line) if len(split) != 5 { continue } if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { o["password"] = split[4] return } } } func (cn *conn) writeBuf(b byte) *writeBuf { cn.scratch[0] = b return &writeBuf{ buf: cn.scratch[:5], pos: 1, } } // Open opens a new connection to the database. dsn is a connection string. // Most users should only use it through database/sql package from the standard // library. func Open(dsn string) (_ driver.Conn, err error) { return DialOpen(defaultDialer{}, dsn) } // DialOpen opens a new connection to the database using a dialer. func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { c, err := NewConnector(dsn) if err != nil { return nil, err } c.dialer = d return c.open(context.Background()) } func (c *Connector) open(ctx context.Context) (cn *conn, err error) { // Handle any panics during connection initialization. Note that we // specifically do *not* want to use errRecover(), as that would turn any // connection errors into ErrBadConns, hiding the real error message from // the user. defer errRecoverNoErrBadConn(&err) o := c.opts cn = &conn{ opts: o, dialer: c.dialer, } err = cn.handleDriverSettings(o) if err != nil { return nil, err } cn.handlePgpass(o) cn.c, err = dial(ctx, c.dialer, o) if err != nil { return nil, err } err = cn.ssl(o) if err != nil { if cn.c != nil { cn.c.Close() } return nil, err } // cn.startup panics on error. Make sure we don't leak cn.c. panicking := true defer func() { if panicking { cn.c.Close() } }() cn.buf = bufio.NewReader(cn.c) cn.startup(o) // reset the deadline, in case one was set (see dial) if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { err = cn.c.SetDeadline(time.Time{}) } panicking = false return cn, err } func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { network, address := network(o) // Zero or not specified means wait indefinitely. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { seconds, err := strconv.ParseInt(timeout, 10, 0) if err != nil { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) } duration := time.Duration(seconds) * time.Second // connect_timeout should apply to the entire connection establishment // procedure, so we both use a timeout for the TCP connection // establishment and set a deadline for doing the initial handshake. // The deadline is then reset after startup() is done. deadline := time.Now().Add(duration) var conn net.Conn if dctx, ok := d.(DialerContext); ok { ctx, cancel := context.WithTimeout(ctx, duration) defer cancel() conn, err = dctx.DialContext(ctx, network, address) } else { conn, err = d.DialTimeout(network, address, duration) } if err != nil { return nil, err } err = conn.SetDeadline(deadline) return conn, err } if dctx, ok := d.(DialerContext); ok { return dctx.DialContext(ctx, network, address) } return d.Dial(network, address) } func network(o values) (string, string) { host := o["host"] if strings.HasPrefix(host, "/") { sockPath := path.Join(host, ".s.PGSQL."+o["port"]) return "unix", sockPath } return "tcp", net.JoinHostPort(host, o["port"]) } type values map[string]string // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune i int } // newScanner returns a new scanner initialized with the option string s. func newScanner(s string) *scanner { return &scanner{[]rune(s), 0} } // Next returns the next rune. // It returns 0, false if the end of the text has been reached. func (s *scanner) Next() (rune, bool) { if s.i >= len(s.s) { return 0, false } r := s.s[s.i] s.i++ return r, true } // SkipSpaces returns the next non-whitespace rune. // It returns 0, false if the end of the text has been reached. func (s *scanner) SkipSpaces() (rune, bool) { r, ok := s.Next() for unicode.IsSpace(r) && ok { r, ok = s.Next() } return r, ok } // parseOpts parses the options from name and adds them to the values. // // The parsing code is based on conninfo_parse from libpq's fe-connect.c func parseOpts(name string, o values) error { s := newScanner(name) for { var ( keyRunes, valRunes []rune r rune ok bool ) if r, ok = s.SkipSpaces(); !ok { break } // Scan the key for !unicode.IsSpace(r) && r != '=' { keyRunes = append(keyRunes, r) if r, ok = s.Next(); !ok { break } } // Skip any whitespace if we're not at the = yet if r != '=' { r, ok = s.SkipSpaces() } // The current character should be = if r != '=' || !ok { return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) } // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. o[string(keyRunes)] = "" break } if r != '\'' { for !unicode.IsSpace(r) { if r == '\\' { if r, ok = s.Next(); !ok { return fmt.Errorf(`missing character after backslash`) } } valRunes = append(valRunes, r) if r, ok = s.Next(); !ok { break } } } else { quote: for { if r, ok = s.Next(); !ok { return fmt.Errorf(`unterminated quoted string literal in connection string`) } switch r { case '\'': break quote case '\\': r, _ = s.Next() fallthrough default: valRunes = append(valRunes, r) } } } o[string(keyRunes)] = string(valRunes) } return nil } func (cn *conn) isInTransaction() bool { return cn.txnStatus == txnStatusIdleInTransaction || cn.txnStatus == txnStatusInFailedTransaction } func (cn *conn) checkIsInTransaction(intxn bool) { if cn.isInTransaction() != intxn { cn.bad = true errorf("unexpected transaction status %v", cn.txnStatus) } } func (cn *conn) Begin() (_ driver.Tx, err error) { return cn.begin("") } func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(false) _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, err } if commandTag != "BEGIN" { cn.bad = true return nil, fmt.Errorf("unexpected command tag %s", commandTag) } if cn.txnStatus != txnStatusIdleInTransaction { cn.bad = true return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus) } return cn, nil } func (cn *conn) closeTxn() { if finish := cn.txnFinish; finish != nil { finish() } } func (cn *conn) Commit() (err error) { defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(true) // We don't want the client to think that everything is okay if it tries // to commit a failed transaction. However, no matter what we return, // database/sql will release this connection back into the free connection // pool so we have to abort the current transaction here. Note that you // would get the same behaviour if you issued a COMMIT in a failed // transaction, so it's also the least surprising thing to do here. if cn.txnStatus == txnStatusInFailedTransaction { if err := cn.rollback(); err != nil { return err } return ErrInFailedTransaction } _, commandTag, err := cn.simpleExec("COMMIT") if err != nil { if cn.isInTransaction() { cn.bad = true } return err } if commandTag != "COMMIT" { cn.bad = true return fmt.Errorf("unexpected command tag %s", commandTag) } cn.checkIsInTransaction(false) return nil } func (cn *conn) Rollback() (err error) { defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } defer cn.errRecover(&err) return cn.rollback() } func (cn *conn) rollback() (err error) { cn.checkIsInTransaction(true) _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { if cn.isInTransaction() { cn.bad = true } return err } if commandTag != "ROLLBACK" { return fmt.Errorf("unexpected command tag %s", commandTag) } cn.checkIsInTransaction(false) return nil } func (cn *conn) gname() string { cn.namei++ return strconv.FormatInt(int64(cn.namei), 10) } func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) { b := cn.writeBuf('Q') b.string(q) cn.send(b) for { t, r := cn.recv1() switch t { case 'C': res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) if res == nil && err == nil { err = errUnexpectedReady } // done return case 'E': err = parseError(r) case 'I': res = emptyRows case 'T', 'D': // ignore any results default: cn.bad = true errorf("unknown response for simple query: %q", t) } } } func (cn *conn) simpleQuery(q string) (res *rows, err error) { defer cn.errRecover(&err) b := cn.writeBuf('Q') b.string(q) cn.send(b) for { t, r := cn.recv1() switch t { case 'C', 'I': // We allow queries which don't return any results through Query as // well as Exec. We still have to give database/sql a rows object // the user can close, though, to avoid connections from being // leaked. A "rows" with done=true works fine for that purpose. if err != nil { cn.bad = true errorf("unexpected message %q in simple query execution", t) } if res == nil { res = &rows{ cn: cn, } } // Set the result and tag to the last command complete if there wasn't a // query already run. Although queries usually return from here and cede // control to Next, a query with zero results does not. if t == 'C' && res.colNames == nil { res.result, res.tag = cn.parseComplete(r.string()) } res.done = true case 'Z': cn.processReadyForQuery(r) // done return case 'E': res = nil err = parseError(r) case 'D': if res == nil { cn.bad = true errorf("unexpected DataRow in simple query execution") } // the query didn't fail; kick off to Next cn.saveMessage(t, r) return case 'T': // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it res = &rows{cn: cn} res.rowsHeader = parsePortalRowDescribe(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. default: cn.bad = true errorf("unknown response for simple query: %q", t) } } } type noRows struct{} var emptyRows noRows var _ driver.Result = noRows{} func (noRows) LastInsertId() (int64, error) { return 0, errNoLastInsertID } func (noRows) RowsAffected() (int64, error) { return 0, errNoRowsAffected } // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { if len(colTyps) == 0 { return nil, colFmtDataAllText } colFmts = make([]format, len(colTyps)) if forceText { return colFmts, colFmtDataAllText } allBinary := true allText := true for i, t := range colTyps { switch t.OID { // This is the list of types to use binary mode for when receiving them // through a prepared statement. If a type appears in this list, it // must also be implemented in binaryDecode in encode.go. case oid.T_bytea: fallthrough case oid.T_int8: fallthrough case oid.T_int4: fallthrough case oid.T_int2: fallthrough case oid.T_uuid: colFmts[i] = formatBinary allText = false default: allBinary = false } } if allBinary { return colFmts, colFmtDataAllBinary } else if allText { return colFmts, colFmtDataAllText } else { colFmtData = make([]byte, 2+len(colFmts)*2) binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts))) for i, v := range colFmts { binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v)) } return colFmts, colFmtData } } func (cn *conn) prepareTo(q, stmtName string) *stmt { st := &stmt{cn: cn, name: stmtName} b := cn.writeBuf('P') b.string(st.name) b.string(q) b.int16(0) b.next('D') b.byte('S') b.string(st.name) b.next('S') cn.send(b) cn.readParseResponse() st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse() st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult) cn.readReadyForQuery() return st } func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { s, err := cn.prepareCopyIn(q) if err == nil { cn.inCopy = true } return s, err } return cn.prepareTo(q, cn.gname()), nil } func (cn *conn) Close() (err error) { // Skip cn.bad return here because we always want to close a connection. defer cn.errRecover(&err) // Ensure that cn.c.Close is always run. Since error handling is done with // panics and cn.errRecover, the Close must be in a defer. defer func() { cerr := cn.c.Close() if err == nil { err = cerr } }() // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. return cn.sendSimpleMessage('X') } // Implement the "Queryer" interface func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { return cn.query(query, args) } func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { if cn.bad { return nil, driver.ErrBadConn } if cn.inCopy { return nil, errCopyInProgress } defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is // *much* faster than going through prepare/exec if len(args) == 0 { return cn.simpleQuery(query) } if cn.binaryParameters { cn.sendBinaryModeQuery(query, args) cn.readParseResponse() cn.readBindResponse() rows := &rows{cn: cn} rows.rowsHeader = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil } st := cn.prepareTo(query, "") st.exec(args) return &rows{ cn: cn, rowsHeader: st.rowsHeader, }, nil } // Implement the optional "Execer" interface for one-shot queries func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) // Check to see if we can use the "simpleExec" interface, which is // *much* faster than going through prepare/exec if len(args) == 0 { // ignore commandTag, our caller doesn't care r, _, err := cn.simpleExec(query) return r, err } if cn.binaryParameters { cn.sendBinaryModeQuery(query, args) cn.readParseResponse() cn.readBindResponse() cn.readPortalDescribeResponse() cn.postExecuteWorkaround() res, _, err = cn.readExecuteResponse("Execute") return res, err } // Use the unnamed statement to defer planning until bind // time, or else value-based selectivity estimates cannot be // used. st := cn.prepareTo(query, "") r, err := st.Exec(args) if err != nil { panic(err) } return r, err } func (cn *conn) send(m *writeBuf) { _, err := cn.c.Write(m.wrap()) if err != nil { panic(err) } } func (cn *conn) sendStartupPacket(m *writeBuf) error { _, err := cn.c.Write((m.wrap())[1:]) return err } // Send a message of type typ to the server on the other end of cn. The // message should have no payload. This method does not use the scratch // buffer. func (cn *conn) sendSimpleMessage(typ byte) (err error) { _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) return err } // saveMessage memorizes a message and its buffer in the conn struct. // recvMessage will then return these values on the next call to it. This // method is useful in cases where you have to see what the next message is // going to be (e.g. to see whether it's an error or not) but you can't handle // the message yourself. func (cn *conn) saveMessage(typ byte, buf *readBuf) { if cn.saveMessageType != 0 { cn.bad = true errorf("unexpected saveMessageType %d", cn.saveMessageType) } cn.saveMessageType = typ cn.saveMessageBuffer = *buf } // recvMessage receives any message from the backend, or returns an error if // a problem occurred while reading the message. func (cn *conn) recvMessage(r *readBuf) (byte, error) { // workaround for a QueryRow bug, see exec if cn.saveMessageType != 0 { t := cn.saveMessageType *r = cn.saveMessageBuffer cn.saveMessageType = 0 cn.saveMessageBuffer = nil return t, nil } x := cn.scratch[:5] _, err := io.ReadFull(cn.buf, x) if err != nil { return 0, err } // read the type and length of the message that follows t := x[0] n := int(binary.BigEndian.Uint32(x[1:])) - 4 var y []byte if n <= len(cn.scratch) { y = cn.scratch[:n] } else { y = make([]byte, n) } _, err = io.ReadFull(cn.buf, y) if err != nil { return 0, err } *r = y return t, nil } // recv receives a message from the backend, but if an error happened while // reading the message or the received message was an ErrorResponse, it panics. // NoticeResponses are ignored. This function should generally be used only // during the startup sequence. func (cn *conn) recv() (t byte, r *readBuf) { for { var err error r = &readBuf{} t, err = cn.recvMessage(r) if err != nil { panic(err) } switch t { case 'E': panic(parseError(r)) case 'N': if n := cn.noticeHandler; n != nil { n(parseError(r)) } case 'A': if n := cn.notificationHandler; n != nil { n(recvNotification(r)) } default: return } } } // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by // the caller to avoid an allocation. func (cn *conn) recv1Buf(r *readBuf) byte { for { t, err := cn.recvMessage(r) if err != nil { panic(err) } switch t { case 'A': if n := cn.notificationHandler; n != nil { n(recvNotification(r)) } case 'N': if n := cn.noticeHandler; n != nil { n(parseError(r)) } case 'S': cn.processParameterStatus(r) default: return t } } } // recv1 receives a message from the backend, panicking if an error occurs // while attempting to read it. All asynchronous messages are ignored, with // the exception of ErrorResponse. func (cn *conn) recv1() (t byte, r *readBuf) { r = &readBuf{} t = cn.recv1Buf(r) return t, r } func (cn *conn) ssl(o values) error { upgrade, err := ssl(o) if err != nil { return err } if upgrade == nil { // Nothing to do return nil } w := cn.writeBuf(0) w.int32(80877103) if err = cn.sendStartupPacket(w); err != nil { return err } b := cn.scratch[:1] _, err = io.ReadFull(cn.c, b) if err != nil { return err } if b[0] != 'S' { return ErrSSLNotSupported } cn.c, err = upgrade(cn.c) return err } // isDriverSetting returns true iff a setting is purely for configuring the // driver's options and should not be sent to the server in the connection // startup packet. func isDriverSetting(key string) bool { switch key { case "host", "port": return true case "password": return true case "sslmode", "sslcert", "sslkey", "sslrootcert": return true case "fallback_application_name": return true case "connect_timeout": return true case "disable_prepared_binary_result": return true case "binary_parameters": return true case "krbsrvname": return true case "krbspn": return true default: return false } } func (cn *conn) startup(o values) { w := cn.writeBuf(0) w.int32(196608) // Send the backend the name of the database we want to connect to, and the // user we want to connect as. Additionally, we send over any run-time // parameters potentially included in the connection string. If the server // doesn't recognize any of them, it will reply with an error. for k, v := range o { if isDriverSetting(k) { // skip options which can't be run-time parameters continue } // The protocol requires us to supply the database name as "database" // instead of "dbname". if k == "dbname" { k = "database" } w.string(k) w.string(v) } w.string("") if err := cn.sendStartupPacket(w); err != nil { panic(err) } for { t, r := cn.recv() switch t { case 'K': cn.processBackendKeyData(r) case 'S': cn.processParameterStatus(r) case 'R': cn.auth(r, o) case 'Z': cn.processReadyForQuery(r) return default: errorf("unknown response for startup: %q", t) } } } func (cn *conn) auth(r *readBuf, o values) { switch code := r.int32(); code { case 0: // OK case 3: w := cn.writeBuf('p') w.string(o["password"]) cn.send(w) t, r := cn.recv() if t != 'R' { errorf("unexpected password response: %q", t) } if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } case 5: s := string(r.next(4)) w := cn.writeBuf('p') w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) cn.send(w) t, r := cn.recv() if t != 'R' { errorf("unexpected password response: %q", t) } if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } case 7: // GSSAPI, startup if newGss == nil { errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") } cli, err := newGss() if err != nil { errorf("kerberos error: %s", err.Error()) } var token []byte if spn, ok := o["krbspn"]; ok { // Use the supplied SPN if provided.. token, err = cli.GetInitTokenFromSpn(spn) } else { // Allow the kerberos service name to be overridden service := "postgres" if val, ok := o["krbsrvname"]; ok { service = val } token, err = cli.GetInitToken(o["host"], service) } if err != nil { errorf("failed to get Kerberos ticket: %q", err) } w := cn.writeBuf('p') w.bytes(token) cn.send(w) // Store for GSSAPI continue message cn.gss = cli case 8: // GSSAPI continue if cn.gss == nil { errorf("GSSAPI protocol error") } b := []byte(*r) done, tokOut, err := cn.gss.Continue(b) if err == nil && !done { w := cn.writeBuf('p') w.bytes(tokOut) cn.send(w) } // Errors fall through and read the more detailed message // from the server.. case 10: sc := scram.NewClient(sha256.New, o["user"], o["password"]) sc.Step(nil) if sc.Err() != nil { errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) } scOut := sc.Out() w := cn.writeBuf('p') w.string("SCRAM-SHA-256") w.int32(len(scOut)) w.bytes(scOut) cn.send(w) t, r := cn.recv() if t != 'R' { errorf("unexpected password response: %q", t) } if r.int32() != 11 { errorf("unexpected authentication response: %q", t) } nextStep := r.next(len(*r)) sc.Step(nextStep) if sc.Err() != nil { errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) } scOut = sc.Out() w = cn.writeBuf('p') w.bytes(scOut) cn.send(w) t, r = cn.recv() if t != 'R' { errorf("unexpected password response: %q", t) } if r.int32() != 12 { errorf("unexpected authentication response: %q", t) } nextStep = r.next(len(*r)) sc.Step(nextStep) if sc.Err() != nil { errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) } default: errorf("unknown authentication response: %d", code) } } type format int const formatText format = 0 const formatBinary format = 1 // One result-column format code with the value 1 (i.e. all binary). var colFmtDataAllBinary = []byte{0, 1, 0, 1} // No result-column format codes (i.e. all text). var colFmtDataAllText = []byte{0, 0} type stmt struct { cn *conn name string rowsHeader colFmtData []byte paramTyps []oid.Oid closed bool } func (st *stmt) Close() (err error) { if st.closed { return nil } if st.cn.bad { return driver.ErrBadConn } defer st.cn.errRecover(&err) w := st.cn.writeBuf('C') w.byte('S') w.string(st.name) st.cn.send(w) st.cn.send(st.cn.writeBuf('S')) t, _ := st.cn.recv1() if t != '3' { st.cn.bad = true errorf("unexpected close response: %q", t) } st.closed = true t, r := st.cn.recv1() if t != 'Z' { st.cn.bad = true errorf("expected ready for query, but got: %q", t) } st.cn.processReadyForQuery(r) return nil } func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { if st.cn.bad { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) st.exec(v) return &rows{ cn: st.cn, rowsHeader: st.rowsHeader, }, nil } func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) { if st.cn.bad { return nil, driver.ErrBadConn } defer st.cn.errRecover(&err) st.exec(v) res, _, err = st.cn.readExecuteResponse("simple query") return res, err } func (st *stmt) exec(v []driver.Value) { if len(v) >= 65536 { errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v)) } if len(v) != len(st.paramTyps) { errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps)) } cn := st.cn w := cn.writeBuf('B') w.byte(0) // unnamed portal w.string(st.name) if cn.binaryParameters { cn.sendBinaryParameters(w, v) } else { w.int16(0) w.int16(len(v)) for i, x := range v { if x == nil { w.int32(-1) } else { b := encode(&cn.parameterStatus, x, st.paramTyps[i]) w.int32(len(b)) w.bytes(b) } } } w.bytes(st.colFmtData) w.next('E') w.byte(0) w.int32(0) w.next('S') cn.send(w) cn.readBindResponse() cn.postExecuteWorkaround() } func (st *stmt) NumInput() int { return len(st.paramTyps) } // parseComplete parses the "command tag" from a CommandComplete message, and // returns the number of rows affected (if applicable) and a string // identifying only the command that was executed, e.g. "ALTER TABLE". If the // command tag could not be parsed, parseComplete panics. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { commandsWithAffectedRows := []string{ "SELECT ", // INSERT is handled below "UPDATE ", "DELETE ", "FETCH ", "MOVE ", "COPY ", } var affectedRows *string for _, tag := range commandsWithAffectedRows { if strings.HasPrefix(commandTag, tag) { t := commandTag[len(tag):] affectedRows = &t commandTag = tag[:len(tag)-1] break } } // INSERT also includes the oid of the inserted row in its command tag. // Oids in user tables are deprecated, and the oid is only returned when // exactly one row is inserted, so it's unlikely to be of value to any // real-world application and we can ignore it. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") { parts := strings.Split(commandTag, " ") if len(parts) != 3 { cn.bad = true errorf("unexpected INSERT command tag %s", commandTag) } affectedRows = &parts[len(parts)-1] commandTag = "INSERT" } // There should be no affected rows attached to the tag, just return it if affectedRows == nil { return driver.RowsAffected(0), commandTag } n, err := strconv.ParseInt(*affectedRows, 10, 64) if err != nil { cn.bad = true errorf("could not parse commandTag: %s", err) } return driver.RowsAffected(n), commandTag } type rowsHeader struct { colNames []string colTyps []fieldDesc colFmts []format } type rows struct { cn *conn finish func() rowsHeader done bool rb readBuf result driver.Result tag string next *rowsHeader } func (rs *rows) Close() error { if finish := rs.finish; finish != nil { defer finish() } // no need to look at cn.bad as Next() will for { err := rs.Next(nil) switch err { case nil: case io.EOF: // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row // description, used with HasNextResultSet). We need to fetch messages until // we hit a 'Z', which is done by waiting for done to be set. if rs.done { return nil } default: return err } } } func (rs *rows) Columns() []string { return rs.colNames } func (rs *rows) Result() driver.Result { if rs.result == nil { return emptyRows } return rs.result } func (rs *rows) Tag() string { return rs.tag } func (rs *rows) Next(dest []driver.Value) (err error) { if rs.done { return io.EOF } conn := rs.cn if conn.bad { return driver.ErrBadConn } defer conn.errRecover(&err) for { t := conn.recv1Buf(&rs.rb) switch t { case 'E': err = parseError(&rs.rb) case 'C', 'I': if t == 'C' { rs.result, rs.tag = conn.parseComplete(rs.rb.string()) } continue case 'Z': conn.processReadyForQuery(&rs.rb) rs.done = true if err != nil { return err } return io.EOF case 'D': n := rs.rb.int16() if err != nil { conn.bad = true errorf("unexpected DataRow after error %s", err) } if n < len(dest) { dest = dest[:n] } for i := range dest { l := rs.rb.int32() if l == -1 { dest[i] = nil continue } dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) } return case 'T': next := parsePortalRowDescribe(&rs.rb) rs.next = &next return io.EOF default: errorf("unexpected message after execute: %q", t) } } } func (rs *rows) HasNextResultSet() bool { hasNext := rs.next != nil && !rs.done return hasNext } func (rs *rows) NextResultSet() error { if rs.next == nil { return io.EOF } rs.rowsHeader = *rs.next rs.next = nil return nil } // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // used as part of an SQL statement. For example: // // tblname := "my_table" // data := "my_data" // quoted := pq.QuoteIdentifier(tblname) // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) // // Any double quotes in name will be escaped. The quoted identifier will be // case sensitive when used in a query. If the input string contains a zero // byte, the result will be truncated immediately before it. func QuoteIdentifier(name string) string { end := strings.IndexRune(name, 0) if end > -1 { name = name[:end] } return `"` + strings.Replace(name, `"`, `""`, -1) + `"` } // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal // to DDL and other statements that do not accept parameters) to be used as part // of an SQL statement. For example: // // exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") // err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) // // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be // replaced by two backslashes (i.e. "\\") and the C-style escape identifier // that PostgreSQL provides ('E') will be prepended to the string. func QuoteLiteral(literal string) string { // This follows the PostgreSQL internal algorithm for handling quoted literals // from libpq, which can be found in the "PQEscapeStringInternal" function, // which is found in the libpq/fe-exec.c source file: // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c // // substitute any single-quotes (') with two single-quotes ('') literal = strings.Replace(literal, `'`, `''`, -1) // determine if the string has any backslashes (\) in it. // if it does, replace any backslashes (\) with two backslashes (\\) // then, we need to wrap the entire string with a PostgreSQL // C-style escape. Per how "PQEscapeStringInternal" handles this case, we // also add a space before the "E" if strings.Contains(literal, `\`) { literal = strings.Replace(literal, `\`, `\\`, -1) literal = ` E'` + literal + `'` } else { // otherwise, we can just wrap the literal with a pair of single quotes literal = `'` + literal + `'` } return literal } func md5s(s string) string { h := md5.New() h.Write([]byte(s)) return fmt.Sprintf("%x", h.Sum(nil)) } func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) { // Do one pass over the parameters to see if we're going to send any of // them over in binary. If we are, create a paramFormats array at the // same time. var paramFormats []int for i, x := range args { _, ok := x.([]byte) if ok { if paramFormats == nil { paramFormats = make([]int, len(args)) } paramFormats[i] = 1 } } if paramFormats == nil { b.int16(0) } else { b.int16(len(paramFormats)) for _, x := range paramFormats { b.int16(x) } } b.int16(len(args)) for _, x := range args { if x == nil { b.int32(-1) } else { datum := binaryEncode(&cn.parameterStatus, x) b.int32(len(datum)) b.bytes(datum) } } } func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) { if len(args) >= 65536 { errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args)) } b := cn.writeBuf('P') b.byte(0) // unnamed statement b.string(query) b.int16(0) b.next('B') b.int16(0) // unnamed portal and statement cn.sendBinaryParameters(b, args) b.bytes(colFmtDataAllText) b.next('D') b.byte('P') b.byte(0) // unnamed portal b.next('E') b.byte(0) b.int32(0) b.next('S') cn.send(b) } func (cn *conn) processParameterStatus(r *readBuf) { var err error param := r.string() switch param { case "server_version": var major1 int var major2 int var minor int _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) if err == nil { cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor } case "TimeZone": cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) if err != nil { cn.parameterStatus.currentLocation = nil } default: // ignore } } func (cn *conn) processReadyForQuery(r *readBuf) { cn.txnStatus = transactionStatus(r.byte()) } func (cn *conn) readReadyForQuery() { t, r := cn.recv1() switch t { case 'Z': cn.processReadyForQuery(r) return default: cn.bad = true errorf("unexpected message %q; expected ReadyForQuery", t) } } func (cn *conn) processBackendKeyData(r *readBuf) { cn.processID = r.int32() cn.secretKey = r.int32() } func (cn *conn) readParseResponse() { t, r := cn.recv1() switch t { case '1': return case 'E': err := parseError(r) cn.readReadyForQuery() panic(err) default: cn.bad = true errorf("unexpected Parse response %q", t) } } func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { for { t, r := cn.recv1() switch t { case 't': nparams := r.int16() paramTyps = make([]oid.Oid, nparams) for i := range paramTyps { paramTyps[i] = r.oid() } case 'n': return paramTyps, nil, nil case 'T': colNames, colTyps = parseStatementRowDescribe(r) return paramTyps, colNames, colTyps case 'E': err := parseError(r) cn.readReadyForQuery() panic(err) default: cn.bad = true errorf("unexpected Describe statement response %q", t) } } } func (cn *conn) readPortalDescribeResponse() rowsHeader { t, r := cn.recv1() switch t { case 'T': return parsePortalRowDescribe(r) case 'n': return rowsHeader{} case 'E': err := parseError(r) cn.readReadyForQuery() panic(err) default: cn.bad = true errorf("unexpected Describe response %q", t) } panic("not reached") } func (cn *conn) readBindResponse() { t, r := cn.recv1() switch t { case '2': return case 'E': err := parseError(r) cn.readReadyForQuery() panic(err) default: cn.bad = true errorf("unexpected Bind response %q", t) } } func (cn *conn) postExecuteWorkaround() { // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores // any errors from rows.Next, which masks errors that happened during the // execution of the query. To avoid the problem in common cases, we wait // here for one more message from the database. If it's not an error the // query will likely succeed (or perhaps has already, if it's a // CommandComplete), so we push the message into the conn struct; recv1 // will return it as the next message for rows.Next or rows.Close. // However, if it's an error, we wait until ReadyForQuery and then return // the error to our caller. for { t, r := cn.recv1() switch t { case 'E': err := parseError(r) cn.readReadyForQuery() panic(err) case 'C', 'D', 'I': // the query didn't fail, but we can't process this message cn.saveMessage(t, r) return default: cn.bad = true errorf("unexpected message during extended query execution: %q", t) } } } // Only for Exec(), since we ignore the returned data func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) { for { t, r := cn.recv1() switch t { case 'C': if err != nil { cn.bad = true errorf("unexpected CommandComplete after error %s", err) } res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) if res == nil && err == nil { err = errUnexpectedReady } return res, commandTag, err case 'E': err = parseError(r) case 'T', 'D', 'I': if err != nil { cn.bad = true errorf("unexpected %q after error %s", t, err) } if t == 'I' { res = emptyRows } // ignore any results default: cn.bad = true errorf("unknown %s response: %q", protocolState, t) } } } func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) colTyps[i].OID = r.oid() colTyps[i].Len = r.int16() colTyps[i].Mod = r.int32() // format code not known when describing a statement; always 0 r.next(2) } return } func parsePortalRowDescribe(r *readBuf) rowsHeader { n := r.int16() colNames := make([]string, n) colFmts := make([]format, n) colTyps := make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) colTyps[i].OID = r.oid() colTyps[i].Len = r.int16() colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } return rowsHeader{ colNames: colNames, colFmts: colFmts, colTyps: colTyps, } } // parseEnviron tries to mimic some of libpq's environment handling // // To ease testing, it does not directly reference os.Environ, but is // designed to accept its output. // // Environment-set connection information is intended to have a higher // precedence than a library default but lower than any explicitly // passed information (such as in the URL or connection string). func parseEnviron(env []string) (out map[string]string) { out = make(map[string]string) for _, v := range env { parts := strings.SplitN(v, "=", 2) accrue := func(keyname string) { out[keyname] = parts[1] } unsupported := func() { panic(fmt.Sprintf("setting %v not supported", parts[0])) } // The order of these is the same as is seen in the // PostgreSQL 9.1 manual. Unsupported but well-defined // keys cause a panic; these should be unset prior to // execution. Options which pq expects to be set to a // certain value are allowed, but must be set to that // value if present (they can, of course, be absent). switch parts[0] { case "PGHOST": accrue("host") case "PGHOSTADDR": unsupported() case "PGPORT": accrue("port") case "PGDATABASE": accrue("dbname") case "PGUSER": accrue("user") case "PGPASSWORD": accrue("password") case "PGSERVICE", "PGSERVICEFILE", "PGREALM": unsupported() case "PGOPTIONS": accrue("options") case "PGAPPNAME": accrue("application_name") case "PGSSLMODE": accrue("sslmode") case "PGSSLCERT": accrue("sslcert") case "PGSSLKEY": accrue("sslkey") case "PGSSLROOTCERT": accrue("sslrootcert") case "PGREQUIRESSL", "PGSSLCRL": unsupported() case "PGREQUIREPEER": unsupported() case "PGKRBSRVNAME", "PGGSSLIB": unsupported() case "PGCONNECT_TIMEOUT": accrue("connect_timeout") case "PGCLIENTENCODING": accrue("client_encoding") case "PGDATESTYLE": accrue("datestyle") case "PGTZ": accrue("timezone") case "PGGEQO": accrue("geqo") case "PGSYSCONFDIR", "PGLOCALEDIR": unsupported() } } return out } // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". func isUTF8(name string) bool { // Recognize all sorts of silly things as "UTF-8", like Postgres does s := strings.Map(alnumLowerASCII, name) return s == "utf8" || s == "unicode" } func alnumLowerASCII(ch rune) rune { if 'A' <= ch && ch <= 'Z' { return ch + ('a' - 'A') } if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' { return ch } return -1 // discard }