JPA利用Java反射机制动态构建sql

第一步:在项目pom.xml 加入 JPA 框架的 maven 依赖坐标

<!-- 数据库 ORM 框架 -->
<dependency>
    <groupId>org.springframework.boot</groupId>
    <artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
JPA的maven坐标

第二步:在项目中的 model 层下创建一个实体类 CityData

/**
 * @author chaoyou
 * @email 
 * @date 2019-10-8 17:55
 * @Description 城市信息实体类
 */
@Entity
@Table(name = "city_data")
public class CityData implements Serializable {
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;
    @Column(unique = true)
    private String cityCode;    // 城市编码
    @Column(unique = true)
    private String cityName;    // 城市名字
    @Column
    private String area;    // 区域
    @Column
    private String province;    // 省份/直辖市


    @Column
    private String prefectureLevel;  // 地级城市
    @Column
    private String town;    // 县城
    @Column
    private Integer level;  // 城市等级:1、省会城市,2、地级城市,3、直辖市,4、县级市
    @Column(name = "create_time", updatable = false)
    @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")
    private Date createTime;    // 创建时间
    @Column(name = "update_time", insertable = false)
    @JsonFormat(pattern = "yyyy-MM-dd HH:mm:ss", timezone = "GMT+8")
    private Date updateTime;    // 更新时间
    @Column
    private String cityRank;        //城市级别  (T1,T2)

    public CityData() {
    }

    public CityData(Long id) {
        this.id = id;
    }

    public Long getId() {
        return id;
    }

    public void setId(Long id) {
        this.id = id;
    }

    public String getCityCode() {
        return cityCode;
    }

    public void setCityCode(String cityCode) {
        this.cityCode = cityCode;
    }

    public String getCityName() {
        return cityName;
    }

    public void setCityName(String cityName) {
        this.cityName = cityName;
    }

    public String getArea() {
        return area;
    }

    public void setArea(String area) {
        this.area = area;
    }

    public String getProvince() {
        return province;
    }

    public void setProvince(String province) {
        this.province = province;
    }

    public String getPrefectureLevel() {
        return prefectureLevel;
    }

    public void setPrefectureLevel(String prefectureLevel) {
        this.prefectureLevel = prefectureLevel;
    }

    public String getTown() {
        return town;
    }

    public void setTown(String town) {
        this.town = town;
    }

    public Integer getLevel() {
        return level;
    }

    public void setLevel(Integer level) {
        this.level = level;
    }

    public Date getCreateTime() {
        return createTime;
    }

    public void setCreateTime(Date createTime) {
        this.createTime = createTime;
    }

    public Date getUpdateTime() {
        return updateTime;
    }

    public void setUpdateTime(Date updateTime) {
        this.updateTime = updateTime;
    }

    public String getCityRank() {
        return cityRank;
    }

    public void setCityRank(String cityRank) {
        this.cityRank = cityRank;
    }
}
CityData实体类

第三步:在项目的 dao 层下创建一个供外部调用的接口 EntryMapper

import java.lang.reflect.Field;
import java.util.List;

/**
 * @author chaoyou
 * @email 
 * @date 2020-6-11 14:56
 * @Description 定义接口用于 获取实体
 */
public interface EntryMapper {

    /**
     * 获取实体类对应数据表的表名
     */
    String getEntryTableName(Class<?> clazz);

    /**
     * 获取实体类对应数据表的主键字段
     */
    String getPKFieldName(Class<?> clazz);

    /**
     * 获取实体类对应数据表的外键字段
     */
    List<String> getFKFieldName(Class<?> clazz);

    /**
     * 获取实体类对应数据表的所有字段列名
     */
    List<String> getSequenceName(Class<?> clazz);

    /**
     * 获取实体类对应数据表的所有变量
     */
    List<Field> getFieldList(Class<?> clazz);

    /**
     * 获取实体类对应数据表的常规插入操作的 sql 语句
     */
    String getSqlToSave(Class<?> clazz);

    /**
     * 获取实体类对应数据表的常规更新操作的 sql 语句
     */
    String getSqlToUpdate(Class<?> clazz, String field);
}
实体类操作接口

第四步:在项目的 impl 层下创建一个对 EntryMapper 接口的实现类

