import java.io.File;
import java.io.IOException;
import java.net.JarURLConnection;
import java.net.URL;
import java.util.*;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.logging.Logger;
public class PackageUtil {
private final static Logger log = Logger.getLogger(PackageUtil.class.getName());
/**
* 扫描包目录下的类
*
* @param packageNames package names
* @return list
*/
public static List<Class<?>> scanClasses(String... packageNames) {
// 存放扫描到的class
Set<Class<?>> scannedClasses = new HashSet<>();
// 存放扫过的资源文件,防止同样的包路径扫两遍
Set<String> scannedResources = new HashSet<>();
for (String packageName : packageNames) {
String path = packageName.replace('.', '/');
Enumeration<URL> resources;
try {
resources = Thread.currentThread().getContextClassLoader().getResources(path);
} catch (IOException e) {
throw new RuntimeException(e);
}
while (resources.hasMoreElements()) {
URL resource = resources.nextElement();
if (scannedResources.contains(resource.getFile()) || scannedResources.contains(resource.getFile() + "/")) {
continue;
}
scannedResources.add(resource.getFile());
if (resource.getProtocol().equals("file")) {
scanClassesInClasspath(packageName, resource, scannedClasses);
} else if (resource.getProtocol().equals("jar")) {
scanClassesInJar(packageName, resource, scannedClasses);
}
}
}
return new ArrayList<>(scannedClasses);
}
private static void scanClassesInClasspath(String packageName, URL resource, Set<Class<?>> scannedClasses) {
File directory = new File(resource.getFile());
if (directory.exists() && directory.list() != null) {
for (String fileName : Objects.requireNonNull(directory.list())) {
if (fileName.endsWith(".class")) {
String className = packageName + '.' + fileName.substring(0, fileName.length() - 6);
try {
scannedClasses.add(Class.forName(className));
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
} else {
String subpackageName = packageName + '.' + fileName;
scannedClasses.addAll(scanClasses(subpackageName));
}
}
}
}
private static void scanClassesInJar(String packageName, URL resource, Set<Class<?>> scannedClasses) {
JarURLConnection jarURLConnection;
JarFile jarFile;
try {
jarURLConnection = (JarURLConnection) resource.openConnection();
jarFile = jarURLConnection.getJarFile();
} catch (IOException e) {
throw new RuntimeException(e);
}
Enumeration<JarEntry> entries = jarFile.entries();
while (entries.hasMoreElements()) {
JarEntry entry = entries.nextElement();
if (!entry.getName().endsWith(".class")
|| entry.getName().endsWith("module-info.class")
|| entry.getName().endsWith("package-info.class")
|| !entry.getName().replace("/", ".").startsWith(packageName)
) {
continue;
}
String className = entry.getName().substring(0, entry.getName().length() - 6).replace("/", ".");
Class<?> clazz;
try {
clazz = Class.forName(className);
if (clazz.getPackageName().startsWith(packageName)) {
scannedClasses.add(clazz);
}
} catch (Throwable e) {
log.warning("扫描到class[%s]时异常,跳过此类。异常信息:%s".formatted(className, e.getMessage()));
}
}
}
}