javaer to go之简单的ORM封装

身为一个做企业级开发的javaer,习惯使用hibernate、ibatis等ORM框架来操作数据库。虽然也发现golang也有ORM框架,像beego ORM等。

为了熟悉golang的一些特性,我还是觉得自己封装一个ORM。

1、struct与interface简单说明

golang是一门面向过程的语言,所以它本身是没有像java那样的类与对象的概念。但golang中提供了struct与interface。甚至通过struct与interface结合,可以模拟类与对象的各种方式。

什么是interface,golang的interface与java的interface是不是一回事呢?

简单的说,interface是一组method的组合,我们通过interface来定义对象的一组行为。

一个类型如果拥有一接口的所有方法,那这个类型实现了这一接口。而不是像java那样使用implements关键字。

所有类型都实现了一个空接口,也就是我们可以通过空接口类型传入(返回)所有类型参数。类似java的Object是所有类的父类一样,空接口interface{}是golang中所有类型的父接口。

  • 简单的struct
type Object struct {
    Id              int       
    name          string
    Age   int  
}
  • 简单的interface
type IObject interface {
    SetName(name string)
    GetName() string
}
  • interface的实现
func (this *Object) SetName(name string)() {
    this.name = name
}

func (this *Object) GetName() string {
    return this.name
}

2、反射机制简单说明

与java一样,golang也提供了反射工具。

golang的reflect包主要有两个数据类型Type和Value。Type就是定义的类型的一个数据类型,Value是值的类型。
golang的反射就是把类型变量(对象),转成Type或Value,然后在运行时,对其操作的过程。

reflect的TypeOf与ValueOf可以把任意类型对象转成转成Type或Value。

// TypeOf returns the reflection Type that represents the dynamic type of i.
// If i is a nil interface value, TypeOf returns nil.
func TypeOf(i interface{}) Type {
    eface := *(*emptyInterface)(unsafe.Pointer(&i))
    return toType(eface.typ)
}
// ValueOf returns a new Value initialized to the concrete value
// stored in the interface i.  ValueOf(nil) returns the zero Value.
func ValueOf(i interface{}) Value {
    if i == nil {
        return Value{}
    }

    // TODO: Maybe allow contents of a Value to live on the stack.
    // For now we make the contents always escape to the heap.  It
    // makes life easier in a few places (see chanrecv/mapassign
    // comment below).
    escapes(i)

    return unpackEface(i)
}

3、struct实现Entity与Annotation

Java的ORM实现,如Hibernate,是通过xml配置文件或者annotation来把对象与表关联起来的。

当然,golang也可以使用xml的方式来做关联。但xml配置文件使用起来比较麻烦,但golang却没有annotation这样的东西。

还好golang的struct中,除了变量名和类型之外,还可以选择性的增加一些tag:tag可以在类型的后面,用双引号(double quote)或重音(backquote/grave accent)表示的字符串。这些符号能被用来做文档或重要的标签。

tag里面的内容在正常编程中没有作用。只有在使用反射的时候才有作用。我们可以在运行时,对tag进行解析,以达到类似注解说明的效果。

我们的Entity:

type Entity struct {
    Id   int       `table:"table_name" column:"id"`
    Time time.Time `column:"time"`
    Name string    `column:"day"`
    Age  string    `column:"age"`
}

上面的struct中,类型后面的就是tag。与Hibernate一样,我们的ORM希望我们的表都有主键。这里写的ORM,只支持简单的主键。

  • tag中的column对应的是表的列
  • tag中的table描述用于主键字段,它即代表着主键,也说明了实体映射的表。

有了tag的描述后,我们就可以把ORM的O与R mapping起来了。

4、单例模式的数据库连接

var sqlDB *sql.DB

func Open() *sql.DB {
    if sqlDB == nil {
        db, err := sql.Open("mysql", "用户名:密码@/数据库名?charset=utf8")
        if err != nil {
            fmt.Println(err.Error())
        }
        sqlDB = db
    }

    return sqlDB
}

5、数据库操作基类(BaseDao)的定义

type IBaseDao interface {
    Init()
    Save(data interface{}) error
    Update(data interface{}) error
    SaveOrUpdate(data interface{}) error
    SaveAll(datas list.List) error
    Delete(data interface{}) error
    Find(sql string) (*list.List, error)
    FindOne(sql string) (interface{}, error)
}

