Golang 中 mock 库的实现

mock 库的地址: https://github.com/golang/mock

mock 库算是 go 项目中编写单元测试时的必备库了,它分为两个模块

  1. mockgen: 可以根据接口来生成单元测试代码
  2. gomock: 利用 mockgen 生成测试代码来实现打桩 (stub) 功能

其实之前对这个库就有一些好奇,这次趁着五一在家隔离,所以看了看 mock 库的实现。

好奇点

首先列举一下我好奇的点,之后围绕着这些点在代码中寻找答案。

  1. mockgen 是如何根据接口生成代码的?
  2. gomock 是怎样将 Expect() 中指定的参数 (ArgsIn) 与执行时接收到参数进行匹配的?
  3. gomock 如何做到在执行接口时的返回在 Expect() 中指定的返回值?
  4. gomock 是怎么判断某个方法的执行次数与预期不符的?

mockgen 是如何根据接口生成代码的?

首先把 mock 库 clone 到本地,寻找 mockgen 的入口函数 (main) ,在 mockgen/mockgen.go 文件中。

我常用的利用 mockgen 生成代码的命令为为

mockgen -destination ../examplemock/example_mock.go -package examplemock -source example.go IExample

-destination 指定了 mock 代码的目标路径 (destination)

-package 指定了 mock 代码所在的包名 (packageOut)

-source 指定了源文件名 (source)

IExample 指定了接口名 (srcInterfaces)。

main() 函数的逻辑为如下几步

  1. parse flag
  2. parse 源文件得到 model.Package 对象
  3. 创建目标文件,以及文件句柄
  4. 创建代码生成器 generator 对象,并给对象内的字段赋值
  5. 利用generator 对象生成代码
  6. 将生成的代码输出到目标文件中
func main() {
	// parse flag
	flag.Parse()

	var pkg *model.Package
	var err error
	if *source != "" {
		// parse 源文件得到 pkg 对象
		pkg, err = sourceMode(*source)
	} else {
		// ...
	}
	// ...

	// 创建目标文件,以及文件句柄
	dst := os.Stdout
	if len(*destination) > 0 {
		if err := os.MkdirAll(filepath.Dir(*destination), os.ModePerm); err != nil {
			log.Fatalf("Unable to create directory: %v", err)
		}
		// ...
	}

	outputPackageName := *packageOut

	// ...
	// 创建代码生成器 generator 对象,并给对象内的字段赋值
	g := new(generator)
	if *source != "" {
		g.filename = *source
	} else {
		// ...
	}
	g.destination = *destination

	// ...
	// 利用 g (generator 对象) 生成代码
	if err := g.Generate(pkg, outputPackageName, outputPackagePath); err != nil {
		log.Fatalf("Failed generating mock: %v", err)
	}
	// 将生成的代码输出到文件中
	if _, err := dst.Write(g.Output()); err != nil {
		log.Fatalf("Failed writing to destination: %v", err)
	}
}
model.Package 是如何生成的?

先看看 Package struct 的定义:

// mockgen/model/model.go
// Package is a Go package. It may be a subset.
type Package struct {
	Name       string
	PkgPath    string
	Interfaces []*Interface
	DotImports []string
}
// Interface is a Go interface.
type Interface struct {
	Name    string
	Methods []*Method
}
// Method is a single method of an interface.
type Method struct {
	Name     string
	In, Out  []*Parameter
	Variadic *Parameter // may be nil
}
// Parameter is an argument or return parameter of a method.
type Parameter struct {
	Name string // may be empty
	Type Type
}

Package 对象是通过 sourceMode() 函数完成的,它内部用到了 go 标准库 parse 包的方法,parse 包我就不去深究了,估计是一些 AST 语法树相关的知识。这里我获取到的启发就是,如果以后碰到需要分析 go 文件的需求,可以使用标准库提供的 parse 包。

import (
	// ...
	"go/parser"
	"go/token"
	// ...
)

// sourceMode generates mocks via source file.
func sourceMode(source string) (*model.Package, error) {
	// ...
	fs := token.NewFileSet()
	file, err := parser.ParseFile(fs, source, nil, 0)
	// ...
}
generator 是如何生成代码的?

