Go源码解读-rpc包

服务端代码

一个简单的rpc server示例如下:

package main

import (
	"log"
	"net"
	"net/http"
	"net/rpc"

	"github.com/monoxy/rpc/common"
)

func main() {
	server := rpc.NewServer()
	server.Register(new(common.Embed))
	lis, err := net.Listen("tcp", ":1234")
	if err != nil {
		log.Fatalf("list err: %v", err)
	}
	http.Serve(lis, server)
}

我们通过rpc.NewServer拿到Server结构体,表示rpc server,其定义为:

type Server struct {
	serviceMap sync.Map   // map[string]*service
	reqLock    sync.Mutex // 对freeReq提供锁保护
	freeReq    *Request
	respLock   sync.Mutex // 对freeResp提供锁保护
	freeResp   *Response
}

提供服务的实体抽象为service(一般与一个接收器对应,接收器可能是结构体或指针):

type service struct {
	name   string                 // 服务名称,没有指定便为结构体的导出名称
	rcvr   reflect.Value          // receiver of methods for the service
	typ    reflect.Type           // type of the receiver
	method map[string]*methodType // registered methods
}

serviceMap是一个同步sync.Map,以serviceName为key,Serice为value。通过server.Register将相应的服务,也就是service注册到当前的server中去。Register的代码和详细注释如下:

func (server *Server) register(rcvr interface{}, name string, useName bool) error {
	s := new(service)
	s.typ = reflect.TypeOf(rcvr)
	s.rcvr = reflect.ValueOf(rcvr)
	sname := reflect.Indirect(s.rcvr).Type().Name()
	if useName { // 如果指定service name,则使用name参数
		sname = name
	}
	if sname == "" {
		s := "rpc.Register: no service name for type " + s.typ.String()
		log.Print(s)
		return errors.New(s)
	}
	if !token.IsExported(sname) && !useName { //  service name必须是可导出的,即首字母大写
		s := "rpc.Register: type " + sname + " is not exported"
		log.Print(s)
		return errors.New(s)
	}
	s.name = sname

	// 调用suitableMethods,将接收器中的导出方法注册到s.method变量中
	s.method = suitableMethods(s.typ, true)

	if len(s.method) == 0 {
		str := ""

		// 这里兼容reciever是指针,但注册的对象是结构体的情况,通过reflect.PtrTo去相应的结构体指针中查找方法并注册
		method := suitableMethods(reflect.PtrTo(s.typ), false)
		if len(method) != 0 {
			str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
		} else {
			str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
		}
		log.Print(str)
		return errors.New(str)
	}
	
    // 当service的method注册完成后,将service整体注册到server的serviceMap变量中,注意这里使用了sync.Map
	if _, dup := server.serviceMap.LoadOrStore(sname, s); dup {
		return errors.New("rpc: service already defined: " + sname)
	}
	return nil
}

再来看suitabaleMethod,这个函数的作用是遍历接收器中的方法,这里就不列举源码,简单来说,对每个方法判断以下条件是否满足:

  • 方法名可导出,即首字母大写
  • 方法有三个入参,接收器本身、请求参数(*args)、响应参数(*reply)
  • 请求参数和响应参数均可导出
  • 响应参数必须为指针类型
  • 方法只有一个出参error

如果满足,则构造methodType对象,将其注册到service.method

注册后方法后,需要提供服务了,rpc包实现了tcp和http两种方式。开关的示例代码即是采用了http server。显然我们的rpc server需要提供ServerHTTP,以实现http.Handler接口。

func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	// 客户端需要发起CONNECT请求,才能进行后续的rpc调用
    if req.Method != "CONNECT" {
		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
		w.WriteHeader(http.StatusMethodNotAllowed)
		io.WriteString(w, "405 must CONNECT\n")
		return
	}
    // 注意这里通过Hijack劫持了http协议,拿到conn连接
	conn, _, err := w.(http.Hijacker).Hijack()
	if err != nil {
		log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
		return
	}
	io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
	server.ServeConn(conn)
}

server.ServeConn里表示对每一条tcp连接的处理过程,首先构建gob编解码器,再将流程转到server.ServeCodec中处理

func (server *Server) ServeConn(conn io.ReadWriteCloser) {
	buf := bufio.NewWriter(conn)
	srv := &gobServerCodec{
		rwc:    conn,
		dec:    gob.NewDecoder(conn),
		enc:    gob.NewEncoder(buf),
		encBuf: buf,
	}
	server.ServeCodec(srv)
}

