自己动手写一个spring

首先创建一个上下文对象 application

package com.alin.teach;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URI;
import java.net.URL;
import java.nio.file.*;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ApplicationContext {

    private final Map<String, Object> ioc = new HashMap<>();

    private final Map<String, BeanDefinition> beanDefinitionMap = new HashMap<>();

    private final Map<String, Object> loadingIoc = new HashMap<>();


    public ApplicationContext(String packageName) throws Exception {
        initContext(packageName);
    }

    public Object getBean(String beanName) {
        Object bean = this.ioc.get(beanName);
        if (bean == null) {
           return createBean(beanDefinitionMap.get(beanName));
        }
        return bean;
    }


    public <T> T getBean(Class<T> beanClass) {
        String beanName = this.beanDefinitionMap.values().stream().filter(bd -> beanClass.isAssignableFrom(bd.getBeanType()))
                .map(BeanDefinition::getName)
                .findFirst()
                .orElse(null);
        return (T) getBean(beanName);
    }


    public <T> List<T> getBeans(Class<T> beanClass) {
        return this.beanDefinitionMap.values().stream()
                .filter(bd -> beanClass.isAssignableFrom(bd.getBeanType()))
                .map(BeanDefinition::getName)
                .map(this::getBean)
                .map(m -> (T) m)
                .toList();

    }

    public void initContext(String packageName) throws Exception {
        scannerPackage(packageName).stream().filter(this::scanCreate)
                .forEach(this::createWrapper);

        // 批量创建bean
        this.beanDefinitionMap.values().forEach(this::createBean);


        // 回调钩子函数
        doPostConstruct();
    }


    private boolean scanCreate(Class<?> beanClass) {
        return beanClass.isAnnotationPresent(Component.class);
    }

    private Object createBean(BeanDefinition beanDefinition) {
        String name = beanDefinition.getName();
        if (ioc.containsKey(name)) {
            return ioc.get(name);
        }
        if (loadingIoc.containsKey(name)) {
            return loadingIoc.get(name);
        }
        return doCreateBean(beanDefinition);

    }

    private Object doCreateBean(BeanDefinition beanDefinition) {
        Constructor constructor = beanDefinition.getConstructor();
        Object bean = null;
        try {
            bean = constructor.newInstance();
            this.ioc.put(beanDefinition.getName(), bean);
            this.loadingIoc.put(beanDefinition.getName(), bean);
            // 属性注入
            autoWiredBean(bean,beanDefinition);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }


        return bean;
    }

    private void autoWiredBean(Object bean, BeanDefinition beanDefinition) {
        for (Field field : beanDefinition.getField()) {
            Class<?> type = field.getType();
            Object fieldBean = getBean(type);
            if (fieldBean == null) {
                return;
            }
            try {
                field.set(bean,fieldBean);
            } catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }
    }


    private void doPostConstruct() {
        this.beanDefinitionMap.values().forEach(beanDefinition -> {
            // 执行回调钩子函数
            Object bean = getBean(beanDefinition.getName());
            Method method = beanDefinition.getMethod();
            if (method != null) {
                try {
                    method.invoke(bean);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        });

    }


    private BeanDefinition createWrapper(Class<?> beanClass) {
        try {
            BeanDefinition beanDefinition = new BeanDefinition(beanClass);
            this.beanDefinitionMap.put(beanDefinition.getName(), beanDefinition);
            return beanDefinition;
        } catch (NoSuchMethodException e) {
            throw new RuntimeException(e);
        }
    }

    private List<Class<?>> scannerPackage(String packageName) throws Exception {
        URL resource = this.getClass().getClassLoader().getResource(packageName.replace(".", File.separator));
        URI uri = resource.toURI();
        Path path = Paths.get(uri);

        ArrayList<Class<?>> list = new ArrayList<>();

        Files.walkFileTree(path, new SimpleFileVisitor<>() {
            @Override
            public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
                Path absolutePath = file.toAbsolutePath();
                if (absolutePath.toString().endsWith(".class")) {
                    String replaceStr = absolutePath.toString().replace(File.separatorChar, '.');
                    int packageIndex = replaceStr.indexOf(packageName);
                    String className = replaceStr.substring(packageIndex, replaceStr.length() - ".class".length());
                    try {
                        list.add(Class.forName(className));
                    } catch (ClassNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                }
                // 遍历所有目录
                return FileVisitResult.CONTINUE;
            }
        });

        return list;
    }


}

beanDefinition类

package com.alin.teach;

import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;

public class BeanDefinition {

    private final String name;

    private final Class<?> beanType;

    private final Constructor<?> constructor;

    private final Method method;

    private final List<Field> field;

    public String getName() {
        return name;
    }

    public BeanDefinition(Class<?> beanClass) throws NoSuchMethodException {
        Component component = beanClass.getAnnotation(Component.class);
        this.name = component.name().isEmpty() ? beanClass.getSimpleName() : component.name();
        this.beanType = beanClass;
        this.constructor = beanClass.getConstructor();
        this.method = Arrays.stream(beanClass.getDeclaredMethods())
                .map(m -> {
                    m.setAccessible(true);
                    return m;
                })
                .filter(f -> f.isAnnotationPresent(PostConstructor.class))
                .findFirst().orElse(null);

        this.field = Arrays.stream(beanClass.getDeclaredFields()).map(m -> {
                    m.setAccessible(true);
                    return m;
                })
                .filter(f -> f.isAnnotationPresent(AutoWired.class))
                .toList();
    }

    public Constructor getConstructor() {
        return constructor;
    }

    public Method getMethod() {
        return method;
    }

    public List<Field> getField() {
        return field;
    }

    public Class<?> getBeanType() {
        return beanType;
    }
}

posted @ 2025-04-06 21:50  alinnnn  阅读(5)  评论(0)    收藏  举报