type BaseDao struct {
    EntityType reflect.Type
    sqlDB      *sql.DB
    tableName     string            //表名
    pk            string            //主键
    columnToField map[string]string //字段名:属性名
    fieldToColumn map[string]string //属性名:字段名
}

IBaseDao接口中,Init是接口实现时的初始化。除了Init外都是数据库的DML操作,就是所谓的增删改查。

BaseDao的EntityType属性类似Java中的泛型对象(如下面的clazz),用于获取操作实体对象的属性与标签。

public class BaseDao<T,PK extends Serializable> extends HibernateDaoSupport implements IBaseDao<T,PK> {

    private Class<T> clazz;

    public BaseDao() {
        ParameterizedType type = (ParameterizedType) this.getClass().getGenericSuperclass();
        clazz = (Class<T>) type.getActualTypeArguments()[0];
    }


    //...
}

6、BaseDao的初始化过程

BaseDao的初始化过程,就是通过解析struct的tag,来把struct的变量与表的列一一对应起来的过程。

  1. 我们可以在后面的使用中,通过columnToField键值对,获取列对应的变量名
  2. 通过fieldToColumn键值对,获取变量名对应的列
  3. 通过pk变量获取主键的列
//初始化
func (this *BaseDao) Init() {
    this.columnToField = make(map[string]string)
    this.fieldToColumn = make(map[string]string)

    types := this.EntityType

    for i := 0; i < types.NumField(); i++ {
        typ := types.Field(i)
        tag := typ.Tag

        if len(tag) > 0 {
            column := tag.Get("column")
            name := typ.Name
            this.columnToField[column] = name
            this.fieldToColumn[name] = column

            if len(tag.Get("table")) > 0 {
                this.tableName = tag.Get("table")
                this.pk = column
            }
        }
    }
}

7、增加

  • 预处理的insert语句封装以及占位符代表的列(有序)
func (this *BaseDao) insertPrepareSQL() (fieldNames list.List, sql string) {
    names := new(bytes.Buffer)
    values := new(bytes.Buffer)

    i := 0

    for column, fieldName := range this.columnToField {

        if i != 0 {
            names.WriteString(",")
            values.WriteString(",")
        }
        fieldNames.PushBack(fieldName)
        names.WriteString(column)
        values.WriteString("?")
        i++
    }
    sql = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", this.tableName, names.String(), values.String())
    return
}
  • 占位符的数据(有序)
//预处理占位符的数据
func (this *BaseDao) prepareValues(data interface{}, fieldNames list.List) []interface{} {
    values := make([]interface{}, len(this.columnToField))
    object := reflect.ValueOf(data).Elem()

    i := 0
    for e := fieldNames.Front(); e != nil; e = e.Next() {
        name := e.Value.(string)
        field := object.FieldByName(name)
        values[i] = this.fieldValue(field)
        i++

    }

    return values
}

//reflect.Value获取值
func (this *BaseDao) fieldValue(v reflect.Value) interface{} {
    if !v.IsValid() {
        return nil
    }

    switch v.Kind() {
    case reflect.String:
        return v.String()
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        return v.Uint()
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        return v.Int()
    case reflect.Float32, reflect.Float64:
        return v.Float()
    case reflect.Struct:
        switch v.Type().String() {
        case "time.Time":
            m := v.MethodByName("Format")
            rets := m.Call([]reflect.Value{reflect.ValueOf(timeFormate)})
            t := rets[0].String()
            return t
        default:
            return nil
        }
    default:
        return nil
    }
}
  • 增加记录
//增加单个记录
func (this *BaseDao) Save(data interface{}) error {
    columns, sql := this.insertPrepareSQL()

    stmt, err := Open().Prepare(sql)
    args := this.prepareValues(data, columns)
    fmt.Println(sql, " ", args)
    _, err = stmt.Exec(args...)
    if err != nil {
        panic(err.Error())
    }
    return err
}

