Spark(三)模仿Spark实现Java发送一个类到服务端(Executor)运行

我们知道Spark可以发送一个函数到Executor,Executor然后会加载并这个函数,然后在JVM中运行。本文用简单的例子模拟了这个过程。

工程和类的关系

我们一共有3个maven工程:

  1. remotecall-base,包含一个Task接口,Task接口有一个方法是run。
  2. remotecall-client,依赖remotecall-base工程, 包含一个ClientTask的Task的实现类。
  3. remotecall-server,依赖remotecall-base工程。

这3个工程的关系是client会把ClientTask类的字节码发送给server,server读到这个字节码后会将ClientTask加载起来然后执行它的run方法。需要注意的细节是当client修改了ClientTask后,发送给server的字节码也会发生变化,而如果server已经加载了这个类,则会抛出重复加载类异常,关于这个问题我使用了Thread的Context ClassLoader(getContextClassLoader方法):每个请求都会创建一条全新的线程,这个线程有全新的ClassLoader,每次请求ClassLoader都会生成一个全新的ClientTask类,这样既可以避免新的ClientTask不被忽略,也避免了重复加载类异常。

remotecall-base

定义了client和server都知晓的Task接口:

package com.github.ralgond.remotecall.base;

public interface Task {
	public void run();
}

remotecall-client

依赖于remotecall-base工程,定义了client知晓而server并不知晓的ClientTask:

package com.github.ralgond.remotecall.client;

import com.github.ralgond.remotecall.base.*;

public class ClientTask implements Task {

	@Override
	public void run() {
		System.out.println("Hello World");
	}
}

Client的Main类如下:

package com.github.ralgond.remotecall.client;

import java.io.BufferedInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.Socket;
import java.net.URL;

public class Main {

	public static void main(String args[]) {
		String className1 = ClientTask.class.getName();
		String className2 = className1.replace('.', '/');
		String classFile = "/" + className2 + ".class";
		URL url = Main.class.getResource(classFile);
		System.out.println(url);

		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		BufferedInputStream in = null;
		try {
			in = new BufferedInputStream(new FileInputStream(url.getPath()));
			int buf_size = 1024;
			byte[] buffer = new byte[buf_size];
			int len = 0;
			while (-1 != (len = in.read(buffer, 0, buf_size))) {
				bos.write(buffer, 0, len);
			}
			byte[] bytecode = bos.toByteArray();
			System.out.println(bytecode.length);

			Socket socket = new Socket("127.0.0.1", 55551);

			OutputStream outputStream = socket.getOutputStream();
			outputStream.write(className1.length());
			outputStream.write(className1.getBytes());
			outputStream.write(bytecode);
			outputStream.flush();

			try {
				Thread.sleep(3000);
			} catch (InterruptedException e) {
				e.printStackTrace();
			}

			outputStream.close();
			socket.close();

		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			try {
				in.close();
			} catch (IOException e) {
				e.printStackTrace();
			}
			try {
				bos.close();
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
	}
}

remotecall-server

server并不知晓ClientTask,他需要读取来自client的bytearray,然后将其中的类信息加载到内存中,它的Main2代码如下:

package com.github.ralgond.remotecall.server;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;

import com.github.ralgond.remotecall.base.Task;

public class Main2 {
	class MyClassLoader extends ClassLoader {
		public Class<?> findOrLoad(String className, byte[] b, int offset, int length) {
			Class<?> clz = null;
			clz = this.findLoadedClass(className);

			if (clz == null) {
				System.out.println("defineClass");
				clz = this.defineClass(null, b, offset, length);
			} else {
				System.out.println("foundClass");
			}

			return clz;
		}
	}

	class MyThread extends Thread {
		byte[] bytecode;

		public MyThread(byte[] bytecode) {
			this.setContextClassLoader(new MyClassLoader());
			this.bytecode = bytecode;
		}

		@Override
		public void run() {
			int classNameLen = bytecode[0];
			String className = new String(bytecode, 1, classNameLen);

			MyClassLoader classLoader = (MyClassLoader) this.getContextClassLoader();
			Class<?> clz = classLoader.findOrLoad(className, bytecode, 1 + classNameLen,
					bytecode.length - (1 + classNameLen));
			try {
				Task t = (Task) clz.newInstance();
				t.run();
			} catch (InstantiationException e) {
				e.printStackTrace();
			} catch (IllegalAccessException e) {
				e.printStackTrace();
			}
		}
	}

	public void run(String args[]) {

		try {
			byte[] byteArray = new byte[2048];
			while (true) {
				ServerSocket serverSocket = new ServerSocket(55551);
				Socket socket = serverSocket.accept();

				InputStream inputStream = socket.getInputStream();

				ByteArrayOutputStream bytecodeOutputStream = new ByteArrayOutputStream();

				int readLength = inputStream.read(byteArray);
				while (readLength != -1) {
					bytecodeOutputStream.write(byteArray, 0, readLength);
					readLength = inputStream.read(byteArray);
				}

				byte[] bytecode = bytecodeOutputStream.toByteArray();
				System.out.println(bytecode.length);

				MyThread t = new MyThread(bytecode);
				t.start();
				try {
					t.join();
				} catch (InterruptedException e) {
					e.printStackTrace();
				}

				bytecodeOutputStream.close();
				inputStream.close();
				socket.close();
				serverSocket.close();
			}
		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	public static void main(String args[]) {
		Main2 m = new Main2();
		m.run(args);
	}
}
posted @ 2021-06-28 18:31  ralgo  阅读(141)  评论(0)    收藏  举报