推荐
关注
TOP
Message

Gin Gorm 实现CRUD 增删改查

ORM简介

对象关系映射(Object Relational Mapping,简称ORM)模式是一种为了解决面向对象与关系数据库(如mysql数据库)存在的互不匹配的现象的技术。简单的说,ORM是通过使用描述对象和数据库之间映射的元数据,将程序中的对象自动持久化到关系数据库中

安装

go get -u gorm.io/driver/sqlite
go get -u github.com/gin-gonic/gin
//安装MySQL驱动
go get -u gorm.io/driver/mysql
//安装gorm包
go get -u gorm.io/gorm

注意

老牌国产Golang orm框架。支持主流关系型数据库。中文文档适合新人入手,国内使用较多。最新版本2.x,比1.x有较大改动

注意:Gorm最新地址为https://github.com/go-gorm/gorm,之前https://github.com/jinzhu/gorm地址为v1旧版本

Gorm最新源码地址:https://github.com/go-gorm/gorm

V1版本地址:https://github.com/jinzhu/gorm

中文文档地址:https://gorm.io/zh_CN/

准备工作

Response 统一返回restful格式的数据

package response

import (
	"github.com/gin-gonic/gin"
	"net/http"
)

// Response
// context 上下文
// httpStatus http 状态码
// code 自己定义的状态码
// data 返回的空接口
// msg 返回的信息
func Response(context *gin.Context, httpStatus int, data gin.H, msg string) {
	context.JSON(httpStatus, gin.H{
		"httpStatus": httpStatus,
		"data": data,
		"msg":  msg,
	})
}

func Success(context *gin.Context, data gin.H, msg string) {
	context.JSON(http.StatusOK, gin.H{
		"code": 200,
		"data": data,
		"msg":  msg,
	})
}

func Fail(context *gin.Context, data gin.H, msg string) {
	context.JSON(http.StatusOK, gin.H{
		"code": 400,
		"data": data,
		"msg":  msg,
	})
}

func UnprocessableEntity(context *gin.Context, data gin.H, msg string) {
	context.JSON(http.StatusUnprocessableEntity, gin.H{
		"code": 422,
		"data": data,
		"msg":  msg,
	})
}

创建表数据

package model

type Post struct {
	ID        uint   `json:"id"`
	TitleDate Time   `json:"title_date"`
	TitleName string `json:"title_name" gorm:"type:varchar(50);not null"`
	Content   string `json:"content" gorm:"type:text;not null"`
	CreatedAt Time   `json:"create_at"`
	UpdatedAt Time   `json:"update_at"`
}

重写时间格式

package model

import (
	"database/sql/driver"
	"fmt"
	"time"
)

const timeFormat = "2006-01-02 15:04:05"
const timezone = "Asia/Shanghai"

type Time time.Time

func (t Time) MarshalJSON() ([]byte, error) {
	b := make([]byte, 0, len(timeFormat)+2)
	b = append(b, '"')
	b = time.Time(t).AppendFormat(b, timeFormat)
	b = append(b, '"')
	return b, nil
}

func (t *Time) UnmarshalJSON(data []byte) (err error) {
	now, err := time.ParseInLocation(`"`+timeFormat+`"`, string(data), time.Local)
	*t = Time(now)
	return
}

func (t Time) String() string {
	return time.Time(t).Format(timeFormat)
}

func (t Time) local() time.Time {
	loc, _ := time.LoadLocation(timezone)
	return time.Time(t).In(loc)
}

func (t Time) Value() (driver.Value, error) {
	var zeroTime time.Time
	var ti = time.Time(t)
	if ti.UnixNano() == zeroTime.UnixNano() {
		return nil, nil
	}
	return ti, nil
}

func (t *Time) Scan(v interface{}) error {
	value, ok := v.(time.Time)
	if ok {
		*t = Time(value)
		return nil
	}
	return fmt.Errorf("can not convert %v to timestamp", v)
}

链接数据库

package common

import (
	"GinDemo/model"
	"fmt"
	"github.com/spf13/viper"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"log"
	"net/url"
)

var DB *gorm.DB

