Spring AI 代码分析(九)--记忆能力实现

记忆能力分析

请关注微信公众号:阿呆-bot

1. 工程结构概览

Spring AI 提供了完整的对话记忆(Chat Memory)能力,支持将对话历史持久化到各种存储后端。记忆能力是构建多轮对话应用的基础。

spring-ai-model/
└── chat/memory/                  # 记忆核心抽象
    ├── ChatMemory.java           # 记忆接口
    ├── ChatMemoryRepository.java # 存储仓库接口
    ├── MessageWindowChatMemory.java  # 窗口记忆实现
    └── InMemoryChatMemoryRepository.java  # 内存实现

memory/repository/                # 持久化实现
├── spring-ai-model-chat-memory-repository-jdbc/      # JDBC 实现
├── spring-ai-model-chat-memory-repository-mongodb/    # MongoDB 实现
├── spring-ai-model-chat-memory-repository-neo4j/      # Neo4j 实现
├── spring-ai-model-chat-memory-repository-cassandra/  # Cassandra 实现
└── spring-ai-model-chat-memory-repository-cosmos-db/  # Cosmos DB 实现

2. 技术体系与模块关系

记忆系统采用分层设计:记忆接口 → 存储仓库 → 具体实现

image.png

3. 关键场景示例代码

3.1 基础使用

使用内存记忆:

// 创建内存记忆
ChatMemory memory = new MessageWindowChatMemory(
    new InMemoryChatMemoryRepository(),
    10  // 窗口大小:保留最近 10 条消息
);

// 添加消息
memory.add("conversation-1", new UserMessage("你好"));
memory.add("conversation-1", new AssistantMessage("你好!有什么可以帮助你的?"));

// 获取对话历史
List<Message> history = memory.get("conversation-1");

3.2 使用 JDBC 持久化

使用数据库持久化记忆:

@Autowired
private DataSource dataSource;

@Bean
public ChatMemory chatMemory() {
    JdbcChatMemoryRepository repository = 
        JdbcChatMemoryRepository.builder()
            .dataSource(dataSource)
            .build();
    
    return new MessageWindowChatMemory(repository, 20);
}

3.3 使用 MongoDB

使用 MongoDB 持久化:

@Autowired
private MongoTemplate mongoTemplate;

@Bean
public ChatMemory chatMemory() {
    MongoChatMemoryRepository repository = 
        MongoChatMemoryRepository.builder()
            .mongoTemplate(mongoTemplate)
            .build();
    
    return new MessageWindowChatMemory(repository, 30);
}

3.4 在 ChatClient 中使用

记忆可以通过 Advisor 集成到 ChatClient:

ChatMemory memory = new MessageWindowChatMemory(repository, 10);

MessageChatMemoryAdvisor memoryAdvisor = 
    MessageChatMemoryAdvisor.builder()
        .chatMemory(memory)
        .conversationId("user-123")
        .build();

ChatClient chatClient = ChatClient.builder(chatModel)
    .defaultAdvisors(memoryAdvisor)
    .build();

// 对话会自动保存到记忆
String response = chatClient.prompt()
    .user("我的名字是张三")
    .call()
    .content();

// 后续对话会自动包含历史
String response2 = chatClient.prompt()
    .user("我的名字是什么?")
    .call()
    .content();  // 模型会记住名字是张三

4. 核心实现图

4.1 记忆存储和检索流程

image.png

5. 入口类与关键类关系

image.png

6. 关键实现逻辑分析

6.1 ChatMemory 接口设计

ChatMemory 接口提供了简单的记忆 API:

public interface ChatMemory {
    void add(String conversationId, List<Message> messages);
    List<Message> get(String conversationId);
    void clear(String conversationId);
}

这个接口设计简洁,但功能强大。它支持:

  • 多对话管理:通过 conversationId 区分不同对话
  • 批量添加:支持一次添加多条消息
  • 清理功能:支持清除特定对话的记忆

6.2 MessageWindowChatMemory 实现

MessageWindowChatMemory 实现了窗口记忆策略:

public class MessageWindowChatMemory implements ChatMemory {
    private final ChatMemoryRepository repository;
    private final int windowSize;
    
