自己动手写一个spring
首先创建一个上下文对象 application
package com.alin.teach;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URI;
import java.net.URL;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class ApplicationContext {
private final Map<String, Object> ioc = new HashMap<>();
private final Map<String, BeanDefinition> beanDefinitionMap = new HashMap<>();
private final Map<String, Object> loadingIoc = new HashMap<>();
public ApplicationContext(String packageName) throws Exception {
initContext(packageName);
}
public Object getBean(String beanName) {
Object bean = this.ioc.get(beanName);
if (bean == null) {
return createBean(beanDefinitionMap.get(beanName));
}
return bean;
}
public <T> T getBean(Class<T> beanClass) {
String beanName = this.beanDefinitionMap.values().stream().filter(bd -> beanClass.isAssignableFrom(bd.getBeanType()))
.map(BeanDefinition::getName)
.findFirst()
.orElse(null);
return (T) getBean(beanName);
}
public <T> List<T> getBeans(Class<T> beanClass) {
return this.beanDefinitionMap.values().stream()
.filter(bd -> beanClass.isAssignableFrom(bd.getBeanType()))
.map(BeanDefinition::getName)
.map(this::getBean)
.map(m -> (T) m)
.toList();
}
public void initContext(String packageName) throws Exception {
scannerPackage(packageName).stream().filter(this::scanCreate)
.forEach(this::createWrapper);
// 批量创建bean
this.beanDefinitionMap.values().forEach(this::createBean);
// 回调钩子函数
doPostConstruct();
}
private boolean scanCreate(Class<?> beanClass) {
return beanClass.isAnnotationPresent(Component.class);
}
private Object createBean(BeanDefinition beanDefinition) {
String name = beanDefinition.getName();
if (ioc.containsKey(name)) {
return ioc.get(name);
}
if (loadingIoc.containsKey(name)) {
return loadingIoc.get(name);
}
return doCreateBean(beanDefinition);
}
private Object doCreateBean(BeanDefinition beanDefinition) {
Constructor constructor = beanDefinition.getConstructor();
Object bean = null;
try {
bean = constructor.newInstance();
this.ioc.put(beanDefinition.getName(), bean);
this.loadingIoc.put(beanDefinition.getName(), bean);
// 属性注入
autoWiredBean(bean,beanDefinition);
} catch (Exception e) {
throw new RuntimeException(e);
}
return bean;
}
private void autoWiredBean(Object bean, BeanDefinition beanDefinition) {
for (Field field : beanDefinition.getField()) {
Class<?> type = field.getType();
Object fieldBean = getBean(type);
if (fieldBean == null) {
return;
}
try {
field.set(bean,fieldBean);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
private void doPostConstruct() {
this.beanDefinitionMap.values().forEach(beanDefinition -> {
// 执行回调钩子函数
Object bean = getBean(beanDefinition.getName());
Method method = beanDefinition.getMethod();
if (method != null) {
try {
method.invoke(bean);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
}
private BeanDefinition createWrapper(Class<?> beanClass) {
try {
BeanDefinition beanDefinition = new BeanDefinition(beanClass);
this.beanDefinitionMap.put(beanDefinition.getName(), beanDefinition);
return beanDefinition;
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
private List<Class<?>> scannerPackage(String packageName) throws Exception {
URL resource = this.getClass().getClassLoader().getResource(packageName.replace(".", File.separator));
URI uri = resource.toURI();
Path path = Paths.get(uri);
ArrayList<Class<?>> list = new ArrayList<>();
Files.walkFileTree(path, new SimpleFileVisitor<>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
Path absolutePath = file.toAbsolutePath();
if (absolutePath.toString().endsWith(".class")) {
String replaceStr = absolutePath.toString().replace(File.separatorChar, '.');
int packageIndex = replaceStr.indexOf(packageName);
String className = replaceStr.substring(packageIndex, replaceStr.length() - ".class".length());
try {
list.add(Class.forName(className));
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
}
// 遍历所有目录
return FileVisitResult.CONTINUE;
}
});
return list;
}
}
beanDefinition类
package com.alin.teach;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
public class BeanDefinition {
private final String name;
private final Class<?> beanType;
private final Constructor<?> constructor;
private final Method method;
private final List<Field> field;
public String getName() {
return name;
}
public BeanDefinition(Class<?> beanClass) throws NoSuchMethodException {
Component component = beanClass.getAnnotation(Component.class);
this.name = component.name().isEmpty() ? beanClass.getSimpleName() : component.name();
this.beanType = beanClass;
this.constructor = beanClass.getConstructor();
this.method = Arrays.stream(beanClass.getDeclaredMethods())
.map(m -> {
m.setAccessible(true);
return m;
})
.filter(f -> f.isAnnotationPresent(PostConstructor.class))
.findFirst().orElse(null);
this.field = Arrays.stream(beanClass.getDeclaredFields()).map(m -> {
m.setAccessible(true);
return m;
})
.filter(f -> f.isAnnotationPresent(AutoWired.class))
.toList();
}
public Constructor getConstructor() {
return constructor;
}
public Method getMethod() {
return method;
}
public List<Field> getField() {
return field;
}
public Class<?> getBeanType() {
return beanType;
}
}
浙公网安备 33010602011771号