SpringAI-ChatMemory

官网地址:https://docs.spring.io/spring-ai/reference/api/chat-memory.html

1、描述

LLM是无状态的,不会保留先前的交互信息,为了解决这个问题,SpringAI提供了聊天记忆功能,使得在与LLM多次交互中存储和检索信息。

2、接口和类说明

  • ChatMemory为顶层接口,提供了三个方法,存储:add()、获取:get()、删除:clear()
  • ChatMemory的默认实现类:MessageWindowChatMemory,用来管理消息的存储和查询,维护一个消息窗口,最大大小为指定值。当消息数量超过最大值时,较旧的消息将被删除,同时保留系统消息。默认窗口大小为 20 条消息。
  • ChatMemoryRepository为操作对话消息的存储和查询的接口,方法列表:
public interface ChatMemoryRepository {
    List<String> findConversationIds();

    List<Message> findByConversationId(String conversationId);

    void saveAll(String conversationId, List<Message> messages);

    void deleteByConversationId(String conversationId);
}
  • ChatMemoryRepository的默认实现类:InMemoryChatMemoryRepository,使用内存存储消息
    可以增加依赖,实现通过jdbc存储,实现类为:JdbcChatMemoryRepository
      <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-starter-model-chat-memory-repository-jdbc</artifactId>
        </dependency>

配置:

spring.ai.chat.memory.repository.jdbc.platform = mysql

spring.datasource.url=jdbc:mysql://localhost:3306/test_db?useUnicode=true&characterEncoding=utf-8&useSSL=false&serverTimezone=UTC
spring.datasource.username=root
spring.datasource.password=root
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver

java配置

@Autowired
    private DateTimeTools dateTimeTools;
    @Autowired
    private JdbcChatMemoryRepository jdbcChatMemoryRepository;
    @Autowired
    private MilvusVectorStore milvusVectorStore;

    @Bean
    public ChatMemory chatMemory() {
        return MessageWindowChatMemory.builder()
                .chatMemoryRepository(jdbcChatMemoryRepository).maxMessages(10).build();
    }

    @Bean
    public ChatClient chatClient(OllamaChatModel ollamaChatModel) {
        return ChatClient.builder(ollamaChatModel)
                .defaultTools(dateTimeTools)
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory()).build())
                .build();
    }

3、聊天记忆Advisor

SpringAI提供了内置的Advisor

  • MessageChatMemoryAdvisor。此 Advisor 使用提供的 ChatMemory 实现管理对话记忆。在每次交互时,它从记忆中检索对话历史并将其作为消息集合包含在提示中。

通过阅读源码,可以看出,通过对话ID把聊天记忆查处,作为消息集合放在prompt中,关键代码:

public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
		String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);

		// 1. Retrieve the chat memory for the current conversation.
		List<Message> memoryMessages = this.chatMemory.get(conversationId);

		// 2. Advise the request messages list.
		List<Message> processedMessages = new ArrayList<>(memoryMessages);
		processedMessages.addAll(chatClientRequest.prompt().getInstructions());

		// 3. Create a new request with the advised messages.
		ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
			.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
			.build();

		// 4. Add the new user message to the conversation memory.
		UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
		this.chatMemory.add(conversationId, userMessage);

		return processedChatClientRequest;
	}
  • PromptChatMemoryAdvisor。此 Advisor 使用提供的 ChatMemory 实现管理对话记忆。在每次交互时,它从记忆中检索对话历史并将其作为纯文本附加到系统提示中。

通过阅读源码,可以看出,通过记忆ID出聊天记忆之后,转为字符串,并且与系统消息,组成系统提示词模板放在prompt中,关键代码:

