单元测试(go)

项目demo地址:go-test

目前只描述了简单的方法,文档持续更新中(通常在周末更新,平常上班)...

本文主要针对golang语言的单元测试工具,博客内容也会涉及一些单元相关的内容

什么是单元测试:单元测试是软件测试体系中最基础、最核心的测试类型,它聚焦于对软件系统中最小的 “可测试单元” 进行独立验证,确保该单元的功能符合预期设计。

简单描述下前因后果:工作需要对项目代码系统化执行单元测试,要求覆盖率达到95%以上,因为不同人的开发风格和代码习惯,外加项目框架和架构的一些要求。单元测试,这个东西一般情况都会让人很痛苦,至于因为啥,我相信看到我这篇博客的各位,都是有不同程度的感同身受的,我这里介绍三种单测工具,基础的单元测试书写和使用就不过多赘述了。

一、单元测试的核心方式

注意:这块作为扩展,如需直接了解工具可忽略这部分

单元测试的实现方式可从核心分类维度展开,结合具体落地实践和技术选型,不同方式适用于不同场景和项目规模。

1、按测试编写时机划分

(核心流程维度)

从开发流程角度区分的两种核心方式,决定了单元测试与业务代码的协作关系。

1.后写单元测试

(传统方式,最常用)

  1. 核心定义:先编写业务功能代码,待功能实现完成后,再针对性补写对应的单元测试用例,验证已实现的代码逻辑是否符合预期。
  2. 适用场景:大部分传统开发场景、快速迭代的小型需求、开发者对 TDD 模式不熟悉的项目。
  3. 优势:符合开发者 “先实现功能再验证” 的直觉,上手成本低,无需提前设计详细的测试用例。
  4. 劣势:可能遗漏部分边界场景的测试,且容易因业务代码耦合度高,导致测试难以编写(后期补测时,修改代码解耦的成本更高)。

2.测试驱动开发

(TDD,Test-Driven Development,进阶方式)

  1. 核心定义:遵循 “先写测试,再写业务代码,最后重构” 的循环流程,测试用例先定义好被测单元的预期行为(输入、输出、异常场景),再编写满足测试用例的业务代码,最终优化代码结构。
  2. 核心流程(红 - 绿 - 重构循环)
    1. 红(Red):编写一个失败的测试用例(此时业务代码未实现,测试必然失败);
    2. 绿(Green):编写最少的业务代码,仅满足让该测试用例通过(不追求代码优雅,只保证功能达标);
    3. 重构(Refactor):在测试用例保驾护航的前提下,优化业务代码的结构、可读性、性能等,确保重构后测试用例仍能通过。
  3. 适用场景:对代码质量要求高的核心模块、复杂业务逻辑、需要长期维护的大型项目。
  4. 优势
    • 强制开发者提前梳理需求和接口设计,减少后期需求偏差;
    • 测试用例覆盖率更高,天然覆盖正常、边界、异常场景;
    • 重构无风险,测试用例作为 “安全网”,确保重构不破坏原有功能;
    • 代码耦合度更低,因为先写测试会倒逼开发者设计可测试的代码(如依赖接口而非具体实现)。

2、按依赖处理方式划分

(技术实现维度)

这是单元测试落地的核心技术维度,决定了如何隔离外部依赖,保证测试的独立性。

1. 基于 Mock/Stub 的单元测试

(主流方式)

  • 核心定义:当被测单元依赖外部资源(数据库、RPC 服务、HTTP 接口、文件系统等)时,通过 ** 模拟(Mock)桩(Stub)** 实现替代真实依赖,预设返回值或行为,从而脱离外部环境限制,专注测试业务逻辑。

  • Mock vs Stub 区别(通俗理解)

    类型 核心特征 适用场景
    Stub 仅预设固定返回值,无行为验证 只需依赖返回值完成业务逻辑测试
    Mock 不仅预设返回值,还可验证依赖方法是否被调用、调用次数、参数是否正确 需要验证业务逻辑对依赖的调用行为
  • 实现方式

    1. 手动编写 Mock/Stub(简单场景):如之前 Go 示例中,手动实现UserDB接口的MockUserDB,预设返回值;
    2. 工具自动生成 Mock(复杂场景):Go 生态的gomock+mockgen、Java 生态的Mockito、Python 生态的unittest.mock,可根据接口自动生成 Mock 代码,支持更灵活的行为验证。
  • 优势:测试独立、快速、可复现,不受外部资源状态影响,能覆盖各种异常场景(如依赖服务报错、超时)。

  • Go 语言工具示例(gomock)

    (1)先安装工具:

    go get github.com/golang/mock/gomock
    
    go install github.com/golang/mock/mockgen@latest
    

    (2)自动生成 Mock 代码,无需手动编写,支持验证调用行为。

2. 真实依赖单元测试

(小众场景)

  • 核心定义:不使用模拟实现,直接使用真实的外部依赖(如真实数据库、真实 HTTP 服务)进行单元测试,验证被测单元与真实依赖的协作是否正常。
  • 适用场景:依赖逻辑简单、真实依赖易于搭建和控制(如本地轻量数据库 SQLite)、对依赖协作正确性要求极高的场景。
  • 优势:测试结果更贴近生产环境,能发现与真实依赖协作的潜在问题(如 SQL 语法错误、接口参数不匹配)。
  • 劣势
    • 测试执行速度慢,依赖外部资源启动和初始化;
    • 测试结果不稳定,受外部资源状态影响(如数据库数据被修改导致测试失败);
    • 测试环境搭建复杂,需要统一管理依赖配置(如数据库连接信息、服务地址)。
  • 示例:测试数据库查询函数时,直接连接本地测试 MySQL 数据库,预先插入测试数据,执行查询后验证结果,最后清理测试数据。

3、个人理解

上面的描述是大模型系统化生成的内容,下面是博主自行整理的,至于为什么会有这样一段赘述,是和下面的工具有些关联

单元测试两种开发方式:

1.方式一

先开发业务代码,后写单元测试代码(常用)

  1. interface单元测试

    1. 核心优势:

      • 完全解耦外部依赖,实现 “纯净” 测试
      • 灵活覆盖全量业务场景,无测试死角
      • 测试执行效率极高,支持高频执行
      • 为代码重构提供 “安全网”,降低重构风险
      • 倒逼良好的代码设计,提升代码质量
      • 可验证依赖方法的调用行为(进阶优势)
    2. 缺点:

      • 增加前期开发成本,引入额外代码量
      • 存在 “过度抽象” 的风险
      • 无法验证真实依赖的协作正确性
      • Mock 与真实实现可能存在 “行为不一致”
      • 对简单场景 “杀鸡用牛刀”,性价比不高
    3. 总结:

      如果代码开发的时候考虑到需要进行单元测试功能开发,可以直接在业务功能开发时进行单元测试的预先埋点处理,做好接口的开发,不过一般情况下大家的开发习惯都不会考虑单元测试这种情况,这时候在想要回去处理单元测试,interface这种方式就会极为麻烦和笨重,单测时间成本成指数级增长。

  2. 使用单元测试工具

    1. 内置核心工具:testing包(基础基石)

    2. 工具包(具体使用方法和功能下面介绍)

      • 接口测试工具:httptest
      • 数据库测试工具:go-sqlmock
      • 打桩测试工具:gomonkey
    3. 优点:

      在业务逻辑代码开发完成后几乎可以不调整原始逻辑代码进行单元测试代码开发

2.方式二

先开发单元测试代码,后写逻辑代码(很少见,不介绍)

二、单元测试命令

go test -v

运行当前包下单元测试;-v 打印详细日志

go test -run "^$"

运行单元测试函数

go test -v 文件名_test.go 业务文件.go

运行单文件单元测试函数

go test ./...

运行整个目录的单元测试文件,包括子目录下的单元测试文件

go test -cover

覆盖率统计

go test ./... -cover

整个目录覆盖率统计,包括子目录

go test -coverprofile=cover.out

当前目录,执行测试 + 生成覆盖率统计的【原始数据文件】cover.out (核心基础,必须先执行)

go tool cover -func=cover.out

当前目录,以【纯文本、按函数】展示覆盖率详情(终端直接看,快速统计)

go tool cover -html=cover.out -o cover.html

