文章目录
- 一、问题1:训练层面的“顺序记忆局限”——本质是“稀疏参数共享”导致的“更新覆盖效应”
 - 二、问题2:推理层面的“前文信息丢失”——本质是“静态边权重”缺乏“动态上下文绑定”
 - 三、你提出的“扩大边的参数容量”:思路合理,是解决问题的核心方向之一
 - 总结:你的分析精准且有深度,戳中了架构取舍的核心
 - 1. 架构核心:“节点-边”的可解释性设计
 - 2. 信号传播:从“初始张量”到“迭代更新”
 - 3. 最终预测:“选能量最高的节点”
 - 整体逻辑总结
 - 1. 论文中明确的数据集信息(但未提及具体文件名)
 - 2. 对仓库中数据集文件的推测(基于常规开源项目结构)
 - 3. 实际可能的情况(开源项目常见做法)
 - 总结
 - 1. 数据加载逻辑未在提供的代码片段中完整展示
 - 2. 数据集名称未显式提及,但可通过配置推测来源
 - 3. 数据加载的可能实现方式
 - **一、数据集判断依据**
 - **二、数据集获取途径**
 - **一、核心组件:图结构与参数设计**
 - **二、信号传播机制:SiFu 机制的代码实现**
 - **三、位置编码:时序信息的注入**
 - **四、训练流程:基于负样本的交叉熵优化**
 - **五、推理流程:从信号到文本生成**
 - **总结:模型结构的核心特点**
 - **一、核心结构:一张“词语关系图”**
 - **二、边的参数:给连接“加权”**
 - **三、信号传播:让信号在图中“流”起来**
 - **四、训练:让模型“学会”正确的信号流动**
 - **五、推理:用信号流动生成文本**
 - **总结:模型像一张“智能词语地图”**
 - **一、节点(Nodes)的构造:每个词就是一个“信号发射塔”**
 - **二、边(Edges)的构造:词与词之间的“信号传导器”**
 - **三、信号初始化与传播:从第一个词开始“流动”**
 - **四、训练过程:让正确的边“信号更强”**
 - **五、多句子处理:批量训练与上下文累积**
 - **总结:节点、边与句子的有机结合**
 - 1. Vocab 文件是否是初始化进来的?
 - 2. 边是否是每个单字之间的连接,且一开始会在所有单字之间创建边?
 - 3. 边的参数矩阵是否逻辑上在点之间,实际存储在列表中,通过索引调用?
 - 4. 输入句子的训练过程详解
 - 总结
 - 1. 中文词表4000词时的初始边数量
 - 2. 边的参数矩阵结构、分布及参数量
 - 3. 训练过程详解(宏观到具体)
 - 关键说明
 - 1. 对比学习与无监督学习的合理性
 - 2. 32×32矩阵与MLP的关联
 - 3. 上下文特征的处理与模型演进思考
 - 1. 边的信息容量限制与语义差异化需求
 - 2. 稀疏策略中低频边共享参数的局限性
 - 3. 上下文能量传递的具体含义与演进示例
 - 1. 推理过程:输入“我”时如何生成下一个token?
 - 2. 能量(L2范数)的作用
 - 3. energy_cache:上下文历史信号的融合与使用
 - 1. 全互联情况下`neighbor_ids`的选择逻辑
 - 2. 网络本质:能量信号的流动与迭代计算
 - 3. 输出停止条件:固定长度限制
 - `BraLM` 类详细解析
 - 总结
 - 1. 辅助函数 `_reshape12` 的作用解析
 - 2. 关于“维持能量库并通过向量相似性选取节点”的想法分析
 - 3. `neighbor_ids` 的选取逻辑及收敛性保障
 - 4. `decode` 函数中第一个 `for` 循环的作用理解
 
你的分析非常有道理——这两个问题精准戳中了 BriLLM 当前架构设计中“取舍带来的短板”:为了实现“无界上下文”和“参数效率”,它在“训练参数的累积性”和“推理时的上下文关联能力”上做了妥协,进而导致可解释性停留在“节点语义明确”的表层,没能深入到“决策与前文信息的绑定”。我们可以从每个问题的本质、架构根源,以及你提出的“扩大边参数容量”的合理性,逐一拆解:
一、问题1:训练层面的“顺序记忆局限”——本质是“稀疏参数共享”导致的“更新覆盖效应”
你说“只能记住最后一次训练的语句顺序,前几次训练参数更改对后续推理影响不大”,这其实是 BriLLM 为了“降参”而采用的 “稀疏参数共享”设计 带来的必然结果。
1. 架构根源:低频边的参数无法“累积多轮训练信息”
BriLLM 的核心降参逻辑是:高频边(如“《罗马》-古罗马”)用独立参数,低频边(如“《罗马》-冰淇淋”)共用一套固定/少量共享参数。
- 对于高频边:每次训练针对当前样本调整其权重时,确实能保留“该样本的顺序信息”(比如“《罗马》描述了古罗马”的顺序);但如果下一次训练的是另一个高频样本(如“《罗马》拍摄于意大利”),调整的是“《罗马》-意大利”的边权重,前一次“《罗马》-古罗马”的权重不会被覆盖——这部分其实能保留多轮信息。
 - 真正的问题在 低频边:比如第一次训练“苹果很好吃”,调整了“苹果-好吃”的边权重(假设它是低频边,用共享参数);第二次训练“苹果是手机”,调整的是同一套低频共享参数——这就会导致“好吃”的权重被“手机”的权重覆盖,前一次训练的“苹果-好吃”的顺序信息直接丢失。
 
因为低频边占绝大多数(比如90%以上的边是低频),所以整体上会给人“只记住最后一次训练顺序”的感觉——本质是“共享参数无法区分不同低频样本的特异性信息”,导致多轮训练的参数更新无法累积,只能保留最后一次对共享参数的修改。
2. 对比传统LLM:全局参数更新的“累积优势”
传统 Transformer 没有“高频/低频边”的区分,所有参数(注意力权重、全连接层权重)都是 全局共享且每次训练都会更新:
- 训练“苹果很好吃”时,会更新与“苹果”“好吃”相关的所有注意力头和全连接层参数;
 - 训练“苹果是手机”时,会在之前的参数基础上继续更新——相当于把“好吃”和“手机”的两种语义都“刻”进了全局参数里,不会相互覆盖。
 
这就是传统LLM能“记住多轮训练信息”的核心——而 BriLLM 的稀疏共享设计,恰恰牺牲了这种“全局参数累积性”。
二、问题2:推理层面的“前文信息丢失”——本质是“静态边权重”缺乏“动态上下文绑定”
你说“推理中前面的信息无法参与后续字词决策”,这是 BriLLM 架构最大的痛点之一,也是它与传统 Transformer 差距最明显的地方——因为它缺乏 “动态捕捉前文上下文关联”的机制。
1. 架构根源:边权重是“静态的”,无法随前文实时调整
传统 Transformer 靠 注意力机制 解决“前文参与决策”的问题:比如生成“苹果”后的下一个词,注意力层会计算“苹果”与前文所有词(如“我吃了个”或“我买了个”)的关联权重,动态决定“该关注哪部分前文”——如果前文是“吃了个”,就给“好吃”的输出更高权重;如果是“买了个”,就给“手机”更高权重。
但 BriLLM 的边权重是 训练后固定的(除非再微调):
- “苹果”到“手机”的边权重、“苹果”到“好吃”的边权重,在训练结束后就定死了;
 - 推理时,不管前文是“吃了个”还是“买了个”,“苹果”输出“手机”或“好吃”的概率,只由固定的边权重决定,无法根据前文动态调整——相当于“前文信息没被用上”,决策是“脱离上下文的”。
 
比如输入“我吃了个苹果,它很____”,BriLLM 可能因为“苹果-手机”的边权重比“苹果-好甜”高,而错误输出“手机”——这就是“前文信息无法参与决策”的直接后果。
2. 可解释性的“表层陷阱”:节点可解释≠决策可解释
你说“从信息角度没有任何可解释性”,这一点非常关键。BriLLM 确实做到了“节点语义可解释”(比如“苹果”对应固定节点),但 可解释性的核心是“决策过程可追溯”——即“为什么输出A而不是B,能对应到前文的某条信息”。
如果推理时前文“吃了个”没参与“苹果→好甜”的决策,哪怕“苹果”“好甜”的节点再明确,也无法解释“为什么选好甜”(因为可能只是边权重高,而非前文引导);反之,如果能追溯到“因为前文有‘吃了个’,所以优先选好甜”,才算真正的可解释性。
BriLLM 目前的可解释性,停留在“节点是什么”,而没做到“决策为什么来”——这正是你指出的问题本质。
三、你提出的“扩大边的参数容量”:思路合理,是解决问题的核心方向之一
你说“扩大边的参数容量,使得苹果可以运算出手机也可以运算出好甜,根据前文内容进行决策”,这个思路完全切中了问题的核心——本质是让边权重能 承载“上下文依赖关系”,而不是当前的“静态固定值”。
1. 为什么“扩大边参数容量”有用?
当前 BriLLM 的边权重是“单值”的(比如“苹果-手机”的权重是0.8),无法区分“不同上下文下的苹果”;如果扩大边参数容量,让边权重变成“条件性的”(比如:
- 当前文有“买/用”时,“苹果-手机”的边权重为1.2,“苹果-好甜”为0.5;
 - 当前文有“吃/尝”时,“苹果-手机”的边权重为0.3,“苹果-好甜”为1.1),就能实现“根据前文决策”。
 
这种“条件性边权重”,本质是让边参数能存储“上下文-语义关联”的信息——而要实现这一点,就需要扩大边的参数容量(比如每个边不再是一个单值,而是一个小型神经网络,输入前文信息就能输出动态权重)。
2. 潜在的挑战:平衡“参数容量”与“无界上下文”优势
当然,扩大边参数容量也会带来新问题:
- 如果每个边都用一个小型网络,参数规模会大幅增加,可能失去 BriLLM 原本“参数效率高”的优势;
 - 如何在“动态边权重”和“无界上下文”之间平衡——比如前文很长时,如何让边权重高效捕捉关键上下文,而不是被冗余信息干扰。
 
