import org.springframework.core.io.DefaultResourceLoader;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.ResourcePatternUtils;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.core.type.classreading.SimpleMetadataReaderFactory;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.bind.annotation.RestController;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
public class App {
private static final Logger logger = LoggerFactory.getLogger(App.class);
private static final String RESOURCE_PATTERN = "**/*.class";
public static void main(String[] args) throws Exception {
System.out.println(scanPackages("com.g2.xx.trade.svr.rest".split(",")));
}
static List<Class<?>> scanPackages(String[] basePackages) {
List<Class<?>> candidates = new ArrayList<Class<?>>();
for (String pkg : basePackages) {
try {
candidates.addAll(findCandidateClasses(pkg));
} catch (IOException e) {
logger.error("扫描指定注解@RestController的基础包{}时出现异常", pkg);
continue;
}
}
return candidates;
}
/**
* 获取符合要求的Controller名称
*
* @param basePackage
* @return
* @throws IOException
*/
private static List<Class<?>> findCandidateClasses(String basePackage) throws IOException {
if (logger.isDebugEnabled()) {
logger.debug("开始扫描指定包{}下的所有类" + basePackage);
}
List<Class<?>> candidates = new ArrayList<Class<?>>();
String packageSearchPath = replaceDotByDelimiter(basePackage) + '/' + RESOURCE_PATTERN;
ResourceLoader resourceLoader = new DefaultResourceLoader();
MetadataReaderFactory readerFactory = new SimpleMetadataReaderFactory(resourceLoader);
Resource[] resources = ResourcePatternUtils.getResourcePatternResolver(resourceLoader).getResources(packageSearchPath);
for (Resource resource : resources) {
MetadataReader reader = readerFactory.getMetadataReader(resource);
Class<?> candidateClass = transform(reader.getClassMetadata().getClassName());
if (candidateClass == null) {
continue;
}
RestController alias = candidateClass.getAnnotation(RestController.class);
if (alias == null) {
continue;
}
candidates.add(candidateClass);
logger.debug("扫描到@RestController注解基础类:{}" + candidateClass.getName());
}
return candidates;
}
/**
* 用"/"替换包路径中"."
*
* @param path
* @return
*/
private static String replaceDotByDelimiter(String path) {
return StringUtils.replace(path, ".", "/");
}
/**
* @param className
* @return
*/
private static Class<?> transform(String className) {
Class<?> clazz = null;
try {
clazz = ClassUtils.forName(className, App.class.getClassLoader());
} catch (ClassNotFoundException e) {
logger.error("未找到指定类:{}", className);
}
return clazz;
}
}