Spring AI Alibaba 项目源码学习(二)-Graph 定义与描述分析
Graph 定义与描述分析
请关注公众号:阿呆-bot
概述
本文档分析 spring-ai-alibaba-graph-core 模块中 Graph 的定义和描述机制,包括入口类、关键类关系、核心实现代码和设计模式。
入口类说明
StateGraph - Graph 定义入口
StateGraph 是定义工作流图的主要入口类,用于构建包含节点(Node)和边(Edge)的状态图。
核心职责:
- 管理图中的节点集合(Nodes)
- 管理图中的边集合(Edges)
- 提供节点和边的添加方法
- 支持条件路由和子图
- 编译为可执行的 CompiledGraph
关键代码:
public class StateGraph {
/**
* Constant representing the END of the graph.
*/
public static final String END = "__END__";
/**
* Constant representing the START of the graph.
*/
public static final String START = "__START__";
/**
* Constant representing the ERROR of the graph.
*/
public static final String ERROR = "__ERROR__";
/**
* Constant representing the NODE_BEFORE of the graph.
*/
public static final String NODE_BEFORE = "__NODE_BEFORE__";
/**
* Constant representing the NODE_AFTER of the graph.
*/
public static final String NODE_AFTER = "__NODE_AFTER__";
/**
* Collection of nodes in the graph.
*/
final Nodes nodes = new Nodes();
/**
* Collection of edges in the graph.
*/
final Edges edges = new Edges();
/**
* Factory for providing key strategies.
*/
private KeyStrategyFactory keyStrategyFactory;
/**
* Name of the graph.
*/
private String name;
/**
* Serializer for the state.
*/
private final StateSerializer stateSerializer;
构造函数:
public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
this.name = name;
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = stateSerializer;
}
public StateGraph(KeyStrategyFactory keyStrategyFactory, PlainTextStateSerializer stateSerializer) {
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = stateSerializer;
}
/**
* Constructs a StateGraph with the specified name, key strategy factory, and SpringAI
* state serializer.
* @param name the name of the graph
* @param keyStrategyFactory the factory for providing key strategies
* @param stateSerializer the SpringAI state serializer to use
*/
public StateGraph(String name, KeyStrategyFactory keyStrategyFactory, SpringAIStateSerializer stateSerializer) {
this.name = name;
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = stateSerializer;
}
/**
* Constructs a StateGraph with the specified key strategy factory and SpringAI state
* serializer.
* @param keyStrategyFactory the factory for providing key strategies
* @param stateSerializer the SpringAI state serializer to use
*/
public StateGraph(KeyStrategyFactory keyStrategyFactory, SpringAIStateSerializer stateSerializer) {
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = stateSerializer;
}
public StateGraph(String name, KeyStrategyFactory keyStrategyFactory) {
this.name = name;
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = new JacksonSerializer();
}
/**
* Constructs a StateGraph with the provided key strategy factory.
* @param keyStrategyFactory the factory for providing key strategies
*/
public StateGraph(KeyStrategyFactory keyStrategyFactory) {
this.keyStrategyFactory = keyStrategyFactory;
this.stateSerializer = new JacksonSerializer();
}
/**
* Default constructor that initializes a StateGraph with a Gson-based state
* serializer.
*/
public StateGraph() {
this.stateSerializer = new JacksonSerializer();
this.keyStrategyFactory = HashMap::new;
}
CompiledGraph - 编译后的可执行图
CompiledGraph 是 StateGraph 编译后的可执行形式,包含了优化后的节点工厂和边映射。
核心职责:
- 存储编译后的节点工厂映射
- 管理边映射关系
- 提供节点执行入口
- 支持中断和恢复
关键代码:
public class CompiledGraph {
private static final Logger log = LoggerFactory.getLogger(CompiledGraph.class);
private static String INTERRUPT_AFTER = "__INTERRUPTED__";
/**
* The State graph.
*/
public final StateGraph stateGraph;
/**
* The Compile config.
*/
public final CompileConfig compileConfig;
/**
* The Node Factories - stores factory functions instead of instances to ensure thread safety.
*/
final Map<String, Node.ActionFactory> nodeFactories = new LinkedHashMap<>();
/**
* The Edges.
*/
final Map<String, EdgeValue> edges = new LinkedHashMap<>();
private final Map<String, KeyStrategy> keyStrategyMap;
private final ProcessedNodesEdgesAndConfig processedData;
private int maxIterations = 25;
/**
* Constructs a CompiledGraph with the given StateGraph.
* @param stateGraph the StateGraph to be used in this CompiledGraph
* @param compileConfig the compile config
* @throws GraphStateException the graph state exception
*/
protected CompiledGraph(StateGraph stateGraph, CompileConfig compileConfig) throws GraphStateException {
maxIterations = compileConfig.recursionLimit();
关键类关系
Node - 节点抽象
Node 表示图中的节点,包含唯一标识符和动作工厂。
关键代码:
public class Node {
public static final String PRIVATE_PREFIX = "__";
public interface ActionFactory {
AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;
}
private final String id;
private final ActionFactory actionFactory;
public Node(String id, ActionFactory actionFactory) {
this.id = id;
this.actionFactory = actionFactory;
}
/**
* Constructor that accepts only the `id` and sets `actionFactory` to null.
* @param id the unique identifier for the node
*/
public Node(String id) {
this(id, null);
}
public void validate() throws GraphStateException {
if (Objects.equals(id, StateGraph.END) || Objects.equals(id, StateGraph.START)) {
return;
}
if (id.isBlank()) {
throw Errors.invalidNodeIdentifier.exception("blank node id");
}
if (id.startsWith(PRIVATE_PREFIX)) {
throw Errors.invalidNodeIdentifier.exception("id that start with %s", PRIVATE_PREFIX);
}
}
/**
* id
* @return the unique identifier for the node.
*/
public String id() {
return id;
}
/**
* actionFactory
* @return a factory function that takes a {@link CompileConfig} and returns an
* {@link AsyncNodeActionWithConfig} instance for the specified {@code State}.
*/
public ActionFactory actionFactory() {
return actionFactory;
}
Edge - 边抽象
Edge 表示图中的边,连接源节点和目标节点,支持条件路由。
关键代码:
public record Edge(String sourceId, List<EdgeValue> targets) {
public Edge(String sourceId, EdgeValue target) {
this(sourceId, List.of(target));
}
public Edge(String id) {
this(id, List.of());
}
public boolean isParallel() {
return targets.size() > 1;
}
public EdgeValue target() {
if (isParallel()) {
throw new IllegalStateException(format("Edge '%s' is parallel", sourceId));
}
return targets.get(0);
}
public boolean anyMatchByTargetId(String targetId) {
return targets().stream()
.anyMatch(v -> (v.id() != null) ? Objects.equals(v.id(), targetId)
: v.value().mappings().containsValue(targetId)
);
}
public Edge withSourceAndTargetIdsUpdated(Node node, Function<String, String> newSourceId,
Function<String, EdgeValue> newTarget) {
var newTargets = targets().stream().map(t -> t.withTargetIdsUpdated(newTarget)).toList();
return new Edge(newSourceId.apply(sourceId), newTargets);
}
OverAllState - 全局状态
OverAllState 是贯穿整个图执行过程的全局状态对象,用于在节点间传递数据。
关键代码:
public final class OverAllState implements Serializable {
public static final Object MARK_FOR_REMOVAL = new Object();
/**
* Internal map storing the actual state data. All get/set operations on state values
* go through this map.
*/
private final Map<String, Object> data;
/**
* Mapping of keys to their respective update strategies. Determines how values for
* each key should be merged or updated.
*/
private final Map<String, KeyStrategy> keyStrategies;
/**
* Store instance for long-term memory storage across different executions.
*/
private Store store;
/**
* The default key used for standard input injection into the state. Typically used
* when initializing the state with user or external input.
*/
关键类关系图
以下 PlantUML 类图展示了 Graph 定义相关的关键类及其关系:
@startuml
!theme plain
skinparam classAttributeIconSize 0
package "Graph Definition" {
class StateGraph {
-Nodes nodes
-Edges edges
-KeyStrategyFactory keyStrategyFactory
-String name
-StateSerializer stateSerializer
+addNode(String, NodeAction)
+addEdge(String, String)
+addConditionalEdges(...)
+compile(CompileConfig): CompiledGraph
}
class CompiledGraph {
+StateGraph stateGraph
+CompileConfig compileConfig
-Map<String, Node.ActionFactory> nodeFactories
-Map<String, EdgeValue> edges
-Map<String, KeyStrategy> keyStrategyMap
}
class Node {
-String id
-ActionFactory actionFactory
+id(): String
+actionFactory(): ActionFactory
+validate(): void
}
class Edge {
-String sourceId
-List<EdgeValue> targets
+isParallel(): boolean
+target(): EdgeValue
+validate(Nodes): void
}
class OverAllState {
-Map<String, Object> data
-Map<String, KeyStrategy> keyStrategies
-Store store
+get(String): Object
+put(String, Object): void
+registerKeyAndStrategy(String, KeyStrategy): void
}
interface NodeAction {
+apply(OverAllState): Map<String, Object>
}
interface EdgeAction {
+apply(OverAllState): String
}
class KeyStrategy {
+merge(Object, Object): Object
}
class StateSerializer {
+serialize(OverAllState): String
+deserialize(String): OverAllState
}
}
StateGraph "1" --> "1" CompiledGraph : compiles to
StateGraph "1" --> "*" Node : contains
StateGraph "1" --> "*" Edge : contains
StateGraph --> KeyStrategyFactory : uses
StateGraph --> StateSerializer : uses
Node --> NodeAction : creates via ActionFactory
Edge --> EdgeAction : uses for conditional routing
CompiledGraph --> Node : stores factories
CompiledGraph --> Edge : stores mappings
OverAllState --> KeyStrategy : uses
OverAllState --> Store : uses
note right of StateGraph
入口类:用于定义工作流图
支持添加节点和边
支持条件路由和子图
end note
note right of CompiledGraph
编译后的可执行图
包含优化的节点工厂
线程安全的工厂函数
end note
note right of OverAllState
全局状态对象
在节点间传递数据
支持键策略管理
end note
@enduml
实现关键点说明
1. Builder 模式
StateGraph 使用链式调用模式构建图,支持流畅的 API:
StateGraph graph = new StateGraph("MyGraph", keyStrategyFactory)
.addNode("node1", nodeAction1)
.addNode("node2", nodeAction2)
.addEdge(START, "node1")
.addConditionalEdges("node1", edgeAction, Map.of("yes", "node2", "no", END))
.addEdge("node2", END);
2. 工厂模式
Node 使用 ActionFactory 接口延迟创建节点动作,确保线程安全:
public interface ActionFactory {
AsyncNodeActionWithConfig apply(CompileConfig config) throws GraphStateException;
}
这种设计允许在编译时创建工厂函数,在执行时根据配置创建实际的动作实例。
3. 策略模式
KeyStrategy 用于控制状态键的更新策略,支持不同的合并逻辑(Replace、Append、Reduce 等)。
4. 序列化支持
StateGraph 支持多种状态序列化器:
PlainTextStateSerializer:纯文本序列化SpringAIStateSerializer:Spring AI 标准序列化JacksonSerializer:Jackson JSON 序列化(默认)
5. 验证机制
Node 和 Edge 都实现了 validate() 方法,确保图的完整性:
- 节点 ID 不能为空或使用保留前缀
- 边引用的节点必须存在
- 并行边不能有重复目标
总结说明
核心设计理念
- 声明式 API:通过
StateGraph提供声明式的图定义方式,隐藏底层复杂性 - 编译时优化:
StateGraph编译为CompiledGraph,将定义转换为可执行形式 - 状态管理:
OverAllState作为全局状态容器,支持键策略和序列化 - 类型安全:使用泛型和接口确保类型安全
- 可扩展性:通过接口和工厂模式支持自定义节点和边动作
关键优势
- 灵活性:支持同步和异步节点、条件路由、并行执行
- 可维护性:清晰的类层次结构和职责分离
- 可测试性:接口抽象便于单元测试
- 性能:编译时优化和工厂模式减少运行时开销
使用流程
- 定义图:使用
StateGraph添加节点和边 - 编译图:调用
compile()方法生成CompiledGraph - 执行图:通过
GraphRunner执行编译后的图 - 状态传递:使用
OverAllState在节点间传递数据

浙公网安备 33010602011771号