当前目录,生成【可视化HTML报告文件】cover.html(最实用,精准看哪行代码未覆盖)

go test -coverprofile=cover.out ./...

整个目录操作,包括子目录

go test -run "^$" -gcflags=all=-l

-gcflags=all=-l 禁用所有包的内联优化,gomonkey官方推荐的打桩必须使用,不过这个和覆盖率-cover参数有冲突,推荐覆盖率计算和单测执行分开使用

go test -run "^$" -cover -gcflags=all=-l -covermode=atomic

-covermode=atomic 可以强制禁用内内联优化和执行覆盖率,不过有时不会生效

三、单元测试工具

主要介绍三个工具:httptest、go-sqlmock、gomonkey

1、httptest

介绍:Go 内置标准库net/http/httptest,核心用途用于测试net/http构建的HTTP服务(如API接口、Web服务等),它可以模拟HTTP请求发送和HTTP响应的接收,无需启动真实的HTTP服务器即可完成接口测试,极大提升了测试的便捷性和执行效率

优点:

  • 无需启动真实服务器:无需调用 http.ListenAndServe 启动端口监听,直接测试 HTTP 处理器(Handler/HandlerFunc),测试执行更高效。
  • 脱离网络依赖:模拟 HTTP 请求与响应的完整生命周期,不受网络波动、端口占用等外部因素影响,测试结果稳定可复现。
  • 精准捕获响应细节:可完整获取响应状态码、响应头、响应体等所有信息,便于精准断言验证。
  • 支持两种核心测试场景:
    • 测试 HTTP 处理器(直接调用 ServeHTTP,最常用)
    • 启动临时测试服务器(模拟真实服务端,用于客户端测试或集成测试)

1.安装

内置工具可以直接使用

2.使用示例

相关代码在gitee代码仓库的示例代码中,仓库地址请看博客开头

blog.go

package httptest_demo

import (
	"errors"
	"fmt"
	"io"
	"net/http"
)

func SearchHttp(targetURL string) (interface{}, error) {
	resp, err := http.Get(targetURL)
	if err != nil {
		errMsg := fmt.Sprintf("发送 GET 请求失败:%v", err)
		fmt.Println(errMsg)
		return nil, errors.New(errMsg)
	}
	defer resp.Body.Close()

	if resp.StatusCode != http.StatusOK {
		errMsg := fmt.Sprintf("请求失败,状态码:%d,状态信息:%s", resp.StatusCode, resp.Status)
		fmt.Println(errMsg)
		return nil, errors.New(errMsg)
	}

	bodyBytes, _ := io.ReadAll(resp.Body)
	return string(bodyBytes), nil
}

blog_test.go

package httptest_demo

import (
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"
)

func TestSearchHttp(t *testing.T) {
	// --------------- 1. 定义所有测试场景的表驱动用例(循环遍历执行) ---------------
	testCases := []struct {
		name             string        // 用例名称
		prepareFunc      func() string // 前置准备:创建模拟服务器/构造URL,返回待请求的URL
		expectedErr      bool          // 是否预期返回错误
		errContains      string        // 预期错误信息包含的关键字(非空则验证)
		expectedNonEmpty bool          // 正常场景下,是否预期返回非空字符串
	}{
		// 场景1:正常请求(200 OK,响应体正常)
		{
			name: "正常请求-状态码200",
			prepareFunc: func() string {
				// 启动模拟HTTP服务器,返回200和测试响应体
				mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					mockBody := "<!DOCTYPE html><html><title>百度一下</title></html>"
					w.WriteHeader(http.StatusOK)
					_, _ = w.Write([]byte(mockBody))
				}))
				// 关键:将mockServer放入测试上下文,确保后续能关闭(避免资源泄露)
				t.Cleanup(func() { mockServer.Close() })
				return mockServer.URL
			},
			expectedErr:      false,
			errContains:      "",
			expectedNonEmpty: true,
		},
		// 场景2:请求失败(无效URL,模拟网络异常)
		{
			name: "异常场景-无效URL请求失败",
			prepareFunc: func() string {
				// 返回一个无效的URL,触发http.Get请求失败
				return "http://invalid-xxx-url-12345/"
			},
			expectedErr:      true,
			errContains:      "发送 GET 请求失败",
			expectedNonEmpty: false,
		},
		// 场景3:状态码非200(模拟404 Not Found)
		{
			name: "异常场景-状态码404",
			prepareFunc: func() string {
				// 启动模拟HTTP服务器,返回404状态码
				mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					w.WriteHeader(http.StatusNotFound)
					_, _ = w.Write([]byte("页面不存在"))
				}))
				t.Cleanup(func() { mockServer.Close() })
				return mockServer.URL
			},
			expectedErr:      true,
			errContains:      "请求失败,状态码:404",
			expectedNonEmpty: false,
		},
		// 场景4:状态码非200(模拟500服务器内部错误)
		{
			name: "异常场景-状态码500",
			prepareFunc: func() string {
				// 启动模拟HTTP服务器,返回500状态码
				mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					w.WriteHeader(http.StatusInternalServerError)
					_, _ = w.Write([]byte("服务器内部错误"))
				}))
				t.Cleanup(func() { mockServer.Close() })
				return mockServer.URL
			},
			expectedErr:      true,
			errContains:      "请求失败,状态码:500",
			expectedNonEmpty: false,
		},
	}

	// --------------- 2. 循环遍历所有测试用例,统一执行验证 ---------------
	for _, tc := range testCases {
		// 循环内使用t.Run创建子用例(便于精准定位失败场景,不影响其他用例)
		t.Run(tc.name, func(t *testing.T) {
			// 步骤1:执行前置准备,获取待请求的URL
			targetURL := tc.prepareFunc()

			// 步骤2:调用被测函数
			result, err := SearchHttp(targetURL)

			// 步骤3:统一断言验证
			// 3.1 验证错误是否符合预期
			if (err != nil) != tc.expectedErr {
				t.Fatalf("错误预期不符:预期是否错误[%t],实际是否错误[%t],错误信息[%v]",
					tc.expectedErr, err != nil, err)
			}

			// 3.2 若预期错误,验证错误信息是否包含指定关键字
			if tc.expectedErr && tc.errContains != "" {
				if !strings.Contains(err.Error(), tc.errContains) {
					t.Errorf("错误信息不符:预期包含[%s],实际错误[%v]", tc.errContains, err)
				}
			}

			// 3.3 验证返回值是否符合预期
			if tc.expectedErr {
				// 异常场景:预期返回nil
				if result != nil {
					t.Errorf("异常场景预期返回nil,实际返回[%v],类型[%T]", result, result)
				}
			} else {
				// 正常场景:验证返回值是string类型,且非空(若预期非空)
				resultStr, ok := result.(string)
				if !ok {
					t.Fatalf("正常场景预期返回string类型,实际返回[%T]", result)
				}
				if tc.expectedNonEmpty && len(resultStr) == 0 {
					t.Error("正常场景预期返回非空字符串,实际返回空字符串")
				}
			}
		})
	}
}

命令行执行命令

go test -cover

结果:

PS D:\wyl\workspace\go\tracer\logic\httptest_demo> go test -cover         
发送 GET 请求失败:Get "http://invalid-xxx-url-12345/": dial tcp: lookup invalid-xxx-url-12345: no such host
请求失败,状态码:404,状态信息:404 Not Found
请求失败,状态码:500,状态信息:500 Internal Server Error
PASS
coverage: 100.0% of statements
ok      tracer/logic/httptest_demo      1.683s

2、go-sqlmock

介绍:gosqlmock是一个用于模拟数据库 /sql 驱动的库,核心作用是在不依赖真实数据库实例的情况下,对数据库相关逻辑进行单元测试,避免测试过程中操作真实数据、产生脏数据或依赖数据库服务可用性。

优点:

  • 解除真实数据库依赖,保证测试独立、稳定、无脏数据
  • 精准控制数据库行为,覆盖常规 / 异常全量测试场景
  • 兼容 database/sql 标准库和主流 ORM,无侵入式集成
  • 严格验证预期行为,提升测试准确性,发现隐藏问题
  • 轻量级无冗余,内存级执行,测试性能优异
  • 支持正则匹配,灵活适配复杂 SQL 场景

1.安装