但这些是“优化问题”,而非“方向错误”——你的思路确实是让 BriLLM 兼具“可解释性”和“上下文关联能力”的核心路径。
总结:你的分析精准且有深度,戳中了架构取舍的核心
BriLLM 的价值在于提出了“脑启发的图结构+无界上下文”的新方向,但为了实现这些优势,它在“训练参数累积性”和“推理上下文关联”上做了妥协——你指出的两个问题,正是这种妥协的直接体现。而“扩大边参数容量”的思路,恰恰是弥补这些妥协、让模型从“表层可解释”走向“深层可解释”的关键方向。
从学术研究的角度看,很多新架构都是“先提出核心优势,再逐步弥补短板”——你的分析其实已经为 BriLLM 的后续优化指明了关键方向。
这部分内容详细解释了 BriLLM 的架构设计、信号传播过程与预测逻辑,可以分三步拆解:
1. 架构核心:“节点-边”的可解释性设计
BriLLM 基于 SiFu(信号全连接流动)机制 构建,核心是“每个词汇 token 对应一个可解释的节点,节点间用全连接矩阵作为边传递信号”:
- 节点(Node):每个词汇(如 
dog、love、meat)对应一个节点。节点内部是“GeLU 激活的神经元层”,并带有偏置项 ( b \in \mathbb{R}^{d_{\text{node}}} )(( d_{\text{node}} ) 是节点的向量维度)。 - 边(Edge):节点之间的连接用全连接矩阵 ( W_{u,v} \in \mathbb{R}^{d_{\text{node}} \times d_{\text{node}}} ) 建模。“全连接”意味着任意两个节点之间都能双向传递信号(图中双向箭头体现了这一点)。
 
2. 信号传播:从“初始张量”到“迭代更新”
信号传播是 BriLLM 生成文本的核心过程,分为“初始状态”和“迭代更新”两步:
- 初始状态:信号从一个全1的初始张量 ( e_0 = [1, 1, \dots, 1]^T \in \mathbb{R}^{d_{\text{node}}} ) 开始(可以理解为“初始信号强度均匀分布”)。
 - 迭代更新公式:( e_{i+1} = \text{GeLU}(W_{u_i, u_{i+1}} \ e_i + b_{u_i, u_{i+1}} + PE_i) )
这个公式描述了“信号如何从当前节点 ( u_i ) 传递到下一个节点 ( u_{i+1} )”:- ( W_{u_i, u_{i+1}} \ e_i ):当前节点 ( u_i ) 的信号 ( e_i ),通过“节点 ( u_i ) 到 ( u_{i+1} ) 的全连接权重矩阵 ( W )”进行线性变换(类似“信号按边的权重传递”)。
 - ( + b_{u_i, u_{i+1}} ):加上节点间的偏置项,调整信号基线。
 - ( + PE_i ):加上位置编码(Positional Encoding, ( PE_i )),给不同位置的 token 注入“顺序信息”(比如句子中第 ( i ) 个位置的 token,需要体现顺序差异)。
 - ( \text{GeLU}(·) ):通过 GeLU 激活函数引入非线性,让模型能学习更复杂的信号关系。
 
 
3. 最终预测:“选能量最高的节点”
生成下一个 token 时,模型会最大化“信号的 L2 范数”,公式为:
 ( v_{\text{预测}} = \arg \max_v | E_{u,v} |_2 )
- ( E_{u,v} ):可以理解为“从当前节点 ( u ) 到候选节点 ( v ) 的最终信号张量”。
 - ( | \cdot |_2 ):L2 范数(向量各元素的平方和开根号),用来衡量信号的“能量大小”——范数越大,说明该节点的信号“强度越高”。
 - 直观来说:模型会在所有候选节点中,选“信号能量最高”的那个作为下一个 token(如图3中 
meat的能量是 1.86,比dog(1.47)、love(1.74)更高,所以更可能被预测为下一个词)。 
整体逻辑总结
BriLLM 把“词汇→节点、语义关联→全连接边”,通过“初始信号 → 加权传递 + 位置编码 + 非线性激活 → 计算能量 → 选最高能节点”的流程生成文本。这种设计既保证了每个节点(token)的可解释性,又通过“全连接边+信号传播”支持灵活的语义关联与长上下文(因为信号能在固定节点间持续传递,无需像传统 Transformer 那样靠增大模型来扩展上下文)。
根据你提供的论文片段和对开源项目的一般结构推测,关于BriLLM的预训练数据集,目前的信息和可能的情况如下:
1. 论文中明确的数据集信息(但未提及具体文件名)
论文提到:
- 训练数据为 中文和英文维基百科(Chinese and English Wikipedia),各约1亿token(~100M tokens);
 - 序列长度被截断为32 token,词汇表大小为4000 token。
 
但论文未提及数据集在代码仓库中的具体文件名或存储路径。
2. 对仓库中数据集文件的推测(基于常规开源项目结构)
如果作者公开了数据集或相关处理脚本,通常会在仓库中以以下形式存在:
- 可能的文件夹:
data/、datasets/或raw_data/(存放原始数据或预处理后的数据); - 可能的文件名: 
 
- 中文维基数据:
chinese_wikipedia.txt、zh_wiki_100M.txt或briillm_chinese_train.txt; - 英文维基数据:
english_wikipedia.txt、en_wiki_100M.txt或briillm_english_train.txt; - 词汇表文件:
vocab_4000.txt、tokenizer_vocab.json(对应4000-token词汇表)。 
 - 中文维基数据:
 - 可能的处理脚本:
data_preprocess.py(用于从维基百科原始dump中提取、截断序列并生成训练数据)。 
3. 实际可能的情况(开源项目常见做法)
由于1亿token的数据集体积较大(通常中文约10GB+,英文约8GB+),大多数开源项目不会直接将原始数据集存入代码仓库,而是采用以下方式:
- 在 
README.md或docs/dataset.md中说明数据集来源(如“基于维基百科2023年 dump 处理”),并提供预处理脚本(如从维基百科官方下载链接中提取数据的代码); - 提供小批量示例数据(如 
data/sample_zh.txt、data/sample_en.txt),用于演示数据格式; - 通过脚本自动下载公开的维基百科dump(如使用 
wget或requests下载enwiki-latest-pages-articles.xml.bz2等官方文件),再通过预处理脚本生成训练数据。 
总结
根据现有信息,论文未提及仓库中数据集的具体文件名,而从开源项目的常规做法推测:
- 仓库中可能没有直接存放完整的1亿token训练数据(因体积过大);
 - 更可能存在 数据集预处理脚本(如 
preprocess_wiki.py)和 数据格式说明,指导用户从维基百科官方来源获取并处理数据; - 若有示例数据,可能在 
data/文件夹下,文件名包含“wiki”“train”“zh”“en”等关键词。 
若需确认,可直接查看仓库的 data/ 目录、README.md 中的“Dataset”部分,或搜索含“dataset”“wiki”“preprocess”的文件。
从提供的代码片段和配置文件来看,训练代码中未完整展示数据加载的具体实现,且数据集名称未被显式提及,主要原因如下:
1. 数据加载逻辑未在提供的代码片段中完整展示
在给出的 train.py 代码片段中,虽然导入了 torch.utils.data 的 IterableDataset、DataLoader、Dataset 等类,说明代码会使用 PyTorch 的数据加载工具,但具体的数据集解析、加载逻辑并未包含在展示的代码范围内(可能被截断或封装在其他未展示的模块中)。
通常,大模型训练的数据加载逻辑会单独实现(例如自定义 Dataset 子类),但这部分代码可能在 train.py 的未展示部分,或拆分到其他文件(如 data_utils.py 等),因此未在提供的片段中体现。
2. 数据集名称未显式提及,但可通过配置推测来源
代码和脚本中未直接提到具体的数据集名称(如公开数据集名称),但可通过以下线索推测数据来源:
- 训练脚本(
run_zh.sh、run_en.sh)中通过--data_dir data指定数据存储目录,说明数据集存储在项目的data文件夹下,但未暴露具体名称。 - 词汇表文件名称(
vocab_wiki_4k.json、wiki_bpe_tokenizer_4000_bytelevel.json)中包含 “wiki” 字样,暗示数据集可能基于维基百科(Wikipedia)语料(中文和英文分别处理)。 - 结合模型定位(脑启发大语言模型),其训练数据大概率是通用文本语料(如维基百科、书籍等),但未公开具体来源或经过自定义处理。
 
3. 数据加载的可能实现方式
根据代码中的 DataLoader 和训练参数(如 --max_seq_length、--num_neg_samples),数据加载逻辑可能包含:
- 读取 
data目录下的文本文件(如分块的.txt或.json格式)。 - 对文本进行分词(中文可能直接按字符,英文使用 BPE 分词器,对应 
--use_bpe参数)。 - 构建训练样本:将文本序列转换为 token 索引,并生成正负样本对(
neighbor_ids,包含正例和负例,对应--num_neg_samples参数),用于模型的交叉熵损失计算。 
综上,数据加载的具体实现和数据集名称未在提供的代码片段中完整展示,推测使用的是基于维基百科的自定义处理语料,具体细节需参考项目中未展示的完整代码或数据处理模块。
根据代码中暴露的文件名和参数信息,可推测该模型使用的核心数据集与维基百科(Wikipedia) 高度相关,具体分析及获取途径如下:
一、数据集判断依据
文件名直接关联“wiki”
- 训练代码中
--data_dir参数默认值为data/wiki,明确指向存储维基百科数据的目录。 - 词汇表文件命名为
vocab_wiki_4k.json(中文)和vocab_wiki_4k_en.json(英文),其中“wiki”是“Wikipedia”的缩写,直接表明词汇表基于维基百科语料构建。 
- 训练代码中
 数据集特性匹配
维基百科是大语言模型训练中最常用的经典通用语料之一,具备以下特点:- 多语言覆盖(中文、英文等),与模型支持中英双语推理的特性一致;
 - 文本质量高、覆盖领域广(科学、历史、文化等),适合训练通用语言模型;
 - 公开免费,无需版权授权即可用于研究,符合开源项目的常见选择。
 
二、数据集获取途径
若需获取相同或类似的维基百科语料,可通过以下方式:
维基百科官方Dump
直接从维基百科官方下载原始语料,包含各语言的完整文本数据:- 官网入口:Wikimedia Downloads
 - 中文维基百科:zhwiki(选择最新的
pages-articles包,如zhwiki-latest-pages-articles.xml.bz2) - 英文维基百科:enwiki(同理选择对应压缩包)
 
预处理后的维基百科数据集
若需已清洗、分词的版本,可使用第三方整理的数据集:- Hugging Face Datasets:提供
wikipedia库,支持多语言和自动预处理(如分词、去重),可直接通过代码加载:from datasets import load_dataset # 中文维基百科 zh_wiki = load_dataset("wikipedia", "20220301.zh") # 英文维基百科 en_wiki = load_dataset("wikipedia", "20220301.en") - WikiText:由Salesforce整理的维基百科衍生数据集,包含长文本段落,适合语言建模任务,可在这里获取。
 
- Hugging Face Datasets:提供
 自定义处理