    @Override
    public void add(String conversationId, List<Message> messages) {
        // 1. 获取现有消息
        List<Message> existing = repository.findByConversationId(conversationId);
        
        // 2. 添加新消息
        List<Message> allMessages = new ArrayList<>(existing);
        allMessages.addAll(messages);
        
        // 3. 应用窗口策略(只保留最近的 N 条)
        List<Message> windowed = applyWindow(allMessages);
        
        // 4. 保存
        repository.saveAll(conversationId, windowed);
    }
    
    @Override
    public List<Message> get(String conversationId) {
        List<Message> messages = repository.findByConversationId(conversationId);
        return applyWindow(messages);
    }
    
    private List<Message> applyWindow(List<Message> messages) {
        if (messages.size() <= windowSize) {
            return messages;
        }
        // 只返回最近的 N 条消息
        return messages.subList(messages.size() - windowSize, messages.size());
    }
}

窗口策略的优势:

  • 控制上下文长度:避免上下文过长导致 token 超限
  • 保持相关性:最近的对话通常更相关
  • 性能优化:减少需要处理的消息数量

6.3 JDBC 实现

JDBC 实现支持多种数据库:

public class JdbcChatMemoryRepository implements ChatMemoryRepository {
    private final JdbcTemplate jdbcTemplate;
    private final ChatMemoryRepositoryDialect dialect;
    
    @Override
    public List<Message> findByConversationId(String conversationId) {
        String sql = dialect.getSelectByConversationIdSql();
        
        return jdbcTemplate.query(sql, 
            new Object[]{conversationId},
            (rs, rowNum) -> {
                String content = rs.getString("content");
                String type = rs.getString("type");
                Map<String, Object> metadata = parseMetadata(rs.getString("metadata"));
                
                return createMessage(type, content, metadata);
            }
        );
    }
    
    @Override
    public void saveAll(String conversationId, List<Message> messages) {
        // 1. 删除现有消息
        deleteByConversationId(conversationId);
        
        // 2. 批量插入新消息
        String sql = dialect.getInsertSql();
        List<Object[]> batchArgs = messages.stream()
            .map(msg -> new Object[]{
                conversationId,
                msg.getText(),
                msg.getMessageType().name(),
                toJson(msg.getMetadata()),
                Timestamp.from(Instant.now())
            })
            .collect(toList());
        
        jdbcTemplate.batchUpdate(sql, batchArgs);
    }
}

支持的数据库

  • PostgreSQL
  • MySQL/MariaDB
  • H2
  • SQLite
  • Oracle
  • SQL Server
  • HSQLDB

每个数据库都有自己的 Dialect 实现,处理 SQL 方言差异。

6.4 MongoDB 实现

MongoDB 实现使用文档存储:

public class MongoChatMemoryRepository implements ChatMemoryRepository {
    private final MongoTemplate mongoTemplate;
    
    @Override
    public List<Message> findByConversationId(String conversationId) {
        Query query = Query.query(
            Criteria.where("conversationId").is(conversationId)
        ).with(Sort.by("timestamp").descending());
        
        List<Conversation> conversations = mongoTemplate.find(
            query, Conversation.class
        );
        
        return conversations.stream()
            .map(this::mapMessage)
            .collect(toList());
    }
    
    @Override
    public void saveAll(String conversationId, List<Message> messages) {
        // 1. 删除现有消息
        deleteByConversationId(conversationId);
        
        // 2. 转换为文档并保存
        List<Conversation> conversations = messages.stream()
            .map(msg -> new Conversation(
                conversationId,
                new Conversation.Message(
                    msg.getText(),
                    msg.getMessageType().name(),
                    msg.getMetadata()
                ),
                Instant.now()
            ))
            .collect(toList());
        
        mongoTemplate.insert(conversations, Conversation.class);
    }
}

MongoDB 文档结构

{
  "conversationId": "user-123",
  "message": {
    "text": "你好",
    "type": "USER",
    "metadata": {}
  },
  "timestamp": "2025-01-01T00:00:00Z"
}

6.5 Neo4j 实现

Neo4j 实现使用图数据库:

public class Neo4jChatMemoryRepository implements ChatMemoryRepository {
    @Override
    public List<Message> findByConversationId(String conversationId) {
        String cypher = """
            MATCH (s:Session {id: $conversationId})-[:HAS_MESSAGE]->(m:Message)
            OPTIONAL MATCH (m)-[:HAS_METADATA]->(metadata:Metadata)
            OPTIONAL MATCH (m)-[:HAS_MEDIA]->(media:Media)
            RETURN m, metadata, collect(media) as medias
            ORDER BY m.idx ASC
            """;
        
        return driver.executableQuery(cypher)
            .withParameters(Map.of("conversationId", conversationId))
            .execute(record -> mapToMessage(record));
    }
    