generator 对象通过 Generate 这个方法来生成代码

func (g *generator) p(format string, args ...interface{}) {
	fmt.Fprintf(&g.buf, g.indent+format+"\n", args...)
}

func (g *generator) Generate(pkg *model.Package, outputPkgName string, outputPackagePath string) error {
	// ...
	// 生成包名和 import 语句
	g.p("package %v", outputPkgName)
	g.p("")
	g.p("import (")
	g.in()
	for pkgPath, pkgName := range g.packageMap {
		if pkgPath == outputPackagePath {
			continue
		}
		g.p("%v %q", pkgName, pkgPath)
	}
	for _, pkgPath := range pkg.DotImports {
		g.p(". %q", pkgPath)
	}
	g.out()
	g.p(")")

	// 生成接口
	for _, intf := range pkg.Interfaces {
		if err := g.GenerateMockInterface(intf, outputPackagePath); err != nil {
			return err
		}
	}

	return nil
}

GenerateMockInterface 生成接口的 mock 代码,主要包括 mock struct 的定义以及 struct 的一些方法,基本都是一些 for range + Fprintf() 操作了。

gomock 的主流程

下面 gomock 相关的代码我们通过一个简单的单元测试来剖析:

func TestCallExample(t *testing.T) {
	ctrl := gomock.NewController(t)
	defer ctrl.Finish()

	e := NewMockExample(ctrl)
	e.EXPECT().someMethod(gomock.Any()).Return("it works!")

	fmt.Println(e.someMethod("test"))
}

Example 接口的定义为:

// Example is an interface with a non exported method
type Example interface {
	someMethod(string) string
}
MockExample 与 MockExampleMockRecorder 的定义

MockExample 与 MockExampleMockRecorder 两者互相套娃

// MockExample is a mock of Example interface.
type MockExample struct {
	ctrl     *gomock.Controller
	recorder *MockExampleMockRecorder
}

// MockExampleMockRecorder is the mock recorder for MockExample.
type MockExampleMockRecorder struct {
	mock *MockExample
}

理论上可以省掉 MockExampleMockRecorder 这个 struct 的定义,直接这样定义:

type MockExample struct {
	ctrl *gomock.Controller
	this *MockExample
}

对这么做的原因是因为我们需要两个同名的方法,这两个方法的作用不一样,一个为打桩,一个为真正执行。

MockExample 有一个 someMethod 方法,用于真正执行。

MockExampleMockRecorder 也得有一个 someMethod 方法,用于打桩。

显然无法为 MockExample 这个 struct 定义两个同名的方法。

e.EXPECT().someMethod(gomock.Any()).Return("it works!")

这句代码是三个方法的组合,分别为 EXPECT(), someMethod() 和 Return(),咱们一个一个分析

EXPECT()
// MockExample is a mock of Example interface.
type MockExample struct {
	ctrl     *gomock.Controller
	recorder *MockExampleMockRecorder
}

// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockExample) EXPECT() *MockExampleMockRecorder {
	return m.recorder
}

Expect() 仅仅就是返回了内部的 MockExampleMockRecorder 成员。

someMethod()
// someMethod indicates an expected call of someMethod.
func (mr *MockExampleMockRecorder) someMethod(arg0 interface{}) *gomock.Call {
	// ...
	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "someMethod", reflect.TypeOf((*MockExample)(nil).someMethod), arg0)
}

// RecordCallWithMethodType is called by a mock. It should not be called by user code.
func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
	// ...
	call := newCall(ctrl.T, receiver, method, methodType, args...)
	// ...
	ctrl.expectedCalls.Add(call)

	return call
}

someMethod() 调用了 Controller.RecordCallWithMethodType(),RecordCallWithMethodType() 内构造了一个 Call 对象,然后将 Call 对象放入了 expectedCalls 内

看看 Controller 对象的定义

type Controller struct {
	// ...
	expectedCalls *callSet
	finished      bool
}
type callSet struct {
	// Calls that are still expected.
	expected map[callSetKey][]*Call
	// Calls that have been exhausted.
	exhausted map[callSetKey][]*Call
}
// callSetKey is the key in the maps in callSet
type callSetKey struct {
	receiver interface{}
	fname    string
}

