手写RPC框架实现
手写RPC框架实现
参考视频:自己动手实现RPC框架
1、什么是RPC?
- RPC, 英文全名remote procedure call 即远程过程掉调用
- 就是说一个应用部署在A服务器上,想要调用B服务器上应用提供的方法
- 由于不在一个内存空间,不能直接调用,需要通过网络来表达调用的语义和传达调用的数据
- RPC就是要像调用本地的函数一样去调用远程函数
2、RPC架构
服务提供者(RPC Server): 运行在服务器端,提供服务接口定义与服务实现类。
注册中心(Registry): 运行在服务器端,负责将本地服务发布成远程服务,管理远程服务,提供给服务消费者使用。
服务消费者(RPC Client): 运行在客户端,通过远程代理对象调用远程服务。
3、动手实现RPC框架
实现RPC框架主要分成6个模块实现,分别是序列化模块、server模块、client模块、网络传输模块、协议模块、通用模块。
3.1 通用模块实现
通用模块主要提供各个模块公用的一个反射工具类,用来根据Class对象创建类的实例对象、获得类的公共方法以及调用指定对象的指定方法。
ReflectionUtils.java
import java.lang.invoke.MethodHandle;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 反射工具类,用于创建实例和调用方法实现。
* @date 2022/5/15 15:52
*/
public class ReflectionUtils {
/**
* 根据clazz创建对象
* @param clazz 带创建对象的类
* @param <T> 对象类型
* @return 创建好对象
*/
public static <T> T newInstance(Class<T> clazz){
try{
return clazz.newInstance();
}catch (Exception e){
throw new IllegalStateException();
}
}
/**
* 获取某个类clazz的公有方法
* @param clazz
* @return 当前类声明的公有方法
*/
public static Method[] getPublicMethods(Class clazz){
Method[] methods = clazz.getDeclaredMethods();
List<Method> pmethods = new ArrayList<>();
for(Method m : methods){
if(Modifier.isPublic(m.getModifiers())){
pmethods.add(m);
}
}
return pmethods.toArray(new Method[0]);
}
/**
*
* @param obj 被调用方法的对象
* @param method 被调用的方法
* @param args 方法的参数
* @return 返回结果
*/
public static Object invoke(Object obj, Method method, Object... args) {
try{
return method.invoke(obj, args);
}catch (Exception e){
throw new IllegalStateException();
}
}
}
3.2 序列化模块实现
序列化模块主要就是当客户端调用服务端的接口方法时,需要传入调用的具体方法和参数,因此需要通过底层的网络协议(如HTTP、TCP),传输到服务端,所以需要将数据序列化后才能在网络中进行传输,服务端接收后需要对传输过来的数据进行反序列化拿到具体的值。序列化的方式有很多,如java原生序列化、json序列化、Protobuff序列化等。
JSONDecoder.java
import com.alibaba.fastjson.JSON;
/**
* @author Wenbo
* @version 1.0
* @program
* @description fastjson实现反序列化
* @date 2022/5/15 16:11
*/
public class JSONDecoder implements Decoder{
@Override
public <T> T decode(byte[] bytes, Class<T> clazz) {
return JSON.parseObject(bytes, clazz);
}
}
JSONEecoder.java
import com.alibaba.fastjson.JSON;
/**
* @author Wenbo
* @version 1.0
* @program
* @description fastjson实现序列化
* @date 2022/5/15 16:11
*/
public class JSONEncoder implements Encoder{
@Override
public byte[] encode(Object obj) {
return JSON.toJSONBytes(obj);
}
}
3.3 协议模块
协议模块主要就是规定了客户端传输的数据内容、服务端响应的内容、网络端点的信息,客户端传输的数据内容主要包括请求的信息(一个实体类,其中包括服务名、调用的具体函数名、参数类型、返回类型、版本号)、调用参数值。服务端响应的内容主要包括相应编码(成功或失败)、服务端信息以及调用的服务端函数返回的数据。网络端点的信息就包括主机名和端口号。
Peer.java
import lombok.AllArgsConstructor;
import lombok.Data;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 表示网络传输的一个端点
* @date 2022/5/15 15:09
*/
@Data
@AllArgsConstructor
public class Peer {
private String host;
private int port;
}
Request.java
import lombok.Data;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 请求服务
* @date 2022/5/15 15:15
*/
@Data
public class Request {
private ServiceDescriptor serviceDescriptor; // 请求的信息,其中包括服务名(接口名)、调用的具体函数名、参数类型、返回类型、版本号
private Object[] parameters;
}
Response.java
import lombok.Data;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 表示RPC的返回
* @date 2022/5/15 15:15
*/
@Data
public class Response {
/**
* 服务返回编码,0 成功,非0 失败
*/
private int code = 0;
/**
* 具体的错误信息
*/
private String message = "ok";
/**
* 返回的数据
*/
private Object data;
}
ServiceDescriptor.java
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.lang.reflect.Method;
import java.util.Arrays;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 服务的描述,在服务端注册服务和具体实现类的时候需要用到map,所以需要重写hashcode 和 equals。 一个服务接口可以有多个不同版本的实现类。所以加入了version字段。
* @date 2022/5/15 15:15
*/
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ServiceDescriptor {
public String clazz; // 服务接口
private String method; // 方法名
private String returnType; // 返回类型
private String[] parameterTypes;// 参数类型
private String version;// 服务版本号
// 不带版本号的接口描述
public static ServiceDescriptor from(Class clazz, Method method){
ServiceDescriptor sdp = new ServiceDescriptor();
sdp.setClazz(clazz.getName());
sdp.setMethod(method.getName());
sdp.setReturnType(method.getReturnType().getName());
Class[] parameterClasses = method.getParameterTypes();
String[] parameterTypes = new String[parameterClasses.length];
for (int i = 0; i < parameterClasses.length; i++) {
parameterTypes[i] = parameterClasses[i].getName();
}
sdp.setParameterTypes(parameterTypes);
return sdp;
}
// 带版本号的接口描述
public static ServiceDescriptor from(Class clazz,String version, Method method){
ServiceDescriptor sdp = new ServiceDescriptor();
sdp.setClazz(clazz.getName());
sdp.setMethod(method.getName());
sdp.setReturnType(method.getReturnType().getName());
sdp.setVersion(version);
Class[] parameterClasses = method.getParameterTypes();
String[] parameterTypes = new String[parameterClasses.length];
for (int i = 0; i < parameterClasses.length; i++) {
parameterTypes[i] = parameterClasses[i].getName();
}
sdp.setParameterTypes(parameterTypes);
return sdp;
}
@Override
public boolean equals(Object o){
if (this == o) return true;
if(o == null || this.getClass() != o.getClass()) return false;
ServiceDescriptor that = (ServiceDescriptor) o;
return this.toString().equals(that.toString());
}
@Override
public int hashCode(){
return toString().hashCode();
}
public String toString(){
return "clazz=" + clazz + ", method" +method + ",returnType = " + returnType + ",parametersType=" + Arrays.toString(parameterTypes);
}
}
3.4 网络传输模块
网络传输模块主要就是指定主机之间采用什么网络协议进行通信,这里采用的是HTTP传输协议,客户端向服务端发起连接请求,并传输数据,服务端利用jetty内嵌服务器接收请求,并对相应的请求处理,调用具体的服务,将服务运行结果返回给消费端。
HTTPTransportClient.java
import com.rpc.Peer;
import org.apache.commons.io.IOUtils;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.MalformedURLException;
import java.net.URL;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 基于HTTP实现的客户端
* @date 2022/5/15 19:38
*/
public class HTTPTransportClient implements TransportClient{
private String url;
@Override
public void connect(Peer peer) {
this.url = "http://" + peer.getHost() + ":" + peer.getPort();
}
@Override
public InputStream write(InputStream data) throws IOException {
HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection();
connection.setDoOutput(true);
connection.setDoInput(true);
connection.setUseCaches(false);
connection.connect();
IOUtils.copy(data, connection.getOutputStream());
int resultCode = connection.getResponseCode();
if(resultCode == HttpURLConnection.HTTP_OK){
return connection.getInputStream();
}else{
return connection.getErrorStream();
}
}
@Override
public void close() {
}
}
HTTPTransportServer.java
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.servlet.ServletHolder;
import lombok.extern.slf4j.Slf4j;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 基于http短连接实现的服务器,就是dubbo中的HTTPserver
* @date 2022/5/15 19:18
*/
@Slf4j
public class HTTPTransportServer implements TransportServer{
private RequestHandler handler; // 请求处理器
private Server server;
@Override
public void init(int port, RequestHandler handler) {
this.handler = handler;
this.server = new Server(port);
// servlet接收请求
ServletContextHandler ctx = new ServletContextHandler();
server.setHandler(ctx);
ServletHolder holder = new ServletHolder(new RequestServlet());
ctx.addServlet(holder, "/*"); // 对所有的http请求处理
}
@Override
public void start() {
try {
server.start();
server.join();
}catch (Exception e){
log.error(e.getMessage(),e);
}
}
@Override
public void stop() {
try {
server.stop();
} catch (Exception e) {
log.error(e.getMessage(),e);
}
}
class RequestServlet extends HttpServlet{
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws IOException {
log.info("客户端连接");
InputStream in = req.getInputStream();
OutputStream out = resp.getOutputStream();
if(handler != null){
handler.onRequest(in, out);
}
out.flush();
}
}
}
RequestHandler.java
import java.io.InputStream;
import java.io.OutputStream;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 处理网络请求的handler
* @date 2022/5/15 16:33
*/
public interface RequestHandler {
void onRequest(InputStream recive, OutputStream toResp);
}
3.5 消费端模块
消费端模块采用动态代理的方式,生成服务的代理对象,这样可以使得消费端直接调用具体的方法,而不用去关心底层的细节,在动态代理的Invoker里面去实现真实的服务具体调用。同时还实现了负载均衡,简单的利用随机选择主机来获取连接。
RandomTransportSelector.java
import com.rpc.Peer;
import com.rpc.ReflectionUtils;
import com.rpc.transport.TransportClient;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 负载均衡选择器, 随机选择一个连接进行传输
* @date 2022/5/15 20:10
*/
@Slf4j
public class RandomTransportSelector implements TransportSelector{
private List<TransportClient> clients;
public RandomTransportSelector(){
this.clients = new ArrayList<>();
}
@Override
public synchronized void init(List<Peer> peers, int count, Class<? extends TransportClient> clazz) {
count = Math.max(count, 1);
for(Peer peer: peers){
for(int i = 0;i < count;i++){
TransportClient client = ReflectionUtils.newInstance(clazz);
client.connect(peer);
clients.add(client);
}
log.info("connect server:{}",peer);
}
}
@Override
public synchronized TransportClient select() {
int i = new Random().nextInt(clients.size());
return clients.remove(i);
}
@Override
public void release(TransportClient client) {
clients.add(client);
}
@Override
public void close() {
for(TransportClient client : clients){
client.close();
}
clients.clear();
}
}
RemoteInvoker.java
import com.rpc.*;
import com.rpc.transport.TransportClient;
import org.apache.commons.io.IOUtils;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 动态代理的InvocationHandler 的实现类。
* @date 2022/5/15 20:32
*/
public class RemoteInvoker implements InvocationHandler {
private Class clazz;
private Encoder encoder;
private Decoder decoder;
private TransportSelector selector;
public RemoteInvoker(Class clazz, Encoder encoder, Decoder decoder, TransportSelector selector){
this.clazz = clazz;
this.encoder = encoder;
this.decoder = decoder;
this.selector = selector;
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
Request request = new Request();
request.setServiceDescriptor(ServiceDescriptor.from(clazz, method));
request.setParameters(args);
Response resp = invokeRemote(request);
if(resp == null || resp.getCode() != 0){
throw new IllegalStateException("fail to invoke remote:" +resp);
}
return resp.getData();
}
private Response invokeRemote(Request request) {
Response response = null;
TransportClient client = null;
try{
client = selector.select();
byte[] encode = encoder.encode(request);
InputStream write = client.write(new ByteArrayInputStream(encode));
byte[] bytes = IOUtils.readFully(write, write.available());
response = decoder.decode(bytes, Response.class);
}catch (IOException e){
response = new Response();
response.setCode(1);
response.setMessage("Rpc error"+e.getClass()+":"+e.getMessage());
}finally {
if(client!=null){
selector.release(client);
}
}
return response;
}
}
RpcClient.java
import com.rpc.Decoder;
import com.rpc.Encoder;
import com.rpc.ReflectionUtils;
import java.lang.reflect.Proxy;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 动态代理, 生成服务的代理对象
* @date 2022/5/15 20:20
*/
public class RpcClient {
private RpcClientConfig config;
private Encoder encoder;
private Decoder decoder;
private TransportSelector selector;
public RpcClient(){
this(new RpcClientConfig());
}
public RpcClient(RpcClientConfig config){
this.config = config;
this.encoder = ReflectionUtils.newInstance(this.config.getEncoderClass());
this.decoder = ReflectionUtils.newInstance(this.config.getDecoderClass());
this.selector = ReflectionUtils.newInstance(this.config.getTransportSelector());
this.selector.init(this.config.getServers(), this.config.getConnectCount(), this.config.getTransportClass());
}
public <T> T getProxy(Class<T> clazz){
return (T) Proxy.newProxyInstance(getClass().getClassLoader(), new Class[]{clazz}, new RemoteInvoker(clazz,encoder,decoder,selector));
}
}
RpcClientConfig.java
import com.rpc.*;
import com.rpc.transport.HTTPTransportClient;
import com.rpc.transport.TransportClient;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.List;
/**
* @author Wenbo
* @version 1.0
* @program
* @description
* @date 2022/5/15 20:06
*/
@Data
public class RpcClientConfig {
public Class<? extends TransportClient> transportClass = HTTPTransportClient.class;
public Class<? extends Encoder> encoderClass = JSONEncoder.class;
public Class<? extends Decoder> decoderClass = JSONDecoder.class;
public Class<? extends TransportSelector> transportSelector = RandomTransportSelector.class;
private int connectCount = 1;
private List<Peer> servers = Arrays.asList(new Peer("127.0.0.1", 3000));
}
3.6 服务端模块
服务端模块将所有的服务注册到到一个map里面,键为服务的描述,ServiceDescriptor对象,值为实现类ServiceInstance对象,这样当获取到消费端传来的ServiceDescriptor对象的时候可以找到对应的具体实现类。因为ServiceDescriptor类作为Map的键,所以需要重写ServiceDescriptor类的hashcode() 和 equals()方法。当服务端找到具体的实现类后,利用动态代理反射的机制去调用具体的方法,并将结果返回给消费端。
RpcServer.java
import com.rpc.*;
import com.rpc.transport.RequestHandler;
import com.rpc.transport.TransportServer;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
/**
* @author Wenbo
* @version 1.0
* @program
* @description
* @date 2022/5/15 21:17
*/
@Slf4j
public class RpcServer {
private RpcServiceConfig config;
private TransportServer net;
private Encoder encoder;
private Decoder decoder;
private ServiceManager serviceManager;
private ServiceInvoker serviceInvoker;
private RequestHandler handler = new RequestHandler() {
Response resp = new Response();
@Override
public void onRequest(InputStream recive, OutputStream toResp) {// 这里其实就是HttpServerHandler
try{
byte[] inBytes = IOUtils.readFully(recive, recive.available());
Request request = decoder.decode(inBytes, Request.class);
log.info("get request: {}",request);
ServiceInstance sis = serviceManager.lookup(request);
Object res = serviceInvoker.invoke(sis, request);
resp.setData(res);
}catch (IOException e){
log.warn(e.getMessage(),e);
resp.setCode(1);
resp.setMessage("RpcServer get error:"+e.getClass().getName()+":"+e.getMessage());
}finally {
byte[] outBytes = encoder.encode(resp);
try{
toResp.write(outBytes);
log.info("response client");
}catch (IOException e) {
e.printStackTrace();
}
}
}
};
public RpcServer(RpcServiceConfig config){
this.config = config;
// net
this.net = ReflectionUtils.newInstance(config.getTransportClass());
this.net.init(config.getPort(), this.handler);
// codec
this.encoder = ReflectionUtils.newInstance(config.getEncoderClass());
this.decoder = ReflectionUtils.newInstance(config.getDecoderClass());
// Service
this.serviceManager = new ServiceManager();
this.serviceInvoker = new ServiceInvoker();
}
public void start(){
this.net.start();
}
public void stop(){
this.net.stop();
}
public <T> void register(Class<T> interfaceClass, T bean){
serviceManager.register(interfaceClass, bean);
}
public RpcServer(){
this(new RpcServiceConfig());
}
}
RpcServiceConfig.java
import com.rpc.Decoder;
import com.rpc.Encoder;
import com.rpc.JSONDecoder;
import com.rpc.JSONEncoder;
import com.rpc.transport.HTTPTransportServer;
import com.rpc.transport.TransportServer;
import lombok.Data;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 服务配置类
* @date 2022/5/15 21:13
*/
@Data
public class RpcServiceConfig {
// 网络协议
private Class<? extends TransportServer> transportClass = HTTPTransportServer.class;
// 序列化
private Class<? extends Encoder> encoderClass = JSONEncoder.class;
// 反序列化
private Class<? extends Decoder> decoderClass = JSONDecoder.class;
private int port = 3000;
}
ServiceInstance.java
import lombok.AllArgsConstructor;
import lombok.Data;
import java.lang.reflect.Method;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 表示一个具体的服务
* @date 2022/5/15 20:52
*/
@Data
@AllArgsConstructor
public class ServiceInstance {
private Object target;
private Method method;
}
ServiceInvoker.java
import com.rpc.ReflectionUtils;
import com.rpc.Request;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 调用具体的服务中的方法
* @date 2022/5/15 21:19
*/
public class ServiceInvoker {
public Object invoke(ServiceInstance serviceInstance, Request request){
return ReflectionUtils.invoke(serviceInstance.getTarget(),
serviceInstance.getMethod(),
request.getParameters());
}
}
ServiceManager.java
import com.rpc.ReflectionUtils;
import com.rpc.Request;
import com.rpc.ServiceDescriptor;
import lombok.extern.slf4j.Slf4j;
import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author Wenbo
* @version 1.0
* @program
* @description 服务管理类, 管理rpc暴露的服务
* @date 2022/5/15 20:51
*/
@Slf4j
public class ServiceManager { // 这个其实就是LocalRegister方法, 将服务名和实现类存储在map里面, 可以根据服务名(接口名)找对应的实现类。
private Map<ServiceDescriptor, ServiceInstance> services;
public ServiceManager(){
services = new ConcurrentHashMap<>();
}
public <T> void register(Class<T> interfaceClass, T bean){ // bean是实现类的对象
Method[] methods = ReflectionUtils.getPublicMethods(interfaceClass);
for(Method method : methods){
ServiceInstance sis = new ServiceInstance(bean, method);
ServiceDescriptor sdp = ServiceDescriptor.from(interfaceClass, method);
services.put(sdp, sis);
// System.out.println("方法个数" + services.size());
log.info("register service: {} {}",sdp.getClazz(), sdp.getMethod());
}
}
// 支持版本号的rpc
public <T> void register(Class<T> interfaceClass,String version, T bean){
Method[] methods = ReflectionUtils.getPublicMethods(interfaceClass);
for(Method method:methods){
ServiceInstance sis = new ServiceInstance(bean, method);
ServiceDescriptor sdp = ServiceDescriptor.from(interfaceClass,version,method);
services.put(sdp,sis);
log.info("register service: {} {}",sdp.getClazz(), sdp.getMethod());
}
}
public ServiceInstance lookup(Request request){
ServiceDescriptor sdp = request.getServiceDescriptor();
return services.get(sdp);
}
}
最终执行结果
服务端启动:

消费端调用:


浙公网安备 33010602011771号