mybatis

点击查看
import lombok.Data;

import java.sql.*;
import java.util.*;
 
public class CodeGenerator {

    // 数据库连接配置
    private static final String DB_URL = "jdbc:mysql://localhost:3306/test";
    private static final String DB_USERNAME = "DB_USERNAME";
    private static final String DB_PASSWORD = "DB_PASSWORD";

    // 代码生成配置
    private static final String BASE_PACKAGE = "com.example";
    private static final String AUTHOR = "CodeGenerator";
    private static final String OUTPUT_DIR = "./src/main/java/";
    private static final String XML_OUTPUT_DIR = "./src/main/resources/mapper/";

    public static void main(String[] args) {
        try {
            // 加载数据库驱动
            Class.forName("com.mysql.cj.jdbc.Driver");

            // 获取数据库连接
            try (Connection connection = DriverManager.getConnection(DB_URL, DB_USERNAME, DB_PASSWORD)) {
                System.out.println("数据库连接成功!");

                // 获取所有表名(这里以user表为例,您可以修改为需要生成的表名)
                //List<String> tableNames = getTableNames(connection);
                //或者指定特定表:
                List<String> tableNames = Arrays.asList("demo");

                for (String tableName : tableNames) {
                    System.out.println("开始生成表 " + tableName + " 的代码...");

                    // 获取表结构信息
                    TableInfo tableInfo = getTableInfo(connection, tableName);

                    // 生成实体类
                    generateEntity(tableInfo);

                    // 生成Mapper接口
                    generateMapperInterface(tableInfo);

                    // 生成Mapper XML
                    generateMapperXml(tableInfo);

                    System.out.println("表 " + tableName + " 的代码生成完成!");
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 获取数据库中的所有表名
     */
    private static List<String> getTableNames(Connection connection) throws SQLException {
        List<String> tableNames = new ArrayList<>();
        DatabaseMetaData metaData = connection.getMetaData();

        try (ResultSet tables = metaData.getTables(connection.getCatalog(), null, null, new String[]{"TABLE"})) {
            while (tables.next()) {
                String tableName = tables.getString("TABLE_NAME");
                // 过滤系统表
                if (!tableName.toLowerCase().contains("schema") &&
                        !tableName.toLowerCase().contains("temp") &&
                        !tableName.toLowerCase().contains("backup")) {
                    tableNames.add(tableName);
                }
            }
        }

        return tableNames;
    }

    /**
     * 获取表结构信息
     */
    private static TableInfo getTableInfo(Connection connection, String tableName) throws SQLException {
        TableInfo tableInfo = new TableInfo();
        tableInfo.setTableName(tableName);
        tableInfo.setClassName(convertToCamelCase(tableName, true));
        tableInfo.setObjectName(convertToCamelCase(tableName, false));

        DatabaseMetaData metaData = connection.getMetaData();

        // 获取列信息
        try (ResultSet columns = metaData.getColumns(connection.getCatalog(), null, tableName, null)) {
            List<ColumnInfo> columnList = new ArrayList<>();
            boolean hasBigDecimal = false;
            boolean hasLocalDateTime = false;

            while (columns.next()) {
                ColumnInfo column = new ColumnInfo();
                column.setColumnName(columns.getString("COLUMN_NAME"));
                column.setPropertyName(convertToCamelCase(column.getColumnName(), false));
                column.setJdbcType(convertJdbcType(columns.getString("TYPE_NAME")));
                column.setJavaType(convertJavaType(columns.getString("TYPE_NAME"), columns.getInt("COLUMN_SIZE")));
                column.setRemarks(columns.getString("REMARKS"));
                column.setNullable(columns.getInt("NULLABLE") == DatabaseMetaData.columnNullable);

                if ("BigDecimal".equals(column.getJavaType())) {
                    hasBigDecimal = true;
                }
                if ("LocalDateTime".equals(column.getJavaType()) || "LocalDate".equals(column.getJavaType()) || "LocalTime".equals(column.getJavaType())) {
                    hasLocalDateTime = true;
                }

                columnList.add(column);
            }

            tableInfo.setColumns(columnList);
            tableInfo.setHasBigDecimal(hasBigDecimal);
            tableInfo.setHasLocalDateTime(hasLocalDateTime);
        }

        // 获取主键信息
        try (ResultSet primaryKeys = metaData.getPrimaryKeys(connection.getCatalog(), null, tableName)) {
            List<String> primaryKeyList = new ArrayList<>();
            while (primaryKeys.next()) {
                primaryKeyList.add(primaryKeys.getString("COLUMN_NAME"));
            }
            tableInfo.setPrimaryKeys(primaryKeyList);
        }

        return tableInfo;
    }

    /**
     * 生成实体类(使用Lombok注解)
     */
    private static void generateEntity(TableInfo tableInfo) {
        StringBuilder sb = new StringBuilder();

        // 包名
        sb.append("package ").append(BASE_PACKAGE).append(".entity;\n\n");

        // 导入语句
        generateEntityImports(sb, tableInfo);

        // 类注释
        sb.append("/**\n");
        sb.append(" * ").append(tableInfo.getTableName()).append(" 实体类\n");
        sb.append(" * \n");
        sb.append(" * @author ").append(AUTHOR).append("\n");
        sb.append(" * @date ").append(new java.util.Date()).append("\n");
        sb.append(" */\n");

        // Lombok注解
        sb.append("@Data\n");
        sb.append("@Builder\n");
        sb.append("@NoArgsConstructor\n");
        sb.append("@AllArgsConstructor\n");

        // 类定义
        sb.append("public class ").append(tableInfo.getClassName()).append(" implements Serializable {\n");
        sb.append("    private static final long serialVersionUID = 1L;\n\n");

        // 字段定义
        for (ColumnInfo column : tableInfo.getColumns()) {
            sb.append("    /**\n");
            sb.append("     * ").append(column.getRemarks().isEmpty() ? column.getColumnName() : column.getRemarks()).append("\n");
            sb.append("     */\n");
            sb.append("    private ").append(column.getJavaType()).append(" ").append(column.getPropertyName()).append(";\n\n");
        }

        sb.append("}\n");

        // 写入文件
        writeToFile(OUTPUT_DIR + BASE_PACKAGE.replace('.', '/') + "/entity/",
                tableInfo.getClassName() + ".java", sb.toString());
    }

    /**
     * 生成实体类的导入语句
     */
    private static void generateEntityImports(StringBuilder sb, TableInfo tableInfo) {
        sb.append("import java.io.Serializable;\n");
        if (tableInfo.isHasBigDecimal()) {
            sb.append("import java.math.BigDecimal;\n");
        }
        if (tableInfo.isHasLocalDateTime()) {
            sb.append("import java.time.LocalDateTime;\n");
            sb.append("import java.time.LocalDate;\n");
            sb.append("import java.time.LocalTime;\n");
        }
        sb.append("\n");
        sb.append("import lombok.Data;\n");
        sb.append("import lombok.Builder;\n");
        sb.append("import lombok.NoArgsConstructor;\n");
        sb.append("import lombok.AllArgsConstructor;\n\n");
    }

    /**
     * 生成Mapper接口
     */
    private static void generateMapperInterface(TableInfo tableInfo) {
        StringBuilder sb = new StringBuilder();

        // 包名和导入
        sb.append("package ").append(BASE_PACKAGE).append(".mapper;\n\n");
        sb.append("import ").append(BASE_PACKAGE).append(".entity.").append(tableInfo.getClassName()).append(";\n");
        sb.append("import org.apache.ibatis.annotations.Mapper;\n");
        sb.append("import org.apache.ibatis.annotations.Param;\n");
        sb.append("import java.util.List;\n\n");

        // 接口注释
        sb.append("/**\n");
        sb.append(" * ").append(tableInfo.getTableName()).append(" Mapper接口\n");
        sb.append(" * \n");
        sb.append(" * @author ").append(AUTHOR).append("\n");
        sb.append(" * @date ").append(new java.util.Date()).append("\n");
        sb.append(" */\n");

        // 接口定义
        sb.append("@Mapper\n");
        sb.append("public interface ").append(tableInfo.getClassName()).append("Mapper {\n\n");

        // 基本CRUD方法
        sb.append("    int deleteByPrimaryKey(").append(getPrimaryKeyParams(tableInfo)).append(");\n\n");

        sb.append("    int insert(").append(tableInfo.getClassName()).append(" record);\n\n");

        sb.append("    int insertSelective(").append(tableInfo.getClassName()).append(" record);\n\n");

        sb.append("    ").append(tableInfo.getClassName()).append(" selectByPrimaryKey(").append(getPrimaryKeyParams(tableInfo)).append(");\n\n");

        sb.append("    int updateByPrimaryKey(").append(tableInfo.getClassName()).append(" record);\n\n");

        sb.append("    int updateByPrimaryKeySelective(").append(tableInfo.getClassName()).append(" record);\n\n");

        // 批量操作方法
        sb.append("    int batchInsert(@Param(\"list\") List<").append(tableInfo.getClassName()).append("> list);\n\n");

        sb.append("    int batchInsertSelective(@Param(\"list\") List<").append(tableInfo.getClassName()).append("> list);\n\n");

        // 新增批量更新方法
        sb.append("    int batchUpdate(@Param(\"list\") List<").append(tableInfo.getClassName()).append("> list);\n\n");

        sb.append("    int batchUpdateSelective(@Param(\"list\") List<").append(tableInfo.getClassName()).append("> list);\n\n");

        sb.append("    int batchUpdateCaseWhen(@Param(\"list\") List<").append(tableInfo.getClassName()).append("> list);\n\n");

        // 查询方法
        sb.append("    List<").append(tableInfo.getClassName()).append("> selectAll();\n\n");

        sb.append("    List<").append(tableInfo.getClassName()).append("> selectByCondition(")
                .append(tableInfo.getClassName()).append(" condition);\n\n");

        sb.append("    long countByCondition(").append(tableInfo.getClassName()).append(" condition);\n\n");

        sb.append("}\n");

        // 写入文件
        writeToFile(OUTPUT_DIR + BASE_PACKAGE.replace('.', '/') + "/mapper/",
                tableInfo.getClassName() + "Mapper.java", sb.toString());
    }

    /**
     * 生成Mapper XML文件(完整版)
     */
    private static void generateMapperXml(TableInfo tableInfo) {
        StringBuilder sb = new StringBuilder();

        // XML声明和DOCTYPE
        sb.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n");
        sb.append("<!DOCTYPE mapper PUBLIC \"-//mybatis.org//DTD Mapper 3.0//EN\" \n");
        sb.append("    \"http://mybatis.org/dtd/mybatis-3-mapper.dtd\">\n\n");

        // Mapper命名空间
        sb.append("<mapper namespace=\"").append(BASE_PACKAGE).append(".mapper.")
                .append(tableInfo.getClassName()).append("Mapper\">\n\n");

        // 生成BaseResultMap
        generateResultMap(sb, tableInfo);

        // 生成所有列
        generateBaseColumnList(sb, tableInfo);

        // 生成SQL片段
        generateBaseColumnListSql(sb, tableInfo);

        // 基本CRUD操作
        generateInsert(sb, tableInfo);
        generateInsertSelective(sb, tableInfo);
        generateSelectByPrimaryKey(sb, tableInfo);
        generateUpdateByPrimaryKey(sb, tableInfo);
        generateUpdateByPrimaryKeySelective(sb, tableInfo);
        generateDeleteByPrimaryKey(sb, tableInfo);

        // 批量操作
        generateBatchInsert(sb, tableInfo);
        generateBatchInsertSelective(sb, tableInfo);
        generateBatchUpdate(sb, tableInfo);
        generateBatchUpdateSelective(sb, tableInfo);
        generateBatchUpdateCaseWhen(sb, tableInfo);

        // 查询操作
        generateSelectAll(sb, tableInfo);
        generateSelectByCondition(sb, tableInfo);
        generateCountByCondition(sb, tableInfo);

        sb.append("</mapper>");

        // 写入文件
        writeToFile(XML_OUTPUT_DIR, tableInfo.getClassName() + "Mapper.xml", sb.toString());
    }

    /**
     * 生成ResultMap
     */
    private static void generateResultMap(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <resultMap id=\"BaseResultMap\" type=\"").append(BASE_PACKAGE)
                .append(".entity.").append(tableInfo.getClassName()).append("\">\n");

        for (ColumnInfo column : tableInfo.getColumns()) {
            if (tableInfo.getPrimaryKeys().contains(column.getColumnName())) {
                sb.append("        <id column=\"").append(column.getColumnName()).append("\" property=\"")
                        .append(column.getPropertyName()).append("\" jdbcType=\"").append(column.getJdbcType()).append("\" />\n");
            } else {
                sb.append("        <result column=\"").append(column.getColumnName()).append("\" property=\"")
                        .append(column.getPropertyName()).append("\" jdbcType=\"").append(column.getJdbcType()).append("\" />\n");
            }
        }

        sb.append("    </resultMap>\n\n");
    }

    /**
     * 生成基础列列表
     */
    private static void generateBaseColumnList(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <sql id=\"Base_Column_List\">\n");
        sb.append("        ");

        for (int i = 0; i < tableInfo.getColumns().size(); i++) {
            ColumnInfo column = tableInfo.getColumns().get(i);
            sb.append(column.getColumnName());
            if (i < tableInfo.getColumns().size() - 1) {
                sb.append(", ");
                if ((i + 1) % 5 == 0) {
                    sb.append("\n        ");
                }
            }
        }

        sb.append("\n    </sql>\n\n");
    }

    /**
     * 生成基础列列表SQL
     */
    private static void generateBaseColumnListSql(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <sql id=\"Base_Column_List_Sql\">\n");

        for (int i = 0; i < tableInfo.getColumns().size(); i++) {
            ColumnInfo column = tableInfo.getColumns().get(i);
            sb.append("        ").append(column.getColumnName());
            if (i < tableInfo.getColumns().size() - 1) {
                sb.append(",\n");
            }
        }

        sb.append("\n    </sql>\n\n");
    }

    /**
     * 生成插入语句
     */
    private static void generateInsert(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <insert id=\"insert\" parameterType=\"").append(BASE_PACKAGE)
                .append(".entity.").append(tableInfo.getClassName()).append("\">\n");
        sb.append("        insert into ").append(tableInfo.getTableName()).append(" (");

        for (int i = 0; i < tableInfo.getColumns().size(); i++) {
            ColumnInfo column = tableInfo.getColumns().get(i);
            sb.append(column.getColumnName());
            if (i < tableInfo.getColumns().size() - 1) {
                sb.append(", ");
            }
        }

        sb.append(")\n        values (");

        for (int i = 0; i < tableInfo.getColumns().size(); i++) {
            ColumnInfo column = tableInfo.getColumns().get(i);
            sb.append("#{").append(column.getPropertyName()).append(",jdbcType=").append(column.getJdbcType()).append("}");
            if (i < tableInfo.getColumns().size() - 1) {
                sb.append(", ");
            }
        }

        sb.append(")\n    </insert>\n\n");
    }