import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;

import javax.persistence.*;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.List;

/**
 * @author chaoyou
 * @email 
 * @date 2020-6-11 14:47
 * @Description 这是获取实体类型信息的工具类
 */
@Component
public class JpaEntryMapper implements EntryMapper {

    /**
     * 获取实体类的名字
     *
     * @param clazz 实体对象
     * @return
     */
    @Override
    public String getEntryTableName(Class<?> clazz) {
        // 校验参数对象是否为空
        Assert.notNull(clazz, "clazz不能为空");
        // 获取「Table」注解的控制权
        Table  tableAnno = clazz.getAnnotation(Table.class);
        // 校验是否有「table」注解
        Assert.notNull(tableAnno, "@Table注解未设置");
        // 检验「table」注解是否设置 name 属性值
        Assert.state(StringUtils.isNotEmpty(tableAnno.name()), "@Table 的 name 属性未设置");

        return tableAnno.name();
    }

    /**
     * 获取实体类中的主键属性名
     *
     * @param clazz     实体对象
     * @return
     */
    @Override
    public String getPKFieldName(Class<?> clazz) {
        Assert.notNull(clazz, "clazz不能为空");

        // 通过「Java反射机制」拿到实体类的所有被 private 修饰的属性(不包括继承属性)
        Field[] fields = clazz.getDeclaredFields();

        // 获取参数类中所有public的属性,包括继承的public属性
//        Field[] fields = clazz.getFields();

        String pk = null;

        if (fields.length == 0){
            return pk;
        }

        /**
         * 遍历属性列表,找到被「@Id」注解修饰的属性
         */
        for (Field field : fields){
            if (field.getAnnotation(Id.class) != null){
                pk = field.getName();
                break;
            }
        }
        return pk;
    }

    /**
     * 获取实体类中的外键属性
     *
     * @param clazz
     * @return
     */
    @Override
    public List<String> getFKFieldName(Class<?> clazz) {
        Assert.notNull(clazz, "clazz不能为空");

        List<String> fks = new ArrayList<>();
        String fk = null;

        Field[] fields = clazz.getDeclaredFields();
        if (fields.length == 0){
            return fks;
        }

        /**
         * 遍历属性列表,找到被「JoinColumn」注解修饰的属性
         */
        for (Field field : fields){
            JoinColumn joinColumn = field.getAnnotation(JoinColumn.class);
            if (joinColumn != null){
                if (!"".equals(joinColumn.name()) && !"".equals(joinColumn.referencedColumnName())){
                    fks.add(joinColumn.name() + "-" + joinColumn.referencedColumnName() + "-" + field.getName());
                } else if (!"".equals(joinColumn.name()) && "".equals(joinColumn.referencedColumnName())){
                    fks.add(joinColumn.name() + "-id" + "-" + field.getName());
                } else if ("".equals(joinColumn.name()) && !"".equals(joinColumn.referencedColumnName())){
                    fks.add(field.getName() + "-" + joinColumn.referencedColumnName() + "-" + field.getName());
                } else if ("".equals(joinColumn.name()) && "".equals(joinColumn.referencedColumnName())){
                    fks.add(field.getName() + "-id" + "-" + field.getName());
                }
                break;
            }
        }

        return fks;
    }

    /**
     * 获取实体类中所有属性对应的持久化字段(主键除外)
     *
     * @param clazz
     * @return
     */
    @Override
    public List<String> getSequenceName(Class<?> clazz) {
        Assert.notNull(clazz, "clazz不能为空");
        List<String> fieldList = null;
        Field[] fields = clazz.getDeclaredFields();
        if (fields.length == 0){
            return fieldList;
        }
        fieldList = new ArrayList<>();
        // 普通持久化字段
        Column column = null;
        // 非持久化字段
        Transient tran = null;
        // 外键字段
        JoinColumn joinColumn = null;
        for (Field field : fields){
            tran = field.getAnnotation(Transient.class);
            if (tran != null){
                continue;
            }
            column = field.getAnnotation(Column.class);
            joinColumn = field.getAnnotation(JoinColumn.class);
            if (column != null){
                if (!"".equals(column.name())){
                    // Column 注解的 name 属性作为其对应的数据表映射字段
                    fieldList.add(column.name());
                } else{
                    fieldList.add(field.getName());
                }
            } else if (joinColumn != null){
                if (!"".equals(joinColumn.name())){
                    // JoinColumn 注解的 name 属性作为其对应的数据表映射字段
                    fieldList.add(joinColumn.name());
                } else{
                    fieldList.add(field.getName());
                }
            }
        }
        return fieldList;
    }