github地址

go get github.com/DATA-DOG/go-sqlmock

2.使用示例

相关代码在gitee代码仓库的示例代码中,仓库地址请看博客开头

(1)查询mock

price_policy.go

package model

import (
	"gorm.io/gorm"
)

type PricePolicy struct {
	gorm.Model
	Catogory      string `gorm:"type:varchar(64)" json:"catogory" label:"收费类型"`
	Title         string `gorm:"type:varchar(64)" json:"title" label:"标题"`
	Price         uint64 `gorm:"type:int(5)" json:"httptest_demo" label:"价格"`
	ProjectNum    uint64 `json:"project_num" label:"项目数量"`
	ProjectMember uint64 `json:"project_member" label:"项目成员人数"`
	ProjectSpace  uint64 `json:"project_space" label:"每个项目空间" help_text:"单位是M"`
	PerFileSize   uint64 `json:"per_file_size" label:"单文件大小" help_text:"单位是M"`
}

// GetAllBlog 查询所有博客信息
func GetAllBlog() PricePolicy {
	var allBlog PricePolicy
	DB.Find(&allBlog)
	return allBlog
}

// TypeBlog 根据类型查找博客
func TypeBlog(tyb string) PricePolicy {
	var typeBlog PricePolicy
	DB.Model(&PricePolicy{}).Where("type=?", tyb).Find(&typeBlog)
	return typeBlog
}

// TopBlog 置顶博客查询
func TopBlog(top string) PricePolicy {
	var topBlog PricePolicy
	DB.Model(&PricePolicy{}).Where("top=?", top).Find(&topBlog)
	return topBlog
}

price_policy_test.go

package model

import (
	"github.com/DATA-DOG/go-sqlmock"
	"github.com/stretchr/testify/assert"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"testing"
	"time"
)

// TestGetAllBlog GetAllBlog 函数单元测试
func TestGetAllBlog(t *testing.T) {
	// 步骤1:创建 sqlmock 模拟连接(内存级,无真实数据库依赖)
	// sqlmock.New() 返回 mockDB(*sql.DB)、mock(sqlmock.Sqlmock)、error
	mockSqlDB, mock, err := sqlmock.New()
	assert.NoError(t, err, "创建 sqlmock 连接失败")
	defer mockSqlDB.Close() // 测试结束关闭模拟连接

	// 步骤2:将 sqlmock 连接适配为 GORM 可用的 DB 实例
	// 关键:使用 gorm mysql 驱动,传入 mock 的 *sql.DB 实例
	gormDB, err := gorm.Open(mysql.New(mysql.Config{
		Conn:                      mockSqlDB, // 绑定 sqlmock 的连接
		SkipInitializeWithVersion: true,      // 跳过 MySQL 版本检测(模拟连接无需版本信息)
	}), &gorm.Config{})
	assert.NoError(t, err, "GORM 绑定 sqlmock 连接失败")

	// 步骤3:替换全局 DB 为 mock 的 GORM DB(核心:让业务函数使用 mock 连接)
	DB = gormDB

	// 步骤4:构造模拟返回数据(与 PricePolicy 字段对应,需包含 gorm.Model 的默认字段)
	expectedPolicy := PricePolicy{
		Model: gorm.Model{
			ID:        1,
			CreatedAt: time.Time{}, // 测试中可忽略时间字段,若需精确匹配可赋值 time.Time 实例
			UpdatedAt: time.Time{},
			DeletedAt: gorm.DeletedAt{},
		},
		Catogory:      "个人版",
		Title:         "基础收费套餐",
		Price:         99,
		ProjectNum:    5,
		ProjectMember: 10,
		ProjectSpace:  1024,
		PerFileSize:   50,
	}

	// 步骤5:设置 sqlmock 预期(关键:匹配 GORM 自动生成的 SQL 语句)
	// GORM 的 Find(&allBlog) 会生成 SELECT * FROM `price_policies` 语句(表名默认是结构体小写复数)
	// 使用正则匹配,忽略无关空格和潜在的字段顺序差异
	rows := sqlmock.NewRows([]string{
		"id", "created_at", "updated_at", "deleted_at",
		"catogory", "title", "httptest_demo", "project_num",
		"project_member", "project_space", "per_file_size",
	}).AddRow(
		expectedPolicy.ID, expectedPolicy.CreatedAt, expectedPolicy.UpdatedAt, expectedPolicy.DeletedAt,
		expectedPolicy.Catogory, expectedPolicy.Title, expectedPolicy.Price, expectedPolicy.ProjectNum,
		expectedPolicy.ProjectMember, expectedPolicy.ProjectSpace, expectedPolicy.PerFileSize,
	)

	// 预设查询预期:匹配 GORM 生成的 SELECT 语句
	mock.ExpectQuery("^SELECT \\* FROM `price_policies`").
		WillReturnRows(rows) // 设置查询返回的模拟数据

	// 步骤6:执行待测试函数
	_ = GetAllBlog()

	// 步骤7:验证结果
	// 关键:验证所有 sqlmock 预期都已被执行(无遗漏、无多余操作)
	assert.NoError(t, mock.ExpectationsWereMet(), "存在未满足的 sqlmock 预期")
}

命令行执行命令

go test -run "^TestGetAllBlog$" -cover

结果:

PS D:\wyl\workspace\go\tracer\model> go test -run "^TestGetAllBlog$" -cover           
PASS
coverage: 13.6% of statements
ok      tracer/model    0.082s
(2)增删改mock

这个需要注意,gorm在执行增删改动作底层使用了事务操作,所以代码中没有使用到事务时,在mock中也要mock事务操作

user.go

package model

import (
	"gorm.io/gorm"
)

// UserInfo 用户表

type UserInfo struct {
	gorm.Model
	UserName string `gorm:"type:varchar(32);unique" json:"user_name" label:"用户名"`
	Password string `gorm:"size:60" json:"password" label:"密码"`
	Phone    string `gorm:"size:11;unique" json:"phone" label:"手机号"`
	Email    string `gorm:"size:32;unique" json:"email" label:"邮箱"`
}

// GetAllUser 查询所有用户信息
func GetAllUser() (users []UserInfo, err error) {
	err = DB.Model(&UserInfo{}).Find(&users).Error
	return
}

func UpdateUserPhone(id int64, phone string) (err error) {
	err = DB.Model(&UserInfo{}).Where("id = ?", id).Updates(map[string]interface{}{
		"phone": phone,
	}).Error
	return
}

user_test.go

package model

import (
	"errors"
	"github.com/DATA-DOG/go-sqlmock"
	"github.com/stretchr/testify/assert"
	"gorm.io/driver/mysql"
	"gorm.io/gorm"
	"testing"
)

// TestUpdateUserPhone_success 测试场景1:更新手机号【成功】- 正常更新匹配ID的用户手机号
func TestUpdateUserPhone_success(t *testing.T) {
	// 步骤1:创建 sqlmock 模拟连接(内存级,无真实数据库依赖)
	// sqlmock.New() 返回 mockDB(*sql.DB)、mock(sqlmock.Sqlmock)、error
	mockSqlDB, mock, err := sqlmock.New()
	assert.NoError(t, err, "创建 sqlmock 连接失败")
	defer mockSqlDB.Close() // 测试结束关闭模拟连接

	// 步骤2:将 sqlmock 连接适配为 GORM 可用的 DB 实例
	// 关键:使用 gorm mysql 驱动,传入 mock 的 *sql.DB 实例
	gormDB, err := gorm.Open(mysql.New(mysql.Config{
		Conn:                      mockSqlDB, // 绑定 sqlmock 的连接
		SkipInitializeWithVersion: true,      // 跳过 MySQL 版本检测(模拟连接无需版本信息)
	}), &gorm.Config{})
	assert.NoError(t, err, "GORM 绑定 sqlmock 连接失败")

	// 步骤3:替换全局 DB 为 mock 的 GORM DB(核心:让业务函数使用 mock 连接)
	DB = gormDB

	// 测试入参
	testID := int64(1)
	testPhone := "13800138000"

	// 核心mock断言:匹配GORM生成的update语句
	// ^ 匹配开头  $ 匹配结尾  \? 是sql占位符的正则转义
	mock.ExpectBegin()
	mock.ExpectExec("^UPDATE `user_infos` SET `phone`=\\?,`updated_at`=\\? WHERE id = \\? AND `user_infos`.`deleted_at` IS NULL$").
		WithArgs(testPhone, sqlmock.AnyArg(), testID). // phone=入参值, updated_at是gorm自动填充用任意值匹配, id=入参值
		WillReturnResult(sqlmock.NewResult(testID, 1)) // 返回执行结果:影响行数1行
	mock.ExpectCommit()

	// 执行待测试的业务函数
	err = UpdateUserPhone(testID, testPhone)

	// 断言:执行无错误
	if err != nil {
		t.Errorf("更新手机号失败,预期无错误,实际错误:%v", err)
	}
}