    /**
     * 生成选择性插入语句(修复字符串空字符判断)
     */
    private static void generateInsertSelective(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <insert id=\"insertSelective\" parameterType=\"").append(BASE_PACKAGE)
                .append(".entity.").append(tableInfo.getClassName()).append("\">\n");
        sb.append("        insert into ").append(tableInfo.getTableName()).append("\n");
        sb.append("        <trim prefix=\"(\" suffix=\")\" suffixOverrides=\",\">\n");

        for (ColumnInfo column : tableInfo.getColumns()) {
            // 字符串类型增加空字符判断
            if ("String".equals(column.getJavaType())) {
                sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null and ").append(column.getPropertyName()).append(" != ''\">\n");
            } else {
                sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null\">\n");
            }
            sb.append("                ").append(column.getColumnName()).append(",\n");
            sb.append("            </if>\n");
        }

        sb.append("        </trim>\n");
        sb.append("        <trim prefix=\"values (\" suffix=\")\" suffixOverrides=\",\">\n");

        for (ColumnInfo column : tableInfo.getColumns()) {
            // 字符串类型增加空字符判断
            if ("String".equals(column.getJavaType())) {
                sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null and ").append(column.getPropertyName()).append(" != ''\">\n");
            } else {
                sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null\">\n");
            }
            sb.append("                #{").append(column.getPropertyName()).append(",jdbcType=").append(column.getJdbcType()).append("},\n");
            sb.append("            </if>\n");
        }

        sb.append("        </trim>\n");
        sb.append("    </insert>\n\n");
    }

    /**
     * 生成根据主键查询语句
     */
    private static void generateSelectByPrimaryKey(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <select id=\"selectByPrimaryKey\" resultMap=\"BaseResultMap\">\n");
        sb.append("        select \n");
        sb.append("        <include refid=\"Base_Column_List\" />\n");
        sb.append("        from ").append(tableInfo.getTableName()).append("\n");
        sb.append("        where ").append(generatePrimaryKeyWhereClause(tableInfo)).append("\n");
        sb.append("    </select>\n\n");
    }

    /**
     * 生成更新语句
     */
    private static void generateUpdateByPrimaryKey(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <update id=\"updateByPrimaryKey\" parameterType=\"").append(BASE_PACKAGE)
                .append(".entity.").append(tableInfo.getClassName()).append("\">\n");
        sb.append("        update ").append(tableInfo.getTableName()).append("\n");
        sb.append("        set \n");

        boolean first = true;
        for (ColumnInfo column : tableInfo.getColumns()) {
            if (!tableInfo.getPrimaryKeys().contains(column.getColumnName())) {
                if (!first) {
                    sb.append(",\n");
                }
                sb.append("            ").append(column.getColumnName()).append(" = #{").append(column.getPropertyName())
                        .append(",jdbcType=").append(column.getJdbcType()).append("}");
                first = false;
            }
        }
        sb.append("\n");

        sb.append("        where ").append(generatePrimaryKeyWhereClause(tableInfo)).append("\n");
        sb.append("    </update>\n\n");
    }

    /**
     * 生成选择性更新语句(修复字符串空字符判断)
     */
    private static void generateUpdateByPrimaryKeySelective(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <update id=\"updateByPrimaryKeySelective\" parameterType=\"").append(BASE_PACKAGE)
                .append(".entity.").append(tableInfo.getClassName()).append("\">\n");
        sb.append("        update ").append(tableInfo.getTableName()).append("\n");
        sb.append("        <set>\n");

        for (ColumnInfo column : tableInfo.getColumns()) {
            if (!tableInfo.getPrimaryKeys().contains(column.getColumnName())) {
                // 字符串类型增加空字符判断
                if ("String".equals(column.getJavaType())) {
                    sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null and ").append(column.getPropertyName()).append(" != ''\">\n");
                } else {
                    sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null\">\n");
                }
                sb.append("                ").append(column.getColumnName()).append(" = #{").append(column.getPropertyName())
                        .append(",jdbcType=").append(column.getJdbcType()).append("},\n");
                sb.append("            </if>\n");
            }
        }

        sb.append("        </set>\n");
        sb.append("        where ").append(generatePrimaryKeyWhereClause(tableInfo)).append("\n");
        sb.append("    </update>\n\n");
    }

    /**
     * 生成删除语句
     */
    private static void generateDeleteByPrimaryKey(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <delete id=\"deleteByPrimaryKey\">\n");
        sb.append("        delete from ").append(tableInfo.getTableName()).append("\n");
        sb.append("        where ").append(generatePrimaryKeyWhereClause(tableInfo)).append("\n");
        sb.append("    </delete>\n\n");
    }

    /**
     * 生成批量插入语句
     */
    private static void generateBatchInsert(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <insert id=\"batchInsert\" parameterType=\"java.util.List\">\n");
        sb.append("        insert into ").append(tableInfo.getTableName()).append(" (\n");
        sb.append("            <include refid=\"Base_Column_List\" />\n");
        sb.append("        ) values\n");
        sb.append("        <foreach collection=\"list\" item=\"item\" separator=\",\">\n");
        sb.append("            (\n");

        for (int i = 0; i < tableInfo.getColumns().size(); i++) {
            ColumnInfo column = tableInfo.getColumns().get(i);
            sb.append("               #{item.").append(column.getPropertyName()).append(",jdbcType=").append(column.getJdbcType()).append("}");
            if (i < tableInfo.getColumns().size() - 1) {
                sb.append(", ").append("\n");
            }
        }

        sb.append("\n            )\n");
        sb.append("        </foreach>\n");
        sb.append("    </insert>\n\n");
    }

    /**
     * 生成批量选择性插入语句(修复字符串空字符判断)
     */
    private static void generateBatchInsertSelective(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <insert id=\"batchInsertSelective\" parameterType=\"java.util.List\">\n");
        sb.append("        insert into ").append(tableInfo.getTableName()).append("\n");
        sb.append("        <trim prefix=\"(\" suffix=\")\" suffixOverrides=\",\">\n");

        // 动态列 - 使用第一个元素的非空字段作为列模板
        for (ColumnInfo column : tableInfo.getColumns()) {
            // 字符串类型增加空字符判断
            if ("String".equals(column.getJavaType())) {
                sb.append("            <if test=\"list[0].").append(column.getPropertyName()).append(" != null and list[0].").append(column.getPropertyName()).append(" != ''\">\n");
            } else {
                sb.append("            <if test=\"list[0].").append(column.getPropertyName()).append(" != null\">\n");
            }
            sb.append("                ").append(column.getColumnName()).append(",\n");
            sb.append("            </if>\n");
        }

        sb.append("        </trim>\n");
        sb.append("        values\n");
        sb.append("        <foreach collection=\"list\" item=\"item\" separator=\",\">\n");
        sb.append("            <trim prefix=\"(\" suffix=\")\" suffixOverrides=\",\">\n");

        // 动态值 - 对每个item进行判断
        for (ColumnInfo column : tableInfo.getColumns()) {
            // 字符串类型增加空字符判断
            if ("String".equals(column.getJavaType())) {
                sb.append("                <if test=\"item.").append(column.getPropertyName()).append(" != null and item.").append(column.getPropertyName()).append(" != ''\">\n");
            } else {
                sb.append("                <if test=\"item.").append(column.getPropertyName()).append(" != null\">\n");
            }
            sb.append("                    #{item.").append(column.getPropertyName()).append(",jdbcType=")
                    .append(column.getJdbcType()).append("},\n");
            sb.append("                </if>\n");
        }

        sb.append("            </trim>\n");
        sb.append("        </foreach>\n");
        sb.append("    </insert>\n\n");
    }

    /**
     * 生成批量更新语句
     */
    private static void generateBatchUpdate(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <update id=\"batchUpdate\" parameterType=\"java.util.List\">\n");
        sb.append("        <foreach collection=\"list\" item=\"item\" separator=\";\">\n");
        sb.append("            update ").append(tableInfo.getTableName()).append("\n");
        sb.append("            set \n");

        boolean first = true;
        for (ColumnInfo column : tableInfo.getColumns()) {
            if (!tableInfo.getPrimaryKeys().contains(column.getColumnName())) {
                if (!first) {
                    sb.append(",\n");
                }
                sb.append("                ").append(column.getColumnName()).append(" = #{item.").append(column.getPropertyName())
                        .append(",jdbcType=").append(column.getJdbcType()).append("}");
                first = false;
            }
        }
        sb.append("\n");
        sb.append("            where ").append(generateBatchUpdatePrimaryKeyWhereClause(tableInfo)).append("\n");
        sb.append("        </foreach>\n");
        sb.append("    </update>\n\n");
    }

    /**
     * 生成批量选择性更新语句
     */
    private static void generateBatchUpdateSelective(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <update id=\"batchUpdateSelective\" parameterType=\"java.util.List\">\n");
        sb.append("        <foreach collection=\"list\" item=\"item\" separator=\";\">\n");
        sb.append("            update ").append(tableInfo.getTableName()).append("\n");
        sb.append("            <set>\n");

        for (ColumnInfo column : tableInfo.getColumns()) {
            if (!tableInfo.getPrimaryKeys().contains(column.getColumnName())) {
                // 字符串类型增加空字符判断
                if ("String".equals(column.getJavaType())) {
                    sb.append("                <if test=\"item.").append(column.getPropertyName()).append(" != null and item.").append(column.getPropertyName()).append(" != ''\">\n");
                } else {
                    sb.append("                <if test=\"item.").append(column.getPropertyName()).append(" != null\">\n");
                }
                sb.append("                    ").append(column.getColumnName()).append(" = #{item.").append(column.getPropertyName())
                        .append(",jdbcType=").append(column.getJdbcType()).append("},\n");
                sb.append("                </if>\n");
            }
        }

        sb.append("            </set>\n");
        sb.append("            where ").append(generateBatchUpdatePrimaryKeyWhereClause(tableInfo)).append("\n");
        sb.append("        </foreach>\n");
        sb.append("    </update>\n\n");
    }

    /**
     * 生成批量更新的主键WHERE条件子句
     */
    private static String generateBatchUpdatePrimaryKeyWhereClause(TableInfo tableInfo) {
        if (tableInfo.getPrimaryKeys().isEmpty()) {
            return "1=0"; // 如果没有主键,返回一个false条件
        }

        StringBuilder whereClause = new StringBuilder();

        for (int i = 0; i < tableInfo.getPrimaryKeys().size(); i++) {
            String pkColumn = tableInfo.getPrimaryKeys().get(i);
            String pkProperty = convertToCamelCase(pkColumn, false);
            String jdbcType = getJdbcTypeForColumn(tableInfo, pkColumn);

            if (i > 0) {
                whereClause.append(" and ");
            }
            whereClause.append(pkColumn).append(" = #{item.").append(pkProperty)
                    .append(",jdbcType=").append(jdbcType).append("}");
        }

        return whereClause.toString();
    }

    /**
     * 生成查询所有语句
     */
    private static void generateSelectAll(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <select id=\"selectAll\" resultMap=\"BaseResultMap\">\n");
        sb.append("        select \n");
        sb.append("        <include refid=\"Base_Column_List\" />\n");
        sb.append("        from ").append(tableInfo.getTableName()).append("\n");
        sb.append("    </select>\n\n");
    }

    /**
     * 生成条件查询语句(修复字符串空字符判断)
     */
    private static void generateSelectByCondition(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <select id=\"selectByCondition\" parameterType=\"").append(BASE_PACKAGE)
                .append(".entity.").append(tableInfo.getClassName()).append("\" resultMap=\"BaseResultMap\">\n");
        sb.append("        select \n");
        sb.append("        <include refid=\"Base_Column_List\" />\n");
        sb.append("        from ").append(tableInfo.getTableName()).append("\n");
        sb.append("        <where>\n");

        for (ColumnInfo column : tableInfo.getColumns()) {
            // 字符串类型增加空字符判断
            if ("String".equals(column.getJavaType())) {
                sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null and ").append(column.getPropertyName()).append(" != ''\">\n");
            } else {
                sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null\">\n");
            }
            sb.append("                and ").append(column.getColumnName()).append(" = #{").append(column.getPropertyName())
                    .append(",jdbcType=").append(column.getJdbcType()).append("}\n");
            sb.append("            </if>\n");
        }

        sb.append("        </where>\n");
        sb.append("    </select>\n\n");
    }

    /**
     * 生成条件计数语句(修复字符串空字符判断)
     */
    private static void generateCountByCondition(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <select id=\"countByCondition\" parameterType=\"").append(BASE_PACKAGE)
                .append(".entity.").append(tableInfo.getClassName()).append("\" resultType=\"java.lang.Long\">\n");
        sb.append("        select count(*) from ").append(tableInfo.getTableName()).append("\n");
        sb.append("        <where>\n");

        for (ColumnInfo column : tableInfo.getColumns()) {
            // 字符串类型增加空字符判断
            if ("String".equals(column.getJavaType())) {
                sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null and ").append(column.getPropertyName()).append(" != ''\">\n");
            } else {
                sb.append("            <if test=\"").append(column.getPropertyName()).append(" != null\">\n");
            }
            sb.append("                and ").append(column.getColumnName()).append(" = #{").append(column.getPropertyName())
                    .append(",jdbcType=").append(column.getJdbcType()).append("}\n");
            sb.append("            </if>\n");
        }

        sb.append("        </where>\n");
        sb.append("    </select>\n\n");
    }

    /**
     * 生成批量更新语句(CASE WHEN方式,返回正确的影响行数)
     */
    private static void generateBatchUpdateCaseWhen(StringBuilder sb, TableInfo tableInfo) {
        sb.append("    <update id=\"batchUpdateCaseWhen\" parameterType=\"java.util.List\">\n");
        sb.append("        UPDATE ").append(tableInfo.getTableName()).append("\n");
        sb.append("        <trim prefix=\"SET\" suffixOverrides=\",\">\n");

        // 为每个非主键字段生成CASE WHEN语句
        for (ColumnInfo column : tableInfo.getColumns()) {
            if (!tableInfo.getPrimaryKeys().contains(column.getColumnName())) {
                sb.append("            <trim prefix=\"").append(column.getColumnName()).append("=CASE\" suffix=\"END,\">\n");
                sb.append("                <foreach collection=\"list\" item=\"item\" index=\"index\">\n");

                // 根据字段类型生成不同的判断条件
                if ("String".equals(column.getJavaType())) {
                    // 字符串类型:判断null和空字符串
                    sb.append("                    <if test=\"item.").append(column.getPropertyName()).append(" != null and item.").append(column.getPropertyName()).append(" != ''\">\n");
                    sb.append("                        WHEN ").append(generatePrimaryKeyConditionForCase(tableInfo, "item")).append(" THEN #{item.").append(column.getPropertyName()).append("}\n");
                    sb.append("                    </if>\n");
                    sb.append("                    <if test=\"item.").append(column.getPropertyName()).append(" == null or item.").append(column.getPropertyName()).append(" == ''\">\n");
                    sb.append("                        WHEN ").append(generatePrimaryKeyConditionForCase(tableInfo, "item")).append(" THEN ").append(column.getColumnName()).append("\n");
                    sb.append("                    </if>\n");
                } else {
                    // 非字符串类型:只判断null
                    sb.append("                    <if test=\"item.").append(column.getPropertyName()).append(" != null\">\n");
                    sb.append("                        WHEN ").append(generatePrimaryKeyConditionForCase(tableInfo, "item")).append(" THEN #{item.").append(column.getPropertyName()).append("}\n");
                    sb.append("                    </if>\n");
                    sb.append("                    <if test=\"item.").append(column.getPropertyName()).append(" == null\">\n");
                    sb.append("                        WHEN ").append(generatePrimaryKeyConditionForCase(tableInfo, "item")).append(" THEN ").append(column.getColumnName()).append("\n");
                    sb.append("                    </if>\n");
                }

                sb.append("                </foreach>\n");
                sb.append("            </trim>\n");
            }
        }

        sb.append("        </trim>\n");
        sb.append("        WHERE ").append(generateBatchInCondition(tableInfo)).append("\n");
        sb.append("    </update>\n\n");
    }

    /**
     * 生成CASE WHEN语句中的主键条件
     */
    private static String generatePrimaryKeyConditionForCase(TableInfo tableInfo, String itemName) {
        StringBuilder condition = new StringBuilder();
        for (int i = 0; i < tableInfo.getPrimaryKeys().size(); i++) {
            String pkColumn = tableInfo.getPrimaryKeys().get(i);
            String pkProperty = convertToCamelCase(pkColumn, false);
            if (i > 0) {
                condition.append(" AND ");
            }
            condition.append(pkColumn).append("=#{").append(itemName).append(".").append(pkProperty).append("}");
        }
        return condition.toString();
    }

    /**
     * 生成IN条件用于批量更新
     */
    private static String generateBatchInCondition(TableInfo tableInfo) {
        StringBuilder condition = new StringBuilder();
        condition.append(tableInfo.getPrimaryKeys().get(0)).append(" IN (\n");
        condition.append("            <foreach collection=\"list\" item=\"item\" separator=\",\">\n");
        condition.append("                #{item.").append(convertToCamelCase(tableInfo.getPrimaryKeys().get(0), false)).append("}\n");
        condition.append("            </foreach>\n");
        condition.append("        )");
        return condition.toString();
    }

    /**
     * 生成主键WHERE条件子句
     */
    private static String generatePrimaryKeyWhereClause(TableInfo tableInfo) {
        if (tableInfo.getPrimaryKeys().isEmpty()) {
            return "1=0"; // 如果没有主键,返回一个false条件
        }

        StringBuilder whereClause = new StringBuilder();

        for (int i = 0; i < tableInfo.getPrimaryKeys().size(); i++) {
            String pkColumn = tableInfo.getPrimaryKeys().get(i);
            String pkProperty = convertToCamelCase(pkColumn, false);
            String jdbcType = getJdbcTypeForColumn(tableInfo, pkColumn);

            if (i > 0) {
                whereClause.append("          and ");
            }
            whereClause.append(pkColumn).append(" = #{").append(pkProperty)
                    .append(",jdbcType=").append(jdbcType).append("}");

            if (i < tableInfo.getPrimaryKeys().size() - 1) {
                whereClause.append("\n          ");
            }
        }

        return whereClause.toString();
    }

    /**
     * 获取主键参数(支持复合主键)
     */
    private static String getPrimaryKeyParams(TableInfo tableInfo) {
        if (tableInfo.getPrimaryKeys().isEmpty()) {
            return "Long id";
        }

        if (tableInfo.getPrimaryKeys().size() == 1) {
            String pkColumn = tableInfo.getPrimaryKeys().get(0);
            String pkProperty = convertToCamelCase(pkColumn, false);
            String javaType = getJavaTypeForColumn(tableInfo, pkColumn);
            return javaType + " " + pkProperty;
        } else {
            // 复合主键,使用@Param注解
            StringBuilder params = new StringBuilder();
            for (int i = 0; i < tableInfo.getPrimaryKeys().size(); i++) {
                String pkColumn = tableInfo.getPrimaryKeys().get(i);
                String pkProperty = convertToCamelCase(pkColumn, false);
                String javaType = getJavaTypeForColumn(tableInfo, pkColumn);

                if (i > 0) {
                    params.append(", ");
                }
                params.append("@Param(\"").append(pkProperty).append("\") ").append(javaType).append(" ").append(pkProperty);
            }
            return params.toString();
        }
    }

    /**
     * 工具方法:下划线转驼峰命名
     */
    private static String convertToCamelCase(String name, boolean firstCharUpper) {
        if (name == null || name.isEmpty()) {
            return name;
        }

        StringBuilder result = new StringBuilder();
        String[] parts = name.split("_");

        for (int i = 0; i < parts.length; i++) {
            String part = parts[i];
            if (part.isEmpty()) {
                continue;
            }

            if (i == 0 && !firstCharUpper) {
                result.append(part.toLowerCase());
            } else {
                result.append(Character.toUpperCase(part.charAt(0)))
                        .append(part.substring(1).toLowerCase());
            }
        }

        return result.toString();
    }

    /**
     * 转换JDBC类型
     */
    private static String convertJdbcType(String dbType) {
        if (dbType == null) return "VARCHAR";

        switch (dbType.toUpperCase()) {
            case "INT":
            case "INTEGER":
                return "INTEGER";
            case "BIGINT":
                return "BIGINT";
            case "VARCHAR":
            case "CHAR":
            case "TEXT":
            case "TINYTEXT":
            case "MEDIUMTEXT":
            case "LONGTEXT":
                return "VARCHAR";
            case "DATETIME":
            case "TIMESTAMP":
                return "TIMESTAMP";
            case "DATE":
                return "DATE";
            case "TIME":
                return "TIME";
            case "DECIMAL":
            case "NUMERIC":
                return "DECIMAL";
            case "BIT":
                return "BIT";
            case "TINYINT":
                return "TINYINT";
            case "SMALLINT":
                return "SMALLINT";
            case "MEDIUMINT":
                return "INTEGER";
            case "FLOAT":
                return "FLOAT";
            case "DOUBLE":
                return "DOUBLE";
            case "BLOB":
            case "LONGBLOB":
            case "MEDIUMBLOB":
            case "TINYBLOB":
                return "BLOB";
            default:
                return "VARCHAR";
        }
    }

    /**
     * 转换Java类型(使用LocalDateTime替代Date)
     */
    private static String convertJavaType(String dbType, int size) {
        if (dbType == null) return "String";

        String upperDbType = dbType.toUpperCase();
        switch (upperDbType) {
            case "INT":
            case "INTEGER":
            case "TINYINT":
            case "SMALLINT":
            case "MEDIUMINT":
                return "Integer";
            case "BIGINT":
                return "Long";
            case "VARCHAR":
            case "CHAR":
            case "TEXT":
            case "TINYTEXT":
            case "MEDIUMTEXT":
            case "LONGTEXT":
            case "ENUM":
            case "SET":
                return "String";
            case "DATETIME":
            case "TIMESTAMP":
                return "LocalDateTime";
            case "DATE":
                return "LocalDate";
            case "TIME":
                return "LocalTime";
            case "DECIMAL":
            case "NUMERIC":
                return "BigDecimal";
            case "BIT":
            case "BOOL":
            case "BOOLEAN":
                return size == 1 ? "Boolean" : "byte[]";
            case "FLOAT":
                return "Float";
            case "DOUBLE":
                return "Double";
            case "BINARY":
            case "VARBINARY":
            case "BLOB":
            case "TINYBLOB":
            case "MEDIUMBLOB":
            case "LONGBLOB":
                return "byte[]";
            default:
                return "String";
        }
    }

    /**
     * 获取主键的Java类型
     */
    private static String getPrimaryKeyJavaType(TableInfo tableInfo) {
        if (tableInfo.getPrimaryKeys().isEmpty()) {
            return "Long";
        }

        String pkColumn = tableInfo.getPrimaryKeys().get(0);
        return getJavaTypeForColumn(tableInfo, pkColumn);
    }

    /**
     * 获取列的Java类型
     */
    private static String getJavaTypeForColumn(TableInfo tableInfo, String columnName) {
        for (ColumnInfo column : tableInfo.getColumns()) {
            if (column.getColumnName().equals(columnName)) {
                return column.getJavaType();
            }
        }
        return "String";
    }

    /**
     * 获取列的JDBC类型
     */
    private static String getJdbcTypeForColumn(TableInfo tableInfo, String columnName) {
        for (ColumnInfo column : tableInfo.getColumns()) {
            if (column.getColumnName().equals(columnName)) {
                return column.getJdbcType();
            }
        }
        return "VARCHAR";
    }

    /**
     * 写入文件
     */
    private static void writeToFile(String directory, String fileName, String content) {
        try {
            java.io.File dir = new java.io.File(directory);
            if (!dir.exists()) {
                dir.mkdirs();
            }

            try (java.io.FileWriter writer = new java.io.FileWriter(directory + fileName)) {
                writer.write(content);
            }
            System.out.println("生成文件: " + directory + fileName);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 表信息类(增强版)
     */
    @Data
    static class TableInfo {
        private String tableName;
        private String className;
        private String objectName;
        private List<ColumnInfo> columns;
        private List<String> primaryKeys;
        private boolean hasBigDecimal;
        private boolean hasLocalDateTime;
    }

    /**
     * 列信息类
     */
    @Data
    static class ColumnInfo {
        private String columnName;
        private String propertyName;
        private String jdbcType;
        private String javaType;
        private String remarks;
        private boolean nullable;
    }
}
posted @ 2025-11-16 15:39  凛冬雪夜  阅读(7)  评论(0)    收藏  举报