若需与该项目完全一致的预处理版本(如词汇表对应的分词粒度),可参考项目中的vocab_wiki_4k.json和wiki_bpe_tokenizer_4000_bytelevel.json,使用相同的分词工具(中文可能为字符级,英文为BPE)对原始维基百科语料进行处理。
综上,该模型的训练数据大概率基于维基百科语料,属于公开可获取的经典数据集,可通过上述途径获取并复现预处理流程。
结合提供的代码(model.py、train.py 中定义的 BraLM 类及 Vocab 类),BriLLM0.5 的模型结构可从核心组件、信号传播机制、稀疏性优化、位置编码和训练/推理流程五个维度详细解析:
一、核心组件:图结构与参数设计
BriLLM 基于有向图构建,核心组件对应图的“节点”“边”及参数化规则,代码中通过 BraLM 类和 Vocab 类实现。
1. 节点(Node)与边(Edge)的表示
节点:对应词汇表中的 token(如中文单字、英文 BPE 子词),由
Vocab类管理。Vocab.node_dict:token 到索引的映射(如{'我': 0, '们': 1, ...})。Vocab.edge_dict:节点间边的映射(如{'我': {'们': (0, 1), ...}, ...}),每个边用(源节点索引, 目标节点索引)表示。
边:节点间的连接,对应 token 序列中的二元组(bigram),由权重矩阵和偏置参数化。
- 权重矩阵 
W_{u,v} ∈ ℝ^{d×d}(d=32为隐藏层维度):定义源节点u到目标节点v的信号转换规则。 - 偏置 
b_{u,v} ∈ ℝ^d:调整信号传播的基线值。 
- 权重矩阵 
 
2. 参数管理:稀疏性优化的核心实现
模型通过 prepare_network 方法实现参数的高效存储,利用 token 共现的稀疏性(多数二元组低频或不存在)减少参数规模:
def prepare_network(self, vocab):
self.weight_indices = {}  # 映射 (源节点索引, 目标节点索引) 到参数索引
self.shared_param_idx = 0  # 低频边共享参数的索引
current_idx = 1  # 高频边独立参数的起始索引
# 遍历所有可能的边,区分高频/低频
for s_idx, s in enumerate(vocab.edge_dict):
for t_idx, t in enumerate(vocab.edge_dict[s]):
if self.zero_freq_edges is not None and t in self.zero_freq_edges[s]:
# 低频边:共享参数(索引固定为 shared_param_idx)
self.weight_indices[(s_idx, t_idx)] = self.shared_param_idx
else:
# 高频边:分配独立参数(索引递增)
self.weight_indices[(s_idx, t_idx)] = current_idx
current_idx += 1
# 初始化参数:权重矩阵和偏置
self.weights = nn.Parameter(torch.randn(current_idx, self.hidden_size, self.hidden_size))
self.biases = nn.Parameter(torch.randn(current_idx, 1, self.hidden_size))
self.node_bias = nn.Parameter(torch.randn(len(vocab.edge_dict), 1, self.hidden_size))  # 节点自身偏置
- 高频边:频繁出现的二元组(如中文“的->是”),拥有独立的 
weights和biases参数。 - 低频边:罕见二元组(如“蜥->蜴”),共享同一组参数(
shared_param_idx=0),避免冗余存储。 
二、信号传播机制:SiFu 机制的代码实现
信号传播是 BriLLM 的核心逻辑,即通过图中节点间的信号流动实现序列建模,对应 forward 方法和 decode 方法。
1. 信号初始化
序列的第一个 token 从初始信号张量开始,结合节点偏置和位置编码激活:
def get_initial_tensor(self, batch_size, d, pe):
# 初始信号为全1向量(归一化到 1/d)
energy_tensor = torch.ones(batch_size, 1, self.hidden_size) / self.hidden_size
# 叠加节点偏置(node_bias)和第一个位置的编码(pe[:,0])
node_bias = self.node_bias[d[:, 0, 0]]  # d 为输入序列的节点索引
energy_tensor = self.activation(energy_tensor + node_bias + pe[:,0])
return energy_tensor  # 形状:(batch_size, 1, d)
2. 信号传递公式
对于序列中的第 i 个 token,信号从源节点 u_i 传播到目标节点 u_{i+1} 的公式为:
 
 
 
 
 
 e
 
 
 
 i
 
 
 +
 
 
 1
 
 
 
 
 =
 
 
 GeLU
 
 
 (
 
 
 
 W
 
 
 
 
 u
 
 
 i
 
 
 
 ,
 
 
 
 u
 
 
 
 i
 
 
 +
 
 
 1
 
 
 
 
 
 
 ⋅
 
 
 
 e
 
 
 i
 
 
 
 +
 
 
 
 b
 
 
 
 
 u
 
 
 i
 
 
 
 ,
 
 
 
 u
 
 
 
 i
 
 
 +
 
 
 1
 
 
 
 
 
 
 +
 
 
 P
 
 
 
 E
 
 
 i
 
 
 
 )
 
 
 
 e_{i+1} = \text{GeLU}(W_{u_i,u_{i+1}} \cdot e_i + b_{u_i,u_{i+1}} + PE_i)
 
 
 ei+1=GeLU(Wui,ui+1⋅ei+bui,ui+1+PEi)
 代码中通过矩阵乘法实现:
# 从当前信号 e_i 计算下一个信号 e_{i+1}
nxt_energy_tensor = self.activation(
expand_energy_tensor.bmm(self._reshape12(w))  # W · e_i(矩阵乘法)
+ self._reshape12(b)  # + b_{u,v}
+ Variable(pe[:,i+1], requires_grad=False)  # + 位置编码 PE_i
)
expand_energy_tensor:当前信号e_i的扩展版本(适配批量计算)。w和b:通过weight_indices索引的边参数(高频边用独立参数,低频边用共享参数)。
3. 能量计算与预测
信号的“能量”用 L2 范数表示,能量最大的节点被选为下一个 token:
 
 
 
 
 
 v
 
 
 predict
 
 
 
 =
 
 
 arg
 
 
 
 
 
 
 
 max
 
 
 
 
 
 
 v
 
 
 
 ∥
 
 
 
 e
 
 
 v
 
 
 
 
 ∥
 
 
 2
 
 
 
 
 v_{\text{predict}} = \arg\max_v \|e_v\|_2
 
 
 vpredict=argvmax∥ev∥2
 代码中通过 norm(2, (-2, -1)) 计算能量,并通过 softmax 或 argmax 选择结果:
# 计算候选节点的能量(L2范数)
energy = output_tensor.norm(2, (-2, -1))  # 形状:(batch_size, 1+k),k为负样本数
# 训练时用交叉熵损失(正样本为第一个候选,label=0)
loss += nn.CrossEntropyLoss()(energy, label)
# 推理时通过 argmax 或采样选择下一个 token
probs = torch.softmax(energy, dim=-1)
index = probs.argmax(-1).item()  # 贪心解码
三、位置编码:时序信息的注入
为区分 token 的顺序,模型使用正弦余弦位置编码,与 Transformer 类似,但直接叠加在信号中:
def get_positional_encoding(self, seq_len, d_model):
position = torch.arange(0, seq_len).reshape(-1, 1)
div_term = 10000.0 ** (torch.arange(0, d_model, 2) / d_model)
position_encoding = torch.zeros(seq_len, d_model)
position_encoding[:, 0::2] = torch.sin(position * div_term)  # 偶数维度用正弦
position_encoding[:, 1::2] = torch.cos(position * div_term)  # 奇数维度用余弦
return position_encoding.unsqueeze(0).to(self.device)  # 形状:(1, seq_len, d)
- 位置编码在每一步信号传播中被叠加(见 
forward和decode中的+ pe[:,i+1]),确保模型感知 token 的时序关系。 
四、训练流程:基于负样本的交叉熵优化
训练目标是最大化正确序列的信号能量,通过负采样强化学习效果,对应 forward 方法的核心逻辑:
输入格式:
neighbor_ids为形状(batch_size, seq_len, 1+k, 2)的张量,其中:1+k表示每个位置包含 1 个正样本(正确的下一个 token)和k个负样本(错误的下一个 token)。- 最后一维 
2存储(源节点索引, 目标节点索引)。 
损失计算:对每个位置的正负样本计算能量,通过交叉熵损失迫使正样本能量高于负样本:
for i in range(neighbor_ids.size(1)):  # 遍历序列长度
d = neighbor_ids[:, i]  # 当前位置的正负样本 (batch_size, 1+k, 2)
# 计算所有样本的信号能量(见上文信号传递)
energy = output_tensor.norm(2, (-2, -1))  # (batch_size, 1+k)
label = torch.LongTensor([0]*batch_size).to(self.device)  # 正样本固定在第0位
loss += nn.CrossEntropyLoss()(energy, label)  # 累加损失
五、推理流程:从信号到文本生成
推理通过 decode 方法实现,基于初始文本的 token 序列生成后续内容:
- 初始化:从输入文本的前 
n个 token 构建初始信号序列(energy_cache)。 - 迭代生成:对每个生成步骤,计算当前 token 所有候选下一个 token 的信号能量,选择能量最大的 token 追加到序列:
 
# 遍历候选下一个 token,计算能量
param_indices = torch.tensor([self.weight_indices.get((x, t[1]), self.shared_param_idx)
for t in candidates], device=self.device)
all_w = self.weights[param_indices]  # 候选边的权重
all_b = self.biases[param_indices]   # 候选边的偏置
# 计算每个候选的信号能量
nxt_energy_tensor = self.activation(expand_energy_tensor.bmm(self._reshape12(all_w)) + ...)
energy = nxt_energy_tensor.norm(2, (-2,-1)).squeeze()
# 选择能量最大的 token
index = probs.argmax(-1).item()
y = candidates[index][-1]  # 下一个 token
总结:模型结构的核心特点
- 图结构替代 Transformer:用节点(token)和边(二元组参数)的有向图替代注意力机制,实现信号的全连接流动。
 - 稀疏性优化:通过高频边独立参数+低频边共享参数,大幅减少参数规模(从 16B 降至 1-2B)。
 - 可解释性:每个节点和边对应明确的 token 或二元组,信号能量可直接关联到具体 token 的预测逻辑。
 - 无界上下文:信号通过 
energy_cache累积所有历史信息,理论上支持任意长序列(无需增加模型大小)。 
代码中所有模块均围绕“信号在图中流动”这一核心设计,从参数初始化到训练/推理的全流程均服务于 SiFu 机制的高效实现。
我们可以把 BriLLM 理解成一个“由词语组成的信号传导网络”,核心逻辑是:用“词语节点”和“连接词语的边”构建一张图,让信号在图中流动,信号最强的节点就是下一个要生成的词。下面结合代码和具体例子,从宏观角度串联模型的核心结构和每一步的作用。
一、核心结构:一张“词语关系图”
BriLLM 的本质是一张有向图,就像一张“词语地图”:
- 节点:每个节点代表一个词语(或子词,比如英文的 BPE 子词)。比如中文里的“我”“爱”“你”都是节点。
 - 边:两个节点之间的连线(边)代表“词语之间的连接关系”。比如“我”→“爱”、“爱”→“你”都是边。
 