func TestUpdateUserPhone_error(t *testing.T) {
	// 步骤1:创建 sqlmock 模拟连接(内存级,无真实数据库依赖)
	// sqlmock.New() 返回 mockDB(*sql.DB)、mock(sqlmock.Sqlmock)、error
	mockSqlDB, mock, err := sqlmock.New()
	assert.NoError(t, err, "创建 sqlmock 连接失败")
	defer mockSqlDB.Close() // 测试结束关闭模拟连接

	// 步骤2:将 sqlmock 连接适配为 GORM 可用的 DB 实例
	// 关键:使用 gorm mysql 驱动,传入 mock 的 *sql.DB 实例
	gormDB, err := gorm.Open(mysql.New(mysql.Config{
		Conn:                      mockSqlDB, // 绑定 sqlmock 的连接
		SkipInitializeWithVersion: true,      // 跳过 MySQL 版本检测(模拟连接无需版本信息)
	}), &gorm.Config{})
	assert.NoError(t, err, "GORM 绑定 sqlmock 连接失败")

	// 步骤3:替换全局 DB 为 mock 的 GORM DB(核心:让业务函数使用 mock 连接)
	DB = gormDB

	// 测试入参
	testID := int64(1)
	testPhone := "13800138000"

	// 核心mock断言:匹配GORM生成的update语句
	// ^ 匹配开头  $ 匹配结尾  \? 是sql占位符的正则转义
	mock.ExpectBegin()
	mock.ExpectExec("^UPDATE `user_infos` SET `phone`=\\?,`updated_at`=\\? WHERE id = \\? AND `user_infos`.`deleted_at` IS NULL$").
		WithArgs(testPhone, sqlmock.AnyArg(), testID). // phone=入参值, updated_at是gorm自动填充用任意值匹配, id=入参值
		WillReturnError(errors.New("更新手机号失败"))
	mock.ExpectRollback()

	// 执行待测试的业务函数
	_ = UpdateUserPhone(testID, testPhone)
}

命令行执行命令

go test -run "^TestUpdateUserPhone_success$" -cover

结果:

PS D:\wyl\workspace\go\tracer\model> go test -run "^TestUpdateUserPhone_success$" -cover         
PASS
coverage: 9.1% of statements
ok      tracer/model    0.082s

go test -run "^TestUpdateUserPhone_error$" -cover

结果:

2026/01/11 17:05:36 D:/wyl/workspace/go/tracer/model/user.go:24 更新手机号失败
[0.506ms] [rows:0] UPDATE `user_infos` SET `phone`='13800138000',`updated_at`='2026-01-11 17:05:36.624' WHERE id = 1 AND `user_infos`.`deleted_at` IS NULL
PASS
coverage: 9.1% of statements
ok      tracer/model    0.086s

3、gomonkey

介绍:gomonkey是一款强大的运行时打桩(Mock)工具/动态 Mock 工具,能够在不修改源代码的前提下,对函数、方法、全局变量等进行动态替换,广泛用于单元测试场景。

工具很全面,可以针对数据库,外部请求http接口,变量和结构体等打桩,不过个人认为对于http接口和数据库还是使用上面的两个方法要方便一些

优点:

  • 无侵入式打桩,无需修改业务代码
  • 功能全面,支持函数、方法、全局变量等多种打桩场景
  • 支持私有成员打桩,适配遗留项目
  • 轻量级易用,API 简洁,兼容主流框架
  • 灵活控制打桩生命周期,精准适配测试需求
  • x86_64 (Intel/AMD) 架构下,功能基本完整、稳定,是生产级 Mock 工具

致命缺陷:

  • 对Windows、Mac、arm架构系统支持很不友好
    • 在 ARM64 (aarch64) 架构下,存在 大量核心功能失效、运行时崩溃、Mock 无效果 的缺陷
    • 缺陷是 gomonkey 的底层实现原理 导致的,而非 Go 语言 / ARM64 的兼容性问题,官方至今未彻底修复
  • Go1.18 + 版本 中更严重,Go1.17 及以下 ARM64 版本问题稍少,但依然存在关键缺陷

1.安装

github地址

go get github.com/agiledragon/gomonkey/v2

2.方法简介

官方推荐命令(禁用内联优化):go test -gcflags=all=-l

gomonkey有两种调用形式:

  • 全局函数调用
  • 结构体方法调用

区别:

  • 归属关系完全不同:前者包级别全局函数,不属于任何结构体,直接通过包名调用即可;后者结构体的指针接收者成员方法,属于结构体的一部分,必须通过 Patches 结构体的实例对象(指针)才能调用。
  • 调用前置条件不同:前者调用前,不需要手动创建Patches实例,全局函数内部帮你自动创建,直接调用即可;后者调用前,必须先手动创建一个Patches实例
  • 底层执行逻辑 - 最核心的调用链路不同:
    • 全局函数的执行链路(一行顶三步)
      • 步骤1:内部调用 create() 创建一个全新的空 Patches 实例
      • 步骤2:立刻调用这个新实例的 成员方法 ApplyMethodFunc
      • 步骤3:返回这个实例对象本身(链式调用基础)
      • 补充:源码中的create()等价于NewPatches(),都是初始化空的 Patches 结构体
    • 结构体成员方法的执行链路
      • 根据入参target(结构体 / 结构体指针 /reflect.Type)和methodName,反射获取目标方法;
      • 通过funcToMethod做「普通函数 → 结构体方法」的转换(核心逻辑,后面讲);
      • 调用ApplyCore执行底层的内存指令改写(打桩核心:函数地址跳转);
      • 返回Patches实例本身,支持链式调用。
      • 补充:是真实的核心实现,做了所有的实际工作

共同点:

  • 最终实现的业务功能完全一致:都是为「结构体的指定方法」打桩,替换原方法的执行逻辑;
  • 入参校验、底层打桩逻辑完全一致:全局函数只是转发调用,所有的校验(方法是否存在、函数签名是否匹配)、内存指令改写,都是结构体成员方法做的;
  • 都支持链式调用:返回值都是*Patches,都可以继续追加.ApplyXXX()系列方法;
  • 都需要手动调用Reset()还原桩:不管是全局函数返回的实例,还是手动创建的实例,最终都要调用Reset(),否则会导致后续测试被污染。

方法目录:

官方包源码文件,地址

package gomonkey

import (
"fmt"
"reflect"
"syscall"
"unsafe"

"github.com/agiledragon/gomonkey/v2/creflect"
)

type Patches struct {
originals    map[uintptr][]byte
targets      map[uintptr]uintptr
values       map[reflect.Value]reflect.Value
valueHolders map[reflect.Value]reflect.Value
}

type Params []interface{}
type OutputCell struct {
Values Params
Times  int
}

func ApplyFunc(target, double interface{}) *Patches {
return create().ApplyFunc(target, double)
}

func ApplyMethod(target interface{}, methodName string, double interface{}) *Patches {
return create().ApplyMethod(target, methodName, double)
}

func ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches {
return create().ApplyMethodFunc(target, methodName, doubleFunc)
}

func ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches {
return create().ApplyPrivateMethod(target, methodName, double)
}

func ApplyGlobalVar(target, double interface{}) *Patches {
return create().ApplyGlobalVar(target, double)
}

func ApplyFuncVar(target, double interface{}) *Patches {
return create().ApplyFuncVar(target, double)
}

func ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
return create().ApplyFuncSeq(target, outputs)
}

func ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches {
return create().ApplyMethodSeq(target, methodName, outputs)
}

func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
return create().ApplyFuncVarSeq(target, outputs)
}

