package mssql import ( "bytes" "context" "encoding/binary" "fmt" "math" "reflect" "strings" "time" "github.com/denisenkom/go-mssqldb/internal/decimal" ) type Bulk struct { // ctx is used only for AddRow and Done methods. // This could be removed if AddRow and Done accepted // a ctx field as well, which is available with the // database/sql call. ctx context.Context cn *Conn metadata []columnStruct bulkColumns []columnStruct columnsName []string tablename string numRows int headerSent bool Options BulkOptions Debug bool } type BulkOptions struct { CheckConstraints bool FireTriggers bool KeepNulls bool KilobytesPerBatch int RowsPerBatch int Order []string Tablock bool } type DataValue interface{} const ( sqlDateFormat = "2006-01-02" sqlTimeFormat = "2006-01-02 15:04:05.999999999Z07:00" ) func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) { b := Bulk{ctx: context.Background(), cn: cn, tablename: table, headerSent: false, columnsName: columns} b.Debug = false return &b } func (cn *Conn) CreateBulkContext(ctx context.Context, table string, columns []string) (_ *Bulk) { b := Bulk{ctx: ctx, cn: cn, tablename: table, headerSent: false, columnsName: columns} b.Debug = false return &b } func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) { //get table columns info err = b.getMetadata(ctx) if err != nil { return err } //match the columns for _, colname := range b.columnsName { var bulkCol *columnStruct for _, m := range b.metadata { if m.ColName == colname { bulkCol = &m break } } if bulkCol != nil { if bulkCol.ti.TypeId == typeUdt { //send udt as binary bulkCol.ti.TypeId = typeBigVarBin } b.bulkColumns = append(b.bulkColumns, *bulkCol) b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId) } else { return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename) } } //create the bulk command //columns definitions var col_defs bytes.Buffer for i, col := range b.bulkColumns { if i != 0 { col_defs.WriteString(", ") } col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti)) } //options var with_opts []string if b.Options.CheckConstraints { with_opts = append(with_opts, "CHECK_CONSTRAINTS") } if b.Options.FireTriggers { with_opts = append(with_opts, "FIRE_TRIGGERS") } if b.Options.KeepNulls { with_opts = append(with_opts, "KEEP_NULLS") } if b.Options.KilobytesPerBatch > 0 { with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch)) } if b.Options.RowsPerBatch > 0 { with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch)) } if len(b.Options.Order) > 0 { with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ","))) } if b.Options.Tablock { with_opts = append(with_opts, "TABLOCK") } var with_part string if len(with_opts) > 0 { with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ",")) } query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part) stmt, err := b.cn.PrepareContext(ctx, query) if err != nil { return fmt.Errorf("Prepare failed: %s", err.Error()) } b.dlogf(query) _, err = stmt.(*Stmt).ExecContext(ctx, nil) if err != nil { return err } b.headerSent = true var buf = b.cn.sess.buf buf.BeginPacket(packBulkLoadBCP, false) // Send the columns metadata. columnMetadata := b.createColMetadata() _, err = buf.Write(columnMetadata) return } // AddRow immediately writes the row to the destination table. // The arguments are the row values in the order they were specified. func (b *Bulk) AddRow(row []interface{}) (err error) { if !b.headerSent { err = b.sendBulkCommand(b.ctx) if err != nil { return } } if len(row) != len(b.bulkColumns) { return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d", len(row), len(b.bulkColumns)) } bytes, err := b.makeRowData(row) if err != nil { return } _, err = b.cn.sess.buf.Write(bytes) if err != nil { return } b.numRows = b.numRows + 1 return } func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) { buf := new(bytes.Buffer) buf.WriteByte(byte(tokenRow)) var logcol bytes.Buffer for i, col := range b.bulkColumns { if b.Debug { logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i])) } param, err := b.makeParam(row[i], col) if err != nil { return nil, fmt.Errorf("bulkcopy: %s", err.Error()) } if col.ti.Writer == nil { return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x", col.ColName, col.ti.TypeId) } err = col.ti.Writer(buf, param.ti, param.buffer) if err != nil { return nil, fmt.Errorf("bulkcopy: %s", err.Error()) } } b.dlogf("row[%d] %s\n", b.numRows, logcol.String()) return buf.Bytes(), nil } func (b *Bulk) Done() (rowcount int64, err error) { if b.headerSent == false { //no rows had been sent return 0, nil } var buf = b.cn.sess.buf buf.WriteByte(byte(tokenDone)) binary.Write(buf, binary.LittleEndian, uint16(doneFinal)) binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd if b.cn.sess.loginAck.TDSVersion >= verTDS72 { binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0 } else { binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0 } buf.FinishPacket() tokchan := make(chan tokenStruct, 5) go processResponse(b.ctx, b.cn.sess, tokchan, nil) var rowCount int64 for token := range tokchan { switch token := token.(type) { case doneStruct: if token.Status&doneCount != 0 { rowCount = int64(token.RowCount) } if token.isError() { return 0, token.getError() } case error: return 0, b.cn.checkBadConn(token) } } return rowCount, nil } func (b *Bulk) createColMetadata() []byte { buf := new(bytes.Buffer) buf.WriteByte(byte(tokenColMetadata)) // token binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count for i, col := range b.bulkColumns { if b.cn.sess.loginAck.TDSVersion >= verTDS72 { binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0? } else { binary.Write(buf, binary.LittleEndian, uint16(col.UserType)) } binary.Write(buf, binary.LittleEndian, uint16(col.Flags)) writeTypeInfo(buf, &b.bulkColumns[i].ti) if col.ti.TypeId == typeNText || col.ti.TypeId == typeText || col.ti.TypeId == typeImage { tablename_ucs2 := str2ucs2(b.tablename) binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2)) buf.Write(tablename_ucs2) } colname_ucs2 := str2ucs2(col.ColName) buf.WriteByte(uint8(len(colname_ucs2) / 2)) buf.Write(colname_ucs2) } return buf.Bytes() } func (b *Bulk) getMetadata(ctx context.Context) (err error) { stmt, err := b.cn.prepareContext(ctx, "SET FMTONLY ON") if err != nil { return } _, err = stmt.ExecContext(ctx, nil) if err != nil { return } // Get columns info. stmt, err = b.cn.prepareContext(ctx, fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename)) if err != nil { return } rows, err := stmt.QueryContext(ctx, nil) if err != nil { return fmt.Errorf("get columns info failed: %v", err) } b.metadata = rows.(*Rows).cols if b.Debug { for _, col := range b.metadata { b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n", col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec, col.Flags, col.ti.Collation.LcidAndFlags) } } return rows.Close() } func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) { res.ti.Size = col.ti.Size res.ti.TypeId = col.ti.TypeId if val == nil { res.ti.Size = 0 return } switch col.ti.TypeId { case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN: var intvalue int64 switch val := val.(type) { case int: intvalue = int64(val) case int32: intvalue = int64(val) case int64: intvalue = val default: err = fmt.Errorf("mssql: invalid type for int column: %T", val) return } res.buffer = make([]byte, res.ti.Size) if col.ti.Size == 1 { res.buffer[0] = byte(intvalue) } else if col.ti.Size == 2 { binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue)) } else if col.ti.Size == 4 { binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue)) } else if col.ti.Size == 8 { binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue)) } case typeFlt4, typeFlt8, typeFltN: var floatvalue float64 switch val := val.(type) { case float32: floatvalue = float64(val) case float64: floatvalue = val case int: floatvalue = float64(val) case int64: floatvalue = float64(val) default: err = fmt.Errorf("mssql: invalid type for float column: %T %s", val, val) return } if col.ti.Size == 4 { res.buffer = make([]byte, 4) binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue))) } else if col.ti.Size == 8 { res.buffer = make([]byte, 8) binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue)) } case typeNVarChar, typeNText, typeNChar: switch val := val.(type) { case string: res.buffer = str2ucs2(val) case []byte: res.buffer = val default: err = fmt.Errorf("mssql: invalid type for nvarchar column: %T %s", val, val) return } res.ti.Size = len(res.buffer) case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar: switch val := val.(type) { case string: res.buffer = []byte(val) case []byte: res.buffer = val default: err = fmt.Errorf("mssql: invalid type for varchar column: %T %s", val, val) return } res.ti.Size = len(res.buffer) case typeBit, typeBitN: if reflect.TypeOf(val).Kind() != reflect.Bool { err = fmt.Errorf("mssql: invalid type for bit column: %T %s", val, val) return } res.ti.TypeId = typeBitN res.ti.Size = 1 res.buffer = make([]byte, 1) if val.(bool) { res.buffer[0] = 1 } case typeDateTime2N: switch val := val.(type) { case time.Time: res.buffer = encodeDateTime2(val, int(col.ti.Scale)) res.ti.Size = len(res.buffer) case string: var t time.Time if t, err = time.Parse(sqlTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } res.buffer = encodeDateTime2(t, int(col.ti.Scale)) res.ti.Size = len(res.buffer) default: err = fmt.Errorf("mssql: invalid type for datetime2 column: %T %s", val, val) return } case typeDateTimeOffsetN: switch val := val.(type) { case time.Time: res.buffer = encodeDateTimeOffset(val, int(col.ti.Scale)) res.ti.Size = len(res.buffer) case string: var t time.Time if t, err = time.Parse(sqlTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } res.buffer = encodeDateTimeOffset(t, int(col.ti.Scale)) res.ti.Size = len(res.buffer) default: err = fmt.Errorf("mssql: invalid type for datetimeoffset column: %T %s", val, val) return } case typeDateN: switch val := val.(type) { case time.Time: res.buffer = encodeDate(val) res.ti.Size = len(res.buffer) case string: var t time.Time if t, err = time.ParseInLocation(sqlDateFormat, val, time.UTC); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } res.buffer = encodeDate(t) res.ti.Size = len(res.buffer) default: err = fmt.Errorf("mssql: invalid type for date column: %T %s", val, val) return } case typeDateTime, typeDateTimeN, typeDateTim4: var t time.Time switch val := val.(type) { case time.Time: t = val case string: if t, err = time.Parse(sqlTimeFormat, val); err != nil { return res, fmt.Errorf("bulk: unable to convert string to date: %v", err) } default: err = fmt.Errorf("mssql: invalid type for datetime column: %T %s", val, val) return } if col.ti.Size == 4 { res.buffer = encodeDateTim4(t) res.ti.Size = len(res.buffer) } else if col.ti.Size == 8 { res.buffer = encodeDateTime(t) res.ti.Size = len(res.buffer) } else { err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size) } // case typeMoney, typeMoney4, typeMoneyN: case typeDecimal, typeDecimalN, typeNumeric, typeNumericN: prec := col.ti.Prec scale := col.ti.Scale var dec decimal.Decimal switch v := val.(type) { case int: dec = decimal.Int64ToDecimalScale(int64(v), 0) case int8: dec = decimal.Int64ToDecimalScale(int64(v), 0) case int16: dec = decimal.Int64ToDecimalScale(int64(v), 0) case int32: dec = decimal.Int64ToDecimalScale(int64(v), 0) case int64: dec = decimal.Int64ToDecimalScale(int64(v), 0) case float32: dec, err = decimal.Float64ToDecimalScale(float64(v), scale) case float64: dec, err = decimal.Float64ToDecimalScale(float64(v), scale) case string: dec, err = decimal.StringToDecimalScale(v, scale) default: return res, fmt.Errorf("unknown value for decimal: %T %#v", v, v) } if err != nil { return res, err } dec.SetPrec(prec) var length byte switch { case prec <= 9: length = 4 case prec <= 19: length = 8 case prec <= 28: length = 12 default: length = 16 } buf := make([]byte, length+1) // first byte length written by typeInfo.writer res.ti.Size = int(length) + 1 // second byte sign if !dec.IsPositive() { buf[0] = 0 } else { buf[0] = 1 } ub := dec.UnscaledBytes() l := len(ub) if l > int(length) { err = fmt.Errorf("decimal out of range: %s", dec) return res, err } // reverse the bytes for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 { buf[i] = ub[j] } res.buffer = buf case typeBigVarBin, typeBigBinary: switch val := val.(type) { case []byte: res.ti.Size = len(val) res.buffer = val default: err = fmt.Errorf("mssql: invalid type for Binary column: %T %s", val, val) return } case typeGuid: switch val := val.(type) { case []byte: res.ti.Size = len(val) res.buffer = val default: err = fmt.Errorf("mssql: invalid type for Guid column: %T %s", val, val) return } default: err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId) } return } func (b *Bulk) dlogf(format string, v ...interface{}) { if b.Debug { b.cn.sess.log.Printf(format, v...) } }