SpringAI-ChatMemory
官网地址:https://docs.spring.io/spring-ai/reference/api/chat-memory.html
1、描述
LLM是无状态的,不会保留先前的交互信息,为了解决这个问题,SpringAI提供了聊天记忆功能,使得在与LLM多次交互中存储和检索信息。
2、接口和类说明
- ChatMemory为顶层接口,提供了三个方法,存储:add()、获取:get()、删除:clear()
- ChatMemory的默认实现类:MessageWindowChatMemory,用来管理消息的存储和查询,维护一个消息窗口,最大大小为指定值。当消息数量超过最大值时,较旧的消息将被删除,同时保留系统消息。默认窗口大小为 20 条消息。
- ChatMemoryRepository为操作对话消息的存储和查询的接口,方法列表:
public interface ChatMemoryRepository {
List<String> findConversationIds();
List<Message> findByConversationId(String conversationId);
void saveAll(String conversationId, List<Message> messages);
void deleteByConversationId(String conversationId);
}
- ChatMemoryRepository的默认实现类:InMemoryChatMemoryRepository,使用内存存储消息
可以增加依赖,实现通过jdbc存储,实现类为:JdbcChatMemoryRepository
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-chat-memory-repository-jdbc</artifactId>
</dependency>
配置:
spring.ai.chat.memory.repository.jdbc.platform = mysql
spring.datasource.url=jdbc:mysql://localhost:3306/test_db?useUnicode=true&characterEncoding=utf-8&useSSL=false&serverTimezone=UTC
spring.datasource.username=root
spring.datasource.password=root
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver
java配置
@Autowired
private DateTimeTools dateTimeTools;
@Autowired
private JdbcChatMemoryRepository jdbcChatMemoryRepository;
@Autowired
private MilvusVectorStore milvusVectorStore;
@Bean
public ChatMemory chatMemory() {
return MessageWindowChatMemory.builder()
.chatMemoryRepository(jdbcChatMemoryRepository).maxMessages(10).build();
}
@Bean
public ChatClient chatClient(OllamaChatModel ollamaChatModel) {
return ChatClient.builder(ollamaChatModel)
.defaultTools(dateTimeTools)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory()).build())
.build();
}
3、聊天记忆Advisor
SpringAI提供了内置的Advisor
- MessageChatMemoryAdvisor。此 Advisor 使用提供的 ChatMemory 实现管理对话记忆。在每次交互时,它从记忆中检索对话历史并将其作为消息集合包含在提示中。
通过阅读源码,可以看出,通过对话ID把聊天记忆查处,作为消息集合放在prompt中,关键代码:
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
// 1. Retrieve the chat memory for the current conversation.
List<Message> memoryMessages = this.chatMemory.get(conversationId);
// 2. Advise the request messages list.
List<Message> processedMessages = new ArrayList<>(memoryMessages);
processedMessages.addAll(chatClientRequest.prompt().getInstructions());
// 3. Create a new request with the advised messages.
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
.prompt(chatClientRequest.prompt().mutate().messages(processedMessages).build())
.build();
// 4. Add the new user message to the conversation memory.
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
this.chatMemory.add(conversationId, userMessage);
return processedChatClientRequest;
}
- PromptChatMemoryAdvisor。此 Advisor 使用提供的 ChatMemory 实现管理对话记忆。在每次交互时,它从记忆中检索对话历史并将其作为纯文本附加到系统提示中。
通过阅读源码,可以看出,通过记忆ID出聊天记忆之后,转为字符串,并且与系统消息,组成系统提示词模板放在prompt中,关键代码:
public ChatClientRequest before(ChatClientRequest chatClientRequest, AdvisorChain advisorChain) {
String conversationId = getConversationId(chatClientRequest.context(), this.defaultConversationId);
// 1. Retrieve the chat memory for the current conversation.
List<Message> memoryMessages = this.chatMemory.get(conversationId);
logger.debug("[PromptChatMemoryAdvisor.before] Memory before processing for conversationId={}: {}",
conversationId, memoryMessages);
// 2. Process memory messages as a string.
String memory = memoryMessages.stream()
.filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT)
.map(m -> m.getMessageType() + ":" + m.getText())
.collect(Collectors.joining(System.lineSeparator()));
// 3. Augment the system message.
SystemMessage systemMessage = chatClientRequest.prompt().getSystemMessage();
String augmentedSystemText = this.systemPromptTemplate
.render(Map.of("instructions", systemMessage.getText(), "memory", memory));
// 4. Create a new request with the augmented system message.
ChatClientRequest processedChatClientRequest = chatClientRequest.mutate()
.prompt(chatClientRequest.prompt().augmentSystemMessage(augmentedSystemText))
.build();
// 5. Add all user messages from the current prompt to memory (after system
// message is generated)
// 4. Add the new user message to the conversation memory.
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
this.chatMemory.add(conversationId, userMessage);
return processedChatClientRequest;
}
- VectorStoreChatMemoryAdvisor。此 Advisor 使用提供的 VectorStore 实现管理对话记忆。在每次交互时,它从向量存储中检索对话历史并将其作为纯文本附加到系统消息中。
通过阅读源码,可以看出,通过用户消息和对话ID,去向量数据库相似性匹配,最后将查询结果与系统消息组成系统提示词模版,放入到prompt中,关键代码:
public ChatClientRequest before(ChatClientRequest request, AdvisorChain advisorChain) {
String conversationId = this.getConversationId(request.context(), this.defaultConversationId);
String query = request.prompt().getUserMessage() != null ? request.prompt().getUserMessage().getText() : "";
int topK = this.getChatMemoryTopK(request.context());
String filter = "conversationId=='" + conversationId + "'";
SearchRequest searchRequest = SearchRequest.builder().query(query).topK(topK).filterExpression(filter).build();
List<Document> documents = this.vectorStore.similaritySearch(searchRequest);
String longTermMemory = documents == null ? "" : (String)documents.stream().map(Document::getText).collect(Collectors.joining(System.lineSeparator()));
SystemMessage systemMessage = request.prompt().getSystemMessage();
String augmentedSystemText = this.systemPromptTemplate.render(Map.of("instructions", systemMessage.getText(), "long_term_memory", longTermMemory));
ChatClientRequest processedChatClientRequest = request.mutate().prompt(request.prompt().augmentSystemMessage(augmentedSystemText)).build();
UserMessage userMessage = processedChatClientRequest.prompt().getUserMessage();
if (userMessage != null) {
this.vectorStore.write(this.toDocuments(List.of(userMessage), conversationId));
}
return processedChatClientRequest;
}
4、代码实现
使用JdbcChatMemoryRepository和MessageChatMemoryAdvisor实现一个聊天记忆的demo
pom依赖:
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-starter-model-chat-memory-repository-jdbc</artifactId>
</dependency>
properties配置:
spring.ai.chat.memory.repository.jdbc.platform = mysql
// 是否初始化脚本
spring.ai.chat.memory.repository.jdbc.initialize-schema = always
// 脚本位置
spring.ai.chat.memory.repository.jdbc.schema=classpath:schema/memory-mysql.sql
spring.datasource.url=jdbc:mysql://localhost:3306/test_db?useUnicode=true&characterEncoding=utf-8&useSSL=false&serverTimezone=UTC
spring.datasource.username=root
spring.datasource.password=root
spring.datasource.driver-class-name=com.mysql.cj.jdbc.Driver
sql脚本:
create table spring_ai_chat_memory
(
id int auto_increment
primary key,
conversation_id int null comment '对话ID',
type varchar(64) null comment '消息类型',
content text null,
timestamp timestamp null
);
配置类:
@Configuration
public class OllamaConfig {
@Autowired
private JdbcChatMemoryRepository jdbcChatMemoryRepository;
@Bean
public ChatMemory chatMemory() {
return MessageWindowChatMemory.builder()
.chatMemoryRepository(jdbcChatMemoryRepository).maxMessages(10).build();
}
@Bean
public ChatClient chatClient(OllamaChatModel ollamaChatModel) {
return ChatClient.builder(ollamaChatModel)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory()).build())
.build();
}
}
入参:
@Data
public class ChatReq {
private String memoryId;
private String message;
}
controller:
@PostMapping("chat")
public Flux<String> chat(@RequestBody ChatReq chatReq) {
Prompt prompt = new Prompt(chatReq.getMessage(), OllamaChatOptions.builder().model("qwen3:8b").disableThinking().build());
return chatClient.prompt(prompt)
.advisors(a -> a.param(ChatMemory.CONVERSATION_ID, chatReq.getMemoryId()))
.stream().content();
}
数据库数据:

本文来自博客园,作者:0xCAFEBABE_001,转载请注明原文链接:https://www.cnblogs.com/0xcafebabe001/p/19368591

浙公网安备 33010602011771号