func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches {
return create().ApplyFuncReturn(target, output...)
}

func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches {
return create().ApplyMethodReturn(target, methodName, output...)
}

func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches {
return create().ApplyFuncVarReturn(target, output...)
}

func create() *Patches {
return &Patches{originals: make(map[uintptr][]byte), targets: map[uintptr]uintptr{},
values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
}

func NewPatches() *Patches {
return create()
}

func (this *Patches) Origin(fn func()) {
for target, bytes := range this.originals {
modifyBinary(target, bytes)
}
fn()
for target, targetPtr := range this.targets {
code := buildJmpDirective(targetPtr)
modifyBinary(target, code)
}
}

func (this *Patches) ApplyFunc(target, double interface{}) *Patches {
t := reflect.ValueOf(target)
d := reflect.ValueOf(double)
return this.ApplyCore(t, d)
}

func (this *Patches) ApplyMethod(target interface{}, methodName string, double interface{}) *Patches {
m, ok := castRType(target).MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}
d := reflect.ValueOf(double)
return this.ApplyCore(m.Func, d)
}

func (this *Patches) ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches {
m, ok := castRType(target).MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}
d := funcToMethod(m.Type, doubleFunc)
return this.ApplyCore(m.Func, d)
}

func (this *Patches) ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches {
m, ok := creflect.MethodByName(castRType(target), methodName)
if !ok {
panic("retrieve method by name failed")
}
d := reflect.ValueOf(double)
return this.ApplyCoreOnlyForPrivateMethod(m, d)
}

func (this *Patches) ApplyGlobalVar(target, double interface{}) *Patches {
t := reflect.ValueOf(target)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}

this.values[t] = reflect.ValueOf(t.Elem().Interface())
d := reflect.ValueOf(double)
t.Elem().Set(d)
return this
}

func (this *Patches) ApplyFuncVar(target, double interface{}) *Patches {
t := reflect.ValueOf(target)
d := reflect.ValueOf(double)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}
this.check(t.Elem(), d)
return this.ApplyGlobalVar(target, double)
}

func (this *Patches) ApplyFuncSeq(target interface{}, outputs []OutputCell) *Patches {
funcType := reflect.TypeOf(target)
t := reflect.ValueOf(target)
d := getDoubleFunc(funcType, outputs)
return this.ApplyCore(t, d)
}

func (this *Patches) ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches {
m, ok := castRType(target).MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}
d := getDoubleFunc(m.Type, outputs)
return this.ApplyCore(m.Func, d)
}

func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
t := reflect.ValueOf(target)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}
if t.Elem().Kind() != reflect.Func {
panic("target is not a func")
}

funcType := reflect.TypeOf(target).Elem()
double := getDoubleFunc(funcType, outputs).Interface()
return this.ApplyGlobalVar(target, double)
}

func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches {
funcType := reflect.TypeOf(target)
t := reflect.ValueOf(target)
outputs := []OutputCell{{Values: returns, Times: -1}}
d := getDoubleFunc(funcType, outputs)
return this.ApplyCore(t, d)
}

func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches {
m, ok := reflect.TypeOf(target).MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}

outputs := []OutputCell{{Values: returns, Times: -1}}
d := getDoubleFunc(m.Type, outputs)
return this.ApplyCore(m.Func, d)
}

func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches {
t := reflect.ValueOf(target)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}
if t.Elem().Kind() != reflect.Func {
panic("target is not a func")
}

funcType := reflect.TypeOf(target).Elem()
outputs := []OutputCell{{Values: returns, Times: -1}}
double := getDoubleFunc(funcType, outputs).Interface()
return this.ApplyGlobalVar(target, double)
}

func (this *Patches) Reset() {
for target, bytes := range this.originals {
modifyBinary(target, bytes)
delete(this.originals, target)
}

for target, variable := range this.values {
target.Elem().Set(variable)
}
}

func (this *Patches) ApplyCore(target, double reflect.Value) *Patches {
this.check(target, double)
assTarget := *(*uintptr)(getPointer(target))
original := replace(assTarget, uintptr(getPointer(double)))
if _, ok := this.originals[assTarget]; !ok {
this.originals[assTarget] = original
}
this.targets[assTarget] = uintptr(getPointer(double))
this.valueHolders[double] = double
return this
}

func (this *Patches) ApplyCoreOnlyForPrivateMethod(target unsafe.Pointer, double reflect.Value) *Patches {
if double.Kind() != reflect.Func {
panic("double is not a func")
}
assTarget := *(*uintptr)(target)
original := replace(assTarget, uintptr(getPointer(double)))
if _, ok := this.originals[assTarget]; !ok {
this.originals[assTarget] = original
}
this.targets[assTarget] = uintptr(getPointer(double))
this.valueHolders[double] = double
return this
}

func (this *Patches) check(target, double reflect.Value) {
if target.Kind() != reflect.Func {
panic("target is not a func")
}

if double.Kind() != reflect.Func {
panic("double is not a func")
}

targetType := target.Type()
doubleType := double.Type()

if targetType.NumIn() < doubleType.NumIn() ||
targetType.NumOut() != doubleType.NumOut() ||
(targetType.NumIn() == doubleType.NumIn() && targetType.IsVariadic() != doubleType.IsVariadic()) {
panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
}

for i, size := 0, doubleType.NumIn(); i < size; i++ {
targetIn := targetType.In(i)
doubleIn := doubleType.In(i)

if targetIn.AssignableTo(doubleIn) {
 continue
}

panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
}

for i, size := 0, doubleType.NumOut(); i < size; i++ {
targetOut := targetType.Out(i)
doubleOut := doubleType.Out(i)

if targetOut.AssignableTo(doubleOut) {
 continue
}

panic(fmt.Sprintf("target type(%s) and double type(%s) are different", target.Type(), double.Type()))
}
}

func replace(target, double uintptr) []byte {
code := buildJmpDirective(double)
bytes := entryAddress(target, len(code))
original := make([]byte, len(bytes))
copy(original, bytes)
modifyBinary(target, code)
return original
}

func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
if funcType.NumOut() != len(outputs[0].Values) {
panic(fmt.Sprintf("func type has %v return values, but only %v values provided as double",
 funcType.NumOut(), len(outputs[0].Values)))
}

needReturn := false
slice := make([]Params, 0)
for _, output := range outputs {
if output.Times == -1 {
 needReturn = true
 slice = []Params{output.Values}
 break
}
t := 0
if output.Times <= 1 {
 t = 1
} else {
 t = output.Times
}
for j := 0; j < t; j++ {
 slice = append(slice, output.Values)
}
}

i := 0
lenOutputs := len(slice)
return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
if needReturn {
 return GetResultValues(funcType, slice[0]...)
}
if i < lenOutputs {
 i++
 return GetResultValues(funcType, slice[i-1]...)
}
panic("double seq is less than call seq")
})
}

func GetResultValues(funcType reflect.Type, results ...interface{}) []reflect.Value {
var resultValues []reflect.Value
for i, r := range results {
var resultValue reflect.Value
if r == nil {
 resultValue = reflect.Zero(funcType.Out(i))
} else {
 v := reflect.New(funcType.Out(i))
 v.Elem().Set(reflect.ValueOf(r))
 resultValue = v.Elem()
}
resultValues = append(resultValues, resultValue)
}
return resultValues
}

type funcValue struct {
_ uintptr
p unsafe.Pointer
}

func getPointer(v reflect.Value) unsafe.Pointer {
return (*funcValue)(unsafe.Pointer(&v)).p
}

func entryAddress(p uintptr, l int) []byte {
return *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: p, Len: l, Cap: l}))
}

func pageStart(ptr uintptr) uintptr {
return ptr & ^(uintptr(syscall.Getpagesize() - 1))
}

func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value {
rf := reflect.TypeOf(doubleFunc)
if rf.Kind() != reflect.Func {
panic("doubleFunc is not a func")
}
vf := reflect.ValueOf(doubleFunc)
return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value {
if funcType.IsVariadic() {
 return vf.CallSlice(in[1:])
} else {
 return vf.Call(in[1:])
}
})
}

func castRType(val interface{}) reflect.Type {
if rTypeVal, ok := val.(reflect.Type); ok {
return rTypeVal
}
return reflect.TypeOf(val)
}

