用Golang手写一个RPC,理解RPC原理

代码结构

.
├── client.go
├── coder.go
├── coder_test.go
├── rpc_test.go
├── server.go
├── session.go
└── session_test.go

代码

client.go

package rpc

import (
	"net"
	"reflect"
)

// rpc 客户端实现

// 抽象客户端方法
type Client struct {
	conn net.Conn
}

// client构造方法
func NewClient(conn net.Conn) *Client {
	return &Client{conn: conn}
}

// 客户端调用服务端rpc实现
// client.RpcCall("login", &req)
func (c *Client) RpcCall(name string, fpr interface{}) {
	// 反射获取函数原型
	fn := reflect.ValueOf(fpr).Elem()
	// 客户端逻辑的实现
	f := func(args []reflect.Value) (results []reflect.Value) {
		// 从匿名函数中构建请求参数
		inArgs := make([]interface{}, 0, len(args))
		for _, v := range args {
			inArgs = append(inArgs, v.Interface())
		}
		// 组装rpc data请求数据
		reqData := RpcData{Name: name, Args: inArgs}
		// 进行数据编码
		reqByteData, err := encode(reqData)
		if err != nil {
			return
		}
		// 创建session 对象
		session := NewSession(c.conn)
		// 客户端发送数据
		err = session.Write(reqByteData)
		if err != nil {
			return
		}
		// 读取客户端数据
		rspByteData, err := session.Read()
		if err != nil {
			return
		}
		// 数据进行解码
		rspData, err := decode(rspByteData)
		if err != nil {
			return
		}
		// 处理服务端返回的数据结果
		outArgs := make([]reflect.Value, 0, len(rspData.Args))
		for i, v := range rspData.Args {
			// 数据特殊情况处理
			if v == nil {
				// reflect.Zero() 返回某类型的零值的value
				// .Out()返回函数输出的参数类型
				// 得到具体第几个位置的参数的零值
				outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i)))
				continue
			}
			outArgs = append(outArgs, reflect.ValueOf(v))
		}

		return outArgs
	}

	// 函数原型到调用的关键,需要2个参数
	// 参数1:函数原型,是Type类型
	// 参数2:返回类型是Value类型
	// 简单理解:参数1是函数原型,参数2是客户端逻辑
	v := reflect.MakeFunc(fn.Type(), f)
	fn.Set(v)
}

coder.go

package rpc

import (
	"bytes"
	"encoding/gob"
	"fmt"
)

// 对传输的数据进行编解码
// 使用Golang自带的一个数据结构序列化编码/解码工具 gob

// 定义rpc数据交互式数据传输格式
type RpcData struct {
	Name string        // 调用方法名
	Args []interface{} // 调用和返回的参数列表
}

// 编码
func encode(data RpcData) ([]byte, error) {
	// gob进行编码
	var buf bytes.Buffer
	// 得到字节编码器
	encoder := gob.NewEncoder(&buf)
	// 进行编码
	if err := encoder.Encode(data); err != nil {
		fmt.Printf("gob encode failed, err: %v\n", err)
		return nil, err
	}
	return buf.Bytes(), nil
}

// 解码
func decode(data []byte) (RpcData, error) {
	// 得到字节解码器
	buf := bytes.NewBuffer(data)
	decoder := gob.NewDecoder(buf)
	// 解码数据
	var rd RpcData
	if err := decoder.Decode(&rd); err != nil {
		fmt.Printf("gob decode failed, err: %v\n", err)
		return rd, err
	}
	return rd, nil
}

server.go

package rpc

import (
	"net"
	"reflect"
)

// rpc 服务端实现

// 抽象服务端
type Server struct {
	add   string                   // 连接地址
	funcs map[string]reflect.Value // 存储方法名和方法的对应关系,服务注册
}

// server 构造方法
func NewServer(addr string) *Server {
	return &Server{add: addr, funcs: make(map[string]reflect.Value)}
}

// 注册接口
func (s *Server) Register(name string, fc interface{}) {
	if _, ok := s.funcs[name]; ok {
		return
	}
	s.funcs[name] = reflect.ValueOf(fc)
}

func (s *Server) Run() (err error) {
	listener, err := net.Listen("tcp", s.add)
	if err != nil {
		return
	}
	for {
		// 监听连接
		conn, err := listener.Accept()
		if err != nil {
			conn.Close()
			continue
		}
		// 创建会话
		session := NewSession(conn)
		// 读取会话请求数据
		reqData, err := session.Read()
		if err != nil {
			conn.Close()
			continue
		}
		// 数据解码
		rpcReqData, err := decode(reqData)
		// 获取客户端要调用的方法
		fc, ok := s.funcs[rpcReqData.Name];
		if !ok {
			conn.Close()
			continue
		}
		// 获取请求的参数列表
		args := make([]reflect.Value, 0, len(rpcReqData.Args))
		for _, v := range rpcReqData.Args {
			args = append(args, reflect.ValueOf(v))
		}
		// 调用
		callReslut := fc.Call(args)
		// 处理调用返回的数据结果
		rargs := make([]interface{}, 0, len(callReslut))
		for _, rv := range callReslut {
			rargs = append(rargs, rv.Interface())
		}
		// 构建返回的rpc数据
		rpcRspData := RpcData{Name: rpcReqData.Name, Args: rargs}
		// 返回数据进行编码
		rspData, err := encode(rpcRspData)
		if err != nil {
			conn.Close()
			continue
		}
		err = session.Write(rspData)
		if err != nil {
			conn.Close()
			continue
		}
	}
	return
}