//增加多个记录
func (this *BaseDao) SaveAll(datas list.List) error {
    if datas.Len() == 0 {
        return nil
    }
    columns, sql := this.insertPrepareSQL()

    stmt, err := Open().Prepare(sql)
    if err != nil {
        panic(err.Error())
    }

    for e := datas.Front(); e != nil; e = e.Next() {
        args := this.prepareValues(e.Value, columns)
        fmt.Println(sql, " ", args)
        _, err = stmt.Exec(args...)
        if err != nil {
            panic(err.Error())
        }
    }

    return err
}

8、修改

  • 预处理的update语句封装以及占位符代表的列
//实体转update sql语句
func (this *BaseDao) updatePrepareSQL() (fieldNames list.List, sql string) {
    //UPDATE 表名称 SET 列名称 = 新值 WHERE 列名称 = 某值
    sets := new(bytes.Buffer)

    i := 0

    for column, fieldName := range this.columnToField {
        if strings.EqualFold(column, this.pk) {
            continue
        }
        if i != 0 {
            sets.WriteString(",")
        }

        fieldNames.PushBack(fieldName)
        sets.WriteString(column)
        sets.WriteString("=?")

        i++
    }
    fieldNames.PushBack(this.columnToField[this.pk])
    sql = fmt.Sprintf("UPDATE %s SET %s WHERE %s=?", this.tableName, sets.String(), this.pk)
    return
}
  • 更新记录
    和增加记录一样,获取到update语句及占位符的列名后,还需要根据列名获取到列的值(占位符的数据)。
//更新一个实体
func (this *BaseDao) Update(data interface{}) error {
    columns, sql := this.updatePrepareSQL()

    stmt, err := Open().Prepare(sql)
    args := this.prepareValues(data, columns)

    fmt.Println(sql, " ", args)
    _, err = stmt.Exec(args...)
    if err != nil {
        panic(err.Error())
    }
    return err
}

9、增加或修改

根据实体对象的ID是否有值来判断是保存还是更新。

//保存或更新一个实体,根据主键是否有值
func (this *BaseDao) SaveOrUpdate(data interface{}) error {

    if this.isPkValue(data) {
        return this.Update(data)
    } else {
        return this.Save(data)
    }

}

//主键是否有值
func (this *BaseDao) isPkValue(data interface{}) bool {
    object := reflect.ValueOf(data).Elem()
    pkName := this.columnToField[this.pk]
    pkValue := object.FieldByName(pkName)

    if !pkValue.IsValid() {
        return false
    }

    switch pkValue.Kind() {
    case reflect.String:
        if len(pkValue.String()) > 0 {
            return true
        }
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        if pkValue.Uint() != 0 {
            return true
        }
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        if pkValue.Int() != 0 {
            return true
        }
    }

    return false
}

10、删除

  • delete语句封装

根据主键来删除数据

//实体转delete sql语句
func (this *BaseDao) deleteSQL(data interface{}) string {
    //DELETE FROM 表名称 WHERE 列名称 = 值
    object := reflect.ValueOf(data).Elem()
    fieldValue := object.FieldByName(this.pk)
    pkValue := this.valueToString(fieldValue)
    return fmt.Sprintf("DELETE FROM  %s WHERE %s=%s", this.tableName, this.pk, pkValue)
}

//reflect.Value转字符串
func (this *BaseDao) valueToString(v reflect.Value) string {
    values := new(bytes.Buffer)
    switch v.Kind() {
    case reflect.String:
        values.WriteString("'")
        values.WriteString(v.String())
        values.WriteString("'")
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        values.WriteString(fmt.Sprintf("%d", v.Uint()))
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        values.WriteString(fmt.Sprintf("%d", v.Int()))
    case reflect.Float32, reflect.Float64:
        values.WriteString(fmt.Sprintf("%f", v.Float()))
    case reflect.Struct:
        switch v.Type().String() {
        case "time.Time":
            m := v.MethodByName("Format")
            rets := m.Call([]reflect.Value{reflect.ValueOf(timeFormate)})
            t := rets[0].String()
            values.WriteString("'")
            values.WriteString(t)
            values.WriteString("'")
        default:
            values.WriteString("null")
        }
    default:
        values.WriteString("null")
    }

    return values.String()
}
  • 删除记录
//删除一个实体
func (this *BaseDao) Delete(data interface{}) error {
    _, err := Open().Exec(this.deleteSQL(data))
    return err
}

11、查询

  • 通过select语句获取到记录
