orm 中 lnsert 操作怎么获取生成的主键id值

mybatis 中的实现


public interface UserMapper extends BaseMapper<User> {
    @Insert("INSERT INTO user (username, password) VALUES (#{user.username}, #{user.password})")
    @Options(useGeneratedKeys = true, keyProperty = "user.id")
    int insertUser(@Param("user") User user);
}

public class OptionsLanguageDriver extends XMLLanguageDriver implements LanguageDriver {

    @Override
    public SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
        // 获取原始的SqlSource
        SqlSource originalSqlSource = super.createSqlSource(configuration, script, parameterType);

        // 判断SqlCommandType是否为INSERT
        MappedStatement mappedStatement = configuration.getMappedStatement(script.getStringAttribute("id"));
        if (mappedStatement.getSqlCommandType() == SqlCommandType.INSERT) {
            // 创建OptionsSqlSource,并将原始的SqlSource作为委托
            return new OptionsSqlSource(originalSqlSource);
        }

        return originalSqlSource;
    }

    private static class OptionsSqlSource implements SqlSource {
        private final SqlSource delegate;

        public OptionsSqlSource(SqlSource delegate) {
            this.delegate = delegate;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            // 获取原始的BoundSql
            BoundSql originalBoundSql = delegate.getBoundSql(parameterObject);

            // 创建OptionsBoundSql,并将原始的BoundSql作为委托
            return new OptionsBoundSql(originalBoundSql);
        }
    }

    private static class OptionsBoundSql implements BoundSql {
        private final BoundSql delegate;

        public OptionsBoundSql(BoundSql delegate) {
            this.delegate = delegate;
        }

        @Override
        public String getSql() {
            // 获取原始的SQL语句
            String originalSql = delegate.getSql();

            // 添加获取自动生成主键的语句
            return originalSql + " SELECT LAST_INSERT_ID()";
        }

        // 其他方法的委托实现...
    }
}

在 JdbcTemplate 基础上实现


protected Integer insert(T t, Boolean ignoreNull) {
        String table = getTableName(t);

        List<Field> filterField = getField(t, ignoreNull);

        List<String> columnList = getColumns(filterField);

        String columns = StrUtil.join(Const.SEPARATOR_COMMA, columnList);

        // 构造占位符
        String params = StrUtil.repeatAndJoin("?", columnList.size(), Const.SEPARATOR_COMMA);

        // 构造值
        Object[] values = filterField.stream().map(field -> ReflectUtil.getFieldValue(t, field)).toArray();

        String sql = StrUtil.format("INSERT INTO {table} ({columns}) VALUES ({params})", Dict.create().set("table", table).set("columns", columns).set("params", params));
        log.debug("【执行SQL】SQL:{}", sql);
        log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(values));
        int update = jdbcTemplate.update(sql, values);
        Long generatedId = jdbcTemplate.queryForObject("SELECT LAST_INSERT_ID()", Long.class);
        Field pkField = getPkField(t);
        if (pkField == null) {
            throw new PrimaryKeyMissingException();
        }
        pkField.setAccessible(true);
        pkField.set(t, generatedId);
        return update;
    }

封装 JdbcTemplate

BaseDao<T, P>


@Slf4j
public class BaseDao<T, P> {
    private final JdbcTemplate jdbcTemplate;
    private final Class<T> clazz;

    @SuppressWarnings(value = "unchecked")
    public BaseDao(JdbcTemplate jdbcTemplate) {
        this.jdbcTemplate = jdbcTemplate;
        clazz = (Class<T>) ((ParameterizedType) getClass().getGenericSuperclass()).getActualTypeArguments()[0];
    }

