1 public class CustomJavaCompiler {
2 //源码
3 private String sourceCode;
4 //类全名
5 private String fullClassName;
6 //获取java的编译器
7 private JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
8 //存放编译之后的字节码(key:类全名,value:编译之后输出的字节码)
9 private Map<String, ByteJavaFileObject> javaFileObjectMap = new ConcurrentHashMap<>();
10 //存放编译过程中输出的信息
11 private DiagnosticCollector<JavaFileObject> diagnosticsCollector = new DiagnosticCollector<>();
12 //编译耗时(单位ms)
13 private long compilerTime;
14
15 public CustomJavaCompiler(String sourceCode) {
16 this.sourceCode = sourceCode;
17 this.fullClassName = getFullClassName(sourceCode);
18 }
19
20 /**
21 * 编译字符串源代码,编译失败在 diagnosticsCollector 中获取提示信息
22 *
23 * @return true:编译成功 false:编译失败
24 */
25 public boolean compiler() {
26 if(compiler == null)
27 return false;
28
29 long startTime = System.currentTimeMillis();
30 //标准的内容管理器,更换成自己的实现,覆盖部分方法
31 StandardJavaFileManager standardFileManager = compiler.getStandardFileManager(diagnosticsCollector, null, null);
32 JavaFileManager javaFileManager = new StringJavaFileManage(standardFileManager);
33 //构造源代码对象
34 JavaFileObject javaFileObject = new StringJavaFileObject(fullClassName, sourceCode);
35 //获取一个编译任务
36 JavaCompiler.CompilationTask task = compiler.getTask(null, javaFileManager, diagnosticsCollector, null, null, Arrays.asList(javaFileObject));
37 //设置编译耗时
38 compilerTime = System.currentTimeMillis() - startTime;
39 return task.call();
40 }
41
42 /**
43 * 获取编译后的Class
44 * @return
45 */
46 public Class<?> getCompilerClass() {
47 StringClassLoader scl = new StringClassLoader();
48 Class<?> clz = null;
49 try {
50 clz = scl.findClass(fullClassName);
51 } catch (Exception e) {
52 e.printStackTrace();
53 }
54 return clz;
55 }
56
57 /**
58 * 获取编译时产生的信息
59 * @return 编译信息(错误 警告)
60 */
61 public String getCompilerMessage() {
62 if(compiler == null)
63 return "JRE环境未配置(请复制JDK路径下lib目录内的tools.jar到JRE路径下lib目录里)";
64
65 StringBuilder sb = new StringBuilder();
66 List<Diagnostic<? extends JavaFileObject>> diagnostics = diagnosticsCollector.getDiagnostics();
67 for (Diagnostic diagnostic : diagnostics) {
68 sb.append(diagnostic.toString()).append("\r\n");
69 }
70 return sb.toString();
71 }
72
73 public long getCompilerTime() {
74 return compilerTime;
75 }
76
77 /**
78 * 获取类的全名称
79 * @param sourceCode 源码
80 * @return 类的全名称
81 */
82 public static String getFullClassName(String sourceCode) {
83 String className = "";
84 Pattern pattern = Pattern.compile("package\\s+\\S+\\s*;");
85 Matcher matcher = pattern.matcher(sourceCode);
86 if (matcher.find()) {
87 className = matcher.group().replaceFirst("package", "").replace(";", "").trim() + ".";
88 }
89
90 pattern = Pattern.compile("class\\s+\\S+\\s+\\{");
91 matcher = pattern.matcher(sourceCode);
92 if (matcher.find()) {
93 className += matcher.group().replaceFirst("class", "").replace("{", "").trim();
94 }
95 return className;
96 }
97
98 /**
99 * 自定义一个字符串的源码对象
100 */
101 private class StringJavaFileObject extends SimpleJavaFileObject {
102 //等待编译的源码字段
103 private String contents;
104
105 //java源代码 => StringJavaFileObject对象 的时候使用
106 public StringJavaFileObject(String className, String contents) {
107 super(URI.create("string:///" + className.replaceAll("\\.", "/") + Kind.SOURCE.extension), Kind.SOURCE);
108 this.contents = contents;
109 }
110
111 //字符串源码会调用该方法
112 @Override
113 public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException {
114 return contents;
115 }
116
117 }
118
119 /**
120 * 自定义一个编译之后的字节码对象
121 */
122 private class ByteJavaFileObject extends SimpleJavaFileObject {
123 //存放编译后的字节码
124 private ByteArrayOutputStream outPutStream;
125
126 public ByteJavaFileObject(String className, Kind kind) {
127 super(URI.create("string:///" + className.replaceAll("\\.", "/") + Kind.SOURCE.extension), kind);
128 }
129
130 //StringJavaFileManage 编译之后的字节码输出会调用该方法(把字节码输出到outputStream)
131 @Override
132 public OutputStream openOutputStream() {
133 outPutStream = new ByteArrayOutputStream();
134 return outPutStream;
135 }
136
137 //在类加载器加载的时候需要用到
138 public byte[] getCompiledBytes() {
139 return outPutStream.toByteArray();
140 }
141 }
142
143 /**
144 * 自定义一个JavaFileManage来控制编译之后字节码的输出位置
145 */
146 private class StringJavaFileManage extends ForwardingJavaFileManager {
147 StringJavaFileManage(JavaFileManager fileManager) {
148 super(fileManager);
149 }
150
151 //获取输出的文件对象,它表示给定位置处指定类型的指定类
152 @Override
153 public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind, FileObject sibling) throws IOException {
154 ByteJavaFileObject javaFileObject = new ByteJavaFileObject(className, kind);
155 javaFileObjectMap.put(className, javaFileObject);
156 return javaFileObject;
157 }
158 }
159
160 /**
161 * 自定义类加载器, 用来加载动态的字节码
162 */
163 private class StringClassLoader extends ClassLoader {
164 @Override
165 protected Class<?> findClass(String name) throws ClassNotFoundException {
166 ByteJavaFileObject fileObject = javaFileObjectMap.get(name);
167 if (fileObject != null) {
168 byte[] bytes = fileObject.getCompiledBytes();
169 return defineClass(name, bytes, 0, bytes.length);
170 }
171 try {
172 return ClassLoader.getSystemClassLoader().loadClass(name);
173 } catch (Exception e) {
174 return super.findClass(name);
175 }
176 }
177 }
178
179 public static <T> T invokeMethod(Object object, String methodName, Class<?>[] classes, Object... args)
180 throws Exception {
181 Method method = object.getClass().getMethod(methodName, classes);
182 return (T) method.invoke(object, args);
183 }
184
185 public static void main(String[] args) {
186 String code = "public class Test {\n" +
187 " public static int runEveryTick() {\n" +
188 "\t\tfor(int i=0; i < 2; i++){\n" +
189 "\t\t\t System.out.println(10);\n" +
190 "\t\t}\n" +
191 "\t\treturn 1;" +
192 " }\n" +
193 "}";
194 CustomJavaCompiler compiler = new CustomJavaCompiler(code);
195 boolean res = compiler.compiler();
196 if (res) {
197 System.out.println("compilerSuccess:" + compiler.getCompilerMessage());
198 System.out.println("compilerTime:" + compiler.getCompilerTime());
199 try {
200 Class<?> clz = compiler.getCompilerClass();
201 int ret = invokeMethod(clz.newInstance(), "runEveryTick", null);
202 System.out.println("ret:" + ret);
203 } catch (Exception e) {
204 e.printStackTrace();
205 }
206 } else {
207 System.out.println("compilerFailed:" + compiler.getCompilerMessage());
208 }
209 }
210
211 }