Controller.expectedCalls 类型为 callSet,内部有两个 map,expected 和 exhausted,map 的 key 是一个 struct(其实我们很少用 struct 作为 map 的 key),callSetKey 的作用就是唯一识别出一个方法。

newCall() 的实现可以先放在一旁,这里我们需要记住的就是 someMethod 方法生成了一个 Call 对象,并放入了 Controller.callSet 内,同时把 Call 对象的指针给返回了。

Return()
// Call represents an expected call to a mock.
type Call struct {
	// ...
	// actions are called when this Call is called. Each action gets the args and
	// can set the return values by returning a non-nil slice. Actions run in the
	// order they are created.
	actions []func([]interface{}) []interface{}
}

func (c *Call) Return(rets ...interface{}) *Call {
	// ...
	c.addAction(func([]interface{}) []interface{} {
		return rets
	})

	return c
}

func (c *Call) addAction(action func([]interface{}) []interface{}) {
	c.actions = append(c.actions, action)
}

所以 Return 是将 rets 封装成了一个函数,并把这个函数放入了 Call.actions 里面。

e.someMethod("test")

接下来看看 e.someMethod() 是如何执行的。

func (m *MockExample) someMethod(arg0 string) string {
	// ...
	ret := m.ctrl.Call(m, "someMethod", arg0)
	ret0, _ := ret[0].(string)
	return ret0
}

e.someMethod() 调用了 ctrl.Call(),并将 receiver,方法名和参数传递了进去。

Controller.Call()
// Call is called by a mock. It should not be called by user code.
func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
   // ...
   actions := func() []func([]interface{}) []interface{} {
      // ...
	  // 从 expectedCalls 找到之前注册的方法
      expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
      // ...
      actions := expected.call()
      if expected.exhausted() {
          // 如果方法的次数用尽 (exhausted) 了,就 Remove 掉
         ctrl.expectedCalls.Remove(expected)
      }
      // 返回了 actions, 与上文中 Return() 里面的 action 联动
      return actions
   }()

   var rets []interface{}
   // 执行 actions 得到返回值 rets
   for _, action := range actions {
      if r := action(args); r != nil {
         rets = r
      }
   }

   // 返回
   return rets
}

Controller.Call() 主要做了如下几步操作:

  1. 调用 expectedCalls.FindMatch() 方法找到之前注册的方法,与上文提到的 someMethod() 里面对方法的注册联动。FindMatch() 的实现后文会说。

  2. 执行 expected.call(),它的作用很简单,就是增加一次调用次数

    func (c *Call) call() []func([]interface{}) []interface{} {
    	c.numCalls++
    	return c.actions
    }
    
  3. 通过 expected.exhausted() 检查调用次数是否用尽 (感觉 exhausted 这个方法名取的挺好的,exhausted 是精疲力竭的意思,精疲力竭 ≈ 用尽)

    func (c *Call) exhausted() bool {
    	return c.numCalls >= c.maxCalls
    }
    
  4. 如果次数用尽了,就通过 Remove() 将 Call 对象从 callSet.expected 移动到 callSet.exhausted 中

    func (cs callSet) Remove(call *Call) {
    	key := callSetKey{call.receiver, call.method}
    	calls := cs.expected[key]
    	for i, c := range calls {
    		if c == call {
    			// maintain order for remaining calls
    			cs.expected[key] = append(calls[:i], calls[i+1:]...)
    			cs.exhausted[key] = append(cs.exhausted[key], call)
    			break
    		}
    	}
    }
    
  5. 执行 action() ,得到返回值

ctrl.Finish()

ctrl.Finish() 会调用 ctrl.finish(),ctrl.finish() 会找到执行次数没有达标的方法,并返回错误。

func (ctrl *Controller) Finish() {
	// ...
	ctrl.finish(false, err)
}

func (ctrl *Controller) finish(cleanup bool, panicErr interface{}) {
	// ...
	// Check that all remaining expected calls are satisfied.
	failures := ctrl.expectedCalls.Failures()
	for _, call := range failures {
		ctrl.T.Errorf("missing call(s) to %v", call)
	}
	if len(failures) != 0 {
		if !cleanup {
			ctrl.T.Fatalf("aborting test due to missing call(s)")
			return
		}
		ctrl.T.Errorf("aborting test due to missing call(s)")
	}
}