3.使用示例

示例(1)使用方法一全局函数调用,其余使用结构体方法调用。

(1)函数打桩方法

gomonkey.ApplyFunc()

相关代码在gitee代码仓库的示例代码中,仓库地址请看博客开头

参数示例:

// 第一个参数:函数名
// 第二个参数:打桩函数,入参和出参要保持和被打桩函数保持一致
func ApplyFunc(target, double interface{}) *Patches {
	return create().ApplyFunc(target, double)
}

user.go

package func_demo

import "tracer/model/sqlmock_demo"

// GetUserInfo 查询所有用户信息
func GetUserInfo() (interface{}, error) {
	obj, err := sqlmock_demo.GetAllUser()
	if err != nil {
		return nil, err
	}
	return obj, nil
}

user_test.go

package func_demo

import (
	"errors"
	"fmt"
	"github.com/agiledragon/gomonkey/v2"
	"gorm.io/gorm"
	"testing"
	"tracer/model"
	"tracer/model/sqlmock_demo"
)

// TestUserInfo 单个函数通过循环覆盖所有测试场景
func TestUserInfo(t *testing.T) {
	// 1. 定义测试用例结构体:封装输入(打桩参数)和预期输出
	type testCase struct {
		name          string                  // 用例名称,便于排查错误
		mockUsers     []sqlmock_demo.UserInfo // 打桩 GetAllUser 返回的用户列表
		mockErr       error                   // 打桩 GetAllUser 返回的错误
		expectedErr   error                   // 预期 UserInfoDao 返回的错误
		expectedNil   bool                    // 预期 UserInfoDao 返回的数据是否为 nil
		expectedCount int                     // 预期返回的用户数量(正常场景有效)
	}

	// 2. 构造所有测试用例(正常场景 + 异常场景)
	testCases := []testCase{
		{
			name: "正常场景-返回2个用户",
			mockUsers: []sqlmock_demo.UserInfo{
				{
					Model:    gorm.Model{ID: 1},
					UserName: "zhangsan",
					Password: "123456",
					Phone:    "13800138000",
					Email:    "zhangsan@test.com",
				},
				{
					Model:    gorm.Model{ID: 2},
					UserName: "lisi",
					Password: "654321",
					Phone:    "13900139000",
					Email:    "lisi@test.com",
				},
			},
			mockErr:       nil,
			expectedErr:   nil,
			expectedNil:   false,
			expectedCount: 2,
		},
		{
			name:          "正常场景-返回空用户列表",
			mockUsers:     []sqlmock_demo.UserInfo{},
			mockErr:       nil,
			expectedErr:   nil,
			expectedNil:   false,
			expectedCount: 0,
		},
		{
			name:          "异常场景-GORM记录不存在错误",
			mockUsers:     nil,
			mockErr:       gorm.ErrRecordNotFound,
			expectedErr:   gorm.ErrRecordNotFound,
			expectedNil:   true,
			expectedCount: 0,
		},
		{
			name:          "异常场景-自定义查询错误",
			mockUsers:     nil,
			mockErr:       errors.New("数据库连接超时"),
			expectedErr:   errors.New("数据库连接超时"),
			expectedNil:   true,
			expectedCount: 0,
		},
	}

	// 3. 循环执行所有测试用例
	for _, tc := range testCases {
		model.InitDb()
		// t.Run:为每个用例创建独立的测试上下文,互不干扰,便于定位用例错误
		t.Run(tc.name, func(t *testing.T) {
			// 步骤1:对 GetAllUser 进行动态打桩(每个用例独立打桩,避免相互影响)
			// 使用ApplyFunc打桩跨包函数
			patches := gomonkey.ApplyFunc(sqlmock_demo.GetAllUser, func() ([]sqlmock_demo.UserInfo, error) {
				// 返回当前用例预设的模拟数据和错误
				return tc.mockUsers, tc.mockErr
			})
			defer patches.Reset() // 每个用例执行完毕后重置打桩,避免污染其他用例

			// 步骤2:执行待测试函数 GetUserInfo
			_, err := GetUserInfo()
			if err != nil {
				fmt.Println(err)
			}
		})
	}
}

命令行执行命令

go test -cover -gcflags=all=-l -covermode=atomic

结果:

PS D:\wyl\workspace\go\tracer\logic\func_demo> go test -cover
PASS
coverage: 75.0% of statements
ok      tracer/logic/func_demo  0.088s

如果报错,这个问题是数据库中不存在表:

Error 1146 (42S02): Table 'tracer.user_info' doesn't exist
(2)结构体方法打桩方法

gomonkey.ApplyMethod()、gomonkey.ApplyMethodFunc()

区别:

  • 匹配方式不同ApplyMethod 是名称匹配,ApplyMethodFunc 是函数本体匹配
  • 传参核心不同ApplyMethod 必须传 reflect.Type+方法名字符串ApplyMethodFunc 直接传 原方法函数不需要反射
  • 底层逻辑不同ApplyMethod 是「反射查找方法」,ApplyMethodFunc 是「直接绑定方法函数」,后者性能更高

gomonkey.ApplyMethod() 参数示例:

// 第一个参数:要打桩的方法所属的类型,通过 reflect.TypeOf(实例) 获取,区分值接收者 / 指针接收者
// 第二个参数:要打桩的方法名,字符串格式、大小写敏感,必须和原方法名完全一致
// 第三个参数:打桩方法,入参和出参要保持和被打桩方法保持一致,但需注意需要额外传入结构体类型且必须是第一个参数
func ApplyMethod(target interface{}, methodName string, double interface{}) *Patches {
	return create().ApplyMethod(target, methodName, double)
}

gomonkey.ApplyMethodFunc() 参数示例:

// 第一个参数:要打桩的方法所属的类型,通过 reflect.TypeOf(实例) 获取,区分值接收者 / 指针接收者
// 第二个参数:要打桩的方法名,字符串格式、大小写敏感,必须和原方法名完全一致
// 第三个参数:打桩方法,入参和出参要保持和被打桩方法保持一致,不需要额外参数
func ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches {
	return create().ApplyMethodFunc(target, methodName, doubleFunc)
}

注意:可以不用写reflect.TypeOf(实例),如果了解方法所属类型,可以直接写方法类型,而不用reflect.TypeOf()在获取一次

gomonkey.ApplyMethod()

method_demo.go

type MethodDemo struct {
}

func (m MethodDemo) MethodDemo(ret string) {
	fmt.Println("MethodDemo:", ret)
}

method_demo_test.go

// ========== 基于 gomonkey.ApplyMethod() 的单元测试 ==========
func TestMethodDemo(t *testing.T) {
	// 1. 初始化结构体实例(当前结构体无成员变量,直接实例化即可)
	md := MethodDemo{}
	patches := gomonkey.NewPatches()
	// 铁律:延迟撤销打桩,防止污染其他测试用例,必写!
	defer patches.Reset()

	// 2. 核心:使用 gomonkey.ApplyMethod() 对【值接收者方法】打桩
	// 第一个参数:reflect.TypeOf(实例) → 因为是值接收者,直接传值类型实例即可
	// 第二个参数:被打桩的方法名字符串(严格和原方法名一致,大小写敏感)
	// 第三个参数:mock桩函数 → 入参/返回值 必须和原方法完全一致
	patches.ApplyMethod(
		// 可以直接写 MethodDemo{}
		reflect.TypeOf(md),
		"MethodDemo",
		func(m MethodDemo, ret string) {
			// 自定义的mock逻辑,替代原方法的 fmt.Println 逻辑
			t.Log("mock执行成功,入参ret:", ret)
		},
	)

	// 3. 调用原方法,验证打桩是否生效
	md.MethodDemo("hello gomonkey")
}

命令行执行命令

go test -run "^TestMethodDemo$" -cover

结果:

PS D:\wyl\workspace\go\tracer\logic\method_demo> go test -run "^TestMethodDemo$" -cover       
MethodDemo: hello gomonkey
PASS
coverage: 50.0% of statements
ok      tracer/logic/method_demo        0.303s
gomonkey.ApplyMethodFunc()

method_func_demo.go

type MethodFuncDemo struct {
}

