// Copyright 2016 The Xorm Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package xorm import ( "bufio" "database/sql" "fmt" "io" "os" "strings" "xorm.io/xorm/internal/utils" "xorm.io/xorm/schemas" ) // Ping test if database is ok func (session *Session) Ping() error { if session.isAutoClose { defer session.Close() } session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName()) return session.DB().PingContext(session.ctx) } // CreateTable create a table according a bean func (session *Session) CreateTable(bean interface{}) error { if session.isAutoClose { defer session.Close() } return session.createTable(bean) } func (session *Session) createTable(bean interface{}) error { if err := session.statement.SetRefBean(bean); err != nil { return err } sqlStrs := session.statement.GenCreateTableSQL() for _, s := range sqlStrs { _, err := session.exec(s) if err != nil { return err } } return nil } // CreateIndexes create indexes func (session *Session) CreateIndexes(bean interface{}) error { if session.isAutoClose { defer session.Close() } return session.createIndexes(bean) } func (session *Session) createIndexes(bean interface{}) error { if err := session.statement.SetRefBean(bean); err != nil { return err } sqls := session.statement.GenIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { return err } } return nil } // CreateUniques create uniques func (session *Session) CreateUniques(bean interface{}) error { if session.isAutoClose { defer session.Close() } return session.createUniques(bean) } func (session *Session) createUniques(bean interface{}) error { if err := session.statement.SetRefBean(bean); err != nil { return err } sqls := session.statement.GenUniqueSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { return err } } return nil } // DropIndexes drop indexes func (session *Session) DropIndexes(bean interface{}) error { if session.isAutoClose { defer session.Close() } return session.dropIndexes(bean) } func (session *Session) dropIndexes(bean interface{}) error { if err := session.statement.SetRefBean(bean); err != nil { return err } sqls := session.statement.GenDelIndexSQL() for _, sqlStr := range sqls { _, err := session.exec(sqlStr) if err != nil { return err } } return nil } // DropTable drop table will drop table if exist, if drop failed, it will return error func (session *Session) DropTable(beanOrTableName interface{}) error { if session.isAutoClose { defer session.Close() } return session.dropTable(beanOrTableName) } func (session *Session) dropTable(beanOrTableName interface{}) error { tableName := session.engine.TableName(beanOrTableName) sqlStr, checkIfExist := session.engine.dialect.DropTableSQL(session.engine.TableName(tableName, true)) if !checkIfExist { exist, err := session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) if err != nil { return err } checkIfExist = exist } if checkIfExist { _, err := session.exec(sqlStr) return err } return nil } // IsTableExist if a table is exist func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) { if session.isAutoClose { defer session.Close() } tableName := session.engine.TableName(beanOrTableName) return session.isTableExist(tableName) } func (session *Session) isTableExist(tableName string) (bool, error) { return session.engine.dialect.IsTableExist(session.getQueryer(), session.ctx, tableName) } // IsTableEmpty if table have any records func (session *Session) IsTableEmpty(bean interface{}) (bool, error) { if session.isAutoClose { defer session.Close() } return session.isTableEmpty(session.engine.TableName(bean)) } func (session *Session) isTableEmpty(tableName string) (bool, error) { var total int64 sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true))) err := session.queryRow(sqlStr).Scan(&total) if err != nil { if err == sql.ErrNoRows { err = nil } return true, err } return total == 0, nil } // find if index is exist according cols func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) { indexes, err := session.engine.dialect.GetIndexes(session.getQueryer(), session.ctx, tableName) if err != nil { return false, err } for _, index := range indexes { if utils.SliceEq(index.Cols, cols) { if unique { return index.Type == schemas.UniqueType, nil } return index.Type == schemas.IndexType, nil } } return false, nil } func (session *Session) addColumn(colName string) error { col := session.statement.RefTable.GetColumn(colName) sql := session.engine.dialect.AddColumnSQL(session.statement.TableName(), col) _, err := session.exec(sql) return err } func (session *Session) addIndex(tableName, idxName string) error { index := session.statement.RefTable.Indexes[idxName] sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index) _, err := session.exec(sqlStr) return err } func (session *Session) addUnique(tableName, uqeName string) error { index := session.statement.RefTable.Indexes[uqeName] sqlStr := session.engine.dialect.CreateIndexSQL(tableName, index) _, err := session.exec(sqlStr) return err } // Sync2 synchronize structs to database tables func (session *Session) Sync2(beans ...interface{}) error { engine := session.engine if session.isAutoClose { session.isAutoClose = false defer session.Close() } tables, err := engine.dialect.GetTables(session.getQueryer(), session.ctx) if err != nil { return err } session.autoResetStatement = false defer func() { session.autoResetStatement = true session.resetStatement() }() for _, bean := range beans { v := utils.ReflectValue(bean) table, err := engine.tagParser.ParseWithCache(v) if err != nil { return err } var tbName string if len(session.statement.AltTableName) > 0 { tbName = session.statement.AltTableName } else { tbName = engine.TableName(bean) } tbNameWithSchema := engine.tbNameWithSchema(tbName) var oriTable *schemas.Table for _, tb := range tables { if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) { oriTable = tb break } } // this is a new table if oriTable == nil { err = session.StoreEngine(session.statement.StoreEngine).createTable(bean) if err != nil { return err } err = session.createUniques(bean) if err != nil { return err } err = session.createIndexes(bean) if err != nil { return err } continue } // this will modify an old table if err = engine.loadTableInfo(oriTable); err != nil { return err } // check columns for _, col := range table.Columns() { var oriCol *schemas.Column for _, col2 := range oriTable.Columns() { if strings.EqualFold(col.Name, col2.Name) { oriCol = col2 break } } // column is not exist on table if oriCol == nil { session.statement.RefTable = table session.statement.SetTableName(tbNameWithSchema) if err = session.addColumn(col.Name); err != nil { return err } continue } err = nil expectedType := engine.dialect.SQLType(col) curType := engine.dialect.SQLType(oriCol) if expectedType != curType { if expectedType == schemas.Text && strings.HasPrefix(curType, schemas.Varchar) { // currently only support mysql & postgres if engine.dialect.URI().DBType == schemas.MYSQL || engine.dialect.URI().DBType == schemas.POSTGRES { engine.logger.Infof("Table %s column %s change type from %s to %s\n", tbNameWithSchema, col.Name, curType, expectedType) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } else { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n", tbNameWithSchema, col.Name, curType, expectedType) } } else if strings.HasPrefix(curType, schemas.Varchar) && strings.HasPrefix(expectedType, schemas.Varchar) { if engine.dialect.URI().DBType == schemas.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbNameWithSchema, col.Name, oriCol.Length, col.Length) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } } } else { if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') { engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s", tbNameWithSchema, col.Name, curType, expectedType) } } } else if expectedType == schemas.Varchar { if engine.dialect.URI().DBType == schemas.MYSQL { if oriCol.Length < col.Length { engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n", tbNameWithSchema, col.Name, oriCol.Length, col.Length) _, err = session.exec(engine.dialect.ModifyColumnSQL(tbNameWithSchema, col)) } } } if col.Default != oriCol.Default { switch { case col.IsAutoIncrement: // For autoincrement column, don't check default case (col.SQLType.Name == schemas.Bool || col.SQLType.Name == schemas.Boolean) && ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") || (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")): default: engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s", tbName, col.Name, oriCol.Default, col.Default) } } if col.Nullable != oriCol.Nullable { engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v", tbName, col.Name, oriCol.Nullable, col.Nullable) } if err != nil { return err } } var foundIndexNames = make(map[string]bool) var addedNames = make(map[string]*schemas.Index) for name, index := range table.Indexes { var oriIndex *schemas.Index for name2, index2 := range oriTable.Indexes { if index.Equal(index2) { oriIndex = index2 foundIndexNames[name2] = true break } } if oriIndex != nil { if oriIndex.Type != index.Type { sql := engine.dialect.DropIndexSQL(tbNameWithSchema, oriIndex) _, err = session.exec(sql) if err != nil { return err } oriIndex = nil } } if oriIndex == nil { addedNames[name] = index } } for name2, index2 := range oriTable.Indexes { if _, ok := foundIndexNames[name2]; !ok { sql := engine.dialect.DropIndexSQL(tbNameWithSchema, index2) _, err = session.exec(sql) if err != nil { return err } } } for name, index := range addedNames { if index.Type == schemas.UniqueType { session.statement.RefTable = table session.statement.SetTableName(tbNameWithSchema) err = session.addUnique(tbNameWithSchema, name) } else if index.Type == schemas.IndexType { session.statement.RefTable = table session.statement.SetTableName(tbNameWithSchema) err = session.addIndex(tbNameWithSchema, name) } if err != nil { return err } } // check all the columns which removed from struct fields but left on database tables. for _, colName := range oriTable.ColumnsSeq() { if table.GetColumn(colName) == nil { engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName) } } } return nil } // ImportFile SQL DDL file func (session *Session) ImportFile(ddlPath string) ([]sql.Result, error) { file, err := os.Open(ddlPath) if err != nil { return nil, err } defer file.Close() return session.Import(file) } // Import SQL DDL from io.Reader func (session *Session) Import(r io.Reader) ([]sql.Result, error) { var results []sql.Result var lastError error scanner := bufio.NewScanner(r) var inSingleQuote bool semiColSpliter := func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } for i, b := range data { if b == '\'' { inSingleQuote = !inSingleQuote } if !inSingleQuote && b == ';' { return i + 1, data[0:i], nil } } // If we're at EOF, we have a final, non-terminated line. Return it. if atEOF { return len(data), data, nil } // Request more data. return 0, nil, nil } scanner.Split(semiColSpliter) for scanner.Scan() { query := strings.Trim(scanner.Text(), " \t\n\r") if len(query) > 0 { result, err := session.Exec(query) results = append(results, result) if err != nil { return nil, err } } } return results, lastError }