// Failures returns the calls that are not satisfied.
func (cs callSet) Failures() []*Call {
	failures := make([]*Call, 0, len(cs.expected))
	for _, calls := range cs.expected {
		for _, call := range calls {
			if !call.satisfied() {
				failures = append(failures, call)
			}
		}
	}
	return failures
}

// Returns true if the minimum number of calls have been made.
func (c *Call) satisfied() bool {
	return c.numCalls >= c.minCalls
}

此外,在调用 NewController() 函数时,也会将 ctrl.finish() 注册到 Test.Cleanup() 中。

func NewController(t TestReporter) *Controller {
	// ...
    ctrl := &Controller{
        T:             h,
        expectedCalls: newCallSet(),
	}
	if c, ok := isCleanuper(ctrl.T); ok {
		c.Cleanup(func() {
			ctrl.T.Helper()
			ctrl.finish(true, nil)	// 这里
		})
	}

	return ctrl
}

到此为止,我们已经知道 gomock 的主流程了,可以解答最初的疑问了。

gomock 是怎样将 Expect() 中指定的参数 (ArgsIn) 与执行时接收到参数进行匹配的?

这里就需要涉及到 FindMatch() 了。

callSet.FindMatch()
func (cs callSet) FindMatch(receiver interface{}, method string, args []interface{}) (*Call, error) {
	key := callSetKey{receiver, method}

	// Search through the expected calls.
	expected := cs.expected[key]
	var callsErrors bytes.Buffer
	for _, call := range expected {
		err := call.matches(args)
		if err != nil {
			_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
		} else {
			return call, nil
		}
	}

	// If we haven't found a match then search through the exhausted calls so we
	// get useful error messages.
	exhausted := cs.exhausted[key]
	for _, call := range exhausted {
		if err := call.matches(args); err != nil {
			_, _ = fmt.Fprintf(&callsErrors, "\n%v", err)
			continue
		}
		_, _ = fmt.Fprintf(
			&callsErrors, "all expected calls for method %q have been exhausted", method,
		)
	}

	if len(expected)+len(exhausted) == 0 {
		_, _ = fmt.Fprintf(&callsErrors, "there are no expected calls of the method %q for that receiver", method)
	}

	return nil, errors.New(callsErrors.String())
}

注释说的很明白了,FindMatch() 会先在 expected 内寻找,调用 call.matches(args) 判断参数是否匹配,如果找不到,就会返回 error,之所以额外在 exhausted 内也寻找一次,是为了得到一些有用的 error 信息。

Call.matches()
type Call struct {
	// ...
	args       []Matcher    // the args
}

type Matcher interface {
	// Matches returns whether x is a match.
	Matches(x interface{}) bool

	// String describes what the matcher matches.
	String() string
}

func (c *Call) matches(args []interface{}) error {
    // IsVariadic() 用户判断方法的参数是否是可变的 (最后一个参数是 ... 的形式)
    // 这里我们仅讨论非可变参数的情况
	if !c.methodType.IsVariadic() {
		if len(args) != len(c.args) {
			return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d",
				c.origin, len(args), len(c.args))
		}
        // 利用 Matches() 方法逐个比较注册的参数 (c.args) 与实际到达的参数是否一致
		for i, m := range c.args {
			if !m.Matches(args[i]) {
				return fmt.Errorf(
					"expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v",
					c.origin, i, formatGottenArg(m, args[i]), m,
				)
			}
		}
	} else {
		// ...
	}
	// ...
	// Check that the call is not exhausted.
	if c.exhausted() {
		return fmt.Errorf("expected call at %s has already been called the max number of times", c.origin)
	}

	return nil
}
入参的注册

参数的匹配是利用 Matcher.Matches() 完成的,但我我们却很少为 Args 实现 Matcher() 接口,我们最终用的是传一个普通的类型,例如 int,string,struct 等。

其实 gomock 在 newCall() 为我们做了从普通类型到 Matcher 接口的转换。