func InitDB() *gorm.DB {
	//driverName := "mysql"
	host := viper.GetString("datasource.host")
	port := viper.GetString("datasource.port")
	username := viper.GetString("datasource.username")
	password := viper.GetString("datasource.password")
	database := viper.GetString("datasource.database")
	charset := viper.GetString("datasource.charset")
	loc:= viper.GetString("datasource.loc")
	args := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=true&loc=%s",
		username,
		password,
		host,
		port,
		database,
		charset,
		url.QueryEscape(loc),
	)

	var err error

	DB, err = gorm.Open(mysql.Open(args), &gorm.Config{})
	if err != nil {
		log.Fatal(err)
	}
	fmt.Printf("连接成功:%v\n", DB)

	return DB
}

func GetDb() *gorm.DB {
	return DB
}

在config/下的application.yml下设置好配置

server:
  port: 1016
datasource:
  driverName: mysql
  host: 127.0.0.1
  port: 3306
  database: ginInessential
  username: root
  password: admin*123
  charset: utf8
  loc: Asia/Shanghai

写完之后 在main函数中要初始化配置


func InitConfig() {
	workDir, _ := os.Getwd()
	viper.SetConfigName("application")
	viper.SetConfigType("yml")
	viper.AddConfigPath(workDir + "/config")

	err := viper.ReadInConfig()
	if err != nil {
		return
	}
}

Crud 增删改查

封装DB对象

type IPostController interface {
	Create(ctx *gin.Context)
	Delete(ctx *gin.Context)
	Put(ctx *gin.Context)
	Select(ctx *gin.Context)
	PageList(ctx *gin.Context)
}

// 操作的对象
type PostController struct {
	DB *gorm.DB
}

// 所需参数
type PostRequest struct {
	TitleName uint   `json:"title_name" binding:"required,max=10"`
	TitleDate string `json:"title_date" binding:"required"`
	Content   string `json:"content" binding:"required"`
}

// 返回操作对象
func NewPostController() IPostController {
	db := GetDb()
	_ = db.AutoMigrate(&Post{})
	return PostController{DB: db}
}

此时 return PostController{DB: db} 的地方会报红,因为 NewPostController的返回值是IPostController
而 PostController 结构体 还没有实现 增删改查这四个方法

分页

我们先来实现分页功能

func (p PostController) PageList(ctx *gin.Context) {
	// 获取分页参数
	pageNum, _ := strconv.Atoi(ctx.DefaultQuery("pageNum", "1"))
	pageSize, _ := strconv.Atoi(ctx.DefaultQuery("pageSize", "20"))

	// 分页
	var posts []Post
	p.DB.Order("created desc").Offset((pageNum - 1) * pageSize).Limit(pageSize).Find(&posts)

	// 记录的总条数
	var total int64
	p.DB.Model(Post{}).Count(&total)

	response.Success(ctx, gin.H{
		"data":  posts,
		"total": total,
	}, "查询成功")

}

添加

func (p PostController) Create(ctx *gin.Context) {

	var requestPost PostRequest

	if err := ctx.ShouldBind(&requestPost); err != nil {
		log.Println(err)
		response.Fail(ctx, gin.H{"数据": requestPost}, "数据验证错误")
		return
	}

	// 创建文章
	post := model.Post{
		TitleName:      requestPost.TitleName,
                TitleDate:      requestPost.TitleDate,
		Content:        requestPost.Content,
	}
	if err := p.DB.Create(&post).Error; err != nil {
		panic(err)
		return
	}
	response.Success(ctx, gin.H{
		"文章": post,
	}, "创建成功")
}

删除

func (p PostController) Delete(ctx *gin.Context) {
// 获取path的ID
postID := ctx.Params.ByName("id")

var post model.Post

if err := p.DB.First(&post, postID).Error; err != nil {
	response.Fail(ctx, nil, "文章不存在")
	return
}

p.DB.Delete(&post)
response.Success(ctx, gin.H{"post": post}, "删除成功")
return

}

修改

func (p PostController) Put(ctx *gin.Context) {
	var requestPost PostRequest

	if err := ctx.ShouldBind(&requestPost); err != nil {
		response.Fail(ctx, nil, "数据验证错误")
		return
	}

	// 获取path的ID
	postID := ctx.Params.ByName("id")

	var post model.Post

	if err := p.DB.First(&post, postID).Error; err != nil {
		response.Fail(ctx, nil, "文章不存在")
		return
	}


	// 更新文章
	if err := p.DB.Model(&post).Updates(requestPost).Error; err != nil {
		response.Fail(ctx, nil, "更新失败")
		return
	}
	response.Success(ctx, gin.H{"post": post}, "更新成功")
	return
}

查询