func (server *Server) ServeCodec(codec ServerCodec) {
	sending := new(sync.Mutex)
	wg := new(sync.WaitGroup)
	for {
        // 从codec中读取每一次rpc请求的service,方法,请求参数和响应参数等信息
		service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
		if err != nil {
			if debugLog && err != io.EOF {
				log.Println("rpc:", err)
			}
			if !keepReading {
				break
			}
			// send a response if we actually managed to read a header.
			if req != nil {
				server.sendResponse(sending, req, invalidRequest, codec, err.Error())
				server.freeRequest(req)
			}
			continue
		}
		wg.Add(1)
        // 将相关的参数传入到call中执行,函数内部利用反射调用接收体的对应方法,将处理结果写回到response中
		go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
	}
	// 在关闭codec之前,等全部的service.call协程退出
	wg.Wait()
	codec.Close()
}

server.readRequest从codec中解码每一次的rpc请求,一定是按先解码请求头,再解码请求体的顺序,分别对应codec.ReadRequestHeadercodec.ReadRequestBody两个函数。

func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err error) {
	service, mtype, req, keepReading, err = server.readRequestHeader(codec)
	if err != nil {
		if !keepReading {
			return
		}
		// 读取request header出错,则丢弃剩余的body数据
		codec.ReadRequestBody(nil)
		return
	}

	// 解码请求参数
	argIsValue := false // if true, need to indirect before calling.
	if mtype.ArgType.Kind() == reflect.Ptr {
		argv = reflect.New(mtype.ArgType.Elem()) // 通过relfect.New构造对应类型的指针
	} else {
		argv = reflect.New(mtype.ArgType) // 通过relfect.New构造对应类型的指针
		argIsValue = true
	}
    // 此时argv的内部value为指针,通过argv。Interface()拿到指针变量
	if err = codec.ReadRequestBody(argv.Interface()); err != nil {
		return
	}
    // 这里,如果请求参数原本是值类型,也要通过argv。Elem()将指针还原为值类型
	if argIsValue {
		argv = argv.Elem()
	}
	
    // 返回参数一定是指针类型,所以直接通过reflect.New构造
	replyv = reflect.New(mtype.ReplyType.Elem())
	
    // 针对map和slice两类参数,提前make
	switch mtype.ReplyType.Elem().Kind() {
	case reflect.Map:
		replyv.Elem().Set(reflect.MakeMap(mtype.ReplyType.Elem()))
	case reflect.Slice:
		replyv.Elem().Set(reflect.MakeSlice(mtype.ReplyType.Elem(), 0, 0))
	}
	return
}

func (server *Server) readRequestHeader(codec ServerCodec) (svc *service, mtype *methodType, req *Request, keepReading bool, err error) {
	// 从server中取一个request指针
	req = server.getRequest()
	err = codec.ReadRequestHeader(req) // 将gob字节流解码到req上
	if err != nil {
		req = nil
		if err == io.EOF || err == io.ErrUnexpectedEOF {
			return
		}
		err = errors.New("rpc: server cannot decode request: " + err.Error())
		return
	}

	// We read the header successfully. If we see an error now,
	// we can still recover and move on to the next request.
	keepReading = true
	
    // 请求头中的ServiceName一定是"Service.Method"这种格式,注意service中可能也包含.,我们需要取最后一个.作为分隔符。
	dot := strings.LastIndex(req.ServiceMethod, ".")
	if dot < 0 {
		err = errors.New("rpc: service/method request ill-formed: " + req.ServiceMethod)
		return
	}
	serviceName := req.ServiceMethod[:dot]
	methodName := req.ServiceMethod[dot+1:]

	// 凭借serviceName去serviceMap中查找对应的service
	svci, ok := server.serviceMap.Load(serviceName)
	if !ok {
		err = errors.New("rpc: can't find service " + req.ServiceMethod)
		return
	}
	svc = svci.(*service)
	mtype = svc.method[methodName]
	if mtype == nil {
		err = errors.New("rpc: can't find method " + req.ServiceMethod)
	}
	return
}

客户端代码

简单的客户端示例如下:

package main

import (
	"fmt"
	"log"
	"net/rpc"

	"github.com/monoxy/rpc/common"
)