// newCall creates a *Call. It requires the method type in order to support
// unexported methods.
func newCall(t TestHelper, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
	// ...
	mArgs := make([]Matcher, len(args))
	for i, arg := range args {
		if m, ok := arg.(Matcher); ok {
            // 如果参数实现了 Matcher 接口,那就不用做额外工作了
			mArgs[i] = m
		} else if arg == nil {
			// Handle nil specially so that passing a nil interface value
			// will match the typed nils of concrete args.
            // Nil() 的定义见下面
			mArgs[i] = Nil()
		} else {
            // Eq() 的定义见下面
			mArgs[i] = Eq(arg)
		}
	}

	// ...
}
一些 Matcher 接口的 Implement
nilMatcher
func Nil() Matcher { return nilMatcher{} }

type nilMatcher struct{}

func (nilMatcher) Matches(x interface{}) bool {
   if x == nil {
      return true
   }

   v := reflect.ValueOf(x)
   switch v.Kind() {
   case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
      reflect.Ptr, reflect.Slice:
      return v.IsNil()
   }

   return false
}
eqMatcher
func Eq(x interface{}) Matcher { return eqMatcher{x} }

type eqMatcher struct {
	x interface{}
}

func (e eqMatcher) Matches(x interface{}) bool {
	// In case, some value is nil
	if e.x == nil || x == nil {
		return reflect.DeepEqual(e.x, x)
	}

	// Check if types assignable and convert them to common type
	x1Val := reflect.ValueOf(e.x)
	x2Val := reflect.ValueOf(x)

	if x1Val.Type().AssignableTo(x2Val.Type()) {
		x1ValConverted := x1Val.Convert(x2Val.Type())
		return reflect.DeepEqual(x1ValConverted.Interface(), x2Val.Interface())
	}

	return false
}
anyMatcher
func Any() Matcher { return anyMatcher{} }

type anyMatcher struct{}

func (anyMatcher) Matches(interface{}) bool {
	return true
}

gomock 如何做到在执行接口时的返回在 Expect() 中指定的返回值?

上面提到了,通过在 return 里面注册 actions,在执行接口时 apply actions。

gomock 是怎么判断某个方法的执行次数与预期不符的?

每执行一次,执行次数就 +1,通过 exhaust() 方法来判断调用次数不会超过预期次数,通过 satisfied() 来判断调用次数不会少于预期次数。

exhaust() 会在每次执行方法时被调用。

satisfied() 会在 Finish() 内调用。

意外的收货

在阅读 gomock 的代码时,发现它会在很多地方调用 T.Helper() 这个方法。

在利用 T.log() 打印信息时,如果当前的函数栈调用了 T.Helper(),就不再打印当前函数栈的文件行信息,而是打印该函数栈的上一层信息,例子如下,下面这几个 snippet 在利用 t.Log() (t.Log() 会调用 t.log())打印信息时,输出的行号时不一样的。

package XXXTest

import "testing"

func TestXXX(t *testing.T) {
    // t.Helper()
	Say1(t)
}
func Say1(t *testing.T) {
	// t.Helper()
	Say2(t)

}
func Say2(t *testing.T) {
	// t.Helper()
	t.Log("hello world") // xxx_test.go:16: hello world
}
package XXXTest

import "testing"

func TestXXX(t *testing.T) {
    // t.Helper()
	Say1(t)
}
func Say1(t *testing.T) {
	// t.Helper()
	Say2(t)

}
func Say2(t *testing.T) {
	t.Helper()
	t.Log("hello world") // xxx_test.go:11: hello world
}
package XXXTest

import "testing"

func TestXXX(t *testing.T) {
    // t.Helper()
	Say1(t)
}
func Say1(t *testing.T) {
	t.Helper()
	Say2(t)

}
func Say2(t *testing.T) {
	t.Helper()
	t.Log("hello world") //  xxx_test.go:7: hello world
}
package XXXTest

import "testing"

func TestXXX(t *testing.T) {
    t.Helper()
	Say1(t)
}
func Say1(t *testing.T) {
	t.Helper()
	Say2(t)

}
func Say2(t *testing.T) {
	t.Helper()
	t.Log("hello world") //  xxx_test.go:7: hello world
}
posted @ 2022-05-16 01:40  机智的小小帅  阅读(1415)  评论(0编辑  收藏  举报