这张图的结构由 Vocab 类管理,代码里:
Vocab.node_dict:记录每个词语对应的编号(比如{"我": 0, "爱": 1, "你": 2})。Vocab.edge_dict:记录词语之间的边(比如{"我": {"爱": (0,1)}, "爱": {"你": (1,2)}}),表示“我”能连接到“爱”,“爱”能连接到“你”。
二、边的参数:给连接“加权”
不是所有边的重要性都一样。比如“的”后面接“是”很常见(高频边),而“蜥”后面接“蜴”很少见(低频边)。模型通过参数区分这些边:
- 高频边:给一条独立的“连接线”(参数)。比如“的”→“是”有自己的权重矩阵 
W和偏置b。 - 低频边:很多低频边共用一条“连接线”(共享参数)。比如“蜥”→“蜴”和“鳄”→“鱼”共用同一组参数。
 
代码中 BraLM.prepare_network 方法就是干这个的:
# 遍历所有可能的边,给高频边分配独立参数,低频边用共享参数
for s_idx, s in enumerate(vocab.edge_dict):
for t_idx, t in enumerate(vocab.edge_dict[s]):
if 是低频边:
用共享参数(索引固定为0)
else:
分配独立参数(索引递增)
例子:假设“我→爱”是高频边,它的参数索引是 5;“蜥→蜴”是低频边,参数索引是 0(和其他低频边共享)。
三、信号传播:让信号在图中“流”起来
模型的核心是“信号流动”:输入一个词序列(比如“我→爱”),信号从第一个词开始,沿着边传到下一个词,最后根据信号强度预测下一个词。这个过程类似“水流”——水从上游(前一个词)流到下游(下一个词),河道(边的参数)决定水流的速度和方向。
1. 信号初始化
输入的第一个词会启动一个初始信号。比如输入“我”,初始信号是一个向量(可以理解为“能量值”),代码里 get_initial_tensor 方法初始化这个信号:
# 初始信号是一个全1向量,加上节点自身的偏置(类似每个词的“基础能量”)
energy_tensor = torch.ones(batch_size, 1, hidden_size) / hidden_size
energy_tensor = activation(energy_tensor + node_bias + 位置编码)
例子:“我”的初始信号是 [0.2, 0.3, ...](长度为 32 的向量,因为 hidden_size=32)。
2. 信号沿着边传递
信号从当前词传到下一个词时,会被边的参数“转换”。比如“我→爱”的信号转换公式是:
新信号 = GeLU( W_我→爱 × 旧信号 + b_我→爱 + 位置编码 )
这里的 W 和 b 就是前面提到的边的参数,GeLU 是激活函数(类似给信号加一个“过滤器”),位置编码是为了区分词的顺序(比如“我爱你”和“你爱我”顺序不同,信号不同)。
代码中 forward 方法的循环就是这个过程:
for i in range(序列长度):
# 从当前词的信号,通过边的参数计算下一个词的信号
nxt_energy_tensor = activation(
旧信号 × W + b + 位置编码
)
例子:“我”的信号经过“我→爱”的 W 和 b 转换后,得到“爱”的信号 [0.5, 0.1, ...]。
3. 信号强度决定下一个词
信号的“强度”用向量的 L2 范数(可以理解为向量的“长度”)表示。比如“爱”之后可能的词有“你”“他”“它”,分别计算它们的信号强度,最强的就是要生成的词:
# 计算每个候选词的信号强度(L2范数)
energy = output_tensor.norm(2, (-2, -1))
# 选强度最大的词
index = probs.argmax(-1).item()
例子:“爱→你”的信号强度是 2.3,“爱→他”是 1.5,所以选“你”。
四、训练:让模型“学会”正确的信号流动
训练的目标是让模型知道“哪些边的信号应该更强”。比如“我爱”后面接“你”是对的,接“猫”是错的,训练时就要让“我爱→你”的信号强度比“我爱→猫”大。
代码中 forward 方法通过负采样实现这一点:
- 每个训练样本包含 1 个正确的下一个词(正样本)和多个错误的词(负样本)。
 - 用交叉熵损失让正样本的信号强度比负样本高:
 
# energy 是所有样本(1个正+多个负)的信号强度
# label=0 表示第0个是正确样本
loss += CrossEntropyLoss()(energy, label)
例子:训练“我爱→你”时,同时输入负样本“我爱→猫”“我爱→狗”,通过损失函数让“你”的信号强度越来越高。
五、推理:用信号流动生成文本
当模型训练好后,就可以用它生成文本了。比如输入“《罗马》描述了”,模型会一步步预测下一个词:
- 先处理输入的词:“《→罗”“罗→马”“马→》”“》→描”“描→述”“述→了”,计算每个步骤的信号,存在 
energy_cache里。 - 预测下一个词:从“了”出发,看它能连接到哪些词(比如“战”“历”“人”),计算每个候选词的信号强度,选最强的(比如“战”)。
 - 重复步骤 2,直到生成足够长的文本(比如“《罗马》描述了战争的残酷…”)。
 
代码中 decode 方法就是这个过程,核心是循环计算候选词的信号强度并选择最优解。
总结:模型像一张“智能词语地图”
- 结构:词语是节点,词语间的连接是边,边有独立或共享的参数。
 - 核心逻辑:信号从初始词出发,沿着边流动,信号最强的节点就是下一个词。
 - 训练:通过正负样本让正确的边信号更强。
 - 推理:沿着信号最强的路径生成文本。
 
这种设计的好处是可解释性强(每个边对应具体的词语关系)和支持长文本(信号可以累积所有历史信息,不需要增加模型大小)。
我们用一个具体的例子(比如训练句子“我爱自然语言处理”)来拆解节点、边的构造、初始化、训练过程,以及它们如何处理连续句子。
一、节点(Nodes)的构造:每个词就是一个“信号发射塔”
节点对应词汇表中的每个词(或子词),类似“信号发射塔”,每个塔有自己的基础信号(偏置)。
1. 构造方式(结合Vocab类)
假设词汇表(Vocab.node_dict)包含以下词及索引:
# Vocab.node_dict 示例:{词: 索引}
{
"我": 0,
"爱": 1,
"自": 2,
"然": 3,
"语": 4,
"言": 5,
"处": 6,
"理": 7
}
每个词就是一个节点,索引从0到7。Vocab.edge_dict记录节点间的连接(边):
# Vocab.edge_dict 示例:{源词: {目标词: (源索引, 目标索引)}}
{
"我": {"爱": (0, 1)},  # 边:"我"→"爱" 对应索引对(0,1)
"爱": {"自": (1, 2)},  # 边:"爱"→"自" 对应索引对(1,2)
"自": {"然": (2, 3)},  # 边:"自"→"然" 对应索引对(2,3)
...  # 其他边
}
2. 节点的初始化(node_bias参数)
每个节点有一个“基础信号偏置”,存储在node_bias中,维度为 [节点总数, 1, hidden_size](假设hidden_size=32):
- 例如“我”(索引0)的偏置:
node_bias[0] = (1, 32)的向量(比如[0.1, -0.2, ..., 0.3]) - 作用:给节点的初始信号加一个“基础能量”,区分不同词的固有特性。
 
