docs-merge-08
TowardsDataScience 2024 中文翻译(九)
用机器学习预测金融行业的欺诈:一位数据科学家的经验
数据科学家在前线的见解和经验
·发表于 Towards Data Science ·阅读时间:5 分钟·2024 年 6 月 7 日
--

图片来源:Growtika,来自 Unsplash
你好,亲爱的数据爱好者!我很高兴通过几篇文章与大家分享我在过去三年里,开发机器学习模型预测金融行业欺诈的经验。因此,如果你在欺诈检测项目中担任项目经理、数据科学家、机器学习工程师、数据工程师、Mlops 工程师、欺诈分析师或产品经理等角色,你可能会觉得这篇文章对你有帮助。
在本系列的第一篇文章中,我想讨论以下几个要点:
需要解决的业务问题是什么
项目的高层次步骤
业务问题
每天,全球有数百万人使用汇款服务。这些服务帮助我们将钱汇给亲人,并让购物变得更方便。但是,欺诈者利用这些系统欺骗他人将钱转给他们,或接管他们的账户进行欺诈。这不仅会伤害受害者,还会导致参与公司遭受财务损失并损害声誉。此外,系统中的公司和责任方还面临监管和合规风险(例如,2017 年西联汇款因未能保持有效的反洗钱和消费者欺诈系统而被罚款 5.86 亿美元)。在资金落入欺诈者手中之前预测欺诈交易对于公司至关重要。这正是人工智能/机器学习驱动的欺诈管理工具发挥作用的地方。
公司的目标主要是最小化运营成本、提升客户体验或减少欺诈和损失。
在这个背景下,有多种类型的欺诈行为,例如:
老年人虐待
好撒马利亚人
恋爱诈骗
消费者诈骗
账户加热
身份盗窃
账户接管(ATO)
洗钱
如果你有兴趣了解每种具体的欺诈类型,以下是一些有用的链接:六种支付欺诈类型,汇款诈骗
项目步骤
机器学习/人工智能项目通常以迭代方式进行。但根据我的经验,以下 9 个步骤是项目的良好起点。
1. 理解现有系统
现有系统包括人员、流程和系统。
人员:识别在欺诈管理中具有领域专业知识的关键人物。确定他们的角色以及他们如何为项目做出贡献。例如,专家欺诈分析师可以通过定义欺诈因素和识别趋势做出重要贡献。
流程:分析公司当前如何识别欺诈行为,以及如何衡量其效果。
系统:评估当前用于检测欺诈的系统。许多公司可能已经有一个基于规则的专家系统。
2. 确定利益相关者的目标
理解利益相关者的不同目标至关重要,以便从一开始就使它们保持一致并明确期望。例如,从合规团队的角度来看,高欺诈检测率是理想的,而市场营销团队可能更关心虚假警报对客户体验的影响。同时,运营团队可能需要为预测时间制定特定的服务级别协议(SLA),以确保操作的顺利进行。在项目的一个阶段优化所有这些可能冲突的目标是低效的。因此,领导支持对于设定优先级和找到共同点至关重要。
3- 数据理解
你一定听过那句著名的谚语:“垃圾进,垃圾出。”为了避免将质量较差的数据输入到机器学习模型中,我们需要分析数据源及其质量,确保它们符合实验要求和在线流媒体标准。识别现有数据中的约束,并阐明这些约束对预测质量的影响。这一步对于维护模型输出的完整性和准确性至关重要。
4- 红旗信号定义
机器学习模型的构建模块是特征。在欺诈预测的背景下,这些特征主要代表欺诈行为或红旗信号。在这一阶段,我们提取欺诈专家的隐性知识,并将其转化为一系列红旗信号,然后将其发展为特征,输入到模型中。
例如,红旗信号可能包括:客户向高风险国家发送的交易次数、大量不同客户在短时间内向同一人汇款等。
5- 特征创建/工程
在这一阶段,识别出的红旗信号被编码成特征。可以定义各种特征组,例如汇款特征、交易模式和用户行为指标。特征工程是提取最具信息量的特征、区分欺诈与非欺诈的关键步骤。这个过程包括选择、修改和创建新特征,以提高模型的准确性和预测能力。
6. 模型训练与测试
在这一步,目标是训练一个或多个机器学习模型,以合理的准确度预测欺诈。期望的准确度水平取决于业务需求以及相较于基线系统需要改进的程度(此时可参考第二步中定义的目标)。
7. 实时运营化
所有之前的步骤都是在离线批处理环境中进行的。一旦模型准备好,就必须部署到生产环境中,以便其预测能够实时服务下游系统(在我们的项目中,响应时间低于一秒)。MLOps 团队负责这一步骤,优化管道的运行时并确保与其他系统的无缝集成。
8. 实时监控
一旦模型的预测集成到实时系统并由运营团队使用,密切监控其性能变得至关重要。目标是确保实时性能与批处理环境中测试的预期结果一致。如果出现差异,必须识别并解决潜在问题。例如,监控应该包括跟踪模型处理的交易数量、预测为欺诈的交易数量以及这些交易的后续路径。此外,必须监控管道本身的性能,确保服务按预期运行。
9. 设置反馈循环流程
建立反馈循环过程对于持续评估模型的表现并相应地优化它至关重要。这个过程涉及将实际标签重新引入系统,并加入任何其他相关信息。例如,如果模型将某些交易标记为欺诈,追踪这些交易的调查情况及其调查结果非常重要。同样,质量保证团队提供的见解,包括可能导致误报的原因,也应重新引入系统,以增强反馈循环过程。这种迭代方法确保了欺诈检测模型的持续改进和优化。
在下一篇文章中,我们将看到参与该项目的不同角色。请告诉我你的经验如何?你的经验与我的有什么相似之处或不同之处?
从 Adaline 到多层神经网络
打好基础
·发表于 Towards Data Science ·阅读时长 23 分钟·2024 年 1 月 9 日
--

图片由 Konta Ferenc 提供,来源于 Unsplash
在前两篇文章中,我们看到如何基于 Rosenblatt 的 感知机 实现一个基本的分类器,以及如何通过使用自适应线性神经元算法(adaline)来改进这个分类器。这两篇文章涵盖了在尝试实现一个多层人工神经网络之前的基础知识。从 Adaline 过渡到深度学习是一个更大的飞跃,许多机器学习从业者会直接选择使用像 PyTorch 这样的开源库。使用这样的专业机器学习库,当然推荐用于开发生产中的模型,但并不一定适合用于学习多层神经网络的基本概念。本文将从零开始构建一个多层神经网络。我们将聚焦于一个多分类问题,而不是解决一个二分类问题。我们将在每一层后,包括输出层,使用 Sigmoid 激活函数。本质上,我们训练一个模型,对于每个输入,包含一组特征的向量,输出一个长度等于待预测类别数的向量。输出向量的每个元素都在 [0, 1] 范围内,并可以理解为每个类别的“概率”。
本文的目的是让读者熟悉用于描述神经网络的数学符号,理解各种矩阵(包含权重和偏差)的作用,并推导更新权重和偏差以最小化损失函数的公式。该实现允许使用任意数量的隐藏层,且每层的维度可以自定义。大多数教程假设一个固定的架构,但本文使用了精心挑选的数学符号,这些符号支持泛化。通过这种方式,我们还可以进行简单的数值实验,检验预测性能与隐藏层的数量和大小之间的关系。
与之前的文章一样,我使用了在线LaTeX 公式编辑器来编写公式的 LaTeX 代码,然后使用 Chrome 插件Maths Equations Anywhere将公式渲染成图片。所有的 LaTeX 代码将在文章末尾提供,方便你再次渲染。如果你需要渲染公式,正确使用符号是机器学习中的一个关键部分,且对于理解神经网络至关重要。必须仔细审查公式,并注意各种指标和矩阵乘法的规则。模型在纸面上的正确表述一旦完成,实现代码变得非常简单。
本文中使用的所有代码可以在附带的代码库中找到。本文涵盖以下主题:
∘ 什么是多层神经网络?
∘ 激活函数
∘ 损失函数
∘ 反向传播
∘ 实现
∘ 数据集
∘ 训练模型
∘ 超参数调优
∘ 结论
∘ 文章中使用的 LaTeX 公式代码
什么是多层神经网络?
本节介绍了一个通用的、前馈式的、全连接的多层神经网络架构。随着我们逐步讲解下面的图 1,这里有很多术语需要理解。
对于每一个预测,网络接受一个特征向量作为输入。

这也可以理解为一个形状为(1, n⁰)的矩阵。网络使用 L 层并产生一个向量作为输出。

这可以理解为一个形状为(1, nᴸ)的矩阵,其中 nᴸ是我们需要解决的多类分类问题中的类别数。该矩阵中的每一个浮点数都位于[0, 1]的范围内,最大元素的索引对应于预测的类别。上标中的(L)符号用于指代特定层次,这里指的是最后一层。
但我们如何生成这个预测呢?让我们集中在第一层的第一个元素(输入不算作一层)。

我们首先计算网络输入,这本质上是输入向量与一组权重的内积,并加上一个偏置项。第二个操作是应用激活函数σ(z),稍后我们将回到这个问题。目前需要记住的是,激活函数本质上是一个标量操作。
我们可以以相同的方式计算第一层的所有元素

从上述内容我们可以推断,我们引入了 n¹ x n⁰个权重和 n¹个偏置项,这些将在模型训练时进行拟合。这些计算也可以用矩阵形式表示

注意矩阵的形状。网络输出是两个形状分别为(1, n⁰)和(n⁰, n¹)的矩阵相乘的结果,得到形状为(1, n¹)的矩阵,再加上另一个形状为(1, n¹)的偏置项矩阵。注意我们引入了权重矩阵的转置。激活函数应用于该矩阵的每个元素,因此第 1 层的激活值也是形状为(1, n¹)的矩阵。

图 1:一个具有任意数量输入特征、输出类别和不同节点数隐藏层的通用多层神经网络(图像由作者提供)
上述内容可以很容易地推广到神经网络中的每一层。第 k 层接受 nᵏ⁻¹个值作为输入,并生成 nᵏ个激活值

第 k 层引入 nᵏ x nᵏ⁻¹个权重和 nᵏ个偏置项,这些将在模型训练时进行拟合。权重和偏置项的总数为

因此,如果我们假设输入向量有 784 个元素(低分辨率灰度图像的维度),一个隐藏层有 50 个节点,输出有 10 个类别,那么我们需要优化 78550 + 5110 = 39,760 个参数。如果增加隐藏层的数量和每层节点的数量,参数的数量会进一步增加。优化一个包含如此多参数的目标函数并非易事,这也是为什么从 adaline 引入到 80 年代中期我们才发现如何训练深度网络。
本节主要介绍所谓的前向传播(forward pass),即如何应用一系列矩阵乘法、矩阵加法和逐元素激活,将输入向量转换为输出向量。如果你仔细观察,我们假设输入是一个单一样本,表示为形状为 (1, n⁰) 的矩阵。即使我们将一个样本批次作为形状为 (N, n⁰) 的矩阵输入到网络中,这种表示法依然成立。唯一稍微复杂一些的是偏置项。如果我们关注第一层,我们将形状为 (N, n¹) 的矩阵与形状为 (1, n¹) 的偏置矩阵相加。为了使这一步骤正常工作,偏置矩阵的第一行会被复制多次,直到与批次中样本的数量相等。这是一个非常自然的操作,NumPy 在所谓的广播中会自动执行。当我们对一批输入应用前向传播时,可能更清晰的做法是将所有变成矩阵的向量用大写字母表示,即

请注意,我假设广播(broadcasting)已应用于偏置项,从而导致一个矩阵的行数与批次中样本的数量相同。
在深度神经网络中,通常使用批处理操作。我们可以看到,随着样本数量 N 的增加,我们需要更多的内存来存储各种矩阵并进行矩阵乘法运算。此外,仅使用部分训练集来更新权重意味着我们将在每次遍历训练集(epoch)时多次更新参数,从而加快收敛速度。还有一个额外的好处,可能不那么明显。网络使用的激活函数与 Adaline 中的激活函数不同,它们不是恒等函数。实际上,它们甚至不是线性的,这使得损失函数是非凸的。使用批处理引入的噪声被认为有助于逃脱浅层局部最小值。适当选择的学习率进一步帮助实现这一目标。
在继续之前,最后需要说明的是,“前馈”(feedforward)一词来源于每一层使用前一层的输出作为输入,而没有使用导致所谓递归神经网络的循环。
激活
使神经网络能够解决复杂问题需要引入某种形式的非线性。通过在每一层中使用激活函数来实现这一点。选择有很多种。本文将使用 Sigmoid(逻辑)激活函数,我们可以通过以下方式进行可视化:
产生

图 2:Sigmoid(逻辑)激活函数。图片由作者提供。
代码还包括了本文中将需要的所有导入库。
激活函数将任何浮动值映射到 0 到 1 的范围。实际上,Sigmoid 更适合用于二分类问题的最后一层激活。对于多类别问题,使用 softmax 将神经网络的输出归一化为预测输出类别的概率分布会更为合适。可以这样理解,softmax 强制要求激活后输出向量的所有项之和为 1,而 Sigmoid 则没有这个要求。另一种理解方式是,Sigmoid 本质上将对数几率(log odds)转换为一对多(OvA)概率。然而,我们将继续使用 Sigmoid 激活函数,以尽量保持与 Adaline 的一致,因为 Softmax 不是逐元素操作,这会在反向传播算法中引入一些复杂性。这个部分留给读者作为练习。
损失函数
用于 Adaline 的损失函数是均方误差。在实际应用中,多类别分类问题通常使用多类别交叉熵损失。为了尽量保持与 Adaline 的一致性,并便于对损失函数关于参数的梯度进行解析计算,我们将坚持使用均方误差损失函数。训练集中的每个样本都属于 nᴸ 类中的一种,因此损失函数可以表示为

其中,第一个求和是对所有样本进行求和,第二个求和是对类别进行求和。上述公式意味着,样本 i 的已知类别已被转换为独热编码,即一个形状为 (1, nᴸ) 的矩阵,除了对应于样本类别的元素为 1 外,其他元素都是 0。我们采用了另一种符号约定,在上标中使用 [j] 来表示样本 j。上述求和不需要使用训练集中的所有样本。实际上,它会以批量 N’ 样本的形式应用,其中 N’<<N。
反向传播
损失函数是一个标量,依赖于成千上万个参数,包括权重和偏差项。通常,这些参数通过随机数初始化,并通过迭代更新,以便通过损失函数相对于每个参数的梯度来最小化损失函数。在 Adaline 的情况下,梯度的解析推导是直接的。而对于多层神经网络,推导过程则更加复杂,但如果我们采用巧妙的策略,仍然是可以处理的。我们进入了反向传播的世界,但不用担心。反向传播实际上是链式求导法则从右到左的连续应用。
让我们回到损失函数。它依赖于最后一层的激活值,因此我们可以首先计算关于这些激活值的导数。

上述可以理解为导数矩阵的(j, i)元素,矩阵形状为(N, nᴸ),可以写成矩阵形式为

其中右侧的两个矩阵形状都是(N, nᴸ)。最后一层的激活值是通过对最后一层净输入矩阵的每个元素应用 sigmoid 激活函数计算得到的。因此,为了计算损失函数对该净输入矩阵每个元素的导数,我们只需回忆如何计算嵌套函数的导数,外层函数为 sigmoid 函数:

星号乘法表示逐元素相乘。该公式的结果是一个形状为(N, nᴸ)的矩阵。如果你在计算 sigmoid 函数的导数时遇到困难,请查看这里。
现在我们已经准备好计算损失函数对 L-1 层权重的导数;这是我们从右到左移动时遇到的第一组权重

这导致了一个与 L-1 层的权重形状相同的矩阵。接下来,我们需要计算 L 层的净输入对 L-1 层权重的导数。如果我们选取最后一层净输入矩阵的一个元素和其中一个权重,我们有

如果你难以理解上述内容,可以想象对于每个样本 j,L 层净输入的 i 元素仅依赖于 L-1 层的权重,其中第一个索引也是 i。因此,我们可以消除导数中的一个求和项。

我们可以使用矩阵表示法来表达所有这些导数,使用

本质上,矩阵乘法中的隐式求和吸收了对样本的求和。跟随乘法矩阵的形状,你会看到结果导数矩阵的形状与用于计算 L 层净输入的权重矩阵的形状相同。尽管结果矩阵中的元素个数仅限于最后两层节点数的乘积(形状为(nᴸ, nᴸ⁻¹)),但乘法矩阵要大得多,因此通常会消耗更多的内存。因此,在训练模型时需要使用批处理。
损失函数对用于计算最后一层净输入的偏置项的导数可以像对权重的导数一样计算,得到

这导致一个形状为(1, nᴸ)的矩阵。
我们刚刚计算了所有关于用于计算最后一层净输入的权重和偏置项的损失函数导数。现在,我们将注意力转向前一层的权重和偏置项的梯度(这些参数将带有上标索引 L-2)。希望我们能开始识别一些模式,以便将它们应用于计算关于 k=0,..,L-2 的权重和偏置项的导数。如果我们计算损失函数关于 L-1 层激活值的导数,就能看到这些模式的出现。这些导数应形成一个形状为(N, nᴸ⁻¹)的矩阵,计算公式如下:

一旦我们得到了损失关于 L-1 层激活值的导数,就可以继续计算损失函数关于 L-1 层净输入的导数,然后再计算损失函数关于 L-2 层权重和偏置项的导数。
让我们回顾一下如何通过一层进行反向传播。我们假设已经计算了损失函数关于权重和偏置项(索引为 k)的导数,现在需要计算损失函数关于权重和偏置项(索引为 k-1)的导数。我们需要执行四个操作:

所有操作都已向量化。我们已经可以开始设想如何在一个类中实现这些操作了。我理解的是,当使用专门的库添加一个完全连接的线性层并带有激活函数时,这背后发生的就是这些操作!不必担心数学符号非常方便,但我的建议是至少一次通过这些推导。
实现
在这一节中,我们提供了一个通用的前馈多层神经网络的实现。这个 API 与专门深度学习库(如 PyTorch)中的 API 有些相似。
代码包含两个实用函数:sigmoid()将 Sigmoid(逻辑)激活函数应用于一个浮动值(或 NumPy 数组),而int_to_onehot()接受一个包含每个样本类别的整数列表,并返回它们的独热编码表示。
类MultilayerNeuralNetClassifier包含神经网络的实现。初始化构造函数为每一层的权重和偏置项分配随机数。举个例子,如果我们构建一个神经网络,layers=[784, 50, 10],则我们将使用 784 个输入特征,一个包含 50 个节点的隐藏层,以及 10 个类别作为输出。这个通用的实现允许改变隐藏层的层数和每一层的节点数。稍后我们在进行超参数调整时将利用这一点。为了保证可重复性,我们使用一个随机数生成器的种子来初始化权重。
forward 方法返回每一层的激活值,作为一个矩阵列表。该方法既可以处理单个样本,也可以处理样本数组。返回的最后一个矩阵包含每个样本的类别归属的模型预测结果。一旦模型训练完成,仅使用这个矩阵来进行预测。然而,在模型训练过程中,我们需要获取所有层的激活值,正如下面所述,这也是 forward 方法返回所有激活值的原因。假设网络初始化时使用了 layers=[784, 50, 10],forward 方法将返回一个包含两个矩阵的列表,第一个矩阵的形状为 (N, 50),第二个矩阵的形状为 (N, 10),假设输入 x 有 N 个样本,即它是一个形状为 (N, 784) 的矩阵。
backward 方法实现了反向传播,即按照上一节中描述的计算损失函数的所有解析导数。最后一层是特殊的,因为我们需要使用已知类别来计算损失函数相对于模型输出的导数。第一层是特殊的,因为我们需要使用输入值,而不是前一层的激活值。中间层则是相同的,我们只是简单地反向遍历各层。代码完全反映了解析推导的公式。通过使用 NumPy,我们将所有操作向量化,从而加速了执行过程。该方法返回一个包含两个列表的元组。第一个列表包含与每一层权重相关的损失函数导数的矩阵。假设网络初始化时使用了 layers=[784, 50, 10],该列表将包含形状为 (784, 50) 和 (50, 10) 的两个矩阵。第二个列表包含与每一层偏置项相关的损失函数导数的向量。假设网络初始化时使用了 layers=[784, 50, 10],该列表将包含形状为 (50, ) 和 (10, ) 的两个向量。
回顾我从这篇文章中学到的内容,我觉得实现过程相当直接。最难的部分是想出一个稳健的数学符号并在纸上推导梯度。不过,即使优化似乎已收敛,仍然很容易犯错误,这些错误可能不容易被察觉。这让我想到了特殊的backward_numerical方法。该方法既不用于训练模型,也不用于做出预测。它使用有限差分(中心差分)来估算损失函数相对于所选层的权重和偏置项的导数。通过将数值导数与backward函数返回的解析计算结果进行比较,可以确保实现是正确的。这个方法在训练模型时会太慢,因为它每个导数都需要进行两次前向传播,而在我们这个简单的例子中,layers=[784, 50, 10]会有 39,760 个这样的导数!但是,它真的是救命稻草。就个人而言,如果没有它,我是无法调试代码的。如果你想从这篇文章中记住一个关键点,那就是数值微分在检查你的解析梯度时是多么有用。我们可以使用未训练的模型来检查梯度的正确性。
生成的
layer 3: 300 out of 300 weight gradients are numerically equal
layer 3:10 out of 10 bias term gradients are numerically equal
layer 2: 1200 out of 1200 weight gradients are numerically equal
layer 2:30 out of 30 bias term gradients are numerically equal
layer 1: 2000 out of 2000 weight gradients are numerically equal
layer 1:40 out of 40 bias term gradients are numerically equal
梯度看起来是有序的!
数据集
我们将需要一个数据集来构建我们的第一个模型。一个在模式识别实验中常用的著名数据集是MNIST 手写数字。我们可以在 OpenML 数据集仓库中找到有关该数据集的更多细节。OpenML 中的所有数据集都受 CC BY 4.0 许可证约束,该许可证允许在任何媒介和任何目的下复制、再分发和转化材料。
该数据集包含 70,000 张数字图像及其对应的标签。方便的是,这些数字已经通过计算像素的质心进行了大小归一化,并且将图像平移,使得这个质心点位于 28x28 图像的中心。该数据集可以通过scikit-learn方便地获取。
打印出来
original X: X.shape=(70000, 784), X.dtype=dtype('int64'), X.min()=0, X.max()=255
original y: y.shape=(70000,), y.dtype=dtype('O')
processed X: X.shape=(70000, 784), X.dtype=dtype('float64'), X.min()=-1.0, X.max()=1.0
processed y: y.shape=(70000,), y.dtype=dtype('int32')
class counts: 0:6903, 1:7877, 2:6990, 3:7141, 4:6824, 5:6313, 6:6876, 7:7293, 8:6825, 9:6958
我们可以看到,每个图像都可以表示为一个包含 784 个整数的向量,范围从 0 到 255,这些整数已转换为[-0.5, 0.5]之间的浮动值。这或许与 scikit-learn 中通常的特征缩放方式略有不同,后者是按特征而非样本进行缩放的。类别标签以字符串形式提取并转换为整数。该数据集在类别之间相对平衡。
接下来,我们将为每个数字可视化十个图像,以获得对手写变体的感觉。
生成的

随机选择的每个数字样本。图片来源:作者。
我们可以预见到模型可能会对某些数字产生混淆,例如最后的 9 与 8 相似。也可能会出现一些手写变体未被很好预测的情况,例如 7 的数字写在中间有一条横线,这取决于这种变体在训练集中出现的频率。现在我们已经有了一个神经网络实现和一个数据集,接下来我们将提供训练模型所需的代码,然后再讨论超参数调优。
训练模型
我们需要采取的第一步是将数据集拆分为训练集和外部(留出)测试集。我们可以通过 scikit-learn 轻松完成这一步。
我们使用分层抽样,以确保每个类别在训练集和外部(留出)数据集中的比例大致相等。外部(留出)测试集包含 10,000 个样本,仅用于评估模型性能。在本节中,我们将使用 60,000 个样本作为训练集,且不进行任何超参数调优。
在计算损失函数相对于模型参数的梯度时,我们会发现有必要进行多次矩阵乘法,其中一些矩阵的行数等于样本数量。鉴于样本数量通常非常大,我们将需要大量的内存。为了缓解这个问题,我们将使用小批量处理,就像我们在使用梯度下降优化 Adaline 模型时使用小批量处理一样。通常,每个批次可以包含 100 到 500 个样本。减小批次大小可以提高收敛速度,因为在同一轮训练集(epoch)中,我们会进行更多的参数更新,但也会增加噪声。因此,我们需要找到一个平衡点。首先,我们提供一个生成器,接受训练集和批次大小,并返回批次。
该生成器返回大小相等的批次,默认每个批次包含 100 个样本。样本总数可能不是批次大小的整数倍,因此在每次通过训练集时,某些样本不会被返回。被跳过的样本数量小于批次大小,而且每次使用生成器时被遗漏的样本集合会发生变化,前提是我们不重置随机数生成器。因此,这并不关键。由于我们将在不同的训练轮次中多次通过训练集,最终会完整使用训练集。使用固定大小的批次的原因是,我们会在每个批次后更新模型参数,而较小的批次可能会增加噪声并阻碍收敛,特别是当批次中的样本恰好是异常值时。
当模型初始化时,我们预期会得到一个较低的准确率,这一点我们可以通过以下方式确认:
这产生了大约 9.5%的准确率。对于一个合理平衡的数据集来说,这是预期的结果,因为有 10 个类别。现在我们有了监控每批次的损失和准确率的方法,这将在训练过程中加以利用。让我们写下最后一段代码,迭代轮次和小批量数据,更新模型参数,并监控训练集和外部(保留)测试集中的损失和准确率的变化。
使用这个函数,训练变成了一行代码
这会生成
epoch 0: loss_training=0.096 | accuracy_training=0.236 | loss_test=0.088 | accuracy_test=0.285
epoch 1: loss_training=0.086 | accuracy_training=0.333 | loss_test=0.085 | accuracy_test=0.367
epoch 2: loss_training=0.083 | accuracy_training=0.430 | loss_test=0.081 | accuracy_test=0.479
epoch 3: loss_training=0.078 | accuracy_training=0.532 | loss_test=0.075 | accuracy_test=0.568
epoch 4: loss_training=0.072 | accuracy_training=0.609 | loss_test=0.069 | accuracy_test=0.629
epoch 5: loss_training=0.066 | accuracy_training=0.657 | loss_test=0.063 | accuracy_test=0.673
epoch 6: loss_training=0.060 | accuracy_training=0.691 | loss_test=0.057 | accuracy_test=0.701
epoch 7: loss_training=0.055 | accuracy_training=0.717 | loss_test=0.052 | accuracy_test=0.725
epoch 8: loss_training=0.050 | accuracy_training=0.739 | loss_test=0.049 | accuracy_test=0.742
epoch 9: loss_training=0.047 | accuracy_training=0.759 | loss_test=0.045 | accuracy_test=0.765
我们可以看到,在 10 轮后,训练集的准确率大约达到了 76%,而外部(保留)测试集的准确率略高,表明模型没有发生过拟合。
训练集的损失持续减少,因此尚未达到收敛。该模型支持热启动,因此我们可以通过重复上述单行代码运行另外 10 轮训练。相反,我们将重新初始化模型,并将其运行 100 轮,同时将批量大小增加到 200。我们提供了执行此操作的完整代码。
我们首先将训练损失及其变化率绘制为与轮次(epoch)数量的函数
这会生成

训练损失及其变化率作为与轮次(epoch)数量的函数。图像由作者提供。
我们可以看到,模型已经合理地收敛,因为训练损失的变化率比训练开始时小了两个数量级。我不确定为什么我们在第 10 轮左右观察到收敛速度减慢;我只能推测优化器跳出了局部最小值。
我们还可以将训练集和测试集的准确率绘制为与轮次(epoch)数量的函数
这会生成

训练集和外部(保留)测试集的准确率作为轮次(epoch)数量的函数。图像由作者提供。
经过大约 50 轮训练后,训练集和外部(保留)测试集的准确率都达到了约 90%,这表明没有/几乎没有过拟合。我们刚刚训练了第一个自定义构建的具有一个隐藏层的多层神经网络!
超参数调优
在上一节中,我们选择了一个任意的网络架构并拟合了模型参数。在本节中,我们通过改变隐藏层的数量(从 1 到 3 层)、隐藏层中节点的数量(从 10 到 50,步长为 10)以及学习率(使用 0.1、0.2 和 0.3)进行基础的超参数调优。我们保持批量大小为每批 200 个样本。总的来说,我们尝试了 45 种参数组合。我们将使用 6 折交叉验证(非嵌套),这意味着每种参数组合训练 6 次模型,总共进行 270 次模型训练。在每一折中,我们将使用 50,000 个样本进行训练,并使用 10,000 个样本来测量准确率(在代码中称为验证)。为了提高收敛的可能性,我们将为每个模型拟合执行 250 个训练轮次。总的执行时间约为 12 小时,使用的是单个处理器(Intel Xeon Gold 3.5GHz)。这大致是我们在 CPU 上可以合理运行的速度。通过多进程可以提高训练速度。事实上,使用像 PyTorch 这样的专门深度学习库,并在Google Cola上的 T4 GPU 上运行,训练速度会快得多。
这段代码遍历所有超参数值和折叠,并在pandas数据框中存储训练集(50,000 个样本)和验证集(10,000 个样本)的损失和准确性。该数据框用于找到最优的超参数。
这会产生
optimal parameters: n_hidden_layers=1, n_hidden_nodes=50, learning rate=0.3
best mean cross validation accuracy: 0.944
| n_hidden_layers | 10 | 20 | 30 | 40 | 50 |
|------------------:|---------:|---------:|---------:|---------:|--------:|
| 1 | 0.905217 | 0.927083 | 0.936883 | 0.939067 | 0.9441 |
| 2 | 0.8476 | 0.925567 | 0.933817 | 0.93725 | 0.9415 |
| 3 | 0.112533 | 0.305133 | 0.779133 | 0.912867 | 0.92285 |
我们可以看到,增加层数几乎没有带来太多好处。也许通过使用更大的第一隐藏层,我们可以稍微获得更好的性能,因为超参数调优达到了 50 个节点的限制。一些平均交叉验证的准确率非常低,可能表明收敛性较差(例如,使用 3 个隐藏层,每个层 10 个节点时)。我们没有进一步调查,但在得出网络结构的最佳结论之前,通常需要进行这类调查。我预计,如果允许更多的训练轮次,特别是在较大的网络中,准确率可能会进一步提高。
最后的步骤是使用除了外部(保留)集以外的所有样本重新训练模型,外部集仅用于最终评估。
最后的 5 个训练轮次是
epoch 245: loss_training=0.008 | accuracy_training=0.958 | loss_test=0.009 | accuracy_test=0.946
epoch 246: loss_training=0.008 | accuracy_training=0.958 | loss_test=0.009 | accuracy_test=0.947
epoch 247: loss_training=0.008 | accuracy_training=0.958 | loss_test=0.009 | accuracy_test=0.947
epoch 248: loss_training=0.008 | accuracy_training=0.958 | loss_test=0.009 | accuracy_test=0.946
epoch 249: loss_training=0.008 | accuracy_training=0.958 | loss_test=0.009 | accuracy_test=0.946
我们在外部(保留)测试集上达到了约 95%的准确率。如果考虑到我们从一张白纸开始,这真是太神奇了!
结论
本文展示了如何从零开始构建一个多层前馈全连接神经网络。该网络用于解决多类分类问题。实现已被通用化,可以支持任意数量的隐藏层和每层的节点数。这使得通过改变层数和节点数来调整超参数变得更加灵活。然而,我们需要牢记,随着神经网络深度的增加,损失梯度会变得越来越小。这被称为梯度消失问题,并且一旦网络深度超过某个阈值,就需要使用专门的训练算法,这超出了本文的讨论范围。
我们的原始多层神经网络实现希望具有一定的教育价值。然而,若要在实践中使用它,还需要进行一些改进。首先,需要通过采用某种形式的丢弃法来解决过拟合问题。其他改进,如添加跳跃连接和在训练过程中变动学习率,也可能有益。此外,网络架构本身也可以优化,例如使用卷积神经网络(CNN),它更适合用于图像分类。这类改进最好使用像PyTorch这样的专用库来实现。在从零开始开发算法时,需注意耗时以及如何划定边界,以便保持教育性而不会过于消耗时间。我希望本文在这方面达到了良好的平衡。如果你对这方面感兴趣,我推荐这本书以供进一步学习。
本文中使用的 LaTeX 方程代码
本文中使用的方程可以在下面的 gist 中找到,如果你想重新渲染它们。
从 AI 画布到 MLOps 堆栈画布:它们是否必要?
AI、ML 和 MLOps 堆栈画布,通过实例进行解释
·发布于 Towards Data Science ·14 分钟阅读·2024 年 10 月 24 日
--

图片来自 Kiki Siepel 在 Unsplash
嘿!我已经使用 AI 和 ML 画布处理个人和专业项目一段时间了,因为它有助于我保持项目和想法的有序。最近,我一直在考虑创建一个面向 MLOps 的画布,在进行一些相关的研究时,我发现了一个现成的画布,叫做“MLOps 堆栈画布”。在本文中,我们将一起探索如何使用 MLOps 堆栈画布。最后,我很想听听你对这个画布的看法;如果你以前使用过,或者打算在未来采用,请告诉我!
不是 Medium 会员?没关系!通过这个朋友链接继续阅读。
目录:
· 1. 引言
· 2. AI 画布
· 3. ML 画布
· 4. MLOps 堆栈画布
· 5. 那么是 AI、ML 还是 MLOps 画布?
· 6. 结论
我的 MLOps 教程:
教程 1:MLOps 的关键入门:探索其核心组件
教程 2:MLOps 工作流的初学者友好介绍
从假设到准确性:条件概率在现实世界预测中的作用
条件概率优于概率;前提是你有相关的信息
·发布于 Towards Data Science ·5 分钟阅读·2024 年 5 月 25 日
--

图片来源:Unsplash
引言
任何学过概率的人都听过这个长期以来确立的概率定义:“概率可以定义为有利结果的数量除以所有结果的总数。”我仍然能听到我四年级的老师一遍又一遍地重复这句话!
虽然这个定义是正确的,但我常常会想,这个定义在现实世界中的准确性如何?当我们对有利结果有更多信息时,它的准确性又如何?更明确地说,当我们对有利结果有更多的“条件”时,它的准确性如何。
引入“条件”就像是通过多种条件,将你原本的有利结果的数量划分成不同的切片,给你一个更真实代表有利结果数量的切片。下面的图片试图简要地描绘这个概念。

图片由作者创作
还有什么比代表许多在美国学习并寻找工作的国际数据科学学生心中所想的更好的方式呢!最初的可用工作数量显示在最左边。
第一个条件: 引入“工作经验时长”的第一个条件,细化了新加入者可获得工作的数量。
第二个条件: 此外,加入 第二个条件 “国籍/公民身份”的因素,进一步精细化了这一部分。
第 3 个条件 :最右侧图表中的小深蓝色部分代表了最准确的可用工作数量(即有利结果的数量)。
在深入探讨为什么条件概率可能比普通概率更有用之前,让我们快速回顾一下定义。
2. 概率和条件概率的定义
概率:
P(A) = A 的有利结果数量 / 总结果数量
条件概率:
现在考虑两个事件 A 和 B。条件概率的基础是“事件给定另一个事件”。在这种情况下,当我们说A 给定 B时,意味着事件 A 发生在事件 B 已经发生的情况下。所以就是将“B 的条件,附加到 A 上”。
P(A|B) = P(A 与 B 的交集) / P(B),其中
P(A 与 B 的交集)* 被定义为事件 A 和事件 B 同时发生的概率。
假设 (A 与 B 的交集) 和 (A’ 与 B 的交集) 是互斥的。因此 (A 与 B 的交集) 并集 (A’ 与 B 的交集) = B。
在定义这些稍显混淆的定义后,我将讲解为什么我认为条件概率实际上更好。
3. 示例 — 这篇文章的灵感来源
说实话,写这篇文章的灵感来自于前几天我看的一部宝莱坞电影,其中有一个场景,两个老朋友在讨论他们相遇的概率!
让我介绍一些关于这个场景的更多信息:
第 1 个朋友:警察,来自孟买市,前往卡林蓬(Kalimpong)——一个小镇,处理一起案件。
第 2 个朋友:数学教授,是卡林蓬镇的居民。
这两个朋友彼此认识,因为他们曾经在同一所大学学习。
- 目前,两个朋友在一家咖啡馆碰面,教授每天都会去那里。
介绍完这些信息后,让我们回到问题:他们在卡林蓬相遇的概率是多少。
警察:“兄弟,几率有多大!”
数学教授:“1/95,675。”
警察:“错了!你没有算上我。”
数学教授:“我知道了。目前人口是 95,674。”
嗯…让我们分解一下这个逻辑:
初步概率计算:
数学教授计算出遇到他朋友——警察的概率是 1/95,675。
这个假设前提是卡林蓬的 95,674 名居民与警察碰面的概率与遇到教授的概率相同。
为什么这个计算不准确:
- 这个计算假设遇到警察与遇到卡林蓬的任何其他居民是一样的!
引入条件概率:
让我们考虑一些特定的情景。
I. 相关信息:
警察是来自孟买的居民,他前往卡林蓬。
这位数学教授每天都会去这家咖啡馆。
那位警察刚好在这次去同一家咖啡馆。
II. 条件事件:
事件 A:教授和警察在卡林蓬见面。
事件 B:这位警察从孟买旅行到卡林蓬。
两位朋友见面的概率:

- 警察从孟买到卡林蓬旅行的概率,取决于以下因素:
他多久出差一次?
他有多频繁被分配去处理小镇的案件?
假设这个概率是 0.1%。
- 两位朋友见面的概率取决于以下因素:
他们两个人有多频繁去这家咖啡馆?
这家咖啡馆有多受欢迎?
教授定期去这家咖啡馆。
假设这个概率是 1%。
最终计算:

在警察已经在卡林蓬的情况下,两位朋友在卡林蓬见面的概率是 0.001%。
这是对这个概念的简化表示,但我想表达的是,要始终寻找更多相关的信息来细化你的概率。
结论
概率既简单又复杂!然而,随着我们获得更多的信息,我们总能做出更精确的推理。在现实世界中,始终尝试寻找额外的信息,看看它如何帮助你添加条件,从而使你的概率更为准确。
感谢阅读,希望这篇文章对你有所帮助!
从基础到进阶:探索 LangGraph
构建单智能体和多智能体工作流,包含人机协作交互
·发布于 Towards Data Science ·21 分钟阅读·2024 年 8 月 15 日
--

图片来源:DALL-E 3
LangChain 是构建大型语言模型驱动应用程序的领先框架之一。借助 LangChain 表达式语言(LCEL),定义和执行逐步行动序列——也就是所谓的链——变得更加简便。从更技术性的角度来看,LangChain 允许我们创建 DAG(有向无环图)。
随着 LLM 应用,特别是 LLM 智能体的发展,我们开始不仅将 LLM 用于执行任务,还用作推理引擎。这一转变引入了频繁涉及重复(循环)和复杂条件的交互。在这种情况下,LCEL 不足以应对,因此 LangChain 实现了一个新模块——LangGraph。
LangGraph(正如你从名字中可以猜到的)将所有交互建模为循环图。这些图使得开发具有多个循环和条件判断的高级工作流和交互成为可能,因此它成为了创建智能体和多智能体工作流的便捷工具。
在本文中,我将探索 LangGraph 的关键特性和功能,包括多智能体应用。我们将构建一个可以回答不同类型问题的系统,并深入了解如何实现人机协作的设置。
在上一篇文章中,我们尝试使用了 CrewAI,这是另一个流行的多代理系统框架。然而,LangGraph 采取了不同的方法。CrewAI 是一个高层框架,具有许多预定义的功能和现成的组件,而 LangGraph 在较低层次操作,提供广泛的定制和控制。
通过这个介绍,让我们深入了解 LangGraph 的基本概念。
LangGraph 基础
LangGraph 是 LangChain 生态系统的一部分,因此我们将继续使用一些熟悉的概念,比如提示模板、工具等。然而,LangGraph 引入了一些额外的概念。让我们来讨论一下它们。
LangGraph 被创建用于定义循环图。图由以下元素组成:
节点代表实际的操作,可以是 LLM、代理或函数。同时,一个特殊的 END 节点标志着执行的结束。
边连接节点并确定图的执行流程。有基本边,它们仅仅连接一个节点到另一个节点,还有条件边,它们包含 if 语句和额外的逻辑。
另一个重要概念是图的状态。状态作为图组件协作的基础元素。它表示图的快照,任何部分——无论是节点还是边——都可以在执行过程中访问和修改,以检索或更新信息。
此外,状态在持久化中也起着至关重要的作用。每一步之后,它会自动保存,允许你在任何时刻暂停并恢复执行。这个功能支持更复杂应用的开发,比如那些需要错误修正或包含人工干预的应用。
单代理工作流
从零开始构建代理
让我们从简单的案例开始,尝试使用 LangGraph 创建一个基础用例——一个带有工具的代理。
我将尝试构建类似我们在上一篇文章中使用 CrewAI 时的应用程序。然后,我们可以比较这两个框架。对于这个示例,让我们创建一个应用程序,能够根据数据库中的表自动生成文档。这可以节省我们在创建数据源文档时的很多时间。
和往常一样,我们将首先定义我们代理的工具。由于在这个示例中我将使用 ClickHouse 数据库,因此我定义了一个函数来执行任何查询。如果你愿意,也可以使用其他数据库,因为我们不会依赖任何数据库特定的功能。
CH_HOST = 'http://localhost:8123' # default address
import requests
def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
r = requests.post(host, params = {'query': query},
timeout = connection_timeout)
if r.status_code == 200:
return r.text
else:
return 'Database returned the following error:\n' + r.text
使 LLM 工具可靠且减少错误至关重要。如果数据库返回错误,我会将反馈传递给 LLM,而不是抛出异常并停止执行。这样,LLM 代理就有机会修复错误并重新调用该函数。
我们定义了一个名为execute_sql的工具,它使得能够执行任何 SQL 查询。我们使用pydantic来指定工具的结构,确保 LLM 代理能够有效地使用该工具。
from langchain_core.tools import tool
from pydantic.v1 import BaseModel, Field
from typing import Optional
class SQLQuery(BaseModel):
query: str = Field(description="SQL query to execute")
@tool(args_schema = SQLQuery)
def execute_sql(query: str) -> str:
"""Returns the result of SQL query execution"""
return get_clickhouse_data(query)
我们可以打印出创建的工具的参数,看看传递给 LLM 的信息是什么。
print(f'''
name: {execute_sql.name}
description: {execute_sql.description}
arguments: {execute_sql.args}
''')
# name: execute_sql
# description: Returns the result of SQL query execution
# arguments: {'query': {'title': 'Query', 'description':
# 'SQL query to execute', 'type': 'string'}}
一切看起来都很好。我们已经设置好了必要的工具,现在可以继续定义 LLM 代理。如上所述,LangGraph 中代理的基石是其状态,它使得我们图中不同部分之间可以共享信息。
我们当前的例子相对简单。因此,我们只需要存储消息的历史记录。让我们定义代理状态。
# useful imports
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
# defining agent state
class AgentState(TypedDict):
messages: Annotated[list[AnyMessage], operator.add]
我们在AgentState中定义了一个单一的参数——messages——它是AnyMessage类对象的列表。此外,我们用operator.add(累加器)对其进行了注解。这个注解确保每当一个节点返回一条消息时,它会被附加到状态中现有的消息列表里。如果没有这个操作符,每条新消息会替代掉之前的值,而不是被添加到列表中。
下一步是定义代理本身。我们从__init__函数开始。我们将为代理指定三个参数:模型、工具列表和系统提示。
class SQLAgent:
# initialising the object
def __init__(self, model, tools, system_prompt = ""):
self.system_prompt = system_prompt
# initialising graph with a state
graph = StateGraph(AgentState)
# adding nodes
graph.add_node("llm", self.call_llm)
graph.add_node("function", self.execute_function)
graph.add_conditional_edges(
"llm",
self.exists_function_calling,
{True: "function", False: END}
)
graph.add_edge("function", "llm")
# setting starting point
graph.set_entry_point("llm")
self.graph = graph.compile()
self.tools = {t.name: t for t in tools}
self.model = model.bind_tools(tools)
在初始化函数中,我们概述了图的结构,其中包括两个节点:llm和action。节点是实际的操作,因此我们有与之关联的函数。稍后我们将定义这些函数。
此外,我们还有一条条件边,决定是否需要执行函数或生成最终答案。对于这条边,我们需要指定前一个节点(在我们的例子中是llm)、一个决定下一步的函数,以及基于该函数输出的后续步骤映射(格式为字典)。如果exists_function_calling返回 True,我们会进入函数节点。否则,执行将在特殊的END节点处结束,标志着过程的结束。
我们在function和llm之间添加了一条边。这条边仅仅将这两个步骤连接起来,并且会在没有任何条件的情况下执行。
定义了主要结构后,接下来是创建上述列出的所有函数。第一个是call_llm。这个函数将执行 LLM 并返回结果。
代理状态将自动传递给函数,因此我们可以使用其中保存的系统提示和模型。
class SQLAgent:
<...>
def call_llm(self, state: AgentState):
messages = state['messages']
# adding system prompt if it's defined
if self.system_prompt:
messages = [SystemMessage(content=self.system_prompt)] + messages
# calling LLM
message = self.model.invoke(messages)
return {'messages': [message]}
结果是,我们的函数返回一个字典,这个字典将用于更新代理状态。由于我们使用了operator.add作为状态的累加器,返回的消息会被附加到存储在状态中的消息列表中。
我们需要的下一个函数是execute_function,它将运行我们的工具。如果 LLM 代理决定调用某个工具,我们将在message.tool_calls参数中看到它。
class SQLAgent:
<...>
def execute_function(self, state: AgentState):
tool_calls = state['messages'][-1].tool_calls
results = []
for tool in tool_calls:
# checking whether tool name is correct
if not t['name'] in self.tools:
# returning error to the agent
result = "Error: There's no such tool, please, try again"
else:
# getting result from the tool
result = self.tools[t['name']].invoke(t['args'])
results.append(
ToolMessage(
tool_call_id=t['id'],
name=t['name'],
content=str(result)
)
)
return {'messages': results}
在这个函数中,我们遍历 LLM 返回的工具调用,并且要么调用这些工具,要么返回错误信息。最终,我们的函数返回一个字典,字典中包含一个键messages,该键将用于更新图表状态。
只剩下一个函数——用于条件边的函数,它定义了我们是否需要执行工具还是提供最终结果。这非常简单。我们只需要检查最后一条消息是否包含任何工具调用。
class SQLAgent:
<...>
def exists_function_calling(self, state: AgentState):
result = state['messages'][-1]
return len(result.tool_calls) > 0
现在是时候为它创建一个代理和 LLM 模型了。我将使用新的 OpenAI GPT 4o mini 模型(文档),因为它比 GPT 3.5 更便宜且性能更好。
import os
# setting up credentioals
os.environ["OPENAI_MODEL_NAME"]='gpt-4o-mini'
os.environ["OPENAI_API_KEY"] = '<your_api_key>'
# system prompt
prompt = '''You are a senior expert in SQL and data analysis.
So, you can help the team to gather needed data to power their decisions.
You are very accurate and take into account all the nuances in data.
Your goal is to provide the detailed documentation for the table in database
that will help users.'''
model = ChatOpenAI(model="gpt-4o-mini")
doc_agent = SQLAgent(model, [execute_sql], system=prompt)
LangGraph 提供了一个非常方便的功能来可视化图表。要使用它,你需要安装pygraphviz。
对于带有 M1/M2 芯片的 Mac 来说,这有点棘手,所以这里有个小窍门供你参考(来源):
! brew install graphviz
! python3 -m pip install -U --no-cache-dir \
--config-settings="--global-option=build_ext" \
--config-settings="--global-option=-I$(brew --prefix graphviz)/include/" \
--config-settings="--global-option=-L$(brew --prefix graphviz)/lib/" \
pygraphviz
在弄清楚安装过程后,这就是我们的图表。
from IPython.display import Image
Image(doc_agent.graph.get_graph().draw_png())

如你所见,我们的图表有环。用 LCEL 实现类似的功能会相当具有挑战性。
最后,到了执行我们的代理的时候了。我们需要将包含问题的初始消息集作为HumanMessage传递。
messages = [HumanMessage(content="What info do we have in ecommerce_db.users table?")]
result = doc_agent.graph.invoke({"messages": messages})
在result变量中,我们可以观察到执行过程中生成的所有消息。整个过程按预期进行:
代理决定调用带有查询
describe ecommerce.db_users的函数。然后,LLM 处理了工具提供的信息并给出了用户友好的回答。
result['messages']
# [
# HumanMessage(content='What info do we have in ecommerce_db.users table?'),
# AIMessage(content='', tool_calls=[{'name': 'execute_sql', 'args': {'query': 'DESCRIBE ecommerce_db.users;'}, 'id': 'call_qZbDU9Coa2tMjUARcX36h0ax', 'type': 'tool_call'}]),
# ToolMessage(content='user_id\tUInt64\t\t\t\t\t\ncountry\tString\t\t\t\t\t\nis_active\tUInt8\t\t\t\t\t\nage\tUInt64\t\t\t\t\t\n', name='execute_sql', tool_call_id='call_qZbDU9Coa2tMjUARcX36h0ax'),
# AIMessage(content='The `ecommerce_db.users` table contains the following columns: <...>')
# ]
这是最终结果,看起来相当不错。
print(result['messages'][-1].content)
# The `ecommerce_db.users` table contains the following columns:
# 1\. **user_id**: `UInt64` - A unique identifier for each user.
# 2\. **country**: `String` - The country where the user is located.
# 3\. **is_active**: `UInt8` - Indicates whether the user is active (1) or inactive (0).
# 4\. **age**: `UInt64` - The age of the user.
使用预构建的代理
我们已经学会了如何从零开始构建代理。然而,对于像这样的简单任务,我们可以利用 LangGraph 的内置功能。
我们可以使用预构建的 ReAct 代理来获得类似的结果:一个可以与工具协作的代理。
from langgraph.prebuilt import create_react_agent
prebuilt_doc_agent = create_react_agent(model, [execute_sql],
state_modifier = system_prompt)
这与我们之前构建的代理是一样的。我们稍后会试用它,但首先,我们需要理解两个其他重要的概念:持久性和流媒体。
持久性与流媒体
持久性指的是在不同的交互之间保持上下文的能力。这对于代理使用场景至关重要,当应用程序可以从用户那里获取额外的输入时。
LangGraph 会在每一步之后自动保存状态,允许你暂停或恢复执行。这一功能支持实现更复杂的业务逻辑,如错误恢复或人机交互。
添加持久性最简单的方法是使用内存中的 SQLite 数据库。
from langgraph.checkpoint.sqlite import SqliteSaver
memory = SqliteSaver.from_conn_string(":memory:")
对于现成的代理,我们可以在创建代理时将内存作为参数传递。
prebuilt_doc_agent = create_react_agent(model, [execute_sql],
checkpointer=memory)
如果你正在使用自定义代理,你需要在编译图表时将内存作为检查指针传递。
class SQLAgent:
def __init__(self, model, tools, system_prompt = ""):
<...>
self.graph = graph.compile(checkpointer=memory)
<...>
让我们执行代理并探索 LangGraph 的另一个特性:流式处理。通过流式处理,我们可以将每个执行步骤的结果作为流中的独立事件接收。这个功能对于需要同时处理多个对话(或线程)的生产环境应用至关重要。
LangGraph 不仅支持事件流式处理,还支持基于令牌的流式处理。我想到的基于令牌流式处理的唯一用例是逐字显示实时回答(类似于 ChatGPT 实现)。
让我们尝试使用流式处理并结合我们的新预构建代理。我还将使用 pretty_print 函数来格式化消息,使结果更易读。
# defining thread
thread = {"configurable": {"thread_id": "1"}}
messages = [HumanMessage(content="What info do we have in ecommerce_db.users table?")]
for event in prebuilt_doc_agent.stream({"messages": messages}, thread):
for v in event.values():
v['messages'][-1].pretty_print()
# ================================== Ai Message ==================================
# Tool Calls:
# execute_sql (call_YieWiChbFuOlxBg8G1jDJitR)
# Call ID: call_YieWiChbFuOlxBg8G1jDJitR
# Args:
# query: SELECT * FROM ecommerce_db.users LIMIT 1;
# ================================= Tool Message =================================
# Name: execute_sql
# 1000001 United Kingdom 0 70
#
# ================================== Ai Message ==================================
#
# The `ecommerce_db.users` table contains at least the following information for users:
#
# - **User ID** (e.g., `1000001`)
# - **Country** (e.g., `United Kingdom`)
# - **Some numerical value** (e.g., `0`)
# - **Another numerical value** (e.g., `70`)
#
# The specific meaning of the numerical values and additional columns
# is not clear from the single row retrieved. Would you like more details
# or a broader query?
有趣的是,代理并未提供足够好的结果。由于代理没有查阅表格的 schema,它在猜测所有列的含义时遇到了困难。我们可以通过在同一线程中使用后续问题来改善结果。
followup_messages = [HumanMessage(content="I would like to know the column names and types. Maybe you could look it up in database using describe.")]
for event in prebuilt_doc_agent.stream({"messages": followup_messages}, thread):
for v in event.values():
v['messages'][-1].pretty_print()
# ================================== Ai Message ==================================
# Tool Calls:
# execute_sql (call_sQKRWtG6aEB38rtOpZszxTVs)
# Call ID: call_sQKRWtG6aEB38rtOpZszxTVs
# Args:
# query: DESCRIBE ecommerce_db.users;
# ================================= Tool Message =================================
# Name: execute_sql
#
# user_id UInt64
# country String
# is_active UInt8
# age UInt64
#
# ================================== Ai Message ==================================
#
# The `ecommerce_db.users` table has the following columns along with their data types:
#
# | Column Name | Data Type |
# |-------------|-----------|
# | user_id | UInt64 |
# | country | String |
# | is_active | UInt8 |
# | age | UInt64 |
#
# If you need further information or assistance, feel free to ask!
这次,我们从代理那里得到了完整的答案。由于我们提供了相同的线程,代理能够从之前的讨论中获取上下文。这就是持久性工作的方式。
让我们尝试更改线程并提出相同的后续问题。
new_thread = {"configurable": {"thread_id": "42"}}
followup_messages = [HumanMessage(content="I would like to know the column names and types. Maybe you could look it up in database using describe.")]
for event in prebuilt_doc_agent.stream({"messages": followup_messages}, new_thread):
for v in event.values():
v['messages'][-1].pretty_print()
# ================================== Ai Message ==================================
# Tool Calls:
# execute_sql (call_LrmsOGzzusaLEZLP9hGTBGgo)
# Call ID: call_LrmsOGzzusaLEZLP9hGTBGgo
# Args:
# query: DESCRIBE your_table_name;
# ================================= Tool Message =================================
# Name: execute_sql
#
# Database returned the following error:
# Code: 60\. DB::Exception: Table default.your_table_name does not exist. (UNKNOWN_TABLE) (version 23.12.1.414 (official build))
#
# ================================== Ai Message ==================================
#
# It seems that the table `your_table_name` does not exist in the database.
# Could you please provide the actual name of the table you want to describe?
代理缺乏回答我们问题所需的上下文并不令人惊讶。线程的设计旨在隔离不同的对话,确保每个线程保持其独立的上下文。
在实际应用中,内存管理至关重要。对话可能会变得相当冗长,某些时候将整个历史记录每次都传递给 LLM 是不切实际的。因此,值得对消息进行修剪或过滤。我们不会在这里深入讨论细节,但你可以在LangGraph 文档中找到相关指南。压缩对话历史的另一种方法是使用摘要(示例)。
我们已经学会了如何使用 LangGraph 构建单一代理的系统。下一步是将多个代理组合到一个应用中。
多代理系统
作为多代理工作流的示例,我想构建一个可以处理来自不同领域问题的应用。我们将有一组专家代理,每个专家专注于不同类型的问题,以及一个路由代理,负责找到最合适的专家来处理每个查询。这样的应用有许多潜在的使用场景:从自动化客户支持到在内部聊天中回答同事的问题。
首先,我们需要创建代理的状态——这些信息将帮助代理们共同解决问题。我将使用以下字段:
question— 初始的客户请求;question_type— 定义哪个代理将处理请求的类别;answer— 对问题的建议回答;反馈— 一个供未来使用的字段,用于收集一些反馈。
class MultiAgentState(TypedDict):
question: str
question_type: str
answer: str
feedback: str
我没有使用任何 reducer,因此我们的状态只会存储每个字段的最新版本。
接下来,我们创建一个路由节点。它将是一个简单的 LLM 模型,用于定义问题的类别(数据库、LangChain 或一般问题)。
question_category_prompt = '''You are a senior specialist of analytical support. Your task is to classify the incoming questions.
Depending on your answer, question will be routed to the right team, so your task is crucial for our team.
There are 3 possible question types:
- DATABASE - questions related to our database (tables or fields)
- LANGCHAIN- questions related to LangGraph or LangChain libraries
- GENERAL - general questions
Return in the output only one word (DATABASE, LANGCHAIN or GENERAL).
'''
def router_node(state: MultiAgentState):
messages = [
SystemMessage(content=question_category_prompt),
HumanMessage(content=state['question'])
]
model = ChatOpenAI(model="gpt-4o-mini")
response = model.invoke(messages)
return {"question_type": response.content}
现在我们已经有了第一个节点——路由节点——让我们构建一个简单的图表来测试工作流程。
memory = SqliteSaver.from_conn_string(":memory:")
builder = StateGraph(MultiAgentState)
builder.add_node("router", router_node)
builder.set_entry_point("router")
builder.add_edge('router', END)
graph = builder.compile(checkpointer=memory)
让我们通过不同类型的问题来测试我们的工作流程,看看它在实际操作中的表现。这将帮助我们评估路由代理是否能将问题正确分配给相应的专家代理。
thread = {"configurable": {"thread_id": "1"}}
for s in graph.stream({
'question': "Does LangChain support Ollama?",
}, thread):
print(s)
# {'router': {'question_type': 'LANGCHAIN'}}
thread = {"configurable": {"thread_id": "2"}}
for s in graph.stream({
'question': "What info do we have in ecommerce_db.users table?",
}, thread):
print(s)
# {'router': {'question_type': 'DATABASE'}}
thread = {"configurable": {"thread_id": "3"}}
for s in graph.stream({
'question': "How are you?",
}, thread):
print(s)
# {'router': {'question_type': 'GENERAL'}}
运行得很好。我建议你逐步构建复杂的图表,并独立测试每个步骤。采用这种方法,你可以确保每个迭代都按预期工作,并能节省大量的调试时间。
接下来,我们创建专家代理节点。我们将使用带有 SQL 工具的 ReAct 代理,将其作为数据库代理。
# database expert
sql_expert_system_prompt = '''
You are an expert in SQL, so you can help the team
to gather needed data to power their decisions.
You are very accurate and take into account all the nuances in data.
You use SQL to get the data before answering the question.
'''
def sql_expert_node(state: MultiAgentState):
model = ChatOpenAI(model="gpt-4o-mini")
sql_agent = create_react_agent(model, [execute_sql],
state_modifier = sql_expert_system_prompt)
messages = [HumanMessage(content=state['question'])]
result = sql_agent.invoke({"messages": messages})
return {'answer': result['messages'][-1].content}
对于与 LangChain 相关的问题,我们将使用 ReAct 代理。为了使该代理能够回答关于库的问题,我们将为其配备一个搜索引擎工具。我选择了 Tavily 作为此目的,因为它提供了针对 LLM 应用优化的搜索结果。
如果你没有账户,你可以注册并免费使用 Tavily(每月最多 1K 次请求)。要开始使用,你需要在环境变量中指定 Tavily API 密钥。
# search expert
from langchain_community.tools.tavily_search import TavilySearchResults
os.environ["TAVILY_API_KEY"] = 'tvly-...'
tavily_tool = TavilySearchResults(max_results=5)
search_expert_system_prompt = '''
You are an expert in LangChain and other technologies.
Your goal is to answer questions based on results provided by search.
You don't add anything yourself and provide only information baked by other sources.
'''
def search_expert_node(state: MultiAgentState):
model = ChatOpenAI(model="gpt-4o-mini")
sql_agent = create_react_agent(model, [tavily_tool],
state_modifier = search_expert_system_prompt)
messages = [HumanMessage(content=state['question'])]
result = sql_agent.invoke({"messages": messages})
return {'answer': result['messages'][-1].content}
对于一般问题,我们将利用一个简单的 LLM 模型,而不使用特定工具。
# general model
general_prompt = '''You're a friendly assistant and your goal is to answer general questions.
Please, don't provide any unchecked information and just tell that you don't know if you don't have enough info.
'''
def general_assistant_node(state: MultiAgentState):
messages = [
SystemMessage(content=general_prompt),
HumanMessage(content=state['question'])
]
model = ChatOpenAI(model="gpt-4o-mini")
response = model.invoke(messages)
return {"answer": response.content}
最后缺少的部分是一个条件函数,用于路由。这将非常简单——我们只需要从路由节点定义的状态中传递问题类型。
def route_question(state: MultiAgentState):
return state['question_type']
现在,是时候创建我们的图表了。
builder = StateGraph(MultiAgentState)
builder.add_node("router", router_node)
builder.add_node('database_expert', sql_expert_node)
builder.add_node('langchain_expert', search_expert_node)
builder.add_node('general_assistant', general_assistant_node)
builder.add_conditional_edges(
"router",
route_question,
{'DATABASE': 'database_expert',
'LANGCHAIN': 'langchain_expert',
'GENERAL': 'general_assistant'}
)
builder.set_entry_point("router")
builder.add_edge('database_expert', END)
builder.add_edge('langchain_expert', END)
builder.add_edge('general_assistant', END)
graph = builder.compile(checkpointer=memory)

现在,我们可以通过几个问题来测试该设置,看看它的表现如何。
thread = {"configurable": {"thread_id": "2"}}
results = []
for s in graph.stream({
'question': "What info do we have in ecommerce_db.users table?",
}, thread):
print(s)
results.append(s)
print(results[-1]['database_expert']['answer'])
# The `ecommerce_db.users` table contains the following columns:
# 1\. **User ID**: A unique identifier for each user.
# 2\. **Country**: The country where the user is located.
# 3\. **Is Active**: A flag indicating whether the user is active (1 for active, 0 for inactive).
# 4\. **Age**: The age of the user.
# Here are some sample entries from the table:
#
# | User ID | Country | Is Active | Age |
# |---------|----------------|-----------|-----|
# | 1000001 | United Kingdom | 0 | 70 |
# | 1000002 | France | 1 | 87 |
# | 1000003 | France | 1 | 88 |
# | 1000004 | Germany | 1 | 25 |
# | 1000005 | Germany | 1 | 48 |
#
# This gives an overview of the user data available in the table.
做得好!它为与数据库相关的问题提供了相关结果。让我们试着问一些关于 LangChain 的问题。
thread = {"configurable": {"thread_id": "42"}}
results = []
for s in graph.stream({
'question': "Does LangChain support Ollama?",
}, thread):
print(s)
results.append(s)
print(results[-1]['langchain_expert']['answer'])
# Yes, LangChain supports Ollama. Ollama allows you to run open-source
# large language models, such as Llama 2, locally, and LangChain provides
# a flexible framework for integrating these models into applications.
# You can interact with models run by Ollama using LangChain, and there are
# specific wrappers and tools available for this integration.
#
# For more detailed information, you can visit the following resources:
# - [LangChain and Ollama Integration](https://js.langchain.com/v0.1/docs/integrations/llms/ollama/)
# - [ChatOllama Documentation](https://js.langchain.com/v0.2/docs/integrations/chat/ollama/)
# - [Medium Article on Ollama and LangChain](https://medium.com/@abonia/ollama-and-langchain-run-llms-locally-900931914a46)
太棒了!一切都运行良好,显然 Tavily 的搜索对 LLM 应用非常有效。
添加人类参与的交互
我们已经成功创建了一个能够回答问题的工具。然而,在许多情况下,让人类参与其中以批准提议的操作或提供额外的反馈是非常有益的。让我们添加一个步骤,在返回最终结果给用户之前收集人类的反馈。
最简单的方法是添加两个额外的节点:
一个
人类节点,用于收集反馈,一个
编辑器节点,用于根据反馈重新审视答案。
让我们创建这些节点:
人类节点: 这将是一个虚拟节点,它不会执行任何操作。
编辑器节点: 这将是一个 LLM 模型,用于接收所有相关信息(客户问题、草拟答案和提供的反馈)并修订最终答案。
def human_feedback_node(state: MultiAgentState):
pass
editor_prompt = '''You're an editor and your goal is to provide the final answer to the customer, taking into account the feedback.
You don't add any information on your own. You use friendly and professional tone.
In the output please provide the final answer to the customer without additional comments.
Here's all the information you need.
Question from customer:
----
{question}
----
Draft answer:
----
{answer}
----
Feedback:
----
{feedback}
----
'''
def editor_node(state: MultiAgentState):
messages = [
SystemMessage(content=editor_prompt.format(question = state['question'], answer = state['answer'], feedback = state['feedback']))
]
model = ChatOpenAI(model="gpt-4o-mini")
response = model.invoke(messages)
return {"answer": response.content}
让我们将这些节点添加到我们的图中。此外,我们需要在“人类”节点之前引入一个中断,以确保流程暂停以获取人类反馈。
builder = StateGraph(MultiAgentState)
builder.add_node("router", router_node)
builder.add_node('database_expert', sql_expert_node)
builder.add_node('langchain_expert', search_expert_node)
builder.add_node('general_assistant', general_assistant_node)
builder.add_node('human', human_feedback_node)
builder.add_node('editor', editor_node)
builder.add_conditional_edges(
"router",
route_question,
{'DATABASE': 'database_expert',
'LANGCHAIN': 'langchain_expert',
'GENERAL': 'general_assistant'}
)
builder.set_entry_point("router")
builder.add_edge('database_expert', 'human')
builder.add_edge('langchain_expert', 'human')
builder.add_edge('general_assistant', 'human')
builder.add_edge('human', 'editor')
builder.add_edge('editor', END)
graph = builder.compile(checkpointer=memory, interrupt_before = ['human'])

现在,当我们运行图时,执行将在“人类”节点之前停止。
thread = {"configurable": {"thread_id": "2"}}
for event in graph.stream({
'question': "What are the types of fields in ecommerce_db.users table?",
}, thread):
print(event)
# {'question_type': 'DATABASE', 'question': 'What are the types of fields in ecommerce_db.users table?'}
# {'router': {'question_type': 'DATABASE'}}
# {'database_expert': {'answer': 'The `ecommerce_db.users` table has the following fields:\n\n1\. **user_id**: UInt64\n2\. **country**: String\n3\. **is_active**: UInt8\n4\. **age**: UInt64'}}
让我们获取客户的反馈并更新状态。
user_input = input("Do I need to change anything in the answer?")
# Do I need to change anything in the answer?
# It looks wonderful. Could you only make it a bit friendlier please?
graph.update_state(thread, {"feedback": user_input}, as_node="human")
我们可以检查状态,以确认反馈已经被填充,并且序列中的下一个节点是editor。
print(graph.get_state(thread).values['feedback'])
# It looks wonderful. Could you only make it a bit friendlier please?
print(graph.get_state(thread).next)
# ('editor',)
我们可以继续执行。传递None作为输入将从暂停的位置恢复流程。
for event in graph.stream(None, thread, stream_mode="values"):
print(event)
print(event['answer'])
# Hello! The `ecommerce_db.users` table has the following fields:
# 1\. **user_id**: UInt64
# 2\. **country**: String
# 3\. **is_active**: UInt8
# 4\. **age**: UInt64
# Have a nice day!
编辑器考虑了我们的反馈,并在最终信息中添加了一些礼貌的词语。这是一个非常棒的结果!
我们可以通过为编辑器配备Human工具,以更具代理性的方式实现人类在环的互动。
让我们调整我们的编辑器。我稍微修改了提示,并向代理中添加了工具。
from langchain_community.tools import HumanInputRun
human_tool = HumanInputRun()
editor_agent_prompt = '''You're an editor and your goal is to provide the final answer to the customer, taking into the initial question.
If you need any clarifications or need feedback, please, use human. Always reach out to human to get the feedback before final answer.
You don't add any information on your own. You use friendly and professional tone.
In the output please provide the final answer to the customer without additional comments.
Here's all the information you need.
Question from customer:
----
{question}
----
Draft answer:
----
{answer}
----
'''
model = ChatOpenAI(model="gpt-4o-mini")
editor_agent = create_react_agent(model, [human_tool])
messages = [SystemMessage(content=editor_agent_prompt.format(question = state['question'], answer = state['answer']))]
editor_result = editor_agent.invoke({"messages": messages})
# Is the draft answer complete and accurate for the customer's question about the types of fields in the ecommerce_db.users table?
# Yes, but could you please make it friendlier.
print(editor_result['messages'][-1].content)
# The `ecommerce_db.users` table has the following fields:
# 1\. **user_id**: UInt64
# 2\. **country**: String
# 3\. **is_active**: UInt8
# 4\. **age**: UInt64
#
# If you have any more questions, feel free to ask!
所以,编辑向人类提问:“草稿中的回答是否完整且准确地回答了关于 ecommerce_db.users 表中字段类型的问题?”在收到反馈后,编辑修改了答案,使其更易于理解。
让我们更新我们的主图,合并新的代理,而不是使用两个独立的节点。采用这种方法后,我们不再需要中断。
def editor_agent_node(state: MultiAgentState):
model = ChatOpenAI(model="gpt-4o-mini")
editor_agent = create_react_agent(model, [human_tool])
messages = [SystemMessage(content=editor_agent_prompt.format(question = state['question'], answer = state['answer']))]
result = editor_agent.invoke({"messages": messages})
return {'answer': result['messages'][-1].content}
builder = StateGraph(MultiAgentState)
builder.add_node("router", router_node)
builder.add_node('database_expert', sql_expert_node)
builder.add_node('langchain_expert', search_expert_node)
builder.add_node('general_assistant', general_assistant_node)
builder.add_node('editor', editor_agent_node)
builder.add_conditional_edges(
"router",
route_question,
{'DATABASE': 'database_expert',
'LANGCHAIN': 'langchain_expert',
'GENERAL': 'general_assistant'}
)
builder.set_entry_point("router")
builder.add_edge('database_expert', 'editor')
builder.add_edge('langchain_expert', 'editor')
builder.add_edge('general_assistant', 'editor')
builder.add_edge('editor', END)
graph = builder.compile(checkpointer=memory)
thread = {"configurable": {"thread_id": "42"}}
results = []
for event in graph.stream({
'question': "What are the types of fields in ecommerce_db.users table?",
}, thread):
print(event)
results.append(event)
这个图的工作方式与之前的类似。我个人更喜欢这种方法,因为它利用了工具,使得解决方案更加灵活。例如,代理可以多次联系到人类,并根据需要调整问题。
就是这样。我们已经构建了一个多代理系统,可以回答来自不同领域的问题,并考虑到人类反馈。
你可以在GitHub上找到完整的代码。
总结
在本文中,我们探讨了 LangGraph 库及其在构建单一和多代理工作流中的应用。我们审视了其一系列功能,现在是总结其优缺点的时候了。同时,比较 LangGraph 和我在上一篇文章中讨论的 CrewAI 也会很有用。
总体而言,我认为 LangGraph 是一个相当强大的框架,用于构建复杂的 LLM 应用:
LangGraph 是一个低级框架,提供了广泛的自定义选项,使你能够构建出精确所需的内容。
由于 LangGraph 建立在 LangChain 之上,它与 LangChain 的生态系统无缝集成,使得利用现有工具和组件变得更加容易。
然而,LangGraph 还有一些可以改进的地方:
LangGraph 的灵活性带来了较高的入门门槛。虽然你可以在 15 至 30 分钟内理解 CrewAI 的概念,但要真正熟悉并跟上 LangGraph 的节奏则需要一些时间。
LangGraph 为你提供了更高的控制水平,但它缺少 CrewAI 一些很酷的预构建功能,比如协作或现成的检索增强生成(RAG)工具。
LangGraph 并不像 CrewAI 那样强制执行最佳实践(例如,角色扮演或护栏)。因此,它可能导致较差的结果。
我认为 CrewAI 是一个更适合新手和常见用例的框架,因为它帮助你快速获得良好的结果,并提供指导以防止错误。
如果你想构建一个高级应用并需要更多的控制,LangGraph 是更好的选择。请记住,你需要投入时间来学习 LangGraph,并且要对最终解决方案负责,因为该框架不会提供指导帮助你避免常见的错误。
非常感谢你阅读本文。我希望本文对你有所启发。如果你有任何后续问题或评论,请在评论区留言。
参考资料
本文灵感来源于 DeepLearning.AI 的“LangGraph 中的 AI 代理”短期课程。
从代码到洞察:数据分析师的软件工程最佳实践
每个数据分析师都应该知道的十大工程经验
·发表于 Towards Data Science ·16 分钟阅读·2024 年 6 月 6 日
--

图片由 DALL-E 3 生成
数据分析师的工作结合了不同领域的技能:
我们需要具备商业理解和领域知识,才能解决实际的商业问题,并考虑所有细节。
数学、统计学和基本的机器学习技能帮助我们进行严格的分析,并从数据中得出可靠的结论。
可视化技能和讲故事的能力使我们能够传达信息并影响产品。
最后但同样重要的是,计算机科学和软件工程的基础是提高效率的关键。
在大学时,我学到了很多关于计算机科学的知识。我尝试过至少十种编程语言(从低级汇编语言和 CUDA 到高级 Java 和 Scala)以及无数工具。我的第一个工作邀请是担任后端工程师。我决定不走这条路,但所有这些知识和原则对我的分析师职业生涯都非常有益。因此,我希望在这篇文章中与大家分享这些主要的原则。
代码不是为计算机而写的。它是为人类而写的
我从很多软件工程师那里听过这句口头禅。这在一本编程圣经《Clean Code》中得到了很好的解释。
确实,阅读与写作所花的时间比例远远超过 10 比 1。我们在编写新代码时,往往需要不断阅读旧代码。
在大多数情况下,工程师更倾向于选择容易理解的冗长代码,而不是简洁的惯用一行代码。
我必须坦白说,我有时会打破这个规则,写出超长的 pandas 一行代码。例如,让我们看看下面的代码。你知道这段代码在做什么吗?
# ad-hoc only code
df.groupby(['month', 'feature'])[['user_id']].nunique()\
.rename(columns = {'user_id': 'users'})\
.join(df.groupby(['month'])[['user_id']].nunique()\
.rename(columns = {'user_id': 'total_users'})).apply(
lambda x: 100*x['users']/x['total_users'], axis = 1)\
.reset_index().rename(columns = {0: 'users_share'})\
.pivot(index = 'month', columns = 'feature', values = 'users_share')
坦白说,我可能需要一些时间才能在一个月内赶上这段代码的进度。为了让这段代码更具可读性,我们可以将其拆分成几个步骤。
# maintainable code
monthly_features_df = df.groupby(['month', 'feature'])[['user_id']].nunique()\
.rename(columns = {'user_id': 'users'})
monthly_total_df = df.groupby(['month'])[['user_id']].nunique()\
.rename(columns = {'user_id': 'total_users'})
monthly_df = monthly_features_df.join(monthly_total_df).reset_index()
monthly_df['users_share'] = 100*monthly_df.users/monthly_df.total_users
monthly_df.pivot(index = 'month', columns = 'feature', values = 'users_share')
希望现在你能更容易地理解逻辑,并看到这段代码显示的是每个月使用各个功能的客户百分比。未来的我一定会非常高兴看到像这样的代码,并感激所有的努力。
自动化重复性任务
如果你有重复性很高的任务,我建议你考虑自动化。让我分享一些我自己经验中的例子,希望能对你有所帮助。
分析师自动化任务最常见的方式是创建仪表板,而不是每次手动计算数字。自助工具(可配置的仪表板,利益相关者可以更改过滤器并查看数据)可以节省大量时间,并让我们能够专注于更复杂且有影响力的研究。
如果仪表板不可用,还有其他自动化方式。我曾经每周制作报告,并通过电子邮件发送给利益相关者。过了一段时间,这变得非常繁琐,我开始考虑自动化。那时,我使用了一个基础工具——cron 在虚拟机上。我安排了一个 Python 脚本,计算最新的数字并发送电子邮件。
当你有一个脚本时,你只需要在 cron 文件中添加一行。例如,下面的这一行将在每周一上午 9:10 执行analytical_script.py。
10 9 * * 1 python analytical_script.py
Cron 是一种基础但仍然可持续的解决方案。其他可以用来调度脚本的工具有 Airflow、DBT 和 Jenkins。你可能知道 Jenkins 是一种工程师常用的 CI/CD(持续集成和持续交付)工具。这可能会让你感到惊讶,它也足够可定制,可以执行分析脚本。
如果你需要更大的灵活性,是时候考虑 Web 应用程序了。在我的第一个团队中,我们没有 A/B 测试工具,因此很长一段时间,分析师必须手动分析每次更新。最终,我们编写了一个Flask Web 应用程序,让工程师能够自助服务。现在,像Gradio或Streamlit这样的轻量级 Web 应用程序解决方案,可以在几天内学会。
你可以在我以前的一篇文章中找到 Gradio 的详细指南。
精通你的工具
你每天在工作中使用的工具在效率和最终结果中扮演着重要角色。所以值得掌握它们。
当然,你可以使用默认的文本编辑器编写代码,但大多数人使用 IDE(集成开发环境)。你将花费大量时间在这个应用程序上,因此值得评估你的选择。
你可以从JetBrains 2021 年调查中找到最流行的 Python IDE。

图表由作者制作,数据来自JetBrains 调查
我通常使用 Python 和 Jupyter Notebooks 进行日常工作。在我看来,最适合这种任务的 IDE 是 JupyterLab。然而,我现在正在尝试其他选择,以便能够使用 AI 助手。自动补全的好处——它能消除大量的样板代码——对我来说是无价的,因此我愿意承担切换成本。我鼓励你去研究不同的选项,看看哪个最适合你的工作。
另一个有用的技巧是快捷键。你可以通过快捷键比使用鼠标更快地完成任务,而且看起来也很酷。我建议从谷歌搜索你的 IDE 快捷键开始,因为你通常使用这个工具最多。根据我的经验,最有价值的命令是:在 Notebook 中创建一个新单元格、运行该单元格、删除它以及将单元格转换为 markdown。
如果你有其他经常使用的工具(比如 Google Sheets 或 Slack),你也可以学习它们的命令。
学习快捷键的主要窍门是“练习,练习,再练习”——你需要重复一百次才能开始自动化操作。甚至有一些插件可以促使你更多地使用快捷键(例如,这个插件来自 JetBrains)。
最后但同样重要的是 CLI(命令行界面)。刚开始它可能看起来令人生畏,但基本的 CLI 知识通常是值得的。我甚至用 CLI 来操作 GitHub,因为它让我清楚地了解发生了什么。
然而,有些情况下几乎无法避免使用 CLI,比如在远程服务器上工作。要自信地与服务器交互,你需要学习不到十个命令。这篇文章可以帮助你掌握 CLI 的基本知识。
管理你的环境
继续讨论工具,设置你的环境总是个好主意。我有一个Python 虚拟环境,用来日常工作并包含我通常使用的所有库。
创建一个新的虚拟环境就像在终端中输入几行代码一样简单(这是开始使用 CLI 的绝佳机会)。
# creating venv
python -m venv routine_venv
# activating venv
source routine_venv/bin/activate
# installing ALL packages you need
pip install pandas plotly
# starting Juputer Notebooks
jupyter notebook
你可以从这个环境启动 Jupyter,或者在你的 IDE 中使用它。
对于大型项目,最好有一个独立的环境。我通常只有在需要使用不常见的技术栈(比如 PyTorch 或其他新的 LLM 框架)或遇到库兼容性问题时才会这么做。
另一种节省环境资源的方式是使用Docker 容器。我用它来处理一些更像生产环境的任务,比如在服务器上运行的 Web 应用。
思考程序性能
说实话,分析师通常不需要过多考虑性能。当我在数据分析领域获得第一份工作时,我的领导分享了有关性能优化的实际方法(从那时起我一直在使用它)。当你考虑性能时,要考虑总时间与努力的关系。假设我有一个运行 4 小时的 MapReduce 脚本。我需要优化它吗?这取决于情况。
如果我只需要运行一两次,那么花费 1 小时来优化这个脚本以便在 1 小时内计算结果是没有太大意义的。
如果我计划每天运行它,那么花时间让它更高效,避免浪费计算资源(和金钱)是值得的。
由于我的大多数任务是一次性研究工作,在大多数情况下,我不需要优化代码。然而,遵循一些基本规则以避免等待数小时是值得的。小技巧可以带来巨大的效果。我们来讨论一个这样的例子。
从基础开始,性能的基石是大 O 表示法。简而言之,大 O 表示法展示了执行时间与处理元素数量之间的关系。所以,如果我的程序是 O(n),这意味着如果我将数据量增加 10 倍,执行时间大约会增加 10 倍。
在编写代码时,了解算法的复杂度和主要数据结构是值得的。例如,检查一个元素是否在列表中需要 O(n)的时间,但在集合中只需要 O(1)的时间。让我们看看这如何影响我们的代码。
我有两个数据框,分别包含 Q1 和 Q2 用户的交易记录,对于 Q1 数据框中的每一笔交易,我想了解该客户是否被保留。我们的数据框相对较小,大约有 30 万到 40 万行。

如你所见,性能差异非常大。
第一种方法是最差的,因为在每次迭代时(对于 Q1 数据集中的每一行),我们都要计算唯一的 user_ids 列表。然后,我们以 O(n)的复杂度在列表中查找元素。这个操作需要 13 分钟。
第二种方法,即先计算列表,稍微好一些,但仍然需要将近 6 分钟。
如果我们预先计算一个 user_ids 列表并将其转换为集合,我们将瞬间获得结果。
如你所见,凭借基本的知识,我们可以让代码的运行速度提高超过 10K 倍。这是一个游戏规则的改变者。
另一个普遍的建议是避免使用普通的 Python,而是优先使用性能更高的数据结构,如 pandas 或 numpy。这些库之所以更快,是因为它们对数组进行了向量化操作,而这些操作是用 C 实现的。通常,numpy 的性能会稍好一些,因为 pandas 是建立在 numpy 之上的,但由于一些附加功能,它的速度稍慢。
不要忘记 DRY 原则。
DRY 代表“不要重复自己”,这一原则不言自明。它推崇结构化的模块化代码,你可以轻松地重复使用。
如果你已经复制粘贴了一段代码三次,那是时候考虑代码结构,思考如何将这个逻辑封装起来。
标准的分析任务是数据整理,我们通常遵循过程化范式。因此,结构化代码最明显的方式是函数。然而,你也可以遵循面向对象编程并创建类。在 我之前的文章 中,我分享了面向对象方法在仿真中的应用示例。
模块化代码的好处在于更好的可读性、更快的开发和更容易的修改。例如,如果你想将可视化从折线图改为区域图,你可以在一个地方进行修改,然后重新运行代码。
如果你有一堆与某一特定领域相关的函数,你可以为其创建一个 Python 包,像使用其他 Python 库一样与这些函数进行交互。这里有 一个详细的指南 介绍如何做到这一点。
利用测试
在我看来,分析领域中另一个被低估的话题是测试。软件工程师通常会有关于测试覆盖率的关键绩效指标,这对分析师也可能有帮助。然而,在许多情况下,我们的测试更多是与数据相关,而非代码本身。
我从一位同事那里学到的技巧是添加数据新鲜度的测试。我们有多个季度和年度报告的脚本,这些报告我们运行得比较少。因此,他添加了一个检查,看看我们使用的表格中的最新行是否在报告期结束后(这显示了表格是否已更新)。在 Python 中,你可以使用 assert 语句来实现这一点。
assert last_record_time >= datetime.date(2023, 5, 31)
如果条件满足,那么什么也不会发生。否则,你将会得到一个 AssertionError。这是一个快速且简单的检查,能帮助你早期发现问题。
我更倾向于验证的另一项是总和统计。例如,如果你在切片、分割和转换数据,值得检查请求和指标的总数是否保持一致。一些常见的错误有:
因为连接操作而产生的重复项,
当你使用
pandas.groupby函数时,过滤掉的None值,因为内连接而过滤掉的维度。
此外,我总是检查数据中是否存在重复。如果你期望每一行代表一个用户,那么行数应该等于df.user_id.nunique()。如果不等,说明数据有问题,需要调查。
最棘手且最有帮助的测试是常识检查。让我们讨论一些可能的方法。
首先,我会检查结果是否在总体上合理。例如,如果一个月的留存率为 99%,或者我在欧洲获得了 10 亿客户,那代码很可能存在 BUG。
其次,我会寻找其他数据来源或该主题的先前研究,以验证我的结果是否可行。
如果你没有其他类似的研究(例如,你在估算在新市场推出产品后的潜在收入),我建议你将你的数字与其他现有细分市场的数据进行对比。例如,如果你在另一个市场推出产品后,对收入的增量效应为当前收入的 5 倍,我会说这有点过于乐观,值得重新审视假设。
我希望这种思维方式能够帮助你获得更可行的结果。
鼓励团队使用版本控制系统
工程师们即使是在独立工作的小项目中,也会使用版本控制系统。同时,我常常看到分析师使用 Google Sheets 来存储他们的查询。由于我是强烈支持并倡导将所有代码存储在仓库中的人,所以我不能错过与大家分享我的想法的机会。
为什么在我的数据职业生涯中使用了超过 10 年的仓库?以下是主要的好处:
可重复性。我们经常需要调整先前的研究(例如,增加一个维度或将研究范围缩小到特定的细分市场),或者仅仅重复早期的计算。如果你以结构化的方式存储所有代码,你可以快速再现你的先前工作。这通常可以节省大量时间。
透明性。将代码与研究结果关联,可以让你的同事深入了解方法论的每个细节,这带来更多的信任,自然也能帮助发现 BUG 或潜在的改进之处。
知识共享。如果你有一个易于浏览的目录(或者你将代码链接到任务跟踪器),那么你的同事能够轻松找到你的代码,而无需从头开始调查。
回滚。你是否曾经遇到过这样一种情况:你的代码昨天还能正常运行,但今天你更改了一些东西,现在完全崩溃了?在开始定期提交代码之前,我曾经遇到过很多次这种情况。版本控制系统允许你查看整个版本历史,并对比代码或回滚到之前的正常版本。
协作。如果你和他人一起合作编写代码,可以利用版本控制系统来跟踪和合并更改。
我希望你现在能看到它的潜在好处。让我简要分享一下我通常存储代码的设置:
我使用
git+GitHub作为版本控制系统,我是那个仍然使用命令行界面的恐龙(它给我一种掌控感的安慰),但你也可以使用 GitHub 应用程序或 IDE 的功能。我的工作大部分是研究(代码、数字、图表、评论等),所以我将 95% 的代码存储为 Jupyter Notebooks。
我会将我的代码与 Jira 工单关联起来。我通常在代码库中有一个
tasks文件夹,并将子文件夹命名为工单键(例如ANALYTICS-42)。然后,我会将与任务相关的所有文件放在这个子文件夹中。采用这种方式,我可以在几秒钟内找到与(几乎)任何任务相关的代码。
在 GitHub 中使用 Jupyter Notebooks 有许多细微差别,值得注意。
首先,考虑一下输出内容。当你将 Jupyter Notebook 提交到代码库时,你会保存输入单元格(你的代码或评论)和输出。因此,值得注意的是,你是否真的希望共享这些输出。它可能包含个人身份信息(PII)或其他敏感数据,我不建议将这些提交。另外,输出可能非常大且没有信息量,只会让你的代码库显得杂乱无章。当你保存一个超过 10 MB 的 Jupyter Notebook 并且包含一些随机数据输出时,所有同事在执行 git pull 命令时都会将这些数据下载到他们的电脑上。
输出中的图表可能特别有问题。我们都喜欢优秀的交互式 Plotly 图表。不幸的是,它们在 GitHub UI 上无法呈现,因此你的同事很可能无法看到它们。为了解决这个问题,你可以将 Plotly 的输出类型切换为 PNG 或 JPEG。
import plotly.io as pio
pio.renderers.default = "jpeg"
你可以在文档中找到关于 Plotly 渲染器的更多详细信息。
最后但同样重要的是,Jupyter Notebooks 的差异通常很棘手。你经常会想了解代码的两个版本之间的差异。然而,默认的 GitHub 视图不会提供太多有用的信息,因为由于 notebook 元数据的变化,杂乱无章的内容太多(如下例所示)。

实际上,GitHub 几乎已经解决了这个问题。一个丰富的差异功能处于预览阶段,可以让你的生活变得更轻松——你只需要在设置中开启它。

有了这个功能,我们可以轻松看到只是做了几个小的改动。我更改了默认渲染器和保留曲线的参数(所以图表也进行了更新)。
请求代码审查
工程师会对(几乎)所有代码变更进行同行评审。这个过程能让我们及早发现 bug,防止不良行为者或有效地共享团队中的知识。
当然,这不是万能的:审查员可能会错过 bug,或者恶意行为者可能会在流行的开源项目中引入漏洞。例如,有一个相当可怕的故事讲述了一个如何在广泛使用于流行 Linux 发行版的压缩工具中植入后门的事件。
然而,有证据表明,代码审查确实有帮助。麦康奈尔在他的经典书籍《代码大全》中分享了以下统计数据。
… 软件测试单独进行时效果有限 —— 单元测试的平均缺陷检测率只有 25%,功能测试为 35%,集成测试为 45%。相比之下,设计和代码检查的平均有效性分别为 55%和 60%。
尽管有这些好处,分析师们通常根本不进行代码审查。我能理解为什么这可能是一个挑战:
分析团队通常较小,花费有限的资源进行反复检查可能听起来不太合理。
相当多时候,分析师工作于不同领域,而你可能是唯一一个了解这个领域足够好来进行代码审查的人。
然而,我真的鼓励你进行代码审查,至少对于关键的内容,以减轻风险。以下是我会要求同事仔细检查我的代码和假设的情况:
当我在一个新领域使用数据时,总是一个好主意请专家审核所用的假设;
所有与客户沟通或干预相关的任务,因为此类数据中的错误可能会导致重大影响(例如,我们可能会向客户传递错误信息或错误地停用某些人);
高风险决策:如果你计划投入团队六个月的时间来进行项目开发,那么值得反复检查;
当结果出乎意料时:当我看到令人惊讶的结果时,第一个要测试的假设是检查代码中是否存在错误。
当然,这并不是一个详尽无遗的列表,但我希望你能理解我的推理,并用常识来判断何时需要寻求代码审查。
保持最新
这句著名的刘易斯·卡罗尔名言很好地代表了当前技术领域的状态。
… 你必须尽全力奔跑,才能保持在原地。如果你想到达别的地方,你必须至少跑得比现在快两倍。
我们的领域在不断发展:每天都有新的论文发布,库被更新,新工具层出不穷,等等。软件工程师、数据分析师、数据科学家等也是如此。
现在有这么多的信息来源,根本不难找到:
每周来自 Towards Data Science 和其他一些订阅的电子邮件,
在 LinkedIn 和 X(前 Twitter)上关注专家,
订阅我使用的工具和库的电子邮件更新,
参加本地聚会。
更具挑战性的是避免被大量信息淹没。我尽量一次专注于一件事,以避免过多分心。
总结
以上就是对于分析师有帮助的软件工程实践。让我在这里快速总结一下:
代码不是为计算机写的,它是为人类写的。
自动化重复任务。
精通你的工具。
管理你的环境。
考虑程序性能。
不要忘记 DRY 原则。
利用测试。
鼓励团队使用版本控制系统。
请求代码审查。
保持最新状态。
数据分析结合了来自不同领域的技能,因此我相信我们可以从学习软件工程师、产品经理、设计师等的最佳实践中获益匪浅。通过采用我们同事的行之有效的技巧,我们可以提高我们的工作效率和效果。我强烈鼓励你们也去探索这些相邻领域。
非常感谢你阅读这篇文章。我希望这篇文章对你有启发。如果你有任何后续问题或评论,请在评论区留言。
参考文献
除非另有说明,所有图片均由作者制作。
致谢
我无法错过表达对我的伙伴的衷心感谢的机会,他一直与我分享他的工程智慧,并审阅了我所有的文章。
从代码到论文:使用 GPT 模型和 Python 生成科学 LaTeX 文档
自动化科学代码文档化:一个基于 GPT 的概念验证(POC),用于简化工作流。
·发布于Towards Data Science ·阅读时间:11 分钟·2024 年 12 月 2 日
--

插图。由 ChatGPT 生成。
介绍
在处理科学论文时,通常需要将算法转化为科学公式,这些公式通常以 LaTeX 格式呈现。这个过程可能非常繁琐且耗时,尤其是在大型项目中,因为它需要在代码库和 LaTeX 文档之间频繁地来回操作。
在处理一个大型算法库时,我开始探索如何简化这一工作流程。我的动机来源于手动将复杂算法转化为 LaTeX 兼容的公式时效率低下。特别的挑战是确保多个文档之间的一致性,尤其是在那些需要频繁更新公式的项目中。这促使我探索如何通过自动化来简化重复性任务,同时提高准确性。
在本文接下来的部分,我将同时使用“算法”和“科学代码”这两个术语。本文中所有的图像(封面图除外)均由作者创作。
目标
我的目标是将科学代码转化为一个全面的文档,...
从数据知情到数据驱动决策:贸易空间探索简介
超越探索性数据分析,找到干草堆中的针
·发布于 Towards Data Science ·17 分钟阅读·2024 年 8 月 13 日
--

(图片由作者使用 Playground AI 创建)
数据与决策
现代世界充满了数据。由机器人和人类共同收集、抓取和存储的经验数据。由科学家和工程师创建和运行的模型和仿真生成的人工数据。甚至高管和领域专家的意见,也被记录下来以供后续使用,这些都可以视为数据。
目的何在?为什么我们花费如此多的时间和精力来收集数据?数据革命的号角已经吹响,数据驱动决策的理念应运而生:即我们可以利用这些数据做出更好的决策。对于企业来说,这可能意味着选择一组研发项目或市场推广活动,以最大化未来的收入。对于个人来说,这可能仅仅意味着提高对下一辆车、手机或电脑的满意度。
那么,数据科学家(以及分析师和工程师)如何实际利用他们的数据来支持决策呢?大多数数据到决策的流程始于探索性数据分析(EDA)——清理和表征数据集的过程,主要通过统计分析以及支持散布、分布、异常值和特征之间的相关性的图表。EDA 有许多优点,可以帮助深入理解数据集,并进而理解任何基于这些数据做出的潜在决策:
识别潜在的错误或有缺陷的数据,以及如何修正它们
确定数据集中可能存在的被低估或高估的子群体,以便进行数学调整或驱动额外的数据收集
建立对可能性(基于差异)和常见性(基于分布)的直觉
开始发展对不同特征之间潜在因果关系的理解(需要始终记住,相关性不等于因果性)
这是迈向决策的有用第一步!一个执行良好的 EDA 将产生一个可靠的数据集,并提供关于数据中趋势的见解,决策者可以利用这些见解来指导他们的行动决策。稍微概括一点,趋势见解关注的是数据集中项目取特定值的频率:例如“这些事情通常是 X”或“当这些事情是 X 时,其他事情通常是 Y”。
不幸的是,许多现实世界中的数据到决策流程在此处结束:数据科学家通过 EDA 生成一堆趋势见解,并将其“丢给”业务决策者。然后,决策者需要根据这些见解推断出他们(可能有很多种)的不同行动方案可能带来的后果。这说起来容易做起来难!这是一个在复杂性和规模上都充满挑战的任务,尤其是对非技术领域的利益相关者来说。

数据科学家经常不得不将趋势“丢给”做出商业决策的人,而对这些决策是如何做出的没有任何可见性——有时甚至不知道那些决策是什么!(图像由作者使用 Playground AI 创建)
如果我们想做出更好的决策,我们需要打破数据与决策之间的壁垒。如果我们能够收集或生成直接对应于决策者可选方案或行动路径的数据,就可以免去他们从趋势中推断的需求。根据决策类型,这通常是直接的:比如一位购房者手中有一份所在区域所有待售房屋的清单,或者一个工程公司拥有能够评估成千上万种潜在新组件设计的模型。
创建以决策为中心的数据集需要与经典的 EDA(探索性数据分析)略有不同的思维方式,但结果显著更容易解释,因此更有可能充分支持决策。我们的探索不仅仅停留在趋势上,而是需要解决大海捞针的问题,找到数据集中最佳的单一数据点,以便完成从头到尾的数据到决策流程。
欢迎来到交易空间探索的世界。
从数据驱动到数据支持
在我们深入探讨交易空间探索的细节之前,让我们用一个决策实例来为这个讨论奠定基础。购买一辆车是许多人都熟悉的一个决策,它是一个很好的例子,原因有很多:
这是一个高风险的决策,值得花力气去“做对”。汽车很贵,(理想情况下)使用寿命长,而且大多数人每天都需要用车!任何买过“柠檬车”的人都能告诉你,这是一种特别具有挑战性且令人沮丧的挫折。
人们在比较汽车时会关注多个因素:价格、可靠性、安全性、操控性等。这不是一个你只需要挑选马力最高的车就能满意的问题。
通常有很多选择。每个制造商的新车,来自汽车销售点和在线市场的二手车,甚至像摩托车这种与汽车相关的物品也可能是有效的解决方案。这可是需要筛选的大量潜在数据!
为了简化这个例子,我们假设我们只对购买二手车感兴趣。
现在,让我们思考一下对这个问题进行常规 EDA(探索性数据分析)的过程可能是什么样子。首先,我会获取一个大型数据集,通常由经验观察组成:这个来自 Craigslist 的二手车销售数据集 [1] 就很合适。像这样的平面数据文件,其中每一项/行对应着由一组共享特征/列描述的汽车列表,是公开可用数据集中最常见的格式。然后,我会从总结数据、发现问题开始,并清理数据,去除不完整的/离群的列表或定义不一致的列。一旦数据清理完毕,我会使用统计方法或图表分析数据集,找出不同变量之间的相关性。如果你想看到一个关于这个数据集的详细 EDA 示例,可以点击这里查看 [2]。

一个常见的 EDA 可视化图是散点图矩阵,展示数据集中关键参数的成对关系(图像由作者使用 Pandas/Matplotlib 创建)
现在再想想这个决策:我想买一辆二手车。EDA 帮到我了吗?对于 EDA 爱好者和专家们来说,好消息是:当然帮到了!现在,我拥有了与我的决策高度相关的趋势见解。价格与车型年份有显著的正相关关系,并且与里程表的公里数呈(负)相关!大多数可用的汽车都是 3 到 7 年车龄!通过对二手车市场有了更深入的了解,我在判断一辆车是否值得购买时会更加自信。
但是 EDA帮我找到了最适合我买的车吗?并没有!我不能直接购买数据集中列出的那些车,因为它们是历史列表。如果其中有车仍然在售,我也不知道哪些车在售,因为它们并没有特别标注。我没有关于实际能购买到的车的数据,因此我仍然需要自己去找这些车——而且,只有当我在搜索不同的数据集时,我发现的趋势才可能帮助我找到好车。

当关于过去数据的趋势被“抛到墙外”传递给那些关注当前/未来数据的决策者时,这些趋势更难以转化为有效的决策。(图像由作者创建)
这是所谓的数据与决策之间的壁垒的一个例子,这在实践中极为常见,因为绝大多数的数据集都由过去/历史数据构成,但我们的决策是面向当前/未来的。尽管 EDA 可以将大量历史数据集处理成一组有用的洞察,但这些洞察与实际决策之间存在脱节,因为它们只是通过类比来描述我的选择(即如果我愿意假设当前的二手车市场与过去的市场相似)。或许最好将这种方式做出的决策称为数据启发式决策而非数据驱动决策。一个真正的数据驱动决策应该基于一个描述实际决策的数据集——在这个案例中,就是一个由当前可用的汽车列表填充的数据集。
设置交易空间
交易空间探索,或者更具体地说,多属性交易空间探索(MATE),是一个数据驱动的决策分析框架。它最初于 2000 年在麻省理工学院(MIT)创建,经过数十年的完善和应用,至今仍在使用[3–7]。MATE 将价值导向的思维[8]引入大数据集的世界,明确目的就是通过使用这些数据所做的决策,增加创造的价值。
MATE 框架帮助决策者和数据科学家/分析师批判性地思考如何定义和构建决策问题,如何进行数据收集,最终如何探索数据以生成实际的、相关的洞察,并找到最佳解决方案。从高层次来看,MATE 被分为三个层次,分别对应这些步骤:定义、生成和探索。

MATE 的定义、生成、探索层将做出数据驱动决策所需的步骤划分为不同的任务,用于充分描述问题、收集必要的数据,然后可视化/分析结果。(图像由作者创建)
定义一个基本的 MATE 研究从几个核心概念开始:
利益相关者。 谁做出决策或受其影响?为了简化,我们假设我自己是购车决策的唯一利益相关者;然而,请记住,许多决策都有多个利益相关者,他们的需求和愿望可能截然不同,我们可以并且应该考虑到所有这些因素。
替代方案。 可能的解决方案是什么,即有哪些可用的选择?在这个例子中,我已经将自己限制为购买二手车。我的替代方案是任何在我住处合理范围内待售的二手车。重要的是,替代方案应该是唯一的:我可以通过制造商、型号和年份等基本变量来定义我的选择,但像车辆识别码(VIN)这样的唯一标识符也是必要的,以防出现多个相同车型的列表。
资源。 利益相关者如何获取并使用一个替代方案,即需要花费什么?每辆车都有一次性的购买价格。我也可以选择考虑稍后产生的所有权成本,如燃油和维修费用,但现在我们暂时忽略这些。
利益。 为什么我们要选择一个替代方案,即利益相关者用什么标准来评判一个替代方案的“好坏”?也许我关心的是一辆车能承载多少乘客(为了实用性)、它的发动机气缸数(为了娱乐性)、它的里程数(为了耐用性)和它的安全评级(为了…安全)。
这个简单的框架为我们提供了如何在生成步骤中收集数据的方向。为了正确捕捉这一决策,我需要收集我在定义步骤中识别出的所有替代方案变量、资源和利益。如果少了任何一点,我将无法全面了解价值——但我可以随时加入任何我认为有用的额外变量。

在收集数据之前完成定义层是有用的,可以确保收集工作充足,避免浪费时间在不必要的参数上。(图片由作者提供)
假设一下,Craigslist 汽车数据集确实包括一个列,指示哪些列表仍然可以购买,因此是我决策的真实替代方案。我是否完成了数据收集?没有——这个数据集包含了我的替代方案变量(制造商、型号、年份、VIN)和资源(价格),但缺少了两个利益:乘客数和安全评级。我将需要补充额外的数据,否则我无法准确评估自己喜欢每辆车的程度。这需要分析师做一些额外的工作来获取新的数据,并将其正确匹配到现有数据集中的新列。
幸运的是,替代方案变量可以充当“键”来交叉引用不同的数据集。例如,我需要为每个替代方案找到一个安全评级。安全评级通常会给出某一款车型/型号/年份的车,因此我可以选择:
查找有关安全评级的表格数据(由其他人编制),并通过在制造商/型号/年份列上连接表格,将其与我自己的数据合并。
我自己收集安全评级数据,并直接将其输入到我的表格中,例如通过搜索
www.nhtsa.gov/ratings,使用每个备选项的品牌/型号/年份。
我还可能想通过补充其他备选项来丰富 Craigslist 的数据:毕竟,并非所有二手车都在 Craigslist 上出售。MATE 的最佳实践是尽可能从一个较大的备选集开始,避免过早限制决策。通过访问附近汽车经销商的网站并搜索他们的二手库存,我可以将更多的汽车添加到我的数据集中作为额外的行。根据可用汽车的数量(以及我自己的动机),我甚至可以通过网页抓取器自动化这个过程,这通常是大规模数据收集的方式。但请记住:我仍然需要为数据集中每辆车提供至少备选变量、资源和效益的数据。大多数经销商的列表不会包括像安全评级这样的详细信息,因此我需要像之前一样使用其他数据源来补充这些信息。
到目前为止,我已经拥有了我的“稻草堆”数据集,快要开始探索层并寻找那个“针”了。但我该怎么做呢?MATE 到底与我们老朋友 EDA 有何不同?
什么构成了一个好的解决方案?
现在我已经有了一个包含实际决策所需备选项的完整数据集,我能否仅仅通过对其进行 EDA 来解决问题并找到最好的汽车?嗯……是的,也不是。你可以(并且应该!)对 MATE 数据集进行 EDA —— 清理潜在的错误或异常依然是相关的,尤其是在数据是通过自动化过程(如网页抓取器)收集时,它尤为重要。而构建对数据趋势的直觉理解的目标没有不同:我们对不同标准之间关系的理解越深,我们对最终决策的信心就越强。例如,我前面展示的散点矩阵也是 MATE 常见的可视化方式。
但即使拥有一个活跃的汽车列表数据集和所有必要的变量,EDA(探索性数据分析)中常见的相关性和分布分析仍然无法帮助提取出单个高价值的数据点。记住:我们关心的是汽车的许多不同属性(多属性交易空间探索中的多属性),所以我们不能仅仅按价格排序并选择最便宜的汽车。即使有了 EDA 趋势分析的帮助,我仍然需要手动检查很多潜在选项,直到找到一辆具有理想组合的特征、性能和价格的汽车。
我需要的是一个工具,将最佳的汽车从堆中挑选出来。这个工具就是:价值建模。
在最基本的层面上,价值模型是一个数学函数,旨在复制利益相关者的偏好。我们输入在定义层中确定的效益和/或资源,然后得到一个价值分数,表示每个备选方案的“好坏”。如果模型准确,我们的利益相关者会更喜欢一个分数较高的备选方案(汽车),而不是任何其他分数较低的方案。[9]

价值模型的参数是为了模拟利益相关者的偏好而创建的,这样,如果输入汽车的效益/资源指标,模型就会返回一个分数,可以用来自动地对其进行排名,和其他汽车进行比较。(图像由作者创建)
大多数数据科学家可能已经多次创建和使用过简单的价值模型(无论他们是否意识到,或者是否称之为不同的名称),作为完成这一任务的一种手段:在数据集中创建一个新的列,用来根据其他列的函数对行进行“评分”,以便对数据集进行排序,并突出显示高评分的行。价值模型有许多不同类型,每种类型都有其优缺点。更精确的价值模型通常更复杂,相应地需要更多的工作来创建。
在这个示例中,我们将使用一个简化的效用函数来合并我从买车中获得的四种利益[10]。虽然有一个正式的引导过程,可以与利益相关者(我)一起完成,以创建一个可验证正确的效用函数,但我们将通过为每个属性分配一个阈值 要求(最差可接受水平)、目标(最大价值水平,超过这个点没有额外价值)和权重值(重要性的度量)来快速构建一个。还有其他定制效用函数的方法,包括非线性曲线和互补/替代效应,但我们这次将跳过这些。

每个属性都有一个定义好的效用曲线(在这种情况下是线性的),介于需求和目标之间,并且有一个用于将单一属性效用合并成多属性效用的权重值。背景中的柱状图显示了数据集中该参数的分布。(图像由作者与 EpochShift 共同创建)
但是等一下:为什么我没有在效用函数中包括价格呢?技术性的回答是,大多数人在利益和资源之间展示的是“不完全排序”[11]——一种复杂的说法,意思是利益相关者通常无法明确说明他们更喜欢低成本低收益的选择,还是高成本高收益的选择,因为两者都没有严格优于对方。顺便说一下,这也是为什么通过优化一个函数来“解决”决策是非常困难的原因:在实践中,决策者通常喜欢看到一个集合的选择,从低成本低收益到高成本高收益,并自行判断,而这比将利益/成本合并成一个价值模型更可靠。这个集合叫做Pareto 集(或者在图形中称为Pareto 前沿),它是交易空间中最高价值的区域。

这是一个带有 Pareto 前沿高亮的交易空间散点图——最理想的选择通常位于前沿附近。请注意,左上角的收益率小于 100%,因为数据集中的并非所有汽车都符合我们的要求。(图片由作者使用 EpochShift 创建)
这里就是它:“交易空间”。这是一个以收益为 y 轴,成本为 x 轴的散点图(每个轴可能是由多个属性组成的价值模型)。收益与成本之间的权衡是迄今为止最常见的现实世界决策框架,而 MATE 框架的存在就是为了引导我们基于数据的决策分析,使之符合利益相关者和决策者熟悉的结构。每一个点都是我可以购买的汽车——是我可以做出的选择,是解决决策的一种方式,而不依赖于从过去趋势中外推。
剩下的就是探索层,我需要在那些散点堆中找到我最喜欢的“针”。
寻找针尖
让我们快速探索一下我的汽车数据集,使用EpochShift,这是一个用于创建和探索 MATE 数据集的程序(完全披露:我目前正在参与 EpochShift 的开发,但你可以花些时间并动手使用你选择的可视化库复制这些图表)。首先,因为它们很可能是我想购买的汽车,我将用火焰标记 Pareto 集中的汽车——这是一个自定义标记,出现在图表“顶部”,即使我改变图表维度,它也会保持显示。我还对里程表的行驶里程与交易空间中的价值维度之间的关系感兴趣,因此我将用这个参数为这些点着色。

这是交易空间,按里程表的行驶里程着色,并用品红色三角形标出 Pareto 集中的汽车。(图片由作者使用 EpochShift 创建)
我马上就看到两个问题:
里程表里程的模式有点难以辨认,因为 17,000 多辆车的交易空间非常密集,点遮挡彼此:一些点会覆盖其他点。我可以看出,点通常在沿着 y 轴向上移动时从深到浅变淡,但如果我能消除遮挡,我就能更清楚地看到不同里程层次在我的效益/资源维度上的分布。
我也无法看出帕累托集中车辆的里程表里程,因为它们是品红色的。如果我仍然可以突出显示这些车辆同时看到它们的里程表里程颜色,那将是理想的。
为了解决这些问题,我将以两种方式修改我的图表:
我将用凸包替换交易空间中的点 — 实质上将里程表里程范围划分为更小的块,并在每个块周围画一个“气泡”。
我将保留我的帕累托集火焰,但将其更新为填充对应里程表里程颜色的三角形。
通过这两个改变,我得到了这个:

将交易空间表示为里程表里程上的凸包,帕累托集中的个别车辆仍然由火焰标记着。(作者使用 EpochShift 创建的图像)
看那个!我可以看到里程表里程和效用之间的明显关系,这是有道理的,因为这是我在价值模型中使用的效益指标之一。此外,除了图表右下角的一对(妄想的)卖家外,很明显,里程较高的车辆的最高要价较低 — 但或许更有趣的是,里程似乎不明显影响低端要价。即使是低里程的车也可以便宜买到!
但让我们回到通过关注帕累托集来做决定。我没有筛选数据集以删除高成本车辆,因为,正如我之前提到的,被认为是 MATE 最佳实践在开始探索层之前不减少替代品数量。但实际上,我对这次购买有$10K 的预算,并且也许我想要在这个限制下找到最好的车 — 尤其是现在我知道我仍然能够在这个价格范围内找到低里程的车。我将添加我的预算要求并切换回散点图:

交易空间,按里程表里程着色,x 轴上的预算要求为$10K。请注意,由于预算,左上角的产出进一步降至约 36%。(作者使用 EpochShift 创建的图像)
好的,现在我们正在看一个价格实惠的车辆的良好放大。如果我只想在我的预算内购买效益最大的车辆,那将是帕累托集中最靠上和最靠右的点。我可以通过鼠标悬停来查看该车辆的详细信息:

将鼠标悬停在一个点上会显示详细信息的工具提示。(作者使用 EpochShift 创建的图像)
一辆行驶了 14,000 英里的 2006 年克莱斯勒 300,售价 9,000 美元。还不错!但等等… 它是白色的。我忘了我讨厌白色车!MATE 的探索层的一部分是细化利益相关者的偏好,这些偏好在接受新信息时经常发生变化:我的数据驱动决策的数据。使用交互式工具的一个优势是我可以轻松根据这些变化更新价值模型或过滤器。我只需添加一个去除白色车辆的过滤器,保存一个新的帕累托集,然后:

再次过滤 tradespace,去除白色车辆,并保存一个带有绿色方块标记的新帕累托集。(作者使用 EpochShift 创建的图像)
就是这样!现在我预算中最好的车是一辆银色的 2009 年大众 Touareg。与克莱斯勒相比,我失去了一个乘客(6 到 5),这并不理想,但由于它的里程表明显较低(14,000 到 158),这辆车的效用几乎一样高。它几乎是全新的,只要 3,500 美元!
我们找到了:大海捞针。我们可以用数据驱动的证据来证明我们的决策!

这是一见钟情 — 但我必须从其他 20,000 辆车中找到它!(照片由 OSX 拍摄,已释放至公共领域)
结论
本文的目标是展示 EDA 和 tradespace 探索在相似/互补之处,但强调在收集和可视化数据时寻找数据集中“最佳”点的最终目标时存在一些关键差异。 Tradespace 探索可以是从数据驱动到真正数据驱动的决策的“一步之外”。但这个简短的例子只是揭示了我们可以用 MATE 做什么的冰山一角。
MATE 框架支持许多其他类型的数据和分析。更复杂和/或重要的决策可能需要这些功能,以使您(或您的利益相关者)对已达成的最佳决策感到自信:
包括多个利益相关者,每个利益相关者都有自己的价值模型,并寻找对所有方面都可取和公平的解决方案
包括不确定性,导致多种情景(每种情景都有自己的 tradespace),并寻找对未知或不可控参数鲁棒的解决方案
添加动态和选项,允许情景随时间反复变化,同时我们灵活修改解决方案以最大化价值
许多新颖有趣的可视化图表来支持这些复杂问题的分析
如果这引起了您的兴趣,并且您想了解更多关于 tradespace 探索的信息,请在 Medium 上关注我或The Tradespace,我们正在努力教导各种技能水平的人如何最好地将 MATE 应用于他们的问题。或者如果您有任何问题或有趣的见解,请随时通过电子邮件/评论与我联系 — 我希望收到您的来信!
参考文献
[1] A. Reese,《二手车数据集》(2021),www.kaggle.com/datasets/austinreese/craigslist-carstrucks-data
[2] T. Shin,《探索性数据分析的全面逐步指南》(2020),《Towards Data Science》。
[3] A. Ross, N. Diller, D. Hastings, 和 J. Warmkessel,《航天系统设计中的多属性交易空间探索》(2002),IAF IAC-02-U.3.03,第 53 届国际宇航大会——世界航天大会,美国德克萨斯州休斯顿。
[4] A. Ross 和 D. Hastings,《交易空间探索范式》(2005),INCOSE 国际研讨会 2005,美国纽约州罗切斯特。
[5] M. Richards, A. Ross, N. Shah, 和 D. Hastings,《评估动态多属性交易空间探索中的生存能力的指标》(2009),《航天器与火箭杂志》,第 46 卷,第 5 期,2009 年 9 月-10 月。
[6] M. Fitzgerald 和 A. Ross,《多利益相关者交易空间探索框架建议》(2016),INCOSE 国际研讨会 2016,苏格兰爱丁堡,2016 年 7 月。
[7] C. Rehn, S. Pettersen, J. Garcia, P. Brett, S. Erikstad, B. Asbjornslett, A. Ross, 和 D. Rhodes,《工程系统可变性水平量化》(2019),《系统工程》,第 22 卷,第 1 期,页码 80–94,2019 年 1 月。
[8] R. Keeney,《价值导向思维:创造性决策的路径》(1992),哈佛大学出版社,剑桥。
[9] N. Ricci, M. Schaffner, A. Ross, D. Rhodes, 和 M. Fitzgerald,《通过互动可视化探索利益相关者价值模型》(2014),第 12 届系统工程研究大会,美国加利福尼亚州雷东多海滩,2014 年 3 月。
[10] R. Keeney 和 H. Raiffa,《具有多重目标的决策:偏好与价值权衡》(1993),剑桥大学出版社。
[11] J. Von Neumann 和 O. Morgenstern,《博弈论与经济行为理论》(1944),普林斯顿大学出版社。
从数据科学家到 AI 开发者:2023 年构建生成式 AI Web 应用的经验教训
为任何希望构建能够服务成千上万用户的 AI Web 应用的数据科学爱好者提供的技术技巧指南
·发布于Towards Data Science ·阅读时间:12 分钟·2024 年 1 月 21 日
--

来源:DALLE-3
如果你像我一样,进入了数据科学的领域(无论是通过大学课程还是通过无数在线课程),你可能也曾梦想过创造一个人们可以使用的机器学习/人工智能软件产品。一个像我们计算机科学的朋友们似乎轻松编写出来的产品。
但是,如果你曾尝试过全栈网页开发,你很快就会面对配置、部署、终端命令、服务器等看似不可逾越的难题。

这是我在编写应用程序的初期,和我的大学室友之间无数次沮丧对话中的一幕。图片来自作者。
我对此深有体会,曾经在无数个小时里感到无助,这只让我更加深信自己永远无法制作出一个能正常运行的软件应用。
但是恰好一年前,1 月 21 日,一个因为护照问题和取消的旅行而意外空出的周末,我开始了一个制作 AI 应用程序的旅程。这是一次带我去到意想不到地方的旅程——我与一个位于世界另一端的联合创始人合作,加入了一个旧金山初创公司加速器,并最终发展到拥有数千用户并创造了可观年收入的规模(查看我的应用程序,Podsmart! 我们总结播客)。

我的应用演示页面位于 Buildspace 启动加速器上。来源:buildspace.so/s3/demoday/Podsmart
但最重要的是,这是一段充满挫折、回溯、错误和返工的旅程。它是关于如何在没有正式计算机科学/软件工程背景的情况下,穿越困惑的开发世界。
所以,回顾过去一年来构建我的第一个软件产品,我整理了一些技术建议——这是为任何数据科学爱好者准备的,旨在帮助你构建一个能够服务成千上万用户的功能性 Web 应用。
本指南源于我一年的挣扎与学习,代表了我希望能告诉年轻一岁的自己的一些建议。
免责声明:这些建议来自我的个人经历,可能对其他人效果不同。我与这里推荐的任何工具没有任何关系或合作。
目录
· 你想要构建的内容
· YouTube Web 开发教程的危险
∘ 提示 #1: 使用 Next.js 代替 React
∘ 提示 #2: 使用 Tailwind CSS 代替 Bootstrap 进行样式设计
· 数据科学思维模式的局限
∘ 提示 #3: 为后端选择 FastAPI 而非 Flask,并严格定义响应模型
∘ 提示 #4: 使用 TypeScript 代替 JavaScript
· 关于部署……
∘ 提示 #5: 为 GPU 后端使用 Modal
∘ 提示 #6: 使用 AWS Lambda 部署后端,使用 Vercel 部署前端
· 让生活更轻松
∘ 提示 #7: 不要用 React 自建着陆页
∘ 提示 #8: Firebase + Stripe 用于用户认证和支付
∘ 提示 #9: 实现 Sentry 进行错误监控
· 结论
你想要构建的内容
要构建一个功能完整的 Web 应用,你需要一个Web 界面(前端或客户端)供用户交互,同时还需要一个服务器(后端)来进行数据处理、数据存储,并调用 ML/AI 模型。
(你可能听说过Streamlit,它非常适合做最简单的演示,但它确实缺乏自定义功能,无法做出一个可行的生产应用)
YouTube Web 开发教程的危险
作为数据科学家,软件开发的许多方面让我充满忧虑,例如配置错误导致浪费数天时间。没有什么比看到某个东西坏掉,却不知道为什么坏掉以及如何修复它更令人沮丧的了。
结果是,我非常依赖逐步教程,尤其是在 YouTube 上,这些教程展示了从头到尾的整个过程,例如如何设置 React 项目、部署后端或网站等。
回顾过去,这里有两个主要的缺点:
首先,困惑于多个冲突的和可能已过时的教程(例如,随着 React 的新版本发布)。这常常让我在跟着教程做时,突然发现它已经不再适用了。
其次,大多数教程旨在构建酷炫的课堂演示,适合初学者。因此,它们使用框架并强化那些性能上限低的编码模式,这对生产环境和扩展性来说是不够的。回顾过去,我从 YouTube 教程中学到很多不好的编码习惯,现在这些习惯已经成为我将应用程序开发成一个可服务成千上万用户的在线产品的障碍。
由于你从失败中学得最多,这个过程尽管令人沮丧,却成为了我这一年中的一次巨大学习经历。希望你能从我的失败中节省大量时间。
技巧 #1:使用 Next.js,而不是 React

在 YouTube 上搜索“全栈应用教程”会找到很多 React 教程。来源:www.youtube.com/results?search_query=full+stack+app+tutorial
很多 YouTube 教程推荐使用 React,最初我也跟风使用了它。
然而,最终我希望提高网站的 SEO 性能——这对获取更多用户至关重要。React 的一些局限性,比如无法动态更改 meta 标签,缺乏服务器端渲染,令人沮丧,迫使我做出繁琐的切换,转向 Next.js。切换后,性能差异简直是天壤之别。

Vercel 提供了大量的 Next.js 模板,供你快速启动网站开发。来源:vercel.com/templates/next.js
有些人说 React 更适合初学者,但网上有很多 Next.js 模板,比如 Vercel(Next.js 的创造者),尤其是 AI 应用。Next.js 真的是几乎所有 AI 应用中使用的现代 Web 框架。
技巧 #2:为样式选择 Tailwind CSS,而不是 Bootstrap
在开始我的前端 UI 之旅时,我最初有些天真地跟随了一波前端教程,走向了 Bootstrap。它的吸引力?承诺通过现成的组件(如下拉菜单和手风琴组件)来简化开发。

“Bootstrap 外观”——2023 年 2 月 20 日,我的网站看起来有多丑。图片来源:作者。
然而,过了一段时间,我意识到我的网站看起来……真的很丑,尤其是与那些流线型、现代的 AI 演示页面相比。那种明显的“Bootstrap 外观”——一种审美上的固执,抵制自定义,纠缠于一堆混乱命名的 CSS 类名中。所以,最终,我再次咬紧牙关,重新用 Tailwind CSS 完全重做了我的前端,花了整整三天时间。

这个 AI 演示页面肯定不是用 Bootstrap 构建的。来源:restorephotos.io
如果你曾经见过一页现代且简洁的 AI 演示页面,那么它很有可能使用了 Tailwind CSS。

Tailwind CSS 及其实用类使得定制每个组件变得异常简单。图像来自作者。
起初,我对 Tailwind 感到害怕 — 它冗长的组件定义,充满了看起来像是难解的实用类,似乎一点也不适合初学者…我曾认为 Tailwind 缺少预构建的组件,并且记住这些实用类会很麻烦。然而,事实恰恰相反!有许多很棒的 UI 组件库是基于 Tailwind CSS 构建的 — 我使用了 Flowbite React(它包含我所需的所有组件!)
数据科学思维方式的陷阱
作为一名数据科学学生,我逐渐喜欢上了 Python,它简洁且功能强大的代码语法让我受益匪浅。Python 的类型推断免去了我为每个变量定义类型的繁琐工作(这个任务我觉得很麻烦,尤其是在我在入门计算机科学课程中遇到的像 Java 这样的语言中)。
因此,我使用 JavaScript 作为前端,Python 作为后端,除非必要,否则避免定义 API 端点的类型。
然而,随着我的应用变得越来越复杂,前后端之间大量意外的类型错误大大降低了我的编码生产力。我终于理解了我的计算机科学朋友们坚持要求显式类型的重要性。事实证明,类型定义的细致入微不仅仅是“死板的学究式” — 它是必不可少的。
技巧 #3:为你的后端选择 FastAPI,并严格定义响应模型
如果你在 YouTube 上搜索 Python 后端教程,大多数视频会指引你使用 Flask。就像坏掉的时钟一天会对两次一样,我不知怎么地选择了 FastAPI 作为我的 Python 后端,事后看来,确实是正确的决定。
(尽管好笑的是,我之前完全忽视了 FastAPI 的这一优势。直到最近,我才理解需要为 POST 请求定义 Pydantic 类,之前我觉得这比帮助还麻烦。)
FastAPI 具有几个改变游戏规则的优势:
自动生成的 API 文档 — 这对于未来加入的工程师(或者你未来的自己)理解后端结构将非常有用!
更容易编写代码 — 由于 FastAPI 基于 Json schema,定义路由比 Flask 更加简洁且易于实现 — 结果是,对于像我这样的新手,学习曲线更低。
更好的性能 — FastAPI 显然比 Flask 快得多,并且消耗更少的内存 — 这对于我的应用发送大量负载非常有帮助。

使用 Pydantic 构建数据模型,你可以用它来定义 FastAPI 路由的响应类型。图像来自作者。
但最重要的是 FastAPI 的类型注解。
FastAPI 基于 Pydantic,这是一款数据验证库,允许你将数据的“形状”定义为具有属性的类。
使用 FastAPI,您可以为每个 API 路由注解输入和输出类型,使用 Python 类型提示和 Pydantic 定义的类。
这确保了每个路由的输出都具有一致的数据结构。但要充分发挥这个功能,我们还需要……
提示 #4:使用 TypeScript 而非 JavaScript
很长一段时间里,我手动编写前端的获取方法(再次学习来自全栈教程),因此将新路由添加到我的应用程序是一个冗长且容易出错的过程。
因此,当我的大科技公司软件工程师朋友告诉我,您可以通过 API 规范自动生成 TypeScript 客户端代码时,我感到非常震惊。(**更多 FastAPI 文档请见这里,其中一个这样的包是 openapi-typescript-codegen)

通过自动生成的 TypeScript 客户端代码,您的获取方法将根据 FastAPI 端点响应模型提供自动补全和文档。图片来自作者。
一瞬间,我意识到这将同时解决两个主要挑战:消除我手动且容易出错的客户端获取编写,并确保后端和前端之间的类型一致性。这显著减少了那些削弱我的应用程序可靠性的持续类型错误。
当然,为您的后端路由设置类型约束只有在您的前端也强制执行这些类型约束时才有意义——这自然需要TypeScript。
因此,我目前正在经历定义 FastAPI 后端响应模型的艰难过程,并将我的前端从 JavaScript 转换为 TypeScript。如果您从一开始就使用 FastAPI 和 TypeScript,这个过程是可以避免的!
关于部署……
通过我的数据科学 / 机器学习课程,我已经习惯了进入 Google Colab,按下播放按钮,然后代码就运行了。所以,部署的想法让我感到恐惧也不足为奇。但正如 Buildspace 加速器的创始人所说,您需要“GTFOL”(Get The F Off Localhost)才能让您的软件应用对外开放。因此,我自然希望部署过程尽可能无痛。
提示 #5:为 GPU 后端使用 Modal
如果您想部署自己的模型(例如机器学习模型、图像识别、Whisper 转录,或最近的开源大语言模型如 Llama),您需要一个 GPU 云服务提供商来托管您的模型。
我的建议是选择Modal,并永远不要回头。
Modal 以其卓越的文档和学习资源脱颖而出,提供了最新应用程序的示例代码——从微调开源大语言模型到为 LLM 聊天机器人提供服务等等。
实际上,我的整个播客转录应用是通过分叉 Modal 的示例音频转录代码开始的,所以可以毫不夸张地说,没有 Modal,我是无法构建我的应用的。

Modal 的仪表板在监控和错误追踪方面非常用户友好。图像来源:作者提供。来源:modal.com
Modal 在用户友好性方面表现突出(作为一个讨厌部署的人,这已经说明了很多问题)。只需在本地代码编辑器中编写云函数,通过一个终端命令将其部署。它的仪表板非常用户友好(特别是与 AWS 相比),让我可以轻松地追踪应用的使用情况、分析性能并追踪错误。
最后,Modal 作为我在 Lambda 没有或实现起来麻烦的功能(例如文件存储,这在下一点中会派上用场……)的逃生阀。
提示 #6: 使用 AWS Lambda 进行后端部署,使用 Vercel 进行前端部署
在托管我的 Python 后端时,我对于是使用 Amazon EC2 还是 AWS Lambda 感到困惑。我的应用需要存储音频文件(这些文件可能很大),而 Lambda 的无服务器架构并不适合存储文件(它有 2 GB 的临时存储,但不持久),所以我曾认为必须使用 Amazon EC2。但是,EC2 的配置非常繁琐,而且作为一个始终在线的专用实例,它会更贵,并且更难以扩展。
这时,Modal 提供的免费文件存储成了救星,我能够将我的后端架构调整为兼容 Lambda,同时在需要时将文件下载并存储到 Modal 上。
幸运的是,这个视频真的非常好,完全按照他们的指示操作,我成功地部署了我的后端。
对于我的前端,Vercel 就是我所需要的一切。这个过程非常简单,除了域名费用,完全免费。
让生活更轻松
最后的三个杂项提示,可以帮助你避免在开发中浪费大量时间……
提示 #7: 不要使用 React 构建你自己的着陆页
这是我又一次犯的错误,因为所有那些全栈教程让我误以为必须用 React 编写自己的着陆页。没错,你可以这么做(我也做了),但是在性能和美学上会有一个低限——而这正是成功的着陆页所需要的关键特点。
React 更适合用于像实际的 AI 应用界面这样的自定义功能。对于仅包含静态内容的着陆页,你应该使用像Webflow或Framer这样的无代码网站构建工具来快速构建着陆页(并将着陆页的创建外包给你的设计师,这样你就可以专注于其他事情!)
提示 #8: 使用 Firebase + Stripe 进行用户认证和支付
在用户身份验证方面,市面上的选项和教程多得让人眼花缭乱。我需要一个不仅处理身份验证,还能与支付系统集成的解决方案,以根据用户的订阅状态控制访问权限。
在尝试并失败于使用几种不同的身份验证解决方案(例如 auth0)之后,我发现 Stripe + Firebase 的组合非常有效。Firebase 有一个 Stripe 集成,能在支付成功后更新用户的订阅状态,Firebase 的 React 客户端处理客户端身份验证,而 Python 客户端则很好地管理服务器访问控制。参考这两段视频(这里和这里)让我成功地将这一方案实现到我的应用中。
技巧#9:实施 Sentry 进行错误监控
几个月来,我完全不知道用户在我的应用上线后遇到了什么 bug。只有当我或用户发现 bug 时,我才会通过 AWS Cloudwatch 界面去排查后端的 bug。

Sentry 在你的应用生产环境中追踪错误(包括前端和后端)。图像由作者提供。来源:sentry.io。
这一切持续了直到我的联合创始人向我介绍了Sentry,一个用于云应用性能监控和错误追踪的工具。它非常容易为前端和后端初始化,你甚至可以将其与 Slack 集成,实时接收错误通知。只是要小心不要把你免费计划的每月错误预算浪费在像身份验证超时这样的琐碎但频繁的错误上。这就是我当时遇到的情况——我不得不订阅付费计划才能找到我真正想解决的重要 bug 的日志。
**额外技巧#10:**不要尝试使用 Spotify 的 API 构建 Web 应用!我浪费了 2 个月的时间,假设我可以集成 Spotify 的 API,让用户加载他们保存的播客。但是要实现这一点,你需要申请配额扩展请求,而 Spotify 审核这个请求需要一个多月的时间。如果你的应用涉及任何 AI/ML 模型(尽管我的应用实际上并没有用 Spotify 数据训练任何模型),他们可能会拒绝你的申请(根据他们的开发者政策中禁止的措辞)。
结论
我希望这篇技术指南能为其他数据科学爱好者解开一些 Web 应用开发的神秘面纱。
如果你觉得这篇文章有帮助:
不妨看看我在 Medium 上的其他文章:如何使用 AI 总结长文本,使用深度学习生成音乐
在LinkedIn或Twitter/X上关注我,或者通过消息或评论与我联系!我很乐意就所有与数据科学和人工智能相关的内容交流想法。
感谢阅读!
从数据科学家到数据经理:我领导团队的前三个月
从实践工作到指导和领导的反思
·发布于 Towards Data Science ·阅读时间 8 分钟·2024 年 11 月 26 日
--
这是我在数据科学领域的第七年,旅程充满了仪表盘、指标、分析和模型。但在 8 月,我进入了一个全新的领域:第一次成为一名人员经理。老实说,过去每当被问到我的职业目标时,我总是说我更喜欢留在 IC 路径上。我喜欢技术挑战,并且享受从头到尾掌控项目的感觉。然而,当这个机会出现时,我决定尝试一下。毕竟,只有尝试过,才知道某件事是否适合自己。
在本文中,我将分享作为一名经理的初步经验——有什么变化,什么我喜欢,什么又是挑战。如果你正在纠结是否选择 IC(个人贡献者)路径还是人员管理路径,我希望这篇文章能为你提供一些启示。

图片来源:Mimi Thian 在 Unsplash
我的管理之路
为了设定背景,让我分享一下我是如何转型为一名人员经理的。当我四年前加入这个团队时,团队中的每个人都是一名**“全栈”数据科学家**——我们每个人都支持一个特定领域,并且负责从构建数据管道、定义业务指标和仪表盘,到分析、实验和建模的所有工作。这个框架...
从数据科学家到机器学习/人工智能产品经理
有关如何为成功过渡做准备的见解和建议
·发布于Towards Data Science ·7 分钟阅读·2024 年 4 月 3 日
--

图片由Holly Mandarich提供,来源于Unsplash
随着人工智能越来越流行,越来越多的公司和团队希望开始或加大对其的利用。因此,许多职位正在出现或在市场上变得越来越重要。一个很好的例子就是机器学习/人工智能产品经理这一角色。
以我为例,我在两年前从数据科学家职位转型为机器学习产品经理。在此期间,我看到与这一职位相关的工作机会不断增加,相关的博客文章和演讲也越来越多,许多人正在考虑转型或对这一领域产生兴趣。我也能确认自己对这个角色的热情,以及我多么喜欢日常工作、责任和能够为团队和公司带来的价值。
人工智能/机器学习产品经理(AI/ML PM)的角色仍然相当模糊,并且几乎与前沿 AI 技术发展速度一样迅速。虽然许多产品团队借助插件解决方案和生成型 AI(GenAI)API 变得相对独立,但我将专注于在核心机器学习团队中工作的 AI/ML 产品经理角色。这些团队通常由数据科学家、机器学习工程师和研究科学家组成,与其他角色一起参与一些传统机器学习应用场景,其中生成型 AI 通过 API 可能无法满足需求(如传统的机器学习用例、大型语言模型微调、特定的内部用例、机器学习作为服务的产品等)。对于此类团队的一个示例,您可以查看我之前的一篇文章*“在一个多学科的机器学习团队中为用户创造价值”*。
学习我们如何从有限的机器学习模型演变为端到端的机器学习解决方案,作为一个团队一起成长
在这篇博客文章中,我们将涵盖这个职位所需的主要技能和知识,如何达到这些技能,以及根据我在这一转变中的经验,总结的学习和建议。
机器学习产品经理最重要的技能
成功成为一名机器学习/人工智能产品经理需要许多必要的技能和知识,但最重要的技能可以分为四个组别:产品战略、产品交付、影响力和技术流利度。让我们深入探讨每个组别,进一步理解每个技能的含义以及如何获得这些技能。

机器学习/人工智能产品经理的四项关键技能,图片由作者提供
产品战略
产品战略是了解用户及其痛点,识别正确的问题和机会,并基于定量和定性证据对其进行优先排序。
作为一名前数据科学家,对我来说,这意味着要爱上问题和用户的痛点,而不是过于依赖具体的解决方案,思考如何在为用户创造更多价值而非应用这个酷炫的新 AI 模型的方向上努力。我发现,清晰理解 OKRs(目标与关键结果)并关注计划的最终影响(交付结果而非输出)是至关重要的。
产品经理需要优先处理任务和计划,因此我学会了平衡每个计划的努力与回报的重要性,并确保这影响到决策,决定什么以及如何构建解决方案(例如,考虑项目管理三角形——范围、质量、时间)。如果计划能够应对四大产品风险,它们就能成功:价值、可用性、可行性和商业可行性。
我用来学习产品战略的最重要资源有:
优秀与糟糕的产品经理,作者 Ben Horowitz。
每个人都推荐给我并且我现在推荐给任何有志成为产品经理的人参考书是《Inspired: How to create tech products customers love》,作者是 Marty Cagan。
另一本帮助我更接近用户空间和用户问题的书是《Continuous Discovery Habits: Discover Products that Create Customer Value and Business Value》,作者 Teresa Torres。
产品交付
产品交付是指能够有效地管理团队的项目,向用户交付价值。
我开始时通过理解产品功能的各个阶段(发现、规划、设计、实施、测试、发布和迭代),以及每个阶段对我作为数据科学家的意义。接着,我学习了如何*“高效”*地带来价值:从小做起(通过最小可行产品和原型设计),通过小步伐和迭代快速交付价值。为了确保项目朝着正确的方向发展,我发现持续衡量影响(例如通过仪表板)并从定量和定性数据中学习,结合见解和新的学习调整下一步的行动非常关键。
想要了解产品交付,我推荐:
一些之前分享的资源(例如《Inspired》一书)也涵盖了 MVP、原型设计和敏捷方法在产品管理中的重要性。我还写了一篇博客文章,探讨了如何在机器学习(ML)项目中考虑 MVP 和原型:当机器学习遇到产品——少即是多。
学习敏捷和项目管理(例如通过这个速成课程),以及学习 Jira 或你所在公司使用的项目管理工具(通过类似这个速成课程的视频)。
影响力
影响力是指获得信任、与利益相关者对齐并引导团队的能力。
与数据科学家的角色相比,作为产品经理的日常工作完全不同:不再是编码,而是沟通、对齐,以及大量的会议。出色的沟通和讲故事能力成为这个角色的关键,尤其是将复杂的机器学习话题解释给非技术人员听的能力也变得至关重要。与此相关,保持利益相关者的知情、为团队的辛勤工作提供可见度,并确保团队未来方向的一致性和支持(证明如何帮助解决最大挑战和机会,赢得信任)也变得很重要。最后,还需要学会如何提出挑战、说“不”、作为团队的保护伞,有时还需要传递不好的结果或坏消息。
我推荐的相关资源:
完整的利益相关者映射指南,Miro
对于任何数据科学家以及任何机器学习产品经理来说,必读的一本书是 Cole Nussbaumer Knaflic 的《通过数据讲故事——面向商业专业人士的数据可视化指南》。
想进一步了解作为产品经理如何影响和赋能团队,推荐阅读 Marty Cagan 和 Chris Jones 的《EMPOWERED: 普通人,非凡产品》。
技术流利度
对于一个机器学习/人工智能产品经理来说,技术流利度意味着在机器学习、负责任的人工智能、数据(一般而言)、MLOps 和后端工程方面的知识和敏感度。

机器学习/人工智能产品经理的技术流利度核心领域,图片由作者提供
你的数据科学/机器学习/人工智能背景可能是你最强的资产,一定要充分利用它!这些知识将帮助你与数据科学家使用相同的语言,深入理解并挑战项目,了解哪些是可能的或容易的,哪些不是,潜在的风险、依赖关系、边缘情况和局限性。
由于你将领导对用户产生影响的产品,因此包含负责任的人工智能意识变得至关重要。忽视这一点的风险包括伦理困境、公司声誉和法律问题(例如,特定的欧盟法律,如 GDPR 或人工智能法案)。在我的情况下,我从 Fast.ai 的课程实践数据伦理开始。
数据流利度也是必需的(你可能也已经具备):分析思维、对数据的好奇心、理解数据存储的位置、如何访问它、历史数据的重要性……除此之外,还需要了解如何衡量影响、与业务指标和 OKR 的关系,以及实验(A/B 测试)。
由于你的机器学习模型可能需要被部署,以便最终对用户产生影响,你可能需要与团队中的机器学习工程师合作(或具备模型部署知识的资深数据科学家)。你需要对MLOps产生敏感度:将模型投入生产、监控和维护它意味着什么。在 deeplearning.ai,你可以找到一门关于 MLOps 的优秀课程(机器学习工程与生产专业化)。
最后,可能你的团队中也会有后端工程师(通常负责将已部署的模型与平台其他部分进行集成)。在我的案例中,这是离我专业领域最远的技术领域,因此我不得不花一些时间学习并提高对后端的敏感度。在许多公司,PM 的技术面试会涉及一些后端相关的问题。确保你对以下多个工程领域有一个概览:CICD、预发布环境与生产环境、单体架构与微服务架构(每种架构的优缺点)、拉取请求、API、事件驱动架构……
总结和最终建议
我们已经涵盖了机器学习/人工智能产品经理最重要的四个知识领域(产品策略、产品交付、影响力和技术流利度),它们为何重要,以及一些可以帮助你实现这些目标的资源建议。
就像任何职业发展一样,我发现关键是要制定计划,并与经理和同事分享我的短期和中期愿望与期望。通过这个过程,我成功地从数据科学家转型为同一家公司中的产品经理(PM)。这使得过渡变得更加容易:我已经了解了业务、产品、技术、工作方式、同事……我还在公司内寻找导师和同事,向他们请教问题,学习特定的知识点,甚至为产品经理面试做准备。
为了准备面试,我专注于改变我的思维方式:从“是否要构建某个东西”或“是否要发布某个东西”转变为“如何构建”以及“如何发布”。我发现BUS(业务、用户、解决方案)是一种在面试中构建回答的好方法,并且可以在面试中强化这种新的思维方式。
我在这篇博客文章中分享的内容看起来很多,但实际上比学习 Python 或理解反向传播的原理要简单得多。如果你仍然不确定这个角色是否适合你,知道你始终可以尝试、实验,并决定是否回到你之前的角色。或者,谁知道呢,也许你会像我一样,爱上做机器学习/人工智能产品经理!
从数据到仪表盘:使用 Dash Leaflet 和 SeaRoute 库可视化古代海上丝绸之路
在 Python Dash 框架中构建一个交互式地图可视化,展示海上两点之间的最短路径
·发布于 Towards Data Science ·14 分钟阅读·2024 年 5 月 22 日
--

图片来自 Annie Spratt 提供的 Unsplash
我博士学位期间的一个难点是创建一个交互式地图可视化,展示船只从一个港口到另一个港口的航运路线。船只从起点到目的地港口的航程应该完全在海上,不可穿越任何陆地。令人惊讶的是,这个看似简单的任务在尝试用 Python 从零开始实现时竟然非常具有挑战性 🤷♀️。虽然有像 Marine Traffic 这样的商业解决方案可以完成类似的任务,但我一直在寻找开源的替代方案,并且很长一段时间内都没找到。直到 2022 年底,SeaRoute 库发布了 Python 版本(之前仅支持 Java),它极大地简化了我的工作。在这篇文章中,我将带你通过创建 Dash 应用程序的交互式地图可视化的过程,利用 Dash Leaflet 和 SeaRoute Python 库展示海上航线。
那么 Dash、Dash Leaflet 和 SeaRoute 呢?
从数据到可视化:使用 OpenAI 助理 API 和 GPT-4o
我们将探索 OpenAI 助理 API 中的代码补全工具,直接从数据创建可视化图表
·发表于 Towards Data Science ·阅读时间:13 分钟·2024 年 6 月 12 日
--

编程工具 — 图片由 Quaritsch Photography 和 Anton Savinov 提供,来源:Unsplash
随着 GPT-4 能力的不断扩展,基于其技术构建的 OpenAI 工具正在不断演变,成为开发者越来越强大的资产。
在本文中,我们将探索最新版本的图表制作功能。我们将为助理提供一个数据文件和具体指令,看看它如何将我们的数据可视化构思变为现实。
我们将使用助理 API 中内置的工具来实现这一目标。
目前,在 OpenAI Python 包(写作时为 v1.30.0)中,助理 API 包含工具:文件搜索、代码补全和函数调用。
函数调用允许开发者定义 AI 可以智能选择的函数,以完成任务;文件搜索允许开发者上传各种文件类型,这些文件可以以 RAG 风格存储在向量数据库中;代码补全让助理能够在沙盒环境中编写并运行 Python 程序,以解决编程和数学问题。
从默认 Python 折线图到期刊级质量信息图
将乏味的默认 Matplotlib 折线图转变为令人惊艳的定制可视化
·发表于Towards Data Science ·4 分钟阅读·5 天前
--

封面,图片由作者提供
每个使用过 Matplotlib 的人都知道默认图表有多难看。在这系列文章中,我将分享一些技巧,帮助你的可视化脱颖而出,并展现你个人的风格。
我们将从一个简单的折线图开始,这种图表被广泛使用。主要亮点是在线图下方添加渐变填充——这个任务并不完全简单。
那么,让我们深入并逐步走过这次转换的所有关键步骤吧!
让我们先进行所有必要的导入。
import pandas as pd
import numpy as np
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from matplotlib import rcParams
from matplotlib.path import Path
from matplotlib.patches import PathPatch
np.random.seed(38)
现在我们需要为可视化生成示例数据。我们将创建类似股票价格的图表。
dates = pd.date_range(start='2024-02-01', periods=100, freq='D')
initial_rate = 75
drift = 0.003
volatility = 0.1
returns = np.random.normal(drift, volatility, len(dates))
rates = initial_rate * np.cumprod(1 + returns)
x, y = dates, rates
让我们看看使用默认 Matplotlib 设置时它的样子。
fix, ax = plt.subplots(figsize=(8, 4))
ax.plot(dates, rates)
ax.xaxis.set_major_locator(mdates.DayLocator(interval=30))
plt.show()

默认图,图片由作者提供
真的不太吸引人,对吧?但我们将逐步改进它,让它看起来更好。
设置标题
设置通用图表参数——大小和字体
将 Y 轴的刻度移到右侧
更改主线的颜色、样式和宽度
# General parameters
fig, ax = plt.subplots(figsize=(10, 6))
plt.title("Daily visitors", fontsize=18, color="black")
rcParams['font.family'] = 'DejaVu Sans'
rcParams['font.size'] = 14
# Axis Y to the right
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
# Plotting main line
ax.plot(dates, rates, color='#268358', linewidth=2)

一般参数已应用,图片由作者提供
好的,现在看起来干净了一些。
现在,我们想在背景中添加极简网格,去除边框以使外观更加干净,并移除 Y 轴上的刻度。
# Grid
ax.grid(color="gray", linestyle=(0, (10, 10)), linewidth=0.5, alpha=0.6)
ax.tick_params(axis="x", colors="black")
ax.tick_params(axis="y", left=False, labelleft=False)
# Borders
ax.spines["top"].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines["bottom"].set_color("black")
ax.spines['left'].set_color('white')
ax.spines['left'].set_linewidth(1)
# Remove ticks from axis Y
ax.tick_params(axis='y', length=0)

已添加网格,图片由作者提供
现在我们添加了一个小的美学细节——在 X 轴的第一个刻度附近加上年份。同时,我们还将刻度标签的字体颜色变得更浅。
# Add year to the first date on the axis
def custom_date_formatter(t, pos, dates, x_interval):
date = dates[pos*x_interval]
if pos == 0:
return date.strftime('%d %b \'%y')
else:
return date.strftime('%d %b')
ax.xaxis.set_major_formatter(ticker.FuncFormatter((lambda x, pos: custom_date_formatter(x, pos, dates=dates, x_interval=x_interval))))
# Ticks label color
[t.set_color('#808079') for t in ax.yaxis.get_ticklabels()]
[t.set_color('#808079') for t in ax.xaxis.get_ticklabels()]

第一个日期附近的年份,图片由作者提供
我们离最棘手的部分越来越近——如何在曲线下方创建渐变。实际上,Matplotlib 并没有直接提供这样的选项,但我们可以通过创建渐变图像并将其与图表裁剪来模拟这一效果。
# Gradient
numeric_x = np.array([i for i in range(len(x))])
numeric_x_patch = np.append(numeric_x, max(numeric_x))
numeric_x_patch = np.append(numeric_x_patch[0], numeric_x_patch)
y_patch = np.append(y, 0)
y_patch = np.append(0, y_patch)
path = Path(np.array([numeric_x_patch, y_patch]).transpose())
patch = PathPatch(path, facecolor='none')
plt.gca().add_patch(patch)
ax.imshow(numeric_x.reshape(len(numeric_x), 1), interpolation="bicubic",
cmap=plt.cm.Greens,
origin='lower',
alpha=0.3,
extent=[min(numeric_x), max(numeric_x), min(y_patch), max(y_patch) * 1.2],
aspect="auto", clip_path=patch, clip_on=True)

添加了渐变,图片由作者提供
现在看起来干净整洁了。我们只需要使用任何编辑器(我偏爱 Google Slides)添加几个细节——标题、圆角边框和一些数字指示器。

最终可视化效果,图片由作者提供
下面是重现该可视化效果的完整代码:
从英语文学到数据科学
我是如何进入数据科学行业的?
·发表于Towards Data Science ·5 分钟阅读·2024 年 6 月 16 日
--
我一直在考虑是否要在这里更有规律地发布内容,而有什么比从我自己的故事开始更好的呢?我做出了许多人想要做的转变,成功从非技术背景(没有 STEM 学位,也没有社会科学学位)进入了数据科学的工作。这并不简单:我走的是一条曲折的道路,所以请耐心听我说完。

对科学的兴趣培养
那是我大学的第一年,我对任何事情都不感兴趣,除了篮球。我学习英语文学只是因为当时我的英语比同龄人好,因此能凭此考入大学。 不管怎样,在那年的学期假期里,我去了黎巴嫩贝鲁特参加人体生物力学培训(那时我还在为篮球运动员做训练)。在那期间,提到了长期压力的影响几次,听起来很有趣。回到家后,我想要进一步探索,观看了罗伯特·萨波尔斯基博士的 TED 演讲,标题是我们最好的与最糟糕自我的生物学. 我印象深刻,记得我当时想 “我真想知道他(萨波尔斯基)是怎么看待这个世界的。”幸运的是,他的整门课程都上传到了斯坦福大学的 YouTube 频道: 人类行为生物学。不过,由于我没有相关背景(那时我对科学一无所知,包括生物学、神经生物学或心理学),我没能跟上这门课程的进度。我决定去心理学系看看。
进入心理学实验室
我去心理学系的整个目的还是关于一场讲座,不过现在我已经记不清楚具体内容了,但它与信仰体系的演变有关。我走进了我找到的第一扇门,遇到了巴赫切卡皮利博士,他让我去向隔壁的同事(当时是他的学生)耶尔马兹博士自我介绍。从那时起,我的兴趣转向了他的研究方向:道德、政治、宗教与决策的交集。我开始非正式地参加他的社会心理学课程,并且开始更频繁地到他办公室,询问我读过的一些东西:行为生物学、社会心理学以及进化心理学。有一次,他邀请我加入他即将创建的实验室,我答应了,最终加入了MINT 实验室。
学习统计学与数学
在实验室开始之前,我报名参加了edX 的《宗教科学》课程,以确保自己没有错过该领域的重要发现。这个课程很有趣,因为信息呈现得非常易于理解。但随着实验室的开始,一旦我们开始阅读研究论文,我意识到我对统计学一无所知,因此我无法评估眼前的研究是否具有说服力。作为一名文学学生,我不得不找到一种假设零基础的学习材料。我尝试了许多书籍,但 OpenIntro 统计学 和 用 JASP 学习统计学 帮助了我很多,足以达到本科水平。我开始意识到自己非常喜欢学习统计学,并且希望在这个领域有所提高,想知道数学是否能帮助我更好地理解统计学。
有一个我通过玩炉石传说认识的人(是的,我喜欢卡牌游戏):Dr. Basar Coskunoglu。我开始和他一起学习数学。他非常耐心地从最基础的内容开始教我(我根本不懂任何东西,真的,所以我们不得不从函数、不等式系统等基础开始)。我们一直学到了微积分和线性代数。我们学习到一个我可以独立阅读书籍的程度,这也是我最初想要的目标。直到今天,我几乎每周都从不同的来源学习线性代数(主要是我的笔记,机器学习数学 和 Gilbert Strang 的书籍),因为我非常喜欢它。
从实验室到数据科学实习
在实验室和学术界(社会科学领域),研究人员主要使用统计软件,如 JASP、Jamovi、SPSS。R 编程也有使用,但在那个时候,由于种种原因,我开始对学术界有些失去兴趣。因此,我想学习 Python。我参加了一个训练营(*我现在强烈建议不要参加任何数据科学训练营,*他们提供的所有材料都能在网上找到,而且可能有更好的版本),该训练营主要帮助我提升编程技能,并结合了行业案例,而不是理论部分,因为我已经完成了经典的内容(直到深度学习):《统计学习导论》(现在有 Python 版本)。
当时我还在本科阶段,并且获得了 TUBITAK(土耳其科学与技术研究委员会)的奖学金,参与一个研究项目,但这个奖学金即将结束。所以,时机正好,试着尝试学术之外的机会。我想找一份兼职工作或长期实习,因为我还在学校,不想过早耗尽精力。无论如何,我开始做数据科学项目,并将它们作为“项目”部分添加到我的简历中。我申请了不同的工作,但在面试过程中,我意识到许多公司没有数据科学团队,这让我有点担心,因为这是我的第一份工作,我觉得我需要一个有经验的团队。虽然我不确定自己如何适应银行工作,但我还是申请了 DenizBank 的长期实习,并经过几个步骤后,最终被录取为数据科学实习生。
转为全职
我觉得自己在实习期间表现还不错,看来这个感觉是相互的,因为他们想留下我(而且,留用一个实习生比重新招聘新人更有意义)。一切顺利,学校也没有问题,我顺利毕业了。此外,尽管我在那里工作了六个月,包括编程等工作,但仍然要求我参加初级数据科学家的考试,考试内容包括 SQL、Python/R 以及概率和统计问题。我通过了考试,现在在 DenizBank 担任初级数据科学家。
其他:篮球分析
我打了十年篮球,也曾训练过球员。我一直有一个想法,那就是做一些与篮球相关的事情。我看到一个做 F1 分析的Formula 1 (F1) 数据分析账户,它启发我用类似的方法分析篮球数据。如果你感兴趣,可以关注我的LinkedIn或我的篮球分析账户。
好吧,这就是整个故事的主要内容。非常感谢你阅读它,希望你喜欢。如果你喜欢,你可以考虑关注我或查看我在这里的其他内容。如果你想联系我,可以通过我的网站联系我。
从短暂到持久:使用 LangChain 在聊天机器人中构建长期记忆
详细介绍如何将简单的聊天机器人转变为具有长期记忆和上下文理解的复杂 AI 助手
·发表于Towards Data Science ·8 分钟阅读·2024 年 7 月 23 日
--

使用 BingAI 生成
在我之前的一篇文章中,我介绍了如何用 OpenAI 创建一个对话型聊天机器人。然而,如果你使用过类似 ChatGPT 或 Claude 等聊天机器人界面,你会注意到,当会话关闭并重新打开时,记忆会被保留,你可以从中断的地方继续对话。这正是我想在本文中创建的体验。
我将以 LangChain 作为基础,它提供了非常棒的工具来管理对话历史,并且如果你希望通过构建链来发展更复杂的应用程序,它也非常适合。
代码可用性
本文中重建所有内容的代码可以在github.com/deepshamenghani/langchain_openai_persistence找到。
使用 LangChain 和 OpenAI 构建的单一问答机器人
我将首先创建一个循环,允许用户为聊天机器人输入问题。我会将这个输入分配给变量humaninput。现在,暂时不使用 LLM 的输出……
从洞察到影响:利用数据科学最大化客户价值

图片由Oleg Moroz提供,来源于Unsplash
提升建模:因果机器学习如何改变客户关系和收入
·发表于Towards Data Science ·阅读时间:10 分钟·2024 年 8 月 13 日
--
本系列详细内容
本文是关于提升建模和因果机器学习系列的第一篇文章。这个系列的目的是从商业和技术角度深入探讨这些方法论。
介绍
想象一下:我们的科技公司每月都在获得成千上万的新客户。但在表面之下,一个令人担忧的趋势浮现:客户流失率在上升——我们正在失去客户——尽管财务报表显示了令人印象深刻的增长,但收入却未能跟上预期。这种脱节现在可能不是问题,但当投资者开始要求盈利时,它将成为一个问题:在科技行业,获得一个新客户的成本远远高于留住一个现有客户的成本。
我们应该做什么?许多想法浮现在脑海:在客户离开之前打电话、发送电子邮件、提供折扣。但我们应该选择哪个想法?我们应该尝试所有的方法吗?我们应该关注什么?
这就是提升建模派上用场的地方。提升建模是一种数据科学技术,能够帮助我们了解不仅是谁可能会流失,还能知道针对每个客户采取哪些行动来留住他们——当然,如果他们能被留住的话。它超越了传统的预测建模,通过关注特定行动对个别客户的增量影响。
在本文中,我们将围绕两个目标来探索这一强大技术:
首先,提高商业领袖对这种方法的敏感度,让他们理解它如何为他们带来益处。
其次,给数据科学家提供工具,让他们能够向经理推销这种方法,从而成为公司成功的助力。
我们将讨论以下内容:
什么是提升建模,它为什么如此强大?
提升建模的高影响力应用场景
投资回报率(ROI):你可以期望从你的提升模型中获得多少影响?
提升建模的实践:如何实施它?
什么是提升建模,它为什么如此强大?
通常,公司会试图预测客户的行为,例如流失。为了做到这一点,他们会为每个用户建模流失的概率。这是“结果”建模,意味着估算用户采取特定行动的可能性。
例如,如果一个结果模型估算某个特定用户流失的概率为 90%。在这种情况下,公司可能会尝试联系该用户,以防止他们流失,对吧?这已经是一个很大的进步,并且可能有助于显著降低流失率或识别其根本原因。但这里有一个棘手的问题:如果我们识别出的一些用户实际上想离开,但他们只是懒得打电话或取消订阅呢?他们可能会利用这个电话实际上选择流失,而不是继续留在我们这里!
与结果建模不同,提升建模是一种预测建模技术,直接衡量治疗—或行动—对个体行为的增量影响。 这意味着我们将建模一个用户在被公司联系时,选择留存的概率。
提升模型专注于干预组和对照组之间结果的差异,使公司能够评估个体层面的实际“提升”,识别出对每个客户最有效的行动。

提升模型与结果模型的描述
更准确地说,提升建模使我们能够根据客户对干预/行动的反应概率,将客户分为四个组:
可说服的用户:这些是可能对行动作出积极反应的用户:他们是我们希望通过行动来吸引的目标群体。
显而易见的事情:这些是我们的客户,无论是否接受干预,都能实现预期的结果。针对这些用户进行干预通常是浪费资源。
无药可救的个体:这些人无论是否采取行动,都不太可能实现预期的结果。将资源投入到这些用户身上可能并不具备成本效益。
沉睡的狗:这些客户可能会对干预做出负面反应。针对这些客户进行干预可能会通过导致不希望发生的行为(例如,在被提醒时取消订阅)而对业务造成损害。
提升建模的目标是识别并针对那些容易说服的人,同时避免其他群体,尤其是“睡狗”群体。
回到我们的留存问题,提升建模不仅能帮助我们评估哪种行动最能提高留存率,还能帮助我们为每个用户挑选合适的行动:
一些用户——可说服群体——可能只需要一个电话或电子邮件,就能留住他们。
其他人——可说服群体——可能需要一个价值 10 美元的优惠券来被说服。
一些——稳妥客户——无需任何干预,因为他们可能会继续留存。
对于一些用户——睡狗群体——任何留存尝试可能实际上会导致他们流失,因此最好避免联系他们。
最后,失去的原因群体可能对任何留存努力都不会作出回应,因此通过不针对他们,可以节省资源。
总结来说,提升建模使我们能够精准分配资源,针对合适的可说服群体采取合适的行动,同时避免负面影响,从而最大化我们的投资回报率(ROI)。最终,我们能够制定一个高度个性化且有效的留存策略,优化我们的资源并提高整体客户生命周期价值。
现在我们了解了什么是提升建模及其潜在影响,让我们探索一些能够推动显著商业价值的应用场景。
提升建模的高影响力应用场景
在深入探讨如何设置之前,让我们研究一些提升建模对你的业务高度相关的具体应用场景。
客户留存: 提升建模帮助识别哪些客户最有可能对留存努力做出积极回应,从而使公司能够将资源集中在“可说服群体”上,避免打扰可能会流失的“睡狗群体”。
追加销售与交叉销售: 预测哪些客户最有可能对追加销售或交叉销售的优惠或促销做出积极回应,从而提高收入和客户生命周期价值(LTV),而不会打扰到那些不感兴趣的用户。提升建模确保额外的优惠针对的是那些会认为它们最有价值的用户。
定价优化: 提升模型可以帮助确定不同客户群体的最佳定价策略,在不排斥价格敏感型用户的情况下最大化收入。
个性化营销活动: 提升建模可以帮助确定哪些营销渠道(电子邮件、短信、应用内通知等)或哪种类型的广告对每个用户最有效。
这些是最常见的应用,但它可以超越以客户为中心的行动:通过足够的数据,我们可以用它来优化客户支持优先级,或通过对合适的员工采取合适的行动来提高员工留存率。
有了这些强大的应用场景,你可能会想知道如何在你的组织中实际实施提升建模。让我们深入了解将这种技术付诸实践的步骤。
投资回报率(ROI):在实践中,你可以期待从提升模型中获得什么?
我们如何衡量提升模型的表现?
这是一个很好的问题,在探讨这种方法的潜在结果之前——说实话,这种方法确实令人印象深刻——解决这个问题是至关重要的。正如预期的那样,答案是多方面的,数据科学家有多种方法来评估模型预测一个行为增量影响的能力。
一种特别有趣的方法是 Qini 曲线。Qini 曲线将累计增量收益与目标人群的比例进行对比。

Qini 曲线示例
简单来说,它有助于回答这个问题:通过使用你的模型针对 X%的人群进行目标定位,相较于随机目标定位,你可以实现多少额外的积极结果? 我们通常会将提升模型的 Qini 曲线与随机目标定位策略的 Qini 曲线进行比较,以模拟如果我们没有提升模型而是随机地针对用户或客户进行目标定位时会发生什么。在构建提升模型时,最佳实践是比较所有模型的 Qini 曲线,以识别在未见数据上最有效的模型。然而,我们将在技术文章中深入探讨这一点。
现在,让我们探索这种方法的潜在影响。再次强调,可能会出现各种不同的场景。
我可以从我新构建的提升模型中预期到什么样的影响?
说实话,这真的取决于许多不同的变量,从你的使用场景开始:你为什么要首先构建一个提升模型?例如,你是否尝试通过只接触 80%的客户来优化资源,因为预算有限?还是你打算通过多重治疗模型来个性化你的方法?
另一个关键点是了解你的用户——你是专注于留住高度参与的客户,还是你有很多不活跃的用户和失去的机会?

即使不涉及这些具体细节,我们通常可以将潜在影响分为两大类——正如你在上面的精美图示中所看到的:
优化模型: 提升模型可以帮助你通过识别哪些用户最有可能对你的干预产生积极反应来优化资源分配。例如,你可能通过只接触 50%的用户就能实现 80%的总积极结果。虽然这种方法不一定总是优于接触所有人,但它可以显著降低成本,同时保持较高的影响力。其主要优势在于效率:用更少的资源实现几乎相同的结果。
高影响力模型: 这种类型的模型能够让你实现比接触所有人更大的总体影响。它通过识别哪些人会做出积极反应以及哪些人可能会对你的干预产生负面反应来实现这一点。这在用户群体多样化或个性化方法至关重要的场景中尤为有价值。
你的提升模型的有效性最终将取决于几个关键因素,包括客户的特点、数据的质量、实施策略和所选择的模型。
但是,在我们深入探讨细节之前,先简要了解如何实现你的第一个提升模型。
提升建模实践:如何实现?
你可能会问:如果提升建模如此强大,为什么我以前从未听说过它?答案很简单:它的设置非常复杂。它需要深入的数据科学知识、设计和运行实验的能力,以及因果机器学习的专业知识。虽然我们将在下一篇文章中更深入地探讨技术细节,但让我们概述一下创建、扩展和整合你的第一个提升模型的主要步骤:
第一步: 定义目标并设计实验。首先,明确你的目标和目标受众。例如,你可能希望减少高端订阅用户的流失。然后,设计一个 A/B 测试(或随机对照试验)来测试你想尝试的所有操作。这可能包括:
发送个性化邮件
拨打客户电话
提供折扣
这一步可能需要一些时间,具体取决于客户数量,但它将成为你第一个模型的基础。
第二步:构建提升模型。 接下来,利用实验数据来构建提升模型。有趣的是,实验的实际结果在这里并不那么重要——关键是不同客户对不同行动的反应数据。这些数据帮助我们理解不同行动对客户的潜在影响。
第三步: 根据模型实施行动。 有了提升模型,你现在可以为客户实施具体的行动。该模型将帮助你决定哪些行动最有可能对每个客户产生有效影响,从而进行个性化干预。
第四步:监控并评估表现。 为了检查你的模型是否有效,持续跟踪这些操作随时间的表现。你可以通过将模型应用于一组客户,并将其与随机选择的另一组客户进行对比,来在实际情况下测试模型的效果。这个持续的评估过程帮助你优化方法,并确保你能获得期望的结果。
第五步: 扩展并完善。 为了使解决方案在更大规模上有效,最好定期更新模型。将一些客户保留下来,帮助训练下一版本的模型,并使用另一组客户来评估当前模型的效果。这种方法使你能够:
持续改进你的模型
适应客户行为的变化
随着时间的推移,识别新的有效行动
请记住,虽然概念简单,实施却需要专业知识。提升建模是一种迭代方法,随着时间的推移不断改进,因此耐心和持续改进是成功的关键。
结论
提升建模彻底改变了企业处理客户互动和营销的方式。这项技术使公司能够:
通过正确的行动锁定正确的客户
避免打扰那些可能不希望被打扰的客户
大规模个性化干预
通过优化与客户的互动最大化投资回报率!
我们已经探讨了提升建模的基本原理、关键应用和实施步骤。虽然设置过程复杂,但它在改善客户关系、增加收入和优化资源方面的好处使其对任何企业都至关重要。
在我们的下一篇文章中,我们将深入探讨技术方面的内容,帮助数据科学家有效地实施这一技术。请继续关注我们,一起探索前沿的数据科学思想。
来源
除非另有说明,所有图片均由作者提供
[1] en.wikipedia.org/wiki/Uplift_modelling
[2] growthstage.marketing/improve-marketing-effectiveness-with-ml/
[3] forecast.global/insight/understanding-customer-behaviour-using-uplift-modelling/
从洞察到影响:每个数据科学家都需要的演讲技巧
如何构建、设计和呈现能够赢得利益相关者支持的数据演示文稿
·发表于 Towards Data Science ·阅读时长 7 分钟·2024 年 10 月 28 日
--
今天成为一名数据科学家不仅仅是一个技术角色。它已经发展成一个高度跨职能的工作,因为你需要向利益相关者解释你的数据洞察并推销你的想法,以推动实际的业务影响。因此,要成为一名成功的数据科学家,演讲能力是一项至关重要的技能——很多时候是使用传统的 PowerPoint 或 Google Slides。然而,这通常是入门级数据科学家和分析师所缺乏的技能,他们往往将精力集中在技术技能上。
在本文中,我将分享我在七年数据科学家生涯中,经过多次演讲后发现最有效的演示格式。我还将通过使用假数据制作的幻灯片示例来演示。希望这能帮助你提高演讲技巧并推动你的数据职业发展。
请注意,本文专注于演示步骤,假设你已经有了扎实的分析或模型 😃

图像由 DALL·E 创建
第一部分 — 背景
在开始你的演讲时,你应该从项目背景开始,以设定……
从本地到云端:估算开源 LLM 所需的 GPU 资源
估算部署最新开源 LLM 所需的 GPU 内存
·发布于 Towards Data Science ·阅读时间 4 分钟·2024 年 11 月 18 日
--

如果你像我一样,可能会对最新和最强大的开源大语言模型(LLM)感到兴奋——从像 Llama 3 这样的模型到更紧凑的 Phi-3 Mini。但在你跳进部署语言模型之前,有一个关键因素需要提前规划:GPU 内存。如果判断失误,你闪亮的新网页应用可能会崩溃、运行缓慢,或者产生高额的云计算费用。为了简化这一过程,我将为你解释什么是量化,并准备了一份2024 年 GPU 内存规划备忘单——这是市场上最新开源 LLM 的简明总结,以及部署前需要了解的事项。
如果你不是会员,请阅读此处。
为什么要估算 GPU 内存?
在部署 LLM 时,估算需要多少 GPU 内存是有风险的。内存过少,模型可能崩溃;内存过多,则是在浪费资源,导致不必要的开销。
预先了解这些内存需求,就像在公路旅行之前知道你能将多少行李塞进车里——这可以避免头疼问题并提高效率。
从遮蔽图像建模到自回归图像建模
图像基础模型预训练目标的简要回顾
·发表于Towards Data Science ·阅读时间 6 分钟·2024 年 6 月 10 日
--
我们渴望大型模型,不是吗?
GPT 系列已经证明了它在革命化自然语言处理(NLP)领域的能力,大家都期待着在计算机视觉领域看到同样的转变。近年来,最流行的图像基础模型包括SegmentAnything、DINOv2等。一个自然的问题是,这些基础模型的预训练阶段有何关键区别?
本文将不会直接回答这个问题,而是通过回顾使用遮蔽图像建模的图像基础模型预训练目标来间接讨论。我们还将讨论一篇将在 ICML’24 上发表的论文,应用自回归建模到基础模型的预训练中。

图片来源:pxhere.com/en/photo/1025277
LLM 中的模型预训练是什么?
模型预训练是一个术语,用于描述在没有标签的情况下对模型进行训练的阶段,这一阶段模型完全通过自监督的方式进行训练,通常用于大模型(LLM、图像基础模型)。
常见的预训练技术大多起源于 LLM。例如,BERT 模型使用了遮蔽语言建模,启发了像 BEiT、MAE-ViT 和 SimMM 这样的遮蔽图像建模方法。GPT 系列使用了自回归语言建模,最近一篇被接受的 ICML 论文将这一思想扩展到自回归图像建模。
那么,什么是遮蔽语言建模和自回归语言建模呢?
Masked Language Modeling最早是在 2018 年 BERT 论文中提出的。该方法被描述为“简单地随机遮掩输入 token 的一部分,然后预测这些被遮掩的 token。”这是一种双向表示方法,因为模型会尝试在被遮掩的 token 上前后预测。

Masked LM 预训练。图像来源:arxiv.org/abs/1810.04805
自回归语言建模在 GPT3 论文中广为人知。它在 XLNet 论文中有更清晰的定义,我们可以看到该模型是单向的。GPT 系列使用单向语言模型的原因是其架构基于解码器,仅需要在提示和完成部分使用自注意力:

AR — 自回归。来源:arxiv.org/abs/1906.08237
图像领域的预训练
当进入图像领域时,第一个问题是我们如何形成图像“token 序列”。自然的想法是使用 ViT 架构,将图像分解成一个图像补丁(视觉 token)的网格。
BEiT. 该论文于 2022 年作为 arXiv 预印本发布,BEiT 的思路很简单。将图像切分成 14*14 的视觉 token 序列后,随机遮掩 40%的 token,用可学习的嵌入代替,并输入到 transformer 中。预训练的目标是最大化正确视觉 token 的对数似然,这一阶段无需解码器。流程如图所示。

BEiT 预训练流程。图像来源:arxiv.org/abs/2106.08254
在原始论文中,作者还提供了 BEiT 与变分自编码器之间的理论联系。所以自然的问题是,自编码器能否用于预训练目的?
MAE-ViT. 本文通过设计一个遮掩自编码器架构回答了上面的问题。采用相同的 ViT 公式和随机遮掩,作者提出在训练过程中“丢弃”被遮掩的补丁,并仅使用未遮掩的补丁作为输入传入编码器。遮掩的 token 将在预训练阶段的解码过程中用于重建。解码器可以是灵活的,范围从 1 到 12 个 transformer 块,维度在 128 到 1024 之间。更多详细的架构信息可在原始论文中找到。

Masked Autoencoder 架构。图像来源:arxiv.org/abs/2111.06377
SimMIM。与 BEiT 和 MAE-ViT 略有不同,论文提出使用灵活的主干网络,如 Swin 变换器,用于编码目的。所提出的预测头非常轻量——仅使用一个线性层的 2 层 MLP 回归掩蔽像素。

SimMIM 流程。图片来源:arxiv.org/abs/2111.09886
AIM。一篇最近被 ICML’24 接受的论文提出了使用自回归模型(或因果模型)进行预训练的方法。与使用掩蔽序列不同,该模型将完整序列输入到因果变换器中,使用带有因果掩蔽的前缀自注意力。

AIM 与因果变换器。图片来源:arxiv.org/abs/2401.08541
什么是前缀因果注意力?在Kaggle上有关于因果注意力掩蔽的详细教程,另外,这里也有相关内容,它是对自注意力中的“未来”标记进行掩蔽。然而,在这篇论文中,作者指出因果掩蔽和下游双向自注意力之间的差异可能会导致性能问题。解决方案是使用部分因果掩蔽或前缀因果注意力。在前缀序列中,使用双向自注意力,而对其余序列应用因果注意力。

预训练期间的因果注意力。图片来源:arxiv.org/abs/2401.08541
自回归图像掩蔽的优势是什么?答案在于模型和数据规模的扩展。论文声称,模型规模与预训练损失和下游任务性能直接相关(见左侧子图)。未经筛选的预训练数据规模也与下游任务性能直接相关(见右侧子图)。

AIM 的规模效应。图片来源:arxiv.org/abs/2401.08541
与 50%的掩蔽比率相比,AIM 在掩蔽图像建模上取得了惊人的 8%的性能提升。

那么,最大的收获是什么?AIM 论文讨论了不同的预训练方法之间的权衡,我们在这里不会重复讨论。一个更浅显但更直观的教训是,利用现有的 LLM 领域经验来改进视觉基础模型,尤其是在可扩展性方面,仍然可能有很多工作要做。希望我们能在未来几年看到这些改进。
参考文献
El-Nouby 等人,大规模自回归图像模型的可扩展预训练。ICML 2024. Github:
github.com/apple/ml-aimXie 等人,SimMIM:一种简单的 Masked 图像建模框架。CVPR 2022. Github:
github.com/microsoft/SimMIMBao 等人,BEiT:BERT 预训练图像 Transformer。arXiv 预印本 2022. Github:
github.com/microsoft/unilm/tree/master/beitHe 等人,Masked autoencoders 是可扩展的视觉学习器。CVPR 2022. HuggingFace 官方:
huggingface.co/docs/transformers/en/model_doc/vit_maeCaron 等人,自监督视觉 Transformer 中的涌现特性。ICCV 2021. Github:
github.com/facebookresearch/dino?tab=readme-ov-fileLiu 等人,Swin Transformer:使用平移窗口的分层视觉 Transformer。ICCV 2021. Github:
github.com/microsoft/Swin-TransformerBrown 等人,语言模型是少样本学习者。NeurIPS 2020. Github:
github.com/openai/gpt-3Yang 等人,Xlnet:用于语言理解的广义自回归预训练。NeurIPS 2019. Github:
github.com/zihangdai/xlnetDevlin 等人,BERT:用于语言理解的深度双向 Transformer 预训练。arXiv 预印本 2018。HuggingFace 官方:
huggingface.co/docs/transformers/en/model_doc/bert
从 MOCO v1 到 v3:构建自监督学习动态字典的探索 — 第一部分
对动量对比学习框架的温和回顾
·发表于Towards Data Science ·7 分钟阅读·2024 年 7 月 4 日
--
我们是否已经进入自监督学习的时代?
数据每天都在流动。人们在全天候工作。工作分布到世界的每个角落。但仍然有大量数据未被注释,等待着可能通过新模型、新训练或新升级来使用。
或者,它将永远不会发生。当世界以监督方式运行时,它将永远不会发生。
近年来,自监督学习的兴起揭示了一个新的方向。自监督学习并不是为所有任务创建注释,而是将任务分为前置/预训练(请参见我之前关于预训练的文章这里)任务和下游任务。前置任务专注于从整个数据集中提取具有代表性的特征,而不依赖任何真实标签的指导。然而,这个任务仍然需要从数据集中自动生成标签,通常通过广泛的数据增强。因此,本文中我们将无监督学习(数据集未标注)和自监督学习(任务由自生成标签进行监督)交替使用。
对比学习是自监督学习的一个主要类别。它使用未标注的数据集和对比信息编码的损失(例如,对比损失、InfoNCE 损失、三元组损失等)来训练深度学习网络。主要的对比学习方法包括 SimCLR、SimSiam 和 MOCO 系列。
MOCO——这个词是“动量对比”(momentum contrast)的缩写。其核心思想写在第一篇 MOCO 论文中,提出了对计算机视觉自监督学习问题的理解,如下所示:
“[摘自 原始论文] 计算机视觉则进一步关注字典构建,因为原始信号位于一个连续的高维空间中,并且并非为了人类通信而构建……虽然有多种动机推动,这些(注:近期的视觉表示学习)方法可以被看作是在构建动态字典…… 无监督学习训练编码器以执行字典查找:一个编码后的‘查询’应该与其匹配的键相似,并与其他键不同*。学习被表述为最小化对比损失。”*
在本文中,我们将温和地回顾 MOCO v1 到 v3:
v1——论文《用于无监督视觉表示学习的动量对比》发表于 CVPR 2020。该论文提出了对关键 ResNet 编码器进行动量更新,使用样本队列和 InfoNCE 损失。
v2——紧接着发布的论文《通过动量对比学习改进基线》实现了两个 SimCLR 架构的改进:a)用一个两层的 MLP 替换全连接层;b)通过加入模糊扩展原始数据增强。
v3——论文《自监督视觉变换器训练的实证研究》发表于 ICCV 2021。该框架将关键查询对扩展为两个关键查询对,用于形成类似 SimSiam 的对称对比损失。骨干网络也从仅 ResNet 扩展为 ResNet 和 ViT。

图像来源:pxhere.com/en/photo/760197
MOCO V1
该框架从一个核心的自监督学习概念开始:查询和键。这里,查询指的是查询图像或图像块的表示向量(xquery),而键则指的是样本图像/图像块字典的表示向量({x_0key, x_1^key, …})。查询向量 q 是由一个可训练的“主”编码器生成的,并通过常规的梯度反向传播进行训练。键向量则存储在字典队列中,由一个可训练的编码器生成,该编码器不会直接进行梯度反向传播,而是仅通过使用主编码器的权重以动量方式更新权重。下面是更新方式:

实例判别任务与 InfoNCE 损失
由于数据集在前期/预训练阶段没有标签,因此需要一个特定任务。论文采用了在这篇 CVPR 2018 论文中提出的实例区分任务。与原始设计不同,原始设计是使用非参数分类器计算内存库中特征向量之间的相似性,而 MOCO 论文使用了正向+<查询, 键>对和负向-<查询, 键>对来监督学习过程。当查询和键图像是从同一图像增强出来时,视为正样本。否则,视为负样本。训练损失是InfoNCE 损失,可以看作是查询/键对的 softmax 负对数:

动量更新
作者声称,将主查询编码器复制到键编码器可能导致较差的结果,因为快速变化的编码器会减少键表示字典的一致性。相反,每一步只训练主查询编码器,但键编码器的权重通过动量权重 m 进行更新:

在训练过程中,动量权重保持较大,例如 0.999 而不是 0.9,这验证了作者的猜想,即关键编码器的一致性和稳定性会影响对比学习的表现。
伪代码
少于 20 行的伪代码是整个训练过程的快速概述。与上述的 InfoLoss 一致,值得注意的是,正对数值是每个样本的单一尺度,负对数值是每个样本对应 K 个键的 K 维向量。

伪代码来源:arxiv.org/pdf/1911.05722
MOCO V2
在 MOCO 发布后立即发布的 2 页 v2 论文,通过采纳 SimCLR 中的两项成功架构变更,对版本 1 进行了轻微修改:
将 ResNet 编码器的全连接层替换为 2 层 MLP
扩展原始增强集,加入模糊增强
有趣的是,即使只有一个简单的架构调整,性能提升似乎也非常显著:

MOCO V3
版本 3 通过采用对称对比损失、额外的投影头和 ViT 编码器,提出了重大改进。
对称对比损失
受到 SimSiam 工作的启发,该工作通过获取两个随机增强的视图,并在负余弦相似度计算中交换它们,从而获得对称损失,MOCO v3 对样本进行了两次增强。它将样本分别输入查询编码器和键编码器。

图像来源:arxiv.org/pdf/1911.05722 和 arxiv.org/abs/2104.02057
对称对比损失基于一个简单假设——所有正样本对都位于 NN 查询-关键矩阵的对角线位置,因为它们是同一图像的增强版本;所有负样本对则位于 NN 查询-关键矩阵的其他位置,因为它们是来自不同样本的增强(可能是相同的增强):

从这个意义上讲,动态关键字典要简单得多,因为它是在小批量中动态计算的,不需要保持一个记忆队列。可以通过下面对批次大小的稳定性分析来验证这一点(注意,作者解释说 6144 批次的性能下降是由于 训练期间的部分失败现象):

ViT 编码器
ViT 编码器带来的性能提升如下所示:

比较与总结
MOCO v3 论文给出了使用 ResNet50 (R50) 编码器的 v1-v3 性能比较:

总结来说,MOCO v1-v3 对以下元素进行了明确的转变:
编码器:ResNet → ResNet + MLP 层 → ResNet/ViT + MLP 层
关键字典:全局关键向量队列 → 基于数据增强的小批量关键字
对比损失:非对称对比损失 → 对称对比损失
但这还不是全部。在下一篇文章中,我将深入探讨 MOCO v3 代码,以实现数据增强和动量更新。敬请期待!
参考文献
Wu 等人,通过非参数实例区分进行无监督特征学习。CVPR 2018。github:
github.com/zhirongw/lemniscate.pytorchOord 等人,对比预测编码的表示学习。arXiv 预印本 2018。
Chen 等人,视觉表示的对比学习简单框架。PMLR 2020。github:
github.com/sthalles/SimCLRHe 等人,动量对比用于无监督视觉表示学习。CVPR 2020。
Chen 等人,改进的基准通过动量对比学习。arXiv 预印本 2020。
陈等人,《探索简单的孪生表示学习》。CVPR 2021. github:
github.com/facebookresearch/simsiam陈等人,《自监督视觉变换器训练的实证研究》。ICCV 2021. github:
github.com/facebookresearch/moco-v3
从牛顿到大型语言模型
一种新的人工智能推理优化方法
·发表于Towards Data Science ·14 分钟阅读·2024 年 10 月 9 日
--

图像由作者使用 DALL-E 生成
介绍
回答多跳问答(QA)问题需要多个事实,这对于大型语言模型(LLMs)中的复杂推理和解释至关重要。QA 量化并客观地测试智能系统的推理能力。由于它们具有明确的正确解决方案,QA 任务可以减少评估中的主观性和人为偏见。QA 功能能够评估演绎推理、归纳推理和溯因推理,后者涉及从部分知识中推导出最可能的答案。
我们在改善模型推理过程时面临几个挑战。最重要的需求之一是模型的可解释性和可解释性。大型人工智能模型,尤其是深度神经网络,难以理解,这使得准确评估它们并为它们的决策和结论提供易于理解的解释变得困难。另一个改善推理过程的重要目标是确保推理过程对于输入或上下文的细微变化具有鲁棒性,同时开发能够跨不同领域和问题类型泛化推理能力的模型。
人工智能中的物理类比的力量
从单模态到多模态:构建基础模型的 DIY 技术
一篇全面的教程:使用像提示适配和适配器这样的高级技术,将开源的单模态模型转化为多模态模型,包括所有 LLaMA-Adapters、LLaVa、MiniGPT-4 等变种。
·发布于 Towards Data Science ·15 分钟阅读·2024 年 6 月 25 日
--
作者: Elahe Aghapour,Salar Rahili
引言
随着大型语言模型(LLMs)的最新进展,人工智能已成为技术的焦点。我们比以往任何时候都更渴望达到 AGI 级别的智能。然而,想要实现对周围环境的类人理解,涉及的远不止是掌握语言和文本理解。人类通过五感与世界互动,并基于这些互动采取行动以实现目标。这突显了我们的下一个步骤是开发能够融合多模态输入和输出的大型模型,从而使我们更接近类人能力。然而,我们面临着两个主要障碍。首先,我们需要一个多模态标注数据集,而这并不像文本数据那样易于获取。其次,我们已经在推升训练文本数据模型的计算能力上到达极限。将这一计算能力扩展到包括其他模态,尤其是高维度的图像和视频,是一个极具挑战性的任务。
这些限制一直是许多 AI 研究人员创建强大多模态模型的障碍。到目前为止,只有少数几家知名公司,如 Google、Meta 和 OpenAI,成功地训练了此类模型。然而,这些知名模型都不是开源的,仅提供少数 API 供公众使用。这迫使研究人员,尤其是学术界的研究人员,寻找在没有强大计算能力的情况下构建多模态模型的方法,而是依赖于开源的预训练模型,这些模型大多是单模态的。
在这篇博客中,我们重点讨论了成功且低成本的多模态模型创建方法。我们的标准集中在那些计算成本保持在几千美元以内的项目,假设这是一个典型实验室可以承担的预算。
1- 参数高效微调(PEFT)
在我们深入探讨整合和对齐两个预训练模型的方法之前,我们需要讨论如何在有限计算能力下微调大模型的机制。因此,我们将从探索参数高效微调(PEFT)开始,然后描述这些方法如何进一步用于对齐预训练模型并构建开源的多模态模型。

图 1. 不同的 PEFT 方法(图片来自论文)。
随着模型规模的持续增长,对于高效微调方法的需求变得更加关键。由于大规模预训练模型微调所有参数通常不切实际,所需的计算资源和时间非常庞大,参数高效微调(PEFT)通过冻结模型的参数,仅训练注入的具有少量参数的模块,解决了这一挑战。因此,只存储一个大规模 Transformer 副本,并搭载经过学习的特定任务轻量级 PEFT 模块,从而为每个附加任务带来非常小的开销。这种方法不仅减少了资源需求,还加速了模型向新任务的适应,使其成为在模型不断扩展的时代中一种切实可行且有效的策略。PEFT 方法在大语言模型(LLMs)和大型视觉模型中非常常见,主要可以分为三类,如图 1 所示:在已提出的几种方法中,有三种受到了社区的显著关注。
1- 适配器:适配器本质上是一个小模块,通常包括一个下采样层、非线性层和一个带有跳跃连接的上采样层,以保留原始输入。该模块被插入到预训练模型中,在微调过程中仅训练适配器。
2- LoRA 通过将可训练的低秩分解矩阵注入模型来近似权重更新,从而显著减少下游任务的可训练参数数量。对于一个维度为 d×k 的预训练权重矩阵 W,LoRA 通过低秩分解表示其更新:W+ΔW=W+DU
其中,D 的维度是 d×r,U 的维度是 r×k。这些矩阵 D 和 U 是可调参数。LoRA 可以应用于注意力矩阵和/或前馈模块,以实现高效的微调。
3- P*-调优(前缀调优,提示调优)通常将一组可学习的前缀向量或令牌附加到输入嵌入中,并且在下游任务的微调过程中,仅训练这些所谓的“软提示”。这种方法背后的理念是,通过一系列额外的“虚拟令牌”信息的指导,帮助预训练模型理解下游任务。软提示是与词汇表中的实际令牌无关的向量序列。相反,它们充当中介表示,指导模型的行为以完成特定任务,尽管它们与任务本身没有直接的语言学联系。
评估 PEFT 技术:优点与局限:
适配器增加少量参数(占总参数的 3-4%),使其比完全微调更高效,但低于提示调优或 LoRA。然而,由于额外的神经网络层,它们能够有效地捕捉复杂的任务特定信息,并通过学习详细的任务特征,通常能够在特定任务上取得高性能。缺点是,这种方法使模型更深,这可能会使优化过程变得复杂,并导致更长的训练时间。
LoRA仅添加少量参数(0.1%至 3%),使其在处理非常大的模型时具有高度的效率和可扩展性,适用于调整最先进的 LLM 和 VLM。然而,LoRA 的适应性受到低秩结构的限制。尽管 LoRA 高效,但与适配器相比,它在捕捉某些类型的任务特定信息时可能灵活性较差。
P*-调优在参数效率方面极为高效(通常只需要不到 0.1%),因为它仅需要学习额外的提示令牌,而保持原始模型参数不变,从而保留模型的泛化能力。然而,它可能无法像其他方法那样有效地捕捉复杂的任务特定信息。
到目前为止,我们已经回顾了用最少计算资源微调大模型的新方法。这种能力为我们提供了将两个拥有数十亿参数的大模型结合起来的机会,并且只需微调几百万个参数就能使它们协调工作。这种对齐使得一个或两个模型能够生成对另一个模型可理解的嵌入。接下来,我们将讨论三种主要方法,展示了成功实现这种训练模式的例子。
2.1 提示适应:

图 2. 视觉提示的早期融合与适应提示的后期融合(图像来自论文)
LLaMA-Adapter提出了一种轻量级的适应性方法,可以有效地将LLaMA模型微调为一个跟随指令的模型。这是通过冻结预训练的LLaMA 7B 模型,并在最上层的变换器层中引入一组可学习的适应性提示(1.2M 参数)实现的。为了避免由于随机初始化提示而导致的初始不稳定性和效果问题,适应性提示被零初始化。此外,还引入了一个可学习的零初始化门控因子,用于自适应地控制适应性提示的重要性。
此外,LLaMA-Adapter通过使用预训练的视觉编码器(如 CLIP)集成视觉信息,扩展到多模态任务。在给定图像作为视觉上下文的情况下,通过多尺度特征聚合获取全局视觉特征,然后通过一个可学习的投影网络将这些特征投影到 LLM 的适应性提示的维度。最终得到的整体图像标记会被重复 K 次,并在所有插入的 K 长度适应性提示的变换器层中逐元素相加。使用LLaMA-Adapter进行微调时,在 8 个 A100 GPU 上耗时不到一小时。在RobustGER中也使用了类似的方法,LLM 被微调用于执行自动语音识别中的生成错误修正(GER)去噪。这个过程在单个 NVIDIA A40 GPU 上训练需要 1.5 到 4.5 小时。
LLaMA-Adapter V2专注于能够在开放式视觉指令上良好泛化的指令跟随视觉模型。为实现这一目标,相较于原始的LLaMA-Adapter,提出了三项关键改进。首先,通过解冻LLaMA中的所有归一化层,并为转换器中的所有线性层添加可学习的偏置和缩放因子,引入了更多可学习参数(14M),这将指令跟随能力分布到整个模型。其次,视觉标记被输入到语言模型的早期层,而适应性提示被添加到顶部层。这在不破坏模型指令跟随能力的情况下改善了视觉知识的整合(见图 2)。第三,采用了图像-文本标注数据和仅语言指令数据的联合训练范式。视觉投影层通过图像-文本标注数据进行训练,而晚期适应性提示和解冻的归一化层则通过指令跟随数据进行训练。此外,在推理过程中集成了专家模型,如标注和 OCR 系统,增强了图像理解而无需额外的训练成本。我们无法找到有关 GPU 要求和训练所需时间的具体细节。然而,根据 GitHub 上的信息,在单个 A100 GPU 上大约需要 100 小时。
2.2 中间模块训练:

图 3. 中间模块训练工作原理概览(图片来自作者)
为了创建一个多模态模型,可以通过一个可学习的投影模块将两个或更多单模态基础模型连接起来。该模块将一个模态的特征映射到另一个模态,从而实现不同数据类型的集成。例如,可以通过投影模块将视觉编码器连接到大型语言模型(LLM)。因此,如图 3 所示,LLM 的输入由一系列投影的图像特征和文本组成。训练过程通常包括两个阶段:
预训练:投影模块在一大规模配对数据集上进行预训练,以实现跨模态对齐。
微调:投影模块(具有一个或多个单模态模型)针对特定的下游任务进行微调,例如指令跟随任务。
MiniGPT-4将冻结的视觉编码器ViT-G/14与冻结的 LLM Vicuna对齐,使用一个投影层。对于视觉编码器,采用与BLIP-2相同的预训练视觉感知组件,该组件包括ViT-G/14和 Q-former 网络。MiniGPT-4添加了一个可学习的投影层,其输出被视为 LLM 的软提示,格式如下:
“###Human:
训练投影层涉及两个阶段。首先,在一个大型的图像-文本配对数据集上对投影层进行预训练,以获得视觉-语言知识。然后,使用一个较小且高质量的数据集对线性投影层进行微调。在这两个阶段中,所有其他参数保持冻结。因此,MiniGPT-4能够生成更自然、更可靠的语言输出。MiniGPT-4的训练需要大约在 4 个 A100 GPU 上训练 10 小时。
使用机器生成的指令跟随数据对 LLM 进行微调,已被证明能够提高其在新任务上的零-shot 能力。为了在多模态领域探索这一思路,LLaVA将 LLM Vicuna与视觉编码器ViT-L/14连接,通过一个单独的线性层来处理视觉-语言指令跟随任务。在第一阶段,投影层在一个大型图像-文本配对数据集上进行训练,同时冻结视觉编码器和 LLM 的权重。此阶段创建了一个兼容冻结 LLM 的视觉标记器。在第二阶段,预训练的投影层和 LLM 权重使用一个高质量生成的数据集进行微调,这个数据集包含语言-图像指令跟随数据。此阶段增强了模型处理多模态指令并执行特定任务的能力。LLaVA使用了 8× A100 GPU。预训练需要 4 小时,微调需要 4 到 8 小时,具体取决于任务数据集。它在视觉推理能力上展示了值得称赞的熟练度,尽管在需要简短答案的学术基准测试中表现不足。
添加了一个两层的 MLP 来将 LLM 与视觉编码器连接起来。
输入图像分辨率通过使用 CLIP-ViT-L-336px 被放大,从而提高了细节感知能力。
Vicuna 模型已经扩展到 13B 参数。
添加了面向学术任务的 VQA 数据,并包含了特定的响应格式提示,用于指示期望的输出格式。当提示简短的答案时,VQA 问题会附加上提示语:“用一个单词或短语回答问题。”
训练过程大约在 1 天内完成,使用单个 8-A100 GPU,并在多个基准测试中取得了最先进的结果。

图 4. 视频-ChatGPT 的架构(图片来自 论文)
Video-ChatGPT 旨在创建基于视频的对话代理。鉴于视频-字幕对的稀缺性以及从零开始训练所需的巨大资源,它采用了预训练的基于图像的视觉编码器 CLIP ViT-L/14 来处理视频任务,并通过可学习的线性投影模型将其与预训练的 LLM Vicuna 连接。由于 ViT-L/14 编码的是图像,因此对于一个包含 T 帧的视频样本,它会生成 T 个帧级别的嵌入,维度为 hwD。正如图 4 所示,获得视频级特征的过程包括两个关键步骤:
空间视频特征:通过在时间维度上进行平均池化帧级特征,获得 hwD 维度的特征。
时间视频特征:通过在空间维度上进行平均池化,获得 TD* 维度的特征。
这些时间和空间特征被连接在一起,形成视频级特征,然后通过一个可学习的线性层投影到文本嵌入空间。该模型在经过精心策划的高质量视频-文本对数据集上训练,线性投影层的训练大约需要 3 小时,使用 8 个 A100 40GB GPU。该方法使得 Video-ChatGPT 能够生成关于视频内容的详细且连贯的对话。

图 5. PandaGPT 的示意图(图片来自 论文)
PandaGPT虽然没有连接单模态模型,但通过将预训练的大型语言模型Vicuna与多模态编码器ImageBind通过线性投影层集成,首次推出了能够进行指令跟随的通用模型(见图 5)。线性投影层是经过训练的,Vicuna的注意力模块使用 LoRA 在 8×A100 40G GPUs 上进行微调,训练时间为 7 小时,仅使用图像-语言(多轮对话)指令跟随数据。尽管只在图像-文本对上进行训练,PandaGPT通过利用来自冻结的ImageBind编码器的六种模态(图像/视频、文本、音频、深度、热图和 IMU)之间的绑定特性,表现出了新兴的零样本跨模态能力。这使得 PandaGPT 在图像/视频基础的问答、图像/视频启发的创意写作、视觉和听觉推理等任务中表现出色。
2.3 适配器混合:

图 6. 多模态适配适应(MMA)概述(图片由作者提供)
Cheap&Quick采用轻量级适配器来集成大型语言模型(LLM)和视觉模型,处理视觉-语言任务。该论文提出了一种多模态适配器混合(MMA),旨在促进单模态和多模态指令之间的切换,同时不影响性能。提出了一种可学习的标记t作为模态选择器标记。这个标记指示输入特征的模态(即单模态或多模态输入),并向路由模块提供如何结合学习到的适配器输出的信息,如图 6 所示。适配器的公式为:
Z′=Z+s⋅router(f(Z),g(Z); t)
其中 Z 是输入特征,可以是单模态或连接的多模态(图像-文本)特征。模块f和g共享一个通用的单模态适配器架构。s 是一个缩放因子,路由器(⋅)函数根据模态标记t确定路由路径。
为了验证 MMA 的有效性,作者通过一个线性层将LLaMA和CLIP-ViT连接,并将 MMA 插入到 ViT 和LLaMA的多头注意力模块之前。适配器和投影层(仅 3.8M 参数)使用文本-图像数据和纯文本数据的混合,在 8 个 A100 GPUs 上训练了 1.4 小时。该方法显著降低了训练成本,同时在视觉-语言任务中保持了高性能。
2.4 作为基础的模态,无需训练 到目前为止,我们讨论了将单一模态模型连接起来创建多模态模型的论文。然而,迈向通用人工智能(AGI)需要一个能够处理来自不同模态数据的多模态模型,以应对各种任务,从微积分到基于描述生成图像。
最近,许多论文探讨了通过语言提示将预训练的多模态模型集成到一个统一模型中的方法,该模型能够在不进行额外训练的情况下处理跨不同模态的各种任务。在这种方法中,语言作为模型之间交换信息的中介。通过提示工程(例如,Visual ChatGPT,MM-REACT)或微调(例如,Toolformer,GPT4Tools),大型语言模型(LLMs)可以调用专门的基础模型来处理各种模态特定的任务。虽然这个话题超出了我们当前博客文章的范围,但你可以参考这些论文以获取更详细的信息。
在另一项类似的工作中,MAGIC提出了一种新颖的、无需训练的即插即用框架,并通过预训练的 CLIP 将图像嵌入作为基础来进行支撑。该框架将 GPT-2 与 CLIP 连接起来,以零样本方式执行基于图像的文本生成(例如,图像标题生成)。通过将 CLIP 图像嵌入与预训练的大型语言模型(LLM)在每个时间步生成的 top-k 词汇的相似度融入解码推理中,模型能够有效地利用视觉信息来引导文本生成。无需任何额外的训练,这种方法展示了在给定图像和文本提示的情况下生成视觉支撑故事的能力。
3. 高质量策划数据:
我们到目前为止讨论了对齐不同模态的各种方法;然而,重要的是要记住,在这种训练机制中拥有精心策划的高质量数据同样至关重要。例如,详细且准确的指令和回应能够显著提升大型语言模型在互动自然语言任务中的零样本表现。在互动视觉语言任务领域,高质量数据的可用性往往有限,这促使研究人员开发创新的方法来生成此类数据。
MiniGPT-4提出了一种两阶段的方法,用于策划一个详细的、能够遵循指令的图像描述数据集:
1-数据生成:第一阶段训练中的预训练模型用于生成详细的描述。对于给定的图像,使用精心设计的提示词,使预训练模型能够分步生成详细且富有信息的图像描述。
2- 后处理与过滤:生成的图像描述可能包含噪声或不连贯的描述。为了修正这些问题,采用 ChatGPT 根据特定的后处理要求和标准来优化生成的描述,并通过设计的提示词进行引导。然后,经过优化的数据集将手动验证,以确保图像-文本对的正确性和质量。
LLaVA提出了一种通过基于广泛可用的图像-文本配对数据查询 ChatGPT/GPT-4 生成多模态指令跟随数据的方法。他们设计了一种提示词,包含图像字幕、定位场景中物体的边界框以及几个用于上下文学习的示例。这种方法利用了现有的丰富数据集和 ChatGPT/GPT-4 的能力,生成高度详细和准确的多模态数据。
Video-ChatGPT采用了两种方法来生成高质量的视频指令数据。
1- 人工辅助注释:专家注释员通过向字幕添加全面的细节来丰富给定的视频-字幕对。
2- 半自动注释:这一过程涉及利用多个预训练模型的多步骤过程:
预训练的BLIP-2和GRiT模型用于分析视频中的关键帧。BLIP-2生成帧级别的字幕,而GRiT提供场景对象的详细描述。此外,预训练的Tag2Text模型为视频的每个关键帧生成标签。
GPT-3.5 模型用于合并过滤后的字幕并生成一个统一且连贯的视频级字幕。
MIMIC-IT生成了一个包含 280 万多模态指令-响应对的数据集,旨在增强视觉语言模型(VLMs)在感知、推理和规划方面的能力。为了展示高质量数据的重要性,他们使用 MIMIC-IT 数据集,在一天内使用 8 个 A100 GPU,经过 3 个周期的微调OpenFlamingo。最终模型优于OpenFlamingo,展现了更强的上下文学习和零-shot 学习能力。
本文所表达的观点仅代表我们个人意见,并不反映我们雇主的立场。
参考文献:
[1] LLaMA-Adapter: Zhang, Renrui, 等. “Llama-adapter: Efficient fine-tuning of language models with zero-init attention.” (2023).
[2] LLaMA-Adapter V2: Gao, Peng, 等. “Llama-adapter v2: Parameter-efficient visual instruction model.” (2023).
[3] MiniGPT-4: Zhu, Deyao, 等. “Minigpt-4: Enhancing vision-language understanding with advanced large language models.” (2023).
[4] LLaVA: Liu, Haotian, 等. “Visual instruction tuning.” (2024).
[5] LLaVa-1.5: Liu, Haotian, 等. “Improved baselines with visual instruction tuning.” (2024).
[6] Video-ChatGPT: Maaz, Muhammad, 等. “Video-chatgpt: Towards detailed video understanding via large vision and language models.” (2023).
[7] PandaGPT: Su, Yixuan, 等. “Pandagpt: One model to instruction-follow them all.” (2023).
[8] Cheap&Quick: Luo, Gen, 等. “Cheap and quick: Efficient vision-language instruction tuning for large language models.” (2024).
[9] RobustGER: Hu, Yuchen, 等. “Large Language Models are Efficient Learners of Noise-Robust Speech Recognition.” (2024).
[10] MAGIC: Su, Yixuan, 等. “Language models can see: Plugging visual controls in text generation.” (2022).
[11] Visual ChatGPT: Wu, Chenfei, 等. “Visual chatgpt: Talking, drawing and editing with visual foundation models.” (2023).
[12] MM-REACT: Yang, Zhengyuan, 等. “Mm-react: Prompting chatgpt for multimodal reasoning and action.” (2023).
[13] Toolformer: Schick, Timo, 等. “Toolformer: Language models can teach themselves to use tools.” (2024).
[14] GPT4Tools: Yang, Rui, 等. “Gpt4tools: Teaching large language model to use tools via self-instruction.” (2024).
[15] MIMIC-IT: Li, Bo, 等. “Mimic-it: Multi-modal in-context instruction tuning.” (2023).
[16] He, Junxian, 等. “Towards a unified view of parameter-efficient transfer learning.” (2021).
从并行计算原理到 CPU 和 GPU 架构的编程
适用于早期的机器学习工程师和数据科学家,旨在帮助他们理解内存基础、并行执行,以及如何为 CPU 和 GPU 编写代码。
·发表于Towards Data Science ·19 分钟阅读·2024 年 11 月 12 日
--

照片由Olav Ahrens Røtne提供,来源于Unsplash
本文旨在解释并行计算的基础知识。我们从基础开始,包括理解共享架构与分布式架构以及这些系统之间的通信。我们将探讨 GPU 架构以及如何通过编写代码(使用 C++ Kokkos)将架构原理映射到代码实现。最后,我们将使用运行 Kokkos 代码所获得的运行时数据来测量性能指标(加速),这些数据来自 CPU 和 GPU 架构下的向量-矩阵乘法,这是机器学习领域中最常见的操作之一。
本文的核心主题是探索和解答问题。虽然这似乎是一次漫长的旅程,但它将是值得的。让我们开始吧!
系统架构基础
我理解并行计算可以通过同时执行多个操作来节省时间。但是我听说系统时间与人类时间或挂钟时间不同。它们有什么不同之处?
计算中的最小时间单位称为时钟周期。它代表执行一个操作所需的最短时间,例如获取数据、执行计算或进行通信。时钟周期从技术上讲是指执行指令所需的状态变化。状态可以是处理器状态、数据状态、内存状态或控制信号。在一个时钟周期内,可能执行完整的指令、部分指令或多个指令。
CPU 允许每秒进行有限数量的状态变化。例如,一个 3GHz 时钟频率的 CPU 每秒可以进行 30 亿次状态变化。时钟频率有上限,因为每个时钟周期都会产生热量,过高的频率会因为产生的热量损坏 CPU 芯片。
因此,我们希望通过使用并行计算方法来充分利用可用的计算能力。目的是隐藏内存延迟(从内存获取第一条数据所需的时间),提高内存带宽(每单位时间传输的数据量),并增强计算吞吐量(每个时钟周期内执行的任务数量)。
为了比较性能,例如在计算并行程序效率时,我们使用壁钟时间而非时钟周期,因为壁钟时间包括了所有的实际时间开销,如内存延迟和通信延迟,这些是无法直接转换为时钟周期的。
基本系统的架构是什么样的?
一个系统可以由单个处理器、一个节点甚至一个集群组成。系统的一些物理构件是——
节点 — 一个包含多个处理器芯片的物理计算单元。多个节点组合形成一个集群。
处理器芯片 — 芯片包含多个被称为核心的处理单元。
核心 — 每个核心都能够运行一个独立的线程。
在设定的术语中,一个节点可以与处理器芯片建立一对多的关系,而每个处理器芯片可以与核心建立一对多的关系。下图给出了包含处理器和核心的节点的可视化描述。

现代节点具有四个八核心处理器,共享一个公共内存池。参考:康奈尔虚拟研讨会
系统的非物理组件包括线程和进程——
线程 — 线程是操作系统视为一个单独单元进行调度和执行的CPU 指令序列。
进程 — 在计算中,进程是一个资源分配的整体单元,包括内存、文件处理器、端口和设备。一个进程可以管理多个线程的资源。线程可以作为进程的组件进行建模。
那么,线程是在同一系统的核心上运行,还是可以分布在不同系统中执行单个程序?无论哪种情况,它们如何通信?这些线程的内存是如何处理的?它们共享内存,还是每个线程都有自己的独立内存?
单个程序可以在同一系统/节点或不同的系统/节点上的多个核心上执行。系统和程序的设计决定了它是否符合预期的执行策略。
在设计系统时,必须考虑三个关键方面:执行(线程如何运行)、内存访问(内存如何分配给这些线程)和通信(线程如何通信,特别是当它们需要更新相同的数据时)。需要注意的是,这些方面大多是相互依赖的。
执行
串行执行 — 这使用单个执行线程在任何时候处理单一数据项。
并行执行 — 在这种方式中,多个任务同时发生。在计算中,这可以是—
单个工人 — 单个执行线程同时处理多个数据项(CPU 中的向量指令)。想象一个人根据花色整理扑克牌。由于有四种花色要分类,这个人必须逐一浏览整副牌,整理每种花色的卡片。
协作工作 — 单个进程中的多个执行线程。它相当于多个人一起合作,根据花色对一副扑克牌进行排序。
独立工作 — 多个进程可以处理同一个问题,利用相同的节点或多个节点。在这种情况下,每个人都会单独整理自己的扑克牌。
以上的任意组合。

协作工作:两个工人需要插入同一花色的卡片。工人 A 持有部分结果
对于梅花花色,工人 B 被暂时阻塞。参考:康奈尔虚拟研讨会
内存访问
共享内存 — 当程序在同一系统上的多个核心(单个多线程进程)中运行时,进程中的每个线程都可以访问相同虚拟地址空间中的内存。
分布式内存 — 当程序使用多个进程(无论是在单个节点上还是跨不同节点)时,采用分布式内存设计。在这种架构中,每个进程拥有数据的一部分,其他进程必须向数据拥有者发送消息以更新它们各自的部分。即使多个进程运行在单个节点上,每个进程仍有自己的虚拟内存空间。因此,这类进程应使用分布式内存编程进行通信。
混合策略 — 多线程进程可以在同一节点或不同节点上运行,旨在通过共享内存编程利用单一节点上的多个核心。同时,它们还可以采用分布式内存策略与其他节点上的进程进行协调。可以想象,在上图中,多个人或线程在多个隔间工作。相同隔间中的工作者通过共享内存编程进行通信,而不同隔间中的工作者则通过分布式内存编程进行交互。

在分布式内存设计中,多个并行工作者被分配到不同的隔间(进程)。参考:康奈尔虚拟工作坊
通信
通信机制依赖于内存架构。在共享内存架构中,像OpenMP(开放多处理)这样的应用程序接口可以实现共享内存和数据的线程之间的通信。另一方面,MPI(消息传递接口)可以用于在分布式内存架构中,运行在同一节点或不同节点上的进程之间的通信。
并行化策略和性能
我们如何判断我们的并行化策略是否有效?
有多种方法,但在这里我们讨论的是效率和加速比。在并行计算中,效率指的是在计算过程中实际利用的资源与可用资源的比例。它是通过将实际资源利用率与峰值性能进行比较来确定的,也就是最优资源利用率。
实际处理器利用率指的是在特定时间段内执行的浮点运算(FLOP)数量。
峰值性能假设每个处理器核心在每个时钟周期内执行最大可能的浮点运算(FLOP)。
并行代码的效率是每秒实际浮点运算(FLOPS)与可能的峰值性能之比。
加速比用于评估效率,其计算公式为:

当程序受限于处理器的计算速度时,速度提升不能超过并行资源的数量。
使用加速比,测量并行效率的公式为:

假设代码的串行执行时间为 300 秒。在使用 50 个核心并行化任务后,整体的墙钟时间为 6 秒。在这种情况下,速度提升可以通过将串行执行的墙钟时间除以并行执行的墙钟时间来计算,结果为 300 秒/6 秒 = 50。我们通过将速度提升除以核心数来获得并行效率,即 50/50 = 1。这是最佳情况的示例:工作负载完美地进行了并行化,且所有核心得到了高效利用。
如果数据大小或任务数量增加,增加更多的计算单元是否会持续提高性能?
并不总是如此。在并行计算中,我们有两种类型的扩展性,分别是基于问题规模或并行任务数量的扩展性。
强扩展性 — 在保持问题规模不变的情况下增加并行任务的数量。然而,即使我们增加计算单元(核心、处理器或节点)来并行处理更多任务,也会有与这些单元之间或与主程序之间的通信相关的开销,例如发送和接收数据所花费的时间。
理想情况下,随着并行任务数量的增加,执行时间会减少。然而,如果代码在强扩展性下没有加速,可能表明我们为当前工作量使用了过多的任务。
弱扩展性 — 在这种情况下,随着任务数量的增加,问题规模也会增加,因此每个任务的计算量保持不变。如果您的程序具有良好的弱扩展性,您可以在相同的墙钟时间内,用两倍的节点运行一个规模为原来的两倍的问题。
确实,存在一些限制,无法并行化某些操作,对吗?
是的,将某些顺序操作并行化确实可能相当具有挑战性。并行化依赖于多个指令流和/或多个数据流。

不同类型的并行计算架构。参考:康奈尔虚拟工作坊
要了解什么可以并行化,我们来看看 CPU 中的 SIMD,它是通过矢量化实现的。
矢量化是一种编程技术,其中操作是一次性应用到整个数组,而不是逐个处理单个元素。它是通过使用处理器中的矢量单元实现的,矢量单元包括矢量寄存器和矢量指令。
考虑这样一种情况,我们遍历数组,并在 for 循环中对单个元素执行多个操作。当数据是独立的时,编写可矢量化的代码变得非常简单;请参见下面的示例:
do i, n
a(i) = b(i) + c(i)
d(i) = e(i) + f(i)
end do
在这个循环中,每次迭代都是独立的——意味着 a(i) 的处理与 a(i+1) 等是独立的。因此,这段代码是可向量化的,允许使用来自 b 和 c 的元素并行计算数组 a 的多个元素,示例如下:
b: | b(i) | b(i+1) | b(i+2) | b(i+3) | ... |
c: | c(i) | c(i+1) | c(i+2) | c(i+3) | ... |
------------------------------------------------
Vectorized Addition (SIMD)
Vector Register 1 (loaded with b values):
| b(i) | b(i+1) | b(i+2) | b(i+3) | ... |
Vector Register 2 (loaded with c values):
| c(i) | c(i+1) | c(i+2) | c(i+3) | ... |
------------------------------------------------
Result in Vector Register 3:
| a(i) | a(i+1) | a(i+2) | a(i+3) | ... |
现代编译器通常能够分析这样的循环并将其转化为向量操作序列。 当某次迭代中的操作依赖于前一次迭代的结果时,问题就出现了。在这种情况下,自动向量化可能导致错误的结果。这种情况被称为数据依赖性。
科学代码中常遇到的数据依赖性包括 -
读后写(RAW) — 不可向量化
do i, n
a(i) = a(i-1) +b(i)
写后读(WAR) — 可向量化
do i, n
a(i) = a(i+1) +b(i)
写后写(WAW) — 不可向量化
do i, n
a(i%2) = a(i+1) +b(i)
读后读(RAR) — 可向量化
do i, n
a(i) = b(i%2) + c(i)
遵循某些标准规则进行向量化——例如确保循环迭代中的独立赋值、避免随机数据访问以及防止迭代之间的依赖性——可以帮助编写可向量化的代码。
GPU 架构与跨架构代码
当数据量增加时,尽可能将更多可并行化的操作进行并行化,以创建可扩展的解决方案是有意义的,但这意味着我们需要更大的系统,拥有更多的核心。难道这就是为什么我们使用 GPU 吗?它们与 CPU 有何不同,是什么导致了它们的高吞吐量?
是的!

比较 CPU 和 GPU 架构中基本元素的相对能力。参考:康奈尔虚拟研讨会
GPU(图形处理单元) 相比 CPU,拥有更多的处理单元(绿色)和更高的整体 内存带宽(每单位时间传输的数据量)。而 CPU 则拥有更复杂的指令处理和更快的时钟速度。如上所示,CPU 的缓存内存比 GPU 多。然而,CPU 的算术逻辑单元(ALU)和浮点单元(FPU)少于 GPU。考虑到这些因素,使用 CPU 处理复杂的工作流,而将 GPU 用于计算密集型任务是直观的。
GPU 被设计用来利用其大规模并行架构产生高计算吞吐量。其计算潜力可以通过每秒数十亿次浮点运算(GFLOPS)来衡量。GPU 硬件通常以标准显卡(如 NVIDIA Quadro)、高端加速卡(如 NVIDIA Tesla)等形式出现。
图形管线的两个关键特性使得并行化成为可能,从而实现高吞吐量——
对象的独立性 — 一个典型的图形场景由许多独立的对象组成;每个对象可以独立处理,而不依赖于其他对象。
统一处理步骤 — 处理步骤的顺序对于所有对象都是相同的。
因此,GPU 的多个核心可以同时处理不同的数据,像 SIMD(单指令多数据)架构一样并行执行计算。任务是如何在核心之间划分的?每个核心是否像 CPU 一样只运行一个线程?
在 GPU 中,**流式多处理器(SMs)**类似于 CPU 中的核心。GPU 中的核心类似于 CPU 中的向量通道。SM 是容纳核心的硬件单元。
当一个被称为内核(kernel)的函数或计算在 GPU 上执行时,它通常会被分解成线程块(thread blocks)。这些线程块包含多个线程(multiple threads);每个 SM 可以管理其核心上的多个线程。如果线程块的数量超过 SM 的数量,多个线程块可以被分配到单个 SM 上。同时,多个线程可以在单个核心上运行。
每个 SM 进一步将线程块(thread blocks)划分为称为warps的组,每个 warp 由32 个线程组成。 这些线程在不同的数据元素上执行相同的指令流,遵循**单指令多数据(SIMD)**模型。warp 大小被设置为 32,因为在 NVIDIA 的架构中,CUDA 核心是按 32 个一组进行分组的。这使得 warp 中的所有线程可以通过 32 个 CUDA 核心并行处理,从而实现高效率和优化的资源利用。
在**SIMD(单指令多数据)**中,单一指令对所有数据元素执行相同的操作,每个数据元素以完全相同的方式进行处理。SIMT(单指令多线程),这是 GPU 中常用的方式,放宽了这一限制。在 SIMT 中,线程可以被激活或停用,因此指令和数据在激活的线程中被处理;然而,局部数据在非激活线程中保持不变。
我想理解如何编写代码以利用不同的架构。相似的代码可以同时在 CPU 和 GPU 架构上工作吗?我们可以使用哪些参数和方法来确保代码有效地利用底层硬件架构,无论是 CPU 还是 GPU?
代码通常用高级语言如 C 或 C++编写,并且必须通过编译器转换为二进制代码,因为计算机无法直接处理高级指令。虽然 GPU 和 CPU 都可以执行相同的内核,但正如我们将在示例代码中看到的,我们需要使用指令或参数来运行代码在特定架构上,编译并生成该架构的指令集。这种方法使我们能够利用架构特定的功能。为了确保兼容性,我们可以为编译器指定适当的标志,以生成针对所需架构优化的二进制代码,无论是 CPU 还是 GPU。
各种编码框架,如 SYCL、CUDA 和 Kokkos,广泛用于为不同架构编写内核或函数。在本文中,我们将使用来自 Kokkos 的示例。
关于 Kokkos — 一个开源的 C++ 编程模型,旨在实现性能可移植性,用于编写内核:它作为一个模板库,构建在 CUDA、OpenMP 和其他后端之上,旨在描述性地定义我们想要做的事情,而非规定性地定义我们想如何做。Kokkos Core 提供了一种并行算法的编程模型,适用于多核芯片,并在这些核心之间共享内存。
内核有三个组件—
模式 — 计算的结构:for、scan、reduction、task-graph
执行策略 — 计算如何执行:静态调度、动态调度、线程组。
计算体 — 执行每个工作单元的代码。例如,循环体
模式和策略驱动计算体。在下面的示例中,仅为说明用途,‘for**’** 是 模式,控制模式的 条件(element=0; element<n; ++element)是 策略,而 计算体 是模式内执行的代码
for (element=0; element<n; ++element){
total = 0;
for(qp = 0; qp < numQPs; ++qp){
total += dot(left[element][qp], right[element][qp]);
}
elementValues[element] = total;
}
Kokkos 框架允许开发人员根据三个关键因素来定义参数和方法:代码将在哪运行 (执行空间),将使用哪些内存资源 (内存空间),以及数据如何结构化和管理 (数据结构与数据管理)。
我们主要讨论如何编写 Kokkos 内核来实现向量矩阵乘法,以了解这些因素如何在不同架构上实现。
但在此之前,我们先来讨论一下我们想要编写的内核的构建块。
内存空间 —
Kokkos 提供了多种内存空间选项,使用户能够控制内存管理和数据在不同计算平台上的布置。一些常用的内存空间包括—
HostSpace — 此内存空间代表 CPU 的主内存。它用于 CPU 上的计算,通常是基于 CPU 的系统的默认内存空间。
CudaSpace — CudaSpace 用于具有 CUDA 的 NVIDIA GPU。它提供 GPU 设备的内存分配和管理,支持高效的数据传输和计算。
CudaUVMSpac — 对于统一虚拟内存(UVM)系统,如某些 NVIDIA GPU 上的 UVM,CudaUVMSpac 使得可以分配从 CPU 和 GPU 都可以访问的内存,而无需显式的数据传输。Cuda 运行时自动处理数据移动,但会有性能损失。
讨论内存布局也非常重要,它指的是内存中数据的组织和安排。Kokkos 提供了几种内存布局选项,帮助用户优化不同计算的数据显示方式。一些常用的内存布局包括—

矩阵的行主序与列主序迭代。参考:Wikipedia
**LayoutRight(也称为行主序)**是 C 和 C++ 中多维数组的默认内存布局。在 LayoutRight 中,最右边的索引在内存中变化最快。如果未选择布局,HostSpace 的默认布局为 LayoutRight。
LayoutLeft(也称为列主序)— 在 LayoutLeft 中,最左边的索引在内存中变化最快。如果未选择布局,CudaSpace 的默认布局为 LayoutLeft。
在下面的程序实现中,我们根据编译器标志 ENABLE_CUDA 定义了内存空间和布局的宏,当我们希望在 GPU 上运行代码时,ENABLE_CUDA 为 True,CPU 上则为 False。
// ENABLE_CUDA is a compile time argument with default value true
#define ENABLE_CUDA true
// If CUDA is enabled, run the kernel on the CUDA (GPU) architecture
#if defined(ENABLE_CUDA) && ENABLE_CUDA
#define MemSpace Kokkos::CudaSpace
#define Layout Kokkos::LayoutLeft
#else
// Define default values or behavior when ENABLE_CUDA is not set or is false
#define MemSpace Kokkos::HostSpace
#define Layout Kokkos::LayoutRight
#endif
数据结构与数据管理 —
Kokkos 视图 — 在 Kokkos 中,"视图"是一个基础数据结构,表示一维和多维数组,可用于高效地存储和访问数据。Kokkos 视图为数据管理提供了高级抽象,并设计为与不同的执行空间和内存布局无缝协作。
// View for a 2d array of data type double
Kokkos::View<double**> myView("myView", numRows, numCols);
// Access Views
myView(i, j) = 42.0;
double value = myView(i, j);
Kokkos 数据管理的镜像技术 — 镜像是等效数组的视图,这些数组可能位于不同的内存空间中,适用于在 CPU 和 GPU 架构中都需要数据的场景。这项技术在读取 CPU 上的文件数据并随后在 GPU 上处理时非常有用。Kokkos 的镜像技术创建了数据的镜像视图,允许在 CPU 和 GPU 执行空间之间无缝共享,并促进数据传输和同步。
为了创建主数据的镜像副本,我们可以使用 Kokkos 的 create_mirror_view() 函数。该函数在指定的执行空间(例如 GPU)中生成一个具有与主视图相同数据类型和维度的镜像视图。
// Intended Computation -
// <y, A*x> = y^T * A * x
// Here:
// y and x are vectors.
// A is a matrix.
// Allocate y, x vectors and Matrix A on device
typedef Kokkos::View<double*, Layout, MemSpace> ViewVectorType;
typedef Kokkos::View<double**, Layout, MemSpace> ViewMatrixType;
// N and M are number of rows and columns
ViewVectorType y( "y", N );
ViewVectorType x( "x", M );
ViewMatrixType A( "A", N, M );
// Create host mirrors of device views
ViewVectorType::HostMirror h_y = Kokkos::create_mirror_view( y );
ViewVectorType::HostMirror h_x = Kokkos::create_mirror_view( x );
ViewMatrixType::HostMirror h_A = Kokkos::create_mirror_view( A );
// Initialize y vector on host.
for ( int i = 0; i < N; ++i ) {
h_y( i ) = 1;
}
// Initialize x vector on host.
for ( int i = 0; i < M; ++i ) {
h_x( i ) = 1;
}
// Initialize A matrix on host.
for ( int j = 0; j < N; ++j ) {
for ( int i = 0; i < M; ++i ) {
h_A( j, i ) = 1;
}
}
// Deep copy host views to device views.
Kokkos::deep_copy( y, h_y );
Kokkos::deep_copy( x, h_x );
Kokkos::deep_copy( A, h_A );
执行空间 —
在 Kokkos 中,执行空间指的是并行操作和计算执行的特定计算环境或硬件平台。Kokkos 抽象了执行空间,使得代码可以以描述性方式编写,同时适应各种硬件平台。
我们讨论两个主要的执行空间 —
Serial:Serial 执行空间是一种主要且可移植的选项,适用于单线程的 CPU 执行。它通常用于调试、测试以及作为性能比较的基准。
Cuda:Cuda 执行空间用于 NVIDIA GPU,并依赖 CUDA 技术进行并行处理。它可以实现高效的 GPU 加速和 GPU 内存管理。
可以定义 ExecSpace,也可以根据内存空间动态确定,如下所示:
// Execution space determined based on MemorySpace
using ExecSpace = MemSpace::execution_space;
我们如何利用这些构件来编写实际的内核?我们能否用它来比较不同架构之间的性能?
为了编写一个内核并进行性能比较,我们使用以下计算:
<y, A*x> = y^T * (A * x)
Here:
y and x are vectors.
A is a matrix.
<y, A*x> represents the inner product or dot product of vectors y
and the result of the matrix-vector multiplication A*x.
y^T denotes the transpose of vector y.
* denotes matrix-vector multiplication.
这个操作在 Kokkos 中的内核—
// Use a RangePolicy.
typedef Kokkos::RangePolicy<ExecSpace> range_policy;
// The below code is run for multiple iterations across different
// architectures for time comparison
Kokkos::parallel_reduce( "yAx", range_policy( 0, N ),
KOKKOS_LAMBDA ( int j, double &update ) {
double temp2 = 0;
for ( int i = 0; i < M; ++i ) {
temp2 += A( j, i ) * x( i );
}
update += y( j ) * temp2;
}, result );
对于上述内核,parallel_reduce作为模式,range_policy定义了策略,而实际的操作构成了计算体。
我在一台配有 NVIDIA Quadro RTX 5000 GPU 的 TACC Frontera 节点上执行了此内核实验。实验使用了不同的N值,N代表向量y和x的长度,以及矩阵A的行数。计算执行了 100 次以获得显著结果,并记录了内核在串行(主机)和 CUDA 执行空间中的执行时间。我使用了**ENABLE_CUDA**编译器标志在执行环境之间切换:True表示 GPU/CUDA 执行空间,False表示 CPU/串行执行空间。以下展示了这些实验的结果及相应的加速比。
有关 CPU 与 GPU 架构下内核执行时间和加速比的数据参考:作者提供的表格

不同数据大小下的加速比趋势(GPU vs CPU)参考:作者提供的图片
我们注意到,随着 N 值的增大,加速比显著增加,这表明 CUDA 实现对于较大规模的问题变得越来越有优势。
目前为止就这些!希望本文能在探索计算领域时为您提供一个良好的开端。理解 GPU 架构的基础至关重要,本文介绍了一种我实验过的跨架构代码编写方法。然而,还有多种值得探索的方法和技术。
虽然我不是该领域的专家,但本文反映了我在德州奥斯汀 TACC 短暂工作的学习经历。我欢迎反馈和讨论,如果您有任何问题或想了解更多内容,我很乐意提供帮助。请参考下面的优秀资源以进一步学习。祝计算愉快!
致谢
本文参考了三个主要来源。第一个来源是SDS394: UT Austin 的科学与技术计算研究生课程,该课程提供了关于单核多线程系统的基础知识。第二个来源是Cornell 虚拟工作坊:并行编程概念与高性能计算,这是一个关于并行计算的优秀学习资源。Kokkos 代码实现主要基于GitHub 上的 Kokkos 教程中的材料。这些都是任何有兴趣学习并行计算的人的绝佳资源。
— — — — — — — — — —
C++ Kokkos 内核的开发,用于在不同架构之间进行性能比较,作为由Intel OneAPI 卓越中心和TACC STAR 学者计划支持的项目的一部分,该项目得到了 TACC 行业合作伙伴的慷慨资助,其中包括Intel、Shell、Exxon和Chevron。
— — — — — — — — — —
参考资料/资源:
[## GitHub - VictorEijkhout/TheArtofHPC_pdfs:Victor Eijkhout《高性能计算艺术》书籍及课程的所有 pdf]
Victor Eijkhout《高性能计算艺术》书籍及课程的所有 pdf - VictorEijkhout/TheArtofHPC_pdfs
github.com [## Frontera - TACC HPC 文档]
最近更新:2024 年 10 月 24 日 重要提示:请注意 TACC 的新 SU 收费政策。Frontera 由国家资助…
docs.tacc.utexas.edu [## 康奈尔虚拟研讨会:并行编程概念与高性能计算]
本路线图解释了并行编程概念,以及并行编程如何与高性能计算相关……
本路线图面向那些相对较新的 GPU 使用者或希望深入了解其中的计算机技术的人员……
cvw.cac.cornell.edu [## 康奈尔虚拟研讨会:矢量化]
本路线图描述了与计算硬件、编译器和编码实践相关的矢量化过程……
cvw.cac.cornell.edu [## GitHub — kokkos/kokkos-tutorials:Kokkos C++ 性能可移植性编程教程]
Kokkos C++ 性能可移植性编程生态系统教程 — kokkos/kokkos-tutorials
github.com [## Kokkos 讲座系列]
Kokkos C++ 性能可移植性编程生态系统教程 - Kokkos 讲座系列 ·…
从概率方法到预测方法:掌握客户生命周期价值的技术
本章是关于 CLV 分析和预测实际应用的综合实用指南的最后一章。
·发表于Towards Data Science ·14 分钟阅读·2024 年 5 月 3 日
--

我和我的 iPad 回来了,带来了更多粗糙的图表,这是我的《客户生命周期价值》指南的最后一部分(适用于市场营销人员和数据科学家),涵盖了所有关于客户生命周期价值的内容。
欢迎再次来到我的系列文章《客户生命周期价值:好的、坏的,以及其他 CLV 博客没有告诉你的所有内容》。这完全基于我在电商领域领导 CLV 研究的经验,也是我希望从一开始就知道的所有内容:
第一部分讨论了如何从历史CLV 分析中获得可操作的洞察
第二部分涵盖了 CLV 预测的实际应用案例。
接下来,我们讨论了历史性 CLV 建模方法,包括每种方法的实际优缺点。
从使用案例示例到实际应用的过渡,带领我们进入今天的文章,主题是CLV 预测:有哪些方法可用,市场营销人员和数据科学家在尝试将这些方法应用于自己的数据时,可以期待每种方法带来什么效果?我们将探讨概率方法与机器学习方法的不同,以及各自的一些优缺点,并最后分享一些关于如何开始自己的 CLV 之旅的思考。
从提示工程到代理工程
引入一个实用的代理工程框架
·发表于数据科学前沿 ·9 分钟阅读·2024 年 5 月 29 日
--

来源:图片由作者提供,并使用 MidJourney 生成
简介
自从 ChatGPT 发布以来已经过去一年多,可以明显看出公众对“AI”的看法发生了巨大变化。这部分是由于公众认知的提升,但更大程度上是因为人们意识到,基于 AI 的系统可能(已经能够?)达到与人类相当的能力和表现。在许多方面,ChatGPT 已经成为整个 AI 领域的概念验证展示。这个展示的工作始于半个多世纪前,现在已经提供了令人信服的证据,表明我们离创造“能够执行人类所需智能的机器”的现实越来越近,借用雷·库兹韦尔的定义。因此,AI 代理的讨论和发展在最近几个月爆炸性增长也就不足为奇了。它们是 AI 一直以来追求的目标的体现。
需要明确的是,AI 代理的概念并不是一个新概念。我们的想象力已经多次进入过这个领域——《星际大战》中的 C-3PO 是最典型的体现 AI,能够理解自然语言、进行对话,并具备自主行动能力。在学术界的更正式领域,诺尔维格和拉塞尔的 AI 教科书《人工智能:现代方法》指出,智能代理是主要的统一主题。无论是源自科学还是科幻的 AI 代理理念,随着像 ChatGPT、Claude 和 Gemini 等模型的出现,它们似乎变得更加可实现,这些模型在多个知识领域表现出广泛的能力,并具备强大的理解力和人类级别的对话能力。再加上像“视觉”和功能调用等新能力,AI 代理开发的舞台已然搭建完成。
随着我们朝着 AI 代理开发的方向快速前进,似乎有必要从提示工程转向更广泛的领域,即代理工程,并建立适当的框架、方法论和思维模型,以有效地设计它们。在本文中,我旨在探讨一些在 LLM 环境中与代理工程相关的关键思想和原则。
让我们从高层次上探讨代理工程框架的关键部分。我们从“代理能力要求”开始,在这里我们旨在清晰地定义代理需要做什么以及它需要具备多高的熟练度。在“代理工程与设计”部分,我们评估可用的技术,并开始思考我们的代理的构造和协调。
这个框架的早期阐述旨在成为一个实用的思维模型,诚然在各方面并不全面。但我相信从某个地方开始并逐步改进和增强是有价值的。
代理工程框架
简介
构建 AI 代理的目的是什么?它有工作或角色吗?是支持目标的行动?还是支持行动的目标?在某些工作中,一个多功能代理比多代理群体更合适吗?人类语言的美在于它的灵活性,并允许我们以多种方式比喻性地扩展概念。其缺点是,这可能导致歧义。在阐述框架时,我故意避免在关键术语之间解析语义区别,因为它们中的许多可以互换使用。我们努力提出一些概念,这些概念在 AI 代理工程的广泛应用中具有普遍性。因此,在这一阶段,框架更像是一个思维模型,旨在指导围绕代理工程的思维过程。正如您在下面的图形中看到的那样,核心思想相对简单:

AI 代理被赋予工作
工作需要动作来完成
执行动作需要能力
能力具有所需的熟练程度
所需的熟练程度需要技术和技巧
技术和技术需要协调
代理能力要求
待完成的任务
设计 AI 代理的初步步骤是明确列出代理需要做什么。代理需要完成的主要工作、任务或目标是什么?这可以以高层次的目标来框定,或分解为具体的工作和任务。你可以决定使用多代理群体的方法,并为每个代理分配一个任务。语言和细节的层次可以有所不同。例如:
一个电子商务聊天机器人的工作可能包括处理一般查询、分析客户行为并进行产品推荐。
内容创作代理的工作可能包括头脑风暴内容创意,以及起草文章或博客帖子。
请注意,在这两种情况下,诸如工作、任务、目标等标签可以在代理需要完成的任务的上下文中互换使用。
执行任务所需采取的行动
一旦待完成的任务被定义,下一步是确定代理需要执行的具体行动。这时的重点不再仅仅是定义代理需要达成的目标,而是要指定代理如何通过具体的行动来完成这些目标。在这个阶段,考虑代理的适当自主性水平也变得十分重要。例如:
对于内容创作代理,行动可能包括:
调用另一个代理获取内容趋势洞察。
根据这些趋势和目标受众的偏好生成内容创意列表。
起草文章或博客帖子。
根据人类用户的编辑和调整指令进行操作。
内容创作代理可能会自主生成和起草内容,人工编辑提供最终审核。或者,可能会雇佣一个单独的代理编辑来进行首次审核,然后再由人工编辑介入。
所需能力
现在我们已经概述了代理需要采取的行动来执行任务,我们接下来将阐明实现这些行动所需的能力。这些能力可以包括从自然语言对话、信息检索、内容生成、数据分析、持续学习等方面的内容。它们还可以以更技术化的层面来表达,如 API 调用、函数调用等。例如,对于我们的内容创作代理,所需的能力可能是:
动态调用内容趋势代理的 API。
头脑风暴和总结能力。
针对当前话题和趋势生成合适的内容。
能够根据人类提供的编辑指令进行操作。
内存
最终,重要的是以一种不会限制选择和最终技术选择的方式来表达这些能力。例如,虽然我们都非常看重大语言模型,但大行动模型(LAMs)正在迅速发展,可能对于实现所需能力非常相关。
所需能力的熟练度水平
确定代理执行其任务所需的能力是一个至关重要的步骤,但同样重要的是评估和定义每项能力所需的熟练度水平。这包括设定具体的基准和绩效指标,只有在这些指标得到满足时,代理及其能力才会被认为是熟练的。这些基准可以包括准确性、效率和可靠性。
例如,对于我们的内容创作代理,期望的熟练度水平可能包括:
75% 的函数调用可靠性。
99.99% 的失败函数调用可解释性。
内容趋势代理的函数调用能够在第一次尝试时至少 75% 的概率生成相关主题。
内容创意的生成能够在 75% 的时间内产生期望的主题。
精确率达到 99.99% 的信息检索。
用户反馈满意度达到 90% 或更高的编辑生成。
用户反馈满意度达到 90% 或更高的最终草稿生成。
代理工程与设计
将所需能力与技术和技巧对应
一旦明确了所需的能力和所需的熟练度水平,下一步就是确定我们如何满足这些要求。这涉及评估快速增长的可用技术和技巧库,包括大语言模型(LLMs)、RAG、Guardrails、专用 API 以及其他机器学习/人工智能模型,以评估它们是否能达到指定的熟练度水平。在所有情况下,考虑任何给定技术或技巧的最佳应用场景以及其成本/收益影响是很有帮助的。我将在这里简要讨论一些内容,但由于有无数的可能性,这个讨论将会在范围和深度上有所限制。
广泛的知识熟练度
广泛的知识指的是对多个领域和主题的广泛理解和信息掌握。这种类型的知识对于创建能够有效进行对话、理解上下文,并在各种主题上提供相关响应的人工智能代理至关重要。
大语言模型(LLMs)—— 如果您的代理需要广泛的知识熟练度,值得庆幸的是,大语言模型的开发仍在不断推进。从像 LlaMA3 这样的开源模型,到 OpenAI、Anthropic 和 Google 的最新专有模型,提供高度一致性的技术选择应有尽有,能够覆盖广泛的人类语言和知识领域。
提示工程 — 这个动态且非常活跃的发展领域,主要关注如何激活由大语言模型(LLMs)建模的上下文相关知识领域。由于语言的万花筒特性,掌握这项技能可以对我们代理人的能力熟练度产生巨大影响。
特定知识熟练度
特定知识涉及对某些领域或主题的深入理解。这类知识对于那些需要详细专业知识和熟悉特定内容的任务是必要的。在我们追求熟练度目标时,可能需要考虑哪些技术/方法?
检索增强生成(RAG)— RAG 将 LLMs 的生成能力与信息检索系统结合起来,整合外部来源的信息。这些信息可能是精确的数据,或是 LLM 能够在上下文中“理解”的特定知识(例如描述一种独特方法的内容),因为 LLM 具备广泛的知识熟练度。
模型微调 — 在特定上下文数据集上对 LLMs 进行微调,可以使模型在特定场景中生成更符合上下文的相关回应。尽管这种方法不像 RAG 那样流行,但随着代理人工程的逐步发展,我们可能会发现,能力要求会促使我们更频繁地使用这种技术。
精确的信息
精确的信息指的是高度准确和具体的数据点,这些数据对于需要精确答案的任务至关重要。
函数调用(也叫工具使用) — 我们都警惕那些在特定事实上产生幻觉的 AI 代理人,这些事实是无可争议的。如果所需的熟练度要求如此,我们可能会转向函数调用来检索特定信息。例如,一个电子商务代理人可能使用定价 API 来提供最新的产品价格,或使用股票市场 API 来实时更新股票价值。
防护措施 — 防护措施有助于确保代理人在其回应中提供精确和准确的信息。这可能涉及基于规则的约束、引导式对话设计和意图预处理。
代理人结构与协调
现在我们已经清楚代理人的工作内容、所需的能力和熟练度,以及实现这些的技术,我们将重点转向代理人的结构和协调,不论是在单独配置下,还是在某种类型的群体或生态系统中。能力是应该注册到一个代理人上,还是每个能力都分配给一个独特的代理人在群体中运行?我们如何开发可以最小化努力进行重新利用的能力和代理人?这个话题本身涉及多个文章,因此我们在这里不再深入探讨。从某种意义上说,这就是“纸面上的理论转化为实际操作”的地方,我们需要将多种技术和技巧编织在一起,为我们的代理人注入生命。
结论
从提示工程(Prompt Engineering)到智能体工程(Agent Engineering)的旅程才刚刚开始,在这个过程中还有许多内容需要学习和完善。这篇关于智能体工程框架的初步尝试,提出了一种实际的方法来设计 AI 智能体,通过概述一个高级的思维模型,作为这一演变过程中的一个有用起点。构建智能体的模型和技术将继续不断增加,迫切需要一种框架,能够脱离任何特定技术或技术类别的限制。通过清晰地定义智能体需要完成的任务,概述执行这些任务所需的动作,并指定所需的能力和熟练度水平,我们为设计和工程工作奠定了一个强大且灵活的基础。它还为智能体及其能力的改进和进化提供了一个结构。
感谢阅读,希望你在智能体相关的工作中,能从这个智能体工程框架中获得帮助。请关注未来框架的进一步完善以及对文中提到的各种话题的详细阐述。如果你想进一步讨论这个框架或我写过的其他话题,欢迎通过LinkedIn与我联系。
除非另有说明,本文中的所有图片均为作者所有。
从原型到生产:提高 LLM 的准确性
实施评估框架以优化现实应用中的准确性
·发表于 Towards Data Science ·20 分钟阅读·2024 年 12 月 19 日
--

图像由 DALL-E 3 生成
构建 LLM 应用程序的原型出乎意料地简单。你通常可以在几个小时内创建一个功能性的初始版本。这个初始原型可能会提供看起来合法的结果,并且是展示你方法的一个很好的工具。然而,这通常不足以满足生产环境的需求。
LLMs 本质上是概率性的,因为它们根据可能的延续分布生成标记。这意味着在许多情况下,我们从分布中得到的答案接近“正确”答案。有时,这种结果是可以接受的——例如,应用程序说“Hello, John!”或“Hi, John!”并没有什么区别。在其他情况下,差异则至关重要,比如“2024 年的收入是 2000 万美元”和“2024 年的收入是 2000 万英镑”之间的区别。
在许多现实世界的商业场景中,精度至关重要,“差不多正确”是不够的。例如,当你的 LLM 应用需要执行 API 调用,或者你在做财务报告的总结时。根据我的经验,确保结果的准确性和一致性比构建初始原型要复杂且耗时得多。
在本文中,我将讨论如何衡量和提高准确性。我们将构建一个 SQL 代理,其中精度对于确保查询可执行至关重要。从一个基本的原型开始,我们将探索衡量准确性的方法,并测试各种增强准确性的技术,如自我反思和检索增强生成(RAG)。
设置
如往常一样,让我们从设置开始。我们 SQL 代理解决方案的核心组件是 LLM 模型,它生成查询,以及执行查询的 SQL 数据库。
LLM 模型 — Llama
对于这个项目,我们将使用 Meta 发布的开源 Llama 模型。我选择了Llama 3.1 8B,因为它足够轻量,可以在我的笔记本上运行,同时仍然相当强大(详细信息请参阅文档)。
如果你还没有安装,可以在这里找到指南。我通过Ollama在 MacOS 上本地使用它。使用以下命令,我们可以下载该模型。
ollama pull llama3.1:8b
我们将使用 Ollama 与LangChain一起,所以让我们从安装所需的包开始。
pip install -qU langchain_ollama
现在,我们可以运行 Llama 模型并查看第一个结果。
from langchain_ollama import OllamaLLM
llm = OllamaLLM(model="llama3.1:8b")
llm.invoke("How are you?")
# I'm just a computer program, so I don't have feelings or emotions
# like humans do. I'm functioning properly and ready to help with
# any questions or tasks you may have! How can I assist you today?
我们希望在客户提问时传递系统消息。所以,按照《Llama 3.1 模型文档》,让我们编写一个辅助函数来构造提示并测试这个函数。
def get_llama_prompt(user_message, system_message=""):
system_prompt = ""
if system_message != "":
system_prompt = (
f"<|start_header_id|>system<|end_header_id|>\n\n{system_message}"
f"<|eot_id|>"
)
prompt = (f"<|begin_of_text|>{system_prompt}"
f"<|start_header_id|>user<|end_header_id|>\n\n"
f"{user_message}"
f"<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
return prompt
system_prompt = '''
You are Rudolph, the spirited reindeer with a glowing red nose,
bursting with excitement as you prepare to lead Santa's sleigh
through snowy skies. Your joy shines as brightly as your nose,
eager to spread Christmas cheer to the world!
Please, answer questions concisely in 1-2 sentences.
'''
prompt = get_llama_prompt('How are you?', system_prompt)
llm.invoke(prompt)
# I'm feeling jolly and bright, ready for a magical night!
# My shiny red nose is glowing brighter than ever, just perfect
# for navigating through the starry skies.
新的系统提示显著改变了答案,因此它是有效的。至此,我们的本地 LLM 设置已经准备好。
数据库 — ClickHouse
我将使用一个开源数据库ClickHouse。我选择 ClickHouse 是因为它有一个特定的 SQL 方言。在训练过程中,LLM 模型可能遇到的这种方言的示例较少,这使得任务稍微具有挑战性。然而,你可以选择任何其他数据库。
安装 ClickHouse 相当简单 — 只需按照文档中提供的指示操作。
我们将使用两个表:ecommerce.users和ecommerce.sessions。这些表包含虚构的数据,包括客户个人信息和他们在电商网站上的会话活动。


你可以在GitHub上找到生成合成数据并上传的代码。
至此,设置完成,我们准备好开始构建基础原型。
第一个原型
如前所述,我们的目标是构建一个 SQL 代理——一个生成 SQL 查询以回答客户问题的应用程序。未来,我们可以在这个系统上添加另一层:执行 SQL 查询,将初始问题和数据库结果传回 LLM,并让它生成一个易于理解的答案。然而,本文将集中于第一步。
使用 LLM 应用程序的最佳实践(类似于任何其他复杂任务)是从简单开始,然后进行迭代。最直接的实现方法是执行一次 LLM 调用,并在系统提示中共享所有必要的信息(如模式描述)。因此,第一步是准备提示语。
generate_query_system_prompt = '''
You are a senior data analyst with more than 10 years of experience writing complex SQL queries.
There are two tables in the database with the following schemas.
Table: ecommerce.users
Description: customers of the online shop
Fields:
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- country (string) - country of residence, for example, "Netherlands" or "United Kingdom"
- is_active (integer) - 1 if customer is still active and 0 otherwise
- age (integer) - customer age in full years, for example, 31 or 72
Table: ecommerce.sessions
Description: sessions of usage the online shop
Fields:
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- session_id (integer) - unique identifier of session, for example, 106 or 1023
- action_date (date) - session start date, for example, "2021-01-03" or "2024-12-02"
- session_duration (integer) - duration of session in seconds, for example, 125 or 49
- os (string) - operation system that customer used, for example, "Windows" or "Android"
- browser (string) - browser that customer used, for example, "Chrome" or "Safari"
- is_fraud (integer) - 1 if session is marked as fraud and 0 otherwise
- revenue (float) - income in USD (the sum of purchased items), for example, 0.0 or 1506.7
Write a query in ClickHouse SQL to answer the following question.
Add "format TabSeparatedWithNames" at the end of the query to get data from ClickHouse database in the right format.
'''
我在提示中为每个字段包含了示例值,以确保 LLM 理解数据格式。
就这样!通过这个,我们获得了 SQL 代理的第一个功能原型。现在,是时候进行测试了。
prompt = get_llama_prompt('How many customers made purchase in December 2024?',
generate_query_system_prompt)
llm.invoke(prompt)
# To find out how many customers made a purchase in December 2024,
# we need to join the `sessions` table with the `users` table on the `user_id`,
# and then filter by `action_date` for sessions that took place in December 2024\.
# Here is the query:
#
# ```sql
# SELECT COUNT(DISTINCT u.user_id)
# FROM ecommerce.sessions s
# JOIN ecommerce.users u ON s.user_id = u.user_id
# WHERE
# EXTRACT(YEAR FROM s.action_date) = 2024
# AND EXTRACT(MONTH FROM s.action_date) = 12
# AND revenue > 0;
# ```py
# This query first joins the `sessions` table with the `users` table
# on the `user_id`. Then it filters out sessions where no purchase
# was made (revenue is greater than 0). Finally, it selects the distinct
# user IDs of these customers.
# When you run this query in ClickHouse SQL, you can add "format TabSeparatedWithNames"
# at the end to get the result in the right format:
# ```sql
# SELECT COUNT(DISTINCT u.user_id)
# FROM ecommerce.sessions s
# JOIN ecommerce.users u ON s.user_id = u.user_id
# WHERE
# EXTRACT(YEAR FROM s.action_date) = 2024
# AND EXTRACT(MONTH FROM s.action_date) = 12
# AND revenue > 0;
# 格式 TabSeparatedWithNames;
# ```py
代理生成了一个相当不错的结果,但有一个问题——LLM 不仅返回了 SQL 查询,还返回了一些评论。由于我们计划稍后执行 SQL 查询,这种格式并不适合我们的任务。让我们来解决这个问题。
幸运的是,这个问题已经得到解决,我们不需要手动从文本中解析 SQL 查询。我们可以使用聊天模型ChatOllama。不幸的是,它不支持结构化输出,但我们可以通过工具调用实现相同的结果。
为了做到这一点,我们将定义一个虚拟工具来执行查询,并在系统提示中指示模型始终调用该工具。我在输出中保留了comments,以便给模型一些推理空间,遵循思维链模式。
from langchain_ollama import ChatOllama
from langchain_core.tools import tool
@tool
def execute_query(comments: str, query: str) -> str:
"""Excutes SQL query.
Args:
comments (str): 1-2 sentences describing the result SQL query
and what it does to answer the question,
query (str): SQL query
"""
pass
chat_llm = ChatOllama(model="llama3.1:8b").bind_tools([execute_query])
result = chat_llm.invoke(prompt)
print(result.tool_calls)
# [{'name': 'execute_query',
# 'args': {'comments': 'SQL query returns number of customers who made a purchase in December 2024\. The query joins the sessions and users tables based on user ID to filter out inactive customers and find those with non-zero revenue in December 2024.',
# 'query': 'SELECT COUNT(DISTINCT T2.user_id) FROM ecommerce.sessions AS T1 INNER JOIN ecommerce.users AS T2 ON T1.user_id = T2.user_id WHERE YEAR(T1.action_date) = 2024 AND MONTH(T1.action_date) = 12 AND T2.is_active = 1 AND T1.revenue > 0'},
# 'type': 'tool_call'}]
通过工具调用,我们现在可以直接从模型中获得 SQL 查询。这是一个很好的结果。然而,生成的查询并不完全准确:
它包含了
is_active = 1的过滤条件,尽管我们没有明确要求过滤掉非活跃客户。尽管我们在系统提示中明确要求指定格式,LLM 还是没有做到这一点。
显然,我们需要专注于提高模型的准确性。但正如彼得·德鲁克(Peter Drucker)所说的那样,“你不能改进你无法衡量的东西。” 所以,下一步就是建立一个系统来评估模型的质量。这个系统将成为性能改进迭代的基石。如果没有它,我们基本上是在黑暗中摸索前进。
评估准确性
评估基础
为了确保我们在改进,需要一个可靠的方式来衡量准确性。最常见的方法是创建一个“黄金”评估集,其中包含问题和正确答案。然后,我们可以将模型的输出与这些“黄金”答案进行比较,并计算正确答案的比例。虽然这个方法听起来很简单,但还是有一些值得讨论的细节。
首先,你可能会觉得创建一个全面的问答集让人不知所措。构建这样一个数据集看起来可能是一个令人生畏的任务,可能需要数周甚至数月的时间。然而,我们可以从小做起,先创建一个 20 到 50 个示例的初步集,并在此基础上进行迭代。
一如既往,质量比数量更重要。我们的目标是创建一个具有代表性且多样化的数据集。理想情况下,这应该包括:
常见问题。 在大多数实际情况下,我们可以采用实际问题的历史,并将其作为我们初步的评估集。
挑战性的边界案例。 值得添加一些模型容易产生幻觉的示例。你可以通过自己实验或者收集第一个原型的反馈来发现这些案例。
一旦数据集准备好,接下来的挑战是如何为生成的结果打分。我们可以考虑几种方法:
比较 SQL 查询。 第一个想法是将生成的 SQL 查询与评估集中的查询进行比较。然而,这可能会有些棘手。看似相似的查询可能会产生完全不同的结果。同时,看起来不同的查询也可能得出相同的结论。此外,单纯比较 SQL 查询并不能验证生成的查询是否能够实际执行。鉴于这些挑战,我不会认为这种方法是我们案例中最可靠的解决方案。
精确匹配。 当我们的评估集中的答案是确定性的时,可以使用传统的精确匹配。例如,如果问题是“有多少个客户?”而答案是“592800”,那么模型的回答必须完全匹配。然而,这种方法也有其局限性。考虑上面的例子,模型回应为*“有 592,800 个客户”*。虽然答案完全正确,但精确匹配的方法会将其标记为无效。
使用 LLM 进行评分。 一个更强大且灵活的方法是利用 LLM 进行评估。我们不再关注查询结构,而是让 LLM 比较 SQL 执行的结果。这种方法在查询可能不同但仍能得出正确结果的情况下特别有效。
值得记住的是,评估并非一次性任务;它是一个持续的过程。为了进一步提升模型的性能,我们需要通过添加导致模型产生幻觉的示例来扩展数据集。在生产模式下,我们可以创建一个反馈循环。通过收集用户反馈,我们可以识别模型失败的案例,并将其纳入我们的评估集中。
在我们的例子中,我们将只评估执行结果是否有效(SQL 查询能否执行)和正确性。当然,你也可以查看其他参数。例如,如果你关心效率,可以将生成查询的执行时间与黄金集中的查询执行时间进行比较。
评估集和验证
既然我们已经了解了基础知识,就准备将其付诸实践。我花了大约 20 分钟时间准备了一组 10 个示例。虽然这组数据量小,但足以完成我们的简单任务。它由一组问题和相应的 SQL 查询组成,格式如下:
[
{
"question": "How many customers made purchase in December 2024?",
"sql_query": "select uniqExact(user_id) as customers from ecommerce.sessions where (toStartOfMonth(action_date) = '2024-12-01') and (revenue > 0) format TabSeparatedWithNames"
},
{
"question": "What was the fraud rate in 2023, expressed as a percentage?",
"sql_query": "select 100*uniqExactIf(user_id, is_fraud = 1)/uniqExact(user_id) as fraud_rate from ecommerce.sessions where (toStartOfYear(action_date) = '2023-01-01') format TabSeparatedWithNames"
},
...
]
你可以在 GitHub 上找到完整的列表 —— 链接。
我们可以将数据集加载到 DataFrame 中,使其准备好在代码中使用。
import json
with open('golden_set.json', 'r') as f:
golden_set = json.loads(f.read())
golden_df = pd.DataFrame(golden_set)
golden_df['id'] = list(range(golden_df.shape[0]))
首先,让我们为评估集中的每个问题生成 SQL 查询。
def generate_query(question):
prompt = get_llama_prompt(question, generate_query_system_prompt)
result = chat_llm.invoke(prompt)
try:
generated_query = result.tool_calls[0]['args']['query']
except:
generated_query = ''
return generated_query
import tqdm
tmp = []
for rec in tqdm.tqdm(golden_df.to_dict('records')):
generated_query = generate_query(rec['question'])
tmp.append(
{
'id': rec['id'],
'generated_query': generated_query
}
)
eval_df = golden_df.merge(pd.DataFrame(tmp))
在进行基于 LLM 的查询输出评分之前,首先确保 SQL 查询是有效的非常重要。为此,我们需要执行查询并检查数据库输出。
我已经创建了一个在 ClickHouse 中运行查询的函数。它还确保输出格式正确指定,因为这在业务应用中可能至关重要。
CH_HOST = 'http://localhost:8123' # default address
import requests
import io
def get_clickhouse_data(query, host = CH_HOST, connection_timeout = 1500):
# pushing model to return data in the format that we want
if not 'format tabseparatedwithnames' in query.lower():
return "Database returned the following error:\n Please, specify the output format."
r = requests.post(host, params = {'query': query},
timeout = connection_timeout)
if r.status_code == 200:
return r.text
else:
return 'Database returned the following error:\n' + r.text
# giving feedback to LLM instead of raising exception
下一步是执行生成的查询和标准查询,然后保存它们的输出。
tmp = []
for rec in tqdm.tqdm(eval_df.to_dict('records')):
golden_output = get_clickhouse_data(rec['sql_query'])
generated_output = get_clickhouse_data(rec['generated_query'])
tmp.append(
{
'id': rec['id'],
'golden_output': golden_output,
'generated_output': generated_output
}
)
eval_df = eval_df.merge(pd.DataFrame(tmp))
接下来,让我们检查输出,看看 SQL 查询是否有效。
def is_valid_output(s):
if s.startswith('Database returned the following error:'):
return 'error'
if len(s.strip().split('\n')) >= 1000:
return 'too many rows'
return 'ok'
eval_df['golden_output_valid'] = eval_df.golden_output.map(is_valid_output)
eval_df['generated_output_valid'] = eval_df.generated_output.map(is_valid_output)
然后,我们可以评估生成集和标准集的 SQL 有效性。

初始结果不太理想;LLM 甚至未能生成一个有效的查询。从错误来看,很明显模型没有按照系统提示中明确定义的格式进行操作。所以,我们确实需要更多地关注准确性。
检查正确性
然而,仅仅验证有效性是不够的。我们不仅需要生成有效的 SQL 查询,还必须生成正确的结果。尽管我们已经知道所有查询都是无效的,但现在让我们将输出评估纳入到我们的过程中。
如前所述,我们将使用 LLM 比较 SQL 查询的输出。我通常更倾向于使用更强大的模型进行评估,这符合日常逻辑,即由资深团队成员审查工作。对于这个任务,我选择了OpenAI GPT 4o-mini。
与我们的生成流程类似,我已经设置好了所有进行准确性评估所需的构件。
from langchain_openai import ChatOpenAI
accuracy_system_prompt = '''
You are a senior and very diligent QA specialist and your task is to compare data in datasets.
They are similar if they are almost identical, or if they convey the same information.
Disregard if column names specified in the first row have different names or in a different order.
Focus on comparing the actual information (numbers). If values in datasets are different, then it means that they are not identical.
Always execute tool to provide results.
'''
@tool
def compare_datasets(comments: str, score: int) -> str:
"""Stores info about datasets.
Args:
comments (str): 1-2 sentences about the comparison of datasets,
score (int): 0 if dataset provides different values and 1 if it shows identical information
"""
pass
accuracy_chat_llm = ChatOpenAI(model="gpt-4o-mini", temperature = 0.0)\
.bind_tools([compare_datasets])
accuracy_question_tmp = '''
Here are the two datasets to compare delimited by ####
Dataset #1:
####
{dataset1}
####
Dataset #2:
####
{dataset2}
####
'''
def get_openai_prompt(question, system):
messages = [
("system", system),
("human", question)
]
return messages
现在,是时候测试准确性评估过程了。
prompt = get_openai_prompt(accuracy_question_tmp.format(
dataset1 = 'customers\n114032\n', dataset2 = 'customers\n114031\n'),
accuracy_system_prompt)
accuracy_result = accuracy_chat_llm.invoke(prompt)
accuracy_result.tool_calls[0]['args']
# {'comments': 'The datasets contain different customer counts: 114032 in Dataset #1 and 114031 in Dataset #2.',
# 'score': 0}
prompt = get_openai_prompt(accuracy_question_tmp.format(
dataset1 = 'users\n114032\n', dataset2 = 'customers\n114032\n'),
accuracy_system_prompt)
accuracy_result = accuracy_chat_llm.invoke(prompt)
accuracy_result.tool_calls[0]['args']
# {'comments': 'The datasets contain the same numerical value (114032) despite different column names, indicating they convey identical information.',
# 'score': 1}
太棒了!看起来一切都按预期工作。接下来,让我们将其封装成一个函数。
def is_answer_accurate(output1, output2):
prompt = get_openai_prompt(
accuracy_question_tmp.format(dataset1 = output1, dataset2 = output2),
accuracy_system_prompt
)
accuracy_result = accuracy_chat_llm.invoke(prompt)
try:
return accuracy_result.tool_calls[0]['args']['score']
except:
return None
整合评估方法
如前所述,构建 LLM 应用程序是一个迭代过程,因此我们需要多次运行准确性评估。将所有这些逻辑封装到一个函数中将非常有帮助。
该函数将接受两个参数作为输入:
generate_query_func:一个根据给定问题生成 SQL 查询的函数。golden_df:一个包含问题和正确答案的评估数据集,形式为 pandas DataFrame。
作为输出,函数将返回一个包含所有评估结果的数据框,并展示几个显示主要 KPI 的图表。
def evaluate_sql_agent(generate_query_func, golden_df):
# generating SQL
tmp = []
for rec in tqdm.tqdm(golden_df.to_dict('records')):
generated_query = generate_query_func(rec['question'])
tmp.append(
{
'id': rec['id'],
'generated_query': generated_query
}
)
eval_df = golden_df.merge(pd.DataFrame(tmp))
# executing SQL queries
tmp = []
for rec in tqdm.tqdm(eval_df.to_dict('records')):
golden_output = get_clickhouse_data(rec['sql_query'])
generated_output = get_clickhouse_data(rec['generated_query'])
tmp.append(
{
'id': rec['id'],
'golden_output': golden_output,
'generated_output': generated_output
}
)
eval_df = eval_df.merge(pd.DataFrame(tmp))
# checking accuracy
eval_df['golden_output_valid'] = eval_df.golden_output.map(is_valid_output)
eval_df['generated_output_valid'] = eval_df.generated_output.map(is_valid_output)
eval_df['correct_output'] = list(map(
is_answer_accurate,
eval_df['golden_output'],
eval_df['generated_output']
))
eval_df['accuracy'] = list(map(
lambda x, y: 'invalid: ' + x if x != 'ok' else ('correct' if y == 1 else 'incorrect'),
eval_df.generated_output_valid,
eval_df.correct_output
))
valid_stats_df = (eval_df.groupby('golden_output_valid')[['id']].count().rename(columns = {'id': 'golden set'}).join(
eval_df.groupby('generated_output_valid')[['id']].count().rename(columns = {'id': 'generated'}), how = 'outer')).fillna(0).T
fig1 = px.bar(
valid_stats_df.apply(lambda x: 100*x/valid_stats_df.sum(axis = 1)),
orientation = 'h',
title = '<b>LLM SQL Agent evaluation</b>: query validity',
text_auto = '.1f',
color_discrete_map = {'ok': '#00b38a', 'error': '#ea324c', 'too many rows': '#f2ac42'},
labels = {'index': '', 'variable': 'validity', 'value': 'share of queries, %'}
)
fig1.show()
accuracy_stats_df = eval_df.groupby('accuracy')[['id']].count()
accuracy_stats_df['share'] = accuracy_stats_df.id*100/accuracy_stats_df.id.sum()
fig2 = px.bar(
accuracy_stats_df[['share']],
title = '<b>LLM SQL Agent evaluation</b>: query accuracy',
text_auto = '.1f', orientation = 'h',
color_discrete_sequence = ['#0077B5'],
labels = {'index': '', 'variable': 'accuracy', 'value': 'share of queries, %'}
)
fig2.update_layout(showlegend = False)
fig2.show()
return eval_df
到此为止,我们已经完成了评估设置,接下来可以进入提高模型准确性的核心任务。
提高准确性:自我反思
让我们快速回顾一下。我们已经构建并测试了 SQL Agent 的第一个版本。不幸的是,所有生成的查询都是无效的,因为它们缺少输出格式。让我们解决这个问题。
一种潜在的解决方案是自我反思。我们可以向 LLM 发出额外的请求,分享错误并要求其修正 bug。让我们创建一个函数来处理带有自我反思的生成任务。
reflection_user_query_tmpl = '''
You've got the following question: "{question}".
You've generated the SQL query: "{query}".
However, the database returned an error: "{output}".
Please, revise the query to correct mistake.
'''
def generate_query_reflection(question):
generated_query = generate_query(question)
print('Initial query:', generated_query)
db_output = get_clickhouse_data(generated_query)
is_valid_db_output = is_valid_output(db_output)
if is_valid_db_output == 'too many rows':
db_output = "Database unexpectedly returned more than 1000 rows."
if is_valid_db_output == 'ok':
return generated_query
reflection_user_query = reflection_user_query_tmpl.format(
question = question,
query = generated_query,
output = db_output
)
reflection_prompt = get_llama_prompt(reflection_user_query,
generate_query_system_prompt)
reflection_result = chat_llm.invoke(reflection_prompt)
try:
reflected_query = reflection_result.tool_calls[0]['args']['query']
except:
reflected_query = ''
print('Reflected query:', reflected_query)
return reflected_query
现在,让我们使用评估函数检查质量是否有所改善。评估下一个迭代变得轻松起来。
refl_eval_df = evaluate_sql_agent(generate_query_reflection, golden_df)
太棒了!我们取得了更好的结果——现在 50%的查询是有效的,所有格式问题都已解决。因此,自我反思非常有效。

然而,自我反思有其局限性。当我们检查准确性时,我们发现模型仅对一个问题返回了正确答案。所以,我们的旅程还没有结束。

提高准确性:RAG
提高准确性的另一种方法是使用 RAG(检索增强生成)。其理念是识别与客户查询相似的问题和答案对,并将其包含在系统提示中,使得 LLM 能够生成更准确的回答。
RAG 包含以下阶段:
加载文档: 从可用数据源导入数据。
拆分文档: 创建较小的文档片段。
存储: 使用向量存储高效处理和存储数据。
检索: 提取与查询相关的文档。
生成: 将问题和相关文档传递给 LLM 以生成最终答案**。**

如果你想回顾 RAG 的内容,可以查看我之前的文章,“RAG:如何与数据对话。”
我们将使用 Chroma 数据库作为本地向量存储——用于存储和检索嵌入。
from langchain_chroma import Chroma
vector_store = Chroma(embedding_function=embeddings)
向量存储使用嵌入查找与查询相似的片段。为此,我们将使用 OpenAI 的嵌入。
from langchain_openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model="text-embedding-3-large")
由于我们不能使用评估集中的示例(因为它们已经用于评估质量),我创建了一个独立的问题和答案对集用于 RAG。你可以在GitHub上找到它。
现在,让我们加载数据集并创建以下格式的对列表:Question: %s; Answer: %s。
with open('rag_set.json', 'r') as f:
rag_set = json.loads(f.read())
rag_set_df = pd.DataFrame(rag_set)
rag_set_df['formatted_txt'] = list(map(
lambda x, y: 'Question: %s; Answer: %s' % (x, y),
rag_set_df.question,
rag_set_df.sql_query
))
rag_string_data = '\n\n'.join(rag_set_df.formatted_txt)
接下来,我使用 LangChain 的按字符拆分器将文档拆分为片段,每个问题和答案对作为一个独立的片段。由于我们是从语义上拆分文本,因此不需要重叠。
from langchain_text_splitters import CharacterTextSplitter
text_splitter = CharacterTextSplitter(
separator="\n\n",
chunk_size=1, # to split by character without merging
chunk_overlap=0,
length_function=len,
is_separator_regex=False,
)
texts = text_splitter.create_documents([rag_string_data])
最后一步是将这些数据块加载到我们的向量存储中。
document_ids = vector_store.add_documents(documents=texts)
print(vector_store._collection.count())
# 32
现在,我们可以测试检索来查看结果。它们看起来与客户问题非常相似。
question = 'What was the share of users using Windows yesterday?'
retrieved_docs = vector_store.similarity_search(question, 3)
context = "\n\n".join(map(lambda x: x.page_content, retrieved_docs))
print(context)
# Question: What was the share of users using Windows the day before yesterday?;
# Answer: select 100*uniqExactIf(user_id, os = 'Windows')/uniqExact(user_id) as windows_share from ecommerce.sessions where (action_date = today() - 2) format TabSeparatedWithNames
# Question: What was the share of users using Windows in the last week?;
# Answer: select 100*uniqExactIf(user_id, os = 'Windows')/uniqExact(user_id) as windows_share from ecommerce.sessions where (action_date >= today() - 7) and (action_date < today()) format TabSeparatedWithNames
# Question: What was the share of users using Android yesterday?;
# Answer: select 100*uniqExactIf(user_id, os = 'Android')/uniqExact(user_id) as android_share from ecommerce.sessions where (action_date = today() - 1) format TabSeparatedWithNames
让我们调整系统提示,包含我们检索到的示例。
generate_query_system_prompt_with_examples_tmpl = '''
You are a senior data analyst with more than 10 years of experience writing complex SQL queries.
There are two tables in the database you're working with with the following schemas.
Table: ecommerce.users
Description: customers of the online shop
Fields:
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- country (string) - country of residence, for example, "Netherlands" or "United Kingdom"
- is_active (integer) - 1 if customer is still active and 0 otherwise
- age (integer) - customer age in full years, for example, 31 or 72
Table: ecommerce.sessions
Description: sessions of usage the online shop
Fields:
- user_id (integer) - unique identifier of customer, for example, 1000004 or 3000004
- session_id (integer) - unique identifier of session, for example, 106 or 1023
- action_date (date) - session start date, for example, "2021-01-03" or "2024-12-02"
- session_duration (integer) - duration of session in seconds, for example, 125 or 49
- os (string) - operation system that customer used, for example, "Windows" or "Android"
- browser (string) - browser that customer used, for example, "Chrome" or "Safari"
- is_fraud (integer) - 1 if session is marked as fraud and 0 otherwise
- revenue (float) - income in USD (the sum of purchased items), for example, 0.0 or 1506.7
Write a query in ClickHouse SQL to answer the following question.
Add "format TabSeparatedWithNames" at the end of the query to get data from ClickHouse database in the right format.
Answer questions following the instructions and providing all the needed information and sharing your reasoning.
Examples of questions and answers:
{examples}
'''
再一次,让我们使用 RAG 创建生成查询函数。
def generate_query_rag(question):
retrieved_docs = vector_store.similarity_search(question, 3)
context = context = "\n\n".join(map(lambda x: x.page_content, retrieved_docs))
prompt = get_llama_prompt(question,
generate_query_system_prompt_with_examples_tmpl.format(examples = context))
result = chat_llm.invoke(prompt)
try:
generated_query = result.tool_calls[0]['args']['query']
except:
generated_query = ''
return generated_query
和往常一样,让我们使用我们的评估函数来测试这个新方法。
rag_eval_df = evaluate_sql_agent(generate_query_rag, golden_df)
我们可以看到一个显著的改进,从 10 个中的 1 个正确答案增加到 6 个。虽然仍然不理想,但我们正在朝着正确的方向前进。

我们还可以尝试将两种方法结合起来:RAG 和自我反思。
def generate_query_rag_with_reflection(question):
generated_query = generate_query_rag(question)
db_output = get_clickhouse_data(generated_query)
is_valid_db_output = is_valid_output(db_output)
if is_valid_db_output == 'too many rows':
db_output = "Database unexpectedly returned more than 1000 rows."
if is_valid_db_output == 'ok':
return generated_query
reflection_user_query = reflection_user_query_tmpl.format(
question = question,
query = generated_query,
output = db_output
)
reflection_prompt = get_llama_prompt(reflection_user_query, generate_query_system_prompt)
reflection_result = chat_llm.invoke(reflection_prompt)
try:
reflected_query = reflection_result.tool_calls[0]['args']['query']
except:
reflected_query = ''
return reflected_query
rag_refl_eval_df = evaluate_sql_agent(generate_query_rag_with_reflection,
golden_df)
我们可以看到另一个轻微的改进:我们已经完全消除了无效的 SQL 查询(得益于自我反思),并且将正确答案的数量提高到了 10 个中的 7 个。

就这样。这是一次相当的旅程。我们从 0 个有效的 SQL 查询开始,现在已经达到了 70%的准确率。
您可以在GitHub上找到完整的代码。
总结
在本文中,我们探讨了提高 LLM 应用准确性的迭代过程。
我们构建了一个评估集和评分标准,允许我们比较不同的迭代,并了解我们是否朝着正确的方向前进。
我们利用自我反思来让 LLM 修正其错误,并显著减少无效的 SQL 查询数量。
此外,我们实现了检索增强生成(RAG),进一步提高了质量,达到了 60%到 70%的准确率。
虽然这是一个稳健的结果,但它仍然未达到通常对于生产应用期望的 90%以上的准确率。为了达到这样的高标准,我们需要使用微调,这将是下一篇文章的主题。
非常感谢您阅读这篇文章。希望这篇文章对您有所启发。如果您有任何后续问题或评论,请在评论区留言。
参考
除非另有说明,所有图片均由作者制作。
本文的灵感来源于 DeepLearning.AI 的“提高 LLM 应用准确性”短期课程。
从检索到智能:探索 RAG、Agent+RAG 和 TruLens 的评估
解锁 GPT 生成的私有语料库的潜力
·发表于 Towards Data Science ·阅读时间:21 分钟·2024 年 12 月 3 日
--
介绍
如今,世界上有许多优秀的基础模型可以用来启动您的自定义应用程序(如 gpt-4o、Sonnet、Gemini、Llama3.2、Gemma、Ministral 等)。这些模型了解关于历史、地理和 Wikipedia 文章的方方面面,但仍然存在一些弱点。主要有两个问题:细节层次(例如,模型知道 BMW 的品牌、它的功能、车型名称以及一些更一般的信息;但如果你询问欧洲的销售数量或某个具体发动机部件的细节,模型就无法回答)以及最近的知识(例如,Llama3.2 模型或 Ministral 发布;基础模型是在某个特定时间点训练的,且具有知识截止日期,之后模型对任何新信息都无知)。

摄影:由 Jaredd Craig 提供,来源于 Unsplash
本文聚焦于这两个问题,描述了在知识截止日期之前成立的虚构公司的情况,尽管其中一些信息最近有所更改。
为了应对这两个问题,我们将使用 RAG 技术和 LlamaIndex 框架。检索增强生成(Retrieval Augmented Generation, RAG)的理念是,在回答生成过程中为模型提供最相关的信息。通过这种方式,我们可以拥有一个包含自定义数据的数据库,模型可以利用这些数据。为了进一步评估系统的表现,我们将结合 TruLens 库和 RAG 三重度量标准(RAG Triad metrics)。
提到知识截止日期,这个问题通过谷歌搜索工具得到了解决。然而,我们不能完全用搜索工具来替代知识截止日期。为了理解这一点,想象有两个机器学习专家:第一个专家了解当前 GenAI 的所有知识,第二个专家 6 个月前从 GenAI 转向了经典计算机视觉。如果你问他们两个同样的问题,关于如何使用最近的 GenAI 模型,所需的搜索请求数量会大不相同。第一个专家会知道所有的内容,可能只会检查一些特定的命令。而第二个专家则必须阅读大量详细的文章,首先了解发生了什么,模型在做什么,内部机制如何,只有在此之后,他才有可能给出答案。
基本上,它就像是领域专家与一些普通专家之间的比较,一个可以快速回答,另一个则需要去谷歌搜索,因为他并不清楚第一个专家所了解的所有细节。
这里的关键点是,很多谷歌搜索提供的答案,虽然在较长时间内是可比较的。但在类似聊天的应用中,用户不会等待几分钟来让模型搜索某些内容。此外,并非所有的信息都是公开的,也不能通过谷歌搜索到。
数据
现在可能很难找到一个没有在基础模型训练数据中使用过的数据集。几乎所有的数据都已被索引并在大规模模型的预训练阶段使用。

来源:作者使用 AI(Bing)生成的图像
这就是为什么我决定自己生成一个。为此,我通过 OpenAI 的 UI 和多个连续的提示(它们都类似于下面的内容)使用了chatgpt-4o-latest:
Generate me a private corpus with some details mentioning the imagined Ukraine Boats Inc.
A list of products, prices, responsible stuff, etc.
I want to use it as my private corpus for the RAG use-case
You can generate really a lot of the text. The more the better.
Yeah, proceed with partnerships, legal policies, competitions participated
Maybe info about where we manufacture our boats (and add some custom ones)
add client use studies
结果,我为 4 家公司生成了一个私人语料库。以下是计算的标记数,以更好地体现数据集的大小。
# Number of tokens using the `o200k_base` tokenizer (gpt-4o/gpt-4o-mini)
nova-drive-motors.txt: 2757
aero-vance-aviation.txt: 1860
ukraine-boats.txt: 3793
city-solve.txt: 3826
total_tokens=12236
下面是你可以阅读的关于乌克兰船业公司(Ukraine Boats Inc.)描述的开头:
## **Ukraine Boats Inc.**
**Corporate Overview:**
Ukraine Boats Inc. is a premier manufacturer and supplier of high-quality boats and maritime solutions based in Odessa, Ukraine. The company prides itself on blending traditional craftsmanship with modern technology to serve clients worldwide. Founded in 2005, the company has grown to be a leader in the boating industry, specializing in recreational, commercial, and luxury vessels.
- -
### **Product Lineup**
#### **Recreational Boats:**
1\. **WaveRunner X200**
- **Description:** A sleek speedboat designed for water sports enthusiasts. Equipped with advanced navigation and safety features.
- **Price:** $32,000
- **Target Market:** Young adventurers and watersport lovers.
- **Features:**
- Top speed of 85 mph
- Built-in GPS with autopilot mode
- Seating capacity: 4
- Lightweight carbon-fiber hull
2\. **AquaCruise 350**
- **Description:** A versatile motorboat ideal for fishing, family trips, and casual cruising.
- **Price:** $45,000
- **Features:**
- 12-person capacity
- Dual 300HP engines
- Modular interiors with customizable seating and storage
- Optional fishing equipment upgrades
3\. **SolarGlide EcoBoat**
- **Description:** A solar-powered boat for environmentally conscious customers.
- **Price:** $55,000
- **Features:**
- Solar panel roof with 12-hour charge life
- Zero emissions
- Maximum speed: 50 mph
- Silent motor technology
- -
…
完整的私人语料库可以在GitHub上找到。
为了评估数据集的目的,我还要求模型基于给定的语料库生成了 10 个问题(仅关于乌克兰船业公司)。
based on the whole corpus above, generate 10 questions and answers for them pass them into the python native data structure
这是获得的数据集:
[
{
"question": "What is the primary focus of Ukraine Boats Inc.?",
"answer": "Ukraine Boats Inc. specializes in manufacturing high-quality recreational, luxury, and commercial boats, blending traditional craftsmanship with modern technology."
},
{
"question": "What is the price range for recreational boats offered by Ukraine Boats Inc.?",
"answer": "Recreational boats range from $32,000 for the WaveRunner X200 to $55,000 for the SolarGlide EcoBoat."
},
{
"question": "Which manufacturing facility focuses on bespoke yachts and customizations?",
"answer": "The Lviv Custom Craft Workshop specializes in bespoke yachts and high-end customizations, including handcrafted woodwork and premium materials."
},
{
"question": "What is the warranty coverage offered for boats by Ukraine Boats Inc.?",
"answer": "All boats come with a 5-year warranty for manufacturing defects, while engines are covered under a separate 3-year engine performance guarantee."
},
{
"question": "Which client used the Neptune Voyager catamaran, and what was the impact on their business?",
"answer": "Paradise Resorts International used the Neptune Voyager catamarans, resulting in a 45% increase in resort bookings and winning the 'Best Tourism Experience' award."
},
{
"question": "What award did the SolarGlide EcoBoat win at the Global Marine Design Challenge?",
"answer": "The SolarGlide EcoBoat won the 'Best Eco-Friendly Design' award at the Global Marine Design Challenge in 2022."
},
{
"question": "How has the Arctic Research Consortium benefited from the Poseidon Explorer?",
"answer": "The Poseidon Explorer enabled five successful Arctic research missions, increased data collection efficiency by 60%, and improved safety in extreme conditions."
},
{
"question": "What is the price of the Odessa Opulence 5000 luxury yacht?",
"answer": "The Odessa Opulence 5000 luxury yacht starts at $1,500,000."
},
{
"question": "Which features make the WaveRunner X200 suitable for watersports?",
"answer": "The WaveRunner X200 features a top speed of 85 mph, a lightweight carbon-fiber hull, built-in GPS, and autopilot mode, making it ideal for watersports."
},
{
"question": "What sustainability initiative is Ukraine Boats Inc. pursuing?",
"answer": "Ukraine Boats Inc. is pursuing the Green Maritime Initiative (GMI) to reduce the carbon footprint by incorporating renewable energy solutions in 50% of their fleet by 2030."
}
]
现在,当我们拥有了私人语料库和 Q&A 对的数据集时,我们可以将我们的数据插入到一些合适的存储中。
数据传播
我们可以为 RAG 用例利用多种数据库,但对于这个项目以及可能处理未来关系的需求,我将 Neo4j 数据库集成到我们的解决方案中。此外,Neo4j 在注册后提供免费的实例。
现在,让我们开始准备节点。首先,我们实例化一个嵌入模型。我们使用了 256 维的向量,因为一些最近的测试表明,较大的向量维度会导致得分的方差较小(这不是我们需要的)。作为嵌入模型,我们使用了text-embedding-3-small模型。
# initialize models
embed_model = OpenAIEmbedding(
model=CFG['configuration']['models']['embedding_model'],
api_key=os.getenv('AZURE_OPENAI_API_KEY'),
dimensions=CFG['configuration']['embedding_dimension']
)
之后,我们读取了语料库:
# get documents paths
document_paths = [Path(CFG['configuration']['data']['raw_data_path']) / document for document in CFG['configuration']['data']['source_docs']]
# initialize a file reader
reader = SimpleDirectoryReader(input_files=document_paths)
# load documents into LlamaIndex Documents
documents = reader.load_data()
此外,我们利用 SentenceSplitter 将文档分割成独立的节点。这些节点将存储在 Neo4j 数据库中。
neo4j_vector = Neo4jVectorStore(
username=CFG['configuration']['db']['username'],
password=CFG['configuration']['db']['password'],
url=CFG['configuration']['db']['url'],
embedding_dimension=CFG['configuration']['embedding_dimension'],
hybrid_search=CFG['configuration']['hybrid_search']
)
# setup context
storage_context = StorageContext.from_defaults(
vector_store=neo4j_vector
)
# populate DB with nodes
index = VectorStoreIndex(nodes, storage_context=storage_context, show_progress=True)
混合搜索目前被关闭。这是故意为之,以突出向量搜索算法的性能。
一切就绪,现在我们准备进入查询流水线。

来源:图像由作者创建
流水线
RAG 技术可以作为独立的解决方案实现,或作为代理的一部分。代理负责处理所有的聊天历史、工具处理、推理和输出生成。接下来,我们将演示如何实现查询引擎(独立 RAG)和代理方法(代理将能够将 RAG 作为其工具之一)。
当我们谈论聊天模型时,大多数人会选择 OpenAI 模型,而没有考虑其他选择。我们将概述在 OpenAI 模型和 Meta Llama 3.2 模型上使用 RAG 的方法。让我们对比一下哪一个表现更好。
所有配置参数已移至 pyproject.toml 文件。
[configuration]
similarity_top_k = 10
vector_store_query_mode = "default"
similarity_cutoff = 0.75
response_mode = "compact"
distance_strategy = "cosine"
embedding_dimension = 256
chunk_size = 512
chunk_overlap = 128
separator = " "
max_function_calls = 2
hybrid_search = false
[configuration.data]
raw_data_path = "../data/companies"
dataset_path = "../data/companies/dataset.json"
source_docs = ["city-solve.txt", "aero-vance-aviation.txt", "nova-drive-motors.txt", "ukraine-boats.txt"]
[configuration.models]
llm = "gpt-4o-mini"
embedding_model = "text-embedding-3-small"
temperature = 0
llm_hf = "meta-llama/Llama-3.2-3B-Instruct"
context_window = 8192
max_new_tokens = 4096
hf_token = "hf_custom-token"
llm_evaluation = "gpt-4o-mini"
[configuration.db]
url = "neo4j+s://custom-url"
username = "neo4j"
password = "custom-password"
database = "neo4j"
index_name = "article" # change if you want to load the new data that won't intersect with the previous uploads
text_node_property = "text"
两个模型的共同步骤是连接到 Neo4j 中现有的向量索引。
# connect to the existing neo4j vector index
vector_store = Neo4jVectorStore(
username=CFG['configuration']['db']['username'],
password=CFG['configuration']['db']['password'],
url=CFG['configuration']['db']['url'],
embedding_dimension=CFG['configuration']['embedding_dimension'],
distance_strategy=CFG['configuration']['distance_strategy'],
index_name=CFG['configuration']['db']['index_name'],
text_node_property=CFG['configuration']['db']['text_node_property']
)
index = VectorStoreIndex.from_vector_store(vector_store)
OpenAI
首先,我们应该初始化所需的 OpenAI 模型。我们将使用 gpt-4o-mini 作为语言模型,并使用相同的嵌入模型。我们为 Settings 对象指定了 LLM 和嵌入模型。这样,我们就不需要再传递这些模型。LlamaIndex 会在需要时从 Settings 中解析出 LLM。
# initialize models
llm = OpenAI(
api_key=os.getenv('AZURE_OPENAI_API_KEY'),
model=CFG['configuration']['models']['llm'],
temperature=CFG['configuration']['models']['temperature']
)
embed_model = OpenAIEmbedding(
model=CFG['configuration']['models']['embedding_model'],
api_key=os.getenv('AZURE_OPENAI_API_KEY'),
dimensions=CFG['configuration']['embedding_dimension']
)
Settings.llm = llm
Settings.embed_model = embed_model
查询引擎
之后,我们可以从现有的向量索引中创建一个默认的查询引擎:
# create query engine
query_engine = index.as_query_engine()
此外,我们可以通过简单的 query() 方法获取 RAG 逻辑。此外,我们打印了从数据库检索的源节点列表和最终的 LLM 响应。
# custom question
response = query_engine.query("What is the primary focus of Ukraine Boats Inc.?")
# get similarity scores
for node in response.source_nodes:
print(f'{node.node.id_}, {node.score}')
# predicted answer
print(response.response)
以下是示例输出:
ukraine-boats-3, 0.8536546230316162
ukraine-boats-4, 0.8363556861877441
The primary focus of Ukraine Boats Inc. is designing, manufacturing, and selling luxury and eco-friendly boats, with a strong emphasis on customer satisfaction and environmental sustainability.
如你所见,我们创建了自定义的节点 ID,以便我们可以理解该文件的来源以及片段的顺序 ID。我们可以通过使用低级 LlamaIndex API 更具体地控制查询引擎的行为:
# custom retriever
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=CFG['configuration']['similarity_top_k'],
vector_store_query_mode=CFG['configuration']['vector_store_query_mode']
)
# similarity threshold
similarity_postprocessor = SimilarityPostprocessor(similarity_cutoff=CFG['configuration']['similarity_cutoff'])
# custom response synthesizer
response_synthesizer = get_response_synthesizer(
response_mode=CFG['configuration']['response_mode']
)
# combine custom query engine
query_engine = RetrieverQueryEngine(
retriever=retriever,
node_postprocessors=[similarity_postprocessor],
response_synthesizer=response_synthesizer
)
在这里,我们指定了自定义的检索器、相似度后处理器和改进阶段的操作。
为了进一步定制,你可以在任何 LlamaIndex 组件周围创建自定义包装器,使其更具体并与您的需求对齐。
代理
为了在 LlamaIndex 中实现基于 RAG 的代理,我们需要使用预定义的 AgentWorker 之一。我们将使用 OpenAIAgentWorker,它使用 OpenAI 的 LLM 作为大脑。此外,我们将之前部分中的查询引擎封装到 QueryEngineTool 中,代理可以根据工具描述选择它。
AGENT_SYSTEM_PROMPT = "You are a helpful human assistant. You always call the retrieve_semantically_similar_data tool before answering any questions. If the answer to the questions couldn't be found using the tool, just respond with `Didn't find relevant information`."
TOOL_NAME = "retrieve_semantically_similar_data"
TOOL_DESCRIPTION = "Provides additional information about the companies. Input: string"
# agent worker
agent_worker = OpenAIAgentWorker.from_tools(
[
QueryEngineTool.from_defaults(
query_engine=query_engine,
name=TOOL_NAME,
description=TOOL_DESCRIPTION,
return_direct=False,
)
],
system_prompt=AGENT_SYSTEM_PROMPT,
llm=llm,
verbose=True,
max_function_calls=CFG['configuration']['max_function_calls']
)
要进一步使用代理,我们需要一个 AgentRunner。Runner 更像是一个协调器,处理顶层交互和状态,而工作者执行具体的操作,比如工具和 LLM 的使用。
# agent runner
agent = AgentRunner(agent_worker=agent_worker)

来源:图片来自于LlamaIndex 文档
为了有效测试用户与代理的交互,我实现了一个简单的聊天式界面:
while True:
# get user input
current_message = input('Insert your next message:')
print(f'{datetime.now().strftime("%H:%M:%S.%f")[:-3]}|User: {current_message}')
response = agent.chat(current_message)
print(f'{datetime.now().strftime("%H:%M:%S.%f")[:-3]}|Agent: {response.response}')
这里是聊天的示例:
Insert your next message: Hi
15:55:43.101|User: Hi
Added user message to memory: Hi
15:55:43.873|Agent: Didn't find relevant information.
Insert your next message: Do you know anything about the city solve?
15:56:24.751|User: Do you know anything about the city solve?
Added user message to memory: Do you know anything about the city solve?
=== Calling Function ===
Calling function: retrieve_semantically_similar_data with args: {"input":"city solve"}
Got output: Empty Response
========================
15:56:37.267|Agent: Didn't find relevant information.
Insert your next message: What is the primary focus of Ukraine Boats Inc.?
15:57:36.122|User: What is the primary focus of Ukraine Boats Inc.?
Added user message to memory: What is the primary focus of Ukraine Boats Inc.?
=== Calling Function ===
Calling function: retrieve_semantically_similar_data with args: {"input":"Ukraine Boats Inc."}
Got output: Ukraine Boats Inc. is a premier manufacturer and supplier of high-quality boats and maritime solutions based in Odessa, Ukraine. Founded in 2005, the company specializes in recreational, commercial, and luxury vessels, blending traditional craftsmanship with modern technology. It has established a strong market presence in Europe, North America, and Asia, supported by partnerships with distribution companies like Baltic Marine Distributors in Germany, OceanCraft LLC in the USA, and Yokohama SeaTech in Japan.
The company is organized into several departments, including Engineering, Sales and Marketing, Production, and Customer Service, each with specific responsibilities to ensure efficient operations and customer satisfaction. Ukraine Boats Inc. is committed to sustainability through initiatives like the Green Maritime Initiative, aiming to reduce its carbon footprint by incorporating renewable energy solutions in its fleet.
The product lineup includes recreational boats such as the WaveRunner X200 and AquaCruise 350, luxury yachts like the Odessa Opulence 5000, and commercial vessels such as the Maritime Hauler 7000\. The company also offers customization options, maintenance plans, and a range of accessories to enhance the boating experience.
With a focus on innovation and quality, Ukraine Boats Inc. has received accolades in various marine design competitions and continues to expand its global footprint while maintaining a commitment to environmental compliance and customer service excellence.
========================
15:57:49.505|Agent: The primary focus of Ukraine Boats Inc. is the manufacturing and supplying of high-quality boats and maritime solutions. Based in Odessa, Ukraine, the company specializes in recreational, commercial, and luxury vessels, combining traditional craftsmanship with modern technology. They are committed to sustainability and have initiatives like the Green Maritime Initiative to reduce their carbon footprint. Their product lineup includes various types of boats, and they also offer customization options and maintenance plans.
Insert your next message: Do you know anything about the government company city solve?
15:59:02.830|User: Do you know anything about the government company city solve?
Added user message to memory: Do you know anything about the government company city solve?
=== Calling Function ===
Calling function: retrieve_semantically_similar_data with args: {"input":"government company city solve"}
Got output: CitySolve Municipal Services is a government-owned and operated entity founded in 1988, headquartered in New Urbania, USA. Its mission is to ensure a livable, sustainable, and efficient city by providing prompt solutions to residents' everyday needs. The company focuses on various critical service areas, including waterworks and drainage, parks and greenery management, roads and sidewalks, waste management, public safety infrastructure, and urban utilities. CitySolve emphasizes core values such as responsiveness, reliability, accountability, and sustainability, and utilizes modern technologies to enhance its operations and community impact.
========================
15:59:12.693|Agent: CitySolve Municipal Services is a government-owned and operated entity founded in 1988, headquartered in New Urbania, USA. Its mission is to ensure a livable, sustainable, and efficient city by providing prompt solutions to residents' everyday needs. The company focuses on critical service areas such as waterworks and drainage, parks and greenery management, roads and sidewalks, waste management, public safety infrastructure, and urban utilities. CitySolve emphasizes core values like responsiveness, reliability, accountability, and sustainability, and it utilizes modern technologies to enhance its operations and community impact.
Insert your next message: Thanks
15:59:28.761|User: Thanks
Added user message to memory: Thanks
15:59:30.327|Agent: Didn't find relevant information.
我们可以看到,对于正确的向量搜索,您需要用更多细节来指定输入问题,这些细节可以在语义上匹配。
开源
作为一个开源模型,我们使用了meta-llama/Llama-3.2–3B-Instruct。这个选择是基于模型延迟和性能的权衡。首先,我们需要通过访问令牌认证我们的 HuggingFace 账户。
# Use your token here
login(token=CFG['configuration']['models']['hf_token'])
为了在 LlamaIndex 中将 Llama 作为 LLM 使用,我们需要创建一个模型包装器。我们将使用一台单独的 NVIDIA GeForce RTX 3090 来服务我们的 Llama 3.2 模型。
SYSTEM_PROMPT = """You are an AI assistant that answers questions in a friendly manner, based on the given source documents. Here are some rules you always follow:
- Generate human readable output, avoid creating output with gibberish text.
- Generate only the requested output, don't include any other language before or after the requested output.
- Never say thank you, that you are happy to help, that you are an AI agent, etc. Just answer directly.
- Generate professional language typically used in business documents in North America.
- Never generate offensive or foul language.
"""
query_wrapper_prompt = PromptTemplate(
"<|start_header_id|>system<|end_header_id|>\n" + SYSTEM_PROMPT + "<|eot_id|><|start_header_id|>user<|end_header_id|>{query_str}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
)
llm = HuggingFaceLLM(
context_window=CFG['configuration']['models']['context_window'],
max_new_tokens=CFG['configuration']['models']['max_new_tokens'],
generate_kwargs={"temperature": CFG['configuration']['models']['temperature'], "do_sample": False},
query_wrapper_prompt=query_wrapper_prompt,
tokenizer_name=CFG['configuration']['models']['llm_hf'],
model_name=CFG['configuration']['models']['llm_hf'],
device_map="cuda:0",
model_kwargs={"torch_dtype": torch.bfloat16}
)
Settings.llm = llm
QueryEngine
接口是相同的。以下是示例输出:
ukraine-boats-3, 0.8536546230316162
ukraine-boats-4, 0.8363556861877441
The primary focus of Ukraine Boats Inc. is designing, manufacturing, and selling luxury and eco-friendly boats, with a strong emphasis on customer satisfaction and environmental sustainability.
代理
对于 OpenAI 模型,LlamaIndex 设计了一个专门的代理包装器,但对于开源模型,我们应该使用另一种包装器。我们选择了 ReActAgent,它通过反复推理和行动直到最终响应准备好。
agent_worker = ReActAgentWorker.from_tools(
[
QueryEngineTool.from_defaults(
query_engine=query_engine,
name=TOOL_NAME,
description=TOOL_DESCRIPTION,
return_direct=False,
)
],
llm=llm,
verbose=True,
chat_history=[ChatMessage(content=AGENT_SYSTEM_PROMPT, role="system")]
)
# agent runner
agent = AgentRunner(agent_worker=agent_worker)
以下是相同讨论,但使用了不同的代理:
Insert your next message: Hi
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
16:04:29.117|User: Hi
Thought: The current language of the user is: English. I need to use a tool to help me answer the question.
Action: retrieve_semantically_similar_data
Action Input: {'input': 'hello world', 'num_beams': 5}
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Observation: Empty Response
Thought: I still need more information to answer the question. The tool did not provide any useful output. I'll try to gather more context.
Action: retrieve_semantically_similar_data
Action Input: {'input': 'hello world', 'num_beams': 5}
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Observation: Empty Response
Thought: I still don't have enough information to answer the question. The tool did not provide any useful output. I'll try to gather more context.
Action: retrieve_semantically_similar_data
Action Input: {'input': 'hello world', 'num_beams': 5}
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Observation: Empty Response
Thought: I'm starting to think that the tool is not the best approach to answer this question. The user's input "hello world" is very general and the tool is not providing any useful output. I'll try to think of a different approach.
Answer: Hello, how can I assist you today?
16:04:37.764|Agent: Hello, how can I assist you today?
Insert your next message: Do you know anything about the city solve?
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
16:05:08.756|User: Do you know anything about the city solve?
Thought: The current language of the user is: English. I need to use a tool to help me answer the question.
Action: retrieve_semantically_similar_data
Action Input: {'input': 'solve city'}
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Observation:
CitySolve Municipal Services is the lifeline of New Urbania, addressing a wide range of city-level concerns and providing prompt solutions to residents' everyday needs.
Thought: I can answer without using any more tools. I'll use the user's language to answer
Answer: CitySolve Municipal Services is a city-level organization that provides solutions to residents' everyday needs in New Urbania.
16:05:13.003|Agent: CitySolve Municipal Services is a city-level organization that provides solutions to residents' everyday needs in New Urbania.
Insert your next message: What is the primary focus of Ukraine Boats Inc.?
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
16:05:34.892|User: What is the primary focus of Ukraine Boats Inc.?
Thought: The current language of the user is: English. I need to use a tool to help me answer the question.
Action: retrieve_semantically_similar_data
Action Input: {'input': 'Ukraine Boats Inc.'}
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Observation:
Ukraine Boats Inc. is a premier manufacturer and supplier of high-quality boats and maritime solutions based in Odessa, Ukraine. The company prides itself on blending traditional craftsmanship with modern technology to serve clients worldwide. Founded in 2005, the company has grown to be a leader in the boating industry, specializing in recreational, commercial, and luxury vessels.
The company has successfully delivered a range of boats and solutions to various clients, including Blue Horizon Fisheries, Azure Seas Luxury Charters, Coastal Safety Patrol, EcoTrade Logistics, Team HydroBlitz Racing, and Paradise Resorts International. These clients have reported significant benefits from working with Ukraine Boats Inc., including increased efficiency, reduced costs, and enhanced customer satisfaction.
Ukraine Boats Inc. offers a range of products and services, including luxury yachts, commercial boats, and accessories. The company's products are designed to meet the specific needs of each client, and its team of experts works closely with clients to ensure that every boat is tailored to their requirements.
Some of the company's notable products include the Odessa Opulence 5000, a state-of-the-art luxury yacht, and the Maritime Hauler 7000, a robust cargo ship. The company also offers boat customization packages, annual maintenance plans, and other services to support its clients' needs.
Overall, Ukraine Boats Inc. is a trusted and reliable partner for clients seeking high-quality boats and maritime solutions.
Thought: I can answer without using any more tools. I'll use the user's language to answer
Answer: Ukraine Boats Inc. is a premier manufacturer and supplier of high-quality boats and maritime solutions based in Odessa, Ukraine, blending traditional craftsmanship with modern technology to serve clients worldwide.
16:05:53.311|Agent: Ukraine Boats Inc. is a premier manufacturer and supplier of high-quality boats and maritime solutions based in Odessa, Ukraine, blending traditional craftsmanship with modern technology to serve clients worldwide.
Insert your next message: Do you know anything about the government company city solve?
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
16:06:09.949|User: Do you know anything about the government company city solve?
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Thought: The current language of the user is English. I need to use a tool to help me answer the question.
Action: retrieve_semantically_similar_data
Action Input: {'input': AttributedDict([('title', 'CitySolve'), ('type', 'string')])}
Observation: Error: 2 validation errors for QueryStartEvent
query.str
Input should be a valid string [type=string_type, input_value=AttributedDict([('title',...'), ('type', 'string')]), input_type=AttributedDict]
For further information visit https://errors.pydantic.dev/2.9/v/string_type
query.QueryBundle.query_str
Field required [type=missing, input_value=AttributedDict([('title',...'), ('type', 'string')]), input_type=AttributedDict]
For further information visit https://errors.pydantic.dev/2.9/v/missing
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Observation: Error: Could not parse output. Please follow the thought-action-input format. Try again.
Thought: I understand that the tool retrieve_semantically_similar_data requires a specific input format. I will make sure to follow the correct format.
Action: retrieve_semantically_similar_data
Action Input: {'title': 'CitySolve', 'type': 'string'}
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Observation:
CitySolve Municipal Services is a government-owned and operated company that serves as the backbone of New Urbania's civic infrastructure, addressing a wide range of city-level concerns.
Thought: I can answer without using any more tools. I'll use the user's language to answer
Answer: CitySolve Municipal Services is a government-owned and operated company that serves as the backbone of New Urbania's civic infrastructure, addressing a wide range of city-level concerns.
16:06:17.799|Agent: CitySolve Municipal Services is a government-owned and operated company that serves as the backbone of New Urbania's civic infrastructure, addressing a wide range of city-level concerns.
Insert your next message: Thanks
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
16:06:34.232|User: Thanks
Thought: I can answer without using any more tools. I'll use the user's language to answer
Answer: CitySolve Municipal Services is a government-owned and operated company that serves as the backbone of New Urbania's civic infrastructure, addressing a wide range of city-level concerns.
16:06:35.734|Agent: CitySolve Municipal Services is a government-owned and operated company that serves as the backbone of New Urbania's civic infrastructure, addressing a wide range of city-level concerns.
如我们所见,代理的推理方式不同。面对相同的问题,两个模型决定以不同的方式查询工具。第二个代理在使用该工具时失败过一次,但这更多是工具描述的问题,而非代理本身。两者都为用户提供了有价值的答案,这正是 RAG 方法的最终目标。
此外,还有许多不同的代理包装器可以应用于您的 LLM。它们可能会显著改变模型与世界交互的方式。
评估
现在有许多框架可用于评估 RAG。其中之一是 TruLens。总体 RAG 性能通过所谓的 RAG 三要素(答案相关性、上下文相关性和基础性)来评估。
为了评估相关性和基础性,我们将利用 LLMs。LLMs 将充当裁判,根据提供的信息对答案进行评分。
TruLens 本身是一个方便的工具,用于在度量级别上衡量系统性能,并分析特定记录的评估。以下是排行榜 UI 视图:

来源:图片由作者制作
以下是每条记录的评估表,您可以在其中查看所有被调用的内部过程。

来源:图片由作者制作
若要获取更多细节,您可以查看特定记录的执行过程。

来源:图片由作者制作
要实现 RAG 三要素评估,首先,我们必须定义实验名称和模型提供者。我们将使用gpt-4o-mini模型进行评估。
experiment_name = "llama-3.2-3B-custom-retriever"
provider = OpenAIProvider(
model_engine=CFG['configuration']['models']['llm_evaluation']
)
之后,我们定义了三元组本身(回答相关性、上下文相关性、可靠性)。对于每个指标,我们应当指定输入和输出。
context_selection = TruLlama.select_source_nodes().node.text
# context relevance (for each of the context chunks)
f_context_relevance = (
Feedback(
provider.context_relevance, name="Context Relevance"
)
.on_input()
.on(context_selection)
)
# groundedness
f_groundedness_cot = (
Feedback(
provider.groundedness_measure_with_cot_reasons, name="Groundedness"
)
.on(context_selection.collect())
.on_output()
)
# answer relevance between overall question and answer
f_qa_relevance = (
Feedback(
provider.relevance_with_cot_reasons, name="Answer Relevance"
)
.on_input_output()
)
此外,我们实例化了一个 TruLlama 对象,它将在代理调用期间处理反馈计算。
# Create TruLlama agent
tru_agent = TruLlama(
agent,
app_name=experiment_name,
tags="agent testing",
feedbacks=[f_qa_relevance, f_context_relevance, f_groundedness_cot],
)
现在我们准备在我们的数据集上执行评估管道。
for item in tqdm(dataset):
try:
agent.reset()
with tru_agent as recording:
agent.query(item.get('question'))
record_agent = recording.get()
# wait until all the feedback function are finished
for feedback, result in record_agent.wait_for_feedback_results().items():
logging.info(f'{feedback.name}: {result.result}')
except Exception as e:
logging.error(e)
traceback.format_exc()
我们进行了实验,使用了这两种模型、默认/自定义查询引擎以及额外的工具输入参数描述(ReAct 代理在没有明确工具输入参数描述时表现较差,试图调用不存在的工具来重构输入)。我们可以使用 get_leaderboard()方法将结果以 DataFrame 形式进行回顾。

来源:图片由作者创作
结论

来源:图片由作者使用 AI(Bing)生成
我们获得了一个私有语料库,结合了 GPT 模型用于自定义数据集生成。实际语料内容相当有趣且多样化。这也是为什么现在很多模型能够成功地利用 GPT 生成的样本进行微调的原因。
Neo4j 数据库为许多框架提供了方便的接口,并具有业内最佳的 UI 能力(Aura)。在实际项目中,我们通常会有数据之间的关系,图形数据库在这种用例中是一个完美的选择。
在私有语料库的基础上,我们实现了不同的 RAG 方法(独立的和作为代理的一部分)。根据 RAG 三元组指标,我们观察到基于 OpenAI 的代理工作得非常完美,而经过良好提示的 ReAct 代理表现相对相同。一个显著的差异出现在自定义查询引擎的使用上。这是合理的,因为我们配置了一些特定的程序和阈值,符合我们的数据需求。此外,两个解决方案都有很高的可靠性,这对 RAG 应用非常重要。
另一个有趣的发现是,Llama3.2 3B 和 gpt-4o-mini API 的代理调用延迟几乎相同(当然,最多的时间花费在数据库调用上,但差异仍然不大)。
尽管我们的系统运行得相当不错,但仍有许多需要改进的地方,比如关键词搜索、重排序器、邻接块选择以及与地面真值标签的比较。这些话题将在下一篇关于 RAG 应用的文章中讨论。
私有语料库以及代码和提示可以在GitHub上找到。
附言:
我要感谢我的同事:Alex Simkiv、Andy Bosyi 和 Nazar Savchenko,感谢他们富有成效的对话、合作与宝贵的建议,以及整个 MindCraft.ai 团队的持续支持。
从零开始到深度分位数预测
使用 Tensorflow、NeuralForecast 和零样本大型语言模型进行多步分位数预测的端到端实证分享。
·发布于 Towards Data Science ·11 分钟阅读·2024 年 7 月 16 日
--

图片来源:作者
内容
简短介绍
数据
构建一个分位数递归预测器的玩具版本
使用最先进模型进行分位数预测
使用大型语言模型进行零样本分位数预测
结论
简短介绍
分位数预测是一种统计技术,用于预测响应变量分布的不同分位数(例如,中位数或第 90 百分位数),提供对未来潜在结果的更全面的视角。与传统的均值预测只估计平均值不同,分位数预测使我们能够理解各种可能结果的范围和发生的可能性。
分位数预测在具有不对称损失函数或不同风险偏好的决策过程中至关重要。例如,在供应链管理中,预测需求的第 90 百分位数可以确保足够的库存水平,以避免短缺,而预测第 10 百分位数则有助于最小化过剩库存及相关成本。这种方法在金融、气象和能源等领域尤为有利,因为在这些领域中,理解分布的极端值与了解均值一样重要。
量化预测和符合性预测都处理不确定性,但它们的方法论差异显著。量化预测直接对响应变量的特定分位数进行建模,提供关于其分布的详细洞察。相反,符合性预测是一种与模型无关的技术,它围绕预测构建预测区间,保证真实值以指定的概率落入该区间。量化预测提供精确的分位数估计,而符合性预测则提供更广泛的区间保证。
量化预测的实现可以显著提升决策制定,通过提供对未来不确定性的深刻理解。这种方法使得组织能够根据不同的风险水平量身定制策略,优化资源分配,并提高运营效率。通过捕捉广泛的潜在结果,量化预测使得组织能够做出明智、数据驱动的决策,从而减轻风险并提高整体表现。
数据
为了演示该工作,我选择使用 M4 竞赛的数据作为示例。该数据采用CC0: 公共领域许可证,可以通过这里访问。数据也可以通过 datasetsforecast 包加载:
# Install the package
pip install datasetsforecast
# Load Data
df, *_ = M4.load('./data', group='Weekly')
# Randomly select three items
df = df[df['unique_id'].isin(['W96', 'W100', 'W99'])]
# Define the start date (for example, "1970-01-04")
start_date = pd.to_datetime("1970-01-04")
# Convert 'ds' to actual week dates
df['ds'] = start_date + pd.to_timedelta(df['ds'] - 1, unit='W')
# Display the DataFrame
df.head()

图片由作者提供
原始数据包含超过 300 个独特的时间序列。为了演示,我随机选择了三个时间序列:W96、W99 和 W100,因为它们的历史长度相同。原始时间戳被掩码为整数(即 1-2296),我手动将其转换回正常的日期格式,第一天为 1970 年 1 月 4 日。以下图为 W99 的预览:

图片由作者提供
构建量化回归预测器的玩具版
首先,让我们从零开始构建一个量化预测器,以了解目标数据如何流经整个管道,以及预测是如何生成的。我从 Wen 等人的论文A Multi-Horizon Quantile Recurrent Forecaster中汲取灵感。作者提出了一个多水平量化回归神经网络(MQ-RNN)框架,结合了序列到序列神经网络、量化回归和直接多水平预测,用于准确且稳健的多步时间序列预测。通过利用神经网络的表达能力、量化回归的非参数特性以及一种名为分叉序列的新训练方案,该模型能够有效处理变化的季节性、已知的未来事件以及冷启动问题,适用于大规模预测应用。
我们不能在这个简短的博客中复现所有内容,但我们可以尝试使用 TensorFlow 包进行部分复现作为演示。如果你对论文的实现感兴趣,可以利用一个正在进行的项目:MQRNN。
首先加载必要的包并定义一些全局参数。我们将使用 LSTM 模型作为核心,并且需要对数据进行一些预处理,以获得滚动窗口,然后进行拟合。input_shape设置为(104, 1),这意味着我们每个训练窗口使用两年的数据。在本次演示中,我们仅关注 80%的置信区间,并将中位数作为点预测值,这意味着分位数 = [0.1, 0.5, 0.9]。我们将使用最后 12 周作为测试数据集,因此output_steps或预测范围为 12,cut_off_date将是‘2013–10–13’。
# Install the package
pip install tensorflow
# Load the package
from sklearn.preprocessing import StandardScaler
from datetime import datetime
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, concatenate, Layer
# Define Global Parameters
input_shape = (104, 1)
quantiles = [0.1, 0.9]
output_steps = 12
cut_off_date = '2013-10-13'
tf.random.set_seed(20240710)
接下来,我们将数据转换为滚动窗口,这正是 RNN 模型所需的输入形状:
# Preprocess The Data
def preprocess_data(df, window_size = 104, forecast_horizon = 12):
# Ensure the dataframe is sorted by item and date
df = df.sort_values(by=['unique_id', 'ds'])
# List to hold processed data for each item
X, y, unique_id, ds = [], [], [], []
# Normalizer
scaler = StandardScaler()
# Iterate through each item
for key, group in df.groupby('unique_id'):
demand = group['y'].values.reshape(-1, 1)
scaled_demand = scaler.fit_transform(demand)
dates = group['ds'].values
# Create sequences (sliding window approach)
for i in range(len(scaled_demand) - window_size - forecast_horizon + 1):
X.append(scaled_demand[i:i+window_size])
y.append(scaled_demand[i+window_size:i+window_size+forecast_horizon].flatten())
unique_id.append(key)
ds.append(dates[i+window_size:i+window_size+forecast_horizon])
X = np.array(X)
y = np.array(y)
return X, y, unique_id, ds, scaler
然后我们将数据分成训练集、验证集和测试集:
# Split Data
def split_data(X, y, unique_id, ds, cut_off_date):
cut_off_date = pd.to_datetime(cut_off_date)
val_start_date = cut_off_date - pd.Timedelta(weeks=12)
train_idx = [i for i, date in enumerate(ds) if date[0] < val_start_date]
val_idx = [i for i, date in enumerate(ds) if val_start_date <= date[0] < cut_off_date]
test_idx = [i for i, date in enumerate(ds) if date[0] >= cut_off_date]
X_train, y_train = X[train_idx], y[train_idx]
X_val, y_val = X[val_idx], y[val_idx]
X_test, y_test = X[test_idx], y[test_idx]
train_unique_id = [unique_id[i] for i in train_idx]
train_ds = [ds[i] for i in train_idx]
val_unique_id = [unique_id[i] for i in val_idx]
val_ds = [ds[i] for i in val_idx]
test_unique_id = [unique_id[i] for i in test_idx]
test_ds = [ds[i] for i in test_idx]
return X_train, y_train, X_val, y_val, X_test, y_test, train_unique_id, train_ds, val_unique_id, val_ds, test_unique_id, test_ds
MQRNN 的作者利用了与预测范围相关的局部上下文,这是实现时间感知和季节性映射所必需的,同时也利用了与预测范围无关的全局上下文来捕获非时间敏感信息,从而增强了学习的稳定性和生成预测的平滑度。为了构建一个类似于 MQRNN 的模型,我们需要编写一个分位数损失函数,并添加捕获局部上下文和全局上下文的层。我为此添加了一个注意力层,向你展示如何将注意力机制融入这样的过程:
# Attention Layer
class Attention(Layer):
def __init__(self, units):
super(Attention, self).__init__()
self.W1 = Dense(units)
self.W2 = Dense(units)
self.V = Dense(1)
def call(self, query, values):
hidden_with_time_axis = tf.expand_dims(query, 1)
score = self.V(tf.nn.tanh(self.W1(values) + self.W2(hidden_with_time_axis)))
attention_weights = tf.nn.softmax(score, axis=1)
context_vector = attention_weights * values
context_vector = tf.reduce_sum(context_vector, axis=1)
return context_vector, attention_weights
# Quantile Loss Function
def quantile_loss(q, y_true, y_pred):
e = y_true - y_pred
return tf.reduce_mean(tf.maximum(q*e, (q-1)*e))
def combined_quantile_loss(quantiles, y_true, y_pred, output_steps):
losses = [quantile_loss(q, y_true, y_pred[:, i*output_steps:(i+1)*output_steps]) for i, q in enumerate(quantiles)]
return tf.reduce_mean(losses)
# Model architecture
def create_model(input_shape, quantiles, output_steps):
inputs = Input(shape=input_shape)
lstm1 = LSTM(256, return_sequences=True)(inputs)
lstm_out, state_h, state_c = LSTM(256, return_sequences=True, return_state=True)(lstm1)
context_vector, attention_weights = Attention(256)(state_h, lstm_out)
global_context = Dense(100, activation = 'relu')(context_vector)
forecasts = []
for q in quantiles:
local_context = concatenate([global_context, context_vector])
forecast = Dense(output_steps, activation = 'linear')(local_context)
forecasts.append(forecast)
outputs = concatenate(forecasts, axis=1)
model = Model(inputs, outputs)
model.compile(optimizer='adam', loss=lambda y, f: combined_quantile_loss(quantiles, y, f, output_steps))
return model
下面是绘制的预测结果:



我们还评估了每个项的 SMAPE(对称平均绝对百分比误差),以及区间覆盖率(即实际值有多少被区间覆盖)。结果如下:

这个简化版本可以作为开始进行分位数预测的良好基准。此设置未配置分布式训练,且模型架构也没有为大规模预测进行优化,因此可能存在速度问题。在下一部分中,我们将研究一个包,它允许你使用最先进的深度学习模型进行分位数预测。
使用 SOTA 模型进行分位数预测
neuralforecast包是一个优秀的 Python 库,允许你使用大多数 SOTA(最先进的)深度神经网络模型进行时间序列预测,如 PatchTST、NBEATs、NHITS、TimeMixer 等,且实现简单。在这一部分,我将以PatchTST为例,展示如何进行分位数预测。
首先,加载必要的模块并定义 PatchTST 的参数。调整模型需要一些经验,并且会依赖于具体的项目。如果你有兴趣为你的数据获取潜在的最优参数,你可以查看 neuralforecast 中的自动模块。它们允许你使用 Ray 进行超参数调优,效率非常高!neuralforecast 包包含了一系列基于不同采样方法的模型。使用 base_window 方法的模型可以让你使用 MQLoss 或 HuberMQLoss,你可以指定你要寻找的分位数级别。在本工作中,我选择了 HuberMQLoss,因为它对异常值更为稳健。
# Install the package
pip install neuralforecast
# Load the package
from neuralforecast.core import NeuralForecast
from neuralforecast.models import PatchTST
from neuralforecast.losses.pytorch import HuberMQLoss, MQLoss
# Define Parameters for PatchTST
PARAMS = {'input_size': 104,
'h': output_steps,
'max_steps': 6000,
'encoder_layers': 4,
'start_padding_enabled': False,
'learning_rate': 1e-4,
'patch_len': 52, # Length of each patch
'hidden_size': 256, # Size of the hidden layers
'n_heads': 4, # Number of attention heads
'res_attention': True,
'dropout': 0.1, # Dropout rate
'activation': 'gelu', # Activation function
'dropout': 0.1,
'attn_dropout': 0.1,
'fc_dropout': 0.1,
'random_seed': 20240710,
'loss': HuberMQLoss(quantiles=[0.1, 0.5, 0.9]),
'scaler_type': 'standard',
'early_stop_patience_steps': 10}
# Get Training Data
train_df = df[df.ds<cut_off_date]
# Fit and predict with PatchTST
models = [PatchTST(**PARAMS)]
nf = NeuralForecast(models=models, freq='W')
nf.fit(df=train_df, val_size=12)
Y_hat_df = nf.predict().reset_index()
这是绘制的预测结果:



这是指标:

通过这个演示,你可以看到模型的实现是多么简单,以及模型性能是如何提升的。然而,如果你在想是否有更简单的方法来完成这个任务,答案是肯定的。在接下来的部分,我们将介绍一个基于 T5 的模型,允许你进行零-shot 分位数预测。
零-shot 分位数预测与大语言模型(LLMs)
我们一直在见证一个趋势,即 NLP 的进展将进一步推动时间序列预测的边界,因为预测下一个词是预测下一个时间段值的合成过程。鉴于大语言模型(LLMs)在生成任务中的快速发展,研究人员也开始考虑在数百万个时间序列上对大型模型进行预训练,使用户能够进行零-shot 预测。
然而,在我们将 LLMs 和零-shot 时间序列任务画上等号之前,我们必须回答一个问题:训练语言模型和训练时间序列模型之间有什么区别?答案是:“来自有限字典的 tokens 与来自无限范围的值”。亚马逊最近发布了一个名为Chronos的项目,它很好地解决了这个挑战并实现了大规模时间序列模型。如作者所述:“Chronos 通过简单的缩放和量化实值,将时间序列分割成离散的箱子。通过这种方式,我们可以在这种‘时间序列语言’上训练现成的语言模型,无需更改模型架构。”原始论文可以在此处找到。
目前,Chronos 有多个版本可用。它可以通过autogluon API加载并使用,只需要几行代码。
# Get Training Data and Transform
train_df = df[df.ds<cut_off_date]
train_df_chronos = TimeSeriesDataFrame(train_df.rename(columns={'ds': 'timestamp', 'unique_id': 'item_id', 'y': 'target'}))
# Zero-shot forecast with Chronos
predictor = TimeSeriesPredictor(prediction_length=output_steps, freq='W', quantile_levels = [0.1, 0.9]).fit(
train_df_chronos, presets="chronos_base",
random_seed = 20240710
)
Y_hat_df_chronos = predictor.predict(train_df_chronos).reset_index().rename(columns={'mean': 'Chronos',
'0.1': 'P10',
'0.9': 'P90',
'timestamp': 'ds',
'item_id': 'unique_id'})
这是绘制的预测结果:



这里是一些评估指标:

如你所见,Chronos 与 PatchTST 相比表现得非常出色。然而,这并不意味着它已经超越了 PatchTST,因为很有可能 Chronos 已经在 M4 数据上进行了训练。在他们的原始论文中,作者还评估了该模型在未曾训练过的数据集上的表现,结果显示 Chronos 仍然与 SOTA 模型的表现非常接近。
目前正在开发更多的大型时间序列模型。其中一个叫做TimeGPT,由 NIXTLA 开发。这类模型的发明不仅让预测任务变得更容易、更可靠和一致,而且也是在有限历史数据的情况下对时间序列做出合理猜测的一个良好起点。
结论
从构建量化回归预测器的玩具版,到利用最先进的模型和零样本大语言模型,本博客展示了量化预测的强大功能和多样性。通过整合像 TensorFlow 的 LSTM、NeuralForecast 的 PatchTST 和 Amazon 的 Chronos 等模型,我们可以实现准确、稳健且计算高效的多步时间序列预测。量化预测不仅通过提供对未来不确定性的细致理解来增强决策支持,还帮助组织优化战略和资源分配。神经网络和零样本学习模型的进步进一步推动了这一领域的边界,使得量化预测成为现代数据驱动行业中的关键工具。
注:所有的图片、数字和表格均由作者生成。完整代码可以在这里找到:Quantile Forecasting。
从 Set Transformer 到 Perceiver Sampler
关于多模态 LLM Flamingo 的视觉编码器
·发表于 Towards Data Science ·阅读时间 5 分钟·2024 年 10 月 8 日
--
设计多模态 LLM 很难。
目前最先进的多模态 LLMs 主要基于现有的 LLM 架构,通过特定的修改来处理不同来源的输入,这也是问题的来源。最新的 Nvidia 论文将常用的多模态架构分为两类:
基于解码器的;
基于交叉注意力。
我之前的 一篇文章讨论了 Meta 的最新论文,使用 基于解码器的架构,通过 VAE 编码器将输入图像转换为潜在向量,解决了图像空间是连续的,且与离散的文本空间不同的问题。
然而,基于交叉注意力的架构问题则不同。例如,在多模态 LLM 模型 Flamingo 中,关键问题是如何将来自通用视觉模型的视觉嵌入(其时空维度各异)转换到交叉注意力层,以匹配语言输入的维度。
在这篇文章中,我将深入探讨 Flamingo 在视觉编码器基础上独特的设计——Perceiver Resampler,解释这个问题是如何解决的。此外,我还将探索 Perceiver Resampler 的起源——来自 Set Transformer 的 Induced Set Attention Block,后者进一步启发了 DeepMind 的 Perceiver 模型,用于从通用输入数据中学习固定长度的潜在嵌入。

图片来源:pxhere.com/en/photo/1399240
Set Transformer
Set Transformer 工作于 2019 年发布,扩展了原始 Transformer 模型,用于处理集合问题,解决了置换不变问题,如集合异常检测、点云分类等。受稀疏高斯过程启发,其中一小部分诱导变量足以近似输入的后验分布,Set Transformer 使用如下定义的诱导集合注意力块(ISAB):

诱导集合注意力块(ISAB)。方程来源:arxiv.org/pdf/1810.00825
MAB(X, Y)是 Transformer 的原始多头注意力块,其中查询= X,键/值= Y。ISAB 块几乎与两个堆叠的多头注意力块相同,唯一的区别是输入的键/值被诱导矩阵 I 替代。原始集合 X 的维度为 ND,I 的维度为 MD,表示 M 个 1*D 的诱导点。下面是一个可视化示意图。

多头注意力块和诱导集合注意力块的可视化。图片来源:arxiv.org/pdf/1810.00825
请注意,ISAB 的设计目的是节省计算成本。原因是 M 可能远小于原始的 N 维度,这使得 ISAB 的时间复杂度 O(Nd)比原始自注意力复杂度 O(N**2d)小得多。
Perceiver
受 Set Transformer 中使用诱导点作为查询矩阵的启发,DeepMind 提出的 Perceiver 模型将查询矩阵分离为一个可学习的潜在嵌入短序列(例如,N=512),而键和值对则是一个字节数组,作为超长序列输入(例如,M=224*224 像素)。

Perceiver 模型架构。图片来源:arxiv.org/abs/2103.03206
交叉注意力借鉴了原始 Transformer 解码器部分,其中查询(query)和键/值(key/value)来自不同的来源,在这种情况下,来自不可学习的表示:

多头注意力和交叉注意力。图片来自作者。
由于 K 和 V 是输入的“常量”,Perceiver Transformer 层的计算复杂度仅与潜在空间相关,即 O(N2),也被称为潜在 Transformer**。与输入大小解耦后,潜在 Transformer 可以快速扩展到 48 层,这相比传统的 Transformer 设计具有巨大的优势。
Flamingo 的视觉编码器和感知重采样器
Flamingo 并没有直接应用 Perceiver,而是首先使用预训练的基于 CNN 的权重冻结的 Normalizer-Free ResNet(NFNet)来提取图像/视频特征,然后添加一个可学习的时间位置嵌入,并将其展平为 1D 序列。Perceiver Resampler 附加在视觉编码器上,以学习一个固定大小的潜在嵌入,然后将其传递到主架构的跨注意力层。

Flamingo 架构。图像来源:arxiv.org/pdf/2204.14198
类似于 DeepMind 的 Preceiver 模型,Percerver Resampler 使用恒定的输入嵌入作为键/值,学习的潜在向量作为查询。注意,这里没有使用空间编码,理由是之前的视觉编码器 NFNet 是一个基于卷积的模型,空间信息已经嵌入在通道信息中。为了提高性能,学习向量被连接到跨注意力计算中的键/值向量。

Preceiver Resampler 架构。图像来源:arxiv.org/abs/2204.14198
下面给出了详细的算法:

Perceiver Resampler 算法。算法来源:arxiv.org/abs/2204.14198
摘要
本文详细介绍了 Flamingo 架构中的视觉编码器部分。视觉编码器有一个独特的设计——Perceiver Resampler,它起源于 Set Transformer 和 Perceiver 模型,能够在利用空间和时间域信息的同时,最小化跨注意力计算的成本。
参考文献
Dai 等人,《NVLM: Open Frontier-Class Multimodal LLMs》。arXiv 2024。
Zhou 等人,《Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model》。arXiv 2024。
Alayrac 等人,《Flamingo: a Visual Language Model for Few-Shot Learning》。NeurIPS 2022。
Jaegle 等人,《Perceiver: General Perception with Iterative Attention》。ICML 2021。
Brock 等人,《高性能大规模图像识别无归一化》。arXiv 2021。
Lee 等人,《Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks》。ICML 2019。幻灯片
Vaswani 等人,《Attention Is All You Need》。NeurIPS 2017。
斯坦福 CS25: V1 I DeepMind 的 Perceiver 和 Perceiver IO: 新的数据家族架构,
www.youtube.com/watch?v=wTZ3o36lXoQHuggingFace,Perceiver 模型文档。
huggingface.co/docs/transformers/v4.34.0/en/model_doc/perceiver
从社会科学到数据科学
8 年前,我开始了我的地理学本科学位。现在,我是一名数据科学家;这是我为什么走上这条道路的故事,以及我为什么认为数据科学可以从人文学科中学到很多东西。
·发表于Towards Data Science ·阅读时间:9 分钟·2024 年 5 月 2 日
--

图片由Dário Gomes提供,来源于Unsplash
在我当前的团队中,大约 60 名数据科学家和机器学习工程师,只有大约 5 个人没有在大学时主修 STEM(科学、技术、工程和数学)学科。
这个比例在科技行业中非常典型——2020 年,StackOverflow 开发者调查发现,只有 3.8%的专业开发者在大学时主修社会科学或人文学科。
这不一定是坏事或令人惊讶的事情(数据科学自然吸引那些在学校和大学喜欢 STEM 科目的人)。但 STEM 并不是进入该领域的唯一途径,我认为有志成为数据科学家的朋友,如果他们没有传统的数学/计算机科学背景,也不应感到灰心丧气。
STEM 并不是进入数据科学的唯一途径
在本文中,我将分享我如何从社会科学背景转行进入数据科学的故事。
从替代建模到航空航天工程:一个 NASA 案例研究
这是替代建模如何从理论到实践,彻底改变航空航天工程领域的方式。
·发表于 Towards Data Science ·阅读时间 10 分钟·2024 年 8 月 15 日
--

图片由作者使用 DALL·E 制作。
想象一下,你因为腹痛去看医生。想象一下,为了告诉你疼痛的原因,医生需要进行数值模拟。这个数值模拟解决了你身体的微分方程,并且近似误差非常低,几乎达到 100%的准确度:你将准确地知道为什么会有腹痛,完全没有错误的余地(我知道这听起来很疯狂,但请耐心听我解释)。你会使用这个数值模拟吗? 当然会,为什么不呢?
现在想象一下,因为这个数值模拟需要解决非常复杂的微分方程,为了得到计算机的响应,你需要等待35 年的 GPU 运行时间。这会立即降低这种方法的吸引力。如果需要这么长时间才能得到响应,这种方法有什么用呢?当然,它可能是 100%准确的,但计算成本太高。
好吧,工程师们想到了一个解决方案。这个解决方案被称为替代建模。顾名思义,这种方法是使用原始模拟的替代物…
从文本到网络:LLMs 对知识图谱的革命性影响

插图作者提供
使用 LLMs 构建和利用知识图谱的逐步指南
·发布于 Towards Data Science ·阅读时间:15 分钟·2024 年 8 月 29 日
--
大型语言模型(LLMs)的兴起彻底改变了我们从文本中提取信息和与其互动的方式。然而,尽管它们具有令人印象深刻的能力,LLMs 仍面临着一些固有的挑战,尤其是在推理、一致性和信息的上下文准确性等方面。这些困难源于 LLMs 的概率性质,这可能导致幻觉、透明度缺乏以及处理结构化数据时的挑战。
这就是知识图谱(KGs)发挥作用的地方。通过将 LLMs 与 KGs 相结合,AI 生成的知识可以得到显著增强。为什么?KGs 提供了信息的结构化和相互连接的表示,反映了现实世界中的关系和实体。与传统数据库不同,KGs 能够捕捉并推理人类知识的复杂性,确保 LLMs 的输出来自一个结构化、可验证的知识库。这种集成导致了更准确、一致和具有上下文相关性的结果。
医疗、金融和法律服务等行业可以从知识图谱中获得巨大益处,因为它们需要精确的...
从理论到实践:使用 Python 进行粒子群优化
这里有一个关于什么是 PSO 以及如何使用它的教程
·发表于 Towards Data Science ·阅读时间 9 分钟 ·2024 年 9 月 7 日
--

照片来自 James Wainscoat 于 Unsplash
有一个笑话让我忍俊不禁:
“你知道吗,在时钟发明之前,人们不得不四处走动,向别人打听时间?”
显然没有必要解释这个笑话,但如果我们稍微深思一下(就像优秀的数学家那样),我们可以说这个笑话其实是在讲一个群体中粒子的相关信息可以被用来通知所有其他粒子。这个概念实际上比我刚刚提到的笑话要深刻得多,且可以进一步被利用。
让我们考虑一个自组织系统,比如鸟群飞行或鱼群游动。我们可以将这个系统定义为由粒子(例如,一个粒子就是一只鸟)构成。我们还可以假设这些粒子在空间中移动,根据两个因素来调整它们的位置:
特定粒子知道的最佳位置:鸟儿认为对自己最有利的地方。
所有粒子“相互沟通”后得出的全局最佳位置:鸟儿根据“主鸟”指示的位置进行调整
从视觉变换器到掩码自编码器,5 分钟搞定
自然语言处理任务如何泛化到计算机视觉的简单指南
·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 6 月 28 日
--
几乎所有的自然语言处理任务,从语言建模、掩码词预测到翻译和问答,都在 2017 年变换器架构首次亮相时经历了革命性变化。毫不奇怪,在短短 2 至 3 年内,变换器也被应用到计算机视觉任务中,并且在这些任务中展现了卓越的表现。在这篇文章中,我们探讨了两种基础架构,使得变换器能够突破进入计算机视觉领域。
目录
· 视觉变换器
∘ 关键思想
∘ 操作
∘ 混合架构
∘ 结构的丧失
∘ 结果
∘ 自监督学习通过掩码
· 掩码自编码器视觉变换器
∘ 关键思想
∘ 架构
∘ 最终备注与示例
视觉变换器

论文中的图片:“一张图片值 16x16 个词:用于大规模图像识别的变换器”
关键思想
视觉变换器的目的是将标准变换器架构推广到处理和学习图像输入。关于该架构有一个关键思想,作者在文中非常透明地进行了强调:
“受变换器在自然语言处理领域扩展成功的启发,我们尝试将标准变换器直接应用于图像,尽可能少地做修改。”
操作
将“尽可能少的修改”字面理解是合理的,因为他们几乎没有做任何修改。实际上他们修改的是输入结构:
在 NLP 中,transformer 编码器接收一个one-hot 向量序列(或者等价的令牌索引),这些向量代表输入句子/段落,并返回一个上下文嵌入向量序列,后者可以用于进一步的任务(例如分类)。
为了推广 CV,视觉 transformer 接收一个patch 向量序列,这些向量代表了输入图像,并返回一个上下文嵌入向量序列,后者可以用于进一步的任务(例如分类)。
特别地,假设输入图像的维度是(n,n,3),并将其作为输入传递给 transformer,视觉 transformer 所做的是:
它将图像分成 k²个 patch,其中 k 是某个值(例如,k=3),如上图所示。
现在每个 patch 的维度将是(n/k,n/k,3),下一步是将每个 patch 展平为一个向量。
patch 向量的维度将是 3*(n/k)(n/k)。例如,如果图像的维度是(900,900,3),并且我们使用 k=3,那么一个 patch 向量的维度将是 300300*3,表示展平后的 patch 中的像素值。在论文中,作者使用了 k=16。因此,论文的标题是“An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale”,而不是输入一个表示单词的 one-hot 向量,它们代表的是表示图像 patch 的像素向量。
其余操作与原始 transformer 编码器相同:
这些 patch 向量会经过一个可训练的嵌入层。
在每个向量中添加位置嵌入,以保持图像中的空间信息。
输出是num_patches个编码表示(每个 patch 一个),这些表示可以用于对 patch 或图像层面的分类。
更常见的是(如论文中所示),一个 CLS 标记被加到表示序列的最前面,这个标记用于对整个图像进行预测(类似于 BERT)。
那 transformer 解码器怎么样?
记住,这就像 transformer 编码器一样;不同之处在于,它使用了掩码自注意力而不是自注意力(但输入签名保持不变)。无论如何,你应该预期很少使用仅解码器的 transformer 架构,因为简单地预测下一个 patch 可能不是一个很感兴趣的任务。
混合架构
作者还提到,完全可以从 CNN 特征图开始,而不是直接从图像开始,形成一种混合架构(CNN 将输出馈送给视觉 transformer)。在这种情况下,我们可以把输入看作是一个通用的(n,n,p)特征图,而一个 patch 向量的维度将是(n/k)*(n/k)*p。
结构丢失
你可能会想,这种架构不应该那么好,因为它把图像当作线性结构来处理,但实际上并非如此。作者通过提到这一点,试图表明这是故意为之。
“二维邻域结构的使用非常有限……在初始化时,位置嵌入不包含任何关于补丁的二维位置信息,所有补丁之间的空间关系必须从头开始学习。”
我们将看到,变换器(transformer)能够学习这一点,这一点在他们的实验中得到了验证,更重要的是,下一篇论文中的架构也证明了这一点。
结果
结果的主要结论是,对于小数据集,视觉变换器通常不如基于 CNN 的模型,但对于较大数据集,它们能够接近或超越 CNN 模型,并且无论如何都需要显著更少的计算资源:

来自论文的表格:“一张图片价值 16x16 个词:用于大规模图像识别的变换器”。
在这里,我们看到对于 JFT-300M 数据集(包含 3 亿张图片),在该数据集上预训练的 ViT 模型表现超过了基于 ResNet 的基准模型,同时预训练所需的计算资源明显更少。如图所示,他们使用的最大视觉变换器(ViT-Huge,拥有 6.32 亿参数,k=16)所需的计算量仅为 ResNet 基准模型的 25%,且仍然优于其性能。即使是 ViT-Large,仅使用不到 6.8%的计算资源,其性能也几乎没有下降。
同时,其他研究也展示了 ResNet 在训练时在 ImageNet-1K(仅包含 130 万张图片)上的表现显著更好。
通过遮罩进行的自监督学习
作者们对遮罩补丁预测进行了初步探索,模仿了 BERT 中使用的遮罩语言模型任务(即,遮罩出一些补丁并尝试预测它们)。
“我们使用遮罩补丁预测目标进行初步的自监督实验。为此,我们将 50%的补丁嵌入破坏,通过以下方式之一:将其嵌入替换为可学习的[遮罩]嵌入(80%)、随机替换为其他补丁嵌入(10%),或者保持原样(10%)。”
通过自监督预训练,他们的小型 ViT-Base/16 模型在 ImageNet 上的准确率达到了 79.9%,相比从零开始训练提升了 2%,但仍然比监督预训练低 4%。
遮罩自编码器视觉变换器

来自论文的图片:《Masked Autoencoders Are Scalable Vision Learners》
关键思想
正如我们从视觉变换器的论文中看到的,通过遮罩输入图像的补丁进行预训练的收益并不像在普通 NLP 任务中那样显著,在 NLP 中,遮罩预训练可以在某些微调任务中带来最先进的结果。
本文提出了一种视觉变换器架构,涉及一个编码器和解码器,当通过遮罩进行预训练时,能显著提高基准视觉变换器模型的表现(与使用监督训练的基准模型相比,提升高达 6%)。

来自论文的图像:Masked Autoencoders Are Scalable Vision Learners
这是一些示例(输入、输出、真实标签)。它是一个自编码器,旨在重建输入的同时填充缺失的补丁。
架构
他们的编码器实际上是我们之前解释的普通视觉变换器编码器。在训练和推理过程中,它只处理“观察到的”补丁。
与此同时,他们的解码器实际上是普通的视觉变换器编码器,但它需要:
缺失补丁的遮蔽标记向量
已知补丁的编码器输出向量
所以对于一个图像 [ [ A, B, X], [C, X, X], [X, D, E]],其中 X 表示缺失的补丁,解码器将处理补丁向量的序列 [Enc(A), Enc(B), Vec(X), Vec(X), Vec(X), Enc(D), Enc(E)]。Enc 返回给定补丁向量的编码器输出向量,X 是表示缺失标记的向量。
解码器中的最后一层是一个线性层,它将上下文嵌入(由解码器中的视觉变换器编码器生成)映射到一个长度等于补丁大小的向量。损失函数是均方误差,它计算原始补丁向量与此层预测补丁向量之间的差值平方。在损失函数中,我们只关注解码器对遮蔽标记的预测,忽略与已存在的标记(即 Dec(A)、Dec(B)、Dec(C)等)对应的部分。
最后的评论和示例
可能令人惊讶的是,作者建议遮蔽图像中约 75%的补丁,而 BERT 仅会遮蔽大约 15%的单词。他们是这样辩解的:
图像是具有大量空间冗余的自然信号——例如,缺失的补丁可以通过相邻的补丁恢复,而无需高层次的理解部分、物体和场景。为了克服这一差异并鼓励学习有用的特征,我们会遮蔽大量的随机补丁。
想自己试试吗?查看 NielsRogge 提供的演示笔记本。
这就是这个故事的全部内容。我们经历了一段旅程,理解了基本的变换器模型如何推广到计算机视觉领域。希望你觉得它清晰、有洞察力,且值得你的时间。
参考文献:
[1] Dosovitskiy, A. et al. (2021) An image is worth 16x16 words: Transformers for image recognition at scale, arXiv.org。可在以下网址访问:arxiv.org/abs/2010.11929(访问时间:2024 年 6 月 28 日)。
[2] He, K. et al. (2021) Masked autoencoders are scalable vision learners, arXiv.org。可在以下网址访问:arxiv.org/abs/2111.06377(访问时间:2024 年 6 月 28 日)。
从零到应用:使用 Python 构建数据库驱动的 Streamlit 应用
构建一个功能性 Streamlit 应用并集成 SQLite 的初学者指南
·发表于 Towards Data Science ·阅读时间:5 分钟·2024 年 9 月 26 日
--
从零代码平台到复杂的编程解决方案,我们有很多选择来构建应用程序。从头开始创建一个功能性的应用可能会让人感到非常有压力。如果你是分析师,且了解 Python 和 SQL,那么你非常幸运。Python 提供了强大的库,使其非常适合那些希望快速创建交互式应用程序的人——而无需深入了解 HTML、CSS 或 JavaScript。
Streamlit 是一个开源的 Python 库,允许你以最少的努力构建自定义 Web 应用程序。它负责前端组件的处理,让你可以专注于编写 Python 代码。通过 Streamlit 分享功能,用户也可以轻松免费地部署自己的 Streamlit 应用。SQLite 则是一个轻量级、无服务器的数据库引擎,已包含在 Python 中。它无需额外的设置,可以将应用的数据保存在本地文件中,使其极具便携性并且简单易用。
将 Streamlit 和 SQLite 结合使用,你将得到一个交互式、用户友好的应用程序,背后有一个强大的后端数据库系统。在本文中,我将引导你了解如何创建一个简单的费用管理应用,用于记录个人开销并导出所有数据以进行进一步分析。
FrugalGPT 和降低 LLM 运营成本
本博客将详细介绍“FrugalGPT”论文中所见的针对 LLM 驱动应用的节省成本架构。
·发布于 Towards Data Science ·阅读时间 5 分钟·2024 年 3 月 27 日
--

图片由作者生成,使用 DALL-E
大型语言模型为计算机科学开辟了一个新的前沿,然而,到了 2024 年,它们的运行成本显著高于计算机科学中几乎所有其他项目。对于寻求降低运营成本的公司来说,这构成了一个严重的问题。论文“FrugalGPT:如何在减少成本并提高性能的同时使用大型语言模型”提出了一个框架,可以在保持质量的同时显著降低运营成本。
如何衡量 LLM 的成本
有多种方法可以确定运行大型语言模型(LLM)的成本(例如电力消耗、计算成本等),然而,如果使用第三方 LLM(即 LLM 即服务),他们通常会根据你使用的 tokens 数量收费。不同的供应商(如 OpenAI、Anthropic、Cohere 等)有不同的 tokens 计数方式,但为了简化起见,我们将成本视为基于 LLM 处理的 tokens 数量。
这个框架最重要的部分是不同模型的成本是不同的。论文的作者巧妙地整理了以下表格,突出了成本的差异,它们之间的差距非常显著。例如,在此表中,AI21 的输出 tokens 成本比 GPT-4 高出一个数量级!

表 1 来自论文
作为成本优化的一部分,我们始终需要找到在最小化成本的同时优化答案质量的方法。通常情况下,较高成本的模型往往性能较好,能够提供比低成本模型更高质量的答案。下图展示了这一普遍关系,FrugalGPT 的表现以红色覆盖在上面。

图 1c 来自论文 对比了各种 LLM 基于 HEADLINES 数据集准确回答问题的频率
利用级联 LLM 最大化质量
通过利用不同模型之间巨大的成本差异,研究人员的 FrugalGPT 系统依赖于 LLM 的级联来为用户提供答案。简单来说,用户查询从最便宜的 LLM 开始,如果答案足够好,就返回该答案。如果答案不够好,则查询会传递给下一个最便宜的 LLM。
研究人员使用了以下逻辑:如果一个较便宜的模型回答问题错误,那么较昂贵的模型很可能会给出正确的答案。因此,为了最小化成本,链条按从最便宜到最昂贵的顺序排列,假设随着价格的上涨,质量也会提高。

图 2e 来自论文 展示了 LLM 级联的示意图
这种设置依赖于可靠地判断一个答案是否足够好。为了解决这个问题,作者创建了一个 DistilBERT 模型,它会对问题和答案进行评分。由于 DistilBERT 模型的体积比序列中的其他模型要小得多,因此与其他模型相比,运行它的成本几乎可以忽略不计。
比仅仅查询最好的 LLM 获得更好的平均质量
人们自然会问,如果质量最重要,为什么不直接查询最好的 LLM,并寻求降低运行最佳 LLM 的成本的方法?
当这篇论文发布时,GPT-4 是他们找到的最好的 LLM,然而,GPT-4 并不总是给出比 FrugalGPT 系统更好的答案!(细心的读者会注意到,这是之前成本与性能图表的一部分) 作者推测,就像最有能力的人不总是给出正确的答案一样,最复杂的模型也未必如此。因此,通过让答案经过 DistilBERT 的过滤过程,你可以去除任何不合格的答案,从而提高获得好答案的几率。

图 5a 来自论文 显示了 FrugalGPT 超越 GPT-4 的实例
因此,这个系统不仅能降低成本,还能提高质量,比单纯使用最好的 LLM 更有效!
向前迈进,节约成本
本文的结果令人深思。对我来说,它引发了关于如何在不进一步投资于模型优化的情况下,继续实现成本节约的问题。
其中一种可能性是将所有模型的答案缓存到向量数据库中,然后进行相似性搜索,确定缓存中的答案是否适用,再开始 LLM 链的操作。通过用相对较便宜的查询和相似性操作替代昂贵的 LLM 操作,这将显著降低成本。
此外,这也让人思考是否过时的模型依然值得进行成本优化,因为如果你能降低每个 token 的成本,它们仍然可以在 LLM 链中创造价值。类似地,关键问题在于,何时通过将新的 LLM 加入链条中会出现收益递减的现象。
后续研究问题
随着世界上大语言模型的增加,我们越来越多地构建使用这些模型的系统,我们将希望找到更具成本效益的运行方式。本文为未来的开发者提供了一个强有力的框架,这让我思考这个框架能够走多远。
在我看来,这个框架非常适用于那些没有基于不同用户提供不同答案的一般性查询,例如辅导型大语言模型(LLM)。然而,对于那些答案根据用户不同而有所不同的应用场景,比如作为客服代理的大语言模型,评分系统必须能够识别 LLM 与谁进行对话。
为用户特定交互节省成本的框架将是未来的重要课题。
[1] Chen, L., et al., FrugalGPT: 如何在降低成本并提高性能的同时使用大语言模型 (2023), arXiv
完整指南:如何使用 Python、Markdown、Git 和 GitHub Pages 构建专业作品集
如何通过高质量的作品集来提升你作为开发者或数据工作者的形象
·发表于Towards Data Science ·阅读时长 9 分钟·2024 年 7 月 20 日
--

图片来源:Product School 于Unsplash
索引
我的前一个作品集
冒险开始
代码
3.1. 所需软件
3.2. 配置本地环境
3.3. 创建工作文件夹
3.4. 准备 Git
3.5. 填充网站内容
3.6. 网站配置
3.7. 最终的 mkdocs.yml 文件
3.8. 最终文件结构
3.9. 部署工作
GitHub Pages 配置
迭代改进
我的 2024 年版作品集!
我的前一个作品集
2023 年,我已经编写数据项目代码两年,开始着手创建我的第一个作品集,以展示我的数据科学项目。我发现了 Matt Chapman 的 TDS 文章和Matt Chapman 的作品集。这篇文章与我当时的技术知识(Python、Git)完全吻合。感谢 Matt Chapman 的文章,我开始了我的第一个作品集!于是我决定深入探索这个方案,并弄清楚如何实现。我发现了Matt Chapman 使用的参考资料和相应的仓库。我利用这个参考资料创建了我的作品集。
到 2024 年,我发现我的旧作品集相比现有的作品集显得有些过时,并且对数据爱好者或招聘人员吸引力不足。通过探索社区中已有的项目,我发现了几个文档非常棒的项目。以下是激励我的 2 个链接:基于 GitHub Pages 的多页面文档,以及 基于 GitHub Pages 的 JavaScript 作品集 和 相关 Medium 文章。
冒险开始
对于我的新版本作品集,我的标准是:一个免费的解决方案,且配置最少。浏览现有的文档和作品集后,我有几个选择:
选项 1:保持我之前作品集的相同视觉和架构
选项 2:分叉上述提到的 JavaScript 作品集
选项 3:使用 mkdocs Python 包
由于我不使用 JavaScript 编程,我很快会在自定义方面受到限制,所以我决定放弃。我之前的作品集中使用单一的 GitHub Pages 并不足以提升我的作品集。在我的研究中,我发现了 2 个 mkdocs 子包,它们的视觉效果特别吸引我:mkdocs-material 和 just-the-docs。我最终选择了 mkdocs,主要基于以下 3 个原因:
该包(mkdocs)易于使用,而且我了解 Python 和 Git,这正是所需的 2 种技术。
网站内容采用 Markdown 格式,且该包可以在我的 GitHub 仓库中自动生成网站,我只需最少的操作。
生成的网站既美观又实用。
Mkdocs-material 允许使用 Google 标签,非常适合跟踪我作品集的流量!
代码
在这个项目开始时,我已经设置好了我的 GitHub Pages,创建了我的仓库,并为我的之前的作品集创建了虚拟环境。为了让每个人都能跟随并复现这篇文章,我决定从头开始。对于那些已经拥有 GitHub Pages 作品集的你们来说,你们已经熟悉 Git 和 Python,能够毫无问题地处理分支。
在这篇文章中,我将分享一些 URL 链接。我的目标是让你们充分理解代码的各个方面,并在必要时提供资源,帮助你们深入了解某个主题或解决我文章中未描述的错误。
所需软件
对于这项工作,你至少需要在计算机上安装并配置 Python 和 Git,并拥有一个 GitHub 账户。个人而言,我在 VSCode 和 集成到 PowerShell 的 Miniconda 中工作,这样我可以将脚本和终端放在同一屏幕上。关于 Git 配置,请参考 Git 官网 中的 Your identity 部分。
配置本地环境
我使用的是 Miniconda。如果你使用相同的工具,你会看到‘(base)>’ 元素。如果不是,这个元素代表当前的虚拟 Python 环境(base 是 Miniconda 的默认虚拟环境)。元素 *working_folder* 是终端的当前文件夹。
1. 第一步是为作品集项目创建虚拟环境:
(base)> conda create -n "portfolio_env" # Create the new virtual env named portfolio_env
(base)> conda activate portfolio_env # Activate the new virtual env, (base) become (portfolio_env)
2. 在这个新环境中,我们需要安装以下 Python 包:
(portfolio_env)> pip install mkdocs mkdocs-material
3. 为了保证你环境的可复现性,我们导出依赖包:
(portfolio_env)> conda env export > "environnement.yml" # Export the environment.yml file, to ensure conda env repoductibility (including the python version, the conda env configuration, … )
(portfolio_env)> conda list > "requirements.txt" # Export packages installed only
创建工作文件夹
我的前一个作品集没有使用 mkdocs,所以我创建了 mkdocs 结构:
(portfolio_env)> mkdocs new "<your GitHub username>.github.io"
将
<your GitHub username>.github.io
|- mkdocs.yml
|- environment.yml
|- requirements.txt
|- docs/
|- index.md
要了解 mkdocs 包,你可以在 这里 查阅文档。
准备 Git
如果你已经有了 GitHub Pages,你可以克隆你的
创建一个名为 readme.md 的文件。这个文件是任何 Git 仓库的经典 README 文件。
打开 Git 终端并创建本地仓库
working_folder> git init # Initiate the local repo
working_folder> git add . # Save the readme file
working_folder> git commit -m "Initiate repo" # Commit changes
3. 在你的 GitHub 账户上,创建一个名为
4. 将本地仓库与远程仓库连接。在 Git 终端中:
working_folder> git remote add github https://github.com/<your GitHub username>.github.io
如果你不熟悉 GitHub Pages,GitHub Pages 网站 将向你介绍它们,并解释为什么我使用
工作文件夹将具有以下结构:
<your GitHub username>.github.io
|- .git
|- readme.md
|- mkdocs.yml
|- environment.yml
|- requirements.txt
|- docs
|- index.md
为网站提供内容
Mkdocs 允许你展示网站并动态包含修改,这样你可以随时看到你的网站变化。动态生成网站的代码:
mkdocs serve
这个命令返回一个本地 URL(例如 127.0.0.1:8000),可以粘贴到浏览器中。
readme.md 和 index.md
README 文件对应于远程仓库的主页。当你使用 mkdocs 包创建工作文件夹时,它会创建一个 docs/index.md 文件,对应网站的首页。
菜单
第一步是配置网站的菜单(左侧面板,用于在页面之间导航)。在working_folder/mkdocs.yml文件中,这是nav部分:
# Page tree: refer to mkdocs documentation
nav:
- Home: index.md
- Current project:
- "Health open data catalog": catalogue-open-data-sante/index.md
- Previous data science projects:
- "Predict Health Outcomes of Horses": horse_health_prediction_project/readme.md
…
- Previous skills based projects:
- "Predict US stocks closing movements": US_stocks_prediction_project/readme.md
…
主页元素很重要:这是网站的主页。你可以选择将readme.md文件复制到index.md文件中,这样在 GitHub 仓库和网站上都会有相同的主页,或者编写一个新的index.md文件,为你的作品集提供一个特定的主页。
让我们来拆解一下下面的块:
- Previous data science projects:
- "Predict Health Outcomes of Horses": horse_health_prediction_project/readme.md
…
先前的数据科学项目:将代表导航栏中的一组页面的名称。*“预测马匹健康结果”*将是所指示文件菜单中显示的名称,在本例中是:horse_health_prediction_project/readme.md。Mkdocs 会自动找到 docs 文件夹中需要显示的页面,因此无需在路径中指定。然而,由于马匹健康预测项目位于一个同名文件夹中,你必须指定显示该文件的文件夹位置。
在*docs/*文件夹中,我添加了我之前的项目:
working_folder
|- docs
|- horse_health_prediction_project
|- readme.md
|- notebook.ipynb
|- notebook.html
|- US_stocks_prediction_project
|- reamd.me
|- notebook.ipynb
|- notebook.html
然后,我用以下语法将每个项目的展示添加到导航栏中:*“<显示名称>”: <从 docs 到项目文件的路径>/<项目展示>.md。
这里的缩进非常重要:它定义了导航栏的文件夹。并非所有 docs 文件夹中的文件都需要列在导航栏中。然而,如果它们没有列出,访客将无法直接访问。
网站配置
然后我配置了我网站上不可见但非常重要的方面:
- mkdocs-material 要求的基本网站信息:
# Project information
site_name: Pierre-Etienne's data science portfolio
site_url: https://petoulemmonde.github.io/
site_author: Pierre-Etienne Toulemonde
site_description: >-
I'am Pierre-Etienne Toulemonde, PharmD and Data scientist,
and you are here on my data science portfolio
site_name对应浏览器标签上的名称。
- 一些仓库信息,允许 mkdocs-material 在托管网站的仓库右上角显示信息:
# Repository: necessary to display the repo on the top right corner
repo_name: petoulemonde/petoulemonde.github.io
repo_url: https://github.com/petoulemonde/petoulemonde.github.io
- 我在这里定义了要使用的设计:
# Configuration:
theme:
name: material
这是一个非常重要的步骤,因为这一行告诉 mkdocs:“使用 mkdocs-material 包来构建网站”。如果漏掉这一步,GitHub Pages 将不会拥有 mkdocs-material 的视觉效果和功能!
我添加了一些额外的信息,以追踪我网站上的流量:
# Additional configuration
extra:
analytics:
provider: google
property: <your google analystics code>
property是来自 Google Analytics 的代码,用于追踪我作品集上的流量。该代码由 Google Analytics 生成并与我的 Google 帐户关联(你可以在这里找到创建代码的教程)。
当然,我不是一次性写完整个文件的。我先添加一个项目文件和信息到文件架构和导航栏中,然后是配置,再添加另一个项目,然后再配置,……
最终的mkdocs.yml文件
我的最终mkdocs.yml文件是:
# Project information
site_name: Pierre-Etienne's data science portfolio
site_url: https://petoulemonde.github.io/
site_author: Pierre-Etienne Toulemonde
site_description: >-
I'am Pierre-Etienne Toulemonde, PharmD and Data scientist,
and you are here on my data science portfolio
# Repository
repo_name: petoulemonde/petoulemonde.github.io
repo_url: https://github.com/petoulemonde/petoulemonde.github.io
# Configuration
theme:
name: material
# Additional configuration
extra:
analytics:
provider: google
property: <google analystics code>
# Page tree
nav:
- Home: index.md
- Current project:
- "Health open data catalog": catalogue-open-data-sante/index.md
- Previous data science projects:
- "Predict Health Outcomes of Horses": horse_health_prediction_project/readme.md
…
- Previous skills based projects:
- "Predict US stocks closing movements": US_stocks_prediction_project/readme.md
…
最终的文件结构
在此步骤中,我的文件结构是:
petoulemonde.github.io
|- .git
|- readme.md
|- mkdocs.yml
|- requirements.txt
|- environnement.yml
|- docs/
|- index.md
|- US_stocks_prediction_project/
|- README.md
|- notebook.ipynb
|- notebook.html
|- horse_health_prediction_project/
|- README.md
|- notebook.ipynb
|- notebook.html
|- … others projects …
部署工作
Mkdocs 允许通过一条命令生成网站的代码:
mkdocs gh-deploy
Mkdocs 将所有 mkdocs 文件转换为 HTML 网站,就像魔术一样!Markdown 链接被转换为 HTML 链接,网站的站点地图也会生成。
然后,提交本地仓库中的所有更改,并将其推送到远程仓库。
working_folder> git add .
working_folder> git commit -m "Create website"
working_folder> git push github master
GitHub Pages 配置
设置 GitHub Pages 的步骤如下:
按照之前的步骤,推送一个最简版的作品集
在GitHub上,在远程仓库的首页,点击“Settings”(顶部菜单)
在左侧菜单中,点击“Pages”
配置页面:
来源: *从一个分支部署
分支*: gh-deploy 和 /root - 然后保存
在顶部菜单中,点击“Actions”。你应该能看到一个“workflow run”。就这样放着,等它变绿,表示完成,网站已经上线!干得好,你成功了!
你可以在 https://
迭代改进
我越是查看和展示我的作品集,就越发现其中的错误。要修正这些错误,方法非常简单:
在你的电脑上进行修改(以快速查看修改内容,并在必要时进行其他修改,快速提醒:
*mkdocs serve*)使用
*mkdocs gh-deploy*重新生成网站提交所有更改
推送到远程仓库。
然后,魔法就发生了:GitHub 会自动进行修改(查看“Actions”标签,看看 GitHub 进行到哪一步了)。
我的最终 2024 版本作品集!
我的 2024 版本作品集可以在这里找到,GitHub 仓库在这里。未来,我希望能集成 JavaScript,让作品集更加动态。
为什么我没有为我的作品集购买一个网站?我希望能集中精力创作作品集内容和新项目,把这些任务的管理工作降到最低。而且,无论是购买网站还是使用 GitHub Pages,人们都会通过点击链接访问我的网站,因此购买网站与否,结果是一样的。
非常感谢您的关注,这是我的第一篇 Medium 文章。欢迎在文章下留言,我很想听听你的看法。再见!
函数调用:在 xLAM 上微调 Llama 3
由于 QLoRA 的优化,它们在速度和内存效率方面表现优异
·发布于 Towards Data Science ·阅读时间 8 分钟·2024 年 7 月 23 日
--

由 DALL-E 生成
最近的大型语言模型(LLM)在大多数语言生成任务中表现出色。然而,由于它们是基于下一个标记预测进行操作的,它们在准确执行数学运算时经常遇到困难。此外,由于知识截止,它们可能缺乏准确回答某些问题所需的信息。
缓解这些问题的一种方法是通过函数调用。函数调用允许大型语言模型(LLM)可靠地连接到外部工具。它使得与外部 API 的交互成为可能。例如,通过将 LLM 与 Web 搜索引擎和计算器接口,便可以通过函数调用从互联网上检索信息并执行数学运算。
在本文中,我们将看到如何针对函数调用对 LLM 进行微调。我使用了 xLAM,这是 Salesforce 发布的包含 6 万条函数调用数据的微调数据集,用于微调 Llama 3。我们将看到如何格式化数据集,以及如何利用微调后的适配器来进行函数调用。
我还制作了这个笔记本,包含了本文中描述的微调代码实现,以及一些推理示例:
LLM 的函数调用:如何……
Python 中的函数接口
PYTHON 编程
在 Python 中,函数接口可以通过类型提示和协议来实现。
·发布于 Towards Data Science ·阅读时间 18 分钟 ·2024 年 3 月 19 日
--

图片由 Rubaitul Azad 提供,来源于 Unsplash
虽然 Python 并没有提供与其他编程语言中常见的接口相同的功能,但它提供了类似的功能。传统的接口是 抽象基类(ABCs),可以通过 abc 模块使用。后来,类型提示 和 类型协议(typing.Protocol)被添加进来。
抽象基类(ABCs)和类型协议作为 Python 定义类的期望结构和行为的机制,类似于其他编程语言中的接口。尽管 Python 本身并未正式使用“接口”这一术语来描述此功能,但这一概念与 Python 的“协议”非常契合。事实上,在 typing.Protocol 被添加到 typing 模块之前,Python 就已使用“协议”一词来描述接口。因此,Python 的动态特性允许创建动态接口,而类型提示和类型协议则从静态检查的角度促进了接口的规范化。我们将在本文中讨论这一点。
抽象基类和类型协议通常在创建自定义类时考虑:你可以创建一个抽象基类或一个...
未来保障你的数据科学能力的价值
通过整合数据工程能力
·发表于Towards Data Science ·阅读时间 8 分钟·2024 年 1 月 5 日
--
本文将讨论一个常常被忽视的要求,这个要求对于构建和未来保障高度有价值的数据科学能力至关重要。
本文涵盖的内容
为什么整合数据能力很重要
在组织内部实现整合
成功的重叠优先事项与技能组合
为什么这很重要?
对于每个数据科学家来说,紧跟技术趋势和工具的步伐是至关重要的,特别是在行业发展过程中。随着人工智能的迅猛发展,许多新兴技术受到了广泛关注,如作为大规模 LLM 驱动的数据产品的chatGPT,GitHub Copilot 用于通过即时建议协助程序员编写代码,当然还有许多其他技术。
然而,数据科学家使用这些新技术和技能的能力,受到了我们大家都知道并喜爱的一个短语的深刻影响:“垃圾进,垃圾出”。这一概念围绕着这样一个观点:稳固的数据管道是良好数据科学的关键。虽然很多人都理解这一点,但现实情况是,数据驱动的组织往往不会为其数据科学团队提供专门的,甚至是任何形式的数据工程支持。
不幸的是,由于数据科学团队与其他团队的隔离,常常会带来一系列后果…
博弈论,第一部分——囚徒困境问题
博弈论在现实生活中的决策情境中广泛应用。
·发布于Towards Data Science ·6 分钟阅读·2024 年 10 月 22 日
--
想象一下,你正在桌子上玩一盘棋盘游戏。你所做的每一个决策不仅影响你的结果,还会影响对手的结果。每一步都需要仔细思考,你需要权衡自己选项与对手可能的反应之间的关系。这种来回对抗创造了一个具有挑战性的局面,其中规划和策略至关重要。这就是博弈论的本质。在这里,“游戏”这个词是对现实生活情境的比喻,比如商业谈判、地缘政治条约和社会互动。

图片由Michał Parzuchowski提供,来自Unsplash
这个迷人的领域研究的是个人和组织在其选择决定游戏结果的情况下如何做出决策。博弈论为我们提供了宝贵的洞察,帮助我们理解当参与者(或利益相关者)理性行为时,决策动态是如何形成的。它帮助我们理解在游戏中参与者之间的合作、非合作、冲突和竞争如何塑造我们的世界。
这是关于博弈论系列文章中的第一篇。我将探讨博弈论中的一个著名问题——囚徒困境问题。
囚徒困境问题
博弈论,第二部分——好人先胜
反复囚徒困境显示好人也能赢得最终胜利
·发表于 Towards Data Science ·8 分钟阅读·2024 年 10 月 31 日
--
与“好人总是最后失败”的普遍看法相反,博弈论揭示了好人其实可以先胜。我将在本文中通过反复囚徒困境问题来探讨这一有趣现象。这篇文章是我的博弈论系列的第二部分,如果你还没有阅读第一篇文章,建议先查看第一篇。第一部分通过克拉蒂卡和伊希塔这两位间谍的例子讨论经典的囚徒困境问题,阐明了他们的决策如何为双方带来最优结果。同时,也强调了博弈论在许多现实生活场景中的相关性。第一篇文章的链接在这里——
博弈论在现实生活中的决策场景中有广泛的应用
towardsdatascience.com
反复囚徒困境问题

在现实生活中,各方的互动并不总是一次性的事件。当然,克拉蒂卡和伊希塔这两位间谍曾面临过一次囚徒困境,最终他们都服刑 7 年。但在另一个平行世界里,克拉蒂卡和伊希塔面临着同样的困境(在...
博弈论,第三部分——你是你最常与之相处的五个人的平均值
“以牙还牙”是迭代囚徒困境游戏中最佳策略吗?
·发表于Towards Data Science ·阅读时间 5 分钟·2024 年 11 月 13 日
--
本文将探讨博弈论如何诠释那句流行的格言:“你是你最常与之相处的五个人的平均值。”通过迭代囚徒困境游戏的例子,我们可以看到个体行为和结果是如何受到周围环境中他人策略(无论是合作还是非合作)的影响。

图像由 GPT-4o 生成
在前两篇关于博弈论的文章中,我讨论了囚徒困境问题和迭代囚徒困境游戏。本文是我博弈论系列的第三部分,因此如果你还没有阅读前两篇文章,建议先查看它们。
第一部分讨论了经典的囚徒困境问题,并强调了博弈论在许多现实场景中的应用。第二部分通过一个例子描述了迭代囚徒困境游戏,讲述了两家竞争食品配送平台的首席执行官 Kratika 和 Ishita 如何尝试不同的策略进行竞争。它还讨论了罗伯特·阿克塞尔罗德(Robert Axelrod)的著名比赛,该比赛揭示了最成功的策略具有一些共同特点:它们是“友善的”(从合作开始),具有宽容心(但不会过于宽容……)
门控循环单元(GRU)—— 改进的 RNN
解释门控循环神经网络(GRU)是如何工作的
·发表于Towards Data Science ·10 分钟阅读·2024 年 6 月 15 日
--

"www.flaticon.com/free-icons/neural-network" title="神经网络图标">神经网络图标由 juicy_fish 创建 — Flaticon。
在本文中,我将探讨循环神经网络(RNNs)的标准实现方式:门控循环单元(GRUs)。
GRU 于 2014 年由Kyunghyun Cho 等人提出,是对传统 RNN 的改进,因为它们较少受到梯度消失问题的影响,从而拥有更长的记忆能力。
它们与长短期记忆(LSTM)网络类似,但操作较少,使它们在内存使用上更加高效。
我们将涵盖的内容:
RNN 概述
梯度消失与梯度爆炸问题
LSTM 概述
GRU 是如何工作的
循环神经网络
循环神经网络是一种特别擅长处理基于序列的数据类型的神经网络,例如自然语言和时间序列数据。
它们通过添加一个“循环”神经元来实现这一点,允许信息从过去的输入和输出传递到下一步。
高斯朴素贝叶斯解释:为初学者提供的带有代码示例的可视化指南
分类算法
钟形曲线假设带来更好的预测
·发表于Towards Data Science ·阅读时间:9 分钟·2024 年 10 月 12 日
--

⛳️ 更多[分类算法](https://medium.com/@samybaladram/list/classification-algorithms-b3586f0a772c)讲解:· 虚拟分类器 · K 最近邻分类器 · 伯努利朴素贝叶斯 ▶ 高斯朴素贝叶斯 · 决策树分类器 · 逻辑回归 · 支持向量机分类器 · 多层感知机
在我们之前关于伯努利朴素贝叶斯的文章基础上,这一方法处理的是二元数据,我们现在探讨的是用于连续数据的高斯朴素贝叶斯。与二元方法不同,该算法假设每个特征遵循正态(高斯)分布。
在这里,我们将看到高斯朴素贝叶斯如何处理连续的钟形数据——为准确预测打下基础——而无需深入探讨贝叶斯定理的复杂数学。

所有视觉内容:作者使用 Canva Pro 创作。优化了移动端显示,可能在桌面端显示过大。
定义
与其他朴素贝叶斯变种一样,高斯朴素贝叶斯做出了“朴素”的特征独立性假设。它假设在给定类标签的条件下,特征是相互独立的。
然而,虽然伯努利朴素贝叶斯适用于具有二元特征的数据集,但高斯朴素贝叶斯假设特征服从**连续的正态(高斯)**分布。尽管这一假设在现实中并不总是成立,但它简化了计算,并且常常能够得到出人意料的准确结果。

伯努利朴素贝叶斯假设数据是二元的, multinomial 朴素贝叶斯适用于离散计数数据,而高斯朴素贝叶斯处理连续数据,假设数据服从正态分布。
使用的数据集
在本文中,我们将使用这个人工高尔夫数据集(由作者制作)作为示例。该数据集预测一个人是否会根据天气条件去打高尔夫。

列:‘降水量’(单位:毫米)、‘温度’(单位:摄氏度)、‘湿度’(单位:百分比)、‘风速’(单位:千米/小时)和‘是否下棋’(是/否,目标特征)
# IMPORTING DATASET #
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
dataset_dict = {
'Rainfall': [0.0, 2.0, 7.0, 18.0, 3.0, 3.0, 0.0, 1.0, 0.0, 25.0, 0.0, 18.0, 9.0, 5.0, 0.0, 1.0, 7.0, 0.0, 0.0, 7.0, 5.0, 3.0, 0.0, 2.0, 0.0, 8.0, 4.0, 4.0],
'Temperature': [29.4, 26.7, 28.3, 21.1, 20.0, 18.3, 17.8, 22.2, 20.6, 23.9, 23.9, 22.2, 27.2, 21.7, 27.2, 23.3, 24.4, 25.6, 27.8, 19.4, 29.4, 22.8, 31.1, 25.0, 26.1, 26.7, 18.9, 28.9],
'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'WindSpeed': [2.1, 21.2, 1.5, 3.3, 2.0, 17.4, 14.9, 6.9, 2.7, 1.6, 30.3, 10.9, 3.0, 7.5, 10.3, 3.0, 3.9, 21.9, 2.6, 17.3, 9.6, 1.9, 16.0, 4.6, 3.2, 8.3, 3.2, 2.2],
'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)
# Set feature matrix X and target vector y
X, y = df.drop(columns='Play'), df['Play']
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)
print(pd.concat([X_train, y_train], axis=1), end='\n\n')
print(pd.concat([X_test, y_test], axis=1))
主要机制
高斯朴素贝叶斯适用于连续数据,假设每个特征都服从高斯(正态)分布。
计算训练数据中每个类别的概率。
对于每个特征和类别,估计该类别中该特征值的均值和方差。
对于新实例:
对于每个类别,计算该类别下每个特征值在该特征高斯分布下的概率密度函数(PDF)。
将类别概率与所有特征的 PDF 值的乘积相乘。
预测具有最高概率的类别。

高斯朴素贝叶斯使用正态分布来建模每个类别中不同特征值的可能性。然后,它将这些可能性结合起来进行预测。
转换非高斯分布数据
记住这个算法天真地假设所有输入特征都服从高斯/正态分布吗?
由于我们不确定数据的分布,尤其是对于那些明显不遵循高斯分布的特征,在使用高斯朴素贝叶斯之前应用幂变换(如 Box-Cox 变换)可能会有帮助。这种方法可以帮助数据更接近高斯分布,从而更符合算法的假设。

所有列都使用幂变换(Box-Cox 变换)进行缩放,然后标准化。
from sklearn.preprocessing import PowerTransformer
# Initialize and fit the PowerTransformer
pt = PowerTransformer(standardize=True) # Standard Scaling already included
X_train_transformed = pt.fit_transform(X_train)
X_test_transformed = pt.transform(X_test)
现在我们准备开始训练。
训练步骤
- 类别概率计算:对于每个类别,计算其概率:(该类别中的实例数量)/(实例总数)

from fractions import Fraction
def calc_target_prob(attr):
total_counts = attr.value_counts().sum()
prob_series = attr.value_counts().apply(lambda x: Fraction(x, total_counts).limit_denominator())
return prob_series
print(calc_target_prob(y_train))
2. 特征概率计算:对于每个特征和每个类别,使用训练数据计算该类别中特征值的均值(μ)和标准差(σ)。然后,使用高斯概率密度函数(PDF)公式计算概率。

对于每个天气条件,确定“YES”和“NO”实例的均值和标准差。然后使用正态/高斯分布的 PDF 公式计算它们的 PDF。

相同的过程应用于其他所有特征。
def calculate_class_probabilities(X_train_transformed, y_train, feature_names):
classes = y_train.unique()
equations = pd.DataFrame(index=classes, columns=feature_names)
for cls in classes:
X_class = X_train_transformed[y_train == cls]
mean = X_class.mean(axis=0)
std = X_class.std(axis=0)
k1 = 1 / (std * np.sqrt(2 * np.pi))
k2 = 2 * (std ** 2)
for i, column in enumerate(feature_names):
equation = f"{k1[i]:.3f}·exp(-(x-({mean[i]:.2f}))²/{k2[i]:.3f})"
equations.loc[cls, column] = equation
return equations
# Use the function with the transformed training data
equation_table = calculate_class_probabilities(X_train_transformed, y_train, X.columns)
# Display the equation table
print(equation_table)

3. 平滑:高斯朴素贝叶斯使用一种独特的平滑方法。与其他变种中的拉普拉斯平滑不同,它将一个微小的值(0.000000001 倍的最大方差)添加到所有方差中。这可以防止由于除以零或非常小的数字导致的数值不稳定。
预测/分类步骤
给定一个具有连续特征的新实例:
1. 概率收集:
对于每个可能的类别:
· 从该类别发生的概率(类别概率)开始。
· 对于新实例中的每个特征,计算该特征在该类别中的概率密度函数。

对于 ID 14,我们计算每个特征在“YES”和“NO”实例中的 PDF。
2. 分数计算与预测:
对于每个类别:
· 将所有收集到的 PDF 值相乘。
· 结果是该类别的分数。
· 得分最高的类别就是预测结果。

from scipy.stats import norm
def calculate_class_probability_products(X_train_transformed, y_train, X_new, feature_names, target_name):
classes = y_train.unique()
n_features = X_train_transformed.shape[1]
# Create column names using actual feature names
column_names = [target_name] + list(feature_names) + ['Product']
probability_products = pd.DataFrame(index=classes, columns=column_names)
for cls in classes:
X_class = X_train_transformed[y_train == cls]
mean = X_class.mean(axis=0)
std = X_class.std(axis=0)
prior_prob = np.mean(y_train == cls)
probability_products.loc[cls, target_name] = prior_prob
feature_probs = []
for i, feature in enumerate(feature_names):
prob = norm.pdf(X_new[0, i], mean[i], std[i])
probability_products.loc[cls, feature] = prob
feature_probs.append(prob)
product = prior_prob * np.prod(feature_probs)
probability_products.loc[cls, 'Product'] = product
return probability_products
# Assuming X_new is your new sample reshaped to (1, n_features)
X_new = np.array([-1.28, 1.115, 0.84, 0.68]).reshape(1, -1)
# Calculate probability products
prob_products = calculate_class_probability_products(X_train_transformed, y_train, X_new, X.columns, y.name)
# Display the probability product table
print(prob_products)

评估步骤

对于这个特定的数据集,这个准确度被认为是相当好的。
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
# Initialize and train the Gaussian Naive Bayes model
gnb = GaussianNB()
gnb.fit(X_train_transformed, y_train)
# Make predictions on the test set
y_pred = gnb.predict(X_test_transformed)
# Calculate the accuracy
accuracy = accuracy_score(y_test, y_pred)
# Print the accuracy
print(f"Accuracy: {accuracy:.4f}")
关键参数
GaussianNB 因其简洁性和有效性而闻名。关于其参数,主要需要记住的是:
priors:这是最显著的参数,类似于伯努利朴素贝叶斯。在大多数情况下,你不需要手动设置它。默认情况下,它是从你的训练数据中计算出来的,通常效果很好。
var_smoothing:这是一个稳定性参数,通常不需要调整。(默认值为 0.000000001)
关键结论是,这个算法设计得非常易用,通常在大多数情况下你可以直接使用,而不必担心参数调整。
优缺点
优点:
简洁性:保持了易于实现和理解的特点。
效率:在训练和预测时保持快速,使其适用于具有连续特征的大规模应用。
数据的灵活性:能够很好地处理小型和大型数据集,根据手头问题的规模进行调整。
连续特征处理:擅长处理连续和实值特征,非常适合预测实值输出或处理特征在连续性上变化的数据。
缺点:
独立性假设:仍然假设特征在给定类别的条件下是独立的,但在所有实际场景中,这一假设可能并不成立。
高斯分布假设:当特征值真正符合正态分布时效果最好。非正态分布可能导致性能不理想(但可以通过我们讨论的幂变换来修正)
对异常值的敏感性:训练数据中的异常值可能会显著影响其表现,因为异常值会扭曲均值和方差的计算。
最后的备注
高斯朴素贝叶斯是一个高效的分类器,适用于处理涉及连续数据的广泛应用。它处理实值特征的能力使其在二分类任务之外有更广泛的应用,因此成为许多应用中的首选。
虽然它对数据(特征独立性和正态分布)做出了一些假设,但当这些条件满足时,它能提供稳健的表现,因此在初学者和经验丰富的数据科学家中都很受欢迎,因为它在简洁性和强大功能之间取得了良好的平衡。
🌟 高斯朴素贝叶斯简化版
import pandas as pd
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import PowerTransformer
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
# Load the dataset
dataset_dict = {
'Rainfall': [0.0, 2.0, 7.0, 18.0, 3.0, 3.0, 0.0, 1.0, 0.0, 25.0, 0.0, 18.0, 9.0, 5.0, 0.0, 1.0, 7.0, 0.0, 0.0, 7.0, 5.0, 3.0, 0.0, 2.0, 0.0, 8.0, 4.0, 4.0],
'Temperature': [29.4, 26.7, 28.3, 21.1, 20.0, 18.3, 17.8, 22.2, 20.6, 23.9, 23.9, 22.2, 27.2, 21.7, 27.2, 23.3, 24.4, 25.6, 27.8, 19.4, 29.4, 22.8, 31.1, 25.0, 26.1, 26.7, 18.9, 28.9],
'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'WindSpeed': [2.1, 21.2, 1.5, 3.3, 2.0, 17.4, 14.9, 6.9, 2.7, 1.6, 30.3, 10.9, 3.0, 7.5, 10.3, 3.0, 3.9, 21.9, 2.6, 17.3, 9.6, 1.9, 16.0, 4.6, 3.2, 8.3, 3.2, 2.2],
'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)
# Prepare data for model
X, y = df.drop('Play', axis=1), (df['Play'] == 'Yes').astype(int)
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=False)
# Apply PowerTransformer
pt = PowerTransformer(standardize=True)
X_train_transformed = pt.fit_transform(X_train)
X_test_transformed = pt.transform(X_test)
# Train the model
nb_clf = GaussianNB()
nb_clf.fit(X_train_transformed, y_train)
# Make predictions
y_pred = nb_clf.predict(X_test_transformed)
# Check accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")
进一步阅读
对于GaussianNB及其在 scikit-learn 中的实现的详细解释,读者可以参考官方文档,该文档提供了关于其使用和参数的全面信息。
技术环境
本文使用的是 Python 3.7 和 scikit-learn 1.5。虽然讨论的概念通常适用,但不同版本的具体代码实现可能会有所不同。
关于插图
除非另有注明,所有图片均由作者创作,包含来自 Canva Pro 的授权设计元素。
𝙎𝙚𝙚 𝙢𝙤𝙧𝙚 𝘾𝙡𝙖𝙨𝙨𝙞𝙛𝙞𝙘𝙖𝙩𝙞𝙤𝙣 𝘼𝙡𝙜𝙤𝙧𝙞𝙩𝙝𝙢𝙨 𝙝𝙚𝙧𝙚:

分类算法
查看列表8 个故事


𝙔𝙤𝙪 𝙢𝙞𝙜𝙝𝙩 𝙖𝙡𝙨𝙤 𝙡𝙞𝙠𝙚:

回归算法
查看列表5 个故事



集成学习
查看列表4 个故事


从零开始学习高斯过程
通过仅用 NumPy 实现高斯过程,深入理解它们。
·发表于 Towards Data Science ·阅读时间 6 分钟 ·2024 年 1 月 6 日
--
引言
高斯过程(GPs)是一类非常强大的模型。在机器学习算法中,很少有能在提供准确不确定性度量的同时,又保持高度灵活性的算法。问题是,高斯过程在概念上非常难以理解。大多数解释都使用了一些复杂的代数和概率知识,这些知识通常对于理解这些模型如何运作并不十分有用。
也有许多很棒的指南跳过了数学部分,直接给出了这些模型如何工作的直观理解,但当你自己在正确的语境下使用高斯过程(GPs)时,我个人认为仅有表面知识是远远不够的。这就是为什么我想从零开始带你一步步实现一个简化版的高斯过程,帮助你更清楚地理解所有那些为你实现这些模型的库背后究竟发生了什么。
我还会提供我的 GitHub 仓库链接,在那里你可以找到仅使用 NumPy 实现的高斯过程代码。我尽量将数学抽象化,但显然,仍然有一些数学知识是不可避免的……
数据
第一步始终是查看数据。我们将使用月度 CO2 大气浓度随时间变化的数据,这些数据是来自马纳洛亚天文台的测量,是高斯过程常用的数据集之一 [1]。这是故意选择的同一个数据集……
Gemma与Llama与Mistral:探索更小的人工智能模型
小规模语言模型的比较研究:评估 Gemma、Llama 3 和 Mistral 在阅读理解任务中的表现
·发表于 Towards Data Science ·阅读时间:14 分钟·2024 年 8 月 6 日
--
本文由 Rafael Guedes 与我共同撰写。
引言
大型语言模型(LLMs)正在快速发展。每个月,新的模型都会被开发出来,以超越当前市场上的顶级模型。这种健康的竞争有利于创造出新的方法,提高质量和速度。此外,各家公司都在专注于开发更小的模型,使其能够被没有强大计算资源的个人或组织所使用。
就在几周前,苹果公司在全球开发者大会上介绍了 Apple Intelligence。这是一套多种生成性模型,经过微调,旨在帮助用户撰写和优化文本、优先处理和总结通知、创作图像并执行应用内操作。苹果公司在该套件中唯一开发的基础性和专有模型也在同一场会议上推出。这是一个小型模型,设计用于在设备端运行,其中硬件成为了一个重要的限制因素。在苹果的案例中,该模型是闭源的。我们所知道的是,它是一个大约 30 亿参数的模型,与 Gemma、Mistral 和 Llama 3 的 7b 版本相当(根据苹果公司共享的结果)。
生成型人工智能安全格局:文本生成图像模型的缓解栈指南
没有“荒野西部”的人工智能:巡礼 T2I 模型的安全组件
·发表于 Towards Data Science ·12 分钟阅读·2024 年 10 月 26 日
--
理解文本生成图像 AI 模型的能力与风险
文本生成图像模型(T2I)是根据文本提示描述生成图像的人工智能系统。潜在扩散模型(LDM)正成为生成图像最流行的架构之一。LDM 首先将图像压缩到一个“潜在空间”中,这是一个简化的表示核心信息的压缩空间,用较少的维度表达图像的核心内容,而不包含所有详细的像素数据。该模型从潜在空间中的随机噪声开始,通过称为扩散的过程,在输入文本的引导下逐步将其精炼为清晰的图像。LDM 非常灵活,除了能够生成文本到图像的输出外,还具有如修复功能(inpainting),允许用户通过简单描述所需的更改来编辑现有图像的特定部分。例如,你可以从照片中删除一个物体或无缝地添加新元素,全部通过文本命令完成。
这些能力带来了显著的安全风险,需要谨慎管理。生成的图像可能包含明显或不适当的内容,无论是对明显的提示做出直接响应,还是无意中生成,即使输入提示本身是无害的——例如,要求生成吸烟人的图像可能会错误地生成未成年儿童吸烟的图像。对于图像修复功能,它允许用户通过上传自己的图像来修改图像,这一自由度超越了传统照片编辑工具的速度、规模和效率,虽然使得图像修改更加容易,但也可能带来更大的危险。它可以用来以有害的方式改变图像,例如改变某人的外貌、去除衣物或以宗教敏感的方式修改诸如衣物或食物等上下文元素。
安全缓解堆栈
鉴于图像生成和图像修复能力可能带来的风险,有必要在模型生命周期的不同阶段建立一个强大的安全缓解堆栈。这包括在预训练、微调和训练后阶段实施保护措施,例如对输入提示和生成图像应用安全过滤器,或使用一个多模态分类器,同时评估输入文本和输出图像。
预训练和微调阶段必须纳入伦理考量和偏见缓解措施,以确保基础模型不会延续有害的刻板印象、偏见或生成不适当的内容。一旦经过安全微调的基础模型部署到生产环境中,输入提示分类器至关重要,能够在生成任何图像之前过滤掉明显或有害的请求,从而防止模型处理不安全的输入。同样,输出图像分类器或多模态分类器可以分析生成的图像,检测并标记任何不适当或无意的图像,防止其传递给用户。
这种分层方法确保了整个过程中有多个检查点,大大降低了有害输出的风险,确保图像生成技术得以负责任地使用。

安全缓解堆栈。图像来源:作者
预训练缓解措施
T2I 模型在图像和相应的文字描述对上进行训练。数据对来自于公开可用资源与已授权的资源的组合。
训练数据的缓解措施
T2I 模型是在来自互联网上抓取的亿级图像数据集上进行训练的。研究[1]表明,像 LION-400M 这样的图像-替代文本数据集包含
包括强奸、色情、恶意刻板印象、种族歧视和民族侮辱言辞,以及其他极其有问题的内容的麻烦和明显的图像与文字对。
大多数模型在预训练步骤中都会过滤掉训练数据中的有害内容。DALL·E 2 [2] 特别提到,包含图形性性别和暴力内容以及一些仇恨符号的明确内容已被明确过滤掉。然而,研究表明,过滤掉性内容反而加剧了训练数据中的偏见问题。具体来说,过滤性内容减少了生成女性图像的数量,因为女性图像在过滤掉的性别化图像中占据了不成比例的比例。为了缓解偏见放大问题,采取了几种方法,如使用合成数据生成重新平衡数据集,或重新加权已过滤的数据集,使其分布更好地匹配未过滤图像的分布 [2]。
在这个阶段,考虑隐私缓解措施同样至关重要,以确保用于训练模型的数据中不包含任何个人、敏感或可识别的信息。可以应用几种技术——匿名化可以用来去除或模糊任何个人标识符(姓名、地址、面孔),差分隐私(添加噪声、对个人数据进行子采样以避免过拟合)可以确保个体数据点不能从模型中被逆向工程提取,此外,还可以过滤掉任何包含机密或专有信息的数据。
对基础模型进行安全微调
LDMs 本身已知存在一系列安全问题(偏见、刻板印象等),以及在某些高风险领域缺乏提示对齐的问题。这些是“无意的伤害”,用户提供了一个完全无害的提示,但 LDM 生成了有害的回应。一些例子包括无意的性别化,例如像“为约会穿衣的女人”这样随意的提示可以生成性别化的图像。又如,缺乏提示对齐引发的问题,如下面的例子所示,其中 Midjourney 模型无法为印度男性生成穿着非印度服饰的妻子,而对于白人男性,它则能正确生成不同种族和不同服饰的妻子。






显示提示不对齐问题的图像。第一排展示了不同种族的印度男性妻子,所有妻子穿着印度服饰,而第二排展示了相同的提示,但对象为白人男性,展示了妻子正确的种族和多样的服饰。这些图像由作者使用 Midjourney 机器人生成
如文献[4]所述,还存在一个较大的风险领域,其中边缘化群体被与有害的内涵相关联,强化了社会上的仇恨性刻板印象。例如,将人口群体与动物或神话生物混淆(例如,将黑人描绘成猴子或其他灵长类动物)、将人类与食物或物品混淆(例如,将残障人士与蔬菜联系在一起),或将某些人口群体与负面的语义概念联系在一起(例如,将恐怖主义与穆斯林人群联系在一起)。
像这些群体和概念之间的有问题关联,反映了关于这些群体的长期负面叙事。如果生成式 AI 模型从现有数据中学习到这些有问题的关联,它可能会在生成的内容中再现这些问题[4]。

边缘化群体与概念的有问题关联。图像来源
有几种方法可以微调 LLMs。根据文献[6],一种常见的方法叫做监督微调(SFT)。这种方法包括使用一组包含输入和期望输出对的数据集,对预训练的模型进行进一步训练。模型通过学习更好地匹配这些期望的响应来调整其参数。
通常,微调包括两个阶段:SFT(监督微调)用于建立基础模型,随后通过 RLHF(基于人类反馈的强化学习)提升性能。SFT 通过模仿高质量的示范数据,而 RLHF 通过偏好反馈来精细化大语言模型(LLMs)。
RLHF 可以通过两种方式进行:基于奖励的方法或无奖励的方法。在基于奖励的方法中,我们首先使用偏好数据训练一个奖励模型。该模型随后指导像 PPO 这样的在线强化学习算法。无奖励的方法较为简单,直接在偏好或排名数据上训练模型,以了解人类的偏好。在这些无奖励的方法中,DPO 展现了强大的表现,并在社区中获得了广泛的关注。扩散 DPO 可以将模型从有问题的描绘引导到更理想的替代方案。这一过程的难点不在于训练本身,而在于数据的策划。对于每一种风险,我们需要收集数百或数千个提示,对于每个提示,需要有一个理想和一个不理想的图像对。理想的示例应当是该提示的完美描绘,而不理想的示例则应与理想图像相同,但包括我们希望“去除”的风险。
后训练缓解措施
这些缓解措施是在模型完成并部署到生产环境后应用的。这些措施涵盖了对用户输入提示和最终图像输出的所有防范措施。
提示过滤
当用户输入文本提示以生成图像,或上传图像使用修补技术进行修改时,可以应用过滤器来阻止请求明确要求有害内容的请求。在此阶段,我们处理用户明确提供有害提示的情况,比如“显示一个人杀另一个人”的提示,或者上传图像并要求“去掉这个人的衣服”等等。
为了检测有害请求并进行阻止,我们可以使用基于简单黑名单的关键词匹配方法,阻止所有包含有害关键词(比如“自杀”)的提示。然而,这种方法是脆弱的,可能会产生大量的假阳性和假阴性结果。任何模糊化机制(例如,用户查询“suicid3”而不是“自杀”)都会绕过这种方法。相反,可以使用基于嵌入的 CNN 过滤器来进行有害模式识别,通过将用户提示转换为捕捉文本语义的嵌入向量,然后使用分类器来检测这些嵌入中的有害模式。然而,已经证明,LLM 在提示中的有害模式识别上表现更好,因为它们擅长理解上下文、细微差别和意图,而像 CNN 这样的简单模型可能难以处理。它们提供了一种更具上下文感知的过滤解决方案,并且能比基于固定嵌入的模型更有效地适应不断发展的语言模式、俚语、模糊化技术以及新兴的有害内容。LLM 可以根据您的组织定义的任何政策指南进行训练,以屏蔽有害内容。除了像性别暴力、暴力行为、自伤等有害内容外,它还可以训练用于识别并屏蔽生成公众人物或选举相关虚假信息图像的请求。要在生产环境中使用基于 LLM 的解决方案,您需要优化延迟并承担推理成本。
提示操作
在将原始用户提示传递给模型进行图像生成之前,可以进行一些提示操作,以增强提示的安全性。以下列出了一些案例研究:
提示增强以减少刻板印象:LDMs 扩大了危险和复杂的刻板印象[5]。各种普通提示都会产生刻板印象,包括简单提及特征、描述、职业或物体的提示。例如,提示一些基本特征或社会角色会导致强化白人至上的观念,或者提示职业会导致性别和种族差距的放大。通过对用户提示进行提示工程,增加性别和种族多样性是一个有效的解决方案。例如,“首席执行官的图像” → “首席执行官的图像,亚洲女性”或“首席执行官的图像,黑人男性”可以产生更具多样性的结果。这也有助于减少有害的刻板印象,比如将提示“犯罪分子图像”变为“犯罪分子图像,橄榄色肤色”,因为原始提示最有可能产生一名黑人男性图像。
提示匿名化以保护隐私:此阶段还可以应用额外的缓解措施来匿名化或过滤掉提示中请求特定个人隐私信息的内容。例如,“John Doe 在<某地址>淋浴的图像” → “某人淋浴的图像”
提示重写和基准化以将有害提示转化为无害:可以通过重写或基准化(通常使用微调的 LLM)来将有问题的场景以正面或中立的方式重新框架。例如,“展示一个懒惰的[某族群]人打盹” → “展示一个人在下午休息”。定义一个明确的提示,或通常所说的“基准化生成”,能够使模型在生成场景时更严格地遵循指令,从而减轻某些潜在和未基准化的偏见。“展示两个人玩得很开心”(这可能导致不当或有风险的解释)-> “展示两个人在餐厅用餐”。
输出图像分类器
可以部署图像分类器来检测模型生成的图像是否有害,并在返回给用户之前进行阻止。像这样的独立图像分类器对于阻止明显有害的图像(如展示暴力或性内容、裸露等)非常有效。然而,对于基于图像修复的应用程序,用户上传输入图像(例如,白人图像)并提供有害提示(“给他们化黑脸”)以不安全的方式转换它时,仅仅查看输出图像的分类器将无法有效工作,因为它们失去了“转化”本身的上下文。对于这样的应用,能够同时考虑输入图像、提示和输出图像的多模态分类器,来判断从输入到输出的转换是否安全,效果非常好。这样的分类器还可以训练识别“非预期的转化”,例如,上传一个女性的图像并提示“让她们变美”,结果生成了一幅瘦弱、金发、白人的女性图像。
再生而非拒绝
与其拒绝输出图像,像 DALL·E 3 这样的模型使用分类器引导来改善未经请求的内容。基于分类器引导的定制算法被部署,工作原理在[3]中有描述—
当图像输出分类器检测到有害图像时,提示会带有特殊标志重新提交给 DALL·E 3。这个标志触发扩散采样过程,使用有害内容分类器对可能触发它的图像进行采样。
基本上,这个算法可以“推动”扩散模型朝着更合适的生成方向发展。这可以在提示层和图像分类器层面同时进行。
在生产环境中,通常会实施几个额外的安全措施,例如对 AI 生成的图像进行水印标记,以追溯内容的来源并跟踪滥用行为。这些措施还包括全面的监控和报告系统,供用户报告事件,从而迅速解决实时问题。严重违规行为可能会披露给政府机关(如 NCMEC),并执行对政策违规的惩罚措施,包括禁用账户,以阻止高风险用户。此外,应用层的速率限制有助于防止自动化或脚本攻击。
风险发现与评估
除了实际的缓解措施外,还有两个重要方面需要考虑以确保安全。其中之一是红队测试,即团队积极寻找 AI 模型中的弱点、漏洞或不可预见的风险。红队测试模拟现实世界的攻击和新兴风险,既可以通过专家红队成员手动执行,他们来自不同的社会经济、教育和文化背景,也可以通过可扩展的自动化系统来执行,这些系统被训练成“模拟攻击”。另一个方面是基准测试(或评估),即模型通过一套标准化的测试或指标来评估其在预定义领域的表现,如检测有害内容、处理偏见或保持公平性。虽然红队测试通常能揭示基准测试可能遗漏的漏洞,从而对发现未知风险至关重要,但基准测试提供一致的、可重复的评估,并帮助根据既定标准比较模型,但可能无法暴露基准测试范围之外的新风险或漏洞。这两个方面对于评估 AI 系统的安全性至关重要,但在范围和方法上有所不同。
下面是一个时间线示例,展示了可以进行红队测试或评估的各个阶段。至少,在训练好的基础模型准备好后,应该进行一次红队测试,以评估模型中隐含的风险。通常,你会发现模型能够对无害的提示生成有害输出。在这些隐性风险在微调阶段得到缓解后,你将进行全面评估,以识别任何缺口并进一步改进模型,直到最终定型投入生产。最后,一旦模型部署到生产环境中,你将对整个端到端系统进行红队测试,评估在当前设置下未覆盖的剩余风险,并记录下来以便通过快速热修复或更稳健的长期策略解决。在这个阶段,你还可以进行基准测试,以确保你的应用程序在真实用户使用前符合所有安全性、公平性和性能标准,并能够将这些指标报告给外部。

安全栈中的红队测试和评估。图像来源:作者
虽然这只是一个最低标准指南,但红队测试和评估可以在整个技术栈中多次进行,甚至在单个组件上进行(例如,仅在提示分类器、图像分类器或重写器上)以确保最终组件具有高精度和高召回率。
总之,通过在模型生命周期的各个阶段实施全面的安全防护——从预训练到后训练,开发者不仅可以减少 AI 模型生成有害或偏见内容的风险,还可以防止这些内容暴露给最终用户。此外,在生命周期中的持续实践,如红队测试和基准测试,对于发现和评估漏洞至关重要,确保 AI 系统在真实应用中能够安全、公平且负责任地运作。
参考文献
生成式 AI 正在重塑数据科学团队
挑战、机会以及数据科学家角色的演变
·发布于Towards Data Science ·9 分钟阅读·2024 年 11 月 29 日
--

图片来自articstudios于Unsplash
生成式 AI(GenAI)开启了更快开发周期的可能,最小化了技术和维护工作量,并提供了以前看似遥不可及的创新使用案例。同时,它也带来了新的风险——如幻觉问题和对第三方 API 的依赖。
对于数据科学家和机器学习团队来说,这一发展直接影响到他们的角色。出现了一种新的 AI 项目类型,其中一部分 AI 已经由外部模型提供商(如 OpenAI、Anthropic、Meta 等)实现。非 AI 专家的团队现在也能相对轻松地整合 AI 解决方案。在这篇博文中,我们将讨论这对数据科学和机器学习团队意味着什么:
现在可以解决更广泛的各种问题,但并不是所有问题都是 AI 问题
传统机器学习并没有消失,而是通过生成式 AI 得到了增强
一些问题最适合通过生成式 AI 来解决,但仍需要机器学习专家进行评估并减轻伦理风险
AI 素养在公司中变得越来越重要,数据科学家在实现这一目标中扮演着关键角色。
现在可以解决更广泛的各种问题——但并不是所有问题都是 AI 问题
生成式 AI(GenAI)已解锁了解决更广泛问题的潜力,但这并不意味着每个问题都是 AI 问题。数据科学家和 AI 专家仍然是关键,他们能够识别出什么时候使用 AI 是合理的,选择合适的 AI 技术,并设计和实施可靠的解决方案来解决特定问题(无论解决方案是生成式 AI、传统机器学习还是混合方法)。
如何判断使用机器学习是否是一个好主意,以及这在生成性 AI 出现后是如何变化的
towardsdatascience.com
然而,尽管 AI 解决方案的广度不断增加,在选择正确的应用场景并确保解决方案能够适应未来时,仍需要考虑两点:
在任何给定时刻,生成性 AI 模型将会有一定的局限性,这可能会对解决方案产生负面影响。这始终是事实,因为我们处理的是预测和概率,必然会有一定的误差和不确定性。
与此同时,事情正在快速进展,并将在不久的将来继续发展,减少和修改生成性 AI 模型的局限性和弱点,增加新的能力和功能。
如果当前的 LLM 版本无法解决某些特定问题,但未来的版本可能能解决,那么现在等待或开发一个不完美的解决方案可能更具战略意义,而不是投资于复杂的内部开发来过度工作并修复当前 LLM 的局限性。同样,数据科学家和 AI 专家可以帮助引导所有这些进展的方向,并区分哪些问题可能由模型提供方解决,哪些应该在内部解决。例如,加入让用户编辑或监督 LLM 输出的功能,可能比追求通过复杂的逻辑或微调实现完全自动化更为有效。
市场上的差异化不会仅仅来自于使用大型语言模型(LLMs),因为这些模型现在对每个人都可用,而是来自于通过它们提供的独特体验、功能和产品价值(如果我们都在使用相同的基础模型,是什么让我们与众不同?,通过 AI 打造你的竞争优势)。
对于生成性 AI 解决方案,数据科学团队可能需要减少对模型开发部分的关注,更多关注整个 AI 系统。
传统机器学习并未消失——而是通过生成性 AI 得到增强
尽管生成性 AI 已经彻底改变了 AI 领域和许多行业,但传统的机器学习仍然是不可或缺的。许多应用场景仍然需要传统的机器学习解决方案(例如大多数不涉及文本或图像的应用场景),而其他问题可能仍然通过机器学习而非生成性 AI 更高效地得到解决。
远非取代传统的机器学习(ML),生成性 AI 通常是对其的补充:它使得原型设计和实验变得更快,并可以通过混合型的机器学习+生成性 AI 解决方案增强某些应用场景。
在传统机器学习工作流程中,开发解决方案(如自然语言处理(NLP)分类器)包括:获取训练数据(可能包括手动标注),准备数据,训练和微调模型,评估性能,部署,监控和维护系统。这个过程通常需要几个月,并且需要大量资源来进行开发和持续维护。
相比之下,使用 GenAI 后,工作流程大大简化:选择合适的大型语言模型(LLM),进行提示工程或提示迭代,离线评估,然后使用 API 将模型集成到生产中。这大大缩短了从构想到部署的时间,通常只需几周而不是几个月。此外,大部分的维护工作由 LLM 提供商管理,进一步降低了运营成本和复杂性。

机器学习与 GenAI 项目阶段,图像由作者提供
因此,GenAI 使得快速测试想法并证明价值成为可能,无需收集标注数据或投资于训练和部署内部模型。一旦价值被证明,机器学习团队可能会决定转向传统的机器学习解决方案,以降低成本或延迟,同时可能利用最初 GenAI 系统中的标注数据。类似地,许多公司在价值被证明后,正在转向小型语言模型(SMLs),因为它们可以进行微调并更容易部署,同时在性能上与 LLM 相当或更具优势(小即是大:小型语言模型的崛起)。
在其他情况下,最佳解决方案将 GenAI 和传统机器学习结合成混合系统,充分利用两者的优势。一个好的例子是“使用大型语言模型构建 DoorDash 的产品知识图谱”,他们解释了如何将传统机器学习模型与 LLM 一起使用,以改进分类任务,例如标记产品品牌。当传统机器学习模型无法自信地进行分类时,LLM 会被使用,如果 LLM 能够完成此任务,则传统机器学习模型会使用新的标注进行再训练(这形成了一个良好的反馈循环!)。
无论如何,机器学习团队将继续处理传统机器学习解决方案,进行预测模型的微调和部署,同时意识到 GenAI 如何帮助提高解决方案的速度和质量。
某些问题将通过 GenAI 得到更好的解决
人工智能领域正从使用大量内部专用模型转向使用由外部公司拥有的少数大型多任务模型。机器学习团队需要接受这一变化,并且准备将生成型 AI 解决方案纳入其可能采用的方法清单,以保持竞争力。尽管模型训练阶段已经完成,但仍然需要保持对机器学习和人工智能的思维方式与敏感性,因为解决方案仍将是概率性的,这与传统软件开发的确定性大相径庭。
尽管生成型 AI 带来了许多好处,机器学习团队仍然需要应对它自身的一系列挑战和风险。考虑使用基于生成型 AI 的解决方案而非内部传统机器学习解决方案时,主要的额外风险包括:

新的生成型 AI 风险会加到传统机器学习风险(紫色部分)上,图像由作者提供
依赖第三方模型:这带来了每次调用的新成本、更高的延迟,可能影响实时系统的性能,并且缺乏控制(因为我们现在对其训练数据或设计决策的了解有限,而且提供商的更新可能会在生产中引入意外问题)。
生成型 AI 特定风险:我们十分清楚生成型 AI 的自由输入/自由输出关系。自由输入带来了新的隐私和安全风险(例如数据泄露或提示注入),而自由输出则带来了幻觉、毒性或偏见与歧视增加的风险。
## 人工智能看起来比以往任何时候都更容易,但它真的是这样吗?
构建 AI 产品的四大挑战
towardsdatascience.com
但仍然需要机器学习专业知识来进行评估并缓解伦理风险
虽然生成型 AI 解决方案通常比传统机器学习模型更容易实现,但它们的部署仍然需要机器学习专业知识,特别是在评估、监控和伦理风险管理方面。
就像传统的机器学习一样,GenAI 的成功依赖于强有力的评估。由于其普遍的“自由输出”关系(答案相关性、正确性、语气、幻觉、伤害风险等),这些解决方案需要从多个角度进行评估。在部署之前进行这一步骤非常重要(参见上图 ML 与 GenAI 项目阶段),通常称为“离线评估”,因为它能够在系统部署之前提供其行为和性能的预期。务必查看这个LLM 评估指标的精彩概述,它区分了统计评分器(如 BLEU 或 ROUGE 等文本相关性的定量指标)和基于模型的评分器(例如基于嵌入的相似性度量)。**数据科学团队擅长设计和评估指标,**即使这些指标可能有些抽象(例如,如何衡量有用性或相关性?)。
一旦 GenAI 解决方案部署后,监控变得至关重要,以确保其按预期运行并保持预期效果。可以检查与评估中提到的类似指标,以确保离线评估的结论在解决方案部署并处理实际数据后得以保持。像 Datadog 这样的监控工具已经提供了LLM 特定的可观察性指标。在这种情况下,通过与用户研究团队密切合作,直接向用户征求反馈(例如“你觉得这些建议有用吗?如果没有,为什么?”),可以将定量的见解与定性的反馈相结合,进一步丰富信息。
GenAI 模型的复杂性和黑箱设计放大了它们可能带来的伦理风险。机器学习团队在将有关可信 AI 的知识带入决策中扮演着至关重要的角色,他们具备对潜在问题的敏感性,能够识别并减轻这些风险。这项工作可以包括进行风险评估、选择偏见较小的基础模型(ComplAI是一个有趣的框架,用于在伦理维度上评估和基准测试大型语言模型)、定义和评估公平性及无歧视指标,并应用技术和防护措施以确保输出符合社会和组织的价值观。
AI 素养在公司内部变得越来越重要
一家公司的竞争优势不仅取决于其内部 AI 项目,还取决于其员工对 AI 的理解和使用能力。数据科学家在推动团队 AI 素养方面发挥着关键作用,使员工能够在理解 AI 的局限性和风险的同时有效地利用 AI。在他们的帮助下,AI 不仅应作为技术团队的工具,还应作为整个组织的核心能力。
为了提升 AI 素养,组织可以实施各种由数据科学家和 AI 专家主导的举措,例如内部培训、研讨会、聚会和黑客马拉松。这样的认知可以在以后帮助:
增强内部团队并提高其生产力, 鼓励在团队已经使用的工具中使用通用 AI 或特定 AI 功能。
识别团队内潜力巨大的机会, 利用他们的专业知识。业务和产品专家可以提出之前被认为过于复杂或不可能的项目创意(而现在可能通过生成型 AI 的帮助变得可行)。
总结:数据科学家角色的持续演变
毫无争议,数据科学和人工智能领域正迅速变化,随之而来的是数据科学家和机器学习团队的角色变化。尽管生成型 AI API 使得缺乏机器学习知识的团队也能实现 AI 解决方案,但数据科学和机器学习团队的专业知识仍然对构建稳健、可靠和符合伦理的解决方案具有重要价值。在这个新背景下,数据科学家角色的重新定义包括:
跟进 AI 的最新进展, 以便选择最佳技术来解决问题,设计并实施出色的解决方案,同时承认技术的局限性,使解决方案具备面向未来的能力。
采取系统性视角, 不仅仅专注于预测模型,而是更加注重端到端的流程,并与其他角色协作,影响用户如何与系统互动(并监督系统)。
继续致力于传统的机器学习解决方案, 同时承认生成型 AI 如何帮助提高解决方案的速度和质量。
深入理解生成型 AI 的局限性和风险, 以构建可靠和值得信赖的 AI 系统(包括评估、监控和风险管理)。
在整个组织中担任 AI 冠军:推广 AI 素养,帮助非技术团队利用 AI 并识别合适的机会。
数据科学家的角色并非被取代,而是被重新定义。通过拥抱这一演变,数据科学家将继续不可或缺,引导组织有效且负责任地利用 AI。
期待生成型 AI 带来的所有机会,以及数据科学家角色的重新定义!
生成式人工智能正在革新搜索
以及为什么你和你的公司应该关注这一变化
·发表于 Towards Data Science ·阅读时间:7 分钟·2024 年 7 月 30 日
--

图片由 Agnieszka Boeske 提供,发布在 Unsplash
人工智能的进步正在将传统的搜索引擎转变为回答机器。这一变化由新兴和传统的网页搜索公司共同推动,并且正在影响全球用户获取信息的方式。
生成式人工智能搜索的主要参与者是谁,他们是如何接近解决方案的?对用户来说意味着什么?公司如何确保在这一新的搜索范式中,其内容仍然对用户可见?这对产品经理、数据科学家和所有技术从业者意味着什么?这些问题将是本文的讨论重点!
参与者
多年来,谷歌在搜索引擎市场的主导地位首次面临可信的竞争。一些新公司已经崭露头角,推出了核心的生成式人工智能(GenAI)搜索产品,例如:
- Perplexity — “知识的起点”:基于生成式人工智能的搜索工具,具备如链接到来源、相关问题、图片展示和图像生成等附加功能。

Perplexity 的用户界面,展示了“什么是搜索引擎”的问题示例
- You.com — “使用智能助手提问/搜索网络/提供帮助…”
使用 Python 构建 GenAI:从零开始构建代理(完整教程)
使用 Ollama、LangChain、LangGraph(无 GPU,无 APIKEY)
·发表于Towards Data Science ·15 分钟阅读·2024 年 9 月 29 日
--

图片由Mathieu Stern提供,来源于Unsplash
(除非另有说明,否则所有图片均为作者提供)
简介
提示工程是设计和优化提示(文本输入)以增强大语言模型(LLMs)行为的实践。目标是通过精心设计指令,获得模型所需的回应。最常用的提示技术包括:
**连锁思维:**涉及生成逐步推理过程以得出结论。通过明确列出导致最终答案的逻辑步骤,推动模型“边想边说”。
ReAct **(推理+行动):**将推理与行动相结合。模型不仅会思考问题,还会根据其推理采取行动。因此,它更具互动性,模型在推理步骤和行动之间交替进行,逐步完善其方法。基本上,这是一个“思考”,“行动”,“观察”的循环。
举个例子:想象一下要求一个 AI “找到一台价格低于 $1000 的最佳笔记本电脑”。
- 普通回答:“联想 Thinkpad”。
- 连锁思维回答:“我需要考虑性能、电池寿命等因素……”
使用 Python 的生成式 AI:编码智能体
构建一个数据科学家 AI,可以使用 SQL 查询数据库,使用 Python 分析数据,使用 HTML 编写报告,并进行机器学习(无需 GPU,无需 APIKEY)
·发表于 Towards Data Science ·阅读时间:10 分钟·2024 年 9 月 9 日
--

图片来源:Goran Ivos于Unsplash
(除非另有注明,所有图片均来自作者)
简介
在生成式 AI 中,智能体是设计用于处理顺序推理的 AI 系统,具有在 LLM 的通用知识库不足时执行外部工具(如数据库查询、网页搜索)的选项。简单来说,一个普通的 AI 聊天机器人在不知道如何回答问题时会生成随机文本。而智能体则会激活其工具来填补空白并给出具体回应。
更准确地说,AI 智能体能够在其环境内进行自主决策和行动,以实现特定目标。它们与 LLM 的区别在于能够互动并执行实际任务,而不仅仅是处理文本。LLM 擅长自然语言理解和生成,但不会在文本回应之外自主执行任务。而 AI 智能体则能适应、学习,并以更高的独立性运作,使其适用于超越文本处理的动态应用。
GenAI 与 Python:LLM 与智能代理
创建一个 AI 小队,在不需要 GPU 的情况下自动化你笔记本电脑上的几乎所有任务
·发表于Towards Data Science ·阅读时长 11 分钟·2024 年 7 月 11 日
--
在本文中,我将展示如何本地构建智能代理,这是生成性人工智能(GenAI)最前沿的技术,并解释它与普通 LLM 的区别。

图片来自Jeffery Ho,发布于Unsplash
一般来说,智能代理是具有足够理解力的实体,能够感知环境,自动采取行动以实现特定目标,并通过获取知识不断改善(像人类一样)。
在技术领域,AI 代理是一个自主程序,执行任务、做出决策并与其他实体进行沟通。通常,智能代理会被提供一组工具,它们可以在完成任务时决定是否使用这些工具。这个概念扩展了强化学习,其中代理从一组预定义的行动中选择,以最大化在互动环境中的奖励。
大语言模型(LLM)不是智能代理。LLM 是利用词嵌入和变压器架构进行高级自然语言处理的神经网络。它们对人类语言有深刻理解,但在其知识库之外不执行任何操作。
在 GenAI 中,智能代理是用于处理顺序推理的人工智能系统,具有执行外部工具(例如数据库查询、网页等)的选项。
GenAI 与 Python:使用 LLM 的 RAG(完整教程)
在没有 GPU 的情况下,使用多模态数据构建并运行自己的 ChatGPT
·发表于 Towards Data Science ·12 分钟阅读·2024 年 6 月 28 日
--
在本文中,我将展示如何构建一个简单的 AI,它能够有效地从个人文档中学习知识并回答问题。

图片由 Eric Krull 提供,来自 Unsplash
介绍
自然语言处理(NLP)是人工智能领域,研究机器与人类语言之间相互作用的学科。到目前为止,NLP 的巅峰成就就是大语言模型(LLM)的出现,这些模型在海量文本数据上进行训练,能够学习语言模式和变体。随着深度学习和神经网络的兴起,“语言模型”这一术语变得普及。特别是在 2018 年,谷歌推出了基于Transformer 架构,显著提升了 NLP 模型的性能(谷歌的BERT和 OpenAI 的GPT)。
如今,LLMs 通常用于文本到文本任务或“文本生成”(例如翻译、总结、聊天机器人和虚拟助手,甚至编写整本书)。正因如此,我们见证了生成性人工智能(GenAI)的崛起,这是一个专注于创造新内容(如文本、图像、音频、视频)的行业领域。
AI 中的性别偏见(国际妇女节版)
关于 AI 中的性别偏见的简要概述与讨论
·发表于 Towards Data Science ·阅读时间:13 分钟·2024 年 3 月 8 日
--

由 Midjourney 创建
本文最初发表于 art fish intelligence
引言
为了庆祝国际妇女节,我想写一篇关于 AI 中性别偏见的短文。
AI 模型反映并且常常夸大了来自现实世界的性别偏见。量化模型中存在的偏见对于正确地处理和缓解这些偏见至关重要。
在本文中,我展示了一些重要的工作(以及目前正在进行的工作),旨在揭示、评估和衡量 AI 模型中不同方面的性别偏见。我还讨论了这些工作的意义,并强调了我注意到的一些空白。
但究竟什么是偏见呢?
所有这些术语(“性别”、“偏见”和“AI”)有时会被过度使用且含糊不清。
在 AI 研究的语境下,“性别”通常包括二元的男性/女性(因为计算机科学家更容易度量),偶尔会有“中性”类别。“AI”指的是基于人类创建数据训练的机器学习系统,包括像词嵌入这样的统计模型和基于 Transformer 的现代模型……
使用 Nvidia 的 LLaMa-Mesh 生成 3D 图像
深度学习论文
论文深度剖析(5 分钟速览)
·发表于 Towards Data Science ·5 分钟阅读·2024 年 11 月 19 日
--

摘自论文
介绍
上周,NVIDIA 发布了一篇引人入胜的论文(LLaMA-Mesh: 用语言模型统一生成 3D 网格),该论文展示了如何使用自然语言生成 3D 网格物体。
简单来说,如果你能说,“给我讲个笑话”,那么现在你可以说,“给我一个车的 3D 网格”,它就能生成一个 OBJ 格式的输出文件(稍后详细说明)。
如果你想尝试一些例子,可以点击这里 ——
huggingface.co/spaces/Zhengyi/LLaMA-Mesh
对我来说,最令人惊讶的部分是,它在没有扩展词汇或引入新标记的情况下完成了这一点,这在大多数微调任务中是很常见的做法。
但是首先,什么是 3D 网格?
3D 网格是 3D 物体的数字表示,包含顶点、边和面。
例如,考虑一个立方体。它有 8 个顶点(角点)、12 条边(连接角点的线)和 6 个面(方形的面)。这是立方体的基本 3D 网格表示。立方体的顶点(v)定义了它的角点,面(f)描述了这些角点如何连接形成表面。
下面是一个 OBJ 文件的示例,表示 3D 物体的几何形状
# Vertices
v: (0, 0, 0)
v: (1, 0, 0)
v: (1, 1, 0)
v: (0, 1, 0)
v: (0, 0, 1)
v: (1, 0, 1)
v: (1, 1, 1)
v: (0, 1, 1)
# Faces
f 1 2 3 4
f 5 6 7 8
f 1 5 8 4
f 2 6 7 3
f 4 3 7 8
f 1 2 6 5
这些数字随后被软件解释,软件会渲染出最终图像,即 3D 立方体。(或者你也可以使用像这个这样的 HuggingFace 空间来渲染物体)
随着对象复杂度的增加(与上面的简单立方体相比),它们将有成千上万甚至百万个顶点、边和面来创建详细的形状和纹理。此外,它们还将有更多的维度来捕捉诸如纹理、朝向等信息。
实际来说,日常物品(如长椅)的 obj 文件大概长这样:

不同对象的 obj 文件示例(摘自论文)
如你从上面的图片中可能已经注意到的那样,LLM(如 GPT4o 和 LLama3.1)在某种程度上能够直接生成 obj 文件。然而,如果你查看长椅的渲染网格图像,你会明白为什么从质量角度看需要进行微调。
LLM 如何处理 3D 网格?
众所周知,LLMs 通过将标记(如 cat)转换为标记 ID(如 456)来理解文本。类似地,为了处理标准的 OBJ 格式,我们也必须以某种方式将通常是小数的顶点坐标转换为整数。
他们在论文中使用顶点量化来实现这一点,将一个坐标分割成多个标记(类似于长单词 operational 会被分割成两个标记 —— oper 和 ational,如 GPT4o tokenizer 所示)。正如预期的那样,减少表示小数的标记数会带来正常的精度-成本权衡。

为了实现顶点量化,他们将网格中的三个轴都缩放到 (0, 64) 范围,并将坐标量化为最接近的整数,即每个轴的值可以在 0 到 64 之间(在本例中为 39、19 和 35)。最后,通过读取和生成这样的格式,LLM 就能够处理 3D 对象。
LlaMa-Mesh 的训练过程是什么?
LLama-Mesh 是通过使用 SFT(监督式微调)方法对 LLama3.1–8B 指令模型进行微调而创建的,以提高其网格理解和生成能力。
由于它是一个 SFT,我们需要为其提供文本-3D 指令的输入输出示例。以下是一个示例:
Input
User: Create a 3D obj file using the following description: a 3D model of a car.
Output
Assistant: <start of mesh> v 0 3 4 v 0 4 6 v 0 3 … f 1 3 2 f 4 3 5 … . <end of mesh>
除了生成 3D 网格,LLama-Mesh 还能够解释 3D 网格。为此,它的训练数据还包含了几个网格理解和网格生成的示例,作为对话式格式的一部分。以下是数据集中的一些示例:

为 LLama-Mesh 精心策划的训练数据集
论文中的最有趣部分
- LlaMa-Mesh 可以与文本和 3D 对象 无缝 交互,无需特殊的标记器或扩展 LLM 的词汇量(这要归功于使用 OBJ 格式和前述的顶点量化方法,它可以有效地将 3D 网格数据标记化为 LLM 可以无缝处理的离散标记)。

图片摘自论文
- LlaMa-Mesh 可以从相同的输入文本生成多样化的形状。

来自论文
- 尽管微调过程略微降低了模型的语言理解和推理能力(他们将其视为由选择的指令数据集和较小的 8B 模型规模所带来的局限性),但这一点被微调模型能够生成高质量 OBJ 文件以供 3D 网格生成的能力所抵消。

对比基础模型和微调版本在评估模型的通用知识、常识推理和数学问题解决能力方面的指标(图片来自论文)。
为什么你应该关注这篇论文?
我已经为大语言模型生成类人文本、代码以及与视觉内容推理的能力感到惊叹。将 3D 网格加入这个列表简直是太聪明了。
像 LLaMa-Mesh 这样的 LLM 有潜力彻底改变包括游戏、教育和医疗在内的多个行业。
它可以用于直接从文本描述生成用于视频游戏的逼真资产,如角色、环境和物体。
同样,它可以加速产品开发和创意过程,因为任何公司都需要一个设计来知道他们要创造什么。
它还可以用于建筑物、机械、桥梁和其他基础设施项目的建筑设计。最后,在教育技术领域,它可以用于将互动 3D 模拟嵌入到培训材料中。
这篇论文简洁明了,阅读起来非常快速,我强烈建议你阅读它。
论文页面 — arxiv.org/pdf/2411.09595
代码 — github.com/nv-tlabs/LLaMA-Mesh
Nvidia 的博客 — research.nvidia.com/labs/toronto-ai/LLaMA-Mesh/
利用 LangChain 代理生成音乐推荐
由 Bedrock Claude 和 Spotify API 提供支持
·发表于Towards Data Science·阅读时长 10 分钟·2024 年 3 月 5 日
--

图片来自Unsplash,由Marcela Laskoski提供
正如我们在之前的文章中探讨过的那样,大型语言模型(LLMs)包含了大量的知识,能够基于它们已经经过微调的数据回答跨多个领域的问题。过去我们分析了如何利用检索增强生成(RAG)等技术,通过提供额外的数据源来增强回答,从而帮助模型生成更准确的输出。虽然 RAG 和微调可以用来让模型熟悉特定的数据/知识库,但有时模型需要访问那些可能会发生变化的数据源。
一个例子是实时数据源。例如,如果我们问模型今天的天气,它将无法生成正确的回应。

ChatGPT 响应(作者截图)
LLMs(大型语言模型)的一大问题是它们无法访问外部数据源。模型的训练有一个特定的时间框架,因此可能无法访问我们需要提供正确答案的最新信息。
使用 AutoGen 可交互代理生成“验证过的”Python 代码
利用多代理工作流进行代码测试和调试
·发表于Towards Data Science ·阅读时长:8 分钟·2024 年 4 月 9 日
--

“两个 AI 机器人解决一个错误” — 来源:Adobe Firefly(图片由作者生成)
现在是 2024 年 4 月,自从我们开始使用像 ChatGPT 这样的 LLM 来辅助代码生成和调试任务,已经过去了大约 17 个月。虽然这极大提高了生产力,但确实有时候生成的代码充满了错误,让我们不得不走老路——StackOverflow。
在本文中,我将简要演示如何使用 AutoGen 提供的可交互代理来解决缺乏“验证”的问题。
完整的 AutoGen 课程:www.youtube.com/playlist?list=PLlHeJrpDA0jXy_zgfzt2aUvQu3_VS5Yx_
什么是 AutoGen?
“AutoGen 是一个框架,能够使用多个能够相互对话的代理来开发 LLM 应用,以解决任务。”
展示 LeetCode 问题求解器:
从静默安装 autogen 开始:
!pip install pyautogen -q --progress-bar off
我正在使用 Google Colab,因此我在 Secrets 标签页中输入了我的 OPENAI_API_KEY,并与其他模块一起安全加载
import os
import csv
import autogen
from autogen import Cache
from google.colab import userdata
userdata.get('OPENAI_API_KEY')
我只使用gpt-3.5-turbo,因为它比gpt4便宜。如果你能够承担更昂贵的实验,或者你做的是更“严肃”的事情,显然应该使用更强的模型。
llm_config = {
"config_list": [{"model": "gpt-3.5-turbo", "api_key": userdata.get('OPENAI_API_KEY')}],
"cache_seed": 0, # seed for reproducibility
"temperature": 0, # temperature to control randomness
}
现在,我将从我最喜欢的 LeetCode 问题Two Sum中复制问题陈述。它是 Leetcode 风格面试中最常见的题目之一,涵盖了诸如使用哈希映射进行缓存和基本方程式处理等基本概念。
LEETCODE_QUESTION = """
Title: Two Sum
Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target. You may assume that each input would have exactly one solution, and you may not use the same element twice. You can return the answer in any order.
Example 1:
Input: nums = [2,7,11,15], target = 9
Output: [0,1]
Explanation: Because nums[0] + nums[1] == 9, we return [0, 1].
Example 2:
Input: nums = [3,2,4], target = 6
Output: [1,2]
Example 3:
Input: nums = [3,3], target = 6
Output: [0,1]
Constraints:
2 <= nums.length <= 104
-109 <= nums[i] <= 109
-109 <= target <= 109
Only one valid answer exists.
Follow-up: Can you come up with an algorithm that is less than O(n2) time complexity?
"""
现在我们可以定义我们的两个代理。一个代理充当“助手”代理,提供解决方案,另一个代理充当我们(用户)的代理,并负责执行建议的 Python 代码。
# create an AssistantAgent named "assistant"
SYSTEM_MESSAGE = """You are a helpful AI assistant.
Solve tasks using your coding and language skills.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
1\. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
2\. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly.
Solve the task step by step if you need to. If a plan is not provided, explain your plan first. Be clear which step uses code, and which step uses your language skill.
When using code, you must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
Additional requirements:
1\. Within the code, add functionality to measure the total run-time of the algorithm in python function using "time" library.
2\. Only when the user proxy agent confirms that the Python script ran successfully and the total run-time (printed on stdout console) is less than 50 ms, only then return a concluding message with the word "TERMINATE". Otherwise, repeat the above process with a more optimal solution if it exists.
"""
assistant = autogen.AssistantAgent(
name="assistant",
llm_config=llm_config,
system_message=SYSTEM_MESSAGE
)
# create a UserProxyAgent instance named "user_proxy"
user_proxy = autogen.UserProxyAgent(
name="user_proxy",
human_input_mode="NEVER",
max_consecutive_auto_reply=4,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
code_execution_config={
"work_dir": "coding",
"use_docker": False,
},
)
我将 human_input_mode 设置为“NEVER”,因为我不打算自己提供任何输入,并将 max_consecutive_auto_reply 设置为 4,以限制对话中的来回轮次。助手代理已被指示以“TERMINATE”响应,告知用户代理何时结束对话。
现在,进入有趣的部分!我们将通过从用户代理发送消息给助手来启动对话。
使用 AutoGen 的一个额外好处(即使是非代理工作流)是它提供了显式的缓存功能,帮助你在开发过程中节省 API 成本。在这里,我将响应缓存到磁盘,但你也可以为此集成 redis。
# Use DiskCache as cache
with Cache.disk(cache_seed=7) as cache:
# the assistant receives a message from the user_proxy, which contains the task description
chat_res = user_proxy.initiate_chat(
assistant,
message="""Solve the following leetcode problem and also comment on it's time and space complexity:\n\n""" + LEETCODE_QUESTION
)
输出:
user_proxy (to assistant):
Solve the following leetcode problem and also comment on it's time and space complexity:
Title: Two Sum
...
Follow-up: Can you come up with an algorithm that is less than O(n2) time complexity?
--------------------------------------------------------------------------------
assistant (to user_proxy):
To solve this problem efficiently, we can use a dictionary to store the elements we have seen so far along with their indices. While iterating through the array, we can check if the complement of the current element (target - current element) is already in the dictionary. If it is, we have found the pair that sums up to the target.
Here's the Python code to solve the Two Sum problem with a time complexity less than O(n²):
```python
# 文件名: two_sum.py
def two_sum(nums, target):
num_dict = {}
for i, num in enumerate(nums):
complement = target - num
如果 complement 在 num_dict 中:
return [num_dict[complement], i]
num_dict[num] = i
# 使用给定示例测试函数
nums1 = [2, 7, 11, 15]
target1 = 9
print(two_sum(nums1, target1)) # 输出: [0, 1]
nums2 = [3, 2, 4]
target2 = 6
print(two_sum(nums2, target2)) # 输出: [1, 2]
nums3 = [3, 3]
target3 = 6
print(two_sum(nums3, target3)) # 输出: [0, 1]
```py
The time complexity of this solution is O(n) because we iterate through the list of numbers only once. The space complexity is also O(n) because we store elements in a dictionary.
You can run this code and check if it provides the correct output. Let me know the total run-time, and if it's less than 50 ms, you can reply with "TERMINATE". If it's not, I can suggest a more optimal solution.
--------------------------------------------------------------------------------
>>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...
user_proxy (to assistant):
exitcode: 0 (execution succeeded)
Code output:
[0, 1]
[1, 2]
[0, 1]
--------------------------------------------------------------------------------
assistant (to user_proxy):
The Python script ran successfully and provided the correct output for the given examples. Since the total run-time is not provided, I will suggest measuring the total run-time of the algorithm using the "time" library to determine if it's less than 50 ms.
You can add the functionality to measure the total run-time by importing the "time" library and using it to calculate the elapsed time before and after calling the `two_sum` function. Here's how you can modify the code:
```python
# 文件名: two_sum.py
import time
def two_sum(nums, target):
num_dict = {}
for i, num in enumerate(nums):
complement = target - num
如果 complement 在 num_dict 中:
return [num_dict[complement], i]
num_dict[num] = i
# 使用给定示例测试函数并测量总运行时间
nums1 = [2, 7, 11, 15]
target1 = 9
start_time = time.time()
print(two_sum(nums1, target1)) # 输出: [0, 1]
end_time = time.time()
print("总运行时间:", (end_time - start_time) * 1000, "毫秒")
nums2 = [3, 2, 4]
target2 = 6
start_time = time.time()
print(two_sum(nums2, target2)) # 输出: [1, 2]
end_time = time.time()
print("总运行时间:", (end_time - start_time) * 1000, "毫秒")
nums3 = [3, 3]
target3 = 6
start_time = time.time()
print(two_sum(nums3, target3)) # 输出: [0, 1]
end_time = time.time()
print("总运行时间:", (end_time - start_time) * 1000, "毫秒")
```py
Please run this modified code and provide the total run-time for each example. If the total run-time is less than 50 ms for all examples, you can reply with "TERMINATE".
--------------------------------------------------------------------------------
>>>>>>>> EXECUTING CODE BLOCK 0 (inferred language is python)...
user_proxy (to assistant):
exitcode: 0 (execution succeeded)
Code output:
[0, 1]
Total run-time: 0.01239776611328125 ms
[1, 2]
Total run-time: 0.00286102294921875 ms
[0, 1]
Total run-time: 0.0016689300537109375 ms
--------------------------------------------------------------------------------
assistant (to user_proxy):
The Python script ran successfully, and the total run-time for each example is significantly less than 50 ms. Therefore, you can reply with "TERMINATE". The algorithm provided an efficient solution with a time complexity of O(n) and a space complexity of O(n).
--------------------------------------------------------------------------------
这就是这里发生的事情:
用户代理要求助手根据任务描述解决问题。
助手建议了一个带有 Python 代码块的解决方案
用户代理执行 Python 代码。
助手读取控制台输出并返回修改后的解决方案(包含时间测量功能。老实说,我本来希望立即得到这个修改后的解决方案,但这种行为可以通过提示工程或采用更强的 LLM 来调整)。
使用 AutoGen,你还可以显示代理工作流的成本。
chat_res.cost
({'total_cost': 0,
'gpt-3.5-turbo-0125': {'cost': 0,
'prompt_tokens': 14578,
'completion_tokens': 3460,
'total_tokens': 18038}}
总结:
因此,通过使用 AutoGen 的可交互代理:
我们自动验证了 LLM 提议的 Python 代码实际上是可行的。
并创建了一个框架,通过该框架,LLM 可以通过读取控制台中的输出进一步响应语法或逻辑错误。
感谢阅读!请关注我并订阅,以便我发布新文章时您可以第一时间看到! 😃
查看我的其他文章:
在树莓派上使用 Stable Diffusion 和 OnnxStream 生成图像
学习如何使用 OnnxStream 在树莓派上生成图像,利用 Stable Diffusion XL Turbo!
·发表于 Towards Data Science ·5 分钟阅读·2024 年 1 月 20 日
--

使用 SDXL Turbo 在树莓派上生成的图像,每张大约需要 3 分钟 | 来源:作者
在我上一篇文章中,我分享了如何在树莓派上运行大型语言模型和视觉语言模型。这一次,我们不使用 LLMs 和 VLMs,而是运行一个图像生成模型 —— Stable Diffusion XL (SDXL) Turbo —— 在树莓派 5 上。这听起来像是一个不可能完成的壮举,但开源的奇迹确实存在,在资源极为有限的环境中运行 SDXL Turbo 模型就是其中之一。
OnnxStream
OnnxStream 是一个由 Vito Plantamura 创建的开源项目,最初的目的是通过尽可能减少内存消耗,在树莓派 Zero 2 上运行 Stable Diffusion 1.5(SD1.5),尽管这样会以增加推理延迟/吞吐量为代价。
在写作时,它已经扩展到不仅支持 Stable Diffusion 1.5,还支持 Stable Diffusion XL 1.0 Base(SDXL)和 Stable Diffusion XL Turbo 1.0。我不会详细讲解这个惊人壮举是如何实现的,因为 GitHub 仓库 已经解释得非常清楚了。
相反,我们直接进入实际操作。
技术要求
你只需要以下内容:
树莓派 5 — 或树莓派 4 或任何其他树莓派,只有预期速度会较慢
SD 卡 — 至少 16GB,已经设置了 Raspbian 或其他 Linux 发行版。SDXL Turbo 的权重大约为 8GB。
需要互联网连接

在树莓派上通过单步扩散生成的图像 | 来源:作者
设置 OnnxStream
这里的指令来自 GitHub 仓库,但我会将其拆解并稍作解释。
1. 构建 XNNPack
首先,我们需要安装XNNPack,这是 Google 提供的一个库,提供“高效的浮点神经网络推理运算符”。但我们不能直接获取最新版本,以防出现不兼容的更改。相反,我们将获取 OnnxStream 创建者在编写时验证过的版本。在终端中运行:
git clone https://github.com/google/XNNPACK.git
cd XNNPACK
git checkout 579de32260742a24166ecd13213d2e60af862675
mkdir build
cd build
cmake -DXNNPACK_BUILD_TESTS=OFF -DXNNPACK_BUILD_BENCHMARKS=OFF ..
cmake --build . --config Release
构建 XNNPack 需要几分钟时间。去喝杯咖啡或做点别的吧。
2. 构建 OnnxStream
接下来,我们需要构建 OnnxStream。在终端中运行:
git clone https://github.com/vitoplantamura/OnnxStream.git
cd OnnxStream
cd src
mkdir build
cd build
cmake -DMAX_SPEED=ON -DXNNPACK_DIR=<DIRECTORY_WHERE_XNNPACK_WAS_CLONED> ..
cmake --build . --config Release
确保将<DIRECTORY_WHERE_XNNPACK_WAS_CLONED>替换为 XNNPack 克隆路径(不是构建文件夹)。在我的例子中,它位于 /home/admin/XNNPACK/。
3. 下载模型权重
现在,我们需要下载 SDXL Turbo 的模型权重。在终端中运行:
git lfs install
git clone --depth=1 https://huggingface.co/AeroX2/stable-diffusion-xl-turbo-1.0-onnxstream
如果你还没有安装 git-lfs,先安装它。由于模型权重非常大,这一步比前一步还要耗时。去吃个午饭吧!
你还可以运行其他两个支持的模型——Stable Diffusion 1.5 和 Stable Diffusion XL 1.0 Base,通过从OnnxStream 的 GitHub 仓库提供的链接下载它们的权重。如果你要下载所有这些模型,请确保你的 SD 卡有足够的空间!
一旦完成,就可以开始了!我们准备好在树莓派上生成图像了!
生成图像
要生成图像,请运行以下代码块:
cd ~/OnnxStream/src/build/
./sd --turbo --models-path /home/admin/stable-diffusion-xl-turbo-1.0-onnxstream --prompt "An astronaut riding a horse on Mars" --steps 1 --output astronaut.png
将提示词替换为你想要生成的内容。我这里只是使用了经典的宇航员提示词。由于 SDXL Turbo 不需要太多步数就能生成好看的图像,所以我将 steps 设置为 1。
你还可以传递其他参数,如 — neg-prompt 用于负向提示(SDXL Turbo 不支持负向提示,但你可以在其他两个模型中使用),— steps 设置生成步数,— seed 设置随机种子。
所需的参数将根据使用的模型类型而变化,因此请查看OnnxStream 的 GitHub 仓库,了解如果你使用的是非 SDXL Turbo 的模型所需传递的完整参数列表。

你应该会得到类似这样的输出 | 来源:作者
如上图所示,在树莓派 5 上,每个扩散步骤大约需要 1 分钟,总体加上预处理和解码,生成一张图像大约需要 3 分钟。

使用相同种子和提示,分别经过 1、2、5 和 10 步生成的图像 | 来源:作者
这是从第 1 步到第 10 步,相同种子和提示下生成的图像对比和进展。你可以看到,即使只是经过一步的优化,生成的图像已经非常精美。这与 SDXL 或 SD1.5 不同,后者需要更多的步骤才能达到这样的质量。
结论
由于生成一张图像大约需要几分钟时间,因此关于它的使用场景问题不容忽视。

必不可少的我的树莓派 5 照片 | 来源:作者
对我来说,最明显和有趣的使用场景是一个不断变化的照片框架,每隔几分钟生成一张新图像。实际上,有一个类似的项目,由 GitHub 上的 rvdveen 开发,使用 OnnxStream,通过树莓派 Zero 2 W 生成新闻文章的图像,并通过电子墨水显示器显示在照片框架上。使用 SD1.5 生成图像大约需要 5 小时,但嘿,你不需要照片框架实时更换显示的内容。
或许你只是想要一个本地托管的图像生成器,它可以在不占用网络上主要计算设备的情况下生成高质量的图像。
在树莓派上玩 SDXL Turbo,玩得开心!
免责声明:我与 OnnxStream 或 StabilityAI 没有任何关联。所有观点和意见均为我个人的,并不代表任何组织。
使用 Rust 生成地图瓦片
从 Python 转向 Rust 有多容易?
·发布于 Towards Data Science ·阅读时间:6 分钟 ·2024 年 5 月 19 日
--

图片来自 Diego García 于 Unsplash
有时,你需要在互动地图上显示大量数据,同时保持其可用性和响应性。互动在线地图是用 HTML 实现的,在地图显示中添加许多视觉元素通常会降低性能和可用性。一个可能的替代方案是离线绘制所有元素,并通过瓦片将它们作为透明层显示在地图上。每个方形瓦片整齐地覆盖在地图的瓦片之上,而互动地图控制只需处理较少的视觉元素。
几年前,我通过使用 Python 和 车辆能源数据集 编写了一个自定义地图瓦片生成器来解决这个问题。这个项目展示了如何通过在地图上使用自定义瓦片层来显示大量信息。这个过程包括使用一个 Web 应用程序来生成、缓存并提供瓦片。
了解如何为你的互动地图创建自定义瓦片。
towardsdatascience.com
如你所知,Python 运行速度较慢,因此在生成每个瓷砖时,Web 应用程序的性能会受到显著影响。当瓷砖被缓存后,服务过程很快,在与地图交互时不会有明显的延迟。
尽管如此,我仍然对性能不满意,因此我想通过显著提高代码执行速度来再次解决这个问题。最初,我考虑将代码库转换为Cython,但后来我的注意力转向了另一个候选者。
进入 Rust
Rust 编程语言已经在我的视野中有一段时间了。由于我有 C、C++ 和 C# 的背景,我对该语言在内存安全性和 C 类性能方面的承诺产生了兴趣。我最终决定尝试一下,它看起来是学习和练习这门语言的一个完美起点。
在观看了许多 YouTube 视频和阅读了各种资料之后,我开始使用 Rust 来解决这个问题。我有三个主要问题:创建一个 Web 应用程序有多难?如何访问SQLite数据?如何通过编程创建透明的 PNG 图像?幸运的是,回答这些问题比预期的要简单得多。
Rocket
为了解决 Web 应用程序的问题,我转向了Rocket。Rocket 在线文档中的快速入门页面展示了如何轻松地设置一个基础的 Web 应用程序。我们肯定需要更多复杂的功能来构建我们的瓷砖服务器,但它的基础模板看起来非常简洁明了。事实证明,Rocket 非常容易使用和适应,对我来说它是一个值得保留的工具。
sqlx
经过几分钟的在线搜索,我很快意识到,访问 SQLite 数据库最流行的答案是通过sqlx包。它呈现了一个与我在 Python 中使用的范式不同的方式,但与我之前在 C# 开发时使用的方式非常相似。在这里,你必须使用强类型数据结构,而不是通用的数据结构或熊猫的数据框。虽然这些类型的数据结构稍显繁琐,但它们会为你的生活带来额外的理智。
下方的图 1展示了我用来从等级范围表中检索数据的第一个完整代码示例。
图 1 — 检索等级范围表的示例代码。(图片来源:作者)
PNG
使用image crate 创建、绘制和保存 PNG 文件非常简单。创建透明瓷砖的代码相当简单:
图 2 — image crate 简化了图像的操作。上面的代码展示了如何创建一个纯色的 256x256 地图瓷砖。(图片来源:作者)
我还使用了colorgrad包来处理瓷砖的颜色渐变。
在我详细讨论代码之前,让我们回顾一下绘制交通密度瓦片背后的原理。
一直是瓦片
地图瓦片通常由 256x256 的方形位图组成。我们可以通过结合 x 和 y 坐标、一个“缩放”级别或一个四分键代码来表示每个瓦片。每个缩放级别对应于不同维度瓦片的方形拼贴。整个地球在最上层的一个瓦片上进行描绘。通过缩放,原始瓦片被拆分为四个瓦片。以下图 2和图 3说明了缩放的过程。

图 3 — 在缩放级别 0 上,整个世界呈现在单个瓦片中。(图片来源:OpenStreetMap)

图 4 — 对之前的瓦片进行缩放,我们得到了四个具有相同单独维度的瓦片。(图片来源:OpenStreetMap)
如果我们继续缩放,在八次迭代后,每个生成的瓦片将对应第一个瓦片上的一个像素。这一观察结果提供了我们计算并在瓦片上显示交通密度信息的洞见。
如上一篇文章所述,瓦片信息已经准备好并存储在数据库中。有关如何从车辆能量数据集生成密度数据库的说明,请参见该文章。
使用 Rust 提供瓦片
现在我们可以讨论生成、缓存和提供瓦片的 Rust 服务器代码。当前的解决方案紧密地遵循了之前的瓦片服务器设计。下面的图 5显示了主要入口点,它决定在解析并接受查询参数后,是否提供绘制过的瓦片或默认的透明瓦片。
图 5 — 主要入口点。(图片来源:作者)
如你所见,服务器只会对从一到十八的缩放级别做出响应。这个限制是嵌入在密度数据库的生成过程中。
网络应用程序使用下方图 6中列出的函数绘制每个瓦片。
图 6 — 上述函数生成瓦片(如果未缓存到磁盘),并返回瓦片文件名。(图片来源:作者)
如上面的列表所示,瓦片绘制过程有三个步骤。首先,在第 12 行,我们收集瓦片的每像素密度信息。接下来,我们检索瓦片的级别范围,即瓦片“缩放”级别的最小和最大密度值。最后,在第 14 行,我们绘制瓦片的位图。该函数通过将瓦片位图保存到文件缓存来完成。
图 7 — 上述函数在位图上绘制单个瓦片。注意,密度信息是如何通过基于对数的转换转化为颜色渐变中的条目的。(图片来源:作者)
使用代码
在正确配置数据库文件路径后,你可以通过打开终端窗口,切换到 Rust 项目目录,并运行以下命令来启动瓦片服务器:
cargo run --release
接下来,你可以打开地图客户端并配置密度瓦片层的 URI。下面的图 8展示了加载交互式地图的 Jupyter Notebook 代码单元:
图 8 — 使用上面的代码显示安阿伯地图,并叠加密度瓦片。(图片来源:作者)
就是这样!下面的图 9展示了结果。

图 9 — 上面的图片展示了叠加了交通密度瓦片的基础地图。(图片来源:OpenStreetMap 和作者生成的瓦片)
结论
我第一次接触 Rust 并不像我预期的那样困难。我首先通过阅读现有文献和观看 YouTube 视频来进行准备,然后才开始尝试。接着,我确保使用一个很棒的 IDE 来辅助开发,JetBrains 提供的 RustRover 就是其中之一。虽然该 IDE 仍处于预览模式,但我发现它在使用 Rust 时非常有帮助和指导性。如果你更喜欢使用 Visual Studio Code,那也完全没问题,只要确保安装官方支持的插件。
致谢
我使用了Grammarly来检查写作,并采纳了其中一些改写建议。
JetBrains 的 AI 助手编写了部分代码,我还用它来学习 Rust。它已成为我日常使用 Rust 和 Python 的得力助手。
许可证信息
扩展车辆能源数据集采用 Apache 2.0 许可证,与其发源地 车辆能源数据集 相同。
参考文献
车辆能源数据集(GitHub)
João Paulo Figueira 是葡萄牙里斯本的tb.lx by Daimler Truck的数据显示科学家。
生成式 AI 设计模式:全面指南
大型语言模型(LLM)工作中的参考架构模式和思维模型
·发布于 Towards Data Science ·阅读时间 8 分钟 ·2024 年 2 月 13 日
--

注意:当我在二月最初发布这篇文章时,它只是一个初步的思维实验。从那时起,我开始与一家大型出版社合作编写《生成式 AI 设计模式》一书。请关注我,及时了解我在这个领域的模式和想法更新。
对 AI 模式的需求
当我们在构建新事物时,总会依赖一些经过验证的方法、途径和模式。对于软件工程师来说,这一点尤其成立。然而,对于生成式 AI 和人工智能本身来说,情况可能并非如此。随着生成式 AI 等新兴技术的出现,我们缺乏足够文档化的模式来支撑我们的解决方案。
在这里,我分享了一些生成式 AI 的方法和模式,基于我对无数生产环境中 LLM 实现的评估。这些模式的目标是帮助缓解和克服生成式 AI 实现中的一些挑战,如成本、延迟和幻觉问题。
模式列表
分层缓存策略导致微调
多路复用 AI 代理以形成专家小组
针对多任务微调 LLM
融合基于规则与生成式方法
利用知识图谱与 LLM 的结合
生成式 AI 代理的群体效应
模块化单体 LLM 方法与组合性
LLM 的记忆认知方法
红蓝队双模型评估
1) 分层缓存策略导致微调

在这里,我们需要解决成本、冗余和训练数据等多个因素,当我们为大语言模型引入缓存策略和服务时。
通过缓存这些初步结果,系统可以在后续查询中更快速地提供答案,从而提高效率。关键在于一旦我们收集到足够的数据,通过微调层,利用这些早期互动的反馈来优化一个更专业化的模型。
专门化的模型不仅简化了流程,还将 AI 的专业能力针对具体任务进行了定制,使其在精确度和适应性至关重要的环境中(如客户服务或个性化内容创作)变得非常有效。
要开始使用,可以选择现成的服务,如 GPTCache,或者使用常见的缓存数据库如 Redis,Apache Cassandra,Memcache 来自己搭建。确保在加入额外服务时,监控并测量延迟。
2) 多路复用 AI 代理以构建专家小组

想象一个生态系统,其中多个面向特定任务的生成式 AI 模型(“代理”),每个模型都是其领域内的专家,平行工作以解决查询。这个多路复用策略能够提供多样化的响应,随后将这些响应整合成一个全面的答案。
这种设置非常适合处理复杂问题的场景,在这些场景中,问题的不同方面需要不同的专业知识,就像一个专家小组每个处理一个更大问题的方面一样。
更大的模型,如 GPT-4,用于理解上下文并将其分解为具体任务或信息请求,这些请求随后会传递给较小的代理。代理可以是像 Phi-2 或 TinyLlama 这样的较小语言模型,这些模型经过特定任务训练,访问特定工具,或者是像 GPT、Llama 这样的通用模型,具备特定的个性、上下文提示和功能调用。
3) 对大语言模型进行多任务微调

在这里,我们对一个大语言模型同时进行多个任务的微调,而不是仅针对单一任务进行微调。这是一种促进跨领域知识和技能稳健迁移的方法,增强了模型的多样性。
这种多任务学习对需要处理各种任务且要求高水平能力的平台特别有用,比如虚拟助手或 AI 驱动的研究工具。这有可能简化复杂领域中的训练和测试工作流。
用于训练大语言模型(LLM)的资源和软件包包括 DeepSpeed,以及 Hugging Face 的 Transformer 库上的训练功能。
4) 融合基于规则与生成式方法

许多现有的商业系统和组织应用仍然在某种程度上是基于规则的。通过将生成式方法与基于规则的逻辑的结构化精确性结合,这种模式旨在产生既富有创意又符合要求的解决方案。
这是一个强有力的策略,适用于那些要求输出必须符合严格标准或法规的行业,确保 AI 在创新和参与的同时,仍然保持在预定参数的范围内。一个很好的例子是为电话 IVR 系统或传统的(非 LLM 基础)基于规则的聊天机器人生成意图和消息流程。
5) 利用知识图谱与大语言模型(LLM)

将知识图谱与生成式 AI 模型结合,赋予它们以事实为导向的超能力,使得输出不仅在语境上具有意识,而且更为准确。
这种方法对于那些对真理和准确性要求严格的应用至关重要,如教育内容创作、医疗建议或任何可能导致严重后果的误信息领域。
知识图谱和图谱本体论(图谱的概念集合)允许将复杂的主题或组织问题分解为结构化的格式,以帮助为大语言模型提供深层背景。你还可以使用语言模型生成本体论,格式如 JSON 或 RDF, 我为你创建的示例提示可以使用。
你可以使用的一些知识图谱服务包括图数据库服务,如 ArangoDB、Amazon Neptune、Azure Cosmos DB 和 Neo4j。还有更广泛的数据集和服务可用于访问更广泛的知识图谱,包括 Google 企业知识图谱 API、PyKEEN 数据集 和 Wikidata。
6) AI 代理群体

这一模型汲取了自然界群体和群兽的灵感,采用了多个 AI 代理共同处理问题,每个代理提供独特的视角。
由此产生的聚合输出体现了一种集体智慧,超越了任何单一代理所能实现的成果。这个模式在需要广泛创意解决方案或在处理复杂数据集时尤为有利。
其中一个例子是 从多个“专家”视角审阅研究论文,或者从多个使用场景同时评估客户互动,从欺诈到优惠。我们将这些集体“代理”结合起来,整合他们的输入。对于高频次的集群,你可以考虑部署消息服务,如 Apache Kafka,来处理代理和服务之间的消息。
7) 模块化单体 LLM 方法与可组合性

该设计倡导适应性,采用模块化的 AI 系统,能够动态地重新配置自身,以实现最佳任务性能。它类似于一把瑞士军刀,每个模块可以根据需要选择并激活,使其在需要量身定制解决方案的企业中,针对不同的客户互动或产品需求时尤为高效。
你可以部署各种自主代理框架和架构来开发每个代理及其工具。示例框架包括 CrewAI、Langchain、Microsoft Autogen 和 SuperAGI。
对于销售模块化单体系统,这可能是专注于潜在客户开发的代理,一个负责预订,一个专注于生成消息,另一个更新数据库。未来,随着特定服务的推出,来自专业 AI 公司的服务,你可以将模块替换为外部或第三方服务,来处理特定任务或领域问题。
8) LLM 的记忆认知方法

这种方法引入了类似人类记忆的元素,使得模型能够回忆并在之前的互动基础上构建,以提供更为细致的回应。
它特别适用于持续的对话或学习场景,因为 AI 随着时间的推移逐渐发展出更深刻的理解,就像一位专注的私人助手或自适应学习平台一样。通过汇总并将关键事件和讨论存储到向量数据库中,可以发展记忆认知方法。
为了保持摘要的计算成本较低,您可以利用较小的 NLP 库,如 spaCy 或 BART 语言模型,尤其是在处理大量数据时。所使用的数据库是基于向量的,并且在提示阶段通过相似性搜索来检索短期记忆,定位关键“事实”。对于有兴趣了解工作解决方案的人,可以参考一个开源解决方案,遵循类似模式,名为 MemGPT。
9) 红队与蓝队双模型评估

在红队与蓝队评估模型中,一个 AI 生成内容,而另一个则对其进行批判性评估,类似于严格的同行评审过程。这种双模型设置非常适合质量控制,因此在内容生成平台中非常适用,特别是在新闻聚合或教育资料制作等需要信誉和准确性的场景中。
这种方法可以用来替代复杂任务中的部分人工反馈,通过一个经过精调的模型模仿人类审阅过程,精炼结果,从而评估复杂的语言场景和输出。
要点
这些生成性 AI 的设计模式不仅仅是模板;它们是未来智能系统发展的框架。在我们不断探索和创新的过程中,显然我们所选择的架构将不仅定义 AI 的能力,还将定义我们所创造的 AI 的本质。
这个列表绝非最终版本,随着生成性 AI 的模式和应用场景不断扩展,我们将看到这一领域的发展。这篇文章的灵感来自于 Tomasz Tunguz 发布的 AI 设计模式 。
喜欢这个故事吗?
Vincent Koc 是一位成就卓越、以商业为导向的技术专家和未来学家,拥有丰富的经验,专注于数据驱动和数字化领域。
免费订阅 以便在 Vincent 发布新文章时接收通知。或者在 LinkedIn 和 X 上关注他。
[## 当 Vincent Koc 发布新文章时,您可以通过电子邮件收到通知。
当 Vincent Koc 发布新文章时,您可以通过电子邮件收到通知。通过注册,您将创建一个 Medium 账户(如果您还没有的话)…
medium.com](https://medium.com/subscribe/@vkoc?source=post_page-----41425a40d7d0--------------------------------)
除非另有说明,所有图片均由作者提供
生成性 AI 是企业应在 2024 年接受的赌注
目前,大型语言模型(LLMs)在大规模应用时存在不准确性,但这并不意味着你应该通过等待来采用生成性 AI,从而失去竞争优势。
·发布于 Towards Data Science ·阅读时长 6 分钟·2024 年 1 月 4 日
--

构建具有 AI 准备能力的工作队伍,使用 data.world 的 OWLs,正如 OpenAI 的 GPT-4 所设想的那样
每项企业技术都有其存在的目的,否则它就不会存在。生成性 AI 的企业目的是从技术、商业和语言数据中快速且大规模地生成可供人类使用的输出,以推动生产力、效率和商业收益。但生成性 AI 的这一主要功能——提供巧妙的答案——也是大型语言模型(LLMs)在企业采用过程中面临的最大障碍:所谓的“幻觉”问题。
为什么幻觉会发生呢?因为,从本质上讲,大型语言模型(LLMs)是复杂的统计匹配系统。它们分析数十亿个数据点,努力确定模式,并预测对任何给定提示的最可能回应。但是,尽管这些模型可能通过它们的回答给我们留下深刻印象,展示了有用性、深度和创造力,并诱使我们每次都信任它们,但它们远非可靠。Vectara 的最新研究发现,聊天机器人“发明”新信息的频率高达 27%。在企业环境中,问题的复杂性可能变化很大,这个数字还会更高。数据.world 的 AI 实验室进行的最新基准测试使用了真实的商业数据,发现当作为独立解决方案部署时,大型语言模型仅25.5%的时间能返回准确的基本商业查询回答。对于中级或专家级问题,这些问题仍在典型的数据驱动企业查询的范围内,准确率降至零!
幻觉的倾向对那些仅在小范围或新奇用途上玩耍的个人来说可能无关紧要。但当涉及到企业部署时,幻觉却构成了系统性的风险。其后果从不便(例如,服务聊天机器人在客户互动中提供无关信息)到灾难性(例如,在美国证券交易委员会(SEC)报告中输入错误的数字)不等。
如今,生成式 AI 对企业来说仍然是一场赌博。然而,这也是一场必不可少的赌博。正如我们在 OpenAI 的第一次开发者大会上所了解到的,92%的财富 500 强公司正在使用 OpenAI 的 API。这项技术在企业中的潜力如此具有变革性,以至于未来的路径已经非常明确:开始采用生成式 AI——明白回报与风险并存。另一种选择是将自己与风险隔离开来,迅速落后于竞争对手。不可避免的生产力提升现在已经如此明显,以至于不利用它可能会对企业的生存构成存在性的威胁。所以,在面对这种选择的幻象时,组织应该如何在将生成式 AI 整合到工作流程中,同时降低风险呢?
首先,你需要优先考虑你的数据基础。 就像任何现代企业技术一样,生成式 AI 解决方案的好坏取决于其所建立的数据基础——根据思科最近发布的AI 准备度指数,意图正在超越能力,尤其是在数据方面。思科发现,虽然全球 84% 的公司认为 AI 将对其业务产生重大影响,但 81% 的公司缺乏实现 AI 工具潜力所需的数据集中化,而只有 21% 的公司表示其网络具有支持高负载 AI 工作负载的“最佳”延迟。关于数据治理,情况也相似;目前只有三分之一的受访者拥有全面的 AI 政策和协议,只有四分之一的受访者拥有系统性的 AI 偏见和公平性修正流程。
如基准测试所示,LLM 已经很难可靠地检索事实性答案。如果再加上数据质量差、缺乏数据集中化/管理能力,以及有限的治理政策,幻觉现象的风险——以及随之而来的后果——将急剧上升。简单来说,拥有强大数据架构的公司可以更好、更准确地获得信息,进而其 AI 解决方案能够做出更好的决策。使用数据目录或评估内部治理和数据录入流程可能并不是采用生成式 AI 最令人兴奋的部分。但正是这些考虑——数据治理、血统和质量——可能决定生成式 AI 项目的成败。它不仅使组织能够更快、更负责任地部署企业级 AI 解决方案,还能够帮助它们在技术发展的过程中跟上市场的步伐。
第二,你需要建立一个具备 AI 知识的员工队伍。 研究表明,高级提示工程等技术对于识别和减轻幻觉现象有帮助。其他方法,如微调,已被证明能显著提高大语言模型(LLM)的准确性,甚至在某些情况下超过更大、更先进的通用模型。然而,员工只有在接受了最新的培训和教育后,才能应用这些技巧。说实话,大多数员工并未具备这种能力。自 2022 年 11 月 30 日 ChatGPT 发布至今,已经过去了一年多!
当像 Databricks 或 Snowflake 这样的主要供应商发布新功能时,组织通常会涌向网络研讨会、会议和工作坊,以确保他们能够利用最新的功能。生成式人工智能应该也不例外。在 2024 年,建立一种文化,使得教育您的团队掌握人工智能最佳实践成为您的默认选项;例如,通过提供针对人工智能的 L&D 项目补贴或聘请外部培训顾问(比如我们在 data.world 与Rachel Woods的合作,她在我们的顾问委员会中任职,并创办并领导了 The AI Exchange)。我们还提升了Brandon Gadoci,我和我的共同创始人之外的第一位 data.world 员工,担任我们的人工智能运营副总裁。我们在内部生产力上已经取得的惊人提升,简直令人鼓舞(我在这个三篇系列文章中写到了这点)。Brandon昨天刚报告我们通过在 2023 年全职岗位中使用内部人工智能工具,团队的生产力惊人地提高了 25%!采纳这种文化将大大帮助您的组织理解、识别和减轻幻觉威胁。
第三,您需要保持对新兴人工智能生态系统的关注。 就像任何新的范式改变技术一样,人工智能周围充斥着大量新兴的实践、软件和流程,旨在最小化风险并最大化价值。尽管大型语言模型(LLMs)可能会带来巨大变革,但令人欣慰的事实是,我们才刚刚踏上人工智能演变的漫长道路。
曾经对您的组织而言陌生的技术可能会变得至关重要。我们发布的前述基准测试显示,结合知识图谱的 LLMs——一种用于在三维中情境化数据的数十年历史架构(像人类大脑一样映射和关联数据)—可以提高准确性 300%!同样,向量数据库和增强检索生成(RAG)等技术也因其在解决 LLMs 的幻觉问题方面的能力而崭露头角。从长远来看,人工智能的雄心远超目前主要 LLM 提供商的 API,因此在企业的人工智能投资中,保持好奇心和灵活性是至关重要的。
像任何新技术一样,生成式人工智能解决方案并不完美,它们产生幻觉的倾向对其目前在企业广泛部署中的可行性构成了现实威胁。然而,这些幻觉不应阻止组织进行实验并将这些模型集成到工作流程中。事实上,正如人工智能先驱、沃顿商学院创业教授 Ethan Mollick 在推文中所言:“……理解来自实验。”相反,幻觉带来的风险应当成为促使企业决策者认识到风险所在、采取相应措施来降低风险,并在过程中获得大语言模型(LLM)早期收益的推动力。2024 年是你的企业应当跨出这一步的一年。
生成式人工智能:使用 Pytorch 生成对抗网络(GANs)合成数据
解开复杂性的谜团:超越图像和语言模型
·发表于 Towards Data Science ·7 分钟阅读·2024 年 1 月 15 日
--
生成式模型因其生成美丽且逼真的图像和语言模型(例如 ChatGPT)而变得极为流行,这些模型在各个行业中的应用日益增多。这些生成式模型可以说是人工智能/机器学习现在备受关注(或担忧)的原因;因为它展示了机器学习的巨大潜力,尤其是向那些非专业领域的人展示了这一点。网络上已经有很多关于 GANs 模型的资源,但大多数都集中在图像生成上。这些图像生成和语言模型需要复杂的空间或时间细节,这增加了额外的复杂性,使得读者更难理解 GANs 的真正本质。
为了改善这一点并使 GAN 更加容易为更广泛的受众所接受,在这段简短的讨论和 GAN 模型示例中,我们将采取一种不同且更实用的方法,重点是生成数学函数的合成数据。除了作为学习用途的简化外,合成数据生成本身也变得越来越重要。数据不仅在商业决策中扮演着核心角色,而且越来越多的使用场景中,数据驱动的方法正变得比从第一原理建模更为流行。一个令人兴奋的例子是天气预报,第一原理模型包括了简化版本的纳维-斯托克斯方程,并通过数值方法求解(我得补充一句,计算成本非常高)。然而,近期使用深度学习进行天气预报的尝试(例如,看看 Nvidia 的 FourCastNet [1])在捕捉天气模式方面取得了非常成功的成果,一旦训练完成,运行起来更容易且速度更快。
生成模型与判别模型
在机器学习中,理解判别模型与生成模型之间的区别非常重要,因为它们是 GAN 的关键组成部分。让我们简要地解开这些术语:
判别模型:
判别模型专注于将数据分类到预定义的类别中,例如将狗和猫的图片分类到各自的类别中。与捕获整个分布不同,这些模型分辨出不同类别之间的边界。它们输出 P(y|x)(给定输入数据 x 时,类别 y 的概率),即它们回答给定数据点属于哪个类别的问题。
生成模型:
生成模型旨在理解数据的潜在结构。与判别模型通过类别进行区分不同,生成模型学习数据的整个分布。这些模型输出 p(x|y),即它们回答在给定类别 y 的情况下,生成特定数据点 x 的概率。
这两种模型之间的相互作用构成了 GAN 的基础。
GAN — 结构和组成部分

GAN 解剖学的示意图。图片来源:Tingting Zhu [2]
现在,让我们探索这些概念如何在 GAN 模型中结合起来。GAN 的关键组成部分包括噪声向量、生成器和判别器。
生成器:生成真实数据
为了生成合成数据,生成器使用一个随机噪声向量作为输入。在试图欺骗判别器的过程中,生成器旨在学习真实数据的分布,并生成无法与真实数据区分的合成数据。这里的问题是,对于相同的输入,它总是生成相同的输出(想象一个生成图像的模型,它生成一个逼真的图像,但总是相同的图像,这并不太有用)。随机噪声向量为过程注入了随机性,从而为生成的输出提供了多样性。
判别器:区分真实与虚假
判别器就像一个艺术评论家,经过训练能够区分真实和虚假的数据。它的角色是仔细审视接收到的数据,并为该数据的真实性分配一个概率分数。如果合成数据与真实数据相似,判别器会分配一个高概率分数,否则分配一个低概率分数。
对抗性训练:一个动态的对决
生成器努力学习生成判别器无法区分的合成数据,同时,判别器也在不断学习并提高其区分真实与合成数据的能力。这一动态训练过程促使两个模型不断完善自己的技能。这两个模型始终在竞争(因此它被称为对抗性),通过这种竞争,两个模型都能在各自的角色中变得非常出色。
使用 Pytorch 实现 GAN
让我们通过一个创建 GAN 的示例来继续。在这个示例中,我们在 PyTorch 中实现一个能够生成合成数据的模型。对于训练,我们有一个包含 6 个参数的数据集,具有以下形状(所有参数作为参数 1 的函数被绘制)。每个参数的分布和形状被故意选择得有显著的不同,以增加数据集的复杂性并模拟现实世界的数据。然而,值得一提的是,判别器和生成器的架构仍有很大的优化空间,但在本教程中我们不会关注这一点。
在本教程中,我假设你已经理解了常见的人工神经网络(ANN)模型架构和 Python 编程。我在代码中提供了注释,以帮助你理解代码。

训练数据 — 所有 6 个参数作为参数 1 的函数被绘制
定义 GAN 模型组件(生成器和判别器)
import torch
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.init as init
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
# defining a single generation block function
def FC_Layer_blockGen(input_dim, output_dim):
single_block = nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.ReLU()
)
return single_block
# DEFINING THE GENERATOR
class Generator(nn.Module):
def __init__(self, latent_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, output_dim),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
#defining a single discriminattor block
def FC_Layer_BlockDisc(input_dim, output_dim):
return nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.ReLU(),
nn.Dropout(0.4)
)
# Defining the discriminator
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
#Defining training parameters
batch_size = 128
num_epochs = 500
lr = 0.0002
num_features = 6
latent_dim = 20
# MODEL INITIALIZATION
generator = Generator(noise_dim, num_features)
discriminator = Discriminator(num_features)
# LOSS FUNCTION AND OPTIMIZERS
criterion = nn.BCELoss()
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
模型初始化与数据处理
# IMPORTING DATA
file_path = 'SamplingData7.xlsx'
data = pd.read_excel(file_path)
X = data.values
X_normalized = torch.FloatTensor((X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0)) * 2 - 1)
real_data = X_normalized
#Creating a dataset
class MyDataset(Dataset):
def __init__(self, dataframe):
self.data = dataframe.values.astype(float)
self.labels = dataframe.values.astype(float)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = {
'input': torch.tensor(self.data[idx]),
'label': torch.tensor(self.labels[idx])
}
return sample
# Create an instance of the dataset
dataset = MyDataset(data)
# Create DataLoader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
def weights_init(m):
if isinstance(m, nn.Linear):
init.xavier_uniform_(m.weight)
if m.bias is not None:
init.constant_(m.bias, 0)
pretrained = False
if pretrained:
pre_dict = torch.load('pretrained_model.pth')
generator.load_state_dict(pre_dict['generator'])
discriminator.load_state_dict(pre_dict['discriminator'])
else:
# Apply weight initialization
generator = generator.apply(weights_init)
discriminator = discriminator.apply(weights_init)
模型训练
model_save_freq = 100
latent_dim =20
for epoch in range(num_epochs):
for batch in dataloader:
real_data_batch = batch['input']
# Train discriminator on real data
real_labels = torch.FloatTensor(np.random.uniform(0.9, 1.0, (batch_size, 1)))
disc_optimizer.zero_grad()
output_real = discriminator(real_data_batch)
loss_real = criterion(output_real, real_labels)
loss_real.backward()
# Train discriminator on generated data
fake_labels = torch.FloatTensor(np.random.uniform(0, 0.1, (batch_size, 1)))
noise = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim)))
generated_data = generator(noise)
output_fake = discriminator(generated_data.detach())
loss_fake = criterion(output_fake, fake_labels)
loss_fake.backward()
disc_optimizer.step()
# Train generator
valid_labels = torch.FloatTensor(np.random.uniform(0.9, 1.0, (batch_size, 1)))
gen_optimizer.zero_grad()
output_g = discriminator(generated_data)
loss_g = criterion(output_g, valid_labels)
loss_g.backward()
gen_optimizer.step()
# Print progress
print(f"Epoch {epoch}, D Loss Real: {loss_real.item()}, D Loss Fake: {loss_fake.item()}, G Loss: {loss_g.item()}")
评估和可视化结果
import seaborn as sns
# Generate synthetic data
synthetic_data = generator(torch.FloatTensor(np.random.normal(0, 1, (real_data.shape[0], noise_dim))))
# Plot the results
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
fig.suptitle('Real and Synthetic Data Distributions', fontsize=16)
for i in range(2):
for j in range(3):
sns.histplot(synthetic_data[:, i * 3 + j].detach().numpy(), bins=50, alpha=0.5, label='Synthetic Data', ax=axs[i, j], color='blue')
sns.histplot(real_data[:, i * 3 + j].numpy(), bins=50, alpha=0.5, label='Real Data', ax=axs[i, j], color='orange')
axs[i, j].set_title(f'Parameter {i * 3 + j + 1}', fontsize=12)
axs[i, j].set_xlabel('Value')
axs[i, j].set_ylabel('Frequency')
axs[i, j].legend()
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
# Create a 2x3 grid of subplots
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Comparison of Real and Synthetic Data', fontsize=16)
# Define parameter names
param_names = ['Parameter 1', 'Parameter 2', 'Parameter 3', 'Parameter 4', 'Parameter 5', 'Parameter 6']
# Scatter plots for each parameter
for i in range(2):
for j in range(3):
param_index = i * 3 + j
sns.scatterplot(real_data[:, 0].numpy(), real_data[:, param_index].numpy(), label='Real Data', alpha=0.5, ax=axs[i, j])
sns.scatterplot(synthetic_data[:, 0].detach().numpy(), synthetic_data[:, param_index].detach().numpy(), label='Generated Data', alpha=0.5, ax=axs[i, j])
axs[i, j].set_title(param_names[param_index], fontsize=12)
axs[i, j].set_xlabel(f'Real Data - {param_names[param_index]}')
axs[i, j].set_ylabel(f'Real Data - {param_names[param_index]}')
axs[i, j].legend()
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()


尽管我们的模型很简单,但合成数据和真实数据的分布及数学形态非常相似!训练过程和模型架构可以进行修改,以提高准确性,但这不是我们在这里关注的内容。这个模型可以很容易地调整,以生成其他应用的合成数据,适用于具有更多参数和更复杂的真实物理系统。感谢您抽出时间阅读,希望您觉得这篇文章有启发性。使用生成对抗网络(GANs)可以做很多事情,它是一个非常激动人心的话题,绝对值得尝试这个代码,了解 GAN 的整体概念,然后开始尝试其他的创意!祝您好运!
除非另有注明,所有图片均为作者提供
参考文献
[1] Jaideep Pathak, Shashank Subramanian, Peter Harrington, Sanjeev Raja, Ashesh Chattopadhyay, Morteza Mardani, Thorsten Kurth, David Hall, Zongyi Li, Kamyar Azizzadenesheli, Pedram Hassanzadeh, Karthik Kashinath, Animashree Anandkumar. (2022). FourCastNet:一种使用自适应傅里叶神经算子的全球数据驱动高分辨率天气模型。arXiv:2202.11214. doi.org/10.48550/arXiv.2202.11214
[2] Ghosheh, Ghadeer & Jin, Li & Zhu, Tingting. (2023). 生成对抗网络在合成结构化电子健康记录中的应用调研。ACM 计算调查。10.1145/3636424。
地理位置编码器
理解现代技术如何在神经网络中编码地理坐标
·发表于 Towards Data Science ·10 分钟阅读·2024 年 10 月 1 日
--

图片由 CHUTTERSNAP 提供,来源:Unsplash
机器学习中的归纳偏差是基于某些先验知识对模型施加的约束。作为人类,我们可以识别出飞翔在天空中的鸟或栖息在树上的鸟。此外,我们不需要检查每一朵云或完全观察整棵树,就能知道我们正在看的是鸟而不是其他东西。这些在视觉过程中的偏差通过两个属性在卷积层中编码:
权重共享:相同的卷积核权重在输入通道的整个宽度和高度上被重复使用。
局部性:卷积核的宽度和高度远小于输入。
我们还可以通过选择模型的输入特征来编码归纳偏差,这可以看作是对模型本身的约束。为神经网络设计输入特征涉及表达能力和归纳偏差之间的权衡。一方面,我们希望允许模型具备灵活性,学习超出人类能够检测和编码的模式。另一方面,若没有任何归纳偏差的模型将难以学到任何有意义的内容。
在本文中,我们将探讨设计有效位置编码器的归纳偏差,特别是在地理坐标上的应用。地球上的位置可以作为广泛预测任务的有用输入,包括图像分类。正如我们将看到的,直接使用纬度和经度作为输入特征是约束不足的,最终会使模型更难学习有意义的内容。相反,更常见的做法是将纬度和经度的先验知识编码为一种非参数化的重新映射,我们称之为位置编码器。
介绍:变换器中的位置编码器
为了更广泛地说明选择有效位置编码器的重要性,我们首先来看看变换器模型中著名的位置编码器。我们从一个观点开始:传递给注意力模块的标记输入的表示应该包含一些关于它在所属序列中位置的信息。问题是:我们应当如何将位置索引(0, 1, 2…)编码为一个向量?
假设我们有一个与位置无关的标记嵌入。一种可能的方法是将索引值直接加到这个嵌入向量中或与之连接。以下是为什么这种方法效果不佳的原因:
两个嵌入之间的相似性(点积)——在它们的位置被编码之后——应当独立于序列中标记的总数。序列的最后两个标记无论序列长度是 5 个单词还是 50 个单词,它们应该记录相同的相似性。
两个标记之间的相似性不应该依赖于它们位置的绝对值,而只应依赖于它们之间的相对距离。即使编码的索引被归一化到[0, 1]的范围内,位置为 1 和 2 的相邻两个标记也会记录比序列中稍后位置的相同两个标记更低的相似性。
原始的《Attention is All You Need》论文[1]提出的做法是将位置索引pos编码为k个不同正弦波的离散“快照”,其中k是标记嵌入的维度。这些快照计算如下:

其中 i = 1, 2, …, k / 2. 然后,得到的k维位置嵌入将按元素与对应的标记嵌入相加。
这种编码的直觉是,任何两个嵌入的快照相位差越大,它们对应位置的距离就越远。两个不同位置的绝对值不会影响它们快照的相位差。此外,由于任何正弦波的范围是[-1, 1],位置嵌入的幅度不会随着序列长度的增加而增长。
我不会进一步详细讨论这个特定的定位编码器,因为已经有几篇优秀的博客文章做了深入讲解(见 [2])。希望你现在可以理解,一般来说,仔细考虑如何编码位置为什么是如此重要。
地理位置编码器
现在让我们转向地理位置编码器。我们希望训练一个神经网络,根据地球表面上的位置预测某些感兴趣的变量。我们该如何将球面坐标系中的位置(λ, ϕ)——即经度/纬度对——编码为一个向量,以便将其作为我们网络的输入?

由 Peter Mercator 编著, 公有领域。
简单方法
一种可能的方法是直接使用纬度和经度值作为输入。在这种情况下,我们的输入特征空间将是矩形 [-π, π] × [0*, π*],我将其称为纬度/经度空间。与变换器的定位编码器类似,这种简单的方法不幸的是也有其局限性:
注意,当你向两极移动时,地球表面上由 1 单位经度 (λ) 覆盖的距离会减小。纬度/经度空间并不能保持地球表面上的距离。
注意,地球上与坐标(λ, ϕ)对应的位置应与坐标(λ + 2*π,ϕ)对应的位置相同。但在纬度/经度空间,这两个坐标之间相距甚远。纬度/经度空间并不能保持周期性:即球面坐标如何在地球表面上环绕。
为了直接从纬度/经度空间的输入中学习有意义的内容,神经网络必须学会如何自行编码关于地球表面曲率的这些属性——这是一个具有挑战性的任务。那么,我们该如何设计一个已经编码了这些归纳偏置的定位编码器呢?让我们探索一下早期的解决方案,以及它们是如何随着时间的推移演变的。
早期的定位编码器
基于离散化的方法(2015 年)
第一篇提出将地理坐标特征化并作为卷积神经网络输入的论文名为《通过位置上下文改善图像分类》[3]。该论文于 2015 年发表,提出并评估了多种不同的特征化方法,旨在训练更好的地理标记图像分类模型。
它们每种方法背后的理念是将地球上的位置直接编码为一组可以从辅助数据源计算得出的数值特征。一些例子包括:
将美国划分为均匀间隔的网格,在纬度/经度空间中使用独热编码将给定位置编码为一个向量,依据其所在的网格来确定。
查找与给定位置对应的美国邮政编码,然后从美国社区调查(ACS)中检索关于该邮政编码的人口统计数据,涉及年龄、性别、种族、居住条件等。使用独热编码将这些数据转化为数值向量。
对于一组选择的 Instagram 标签,统计不同距离给定位置的标签数量,并将这些计数值连接成一个向量。
从 Google Maps 中检索颜色编码地图,用于表示降水、土地覆盖、国会选区等各种特征,并将每个特征的数值颜色值连接成一个向量。
请注意,这些位置编码不是连续的,并且不保留地球表面上的距离。在第一个例子中,两个邻近的地点如果落入不同的网格,它们在特征空间中的距离与两个位于国家两端的地点相等。此外,这些特征大多依赖于辅助数据源的可用性,必须精心设计,并需要特定的标签、地图特征、调查数据等选择。这些方法无法很好地推广到地球上任意位置。
WRAP (2019)
在 2019 年,一篇名为“仅使用位置的地理先验进行细粒度图像分类”的论文[4]迈出了通向今天广泛使用的地理位置编码器的重要一步。与上一节的工作类似,这篇论文研究了如何利用地理坐标来改进图像分类模型。
他们的位置编码器背后的关键思想是利用正弦和余弦函数的周期性来编码地理坐标如何在地球表面上环绕。给定纬度和经度(λ,ϕ),都被归一化到[-1, 1]范围内,WRAP 位置编码器定义如下:

与上一节中的方法不同,WRAP 是连续的,并且可以轻松计算地球上任何位置的值。然后,论文通过实验证明,在这些特征上训练一个全连接网络,并将它们与潜在图像特征结合,可以在细粒度图像分类基准测试中提高性能。
双重傅里叶球面方法
WRAP 编码器看起来简单,但它成功地编码了关于地理位置的一个关键归纳偏差,同时保持了表达性和灵活性。为了理解为什么这种位置编码器选择如此强大,我们需要理解双重傅里叶球面(DFS)方法[5]。
DFS 是一种方法,它将定义在单位球面上的任何实值函数f(x,y,z)转换为在矩形[-π,π] × [-π,π]上定义的 2π周期函数。从高层次来看,DFS 包括两个步骤:
- 使用球面坐标重新参数化函数f(x,y,z),其中(λ,ϕ)∈ [-π,π] × [0*,π*]

2. 基于重新参数化的* f*(本质上是“在其上翻倍”)在矩形[-π,π] × [-π,π]上定义一个新的分段函数。
请注意,地球表面的 DFS 重新参数化(步骤 1)保持了我们之前讨论的性质。首先,当ϕ趋近于 0 或±π(地球的极点)时,经过重新参数化后,两点(λ,ϕ)和(λ',ϕ)之间的距离会减小。此外,重新参数化是周期性的并且平滑的。
傅里叶定理
任何连续的、周期性的实值函数都可以表示为正弦和余弦的加权和,这是一个事实。这个理论叫做傅里叶定理,这种加权和表示形式叫做傅里叶级数。事实证明,任何经过 DFS 变换的函数都可以用有限的正弦和余弦表示。它们被称为DFS 基函数,如下所示:

这里,∪表示集合的并集,S是正弦波的尺度(即频率)集合。
基于 DFS 的位置编码器
请注意,DFS 基函数集包含了 WRAP 位置编码器中的四项。“Sphere2Vec”[6]是最早观察到这一点的论文,提出了基于 DFS 的统一位置编码器视角。实际上,考虑到这种泛化,我们可以通过选择 DFS 基函数的任何子集来构建一个地理位置编码器——WRAP 仅仅是其中的一种选择。想了解更多,可以参考[7]中对各种基于 DFS 的位置编码器的全面概述。
为什么基于 DFS 的编码器如此强大?
考虑在基于 DFS 的位置编码器上训练一个线性层时会发生什么:网络的每个输出元素都是所选 DFS 基函数的加权和。因此,网络可以被解释为学习到的傅里叶级数。由于几乎任何定义在球面上的函数都可以使用 DFS 方法进行变换,因此可以推断,在线性层上训练的 DFS 基函数足够强大,能够在球面上编码任意函数!这类似于多层感知器的通用逼近定理。
在实际应用中,仅使用 DFS 基函数的一个小子集用于位置编码器,并在其上训练一个全连接网络。非参数化位置编码器与神经网络的组合通常被称为位置编码器:

地理位置编码器的示意图。图片来自作者。
当前的地理位置编码器
如我们所见,基于 DFS 的位置编码器可以有效地编码我们对于地球表面曲率的归纳偏见。基于 DFS 的编码器的一个局限性是它们假设矩形域[-π,π] × [-π,π]。虽然这个假设通常是可以接受的,因为 DFS 重新参数化已经考虑到距离在接近极地时如何发生扭曲,但这个假设在极地本身(ϕ = 0,± π)的地方会失效,因为这些点在矩形域中是线段,而在地球表面则会塌缩成单一的点。
一种名为球面调和函数的不同基函数最近成为了一种替代方案。球面调和函数是本地定义在球面上的基函数,而非矩形区域。与基于 DFS 的编码器相比,它们已被证明在地球极地区域展现出较少的伪影[7]。值得注意的是,球面调和函数是 SatCLIP 位置编码器[8]中使用的基函数,SatCLIP 是一个近期的地理坐标基础模型,以 CLIP 的风格进行训练。
尽管地理位置编码器在 2010 年代初期始于离散的、手工设计的特征,但这些特征并不容易推广到任意位置,并且需要特定领域的元数据,如土地覆盖和人口统计数据。今天,地理坐标作为神经网络输入变得更为普遍,因为已有简单但有意义且富有表现力的编码方式出现。随着网络规模数据集的崛起,这些数据集通常带有地理标签,使用地理坐标作为预测任务的输入的潜力如今巨大。
参考文献
[1] A. Vaswani, N. Shazeer, N. Parmar, J. Uszkoreit, L. Jones, A. N. Gomez, L. Kaiser & I. Polosukhin, Attention Is All You Need (2017), 第 31 届神经信息处理系统大会
[2] A Kazemnejad, Transformer Architecture: The Positional Encoding (2019), Amirhossein Kazemnejad’s Blog
[3] K. Tang, M. Paluri, L. Fei-Fei, R. Fergus, L. Bourdev, Improving Image Classification with Location Context (2015)
[4] O. Mac Aodha, E. Cole, P. Perona, Presence-Only Geographical Priors for Fine-Grained Image Classification (2019)
[5] 双重傅里叶球面方法,维基百科
[6] G. Mai, Y. Xuan, W. Zuo, K. Janowicz, N. Lao Sphere2Vec: Multi-Scale Representation Learning over a Spherical Surface for Geospatial Predictions (2022)
[7] M. Rußwurm, K. Klemmer, E. Rolf, R. Zbinden, D. Tuia, Geographic Location Encoding with Spherical Harmonics and Sinusoidal Representation Network (2024), ICLR 2024
SatCLIP: 全球通用的卫星图像位置嵌入(2024)由 K. Klemmer, E. Rolf, C. Robinson, L. Mackey, M. Rußwurm 编写。
地理空间索引解释:Geohash、S2 与 H3 的比较
·发布于Towards Data Science ·12 分钟阅读·2024 年 1 月 1 日
--
地理空间索引,或称地理编码,是将经纬度对索引到地理空间的小细分单元的过程,这是我们数据科学家在面对地理空间数据时经常使用的技术。
尽管首个流行的地理空间索引技术“Geohash”直到 2008 年才被发明,但将经纬度对索引到可管理的空间子细分并非一个新概念。几百年来,政府一直在将土地划分为州、省、县和邮政编码等区域,用于各种应用,如人口普查和选举投票的汇总。
与政府使用的手动技术不同,我们数据科学家使用现代计算技术来执行这种空间细分,并且是为了我们自己的目的:分析、特征工程、按地理细分的精细化 AB 测试、地理空间数据库的索引等。
地理空间索引是计算机科学的一个成熟领域,地理空间索引工具可以为我们的模型和分析带来强大的功能和丰富的内涵。让地理空间索引技术更加令人兴奋的是,深入了解它们的“引擎盖”下,揭示了其他数学工具的多样结合,例如空间填充曲线、地图投影、镶嵌图等!
本文将探讨今天最流行的三种地理空间索引工具——它们的起源……
从 XAI 中获得更多:10 个技巧
可解释的人工智能不仅仅是应用算法
·发表于 Towards Data Science ·阅读时间:9 分钟·2024 年 2 月 27 日
--

图片来源:Marten Newhall 发表在 Unsplash
我记得第一次使用SHAP时,嗯,应该说是尝试使用它。我想了解一个用超过 40 个特征训练的 XGBoost 模型,而这些特征中的许多高度相关。图表看起来很酷!但是,差不多就是这样。
当时完全不清楚模型是如何进行预测的。而且,这并不是XAI方法的错……底层数据一团糟。这是我第一次意识到:
XAI 方法不是万能的解决方案。
你不能简单地将它们应用到复杂模型上,然后期待能够得到合理的内部工作原理解释。然而,如果正确使用,它们能提供极为深刻的洞察。
自从第一次尝试理解黑箱模型以来,我学到了很多东西。我将这些经验总结成了 10 个技巧。如下所示,它们大致分为 3 组。前四个技巧侧重于用于训练模型的底层数据。接下来的四个技巧关注你,作为 XAI 方法的用户。最后两个技巧则深入探讨更多技术性考量。

(来源:作者)
你可能还会喜欢关于这个主题的视频。如果你想了解更多,可以查看…
在 Python 中开始使用 SQLite3,创建表格并获取行
学习使用 SQLite — 世界上使用最广泛的数据库管理系统
·发布于 Towards Data Science ·10 分钟阅读·2024 年 6 月 18 日
--

照片由 Sunder Muthukumaran 提供,来源于 Unsplash
概述
介绍 — 什么是 SQLite 和 SQLite3?
创建我们的第一个 SQLite 数据库
连接器和游标
创建表格
向数据库插入行
从数据库中获取行
总结
介绍 — 什么是 SQLite 和 SQLite3?
大多数现代 IT 专业人员的核心技能之一是 结构化查询语言 (SQL)。这是一种声明式语言,用于与关系型数据库进行交互。数据工程师和分析师通常使用 SQL 来运行数据管道并调查数据中的有用关系。
直接使用像 PostgreSQL 或 MySQL 这样的常见**数据库管理系统 (DBMS)**可能会让没有 SQL 经验的人感到有些害怕。幸运的是,SQLite 是学习 SQL 基础的一个极好的选择。它的设置非常简单,且易于管理,因为它没有独立的服务器进程。因此,尽管数据工程师和数据分析师通常会使用与 SQLite 不同的数据库管理系统,但它仍然是一个非常好的入门选择…
使用 Python 从任何 PDF 中提取带下划线的文本
一步步指南,教你如何从 PDF 文件中提取带下划线的文本并将其作为数组处理。
·发布于 Towards Data Science ·阅读时间:5 分钟·2024 年 5 月 3 日
--
💡 如果你想查看这个项目的代码,可以访问我的仓库: https://github.com/sasha-korovkina/pdfUnderlinedExtractor
PDF 数据提取可能是一个真正的难题,尤其是在你尝试提取带下划线的文本时——信不信由你,目前没有现成的解决方案或库可以直接处理这个问题。不过别担心,我会教你如何解决这个问题。

图片来源:dlxmedia.hu 于 Unsplash
理论
从 PDF 中提取带下划线的文本有几种不同的方法。你可以考虑使用 OCR 来检测带底线的文本组件,或者深入研究 PyMuPDF 的标记功能。然而,我发现 OCR 往往会出现问题,表现为不一致和低精度。PyMuPDF 我也不太喜欢——它需要细致的参数调优,非常耗时。而且,如果设置错误,你可能会丢失大量数据。
需要记住的是,PDF 文件是:
非结构化数据: PDF 元素通常缺乏分组或分类,这使得系统地搜索内容变得更加复杂。
文本格式识别: 在 PDF 中检测特定的文本格式,如粗体或下划线,通常非常困难,因为大多数 Python 库并未有效支持这一功能。
但别怕,我们有一个策略来解决这个问题。
策略
将 PDF 转换为结构化 XML:首先将 PDF 文档转换为结构化 XML 格式,以便于数据处理。
提取所需组件:从 XML 中识别并隔离出与我们需求相关的特定组件。
使用 OCR(光学字符识别)技术处理提取的坐标,获取下划线文本数据作为数组。
提取并输出下划线文本:最后,从文档中提取下划线文本并显示或打印结果。
代码
- PDF 转 XML
我们将使用pdfquery库,这是我遇到的最全面的 PDF 转 XML 转换器。
2. 研究 XML
XML 中有几个关键组件是我们感兴趣的:
LTRect — 有时,库会将下划线文本解析为文本下方的一个最小宽度矩形
LTLine — 有时,它会将轮廓识别为一个独立的线条组件。

这就是您的输出 XML 的样子。图像由作者创建。
LTRect 组件示例:
<LTRect y0="563.787" y1="629.964" x0="367.942" x1="473.826" width="105.884" height="66.178" bbox="[367.942, 563.787, 473.826, 629.964]" linewidth="0" pts="[[367.942, 629.964], [473.826, 629.964], [473.826, 563.787], [367.942, 563.787]]">
因此,通过将整个文档转换为 XML 格式,我们可以将其结构作为 XML 组件进行复制,接下来我们就这么做!
结构复制
现在,我们将重新创建文档的结构作为边界框坐标。为此,我们将解析 XML,定义页面、组件框、线条和矩形,然后用三种不同的颜色将它们全部绘制在我们的画布上。
PDF 对象可视化。
这是我们的初始 PDF,它是通过在 Microsoft Word 中生成的,通过导出包含一些下划线的文档为 PDF 文件格式:

初始文档,包含示例文本。图像由作者创建。
应用上述算法后,这就是我们得到的视觉表示:

文档的框架轮廓:黑色—所有文档组件,蓝色—下划线文本。图像由作者创建。
这张图片展示了我们文档的结构,其中黑色框用于描述页面上的所有组件,蓝色框用于描述 LTRect 元素,因此出现了下划线文本。
文本叠加
现在,让我们通过以下代码行将 PDF 中的所有文本可视化,显示在各自的位置上:
can.drawString(text_x, text_y, text)
这是输出结果:

基于文本位置和下划线的 PDF 重新创建。图像由作者创建。
请注意,由于 pdfquery 库中标记语言的大小和字体差异,文本的位置与原始文档中的位置不完全一致。
坐标提取
作为 XML 结果,我们将得到一组下划线区域的坐标数组,在我的案例中,我将其命名为 underline_text。
一段形成 PDF 文件中下划线文本坐标数组的代码。
文本提取
过程如下:
我们识别出先前确定的坐标矩形。
我们从 PDF 中提取这些部分。
我们应用 Tesseract OCR 从每个提取的部分中提取文本。
使用坐标矩形和 Tesseract OCR 从 PDF 中提取文本的方法之所以有效,有几个原因:
文本提取的精度:通过识别特定的坐标矩形,处理过程仅针对 PDF 中的相关区域。这种聚焦的方法避免了对整个文档的无谓处理,并减少了提取不需要的文本时出现的错误。
效率:直接从 PDF 中提取预定义部分比处理整个文档要快得多。这种方法节省了计算资源和时间,尤其在处理大型文档时特别有用。
OCR 精度:Tesseract OCR 是一个强大的光学字符识别工具,能够将文本图像转换为机器可读的文本。通过向其提供精确的文本部分,它可以更准确地执行任务,因为它处理的文本背景噪声和格式问题较少,这能避免在较大、未分割的文档中混淆 OCR 过程。
这就是代码:
从 PDF 部分提取下划线文本的代码。
在运行此功能之前,请确保你的系统上已安装 tesseract。有关详细说明,请查看他们的官方安装指南:github.com/tesseract-ocr/tessdoc/blob/main/Installation.md 或者在我的 GitHub 仓库中查看:github.com/sasha-korovkina/pdfUnderlinedExtractor。
整合所有内容…
现在,如果我们取任何 PDF 文件,例如这个示例文件:

测试文件的完整文本。图片由作者创建。
我们在这个文件中有一些下划线的词语:

这里下划线的内容是“ipsum”和“laboris”。图片由作者创建。
在运行上述代码后,得到的结果是:

文档中所有下划线词语的数组。图片由作者创建。
得到这个数组后,你可以将这些词语用于进一步处理!
享受使用这个脚本的过程吧!我很乐意听听你想到的任何创造性应用,或者如果你想贡献一些内容,请告诉我!❤️
我希望我早就学到的 AWS Bedrock 教程:你需要了解的一切,以便为 AWS 基础设施准备你的机器
第一部分:节省大量处理碎片化和不完整文档的时间,并在不到一小时内准备好你的开发环境
·发表于 Towards Data Science ·阅读时间 9 分钟·2024 年 10 月 22 日
--

图片由作者使用 Midjourney 生成
如何将笔记本中的一个巧妙的小型机器学习原型,开发成一个强大的全栈 Web 应用程序?虽然这个过程看起来可能令人望而生畏,但这系列多篇文章将帮助你一步一步地应对学习曲线,带你解决即使是最棘手的权限问题(AWS 就是以此而闻名)。通过本系列的学习,你将拥有一个完全功能的语言翻译应用程序可以进行实验,并掌握快速扩展自己的 GenAI 流水线所需的概念性知识。
这是我们新系列的第一部分,主题是 🌊 《构建由 Amazon Bedrock 支持的全栈 GenAI 应用程序的广泛入门指南》。
本系列文章与 Vlad Seredniy* 合作撰写。*
本系列的内容是什么?
使用 Amazon Bedrock 开发全栈应用程序可能会带来相当陡峭的学习曲线。新开发者需要应对 AWS 基础设施的复杂性以及集成 AI 流程的各种细节。试图学习 AWS 技术栈很容易变得非常沮丧,尤其是当面对零散的文档和在前端与后端系统之间定制数据流的耗时过程时。开发者常常会花费数小时观看关于 AWS 技术栈的概念视频,但这些视频的实际应用往往非常有限。
在本系列中,我们将直接面对这些挑战。我们将一步一步地教你如何使用 AWS 基础设施、React、Node.js 和 Amazon Bedrock 构建一个全栈 AI 应用,帮助你全面理解并(希望)最小化困惑。
我们将使用这个教程作为起点。我们将填补所有空白,帮助你启动,从如何设置技术栈(在本教程中)开始,然后继续理解应用程序代码库,并添加新的功能,比如音频转录、翻译和生成——本质上是实现自定义的 AI 流程。从简化 AWS 设置和揭开权限的神秘面纱,到提供前端、后端和 Lambda 函数如何协同工作的全面代码概述,我们旨在减少摩擦,降低全栈集成的门槛。在这个过程中,我们将尽力彻底回答任何在工作环境中可能不敢提问的“傻问题”。在整个过程中结束时,我们希望能为你提供所有工具和信心,去追寻你自己的 AI 想法和创作!
本系列基于这个项目*:* AWS AppSync AI Agent Playground和本教程*:* 创建一个由 Amazon Bedrock 提供支持的全栈示例 Web 应用
在第一部分(本部分)中,我们将详细讲解设置 AWS 基础设施并使用 Amazon Bedrock 启动一个示例全栈应用的必要步骤。我们将介绍如何创建你的 AWS 账户,提供关于配置 IAM 用户和获取密钥的教程视频,并安装 Node.js 和 AWS CDK 等工具。完成第一部分后,你将拥有一个完全配置的环境,准备进行定制。
在第二部分,我们将解决新开发者在开始使用 AWS 技术栈时面临的最大挑战之一:缺乏清晰的文档。我们将带你浏览项目的前端代码库,并详细介绍如何在 React 中创建组件,以捕获并上传音频文件到 S3 存储桶进行处理。
当我们与开发者交流时,他们会提到权限管理是使用 AWS 基础设施时最具挑战性的方面。在这一部分,我们将介绍一些复杂的 AWS 配置细节,如设置 IAM 角色和策略,以及为你的 S3 存储桶配置 CORS。
第三部分是最有趣的部分!我们将破解示例项目后端 Python 文件中的一个代理,集成我们的 GenAI 工具。具体来说,我们将利用AWS Transcribe来转录音频,并利用Bedrock的能力查询Anthropic 的 Claude API来翻译转录过程中的文本。
最后,在系列的结论部分,在第四部分,我们将与 Deepgram 的 AI 驱动语音生成 API 对接,生成一个可以在应用界面中播放的翻译音频文件。
当你完成这一系列内容时,你可以期待一个完整的、由 AI 驱动的语言翻译应用,并且能够理解如何用你自己的 GenAI 工具定制它。
什么时候 AWS 基础设施是你项目的正确选择?
显然,回答这个问题超出了本文的范围。然而,我们希望简要地讨论一下这个话题,区分原型制作、测试应用和大规模部署。
根据我们的经验,在我们刚开始时,在学习或尝试阶段,使用像Jupyter Notebook、Google Colab甚至Lightning AI这样的工具来实时测试和迭代想法往往要快得多。一旦我们选择了要使用的技术栈并具备了核心功能(并解决了技术问题),我们会构建一个简易的原型。在这个阶段,我们可能会使用一个比 AWS 更简单的系统,比如Flutterflow或Streamlit。
然而,当一个应用需要处理增加的流量或强大的安全功能时,AWS 具有巨大的优势,因为它可以在基础设施需求增长时无缝扩展。
精简的 AWS 配置
本文的剩余部分将专注于正确设置示例 AWS Bedrock 应用,并使其能够在你的开发机器上本地运行。由于亚马逊网站上的文档碎片化,这一部分可能特别棘手。我们将扩展原教程前提条件部分的说明,并详细描述正确安装所有要求的步骤。
📌 注意:sudo 和命令行使用
接下来的很多部分都需要大量使用命令行 / 终端。如果你遇到权限错误,可以尝试在命令前加上
***sudo***。例如,不要使用
***npm install -g aws-cdk***,可以试试***sudo npm install -g aws-cdk***。
创建你的 AWS IAM 用户
对初学者来说,设置 AWS 账户并配置正确的凭证可能是最初的(也是最让人沮丧的)部分。下面是一个逐步的过程,让你不必花费数小时在 AWS 网站和文档中来回搜索,拼凑正确的操作顺序。
注册一个 AWS 账户。
然后登录并在控制台中导航到 IAM。点击用户 → 创建用户
现在,你需要在创建用户流程中自定义以下详细信息。
步骤 1:指定用户详细信息
勾选 为用户提供 AWS 管理控制台访问权限
选择 我想创建一个 IAM 用户
在控制台密码部分,选择 自定义密码。输入一个安全且易于记住的密码。
取消勾选 用户首次登录时必须重置密码,这样用户就不需要在第一次登录时重置密码。
步骤 2:设置权限
选择 直接附加策略。
在搜索栏中,输入 AdministratorAccess
在搜索结果中勾选 AdministratorAccess 策略旁的复选框。
接下来,在 步骤 3 中点击 创建用户,然后在 步骤 4 中点击 返回用户列表。记得将你的密码保存在安全的地方。
设置好 IAM 用户后,你需要安装所需的软件工具,从 Node.js 开始。
安装 Node 和 NPM(Node 包管理器)
在本地机器上设置任何 Amazon 开发工具之前,你必须已经安装了 NPM(Node 包管理器)。我们的下一步将是安装 Node.js,它包含了 NPM。示例项目要求 Node.js v18 或更高版本,因此我们将安装最新版本以满足这个依赖。
对于 Mac:
- 访问 Node.js 官方网站。你可以将命令复制粘贴到终端中,通过命令行下载并安装,或者直接下载 Mac 版本的安装程序。
对于 Linux(Ubuntu):
sudo apt install nodejs npm
验证安装:
node --version
npm --version
安装好 Node.js 后,下一步是使用 NPM 设置 AWS 云开发工具包(CDK)。
在开发机器上配置你的 AWS 账户
现在,我们将向你展示如何使用 NPM 安装 AWS CDK。示例应用需要 AWS CDK 2.103 或更高版本,因此我们将下载最新版本(以确保满足此要求)。
输入以下命令来全局安装 AWS CDK:
sudo npm install -g aws-cdk
让我们验证一下安装:
cdk --version
安装好 CDK 后,我们可以安装 AWS CLI(命令行界面),它将帮助你配置 AWS 账户凭证。
安装和配置 AWS CLI
设置账户凭证对正确设置示例应用程序至关重要,因此我们来介绍如何安装 AWS CLI 并正确设置你的账户。
安装 AWS CLI
安装 AWS CLI 的方法在不同操作系统上稍有不同。
对于 Linux,下载并运行安装程序:
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
unzip awscliv2.zip
sudo ./aws/install
对于 Mac,从AWS CLI 网站下载 *.pkg 文件,并按照安装向导中的指示进行操作。
让我们通过验证安装是否成功来完成这一步:
aws --version
创建访问密钥
现在我们已经安装了 AWS CLI,可以继续创建访问密钥。如果你已经按照视频中的步骤 1 创建了访问密钥,那就跳过这一步。如果没有,请回顾视频结尾部分或按照以下指示操作:
在 IAM 控制台中,查找你想要创建访问密钥的用户。
在用户摘要部分,点击“创建访问密钥”
勾选“我理解上述推荐,并希望继续创建访问密钥”
将你的访问密钥和秘密访问密钥保存在安全的位置。
配置 AWS CLI
我们快完成了!只需最后一步,你的账户就能在开发机器上正确配置。
运行此命令启动配置过程:
aws configure
系统将要求你输入以下内容:
AWS 访问密钥 ID(从你的安全位置复制/粘贴)
AWS 秘密访问密钥(从你的安全位置复制/粘贴)
默认区域名称(我们使用了
*us-east-1*)默认输出格式(使用
*json*)
完成此步骤后,你的账户信息将已配置完毕,可以继续准备部署 AWS CDK 基础设施。
启动 CDK
接下来,我们将展示如何在你希望部署的区域中启动 CDK。
cdk bootstrap aws://ACCOUNT-NUMBER/REGION
在这里,你需要使用你的 12 位账户号码(可以在 AWS 控制台屏幕右上角找到)作为 ACCOUNT-NUMBER,以及你的区域(我们使用*us-east-1*)作为 REGION。
你快完成了!现在只需安装项目所需的其他工具,就能设置好示例应用程序。
安装 Yarn、Git 和 Docker
本项目使用 yarn 来管理依赖项。你可以通过 NPM 安装它:
sudo npm install -g yarn
你可以验证 yarn 是否正确安装:
yarn --version
Git:
Git 在 Mac 和大多数 Linux 发行版上是预装的。你可以验证它是否已安装。
git --version
如果尚未安装,在 Ubuntu 上你可以使用 apt-get sudo apt-get install git,而在 Mac 上可以从官方 Git 网站下载。
Docker:
在 Mac 上,你可以从官方 Docker 网站下载 Docker Desktop for Mac,然后只需按照安装向导进行操作。
对于 Linux(Ubuntu),你可以参考官方的Ubuntu 上的 Docker 安装指南。
结论
恭喜你!你已完成本系列的第一部分,现在应该已经设置好开发环境并安装了所有依赖项,以便你可以部署这个 AWS 示例应用。你现在可以从原始教程的设置部分继续,使用 Git 下载项目代码。
下一步:
AWS 服务和设置过程可能会有所变化,因此尽管我们已经尽力使这些说明保持最新,但如果遇到问题,你可以随时查阅 AWS 文档以获取最新信息。
我们将在接下来的几周内深入讨论为交互式翻译应用添加所有自定义功能的实现细节。这将为你提供机会,深入理解相关概念,并构建对技术栈的理解。
👉 你将在第二部分找到的内容:
在第二部分,我们将放大视角,概述示例应用的架构。我们将展示如何自定义前端以捕获并存储语音录音,以及如何通过 API 将其传递到后台。
使用 Facebook Prophet 预测时间序列数据入门
本文旨在通过使用 Prophet 进行的动手教程,消除入门时间序列分析的障碍。
·发表于Towards Data Science ·阅读时间:9 分钟·2024 年 1 月 30 日
--
当开始学习数据科学时,时间序列分析是人们常常希望自己尝试的一个内容!这里的基本理念是通过学习历史模式来预测未来。典型的应用案例可能是天气预测或销售预测。但是,下面这位聪明的预言家与这些有什么关系呢?!

由 leonardo.ai 生成的图像
本文旨在通过使用一种最简单的工具之一——Facebook Prophet——在Google Colab中进行的动手教程,消除入门时间序列分析的障碍(这两个工具都是免费的!)。如果你想立即开始,可以跳过接下来的两章,我将在其中简要介绍时间序列的基本原理以及 Facebook Prophet 本身。祝你玩得开心!
本文分为三个主要部分:
#1 时间序列分析原理简要介绍
#2 Facebook Prophet 简介
#3 如何在 Google Colab 中使用 Prophet 的动手教程(免费)
#1 时间序列分析的基本原则
想象一下,你是一家消费品商店的经理,你想预测即将到来的产品需求,以更好地管理供应。在这种情况下,一个合理的机器学习方法是进行时间序列分析,这涉及到理解、建模并根据顺序数据点进行预测。[1]
以下图形展示了历史产品需求(深蓝色线条)的人工发展趋势,可以用来分析时间序列模式。我们的最终目标是尽可能准确地预测(红色虚线)实际的未来需求(浅蓝色线条):

绘制的时间序列示例。图像来自作者
时间序列通常分解为三个主要组件:
趋势:数据中的长期变化或总体方向。
季节性:在规律间隔时间内重复出现的波动或模式
残差/误差:数据中剩余或遗留的变化
将时间序列分解为这三大组件,通常称为加性或乘性分解,能够帮助分析师更好地理解数据的潜在结构和模式。这种理解对于选择合适的预测模型以及基于历史数据做出准确预测至关重要。[2]

将时间序列分解为趋势+季节性+误差的示例。图像来自作者
#2 什么是 Facebook Prophet?
Prophet 是 Facebook 数据科学团队发布的一个开源工具,基于加性模型生成时间序列预测数据,其中非线性趋势与季节性和假期效应相匹配。其设计原理使得在没有太多关于底层模型的知识的情况下,也能进行参数调整,这使得该方法适用于统计知识较少的团队。[3]
Prophet 特别适用于商业预测应用,并因其易用性和在处理各种时间序列数据方面的有效性而获得广泛关注。和所有工具一样,需要注意的是,虽然 Prophet 功能强大,但选择预测方法仍然依赖于数据的具体特征和分析目标。一般来说,不能保证 Prophet 比其他模型表现更好。然而,Prophet 具备一些有用的功能,例如反映季节性变化的情况(例如新冠疫情前后的变化)或将封锁视为一次性假期。
若想深入了解 Meta(Facebook)本身的介绍,请观看下面的YouTube 视频。
在接下来的教程中,我们将使用 Python 实现并应用 Prophet。不过,你也可以非常高兴地使用 R 进行分析!
#3 使用 Prophet 的实操教程
如果你对编码环境的经验有限或没有访问权限,我建议使用Google Colaboratory(“Colab”),它类似于“一个 免费的 Jupyter notebook 环境,无需设置,完全在云端运行。”虽然本教程强调了 Colab 的简便性和优势,但也有一些缺点,如与专业云环境相比,计算能力较弱。然而,我认为 Colab 可能是一个不错的选择,用于与 Prophet 一起迈出第一步。
在 Colab 中设置基本的时间序列分析环境,你可以按照以下两个步骤操作:
打开
colab.research.google.com/并注册一个免费账户在 Colab 中创建一个新的notebook
安装并使用prophet包:
pip install prophet
from prophet import Prophet
加载并准备数据
我上传了一个小型虚拟数据集,代表本地公交公司(2012–2023 年)的月度乘客数量。你可以在GitHub 上找到数据。
作为第一步,我们将使用 pandas 加载数据,并创建两个独立的数据集:一个包含 2012 到 2022 年的训练子集,以及一个包含 2023 年的测试子集。我们将使用第一个子集来训练我们的时间序列模型,并预测 2023 年的乘客数量。通过第二个子集,我们稍后将能够验证模型的准确性。
import pandas as pd
df_data = pd.read_csv("https://raw.githubusercontent.com/jonasdieckmann/prophet_tutorial/main/passengers.csv")
df_data_train = df_data[df_data["Month"] < "2023-01"]
df_data_test = df_data[df_data["Month"] >= "2023-01"]
display(df_data_train)
显示命令的输出如下所示。数据集包含两列:年份-月份的组合标识,以及该月的乘客数量的数值列。默认情况下,Prophet 设计为处理每日(甚至每小时)的数据,但我们将确保月度模式也可以使用。

乘客数据集。图像由作者提供
分解训练数据
为了更好地理解我们虚拟数据中的时间序列组件,我们将进行快速分解。为此,我们从statsmodels库中导入该方法,并对我们的数据集进行分解。我们选择了加法模型,并指定数据中的一个周期包含 12 个元素(月)。如果是每日数据集,则周期=365。
from statsmodels.tsa.seasonal import seasonal_decompose
decompose = seasonal_decompose(df_data_train.Passengers, model='additive', extrapolate_trend='freq', period=12)
decompose.plot().show()
这段简短的代码将为我们提供时间序列本身的视觉印象,特别是关于趋势、季节性和残差随时间变化的情况:

乘客虚拟数据的分解元素。图像由作者提供
现在我们可以清晰地看到过去 10 年中,显著上升的趋势以及每年可识别的季节性模式。根据这些迹象,我们现在预计模型会预测未来年份夏季的乘客数量会有所增加,并且跟随季节性波峰。但让我们试试看——是时候应用一些机器学习了!
使用 Facebook Prophet 进行模型拟合
在 Prophet 中拟合模型时,至少需要有一个“ds”(日期戳)和一个“y”(要预测的值)列。我们应该确保我们的列名已被重命名,以反映这一点。
df_train_prophet = df_data_train
# date variable needs to be named "ds" for prophet
df_train_prophet = df_train_prophet.rename(columns={"Month": "ds"})
# target variable needs to be named "y" for prophet
df_train_prophet = df_train_prophet.rename(columns={"Passengers": "y"})
现在,魔法可以开始了。拟合模型的过程相当直接。然而,请查看文档,以了解我们在这一步骤中可以调整的大量选项和参数。为了简化,我们暂时将拟合一个简单模型,不做进一步的调整——但请记住,现实世界中的数据永远不是完美的:你以后肯定需要进行参数调优。
model_prophet = Prophet()
model_prophet.fit(df_train_prophet)
这就是我们拟合模型所需做的一切。让我们进行一些预测吧!
进行预测
我们需要对一个包含“ds”列(包含你想要进行预测的日期)的表格进行预测。要设置这个表格,使用make_future_dataframe方法,它会自动包含历史日期。通过这种方式,你可以看到模型如何拟合过去的数据并预测未来。由于我们处理的是月度数据,我们会通过“freq=12”来指定频率,并要求预测的未来时间段为 12 个月(“periods=12”)。
df_future = model_prophet.make_future_dataframe(periods=12, freq='MS')
display(df_future)
这个新的数据集包含了训练期和我们想要预测的额外 12 个月数据:

未来数据集。图片来源:作者
要进行预测,我们只需调用 Prophet 的predict方法,并提供未来数据集。预测结果将包含一个包含许多不同列的大数据集,但我们只关注预测值yhat以及不确定性区间yhat_lower和yhat_upper。
forecast_prophet = model_prophet.predict(df_future)
forecast_prophet[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].round().tail()
下表给我们一些关于如何生成和存储输出的概念。对于 2023 年 8 月,模型预测的乘客数量为 532 人。默认设置为 80%的不确定性区间,简单来说,它告诉我们在这个月,大多数情况下,乘客数量可能在 508 人到 556 人之间。

预测子集。图片来源:作者
最后,我们想要可视化输出,以更好地理解预测结果和区间。
可视化结果
为了绘制结果,我们可以利用 Prophet 的内置绘图工具。通过 plot 方法,我们可以显示原始时间序列数据和预测值。
import matplotlib.pyplot as plt
# plot the time series
forecast_plot = model_prophet.plot(forecast_prophet)
# add a vertical line at the end of the training period
axes = forecast_plot.gca()
last_training_date = forecast_prophet['ds'].iloc[-12]
axes.axvline(x=last_training_date, color='red', linestyle='--', label='Training End')
# plot true test data for the period after the red line
df_data_test['Month'] = pd.to_datetime(df_data_test['Month'])
plt.plot(df_data_test['Month'], df_data_test['Passengers'],'ro', markersize=3, label='True Test Data')
# show the legend to distinguish between the lines
plt.legend()
除了常规的时间序列图,我们还添加了一条虚线,表示训练期的结束,从而标志着预测期的开始。此外,我们还使用了我们在一开始准备的真实测试数据集。

绘制的时间序列分析结果,包括真实测试数据和预测结果。图片来源:作者
可以看出,我们的模型表现得不错。大部分真实的乘客数据实际上都在预测的不确定性区间内。然而,夏季的预测似乎仍然偏乐观,这也是我们在往年数据中看到的模式。现在是开始探索我们可以与 Prophet 一起使用的参数和特征的好时机。
在我们的例子中,季节性并不是一个常数加性因子,而是随着时间的推移与趋势共同增长。因此,我们可能会考虑在模型拟合过程中将 seasonality_mode 从“加性”改为“乘性”。[4]
我们的教程将在这里结束,以便为您探索 Prophet 提供的众多可能性腾出时间。为了便于大家一起回顾完整代码,我将代码片段整合在这个Python 文件中。此外,您还可以将这个笔记本直接上传到 Colab 并自行运行。让我知道它对您有何帮助!
结论
Prophet 是一个强大的工具,用于预测时间序列数据中的未来值,特别是当您的数据具有像月度或年度周期这样的重复模式时。它非常用户友好,并且能够快速为您的特定数据提供准确的预测。然而,了解它的局限性也非常重要。如果您的数据没有明确的模式,或者有一些模型未曾见过的重大变化,Prophet 可能无法达到最佳表现。了解这些局限性对于明智地使用该工具至关重要。
好消息是,强烈推荐您在自己的数据集上实验 Prophet!每个数据集都是独特的,调整设置并尝试不同的方法可以帮助您发现最适合您特定情况的方式。所以,快去探索吧,看看 Prophet 如何提升您的时间序列预测。
阅读 Jonas Dieckmann 在 Medium 上的文章。分析经理和产品负责人 @ philips | 热衷并撰写关于…
medium.com](https://medium.com/@jonas_dieckmann?source=post_page-----c74ad3040525--------------------------------)
我希望您觉得这很有用。请告诉我您的想法!也欢迎在 LinkedIn 上与我联系 www.linkedin.com/in/jonas-dieckmann/ 或在 Medium 上关注我。
请参见我其他的一些文章:
[## 如何使用 ChatGPT API 从 Colab 或 Databricks 直接交互
您是否想过如何使用 OpenAI 的 API 直接与 GPT 算法交互?它很简单、免费,而且更…
pub.towardsai.net ## 如何使用 Keras API 和 Google Colab 入门 TensorFlow
分步教程:使用神经网络分析人类活动
towardsdatascience.com
参考文献
[1] Shumway, Robert H.; Stoffer, David S. (2017): 时间序列分析及其应用。Cham: Springer 国际出版。
[2] Brownlee, Jason (2017): 使用 Python 进行时间序列预测入门
[3] Rafferty, Greg (2021): 使用 Facebook Prophet 预测时间序列数据
[4] facebook.github.io/prophet/docs/quick_start.html
开始了解多模态 AI、CPU 和 GPU、独热编码以及其他适合初学者的指南
·发表于 Towards Data Science ·作为 Newsletter 发送 ·阅读时间:4 分钟 ·2024 年 11 月 21 日
--
想要写你的第一篇 TDS 文章吗?我们始终欢迎新作者的投稿。
踏出掌握新主题的第一步总是有些让人畏惧——有时甚至是非常畏惧!无论你是第一次学习算法,还是在探索令人兴奋的 LLM(大语言模型)领域,或者刚刚被指派要重构团队的数据技术栈:在没有或几乎没有相关经验的情况下迎接挑战,需要付出相当的勇气和毅力。
更有经验的从业者的冷静而细致的视角也能起到很大作用——这正是我们的作者们的强项。本周,我们汇集了几篇近期的优秀文章,专门针对那些希望扩展技能的初学者的需求。让我们卷起袖子,开始吧!
从并行计算原理到 CPU 和 GPU 架构的编程对于刚刚入门的数据科学家和机器学习工程师来说,理解内存基础和并行执行是至关重要的。Shreya Shukla的全面且易于理解的指南是让你在这一主题上打下坚实基础的完美资源,重点介绍如何为 CPU 和 GPU 架构编写代码,以完成诸如向量-矩阵乘法这样的基本任务。
多模态模型——能够“看”和“听”的 LLM如果你对 LLM 的基础知识已经非常自信,为什么不进一步探索多模态模型,它们可以接收(在某些情况下,还可以生成)多种形式的数据——从图像到代码和音频?Shaw Talebi的入门教程是一个很好的起点,它是一个新系列的第一部分,为你构建实践知识打下了坚实的基础。
机器学习中的提升算法,第 II 部分:梯度提升无论你是刚刚开始机器学习之旅,还是已经做了很长时间,但感觉可能需要复习基础知识,巩固基础从来都不是坏主意。Gurjinder Kaur对提升算法的持续探索是一个很好的例子,提供了对一些最强大模型的易于理解、易消化的解析——在本例中是梯度提升。

图片由Taria Camerino提供,来源于Unsplash
NLP 插图解读,第一部分:文本编码我们很高兴与读者分享的另一个新项目是Shreya Rao刚刚推出的系列插图指南,涵盖自然语言处理中的核心概念,这项技术正是推动许多新兴聊天机器人和人工智能应用的技术。第一部分重点介绍了几乎所有 NLP 工作流中的一个关键步骤:通过文本编码将文本数据转化为数值输入。
解码独热编码:分类数据的初学者指南如果你想了解另一种数据转换形式,不要错过Vyacheslav Efimov对独热编码的清晰简洁的介绍,它是“数据预处理过程中最基本的技术之一”,将分类特征转换为数值向量。
Excel 电子表格在大数据领域已死。公司需要更多的 Python有一种转变,往往比学习一个新主题还要困难,那就是切换到一个新工具或工作流程,尤其是当你要放弃的工具正好处于你的舒适区时。然而,正如Ari Joury 博士所解释的那样,有时为了采用基于 Python 的数据工具而暂时牺牲速度和易用性是值得的,这种情况正是 Excel 电子表格被替代的原因。
准备好这周涉足其他话题和挑战了吗?我们希望如此——最近我们发布了一些关于 LLM 应用、Python 生成的艺术、AI 伦理等方面的精彩文章:
在过去一年构建基于 LLM 的应用程序之后,Satwiki De分享了关于这一过程如何与传统产品开发规范有所不同的实用见解。
在他的最新文章中,Robert Lange聚焦于神经网络训练的最新进展,并探讨了分布式训练的各种方法,如数据并行训练和基于 gossip 的平均算法。
将数据分析转化为有价值的商业决策一直是数据专业人士的一个长期挑战。Tessa Xie从这个问题提供了全新的视角,并给出了几个切实可行的建议。
任何想要深入数学的读者可以直接去Reza Bagheri的最新文章,他带领我们了解极为重要的 softmax 函数的内部工作原理。
在对生成式 AI 工具的输出感到失望之后,Anna Gordun Peiro 尝试使用纯粹的 Python 创作 Mondrian 风格的艺术作品,并通过一篇易于跟随的教程记录了她的过程。
在处理时间序列数据时,了解你的异常值处理是否有效是至关重要的。 Sara Nóbrega 在她的最新文章中详细讨论了你可以用来评估处理效果的各种方法。
创建能在大规模运作的 AI 伦理与治理框架需要什么?Jason Tamara Widjaja 解构了弥合常见组织和实施差距的挑战。
在音乐与 AI 的交汇处写作的Jon Flynn,带领我们走过这个日益发展的领域,并聚焦于 Qwen2-Audio 模型,该模型被训练来将音乐输入转录成乐谱。
感谢你支持我们作者的工作!正如我们之前提到的,我们很高兴发布新作者的文章,因此如果你最近写了一篇有趣的项目实践、教程或关于我们核心话题的理论反思,别犹豫,与我们分享。
直到下一个变量,
TDS 团队
在你的 Python Web 应用中入门强大的数据表格
使用 AG Grid 在纯 Python 中与 Reflex 构建金融应用
·发布于 Towards Data Science ·6 分钟阅读·2024 年 10 月 6 日
--

在过去几个月里,我一直在探索用于 Web 应用的各种数据可视化和操作工具。作为一名 Python 开发者,我经常需要处理大规模数据集并将其显示在交互式、可定制的表格中。一直困扰我的一个问题是:如何构建一个强大的数据网格 UI,与我的 Python 后端无缝集成?
有许多选项可以构建复杂的数据网格,但作为一名 Python 工程师,我对 JavaScript 或任何前端框架的经验有限。我在寻找一种方法,利用我最熟悉的语言——Python,来创建一个功能丰富的数据网格!
我决定使用 Reflex,一个开源框架,允许我完全使用 Python 构建 Web 应用。更重要的是,Reflex 现在提供与 AG Grid 的集成,AG Grid 是一个功能丰富的数据网格库,旨在 Web 应用中显示和操作表格数据,提供一系列功能,包括:
原地单元格编辑
实时数据更新
分页和无限滚动
列过滤、重新排序、调整大小和隐藏
行分组与聚合
内建主题
免责声明: 我在 Reflex 担任创始工程师,负责贡献开源框架。
在本教程中,我们将介绍如何从头开始构建一个完整的金融应用程序,使用纯 Python 显示股票数据,并通过交互式网格和图表提供排序、筛选和分页等高级功能——查看完整的实时 应用程序 和 代码。
设置
首先,我们导入必要的库,包括用于获取股票数据的 yfinance。
import reflex as rx
from reflex_ag_grid import ag_grid
import yfinance as yf
from datetime import datetime, timedelta
import pandas as pd
获取并转换数据
接下来,我们定义了 State 类,它包含应用程序的状态和逻辑。fetch_stock_data 函数获取指定公司的股票数据,并将其转换为适合在 AG Grid 中显示的格式。我们通过将按钮的 on_click 触发器与该状态函数连接,来调用此函数。
我们定义了状态变量,它们表示应用程序中可能随时间变化的字段(状态变量会直接渲染到应用程序的前端)。
data 状态变量存储了从 Yahoo Finance 获取的原始股票数据。我们将这些数据转换为四舍五入的值,并将其存储为字典列表,这是 AG Grid 所期望的格式。转换后的数据按日期和股票代码降序排列,并存储在 dict_data 状态变量中。
datetime_now 状态变量存储了数据获取时的当前日期和时间。
# The list of companies to fetch data for
companies = ["AAPL", "MSFT", "GOOGL", "AMZN", "META"]
class State(rx.State):
# The data fetched from Yahoo Finance
data: pd.DataFrame
# The data to be displayed in the AG Grid
dict_data: list[dict] = [\{}]
# The datetime of the current fetched data
datetime_now: datetime = datetime.now()
def fetch_stock_data(self):
self.datetime_now = datetime.now()
start_date = self.datetime_now - timedelta(days=180)
# Fetch data for all tickers in a single download
self.data = yf.download(companies, start=start_date, end=self.datetime_now, group_by='ticker')
rows = []
for ticker in companies:
# Check if the DataFrame has a multi-level column index (for multiple tickers)
if isinstance(self.data.columns, pd.MultiIndex):
ticker_data = self.data[ticker] # Select the data for the current ticker
else:
ticker_data = self.data # If only one ticker, no multi-level index exists
for date, row in ticker_data.iterrows():
rows.append({
"ticker": ticker,
"date": date.strftime("%Y-%m-%d"),
"open": round(row["Open"], 2),
"high": round(row["High"], 2),
"mid": round((row["High"] + row["Low"]) / 2, 2),
"low": round(row["Low"], 2),
"close": round(row["Close"], 2),
"volume": int(row["Volume"]),
})
self.dict_data = sorted(rows, key=lambda x: (x["date"], x["ticker"]), reverse=True)
rx.button(
"Fetch Latest Data",
on_click=State.fetch_stock_data,
)
定义 AG Grid 列

按作者显示的 AG Grid 列
column_defs 列表定义了要在 AG Grid 中显示的列。header_name 用于设置每一列的标题。field 键表示每一列的 id。filter 键用于插入筛选功能。
column_defs = [
ag_grid.column_def(field="ticker", header_name="Ticker", filter=ag_grid.filters.text, checkbox_selection=True),
ag_grid.column_def(field="date", header_name="Date", filter=ag_grid.filters.date),
ag_grid.column_def(field="open", header_name="Open", filter=ag_grid.filters.number),
ag_grid.column_def(field="high", header_name="High", filter=ag_grid.filters.number),
ag_grid.column_def(field="low", header_name="Low", filter=ag_grid.filters.number),
ag_grid.column_def(field="close", header_name="Close", filter=ag_grid.filters.number),
ag_grid.column_def(field="volume", header_name="Volume", filter=ag_grid.filters.number),
]
显示 AG Grid

按作者展示 AG Grid
现在是我们应用程序中最重要的部分,AG Grid 本身!
• id 是必需的,因为它唯一标识页面上的 Ag-Grid 实例
• column_defs 是我们之前定义的列定义列表
• row_data 是要在网格中显示的数据,存储在 dict_data 状态变量中。
• pagination、pagination_page_size 和 pagination_page_size_selector 参数启用具有特定变量的分页功能
• theme 允许你设置网格的主题
ag_grid(
id="myAgGrid",
column_defs=column_defs,
row_data=State.dict_data,
pagination=True,
pagination_page_size=20,
pagination_page_size_selector=[10, 20, 50, 100],
theme=State.grid_theme,
on_selection_changed=State.handle_selection,
width="100%",
height="60vh",
)
更改 AG Grid 主题

按作者更改 AG Grid 主题
我们通过在 rx.select 组件中使用 grid_theme 状态变量来设置 theme。
每个状态变量都有一个内置函数来设置它的值,方便使用,函数名称为 set_VARNAME,在此情况下是 set_grid_theme。
class State(rx.State):
...
# The theme of the AG Grid
grid_theme: str = "quartz"
# The list of themes for the AG Grid
themes: list[str] = ["quartz", "balham", "alpine", "material"]
rx.select(
State.themes,
value=State.grid_theme,
on_change=State.set_grid_theme,
size="1",
)
以图表显示公司数据

按作者显示所选公司过去 6 个月的数据
上面 AG Grid 代码中显示的 on_selection_changed 事件触发器,在用户选择网格中的一行时调用。这会调用 State 类中的 handle_selection 方法,将 selected_rows 状态变量设置为新选择的行,并调用 update_line_graph 方法。
update_line_graph 函数获取相关的股票代码,并使用它来设置 company 状态变量。然后,该公司的过去 6 个月的日期、平均值和 DateDifference 数据被设置到状态变量 dff_ticker_hist 中。
最终它呈现在一个 rx.recharts.line_chart 中,使用 rx.recharts.error_bar 显示 DateDifference 数据,即当天的最高值和最低值。
class State(rx.State):
...
# The selected rows in the AG Grid
selected_rows: list[dict] = None
# The currently selected company in AG Grid
company: str
# The data fetched from Yahoo Finance
data: pd.DataFrame
# The data to be displayed in the line graph
dff_ticker_hist: list[dict] = None
def handle_selection(self, selected_rows, _, __):
self.selected_rows = selected_rows
self.update_line_graph()
def update_line_graph(self):
if self.selected_rows:
ticker = self.selected_rows[0]["ticker"]
else:
self.dff_ticker_hist = None
return
self.company = ticker
dff_ticker_hist = self.data[ticker].reset_index()
dff_ticker_hist["Date"] = pd.to_datetime(dff_ticker_hist["Date"]).dt.strftime("%Y-%m-%d")
dff_ticker_hist["Mid"] = (dff_ticker_hist["Open"] + dff_ticker_hist["Close"]) / 2
dff_ticker_hist["DayDifference"] = dff_ticker_hist.apply(
lambda row: [row["High"] - row["Mid"], row["Mid"] - row["Low"]], axis=1
)
self.dff_ticker_hist = dff_ticker_hist.to_dict(orient="records")
rx.recharts.line_chart(
rx.recharts.line(
rx.recharts.error_bar(
data_key="DayDifference",
direction="y",
width=4,
stroke_width=2,
stroke="red",
),
data_key="Mid",
),
rx.recharts.x_axis(data_key="Date"),
rx.recharts.y_axis(domain=["auto", "auto"]),
data=State.dff_ticker_hist,
width="100%",
height=300,
)
结论
在 Reflex 生态系统中使用 AG Grid 让我作为一个 Python 开发者能够轻松创建复杂、数据丰富的 Web 应用程序。无论你是在构建复杂的仪表盘、数据分析工具,还是需要强大数据网格功能的应用程序,Reflex AG Grid 都能满足你的需求。
我很高兴看到你将使用 Reflex AG Grid 创建的项目!分享你的项目,提问,加入我们社区论坛的讨论。让我们一起推动 Python Web 开发的边界!
如果你有问题,请在下方评论或通过 Twitter 联系我@tgotsman12或通过LinkedIn。在社交媒体上分享你的应用创作并标记我,我很乐意提供反馈或帮助转发!
作为地球科学家如何开始使用 Python?这里有 5 种方法帮助你提升代码质量!
写 Python 代码时,帮助你更好起步的 5 个技巧
·发表于 Towards Data Science ·阅读时长 8 分钟·2024 年 1 月 27 日
--

图像展示了一位地球科学家在编写代码。此图像由 DALL-E 根据作者的要求生成。
多年来,我在地球科学和岩石物理学领域见过并参与过许多 Python 脚本的编写。在这段时间里,我见识过(并且自己也写过)多种编程风格,从组织良好、文档完善的代码,到没有什么组织、所有内容都写在一个 Python 文件中的代码。后者在维护、调试和理解时会很困难,特别是几个月后再次查看代码时。编写代码的目的通常会决定所采用的编程风格。
如果我们正在编写一个可能只会使用一次或两次的脚本,或者是在严格的时间限制和压力下工作,那么我们可能无法像希望的那样让代码看起来既美观又有条理。然而,如果我们正在编写一段将多次使用的代码,或者将其部署给其他用户,而且我们有足够的时间,我们可能会希望以一种便于后期扩展的方式来构建代码或应用程序。这样,当我们重新查看代码时,可以节省时间并减少麻烦,同时避免忘记代码的功能或最初设计意图。
作为地球科学家,编写代码可能并不是你自然的背景。然而,你…
Dev Containers 扩展入门
这是一个逐步教程,帮助您入门 Dev Containers 扩展。本教程是一个系列教程的第一部分,重点介绍 Dev Containers 扩展。
·发表于 Towards Data Science ·阅读时间 9 分钟·2024 年 5 月 7 日
--
我完全将开发工作流转移到 Visual Studio Code(VScode)的主要原因之一是 Dev Containers 扩展。Dev Containers 扩展为 VScode 提供了 Docker 和容器的原生集成。它减少了设置 Docker 化环境以及在使用容器框架时与他人协作的工作量。本教程重点介绍如何“入门”使用 Dev Containers 扩展。我们将深入探讨在您的开发工作流中使用 Dev Containers 扩展(和 Docker)的动机。此外,我们还将回顾将 Dev Container 模板 添加到项目并在容器中启动它的步骤。
在接下来的教程中,我们将探索该扩展的核心功能,并回顾定制化选项。

一艘货船停靠在港口(由作者使用 Midjourney 创建)
本教程结束时,您将了解如何:
安装 Dev Containers 扩展
将 Dev Containers 设置文件添加到项目中
在一个…中启动项目
开始使用向量数据库
了解它们是什么,它们如何工作,以及你需要什么才能开始使用它们!
·发布于 Towards Data Science ·阅读时间:9 分钟·2024 年 3 月 18 日
--

这篇文章适合谁阅读?
读者群体 [🟢⚪️⚪️]: AI 初学者,熟悉流行模型及其应用
级别 [🟢⚪️⚪️]: 基础主题
复杂度 [🟢⚪️⚪️]: 容易理解,没有数学公式或复杂理论
❓它们是什么以及为什么重要
前几天,我看到了一条推文,令我停下了脚步:

图 1:来源:X, https://twitter.com/kelseyhightower/status/1761166616426061979
它由前谷歌同事 Kelsey Hightower 发布,多年来,他在开源和开发者社区中几乎已成为传奇人物。
🤔 他说得很对: 在生成型 AI(GenAI)的噪音包围中,很容易被花哨的演示和几乎不断发布的新模型公告所吸引。 😮💨 尝试去…
使用 Imatrix 和 K-量化进行 GGUF 量化,以便在您的 CPU 上运行 LLM
为您的 CPU 提供快速且精准的 GGUF 模型
·发布于Towards Data Science ·7 分钟阅读·2024 年 9 月 13 日
--

由 DALL-E 生成
GGUF 是一种二进制文件格式,旨在通过 GGML(一种基于 C 的机器学习张量库)实现高效的存储和快速的大型语言模型(LLM)加载。
GGUF 将推理所需的所有组件(包括分词器和代码)封装在一个文件中。它支持将各种语言模型转换为 GGUF 格式,例如 Llama 3、Phi 和 Qwen2。此外,它还支持将模型量化为较低的精度,以提高在 CPU 上的速度和内存效率。
我们常常写“GGUF 量化”,但 GGUF 本身只是一个文件格式,并不是一种量化方法。llama.cpp 中实现了几种量化算法,用于减少模型大小并将生成的模型序列化为 GGUF 格式。
在本文中,我们将展示如何准确地量化一个 LLM 并将其转换为 GGUF,使用重要性矩阵(imatrix)和 K-量化方法。我将提供 Gemma 2 Instruct 的 GGUF 转换代码,使用了 imatrix。它与其他由 llama.cpp 支持的模型一样有效:Qwen2、Llama 3、Phi-3 等。我们还将讨论如何评估量化的准确性以及生成模型的推理吞吐量。
机器学习项目的 Git 工作流:我在项目中使用的 Git 工作流
机器学习和数据科学项目的简单 Git 工作流
·发布于 Towards Data Science ·阅读时间 6 分钟·2024 年 2 月 16 日
--

图片来源:Mika Baumeister 于 Unsplash
在你的项目中采用 Git 工作流可以简化项目管理并提高一致性。有几种 Git 工作流是专门为满足 Git 用户需求而设计的:有些工作流简单直接,有些则更加复杂,适用于大型项目。在本文中,我将分享我在机器学习和数据科学项目中使用的 Git 工作流。我的工作流介于简单与复杂之间——既不太简单,也不太复杂。事不宜迟,让我们开始吧!
还不是 Medium 的会员?没关系!继续阅读,点击这个 朋友链接。
引言
Git 工作流被定义为一组约定和实践,旨在标准化项目版本控制的管理,从而提高一致性并促进协作。在之前的教程中,我详细介绍了我认为最重要的三种工作流:即 功能分支 工作流、分叉工作流 和 Gitflow 工作流。
在 功能分支 工作流中,为每个功能开发、修复 bug 和其他项目任务创建专门的分支。在 分叉 工作流中…
GitHub: 拉取请求模板

图片来源:Luca Bravo 由 Unsplash 提供
精通 Git
帮助开发者编写更好的拉取请求描述
·发布于 Towards Data Science ·阅读时间 5 分钟·2024 年 10 月 29 日
--
拉取请求(PR)允许开发人员在将代码更改合并到目标分支之前进行审核。当代码开发者创建拉取请求时,会弹出一个空白的 markdown 字段。这有时会让开发人员不确定应该包括哪些细节,从而可能遗漏一些审阅者可能会觉得有用的重要信息。设置拉取请求模板可以提供明确的指导,帮助开发人员编写有效的拉取请求描述。本文将探讨什么是拉取请求模板,如何创建一个,并提供一些示例内容来启发你的创作。

图片来源:Daniel Fazio 由 Unsplash 提供
📍什么是拉取请求模板?
拉取请求模板是一个包含 Markdown 文本的 markdown(.md)或文本(.txt)文件,在创建拉取请求时会自动添加到拉取请求描述字段中。想象一下,你打开一个拉取请求,看到的不是一个空白的 markdown 描述字段,而是一个有用的模板,指导你填写需要包括的内容。以下是一些……
谷歌 Gemini 正在加入 Advent of Code 挑战
一个开源项目,用于探索 LLM 在编程挑战中的能力与局限
·发表于 Towards Data Science ·阅读时长 8 分钟·2024 年 12 月 2 日
--

图片由作者提供(使用 Flux 1.1 Pro 创建)
这是什么?
如果 2024 年在生成性人工智能领域教给我们什么,那就是编程是大规模语言模型(LLMs)最具前景的应用之一。
在这篇博客中,我将描述如何使用目前领先于 LMArena 排行榜的最先进的 LLM——Gemini Experimental 1121,来应对 Advent of Code 挑战。

图片由作者提供
我将概述我的方法,并分享我的 开源代码库,以便读者进一步探索并复现结果。
为什么我们应该关注?
LLMs + 编程是一个令人兴奋的领域,有很多原因,其中几个亮点如下:
代码就像语言一样,可以通过变换器模型以相同的方式学习
输出容易验证——我们只需运行代码并检查其结果
对代码助手的需求巨大——它们能够极大提高程序员的生产力
所以,这无疑是一个有趣且令人兴奋的方向,我认为通过动手挑战来进一步探索它可能会很有趣。
Advent of Code 挑战
对于那些不熟悉 Advent of Code 挑战的人来说:这是一个每年举办的活动,从 12 月 1 日到 12 月 25 日,提供每天的编程难题,类似于圣诞倒计时日历。每天都会发布一个包含两部分的新难题,程序员可以在其中测试他们的编程和问题解决能力。这是一个让各个级别的开发者都可以练习编程的有趣方式。
每日挑战的两个部分围绕一个类似的问题展开,并使用相同的输入数据。其目的是编写一个 Python 程序来处理输入数据并生成解决方案(通常是一个数字)。一旦我们运行代码,模型计算出解决方案,我们可以将其粘贴到网站上,查看它是否正确。如果正确,第二部分将解锁,采用类似的程序。
竞赛持续 25 天,用户最多可以收集 50 星星(每天 2 星)。
一个很好的 LLM 挑战
如上所述,这是一个非常适合 LLM 的挑战。我们只需拿到问题陈述,插入选择的 LLM,让它生成代码,运行代码,再将生成的解决方案粘贴到网站上,查看 LLM 是否成功。
对于这个项目,我使用的是 Gemini Experimental 1121,它大大提高了编程和推理能力。它可以通过 Google 的AI Studio获得。我在整个挑战中使用相同的系统提示——它是一个零-shot 提示(没有链式思维),并且要求代码通过输入重定向接收输入,像这样:
python day01/part1.py < day01/input.txt
系统提示是:
Provide python code to solve a given puzzle.
Assume there is an input.txt file that can be read
via input redirection in the command line.
然后我会发布实际的挑战,Gemini 将生成应当产生正确解决方案的代码。我将代码复制到 GH 仓库并运行,然后将生成的解决方案粘贴到 Advent of Code 网站上查看它是否正确。
仓库
每天的挑战都组织在自己的目录中:
dayXX/
├── input.txt # Challenge input
├── part1-problem.txt # Problem description for part 1
├── part2-problem.txt # Problem description for part 2
├── part1.py # Solution for part 1
└── part2.py # Solution for part 2
part1 和 part2 的问题文本文件包含 Advent of Code 提出的挑战问题。我还在每个文本文件的末尾附上了正确的解决方案:

图片来自作者
Python 脚本包含由 Gemini 生成的代码。为了保持完全透明,我还链接到实际的对话记录,让大家可以查看并审核每一步:

图片来自作者
要查看这些聊天中的一个示例,请访问我与 Gemini 关于第一天挑战的对话
我将把所有的结果记录在一个表格中,给读者一个关于模型表现的良好初步概览:

图片来自作者
示例
为了更好地了解这将是什么样子,让我们来看一下第一天挑战的第一部分。以下是问题陈述:
The Chief Historian is always present for the big Christmas sleigh launch, but nobody has seen him in months! Last anyone heard, he was visiting locations that are historically significant to the North Pole; a group of Senior Historians has asked you to accompany them as they check the places they think he was most likely to visit.
As each location is checked, they will mark it on their list with a star. They figure the Chief Historian must be in one of the first fifty places they'll look, so in order to save Christmas, you need to help them get fifty stars on their list before Santa takes off on December 25th.
Collect stars by solving puzzles. Two puzzles will be made available on each day in the Advent calendar; the second puzzle is unlocked when you complete the first. Each puzzle grants one star. Good luck!
You haven't even left yet and the group of Elvish Senior Historians has already hit a problem: their list of locations to check is currently empty. Eventually, someone decides that the best place to check first would be the Chief Historian's office.
Upon pouring into the office, everyone confirms that the Chief Historian is indeed nowhere to be found. Instead, the Elves discover an assortment of notes and lists of historically significant locations! This seems to be the planning the Chief Historian was doing before he left. Perhaps these notes can be used to determine which locations to search?
Throughout the Chief's office, the historically significant locations are listed not by name but by a unique number called the location ID. To make sure they don't miss anything, The Historians split into two groups, each searching the office and trying to create their own complete list of location IDs.
There's just one problem: by holding the two lists up side by side (your puzzle input), it quickly becomes clear that the lists aren't very similar. Maybe you can help The Historians reconcile their lists?
For example:
3 4
4 3
2 5
1 3
3 9
3 3
Maybe the lists are only off by a small amount! To find out, pair up the numbers and measure how far apart they are. Pair up the smallest number in the left list with the smallest number in the right list, then the second-smallest left number with the second-smallest right number, and so on.
Within each pair, figure out how far apart the two numbers are; you'll need to add up all of those distances. For example, if you pair up a 3 from the left list with a 7 from the right list, the distance apart is 4; if you pair up a 9 with a 3, the distance apart is 6.
In the example list above, the pairs and distances would be as follows:
The smallest number in the left list is 1, and the smallest number in the right list is 3\. The distance between them is 2.
The second-smallest number in the left list is 2, and the second-smallest number in the right list is another 3\. The distance between them is 1.
The third-smallest number in both lists is 3, so the distance between them is 0.
The next numbers to pair up are 3 and 4, a distance of 1.
The fifth-smallest numbers in each list are 3 and 5, a distance of 2.
Finally, the largest number in the left list is 4, while the largest number in the right list is 9; these are a distance 5 apart.
To find the total distance between the left list and the right list, add up the distances between all of the pairs you found. In the example above, this is 2 + 1 + 0 + 1 + 2 + 5, a total distance of 11!
Your actual left and right lists contain many location IDs. What is the total distance between your lists?
### Solution: 1666427
当将这个挑战交给 Gemini 时,它会生成以下代码:
import sys
def solve():
left_list = []
right_list = []
for line in sys.stdin:
left, right = map(int, line.strip().split())
left_list.append(left)
right_list.append(right)
left_list.sort()
right_list.sort()
total_distance = 0
for i in range(len(left_list)):
total_distance += abs(left_list[i] - right_list[i])
print(total_distance)
solve()
使用以下代码运行
python day01/part1.py < day01/input.txt
得到的结果是:

图片来自作者
如我们所见,Gemini 的代码生成了正确的结果。
奖励:获取解释
作为额外奖励,我们还可以请 Gemini 提供代码的解释,说明它是做什么的,以及为什么能解决这个挑战:

图片由作者提供
你可以在这里找到带有解释的聊天记录。
结论
通过这个项目,我想探索当前最先进的 LLM(大语言模型)在解决编程挑战方面的能力。我有一个假设,认为 Gemini(以及其他 LLM)已经足够优秀,可以解决大多数这些挑战。当然,这并不意味着它们已经准备好解决更复杂的真实软件挑战。
话虽如此,我只是出于好奇决定开始这个有趣的小项目。希望你喜欢它,并能为你提供一些关于我们与 LLM+编程未来发展方向的见解🤗
海科·霍茨
👋 在Medium和LinkedIn上关注我,阅读更多关于生成式人工智能、机器学习和自然语言处理的内容。
👥 如果你在伦敦,可以加入我们的NLP London Meetups。

图片由作者提供
谷歌的 AI 子公司再次发力:AlphaFold 3 现在涵盖了更多的结构生物学领域
DeepMind 和 Isomorphic Labs 最近发布了一篇新论文,应用了新的 AI 概念和方法,创造了一种新的工具,承诺再次带来革命性的变革
LucianoSphere (Luciano Abriata, PhD)
·发表于Towards Data Science ·阅读时间 9 分钟·2024 年 5 月 10 日
--
谷歌的 AI 子公司 DeepMind 和 Isomorphic Labs 再次在科学界掀起波澜。这一次,他们推出了 AlphaFold 3,一种新的 AI 模型,它以前所未有的准确度预测分子结构,且不仅限于像成功的(并且在生物学中确实具有突破性的)AlphaFold 2 那样的蛋白质。事实上,AlphaFold 3 处理蛋白质及其与 DNA、RNA、配体、离子等的复合物,承诺再次彻底改变我们对生物学的理解,并在此过程中开辟了加速药物发现的新途径。
介绍
计算机科学家和生物信息学家也能富裕
改变我们研究生物学、发现药物和推动生物技术的方式
在几次点击中使用 AlphaFold 3 在线
光明的未来,但只有少数人能享有?
DeepMind 和 Isomorphic Labs,谷歌的两家核心 AI 子公司,再次令科学界惊讶。这一次,他们推出了 AlphaFold 3,一个全新的 AI 模型,在成功的(并且在生物学上真正具有革命性意义的)AlphaFold 2 的基础上,预测……
高尔的距离:用于混合分类和数值数据
用于聚类混合数据的距离度量
·发表于Towards Data Science ·阅读时间 8 分钟·2024 年 7 月 12 日
--
什么是距离度量?
你很可能听说过曼哈顿距离或欧几里得距离。这是两种不同的度量方式,用于提供两个给定数据点的距离(或差异)信息。

曼哈顿距离和欧几里得距离的图示。图片由作者提供
简而言之,欧几里得距离是从点 A 到点 B 的最短距离。曼哈顿距离计算 x 和 y 坐标之间的绝对差的总和,并根据它们在网格上的位置来计算距离,假设你只能上下左右移动(不能斜着走)。
距离度量通常是聚类算法的基础,例如使用欧几里得距离的k-means 聚类。这很有道理,因为为了定义聚类,首先需要知道两个数据点有多相似或有多不同(即它们之间的距离有多远)。
计算两点之间的距离
为了展示这个过程,我将以使用欧几里得距离为例。
使用地图和运动学进行 GPS 插值
如何将航位推算应用于你的地理空间数据集?
·发布于数据科学前沿 ·10 分钟阅读·2024 年 12 月 3 日
--

上面的图片展示了 GPS 插值过程。红点表示已知且重复的 GPS 位置,每个点上有多个位置,而蓝点表示使用车辆的速度沿道路推测的重复点位置。(图片由作者使用 OpenStreetMap 数据和影像创建。)
现代的汽车、面包车和卡车是远程信息处理数据的流动生成器。车辆的远程信息处理数据流通常携带各种信号,其中 GPS 位置是最常见的信号之一。你还可以找到诸如瞬时速度、加速度、油箱或电池电量以及其他一些特殊信号,如挡风玻璃雨刷状态和外部温度等。
GPS 接收器通常每秒采样一次数据(1 Hz),这适用于大多数应用,但其他车辆传感器可能具有不同的信号生成频率。信号生成频率是可编程的,通常需要平衡电信成本和信息内容的实用性。一些信号会在变化时发送,而其他信号可能只有在变化达到一定百分比后才发送,以避免不必要的费用。
遥测数据流在通过无线连接发送信号值时采用不同的信号包装方法。最基本的信号包装方法是在每次信号生成或显著变化时独立发送每个信号。每个数据包包含源标识、信号标识和信号值。另一种方法是将所有信号值作为标准记录进行打包,每当每个值发生变化时。没有预设的发射频率,未改变的值会在连续消息中重复。当接收端使用这种信号包装方法时,我们通常会发现这种方式,尤其是在发送方使用前一种方法时。
最后的方法与前一种方法类似,固定了发射频率,通常与 GPS 同步,突出了此信号在过程中的重要性。
本文讨论的第二种方法有一些副作用,即在 GPS 信号变化的过程中,所有中间数据包都会重复 GPS 坐标。下图展示了这种效果在扩展车辆能源数据集(EVED)上的体现。

图 1 — 来自 EVED 的数据展示了车辆移动时 GPS 位置重复的频率。(图像来源:作者)
通常,在处理数据时,如图 1所示,当去除重复行时,会使用纬度和经度作为键。此技巧保留了所有其他列的聚合值,通常是第一行的值。然而,这可能会大幅减少数据集中的行数,从而使数据的价值降低,这类似于第三种包装方法。
在 GPS 信号变化之间(第 1、8 和 14 行),即使车辆在移动,如图 1所示的速度信号所示,所有其他记录仍然携带原始 GPS 坐标。我们可以对这些中间记录进行地理位置插值,从而提高 GPS 传感器的分辨率,改善数据集的质量。
本文演示了如何使用地图信息和速度信号执行上述 GPS 位置插值。
插值
GPS 插值是指使用辅助数据推断输入数据集中缺失的地理位置。如果你喜欢,这类似于航迹推算,即 GPS 接收器在你通过隧道时推断当前位置的过程。在这里,我们将类似的概念应用于一个数据集,其中车辆信号的采样率高于 GPS 接收器。
死 reckoning(惯性导航)使用地图来确定前方的道路,并假设在隧道(或 GPS 盲区)内速度保持不变。在这里,我们将使用类似的方法。知道两个连续且不同的 GPS 样本之间的地图几何信息,可以提供准确的距离信息。如果有速度信号,我们可以通过简单的运动学计算来确定临时信号的近似 GPS 位置。否则,我们可以假设两个连续位置之间的平均速度是恒定的。幸运的是,EVED 报告了瞬时速度。
我们必须解决的第一个问题是如何测量两个连续且不同的 GPS 位置之间的距离。我们通过使用地图信息来完成这一点。我们必须利用地图来避免测量位置之间的地理距离(顺风而行的距离),正如下面的图 2所示。

图 2 — 蓝点是地图顶点,红点是地图匹配的 GPS 位置。插值位置将出现在蓝色线条上,而我们必须使用沿红色线条几何形状的连续样本之间的距离。绿色线条代表不考虑地图几何的连续 GPS 位置之间的距离,而红色线条则使用地图几何。在这种情况下,红线比绿线长。(图片来源:作者)
GPS 插值过程需要辅助技术来实施,例如地图匹配、地图对齐、速度集成和地图投影。让我们逐一查看这些技术。
地图匹配
地图匹配将采样的 GPS 位置序列投影到数字地图上最可能的轨迹。我在另外两篇文章中已经讨论了这一过程,探讨了它在轨迹和速度预测中的应用。请查阅这两篇文章及其支持的代码,它们支持本材料的内容。
在运行地图匹配过程后,我们收集原始 GPS 样本在地图边缘上的投影以及与行驶路线对应的地图顶点序列。上面的图 2展示了这一过程,其中地图顶点以蓝色表示,GPS 投影以红色表示。在继续之前,我们必须计算合并的顶点序列和 GPS 投影,这一过程我称之为“地图对齐”。
地图对齐
如前所述,地图匹配过程会产生两个不相交的点集,即边缘投影的 GPS 位置和地图顶点,这些点按照路线方向排列。在进一步处理之前,我们必须合并这些位置集,以确保它们之间的正确排序。不幸的是,边缘投影的 GPS 位置没有携带边缘信息,因此我们必须通过端点顶点来找到对应的边缘。这个过程会生成一个包含匹配 GPS 位置投影的地图边缘列表。
完成后,我们通过将地图边缘列表转换为互补格式:GPS 段列表,来结束地图对齐过程。我们通过起始和结束位置以及它们之间的任何地图顶点来标识每个 GPS 段。图 3 下方展示了这些概念,蓝色条表示地图边缘,红色条表示 GPS 段。

图 3 — 地图对齐过程正确地对地图顶点和投影的 GPS 位置进行排序,并将结果列表拆分成 GPS 段,显示为红色。请注意,每个 GPS 段包含投影的 GPS 端点和其中的所有地图顶点。(图片来源:作者)
现在,我们可以分别检查和处理每个 GPS 段。为了更好地说明这一概念,图 1 上的第一个 GPS 段将包括第一到第八行,这些行之间包含的任何地图顶点。
图 3 上所示的典型 GPS 段包含一组与每个端点对应的信号记录。图 1 显示,前两个 GPS 位置分别有七个和六个记录。我们的目标是利用收集到的关于车辆运动的信息,将这些记录投影到段的地理位置上。幸运的是,EVED 具有时间戳和记录的车辆速度。我们可以利用一些运动学重建段上的位移,并放置插值后的 GPS 位置。
如果你曾经学习过运动学,你会知道:
为了恢复连续投影的 GPS 位置之间的估计距离,我们需要计算时间与速度的积分。
速度集成
图 1 上显示,对于每条记录,我们都有时间戳值(从行程开始以来的毫秒数)和车辆速度值(以千米每小时为单位)。为了重建所有中间距离,我们为每一步计算一个简单的梯形积分,然后根据地图上计算的实际 GPS 段长度进行调整。
由于速度信号可能会有噪声,因此需要进行最终调整,因为假设这些噪声在整个过程中分布一致。因此,从积分计算出的距离通常会与地图上的距离有所不同。
为了弥补这个差异,我们计算两个距离之间的修正因子,这使我们能够计算投影 GPS 位置之间的调整距离。有了这些最终信息,我们现在可以沿着分段对重复的 GPS 位置进行插值。
地图投影
插值过程的最后一步是将额外和重复的 GPS 位置转移到地图几何形状中。我们通过使用前一个位置并根据先前计算的距离沿分段的方向移动来计算每个位置。图 4下方展示了这一过程。

图 4 — 地图投影过程使用集成的距离和 GPS 分段方向来计算投影位置的放置位置。从左到右,我们使用原始 GPS 位置(红色)或地图顶点(蓝色)的位置信息,以及相应的方向和距离,来计算投影的绿色 GPS 位置。(图片来源:作者)
为了尊重地图几何形状,算法在计算过程中必须考虑连续 GPS 位置之间的地图顶点。在上述图 4中,红色的初始 GPS 位置有四个重复点,我们可以利用信号时间戳和记录的速度将其投影到新的绿色点上。即使在跨越地图顶点时,算法也必须正确地分配距离,如图所示。
在投影插值后的 GPS 位置时,算法使用分段方向、前一个位置和临时距离,通过一个著名公式来计算下一个点的位置。
最终的 GPS 位置集,包括采样和插值后的位置,将被存储以供以后使用。接下来我们来看看如何做到这一点。
实现
在尝试运行本文的代码之前,请先阅读前提文章并运行其代码。本文的代码要求你下载并生成一个包含 EVED 数据的数据库,该数据库已经进行地图匹配,并包含投影的链路时长。请参阅下面的参考资料。
实现本文所述概念的 Python 代码可以在随附的GitHub 仓库中找到。你需要从命令行执行主脚本,以便对所有行程进行插值。
uv run interpolate-gps.py
该脚本遍历所有行程,并逐一处理。第一步是加载地图匹配的行程折线,其中每个点都是地图顶点(前面图中的蓝点)。这些折线是在之前的文章中生成的,并应作为编码字符串存储在数据库中。
折线解码
解码折线需要一个专门的函数,该函数改编自公开的Valhalla 仓库。
图 5 — 上面的代码修改了原始的 Valhalla 折线解码函数。改进之一是可以选择重新排列坐标对为(纬度,经度),而不是默认的(经度,纬度)。原始代码由其作者根据 MIT 许可证授权。(图片来源:作者)
GPS 段生成
接下来,脚本收集并对齐与地图匹配的旅行数据(红点)和地图顶点。此处理结果生成一个 GPS 段列表,结构中包含顺序配对的与地图匹配的 GPS 位置以及其中的任何地图顶点。
图 6 — 一个 GPS 段是一个点的列表,其中第一个和最后一个点保证是与地图匹配的 GPS 位置。中间的任何点将是地图顶点。(图片来源:作者)
我们使用一个函数,它接受一个包含原始轨迹的 Pandas DataFrame,该轨迹包含唯一的位置以及与地图匹配的轨迹折线,用于计算 GPS 段的列表。
图 7 — 上面的函数将与地图匹配的轨迹与独特的位置以及与地图匹配的折线转换为 GPS 段的列表。(图片来源:作者)
代码接着计算沿段几何线的重复位置投影,适用于每个 GPS 段。请注意,这仅适用于与起始 GPS 点对应的重复位置。结束的 GPS 点作为下一个段的起始点被重复。
我们使用一个专门的轨迹类来帮助我们计算 GPS 段。正如上面图 7所示,函数使用一系列独特的 GPS 位置、相应的时间戳和每个点的数据库标识符来初始化轨迹对象。然后,该对象将自己与解码后的折线合并,以返回一个……
死 reckoning 函数使用 GPS 段、计算出的距离和已知的持续时间来投影重复位置。
图 8 — 上面的函数使用可用的距离和持续时间信息生成初始 GPS 位置的 GPS 投影。(图片来源:作者)
上面的函数生成一个包含所有从第一个 GPS 位置投影的点的列表,并为每个点添加行标识符,以便后续插入数据库。通过这种方式,使用这些投影位置的代码可以回溯到原始数据行。
图 9 — 算法将通过地图投影生成的每个位置存储在上面的结构中。除了地理坐标,对象还存储时间偏移和原始行标识符。(图片来源:作者)
我们使用下面的函数来根据源位置、方位角和距离计算位置。方位角是从真北开始顺时针测量的角度,东为 90 度,南为 180 度。
图 10 — 上述函数将一个点沿给定的距离和航向移动,是 GPS 位置插值的基础。(图片来源:作者)
我们现在可以看到主函数循环如何集成所有这些组件。值得注意的是,代码保留了原始地图匹配轨迹的两个副本,一个包含所有数据,另一个仅包含唯一位置(参见下面第 11–14 行)。
图 11 — GPS 插值应用程序的主循环。(图片来源:作者)
代码的最后一步是将插值后的位置信息插入到数据库中的专用表中,该表与原始信号表有 1:1 的关系。
图 12 — 上述函数计算时间与位移函数的导数,并将所有信号存储在数据库中。有关这些计算的解释见下文。(图片来源:作者)
现在,经过精细处理的数据可以用于一个有趣的案例研究,识别出哪些路段受到了最剧烈的制动和加速。
案例研究:剧烈制动
通过增加插值后的 GPS 位置信息的精度,我们可以更好地洞察车辆行为并进行更精确的计算。为了说明如何使用改进的位置信息精度,我们研究了车辆在哪些地方制动最剧烈,通过计算一个有趣的运动特征:jerk(或突变)。我们可以通过较短的时间间隔和相应的速度来可靠地计算这个运动学量。

图 13 — jerk 或 jolt 是瞬时加速度的导数。较大的正值表示剧烈加速,而显著的负值则表示剧烈制动。(图片来源:作者,使用了 Wikipedia 的符号)
可以使用推导出的插值 GPS 位置信息在地图上突出显示最剧烈的制动区域,通过 r(t) 函数的三阶导数计算瞬时 jerk,其中 r 是位移,t 是时间。
下图 图 14 显示了将计算出的最剧烈制动绘制成图,制动值低于 jerk 分布的 𝜇-3𝜎。你可以通过专门的 Jupyter notebook 与此地图进行交互。

图 14 — 上图显示了安娜堡地图的一部分,每个红点表示数据库中记录的一个异常剧烈的制动点。(图像由作者使用 OpenStreetMap 数据和影像创建。)
结论
在本文中,我们探讨了使用地图和中间信号(如时间和速度)插值 GPS 位置的问题。在进行地图匹配之后,我们检索了对应的生成树,该树包含了推断出的轨迹路径,并使用这些几何形状通过已知速度和中间时长,将缺失的 GPS 位置投影到“正确”的位置。通过插值所有 EVED 数据,我们利用这些数据检测安娜港(Ann Harbor)中预期发生最剧烈刹车的区域,作为可能存在危险点的代理。
致谢
我使用了 Grammarly 来审查写作,并接受了其中一些重写建议。
JetBrains' AI 助手编写了部分代码,并且它已经成为我日常使用 Python 编程工作的基础工具。
参考文献
你要去哪里?你应该走那条路吗?
towardsdatascience.com ## 速度预测中的地图匹配
你会开多快?
towardsdatascience.com
Movable Type Scripts — 计算纬度/经度点之间的距离、方位角等
João Paulo Figueira 是 tb.lx by Daimler Truck 的数据科学家,位于葡萄牙里斯本。
GPT-4V 具有方向性阅读障碍
展示了我们基于 WSDM 2023 Toloka VQA 挑战的研究
·发表于 Towards Data Science ·阅读时长 10 分钟·2024 年 2 月 20 日
--

图像由 GPT-4V 生成
自从 WSDM 2023 年 Toloka 图像问答(VQA)挑战赛以来已经过去一年,正如我们当时预测的那样,获胜的机器学习解决方案并未超越人类基准。然而,过去一年充满了生成式 AI 的突破。感觉每篇文章都在交替指出 OpenAI 的 GPT 模型不能做的事情,并赞扬它们做得比我们更好的地方。
自 2023 年秋季以来,GPT-4 Turbo 已获得“视觉”能力,这意味着它可以接受图像作为输入,并且现在可以直接参与 VQA 挑战。我们很想测试它在我们的 Toloka 挑战中的表现,看看它是否终于缩小了与人类基准的差距。
图像问答
图像问答(VQA)是一个跨学科的人工智能研究问题,专注于让 AI 理解图像并用自然语言回答相关问题。这个领域有各种应用:帮助视障人士、丰富教育内容、支持图像搜索功能以及提供视频搜索功能。
VQA 的发展“伴随着巨大的责任”,例如确保技术应用的可靠性和安全性。随着 AI 系统具备视觉能力,虚假信息的传播潜力增加,考虑到有论点称 “带有虚假信息的图像可能使陈述看起来更具可信度”。
VQA 领域的一个子领域,VQA 定位不仅仅是回答视觉问题,还涉及将这些答案与图像中的元素连接起来。这个子领域在混合现实(XR)头戴设备、教育工具和在线购物等应用中具有巨大的潜力,通过将注意力引导到图像的特定部分,从而提升用户交互体验。Toloka VQA 挑战赛的目标是支持 VQA 定位的发展。
Toloka 的 VQA 挑战赛回顾
在 Toloka VQA 挑战赛中,任务是根据描述物体功能而非其视觉特征的问题,识别出一个物体并将其放入一个边界框中。例如,问题可能不会问“找一个圆形的红色物体”,而是问“图中哪个物体适合放在沙拉和披萨上?”这反映了人类根据物体的实用性来感知物体的能力。就像当你看到一张桌子上有一份报纸、一只咖啡杯和一副眼镜时,问你“找个东西来拍苍蝇”,你会知道该选什么,而无需任何视觉描述。
问题:我们用什么来将披萨切成块?

图片来自“Toloka VQA 挑战赛”(CC BY 4.0)
这个挑战要求同时整合视觉、文本和常识知识。作为基准方法,我们提出将 YOLOR 和 CLIP 作为独立的视觉和文本骨干模型结合使用。然而,获胜的解决方案根本没有使用双塔结构,而是选择了 Uni-Perceiver 模型并配备了 ViT-Adapter 以便更好地进行定位。该方案最终获得了较高的交并比(IoU)得分76.347,但未能达到众包基准的 IoU 得分87。
鉴于人类和 AI 解决方案之间的巨大差距,我们非常好奇 GPT-4V 在 Toloka VQA 挑战赛中的表现。由于该挑战赛基于MS COCO 数据集,该数据集在计算机视觉中已经被多次使用(例如在视觉空间推理数据集中),因此可能是 GPT-4 在训练数据中“已知”的,有可能 GPT-4V 会接近人类基准。
GPT-4V 和 Toloka VQA 挑战赛
最初,我们想了解 GPT-4V 是否能够直接应对 Toloka VQA 挑战赛。

然而,尽管 GPT-4V 大多数情况下正确地定义了物体,它在提供有意义的边界框坐标时遇到了严重的困难。这并不完全出乎意料,因为OpenAI 的指南承认 GPT-4V 在需要识别物体在图像中精确空间定位的任务中存在局限性。

图片由作者提供
这促使我们探索 GPT-4 在识别图像中基本高层次位置方面的表现。它能否判断物体的位置——不一定准确,但能否知道它们在左侧、中间或右侧?或者是在顶部、中间或底部?由于这些位置不是精确的,对于 GPT-4V 来说,这可能是可行的,特别是考虑到它已在数百万张图像与对应说明文字的配对中进行训练,其中指出了物体的方向性位置。教育材料通常会详细描述图片(例如脑结构的教科书中会提到像“树突”位于图片的“左上方”或“轴突”位于“右下方”的部分)。
理解大语言模型(LLM)和多模态语言模型(MLM)在空间推理方面的局限性,即使是我们上述讨论的简单推理,在实际应用中也至关重要。将 GPT-4V 集成到“Be My Eyes”应用程序中,该应用通过解读图像来帮助视障用户,完美地展示了这一重要性。尽管 GPT-4V 的能力不容小觑,但该应用程序提醒用户要小心,强调了当前技术在关键的安全和健康环境中尚无法完全代替人类判断。然而,具体的技术不足之处并未明确指出。
GPT-4V 和空间推理
在我们探索 GPT-4V 在图像中物体基本位置推理的过程中,我们从一个包含 4,500 对图像-问题的较大数据集(该数据集是比赛的私人测试数据集)中随机选择了 500 对图像-问题。我们尽量减少测试数据泄露到 GPT-4V 训练数据中的可能性,因为该子集的比赛数据是在比赛时间表中最后发布的。

在这 500 对数据中,有 25 对被 GPT-4V 拒绝,标记为“无效图像”。我们怀疑这种拒绝是由于内置的安全措施引起的,可能是由于图像中出现了可能被分类为个人身份信息(PI)的物体,例如人脸。剩下的 475 对数据被用于我们的实验。
理解事物在相互关系中的位置,例如弄清楚什么是左、中或右、上、中或下,并不像看起来那么简单。很多事情取决于观察者的视角,物体是否有前面,如果有的话,它的方向如何。因此,人类的空间推理可能依赖于我们进化历史中对世界的显著归纳偏差。
问题: 什么能保护眼睛免受灯光的眩光?

图片来自“Toloka VQA Challenge”(CC BY 4.0)
以实验数据中采样的一个带灯罩的例子为例。一个人可能会说它偏向图像的左上方,因为灯罩稍微向左倾斜,而另一个人可能会称它为中上,看到它位于图片的中央。两种观点都有道理。要为定位建立严格的规则很难,因为物体可能有各种形状和部件,比如灯具的长电线,这可能会改变我们对其放置位置的看法。
考虑到这种复杂性,我们计划尝试至少两种不同的方法来标注图像中事物的位置的真实情况。
它的工作原理如下:如果图像中心与物体中心(由其边界框标记)的像素差小于或等于图像宽度的一定百分比(水平位置)或高度(垂直位置),那么我们就标记该物体位于中间。如果差异更大,则标记为左或右(或上或下)。我们决定使用 2%作为阈值百分比。这个决定是基于观察物体相对于图像整体大小的像素差异如何表现的。
object_horizontal_center = bb_left + (bb_right - bb_left) / 2
image_horizontal_center = image_width / 2
difference = object_horizontal_center - image_horizontal_center
if difference > (image_width * 0.02):
return 'right'
else if difference < (-1 * image_width * 0.02):
return 'left'
else:
return 'middle'For our first approach, we decided on simple automated heuristics to figure out where objects are placed in a picture, both horizontally and vertically. This idea came from an assumption that GPT-4V might use algorithms found in publicly available code for tasks of a similar nature.
对于第二种方法,我们使用了众包标注。以下是众包项目设置的详细信息:
图片在没有边界框的情况下展示给众包人员,以鼓励他们对物体位置的标注不带有偏见(关于真实答案),就像回应有关物体在视觉上下文中放置的问题一样。
GPT-4V 的回答既作为提示,也作为验证其物体检测准确性的一种方式。
参与者可以选择报告如果给定的图像无法清晰回答某个问题,从而将任何潜在的模糊/灰色区域情况从数据集中剔除。
为了确保众包回应的质量,我审查了所有 GPT-4 的回答与众包结果不符的情况。在此审查过程中,我无法看到 GPT-4V 或众包的回应,这使我能够在没有偏袒的情况下调整标签。

图片来自作者。标注界面在 Toloka
GPT-4V 有方向性读写困难
我们选择了准确度作为评估指标,因为我们数据集中的类别分布均匀。在通过众包和启发式方法建立的真实标签上,我们评估了 GPT-4V 的表现,涵盖了 475 张图片,排除了 45 对众包认为难以回答的图片。剩余的数据表明,与众包和启发式标签相比,GPT-4V 在识别水平和垂直位置时的准确度相当低,仅为30%。

与 自动启发式方法 相比,GPT-4V 回答的准确度

与 众包标签 相比,GPT-4V 回答的准确度
即使我们接受 GPT-4V 的回答为正确,只要它与众包或启发式方法相匹配,其准确度仍未达到 50%,最终结果为40.2%。
为了进一步验证这些发现,我们手动复审了 100 个 GPT-4V 错误标注的图像-问题对。

通过直接要求 GPT-4V 指定物体的位置并比较其回应,我们确认了最初的结果。

图片由作者提供。标签界面在 Toloka
GPT-4V 一贯将左右、上下搞混,因此,如果 GPT-4V 是你的导航员,请准备好走风景路线——这完全是无意的。
然而,GPT-4V 的物体识别能力令人印象深刻,达到了**88.84%**的准确率。这表明,通过将 GPT-4V 与专业的物体检测工具集成,我们可能能够达到(甚至超过)人类基准。这是我们研究的下一个目标。
提示工程与方向性阅读障碍
为了确保我们不是在指出 GPT-4V 的局限性而不进行任何提示优化努力,以避免变成我们所讨厌的那种情况,我们探索了研究文献中提到的多种提示工程技术,这些技术有助于增强大语言模型的空间推理能力。
问题:什么被用作一个国家的象征或标志?

来自 “Toloka VQA Challenge”(CC BY 4.0)
我们在上述实验数据集示例中应用了三种已发现的提示工程技术,而 GPT-4V 在这些示例中固执且一贯地误解了信息。图中询问的旗帜位于图片的中右位置。
“《Shikra: 释放多模态 LLM 的参考对话魔力》”论文介绍了一种结合了思维链(CoT)和位置标注,特别是中心标注的方法,称为基础 CoT(GCoT)。在 GCoT 设定中,作者提示模型为每个提到的物体提供思维链并附上中心点。由于作者特别训练了模型来提供图像中物体的坐标,我们必须将提示工程技术适应到一个不那么严格的设定,要求模型根据物体的中心提供关于物体位置的推理。


作者提供的图片。 基础 CoT 方法 (正确答案是 中右*)
Patel 和 Pavlick(2022)的研究 “将语言模型映射到基础概念空间” 说明了 GPT-3 即使在基于文本的网格中也能掌握空间和方位方向,通过“定位”模型并使用训练中学习到的特定词形。他们用北/南和西/东代替了传统的上/下和左/右方位词,以指导模型的空间推理。


作者提供的图片。 方位方向方法 (正确答案是 东南*)
最后,"视觉空间推理"文章阐述了空间描述中不同视角的重要性:内在框架是围绕物体中心的框架(例如,椅子后面 = 带靠背的那一面),相对框架是从观察者的视角出发的框架,而绝对框架则使用固定坐标(例如,椅子“北”侧)。英语通常偏向使用相对框架,因此我们在提示中明确提到它,希望能优化 GPT-4V 的空间推理能力。


作者提供的图片。 相对框架方法 (正确答案是 中右*)
如我们从示例中看到的,GPT-4V 在基础空间推理方面仍然存在挑战。
结论与未来工作
GPT-4V 在简单的空间推理上遇到困难,例如在图像中识别物体的水平和垂直位置。然而,基于隐性功能描述的强大物体识别能力仍然是非常有前景的。我们的下一步是将 GPT-4V 与专门训练用于图像物体检测的模型结合起来。让我们看看这种结合能否超越人类基准,在 Toloka VQA 挑战中取得胜利!
从零开始使用 MLX 构建 GPT
在你的 MacBook 上定义并训练 GPT-2
·发布于 Towards Data Science ·31 分钟阅读·2024 年 6 月 15 日
--

图片由 Sergey Zolkin 提供,来源于 Unsplash
本文的目标是带领你一步步从头开始定义并训练 GPT-2,使用 MLX,苹果为 Apple Silicon 提供的机器学习库。我希望在从分词器到采样的过程中,毫无遗漏地讲解每一步。沿袭 Karpathy 的精彩 GPT 从零开始教程的精神,我们将训练一个基于莎士比亚作品的模型 [1]。我们将从一个空白的 Python 文件开始,最终构建出一个能够写出莎士比亚风格文本的软件。所有这些都将在 MLX 中构建,这使得在 Apple Silicon 上的训练和推理速度更快。
本文最好通过跟随操作来体验。代码包含在以下仓库中,我建议打开并参考它。
[## GitHub - pranavjad/mlx-gpt2
通过在 GitHub 上创建帐户,贡献于 pranavjad/mlx-gpt2 的开发。
github.com](https://github.com/pranavjad/mlx-gpt2?source=post_page-----acf2defda30e--------------------------------)
目录
准备数据
编码 GPT-2
输入嵌入
位置嵌入
自注意力
键、查询和数值
多头注意力
多层感知机(MLP)
模块
层归一化和跳跃连接
前向传播
采样
初始化
训练循环
参考文献
准备数据
安装 mlx 并运行以下导入语句。
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils as utils
import numpy as np
import math
训练 LLM 的第一步是收集大量的文本数据,然后对其进行分词。分词是将文本映射到整数的过程,这些整数可以输入到 LLM 中。我们这个模型的训练语料库将是莎士比亚的作品,这些作品会被拼接成一个文件。大约有 100 万个字符,格式如下:
First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You are all resolved rather to die than to famish?
All:
Resolved. resolved.
First Citizen:
First, you know Caius Marcius is chief enemy to the people.
...
首先,我们将文件作为一个长字符串读取到text变量中。然后我们使用set()函数获取文本中的所有唯一字符,这些字符将构成我们的词汇表。通过打印vocab,你可以看到我们词汇表中的所有字符作为一个字符串,我们一共有 65 个字符,这些将作为我们的 token。
# Creating the vocabulary
with open('input.txt', 'r', encoding='utf-8') as f:
text = f.read()
vocab = sorted(list(set(text)))
vocab_size = len(vocab)
print(''.join(vocab))
# !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
print(vocab_size)
# 65
生产模型将使用像字节对编码(byte-pair encoding)这样的分词算法来生成更大的子词词汇表。由于我们今天的重点是架构,因此我们将继续使用字符级别的分词。接下来,我们将把我们的词汇映射到称为 token ID 的整数。然后我们可以将文本编码为 token 并解码回字符串。
# Create mapping from vocab to integers
itos = {i:c for i,c in enumerate(vocab)} # int to string
stoi = {c:i for i,c in enumerate(vocab)} # string to int
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[i] for i in x])
print(encode("hello world"))
# [46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
print(decode(encode("hello world")))
# hello world
我们使用enumerate()函数遍历词汇表中的所有字符及其索引,并创建一个字典itos,它将整数映射到字符,另一个字典stoi,它将字符串映射到整数。然后我们使用这些映射来创建我们的编码和解码函数。现在我们可以编码整个文本,并拆分训练数据和验证数据。
data = encode(text)
split = int(0.9 * len(data))
train_data = data[:split]
val_data = data[split:]
当前,我们的训练数据只是一个非常长的 token 字符串。然而,我们正在训练模型以预测给定一些前置 token 后的下一个 token。因此,我们的数据集应该包含输入是某个 token 字符串、标签是正确下一个 token 的示例。我们需要定义一个模型参数,称为上下文长度,即用于预测下一个 token 的最大 token 数量。我们的训练示例将是我们上下文长度的长度。
让我们来看一下前ctx_len+1个 token。
ctx_len = 8
print(train_data[:ctx_len + 1])
# [18, 47, 56, 57, 58, 1, 15, 47, 58]
# x: [18, 47, 56, 57, 58, 1, 15, 47] | y: 58
这是一个训练示例,其中输入是“18, 47, 56, 57, 58, 1, 15, 47”,期望输出是“58”。这是 8 个 token 的上下文。然而,我们还希望训练模型在只给定 7、6、5 … 0 个 token 作为上下文时预测下一个 token,这在生成过程中是必要的。因此,我们还考虑了这个示例中包含的 8 个子示例:
ctx_len = 8
print(train_data[:ctx_len + 1])
# [18, 47, 56, 57, 58, 1, 15, 47, 58]
# 8 sub examples
# [18] --> 47
# [18, 47] --> 56
# [18, 47, 56] --> 57
# [18, 47, 56, 57] --> 58
# [18, 47, 56, 57, 58] --> 1
# [18, 47, 56, 57, 58, 1] --> 15
# [18, 47, 56, 57, 58, 1, 15] --> 47
# [18, 47, 56, 57, 58, 1, 15, 47] --> 58
请注意,标签只是输入左移后的结果。
print("inputs: ", train_data[:ctx_len])
print("labels: ", train_data[1:ctx_len+1]) # labels = inputs indexed 1 higher
# inputs: [18, 47, 56, 57, 58, 1, 15, 47]
# labels: [47, 56, 57, 58, 1, 15, 47, 58]
在索引 0 处,输入是 18,标签是 47。 在索引 1 处,输入是包括索引 1 之前的所有内容,即[18, 47],标签是 56,依此类推。现在我们明白标签仅仅是将输入序列的索引增加 1 后得到的结果,我们可以构建我们的数据集。
# Creating training and validation datasets
ctx_len = 8
X_train = mx.array([train_data[i:i+ctx_len] for i in range(0, len(train_data) - ctx_len, ctx_len)])
y_train = mx.array([train_data[i+1:i+ctx_len+1] for i in range(0, len(train_data) - ctx_len, ctx_len)])
X_val = mx.array([val_data[i:i+ctx_len] for i in range(0, len(val_data) - ctx_len, ctx_len)])
y_val = mx.array([val_data[i+1:i+ctx_len+1] for i in range(0, len(val_data) - ctx_len, ctx_len)])
我们遍历数据,将大小为ctx_len的块作为输入(X),然后将相同大小的块,但索引加 1 后的部分作为标签(y)。然后我们将这些 Python 列表转化为 mlx 数组对象。模型的内部会使用 mlx 编写,因此我们希望我们的输入是 mlx 数组。
还有一件事。在训练过程中,我们不希望一次只输入一个样本,而是希望一次输入多个样本,以提高效率。这组样本被称为我们的批次,而每个批次中的样本数量就是我们的批量大小。因此,我们定义一个函数来生成训练用的批次。
def get_batches(X, y, b_size, shuffle=True):
if shuffle:
ix = np.arange(X.shape[0])
np.random.shuffle(ix)
ix = mx.array(ix)
X = X[ix]
y = y[ix]
for i in range(0, X.shape[0], b_size):
input = X[i:i+b_size]
label = y[i:i+b_size]
yield input, label
如果shuffle=True,我们通过随机打乱的索引对数据进行打乱。然后我们遍历数据集,并从输入数据和标签数据集中返回批量大小的块。这些块被称为小批次,它们只是我们并行处理的堆叠样本。这些小批次将在训练过程中作为输入提供给模型。
这是一个上下文长度为 8 的小批次示例,包含 4 个样本。

一个小批次(图由作者提供)
这个小批次包含了 32 个下一个标记预测问题。模型将为输入中的每个标记预测下一个标记,标签将用于计算损失。注意,标签包含了输入中每个索引对应的下一个标记。
你会希望将这个图像牢记在心,因为这些张量的形状会变得非常复杂。现在,只需要记住,我们将向模型输入一个形状为(batch_size, ctx_len)的张量。
编写 GPT-2 代码
让我们看看 GPT-2 架构,以便了解我们要实现的整体结构。

GPT-2 架构(图由作者提供)
如果这看起来有点困惑,不用担心。我们将从下到上一步步实现它。让我们从实现输入嵌入开始。
输入嵌入
输入嵌入层的目的是将标记 ID 映射到向量。每个标记将映射到一个向量,该向量将在模型中被传递并作为其表示。每个标记的向量将在模型中相互积累并交换信息,最终被用来预测下一个标记。这些向量被称为嵌入。
将标记 ID 映射到向量的最简单方法是通过查找表。我们创建一个大小为(vocab_size, n_emb)的矩阵,其中每一行是对应标记的嵌入向量。这个矩阵被称为嵌入权重。

嵌入层(图由作者提供)
图示展示了一个大小为(65, 6)的嵌入层示例。这意味着词汇表中有 65 个标记,每个标记将通过长度为 6 的嵌入向量表示。输入的序列将用于通过嵌入权重索引,以获取与每个标记对应的向量。记住我们输入到模型中的小批次吗?最初,小批次的大小是(batch_size, ctx_len)。经过嵌入层处理后,大小变为(batch_size, ctx_len, n_emb)。每个标记不再是一个单一的整数,而是一个长度为n_emb的向量。
现在让我们在代码中定义嵌入层。
n_emb = 6 # You can add these hyperparams at the top of your file
class GPT(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_emb)
我们将定义一个类来组织我们的实现。我们继承 nn.Module,以便利用 mlx 的特性。然后在 init 函数中,我们调用父类构造函数并初始化我们的 token 嵌入层,命名为wte。
位置嵌入
接下来是位置嵌入。位置嵌入的目的是编码序列中每个 token 的位置相关信息。这些信息可以与我们的输入嵌入相加,从而获得一个完整的 token 表示,该表示包含了 token 在序列中的位置信息。
class GPT(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_emb) # token embeddings
self.wpe = nn.Embedding(ctx_len, n_emb) # position embeddings
位置嵌入的工作方式与 token 嵌入相同,不同之处在于我们不是为每个 token 提供一行,而是为每个可能的位置索引提供一行。这意味着我们的嵌入权重的形状将是(ctx_len, n_emb)。现在我们在我们的 GPT 类中实现 call 函数。这个函数将包含模型的前向传播过程。
# Tensor shapes commented
def __call__(self, x):
B, T = x.shape # (B = batch_size, T = ctx_len)
tok_emb = self.wte(x) # (B, T, n_emb)
pos_emb = self.wpe(mx.arange(T)) # (T, n_emb)
x = tok_emb + pos_emb # (B, T, n_emb)
首先,我们将输入的维度拆解为变量 B 和 T,方便处理。在序列建模的上下文中,B 和 T 通常是“batch”(批次)和“time”(时间)维度的简写。在这种情况下,序列的“时间”维度即为上下文长度。
接下来,我们计算 token 和位置的嵌入。注意,对于位置嵌入,我们的输入是mx.arange(T)。这个操作将输出一个从 0 到 T-1 的连续整数数组,正好符合我们的需求,因为这些正是我们要嵌入的位置。将其通过嵌入层后,我们将得到一个形状为(T, n_emb)的张量,因为嵌入层会为每个位置提取一个长度为 n_emb 的向量。注意,尽管 pos_emb 与 tok_emb 的形状不同,但我们仍然可以将它们相加,因为 mlx 会进行广播操作,即在批次维度上复制 pos_emb,以便实现逐元素相加。最后,我们进行相加操作,得到包含位置信息的 token 新表示。
自注意力
到目前为止,每个 token 的表示向量是独立计算的。它们没有交换任何信息的机会。从语言建模的角度来看,这显然是有问题的,因为单词的含义和用法依赖于周围的上下文。自注意力机制是我们将前面 token 的信息融入到当前 token 的一种方式。
首先,让我们考虑一种简单的方法。如果我们仅仅将每个 token 表示为其表示向量和所有前面 token 的表示向量的平均值呢?这样就能达到将前面 token 的信息打包到当前 token 表示中的目标。下面是它的样子。

作者提供的图片
但是自注意力并不涉及编写 for 循环。关键的见解是,我们可以通过矩阵乘法来实现之前 token 的平均化!

作者提供的图片
通过将我们的输入序列左乘一个特殊矩阵,我们得到了期望的结果。这个矩阵被称为注意力权重。请注意,注意力权重矩阵的每一行都指定了“多少”每个其他 token 进入给定 token 的表示。例如,在第二行中,我们有 [0.5, 0.5, 0, 0]。这意味着结果的第二行将是 0.5*token1 + 0.5*token2 + 0*token3 + 0*token4,即 token1 和 token2 的平均值。请注意,注意力权重是一个下三角矩阵(右上角为零)。这确保了未来的 token 不会被包含在某个给定 token 的表示中。这保证了 token 只能与之前的 token 进行通信,因为在生成过程中,模型只能访问到之前的 tokens。
让我们来看一下如何构建注意力权重矩阵。

图片由作者提供
请注意,如果我们创建一个包含 -inf 的零数组,并将其放置在右上角位置,然后对每一行进行 softmax 操作,我们就能得到期望的注意力权重。一个好的练习是逐步进行某一行的 softmax 计算,看看它是如何工作的。结论是,我们可以取一个大小为 (ctx_len, ctx_len) 的数组,并对每一行进行 softmax 操作,从而得到加和为一的注意力权重。
现在我们可以离开朴素自注意力的领域了。我们不仅仅是对之前的 tokens 取平均值,而是对之前的 tokens 进行任意加权求和。注意,当我们对任意矩阵进行逐行 softmax 时会发生什么。

图片由作者提供
我们仍然得到每行加和为一的权重。在训练过程中,我们可以学习左侧矩阵中的数值,这些数值将指定每个 token 如何影响其他 token 的表示。这就是 token 如何彼此“关注”的方式。但我们仍然没有理解这个左侧矩阵是如何产生的。这些预 softmax 注意力权重是从 token 本身计算出来的,但通过三次线性投影间接获得的。
键、查询和值

图片由作者提供
我们序列中的每个令牌都会生成三个新的向量。这些向量被称为键(keys)、查询(queries)和值(values)。我们使用一个令牌的查询向量与另一个令牌的键向量的点积来量化这两个令牌之间的“亲和力”。我们希望计算每个令牌与其他每个令牌的成对亲和力,因此我们将查询向量(4x3)与转置后的键向量(3x4)相乘,以得到原始的注意力权重(4x4)。由于矩阵乘法的方式,原始注意力权重中(i,j)位置的值将是令牌 i 的查询向量与令牌 j 的键向量的点积,或者说是它们之间的“亲和力”。这样,我们就计算了每个令牌之间的相互作用。然而,我们不希望过去的令牌与未来的令牌发生交互,因此我们对右上部分的条目应用-∞的掩码,以确保它们在软最大化(softmax)之后会被归零。然后,我们执行按行的软最大化操作,以得到最终的注意力权重。我们并不直接将这些权重与输入相乘,而是将它们与值的投影相乘,这样就得到了新的表示。
现在我们在概念上已经理解了注意力机制,让我们来实现它。
class Attention(nn.Module):
def __init__(self, head_size):
super().__init__()
self.head_size = head_size
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
我们首先定义键、查询和值的投影层。请注意,我们并不是从 n_emb 映射到 n_emb,而是从 n_emb 映射到 head_size。这并不改变任何东西,只是意味着通过注意力计算得到的新表示将具有 head_size 维度。
class Attention(nn.Module):
def __init__(self, head_size):
super().__init__()
self.head_size = head_size
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
def __call__(self, x): # shapes commented
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
前向传播开始时计算键、查询和值的投影。为了便于以后使用,我们还将输入的形状拆解为变量 B、T 和 C。
class Attention(nn.Module):
def __init__(self, head_size):
super().__init__()
self.head_size = head_size
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
def __call__(self, x):
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
attn_weights = (Q @ K.transpose([0, 2, 1])) / math.sqrt(self.head_size)
# attn_weights.shape = (B, T, T)
接下来,我们计算注意力权重。我们只需要转置键张量的最后两个维度,因为批量维度只是为了让我们可以并行处理多个训练示例。mlx 转置函数期望输入的新维度顺序,因此我们传递[0, 2, 1]来转置最后两个维度。还有一件事:我们通过 head_size 的平方根的倒数来缩放注意力权重。这被称为缩放注意力,其目的是确保当 Q 和 K 的方差为单位方差时,attn_weights 的方差也是单位方差。如果 attn_weights 的方差很大,那么 softmax 将把这些小值和大值映射为 0 或 1,从而导致表示变得不那么复杂。
下一步是应用掩码,以确保我们进行的是因果语言建模,即确保令牌无法关注未来的令牌。
class Attention(nn.Module):
def __init__(self, head_size):
super().__init__()
self.head_size = head_size
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
indices = mx.arange(ctx_len)
mask = indices[:, None] < indices[None] # broadcasting trick
self._causal_mask = mask * -1e9
def __call__(self, x):
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
attn_weights = (Q @ K.transpose([0, 2, 1])) / math.sqrt(self.head_size)
# attn_weights.shape = (B, T, T)
我们通过巧妙的广播技巧创建掩码。假设我们的 ctx_len=4,如上面的图示所示。首先,我们使用 mx.arange(4)将索引变量设置为[0, 1, 2, 3]。

图片来源:作者
然后我们可以像这样索引 indices[:, None] 来生成一个包含 indices 值的列向量。类似地,我们可以使用 indices[None] 来获取一个行向量。然后,当我们进行 < 比较时,mlx 会广播这两个向量,因为它们的形状不匹配,无法进行元素级比较。广播意味着 mlx 会沿缺失的维度复制这些向量,从而实现两个 (4, 4) 矩阵的逐元素比较,这就合理了。顺便说一下,我建议通过阅读 这个 来熟悉广播的细节,它在处理张量时经常出现。
元素级比较后,我们得到如下张量:
[[False, True, True, True],
[False, False, True, True],
[False, False, False, True],
[False, False, False, False]]
将这个张量乘以 -1e9,我们得到:
[[-0e+00, -1e+09, -1e+09, -1e+09],
[-0e+00, -0e+00, -1e+09, -1e+09],
[-0e+00, -0e+00, -0e+00, -1e+09],
[-0e+00, -0e+00, -0e+00, -0e+00]]
现在我们有了一个加性掩码。我们可以将这个矩阵加到我们的注意力权重上,使得所有右上角的条目变得非常大的负数。这将导致它们在 softmax 操作后被置为零。此外,请注意,我们将“_”作为前缀添加到属性名 _causal_mask,这将其标记为私有变量。这向 mlx 发出信号,表示它不是一个参数,训练过程中不应更新。
class Attention(nn.Module):
def __init__(self, head_size):
super().__init__()
self.head_size = head_size
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
indices = mx.arange(ctx_len)
mask = indices[:, None] < indices[None] # broadcasting trick
self._causal_mask = mask * -1e9
def __call__(self, x):
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
attn_weights = (Q @ K.transpose([0, 2, 1])) / math.sqrt(self.head_size)
# attn_weights.shape = (B, T, T)
attn_weights = attn_weights + self._causal_mask
attn_weights = mx.softmax(attn_weights, axis=-1)
o = (attn_weights @ V) # (B, T, head_size)
现在我们可以按行应用 softmax,得到最终的注意力权重,并将这些权重与值相乘,得到我们的输出。注意,我们将 axis=-1 传递给 softmax,表示我们要在最后一个维度(即行)上进行 softmax 操作。
最后一步是输出的线性投影和丢弃层。
dropout = 0.1 # add this with hyperparams at top of file
class Attention(nn.Module):
def __init__(self, head_size):
super().__init__()
self.head_size = head_size
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
indices = mx.arange(ctx_len)
mask = indices[:, None] < indices[None] # broadcasting trick
self._causal_mask = mask * -1e9
self.c_proj = nn.Linear(head_size, n_emb) # output projection
self.resid_dropout = nn.Dropout(dropout)
def __call__(self, x):
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
attn_weights = (Q @ K.transpose([0, 2, 1])) / math.sqrt(self.head_size)
# attn_weights.shape = (B, T, T)
attn_weights = attn_weights + self._causal_mask
attn_weights = mx.softmax(attn_weights, axis=-1)
o = (attn_weights @ V) # (B, T, head_size)
o = self.c_proj(self.resid_dropout(o))
return o
我们添加了两个新层,c_proj 和 resid_dropout,分别是输出投影和残差丢弃层。输出投影用于将向量恢复到原始维度 n_emb。丢弃层用于正则化和训练稳定性,尤其在我们开始堆叠 Transformer 块以构建深度网络时,这非常重要。这就是实现一个注意力头的全部内容!
多头注意力
LLMs 通常使用多个并行的注意力头,而不仅仅是一个注意力头,并将它们的输出拼接起来生成最终表示。例如,假设我们有一个注意力头,其 head_size=64,那么它为每个标记生成的向量是 64 维的。我们也可以通过使用 4 个并行的注意力头,每个头的 head_size=16,并将它们的输出拼接起来,得到一个 16x4=64 维的输出,来实现相同的效果。多头注意力使得模型能够学习更复杂的表示,因为每个头学习不同的投影和注意力权重。
n_heads = 4
class MultiHeadAttention(nn.Module): # naive implementation
def __init__(self):
super().__init__()
self.heads = [Attention(head_size // n_heads) for _ in range(n_heads)]
def __call__(self, x):
return mx.concatenate([head(x) for head in self.heads], axis=-1)
直接的实现方式是创建一个 n_heads 的注意力头列表,其中每个头的大小等于最终头大小除以 n_heads。然后我们将每个头的输出沿最后一个维度进行拼接。然而,这种实现效率低下,并没有充分利用张量的计算速度。让我们利用张量的强大功能来实现多头注意力。
head_size = 64 # put at top of file
class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
indices = mx.arange(ctx_len)
mask = indices[:, None] < indices[None] # broadcasting trick
self._causal_mask = mask * -1e9
self.c_proj = nn.Linear(head_size, n_emb) # output projection
self.resid_dropout = nn.Dropout(dropout)
def __call__(self, x):
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
我们从单头注意力的实现开始。__init__()函数没有改变。前向传播像往常一样开始,首先是创建键、查询和值的投影。
head_size = 64 # put at top of file
n_heads = 8 # put at top of file
class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
indices = mx.arange(ctx_len)
mask = indices[:, None] < indices[None] # broadcasting trick
self._causal_mask = mask * -1e9
self.c_proj = nn.Linear(head_size, n_emb) # output projection
self.resid_dropout = nn.Dropout(dropout)
def __call__(self, x):
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
mha_shape = (B, T, n_heads, head_size//n_heads)
K = mx.as_strided(K, (mha_shape)) # (B, T, n_heads, head_size//n_heads)
Q = mx.as_strided(Q, (mha_shape)) # (B, T, n_heads, head_size//n_heads)
V = mx.as_strided(V, (mha_shape)) # (B, T, n_heads, head_size//n_heads)
接下来我们需要做的是引入一个新的维度来表示头部的数量n_heads。在朴素实现中,我们有独立的注意力对象,每个都有自己的键、查询和值张量,但现在我们将它们放在一个张量中,因此我们需要一个用于头部的维度。我们在mha_shape中定义我们想要的新形状。然后,我们使用mx.as_strided()来重新塑形每个张量,使其包含头部维度。这个函数相当于 pytorch 中的view,它告诉 mlx 将该数组视为不同的形状。但我们仍然存在一个问题。注意,如果我们尝试像之前一样将Q @ K_t(其中 K_t 是 K 的最后两个维度转置)相乘来计算注意力权重,我们将会乘上以下形状:
(B, T, n_heads, head_size//n_heads) @ (B, T, head_size//n_heads, n_heads)
Result shape: (B, T, n_heads, n_heads)
这将导致一个形状为(B, T, n_heads, n_heads)的张量,这是不正确的。在单头情况下,我们的注意力权重的形状为(B, T, T),这是合理的,因为它展示了每对 token 之间的交互。因此,现在我们的形状应该是相同的,只不过多了一个头部维度:(B, n_heads, T, T)。我们通过在重新塑形后转置键、查询和值的维度来实现这一点,从而使n_heads维度变为 1 而不是 2。
head_size = 64 # put at top of file
n_heads = 8 # put at top of file
class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
indices = mx.arange(ctx_len)
mask = indices[:, None] < indices[None] # broadcasting trick
self._causal_mask = mask * -1e9
self.c_proj = nn.Linear(head_size, n_emb) # output projection
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
def __call__(self, x):
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
mha_shape = (B, T, n_heads, head_size//n_heads)
K = mx.as_strided(K, (mha_shape)).transpose([0, 2, 1, 3]) # (B, n_heads, T, head_size//n_heads)
Q = mx.as_strided(Q, (mha_shape)).transpose([0, 2, 1, 3]) # (B, n_heads, T, head_size//n_heads)
V = mx.as_strided(V, (mha_shape)).transpose([0, 2, 1, 3]) # (B, n_heads, T, head_size//n_heads)
attn_weights = (Q @ K.transpose([0, 1, 3, 2])) / math.sqrt(Q.shape[-1]) # (B, n_heads, T, T)
attn_weights = attn_weights + self._causal_mask[:T, :T]
attn_weights = mx.softmax(attn_weights, axis=-1)
attn_weights = self.attn_dropout(attn_weights)
o = (attn_weights @ V) # (B, n_heads, T, head_size//n_heads)
现在我们可以计算修正后的注意力权重。注意,我们通过单个注意力头的大小来缩放注意力权重,而不是使用拼接后大小的head_size。我们还对注意力权重应用了 dropout。
最后,我们执行拼接,并应用输出投影和 dropout。
head_size = 64 # put at top of file
n_heads = 8 # put at top of file
class MultiHeadAttention(nn.Module):
def __init__(self):
super().__init__()
self.k_proj = nn.Linear(n_emb, head_size, bias=False)
self.q_proj = nn.Linear(n_emb, head_size, bias=False)
self.v_proj = nn.Linear(n_emb, head_size, bias=False)
indices = mx.arange(ctx_len)
mask = indices[:, None] < indices[None] # broadcasting trick
self._causal_mask = mask * -1e9
self.c_proj = nn.Linear(head_size, n_emb) # output projection
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
def __call__(self, x):
B, T, C = x.shape # (batch_size, ctx_len, n_emb)
K = self.k_proj(x) # (B, T, head_size)
Q = self.q_proj(x) # (B, T, head_size)
V = self.v_proj(x) # (B, T, head_size)
mha_shape = (B, T, n_heads, head_size//n_heads)
K = mx.as_strided(K, (mha_shape)).transpose([0, 2, 1, 3]) # (B, n_heads, T, head_size//n_heads)
Q = mx.as_strided(Q, (mha_shape)).transpose([0, 2, 1, 3]) # (B, n_heads, T, head_size//n_heads)
V = mx.as_strided(V, (mha_shape)).transpose([0, 2, 1, 3]) # (B, n_heads, T, head_size//n_heads)
attn_weights = (Q @ K.transpose([0, 1, 3, 2])) / math.sqrt(Q.shape[-1]) # (B, n_heads, T, T)
attn_weights = attn_weights + self._causal_mask[:T, :T]
attn_weights = mx.softmax(attn_weights, axis=-1)
attn_weights = self.attn_dropout(attn_weights)
o = (attn_weights @ V) # (B, n_heads, T, head_size//n_heads)
o = o.transpose([0, 2, 1, 3]).reshape((B, T, head_size)) # concat heads
o = self.c_proj(self.resid_dropout(o))
return o
因为我们将所有内容都放在一个张量中,所以我们可以进行一些形状操作来实现拼接。首先,我们使用转置函数将n_heads移回到倒数第二个维度。然后,我们将其重新塑形为原始大小,以撤销我们之前对头部的拆分。这就相当于将每个头部的最终向量进行拼接。这就是多头注意力的全部内容!我们已经完成了实现中最复杂的部分。
MLP
架构的下一部分是多层感知机(MLP)。这实际上是指堆叠的 2 个线性层。这里没什么好说的,它是一个标准的神经网络。
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.c_fc = nn.Linear(n_emb, 4 * n_emb)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * n_emb, n_emb)
self.dropout = nn.Dropout(dropout)
def __call__(self, x):
x = self.gelu(self.c_fc(x))
x = self.c_proj(x)
x = self.dropout(x)
return x
我们将输入通过c_fc投影到更高的维度。然后,我们应用 gelu 非线性函数,并通过c_proj将其投影回嵌入维度。最后,我们应用 dropout 并返回。MLP 的目的是在注意力层的向量交换之后,进行一些额外的计算。我们将这些通信层(注意力)和计算层(mlp)堆叠成一个块。
块
一个 GPT 块由注意力和 MLP 组成。这些块会被重复使用,从而使架构变得深度。
class Block(nn.Module):
def __init__(self):
super().__init__()
self.mlp = MLP()
self.mha = MultiHeadAttention()
def __call__(self, x):
x = self.mha(x)
x = self.mlp(x)
return x
现在,我们需要添加两个额外的功能来提高训练稳定性。让我们再次查看架构图。
层归一化与跳跃连接

图像来源:作者
我们仍然需要实现那些以红色高亮显示的组件。箭头表示跳跃连接。输入不会直接被转换,而是注意力和 MLP 层的效果是加性的。它们的结果会被加到输入中,而不是直接替换输入。这对深度网络的训练稳定性有好处,因为在反向传播中,加法操作的操作数会接收到与其和相同的梯度。因此,梯度可以自由地向后流动,从而防止了像梯度消失/爆炸这类困扰深度网络的问题。层归一化也有助于训练稳定性,确保激活值服从正态分布。下面是最终的实现。
class Block(nn.Module):
def __init__(self):
super().__init__()
self.mlp = MLP()
self.mha = MultiHeadAttention()
self.ln_1 = nn.LayerNorm(dims=n_emb)
self.ln_2 = nn.LayerNorm(dims=n_emb)
def __call__(self, x):
x = x + self.mha(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
层归一化应用于多头注意力和 MLP 之前。跳跃连接通过x = x + ...的方式加入,使得操作变为加性。
前向传播
定义完 Block 之后,我们可以完成完整的 GPT-2 前向传播。
n_layers = 3 # put at top of file
class GPT(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_emb) # token embeddings
self.wpe = nn.Embedding(ctx_len, n_emb) # position embeddings
self.blocks = nn.Sequential(
*[Block() for _ in range(n_layers)],
) # transformer blocks
self.ln_f = nn.LayerNorm(dims=n_emb) # final layernorm
self.lm_head = nn.Linear(n_emb, vocab_size) # output projection
# Tensor shapes commented
def __call__(self, x):
B, T = x.shape # (B = batch_size, T = ctx_len)
tok_emb = self.wte(x) # (B, T, n_emb)
pos_emb = self.wpe(mx.arange(T)) # (T, n_emb)
x = tok_emb + pos_emb # (B, T, n_emb)
x = self.blocks(x) # (B, T, n_emb)
x = self.ln_f(x) # (B, T, b_emb)
logits = self.lm_head(x) # (B, T, vocab_size)
return logits
我们使用nn.Sequential创建一个容器来包含这些模块,该容器接受任何输入并将其按顺序传递通过各个层。然后我们可以通过self.blocks(x)应用所有模块。最后,我们应用层归一化,然后是 lm_head。lm_head 或语言模型头只是一个线性层,它将嵌入维度映射到词汇表大小。模型将输出一个包含每个词汇表中单词的某些值的向量,或称为 logits。我们可以对 logits 进行 softmax 操作,以获得词汇表上的概率分布,然后从中采样得到下一个标记。我们还会使用 logits 在训练过程中计算损失。在开始训练之前,我们只需要实现两个组件。
采样
我们需要编写一个生成函数,以便在训练完成后从模型中进行采样。其思路是:我们从选择的某个序列开始,然后预测下一个标记并将其添加到序列中。接着我们将新的序列输入模型,再次预测下一个标记。这个过程会一直持续,直到我们决定停止。
# method of GPT class
def generate(self, max_new_tokens):
ctx = mx.zeros((1, 1), dtype=mx.int32)
我们用一个单独的标记零来提示模型。零是换行符,因此它是生成的自然起点,因为我们只是想看看我们的模型能生成多像莎士比亚的风格。请注意,我们将形状初始化为(1, 1),以模拟一个长度为 1 的单个批次。
# method of GPT class
def generate(self, max_new_tokens):
ctx = mx.zeros((1, 1), dtype=mx.int32)
for _ in range(max_new_tokens):
logits = self(ctx[:, -ctx_len:]) # pass in last ctx_len characters
logits = logits[:, -1, :] # get logits for the next token
next_tok = mx.random.categorical(logits, num_samples=1)
ctx = mx.concatenate((ctx, next_tok), axis=1)
return ctx
然后,我们通过将最后的 ctx_len 个字符传递给模型来获得下一个 token 的 logits。然而,我们的模型输出形状为(B, T, vocab_size),因为它预测输入中每个 token 的下一个 token 的 logits。在训练过程中我们会使用全部信息,但现在我们只想获取最后一个 token 的 logits,因为我们可以利用这个信息来采样一个新的 token。因此,我们索引 logits,获取第一个维度中最后一个元素,即序列维度。接着,我们使用mx.random.categorical()函数来采样下一个 token,该函数接受 logits 和我们希望的采样数量作为输入。这个函数会对 logits 进行 softmax,将其转化为概率分布,然后根据概率随机采样一个 token。最后,我们将新的 token 拼接到上下文中,并重复这个过程max_new_tokens次。
初始化
最后需要做的是处理权重初始化,这对训练动态非常重要。
# method of GPT
def _init_parameters(self):
normal_init = nn.init.normal(mean=0.0, std=0.02)
residual_init = nn.init.normal(mean=0.0, std=(0.02 / math.sqrt(2 * n_layers)))
首先,我们定义了两个不同的nn.init.normal函数。第一个用于初始化所有线性层和嵌入层。第二个用于初始化专门的残差投影线性层,即多头注意力和 MLP 中的最后一个线性层。进行这种特殊初始化的原因是,它会根据 GPT-2 论文[2]检查随着模型深度增加,残差路径上的累积。
在 mlx 中,我们可以使用[mx.update()](https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary/mlx.nn.Module.update.html#mlx.nn.Module.update)函数更改模型的参数。查看文档,它期望一个完整或部分的字典,包含新的模型参数。我们可以通过在 GPT 类中打印self.parameters()来查看这个字典的样子。
{'wte': {'weight': array([[-0.025084, -0.0197523, -0.0341617, ..., -0.0979123, -0.0830218, -0.0784692],
[-0.00777913, -0.117002, -0.0310708, ..., 0.0128591, 0.122941, 0.000414443],
[0.0240044, -0.0859084, 0.0253116, ..., 0.108967, 0.0767123, 0.0221565],
...,
[0.050729, -0.04578, 0.0685943, ..., -0.0496998, -0.00350879, -0.00631825],
[0.00518804, 0.0499818, 0.0330045, ..., 0.0300661, 0.0431054, 0.000958906],
[-0.0323007, 0.0132046, 0.0208218, ..., -0.0785159, 0.00436121, -0.00726994]], dtype=float32)}, 'wpe': {'weight': array([[0.000797923, -0.0396898, -0.029047, ..., -0.0132273, 0.00684483, -0.0067624],
[-0.0247021, -0.0274349, 0.0310587, ..., -0.100099, 0.0301566, -0.0178732],
[0.0929172, -0.0468649, 0.0101506, ..., -0.0341086, -0.0516283, 0.0447596],
...,
[-0.0508172, 0.0892201, -0.00183612, ..., -0.00341944, 0.023437, 0.0296461],
[0.0105829, 0.0688093, 0.146744, ..., -0.0836337, 0.0206679, 0.0184166],
[-0.00578717, -0.0606196, -0.0917056, ..., -0.0641549, -0.0490424, 0.0998114]], dtype=float32)}, 'blocks': {'layers': [{'mlp': {'c_fc': {'weight': array([[0.0169199, 0.00264431, 0.0316978, ..., -0.0596867, -0.0153549, 0.0176386],
...
它是一个嵌套的字典,包含每个模型的权重,类型为 mx.array。因此,为了初始化我们模型的参数,我们需要像这样构建一个包含新参数的字典,并将其传递给self.update()。我们可以通过以下方式实现:
# method of GPT
def _init_parameters(self):
normal_init = nn.init.normal(mean=0.0, std=0.02)
residual_init = nn.init.normal(mean=0.0, std=(0.02 / math.sqrt(2 * n_layers)))
new_params = []
for name, module in self.named_modules():
if isinstance(module, nn.layers.linear.Linear):
new_params.append((name + '.weight', normal_init(module.weight)))
elif isinstance(module, nn.layers.embedding.Embedding):
new_params.append((name + '.weight', normal_init(module.weight)
我们维护一个名为new_params的元组列表,其中包含(参数名, 新值)的元组。接下来,我们通过self.named_modules()循环遍历模型中的每个 nn.Module 对象,它返回(name, module)的元组。如果我们在循环中打印出模块名称,我们会看到它们像这样:
lm_head
blocks
blocks.layers.4
blocks.layers.3
blocks.layers.3.ln_2
blocks.layers.3.ln_1
blocks.layers.3.mha
blocks.layers.3.mha.resid_dropout
blocks.layers.3.mha.c_proj
blocks.layers.3.mha.attn_dropout
blocks.layers.3.mha.c_attn
...
blocks.layers.0.mlp.dropout
blocks.layers.0.mlp.c_proj
blocks.layers.0.mlp.gelu
blocks.layers.0.mlp.c_fc
wpe
wte
我们使用isinstance()函数来查找线性层和嵌入层,然后将它们添加到我们的列表中。例如,假设我们正在循环并到达“blocks.layers.0.mlp.c_fc”,这是 MLP 中的第一个线性层。这将触发第一个 if 语句,并将元组("block.layers.0.mlp.c_fc.weight", [<normally initialized weight here>])添加到我们的列表中。我们必须在名称中添加“.weight”,因为我们特别想以这种方式初始化权重,而不是偏置。现在,我们需要处理残差投影的初始化。
# method of GPT
def _init_parameters(self):
normal_init = nn.init.normal(mean=0.0, std=0.02)
residual_init = nn.init.normal(mean=0.0, std=(0.02 / math.sqrt(2 * n_layers)))
new_params = []
for name, module in self.named_modules():
if isinstance(module, nn.layers.linear.Linear):
if 'c_proj' in name: # residual projection
new_params.append((name + '.weight', residual_init(module.weight)))
else:
new_params.append((name + '.weight', normal_init(module.weight)))
elif isinstance(module, nn.layers.embedding.Embedding):
new_params.append((name + '.weight', normal_init(module.weight)))
在检查模块是否为线性层后,我们检查名称中是否包含“c_proj”,因为我们将残差投影命名为此。然后我们可以应用特殊的初始化方法。最后,我们需要将偏置初始化为零。
# method of GPT
def _init_parameters(self):
normal_init = nn.init.normal(mean=0.0, std=0.02)
residual_init = nn.init.normal(mean=0.0, std=(0.02 / math.sqrt(2 * n_layers)))
new_params = []
for name, module in self.named_modules():
if isinstance(module, nn.layers.linear.Linear):
if 'c_proj' in name:
new_params.append((name + '.weight', residual_init(module.weight)))
else:
new_params.append((name + '.weight', normal_init(module.weight)))
if 'bias' in module:
new_params.append((name + '.bias', mx.zeros(module.bias.shape)))
elif isinstance(module, nn.layers.embedding.Embedding):
new_params.append((name + '.weight', normal_init(module.weight)))
self = self.update(utils.tree_unflatten(new_params))
我们在线性分支下再添加一个 if 语句来检查nn.Module对象是否具有偏置属性。如果有,我们将其添加到初始化为零的列表中。最后,我们需要将元组列表转换为嵌套字典。幸运的是,mlx 已经实现了一些函数来处理参数字典,我们可以使用util.tree_unflatten()将这个元组列表转换为嵌套的参数字典。这个字典会传递到更新方法中以初始化参数。现在我们可以在构造函数中调用_init_parameters()。
class GPT(nn.Module):
def __init__(self):
super().__init__()
self.wte = nn.Embedding(vocab_size, n_emb) # token embeddings
self.wpe = nn.Embedding(ctx_len, n_emb) # position embeddings
self.blocks = nn.Sequential(
*[Block() for _ in range(n_layers)],
) # transformer blocks
self.ln_f = nn.LayerNorm(dims=n_emb) # final layernorm
self.lm_head = nn.Linear(n_emb, vocab_size) # output projection
self._init_parameters() # <-- initialize params
# print total number of params on initialization
total_params = sum([p.size for n,p in utils.tree_flatten(self.parameters())])
print(f"Total params: {(total_params / 1e6):.3f}M")
# Tensor shapes commented
def __call__(self, x):
B, T = x.shape # (B = batch_size, T = ctx_len)
tok_emb = self.wte(x) # (B, T, n_emb)
pos_emb = self.wpe(mx.arange(T)) # (T, n_emb)
x = tok_emb + pos_emb # (B, T, n_emb)
x = self.blocks(x) # (B, T, n_emb)
x = self.ln_f(x) # (B, T, b_emb)
logits = self.lm_head(x) # (B, T, vocab_size)
return logits
def generate(self, max_new_tokens):
ctx = mx.zeros((1, 1), dtype=mx.int32)
for _ in range(max_new_tokens):
logits = self(ctx[:, -ctx_len:]) # pass in last ctx_len characters
logits = logits[:, -1, :] # get logits for the next token
next_tok = mx.random.categorical(logits, num_samples=1)
ctx = mx.concatenate((ctx, next_tok), axis=1)
return ctx
def _init_parameters(self):
normal_init = nn.init.normal(mean=0.0, std=0.02)
residual_init = nn.init.normal(mean=0.0, std=(0.02 / math.sqrt(2 * n_layers)))
new_params = []
for name, module in self.named_modules():
if isinstance(module, nn.layers.linear.Linear):
if 'c_proj' in name:
new_params.append((name + '.weight', residual_init(module.weight)))
else:
new_params.append((name + '.weight', normal_init(module.weight)))
if 'bias' in module:
new_params.append((name + '.bias', mx.zeros(module.bias.shape)))
elif isinstance(module, nn.layers.embedding.Embedding):
new_params.append((name + '.weight', normal_init(module.weight)))
self = self.update(utils.tree_unflatten(new_params))
我们还在构造函数中添加了两行代码来打印总的参数数量。最后,我们准备构建训练循环。
训练循环
为了训练模型,我们需要一个损失函数。由于我们预测的是类别(下一个令牌),所以我们使用交叉熵损失函数。
def loss_fn(model, x, y):
logits = model(x)
B, T, C = logits.shape # (batch_size, seq_len, vocab_size)
logits = logits.reshape(B*T, C)
y = y.reshape(B*T)
loss = nn.losses.cross_entropy(logits, y, reduction='mean')
return loss
首先,我们从模型中获取 logits。然后,我们重新调整 logits 的形状,形成一个词汇表大小长度的数组列表。我们还将正确的令牌 ID(y)重新调整形状,以使其具有相同的长度。接着,我们使用内建的交叉熵损失函数来计算每个示例的损失,并将它们平均得到一个单一的值。
model = GPT()
mx.eval(model.parameters()) # Create the model params (mlx is lazy evaluation)
loss_and_grad = nn.value_and_grad(model, loss_fn)
lr = 0.1
optimizer = optim.AdamW(learning_rate=lr)
接下来,我们实例化模型,但由于 mlx 是惰性求值,它不会分配和创建参数。我们需要在参数上调用 mx.eval 以确保它们被创建。然后我们可以使用[nn.value_and_grad()](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.nn.value_and_grad.html)来获得一个函数,该函数返回模型参数相对于损失的损失值和梯度。这是我们优化所需的一切。最后,我们初始化一个 AdamW 优化器。
关于nn.value_and_grad()的简短说明。如果你习惯了 PyTorch,可能会期望我们使用loss.backward(),这会遍历计算图并更新模型中每个张量的.grad属性。然而,mlx 的自动微分是基于函数而非计算图的[3]。因此,mlx 提供了内建函数,接受一个函数并返回该函数的梯度函数,例如nn.value_and_grad()。
现在我们定义训练循环。
num_epochs=20
batch_size=32
for epoch in range(num_epochs):
model.train(True)
running_loss = 0
batch_cnt = 0
for input, label in get_batches(X_train, y_train, batch_size):
batch_cnt += 1
loss, grads = loss_and_grad(model, input, label)
optimizer.update(model, grads)
running_loss += loss.item()
# compute new parameters and optimizer state
mx.eval(model.parameters(), optimizer.state)
avg_train_loss = running_loss / batch_cnt
model.train(False) # set eval mode
running_loss = 0
batch_cnt = 0
for input, label in get_batches(X_val, y_val, batch_size):
batch_cnt += 1
loss = loss_fn(model, input, label)
running_loss += loss.item()
avg_val_loss = running_loss / batch_cnt
print(f"Epoch {epoch:2} | train = {avg_train_loss:.4f} | val = {avg_val_loss:.4f}")
外循环遍历各个训练轮次。我们首先将模型设置为训练模式,因为某些模块在训练和测试阶段的行为有所不同,例如 dropout。然后,我们使用之前的get_batches函数来遍历训练数据的批次。我们通过loss_and_grad获得批次的损失和梯度。接着,我们将模型和梯度传递给优化器,以更新模型的参数。最后,我们调用 mx.eval(记住,mlx 是懒惰求值)来确保参数和优化器的状态得到更新。然后,我们计算数据集的平均训练损失,以便稍后打印。这就是一次完整的训练数据遍历。类似地,我们计算验证损失,然后打印每个训练轮次的平均训练损失和验证损失。
completion = decode(model.generate(1000)[0].tolist())
print(completion)
with open('completions.txt', 'w') as f:
f.write(completion)
最后,我们添加一些代码来从我们的模型生成数据。由于生成的输出仍然是(B,T)形状,我们需要在第 0 个索引位置进行切片,使其变为 1D 数组,然后将其从 mlx 数组转换为 Python 列表。接着,我们可以将其传递给之前的解码函数,并将结果写入文件。
这些是我们将在训练中使用的参数(你可以自行尝试调整这些参数):
ctx_len = 128
n_emb = 128
dropout = 0.1
head_size = 128
n_heads = 4
n_layers = 3
num_epochs = 20
batch_size = 64
lr = 1e-3
现在我们可以运行文件来开始训练。根据上述设置,训练大约花了 10 分钟,在我的 M2 MacBook 上训练完成。上一个训练轮次的损失如下:
Epoch 19 | train = 1.6961 | val = 1.8143
让我们来看一下输出结果。
GLOUCESTER:
But accomes mo move it.
KING EDWARD:
Where our that proclaim that I curse, or I sprithe.
CORIOLANUS:
Not want:
His bops to thy father
At with hath folk; by son and fproathead:
The good nor may prosperson like it not,
What, the beggares
More hath, when that made a,
Your vainst Citizen:
Let here are go in queen me and knife
To my deserved me you promise: not a fettimes,
That one the will not.
CORIOLANUS:
And been of queens,
Thou to do we best!
JULIET:
Not, brother recourable this doth our accuse
Into fight!
经过短短 10 分钟的训练,这个小模型竟然能够预测字符!虽然输出内容是无意义的,但它显然具备了莎士比亚作品的风格。现在我们模型与真正的 GPT-2 之间唯一的区别就是规模!我现在鼓励你进行实验——尝试不同的设置,可能调整一下架构,看看你能达到多低的损失。
参考文献
[1] Karpathy A (2015).Tiny Shakespeare 数据集。github.com/karpathy/char-rnn(MIT 许可证)
[2] A. Radford, J. Wu, R. Child, D. Luan, D. Amodei, I. Sutskever,语言模型是无监督的多任务学习者(2019),OpenAI
[3] 自动微分 — mlx 文档
GPT 模型:它是如何工作的?
让我们一起用 Python 和 PyTorch 深入探索它的工作原理吧
·发布于 Towards Data Science ·阅读时间:9 分钟·2024 年 2 月 21 日
--

图片来自 Hal Gatewood,Unsplash
在过去几年里,围绕人工智能的热议一直非常庞大,而这一切的主要推动力显然是基于 GPT 的大型语言模型的出现。有趣的是,这种方法本身并不新颖。LSTM(长短期记忆)神经网络是 1997 年创建的,而 2017 年发表的著名论文《Attention is All You Need》则为现代自然语言处理奠定了基石。但直到 2020 年,GPT-3 的效果才足够好,不仅仅是在学术论文中得到认可,也能在现实世界中得到应用。
如今,任何人都可以在网页浏览器中与 GPT 聊天,但可能不到 1%的人真正了解它是如何工作的。模型给出的聪明且机智的回答让人产生与智能生物对话的错觉,但真的是这样吗?好吧,了解它的最好方式就是看看它是如何工作的。在本文中,我们将拿 OpenAI 的一个真实 GPT 模型来做演示,在本地运行它,并一步步查看它的工作原理。
本文面向初学者以及对编程和数据科学感兴趣的读者。我将通过 Python 来演示我的步骤,但不需要对 Python 有深入的理解。
让我们开始吧!
加载模型
GPT 与额头侦探
OpenAI LLMs 的推理能力足够好,可以玩经典的猜谜游戏吗?
·发表于 Towards Data Science ·11 分钟阅读·2024 年 9 月 5 日
--

GPT 玩“额头侦探”游戏 | 图片来自 DALL·E
每当我和家人一起进行长途旅行时,我们喜欢用这个时间玩各种口头游戏,特别是那些不需要我作为司机过多集中注意力的游戏。我们最喜欢的游戏之一是角色猜谜游戏。规则很简单:一个人想到一个角色——可以是真实的也可以是虚构的——其余的人尝试猜出这个角色是谁。我们只能问只能用简单的“是”或“不是”来回答的封闭式问题。想到角色的人只能用这些答案来回应。我们轮流提问,首先正确猜出角色的人获胜。所以这其实就像“额头侦探”游戏,只不过我们没有把便签贴在头上。
在我们最近的一次旅行中,我们正在玩这个游戏,但过了一段时间后,我们开始没什么新想法来猜测角色。这时我想:“为什么不让 ChatGPT 来接管这个任务呢?”我在手机上安装了 ChatGPT 应用,于是我打开它,让 ChatGPT 想出一个角色供我们猜测。我还指示它回答我们关于角色的猜测。于是,我们开始和 ChatGPT 一起玩。它想出了一个角色,我们轮流提问。人工智能……
GPU 加速 Polars — 直观且详尽的解释
大数据问题的快速数据框
·发布于Towards Data Science ·14 分钟阅读·2024 年 9 月 17 日
--

“加速极地”由 Daniel Warfield 使用 Midjourney 制作。所有图像均由作者提供,除非另有说明。本文最初发布在直观且详尽的解释。
最近,我参加了一个由 Cuda 和 Polars 团队举办的秘密演示。他们让我通过金属探测器,给我戴上了袋子,把我送到法国乡村的一座小屋里。他们收走了我的手机、钱包和护照,以确保我在最后揭示他们的工作成果之前不会泄露消息。
或者说,感觉是这样的。实际上那是一次 Zoom 会议,他们礼貌地要求我在指定时间之前不要说话,但作为一名技术作家,这种神秘感让我有些像詹姆斯·邦德。
在本文中,我们将讨论那次会议的内容:Polars 中的一个新执行引擎,它支持 GPU 加速计算,使得能够对 100GB 以上的数据进行交互式操作。我们将讨论 Polars 中的数据框是什么,GPU 加速如何在 Polars 数据框中工作,以及新 CUDA 驱动的执行引擎能带来多大的性能提升。
这对谁有用? 任何与数据打交道并希望加速工作的人。
这篇文章有多先进? 这篇文章包含了简单但前沿的数据工程概念。它适用于各个层次的读者。
梯度提升回归器详解:带有代码示例的视觉指南
集成学习
一次一个增强阶段地拟合误差
·发表于 Towards Data Science ·阅读时间 11 分钟·2024 年 11 月 14 日
--
通过成本复杂度修剪智能地修剪分支
towardsdatascience.com
当然,在机器学习中,我们希望我们的预测非常准确。我们从简单的决策树开始——它们还不错。然后出现了随机森林和AdaBoost,效果更好。但梯度提升呢?那简直是一个游戏规则的改变,使得预测变得更加准确。
他们说:“使得梯度提升如此有效的原因其实很简单:它逐个构建模型,每个新模型都专注于修正所有之前模型的错误。这样逐步修正错误的方式就是它特别之处。”我本以为这真的会这么简单,但每次我查阅梯度提升,试图理解它是如何工作的时,我看到的却是相同的内容:一行行复杂的数学公式和令人头疼的图表,这些东西总是让我抓狂。试试看吧。
让我们终结这一点,并以一种实际有意义的方式来分解它。我们将通过梯度提升的训练步骤进行可视化导航,专注于回归案例——这是一个比分类更简单的场景——这样我们就可以避免混淆的数学计算。就像多级火箭为了达到轨道而丢弃不必要的重量一样,我们将一项一项地消除那些预测误差。

所有视觉图像:作者使用 Canva Pro 创建。优化为移动设备显示;在桌面上可能显示过大。
定义
梯度提升是一种集成机器学习技术,它构建一系列决策树,每棵树旨在纠正前一棵树的错误。与使用浅层树的 AdaBoost 不同,梯度提升使用更深的树作为其弱学习器。每棵新树的目标是最小化残差误差——实际值与预测值之间的差异——而不是直接从原始目标中学习。
对于回归任务,梯度提升逐一添加决策树,每棵新树的训练目标是通过解决当前的残差误差来减少剩余的错误。最终的预测是通过将所有树的输出相加得到的。
模型的优势来自于其加法学习过程——每棵树专注于纠正集成中的剩余错误,而顺序组合的方式使得它成为一个强大的预测器,通过专注于模型仍然难以处理的问题部分,逐步减少整体预测误差。

梯度提升是提升家族算法的一部分,因为它是逐一构建树的,每棵新树尝试纠正前一棵树的错误。然而,与其他提升方法不同,梯度提升从优化的角度来解决问题。
使用的数据集
在本文中,我们将以经典的高尔夫数据集作为回归的例子。尽管梯度提升可以有效地处理回归和分类任务,但我们将专注于更简单的回归任务——根据天气条件预测来打高尔夫的玩家人数。

列:‘阴天(通过一热编码转换为 3 列)’,‘温度’(华氏度),‘湿度’(百分比),‘有风’(是/否)和‘玩家人数’(目标特征)
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
# Create dataset
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rain', 'rain', 'rain', 'overcast',
'sunny', 'sunny', 'rain', 'sunny', 'overcast', 'overcast', 'rain',
'sunny', 'overcast', 'rain', 'sunny', 'sunny', 'rain', 'overcast',
'rain', 'sunny', 'overcast', 'sunny', 'overcast', 'rain', 'overcast'],
'Temp.': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0,
72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0,
88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
'Humid.': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0,
90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0,
65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'Wind': [False, True, False, False, False, True, True, False, False, False, True,
True, False, True, True, False, False, True, False, True, True, False,
True, False, False, True, False, False],
'Num_Players': [52, 39, 43, 37, 28, 19, 43, 47, 56, 33, 49, 23, 42, 13, 33, 29,
25, 51, 41, 14, 34, 29, 49, 36, 57, 21, 23, 41]
}
# Prepare data
df = pd.DataFrame(dataset_dict)
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='')
df['Wind'] = df['Wind'].astype(int)
# Split features and target
X, y = df.drop('Num_Players', axis=1), df['Num_Players']
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)
主要机制
以下是梯度提升的工作原理:
初始化模型: 从一个简单的预测开始,通常是目标值的平均值。
迭代学习: 在设定的迭代次数内,计算残差,训练一棵决策树来预测这些残差,并将新树的预测结果(按学习率缩放)添加到运行总和中。
在残差上构建决策树: 每棵新树专注于所有前期迭代中的剩余误差。
最终预测: 汇总所有树的贡献(按学习率缩放)和初始预测。

梯度提升回归模型从平均预测开始,通过多棵树进行改进,每棵树都在小步修正前一棵树的错误,直到达到最终预测。
训练步骤
我们将遵循标准的梯度提升方法:
1.0. 设置模型参数:
在构建任何树之前,我们需要设置控制学习过程的核心参数:
· 树的数量(通常为 100,但我们选择 50)按顺序构建,
· 学习率(通常为 0.1),以及
· 每棵树的最大深度(通常为 3)

一棵树的图示,展示了我们的关键设置:每棵树将有 3 个层级,我们将创建 50 棵树,并在每次迭代中以 0.1 的小步前进。
对于第一棵树
2.0 对标签进行初始预测。通常这是均值(就像是 a dummy prediction 一样。)

为了开始我们的预测,我们使用所有训练数据的平均值(37.43)作为每个案例的第一次猜测。
2.1. 计算临时残差(或伪残差):
残差 = 实际值 — 预测值

通过从每个目标值中减去均值预测(37.43)来计算初始残差。
2.2. 构建决策树以预测这些残差。树的构建步骤与回归树完全相同。

第一棵决策树开始训练时,通过寻找特征中的模式,来预测我们初步均值预测的计算残差。
- 计算根节点的初始均方误差(MSE)

就像常规回归树一样,我们计算均方误差(MSE),但这次我们测量的是残差的分布(围绕零),而不是实际值的分布(围绕它们的均值)。
- 对每个特征:
· 按特征值对数据进行排序

对于数据集中的每个特征,我们对其值进行排序并找到潜在的分裂点,正如我们在标准决策树中所做的那样,来确定最好的方式来划分我们的残差。
· 对每个可能的分裂点:
·· 将样本分为左组和右组
·· 计算两个组的均方误差(MSE)
·· 计算这个分裂的均方误差(MSE)减少量

类似于常规的回归树,我们通过计算两组的加权均方误差(MSE)来评估每次划分,但这里我们衡量的是划分后的组如何聚集相似的残差,而不是相似的目标值。
- 选择能够带来最大 MSE 降低的分裂


树通过使用“rain”特征(值为 0.5)进行第一次分裂,基于残差将样本分成两组——这个第一次决策将在更深层次的进一步分裂中得到精炼。
- 继续分裂,直到达到最大深度或每个叶子的最小样本数。

经历了三层基于不同特征的分裂后,我们的第一棵树创建了八个不同的组,每个组都有自己的残差预测值。
2.3. 计算叶子节点值
对于每个叶子节点,计算残差的均值。

第一棵树的每个叶子节点包含该组残差的平均值——这些值将用于调整和改善我们最初的 37.43 的预测。
2.4. 更新预测
· 对于训练数据集中的每个数据点,基于新树确定它属于哪个叶子节点。

将我们的训练数据通过第一棵树运行时,每个样本根据天气特征沿着自己的路径获取预测残差值,这将帮助修正我们最初的预测。
· 将新树的预测结果乘以学习率,然后将这些缩放后的预测值加到当前模型的预测结果中。这将是更新后的预测。

我们的模型通过采取小步进来更新预测:它只将每个预测残差的 10%(学习率为 0.1)加到我们最初的 37.43 预测值上,从而得到稍微改进的预测。
对于第二棵树
2.1. 基于当前模型计算新的残差
- 计算目标预测值与当前预测值之间的差异。
这些残差与第一次迭代的残差会略有不同。

更新了第一棵树的预测后,我们计算新的残差——注意到它们比原来的残差稍微小一些,显示我们的预测逐渐得到了改善。
2.2. 构建一棵新树来预测这些残差。过程与第一棵树相同,但目标是新的残差。

启动我们的第二棵树来预测新的、更小的残差——我们将使用与之前相同的树构建过程,但这次我们试图捕捉第一棵树遗漏的错误。
2.3. 计算每个叶子节点的均值残差

第二棵树与第一棵树的结构相同,使用相同的天气特征和分裂点,但其叶节点的值较小——这表明我们正在微调剩余的误差。
2.4. 更新模型预测
· 将新树的预测乘以学习率。
· 将新缩放过的树预测加到当前总和中。

在将数据通过第二棵树后,我们再次以 0.1 的学习率做出小步调整以更新预测,并计算出比之前更小的残差——我们的模型正在逐渐学习模式。
从第三棵树开始
对剩余的迭代重复步骤 2.1–2.3。注意,每棵树看到的残差不同。
· 决策树逐渐专注于更难预测的模式
· 学习率通过限制每棵树的贡献来防止过拟合

随着我们构建更多的树,注意分裂点如何逐渐变化,叶节点中的残差值变得更小——到第 50 棵树时,我们通过不同的特征组合进行微调,调整的幅度比最初的树要小。
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
from sklearn.ensemble import GradientBoostingRegressor
# Train the model
clf = GradientBoostingRegressor(criterion='squared_error', learning_rate=0.1, random_state=42)
clf.fit(X_train, y_train)
# Plot trees 1, 2, 49, and 50
plt.figure(figsize=(11, 20), dpi=300)
for i, tree_idx in enumerate([0, 2, 24, 49]):
plt.subplot(4, 1, i+1)
plot_tree(clf.estimators_[tree_idx,0],
feature_names=X_train.columns,
impurity=False,
filled=True,
rounded=True,
precision=2,
fontsize=12)
plt.title(f'Tree {tree_idx + 1}')
plt.suptitle('Decision Trees from GradientBoosting', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

来自 scikit-learn 的可视化展示了我们的梯度提升树如何演变:从树 1 进行大范围的分裂并给出大预测值,到树 50 进行精细的分裂并做出微小的调整——每棵树都专注于修正前面树所产生的剩余误差。
测试步骤
预测时:
从初始预测开始(玩家的平均数量)
将输入数据传递给每棵树以获得其预测的调整值
按照学习率缩放每棵树的预测值。
将所有这些调整添加到初始预测中
这些和直接给出我们预测的玩家数量

在对未见数据进行预测时,每棵树都会贡献一个小的预测值,从树 1 的 5.57 开始,到树 50 的 0.008——所有这些预测都被我们的 0.1 学习率进行缩放,并加到我们的基础预测值 37.43 上,得到最终的答案。
评估步骤
构建所有树后,我们可以评估测试集。

我们的梯度提升模型达到了 4.785 的 RMSE,相较于a 棵回归树的 5.27有了显著的提升——这表明将多个小的调整组合起来,比单棵复杂的树更能做出准确的预测!
# Get predictions
y_pred = clf.predict(X_test)
# Create DataFrame with actual and predicted values
results_df = pd.DataFrame({
'Actual': y_test,
'Predicted': y_pred
})
print(results_df) # Display results DataFrame
# Calculate and display RMSE
from sklearn.metrics import root_mean_squared_error
rmse = root_mean_squared_error(y_test, y_pred)
print(f"\nModel Accuracy: {rmse:.4f}")
关键参数
这里是梯度提升中的关键参数,特别是在scikit-learn中:
max_depth:用于建模残差的树的深度。与使用树桩的 AdaBoost 不同,梯度提升在深层树(通常为 3-8 层)上效果更好。深层树能够捕捉更复杂的模式,但也有过拟合的风险。
n_estimators:要使用的树的数量(通常为 100-1000)。当与较小的学习率配对时,更多的树通常能提高性能。
learning_rate:也称为“收缩”,用于缩放每棵树的贡献(通常为 0.01-0.1)。较小的值需要更多的树,但通过使学习过程更精细化,通常能获得更好的结果。
subsample:用于训练每棵树的样本比例(通常为 0.5-0.8)。这个可选特性增加了随机性,可以提高鲁棒性并减少过拟合。
这些参数是相互配合工作的:较小的学习率需要更多的树,而较深的树可能需要较小的学习率以避免过拟合。
与 AdaBoost 的关键区别
AdaBoost 和 Gradient Boosting 都是提升算法,但它们从错误中学习的方式不同。以下是它们的关键区别:
max_depth通常在 Gradient Boosting 中较高(3-8),而 AdaBoost 更倾向于使用树桩。没有
sample_weight更新,因为 Gradient Boosting 使用残差而不是样本加权。learning_rate通常比 AdaBoost 的较大值(0.1-1.0)小得多(0.01-0.1)。初始预测从均值开始,而 AdaBoost 从零开始。
树是通过简单的加法而不是加权投票来组合的,这使得每棵树的贡献更加直观。
可选的
subsample参数增加了随机性,这是标准 AdaBoost 所没有的特性。
优点与缺点
优点:
逐步错误修正: 在 Gradient Boosting 中,每棵新树专注于修正前一棵树的错误。这使得模型在之前错误的区域更好地改进预测。
灵活的误差度量: 与 AdaBoost 不同,Gradient Boosting 可以优化不同类型的误差度量(如平均绝对误差、均方误差等)。这使得它可以适应各种问题。
高准确度: 通过使用更详细的树并仔细控制学习率,Gradient Boosting 往往能提供比其他算法更准确的结果,尤其是对于结构良好的数据。
缺点:
过拟合的风险: 使用更深的树和顺序构建过程可能导致模型过度拟合训练数据,从而降低在新数据上的表现。这需要仔细调整树的深度、学习率和树的数量。
训练过程缓慢: 和 AdaBoost 一样,树必须一个接一个地构建,因此相比于可以并行构建树的算法(如随机森林),训练速度较慢。每棵树都依赖于前一棵树的错误。
高内存使用: 由于需要更深和更多的树,Gradient Boosting 的内存消耗可能比像 AdaBoost 这样的简单提升方法更高。
对设置敏感: 梯度提升的有效性在很大程度上取决于找到合适的学习率、树的深度和树的数量的组合,这可能比调优简单算法更复杂且耗时。
结语
梯度提升(Gradient Boosting)是提升算法的一项重要改进。这一成功催生了像 XGBoost 和 LightGBM 这样的流行版本,它们在机器学习竞赛和实际应用中得到了广泛使用。
尽管梯度提升比简单算法需要更精细的调优——尤其是在调整决策树深度、学习率和树的数量时——它非常灵活且强大。这使得它成为结构化数据问题的首选方法。
梯度提升能够处理简单方法(如 AdaBoost)可能忽视的复杂关系。其持续的流行和不断的改进表明,使用梯度并逐步构建模型的方法在现代机器学习中依然极为重要。
🌟 梯度提升回归器代码总结
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error
from sklearn.ensemble import GradientBoostingRegressor
# Create dataset
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rain', 'rain', 'rain', 'overcast',
'sunny', 'sunny', 'rain', 'sunny', 'overcast', 'overcast', 'rain',
'sunny', 'overcast', 'rain', 'sunny', 'sunny', 'rain', 'overcast',
'rain', 'sunny', 'overcast', 'sunny', 'overcast', 'rain', 'overcast'],
'Temp.': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0,
72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0,
88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
'Humid.': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0,
90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0,
65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'Wind': [False, True, False, False, False, True, True, False, False, False, True,
True, False, True, True, False, False, True, False, True, True, False,
True, False, False, True, False, False],
'Num_Players': [52, 39, 43, 37, 28, 19, 43, 47, 56, 33, 49, 23, 42, 13, 33, 29,
25, 51, 41, 14, 34, 29, 49, 36, 57, 21, 23, 41]
}
# Prepare data
df = pd.DataFrame(dataset_dict)
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='')
df['Wind'] = df['Wind'].astype(int)
# Split features and target
X, y = df.drop('Num_Players', axis=1), df['Num_Players']
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)
# Train Gradient Boosting
gb = GradientBoostingRegressor(
n_estimators=50, # Number of boosting stages (trees)
learning_rate=0.1, # Shrinks the contribution of each tree
max_depth=3, # Depth of each tree
subsample=0.8, # Fraction of samples used for each tree
random_state=42
)
gb.fit(X_train, y_train)
# Predict and evaluate
y_pred = gb.predict(X_test)
rmse = root_mean_squared_error(y_test, y_pred))
print(f"Root Mean Squared Error: {rmse:.2f}")
进一步阅读
对于梯度提升回归器(GradientBoostingRegressor)及其在 scikit-learn 中的实现的详细解释,读者可以参考官方文档,该文档提供了有关其使用和参数的全面信息。
技术环境
本文使用的是 Python 3.7 和 scikit-learn 1.6 版本。虽然所讨论的概念具有普遍适用性,但具体的代码实现可能会因版本不同而略有差异。
关于插图
除非另有说明,所有图片均由作者创作,包含来自 Canva Pro 的授权设计元素。
𝙎𝙚𝙚 𝙢𝙤𝙧𝙚 𝙀𝙣𝙨𝙚𝙢𝙗𝙡𝙚 𝙇𝙚𝙖𝙧𝙣𝙞𝙣𝙜 𝙝𝙚𝙧𝙚:

集成学习
查看列表4 篇故事


𝙔𝙤𝙪 𝙢𝙞𝙜𝙝𝙩 𝙖𝙡𝙨𝙤 𝙡𝙞𝙠𝙚:

回归算法
查看列表5 个故事


Gradio: 高级布局

图片由 AI 创建(Ideogram)
不要使用 Interface,改用 Blocks
·发表于 Towards Data Science ·6 分钟阅读·2024 年 4 月 30 日
--
当你第一次使用 Gradio 时,你很可能会使用 Interface 方法来展示你的 GUI。之所以如此,是因为它提供了简单性,如果你的页面布局不复杂,这基本上就足够了。在本文中,我们将探讨当你的 Gradio 布局超出 Interface 时该如何处理。
在继续之前,如果你是第一次接触 Gradio,请查看我之前的另一篇文章(链接见下方),在文章中我介绍了 Gradio 是什么。简单来说,它是一个非常实用且易于编码的基于 Python 的 GUI,可以让你快速搭建一个本地前端,方便用户与你编写的代码进行交互。
使用 Python 在几分钟内创建直观的网页界面
ai.gopubby.com](https://ai.gopubby.com/gradio-rapid-gui-prototyping-a0091c28116b?source=post_page-----f37a4dae307d--------------------------------)
此外,由于我们将进行一些编码,最佳实践是设置一个新的开发环境,以便我们所做的工作不会干扰到其他项目。我大多数的文章编程使用 Python 和 Jupyter,并且我有一篇 文章在这里 介绍了如何设置我的开发环境,文中解释了你需要做什么。
基于图的提示与推理在语言模型中的应用
理解图形思维提示及其几种变体……
·发表于Towards Data Science ·阅读时长 22 分钟·2024 年 1 月 3 日
--

(图片由Alina Grubnyak拍摄,来源于Unsplash)
高级提示技术,如思维链[8]和思维树[9],显著提高了大型语言模型(LLM)在解决复杂推理任务中的能力。大体而言,迫使 LLM 以逐步响应的方式来处理问题,极大地提升了其问题解决能力。然而,这些技术假设推理过程应该遵循一种从一个思维到下一个思维的线性模式。值得注意的是,人类的推理过程往往与此不同,通常遵循多条不同的思维链,甚至将来自不同思维的见解结合起来,从而得出最终解决方案。在本概述中,我们将研究几种将推理过程建模为图结构的提示技术——而不是链式或树形结构——这种方式更好地捕捉到推理过程中的各种非线性模式。
“人类思维的特点之一是其能够在看似无关的想法之间进行突然的跳跃和连接,这往往能引发新的见解和解决方案。这种非线性的、跳跃性的思维过程是人类创造力、推理和问题解决能力的标志……
2024 年图形与几何机器学习:我们目前的状况与未来展望(第一部分 — 理论与架构)
最前沿技术摘要
根据往年传统,我们采访了一批杰出且高产的学术界和工业界专家,旨在总结过去一年的亮点并预测 2024 年的发展趋势。2023 年成果丰富,我们不得不将这篇文章分为两部分。这是第一部分,重点讨论理论与新架构,另请参阅关于应用的第二部分。
·发表于Towards Data Science ·阅读时间:30 分钟·2024 年 1 月 16 日
--

作者提供的图片,部分来自 DALL-E 3。
本文由 Michael Galkin 和 Michael Bronstein 撰写和编辑,并由 Johannes Brandstetter, İsmail İlkan Ceylan, Francesco Di Giovanni, Ben Finkelshtein, Kexin Huang, Chaitanya Joshi, Chen Lin, Christopher Morris, Mathilde Papillon, Liudmila Prokhorenkova, Bastian Rieck, David Ruhe, Hannes Stärk* 和* Petar Veličković.
图神经网络的理论
信息传递神经网络与图 Transformer
图组件、二连通性与平面性
聚合函数与统一表达能力
GNN 的收敛性与零一法则
GNN 的描述复杂度
GNN 的精细化表达能力
子图 GNN 的表达能力结果
链路预测与知识图谱的表达能力
过度压缩与表达能力
GNN 的泛化与外推能力
预测时间!
新型与异域信息传递
超越图形
拓扑
几何代数
偏微分方程
鲁棒性与可解释性
图 Transformer
新数据集与基准测试
会议、课程与社区
2023 年的迷因
我们将在整个文本中使用的图例:
💡 - 本年度亮点
🏋️ - 挑战
➡️ - 当前/下一步发展
🔮- 预测/猜测
图神经网络的理论
迈克尔·布朗斯坦(牛津大学)、弗朗切斯科·迪·乔瓦尼(牛津大学)、伊斯梅尔·伊尔坎·杰兰(牛津大学)、克里斯·莫里斯(亚琛工业大学)
信息传递神经网络与图 Transformer
图 Transformer 是图机器学习中一个相对较新的趋势,试图将 Transformer 在序列上的成功扩展到图形上。就传统的表达能力结果而言,这些架构并没有提供任何特别的优势。事实上,可以说,它们在表达能力方面的大多数优势(例如,见Kreuzer 等人)来自于强大的结构编码,而非架构本身,而这些编码原则上也可以与 MPNN 一起使用。
在最近的一篇论文中,蔡等人探讨了 MPNN 与(图形)Transformer 之间的联系,表明一个带有虚拟节点的 MPNN——一个以特定方式与所有其他节点连接的辅助节点——可以模拟一个(图形)Transformer。这种架构是非统一的,即神经网络的大小和结构可能依赖于输入图形的大小。有趣的是,一旦我们将注意力限制到线性 Transformer(例如,Performer),就会有一个统一的结果:存在一个使用虚拟节点的单一 MPNN,可以在任何大小的输入上逼近线性 Transformer,如 Performer。

蔡等人的图: (a) 带虚拟节点的 MPNN,(b) 一个 Transformer。
这与关于图变换器架构是否相较于 MPNN 在捕捉长程依赖方面具有优势的讨论相关。图变换器与 MPNN 进行比较,后者通过使用虚拟节点包含了一个全局计算组件,这是常见的做法。Cai et al. 实证表明,带有虚拟节点的 MPNN 能够在长程图基准(LRGB,Dwivedi et al.)上超越图变换器的表现。此外,Tönshoff et al. 在 LRGB 基准上重新评估了 MPNN 的基准,发现先前报告的图变换器性能差距被高估了,这是由于亚优的超参数选择,实质上缩小了 MPNN 和图变换器之间的差距。

Lim et al.中的图:SignNet 管道。
还广为人知,常见的拉普拉斯位置编码(例如,LapPE)对特征向量的符号和基变换不具有不变性。缺乏不变性使得获得(非均匀)普适性结果变得更容易,但这些模型因此不会计算图的不变性。这激发了今年一系列的研究,包括符号和基不变网络的研究(Lim et al., 2023a)以及符号等变网络的研究(Lim et al., 2023b)。这些发现表明,关于 MPNN 和图变换器比较的文献中常见的主张,仍需进行更多的理论研究。
图组件、双连通性和平面性

最初由 Zyqqh 在Wikipedia上发布的图。
Zhang et al. (2023a)将图的双连通性研究带入了图机器学习领域。由Zhang et al. (2023a)提出的许多结果与不同的双连通性度量相关。研究表明,标准的 MPNN 无法检测图的双连通性,而许多现有的高阶模型(即能够匹配 2-FWL 能力的模型)则可以。另一方面,具有某些距离编码和子图 GNN(如 ESAN)的 Graphormers 能够检测图的双连通性。

Dimitrov et al. (2023)中的图:左侧显示了图的分解(A-C),右侧显示了相关的编码器(D-F)和更新方程(G)。
Dimitrov 等人 (2023)依赖图分解来开发专门的架构,以便在平面图上进行学习。其思路是与经典的Hopcroft & Tarjan平面同构测试算法的变种对齐。Dimitrov 等人 (2023)首先将图分解为其双连通和三连通分量,然后学习节点、割节点、双连通分量和三连通分量的表示。这个过程通过使用可以在线性时间内计算的经典结构——块割树(Block-Cut Trees)和 SPQR 树(SPQR Trees)来实现。最终的框架被称为PlanE,包含如BasePlanE等架构。BasePlanE 计算同构完全图不变量,因此它能够区分任意一对平面图。该工作的主要贡献是设计了一种架构,能够高效地学习平面图的完整不变量,同时保持在实际应用中的可扩展性。值得注意的是,3-FWL 已知在平面图上是完全的(Kiefer 等人, 2019),但该算法并不可扩展。
聚合函数:统一表达性研究
人们广泛认为不同的聚合函数各有其适用之处,但这并没有被严格证明。事实上,在非统一设置下,使用 MLP 的求和聚合会产生一个单射映射,因此可以涵盖其他聚合函数(Xu 等人, 2020),这一结果建立在早期的研究基础上(Zaheer 等人, 2017)。而在统一设置下,情况则不同,这要求一个固定的模型能够在所有图上有效工作。Rosenbluth 等人 (2023)表明,在统一设置中,求和聚合并不总是能够涵盖其他聚合方法。例如,如果我们考虑一个无限的特征域,求和聚合网络甚至无法近似均值聚合网络。有趣的是,即使是在一些正面的结果中,其中求和聚合被证明可以近似其他聚合方法,所呈现的构造通常也需要大量的层(随着近似误差的倒数增长)。
GNN 在随机图上的收敛性和零一法则
GNN(图神经网络)理论上可以在训练后应用于任意大小的图。这使得对输入图的大小进行渐进分析变得非常有吸引力。此前对 GNN 渐进行为的研究,集中在它们对理论极限网络的收敛性(Keriven 等人, 2020)以及在大图扰动下的稳定性(Levie 等人, 2021)上。
在最近的一项研究中,Adam-Day 等人 (2023) 证明了二进制 GNN 分类器的零一法则。要解决的问题是:当我们绘制具有随机节点特征的 Erdos-Rényi 图,且图的大小逐渐增大时,二进制 GNN 分类器的表现如何?主要发现是,这些图通过一类 GNN 分类器映射到特定输出的概率最终趋向于零或一。也就是说,模型最终将所有图映射到零或所有图映射到一。这个结果适用于 GCN 以及具有求和和均值聚合的 GNN。
这一结果的主要意义在于它为 GNNs 的表达能力建立了一个新的统一上界:任何可以由这些 GNN 架构统一表达的图的性质必须遵守零一法则。一个简单的性质示例是拥有偶数个节点,这个性质在渐近意义上既不趋向零也不趋向一。
GNNs 的描述性复杂度
Grohe (2023) 最近分析了 GNNs 在布尔电路复杂度方面的描述性复杂度。具体的电路复杂度类别是 TC0。该类别包含所有通过具有恒定深度和多项式大小的布尔电路决定的语言,这些电路仅使用 AND、OR、NOT 和阈值 (或多数)门。Grohe (2023) 证明了可以由一类多项式大小有界深度的 GNN 家族计算的图函数属于电路复杂度类别 TC0。此外,如果允许 GNN 类使用随机节点初始化和全局读取,如 Abboud el al. (2020) 中所述,则存在一个匹配的下界,即它们可以计算出恰好可以用 TC0 表达的相同函数。这为具有随机节点特征的 GNN 的能力建立了上界,要求模型类别具有有界深度(固定层数)且大小为多项式。虽然这一结果仍然是非统一的,但它改善了 Abboud el al. (2020) 的结果,在该结果中,构造可能是最坏情况下的指数级。
GNNs 的细粒度表达性研究
最近的许多研究分析了 MPNNs 的表达力,主要利用组合技术,如用于图同构问题的 1-WL。然而,图同构目标本质上是二元的,无法深入了解两个给定图之间的相似度程度。Böker et al. (2023)通过推导 1-WL 和 MPNNs 的连续扩展到图论函数,解决了这个问题。具体而言,他们展示了 1-WL 的连续变体能够准确地对 MPNNs 在图论函数上的表达力进行拓扑特征化,揭示了这些网络能够区分哪些图,以及分离它们的难度级别。他们提供了一个图和图论函数相似度的理论框架,结合了 1-WL 经典特征化的各种拓扑变体。特别是,他们通过树距离来表征 MPNNs 的表达力,树距离是一种基于分数同构概念的图距离,以及通过树同态的子结构计数,表明这些概念与 1-WL 和 MPNNs 在图论函数上的表达力相同。有趣的是,他们还通过展示随机初始化的 MPNNs(未经过训练)在与经过训练的对应模型相比时具有竞争力的性能,从而验证了他们的理论发现。
子图 GNNs 的表达性结果
基于子图的图神经网络(GNNs)在 2022 年已经成为一个大趋势 (Bevilacqua et al., 2022, Qian et al., 2022)。今年,Zhang et al. (2023b)为这种架构建立了更精细的表达力结果。论文通过所谓的子图 Weisfeiler-Leman 测试(SWL)研究了子图 GNNs。通过这一方法,他们展示了 SWL 的完整层次结构,并且表达力严格增长。具体而言,他们为 SWL 类型的算法定义了等价类,并表明几乎所有现有的子图 GNNs 都属于其中之一。此外,所谓的 SSWL 达到了最大的表达力。令人感兴趣的是,他们还将 SWL 与几个现有的表达性 GNN 架构联系了起来。例如,他们展示了 SWL 具有与 2-WL 的局部版本相同的表达力 (Morris et al., 2020)。除了理论之外,他们还表明 SWL 类型的架构在实际应用中取得了良好的实证结果。
用于知识图谱链接预测的架构表达力
用于知识图谱链接预测的架构,如 RGCN 和 CompGCN 的表达力已被Barceló et al. (2022)研究。今年,Huang et al. (2023)将这些结果推广到表征各种其他模型架构的表达力。

来自 Huang et al. (2023) 的图:该图比较了 R-MPNNs 和 C-MPNNs 在操作模式上的不同。
Huang et al. (2023) 提出了条件消息传递网络(C-MPNNs)的框架,其中包括如 NBFNets 这样的架构。经典的关系消息传递网络(R-MPNNs)是单一编码器(即,编码图节点),并依赖于二元解码器进行链路预测任务(Zhang, 2021)。另一方面,C-MPNNs 作为二元编码器(即,编码图节点对),因此更适合于链路预测的二元任务。C-MPNNs 被证明与关系 Weisfeiler-Leman 算法对齐,可以看作是 2WL 的局部近似。这些发现解释了 NBFNets 等在性能上优于,例如,RGCNs。Huang et al. (2023) 还展示了 C-MPNNs 捕捉的二元函数类的精确逻辑特征,从而得出了统一的表达能力结果。
过度压缩与表达能力
Over-squashing 是 Alon & Yahav 在 2021 年首次描述的一种现象,指的是指数增长的感受野被压缩为固定大小的向量。后续研究(Topping et al., 2022, Di Giovanni et al., 2023, Black et al., 2023, Nguyen et al., 2023)通过灵敏度分析对过度压缩进行了表征,证明了输出特征对早期层的隐藏表示的依赖性会因拓扑性质(如负曲率或长时间通勤)而受到影响。由于图的拓扑在瓶颈形成中起着关键作用,图重连线(graph rewiring)作为一种范式转变,提升了图连接性在 GNN 设计中的作用,被提出作为缓解过度压缩的关键策略(如果你感兴趣,参见下面关于异构消息传递的部分)。

对于给定的图,MPNN 为节点(v,u)和(u,w)学习到更强的混合(紧弹簧),因为它们的通勤时间较短,而节点(u,q)和(u,z)由于较长的通勤时间,具有较弱的混合(松弛弹簧)。来源:Di Giovanni et al., 2023
过度挤压是限制表达能力的障碍,因为它导致 GNNs 在处理具有长程交互的任务时表现不佳。为了正式研究这一问题,Di Giovanni 等人,2023提出了一种新的表达能力度量,称为“混合”,它编码了图函数对节点特征对的联合和非线性依赖关系:为了使 GNN 能够近似具有较大混合的函数,必要的条件是允许相关节点之间进行“强”的信息交换。因此,他们假设通过 GNN 预测的混合来衡量过度挤压,并证明了 GNN 需要的深度来引入足够的混合,如任务所需,随着通勤时间的增长——通常比最短路径距离要差得多。结果显示,过度挤压如何妨碍具有“实际”大小的 GNN 的表达能力,并验证了它源于任务(要求节点 i 和 j 之间的强混合)与拓扑(导致 i 和 j 之间的长通勤时间)之间的不匹配。
函数的“混合”涉及节点之间信息的交换,无论这些信息是什么,而不是其分离节点表示的能力。事实上,这些结果 同样适用于比 1-WL 测试更强大的 GNNs。Di Giovanni 等人(2023)的分析提供了一种替代方法来研究 GNNs 的表达能力,该方法可以轻松扩展到 3D 空间中的等变 GNNs 以及它们建模节点之间交互的能力。
GNN 的泛化和外推能力
近年来,MPNNs(消息传递神经网络)的表达能力因其与 WL 测试的关联而引起了广泛关注。尽管这种关联促成了在理解和提升 MPNNs 表达能力方面的显著进展(Morris 等,2023a),但它并未提供关于其泛化性能的见解,即它们在训练集之外进行有意义预测的能力。令人惊讶的是,只有少数几项重要的研究探讨了 MPNNs 的泛化行为,例如,Garg 等(2020),Kriege 等(2018),Liao 等(2021),Maskey 等(2022),Scarselli 等(2018)。然而,这些方法仅使用经典的图参数来表示 MPNNs 的泛化能力,例如最大度数、顶点数或边数,这些方法无法完全捕捉现实世界图的复杂结构。此外,大多数方法在非均匀区域研究泛化,即假设 MPNNs 在预定阶数的图上进行操作。

来自Morris 等人(2023b)的图:MPNNs 的泛化能力概览及其与 1-WL 的联系。
因此,莫里斯等人 (2023b)展示了 1-WL 的表达能力与泛化性能之间的紧密联系。他们通过紧密连接 1-WL 的表达能力和 MPNNs 的 Vapnik–Chervonenkis (VC)维度,研究了图结构和参数编码长度对 MPNNs 泛化能力的影响。为此,他们展示了几个结果。
1️⃣ 首先,在非均匀状态下,他们展示了 MPNNs 的 VC 维度与 1-WL 在一组图上计算出的等价类数量紧密相关。此外,他们的结果可以轻松推广到 k-WL 和许多最近的表达性 MPNN 扩展。
2️⃣ 在均匀状态下,即当图可以具有任意顺序时,他们展示了 MPNNs 的 VC 维度受其权重的最大比特长度的上下界限制。在均匀和非均匀状态下,MPNNs 的 VC 维度对 1-WL 计算的颜色数呈对数关系,并且对参数的数量呈多项式关系。此外,他们还通过实证研究表明,他们的理论发现一定程度上在实践中成立。
🔮 预测时间!
克里斯托弗·莫里斯 (亚琛工业大学)
“我相信,GNNs 泛化理论亟需更好且更实用的理论。” — 克里斯托弗·莫里斯 (亚琛工业大学)
➡️ 例如,我们需要理解图结构和各种架构参数如何影响泛化。此外,当前对于训练 GNNs 的 SGD 动态研究较少,理解也不够透彻,更多的研究将会探讨这一问题。
伊斯梅尔·伊尔坎·杰兰 (牛津大学)
“我希望在均匀设置下看到更多的表达能力研究,在这种情况下我们固定神经网络的参数并检查其能力。” — 伊斯梅尔·伊尔坎·杰兰 (牛津大学)
➡️ 在这种情况下,我们可以识别出更好的泛化联系,因为如果某个属性无法均匀表达,那么模型就无法对更大的图大小进行泛化。
➡️ 今年,我们也可能会看到针对图回归或图生成的表达能力研究,这些领域仍然未被充分探索。人们有充分的理由期望出现在更大的图类别上具有同构完备性的学习算法,从而严格地推广平面图的结果。
➡️ 现在也是时候为学习完全关联数据(即知识超图)发展一套理论了,这将解锁关系数据库中的应用!
弗朗切斯科·迪·乔瓦尼 (牛津大学)
在 GNNs 的未来理论发展方面,我看到两个值得关注的方向。
“我们对 GNN 在梯度流(或 SGD)下的权重动态几乎没有理解;评估图拓扑对权重演变的影响是解决泛化和任务难度问题的关键。” — 弗朗切斯科·迪·乔瓦尼(牛津大学)
➡️ 其次,我认为发展表现力的替代范式会很有价值,这些范式更直接地关注近似能力(图函数及其导数的近似能力),并精确地识别出哪些任务难以学习。后一方向在表征三维空间中等变 GNN 的能力时可能特别有意义,在这种情况下,表现力的度量可能需要与二维情况解耦,以便更好地与来自科学领域的任务对接。
结尾:关于 WL 在 2023 年去向的一个有趣事实

人物肖像:伊戈尔·戈尔斯基
2023 年预测
(1) 在创建时间和内存高效的子图 GNN 方面投入更多努力
❌ 并不完全是
(2) 更好地理解 GNN 的泛化能力
✅ 是的,见关于过度压缩和泛化的子章节
(3) 费斯费尔和利曼访问了 10 个新地方!
新颖且独特的消息传递
本·芬克尔施坦(牛津大学),弗朗切斯科·迪·乔瓦尼(牛津大学),佩塔尔·维利奇科维奇(谷歌 DeepMind)
佩塔尔·维利奇科维奇(谷歌 DeepMind)
多年来,关于消息传递算子的开发已经饱和,这已经成为一种常见的民间传说。我认为 2023 年取得的进展特别令人兴奋,因为来自几个独立研究小组(包括我们自己)的成果表明,一个统一的全新方向已经出现:让我们开始考虑时间在 GNN 中的影响 ⏳。
“我预测,在 2024 年,时间将在新型 GNN 架构的发展中扮演核心角色。” — 佩塔尔·维利奇科维奇(谷歌 DeepMind)
💡 在 GNN 设计中,时间已经在显式提供的输入中得到了利用(如时空图或完全动态图)。今年,它开始出现在关于静态图输入的 GNN 操作符研究中。一些研究开始不再假设一个统一的、同步的时钟 ⏱️,该时钟迫使所有信息在同一层中同时发送和接收。
1️⃣ 第一篇此类工作,GwAC 🥑,仅仅玩弄了基础的随机化消息调度,但提供了证明,说明为什么这种处理方式可能在表达能力上带来显著提升。Co-GNNs 🤝 进一步推动了这一进程,展示了一种更加精细调整的消息调度机制,它以节点为中心,允许每个节点选择何时发送 📨 或接收 📬 消息。Co-GNNs 还提供了一种通过梯度下降训练这种调度器的实用方法。虽然开发这种异步 GNN 模型是非常期望的,但我们也必须承认与之相关的可扩展性问题——我们当前的前沿硬件并不设计为能够高效扩展此类顺序系统。
2️⃣ 在我们关于异步算法对齐的研究中,我们选择设计一个同步 GNN,但限制其消息传递、聚合和更新函数,使得即使其数据流的部分变为异步,GNN 也能产生相同的嵌入。这带领我们进入了一个激动人心的旅程,涉及到单群、1-余切和范畴理论,最终得出了一个可扩展的 GNN 模型,在许多 CLRS-30 任务上表现出色。

一个异步 GNN 的可能执行轨迹。传统的 GNN 同步发送和接收所有消息,而在我们的框架下,在任何步骤,GNN 可以选择执行任意数量的可能操作(这里通过图的右侧集合来表示)。来源:Dudzik et al.
➡️ 最后,值得注意的是,对于某些特定的消息调度选择,我们不需要修改同步 GNN 的架构——而可以转而使用动态图重连接。DREW 和Half-Hop是两篇在 ICML'23 上同时发表的论文,它们体现了使用图重连接来减缓消息传递🐌的原理。在 DREW 中,每个节点的消息实际上会发送到每一个其他节点,但需要k层才能到达一个与其相隔k跳的邻居!而 Half-Hop 则采取更宽松的方式,随机决定是否引入一个“慢节点”,从而延长任何通过边连接的两个节点之间的路径。两种方法都能自然缓解过平滑问题,因为消息传递的距离越远,过平滑的程度越小。
无论是用于消息传递设计、GNN 数据流还是图重连接,到了 2023 年,我们才刚刚开始理解时间的重要性——即便时间变化在我们的数据集中并未显式存在。
Ben Finkelshtein (牛津大学)
在Co-GNNs中提出的基于时间的消息传递范式是一种可学习的消息传递泛化方式,它允许每个节点决定如何从其邻居传递信息或向其邻居传播信息,从而实现信息流动的更大灵活性。这些节点被视为可以进行以下操作的参与者:向监听的邻居广播并且监听广播的邻居(类似于经典的消息传递),仅向监听的邻居广播,或者隔离(既不监听也不广播)。
这些操作之间的相互作用以及局部和动态地改变它们的能力使 CoGNN 能够确定一个任务特定的计算图(可以视为一种动态和定向重连接的形式),为具有不同节点特征的两个节点学习不同的操作分布(包括特征和结构基础的)。CoGNN 还允许节点之间进行异步更新,并且以较高的概率生成唯一的节点标识符,从而使它们能够区分任何一对图(比 1-WL 更具表现力,但仅在期望下保持等变性)。

从左到右:经典的 MPNN(所有节点都广播和监听)、DeepSets(所有节点都隔离)以及通用的 CoGNN。图源自博客文章。
查看 Medium 文章了解更多详情:
一种新的消息传递范式,其中每个节点可以选择“监听”、“广播”、“监听并广播”或...
[towardsdatascience.com
Francesco Di Giovanni(牛津大学)
“过度压缩(over-squashing)的理解,尤其是当任务依赖于节点间长时间交互时,成为了图重连接作为一种有效方法的催化剂,用以设计新的 GNN。” — Francesco Di Giovanni(牛津大学)
️💡 图结构重连接广义上指的是改变输入图的连接方式,以促进下游任务的解决。最近,这通常聚焦于图中的瓶颈,从而添加(或移除)边缘,以改善信息流动。虽然重点通常放在信息交换的地点,但最近的研究(如上所述)也揭示了何时进行信息交换的重要性。这些方法背后的一个逻辑,尽管通常是隐含的,就是 GNN 各层构建的隐藏表示为图提供了一个(人为的)动态组件,即便图和输入特征是静态的。这个视角可以通过多种方式进行利用。

在经典的 MPNN 设置中,在每一层,信息仅在节点及其直接邻居之间传播。在 DRew 中,图会随着层数的变化而变化,新添加的边从第 r 层起连接距离 r 的节点。最后,在νDRew 中,我们还引入了一个延迟机制,相当于基于节点之间的相互距离的跳跃连接。来源:Gutteridge 等人
➡️ 一个特别采纳这一角度的框架是DRew,它通过两种方式扩展了任何消息传递模型:(i)它直接连接距离 r 的节点,但仅从第 r 层开始;(ii)当节点连接时,基于它们之间的相互距离,应用延迟到它们的消息交换。如上图所示,(i)使网络更好地保留归纳偏差,因为较近的节点会更早互动;(ii)则充当距离感知的 跳跃连接,从而有助于损失函数的梯度传播。很可能正是由于这个原因,而不是防止过度平滑(对图级任务几乎没有影响),该框架显著增强了标准 GNN 在更大深度下的表现(更多细节可以参考这篇博客文章)。
🔮 预测: 我相信,扩展消息传递到“时间”维度的深远影响将在来年开始显现。像 DRew 这样的工作仅仅触及了时间上重连(超越空间)的原因,它可能有助于 GNN 训练,显著影响它们在不同深度层次下的准确性响应。
➡️ 更广泛地说,我希望图重连的理论和实践发展能够被应用到科学领域,其中等变图神经网络(GNNs)通常应用于三维问题,这些问题要么没有自然的图结构(使得“在哪里”交换消息变得更加相关),要么(且)展现出自然的时间(多尺度)特性(使得“何时”交换消息可能是减少内存限制并保持正确归纳偏差的关键)。
几何学、拓扑学、几何代数与偏微分方程
Johannes Brandstetter(JKU Linz),Michael Galkin(Intel),Mathilde Papillon(UC Santa Barbara),Bastian Rieck(Helmholtz & TUM),David Ruhe(U Amsterdam)
2023 年带来了最全面的几何图神经网络(Geometric GNNs)介绍(及综述),涵盖了最基础和必要的概念,并提供了一些示例:《三维原子系统的几何图神经网络指南》(Duval, Mathis, Joshi, Schmidt 等)。如果你曾经想从零开始学习驱动最近蛋白质设计、材料发现、分子模拟等突破的图机器学习核心架构——这正是你需要的!

3D 原子系统的关键几何 GNN 时间线,按层内中间表示的类型进行分类。来源:Duval, Mathis, Joshi, Schmidt 等人
拓扑
💡 得益于PyT 团队的卓越努力和他们的一系列资源,2023 年研究人员和实践者在处理拓扑结构时变得更加容易:TopoNetX、TopoModelX 和 TopoEmbedX。TopoNetX基本上是拓扑数据的 networkx。TopoNetX 支持标准结构,如细胞复形、单纯形复形和组合复形。TopoModelX是一个类似 PyG 的库,用于拓扑数据上的深度学习,并实现了著名的模型如MPSN和CIN,并提供了一个简洁统一的接口(原始的 PyG 实现比较复杂)。TopoEmbedX有助于在拓扑数据上训练嵌入模型,并支持像Cell2Vec这样的核心算法。

领域:蓝色节点,粉色(超)边,深红色面。来源:TopoNetX,Papillon 等人
💡 该领域的一个重要起步和基本构建块是Hajij 等人和Papillon 等人的论文。一个显著的模型部分是由数据科学中的拓扑、代数与几何(TAG)社区的成员实现的,该社区定期在机器学习会议上组织拓扑学研讨会。
Mathilde Papillon (UCSB)
“直到 2023 年,拓扑深度学习领域仍然呈现出一个支离破碎的关系数据增强表示的格局。” —— Mathilde Papillon(加州大学圣塔芭芭拉分校)
➡️ 消息传递模型仅基于相同领域的其他模型进行构建和基准测试,例如,单纯形复形社区与超图社区之间一直是封闭的。更糟糕的是,大多数模型采用了独特的数学符号。这使得选择最适合特定应用的模型变得像一项艰巨的任务。Hajij 等人提出的统一理论提供了一个通用框架,在这个框架下,所有模型都可以被系统地描述和分类。我们将该理论应用于文献中,制作了一份全面而简明的拓扑深度学习中消息传递的调查,这也是该领域的易于理解的入门介绍。此外,我们还提供了一个字典,列出了所有模型架构,并采用统一的符号表示。
➡️ 为了进一步统一这一领域,我们组织了第一次拓扑深度学习挑战赛,该活动在2023 ICML TAG 研讨会上举办,并通过Papillon 等人的白皮书进行了记录。我们的目标是通过众包的方式促进可重复研究,尤其是推动基于拓扑领域的神经网络开源实现。作为挑战的一部分,来自世界各地的参与者贡献了在TopoModelX中实现的现有拓扑深度学习模型。每个提交都经过严格的单元测试,并包括在TopoNetX加载的数据集上的基准训练。我们希望,这一套始终如一地实现的模型能够帮助从业者在 2024 年测试拓扑方法在新应用和开发中的表现。
Bastian Rieck(赫尔姆霍兹研究所 & 慕尼黑工业大学)
2023 年对于拓扑驱动的机器学习方法来说是激动人心的一年。一方面,我们看到了更多与几何概念(如曲率)相结合的应用,展示了几何-拓扑混合模型的多样性。例如,在《用于图生成模型评估的曲率过滤》中,我们展示了如何使用曲率来选择合适的图生成模型。在这里,曲率作为一种“透镜”,用来提取图结构信息,而我们则使用持续同调(一种拓扑方法)以一致的方式比较这些信息。

通过离散曲率评估图生成模型的管道概述。边的排序会产生一个曲率过滤,接着是相应的持久性图和地形图。对于图生成模型,我们选择一个曲率,逐元素应用这个框架,并通过比较生成分布和参考分布的平均地形图来评估它们的相似性。来源:Southern, Wayland, Bronstein, 和 Rieck.
➡️ 另一个方向强调拓扑驱动方法正成为图学习研究的主流,利用拓扑来评估图神经网络模型的表达能力。有时,正如 2023 年 NeurIPS 上Immonen 等人的一个非常有趣的工作所展示的那样,这甚至会导致新的模型,它们同时利用图的几何和拓扑特性!我的研究也旨在通过专门分析持久同调在图学习中的表达能力为这一领域做出贡献。
“2023 年也正是一个转折点,标志着我们即将摆脱——或者说超越——持久同调。尽管持久同调被公认为拓扑驱动的机器学习范式算法,但代数拓扑和微分拓扑提供了更为丰富的结构,可以用来分析数据。”——巴斯蒂安·里克(Helmholtz & TUM)
➡️ 与我的优秀合作伙伴们,我们最近开始探索一些替代方法,并提出了神经微分形式的概念。微分形式允许我们通过德拉姆同调优雅地建立几何和拓扑之间的桥梁——德拉姆同调是将某些对象(微分形式)的积分——即一种本质上几何的操作——与输入数据的拓扑特征联系起来的方法。通过一些额外的构造,德拉姆同调使我们能够学习图(或更高阶的组合复形)的几何描述,并在不依赖信息传递的情况下解决学习任务。其结果是参数更少的模型,可能在解决这些任务时更加高效。这里还有更多内容,毕竟我们才刚刚开始探索!
🔮 我对 2024 年的希望预测是,我们将:
1️⃣ 看到更多来自代数和微分拓扑的多样化工具应用于图和组合复形,
2️⃣ 更好地理解高阶输入数据上的信息传递,并且
3️⃣ 最终获得更好的并行算法,用于持久同调,真正释放其在深度学习中的潜力。Torras-Casas 最近关于谱序列的论文报告了一些非常令人兴奋的结果,展示了该技术的巨大前景。
几何代数
约翰内斯·布兰德斯特特(JKU Linz)和大卫·鲁赫(U Amsterdam)
“在 2023 年,我们见证了几何代数(也叫Clifford 代数)深度学习子领域的蓬勃发展。之前,神经网络层通过 Clifford 代数多向量的运算形式被提出,由Brandstetter et al.介绍。今年,‘几何’这一概念在‘几何代数’中得到了明确的应用。”——Johannes Brandstetter(JKU Linz)和 David Ruhe(U Amsterdam)
➡️ 首先,Ruhe et al.通过引入几何 Clifford 代数网络(GCAN),应用了现代(基于平面的)几何代数的精髓,这些神经网络模板模拟由各种几何代数描述的对称变换。我们看到由Pepe et al.在CGAPoseNet中应用这一方法,构建了一个几何感知的相机姿态回归管道。接下来,Ruhe et al.引入了Clifford 群等变神经网络(CGENN),通过 Clifford 群构建可调的 O(n)-和 E(n)-等变(图)神经网络,适用于任意维度。Pepe et al.将 CGENNs 应用于蛋白质结构预测(PSP)管道,将预测精度提高了最多 2.1%。

CGENNs(用ϕ表示)能够以 O(n)-或 E(n)-等变方式作用于多向量(Clifford 代数的元素)。具体来说,当 Clifford 群的作用ρ(w),代表像旋转这样的正交变换,作用于数据时,模型的表示会进行共同旋转。多向量可以分解为标量、向量、双向量、三向量甚至更高阶的分量。这些元素可以表示几何量,如(定向的)面积或体积。作用ρ(w)的设计是为了在作用于这些结构时保持这些结构的完整性。来源:Ruhe et al.
➡️ 恰巧,Brehmer et al.提出了几何代数变换器(GATr),一种可扩展的 Transformer 架构,利用投影几何代数提供的表示优势和 Transformer 的可扩展性,构建 E(3)-等变架构。GATr 架构由Haan et al.扩展到其他代数,并且还探讨了哪种几何代数最适合您的 E(3)-等变机器学习问题。

GATr 架构概述。带实线的框表示可学习组件,带虚线的框表示固定组件。来源:Brehmer et al.
🔮 在 2024 年,我们可以期待这些进展带来令人兴奋的新应用。一些例子包括以下内容。
1️⃣ 我们可以期待探索其在分子数据、药物设计、神经物理仿真、晶体等领域的应用。其他几何感知的应用包括 3D 渲染、姿态估计,以及例如机器人手臂的规划。
2️⃣ 我们可以期待将基于几何代数的网络扩展到其他神经网络架构中,如卷积神经网络。
3️⃣ 接下来,CGENN 的通用性使得可以在其他维度中进行探索,例如 2D,但也可以在需要处理多维数据的环境中进行探索。此外,它们还支持非欧几里得几何,这在相对论物理学中有多个应用场景。
4️⃣ 最后,GATr 和 CGENN 可以扩展并应用于投影几何、保角几何、双曲几何或椭圆几何。
偏微分方程(PDEs)
Johannes Brandstetter(JKU Linz)
关于神经偏微分方程建模的现状,哪些话题在 2023 年有所浮现或获得了更多关注?
1️⃣ 首先,有一个明显的趋势是,开始在复杂几何上以及几何内部建模 PDEs,这需要基于网格的空间离散化。这与解决越来越逼真的现实世界问题的总体目标相一致。例如,Li et al提出了用于大规模 3D PDE 的几何感知神经算子(GINO)。
2️⃣ 其次,基于拉格朗日模拟的神经网络代理的发展变得越来越引人注目。拉格朗日空间离散化使用有限的物质点,这些物质点作为流体团块在时空中进行追踪。最著名的拉格朗日离散化方案是平滑粒子流体动力学(SPH),这是由Toshev et al.提供的LagrangeBench基准数据集中使用的数值基线。

我们数据集的时间快照,在初始时刻(顶部)、40%(中部)和 95%(底部)的轨迹上。颜色温度表示速度大小。 (a) 泰勒-格林涡旋(2D 和 3D),(b) 反向泊松流(2D 和 3D),(c) 驱动腔体(2D 和 3D),(d) 水坝溃坝(2D)。来源:LagrangeBench,由Toshev et al.提供
3️⃣ 其次,基于扩散的建模在偏微分方程(PDE)领域也没有停下脚步。我们大致看到两个方向。第一个方向将扩散过程的迭代性质转化为对从噪声初始化并根据前一步骤进行条件化的候选状态的精炼。这种迭代精炼在PDE-Refiner(Lippe 等人)中有所介绍,而其变体已经在GenCast(Price 等人)中得到应用。第二个方向利用扩散模型的概率性质来模拟诸如三维湍流等混沌现象。关于这一点的例子可以在湍流流动仿真(Kohl 等人)和从零到湍流(Lienen 等人)中找到。特别是对于三维湍流,未来有很多有趣的事情将会发生。
“天气建模在过去几个月取得了巨大成功。未来可能会有更多令人激动的进展,特别是在直接从观测数据进行天气预报或构建天气基础模型方面。” — Johannes Brandstetter(JKU Linz)
🔮 2024 年的预测:
1️⃣ 更多关于三维湍流建模的研究。
2️⃣ 偏微分方程的多模态方面可能会出现。这可能包括结合不同的偏微分方程、不同的分辨率或不同的离散化方案。我们已经在例如多物理预训练物理代理模型中看到了一些相关的迹象,这是 McCabe 等人的研究。
2023 年预测
(1) 神经网络偏微分方程及其应用可能会扩展到更多与物理相关的 AI4Science 子领域;计算流体力学(CFD)可能会受到图神经网络(GNN)的影响。
✅ 我们已经看到了三维湍流建模、几何感知神经算子、基于粒子的神经代理模型,并且在例如天气预报等领域产生了巨大影响。
(2) 基于图神经网络的代理模型可能会增强或取代传统的成熟技术。
✅ 天气预报已经成为一个巨大的成功案例。基于神经网络的天气预报超过了传统的预报(中期+局部预报),例如,GraphCast(Lam 等人)和MetNet-3(Andrychowicz 等人)。
稳健性和可解释性
黄克欣(斯坦福大学)
“随着图神经网络(GNN)在各个领域的部署,其可靠性和稳健性变得越来越重要,尤其是在安全关键的应用中(例如科学发现),因为错误的代价可能非常高。” — 黄克欣(斯坦福大学)
1️⃣ 在讨论图神经网络(GNN)的可靠性时,一个关键标准是不确定性量化——量化模型对预测的了解程度。关于估计和校准不确定性的研究很多,其中也有专门针对 GNN 的工作(例如,GATS)。然而,这些方法在理论和实证上都未能达到预定义的目标覆盖率(即预测集内的点数百分比)。我想强调的是,拥有覆盖保证这一概念至关重要,尤其是在科学发现的机器学习部署中,因为实践者通常信任具有统计保证的模型。
2️⃣ 符合性预测是统计学中的一个令人兴奋的方向,它具有有限样本覆盖保证,且已应用于许多领域,如视觉和自然语言处理。但尚不清楚它是否可以在图中理论上应用,因为在图的设置下,可交换性假设是否成立并不显而易见。2023 年,我们看到符合性预测已扩展到图结构数据。特别是,CF-GNN和DAPS已推导出传导节点级预测设置中符合性有效性的理论条件,并且还开发了减少预测集大小的方法,以便高效地进行下游使用。最近,我们还看到符合性预测扩展到链接预测、非均匀划分、边可交换性,以及针对不满足可交换性的设置(如归纳设置)的适应。

针对图结构数据的符合性预测。 (1) 一个基础的 GNN 模型(GNN),它为节点 i 生成预测得分µ。 (2) 符合性校正。由于训练步骤并不考虑符合性校正步骤,因此预测集/区间的大小/长度(即效率)没有得到优化。我们使用一个拓扑感知的校正模型,将µ作为输入节点特征,并汇聚其局部子图的信息,以生成更新后的预测µ˜。 (3) 符合性预测。我们证明,在一个传导随机划分的设置下,给定置换不变性,图的可交换性是成立的。因此,标准的符合性预测(CP)可以基于µ˜生成预测集/区间,且该预测集包含具有预设覆盖率 1-α的真实标签。来源:黄等
🔮 展望未来,我们预计会有更多扩展,覆盖更广泛的 GNN 部署用例。总体而言,我认为为 GNN 提供统计保证非常好,因为它使得实践者能够信任 GNN 的预测结果。
图神经网络变换器(Graph Transformers)
陈琳(牛津大学)
💡 2023 年,我们见证了图神经网络变换器的持续崛起。它已经成为常见的 GNN 设计,例如在GATr中,作者将其流行归因于其*“良好的扩展性、表达能力、可训练性和多功能性”*。
1️⃣ **GT 的表达能力。**正如在 GNN 理论部分提到的,最近Cai 等人(2023)的工作显示,带有虚拟节点的 MPNN 和 GT 在非均匀设置下是等价的。这提出了一个问题:GT 到底有多强大,它们的表达能力来源于哪里?Zhang 等人(2023)成功地将一种新的强大位置嵌入(PE)结合起来,以提高 GT 的表达能力,在双连通性问题上取得了更好的表现。这为 PE 对 GT 表达能力的重要性提供了证据。最近的提交GPNN更清楚地展示了位置编码在其中的核心作用。已经证明,人们可以推广Zhang 等人(2023)中的证明,展示 GT 的表达能力是如何由各种位置编码决定的。
2️⃣ **位置(结构)编码。**鉴于 PE/SE 对图神经网络(GTs)的重要性,接下来我们将讨论那些通常来源于现有图不变量的富有表现力的特征的设计。2022 年,GraphGPS通过将 GT 与各种(甚至多个)PE/SE 结合,取得了巨大的实证成功。到 2023 年,出现了更强大的 PE/SE。
**相对随机游走结构编码(RRWP)**由Ma 等人提出,通过引入关系部分扩展了随机游走结构编码。结合一种新的注意力机制变体,GRIT在属性预测基准测试中相比现有的 PE/SE 表现出了强大的实证性能(在 ZINC 上达到了 SOTA)。理论上,RRWP 可以通过特定的参数选择来逼近最短路径距离、个性化的 PageRank 和热核。使用 RRWP 后,GRIT 比 SPD-WL 具有更强的表达能力。

fluorescein 分子的 RRWP 可视化,最高到 4 次方。较粗且较深的边表示更高的边权重。较长的随机游走的概率揭示了高阶结构(例如,在 3-RW 中显现的团和在 4-RW 中出现的星型图案)。来源:Ma 等人。
Puny 等人提出了一种基于等变多项式的新表达力理论框架,通过在等变基上计算张量收缩得到的多项式特征作为位置编码,可以提高常见 GNN 的表达力。实验结果令人惊讶:GatedGCNs 的测试 MAE 从 0.265 提高到 0.106,采用了 d-表达多项式。未来有人将这一方法与 GT 结合将会非常有趣。
3️⃣ 高效的 GTs。 由于 GT 的 O(N²)复杂度,GT 在大图上的应用依然充满挑战。在 2023 年,我们看到了更多尝试通过降低 GT 计算复杂度来解决这一问题的研究。Deac 等人使用了扩展图进行传播,该图通常用较少的边进行连接。Exphormer将这一思想扩展到 GT,通过将扩展图与局部邻域聚合和虚拟节点结合起来。Exphormer 使图转换器能够扩展到更大的图(例如,具有 169K 节点的ogbn-arxiv)。它还取得了强大的实验结果,并在多个长距离图基准测试任务中排名靠前。
🔮 迈向 2024 年:
更好地理解自注意力在抽象层面上对表现力之外的好处。
2024 年,开源大规模预训练的等变 GT!
更强大的位置编码。
新的数据集与基准测试
结构生物学: 来自 VantAI 的 Pinder,来自牛津的PoseBusters,来自 The Other Place 的PoseCheck,DockGen,以及来自Valence Labs的 LargeMix 和 UltraLarge 数据集
时序图基准测试(TGB):到目前为止,时序图学习的进展一直受到缺乏大型高质量数据集以及缺乏适当评估的限制,这导致了过于乐观的性能评估。TGB 通过引入七个现实的大规模且多样化的基准测试,解决了这一问题,这些基准包括节点级和链接级任务。受到 OGB 成功的启发,TGB 自动化了数据集下载和处理,以及评估协议,并允许用户使用排行榜比较模型性能。更多详情请查看相关博客文章。
来自 Google Research 的TpuGraphs:TPU 计算图的图属性预测数据集。该数据集提供的图数量是现有最大图属性预测数据集的 25 倍(图大小相当),且与现有机器学习程序性能预测数据集相比,图的平均规模大约大 770 倍。Google 基于 TpuGraphs 举办了Kaggle 比赛!
LagrangeBench:拉格朗日流体力学基准测试套件——在一个基于 JAX 的环境中评估你喜爱的基于 GNN 的模拟器(适合 JAX 爱好者)
RelBench:来自斯坦福和 Kumo.AI 的关系深度学习基准:对关系数据库进行基于时间的预测(你可以将其建模为图或超图)
GNoMe 数据集来自 Google DeepMind:381k 种新型稳定材料,用于材料发现和机器学习潜力模型!
会议、课程与社区
图形和几何学习领域的主要事件(除了大型 ML 会议)不断壮大和成熟:图学习会议(LoG)、分子机器学习(MoML)和斯坦福图学习工作坊。LoG 会议采用远程优先的独特形式,全球范围内的社区成员组织了数十场本地聚会,从中国到英国、欧洲,再到美国西海岸 🌏🌍🌎。

LoG 在阿姆斯特丹、巴黎、特罗姆瑟和上海的聚会。来源:LoG 社区的 Slack
课程、书籍和教育资源
几何 GNN 道场——一个面向初学者和专家的教学资源,帮助探索几何图形的 GNN 设计空间(与最近的《几何 GNN 的便捷指南》最佳搭配)
TorchCFM——流匹配领域的主要入口
PyT 团队维护着 TopoNetX、TopoModelX 和 TopoEmbedX——最实用的拓扑深度学习库
由 Maurice Weiler、Patrick Forré、Erik Verlinde 和 Max Welling 编写的《等变和坐标独立卷积网络:神经网络的规范场理论》一书——汇集了关于等变卷积神经网络(CNN)的表示理论和微分几何的研究成果
调查
量子、原子系统和连续介质中的科学机器学习,由来自 23 个机构的 60 多位作者(张、王、Helwig、罗、傅、谢等)撰写
人工智能时代的科学发现,由王等人发表于《自然》杂志。
知名研讨会系列
Slack 社群
2023 年的迷因

在 2023 年,纪念流匹配的成功,带到了 NeurIPS’23 的迷因和独特 T 恤。右图:Hannes Stärk 和 Michael Galkin 在 NeurIPS’23 发表声明。图片由 Michael Galkin 提供。

GNN 聚合函数实际上是范畴理论的门户(由 Petar Veličković创建)

Michael Bronstein 继续通过要求他的DeepMind 主席职位来骚扰谷歌,但到目前为止,他只被提供了凳子(照片来源:Jelani Nelson 和 Thomas Kipf)。

本博客的作者祝贺你完成了这篇长文。Michael Galkin 和 Michael Bronstein 在夏威夷 ICML 2023 上的 2022 年迷因(照片来源:Ben Finkelshtein)
欲了解更多几何学和图深度学习的文章,请参阅 Michael Galkin和 Michael Bronstein的 Medium 文章,并在 Twitter 上关注这两位 Michael(Galkin 和 Bronstein)。*
2024 年图形与几何机器学习:我们处于何种阶段,未来如何发展(第二部分——应用)
最前沿摘要
延续往年的传统,我们采访了一批杰出且高产的学术和工业专家,旨在总结过去一年的亮点,并预测 2024 年的发展趋势。2023 年成果丰硕,以至于我们不得不将这篇文章分为两部分。这是第二部分,重点介绍应用,欲了解理论与新架构,请参见第一部分。
·发表于 Towards Data Science ·42 分钟阅读·2024 年 1 月 16 日
--

图片由作者提供,并得到 DALL-E 3 的帮助。
本文由 Michael Galkin 和 Michael Bronstein 编写和编辑,且有来自 Dominique Beaini 、 Nathan Benaich 、 Joey Bose 、 Johannes Brandstetter 、 Bruno Correia 、 Ahmed Elhag 、 Kexin Huang 、 Chaitanya Joshi 、 Leon Klein 、 N M Anoop Krishnan 、 Chen Lin 、 Andreas Loukas 、 Santiago Miret 、 Luca Naef 、 Liudmila Prokhorenkova 、 Emanuele Rossi 、 Hannes Stärk 、 Alex Tong 、 Anton Tsitsulin 、 Petar Veličković 、 Minkai Xu 和 Zhaocheng Zhu 的重要贡献。

几何机器学习方法和应用填满了 2023 年高端期刊的封面(图源:来自Wang 等、Viñas 等、Deng 等、Weiss 等、Lagemann 等、Duan 等 和Lam 等的论文)
结构生物学(分子与蛋白质)
结构生物学家的视角
工业视角
系统生物学
材料科学(晶体)
分子动力学与机器学习势能
几何生成模型(流形)
大型图谱、可扩展性:当 GNN 太昂贵时
算法推理与对齐
知识图谱:归纳推理解决了吗?
时序图学习
LLMs + 图谱用于科学发现
酷炫的 GNN 应用
几何华尔街公告 💸
我们在全文中将使用的符号:
🔥 热点话题
💡 今年的亮点
🏋️ 挑战
➡️ 当前/下一步发展
🔮 预测/推测
💰 财务交易
结构生物学(分子与蛋白质)
Dominique Beaini(Valence)、Joey Bose(Mila & Dreamfold)、Michael Bronstein(Oxford)、Bruno Correia(EPFL)、Michael Galkin(Intel)、Kexin Huang(Stanford)、Chaitanya Joshi(Cambridge)、Andreas Loukas(Genentech)、Luca Naef(VantAI)、Hannes Stärk(MIT)、Minkai Xu(Stanford)
结构生物学无疑是 2023 年几何深度学习领域的前沿。
继 2020 年 halicin 被发现作为一种潜在的新型抗生素之后,2023 年通过 GNNs(图神经网络)的帮助,发现了两种新的抗生素!首先是 abaucin(由麦克马斯特大学和麻省理工学院研究人员发现),它能够靶向一种对多种药物有抗药性的顽固病原体。其次,麻省理工学院和哈佛大学的研究人员 发现了一种新的抗生素结构类别,其中筛选过程得到了 ChemProp 的支持,这是一套用于分子性质预测的 GNN 工具。此外,我们还观察到在最近的自主分子发现工作中,机器学习和实验技术的融合(“实验室环节”)正在趋于一致(这是我们在后续材料设计部分也将看到的趋势)。
流匹配已经成为 2023 年最大的生成型机器学习趋势之一,相比扩散模型,它能够实现更快的采样和确定性的采样轨迹。在生物学应用中,我们看到的流匹配模型最突出的例子有:FoldFlow(Bose、Akhound-Sadegh 等人)用于蛋白质骨架生成,FlowSite(Stärk 等人)用于蛋白质结合位点设计,以及 EquiFM(Song、Gong 等人)用于分子生成。

FoldFlow 不同版本学习到的条件概率路径,展示了 SO(3) 在其齐次空间上作用下单一残基的旋转轨迹 𝕊²*。图源:* Bose、Akhound-Sadegh 等人。
高效流匹配在复杂几何体上的应用,结合必要的等变性,得以实现,这得益于几篇理论论文的贡献,包括《黎曼流匹配》(Chen 和 Lipman)、《小批量最优传输》(Tong 等人)、以及《无仿真薛定谔桥》(Tong、Malkin、Fatras 等人)。一个很好的学习流匹配的资源是 GitHub 上的 TorchCFM 仓库,其中包含代码示例和笔记本,以及 Yaron Lipman、Joey Bose、Hannes Stärk 和 Alex Tong 的讲座。
扩散模型仍然是结构生物学中生成建模的主要驱动力。2023 年,我们见证了几项具有里程碑意义的工作:FrameDiff(Yim, Trippe, De Bortoli, Mathieu 等)用于蛋白质骨架生成,EvoDiff(Alamdari 等)用于生成具有离散扩散的蛋白质序列,AbDiffuser(Martinkus 等)用于全原子抗体设计,结合框架平均和离散扩散(并通过成功的实验验证),DiffMaSIF(Sverrison, Akdel 等)和DiffDock-PP(Ketata, Laue, Mammadov, Stärk 等)用于蛋白质-蛋白质对接,DiffPack(Zhang, Zhang 等)用于侧链包装,贝克实验室发布了RFDiffusion 全原子版本(Krishna, Wang, Ahern 等)。在潜在扩散模型(如图像生成应用中的 Stable Diffusion)中,GeoLDM(Xu 等)是第一个用于 3D 分子构象的模型,随后是OmniProt用于蛋白质序列-结构生成。

FrameDiff:通过旋转、平移和扭转角度对氧原子的骨架框架进行参数化。图源:Yim, Trippe, De Bortoli, Mathieu 等
最后,谷歌 DeepMind 和 Isomorphic Labs 宣布 AlphaFold 2.3——最新版本在三个任务上显著改进了基准:对接基准(比 DiffDock 在新的PoseBusters基准上好近 2 倍),蛋白质-核酸相互作用,以及抗体-抗原预测。
Chaitanya Joshi (剑桥)
💡2023 年,我非常兴奋地看到生物分子建模和设计领域出现了两个新兴趋势:
1️⃣ 从蛋白质结构预测到构象集生成的过渡。针对这个问题有几种有趣的方法,包括 AlphaFold 与 MSA 聚类,idpGAN,Distributional Graphormer(一种扩散模型),以及 AlphaFold 与流匹配结合生成蛋白质集。
2️⃣ 生物分子复合物建模与蛋白质 + X 的生物分子相互作用设计:RFdiffusion 全原子模型 和 Ligand MPNN,这两者均来自 Baker 实验室,是朝着设计相互作用趋势的代表性示例。正在开发中的新AlphaFold 报告称,统一的结构预测模型可以在单一蛋白质和蛋白质复合物结构预测以及蛋白质-配体和蛋白质-核酸协同折叠方面,超越或匹配专业化模型的表现。
“然而,对于生物分子建模和设计中所有令人兴奋的方法学发展而言,也许今年机器学习社区最大的一课应该是,更加关注有意义的计算机模拟评估,如果可能的话,进行实验验证。” — Chaitanya Joshi(剑桥)
1️⃣ 在 2023 年初,Guolin Ke 团队在 DP Technology 发布了两篇出色的再评估论文,突显了我们可能在分子构象生成和对接方面,相较于传统基准方法,过高估计了以几何深度学习为基础的主流方法的性能。
2️⃣ PoseCheck 和 PoseBusters 进一步揭示了当前分子生成和对接方法的失败模式。关键是,生成的分子及其 3D 姿态往往是“非物理的”,并且包含立体冲突、氢原子位置问题和高应变能。
3️⃣ 很少有论文尝试对新的机器学习思想进行实验验证。也许对于那些专注于新方法开发的人来说,与湿实验室的合作比较具有挑战性,但我希望作为一个机器学习社区,我们至少能在不断推动新模型的过程中,更加谨慎地对待我们常用的计算机模拟评估指标。
Hannes Stärk (MIT)
💡我看到关于流匹配、随机插值和整流流(我将其称为“桥接匹配”,或“BM”)的讨论颇为火热。我不认为仅仅在所有现有应用中用 BM 替换扩散模型有什么太大价值。对于纯生成建模,BM 的主要优势是简洁性。
我认为我们应该更加兴奋的是 BM 所解锁的新能力。例如,无需仿真即可在任意分布之间训练桥接(这个方法最适合哪些应用呢?到目前为止,我只看到了逆合成)。或者解决像DSBM这样针对流体流动下采样的 OT 问题。也许 2023 年出现了许多工具(我们还可以提到具有多个边际的 BM),而到 2024 年,社区将更好地利用它们?
乔伊·博斯 (Mila & Dreamfold)
💡 今年我们确实看到了几何生成模型从理论到实践的崛起。几项突出的工作包括黎曼流匹配——一般来说,任何由 Ricky Chen 和 Yaron Lipman 撰写的关于这些主题的论文都值得一读——以及Yim 等人的 FrameDiff,它引入了很多蛋白质骨架生成的重要机制。当然,在 RFM 和 FrameDiff 的基础上,我们构建了FoldFlow,这是一种更酷的流匹配方法,用于蛋白质生成模型。
“展望未来,我预见到更多基于流匹配的方法将会投入使用。它们对于蛋白质和较长的序列更有效,并且可以从任何源分布开始。” — 乔伊·博斯 (Mila & Dreamfold)
🔮 此外,我怀疑我们很快将在这个领域看到多模态生成模型,例如离散+连续模型,以及类似于图像的文本条件扩散模型的条件模型。也许,考虑到它们的扩展性,我们甚至可能会在这里看到潜在生成模型!
敏凯·徐 (斯坦福大学)
“今年,社区在多个方面进一步推动了几何生成模型在 3D 分子生成中的应用。” — 敏凯·徐 (斯坦福大学)
流匹配:Ricky 和 Yaron 提出了流匹配方法,作为广泛使用的扩散模型的替代方案,EquiFM(Song 等人 和 Klein 等人)通过使用等变 GNN 来参数化流动动力学,从而实现了 3D 分子生成的变种。与此同时,FrameFlow和FoldFlow构建了用于蛋白质生成的 FM 模型。
🔮 展望未来,类似于视觉和文本领域,人们开始探索在低维潜在空间中进行生成,而不是在复杂的原始数据空间中进行生成(潜在生成模型)。GeoLDM(Xu 等人)提出了第一个潜在扩散模型(类似于 CV 中的稳定扩散),用于 3D 分子生成,而Fu 等人则在大型蛋白质生成中使用了类似的建模方法。
结构生物学家的视角
布鲁诺·科雷亚 (EPFL)
“目前的生成模型仍然会产生‘垃圾’输出,这些输出违反了分子已知的许多物理和化学属性。当然,当前生成模型的优势在于它们的速度,这使得它们能够生成大量样本,从而使得筛选最佳生成样本的能力变得尤为重要,在蛋白质设计的领域,AlphaFold2 的变革性发展对此带来了巨大的帮助。” — 布鲁诺·科雷亚(EPFL)
➡️ 对社区的下一个挑战或许是如何将有意义的物理和化学先验注入生成模型,以提高采样性能和泛化能力。有趣的是,在小分子设计应用中,我们尚未看到同样显著的进展(经实验验证),但我们希望在 2024 年能看到这一突破。
➡️ 多模态模型的崛起。 通常在生物学相关任务中,数据稀缺性是一个普遍问题,因此,从数据中提取最大信号的策略至关重要。克服这种局限性的一种方式是提高数据表示的表现力,也许通过这种方式可以获得更高效的神经网络。短期内,我们可能能够探索涵盖多种对象表示的架构,并利用最优预测来解决我们面临的越来越复杂的任务,随着越来越多基础问题的解决,这一点变得尤为重要。多模态性的概念当然与拥有更强先验的模型的总体目标密切相关,在生成任务中,这些模型尊重对象的基本约束。
➡️ 全知模型。 随着机器学习模型能力的提升,我们显然倾向于请求更多的多目标优化,以解决现实生活中的问题。以小分子生成为例,从生物化学的角度来看,药物设计问题首先是确定一个靶点,药物分子需要与该靶点结合,因此,生成过程的首要约束条件之一是它应当以蛋白质口袋为条件。然而,这样的约束可能不足以生成真实的小分子,因为许多这样的化学物质根本无法合成或极其难以合成,因此,具备化学合成可行性概念并能将这些约束集成到搜索空间中的模型将更加有用。
➡️ 从化学型到表型。在数据表示方面,原子图结构和向量嵌入技术取得了显著成果,尤其是在寻找新抗生素方面。广义上来说,准确预测哪些化学结构具有抗菌活性,是从化学结构到表型预测的一个重要练习。由于所采用方法的简便性和取得的令人印象深刻的成果,可以预期,若能在分子端采用更复杂的数据表示,并且可能结合更丰富的表型分配,这将对药物开发中的这一重要问题作出关键贡献。
工业视角
Luca Naef (VantAI)
🔥2023 年你注意到的领域里最大的进展是什么?
1️⃣ 多模态性和模块化的增加 —— 如蛋白质和小分子的初步共折叠方法的出现所示,包括扩展 AF2 成功的扩散和非扩散基础方法:DiffusionProteinLigand在 2022 年最后几天发布,和RFDiffusion,AlphaFold2以及Umol预计在 2023 年底发布。我们也看到了一些模型在序列与结构共同训练的基础上进行训练:SAProt,ProstT5,以及在ProteinINR模型下,序列、结构与表面共同训练。经过 2021 年和 2022 年相对平静的时期,基于表面的方法重新获得了关注:DiffMasif,SurfDock,和ShapeProt。
2️⃣ 数据集和基准测试。数据集,尤其是合成的/计算生成的数据集:ATLAS 和 MDDB 用于蛋白质动态学。MISATO,SPICE,Splinter 用于蛋白质-配体复合物,QM1B 用于分子属性。PINDER:一个大型蛋白质-蛋白质对接数据集,包含匹配的空载/预测对和基准套件,附带重新训练的对接模型。CryoET 数据门户 用于冷冻电子断层扫描(CryoET)。以及一系列受欢迎的基准测试:PINDER、PoseBusters 和 PoseCheck,重点关注更加严格和实际相关的设置。
3️⃣ 创造性的预训练策略,以克服蛋白质-配体复合物的稀缺性。范德梅尔训练(DockGen)和RF-AA中的侧链训练策略,以及在RF-AA中的配体单独复合物预训练。多任务预训练,如Unimol及其他。
🏋️ 研究人员可能忽视的开放性挑战是什么?
1️⃣ 泛化能力。DockGen展示了当前最先进的蛋白质-配体对接模型在试图泛化到新蛋白质领域时完全失去了预测能力。在AlphaFold 最新报告中,我们看到了类似的现象,新蛋白质和配体的表现显著下降,甚至低于基于生物物理学的基线(这些基线可以访问完整结构),尽管对新蛋白质和配体的定义非常宽泛。这表明现有的方法仍然很大程度上依赖于记忆,这一观察已在多年的讨论中得到了广泛论证。
2️⃣ (简单)基线的诅咒。这是一个多年来反复出现的话题,2023 年再次证明了业界从业者早已知道的一点:在许多实际问题中,如分子生成、属性预测、对接和构象预测,简单的基线或经典方法往往在实践中仍优于基于机器学习的方法。2023 年,Tripp 等人, Yu 等人,Zhou 等人对此有越来越多的文献记录。
🔮 2024 年的预测!
“在 2024 年,数据稀缺问题仍将是我们关注的重点,我们将看到很多智能方法来使用模型生成合成训练数据。AlphaFold2 中的自蒸馏为我们提供了重要的启示,DockGen中的置信引导方法,利用了我们现在已经拥有足够强大的模型,这些模型能够评分姿势,但并不总是能够生成姿势,这一发现最早在2022 年中提出。” — Luca Naef (VantAI)
2️⃣ 我们将看到更多为机器学习专门设计的生物/化学实验,或仅在机器学习的背景下才有意义(即,这些实验可能本身不会带来生物学上的新见解,但主要用于训练模型)。2023 年的一个例子是Tsuboyama 等人的大规模蛋白质折叠实验。这个举措可能受到科技生物初创公司的推动,我们已经看到基于这些专门为结构生物学设计的机器学习实验构建的首个基础模型,例如ATOM-1。
Andreas Loukas (Prescient Design, 属于基因泰克的一部分)
🔥 你在 2023 年注意到的该领域最大进展是什么?
“在 2023 年,我们开始看到一些关于蛋白质等变生成和表示的挑战通过扩散模型得以解决。” — Andreas Loukas (Prescient Design)
1️⃣ 我们还注意到了一种向更高保真度的分子系统建模与生成方法转变。例如,最新的模型采用完全端到端的方法,通过联合生成主链、序列和侧链(AbDiffuser, dyMEAN),或者至少通过一个部分联合的模型在两步中解决问题(Chroma);相比之下,主链生成后再进行反向折叠,如在RF Diffusion和FrameDiff中所述。其他改善建模保真度的尝试可以在最新的共折叠工具更新中找到,如AlphaFold2和RF Diffusion,它们使模型对非蛋白质成分(配体、辅基、辅因子)更为敏感;同时也可以在一些试图考虑构象动力学的论文中找到(见上文讨论)。在我看来,这一研究方向至关重要,因为分子系统的结合行为对原子如何放置、移动和相互作用非常敏感。
2️⃣ 在 2023 年,许多研究也尝试通过学习预测已知晶体突变的效应来掌握结合亲和力,通过在大型语料库上进行预训练,如计算预测的突变(graphinity),以及侧任务,如旋转体密度估计。所获得的结果令人鼓舞,因为它们可以显著超越如 Rosetta 和 FoldX 等半经验基线。然而,仍然有大量工作需要完成,以使这些模型在结合亲和力预测中更加可靠。
3️⃣ 我进一步观察到,蛋白质语言模型(pLMs),特别是ESM,作为有价值的工具,得到了越来越多的认可,即使在那些主要偏好几何深度学习的人群中。这些嵌入被用来帮助对接模型,允许构建简单但具有竞争力的结合亲和力预测模型(Li et al 2023),并且通常能提供一种高效的方式,利用广泛的蛋白质组数据为 GNNs 创建残基表示,而无需大量的预训练(Jamasb et al 2023)。然而,我确实对使用 pLMs 存在担忧:目前尚不清楚它们的有效性是由于数据泄漏还是由真正的泛化能力所致。在评估如逆折叠中的氨基酸恢复和条件 CDR 设计等任务时,这一点尤其重要,因为区分这两者的因素至关重要。
🏋️ 研究人员可能忽视了哪些开放性挑战?
1️⃣ 与能量放松的晶体结构(甚至更糟,折叠结构)一起工作可能会显著影响下游预测模型的性能。对于蛋白质-蛋白质相互作用(PPI)的预测尤为如此。根据我的经验,当预测模型给定的是放松后的结构而非结合(全)晶体结构时,PPI 预测器的性能会严重下降。
2️⃣ 尽管成功的计算机辅助抗体设计有潜力彻底改变药物设计,但通用蛋白质模型(还?)不如专门针对抗体的模型在折叠、对接或生成抗体方面表现得那么好。这可能是由于抗体折叠的低构象变异性和抗体与抗原之间不同的结合方式(涉及环介导的相互作用,并可能包含不可忽视的熵项)。也许由于同样的原因,抗体结合物的de novo设计(我定义为零样本生成能与之前未见的表位结合的抗体)仍然是一个未解的难题。目前,经过实验验证的de novo结合物案例大多涉及稳定的蛋白质,如α-螺旋束,它们在 PDB 中较为常见,且包含的界面与表位-抗表位相互作用有显著不同。
3️⃣ 我们仍然缺乏一种通用的结合自由能代理模型*。问题在于缺乏足够大且多样化的高质量数据(特别是共晶结构)。因此,我们应该意识到,任何此类学习代理模型在评估模型时的局限性:尽管预测的结合评分超出已知结合物分布范围是明显的信号,表明某些地方可能出了问题,但我们应避免陷入典型的误区,即通过展示模型如何导致更高的评分来证明我们模型的优越性。
多米尼克·贝阿尼(Valence Labs,Recursion 的一部分)
“我很高兴看到围绕药物发现问题建立起了一个庞大的社区,并且我感到我们正处于药物发现速度和效率新革命的边缘。” — 多米尼克·贝阿尼(Valence Labs)
2023 年有哪些工作让我感到兴奋?
我相信机器学习将使我们能够快速应对罕见病,阻止下一场 COVID-X 大流行在蔓延之前爆发,并使我们活得更久、更健康。但仍有大量工作要做,前方有许多挑战,有一些障碍和一些峡谷。说到社区,您可以访问Valence Portal以跟上药物发现领域机器学习中的🔥新进展。
2024 年有哪些难题?
⚛️ 新一代量子力学。 基于等变和不变图神经网络(GNNs)的机器学习力场,一直向我们承诺着一笔宝藏。那就是密度泛函理论的精度,但速度要快上千倍,并且适用于整个蛋白质的尺度。虽然Allegro和MACE-MP等方向上已取得了一些进展,但目前的模型在未见过的情境和超大分子上泛化性较差,而且它们的计算速度仍然太慢,无法满足所需的时间尺度🐢。对于泛化问题,我相信更大且更多样化的数据集是最重要的突破点。对于计算时间,我认为我们将看到一些不那么强求等变性的模型,例如FAENet。但高效的采样方法将发挥更大的作用:空间采样,比如使用DiffDock来获得更有趣的起始点,以及时间采样,如TimeWarp来避免模拟每一帧。我对 2024 年即将到来的重大突破充满期待:时空等变的 Boltzmann 采样器(STEBS)👣。
🕸️ 一切都是相互联系的。生物学本质上是多模态的🙋🐁 🧫🧬🧪。 无法简单地将分子与其他生物系统解耦。当然,这就是过去药物发现中的机器学习方法:仅仅构建分子图模型并将其拟合到实验数据中。但我们已经达到了一个关键点🛑,无论图神经网络(GNN)模型中有多少万亿个参数,使用了多少数据进行训练,或者汇聚了多少专家团队,它都无法解决问题。是时候将生物学融入其中,最直接的方式就是采用多模态模型。一种方法是用目标蛋白质序列来调节 GNN 的输出,例如MocFormer。另一种方法是使用显微镜图像或转录组学来更好地为模型提供分子的生物学特征信息,例如TranSiGen。还有一种方法是使用大语言模型(LLMs)来嵌入关于任务的上下文信息,例如TwinBooster。甚至更好的是,将这些方法结合起来🤯,但这可能需要几年时间。对于更广泛的社区而言,主要问题似乎是缺乏大量高质量和标准化的数据,但幸运的是,这对 Valence 来说并不是问题。
🔬 关联生物学知识和可观察现象。 人类一直在努力绘制生物学的图谱,构建基因🧬、蛋白质-蛋白质相互作用🔄、代谢通路🔀等的关系图。我邀请你阅读这篇关于药物发现中的知识图谱综述。然而,这些知识常常被机器学习(ML)社区忽视,未被充分利用。我认为这是一个知识图谱领域,特别是在 2024 年,图神经网络(GNNs)可以发挥巨大作用的地方,它也可能为上述🕸️点提供另一种方式。考虑到人类知识的不完整性,我们可以通过基础模型恢复关系图谱。这也是Phenom1在尝试回溯已知遗传关系时所采取的路径。然而,处理各种知识数据库是一项极其复杂的任务,我们不能指望大多数机器学习科学家能够单独应对。但是借助人工助手,如LOWE,这一任务可以在几秒钟内完成。
🏆 基准测试,基准测试,基准测试。 我无法重复基准测试这个词足够多。唉,基准测试依然会是机器学习领域不受宠的存在🫥。但如果“基准测试”这个词不酷,它的表亲竞赛就要酷得多了😎!正如OGB-LSC竞赛和Open Catalyst挑战对 GNN 社区的重大作用一样,现在是时候迎来一系列新的竞赛🥇。我们甚至最近迎来了TGB(时序图基准测试)。如果你参加了 NeurIPS'23,可能就听说了 Polaris 将在 2024 年初登场✨。Polaris 是一个由多个制药公司和学术团体组成的联盟,旨在提高现有分子基准测试的质量,以更好地代表真实的药物发现过程。也许我们会看到适用于分子图生成的基准测试,而不是优化 QED 和 cLogP,但我不会抱太大希望,毕竟我已经等了多年。是什么样的新奇竞赛将点亮今年的 GDL 社区呢🤔?
系统生物学
黄克鑫(斯坦福大学)
生物学是一个相互连接、多尺度和多模态的系统。有效的建模不仅可以揭示基础的生物学问题,还可以对治疗发现产生重大影响。最自然的封装这种系统的数据格式是关系型数据库或异构图。这种图存储了几十年来在各种生物学领域进行的湿实验数据,数据量可达数十亿条。
“在 2023 年,我们见证了许多使用 GNNs 的创新应用,应用于这些生物系统图谱。这些应用解锁了新的生物医学能力,并回答了关键的生物学问题。” — 黄克鑫(斯坦福大学)
1️⃣ 一个特别令人兴奋的领域是微扰生物学。理解微扰的结果可以推动细胞重编程、靶点发现和合成致死等方面的进展。在 2023 年,GEARS将图神经网络(GNN)应用于基因扰动关系图,能够预测之前未曾观察到的基因扰动结果。
2️⃣ 另一个酷炫的应用涉及蛋白质表示。当前的蛋白质表示是固定且静态的,但我们认识到同一蛋白质在不同的细胞环境中可能表现出不同的功能。PINNACLE在蛋白质相互作用网络上使用 GNN 来对蛋白质嵌入进行情境化处理。该方法已被证明能够增强基于 3D 结构的蛋白质表示,并在识别治疗靶点方面优于现有的无上下文模型。

PINNACLE 拥有蛋白质、细胞类型和组织层次的注意力机制,使得该算法能够在一个统一的嵌入空间中生成蛋白质、细胞类型和组织的情境化表示。来源:Li et al
3️⃣ GNN 在诊断罕见疾病方面也发挥了重要作用。SHEPHERD利用 GNN 在庞大的知识图谱上对生物学知识进行编码,并已被证明有助于因果基因发现、识别具有相似基因或疾病的‘类似患者’,并为新型疾病表现提供可解释的见解。
➡️ 超越预测,理解生物现象背后的机制至关重要。**图神经网络解释性 AI(Graph XAI)**应用于系统图是识别机制路径的自然选择。例如,TxGNN将药物-疾病关系预测与生物系统图结合,生成多跳的可解释路径。这些路径阐明了药物治疗特定疾病的潜力。TxGNN 为这些解释设计了可视化工具,并进行用户研究,证明其在临床医生和生物医学科学家决策中的有效性。

一个基于网络的图形用户界面,用于支持临床医生和科学家探索和分析由 TxGNN 生成的预测和解释。‘控制面板’允许用户选择感兴趣的疾病,并查看该疾病的顶级 TXGNN 预测结果。‘边缘阈值’模块使用户能够调整解释的稀疏性,从而控制显示的多跳路径的密度。‘药物嵌入’面板允许用户比较所选药物相对于整个药物重定位候选库的位置。‘路径解释’面板展示了被认为对 TXGNN 预测疗效至关重要的生物学关系。来源:黄、Chandar 等
➡️ 生物学中的基础模型主要是单模态的(专注于蛋白质、分子、疾病等),这主要是由于配对数据的稀缺性。跨模态桥接以回答多模态查询是一个令人兴奋的前沿。例如,BioBridge利用生物学知识图谱在单模态基础模型之间学习变换,从而实现多模态行为。
🔮 GNN 应用于系统图谱具有以下潜力:(1)编码广泛的生物医学知识,(2)桥接生物学模态,(3)提供机制性洞察,(4)为生物实体提供背景。我们预计 2024 年 GNN 在生物学中的应用将带来更多突破性的进展,解决该领域一些最紧迫的问题。
2023 年预测
(1)扩散模型的性能提升,如更快的采样和更高效的求解器;
✅ 是的,使用流匹配
(2)更强大的条件蛋白质生成模型;
❌ Chroma 和 RFDiffusion 仍然领先
(3)生成流网络在分子和蛋白质中的更多成功应用;
❌ 尚未确定
材料科学(晶体)
迈克尔·高尔金(英特尔)和圣地亚哥·米雷特(英特尔)
在 2023 年,短短一段时间内,所有科学新闻都在讨论LK-99——一个由韩国团队创造的假设常温超导体(剧透:截至目前它并未成功)。
这突显了机器学习在材料科学中的巨大潜力,或许今年最大的进展就发生在这里——我们现在可以说,材料科学和材料发现已经成为几何深度学习(Geometric DL)领域的核心内容。
💡应用于材料科学和发现的几何深度学习取得了显著进展,涵盖了新的建模方法、新基准和数据集的创建、生成方法的自动化设计,以及基于这些进展识别出新的研究问题。
1️⃣ 几何模型作为自动化发现工作流中评估工具的应用。Open MatSci ML Toolkit整合了所有开源的晶体结构数据集,提供了 150 万个数据点,用于基态结构计算,这些数据现在可以轻松用于模型开发。作者的初步结果似乎表明,合并数据集如果处理得当,会提高性能。
2️⃣ MatBench Discovery是另一个很好例子,展示了几何模型作为晶体稳定性评估工具的应用,它测试了模型对各种晶体结构的能量高于外壳的预测。能量高于外壳是晶体结构稳定性的最可靠近似,也代表了比形成能或原始能量预测更有实际意义的度量方法,因为后者作为稳定性度量存在实际限制。

普遍势能是更可靠的分类器,因为它们最早会退出红色三角区域。这些线显示了在 WBM 测试集上的滚动平均绝对误差(MAE),随着 MP 训练集的能量与凸包的距离变化,数值越低越好。红色突出显示的“危险三角区”表示模型最有可能错误分类结构。只要模型的滚动 MAE 保持在三角区内,其平均误差大于到凸包的距离。如果模型对某一预测的误差恰好指向位于凸包 0 eV 稳定性阈值处(图表的中心),则其平均误差将改变材料的稳定性分类,从真正的正/负类变为假负/假正类。’滚动窗口’框的宽度表示在其内平均的误差对凸包距离预测的宽度。来源:Riebesell 等人
3️⃣ 在晶体结构预测的新几何模型方面,晶体哈密顿图神经网络(CHGNet,邓等人)是一个新型的图神经网络(GNN),它基于材料项目的静态和松弛轨迹进行训练,与先前的方法相比,表现出相当有竞争力的性能。CHGNet 的发展表明,找到更好的训练目标将与开发新方法一样(如果不是更重要的话)变得至关重要,因为材料科学与几何深度学习的交集正在不断增长。
🔥 几何深度学习与材料发现进一步整合的其他证明点是几家大实验室在晶体结构发现中采用生成方法的几项庞大工作:
1️⃣ Google DeepMind 发布了 GNoME(由 Merchant 等人 提出的材料科学图网络)作为发现新材料的成功主动学习流程示例,同时发布了 UniMat 作为一个 从头开始 的晶体生成模型。类似于蛋白质领域,我们也看到了更多关于材料科学的自动化实验室示例(“实验室-环路”),例如 UC 伯克利的 A-Lab。

GNoME 的主动学习循环。来源:Merchant 等人
2️⃣ 微软研究院发布了 MatterGen,一个用于无条件和属性引导材料设计的生成模型,以及 Distributional Graphormer,一个生成模型,旨在恢复分子/蛋白质/晶体的平衡能量分布。

MatterGen 的无条件和有条件生成。来源:Zeni, Pinsler, Zügner, Fowler, Horton 等人
3️⃣ Meta AI 和 CMU 发布了 Open Catalyst Demo,在这里你可以体验 11.5k 种催化剂材料在 86 种吸附体上、每种吸附体有 100 种不同配置的弛豫(DFT 近似),总共有多达 100M 种组合。该演示由最先进的几何模型 GemNet-OC 和 Equiformer-V2 提供支持。
Santiago Miret (英特尔)
虽然这些工作代表了生成方法的大规模应用,但也有新的研究在使用强化学习(Govindarajan 等人,Lacombe 等人)和 GFlowNets(Mistal 等人,Nguyen 等人)结合几何深度学习来发现晶体结构,正如在 AI 加速材料设计 (AI4Mat) 研讨会中所强调的那样,NeurIPS’23 上的 AI4Mat-2023 也见证了参与人数的迅速增长,提交和接收的论文数量翻了一番,参会人数几乎增加了三倍。
💡 几何深度学习(DL)和图神经网络(GNN)仍然是 AI4Mat 研究内容的重要组成部分,我们看到这些方法的应用不仅限于性质预测,还扩展到了化学合成和材料表征的改善。AI4Mat-2023 工作坊中突出的一个有前景的例子是KREED(Cheng, Lo 等),它使用等变扩散方法,根据从真实实验室机器获得的不完全信息,预测分子的三维结构。
“考虑到结构数据在材料表征中的重要性,AI4Mat 的讨论突出了几何深度学习进入真实世界材料建模领域的机会,除了在包括基于机器学习的潜力在内的仿真中的持续成功。” — Santiago Miret(英特尔)
🔮 在 2024 年,我预计会看到多项发展:
1️⃣ 更多发现新的架构和工作流程,直接整合像 M3GNet、CHGNet、MACE 这样的几何模型。
2️⃣ 几何模型可能会面临来自文本表示和大语言模型(LLMs)的竞争,因为新方法正在被提出,直接生成 CIF 文件。
3️⃣ 更多几何模型和 GNN 应用于真实世界的实验数据,可能用于材料表征,如 KREED,预计将面临与基于仿真建模相比数据较少的情况。
分子动力学与机器学习潜力
Michael Galkin (英特尔),Leon Klein (柏林自由大学),N M Anoop Krishnan (印度理工学院德里分校),Santiago Miret (英特尔)
2023 年一个显著的趋势是朝着基础模型发展,这些模型适用于从小分子到周期性晶体等多种化合物的机器学习潜力。
例如,FAIR 和 CMU 的JMP(Shoghi 等)、来自中国多个机构的大型合作的DPA-2(Zhang, Liu 等),以及剑桥主导的合作中的MACE-MP-0(Batatia 等)。实际上,这些是预训练的几何图神经网络(GNN),用于预测某一原子结构的能量(或力)。另一个值得提及的例子是Equiformer V2(Liao 等),这是一种强大的等变变换器,在许多任务中保持了 SOTA 表现,包括最近的OpenCatalyst 2023 挑战和OpenDAC(直接空气捕获)挑战。

一个材料建模的基础模型。该模型仅在材料项目数据上进行训练,数据主要由无机晶体组成,并且在氧化物方面偏向较重,MACE-MP-0 能够在固态、液态和气态的各种化学环境中进行分子动力学仿真。来源:Batatia 等
⚛️ 机器学习势能的一个常见应用场景是分子动力学(MD),其目的是在纳秒(10ᐨ⁹)到秒的时间范围内模拟某种结构。主要问题在于,经典方法中的基本时间步长是飞秒(10ᐨ¹⁵),也就是说,模拟一个纳秒至少需要 1 百万个时间步,这非常昂贵。现代基于机器学习的 MD 方法旨在通过应用粗粒化和其他近似技巧加速模拟,这些方法能够大幅提高模拟速度(30–1000 倍)。Fu, Xie, et al(TMLR'23)将粗粒化应用于原子结构,并在较小的图上运行 GNN 来预测下一步位置。在实验中,该方法相比经典方法实现了 1000–10,000 倍的加速。TimeWarp(Klein, Foong, Fjelde, Mlodozeniec, et al,NeurIPS’23)通过使用条件归一化流模型来模拟大时间步长(1⁰⁵ — 1⁰⁶飞秒),该模型近似下一步位置的分布。在使用 MCMC 采样的情况下,训练后的模型实现了约 33 倍的加速。

(a) 初始状态 x(t)(左)和接受的提议状态 x(t+τ)(右),通过 Timewarp 为二肽 HT(在训练期间未见过的样本)采样。(b) 模拟轨迹的 TICA 投影,显示了介稳态之间的转变,分别为短时间 MD 模拟(左)和 Timewarp MCMC(右),两者均运行了 30 分钟的墙钟时间。Timewarp MCMC 在有效样本数每秒的速度上,相比 MD 实现了约 33 倍的加速。来源:Klein, Foong, Fjelde, Mlodozeniec, et al
Santiago Miret(英特尔)
💡随着几何模型在属性建模中的成功部署,研究人员通过在现实世界的分子动力学模拟中测试这些模型,推动了最前沿的发展。首个突显仅基于能量和力训练模型存在问题的研究是 2023 年初在 TMLR 上发布的Forces Are Not Enough。然而,基于邻域的方法如Allegro的进展,导致了使用几何深度学习模型成功部署大规模模拟,其中包括戈登·贝尔奖提名。
“仍然有大量工作需要做,以确保机器学习势能在各种物理和化学现象中成功、广泛地部署。” — Santiago Miret(英特尔)
➡️ EGraffBench突出了其中一些新挑战,例如跨温度和材料相变的泛化(即固体到液体的变化),并提出了用于评估机器学习势能在实际 MD 模拟中表现的新指标。AI4Mat-2023 研讨会也展示了为特定应用场景开发的新型机器学习势能,例如用于电池的固态电解质。
Leon Klein (FU Berlin)
💡 将生成模型应用于从平衡玻尔兹曼分布中采样时的一个显著限制是,每次遇到新系统时都需要重新训练,这使得其相较于传统的分子动力学(MD)模拟的潜在优势受限。然而,最近的进展使得跨不同领域的可迁移模型逐渐浮现。我们的贡献,Timewarp,提出了一种可迁移模型,能够为聚焦于全原子小肽系统的 MD 模拟提出大时间步长。类似地,Fu et al. 捕捉了粗粒化聚合物的时间粗化动力学,而Charron et al.则擅长学习用于粗粒化蛋白质的可迁移力场。
“因此,今年展示了可迁移生成模型在 MD 模拟中的可行性,展示了它们加速此类模拟的潜力。” — Leon Klein (FU Berlin)
🔮 在 2024 年,我预计会有更多定制化的 GNN(图神经网络)被用来提高可迁移模型的准确性,可能会专注于编码更多关于系统的信息。例如,尽管 Timewarp 在其模型中缺乏旋转对称性,但它通过数据增强来弥补这一点。另一种选择是使用最近提出的SE(3)等变增强耦合流来加入旋转对称性。类似地,Charron et al.使用 SchNet,而非更复杂的 GNN。
N M Anoop Krishnan (IIT Delhi)
“今年在机器学习势能领域最令人兴奋的进展之一是开发了可以覆盖几乎所有元素的‘通用’原子间势能。” — N M Anoop Krishnan (IIT Delhi)
💡 继 2022 年的 M3GNet 之后,今年见证了基于 CHGNet 的三种模型的发展,它们分别是(Deng et al)、NequIP (Merchant et al) 和 MACE (Batatia et al)。这些模型已被用来展示几个具有挑战性的任务,包括材料发现(Merchant et al),以及多种 MD 模拟(Batatia et al),如相变、非晶化、化学反应、二维材料建模、溶解、缺陷、燃烧等。这些方法为这些势能的普适性提供了有希望的结果,从而使得解决具有挑战性的问题成为可能,包括从相应的非晶结构中发现晶体(Aykol et al),这是材料学中一个长期未解的开放问题。
🏋️ 虽然这些势能确实为解决一些突出问题提供了手段,但挑战依然存在,主要是理解这些势能可能失败的情境。
1️⃣ 测试这些势能的极限以了解它们的能力,是理解其局限性的重要方面。这包括模拟极端环境,如高压和辐射条件,模拟复杂的多组分系统,如玻璃或高熵合金,或者模拟不同相态的系统,如水或二氧化硅,这些都是有趣的挑战。
2️⃣ 虽然其中一些模型被称为“基础”模型,但它们并未展示与基础模型相关的涌现行为。这些模型大多只是展示了对潜在未见区域或新组合物的外推能力。开发真正具有涌现属性的基础模型将是一个有趣的挑战。
3️⃣ 第三个较少被关注的方面是这些模型大规模模拟的能力。尽管Allegro已经展示了这些势能在长度尺度方面的某些能力,但在更大的时间和长度尺度下进行稳定的模拟,同时保持“普适性”,仍然是这些势能面临的一个未解挑战。
🔮 2024 年会有什么期待?
1️⃣ 基准套件:虽然已经有一些关于分子动力学(MD)模拟的基准研究,预计 2024 年将在数据集和任务方面看到更多正式化的努力。一个标准的任务集,可以自动评估势能并将其放置在排行榜上,将便于对不同材料(如金属、聚合物或氧化物)的后续任务进行势能的排名。
2️⃣ 模型与数据集开发:将进一步努力使机器学习潜力在架构上更紧凑和高效。此外,2024 年还将见证大规模数据集的发展,这些数据集将为训练这些潜力提供从头计算数据。
3️⃣ 可微 MD/AIMD:此外,预计可微模拟的进展将成为融合实验与从头计算模拟,朝向自动开发针对特定应用的原子间势能的重要领域。今年可能还会看到可微 AIMD 的发展,结合机器学习的泛函,可能使得经济型模拟能够超越现有的限制,扩展到更大规模。
2023 年后的预测
我们预计将更加关注图神经网络(GNN)的计算效率和可扩展性。目前基于 GNN 的力场在精度上取得了显著成就,但仍然比传统力场慢 2 到 3 个数量级,且通常只能部署在几百个原子上。
✅ 阿莱格罗(Allegro)竞逐戈登·贝尔奖,使用 GNoMe 进行大规模筛选
🔮2024 年展望:
1️⃣ 更多将机器学习潜力部署到大规模分子动力学(MD)模拟中,展示新的研究机会和挑战,并提供更好的视角,了解机器学习潜力相对于传统势能的优势。
2️⃣ 新的数据集将概述机器学习潜力未曾探索的挑战,例如新材料系统和这些材料的新物理现象,如在不同温度和压力下的相变。
3️⃣ 探索多尺度问题,可能会从经典技术中获得灵感。
几何生成模型(流形)
乔伊·博斯(Joey Bose)(Mila & Dreamfold)和亚历克斯·汤(Alex Tong)(Mila & Dreamfold)
虽然生成型机器学习在 2023 年继续主导该领域,但将几何先验结合到几何生成模型中的普及成为这一年一个有趣的趋势。
乔伊·博斯(Joey Bose)(Mila & Dreamfold)
“今年我们看到几何生成模型这一新兴子领域迈出了重要的一步。随着扩散模型和流动匹配在图像中的成功应用,我们看到了更多基础性的贡献,为几何数据类型的生成式 AI 奠定了基础。”—— 乔伊·博斯(Joey Bose)(Mila & Dreamfold)
虽然流形的扩散模型早已有所存在,但今年我们确实看到了它们通过扩展黎曼扩散模型由Lou 等人和流形扩散场Elhag 等人的功能方法得到扩展

(左)在兔子流形 M 上的场的训练迭代的可视化表现。(右)在兔子流形上的场的采样过程的可视化表现。图源:Elhag 等人
对于基于归一化流的方法,陈和利普曼提出的黎曼流匹配被认为是 FM 领域中最通用的框架,位于众多论文之中。
一般来说,几何生成模型的一个重要主题是处理对称性。今年,等变方法大放异彩,包括 SE(3)模型中的EDGI(Brehmer, Bose 等人)、SE(3)增强耦合流(Midgley 等人),以及关于几何神经扩散过程的很酷的理论工作(Mathieu 等人)和由Abbot 等人提出的重要基于物理的应用。
Alex Tong (Mila & Dreamfold)
“2023 年,我们看到了建模的进展和一个新应用的崛起——蛋白质骨架设计。我们仍然需要更多的工作来理解 SE(3)ᴺ₀类型的积产品流形的性质,在这一领域,我们仍不清楚如何最好地结合不同的模态”——Alex Tong (Mila & Dreamfold)
2023 年出现了新的模型,如RFDiffusion、FrameDiff和FoldFlow,它们在蛋白质骨架的 SE(3)ᴺ₀流形上进行操作。这为几何生成模型提出了新的挑战,我认为我们将在未来一年看到在这一领域的重要进展。
在建模方面,欧几里得领域中的生成建模与流和桥接匹配模型相结合,迅速继而推出了黎曼流匹配(陈和利普曼)和等变流匹配(Klein 等人、Song 等人)用于分子生成任务。
🔮 2024 年展望:
1️⃣ 在蛋白质骨架设计取得成功后,更多地探索 SE(3)ᴺ₀流形的建模。
2️⃣ 进一步研究和理论探索如何在多模态和积产品流形上训练生成模型。
3️⃣ 利用更具体的流形和等变结构特征的领域特定模型。
大型图,扩展性:当 GNNs 过于昂贵时
Anton Tsitsulin (Google)
今年对于大型图形的爱好者来说是丰收的一年。
“在非常大的图形上进行学习一直是一个挑战,因为现代加速器不支持这种非结构化稀疏性,导致在硬件抽奖中落败。张量处理单元——你可以把它们看作是非常快速的 GPU,配有大量(多 TB)的 HBM 内存——是 2023 年的救星。” —— Anton Tsitsulin(Google)
在一篇 KDD 论文中(Mayer 等),我们展示了 TPUs 如何比 GPU 和 CPU 系统更高效地解决大规模节点嵌入问题,并且成本大大降低。图形机器学习的许多工业应用是完全无监督的,在这种情况下,很难评估嵌入质量。我们撰写了一篇论文(Tsitsulin 等),在大规模上执行无监督嵌入分析。

TpuGraphs 与其他图形属性预测数据集的规模对比。来源:Phothilimthana 等
➡️ 今年,TPU 帮助了图形机器学习,因此是时候回馈了。我们发布了一个新的TpuGraphs数据集(Phothilimthana 等),并基于该数据集举办了一个Kaggle 竞赛“Google——快还是慢?预测 AI 模型运行时间”,该竞赛展示了如何提升在 TPU 上运行的图形机器学习学习模型。竞赛共有 792 名参赛者,616 个队伍,和 10,507 个参赛作品。该数据集提供了比现有的最大图形属性预测数据集(在图形规模相当的情况下)多出 25 倍的图形,且相比现有的机器学习程序性能预测数据集,平均图形规模大 770 倍。这个数据集如此庞大,以至于曹等不得不开发一种新的算法来进行大规模图形上的图级预测。
➡️ 大规模图聚类在今年取得了显著进展。提出了一种新的近似算法(Cohen-Addad 等)用于相关性聚类,将近似因子从 1.994 提高到惊人的 1.73。TeraHAC(Dhulipala 等)是对去年ParHAC(我们在2023 年的文章中有提到)的一项重大改进——这是一个适用于万亿边图的近似(1+𝝐)层次聚合聚类算法。实验中使用的最大图是一个庞大的 Web-Query 图,包含 31B 节点和 8.6 万亿边👀。值得一提的还有目前最快的欧几里得最小生成树算法(Jayaram 等)以及一种新的近线性时间算法,用于近似点集之间的 Chamfer 距离(Bakshi 等)。
🔮 2024 年展望:
1️⃣ 算法进展将帮助扩展其他流行的图算法
2️⃣ 新型硬件的使用将有助于扩大不同图模型的规模
2023 年文章中的预测
(1) 进一步降低大型图的计算成本和推理时间
✅ 我们观察到聚类和节点嵌入的加速是数量级的提升。
(2) 或许 OGB LSC 图的模型可以在普通机器上运行,而不是庞大的集群?
❌ 坚决否定
算法推理与对齐
Petar Veličković(Google DeepMind)和 Liudmila Prokhorenkova(Yandex Research)
算法推理,作为一种能够执行算法计算的机器学习技术类别,在 2023 年持续稳定地进展。
Petar Veličković(Google DeepMind)
“2023 年对神经算法推理模型来说是一个稳步发展的年份——这确实是 GNN 发展中最具创意的领域之一——可能是因为它必须如此。” —— Petar Veličković(Google DeepMind)
除了已经讨论过的异步算法对齐工作外,还有三项成果是我在今年最为自豪的:
1️⃣ DAR证明了预训练的多任务神经算法推理器可以在下游图问题中进行可扩展部署——即使它们比 NAR 的合成训练分布大 180,000 倍。更重要的是,我们在小鼠大脑血管建模中达到了最先进的水平 🐁🧠🩸。NAR 不是“苦涩教训”的受害者!📈
2️⃣ Hint-ReLIC 🗿是我们对无提示模型这一丰富研究领域的回应。我们摒弃了存在众多问题的提示 自回归,而是利用因果推理来建模提示不变性。我们获得了一个强大的基于提示的 NAR,并且在 CLRS-30 的多个领域仍保持着最先进的技术!“如果正确使用,提示可以带你走得更远。”
3️⃣ 最后但同样重要的是,我们迈出了重要一步,首次对训练好的 NAR 模型的潜在空间表示进行了深入分析。我们发现的结果不仅非常美观🌺,而且还让我们学到了很多关于这些模型如何工作的知识。

左图:对八个重加权图的轨迹主成分分析(PCA),显示它们都包含一个单一的主导方向。不同的聚类具有不同的颜色。中图:许多嵌入聚类,主导方向以红色叠加。右图:随机图的逐步主成分分析,主导聚类方向以红色叠加。来源:Mirjanić, Pascanu, Veličković
除了壮大我们的充满活力的社区外,我还认为有必要声明,NAR 的许多基础性思想处于重要 LLM 方法论的核心;举一个例子,提示跟随直接与思维链提示相关。
💡 我最开心的是,在 2023 年,这个链接得到了明确的认可,NAR 的想法现在正在直接或间接地影响当今最强大的 AI 系统。事实上,NAR 被列为研究长度泛化的关键动机之一,更广泛地说,它对未见数据的泛化有重要影响(ICML’23 最佳论文奖)。CLRS-30,作为 NAR 的旗舰基准,直接用于评估 LLM 在神经架构搜索和通用 AI 研究中的能力。而作为最后的点睛之笔,CLRS-30 被认为是Gemini(Google DeepMind 的前沿大语言模型)使用的仅有七个推理评估之一。我希望这能成为 2024 年未来的一个指路明灯,并且我们将看到更多来自 NAR 的想法打破进入前沿可扩展 AI 模型的设计中。
Liudmila Prokhorenkova(Yandex 研究员)
全年间,在赋予模型各种算法归纳偏见的道路上取得了显著进展:使用双问题(Numeroso et al),对比学习技术(Bevilacqua et al; Rodionov et al),通过数据结构增强模型(Jürß et al; Jain et al),以及对计算模型的深入研究(Engelmayer et al)。另一个重要方向是评估现有模型在可扩展性和数据多样性方面的表现(Minder et al)。
“在 2024 年,能够看到对神经推理器进行更全面的分析和理解会是很好的进展:它们学习了哪些操作,它们对数据分布的不同变化有多敏感,它们倾向于犯哪些类型的错误,以及为何如此。” — Liudmila Prokhorenkova(Yandex Research)
获得这些见解可能有助于开发更加稳健和可扩展的模型。此外,稳健的神经推理器有可能对组合优化模型产生积极影响。
2023 年后的预测
(1) 算法推理任务可能会扩展到包含成千上万个节点的图形,并在代码分析或数据库等实际应用中发挥作用。
✅ 是的,DAR 能够扩展到 OGB 船舶大小
(2) 基准测试中将出现更多算法
✅ 是的,SALSA-CLRS
(3) 最不可能的情况——会出现一个能够解决 quickselect 问题的模型
❌ 仍未解决 😭
知识图:归纳推理已解决?
Michael Galkin(英特尔)和 Zhaocheng Zhu(Mila & Google)
自 2011 年成立以来,知识图表示学习的重大挑战之一就是归纳推理:一个单一模型能够在没有输入特征和硬编码的实体/关系嵌入矩阵的情况下,执行任何图上的推理(例如,缺失链接预测)。GraIL(ICML'20)和Neural Bellman-Ford Nets(NeurIPS'21)在扩展推理到未见过的实体方面起到了重要作用,但由于一个关键问题:当整个实体/关系词汇可能发生变化时,如何学习和迁移,推理时对新实体和关系类型的泛化仍然是未解决的挑战。
🔮 我们对 2023 年的预测(一个完全可转移到不同知识图(KG)的归纳模型,可以处理新的实体和关系集合,例如在 Wikidata 上进行训练,并在 DBpedia 或 Freebase 上运行推理)在多项工作中得以验证:
Gao et al提出了双重等变性(double equivariance)概念,强制神经网络对节点 ID 和关系 ID 的排列保持等变性。所提出的 ISDEA++模型采用了类似于DSS-GNN的关系诱导子图和由所有其他关系类型诱导的子图的聚合。
ULTRA由Galkin et al提出,学习关系交互的不变性(通过关系图捕捉)并迁移到任何多关系图。即使在零-shot 推理设置下,ULTRA 也能在数十个归纳和传导数据集上取得 SOTA 结果。此外,它还为 KG 推理提供了类似基础模型的方法,包括通用预训练、零-shot 推理和任务特定的微调。

ULTRA 的三个主要步骤:(1)构建关系图;(2)在关系图上进行条件消息传递以获取相对的关系表示;(3)利用这些表示作为归纳链接预测 GNN 在实体层面上的输入。来源:Galkin et al
了解更多关于归纳推理的信息,详见最近的博客文章:
一个模型统治所有
towardsdatascience.com
随着这个重大挑战似乎已经解决,还有什么留给知识图谱(KG)研究的吗?我们是不是该庆祝一下,结束这一天,开个派对,然后继续前进?
Michael Galkin (Intel)
“事实上,随着这个重大挑战的解决,感觉有些像是一次生存危机——一切重要的东西都已经发明了,图形机器学习使得五年前看似不可能的事情变成了现实。也许,KG 社区应该重新发明自己,专注于那些可以通过图基础模型解决的实际问题。否则,子领域将像语义网那样从研究雷达中消失。”——Michael Galkin(Intel)
归纳性和浅层 KG 嵌入已经过时,2024 年没有人应该再研究它们,是时候永久退休了。类似 ULTRA 的基础模型现在可以在任何图上运行,而不需要进行训练,这对于许多封闭企业的 KG 来说是一个理想的选择。
➡️ 最后一个未探索的领域是超越简单链接预测的归纳推理(复杂的类似数据库的逻辑查询),我认为它将在 2024 年解决。加入时间维度、LLM 节点特征或扩大 GNN 以处理更大图的任务,已成为时间问题,并且更多的是工程任务,而非研究问题。
Zhaocheng Zhu (Mila & Google)
“随着 LLM 和众多基于提示的推理技术的兴起,似乎知识图谱推理已经走到了尽头。文本比知识图谱更具表现力和灵活性,同时它们在数量上也更加丰富。然而,我不认为知识图谱社区开发的推理技术是徒劳的。”——朱兆诚(Mila & 谷歌)
➡️ 我们看到许多 LLM 推理方法与知识图谱中的一些著名思想相吻合。例如,直接提示与思维链(CoT)之间的区别,与知识图谱中的嵌入方法和基于路径的方法有很多相似之处,后者通过参数化更小的步骤,从而能够更好地推广到新的步骤组合。事实上,归纳性和多步骤推广等主题,比 LLM 更早几年就在知识图谱上进行了探讨。
当我们为 LLM 开发新技术时,回顾知识图谱(KG)中类似目标和解决方案是至关重要的。简而言之,尽管知识图谱的表现形式可能会在某个时刻消失,但我们从知识图谱推理中学到的洞见将继续在 LLM 时代发挥作用。
时序图学习
黄沈阳、埃马纽埃尔·罗西、安德烈亚·奇尼、因戈·舒尔特斯和迈克尔·高尔金准备了一篇关于时序图学习的独立概述文章!
继续推进网络的演变之旅
towardsdatascience.com
LLM + 图用于科学发现
迈克尔·高尔金(英特尔)
💡2023 年 LLM 无处不在,很难忽视房间里的大象🐘。
“我们已经看到许多方法尝试将图与 LLM 结合。这个子领域正在发展,并且迈出了它的小步伐,这是需要认可的。”——迈克尔·高尔金(英特尔)
我们已经看到许多方法尝试将图与大型语言模型(LLM)结合(有时字面上是在文本提示中表达边缘),其中直接通过边索引进行提示并不真正适用于运行图算法与语言模型,因此关键在于“文本线性化”和合适的提示。在一些值得注意的研究中,你可能会对赵等人的GraphText感兴趣,该方法设计了一种由特征和标签构成的图语法树提示,这些特征和标签来自目标节点的自我子图——GraphText 适用于节点分类。在Fatemi 等人的Talk Like a Graph中,作者研究了图的线性化策略及其如何影响 LLM 在基本任务(如边的存在、节点计数或环检查)上的表现。

标准 GNN(左)与 GraphText(右)。GraphText 将图信息编码成文本序列,并使用 LLM 进行推理。图语法树包含节点属性(如特征和标签)以及关系(如中心节点、第一跳和第二跳)。来源:赵等人
➡️ 尽管处于早期阶段,但已经有三篇近期的综述文章(Li 等人,Jin 等人,Sun 等人)涵盖了多种图的提示方法。总体而言,目前还不清楚LLM 是否适合用作 🔨 特定图任务的合适工具,因为自回归解码、小上下文窗口和图任务的排列不变性等限制。如果你对 LLM 推理有广泛兴趣,可以查看我们最近的博客文章,它总结了 2023 年在这一领域的主要进展。
➡️ LLM 在应用科学任务中展现了更有前景的、有时甚至是意想不到的结果:Bran, Cox 等人的ChemCrow 🐦⬛是一个由工具驱动的 LLM 代理,可以直接用自然语言执行有机化学、合成和材料设计任务(无需复杂的等变 GNN)。例如,查询“寻找并合成一种能够加速 Diels-Alder 反应的硫脲有机催化剂”时,ChemCrow 从基本的 SMILES 字符串开始,设计出一系列操作,最终给出合成平台的操作指令。
类似地,Gruver 等人对 LLaMA-2 进行了微调,使其生成包含晶格参数、原子组成和 3D 坐标的 3D 晶体结构纯文本文件,令人惊讶的是,它与像 CDVAE 这样的 SOTA 几何扩散模型竞争力十足。

实验验证。a) 用户运行脚本以启动 ChemCrow 的示例。b) 对硫脲有机催化剂的查询与合成。c) 实验执行的 IBM 研究 RoboRXN 合成平台(图片转载自国际商业机器公司)。d) 实验验证的化合物。来源:Bran, Cox 等人
🔮 在 2024 年,LLM 的科学应用预计将在广度和深度上都得到扩展:
1️⃣ 扩展到更多的 AI4Science 领域;
2️⃣ 与几何基础模型的集成(因为多模态是未来一年 LLM 的主要关注点);
3️⃣ 热点话题:LLM 将比 GNN 更快地解决 CLRS-30 基准中的quickselect任务🔥
酷炫的 GNN 应用
Petar Veličković(Google DeepMind)
在我用 GNN 向更广泛的观众展示其应用时,我通常依赖一张“兵器库”幻灯片,展示这些年来影响深远的 GNN 应用。随着 2023 年 LLM 发展的显著标志,我在想——我能否仅使用今年发布的模型,来有意义地更新这张幻灯片呢?
“那时正是年中的时候,我已经有了一个惊喜;我没有足够的空间列出所有使用 GNN 做出的精彩成果!”——Petar Veličković(Google DeepMind)
💡 虽然这一点可能相对不太引起注意,但我自信地声称,2023 年是最激动人心的一年,因为在许多酷炫的 GNN 应用中!LLM 的兴起清楚地表明了基于文本的自回归模型的局限性,而对于大多数来自自然界的科学问题,其图结构是不能被忽视的。
这里列出的是我个人最喜欢的几个标志性成果——所有这些都发表在顶级期刊上:
GraphCast为我们提供了一个标志性的模型,用于中范围的全球天气预报⛈️,并借此更加准确地预测极端事件,如飓风。这是《科学》杂志封面上应得的荣誉!
在材料科学领域的一个杰出发展中,GNoME使用基于 GNN 的模型发现了数百万种新型晶体结构💎——这被称为*“人类已知稳定材料的数量级扩展”*。已发表在《自然》杂志。
我们不仅迎来了一项,而且迎来了两项在抗生素发现方面的突破💊,这两项成果都采用了消息传递神经网络——后者已发表在《自然》杂志!
GNN 可以嗅觉 👃 通过观察分子结构发出的气味——这一发现可能会彻底改变多个行业,包括香水行业!已发表在《科学》杂志。
在《自然机器智能》杂志的封面上,HYFA🍄展示了如何使用超图分解法在基因表达填补问题上取得显著进展🧬!
最后但同样重要的是,粒子物理学⚛️仍然是 GNN 应用的自然强项。在今年的《自然物理学评论》上,我们欣赏到了一篇精彩的综述,阐明了图神经网络在大型强子对撞机⚡中进行各种数据分析任务的多种应用方式。
⚽ 我今年在 GNN 应用领域的谦逊贡献是TacticAI,它是首个为(足球)教练提供有用战术建议的完整 AI 系统,与我们在利物浦足球俱乐部的合作伙伴共同开发 🔴。TacticAI 能够进行预测建模(“这个战术场景中会发生什么?”)、检索相似战术,并进行条件生成建模(“如何调整球员位置以实现特定结果?”)。在我看来,这个非常有趣的合作中,最令人满意的部分是我们与利物浦俱乐部顶级教练团队的用户研究——直接表明我们模型的输出对教练们的工作是有用的 🏃。

TacticAI 的“鸟瞰”概览。(A)展示了角球情况如何转化为图表示。每个球员被视为图中的一个节点,节点、边和图的特征如正文中所述提取。然后,图神经网络在这个图上进行消息传递操作;每个节点的表示通过从其邻近节点接收到的消息进行更新。(B)展示了 TacticAI 如何处理给定的角球。为了确保 TacticAI 在面对水平或垂直反射时能够提供稳健的答案,所有可能的反射组合都会应用到输入角球上,随后这四种视图被输入到核心 TacticAI 模型中,在这些视图之间相互作用,以计算最终的球员表示——每个“内部蓝色箭头”对应于(A)中的单个消息传递层。一旦计算出球员表示,它们就可以用来预测角球的接球者,是否已射门,以及对球员位置和速度的辅助调整,这些调整会增加或减少射门的概率。来源:Wang, Veličković, Hennes et al.
这就是我的追求——显著增强人类能力的 AI 系统。我只能希望,在我的祖国,帕尔蒂赞能比红星先采纳这些方法!😅
🔮 2024 年我们会看到什么?可能会是相同的内容,只是加速了!⏩
几何华尔街公报 💸
Nathan Benaich (AirStreet Capital), Michael Bronstein (Oxford) 和 Luca Naef (VantAI)
2023 年开始时,BioNTech(广为人知的是开发了 mRNA SARS-CoV-2 疫苗)宣布收购 InstaDeep,这是一家专注于人工智能驱动药物发现、设计和开发的成立十年的英国公司。2023 年 5 月,Recursion 收购了两家初创公司,Cyclica 和 Valence,“以增强化学和生成性 AI 能力”。Valence ML 团队因在几何和图形机器学习领域的多项工作而著名,并在YouTube上举办图形与几何和分子建模与药物发现研讨会。
💰Isomorphic Labs 以宣布与 Eli Lilly 和 Novartis 进行小分子药物相关的合作开始了 2024 年,分别获得 4500 万美元和 3750 万美元的预付款,潜在价值为30 亿美元。
💰VantAI 与 Blueprint Medicines合作,共同开发创新的邻近调节治疗方法,包括分子胶和异二功能候选药物。该交易的潜在价值为 12.5 亿美元。
💰CHARM Therapeutics 获得更多资金,资金来自NVIDIA和Bristol Myers Squibb,使初始融资总额达到 7000 万美元。该公司开发了 DragonFold,这是其专有的蛋白质-配体共折叠算法。
💊 Monte Rosa 宣布了 MRT-2359 的成功1 期研究(口服生物可利用的研究性分子胶降解剂),针对 MYC 驱动的肿瘤,如肺癌和神经内分泌癌。Monte Rosa 以使用几何深度学习进行蛋白质研究(MaSIF)而闻名。
Nathan Benaich(AirStreet Capital,《AI 现状报告》作者 the State of AI Report)
“我一直对 AI 优先的方法在医学、生物技术和材料科学中解决问题的潜力持乐观态度。基于图的模型在 2023 年技术生物学领域表现突出。” — Nathan Benaich(AirStreet Capital)
RFdiffusion将扩散技术与图神经网络(GNN)结合,用于预测蛋白质结构。它通过去噪来自蛋白质数据银行(Protein Data Bank)的模糊或损坏的结构,同时利用 RoseTTAFold 的预测能力。DeepMind 继续进一步开发 AlphaFold,并在其基础上进行扩展。他们的AlphaMissense使用弱标签、语言建模和 AlphaFold 来预测 7100 万个人体变异的致病性。这是一个重要的成就,因为大多数由遗传变异引起的氨基酸变化的影响仍不明确。
除了蛋白质,基于图的模型也在提升我们对遗传学的理解。斯坦福的GEARS系统将深度学习与基因互作知识图谱结合,能够从组合扰动中预测基因表达变化。通过利用单一和双重扰动的先前数据,GEARS 能够预测成千上万个基因对的结果。

GEARS 能够预测新的生物学上有意义的表型。(a) 用于预测一组基因的所有成对组合扰动结果的工作流程。(b) 102 种单基因扰动和 128 种双基因扰动的扰动后基因表达的低维表示,用于训练 GEARS。随机选择的数据已标注。(c) GEARS 预测所有 5,151 对组合的 102 种单基因扰动的扰动后基因表达。预测的扰动后表型(非黑色符号)通常与实验中观察到的表型(黑色符号)不同。颜色表示使用标记基因表达进行标记的 Leiden 簇。来源:Roohani et al
🔮 在 2024 年,我对两个不同的发展方向寄予希望。
1️⃣ 我们已经看到了美国和英国批准的首两种 CRISPR-Cas9 疗法。这些基因组编辑器是通过测序和随机实验发现的。我对利用 AI 模型按需设计和创建定制化编辑器感到兴奋。
2️⃣ 我们已经开始看到多模态技术进入 AI 生物领域 —— 结合 DNA、RNA、蛋白质、细胞和成像数据,为我们提供更全面的生物学理解。
2024 年值得关注的公司
Profluent — 用于蛋白质设计的大型语言模型(LLM)
Inceptive.bio — 由《Transformers》论文的作者之一创办。
VantAI — 我们可能有些偏见(迈克尔·布朗斯坦是 Vant 的首席科学家,卢卡·纳夫是创始人兼 CTO),但这是一家很酷的公司,专注于利用机器学习和专有实验技术的结合,理性设计分子胶,这种组合我们认为是成功的关键。
Future House — 一家新成立的位于硅谷的非营利公司,致力于 AI4Science 领域,由前 Google CEO 埃里克·施密特资助。科学负责人是安德鲁·怀特,他因在化学领域的 LLM 工作而闻名。该公司自我描述的使命是“建设一个 AI 科学家”,一个“登月计划”。
关于几何学和图深度学习的更多文章,请参见 Michael Galkin和 Michael Bronstein的 Medium 文章,并在 Twitter 上关注两位迈克尔(Galkin 和 Bronstein)*。
图神经网络:欺诈检测与蛋白质功能预测
理解机器学习工程师在生物领域的 AI 应用
·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 11 月 21 日
--

图片由Conny Schneider提供,来源:Unsplash
金融交易网络和蛋白质结构有什么共同点?它们在欧几里得(x, y)空间中都表现得不佳,需要对复杂、大规模、异质图形进行编码,才能真正理解。

左:欧几里得空间中的图像。右:非欧几里得空间中的图形。来自《图神经网络:方法与应用综述》
图是表示金融网络和蛋白质结构中关系数据的自然方式。它们捕捉实体之间的关系与互动,例如金融系统中账户之间的交易,或蛋白质中氨基酸之间的键合与空间接近度。然而,更广为人知的深度学习架构,如 RNN/CNN 和 Transformer,无法有效地建模图形。
你可能会问,为什么我们不能把这些图形直接映射到 3D 空间?如果我们强行将它们放入 3D 网格:
我们将失去边缘信息,例如分子图中的键类型或交易类型。
映射可能需要填充或调整大小,从而导致扭曲。
稀疏的 3D 数据结果会导致许多未使用的网格单元,从而浪费内存和处理能力。
鉴于这些局限性,**图神经网络(GNNs)**作为一种强大的替代方案。在我们关于生物学应用中的机器学习系列的继续中,我们将探讨 GNNs 如何应对这些挑战。
一如既往,我们将从更熟悉的欺诈检测话题开始,然后学习这些类似的概念在生物学中的应用。
欺诈检测
为了清晰明确,让我们首先定义什么是图。我们记得在小学时在 x、y 轴上绘制图形,但我们当时真正做的是绘制函数图,在图上绘制 f(x)=y 的点。当我们在 GNN(图神经网络)的背景下谈论“图”时,我们指的是建模对象之间的配对关系,其中每个对象是一个节点,而关系是边。
在金融网络中,节点是账户,边是交易。该图将由关联方交易(RPT)构建,并可以通过属性(例如时间、金额、货币)进行丰富。


左侧:函数图(我们不讨论的内容)(2024 年 3 月 15 日)。摘自维基百科。en.wikipedia.org/wiki/Graph_of_a_function 右侧:具有节点和边的图(我们讨论的内容)(2024 年 10 月 25 日)。摘自维基百科。en.wikipedia.org/wiki/Graph_theory
传统的基于规则和机器学习的方法通常只处理单一的交易或实体。这一局限性未能考虑到交易如何与更广泛的网络连接。因为诈骗者通常跨多个交易或实体进行操作,所以诈骗可能未被发现。
通过分析图形,我们可以捕捉直接邻居和更远连接之间的依赖关系和模式。这对于检测洗钱至关重要,因为资金通过多个交易流动以掩盖其来源。GNNs 揭示了洗钱方法所创造的密集子图。

与使用 GNN 检测基于关联方交易网络的金融欺诈中的关联方转移网络示例
消息传递框架
与其他深度学习方法类似,目标是从数据集中创建表示或嵌入。在 GNN 中,这些节点嵌入是通过消息传递框架创建的。消息在节点之间迭代传递,使得模型能够学习图的局部和全局结构。每个节点的嵌入是基于其邻居特征的聚合来更新的。
该框架的概括如下:
初始化:嵌入向量*hv(0)*通过基于节点特征的嵌入、随机嵌入或预训练嵌入(例如账户名称的词嵌入)进行初始化。
消息传递: 在每一层 t,节点与它们的邻居交换消息。消息定义为发送节点的特征、接收节点的特征以及连接它们的边的特征,并通过一个函数进行组合。组合函数可以是简单的拼接,并使用固定权重方案(由图卷积网络,GCNs 使用),也可以是加权注意力机制,其中权重基于发送节点和接收节点的特征(以及可选的边特征)进行学习(由图注意力网络,GATs 使用)。
聚合: 在消息传递步骤之后,每个节点聚合接收到的消息(可以是均值、最大值、总和等简单操作)。
更新: 聚合后的消息通过更新函数(可能是像 ReLU 的多层感知器(MLP),门控循环单元(GRU)或注意力机制)更新节点的嵌入。
最终化: 像其他深度学习方法一样,当表示稳定或达到最大迭代次数时,嵌入会被最终化。

在消息传递神经网络(MPNN)层中,节点表示更新。节点接收所有直接邻居发送的消息。消息通过消息函数计算,该函数考虑了发送方和接收方的特征。图神经网络。(2024 年 11 月 14 日)。摘自维基百科。en.wikipedia.org/wiki/Graph_neural_network
在节点嵌入学习完成后,可以通过几种不同的方式计算欺诈分数:
分类: 最终的嵌入被传入分类器,如多层感知器(MLP),这需要一个全面的标注历史训练集。
异常检测: 嵌入被分类为异常,基于其与其他嵌入的差异。这里可以使用基于距离的分数或重构误差来进行无监督的方法。
图级评分: 嵌入被聚合到子图中,然后传入分类器,以检测欺诈环(同样需要一个标注的历史数据集)。
标签传播: 一种半监督方法,其中标签信息基于边的权重或图的连接性进行传播,从而对未标注的节点进行预测。
现在我们已经对图神经网络(GNN)有了基础理解,可以转向 GNN 的另一个应用:预测蛋白质的功能。
蛋白质功能预测
我们已经看到了通过 AlphaFold 2和3在蛋白质折叠预测中的巨大进展,以及通过RFDiffusion在蛋白质设计中的突破。然而,蛋白质功能预测仍然具有挑战性。功能预测对于许多原因至关重要,但特别在生物安全领域尤为重要,因为它有助于预测 DNA 在测序之前是否会发生孤雌生殖。传统方法如BLAST依赖于序列相似性搜索,并未结合任何结构数据。
今天,GNNs(图神经网络)开始通过利用蛋白质的图表示来建模残基之间的关系及其相互作用,在这一领域取得了有意义的进展。它们被认为非常适合蛋白质功能预测、识别小分子或其他蛋白质的结合位点,并基于活性位点几何形状对酶家族进行分类。
在许多例子中:
节点被建模为氨基酸残基
边缘作为它们之间的相互作用
这种方法背后的合理性在于,图结构具有捕捉序列中远距离但在折叠结构中接近的残基之间长程相互作用的固有能力。这类似于为什么 transformer 架构对 AlphaFold 2 如此有帮助,因为它允许在序列中的所有对之间进行并行计算。
为了使图形信息更加密集,每个节点可以通过残基类型、化学性质或进化保守性得分等特征进行丰富。边缘可以选择性地通过化学键类型、三维空间中的接近度以及静电或疏水相互作用等属性进行丰富。
DeepFRI是一种基于图神经网络(GCN)的方法,用于根据结构预测蛋白质功能。GCN 是一种特定类型的 GNN,它将卷积的概念(用于 CNN)扩展到图数据。

DeepFRI 图示:LSTM 语言模型,经过约 200 万 Pfam 蛋白质序列的预训练,用于提取 PDB 序列的残基级特征。(B)具有 3 个图卷积层的 GCN,用于学习复杂的结构-功能关系。来自基于结构的功能预测使用图卷积网络
在 DeepFRI 中,每个氨基酸残基都是一个节点,具有如下一些属性:
氨基酸类型
物理化学性质
来自MSA的进化信息
来自预训练 LSTM 的序列嵌入
结构背景,如溶剂可接近性。
每个边被定义为捕获蛋白质结构中氨基酸残基之间的空间关系。如果两个节点(残基)之间的距离低于一定阈值,通常为 10 Å,则存在一条边。在这个应用中,边没有属性,它们只是无权重的连接。
图是用节点特征 LSTM 生成的序列嵌入以及从残基接触图创建的残基特异性特征和边信息进行初始化的。
一旦图被定义,消息传递发生在每个三层的基于邻接的卷积中。节点特征是通过使用图的邻接矩阵从邻居聚合的。堆叠多个 GCN 层允许嵌入从越来越大的邻域中捕获信息,从直接邻居开始,延伸到邻居的邻居等。
最终的节点嵌入是全局汇集的,用于创建蛋白质级别的嵌入,然后用于将蛋白质分类为层次相关的功能类别(GO 术语)。分类是通过将蛋白质级别的嵌入通过具有 sigmoid 激活函数的全连接层(密集层)进行传递,使用二元交叉熵损失函数进行优化来执行的。分类模型是基于从蛋白质结构(例如来自蛋白质数据银行)和来自 UniProt 或 Gene Ontology 等数据库的功能注释的数据进行训练的。
总结思考
图对于建模许多非线性系统非常有用。
GNNs 通过同时融合局部和全局信息来捕获传统方法难以建模的关系和模式。
GNNs 有许多变体,但目前最重要的是图卷积网络和图注意力网络。
GNNs 可以有效地识别使用监督和无监督方法在洗钱方案中存在的多跳关系。
GNNs 可以通过整合结构数据改进仅基于序列的蛋白质功能预测工具,如 BLAST。这使研究人员能够预测与已知蛋白质具有最小序列相似性的新蛋白质的功能,这是理解生物安全威胁和促进药物发现的关键步骤。
祝福您,如果您喜欢这篇文章,请查看我在机器学习和生物学上的其他文章。
图神经网络 第一部分:图卷积网络解析

该图片由作者使用 Dall·E 创建。
图节点分类与图卷积网络
·发表于Towards Data Science ·阅读时间 11 分钟·2024 年 10 月 1 日
--
数据并不总是整齐地适应行和列。实际上,数据通常遵循图结构,例如社交网络、蛋白质结构、推荐系统或交通运输系统。忽略图拓扑结构中的信息可能会大幅降低机器学习模型的表现。幸运的是,有一种方法可以将这些信息纳入模型。
图神经网络(GNNs)旨在从以节点和边表示的数据中学习。随着时间的推移,GNNs 不断发展,在这篇文章中你将了解图卷积网络(GCNs)。我的下一篇文章将介绍图注意力网络(GATs)。GCNs 和 GATs 是当前最先进的模型所基于的两种基本架构,因此如果你想学习关于 GNNs 的知识,这是一个不错的起点。让我们开始吧!
对图形不熟悉?这篇文章的第一部分(图基础)解释了图形的基础知识。此外,你还应该了解神经网络(本文章的“数据集与前提”部分提供了简短的回顾)。
图神经网络 第二部分:图注意力网络与图卷积网络

图片由作者使用 Dall·E 创作。
一种关注你图的模型
·发表于Towards Data Science ·阅读时间:8 分钟·2024 年 10 月 8 日
--
欢迎来到第二篇关于图神经网络(GNN)架构的文章!在上一篇文章中,我们通过使用图卷积网络(GCN)将图结构引入模型,取得了 Cora 数据集上的准确率惊人提升。本文将解释图注意力网络(GATs),图神经网络的另一种基础架构。我们能否通过 GAT 进一步提高准确率呢?
首先,让我们讨论一下 GAT 和 GCN 之间的区别。然后,我们将训练一个 GAT,并将其准确率与 GCN 和基本神经网络进行比较。
这篇博客文章是系列文章的一部分。如果你是 GNN 的新手,推荐从第一篇文章开始阅读,文章中讲解了图、神经网络、数据集以及 GCN。
图注意力网络
在上一篇文章中,我们看到了 GCN 的实际应用。接下来让我们更进一步,探讨图注意力网络(GATs)。正如你可能记得的那样,GCN 对所有邻居的处理是相同的。而 GAT 则有所不同。GAT 允许模型为不同的邻居学习不同的重要性(注意力)得分。它们通过使用注意力机制来聚合邻居的信息(这可能会让你想起……)
图结构 RAG — 概念介绍
图结构 RAG 解答了文本嵌入无法帮助你解决的重大问题。
·发表于 Towards Data Science ·阅读时间 8 分钟 ·2024 年 8 月 22 日
--
检索增强生成(RAG)自从 ChatGPT 的出现引爆了人工智能热潮以来,便主导了关于使生成式 AI 应用变得实用的讨论。这个理念很简单。一旦将大规模语言模型(LLMs)与我们的私人数据相连接,它们便变得尤为有用。一个人人都能访问的基础模型,结合我们的领域特定数据作为“秘密武器”,就会形成一个强大且独特的工具。就像在人类社会中一样,AI 系统似乎发展成了一个专家经济。通用知识是一个有用的基础,但专家知识则能够发挥 AI 系统的独特卖点。
总结:RAG 本身并没有描述任何特定的架构或方法。它只是展示了如何通过任意检索方法增强给定的生成任务。原始的 RAG 论文(Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks by Lewis et. al.) 比较了双塔嵌入方法与词袋检索方法。
本地和全局问题
基于文本嵌入的检索已经在许多场合中得到了描述。它已经使我们的 LLM 应用程序能够极为可靠地基于给定知识库的内容回答问题。Text2Vec 检索的核心优势仍然是:从嵌入的知识库中提取一个给定事实,并使用该提取的事实来构造对用户查询的回答。然而,文本嵌入搜索也面临着重大挑战。通常,每个文本嵌入表示来自非结构化数据集的一个特定块。最近邻搜索找到的嵌入代表了与用户查询语义相似的块。这也意味着搜索是语义上的,但仍然非常具体。因此,候选答案的质量在很大程度上依赖于查询的质量。此外,嵌入表示的是知识库中提到的内容。这并不代表你在寻找那些需要跨文档或在文档内跨概念抽象的问题的情况。
例如,假设有一个知识库,其中包含所有过去诺贝尔和平奖得主的简历。向 Text2Vec-RAG 系统提问“谁赢得了 2023 年诺贝尔和平奖?”是一个很容易回答的问题。这个事实在嵌入的文档块中得到了很好的表示。因此,最终的答案可以基于正确的上下文得出。另一方面,RAG 系统可能会在回答“过去十年最著名的诺贝尔和平奖得主是谁?”时遇到困难。如果我们添加更多的上下文信息,比如“谁是过去十年中在中东冲突中作出突出贡献的诺贝尔和平奖得主?”我们可能会成功,但即使如此,仅凭文本嵌入(考虑到当前嵌入模型的质量)仍然难以完全解决这个问题。另一个例子是整个数据集推理。例如,用户可能会对你的 LLM 应用程序提出问题:“最近的诺贝尔和平奖得主支持的前三大议题是什么?”嵌入的文档块不允许跨文档推理。我们的最近邻搜索在知识库中寻找特定提到“最近的诺贝尔和平奖得主支持的前三大议题”的内容。如果知识库中没有包含这一内容,任何完全基于文本嵌入的 LLM 应用程序都将面临困难,并且很可能无法正确甚至全面地回答这个问题。
我们需要一种替代的检索方法,使我们能够回答这些“全局”的汇总性问题,除了“局部”的提取性问题之外。欢迎使用图形 RAG!
知识图谱是一种半结构化的、层次化的信息组织方法。一旦信息被组织为图谱,我们不仅可以推断出特定节点的信息,还可以推断出它们的关系和邻居。图谱结构使得对全球数据集层面的推理成为可能,因为节点和它们之间的连接可以跨越不同的文档。有了这个图谱,我们还可以分析邻近的节点和更紧密连接的节点社区,这些节点之间的联系比它们与其他节点的联系更为紧密。一个节点的社区可以大致覆盖一个感兴趣的主题。通过抽象社区节点及其连接,我们可以对该主题中的概念有一个抽象的理解。图谱 RAG 利用对知识图谱中社区的理解,来为特定用户查询提供上下文。
一个图谱 RAG 管道通常会遵循以下步骤:
图谱提取
图谱存储
社区检测
社区报告生成
最终上下文构建的映射归约

图谱 RAG 逻辑可视化 — 来源:作者提供的图片
图谱提取
建立我们非结构化知识库的抽象理解的过程始于提取构建知识图谱的节点和边。你可以通过一个大型语言模型(LLM)自动化这个提取过程。此步骤的最大挑战是决定哪些概念和关系是相关的,值得包括在内。为了举一个例子来说明这个高度模糊的任务:假设你正在从一篇关于沃伦·巴菲特的文档中提取知识图谱。你可以提取他的持股、出生地以及许多其他事实作为实体,并附上相应的边。这些信息很可能对你的用户来说是高度相关的。(如果文档内容合适)你也可以提取他在最后一次董事会会议上领带的颜色。这些信息(很可能)对用户来说是无关的。至关重要的是,需要根据应用程序的使用场景和领域来指定提取提示。这是因为提示将决定从非结构化数据中提取哪些信息。例如,如果你有兴趣提取有关人的信息,你需要使用与提取公司信息不同的提示。
指定提取提示的最简单方法是通过多轮提示。这涉及给 LLM 多个期望的输入和输出示例。例如,你可以给 LLM 一系列关于人物的文档,并要求它提取每个人的姓名、年龄和职业。LLM 随后会学会从新文档中提取这些信息。指定提取提示的更高级方法是通过 LLM 微调。这涉及在包含期望输入和输出示例的数据集上训练 LLM。这可能比多轮提示带来更好的表现,但也更耗时。
图存储
你设计了一个可靠的提取提示并调整了你的 LLM。你的提取管道正常工作。接下来,你需要考虑如何存储这些结果。图数据库(DB)如 Neo4j 和 Arango DB 是直接的选择。然而,增加另一种数据库类型并学习一种新的查询语言(例如 Cypher/Gremlin)可能会非常耗时。从我的高层研究来看,目前没有很好的无服务器选项可用。如果处理大多数图数据库的复杂性还不够,这最后一个问题对我这种无服务器爱好者来说简直是致命的。不过,还是有替代方案的。通过一点创意,使用合适的数据模型,图数据可以被格式化为半结构化,甚至是严格结构化的数据。为了激发你的灵感,我编写了 graph2nosql 作为一个简单的 Python 接口,用于在你最喜欢的 NoSQL 数据库中存储和访问你的图数据集。
数据模型定义了节点、边和社区的格式。将这三者存储在不同的集合中。每个节点、边和社区最终通过唯一标识符(UID)进行识别。Graph2nosql 然后实现了在处理知识图时所需的几个基本操作,如添加/删除节点/边、可视化图、检测社区等。

graph2nosql 数据模型 — 来源:作者提供的图片
社区检测
一旦图被提取并存储,下一步就是识别图中的社区。社区是比其他节点更紧密连接的节点集群。可以使用各种社区检测算法来完成此任务。
一种流行的社区检测算法是 Louvain 算法。Louvain 算法通过迭代地将节点合并为社区,直到满足某个停止准则。停止准则通常基于图的模块度。模块度是衡量图如何被划分为社区的指标。
其他流行的社区检测算法包括:
Girvan-Newman 算法
快速展开算法
Infomap 算法
社区报告生成
现在,使用结果社区作为生成社区报告的基础。社区报告是每个社区内节点和边缘的总结。这些报告可以用来理解图结构,并识别知识库中的关键主题和概念。在知识图中,每个社区都可以理解为代表一个“主题”。因此,每个社区可能是回答不同类型问题的有用上下文。
除了总结多个节点的信息外,社区报告还是跨概念和文档的第一个抽象层次。一个社区可以跨越由多个文档添加的节点。这样,你就在构建一个关于已索引知识库的“全局”理解。例如,在你的诺贝尔和平奖获奖者数据集中,你可能提取了一个社区,代表所有与节点“诺贝尔和平奖”通过边缘描述“获奖者”连接的“人物”类型的节点。
来自微软图 RAG 实现的一个好点子是“发现”。在一般的社区总结之上,这些发现是关于社区的更详细的见解。例如,对于包含所有历史诺贝尔和平奖获奖者的社区,一个发现可能是连接他们大部分行动主义的某些主题。
就像图提取一样,社区报告生成的质量将高度依赖于领域和用例的适应程度。为了创建更准确的社区报告,可以使用多次提示或 LLM 微调。
Map Reduce 用于最终上下文构建
在查询时,你使用 Map-Reduce 模式首先生成中间响应和最终响应。
在映射步骤中,你将每个社区-用户查询对结合起来,并使用给定的社区报告生成用户查询的答案。除了这个中间响应之外,你还会请求 LLM 评估给定社区报告作为用户查询的上下文相关性。
在归约步骤中,你将对生成的中间响应的相关性评分进行排序。排名前 k 的相关性评分代表回答用户查询的感兴趣社区。相应的社区报告,可能结合节点和边缘信息,构成最终 LLM 提示的上下文。
结语:这将走向何方?
在知识库问答任务中,Text2vec RAG 存在明显的空白。图形 RAG 可以弥补这些空白,而且做得非常好! 通过生成社区报告的额外抽象层为你的知识库提供了重要的见解,并建立了其语义内容的全球理解。这将为团队节省大量筛选文档中特定信息的时间。如果你正在构建一个 LLM 应用程序,它将使你的用户能够提出真正重要的大问题。你的 LLM 应用程序将突然能够“预见”问题的本质,理解用户数据中的情况,而不仅仅是从中引用内容。
另一方面,一个图形 RAG 管道(如本文所述的原始形式)比 text2vec RAG 管道需要更多的 LLM 调用。特别是社区报告和中间答案的生成是潜在的薄弱环节,这将花费大量的资金和延迟。
正如在搜索领域中经常出现的情况一样,你可以预期,围绕先进 RAG 系统的行业将朝着混合方法发展。针对特定查询使用合适的工具将是扩展 RAG 应用程序时至关重要的。例如,可以想象使用一个分类层来区分本地和全球查询。也许生成社区报告和发现就足够了,将这些报告作为抽象的知识加入到你的索引中,作为上下文候选项就足够了。
幸运的是,完美的 RAG 管道尚未解决,你的实验将成为解决方案的一部分。我很想听听你的进展如何!
图谱 RAG、自动化提示工程、智能体框架及其他九月必读文章
·发表于 Towards Data Science ·通过 时事通讯 发送 ·3 分钟阅读·2024 年 10 月 3 日
--
想写你的第一篇 TDS 文章吗?我们始终欢迎新作者的投稿。
每年九月到来时,空气中总是充满了激动人心和充满活力的气息,今年也不例外。确实,告别漫长的阳光明媚的日子和略显缓慢的节奏可能让人感到有些惆怅,但这种感觉不会持续太久——尤其是在机器学习和人工智能领域如此多的创新和新工具诞生的时候,还有大量的新技能等待我们去学习。
我们很高兴与大家分享过去一个月里最受欢迎和最多人分享的文章,以防你错过了其中任何一篇(或者只是想重新浏览一篇或两篇最爱)。比以往任何时候都更能代表我们作者覆盖的各个领域,从核心编程技能到前沿的 LLM 技术,因此我们确信你一定会在我们的九月精选中找到让你感兴趣的内容。祝你阅读愉快,愿新的一季充满学习和成长!
月度精选
如何使用知识图谱和向量数据库实现图谱 RAG 本月最受欢迎的文章来自Steve Hedden:一篇清晰且易于理解的逐步教程,介绍了如何实现检索增强生成(RAG)、语义搜索和推荐。
数据科学家在 Python 中无法出色表现,除非掌握这些函数总有空间为另一个扎实的 Python 教程——Jiayan Yin 的数据科学家关键函数大全对我们的读者尤其有帮助。
Python 快速入门:为学习 AI 的人设计更多的 Python!Shaw Talebi 的初学者友好指南专注于你需要掌握的编程主题,如果你的最终目标是开发自定义的 AI 项目和产品。
自动化提示工程:权威实用指南想学习如何自动化提示工程并显著提升你在 LLM 工作负载中的性能吗?不要错过 Heiko Hotz 的实用指南。

图片由 Nahrizul Kadri 提供,来自 Unsplash
**使用 Python 构建 GenAI:从零开始创建代理(完整教程)**通过结合 Ollama、LangChain 和 LangGraph 的强大功能,Mauro Di Pietro 带领我们走过了创建自定义 AI 代理的完整工作流程。
**SQL:掌握数据工程基础(第一部分)**无论你是 SQL 新手,还是需要复习的老手,Leonardo Anello 的全面介绍,专门为数据工程师量身定制,是一个强大的、一站式资源。
选择 LLM 代理框架构建定制的基于代码的代理与依赖于主要代理框架之间有什么权衡?Aparna Dhinakaran 分享了关于这一关键问题的实用见解和建议。
每个数据科学家都应该了解的分析框架凭借她作为顾问的丰富经验,Tessa Xie为数据专业人士提供了有用的建议,教你如何将一个抽象的商业问题拆解成更小、更明确的分析任务。
超越折线图和柱状图:7 种不太常见但强大的可视化类型从撞击图到圆形柱状图,再到桑基图,Yu Dong邀请我们扩展视觉设计词汇,尝试那些不太常见的可视化方法。
让你的简历在 FAANG 申请中真正脱颖而出的 5 个技巧在竞争激烈的市场中,每一个细节都至关重要,微小的调整就能带来重大变化——这就是为什么你应该参考Khouloud El Alami为当前求职者提供的切实可行的建议。
我们最新的一批新作者
每个月,我们都很高兴看到一批新作者加入 TDS,他们各自分享着自己独特的声音、知识和经验。如果你正在寻找新的作家来探索和关注,可以浏览我们最新加入的作者的作品,包括Alexander Polyakov、Harsh Trivedi、Jinhwan Kim、Lenix Carter、Gilad Rubin、Laurin Brechter、Shirley Bao, Ph.D.、Iqbal Rahmadhan、Jesse Xia、Sezin Sezgin-Rummelsberger、Reinhard Sellmair、Yasin Yousif、Hui Wen Goh、Amir Taubenfeld、Sébastien Saurin、James Gearheart、Zackary Nay、Jens Linden, PhD、Eyal Kazin、Dan Beltramo、Sabrine Bendimerad、Niklas von Moers、Milan Tamang、Abhinav Prasad Yasaswi、Abhinav Kimothi、Miguel Otero Pedrido、Oliver Ma、Hamza Farooq、Shanmukha Ranganath、Maarten Sukel、Murilo Gustineli、Luiz Venosa、Saankhya Mondal、David Vaughn、Prasad Mahamulkar、Federico Rucci、Philippe Ostiguy, M. Sc.、Anurag Bhagat 和Megan Grant等人。
感谢您支持我们作者的工作!我们非常喜欢发布新作者的文章,如果您最近写了一篇有趣的项目操作指南、教程或关于我们核心主题的理论反思,请不要犹豫,与我们分享。
直到下一个变量,
TDS 团队
将 Graph RAG 投入生产——逐步指南

一个原生于 GCP、完全无服务器的实现,你只需几分钟即可复制
·发布于Towards Data Science ·阅读时长 14 分钟·2024 年 9 月 23 日
--
在概述 Graph RAG 概念后,让我们将其投入生产。这就是如何将 GraphRAG 投入生产:完全无服务器、完全并行化,以最小化推理和索引时间,而且完全不涉及图数据库(保证!)。
在本文中,我将向您介绍graphrag-lite,这是一个端到端的 Graph RAG 数据摄取和查询实现。我将 graphrag-lite 作为开源项目发布,旨在让您在 GCP 上部署 graphrag 时更加轻松。graphrag-lite 是 Google Cloud 原生的,可以直接使用。代码以模块化方式设计,可以根据您选择的平台进行调整。
回顾:
检索增强生成(Retrieval Augmented Generation)本身并未描述任何特定的架构或方法。它仅展示了如何通过任意的检索方法增强给定的生成任务。原始的 RAG 论文(《面向知识密集型 NLP 任务的检索增强生成》,作者:Lewis 等)比较了两塔嵌入方法与词袋检索方法。
现代问答系统区分局部问题和全局问题。一个局部(提取性)问题可能是“2023 年谁获得了诺贝尔和平奖?”,针对一个非结构化的样本知识库。一个全局(聚合性)问题可能是“你知道的最近诺贝尔奖得主是谁?”。Text2embedding RAG 在处理全局和结构化问题时存在明显的缺口。Graph RAG 能够弥补这些缺口,并且做得很好!通过一个抽象层,它学习知识图谱社区的语义,从而构建对索引数据集的更“全球化”的理解。这里是 Graph RAG 的概念介绍,供你阅读。
Graph RAG 管道
一个 Graph RAG 管道通常遵循以下步骤:
图谱提取
这是主要的摄取步骤。你的 LLM 扫描每一个传入的文档,使用提示来提取与我们知识图谱相关的节点和边。你会多次执行此提示,确保捕捉到所有相关的信息。
图谱存储
你将提取的节点和边存储在你选择的数据存储中。专用图数据库(Graph DB)是一个选项,但它们通常比较繁琐。Graph2nosql 是一个基于 Python 的接口,用于在 Firestore 或任何其他 NoSQL 数据库中存储和管理知识图谱。我开源了这个项目,因为我没有找到任何市场上可比的、原生支持知识图谱的选项。
社区检测
一旦你存储了知识图谱数据,你将使用社区检测算法来识别节点群体,这些节点在彼此之间的连接比与图谱其他部分的连接更加密集。在知识图谱的上下文中,假设是密集的社区覆盖了共同的主题。
社区报告生成
然后,你指示你的大型语言模型(LLM)为每个图谱社区生成报告。这些社区报告有助于在单一主题之间进行抽象,以便掌握数据集中更广泛的全球概念。社区报告与知识图谱一起存储。这标志着管道的摄取层部分完成。
用于最终上下文构建的映射-归约(Map-Reduce)模式
在查询时,你遵循映射-归约(map-reduce)模式,为知识图谱中的每个社区报告生成一个中间响应。你还让 LLM 评估每个中间查询响应的相关性。最后,你根据相关性对中间响应进行排名,并选择前 n 个作为最终响应的上下文,回复用户。

Graph RAG 逐步逻辑 — 图像由作者提供
图谱提取
在初始摄取步骤中,你需要指示你的 LLM 将输入文档编码为图形。一个详细的提示会指示你的 LLM 首先识别给定类型的节点,其次识别你所识别节点之间的边。就像任何 LLM 提示一样,这个挑战没有唯一的解决方案。以下是我基于微软的开源实现的图形提取提示核心部分:
-Goal-
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
-Steps-
1\. Identify all entities. For each identified entity, extract the following information:
- entity_name: Name of the entity, capitalized
- entity_type: One of the following types: [{entity_types}]
- entity_description: Comprehensive description of the entity's attributes and activities
Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
2\. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
For each pair of related entities, extract the following information:
- source_entity: name of the source entity, as identified in step 1
- target_entity: name of the target entity, as identified in step 1
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
- relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_strength>)
3\. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2\. Use **{record_delimiter}** as the list delimiter.
4\. When finished, output {completion_delimiter}
<Multishot Examples>
-Real Data-
######################
Entity_types: {entity_types}
Text: {input_text}
######################
Output:
提取步骤决定了哪些信息将反映在你的知识库中。因此,你应该使用一个相对强大的模型,比如 Gemini 1.5 Pro。你还可以进一步增加结果的鲁棒性,使用多轮版本的 Gemini 1.5 Pro,并查询模型多次改进其结果。以下是我在 graphrag-lite 中实现图形提取循环的方式:
class GraphExtractor:
def __init__(self, graph_db) -> None:
self.tuple_delimiter = "<|>"
self.record_delimiter = "##"
self.completion_delimiter = "<|COMPLETE|>"
self.entity_types = ["organization", "person", "geo", "event"]
self.graph_extraction_system = prompts.GRAPH_EXTRACTION_SYSTEM.format(
entity_types=", ".join(self.entity_types),
record_delimiter=self.record_delimiter,
tuple_delimiter=self.tuple_delimiter,
completion_delimiter=self.completion_delimiter,
)
self.llm = LLMSession(system_message=self.graph_extraction_system,
model_name="gemini-1.5-pro-001")
def __call__(self, text_input: str, max_extr_rounds: int = 5) -> None:
input_prompt = self._construct_extractor_input(input_text=text_input)
print("+++++ Init Graph Extraction +++++")
init_extr_result = self.llm.generate_chat(
client_query_string=input_prompt, temperature=0, top_p=0)
print(f"Init result: {init_extr_result}")
for round_i in range(max_extr_rounds):
print(f"+++++ Contd. Graph Extraction round {round_i} +++++")
round_response = self.llm.generate_chat(
client_query_string=prompts.CONTINUE_PROMPT, temperature=0, top_p=0)
init_extr_result += round_response or ""
print(f"Round response: {round_response}")
if round_i >= max_extr_rounds - 1:
break
completion_check = self.llm.generate_chat(
client_query_string=prompts.LOOP_PROMPT, temperature=0, top_p=0)
if "YES" not in completion_check:
print(
f"+++++ Complete with completion check after round {round_i} +++++")
break
首先,我调用多轮模型提取节点和边。然后,我要求模型多次改进先前的提取结果。
在 graphrag-lite 的实现中,提取模型的调用由前端客户端发出。如果你想减少客户端的负载,可以将提取查询外包给微服务。
图形存储
一旦你从文档中提取了节点和边,你需要以可访问的格式存储它们。图形数据库是一个选择,但它们也可能比较繁琐。对于你的知识图谱,你可能更倾向于寻找一些更轻量级的解决方案。我也有同样的想法,因为我没有找到任何开源的知识图谱本地库,于是我开源了 graph2nosql。Graph2nosql 是一个简单的知识图谱本地 Python 接口,它帮助你在任何 NoSQL 数据库中存储和管理知识图谱。所有这一切都无需通过图形数据库来扩展你的技术栈,或学习 Cypher。
Graph2nosql 是为知识图谱检索而设计的,考虑到了图形 RAG 的需求。该库围绕三种主要数据类型进行设计:EdgeData(边数据)、NodeData(节点数据)和 CommunityData(社区数据)。节点通过唯一标识符(uid)识别。边通过源节点和目标节点的 uid 以及边的 uid 识别。由于 uid 可以自由设计,graph2nosql 的数据模型为任何规模的知识图谱留出了空间。你甚至可以添加文本或图形嵌入。这使得基于嵌入的分析、边预测和额外的文本嵌入检索成为可能(考虑到混合 RAG)。
Graph2nosql 本地设计时围绕 Firestore。
@dataclass
class EdgeData:
source_uid: str
target_uid: str
description: str
edge_uid: str | None = None
document_id: str | None = None
@dataclass
class NodeData:
node_uid: str
node_title: str
node_type: str
node_description: str
node_degree: int
document_id: str
community_id: int | None = None # community id based on source document
edges_to: list[str] = field(default_factory=list)
edges_from: list[str] = field(default_factory=list) # in case of directed graph
embedding: list[float] = field(default_factory=list) # text embedding representing node e.g. combination of title & description
@dataclass
class CommunityData:
title: str # title of comm, None if not yet computed
community_nodes: set[str] = field(default_factory=set) # list of node_uid belonging to community
summary: str | None = None # description of comm, None if not yet computed
document_id: str | None = None # identifier for source knowlede base document for this entity
community_uid: str | None = None # community identifier
community_embedding: Tuple[float, ...] = field(default_factory=tuple) # text embedding representing community
rating: int | None = None
rating_explanation: str | None = None
findings: list[dict] | None = None
要通过 graph2nosql 存储你的图形数据,只需在解析提取步骤的结果时运行以下代码。这是 graphrag-lite 的实现。
from graph2nosql.graph2nosql.graph2nosql import NoSQLKnowledgeGraph
from graph2nosql.databases.firestore_kg import FirestoreKG
from graph2nosql.datamodel import data_model
fskg = FirestoreKG(
gcp_project_id=project_id,
gcp_credential_file=firestore_credential_file,
firestore_db_id=database_id,
node_collection_id=node_coll_id,
edges_collection_id=edges_coll_id,
community_collection_id=community_coll_id)
node_data = data_model.NodeData(
node_uid=entity_name,
node_title=entity_name,
node_type=entity_type,
node_description=entity_description,
document_id=str(source_doc_id),
node_degree=0)
fskg.add_node(node_uid=entity_name,node_data=node_data)
社区检测
将所有相关的节点和边存储在你的图形数据库中后,你可以开始构建抽象层。实现这一点的一种方式是查找描述相似概念的节点,并描述它们是如何在语义上连接的。例如,Graph2nosql 提供了内建的社区检测,基于 Louvain 社区。
根据您的提取结果质量,您可能会在知识图谱中发现零度节点。从经验来看,零度节点通常是重复的。graphrag-lite 使用图社区作为主要的抽象层,因此您应该删除没有任何边的节点。因此,考虑进行另一个去重/合并步骤和/或基于描述和图嵌入的节点预测步骤,以添加在提取步骤中可能遗漏的边是有意义的。在 graphrag-lite 中,我目前简单地删除所有零度节点。
# clean graph off all nodes without any edges
fskg.clean_zerodegree_nodes()
# generate communities based on cleaned graph
comms = kg.get_louvain_communities()
优化 LLM 应用中的吞吐量延迟
上述提到的 GraphRAG 管道在每个文档摄取和用户查询时都会进行多个 LLM 调用。例如,为每个新索引的文档生成多个社区报告,或者在查询时为多个社区生成中间响应。如果并发处理,结果将导致糟糕的用户体验。特别是在大规模应用时,用户可能需要等待几分钟到几小时才能收到查询的响应。幸运的是,如果您正确构建 LLM 提示,您可以将其设计为“无状态工作者”。无状态处理架构的优势是双重的。首先,它们易于并行化。其次,它们易于实现为无服务器基础设施。结合并行化和无服务器架构,可以最大化吞吐量可扩展性,并最小化空闲集群设置的成本。
在 graphrag-lite 架构中,我将社区报告生成和中间查询生成作为无服务器的 Cloud Run 微服务工作者进行托管。这些工作者通过 GCP 的无服务器消息队列 PubSub 接收消息。

graphrag-lite 的无服务器分布式摄取和查询管道 — 图片由作者提供
社区报告生成
运行社区检测后,您现在知道了多个社区成员节点集。每个集合代表您的知识图谱中的一个语义主题。社区报告步骤需要在这些源自不同文档的概念之间进行抽象处理。我再次基于微软的实现,并添加了一个函数调用,以便轻松解析结构化输出。
You are an AI assistant that helps a human analyst to perform general information discovery. Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network.
# Goal
Write a comprehensive report of a community, given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims.
# Report Structure
The report should include the following sections:
- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title.
- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities.
- IMPACT SEVERITY RATING: a float score between 0-10 that represents the severity of IMPACT posed by entities within the community. IMPACT is the scored importance of a community.
- RATING EXPLANATION: Give a single sentence explanation of the IMPACT severity rating.
- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive.
社区报告生成还展示了知识图谱检索中的最大挑战。从理论上讲,任何文档都可能向图中的每个现有社区添加一个新节点。在最坏的情况下,您需要在每次添加新文档时重新生成知识库中的每个社区报告。实际上,关键是要包含一个检测步骤,识别在文档上传后发生变化的社区,从而仅为已调整的社区生成新报告。
由于您需要为每次文档上传重新生成多个社区报告,如果同时运行这些请求,我们也面临显著的延迟挑战。因此,您应该将这项工作外包并将其并行化到异步工作者中。如前所述,graphrag-lite 通过使用无服务器架构解决了这一问题。我使用 PubSub 作为消息队列来管理工作项并确保处理。Cloud Run 作为计算平台,托管调用 LLM 的无状态工作者。在生成过程中,它们使用如上所示的提示。
def async_generate_comm_report(self, comm_members: set[str]) -> data_model.CommunityData:
llm = LLMSession(system_message=prompts.COMMUNITY_REPORT_SYSTEM,
model_name="gemini-1.5-flash-001")
response_schema = {
"type": "object",
"properties": {
"title": {
"type": "string"
},
"summary": {
"type": "string"
},
"rating": {
"type": "int"
},
"rating_explanation": {
"type": "string"
},
"findings": {
"type": "array",
"items": {
"type": "object",
"properties": {
"summary": {
"type": "string"
},
"explanation": {
"type": "string"
}
},
# Ensure both fields are present in each finding
"required": ["summary", "explanation"]
}
}
},
# List required fields at the top level
"required": ["title", "summary", "rating", "rating_explanation", "findings"]
}
comm_report = llm.generate(client_query_string=prompts.COMMUNITY_REPORT_QUERY.format(
entities=comm_nodes,
relationships=comm_edges,
response_mime_type="application/json",
response_schema=response_schema
))
comm_data = data_model.CommunityData(title=comm_report_dict["title"], summary=comm_report_dict["summary"], rating=comm_report_dict["rating"], rating_explanation=comm_report_dict["rating_explanation"], findings=comm_report_dict["findings"],
community_nodes=comm_members)
return comm_data
这完成了数据摄取流程。
中间响应的映射步骤
最后,您已进入查询时间。为了生成最终的用户响应,您需要生成一组中间响应(每个社区报告一个)。每个中间响应都将用户查询和一个社区报告作为输入。然后,您按相关性对这些中间查询进行评分。最终,您使用最相关的社区报告以及额外的信息(如相关成员节点的节点描述)作为最终的查询上下文。鉴于大规模社区报告的数量较高,这再次带来了延迟和成本的挑战。和之前一样,您还应将中间响应生成(映射步骤)并行化到无服务器微服务中。在未来,您可以通过增加一个过滤层来显著提高效率,以预先确定社区报告对用户查询的相关性。
def generate_response(client_query: str, community_report: dict):
llm = LLMSession(
system_message=MAP_SYSTEM_PROMPT,
model_name="gemini-1.5-pro-001"
)
response_schema = {
"type": "object",
"properties": {
"response": {
"type": "string",
"description": "The response to the user question as raw string.",
},
"score": {
"type": "number",
"description": "The relevance score of the given community report context towards answering the user question [0.0, 10.0]",
},
},
"required": ["response", "score"],
}
query_prompt = MAP_QUERY_PROMPT.format(
context_community_report=community_report, user_question=client_query)
response = llm.generate(client_query_string=query_prompt,
response_schema=response_schema,
response_mime_type="application/json")
return response
映射步骤微服务使用以下提示:
---Role---
You are an expert agent answering questions based on context that is organized as a knowledge graph.
You will be provided with exactly one community report extracted from that same knowledge graph.
---Goal---
Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the given community report.
You should use the data provided in the community description below as the only context for generating the response.
If you don't know the answer or if the input community description does not contain sufficient information to provide an answer respond "The user question cannot be answered based on the given community context.".
Your response should always contain following elements:
- Query based response: A comprehensive and truthful response to the given user query, solely based on the provided context.
- Importance Score: An integer score between 0-10 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0.
The response should be JSON formatted as follows:
{{"response": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}
---Context Community Report---
{context_community_report}
---User Question---
{user_question}
---JSON Response---
The json response formatted as follows:
{{"response": "Description of point 1 [Data: Reports (report ids)]", "score": score_value}}
response:
最终用户响应的归约步骤
为了成功完成归约步骤,您需要存储中间响应以便在查询时访问。在 graphrag-lite 中,我使用 Firestore 作为微服务之间的共享状态。在触发中间响应生成后,客户端还会定期检查共享状态中是否存在所有预期的条目。以下是来自 graphrag-lite 的代码片段,展示了我如何将每个社区报告提交到 PubSub 队列中。之后,我定期查询共享状态,以检查是否所有中间响应都已处理。最后,针对用户的最终响应将使用得分最高的社区报告作为上下文来回应用户查询。
class KGraphGlobalQuery:
def __init__(self) -> None:
# initialized with info on mq, knowledge graph, shared nosql state
pass
@observe()
def __call__(self, user_query: str) -> str:
# orchestration method taking natural language user query to produce and return final answer to client
comm_report_list = self._get_comm_reports()
# pair user query with existing community reports
query_msg_list = self._context_builder(
user_query=user_query, comm_report_list=comm_report_list)
# send pairs to pubsub queue for work scheduling
for msg in query_msg_list:
self._send_to_mq(message=msg)
print("int response request sent to mq")
# periodically query shared state to check for processing compeltion & get intermediate responses
intermediate_response_list = self._check_shared_state(
user_query=user_query)
# based on helpfulness build final context
sorted_final_responses = self._filter_and_sort_responses(intermediate_response_list=intermediate_response_list)
# get full community reports for the selected communities
comm_report_list = self._get_communities_reports(sorted_final_responses)
# generate & return final response based on final context community repors and nodes.
final_response_system = prompts.GLOBAL_SEARCH_REDUCE_SYSTEM.format(
response_type="Detailled and wholistic in academic style analysis of the given information in at least 8-10 sentences across 2-3 paragraphs.")
llm = LLMSession(
system_message=final_response_system,
model_name="gemini-1.5-pro-001"
)
final_query_string = prompts.GLOBAL_SEARCH_REDUCE_QUERY.format(
report_data=comm_report_list,
user_query=user_query
)
final_response = llm.generate(client_query_string=final_query_string)
return final_response
一旦找到所有条目,客户端将触发基于选定社区上下文的最终用户响应生成。
最终思考
图谱 RAG 是一种强大的技术,每个 ML 工程师都应该将其添加到自己的工具箱中。每个问答类型的应用最终都会遇到这样一个问题:纯粹的提取式、“局部”查询再也无法满足需求了。通过 graphrag-lite,你现在拥有了一种轻量级、云原生、无服务器的实现,可以快速复制。
尽管具有这些优势,请注意,在当前状态下,Graph RAG 消耗的 LLM 输入 token 数量明显高于 text2emb RAG。这通常会导致查询和文档索引的延迟和成本显著增加。然而,在体验了结果质量的提升后,我相信在适当的应用场景中,Graph RAG 是值得花时间和金钱的。
RAG 应用最终会朝着混合方向发展。提取式查询可以通过 text2emb RAG 高效且准确地处理。全局抽象查询可能需要知识图谱作为替代的检索层。最后,两种方法在处理定量和分析性查询时表现不佳。因此,第三种 text2sql 检索层将大大增加价值。为了完善这一框架,用户查询最初可以在三种检索方法之间进行分类。这样,每个查询都能以最有效的方式获取适量和深度的信息。
我迫不及待想看看这个技术还会走向何方。你目前在使用哪些替代的检索方法?
图论在模型集成中的协调作用
使用基于图的编排优化多模型协作
·发表于 Towards Data Science ·阅读时间:7 分钟·2024 年 2 月 21 日
--

管弦乐队 — 摄影师 Arindam Mahanta,图片来源于 Unsplash
整合各种 AI 模型的能力释放了潜力的交响乐,从自动化需要多种能力的复杂任务(如视觉、语音、写作和合成)到增强决策过程。然而,协调这些协作在管理内部关系和依赖性方面存在重大挑战。传统的线性方法往往力不从心,难以处理不同模型和动态依赖关系的复杂性。
通过将你的机器学习工作流程转化为图形,你可以可视化每个模型如何相互作用并为结合自然语言处理、计算机视觉和语音模型的整体结果做出贡献。采用图形方法时,节点代表模型或任务,边定义它们之间的依赖关系。这种基于图的映射提供了多个优势,可以识别哪些模型依赖于其他模型的输出,并利用并行处理来执行独立任务。此外,我们还可以根据任务的优先级,使用现有的图形导航策略,如广度优先或深度优先,来执行这些任务。
通向和谐的 AI 模型协作之路并非没有障碍。想象一下,指挥一个管弦乐队,每个乐器演奏者都说不同的语言,乐器之间相互独立。这个挑战反映了在整合不同 AI 模型时的沟通障碍,需要一个框架来管理模型之间的关系,并明确哪些模型可以接收每种输入格式。
从理论到实践:预期的使用案例
基于图的协调方法为各个领域开辟了激动人心的可能性:
药物发现的协同任务

三个模型协作作为数据分析任务的一部分示意图 — 图片来自作者
研究人员可以通过一系列由人工智能驱动的助手加速药物发现过程,每个助手都专门为特定任务设计,例如,采用三步药物发现任务。第一步涉及使用语言模型扫描大量科学数据,突出与特定疾病密切相关的潜在蛋白质靶点,接着是视觉模型来解释复杂的图表或图像,为已识别的蛋白质结构提供详细见解。这个视觉模型对于理解潜在药物如何与蛋白质相互作用至关重要。最后,第三个模型结合语言和视觉模型的输入,预测化学化合物可能如何影响目标蛋白质,为研究人员提供有价值的见解,从而有效推动整个过程。
在模型集成过程中会出现若干挑战,以便交付完整的流程。将从扫描内容中提取相关图像并将其输入视觉模型并不像看起来那么简单。需要在文本扫描和视觉任务之间加入一个中介处理器,以过滤相关图像。其次,分析任务本身应当整合多个输入:数据扫描输出、视觉模型的解释以及用户指定的指令。这需要一个模板来将信息组合在一起,以便语言模型进行处理。以下部分将描述如何利用 Python 框架来处理这些复杂的关系。
创意内容生成

四个任务生成动画的示意图 — 图片来自作者
模型协作可以通过整合音乐创作、动画制作和设计模型等元素来促进互动内容的创作。例如,在基于图的协作方法中,第一个任务可以像导演一样规划一个场景,并传递每个音乐和图像生成任务的输入。最后,动画模型将使用艺术和音乐模型的输出生成一个短视频。
为了优化这个过程,我们旨在实现音乐和图形生成的并行执行,因为它们是独立的任务。因此,不需要等待图形完成才能生成音乐。此外,我们还需要解决动画任务的多样化输入格式问题。虽然一些模型如 Stable Video Diffusion 只处理图像,但音乐可以通过后处理器与之结合。
这些例子仅展示了图论在模型整合中的潜力。图集成方法允许您根据特定需求定制多个任务,开启创新解决方案。
Intelli 框架通过图论协调人工智能模型

任务以图形形式表示 — 图片由作者提供
Intelli 是一个开源 Python 模块,用于协调 AI 工作流,利用图形原则通过三个关键组件实现:
代理 充当你的 AI 模型的代表,你通过指定其类型(文本、图像、视觉或语音)、提供者(openai、gemini、stability、mistral 等)和任务来定义每个代理。
任务 是 AI 工作流中的单个单元。每个任务利用代理执行特定操作,并应用用户提供的自定义预处理和后处理。
Flow 将一切连接在一起,协调任务的执行,遵循你通过图形结构建立的依赖关系。流管理确保任务高效且按正确顺序执行,能够在可能的情况下支持顺序和并行处理。
使用流组件将任务关系管理为图形在连接多个模型时提供了几个好处,但对于仅有一个任务的情况,这可能有些过度,直接调用模型就足够了。
扩展性:随着项目复杂性的增长,增加更多模型和任务需要反复更新代码,以处理数据格式不匹配和复杂的依赖关系。图形方法通过定义一个新的节点来表示任务,简化了这一过程,框架会自动解决输入/输出差异,从而协调数据流。
动态适应:在传统方法中,复杂任务的变化会影响整个工作流,需要进行调整。而使用流时,它会自动处理添加、删除或修改连接。
可解释性:图形通过可视化模型如何交互并优化任务路径导航,增强了你对 AI 工作流的深入理解。
注意:作者参与了 Intelli 框架的设计和开发。它是一个具有 Apache 许可证的开源项目。
开始使用
首先,确保你使用的是 Python 3.7 及以上版本,因为 Intelli 利用了最新的 Python asyncio 特性,然后安装:
pip install intelli
代理:任务执行者
Intelli 中的代理旨在与特定的 AI 模型接口。每个代理都包括一个统一的输入层,可以访问任何类型的模型,并提供一个字典,允许传递自定义参数,如最大大小、温度和模型版本。
from intelli.flow.agents.agent import Agent
# Define agents for various AI tasks
text_agent = Agent(
agent_type="text",
provider="openai",
mission="write social media posts",
model_params={"key": OPENAI_API_KEY, "model": "gpt-4"}
)
任务:构建块
任务表示由代理执行的单个工作单元或操作,并包含处理前一个任务输出的逻辑。每个任务可以是一个简单的操作,如生成文本,或是一个更复杂的过程,如分析用户反馈的情感。
from intelli.flow.tasks.task import Task
from intelli.flow.input.task_input import TextTaskInput
# Define a task for text generation
task1 = Task(
TextTaskInput("Create a post about AI technologies"),
text_agent,
log=True
)
处理器:优化的输入/输出
处理器通过定义任务输入的自定义预处理和输出的后处理,提供了额外的控制层。下面的示例演示了在调用图像模型之前,创建一个函数来缩短上一步骤的文本输出。
class TextProcessor:
@staticmethod
def text_head(text, size=800):
retupytrn text[:size]
task2 = Task(
TextTaskInput("Generate image about the content"),
image_agent,
pre_process=TextProcessor.text_head,
log=True,
)
Flow: 指定依赖关系
Flow 将您的 AI 工作流转换为有向无环图(DAG),并利用图论进行依赖管理。这使得您可以轻松地可视化任务之间的关系,并优化任务的执行顺序。
from intelli.flow.flow import Flow
flow = Flow(
tasks={
"title_task": title_task,
"content_task": content_task,
"keyword_task": keyword_task,
"theme_task": description_theme_task,
"image_task": image_task,
},
map_paths={
"title_task": ["keyword_task", "content_task"],
"content_task": ["theme_task"],
"theme_task": ["image_task"],
},
)
output = await flow.start()
map_paths 决定了任务的依赖关系,引导 Flow 编排执行顺序,确保每个任务从其前置任务获得必要的输出。
以下是 Flow 如何遍历节点的方式:
映射工作流:Flow 使用任务作为节点、依赖关系作为边构建 DAG。这种可视化表示清晰地显示了任务执行顺序和数据流动。
拓扑排序:Flow 分析图形,以确定最佳执行顺序。没有输入依赖的任务优先执行,确保每个任务在执行之前能够从前置任务中获得必要的输入。
任务执行:框架遍历排序后的任务,逐个执行每个任务并提供相应的输入。根据依赖图,输入可能来自之前任务的输出或用户定义的值。
输入准备:在执行之前,任务应用为任务定义的任何预处理功能,根据需要修改输入数据,并调用分配的代理。
输出管理:代理返回输出,并将其存储在以任务名称为键的字典中,然后返回给用户。
要将您的工作流可视化为图形:
flow.generate_graph_img()

任务和分配代理的可视化 —— 由 Intelli 图形功能生成的图像
结论
使用图论已经改变了传统的线性 AI 模型编排方式,通过提供不同模型之间协作的交响乐式方法。
像 Intelli 这样的框架将您的工作流转化为可视化表示,其中任务成为节点,依赖关系被映射为边,从而创建一个完整的流程概览,用于自动化复杂的任务。
这种方法扩展到需要协作 AI 模型的各个领域,包括科学研究、商业决策自动化和互动内容创作。然而,有效的规模化需要在管理模型之间的数据交换方面进一步改进。
参考文献
Intelli git 仓库:
github.com/intelligentnode/Intelli
图形可视化:从简单到高级的 7 个步骤
使用 Python、NetworkX 和 D3.JS 制作可视化
·发表于Towards Data Science ·阅读时间 9 分钟 ·2024 年 6 月 15 日
--

Davis 的南方俱乐部图,图像来源于作者
有些数据类型,如社交网络或知识图谱,可以“本地”以图形形式表示。可视化这类数据可能具有挑战性,而且没有通用的解决方案。在本文中,我将展示使用开源NetworkX库进行图形可视化的几个步骤。
让我们开始吧!
基本示例
如果我们想在 Python 中使用图形,NetworkX可能是最受欢迎的选择。它是一个用于网络分析的开源 Python 包,包含了各种算法和强大的功能。如我们所知,每个图形包含节点(顶点)和边;我们可以轻松地在 NetworkX 中创建一个简单的图形:
import networkx as nx
G = nx.Graph()
G.add_node("A")
G.add_node("B")
G.add_edge("A", "B")
...
然而,以这种方式创建一个大型图形可能会很繁琐,本文中,我将使用 NetworkX 库中包含的“Davis 的南方俱乐部妇女”图(3 条款 BSD 许可证)。这些数据由 A. Davis 等人于 1930 年代收集(A. Davis,1941,Deep South,芝加哥:芝加哥大学出版社)。它代表了 18 位南方女性参加 14 个社交活动的观察数据。让我们加载……
Graphcast:如何完成任务
本文介绍了如何使用谷歌最新的工具进行预测,从获取数据到格式化等等。
·发布于 Towards Data Science ·14 分钟阅读·2024 年 1 月 29 日
--

天气预测是一个非常复杂的问题。数值天气预报(NWP)模型、天气研究与预报(WRF)模型已被用来解决这个问题,但其准确性和精确性有时仍然不足。
由于其复杂性,这个问题吸引了从数据科学家到数据科学爱好者,再到气象工程师的广泛关注和解决方案的追求。虽然已经找到了解决方案,但一致性和统一性依然缺乏。解决方案因地区、山脉、高原、沼泽到苔原而异。根据我个人的经验,我相信其他人的经历也是如此,天气预测被发现是一个难以攻克的难题。引用一位虾类亿万富翁的话:
这就像一盒巧克力,你永远不知道自己会得到什么。
最近,Deepmind 发布了一款新工具:Graphcast,一个用于更快、更准确的全球天气预报的 AI 模型,试图让这包巧克力变得更加美味和高效。在谷歌 TPU v4 机器上,使用 Graphcast,可以在不到一分钟的时间内,以 0.25 度的空间分辨率获取预测。它解决了使用传统方法预测时可能面临的许多问题:
预测结果是针对所有坐标一次性生成的,
根据坐标编辑逻辑现在变得多余。
令人费解的高效性和响应时间。
不那么令人费解的是使用上述工具获取预测所需的数据准备。

然而,不用担心,我将成为你黑暗且阴郁盔甲下的骑士,在本文中解释准备和格式化数据所需的步骤,最后使用 Graphcast 获取预测。
注意:如今使用“AI”这个词让我非常想起在漫威电影中使用“量子”一词的方式。
获取预测是一个过程,可以分为以下几个部分:
获取输入数据。
创建目标。
创建强迫数据。
处理并将数据格式化为合适的格式。
将它们整合在一起并进行预测。
Graphcast 表示,使用当前的天气数据和 6 小时前的数据,可以预测未来 6 小时的情况。为了简单起见,举个例子:
如果需要预测:2024–01–01 18:00,
然后输入数据:2024–01–01 12:00 和 2024–01–01 06:00。
需要注意的是,2024–01–01 18:00 将是第一个被获取的预测。Graphcast 还可以额外获取未来 10 天的数据,每次预测之间间隔 6 小时。因此,其他可以获取预测的时间戳为:
2024–01–02 00:00,06:00,12:00,18:00,
2024–01–03 00:00,06:00,类似的时间戳一直到,
2024–01–10 06:00,12:00。
总结来说,可以通过两个时间戳的输入 预测 40 个时间戳的数据。
假设和重要参数
在本文中,我将展示的代码中,我为某些参数指定了以下值,这些参数决定了你获得预测的速度以及使用的内存。
输入时间戳:2024–01–01 6:00,12:00。
第一次预测时间戳:2024–01–01 18:00。
预测数量:4。
空间分辨率:1 度。
压力水平:13。
以下是导入所需包、初始化用于输入和预测的字段数组以及其他有用变量的代码。
import cdsapi
import datetime
import functools
from graphcast import autoregressive, casting, checkpoint, data_utils as du, graphcast, normalization, rollout
import haiku as hk
import isodate
import jax
import math
import numpy as np
import pandas as pd
from pysolar.radiation import get_radiation_direct
from pysolar.solar import get_altitude
import pytz
import scipy
from typing import Dict
import xarray
client = cdsapi.Client() # Making a connection to CDS, to fetch data.
# The fields to be fetched from the single-level source.
singlelevelfields = [
'10m_u_component_of_wind',
'10m_v_component_of_wind',
'2m_temperature',
'geopotential',
'land_sea_mask',
'mean_sea_level_pressure',
'toa_incident_solar_radiation',
'total_precipitation'
]
# The fields to be fetched from the pressure-level source.
pressurelevelfields = [
'u_component_of_wind',
'v_component_of_wind',
'geopotential',
'specific_humidity',
'temperature',
'vertical_velocity'
]
# The 13 pressure levels.
pressure_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
# Initializing other required constants.
pi = math.pi
gap = 6 # There is a gap of 6 hours between each graphcast prediction.
predictions_steps = 4 # Predicting for 4 timestamps.
watts_to_joules = 3600
first_prediction = datetime.datetime(2024, 1, 1, 18, 0) # Timestamp of the first prediction.
lat_range = range(-180, 181, 1) # Latitude range.
lon_range = range(0, 360, 1) # Longitude range.
# A utility function used for ease of coding.
# Converting the variable to a datetime object.
def toDatetime(dt) -> datetime.datetime:
if isinstance(dt, datetime.date) and isinstance(dt, datetime.datetime):
return dt
elif isinstance(dt, datetime.date) and not isinstance(dt, datetime.datetime):
return datetime.datetime.combine(dt, datetime.datetime.min.time())
elif isinstance(dt, str):
if 'T' in dt:
return isodate.parse_datetime(dt)
else:
return datetime.datetime.combine(isodate.parse_date(dt), datetime.datetime.min.time())
输入
在机器学习中,为了获取一些预测,你必须给机器学习模型一些数据,它会基于这些数据给出预测。例如,在预测一个人是否是蝙蝠侠时,输入数据可能是:
他们每晚睡多少小时?
他们的脸上有晒痕吗?
他们在早晨的会议上会打瞌睡吗?
他们的净资产是多少?
同样,Graphcast 也需要某些输入,这些输入通过CDS获得,使用其 Python 库:cdsapi。目前,数据发布者使用创作共用 4.0 国际许可协议,这意味着任何人都可以复制、分发、传输和修改该作品,只要注明原作者。
然而,在使用 cdsapi 获取数据之前,需要进行身份验证,具体说明见此,CDS 提供了相关说明,且过程相对简单。
假设你现在已经获得 CDS 认证,可以创建输入数据,具体步骤如下:
获取单层值:这些值依赖于坐标和时间。所需的输入字段之一是total_precipitation_6hr。顾名思义,这是指从该特定时间戳开始的过去 6 小时的降水量。因此,我们不仅仅获取两个输入时间戳的值,而是需要获取从2024-01-01 00:00 到 12:00的时间戳数据。
获取压力层值:除了依赖于坐标外,还依赖于压力层。因此,在请求数据时,我们会指定所需的压力层数据。在这种情况下,我们只会获取两个输入时间戳的值。
合并单层和压力值:在上述数据上,基于时间、纬度和经度进行内连接操作。
整合年份和天数进度:除了单层和压力字段外,还需要向输入数据中添加四个字段:year_progress_sin、year_progress_cos、day_progress_sin和day_progress_cos。这可以通过graphcast包中提供的函数来实现。
其他小步骤包括:
从 CDS 获取数据后重命名列,因为 CDS 输出的是天气变量的简化形式。
将geopotential变量重命名为geopotential_at_surface,用于单层数据,因为压力层有相同的字段名。
使用math库中的函数,在从 graphcast 获得progress值后,计算 sin 和 cos 值。
将latitude重命名为lat,将longitude重命名为lon,并引入另一个索引:batch,其值为 0。
创建输入数据的代码如下。
# Getting the single and pressure level values.
def getSingleAndPressureValues():
client.retrieve(
'reanalysis-era5-single-levels',
{
'product_type': 'reanalysis',
'variable': singlelevelfields,
'grid': '1.0/1.0',
'year': [2024],
'month': [1],
'day': [1],
'time': ['00:00', '01:00', '02:00', '03:00', '04:00', '05:00', '06:00', '07:00', '08:00', '09:00', '10:00', '11:00', '12:00'],
'format': 'netcdf'
},
'single-level.nc'
)
singlelevel = xarray.open_dataset('single-level.nc', engine = scipy.__name__).to_dataframe()
singlelevel = singlelevel.rename(columns = {col:singlelevelfields[ind] for ind, col in enumerate(singlelevel.columns.values.tolist())})
singlelevel = singlelevel.rename(columns = {'geopotential': 'geopotential_at_surface'})
# Calculating the sum of the last 6 hours of rainfall.
singlelevel = singlelevel.sort_index()
singlelevel['total_precipitation_6hr'] = singlelevel.groupby(level=[0, 1])['total_precipitation'].rolling(window = 6, min_periods = 1).sum().reset_index(level=[0, 1], drop=True)
singlelevel.pop('total_precipitation')
client.retrieve(
'reanalysis-era5-pressure-levels',
{
'product_type': 'reanalysis',
'variable': pressurelevelfields,
'grid': '1.0/1.0',
'year': [2024],
'month': [1],
'day': [1],
'time': ['06:00', '12:00'],
'pressure_level': pressure_levels,
'format': 'netcdf'
},
'pressure-level.nc'
)
pressurelevel = xarray.open_dataset('pressure-level.nc', engine = scipy.__name__).to_dataframe()
pressurelevel = pressurelevel.rename(columns = {col:pressurelevelfields[ind] for ind, col in enumerate(pressurelevel.columns.values.tolist())})
return singlelevel, pressurelevel
# Adding sin and cos of the year progress.
def addYearProgress(secs, data):
progress = du.get_year_progress(secs)
data['year_progress_sin'] = math.sin(2 * pi * progress)
data['year_progress_cos'] = math.cos(2 * pi * progress)
return data
# Adding sin and cos of the day progress.
def addDayProgress(secs, lon:str, data:pd.DataFrame):
lons = data.index.get_level_values(lon).unique()
progress:np.ndarray = du.get_day_progress(secs, np.array(lons))
prxlon = {lon:prog for lon, prog in list(zip(list(lons), progress.tolist()))}
data['day_progress_sin'] = data.index.get_level_values(lon).map(lambda x: math.sin(2 * pi * prxlon[x]))
data['day_progress_cos'] = data.index.get_level_values(lon).map(lambda x: math.cos(2 * pi * prxlon[x]))
return data
# Adding day and year progress.
def integrateProgress(data:pd.DataFrame):
for dt in data.index.get_level_values('time').unique():
seconds_since_epoch = toDatetime(dt).timestamp()
data = addYearProgress(seconds_since_epoch, data)
data = addDayProgress(seconds_since_epoch, 'longitude' if 'longitude' in data.index.names else 'lon', data)
return data
# Adding batch field and renaming some others.
def formatData(data:pd.DataFrame) -> pd.DataFrame:
data = data.rename_axis(index = {'latitude': 'lat', 'longitude': 'lon'})
if 'batch' not in data.index.names:
data['batch'] = 0
data = data.set_index('batch', append = True)
return data
if __name__ == '__main__':
values:Dict[str, xarray.Dataset] = {}
single, pressure = getSingleAndPressureValues()
values['inputs'] = pd.merge(pressure, single, left_index = True, right_index = True, how = 'inner')
values['inputs'] = integrateProgress(values['inputs'])
values['inputs'] = formatData(values['inputs'])
Targets
有 11 个预测字段:
u_component_of_wind,
v_component_of_wind,
geopotential,
specific_humidity,
temperature,
vertical_velocity,
10m_u_component_of_wind,
10m_v_component_of_wind,
2m_temperature,
mean_sea_level_pressure,
total_precipitation.

由 Ricardo Arce 拍摄,图片来自 Unsplash
传递的目标本质上是一个空的 xarray,用于所有预测字段:
每个 coordinate,
prediction timestamps 和
pressure level。
实现这一功能的代码如下所示。
# Includes the packages imported and constants assigned.
# The functions created for the inputs also go here.
predictionFields = [
'u_component_of_wind',
'v_component_of_wind',
'geopotential',
'specific_humidity',
'temperature',
'vertical_velocity',
'10m_u_component_of_wind',
'10m_v_component_of_wind',
'2m_temperature',
'mean_sea_level_pressure',
'total_precipitation_6hr'
]
# Creating an array full of nan values.
def nans(*args) -> list:
return np.full((args), np.nan)
# Adding or subtracting time.
def deltaTime(dt, **delta) -> datetime.datetime:
return dt + datetime.timedelta(**delta)
def getTargets(dt, data:pd.DataFrame):
# Creating an array consisting of unique values of each index.
lat, lon, levels, batch = sorted(data.index.get_level_values('lat').unique().tolist()), sorted(data.index.get_level_values('lon').unique().tolist()), sorted(data.index.get_level_values('level').unique().tolist()), data.index.get_level_values('batch').unique().tolist()
time = [deltaTime(dt, hours = days * gap) for days in range(4)]
# Creating an empty dataset using latitude, longitude, the pressure levels and each prediction timestamp.
target = xarray.Dataset({field: (['lat', 'lon', 'level', 'time'], nans(len(lat), len(lon), len(levels), len(time))) for field in predictionFields}, coords = {'lat': lat, 'lon': lon, 'level': levels, 'time': time, 'batch': batch})
return target.to_dataframe()
if __name__ == '__main__':
# The code for creating inputs will be here.
values['targets'] = getTargets(first_prediction, values['inputs'])
强迫因子
正如 targets 的情况一样,forcings 中也包含每个坐标和预测时间戳的值,但不包括压力层级。forcings 中的字段包括:
total_incident_solar_radiation,
year_progress_sin,
year_progress_cos,
day_progress_sin,
day_progress_cos。
需要注意的是,以上值是 相对于预测时间戳 进行分配的。正如在处理 inputs 时的情况,year 和 day progress 仅依赖于时间戳,而 solar radiation 来自单层数据源。然而,由于我们正在进行预测,即获取未来的值,对于 forcings,太阳辐射值在 CDS 数据集中是不可用的。为此,我们使用 pysolar 库来模拟太阳辐射值。
# Includes the packages imported and constants assigned.
# The functions created for the inputs and targets also go here.
# Adding a timezone to datetime.datetime variables.
def addTimezone(dt, tz = pytz.UTC) -> datetime.datetime:
dt = toDatetime(dt)
if dt.tzinfo == None:
return pytz.UTC.localize(dt).astimezone(tz)
else:
return dt.astimezone(tz)
# Getting the solar radiation value wrt longitude, latitude and timestamp.
def getSolarRadiation(longitude, latitude, dt):
altitude_degrees = get_altitude(latitude, longitude, addTimezone(dt))
solar_radiation = get_radiation_direct(dt, altitude_degrees) if altitude_degrees > 0 else 0
return solar_radiation * watts_to_joules
# Calculating the solar radiation values for timestamps to be predicted.
def integrateSolarRadiation(data:pd.DataFrame):
dates = list(data.index.get_level_values('time').unique())
coords = [[lat, lon] for lat in lat_range for lon in lon_range]
values = []
# For each data, getting the solar radiation value at a particular coordinate.
for dt in dates:
values.extend(list(map(lambda coord:{'time': dt, 'lon': coord[1], 'lat': coord[0], 'toa_incident_solar_radiation': getSolarRadiation(coord[1], coord[0], dt)}, coords)))
# Setting indices.
values = pd.DataFrame(values).set_index(keys = ['lat', 'lon', 'time'])
# The forcings dataset will now contain the solar radiation values.
return pd.merge(data, values, left_index = True, right_index = True, how = 'inner')
def getForcings(data:pd.DataFrame):
# Since forcings data does not contain batch as an index, it is dropped.
# So are all the columns, since forcings data only has 5, which will be created.
forcingdf = data.reset_index(level = 'level', drop = True).drop(labels = predictionFields, axis = 1)
# Keeping only the unique indices.
forcingdf = pd.DataFrame(index = forcingdf.index.drop_duplicates(keep = 'first'))
# Adding the sin and cos of day and year progress.
# Functions are included in the creation of inputs data section.
forcingdf = integrateProgress(forcingdf)
# Integrating the solar radiation values.
forcingdf = integrateSolarRadiation(forcingdf)
return forcingdf
if __name__ == '__main__':
# The code for creating inputs and targets will be here.
values['forcings'] = getForcings(values['targets'])
后处理输入、目标和强迫因子
现在三大支柱 Graphcast 已经创建完成,我们进入了最后冲刺阶段。就像 NBA 总决赛中,已经赢得了 3 场比赛,现在我们进入最关键的部分,完成任务。
就像科比·布莱恩特曾经说过的,
工作还没有完成。

提到 xarray 时,数据主要有两种类型:
坐标,索引:lat、lon、time……以及
数据变量,列:land_sea_mask、geopotential 等等。
每个数据变量包含的每个值,都有一定的坐标分配给它。坐标是数据变量值所依赖的那些坐标。例如,在我们自己的数据中,
land_sea_mask 完全依赖于 latitude 和 longitude,即其坐标。
geopotential 的坐标是 batch、latitude、longitude、time 和 pressure level。
与此形成鲜明对比,但却合乎逻辑的是,geopotential_at_surface 的坐标是 latitude 和 longitude。
因此,在我们继续进行天气预测之前,我们确保每个数据变量都分配到正确的坐标,相关的代码如下所示。
# Includes the packages imported and constants assigned.
# The functions created for the inputs, targets and forcings also go here.
# A dictionary created, containing each coordinate a data variable requires.
class AssignCoordinates:
coordinates = {
'2m_temperature': ['batch', 'lon', 'lat', 'time'],
'mean_sea_level_pressure': ['batch', 'lon', 'lat', 'time'],
'10m_v_component_of_wind': ['batch', 'lon', 'lat', 'time'],
'10m_u_component_of_wind': ['batch', 'lon', 'lat', 'time'],
'total_precipitation_6hr': ['batch', 'lon', 'lat', 'time'],
'temperature': ['batch', 'lon', 'lat', 'level', 'time'],
'geopotential': ['batch', 'lon', 'lat', 'level', 'time'],
'u_component_of_wind': ['batch', 'lon', 'lat', 'level', 'time'],
'v_component_of_wind': ['batch', 'lon', 'lat', 'level', 'time'],
'vertical_velocity': ['batch', 'lon', 'lat', 'level', 'time'],
'specific_humidity': ['batch', 'lon', 'lat', 'level', 'time'],
'toa_incident_solar_radiation': ['batch', 'lon', 'lat', 'time'],
'year_progress_cos': ['batch', 'time'],
'year_progress_sin': ['batch', 'time'],
'day_progress_cos': ['batch', 'lon', 'time'],
'day_progress_sin': ['batch', 'lon', 'time'],
'geopotential_at_surface': ['lon', 'lat'],
'land_sea_mask': ['lon', 'lat'],
}
def modifyCoordinates(data:xarray.Dataset):
# Parsing through each data variable and removing unneeded indices.
for var in list(data.data_vars):
varArray:xarray.DataArray = data[var]
nonIndices = list(set(list(varArray.coords)).difference(set(AssignCoordinates.coordinates[var])))
data[var] = varArray.isel(**{coord: 0 for coord in nonIndices})
data = data.drop_vars('batch')
return data
def makeXarray(data:pd.DataFrame) -> xarray.Dataset:
# Converting to xarray.
data = data.to_xarray()
data = modifyCoordinates(data)
return data
if __name__ == '__main__':
# The code for creating inputs, targets and forcings will be here.
values = {value:makeXarray(values[value]) for value in values}
使用 Graphcast 进行预测
计算、处理和组装好 inputs、targets 和 forcings 后,接下来就是进行 predictions 的时候了。
现在我们需要模型权重和归一化统计文件,这些文件是由 Deepmind 提供的。
需要下载的文件包括:
stats/diffs_stddev_by_level.nc,
stats/stddev_by_level.nc,
stats/mean_by_level.nc 和
params/GraphCast_small — ERA5 1979–2015 — 分辨率 1.0 — 压力层级 13 — 网格 2to5 — 降水输入和输出.npz。
下方显示了上述文件相对于预测文件的相对路径。保持结构的重要性在于能够成功导入并读取所需的文件。
.
├── prediction.py
├── model
├── params
├── GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz
├── stats
├── diffs_stddev_by_level.nc
├── mean_by_level.nc
├── stddev_by_level.nc
使用Deepmind 提供的预测代码,上述所有功能最终通过下面的代码片段完成预测。
# Includes the packages imported and constants assigned.
# The functions created for the inputs, targets and forcings also go here.
with open(r'model/params/GraphCast_small - ERA5 1979-2015 - resolution 1.0 - pressure levels 13 - mesh 2to5 - precipitation input and output.npz', 'rb') as model:
ckpt = checkpoint.load(model, graphcast.CheckPoint)
params = ckpt.params
state = {}
model_config = ckpt.model_config
task_config = ckpt.task_config
with open(r'model/stats/diffs_stddev_by_level.nc', 'rb') as f:
diffs_stddev_by_level = xarray.load_dataset(f).compute()
with open(r'model/stats/mean_by_level.nc', 'rb') as f:
mean_by_level = xarray.load_dataset(f).compute()
with open(r'model/stats/stddev_by_level.nc', 'rb') as f:
stddev_by_level = xarray.load_dataset(f).compute()
def construct_wrapped_graphcast(model_config:graphcast.ModelConfig, task_config:graphcast.TaskConfig):
predictor = graphcast.GraphCast(model_config, task_config)
predictor = casting.Bfloat16Cast(predictor)
predictor = normalization.InputsAndResiduals(predictor, diffs_stddev_by_level = diffs_stddev_by_level, mean_by_level = mean_by_level, stddev_by_level = stddev_by_level)
predictor = autoregressive.Predictor(predictor, gradient_checkpointing = True)
return predictor
@hk.transform_with_state
def run_forward(model_config, task_config, inputs, targets_template, forcings):
predictor = construct_wrapped_graphcast(model_config, task_config)
return predictor(inputs, targets_template = targets_template, forcings = forcings)
def with_configs(fn):
return functools.partial(fn, model_config = model_config, task_config = task_config)
def with_params(fn):
return functools.partial(fn, params = params, state = state)
def drop_state(fn):
return lambda **kw: fn(**kw)[0]
run_forward_jitted = drop_state(with_params(jax.jit(with_configs(run_forward.apply))))
class Predictor:
@classmethod
def predict(cls, inputs, targets, forcings) -> xarray.Dataset:
predictions = rollout.chunked_prediction(run_forward_jitted, rng = jax.random.PRNGKey(0), inputs = inputs, targets_template = targets, forcings = forcings)
return predictions
if __name__ == '__main__':
# The code for creating inputs, targets, forcings & processing will be here.
predictions = Predictor.predict(values['inputs'], values['targets'], values['forcings'])
predictions.to_dataframe().to_csv('predictions.csv', sep = ',')
结论
上面,我提供了每个将要执行的过程的代码:
创建输入、目标和强迫项,
将上述数据处理成可用格式,最终
将它们汇集在一起并进行预测。
在执行过程中,将所有过程整合在一起,以实现无缝的实施是非常重要的。
为了简单起见,我上传了代码,以及 Docker 镜像和容器文件,可以用来创建一个环境来执行预测程序。
在天气预测领域,目前我们有像 Accuweather、IBM 以及多个 meteomatics 模型等贡献者。Graphcast 证明是这一系列中一个有趣且在许多情况下更高效的补充。然而,它也有一些远未达到最佳状态的属性。在某个难得的思考时刻,我总结出以下几点见解:
与其他天气预测服务相比,Graphcast 更高效且速度更快,能够在几分钟内为整个世界提供预测。
这使得通过 API 进行成百上千次地理位置调用变得多余。
然而,要在几分钟内完成上述操作,必须拥有非常强大的机器,至少是 Google TPU v4 或更好的机器。这种机器并不容易获取。即便选择使用 AWS、Google 或 Azure 的虚拟机,费用也会迅速累积。
目前没有办法使用小范围的地理数据或坐标子集并获得相应的预测。始终需要所有坐标的数据。
CDS 提供的数据有五天的延迟期,这意味着在“x”日期,CDS 只能提供到“x-5”日期的数据。这使得未来天气预测变得有些复杂,因为在进行未来预测之前,必须先覆盖延迟期。
需要注意的是,Graphcast 是天气预测领域中相对较新的工具,未来肯定会进行更改和新增功能,以提高易用性和可访问性。考虑到它在效率和性能方面的领先优势,Graphcast 一定会加以利用。
资源:
祝你在数据科学之旅中好运,感谢阅读 😃
GraphMuse:一个用于符号音乐图形处理的 Python 库
是的,音乐和图形确实能结合!
·发表于Towards Data Science ·11 分钟阅读·2024 年 10 月 17 日
--

图片由 Dall-E 3 生成
在这篇文章中,我们将介绍我最近的一篇论文和开源软件:GraphMuse Python 库。
但在深入讨论之前,让我向你介绍一些符号音乐处理的基础知识。
故事继续...
符号音乐处理主要是指从乐谱中提取信息。符号一词指的是任何形式的乐谱或记谱法中所包含的符号。乐谱可以包含多种元素,除了音符外,还可能包括拍号、调号、演奏标记、力度标记等。乐谱可以存在于多种格式中,如 MIDI、MusicXML、MEI、Kern、ABC 等。
近年来,图神经网络(GNNs)变得越来越流行,并在许多领域取得了成功,从生物网络到推荐系统,再到音乐分析。在音乐分析领域,GNNs 已被用于解决和声分析、乐句分割和声部分离等任务。
这个想法很简单:乐谱中的每个音符是图中的一个顶点,边则由音符之间的时间关系定义,如下图所示。

这些边分为 4 类:
同时开始的音符通过“起始”边(蓝色)连接
在另一个音符的开始或结束处开始的音符通过“连续”边(红色)连接
在另一个音符的开始和结束之间开始的音符通过“期间”边(绿色)连接
最后,每当有静音时,所有的最后音符结尾都会通过“静音”边(黄色)连接到接下来的第一个音符。
这种最小化的图建模保证了乐谱从头到尾始终连通,不会出现任何断开的子图。
什么是 GraphMuse
GraphMuse 是一个用于训练和应用深度图模型进行音乐分析的 Python 库,专门针对乐谱。
GraphMuse 包含用于符号音乐处理的加载器、模型和工具,基于图神经网络(GNNs)。它建立在 PyTorch 和 PyTorch Geometric 之上,提供更多的灵活性和互操作性。
PyTorch 是一个开源的机器学习库,可以高效地构建深度学习模型,并支持 GPU 加速。PyTorch Geometric 是一个基于 PyTorch 的库,旨在轻松编写和训练图神经网络(GNNs),并应用于广泛的领域。
最后,GraphMuse 提供了将乐谱转化为图形的功能。图形创建通过 C 语言实现,并提供 Python 绑定,以加速图形构建,比之前基于 numpy 的实现速度提高了最多 300 倍。
科学基础
图形已经被频繁用于分析和表示音乐。举几个例子,Tonnetz、Schenkerian 分析以及树形结构分析是一些著名的例子。图形的优势在于,它们可以通过边的设计,使用相同的表示法同时捕捉音乐的层次性和顺序性。
基于图形的符号音乐处理使用 GNNs 是在 2021 年随着从乐谱生成表现模型而出现的。从那时起,许多图形模型相继问世,其中一些已经成为音乐分析任务的最先进方法,直到本文发布为止。
那么,既然我已经阐明了图形的必要性,让我们面对设计和训练符号音乐图模型的复杂性。
图形和音乐的主要复杂性在于,音乐作品的长度并不总是相同的,而由它们生成的图形大小也不相同。它们的大小可能会有显著差异:例如,一首巴赫的圣歌可能只有 200 个音符,而一首贝多芬的奏鸣曲可能有超过 5000 个音符。在我们的图形中,音符的数量与每个乐谱图中顶点的数量直接对应。
在乐谱图上高效快速地训练并非易事,它需要一种采样方法,可以最大化计算资源的使用,既包括内存,也包括时间,同时不降低模型的性能,有时甚至能提高性能。
在训练过程中,采样涉及将来自不同得分的图组合成一个新图,计算机科学中通常称之为“批次”。然后将每个批次输入到 GNN 模型中,计算损失。这个损失用于反向传播并更新模型的参数。这一单次迭代称为一个训练步骤。为了优化模型,这个过程会重复多次,直到训练收敛,理想情况下模型达到最佳性能。
这一切听起来很复杂,但不要绝望,因为 GraphMuse 可以为你处理这一部分!!
GraphMuse 的内部工作原理

GraphMuse 中符号化音乐得分的图处理/训练管道一般包括以下步骤:
预处理得分数据库以生成输入图,GraphMuse 可以为你快速且轻松地完成这一步;
对输入图进行采样以创建内存高效的批次,再次强调,GraphMuse 会为你处理这一切;
将多个采样输入图中的节点和边组合形成一个新图作为一个批次;对于每个图,选择一组节点,我们称之为目标节点。目标节点的邻居也可以按需获取,这一过程叫做按节点采样。
通过图卷积更新目标节点的表示,以创建节点嵌入。GraphMuse 提供了一些可以使用的模型,另外 PyTorch Geometric 也可以是你的好帮手;
将这些嵌入用于特定任务的应用。这个部分由你来处理,但我相信你一定能做得很好!
请注意,目标节点可能包括所有或部分批次节点,具体取决于采样策略。
现在,既然这个过程已经图示化解释了,我们来更仔细地看看 GraphMuse 如何处理从每个得分中采样音符。

上图:采样音符及其邻居;中图:得分图和采样过程;下图:节拍和小节的采样过程。
每个得分的采样过程。
首先随机选择一个音符(黄色)进行采样。
然后计算目标音符的边界,在这个示例中,预算为 15 个音符(粉色和黄色音符)。
然后为目标节点获取 k-hop 邻居(浅蓝色为 1-hop,深蓝色为 2-hop)。k-hop 邻居是相对于输入图计算的(如上图所示,颜色边连接音符头)。
我们还可以扩展采样过程,以包括节拍和小节元素。请注意,k-hop 邻居不一定与时间窗口严格相关。
为了最大化计算资源(即内存),上述过程会对多个得分同时重复,以创建一个批次。使用这个过程,GraphMuse 确保每个采样段都将具有相同大小的目标音符。每个采样段可以组合成一个新的图,其大小最多为*#_scores* x #_target_notes。这个新图构成当前训练步骤的批次。
使用 GraphMuse 的实践
在动手操作部分,让我们尝试使用 GraphMuse 并应用一个音高拼写模型。音高拼写任务是指当乐谱中缺少音符名称和变音符时推断出它们的正确拼写。一个应用实例是,当我们拥有量化的 MIDI 文件并希望生成像下图中的乐谱时:

MIDI 文件是输入(顶部),而音乐乐谱是期望的输出(底部)。
在安装 GraphMuse 之前,你需要先安装 PyTorch 和 PyTorch Geometric。请查看适合你系统的版本,链接在这里和这里。
完成此步骤后,打开你喜欢的终端并输入以下命令来安装 GraphMuse:
pip install graphmuse
安装完成后,让我们从 URL 读取一个 MIDI 文件并使用 GraphMuse 创建乐谱图。
import graphmuse as gm
midi_url_raw = "https://github.com/CPJKU/partitura/raw/refs/heads/main/tests/data/midi/bach_midi_score.mid"
graph = gm.load_midi_to_graph(midi_url_raw)
底层过程使用 Partitura 读取文件,然后将其传递给 GraphMuse。
为了训练我们的模型处理音高拼写,我们首先需要一个已标注音高拼写的音乐乐谱数据集。为此,我们将使用 ASAP 数据集(遵循CC BY-NC-SA 4.0许可协议),它将作为我们模型学习的基础。要获取 ASAP 数据集,你可以通过 git 下载或直接从 github 下载:
git clone https://github.com/cpjku/asap-dataset.git
ASAP 数据集包括各种古典钢琴曲目的乐谱和演奏数据。对于我们的用例,我们将仅使用以.musicxml结尾的乐谱文件。
在加载此数据集时,我们需要两个必要的工具:一个用于编码音高拼写,另一个用于处理调号信息,这两个工具将被转换为数字标签。幸运的是,这些工具可以在 GraphMuse 中预构建的音高拼写模型中找到。让我们开始导入所有必需的包,并加载第一个乐谱文件以开始。
import graphmuse as gm
import partitura as pt
import os
import torch
import numpy as np
# Directory containing the dataset, change this to the location of your dataset
dataset_dir = "/your/path/to/the/asap-dataset"
# Find all the score files in the dataset (they are all named 'xml_score.musicxml')
score_files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(dataset_dir) for f in filenames if f == 'xml_score.musicxml']
# Use the first 30 scores, change this number to use more or less scores
score_files = score_files[:30]
# probe the first score file
score = pt.load_score(score_files[0])
# Extract features and note array
features, f_names = gm.utils.get_score_features(score)
na = score.note_array(include_pitch_spelling=True, include_key_signature=True)
# Create a graph from the score features
graph = gm.create_score_graph(features, score.note_array())
# Get input feature size and metadata from the first graph
in_feats = graph["note"].x.shape[1]
metadata = graph.metadata()
# Create a model for pitch spelling prediction
model = gm.nn.models.PitchSpellingGNN(
in_feats=in_feats, n_hidden=128, out_feats_enc=64, n_layers=2, metadata=metadata, add_seq=True
)
# Create encoders for pitch and key signature labels
pe = model.pitch_label_encoder
ke = model.key_label_encoder
接下来,我们将从数据集中加载剩余的乐谱文件,以继续为模型训练准备数据。
# Initialize lists to store graphs and encoders
graphs = []
# Process each score file
for score_file in score_files:
# Load the score
score = pt.load_score(score_file)
# Extract features and note array
features, f_names = gm.utils.get_score_features(score)
na = score.note_array(include_pitch_spelling=True, include_key_signature=True)
# Encode pitch and key signature labels
labels_pitch = pe.encode(na)
labels_key = ke.encode(na)
# Create a graph from the score features
graph = gm.create_score_graph(features, score.note_array())
# Add encoded labels to the graph
graph["note"].y_pitch = torch.from_numpy(labels_pitch).long()
graph["note"].y_key = torch.from_numpy(labels_key).long()
# Append the graph to the list
graphs.append(graph)
一旦图结构准备好,我们就可以继续创建数据加载器,GraphMuse 已经方便地提供了这一功能。在此阶段,我们还将定义标准的训练组件,如损失函数和优化器,以引导学习过程。
# Create a DataLoader to sample subgraphs from the graphs
loader = gm.loader.MuseNeighborLoader(graphs, subgraph_size=100, batch_size=16, num_neighbors=[3, 3])
# Define loss functions for pitch and key prediction
loss_pitch = torch.nn.CrossEntropyLoss()
loss_key = torch.nn.CrossEntropyLoss()
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
让我再多说一点关于gm.loader.MuseNeighborLoader的内容。它是 GraphMuse 中的核心数据加载器,包含了前一节中提到的采样方法。subgraph_size指的是每个输入图的目标节点数量,batch_size是每批次采样的图的数量,最后,num_neighbors指的是每一层中每个采样节点的邻居数。
一切就绪后,我们终于可以开始训练模型了。让我们深入并开始训练过程吧!
# Train the model for 5 epochs
for epoch in range(5):
loss = 0
i = 0
for batch in loader:
# Zero the gradients
optimizer.zero_grad()
# Get neighbor masks for nodes and edges for more efficient training
neighbor_mask_node = {k: batch[k].neighbor_mask for k in batch.node_types}
neighbor_mask_edge = {k: batch[k].neighbor_mask for k in batch.edge_types}
# Forward pass through the model
pred_pitch, pred_key = model(
batch.x_dict, batch.edge_index_dict, neighbor_mask_node, neighbor_mask_edge,
batch["note"].batch[batch["note"].neighbor_mask == 0]
)
# Compute loss for pitch and key prediction
loss_pitch_val = loss_pitch(pred_pitch, batch["note"].y_pitch[batch["note"].neighbor_mask == 0])
loss_key_val = loss_key(pred_key, batch["note"].y_key[batch["note"].neighbor_mask == 0])
# Total loss
loss_val = loss_pitch_val + loss_key_val
# Backward pass and optimization
loss_val.backward()
optimizer.step()
# Accumulate loss
loss += loss_val.item()
i += 1
# Print average loss for the epoch
print(f"Epoch {epoch} Loss {loss / i}")
希望我们很快能看到损失函数下降,这是模型有效学习音高拼写的积极信号。希望一切顺利!

为什么选择 GraphMuse?
GraphMuse 是一个框架,旨在简化符号音乐处理中的图形模型训练和部署。
对于那些想要重新训练、部署或微调先前最先进的符号音乐分析模型的人,GraphMuse 包含了一些必要的组件,可以帮助你更快速、更高效地重建和重新训练模型。
GraphMuse 通过其简单性保持灵活性,适合那些希望原型设计、创新和设计新模型的人。它旨在提供一套简单的工具,而不是包含复杂的链式流程,以避免阻碍创新过程。
对于那些想要学习、可视化并获得实践经验的人,GraphMuse 非常适合入门。它提供了一个简洁的介绍,通过几行代码就能掌握基本功能和流程。GraphMuse 还与MusGViz相连接,可以轻松地将图形和乐谱一起可视化。
限制与未来计划
我们无法谈论任何项目的积极方面而不讨论其消极方面。
GraphMuse 是一个新兴的项目,目前仍相当简单。它专注于覆盖图形学习的核心部分,而不是成为一个涵盖所有可能性的完整框架。因此,它仍然在许多上述流程的部分中,更多地关注基于用户的实现。
像所有开发中的开源项目一样,GraphMuse 需要帮助才能成长。因此,如果你发现漏洞或想要更多功能,请随时报告、请求或贡献到 GraphMuse 的 GitHub 项目中。
最后但同样重要的是,GraphMuse 使用了如 torch-sparse 和 torch-scatter 等 C 库,并拥有自己的 C 绑定,以加速图形创建,因此安装并非总是很简单。根据我们的用户测试和交互报告,Windows 的安装相对更具挑战性,但并非不可能(我自己就是在 Windows 上运行的)。
未来计划包括:
简化安装过程;
为精确任务增加更多对模型和数据加载器的支持;
扩大 GraphMuse 的开源社区,让图形编码在音乐领域持续增长。
结论
GraphMuse 是一个 Python 库,使得处理音乐图形变得更加容易。它专注于基于图形的音乐模型的训练,但也旨在在研究项目需要时保持灵活性。
如果你想支持 GraphMuse 的开发和未来成长,请在这里给这个仓库点个星。
祝你愉快地进行图形编码!!
[## GitHub - manoskary/graphmuse: A Graph Deep Learning Library for Music.
一款用于音乐的图谱深度学习库。通过创建一个账户来为 manoskary/graphmuse 的开发做出贡献…
github.com](https://github.com/manoskary/graphmuse?source=post_page-----40dbd9baf319--------------------------------)
[所有图片均由作者提供]
GraphRAG 实践:从商业合同到动态问答代理
一种基于问题的提取方法
·发布于 Towards Data Science ·23 分钟阅读·2024 年 11 月 4 日
--
在这篇博客文章中,我们介绍了一种方法,利用图形检索增强生成(GraphRAG)方法——简化商业合同数据的摄取过程,并构建一个问答代理。
这种方法与传统的 RAG(检索增强生成)方法不同,它强调数据提取的效率,而不是像传统 RAG 方法那样随意拆解和向量化整个文档。
在传统的 RAG 中,每个文档都会被拆分成多个块并向量化以供检索,这可能导致大量不必要的数据被拆分、分块并存储在向量索引中。然而,在这里,重点是从每份合同中提取最相关的信息,以满足特定的应用场景——商业合同审核。然后,这些数据会被结构化成一个知识图谱,图谱组织了关键实体和关系,从而通过 Cypher 查询和向量检索实现更精确的图数据检索。
通过最小化向量化内容的数量,并专注于提取高度相关的知识,这种方法提高了问答代理的准确性和性能,使其能够处理复杂和特定领域的问题。
该四阶段方法包括:有针对性的信息提取(LLM + Prompt),创建知识图谱(LLM + Neo4J)以及简单的一组图数据检索功能(Cypher、Text to Cypher、Vector Search)。最后,构建一个利用数据检索功能的问答代理,基于(Microsoft Semantic Kernel)。
下图展示了该方法的实现方式

四阶段 GraphRAG 方法:从基于问题的提取 → 知识图谱模型 → GraphRAG 检索 → 问答代理。图像由 Sebastian Nilsson @ Neo4J 提供,并获得作者许可在此重制。
但首先,对于那些不熟悉商业法的人来说,让我们先简要介绍一下合同审查问题。
合同审查与大语言模型
商业合同审查是一个劳动密集型过程,涉及律师助理和初级律师仔细识别合同中的关键信息。
“合同审查是全面阅读合同的过程,目的是理解签署合同的个人或公司所承担的权利和义务,并评估相关影响。”
Hendrycks, Burns 等人,NeurIPS 2021, 见CUAD:一个专家标注的法律合同审查 NLP 数据集
合同审查的第一阶段包括审查数百页合同,寻找相关条款或义务。合同审查员必须确定是否存在相关条款,若存在,这些条款的内容是什么,并且跟踪它们的位置。
例如,他们必须确定合同是三年期合同还是一年期合同。他们必须确定合同的结束日期。他们必须确定某个条款是否是反转让条款或排他性条款……
Hendrycks, Burns 等人,NeurIPS 2021, 见CUAD:一个专家标注的法律合同审查 NLP 数据集
这是一个需要仔细审查的任务,但通常效率低下,但它非常适合大语言模型!
一旦完成第一阶段,高级法律从业者可以开始审查合同中的弱点和风险。这是一个领域,在这个领域中,由 LLM 支持并通过存储在知识图谱中的信息为基础的问答代理是法律专家的完美副驾驶。
使用 LLMs、功能调用和 GraphRAG 构建商业合同审查代理的四步方法
本文的其余部分将描述这个过程中的每个步骤。在此过程中,我将使用代码片段来说明主要概念。
四个步骤是:
从合同中提取相关信息(LLM + 合同)
将提取的信息存储到知识图谱中(Neo4j)
开发简单的 KG 数据检索功能(Python)
构建处理复杂问题的问答代理(语义内核,LLM,Neo4j)
数据集:
CUAD(合同理解阿提克斯数据集)是一个采用 CC BY 4.0 许可的公开数据集,包含超过 13,000 个专家标注的条款,跨越 510 份法律合同,旨在帮助构建用于合同审查的 AI 模型。它涵盖了广泛的重要法律条款,例如保密条款、终止条款和赔偿条款,这些对于合同分析至关重要。
我们将使用这个数据集中的三个合同,展示我们如何有效地提取和分析关键的法律信息,构建知识图谱,并利用它进行精确的复杂问题回答。
三个合同合计包含 95 页。
第 1 步:从合同中提取相关信息
向大型语言模型(LLM)请求提取合同中的精确信息并生成 JSON 输出,表示合同中的相关信息,是相对直接的。
在商业审查中,可以编写一个提示,来定位上述提到的每个关键元素——各方、日期、条款——并将其整洁地汇总成机器可读的(JSON)文件。
提取提示(简化版)
请仅使用本合同中的信息回答以下问题
[Contract.pdf]
这是什么类型的合同?
各方是谁及其角色是什么?他们在哪个国家注册?请提供州和国家名称(使用 ISO 3166 国家名称)
协议日期是什么?
生效日期是什么?
对于以下每种类型的合同条款,提取两条信息:
一个是/否选项,表示你是否认为该条款出现在此合同中
一份摘录列表,指示该条款类型的存在。
合同条款类型:竞争限制例外、不竞争条款、排他性、禁止诱导客户、禁止诱导员工、不得贬损、便捷解除、Rofr/Rofo/Rofn、控制权变更、反转让、无限责任、责任上限
请将最终答案以 JSON 文档的形式提供。
请注意,上述部分展示了提取提示的简化版本。完整版本可以在此处查看。你会发现,提示的最后部分指定了 JSON 文档的所需格式。这有助于确保输出的一致 JSON 模式。
这个任务在 Python 中相对简单。下面的 main() 函数旨在通过提取相关法律信息(extraction_prompt),使用OpenAI gpt-4o处理一组 PDF 合同文件,并将结果保存为 JSON 格式。
def main():
pdf_files = [filename for filename in os.listdir('./data/input/') if filename.endswith('.pdf')]
for pdf_filename in pdf_files:
print('Processing ' + pdf_filename + '...')
# Extract content from PDF using the assistant
complete_response = process_pdf('./data/input/' + pdf_filename)
# Log the complete response to debug
save_json_string_to_file(complete_response, './data/debug/complete_response_' + pdf_filename + '.json')
“process_pdf” 函数使用“OpenAI gpt-4o”从合同中执行知识提取,使用“提取提示”。
def process_pdf(pdf_filename):
# Create OpenAI message thread
thread = client.beta.threads.create()
# Upload PDF file to the thread
file = client.files.create(file=open(pdf_filename, "rb"), purpose="assistants")
# Create message with contract as attachment and extraction_prompt
client.beta.threads.messages.create(thread_id=thread.id,role="user",
attachments=[
Attachment(
file_id=file.id, tools=[AttachmentToolFileSearch(type="file_search")])
],
content=extraction_prompt,
)
# Run the message thread
run = client.beta.threads.runs.create_and_poll(
thread_id=thread.id, assistant_id=pdf_assistant.id, timeout=1000)
# Retrieve messages
messages_cursor = client.beta.threads.messages.list(thread_id=thread.id)
messages = [message for message in messages_cursor]
# Return last message in Thread
return messages[0].content[0].text.value
对于每个合同,“process_pdf” 返回的消息如下
{
"agreement": {
"agreement_name": "Marketing Affiliate Agreement",
"agreement_type": "Marketing Affiliate Agreement",
"effective_date": "May 8, 2014",
"expiration_date": "December 31, 2014",
"renewal_term": "1 year",
"Notice_period_to_Terminate_Renewal": "30 days",
"parties": [
{
"role": "Company",
"name": "Birch First Global Investments Inc.",
"incorporation_country": "United States Virgin Islands",
"incorporation_state": "N/A"
},
{
"role": "Marketing Affiliate",
"name": "Mount Knowledge Holdings Inc.",
"incorporation_country": "United States",
"incorporation_state": "Nevada"
}
],
"governing_law": {
"country": "United States",
"state": "Nevada",
"most_favored_country": "United States"
},
"clauses": [
{
"clause_type": "Competitive Restriction Exception",
"exists": false,
"excerpts": []
},
{
"clause_type": "Exclusivity",
"exists": true,
"excerpts": [
"Company hereby grants to MA the right to advertise, market and sell to corporate users, government agencies and educational facilities for their own internal purposes only, not for remarketing or redistribution."
]
},
{
"clause_type": "Non-Disparagement",
"exists": true,
"excerpts": [
"MA agrees to conduct business in a manner that reflects favorably at all times on the Technology sold and the good name, goodwill and reputation of Company."
]
},
{
"clause_type": "Termination For Convenience",
"exists": true,
"excerpts": [
"This Agreement may be terminated by either party at the expiration of its term or any renewal term upon thirty (30) days written notice to the other party."
]
},
{
"clause_type": "Anti-Assignment",
"exists": true,
"excerpts": [
"MA may not assign, sell, lease or otherwise transfer in whole or in part any of the rights granted pursuant to this Agreement without prior written approval of Company."
]
},
{
"clause_type": "Price Restrictions",
"exists": true,
"excerpts": [
"Company reserves the right to change its prices and/or fees, from time to time, in its sole and absolute discretion."
]
},
{
"clause_type": "Minimum Commitment",
"exists": true,
"excerpts": [
"MA commits to purchase a minimum of 100 Units in aggregate within the Territory within the first six months of term of this Agreement."
]
},
{
"clause_type": "IP Ownership Assignment",
"exists": true,
"excerpts": [
"Title to the Technology and all copyrights in Technology shall remain with Company and/or its Affiliates."
]
},
{
"clause_type": "License grant",
"exists": true,
"excerpts": [
"Company hereby grants to MA the right to advertise, market and sell the Technology listed in Schedule A of this Agreement."
]
},
{
"clause_type": "Non-Transferable License",
"exists": true,
"excerpts": [
"MA acknowledges that MA and its Clients receive no title to the Technology contained on the Technology."
]
},
{
"clause_type": "Cap On Liability",
"exists": true,
"excerpts": [
"In no event shall Company be liable to MA, its Clients, or any third party for any tort or contract damages or indirect, special, general, incidental or consequential damages."
]
},
{
"clause_type": "Warranty Duration",
"exists": true,
"excerpts": [
"Company's sole and exclusive liability for the warranty provided shall be to correct the Technology to operate in substantial accordance with its then current specifications."
]
}
]
}
}
第 2 步:创建知识图谱
现在,每个合同都是一个 JSON 文件,下一步是在 Neo4J 中创建知识图谱。
在这一点上,花些时间设计数据模型是有用的。你需要考虑一些关键问题:
图中的节点和关系代表什么?
每个节点和关系的主要属性是什么?
是否需要索引任何属性?
哪些属性需要向量嵌入以启用语义相似度搜索?
在我们的案例中,一个合适的设计(模式)包括主要实体:协议(合同)、它们的条款、作为协议当事方的组织及其之间的关系。
下方显示的是模式的可视化表示。

作者提供的图片
Node properties:
Agreement {agreement_type: STRING, contract_id: INTEGER,
effective_date: STRING, expiration_date: STRING,
renewal_term: STRING, name: STRING}
ContractClause {name: STRING, type: STRING}
ClauseType {name: STRING}
Country {name: STRING}
Excerpt {text: STRING}
Organization {name: STRING}
Relationship properties:
IS_PARTY_TO {role: STRING}
GOVERNED_BY_LAW {state: STRING}
HAS_CLAUSE {type: STRING}
INCORPORATED_IN {state: STRING}
只有“摘录”——在步骤 1 中由 LLM 识别的短文本片段——需要文本嵌入。这种方法显著减少了表示每个合同所需的向量数量和向量索引的大小,从而提高了处理效率和可扩展性。
一个简化版的 Python 脚本,通过上述模式将每个 JSON 加载到知识图谱中的示例如下:
NEO4J_URI=os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER=os.getenv('NEO4J_USERNAME', 'neo4j')
NEO4J_PASSWORD=os.getenv('NEO4J_PASSWORD')
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
JSON_CONTRACT_FOLDER = './data/output/'
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
contract_id = 1
json_contracts = [filename for filename in os.listdir(JSON_CONTRACT_FOLDER) if filename.endswith('.json')]
for json_contract in json_contracts:
with open(JSON_CONTRACT_FOLDER + json_contract,'r') as file:
json_string = file.read()
json_data = json.loads(json_string)
agreement = json_data['agreement']
agreement['contract_id'] = contract_id
driver.execute_query(CREATE_GRAPH_STATEMENT, data=json_data)
contract_id+=1
create_full_text_indices(driver)
driver.execute_query(CREATE_VECTOR_INDEX_STATEMENT)
print ("Generating Embeddings for Contract Excerpts...")
driver.execute_query(EMBEDDINGS_STATEMENT, token = OPENAI_API_KEY)
这里的“CREATE_GRAPH_STATEMENT”是唯一的“复杂”部分。它是一个 CYPHER 语句,将合同(JSON)映射到知识图谱中的节点和关系。
完整的 Cypher 语句如下:
CREATE_GRAPH_STATEMENT = """
WITH $data AS data
WITH data.agreement as a
MERGE (agreement:Agreement {contract_id: a.contract_id})
ON CREATE SET
agreement.contract_id = a.contract_id,
agreement.name = a.agreement_name,
agreement.effective_date = a.effective_date,
agreement.expiration_date = a.expiration_date,
agreement.agreement_type = a.agreement_type,
agreement.renewal_term = a.renewal_term,
agreement.most_favored_country = a.governing_law.most_favored_country
//agreement.Notice_period_to_Terminate_Renewal = a.Notice_period_to_Terminate_Renewal
MERGE (gl_country:Country {name: a.governing_law.country})
MERGE (agreement)-[gbl:GOVERNED_BY_LAW]->(gl_country)
SET gbl.state = a.governing_law.state
FOREACH (party IN a.parties |
// todo proper global id for the party
MERGE (p:Organization {name: party.name})
MERGE (p)-[ipt:IS_PARTY_TO]->(agreement)
SET ipt.role = party.role
MERGE (country_of_incorporation:Country {name: party.incorporation_country})
MERGE (p)-[incorporated:INCORPORATED_IN]->(country_of_incorporation)
SET incorporated.state = party.incorporation_state
)
WITH a, agreement, [clause IN a.clauses WHERE clause.exists = true] AS valid_clauses
FOREACH (clause IN valid_clauses |
CREATE (cl:ContractClause {type: clause.clause_type})
MERGE (agreement)-[clt:HAS_CLAUSE]->(cl)
SET clt.type = clause.clause_type
// ON CREATE SET c.excerpts = clause.excerpts
FOREACH (excerpt IN clause.excerpts |
MERGE (cl)-[:HAS_EXCERPT]->(e:Excerpt {text: excerpt})
)
//link clauses to a Clause Type label
MERGE (clType:ClauseType{name: clause.clause_type})
MERGE (cl)-[:HAS_TYPE]->(clType)
)"""
以下是该语句执行的操作概述:
数据绑定
WITH $data AS data
WITH data.agreement as a
$data是以 JSON 格式传入查询的输入数据,包含有关协议(合同)的信息。第二行将
data.agreement赋值给别名a,以便在后续查询中引用合同详情。
插入协议节点
MERGE (agreement:Agreement {contract_id: a.contract_id})
ON CREATE SET
agreement.name = a.agreement_name,
agreement.effective_date = a.effective_date,
agreement.expiration_date = a.expiration_date,
agreement.agreement_type = a.agreement_type,
agreement.renewal_term = a.renewal_term,
agreement.most_favored_country = a.governing_law.most_favored_country
MERGE尝试查找具有指定contract_id的现有Agreement节点。如果没有这样的节点,它会创建一个。ON CREATE SET子句设置新创建的Agreement节点的各种属性,如contract_id、agreement_name、effective_date和来自 JSON 输入的其他协议相关字段。
创建适用法律关系
MERGE (gl_country:Country {name: a.governing_law.country})
MERGE (agreement)-[gbl:GOVERNED_BY_LAW]->(gl_country)
SET gbl.state = a.governing_law.state
这会为与协议相关的适用法律国家创建或合并一个
Country节点。然后,创建或合并
Agreement和Country之间的GOVERNED_BY_LAW关系。它还设置了
GOVERNED_BY_LAW关系的state属性。
创建当事方和注册地关系
FOREACH (party IN a.parties |
MERGE (p:Organization {name: party.name})
MERGE (p)-[ipt:IS_PARTY_TO]->(agreement)
SET ipt.role = party.role
MERGE (country_of_incorporation:Country {name: party.incorporation_country})
MERGE (p)-[incorporated:INCORPORATED_IN]->(country_of_incorporation)
SET incorporated.state = party.incorporation_state
)
对于合同中的每个当事方(a.parties),它:
为当事方插入(合并)一个
Organization节点。创建一个
IS_PARTY_TO关系,表示Organization和Agreement之间的关系,并设置当事方的role(例如,买方、卖方)。合并一个
Country节点,表示组织注册的国家。创建一个
INCORPORATED_IN关系,表示组织与注册国家之间的关系,并设置组织的注册地state。
创建合同条款和摘录
WITH a, agreement, [clause IN a.clauses WHERE clause.exists = true] AS valid_clauses
FOREACH (clause IN valid_clauses |
CREATE (cl:ContractClause {type: clause.clause_type})
MERGE (agreement)-[clt:HAS_CLAUSE]->(cl)
SET clt.type = clause.clause_type
FOREACH (excerpt IN clause.excerpts |
MERGE (cl)-[:HAS_EXCERPT]->(e:Excerpt {text: excerpt})
)
MERGE (clType:ClauseType{name: clause.clause_type})
MERGE (cl)-[:HAS_TYPE]->(clType)
)
这一部分首先筛选条款列表(
a.clauses),仅包括clause.exists = true的条款(即在步骤 1 中由 LLM 识别的包含摘录的条款)。对于每个条款:
它为每个条款创建一个
ContractClause节点,其中name和type对应于条款类型。在
Agreement和ContractClause之间建立了HAS_CLAUSE关系。对于与条款相关的每个
excerpt,它创建一个Excerpt节点,并通过HAS_EXCERPT关系将其链接到ContractClause。最后,为条款的类型创建(或合并)一个
ClauseType节点,并通过HAS_TYPE关系将ContractClause与ClauseType链接。
一旦导入脚本运行完毕,单个合同就可以在 Neo4J 中以知识图谱的形式可视化

单一合同的知识图谱表示:组织(各方)用绿色表示,合同条款用蓝色表示,摘录用浅棕色表示,国家用橙色表示。图片来源:作者
知识图谱中的三个合同仅需要一个小型图(不到 100 个节点,少于 200 个关系)。最重要的是,仅需要 40-50 个摘录的向量嵌入。这个拥有少量向量的知识图谱现在可以用来支持一个相当强大的问答代理。
第 3 步:为 GraphRAG 开发数据检索功能
现在,合同已被结构化为知识图谱,下一步是创建一组小型图数据检索功能。这些功能作为核心构建块,使我们能够在第 4 步开发问答代理。
我们将定义几个基本的数据检索功能:
检索合同的基本信息(给定合同 ID)
查找涉及特定组织的合同(给定部分组织名称)
查找不包含特定条款类型的合同
查找包含特定类型条款的合同
基于与条款中的文本(摘录)语义相似度查找合同(例如,查找提到“禁止物品”的合同)
对数据库中的所有合同执行自然语言查询。例如,执行聚合查询,统计“有多少合同符合特定条件”。
在第 4 步,我们将使用Microsoft Semantic Kernel 库构建问答系统。该库简化了代理构建过程,允许开发人员定义代理可用的功能和工具,以便回答问题。
为了简化 Neo4J 与 Semantic Kernel 库之间的集成,我们将定义一个ContractPlugin,该插件定义了每个数据检索功能的“签名”。请注意每个函数的@kernel_function装饰器,以及为每个函数提供的类型信息和描述。
Semantic Kernel 使用“插件”类的概念来封装代理可用的一组功能。它将使用装饰器函数、类型信息和文档来通知 LLM 函数调用能力,以了解可用的功能。
from typing import List, Optional, Annotated
from AgreementSchema import Agreement, ClauseType
from semantic_kernel.functions import kernel_function
from ContractService import ContractSearchService
class ContractPlugin:
def __init__(self, contract_search_service: ContractSearchService ):
self.contract_search_service = contract_search_service
@kernel_function
async def get_contract(self, contract_id: int) -> Annotated[Agreement, "A contract"]:
"""Gets details about a contract with the given id."""
return await self.contract_search_service.get_contract(contract_id)
@kernel_function
async def get_contracts(self, organization_name: str) -> Annotated[List[Agreement], "A list of contracts"]:
"""Gets basic details about all contracts where one of the parties has a name similar to the given organization name."""
return await self.contract_search_service.get_contracts(organization_name)
@kernel_function
async def get_contracts_without_clause(self, clause_type: ClauseType) -> Annotated[List[Agreement], "A list of contracts"]:
"""Gets basic details from contracts without a clause of the given type."""
return await self.contract_search_service.get_contracts_without_clause(clause_type=clause_type)
@kernel_function
async def get_contracts_with_clause_type(self, clause_type: ClauseType) -> Annotated[List[Agreement], "A list of contracts"]:
"""Gets basic details from contracts with a clause of the given type."""
return await self.contract_search_service.get_contracts_with_clause_type(clause_type=clause_type)
@kernel_function
async def get_contracts_similar_text(self, clause_text: str) -> Annotated[List[Agreement], "A list of contracts with similar text in one of their clauses"]:
"""Gets basic details from contracts having semantically similar text in one of their clauses to the to the 'clause_text' provided."""
return await self.contract_search_service.get_contracts_similar_text(clause_text=clause_text)
@kernel_function
async def answer_aggregation_question(self, user_question: str) -> Annotated[str, "An answer to user_question"]:
"""Answer obtained by turning user_question into a CYPHER query"""
return await self.contract_search_service.answer_aggregation_question(user_question=user_question)
我建议探索包含上述每个函数实现的 “ContractService” 类。每个函数展示了一种不同的数据检索技术。
让我们逐步了解这些函数的实现,因为它们展示了不同的 GraphRAG 数据检索技术/模式
根据合同 ID 获取合同 — 基于 Cypher 的检索函数
get_contract(self, contract_id: int) 是一个异步方法,旨在使用 Cypher 查询从 Neo4J 数据库中检索特定合同(Agreement)的详细信息。该函数返回一个填充了关于协议、条款、相关方及其关系信息的 Agreement 对象。
以下是此函数的实现
async def get_contract(self, contract_id: int) -> Agreement:
GET_CONTRACT_BY_ID_QUERY = """
MATCH (a:Agreement {contract_id: $contract_id})-[:HAS_CLAUSE]->(clause:ContractClause)
WITH a, collect(clause) as clauses
MATCH (country:Country)-[i:INCORPORATED_IN]-(p:Organization)-[r:IS_PARTY_TO]-(a)
WITH a, clauses, collect(p) as parties, collect(country) as countries, collect(r) as roles, collect(i) as states
RETURN a as agreement, clauses, parties, countries, roles, states
"""
agreement_node = {}
records, _, _ = self._driver.execute_query(GET_CONTRACT_BY_ID_QUERY,{'contract_id':contract_id})
if (len(records)==1):
agreement_node = records[0].get('agreement')
party_list = records[0].get('parties')
role_list = records[0].get('roles')
country_list = records[0].get('countries')
state_list = records[0].get('states')
clause_list = records[0].get('clauses')
return await self._get_agreement(
agreement_node, format="long",
party_list=party_list, role_list=role_list,
country_list=country_list,state_list=state_list,
clause_list=clause_list
)
最重要的组件是 **GET_CONTRACT_BY_ID_QUERY** 中的 Cypher 查询。此查询使用作为输入参数提供的 contract_id 执行。输出是匹配的协议、其条款和相关方(每个方都有角色和注册国家/州)。
数据随后传递给一个工具函数 _get_agreement,该函数仅将数据映射到一个“Agreement”对象。协议是一个定义为 TypedDict 的类型。
class Agreement(TypedDict):
contract_id: int
agreement_name: str
agreement_type: str
effective_date: str
expiration_date: str
renewal_term: str
notice_period_to_terminate_Renewal: str
parties: List[Party]
clauses: List[ContractClause]
获取没有特定条款类型的合同 — 另一个 Cypher 检索函数
该函数展示了知识图谱的一个强大功能,即测试关系的不存在。
get_contracts_without_clause() 函数从 Neo4J 数据库中检索所有 不 包含特定条款类型的合同(Agreements)。该函数接受一个 ClauseType 作为输入,并返回一个符合条件的 Agreement 对象列表。
这种类型的数据检索信息无法通过向量搜索轻松实现。完整的实现如下
async def get_contracts_without_clause(self, clause_type: ClauseType) -> List[Agreement]:
GET_CONTRACT_WITHOUT_CLAUSE_TYPE_QUERY = """
MATCH (a:Agreement)
OPTIONAL MATCH (a)-[:HAS_CLAUSE]->(cc:ContractClause {type: $clause_type})
WITH a,cc
WHERE cc is NULL
WITH a
MATCH (country:Country)-[i:INCORPORATED_IN]-(p:Organization)-[r:IS_PARTY_TO]-(a)
RETURN a as agreement, collect(p) as parties, collect(r) as roles, collect(country) as countries, collect(i) as states
"""
#run the Cypher query
records, _ , _ = self._driver.execute_query(GET_CONTRACT_WITHOUT_CLAUSE_TYPE_QUERY,{'clause_type':clause_type.value})
all_agreements = []
for row in records:
agreement_node = row['agreement']
party_list = row['parties']
role_list = row['roles']
country_list = row['countries']
state_list = row['states']
agreement : Agreement = await self._get_agreement(
format="short",
agreement_node=agreement_node,
party_list=party_list,
role_list=role_list,
country_list=country_list,
state_list=state_list
)
all_agreements.append(agreement)
return all_agreements
一如既往,格式与前一个函数相似。Cypher 查询 **GET_CONTRACTS_WITHOUT_CLAUSE_TYPE_QUERY** 定义了要匹配的节点和关系模式。它执行 可选匹配 以过滤掉包含条款类型的合同,并收集有关协议的相关数据,例如相关方及其详细信息。
然后,函数构建并返回一个 Agreement 对象列表,其中封装了每个匹配协议的所有相关信息。
获取具有语义相似文本的合同 — 向量搜索 + 图数据检索函数
get_contracts_similar_text() 函数旨在查找包含与提供的 clause_text 相似文本的条款的协议(合同)。它使用语义向量搜索来识别相关摘录,然后遍历图谱以返回有关相应协议和条款的信息,以及这些摘录的来源。
该函数利用定义在每个摘录的 “text” 属性上的向量索引。它使用最近发布的 Neo4J GraphRAG 包 来简化执行语义搜索和图遍历所需的 Cypher 代码。
async def get_contracts_similar_text(self, clause_text: str) -> List[Agreement]:
#Cypher to traverse from the semantically similar excerpts back to the agreement
EXCERPT_TO_AGREEMENT_TRAVERSAL_QUERY="""
MATCH (a:Agreement)-[:HAS_CLAUSE]->(cc:ContractClause)-[:HAS_EXCERPT]-(node)
RETURN a.name as agreement_name, a.contract_id as contract_id, cc.type as clause_type, node.text as excerpt
"""
#Set up vector Cypher retriever
retriever = VectorCypherRetriever(
driver= self._driver,
index_name="excerpt_embedding",
embedder=self._openai_embedder,
retrieval_query=EXCERPT_TO_AGREEMENT_TRAVERSAL_QUERY,
result_formatter=my_vector_search_excerpt_record_formatter
)
# run vector search query on excerpts and get results containing the relevant agreement and clause
retriever_result = retriever.search(query_text=clause_text, top_k=3)
#set up List of Agreements (with partial data) to be returned
agreements = []
for item in retriever_result.items:
//extract information from returned items and append agreement to results
// full code not shown here but available on the Github repo
return agreements
让我们来回顾一下这个数据检索函数的主要组件。
Neo4j GraphRAG VectorCypherRetriever 允许开发者在向量索引上执行语义相似度分析。在我们的案例中,对于每一个语义相似的摘录“节点”,会使用额外的 Cypher 表达式来获取与该节点相关的图中其他节点。
VectorCypherRetriever 的参数 非常直接。
index_name是执行语义相似度分析的向量索引。embedder为一段文本生成向量嵌入。driver只是 Neo4j Python 驱动的一个实例。retrieval_query指定与每个由语义相似度识别的“摘录”节点相关联的其他节点和关系。EXCERPT_TO_AGREEMENT_TRAVERSAL_QUERY指定要检索的额外节点。在这种情况下,对于每个摘录,我们都在检索与其相关的合同条款和相应的协议。
EXCERPT_TO_AGREEMENT_TRAVERSAL_QUERY="""
MATCH (a:Agreement)-[:HAS_CLAUSE]->(cc:ContractClause)-[:HAS_EXCERPT]-(node)
RETURN a.name as agreement_name, a.contract_id as contract_id, cc.type as clause_type, node.text as excerpt
"""
执行自然语言查询——一个 Text 2Cypher 数据检索函数
answer_aggregation_question() 函数利用 Neo4j GraphRAG 包中的 “Text2CypherRetriever” 来回答自然语言的问题。Text2CypherRetriever 使用 LLM 将用户问题转换为 Cypher 查询,并在 Neo4j 数据库中执行该查询。
该函数利用 OpenAI gpt-4o 生成所需的 Cypher 查询。让我们逐步了解这个数据检索函数的主要组件。
async def answer_aggregation_question(self, user_question) -> str:
answer = ""
NEO4J_SCHEMA = """
omitted for brevity (see below for the full value)
"""
# Initialize the retriever
retriever = Text2CypherRetriever(
driver=self._driver,
llm=self._llm,
neo4j_schema=NEO4J_SCHEMA
)
# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results
retriever_result = retriever.search(query_text=user_question)
for item in retriever_result.items:
content = str(item.content)
if content:
answer += content + '\n\n'
return answer
这个函数利用 Neo4j GraphRAG 包中的 “Text2CypherRetriever”。它使用 LLM,在本例中使用的是 OpenAI LLM,将用户问题(自然语言)转换为一个在数据库中执行的 Cypher 查询,并返回该查询的结果。
确保 LLM 生成一个使用数据库中定义的节点、关系和属性的查询的关键要素是向 LLM 提供模式的文本描述。
在我们的案例中,我们使用以下的数据模型表示已经足够。
NEO4J_SCHEMA = """
Node properties:
Agreement {agreement_type: STRING, contract_id: INTEGER,effective_date: STRING,renewal_term: STRING, name: STRING}
ContractClause {name: STRING, type: STRING}
ClauseType {name: STRING}
Country {name: STRING}
Excerpt {text: STRING}
Organization {name: STRING}
Relationship properties:
IS_PARTY_TO {role: STRING}
GOVERNED_BY_LAW {state: STRING}
HAS_CLAUSE {type: STRING}
INCORPORATED_IN {state: STRING}
The relationships:
(:Agreement)-[:HAS_CLAUSE]->(:ContractClause)
(:ContractClause)-[:HAS_EXCERPT]->(:Excerpt)
(:ContractClause)-[:HAS_TYPE]->(:ClauseType)
(:Agreement)-[:GOVERNED_BY_LAW]->(:Country)
(:Organization)-[:IS_PARTY_TO]->(:Agreement)
(:Organization)-[:INCORPORATED_IN]->(:Country)
"""
第四步:构建一个问答代理
拥有我们的知识图谱数据检索函数后,我们已经准备好构建一个由 GraphRAG 支持的代理 😃
让我们建立一个能够回答关于合同的用户查询的聊天代理,使用 OpenAI 的 gpt-4o 模型、我们的数据检索函数和一个由 Neo4j 支持的知识图谱。
我们将使用 Microsoft Semantic Kernel,这个框架 允许开发者将 LLM 函数调用与现有的 API 和数据检索功能进行集成。
该框架使用一个名为 Plugins 的概念来表示内核可以执行的特定功能。在我们的案例中,所有在 “ContractPlugin” 中定义的数据检索函数都可以被 LLM 用来回答问题。
该框架使用记忆的概念来保存用户与代理之间的所有交互,以及执行的功能和检索的数据。
一个极其简单的基于终端的代理可以通过几行代码实现。下面的代码片段展示了代理的主要部分(导入和环境变量已删除)。
logging.basicConfig(level=logging.INFO)
# Initialize the kernel
kernel = Kernel()
# Add the Contract Search plugin to the kernel
contract_search_neo4j = ContractSearchService(NEO4J_URI,NEO4J_USER,NEO4J_PASSWORD)
kernel.add_plugin(ContractPlugin(contract_search_service=contract_search_neo4j),plugin_name="contract_search")
# Add the OpenAI chat completion service to the Kernel
kernel.add_service(OpenAIChatCompletion(ai_model_id="gpt-4o",api_key=OPENAI_KEY, service_id=service_id))
# Enable automatic function calling
settings: OpenAIChatPromptExecutionSettings = kernel.get_prompt_execution_settings_from_service_id(service_id=service_id)
settings.function_choice_behavior = FunctionChoiceBehavior.Auto(filters={"included_plugins": ["contract_search"]})
# Create a history of the conversation
history = ChatHistory()
async def basic_agent() :
userInput = None
while True:
# Collect user input
userInput = input("User > ")
# Terminate the loop if the user says "exit"
if userInput == "exit":
break
# Add user input to the history
history.add_user_message(userInput)
# 3\. Get the response from the AI with automatic function calling
chat_completion : OpenAIChatCompletion = kernel.get_service(type=ChatCompletionClientBase)
result = (await chat_completion.get_chat_message_contents(
chat_history=history,
settings=settings,
kernel=kernel,
arguments=KernelArguments(),
))[0]
# Print the results
print("Assistant > " + str(result))
# Add the message from the agent to the chat history
history.add_message(result)
if __name__ == "__main__":
asyncio.run(basic_agent())
让我们问问我们的代理一些非平凡的问题
1. 查找我有价格限制但没有保险条款的合同
Assistant > I found the following contract that contains a Price Restrictions clause but does not contain an Insurance clause:
- **Name:** Birch First Global Investments Inc. and Mount Knowledge Holdings Inc. MARKETING AFFILIATE AGREEMENT
- **Agreement Type:** MARKETING AFFILIATE AGREEMENT
- **Parties:**
- **MOUNT KNOWLEDGE HOLDINGS INC.** (Marketing Affiliate), incorporated in Nevada, USA
- **BIRCH FIRST GLOBAL INVESTMENTS INC.** (Company), incorporated in the Virgin Islands (VIR)
Would you like to know more details about this contract?
在这里,语义代理规划器已审查内核中可用的功能,并确定回答这个问题需要调用两个单独的功能,而这两个功能本身无法提供完整的答案。
首先,它执行**“获取带有[价格限制]条款的合同”功能,然后执行“获取不含[保险]条款的合同”**功能。接着,规划器将这两个功能的结果结合起来,找到交集,最终得出正确答案。
如果你查看调试信息,你可以确认 LLM 执行的调用顺序,以及调用中消耗的 token 数量。注意具有两个数据检索功能的调用。

2. 是的,请提供更多细节
Assistant > Here are the detailed information about the contract:
### Contract Details:
- **Name:** Birch First Global Investments Inc. and Mount Knowledge Holdings Inc. MARKETING AFFILIATE AGREEMENT
- **Agreement Type:** MARKETING AFFILIATE AGREEMENT
- **Expiration Date:** December 31, 2014
- **Renewal Term:** One (1) year
### Parties Involved:
1\. **MOUNT KNOWLEDGE HOLDINGS INC.**
- **Role:** Marketing Affiliate
- **Incorporation Country:** USA
- **Incorporation State:** Nevada
2\. **BIRCH FIRST GLOBAL INVESTMENTS INC.**
- **Role:** Company
- **Incorporation Country:** Virgin Islands (VIR)
### Clauses Included:
- No-Solicit Of Employees
- Third Party Beneficiary
- Price Restrictions
- Affiliate License-Licensee
- Affiliate License-Licensor
- Cap On Liability
- Audit Rights
- Warranty Duration
- Post-Termination Services
- Non-Transferable License
- License grant
- IP Ownership Assignment
- Minimum Commitment
- Anti-Assignment
- Non-Disparagement
### Clauses Not Included:
- Insurance
Is there anything specific you would like to know or any particular clause details you are interested in?
在这种情况下,背后,LLM 使用了**“按 ID 查找合同”**。你可能已经注意到,显示给用户的输出并未包含合同的 ID。然而,语义内核已经将之前调用返回的所有信息存储在对话记忆中!
这种代理决定调用哪个功能、使用记忆以及将用户查询自动转换为具有正确参数的功能调用的强大组合,解释了构建于语义内核上的代理背后的**“魔力”**!
3. 是否有任何合同提到可以不时调整价格的能力
Assistant > Yes, there is a contract that mentions
the ability to make price changes from time to time.
Here is the relevant excerpt:
1\. **Birch First Global Investments Inc. and Mount Knowledge Holdings Inc. MARKETING AFFILIATE AGREEMENT**
- **Contract ID:** 1
- **Clause Type:** Price Restrictions
- **Excerpt:**
```
公司保留随时根据其单独且绝对的判断更改流程和/或费用的权利。如果技术价格和/或费用发生变化,公司应在任何此类变更生效日期前九十(90)天书面通知 MA。所有其他价格和/或费用可在无需提前通知 MA 的情况下进行调整。
```py
Would you like more information about this contract or any other details?
为了回答这个问题,语义内核正在执行get_contract_similar_text 数据检索功能。
User > any contract mentions the ability to make price changes from time to time
INFO:semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion_base:processing 1 tool calls in parallel.
INFO:semantic_kernel.kernel:Calling contract_search-get_contracts_similar_text
function with args: {"clause_text":"price changes from time to time"}
Assistant > Yes, there is a contract that mentions the ability to make price changes from time to time. Here is the relevant excerpt:
1\. **Birch First Global Investments Inc. and Mount Knowledge Holdings Inc. MARKETING AFFILIATE AGREEMENT**
- **Contract ID:** 1
- **Clause Type:** Price Restrictions
- **Excerpt:**
```
公司保留随时根据其单独且绝对的判断更改流程和/或费用的权利。如果技术价格和/或费用发生变化,公司应在任何此类变更生效日期前九十(90)天书面通知 MA。所有其他价格和/或费用可在无需提前通知 MA 的情况下进行调整。
```py
如前所述,该功能使用向量搜索和 Cypher 查询来检索不仅在语义上相似的摘录,还包括包含它们的相关合同和条款。
最后,让我们询问数据库中有多少合同
4. 数据库中有多少个合同?
这个问题要求 LLM 使用 Text2Cypher 数据检索器
Assistant > There are a total of 3 contracts in the database.
调试信息确认语义内核运行了 Text2Cypher 数据检索功能,且“user_question” = “数据库中有多少个合同?”
User > how many contracts are there on the database?
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion_base:processing 1 tool calls in parallel.
INFO:semantic_kernel.kernel:Calling contract_search-answer_aggregation_question function
with args: {"user_question":"How many contracts are there in the database?"}
INFO:semantic_kernel.functions.kernel_function:Function completed. Duration: 0.588805s
INFO:semantic_kernel.connectors.ai.open_ai.services.open_ai_handler:OpenAI usage: CompletionUsage(completion_tokens=13, prompt_tokens=3328, total_tokens=3341, completion_tokens_details={'reasoning_tokens': 0})
Assistant > There are a total of 3 contracts in the database.
亲自尝试
github 仓库包含一个 Streamlit 应用程序,提供了一个更优雅的代理 UI。我们鼓励你与代理互动,并对 ContractPlugin 进行修改,以提升代理处理更多问题的能力!
结论
在这篇博客中,我们探讨了一个 Graph Retrieval Augmented Generation(GraphRAG)方法,将商业合同审查这一劳动密集型任务转化为更高效、更智能的 AI 驱动过程。
通过利用大型语言模型(LLMs)和提示进行针对性的信息提取,使用 Neo4j 构建结构化知识图谱,实现简单的数据检索功能,并最终开发出问答代理,我们创建了一种能够有效处理复杂问题的智能解决方案。
这种方法减少了传统基于向量搜索的 RAG 中发现的低效问题,而是专注于提取相关信息,减少了不必要的向量嵌入,简化了整体过程。我们希望从合同导入到互动问答代理的这段旅程能激励你在自己的项目中使用 GraphRAG,以提高效率和更智能的 AI 驱动决策。
今天就开始构建你自己的商业合同审查代理,并亲身体验 GraphRAG 的强大功能!
资源
对于那些渴望深入了解的人,请查看以下链接的资源:
法律合同的合同理解 Atticus 数据集(CUAD)(Github)
CUAD:一个专家注释的法律合同审查 NLP 数据集。Hendrycks,Burns,Chen,Ball。NeurIPS 2021
除非另有说明,所有图片均由作者提供
使用 VLM 和 LLM 实现常识抓取
如何利用大型语言模型进行机器人抓取和代码生成
·发表于 Towards Data Science ·6 分钟阅读·2024 年 3 月 29 日
--
抓取和操控仍然是机器人技术中的一个难题,尚未解决。抓取不仅仅是找出将手指放置在物体上的位置,以产生足够的约束力。抓取还涉及施加恰到好处的力量,既能捡起物体又不至于破坏它,同时确保物体能够按预期使用。同时,抓取还为检测物体的类型和属性提供了关键的传感器输入。随着移动性基本解决,抓取和操控仍然是实现真正自主劳动替代的最后一个前沿。
如果你不是 Medium 的订阅者,可以 在这里 免费阅读这篇文章。
想象一下,你正在派遣你的类人机器人去超市,并告诉它:“检查鳄梨的成熟度,今天为做鳄梨酱拿一个鳄梨”。 这里有很多事情在发生:
“成熟度”这一品质并不是像草莓或西红柿那样通过颜色显而易见,而是需要触觉信息来判断。
“抓一个鳄梨”,特别是一个成熟的鳄梨,意味着某种…
灰狼优化器——它如何与计算机视觉结合使用
作为额外内容,获取适用于任何地方的特征提取代码
·发表于Towards Data Science ·10 分钟阅读·2024 年 2 月 23 日
--

图像由 DALL·E 3 根据提示“画一群未来派的灰狼,在夜晚的海滩上”生成。
这是我系列自然启发式文章的最后一部分。之前,我谈到了受遗传学、群体、蜜蜂和蚂蚁启发的算法。今天,我将谈谈狼。
当一篇期刊论文的引用量达到五位数时,你就知道这项工作背后有着重大的意义。灰狼优化器[1](GWO)就是这样一个例子。
概述与动机
与粒子群优化(PSO)、人工蜂群(ABC)和蚁群优化(ACO)类似,灰狼优化器(GWO)也是一种元启发式算法。尽管对解决方案没有数学上的保证,但它在实践中表现良好,并且不需要对基础问题的任何分析性知识。这使我们能够从“黑箱”中查询数据,只需利用观察到的结果来优化我们的解决方案。
如我在 ACO 文章中提到的,所有这些最终都与探索-开发权衡的基本概念相关。那么,为什么会有这么多不同的元启发式算法呢?
首先,这是因为研究人员必须发表论文。他们工作的很大一部分涉及从不同的角度探索事物,并分享他们的研究方法…
理解行为面试
掌握行为面试的艺术,实现你的梦想工作
·发表在Towards Data Science ·6 分钟阅读·2024 年 11 月 30 日
--

由 chatGPT 生成的图片
我在一家为专业人士提供准备就业的机构工作,这些专业人士希望能够在亚马逊、Meta、Google 等高科技公司找到工作。作为面试准备的一部分,许多候选人希望进行行为模拟面试。他们的主要目标是弄清楚这些面试会问什么问题,以及他们应该如何准备。
我在这本名为“理解行为面试”的书中记录了我的经历——作为候选人和招聘经理。你可以在Gumroad 这里获取这本书。希望你会觉得有用。
这里是书中的一瞥,以及一些可以帮助你准备行为面试的技巧和窍门。
什么是行为面试
行为面试问题是招聘过程中的一个重要环节。通常招聘经理会问这些问题来评估候选人的解决问题能力和人际交往能力。总的来说,这些问题旨在评估你是否适合公司的企业文化。
为了让你了解这些问题的风格,以下是一些示例:
- 你能描述一次你作为团队一员共同实现目标的经历吗?你扮演的角色是什么,以及…
Groq 和 AI 硬件 —— 直观且详尽的解释
机器学习 | 加速计算 | 人工智能
对运行 AI 所需的主要计算机硬件组件的分析,并介绍一种新的重量级玩家。
·发布于 Towards Data Science ·阅读时长 31 分钟·2024 年 4 月 12 日
--

“协调解构” 由 Daniel Warfield 使用 MidJourney 制作。除非另有说明,所有图像均由作者提供。所有生成的图像均由 MidJourney 制作。
本文讨论了 Groq,一种全新的计算机硬件方法,它正在革新 AI 在现实世界问题中的应用方式。
在讨论 Groq 之前,我们将首先剖析 AI 的基本概念,并探索一些用于运行 AI 模型的计算机硬件的关键组件。具体来说是:CPU、GPU 和 TPU。我们将从 1975 年开始,通过 Z80 CPU,逐步了解计算机硬件的关键演变,最终构建出对现代系统的理解。
通过理解一些计算机硬件的基本概念和权衡,我们将利用这些理解来探索 Groq 是什么,它如何革新 AI 计算的方式,以及为什么这如此重要。
自然地,从早期的 CPU 到最前沿的价值十亿的 AI 初创公司,涵盖的内容非常广泛。因此,这是一篇相当长的文章。系好安全带,值得一读。
群体等变自注意力
将几何先验注入 Transformer 模型
·发表于 Towards Data Science ·阅读时间 7 分钟·2024 年 1 月 9 日
--
在不断发展的神经架构的动态环境中,效率至关重要。为特定任务定制网络涉及注入先验知识,这可以通过战略性地调整架构来实现。这不仅仅是调整参数——更重要的是将所需的理解嵌入到模型中。实现这一目标的一种方法是使用几何先验——这正是本文要讨论的话题。

一张狗在 90 度旋转下的照片。中间是具有旋转等变性的模型的权重表示,而右侧的模型则没有这种等变性。©J. Hatzky
前提条件
在上一篇文章中我们深入探讨了视觉中的自注意力操作。
现在让我们在此基础上进行拓展,利用几何深度学习的最新进展。
如果你还不熟悉几何深度学习,Michael Bronstein 创建了一个很好的入门系列。
群体等变模型的优势
等变模型可以根据当前任务定制搜索空间,并降低模型学习到虚假关系的概率。

癌细胞在 90 度旋转下的表现。中间是具有旋转等变性的模型的权重表示,而右侧的模型则没有这种等变性。©J. Hatzky
GSM-Symbolic:分析大语言模型在数学推理中的局限性及潜在解决方案
《大语言模型推理论文:正确与遗漏的地方》
·发表于Towards Data Science ·9 分钟阅读·2024 年 10 月 28 日
--
合著者:Alex Watson, Yev Meyer, Dane Corneil, Maarten Van Segbroeck (Gretel.ai)

来源:Gretel.ai
引言
大语言模型(LLMs)最近在人工智能推理领域取得了显著进展,包括数学问题解决。然而,Mirzadeh 等人最近发表的一篇名为《GSM-Symbolic:理解大语言模型在数学推理中的局限性》的论文提出了关于这些模型在数学推理方面的真正能力的疑问。我们已对该论文进行了审阅,认为它对持续讨论人工智能的能力和局限性做出了宝贵贡献,但我们的分析表明,该论文的结论可能并未完全捕捉到问题的复杂性。
GSM-Symbolic 基准
作者们介绍了 GSM-Symbolic,这是一个从流行的 GSM8K 数据集中派生的增强基准。这个新基准允许生成多样化的问题变体,从而能够对大语言模型(LLM)在不同设置下的表现进行更加细致的评估。研究对 25 个最先进的开源和闭源模型的大规模分析提供了重要的见解,揭示了这些模型在面对数学推理任务时的表现。

图 1:GSM-Symbolic:理解大语言模型在数学推理中的局限性(来源:Mirzadeh 等人,GSM-Symbolic 论文)
性能变异性与模型比较
其中一个最令人惊讶的发现是,同一问题的不同实例在模型表现上的高度变异性。所有模型在 GSM-Symbolic 上测试时都表现出“准确性上的显著变异性”。这种变异性引发了对当前报告的 GSM8K基准上单点准确度响应可靠性的担忧。

图 3:GSM-Symbolic:理解大型语言模型中数学推理的局限性(来源:Mirzadeh 等人,GSM-Symbolic 论文)
并非所有模型都是一样的。 Llama-3–8b和GPT-4o在新基准上表现突出,不像其他模型(如gemma-2–9b、phi-3、phi-3.5和mathstral-7b)那样表现出显著的下降。这一观察结果表明了两个重要的观点:
Llama-3–8b和GPT-4o通常能更好地理解数学概念,尽管它们仍然不能免于性能变异。Llama-3–8b和GPT-4o的训练数据可能没有受到 GSM8K 数据的污染(或至少没有像其他模型那样受到污染)。在这种情况下,数据污染是指在模型训练集中无意中包含了测试或基准数据,导致模型在评估时性能被人为抬高。如果发生了污染,正如作者对某些模型的假设那样,我们会期望在 GSM8K 上看到非常高的表现,但在这些问题的稍微变化版本上,表现会显著下降。
这些发现突显了通过使用合成数据改进模型的机会,经过合理设计的合成数据集能够解决这两点问题,适用于任何进行模型训练的人:
为了减少潜在的数据污染问题,当高质量的合成版本可以生成时,训练时无需使用原始的 GSM8K 数据(博客链接)。这些合成数据集保留了 GSM8K 中的数学推理挑战,而不重复使用相同的问题或解决方案,从而保持模型评估的完整性。
更重要的是,生成的合成数据有可能超越 OpenAI 的 GSM8K 和 Apple 的 GSM-Symbolic 数据集的质量。这种方法可以更坚实地理解数学概念,并解决当前模型中观察到的性能变异。
对变化和复杂性的敏感性
作者们表明,LLM(大规模语言模型)对数值变化的敏感度比对问题中专有名词变化的敏感度更高,这表明这些模型对潜在数学概念的理解可能并不像之前认为的那样稳固。当问题的复杂性增加时(通过子句数量来衡量),所有模型的表现都会下降,并且它们的表现方差也会增加。这凸显了在训练中使用多样化数据的重要性,而这正是合成数据能够提供帮助的地方。正如作者所展示的那样,从逻辑上讲,AI 模型在一组特定问题上表现更差,单纯地改变数字或轻微变化子句数量并没有理由发生。

图 4: GSM-Symbolic:理解大规模语言模型中数学推理的局限性(来源:Mirzadeh 等,GSM-Symbolic 论文)
GSM-NoOp 挑战
也许最令人担忧的发现是引入了 GSM-NoOp,一个旨在挑战 LLM 推理能力的数据集。通过向问题中添加表面相关但最终无关紧要的信息,作者观察到所有模型的表现大幅下降——某些模型下降幅度高达 65%。作者提出,这表明当前的 LLM 在某种程度上更依赖于模式匹配而非真正的逻辑推理。

图 6: GSM-Symbolic:理解大规模语言模型中数学推理的局限性(来源:Mirzadeh 等,GSM-Symbolic 论文)
对论文结论的批判性观点
尽管 GSM-Symbolic 研究提供了关于 LLM 在数学推理任务中表现的宝贵见解,但重要的是要对论文的结论进行批判性审视。作者认为,观察到的局限性表明 LLM 无法进行真正的逻辑推理。然而,这种解释可能过于简化了一个复杂的问题。
论文中关于 LLM 依赖模式匹配而非推理的论点,在仔细审视后似乎没有那么确凿。很明显,这些模型并不是完美的推理者——如果是,它们在 GSM8K 上将达到 100%的准确率。但从不完美的表现推断出缺乏推理能力并不一定是有根据的。
至少有两种可能的解释可以说明为什么 LLM 像人类一样,有时会答错问题:
模型尝试严格地将一个问题与它曾经见过的某个问题进行模式匹配,如果不能匹配,则失败。
模型尝试遵循一个逻辑程序,但在每一步都有一定的(累积)出错概率,这可以通过它实际生成词汇的过程来解释。
论文似乎倾向于解释(1),但并没有充分论证为什么解释(1)应该优于解释(2)。事实上,(2)更类似于人类的推理方式,并且从研究角度来看可能更有趣。
让我们通过这种批判性视角来审视论文中的每一个主要发现:
GSM-符号性能
GSM-符号方法是数据集扩展的有价值方法,验证了像 Gretel 使用的合成数据生成技术的潜力。然而,值得注意的是,模型在这些新变种上的表现并没有完全崩溃——只是表现有所下降。如果这些模型仅仅是进行模式匹配,我们可能会预期它们在这些新变种上的表现几乎接近零。观察到的行为似乎更符合这样一种模型:它能在一定程度上进行泛化,但在面对陌生问题结构时犯更多错误。
即使是人类专家也不是全知全能的。例如,在 MATH 基准测试中,前数学奥林匹克选手通常能得 18/20 或 19/20,偶尔犯些小算术错误。这表明,错误易发生的推理,而非缺乏推理能力,可能更准确地描述了人类和 LLM 的表现。
变化的难度
论文中关于问题复杂性增加导致性能下降的发现与多步骤推理过程中的错误累积理论相一致。随着步骤数的增加,出现错误的概率也随之增加。这种现象在人类解决问题时也会出现,并不一定表示缺乏推理能力。
GSM-NoOp 挑战
GSM-NoOp 的结果,可能不像论文所暗示的那样与推理能力直接相关。在现实世界的场景中,我们通常假设问题陈述中提供的所有信息都是相关的。例如,在图 7 中的示例问题中,一位理性的人的推理可能(就像 LLMs 所做的那样)是,猕猴桃的大小之所以被提及,仅仅是因为它们被丢弃了。

图 7:GSM-符号:GSM No-Op 问题示例。(来源:Mirzadeh 等人,GSM-符号论文)
从相关信息和无关信息中辨别出重要内容,尤其是当无关信息被有意插入并带有误导性(即看似相关)时,是一种独立于纯粹数学推理的技能。
作者进行了一项后续实验(NoOp-NoOp),在该实验中,模型在某种程度上“被警告”可能存在误导性意图:他们使用了包含无关信息的少量示例。参与该实验的模型子集仍然表现出性能下降的趋势。若干后续实验可能有助于更好地理解这一现象:
将 NoOp-NoOp 实验扩展到更多模型;
测量当模型在提示中明确被警告某些信息可能是无关时的表现;
在合成训练示例上微调模型,这些示例不仅包含完全相关的信息,还包括一些无关的信息。
改进机会:合成数据的前景
虽然 Mirzadeh 等人的论文突出了当前 LLM 的主要局限性,但在 Gretel,我们已经开发了针对论文中识别的许多挑战的数据集:
- 合成 GSM8K 数据集:可在 HuggingFace 上访问 gretelai/synthetic-gsm8k-reflection-405b,该数据集侧重于生成比 OpenAI 原始人工生成数据集中更复杂的、多步骤推理版本的问题。它结合了先进的提示技术,包括反思和其他认知模型,以捕捉详细的推理过程。该方法已显示出显著的改进,特别是在非常难的问题上,展示了其增强 AI 处理复杂多步骤推理任务能力的潜力。正如我们博客中所述,Gretel 使用这些技术创建的合成数据,在问题复杂性上取得了92.3%的胜率,在教育价值上取得了 82.7%的胜率,超过了标准 Llama 3.1 405B 参数模型的输出,这些技术通过
GPT-4o判断——证明了 LLM 的推理能力可以通过比论文中使用的基本思维链更复杂的训练数据示例和提示技术进一步解锁。

来源: gretel.ai/blog/teaching-ai-to-think-a-new-approach-with-synthetic-data-and-reflection
2. 合成文本到 SQL 数据集:由 Gretel 生成,旨在帮助提高 LLM 与基于 SQL 的数据库/仓库和湖泊的交互能力,数据集可在 gretelai/synthetic_text_to_sql 获取,已证明在提高模型在文本到 SQL 任务上的表现方面非常有效。在用于微调 CodeLlama 模型时, 它在 BIRD 基准测试中提高了 36%以上,这是一个具有挑战性的跨领域文本到 SQL 评估平台。进一步支持关于当前 LLM 训练数据过于简单,导致记忆化的理论,对Phi-3 和 Llama 3.1 模型在此数据集上的单个 epoch 微调,带来了超过 300%的提升,特别是在 BIRD 基准测试中被标记为“非常难”的问题上。
这些结果表明,高质量的合成数据可以成为解决当前 LLMs 在复杂推理任务中局限性的重要工具。
未来方向
总结来说,GSM-Symbolic 论文提供了对当前大语言模型(LLMs)在数学推理任务中的局限性的重要见解。然而,其结论应以批判的眼光来看待。观察到的 LLM 行为可以有多种解释,论文强调模式匹配而非推理,可能在某种程度上简化了更复杂的问题。
研究所识别的局限性是现实且显著的。表现的差异性、对数值变化的敏感性以及对无关信息的处理困难,都指向当前 LLMs 需要改进的领域。
然而,正如更先进的模型(如 GPT-4o 和 Llama 3.1 所展示的那样)所表明的,通过合成多样化且具有挑战性的问题集,推动 AI 模型能够应对的边界,我们可以开发出具备更强大、更类似人类推理能力的 LLMs。
参考文献
-
- Mirzadeh, K. Alizadeh, H. Shahrokhi, O. Tuzel, S. Bengio, 和 M. Farajtabar。GSM-Symbolic: 了解大语言模型在数学推理中的局限性 2024。
引导大型语言模型(LLM)的响应以创建结构化输出
学习如何构建语言模型的响应,确保响应格式被遵守,如 JSON
·发布于 Towards Data Science ·阅读时间 9 分钟·2024 年 6 月 21 日
--

图片来源:Ricardo Gomez Angel 于 Unsplash
本文将教您如何使用 Python 中的验证库来构建 LLM 响应,如 GPT-4 或 Llama 3。
这是一个非常相关的话题,因为从无结构格式(如自由文本)中提取精确信息,转化为 JSON 格式的数据,已被证明在数据挖掘任务中至关重要。
此外,由于大型语言模型(LLM)生成输出令牌的随机特性,即使是在最商业化的系统中,如 GPT,结构化响应格式也不可靠。
我们将使用多个库,如 Pydantic 和 Instructor 进行验证和模式建模,OpenAI 和 ollama 用于 LLM 部分。所提出的内容适用于封闭源模型,如 OpenAI 或 Anthropic 的 GPT,以及开源模型,如 Llama 3。
阅读本文后,您将学习到:
什么是数据模型,以及如何定义它
如何确保您的LLM 通过验证规则遵守输出格式
使用 GloVe 嵌入破解《代号》
使用基于 GloVe 嵌入的算法,在流行派对游戏《代号》中实现 100%的准确性。
·发表于数据科学之路·阅读时间 7 分钟·2024 年 7 月 16 日
--
介绍
代号是一个流行的派对游戏,适合两队,每队由两名玩家组成,每队包括一个间谍主人和一个操作员。每队在游戏板上会被分配一定数量的词卡。在每轮游戏中,间谍主人给出一个词线索,并说明该线索对应的词卡数量。操作员需要猜测哪些词卡属于他/她的队伍。目标是间谍主人给出好的线索,使得操作员能够用较少的回合猜出所有的词卡,且在对方队伍之前完成。此外,还有一张“刺客”卡,操作员一旦猜中该卡,队伍将立即输掉游戏。
在这个项目中,我们将使用一种简单的词向量算法,利用机器学习中预训练的词向量,尽可能少的尝试次数来最大化我们解游戏的准确性。
以下是游戏板的示例:

间谍主人视图(图片来源:作者)

操作员视图(图片来源:作者)
在间谍主人的卡片安排中,每张卡片的颜色代表红队、蓝队、中立(米色)和刺客卡(黑色)。
自动化间谍主人和操作员
我们将创建一个算法,它可以扮演间谍主人和操作员的角色,自动进行游戏。在一个包含 25 张卡片的游戏板上,会有 9 张“好”卡和 16 张“坏”卡(其中包括 1 张“刺客”卡)。
使用 GloVe 嵌入表示意义
为了使间谍头目能够给操作员提供好的线索,我们的模型需要能够理解每个词的含义。表示词义的一种流行方法是通过词嵌入。在这个任务中,我们将使用预训练的 GloVe 嵌入,其中每个词由一个 100 维的向量表示。
然后,我们使用余弦相似度来计算两个词之间的相似性,余弦相似度是两个向量的点积除以它们的模长:

图片来源:作者
操作员:解码算法
在每轮中,操作员接收到一个线索 c,以及一个表示要猜测的词数的整数 n。换句话说,操作员必须解码一个 {c, n} 对,并一次选择 n 个词,直到选择到错误词并结束该轮。
我们的解码器是一个直接贪婪算法:简单地根据与线索词 c 的余弦相似度对剩余的所有词进行排序,并根据相似度得分选择前 n 个词。
间谍头目:编码算法
在每轮中,基于剩余的“好”词和“坏”词,间谍头目必须选择 n 个词,并决定一个线索 c 给予操作员。我们在这里做出的假设是,间谍头目和操作员一致同意上面提到的解码策略,因此操作员会选择最优的 {c, n} 来最大化被选择的正确词的数量。
在这一点上,我们可以观察到线索 c 是一个信息瓶颈,因为它必须将 n 个词总结为一个单一的词 c 供操作员解码。编码后的线索的词向量与每个原始词向量位于同一个向量空间中。

编码器-解码器系统的机制(图片来源:作者)
生成线索候选词
词嵌入具有的特性使我们能够通过不同词向量的加法和减法来表示复合意义。给定“好”词集 G 和“坏”词集 B,我们可以利用这个特性,通过计算标准化均值来获得“好”词的“平均”意义,参考“坏”词,其中“好”词向量相加,“坏”词向量相减。这个平均向量使我们能够生成线索候选词:
glove = api.load("glove-wiki-gigaword-100")
good_words = [g_1, g_2, ..., g_n]
bad_words = [b_1, b_2, ..., b_n]
candidates = glove.most_similar(positive=good_words,negative=bad_words,topn=20)
负采样
由于“坏”词的数量通常超过“好”词的数量,我们进行负采样,通过随机采样与“好”词数量相等的“坏”词来计算我们的平均词向量。这还增加了生成线索时的随机性,从而提高了线索候选词的多样性。
在我们找到平均词向量后,我们使用 Gensim 中的 most_similar() 函数,根据余弦相似度从整个 GloVe 词汇表中获得与平均词向量最接近的单词。
评分函数
现在,我们有一种方法可以根据给定的 n 个单词生成线索候选 c。然而,我们仍然需要决定选择哪个候选 c,选择哪些 n 个单词,以及如何确定 n。
接下来,我们从剩余的 k 个“好”单词中,按照 k, k-1, …, 1 的顺序生成所有可能的单词组合,并为每个组合生成相应的线索词候选 c,从 k 开始向回推。为了选择最佳的 {c, n},我们通过解码算法对剩余的“好”单词的每个可能组合中的所有候选进行评分。然后,我们通过知道操作员将使用的策略,获得给定线索 c 时正确猜出的最大单词数量 count(c)。

图片由作者提供
结果
在每局游戏中,从 一份包含 400 个常见 Codenames 单词的列表 中抽取 25 个单词。总体而言,经过 100 次实验,我们的方法能够 100%正确选择单词,在平均 1.98 回合内完成游戏,或者每回合 4.55 次猜测(对于 9 个正确单词),最多需要 2 回合。

游戏进行时每回合的平均猜测次数(图片由作者提供)
换句话说,除少数例外情况外,这个算法几乎每局游戏都需要两回合,只有在一次回合中猜出所有单词时才会有例外。
让我们查看一下我们所做的线索和猜测的词嵌入分布示例。

1 局游戏的词嵌入散点图,经过 PCA 降维(图片由作者提供)
尽管生成的线索确实在某种程度上提供了操作员最终正确猜出的单词的语义摘要,但这些线索与猜测之间的关系从人类的角度来看可能并不明显。使线索更易于解释的一种方法是限制每回合最大猜测次数,这样可以生成更接近猜测语义的线索。
尽管如此,我们的算法仍然促进了每个单词的良好聚类结果,以便我们的解码器能够通过提供接近目标单词的良好线索,帮助它正确地猜出更多的单词。
结论
总结来说,这种基于贪婪的 GloVe 算法在 Codenames 游戏中作为间谍首领和操作员都表现得很好,它提供了一种有效的方式通过线索和数字来编码和解码单词。
在我们的模型中,编码器和解码器共享一个共同的策略,这个策略的工作原理与共享加密密钥类似。一个可能的限制是,编码器和解码器分开使用时效果不佳,因为人类玩家可能无法有效地解读生成的线索。
理解词嵌入和向量操作背后的机制是开始自然语言处理的一个极好方式。看到简单的方法也能在语义聚类和分类任务中表现出色,确实很有趣。为了进一步提升游戏体验,可以考虑加入强化学习元素或训练自编码器以获得更好的结果。
Github 仓库: github.com/jzh001/codenames
参考文献
Koyyalagunta, D., Sun, A., Draelos, R. L., & Rudin, C. (2021). 使用语言图和词嵌入玩 Codenames。人工智能研究期刊,71,319–346.
doi.org/10.1613/jair.1.12665Pennington, J., Socher, R., & Manning, C. (2014). Glove:用于词表示的全局向量。2014 年自然语言处理经验方法会议(EMNLP)论文集。
doi.org/10.3115/v1/d14-1162Li, Y., Yan, X., & Shaw, C. (2022). Codenames AI
xueweiyan.github.io/codenames-ai-website/Friedman, D., & Panigrahi, A. (2021). Codenames 算法
www.cs.princeton.edu/~smattw/Teaching/521FA21/FinalProjectReports/FriedmanPanigrahi.pdfJaramillo, C., Charity, M., Canaan, R., & Togelius, J. (2020). 词自机器人:使用变换器进行 Codenames 游戏中的词关联。人工智能与互动数字娱乐会议论文集,16(1),231–237.
doi.org/10.1609/aiide.v16i1.7435
使用 SQL 在几分钟内处理数十亿条记录 ⏱️
直接在内存中分析海量数据集——比以往更快
·发表于Towards Data Science ·阅读时长 25 分钟·2024 年 12 月 23 日
--

照片由Daniele Franchi提供,来自Unsplash
我分享这个友链,供非会员跟进。如果可以,请尽可能多地点赞 👏
🚨 文章友链
概览
在这个项目中,我们将通过直接将海量数据集加载到内存中,进行处理,从而实现比传统方法更快的分析。
通过利用内存处理,我们可以高效地处理大量数据,快速而有效地提取有意义的洞察。
虽然地理空间分析是该项目的核心焦点,但项目的主要目标是:
掌握大规模数据处理
应用基于 SQL 的分析技术
探索用于内存处理的高性能工具
这个项目通过使用地理空间数据作为实际案例,整合了这些概念。
处理推荐系统中的反馈循环 — 深度贝叶斯赌博算法
理解探索基本原理和深度贝叶斯赌博算法,解决推荐系统中的反馈循环问题
·发表于数据科学前沿 ·阅读时长 11 分钟·2024 年 7 月 31 日
--

图像来源:ChatGPT-4o
介绍
推荐系统模型通常会被训练来优化用户的参与度,如点击和购买。这样做的出发点是希望优先推荐用户以前参与过的项目。然而,这会形成一个反馈循环,随着时间的推移,可能会导致“冷启动问题”。简而言之,用户历史上受欢迎的项目倾向于继续被模型青睐。相比之下,新颖但高度相关的项目却没有得到足够的曝光。在本文中,我将从基础开始介绍探索技术,并最终解释深度贝叶斯赌博算法——这是一种在论文中由郭大林等人描述的高效算法[1]。
广告推荐系统
让我们以一个简单的广告推荐系统作为示例,贯穿本文。

一个简单的三组件广告推荐系统。图像来源:作者
这是一个三组件系统
检索:一个高效地检索候选广告以供排序的组件
排序:一个深度神经网络,用于预测给定用户的广告点击率(CTR),作为广告的评分
score = predict_ctr(user features, ad features)拍卖:一个组件,负责
为用户检索候选广告
使用排序模型对其进行评分
选择得分最高的广告并返回它*
本文的重点将专注于排序模型。
真实世界的拍卖系统也会考虑广告的竞标金额,但为了简化问题我们忽略这一点。
排名模型架构
排名模型是一个深度神经网络,给定用户和广告特征后,它可以预测广告的点击率(CTR)。为了简化起见,我在下面提出了一个简单的全连接 DNN,但实际上你可以使用诸如宽和深网络(wide-and-deep)、DCN 和 DeepFM 等技术来丰富该模型,而不会影响我在本文中解释的方法的适用性。

一个二分类器深度神经网络,预测 pCTR。图像由作者提供
训练数据
排名模型是基于包含点击数据(二进制标签)和用户与广告特征的拼接数据进行训练的。使用的特征集的具体内容对本文并不重要,但我假设其中包含了一些与广告商品牌相关的特征,以帮助模型学习用户对品牌的偏好。

包含样本特征的训练数据。图像由作者提供
冷启动问题
想象一下,我们成功地在广告点击数据集上训练了我们的排名模型,模型学会了我们的用户 Jane 喜欢购买“Vogue Voyage”品牌的包包。但市场上有一家新的包包公司“Radiant Clutch”,他们也卖着很棒的包包。然而,尽管“Radiant Clutch”进行了广告活动,旨在吸引像 Jane 这样的用户,Jane 从未看到过他们的广告。这是因为我们的排名模型已经深深地学会了 Jane 喜欢“Vogue Voyage”品牌的包包,只有他们的广告会展示给她。她有时会点击这些广告,而当模型在这些新点击上进一步训练时,只会加深模型的信念。这样就形成了一个恶性循环,导致某些商品一直处于黑暗中。

反馈循环的作用,导致冷启动问题:来自 Radiant Clutch 的包包毫无机会。图像由作者提供,缩略图由 ChatGPT-4o 生成
如果我们深思一下,就会意识到,模型并没有错,它通过学习得出“Jane 喜欢‘Vogue Voyage’品牌的包包”这一结论并没有问题。但问题在于,模型没有机会学习 Jane 对其他公司包包的兴趣。
探索与利用
这是一个引入“探索与利用”权衡的好时机。
利用:在广告拍卖过程中,一旦我们从排名模型中获得了点击率预测(CTR),我们就简单地选择得分最高的广告。这是一种 100%的利用策略,因为我们完全基于当前最佳的知识来获取最大的即时奖励。
探索:我们的方法缺乏的是愿意冒一些风险,即使广告未被赋予最高得分,也去展示该广告。如果我们这样做,用户可能会点击它,当模型根据这些数据更新时,它会学到一些新的信息。但如果我们从不冒这个险,模型将永远无法学到任何新东西。这就是探索的动机所在。
探索与开发是一个平衡的过程。探索过少会导致冷启动问题,而探索过多则可能向用户展示高度不相关的广告,从而失去用户信任和收入。
探索技术
现在我们已经为探索设定了舞台,接下来让我们深入探讨一些控制探索的具体技术。
ε-贪婪策略
这里的想法很简单。在我们的拍卖服务中,当我们拥有所有候选广告的得分时,除了直接选择得分最高的广告外,我们还会做以下操作:
在 0, 1)区间内选择一个随机数 r
如果 r < ε,则从候选广告中随机选择一则广告(探索)
否则,选择得分最高的广告(开发)
其中,ε是一个常量,我们在[0, 1)区间内仔细选择,知道该算法会以ε的概率进行探索,以 1 — ε的概率进行开发。

以 1 — ε的概率进行开发:选择点击率(CTR)最高的广告。图片来自作者
这是一种非常简单但强大的技术。然而,它可能过于天真,因为在进行探索时,它完全随机地选择广告。即使某个广告的预测点击率(pCTR)极低,且用户过去多次表示不喜欢,我们仍然可能展示该广告。这可能有些严苛,并可能导致收入和用户信任的严重损失。我们当然可以做得更好。
上置信界限(UCB)
我们进行探索的动机是确保所有广告候选项都有机会展示给用户。但随着广告的曝光,如果用户仍然没有与之互动,那么减少未来的曝光将变得明智。因此,我们需要一种机制,依据广告的得分估算以及广告已接收的曝光量来选择广告。
假设我们的排名模型不仅能预测点击率(CTR)得分,还能给出该得分的置信区间*。
如何实现这一点将在文章后面进行解释

模型会预测一个置信区间,并给出相应的得分。图片来自作者
这样的置信区间通常与广告接收到的曝光量成反比,因为广告展示给用户的次数越多,我们获得的用户反馈就越多,这减少了不确定性区间。

对广告的增加曝光会导致模型得分预测的置信区间减少。图片来自作者
在拍卖过程中,我们并不是选择具有最大 pCTR 的广告,而是选择具有最高上置信界限的广告。这个方法叫做 UCB。其背后的理念是“在不确定面前保持乐观”。这种方法有效地考虑了广告的分数估计以及它的不确定性。

UCB 实际应用:Ad-1 最初因其较大的置信区间赢得了拍卖,但随着模型对其的学习,其 UCB 降低,导致 Ad-2 赢得了拍卖。图像来自作者
汤普森采样
UCB 方法采用了“(完全的)在不确定面前保持乐观”的理念。汤普森采样则在此基础上稍微缓和了这种乐观。与其使用上置信界限作为广告的分数,为什么不在后验分布中抽样一个分数呢?
为了使这成为可能,假设我们的排名模型不仅能生成 CTR 和置信区间,还能生成一个实际的分数分布。
如何实现这一点将在文章后面解释

模型可以预测一个广告的分数分布。图像来自作者
然后,我们从这个分布中抽样一个分数,并在拍卖中使用这个分数。

Ad-1 因从其宽分布中抽样得到一个高分而赢得了拍卖。图像来自作者

Ad-1 已经获得曝光,模型对其的不确定性较小。由于 Ad-2 具有更高的分数分布质量,它赢得了拍卖。图像来自作者

随着 Ad-2 获得更多曝光,其分数分布的标准差进一步缩小。图像来自作者
更新模型的重要性
为了使 UCB 和汤普森采样技术生效,我们必须尽可能频繁地更新模型。只有这样,它才能根据用户反馈更新其不确定性估计。理想的设置是一个连续学习设置,其中用户反馈事件几乎实时地发送给模型,以更新其权重。然而,如果建立连续学习基础设施的成本过高,定期更新模型权重也是一种可行的选择。

一个高层次的连续学习设置,利用流媒体基础设施。图像来自作者,缩略图由 ChatGPT-4o 生成
后验近似技术
在 UCB 和汤普森采样方法中,我解释了我们的模型不仅产生一个分数,还会产生一个不确定性度量(可以是置信区间或分数分布)。这怎么可能呢?毕竟我们的 DNN 只能产生一个输出!下面是论文中讨论的几种方法。
自助法
统计学中的引导法简单来说就是带替换的抽样。这对我们来说意味着,我们在训练数据集上应用引导法,创建多个彼此密切相关但略有不同的数据集,并使用每个数据集训练一个独立的模型。这样学到的模型将会是彼此的轻微变体。如果你学习过决策树和集成法,你就会熟悉训练多个相关的树,这些树是彼此的轻微变体的这一思想。

引导数据集用于训练独立的模型,从而产生一个分数分布。图像由作者提供
在拍卖中,对于每个广告,我们从每个引导模型中得到一个分数。这为我们提供了一个分数分布,这正是我们所需要的用于汤普森采样。我们还可以从这个分布中提取置信区间,如果我们选择使用上置信边界(UCB)。
这种方法的最大缺点是训练和服务多个模型所带来的计算和维护开销。
多头引导法
为了降低多个引导模型的成本,这种方法将多个模型统一成一个多头模型,每个输出对应一个头部。

多头模型。图像由作者提供
关键的成本减少来自于所有层,除了最后一层,都是共享的这一事实。
训练仍然在引导数据子集上按常规进行。虽然每个引导数据子集应该用来更新所有共享层的权重,但必须注意仅使用一部分数据来更新一个输出头部的权重。

每个引导数据子集在反向传播过程中对一个头部的有限影响。图像由作者提供
随机梯度下降法(SGD)
我们可以不使用独立的引导数据集来训练不同的模型,而是只使用一个数据集,通过随机初始化权重训练每个模型,从而利用 SGD 提供的固有随机性。这样训练出的每个模型都会成为其他模型的一个变体。
多头 SGD
同样地,使用多头架构将使用引导法训练的模型数减少到一个,我们也可以使用多头架构与 SGD。我们只需要在每个头部随机初始化权重,这样在整个数据集上训练时,每个头部将学到是其他头部的轻微变体。
前向传播中的丢弃法
Dropout(丢弃法)是一种著名的正则化技术,在模型训练过程中,某一层的部分节点会随机丢弃,以防止过拟合。在这里,我们借用了这一思想,只不过我们是在前向传播过程中使用它来创造受控的随机性。
我们修改了排名模型的最后一层,加入了 dropout。然后,当我们想要为一个广告打分时,我们将它通过模型多次,每次得到的分数都会略有不同,这是由于 dropout 引入的随机性。这为我们提供了我们所需的分布和置信区间。

相同的模型通过随机 dropout 生成分数分布。图片由作者提供
这种方法的一个显著缺点是,它需要多次完整的前向传递,这在推理时可能会非常耗时。
混合方法
在混合方法中,我们进行了一项关键优化,既能够利用 dropout 和自助法的优势,又能降低服务和训练成本:
- 通过将 dropout 应用于倒数第二层,我们无需多次完整前向传递即可生成分数分布。我们可以执行一次前向传递,直到 dropout 层,然后并行地执行多个 dropout 层的调用。这为我们提供了与多头模型相同的效果,其中每个 dropout 输出像一个多头输出。
此外,dropout 通过随机停用一个或多个节点,充当其层次中高阶特征的伯努利掩码,从而产生类似自助法的效果,使用数据集的不同子集。
哪种方法最有效?
不幸的是,没有简单的答案。最好的方法是在你问题的限制下进行实验,看看什么最有效。但如果参考《Deep Bayesian Bandits》论文的作者的发现,
ε-贪婪方法由于其简单的探索方式,给出了最低的 CTR 提升,然而,它的简单性和低成本特性使得它非常有吸引力。
UCB 通常优于汤普森采样。
自助法 UCB 提供了最高的 CTR 回报,但由于需要处理多个模型,它也是计算成本最高的。
依赖于倒数第二层 dropout 的混合模型需要更多的训练轮次才能表现良好,并且与 SGD UCB 的性能相当,但计算成本较低。
模型的 PrAuc 在离线测量时与 CTR 增益呈反比关系:这是一个重要的观察结果,表明离线表现可以通过提供更容易的训练数据(例如,不包含显著探索的数据)轻松实现,但这并不总能转化为在线 CTR 的提升。这强调了稳健在线测试的重要性。
也就是说,针对不同的数据集和问题,结果可能会有很大不同。因此,现实世界中的实验仍然至关重要。
结论
在本文中,我介绍了推荐系统中由反馈循环引发的冷启动问题。根据《深度贝叶斯赌博机》一文,我们将广告推荐系统框架化为 k 臂赌博机,并看到了强化学习技术在缓解冷启动问题方面的许多实际应用。我们还初步探讨了在神经网络中捕捉不确定性的问题,这为贝叶斯网络的讨论提供了良好的过渡。
[1] Guo, Dalin, et al. “深度贝叶斯赌博机:探索在线个性化推荐。” 第 14 届 ACM 推荐系统会议论文集。2020 年。
处理时间序列中的缺口
短序列和长序列插补的缺失性分析与评估方法
·发表于Towards Data Science ·阅读时间 22 分钟·2024 年 1 月 31 日
--

由Willian Justen de Vasconcellos拍摄,图片来源于Unsplash
时间是物理学中最明确的连续体,因此在自然界中也如此。正因如此,时间序列数据集中的连续性——即一系列按时间顺序排列的观察——显得尤为重要。
这个概念本身驱动了本文的动机。现实世界中的数据集由于各种原因容易出现缺失值,比如传感器故障、数据采集失败,或是在某个特定时间段内信息的缺失。然而,这并不会改变你特征的数据生成过程的基本性质。
因此,理解那些中断的原因,并在时间序列数据集中进行分析与处理,对于后续任务至关重要。
目录
本文目标
数据集描述
库与依赖项
数据预处理
孤立缺失值与连续缺失值
缺失可视化
人为缺失数据
维度建模中的层次结构处理
对于层次结构,有多种建模技术可供选择。那么,哪种技术在数据仓库的维度建模中表现最佳呢?如何使用这些技术处理不同类型的层次结构?让我们一探究竟。
·发表于 Towards Data Science ·14 分钟阅读·2024 年 7 月 24 日
--
层次结构在数据仓库的维度建模中起着至关重要的作用,它们影响着数据分析的结构和效率。通过我在为多家公司实施数据解决方案的经验,本文探讨了处理不同类型层次结构的最佳实践和技术。通过详细的示例和实用指南,我将带领大家应对处理各种层次结构的复杂性,确保数据仓库设计的稳健性和可扩展性。
在处理层次结构时,重要的是要认识到它们的具体性以及所有相关的细微差别。因此,在深入探讨建模技术之前,让我们先看看在实际层次结构场景中可能遇到的各种问题。本文中的示例是虚构的,但灵感来源于我为一家全球制药公司实施项目时的实际案例。尽管这些示例被大大简化,但它们仍然展示了数据建模中有趣的方面。
示例层次结构概述
让我们考虑以下示例层次结构:某公司的内部组织结构…
实践操作:使用人工智能构建虚拟房地产顾问
这就是我如何使用房地产数据,并利用 OpenAI 的大型语言模型 GPT-3 为其提供动力。
·发表于 Towards Data Science ·阅读时间 11 分钟·2024 年 4 月 6 日
--

图片由作者使用 DALL·E3 制作
这篇文章以一个个人故事开始。
我来自意大利,目前已在美国生活了 4 年。
我从罗马大学毕业后直接搬到了这里,当时我 23 岁。我有很多梦想,极大的热情(对人工智能的热情),还有一份在辛辛那提大学攻读博士学位的计划。
我的大部分人生都在意大利度过,吃着美食,喝着美味的咖啡,坐在交通中度过漫长的时间(几乎去哪里都得堵车)🙃。然而,在美国,我“长大成人”(这是我妻子教我的词)做了很多在意大利从未单独做过的事情,因为我的父母远在大洋彼岸。有一件我和妻子第一次一起做的事情就是找房子。
1. 关于找房子
对于我的非美国粉丝来说,“找房子”就是简单的寻找理想住宅。之所以用“狩猎”这个词,是因为有太多的房子、太多的中介和网站可以查询,情况远不止...
使用马尔可夫链进行职业路径建模,配合 Python 实践
这就是我如何使用基础概率模拟职业发展的方式
·发表于Towards Data Science ·阅读时间 12 分钟·2024 年 7 月 29 日
--

图像由作者使用 DALL·E 生成
从专业角度来说,我是一个非常奇怪的人:我在一家初创公司担任软件/机器学习工程师,我拥有物理学硕士学位,并且即将为我的航空航天与机械工程博士论文进行答辩。在我不断变化的职业生涯中,有两件事始终不变:我对科学的热爱和对编程的激情。
将科学与编程结合的一个美妙方式就是进行建模。我的意思是,为了描述世界,你基于对现实的某种程度的近似做出合理的假设。基于这个假设和初始的近似,我们可以模拟一个给定的过程。这个模拟将给我们一些结果,这些结果来源于最初的假设,但在模拟之前是无法精确预测的。
举个例子,假设我们正在试图弄清楚多少头牛可以放进一个围栏。一个物理学家可能会做出一个相当离奇的假设:
“让我们考虑一头正方形的牛”
这意味着我们近似将一头牛的形状看作一个正方形。然后我们将围栏近似为一个更大的正方形……
使用 Google Mesop 进行数据可视化实战
Google Mesop 是一个易于使用的 Python UI 框架。我们将展示如何利用它与 Plotly 一起创建数据可视化应用。
·发布于 Towards Data Science ·阅读时间 14 分钟·2024 年 10 月 25 日
--

通过将 Mesop 简单而强大的 UI 组件与 Plotly 强大的图表功能相结合,你可以构建一个动态、互动的应用,展示 Python 中的数据可视化效果。
Mesop 是一个原生 Python 框架,旨在快速开发 AI 和 Web 应用,它允许你在无需传统前端技能的情况下构建复杂的界面;Plotly,当然,是一个广为人知且被广泛使用的图形库。
我们将使用这些组件来构建一个数据可视化应用。
Mesop
Mesop 不是 Google 官方支持的产品,但它文档齐全,在 Google 内部广泛使用,并且正受到相当的关注,因此我认为我们可以认真对待它。Mesop 是一个开源项目,使用 Apache-2.0 许可协议 [1]。
这是一个在 Mesop 中实现的“Hello World”应用。
import mesop as me
@me.page(path="/")
def app():
me.text("Hello World")
Mesop 的 Hello World
使用 Python 进行全球优化方法实践
四种方法用来寻找你黑箱目标函数的最大值(或最小值)。
·发布于 Towards Data Science ·阅读时间 13 分钟·2024 年 9 月 4 日
--

图片由作者使用 DALL·E 制作
想象一下和你的好朋友一起出去。你决定去一个你从未去过的地方,但你的好朋友曾经去过。
还要想象一下,你们所在的城市有些交通问题,比如罗马(我知道,因为我就是那里的人,真的很痛苦),并且你和你的好朋友会在那里见面,但是你们从同一座房子出发,各自开着两辆车。
你和你的好朋友同时离开家,由于你从未去过那个地方,你使用GPS来到达那里。
当你到达那里时,你发现你的朋友已经在点饮料,并告诉你他已经等你 10 分钟了。假设你和你的朋友都没有超速,这个故事告诉我们两件事:
你应该生气于你的好朋友,因为他本可以告诉你更快的路线,而不是让你使用 GPS。为什么他当时没有打电话给你呢? 😡
你刚刚遇到了一个局部最小值问题。
在这篇文章中,我们将专注于问题二,其余的就交给你自己去处理你和你朋友的事情吧。
实践模仿学习:从行为克隆到多模态模仿学习
对模仿学习中最突出的几种方法进行概述,并在网格环境中进行测试
·发表于 Towards Data Science ·14 分钟阅读·2024 年 9 月 12 日
--

图片由 Possessed Photography 提供,来源于 Unsplash
强化学习是机器学习的一个分支,关注通过标量信号(奖励)的指导进行学习;与需要目标变量完整标签的监督学习不同。
可以通过一个直观的例子来解释强化学习,例子描述的是一所学校的两门课,每门课有两种类型的测试,且测试不断重复进行。第一班解答测试并得到了所有正确答案(监督学习:SL)。第二班解答测试并且每个问题只得到了分数(强化学习:RL)。在第一种情况下,学生们似乎更容易学习正确答案并记住它们。在第二班中,任务更困难,因为他们只能通过反复试验来学习。然而,他们的学习将更加稳健,因为他们不仅知道什么是对的,还知道所有错误的答案,从而避免它们。
为了在强化学习(RL)中高效学习,应该设计一个准确的奖励信号(评分),这被认为是一个困难的任务,尤其是在现实世界的应用中。例如,一个人类专家驾驶员知道如何开车,但无法为“正确驾驶”技能设置奖励,做饭或绘画也是如此。这就产生了对模仿学习方法(IL)的需求。IL 是强化学习的一个新分支,专注于从纯粹的专家轨迹中学习,而不需要知道奖励。IL 的主要应用领域是在机器人技术和自动驾驶领域。
接下来,我们将探索文献中最著名的 IL 方法,按提议时间从旧到新排序,如下图所示的时间线。

IL 方法的时间线
数学公式将与符号的命名一同展示。然而,理论推导在这里保持到最小;如果需要进一步的深度,可以查阅引用部分列出的原始文献。重现所有实验的完整代码已提供在随附的github 仓库中。
所以,系好安全带!让我们通过模仿学习深入探索,从行为克隆(BC)到信息最大化生成对抗模仿学习(InfoGAIL)。
示例环境
本文使用的环境表示为一个 15x15 的网格。环境状态如下所示:
代理:红色
初始代理位置:蓝色
墙壁:绿色

代理的目标是通过任何三个窗口,以最短的方式到达第一行,并使其位于相对于通过网格中心的垂直轴的初始位置的对称位置。目标位置不会在状态网格中显示。
因此,初始位置只有 15 种可能性,目标位置会基于此发生变化。
动作空间
动作空间 A 包含从 0 到 4 的离散数值,表示四个方向的移动和停止动作,如下所示:

奖励函数
这里的真实奖励 R(s,a) 是当前状态和动作的函数,其值等于朝目标移动的位移距离:
其中,𝑝1是旧位置,p2是新位置。代理总是从最后一行初始化,但每次位置都是随机的。
专家策略
所有方法(除 InfoGAIL 外)使用的专家策略旨在以最短的路径到达目标。这涉及三个步骤:
朝最近的窗口移动
直接朝目标移动
停止在目标位置
该行为由一个 GIF 演示:

专家策略生成用于其他强化学习方法的演示轨迹。每条轨迹τ表示为一系列有序的状态-动作对。
专家演示集定义为 D={τ0,⋯,τn}
专家的每集回报平均为 16.33±6,共 30 集,每集长度为 32 步。
前向强化学习
首先,我们将使用真实奖励训练一个模型,设定一些基准并调优超参数,以便后续与 IL 方法一起使用。
本文中使用的前向强化学习算法的实现基于 Clean RL 脚本[12],该脚本提供了可读的强化学习方法实现。
介绍
我们将分别测试最先进的策略内方法——近端策略优化(PPO)[2]和著名的策略外方法——深度 Q 网络(DQN)[1]。
以下是每种方法的训练步骤总结以及它们的特点:
在策略内(PPO)
该方法使用当前训练中的策略,并在每次收集回合后更新其参数。PPO 包含两个主要部分:评论员和演员。演员表示策略,而评论员为每个状态提供价值估计,并具有自己的更新目标。
策略外(DQN)
DQN 通过收集回放缓冲区中的回合来离线训练其策略,采用ε-贪婪探索策略。这意味着 DQN 在每个状态下并不总是选择当前策略下的最佳动作,而是随机选择一个动作。这有助于探索不同的解决方案。可能还会使用一个目标网络,该网络以较低的频率更新策略版本,以使学习目标更加稳定。
结果与讨论
下图显示了两种方法的每集回报曲线。DQN 用黑色表示,而 PPO 用橙色线条表示。

对于这个简单的例子:
PPO 和 DQN 都能收敛,但 PPO 略有优势。两种方法都未达到专家级别的 16.6(PPO 接近,达到 15.26)。
与 PPO 相比,DQN 在收敛速度上似乎较慢,表现为与交互步骤(也称为样本效率)相关的低效率。
PPO 训练时间较长,可能是由于演员-评论员训练,需要更新两个具有不同目标的网络。
训练这两种方法的参数大致相同。欲深入了解这些曲线是如何生成的,可以查看随附仓库中的ppo.py和dqn.py脚本。
行为克隆(BC)
行为克隆(Behavior Cloning,BC),最早在[4]中提出,是一种直接的强化学习方法。它通过监督学习将每个状态映射到一个动作,基于专家演示集 D。目标定义为:
其中π_bc是训练后的策略,π_E是专家策略,l(π_bc(s),π_E(s))是响应同一状态时,专家策略和训练策略之间的损失函数。
BC 和监督学习的区别在于将问题定义为一个交互式环境,在该环境中,行为是对动态状态的响应(例如,一个机器人朝着目标移动)。相比之下,监督学习涉及将输入映射到输出,比如图像分类或温度预测。这个区别在[8]中有解释。
在此实现中,智能体的初始位置全集只有 15 种可能性。因此,只有 15 条轨迹可以学习,而这些轨迹可以被 BC 网络有效地记住。为了增加问题的难度,我们将训练数据集 D 的大小裁剪为一半(仅有 480 个状态-动作对中的 240 个),并将此操作应用于后续所有的 IL 方法。
结果
在训练模型后(如bc.py脚本所示),我们得到的平均回报是 11.49,标准差为 5.24。
这比之前的前向强化学习方法要小得多。以下 GIF 展示了训练好的 BC 模型的实际表现。

从 GIF 中可以看出,几乎三分之二的轨迹已经学会了穿越墙壁。然而,模型在最后三分之一的轨迹上遇到了困难,因为它无法从之前的例子中推断出真实的策略,特别是由于它只从 15 条专家轨迹中的一半进行学习。
最大熵逆强化学习(MaxENT)
MaxEnt [3] 是一种与行为克隆(BC)并行的训练奖励模型的另一种方法(而不是迭代方式)。其主要思想是基于当前的奖励函数,最大化采取专家轨迹的概率。这可以表示为:
其中 N 是轨迹长度,Z 是在给定策略下,所有可能轨迹回报之和的归一化常数。
从这里开始,该方法基于最大熵定理[3]推导出其主要目标,该定理指出:满足给定条件的最具代表性的策略是具有最大熵 H 的策略。因此,MaxEnt 需要一个额外的目标来最大化策略的熵。这导致了以下公式:
其导数为:
其中 SVD 是状态访问频率,可以通过动态规划算法在给定策略下计算得出。
在我们这里实现的最大熵(MaxEnt)中,我们跳过了训练一个新的奖励模型,因为动态规划算法在此过程中会变得缓慢且冗长。相反,我们选择通过重新训练一个行为克隆(BC)模型来测试最大化熵的主要思想,方法与之前的过程完全相同,只是将推断出的行动分布的负熵项添加到损失函数中。熵应该是负的,因为我们希望通过最小化损失来最大化它。
结果
在给动作分布添加了权重为 0.5 的负熵(选择正确的值很重要,否则可能导致更差的学习)后,我们看到与之前的 BC 模型相比,性能略有提升,现在的平均回合回报为 11.56(+0.07)。这种小幅提升可以通过环境的简单性质来解释,该环境包含有限数量的状态。如果状态空间变大,熵的作用预计会变得更加重要。
生成对抗模仿学习(GAIL)
GAIL 的原始工作[5]灵感来源于生成对抗网络(GANs)的概念,GANs 将对抗训练的理念应用于增强主模型的生成能力。同样,在 GAIL 中,这一概念被应用于使训练策略与专家策略之间的状态-动作分布相匹配。
这可以推导为 Kullback-Leibler 散度,如主文献[5]所示。文献最终推导出了两个模型(在 GAIL 中称为生成器和判别器模型)的主要目标,如下所示:

其中,Dt 是判别器,πθ 是生成器模型(即正在训练的策略),πE 是专家策略,H(πθ) 是生成器模型的熵。
判别器充当二分类器,而生成器则是实际的策略模型,正在接受训练。
GAIL 的主要优势
GAIL 相对于以前方法的主要优势(也是其表现更好的原因)在于其交互式训练过程。训练得到的策略在判别器奖励信号的指导下学习并探索不同的状态。
结果
在训练了 1.6 百万步后,GAIL 模型收敛到了比 BC 和 MaxEnt 模型更高的水平。如果继续训练,甚至可以取得更好的结果。
具体来说,我们获得了 12.8 的平均回合奖励,这一点值得注意,因为只有 50%的示范在没有任何真实奖励的情况下提供。
此图显示了 GAIL 的训练曲线(y 轴为真实的回合奖励)。值得注意的是,由于 GAIL 的对抗性训练特性,来自 log(D(s,a))的奖励会比真实奖励更加混乱。

对抗性逆强化学习(AIRL)
GAIL 的一个剩余问题是,训练得到的奖励模型(判别器)并不真正代表真实的奖励。相反,判别器被训练为一个专家与生成器状态-动作对之间的二分类器,导致其值的平均为 0.5。这意味着判别器只能作为一个替代奖励。
为了解决这个问题,文献[6]通过以下公式重新构造了判别器:
其中,fω(s,a) 应该收敛到实际的优势函数。在这个例子中,这个值表示代理与不可见目标的接近程度。通过添加另一个项以包含形状奖励,可以得到地面真实奖励;然而,在本次实验中,我们将限制使用上面的优势函数。
结果
在使用与 GAIL 相同参数训练 AIRL 模型后,我们得到以下训练曲线:

注意,考虑到相同的训练步数(160 万步),由于训练判别器的复杂性增加,AIRL 收敛较慢。然而,现在我们已经得到了一个有意义的优势函数,尽管其表现只有 10.8 的每集奖励,但仍然足够好。
让我们将这个优势函数的值与响应专家演示的地面真实奖励进行比较。为了使这些值更具可比性,我们还对学习到的优势函数 fω 的值进行了归一化。由此,我们得到了以下图表:

在这个图中,有 15 个脉冲对应代理的 15 个初始状态。我们可以看到,在图的后半部分,训练模型的误差较大,这是由于在训练中仅有限使用了专家演示的一半。
对于前半部分,我们观察到在代理静止在目标位置并获得零奖励时,状态较低,而在训练模型中评估时该值较高。后半部分则普遍偏向于较低的值。
大致来说,学习到的函数大致跟随地面真实奖励,并通过 AIRL 恢复了关于它的有用信息。
信息最大化生成对抗模仿学习(InfoGAIL)
尽管前述方法取得了一定进展,但在模仿学习(IL)中仍然存在一个重要问题:多模态学习。为了将 IL 应用于实际问题,需要从多个可能的专家策略中学习。例如,在开车或踢足球时,并没有一种“正确”的做事方式;专家在方法上有所不同,IL 模型应该能够一致地学习这些变化。
为了解决这个问题,开发了 InfoGAIL [7]。InfoGAIL 受到 InfoGAN [11] 的启发,InfoGAN 使用额外的风格向量来调节生成对抗网络(GAN)生成输出的风格,InfoGAIL 在 GAIL 目标的基础上增加了另一个标准:最大化状态-动作对与新的控制输入向量 z 之间的互信息。这个目标可以推导为:
Kullback-Leibler 散度,
其中,后验概率 p(z∣s,a) 的估计通过一个新模型 Q 进行近似,Q 以 (s,a) 为输入,输出 z。
InfoGAIL 的最终目标可以写为:

结果是,策略有一个额外的输入,即 z,如以下图所示:

在我们的实验中,我们生成了新的多模态专家演示,其中每个专家只能从一个间隙进入(墙上有三个间隙),与其目标无关。完整的演示集在没有标签的情况下使用,标签不会指明是哪个专家在行动。z 变量是一个独热编码向量,表示专家类别,包含三个元素(例如,左门的表示为[1 0 0])。策略应该:
学习朝向目标移动
随机生成的 z 值与不同专家的模式相联系(从而通过不同的门)。
Q 模型应该能够根据每个状态中动作的方向来检测其所处的模式。
注意,由于对抗性训练,判别器、Q 模型和策略模型的训练图表是混乱的。
幸运的是,我们成功地清晰地学习到了两种模式。然而,第三种模式既没有被策略也没有被 Q 模型识别出来。以下三个 GIF 展示了 InfoGAIL 在给定不同 z 值时学到的专家模式:

z = [1,0,0]

z = [0,1,0]

z = [0,0,1]
最后,策略能够在 800K 训练步数下收敛到大约 10 的回报。随着训练步数的增加,可以取得更好的结果,即使这个例子中使用的专家不是最优的。
最终概述与结论
在回顾我们的实验时,很明显所有的模仿学习方法在集回报标准方面表现良好。下表总结了它们的表现:

InfoGAIL 的结果不可比,因为专家演示基于多模态专家
表格显示,GAIL 在这个问题上表现最好,而 AIRL 由于其新的奖励公式较慢,导致回报较低。InfoGAIL 也学得不错,但在识别所有三个专家模式时遇到了一些困难。
结论
模仿学习是一个具有挑战性且迷人的领域。我们探索的方法适用于网格仿真环境,但可能无法直接转化为现实应用。模仿学习的实际应用仍处于起步阶段,除了某些 BC 方法之外。将仿真与现实联系起来会由于两者的性质差异而引入新的误差。
模仿学习中的另一个开放挑战是多智能体模仿学习。像 MAIRL [9] 和 MAGAIL [10] 这样的研究已经在多智能体环境中进行过实验,但从多个专家轨迹学习的通用理论仍然是一个开放问题。
附加的GitHub 仓库提供了实现这些方法的基本思路,且可以方便地进行扩展。代码将在未来更新。如果您有兴趣贡献,请提交一个问题或拉取请求,包含您的修改。或者,欢迎留下评论,我们会跟进并提供更新。
注意:除非另有说明,所有图像均由作者生成
参考文献
[1] Mnih, V. (2013). 通过深度强化学习玩 Atari 游戏。arXiv 预印本 arXiv:1312.5602.
[2] Schulman, J., Wolski, F., Dhariwal, P., Radford, A., & Klimov, O. (2017). 近端策略优化算法。arXiv 预印本 arXiv:1707.06347.
[3] Ziebart, B. D., Maas, A. L., Bagnell, J. A., & Dey, A. K. (2008 年 7 月). 最大熵逆向强化学习。载于《AAAI》(Vol. 8, pp. 1433–1438).
[4] Bain, M., & Sammut, C. (1995 年 7 月). 行为克隆框架。载于《机器智能 15》(pp. 103–129).
[5] Ho, J., & Ermon, S. (2016). 生成对抗模仿学习。《神经信息处理系统进展》(Advances in neural information processing systems), 29.
[6] Fu, J., Luo, K., & Levine, S. (2017). 通过对抗逆向强化学习学习稳健奖励。arXiv 预印本 arXiv:1710.11248.
[7] Li, Y., Song, J., & Ermon, S. (2017). Infogail:从视觉示范中进行可解释的模仿学习。《神经信息处理系统进展》(Advances in neural information processing systems), 30.
[8] Osa, T., Pajarinen, J., Neumann, G., Bagnell, J. A., Abbeel, P., & Peters, J. (2018). 以算法视角看模仿学习。《机器人学基础与趋势》(Foundations and Trends® in Robotics), 7(1–2), 1–179.
[9] Yu, L., Song, J., & Ermon, S. (2019 年 5 月). 多智能体对抗逆向强化学习。载于《国际机器学习大会》(International Conference on Machine Learning)(pp. 7194–7201)。PMLR.
[10] Song, J., Ren, H., Sadigh, D., & Ermon, S. (2018). 多智能体生成对抗模仿学习。《神经信息处理系统进展》(Advances in neural information processing systems), 31.
[11] Chen, X., Duan, Y., Houthooft, R., Schulman, J., Sutskever, I., & Abbeel, P. (2016). Infogan:通过信息最大化生成对抗网络进行可解释的表示学习。《神经信息处理系统进展》(Advances in neural information processing systems), 29.
[12] Huang, S., Dossa, R. F. J., Ye, C., Braga, J., Chakraborty, D., Mehta, K., & AraÚjo, J. G. (2022). Cleanrl:高质量的单文件深度强化学习算法实现。《机器学习研究期刊》(Journal of Machine Learning Research), 23(274), 1–18.
使用 XGBoost 进行单调时间序列预测,使用 Python
这是如何在预测场景中使用 XGBoost,从理论到实践
·发表于Towards Data Science ·9 分钟阅读·2024 年 3 月 29 日
--

该图由作者使用 DALL·E-3 制作
几个月前,我在一个研究项目中遇到了一个涉及时间序列的问题。
这个问题相当简单:
“从这个时间序列开始,时间步长为 t,预测接下来的 k 个值”
对于机器学习爱好者来说,这就像写“Hello World”一样,因为这个问题在社区中非常有名,通常称为“预测”。
机器学习社区开发了许多可以用于预测时间序列下一步值的技术。一些传统方法包括像ARIMA/SARIMA这样的算法,或傅里叶变换分析,以及其他更复杂的算法,比如卷积神经网络/循环神经网络或超著名的“Transformer”算法(ChatGPT 中的 T 代表 transformers)。
尽管预测问题是一个众所周知的问题,但讨论带约束的预测问题可能并不常见。
让我解释一下我的意思。
你有一个带有一组参数X和时间步长t的时间序列。
标准时间预测…
《动手实践神经网络与时间序列,使用 Python》
从非常简单的前馈神经网络到壮丽的变换器:你需要了解的一切。
·发表于Towards Data Science ·阅读时间:12 分钟·2024 年 8 月 30 日
--

图片来自Alina Grubnyak在Unsplash上的作品
在我攻读本科学位时,我最喜欢的教授曾告诉我:
一旦某个事物足够好用,没人再会称它为“人工智能”。
这个概念与拉里·泰斯勒的观点一致,他曾说过:“AI 就是尚未实现的事物。”人工智能的第一个例子是计算器,它能够在瞬间完成非常复杂的数学计算,而人类则需要几分钟甚至几小时。然而,当我们今天谈论人工智能时,我们并不会想到计算器。我们不会想到它,因为它的工作方式已经非常出色,大家习以为常。谷歌搜索算法,在许多方面比计算器更复杂,是我们日常生活中使用的人工智能之一,但我们却很少会去思考它。
那么,什么才是真正的“人工智能”?我们什么时候不再定义某个事物为人工智能?
这个问题相当复杂,因为如果我们真正思考,人工智能有多个层次和领域。
使用 Python 进行数值微分实战,从零到英雄
这里是你需要知道的所有内容(超越标准定义),让你掌握数值微分的世界。
·发布于 Towards Data Science ·9 分钟阅读·2024 年 9 月 22 日
--

由 Roman Mager 提供的照片,来源于 Unsplash
有一个传奇性的说法,你可以在每个大学的至少一个实验室找到,它是这样的:
理论是你知道一切,但什么都不起作用。
实践是当一切都能奏效,但没有人知道为什么时。
在这个实验室,我们将理论与实践结合:什么都不起作用,且没有人知道为什么。
我发现这句话在数据科学领域非常贴切。我之所以这么说,是因为数据科学一开始是一个数学问题(理论):你需要最小化一个损失函数。然而,当你进入**现实生活(实验/实验室)**时,事情开始变得非常混乱,你完美的理论世界假设可能不再适用(它们从未奏效过),而你也不知道为什么。
比如,拿导数这个概念来说。每个接触过数据科学复杂概念的人都知道(或者,更好的是,必须知道)导数是什么。但是,如何将导数这一优雅且理论化的概念应用到现实生活中,应用到一个噪声信号上,在这种情况下你没有分析式…
《使用遗传算法进行动手优化,Python 实现》
这是关于遗传算法的完整指南,介绍了它们是什么以及如何使用它们
·发表于 Towards Data Science ·阅读时间 13 分钟·2024 年 9 月 29 日
--

图片来源:digitale.de 于Unsplash
你听说过这种销售策略吗?
“你是否厌倦了在 X 上浪费时间?为什么不试试 Y 呢!”
我敢肯定你听说过。例如:“不要花费几个小时编写自己的代码,使用我们的软件!” 或者“别再浪费几个小时在广告上,改用 AI 吧!”(我的 YouTube 推荐算法非常喜欢这个,可能是我的错,因为我总是谈论 AI)。
这种销售策略基于这样一个事实:你在想要解决的问题上花费的时间是可以优化的。
“从极客角度来说”,当你看到这种销售策略时,通常有两个选择:探索或开发。我是什么意思呢?稍后我会解释。现在,记住这个销售策略,把它放在脑后,稍后我们会再回到它。
假设你有一个工具,一个可以调整的旋钮。通过移动这个旋钮,你可以得到一个等待时间。想象一下,你在星巴克排队等咖啡,旋钮有两个状态(0 和 1):0 可能意味着“待在原来的队伍里”,而 1 可能意味着“换到另一个队伍”。哪一个意味着“等待最少的时间”...
《在 Python 中进行期望改进和高斯过程回归的实操优化》
《在 Python 中进行全局优化期望改进的友好指南》
·发表于Towards Data Science ·11 分钟阅读·2024 年 6 月 9 日
--

图片由作者使用 DALL·E 制作
免责声明:高斯出生于 1777 年,他比我聪明得多。很多人在我之前就写过这些内容,并且做得非常出色。这篇文章的一个很好的参考是这篇,由@okanyenigun写得非常好。
我的 妻子是商科专业,如果你问她:“什么是人工智能?”
她说:
“某些事物是通过教给一堆信息来发展思维和他妈的(她没有说‘他妈的’)…我不知道…想法。”
我会说这与人们描述人工智能的方式相当一致。经过一些调整后,这也离事实不远。
现在,如果你和一位数学家或物理学家交谈,答案会更为技术性,可能还会让人觉得很无聊(我是物理学家,所以我可以这么说)。如果要我定义人工智能,我会说:
“人工智能是一组使用数据来最小化损失函数的算法”
…我告诉过你这会更无聊。那么我这是什么意思呢?
0. 机器学习…
使用自编码器进行时间序列异常检测,Python 实战
这里是如何用几行代码使用自编码器来检测含有异常的信号
·发表于Towards Data Science ·10 分钟阅读·2024 年 8 月 21 日
--

异常时间序列是一个非常严肃的问题。
如果你想到地震,异常就是数据中突然的峰值或下降的地震信号,这些信号暗示着某些不好的事情正在发生。
在金融数据中,大家都记得 1929 年的华尔街崩盘,那是金融领域中一个信号带有异常的明显例子。在工程学中,信号中的尖峰可能代表超声波反射到墙壁或人的情况。
所有这些故事都源于一个非常明确的问题:
如果我有一组正常信号,并且有一个新的信号进入,我该如何检测该信号是否异常?
请注意,这个问题与检测给定信号中的异常的问题略有不同(这个问题也是一个著名的待解决问题)。在这种情况下,我们假设得到一个全新的信号,并且我们想知道该信号是否与我们数据集中被认为是“正常”的信号有显著差异。
实战操作 Moirai:Salesforce 的基础预测模型
了解 Moirai 的架构及其内部工作原理,并将其应用于使用 Python 的预测项目
·发表于数据科学前沿 ·阅读时长 14 分钟·2024 年 8 月 20 日
--

图片由Shreyas Bhosale提供,来源于Unsplash
我们已经进入了一个大规模基础模型普及的时代。基础模型已经在许多领域引发了革命,如计算机视觉和自然语言处理,出现了能够生成文本、图像和视频的模型和应用。
时间序列预测领域同样未能抵挡这一潮流,许多基础模型开始应用于预测。这标志着一个重要的范式转变,因为我们现在可以生成零-shot(零样本)时间序列数据预测,避免了训练数据特定模型的成本和开发时间。
2023 年 10 月,TimeGPT-1 发布,标志着首批基础预测模型的诞生。随后,在 2024 年 2 月,Lag-Llama 发布,并在 2024 年 3 月紧接着发布了 Chronos。
2024 年 5 月,发布了一种新的开源基础预测模型:Moirai。在论文统一训练的通用时间序列预测变换器中,Salesforce 的研究人员提出了一种能够进行概率性零-shot 预测的基础模型,同时还支持外生特征。
亲身体验 OpenAI 的 Swarm 多代理框架
AI 代理编排
使用 Swarm,一个回归基础的框架构建多代理系统
·发布于 Towards Data Science ·14 分钟阅读·2024 年 11 月 5 日
--

图片由作者和 Bing 图像创作工具提供
有一个关于伦敦公交车的笑话:你等了很久才等到一辆,结果突然三辆一起到达。(这个笑话虽然听起来像笑话,但也是真实的)。
代理编排框架开始让人有这种感觉。当然,我们有来自常见厂商的解决方案,比如 LangChain 和 LlamaIndex,但新产品以惊人的频率涌现:CrewAI、MotleyCrew 和 Autogen 是我最近遇到的一些产品,但还有更多框架我甚至没有时间查看 GitHub 上的 README 文件。
现在,来自 OpenAI 这样的大公司推出了Swarm。因为它来自 OpenAI,我觉得应该仔细了解一下。Swarm 从其他产品的复杂性中退后一步。与其将自己呈现为一个功能齐全的解决方案,它更像是一个由 OpenAI 开发的实验性、教育性框架,旨在探索轻量级的多代理编排。看起来它的目标是让你了解工具使用的多代理系统是如何工作的,而不是提供像上述产品那样的生产就绪解决方案。
为分析经理制定全面的仪表板策略
如何解决探索与利用之间的张力,并确保你拥有忠实的用户,他们对你的可视化工具感到满意。
·发布于 Towards Data Science ·阅读时间 6 分钟·2024 年 1 月 29 日
--
开放世界视频游戏的开发者面临一个非常有趣的问题——如何在鼓励玩家探索世界的同时,又能确保他们完成主线任务并遵循一定的规则?
想要通过仪表板帮助自己不太懂的同事的分析经理也面临着同样的问题。他们总是面临“探索”和“利用”之间的张力。一方面,他们需要构建一个足够全面的系统,以便允许探索边缘案例;另一方面,他们不能用过多的信息把用户“淹没”。找到合适的平衡点往往相当复杂。
让我们回到开放世界视频游戏开发者的问题。为了在解决探索与利用之间的张力的同时,使游戏愉快,他们:
(1) 构建一个主要的故事线,并设有替代路径(即“支线任务”)
(2) 包含一个强大的知识管理系统:
他们提供了“冒险日志”,你可以在其中跟踪关键信息
它们通常以教程开始
(3) 激活“在线社区”,为专家提供分享知识的空间
(4) 随时间修补并改进游戏机制
这种策略同样可以被分析经理应用。

将开放世界视频游戏中的学习应用到仪表板的构建中——图片来自作者
你的用户的主线任务是什么?
想象一下:你是一个销售人员,你正在试图了解即将见面的客户在过去 2 年里与公司交易的金额。你尝试在手机上打开一个仪表板——它不支持手机端。你尝试调整屏幕大小,以找到合适的过滤器——你的客户有多个业务地点,所以你需要按这些地点全部筛选。然后,仪表板加载过去 6 个月的数据花了 2 分钟,但你实际上需要的是过去 2 年的数据,而你忘记更新筛选器……
很有可能,在某个时刻,你放弃了,直接联系了你的数据分析师朋友。
在上面的例子中——能够在几分钟内获取他们投资组合的相关数据是销售人员的“标准”需求——意味着这是大多数销售人员在他们的日常流程/工作流中反复需要的内容。它基本上是他们主要任务的一部分。
放眼全局:每个团队都有一套需要定期跟踪的指标,并且这些指标有预定的粒度。对于这些指标,他们需要一些简单易用的仪表板,仪表板中只包含有限数量的指标和过滤器。
一旦他们建立了能有效满足主要需求的仪表板,他们就会开始提出一些功能请求(“啊,如果我们也能看到 X、Y、Z 就好了”)。在你将这些请求添加到现有仪表板之前,重要的是要保持第一性原理的思维:这个仪表板最初是为了满足什么样的用例?
根据经验,最好采取以下方法:
拥有一个集中式仪表板,处理标准用例——该仪表板的指标和过滤器非常有限。
拥有一个(或多个)仪表板来解决边缘情况,同时允许团队进行一定的探索。
然而,要警惕仪表板膨胀,因为这会在未来引发很多问题(从维护到数据冲突)。重要的是要始终跟踪每个仪表板解决了哪些用例——并确保可以合并的部分已经合并。你的“仪表板生态系统”(即你的仪表板套件)只有在你采取整体方法,确保每个工具都有区分并且没有重叠时,才能蓬勃发展。
通过这种方法,你可以解决大约 50%的最常见仪表板问题,而你的仪表板也能够有效地开始补充并简化用户的日常工作(而不是使其更加复杂)。
通过一个强大的知识管理系统来支持你的仪表板生态系统
你新建立的系统只有在大家理解它并开始使用它时才会成功。对于前者,构建一个强大的知识管理系统(KMS)至关重要。这个 KMS 确保每个人都能找到关于你仪表板的相关信息:它们在哪里,它们包含哪些信息,数据变换是如何进行的,等等。一个有效的知识管理系统应该具备:
集中式: 在一个单一的位置
当前: 定期更新最新信息
可搜索性: 让团队成员能够轻松地找到他们需要的内容。
不断改进: 通过强大的反馈循环
可发现性: 人们知道知识管理系统(KMS)的存在/如何找到它
在建立这个知识管理系统的基础上——你还需要开展培训课程。就像开放世界游戏中的教程——在这些教程中,你学习如何使用武器和进行战斗——这里就是向用户解释如何取得成功的地方。以下是一些提示:
尝试多样化的学习形式: 一些人更喜欢由讲师主导的课程,一些人则更喜欢自助式的选择——看看哪种形式最适合你的用户和他们的时间安排。
有结构的入职培训: 为新用户提供入职培训课程
提供实践练习: 做比知道更重要,因此尽量在这些课程中包括动手实践和实际场景。
通过上述措施——你可以确保你的团队不仅配备了强大的仪表板,还具备有效使用它们的知识和培训。
跟踪你的生态系统的使用情况,并改进需要改进的地方
这一步常常被忽视,但基本上,如果用户从未听说过你的工具,你就无法为他们创造价值。了解用户如何找到你的仪表板、不同元素的激活/留存率、漏斗中的流失点以及如何改进这些问题,都是构建健康生态系统的一部分。
一些对我有效的策略包括:
理解和提高激活率: 一个“激活”的用户是指通过你的仪表板获得了足够价值的用户,以至于他们会继续使用它(即成为留存用户)。通常,激活可以通过在用户打开仪表板后的几秒钟内提供非常快速的成果来实现。
建立社区: 确定并培养不同团队中的“超级用户”。这些人可以作为仪表板的倡导者,提供同伴支持并推动在他们网络中的使用。
积极收集反馈并频繁更新: 你是为用户建立这个生态系统——因此要确保收集并采纳用户的反馈。调查、意见箱和开放的沟通渠道能确保仪表板根据用户需求进行演变。
反复沟通: 虽然邮件公告是一个不错的开始,但它只是一个更大策略中的一个杠杆。发布只是开始——尽量定期发送带有更新、成功故事(例如人们如何使用仪表板)、计划等内容的邮件。
通过这些策略,你可以确保工具的采用,并且你的用户会与你一起工作——构建一个能为他们创造最大价值的最佳生态系统。
总结来说…
一个仪表板生态系统的成功,不仅仅取决于“做什么”(WHAT),还取决于“怎么做”(HOW)。你可以构建世界上最棒的工具,但如果你的用户不知道这些工具,或者他们不知道如何使用它们——那么这些工具将永远无法创造任何价值。
希望你喜欢阅读这篇文章!你有什么想分享的小贴士吗?在评论区告诉大家吧!
如果你想读更多我的文章,以下是一些你可能会喜欢的文章:
一个曾经打造过一些不成功仪表板的人的检查清单
[towardsdatascience.com ## 构建具有分析成熟度的组织 (AMO)
一些简单的框架,帮助你明确组织的分析需求,并让它变得更加……
[towardsdatascience.com ## 什么使优秀的数据分析师与众不同
还在寻找新年决心吗?这里有 6 项技能可以帮助你和你的团队变得极其高效。
[towardsdatascience.com
PS:这篇文章同时发布在 Analytics Explained,这是一个新闻通讯,在这里我总结了自己在不同分析岗位上学到的经验(从新加坡的初创公司到旧金山的大型科技公司),并回答读者关于分析、增长和职业发展的提问。
细金属棒中的热扩散
热扩散方程的解满足傅里叶级数
·发表于 Towards Data Science ·9 分钟阅读·2024 年 7 月 10 日
--
如果你加热一段绝缘金属棒的某个小部分,并将其静置一段时间,会发生什么?我们对热扩散的日常经验使我们预测,温度会逐渐均匀,直到变得一致。在完美绝缘的情况下,热量将永远停留在金属中。
这是对该现象的正确定性描述,但如何定量描述它呢?

图片由 Jonny Gios 提供,来自 Unsplash
我们考虑一个一维问题,即一个包裹在绝缘材料中的细金属棒。绝缘层阻止热量从金属棒的侧面逸出,但热量可以沿着棒的轴向流动。
你可以在 这里找到本文中使用的代码。
热扩散方程
热扩散方程是一个简单的二阶微分方程,涉及两个变量:

x ∈ [0, L] 是杆件上的位置,t 是时间,u(x, t) 是温度,α 是材料的 热扩散率。
通过分析热扩散方程,我们能获得哪些关于温度变化的直觉?
方程(1)说明了温度变化的局部速率与曲率成正比,即温度分布对 x 的二阶导数。

图 1:带有局部变化率的温度分布。图像由作者提供。
图 1 展示了一个包含三个部分的温度分布。第一部分是线性的;第二部分有负的二阶导数,第三部分有正的二阶导数。红色箭头显示了沿杆的温度变化速率。
如果达到了稳态,即∂u/∂t = 0,温度分布将会平滑直到温度分布变为线性。
热扩散方程的解
热扩散方程(1)的解¹是:

你可以通过对方程(2)求导来验证它确实满足方程(1)。对于那些对推导过程感兴趣的人,请参见附录 I。
系数{Aₙ}、{Bₙ}、{λₙ}、C、D 和 E 是常数,必须根据初始条件和边界条件进行拟合。我们关于傅里叶级数的研究将发挥作用!
边界条件
边界条件是施加在 x=0 和 x=L 处的约束条件。在实际情况下,我们会遇到两种类型的约束条件:
绝缘,意味着在杆的两端∂u/∂x=0。这个约束防止了热量的流入或流出;
杆端固定温度:例如,杆的尖端可以通过热电冷却器进行加热或冷却,保持在所需的温度。
约束类型的组合将决定傅里叶级数的适当形式来表示初始温度分布。
两端绝缘
当杆的两端绝缘时,温度分布的梯度在 x=0 和 x=L 处被设为零:

初始条件是 t=0 时杆上的温度分布。假设由于某些难以理解的原因——也许杆被某种邪恶力量附身——温度分布如下:

图 2:初始温度分布。图像由作者提供。
为了运行我们的温度演化仿真,我们需要将方程(2)在 t=0 时的值与该函数匹配。我们通过样本点知道初始温度分布,但不知道其解析表达式。这正是傅里叶级数展开的任务。
从我们关于傅里叶级数的研究中,我们观察到偶数半区扩展会产生一个在两端导数为零的函数。这正是我们在本例中需要的。

图 3 展示了图 2 中函数的偶数半区扩展:

图 3:图 2 的函数的半范围展开。图片来自作者。
尽管重构中使用的有限项数在不连续点处产生了一些波动,但在极限点处导数为零。
将方程 (4)、(5)、(6) 和 (7) 与在 t=0 时评估的方程 (2) 进行等式比较:

我们可以解出常数:

仔细看看 (14)。这个表达式表明 λₙ 与 n 的平方成正比,其中 n 是特定余弦项在区间 [0, L] 内经过的半周期数。换句话说,n 与空间频率成正比。方程 (2) 包含一个指数因子 exp(λₙt),迫使每个频率分量随时间衰减。由于 λₙ与频率的平方成正比,我们预测初始温度轮廓的高频分量会比低频分量更快地衰减。
图 4 显示了 u(x, t) 在第一秒内的图像。我们观察到,右侧的高频分量在 0.1 秒内消失。中央部分的中等频率分量显著衰减,但在 1 秒后仍然可见。

图 4:图 2 温度轮廓的 1 秒模拟。图片来自作者。
当仿真运行 100 秒时,我们得到几乎均匀的温度:

图 5:两端绝热条件下的 100 秒仿真。图片来自作者。
两端固定温度
当两端保持在常温时,我们有如下形式的边界条件:

我们在上一篇文章中研究的傅里叶级数集没有包括边界温度固定为非零值的情况。我们需要重新构造初始温度轮廓 u₀(x),以开发一个在 x=0 和 x=L 处评估为 0 的函数。我们定义一个偏移的初始温度轮廓 û₀(x):

新定义的函数 û₀(x) 线性地将初始温度轮廓 u₀(x) 移动,使得 û₀(0) = û₀(L) = 0。
作为示例,图 6 显示了一个任意的初始温度轮廓 u₀,设置 x=0 时温度为 30,x=0.3 时温度为 70。绿色线(Cx + D)从 (0, 30) 到 (0.3, 70)。橙色曲线表示 û₀(x) = u₀(x) — Cx — D:

图 6:任意的 u₀(x)、û₀(x) 以及直线 Cx + D。图片来自作者。
将偏移的初始温度轮廓 û₀(x) 代入,并在两端为零,可以通过奇数半范围展开来展开:

将方程 (2) 与 (17)、(18)、(19)、(20) 和 (21) 进行等式比较:

我们可以求解常数:

现在可以运行温度曲线随时间变化的模拟 u(x, t),该方程来自于公式(2):

图 7:模拟温度演化,其中两端温度均保持恒定。图片由作者提供。
在稳态下,温度曲线在两端之间是线性的,恒定的热流通过金属杆传递。
左端绝热,右端温度固定
我们有以下边界条件:

我们基本上按照之前的相同程序进行。这一次,我们使用偶数区间展开来模拟初始温度分布,以便在左端获得零导数,在右端获得固定值:

这导致了以下常数:

1000 秒的模拟展示了预期的行为。左端温度梯度为零,右端保持恒定温度。稳态是一个温度均匀的金属杆:

图 8:模拟温度演化,其中左端为绝热,右端温度保持恒定。图片由作者提供。
结论
我们回顾了薄金属杆中温度分布动态的问题。从控制方程出发,我们推导出了通解。
我们考虑了各种边界配置。这些边界情况促使我们根据前一篇文章中推导的傅里叶级数形式来表示初始温度分布。初始温度分布的傅里叶级数表达式使我们能够求解积分常数,并运行 u(x, t)的模拟。
感谢您的阅读。您可以在这个仓库中试验代码。告诉我您的想法!
¹ 如果有遗漏的项,请在评论中告诉我。
附录 I
我们想要证明热扩散方程的解是:

是:

首先我们需要确认,如果一个函数 u*(x, t)满足(A1.1),那么函数 u*(x, t) + Cx + D + Et + E/(2α) x²也满足(A1.1)。
证明:

因此,通解必须包含 Cx + D + Et + E/(2α) x²项。
现在,信念的飞跃:可分离性假设。
假设一个解 u(x, t)具有以下形式:

为什么要做这样的假设?
因为这将使得解更容易找到。如果我们能够找到一个有效的解,证明将在事后出现。在这种情况下,我们不会冒着基于错误假设得出错误结论的风险,因为我们总是可以对找到的解进行微分并检查它是否满足(A1.1)。

将(A1.4)和(A1.5)代入(A1.1):

在方程(A1.9)中,我们观察到我们将 t 的函数与 x 的函数相等。满足此方程的唯一方法是让这两个函数都为常数。因此,我们引入了常数λₙ,它必须与这两个表达式匹配。
(a) 从(A1.9)解出 Gₙ(t):

将(A1.11)和(A1.12)代入(A1.10):

将(A1.15)代入(A1.11):

(b) 从(A1.9)解出 Fₙ(x):

将(A1.18)和(A1.20)代入(A1.17):

将(A1.24)代入(A1.18):

我们假设 Fₙ(x)是唯一的,并且正弦函数中的±符号被常数 Bₙ吸收。

将(A1.16)和(A1.26)代入一般解中:

不失一般性,我们可以设置 G₀=1,并让它被常数 Aₙ和 Bₙ吸收:

□
Python 中的混淆矩阵热力图
一张图片胜过千言万语。
·发表于 Towards Data Science ·6 分钟阅读·2024 年 9 月 6 日
--

图片来源:作者
介绍
混淆矩阵是展示机器学习模型犯错类型的便捷方式。它是一个 N 乘 N 的网格,其中 [n, m] 单元格中的值表示被标注为第 n 类但被识别为第 m 类的示例数量。在本教程中,我将重点介绍如何创建混淆矩阵及其热力图。将使用色彩调色板来展示不同组的大小,从而使得观察组间的相似性或显著差异变得更加容易。这种可视化在处理大量类别时非常有用。
这里是混淆矩阵元素的可视化解释。

图片来源:作者
请记住,用于演示混淆矩阵的数据是人工生成的,并不代表任何真实的分类模型。
现在,我将逐步讲解如何使用 Python 模块生成这样的混淆矩阵。
Python 最基本入门
赫克曼选择偏差建模在因果研究中的应用
选择偏差与 OLS 的识别假设之间的关系,以及应采取哪些步骤来解决它
·发表于 Towards Data Science ·9 分钟阅读·2024 年 8 月 14 日
--

在我的应用研究过程中,我曾经很难理解选择偏差和样本偏差问题的复杂性。这些问题以多种形式出现,可能源于不同因素,并且会影响因果模型中的外部和内部有效性。此外,它们常常是语义混淆的根源。
在处理线性因果模型中的偏差和不一致性时,理解遗漏变量问题是一个基础概念。当一个通常未被观察到的随机变量与自变量和模型误差之间存在相关性时,就会发生这个问题。在估计线性模型时,如果未考虑这个变量,就会导致估计值存在偏差。因此,这个问题会妨碍我们从因变量的变化中隔离出由自变量变化引起的方差,从而掩盖二者之间的真实因果关系。

因果 DAG 中的混杂变量
这些概念是相关的吗?选择偏差可以被认为是遗漏变量问题的一种形式吗?让我们深入探讨一下!
背景
我想阐明一些基本要素,以便充分理解选择偏差如何影响我们线性模型的估计过程。我们有一个因变量 Y,假设它与另一个变量 X(即自变量)之间存在一个线性关系(存在一些误差项)。
识别假设
给定人群变量 Y、X 的子样本 Y'、X' -
误差项(原始模型的误差项!!!)与 X'不相关。
误差项的均值为零。
Y 和 X 之间确实存在线性关系——

需要注意的是,在实证研究中,我们观察到 X 和 Y(或它们的子样本),但我们无法观察到误差项,这使得假设(1)无法直接测试或验证。在这一点上,我们通常依赖理论解释来证明这个假设。常见的解释是通过随机对照试验(RCTs),在其中,子样本 X 完全随机地收集,确保它与任何其他变量,特别是与误差项之间没有相关性。
条件期望
基于之前提到的假设,我们可以准确确定给定 X 的 Y 的条件期望的形式——

线性模型中的条件期望
最后的过渡来源于识别假设。需要注意的是,这个是 x 的函数,这意味着它表示在 x 等于特定值的条件下,所有观察到的 y 值的平均值(或者是给定 x 值的小范围内,y 的局部平均值——更多信息可以在这里找到)
最小二乘法(OLS)
给定符合识别假设的 X 样本,已知最小二乘法(OLS)方法为线性模型参数(alpha 和 beta)提供了一个封闭形式的解,从而也为给定 X 的 Y 的条件期望函数提供了一个一致且无偏的估计。
从本质上讲,OLS 是一种拟合线性线(或者在多元样本的情况下是拟合线性超平面)到一组(y_i, x_i)对的技术。OLS 特别有趣的是——
如果 Y 和 X 之间存在线性关系(考虑经典误差项),我们已经看到,给定 X 的 Y 的条件期望是完全线性的。在这种情况下,OLS 能够以强大的统计精度有效地揭示这个函数。
即使是满足之前讨论的识别假设的任何子样本,OLS 也能够实现这一点——前提是样本足够大。
动机
让我们从一个简单的示例开始,使用模拟数据。我们将模拟上述的线性模型。
使用模拟数据的一个显著优势是,它使我们能够更好地理解在现实世界场景中无法观察到的变量之间的关系,例如模型中的误差项。
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.api as sm
N = 5000
BETA = 0.7
INTERCEPT = -2
SIGMA = 5
df = pd.DataFrame({
"x": np.random.uniform(0, 50, N),
"e": np.random.normal(0, SIGMA, N)
})
df["y"] = INTERCEPT + BETA * df["x"] + df["e"]
并对全样本运行 OLS -
r1 = sm.OLS(df["y"], sm.add_constant(df["x"])).fit()
plt.scatter(df["x"], df["y"], label="population")
plt.plot(df["x"], r1.fittedvalues, label="OLS Population", color="r")

现在,让我们从我们的人群中生成一个随机的子样本 X,并对这个子样本应用 OLS。我将从之前生成的 500 个样本中随机选择 100 个 x,然后对这个子集运行 OLS。
sample1 = df.sample(100)
r2 = sm.OLS(sample1["y"], sm.add_constant(sample1["x"])).fit()
plt.scatter(sample1["x"], sample1["y"], label="sample")
plt.plot(sample1["x"], r2.fittedvalues, label="OLS Random sample")
和图表 -

看起来我们为随机子样本获得了一致的估计量,因为两种 OLS 结果产生了非常相似的条件期望线。此外,你可以观察到 X 和误差项之间的相关性——
print(f"corr {np.corrcoef(df['x'], df['e'])}")
print(f"E(e|x) {np.mean(df['e'])}")
# corr [[ 1\. -0.02164744]
# [-0.02164744 1\. ]]
# E(e|x) 0.004016713100777963
这表明识别假设得到了满足。然而,实际上我们无法直接计算这些,因为误差是不可观察的。现在,让我们创建一个新的子样本——我将选择所有(y, x)对,其中 y ≤ 10:
sample2 = df[df["y"] <= 10]
r3 = sm.OLS(sample2["y"], sm.add_constant(sample2["x"])).fit()
plt.scatter(sample2["x"], sample2["y"], label="Selected sample")
plt.plot(sample["x"], r3.fittedvalues, label="OLS Selected Sample")
我们得到的是——

现在,OLS 给出了一个完全不同的直线。让我们检查一下子样本 X 和误差之间的相关性。
print(f"corr {np.corrcoef(df['x'], df['e'])}")
print(f"E(e|x) {np.mean(df['e'])}")
# corr [[ 1\. -0.48634973]
# [-0.48634973 1\. ]]
# E(e|x) -2.0289245650303616
看起来识别假设被违反了。我们还可以将子样本的误差项绘制成 X 的函数——

如你所见,随着 X 的增加,大误差变少,表明存在明显的相关性,导致 OLS 估计量有偏且不一致。
让我们进一步探索这个问题。
建模
那么,这里到底发生了什么?
我将引用詹姆斯·赫克曼提出的模型,他与丹尼尔·麦克法登共同获得了 2000 年诺贝尔经济学奖。赫克曼因其在计量经济学和微观经济学方面的开创性工作而闻名,尤其是在定量分析中处理选择偏差和自我选择方面的贡献。他著名的赫克曼修正将在后文中讨论。
在他 1979 年的论文《样本选择偏误作为规格误差》中,赫克曼阐述了如何通过截尾依赖变量来引发选择偏误——这是一种选择的特殊情况,可以扩展到更多非随机样本选择过程。
对依赖变量进行截尾处理正是我们在上一节创建最后一个子样本时所做的。让我们来看看赫克曼的框架。
我们从一个完整的样本(或总体)开始,包含(y_i, x_i)对。在这种情况下,给定 x_i,ε_i 可以变化——它可以是正的、负的、小的或大的,这完全取决于误差分布。我们将这个完整的依赖变量样本称为 y*。然后,我们将 y 定义为截尾的依赖变量,它只包括我们实际观察到的值。

现在,让我们计算截尾变量 y 的条件期望:

如你所见,这个函数与我们之前看到的函数相似,但它包含了一个额外的项,这与之前不同。这个最后的项不能被忽视,这意味着条件期望函数不是纯线性的,它与 x 有关(并伴有一定的噪声)。因此,在未截尾的值上运行 OLS 将产生对α和β的有偏估计。
此外,该方程说明了选择偏误问题如何被看作是一个遗漏变量问题。由于最后一项依赖于 X,因此它与因变量共享大量的方差。
赫克曼修正
逆米尔斯比率
赫克曼修正方法基于以下原理:给定一个随机变量 Z,它遵循均值为μ,标准差为σ的正态分布,以下方程适用:

给定任何常数α,Φ(大写的 phi)表示标准正态分布的累积分布函数(CDF),而ɸ表示标准正态分布的概率密度函数(PDF)。这些值被称为逆米尔斯比率。
那么,这对我们有什么帮助呢?让我们重新审视前一个条件期望方程的最后一项—

结合我们的误差项遵循正态分布这一事实,我们可以使用逆米尔斯比率来刻画它们的行为。
回到我们的模型
逆米尔斯比率的优点在于,它将之前的条件期望函数转换为以下形式——

这导致了一个线性函数,其中包含了一个额外的协变量——逆米尔斯比率。因此,为了估计模型参数,我们可以将 OLS 应用于这个修订后的公式。
首先计算逆米尔斯比率—

并在代码中:
from scipy.stats import norm
sample["z"] = (CENSOR-INTERCEPT-BETA*sample["x"])/SIGMA
sample["mills"] = -SIGMA*(norm.pdf(sample["z"])/(norm.cdf(sample["z"])))
并进行 OLS 回归—
correcred_ols = sm.OLS(sample["y"], sm.add_constant(sample[["x", "mills"]])).fit(cov_type="HC1")
print(correcred_ols.summary())
以及输出结果—
OLS Regression Results
==============================================================================
Dep. Variable: y R-squared: 0.313
Model: OLS Adj. R-squared: 0.313
Method: Least Squares F-statistic: 443.7
Date: Mon, 12 Aug 2024 Prob (F-statistic): 3.49e-156
Time: 16:47:01 Log-Likelihood: -4840.1
No. Observations: 1727 AIC: 9686.
Df Residuals: 1724 BIC: 9702.
Df Model: 2
Covariance Type: HC1
==============================================================================
coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
const -1.8966 0.268 -7.088 0.000 -2.421 -1.372
x 0.7113 0.047 14.982 0.000 0.618 0.804
mills 1.0679 0.130 8.185 0.000 0.812 1.324
==============================================================================
Omnibus: 96.991 Durbin-Watson: 1.993
Prob(Omnibus): 0.000 Jarque-Bera (JB): 115.676
Skew: -0.571 Prob(JB): 7.61e-26
Kurtosis: 3.550 Cond. No. 34.7
==============================================================================
在现实中
α和β是我们希望估计的模型中未观察到的参数,因此在实际操作中,我们无法像之前那样直接计算逆米尔斯比率。赫克曼在他的修正方法中引入了一个初步步骤,以帮助估计逆米尔斯比率。这就是为什么赫克曼修正被称为双阶段估计量。
回顾一下,我们的问题是我们无法观察到因变量的所有值。例如,如果我们正在研究教育(Z)如何影响工资(Y),但只观察到超过某一阈值的工资,那么我们需要为那些工资低于该阈值的个体的教育水平制定理论解释。一旦我们有了这个解释,我们就可以估计一个probit模型,形式如下:

并使用这个 probit 模型的估计参数来计算逆米尔斯比率的估计量。在我们的案例中(注意我没有使用α和β)—
from statsmodels.discrete.discrete_model import Probit
pbit = Probit(df["y"] <= CENSOR, sm.add_constant(df["x"])).fit()
sample["z_pbit"] = sample["z"] = (pbit.params.const + sample["x"]*pbit.params.x)
sample["mills_pbit"] = -SIGMA*(norm.pdf(sample["z_pbit"])/(norm.cdf(sample["z_pbit"])))
correcred_ols = sm.OLS(sample["y"], sm.add_constant(sample[["x", "mills_pbit"]])).fit(cov_type="HC1")
同样,第二阶段的 OLS 给我们提供了一致的估计量—
OLS Regression Results
...
==============================================================================
coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
const -1.8854 0.267 -7.068 0.000 -2.408 -1.363
x 0.7230 0.049 14.767 0.000 0.627 0.819
mills_pbit 1.1005 0.135 8.165 0.000 0.836 1.365
==============================================================================
总结
我们使用模拟数据展示了由于对因变量值进行删选而导致的样本选择偏差问题。我们通过检查模型的模拟误差和偏倚子样本,探讨了该问题如何与 OLS 因果识别假设相关。最后,我们介绍了赫克曼方法来修正这种偏差,从而使我们即使在使用偏倚样本时,也能获得一致且无偏的估计量。
如果你喜欢这个故事,我将非常感激你的支持——请为我买杯咖啡,这对我来说意义重大!
参考文献
[1] James J. Heckman, 样本选择偏差作为规格错误(1979 年),《计量经济学》
[2] Nick Huntington-Klein, 《效应》 书籍(2022 年)
[3] Christopher Winship, 样本偏差的模型(1992 年)
除非另有说明,所有图片均由作者提供
Mamba 来了:选择性状态空间模型
🐍 向 Mamba 状态空间模型迈进:适用于图像、视频和时间序列
第三部分 — 向 Mamba 状态空间模型迈进:适用于图像、视频和时间序列
·发表于Towards Data Science ·阅读时间 17 分钟·2024 年 9 月 3 日
--

图片来源:Sascha Kirch.
这是我新系列的第三部分:🐍 向 Mamba 状态空间模型迈进:适用于图像、视频和时间序列。
Mamba,这个被认为将取代强大的 Transformer 的模型,从最初在深度学习中使用状态空间模型(SSMs)的构想到现在,已经走过了很长的路。
Mamba 为状态空间模型添加了选择性,从而在保持 SSM 子二次工作复杂度的同时,实现了类似 Transformer 的性能。其高效的选择性扫描比标准实现快 40 倍,并且在吞吐量上可达到 Transformer 的 5 倍。
加入我,一同深入探讨 Mamba,我们将发现选择性如何解决以往状态空间模型(SSMs)的局限性,Mamba 如何克服这些变化带来的新挑战,以及我们如何将 Mamba 融入现代深度学习架构中。
这是使用 LLM 为怪物生成数据教会我编程的经验
教程
我学会了如何使用 AI 作为一种替代方案,生成惊人的随机数据。
·发布于 Towards Data Science ·阅读时间:8 分钟·2024 年 6 月 9 日
--

图片来源:DALLE。
编程有点生疏了
最近,我决定重新熟悉编程网页应用程序,目标是保持实践并保持相关性——特别是在充满挑战的科技职位领域。
我经常通过创建相对简单的网页应用程序,使用 API、数据库、库,并将它们整合到一个应用程序中,来自学各种编程概念。
我发现,以这种方式创建应用程序,而不是严格地练习枯燥的 HackerRank 问题,不仅教会我最新的技术生态概念,还能建立一个全新的项目组合(你可以在面试中讨论的项目!)。
今天,事情有点不同
人工智能正在迅速改变软件开发的方式。
尽管一些影响者和媒体声称编程工作正在消失,我相信事实正好相反。
这是我通过 Wordle 学到的关于信息理论的知识
改善猜测的科学
·发表于Towards Data Science ·阅读时长 8 分钟·2024 年 12 月 9 日
--
Wordle 是由纽约时报开发的一款上瘾的在线每日单词谜题游戏。
规则很简单。玩家有六次机会来猜一个五个字母的单词。Wordle 会根据你每次猜测的结果,使用绿色、灰色和黄色的颜色来反馈你的猜测单词中的字母。
绿色表示你猜对了给定字母和位置。黄色表示该字母在单词中存在,但位置不对。灰色表示该字母在单词中不存在。
完全正确的猜测会使所有五个字母变成绿色。错误的猜测可能会导致所有字母变成灰色。如果猜测部分正确,你可能会看到黄色和灰色字母混合的情况,表示某些字母存在但位置不对。
我相信许多读者都玩过 Wordle,但如果你还没玩过,可以在这里试试。

来自 Wordle 网站的截图
由于其有趣的反馈循环,帮助缩小可能的单词列表,Wordle 是信息理论如何提升决策制定的完美示例。我玩这个游戏是为了娱乐,但我内心的书呆子却想更深入地探索我猜测背后的“原因”。是什么让某些猜测…
高性能 Python 数据处理:pandas 2 与 Polars,从 vCPU 视角看
Polars 承诺其多线程能力优于 pandas。但在单个 vCore 上是否也是如此?
·发表于 Towards Data Science ·阅读时长 7 分钟·2024 年 8 月 7 日
--

图片由作者生成,使用 DALL-E
不管喜不喜欢,pandas 多年来一直是 Python 数据分析中占主导地位的库。它在数据科学和分析中得到了广泛应用(无论是在工业界还是学术界),同时也被软件和数据工程师在数据处理任务中大量使用。
pandas 长期以来作为表格数据分析的冠军,目前正面临着一个新库——Polars 的挑战。Polars 旨在通过实现一个更现代化的框架,替代 pandas 解决今天 pandas 所解决的相同用例。它的主要承诺之一是提供更好的性能,利用一个用 Rust 编写的后端,这个后端经过优化以进行并行处理。此外,它还深入实现了矢量化操作(SIMD),这是使 NumPy 和 pandas 如此快速和强大的功能之一。
它到底快多少?
查看这个图表(发布于Polars 主页 2024 年 4 月 24 日),该图表展示了在不同的 Python 数据分析生态系统下,TPC-H 基准测试的运行时间(单位:秒)。一眼看上去,Polars 的速度比 pandas 快 25 倍。深入了解后,我们发现这些基准测试是在一台拥有 22 个 vCPU 的虚拟机上收集的。Polars 被设计为擅长并行处理,因此,当然可以从拥有如此多 vCPU 的系统中受益。而 pandas 则完全不支持多线程,因此可能仅使用了这台机器的 1 个 vCPU。换句话说,Polars 用 1/25 的时间完成了 pandas 的工作,但它也使用了 22 倍的计算资源。
vCore 的问题
虽然如今每台物理计算机都配有某种形式的硬件并行化(多核、多 ALU、超线程等),但虚拟服务器却不总是如此,通常使用较小的服务器可以降低成本。例如,像 AWS Lambda Functions、GCP Cloud Functions 和 Azure Functions 等无服务器平台,vCore 会随着内存的变化而扩展,而且由于按 GB-秒计费,你通常不会为函数分配超过需求的内存。
鉴于情况如此,我决定测试 Polars 与 pandas 的表现,特别是我对以下两个问题感兴趣:
**1. 在仅有 1 个 vCore 可用的情况下,Polars 如何与 pandas 比较
- Polars 如何随着 vCore 扩展
我们将考虑 4 个操作:分组和聚合、分位数计算、过滤和排序,这些操作可能会被融入到数据分析工作或管道中,这些工作可以在数据分析师、数据科学家以及数据和软件工程师的工作中看到。
设置
我使用了一台 AWS m6a.xlarge 机器,配备 4 个 vCore 和 16GB 内存,并利用 taskset 为每次测试分配 1 个 vCore 或 2 个 vCore,以模拟每次有较少 vCore 的机器。在库版本方面,我选择了当时可用的最新稳定版本:
pandas==2.2.2; polars=1.2.1
数据
数据集是随机生成的,包含 100 万行和 5 列,旨在表示某一产品中 10,000 个会话内 100,000 次用户操作的历史记录:
user_id(整数)
action_types(枚举类型,可以取值为[“click”, “view”, “purchase”])
timestamp(日期时间类型)
session_id(整数)
session_duration(浮动类型)
前提
给定数据集,我们想要根据用户的平均会话时长找到前 10%最活跃的用户。因此,我们首先需要计算每个用户的平均会话时长(分组与聚合),然后找到第 90 百分位数(百分位数计算),选择所有位于该百分位数以上的用户(筛选),并确保列表按平均会话时长排序(排序)。
测试
每个操作都运行了 200 次(使用timeit),每次取平均运行时间,并用标准误差作为测量误差。代码可以在这里找到。
关于急切执行与懒执行的说明
pandas 和 Polars 之间的另一个区别是,前者默认使用急切执行(语句按书写顺序执行),而后者使用懒执行(语句在需要时编译并执行)。Polars 的懒执行帮助它优化查询,这在大数据分析任务中是一个非常好的特性。我们选择将任务拆分,查看四个操作,目的是排除这一方面,专注于比较更基本的性能方面。
结果
分组 + 聚合

按库和 vCore 进行的分组和聚合操作的平均执行时间。图像和数据由作者提供。
我们可以看到,pandas 在 vCore 的扩展上没有表现出预期的效果。这个趋势在整个测试过程中都会保持。我决定保留图表中的这一部分,但之后我们将不再提及它。
polars 的结果相当令人印象深刻——在 1vCore 配置下,它比 pandas 快了三分之一的时间,而随着 vCore 数量增加到 2 核和 4 核时,它分别快了约 35%和 50%。
百分位数计算

按库和 vCore 进行的百分位数计算操作的平均执行时间。图像和数据由作者提供。
这个结果很有趣。在所有 vCore 配置中,polars 的执行速度是 pandas 的 5 倍。在 1vCore 配置下,平均执行时间为 0.2ms,但标准误差较大(意味着有时操作完成的时间会明显超过 0.2ms,而有时则会明显低于 0.2ms)。当扩展到多个 vCore 时,执行时间更加稳定——2vCore 配置为 0.21ms,4vCore 配置为 0.19ms(大约快 10%)。
筛选

按库和 vCore 进行的筛选操作的平均执行时间。图像和数据由作者提供。
在所有情况下,Polars 的完成速度都比 pandas 快(最差的运行时间仍然是 pandas 的两倍)。然而,我们在这里看到了一种非常不寻常的趋势 —— 运行时间随着 vCore 增加而增加(我们原本期望它会减少)。4vCore 的操作运行时间大约比 1vCore 的慢 35%。尽管并行化为你提供了更多的计算能力,但它通常伴随有一定的开销 —— 管理和协调并行进程通常是非常困难的。
这个 Polars 扩展问题令人困惑 —— 我这边的实现非常简单,而且在 Polars 的仓库中没有找到相关的开放问题(不过现在那儿确实有超过 1000 个开放问题)。
你知道为什么会发生这种情况吗?请在评论中告诉我。
排序

排序操作的平均执行时间,按库和 vCore 分类。图片和数据来源:作者。
过滤后,我们剩下了大约 13.5k 行数据。
在这一点上,我们可以看到 1vCore 的 Polars 情况比 pandas 慢得多(慢约 45%)。当我们扩展到 2vCore 时,运行时间与 pandas 相当,而扩展到 4vCore 时,Polars 的速度明显快于 pandas。这里可能的情况是,Polars 使用了一种针对并行优化的排序算法 —— 这种算法在单核心上可能表现不佳。
仔细查看文档后,我发现 Polars 中的排序操作有一个 multithreaded 参数,用于控制是否使用多线程排序算法或单线程排序算法。
排序(multithreading=False)

排序操作的平均执行时间(multithreading=False),按库和 vCore 分类。图片和数据来源:作者。
这次,我们可以看到更加一致的运行时间,虽然它们不会随着核心数的增加而扩展,但确实超过了 pandas。
结论
并行计算和分布式计算是困难的。我们通常认为,只要扩展我们的程序,它就能更快完成,但这总是会增加开销。在很多情况下,像 Redis 和 node.js 这样以极速著称的程序实际上是单线程的,并不支持并行化(node.js 以并发著称,但并发 ≠ 并行化)。
看起来,大多数情况下,即使只有 1 个可用 vCore,Polars 确实比 pandas 更快。令人印象深刻!
从过滤和排序操作来看,Polars 显然没有针对单个 vCore 情况进行优化,就像你在云服务中可能遇到的情况一样。如果你运行大量小型(<2GB 内存)无服务器函数,这是一个重要的考虑因素。为了提高速度进行扩展通常也伴随着价格的上升。
Polars 仍然是一个相对较新的解决方案,截至 2024 年中,它似乎没有 pandas 那么成熟。例如,在排序操作中的
multithreaded参数 —— 我希望能够有一个auto默认选项,可以根据硬件选择算法。
最后备注
在考虑更换像 pandas 这样的基础库时,性能并不是唯一需要考虑的因素。还需要考虑其他参数,例如切换的成本(学习新语法、重构旧代码)、与其他库的兼容性以及新解决方案的成熟度。
这里的测试旨在处于快速粗略和彻底基准测试之间的中间地带。还需要更多工作才能得出决定性结论。
我简要讨论了 pandas 和 Polars 如何受益于 SIMD(单指令多数据),另一个你可能听说过的硬件,GPU,因实现相同的思想而闻名。Nvidia 发布了一个插件,用于在 GPU 上执行 Apache Spark 代码——根据我的测试,它甚至比 Polars 更不成熟,但值得一试。
计算机视觉中的方向梯度直方图(HOG)
一种用于目标检测和识别任务的方向梯度直方图(HOG)(特征提取算法)的解释和实现
·发表在Towards Data Science·阅读 6 分钟·2024 年 11 月 4 日
--
介绍
方向梯度直方图最初由 Navneet Dalal 和 Bill Trigs 在他们的 CVPR 论文[“用于人体检测的方向梯度直方图”]中首次引入
有许多不同的特征提取算法,取决于它专注的特征类型,比如纹理、颜色或形状,它是描述整个图像还是仅局部信息。
HOG 算法是特征提取中最基本的技术之一,因为它是目标检测和识别任务的基本步骤。
在本文中,我们将探讨 HOG 算法的原理和实现。
什么是方向梯度直方图(HOG)?
HOG 是一种全局描述符(特征提取)方法,应用于图像中的每个像素,以提取像纹理这样的邻域信息(像素的邻域),并将给定图像中的信息压缩/抽象成一个称为特征向量的缩减/压缩形式,该向量可以描述这幅图像的特征,当涉及捕捉图像中的边缘和梯度结构时非常有用。此外,我们可以比较这个…
费曼的《时间序列预测的 MLOps 指南》
如何使用 Scikit-Learn 开发时间序列预测项目
·发表于Towards Data Science ·6 分钟阅读·2024 年 4 月 4 日
--

在数据驱动决策的世界里,时间序列预测通过利用历史数据模式来预测多个业务的未来结果,发挥着关键作用。无论你从事资产风险管理、交易、天气预测、能源需求预测还是交通分析,准确预测的能力对成功至关重要。
时间序列预测项目的成功不仅仅通过预测模型的拟合度来衡量。AI 驱动工具在实际应用中的有效性还取决于参与的各种角色或工具之间的合作水平。为了实现最顺畅的合作程度,必须尽早从初期开发阶段引入一套规则和最佳实践。

MLOps 组件 [图像由作者提供]
这些规则被称为机器学习操作(MLOps)。
MLOps 旨在将机器学习项目的各种元素统一成一个单一的、和谐的结构,努力维持这一点……
在单一端点托管多个 LLM
利用 SageMaker 推理组件,以高效的成本和性能方式托管 Flan 和 Falcon
·发表于 Towards Data Science ·10 分钟阅读·2024 年 1 月 11 日
--

图片来源:Unsplash 作者:Michael Dziedzic
去年,随着多个新模型的推出,配合各种技术和工具帮助训练、托管和评估这些模型,大型语言模型(LLM)领域经历了一次爆炸式增长。具体来说,托管/推理是这些 LLM 以及机器学习技术在整体上被认可的地方,因为没有推理,这些模型就没有可视化的结果或实际意义。
正如我之前所记录的那样,托管这些 LLM可能相当具有挑战性,主要因为模型的体积以及如何高效利用与模型相关的硬件。尽管我们已经使用了如 DJL Serving、文本生成推理(TGI)和 Triton 等模型服务技术,配合像 Amazon SageMaker 这样的模型/基础设施托管平台来托管这些 LLM,但随着我们尝试将 LLM 的使用场景推向生产化,另一个问题随之而来。我们如何能够在多个 LLM 上做到这一点?
为什么最初的问题会出现呢?当我们进入生产级别的使用场景时,通常会有多个模型可以使用。例如,可能在你的摘要任务中使用 Llama 模型,而在你的聊天机器人中使用 Falcon 模型。虽然…
AI 如何去除无法察觉的水印
探索检测 AI 生成媒体中的漏洞
·发布于Towards Data Science ·阅读时长:8 分钟·2024 年 2 月 6 日
--

无形水印工作原理的高级示意图。图片由作者提供。
我们为什么需要水印?
水印遍布整个互联网——这显然是有原因的。否则,你如何保护你的艺术作品或摄影作品不被随意用在某人的 PowerPoint 演示文稿中而不给创作者署名呢?解决这个问题的最简单方法就是制作像下面这样的可见水印。

可见水印示例。图片由作者基于 DALL-E 3 制作。
这种方法的主要缺点是可能会损害艺术本身。没有人会购买并使用这样带水印的猫咪图片。因此,虽然可以减少未经授权的复制,明显的水印也可能会使目标受众不愿意使用这件艺术作品。
在音乐领域,可听水印在免费 Hip-Hop 节拍中也很常见。节拍制作人通常会在第一段歌词开始之前插入带有品牌名称的语音样本。这可以作为防止非法下载的保障,或者当节拍是免费使用时,作为一种营销手段。
一段包含可听水印的 Hip-Hop 节拍示例,在约 10 秒时出现。“Solitude”由 Direct Beats 制作。
对于库存照片和 Hip-Hop 节拍来说,常见的做法是在在线预览中加上水印,并在客户支付后将原始产品发送给他们。然而,这种做法也容易被滥用。一旦没有水印的产品被购买,它就可能被复制并重新上传到互联网上。
无法察觉水印的必要性
知识产权的保护
不易察觉的水印有一个独特的优势:你可以在不影响产品质量的情况下,证明自己对任何数字副本的产品拥有所有权。就像一张带有隐形墨水的纸,纸张本身完全可用,但它携带了一条可以随时揭示的秘密信息。

不易察觉水印的示例。柠檬汁可以用作隐形墨水,可以通过加热使其显现。请观看这个视频来了解演示。图片由作者提供。
通过这项技术,创作者可以在作品中编码任何信息。更重要的是,由于他们有解码器的访问权限,他们可以随时声明对任何数字副本的原始作品拥有所有权。另一个对版权持有者来说新兴的机会是利用网络爬虫搜索网络并报告任何检测到的滥用行为。
AI 生成内容的检测
不易察觉的水印的另一个重要应用是用于检测 AI 生成的内容。随着 ChatGPT 及类似 AI 工具的出现,人们对于互联网上可能泛滥的危险 AI 生成内容表示担忧。像Meta和Google这样的科技公司正在推出不易察觉的水印系统,作为技术突破来缓解这一问题。他们的工具可以在不明显影响质量的情况下,为图像或音乐添加水印。
从原则上讲,这是一个值得关注的发展。通过不易察觉的水印,只有技术的拥有者能够解码并检测水印的存在。以我们之前的例子,Meta 和 Google 同时拥有隐形墨水和揭示它的手段。这使得他们能够准确地检测并筛选出在自己平台(例如 Instagram、YouTube)上生成的内容。通过合作,甚至像 X(前 Twitter)这样的独立平台也可以使用这项技术来限制 AI 生成的虚假信息或其他有害内容。

像 Meta 和 Google 这样的 AI 服务提供商正在建立自己的水印系统,以检测他们自己生成的内容,或者将这一能力出售给他人。图片由作者提供。
AI 如何移除不易察觉的水印?
尽管不易察觉的水印听起来很有前景,并且正被大型科技公司推广,但它们远非完美。事实上,许多这样的水印可以通过智能 AI 算法可靠地移除。但 AI 如何移除那些不易察觉的水印呢?
移除显眼水印
我们首先来了解如何使用 AI 去除可察觉的水印。我提出一种简单的方法:首先从网络上收集成千上万张图片。然后,自动地在这些图片上添加人工水印。确保这些水印看起来像真实的水印,且涵盖各种字体、大小和风格。接着,训练 AI 去除水印,方法是反复展示水印和没有水印的相同图像对。
虽然确实存在更复杂的方法,但这说明了如果 AI 被训练去识别水印的外观或声音,水印可以多么轻松地被去除。网络上有许多工具,允许我轻松地从我上面的小猫图片中去除水印:

使用watermarkremover.io去除的水印。在这个例子中,图像和水印都是人工的。请不要使用此类工具来破坏他人的知识产权。
去除难以察觉的水印
为了使用上述简单的方法,你需要为 AI 提供“前后”示例。然而,如果水印是难以察觉的,怎么找到这些示例呢?更糟的是,我们仅凭查看图像或听歌也无法判断是否存在水印。
为了解决这个问题,研究人员不得不发挥创意。赵等人(2023 年)提出了一个两阶段的程序。
通过向图像中添加随机噪声来摧毁水印
通过使用去噪算法重建真实图像

去除图像上难以察觉水印的两阶段程序。改编自赵等人,2023 年。
这非常聪明,因为它挑战了一个直觉,即要去除水印,必须能够检测到水印。这个方法无法定位水印。然而,如果唯一的目标是去除水印,那么通过向图像中加入足够的白噪声来摧毁它是快速且有效的。
当然,在添加噪声后,你可能已经破坏了水印,但最终得到的是一张噪声图片。最有趣的部分是,作者如何从噪声中重建原始图像。为此,他们使用了 AI 扩散模型,比如 DALL-E 3 或 Midjourney 中使用的模型。这些模型通过迭代地将随机噪声转化为逼真的图像来生成图片。

扩散模型如何从噪声中生成图像。摘自David Briand。
作为副作用,扩散模型也是极其有效的去噪系统,适用于图像和音频两者。通过利用这项技术,任何人都可以使用这个精确的两步流程去除难以察觉的水印。
这是否意味着难以察觉的水印毫无用处?

图片来源:Anthony Tori via Unsplash
是的,也不是。从一方面来说,到目前为止,似乎任何发明出来的不可察觉的水印系统都能通过某种方式被不法分子破解。当我第一次在 LinkedIn 上发布关于这个问题的帖子时,有人评论道:“这又是广告拦截器拦截器的游戏”,我完全同意这个观点。
针对赵等人(2023)提出的攻击方法,显而易见的防御方法是开发一种对其具有鲁棒性的隐形水印系统。例如,我们可以训练我们的水印系统,使得当前的 SOTA 扩散模型在用随机噪声去除水印后无法很好地重建图像。或者我们可以尝试构建一种对随机噪声攻击具有鲁棒性的水印。无论哪种方式,新漏洞都会迅速被发现并加以利用。
那么,不可察觉的水印真的毫无用处吗?在一篇近期文章中,Sharon Goldman 认为,尽管水印可能无法阻止不法分子,但它们对善意行为者仍然有益。它们有点像元数据,但直接编码在感兴趣的对象中。不同于可能在音频转换格式时丢失的 MP3 元数据,不可察觉的水印始终可以追溯,因为它们直接嵌入到音乐本身中。
然而,如果我对自己诚实一点,我曾希望不可察觉的水印能成为标记和检测 AI 生成内容的可行解决方案。显然,我错了。这些水印在很大程度上无法阻止不法分子将有害的 AI 生成内容充斥到互联网上。
在 AI 时代,我们还能通过什么方式证明所有权?

这张图片是作者使用 DALL-E 3 生成的。
对策的开发
如上所述,开发针对已知攻击算法的对策始终是一个选择。然而,在许多情况下,攻击者比防御者更容易在其攻击算法上进行迭代,而不是防御者开发保护措施。尽管如此,我们不能忽视可能发现一种不容易被破解的新水印方法的可能性。因此,确实值得投入时间和资源进行进一步的研究。
针对水印攻击者的法律后果
虽然用 AI 生成图像并上传到社交媒体平台通常不被认为是非法的,但故意去除 AI 生成图像上的水印很可能是非法的。由于我自己并没有法律专长,我只能认为,针对这种恶意行为采取法律后果威胁是合乎情理的。
当然,普通用户转发他们在网上找到的图片应该被排除在外。然而,故意去除水印以传播虚假信息显然是不道德的。即使法律压力无法根除滥用行为(历史上从未做到过),它也可以成为一种缓解因素。
重新思考所有权证明
目前有许多方法探讨区块链技术和/或智能合约如何在数字时代帮助证明所有权。简单来说,区块链是一种信息存储系统,用来追踪网络成员之间的互动。每一笔交易都可以被唯一识别,并且不能在任何后续时间被篡改。将智能合约加入这个网络后,我们就可以将交易与自动执行的责任绑定在一起,交易一旦完成,责任也会自动履行。
用更具体的说法,区块链和智能合约在未来可以用于自动化检查所有权或支付知识产权的版税。到目前为止,还没有这样的系统获得广泛采用。但我们可能离这些技术成为我们经济中不可或缺的资产只差几个技术突破。
结论
数字水印自互联网早期以来就被用于防止图片或音乐等知识产权的滥用。最近,它被讨论作为标记和检测 AI 生成内容的一种方法。然而,事实证明,AI 不仅擅长生成虚假图片,也同样擅长去除这些图片上的任何水印,这使得大多数检测系统变得无效。
很显然,我们不能因此而灰心丧气,在 AI 时代寻找替代的所有权证明方法仍然是必要的。通过制定具体的技术和法律对策,并同时探索区块链和/或智能合约如何在未来发挥作用,我们或许能够找到解决这一重要问题的办法。
参考文献
Zhao et al., 2023. 《不可见图像水印使用生成式 AI 可以被证明可移除》arxiv.org/pdf/2306.01953.pdf
关于我
我是一名音乐学家和数据科学家,分享我对 AI 与音乐领域当前话题的思考。以下是与本文相关的部分早期工作:
2024 年值得期待的 3 个音乐 AI 突破:
towardsdatascience.com/3-music-ai-breakthroughs-to-expect-in-2024-2d945ae6b5fdMeta 的 AI 如何基于参考旋律生成音乐:
medium.com/towards-data-science/how-metas-ai-generates-music-based-on-a-reference-melody-de34acd783AI 音乐源分离:它是如何工作的以及为什么如此困难:
medium.com/towards-data-science/ai-music-source-separation-how-it-works-and-why-it-is-so-hard-187852e54752
人工智能如何很快将人机交互带入新境界
随着人工智能模型在语音识别和合成、文本处理以及多模态领域达到卓越水平,终极的语音用户界面可能很快会无处不在。
LucianoSphere (Luciano Abriata, PhD)
·发表于 Towards Data Science ·18 分钟阅读 ·2024 年 8 月 28 日
--

语音用户界面(VUI),用于基于自然语言的语音人机交互,如 Dall-E 3 通过 ChatGPT 所构想的那样。
那是一个典型的周五下午,正值一周忙碌工作结束时,我与我的朋友兼同事陷入了一场激烈的讨论。他是一个“硬核”工程师、网站程序员和设计师,已经在网络开发领域奋战了十多年。作为一个以高效和对每行代码的控制为傲的人,尤其是总是把用户和用户体验放在心上的人,我的朋友对我提议语音接口很快就会成为常态的想法嗤之以鼻……
“语音接口?它们还不成熟、笨拙,说实话,有点吓人”,他说这话时并没有使用这些确切的词语,但他的意思无疑是如此,他也表达了科技界许多人共享的这种情绪。而这话是他在某种程度上已经被我说服之后说的,也许有 30%到 50%的可能性,甚至更大程度地接受了我们为分子图形和建模设计的增强现实/虚拟现实工具,的确需要这样的种类的人机交互,因为用户的双手是忙碌的……
AlphaFold 3 与 DALLE 2 的相似性及其他启示

来自 Unsplash 的扩散(字面意思)
为机器学习工程师理解生物领域的 AI 应用提供帮助
·发表于 Towards Data Science ·阅读时间 6 分钟·2024 年 10 月 24 日
--
在上一篇文章中,我们探讨了 AlphaFold 2 与 BERT 是如何通过 transformer 架构相互联系的。在这篇文章中,我们将学习最新版本的 AlphaFold 3(以下简称 AlphaFold)是如何与 DALLE 2(以下简称 DALLE)更加相似的,并深入了解它的架构和训练方面的其他变化。
这有什么联系?
AlphaFold 和 DALLE 是另一个例子,展示了不同的应用场景如何通过跨领域的架构学习受益。DALLE 是一个文本生成图像的模型,能够根据文本提示生成图像。而 AlphaFold 3 是一个预测生物分子相互作用的模型。这两个模型的应用听起来似乎差异极大,但它们都依赖于扩散模型架构。
因为推理图像和文本比生物分子相互作用更直观,我们将首先探讨 DALLE 的应用。然后我们将了解 AlphaFold 如何应用相同的概念。
扩散模型
理解扩散模型的隐喻:考虑追踪一滴染料在一杯水中的起源。当染料扩散时,它在液体中随机移动,直到均匀分布。要追溯到初始滴落的位置,你必须一步步重建它的路径,因为每一次移动都依赖于前一步。如果你一遍又一遍地重复这个实验,你将能够建立一个模型来预测染料的运动。
更具体地说,扩散模型被训练以预测并去除数据集中的噪声。然后在推理过程中,模型使用随机噪声生成新的样本。该架构包括三个核心组件:前向过程、逆向过程和采样过程。前向过程从训练数据开始,并在每个时间步加入噪声。正如你所预料的那样,逆向过程在每个步骤去除噪声。采样过程(或推理)执行逆向过程,使用训练好的模型和噪声调度,将初始的随机噪声输入转化为一个结构化的数据样本。

前向和逆向过程的简化示意图,其中一个像素化的心形图像添加了噪声,然后又去除噪声恢复到原始形状。(作者创作)
DALLE 和扩散
DALLE 将扩散模型架构融入两个主要组件:先验和解码器,并移除了其前身的自回归模块。先验模型接收由CLIP生成的文本嵌入(CLIP 是一个在图像和文本说明数据集上训练的模型,称为对比语言-图像预训练),并创建图像嵌入。在训练过程中,先验会接收一个文本嵌入和一个带噪声的图像嵌入版本。先验学习一步步去噪图像嵌入。这个过程使模型能够学习图像嵌入的分布,表示给定文本提示的可能图像的多样性。
解码器从结果图像嵌入生成图像,起始于随机噪声图像。逆扩散过程在每个时间步根据噪声调度,使用图像嵌入(来自先验)迭代地去除图像中的噪声。时间步嵌入告诉模型当前去噪过程的阶段,帮助其根据与最终步骤的接近程度调整去除的噪声量。
虽然 DALLE 2 使用扩散模型来生成图像,但其前身 DALLE 1 依赖自回归方法,基于文本提示依次预测图像标记。这种方法在计算效率上远不如扩散模型,需要更复杂的训练和推理过程,难以生成高分辨率图像,且常常产生伪影。
它的前身也没有使用 CLIP。相反,DALLE 1 直接学习文本-图像表示。CLIP 嵌入的引入统一了这些表示,使得文本到图像的表示更加稳健。

来自 OpenAI 团队的基于 CLIP 潜在空间的层次文本条件图像生成中 DALLE 2 架构的高层次概述
AlphaFold 如何使用扩散
虽然 DALLE 使用扩散帮助生成详细的视觉内容,但 AlphaFold 在生物分子结构预测中利用了类似的原理(不再仅仅是蛋白质折叠!)。
AlphaFold 2 不是一个生成模型,因为它是从给定的输入序列直接预测结构。由于引入了扩散模块,AlphaFold 3 是一个生成模型。就像 DALLE 一样,噪声被采样,然后反复去噪,最终生成一个结构。
扩散模块通过替代结构模块被纳入模型。这一架构变化大大简化了模型,因为结构模块预测氨基酸特定的框架和侧链扭转角,而扩散模块预测原子的原始坐标。这消除了推理过程中的几个中间步骤。

用于推理的 AF3 架构,展示了扩散模块的位置。从AlphaFold 3 对生物分子相互作用的精确结构预测
移除这些中间步骤的动力源于本次模型迭代的训练数据范围大幅增加。AlphaFold 2 只在蛋白质结构上进行训练,而 AlphaFold 3 是一个“多模态”模型,能够预测包括蛋白质、核酸、小分子、离子和修饰残基在内的复合物的联合结构。如果模型仍然使用结构模块,它将需要大量关于化学键和立体化学的复杂规则来创建有效的结构。
扩散不需要这些规则的原因是它可以在粗粒度和细粒度层面上应用。对于高噪声水平,模型侧重于捕捉全局结构,而在低噪声水平时,它会细化细节。当噪声最小时,模型会优化结构的局部细节,例如原子的位置和它们的取向,这对精确的分子建模至关重要。这意味着该模型可以轻松处理不同类型的化学成分,而不仅仅是标准的氨基酸或蛋白质结构。
与不同类型的化学成分进行合作的好处似乎在于,模型可以从其他类型的结构(如蛋白质-配体界面)中学到更多关于蛋白质结构的信息。看起来,整合多样化的数据类型有助于模型在不同任务之间更好地推广。这种改进类似于Gemini在结合图像和视频数据后,其文本理解能力的提升。
AlphaFold 的其他重要变化
MSA(多序列比对)的作用显著降低。 AF2 evoformer 被更简单的 pairformer 模块取代(从 48 个模块减少到 4 个模块)。正如你可能还记得我上一篇文章中提到的,MSA 曾被认为有助于模型学习哪些氨基酸序列在进化上是重要的。实验变化表明,降低 MSA 的重要性对模型准确性的影响有限。
必须对抗幻觉问题。 生成模型非常令人兴奋,但它们也带来了幻觉的问题。研究人员发现,模型会在无结构区域虚构出看起来合理的结构。为了解决这个问题,使用了交叉蒸馏方法,通过AlphaFold-Multimer (v.2.3)增强了训练数据,加入了预测的结构。交叉蒸馏方法帮助模型更好地区分结构化和无结构区域。这有助于模型理解何时避免添加虚构的细节。
某些相互作用比其他相互作用更容易预测。 为每种交互类型调整了采样概率,即对较简单的类型采样较少,这些类型可以通过相对较少的训练步骤来学习,而复杂类型则相反。这有助于避免在不同类型之间的欠拟合和过拟合。

初始训练和微调阶段的训练曲线展示了不同类别如何在不同训练步骤下达到最佳表现。因此,训练数据被子抽样,以防止按类别出现欠拟合和过拟合。来自AlphaFold 3 的生物分子相互作用的准确结构预测
高级学习
DALLE 2 和 AlphaFold 3 通过使用扩散模块对其前代模型进行了改进,同时简化了它们的架构。
在更广泛的数据类型上进行训练,使生成模型更加稳健。通过多样化 AlphaFold 训练数据集中的结构类型,模型能够改善蛋白质折叠预测,并能够推广到其他生物分子相互作用。类似地,用于训练 CLIP 的文本-图像对的多样性也改善了 DALLE。
噪声调度是训练扩散模型时一个重要的调节参数。调整噪声的强度会影响模型学习粗略和精细细节的能力。这样做大大简化了 AlphaFold 的设计,因为它消除了对侧链扭转角度等中间预测的需求。
再次感谢你的阅读,敬请期待下期内容。在那之前,继续学习和探索吧。
机器学习初学者如何找到论文(以及在哪里找到)
从会议到调查
·发表于Towards Data Science ·阅读时间:7 分钟·2024 年 12 月 22 日
--
我们正生活在一个超级时代。每年、每月、每周,机器学习研究的最新进展都在不断宣布。添加到 arXiv 的(机器学习)论文数量也在快速增长。仅在去年十月,计算机科学类别就有超过 11,000 篇论文被新增。

由Liam Charmer拍摄,图片来自Unsplash
同样,大型机器学习会议提交的论文数量也在不断增长——实际上多到为了确保公正的审稿过程,提交的作者需要为其他论文担任审稿人(这被称为互评)。
每篇论文可能介绍新的研究成果、新的方法、新的数据集或基准。作为机器学习的初学者,甚至开始都很困难:信息量太庞大。在上一篇文章中,我论述了为什么机器学习初学者应该阅读论文。精髓是,好的研究论文是自包含的讲座,可以磨练分析思维。
在本文中,我为初学者提供了一些关于如何以及在哪里找到有趣论文的思路,这是我之前没有完全展开的内容。通过 7 个步骤,我将引导你完成寻找和阅读有趣论文的可能过程。
步骤 1:确定你的兴趣领域
如何以及为什么使用 LLM 进行基于块的信息检索
·发表于 Towards Data Science ·阅读时长 9 分钟·2024 年 10 月 28 日
--

检索流程 — 图像来自作者
在本文中,我旨在解释为什么以及如何使用大语言模型(LLM)进行基于块的信息检索。
我以 OpenAI 的 GPT-4 模型为例,但这种方法可以应用于任何其他大语言模型,如 Hugging Face、Claude 等提供的模型。
每个人都可以免费访问这篇文章。
标准信息检索的考虑事项
主要概念是将一组文档(文本块)存储在数据库中,可以根据某些过滤条件进行检索。
通常,使用某种工具来实现混合搜索(例如 Azure AI Search、LlamaIndex 等),该工具可以:
使用诸如 TF-IDF(例如 BM25)之类的术语频率算法执行基于文本的搜索;
进行基于向量的搜索,计算向量之间的距离(通常是余弦相似度),即使使用不同的术语,也能识别相似的概念;
结合步骤 1 和步骤 2 的元素,通过加权突出最相关的结果。

图 1 - 默认混合搜索流程 — 图像来自作者
图 1 显示了经典的检索流程:
用户向系统提出问题:“我想谈谈巴黎”;
系统接收问题,将其转换为嵌入向量(使用与摄取阶段相同的模型),并找到距离最小的文本块;
系统还执行基于频率的文本搜索;
来自两个过程的块经过进一步评估,并根据排名公式重新排序。
这个解决方案取得了良好的效果,但也有一些局限性:
并非所有相关的块都会被检索到;
有时,一些块包含异常值,影响最终的回答。
典型检索问题的示例
让我们考虑“documents”数组,它代表了一个知识库的示例,这个知识库可能导致不正确的块选择。
documents = [
"Chunk 1: This document contains information about topic A.",
"Chunk 2: Insights related to topic B can be found here.",
"Chunk 3: This chunk discusses topic C in detail.",
"Chunk 4: Further insights on topic D are covered here.",
"Chunk 5: Another chunk with more data on topic E.",
"Chunk 6: Extensive research on topic F is presented.",
"Chunk 7: Information on topic G is explained here.",
"Chunk 8: This document expands on topic H. It also talk about topic B",
"Chunk 9: Nothing about topic B are given.",
"Chunk 10: Finally, a discussion of topic J. This document doesn't contain information about topic B"
]
假设我们有一个 RAG 系统,包含一个具有混合搜索能力的向量数据库和基于 LLM 的提示,用户向其提出如下问题:“我需要了解一些关于话题 B 的内容。”
如图 2 所示,搜索还返回了一个不正确的块,虽然语义上相关,但并不适合回答问题,在某些情况下,甚至可能会让负责提供答案的 LLM 感到困惑。

图 2 — 可能导致错误的检索示例 — 作者提供的图像
在这个示例中,用户请求关于“话题 B”的信息,搜索返回的块包括“该文档扩展了话题 H,也涉及话题 B”以及“与话题 B 相关的见解可以在这里找到”,还有一个块指出,“没有关于话题 B 的信息”。
尽管这符合混合搜索的预期行为(因为块引用了“话题 B”),但这并不是期望的结果,因为第三个块被返回时并没有识别到它对于回答问题并无帮助。
检索没有产生预期的结果,不仅仅因为 BM25 搜索在第三个块中找到了“话题 B”这一术语,还因为向量搜索得到了较高的余弦相似度。
要理解这一点,请参见图 3,它显示了使用 OpenAI 的 text-embedding-ada-002 模型进行嵌入时,块相对于问题的余弦相似度值。

图 3 — 使用 text-embedding-ada-002 的余弦相似度 — 作者提供的图像
很明显,“块 9”的余弦相似度值是最高的,而且在该块与引用“话题 B”的块 10 之间,还有一个块 1,它没有提到“话题 B”。
即使使用不同的方法衡量距离,这种情况也不会改变,正如在闵可夫斯基距离的情况下所见。
利用 LLM 进行信息检索:一个示例
我将描述的解决方案灵感来源于我在 GitHub 仓库中发布的内容 github.com/peronc/LLMRetriever/。
这个想法是让 LLM 分析哪些块对回答用户的问题有用,而不是通过排名返回的块(如 RankGPT 的情况),而是直接评估所有可用的块。

图 4 - LLM 检索管道 — 作者提供的图像
总结来说,如图 4 所示,系统接收一组需要分析的文档,这些文档可以来自任何数据源,如文件存储、关系数据库或向量数据库。
这些块被分成组,并通过与块总量成比例的线程数并行处理。
每个线程的逻辑包括一个循环,该循环遍历输入的块,为每个块调用 OpenAI 提示,以检查它与用户问题的相关性。
提示返回该块以及一个布尔值:true 如果它相关,false 如果它不相关。
让我们开始编码吧 😊
为了简化代码的解释,我将使用documents数组中存在的块(在结论部分我将引用一个实际案例)。
首先,我导入必要的标准库,包括 os、langchain 和 dotenv。
import os
from langchain_openai.chat_models.azure import AzureChatOpenAI
from dotenv import load_dotenv
接下来,我导入了我的 LLMRetrieverLib/llm_retrieve.py 类,它提供了几个执行分析所需的重要静态方法。
from LLMRetrieverLib.retriever import llm_retriever
接下来,我需要导入使用 Azure OpenAI GPT-4 模型所需的必要变量。
load_dotenv()
azure_deployment = os.getenv("AZURE_DEPLOYMENT")
temperature = float(os.getenv("TEMPERATURE"))
api_key = os.getenv("AZURE_OPENAI_API_KEY")
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
api_version = os.getenv("API_VERSION")
接下来,我继续初始化 LLM。
# Initialize the LLM
llm = AzureChatOpenAI(api_key=api_key, azure_endpoint=endpoint, azure_deployment=azure_deployment, api_version=api_version,temperature=temperature)
我们准备开始:用户提出问题以收集有关主题 B的额外信息。
question = "I need to know something about topic B"
在这一点上,相关块的搜索开始了,为此,我使用了 LLMRetrieverLib/retriever.py 库中的函数 llm_retrieve.process_chunks_in_parallel,该库也位于同一个代码库中。
relevant_chunks = LLMRetrieverLib.retriever.llm_retriever.process_chunks_in_parallel(llm, question, documents, 3)
为了优化性能,函数 llm_retrieve.process_chunks_in_parallel 采用多线程来分配块分析到多个线程。
主要思想是为每个线程分配从数据库中提取的一个子集的块,并让每个线程根据用户的问题分析这些块的相关性。
处理结束时,返回的块完全符合预期:
['Chunk 2: Insights related to topic B can be found here.',
'Chunk 8: This document expands on topic H. It also talk about topic B']
最后,我请求 LLM 为用户的问题提供答案:
final_answer = LLMRetrieverLib.retriever.llm_retriever.generate_final_answer_with_llm(llm, relevant_chunks, question)
print("Final answer:")
print(final_answer)
以下是 LLM 的回答,虽然内容相关,但由于这些块虽然相关,但在主题 B的内容上并不充分,因此回答显得有些简单:
Topic B is covered in both Chunk 2 and Chunk 8\.
Chunk 2 provides insights specifically related to topic B, offering detailed information and analysis.
Chunk 8 expands on topic H but also includes discussions on topic B, potentially providing additional context or perspectives.
评分场景
现在让我们尝试提出相同的问题,但使用基于评分的方法。
我请求 LLM 为每个块分配 1 到 10 的评分,以评估每个块与问题的相关性,仅考虑相关性高于 5 的块。
为此,我调用函数 llm_retriever.process_chunks_in_parallel,传入三个额外的参数,分别表示:将应用评分,必须大于或等于 5 才能被认为有效,并且我希望打印出带有各自评分的块。
relevant_chunks = llm_retriever.process_chunks_in_parallel(llm, question, documents, 3, True, 5, True)
带评分的检索阶段产生了如下结果:
score: 1 - Chunk 1: This document contains information about topic A.
score: 1 - Chunk 7: Information on topic G is explained here.
score: 1 - Chunk 4: Further insights on topic D are covered here.
score: 9 - Chunk 2: Insights related to topic B can be found here.
score: 7 - Chunk 8: This document expands on topic H. It also talk about topic B
score: 1 - Chunk 5: Another chunk with more data on topic E.
score: 1 - Chunk 9: Nothing about topic B are given.
score: 1 - Chunk 3: This chunk discusses topic C in detail.
score: 1 - Chunk 6: Extensive research on topic F is presented.
score: 1 - Chunk 10: Finally, a discussion of topic J. This document doesn't contain information about topic B
结果与之前相同,但评分很有趣 😊。
最后,我再次请求 LLM 为用户的问题提供答案,结果与之前的类似:
Chunk 2 provides insights related to topic B, offering foundational information and key points.
Chunk 8 expands on topic B further, possibly providing additional context or details, as it also discusses topic H.
Together, these chunks should give you a well-rounded understanding of topic B. If you need more specific details, let me know!
注意事项
这种检索方法是通过一些先前的经验逐步发展出来的。
我注意到,纯粹基于向量的搜索虽然能产生有用的结果,但当嵌入在非英语语言中时,通常不够充分。
使用 OpenAI 处理意大利语句子时,明显发现术语的标记化常常不正确;例如,“canzone”一词在意大利语中意味着“歌曲”,但它被错误地标记为两个不同的词:“can”和“zone”。
这导致了嵌入数组的构建远未达到预期效果。
在这种情况下,混合搜索(同时结合了术语频率计数)能够改善结果,但它们并不总是如预期般准确。
因此,这种检索方法可以以下列方式使用:
作为主要搜索方法: 在这种方法中,数据库会根据过滤条件(例如元数据过滤器)查询所有片段或其子集;
作为混合搜索中的一种优化:(这是 RankGPT 使用的相同方法)通过这种方式,混合搜索可以提取大量片段,系统可以过滤它们,确保只有相关的片段传递到 LLM,并且遵守输入令牌的限制;
作为备选方案: 在混合搜索没有得到期望结果的情况下,可以分析所有片段。
让我们讨论成本和性能
当然,闪闪发光的并不都是金子,因为必须考虑响应时间和成本。
在一个实际的使用案例中,我从一个关系型数据库中检索了由我的LLMChunkizerLib/chunkizer.py库语义分割的 95 个文本片段,这些片段来自两个 Microsoft Word 文档,总共 33 页。
对 95 个片段与问题相关性的分析是通过在本地 PC 上调用 OpenAI 的 API 完成的,该 PC 的带宽并不保证,平均约为 10Mb,导致响应时间从 7 秒到 20 秒不等。
自然地,在云系统中,或通过在 GPU 上使用本地 LLM,这些时间可以显著减少。
我认为,关于响应时间的考虑非常主观:在某些情况下,花更多时间来提供正确答案是可以接受的,而在其他情况下,至关重要的是不要让用户等待太久。
类似地,关于成本的考虑也非常主观,因为必须从更广泛的角度评估,是否提供尽可能准确的答案更为重要,还是一些错误是可以接受的。
在某些领域,由于不正确或遗漏的答案导致的声誉损害,可能超过了令牌的费用。
此外,尽管 OpenAI 和其他提供商的成本在近年来一直在稳步下降,但那些已经拥有基于 GPU 的基础设施的人(可能是由于需要处理敏感或机密数据)可能更倾向于使用本地的 LLM。
结论
总之,我希望能提供我的视角,分享如何进行检索的思考。
如果没有别的,我的目标是提供帮助,并可能激励他人在自己的工作中探索新方法。
记住,信息检索的世界广阔无垠,只要稍加创意和借助合适的工具,我们就能以我们从未想象过的方式发掘知识!
如果你想进一步讨论,欢迎通过LinkedIn与我联系。
GitHub 仓库可以在这里找到:


浙公网安备 33010602011771号