rows, err := Open().Query(sql)
if err != nil {
    fmt.Println(err.Error())
}
  • 获取列名
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
    panic(err.Error())
}
  • 根据列名的顺序,把一行的值封装到一个EntityType对象中
//对一条查询结果进行封装
func (this *BaseDao) parseQuery(columns []string, values []interface{}) interface{} {
    obj := reflect.New(this.EntityType).Interface()
    typ := reflect.ValueOf(obj).Elem()

    for i, col := range values {
        if col != nil {
            name := this.columnToField[columns[i]]
            field := typ.FieldByName(name)

            this.parseQueryColumn(field, string(col.([]byte)))
        }
    }

    return obj
}

//单个属性赋值
func (this *BaseDao) parseQueryColumn(field reflect.Value, s string) {
    switch field.Kind() {
    case reflect.String:
        field.SetString(s)
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        v, _ := strconv.ParseUint(s, 10, 0)
        field.SetUint(v)
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        v, _ := strconv.ParseInt(s, 10, 0)
        field.SetInt(v)
    case reflect.Float32:
        v, _ := strconv.ParseFloat(s, 32)
        field.SetFloat(v)
    case reflect.Float64:
        v, _ := strconv.ParseFloat(s, 64)
        field.SetFloat(v)
    case reflect.Struct:
        switch field.Type().String() {
        case "time.Time":
            v, _ := time.Parse(timeFormate, s)
            field.Set(reflect.ValueOf(v))
        }
    default:

    }
}
  • 把多个EntityType对象放到一个集合中
data := list.New() //创建一个新的list

for rows.Next() {
    err = rows.Scan(scanArgs...)
    if err != nil {
        panic(err.Error())
    }

    obj := this.parseQuery(columns, values)
    data.PushBack(obj)
}
  • 查询单条记录
    封装sql的时候加上limit限制,返回数据时,返回集合中的第一条(也是唯一的一条)。
if isOne && !strings.Contains(sql, "limit") {
    sql = sql + " limit 1"
}
var data interface{}
if datas.Len() > 0 {
    data = datas.Front().Value
}
  • 完整的查询代码
//根据SQL查询多条记录
func (this *BaseDao) Find(sql string) (*list.List, error) {
    return this.query(sql, false)
}

//根据SQL查询一条记录,如果找到不数据,data会返回nil
func (this *BaseDao) FindOne(sql string) (interface{}, error) {
    datas, err := this.query(sql, true)

    var data interface{}
    if datas.Len() > 0 {
        data = datas.Front().Value
    }
    return data, err

}

//根据SQL查询
func (this *BaseDao) query(sql string, isOne bool) (*list.List, error) {
    if isOne && !strings.Contains(sql, "limit") {
        sql = sql + " limit 1"
    }
    rows, err := Open().Query(sql)
    if err != nil {
        fmt.Println(err.Error())
    }

    defer rows.Close()
    columns, err := rows.Columns()
    if err != nil {
        panic(err.Error())
    }

    //构造scanArgs、values两个数组,scanArgs的每个值指向values相应值的地址
    values := make([]interface{}, len(columns))
    scanArgs := make([]interface{}, len(values))
    for i := range values {
        scanArgs[i] = &values[i]
    }

    data := list.New() //创建一个新的list

    for rows.Next() {
        err = rows.Scan(scanArgs...)
        if err != nil {
            panic(err.Error())
        }

        obj := this.parseQuery(columns, values)
        data.PushBack(obj)
    }
    return data, err
}

//对一条查询结果进行封装
func (this *BaseDao) parseQuery(columns []string, values []interface{}) interface{} {
    obj := reflect.New(this.EntityType).Interface()
    typ := reflect.ValueOf(obj).Elem()

    for i, col := range values {
        if col != nil {
            name := this.columnToField[columns[i]]
            field := typ.FieldByName(name)

            this.parseQueryColumn(field, string(col.([]byte)))
        }
    }

    return obj
}