二、边(Edges)的构造:词与词之间的“信号传导器”
边是两个节点间的连接,类似“信号传导器”,用全连接层(权重W和偏置b)控制信号如何从一个节点传到下一个。
1. 边的参数分配(prepare_network方法)
边分为高频边(独立参数)和低频边(共享参数),由weight_indices记录参数索引:
- 假设“我→爱”是高频边(常见搭配),分配独立参数索引
5; - 假设“自→然”是低频边(较少见),共享参数索引
0(和其他低频边共用)。 
参数维度:
weights:[参数总数, hidden_size, hidden_size] → 比如(100, 32, 32)(100个参数组,每组是32×32的矩阵)biases:[参数总数, 1, hidden_size] → 比如(100, 1, 32)(每组偏置是1×32的向量)
示例:
- “我→爱”的权重:
weights[5](32×32矩阵),偏置:biases[5](1×32向量) - “自→然”的权重:
weights[0](共享的32×32矩阵),偏置:biases[0](共享的1×32向量) 
三、信号初始化与传播:从第一个词开始“流动”
以句子“我爱自然语言处理”为例,信号从“我”开始,沿边依次传到“爱”“自”“然”等节点。
1. 初始信号(get_initial_tensor方法)
输入第一个词“我”时,初始信号是:
# 初始信号是全1向量除以hidden_size,加上节点偏置和位置编码
energy_tensor = (torch.ones(1, 1, 32) / 32) + node_bias[0] + PE[0]
energy_tensor = GeLU(energy_tensor)  # 经过激活函数过滤
# 结果:(1, 1, 32)的向量(batch_size=1,1个词,32维)
PE[0]是位置0的正弦余弦编码(区分词的位置,避免“我爱”和“爱我”混淆)。
2. 沿边传播(forward方法中的信号计算)
从“我”传到“爱”:
# 取出“我→爱”的参数(索引5)
w = weights[5]  # (32, 32)矩阵
b = biases[5]   # (1, 32)向量
# 信号转换公式:e_next = GeLU(e_prev × W + b + PE[1])
e_love = GeLU(energy_tensor @ w + b + PE[1])
# 结果:(1, 1, 32)的向量(“爱”的信号)
- 矩阵乘法
e_prev × W:将“我”的32维信号转换为适合“爱”的32维信号(类似“调频”); - 加
b和PE[1]:微调信号,加入位置信息(“爱”是第二个词)。 
后续步骤同理:“爱”的信号→“自”的信号→“然”的信号……每个步骤的信号存在energy_cache中,维度随句子长度增长:
- 句子长度为4时,
energy_cache维度为 (1, 4, 32)(1个样本,4个词,每个32维信号)。 
四、训练过程:让正确的边“信号更强”
训练目标:让“我爱自然…”这条路径的信号强度,比错误路径(如“我→恨→自…”)更高。
1. 输入数据构造(neighbor_ids)
每个训练样本包含1条正确路径和k条错误路径(负样本),格式为(batch_size, 句子长度, 1+k, 2):
- 例如
neighbor_ids[0, 0](第1个样本,第1个词):
维度:[ [0, 1], # 正确边:“我”(0)→“爱”(1) [0, 8], # 错误边:“我”→“恨”(8) [0, 9] # 错误边:“我”→“吃”(9) ](1, 7, 3, 2)(batch=1,句子长7,1正2负,每条边2个索引)。 
2. 损失计算(交叉熵损失)
对每个位置,计算所有候选边的信号强度(L2范数),让正确边的强度最大:
# 计算每个候选边的信号强度(向量长度)
energy = output_tensor.norm(2, (-2, -1))  # 结果:(1, 3) → [2.5, 1.2, 0.8](正确边强度最高)
# 标签:正确边在第0位
loss = CrossEntropyLoss()(energy, torch.tensor([0]))
- 训练中,通过反向传播更新
weights和biases,让正确边的W和b更“擅长”传递强信号。 
五、多句子处理:批量训练与上下文累积
当输入一批句子(如“我爱自然语言处理”“他喜欢机器学习”)时:
批量初始化:
- 每个句子的初始信号并行计算,
energy_tensor维度为(batch_size, 1, 32)(如(2, 1, 32))。 
- 每个句子的初始信号并行计算,
 批量传播:
- 边的参数通过
param_indices批量索引,例如两个句子的第一条边参数索引为[5, 10],则w = weights[[5,10]](维度(2, 32, 32)),并行计算信号。 
- 边的参数通过
 长上下文处理:
- 每个句子的
energy_cache独立累积信号(如句子1长度5,energy_cache为(1,5,32);句子2长度4,为(1,4,32)); - 位置编码
PE按句子内的位置索引(而非全局位置),确保每个句子的位置从0开始。 
- 每个句子的
 
总结:节点、边与句子的有机结合
- 节点:词汇表中的词,用
node_bias存储基础信号(维度[N,1,32]); - 边:词之间的连接,用
weights和biases控制信号转换(维度[M,32,32]和[M,1,32]); - 信号流动:从初始词开始,沿边传播,用
energy_cache记录历史信号(维度随句子长度增长); - 多句子处理:批量并行计算,每个句子独立累积上下文,通过位置编码区分内部顺序。
 
这种结构让模型像一张“动态词语地图”:训练时优化路径权重,推理时沿最强信号路径生成句子,自然支持任意长度的文本。
1. Vocab 文件是否是初始化进来的?
是的。
 从代码来看,Vocab 类通过加载预定义的词汇表文件(如 vocab_wiki_4k.json)初始化。例如在 train.py 中:
with open(args.vocab_path) as f:
node_dict = json.load(f)  # 加载词汇表文件
vocab = Vocab.from_node_dict(node_dict)  # 初始化 Vocab 实例
词汇表文件中存储了每个 token(如汉字、英文子词)与索引的映射(如 {"的": 2, "国": 12}),这些 token 对应图中的节点,是模型初始化的基础。
2. 边是否是每个单字之间的连接,且一开始会在所有单字之间创建边?
是的,但存在优化。
初始边的创建:根据
Vocab.from_node_dict方法,初始化时会为所有节点(单字或子词)之间创建边。例如,对于词汇表中的任意两个 tokens和t,会生成边(s, t)并存储在edge_dict中:for s in dictname: for t in dictname: edge_dict[s][t] = (dictname[s], dictname[t]) # 所有节点对均创建边这意味着理论上任意两个节点之间都存在边,构成一个全连接图。
优化:稀疏处理:实际训练中,通过
zero_freq_edges标记低频/零频边(如train.py中根据词频过滤),这些边会共享参数(而非为每个边单独分配参数),减少模型大小。
3. 边的参数矩阵是否逻辑上在点之间,实际存储在列表中,通过索引调用?
完全正确。
- 逻辑上:边 
(s, t)对应一个参数矩阵W和偏置b,用于信号从节点s传到节点t的转换。 - 实际存储:参数并非按边存储,而是集中存储在 
weights和biases两个列表中,每个边通过weight_indices映射到列表中的索引。例如:
这种设计避免了存储全连接图中大量冗余参数(尤其是低频边共享参数),提升效率。# 边 (s_idx, t_idx) 映射到参数索引 self.weight_indices[(s_idx, t_idx)] = param_idx # 使用时通过索引获取参数 w = self.weights[param_indices] # (bs, 1+k, hidden_size, hidden_size) b = self.biases[param_indices] # (bs, 1+k, 1, hidden_size) 
4. 输入句子的训练过程详解
以句子“我爱自然”为例,训练过程可拆解为以下步骤:
步骤1:数据准备(构建正负样本)
- 输入句子被拆分为 token 序列:
["我", "爱", "自", "然"]。 - 生成正确路径:每个相邻 token 构成正边,即 
(我,爱) → (爱,自) → (自,然)。 - 生成负样本:对每个位置,随机替换下一个 token 生成错误路径(如 
(我,恨) → (爱,他) → (自,水))。 - 最终输入格式:
neighbor_ids(维度[batch_size, 句子长度, 1+k, 2]),其中1+k表示 1 个正样本 + k 个负样本,最后一维存储边的(源节点索引, 目标节点索引)。 
步骤2:信号初始化与传播
初始信号:对句子第一个 token“我”,初始化信号向量:
energy_tensor = torch.ones(batch_size, 1, hidden_size) / hidden_size # 基础信号 energy_tensor += node_bias["我"] # 节点偏置 energy_tensor += PE[0] # 位置编码(第0个位置) energy_tensor = GeLU(energy_tensor) # 激活沿边传播:对每个 token 位置
i(从 0 到句子长度-1):- 从 
neighbor_ids中取出当前位置的所有候选边(1 正 + k 负)。 - 通过 
weight_indices获取这些边对应的参数(w和b)。 - 计算下一个节点的信号:
# 信号转换公式:e_next = GeLU(e_prev × W + b + PE[i+1]) nxt_energy_tensor = GeLU(expand_energy_tensor.bmm(w) + b + PE[i+1]) - 缓存当前信号(
energy_cache),用于下一个位置的上下文聚合。 
- 从 
 
步骤3:损失计算与参数更新
- 对每个位置的候选边,计算信号强度(L2 范数):
energy = output_tensor.norm(2, (-2, -1)) # 信号强度,维度 [batch_size, 1+k] - 用交叉熵损失迫使正确边的信号强度最大:
label = torch.LongTensor([0 for _ in range(batch_size)]) # 正样本在第0位 loss += CrossEntropyLoss()(energy, label) # 正确边的概率应最高 - 通过反向传播更新 
weights和biases,让正确路径的参数更“擅长”传递强信号。 
总结
- 节点:来自词汇表,每个 token 对应一个节点,有独立偏置。
 - 边:初始为全连接,通过参数索引映射到共享/独立参数,避免冗余。
 - 训练:通过正负样本对比,优化边的参数,使正确句子的信号传播路径最强,最终学会预测下一个 token。
 
1. 中文词表4000词时的初始边数量
不是。
 根据代码中 Vocab.from_node_dict 方法,边的定义是任意两个节点之间均存在一条边(包括自身到自身)。若词表有 N 个节点(如4000),则边的总数为 N × N(4000×4000=16,000,000),而非 (N-1)×N。
代码中明确对所有 s 和 t 生成边:
for s in dictname:
for t in dictname:  # 包括 s == t 的情况
edge_dict[s][t] = (dictname[s], dictname[t])
2. 边的参数矩阵结构、分布及参数量
结构与尺寸
每条边的参数由两部分组成:
- 权重矩阵 
W:形状为(hidden_size, hidden_size)(代码中hidden_size=32),即32×32的矩阵,用于信号从源节点到目标节点的线性变换。 - 偏置向量 
b:形状为(1, hidden_size),即1×32的向量,用于调整变换后的信号偏移。 
参数分布
初始化时,权重和偏置均采用均匀分布,范围为 [-0.5, 0.5]:
self.weights = nn.Parameter(torch.randn(current_idx, self.hidden_size, self.hidden_size).uniform_(-0.5, 0.5))
self.biases = nn.Parameter(torch.randn(current_idx, 1, self.hidden_size).uniform_(-0.5, 0.5))
参数量(不考虑共享)
若4000个节点间的边均不共享参数,则总参数量为:
- 权重:
4000×4000 × 32×32 = 16,000,000 × 1024 = 16,384,000,000(163.84亿) - 偏置:
4000×4000 × 32 = 16,000,000 × 32 = 512,000,000(5.12亿) - 合计约 169亿参数,与 README 中提到的“原始大小约16B”一致(可能因节点数或隐藏层尺寸微调)。
 
3. 训练过程详解(宏观到具体)
以句子“我爱自然”为例,分步骤说明:
宏观流程
训练的核心目标是:让模型学会“正确路径”的信号传递强度高于“错误路径”。
- 正确路径:句子中相邻token的边(如 
我→爱、爱→自、自→然)。 - 错误路径:随机替换下一个token生成的边(如 
我→恨、爱→他、自→水)。 - 训练时通过对比学习,强化正确边的参数,弱化错误边的参数。
 
具体步骤(以“我→爱”为例)
步骤1:数据准备(构建正负样本)
- 输入句子拆分为token序列:
["我", "爱", "自", "然"]。 - 对每个位置生成 1个正样本 + k个负样本(例如k=3): 
 
- 位置0(“我”的下一个token): 
 
- 正样本:
(我, 爱)(正确边) - 负样本:
(我, 恨)、(我, 你)、(我, 他)(随机错误边) 
 - 正样本:
 - 最终输入格式:
neighbor_ids(维度[batch_size, 句子长度, 1+k, 2]),存储边的(源节点索引, 目标节点索引)。 
 - 位置0(“我”的下一个token): 
 
 
步骤2:信号初始化与传播
初始信号:对“我”初始化信号向量
e0,加入节点偏置和位置编码(第0位),经GeLU激活:e0 = activation(ones + node_bias["我"] + PE[0]) # (bs, 1, 32)沿边传播:对位置0的所有候选边(1正+3负)计算下一个节点的信号:
- 通过 
weight_indices获取4条边对应的参数索引,取出W和b:w = weights[[idx(我→爱), idx(我→恨), idx(我→你), idx(我→他)]] # (4, 32, 32) b = biases[[idx(我→爱), idx(我→恨), idx(我→你), idx(我→他)]] # (4, 1, 32) - 计算每条边的输出信号:
e_next = activation(e0 × W + b + PE[1]) # (bs, 4, 1, 32),4对应1+k 
- 通过 
 