    /**
     * 通用插入,自增列需要添加 {@link Pk} 注解
     *
     * @param t          对象
     * @param ignoreNull 是否忽略 null 值
     * @return 操作的行数
     */
    @SneakyThrows
    protected Integer insert(T t, Boolean ignoreNull) {
        String table = getTableName(t);

        List<Field> filterField = getField(t, ignoreNull);

        List<String> columnList = getColumns(filterField);

        String columns = StrUtil.join(Const.SEPARATOR_COMMA, columnList);

        // 构造占位符
        String params = StrUtil.repeatAndJoin("?", columnList.size(), Const.SEPARATOR_COMMA);

        // 构造值
        Object[] values = filterField.stream().map(field -> ReflectUtil.getFieldValue(t, field)).toArray();

        String sql = StrUtil.format("INSERT INTO {table} ({columns}) VALUES ({params})", Dict.create().set("table", table).set("columns", columns).set("params", params));
        log.debug("【执行SQL】SQL:{}", sql);
        log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(values));
        int update = jdbcTemplate.update(sql, values);
        Long generatedId = jdbcTemplate.queryForObject("SELECT LAST_INSERT_ID()", Long.class);
        Field pkField = getPkField(t);
        if (pkField == null) {
            throw new PrimaryKeyMissingException();
        }
        pkField.setAccessible(true);
        pkField.set(t, generatedId);
        return update;
    }

    /**
     * 通用根据主键删除
     *
     * @param pk 主键
     * @return 影响行数
     */
    protected Integer deleteById(P pk) {
        String tableName = getTableName();
        String sql = StrUtil.format("DELETE FROM {table} where id = ?", Dict.create().set("table", tableName));
        log.debug("【执行SQL】SQL:{}", sql);
        log.debug("【执行SQL】参数:{}", pk);
        return jdbcTemplate.update(sql, pk);
    }

    /**
     * 通用根据主键更新,自增列需要添加 {@link Pk} 注解
     *
     * @param t          对象
     * @param pk         主键
     * @param ignoreNull 是否忽略 null 值
     * @return 操作的行数
     */
    protected Integer updateById(T t, P pk, Boolean ignoreNull) {
        String tableName = getTableName(t);

        List<Field> filterField = getField(t, ignoreNull);

        List<String> columnList = getColumns(filterField);

        List<String> columns = columnList.stream().map(s -> StrUtil.appendIfMissing(s, " = ?")).collect(Collectors.toList());
        String params = StrUtil.join(Const.SEPARATOR_COMMA, columns);

        // 构造值
        List<Object> valueList = filterField.stream().map(field -> ReflectUtil.getFieldValue(t, field)).collect(Collectors.toList());
        valueList.add(pk);

        Object[] values = ArrayUtil.toArray(valueList, Object.class);

        String sql = StrUtil.format("UPDATE {table} SET {params} where id = ?", Dict.create().set("table", tableName).set("params", params));
        log.debug("【执行SQL】SQL:{}", sql);
        log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(values));
        return jdbcTemplate.update(sql, values);
    }

    /**
     * 通用根据主键查询单条记录
     *
     * @param pk 主键
     * @return 单条记录
     */
    public T findOneById(P pk) {
        String tableName = getTableName();
        String sql = StrUtil.format("SELECT * FROM {table} where id = ?", Dict.create().set("table", tableName));
        RowMapper<T> rowMapper = new BeanPropertyRowMapper<>(clazz);
        log.debug("【执行SQL】SQL:{}", sql);
        log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(pk));
        return jdbcTemplate.queryForObject(sql, new Object[]{pk}, rowMapper);
    }

    /**
     * 根据对象查询
     *
     * @param t 查询条件
     * @return 对象列表
     */
    public List<T> findByExample(T t) {
        String tableName = getTableName(t);
        List<Field> filterField = getField(t, true);
        List<String> columnList = getColumns(filterField);

        List<String> columns = columnList.stream().map(s -> " and " + s + " = ? ").collect(Collectors.toList());

        String where = StrUtil.join(" ", columns);
        // 构造值
        Object[] values = filterField.stream().map(field -> ReflectUtil.getFieldValue(t, field)).toArray();

        String sql = StrUtil.format("SELECT * FROM {table} where 1=1 {where}", Dict.create().set("table", tableName).set("where", StrUtil.isBlank(where) ? "" : where));
        log.debug("【执行SQL】SQL:{}", sql);
        log.debug("【执行SQL】参数:{}", JSONUtil.toJsonStr(values));
        List<Map<String, Object>> maps = jdbcTemplate.queryForList(sql, values);
        List<T> ret = CollUtil.newArrayList();
        maps.forEach(map -> ret.add(BeanUtil.fillBeanWithMap(map, ReflectUtil.newInstance(clazz), true, false)));
        return ret;
    }

    /**
     * 获取表名
     *
     * @param t 对象
     * @return 表名
     */
    private String getTableName(T t) {
        Table tableAnnotation = t.getClass().getAnnotation(Table.class);
        if (ObjectUtil.isNotNull(tableAnnotation)) {
            return StrUtil.format("`{}`", tableAnnotation.name());
        } else {
            return StrUtil.format("`{}`", t.getClass().getName().toLowerCase());
        }
    }

    /**
     * 获取表名
     *
     * @return 表名
     */
    private String getTableName() {
        Table tableAnnotation = clazz.getAnnotation(Table.class);
        if (ObjectUtil.isNotNull(tableAnnotation)) {
            return StrUtil.format("`{}`", tableAnnotation.name());
        } else {
            return StrUtil.format("`{}`", clazz.getName().toLowerCase());
        }
    }

    /**
     * 获取列
     *
     * @param fieldList 字段列表
     * @return 列信息列表
     */
    private List<String> getColumns(List<Field> fieldList) {
        // 构造列
        List<String> columnList = CollUtil.newArrayList();
        for (Field field : fieldList) {
            Column columnAnnotation = field.getAnnotation(Column.class);
            String columnName;
            if (ObjectUtil.isNotNull(columnAnnotation)) {
                columnName = columnAnnotation.name();
            } else {
                columnName = field.getName();
            }
            columnList.add(StrUtil.format("`{}`", columnName));
        }
        return columnList;
    }

    /**
     * 获取字段列表 {@code 过滤数据库中不存在的字段,以及自增列}
     *
     * @param t          对象
     * @param ignoreNull 是否忽略空值
     * @return 字段列表
     */
    private List<Field> getField(T t, Boolean ignoreNull) {
        // 获取所有字段,包含父类中的字段
        Field[] fields = ReflectUtil.getFields(t.getClass());

        // 过滤数据库中不存在的字段,以及自增列
        List<Field> filterField;
        Stream<Field> fieldStream = CollUtil.toList(fields).stream().filter(field -> ObjectUtil.isNull(field.getAnnotation(Ignore.class)) || ObjectUtil.isNull(field.getAnnotation(Pk.class)));

        // 是否过滤字段值为null的字段
        if (ignoreNull) {
            filterField = fieldStream.filter(field -> ObjectUtil.isNotNull(ReflectUtil.getFieldValue(t, field))).collect(Collectors.toList());
        } else {
            filterField = fieldStream.collect(Collectors.toList());
        }
        return filterField;
    }

    private Field getPkField(T t) {
        // 获取所有字段,包含父类中的字段
        Field[] fields = ReflectUtil.getFields(t.getClass());
        Optional<Field> first = CollUtil.toList(fields).stream().filter(field -> ObjectUtil.isNotNull(field.getAnnotation(Pk.class))).findFirst();
        boolean present =first.isPresent();
        if (present) {
            return first.get();
        }
        return null;
    }

}

UserDao



@Repository
public class UserDao extends BaseDao<User, Long> {

    @Autowired
    public UserDao(JdbcTemplate jdbcTemplate) {
        super(jdbcTemplate);
    }

    /**
     * 保存用户
     *
     * @param user 用户对象
     * @return 操作影响行数
     */
    public Integer insert(User user) {
        return super.insert(user, true);
    }

    /**
     * 根据主键删除用户
     *
     * @param id 主键id
     * @return 操作影响行数
     */
    public Integer delete(Long id) {
        return super.deleteById(id);
    }

    /**
     * 更新用户
     *
     * @param user 用户对象
     * @param id   主键id
     * @return 操作影响行数
     */
    public Integer update(User user, Long id) {
        return super.updateById(user, id, true);
    }

    /**
     * 根据主键获取用户
     *
     * @param id 主键id
     * @return id对应的用户
     */
    public User selectById(Long id) {
        return super.findOneById(id);
    }

    /**
     * 根据查询条件获取用户列表
     *
     * @param user 用户查询条件
     * @return 用户列表
     */
    public List<User> selectUserList(User user) {
        return super.findByExample(user);
    }
}
posted @ 2023-08-23 14:10  linzm14  阅读(8)  评论(0编辑  收藏  举报