//单个属性赋值
func (this *BaseDao) parseQueryColumn(field reflect.Value, s string) {
    switch field.Kind() {
    case reflect.String:
        field.SetString(s)
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        v, _ := strconv.ParseUint(s, 10, 0)
        field.SetUint(v)
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        v, _ := strconv.ParseInt(s, 10, 0)
        field.SetInt(v)
    case reflect.Float32:
        v, _ := strconv.ParseFloat(s, 32)
        field.SetFloat(v)
    case reflect.Float64:
        v, _ := strconv.ParseFloat(s, 64)
        field.SetFloat(v)
    case reflect.Struct:
        switch field.Type().String() {
        case "time.Time":
            v, _ := time.Parse(timeFormate, s)
            field.Set(reflect.ValueOf(v))
        }
    default:

    }
}

12、完整的ORM代码

package db

import (
    "bytes"
    "container/list"
    "database/sql"
    "fmt"
    "reflect"
    "strconv"
    "strings"
    "time"

    _ "github.com/go-sql-driver/mysql"
)

const (
    timeFormate = "2006-01-02 15:04:05"
)

var sqlDB *sql.DB

func Open() *sql.DB {
    if sqlDB == nil {
        db, err := sql.Open("mysql", "root:test@/time_table?charset=utf8")
        if err != nil {
            fmt.Println(err.Error())
        }
        sqlDB = db
    }

    return sqlDB
}

type IBaseDao interface {
    Init()
    Save(data interface{}) error
    Update(data interface{}) error
    SaveOrUpdate(data interface{}) error
    SaveAll(datas list.List) error
    Delete(data interface{}) error
    Find(sql string) (*list.List, error)
    FindOne(sql string) (interface{}, error)
}

type BaseDao struct {
    EntityType    reflect.Type
    sqlDB         *sql.DB
    tableName     string            //表名
    pk            string            //主键
    columnToField map[string]string //字段名:属性名
    fieldToColumn map[string]string //属性名:字段名
}

//初始化
func (this *BaseDao) Init() {
    this.columnToField = make(map[string]string)
    this.fieldToColumn = make(map[string]string)

    types := this.EntityType

    for i := 0; i < types.NumField(); i++ {
        typ := types.Field(i)
        tag := typ.Tag

        if len(tag) > 0 {
            column := tag.Get("column")
            name := typ.Name
            this.columnToField[column] = name
            this.fieldToColumn[name] = column

            if len(tag.Get("table")) > 0 {
                this.tableName = tag.Get("table")
                this.pk = column
            }
        }
    }
}

//增加单列
func (this *BaseDao) Save(data interface{}) error {
    columns, sql := this.insertPrepareSQL()

    stmt, err := Open().Prepare(sql)
    args := this.prepareValues(data, columns)
    fmt.Println(sql, " ", args)
    _, err = stmt.Exec(args...)
    if err != nil {
        panic(err.Error())
    }
    return err
}

//集合保存
func (this *BaseDao) SaveAll(datas list.List) error {
    if datas.Len() == 0 {
        return nil
    }
    columns, sql := this.insertPrepareSQL()

    stmt, err := Open().Prepare(sql)
    if err != nil {
        panic(err.Error())
    }

    for e := datas.Front(); e != nil; e = e.Next() {
        args := this.prepareValues(e.Value, columns)
        fmt.Println(sql, " ", args)
        _, err = stmt.Exec(args...)
        if err != nil {
            panic(err.Error())
        }
    }

    return err
}

//更新一个实体
func (this *BaseDao) Update(data interface{}) error {
    columns, sql := this.updatePrepareSQL()

    stmt, err := Open().Prepare(sql)
    args := this.prepareValues(data, columns)

    fmt.Println(sql, " ", args)
    _, err = stmt.Exec(args...)
    if err != nil {
        panic(err.Error())
    }
    return err
}

//实体转update sql语句
func (this *BaseDao) updatePrepareSQL() (fieldNames list.List, sql string) {
    //UPDATE 表名称 SET 列名称 = 新值 WHERE 列名称 = 某值
    sets := new(bytes.Buffer)

    i := 0

    for column, fieldName := range this.columnToField {
        if strings.EqualFold(column, this.pk) {
            continue
        }
        if i != 0 {
            sets.WriteString(",")
        }

        fieldNames.PushBack(fieldName)
        sets.WriteString(column)
        sets.WriteString("=?")

        i++
    }
    fieldNames.PushBack(this.columnToField[this.pk])
    sql = fmt.Sprintf("UPDATE %s SET %s WHERE %s=?", this.tableName, sets.String(), this.pk)
    return
}