func (p PostController) Select(ctx *gin.Context) {
	// 获取path的ID
	postID := ctx.Params.ByName("id")

	var post model.Post

	if err := p.DB.Preload("Category").First(&post, postID).Error; err != nil {
		response.Fail(ctx, nil, "文章不存在")
		return
	}
	response.Success(ctx, gin.H{"post": post}, "查询成功")
	return
}

全部代码

package controller

import (
	"GinDemo/common"
	"GinDemo/model"
	"GinDemo/response"
	"github.com/gin-gonic/gin"
	"gorm.io/gorm"
	"log"
	"strconv"
)

type IPostController interface {
	Create(ctx *gin.Context)
	Delete(ctx *gin.Context)
	Put(ctx *gin.Context)
	Select(ctx *gin.Context)
	PageList(ctx *gin.Context)
}
type PostController struct {
	DB *gorm.DB
}

func NewPostController() IPostController {
	db := GetDb()
	_ = db.AutoMigrate(&Post{})
	return PostController{DB: db}
}

type PostRequest struct {
	TitleName uint   `json:"title_name" binding:"required,max=10"`
	TitleDate string `json:"title_date" binding:"required"`
	Content   string `json:"content" binding:"required"`
}

func (p PostController) PageList(ctx *gin.Context) {
	// 获取分页参数
	pageNum, _ := strconv.Atoi(ctx.DefaultQuery("pageNum", "1"))
	pageSize, _ := strconv.Atoi(ctx.DefaultQuery("pageSize", "20"))

	// 分页
	var posts []Post
	p.DB.Order("created desc").Offset((pageNum - 1) * pageSize).Limit(pageSize).Find(&posts)

	// 记录的总条数
	var total int64
	p.DB.Model(Post{}).Count(&total)

	response.Success(ctx, gin.H{
		"data":  posts,
		"total": total,
	}, "查询成功")

}

func (p PostController) Create(ctx *gin.Context) {

	var requestPost PostRequest

	if err := ctx.ShouldBind(&requestPost); err != nil {
		log.Println(err)
		response.Fail(ctx, gin.H{"数据": requestPost}, "数据验证错误")
		return
	}

	// 创建文章
	post := model.Post{
		TitleName: requestPost.TitleName,
		TitleDate: requestPost.TitleDate,
		Content:   requestPost.Content,
	}
	if err := p.DB.Create(&post).Error; err != nil {
		panic(err)
		return
	}
	response.Success(ctx, gin.H{
		"文章": post,
	}, "创建成功")
}

func (p PostController) Delete(ctx *gin.Context) {
	// 获取path的ID
	postID := ctx.Params.ByName("id")

	var post model.Post

	if err := p.DB.First(&post, postID).Error; err != nil {
		response.Fail(ctx, nil, "文章不存在")
		return
	}

	p.DB.Delete(&post)
	response.Success(ctx, gin.H{"post": post}, "删除成功")
	return

}

func (p PostController) Put(ctx *gin.Context) {
	var requestPost PostRequest

	if err := ctx.ShouldBind(&requestPost); err != nil {
		response.Fail(ctx, nil, "数据验证错误")
		return
	}

	// 获取path的ID
	postID := ctx.Params.ByName("id")

	var post model.Post

	if err := p.DB.First(&post, postID).Error; err != nil {
		response.Fail(ctx, nil, "文章不存在")
		return
	}

	// 更新文章
	if err := p.DB.Model(&post).Updates(requestPost).Error; err != nil {
		response.Fail(ctx, nil, "更新失败")
		return
	}
	response.Success(ctx, gin.H{"post": post}, "更新成功")
	return
}

func (p PostController) Select(ctx *gin.Context) {
	// 获取path的ID
	postID := ctx.Params.ByName("id")

	var post model.Post

	if err := p.DB.Preload("Category").First(&post, postID).Error; err != nil {
		response.Fail(ctx, nil, "文章不存在")
		return
	}
	response.Success(ctx, gin.H{"post": post}, "查询成功")
	return
}

Gin 中间件

异常中间件


package middleware

import (
	"GinDemo/response"
	"fmt"
	"github.com/gin-gonic/gin"
)

func RecoveryMiddleware() gin.HandlerFunc {
	return func(context *gin.Context) {
		defer func() {
			if err := recover(); err != nil {
				response.Fail(context, nil, fmt.Sprint(err))
			}
		}()

		context.Next()
	}

}

跨域中间件

package middleware

import (
	"github.com/gin-gonic/gin"
	"net/http"
)