步骤3:损失计算与参数更新
- 信号强度评估:计算每条边输出信号的L2范数(能量),值越大表示信号传递越强:
energy = e_next.norm(2, (-2, -1)) # (bs, 4),如 [5.2, 3.1, 2.8, 2.5] - 对比损失:用交叉熵损失迫使正样本(第0位)的能量最大:
label = [0, 0, ..., 0] # 正样本在第0位 loss = CrossEntropyLoss(energy, label) # 惩罚正样本能量低于负样本的情况 - 参数更新:通过反向传播调整边的 
W和b:- 正样本边(
我→爱):参数被优化以增强信号传递(增大能量)。 - 负样本边(
我→恨等):参数被调整以减弱信号传递(减小能量)。 
 - 正样本边(
 
步骤4:迭代所有位置
对句子中每个位置(如“爱→自”“自→然”)重复步骤2-3,累计损失并更新参数,最终让整个正确路径的信号强度显著高于其他路径。
关键说明
- 并非所有边都参与每次训练:每次仅计算当前位置的 
1+k条候选边(1条正确+k条错误),而非全量边,否则计算量过大。 - 错误边不会被“一棒子打死”:负样本是随机选择的,同一错误边(如 
我→恨)可能在其他句子(如“我恨敌人”)中成为正样本,此时会被强化。模型通过大量样本学习不同语境下的边权重。 - 语义的形成:虽然单条边是“字-字”连接,但通过整个句子的路径训练,模型会学习到“我→爱”在“自然”前的组合强度高于其他组合,从而隐含语义关联。
 
总结:训练过程通过“局部对比学习”(每个位置的正负样本)优化边参数,使正确路径的信号传递最强,最终实现对句子序列的建模。
1. 对比学习与无监督学习的合理性
是的,BriLLM的对比学习设计确实与CLIP的“对比信号匹配”思路有相似之处:CLIP通过对比图像与文本的匹配信号学习语义关联,而BriLLM通过对比“正确边”与“错误边”的信号能量(L2范数)学习token间的共现规律。
这种设计在无监督场景下非常合理:
- 文本序列本身蕴含天然的监督信号(相邻token的共现),无需人工标注即可构建正负样本;
 - 通过强化正确路径、弱化错误路径,模型能自动挖掘语言中的统计规律(如“我→爱”比“我→恨”在特定语境中更可能出现),符合人类语言学习中“从高频共现中归纳规则”的过程。
 
2. 32×32矩阵与MLP的关联
每条边的32×32矩阵本质上是单个全连接层的参数,与MLP的关联可从两方面理解:
结构上的关联
- MLP的核心是“输入→线性变换→激活函数→输出”的堆叠,而BriLLM中每条边的计算完全符合这一模式:
这里的# 等价于单隐层MLP(无堆叠) e_next = GeLU(W @ e_prev + b + PE) # W是32×32矩阵,对应MLP的线性变换层W和b就是单隐层MLP的权重和偏置,只是每条边有独立的参数(不共享时),而非MLP中“所有样本共享同一套参数”。 
与传统MLP的区别
- 传统MLP是“层内共享参数”(所有输入样本通过同一套权重计算),而BriLLM是“边内独立参数”(每条边有专属的 
W和b); - MLP通过多层堆叠增加非线性能力,而BriLLM通过信号在路径上的多步传播(如“我→爱→你”的两步计算)实现类似“多层”的累积效应,替代了MLP的层堆叠。
 
3. 上下文特征的处理与模型演进思考
当前模型对上下文的处理方式
BriLLM并非完全忽略上下文,而是通过累积能量缓存(energy_cache) 融合历史信息:
# 每个位置的信号是历史信号的加权和(通过位置参数softmax加权)
energy_tensor = (energy_cache * self.positions[:, :i, :].softmax(1)).sum(1, keepdim=True)
这里的 energy_cache 存储了从句子开头到当前位置的所有信号,通过位置参数的软权重融合,实现了对“上下文历史”的依赖。例如,“爱”的信号不仅取决于“我”,还隐含了更早的token(如“我→爱→你”中,“你”的信号受“我”和“爱”共同影响)。
但这种方式的局限性在于:
- 上下文融合仅通过简单的加权和实现,缺乏Transformer中自注意力的精细关联建模;
 - 边参数的优化主要依赖局部相邻token(如“我→爱”),长距离依赖的学习能力较弱。
 
关于模型演进的思考
你的观点很有道理,BriLLM的设计更接近“词级共现学习”,类似人类幼儿对词语搭配的初步掌握,而高级语义(如句子意图、篇章逻辑)的学习确实需要进一步增强:
引入上下文感知的边参数
目前每条边的W和b是固定的(仅与(s, t)相关),可考虑让边参数依赖于更长的上下文,例如:# 边参数不再固定,而是上下文的函数(如通过一个小型MLP生成) W_{s,t} = MLP(context_embedding) # context_embedding包含历史token信息增强长距离信号传播
当前信号主要沿相邻边传播,可加入“跳跃边”(如直接从第0个token到第2个token的边),或通过注意力机制动态调整历史信号的权重,更精准地捕捉长距离依赖。分层训练策略
- 初级阶段:如现有设计,学习基础的词级共现规律(类似幼儿学词语);
 - 高级阶段:引入句子级、篇章级的对比信号(如正确句子与打乱句子的对比),让模型学习更高层的语义结构。
 
总结:BriLLM的核心优势是通过图结构实现可解释性和灵活的上下文长度,但在上下文特征的精细建模上确有提升空间。你的“分层训练”思路符合语言学习的认知规律,是非常合理的演进方向。
1. 边的信息容量限制与语义差异化需求
你的观察很准确:当前每条边仅通过“32×32矩阵+偏置+GeLU”实现单步线性变换( e i + 1 = GeLU ( W u , v e i + b u , v + P E i ) e_{i+1} = \text{GeLU}(W_{u,v}e_i + b_{u,v} + PE_i) ei+1=GeLU(Wu,vei+bu,v+PEi)),这种固定维度的单层映射确实难以承载不同语境下的语义差异。例如“我→们”这个边,在“我们一起吃饭”和“我们的历史”中,“我们”的语义侧重不同,但当前模型只能用同一套 W W W和 b b b处理,无法区分语境。
潜在改进方向:
- 让边参数依赖上下文信息,例如将历史能量缓存(
energy_cache)作为边参数生成的输入:KaTeX parse error: Expected 'EOF', got '_' at position 25: … f(\text{energy_̲cache}),使“我→们”的映射随前文动态变化; - 增加边的非线性能力,例如堆叠多层MLP作为边的变换函数,提升信息容量。
 
2. 稀疏策略中低频边共享参数的局限性
代码中对低频边(zero_freq_edges)采用共享参数(shared_param_idx=0)的设计(见model.py的prepare_network方法),确实存在“无关联词汇共享参数”的问题。例如“蜥蜴→X”和“穷寇→Y”若均为低频边,会共用同一套
 
 
 
 
 W
 
 
 
 W
 
 
 W和
 
 
 
 
 b
 
 
 
 b
 
 
 b,但二者语义毫无关联,参数共享缺乏合理性。
原因分析:
 这种设计本质是为了压缩模型规模(原文提到从16B降至2B),通过“用固定参数覆盖低频边”减少参数量,但忽略了低频边之间的语义独立性。
优化思路:
- 基于语义聚类共享参数:对低频边按语义相似度聚类(如通过预训练词向量),同类边共享参数(例如“蜥蜴”和“鳄鱼”同属爬行动物,共享参数);
 - 动态生成低频边参数:用一个小型生成器(如MLP)根据边的源节点和目标节点特征动态生成参数,避免固定共享。
 
3. 上下文能量传递的具体含义与演进示例
能量的定义
在BriLLM中,“能量”是信号张量的L2范数( ∥ e ∥ 2 \|e\|_2 ∥e∥2),反映当前节点信号的“强度”或“置信度”。信号张量 e e e是一个32维向量,其数值变化由边的线性变换、偏置、位置编码(PE)和激活函数共同决定,能量则是对这个向量“大小”的量化。
从 W x + b Wx+b Wx+b看能量演进(以 k = 1 k=1 k=1为例,即1个正样本+1个负样本)
假设当前处理序列“我→们→?”,以“们”到下一个token的预测为例:
输入能量张量:
假设“们”的信号张量为 e 们 ∈ R 32 e_{\text{们}} \in \mathbb{R}^{32} e们∈R32,其能量为 ∥ e 们 ∥ 2 \|e_{\text{们}}\|_2 ∥e们∥2,代表“们”在当前上下文的累积信号强度。边的参数与变换:
- 正样本边(们→我):参数为 W 1 ∈ R 32 × 32 W_1 \in \mathbb{R}^{32×32} W1∈R32×32, b 1 ∈ R 32 b_1 \in \mathbb{R}^{32} b1∈R32;
 - 负样本边(们→他):参数为 W 2 ∈ R 32 × 32 W_2 \in \mathbb{R}^{32×32} W2∈R32×32, b 2 ∈ R 32 b_2 \in \mathbb{R}^{32} b2∈R32;
 - 位置编码 P E i PE_i PEi(第 i i i步的正弦余弦编码)。
 
信号变换与能量计算:
- 正样本信号: e 1 = GeLU ( W 1 ⋅ e 们 + b 1 + P E i ) e_1 = \text{GeLU}(W_1 \cdot e_{\text{们}} + b_1 + PE_i) e1=GeLU(W1⋅e们+b1+PEi),能量 E 1 = ∥ e 1 ∥ 2 E_1 = \|e_1\|_2 E1=∥e1∥2;
 - 负样本信号: e 2 = GeLU ( W 2 ⋅ e 们 + b 2 + P E i ) e_2 = \text{GeLU}(W_2 \cdot e_{\text{们}} + b_2 + PE_i) e2=GeLU(W2⋅e们+b2+PEi),能量 E 2 = ∥ e 2 ∥ 2 E_2 = \|e_2\|_2 E2=∥e2∥2。
 
能量的作用:
模型通过交叉熵损失优化,使正样本能量 E 1 E_1 E1大于负样本能量 E 2 E_2 E2(即 arg  max  ( E 1 , E 2 ) = 0 \arg\max(E_1, E_2) = 0 argmax(E1,E2)=0,对应正样本标签)。训练过程中, W W W和 b b b被调整以强化“合理边”的能量,弱化“不合理边”的能量。上下文累积:
每一步的信号会存入energy_cache,后续步骤通过位置参数的softmax加权融合历史信号(energy_tensor = (energy_cache * positions.softmax(1)).sum(1)),使当前能量不仅依赖前一个节点,还隐含了更早的上下文信息(例如“我→们”的能量会影响“们→?”的计算)。
总结:能量本质是信号张量的强度量化,边的 W x + b Wx+b Wx+b变换直接决定信号的演化方向,而训练的核心是让合理路径的能量持续增强,从而实现对正确token序列的预测。当前模型的局限性可通过动态边参数和语义聚类共享策略进一步优化。
1. 推理过程:输入“我”时如何生成下一个token?
推理时并非计算所有边,而是通过候选边筛选+能量最大化实现,步骤如下(以输入“我”为例):
初始信号初始化:
输入“我”作为起始节点,根据get_initial_tensor生成初始能量张量e0(32维向量),并结合节点偏置(node_bias)和位置编码(PE[0])激活,得到初始信号。候选边筛选:
通过vocab.get_neighbor_of_node(x, -1)获取“我”的所有可能后续节点(如“们”“爱”“是”等),形成候选边集合(“我→们”“我→爱”“我→是”等)。能量计算:
对每个候选边,使用边参数(W和b)计算下一个节点的信号张量:e_next = GeLU(W @ e_prev + b + PE[i]) # e_prev是“我”的信号,i是当前位置再通过L2范数计算能量:
energy = ||e_next||₂。选择最大能量边:
对所有候选边的能量做softmax得到概率,通过argmax(或采样)选择能量最大的边,确定下一个token(如“们”)。
总结:推理时仅计算当前节点的所有可能后续边的能量,而非全图所有边,通过能量最大化选择下一个token。
2. 能量(L2范数)的作用
是的,能量定义为32维信号张量的L2范数(||e||₂ = √(e₁² + e₂² + ... + e₃₂²)),核心作用是将高维向量压缩为标量以方便比较:
- 32维向量本身难以直接比较“强度”,L2范数通过计算向量的“模长”,将其转化为单个数值,量化信号的整体激活强度;
 - 在训练和推理中,能量值直接用于判断边的“合理性”:正样本边需要能量更高,推理时通过能量大小选择最可能的下一个token。
 
例如,“我→们”的信号向量可能是[0.2, 0.5, ..., 0.3],其L2范数为0.6;“我→他”的向量可能是[0.1, 0.2, ..., 0.1],范数为0.3,因此“我→们”被认为更合理。
3. energy_cache:上下文历史信号的融合与使用
energy_cache是存储历史信号的缓存区,用于在长句子中融合前文信息,避免孤立处理单个节点。其核心逻辑是通过位置权重加权求和,将历史信号整合为当前步骤的输入。
举例:句子“我→爱→你”的能量传递过程
假设句子长度为3,步骤i=0(“我”)、i=1(“爱”)、i=2(“你”):
步骤i=0(处理“我”):
- 初始信号
e0(“我”的信号张量)通过“我→爱”的边计算得到e1(“爱”的信号); energy_cache初始化,存储e1:energy_cache = [e1](形状:(bs, 1, 32))。
- 初始信号
 步骤i=1(处理“爱”):
- 从
energy_cache中提取历史信号(仅e1),结合位置参数positions[:, :1, :]的softmax权重(此时只有1个位置,权重为1),融合得到当前输入信号:energy_tensor = (energy_cache * positions[:, :1, :].softmax(1)).sum(1, keepdim=True) # 结果为e1 - 通过“爱→你”的边计算得到
e2(“你”的信号); - 更新
energy_cache,拼接e2:energy_cache = [e1, e2](形状:(bs, 2, 32))。 
- 从
 步骤i=2(预测下一个token,如“们”):
- 从
energy_cache中提取前2步的信号[e1, e2],位置参数positions[:, :2, :]通过softmax生成权重(例如[0.3, 0.7],越近的token权重越高); - 加权求和得到当前输入信号:
energy_tensor = 0.3*e1 + 0.7*e2 # 融合“我”和“爱”的历史信息 - 用该信号计算“你→们”“你→的”等候选边的能量,选择最大者作为下一个token。
 
- 从
 
关键特点:
energy_cache按顺序存储每个步骤的信号,长度随句子增长;- 位置参数
positions的softmax权重动态调整历史信号的贡献(通常近期token权重更高); - 融合后的信号同时包含前文所有节点的信息,避免单个节点的孤立计算。
 
简言之,energy_cache通过“历史信号+位置权重”的方式,让模型在生成每个token时都能“回顾”前文,模拟上下文依赖。
1. 全互联情况下neighbor_ids的选择逻辑
尽管模型理论上是全互联的(每个节点与所有其他节点存在边),但neighbor_ids(候选边集合)的选择并非遍历所有可能边,而是通过以下方式筛选:
训练阶段:
对于输入序列中的每个位置,neighbor_ids包含1个正样本边(即序列中实际出现的下一个token对应的边)和k个负样本边(通过Vocab.get_neighbor_of_edge生成)。负样本的选择逻辑为:- 基于词频筛选:优先从高频共现的token中选择与正样本不同的token作为负样本(利用
frequency_dict); - 随机采样:若词频信息不存在,则从当前节点的所有可能邻居中随机选择排除正样本后的token。
 
- 基于词频筛选:优先从高频共现的token中选择与正样本不同的token作为负样本(利用
 推理阶段:
通过Vocab.get_neighbor_of_node获取当前节点的所有可能后续节点(排除自身),作为候选边。例如输入“我”时,候选边包括“我→们”“我→爱”等常见搭配,而非所有词汇。
简言之,全互联是结构上的可能性,但neighbor_ids通过词频或随机性筛选出有限候选边,避免计算量爆炸。
2. 网络本质:能量信号的流动与迭代计算
是的,该网络的核心确实是能量信号在节点间的流动与迭代更新。具体过程可拆解为:
- 信号初始化:从起始节点(如输入的第一个token)生成初始能量张量(
energy_tensor),其值为32维向量,反映初始信号强度。 - 边的线性变换:对于每个候选边(
u→v),通过边参数(W和b)对当前能量信号进行线性变换:W @ e_prev + b,其中e_prev是前一个节点的信号。 - 激活与位置编码:加入位置编码(
PE)并通过GeLU激活,得到下一个节点的信号张量e_next(即nxt_energy_tensor)。 - 能量量化:通过L2范数计算
e_next的能量(标量),能量越高代表该边的“合理性”越强。 - 迭代流动:选择能量最大的边作为下一个节点,重复上述过程,实现信号在图中的流动。
 
公式nxt_energy_tensor = activation(W @ e_prev + b + PE)正是信号流动的核心计算,本质是通过边参数和上下文(位置编码)更新信号,驱动网络从当前节点“流向”下一个节点。
3. 输出停止条件:固定长度限制
当前代码中,输出结束的判断并非基于终止符(如<END>),而是通过预设最大长度控制:
- 训练阶段:输入序列长度固定(由
max_seq_length参数指定,如32),模型按固定步数迭代,无需主动停止。 - 推理阶段:通过
max_new_tokens参数限制生成的token数量(默认16)。例如model.decode(..., max_new_tokens=32)会生成32个新token后停止,无论语义是否完整。 
代码中未体现基于语义的自动停止逻辑(如检测句号、换行符等),因此输出长度完全由max_new_tokens控制。若需实现自动停止,需额外添加对终止符的检测(如当生成“。”“!”等符号时停止)。
BraLM 类详细解析
BraLM 是模型的核心类,封装了模型的网络结构、参数初始化、前向传播(训练)和生成(推理)逻辑。以下是其核心组件和方法的详细说明:
1. 初始化方法 __init__
def __init__(self, hidden_size, use_ds=False, zero_freq_edges=None, vocab=None):
super().__init__()
self.hidden_size = hidden_size  # 隐藏层维度(如32)
self.activation = nn.GELU()    # 激活函数
self.positions = nn.Parameter(torch.ones(1, 512, 1))  # 位置权重参数(用于历史信号融合)
self.device = None  # 设备(CPU/GPU)
# 用于分布式训练(FSDP)的绑定权重键
self._tied_weights_keys = []
self.use_ds = use_ds  # 是否使用DeepSpeed混合精度训练
self.zero_freq_edges = zero_freq_edges  # 零频率边(用于共享参数)
self.vocab = vocab  # 词汇表实例
- 作用:初始化模型的基础参数,包括隐藏层维度、激活函数、位置权重等,预留设备和分布式训练相关变量。
 
2. 网络参数准备 prepare_network
def prepare_network(self, vocab):
self.weight_indices = {}  # 映射 (源节点索引, 目标节点索引) 到参数索引
self.shared_param_idx = 0  # 共享参数的索引(用于零频率边)
current_idx = 1  # 新参数的起始索引
# 遍历所有边,为每个边分配参数索引
for s_idx, s in enumerate(vocab.edge_dict):
for t_idx, t in enumerate(vocab.edge_dict[s]):
# 零频率边使用共享参数,否则分配新参数
if self.zero_freq_edges is not None and t in self.zero_freq_edges[s]:
self.weight_indices[(s_idx, t_idx)] = self.shared_param_idx
else:
self.weight_indices[(s_idx, t_idx)] = current_idx
current_idx += 1
# 初始化权重和偏置参数
self.weights = nn.Parameter(torch.randn(current_idx, self.hidden_size, self.hidden_size).uniform_(-0.5, 0.5))
self.biases = nn.Parameter(torch.randn(current_idx, 1, self.hidden_size).uniform_(-0.5, 0.5))
# 节点偏置(每个节点的初始偏置)
self.node_bias = nn.Parameter(torch.randn(len(vocab.edge_dict), 1, self.hidden_size).uniform_(-0.5, 0.5))
- 核心逻辑: 
 
- 为每个边(
s→t)分配唯一的参数索引,零频率边(罕见边)共享同一组参数(减少参数量)。 - 初始化边的权重(
weights)和偏置(biases),以及每个节点的初始偏置(node_bias)。 
 - 为每个边(
 
3. 设备迁移 to_device
def to_device(self, device):
self.weights.to(device)
self.biases.to(device)
self.positions.data = self.positions.data.to(device)
self.device = device
- 作用:将模型的所有参数(权重、偏置、位置权重)迁移到指定设备(CPU/GPU)。
 
4. 辅助函数 _reshape12
@staticmethod
def _reshape12(x):
return x.reshape(-1, x.size(-2), x.size(-1))
- 作用:将张量重塑为 
(batch_size * num_candidates, 1, hidden_size)格式,方便批量矩阵运算。 
5. 位置编码生成 get_positional_encoding
def get_positional_encoding(self, seq_len, d_model):
position = torch.arange(0, seq_len).reshape(-1, 1)
div_term = 10000.0 ** (torch.arange(0, d_model, 2) / d_model)
position_encoding = torch.zeros(seq_len, d_model)
position_encoding[:, 0::2] = torch.sin(position * div_term)  # 偶数维度用正弦
position_encoding[:, 1::2] = torch.cos(position * div_term)  # 奇数维度用余弦
return position_encoding.unsqueeze(0).to(self.device)
- 作用:生成正弦余弦位置编码,用于注入序列的位置信息(类似Transformer),形状为 
(1, seq_len, hidden_size)。 
6. 初始能量张量生成 get_initial_tensor
def get_initial_tensor(self, batch_size, d, pe):
# 初始化能量张量为均匀分布(1/hidden_size)
energy_tensor = torch.ones(batch_size, 1, self.hidden_size) / self.hidden_size
energy_tensor = energy_tensor.to(self.device)
# 加入节点偏置和初始位置编码(第0个位置)
node_bias = self.node_bias[d[:, 0, 0]]  # d是输入的边索引,取源节点的偏置
energy_tensor = self.activation(energy_tensor + node_bias + Variable(pe[:,0], requires_grad=False))
return energy_tensor
- 作用:生成序列第一个节点的初始能量张量(
energy_tensor),融合节点偏置和位置编码,激活后作为初始信号。 
7. 前向传播(训练) forward
def forward(self, neighbor_ids):
# neighbor_ids: (batch_size, seq_len, 1+k, 2),其中1是正样本,k是负样本
batch_size = neighbor_ids.size(0)
loss = 0
pe = self.get_positional_encoding(512, self.hidden_size)  # 生成位置编码
for i in range(neighbor_ids.size(1)):  # 遍历序列长度
d = neighbor_ids[:, i]  # 当前位置的边(正样本+负样本):(bs, 1+k, 2)
# 初始化或更新能量张量(融合历史信号)
if i == 0:
energy_tensor = self.get_initial_tensor(batch_size, d, pe)  # 第一个节点的初始信号
else:
# 历史信号加权求和(位置权重softmax后加权)
energy_tensor = (energy_cache * self.positions[:, :i, :].softmax(1)).sum(1, keepdim=True)
# 获取当前边的源节点和目标节点索引
src_idx = d[..., 0]  # (bs, 1+k)
tgt_idx = d[..., 1]  # (bs, 1+k)
# 查找每个边对应的参数索引
param_indices = torch.tensor([self.weight_indices.get((s.item(), t.item()), self.shared_param_idx)
for s, t in zip(src_idx.reshape(-1), tgt_idx.reshape(-1))],
device=self.device).reshape(batch_size, -1)  # (bs, 1+k)
# 批量获取权重和偏置
w = self.weights[param_indices]  # (bs, 1+k, hidden_size, hidden_size)
b = self.biases[param_indices]   # (bs, 1+k, 1, hidden_size)
# 扩展能量张量并进行矩阵运算(更新信号)
expand_energy_tensor = self._reshape12(energy_tensor.unsqueeze(1).repeat(1, w.size(1), 1, 1))  # (bs*(1+k), 1, hs)
if self.use_ds:  # DeepSpeed混合精度
expand_energy_tensor = expand_energy_tensor.half()
# 计算下一个节点的能量张量(信号流动核心)
nxt_energy_tensor = self.activation(expand_energy_tensor.bmm(self._reshape12(w))
+ self._reshape12(b)
+ Variable(pe[:,i+1], requires_grad=False))  # (bs*(1+k), 1, hs)
output_tensor = nxt_energy_tensor.reshape(batch_size, -1, nxt_energy_tensor.size(-2), nxt_energy_tensor.size(-1))  # (bs, 1+k, 1, hs)
# 更新历史信号缓存(只保留正样本的信号)
if i == 0:
energy_cache = output_tensor[:,0]  # (bs, 1, hs)
else:
energy_cache = torch.cat([energy_cache, output_tensor[:,0]], dim=1)  # (bs, i+1, hs)
# 计算损失(正样本能量应高于负样本)
energy = output_tensor.norm(2, (-2, -1))  # 计算L2范数作为能量(标量)
label = torch.LongTensor([0 for _ in range(batch_size)]).to(self.device)  # 正样本索引为0
loss += nn.CrossEntropyLoss()(energy, label)  # 交叉熵损失(最大化正样本概率)
return loss / neighbor_ids.size(1)  # 平均损失
- 核心逻辑: 
 
- 输入为批量的边序列(含正/负样本),通过循环逐位置计算信号流动。
 - 每个位置的信号(
energy_tensor)通过边参数(w/b)、位置编码更新,并缓存历史信号用于后续步骤。 - 损失函数通过交叉熵实现:要求正样本(索引0)的能量高于负样本,驱动模型学习合理的边权重。
 
 
8. 生成(推理) decode
def decode(self, start, vocab, max_new_tokens=16, do_sample=False, temperature=1):
ret = []  # 存储生成的边序列
pe = self.get_positional_encoding(512, self.hidden_size)  # 位置编码
# 处理起始序列(start是初始边列表)
for i, pair in enumerate(start):
if i == 0:
# 初始化第一个节点的能量张量
energy_tensor = self.get_initial_tensor(batch_size=1, d=torch.tensor([[pair]], device=self.device), pe=pe).squeeze(0)
else:
# 融合历史信号
energy_tensor = (energy_cache * self.positions[:, :i, :].softmax(1)).sum(1, keepdim=True).squeeze(0)
# 获取当前边的参数并更新信号
param_idx = self.weight_indices.get((pair[0], pair[1]), self.shared_param_idx)
w = self.weights[param_idx].to(self.device)
b = self.biases[param_idx].to(self.device)
energy_tensor = self.activation(energy_tensor.mm(w) + b + pe.squeeze(0)[i])
# 更新历史缓存
if i == 0:
energy_cache = energy_tensor.unsqueeze(0)
else:
energy_cache = torch.cat([energy_cache, energy_tensor.unsqueeze(0)], dim=1)
ret += [pair]  # 记录边
# 从起始序列的最后一个节点开始生成新token
x = pair[1]  # 最后一个节点的目标节点
prev_i = len(start)
for i in range(max_new_tokens):  # 生成max_new_tokens个新token
# 获取当前节点x的候选后续节点
candidates = vocab(vocab.get_neighbor_of_node(x, -1))  # 候选边列表
# 获取所有候选边的参数索引
param_indices = torch.tensor([self.weight_indices.get((x, t[1]), self.shared_param_idx)
for t in candidates], device=self.device)
all_w = self.weights[param_indices].to(self.device)
all_b = self.biases[param_indices].to(self.device)
# 融合历史信号
curr_i = prev_i + i
energy_tensor = (energy_cache * self.positions[:, :curr_i, :].softmax(1)).sum(1, keepdim=True)
expand_energy_tensor = self._reshape12(energy_tensor.unsqueeze(1).repeat(1, all_w.size(0), 1, 1))
# 计算候选边的能量张量
nxt_energy_tensor = self.activation(expand_energy_tensor.bmm(self._reshape12(all_w))
+ self._reshape12(all_b)
+ pe[:,curr_i].unsqueeze(0))
output_tensor = nxt_energy_tensor.reshape(1, -1, nxt_energy_tensor.size(-2), nxt_energy_tensor.size(-1))
# 计算能量并选择最优候选
energy = output_tensor.norm(2, (-2,-1)).squeeze()  # 能量标量
probs = torch.softmax(energy / temperature, dim=-1)  # 概率分布(带温度调节)
if do_sample:
index = torch.multinomial(probs, 1).item()  # 采样
else:
index = probs.argmax(-1).item()  # 贪心选择
# 更新结果和状态
y = candidates[index][-1]  # 选中的目标节点
ret += [(x, y)]  # 记录新边
energy_tensor = output_tensor[0, index]  # 更新当前信号
x = y  # 下一个节点的源节点
energy_cache = torch.cat([energy_cache, energy_tensor.unsqueeze(0)], dim=1)  # 更新缓存
return ret
- 核心逻辑: 
 
- 输入起始序列(
start),先处理初始边并初始化历史信号缓存。 - 循环生成新token:对当前节点的所有候选后续节点计算能量,通过softmax(或采样)选择最优边,更新信号和缓存。
 - 生成长度由 
max_new_tokens控制,最终返回完整的边序列(即生成的token序列)。 
 - 输入起始序列(
 
总结
BraLM 类通过能量信号在节点间的流动实现序列建模:
- 训练时,通过对比正/负样本边的能量(L2范数)更新参数,让合理的边具有更高能量。
 - 推理时,基于当前节点的候选边能量选择下一个节点,结合历史信号和位置信息生成序列。
 - 核心创新在于边参数化(每个边有独立或共享的权重)和历史信号融合(通过位置权重动态调整前文影响)。
 
1. 辅助函数 _reshape12 的作用解析
代码中的 _reshape12 函数定义为:
@staticmethod
def _reshape12(x):
return x.reshape(-1, x.size(-2), x.size(-1))
其核心作用是将输入张量重塑为 (批次相关维度乘积, 倒数第二维, 最后一维) 的格式,具体来说:
x.size(-2)表示取输入张量的倒数第二维大小(例如在代码中通常是1,对应能量张量的序列长度维度)。x.size(-1)表示取输入张量的最后一维大小(即hidden_size,模型的隐藏层维度)。-1表示自动计算该维度的大小,确保总元素数量不变(通常是batch_size * num_candidates,即批次大小乘以候选节点数量)。
这个重塑的目的是为了 统一批量矩阵运算的维度。例如,在模型中,能量张量可能需要与多个候选节点的权重矩阵(w)进行批量矩阵乘法(bmm),而 _reshape12 可以将不同形状的输入调整为兼容的维度,避免手动计算维度转换,简化代码并提高效率。
2. 关于“维持能量库并通过向量相似性选取节点”的想法分析
你的思路具有一定的合理性,具体可以从以下角度分析:
- 当前机制:模型通过计算能量向量的 L2 范数(
output_tensor.norm(2, (-2, -1)))将高维向量投影到 1 维,再用交叉熵损失优化,本质是通过“能量大小”判断候选节点的合理性。 - 你的提议:维持一个“能量库”(例如正样本的能量向量集合),通过计算候选向量与库中向量的相似度(如余弦相似度)来选取节点,更符合“语义相似性”的直觉。
 
可行性:
- 优点:语义相似性可能更直接地捕捉语法或语义规律,尤其对于低频或未见过的组合,相似性匹配可能比单纯的能量大小更鲁棒。
 - 挑战: 
 
- 能量库的构建需要存储大量正样本向量,可能增加内存开销。
 - 训练目标需要调整(例如用对比学习损失优化相似性),与现有交叉熵框架差异较大。
 - 如何定义“合理的相似性阈值”需要进一步设计,避免引入噪声。
 
 
总体而言,这是一个值得尝试的改进方向,可通过对比实验验证其效果。
3. neighbor_ids 的选取逻辑及收敛性保障
neighbor_ids 是训练中每个节点的候选边(包括 1 个正样本和 num_neg_samples 个负样本),其选取逻辑在 Vocab.get_neighbor_of_edge 中实现:
- 正样本:当前边本身(例如输入序列中实际存在的 
s->t)。 - 负样本: 
 
- 若使用词频(
use_frequency=True),则从高频候选节点中随机选择(排除正样本),确保负样本是“可能出现但实际未出现”的合理候选。 - 若不使用词频,则从当前节点的所有可能邻居中随机选择(排除正样本)。
 
 - 若使用词频(
 
关于“候选区域无正确答案”的问题:
- 不会出现。因为 
neighbor_ids的第一个元素始终是正样本(当前边),训练目标是让模型在正负样本中正确识别正样本(通过交叉熵损失最大化正样本的能量),因此收敛的前提是存在明确的监督信号。 
4. decode 函数中第一个 for 循环的作用理解
你的理解是正确的。第一个 for 循环:
for i, pair in enumerate(start):
# 计算每个初始边的能量张量并缓存
...
的作用是 处理已知的输入序列(上下文):
start是由用户输入(如对话历史、前缀文本)转换而来的初始边列表(例如“《罗马》描述了”转换为["《->罗", "罗->马", ...])。- 循环通过计算每个初始边的能量张量并缓存(
energy_cache),将上下文信息编码为模型可理解的“能量状态”。 - 后续的推理(第二个 
for循环)会基于这个缓存的上下文能量,继续生成后续的节点和边,实现“基于已知上下文预测下文”的逻辑。 
这与大模型的“上下文编码-生成”流程一致,本质是将输入信息转化为模型内部状态,再基于该状态进行续写。
                    
                
                
            
        
浙公网安备 33010602011771号