//保存或更新一个实体,根据主键是否有值
func (this *BaseDao) SaveOrUpdate(data interface{}) error {

    if this.isPkValue(data) {
        return this.Update(data)
    } else {
        return this.Save(data)
    }

}

//主键是否有值
func (this *BaseDao) isPkValue(data interface{}) bool {
    object := reflect.ValueOf(data).Elem()
    pkName := this.columnToField[this.pk]
    pkValue := object.FieldByName(pkName)

    if !pkValue.IsValid() {
        return false
    }

    switch pkValue.Kind() {
    case reflect.String:
        if len(pkValue.String()) > 0 {
            return true
        }
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        if pkValue.Uint() != 0 {
            return true
        }
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        if pkValue.Int() != 0 {
            return true
        }
    }

    return false
}

//预处理占位符的数据
func (this *BaseDao) prepareValues(data interface{}, fieldNames list.List) []interface{} {
    values := make([]interface{}, len(this.columnToField))
    object := reflect.ValueOf(data).Elem()

    i := 0
    for e := fieldNames.Front(); e != nil; e = e.Next() {
        name := e.Value.(string)
        field := object.FieldByName(name)
        values[i] = this.fieldValue(field)
        i++

    }

    return values
}

//reflect.Value获取值
func (this *BaseDao) fieldValue(v reflect.Value) interface{} {
    if !v.IsValid() {
        return nil
    }

    switch v.Kind() {
    case reflect.String:
        return v.String()
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        return v.Uint()
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        return v.Int()
    case reflect.Float32, reflect.Float64:
        return v.Float()
    case reflect.Struct:
        switch v.Type().String() {
        case "time.Time":
            m := v.MethodByName("Format")
            rets := m.Call([]reflect.Value{reflect.ValueOf(timeFormate)})
            t := rets[0].String()
            return t
        default:
            return nil
        }
    default:
        return nil
    }
}

func (this *BaseDao) insertPrepareSQL() (fieldNames list.List, sql string) {
    names := new(bytes.Buffer)
    values := new(bytes.Buffer)

    i := 0

    for column, fieldName := range this.columnToField {

        if i != 0 {
            names.WriteString(",")
            values.WriteString(",")
        }
        fieldNames.PushBack(fieldName)
        names.WriteString(column)
        values.WriteString("?")
        i++
    }
    sql = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", this.tableName, names.String(), values.String())
    return
}

//删除一个实体
func (this *BaseDao) Delete(data interface{}) error {
    _, err := Open().Exec(this.deleteSQL(data))
    return err
}

//实体转delete sql语句
func (this *BaseDao) deleteSQL(data interface{}) string {
    //DELETE FROM 表名称 WHERE 列名称 = 值
    object := reflect.ValueOf(data).Elem()
    fieldValue := object.FieldByName(this.pk)
    pkValue := this.valueToString(fieldValue)
    return fmt.Sprintf("DELETE FROM  %s WHERE %s=%s", this.tableName, this.pk, pkValue)
}

//reflect.Value转字符串
func (this *BaseDao) valueToString(v reflect.Value) string {
    values := new(bytes.Buffer)
    switch v.Kind() {
    case reflect.String:
        values.WriteString("'")
        values.WriteString(v.String())
        values.WriteString("'")
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        values.WriteString(fmt.Sprintf("%d", v.Uint()))
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        values.WriteString(fmt.Sprintf("%d", v.Int()))
    case reflect.Float32, reflect.Float64:
        values.WriteString(fmt.Sprintf("%f", v.Float()))
    case reflect.Struct:
        switch v.Type().String() {
        case "time.Time":
            m := v.MethodByName("Format")
            rets := m.Call([]reflect.Value{reflect.ValueOf(timeFormate)})
            t := rets[0].String()
            values.WriteString("'")
            values.WriteString(t)
            values.WriteString("'")
        default:
            values.WriteString("null")
        }
    default:
        values.WriteString("null")
    }

    return values.String()
}

//根据SQL查询多条记录
func (this *BaseDao) Find(sql string) (*list.List, error) {
    return this.query(sql, false)
}