    @Override
    public void saveAll(String conversationId, List<Message> messages) {
        String cypher = """
            MERGE (s:Session {id: $conversationId})
            WITH s
            UNWIND $messages AS msg
            CREATE (m:Message {
                text: msg.text,
                type: msg.type,
                idx: msg.idx
            })
            CREATE (s)-[:HAS_MESSAGE]->(m)
            """;
        
        driver.executableQuery(cypher)
            .withParameters(Map.of(
                "conversationId", conversationId,
                "messages", toMessageParams(messages)
            ))
            .execute();
    }
}

Neo4j 图结构

(Session {id: "user-123"})-[:HAS_MESSAGE]->(Message {text: "你好", type: "USER"})
(Session {id: "user-123"})-[:HAS_MESSAGE]->(Message {text: "你好!", type: "ASSISTANT"})

6.6 Cassandra 实现

Cassandra 实现使用分布式存储:

public class CassandraChatMemoryRepository implements ChatMemoryRepository {
    @Override
    public List<Message> findByConversationId(String conversationId) {
        BoundStatement stmt = getStmt.boundStatementBuilder()
            .setString("conversation_id", conversationId)
            .build();
        
        ResultSet rs = session.execute(stmt);
        return rs.all().stream()
            .map(this::mapToMessage)
            .collect(toList());
    }
    
    @Override
    public void saveAll(String conversationId, List<Message> messages) {
        // 1. 删除现有消息
        deleteByConversationId(conversationId);
        
        // 2. 批量插入
        BatchStatement batch = BatchStatement.builder(BatchType.LOGGED)
            .build();
        
        for (Message msg : messages) {
            BoundStatement stmt = addStmt.boundStatementBuilder()
                .setString("conversation_id", conversationId)
                .setString("content", msg.getText())
                .setString("type", msg.getMessageType().name())
                .setMap("metadata", msg.getMetadata())
                .build();
            batch.add(stmt);
        }
        
        session.execute(batch);
    }
}

7. 实现对比分析

特性 JDBC MongoDB Neo4j Cassandra
存储模型 关系型 文档型 图型 列族型
查询方式 SQL Query DSL Cypher CQL
适用场景 通用 灵活结构 关系查询 大规模分布式
性能 中等 中等 极高
扩展性 很好 优秀
事务支持

8. 外部依赖

不同实现的依赖:

8.1 JDBC

  • Spring JDBC:JDBC 模板
  • 数据库驱动:PostgreSQL、MySQL 等

8.2 MongoDB

  • Spring Data MongoDB:MongoDB 集成

8.3 Neo4j

  • Neo4j Java Driver:Neo4j 官方驱动

8.4 Cassandra

  • Cassandra Java Driver:Cassandra 官方驱动

9. 工程总结

Spring AI 的记忆能力设计有几个值得学习的地方:

分层抽象ChatMemory 提供高级 API,ChatMemoryRepository 提供存储抽象,具体实现处理数据库差异。这种设计让记忆功能既易用又灵活。想换存储后端?换个 ChatMemoryRepository 实现就行。

窗口记忆策略MessageWindowChatMemory 实现了智能的消息管理,只保留最近的 N 条消息,这既控制了上下文长度,又保持了相关性。不会因为对话历史太长导致 token 超限。

多存储后端支持。支持 JDBC、MongoDB、Neo4j、Cassandra 等多种存储,用户可以根据需求选择最合适的后端。想用关系数据库?用 JDBC。想用图数据库?用 Neo4j。

统一的数据模型。所有实现都使用相同的 Message 模型,这让切换存储后端变得简单。今天用 PostgreSQL,明天想换 MongoDB?改个配置就行。

自动模式初始化。大多数实现都支持自动创建表/集合,简化了部署。不用手动建表,启动时自动搞定。

总的来说,Spring AI 的记忆能力既简单又强大。简单的 API 让使用变得容易,强大的实现让系统可以适应各种场景。这种设计让开发者可以轻松构建支持多轮对话的 AI 应用。

posted @ 2025-11-26 21:43  wasp  阅读(3)  评论(0)    收藏  举报