func CORSMiddleware() gin.HandlerFunc {
	return func(context *gin.Context) {

		context.Header("Access-Control-Allow-Origin", "*")
		context.Header("Access-Control-Allow-Headers", "Content-Type,AccessToken,X-CSRF-Token, Authorization, Token, x-token")
		context.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE, PATCH, PUT")
		context.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Content-Type")
		context.Header("Access-Control-Allow-Credentials", "true")

		if context.Request.Method == "OPTIONS" {
			context.AbortWithStatus(http.StatusNoContent)
		}else {
			context.Next()
		}
	}
}

用户验证中间件

这个文章中没有涉及到 但是我之前写的代码是有的 也拉出来单独讲讲
就是验证Token的正确性
然后简单看看 因为后面还有JWT 还有返回用户 内容很多 简单看看逻辑即可

package middleware

import (
	"GinDemo/common"
	"GinDemo/dto"
	"GinDemo/response"
	"fmt"
	"net/http"
	"strings"
	"time"

	"github.com/gin-gonic/gin"
)

func AuthMiddleware() gin.HandlerFunc {
	return func(context *gin.Context) {
		// 获取 authorization headers
		tokenString := context.GetHeader("Authorization")
		fmt.Printf("目前的tokenString:%v\n", tokenString)
		// 如果token 开头是空,或者开头不是以bearer结尾的 则报错
		if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer") {
			response.Response(context, http.StatusUnauthorized, nil, "权限验证错误")

			context.Abort()
			return
		}

		//token格式错误
		tokenSlice := strings.SplitN(tokenString, " ", 2)
		if len(tokenSlice) != 2 && tokenSlice[0] != "Bearer" {
			response.UnprocessableEntity(context, nil, "token格式错误")

			context.Abort() //阻止执行
			return
		}
		//验证token
		claims, ok := common.ParseToken(tokenSlice[1])
		if !ok {
			response.UnprocessableEntity(context, nil, "token不正确")
			context.Abort() //阻止执行
			return
		}
		//token超时
		if time.Now().Unix() > claims.StandardClaims.ExpiresAt {
			response.UnprocessableEntity(context, nil, "token过期")
			context.Abort() //阻止执行
			return
		}

		/*
			//返回所有数据
			context.Set("userInfo", tokenStruck)
			context.Next()

			// 返回部分数据
			var user model.User
			db := common.GetDb()
			db.First(&user, tokenStruck.UserID)

			// 如果用户不存在
			if user.ID == 0 {
				context.JSON(http.StatusUnauthorized, gin.H{
					"code": 401,
					"msg":  "用户不存在",
				})
				context.Abort()
				return
			}

			// 返回部分数据
			context.Set("name", user.Name)
			context.Set("telephone", user.Telephone)
			context.Next()
		*/

		userInfo := dto.ToUserDto(claims, context)
		// 返回部分数据
		context.Set("userInfo", userInfo)
		context.Next()

	}
}

路由


package main

import (
	"GinDemo/controller"
	"GinDemo/middleware"

	"github.com/gin-gonic/gin"
)

func CollectRoute(r *gin.Engine) *gin.Engine {
	r.Use(middleware.CORSMiddleware(), middleware.RecoveryMiddleware())
	postRoutes := r.Group("/posts")
	// postRoutes.Use(middleware.AuthMiddleware())
	postController := controller.NewPostController()
	postRoutes.POST("", postController.Create)
	postRoutes.DELETE(":id", postController.Delete)
	postRoutes.PUT(":id", postController.Put)
	postRoutes.GET(":id", postController.Select)
	postRoutes.POST("page/list", postController.PageList)


	return r
}

调用

package main

import (
	"GinDemo/common"
	"github.com/gin-gonic/gin"
	"github.com/spf13/viper"
	"os"
)

func main() {
	//初始化配置
	InitConfig()

	//初始化数据库
	common.InitDB()

	//使用gin
	r := gin.Default()
	r = CollectRoute(r)

	port := viper.GetString("server.port")

	if port != "" {
		_ = r.Run(":" + port)
	} else {
		err := r.Run()
		if err != nil {
			return 
		}
	}

}

func InitConfig() {
	workDir, _ := os.Getwd()
	viper.SetConfigName("application")
	viper.SetConfigType("yml")
	viper.AddConfigPath(workDir + "/config")

	err := viper.ReadInConfig()
	if err != nil {
		return
	}
}
posted @ 2022-10-27 15:41  始識  阅读(518)  评论(0)    收藏  举报