func (m MethodFuncDemo) MethodFuncDemo(ret string) {
	fmt.Println("MethodFuncDemo:", ret)
}

method_func_demo_test.go

// ==========  【结构体值类型绑定】对应的单元测试(纯值类型,无任何指针语法) ==========
func TestMethodFuncDemo(t *testing.T) {
	// 1. 初始化【结构体值类型实例】 核心✅ 无指针&,纯结构体类型绑定
	mfd := MethodFuncDemo{}
	patches := gomonkey.NewPatches()
	// 铁律:延迟撤销打桩,防止污染其他测试用例,必写!
	defer patches.Reset()

	// 2. 核心:gomonkey.ApplyMethodFunc 三参数打桩【结构体值类型绑定】
	// 三参数固定规则:值类型 = 值实例 + 方法名字符串 + 值类型桩函数
	patches.ApplyMethodFunc(
		mfd,              // 参数1:结构体值类型实例(核心,纯值绑定)
		"MethodFuncDemo", // 参数2:方法名字符串(和值接收者方法名一致)
		func(ret string) { // 参数3:桩函数【无*号,纯结构体值类型入参】✅必匹配
			// 桩函数第一个入参必须是:纯结构体类型 MethodFuncDemo,无任何指针
			t.Log("✅ 结构体值类型绑定打桩生效,入参ret = ", ret)
		},
	)

	// 3. 调用【值接收者方法】,验证结构体值类型绑定打桩结果
	mfd.MethodFuncDemo("hello gomonkey 结构体值类型绑定")
}

命令行执行命令

go test -run "^TestMethodFuncDemo$" -gcflags=all=-l

结果:

PS D:\wyl\workspace\go\tracer\logic\method_demo> go test -run "^TestMethodFuncDemo$"
MethodFuncDemo: hello gomonkey 结构体值类型绑定
PASS
ok      tracer/logic/method_demo        0.243s

PS D:\wyl\workspace\go\tracer\logic\method_demo> go test -run "^TestMethodFuncDemo$" -gcflags=all=-l
PASS
ok      tracer/logic/method_demo        0.219s
(3)指针接收者方法打桩方法

gomonkey.ApplyMethod()、gomonkey.ApplyMethodFunc()

区别:

  • 匹配方式不同ApplyMethod 是名称匹配,ApplyMethodFunc 是函数本体匹配
  • 传参核心不同ApplyMethod 必须传 reflect.Type+方法名字符串ApplyMethodFunc 直接传 原方法函数不需要反射
  • 底层逻辑不同ApplyMethod 是「反射查找方法」,ApplyMethodFunc 是「直接绑定方法函数」,后者性能更高

gomonkey.ApplyMethod() 参数示例:

// 第一个参数:要打桩的方法所属的类型,通过 reflect.TypeOf(实例) 获取,区分值接收者 / 指针接收者
// 第二个参数:要打桩的方法名,字符串格式、大小写敏感,必须和原方法名完全一致
// 第三个参数:打桩方法,入参和出参要保持和被打桩方法保持一致,但需注意需要额外传入结构体类型且必须是第一个参数
func ApplyMethod(target interface{}, methodName string, double interface{}) *Patches {
	return create().ApplyMethod(target, methodName, double)
}

gomonkey.ApplyMethodFunc() 参数示例:

// 第一个参数:要打桩的方法所属的类型,通过 reflect.TypeOf(实例) 获取,区分值接收者 / 指针接收者
// 第二个参数:要打桩的方法名,字符串格式、大小写敏感,必须和原方法名完全一致
// 第三个参数:打桩方法,入参和出参要保持和被打桩方法保持一致,不需要额外参数
func ApplyMethodFunc(target interface{}, methodName string, doubleFunc interface{}) *Patches {
	return create().ApplyMethodFunc(target, methodName, doubleFunc)
}

注意:可以不用写reflect.TypeOf(实例),如果了解方法所属类型,可以直接写方法类型,而不用reflect.TypeOf()在获取一次

gomonkey.ApplyMethod()

method_demo.go

type MethodDemo struct {
}

func (m MethodDemo) MethodDemo(ret string) {
	fmt.Println("MethodDemo:", ret)
}

func (m *MethodDemo) MethodPointerDemo(ret string) {
	fmt.Println("MethodPointerDemo:", ret)
}

method_demo_test.go

// ========== 基于 gomonkey.ApplyMethod() 的单元测试【指针接收者专用写法】 ==========
func TestMethodPointerDemo(t *testing.T) {
	// 1. 初始化【指针类型】的结构体实例 (必须是指针,和方法接收者对应)
	md := &MethodDemo{}
	patches := gomonkey.NewPatches()
	// 铁律:延迟撤销打桩,防止污染其他测试用例,必写!
	defer patches.Reset()

	// 2. 核心:gomonkey.ApplyMethod() 打桩【指针接收者方法】
	// 第一个参数:reflect.TypeOf(md) 传入指针实例,获取 *MethodDemo 的反射类型 【必须传指针实例】
	// 第二个参数:方法名字符串,严格大小写一致
	// 第三个参数:mock桩函数,入参规则严格匹配
	patches.ApplyMethod(
		// 可以直接写 &MethodDemo{}
		reflect.TypeOf(md),
		"MethodPointerDemo",
		// mock桩函数规则:
		// ① 第一个入参:必须是【指针接收者】 *MethodDemo ,和原方法一致
		// ② 第二个入参:原方法的入参 ret string,和原方法一致
		// ③ 原方法无返回值,桩函数也必须无返回值
		func(m *MethodDemo, ret string) {
			// 自定义mock逻辑,替代原方法的 fmt.Println 逻辑
			t.Log("✅ 指针方法打桩生效,入参ret = ", ret)
		},
	)

	// 3. 调用被打桩的指针方法,验证mock是否生效
	md.MethodPointerDemo("hello gomonkey 指针接收者")
}

命令行执行命令

go test -run "^TestMethodPointerDemo$" -gcflags=all=-l

结果:

PS D:\wyl\workspace\go\tracer\logic\method_demo> go test -run "^TestMethodPointerDemo$" -cover
MethodPointerDemo: hello gomonkey 指针接收者
PASS
coverage: 25.0% of statements
ok      tracer/logic/method_demo        0.277s

PS D:\wyl\workspace\go\tracer\logic\method_demo> go test -run "^TestMethodPointerDemo$" -gcflags=all=-l                         
PASS
ok      tracer/logic/method_demo        0.266s
gomonkey.ApplyMethodFunc()

method_func_demo.go

type MethodFuncDemo struct {
}

func (m MethodFuncDemo) MethodFuncDemo(ret string) {
	fmt.Println("MethodFuncDemo:", ret)
}

func (m *MethodFuncDemo) MethodFuncPrinterDemo(ret string) {
	fmt.Println("MethodFuncPrinterDemo:", ret)
}

method_func_demo_test.go

// ========== 指针接收者 对应的 ApplyMethodFunc 单元测试 ==========
func TestMethodFuncPointerDemo(t *testing.T) {
	// 1. 初始化指针类型的结构体实例
	mfd := &MethodFuncDemo{}
	patches := gomonkey.NewPatches()
	// 铁律:延迟撤销打桩,防止污染其他测试用例,必写!
	defer patches.Reset()

	// 2. 核心:gomonkey.ApplyMethodFunc 打桩【指针接收者方法】
	// ✅ 第一个参数:直接传【指针接收者的方法本体】 语法固定:(*结构体名).方法名
	patches.ApplyMethodFunc(
		// 重点:指针接收者的方法本体写法
		mfd,
		"MethodFuncPrinterDemo",
		func(m *MethodFuncDemo, ret string) {
			// 桩函数第一个入参必须是 指针类型 *MethodFuncDemo
			t.Log("✅ MethodFuncPrinterDemo 指针接收者打桩生效,入参ret = ", ret)
		},
	)

	// 3. 调用指针方法
	mfd.MethodFuncPrinterDemo("hello gomonkey 指针接收者")
}

命令行执行命令

go test -run "^TestMethodFuncPointerDemo$" -gcflags=all=-l

结果:

PS D:\wyl\workspace\go\tracer\logic\method_demo> go test -run "^TestMethodFuncPointerDemo$" -cover
MethodFuncPrinterDemo: hello gomonkey 指针接收者
PASS
coverage: 25.0% of statements
ok      tracer/logic/method_demo        0.275s