func main() {
	client, err := rpc.DialHTTP("tcp", "127.0.0.1:1234")
	if err != nil {
		log.Fatal(err)
	}
	args := common.Args{A: 1, B: 2}
	reply := new(common.Reply)
	err = client.Call("Embed.Add", args, reply)
	fmt.Println(reply)
	fmt.Println(err)
}

针对获取rpc.Client,rpc包提供了两类方法,DialHTTP和Dial,前者以http协议的形式进行初始化,后者为tcp协议。

DialHTTP的使用前提是服务端需要提供http服务。DialHTTP中,以默认rpc路由为参数调用DialHTTPPath。可以看到DialHTTPPath内部首先对服务端发送了一次CONNET请求,如果成功,则将conn构造rpc.Client对象并返回。

我们可以发现,DialHTTP相比于Dial,实际上就多发起了一次CONNECT请求。

func DialHTTP(network, address string) (*Client, error) {
	return DialHTTPPath(network, address, DefaultRPCPath)
}

// DialHTTPPath connects to an HTTP RPC server
// at the specified network address and path.
func DialHTTPPath(network, address, path string) (*Client, error) {
	var err error
	conn, err := net.Dial(network, address)
	if err != nil {
		return nil, err
	}
	io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n\n")

	// 需要成功的HTTP响应,才能切换到rpc协议
	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
	if err == nil && resp.Status == connected {
		return NewClient(conn), nil
	}
	if err == nil {
		err = errors.New("unexpected HTTP response: " + resp.Status)
	}
	conn.Close()
	return nil, &net.OpError{
		Op:   "dial-http",
		Net:  network + " " + address,
		Addr: nil,
		Err:  err,
	}
}

NewClient的参数只有一个conn,先基于conn构造客户端gob解码器,再将其作为参数传入NewClientWithCodec。构造完成后,开启一个读协程,读取服务端返回的数据。

func NewClient(conn io.ReadWriteCloser) *Client {
	encBuf := bufio.NewWriter(conn)
	client := &gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(encBuf), encBuf}
	return NewClientWithCodec(client)
}

// NewClientWithCodec is like NewClient but uses the specified
// codec to encode requests and decode responses.
func NewClientWithCodec(codec ClientCodec) *Client {
	client := &Client{
		codec:   codec,
		pending: make(map[uint64]*Call),
	}
    // 注意这里开启了读协程
	go client.input()
	return client
}

// 读协程
func (client *Client) input() {
	var err error
	var response Response
	for err == nil {
		response = Response{}
		err = client.codec.ReadResponseHeader(&response)
		if err != nil {
			break
		}
		seq := response.Seq
		client.mutex.Lock()
		call := client.pending[seq]
		delete(client.pending, seq)
		client.mutex.Unlock()

		switch {
		case call == nil:
            // call为空,表明在执行client.Call或者client.Go发送请求时已经出现失败,将对应的call删除,那么接下的处理就是丢弃对应的body,只获取错误信息
			err = client.codec.ReadResponseBody(nil)
			if err != nil {
				err = errors.New("reading error body: " + err.Error())
			}
		case response.Error != "":
            // 服务端返回了错误,将错误信息返回给call,并终止循环。那么所有后续的处于pending的call将收到同样的错误信息
			call.Error = ServerError(response.Error)
			err = client.codec.ReadResponseBody(nil)
			if err != nil {
				err = errors.New("reading error body: " + err.Error())
			}
			call.done()
		default:
			err = client.codec.ReadResponseBody(call.Reply)
			if err != nil {
				call.Error = errors.New("reading body " + err.Error())
			}
			call.done()
		}
	}
	// 终止pending状态的call
	client.reqMutex.Lock()
	client.mutex.Lock()
	client.shutdown = true
	closing := client.closing
	if err == io.EOF {
		if closing {
			err = ErrShutdown
		} else {
			err = io.ErrUnexpectedEOF
		}
	}
	for _, call := range client.pending {
		call.Error = err
		call.done()
	}
	client.mutex.Unlock()
	client.reqMutex.Unlock()
	if debugLog && err != io.EOF && !closing {
		log.Println("rpc: client protocol error:", err)
	}
}
posted @ 2021-05-16 22:54  g2012  阅读(288)  评论(0编辑  收藏  举报