会话记忆持久化

spring ai 对于chatMemory的持久化存储,有两种方案

方案一:使用MessageWindowChatMemory进行持久化

问题:其中的add方法会把所有的消息进行合并,然后再调用repositorysaveAll方法,这样想要存储

  • 获取processedMessages中的最后一个,然后进行存储
  • 删除数据库中该 chatId 原有的所有数据,然后进行存储

方案二:直接自己重写chatMemorychatRepository

IMyChatMemoryRepository

参考了ChatMemoryRepository,对于findByConversationId方法添加了参数maxMessages,作用是获取数据库中会话记录时,可以limit

import org.springframework.ai.chat.messages.Message;
import java.util.List;

public interface IMyChatMemoryRepository {
    List<String> findConversationIds();
    List<Message> findByConversationId(String conversationId, Integer maxMessages);
    void saveAll(String conversationId, List<Message> messages);
    void deleteByConversationId(String conversationId);
}

MyChatMemoryRepository

由于chatMemorysavaAll方法传递过来的是最新消息的列表,所以在repositoty中直接进行存储即可

import com.tang.springaialibabademo.mapper.ChatMessageMapper;
import com.tang.springaialibabademo.model.ChatMessageEntity;
import jakarta.annotation.Resource;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

@Component
public class MyChatMemoryRepository implements IMyChatMemoryRepository {

    @Resource
    private ChatMessageMapper chatMessageMapper;

    @NotNull
    @Override
    public List<String> findConversationIds() {
        return chatMessageMapper.selectAllChatId();
    }

    @NotNull
    @Override
    public List<Message> findByConversationId(@NotNull String conversationId, Integer maxMessages) {
        List<ChatMessageEntity> chatMessageEntities = chatMessageMapper.selectByChatId(conversationId,maxMessages);
        return chatMessageEntities.stream()
                .map(entity -> {
                    String messageType = entity.getMessageType();
                    String content = entity.getMessage();

                    // 根据消息类型创建对应的 Message 对象
                    return switch (messageType) {
                        case "USER" -> new UserMessage(content);
                        case "ASSISTANT" -> new AssistantMessage(content);
                        case "SYSTEM" -> new SystemMessage(content);
                        default -> throw new IllegalArgumentException("Unknown message type: " + messageType);
                    };
                }).collect(Collectors.toList());
    }

    @Override
    public void saveAll(@NotNull String conversationId, @NotNull List<Message> messages) {
        if (messages.isEmpty()) {
            return;
        }
        List<ChatMessageEntity> entityList = new ArrayList<>();
        messages.forEach(message -> {
            ChatMessageEntity entity = new ChatMessageEntity();
            entity.setChatId(conversationId);
            entity.setMessageType(message.getMessageType().toString());
            entity.setMessage(message.getText());
            entityList.add(entity);
        });
        chatMessageMapper.insertBatch(entityList);
    }
    
    @Override
    public void deleteByConversationId(@NotNull String conversationId) {
        chatMessageMapper.deleteByChatId(conversationId);
    }
}

MyChatMemory

参考了MessageWindowChatMemory,使用build进行对象的创建

import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.util.Assert;

import java.util.List;

@Slf4j
public class MyChatMemory implements ChatMemory {
    private static final int DEFAULT_MAX_MESSAGES = 20;
    private final MyChatMemoryRepository myChatMemoryRepository;
    private final Integer maxMessages;

    private MyChatMemory(MyChatMemoryRepository myChatMemoryRepository, Integer maxMessages) {
        Assert.notNull(myChatMemoryRepository, "chatMemoryRepository cannot be null");
        Assert.isTrue(maxMessages > 0, "maxMessages must be greater than 0");

        this.myChatMemoryRepository = myChatMemoryRepository;
        this.maxMessages = maxMessages;
    }

    @Override
    public void add(@NotNull String conversationId, @NotNull List<Message> messages) {
        Assert.hasText(conversationId, "conversationId cannot be null or empty");
        Assert.notNull(messages, "messages cannot be null");
        Assert.noNullElements(messages, "messages cannot contain null elements");

        log.info("add messages to conversationId: {}, messages: {}", conversationId, messages);
        myChatMemoryRepository.saveAll(conversationId, messages);
    }

    @NotNull
    @Override
    public List<Message> get(@NotNull String conversationId) {
        Assert.hasText(conversationId, "conversationId cannot be null or empty");

        log.info("get messages from conversationId: {}", conversationId);
        return myChatMemoryRepository.findByConversationId(conversationId, maxMessages);
    }

    @Override
    public void clear(@NotNull String conversationId) {
        Assert.hasText(conversationId, "conversationId cannot be null or empty");

        log.info("clear messages from conversationId: {}", conversationId);
        myChatMemoryRepository.deleteByConversationId(conversationId);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private MyChatMemoryRepository myChatMemoryRepository;
        private int maxMessages = DEFAULT_MAX_MESSAGES;

        private Builder() {
        }

        public Builder chatMemoryRepository(MyChatMemoryRepository myChatMemoryRepository) {
            this.myChatMemoryRepository = myChatMemoryRepository;
            return this;
        }

        public Builder maxMessages(int maxMessages) {
            this.maxMessages = maxMessages;
            return this;
        }

        public MyChatMemory build() {
            return new MyChatMemory(this.myChatMemoryRepository, this.maxMessages);
        }
    }
}

ChatMessageEntity

对于存储的实体类定义如下

@Data
@Table(value = "chat_message")
public class ChatMessageEntity {
    @Id(keyType = KeyType.Auto)
    private Long id;
    @Column(value = "message")
    private String message;
    @Column(value = "chat_id")
    private String chatId;
    @Column(value = "message_type")
    private String messageType;
}

调用

在调用的类中,使用构造方法

 public TestAgent(ChatModel dashscopeChatMode,
                     MyChatMemoryRepository myChatMemoryRepository) {

        MyChatMemory myChatMemory = MyChatMemory.builder()
                .chatMemoryRepository(myChatMemoryRepository)
                .build();
        this.chatClient = ChatClient.builder(dashscopeChatMode)
                .defaultAdvisors(
                        MessageChatMemoryAdvisor.builder(myChatMemory).build()
                )
                .build();
    }
posted @ 2025-12-24 15:49  棠仔517890027  阅读(3)  评论(0)    收藏  举报