会话记忆持久化
spring ai 对于chatMemory的持久化存储,有两种方案
方案一:使用MessageWindowChatMemory进行持久化
问题:其中的add方法会把所有的消息进行合并,然后再调用repository的saveAll方法,这样想要存储
- 获取
processedMessages中的最后一个,然后进行存储 - 删除数据库中该
chatId原有的所有数据,然后进行存储
方案二:直接自己重写chatMemory和chatRepository
IMyChatMemoryRepository
参考了ChatMemoryRepository,对于findByConversationId方法添加了参数maxMessages,作用是获取数据库中会话记录时,可以limit
import org.springframework.ai.chat.messages.Message;
import java.util.List;
public interface IMyChatMemoryRepository {
List<String> findConversationIds();
List<Message> findByConversationId(String conversationId, Integer maxMessages);
void saveAll(String conversationId, List<Message> messages);
void deleteByConversationId(String conversationId);
}
MyChatMemoryRepository
由于chatMemory的savaAll方法传递过来的是最新消息的列表,所以在repositoty中直接进行存储即可
import com.tang.springaialibabademo.mapper.ChatMessageMapper;
import com.tang.springaialibabademo.model.ChatMessageEntity;
import jakarta.annotation.Resource;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Component
public class MyChatMemoryRepository implements IMyChatMemoryRepository {
@Resource
private ChatMessageMapper chatMessageMapper;
@NotNull
@Override
public List<String> findConversationIds() {
return chatMessageMapper.selectAllChatId();
}
@NotNull
@Override
public List<Message> findByConversationId(@NotNull String conversationId, Integer maxMessages) {
List<ChatMessageEntity> chatMessageEntities = chatMessageMapper.selectByChatId(conversationId,maxMessages);
return chatMessageEntities.stream()
.map(entity -> {
String messageType = entity.getMessageType();
String content = entity.getMessage();
// 根据消息类型创建对应的 Message 对象
return switch (messageType) {
case "USER" -> new UserMessage(content);
case "ASSISTANT" -> new AssistantMessage(content);
case "SYSTEM" -> new SystemMessage(content);
default -> throw new IllegalArgumentException("Unknown message type: " + messageType);
};
}).collect(Collectors.toList());
}
@Override
public void saveAll(@NotNull String conversationId, @NotNull List<Message> messages) {
if (messages.isEmpty()) {
return;
}
List<ChatMessageEntity> entityList = new ArrayList<>();
messages.forEach(message -> {
ChatMessageEntity entity = new ChatMessageEntity();
entity.setChatId(conversationId);
entity.setMessageType(message.getMessageType().toString());
entity.setMessage(message.getText());
entityList.add(entity);
});
chatMessageMapper.insertBatch(entityList);
}
@Override
public void deleteByConversationId(@NotNull String conversationId) {
chatMessageMapper.deleteByChatId(conversationId);
}
}
MyChatMemory
参考了MessageWindowChatMemory,使用build进行对象的创建
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;
import java.util.List;
@Slf4j
public class MyChatMemory implements ChatMemory {
private static final int DEFAULT_MAX_MESSAGES = 20;
private final MyChatMemoryRepository myChatMemoryRepository;
private final Integer maxMessages;
private MyChatMemory(MyChatMemoryRepository myChatMemoryRepository, Integer maxMessages) {
Assert.notNull(myChatMemoryRepository, "chatMemoryRepository cannot be null");
Assert.isTrue(maxMessages > 0, "maxMessages must be greater than 0");
this.myChatMemoryRepository = myChatMemoryRepository;
this.maxMessages = maxMessages;
}
@Override
public void add(@NotNull String conversationId, @NotNull List<Message> messages) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
Assert.notNull(messages, "messages cannot be null");
Assert.noNullElements(messages, "messages cannot contain null elements");
log.info("add messages to conversationId: {}, messages: {}", conversationId, messages);
myChatMemoryRepository.saveAll(conversationId, messages);
}
@NotNull
@Override
public List<Message> get(@NotNull String conversationId) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
log.info("get messages from conversationId: {}", conversationId);
return myChatMemoryRepository.findByConversationId(conversationId, maxMessages);
}
@Override
public void clear(@NotNull String conversationId) {
Assert.hasText(conversationId, "conversationId cannot be null or empty");
log.info("clear messages from conversationId: {}", conversationId);
myChatMemoryRepository.deleteByConversationId(conversationId);
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private MyChatMemoryRepository myChatMemoryRepository;
private int maxMessages = DEFAULT_MAX_MESSAGES;
private Builder() {
}
public Builder chatMemoryRepository(MyChatMemoryRepository myChatMemoryRepository) {
this.myChatMemoryRepository = myChatMemoryRepository;
return this;
}
public Builder maxMessages(int maxMessages) {
this.maxMessages = maxMessages;
return this;
}
public MyChatMemory build() {
return new MyChatMemory(this.myChatMemoryRepository, this.maxMessages);
}
}
}
ChatMessageEntity
对于存储的实体类定义如下
@Data
@Table(value = "chat_message")
public class ChatMessageEntity {
@Id(keyType = KeyType.Auto)
private Long id;
@Column(value = "message")
private String message;
@Column(value = "chat_id")
private String chatId;
@Column(value = "message_type")
private String messageType;
}
调用
在调用的类中,使用构造方法
public TestAgent(ChatModel dashscopeChatMode,
MyChatMemoryRepository myChatMemoryRepository) {
MyChatMemory myChatMemory = MyChatMemory.builder()
.chatMemoryRepository(myChatMemoryRepository)
.build();
this.chatClient = ChatClient.builder(dashscopeChatMode)
.defaultAdvisors(
MessageChatMemoryAdvisor.builder(myChatMemory).build()
)
.build();
}

浙公网安备 33010602011771号