public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
		String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
		// 1. Retrieve the chat memory for the current conversation.
		List<Message> memoryMessages = this.chatMemory.get(conversationId);
		logger.debug("[PromptChatMemoryAdvisor.before] Memory before processing for conversationId={}: {}",
				conversationId, memoryMessages);

		// 2. Process memory messages as a string.
		String memory = memoryMessages.stream()
			.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
			.map(m -> m.getMessageType() + ":" + m.getText())
			.collect(Collectors.joining(System.lineSeparator()));

		// 3. Augment the system message.
		SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
		String augmentedSystemText = this.systemPromptTemplate
			.render(Map.of("instructions", systemMessage.getText(), "memory", memory));

		// 4. Create a new request with the augmented system message.
		ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
			.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
			.build();

		// 5. Add all user messages from the current prompt to memory (after system
		// message is generated)
		// 4. Add the new user message to the conversation memory.
		UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
		this.chatMemory.add(conversationId, userMessage);

		return processedChatClientRequest;
	}
  • VectorStoreChatMemoryAdvisor。此 Advisor 使用提供的 VectorStore 实现管理对话记忆。在每次交互时,它从向量存储中检索对话历史并将其作为纯文本附加到系统消息中。

通过阅读源码,可以看出,通过用户消息和对话ID,去向量数据库相似性匹配,最后将查询结果与系统消息组成系统提示词模版,放入到prompt中,关键代码:

public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
        String conversationId = this.getConversationId(request.context(), this.defaultConversationId);
        String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
        int topK = this.getChatMemoryTopK(request.context());
        String filter = "conversationId=='" + conversationId + "'";
        SearchRequest searchRequest = SearchRequest.builder().query(query).topK(topK).filterExpression(filter).build();
        List<Document> documents = this.vectorStore.similaritySearch(searchRequest);
        String longTermMemory = documents == null ? "" : (String)documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
        SystemMessage systemMessage = request.prompt().getSystemMessage();
        String augmentedSystemText = this.systemPromptTemplate.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
        ChatClientRequest processedChatClientRequest = request.mutate().prompt(request.prompt().augmentSystemMessage(augmentedSystemText)).build();
        UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
        if (userMessage != null) {
            this.vectorStore.write(this.toDocuments(List.of(userMessage), conversationId));
        }

        return processedChatClientRequest;
    }

4、代码实现

使用JdbcChatMemoryRepository和MessageChatMemoryAdvisor实现一个聊天记忆的demo
pom依赖:

     <dependency>
            <groupId>org.springframework.ai</groupId>
            <artifactId>spring-ai-starter-model-chat-memory-repository-jdbc</artifactId>
        </dependency>

properties配置:

spring.ai.chat.memory.repository.jdbc.platform = mysql
// 是否初始化脚本
spring.ai.chat.memory.repository.jdbc.initialize-schema = always
// 脚本位置
spring.ai.chat.memory.repository.jdbc.schema=classpath:schema/memory-mysql.sql

spring.datasource.url=jdbc:mysql://localhost:3306/test_db?useUnicode=true&characterEncoding=utf-8&useSSL=false&serverTimezone=UTC
spring.datasource.username=root
spring.datasource.password=root
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver

sql脚本:

create table spring_ai_chat_memory
(
    id              int auto_increment
        primary key,
    conversation_id int         null comment '对话ID',
    type            varchar(64) null comment '消息类型',
    content         text        null,
    timestamp       timestamp   null
);

配置类:

@Configuration
public class OllamaConfig {

    @Autowired
    private JdbcChatMemoryRepository jdbcChatMemoryRepository;

    @Bean
    public ChatMemory chatMemory() {
        return MessageWindowChatMemory.builder()
                .chatMemoryRepository(jdbcChatMemoryRepository).maxMessages(10).build();
    }

    @Bean
    public ChatClient chatClient(OllamaChatModel ollamaChatModel) {
        return ChatClient.builder(ollamaChatModel)
                .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory()).build())
                .build();
    }

}

入参:

@Data
public class ChatReq {
    private String memoryId;
    private String message;
}

controller:

    @PostMapping("chat")
    public Flux<String> chat(@RequestBody ChatReq chatReq) {
        Prompt prompt = new Prompt(chatReq.getMessage(), OllamaChatOptions.builder().model("qwen3:8b").disableThinking().build());

        return chatClient.prompt(prompt)
                .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, chatReq.getMemoryId()))
                .stream().content();
    }

数据库数据:
image

posted @ 2025-12-18 19:10  0xCAFEBABE_001  阅读(6)  评论(0)    收藏  举报