//根据SQL查询一条记录,如果找到不数据,data会返回nil
func (this *BaseDao) FindOne(sql string) (interface{}, error) {
    datas, err := this.query(sql, true)

    var data interface{}
    if datas.Len() > 0 {
        data = datas.Front().Value
    }
    return data, err

}

//根据SQL查询
func (this *BaseDao) query(sql string, isOne bool) (*list.List, error) {
    if isOne && !strings.Contains(sql, "limit") {
        sql = sql + " limit 1"
    }
    rows, err := Open().Query(sql)
    if err != nil {
        fmt.Println(err.Error())
    }

    defer rows.Close()
    columns, err := rows.Columns()
    if err != nil {
        panic(err.Error())
    }

    //构造scanArgs、values两个数组,scanArgs的每个值指向values相应值的地址
    values := make([]interface{}, len(columns))
    scanArgs := make([]interface{}, len(values))
    for i := range values {
        scanArgs[i] = &values[i]
    }

    data := list.New() //创建一个新的list

    for rows.Next() {
        err = rows.Scan(scanArgs...)
        if err != nil {
            panic(err.Error())
        }

        obj := this.parseQuery(columns, values)
        data.PushBack(obj)
    }
    return data, err
}

//对一条查询结果进行封装
func (this *BaseDao) parseQuery(columns []string, values []interface{}) interface{} {
    obj := reflect.New(this.EntityType).Interface()
    typ := reflect.ValueOf(obj).Elem()

    for i, col := range values {
        if col != nil {
            name := this.columnToField[columns[i]]
            field := typ.FieldByName(name)

            this.parseQueryColumn(field, string(col.([]byte)))
        }
    }

    return obj
}

//单个属性赋值
func (this *BaseDao) parseQueryColumn(field reflect.Value, s string) {
    switch field.Kind() {
    case reflect.String:
        field.SetString(s)
    case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
        v, _ := strconv.ParseUint(s, 10, 0)
        field.SetUint(v)
    case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
        v, _ := strconv.ParseInt(s, 10, 0)
        field.SetInt(v)
    case reflect.Float32:
        v, _ := strconv.ParseFloat(s, 32)
        field.SetFloat(v)
    case reflect.Float64:
        v, _ := strconv.ParseFloat(s, 64)
        field.SetFloat(v)
    case reflect.Struct:
        switch field.Type().String() {
        case "time.Time":
            v, _ := time.Parse(timeFormate, s)
            field.Set(reflect.ValueOf(v))
        }
    default:

    }
}

12、ORM使用

  • 定义一个Entity
/*entity*/
type Chair struct {
    Id              int       `table:"device_chair" column:"id"`
    Time            time.Time `column:"time"`
    Day             string    `column:"day"`
    Origin          string    `column:"origin"`
    ParentSerial    string    `column:"parent_serial"`
    Serial          string
    RunningNumber   string  `column:"running_number"`
    PointA          float32 `column:"point_a"`
    PointB          float32 `column:"point_b"`
    PointC          float32 `column:"point_c"`
    PointD          float32 `column:"point_d"`
    PointE          float32 `column:"point_e"`
    PointF          float32 `column:"point_f"`
    Voltage         float32 `column:"voltage"`
    SeatingPosition uint8   `column:"seating_position"`
}
  • 通过组合的方式实体BaseDao的继承
/*dao*/
type IChairDao interface {
    db.IBaseDao
}

type chairDao struct {
    db.BaseDao
}
  • 初始化Dao
    这个过程中记得要给BaseDao传入EntityType的值,并调用Init方法进行初始化工作。
var chairDaoImpl IChairDao

func ChairDao() IChairDao {
    if chairDaoImpl == nil {
        chairDaoImpl = &chairDao{db.BaseDao{EntityType: reflect.TypeOf(new(Chair)).Elem()}}
        chairDaoImpl.Init()
    }

    return chairDaoImpl
}
  • Dao使用
    这里使用增加数据为例子
    chair := new(Chair)
    chair.Id = 129986
    chair.SeatingPosition = 10
    chair.Day = chair.Time.Format("2006-01-02")
    err = ChairDao().Save(chair)

13、未完成

  1. 多模式主键的支持
  2. 外键关联的支持
  3. 事务的支持
上一篇:关于linux系统下批量修改文件名和后缀


下一篇:css:使div中的内容在同一行与其横向滚动条