|  | // Copyright 2011 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | package sql | 
|  |  | 
|  | import ( | 
|  | "context" | 
|  | "database/sql/driver" | 
|  | "errors" | 
|  | "fmt" | 
|  | "io" | 
|  | "reflect" | 
|  | "sort" | 
|  | "strconv" | 
|  | "strings" | 
|  | "sync" | 
|  | "testing" | 
|  | "time" | 
|  | ) | 
|  |  | 
|  | // fakeDriver is a fake database that implements Go's driver.Driver | 
|  | // interface, just for testing. | 
|  | // | 
|  | // It speaks a query language that's semantically similar to but | 
|  | // syntactically different and simpler than SQL.  The syntax is as | 
|  | // follows: | 
|  | // | 
|  | //   WIPE | 
|  | //   CREATE|<tablename>|<col>=<type>,<col>=<type>,... | 
|  | //     where types are: "string", [u]int{8,16,32,64}, "bool" | 
|  | //   INSERT|<tablename>|col=val,col2=val2,col3=? | 
|  | //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? | 
|  | //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 | 
|  | // | 
|  | // Any of these can be preceded by PANIC|<method>|, to cause the | 
|  | // named method on fakeStmt to panic. | 
|  | // | 
|  | // Any of these can be proceeded by WAIT|<duration>|, to cause the | 
|  | // named method on fakeStmt to sleep for the specified duration. | 
|  | // | 
|  | // Multiple of these can be combined when separated with a semicolon. | 
|  | // | 
|  | // When opening a fakeDriver's database, it starts empty with no | 
|  | // tables. All tables and data are stored in memory only. | 
|  | type fakeDriver struct { | 
|  | mu         sync.Mutex // guards 3 following fields | 
|  | openCount  int        // conn opens | 
|  | closeCount int        // conn closes | 
|  | waitCh     chan struct{} | 
|  | waitingCh  chan struct{} | 
|  | dbs        map[string]*fakeDB | 
|  | } | 
|  |  | 
|  | type fakeConnector struct { | 
|  | name string | 
|  |  | 
|  | waiter func(context.Context) | 
|  | closed bool | 
|  | } | 
|  |  | 
|  | func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { | 
|  | conn, err := fdriver.Open(c.name) | 
|  | conn.(*fakeConn).waiter = c.waiter | 
|  | return conn, err | 
|  | } | 
|  |  | 
|  | func (c *fakeConnector) Driver() driver.Driver { | 
|  | return fdriver | 
|  | } | 
|  |  | 
|  | func (c *fakeConnector) Close() error { | 
|  | if c.closed { | 
|  | return errors.New("fakedb: connector is closed") | 
|  | } | 
|  | c.closed = true | 
|  | return nil | 
|  | } | 
|  |  | 
|  | type fakeDriverCtx struct { | 
|  | fakeDriver | 
|  | } | 
|  |  | 
|  | var _ driver.DriverContext = &fakeDriverCtx{} | 
|  |  | 
|  | func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { | 
|  | return &fakeConnector{name: name}, nil | 
|  | } | 
|  |  | 
|  | type fakeDB struct { | 
|  | name string | 
|  |  | 
|  | mu       sync.Mutex | 
|  | tables   map[string]*table | 
|  | badConn  bool | 
|  | allowAny bool | 
|  | } | 
|  |  | 
|  | type fakeError struct { | 
|  | Message string | 
|  | Wrapped error | 
|  | } | 
|  |  | 
|  | func (err fakeError) Error() string { | 
|  | return err.Message | 
|  | } | 
|  |  | 
|  | func (err fakeError) Unwrap() error { | 
|  | return err.Wrapped | 
|  | } | 
|  |  | 
|  | type table struct { | 
|  | mu      sync.Mutex | 
|  | colname []string | 
|  | coltype []string | 
|  | rows    []*row | 
|  | } | 
|  |  | 
|  | func (t *table) columnIndex(name string) int { | 
|  | for n, nname := range t.colname { | 
|  | if name == nname { | 
|  | return n | 
|  | } | 
|  | } | 
|  | return -1 | 
|  | } | 
|  |  | 
|  | type row struct { | 
|  | cols []any // must be same size as its table colname + coltype | 
|  | } | 
|  |  | 
|  | type memToucher interface { | 
|  | // touchMem reads & writes some memory, to help find data races. | 
|  | touchMem() | 
|  | } | 
|  |  | 
|  | type fakeConn struct { | 
|  | db *fakeDB // where to return ourselves to | 
|  |  | 
|  | currTx *fakeTx | 
|  |  | 
|  | // Every operation writes to line to enable the race detector | 
|  | // check for data races. | 
|  | line int64 | 
|  |  | 
|  | // Stats for tests: | 
|  | mu          sync.Mutex | 
|  | stmtsMade   int | 
|  | stmtsClosed int | 
|  | numPrepare  int | 
|  |  | 
|  | // bad connection tests; see isBad() | 
|  | bad       bool | 
|  | stickyBad bool | 
|  |  | 
|  | skipDirtySession bool // tests that use Conn should set this to true. | 
|  |  | 
|  | // dirtySession tests ResetSession, true if a query has executed | 
|  | // until ResetSession is called. | 
|  | dirtySession bool | 
|  |  | 
|  | // The waiter is called before each query. May be used in place of the "WAIT" | 
|  | // directive. | 
|  | waiter func(context.Context) | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) touchMem() { | 
|  | c.line++ | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) incrStat(v *int) { | 
|  | c.mu.Lock() | 
|  | *v++ | 
|  | c.mu.Unlock() | 
|  | } | 
|  |  | 
|  | type fakeTx struct { | 
|  | c *fakeConn | 
|  | } | 
|  |  | 
|  | type boundCol struct { | 
|  | Column      string | 
|  | Placeholder string | 
|  | Ordinal     int | 
|  | } | 
|  |  | 
|  | type fakeStmt struct { | 
|  | memToucher | 
|  | c *fakeConn | 
|  | q string // just for debugging | 
|  |  | 
|  | cmd   string | 
|  | table string | 
|  | panic string | 
|  | wait  time.Duration | 
|  |  | 
|  | next *fakeStmt // used for returning multiple results. | 
|  |  | 
|  | closed bool | 
|  |  | 
|  | colName      []string // used by CREATE, INSERT, SELECT (selected columns) | 
|  | colType      []string // used by CREATE | 
|  | colValue     []any    // used by INSERT (mix of strings and "?" for bound params) | 
|  | placeholders int      // used by INSERT/SELECT: number of ? params | 
|  |  | 
|  | whereCol []boundCol // used by SELECT (all placeholders) | 
|  |  | 
|  | placeholderConverter []driver.ValueConverter // used by INSERT | 
|  | } | 
|  |  | 
|  | var fdriver driver.Driver = &fakeDriver{} | 
|  |  | 
|  | func init() { | 
|  | Register("test", fdriver) | 
|  | } | 
|  |  | 
|  | func contains(list []string, y string) bool { | 
|  | for _, x := range list { | 
|  | if x == y { | 
|  | return true | 
|  | } | 
|  | } | 
|  | return false | 
|  | } | 
|  |  | 
|  | type Dummy struct { | 
|  | driver.Driver | 
|  | } | 
|  |  | 
|  | func TestDrivers(t *testing.T) { | 
|  | unregisterAllDrivers() | 
|  | Register("test", fdriver) | 
|  | Register("invalid", Dummy{}) | 
|  | all := Drivers() | 
|  | if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { | 
|  | t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) | 
|  | } | 
|  | } | 
|  |  | 
|  | // hook to simulate connection failures | 
|  | var hookOpenErr struct { | 
|  | sync.Mutex | 
|  | fn func() error | 
|  | } | 
|  |  | 
|  | func setHookOpenErr(fn func() error) { | 
|  | hookOpenErr.Lock() | 
|  | defer hookOpenErr.Unlock() | 
|  | hookOpenErr.fn = fn | 
|  | } | 
|  |  | 
|  | // Supports dsn forms: | 
|  | //    <dbname> | 
|  | //    <dbname>;<opts>  (only currently supported option is `badConn`, | 
|  | //                      which causes driver.ErrBadConn to be returned on | 
|  | //                      every other conn.Begin()) | 
|  | func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { | 
|  | hookOpenErr.Lock() | 
|  | fn := hookOpenErr.fn | 
|  | hookOpenErr.Unlock() | 
|  | if fn != nil { | 
|  | if err := fn(); err != nil { | 
|  | return nil, err | 
|  | } | 
|  | } | 
|  | parts := strings.Split(dsn, ";") | 
|  | if len(parts) < 1 { | 
|  | return nil, errors.New("fakedb: no database name") | 
|  | } | 
|  | name := parts[0] | 
|  |  | 
|  | db := d.getDB(name) | 
|  |  | 
|  | d.mu.Lock() | 
|  | d.openCount++ | 
|  | d.mu.Unlock() | 
|  | conn := &fakeConn{db: db} | 
|  |  | 
|  | if len(parts) >= 2 && parts[1] == "badConn" { | 
|  | conn.bad = true | 
|  | } | 
|  | if d.waitCh != nil { | 
|  | d.waitingCh <- struct{}{} | 
|  | <-d.waitCh | 
|  | d.waitCh = nil | 
|  | d.waitingCh = nil | 
|  | } | 
|  | return conn, nil | 
|  | } | 
|  |  | 
|  | func (d *fakeDriver) getDB(name string) *fakeDB { | 
|  | d.mu.Lock() | 
|  | defer d.mu.Unlock() | 
|  | if d.dbs == nil { | 
|  | d.dbs = make(map[string]*fakeDB) | 
|  | } | 
|  | db, ok := d.dbs[name] | 
|  | if !ok { | 
|  | db = &fakeDB{name: name} | 
|  | d.dbs[name] = db | 
|  | } | 
|  | return db | 
|  | } | 
|  |  | 
|  | func (db *fakeDB) wipe() { | 
|  | db.mu.Lock() | 
|  | defer db.mu.Unlock() | 
|  | db.tables = nil | 
|  | } | 
|  |  | 
|  | func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { | 
|  | db.mu.Lock() | 
|  | defer db.mu.Unlock() | 
|  | if db.tables == nil { | 
|  | db.tables = make(map[string]*table) | 
|  | } | 
|  | if _, exist := db.tables[name]; exist { | 
|  | return fmt.Errorf("fakedb: table %q already exists", name) | 
|  | } | 
|  | if len(columnNames) != len(columnTypes) { | 
|  | return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", | 
|  | name, len(columnNames), len(columnTypes)) | 
|  | } | 
|  | db.tables[name] = &table{colname: columnNames, coltype: columnTypes} | 
|  | return nil | 
|  | } | 
|  |  | 
|  | // must be called with db.mu lock held | 
|  | func (db *fakeDB) table(table string) (*table, bool) { | 
|  | if db.tables == nil { | 
|  | return nil, false | 
|  | } | 
|  | t, ok := db.tables[table] | 
|  | return t, ok | 
|  | } | 
|  |  | 
|  | func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { | 
|  | db.mu.Lock() | 
|  | defer db.mu.Unlock() | 
|  | t, ok := db.table(table) | 
|  | if !ok { | 
|  | return | 
|  | } | 
|  | for n, cname := range t.colname { | 
|  | if cname == column { | 
|  | return t.coltype[n], true | 
|  | } | 
|  | } | 
|  | return "", false | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) isBad() bool { | 
|  | if c.stickyBad { | 
|  | return true | 
|  | } else if c.bad { | 
|  | if c.db == nil { | 
|  | return false | 
|  | } | 
|  | // alternate between bad conn and not bad conn | 
|  | c.db.badConn = !c.db.badConn | 
|  | return c.db.badConn | 
|  | } else { | 
|  | return false | 
|  | } | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) isDirtyAndMark() bool { | 
|  | if c.skipDirtySession { | 
|  | return false | 
|  | } | 
|  | if c.currTx != nil { | 
|  | c.dirtySession = true | 
|  | return false | 
|  | } | 
|  | if c.dirtySession { | 
|  | return true | 
|  | } | 
|  | c.dirtySession = true | 
|  | return false | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) Begin() (driver.Tx, error) { | 
|  | if c.isBad() { | 
|  | return nil, fakeError{Wrapped: driver.ErrBadConn} | 
|  | } | 
|  | if c.currTx != nil { | 
|  | return nil, errors.New("fakedb: already in a transaction") | 
|  | } | 
|  | c.touchMem() | 
|  | c.currTx = &fakeTx{c: c} | 
|  | return c.currTx, nil | 
|  | } | 
|  |  | 
|  | var hookPostCloseConn struct { | 
|  | sync.Mutex | 
|  | fn func(*fakeConn, error) | 
|  | } | 
|  |  | 
|  | func setHookpostCloseConn(fn func(*fakeConn, error)) { | 
|  | hookPostCloseConn.Lock() | 
|  | defer hookPostCloseConn.Unlock() | 
|  | hookPostCloseConn.fn = fn | 
|  | } | 
|  |  | 
|  | var testStrictClose *testing.T | 
|  |  | 
|  | // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close | 
|  | // fails to close. If nil, the check is disabled. | 
|  | func setStrictFakeConnClose(t *testing.T) { | 
|  | testStrictClose = t | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) ResetSession(ctx context.Context) error { | 
|  | c.dirtySession = false | 
|  | c.currTx = nil | 
|  | if c.isBad() { | 
|  | return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} | 
|  | } | 
|  | return nil | 
|  | } | 
|  |  | 
|  | var _ driver.Validator = (*fakeConn)(nil) | 
|  |  | 
|  | func (c *fakeConn) IsValid() bool { | 
|  | return !c.isBad() | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) Close() (err error) { | 
|  | drv := fdriver.(*fakeDriver) | 
|  | defer func() { | 
|  | if err != nil && testStrictClose != nil { | 
|  | testStrictClose.Errorf("failed to close a test fakeConn: %v", err) | 
|  | } | 
|  | hookPostCloseConn.Lock() | 
|  | fn := hookPostCloseConn.fn | 
|  | hookPostCloseConn.Unlock() | 
|  | if fn != nil { | 
|  | fn(c, err) | 
|  | } | 
|  | if err == nil { | 
|  | drv.mu.Lock() | 
|  | drv.closeCount++ | 
|  | drv.mu.Unlock() | 
|  | } | 
|  | }() | 
|  | c.touchMem() | 
|  | if c.currTx != nil { | 
|  | return errors.New("fakedb: can't close fakeConn; in a Transaction") | 
|  | } | 
|  | if c.db == nil { | 
|  | return errors.New("fakedb: can't close fakeConn; already closed") | 
|  | } | 
|  | if c.stmtsMade > c.stmtsClosed { | 
|  | return errors.New("fakedb: can't close; dangling statement(s)") | 
|  | } | 
|  | c.db = nil | 
|  | return nil | 
|  | } | 
|  |  | 
|  | func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { | 
|  | for _, arg := range args { | 
|  | switch arg.Value.(type) { | 
|  | case int64, float64, bool, nil, []byte, string, time.Time: | 
|  | default: | 
|  | if !allowAny { | 
|  | return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) | 
|  | } | 
|  | } | 
|  | } | 
|  | return nil | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { | 
|  | // Ensure that ExecContext is called if available. | 
|  | panic("ExecContext was not called.") | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | 
|  | // This is an optional interface, but it's implemented here | 
|  | // just to check that all the args are of the proper types. | 
|  | // ErrSkip is returned so the caller acts as if we didn't | 
|  | // implement this at all. | 
|  | err := checkSubsetTypes(c.db.allowAny, args) | 
|  | if err != nil { | 
|  | return nil, err | 
|  | } | 
|  | return nil, driver.ErrSkip | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { | 
|  | // Ensure that ExecContext is called if available. | 
|  | panic("QueryContext was not called.") | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | 
|  | // This is an optional interface, but it's implemented here | 
|  | // just to check that all the args are of the proper types. | 
|  | // ErrSkip is returned so the caller acts as if we didn't | 
|  | // implement this at all. | 
|  | err := checkSubsetTypes(c.db.allowAny, args) | 
|  | if err != nil { | 
|  | return nil, err | 
|  | } | 
|  | return nil, driver.ErrSkip | 
|  | } | 
|  |  | 
|  | func errf(msg string, args ...any) error { | 
|  | return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) | 
|  | } | 
|  |  | 
|  | // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? | 
|  | // (note that where columns must always contain ? marks, | 
|  | //  just a limitation for fakedb) | 
|  | func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { | 
|  | if len(parts) != 3 { | 
|  | stmt.Close() | 
|  | return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) | 
|  | } | 
|  | stmt.table = parts[0] | 
|  |  | 
|  | stmt.colName = strings.Split(parts[1], ",") | 
|  | for n, colspec := range strings.Split(parts[2], ",") { | 
|  | if colspec == "" { | 
|  | continue | 
|  | } | 
|  | nameVal := strings.Split(colspec, "=") | 
|  | if len(nameVal) != 2 { | 
|  | stmt.Close() | 
|  | return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) | 
|  | } | 
|  | column, value := nameVal[0], nameVal[1] | 
|  | _, ok := c.db.columnType(stmt.table, column) | 
|  | if !ok { | 
|  | stmt.Close() | 
|  | return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) | 
|  | } | 
|  | if !strings.HasPrefix(value, "?") { | 
|  | stmt.Close() | 
|  | return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", | 
|  | stmt.table, column) | 
|  | } | 
|  | stmt.placeholders++ | 
|  | stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) | 
|  | } | 
|  | return stmt, nil | 
|  | } | 
|  |  | 
|  | // parts are table|col=type,col2=type2 | 
|  | func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { | 
|  | if len(parts) != 2 { | 
|  | stmt.Close() | 
|  | return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) | 
|  | } | 
|  | stmt.table = parts[0] | 
|  | for n, colspec := range strings.Split(parts[1], ",") { | 
|  | nameType := strings.Split(colspec, "=") | 
|  | if len(nameType) != 2 { | 
|  | stmt.Close() | 
|  | return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) | 
|  | } | 
|  | stmt.colName = append(stmt.colName, nameType[0]) | 
|  | stmt.colType = append(stmt.colType, nameType[1]) | 
|  | } | 
|  | return stmt, nil | 
|  | } | 
|  |  | 
|  | // parts are table|col=?,col2=val | 
|  | func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { | 
|  | if len(parts) != 2 { | 
|  | stmt.Close() | 
|  | return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) | 
|  | } | 
|  | stmt.table = parts[0] | 
|  | for n, colspec := range strings.Split(parts[1], ",") { | 
|  | nameVal := strings.Split(colspec, "=") | 
|  | if len(nameVal) != 2 { | 
|  | stmt.Close() | 
|  | return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) | 
|  | } | 
|  | column, value := nameVal[0], nameVal[1] | 
|  | ctype, ok := c.db.columnType(stmt.table, column) | 
|  | if !ok { | 
|  | stmt.Close() | 
|  | return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) | 
|  | } | 
|  | stmt.colName = append(stmt.colName, column) | 
|  |  | 
|  | if !strings.HasPrefix(value, "?") { | 
|  | var subsetVal any | 
|  | // Convert to driver subset type | 
|  | switch ctype { | 
|  | case "string": | 
|  | subsetVal = []byte(value) | 
|  | case "blob": | 
|  | subsetVal = []byte(value) | 
|  | case "int32": | 
|  | i, err := strconv.Atoi(value) | 
|  | if err != nil { | 
|  | stmt.Close() | 
|  | return nil, errf("invalid conversion to int32 from %q", value) | 
|  | } | 
|  | subsetVal = int64(i) // int64 is a subset type, but not int32 | 
|  | case "table": // For testing cursor reads. | 
|  | c.skipDirtySession = true | 
|  | vparts := strings.Split(value, "!") | 
|  |  | 
|  | substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) | 
|  | if err != nil { | 
|  | return nil, err | 
|  | } | 
|  | cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) | 
|  | substmt.Close() | 
|  | if err != nil { | 
|  | return nil, err | 
|  | } | 
|  | subsetVal = cursor | 
|  | default: | 
|  | stmt.Close() | 
|  | return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) | 
|  | } | 
|  | stmt.colValue = append(stmt.colValue, subsetVal) | 
|  | } else { | 
|  | stmt.placeholders++ | 
|  | stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) | 
|  | stmt.colValue = append(stmt.colValue, value) | 
|  | } | 
|  | } | 
|  | return stmt, nil | 
|  | } | 
|  |  | 
|  | // hook to simulate broken connections | 
|  | var hookPrepareBadConn func() bool | 
|  |  | 
|  | func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { | 
|  | panic("use PrepareContext") | 
|  | } | 
|  |  | 
|  | func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | 
|  | c.numPrepare++ | 
|  | if c.db == nil { | 
|  | panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) | 
|  | } | 
|  |  | 
|  | if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { | 
|  | return nil, fakeError{Message: "Preapre: Sticky Bad", Wrapped: driver.ErrBadConn} | 
|  | } | 
|  |  | 
|  | c.touchMem() | 
|  | var firstStmt, prev *fakeStmt | 
|  | for _, query := range strings.Split(query, ";") { | 
|  | parts := strings.Split(query, "|") | 
|  | if len(parts) < 1 { | 
|  | return nil, errf("empty query") | 
|  | } | 
|  | stmt := &fakeStmt{q: query, c: c, memToucher: c} | 
|  | if firstStmt == nil { | 
|  | firstStmt = stmt | 
|  | } | 
|  | if len(parts) >= 3 { | 
|  | switch parts[0] { | 
|  | case "PANIC": | 
|  | stmt.panic = parts[1] | 
|  | parts = parts[2:] | 
|  | case "WAIT": | 
|  | wait, err := time.ParseDuration(parts[1]) | 
|  | if err != nil { | 
|  | return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) | 
|  | } | 
|  | parts = parts[2:] | 
|  | stmt.wait = wait | 
|  | } | 
|  | } | 
|  | cmd := parts[0] | 
|  | stmt.cmd = cmd | 
|  | parts = parts[1:] | 
|  |  | 
|  | if c.waiter != nil { | 
|  | c.waiter(ctx) | 
|  | if err := ctx.Err(); err != nil { | 
|  | return nil, err | 
|  | } | 
|  | } | 
|  |  | 
|  | if stmt.wait > 0 { | 
|  | wait := time.NewTimer(stmt.wait) | 
|  | select { | 
|  | case <-wait.C: | 
|  | case <-ctx.Done(): | 
|  | wait.Stop() | 
|  | return nil, ctx.Err() | 
|  | } | 
|  | } | 
|  |  | 
|  | c.incrStat(&c.stmtsMade) | 
|  | var err error | 
|  | switch cmd { | 
|  | case "WIPE": | 
|  | // Nothing | 
|  | case "SELECT": | 
|  | stmt, err = c.prepareSelect(stmt, parts) | 
|  | case "CREATE": | 
|  | stmt, err = c.prepareCreate(stmt, parts) | 
|  | case "INSERT": | 
|  | stmt, err = c.prepareInsert(ctx, stmt, parts) | 
|  | case "NOSERT": | 
|  | // Do all the prep-work like for an INSERT but don't actually insert the row. | 
|  | // Used for some of the concurrent tests. | 
|  | stmt, err = c.prepareInsert(ctx, stmt, parts) | 
|  | default: | 
|  | stmt.Close() | 
|  | return nil, errf("unsupported command type %q", cmd) | 
|  | } | 
|  | if err != nil { | 
|  | return nil, err | 
|  | } | 
|  | if prev != nil { | 
|  | prev.next = stmt | 
|  | } | 
|  | prev = stmt | 
|  | } | 
|  | return firstStmt, nil | 
|  | } | 
|  |  | 
|  | func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { | 
|  | if s.panic == "ColumnConverter" { | 
|  | panic(s.panic) | 
|  | } | 
|  | if len(s.placeholderConverter) == 0 { | 
|  | return driver.DefaultParameterConverter | 
|  | } | 
|  | return s.placeholderConverter[idx] | 
|  | } | 
|  |  | 
|  | func (s *fakeStmt) Close() error { | 
|  | if s.panic == "Close" { | 
|  | panic(s.panic) | 
|  | } | 
|  | if s.c == nil { | 
|  | panic("nil conn in fakeStmt.Close") | 
|  | } | 
|  | if s.c.db == nil { | 
|  | panic("in fakeStmt.Close, conn's db is nil (already closed)") | 
|  | } | 
|  | s.touchMem() | 
|  | if !s.closed { | 
|  | s.c.incrStat(&s.c.stmtsClosed) | 
|  | s.closed = true | 
|  | } | 
|  | if s.next != nil { | 
|  | s.next.Close() | 
|  | } | 
|  | return nil | 
|  | } | 
|  |  | 
|  | var errClosed = errors.New("fakedb: statement has been closed") | 
|  |  | 
|  | // hook to simulate broken connections | 
|  | var hookExecBadConn func() bool | 
|  |  | 
|  | func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { | 
|  | panic("Using ExecContext") | 
|  | } | 
|  |  | 
|  | var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") | 
|  |  | 
|  | func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | 
|  | if s.panic == "Exec" { | 
|  | panic(s.panic) | 
|  | } | 
|  | if s.closed { | 
|  | return nil, errClosed | 
|  | } | 
|  |  | 
|  | if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { | 
|  | return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} | 
|  | } | 
|  | if s.c.isDirtyAndMark() { | 
|  | return nil, errFakeConnSessionDirty | 
|  | } | 
|  |  | 
|  | err := checkSubsetTypes(s.c.db.allowAny, args) | 
|  | if err != nil { | 
|  | return nil, err | 
|  | } | 
|  | s.touchMem() | 
|  |  | 
|  | if s.wait > 0 { | 
|  | time.Sleep(s.wait) | 
|  | } | 
|  |  | 
|  | select { | 
|  | default: | 
|  | case <-ctx.Done(): | 
|  | return nil, ctx.Err() | 
|  | } | 
|  |  | 
|  | db := s.c.db | 
|  | switch s.cmd { | 
|  | case "WIPE": | 
|  | db.wipe() | 
|  | return driver.ResultNoRows, nil | 
|  | case "CREATE": | 
|  | if err := db.createTable(s.table, s.colName, s.colType); err != nil { | 
|  | return nil, err | 
|  | } | 
|  | return driver.ResultNoRows, nil | 
|  | case "INSERT": | 
|  | return s.execInsert(args, true) | 
|  | case "NOSERT": | 
|  | // Do all the prep-work like for an INSERT but don't actually insert the row. | 
|  | // Used for some of the concurrent tests. | 
|  | return s.execInsert(args, false) | 
|  | } | 
|  | return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) | 
|  | } | 
|  |  | 
|  | // When doInsert is true, add the row to the table. | 
|  | // When doInsert is false do prep-work and error checking, but don't | 
|  | // actually add the row to the table. | 
|  | func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { | 
|  | db := s.c.db | 
|  | if len(args) != s.placeholders { | 
|  | panic("error in pkg db; should only get here if size is correct") | 
|  | } | 
|  | db.mu.Lock() | 
|  | t, ok := db.table(s.table) | 
|  | db.mu.Unlock() | 
|  | if !ok { | 
|  | return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) | 
|  | } | 
|  |  | 
|  | t.mu.Lock() | 
|  | defer t.mu.Unlock() | 
|  |  | 
|  | var cols []any | 
|  | if doInsert { | 
|  | cols = make([]any, len(t.colname)) | 
|  | } | 
|  | argPos := 0 | 
|  | for n, colname := range s.colName { | 
|  | colidx := t.columnIndex(colname) | 
|  | if colidx == -1 { | 
|  | return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) | 
|  | } | 
|  | var val any | 
|  | if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { | 
|  | if strvalue == "?" { | 
|  | val = args[argPos].Value | 
|  | } else { | 
|  | // Assign value from argument placeholder name. | 
|  | for _, a := range args { | 
|  | if a.Name == strvalue[1:] { | 
|  | val = a.Value | 
|  | break | 
|  | } | 
|  | } | 
|  | } | 
|  | argPos++ | 
|  | } else { | 
|  | val = s.colValue[n] | 
|  | } | 
|  | if doInsert { | 
|  | cols[colidx] = val | 
|  | } | 
|  | } | 
|  |  | 
|  | if doInsert { | 
|  | t.rows = append(t.rows, &row{cols: cols}) | 
|  | } | 
|  | return driver.RowsAffected(1), nil | 
|  | } | 
|  |  | 
|  | // hook to simulate broken connections | 
|  | var hookQueryBadConn func() bool | 
|  |  | 
|  | func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { | 
|  | panic("Use QueryContext") | 
|  | } | 
|  |  | 
|  | func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | 
|  | if s.panic == "Query" { | 
|  | panic(s.panic) | 
|  | } | 
|  | if s.closed { | 
|  | return nil, errClosed | 
|  | } | 
|  |  | 
|  | if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { | 
|  | return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} | 
|  | } | 
|  | if s.c.isDirtyAndMark() { | 
|  | return nil, errFakeConnSessionDirty | 
|  | } | 
|  |  | 
|  | err := checkSubsetTypes(s.c.db.allowAny, args) | 
|  | if err != nil { | 
|  | return nil, err | 
|  | } | 
|  |  | 
|  | s.touchMem() | 
|  | db := s.c.db | 
|  | if len(args) != s.placeholders { | 
|  | panic("error in pkg db; should only get here if size is correct") | 
|  | } | 
|  |  | 
|  | setMRows := make([][]*row, 0, 1) | 
|  | setColumns := make([][]string, 0, 1) | 
|  | setColType := make([][]string, 0, 1) | 
|  |  | 
|  | for { | 
|  | db.mu.Lock() | 
|  | t, ok := db.table(s.table) | 
|  | db.mu.Unlock() | 
|  | if !ok { | 
|  | return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) | 
|  | } | 
|  |  | 
|  | if s.table == "magicquery" { | 
|  | if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { | 
|  | if args[0].Value == "sleep" { | 
|  | time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) | 
|  | } | 
|  | } | 
|  | } | 
|  | if s.table == "tx_status" && s.colName[0] == "tx_status" { | 
|  | txStatus := "autocommit" | 
|  | if s.c.currTx != nil { | 
|  | txStatus = "transaction" | 
|  | } | 
|  | cursor := &rowsCursor{ | 
|  | parentMem: s.c, | 
|  | posRow:    -1, | 
|  | rows: [][]*row{ | 
|  | { | 
|  | { | 
|  | cols: []any{ | 
|  | txStatus, | 
|  | }, | 
|  | }, | 
|  | }, | 
|  | }, | 
|  | cols: [][]string{ | 
|  | { | 
|  | "tx_status", | 
|  | }, | 
|  | }, | 
|  | colType: [][]string{ | 
|  | { | 
|  | "string", | 
|  | }, | 
|  | }, | 
|  | errPos: -1, | 
|  | } | 
|  | return cursor, nil | 
|  | } | 
|  |  | 
|  | t.mu.Lock() | 
|  |  | 
|  | colIdx := make(map[string]int) // select column name -> column index in table | 
|  | for _, name := range s.colName { | 
|  | idx := t.columnIndex(name) | 
|  | if idx == -1 { | 
|  | t.mu.Unlock() | 
|  | return nil, fmt.Errorf("fakedb: unknown column name %q", name) | 
|  | } | 
|  | colIdx[name] = idx | 
|  | } | 
|  |  | 
|  | mrows := []*row{} | 
|  | rows: | 
|  | for _, trow := range t.rows { | 
|  | // Process the where clause, skipping non-match rows. This is lazy | 
|  | // and just uses fmt.Sprintf("%v") to test equality. Good enough | 
|  | // for test code. | 
|  | for _, wcol := range s.whereCol { | 
|  | idx := t.columnIndex(wcol.Column) | 
|  | if idx == -1 { | 
|  | t.mu.Unlock() | 
|  | return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) | 
|  | } | 
|  | tcol := trow.cols[idx] | 
|  | if bs, ok := tcol.([]byte); ok { | 
|  | // lazy hack to avoid sprintf %v on a []byte | 
|  | tcol = string(bs) | 
|  | } | 
|  | var argValue any | 
|  | if wcol.Placeholder == "?" { | 
|  | argValue = args[wcol.Ordinal-1].Value | 
|  | } else { | 
|  | // Assign arg value from placeholder name. | 
|  | for _, a := range args { | 
|  | if a.Name == wcol.Placeholder[1:] { | 
|  | argValue = a.Value | 
|  | break | 
|  | } | 
|  | } | 
|  | } | 
|  | if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { | 
|  | continue rows | 
|  | } | 
|  | } | 
|  | mrow := &row{cols: make([]any, len(s.colName))} | 
|  | for seli, name := range s.colName { | 
|  | mrow.cols[seli] = trow.cols[colIdx[name]] | 
|  | } | 
|  | mrows = append(mrows, mrow) | 
|  | } | 
|  |  | 
|  | var colType []string | 
|  | for _, column := range s.colName { | 
|  | colType = append(colType, t.coltype[t.columnIndex(column)]) | 
|  | } | 
|  |  | 
|  | t.mu.Unlock() | 
|  |  | 
|  | setMRows = append(setMRows, mrows) | 
|  | setColumns = append(setColumns, s.colName) | 
|  | setColType = append(setColType, colType) | 
|  |  | 
|  | if s.next == nil { | 
|  | break | 
|  | } | 
|  | s = s.next | 
|  | } | 
|  |  | 
|  | cursor := &rowsCursor{ | 
|  | parentMem: s.c, | 
|  | posRow:    -1, | 
|  | rows:      setMRows, | 
|  | cols:      setColumns, | 
|  | colType:   setColType, | 
|  | errPos:    -1, | 
|  | } | 
|  | return cursor, nil | 
|  | } | 
|  |  | 
|  | func (s *fakeStmt) NumInput() int { | 
|  | if s.panic == "NumInput" { | 
|  | panic(s.panic) | 
|  | } | 
|  | return s.placeholders | 
|  | } | 
|  |  | 
|  | // hook to simulate broken connections | 
|  | var hookCommitBadConn func() bool | 
|  |  | 
|  | func (tx *fakeTx) Commit() error { | 
|  | tx.c.currTx = nil | 
|  | if hookCommitBadConn != nil && hookCommitBadConn() { | 
|  | return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} | 
|  | } | 
|  | tx.c.touchMem() | 
|  | return nil | 
|  | } | 
|  |  | 
|  | // hook to simulate broken connections | 
|  | var hookRollbackBadConn func() bool | 
|  |  | 
|  | func (tx *fakeTx) Rollback() error { | 
|  | tx.c.currTx = nil | 
|  | if hookRollbackBadConn != nil && hookRollbackBadConn() { | 
|  | return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} | 
|  | } | 
|  | tx.c.touchMem() | 
|  | return nil | 
|  | } | 
|  |  | 
|  | type rowsCursor struct { | 
|  | parentMem memToucher | 
|  | cols      [][]string | 
|  | colType   [][]string | 
|  | posSet    int | 
|  | posRow    int | 
|  | rows      [][]*row | 
|  | closed    bool | 
|  |  | 
|  | // errPos and err are for making Next return early with error. | 
|  | errPos int | 
|  | err    error | 
|  |  | 
|  | // a clone of slices to give out to clients, indexed by the | 
|  | // original slice's first byte address.  we clone them | 
|  | // just so we're able to corrupt them on close. | 
|  | bytesClone map[*byte][]byte | 
|  |  | 
|  | // Every operation writes to line to enable the race detector | 
|  | // check for data races. | 
|  | // This is separate from the fakeConn.line to allow for drivers that | 
|  | // can start multiple queries on the same transaction at the same time. | 
|  | line int64 | 
|  | } | 
|  |  | 
|  | func (rc *rowsCursor) touchMem() { | 
|  | rc.parentMem.touchMem() | 
|  | rc.line++ | 
|  | } | 
|  |  | 
|  | func (rc *rowsCursor) Close() error { | 
|  | rc.touchMem() | 
|  | rc.parentMem.touchMem() | 
|  | rc.closed = true | 
|  | return nil | 
|  | } | 
|  |  | 
|  | func (rc *rowsCursor) Columns() []string { | 
|  | return rc.cols[rc.posSet] | 
|  | } | 
|  |  | 
|  | func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { | 
|  | return colTypeToReflectType(rc.colType[rc.posSet][index]) | 
|  | } | 
|  |  | 
|  | var rowsCursorNextHook func(dest []driver.Value) error | 
|  |  | 
|  | func (rc *rowsCursor) Next(dest []driver.Value) error { | 
|  | if rowsCursorNextHook != nil { | 
|  | return rowsCursorNextHook(dest) | 
|  | } | 
|  |  | 
|  | if rc.closed { | 
|  | return errors.New("fakedb: cursor is closed") | 
|  | } | 
|  | rc.touchMem() | 
|  | rc.posRow++ | 
|  | if rc.posRow == rc.errPos { | 
|  | return rc.err | 
|  | } | 
|  | if rc.posRow >= len(rc.rows[rc.posSet]) { | 
|  | return io.EOF // per interface spec | 
|  | } | 
|  | for i, v := range rc.rows[rc.posSet][rc.posRow].cols { | 
|  | // TODO(bradfitz): convert to subset types? naah, I | 
|  | // think the subset types should only be input to | 
|  | // driver, but the sql package should be able to handle | 
|  | // a wider range of types coming out of drivers. all | 
|  | // for ease of drivers, and to prevent drivers from | 
|  | // messing up conversions or doing them differently. | 
|  | dest[i] = v | 
|  |  | 
|  | if bs, ok := v.([]byte); ok { | 
|  | if rc.bytesClone == nil { | 
|  | rc.bytesClone = make(map[*byte][]byte) | 
|  | } | 
|  | clone, ok := rc.bytesClone[&bs[0]] | 
|  | if !ok { | 
|  | clone = make([]byte, len(bs)) | 
|  | copy(clone, bs) | 
|  | rc.bytesClone[&bs[0]] = clone | 
|  | } | 
|  | dest[i] = clone | 
|  | } | 
|  | } | 
|  | return nil | 
|  | } | 
|  |  | 
|  | func (rc *rowsCursor) HasNextResultSet() bool { | 
|  | rc.touchMem() | 
|  | return rc.posSet < len(rc.rows)-1 | 
|  | } | 
|  |  | 
|  | func (rc *rowsCursor) NextResultSet() error { | 
|  | rc.touchMem() | 
|  | if rc.HasNextResultSet() { | 
|  | rc.posSet++ | 
|  | rc.posRow = -1 | 
|  | return nil | 
|  | } | 
|  | return io.EOF // Per interface spec. | 
|  | } | 
|  |  | 
|  | // fakeDriverString is like driver.String, but indirects pointers like | 
|  | // DefaultValueConverter. | 
|  | // | 
|  | // This could be surprising behavior to retroactively apply to | 
|  | // driver.String now that Go1 is out, but this is convenient for | 
|  | // our TestPointerParamsAndScans. | 
|  | // | 
|  | type fakeDriverString struct{} | 
|  |  | 
|  | func (fakeDriverString) ConvertValue(v any) (driver.Value, error) { | 
|  | switch c := v.(type) { | 
|  | case string, []byte: | 
|  | return v, nil | 
|  | case *string: | 
|  | if c == nil { | 
|  | return nil, nil | 
|  | } | 
|  | return *c, nil | 
|  | } | 
|  | return fmt.Sprintf("%v", v), nil | 
|  | } | 
|  |  | 
|  | type anyTypeConverter struct{} | 
|  |  | 
|  | func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) { | 
|  | return v, nil | 
|  | } | 
|  |  | 
|  | func converterForType(typ string) driver.ValueConverter { | 
|  | switch typ { | 
|  | case "bool": | 
|  | return driver.Bool | 
|  | case "nullbool": | 
|  | return driver.Null{Converter: driver.Bool} | 
|  | case "byte", "int16": | 
|  | return driver.NotNull{Converter: driver.DefaultParameterConverter} | 
|  | case "int32": | 
|  | return driver.Int32 | 
|  | case "nullbyte", "nullint32", "nullint16": | 
|  | return driver.Null{Converter: driver.DefaultParameterConverter} | 
|  | case "string": | 
|  | return driver.NotNull{Converter: fakeDriverString{}} | 
|  | case "nullstring": | 
|  | return driver.Null{Converter: fakeDriverString{}} | 
|  | case "int64": | 
|  | // TODO(coopernurse): add type-specific converter | 
|  | return driver.NotNull{Converter: driver.DefaultParameterConverter} | 
|  | case "nullint64": | 
|  | // TODO(coopernurse): add type-specific converter | 
|  | return driver.Null{Converter: driver.DefaultParameterConverter} | 
|  | case "float64": | 
|  | // TODO(coopernurse): add type-specific converter | 
|  | return driver.NotNull{Converter: driver.DefaultParameterConverter} | 
|  | case "nullfloat64": | 
|  | // TODO(coopernurse): add type-specific converter | 
|  | return driver.Null{Converter: driver.DefaultParameterConverter} | 
|  | case "datetime": | 
|  | return driver.NotNull{Converter: driver.DefaultParameterConverter} | 
|  | case "nulldatetime": | 
|  | return driver.Null{Converter: driver.DefaultParameterConverter} | 
|  | case "any": | 
|  | return anyTypeConverter{} | 
|  | } | 
|  | panic("invalid fakedb column type of " + typ) | 
|  | } | 
|  |  | 
|  | func colTypeToReflectType(typ string) reflect.Type { | 
|  | switch typ { | 
|  | case "bool": | 
|  | return reflect.TypeOf(false) | 
|  | case "nullbool": | 
|  | return reflect.TypeOf(NullBool{}) | 
|  | case "int16": | 
|  | return reflect.TypeOf(int16(0)) | 
|  | case "nullint16": | 
|  | return reflect.TypeOf(NullInt16{}) | 
|  | case "int32": | 
|  | return reflect.TypeOf(int32(0)) | 
|  | case "nullint32": | 
|  | return reflect.TypeOf(NullInt32{}) | 
|  | case "string": | 
|  | return reflect.TypeOf("") | 
|  | case "nullstring": | 
|  | return reflect.TypeOf(NullString{}) | 
|  | case "int64": | 
|  | return reflect.TypeOf(int64(0)) | 
|  | case "nullint64": | 
|  | return reflect.TypeOf(NullInt64{}) | 
|  | case "float64": | 
|  | return reflect.TypeOf(float64(0)) | 
|  | case "nullfloat64": | 
|  | return reflect.TypeOf(NullFloat64{}) | 
|  | case "datetime": | 
|  | return reflect.TypeOf(time.Time{}) | 
|  | case "any": | 
|  | return reflect.TypeOf(new(any)).Elem() | 
|  | } | 
|  | panic("invalid fakedb column type of " + typ) | 
|  | } |