diff --git a/sqlx/scanner.go b/sqlx/scanner.go index 5b1b296..60adeff 100644 --- a/sqlx/scanner.go +++ b/sqlx/scanner.go @@ -15,6 +15,8 @@ package sqlx import ( + "bytes" + "database/sql" "errors" "fmt" "reflect" @@ -24,6 +26,7 @@ var ( ErrNoMoreRows = errors.New("ekit: 已读取完") errInvalidArgument = errors.New("ekit: 参数非法") _ Scanner = &sqlRowsScanner{} + bytesType = reflect.TypeOf([]byte("")) ) // Scanner 用于简化sql.Rows包中的Scan操作 @@ -54,6 +57,7 @@ func NewSQLRowsScanner(r Rows) (Scanner, error) { for i, columnType := range columnTypes { typ := columnType.ScanType() for typ.Kind() == reflect.Pointer { + // 兼容 sqlite,理论上来说其他 driver 不应该命中这个分支 typ = typ.Elem() } columnValuePointers[i] = reflect.New(typ).Interface() @@ -84,7 +88,12 @@ func (s *sqlRowsScanner) Scan() ([]any, error) { func (s *sqlRowsScanner) columnValues() []any { values := make([]any, len(s.columnValuePointers)) for i := 0; i < len(s.columnValuePointers); i++ { - values[i] = reflect.ValueOf(s.columnValuePointers[i]).Elem().Interface() + val := reflect.ValueOf(s.columnValuePointers[i]).Elem().Interface() + // sql.RawBytes 存在内存共享的问题,所以需要执行复制 + if rawBytes, ok := val.(sql.RawBytes); ok { + val = sql.RawBytes(bytes.Clone(rawBytes)) + } + values[i] = val } return values }