PS D:\wyl\workspace\go\tracer\logic\method_demo> go test -run "^TestMethodFuncPointerDemo$" -gcflags=all=-l                     
PASS
ok      tracer/logic/method_demo        0.227s
(4)多次调用,顺序返回不同结果

gomonkey.ApplyMethodSeq()(高频实用)方法

处理「结构体方法被多次调用,需要返回不同结果」的核心方法,完美解决:方法第 1 次调用返回 A、第 2 次返回 B、第 N 次返回默认值 / 报错 这类高频业务场景。

这个也和上面(2)和(3)一样支持两种写法,函数和结构体方法,写法类似,这里就只写一个示例了

参数示例:

// 存储单次调用的返回值集合
type Params []interface{}
type OutputCell struct {
	Values Params // 本次要返回的参数,个数/类型必须和原方法返回值完全一致
	Times  int    // 生效次数:-1=永久生效(匹配后后续调用全用这个);>0=生效指定次数
}
// 第一个参数:要打桩的方法所属的类型,通过 reflect.TypeOf(实例) 获取,区分值接收者 / 指针接收者
// 第二个参数:要打桩的方法名,字符串格式、大小写敏感,必须和原方法名完全一致
// 第三个参数:打桩方法返回值切片,入参和出参要保持和被打桩方法保持一致
func (this *Patches) ApplyMethodSeq(target interface{}, methodName string, outputs []OutputCell) *Patches {
	m, ok := castRType(target).MethodByName(methodName)
	if !ok {
		panic("retrieve method by name failed")
	}
	d := getDoubleFunc(m.Type, outputs)
	return this.ApplyCore(m.Func, d)
}

product.go

package method_demo

import "fmt"

// ========== 1. 定义业务结构体和待打桩的方法 ==========
type Product struct {
	Id    int
	Name  string
	Stock int
}

// 结构体公有方法(首字母大写):库存查询,返回【库存数量、错误信息】
func (p *Product) GetStock() (int, error) {
	// 真实业务逻辑:查询数据库/缓存获取库存
	return p.Stock, nil
}

// ========== 2. 业务逻辑:连续调用3次GetStock方法 ==========
// 模拟业务中多次调用结构体方法的场景
func QueryStockMultiTimes(p *Product) {
	// 第1次调用
	stock1, err1 := p.GetStock()
	fmt.Printf("第1次查询库存: %d, err: %v\n", stock1, err1)

	// 第2次调用
	stock2, err2 := p.GetStock()
	fmt.Printf("第2次查询库存: %d, err: %v\n", stock2, err2)

	// 第3次调用
	stock3, err3 := p.GetStock()
	fmt.Printf("第3次查询库存: %d, err: %v\n", stock3, err3)
}

product_test.go

package method_demo

import (
	"errors"
	"github.com/agiledragon/gomonkey/v2"
	"testing"
)

func TestProduct(t *testing.T) {
	// 1. 创建patches实例,必须写defer Reset(),保证打桩还原,无残留
	patches := gomonkey.NewPatches()
	defer patches.Reset()

	// 2. 初始化结构体实例
	prod := &Product{Id: 1, Name: "苹果手机", Stock: 200}
	// 3. 核心:定义序列打桩规则
	stockSeq := []gomonkey.OutputCell{
		{Values: gomonkey.Params{100, nil}, Times: 1},                    // 规则1:第1次调用,返回100,nil
		{Values: gomonkey.Params{0, nil}, Times: 1},                      // 规则2:第2次调用,返回0,nil
		{Values: gomonkey.Params{0, errors.New("库存不足,无法下单")}, Times: 1}, // 规则3:永久生效(-1)这个参数会导致严重bug,避免使用
	}

	// 4. 执行方法序列打桩
	patches.ApplyMethodSeq(prod, "GetStock", stockSeq)

	// 5. 执行业务逻辑,触发多次调用
	QueryStockMultiTimes(prod)
}

命令行执行命令

go test -run "^TestProduct$" -cover -gcflags=all=-l -covermode=atomic

结果:

PS D:\wyl\workspace\go\tracer\logic\method_demo> go test -run "^TestProduct$" -cover -gcflags=all=-l -covermode=atomic
第1次查询库存: 100, err: <nil>
第2次查询库存: 0, err: <nil>
第3次查询库存: 0, err: 库存不足,无法下单
PASS
coverage: 54.5% of statements
ok      tracer/logic/method_demo        0.330s
(5)结构体私有方法打桩方法

gomonkey.ApplyPrivateMethod()

私有方法打桩方案:

  • 方案一:【最优推荐,零侵入、Go 官方标准】测试文件和业务文件放在同一个包下(首选,无任何副作用)
  • 方案二:【妥协方案,少量侵入】把「需要 mock 的私有方法」改为包级私有函数(适合特殊场景)
  • 方案三:【不推荐,侵入性大】把私有方法改为公有方法(万不得已才用)

ApplyPrivateMethodApplyMethod 打私有方法区别:功能上完全一致,同包下都能正常打桩、正常 mock 逻辑。唯一的区别是:

  • ApplyPrivateMethod:语义精准,一看就知道是打「私有方法」,符合官方设计;
  • ApplyMethod:语义模糊,它的设计初衷是打「公有方法」,打私有方法只是兼容生效。

这个也和上面(2)和(3)一样支持两种写法,函数和结构体方法,写法类似,这里就只写一个示例了

参数示例:

// 第一个参数:要打桩的方法所属的类型,通过 reflect.TypeOf(实例) 获取,区分值接收者 / 指针接收者
// 第二个参数:要打桩的方法名,字符串格式、大小写敏感,必须和原方法名完全一致
// 第三个参数:打桩方法,入参和出参要保持和被打桩方法保持一致
func (this *Patches) ApplyPrivateMethod(target interface{}, methodName string, double interface{}) *Patches {
	m, ok := creflect.MethodByName(castRType(target), methodName)
	if !ok {
		panic("retrieve method by name failed")
	}
	d := reflect.ValueOf(double)
	return this.ApplyCoreOnlyForPrivateMethod(m, d)
}

student.go

package private_demo

type Student struct{}

// 公有方法(对外暴露)
func (s Student) Study(course string) bool {
	// 内部调用了私有方法
	if s.checkCourse(course) { // 私有方法,首字母小写
		s.recordLog(course) // 私有方法
		return true
	}
	return false
}

// 私有方法(内部逻辑)
func (s Student) checkCourse(course string) bool {
	return course != ""
}
func (s Student) recordLog(course string) {}

student_test.go

package private_demo

import (
	"github.com/agiledragon/gomonkey/v2"
	"testing"
)

// TODO: TestStudy有问题暂时不确定是因为不支持这种写法还是代码有问题
func TestStudy(t *testing.T) {
	s := &Student{}
	patches := gomonkey.NewPatches()
	defer patches.Reset()

	// ✅ 可以直接用gomonkey打桩【私有方法】checkCourse,零权限问题!
	patches.ApplyMethod(&Student{}, "checkCourse", func(_ *Student, course string) bool {
		return true // mock私有方法返回true
	})

	res := s.Study("math")
	if !res {
		t.Error("测试失败")
	}
}

// 测试代码:指针接收者的私有方法打桩
func TestStudyPrivate(t *testing.T) {
	s := &Student{}
	patches := gomonkey.NewPatches()
	defer patches.Reset()

	patches.ApplyPrivateMethod(
		(*Student)(nil), // ✅ 指针接收者传 (*结构体)(nil)
		"checkCourse",   // 私有方法名不变
		func(_ *Student, course string) bool { // ✅ 桩函数接收者是指针类型
			return true
		},
	)
	res := s.Study("math")
	if !res {
		t.Error("测试失败")
	}
}

命令行执行命令

go test -run "^TestStudyPrivate$" -gcflags=all=-l

结果:

PS D:\wyl\workspace\go\tracer\logic\private_demo> go test -run "^TestStudyPrivate$" -gcflags=all=-l
PASS
ok      tracer/logic/private_demo       0.210s
posted @ 2026-01-03 20:55  HashFlag  阅读(25)  评论(0)    收藏  举报