session.go

package rpc

import (
	"encoding/binary"
	"fmt"
	"io"
	"net"
)

// 处理连接会话

// 会话对象结构体
type Session struct {
	conn net.Conn
}

// 传输数据存储方式
// 字节数组, 添加4个字节的头,用来存储数据的长度

// 会话构造函数
func NewSession(conn net.Conn) *Session {
	return &Session{conn: conn}
}

// 从连接中读取数据
func (s *Session) Read() (data []byte, err error) {
	// 读取数据header数据
	header := make([]byte, 4)
	_, err = s.conn.Read(header)
	if err != nil {
		fmt.Printf("read conn header data failed, err: %v\n", err)
		return
	}
	// 读取body数据
	hlen := binary.BigEndian.Uint32(header)
	data = make([]byte, hlen)
	_, err = io.ReadFull(s.conn, data)
	if err != nil {
		fmt.Printf("read conn body data failed, err: %v\n", err)
		return
	}
	return
}

// 向连接中写入数据
func (s *Session) Write(data []byte) (err error) {
	// 创建数据字节切片
	buf := make([]byte, 4+len(data))
	// 向header写入数据长度
	binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
	// 写入body内容
	copy(buf[4:], data)
	// 写入连接数据
	_, err = s.conn.Write(buf)
	if err != nil {
		fmt.Printf("write conn data failed, err: %v\n", err)
		return
	}
	return
}

coder_test.go

package rpc

import (
	"testing"
)

func TestCoder(t *testing.T) {
	rd := RpcData{
		Name: "login",
		Args: []interface{}{"zhangsan", "zs123"},
	}

	eData, err := encode(rd)
	if err != nil {
		t.Error(err)
		return
	}
	t.Logf("gob 编码后数据长度: %d\n", len(eData))

	dData, err := decode(eData)
	if err != nil {
		t.Error(err)
		return
	}
	t.Logf("%#v\n", dData)
}

session_test.go

package rpc

import (
	"net"
	"sync"
	"testing"
)

func TestSession(t *testing.T) {
	addr := ":8080"
	test_data := "my is test data"
	var wg sync.WaitGroup
	wg.Add(2)
	// 写数据
	go func() {
		defer wg.Done()
		listener, err := net.Listen("tcp", addr)
		if err != nil {
			t.Fatal(err)
			return
		}
		conn, _ := listener.Accept()
		s := NewSession(conn)
		data, err := s.Read()
		if err != nil {
			t.Error(err)
			return
		}
		t.Log(string(data))
	}()

	// 读数据
	go func() {
		defer wg.Done()
		conn, err := net.Dial("tcp", addr)
		if err != nil {
			t.Fatal(err)
			return
		}
		s := NewSession(conn)
		err = s.Write([]byte(test_data))
		if err != nil {
			return
		}
		t.Log("写入数据成功")
		return
	}()

	wg.Wait()
}

rpc_test.go

package rpc

import (
	"encoding/gob"
	"fmt"
	"net"
	"testing"
)

// rpc 客户端和服务端测试

// 定义一个服务端结构体
// 定义一个方法
// 通过调用rpc方法查询用户的信息

type User struct {
	Name string
	Age  int
}

// 定义查询用户的方法
// 通过用户id查询用户数据
func queryUser(id int) (User, error) {
	// 造一些查询user的假数据
	users := make(map[int]User)
	users[0] = User{"user01", 22}
	users[1] = User{"user02", 23}
	users[2] = User{"user03", 24}
	if u, ok := users[id]; ok {
		return u, nil
	}
	return User{}, fmt.Errorf("%d id not found", id)

}

func TestRpc(t *testing.T) {
	// 给gob注册类型
	gob.Register(User{})

	addr := ":8080"

	// 创建服务端
	server := NewServer(addr)
	// 注册服务
	server.Register("queryUser", queryUser)
	// 启动服务端
	go server.Run()

	// 创建客户端连接
	conn, err := net.Dial("tcp", addr)
	if err != nil {
		return
	}
	// 创客户端
	client := NewClient(conn)
	// 定义函数调用原型
	var query func(int) (User, error)
	// 客户端调用rpc
	client.RpcCall("queryUser", &query)
	// 得到返回结果
	user, err := query(1)
	if err != nil {
		t.Error(err)
		return
	}
	fmt.Printf("%#v\n", user)
}
posted @ 2020-04-05 17:40  ZhiChao&  阅读(1596)  评论(0编辑  收藏  举报