    /**
     * 获取实体类中所有持久化字段的属性
     *
     * @param clazz
     * @return
     */
    @Override
    public List<Field> getFieldList(Class<?> clazz) {
        Assert.notNull(clazz, "clazz不能为空");
        List<Field> fieldList = null;
        Field[] fields = clazz.getDeclaredFields();
        if (fields.length == 0){
            return fieldList;
        }
        fieldList = new ArrayList<>();
        Column column = null;
        // 非持久化字段注解
        Transient tran = null;
        JoinColumn joinColumn = null;
        for (Field field : fields){
            tran = field.getAnnotation(Transient.class);
            if (tran != null){
                continue;
            }
            column = field.getAnnotation(Column.class);
            joinColumn = field.getAnnotation(JoinColumn.class);
            if (column != null || joinColumn != null){
                fieldList.add(field);
            }
        }
        return fieldList;
    }

    /**
     * 设置一个该实体类的 insert 持久化 sql
     *
     * @param clazz
     * @return
     */
    @Override
    public String getSqlToSave(Class<?> clazz) {
        List<String> list = getSequenceName(clazz);
        List<String> sqls = new ArrayList<>();
        for (int i=0; i<list.size(); i++){
            sqls.add("?");
        }
        String sql = "insert into " + getEntryTableName(clazz) + "("
                + ArrayUtil.getStringByArray(list.toArray(), ", ")
                + ") values(" + ArrayUtil.getStringByArray(sqls.toArray(), ", ") + ")";
        return sql;
    }

    /**
     * 设置一个该实体类的 update 持久化 sql
     *
     * @param clazz
     * @param field
     * @return
     */
    @Override
    public String getSqlToUpdate(Class<?> clazz, String field) {
        List<String> list = getSequenceName(clazz);
        List<String> sqls = new ArrayList<>();
        for (int i=0; i<list.size(); i++){
            sqls.add(list.get(i) + "=?");
        }
        String sql = "update " + getEntryTableName(clazz) + " set "
                + ArrayUtil.getStringByArray(sqls.toArray(), ", ")
                + " where " + field + "=?";
        return sql;
    }
}
接口实现类

第五步:在 test 层下创建一个测试类LCYTest

import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;

import java.lang.reflect.Field;
import java.text.ParseException;
import java.util.Date;
import java.util.List;

/**
 * @author chaoyou
 * @email 
 * @date 2020-8-12 16:44
 * @Description
 * @Reference
 */
@RunWith(SpringRunner.class)
@SpringBootTest
public class LCYTest {

    @Autowired
    private EntryMapper entryMapper;

    @Test
    public  void test09(){
        String entryTableName = entryMapper.getEntryTableName(CityData.class);
        String pkFieldName = entryMapper.getPKFieldName(CityData.class);
        List<String> fkFieldNameList = entryMapper.getFKFieldName(CityData.class);
        List<String> sequenceNameList = entryMapper.getSequenceName(CityData.class);
        List<Field> fieldList = entryMapper.getFieldList(CityData.class);
        String sqlToSave = entryMapper.getSqlToSave(CityData.class);
        String sqlToUpdate = entryMapper.getSqlToUpdate(CityData.class, "id");
        System.out.println("entryTableName:" + entryTableName);
        System.out.println("pkFieldName:" + pkFieldName);
        System.out.println("fkFieldNameList:" + fkFieldNameList);
        System.out.println("sequenceNameList:" + sequenceNameList);
        System.out.println("sqlToSave:" + sqlToSave);
        System.out.println("sqlToUpdate:" + sqlToUpdate);
    }
}
测试类

第六步:当然就是看结果了

 

posted @ 2020-08-21 16:39  朝油  阅读(439)  评论(0编辑  收藏  举报