docs-merge-02
TowardsDataScience 2024 中文翻译(三)
VLM 简介:计算机视觉模型的未来
使用 VLM 构建一个准确度提高 28%的多模态图像搜索引擎。
·发布于Towards Data Science ·12 分钟阅读·2024 年 11 月 6 日
--
直到最近,AI 模型的范围较窄,只能理解语言或特定图像,但很少同时理解两者。
在这方面,像 GPT 这样的通用语言模型是一次巨大的飞跃,因为我们从专业化的模型转向了通用但更强大的模型。
但即使语言模型有所进步,它们仍然与计算机视觉领域分离,每个领域在孤立的状态下发展,未能弥合这一鸿沟。试想一下,如果你只能听而不能看,或者只能看而不能听,会发生什么。
我的名字是 Roman Isachenko,我是 Yandex 计算机视觉团队的一员。
在本文中,我将讨论视觉语言模型(VLM),我相信它们是复合 AI 系统的未来。
我将解释开发图像搜索多模态神经网络的基础知识和训练过程,并探讨使这一切成为可能的设计原则、挑战和架构。
到最后,我还将展示我们如何使用 AI 驱动的搜索产品来处理图像和文本,以及引入 VLM 后发生了哪些变化。
让我们开始吧!
什么是 VLM?
拥有数十亿甚至数百亿参数的 LLM 已经不再是新鲜事物。
我们到处都能看到它们!
LLM 研究的下一个重点是更多地倾向于开发多模态模型(全能模型)——能够理解和处理多种数据类型的模型。

多模态模型(图像由作者提供)
正如其名称所示,这些模型不仅能处理文本,还能分析图像、视频和音频。
那我们为什么要这么做呢?
万能工匠,什么都能做,什么都不精,往往比某一领域的专家更有优势。
近年来,我们看到一个趋势,即通用方法正在主导特定的方法。
想一想。
今天的基于语言的机器学习模型已经相对先进且具有通用性。一个模型可以进行翻译、摘要、识别语音标签,等等。

通用 NLP 模型(图片由作者提供)
但早些时候,这些模型往往是特定任务的(现在我们也有这种模型,但比以前少了)。
-
一个专门用于翻译的模型。
-
一个专门用于摘要的模型,等等。
换句话说,今天的自然语言处理模型(特别是大型语言模型)可以承担多种任务,这些任务以前需要开发高度特定的解决方案。
第二,这种方法让我们能够以指数级的速度扩展可用于模型训练的数据,这对于有限的文本数据量至关重要。然而,早期的做法是需要任务特定的数据:
-
一个专门的翻译标注数据集。
-
一个专门的摘要数据集,等等。
第三,我们相信训练一个多模态模型能够提升每种数据类型的表现,就像它对人类的表现一样。
对于这篇文章,我们将“黑盒”概念简化为一种情景:模型接收一张图像和一些文本(我们称之为“指令”)作为输入,并仅输出文本(响应)。
结果是,我们得到了一个更简化的过程,如下所示:

一个简化的多模态模型(图片由作者提供)
我们将讨论图像辨识模型,它们能够分析和解释图像所描绘的内容。
在深入技术细节之前,请考虑这些模型能够解决的问题。
以下是一些示例:

任务示例(图片由作者提供)
-
左上角图片:我们让模型描述这张图像。这个要求是用文本指定的。
-
中上部图片:我们让模型解释这张图像。
-
右上角图片:我们让模型解释这张图像,并告诉我们如果我们遵循这个标志,会发生什么。
-
底部图片:这是最复杂的例子。我们给模型一些数学问题。从这些例子中,你可以看到任务范围是如此广泛和多样化。
VLMs 是计算机视觉领域的一个新前沿,它们能够在零样本和单样本模式下解决各种基本的计算机视觉任务(分类、检测、描述)。
虽然 VLMs 可能尚未在每个标准任务中表现得特别出色,但它们正在迅速进步。
现在,让我们了解它们是如何工作的。
VLM 架构
这些模型通常有三个主要组件:

VLM 的简化表示(图片由作者提供)
-
LLM — 一种文本模型(在我们的例子中是 YandexGPT),它不理解图像。
-
图像编码器 — 一种图像模型(CNN 或视觉变换器),它不理解文本。
-
适配器 — 一个充当中介的模型,确保 LLM 和图像编码器能够良好配合。
该流程相当直接:
-
将图像输入图像编码器。
-
将图像编码器的输出转换为某种表示,供适配器使用。
-
将适配器的输出整合到 LLM 中(下面将详细介绍)。
-
当图像被处理时,将文本指令转换为一系列标记并输入 LLM。
关于适配器的更多信息
适配器是模型中最令人兴奋和最重要的部分,因为它准确地促进了 LLM 和图像编码器之间的通信/交互。
有两种类型的适配器:
-
基于提示的适配器
-
基于交叉注意力的适配器
基于提示的适配器最早在 BLIP-2 和 LLaVa 模型中提出。
这个想法很简单且直观,从名字本身就可以看出。
我们将图像编码器的输出(一个向量、一系列向量或一个张量——取决于架构)转换为一系列向量(标记),然后将其输入到 LLM 中。你可以采用一个简单的 MLP 模型,带有几层,并将其作为适配器,结果可能会非常好。
基于交叉注意力的适配器在这方面要复杂一些。
它们已被应用于最近关于 Llama 3.2 和 NVLM 的论文中。
这些适配器旨在将图像编码器的输出转换为可在 LLM 的交叉注意力块中用作键/值矩阵的形式。此类适配器的示例包括像 感知器重采样器 或 Q‑former 这样的变换器架构。

基于提示的适配器(左)和基于交叉注意力的适配器(右)(图像来源:作者)
基于提示的适配器(左)和基于交叉注意力的适配器(右)
两种方法各有优缺点。
目前,基于提示的适配器 提供更好的结果,但它们会削弱 LLM 的大量输入上下文,而这很重要,因为 LLM 的上下文长度有限(目前如此)。
基于交叉注意力的适配器不会削弱 LLM 的上下文,但需要大量的参数才能达到良好的质量。
VLM 训练
架构确定后,让我们深入探讨训练过程。
首先,请注意,VLM 不是从头开始训练的(尽管我们认为这只是时间问题),而是基于预训练的 LLM 和图像编码器构建的。
使用这些预训练模型,我们在多模态文本和图像数据上对 VLM 进行微调。
这个过程包括两个步骤:
-
预训练
-
对齐:SFT + RL(可选)

VLM 训练过程(图像来源:作者)
注意这些阶段如何类似于 LLM 的训练?
这是因为这两个过程在概念上是相似的。让我们简要回顾一下这些阶段。
VLM 预训练
这是我们在这个阶段希望实现的目标:
-
将文本和图像模态连接起来(记住,我们的模型包含一个我们之前没有训练过的适配器)。
-
将世界知识加载到我们的模型中(这些图像包含许多细节,例如 OCR 技能)。
预训练 VLM 时使用了三种类型的数据:
- 交替预训练:这与 LLM 的预训练阶段相似,在这个阶段,我们通过输入网络文档来训练模型执行下一个 token 预测任务。对于 VLM 的预训练,我们选择带有图像的网络文档,训练模型预测文本。这里的关键区别在于,VLM 会同时考虑页面上的文本和图像。这样的数据很容易获得,因此这种类型的预训练不难扩展。然而,数据质量并不高,提升其质量证明是一项艰巨的任务。

交替预训练数据集(图片由作者提供)
图像-文本配对预训练:我们训练模型执行一个特定任务:为图像生成标题。你需要一个包含相关描述的大量图像数据集来完成这项任务。这种方法更为流行,因为许多此类数据集被用于训练其他模型(如文本生成图像、图像到文本检索)。

图像-文本配对预训练数据集(图片由作者提供)
基于指令的预训练:在推理过程中,我们将图像和文本输入模型。为什么不从一开始就用这种方式训练模型呢?这正是基于指令的预训练所做的:它在一个庞大的图像-指令-答案三元组数据集上训练模型,即使数据并不总是完美的。

基于指令的预训练数据集(图片由作者提供)
要正确训练一个 VLM 模型需要多少数据是一个复杂的问题。在这一阶段,所需的数据集大小可以从几百万到几十亿(幸运的是,不是上万亿!)个样本不等。
我们团队使用了基于指令的预训练,样本量为几百万。然而,我们相信交替预训练具有巨大的潜力,我们正在积极朝着这个方向努力。
VLM 对齐
一旦预训练完成,就可以开始对齐阶段了。
它包括 SFT 训练和可选的 RL 阶段。由于我们只有 SFT 阶段,我将重点介绍这个。
尽管如此,最近的论文(例如 这篇 和 这篇)通常在 VLM 上加入 RL 阶段,使用与 LLM 相同的方法(DPO 和方法名称首字母的各种修改)。
无论如何,回到 SFT。
严格来说,这一阶段与基于指令的预训练相似。
区别在于我们注重高质量的数据,具备适当的响应结构、格式化和强大的推理能力。
这意味着模型必须能够理解图像并对其进行推断。理想情况下,它应对没有图像的文本指令做出同样好的响应,因此我们还将添加高质量的纯文本数据。
最终,这个阶段的数据通常在数十万到几百万个样本之间。在我们的案例中,这个数字大约在六位数。
质量评估
让我们讨论一下评估 VLM 质量的方法。我们使用两种方法:
-
计算开源基准上的指标。
-
通过并排(SBS)评估来比较模型,在这种评估中,评估者比较两个模型的响应并选择更好的一个。
第一种方法允许我们在特定的数据子集上测量代理指标(例如分类任务中的准确性)。
然而,由于大多数基准测试都是英语的,因此它们不能用于比较其他语言(如德语、法语、俄语等)训练的模型。
虽然可以使用翻译,但翻译模型引入的错误使得结果不可靠。
第二种方法可以更深入地分析模型,但需要细致(且昂贵)的手动数据标注。
我们的模型是双语的,能够用英语和俄语进行响应。因此,我们可以使用英语开源基准并进行并排比较。
我们信任这种方法并在其中投入了大量资金。以下是我们要求评估者评估的内容:
-
语法
-
可读性
-
完整性
-
与指令的相关性
-
错误(逻辑错误和事实错误)
-
幻觉
我们力求评估模型技能的一个完整且多样的子集。
以下饼图展示了我们 SbS 评估任务的分配情况。

质量评估任务分配(图片来源:作者)
这总结了 VLM 基础知识的概述,以及如何训练一个模型并评估其质量。
流水线架构
今年春天,我们为 Neuro(一个由 AI 驱动的搜索产品)增加了多模态功能,允许用户通过文本和图像提问。
直到最近,它的底层技术还不是真正的多模态。
这里展示的是这个流水线之前的样子。

流水线架构(图片来源:作者)
这个图看起来很复杂,但一旦将其分解成步骤后,实际上是直观的。
以下是这个过程之前的样子
-
用户提交一张图像和一个文本查询。
-
我们将图像发送到我们的视觉搜索引擎,后者会返回关于该图像的大量信息(标签、识别的文本、信息卡片)。
-
我们使用改写器(一个微调的 LLM)结合这些信息和原始查询来构造文本查询。
-
使用改写后的文本查询,我们使用 Yandex 搜索来检索相关文档(或摘录,我们称之为信息上下文)。
-
最后,利用所有这些信息(原始查询、视觉搜索信息、重述的文本查询和信息上下文),我们通过生成器模型(另一个经过微调的 LLM)生成最终响应。
完成!
如你所见,我们过去依赖两个单模态 LLM 和我们的视觉搜索引擎。这个方案在少量查询样本中表现良好,但也有其局限性。
下面是一个例子(虽然有些夸张),说明事情可能如何出错。

两个单模态 LLM 的问题(图像来源:作者)
在这里,重述器接收视觉搜索服务的输出,但根本无法理解用户的原始意图。
反过来,LLM 模型对图像一无所知,生成了一个错误的搜索查询,同时获取了关于哈巴狗和苹果的标签。
为了提高我们的多模态响应质量并允许用户提出更复杂的问题,我们将 VLM 引入了我们的架构中。
更具体地说,我们进行了两项主要修改:
-
我们用 VLM 重述器替代了 LLM 重述器。从本质上讲,我们开始将原始图像与视觉搜索引擎的文本一起输入到重述器中。
-
我们在流程中添加了一个独立的 VLM 字幕生成器。这个模型提供了图像描述,我们将其作为最终生成器的信息上下文。
你可能会问
为什么不让生成器本身基于 VLM 呢?
这是个好主意!
但有一个问题。
我们的生成器训练继承自 Neuro 的文本模型,该模型经常更新。
为了更快、更方便地更新流程,我们引入了一个独立的 VLM 模块,这样做要容易得多。
此外,这个设置同样有效,下面展示了结果:

在 AI 驱动的搜索中使用 VLM(图像来源:作者)
训练 VLM 重述器和 VLM 字幕生成器是两个独立的任务。
为此,我们使用前面提到的 VLM,并对其进行了针对这些特定任务的微调。
微调这些模型需要收集成千上万的独立训练数据集。
我们还必须对基础设施进行重大改进,以提高流程的计算效率。
评估质量
现在,进入重要的问题:
引入 VLM 到这个相对复杂的流程中,真的有改善吗?
简而言之,是的,确实如此!
我们进行了并行测试,以衡量新流程的性能,并将我们之前的 LLM 框架与新的 VLM 框架进行了比较。
这个评估与之前讨论的核心技术评估类似。然而,在这种情况下,我们使用了一组不同的图像和查询,这些查询更贴近用户可能提出的需求。
下面是这个桶中聚类的大致分布。

聚类分布(图像来源:作者)
我们的离线并行评估显示,我们大大提高了最终响应的质量。
VLM 管道显著提高了响应质量,涵盖了更多用户场景。

VLM 与 LLM 在神经网络中的准确性对比(图源:作者)
我们还希望在真实观众中测试这些结果,看看我们的用户是否会注意到我们认为能够改善产品体验的技术变化。
因此,我们进行了一个在线拆分测试,将我们的 LLM 管道与新的 VLM 管道进行了对比。初步结果显示了以下变化:
-
包含图片的指令数量增加了 17%。
-
会话数量(用户连续输入多个查询)增长了 4.5%。
重申一下之前的观点,我们坚信 VLMs 是计算机视觉模型的未来。
VLMs 已经能够解决许多开箱即用的问题。经过一点微调,它们完全可以提供最先进的质量。
感谢阅读!
数据科学的 Docker 直观指南
了解 Docker 的基本概念、常用命令,以及如何将机器学习应用容器化
·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 5 月 13 日
--

在担任数据科学家的工作中,编写可以在任何操作系统中运行且具备所有必要依赖的代码,并准备好在云端部署是非常重要的。尽管做了很多努力,它可能仍然无法正常工作,并且你可能会浪费时间去理解问题所在。
我们可以使用什么工具来避免这些困扰?Docker 是解决您问题的答案。使用 Docker,您可以轻松获得一个强大的数据科学项目环境,而不会让自己陷入困境。
在这篇文章中,我将解释 Docker 的主要概念、最常用的命令,以及一个快速的机器学习应用容器化示例。让我们开始吧!
目录:
-
什么是 Docker?
-
Docker 的基本概念
-
虚拟机与容器
-
设置 Docker
-
将机器学习应用容器化
-
Docker 命令总结
-
Docker 的局限性
什么是 Docker?
直观指南:将 SQL 与 Python 融合以进行数据科学
学习掌握 MySQL 连接器,这是一个使 Python 能够与 MySQL 数据库交互的库
·发表于 Towards Data Science ·10 分钟阅读·2024 年 9 月 21 日
--

照片来源:ThisisEngineering 来自 Unsplash
在我最近的工作经历中,我发现了两个重要的数据管理和分析工具——Python 和 SQL 之间的显著协同效应。
如果你已经沉浸在数据科学的世界中,你可能已经意识到 Python 对于任何数据科学家都是不可或缺的,因为它拥有广泛的库生态系统,支持诸如数据操作、数据可视化和建模等任务。
虽然 Python 在处理复杂数据流程时能够用少量代码实现出色的效果,但 SQL 在高效管理结构化数据、执行查询以及进行读取和修改数据操作方面依然无可匹敌。
在本文中,我将展示一个实际的使用案例,突出展示将 Python 和 SQL 融合以处理 MySQL 数据的好处。让我们开始吧!
目录:
-
什么是关系型数据库?
-
设置 MySQL
-
开始使用 MySQL Workbench
-
将 Python 连接到 MySQL 数据库
-
使用 Python 执行常见 SQL 操作
强化学习的直观介绍,第一部分
以适合初学者的方式探索流行的强化学习环境
·发表于 Towards Data Science ·阅读时长 16 分钟·2024 年 9 月 6 日
--
这是一个关于强化学习概念的系列教程,使用 OpenAI Gymnasium Python 包中的环境进行演示。本文将涵盖理解并实现 Q 学习以解决“Frozen Lake”环境所需的高阶概念。
祝学习愉快 ❤ !

一片微笑的湖(图片由作者拍摄,使用 OpenAI Gymnasium 的 Frozen Lake 环境制作)
让我们通过将强化学习与日常生活中熟悉的例子进行对比,来探索这一领域。
纸牌游戏 — 想象你在玩一场纸牌游戏:当你刚学会这款游戏时,规则可能不清楚。你打出的牌可能不是最优的,使用的策略也可能不完美。随着你玩得更多,也许赢了几局,你会学到什么时候打什么牌,哪些策略比其他策略更好。有时,虚张声势可能更好,但其他时候你应该弃牌;将一张万能牌留到以后使用可能比立即打出它更好。通过一系列的经验和奖励,你能学到最佳的行动方案。你的经验来自于玩游戏,而当你的策略奏效时,你会得到奖励,也许这能带来胜利或新的高分。

一局纸牌游戏(图片由作者从 Google 的纸牌游戏中截图)
经典条件作用 — 通过在喂狗之前按铃,伊凡·巴甫洛夫展示了外部刺激与生理反应之间的联系。狗被条件化为将铃声与食物联系起来,因此它在听到铃声时开始流口水,即使没有食物存在。虽然这严格来说不是强化学习的例子,但通过反复的经验,狗在听到铃声时被奖励食物,最终学会了将二者联系起来。
反馈控制 — 一种在工程学科中应用的控制理论,该理论认为系统的行为可以通过向控制器提供反馈来进行调整。作为反馈控制的一个子集,强化学习需要来自当前环境的反馈来影响我们的行动。通过提供形式为奖励的反馈,我们可以激励代理选择最优的行动方案。
代理、状态与环境
强化学习是一个基于过去经验积累与可量化奖励的学习过程。 在每个例子中,我们展示了我们的经验如何影响我们的行动,以及如何通过强化奖励与反应之间的正向联系来解决某些问题。如果我们能学会将奖励与最优行动联系起来,我们就能推导出一种算法,选择那些带来最高可能奖励的行动。
在强化学习中,“学习者”被称为代理。代理与环境进行交互,通过其行动,根据所收到的奖励来学习什么是“好的”或“坏的”。

强化学习中的反馈循环:代理 -> 行动 -> 环境 -> 奖励,状态(图片来自作者)
为了选择一个行动方案,我们的代理需要一些关于环境的信息,这些信息由状态提供。状态代表了关于环境的当前信息,如位置、速度、时间等。我们的代理并不一定知道当前状态的全部信息。代理在任何给定时间点可获取的信息称为观察值,它包含了状态中某些子集的信息。并非所有状态都是完全可观察的,有些状态可能要求代理只知道环境中发生的事情的一小部分。利用观察值,我们的代理必须根据学习到的经验推测出可能的最佳行动,并尝试选择能够带来最高预期奖励的行动。
在选择了一个行动后,环境将通过提供更新后的状态和奖励来作出反馈。这个奖励将帮助我们判断代理所采取的行动是否是最优的。
马尔可夫决策过程(MDP)
为了更好地表示这个问题,我们可以将其视为马尔可夫决策过程(MDP)。MDP 是一个有向图,其中图中的每条边都有非确定性的属性。在图中的每个可能状态下,我们都有一组可以选择的动作,每个动作都会带来一定的固定回报,并且有一定的转移概率会导致某个后续状态。这意味着,相同的动作每次未必会导致相同的状态,因为从一个状态到另一个状态的转移不仅依赖于动作,还依赖于转移概率。

马尔可夫决策过程的表示(图片来源:作者)
决策模型中的随机性在实际强化学习中是非常有用的,它允许动态环境,其中智能体无法完全控制。像棋类这样的回合制游戏要求对手先走一步,才能轮到你行动。如果对手随机出招,那么棋盘的未来状态是无法保证的,我们的智能体必须在考虑多个可能的未来状态的同时进行决策。当智能体采取某个动作时,下一状态取决于对手的走法,因此由对手可能的走法的概率分布来定义。

动画展示了棋盘的状态也依赖于对手选择的走法(图片来源:作者)
我们的未来状态因此是智能体选择某个动作的概率与对手选择某个动作的转移概率的函数。一般来说,我们可以假设,对于任何环境,从当前状态到后续状态的智能体转移概率由智能体选择某个动作的联合概率和转移到该状态的转移概率表示。
求解 MDP
为了确定最佳行动路径,我们希望为智能体提供大量的经验。通过环境的多次迭代,我们的目标是为智能体提供足够的反馈,使其能够正确地选择最佳行动,尽可能多地选择最佳行动。回想一下我们对强化学习的定义:一种建立在过去经验积累基础上,并伴随可量化回报的学习过程。 在积累了一些经验后,我们希望利用这些经验来更好地选择未来的行动。
我们可以通过利用经验来预测未来状态的预期回报,从而量化我们的经验。随着我们积累更多的经验,我们的预测将变得更加准确,并在经过一定次数的迭代后收敛到真实值。对于我们收到的每一个回报,我们都可以用它来更新我们关于当前状态的一些信息,这样下次遇到该状态时,我们就能更好地估计我们可能会收到的回报。
冰湖问题
让我们考虑一个简单的环境,其中我们的智能体是一个小角色,试图穿越一片冰冻的湖面,表示为一个二维网格。它可以朝四个方向移动:向下、向上、向左或向右。我们的目标是教它从左上角的起始位置移动到地图右下角的结束位置,同时避免冰面上的洞。如果我们的智能体成功到达目的地,我们会给它奖励 +1。对于所有其他情况,智能体将获得 0 奖励,并且如果它掉进一个洞中,探索将立即终止。

冰湖动画(图片来自OpenAI Gymnasium 冰湖文档)
每个状态可以通过它在网格中的坐标位置来表示,起始位置位于左上角,表示为原点 (0, 0),右下角的结束位置表示为 (3, 3)。
最通用的解决方案是应用一些路径寻找算法,以找到从左上角到右下角的最短路径,同时避开冰面上的洞。然而,智能体从一个状态移动到另一个状态的概率并不是确定性的。每次智能体尝试移动时,它有 66% 的机会“滑倒”,并移动到一个随机的相邻状态。换句话说,智能体选择的行动只有 33% 的机会会真正发生。传统的路径寻找算法无法处理引入转移概率的情况。因此,我们需要一个能够处理随机环境的算法,也就是强化学习。
这个问题可以很容易地表示为一个马尔科夫决策过程(MDP),其中我们网格中的每个状态都有一些转移概率,可能移动到任何相邻的状态。为了求解我们的 MDP,我们需要从任何给定状态中找到最优的行动路径。回想一下,如果我们能找到一种方法,准确预测每个状态的未来奖励,我们就可以通过贪婪地选择最高预期奖励的状态来选择最佳路径。我们将把这个预测奖励称为状态值。更正式地,状态值将定义从某个状态开始获得的预期奖励,以及在此之后所有未来状态的预期奖励估计,假设我们始终按照选择最高预期奖励的策略行事。最初,我们的智能体并不知道预期得到什么奖励,因此这个估计可以任意设为 0。
现在让我们定义一种方法,供我们的智能体选择行动:我们将首先用一个表格来存储我们对每个状态的预测状态值估计,表格初始时全部为零。

表示我们网格中每个状态的估计状态值的表格(图片作者提供)
我们的目标是随着我们探索环境来更新这些状态价值的估算。我们越是遍历环境,积累的经验就越多,估算也会变得更加精确。随着估算的改进,我们的状态价值将变得更加准确,并且我们将更好地表示哪些状态会带来更高的奖励,从而使我们能够根据哪个后续状态具有最高的状态价值来选择行动。这一定会奏效,对吧?

我们的 MDP 单一分支的可视化表示(图片来自作者)
状态价值与行动价值
不,抱歉。你可能会注意到的一个直接问题是,单纯根据最高状态价值来选择下一个状态并不可行。当我们查看可能的下一个状态集合时,并没有考虑当前的行动——也就是说,我们从当前状态到达下一个状态时所采取的行动。根据我们对强化学习的定义,代理-环境反馈循环总是包括代理采取某个行动,环境则通过状态和奖励来做出回应。如果我们只看下一个状态的状态价值,我们实际上是在考虑从这些状态开始时我们将获得的奖励,这完全忽视了我们为到达这些状态所采取的行动(及其带来的奖励)。此外,试图选择下一个可能状态中的最大值还假设我们首先能够到达那里。有时候,更为保守一点可以帮助我们更一致地实现最终目标;不过,这超出了本文的讨论范围 :(。
我们希望直接评估我们可用的行动,而不是评估可能的下一个状态集合。如果我们之前的状态价值函数是基于下一个状态的预期奖励,那么我们希望更新这个函数,现在要将从当前状态采取一个行动以到达下一个状态的奖励,以及从那里开始的预期奖励,包含在内。我们将这个新的估算称为行动价值(action-value)。
我们现在可以根据奖励和转移概率正式定义我们的状态价值和行动价值函数。我们将使用期望值来表示奖励和转移概率之间的关系。我们将根据强化学习文献中的标准惯例,分别将状态价值表示为V,将行动价值表示为Q。

状态价值和行动价值的方程式(图片来自作者)
某状态 s[t]的状态价值 V 是从 s[t]开始到未来某状态 s[T]的每个状态的预期奖励 r[t]的总和;某状态 s[t]的行动价值 Q 是从采取某个行动 a[t]到达未来某状态-行动对 s[T],a[T]的每个状态的预期奖励 r[t]的总和。
这个定义实际上不是最准确或最常规的,稍后我们会对其进行改进。然而,它提供了我们所寻求的一个基本思路:对未来奖励的量化度量。
我们的状态值函数 V 是从状态 s 开始,持续移动到提供最高奖励的状态时,最大奖励总和 r 的估计值。我们的动作值函数是通过从某一初始状态采取动作,并持续选择之后提供最高奖励的最优动作,所获得的最大奖励的估计值。在这两种情况下,我们根据预计的奖励选择最优的动作/状态,并不断循环这一过程,直到我们陷入困境或达到目标。
贪心策略与回报
我们选择动作的方法称为 策略。策略是状态的一个函数——给定某一状态,它会输出一个动作。在这种情况下,由于我们希望根据最大化奖励来选择下一个动作,我们的策略可以定义为一个函数,返回从当前状态开始,得到最大动作值(Q 值)的动作,或 argmax。由于我们始终选择最大值,我们将这种特定的策略称为 贪心 策略。我们将我们的策略表示为状态 s 的函数:π(s),其正式定义如下:

策略函数的方程 = 从某一状态 s 得到的最大估计 Q 值所对应的动作(图片来源:作者)
为了简化我们的符号表示,我们还可以定义一个奖励总和的替代项,称为 回报,以及一个状态和动作序列的替代项,称为 轨迹。轨迹,用希腊字母τ(tau)表示,定义如下:

轨迹的表示法:定义为某一状态-动作对的序列,直到某个未来的时间步 T。定义轨迹使我们可以跳过写出整个状态和动作的序列,而用一个单一的变量代替 😛!(图片来源:作者)
由于我们的环境是随机的,因此同样需要考虑轨迹发生的可能性——低概率的轨迹会降低奖励的期望值。(由于我们的期望值是通过将奖励与转移概率相乘得到的,因此低概率的轨迹与高概率的轨迹相比,期望奖励会较低。)这个概率可以通过逐步考虑每个动作和状态发生的概率来推导:在我们的马尔可夫决策过程中(MDP)的每个时间步中,我们将根据策略选择动作,而 resulting 状态将取决于我们选择的动作和转移概率。为了简化,我们将转移概率表示为一个独立的概率分布,它是当前状态和所选动作的函数。因此,某一未来状态发生的条件概率定义为:

从当前状态转移到未来状态的转移概率——对于我们的冰湖问题,我们知道这个值固定在 ~0.33(图示来自作者)
基于我们的策略,某个动作发生的概率仅仅通过将我们的状态传入我们的策略函数来评估。

某个动作被策略选择的概率表达式,给定某个状态(图示来自作者)
我们的策略目前是确定性的,因为它根据最高的期望动作值选择动作。换句话说,具有低动作值的动作永远不会被选择,而具有高 Q 值的动作将始终被选择。这导致了一个伯努利分布在所有可能的动作中。这种策略很少是有利的,正如我们稍后会看到的。
将这些表达式应用到我们的轨迹中,我们可以定义某个轨迹发生的概率为:

发生某个特定轨迹的扩展方程式。请注意,假设每次从相同状态(左上角)开始,s0 的概率固定为 1。(图示来自作者)
为了更清晰地说明,以下是轨迹的原始符号表示:

轨迹的符号表示:定义为一些状态-动作对的序列,直到某个未来时间步 T(图示来自作者)
更简洁地,我们可以写为:

轨迹发生的概率的简洁符号表示(图示来自作者)
定义了轨迹及其概率之后,我们可以替换这些表达式来简化我们对回报及其期望值的定义。回报(奖励的总和),我们将根据惯例定义为G,现在可以表示为:

回报的方程式(图示来自作者)
我们还可以通过引入概率来定义预期回报。既然我们已经定义了轨迹发生的概率,那么预期回报就是

更新后的预期回报方程 = 轨迹发生的概率乘以回报(图片来自作者)
我们现在可以调整价值函数的定义,以包括预期回报。

更新后的状态值和行为值方程(图片来自作者)
这里的主要区别是添加了下标 τ∼π,表示我们的轨迹是通过遵循策略采样得到的(即,我们的行为是基于最大 Q 值来选择的)。我们还去除了下标t以便于清晰。这里再次给出之前的方程,供参考:

状态值和行为值方程(图片来自作者)
折扣回报
所以现在我们有了一个相对明确的回报估计表达式,但在我们开始在环境中迭代之前,仍然有一些问题需要考虑。在我们的冰湖环境中,代理人不太可能无限期地继续探索。总有一天,它会滑倒并掉进一个洞里,导致一局游戏结束。然而,在实际的强化学习环境中,可能没有明确的结束点,训练会持续进行。在这种情况下,假设时间是无限的,预期回报将趋向于无限大,评估状态值和行为值将变得不可能。即使在我们的情况下,为计算回报设定硬性限制通常也没有好处,如果我们设定的限制过高,最终我们得到的回报值也可能是非常大的数字。在这些情况下,确保我们的奖励序列会收敛是很重要的,这可以通过使用折扣因子来实现。这能提高训练过程的稳定性,并确保无论我们考虑多远的未来,回报总是有限的。这样的折扣回报也被称为无限时域折扣回报。
为了在我们的回报方程中加入折扣因素,我们将引入一个新的变量 γ(gamma)来表示折扣因子。

折扣回报方程(图片来自作者)
γ(gamma)必须始终小于 1,否则我们的序列将无法收敛。扩展这个表达式会使这一点更加明显。

展开后的折扣回报方程(图片来自作者)
我们可以看到,随着时间的推移,gamma 将被提高到更高的指数。由于 gamma 小于 1,将其提高到更高的指数只会使其变得更小,从而使未来奖励对整体总和的贡献呈指数级减少。我们可以将这种更新后的回报定义代入我们的价值函数中,尽管由于变量仍然相同,所看到的结果不会发生明显变化。

状态和行动价值的公式,再次复制以便强调(图像来自作者)
探索与利用
我们之前提到过,始终贪婪并不是最好的选择。始终基于最大 Q 值选择行动可能会给我们最大化奖励的最高机会,但这仅在我们最初就有准确的 Q 值估计时成立。为了获得准确的估计,我们需要大量的信息,而我们只能通过尝试新事物——即探索——来获得信息。
当我们基于最高估计的 Q 值选择行动时,我们在利用当前的知识库:我们利用已经积累的经验,试图最大化奖励。当我们基于其他任何指标,甚至是随机选择行动时,我们在探索其他可能性,试图获得更多有用的信息来更新 Q 值估计。在强化学习中,我们希望平衡探索与利用。要正确地利用我们的知识,我们需要拥有知识,而要获得知识,我们必须进行探索。
Epsilon-贪婪策略
我们可以通过将策略从纯贪婪策略转变为epsilon-贪婪策略来平衡探索和利用。epsilon-贪婪策略大部分时间都采取贪婪的行动,概率为 1 - ε,但也有ε的概率进行随机行动。换句话说,我们大部分时间会利用我们的知识来尝试最大化奖励,并且偶尔进行探索以获得更多知识。这不是平衡探索与利用的唯一方法,但它是最简单且最容易实现的一种。
总结
现在我们已经建立了理解强化学习(RL)原理的基础,我们可以继续讨论实际的算法——这将在下一篇文章中进行。目前,我们将概述一个高层次的概述,将所有这些概念结合成一个连贯的伪代码,下一次我们可以深入探讨。
Q 学习
本文的重点是建立理解和实现 Q 学习的基础。Q 学习包含以下步骤:
-
初始化所有行动价值(Q 值)的表格估计,并在我们迭代环境时更新它们。
-
通过从我们的 epsilon-贪婪策略中采样来选择一个行动。
-
收集奖励(如果有的话),并更新我们对行动价值的估计。
-
移动到下一个状态,或者如果掉进坑里或到达目标,则终止。
-
重复步骤 2–4,直到我们的 Q 值估计收敛。
Q 学习是一个迭代过程,我们构建动作值(和预期回报)的估计,或者说“经验”,并利用我们的经验来识别哪些动作最能带来回报。通过多次与环境的交互,这些经验被“学习”,借助这些经验,我们将能够持续达成目标,从而解决我们的 MDP 问题。
词汇表
-
环境 — 智能体不能任意改变的任何事物,也就是它周围的世界
-
状态 — 环境的特定条件
-
观察 — 状态的一部分信息
-
策略 — 给定状态下选择一个动作的函数
-
智能体 — 我们的“学习者”,根据策略在环境中采取行动
-
奖励 — 我们的智能体在执行某些动作后所获得的反馈
-
回报 — 一系列动作的总奖励
-
折扣 — 通过这一过程,我们确保回报不会趋向于无穷大
-
状态值 — 从某个状态开始并按照某个策略继续行动,最终得到的预期回报
-
动作值 — 从某个状态开始并采取某个动作,然后继续根据某个策略行动,最终得到的预期回报
-
轨迹 — 一系列的状态和动作
-
马尔可夫决策过程(MDP)— 我们用来表示强化学习决策问题的模型,也就是具有非确定性边的有向图
-
探索 — 我们如何获得更多知识
-
利用 — 我们如何使用现有的知识库来获得更多的奖励
-
Q 学习 — 一种强化学习算法,我们通过迭代更新 Q 值来获得更好的估计,预测哪些动作能带来更高的预期回报
-
强化学习 — 一种基于过去经验积累和可量化奖励的学习过程
如果你读到这里,考虑给这篇文章留下反馈——我会很感激❤。
参考文献
1 Gymnasium, 冰冻湖(无日期),OpenAI Gymnasium 文档
[2] OpenAI, 深度强化学习入门(无日期),OpenAI
[3] R. Sutton 和 A. Barto, 强化学习:入门(2020),incompleteideas.net/book/RLbook2020.pdf
[4] Spiceworks, 什么是马尔可夫决策过程?(无日期),Spiceworks
[5] IBM, 强化学习(无日期),IBM
弱监督的直观概述
这可能是解决你下一个 NLP 问题的方案。
·发布于Towards Data Science ·8 分钟阅读·2024 年 6 月 29 日
--
在这个故事中,我们介绍并广泛探讨机器学习中的弱监督主题。弱监督是机器学习中的一种学习范式,近年来开始引起了显著关注。简而言之,完全监督要求我们拥有一个训练集(x,y),其中y是x的正确标签;而弱监督假设一个一般情境(x, y’),其中y’不必是正确的(即,它可能是错误的;一个弱标签)。此外,在弱监督中,我们可以有多个弱监督者,因此对于每个示例,可以拥有(x, y’1,y’2,…,y’F),其中每个y’j来自不同的来源,并且可能是错误的。

巨型宽泛无特征怪物,由 DALLE 生成
目录
∘ 问题陈述
∘ 通用框架
∘ 通用架构
∘ Snorkel
∘ 弱监督示例
问题陈述
更实用地说,弱监督旨在解决我所称之为监督式机器学习困境的问题。如果你是一个企业或有一个机器学习新想法的人,你将需要数据。通常,收集大量样本(x1, x2, …, xm)并不困难,有时甚至可以通过编程完成;然而,真正的困境在于,你需要雇佣人工标注员来标记这些数据,并为每个标签支付一定的$Z 费用。问题不仅在于你可能无法确定项目是否值得花这么多钱,而且还在于你可能根本负担不起雇佣标注员的费用,因为这个过程在法律和医学等领域尤其昂贵。
你可能会想,弱监督是如何解决这个问题的?简单来说,与其支付注释员来为你提供标签,不如让他们给你一些通用规则,这些规则在标注数据时有时会不准确(这会节省大量的时间和金钱)。在某些情况下,开发团队甚至可能很轻松地自己搞定这些规则(例如,如果任务不需要专家注释员的话)。
现在让我们思考一个具体的应用场景。你正在构建一个自然语言处理系统,用于屏蔽与敏感信息相关的单词,比如电话号码、姓名和地址。你不需要雇佣人员来标注你收集到的句子中的单词,而是编写一些函数,基于一些规则自动标注数据,比如判断单词是否全为数字(很可能是电话号码,但不一定),判断单词是否以大写字母开头而且不是句首(很可能是名字,但不一定),等等。然后,你可以用这些弱标签数据来训练你的系统。你可能会认为训练出的模型不会比这种标注源更好,但这是不正确的;弱监督模型的设计本意就是要超越标注源的局限,通过识别不确定性并以某种方式进行考虑,从而实现更好的泛化能力。

实验室实验的工程规划论文,来自 DALLE
通用框架
现在让我们正式地看一下弱监督在自然语言处理中的应用框架。
✦ 给定条件
一组F个标注函数 {L1 L2,…,LF},其中Lj是一个弱标签函数,在给定输入x的情况下,任何标注函数Lj都可以是以下任意一种:
-
众包注释员(有时他们的准确性较低)
-
基于距离监督的标签(即,从另一个知识库中提取的标签)
-
弱模型(例如,天生较弱或在其他任务上训练的模型)
-
启发式函数(例如,根据关键字或模式的存在为观察项标注标签,或由领域专家定义)
-
地名词典(例如,根据其在特定列表中的出现为观察项标注标签)
-
在特定提示 P 下调用大型语言模型(最近的研究成果)
-
一般来说,任何能(最好)比随机猜测更好地预测x标签的函数。
通常假设Li可能会选择不提供标签(例如,像“如果单词含有数字,则标注为电话号码,否则不标注”这样的启发式函数)。
假设训练集有 N 个样本,那么在序列分类的情况下,给定的弱标签矩阵等价于一个(N,F)的矩阵。对于长度为 T 的序列分类任务,它是一个(N,T,F)的弱标签矩阵。
✦ 期望
训练一个模型 M,有效利用弱标签数据,并结合任何强标签数据(如果存在的话)。
✦ 常见的自然语言处理任务
-
序列分类(例如,情感分析)或 标记分类(例如,命名实体识别),其中标签函数通常是启发式函数或地名词典。
-
低资源翻译 (x→y),其中标签函数通常是一个较弱的翻译模型(例如,采用反向方向的翻译模型 (y→x) 来增加更多的 (x,y) 翻译对)。
一般架构
对于序列或标记分类任务,文献中最常见的架构通常采用以下形式:

来自论文的图示 WRENCH: 弱监督的综合基准
标签模型学习将标签函数的输出映射到概率性或确定性的标签,这些标签用于训练最终模型。换句话说,它接收上述讨论的 (N,F) 或 (N,T,F) 标签矩阵,并返回 (N) 或 (N,T) 的标签矩阵(这些标签通常是概率性的(即软标签))。
最终模型在此步骤后单独使用,仅是一个普通的分类器,操作的是标签模型生成的软标签(交叉熵损失函数能够处理这种情况)。某些架构使用深度学习将标签模型和最终模型合并。
请注意,一旦我们训练了标签模型,我们就用它为最终模型生成标签,之后我们不再使用标签模型。从这个意义上说,这与堆叠模型是相当不同的,即使标签函数本身是其他机器学习模型。
另一种架构,这是翻译任务中的默认架构(对于序列/标记分类任务则不常见),是根据弱示例(源语言、目标语言)对的质量来加权(通常翻译任务只有一个标签函数,即反向模型,正如前面所讨论)。这种权重可以用于损失函数,使得模型从高质量的示例中学习得更多,而从低质量的示例中学习得更少。在这种情况下的方法尝试设计用于评估特定示例质量的方案。例如,一种方法使用回译 BLEU 分数(即将句子翻译到目标语言,然后再翻译回源语言)来估算这种权重。
Snorkel

来自 Snorkel 的图片:使用弱监督快速创建训练数据
为了查看标签模型如何工作,我们可以看看Snorkel,它无疑是序列分类中弱监督的最基础性工作。

来自论文的方程式
在 Snorkel 中,作者的目标是找到 P(yi|Λ(xi)),其中 Λ(xi) 是第 i 个示例的弱标签向量。显然,一旦找到了这个概率,我们可以将其作为最终模型的软标签(因为正如我们所说,交叉熵损失可以处理软标签)。同样显而易见的是,如果我们有 P(y, Λ(x)),那么我们可以轻松地用它来求得 P(y|Λ(x))。
从上面的方程中我们可以看到,他们使用了与逻辑回归相同的假设来建模 P(y, Λ(x))(Z 用于归一化,类似于 Sigmoid/Softmax)。不同之处在于,代替 w.x,我们有 w.φ(Λ(xi),yi)。特别地,φ(Λ(xi),yi) 是一个维度为 2F+|C| 的向量。F 是之前提到的标注函数的数量;同时,C 是一组相关的标注函数对(因此,|C| 是相关对的数量)。作者提到在另一篇论文中有一种方法来自动构建 C,这里为了简洁起见不再深入探讨。
向量 φ(Λ(xi),yi) 包含:
-
F个二元元素,用于指定每个标注函数是否在给定示例中放弃
-
F个二元元素,用于指定每个标注函数是否等于真实标签 y(此处 y 作为变量,它是分布的输入)
-
C个二元元素用于指定给定此示例时,每一对相关的元素是否做出了相同的投票
他们接着通过解决以下目标来训练这个标签模型(最小化负对数边际似然):

论文中的方程
注意,他们不需要关于 y 的信息,因为这个目标是无论 y 的具体值如何都能解决的,如通过求和所示。如果你仔细观察(去掉负号和对数),你会发现这等同于找到那些最大化任何真实标签的概率的权重。
一旦标签模型训练完成,他们使用该模型生成N个软标签 P(y1|Λ(x1)), P(y2|Λ(x2)),…,P(yN|Λ(xN)),并利用这些标签来训练某些判别模型(即分类器)。
弱监督示例
Snorkel 提供了一个出色的垃圾邮件分类教程在这里。Skweak 是另一个基本的弱监督标注包(以及相关论文),用于标记分类。下面是如何开始使用 Skweak 的示例,如在他们的 GitHub上所示:
首先定义标注函数:
import spacy, re
from skweak import heuristics, gazetteers, generative, utils
### LF 1: heuristic to detect occurrences of MONEY entities
def money_detector(doc):
for tok in doc[1:]:
if tok.text[0].isdigit() and tok.nbor(-1).is_currency:
yield tok.i-1, tok.i+1, "MONEY"
lf1 = heuristics.FunctionAnnotator("money", money_detector)
### LF 2: detection of years with a regex
lf2= heuristics.TokenConstraintAnnotator("years", lambda tok: re.match("(19|20)\d{2}$",
tok.text), "DATE")
### LF 3: a gazetteer with a few names
NAMES = [("Barack", "Obama"), ("Donald", "Trump"), ("Joe", "Biden")]
trie = gazetteers.Trie(NAMES)
lf3 = gazetteers.GazetteerAnnotator("presidents", {"PERSON":trie})
将它们应用于语料库
# We create a corpus (here with a single text)
nlp = spacy.load("en_core_web_sm")
doc = nlp("Donald Trump paid $750 in federal income taxes in 2016")
# apply the labelling functions
doc = lf3(lf2(lf1(doc)))
创建并拟合标签模型
# create and fit the HMM aggregation model
hmm = generative.HMM("hmm", ["PERSON", "DATE", "MONEY"])
hmm.fit([doc]*10)
# once fitted, we simply apply the model to aggregate all functions
doc = hmm(doc)
# we can then visualise the final result (in Jupyter)
utils.display_entities(doc, "hmm")
然后,你当然可以在此基础上使用估算的软标签训练分类器。
在本文中,我们探讨了弱监督所解决的问题,提供了正式的定义,并概述了在这种背景下通常采用的一般架构。我们还深入研究了 Snorkel,这是弱监督中的基础模型之一,并以一个实际示例结束,说明了弱监督如何应用。

由 DALLE 创作的“Jeep Going Away Bye”
希望你觉得这篇文章有用。下次再见,au revoir。
参考文献
1 Zhang, J. 等 (2021) Wrench: 一个全面的弱监督基准测试, arXiv.org。可在以下网址获取:arxiv.org/abs/2109.11377。
[2] Ratner, A. 等 (2017) Snorkel: 使用弱监督快速创建训练数据, arXiv.org。可在以下网址获取:arxiv.org/abs/1711.10160。
[3] NorskRegnesentral (2021) NorskRegnesentral/skweak: Skweak: 一个应用于 NLP 任务的弱监督软件工具包, GitHub。可在以下网址获取:github.com/NorskRegnesentral/skweak。
互信息的直观视角
普通人如何理解关联概念
·发表于Towards Data Science ·8 分钟阅读·2024 年 3 月 13 日
--

最近,我一直在进行一个项目,旨在筛选股市中的变量对,并查看它们是否显示出足够的相关性潜力,以便我们可以深入研究。
在我的研究过程中,我接触了许多不同的方法论;从简单的斯皮尔曼/皮尔逊线性相关,到使用时间延迟嵌入和甚至机器学习技术的更先进的非线性方法。
那时,我偶然发现了这个强大且具有概率性的概念——互信息,它帮助人们衡量两个变量之间的关联/依赖程度。这是模型开发或关联研究中的一个很好的起步工具。
在网上进一步了解时,我意识到虽然有很多出色的数学和统计学解释,但很少有直观的见解来阐明互信息为何及如何起作用。
因此,我们现在在这里,
欢迎来到我的尝试,帮助你拆解并理解这个统计学概念!
教科书定义
互信息是衡量通过观察一个变量可以获得多少关于另一个变量的“信息”的度量。
我相信你在自己的研究中一定见过上述陈述,或者它的变种。但究竟是什么是他们所说的“信息”?它如何告诉我两个变量是相关/依赖的?
当你看到这个公式时,定义变得更加令人生畏:

离散观察下互信息的公式
不要担心!让我们通过案例研究来分解这个概念,变得易于理解。
“人们真的只有在下雨时才用伞吗?”

图片来自zhang dayong在Unsplash上的分享
这是 Bob 说的,你的醉酒朋友,在节日狂欢的夜晚。
他坚持认为人们只有在自己想带伞时才会带伞,而不是因为需要它来遮挡雨水。
你认为这个说法太荒谬了!它挑战了你成长过程中所做的每一次观察;挑战了你骨子里的每一个逻辑观念。
你决定跟踪 Bob,并在他在热带新加坡度假的 5 天里观察他。你想看看他是否真的做到了他所宣称的行为。
你决定使用互信息的概念来进行分析。
Bob 与互信息
我们可以将互信息公式分解为以下几个部分:
x, X 和 y, Y
x和y是我们在数据中看到的个体观察值/数值。X和Y仅是这些个体值的集合。一个很好的例子如下:

离散/二进制观察携伞与天气
假设我们有 Bob 在这个确切顺序中的 5 天观察数据:

在 5 天内对携伞与天气的离散/二进制观察
个体/边际概率

这些仅仅是观察特定x或y在其各自的X和Y值集合中的简单概率。
以x = 1为例:概率简单为0.4(Bob 在假期的 5 天中有 2 天带了伞)。
联合概率

这是从联合概率(X, Y)中观察特定x和y的概率。联合概率(X, Y)仅仅是成对观察的集合。我们根据它们的索引将它们配对。
在我们的案例中,我们根据它们发生的具体日期将 Bob 的观察配对起来。

看完这些对的配对后,你可能会忍不住得出结论:
由于有 80%的时间发生了相等值的配对,这明显意味着人们带伞是因为下雨了!
好吧,我来充当魔鬼代言人,认为这可能只是一个极其巧合的事件:
如果新加坡的降雨几率非常低,并且 Bob 独立地携带雨伞的可能性也同样低(因为他讨厌拿额外的东西),你能看到(0,0)配对观察的几率自然会非常高吗?
那么我们可以做什么来证明这些配对观察结果并非巧合呢?
联合概率与个体概率

我们可以计算这两个概率的比值,从而给我们一个关于“巧合程度”的线索。
在分母中,我们取特定x和特定*y 发生的个体概率的乘积。我们为什么要这么做?
窥探那朴素的硬币抛掷
回想一下你在统计课上学的第一课:计算在 2 次公平抛掷中得到 2 个正面的概率。
-
第一次抛掷 [p(x)]:得到正面的几率是 50%
-
第二次抛掷 [p(y)]:得到正面的几率仍然是 50%,因为结果独立于第一次抛掷发生的事情。
-
上面这 2 次抛掷构成了你的个体概率
-
因此,理论上在 2 次独立抛掷中得到 2 个正面的概率是 0.5 * 0.5* = 0.25*( p(x).p(y))

如果你实际上做了 100 组这样的双重硬币抛掷实验,你很可能会看到你得到(正面, 正面)结果的几率是 25%。这 100 组实验实际上是你的(X, Y)联合概率集!
因此,当你计算联合概率与各自个体概率的比值时,你会得到一个值为1。
这实际上是独立事件的真实期望:特定值对的联合概率正好等于它们各自概率的乘积!就像你在基础统计学中学到的那样。
现在假设你的 100 组实验中,(正面, 正面)的结果出现了 90%的时间。这肯定不是巧合……
你原本预计是 25%,因为你知道这些事件是独立的,但观察到的却是这种期望的极端偏斜。
为了将这种定性感觉转化为数字,概率的比值现在是惊人的3.6 (0.9 / 0.25),也就是说,比我们预期的要频繁 3.6 倍。
因此,我们开始认为也许硬币抛掷并不是独立的。也许第一次抛掷的结果实际上会对第二次抛掷产生一些未解释的影响。也许第一次和第二次抛掷之间存在某种程度的关联/依赖。
这就是互信息试图告诉我们的!
观察的期望值

为了对 Bob 公平,我们不仅仅应该看他错误声明的次数,即计算(0,0)和(1,1)的概率比。
我们还应该计算他声明正确时的概率比,即(0,1)和(1,0)的概率比。
此后,我们可以通过聚合所有 4 种情形的期望值方法,这意味着“取平均”:将每一对观察到的(X, Y)的概率比率聚合起来,然后除以观察的次数。
这就是这两个求和项的目的。对于像我的股市示例这样的连续变量,我们将使用积分来代替。
比率的对数

类似于我们计算掷硬币时得到两个连续正面的概率一样,我们现在也在计算观察到的 5 对的额外概率。
对于掷硬币,我们通过乘法计算每次掷硬币的概率。对于 Bob 来说也是一样:概率之间有乘法效应,从而给我们提供了我们在联合集合中观察到的序列。
使用对数,我们将乘法效应转化为加法效应:

将概率比率转换为它们的对数形式后,我们现在可以简单地按照上面描述的使用对数求和来计算期望值。
随意使用底数为 2、e 或 10 的对数,本文目的并不受影响。
将所有内容汇总

离散观察的互信息公式
现在让我们通过计算互信息来证明 Bob 是错的。我将使用以 e 为底的对数(自然对数)进行计算:

那么,值为 0.223 的结果告诉了我们什么?
首先假设 Bob 是对的,且伞的使用与下雨的存在是独立的:
-
我们知道,联合概率将恰好等于单独概率的乘积。
-
因此,对于每一个 x 和 y 的排列,概率比率 = 1。
-
取对数后,结果等于 0。
-
因此,所有排列的期望值(即互信息)因此为0。
但是,由于我们计算的互信息分数是非零的,因此我们可以证明 Bob 是错的!
超越线性相关
因为互信息是一个概率性关联/依赖度的度量,它同样适用于非线性相关的研究!
举个例子,考虑两个变量 X 和 Y:

计算它们的互信息分数、Spearman 等级相关系数,并绘制图表,我们得到如下结果:

Y 对 X:Y 是 X 的确定性、非线性缩放
仅依赖 Spearman 等级相关系数,我们会认为这两个变量没有任何关系,但我们明明知道它们是确定性相关的(基于我上面的公式)!
非零的互信息分数提示我们需要深入探讨,尽管它没有给出关系的明确形式。
它也足够强大,可以处理严格的线性相关性:

Y 对 X:Y 是 X 的确定性线性变换
所以,如果你在进行 X 对 Y 分析时不确定你期待的相关性是什么,你可以尝试将互信息作为零步来使用!
我对互信息的“外行”定义

通过以上示例和分析,我希望我已经帮助大家直观地理解了互信息是什么以及它是如何工作的。
如果这对你有帮助,我倾向于将互信息总结如下:
互信息为我们提供了 x 和 y 同时发生的额外概率,这一概率是由其他因素引起的,而不仅仅是它们同时发生的机会。
互信息在构建机器学习模型之前的特征选择等领域非常有用,甚至在与文本嵌入一起使用时,能用于文本关联分析。因此,在我们采用它来应对各种用途之前,真正理解它的工作原理是至关重要的。
凭借你新获得的直觉和理解,我相信你将能够发现其他机会,将这一多用途概念应用于其他领域,就像我在股市投资中所做的那样!
一种非传统的训练-测试-验证集划分方法
确保在小数据集划分中的分布一致性
·发表于Towards Data Science ·阅读时长 8 分钟·2024 年 7 月 7 日
--

使用 Microsoft Designer 生成
我们都需要对总体进行抽样,以进行统计分析并获得洞见。当我们这样做时,目标是确保我们的样本分布尽可能接近总体分布。
为此,我们有多种方法:简单随机采样(每个样本在总体中被选中的概率相等),分层采样(将总体划分为子组,并从每个子组中进行采样),聚类采样(将总体分为若干簇,并随机选取整簇),系统采样(每隔 n 个样本选择一个样本),等等。每种方法都有其优点,并根据研究的具体需求和特征进行选择。
在本文中,我们不会专注于采样方法本身,而是会介绍如何利用这些概念将用于机器学习方法的数据集划分为训练-测试-验证集。这些方法适用于所有类型的表格数据。我们将在这里使用 Python 进行演示。
以下是您可能已经了解的一些方法:
1. 简单训练-测试-验证集划分
该方法使用随机采样方法。
示例代码:
from sklearn.model_selection import train_test_split
# Assuming X is your feature set and y is your target variable
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
2. 分层训练-测试-验证集划分
该方法确保划分保持与原始数据集相同的类别比例(当然还是采用随机采样),这对于类别不平衡的数据集非常有用。当目标变量不是连续变量时,该方法有效。
from sklearn.model_selection import train_test_split
# Stratified split to maintain class distribution
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)
3. K 折交叉验证
在 K 折交叉验证中,数据集被切分为k个子集(折)。模型在k-1个折上训练,在剩余的一个折上测试。这个过程会重复k次。
from sklearn.model_selection import KFold, train_test_split
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
# Further split X_train and y_train into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
# Now you have X_train, X_val, X_test, y_train, y_val, y_test for each fold
# You can now train and evaluate your model using these sets
4. 分层 K 折交叉验证
如名称所示,这是分层抽样和 K 折交叉验证的结合。
from sklearn.model_selection import StratifiedKFold, train_test_split
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in skf.split(X, y):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
# Further split X_train and y_train into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train, random_state=42)
# Now you have X_train, X_val, X_test, y_train, y_val, y_test for each fold
# You can now train and evaluate your model using these sets
完整示例使用:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
# Initialize lists to store the scores for each fold
accuracy_scores = []
precision_scores = []
recall_scores = []
f1_scores = []
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in skf.split(X, y): #y is a categorical target variable
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
# Further split X_train and y_train into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train, random_state=42)
# Train the model
model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)
# Validate the model
y_val_pred = model.predict(X_val)
val_accuracy = accuracy_score(y_val, y_val_pred)
val_precision = precision_score(y_val, y_val_pred, average='weighted')
val_recall = recall_score(y_val, y_val_pred, average='weighted')
val_f1 = f1_score(y_val, y_val_pred, average='weighted')
print(f"Validation Scores - Accuracy: {val_accuracy}, Precision: {val_precision}, Recall: {val_recall}, F1 Score: {val_f1}")
# Test the model
y_test_pred = model.predict(X_test)
test_accuracy = accuracy_score(y_test, y_test_pred)
test_precision = precision_score(y_test, y_test_pred, average='weighted')
test_recall = recall_score(y_test, y_test_pred, average='weighted')
test_f1 = f1_score(y_test, y_test_pred, average='weighted')
# Store the scores
accuracy_scores.append(test_accuracy)
precision_scores.append(test_precision)
recall_scores.append(test_recall)
f1_scores.append(test_f1)
print(f"Test Scores - Accuracy: {test_accuracy}, Precision: {test_precision}, Recall: {test_recall}, F1 Score: {test_f1}")
# Calculate and print the average scores across all folds
print(f"\nAverage Test Scores across all folds - Accuracy: {sum(accuracy_scores) / len(accuracy_scores)}, Precision: {sum(precision_scores) / len(precision_scores)}, Recall: {sum(recall_scores) / len(recall_scores)}, F1 Score: {sum(f1_scores) / len(f1_scores)}")
现在,你可以使用这些方法来切分数据集,但它们有以下局限性:
-
随机训练-测试-验证切分: 这种方法无法确保切分后的数据具有相似的 分布,尤其是在数据集不够大或目标变量存在不平衡时。
-
分层切分: 这种方法仅在目标变量(y)是非连续的时有效。虽然对于连续目标变量也有一些解决方法(例如通过某些条件将连续变量转换为类别变量,比如如果 y ≥ 四分位数 1 → 1,否则为 0),但这些方法仍然不能确保切分后的数据具有相似的分布。
现在,假设你的数据集中总的观测数较小,并且很难确保切分后的数据分布相似。在这种情况下,你可以结合聚类和随机抽样(或分层抽样)。
以下是我在我的问题上如何操作的:
5. 基于聚类的训练-测试-验证切分
在这种方法中,我们首先对数据集进行聚类,然后对每个聚类使用抽样方法来获得数据切分。
例如,使用HDBSCAN:
import hdbscan
from sklearn.metrics import silhouette_score
from sklearn.model_selection import ParameterGrid
import random
random.seed(48) #for regeneration of same results
def get_clusters(df):
to_drop =["cluster_", "ID"]
req_cols = sorted(set(df.columns) - set(to_drop))
X = df[req_cols] #keep only required columns in X
X_std = X.values #no need of scaling the training set for HDBSCAN
# Define parameter grid for HDBSCAN, you can play with this grid accordingly
param_grid = {
'min_cluster_size': list(range(2,20))
#'min_samples': [1, 2, 3]
}
best_score = -1
best_params = None
# Iterate over parameter grid
for params in ParameterGrid(param_grid):
model = hdbscan.HDBSCAN(**params, gen_min_span_tree=True)
cluster_labels = model.fit_predict(X_std)
unique_labels = np.unique(cluster_labels)
if len(unique_labels) > 1: # Check if more than one cluster is formed
silhouette_avg = silhouette_score(X_std, cluster_labels) if len(unique_labels) > 1 else -1
if silhouette_avg > best_score:
best_score = silhouette_avg
best_params = params
if best_params is not None:
print(best_params)
best_model = hdbscan.HDBSCAN(**best_params, gen_min_span_tree=True)
cluster_labels = best_model.fit_predict(X_std) #get cluster labels from best model
df["cluster_"] = [str(i) for i in cluster_labels]
else:
print("HDBSCAN produced only one cluster label. Unable to split the data.")
df["cluster_"] = "0" #when no clusters are found
return df
根据你的问题,你也可以使用其他聚类方法,例如K 均值聚类:
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from yellowbrick.cluster import KElbowVisualizer
def get_clusters(df):
to_drop =["cluster_", "ID"]
req_cols = sorted(set(df.columns) - set(to_drop))
X = df[req_cols].values #keep only required columns in X
scaler = StandardScaler()
X_std = scaler.fit_transform(X) #scaling is needed in case of K-Means
model = KMeans()
visualizer = KElbowVisualizer(model, k=(2, 50)) #you can play with the range accordingly
visualizer.fit(X_std)
#visualizer.show()
optimal_n_clusters = visualizer.elbow_value_ #using elbow method to get optimal no. of clusters
kmeans = KMeans(n_clusters=optimal_n_clusters, random_state=42)
kmeans.fit(X_std)
clust_labels = [str(i) for i in kmeans.labels_]
# Evaluate the clustering using silhouette score
silhouette_avg = silhouette_score(X_std, clust_labels)
df["cluster_"] = clust_labels
return df
现在你还可以向数据集中添加粒度层次(任何类别变量),从而获得更精细的聚类,具体如下:
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler
from yellowbrick.cluster import KElbowVisualizer
def get_clusters(df):
# taking animal categorical variable as a level of granularity to split on
grp1 = df.loc[(df['animal']=='cat')]
grp2 = df.loc[(df['animal']=='dog')]
temps = []
for num, temp in enumerate([grp1, grp2]):
to_drop =["cluster_", "ID"]
final_cols = sorted(set(temp.columns) - set(to_drop))
X = temp[final_cols]
X = X.values
scaler = StandardScaler()
X_std = scaler.fit_transform(X) #scaling of variables is needed for K-Means clustering
model = KMeans()
visualizer = KElbowVisualizer(model, k=(2, 50))
visualizer.fit(X_std)
# visualizer.show()
#get optimal no. of clusters, K using elbow method
optimal_n_clusters = visualizer.elbow_value_
kmeans = KMeans(n_clusters=optimal_n_clusters, random_state=42)
kmeans.fit(X_std)
clust_labels = [str(num) + "_" + str(i) for i in kmeans.labels_]
# Evaluate the clustering using silhouette score
silhouette_avg = silhouette_score(X_std, clust_labels)
temp["cluster_"] = clust_labels
temps.append(temp)
df = pd.concat(temps, axis=0)
return df
一旦你从任何聚类方法中获得了聚类标签,你可以使用随机抽样或分层抽样从每个聚类中选择样本。
我们将随机选择索引,然后使用这些索引来选择我们的训练-测试-验证集,如下所示:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
# Assuming df is your DataFrame, "cluster_" is the column with cluster labels,
unique_clusters = df["cluster_"].unique()
train_indices = []
val_indices = []
test_indices = []
for cluster in unique_clusters:
cluster_data = df[df["cluster_"] == cluster]
cluster_indices = cluster_data.index.values
cluster_y = cluster_data['y'].values
if stratify_ == True: #if you have categorical target variable
train_idx, temp_idx, _, temp_y = train_test_split(cluster_indices, cluster_y, test_size=0.4, stratify=cluster_y, random_state=42)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, stratify=temp_y, random_state=42)
else:
# Split indices of the current cluster into train and temp (which will be further split into val and test)
train_idx, temp_idx = train_test_split(cluster_indices, test_size=0.4, random_state=42)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)
train_indices.extend(train_idx)
val_indices.extend(val_idx)
test_indices.extend(test_idx)
# Convert the indices lists to numpy arrays
train_indices = np.array(train_indices)
val_indices = np.array(val_indices)
test_indices = np.array(test_indices)
# Assuming 'X' are the features and 'y' is the target column
X = df.drop(columns=['y', 'cluster_']).values
y = df['y'].values
# Select the corresponding data for train, validation, and test sets
X_train, y_train = X[train_indices], y[train_indices]
X_val, y_val = X[val_indices], y[val_indices]
X_test, y_test = X[test_indices], y[test_indices]
根据我的使用案例,将目标变量 y 排序,然后选择每个第 1、第 2 和第 3个索引作为训练、测试和验证集,分别为(互斥的),即所谓的系统性随机抽样,如下所示:
def get_indices(df):
np.random.seed(seed=48)
total_length = len(df)
sample1_length = int(0.60 * total_length) #you can choose proportion accordingly
remaining_length = total_length - sample1_length
sample2_length = int(remaining_length / 2)
sample3_length = total_length - (sample1_length + sample2_length)
#create an array with range 0 - length of the df
all_indxs = np.array(range(total_length))
# Create arrays of indices divisible by 2 and 3 exclusively
indices_divisible_by_2 = np.array(list(set(np.where(all_indxs % 2 == 0)[0]) - set(np.where(all_indxs % 6 == 0)[0])))
indices_divisible_by_3 = np.array(list(set(np.where(all_indxs % 3 == 0)[0]) - set([0])))
#randomly choose indices divisibly by 2 with sample2_length
sample2_indices = sorted(indices_divisible_by_2[np.random.choice(len(indices_divisible_by_2), size=sample2_length, replace=False)])
try:
sample3_indices = sorted(indices_divisible_by_3[np.random.choice(len(indices_divisible_by_3), size=sample3_length, replace=False)])
except:
sample3_indices = []
sample1_indices = sorted(set(all_indxs) - set(sample2_indices) - set(sample3_indices))
return sample1_indices, sample2_indices, sample3_indices
indices_train = []
indices_test = []
indices_val = []
for num, cluster in enumerate(df['cluster_'].unique()):
temp_df = df[df['cluster_'] == cluster]
sample1_indices, sample2_indices, sample3_indices = get_indices(temp_df)
indices_train.append(list(temp_df.iloc[sample1_indices].index))
indices_test.append(list(temp_df.iloc[sample2_indices].index))
indices_val.append(list(temp_df.iloc[sample3_indices].index))
# to flatten the list of lists containing indices for train,test,val set
indices_train = [x for xs in indices_train for x in xs]
indices_test = [x for xs in indices_test for x in xs]
indices_val = [x for xs in indices_val for x in xs]
def traintestvalsplit(df, id_col, cols_to_drop, cont_var, train_indices, test_indices, val_indices):
train, test, val = df.loc[train_indices], df.loc[test_indices], df.loc[val_indices]
# Split the data into train, validation, and test sets based on indices
X_train = train.drop(cols_to_drop + [cont_var] ,axis=1) #add which columns to drop
X_test = test.drop(cols_to_drop + [cont_var] ,axis=1)
X_val = val.drop(cols_to_drop + [cont_var] ,axis=1)
y_train = train[[cont_var]] #target variable
y_test = test[[cont_var]]
y_val = val[[cont_var]]
train_ids = train[[id_col]] #to preserve the IDs
test_ids = test[[id_col]]
val_ids = val[[id_col]]
print("Train set size:", X_train.shape, len(train_ids))
print("Test set size:", X_test.shape, len(test_ids))
print("Validation set size:", X_val.shape, len(val_ids))
return X_train, X_val, X_test, y_train, y_val, y_test, train_ids, val_ids, test_ids
X_train, X_val, X_test, y_train, y_val, y_test, train_ids, val_ids, test_ids = traintestvalsplit(df, id_col, cols_to_drop, cont_var, train_indices, test_indices, val_indices)
上述将聚类与不同抽样方法结合的方式,在数据集观测数量较少时非常有用,因为它们确保了在训练、测试和验证集之间保持相似的分布。
感谢阅读,希望你觉得这篇文章有帮助!
基于开放数据的优化医疗设施位置的 Python 方法
使用开放数据栈的 Python 教程
·发布于Towards Data Science ·阅读时间:13 分钟·2024 年 6 月 11 日
--

该图像由作者在Midjourney中生成
这项工作是与 Joaquim Gromicho 教授以及 Kai Kaiser共同合作完成的。作者对所有错误和遗漏负责。
根据 2020 年发布的关于全球前往医疗设施的旅行时间地图的研究,43.3%(约 31.6 亿人)无法在一小时内步行到达医疗设施。
准确计算前往医疗设施的旅行时间是评估医疗可达性的基础,尤其是在可达性障碍可能显著影响公共健康结果的地区。这些计算对于资源分配、医疗服务利用、公平的医疗访问以及未来设施的战略规划至关重要。然而,要进行这些计算,需要大量的数据处理,包括医院位置、人口分布以及基于道路网络数据(如 OpenStreetMaps)或 API(如 Google 或 Mapbox)的旅行时间计算。
地理变异性,如不同的地形、道路状况和天气,也会影响计算结果…
面向日全食追逐者的开源数据库
·发表于 Towards Data Science ·阅读时长 11 分钟·2024 年 4 月 14 日
--

图像由 midjourney 生成
开放数据的动机
以免说得太显而易见,数据科学家最大的弱点就是在没有高质量数据的情况下无法实践他们的专业技能。而创建一个高质量的数据集并不简单。这成了通过这一学科增值的最明显障碍。与工程不同,在工程中你可以从第一天开始卷起袖子动手建设,但数据科学家在没有数据的情况下,几乎做不了什么。
在一个大到中型的组织中,这个问题通常通过首先投资数据工程来解决,确保数据流通,数据科学家可以在此基础上进行工作,发挥他们的技能。这些数据集的一个重要特征是它们不是静态的,而是动态的。随着业务的变化,数据不断流入这些数据集,使它们保持动态并不断发展。基于这些数据集构建的数据科学产品也可以不断发展。这形成了一个正反馈循环,一旦人们看到数据科学产品带来的价值,它就会推动进一步投资数据工程,收集更丰富的数据,从而推动更强大的数据科学应用,依此类推。
虽然这个故事在各种组织的闭门会议中反复上演,但我还没看到它在…领域展开。
上下文 Bandit 概述
一种治疗个性化的动态方法
·发表于 Towards Data Science ·19 分钟阅读·2024 年 2 月 2 日
--
大纲
-
介绍
-
何时使用上下文 Bandit
2.1. 上下文 Bandit 与多臂赌博机(MAB)与 A/B 测试
2.2. 上下文 Bandit 与多个多臂赌博机(MAB)
2.3. 上下文 Bandit 与多步强化学习
2.4. 上下文 Bandit 与提升建模
-
上下文 Bandit 中的探索与利用
3.1. ε-贪婪策略
3.2. 上置信界限(UCB)
3.3. 汤普森采样
-
上下文 Bandit 算法步骤
-
上下文 Bandit 中的离线策略评估
5.1. 使用因果推断方法的 OPE
5.2. 使用采样方法的 OPE
-
上下文 Bandit 实践
-
结论
-
致谢
-
参考文献
1. 介绍
想象一下这样的场景:你刚刚开始了一个为期两周的 A/B 测试。然而,仅仅一两天后,越来越明显地发现版本 A 对于某些类型的用户效果更好,而版本 B 对于另一组用户效果更佳。你心想:也许我应该重新引导流量,让用户更多地接触到对他们更有益的版本,而少接触到另一个版本。有没有一种有原则的方法可以实现这一点?
上下文赌博者(Contextual Bandits)是一类一阶强化学习算法,专门为这种治疗个性化问题设计,在这种问题中,我们希望根据哪种治疗对谁有效来动态调整流量。尽管它们在可以实现的效果上极为强大,但它们是数据科学中较少为人知的算法之一,我希望这篇文章能为你提供一个全面的介绍,帮助你了解这个令人惊叹的话题。事不宜迟,让我们直接深入了解吧!
2. 何时使用上下文赌博者
如果你刚刚开始了解上下文赌博者,可能会对上下文赌博者与其他更广为人知的方法(如 A/B 测试)之间的关系感到困惑,以及为什么你可能会选择使用上下文赌博者而不是其他方法。因此,我们将从讨论上下文赌博者与相关方法之间的相似性和差异性开始我们的旅程。
2.1. 上下文赌博者 vs 多臂赌博者 vs A/B 测试
让我们从最基本的 A/B 测试设置开始,该设置将流量以静态的方式分配到治疗组和控制组。例如,一个数据科学家可能决定进行为期两周的 A/B 测试,50%的流量分配给治疗组,50%分配给控制组。这意味着无论我们处于测试的第一天还是最后一天,我们都将以 50%的概率将用户分配给治疗组或控制组。
另一方面,如果数据科学家在这种情况下使用多臂赌博者(MAB)而不是 A/B 测试,那么流量将以动态的方式分配到治疗组和控制组。换句话说,MAB 中的流量分配将随着时间的推移而变化。例如,如果算法在第一天判断治疗组优于控制组,那么流量分配可能会从第一天的 50%治疗和 50%控制调整为第二天的 60%治疗和 40%控制,依此类推。
尽管动态分配流量,MAB 忽略了一个重要事实,那就是并不是所有用户都一样。这意味着对某一类用户有效的治疗方法可能对另一类用户无效。例如,可能出现这种情况:尽管治疗对核心用户更有效,但对休闲用户而言,控制组实际上更好。在这种情况下,即使治疗方法整体上更有效,我们如果将更多核心用户分配给治疗组,将更多休闲用户分配给控制组,实际上可以从我们的应用中获得更多价值。
这正是上下文赌博机(CB)派上用场的地方。尽管多臂赌博机(MAB)只关注治疗组或对照组整体效果如何,总体上是否表现更好,CB 则关注治疗组或对照组对于具有特定特征的用户表现如何。在上下文赌博机中,“上下文”恰恰指的就是这些用户特征,这也是其与多臂赌博机的区别所在。例如,CB 可能会在观察到第一天的数据后,决定将核心用户的治疗分配提高到 60%,而将普通用户的治疗分配降至 40%。换句话说,CB 会根据用户特征(在此例中为核心用户与普通用户)动态调整流量分配。
下表总结了 A/B 测试、多臂赌博机和上下文赌博机的关键区别,接下来的图表将可视化这些概念。
表 1:A/B 测试、多臂赌博机和上下文赌博机的区别

图 1:A/B 测试、多臂赌博机和上下文赌博机中的流量分配

2.2. 上下文赌博机与多个多臂赌博机
到此,你可能会产生这样的想法:上下文赌博机(CB)不过是多个多臂赌博机(MAB)同时运行的集合。事实上,当我们关注的上下文较小(例如,我们只关心一个用户是核心用户还是普通用户)时,我们可以简单地为核心用户运行一个 MAB,为普通用户运行另一个 MAB。然而,当上下文变得庞大(例如核心与普通用户、年龄、国家、上次活跃时间等)时,为每个独特的上下文值运行一个单独的 MAB 变得不切实际。
在这种情况下,上下文赌博机的真正价值体现在通过使用模型来描述不同上下文中的实验条件与我们关注的结果(例如转化率)之间的关系。与逐一列举每个上下文值并将其独立处理不同,使用模型可以让我们共享来自不同上下文的信息,从而处理更大的上下文空间。这个模型的概念将在本文的多个部分进行讨论,因此请继续阅读以了解更多内容。
2.3. 上下文赌博机与多步强化学习
引言中将上下文赌博机(CB)称为一种单步强化学习(RL)算法。那么,单步与多步强化学习到底有什么区别呢?是什么使得 CB 成为单步学习?上下文赌博机与多步强化学习的根本区别在于,在 CB 中,我们假设算法所采取的行动(例如,为特定用户提供治疗或对照组)不会影响系统整体的未来状态。换句话说,状态(或在 CB 中更适当称为“上下文”)影响我们采取的行动,但我们采取的行动不会反过来影响或改变状态。下图总结了这一区别。
图 2:上下文赌博机与多步强化学习

图片由作者提供,灵感来源于 source
一些例子应该能让这个区别更加清晰。假设我们正在构建一个系统,根据用户的年龄决定展示哪些广告。我们预期,不同年龄段的用户可能会觉得不同的广告与他们更相关,这意味着用户的年龄应该影响我们展示给他们的广告。然而,我们展示的广告并不会反过来影响他们的年龄,所以 CB 的单步假设似乎成立。然而,如果我们进一步发现,展示昂贵的广告会消耗我们的库存(并限制未来我们能展示的广告),或者我们今天展示的广告会影响用户是否会再次访问我们的网站,那么单步假设就间接被违反了,因此我们可能需要考虑开发一个完整的强化学习(RL)系统。
需要注意的是:虽然与上下文赌博机相比,多步强化学习更加灵活,但它的实现也更加复杂。因此,如果当前的问题能够准确地被框定为一个单步问题(即使乍一看像是多步问题),上下文赌博机可能是更实际的解决方案。
2.4. 上下文赌博机与提升建模
在继续讨论不同的 CB 算法之前,我还想简要地提及一下 CB 与提升建模之间的关系。提升模型通常是基于 A/B 测试数据构建的,用于发现处理效果(提升)与用户特征之间的关系。然后,可以使用该模型的结果来个性化未来的处理。例如,如果提升模型发现某些用户更可能从某个处理方法中受益,那么未来可能只会将该处理方法提供给这些类型的用户。
给定对提升建模的描述,应该很清楚,CB(上下文赌博机)和提升建模都是个性化问题的解决方案。它们之间的关键区别在于,CB 以一种更动态的方式解决这个问题,个性化发生在即时的过程中,而不是等待 A/B 测试的结果。从概念层面上讲,CB 可以非常宽泛地被看作是 A/B 测试和提升建模同时发生,而不是顺序发生。考虑到本文的重点,我不会进一步讨论提升建模,但有几个很好的资源可以了解更多相关内容,例如[1](https://www.uber.com/blog/research/uplift-modeling-for-multiple-treatments-with-cost-optimization/)。
3. 上下文赌博机中的探索与开发
上文我们讨论了 CB 如何根据某个给定时间点、特定用户群体的治疗组和对照组的表现,动态分配流量。这引出了一个重要的问题:当我们进行这些流量分配调整时,我们希望多么激进?例如,如果在一天的数据之后,我们决定治疗组对美国用户表现更好,是否应该完全停止对美国用户提供对照组?
我相信大多数人都会同意,这个做法是一个糟糕的主意,而且你们是对的。过于激进地改变流量分配的主要问题是,基于不足的数据进行推断可能导致错误的结论。例如,可能第一天的数据实际上并不代表沉睡用户的情况,实际上对照组对他们更好。如果我们在第一天之后就停止向美国用户提供对照组,我们将永远无法了解这种正确的关系。
更好的动态更新流量分配的方法是,在利用(基于目前的数据提供最佳实验条件)和探索(继续为其他实验条件提供服务)之间找到合适的平衡。延续前面的例子,如果第一天的数据表明治疗组对美国用户更好,我们可以在第二天通过更高的概率为这些用户提供治疗,同时仍然为对照组分配一个减少但非零的比例。
在 CB(以及 MAB)中有许多探索策略,还有一些变种试图在探索和利用之间找到合适的平衡。三种常见的策略包括ε-贪婪策略、上置信界限(UCB)和汤普森采样(Thompson sampling)。
3.1. ε-贪婪策略
在这个策略中,我们首先决定哪个实验条件在某个给定时间点对于某个特定用户群体表现更好。最简单的做法是通过比较这些用户在每个实验条件下的目标值(y)的平均值来实现。更正式地说,我们可以通过找出条件d,其值较高,从而决定该组用户的“获胜”条件。

其中,n_dx是我们到目前为止从条件d和上下文x中的用户中获取的样本数量,y_idx是条件d和上下文x中第i个样本的目标值。
在决定了哪个实验条件目前对这些用户是“最好的”之后,我们以1-ε的概率为他们提供该条件(其中ε通常是一个较小的数字,如 0.05),并以ε的概率提供一个随机的实验条件。实际上,我们可能希望动态更新我们的ε,使其在实验开始时较大(此时需要更多的探索),随着我们收集更多数据,ε逐渐变小。
此外,上下文 X 可能是高维的(例如,国家、性别、平台、任期等),因此我们可能需要使用模型来获取这些 y 估计值,从而应对维度灾难。形式上,服务的条件可以通过找到条件 d 使其值更高来决定。

其中 x^T 是一个 m 维的行向量,表示上下文值,θ_d 是一个 m 维的列向量,表示与条件 d 相关的可学习参数。
3.2. 上置信度界限(UCB)
这个策略通过查看不仅是哪个条件具有更高的 y 估计值,还包括我们对该估计值的精确度(或信心),来决定下一个要服务的条件。在一个简单的 MAB 设置中,精确度可以被视为已经服务过多少次某个条件的函数。特别地,一个条件(i)具有较高的平均 y(所以值得利用),或者(ii)还没有被服务过很多次(所以需要更多的探索),更有可能被下一个选择来服务。
我们可以通过跟踪不同条件在不同上下文中被服务的次数,将这一思想推广到 CB 设置。假设在一个简单的设置中,上下文 X 是低维的,这样 CB 可以看作是多个 MAB 的组合运行,我们可以基于哪个条件 d 的值更高来选择下一个要服务的条件。

其中 c 是一个常数(根据我们希望在探索时对估计精确度的重视程度来选择),n_x 是上下文 x 到目前为止被看到的次数。
然而,在大多数情况下,上下文 X 会是高维的,这意味着就像在 ε-贪婪情况下,我们需要使用模型。在这种设置中,如果某个条件 d 的值更高,那么它可以被选择为下一个要服务的条件。

其中 SE(.) 是我们估计的标准误差(或更广泛地说,是量化我们当前对该估计的信心水平的度量)。
请注意,UCB 有多个版本,因此你可能会遇到不同的公式。一种流行的 UCB 方法是 LinUCB,它在一个线性模型框架中形式化了问题(例如,[2])。
3.3. 汤普森采样
将要讨论的第三种也是最后一种探索策略是汤普森采样,它是一种解决探索与利用困境的贝叶斯方法。在这里,我们有一个模型f(D, X; Θ),它根据实验条件D、上下文X和一些可学习的参数Θ返回预测的y值。这个函数让我们可以访问任何条件-上下文对的预期y值的后验分布,从而根据给定上下文下产生最高预期y的概率选择下一个要执行的条件。汤普森采样自然地平衡了探索与利用,因为我们是从后验分布中采样并根据观察结果更新我们的模型。为了让这些概念更加具体,以下是汤普森采样涉及的步骤:

实际上,我们可以使用不同的函数来处理每个实验条件,而不是使用单一的函数(例如,评估f_c(X; Θ_c)和f_t(X; Θ_t),然后选择具有较高值的条件)。此外,更新步骤通常不会在每次样本之后进行,而是在看到一批样本之后进行。有关汤普森采样的更多细节,可以参考[3] [4]。
4. 上下文强盗算法步骤
前一节(尤其是关于汤普森采样的部分)应该已经让你对 CB 算法的步骤有了一个相当清晰的了解。然而,为了完整性,以下是标准 CB 算法的逐步描述:
-
一个新的数据点到达,具有上下文X(例如,一个在美国使用 iOS 设备的核心用户)。
-
给定这个数据点和选择的探索策略(例如,ε-贪婪策略),算法决定为该用户执行哪个条件(例如,治疗或对照)。
-
在条件被执行后,我们观察到结果y(例如,用户是否进行了购买)。
-
在看到新数据后更新(或完全重新训练)步骤 2 中使用的模型。(如前所述,我们通常不是在每次样本后进行更新,而是在看到一批样本后进行更新,以确保更新不那么嘈杂。)
-
重复。
5. 上下文强盗中的离线策略评估
到目前为止,我们只讨论了如何在新数据到来时实现 CB 算法。另一个同样重要的话题是如何使用旧(或已记录)数据来评估 CB 算法。这被称为离线评估或离线策略评估(OPE)。
5.1. 使用因果推断方法进行 OPE
进行 OPE 的一种方法是使用广为人知的因果推断技术,如逆倾向得分(IPS)或双重稳健(DR)方法。因果推断在这里是合适的,因为我们实际上是在尝试估计一个反事实,即如果采用不同的策略为用户提供不同的条件会发生什么。这个话题已经有一篇很好的 Medium 文章介绍[5],所以这里我将简要总结文章的主要思想,并将其应用到我们的讨论中。
以 IPS 为例,进行 OPE 通常要求我们不仅知道(i)使用我们新的 CB 算法为样本分配给定条件的概率,还要知道(ii)在记录数据中样本被分配到给定条件的概率。假设以下是一个假设的记录数据,其中X_1-X_3是上下文,D是实验条件,P_O(D)是将D分配给该用户的概率,y是结果。
表 2:来自 A/B 测试的示例记录数据

如你所见,在这个例子中,P_O(D) 对于D=1始终为 0.6,对于D=0始终为 0.4,无论上下文如何,因此可以假设记录的数据来自一个以 0.6 的概率分配处理的 A/B 测试。现在,如果我们想测试 CB 算法在我们使用 CB 算法而非简单 A/B 测试分配条件的情况下会如何表现,我们可以使用以下公式来获取 CB 的累积y的 IPS 估计值。

其中n是记录数据中的样本数量(此处为 5),P_N(D_i)是如果我们使用新的 CB 算法,给user_i分配记录中的D的概率(此概率将取决于被评估的具体算法)。
一旦我们有了这个估计值,我们可以将其与旧 A/B 测试中观察到的累积y进行比较(在这里为 1+0+0+1+1=3),以决定 CB 是否会产生更高的累积y。
有关使用因果推断方法进行 OPE 的更多信息,请参考本节开头链接的文章。该文章还链接了一个非常好的 GitHub 仓库,其中包含许多 OPE 实现。
这里有一个旁注,本节讨论了因果推断方法作为 OPE 中的一种技术。然而,实际上,人们也可以在 CB 算法运行时应用这些方法,以“去偏”算法在过程中收集的训练数据。我们可能希望将像 IPS 这样的技术应用于训练数据的原因是,生成这些数据的 CB 策略本质上是一个非均匀的随机策略,因此,从中估计因果效应来决定采取什么行动时,使用因果推断方法会更有帮助。如果你想了解更多关于去偏的信息,请参考[6]。
5.2. 使用采样方法进行 OPE
进行 OPE 的另一种方法是通过使用采样方法。特别地,可以使用一种非常简单的重放方法[7]来评估 CB 算法(或任何其他算法),该方法使用来自随机化策略(如 A/B 测试)的日志数据。在最简单的形式下(假设我们使用均匀随机日志策略),该方法的工作原理如下:
-
从日志数据中采样下一个具有上下文 X 的用户。
-
使用新的 CB 算法决定分配给该用户的条件。
-
如果选定的条件与日志数据中的实际条件匹配,则将观察到的 y 添加到累积 y 计数器中。如果不匹配,则忽略该样本。
-
重复直到所有样本都被考虑。
如果日志策略没有均匀随机地分配处理,则需要对该方法稍作修改。作者自己提到的一个修改是使用拒绝采样(例如,[8]),在第 3 步中,我们接受来自多数处理的样本的频率会比少数处理的样本少。或者,我们可以考虑在第 3 步中将观察到的 y 除以倾向性,从而“下调”更频繁的处理,并“上调”较少频繁的处理。
在下一部分中,我使用了一个更简单的方法来评估,它通过引导法进行上采样和下采样,将原始的不均匀数据转化为均匀数据,然后按原样应用该方法。
6. 上下文赌博机实际应用
为了演示上下文赌博机的实际应用,我整理了一个 notebook,该 notebook 生成一个模拟数据集并比较新 A/B、MAB 和 CB 策略在该数据集上的累积 y(或“奖励”)估计。该 notebook 中的许多代码来自一本关于强化学习的精彩书籍的“上下文赌博机”章节[9](如果你想深入了解使用 Python 的强化学习,强烈推荐),以及 James LeDoux 的两篇精彩文章[10] [11],并已根据我们在此讨论的设置进行了调整。
设置非常简单:我们拥有的原始数据来自一个 A/B 测试,该测试以 0.75 的概率分配处理给用户(即非均匀随机)。使用这些随机化的日志数据,我们希望基于它们的累积 y 来评估和比较以下三种策略:
-
一种新的 A/B 策略,它以 0.4 的概率随机将处理分配给用户。
-
一个 MAB 策略,它使用ε-贪心策略决定下一步分配什么治疗,但不考虑上下文 X。
-
一个 CB 策略,它使用ε-贪心策略决定下一步分配什么治疗,同时考虑上下文 X。
我修改了 Li 等人在论文中描述的原始方法,不再直接从模拟数据中采样(在我的例子中,模拟数据是 75%治疗和仅 25%控制),而是首先对治疗案例进行下采样,对控制案例进行上采样(都使用替换),从而获得一个新的数据集,确保治疗和控制各占 50%。
我之所以从一个非50%治疗和 50%控制的数据集开始,是为了展示即使原始数据不是来自一个均匀随机分配治疗和控制的策略,我们仍然可以通过对数据进行上下采样,将其处理成一个 50/50%的数据集,并进行离线评估。如前一节所提到的,上下采样的逻辑类似于拒绝采样以及相关的将观察到的 y 除以倾向的概念。
以下图表比较了上述三种策略(A/B vs MAB vs CB)在其累计 y 值上的表现。
图 3:累计奖励比较

如图所示,累计 y 在 CB 策略下增加最快,A/B 策略最慢,而 MAB 则介于两者之间。虽然这个结果基于一个模拟数据集,但这里观察到的模式仍然可以进行概括。A/B 测试无法获得较高的累计 y 的原因在于,即便看到足够的证据表明治疗总体上优于控制,它仍然没有改变 60/40%的分配。另一方面,虽然 MAB 能够动态更新流量分配,但由于没有基于观察到的上下文 X 个性化治疗与控制的分配,它的表现仍然不如 CB。最终,CB 同时动态变化流量分配,并且个性化治疗,因此表现最好。
7. 结论
恭喜你完成了这篇相对较长的文章!在这篇文章中,我们讨论了很多与上下文赌博机(CB)相关的内容,我希望你读完这篇文章后,能更加理解这种方法在在线实验中的实用性,特别是当治疗需要个性化时。
如果你对学习更多关于上下文臂(或者想深入了解多步强化学习)感兴趣,我强烈推荐阅读Mastering Reinforcement Learning with Python一书,作者是 E. Bilgin。这本书的上下文臂章节让我最终理解这个主题的“啊哈!”时刻,我继续阅读以了解更多关于 RL 的知识。至于离线策略评估,我强烈推荐 E. Conti 和 J. LeDoux 的文章,两者都提供了关于所涉技术的很好的解释并附有代码示例。关于上下文臂中的去偏见问题,A. Bietti、A. Agarwal 和 J. Langford 的论文提供了很好的技术概述。最后,虽然本文专注于在构建上下文臂时使用回归模型,但还有一种名为成本敏感分类的替代方法,你可以通过查阅 A. Agarwal 和 S. Kakade 的这些讲义开始学习[12]。
玩得开心,构建上下文臂!
8. 致谢
我要感谢 Colin Dickens 向我介绍上下文臂,并在这篇文章中提供了宝贵的反馈,感谢 Xinyi Zhang 在写作过程中提供的所有有用反馈,感谢 Jiaqi Gu 对采样方法的富有成效的讨论,以及感谢 Dennis Feehan 鼓励我花时间写这篇文章。
除非另有说明,所有图片均为作者所有。
9. 参考文献
1 Z. Zhao 和 T. Harinen,多治疗方案的提升建模与成本优化(2019),DSAA
[2] Y. Narang,使用 LinUCB 的推荐系统:上下文多臂老丨虎丨机方法(2020),Medium
[3] D. Russo, B. Van Roy, A. Kazerouni, I. Osband, and Z. Wen,Thompson 抽样教程(2018),Foundations and Trends in Machine Learning
[4] B. Shahriari, K. Swersky, Z. Wang, R. Adams, and N. de Freitas,摒弃人类的循环:贝叶斯优化综述(2015),IEEE
[5] E. Conti,离线策略评估:减少,优化 A/B 测试(2021),Medium
[6] A. Bietti, A. Agarwal, and J. Langford,上下文臂竞赛(2021),ArXiv
[7] L. Li, W. Chu, J. Langford, and X. Wang,无偏离线评估基于上下文臂的新闻推荐算法(2011),WSDM
[8] T. Mandel, Y. Liu, E. Brunskill, 和 Z. Popovic, 在线强化学习算法的离线评估 (2016), AAAI
[9] E. Bilgin, 用 Python 精通强化学习 (2020), Packt 出版社
[10] J. LeDoux, 使用回放在 Python 中离线评估多臂老丨虎丨机算法 (2020), LeDoux 个人网站
[11] J. LeDoux, Python 中的多臂老丨虎丨机:Epsilon 贪婪法、UCB1、贝叶斯 UCB 和 EXP3 (2020), LeDoux 个人网站
[12] A. Agarwal 和 S. Kakade, 非策略评估与学习 (2019), 华盛顿大学计算机科学系
LoRA 家族概述
LoRA、DoRA、AdaLoRA、Delta-LoRA 等低秩适应的变体。
·发布于Towards Data Science ·阅读时长 17 分钟·2024 年 3 月 10 日
--

LoRA 有多种不同的形式和变体。照片由Lucas George Wendt提供,来源于Unsplash。
Low-Rank Adaptation (LoRA) 可以被认为是朝着高效训练大型语言模型以完成特定任务的一项重大突破。它今天广泛应用于许多领域,并激发了关于如何改进其主要思想以实现更好的性能或更快速训练模型的研究。
在本文中,我将概述一些 LoRA 的变体,它们承诺在不同方面提高 LoRA 的能力。我将首先解释 LoRA 本身的基本概念,然后介绍LoRA+、VeRA、LoRA-FA、LoRA-drop、AdaLoRA、DoRA 和 Delta-LoRA。我将介绍每种方法的基本概念和主要思想,并展示这些方法如何偏离原始的 LoRA。我会省略技术细节,除非它们对基本概念至关重要,并且也不会详细讨论评估。对于有兴趣的读者,我在文末提供了原始论文的链接。
Lora

LoRA 的主要思想是将两个较小的可调矩阵 A 和 B 添加到预训练的权重矩阵 W 旁边,而不改变 W 的参数。图片来源1。
低秩适应(LoRA)1是一种今天广泛用于训练大型语言模型(LLMs)的技术。大型语言模型具有根据自然语言输入预测自然语言令牌的能力。这是一个令人惊叹的能力,但在解决许多问题时,仅凭此能力往往不够。大多数情况下,你希望在给定的下游任务上训练 LLM,例如对句子进行分类或生成给定问题的答案。实现这一点最直接的方式是微调,即使用目标任务的数据训练 LLM 的部分层。 但这意味着需要训练具有数百万到数十亿参数的非常大模型。
LoRA 提供了一种替代的训练方式,由于大幅减少了参数数量,因此训练速度更快,实施更为简便。除了已经预训练的 LLM 层的参数权重外,LoRA 引入了两个矩阵 A 和 B,这些矩阵被称为适配器,它们的尺寸要小得多。如果原始的参数矩阵 W 的大小是d x d,那么矩阵 A 和 B 的大小分别为d x r和r x d,其中r要小得多(通常小于 100)。参数r被称为秩。也就是说,如果你使用秩为r=16的 LoRA,这些矩阵的形状就是16 x d。秩越高,训练的参数越多。这一方面可能带来更好的性能,另一方面也需要更多的计算时间。
现在我们有了这些新的矩阵 A 和 B,那么它们会发生什么呢?输入给 W 的数据同样也会输入到 BA 中,BA 的输出会加到原始矩阵 W 的输出上。也就是说,你在原有基础上训练了一些参数,并将它们的输出添加到原始预测中,这使得你可以影响模型的行为。你不再训练 W,这也是为什么我们有时会说 W 是冻结的。重要的是,A 和 B 的加法不仅仅发生在最后一层(那样只是增加一层),而是可以应用于神经网络内部的深层。
这就是 LoRA 的主要思想,它的最大优势是,训练的参数比微调时少,但仍能获得相似的性能。这里我还想提到一个技术细节:一开始,矩阵 A 用均值为零但方差较小的随机值进行初始化。矩阵 B 则初始化为一个全零矩阵。这确保了 LoRA 矩阵从一开始就不会以随机的方式改变原始 W 的输出。当 A 和 B 的参数被调优到期望方向时,A 和 B 的更新对 W 的输出应当是一个加法操作,而不是随机改变原始输出。然而,我们稍后会看到,由于不同的原因,某些方法偏离了这一思想。
如前所述,LoRA 在当今的 LLM 中被广泛使用。然而,到现在为止,已经有许多 LoRA 的变种,它们在不同的方面偏离了原始方法,旨在提高速度、性能,或两者兼顾。接下来我将向你介绍其中一些。
LoRA+

LoRA+为两个矩阵 A 和 B 引入了不同的学习率,这里通过参数λ来表示。图片来源于[2]。
LoRA+ [2]通过为矩阵 A 和 B 引入不同的学习率,提供了一种更高效的 LoRA 适配器训练方法。在大多数情况下,当训练神经网络时,只有一个学习率应用于所有的权重矩阵。但是,LoRA+的作者展示了,对于 LoRA 中使用的适配器矩阵,使用单一学习率是次优的。通过将矩阵 B 的学习率设置得比矩阵 A 高得多,训练变得更加高效。
这种方法背后有一个理论依据,主要基于神经网络初始化的数值陷阱,尤其是当模型在神经元数量上变得非常宽时。然而,证明这一点的数学推导相当复杂(如果你对这一点非常感兴趣,可以查看原文[2])。直观上,你可能会认为矩阵 B 在初始化为零后,可以使用比随机初始化的矩阵 A 更大的更新步长。此外,实证研究也证明了这种方法的改进。通过将矩阵 B 的学习率设置为矩阵 A 的 16 倍,作者在像 RoBERTa 或 Llama-7b 这样的模型上获得了小幅度的精度提升(大约 2%),同时将训练时间提高了两倍。
VeRA

VeRA 不训练 A 和 B,而是将其初始化为一个随机投影,随后训练额外的向量 d 和 b。图片来源于[3]。
VeRA(Vector-based Random Matrix Adaptation)[3]中,作者提出了一种方法,极大地减少了 LoRA 适配器的参数规模。与其训练矩阵 A 和 B(这是 LoRA 的核心思想),他们用共享随机权重初始化这些矩阵(即所有层中的矩阵 A 和 B 具有相同的权重),并添加了两个新的向量 d 和 b。接下来只训练这些向量 d 和 b。
你可能会想,这怎么可能有效呢?A 和 B 是随机权重的矩阵。如果它们完全没有经过训练,怎么能为模型的性能做出贡献呢?这种方法基于一个有趣的研究领域——所谓的随机投影。许多研究表明,在大型神经网络中,只有一小部分权重被用来引导模型行为,并产生模型在训练任务上的预期表现。由于随机初始化,模型的某些部分(或子网络)从一开始就更有可能贡献于预期的模型行为。在训练过程中,所有参数都会被训练,因为现在已经知道了哪些子网络是重要的。这使得训练非常昂贵,因为大多数更新的参数并未对模型的预测产生任何价值。
基于这一思想,有一些方法仅训练这些相关的子网络。通过在矩阵后添加投影向量,而不是直接训练子网络本身,也可以获得类似的行为。由于矩阵与向量的乘法,这可以产生与调整矩阵中的一些稀疏参数相同的输出。这正是 VeRA 的作者提出的,通过引入向量 d 和 b 来实现训练,而矩阵 A 和 B 则被冻结。此外,与原始的 LoRa 方法不同,矩阵 B 不再被设为零,而是像矩阵 A 一样随机初始化。
这种方法自然导致的参数数量远小于完整的矩阵 A 和 B。例如,如果你为 GPT-3 引入秩为 16 的 LoRA 层,你将拥有 7550 万个参数。而使用 VeRA 时,你只有 280 万个参数(减少了 97%)。但是,参数数量如此之小,性能如何呢?VeRA 的作者使用了一些常见的基准,如 GLUE 或 E2E,并使用基于 RoBERTa 和 GPT2 Medium 的模型进行了评估。他们的结果表明,VeRA 模型的性能仅比完全微调或使用原始 LoRa 技术的模型略低。
LoRA-FA

LoRA-FA 冻结矩阵 A,只训练矩阵 B。图片来自[4]。
另一种方法,LoRA-FA [4],即带有Frozen-A的 LoRA,朝着与 VeRA 类似的方向发展。在 LoRA-FA 中,矩阵 A 在初始化后被冻结,因此作为一个随机投影存在。不同于添加新向量的是,矩阵 B 在初始化为零后被训练(就像原始 LoRA 一样)。这样可以将参数数量减半,同时保持与普通 LoRA 相当的性能。
LoRa-drop

LoRA-drop 使用 B*A 的输出决定哪些 LoRA 层值得训练。图片来自[5]。
一开始,我解释过,你可以将 LoRA 矩阵添加到神经网络的任何一层。LoRA-drop [5] 提出了一种算法,用来决定哪些层值得通过 LoRA 进行增强,哪些则不值得这么做。即使训练 LoRA 适配器比微调整个模型便宜得多,但你添加的 LoRA 适配器越多,训练的成本仍然会越高。
LoRA-drop 包含两个步骤。第一步,你从数据中抽取一个子集并训练 LoRA 适配器若干迭代。然后,你计算每个 LoRA 适配器的重要性,计算公式为 BAx,其中 A 和 B 是 LoRA 矩阵,x 是输入。这只是 LoRA 适配器的输出,它会加到每个冻结层的输出中。如果这个输出很大,它会显著改变冻结层的行为。如果输出很小,这表明 LoRA 适配器对冻结层的影响微乎其微,完全可以省略。
鉴于这一重要性,你现在可以选择最重要的 LoRA 层。有多种方法可以做到这一点。你可以将重要性值加总,直到达到一个由超参数控制的阈值,或者直接选择前 n 个具有最高重要性的 LoRA 层,其中 n 是固定的。无论哪种方式,在下一步中,你将在整个数据集上进行完整训练(记住,你在前几步使用的是数据的子集),但只针对你刚才选择的那些层。其他层将固定为一组共享的参数,在训练过程中不再改变。
因此,LoRA-drop 算法允许仅使用部分 LoRA 层进行模型训练。作者提供了实证证据,表明与训练所有 LoRA 层相比,准确率变化仅为微不足道,但由于需要训练的参数数量减少,计算时间大大缩短。
AdaLoRA

AdaLoRA 允许动态调整 LoRA 矩阵的秩。照片来自 Hasmik Ghazaryan Olson 上传于 Unsplash。
还有其他方式可以决定哪些 LoRA 参数比其他参数更重要。在这一部分,我将介绍 AdaLoRA [6],它代表了 自适应 LoRA。这里 LoRA 的哪个部分是自适应的?就是 LoRA 矩阵的秩(即大小)。主要问题与上一部分相同:并非每一层都值得添加 LoRA 矩阵 A 和 B,但对于某些层,LoRA 训练可能比其他层更为重要(即可能会对模型行为产生更大的影响)。为了决定这种重要性,AdaLoRA 的作者提出考虑 LoRA 矩阵的奇异值作为其重要性的指标。
这是什么意思呢?首先,我们必须理解,矩阵乘法也可以看作是对向量应用一个函数。在处理神经网络时,这一点是非常明显的:大多数时候你将神经网络视作一个函数,也就是说,你给定一个输入(例如,一组像素值的矩阵),然后得到一个结果(比如,图像的分类)。在幕后,这个函数的应用是通过一系列矩阵乘法来实现的。现在,假设你想要减少矩阵中的参数数量。这样会改变函数的行为,但你希望它的变化尽可能小。一个方法是计算矩阵的特征值,它们告诉你每一行矩阵捕获了多少方差。然后,你可以决定将一些捕获较小方差的行设为零,这些行对函数贡献不大。AdaLoRA 的主要思想就是基于上述的奇异值,它们正是特征值的平方根。也就是说,基于奇异值,AdaLoRA 决定了哪些 LoRA 矩阵的行更重要,哪些可以被省略。这有效地缩小了一些矩阵的秩,这些矩阵中有许多行并没有太大贡献。然而,需要注意的是,与上一节的 LoRA-drop 存在一个重要的区别:在 LoRA-drop 中,层的适配器要么被完全训练,要么根本不训练。而 AdaLoRA 还可以决定保留某些层的适配器,但使用较低的秩。这意味着,最终不同的适配器可以有不同的秩(而在原始的 LoRA 方法中,所有适配器的秩是相同的)。
AdaLoRA 方法还有一些细节,我为了简洁起见省略了它们。不过,我想提到其中的两个:首先,AdaLoRA 方法并不是每次都显式地计算奇异值(因为那样做会非常昂贵),而是通过奇异值分解来分解权重矩阵。这种分解是另一种表示相同信息的方式,但它允许直接获取奇异值,而无需进行昂贵的计算。其次,AdaLoRA 不仅仅依赖奇异值,还考虑了损失函数对某些参数的敏感性。如果将某个参数设为零对损失有很大影响,那么这个参数就被认为具有较高的敏感性。在决定压缩秩时,除了奇异值外,还会考虑一行元素的平均敏感性。
通过比较 AdaLoRA 与标准 LoRA(在相同秩预算下)的结果,可以获得该方法有效性的实证证据。也就是说,这两种方法总参数数量相同,但分布方式不同。在 LoRA 中,所有矩阵的秩相同,而在 AdaLoRA 中,部分矩阵的秩较高,部分较低,最终导致相同数量的参数。在许多场景中,AdaLoRA 比标准 LoRA 方法取得更好的分数,表明它能更好地分配模型的可训练参数,尤其是在对于给定任务至关重要的部分。以下图表展示了 AdaLoRA 如何为给定模型分配秩。如图所示,AdaLoRA 将较高的秩分配给模型后面的层,表明调整这些层更为重要。

在网络的不同层中,LoRA 矩阵被赋予不同的秩。通常,在后面的层中,秩会更高。图片来源:[6]。
DoRA

在 DoRA 中,权重矩阵 W 被分解为大小 m 和方向 V,这两个部分是独立调整的。图片来源:[7]。
改进 LoRA 以获得更好性能的另一种方法是权重-分解低秩适应(Weight-Decomposed Low-Rank Adaption),简称DoRA[7]。DoRA 的核心思想是,每个矩阵都可以分解为大小和方向的乘积。对于二维空间中的一个向量,你可以很容易地可视化这一点:一个向量不过是从零位置出发,指向向量空间中某个点的箭头。通过向量的条目,你可以指定那个点,例如,如果你的空间有 x 和 y 两个维度,你可以说 x=1 和 y=1。或者,你也可以通过指定大小和角度(即方向)来以不同的方式描述这个点,比如 m=√2 和 a=45°。这意味着你从零点开始,沿着 45°方向移动,箭头长度为√2,最终会到达相同的点(x=1,y=1)。
这种大小和方向的分解方法同样可以应用于更高阶的矩阵。DoRA 的作者将此方法应用于描述模型训练步骤中更新的权重矩阵,适用于普通微调训练的模型和使用 LoRA 适配器训练的模型。我们在下图中可以看到这两种技术的对比:

微调和 LoRA 在大小和方向的变化关系上有所不同。图片来源:[7]。
我们看到两个图表,一个是微调模型(左)的图表,一个是使用 LoRA 适配器训练的模型(右)的图表。在 x 轴上,我们看到方向的变化,在 y 轴上,我们看到幅度的变化,图中的每个散点代表模型的一层。两种训练方式之间有一个重要的区别。在左图中,方向更新与幅度更新之间存在较小的负相关,而在右图中,存在更强的正相关关系。你可能会想,哪种方式更好,或者这是否有什么意义。请记住,LoRA 的主要理念是以较少的参数实现与微调相同的性能。这意味着,理想情况下,我们希望 LoRA 的训练与微调共享尽可能多的特性,只要这不会增加成本。如果微调中方向与幅度之间的相关性稍微是负的,这对于 LoRA 来说可能是一个理想的特性,前提是它可以实现。换句话说,如果 LoRA 中方向与幅度的关系与全微调不同,这可能是 LoRA 有时表现不如微调的原因之一。
DoRA 的作者提出了一种方法,通过将预训练矩阵 W 分解为幅度向量 m(大小为1 x d)和方向矩阵 V,独立地训练幅度和方向。然后,方向矩阵 V 通过 B*A 进行增强,正如标准 LoRA 方法所示,m 则保持不变进行训练,因为它只有一个维度。尽管 LoRA 倾向于同时改变幅度和方向(如这两者之间的高度正相关所示),但 DoRA 可以更轻松地单独调整一个,而不影响另一个,或通过负向变化来补偿一个的变化。我们可以看到方向和幅度之间的关系更像是在微调中的关系:

对于 DoRA,幅度和方向之间的关系更像是在微调中的关系。图片来自[7]。
在多个基准测试中,DoRA 在准确性上优于 LoRA。将权重更新分解为幅度和方向可能使 DoRA 执行更接近于微调中的训练,同时仍使用 LoRA 引入的较小参数空间。
Delta-LoRA

Delta-LoRA 并不冻结矩阵 W,而是通过从 B*A 获得的梯度更新它。图片来自[8]。
Delta-LoRA [8]引入了另一个改进 LoRA 的想法。这一次,预训练矩阵 W 再次发挥了作用。记住,LoRA 的主要思想是不要(!)调整预训练矩阵 W,因为那样代价太高(这将是正常的微调)。这就是为什么 LoRA 引入了新的较小矩阵 A 和 B。然而,这些较小的矩阵在学习下游任务时的能力较弱,这也是为什么 LoRA 训练的模型性能通常低于微调模型性能的原因。在训练过程中调整 W 是非常理想的,但我们如何负担得起呢?
Delta-LoRA 的作者提议通过 AB 的梯度来更新矩阵 W,即 AB 在两个连续时间步之间的差异。这个梯度会通过某个超参数λ进行缩放,λ控制新训练对预训练权重的影响程度,然后将其加到 W 中(同时α和 r(秩)是原始 LoRA 设置中的超参数):

W 通过两步之间 AB 的差值来更新。图片来源于[8]。
这引入了更多的参数进行训练,而几乎没有计算开销。我们不需要像在微调中那样为整个矩阵 W 计算梯度,而是利用在 LoRA 训练过程中已经得到的梯度来更新它。作者在多个基准测试中使用 RoBERTA 和 GPT-2 等模型对比了这种方法,并发现其在性能上优于标准的 LoRA 方法。
总结

恭喜你,已经读完了。图片来源:david Griffiths来自Unsplash
我们刚刚看到了多种方法,它们在 LoRA 的核心思想上有所变化,目的是减少计算时间或提高性能(或两者兼具)。最后,我将简要总结这些不同的方法:
-
LoRA引入了低秩矩阵 A 和 B 进行训练,同时预训练的权重矩阵 W 保持冻结。
-
LoRA+建议 B 的学习率远高于 A。
-
VeRA不训练 A 和 B,而是随机初始化它们,并在其上训练新的向量 d 和 b。
-
LoRA-FA仅训练矩阵 B。
-
LoRA-drop使用 B*A 的输出来确定哪些层值得进行训练。
-
AdaLoRA动态地调整 A 和 B 在不同层中的秩,允许在这些层中使用更高的秩,尤其是在期望对模型性能贡献较大的层中。
-
DoRA将 LoRA 适配器分为幅度和方向两个组件,并允许它们更加独立地进行训练。
-
Delta-LoRA通过 A*B 的梯度来改变 W 的权重。
LoRA 及相关方法的研究领域非常丰富且生动,几乎每天都有新的贡献。在本文中,我想解释一些方法的核心思想。当然,这只是其中的一部分,远远不能算作完整的综述。
我希望我能够与您分享一些知识,并可能激发您产生新的想法。正如我们所看到的,LoRA 及相关方法是一个具有巨大潜力的研究领域。预计在提升大规模语言模型训练性能或计算时间方面,新的突破很快就会出现。
参考文献及进一步阅读
这些是本文中解释的概念相关的论文:
-
1 LoRA: Hu, E. J., Shen, Y., Wallis, P., Allen-Zhu, Z., Li, Y., Wang, S., … & Chen, W. (2021). LoRA: 大规模语言模型的低秩适应. arXiv 预印本 arXiv:2106.09685.
-
[2] LoRA+: Hayou, S., Ghosh, N., & Yu, B. (2024). LoRA+: 大规模模型的高效低秩适应. arXiv 预印本 arXiv:2402.12354.
-
[4]: LoRA-FA: Zhang, L., Zhang, L., Shi, S., Chu, X., & Li, B. (2023). Lora-fa: 针对大规模语言模型微调的内存高效低秩适应. arXiv 预印本 arXiv:2308.03303.
-
[5] LoRA-drop: Zhou, H., Lu, X., Xu, W., Zhu, C., & Zhao, T. (2024). LoRA-drop: 基于输出评估的高效 LoRA 参数剪枝. arXiv 预印本 arXiv:2402.07721.
-
[8]: Delta-LoRA: Zi, B., Qi, X., Wang, L., Wang, J., Wong, K. F., & Zhang, L. (2023). Delta-lora: 用低秩矩阵的增量微调高秩参数. arXiv 预印本 arXiv:2309.02411.
对于关于随机投影的一些核心思想,正如在 VeRA 一节中提到的,这篇文章是该领域的主要贡献之一:
若需要更细致地解释 LoRA 和 DoRA,我可以推荐这篇文章:
喜欢这篇文章吗? 关注我 以便接收我未来文章的通知。
一种被不公正遗忘的相关系数
一种适用于日常任务的非线性相关度量
·发表于 Towards Data Science ·阅读时长 6 分钟·2024 年 4 月 30 日
--

图片由作者使用 recraft.ai 创建
传统的相关系数,如皮尔逊ρ,斯皮尔曼或肯德尔τ,仅限于寻找线性或单调关系,并且在识别更复杂的关联结构时存在困难。关于一种新的相关系数 ξ 的近期文章 1,该系数旨在克服这些限制,已引起广泛关注并进行了深入讨论。评论中提出的一个问题是,ξ 相比基于互信息的非线性相关度量,具有哪些特别的优势。在这种辩论中,实验可能胜过千言万语。因此,在这个故事中,我将从多个方面实验性地比较 ξ 与基于互信息的系数 R,这些方面是一个非线性相关度量应满足的特性。基于结果,我强烈建议在大多数需要寻找非线性关联的常规操作中使用 R,而不是 ξ。
要求
让我首先总结并说服你,我们正在寻找的系数应具备的期望特性。我们希望找到一个关联度量 A(x,y),它
-
是非线性的。也就是说,当 x 和 y 独立时,它的值为零;当变量之间存在确切的非线性关系时,如 x = h(t), y=f(t),其中 t 是一个参数时,它的度量值为一;
-
是对称的。也就是说,A(x,y)=A(y,x)。反之则会令人困惑;
-
是一致的。也就是说,当x, y服从双变量正态分布时,它等于线性相关系数ρ,即它是ρ对其他分布的推广。这是因为ρ在实践中广泛使用,我们中的许多人已经培养出了对其数值与关系强度之间联系的直觉。此外,ρ对于标准正态分布有明确的意义,因为它完全定义了标准正态分布;
-
是可扩展的——即使对于具有大量观测值的数据集,也可以在合理的时间内计算相关性;
-
是精确的,即具有低方差的估计器。
下表总结了我的实验结果,其中绿色表示度量具有测试的特性,红色表示相反,橙色稍好于红色。现在让我带你走过这些实验;你可以在这个Github 仓库 [2]中找到它们的代码,使用R编程语言编写。

图片由作者创建
相关系数
我使用以下系数的实现及其配置
-
对于线性相关系数ρ,我使用‘stats’包中的标准函数
cor(); -
对于ξ,我使用‘XICOR’包中的
xicor()函数[3]; -
互信息(MI)取值范围是[0,∞),并且有多种方法可以估计它。因此,对于 R,必须选择(a)要使用的 MI 估计器和(b)将 MI 转换到[0,1]范围内的变换。
目前有基于直方图和基于最近邻的 MI 估计器。尽管许多人仍然使用基于直方图的估计器,但我认为 Kraskov 的最近邻估计器[4]是最好的之一。我将使用它在‘FNN’包中的实现mutinfo()[5],并使用论文中建议的参数k=2。
如果你想了解更多关于这个特定的互信息(MI)估计器,请在评论中写下。
也有几种方法可以将 MI 归一化到[0,1]的区间。我将使用下面的方法,因为它已被证明具有一致性特性,我将在实验中展示这一点。

这个度量 R 称为互信息系数[6]。然而,我注意到人们有时会把它与较新的最大信息系数(MIC)[7]混淆。后者已被证明比一些替代方法差[8],并且缺乏它应有的一些特性[9]。
非线性
在下图中,我为一个具有不同甜甜圈厚度的 10K 个点的数据集计算了所有三个相关系数。如预期的那样,线性相关系数ρ没有捕捉到任何图形中的关系。相反,R正确地确定了x和y是相关的,并且对于右侧图形的数据,R的值为 1,这对应于x和y之间无噪声的关系:x = cos(t)和y = sin(t)。然而,在后者的情况下,系数ξ仅为 0.24。更重要的是,在左侧图中,尽管x和y并非独立,ξ却接近零。

图像由作者创建
对称性
在下图中,我为从不同分布生成的数据集计算了这些量。我得到了ρ(x,y)=ρ(y,x)和R(x,y)=R(y,x),因此对于这些量我只报告一个单一的值。然而,ξ(x,y)和ξ(y,x)则差异很大。这可能是因为y=f(x),但x并不是y的函数。这种行为在现实中可能并不理想,因为解读一个非对称的相关矩阵并不容易。

图像由作者创建
一致性
在这个实验中,我计算了所有系数,对于由具有给定相关系数 0.4、0.7 或 1 的双变量标准正态分布生成的数据集。ρ和R接近真实的相关性,而ξ则不是,即它没有上述定义的一致性特性。

图像由作者创建
可扩展性
为了检查估计量的性能,我生成了不同大小的数据集,这些数据集由两个独立且均匀分布的变量组成。下图展示了计算每个系数所需的时间(单位:毫秒)。当数据集包含 50K 个点时,R的计算速度比ξ慢大约 1000 倍,比ρ慢大约 10000 倍。然而,计算仍然需要大约 10 秒,这在计算适量的相关性时是合理的。考虑到上述提到的R的优势,我建议即使是计算大量相关性时也使用它——只需随机抽样你的数据至大约 10K 个点,此时计算R所需的时间不到一秒。

图像由作者创建
精确度
对于来自同一分布的不同样本,相关系数的估计值会有所不同。如果 x 和 y 之间存在关联,我们希望这些估计值的方差相较于相关系数的均值较小。对于一个度量 A(x,y),可以计算 precision=sd(A)/mean(A),其中 sd 是标准差。该值越小越好。下表展示了在不同大小的数据集上,计算的 precision 值,这些数据集具有不同的维度相关值。ξ 是最不精确的,而 ρ 是最精确的。

图片由作者创建
参考文献
[2] 我在 Github 上的实验
[3] R 的 XICOR 包
[4] Kraskov, A., Stögbauer, H., & Grassberger, P. (2004). 估计互信息。Physical review E, 69(6), 066138.
[5] R 的 FNN 包
[6] Granger, C., & Lin, J. L. (1994). 使用互信息系数识别非线性模型中的滞后。Journal of time series analysis, 15(4), 371–384.
[7] Reshef, D. N., Reshef, Y. A., Finucane, H. K., Grossman, S. R., McVean, G., Turnbaugh, P. J., … & Sabeti, P. C. (2011). 在大数据集中检测新型关联。science, 334(6062), 1518–1524.
[8] Simon, N., & Tibshirani, R. (2014). 对 Reshef 等人《在大数据集中检测新型关联》的评论,发表于《Science》2011 年 12 月 16 日。arXiv 预印本 arXiv:1401.7645.
[9] Kinney, J. B., & Atwal, G. S. (2014). 公平性、互信息和最大信息系数。Proceedings of the National Academy of Sciences, 111(9), 3354–3359.
使用 Friedman 的 H-stat 和 Python 分析特征交互
使用 artemis 包应用 H-stat 并解释成对、总体和未标准化的度量
·发表于Towards Data Science ·8 分钟阅读·2024 年 6 月 21 日
--

(来源:作者)
Friedman 的 h 统计量(h-stat)为复杂的机器学习模型提供了一个强大的观察窗口。具体来说,它帮助我们了解模型是否通过特征之间的交互来进行预测。我们将看到,这种 XAI 方法可以告诉我们某个特征是否与其他任何特征或某个特定特征发生交互。为此,我们将:
-
使用artemis Python 包应用 h-stat。
-
解释输出结果,包括交互热图和条形图。
你可以在GitHub上找到完整的项目。你也可以观看关于该主题的视频。如果你想了解更多,可以查看我的课程——使用 Python 的 XAI。如果你订阅我的新闻通讯,你将获得免费访问权限。
什么是 H-stat?
交互是指特征与目标变量之间的关系依赖于另一个特征的值……
每个数据科学家应该了解的分析框架
为什么我认为我在麦肯锡的经历让我成为了一名更好的数据科学家
·发布于 Towards Data Science ·11 分钟阅读·2024 年 8 月 28 日
--

图片由作者提供(通过 Midjourney 辅助创作)
与许多科技领域的数据科学家不同,我的数据科学职业生涯始于咨询行业,我认为这是我做出的最佳职业选择。不管你怎么评价咨询文化和工作时长,我在麦肯锡的两年里学到了很多,至今每天都从中受益。
作为一名经理,我的工作之一是指导团队中的数据科学家,帮助他们在项目和职业发展方面取得进步。我意识到,初级数据科学家最常面临的挑战通常不是技术或执行部分——那是容易教、容易学的部分。
通常,工作中更抽象/软技能相关的部分是大多数人不知道如何应对的——例如如何将一个抽象的商业问题分解为更小、定义明确的分析,这些分析最终可以带来实际的商业影响。
这些是我作为顾问每天练习的内容,我认为这些经验对数据科学非常有帮助。
为了帮助我的数据科学家同行,我想总结一下我在咨询工作中的收获,让你们能够在不经历同样磨砺的情况下从中受益。
使用 Python 分析门多塔湖冰冻现象
一个可以立即使用的分析准备好的数据集
·发表于 Towards Data Science ·9 分钟阅读·2024 年 6 月 10 日
--

图片由 Dave Hoefler 提供,来自 Unsplash
位于威斯康星州麦迪逊的门多塔湖被认为是世界上最被研究的湖泊之一,拥有自 19 世纪 50 年代以来的冰冻现象记录。由威斯康星大学气候学办公室维护的这一记录,已经取得了令人瞩目的进展,使得数据变得可以下载。然而,仅仅将冰面冻结日期随时间变化的简单图表仍然需要用户进行一些处理才能使其易于阅读。正因如此,我创建了mendotapy,一个开源(MIT 许可)Python 包,它允许以分析准备好的格式提取这些宝贵数据。它内置了一些工具,使得常见的转换操作变得轻松。
从气候角度来看,我觉得这个数据集非常有趣,但我认为任何人都能发现这个数据集的相关性。例如,它可以作为一个玩具数据集,用来练习回归分析、统计预测或谱分析。
在这篇文章中,我将描述这个包,展示一些数据的图表,并总结我在制作这个包时遇到的挑战。生成这些图形的所有代码都可以在本文中找到,包的代码可以在GitHub上找到(欢迎贡献)。
安装和使用该包
分析电动汽车购买的利弊:来自报纸新闻的见解
使用 Mistral LLM 进行全面评估以指导明智决策
·发表于Towards Data Science ·阅读时长 9 分钟·2024 年 7 月 1 日
--

myenergi 在 Unsplash
在今天的市场中,购买电动汽车是一个重要的挑战,也是我们必须仔细评估的购买决策。许多消费者面临选择传统燃油车或电动汽车的不确定性。这一决策至关重要,因为买车涉及重大经济支出。
关于是否购买电动汽车的信息仍然不明确,因为市场中对于电动汽车可能完全普及的前景存在不确定性。此外,许多关于电动汽车的网站和新闻已经过时,提供的信息并未反映实际情况,导致很难做出明智的决策。
因此,本文的目标是提供对电动汽车领域近期新闻的分析。通过分析报纸新闻,我们可以识别出与今天购买电动汽车相关的正面和负面因素。这些近期新闻的分析可以为我们提供对当前情况的更新视角,使我们能够做出基于最新数据的决策,而不是依赖过时的信息做出决策。
使用嵌入模型和大语言模型(LLMs)分析非结构化 PDF 数据
如何将 PDF 转化为可操作的见解
·发布于Towards Data Science ·阅读时长 8 分钟·2024 年 6 月 15 日
--

图片由作者提供
我们生活在一个数据极其杂乱的世界里,几乎每个数据科学家或分析师都知道这一点。几乎可以说,无论你开始哪个数据科学项目,整个项目最困难的部分都是获取、打包和清洗数据,以便能够提供有价值的见解。如今大多数公司和个人面临的问题并不一定是数据的获取(那方面数据很多),而是如何打包和清洗那些非结构化的、杂乱无章的数据。
以公共公司的 SEC 文件中的 PDF 为例。这些文件包含数百页,详细说明了重要的公司信息。此外,新的文件每季度和每年都会创建,从公司上市以来到最新的文件为止。这些文件包含的数据非常庞大,但却不是传统意义上的数据。这些文档包含表格、图表、摘要、解释等大量内容。分析这些文档的问题不是因为它们缺乏信息(当然,有时候确实如此),而是更大的问题在于找到一种快速简便的方式来提取对个体用户最重要的细节。最后这一点可能是最重要的:个体用户。对于我来说…
Polars 查询的构造:Polars 与 SQL 的语法比较
从 Pandas 到 Polars 的轻松过渡——通过在 SQL 中短暂停留。
·发布于 Towards Data Science ·7 分钟阅读·2024 年 3 月 19 日
--
秘密已经揭晓!Polars是当前最火的技术,每个人都想尝一块 😎
我最近写了一篇文章,《我永久转换从 Pandas 到 Polars 的 3 个原因》,因为,嗯,这是选择 Polars 的最常见使用场景——作为 Pandas 的替代品。然而,尽管这是最常见的用例,从 Pandas 到 Polars 的过渡可能会有些奇怪,因为两者在语法上有很大的不同。
在我之前的博客文章中,我讨论了 Pandas 如何迫使用户以面向对象的编程方法进行数据查询,而 Polars 则允许用户以面向数据的编程方法进行数据查询,类似于 SQL。因此,尽管 Polars 通常作为 Pandas 的替代品,但如果你正在尝试学习 Polars,与其将其与 Pandas 进行比较,不如将其与 SQL 进行比较,这可能是一个更容易的起点。本文的目标正是做到这一点:将 Polars 的语法与 SQL 的语法进行比较,为开始使用 Polars 打下基础。
窗口函数的结构
一个被低估的 SQL 操作的理论与实践
·发表于 Towards Data Science ·12 分钟阅读·2024 年 6 月 11 日
--

图片来源:Marcus Woodbridge 于 Unsplash
介绍
IT 领域以其不断变化而闻名,每天都有新的工具、新的框架、新的云服务提供商和新的大语言模型(LLMs)被创造出来。然而,即便在这个忙碌的世界中,一些原则、范式和工具似乎在挑战 现状——“没有什么是永恒不变的”。而在数据领域,没有什么能像 SQL 语言一样,给人留下如此深刻的印象。
自从它在 80 年代诞生以来,SQL 穿越了数据仓库的时代,实现在 Hadoop/数据湖/大数据中如 Hive 的形式,直到今天,作为 Spark API 之一仍然存在。世界变化了很多,但 SQL 不仅依然存在,而且非常重要和广泛应用。
但是 SQL 就像国际象棋,理解基本规则很容易,但要精通却很难!它是一种有着众多可能性、解决同一问题的多种方式、许多函数和关键字的语言,遗憾的是,它还有许多被低估的功能。如果这些功能被更好地了解,可能在构建查询时能为我们带来很大的帮助。
因此,在这篇文章中,我想谈谈我在构建日常查询时发现非常有用的一个不太为人所知的 SQL 特性:窗口函数。
什么是窗口函数
传统的最著名的关系型数据库(如 PostgreSQL、MySQL 和 Oracle)是基于关系代数概念的。在其中,行称为元组(tuples),表格称为关系(relations)。关系是一个元组的集合(从数学意义上讲),也就是说,它们之间没有顺序或连接。因此,表中的行没有默认顺序,并且在一行上执行的计算不会影响也不会受其他行结果的影响。即使是像 ORDER BY 这样的子句,也仅仅对表格进行排序,无法在一行中基于其他行的值进行计算。
简单来说,窗口函数解决了这个问题,扩展了 SQL 功能,允许我们基于其他行的值在一行中执行计算。
理解的基本案例/解剖结构
1-无聚合的聚合
理解窗口函数的最简单例子是能够“无聚合的聚合”。
当我们使用传统的 GROUP BY 进行聚合时,整个表格会被压缩成一个新表,其中每一行代表一个组的元素。使用窗口函数时,不是压缩行,而是可以在同一表格中创建一个包含聚合结果的新列。
例如,如果你需要将支出表中的所有支出加总,传统方法是这样做的:
SELECT SUM(value) AS total FROM myTable
使用窗口函数时,你会做如下操作:
SELECT *, SUM(value) OVER() FROM myTable
-- Note that the window function is defined at column-level
-- in the query
下图显示了结果:

图 1. 传统的 Group By 与窗口函数。
与其创建一个新表,它将在一个新列中返回聚合值。请注意,值是相同的,但表并没有被“汇总”,原始的行被保留——我们只是进行了一个无聚合的聚合 😉
OVER 子句是我们创建窗口函数的标志。该子句定义了计算将在哪些行上进行。上述代码中它是空的,因此它将在所有行上计算 SUM()。
当我们需要基于列的总计(或平均值、最小值、最大值)进行计算时,这非常有用。例如,要计算每项支出相对于总支出的百分比。
在实际案例中,我们可能还希望按某些类别查看详细信息,就像图 2 中的例子那样,其中列出了按部门划分的公司支出。同样,我们可以通过简单的 GROUP BY 来获得每个部门的总支出:
SELECT depto, sum(value) FROM myTable GROUP BY depto
或者在窗口函数中指定 PARTITION 逻辑:
SELECT *, SUM(value) OVER(PARTITION BY depto) FROM myTable
查看结果:

图 2. 传统的 Group By 与窗口函数 II。
这个例子有助于理解为什么该操作被称为“窗口”函数——OVER 子句定义了一组行,函数将在这些行上执行操作,表中的一个“窗口”。
在上面的例子中,SUM() 函数将在由 depto 列(RH 和 SALES)创建的分区中操作 —— 它将分别计算 depto 列中每个项目的“value”列的所有值。该行所属的组(RH 或 SALES)决定了“Total”列中的值。
2 — 时间和排序意识
有时我们需要基于其他行的值来计算某一行中的列的值。一个经典的例子是计算一个国家 GDP 的年增长率,使用当前值和前一年的值来计算。
这类计算,比如需要前一年值、当前行与下一行的差异、一系列中的第一个值等等,都是窗口函数强大功能的体现。事实上,我不知道是否可以通过标准 SQL 命令实现这种行为!可能可以,但会是一个非常复杂的查询……
但是,窗口函数使得这一过程变得简单,见下图(表格记录了某些孩子的身高):

图像 3. 分析函数示例。
SELECT
year, height,
LAG(height) OVER (ORDER BY year) AS height_last_year
FROM myTable
LAG( ‘column’ ) 函数负责引用前一行的‘column’值。你可以把它想象成一个步骤序列:在第二行,考虑第一行的值;在第三行,考虑第二行的值;依此类推……第一行不算(因此是NULL),因为它没有前一个值。
自然地,需要一些排序标准来定义“前一行”是什么。这是窗口函数中的另一个重要概念:分析函数。
与传统 SQL 函数不同,分析函数(如 LAG)认为行之间是有排序的 —— 这个排序是由 OVER() 内的 ORDER BY 子句定义的,也就是说,第一行、第二行、第三行等概念是在 OVER 关键字内定义的。这些函数的主要特点是能够相对于当前行引用其他行:LAG 引用上一行,LEAD 引用下一行,FIRST 引用分区中的第一行,依此类推。
LAG 和 LEAD 的一个优点是,它们都接受第二个参数,即偏移量,指定向前(对于 LEAD)或向后(对于 LAG)查看多少行。
SELECT
LAG(height, 2) OVER (ORDER BY year) as height_two_years_ago,
LAG(height, 3) OVER (ORDER BY year) as height_three_years_ago,
LEAD(height) OVER (ORDER BY year) as height_next_year
FROM ...
而且,使用这些函数进行计算也是完全可行的:
SELECT
100*height/(LAG(height) OVER (ORDER BY year))
AS "annual_growth_%"
FROM ...
3 — 时间意识与聚合
时间与空间其实是统一的 —— 爱因斯坦曾经这么说过,或者类似的话,我不太确定 ¯_(ツ)_/¯
现在我们知道如何进行分区和排序,可以将这两者结合使用!回到之前的例子,假设桌子上有更多的孩子,我们需要计算每个孩子的增长率。非常简单,只需将排序和分区结合起来!我们可以按年份排序,并按孩子的名字进行分区。
SELECT 1-height/LAG(height) OVER (ORDER BY year PARTITION BY name) ...

图像 4. ORDER BY + PARTITION BY
上述查询执行以下操作——按子项对表进行分区,并在每个分区中按年份对值进行排序,随后用当前年份的高度值除以前一年的值(并从结果中减去 1)。
我们正逐渐接近‘窗口’的完整概念!它是一个表切片,是一组按 PARTITION BY 中定义的列分组的行,按 ORDER BY 中的字段排序,其中所有计算仅考虑同一组(分区)中的行以及特定的排序。
4-排名与位置
窗口函数可以分为三类,其中两类我们已经讨论过:聚合函数(COUNT、SUM、AVG、MAX 等)和分析函数(LAG、LEAD、FIRST_VALUE、LAST_VALUE 等)。
第三组是最简单的——排名函数,其中最常用的函数是 row_number(),它返回一个整数,表示行在分组中的位置(基于定义的顺序)。
SELECT row_number() OVER(ORDER BY score)
排名函数,顾名思义,根据行在分组中的位置返回值,该位置由排序标准定义。ROW_NUMBER、RANK 和 NTILE 是最常用的几个函数。

图片 5:排名函数示例
在上面的图片中,行号是根据每个玩家的得分生成的。
…是的,它犯了一个可怕的编程错误,那就是从 1 开始。
5-窗口大小
到目前为止,所有展示的函数在计算结果时都会考虑分区/组中的所有行。例如,第一示例中描述的 SUM()函数会考虑所有部门的行来计算总和。
但可以指定一个更小的窗口大小,也就是计算时需要考虑当前行前后多少行。这是一个有用的功能,用于计算移动平均/滚动窗口。
让我们考虑以下示例,表格中包含某种疾病的每日病例数,我们需要计算考虑当前日期和前两天的平均病例数。请注意,可以使用之前展示的 LAG 函数来解决这个问题:
SELECT
( n_cases + LAG(n_cases, 1) + LAG(n_cases, 2) )/3
OVER (ORDER BY date_reference)
但我们可以使用框架的概念更优雅地实现相同的结果:
SELECT
AVG(n_cases)
OVER (
ORDER BY date_reference
ROWS BETWEEN 2 PRECEDING AND CURRENT ROW
)
上述框架指定了我们必须仅计算前两行(PRECEDING)和当前行的平均值。如果我们希望考虑前一行、当前行和后一行,可以更改框架:
AVG(n_cases)
OVER (
ORDER BY date_reference
ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
)
这就是所谓的框架——限制一个函数作用范围到特定边界的方法。默认情况下(在大多数情况下),窗口函数会考虑以下框架:
ROWS BETWEEN UNBOUDED PRECEDING AND CURRENT ROW
-- ALL THE PREVIOUS ROWS + THE CURRENT ROW

图片 6:探索窗口大小定义
我希望这段介绍能帮助你更好地理解什么是窗口函数,它们是如何工作的,以及它们在实际中的语法。自然,窗口函数可以添加更多的关键字,但我认为这段介绍已经涵盖了你在日常工作中很可能使用的许多命令。接下来,让我们看一些我在日常工作中用来解决问题的有趣实际应用 —— 有些是非常有趣的!
窗口函数的好奇而有趣的使用案例
随时间的累积和
这是使用窗口函数的经典案例之一。
假设有一张表,记录了你每个月的工资,并且你想知道每个月的累计收入(考虑到所有之前的月份),它是这样工作的:

图 7. 实际示例 — 累积和
非常简单,对吧?
在这个查询中,值得注意的一个有趣点是,SUM() 函数会考虑当前行和所有之前的行来计算聚合,就像之前提到的那样。
日志表中事件的持续时间
我最近在我的文章我在 DuckDB 中的第一亿行(数据)中使用了这个方法,在文章中我处理了来自巴西电子投票机的日志数据,如果你对处理大数据量感兴趣,这篇文章值得一读。
总结来说,假设有一个日志表,每个事件都由一个时间戳组成,表示事件的开始时间、事件的名称以及一个唯一标识符。考虑到每个事件只有在前一个事件结束后才开始,我们可以轻松地添加一个表示事件持续时间的列,如下所示:

图 8. 实际示例 — 日志中的事件持续时间
填充缺失值(使用最后一次出现的值)
使用 pandas 进行机器学习时的经典方法!只需进行 fillna、bfill 或者其他方法,就可以填充空值,用最后一个有效值填补缺失值。
如何在 SQL 中做到这一点?很简单!

图 8. 实际示例 — 填充缺失值 I
当我们第一次学习机器学习时,我们经常使用 pandas,并习惯了它们的高级功能。然而,在实际项目中,数据量可能非常庞大,因此我们可能没有运气使用 pandas,而需要切换到 PySpark、Snowflake、Hive+hadoop 等工具 —— 这些工具都可以以某种方式用 SQL 进行操作。因此,我认为学习如何在 SQL 中进行这些数据处理和预处理是非常重要的。
填充缺失值(使用前面行的平均值)
填充空值的稍微复杂一点的方法,但依然简单!

图 8. 实际示例 — 填充缺失值 II
这个例子突出了尽管窗口函数看起来复杂且特殊,但它们可以像普通列一样使用!它们可以包含在 CASE 语句中,可以用它们进行计算等等。我知道的少数几个限制之一是,它们不能直接放在 WHERE 子句中:
SELECT * FROM
WHERE SUM() OVER() > 10 -- This is not possible in postgres
基于一组列的行去重
另一个窗口函数的经典例子!有时我们需要根据一组列来去重表中的行。
当然,在 SQL 中我们有 DISTINCT 子句,但它仅在整行重复时有效。如果一个表中有多行 ID 列相同但其他列的值不同的记录,可以通过以下逻辑来去重:

图 9. 实际例子 — 去重
SELECT *
FROM (
SELECT
ROW_NUMBER() OVER (PARTITION BY id) as row_number
)
WHERE row_number = 1
这个操作还允许数据版本控制!例如,如果我们在系统中每次用户更改姓名时保存一行新记录,并附上更改日期(而不是修改现有记录),我们就可以检索每个用户的当前姓名:
SELECT
*
FROM
(
SELECT
name,
row_number() OVER (PARTITION BY id ORDER BY DATE DESC) AS row_number
FROM myTable
) AS subquery
WHERE row_number = 1
一个组/类别在所有行中的出现百分比
假设有一个列出了各种宠物的表,这些宠物可以是狗、猫或鸟。我们需要为每一行添加一列,表示每种宠物类型占所有宠物总数的百分比。这个任务通过使用不仅仅一个,而是两个窗口函数来解决!

图 10. 实际例子 — 出现百分比
在上面的图像中,为了更具教育性,我添加了两列来表示每个窗口函数的结果,但实际上只创建了最右侧的一列。
那你呢?你有什么有趣的窗口函数案例想要分享吗?请在评论区留下!
结论
我不敢说 SQL 是复古或经典的,因为这些词虽然是褒义的,但指的是过去。对我来说,SQL 是现在的、普遍存在的,而且无疑是任何从事数据领域工作的人必须掌握的语言。
然而,有些问题仅用 SQL 本身解决可能显得复杂,这时,良好的语言知识和对其能力的了解就显得尤为重要。如果没有窗口函数,很多从 Pythonic 视角看待的常见问题将变得非常困难,甚至无法解决。但只要我们知道如何正确使用工具,就能创造奇迹!
希望这篇文章帮助你更好地理解窗口函数是如何工作的,以及它们可以在实际中解决哪些类型的问题。这里展示的所有材料主要基于 PostgreSQL 语法,可能在其他数据库中不一定能立即生效,但最重要的是理解它背后的逻辑。像往常一样,我不是该领域的专家,强烈建议对这个主题感兴趣的朋友深入阅读并多加练习。
感谢阅读!😉
参考文献
所有代码均可在这个 GitHub 仓库中找到。
对这类作品感兴趣吗?请访问我的帖子仓库。
1 使用 PostgreSQL 窗口函数进行数据处理。(n.d.)。Timescale。Link。
[2] Kho, J. (2022 年 6 月 5 日). 高级 SQL 窗口函数简易指南 — 朝向数据科学。Medium。
[3] Markingmyname. (2023 年 11 月 16 日). 分析函数 (Transact-SQL) — SQL Server。Microsoft Learn。
[4] PostgreSQL 教程。(2021 年 4 月 27 日). PostgreSQL 窗口函数:终极指南。Link。
[5] VanMSFT. (2023 年 5 月 23 日). OVER 子句 (Transact-SQL) — SQL Server。Microsoft Learn。
[6] 窗口函数。(n.d.)。SQLite 官方文档。
[7] 窗口函数。(2014 年 7 月 24 日)。PostgreSQL 文档。
本文中的所有图片均由作者制作。
Python 中的动态可视化
如何用 OpenCV 和 Matplotlib 制作动画图表
·发布于 Towards Data Science ·8 分钟阅读·2024 年 11 月 21 日
--

追踪球体轨迹并实时动画显示其垂直位置
在计算机视觉中,一个基本目标是从静态图像或视频序列中提取有意义的信息。为了理解这些信号,通常可视化它们是非常有帮助的。
例如,当在高速公路上追踪单个汽车时,我们可以在它们周围画上边界框;或者在生产线上的传送带上检测到问题时,我们可以使用不同的颜色来标记异常。但如果提取的信息是数值性质的,并且你想要可视化该信号的时间动态呢?
仅将值作为数字显示在屏幕上可能无法提供足够的洞察,尤其是当信号快速变化时。在这些情况下,可视化信号的一种非常好的方法是使用带有时间轴的图表。在这篇文章中,我将展示如何结合OpenCV和Matplotlib的强大功能,创建这种信号的动态实时可视化。
我在这个项目中使用的代码和视频可以在 GitHub 上找到:
[## GitHub - trflorian/ball-tracking-live-plot: 使用 OpenCV 追踪球体并绘制…
使用 OpenCV 追踪球体并用 Matplotlib 绘制轨迹 - trflorian/ball-tracking-live-plot
绘制球体轨迹
让我们来探索一个简单的问题,我录制了一个球垂直投向空中的视频。目标是跟踪视频中的球,并绘制其位置 p(t)、速度 v(t)和加速度 a(t)随时间的变化。

输入视频
我们将参考帧定义为摄像头,并为简单起见,我们仅追踪图像中球体的垂直位置。我们预计位置呈抛物线形状,速度呈线性减少,且加速度保持恒定。

我们预期的图形草图
球体分割
在第一步,我们需要识别视频序列中每一帧的球体。由于摄像头保持静止,检测球体的一个简单方法是使用背景减除模型,并结合颜色模型来去除帧中的手部。
首先,我们通过使用VideoCapture从OpenCV库播放视频片段,并用一个简单的循环显示视频。视频播放到结尾时,我们会重新启动视频。我们还确保以原始帧率回放视频,通过计算视频的 FPS 来确定sleep_time(以毫秒为单位)。最后,记得在结束时释放资源并关闭窗口。
import cv2
cap = cv2.VideoCapture("ball.mp4")
fps = int(cap.get(cv2.CAP_PROP_FPS))
while True:
ret, frame = cap.read()
if not ret:
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
continue
cv2.imshow("Frame", frame)
sleep_time = 1000 // fps
key = cv2.waitKey(sleep_time) & 0xFF
if key & 0xFF == ord("q"):
break
cap.release()
cv2.destroyAllWindows()

输入视频的可视化
首先,我们要提取一个用于球体的二值分割掩模。这实际上意味着我们希望创建一个掩模,该掩模对于球体的像素是激活的,而对于其他所有像素是非激活的。为此,我将结合两个掩模:一个是运动掩模,另一个是颜色掩模。运动掩模提取运动部分,而颜色掩模主要去除帧中的手部。
对于颜色滤波器,我们可以将图像转换到HSV颜色空间,并选择一个特定的色调范围(20–100),该范围包含球体的绿色颜色,而不包含皮肤色调。我不会对饱和度或亮度值进行过滤,因此我们可以使用完整的范围(0–255)。
# filter based on color
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
mask_color = cv2.inRange(hsv, (20, 0, 0), (100, 255, 255))
为了创建一个运动掩模,我们可以使用一个简单的背景减除模型。我们使用视频的第一帧作为背景,并将学习率设置为 1。在循环中,我们应用背景模型来获取前景掩模,但通过将学习率设置为 0,不将新帧集成到模型中。
...
# initialize background model
bg_sub = cv2.createBackgroundSubtractorMOG2(varThreshold=50, detectShadows=False)
ret, frame0 = cap.read()
if not ret:
print("Error: cannot read video file")
exit(1)
bg_sub.apply(frame0, learningRate=1.0)
while True:
...
# filter based on motion
mask_fg = bg_sub.apply(frame, learningRate=0)
在下一步中,我们可以将这两个掩模合并,并应用开运算形态学来去除小的噪声,最终得到一个完美的球体分割。
# combine both masks
mask = cv2.bitwise_and(mask_color, mask_fg)
mask = cv2.morphologyEx(
mask, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (13, 13))
)

左上角: 视频序列,右上角: 颜色掩模,左下角: 运动掩模,右下角: 组合掩模
跟踪球体
我们剩下的唯一部分是图像中的球体遮罩。为了跟踪球的中心,我首先提取球的轮廓,然后以其边界框的中心作为参考点。如果有噪声通过了我们的遮罩,我会根据大小过滤掉检测到的轮廓,只关注最大的一个。
# find largest contour corresponding to the ball we want to track
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) > 0:
largest_contour = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(largest_contour)
center = (x + w // 2, y + h // 2)
我们还可以在帧中添加一些注释来可视化我们的检测。我将画两个圆,一个表示球的中心,另一个表示球的边缘。
cv2.circle(frame, center, 30, (255, 0, 0), 2)
cv2.circle(frame, center, 2, (255, 0, 0), 2)
为了跟踪球的位置,我们可以使用列表。每当我们检测到球时,我们只需将中心位置添加到列表中。我们还可以通过在跟踪位置列表中的每个片段之间绘制线条来可视化轨迹。
tracked_pos = []
while True:
...
if len(contours) > 0:
...
tracked_pos.append(center)
# draw trajectory
for i in range(1, len(tracked_pos)):
cv2.line(frame, tracked_pos[i - 1], tracked_pos[i], (255, 0, 0), 1)

球的轨迹可视化
创建图表
现在我们可以跟踪球的位置,接下来让我们探索如何使用matplotlib绘制信号。第一步,我们可以先在视频结束时创建最终的图表,然后在第二步再考虑如何实时动画化它。为了显示位置、速度和加速度,我们可以使用三个水平排列的子图:
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(10, 2), dpi=100)
axs[0].set_title("Position")
axs[0].set_ylim(0, 700)
axs[1].set_title("Velocity")
axs[1].set_ylim(-200, 200)
axs[2].set_title("Acceleration")
axs[2].set_ylim(-30, 10)
for ax in axs:
ax.set_xlim(0, 20)
ax.grid(True)
我们只对图像中的 y 位置(数组索引 1)感兴趣,若要得到一个零偏移位置图,我们可以减去第一个位置。
pos0 = tracked_pos[0][1]
pos = np.array([pos0 - pos[1] for pos in tracked_pos])
对于速度,我们可以使用位置差异作为近似值,对于加速度,我们可以使用速度差异。
vel = np.diff(pos)
acc = np.diff(vel)
现在我们可以绘制这三个值:
axs[0].plot(range(len(pos)), pos, c="b")
axs[1].plot(range(len(vel)), vel, c="b")
axs[2].plot(range(len(acc)), acc, c="b")
plt.show()

位置、速度和加速度的静态图
动画化图表
接下来是有趣的部分,我们希望使这个图表动态化!由于我们在 OpenCV 的 GUI 循环中工作,我们不能直接使用matplotlib的show函数,因为这会阻塞循环,导致程序无法运行。相反,我们需要利用一些技巧 ✨
主要思想是将图表绘制到内存中的缓冲区,然后在 OpenCV 窗口中显示这个缓冲区。通过手动调用画布的绘制函数,我们可以强制将图形渲染到缓冲区中。然后,我们可以获取这个缓冲区并将其转换为数组。由于缓冲区是RGB格式,但 OpenCV 使用的是BGR格式,因此我们需要转换颜色顺序。
fig.canvas.draw()
buf = fig.canvas.buffer_rgba()
plot = np.asarray(buf)
plot = cv2.cvtColor(plot, cv2.COLOR_RGB2BGR)
确保axs.plot调用现在位于帧循环内部:
while True:
...
axs[0].plot(range(len(pos)), pos, c="b")
axs[1].plot(range(len(vel)), vel, c="b")
axs[2].plot(range(len(acc)), acc, c="b")
...
现在我们可以使用 OpenCV 的imshow函数简单地显示图表。
cv2.imshow("Plot", plot)

动画图表
然后,瞧,您得到了动画图!然而,您会注意到性能相当低下。每帧重新绘制整个图形非常耗费资源。为了提高性能,我们需要使用blitting。这是一种高级渲染技术,它将图形的静态部分绘制到背景图像中,只重新绘制变化的前景元素。为了设置这一点,我们需要在帧循环之前,首先定义我们三个图形的引用。
pl_pos = axs[0].plot([], [], c="b")[0]
pl_vel = axs[1].plot([], [], c="b")[0]
pl_acc = axs[2].plot([], [], c="b")[0]
然后我们需要在循环之前绘制一次图形的背景,并获取每个坐标轴的背景。
fig.canvas.draw()
bg_axs = [fig.canvas.copy_from_bbox(ax.bbox) for ax in axs]
在循环中,我们现在可以改变每个图形的数据,然后对于每个子图,我们需要恢复区域的背景,绘制新的图形,并调用blit函数来应用更改。
# Update plot data
pl_pos.set_data(range(len(pos)), pos)
pl_vel.set_data(range(len(vel)), vel)
pl_acc.set_data(range(len(acc)), acc)
# Blit Pos
fig.canvas.restore_region(bg_axs[0])
axs[0].draw_artist(pl_pos)
fig.canvas.blit(axs[0].bbox)
# Blit Vel
fig.canvas.restore_region(bg_axs[1])
axs[1].draw_artist(pl_vel)
fig.canvas.blit(axs[1].bbox)
# Blit Acc
fig.canvas.restore_region(bg_axs[2])
axs[2].draw_artist(pl_acc)
fig.canvas.blit(axs[2].bbox)
就这样,绘图速度加快,性能得到了显著提升。

优化后的图形
结论
在这篇文章中,您学会了如何应用简单的计算机视觉技术提取一个移动的前景物体并追踪它的轨迹。接着,我们使用matplotlib和OpenCV创建了一个动画图。这个绘图演示是通过一个玩具示例视频实现的,视频中一个球被垂直抛向空中。然而,这个项目中使用的工具和技术对于各种任务和实际应用都非常有用!完整的源代码可以在我的 GitHub 上找到。希望今天您有所收获,祝编码愉快,保重!
[## GitHub - trflorian/ball-tracking-live-plot: 使用 OpenCV 追踪一个球并绘制轨迹…
使用 OpenCV 追踪一个球并使用 Matplotlib 绘制轨迹 - trflorian/ball-tracking-live-plot
本文中的所有可视化图表均由作者创建。
再次登顶珠穆朗玛
如何在人工智能的难题中取得进展
·发表于 Towards Data Science ·阅读时间 8 分钟·2024 年 11 月 22 日
--

珠穆朗玛峰,当地人称之为萨加玛塔或珠穆朗玛(维基百科)
新技术诞生、成熟,最终被取代。人工智能(AI)也不例外,将遵循这一发展曲线。许多新闻文章已经宣称生成性人工智能(Gen AI)已经进入了“幻灭低谷”:即采用过程中的阶段,早期采用者开始意识到,新技术所承诺的成果比他们想象的要难以实现得多。

Gartner 炒作周期 维基百科
这是正常现象,在生成性人工智能之前也曾多次发生过。比如区块链的兴起与衰退——你在商店买的生菜将通过区块链从农场追踪到餐桌!又比如大数据:你将能够了解客户的一切,以最小的努力为他们提供价值,同时为自己带来利润!
问题在于,这些新技术所解决的每个问题实际上都是非常庞大的。每个问题都是一座独立的“珠穆朗玛峰”。
就像珠穆朗玛峰一样,你不能在一天内登顶。需要几个月甚至几年的准备。途中每个营地都为该位置量身定制。有时,即使是准备最充分的尝试也未必能成功登顶——这并不意味着登山队不具备资格或能力:可能是天气不好,或者他们走错了路线。

1963 年美国珠穆朗玛峰远征(维基百科)
你的生成式人工智能策略应该与爬珠穆朗玛峰的策略相同(不过可以暂时不使用额外的氧气)。
生成式人工智能正在解决的每个问题通常都是一个大难题——输入复杂、输出复杂,并且有复杂的过程将两者连接起来。
记住:在爬山时,大幅度的跃进是危险的。进步实际上是通过沿着一条路径的小步伐逐渐取得的。

不要跳跃——使用梯子(维基百科)
每一步登顶前,都需要在山坡上收集和组织所需的材料。你不想在珠穆朗玛峰的一半高度时没有食物或水。
类似地,你需要训练自己和你的团队在艰险条件下具备在更高海拔上执行任务的体能。
理解正在解决的问题
这不应该意味着“今天的解决方案是什么样的”。现代化努力通常需要替换现有的依赖于变通和妥协的解决方案。理解实际问题至关重要。这个过程的结果到底是从哪里产生价值的?它是如何改善客户体验的?明确界定问题有助于后续定义明确的需求。
关键是要记住,人类在处理模糊的要求时非常擅长。因此,许多人工智能正在解决的“大难题”通常都是这样描述的:
“我们希望使用人工智能来自动化处理我们用于处理所有大客户订单的复杂订单系统!”
听起来很棒!你能描述一下从头到尾的整个过程是如何运作的吗?
“嗯,我们从客户那里收到电子邮件,提取订单信息,然后将这些信息放入我们的订单表单。然后我们将表单上传到订单系统中进行处理。生成式人工智能可以自动化整个过程,对吧?”
如果我们一步一步地构建,当然可以!
上述过程包含了大量的模糊性。期望生成式人工智能能够轻松处理上述过程中的每一个细节,是一个错误。
-
电子邮件的格式是什么?总是相同的格式吗?
-
客户是如何描述订单信息的?他们使用的是口语化的术语吗?还是使用了你的物品编号?
-
客户的订单信息与您的履单系统使用的是否相同?是否存在查找过程?
-
上传期待的格式是什么?文本?PDF?Excel?
-
如果是 Excel 模板,是否有多个工作表?不可写入的单元格?数据验证要求?
生成式 AI 能够 处理所有这些任务 —— 你只需要能够清晰地定义每一步骤。如果你不能清楚地描述一个流程的输入和输出,生成式 AI 很可能不会按照你预期的方式执行。
如果你从自上而下的角度来处理(提示是“你是一个填写订单表单的 AI 代理”),你最终会得到一个大约 50% 准确的过程(老实说,这已经挺不错了!),但结果格式可能与你预期的不一致。问题是,你仍然需要人工检查每一个输出,这样就相当于加倍了工作量。
最小可行产品(MVP):我们不是已经在这里走过一遍了吗?
这并不是什么新鲜事。我们已经做了多年的最小可行产品(MVP)了。你必须从小做起,解决问题中的单一步骤,并在此基础上逐步扩展(通过客户反馈!)。AI 产品和工作流也没有什么不同。先构建立即有用的部分,然后从那里扩展。
我们如何将上述内容应用到订单系统中呢?我们应该将流程中的每个步骤拆解,并在最合适的地方应用生成式 AI:
-
客户发送订单邮件(非结构化输入)
-
订单详情被填写到表单中(结构化输入)
-
表单已格式化并上传到系统中(结构化输出)或者:
-
没有表单,订单手动构建(非结构化输出)
邮件内容通常是非结构化的,这使得在此处应用 AI 成为一个非常好的用例!在这种情况下,询问你的流程负责人:“一个有效的邮件订单应该包含哪些内容?” 客户姓名、账户号码、地址、所请求的商品及商品数量等数据是很好的候选项。为了在处理这些订单时最大化生成式 AI 系统的准确性和韧性,定义 AI 应遵循的数据结构。我将 使用 [pydantic](https://docs.pydantic.dev/latest/) 帮助构建这些结构 如下:
from pydantic import BaseModel
class OrderItem(BaseModel):
ItemName: str
ItemQuantity: int
class EmailOrder(BaseModel):
CustomerName: str
AccountNumber: str
ShippingAddress: str
Items: list[OrderItem]
从这里,我们可以利用这些对象来为我们的 AI 提供结构:
>>> i = OrderItem(ItemName='eggs', ItemQuantity=2)
>>> i
OrderItem(ItemName='eggs', ItemQuantity=2)
>>> i.model_dump_json()
'{"ItemName":"eggs","ItemQuantity":2}'
>>> e = EmailOrder(CustomerName="James", AccountNumber="1234", ShippingAddress="1234 Bayberry Ln", Items=[i])
>>> e.model_dump_json()
'{"CustomerName":"James","AccountNumber":"1234","ShippingAddress":"1234 Bayberry Ln","Items":[{"ItemName":"eggs","ItemQuantity":2}]}'
现在,通过这些示例,你可以使用少量示例提示(few-shot prompting)来提供给你的生成式 AI,并提高准确性。我们将使用 LangChain OutputParsers 来做一些繁重的工作:
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAI
llm = OpenAI(model="gpt-3.5-turbo-instruct")
template = """
{format_instructions}
<email>
{email_body}
</email>
Instructions:
- Read the email and extract the information in it.
- Respond in the format instructions given above.
Begin!
"""
parser = JsonOutputParser(pydantic_object=EmailOrder)
prompt = PromptTemplate(
template=template,
input_variables=["email_body"],
partial_variables={
"format_instructions": parser.get_format_instructions
},
)
chain = prompt | llm | parser
email_body = "hello i'd like to order 2 eggs. My name is James. My account number is 1234\. My address is 1234 Bayberry Ln. Appreciate it!"
chain.invoke({"email_body": email_body})
实际上发送给 OpenAI 的提示是:
prompt = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {"properties": {"foo": {"title": "Foo", "description": "a list of strings", "type": "array", "items": {"type": "string"}}}, "required": ["foo"]}
the object {"foo": ["bar", "baz"]} is a well-formatted instance of the schema. The object {"properties": {"foo": ["bar", "baz"]}} is not well-formatted.
Here is the output schema:
```{"$defs": {"OrderItem": {"properties": {"ItemName": {"title": "Itemname", "type": "string"}, "ItemQuantity": {"title": "Itemquantity", "type": "integer"}}, "required": ["ItemName", "ItemQuantity"], "title": "OrderItem", "type": "object"}}, "properties": {"CustomerName": {"title": "Customername", "type": "string"}, "AccountNumber": {"title": "Accountnumber", "type": "string"}, "ShippingAddress": {"title": "Shippingaddress", "type": "string"}, "Items": {"items": {"$ref": "#/$defs/OrderItem"}, "title": "Items", "type": "array"}}, "required": ["CustomerName", "AccountNumber", "ShippingAddress", "Items"]}```py
<email>
"hello i'd like to order 2 eggs. My name is James. My account number is 1234\. My address is 1234 Bayberry Ln. Appreciate it!"
</email>
Instructions:
- Read the email and extract the information in it.
- Respond in the format instructions given above.
Begin!"""
当你发送这个提示时,LLM 会按照示例提取信息:
{
"CustomerName": "James",
"AccountNumber": "1234",
"ShippingAddress": "1234 Bayberry Ln",
"Items": [
{
"ItemName": "eggs",
"ItemQuantity": 2
}
]
}
通过使用这个定义明确的格式来处理电子邮件订单,我们可以将这个解析后的对象重新传回 LLM,并要求它确保订单中的所有必需字段都存在。如果没有,我们可以将电子邮件发送给人工求助!
例如,假设所有的 EmailOrders 也需要一个 CompanyName 字段。如果验证规则这么简单,我们可以直接使用pydantic进行验证(无需 AI!)。如果你的使用场景变得更加复杂,输出结果可以通过 LLM 进行处理,以提供更高层次的逻辑。
我们将采取与上述相同的订单,但省略 CompanyName 字段:
>>> class EmailOrder(BaseModel):
... CustomerName: str
... AccountNumber: str
... ShippingAddress: str
... Items: list[OrderItem]
... CompanyName: str
...
>>> e = EmailOrder(CustomerName="James", AccountNumber="1234", ShippingAddress="1234 Bayberry Ln", Items=[i])
Traceback (most recent call last):
File "<python-input-19>", line 1, in <module>
e = EmailOrder(CustomerName="James", AccountNumber="1234", ShippingAddress="1234 Bayberry Ln", Items=[i])
File "/Users/jbarney/.venv/lib/python3.13/site-packages/pydantic/main.py", line 212, in __init__
validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)
pydantic_core._pydantic_core.ValidationError: 1 validation error for EmailOrder
CompanyName
Field required [type=missing, input_value={'CustomerName': 'James',...ello', ItemQuantity=2)]}, input_type=dict]
Pydantic 在这里为我们做了很多工作,抛出了一个ValidationError。我们的驱动程序可以简单地捕获这个错误,并将电子邮件发送给人工审查员。
当然,一个 LLM 也可以检测到这个错误。我展示这个是为了完整性;通常你会想利用传统编程来进行数据验证:
prompt = """Evaluate that the input object matches the expected schema:
{input}
{schema}
Reply with "True" if it does match and "False" if it does not match.
"""
有了这些,我们现在拥有一个能够轻松处理正确书写的电子邮件订单的系统。更重要的是,我们实现了一个自我管理的过程,当 AI 需要帮助时,会让人工参与。
关键是,我们并没有重写整个订单录入过程!我们已经将这个耗时的过程部分化,并建立了一个集中人类努力的系统,将精力集中在那些能够产生最大差异的领域。未来,我们可以开始修改过程中的其他部分,系统地去除人工劳动力。
登顶珠穆朗玛峰
这种解决复杂问题的迭代方法并不新鲜。所有大问题都需要拆解成其组成部分,才能真正得到解决。
然而,AI 的“魔力”尤其令人信服。鉴于这些模型只需几行输入就能表现得如此强大,很容易希望实现大的突破。与区块链和大数据等技术相比,从创意到令人兴奋的概念验证所需的努力最小。AI 不需要数十台定制配置的服务器来执行一个跨越 18 TB 数据的 Map-Reduce 任务,这个任务花了你 6 个月时间进行迁移。
因此,在构建下一个 AI 解决方案时,请记住这种简单性:一步一步走向顶峰。
在那里见!
蚁群优化——直觉、代码与可视化
它与其他群体算法的不同之处
·发布于Towards Data Science ·10 分钟阅读·2024 年 1 月 21 日
--

该图像由 DALL·E 3 根据提示“在自然森林环境中画出未来的军蚁”创建。
本文是我系列文章的延续,系列灵感来自自然。
之前,我讨论过进化算法(EA)、粒子群优化(PSO)以及人工蜂群算法(ABC)。大自然无处不在,显然人类可以从大自然中学到更多并加以应用。
今天,我们关注蚂蚁。
当我们还是孩子时,我们学到蚂蚁是勤劳且合作的。我们的父母没有教我们的是,蚂蚁集体形成了一个高度复杂的群体,能够有效地相互沟通。
对蚂蚁或信息素(或任何化学物质的扩散)的了解在这里完全不需要。这些仅仅是为了包装而使用的名称。我之前已经展示过,你不需要了解蜜蜂的舞蹈来欣赏或利用 ABC,也不需要学习基因、突变或繁殖机制才能应用进化算法(EA)。
你只需要理解英语,就能掌握直觉,此外还需具备非常基础的数学和 Python 编程技能。虽然我会展示一些数学内容以保证完整性,包括希腊字母符号,但这些内容实际上只是为了补充完整。它……
Apache Beam:数据处理、数据管道、Dataflow 和 Flex 模板
在这篇文章中,我们将探索 Apache Beam,从一个简单的管道到更复杂的管道,使用 GCP Dataflow。让我们学习什么是 PTransform、PCollection、GroupByKey 以及 Dataflow Flex 模板。
·发表于 Towards Data Science ·阅读时间:19 分钟 ·2024 年 2 月 12 日
--

图片由 Faruk Kaymak 提供,来源于 Unsplash
Apache Beam 简介
毫无疑问,数据处理、特征创建、数据传输以及在安全环境中以稳定且高效的计算方式进行所有这些操作,对于当今所有的 AI 任务至关重要。早期,Google 开始开发一个开源项目,旨在同时进行 批处理 和 流处理 数据操作,命名为 Beam。随后,Apache 软件基金会开始为这个项目贡献代码,推动 Apache Beam 的规模化发展。
Apache Beam 的关键特点是其灵活性,使其成为构建数据处理管道的最佳编程 SDK 之一。我认为可以识别出 Apache Beam 中的 4 个主要概念,它们使其成为一个不可或缺的数据工具:
- 批处理/流处理的统一模型:Beam 是一个统一的编程模型,使用相同的 Beam 代码,你可以决定是进行…
Apache Hadoop 和 Apache Spark 用于大数据分析
一份完整的指南,介绍如何使用 Apache Hadoop(HDFS)和 Python 中的 PySpark 库进行大数据分析,以分析 Steam 游戏平台上的游戏评论。
·发布于 Towards Data Science ·14 分钟阅读·2024 年 5 月 8 日
--
全球每年生产超过 100 个泽字节(= 10¹²GB)的数据,因此处理大数据的能力是今天最为关键的技能之一。数据分析本身可以定义为处理大数据并从无止境、指数增长的数据中提取洞察力的能力。Apache Hadoop 和 Apache Spark 是帮助我们解开大型数据集中无限可能性的两个基本工具。Apache Hadoop 通过其分布式文件系统(HDFS)和基于 MapReduce 的数据并行处理,帮助我们简化数据存储和分布式计算。Apache Spark 是一个大数据分析引擎,能够进行 EDA、SQL 分析、流处理、机器学习和图处理,并通过其 API 与主要编程语言兼容。两者结合起来,构成了一个出色的大数据处理环境,并且在大多数情况下,只需要一台个人电脑即可完成!
让我们通过一个简单的分析项目,利用 Apache Spark 在 Python 中实现,来展开大数据和 Apache Hadoop的强大功能。
首先,让我们深入了解如何在 MacOS 上安装 Hadoop 分布式文件系统和 Apache Spark。我使用的是一台配备 M1 芯片的 MacBook Air,操作系统为 macOS Sonoma。
跳转到章节 —
-
安装 Hadoop 分布式文件系统
-
安装 Apache Spark
-
使用 PySpark 进行 Steam 评论分析
-
接下来做什么?
1. 安装 Hadoop 分布式文件系统
感谢 Code With Arjun 分享的精彩文章,帮助我完成了在我的 Mac 上的 Hadoop 安装。我按照他的步骤顺利安装并运行了 Hadoop,下面我也会展示这些步骤。
- a. 安装 HomeBrew
我使用 Homebrew 来在我的 Mac 上安装应用程序,便于操作。可以通过以下代码直接在系统上安装它 —
/bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
安装完成后,你可以运行以下简单代码来验证安装。
brew --version

图 1:作者提供的图片
然而,你可能会遇到一个错误,提示command not found,这是因为 Homebrew 会被安装在不同的位置(如图 2 所示),并且无法从当前目录执行。为了使其正常工作,我们需要为 brew 添加一个路径环境变量,即将 Homebrew 添加到 .bash_profile 中。

图 2:作者提供的图片
你可以通过在命令中使用 Homebrew 的完整路径来避免此步骤,但在后期阶段可能会变得麻烦,所以不推荐这么做!
echo ‘eval “$(/opt/homebrew/bin/brew shellenv)”’ >> /Users/rindhujajohnson/.bash_profile
eval “$(/opt/homebrew/bin/brew shellenv)”
现在,当你尝试运行 brew --version 时,它应该能够正确显示 Homebrew 版本。
- b. 安装 Hadoop
声明!Hadoop 是一个基于 Java 的应用程序,支持 11 版本以下的 Java 开发工具包(JDK),最好是 8 或 11。请在继续之前安装 JDK。
再次感谢 Code With Arjun,感谢他分享的关于在 MacBook M1 上安装 JDK 的视频。
安装 JDK 指南
现在,我们使用 brew 命令在系统上安装 Hadoop。
brew install hadoop
这个命令应该能够无缝地安装 Hadoop。与安装 HomeBrew 时遵循的步骤类似,我们需要在 Hadoop 文件夹中编辑 Java 的路径环境变量。安装版本的 Hadoop 的环境变量设置可以在 HomeBrew 中的 Hadoop 文件夹中找到。你可以使用which hadoop命令来查找 Hadoop 安装文件夹的位置。一旦找到文件夹,你可以在以下位置找到变量设置。以下命令将带你到所需的文件夹以编辑变量设置(请检查你安装的 Hadoop 版本,以避免出错)。
cd /opt/homebrew/Cellar/hadoop/3.3.6/libexec/etc/hadoop
你可以使用 ls 命令查看该文件夹中的文件。我们将编辑 hadoop-env.sh 以确保 Hadoop 在系统上正确运行。

图 3:作者提供的图片
现在,我们需要找到 Java 的路径变量,以便使用以下命令编辑 hadoop-ev.sh 文件。
/usr/libexec/java_home

图 4:图片来自作者
我们可以在任何文本编辑器中打开 hadoop-env.sh 文件。我使用了 VI 编辑器,您可以使用任何编辑器。我们可以将路径 Library/Java/JavaVirtualMachines/adoptopenjdk-11.jdk/Contents/Home 复制并粘贴到 export JAVA_HOME = 位置。

图 5:在 VI 文本编辑器中打开的 hadoop-env.sh 文件
接下来,我们编辑 Hadoop 文件夹中的四个 XML 文件。
core-site.xml
<configuration>
<property>
<name>fs.defaultFS</name>
<value>hdfs://localhost:9000</value>
</property>
</configuration>
hdfs-site.xml
<configuration>
<property>
<name>dfs.replication</name>
<value>1</value>
</property>
</configuration>
mapred-site.xml
<configuration>
<property>
<name>mapreduce.framework.name</name>
<value>yarn</value>
</property>
<property>
<name>mapreduce.application.classpath</name>
<value>
$HADOOP_MAPRED_HOME/share/hadoop/mapreduce/*:$HADOOP_MAPRED_HOME/share/hadoop/mapreduce/lib/*
</value>
</property>
</configuration>
yarn-site.xml
<configuration>
<property>
<name>yarn.nodemanager.aux-services</name>
<value>mapreduce_shuffle</value>
</property>
<property>
<name>yarn.nodemanager.env-whitelist</name>
<value>
JAVA_HOME,HADOOP_COMMON_HOME,HADOOP_HDFS_HOME,HADOOP_CONF_DIR,CLASSPATH_PREPEND_DISTCACHE,HADOOP_YARN_HOME,HADOOP_MAPRED_HOME
</value>
</property>
</configuration>
至此,我们已经成功完成了在本地安装和配置 HDFS。为了使 Hadoop 上的数据可以通过远程登录访问,我们可以在常规设置中的共享部分启用 远程登录。您可以通过点击信息图标编辑用户访问权限。

图 6:启用远程访问。图片来自作者
让我们运行 Hadoop!
执行以下命令
hadoop namenode -format
# starts the Hadoop environment
% start-all.sh
# Gathers all the nodes functioning to ensure that the installation was successful
% jps

图 7:启动 Hadoop 并查看运行中的节点和资源。图片来自作者
一切就绪!现在让我们在 HDFS 中创建一个目录并添加我们将要处理的数据。让我们快速查看我们的数据源及其详细信息。
数据
Steam Reviews Dataset 2021 (许可协议:GPL 2) 是一个包含大约 2100 万玩家评论的数据集,涵盖了 2021 年超过 300 款不同的游戏。数据是通过 Steam 的 API — Steamworks — 使用“获取列表”功能提取的。
GET store.steampowered.com/appreviews/<appid>?json=1
数据集由 23 列和 2170 万行组成,大小为 8.17 GB(这很大!)。数据包含不同语言的评论和一个布尔列,指示玩家是否推荐该游戏给其他玩家。我们将重点讨论如何使用 HDFS 本地处理这些大数据,并使用 PySpark 库在 Python 中通过 Apache Spark 进行分析。
- c. 上传数据到 HDFS
首先,我们使用 mkdir 命令在 HDFS 中创建一个目录。如果我们尝试将文件直接添加到一个不存在的文件夹中,它将抛出一个错误。
hadoop fs -mkdir /user
hadoop fs -mkdir /user/steam_analysis
现在,我们将使用 put 命令将数据文件添加到文件夹 steam_analysis 中。
hadoop fs -put /Users/rindhujajohnson/local_file_path/steam_reviews.csv /user/steam_analysis
Apache Hadoop 还提供一个用户界面,可以通过 localhost:9870/ 访问。

图 8:localhost:9870 上的 HDFS 用户界面。图片来自作者
我们可以看到上传的文件,如下所示。

图 10:在 HDFS 中浏览文件。图片来自作者
一旦数据交互完成,我们可以使用stop-all.sh命令停止所有 Apache Hadoop 守护进程。
让我们进入下一步 — 安装 Apache Spark
2. 安装 Apache Spark
Apache Hadoop 负责数据存储(HDFS)和数据的并行处理(MapReduce),以加速执行。Apache Spark是一个多语言兼容的分析引擎,旨在处理大数据分析。我们将在 Jupyter IDE 中使用 Python 运行 Apache Spark。
在安装并运行 HDFS 之后,安装 Apache Spark for Python 轻松得多。PySpark 是 Apache Spark 的 Python API,可以通过在 Jupyter Notebook 中使用pip方法进行安装。PySpark 是 Spark Core API,包含四个组件——Spark SQL、Spark ML 库、Spark Streaming 和 GraphX。此外,我们可以通过初始化安装并指定所需的 Hadoop 版本,使用 PySpark 访问 Hadoop 文件。
# By default, the Hadoop version considered will be 3 here.
PYSPARK_HADOOP_VERSION=3 pip install pyspark
让我们开始大数据分析吧!
3. 使用 PySpark 进行 Steam 评论分析
Steam是一个在线游戏平台,全球拥有超过 100 百万玩家,平台上托管着超过 30,000 款游戏。除了游戏,平台还允许玩家为他们玩的游戏提供评论,这为平台改进客户体验和游戏公司保持玩家活跃提供了重要资源。我们使用了平台上公开提供的Kaggle上的评论数据。
3. a. 从 HDFS 提取数据
我们将使用 PySpark 库来访问、清理和分析数据。首先,我们通过本地主机地址将 PySpark 会话连接到 Hadoop。
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
# Initializing the Spark Session
spark = SparkSession.builder.appName("SteamReviewAnalysis").master("yarn").getOrCreate()
# Providing the url for accessing the HDFS
data = "hdfs://localhost:9000/user/steam_analysis/steam_reviews.csv"
# Extracting the CSV data in the form of a Schema
data_csv = spark.read.csv(data, inferSchema = True, header = True)
# Visualize the structure of the Schema
data_csv.printSchema()
# Counting the number of rows in the dataset
data_csv.count() # 40,848,659
3. b. 数据清理与预处理
我们可以先看看数据集。与 Pandas 中的 pandas.head()函数类似,PySpark 提供了 SparkSession.show()函数,能够展示数据集的一部分。
在此之前,我们将删除数据集中的评论列,因为我们不打算对数据集进行任何自然语言处理(NLP)。此外,评论使用了不同的语言,这使得基于评论进行情感分析变得困难。
# Dropping the review column and saving the data into a new variable
data = data_csv.drop("review")
# Displaying the data
data.show()

图 11:模式的结构
我们有一个庞大的数据集,包含 23 个属性,其中有不同属性的 NULL 值,这使得考虑进行插补变得不合适。因此,我已删除了包含 NULL 值的记录。然而,这不是推荐的方法。您可以评估可用属性的重要性,删除无关的属性,然后尝试对 NULL 值进行插补。
# Drops all the records with NULL values
data = data.na.drop(how = "any")
# Count the number of records in the remaining dataset
data.count() # 16,876,852
数据集中仍然有接近 1700 万条记录!
现在,我们集中关注数据集中的变量名,如图 11 所示。我们可以看到某些属性包含像点(.)这样的字符,这些字符不符合 Python 标识符的命名规则。同时,我们还需要更改日期和时间属性的数据类型。因此,我们使用以下代码进行更改——
from pyspark.sql.types import *
from pyspark.sql.functions import from_unixtime
# Changing the data type of each columns into appropriate types
data = data.withColumn("app_id",data["app_id"].cast(IntegerType())).\
withColumn("author_steamid", data["author_steamid"].cast(LongType())).\
withColumn("recommended", data["recommended"].cast(BooleanType())).\
withColumn("steam_purchase", data["steam_purchase"].cast(BooleanType())).\
withColumn("author_num_games_owned", data["author_num_games_owned"].cast(IntegerType())).\
withColumn("author_num_reviews", data["author_num_reviews"].cast(IntegerType())).\
withColumn("author_playtime_forever", data["author_playtime_forever"].cast(FloatType())).\
withColumn("author_playtime_at_review", data["author_playtime_at_review"].cast(FloatType()))
# Converting the time columns into timestamp data type
data = data.withColumn("timestamp_created", from_unixtime("timestamp_created").cast("timestamp")).\
withColumn("author_last_played", from_unixtime(data["author_last_played"]).cast(TimestampType())).\
withColumn("timestamp_updated", from_unixtime(data["timestamp_updated"]).cast(TimestampType()))

图 12:Steam 评论分析数据集的简要概览。图像来自作者
数据集已经清理完毕,准备好进行分析!
3. c. 探索性数据分析
数据集包含超过 20 个变量,信息丰富。我们可以从不同的角度分析数据。因此,我们将把数据拆分成不同的 PySpark 数据框,并进行缓存,以加快分析速度。
# Grouping the columns for each analysis
col_demo = ["app_id", "app_name", "review_id", "language", "author_steamid", "timestamp_created" ,"author_playtime_forever","recommended"]
col_author = ["steam_purchase", 'author_steamid', "author_num_games_owned", "author_num_reviews", "author_playtime_forever", "author_playtime_at_review", "author_last_played","recommended"]
col_time = [ "app_id", "app_name", "timestamp_created", "timestamp_updated", 'author_playtime_at_review', "recommended"]
col_rev = [ "app_id", "app_name", "language", "recommended"]
col_rec = ["app_id", "app_name", "recommended"]
# Creating new pyspark data frames using the grouped columns
data_demo = data.select(*col_demo)
data_author = data.select(*col_author)
data_time = data.select(*col_time)
data_rev = data.select(*col_rev)
data_rec = data.select(*col_rec)
i. 游戏分析
在这一部分,我们将尝试了解不同游戏的评论和推荐模式。我们将把评论数量视为游戏的受欢迎程度,而推荐数量 True 则代表玩家对该游戏的偏好。
- 找到最受欢迎的游戏
# the data frame is grouped by the game and the number of occurrences are counted
app_names = data_rec.groupBy("app_name").count()
# the data frame is ordered depending on the count for the highest 20 games
app_names_count = app_names.orderBy(app_names["count"].desc()).limit(20)
# a pandas data frame is created for plotting
app_counts = app_names_count.toPandas()
# A pie chart is created
fig = plt.figure(figsize = (10,5))
colors = sns.color_palette("muted")
explode = (0.1,0.075,0.05,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0)
plt.pie(x = app_counts["count"], labels = app_counts["app_name"], colors = colors, explode = explode, shadow = True)
plt.title("The Most Popular Games")
plt.show()
- 找到推荐最多的游戏
# Pick the 20 highest recommended games and convert it in to pandas data frame
true_counts = data_rec.filter(data_rec["recommended"] == "true").groupBy("app_name").count()
recommended = true_counts.orderBy(true_counts["count"].desc()).limit(20)
recommended_apps = recommended.toPandas()
# Pick the games such that both they are in both the popular and highly recommended list
true_apps = list(recommended_apps["app_name"])
true_app_counts = data_rec.filter(data_rec["app_name"].isin(true_apps)).groupBy("app_name").count()
true_app_counts = true_app_counts.orderBy(true_app_counts["count"].desc())
true_app_counts = true_app_counts.toPandas()
# Evaluate the percent of true recommendations for the top games and sort them
true_perc = []
for i in range(0,20,1):
percent = (true_app_counts["count"][i]-recommended_apps["count"][i])/true_app_counts["count"][i]*100
true_perc.append(percent)
recommended_apps["recommend_perc"] = true_perc
recommended_apps = recommended_apps.sort_values(by = "recommend_perc", ascending = False)
# Built a pie chart to visualize
fig = plt.figure(figsize = (10,5))
colors = sns.color_palette("muted")
explode = (0.1,0.075,0.05,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0)
plt.pie(x = recommended_apps["recommend_perc"], labels = recommended_apps["app_name"], colors = colors, explode = explode, shadow = True)
plt.title("The Most Recommended Games")
plt.show()


图 13:显示了受欢迎和推荐游戏的饼图。图片来源:作者
洞察
-
《绝地求生》(PUBG)是 2021 年最受欢迎和最推荐的游戏。
-
然而,在这两个类别中,第二的位置分别由《侠盗猎车手 V》(GTA V)和《星露谷物语》占据。这表明,受欢迎并不意味着所有玩家都会向其他玩家推荐该游戏。
-
其他游戏也观察到了相同的模式。然而,游戏的评论数量显著影响这一趋势。
ii. 人口统计分析
我们将使用data_demo数据框来找到玩家的地区,尤其是玩家的所在位置。这项分析将帮助我们了解受欢迎游戏的评论语言和评论者使用的语言。我们可以利用这一趋势来确定玩家的地域影响力和情感,以推荐未来的新游戏。
- 找到最受欢迎的评论语言
# We standardize the language names in the language column, then group them,
# Count by the groups and convert into pandas df after sorting them the count
author_lang = data_demo.select(lower("language").alias("language"))
\.groupBy("language").count().orderBy(col("count").desc()).
\limit(20).toPandas()
# Plotting a bar graph
fig = plt.figure(figsize = (10,5))
plt.bar(author_lang["language"], author_lang["count"])
plt.xticks(rotation = 90)
plt.xlabel("Popular Languages")
plt.ylabel("Number of Reviews (in Millions)")
plt.show()
- 找到受欢迎游戏的评论语言
# We group the data frame based on the game and language and count each occurrence
data_demo_new = data_demo.select(lower("language").
\alias("language"), "app_name")
games_lang = data_demo_new.groupBy("app_name","language").count().orderBy(col("count").desc()).limit(100).toPandas()
# Plot a stacked bar graph to visualize
grouped_games_lang = games_lang_df.pivot(index='app_name', columns='language', values='count')
grouped_games_lang.plot(kind='bar', stacked=True, figsize=(12, 6))
plt.title('Count of Different App Names and Languages')
plt.xlabel('App Name')
plt.ylabel('Count')
plt.show()


图 14:语言流行度;受欢迎游戏中的语言流行度。图片来源:作者
洞察
-
英语是评论者使用的最流行语言,其次是简体中文和俄语。
-
简体中文是最受欢迎的游戏(PUBG)中使用最广泛的语言,而英语则是第二受欢迎的游戏(GTA V)以及几乎所有其他游戏中使用最广泛的语言。
-
游戏的受欢迎程度似乎与其来源地有关。PUBG是由一家韩国游戏公司开发的,我们观察到它的评论中有韩语,这是使用频率较高的语言之一。
对该数据也进行了时间、作者和评论分析,但并未提供任何可操作的洞察。欢迎访问GitHub 仓库,查看完整项目文档。
3. d. 使用 Spark ML 库进行游戏推荐
我们已经进入项目的最后阶段,在这里我们将实现 交替最小二乘法(ALS) 机器学习算法,来自 Spark ML 库。该模型利用协同过滤技术,根据玩家的行为(即他们之前玩的游戏)来推荐游戏。这个算法识别出那些在 Steam 应用中玩过每个可用游戏的玩家的游戏选择模式。
为了使算法正常工作,
-
我们需要三个变量——独立变量、目标变量(根据推荐数量,这里为 5),以及评分变量。
-
我们对游戏和作者进行编码,以便简化计算。我们还将
boolean推荐列转换为评分列,True = 5,False = 1。 -
此外,我们将为每个已玩游戏推荐 5 款新游戏,因此我们会考虑那些玩过超过五款游戏的玩家的数据,用于建模算法。
让我们跳到建模和推荐部分!
new_pair_games = data_demo.filter(col("author_playtime_forever")>=5*mean_playtime)
new_pair_games = new_pair_games.filter(new_pair_games["author_steamid"]>=76560000000000000).select("author_steamid","app_id", "app_name","recommended")
# Convert author_steamid and app_id to indices, and use the recommended column for rating
author_indexer = StringIndexer(inputCol="author_steamid", outputCol="author_index").fit(new_pair_games)
app_indexer = StringIndexer(inputCol="app_name", outputCol="app_index").fit(new_pair_games)
new_pair_games = new_pair_games.withColumn("Rating", when(col("recommended") == True, 5).otherwise(1))
# We apply the indexing to the data frame by invoking the reduce phase function transform()
new_pair = author_indexer.transform(app_indexer.transform(new_pair_games))
new_pair.show()

# The reference chart for games
games = new_pair.select("app_index","app_name").distinct().orderBy("app_index")

图 16:带有相应索引的游戏列表,供参考。图片来源:作者
实现 ALS 算法
# Create an ALS (Alternating Least Squares) model
als = ALS(maxIter=10, regParam=0.01, userCol="app_index", itemCol="author_index", ratingCol="Rating", coldStartStrategy="drop")
# Fit the model to the data
model = als.fit(new_pair)
# Generate recommendations for all items
app_recommendations = model.recommendForAllItems(5) # Number of recommendations per item
# Display the recommendations
app_recommendations.show(truncate=False)

图 17:根据每位作者的游戏历史生成的推荐和评分。图片来源:作者
我们可以交叉匹配图 16 中的索引,以找出每个玩家推荐的游戏。因此,我们使用 Spark Core ML 库实现了一个基础推荐系统。
3. e. 结论
在这个项目中,我们成功实现了以下内容——
-
下载并安装 Hadoop 生态系统——HDFS 和 MapReduce——以高效地存储、访问和提取大数据,并通过个人计算机实现更快速的大数据分析。
-
安装 Apache Spark 的 Python API(PySpark)并将其与 Hadoop 生态系统集成,使我们能够进行大数据分析和一些机器学习操作。
-
游戏和人口统计分析为我们提供了一些见解,可以用来改善游戏体验并控制玩家流失。保持玩家更新并告知他们同行的趋势应该是 Steam 平台的优先事项。像“最受欢迎”、“你所在地区最受欢迎”、“最推荐”和“不要错过这些新游戏”等建议可以保持玩家活跃。
-
Steam 应用可以使用 ALS 推荐系统,根据玩家的个人资料向现有玩家推荐新游戏,并保持他们的参与度和新鲜感。
4. 下一步?
-
在评论列中实现自然语言处理技术,处理不同语言的评论,以提取评论的精髓并改善游戏体验。
-
Steam 可以根据评论报告游戏中的 bug。开发一种能够捕捉评论内容、进行分类并将其发送给相关人员的 AI 算法,将对平台产生巨大帮助。
-
评论告诉我你认为还可以做些什么!
5. 参考文献
-
Apache Hadoop. Apache Hadoop。Apache Hadoop
-
Statista. (2021). 2010 年至 2020 年全球创建、捕获、复制和消费的数据/信息量,以及 2021 年至 2025 年的预测。statista
-
Dey, R. (2023). 大数据和 Hadoop 分布式文件系统(HDFS)初学者指南。Medium
-
Code with Arjun (2021). 在 Mac OS(MacBook M1)上安装 Hadoop。Medium
-
Apache Spark. PySpark 安装指南。Apache Spark
-
Apache Spark. 使用 ALS 进行协同过滤。Apache Spark
-
Let’s Uncover it. (2023). PUBG。Let’s Uncover It
你可以在我的GitHub 仓库中找到完整的大数据分析项目。
让我们在LinkedIn上联系,进一步讨论!
如果你觉得这篇文章有用,请点赞、分享并评论!
如何设计 X(Twitter)首页时间线 API:值得借鉴的经验
详细了解 X 的 API:获取数据、链接实体,并解决数据获取不足的问题。
·发布于 Towards Data Science ·17 分钟阅读·2024 年 12 月 12 日
--

在设计系统的 API 时,软件工程师通常会评估各种方法,如 REST 与 RPC 与 GraphQL,或者混合模型,以确定最适合特定任务或项目的方法。这些方法定义了数据如何在后端和前端之间流动,以及响应数据的结构:
-
是否应将所有数据打包成一个“批量”,并在一个响应中返回?
-
是否可以配置“批量”仅包含特定客户端(例如,浏览器与移动端)所需的字段,以避免过度获取数据?
-
如果客户端数据获取不足,并需要额外的后端调用来获取缺失的实体,会发生什么情况?
-
应该如何处理父子关系?是否应将子实体嵌套在父实体中,还是应该应用规范化,仅通过父实体引用子实体的 ID,以提高可重用性并减少响应大小?
在本文中,我们探讨了 X(前身为 Twitter)首页时间线 API(x.com/home)如何解决这些挑战,包括:
-
获取推文列表
-
返回层级或关联数据(例如,推文、用户、媒体)
-
对结果进行排序和分页
-
检索推文详情
-
点赞一条推文
我们的重点将放在 API 设计和功能上,将后端视为黑箱,因为其实现无法访问。

X 首页时间线示例
在这里展示确切的请求和响应可能会显得繁琐且难以跟随,因为深层嵌套和重复的对象很难阅读。为了更容易查看请求/响应负载结构,我尝试在 TypeScript 中“写出”首页时间线 API。因此,在展示请求/响应示例时,我将使用请求和响应类型,而不是实际的 JSON 对象。此外,请记住,这些类型是简化的,许多属性为了简洁被省略。
您可以在 types/x.ts 文件中或在本文底部的“附录:所有类型汇总”部分找到所有类型。
除非另有说明,所有图片均来自作者。
获取推文列表
端点和请求/响应结构
获取首页时间线的推文列表从对以下端点的 POST 请求开始:
POST https://x.com/i/api/graphql/{query-id}/HomeTimeline
这里是一个简化的请求体类型:
type TimelineRequest = {
queryId: string; // 's6ERr1UxkxxBx4YundNsXw'
variables: {
count: number; // 20
cursor?: string; // 'DAAACgGBGedb3Vx__9sKAAIZ5g4QENc99AcAAwAAIAIAAA'
seenTweetIds: string[]; // ['1867041249938530657', '1867041249938530659']
};
features: Features;
};
type Features = {
articles_preview_enabled: boolean;
view_counts_everywhere_api_enabled: boolean;
// ...
}
这里是一个简化的响应体类型(我们将在下面深入探讨响应子类型):
type TimelineResponse = {
data: {
home: {
home_timeline_urt: {
instructions: (TimelineAddEntries | TimelineTerminateTimeline)[];
responseObjects: {
feedbackActions: TimelineAction[];
};
};
};
};
};
type TimelineAddEntries = {
type: 'TimelineAddEntries';
entries: (TimelineItem | TimelineCursor | TimelineModule)[];
};
type TimelineItem = {
entryId: string; // 'tweet-1867041249938530657'
sortIndex: string; // '1866561576636152411'
content: {
__typename: 'TimelineTimelineItem';
itemContent: TimelineTweet;
feedbackInfo: {
feedbackKeys: ActionKey[]; // ['-1378668161']
};
};
};
type TimelineTweet = {
__typename: 'TimelineTweet';
tweet_results: {
result: Tweet;
};
};
type TimelineCursor = {
entryId: string; // 'cursor-top-1867041249938530657'
sortIndex: string; // '1866961576813152212'
content: {
__typename: 'TimelineTimelineCursor';
value: string; // 'DACBCgABGedb4VyaJwuKbIIZ40cX3dYwGgaAAwAEAEEAA'
cursorType: 'Top' | 'Bottom';
};
};
type ActionKey = string;
有趣的是,“获取”数据是通过“POST 请求”完成的,这对于类 REST API 并不常见,但在类似 GraphQL 的 API 中是很常见的。此外,URL 中的graphql部分表明 X 正在为其 API 使用 GraphQL 风格。
我在这里使用“flavor”这个词,因为请求体本身不像纯粹的GraphQL 查询,我们可能会在其中描述所需的响应结构,列出我们想要获取的所有属性:
# An example of a pure GraphQL request structure that is *not* being used in the X API.
{
tweets {
id
description
created_at
medias {
kind
url
# ...
}
author {
id
name
# ...
}
# ...
}
}
这里的假设是,首页时间线 API 不是一个纯粹的 GraphQL API,而是几种方法的混合体。像这样通过 POST 请求传递参数看起来更接近于“功能性”RPC 调用。但与此同时,似乎 GraphQL 的特性可能在后台某个地方被使用在 HomeTimeline 端点处理程序/控制器中。像这样的混合可能还与遗留代码或某种正在进行的迁移有关。但再说一次,这些仅仅是我的推测。
您可能还注意到,API URL 和 API 请求体中都使用了相同的TimelineRequest.queryId。这个 queryId 很可能是在后台生成的,然后嵌入到 main.js 包中,之后在从后台获取数据时使用它。由于 X 的后台对于我们来说是一个黑箱,我很难理解这个 queryId 是如何被使用的。但是,再说一次,推测可能是它对于某种性能优化(重用一些预计算的查询结果?)、缓存(与 Apollo 相关?)、调试(通过 queryId 连接日志?)或跟踪/追踪目的来说是必要的。
还需要注意的是,TimelineResponse 并不包含推文列表,而是包含一系列指令,例如 “将一条推文添加到时间线”(请参见 TimelineAddEntries 类型),或 “终止时间线”(请参见 TimelineTerminateTimeline 类型)。
TimelineAddEntries 指令本身也可能包含不同类型的实体:
-
推文 — 请参见
TimelineItem类型 -
游标 — 请参见
TimelineCursor类型 -
对话/评论/线程 — 请参见
TimelineModule类型
type TimelineResponse = {
data: {
home: {
home_timeline_urt: {
instructions: (TimelineAddEntries | TimelineTerminateTimeline)[]; // <-- Here
// ...
};
};
};
};
type TimelineAddEntries = {
type: 'TimelineAddEntries';
entries: (TimelineItem | TimelineCursor | TimelineModule)[]; // <-- Here
};
从可扩展性的角度来看,这一点很有趣,因为它允许在不大幅修改 API 的情况下,渲染更多种类的内容到首页时间线。
分页
TimelineRequest.variables.count 属性设置我们希望一次获取多少条推文(每页)。默认值是 20。然而,TimelineAddEntries.entries 数组中可能会返回超过 20 条推文。例如,数组可能在第一页加载时包含 37 条条目,因为它包括推文(29 条)、置顶推文(1 条)、推广推文(5 条)和分页游标(2 条)。不过,我不确定为什么有 29 条常规推文,而请求的数量是 20 条。
TimelineRequest.variables.cursor 负责基于游标的分页。
“游标分页最常用于实时数据,因为新记录的添加频繁,并且在读取数据时,通常先看到最新的结果。它消除了跳过项和重复显示相同项的可能性。在基于游标的分页中,使用一个常量指针(或游标)来跟踪下一项应从数据集中哪里获取。”请参见 偏移分页与游标分页 讨论帖获取相关背景信息。
在第一次获取推文列表时,TimelineRequest.variables.cursor 是空的,因为我们希望从默认的(很可能是预计算的)个性化推文列表中获取顶级推文。
然而,在响应中,除了推文数据,后端还会返回游标条目。以下是响应类型的层级结构:TimelineResponse → TimelineAddEntries → TimelineCursor:
type TimelineResponse = {
data: {
homet: {
home_timeline_urt: {
instructions: (TimelineAddEntries | TimelineTerminateTimeline)[]; // <-- Here
// ...
};
};
};
};
type TimelineAddEntries = {
type: 'TimelineAddEntries';
entries: (TimelineItem | TimelineCursor | TimelineModule)[]; // <-- Here (tweets + cursors)
};
type TimelineCursor = {
entryId: string;
sortIndex: string;
content: {
__typename: 'TimelineTimelineCursor';
value: string; // 'DACBCgABGedb4VyaJwuKbIIZ40cX3dYwGgaAAwAEAEEAA' <-- Here
cursorType: 'Top' | 'Bottom';
};
};
每一页都包含推文列表,以及“顶部”和“底部”游标:

游标与推文一同传递的示例
数据加载后,我们可以从当前页面向两个方向移动,使用“底部”游标获取“之前/更旧”的推文,或使用“顶部”游标获取“下一条/更新”的推文。我猜测,使用“顶部”游标获取“下一条”推文有两种情况:一是当用户仍在浏览当前页面时,新推文已被添加,二是当用户开始向上滚动动态时(如果没有缓存条目,或之前的条目由于性能原因被删除)。
X 的游标本身可能看起来是这样的:DAABCgABGemI6Mk__9sKAAIZ6MSYG9fQGwgAAwAAAAIAAA。在一些 API 设计中,游标可能是一个 Base64 编码的字符串,包含列表中最后一项的 ID,或者是最后一项的时间戳。例如:eyJpZCI6ICIxMjM0NTY3ODkwIn0= --> {"id": "1234567890"},然后这些数据会用于相应地查询数据库。在 X API 的情况下,游标看起来像是被 Base64 解码成某种自定义的二进制序列,可能需要进一步解码才能得到有意义的内容(例如,通过 Protobuf 消息定义)。由于我们不知道它是否是.proto编码,也不知道.proto的消息定义,我们只能假设后端知道如何根据游标字符串查询下一批推文。
TimelineResponse.variables.seenTweetIds参数用于通知服务器客户端当前活动页面中已经查看过的推文(来自无限滚动)。这很可能有助于确保服务器在后续的结果页面中不包含重复的推文。
链接/层级实体
像主页时间线(或主页动态)这样的 API 面临的挑战之一是如何返回链接或层级实体(即tweet → user、tweet → media、media → author等):
-
我们是否应该先仅返回推文列表,然后根据需求通过一系列单独的查询获取依赖的实体(如用户详情)?
-
或者我们应该一次性返回所有数据,这样虽然会增加首次加载的时间和大小,但可以节省后续所有调用的时间?
-
在这种情况下,我们是否需要对数据进行规范化,以减少负载大小(即当同一用户是多条推文的作者时,我们希望避免在每个推文实体中重复用户数据)?
-
还是应该是上述方法的组合?
我们来看一下 X 是如何处理的。
在TimelineTweet类型中,早些时候使用了Tweet子类型。我们来看一下它的样子:
export type TimelineResponse = {
data: {
home: {
home_timeline_urt: {
instructions: (TimelineAddEntries | TimelineTerminateTimeline)[]; // <-- Here
// ...
};
};
};
};
type TimelineAddEntries = {
type: 'TimelineAddEntries';
entries: (TimelineItem | TimelineCursor | TimelineModule)[]; // <-- Here
};
type TimelineItem = {
entryId: string;
sortIndex: string;
content: {
__typename: 'TimelineTimelineItem';
itemContent: TimelineTweet; // <-- Here
// ...
};
};
type TimelineTweet = {
__typename: 'TimelineTweet';
tweet_results: {
result: Tweet; // <-- Here
};
};
// A Tweet entity
type Tweet = {
__typename: 'Tweet';
core: {
user_results: {
result: User; // <-- Here (a dependent User entity)
};
};
legacy: {
full_text: string;
// ...
entities: { // <-- Here (a dependent Media entities)
media: Media[];
hashtags: Hashtag[];
urls: Url[];
user_mentions: UserMention[];
};
};
};
// A User entity
type User = {
__typename: 'User';
id: string; // 'VXNlcjoxNDUxM4ADSG44MTA4NDc4OTc2'
// ...
legacy: {
location: string; // 'San Francisco'
name: string; // 'John Doe'
// ...
};
};
// A Media entity
type Media = {
// ...
source_user_id_str: string; // '1867041249938530657' <-- Here (the dependant user is being mentioned by its ID)
url: string; // 'https://t.co/X78dBgtrsNU'
features: {
large: { faces: FaceGeometry[] };
medium: { faces: FaceGeometry[] };
small: { faces: FaceGeometry[] };
orig: { faces: FaceGeometry[] };
};
sizes: {
large: MediaSize;
medium: MediaSize;
small: MediaSize;
thumb: MediaSize;
};
video_info: VideoInfo[];
};
这里有趣的是,大部分依赖数据,如tweet → media和tweet → author,在第一次调用的响应中就已经嵌入(无需后续查询)。
此外,User和Media与Tweet实体的连接并没有进行规范化(如果两条推文有相同的作者,那么它们的数据将在每个推文对象中重复)。但似乎这应该没问题,因为在特定用户的主页时间线范围内,推文会由多位作者创作,数据的重复是可能的,但并不频繁。
我的假设是,负责获取某一特定用户推文的 UserTweets API(我们在这里不讨论)会以不同的方式处理,但显然并非如此。UserTweets 返回相同用户的推文列表,并且为每条推文重复嵌入相同的用户数据。这很有趣,也许这种方法的简单性克服了一些数据大小的开销(也许用户数据的大小被认为相对较小)。我不确定。
关于实体关系的另一个观察是,Media 实体也与 User(作者)有链接。但它不是通过直接的实体嵌入(像 Tweet 实体那样),而是通过 Media.source_user_id_str 属性进行链接。
每个主页时间线中“推文”的“评论”(它们本质上也是“推文”)完全没有被获取。要查看推文线程,用户必须点击推文以查看其详细视图。推文线程将通过调用 TweetDetail 端点来获取(更多关于它的内容将在下文的“推文详情页面”部分中介绍)。
每个 Tweet 还有一个实体是 FeedbackActions(即“更少推荐”或“更少查看”)。FeedbackActions 在响应对象中的存储方式与 User 和 Media 对象的存储方式不同。虽然 User 和 Media 实体是 Tweet 的一部分,但 FeedbackActions 是单独存储在 TimelineItem.content.feedbackInfo.feedbackKeys 数组中的,并通过 ActionKey 进行链接。对我来说这是一个小小的惊讶,因为看起来没有任何动作是可以重用的。似乎每个动作仅用于特定的推文。因此,FeedbackActions 似乎可以像 Media 实体一样嵌入到每个推文中。但我可能忽略了某些隐藏的复杂性(例如每个动作可能有子动作)。
关于动作的更多细节请参见下文的“推文动作”部分。
排序
时间线条目的排序顺序由后端通过 sortIndex 属性定义:
type TimelineCursor = {
entryId: string;
sortIndex: string; // '1866961576813152212' <-- Here
content: {
__typename: 'TimelineTimelineCursor';
value: string;
cursorType: 'Top' | 'Bottom';
};
};
type TimelineItem = {
entryId: string;
sortIndex: string; // '1866561576636152411' <-- Here
content: {
__typename: 'TimelineTimelineItem';
itemContent: TimelineTweet;
feedbackInfo: {
feedbackKeys: ActionKey[];
};
};
};
type TimelineModule = {
entryId: string;
sortIndex: string; // '73343543020642838441' <-- Here
content: {
__typename: 'TimelineTimelineModule';
items: {
entryId: string,
item: TimelineTweet,
}[],
displayType: 'VerticalConversation',
};
};
sortIndex 本身可能像这样 '1867231621095096312'。它可能直接对应或从 Snowflake ID 派生出来。
实际上,您在响应中看到的大多数 ID(推文 ID)都遵循“Snowflake ID”规范,并且看起来像是
*'1867231621095096312'**。
如果用来对诸如推文等实体进行排序,系统会利用 Snowflake ID 本身的时间顺序排序。具有较高 sortIndex 值(即较新的时间戳)的推文或对象会出现在信息流的更高位置,而具有较低值(即较旧时间戳)的推文则会出现在信息流的较低位置。
这是对 Snowflake ID(在我们这里是 sortIndex)1867231621095096312 的逐步解码:
-
提取时间戳:
-
时间戳通过将 Snowflake ID 右移 22 位(以去掉数据中心、工作者 ID 和序列的低 22 位)得出:
1867231621095096312 → 445182709954 -
添加 Twitter 的纪元:
-
将 Twitter 的自定义纪元(1288834974657)添加到此时间戳中,得到 UNIX 时间戳的毫秒数:
445182709954 + 1288834974657 → 1734017684611ms -
转换为人类可读的日期:
-
将 UNIX 时间戳转换为 UTC 日期时间:
1734017684611ms → 2024-12-12 15:34:44.611 (UTC)
因此,我们可以在这里假设,首页时间线中的推文是按时间顺序排序的。
推文操作
每条推文都有一个“操作”菜单。

推文操作示例
每条推文的操作来自后台的TimelineItem.content.feedbackInfo.feedbackKeys数组,并通过ActionKey与推文关联:
type TimelineResponse = {
data: {
home: {
home_timeline_urt: {
instructions: (TimelineAddEntries | TimelineTerminateTimeline)[];
responseObjects: {
feedbackActions: TimelineAction[]; // <-- Here
};
};
};
};
};
type TimelineItem = {
entryId: string;
sortIndex: string;
content: {
__typename: 'TimelineTimelineItem';
itemContent: TimelineTweet;
feedbackInfo: {
feedbackKeys: ActionKey[]; // ['-1378668161'] <-- Here
};
};
};
type TimelineAction = {
key: ActionKey; // '-609233128'
value: {
feedbackType: 'NotRelevant' | 'DontLike' | 'SeeFewer'; // ...
prompt: string; // 'This post isn’t relevant' | 'Not interested in this post' | ...
confirmation: string; // 'Thanks. You’ll see fewer posts like this.'
childKeys: ActionKey[]; // ['1192182653', '-1427553257'], i.e. NotInterested -> SeeFewer
feedbackUrl: string; // '/2/timeline/feedback.json?feedback_type=NotRelevant&action_metadata=SRwW6oXZadPHiOczBBaAwPanEwE%3D'
hasUndoAction: boolean;
icon: string; // 'Frown'
};
};
有趣的是,这个扁平的操作数组实际上是一个树(或图形?我没检查),因为每个操作可能有子操作(参见TimelineAction.value.childKeys数组)。这是有意义的,例如,当用户点击“不要喜欢”操作时,后续可能会显示“此帖子不相关”操作,用以解释用户为什么不喜欢这条推文。
推文详情页
一旦用户想查看推文详情页(即查看评论/推文的线程),用户点击推文并发送GET请求到以下端点:
GET https://x.com/i/api/graphql/{query-id}/TweetDetail?variables={"focalTweetId":"1867231621095096312","referrer":"home","controller_data":"DACABBSQ","rankingMode":"Relevance","includePromotedContent":true,"withCommunity":true}&features={"articles_preview_enabled":true}
我在这里很好奇,为什么推文列表是通过POST请求获取的,而每条推文的详情是通过GET请求获取的。这似乎不一致。特别是考虑到类似query-id、features等查询参数这次是通过 URL 传递的,而不是通过请求体传递的。响应格式也类似,并且重用了列表调用中的类型。我不确定为什么会这样。但再次强调,我肯定可能会错过一些背景复杂性。
这里是简化后的响应体类型:
type TweetDetailResponse = {
data: {
threaded_conversation_with_injections_v2: {
instructions: (TimelineAddEntries | TimelineTerminateTimeline)[],
},
},
}
type TimelineAddEntries = {
type: 'TimelineAddEntries';
entries: (TimelineItem | TimelineCursor | TimelineModule)[];
};
type TimelineTerminateTimeline = {
type: 'TimelineTerminateTimeline',
direction: 'Top',
}
type TimelineModule = {
entryId: string; // 'conversationthread-58668734545929871193'
sortIndex: string; // '1867231621095096312'
content: {
__typename: 'TimelineTimelineModule';
items: {
entryId: string, // 'conversationthread-1866876425669871193-tweet-1866876038930951193'
item: TimelineTweet,
}[], // Comments to the tweets are also tweets
displayType: 'VerticalConversation',
};
};
响应在类型上与列表响应非常相似,因此我们不会在这里停留太久。
一个有趣的细节是,每条推文的“评论”(或对话)实际上是其他推文(参见TimelineModule类型)。因此,推文线程看起来与首页时间线的推送非常相似,通过显示TimelineTweet条目的列表来呈现。这看起来非常优雅。一个很好的例子,展示了 API 设计中通用且可复用的方法。
点赞推文
当用户点赞推文时,会向以下端点发送POST请求:
POST https://x.com/i/api/graphql/{query-id}/FavoriteTweet
这里是请求体类型:
type FavoriteTweetRequest = {
variables: {
tweet_id: string; // '1867041249938530657'
};
queryId: string; // 'lI07N61twFgted2EgXILM7A'
};
这里是响应体类型:
type FavoriteTweetResponse = {
data: {
favorite_tweet: 'Done',
}
}
看起来很直接,也类似于 RPC 风格的 API 设计方法。
结论
我们通过查看 X 的 API 示例,已经涉及了一些家庭时间线 API 设计的基本部分。在这个过程中,我尽量根据我的知识做出了一些假设。我相信我可能有一些地方理解得不准确,也可能错过了一些复杂的细微差别。但即便如此,我希望你能从这份高层次的概述中获得一些有用的见解,这些见解可以应用到你下次的 API 设计会议中。
起初,我计划通过浏览一些顶级技术网站,获取来自 Facebook、Reddit、YouTube 等的见解,并收集经过实践验证的最佳做法和解决方案。我不确定是否能找到时间去做这件事。到时再看看。但这可能是一个有趣的练习。
附录:所有类型汇总
作为参考,我在这里一次性添加了所有类型。你也可以在types/x.ts文件中找到所有类型。
/**
* This file contains the simplified types for X's (Twitter's) home timeline API.
*
* These types are created for exploratory purposes, to see the current implementation
* of the X's API, to see how they fetch Home Feed, how they do a pagination and sorting,
* and how they pass the hierarchical entities (posts, media, user info, etc).
*
* Many properties and types are omitted for simplicity.
*/
// POST https://x.com/i/api/graphql/{query-id}/HomeTimeline
export type TimelineRequest = {
queryId: string; // 's6ERr1UxkxxBx4YundNsXw'
variables: {
count: number; // 20
cursor?: string; // 'DAAACgGBGedb3Vx__9sKAAIZ5g4QENc99AcAAwAAIAIAAA'
seenTweetIds: string[]; // ['1867041249938530657', '1867041249938530658']
};
features: Features;
};
// POST https://x.com/i/api/graphql/{query-id}/HomeTimeline
export type TimelineResponse = {
data: {
home: {
home_timeline_urt: {
instructions: (TimelineAddEntries | TimelineTerminateTimeline)[];
responseObjects: {
feedbackActions: TimelineAction[];
};
};
};
};
};
// POST https://x.com/i/api/graphql/{query-id}/FavoriteTweet
export type FavoriteTweetRequest = {
variables: {
tweet_id: string; // '1867041249938530657'
};
queryId: string; // 'lI07N6OtwFgted2EgXILM7A'
};
// POST https://x.com/i/api/graphql/{query-id}/FavoriteTweet
export type FavoriteTweetResponse = {
data: {
favorite_tweet: 'Done',
}
}
// GET https://x.com/i/api/graphql/{query-id}/TweetDetail?variables={"focalTweetId":"1867041249938530657","referrer":"home","controller_data":"DACABBSQ","rankingMode":"Relevance","includePromotedContent":true,"withCommunity":true}&features={"articles_preview_enabled":true}
export type TweetDetailResponse = {
data: {
threaded_conversation_with_injections_v2: {
instructions: (TimelineAddEntries | TimelineTerminateTimeline)[],
},
},
}
type Features = {
articles_preview_enabled: boolean;
view_counts_everywhere_api_enabled: boolean;
// ...
}
type TimelineAction = {
key: ActionKey; // '-609233128'
value: {
feedbackType: 'NotRelevant' | 'DontLike' | 'SeeFewer'; // ...
prompt: string; // 'This post isn’t relevant' | 'Not interested in this post' | ...
confirmation: string; // 'Thanks. You’ll see fewer posts like this.'
childKeys: ActionKey[]; // ['1192182653', '-1427553257'], i.e. NotInterested -> SeeFewer
feedbackUrl: string; // '/2/timeline/feedback.json?feedback_type=NotRelevant&action_metadata=SRwW6oXZadPHiOczBBaAwPanEwE%3D'
hasUndoAction: boolean;
icon: string; // 'Frown'
};
};
type TimelineAddEntries = {
type: 'TimelineAddEntries';
entries: (TimelineItem | TimelineCursor | TimelineModule)[];
};
type TimelineTerminateTimeline = {
type: 'TimelineTerminateTimeline',
direction: 'Top',
}
type TimelineCursor = {
entryId: string; // 'cursor-top-1867041249938530657'
sortIndex: string; // '1867231621095096312'
content: {
__typename: 'TimelineTimelineCursor';
value: string; // 'DACBCgABGedb4VyaJwuKbIIZ40cX3dYwGgaAAwAEAEEAA'
cursorType: 'Top' | 'Bottom';
};
};
type TimelineItem = {
entryId: string; // 'tweet-1867041249938530657'
sortIndex: string; // '1867231621095096312'
content: {
__typename: 'TimelineTimelineItem';
itemContent: TimelineTweet;
feedbackInfo: {
feedbackKeys: ActionKey[]; // ['-1378668161']
};
};
};
type TimelineModule = {
entryId: string; // 'conversationthread-1867041249938530657'
sortIndex: string; // '1867231621095096312'
content: {
__typename: 'TimelineTimelineModule';
items: {
entryId: string, // 'conversationthread-1867041249938530657-tweet-1867041249938530657'
item: TimelineTweet,
}[], // Comments to the tweets are also tweets
displayType: 'VerticalConversation',
};
};
type TimelineTweet = {
__typename: 'TimelineTweet';
tweet_results: {
result: Tweet;
};
};
type Tweet = {
__typename: 'Tweet';
core: {
user_results: {
result: User;
};
};
views: {
count: string; // '13763'
};
legacy: {
bookmark_count: number; // 358
created_at: string; // 'Tue Dec 10 17:41:28 +0000 2024'
conversation_id_str: string; // '1867041249938530657'
display_text_range: number[]; // [0, 58]
favorite_count: number; // 151
full_text: string; // "How I'd promote my startup, if I had 0 followers (Part 1)"
lang: string; // 'en'
quote_count: number;
reply_count: number;
retweet_count: number;
user_id_str: string; // '1867041249938530657'
id_str: string; // '1867041249938530657'
entities: {
media: Media[];
hashtags: Hashtag[];
urls: Url[];
user_mentions: UserMention[];
};
};
};
type User = {
__typename: 'User';
id: string; // 'VXNlcjoxNDUxM4ADSG44MTA4NDc4OTc2'
rest_id: string; // '1867041249938530657'
is_blue_verified: boolean;
profile_image_shape: 'Circle'; // ...
legacy: {
following: boolean;
created_at: string; // 'Thu Oct 21 09:30:37 +0000 2021'
description: string; // 'I help startup founders double their MRR with outside-the-box marketing cheat sheets'
favourites_count: number; // 22195
followers_count: number; // 25658
friends_count: number;
location: string; // 'San Francisco'
media_count: number;
name: string; // 'John Doe'
profile_banner_url: string; // 'https://pbs.twimg.com/profile_banners/4863509452891265813/4863509'
profile_image_url_https: string; // 'https://pbs.twimg.com/profile_images/4863509452891265813/4863509_normal.jpg'
screen_name: string; // 'johndoe'
url: string; // 'https://t.co/dgTEddFGDd'
verified: boolean;
};
};
type Media = {
display_url: string; // 'pic.x.com/X7823zS3sNU'
expanded_url: string; // 'https://x.com/johndoe/status/1867041249938530657/video/1'
ext_alt_text: string; // 'Image of two bridges.'
id_str: string; // '1867041249938530657'
indices: number[]; // [93, 116]
media_key: string; // '13_2866509231399826944'
media_url_https: string; // 'https://pbs.twimg.com/profile_images/1867041249938530657/4863509_normal.jpg'
source_status_id_str: string; // '1867041249938530657'
source_user_id_str: string; // '1867041249938530657'
type: string; // 'video'
url: string; // 'https://t.co/X78dBgtrsNU'
features: {
large: { faces: FaceGeometry[] };
medium: { faces: FaceGeometry[] };
small: { faces: FaceGeometry[] };
orig: { faces: FaceGeometry[] };
};
sizes: {
large: MediaSize;
medium: MediaSize;
small: MediaSize;
thumb: MediaSize;
};
video_info: VideoInfo[];
};
type UserMention = {
id_str: string; // '98008038'
name: string; // 'Yann LeCun'
screen_name: string; // 'ylecun'
indices: number[]; // [115, 122]
};
type Hashtag = {
indices: number[]; // [257, 263]
text: string;
};
type Url = {
display_url: string; // 'google.com'
expanded_url: string; // 'http://google.com'
url: string; // 'https://t.co/nZh3aF0Aw6'
indices: number[]; // [102, 125]
};
type VideoInfo = {
aspect_ratio: number[]; // [427, 240]
duration_millis: number; // 20000
variants: {
bitrate?: number; // 288000
content_type?: string; // 'application/x-mpegURL' | 'video/mp4' | ...
url: string; // 'https://video.twimg.com/amplify_video/18665094345456w6944/pl/-ItQau_LRWedR-W7.m3u8?tag=14'
};
};
type FaceGeometry = { x: number; y: number; h: number; w: number };
type MediaSize = { h: number; w: number; resize: 'fit' | 'crop' };
type ActionKey = string;
Apple M2 Max GPU vs Nvidia V100(第二部分):大模型与能效
比较 Apple Silicon M2 Max GPU 在训练大规模 CNN 模型时的性能和能效,并与 Nvidia V100 进行对比。
·发表于 Towards Data Science ·阅读时间 16 分钟·2024 年 2 月 7 日
--
在我的 上一篇文章 中,我将 M2 Max GPU 与 Nvidia V100、P100 和 T4 在 MLP、CNN 和 LSTM 训练中的表现进行了比较。结果显示,M2 Max 在小型模型训练中表现非常好,甚至超过了 Nvidia GPU。但是,正如文章中所述:
[…] 这些指标只能用于与本次测试中使用的神经网络类型和深度相似的情况。
因此,第二部分测试了更大的模型,专注于仅 CNN,并将 M2 Max 与之前测试过的最强 GPU:Nvidia V100 进行了比较。
另一个在本次测试中考虑的因素是内存管理。虽然 Nvidia GPU 在内存传输上浪费了大量时间,M2 Max GPU 由于可以直接访问统一内存,因此在训练模型之前无需任何延迟。由于正如上一篇文章中所示,这对在少量周期上训练的小型模型有很大影响,因此为了比较纯粹的训练时间,我们去除了这一影响,并测试了更大的模型。
为此,我们对模型进行十个周期的训练,但不是使用总的训练时间,而是从第二个周期到最后一个周期,捕捉并平均每步的训练时长。这样可以去除初始化和内存传输的开销……
Apple M3 机器学习速度测试
苹果的 M3、M3 Pro 和 M3 Max 与 TensorFlow 和 PyTorch 相比如何?
·发布于 Towards Data Science ·10 分钟阅读·2024 年 1 月 9 日
--

四台 MacBook Pro 对比八个机器学习测试。哪台机器学习最快?来源:作者的客厅。
在过去的两年里,我一直在使用我的 M1 Pro MacBook Pro 14 英寸。
我购买了升级版,增加了更多内存、GPU 核心和存储,以确保未来使用。
而且它一直表现稳定。
但苹果最新发布的 M3 系列让我产生了好奇。
我观看了演示,并看到了一堆关于它们是近年来 GPU 性能跃升最大的一些图表。
作为一名机器学习工程师,自然,我对它们在机器学习角度的表现产生了好奇。
我的 M1 Pro 在日常使用中无可匹敌。
我喜欢它。
但是我不会在其上训练更大规模的机器学习模型。
M3 系列能改变这一点吗?
我做了一些测试来找出答案。
资源
-
GitHub 上的代码 — 我用来设置并运行机器测试的所有代码都可以在 GitHub 上找到。
-
视频演示 — 我还在 YouTube 上制作了一个视频演示,展示了所有结果并给出了一些建议和推荐。
滚动窗口在时间序列中的应用,使用 Python
这是滚动窗口和时间序列的一些强大应用
·发表于Towards Data Science ·11 分钟阅读·2024 年 9 月 15 日
--

图片由Claudia Aran提供,来源于Unsplash
昨晚我和妻子一起做洗衣。我们有一个非言语的约定(不过当我违反它时,它会变得相当言语化):关于洗衣的事情,她负责把衣服放进洗衣机和干衣机,而我则负责折叠衣服。
我们通常是这样做的:

图像由作者使用 DALLE 制作
现在,我并不是真的折叠所有衣服并把它们收好。否则,我会被衣服淹没。我做的事情是一种让我想起滚动窗口方法的方式:

图像由作者使用 DALLE 制作
为什么我说它让我想起滚动窗口?让我们看看这个类比。

图像由作者使用 DALLE 制作
滚动窗口的概念正是我在折叠衣服时应用的那个方法。我有一个任务要做,但你…
应用 LLM 量化与 AWS Sagemaker | Analytics.gov
以两倍的速度和五分之一的成本托管生产就绪的 LLM 端点。
·发表于 Towards Data Science ·阅读时间 16 分钟·2024 年 6 月 7 日
--

作者提供的图像,使用 AWS Sagemaker Jumpstart - Stable Diffusion XL 1.0(开源)生成
声明:我是一名数据工程师,隶属于新加坡政府技术局(GovTech)数据科学与人工智能部(DSAID)。作为 Analytics.gov 的核心开发者之一,我与各个政府部门合作,为公共部门开发数据科学和 AI/ML 能力,造福社会。
目录
-
前言
-
为什么使用开源模型?
-
托管开源 LLM 的障碍
-
什么是量化,它如何帮助?
-
AWS Sagemaker 端点是如何工作的?
-
在 AG Sagemaker 上托管量化模型
-
基准测试
-
结论
1. 前言
如果你还没有阅读我们之前的发布文章,可以在这里查阅!
[## 通过 MLOps 加速 Analytics.gov 上机器学习与 AI 的影响
Analytics.gov 简介
medium.com [## 使用 Analytics.gov 将 LLM 和机器学习模型生产化:MOM 在 AI 解决方案部署中的旅程
特别感谢本文的共同贡献者:MOM 前沿部署团队(Barry Tng,Ethan Mak,Joel Koo),以及...
Analytics.gov (AG)是由新加坡 GovTech 的数据科学与人工智能部门(DSAID)开发的中央机器学习操作(MLOps)平台,它将机器学习和人工智能用例推向全政府(WOG)生产化。该平台托管于政府商业云(GCC)2.0,采用最佳实践的网络和安全配置,为所有数据科学和 AI 需求提供安全的环境。通过 AG,政府官员可以直接从其政府发放的笔记本电脑访问计算资源、托管的 AI 服务及其他工具,而无需管理或开发新的基础设施,从而加速了全政府的 AI/ML 项目。
AG 提供定制功能,利用 AWS Sagemaker Endpoints 提供的能力,为量化模型创建和管理生产就绪的推理端点。仅需几行代码,最终用户即可快速为量化模型设置自己的私有推理端点,将可能需要几天或几周的工作缩短为几分钟。这大大降低了整个政府机构使用 GenAI 的门槛,从而提高了效率和成本效益。
在本文中,我们将探讨 AG 如何使政府机构高效且具有成本效益地运行大语言模型(LLM)。我们的目标是揭开模型量化的神秘面纱,展示我们如何简化在 AWS Sagemaker 中托管量化开源 LLM 的过程,并提供基准测试以评估性能和成本效益的提升。
2. 为什么要使用开源模型?
如需深入了解开源 LLM,请阅读 Sau Sheong 的相关文章!(注:此为 Medium 会员专享内容)
在 LLM 应用中使用开源 LLM
我强烈推荐它,因为它为将开源 LLM 作为 API 托管提供了很好的启示,是本文的重要补充。
安全性与敏感性
开源模型可以在你自己的设备或云环境中私密托管,这意味着向模型发出的查询不会被发送到第三方提供商。这在政府数据中尤为重要,因为其中大部分包含敏感信息。
控制输出生成
开源模型的使用可以在更细粒度的层面上进行控制。封闭源模型必须通过公开的商业 API 进行接口连接,这种方式简化了复杂性,但减少了对模型的控制程度。本地托管的开源模型则允许对输出生成进行完全控制,这一点非常重要,因为许多有用的库,如LMQL和Guidance,在本地托管模型上表现更好。
多样性
截至目前,HuggingFace 上已有超过 60 万个模型,包括由 Meta 和 Google 等大厂发布的模型以及独立贡献者发布的自定义版本。有些版本针对特定目的或任务进行了微调,可以直接使用。用户可以简单地重用这些模型,而无需自己进行微调。
例如,AiSingapore 的 SEA-LION模型经过指令调优,专为东南亚(SEA)地区的语言设计,其训练数据集包含了从马来语到泰语等多种语言。使用此模型可以节省大量获取不同语言数据集的工作量,同时减少微调的计算成本。
3. 托管开源 LLM 的障碍
语言模型有多种形式和大小,流行的模型从 TinyLlama(1.1B)到即将发布的 Llama-3 400B+不等。像 TinyLlama 这样的较小语言模型(SLM)适用于较小且更简单的用例,而复杂的用例通常需要“更智能”的大型语言模型(LLM)。毫无疑问,所有生成 AI 应用都将受益于来自大型 LLM 的更好输出质量,然而,模型的体积越大,也意味着更多的权衡。
为了最大化推理速度,模型必须完全加载到 GPU 内存中,因为任何磁盘和 GPU 内存或 CPU 和 GPU 内存之间的数据传输都会引入额外的开销,从而显著降低推理速度。
LLM 需要大量内存来托管,LLM 越大,所需的 GPU 内存就越多。大多数大型模型需要多个 GPU 才能完全加载到内存中,这使得这一任务成为极其资源密集且昂贵的工作。
自然地,随着模型规模的增大,每次推理任务所需的计算量也随之增加。因此,LLM 越大,推理速度就越低。

按作者分类的变换器 BF16 推理基准
这些模型究竟有多大?
这些大型语言模型(LLM)的大小可以通过以下公式进行估算(注意,这只是一个简单估算,实际的模型大小几乎总是略大于此估算值。)

按作者分类的简化模型大小计算公式,灵感来源于 https://www.substratus.ai/blog/calculating-gpu-memory-for-llm/
使用这个公式,我们可以估算一些流行模型的模型大小:

按作者分类的流行模型的模型大小表
注意:该公式仅估算模型大小,实际的 GPU 需求肯定会更大,并且会因其他因素而有所不同。(正如你将在后续基准部分看到的,实际的 GPU 需求远远超出了这些估算值)。"BF16"代表脑浮点 16 数字格式,而"FP16"代表浮点 16 格式。
即将发布的 Meta 的 Llama-3 400B+将是发布时最大的开放源代码模型之一。我们可以估算,这个巨型模型可能会大到 800GB。作为对比,800GB 的存储至少需要 10 张 A100 80GB 的 GPU 卡来托管,即使我们天真地假设零托管开销。
另一个流行但尺寸更为合理的模型——Llama-3 70B,采用 bf16 或每权重 16 位(bpw)精度发布,仍然需要 141.2GB 的 GPU 内存来进行推理托管。
为什么大 GPU 内存需求是一个问题?
由于当前 GPU 供不应求且需求高,找到便宜的多个 GPU 芯片并不容易。因此,托管 LLM 的原始未量化格式可能是一项非常昂贵的业务,只有少数能够负担得起的特权人群才能使用。这对于那些需要 LLM 智慧的项目来说可能是一个限制,但它的价值不足以让其值得使用多块稀缺且昂贵的 GPU。
更大的 LLM 模型尺寸导致推理速度变慢,从而也会导致:
-
由于输出缓慢,用户体验更差。
-
下游应用程序能够提取的总吞吐量减少。对于像文本摘要或报告生成这样的重令牌应用程序,吞吐量的减少可能会严重影响应用程序的可行性。
缓慢的推理速度和高昂的成本是制约生产级应用的因素,因此每个生成式 AI 应用都需要在输出质量、推理速度和成本之间做出权衡。
4. 什么是量化,它如何提供帮助?
什么是量化?
关于量化的更严谨解释,请参考以下两篇精彩的指南: https://www.tensorops.ai/post/what-are-quantized-llms, https://www.semianalysis.com/p/neural-network-quantization-and-number
为了简化,以下部分将仅讨论训练后量化(PTQ)
简单来说,在 AI/ML 领域,量化是一种减少模型大小的技术。在内部,模型的权重作为数字存储。通常,这些权重以类似浮动点 16(FP16)或脑浮动点 16(BF16)这样的数字格式存储,顾名思义,这些格式需要 16 位来存储一个数字。
量化减少了存储每个数字所需的比特数,这使得模型的存储大小得以减少,因为每个模型权重所使用的比特数更少。
然而,使用更少的比特数表示每个权重意味着权重的精度降低。这就是为什么大多数文章都恰当地将量化描述为“减少模型权重的精度”。
对于视觉学习者,这里是 π 在不同精度下的表示:

作者在不同精度下表示的π
你可以使用这个 浮动点计算器 亲自尝试一下。
注意:现代量化方法可能使用定制的数字格式,而非 FP 系列来对模型进行量化。这些方法可以将量化精度降低到 1 位(Q1)。
如表中所示,随着比特数的减少,π 的精度也降低。这不仅影响小数位数,还会影响数字本身的近似值。
例如,3.141592502593994 不能在 FP8 中精确表示,因此它必须四舍五入到 FP8 能表示的最接近值——3.125,这也被称为浮动点误差。
它有什么帮助?
随着每个权重所需比特数的减少,总的 GPU 内存需求也会减少。例如,将 FP16 转换为 8 位量化(Q8)可以将每个数字存储所需的比特数从 16 位减少到 8 位。这会使模型的大小减少 50%。
举个例子,一个未经量化的 FP16 Mistral 7B 估计大小约为 14.48 GB,而一个 Q8 Mistral 7B 仅为 7.24 GB。一个 Q4 Mistral 7B 仅为 3.62 GB,这使得它可以加载到一些移动设备中。
除了减少内存外,减少的内存需求还降低了托管模型所需的最低计算要求,同时提高了推理速度。

作者在不同量化下的 7B 模型基准测试
有什么问题吗?
当然,世界上没有免费的午餐!精度的降低会影响模型输出的质量。参考我们之前的表格“π的表示”,一个以 FP16 表示的 π 可能足够准确以通过数学考试,但一个 FP8 的 π 会让你得 F。
幸运的是,大多数 LLM 对较高精度的量化不太敏感。一般来说,8 位量化或 Q8 模型几乎与原始模型相同。这一点在以下基准中有所体现,来自“低位量化 LLAMA3 模型效果如何?一项实证研究”。

提取的 8 位量化 Llama-3 与基准的对比表,来源: https://arxiv.org/pdf/2404.14047.
简而言之,这意味着通过将模型权重量化为 Q8,您几乎可以在几乎不损失的情况下将模型大小减少 50%。

提取的 4 位量化 Llama-3 与基准的对比表,来源: https://arxiv.org/pdf/2404.14047.
对于模型大小减少 75%,即 Q4,使用更智能的量化技术(如 AWQ)时,模型依然可以接受,尽管会有明显的质量损失。

提取的 3 位量化 Llama-3 与基准的对比表,来源: https://arxiv.org/pdf/2404.14047.
如果低于 Q4,模型输出质量可能会严重下降。
请注意,量化对模型质量的影响可能因模型而异。确定最佳量化级别的最好方法,实际上是基于您自己的使用情况和测试。
选择哪个量化框架?
有关选择量化框架的更严谨讨论,请参见: https://oobabooga.github.io/blog/posts/gptq-awq-exl2-llamacpp/ , https://www.reddit.com/r/LocalLLaMA/comments/1anb2fz/guide_to_choosing_quants_and_engines/
有许多量化框架可供选择,其中一些更流行的包括 GGUF、GPTQ、EXL2 和 AWQ。最适合您的量化框架将取决于您的使用场景。以下是我根据个人使用经验给出的推荐。最适合您的选择将取决于您的使用场景,实际效果可能因人而异。
GGUF
由Georgi Gerganov创建,GGUF 旨在通过最小的设置和在任何硬件上实现最先进的 LLM 推理,无论是本地还是云端,成为 AI/ML 爱好者的必备工具,由于其易用性,GGUF 已经成为许多 LLM 托管的首选。
如果你需要在普通硬件或仅有 CPU 的系统上托管模型,那么 GGUF 是最合适的选择,因为它是唯一支持 CPU 托管的框架。GGUF 还允许你在旧款 GPU 上运行较新的模型。GGUF 由于将模型权重打包为一个统一格式的单个文件,因此也是最稳定的框架。如果你需要在任何机器上可靠托管量化模型,比如你的笔记本电脑,那么 GGUF 是最好的选择。
GGUF 的一个缺点是,旧版本的量化(Qx_0)使用的是较简单的量化方法,如四舍五入量化(RTN)。这可能会在一定程度上降低模型输出质量,但在较高的量化级别下影响较小。GGUF 中的新量化方法(Qx_K 或 IQx_S)在较低的量化级别下能更好地保持模型质量。
GPTQ、EXL2 和 AWQ
GPTQ、EXL2 和 AWQ 专为 GPU 使用而设计,它们都基于 GPTQ 格式。这些框架在 GPU 上运行时通常比 GGUF 更快,因为它们专门优化了 GPU 的运行性能。EXL2 允许在模型内混合量化级别。AWQ 则倾向于提供最佳的输出质量,因为它使用比 GPTQ 更“智能”的量化技术。EXL2 和 AWQ 都致力于在较低量化级别时减少性能下降。GPTQ 通常是下游推理引擎支持最广泛的格式。
总结来说,选择 GGUF 可以方便托管,EXL2 适合混合量化级别,AWQ 则适用于输出质量,而如果推理引擎不支持其他格式,可以选择 GPTQ。
5. AWS Sagemaker 端点是如何工作的?
现在我们了解了量化是什么,那么如何将其引入到 AG 的 AWS Sagemaker 中,让用户能够为他们的用例托管自己的生产就绪模型推理端点呢?
什么是 Sagemaker 端点?
AWS Sagemaker 端点是 AWS Sagemaker 中的原生工具,用于托管模型推理。它的优势包括:
-
易于配置自动扩展:只需几行代码即可将自动扩展添加到现有端点。
-
零停机更新:Sagemaker 端点的更新默认使用 BlueGreen 部署。
-
灵活性与自定义:Sagemaker 端点可以使用自定义容器。
-
访问 AWS 服务:Sagemaker 端点能够访问 AWS 服务,比如 S3,这可以为处理推理请求时增加额外的步骤提供更大的灵活性。
这有助于节省用户的时间和专业知识,尤其是那些只想部署模型而不希望考虑在生产规模上管理它所需工程工作的用户,将原本可能需要数天/数周的工作转化为几分钟的工作。
Sagemaker 端点如何工作?
在背后,Sagemaker Endpoints 使用基于Sagemaker-Inference-Toolkit库的特殊推理容器来托管模型 API。这些容器提供了一种快速简便的方法来运行推理,无需构建自己的容器镜像,并且支持许多不同的框架,从使用 scikit-learn 容器的简单 scikit-learn 模型,到使用TensorRT-LLM容器的复杂 LLM(以及它们的 AWQ/GPTQ 量化变体)。
然而,GGUF 和 EXL2 量化仍然需要重度定制的推理框架。幸运的是,Sagemaker 提供了使用自定义容器的灵活性,并且 Sagemaker Endpoints 使这一过程变得非常简单。只需记住几个细节即可使其工作:
-
容器必须监听 8080 端口。
-
容器必须响应/ping 和/invocations
-
容器将通过‘docker run
serve’命令运行,容器预计将使用 ENTRYPOINT 而不是 CMD
-
通过指定包含模型工件的 tar.gz 文件的 S3 路径,将模型工件引入‘/opt/ml/model’目录。这发生在容器运行时之前。

由作者提供的自定义 Sagemaker 容器要求的视觉表示,灵感来源于docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html
为开源推理引擎定制
上图表示一个预打包了 Sagemaker-Inference-Toolkit 的容器。为了使用我们自己的推理引擎,我们可以简单地将预打包的包替换为我们自己的自定义包。
例如,我们策划的一个自定义容器使用户能够通过使用 Abetlen 的Llama-cpp-python作为推理引擎来托管 GGUF 模型。这个库是开源的,并且遵循宽松的 MIT 许可证。
在我们的 dockerfile 中,我们只需要写几行代码以符合 sagemaker 端点的要求:
-
将监听端口更改为 8080
-
为/ping 和/invocations 添加路由
-
在 ENTRYPOINT 中运行
6. 在 AG Sagemaker 中托管量化模型
使用自定义容器,在 AG 的 Sagemaker 环境中托管量化 LLM 仅需几行代码。
# Code will vary depending on how you have curated your own custom container.
from sagemaker.model import Model
endpoint_name = "<Name of endpoint>"
image_uri = "<ECR Image URI to Llama-cpp-python Image>"
model_artifact_location = "<S3 Path to Model Artifacts>"
model = "<Path to model file>"
# All other ENV variables defined in documentation
model_endpoint = Model(
image_uri = image_uri,
model_data = model_artifact_location,
role = role,
env = {
"MODEL": model_file_path_in_container,
"N_GPU_LAYERS": "999",
"INVOCATIONS_ROUTE": "/v1/completions"
}
)
model_endpoint.deploy(
initial_instance_count=1,
instance_type="ml.g4dn.xlarge",
endpoint_name=endpoint_name
)
就是这样,简短而简单。有了这个,我们的用户可以专注于开发他们的 LLM 使用案例,而不被幕后复杂的工作所困扰。
7. 基准测试
以下是基于单次查询推理的每秒生成的平均令牌数基准,测试了 5 次,共计 30 个提示,即每个候选者基于 150 次测试的平均值。所有测试中,我们使用了 CodeLlama 模型,因为它有多种大小可用,即 7、13、34 和 70 亿个参数。我们测试了量化和未量化的模型,使用 Transformers 作为基准,因为它通常是运行未量化模型的常见方式。
以下是基准测试的规格:

基准规格,作者提供
请注意,ExllamaV2 指的是推理引擎,而 EXL2 是 ExllamaV2 的本地量化格式,在这种情况下,ExllamaV2 也支持 GPTQ 的推理。ExllamaV2 将仅使用 Q4_0 进行基准测试,因为一些 Q8_0 量化在 HuggingFace 上找不到。
通过 Transformers(基准)的未量化
BF16:

Transformers BF16 推理基准,作者提供
以下测试中的所有倍数都是基于使用 Transformers 作为基准。例如,GPTQ 7b Q4_0 模型在“每秒令牌”列中有一个“(3.42x)”的倍数,这意味着 GPTQ 在 7b 模型上比 Transformers 基准快 3.42 倍。
通过 Llama-cpp-python 的 GGUF
GGUF 可以支持在较旧的 Nvidia T4 设备上托管,来自 g4dn 实例系列,因此我们增加了额外的测试,优化成本时尽可能使用 g4dn 实例类型:
Q4_0

GGUF Q4_0 推理(最小化成本)基准,作者提供
Q8_0

GGUF Q8_0 推理(最小化成本)基准,作者提供
使用较新的 Nvidia A10g 来自 g5 实例系列:
Q4_0

GGUF Q4_0 推理基准,作者提供
Q8_0

GGUF Q8_0 推理基准,作者提供
在每一个案例中,GGUF 都能以更低的成本或相同的价格运行模型,但速度显著更快。例如,Q8 13B 模型比基准快 74%,但成本仅为基准的五分之一!
GPTQ — 通过 ExllamaV2
ExllamaV2 仅支持在较新的 Nvidia A10g 上托管,来自 g5 实例系列,而不支持 g4dn 实例系列。
Q4_0

GPTQ Q4_0 推理基准,作者提供
GPTQ 在 ExllamaV2 上将性能提升带到了全新的水平,对于每个量化的模型大小,Q4_0 在速度上超过了基准的三倍多。
AWS Sagemaker Jumpstart
AWS 本身也提供一种名为 JumpStart 的服务,允许通过几次点击部署预训练模型。这些 AWS Sagemaker 容器实现了 Sagemaker 推理工具包,并预装了多种推理引擎。在这种情况下,使用的是 HuggingFace 的文本生成推理(TGI)框架作为推理引擎。
BF16:

AWS Jumpstart TGI BF16 推理基准测试,作者
请注意,13B 比 7B 更快。这是因为 TGI 容器能够利用更多的 GPU 内存来提高推理速度。在像 34B 和 70B 这样更大的参数规模上,使用 AWS Sagemaker Jumpstart 与 TGI 容器,甚至可以超越 ExllamaV2 上的 GPTQ。
8. 结论
量化为 LLM 提供了显著的好处,因为它减少了托管所需的内存。内存需求的减少提高了推理速度并降低了成本。通过较高位数的量化,可以几乎零损失地提高输出质量,显著提高速度并降低成本——本质上是对使用未量化 LLM 的帕累托改进。
在 AWS Sagemaker Endpoints 基础上,AG 提供的辅助功能使得整个公共部门的机构能够轻松访问创建和管理生产就绪的量化开放 LLM API 的能力。通过简化部署量化大语言模型的过程,AG 大大降低了生成高效且成本效益高的 GenAI 应用程序的门槛,使政府机构能够专注于创新和开发对公共利益有益的技术。
与此相辅相成的是,AG 将继续推进其 GenAI 事业,通过安全的跨云集成,提供访问闭源模型(如 Azure OpenAI 和 VertexAI 的 Gemini),同时与我们现有的 AWS Bedrock 服务结合。通过强大而全面的产品,AG 使用户能够根据其用例优化模型,从而在公共部门实现更好、更快和更便宜的 GenAI 应用。
参考文献
1 Sau Sheong,《与 AI 编程 — 开放 LLM》(2024),sausheong.com/programming-with-ai-open-llms-28091f77a088
[2] S. Stoelinga,《为 LLM 提供服务时计算 GPU 内存》(2023),https://www.substratus.ai/blog/calculating-gpu-memory-for-llm/
[3] M.C. Neves,《什么是量化 LLM?》(2023),https://www.tensorops.ai/post/what-are-quantized-llms
[4] D. Patel,《神经网络量化与数值格式从基础原理出发》(2024),www.semianalysis.com/p/neural-network-quantization-and-number
[5] W. Huang,《低比特量化 LLAMA3 模型有多好?一项实证研究》(2024),arxiv.org/pdf/2404.14047
[6] Oobabooga,《GPTQ、AWQ、EXL2、q4_K_M、q4_K_S 和 load_in_4bit 的详细比较:困惑度、VRAM、速度、模型大小和加载时间》(N.A.),https://oobabooga.github.io/blog/posts/gptq-awq-exl2-llamacpp/
[7] Sgsdxzy,《选择量化和引擎指南》(2024),https://www.reddit.com/r/LocalLLaMA/comments/1anb2fz/guide_to_choosing_quants_and_engines/
[8] 亚马逊网络服务,《使用您自己的推理代码与托管服务》(N.A.),docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html
应用 Python 编年史:Pydantic 简介
无论你是数据工程师、机器学习工程师还是 Web 开发者,你都应该熟悉这个工具。
·发表于Towards Data Science ·阅读时间:11 分钟·2024 年 7 月 25 日
--

夕阳如何照耀 Pydantic 用户。图片由弗拉基米尔·季莫费耶夫提供,授权给伊利亚·拉扎列维奇。
在许多使用场景中,Pydantic 几乎可以无缝地适配。数据处理等任务也能从使用 Pydantic 中受益。不过,它也可以在 Web 开发中用于解析和结构化预期格式的数据。
今天的主题是定义几个痛点,并展示如何使用 Pydantic。我们从最常见的使用场景开始,那就是数据解析和处理。
假设我们有一个包含十几列和几千行的 CSV 文件。在数据分析中,常见的情况是将这个 CSV 文件加载到 Pandas DataFrame 中,然后开始处理它。通常,你会开始检查数据和列的数据类型,删除其中的一些列,并创建新的列。这个过程是基于你对数据集的先前知识。然而,这对其他人来说并不总是透明的。他们要么需要打开 CSV 文件(或任何其他数据源),要么需要浏览代码,搞清楚哪些列正在使用或被创建。这对于数据分析和研究的初始步骤是没问题的。然而,一旦数据集被分析完,并且我们准备创建一个数据管道来加载、转换并用于分析或机器学习目的时,我们就需要一种标准化的方法,确保数据集和数据类型符合预期格式。这就是为什么我们需要一个可以声明或定义这些的库。有一些库可以做到这一点,它们大多数是开源的,而 Pydantic 作为开源库,也已经进入了不同的框架,并在各种用例中得到广泛接受。
好的,让我们开始吧。
Python — 类型提示
在我们深入我之前提到的示例之前,我想先讲一些 Python 的基础知识。
Python 在其各个版本中引入了类型提示。什么是类型提示,为什么我们需要它呢?嗯,正如我们所知,Python 是一种动态类型的脚本语言。这意味着数据类型是在运行时推断的。这在工程师能够更快地编写代码上有其好处。坏处是,直到运行代码时,你才会发现类型不匹配。那时,可能有些晚了,无法迅速修复错误。由于 Python 仍然是一种动态类型语言,因此引入了所谓的“类型提示”,其目的是弥合这一差距,工程师可以使用它来通知读者和 IDE 预期的数据类型。
示例:
def add(a, b):
return a + b
add(4, 3)
> 7
add(.3, 4)
> 4.3
add('a', 'b')
> 'ab'
这是一个简短的示例,展示了一个定义好的函数如何在多个用例中使用,其中一些用例并非其编写者所预见。对于那些足够坚持的人,你将不得不引入许多限制条件,以确保代码按预期方式使用。
类型提示是怎样的呢?
def add(a: int, b: int) -> int:
return a + b
add(4, 3)
> 7
add(.3, 4)
> 4.3
add('a', 'b')
> 'ab'
这个也有效!为什么?因为这仍然是“类型提示”,而不是“类型强制”。如前所述,它是用来“通知”读者和“用户”预期的使用方式。代码的“用户”之一是 IDE,而你选择的 IDE 应该能够识别并在你尝试绕过数据类型声明时发出警告。
为什么我们要描述这些内容呢?因为 Pydantic 是建立在这种类型提示基础上的。它使用类型提示来定义数据类型和结构,并进行验证。
Pydantic — 第一步
正如我之前提到的,Pydantic 用于验证数据结构和数据类型。你可以通过四种方式来使用它。今天我将介绍其中两个最重要的:
-
validate_call用于基于类型提示和注解来验证函数调用, -
BaseModel用于通过类定义来定义和验证模型。
Pydantic — validate_call
所以,没有比直接开始新事物更好的方法了。这就是我们开始学习 Pydantic 的方式。
在你能够使用它之前,你需要先安装它:
pip install pydantic
为了清晰起见,我在这里也标注一下 Python 和 Pydantic 的版本:
python version: 3.10.5
pydantic version: 2.5.3
然后,你需要创建一个新的 Python 项目,创建你的第一个 Python 脚本,导入 Pydantic,并开始使用它。第一个示例是修改我们之前的函数,并使用 Pydantic 来确保它按预期的方式使用。示例:
import pydantic
@pydantic.validate_call
def add(a: int, b: int) -> int:
return a + b
# ----
add(4, 4)
> 8
# ----
add('a', 'a')
> ValidationError: 2 validation errors for add
0
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/int_parsing>
1
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/int_parsing>
# ----
add(.4, .3)
> ValidationError: 2 validation errors for add
0
Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=0.4, input_type=float]
For further information visit <https://errors.pydantic.dev/2.5/v/int_from_float>
1
Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=0.3, input_type=float]
For further information visit <https://errors.pydantic.dev/2.5/v/int_from_float>
# ----
add('3', 'a')
> ValidationError: 1 validation error for add
1
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='a', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/int_parsing>
# ----
add('3', '3')
> 6
# ----
add('3', '3.3')
> ValidationError: 1 validation error for add
1
Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='3.3', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/int_parsing>
需要澄清几点:
-
validate_call被用作装饰器。它本质上是包裹在声明的函数周围,增加了在函数定义时以及调用时可以执行的额外逻辑。在这里,它用于确保传递给函数调用的数据符合预期的数据类型(提示)。 -
验证过的函数调用会在你以非预期的方式使用函数时抛出
ValidationError。这个错误信息很详细,能说明为什么会抛出这个错误。 -
根据慈悲原则,Pydantic 尝试推断你的意思并进行类型强制转换。这可能导致传递给函数调用的字符串值被隐式转换为预期的类型。
-
类型强制转换并不总是可能的,在这种情况下,
ValidationError会被抛出。
不知道什么是 Python 的 decorator 函数?阅读我之前关于这个主题的文章之一:
阅读标题后,你可能会问自己:“Python 中的函数是高级...”
[towardsdatascience.com
那么默认值和参数提取呢?
from pydantic import validate_call
@validate_call(validate_return=True)
def add(*args: int, a: int, b: int = 4) -> int:
return str(sum(args) + a + b)
# ----
add(4,3,4)
> ValidationError: 1 validation error for add
a
Missing required keyword only argument [type=missing_keyword_only_argument, input_value=ArgsKwargs((4, 3, 4)), input_type=ArgsKwargs]
For further information visit <https://errors.pydantic.dev/2.5/v/missing_keyword_only_argument>
# ----
add(4, 3, 4, a=3)
> 18
# ----
@validate_call
def add(*args: int, a: int, b: int = 4) -> int:
return str(sum(args) + a + b)
# ----
add(4, 3, 4, a=3)
> '18'
这个例子的收获:
-
你可以注解变量的可变参数声明(*args)。
-
即使你正在注解变量的数据类型,默认值仍然是一个可选项。
-
validate_call接受validate_return参数,它使得函数返回值也进行验证。在这种情况下,也会应用数据类型强制转换。默认情况下,validate_return设置为False。如果保持默认值不变,函数返回的值可能与类型提示中声明的值不一致。
如果你想验证数据类型,但同时也限制该变量可以取的值怎么办?示例:
from pydantic import validate_call, Field
from typing import Annotated
type_age = Annotated[int, Field(lt=120)]
@validate_call(validate_return=True)
def add(age_one: int, age_two: type_age) -> int:
return age_one + age_two
add(3, 300)
> ValidationError: 1 validation error for add
1
Input should be less than 120 [type=less_than, input_value=200, input_type=int]
For further information visit <https://errors.pydantic.dev/2.5/v/less_than>
这个示例展示了:
-
你可以使用
Annotated和pydantic.Field不仅验证数据类型,还可以添加 Pydantic 用来约束变量值和格式的元数据。 -
ValidationError再次非常详细地说明了函数调用中哪里出了问题,这非常有帮助。
这里有一个如何验证并约束变量值的示例。我们将模拟一个 payload(字典),在它被验证后,你希望在函数中处理它:
from pydantic import HttpUrl, PastDate
from pydantic import Field
from pydantic import validate_call
from typing import Annotated
Name = Annotated[str, Field(min_length=2, max_length=15)]
@validate_call(validate_return=True)
def process_payload(url: HttpUrl, name: Name, birth_date: PastDate) -> str:
return f'{name=}, {birth_date=}'
# ----
payload = {
'url': 'httpss://example.com',
'name': 'J',
'birth_date': '2024-12-12'
}
process_payload(**payload)
> ValidationError: 3 validation errors for process_payload
url
URL scheme should be 'http' or 'https' [type=url_scheme, input_value='httpss://example.com', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/url_scheme>
name
String should have at least 2 characters [type=string_too_short, input_value='J', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/string_too_short>
birth_date
Date should be in the past [type=date_past, input_value='2024-12-12', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/date_past>
# ----
payload = {
'url': '<https://example.com>',
'name': 'Joe-1234567891011121314',
'birth_date': '2020-12-12'
}
process_payload(**payload)
> ValidationError: 1 validation error for process_payload
name
String should have at most 15 characters [type=string_too_long, input_value='Joe-1234567891011121314', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/string_too_long>
这就是验证函数参数及其返回值的基本方法。
现在,我们将进入 Pydantic 用于验证和处理数据的第二种最重要的方法:通过定义模型。
Pydantic — BaseModel
这一部分对于数据处理目的来说更为有趣,正如你将看到的那样。
到目前为止,我们已经使用validate_call装饰器对函数进行了装饰,并指定了函数参数及其对应的类型和约束。
在这里,我们通过定义模型类来定义模型,在模型类中指定字段、字段类型和约束。这与我们之前做的非常相似。通过定义继承自 Pydantic BaseModel的模型类,我们使用了一个隐藏的机制来进行数据验证、解析和序列化。这样我们就能创建符合模型规范的对象。
这里有一个例子:
from pydantic import Field
from pydantic import BaseModel
class Person(BaseModel):
name: str = Field(min_length=2, max_length=15)
age: int = Field(gt=0, lt=120)
# ----
john = Person(name='john', age=20)
> Person(name='john', age=20)
# ----
mike = Person(name='m', age=0)
> ValidationError: 2 validation errors for Person
name
String should have at least 2 characters [type=string_too_short, input_value='j', input_type=str]
For further information visit <https://errors.pydantic.dev/2.5/v/string_too_short>
age
Input should be greater than 0 [type=greater_than, input_value=0, input_type=int]
For further information visit <https://errors.pydantic.dev/2.5/v/greater_than>
你也可以在这里使用注解,并且还可以为字段指定默认值。我们来看另一个例子:
from pydantic import Field
from pydantic import BaseModel
from typing import Annotated
Name = Annotated[str, Field(min_length=2, max_length=15)]
Age = Annotated[int, Field(default=1, ge=0, le=120)]
class Person(BaseModel):
name: Name
age: Age
# ----
mike = Person(name='mike')
> Person(name='mike', age=1)
当你的用例变得有些复杂时,事情变得非常有趣。还记得我们定义的payload吗?我将定义另一个更复杂的结构,我们将对其进行逐步处理和验证。为了让它更有趣,我们来创建一个 payload,用于查询一个充当我们与 LLM 提供商之间中介的服务。然后我们会对其进行验证。
这里有一个例子:
from pydantic import Field
from pydantic import BaseModel
from pydantic import ConfigDict
from typing import Literal
from typing import Annotated
from enum import Enum
payload = {
"req_id": "test",
"text": "This is a sample text.",
"instruction": "embed",
"llm_provider": "openai",
"llm_params": {
"llm_temperature": 0,
"llm_model_name": "gpt4o"
},
"misc": "what"
}
ReqID = Annotated[str, Field(min_length=2, max_length=15)]
class LLMProviders(str, Enum):
OPENAI = 'openai'
CLAUDE = 'claude'
class LLMParams(BaseModel):
temperature: int = Field(validation_alias='llm_temperature', ge=0, le=1)
llm_name: str = Field(validation_alias='llm_model_name',
serialization_alias='model')
class Payload(BaseModel):
req_id: str = Field(exclude=True)
text: str = Field(min_length=5)
instruction: Literal['embed', 'chat']
llm_provider: LLMProviders
llm_params: LLMParams
# model_config = ConfigDict(use_enum_values=True)
# ----
validated_payload = Payload(**payload)
validated_payload
> Payload(req_id='test',
text='This is a sample text.',
instruction='embed',
llm_provider=<LLMProviders.OPENAI: 'openai'>,
llm_params=LLMParams(temperature=0, llm_name='gpt4o'))
# ----
validated_payload.model_dump()
> {'text': 'This is a sample text.',
'instruction': 'embed',
'llm_provider': <LLMProviders.OPENAI: 'openai'>,
'llm_params': {'temperature': 0, 'llm_name': 'gpt4o'}}
# ----
validated_payload.model_dump(by_alias=True)
> {'text': 'This is a sample text.',
'instruction': 'embed',
'llm_provider': <LLMProviders.OPENAI: 'openai'>,
'llm_params': {'temperature': 0, 'model': 'gpt4o'}}
# ----
# After adding
# model_config = ConfigDict(use_enum_values=True)
# in Payload model definition, you get
validated_payload.model_dump(by_alias=True)
> {'text': 'This is a sample text.',
'instruction': 'embed',
'llm_provider': 'openai',
'llm_params': {'temperature': 0, 'model': 'gpt4o'}}
从这个详细的例子中我们可以得到一些重要的见解:
-
你可以使用枚举或
Literal来定义期望的特定值列表。 -
如果你想将模型字段命名为与验证数据中的字段名不同的名称,可以使用
validation_alias。它指定验证数据中的字段名。 -
serialization_alias用于当模型的内部字段名称不一定与序列化模型时你想使用的名称相同时。 -
可以使用
exclude=True将字段排除在序列化之外。 -
模型字段也可以是 Pydantic 模型。在这种情况下,验证过程是递归进行的。这一部分非常棒,因为 Pydantic 在验证嵌套结构时会深入执行验证。
-
在模型定义中未考虑的字段不会被解析。
Pydantic — 用例
在这里,我将展示一些代码片段,展示你如何在日常任务中使用 Pydantic。
数据处理
假设你有需要验证和处理的数据。这些数据可以存储在 CSV、Parquet 文件中,或者例如在 NoSQL 数据库中以文档的形式存储。让我们以 CSV 文件为例,假设你想处理其内容。
这里是 CSV 文件(test.csv)示例:
name,age,bank_account
johnny,0,20
matt,10,0
abraham,100,100000
mary,15,15
linda,130,100000
下面是它如何进行验证和解析的:
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from pydantic import ValidationInfo
from typing import List
import csv
FILE_NAME = 'test.csv'
class DataModel(BaseModel):
name: str = Field(min_length=2, max_length=15)
age: int = Field(ge=1, le=120)
bank_account: float = Field(ge=0, default=0)
@field_validator('name')
@classmethod
def validate_name(cls, v: str, info: ValidationInfo) -> str:
return str(v).capitalize()
class ValidatedModels(BaseModel):
validated: List[DataModel]
validated_rows = []
with open(FILE_NAME, 'r') as f:
reader = csv.DictReader(f, delimiter=',')
for row in reader:
try:
validated_rows.append(DataModel(**row))
except ValidationError as ve:
# print out error
# disregard the record
print(f'{ve=}')
validated_rows
> [DataModel(name='Matt', age=10, bank_account=0.0),
DataModel(name='Abraham', age=100, bank_account=100000.0),
DataModel(name='Mary', age=15, bank_account=15.0)]
validated = ValidatedModels(validated=validated_rows)
validated.model_dump()
> {'validated': [{'name': 'Matt', 'age': 10, 'bank_account': 0.0},
{'name': 'Abraham', 'age': 100, 'bank_account': 100000.0},
{'name': 'Mary', 'age': 15, 'bank_account': 15.0}]}
FastAPI 请求验证
FastAPI 已经与 Pydantic 集成,因此这一部分会非常简短。FastAPI 处理请求的方式是将它们传递给处理路由的函数。通过将请求传递给一个函数,验证会自动进行。类似于我们在本文开头提到的 validate_call。
用于运行基于 FastAPI 的服务的 app.py 示例:
from fastapi import FastAPI
from pydantic import BaseModel, HttpUrl
class Request(BaseModel):
request_id: str
url: HttpUrl
app = FastAPI()
@app.post("/search/by_url/")
async def create_item(req: Request):
return item
结论
Pydantic 是一个非常强大的库,具有很多机制,适用于多种不同的使用场景和边缘情况。今天,我讲解了如何使用它的最基础部分,并且在下面提供了参考资料,供那些不畏困难的人查阅。
去探索一下吧。我相信它会在不同方面对你有所帮助。
参考文献
应用统计矩和矩生成函数
动机、定义及应用
·发布于Towards Data Science ·9 分钟阅读·2024 年 1 月 22 日
--

图片由Milad Fakurian提供,来自Unsplash
介绍性和基础的概率与统计课程通常会在随机变量的上下文中讨论期望和方差。少数课程甚至会讨论更高阶的矩,如偏度和峰度,并结合矩生成函数的概念。
如果你曾参加过我的本科概率与统计课程,你可能和我当时一样感到困惑。我们中的许多人不明白为什么我们不能像计算前两个矩(期望和方差)那样,直接计算更高阶的矩。
现在我了解得更多了,我会说:“孩子们,试着解决积分题吧”。在本文中,我清楚地阐述了在正态随机变量的背景下,矩生成函数的动机和定义,以及标准布朗运动的矩。
本文分为以下几个部分。
-
期望、方差与统计矩的定义
-
矩生成函数与正态随机变量
-
通过矩生成函数推导矩
期望、方差与更高阶矩的定义
近似具有多元输出的随机函数
一种新颖的生成式机器学习模型训练方法
·发表于 Towards Data Science ·阅读时长 21 分钟·2024 年 9 月 4 日
--

Pin Movement Training — 作者图像
您可以通过克隆 github.com/narroyo1/pmt 来重现本文中的实验。
本系列之前的文章名为 近似随机函数,介绍了一种新颖的方法来训练生成式机器学习模型,能够近似任何具有单一输出变量的随机函数。从现在开始,我将简称这种方法为Pin Movement Training,或简称为PMT。这是因为它通过将大头针插入布料并移动它们的类比来加以说明。
该方法是针对具有任意输入数量的X函数,但仅有单一输出Y而描述的。本文将对具有任意输出数量的函数对PMT进行推广。将提供该方法的总结,并且应该足以理解它是如何工作的,但如果您想要更深入的描述,可以阅读之前的文章。
该推广方法由于下文所述的原因,利用了类似于自编码器的架构。由于这一点,并且由于均匀采样分布可能对许多应用更加方便,我认为此方法是变分自编码器的有效替代方案。
原始方法回顾
假设我们想使用神经网络来近似定义为 𝑓(𝑥) → 𝑌 的随机函数,其中x是X中任意维度的输入,而Y是一个一维随机变量。
我们要做的第一件事是引入一个次级输入Z,它定义为一个均匀分布的随机变量,范围为[Zₘᵢₙ, Zₘₐₓ]。这是为了给原本确定的系统引入随机性。这将给我们带来一个神经网络,定义为 𝑓𝜃(𝑥,𝑧∼𝑍) → 𝑌,其中𝜃代表网络训练后的权重。
现在让我们可视化任意给定点 𝑥′, 𝑠.𝑡. x′ ∈ X。对于这个x',我们希望将整个范围[Zₘᵢₙ, Zₘₐₓ]映射到Yₓ′。也就是说,f(x′, Zₘᵢₙ) 应尽可能接近 min(Yₓ′),而 f(x′, Zₘᵐₐₓ) 应尽可能接近 max(Yₓ′)。此外,中点 f(x′, mid(Z)) 和 mid(Yₓ′) 应尽可能相似,当然,范围内的其他每个点也应如此(参见图 1)。

图 1 将Z映射到Y — 图片来源:作者
为了实现这一点,我们可以将模型 𝑓𝜃 想象为一块可拉伸的透明布料,其中 X 水平表示,Z 垂直表示。我们还可以设想一个板子,板子上绘制了数据集中的所有数据点,在这个板子上,X 水平表示,Y 垂直表示。然后我们将布料放置在板子上方。
对于每个数据点,我们在布料的垂直中点位置放置一个“针”,该位置为Z或mid(Z)。然后我们比较针的位置与数据点的位置。如果数据点高于针,我们在不取下布料上针的情况下,将针向上移动预定的距离,使其到达板子上的更高位置。这个过程中,针会拉伸或压缩布料。如果数据点低于针,则将针向下移动预定的距离。我们将向上和向下的移动距离相加,并称之为总移动距离。
在处理每个数据点后,如果针最初不在中点位置,总移动距离会朝着实际中点的方向更大。经过足够多次重复这个过程后,针会达到一个接近中点的位置,在这个位置,向上和向下的总移动距离相等,也就是说,针上方的数据点数量与下方的数据点数量相同。请参见图 2,了解这个过程如何稳定。

图 2 将针移向观察点,直到它稳定在中间位置 — 图片来源:作者
现在,如果我们不是将 pin 放在Z的中点,而是将其放在从最小点Zₘᵢₙ到范围[Zₘᵢₙ, Zₘᵐₓ]中 1/3 距离的位置上。并且不是将其上下移动相同的预定距离,而是将其向下移动预定距离的1.5倍,向上移动预定距离的0.75倍。那么这个 pin 将在一个稳定点上停留(上下移动的总距离相等),该点大致位于数据点的1/3处。
这是因为向上移动的距离 * 较高的数据点 = 向下移动的距离 * 较低的数据点 或者 (0.75∗2/3=1.5∗1/3=0.5)。参见图 3,其中展示了此过程如何使Zₘᵢₙ + 1/3和Zₘᵢₙ + 2/3的 pins 稳定。

图 3 将 2 个 pins 移动至观测点直到它们稳定——图源自作者
我们如何通过神经网络实现这种移动? 为了用神经网络移动“布料上的 pins”,我们选择一个在Z中的值(我们称之为z-pin),并通过反向传播,将目标值设定为z-pin加/减预定的距离,操作如下:

利用这一原理,我们可以在Z中均匀选择点,并通过若干周期获得我们需要的映射。即 𝑓𝜃(𝑥,𝑧∼𝑍) → 𝑌。
来自原文的注释
-
在原文中,布料拉伸/收缩的类比指的是用于重塑模型的pins,然而模型定义和训练方法使用了z-samples这一术语来指代同一具体概念。在本篇文章中以及未来的讨论中,这些将被专门称为z-pins。
-
在选择z-pins时,原文总是将它们均匀分布在Z中,并且在每个周期对每个数据点使用相同的位置。然而,这并不是必要的,唯一的要求是z-pins在Z中均匀分布。
-
原文文章会在每个数据点使用多个z-pins。但这并非必要,每个数据点只需选择一个z-pin即可。在本篇文章中,所有实验将在每个数据点每个周期选择一个单一的z-pin。
针对多个输出的推广
在重新审视原始方法以应对单一输出后,接下来我们将讨论针对多个输出所需的变化。
重新定义 Z 空间
让我们定义Z,即我们从中选择z-pins的采样空间。在原文中,Z被定义为一个单一维度的范围,简单地通过下限和上限描述为[Zₘᵢₙ, Zₘᵐₓ]。然而,在广义方法中,为了能够处理多维输出Y,Z必须也在多个维度中进行定义(但需要注意的是,Z和Y的维度数量不必相同)。
理论上,它可以是任何有界的 n 维空间,但因为后续计算标量更为简便,正如你将看到的那样,我选择使用一个可以通过原点 O、半径 R 和维度 N 定义的 超球面(见 图 4)。

图 4 三维超球面 Z-空间 — 作者图片
现在让我们定义一些与 Z 相关的概念,这些概念将在接下来的讨论中派上用场。
-
z-针:这些是Z中的均匀采样点。它们可以定义为一个N维向量,形式如下:zₚᵢₙ = (z₀, z₁, …, zₙ),其中 z₀, z₁, … 是Z中的坐标。
-
z-方向:一个 z-方向 是一个位于原点 O 的单位向量,定义如下:z-dir = O + (ž₀, ž₁, …, žₙ)
-
z-线:一个 z-线 是一个位于 Z 中的直线,连接 Z 中的任意两个点。我们将其定义为一个以 z-针 为起点,并具有 z-方向 的直线,包含所有位于 Z 中的点,形式如下:zₗᵢₙₑ = zₚᵢₙ + z-dir 使得 ∀z ∈ zₗᵢₙₑ : z ∈ Z
模型
进入多维 Z 空间引入了一个重要的挑战。在一维的 Z 和 𝑌 空间中,判断所选的 z-针 投影,即 𝑓𝜃(x, zₚᵢₙ) 是否大于或小于观测到的数据点,以决定将其移动到哪个方向,变得非常简单。在一维中,“大于”可以简单地转换为Z中的“大于”,而 z-针 可以简单地向上移动。这是因为我们只是将一条线映射到另一条线。
但是在多维 𝑌 和 Z 的情况下,无法假设这两个空间具有相同的形状或相同的维度数量,这意味着为了根据数据点与z-针的关系决定移动方向,有必要将数据点从 𝑌 映射到 Z。这意味着除了训练函数 𝑓𝜃 来生成 𝑌 中的值外,我们还需要训练一个逆函数 𝑓𝜃⁻¹ 来将数据点映射到 Z。这一事实使得我们的模型架构发生了变化,变成了如下所示:

图 5 模型架构 — 作者图片
模型的左侧允许我们将 𝑌 中的点映射到 Z。模型的右侧允许我们通过在 Z 中采样点来生成 𝑌 中的随机样本。
你可能注意到,这个架构与普通自编码器的架构相似,的确如此。这一优势在于,它使得该方法对于学习有界且均匀分布的潜在表示非常有用。
方法
在定义了我们需要的所有概念后,我们可以继续讨论如何在多维空间中进行针的移动。
将数据点映射到 Z
第一步是使用逆函数𝑓𝜃⁻¹(或使用自编码器术语中的编码器)将批次中的所有数据点从𝑌空间映射到Z空间。我们将原始数据点称为y 数据,将映射后的数据点称为z 数据。

图 6 将数据点映射到 2 维Z 空间 — 图像来源于作者
选择 z-针
接下来,我们必须选择一些z-针。为了做到这一点,我们首先选择均匀采样的z-方向,每个数据点选择一个。最简单的方法是选择一个超球面上的随机点,其维度与Z相同。然后,我们使用选定的z-方向,并将它们平移,使得前一步中映射的z 数据作为原点。这就得到了如图 7所示的一些z 线。

图 7 在 2 维Z 空间中选择随机z 线 — 图像来源于作者
一旦我们得到了我们的z 线,接下来就可以在这些线中随机选择点,这些点就是我们的z-针。图 8展示了这种情况的示意图。

图 8 在 2 维Z 空间中选择随机z-针 — 图像来源于作者
为了使方法有效,对于任何给定的z 线,它在Z中的每个映射数据点z 数据出现的概率应该是相等的,否则在计算运动标量中的方程将无法成立。
给定一个 2 维的Z空间,并且对于其中的任何一个z 线,可以将其视为一个最小宽度为𝜖的线段,使其看起来像一个长矩形,类似于图 8中的z 线。任意给定的𝑧出现在其中的概率就是这个“薄”z 线的面积与Z面积之比。

因为这个“薄”z 线是矩形的,所以其任意一个最小长度为𝛿的线段𝑠在其长度上具有相同的面积,因此任意一个𝑧出现在该线段的概率是相等的。

同样,任意给定的𝑧出现在这个“薄”z 线中的概率,选择该“薄”z 线的z-方向的概率是恒定的,因为z-方向是通过均匀分布选择的。

根据方程(2)和(3),我们得到任意一个𝑧出现在给定的z 线的任意线段上的概率,并选择相同的z-方向,而且对于满足上述要求的每个线段,这个概率都是相同的。

这个概率与𝑧在z 线中的位置无关,因此在任何z 线中的分布都是均匀的。
计算目标值
在选择了z-pins之后,我们可以继续计算目标值(或z-targets)用于反向传播。我们所需要做的就是将运动常数𝑀加到每个z-pin上,方向是映射数据点𝑧-𝑑𝑎𝑡𝑎所在的方向。

图 9 显示了如何计算z-targets。

图 9 计算z-targets 在 2-D Z-space 中 — 图片来源:作者
计算运动标量
运动标量的计算方法类似于原始一维方法中的计算方式。
让我们首先通过图示一个z-line和一个z-pin以及一些映射的数据点𝑧𝑑𝑎𝑡𝑎,如同在图 10中看到的那样。

图 10 计算标量 — 图片来源:作者
设a为z-pin到z-line一端的距离,b为到另一端的距离。并且设前侧的数据点数量为a',后侧的数据点数量为b'。我们的目标是使数量a'与距离a成正比,b'与b成正比,即 𝑎:𝑏::𝑎′:𝑏′。
接下来,我们将把𝛼称为在长度为a的侧面上应用于z-pin的运动标量。而我们将把𝛽称为在长度为b的侧面上应用于z-pin的运动标量。
我们还将T称为总运动量,它是将z-pin沿每个数据点的侧面移动一个常数运动量M并乘以该侧的标量的总和。

我们希望当 𝑎′/(𝑎′+𝑏′)≈𝑎/(𝑎+𝑏)∧𝑏′/(𝑎′+𝑏′)≈𝑏/(𝑎+𝑏) 时,T 为 0(即稳定),即当z-pin将数据点按照预定比例分配到两侧时。将T替换为0后,方程(5)给出了以下方程:

现在让我们记住,并非所有的z-lines长度相同,因为它们被由Z定义的超球面所限制,朝向中心的z-lines将比边缘的z-lines长。较长的z-lines将表示Z中更大的空间(参见方程(1)),因此它们在运动中的影响应该与其长度成比例。我们希望T与z-line的长度成线性关系,这给我们带来了以下方程:

如果我们将(6)和(7)合并,我们得到标量应具有以下值:

这些方程与原文中的方程相似。
你可能已经注意到,当a或b趋近于0时,这些方程会在边缘发生断裂。
为了解决这个问题,引入了一个最大标量常数 S 来钳制标量。当然,在钳制标量时,我们必须小心调整两侧的值,例如,如果 a 非常小(因此 𝛼 很大),但数据点位于 b 一侧,则标量 𝛽 也必须进行调整,否则方程 (5) 将无法成立。
我们首先选择两个标量 𝑚𝑎𝑥(𝛼,𝛽) 中的最大值。然后,我们通过将 S 除以 𝑚𝑎𝑥(𝛼,𝛽) 来计算一个调整值,并将其钳制到 1.0,以确保其始终位于 [0, 1] 的范围内。我们将使用该调整值来防止标量超过 S。最后,如果 a 为 0.0,则 𝛼 和 𝛽 的值分别为 S 和 0.0,如果 b 为 0.0,则反之亦然。这样,我们就得到了修正后的方程式 (8b)。

下面你可以看到与 a 或 b 成比例的标量图像。注意它们在超过选定的 S 之后是如何被钳制的。

图 11 针脚标量钳制,S=5.0 — 图片由作者提供
计算出两个标量后,我们可以通过确定数据点所在的边来选择使用哪个标量。

训练模型
现在所有相关概念都已明确,我们可以继续描述训练算法。
1. 预训练和选择 Z 超参数
所描述的算法假设模型 𝑓𝜃⁻¹ 和 𝑓𝜃 彼此反向匹配。如果我们在进行针脚运动时同时训练这两个模型来使它们相互匹配,可能会导致启动缓慢。因此,已经发现,进行“预训练”阶段,使得我们仅训练 𝑓𝜃⁻¹ 和 𝑓𝜃 以使它们匹配,会更有利。这个阶段本质上是一个普通的自编码器训练。在重构误差达到一个合理低的值之后,算法可以进入主训练阶段。
这个预训练阶段还有一个额外的优势,它使得在完成后更容易定义 Z。在章节 重新定义 Z-space 中提到,Z 是由原点 O 和半径 R 定义的。经过一段时间的预训练后,我们只需要通过逆模型运行一批数据点来计算一组 Z-data。

然后我们取这个集合的平均值,并将其用作原点O。

我们还可以使用 Z-data 到 O 的平均距离作为 R,但是已观察到,调整和调试该值可能会获得更好的结果。

这是有效的,因为在“预训练”阶段后,模型已经找到了一个能够表示数据的区域,因此在其附近定义 Z 很可能会产生较低的重构误差。
2. 针脚运动
为了开始 pin 的移动,我们从训练数据集中选择一批数据y-data = {y-data₀, y-data₁, …, y-dataₙ}并将其映射到 z-data = {z-data₀, z-data₁, …, z-dataₙ},正如在将数据点映射到 Z 中所解释的那样。
下一步是随机选择z-pins集合{z-pin₀, z-pin₁, …, z-pinₙ}(每个数据点一个)的方法,如选择 Z-pins 部分所述。
注意,每个数据点可以选择多个 z-pins 。但是这不是必须的,为了简便起见,我们在实验中只使用一个。
然后我们计算目标值z-targets = {z-target₀, z-target₁, …, z-targetₙ}和标量 s = {s₀, s₁, …, sₙ},如计算目标值和计算移动标量部分所述。
获得z-targets后,我们通过将其传递给𝑓𝜃计算当前模型的预测值,这将给我们:

现在我们已经为损失函数的第一个组成部分准备好了所有内容:

请注意,我们使用的是加权平均绝对误差(WMAE)函数,而不是加权平均平方误差(WMSE)。这是因为后者旨在惩罚较大的差异,而我们将所有的 pin 都移动相同的距离。
3. 重建损失
损失函数的下一个组成部分是我们的模型𝑓𝜃和我们的逆模型𝑓𝜃⁻¹之间的差异。这与变分自编码器和普通自编码器中的重建损失非常相似。我们需要将批量数据点传递给𝑓𝜃⁻¹,获取结果后再传递给𝑓𝜃,然后使用这些结果和原始数据点进行反向传播。

4. 逆重建损失
在定义损失函数的最后一个组成部分之前,我们先解释一下它为什么是必要的。理想情况下,在训练结束时,𝑓𝜃和𝑓𝜃⁻¹都应该是双射的,这意味着Z和Y之间会有严格的一一对应关系。然而,在训练过程中并不能保证这一点,可能会出现Z中的某些区域未能映射到Y中。

图 12 模型和逆模型可能不是双射的 — 作者提供的图片
正如你在图 12中看到的,经过loss-y组成部分的训练后,𝑓𝜃和𝑓𝜃⁻¹在Y上是一致的。即∀y ∈ Y, 𝑓𝜃⁻¹(𝑓𝜃(𝑦)) ≈ y。然而,并不是所有的Z都被使用,一些Y中的点映射到了它之外。这是一个问题,因为假设将z-pins移动到一个位置,这个位置会映射到一个Y中的点,而𝑓𝜃和𝑓𝜃⁻¹都能一致,这个假设被打破了。
图 12展示了可能发生的两个问题。布料中的“折叠”发生在Z中的两个或更多点映射到Y中的同一点时。发生“越界”时,Z中的一个点映射到Y之外的点。
为了解决这个问题,我们向损失函数中添加了第三个组件:

这样做的目的是使𝑓𝜃和𝑓𝜃⁻¹在Z方面保持同步,方法是选择Z中的随机点,而不是使用训练集中的数据点。
请注意,对于重建损失和逆重建损失,我们简单地使用均方误差(MSE)。
5. 损失函数
现在我们已经有了损失函数的所有组件,剩下的就是为它们定义权重,我们将这些权重命名为𝛾-𝑝、𝛾-y和𝛾-𝑧。我们可以将(10)、(11)和(12)结合起来,像这样定义损失函数:

剩下的就是对损失进行反向传播。
测试模型
在原始论文中,我们使用了目标 1 和目标 2 测试,测量了z-pins之间的数据点密度,并将其与测试数据集的密度进行了比较。然而,在多维空间中,这种方法并不实用,因为z-pins之间的空间数量会迅速增大。
原始论文还使用了地球搬运工距离(EMD)作为模型性能的指标。对于多维PMT,我们将使用EMD来衡量模型的准确性。我们将通过将训练数据集中的数据点与PMT模型生成的数据点进行比较来定义 EMD 误差。

为了估计最低的 EMD 误差是多少,我们还将通过将训练数据集中的数据点与测试数据集中的数据点进行比较来计算一个基准 EMD。

这为我们提供了一个基准,可以用它来与 E-emd 进行比较,从而衡量模型的准确性。
与变分自编码器的比较
与PMT最相似的生成模型是变分自编码器(VAE)。它具有几乎相同的神经网络架构,并且既是生成模型又是潜在表示映射器。两者之间的最大区别在于,VAE中的源分布是无界的(高斯分布),而PMT中的源分布是有界的(均匀分布)。
然而,实验表明,无论是有界还是无界目标分布,PMT都优于VAE。此外,PMT中的重建误差显著低于VAE。其原因可能在于,损失函数的各个组件在PMT中相互协作,而在VAE中则是相互竞争。而且,由于目标分布是均匀的,Z中数据点之间的间距可以更大。
另一个区别是,PMT有更多的超参数,包括𝑆(最大标量)、𝛾-𝑝(针脚移动权重)、𝛾-𝑦(重建损失权重)、𝛾-𝑧(反向重建损失权重)和𝑀(运动常数),而VAE的超参数仅为kld 权重。这可能会使得PMT的训练更加困难。
最后,PMT每个周期的训练时间比VAE长,这因为需要进行一次传递来计算z-targets,而且损失函数有一个附加组件(见公式(12))。
实验
现在,我将在多个数据集上尝试该模型。为了方便绘制,下面的实验将不包含X输入。
由于与VAE的相似性,每个实验将使用PMT和VAE模型进行比较。在每个实验中,两个模型将采用相同的神经网络架构。
你可以在github.com/narroyo1/pmt找到源代码和重现下面实验所需的一切。
多个数据块
我将尝试的第一个数据集是使用make_blobs()从sklearn库生成的。顾名思义,它生成若干个高斯数据块,是测试PMT在无界数据集上表现的一个良好数据集。

图 13a 生成的数据 — 作者提供的图像


图 13b PMT 训练动画 /图 13c VAE 训练动画 — 作者提供的图像


图 13d EMD 误差图/图 13e 重建损失图 — 作者提供的图像
图 13a展示了由make_blobs()函数生成的测试数据。图 13b和图 13c分别展示了PMT和VAE训练方法的动画。
图 13d展示了计算的EMD误差(𝐸-𝑒𝑚𝑑)图,分别为PMT、VAE和基准值(𝐵-𝑒𝑚𝑑)。正如你所看到的,PMT的𝐸-𝑒𝑚𝑑比VAE的更接近𝐵-𝑒𝑚𝑑,这意味着其性能更好。
图 13e展示了PMT和VAE的重建误差图。正如你所看到的,PMT的重建误差比VAE低一个数量级。
方形与另一个方形
第二个数据集相当简单。我们只需要一个外部的方形区域,里面均匀分布着数据点,再加上一个内嵌的方形区域,里面同样是均匀分布的数据,但密度更大。这将帮助我们测试具有尖锐细节的非高斯分布。

图 14a 生成的数据 — 作者提供的图像


图 14b PMT 训练动画/图 14c VAE 训练动画 — 图像来源:作者


图 14d EMD 图/图 14e 重建损失图 — 图像来源:作者
图 14d 显示了 EMD 错误图,你可以从中看到 PMT 超越了 VAE。
图 14e 显示了重建误差值,你可以看到 PMT 的重建误差比 VAE 低两个数量级。
人类行为
下一个数据集由人体运动传感器数据构成,这些数据是通过进行几种体育活动获得的。它来源于 移动健康人体行为分析数据集¹。这个数据集具有 3 个维度,而不像之前的数据集只有 2 个维度。

图 15a 测试数据 — 图像来源:作者


图 15b PMT 训练动画/图 15c VAE 训练动画 — 图像来源:作者


图 15d EMD 图/图 15e 重建损失图 — 图像来源:作者
图 15d 显示了 EMD 错误图,再次证明 PMT 超越了 VAE。
图 15e 显示 PMT 的重建误差比 VAE 低一个数量级。
MNIST
最后是著名的 MNIST 数据集²。正如你所知,它包含了人类书写的数字的位图,而任务是生成看起来像手写数字的新的数据点。这个数据集很有趣,因为它具有大量的输出维度(784)和 4 维的潜在空间。



图 17a PMT 原始数据/图 17b PMT 重建/图 17c PMT 生成的样本 — 图像来源:作者



图 16c VAE 原始数据/图 16e VAE 重建/图 16e VAE 生成的样本 — 图像来源:作者

图 16g 重建损失图 — 图像来源:作者
由于输出维度过多,计算 EMD 错误图非常困难(而且没有指示性意义),因此这个数据集没有 EMD 错误图。
图 16b 绘制了重建误差图,再次证明 PMT 的误差低于 VAE 的误差。
结论与未来的方向
用单一输出逼近随机函数对于预测单值分布(如温度或市场值)非常有用。但产生多个输出的能力使得该方法适用于多种应用场景,如模拟和生成任务。
本文描述的多输出方法已经证明,在实验数据集中,它能够在概率相似度和重建方面优于 VAE。我相信它们在各种现实世界应用中也会产生更好的结果。
未来,我希望在更高维度的数据集上继续测试PMT,以进行生成任务,比如时尚 MNIST和CelebA。为此,还需要尝试深度网络和卷积神经网络(CNN)。
如果有任何问题或建议,欢迎随时联系我。
https://www.kaggle.com/datasets/gaurav2022/mobile-health
CC0 公共领域 creativecommons.org/publicdomain/zero/1.0/
[2] MNIST 手写数字数据库
[## MNIST 手写数字数据库,Yann LeCun,Corinna Cortes 和 Chris Burges
手写数字 Yann LeCun,纽约大学 Courant 研究所 Corinna Cortes,Google 实验室,纽约 Christopher J.C. Burges…
yann.lecun.com](http://yann.lecun.com/exdb/mnist?source=post_page-----ffefc7099a90--------------------------------)
MIT choosealicense.com/licenses/mit/
AI 代理能否完成你在应用上的日常任务?
在一个由应用和人组成的世界中对编码代理进行基准测试
·发表于 Towards Data Science ·阅读时间:7 分钟·2024 年 7 月 28 日
--
想象一个世界,在这个世界中,AI 代理可以充当你的个人助手,像在亚马逊上设置退货或根据你的邮件取消会议等任务。这要求代理能够在复杂的工作流中与应用程序交互操作,而迄今为止,还没有一个很好的方法来评估这样的代理。直到现在。
🤖 1. 个人应用的编码代理
随着底层 AI 模型的不断改进,AI 助手(例如我们手机上的助手)也在不断进步。几年前,它们在回答简单的事实性问题时还存在困难。今天,它们已经开始达到可以代表我们操作应用程序执行基本任务的程度。例如,最近的 GoogleIO 和 Apple WWDC 事件,正是关于 AI 助手作为自主代理为我们工作这一愿景。
在未来,它们将能够在我们的应用程序上自主完成更复杂的任务。例如,你可以说:“嘿,我的一些同事通过电子邮件取消了会议,请删除我对应的手机提醒。” 代理会自主地检查你的邮箱,弄清楚哪些同事取消了会议,然后打开日历应用,确定哪些会议是和这些同事的,并将其取消。

包含像亚马逊、Venmo、Gmail 等应用的日常任务示例
AI 模型可以通过交互式编写代码并调用API来解决这类任务。API 使代理能够在应用程序上执行基本操作,代码 使代理能够将这些操作组织成复杂的逻辑和控制流程,而交互 使代理能够探索用户账户并根据代码执行结果进行适应。
请看下图中的示例,代理的任务是启动一个播放列表,确保其中的歌曲足够覆盖用户今天的运动时长。为此,代理首先需要编写代码调用 SimpleNote API(第一个代码块),找到并“读取”(打印)包含运动计划的笔记。只有在进行这次交互以观察笔记的结构后——看到时长按天列出——代理才能编写必要的代码(第二个代码块),这包括查找今天的星期几并提取相关时长。为了选择播放列表,代理必须编写丰富的代码,使用 for 循环和其他控制流来遍历播放列表,计算播放列表的时长,并播放一个涵盖运动时长的列表(第三个代码块)。

一个代理代表用户通过交互式编写包含各种应用程序 API 调用的丰富代码来解决任务。
现在我们知道代理如何完成这类任务,问题是:
我们如何开发和基准化这样的编码代理,以完成各种应用中的日常数字任务?
为此,我们需要(i)一个丰富、稳定且可复现的执行环境,让代理可以通过代码和 API 与许多日常应用进行交互,(ii)需要 API 调用和丰富交互式编码的复杂任务,以及(iii)一个可靠的评估框架。
现有的基准如 Gorilla、ToolBench、API-Bank、ToolTalk、RestBench 并不满足这三个要求。除了缺少上述类型的环境,它们的任务只涉及 1–4 次 API 调用的线性序列,不需要丰富的交互式编码,并且它们通过将代理的解决方案与参考解决方案(使用 LLM 或人工方式)进行比较来进行评估,但这种方式在复杂任务中表现不佳,因为复杂任务可能有多种不同的解决方案。
🌎 2. 引入 AppWorld
为了解决这一差距,我们引入了 AppWorld,它包括(1)一个可控且模拟的世界环境(引擎),在这个环境中,编码代理可以通过 API 代表人们操作各种应用程序,(2)一个在此环境上定义的复杂任务的基准,以及(3)一个强大的评估框架,用于评估代理的性能。

AppWorld 的概述,包括一个模拟的应用程序和人类世界环境,一个建立在其上的复杂任务基准,以及一个强大的评估框架。
⚙️ 2.1 引擎:模拟的数字世界
AppWorld Engine 是一个基于 API 的高保真模拟器(60K 行代码),模拟来自各个领域的 9 款日常应用程序的生态系统(例如 Gmail 用于电子邮件,Amazon 用于购物,Spotify 用于音乐等)。这个引擎由一个完全可控的本地后端支持,包含 457 个 API 和 100 多个数据库表,紧密模拟了真实应用的丰富功能。这些 API 拥有详细的文档(互动探索),代理可以阅读这些文档来理解其使用方法。
然后,我们在此引擎之上模拟了一个人的数字世界及其在这些应用程序中的数字活动。具体来说,我们将应用数据库(DB)填充了 106 个虚构人物,这些人生活在这个模拟世界中。他们通过各种关系相互联系,比如室友、朋友、经理等,以便执行人际任务,例如与室友分账。接着,模拟他们的日常生活,执行各种个人和人际活动,例如在 Amazon 上订购 T 恤进行家庭配送,或通过电话向室友请求车钥匙等。最终的数据库包含超过 30 万行,跨越 726 列。
📊 2.2 复杂任务基准测试
AppWorld 基准测试在此引擎基础上构建了 750 个日常任务(如上所示的示例),这些任务需要使用多个 API(通常 15 个以上),跨越多个应用(1–4 个),并且要求编写丰富且互动的代码(通常是 80 行以上,包含许多编程构造)。请参见下面的统计图表,并在我们的游乐场上互动探索任务。
每个任务指令都会配有一个监督者(AppWorld 中的人物),代理需代表其执行任务。代理可以访问他们所有的应用账户。每个任务的初始数据库状态都经过精心设计(通过编程),以确保任务明确定义,并包含现实的干扰和障碍。任务还包括任务变体,从整体上检查代理是否能在不同的初始条件和指令变体下可靠地解决任务。
所有任务的实现都是由我们设计和开发的(不是众包的)。它们的实现代码总行数超过 40K 行(是的,任务开发投入了大量工作;请参阅论文)。

AppWorld 基准测试中任务的难度等级分布,以及我们编写的解决方案的特性,如应用程序数量、唯一 API 和代码行数、评估测试的数量等。
✔️ 2.3. 强大的评估框架
在 AppWorld 中,复杂任务可以通过多种方式完成(例如,可以通过其 Amazon API 或确认邮件下载订单收据)。此外,解决任务的代理可能以多种不同的方式造成附带损害(例如,发起一个并未要求的退货)。因此,基于过程的方式,将代理生成的代码与参考代码或 API 调用进行比较,是不足以评估任务完成情况的。
相反,AppWorld 采用了基于状态的方法。具体来说,对于每个任务,我们定义了一套程序化的单元测试,利用数据库状态快照作为输入:(1) 代理开始前的状态和(2) 代理结束后的状态。然后,我们检查是否只进行了预期的数据库更改,并确保没有发生意外更改。这使我们能够可靠地检查代理是否正确完成任务而未造成附带损害。
最后,为确保任务是可解的,我们编写验证解决方案代码,并通过程序化的方式验证其运行是否通过所有评估测试。
🧪 3. 代理的表现如何?
我们使用多种少量样本提示方法对多种 LLM 进行了基准测试,方法包括 ReAct、计划与执行、生成带反思的完整代码和函数调用。即便是最好的 LLM,GPT-4o,也表现得相当差劲。例如,它在挑战测试集中的任务仅正确完成约 30%。GPT-4 Turbo 和开放 LLM 则更为落后。
此外,在我们的严格鲁棒性度量下,得分要低得多,该度量检查代理是否能够在不同的初始条件和指令扰动下可靠地完成所有任务变化。

展示了使用各种提示方法的最先进大型语言模型(LLMs)的分数。AppWorld 对当前模型来说是具有挑战性的。例如,GPT-4o 仅能正确解决大约 30%的 Test-Challenge 任务,且在我们的鲁棒性度量中得分下降至 13.0。
此外,分数随着难度的增加而显著下降,依据我们提供的标签以及其他难度指标(例如,基于我们书面验证解决方案的 API 数量和代码行数)。

展示了在各种任务难度指标下,最佳模型 GPT4-o 的分数曲线。随着任务难度的增加,模型的分数显著下降。
🔮 4. AppWorld 的未来是什么?
AppWorld 是一个模块化且可扩展的基础平台,为自动化数字任务开辟了许多激动人心的可能性。例如,未来的工作可以:
-
将 AppWorld 引擎扩展为支持基于浏览器/移动 UI 的控制,以便为现有任务提供统一的基准,涵盖代码、API 和 UI 基础的自主代理。
-
将 AppWorld 基准扩展为需要多代理(和人类)协调与合作的任务(例如,通过与朋友的代理在电子邮件中协调,安排一次日历会议)。
-
将我们的数字世界引擎叠加到一个物理世界引擎上,比如 Simulacra,并通过角色扮演代理在一个受控环境中研究社会动态和行为。
-
将引擎作为一个没有后果的沙盒,用来研究当数字助手被赋予“代理”角色,在现实世界中代替我们行动时,可能产生的隐私和安全风险。
-
当然,还可以将 AppWorld 扩展到更大的应用和任务生态系统。
我们非常激动,期待自己和他人能够在 AppWorld 的基础上探索这些方向(以及更多!)。如果需要帮助或想要合作,随时联系我们!
🚀 5. 准备好尝试了吗?
AppWorld 易于使用且速度快。你可以通过 pip 安装其开源 Python 包并开始构建和测试你的代理。如果你已经有了代理,下面的代码就是你在 AppWorld 上运行并评估它所需要的全部内容。

AppWorld 环境的最小使用示例。
有关论文、代码、排行榜、数据探查器(任务、API、代理轨迹)、互动游乐场(直接与 AppWorld 任务互动)、视频解说等内容,请访问 https://appworld.dev。
新消息:AppWorld 在 ACL'24 上获得了 最佳资源论文 奖。🏆 🎉
图片来源:所有图片均由作者创建。
数据科学家是算命师吗?

图片由petr sidorov提供,来自Unsplash
我们是否应该努力成为其中之一?
·发布于Towards Data Science ·8 分钟阅读·2024 年 5 月 9 日
--
在一个不断拥抱新思想和新能力的世界里,数据科学家利用越来越复杂的算法,输入越来越庞大的数据集,似乎比以往任何时候都更加神秘,尤其是在预测过程中。
当我作为数据科学家工作,主要的交付成果是产品需求预测时,我不禁思考:数据科学家是不是现代的算命师,凭借神秘但强大的魔法为企业做出明智的决策?
本文将从我作为一名数据科学家的经验和思考出发,窥探预测未来的迷人而全面的世界,尤其是作为一名始终努力获得自己甚至他人信任的预测者。
我从未将自己视为算命师,也许是因为我一直认为背后的科学原理会将我与算命中使用的“黑盒”魔法区分开来。我的模型是有道理的,因为我们有充足的特征、可信赖的算法和不错的历史表现。
我曾经认为这已经足够解释问题,即使面对业务决策者和模型构建者之间的技术鸿沟。使用最前沿的流行词和难以理解的大量数据构建的模型,总是给利益相关者带来复杂的感受。
GPT 是优秀的嵌入模型吗?
一个令人惊讶的实验,表明细节决定成败
·发表于Towards Data Science ·6 分钟阅读·2024 年 5 月 18 日
--

图像由作者使用 DALL-E 生成
随着可用的嵌入模型数量不断增加,选择适合自己机器学习应用的模型可能变得具有挑战性。幸运的是,MTEB 排行榜为各种自然语言处理任务提供了全面的排名指标。

截至 2024 年 5 月 17 日,来自MTEB 排行榜的前五大嵌入模型
当你访问这个网站时,你会注意到排名前五的嵌入模型是生成式预训练变换模型(GPTs)。这可能让你认为 GPT 模型是最适合用于嵌入的模型。但这是真的吗?让我们进行一个实验来找出答案。
GPT 嵌入
嵌入是文本的张量表示,它将文本标记 ID 转换并投影到张量空间中。
通过将文本输入到神经网络模型中并执行前向传递,你可以获得嵌入向量。然而,实际过程要复杂一些。让我们一步一步分解:
-
将文本转换为标记 ID
-
将标记 ID 传递到神经网络中
-
返回神经网络的输出
在第一步中,我将使用分词器来实现这一点。model_inputs是文本内容"some questions."的张量表示。
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
messages = [
{
"role": "user",
"content": "some questions.",
},
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to("cuda")
第二步很简单,将model_inputs前向传递给神经网络。生成的标记的 logits 可以通过.logits访问。torch.no_grad()表示我不希望模型的权重被更新,因为模型处于推理模式。
import torch
with torch.no_grad():
return model(model_inputs).logits
第三步有点棘手。GPT 模型是解码器-only,其 token 生成是自回归的。简单来说,已完成句子的最后一个 token 已经看到了句子中的所有前面的 tokens。因此,最后一个 token 的输出包含了来自前面 tokens 的所有相关性分数(注意力)。
完美!你最感兴趣的是最后一个 token,因为在 transformer 中的注意力机制。
在 Hugging Face 中实现的 GPT 的输出维度是(批量大小,输入 token 大小,词汇表数量)。为了获取所有批次的最后一个 token 输出,我可以执行张量切片。
import torch
with torch.no_grad():
return model(model_inputs).logits[:, -1, :]
这些 GPT 嵌入的质量
要衡量这些 GPT 嵌入的质量,可以使用余弦相似度。余弦相似度越高,句子的语义越接近。
import torch
def compute_cosine_similarity(vec1, vec2):
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
return cos(vec1, vec2)
让我们创建一些实用函数,允许我们遍历问题和答案对列表并查看结果。Mistral 7b v0.1 指令,这是一个出色的开源模型,用于本实验。
import torch
from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.1"
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
def generate_last_token_embeddings(question):
messages = [
{
"role": "user",
"content": question,
},
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to("cuda")
with torch.no_grad():
return model(model_inputs).logits[:, -1, :]
def get_similarities(questions, answers):
for question in questions:
for answer in answers:
q_embedding, a_embedding = (
generate_last_token_embeddings(question),
generate_last_token_embeddings(answer),
)
similarity = compute_cosine_similarity(q_embedding, a_embedding)
print(colored(f"question: {question} and ans: {answer}", "green"))
print(colored(f"result: {similarity}", "blue"))
questions = ["Where is the headquarter of OpenAI?", "What is GPU?"]
answers = [
"OpenAI is based at San Francisco.",
"A graphics processing unit (GPU) is an electronic circuit that can perform mathematical calculations quickly",
]
get_similarities(questions, answers)

Mistral 7b v0.1 指令的余弦相似度(图像来源:作者)
结果与观察
对于第一个问题和答案对:
-
问题:“OpenAI 的总部是什么?”
-
答案:“OpenAI 总部位于旧金山。”
-
余弦相似度:0.96
对于第二个问题和答案对:
-
问题:“什么是 GPU?”
-
答案:“图形处理单元(GPU)是一个能够快速执行数学计算的电子电路。”
-
余弦相似度:0.94
对于不相关的对:
-
问题:“OpenAI 的总部在哪里?”
-
答案:“图形处理单元(GPU)是一个能够快速执行数学计算的电子电路。”
-
余弦相似度:0.90
对于最差的对:
-
问题:“什么是 GPU?”
-
答案:“OpenAI 总部位于旧金山。”
-
余弦相似度:0.93
这些结果表明,在这种情况下,使用 GPT 模型(如 mistral 7b instruct v0.1)作为嵌入模型可能不会在区分相关和不相关的对方面产生很好的结果。但为什么 GPT 模型仍然位居前五名嵌入模型呢?
对比损失来解救
tokenizer = AutoTokenizer.from_pretrained("intfloat/e5-mistral-7b-instruct")
model = AutoModelForCausalLM.from_pretrained(
"intfloat/e5-mistral-7b-instruct"
)

e5-mistral-7b-instruct的余弦相似度(图像来源:作者)
使用不同的模型e[5-mistral-7b-instruct](http://intfloat/e5-mistral-7b-instruct)重复相同的评估过程,该模型是 MTEB 排行榜上排名前列的开源模型之一,并从 mistral 7b instruct 微调而来。我发现,相关问题和对之间的余弦相似度分别为 OpenAI 和 GPU 问题为 0.88 和 0.84。对于不相关的问题和答案对,相似度降至 0.56 和 0.67。这个发现表明,e5-mistral-7b-instruct是一个在嵌入方面有显著改进的模型。这种改进是如何实现的呢?

对比损失函数
深入研究e5-mistral-7b-instruct背后的论文,关键在于使用对比损失进一步微调 mistral 模型。
与使用交叉熵损失对预测标记和标签标记进行训练或进一步微调的 GPT 模型不同,对比损失旨在最大化负样本对之间的距离,同时最小化正样本对之间的距离。
这篇博客文章详细介绍了这一概念。sim函数计算两个向量之间的余弦距离。对于对比损失,分母表示正例和负例之间的余弦距离。对比损失背后的原理是,我们希望相似的向量尽可能接近 1,因为 log(1) = 0 表示最优损失。
结论
在这篇文章中,我强调了使用 GPT 作为嵌入模型而没有微调的常见陷阱。我的评估表明,通过对 GPT 进行对比损失微调,嵌入可以更加有意义和具有区分性。通过了解 GPT 模型的优缺点,并利用定制的损失函数,如对比损失,你可以在选择和使用嵌入模型时做出更明智的决策。希望这篇文章能帮助你为应用程序明智地选择 GPT 模型,并期待听到你的反馈! 😃
如果你对大规模微调 LLM 感兴趣,我还有另一篇相关的文章,可以帮助你实现这一目标。 😃
[## 微调大型语言模型:分布式并行训练指南(使用 DeepSpeed 和 Ray)
开源 LLM 的路径
我还有一篇关于大型语言模型(LLMs)扩展法则的文章。它提供了更多关于为什么语言模型变得越来越大的背景,并解释了背后的原因。祝你学习愉快,分享愉快!干杯。
解构 Transformer 技术与扩展策略
语言模型是基准测试天才还是现实世界问题解决者?
评估语言模型在现实任务中的演变与应用
·发布于 Towards Data Science ·7 分钟阅读·2024 年 3 月 23 日
--

AI 学生在教室里参加考试。图片由作者和 DALL-E 3 创作。
在教育领域,最好的考试是那些挑战学生将所学知识应用于新的和不可预测的方式的考试,这种考试不仅仅是让学生记住事实,而是展示他们的真正理解。我们对语言模型的评估也应遵循相同的模式。随着我们每天看到新的模型涌入 AI 领域,无论是来自像 OpenAI 和 Anthropic 这样的巨头,还是来自较小的研究团队和大学,至关重要的是,我们的模型评估应深入到标准基准测试的表现之上。新兴研究表明,我们一直以来用来衡量模型能力的基准并不像我们曾经认为的那样可靠。为了能够适当地支持新模型,我们的基准必须发展成为与我们要求这些模型和新兴 AI 代理架构解决的现实世界挑战一样动态和复杂。
在本文中,我们将通过回答以下问题来探讨语言模型评估的复杂性:
-
目前语言模型是如何评估的?
-
在基准测试中表现优秀的语言模型有多可靠?
-
语言模型和 AI 代理能否将知识转化为行动?
-
为什么语言模型(或基础模型)需要掌握的不仅仅是文本?
那么,今天语言模型是如何评估的?
今天,大多数模型,无论是大型语言模型(LLM)还是小型语言模型(SLM),都在一组共同的基准上进行评估,包括大规模多任务语言理解(MMLU)、小学数学(GSM8K)和 Big-Bench Hard(BBH)数据集等。
为了更深入地理解每个基准评估的任务类型,以下是来自每个数据集的一些示例问题:
-
MMLU:旨在通过多项选择题衡量模型在预训练过程中学习到的关于 STEM 和人文学科各个学科以及从小学到高级专业理解的各个难度水平的信息。
MMLU 中的医学类大学问题示例:“在对新生儿进行基因检测时,发现了一种罕见的遗传病,该病具有 X 连锁隐性遗传方式。以下哪项陈述可能是关于该病家族系谱的正确描述? A. 所有母系后代都会患此病 B. 女性的发病率大约是男性的两倍 C. 所有患病男性的女儿都会患病 D. 男性和女性的发病率将相等。”(正确答案是 C) [2]
-
GSM8K:语言模型通常难以解答数学问题,GSM8K 数据集评估模型在解答 8.5k 个多样化的小学数学问题时的推理能力和解题能力。
示例:“Dean 的母亲给了他 28 美元去杂货店。Dean 买了 6 辆玩具车和 5 只泰迪熊。每辆玩具车的价格是 12 美元,每只泰迪熊的价格是 1 美元。之后,他的母亲心情好,决定再给他 10 美元。那么 Dean 剩下多少钱?” [3]
-
BBH:该数据集由 23 个任务组成,这些任务来自 Big Bench 数据集,语言模型通常难以解决这些任务。这些任务通常需要多步推理才能成功完成。
示例:“如果你按照这些指示走,你是否会回到起点?左转。右转。走 5 步。走 4 步。转身。走 9 步。选项:— 是 — 否” [4]
Anthropic 最近宣布,Claude-3 模型凭借其 Opus 版本在大多数常见基准上超越了 GPT-4,成为领先的模型。例如,Claude-3 Opus 在 MMLU 上取得了 86.8%的成绩,略微超越了 GPT-4 的 86.4%。Claude-3 Opus 还在 GSM8K 上得到了 95%的成绩,在 BBH 上得到了 86.8%,而 GPT-4 分别为 92%和 83.1% [1]。
尽管像 GPT-4 和 Claude 这样的模型在这些基准测试中的表现令人印象深刻,但这些任务并不总是代表企业想要解决的挑战类型。此外,越来越多的研究表明,模型正在记住基准问题,而不是理解它们。这并不意味着这些模型不能推广到新任务,我们每天都看到 LLM 和 SLM 完成惊人的任务,但这意味着我们应该重新考虑如何评估、打分和推广模型。
在基准测试中表现出色的语言模型有多可靠?
来自微软、自动化研究所(中国科学院)和中国科学技术大学的研究表明,当向不同的语言模型提问经过改写或修改的基准问题时,这些模型的表现明显比直接提问相同基准问题时差。为了展示他们研究的目的,研究人员在论文中展示了 DyVal 2,研究者从像 MMLU 这样的基准中取出问题,通过改写问题、为问题添加额外的答案、改写答案、排列答案顺序或为问题增加额外内容等方式进行修改。在比较“原始”数据集与修改后问题的模型表现时,他们发现性能下降,例如GPT-4 在原始 MMLU 问题上的得分为 84.4,在修改后的 MMLU 问题上的得分为 68.86 [5]。

来源:DyVal2,原始基准与探测基准上模型表现比较
同样,来自亚利桑那大学计算机科学系的研究表明,语言模型中存在大量数据污染 [6]。这意味着基准中的信息正成为模型训练数据的一部分,实际上使得基准得分变得无关紧要,因为模型是在自己已被训练过的信息上进行测试的。
来自复旦大学、同济大学和阿里巴巴的额外研究强调了需要为 AI 代理设计自我进化的动态评估,以应对数据污染和基准记忆化的问题 [7]。这些动态基准将有助于防止模型在预训练过程中记住或学习它们后续将被测试的信息。尽管不断增加新的基准可能会在将旧模型与新模型进行比较时带来挑战,但理想情况下,这些基准将减轻数据污染问题,并使得评估模型理解训练中话题的能力变得更容易。
在评估模型在特定问题上的能力时,我们需要理解模型在预训练过程中学到的信息的理解程度,以及它能多好地将这些知识推广到新的任务或概念,超越它的训练数据。
语言模型和 AI 智能体能否将知识转化为行动?
当我们考虑使用模型作为 AI 智能体代表我们执行任务时,无论是预订假期、写报告,还是为我们研究新话题,我们都需要额外的基准或评估机制来评估这些智能体的可靠性和准确性。大多数希望利用基础模型力量的企业需要允许模型访问集成其独特数据源的各种工具,并要求模型推理和规划何时以及如何有效地使用这些工具。这些类型的任务在许多传统的 LLM 基准中并未得到体现。

来源:AgentVerse,智能体团队与单一智能体在涉及工具调用和代码执行的软件开发任务中的比较结果
为了解决这个问题,许多研究团队正在创建自己的基准和框架,用于评估智能体在涉及工具使用和超出模型训练数据的知识任务中的表现。例如,AgentVerse 的作者评估了智能体团队在执行现实世界任务(如活动策划、软件开发和咨询)方面的表现。研究人员创建了自己的一套 10 个测试任务,并通过人工评估来判断智能体是否执行了正确的操作、使用了合适的工具,并得出了准确的结果。他们发现,采用包含明确阶段(智能体招募、任务规划、独立执行任务和后续评估)周期的智能体团队,比独立智能体的表现更优秀 [8]。
超越单一模态,走向现实世界。为什么语言模型(或基础模型)要掌握文本之外的内容?
在我看来,正在出现的智能体架构和基准是理解语言模型在商业问题上表现如何的重要一步,但一个限制是,大多数仍然集中于文本。随着我们考虑到世界及大多数工作的动态性,我们需要评估模型在文本任务、视觉和听觉任务上表现的智能体系统和模型。AlgoPuzzleVQA 数据集就是一个评估模型是否能够推理、阅读和视觉解读数学和算法谜题的例子 [9]。

来源:语言模型是谜题天才吗? 来自 AlgoPuzzleVQA 数据集的示例问题
尽管企业可能不关心模型能多好地解答难题,但这仍然是理解模型如何推理多模态信息的一个正确方向。
结论
随着我们在日常生活和职业工作中不断采用基础模型,我们需要更多能够反映现实世界问题的评估选项。动态和多模态基准是其中的一个关键组成部分。然而,随着我们引入更多的代理框架和架构,多个 AI 代理协同解决问题,跨模型和框架的评估与比较变得更加具有挑战性。基础模型的真正衡量标准,不在于它们能否征服标准化测试,而在于它们在复杂且往往不可预测的现实世界中理解、适应和行动的能力。通过改变我们对语言模型的评估方式,我们挑战这些模型从基于文本的智力和基准测试专家,发展成为能够应对多方面(和多模态)挑战的全面思考者。
有兴趣进一步讨论或合作?请通过 LinkedIn联系!
离群点更难预测吗?
一项关于机器学习模型在预测离群点时是否更容易出错的实证分析
·发布于 Towards Data Science ·阅读时间:8 分钟·2024 年 2 月 4 日
--

[图像由作者提供]
离群点是与大多数数据群体差异很大的个体。传统上,在实践者中,离群点往往不被信任,这也是为什么通常会采取删除离群点的措施来处理数据集。
然而,在处理真实数据时,离群点是常见的现象。有时候,它们比其他数据点还要重要! 例如,假设有一些离群点是因为他们是非常高付费的客户:你肯定不想丢弃他们,实际上,你可能希望特别关注他们。
离群点的一个有趣的——且未被充分探索的——方面是它们与机器学习模型的互动。我的感觉是,数据科学家们认为离群点会损害他们模型的表现。但这种看法很可能是基于一种先入为主的观念,而非真实的证据。
因此,我在本文中将尝试回答以下问题:
机器学习模型在预测离群点时是否更容易出错?
问题框架
类人机器人会一直存在吗?
类人机器人可能最终解决困扰机器人适应的“旧场地”问题,而最近在多模态变换器和扩散模型方面的突破,可能真的让这一切成为现实。
·发表于 Towards Data Science ·10 分钟阅读·2024 年 3 月 1 日
--
几乎每周都有类人公司发布新更新。Optimus 能走路了?Digit 刚刚移动了一个空的购物篮?Figure 也做到了!似乎真正的公司终于开始关注了。从特斯拉开始,类人机器人现在已经在亚马逊和宝马“工作”,距离进入我们的家庭和花园只差一步。
如果你不是 Medium 的订阅者,你可以 在这里免费阅读这篇文章。

一位类人机器人正在清理(它自己的?)乱摊子,同时准备做饭。类人形态在无缝集成现有价值创造流程方面具有巨大潜力。图片:作者通过 miramuseai.net 提供
我们是孤独的吗?
遭遇外星生命的真实概率(德雷克方程系列的第五部分)
·发表于Towards Data Science ·阅读时间 11 分钟·2024 年 9 月 8 日
--
回顾:
在本系列文章中,我们探讨了可能导致外星文明存在的各种因素,从宜居行星的数量到智能文明是否发展出通讯技术的概率。在本文的最后一篇,我们探讨了终极问题:我们曾经遇到过外星生命吗?我们未来会遇到它们吗?

所有图片均由作者使用 Midjourney 创作。
第 10 步:理性应对外星遭遇
对外星生命的探索长期以来一直是科学、猜测和耸人听闻的混合体。从 UFO 目击到政府的 UAP(不明空中现象)报告,公众的想象力一直被外星遭遇的想法所吸引。然而,从科学的角度来看,我们是否已经遇到过外星生命——或者将来是否有可能遇到外星生命,究竟有多大可能?
这时,理性、数据驱动的方法就显得尤为重要。通过结合德雷克方程、现代模拟和贝叶斯概率模型,我们终于可以计算出过去和未来遭遇外星生命的可能性。
为什么这一步很重要
很容易被潜在外星遭遇的兴奋所感染,但现实却远比想象复杂。即使银河系中存在智能文明,它们与我们文明在时间和距离上的重叠几率也极其微小。本文将帮助我们量化这些遭遇的可能性,基于过去和未来的可能性,为我们提供一个更清晰的概率图景。
贝叶斯概率与外星遭遇
贝叶斯推理使我们能够在新证据(或证据缺乏)出现时更新我们的概率估计。在外星接触的案例中,我们可以使用这种方法来评估过去和未来接触的概率。
让我们来解析贝叶斯方法:
-
P(H|E):给定当前证据,外星人存在并且我们已与其接触的概率。
-
P(H):我们的先验概率,或是我们对外星人相遇已经发生或将要发生的可能性的初步假设。
-
P(E|H):假设相遇假设为真,当前证据(例如没有确认的外星接触)出现的可能性。
-
P(E):证据的总体概率,它考虑了所有可能的假设。
我们将使用这个框架来计算过去和未来的相遇。
外星接触的贝叶斯蒙特卡洛模拟:理解方法
为了量化过去和未来外星接触的概率,我们采用了贝叶斯框架,并结合蒙特卡洛模拟来处理参数中固有的不确定性。本节将带您了解这两种方法背后的基本原理和方法论,然后展示实际代码。
为什么使用贝叶斯分析?
贝叶斯分析是一种基于新证据更新事件概率的强大方法。在我们的例子中,事件是我们是否已经遇到或将遇到外星文明。通过结合先验知识和现有的(尽管有限的)证据——例如缺乏确认的接触——我们可以细化我们的估计并量化过去和未来外星人接触的相关不确定性。
贝叶斯定理允许我们计算后验概率——换句话说,在我们假设和观察的基础上,外星人相遇的可能性。这个过程至关重要,因为它会随着新信息的出现不断更新我们的理解,无论是外星生命的确凿证据,还是继续缺乏接触的情况。
为什么使用蒙特卡洛模拟?
鉴于德雷克方程和其他与外星人接触相关的概率中的不确定性和变异性,使用一组固定的值来估计概率是不现实的。相反,蒙特卡洛模拟使我们能够为每个参数(例如接触的可能性或外星生命存在的先验概率)采样一组广泛的合理值。
通过使用这些不同的值进行数千次模拟,我们可以探索一系列结果,而不是依赖于僵化的点估计。结果是我们对过去和未来相遇的可能性有了更细致的理解,同时对每种情景的概率分布也有了更清晰的认识。
现在,让我们深入了解实际的代码实现:
**********************************;
**********************************;
/* Set the random seed for reproducibility */
data _null_;
call streaminit(1234);
run;
/* Number of simulations */
%let num_simulations = 100000;
/* Number of civilizations to generate */
%let num_civilizations = 2364;
/* Galactic radius and height in light years */
%let galactic_radius = 50000;
%let galactic_height = 1300;
/* Earth's position (assumed to be at 3/4 of the galactic radius) */
%let earth_position_x = &galactic_radius * 3 / 4;
%let earth_position_y = 0;
%let earth_position_z = 0;
/* Create a dataset to store civilization positions */
data civilization_positions;
length Civilization $10.;
input Civilization $ Position_X Position_Y Position_Z;
datalines;
Earth &earth_position_x &earth_position_y &earth_position_z
;
run;
/* Generate random positions for other civilizations */
data civilization_positions;
set civilization_positions;
do i = 1 to &num_civilizations;
Position_X = rand("Uniform") * &galactic_radius;
Position_Y = rand("Uniform") * 2 * &galactic_height - &galactic_height;
Position_Z = rand("Uniform") * 2 * &galactic_height - &galactic_height;
Civilization = "Civilization " || strip(put(i, 8.));
output;
end;
drop i;
run;
/* Calculate the distance between civilizations and Earth */
data civilization_distances;
set civilization_positions;
Distance = sqrt((Position_X - &earth_position_x)**2 + (Position_Y - &earth_position_y)**2 + (Position_Z - &earth_position_z)**2);
run;
/* Calculate the minimum distance to Earth for each civilization */
proc sql;
create table civilization_min_distance as
select Civilization, Distance as Min_Distance
from civilization_distances
order by Distance;
quit;
/* Calculate the probability of encountering civilizations based on distance */
data probability_encounter;
set civilization_min_distance;
Probability = 1 / (1 + Min_Distance);
run;
/* Calculate the average probability for each distance band */
proc sql;
create table average_probability as
select case
when Min_Distance <= 1000 then 'Close'
when Min_Distance > 1000 and Min_Distance <= 3000 then 'Medium'
when Min_Distance > 3000 then 'Far'
end as Distance_Band,
avg(Probability) as Average_Probability
from probability_encounter
group by case
when Min_Distance <= 1000 then 'Close'
when Min_Distance > 1000 and Min_Distance <= 3000 then 'Medium'
when Min_Distance > 3000 then 'Far'
end;
quit;
/* Print the result */
proc print data=average_probability;
run;
/* Select the closest civilization to Earth and its associated probability */
proc sql;
create table closest_civilization as
select Civilization, Min_Distance, Probability
from probability_encounter
where Min_Distance = (select min(Min_Distance) from probability_encounter);
quit;
/* Print the result */
proc print data=closest_civilization;
run;
/*Bayesian analysis for probability of encountering aliens in the past or future*/
/* Set seed for reproducibility */
%let num_iterations = 100;
/* Create Bayesian analysis dataset */
data bayesian_analysis;
call streaminit(123);
/* Define variables for posterior probabilities */
array posterior_past[&num_iterations];
array posterior_future[&num_iterations];
do i = 1 to &num_iterations;
/* Sample prior probabilities and likelihoods for past encounters */
prior_past = rand("Uniform", 0.0001, 0.01); /* P(Past encounter) */
likelihood_past_encounter = rand("Uniform", 0.001, 0.1); /* P(No contact | Past encounter) */
likelihood_no_encounter_past = rand("Uniform", 0.8, 0.99); /* P(No contact | No encounter) */
/* Calculate posterior probability for past encounter using Bayes' Theorem */
numerator_past = prior_past * likelihood_past_encounter;
denominator_past = numerator_past + (1 - prior_past) * likelihood_no_encounter_past;
posterior_past[i] = numerator_past / denominator_past;
/* Sample prior probabilities and likelihoods for future encounters */
prior_future = rand("Uniform", 0.001, 0.05); /* P(Future encounter) */
likelihood_future_encounter = rand("Uniform", 0.01, 0.1); /* P(No contact | Future encounter) */
likelihood_no_encounter_future = rand("Uniform", 0.8, 0.99); /* P(No contact | No encounter) */
/* Calculate posterior probability for future encounter using Bayes' Theorem */
numerator_future = prior_future * likelihood_future_encounter;
denominator_future = numerator_future + (1 - prior_future) * likelihood_no_encounter_future;
posterior_future[i] = numerator_future / denominator_future;
end;
/* Output the results */
do i = 1 to &num_iterations;
posterior_past_value = posterior_past[i];
posterior_future_value = posterior_future[i];
output;
end;
keep posterior_past_value posterior_future_value;
run;
/* Summary statistics for the posterior probabilities */
proc means data=bayesian_analysis mean std min max;
var posterior_past_value posterior_future_value;
run;
/* Distribution histograms for the posterior probabilities */
proc sgplot data=bayesian_analysis;
histogram posterior_past_value / transparency=0.5 fillattrs=(color=blue) binwidth=0.00001;
title "Distribution of Posterior Probabilities for Past Encounters";
run;
proc sgplot data=bayesian_analysis;
histogram posterior_future_value / transparency=0.5 fillattrs=(color=green) binwidth=0.0001;
title "Distribution of Posterior Probabilities for Future Encounters";
run;
使用这段代码,我们在一系列假设下模拟了过去和未来的外星人遭遇,从而利用贝叶斯推理估算每种情境的可能性。在过程结束时,我们得到了过去和未来外星人接触的概率分布,接下来我们将分析这些分布以获取更多洞见。
分析表格和图形输出
表格输出
该表格展示了后验概率的汇总统计数据,这些概率表示了过去和未来外星人遭遇的可能性:

posterior_past_value:
-
平均值:0.000306778
-
标准差:0.000262715
-
最小值:8.258388E-6
-
最大值:0.0010357
posterior_future_value:
-
平均值:0.0015038
-
标准差:0.0012378
-
最小值:0.000036464
-
最大值:0.0052718
解释:
-
过去遭遇:过去遭遇的平均概率约为 0.0003,或约 0.03%。更直观地说,这意味着我们过去遇到外星人的几率约为1/3,260。
-
未来遭遇:未来遭遇的平均概率较高,大约为 0.0015,即 0.15%。这意味着未来遇到外星人的几率大约为1/667。
这些数值的范围表明存在相当大的不确定性,这也符合数据和假设的局限性。过去遭遇的最小值低至 0.000008(即 1/125,000),而最大值则接近 0.001(即 1/1,000)。未来遭遇的值从 0.000036(1/27,397)到 0.005(即 1/190)不等。
图形输出
-
过去遭遇的后验概率分布:
直方图显示了一个广泛的分布,大多数概率集中在较低的范围内,低于 0.0005。 这表明,在我们的模拟中,过去遭遇外星人的可能性普遍较低,但仍然有少数几次情况下,概率较高,接近 0.001(即千分之一)。

2. 未来遭遇的后验概率分布:
未来遭遇的分布更为分散,最高概率的发生集中在 0.0005 到 0.002 之间。 这表明,尽管未来遭遇外星人的可能性仍然较低,但其概率比过去的遭遇更高。分布的形态表明,尽管接触的几率较低,但根据不同假设的结果,未来遭遇发生的可能性并非微不足道。

关键要点和概率计算
过去遭遇:
过去接触的平均后验概率约为 0.0003。用简单的概率来说,这相当于1/3,260的机会表明人类可能已经与外星生命接触过,只是没有意识到。这个广泛的分布反映了不确定性,概率从最低的 1/125,000 到最高的 1/1,000 不等,这取决于我们对先验概率和证据的假设。
未来的接触:
未来接触的平均后验概率为 0.0015,这意味着有1/667的机会我们将来某个时候会遇到外星生命。尽管这种可能性仍然不大,但与过去的接触相比,这个更高的概率表明未来接触的机会更大(尽管仍然微小)。该分布的范围从最低的 1/27,397 到更乐观的 1/190,反映了可能结果的广泛范围。
将一切联系起来:这意味着什么?
在这系列研究中,我们的旅程是一次令人着迷的概率、不确定性和最宏大的问题的探索:我们在宇宙中是否孤独?以德雷克方程为框架,我们探讨了从适宜居住行星的形成到智能、能沟通的文明的发展每一个步骤。但这一切意味着什么,为什么我们要采取这种方法?
更大的图景
-
我们为什么要这么做: 我们的目标简单却深刻:理性地评估外星文明存在的可能性,更重要的是,评估我们是否已经与它们接触过,或者未来是否会接触。流行文化中有很多关于 UFO、目击事件和神秘信号的猜测,但我们希望以科学的方式来处理这个问题。通过运用德雷克方程,使用蒙特卡洛模拟,并应用贝叶斯推理,我们试图为这个模糊的问题提供一些具体的数字。
-
我们是如何做到的: 我们采用的方法不是寻求确定的答案,而是理解可能性范围。德雷克方程的每一步都带来了巨大的不确定性——有多少适宜居住的行星,多少发展出了生命,多少文明在向宇宙发出信号。为了应对这种不确定性,我们使用了蒙特卡洛模拟,这使我们能够考虑广泛的结果并计算分布,而不是单一的估算值。贝叶斯分析帮助我们根据当前的证据(或者缺乏证据)来细化这些概率,从而提供更细致的外星接触预测。
-
结果意味着什么: 这些数字乍一看可能很小,但它们的意义却非常重大。过去接触的概率(大约为 1/3,260)较低,这并不奇怪,因为缺乏确凿证据。然而,这些概率并非为零,这本身就值得注意——尽管概率很小,但我们已经可能遇到过外星生命,只是我们没有意识到。
-
未来接触的概率稍微乐观一些:大约是 1/667。尽管仍然是一个小概率事件,但这表明如果我们继续寻找,未来某个时刻我们可能会发现或与外星文明沟通。未来是不可预测的,但随着技术的进步,以及天体生物学和太空探索领域的不断扩展,这一可能性依然存在。
总结:
这项分析带给我们一个令人深思但充满希望的结论。宇宙浩瀚,星际之间的距离——更不用说文明之间的距离——令人震撼。宇宙的结构,再加上文明兴衰的时间尺度,表明遭遇的可能性微乎其微,但并非不可能。
这里的真正奇迹不仅仅在于数字,而在于它们所代表的意义:人类好奇心与我们基于理性和证据的探索能力的交汇点。我们可能是孤独的,或者我们可能在某一天与另一个智能文明共享信号。无论如何,我们为了量化这些概率所做的工作表明,这项搜索本身是值得的。它揭示了我们仍然需要了解宇宙及我们在其中的位置的多么庞大的知识。
虽然成功的机会可能不大,但未来相遇的可能性——无论多么渺茫——依然给我们提供了继续仰望星空的理由。宇宙充满了谜团,我们解决这些谜团的旅程仍在继续。无论我们是否能够与外星文明取得联系,寻找的过程本身就推动了科学、哲学以及我们集体想象力的边界。
这就是这项工作的结论——并没有给出具体的答案,而是提出了深刻的问题,这些问题将继续激发好奇心、探索和对未来几代人的奇迹感。寻找外星生命不仅是对宇宙的探索,也是对我们自身的探索。
如果你错过了之前的部分,可以从这里开始。
除非另有说明,所有图片均为作者提供
你意识到你数据专长在推动商业盈利能力方面的潜力吗?
一位供应链数据科学家的反思,他偶然发现了数据分析在帮助大小企业方面的力量。
·发表于Towards Data Science ·11 分钟阅读·2024 年 10 月 24 日
--

图片来源:Fabian Blank via Unsplash
我在分析项目中面临的最大挑战是估算一个解决方案的投资回报率(ROI)。
销售预测引擎的投资回报率(ROI)是多少?
这是决策者在你提议设计工具来解决他们的运营或业务问题时常常会问的第一个问题。
作为物流解决方案设计经理,我的工作是为零售和时尚公司定价仓储和运输运营。
因此,估算我方案的投资回报率(ROI)变得稍微可控一些,但仍然很难说服决策者。
例如,我会解释:“这个算法将提高拣货生产力 25%,这将导致可变成本减少 12%。”
这些成功激励我分享涵盖 60 多个运营案例研究的解决方案,已发布在这个Medium 博客上。
虽然我的重点是改善物流运营,但在过程中发生了一些意想不到的事情:
如果我将重点转向商业盈利能力会怎样?
在一个项目中,我将这些工具应用于一个商业案例研究:最大化一家面包店的盈利能力.
我收到的反馈让我意识到,无论是小企业还是大企业,都需要优化他们的利润率,并且自动化数据驱动的决策。
通过使用商业语言,我能够更有效地销售我的分析解决方案。
在这篇文章中,我想分享我通过发现数据的力量来帮助企业所获得的见解——以及为什么我认为你也应该考虑这条道路。
Summary
**I. Introduction**
Exploring the challenge of proving ROI in analytics projects
**II. How Did I Develop My Business Acumen?**
Sharing my early career experience as a Solution Design Manager
**III. My journey discovering data analytics for business optimization**
Optimization methods to help a bakery business improve profitability.
**VI. Use your Analytics Skills to Solve Business Problems**
How data analytics can answer the needs of decision makers.
**V. How to Adapt your Analytics Approach to Business Problems?**
The importance of translating business problems into simple analytics solutions
**VI. Conclusion**
Understanding processes is an important skill to answer business problems
我是如何培养我的商业敏锐度的?
在我职业生涯的前四年,我为跨越亚洲的主要国际公司设计了仓储和运输解决方案。
供应链解决方案设计经理的工作是什么?
例如,想象一下像 Costco 这样的零售商希望在上海建立一个配送中心。
-
他们提供关于交易量和过程需求的数据,通常是在 RFP 中。
-
我的工作是设计解决方案(布局、人员配置、设备)并基于超过 100 个参数的成本模型来制定定价。
-
我们向客户展示解决方案,并提供详细的定价表。
我们最关心的是什么?毛利率!
为了赢得这个项目,我必须确保具有竞争力的定价,同时保持最低利润率,并且不能低于成本定价。

定价结构 — (图片来自作者)
例如,如果我报价每箱€1.25 用于拣货,我就能准确知道成本和利润是如何分解的
-
€0.57用于劳动力成本
-
€0.37用于设备和消耗品
-
€0.20用于固定仓储成本
-
€0.11为我们的毛利(销售额的 8.8%利润)
接下来是什么?你赢得了这笔生意,并签署了一个为期三年的 500 万欧元预算合同。
但是,如果客户希望将价格降低到每箱€1.10 会发生什么呢?
这在低利润公司中经常发生,比如传统零售商、汽车售后分销商或消费品公司。
你必须找到减少成本的方式,同时保持 8.8%的利润率不变。
因此,我一直在使用数据分析来
-
使用帕累托原则和 Python 减少仓库空间
-
使用路径规划算法提高仓库拣货效率
-
使用图论和 Python 优化运输路线

持续改进举措 — (图片来自作者)
以及许多其他由数据分析支持的运营改进,这些内容在这篇 Medium 博客中分享。
我可以将类似的方法应用到物流以外的领域吗?
我探索数据分析进行商业优化的旅程
这始于我为一家小型物流公司提供咨询服务,该公司将产品运送到巴黎的面包店。
以我作为解决方案设计师的天真视角,我与一家面包店连锁的老板交谈,了解他们的商业模式:
-
一根售价 €1.50 的法棍的利润率是多少(%)?
-
每卖出一个可颂,劳动力成本是多少(€)?
-
你的成本中有多少比例(%)是固定的?
-
你店里最有利润的商品是什么?
令我惊讶的是,他们无法回答这些问题中的任何一个。
我意识到,还有一种做生意的方式,其中价格的设定并不明确了解基础成本,而运营的可视性几乎为零。
这是一个巨大的机会——如果我们为这些企业主提供他们迫切需要的可视性和洞察力,会怎么样?
所以,我接受了模拟一家面包店的挑战。
我应用了在物流持续改进项目中使用的相同方法论:
-
了解他们当前的运营:固定和变动成本、瓶颈和收入来源。
-
收集和处理数据,做出必要的假设(例如,每种商品的生产成本和销售价格)。
-
建立一个 Python 模型来复制他们当前的设置并模拟不同的场景。
结果是文章中提出的解决方案:使用 Python 最大化商业盈利。

应用于面包店盈利案例的方法论 — (图像来源:作者)
最有利可图的产品组合是什么?
考虑到有限的资源来生产和储存产品,这个模型可以提供最佳的产品组合来销售。
对客户的影响超出了我的预期。
“这是我第一次能够估算我的商业策略对整体盈利能力的影响,”那位老板说。

添加指标:设备和劳动力使用比例 — (图像来源:作者)
经过几次迭代,我们改进了算法,提供了有趣的见解。
Samir:“你们生产的瓶颈是人力资源和烤箱的产能。”
对于这样一个简单的算法,用不到一小时编写,感知到的商业价值远远超过我过去的实验。
这标志着我在使用数据分析来推动商业影响的方法上的转折点。
你如何利用你的分析技能来支持企业?
利用你的分析技能来解决商业问题
自从开始我的咨询业务并开发我的可持续供应链 SaaS以来,我与来自多个行业的几十位企业家进行了交流。
这是一个评估我用数据和商业洞察力解决他们问题能力的机会。
他们需要什么?让我们来看看一个例子!
我的一个朋友,他经营着一家小型餐饮企业,使用我的模型来支持决策并最大化收入。
我朋友的商业模型 —— (图片来源:作者)
他们从中国采购可再生杯子,并通过空运或海运将其运送到当地仓库。
从仓库出发,杯子被送到咖啡店和分销商。
数据分析在这种情况下能解决哪些问题?
他面临的最大挑战之一是库存管理和现金流。
“我们不得不拒绝订单,因为我们没有足够的现金支付供应商补货费用,”他解释道。
核心问题很明确:他们需要清楚了解他们的财务流动。
咖啡杯的价值链 —— (图片来源:作者)
我需要列出所有相关参数,来解决这个问题并构建价值链模型。
-
供应商的付款条款和交货周期。
-
与客户的服务水平协议和付款条款,按销售渠道(直接客户与分销商)进行细分。
-
固定的运营成本和现金流管理。
业务的模拟引擎 —— (图片来源:作者)
结果是一个用 Python 编写的模拟引擎(具有合理的粒度),它复制了我朋友的业务。
接下来是什么?我们可以回答我朋友所有的问题!
我朋友最大的不满是缺乏可视化和无法验证假设。
-
如果我们将库存覆盖周期从 8 周减少到 6 周,会怎么样?
-
使用空运交付会更具成本效益吗?
-
我们是否应该改变销售策略,专注于分销商?
由于涉及的参数复杂,他之前无法得到明确的答案。
所有场景的总结 —— (图片来源:作者)
使用该模型模拟这些假设仅需几秒钟。
这种方法通过确认他们可以安全地减少库存覆盖,而不会影响客户供应,节省了数千欧元。
事实证明,这是一个非常强大的工具,通过简单的分析构建,帮助管理他的业务
-
了解影响价值链的参数
-
模拟“如果”场景来评估商业战略
-
寻找最优设置以最小化成本并最大化盈利
这听起来像是你的业务面临的问题吗?欲了解更多详情,请查看本文
## 使用 Python 进行商业规划 — 库存和现金流管理
小企业的商业规划,用于管理库存、预测流动性需求并最大化盈利能力……
towardsdatascience.com
模型提供的洞察帮助我的朋友减少了运营所需的现金,并降低了销售成本(COGS)。
后来,我们通过解决收入最大化问题,超越了成本降低。
他回来说有另一个与定价策略相关的请求。
定价策略示例 — (图片来源:作者)
与他们的新商业伙伴——一位在餐饮业拥有资本和市场专业知识的专家合作,他们正在制定促进营业额增长的策略。
她提出了多种定价策略,见上文,旨在增加客户的订单数量。
我的朋友:“我怎么评估这些策略及其对盈利能力的影响?”
你只需要通过添加定价模块来调整模型,并估算其对盈利能力的影响。
我们可以通过多种销售量场景估算每个定价策略的盈利能力及其他商业指标。
模拟定价策略 2 文章: [链接] — (图片来源:作者)
在上表中,我们模拟了定价策略二的影响,假设了七种营业额情景。
Samir:“你需要获得+200%的增长,才能通过这个策略恢复你的基线情景的盈利能力。”
我们可以通过几次点击评估每个提议的策略。
你如何利用数据分析帮助小企业在保持或改善其他方面的同时,最大化收入?
towardsdatascience.com
模型生成的洞察帮助解决了联合创始人之间的激烈讨论。
这一过程使他们就定价策略达成共识,且基于精确的盈利预测。
听起来不错,对吧?
但是,成功实施这种项目需要什么?
如何将分析方法适应商业问题?
正如我多次提到的,为这些项目设计的分析解决方案通常是“技术性基础”的。
我会说,80%的努力都在于将商业问题转化为分析解决方案。
保持好奇心!表现出对商业模式的兴趣。
这是一个积极的过程,要求你提出正确的问题,以了解哪些指标对企业主重要,以及如何建模这些流程。

咖啡杯的价值链 — (图片来源:作者)
在达到这种模型化水平之前,我与我的朋友进行了多次迭代,确保我的模型准确反映了他业务的现实。
因此,你需要使你的模型洞察结果对非技术观众可访问,以便他们能帮助你评估结果的准确性。
这种解决方案有需求吗?
这是一个市场需求。
自从我开始全职担任顾问以来,我收到了更多这种类型项目的请求,而不是依赖我的核心供应链工程技能。
由于你直接影响盈利能力并为商业决策者提供可见性,因此获得项目的资本支出(CAPEX)和参与变得更容易。
我们从这两个例子中学到了什么?
-
企业主缺乏对其流程和财务流动的可见性。
-
了解商业模式和流程对于设计正确的仿真模型至关重要。
-
决策者重视数据驱动的洞察力,以支持战略项目。
这种方法适用于各种不同的行业和公司规模的商业案例。
结论
我从未想到过,我会从一个优化物流运营的解决方案设计经理,踏上成为帮助企业提高盈利能力的顾问之路。
这是因为我发现先进的分析工具在业务优化中非常有效。
你不需要关注分析解决方案的复杂性(机器学习、优化或生成 AI),而是要理解业务本身。
你听说过可持续性吗?
我目前正在学习关于欧洲企业可持续发展报告指令(CSRD)的内容。
这将塑造公司如何报告其可持续发展努力。
这个想法是引入更严格的透明度要求,特别是在环境、社会和治理(ESG)指标方面。
ESG 支柱展示 文章:[链接] — (图片来源:作者)
本文中介绍的我们在商业盈利能力方面的方法也可以应用于可持续性挑战。
决策者:“我们需要将分销网络的范围 3 排放量减少 30%。”
例如,在这篇关于绿色库存管理的文章中,我分享了一个关于减少商店配送碳排放的案例研究。
这个方法的目标是找到最优配送频率,以最小化 CO2 排放。
绿色库存管理文章:[链接] — (图片由作者提供)
文章中提出的解决这个操作问题的方法是类似的。
-
了解运营情况
为时尚零售商店准备并交付订单
-
使用 Python 构建一个仿真模型来估算排放量。
输入:销售数据和配送频率 / 输出:CO2 排放量
-
测试多个配送频率的不同情境,并计算排放量的减少。
如果你有兴趣寻找减少排放的解决方案,
模拟商店配送频率对时尚零售商 CO2 排放的影响。
towardsdatascience.com
针对不同问题的类似方法。
对于任何案例研究,保持好奇心、提出正确的问题并与决策者互动是至关重要的,这样才能真正理解他们的痛点。
通过这样做,你可以创建出能为技术和非技术受众提供可操作性洞察的模型。
根据我的经验,这些洞察可以显著提高盈利能力,并支持运营转型中的决策制定。
如果你还没有考虑将你的专业知识应用到商业挑战中,现在是时候了!
你可能会发现一种新的方式来创造影响——就像我做的一样。
关于我
让我们在Linkedin和Twitter上建立联系。我是一名使用数据分析来改善物流运营并降低成本的供应链工程师。
若需关于商业分析和可持续供应链转型的咨询或建议,欢迎通过Logigreen Consulting与我联系。
如果你对数据分析和供应链感兴趣,请访问我的网站。
这是一个专注于数据科学、个人生产力、自动化、运筹学和可持续发展的技术博客。
你确定要成为数据科学经理吗?
在你阅读完这些内容之前,不要急于追求这个华丽的职位。
·发表于Towards Data Science ·阅读时间 15 分钟·2024 年 11 月 22 日
--

图片由Benjamin Elliott提供,来自Unsplash
想象一下。 你刚刚交付了一个绝妙的项目,团队气氛热烈,然后——砰!你被问到,‘你有没有考虑过领导这个团队?’ 听起来很诱人,对吧?但等等——你真的知道自己在做什么吗?
作为一名数据科学经理,我亲眼见证了我的团队从 0 人增长到 12 名数据科学家,并帮助我们将数据科学领域的规模从 20 人扩大到 50 人以上。我还看到一些同事经理们离开,寻求新的挑战。这两种情况都产生了一个空缺:需要一位经理来领导数据科学团队。填补这个空缺可能是一个绝佳的机会,但我也看到许多同事未能适应这一转变。
向管理层的过渡可不是小事。的确,它有其好处。但没有人谈论其中的取舍。大多数初任经理完全没有准备——这就是挫败感开始的地方。
你在这篇博客中将能阅读到什么?
我将介绍在你转向管理岗位之前需要考虑的主要因素。
你将阅读到以下内容:
- 进行高层次的自我反思。 你为什么会考虑做出这个决定?
ARIMA:一种预测时间序列数据的模型
学习 ARIMA 模型如何工作,并学习如何在 Python 中实现它们以进行准确预测
·发表于Towards Data Science ·14 分钟阅读·2024 年 10 月 30 日
--

图片由Jean-Luc Picard提供,来源:Unsplash
缩写 ARIMA 代表自回归积分滑动平均,是一类用于分析时间序列数据的统计模型。这个模型可以用来预测数据未来的发展趋势,例如在科学或技术领域。ARIMA 方法主要用于存在所谓的时间自相关的情况,也就是说,简单来说,时间序列表现出某种趋势。
在本文中,我们将解释与 ARIMA 模型相关的所有方面,从时间序列数据及其特殊特征的简单介绍开始,直到在文章末尾训练我们自己的模型并进行详细评估。
什么是时间序列数据?
时间序列数据是一种特殊形式的数据集,其中测量是按规律的时间间隔进行的。这使得这样的数据集合具有其他数据集所缺少的一个额外维度,即时间维度。时间序列数据通常用于金融和经济领域,或者在自然科学中,当需要测量一个系统随时间变化时。
数组 — 数据科学家的数据结构与算法
动态数组和静态数组在后台是如何工作的
·发表于 Towards Data Science ·阅读时长 6 分钟·2024 年 10 月 7 日
--

图片由 Caspar Camille Rubin 提供,来源于 Unsplash
作为数据科学家,我们很少会被问到类似 LeetCode 的问题,因此我们学习数据结构和算法的需求不如软件工程师那样迫切。
然而,能够编写高效的代码对于你的数据科学职业生涯是一个巨大的助推器。试想,如果你能成为一名既懂得实现机器学习模型,又理解写代码的最佳实践,同时对软件工程有一定了解并拥有相关知识的“数据科学家”会怎样?
你突然变得非常有价值,几乎成了市场上的独角兽。这就是为什么我开始学习数据结构与算法课程,并计划分享我所学到的内容。
本文将专门讨论数组、它们是如何在后台工作的以及它们的不同类型。
数据结构
数据结构是计算机内存储信息的一种便捷方式。正如 维基百科 所定义的:
数据结构是一种数据组织和存储格式,通常选择它是为了高效访问数据。更准确地说,数据结构是一种…
Python 和 Excel VBA 中的数组
通过简单的例子学习数组
·发表于 Towards Data Science ·阅读时间 8 分钟·2024 年 1 月 23 日
--
作为一个没有接受过正式编程教育的人,我的编程旅程一直由自学塑造。意识到回顾基础编程概念的重要性,我发现扎实的基础能够提升整体编程体验。在本教程中,我们将深入探讨一个基本概念——数组。具体来说,我们将通过简单的例子,探讨 Python 和 Excel VBA 中的数组概念。让我们开始吧。

图片由 Nathan Dumlao 提供,来源于 Unsplash
1. Python 中的数组
数组是一个特殊的变量,可以保存一个或多个任何数据类型的值。在 Python 中,与类似的数据类型(如列表)不同,没有内置的数组支持。然而,可以使用 numpy 包的 array 模块来创建数组。numpy 数组对象的索引总是从 0 开始。可以通过引用 -1 来访问 numpy 数组中的最后一个项。一个 numpy 数组可以包含某一特定数据类型的变量或多种数据类型。
下面的代码片段展示了这一点。代码片段还展示了如何从 numpy 数组中访问形状(维度,即行、列)、大小(元素个数)和长度(容器中的项目数量,即行数)。
import numpy as np
simple_array = np.array([1, 2, 3])
mixed_array = np.array([1, 2, 3, "a", "b", "c", 4.5])
print ("Simple array: ", simple_array)
print ("First element of simple_array: ", simple_array[0])
print ("Shape of simple_array: ", simple_array.shape)
print ("Size of simple_array; ", simple_array.size)
print ("\n")
print ("Mixed array: ", mixed_array)
print ("Last element of mixed_array: ", mixed_array[-1])
print ("Length of mixed_array: ", len(mixed_array))
1.1 使用 numpy 数组进行代数矩阵运算
由于其灵活的结构,numpy 数组在创建不同维度的矩阵对象并对其进行操作时非常方便。上面的截图展示了 1 维数组对象的例子。
在下面,我创建了两个数组对象 a 和 b,它们都是二维数组,可以看作是 2*2 的矩阵。计算这两个矩阵的点积就像执行 np.dot(a, b) 一样简单。在点积中,a 和 b 被视为向量(既有大小又有方向的对象)。在矩阵乘法中,矩阵 a 中的每个元素与矩阵 b 中对应的元素相乘。例如,a11(第一行第一列的元素)与 b11 相乘,以此类推。
a = np.array([[0, 1],[2,3]])
b = np.array([[3,4],[5,6]])
print ("Dot Product of a and b: \n", np.dot(a,b))
print ("Matrix multiplication of a and b \n",a*b)
此外,还可以执行其他矩阵操作,如加法、减法和转置。要获得矩阵的行列式,可以使用 np.linalg.det(a)。要获得矩阵的乘法逆,可以使用 np.linalg.inv(a)。
print (“Addition of a and b:\n”, np.add(a, b))
print ("Also addition of a and b:\n", a+b)
print ("Transpose of a:\n", a.T)
print ("Determinant of a:\n", np.linalg.det(a))
print ("Inverse of a:\n", np.linalg.inv(a))
1.2 从列表对象创建 m*n 形状的 numpy 数组
我有两个列表,分别叫做 countries_lived 和 capitals,它们包含我曾居住过的国家及其对应的首都。
countries_lived = [“Nepal”,”India”,”Germany”,”Netherlands”]
capitals = [“Kathmandu”,”New Delhi”,”Berlin”,”Amsterdam”]
要创建一个包含这些列表对象的数组,我可以使用 np.array([countries_lived, capitals])。这将返回一个形状为 2*4(即 2 行 4 列)的数组。如果我希望每一行包含一个国家及其对应的首都,我只需转置该数组即可。
array1 = np.array([countries_lived, capitals])
print ("array1:\n", array1)
print ("Shape of array1:\n", array1.shape)
print ("Size of array1:\n", array1.size)
array2 = np.array([countries_lived, capitals]).T
print ("array2:\n", array2)
print ("Shape of array2:\n", array2.shape)
print ("Size of array2:\n", array2.size)
1.3 向 numpy 数组追加一个元素并创建一个 dataframe
比如说,我想将 France 和 Paris 作为新的一行追加到 array2 中,可以使用语法 np.append(arr, values, axis = None) 来实现。values 必须与 arr 具有相同的形状,轴(axis)除外。如果未指定轴,arr 和 values 会在使用之前被展平。
如下所示,我将新元素作为新的一行追加到数组中。最后,形状为 (5,2) 的 array2 被用来创建一个包含 Country 和 Capital 列的数据框对象 df。
array2 = np.append(array2, [[“France”,”Paris”]], axis = 0)
print ("array2 after appening new row: \n", array2)
import pandas as pd
df = pd.DataFrame(array2,
columns = ["Country", "Capital"])
df
2. Excel VBA 中的数组
与 Python 类似,Excel VBA 中的数组也是一组变量。数组的下界可以从 0 或 1 开始,Excel VBA 的默认下界是 0。但是,可以通过在每个模块顶部声明 Option Base 0 或 Option Base 1 来指定数组的下界。
要检测数组的下界和上界,可以分别使用 Lbound(array_name) 和 Ubound(array_name)。
2.1 声明数组
数组可以通过使用 Public 关键字声明为公共(即全局)数组。在 Excel VBA 中将数组或任何其他变量声明为公共变量,允许在任何模块或子程序中使用,无需重新声明。
Public countries(1 to 4) as String
Public capitals(4) as String
Public countries_visited() as String
另外,数组也可以在子程序内局部声明,只需使用 Dim 关键字即可。这样声明的数组只能在特定的子程序内部使用。
Dim countries(1 to 4) as String
Dim capitals(4) as String
在上述示例中,数组的大小也被指定。指定 1 到 4 或仅指定 4 都表示数组的大小为 4。
2.2 一维数组
一维数组是通过声明行数(例如,从 1 到 5),即数组包含的元素数量来赋值的。下面给出了一个创建我曾经居住过的四个国家的一维数组的示例。它将把这些国家的名称打印到 Excel 文件工作表的 A 列中。
Option Base 1
Sub array_1d()
countries(1) = "Nepal"
countries(2) = "India"
countries(3) = "Germany"
countries(4) = "Netherlands"
Dim i As Integer
Range("A1").Value = "Country"
For i = 1 To 4
Range("A" & i + 1).Value = countries(i)
Next i
End Sub
运行array_1d子程序的输出如下:

array_1d 子程序的输出。图片来源:作者。
2.2 二维数组
二维数组通过声明行数和列数来定义。在以下示例中,我声明了一个名为country_capital的二维数组。每一行的第一个元素对应上一节中声明的countries数组中的元素。每一行的第二个元素对应于这些国家的首都,它们已经在下面的代码中单独声明。
Sub array_2d()
Dim country_capital(4, 2) As String
For i = 1 To 4
country_capital(i, 1) = countries(i)
Next i
country_capital(1, 2) = "Kathmandu"
country_capital(2, 2) = "New Delhi"
country_capital(3, 2) = "Berlin"
country_capital(4, 2) = "Amsterdam"
Range("B1").Value = "Capital"
For i = 1 To 4
Range("A" & i + 1).Value = country_capital(i, 1)
Range("B" & i + 1).Value = country_capital(i, 2)
Next i
End Sub
运行此子程序返回以下结果:

array_2d 子程序的输出。图片来源:作者。
2.3 动态数组
动态数组在无法确定数组大小并且数组大小可能在未来发生变化的情况下非常有用。在下面的代码中,我声明了两个数组countries_visited和population,但没有指定数组的大小。在dynamic_array子程序内,我通过使用ReDim语句将这两个数组的大小指定为 4。接下来,我根据我访问过的四个国家及其人口分别指定了数组的每个元素。
Option Base 1
Public countries_visited() As String
Public population() As Long
Sub dynamic_array()
Dim wb As Workbook
Dim ws2 As Worksheet
Set wb = ThisWorkbook
Set ws2 = wb.Worksheets("Sheet2")
ReDim countries_visisted(4)
ReDim population(4)
countries_visited(1) = "France"
population(1) = 68
countries_visited(2) = "Spain"
population(2) = 48
countries_visited(3) = "Iran"
population(3) = 88
countries_visited(4) = "Indonesia"
population(4) = 274
End Sub
一段时间后,我意识到我还访问了一个新国家(葡萄牙)。我在保留这些数组原始内容/元素的情况下重新定义了数组的大小。我通过增加数组的大小 1 来进行操作。为此,我使用了ReDim Preserve语句,如下所示。
ReDim Preserve countries_visited(1 to 5)
ReDim Preserve population(1 to 5)
完整代码如下:
Option Base 1
Public countries_visited() As String
Public population() As Long
Sub dynamic_array()
Dim wb As Workbook
Dim ws2 As Worksheet
Set wb = ThisWorkbook
Set ws2 = wb.Worksheets("Sheet2")
ReDim countries_visisted(4)
ReDim population(4)
countries_visited(1) = "France"
population(1) = 68
countries_visited(2) = "Spain"
population(2) = 48
countries_visited(3) = "Iran"
population(3) = 88
countries_visited(4) = "Indonesia"
population(4) = 274
ws2.Range("A1").Value = "Countries visited"
ws2.Range("B1").Value = "Population (million)"
ReDim Preserve countries_visited(5)
ReDim Preserve population(5)
countries_visited(5) = "Portugal"
population(5) = 10
Dim i As Integer
For i = 2 To 6
Range("A" & i).Value = countries_visited(i - 1)
Range("B" & i).Value = population(i - 1)
Next i
End Sub
上述代码的输出如下所示:

动态数组子程序的输出。图片来源:作者。
2.4 声明数组以存储不同数据类型的变量
在上面的章节中,countries_visited数组声明用于存储String数据类型的变量,population数组声明用于存储Long数据类型的变量。与 Python 的 numpy 数组类似,在 Excel VBA 中也可以在数组中存储不同数据类型的变量。在这种情况下,数组必须声明为Variant类型。
在下面的示例中,声明了一个名为test的数组作为Variant。其大小通过ReDim语句指定为 3。test中的三个元素分别为String、Integer和Date类型。通过将变量传递给TypeName()函数,可以识别这些数据类型。
Option Base 0
Sub variant_test()
Dim test() As Variant
ReDim test(3)
test = Array("Germany population in million: ", 83, Date)
Dim i As Integer
For i = 0 To 2
Debug.Print "Element " & i & " of test array is: " & test(i) & " of type " & TypeName(test(i))
Next i
End Sub
输出如下所示:

variant_test 子程序的输出。图片由作者提供。
结论
数组是由一个或多个数据类型的值/变量组成的集合。每个变量都与数组中的特定索引号相关联。数组可以是单维的、二维的或多维的。在 Python 中,没有内建的数组支持,但可以使用 numpy 包创建数组。除了存储值,numpy 数组在进行矩阵运算时也非常有用。在 Excel VBA 中,数组在处理大型数据元素数据库时非常有用。在 Excel VBA 中,数组可以是静态的,即数组的大小是预定义的。或者,数组也可以是动态的,即数组的大小不是预定义的,但我们可以在使用过程中指定其大小,甚至在保持已存储元素的情况下调整数组大小。
这个 GitHub 仓库提供了 Python 笔记本、Excel 工作簿以及 VBA 脚本。感谢阅读!
艺术守护:保护你的在线图像免受生成式 AI 的侵害
你可以采取的步骤,以防止机器人爬取并使用你的艺术作品训练 AI 模型,如 Stable Diffusion、Midjourney 和 DALL-E
·发布于 Towards Data Science ·18 分钟阅读·2024 年 8 月 23 日
--

艺术守护,DALL-E 3 生成的 AI 图像,作者编辑
很多艺术家都对生成式 AI 感到担忧。他们担心网络爬虫未经许可和/或补偿,爬取他们网页上的图像用于训练 AI 模型。
我花了过去四周的时间研究这个话题,并找到了很多关于这些模型如何工作以及如何防止机器人窃取你作品的信息。
TL;DR — 你可以做的最简单的事情是通过你的托管服务的设置关闭 AI 爬虫访问你的网站,比如 SquareSpace 和 其他服务。更多关于这个和其他你可以采取的步骤的信息在下面。
在这篇文章中,我将提供一些关于文本生成图像的 AI 模型如何工作的背景知识,包括 Stable Diffusion、Midjourney 和 DALL-E 3。接下来,我将向你展示如何检测这些模型是否使用了你的图像进行训练。最后,我将提供一些建议和步骤,帮助你防止机器人窃取你的作品。
ASA 的警示:重新思考我们在研究中如何使用 p 值
理解 ASA 的声明,提升你的数据科学实践
·发表于 Towards Data Science ·8 分钟阅读·2024 年 6 月 11 日
--

图片由 Jason Dent 提供,来源于 Unsplash
引言:
在数据科学领域,计算 p 值是一项非常常见的任务,是假设检验的核心。无论你是在分析 A/B 测试结果、进行医学研究,还是评估市场趋势,p 值都是解读数据的重要指标。然而,误解这一指标可能导致错误的结论。近七年前,美国统计协会(ASA)发布了关于最佳实践和常见误用的声明,提醒人们在解读 p 值时应避免的错误做法。尽管如此,自那时以来,这一指标的误用并未显著减少。
假设检验:
为了希望能够全面理解这一声明,我将首先列出成功进行假设检验所需的步骤:
-
写下零假设(H0)和备择假设(H1)。
-
为了保证结果的严谨性,决定显著性水平(𝛂)——这是犯第一类错误的概率,即当零假设 H0 为真时却错误地拒绝它。显著性水平通常设定为𝛂 = 0.05,但这仅仅是一个…
ASCVIT V1:自动化统计计算、可视化和解释工具
轻松实现自动化数据分析:ASCVIT 工具的第一个版本,提供统计计算、可视化和解释功能
·发表于Towards Data Science ·阅读时间 30 分钟·2024 年 9 月 16 日
--
在我的学习过程中,我参加了一个数据科学研讨会,并首次接触到了统计编程语言 R。当时,我对其可能带来的应用潜力感到着迷。与此同时,得益于机器学习领域的进展,数据的统计评估变得更加简便。当然,这需要一定的技术理解,并且你需要知道某些方法的实际作用。还需要了解哪些数据或输入是某些方法能够正常工作或得出有意义结果的前提。在本文中,我将讨论开发本地应用的第一个版本(V1)的过程,该应用可以用于自动地将各种统计方法应用于任何数据集。这是一个开源项目,旨在用于教育和研究目的。
数据可以以.csv 或.xlsx 格式上传。应用的第一版本提供了一个通用的数据概览(数据预览、数据描述、数据点数量和变量分类)、描述性统计分析(直方图、箱型图、散点图矩阵和相关矩阵)、各种假设检验(t 检验、方差分析和卡方检验)、回归分析(线性回归、逻辑回归和多元回归)、时间序列分析,并支持各种聚类方法(k 均值、层次聚类和 DBSCAN)。该应用是使用 Python 框架 Streamlit 创建的。

ASCVIT V1 分析方法概览(图片来自作者)
由于代码的模块化结构,可以轻松实现进一步的统计程序。代码中有注释,这使得你更容易上手。当应用程序运行时,上传数据集后界面如下所示。

ASCVIT V1 Streamlit 应用程序(图片来自作者)
除了在前述各个领域的自动分析外,还集成了一个功能,能够自动分析统计记录的数值。“query_llm_via_cli” 功能使得通过 CLI(命令行界面)与 LLM 进行交流成为可能,使用的是 Ollama。
我已经在我发布的Towards Data Science 1 文章中解释了这一原理。在应用程序的第一个版本中,此功能仅限于描述性统计分析,但也可以扩展到其他分析上。具体来说,这意味着除了自动统计计算外,应用程序还会自动解读数据。

ASCVIT V1 CLI + OLLAMA + LMS(图片来自作者)
测试应用程序的数据集
如果你没有自己的数据,可以访问互联网上的多个网站,这些网站提供免费的数据集。用于开发和测试此应用程序的数据集来自Maven Analytics(许可证:ODC-BY) [2]。

MAVEN Analytics 数据游乐场(截图来自作者)
网站上有大量免费的数据集。我所查看的数据涉及从 1976 年到 2024 年间的视频游戏销售数据。具体来说,它记录了北美、日本、欧盟、非洲和其他地区的销售数据。总共有 64016 个游戏标题以及它们的评分、类型、平台等信息。
不幸的是,并非所有标题都有完整的信息。有很多 NaN(非数字)值,这在用 Python 分析时会导致问题或扭曲某些统计分析结果。下面我将简要讨论数据记录的清理过程。

MAVEN Analytics 的视频游戏销售数据(截图来自作者)
清理数据集
你可以在将数据集加载到应用程序之前,通过使用单独的脚本清理数据集,或者直接在应用程序中进行清理。在本文的应用程序中,我已在应用程序中直接实现了数据清理。如果你希望提前清理数据记录,可以使用以下脚本进行操作。
import pandas as pd
df = pd.read_csv('YOUR .CSV FILE')
df_cleaned = df.dropna()
df_cleaned.to_csv('cleaned_file.csv', index=False)
print("Lines with missing data have been removed and saved in 'cleaned_file.csv'.")
使用“pd.read_csv(‘.csv’)”读取文件,并将数据保存到 DataFrame“df”中。“df.dropna()”删除 DataFrame 中包含缺失值‘NaN’的所有行。清洗后的 DataFrame 保存在变量“df_cleaned”中。使用“df_cleaned.to_csv(‘cleaned_file.csv’, index=False)”将数据保存到新的.csv 文件中,且不保存行索引。接下来,输出成功完成的过程“print(…)”。该数据集清洗的代码可以在文件“clean.py”中找到,并且稍后也可以下载。接下来,让我们进入应用程序的实际代码部分。

ASCVIT V1 Python 代码片段(作者制作的 GIF)
所需的库和模块
使用此应用程序需要各种库和模块,这些库和模块结合在一起执行数据可视化、统计分析和机器学习任务。
import re
import subprocess
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import seaborn as sns
from matplotlib.patches import Patch
from scipy import stats
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression, LogisticRegression
from statsmodels.stats.multicomp import pairwise_tukeyhsd
import streamlit as st
关于图表表示的说明。 有些使用“pyplot”(Matplotlib),而有些使用“plotly”(例如箱形图)。尽管使用“plotly”会产生更互动的图形,但并不意味着每种图表类型都需要使用它。最终,用户必须自行决定图表应如何显示。代码必须相应地进行调整。
可以通过 ZIP 目录中的 requirements.txt 文件使用以下命令安装应用程序所需的库。
pip install -r requirements.txt
数据概览
函数“display_data_info()”专门分析 Pandas DataFrame “df”并输出统计关键数据(均值、标准差等)“df.describe()”。DataFrame 的总数据点(行数)通过“len(df)”输出。同样,DataFrame 的数值型变量“numerical_columns”和分类变量“categorical_columns”(字符串类型)也会被输出。

ASCVIT V1 应用数据概览(作者制作的图片)
数据集总共有 64016 个数据点,其中 6 个数值型变量和 8 个分类变量。在开始进行某些统计处理之前,首先应该查看数据。在“数据概览”部分,您可以获取各种信息,以得出是否可以进行某些测试的结论。
例如,如果数据集中没有日期变量,则无法进行时间序列分析。如果没有二元变量,则无法进行逻辑回归。该应用程序已经设计为在变量类别不正确时询问或显示错误信息。接下来,让我们继续描述性统计部分。
def display_data_info(df):
st.write("**Data description:**")
st.write(df.describe())
st.write(f"**Number of data points:** {len(df)}")
numerical_columns = df.select_dtypes(include=np.number).columns.tolist()
categorical_columns = df.select_dtypes(include='object').columns.tolist()
st.write("**Numerical variables:** ", ", ".join(numerical_columns))
st.write("**Categorical variables:** ", ", ".join(categorical_columns))
return numerical_columns, categorical_columns
描述性统计
“descriptive_statistics()”函数允许用户选择不同的图表类型(直方图、箱型图、对角线图和相关矩阵)。通过“st.markdown(”“”…“”“)”可以对这些类型进行简要解释。然后需要选择一个或多个数值变量“selected_vars”。除了相关矩阵外,还可以选择是否应用对数缩放“apply_log_scale”。如果数据严重扭曲,应用对数缩放对变量尤其有用。使用相应的图表函数来创建可视化。
def descriptive_statistics(df, numerical_columns):
chart_type = st.selectbox("Select the diagram:", ["Histogram", "Boxplot", "Pairplot", "Correlation matrix"])
if chart_type == "Histogram":
st.markdown("""
**Histogram:**
A histogram shows the distribution of a numerical variable. It helps to
recognize how frequently certain values occur in the data and whether there are patterns, such as a normal distribution.
""")
elif chart_type == "Boxplot":
st.markdown("""
**Boxplot:**
A boxplot shows the distribution of a numerical variable through its quartiles.
It helps to identify outliers and visualize the dispersion of the data.
""")
elif chart_type == "Pairplot":
st.markdown("""
**Pairplot:**
A pairplot shows the relationships between different numerical variables through scatterplots.
It helps to identify possible relationships between variables.
""")
elif chart_type == "Correlation matrix":
st.markdown("""
*Correlation matrix:**
The correlation matrix shows the linear relationships between numerical variables.
A positive correlation indicates that high values in one variable also correlate with high values in another.
""")
if chart_type in ["Pairplot", "Correlation matrix"]:
selected_vars = st.multiselect("Select variables:", numerical_columns, default=numerical_columns)
else:
selected_vars = [st.selectbox("Select a variable:", numerical_columns)]
if chart_type != "Correlation matrix":
apply_log_scale = st.checkbox("Apply logarithmic scaling?", value=False)
else:
apply_log_scale = False
if st.button("Create diagram"):
if chart_type == "Histogram":
plot_histogram(df, selected_vars[0], apply_log_scale)
elif chart_type == "Boxplot":
plot_boxplot(df, selected_vars[0], apply_log_scale)
elif chart_type == "Pairplot":
plot_pairplot(df, selected_vars)
elif chart_type == "Correlation matrix":
plot_correlation_matrix(df, selected_vars)
直方图函数
“plot_histogram()”函数用于根据用户选择的变量创建直方图。开始时,所有 NaN 值都会从变量“cleaned_data”中去除。然后计算各种统计关键数字(均值“mean_value”、中位数“median_value”、标准差“std_value”、最小值“min_value”、最大值“max_value”以及标准差的上下限)。
由于数据是由 LLM(大语言模型)进行解读的,正如前面提到的,数据的离散度(与数据范围相关的标准差)和分布(均值与中位数的差异)被分类。直方图是通过“fix, ax = plt.subplots()”创建的,随后添加垂直线以增加信息量,最后通过“st.pyplot(fig)”显示直方图。如果数据是扭曲的或呈指数分布,可以激活对数缩放,进而调整直方图的 y 轴。此时,图表看起来如下所示[3]。

ASCVIT V1 直方图(pyplot)与 LLM 解释(作者提供的图片)
由于目前没有可以直接读取图形的模型,我们为 LLM 创建了一个通用的分析上下文。该上下文包含统计计算结果及附加说明或所需的解释。这意味着,作为 LLM 输入的上下文,可以应用于任何数据集。
具体而言,输入包括上述统计关键数字、分布的分析(对称、右偏或左偏)、扩展范围的估计(低、中或高)以及为 LLM 格式化的解释。根据所需的输出,可以单独调整和进一步指定上下文。

LLM 上下文示例(作者提供的图片)
分析结果通过“response = query_llm_via_cli(context)”发送到 LLM,之后在短时间间隔内,根据本地系统的性能,进行直方图的解释“st.write(f”Histogram Interpretation: {response}”)”。
def plot_histogram(df, variable, apply_log_scale):
cleaned_data = df[variable].dropna()
mean_value = cleaned_data.mean()
median_value = cleaned_data.median()
std_value = cleaned_data.std()
min_value = cleaned_data.min()
max_value = cleaned_data.max()
std_upper = mean_value + std_value
std_lower = max(0, mean_value - std_value)
concentration_range = (mean_value - std_value, mean_value + std_value)
if std_value < (max_value - min_value) / 6:
scatter = "low"
elif std_value < (max_value - min_value) / 3:
scatter = "moderate"
else:
scatter = "high"
if abs(mean_value - median_value) < 0.1 * std_value:
distribution = "symmetrical"
elif mean_value > median_value:
distribution = "right-skewed"
else:
distribution = "left-skewed"
fig, ax = plt.subplots()
ax.hist(cleaned_data, bins=30, edgecolor='black', alpha=0.7)
ax.axvline(mean_value, color='red', linestyle='--', label=f'Mean: {mean_value:.2f}')
ax.axvline(median_value, color='green', linestyle='-', label=f'Median: {median_value:.2f}')
ax.axvline(std_upper, color='blue', linestyle=':', label=f'+1 Std: {std_upper:.2f}')
ax.axvline(std_lower, color='blue', linestyle=':', label=f'-1 Std: {std_lower:.2f}')
ax.set_title(f"Histogram of {variable}")
ax.legend(title=f'Std-Deviation: {std_value:.2f}')
if apply_log_scale:
ax.set_yscale('log')
st.pyplot(fig)
context = (
f"Here is an analysis of the distribution of the variable '{variable}':\n"
f"- Mean: {mean_value:.2f}\n"
f"- Median: {median_value:.2f}\n"
f"- Standard deviation: {std_value:.2f}\n"
f"- Minimum: {min_value:.2f}\n"
f"- Maximum: {max_value:.2f}\n\n"
f"The distribution of the data shows a {distribution} distribution.\n"
f"The small difference between mean and median indicates a {distribution} distribution.\n"
f"A strong concentration of data points is observed between {concentration_range[0]:.2f} and {concentration_range[1]:.2f}.\n"
f"The scatter of the data is described as {scatter}, indicating a relatively tight distribution around the mean.\n\n"
f"Please analyze this distribution in the histogram, paying particular attention to symmetry, scatter, and potential deviations.\n"
f"Avoid calling the distribution normal unless there are explicit indications.\n"
f"Use only the names of the variables {variable} in the analysis!"
)
response = query_llm_via_cli(context)
st.write(f"**Histogram Interpretation:** {response}")
箱型图函数
“plot_boxplot()” 函数为用户选择的变量创建箱线图。基于变量,从 DataFrame 中计算统计关键数字,以便在图表中显示数据分布,并使用 LLM 进行集中趋势和离散度分析。除了均值、中位数和标准差之外,与直方图一样,还计算了下四分位数 “q1”、上四分位数 “q3”、四分位距 “iqr”(Q3 — Q1)以及基于四分位距(1.5 * IQR)的下须“lower_whisker”和 “upper_whisker”,这些数据也会用于箱线图。
后者有助于识别异常值以及其他超出某个值的数据参数。箱线图是通过 Plotly 库 “fig = px.box(df, y=variable)” 创建的,最后在应用程序中显示 “st.plotly_chart(fig)”。此图类型也可以使用对数尺度 [4]。图表的样式如下所示:

ASCVIT V1 箱线图(plotly),显示 critic_score 和 LLM 解读(作者提供的图片)
与直方图类似,也为箱线图创建了一个上下文,并将其传递给 LLM。统计关键数字以及有关潜在异常值的信息(即那些超出须值的数据)都会被传输。发送给 LLM 的文本已格式化,以便在这些指标上执行分析。
def plot_boxplot(df, variable, apply_log_scale):
mean_value = df[variable].mean()
median_value = df[variable].median()
std_value = df[variable].std()
q1 = df[variable].quantile(0.25)
q3 = df[variable].quantile(0.75)
iqr = q3 - q1
lower_whisker = max(df[variable].min(), q1 - 1.5 * iqr)
upper_whisker = min(df[variable].max(), q3 + 1.5 * iqr)
fig = px.box(df, y=variable)
fig.update_layout(title=f"Boxplot of {variable}")
if apply_log_scale:
fig.update_yaxes(type="log")
st.plotly_chart(fig)
context = (
f"Here is an analysis of the distribution of the variable '{variable}' based on a boxplot:\n"
f"- Mean: {mean_value:.2f}\n"
f"- Median: {median_value:.2f}\n"
f"- Standard deviation: {std_value:.2f}\n"
f"- Lower quartile (Q1): {q1:.2f}\n"
f"- Upper quartile (Q3): {q3:.2f}\n"
f"- Interquartile range (IQR): {iqr:.2f}\n"
f"- Potential outliers outside values from {lower_whisker:.2f} to {upper_whisker:.2f}.\n"
f"Please analyze this distribution and identify patterns or outliers.\n"
f"Use only the names of the variables {variable} in the analysis!"
)
response = query_llm_via_cli(context)
st.write(f"**Boxplot Interpretation:** {response}")
PAIRPLOT 函数
“plot_pairplot()” 函数根据用户选择的变量创建配对图。如果选择的变量少于两个,则会显示错误信息。会显示所有可能的变量组合的散点图,并绘制线性回归线,以显示变量之间的关系。为了使其生效,使用 “calculate_regression_stats” 函数计算所有可能的变量对的回归统计数据。所选变量 “selected_vars” 中的 NaN 值会被移除。
在这两个变量之间执行线性回归。这里,“var1” 是自变量 x,“var2” 是因变量 y。计算斜率和 R2 值 “r_squared”。结果以元组列表(var1, var2, slope, r_squared)的形式返回。如果选择了三个变量 [“A”, “B”, “C”],则函数会计算对(A, B)、(A, C)、(B, A)、(B, C)等的回归统计 [5]。
def calculate_regression_stats(df, selected_vars):
regression_results = []
for var1 in selected_vars:
for var2 in selected_vars:
if var1 != var2:
non_nan_data = df[[var1, var2]].dropna()
X = non_nan_data[[var1]].values.reshape(-1, 1)
y = non_nan_data[var2].values
if len(X) > 0 and len(y) > 0:
model = LinearRegression()
model.fit(X, y)
r_squared = model.score(X, y)
slope = model.coef_[0]
regression_results.append((var1, var2, slope, r_squared))
return regression_results
def plot_pairplot(df, selected_vars):
if len(selected_vars) > 1:
st.write("**Pairplot with regression lines:**")
pairplot_fig = sns.pairplot(df[selected_vars], kind='reg', diag_kind='kde',
plot_kws={'line_kws': {'color': 'red'}, 'scatter_kws': {'color': 'blue'}})
st.pyplot(pairplot_fig.fig)
corr_matrix = df[selected_vars].corr()
regression_stats = calculate_regression_stats(df, selected_vars)
correlation_list = "\n".join(
[f"The correlation between {var1} and {var2} is {corr_matrix.at[var1, var2]:.2f}."
for var1 in corr_matrix.columns for var2 in corr_matrix.columns if var1 != var2]
)
regression_list = "\n".join(
[f"The regression line for {var1} and {var2} has a slope of {slope:.2f} and an R² of {r_squared:.2f}."
for var1, var2, slope, r_squared in regression_stats]
)
context = (
f"Here are the correlation and regression analyses between the selected variables:\n"
f"{correlation_list}\n\n"
f"{regression_list}\n\n"
f"Please analyze these relationships in detail based solely on the numerical values (correlation and regression lines).\n"
f"Use only the names of the variables {selected_vars} in the analysis!"
)
response = query_llm_via_cli(context)
st.write(f"**Pairplot Interpretation:** {response}")
else:
st.error("At least two variables must be selected for a pairplot.")
“plot_pairplot()” 函数在对角线上使用 KDE(核密度估计)来显示每个单独变量的分布。与之前的函数一样,也为 LLM 创建了一个上下文进行分析。对于这种类型的图表,LLM 会接收来自相关性和回归分析的数据。文本已格式化,以便生成有关变量之间关系的详细解释。

ASCVIT V1 配对图(pyplot),展示 critic_score、na_sales、pal_sales 和 LLM 解释(作者提供的图像)
相关矩阵函数
“plot_correlation_matrix” 函数用于根据用户选择的变量创建相关矩阵 “if len(selected_vars) > 1”。如果只选择了一个变量,将显示错误消息。可视化以热图形式展示。矩阵单元格的颜色表示相关性的强度和方向。显著的相关性会自动发送到 LLM 进行进一步分析 “if var1 != var2 and abs(corr_matrix.at[var1, var2]) >= 0.5”。
选定变量之间的线性相关性以相关系数(值在 -1 和 +1 之间)“corr_matrix = df[selected_vars].cor()” 的形式展示。若值为 0,表示没有线性相关性。接近 -1 的值表示强烈的负相关,而接近 +1 的值表示强烈的正相关。变量对及其相关值会保存在 “high_correlations” [4] 中。

ASCVIT V1 相关矩阵(pyplot),包含所有变量及 LLM 解释(作者提供的图像)
为 LLM 创建了一个上下文。现有的显著相关性被分类为文本描述 “correlation_list”。若相关性较强(无论是正相关还是负相关),其值大于 0.7。如果值介于 0.5 和 0.7 之间,则表示中等相关性,而如果值仅略高于 0.5,则表示相关性较弱。如果未发现显著相关性,则会显示相应的消息。
def plot_correlation_matrix(df, selected_vars):
if len(selected_vars) > 1:
corr_matrix = df[selected_vars].corr()
fig, ax = plt.subplots()
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', ax=ax)
ax.set_title("Correlation Matrix")
st.pyplot(fig)
high_correlations = []
for var1 in corr_matrix.columns:
for var2 in corr_matrix.columns:
if var1 != var2 and abs(corr_matrix.at[var1, var2]) >= 0.5:
if (var2, var1) not in [(v1, v2) for v1, v2, _ in high_correlations]:
high_correlations.append((var1, var2, corr_matrix.at[var1, var2]))
if high_correlations:
correlation_list = "\n".join([f"- {var1} and {var2} have a correlation value of {value:.2f}, "
f"indicating a {'strong' if abs(value) > 0.7 else 'moderate' if abs(value) > 0.5 else 'weak'} correlation."
for var1, var2, value in high_correlations])
context = (
f"Here is an analysis of the significant correlations between the selected variables in the correlation matrix:\n"
f"{correlation_list}\n\n"
f"Please analyze the correlations solely based on their strength and significance.\n"
f"Use only the names of the variables {selected_vars} in the analysis!"
f"Focus in detail on the statistical relationship and patterns."
)
response = query_llm_via_cli(context)
st.write(f"**Model Response:** {response}")
else:
st.write("**No significant correlations were found.**")
else:
st.write("**The correlation matrix cannot be displayed because fewer than two variables were selected.**")
在以下统计程序中,LLM 记录的各项关键数据的解释不可用。然而,基于先前的程序,独立实施应该不成问题。现在让我们转向各种假设检验。
假设检验选择
在当前版本中,可以进行三种不同的检验(t 检验、方差分析(ANOVA)和卡方检验) “test_type = st.selectbox()”。根据选择的检验类型,会出现简要的说明,解释其用途。根据应用领域,这些描述可以扩展或移除。t 检验用于比较两组的均值。方差分析(ANOVA)用于比较多于两组的均值。卡方检验用于检验两个分类变量之间的独立性。根据选择的检验,执行相应的函数。
T 检验
如果用户选择了 t 检验,他们必须选择一个组变量(类别型)“group_col” 和一个数值变量(数值型)“value_col”。组变量定义了要比较的两组,数值变量比较这两组的均值。一旦选择完成,必须在文本框 “st.text_input()” 中输入两组的名称 “group1” 和 “group2”。这两组应该出现在所选的类别变量中。这里也提供了对数标度 “apply_log_scale”,它应用于数值变量。当进行检验时,提取组的数据并输出数据点数量(移除 NaN 值后)。然后显示 t 统计量和 p 值。
第一个值表示两组之间均值差异与数据分布的相对关系。是否存在显著的组间差异由 p 值指示。如果 p 值小于 0.05,则说明差异显著。为了直观地突出显示两组的分布,“filtered_df = df[df[group_col].isin([group1, group2])]”,创建了一个箱形图 “fig, ax = plt.subplots()”。这里使用了“pyplot”,你也可以选择使用“plotly” [6]。

ASCVIT V1 箱形图,按类型(动作/射击)和评论分数分类(图像来源:作者)
在此示例中,“genre” 被选为组变量,“critic_score” 被选为数值变量。动作(组 1)和射击(组 2)被定义为比较组。该函数还计算组内是否存在显著的离群值。离群值被定义为超出上四分位数 1.5 倍四分位距的数据点 “outliers_group1/2”。最后,找到的离群值会被显示出来,以确认 t 检验的有效性。如果偏差过大,必须相应考虑,以便更好地分类检验结果的可靠性和可解释性。
def t_test(df, numerical_columns, categorical_columns):
group_col = st.selectbox("Choose the group variable:", categorical_columns)
value_col = st.selectbox("Choose the value variable:", numerical_columns)
group1 = st.text_input("Name of group 1:")
group2 = st.text_input("Name of group 2:")
apply_log_scale = st.checkbox("Apply logarithmic scaling?", value=False)
if st.button("Perform t-Test"):
group1_data = df[df[group_col] == group1][value_col]
group2_data = df[df[group_col] == group2][value_col]
initial_count_group1 = len(group1_data)
initial_count_group2 = len(group2_data)
group1_data = group1_data.dropna()
group2_data = group2_data.dropna()
remaining_count_group1 = len(group1_data)
remaining_count_group2 = len(group2_data)
st.write(f"**Group 1 ({group1}):** Total number of data points: {initial_count_group1}, without NaN: {remaining_count_group1}")
st.write(f"**Group 2 ({group2}):** Total number of data points: {initial_count_group2}, without NaN: {remaining_count_group2}")
if apply_log_scale:
group1_data = np.log1p(group1_data)
group2_data = np.log1p(group2_data)
if not group1_data.empty and not group2_data.empty:
t_stat, p_value = stats.ttest_ind(group1_data, group2_data)
st.markdown(f"**t-Statistic:** {t_stat}")
st.markdown(f"**p-Value:** {p_value}")
filtered_df = df[df[group_col].isin([group1, group2])]
fig, ax = plt.subplots()
sns.boxplot(x=filtered_df[group_col], y=filtered_df[value_col], ax=ax, palette="Set2")
ax.set_title(f"Boxplot for {group1} vs. {group2}")
if apply_log_scale:
ax.set_yscale('log')
st.pyplot(fig)
outliers_group1 = group1_data[group1_data > group1_data.quantile(0.75) + 1.5 * (group1_data.quantile(0.75) - group1_data.quantile(0.25))]
outliers_group2 = group2_data[group2_data > group2_data.quantile(0.75) + 1.5 * (group2_data.quantile(0.75) - group2_data.quantile(0.25))]
st.write("**Outlier Analysis:**")
if not outliers_group1.empty:
st.write(f"In group 1 ({group1}) there are {len(outliers_group1)} outliers.")
else:
st.write(f"In group 1 ({group1}) there are no significant outliers.")
if not outliers_group2.empty:
st.write(f"In group 2 ({group2}) there are {len(outliers_group2)} outliers.")
else:
st.write(f"In group 2 ({group2}) there are no significant outliers.")
else:
st.error("One or both groups contain no data after removing NaN values.")
ANOVA 检验
“anova_test()” 函数集成了执行 ANOVA 检验的选项。该检验用于检查多个组的均值是否存在显著差异。数据首先被清洗 “df_clean”。如果 ANOVA 检验显著,还会进行 Tukey 的 HSD 检验(诚实显著差异)。首先,再次定义一个组变量和一个数值变量。如果某个组的数据点少于 2 个,它将被排除 “valid_groups = group_sizes[group_sizes >= 2].index”。
如果调整后剩余的组少于两个,则会显示错误信息,且不执行该测试。ANOVA 测试计算 F 值和 p 值。F 值衡量组间的变异性与组内变异性的比值。p 值则指示组间均值差异是否显著。如果 p 值小于 0.05,则至少有一个组存在显著差异。为了可视化结果,使用 “pyplot” 创建箱型图 [7]。

ASCVIT V1 箱型图与控制台和评分(图像来源:作者)
如果 ANOVA 测试结果显著,则进行 Tukey 测试,以具体检验各组之间的差异。因此,ANOVA 测试并不显示哪些组之间存在差异。会创建一个图表,显示各组之间的配对均值差异及其置信区间 “st.pyplot(tukey.plot_simultaneous())”。

ASCVIT V1 Tukey 测试结果(图像来源:作者)
在图表下方,结果以表格形式显示,“st.dataframe(tukey_results_df, height=400)”。该表包含两组数据、均值差异 “meandiff”、调整后的 p 值 “p-adj”、置信区间以及是否可以拒绝原假设 “reject”(True = 显著,False = 不显著)。以下是置信区间的简要示例:对于 3DS 和 GBA 控制台,区间位于 -0.9319 和 -0.0061 之间,因此完全低于零。均值差异是显著的。
关键数字可以用于通过 LLM 解读结果。还可以选择将数据作为 .csv 文件下载,以便进一步进行统计分析(例如回归分析)[7]。
def anova_test(df, numerical_columns, categorical_columns):
group_col = st.selectbox("Choose the group variable:", categorical_columns)
value_col = st.selectbox("Choose the value variable:", numerical_columns)
if st.button("Perform ANOVA"):
df_clean = df[[group_col, value_col]].dropna()
group_sizes = df_clean.groupby(group_col).size()
valid_groups = group_sizes[group_sizes >= 2].index
df_filtered = df_clean[df_clean[group_col].isin(valid_groups)]
if len(valid_groups) < 2:
st.error("After removing small groups, there are not enough groups left for the ANOVA test.")
else:
grouped_data = [group[value_col].values for name, group in df_filtered.groupby(group_col)]
try:
anova_result = stats.f_oneway(*grouped_data)
st.markdown(f"**F-Value:** {anova_result.statistic}")
st.markdown(f"**p-Value:** {anova_result.pvalue}")
fig, ax = plt.subplots(figsize=(10, 6))
sns.boxplot(x=group_col, y=value_col, data=df_filtered, ax=ax)
plt.xticks(rotation=90)
st.pyplot(fig)
if anova_result.pvalue < 0.05:
st.write("The ANOVA test is significant. Tukey's HSD test will be performed.")
try:
tukey = pairwise_tukeyhsd(endog=df_filtered[value_col], groups=df_filtered[group_col], alpha=0.05)
st.pyplot(tukey.plot_simultaneous())
tukey_results_df = pd.DataFrame(data=tukey.summary().data[1:], columns=tukey.summary().data[0])
st.write("Results of the Tukey HSD test:")
st.dataframe(tukey_results_df, height=400)
csv = tukey_results_df.to_csv(index=False)
st.download_button(label="Download Tukey HSD results as CSV", data=csv, file_name='tukey_hsd_results.csv', mime='text/csv')
except Exception as e:
st.error(f"An error occurred during Tukey's HSD test: {str(e)}")
except ValueError as e:
st.error(f"An error occurred: {str(e)}.")
卡方检验
“chi_square_test()” 函数检查两个分类变量之间是否存在统计学显著关系。由于只能使用分类变量,因此不需要激活对数缩放选项。具体来说,它检查类别中观测频率是否相互独立,或者是否存在相关性。用户选择两个现有的分类变量。NaN 值会被移除,并且每个变量仅选择前 10 个最频繁的类别,以保持分析的可管理性,“value_counts().nlargest(10).index”。
创建了一个交叉表 “contingency_table”,它使用热图显示两个选定变量中类别组合的频率。如果交叉表无效(数据过少或仅有一个类别),则不执行该测试 [8]。

ASCVIT V1 带有类型和控制台的热图(图像来源:作者)
测试计算各种值。卡方值“chi2”确定观察频率与预期频率之间差异的程度。高值表明差异很大。与其他分析一样,p 值“p”显示差异是否显著。还指示了测试的自由度“dof”以及预期频率“expected”。
def chi_square_test(df, categorical_columns):
cat_var1 = st.selectbox("Choose the first group variable:", categorical_columns)
cat_var2 = st.selectbox("Choose the second group variable:", categorical_columns)
if st.button("Perform Chi-square test"):
df_clean = df[[cat_var1, cat_var2]].dropna()
top_cat_var1 = df_clean[cat_var1].value_counts().nlargest(10).index
top_cat_var2 = df_clean[cat_var2].value_counts().nlargest(10).index
df_filtered = df_clean[df_clean[cat_var1].isin(top_cat_var1) & df_clean[cat_var2].isin(top_cat_var2)]
try:
contingency_table = pd.crosstab(df_filtered[cat_var1], df_filtered[cat_var2])
if contingency_table.empty or contingency_table.shape[0] < 2 or contingency_table.shape[1] < 2:
st.error("The contingency table is invalid. Check the variables.")
else:
chi2, p, dof, expected = stats.chi2_contingency(contingency_table)
st.markdown(f"**Chi-square:** {chi2}")
st.markdown(f"**p-Value:** {p}")
st.write("**Heatmap of the contingency table:**")
fig, ax = plt.subplots(figsize=(12, 10)) # Larger display
sns.heatmap(contingency_table, annot=False, cmap="YlGnBu", ax=ax)
ax.set_title(f"Heatmap of the contingency table: {cat_var1} vs. {cat_var2} top 10")
plt.xticks(rotation=90)
st.pyplot(fig)
except ValueError as e:
st.error(f"An error occurred: {str(e)}.")
回归分析选择
可用回归分析(线性回归、逻辑回归和多元回归)的选择方式与选择各种假设检验相似。选择分析方法后,将显示简短的说明,并调用相应的函数。
线性回归分析
在“linear_regression()”函数的开始,创建一个由上传数据集中所有可用数值型变量构成的相关矩阵“corr_matrix = df[numerical_columns].corr()”。该矩阵旨在帮助用户理解变量之间的关系,以便识别适合回归分析的变量和不适合的变量(多重共线性)。
最后,选择因变量和一个或多个自变量。数据被清理后,为所有选择的自变量创建线性回归模型“model = LinearRegression()”。回归系数和截距被指定。在总体模型运行完成后,为每个自变量创建单独的线性回归模型,并通过散点图表示[9]。

ASCVIT V1 线性回归 pal_sales、na_sales 和 total_sales(图像由作者提供)
显示的回归系数表示当相应的自变量变化一个单位时,因变量的变化程度。假设所有其他变量保持不变。当所有自变量为零时,因变量所取的值由截距指示。
def linear_regression(df, numerical_columns):
st.write("**Correlation matrix of numerical variables:**")
corr_matrix = df[numerical_columns].corr()
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', ax=ax)
st.pyplot(fig)
dependent_var = st.selectbox("Choose the dependent variable:", numerical_columns)
independent_vars = st.multiselect("Choose the independent variables:", numerical_columns)
if independent_vars:
if st.button("Perform regression"):
X = df[independent_vars].dropna()
y = df[dependent_var].loc[X.index]
y = y.dropna()
X = X.loc[y.index]
if y.isnull().values.any():
st.error("The dependent variable still contains missing values. Please clean the data.")
else:
model = LinearRegression()
model.fit(X, y)
st.markdown("**Regression coefficients:**")
for var, coef in zip(independent_vars, model.coef_):
st.write(f"- {var}: {coef}")
st.write(f"**Intercept:** {model.intercept_}")
for var in independent_vars:
X_single_var = X[[var]] # Use only the current independent variable
model_single = LinearRegression()
model_single.fit(X_single_var, y)
fig, ax = plt.subplots()
ax.scatter(X[var], y, edgecolor='none', facecolors='blue', s=5, label='Data points')
ax.plot(X[var], model_single.predict(X_single_var), color='red', label='Regression line')
ax.set_xlabel(var)
ax.set_ylabel(dependent_var)
ax.set_title(f"{dependent_var} vs {var}")
ax.legend()
st.pyplot(fig)
逻辑回归分析
与其他功能一样,用户在开始时选择因变量和自变量。在此分析方法中,因变量必须是二元的(0/1)。为了演示该功能,我创建了一些与迄今为止使用的数据集无关的数据。或者,你也可以手动调整变量的值,只要类别数目不太多。如果选择了错误的变量,将显示相应的错误信息。
如果所有内容都正确定义,则执行逻辑回归,模型使用自变量来建模目标变量的概率。具体而言,这是指事件发生的概率。每个自变量都有相应的系数,并且逻辑函数会被可视化。这展示了当自变量变化时,目标结果(1 而非 0)的概率如何变化[10]。

ASCVIT V1 逻辑回归演示用途的虚构值(图片来源:作者)
在散点图中,红线代表目标结果的预测概率,即结果 1 发生的概率。自变量发生变化。“logistic_regression()” 函数非常适用于二分类问题,在该问题中,您希望根据多个因素预测事件的发生。
def logistic_regression(df, numerical_columns):
dependent_var = st.selectbox("Choose the dependent variable (binary):", numerical_columns)
independent_vars = st.multiselect("Choose the independent variables:", numerical_columns)
if independent_vars:
if st.button("Perform logistic regression"):
X = df[independent_vars].dropna()
y = df[dependent_var].loc[X.index].dropna()
X = X.loc[y.index]
unique_values = y.unique()
if len(unique_values) != 2:
st.error("The dependent variable must be binary (e.g., 0 and 1).")
else:
model = LogisticRegression()
model.fit(X, y)
st.write("**Logistic regression coefficients:**")
for var, coef in zip(independent_vars, model.coef_[0]):
st.write(f"- {var}: {coef}")
st.write(f"**Intercept:** {model.intercept_[0]}")
for var in independent_vars:
fig, ax = plt.subplots()
ax.scatter(X[var], y, label='Data points')
x_range = np.linspace(X[var].min(), X[var].max(), 300).reshape(-1, 1)
X_copy = pd.DataFrame(np.tile(X.mean().values, (300, 1)), columns=X.columns)
X_copy[var] = x_range.flatten() # Vary the current variable var
y_prob = model.predict_proba(X_copy)[:, 1]
ax.plot(x_range, y_prob, color='red', label='Logistic function')
ax.set_xlabel(var)
ax.set_ylabel(f'Probability ({dependent_var})')
ax.set_title(f'Logistic regression: {dependent_var} vs {var}')
ax.legend()
st.pyplot(fig)
多元回归分析
在多元回归分析中,用户必须选择多个因变量和一个或多个自变量。分析将检视因变量如何受到自变量的影响。选择变量后,NaN 值将被再次删除,必要时会显示错误信息。模型输出所有因变量的回归系数和截距。
为所有自变量和因变量的组合创建带回归线的散点图。这个功能使得可以同时分析多个目标变量,并且建立它们与多个预测因子之间的关系[11]。

ASCVIT V1 多元回归(plotly)示例(图片来源:作者)
def multivariate_regression(df, numerical_columns):
dependent_vars = st.multiselect("**Choose the dependent variables (multiple):**", numerical_columns)
independent_vars = st.multiselect("**Choose the independent variables:**", numerical_columns)
if dependent_vars and independent_vars:
if st.button("Perform multivariate regression"):
X = df[independent_vars].dropna()
Y = df[dependent_vars].loc[X.index].dropna()
X = X.loc[Y.index]
if X.shape[1] != len(independent_vars) or Y.shape[1] != len(dependent_vars):
st.error("The number of independent or dependent variables does not match.")
return
model = LinearRegression()
model.fit(X, Y)
st.write("**Multivariate regression coefficients:**")
for i, dep_var in enumerate(dependent_vars):
st.write(f"\nFor the dependent variable: **{dep_var}**")
st.write(f"Intercept: {model.intercept_[i]}")
for var, coef in zip(independent_vars, model.coef_[i]):
st.write(f"- {var}: {coef}")
for dep_var in dependent_vars:
for var in independent_vars:
fig, ax = plt.subplots()
ax.scatter(X[var], Y[dep_var], label='Data points')
x_range = np.linspace(X[var].min(), X[var].max(), 300).reshape(-1, 1)
X_copy = pd.DataFrame(np.tile(X.mean().values, (300, 1)), columns=X.columns)
X_copy[var] = x_range.flatten()
y_pred = model.predict(X_copy)
ax.plot(x_range, y_pred[:, dependent_vars.index(dep_var)], color='red', label='Regression line')
ax.set_xlabel(var)
ax.set_ylabel(dep_var)
ax.set_title(f'Multivariate regression: {dep_var} vs {var}')
ax.legend()
st.plotly_chart(fig)
时间序列分析
该分析方法进行的是时间上的分析。为此,将给定的时间序列按年分组,并计算并显示值的年均值。该分析需要一个时间变量;如果使用的数据集不包含该变量,则无法进行分析。在我选择的数据集中,包含了变量 “release_date”,它记录了每个游戏的发布日期。
所选时间变量被转换为日期格式 “df[time_var]”。如果数据点无效,则会将其转换为 NaN 值并删除 “df = df.dropna(subset=[time_var])”。然后,数据按年分组 “df[‘year’] = df[time_var].dt.year”,并计算指定值变量“value_var”的年均值 “yearly_avg”。计算值变量的最小和最大年均值以及所有数据点的总体均值 “overall_avg”。接着,以折线图的形式显示每年的值变量年均值。总体均值集成在水平线上。为了提高可读性,值会交替显示在数据点的上下方[12]。

ASCVIT V1 时间序列分析与年份和 critic_score(图片来自作者)
重要的统计关键指标显示在图表下方,可以像描述性分析一样使用 LLM 轻松解释。具体来说,显示了标准差、方差以及值变量和年份的最小值和最大值。“perform_time_series_analysis()”函数适用于分析数据序列中的时间趋势。这可以对时间的变异性进行初步分析。
def perform_time_series_analysis(df, time_var, value_var):
df[time_var] = pd.to_datetime(df[time_var], errors='coerce')
df = df.dropna(subset=[time_var])
if df.empty:
st.error("**Error:** The time variable has an incorrect format.")
else:
df['year'] = df[time_var].dt.year
yearly_avg = df.groupby('year')[value_var].mean().reset_index()
y_min = yearly_avg[value_var].min()
y_max = yearly_avg[value_var].max()
y_range = y_max - y_min
y_buffer = y_range * 0.05
overall_avg = df[value_var].mean()
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(yearly_avg['year'], yearly_avg[value_var], marker='o', label='Yearly average')
ax.axhline(overall_avg, color='red', linestyle='--', label=f'Overall average: {overall_avg:.2f}')
ax.set_title(f'Average {value_var} per year')
ax.set_xlabel('Year')
ax.set_ylabel(f'Average {value_var}')
ax.set_ylim(y_min - y_buffer, y_max + y_buffer)
ax.text(yearly_avg['year'].max() - (yearly_avg['year'].max() - yearly_avg['year'].min()) * 0.05,
overall_avg + y_buffer,
f'{overall_avg:.2f}', color='red', ha='right', va='center')
for i in range(len(yearly_avg)):
if i % 2 == 0:
ax.text(yearly_avg['year'][i], yearly_avg[value_var][i] + y_buffer/2,
f'{yearly_avg[value_var][i]:.2f}', color='blue', ha='center', va='bottom')
else:
ax.text(yearly_avg['year'][i], yearly_avg[value_var][i] - y_buffer/2,
f'{yearly_avg[value_var][i]:.2f}', color='blue', ha='center', va='top')
plt.xticks(rotation=45)
ax.legend()
st.pyplot(fig)
st.write(f"**Standard deviation:** {df[value_var].std():.2f}")
st.write(f"**Variance:** {df[value_var].var():.2f}")
st.write(f"**Minimum {value_var}:** {y_min:.2f} in year {yearly_avg.loc[yearly_avg[value_var].idxmin(), 'year']}")
st.write(f"**Maximum {value_var}:** {y_max:.2f} in year {yearly_avg.loc[yearly_avg[value_var].idxmax(), 'year']}")
聚类方法选择
与假设检验和回归分析类似,在聚类方法领域也有多种选项可供选择。选择函数的结构与其他方法相似。选择一种方法并执行相应的函数,同时也会显示该方法的简要说明。根据不同的方法,必须定义簇的数量。对于 k-Means 和层次聚类,可以定义最多 10 个簇。对于 DBSCAN,则需要查询半径“eps”和每个簇的最小点数“min_samples”。每种方法必须选择至少两个数值变量。
k-Means 聚类
k-Means 算法将数据分为“n_clusters”个簇。数据点的分组方式是使得簇内数据点之间的距离最小化。簇的数量由用户确定。根据该数量,算法计算出每个数据点属于哪个簇。结果会发送到“visualize_clusters()”函数进行可视化[13]。
def perform_kmeans(X, n_clusters):
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
X['Cluster'] = kmeans.fit_predict(X)
visualize_clusters(X, 'k-Means Clustering')

ASCVIT V1 k-Means 聚类与 PCA 降维(图片来自作者)
层次聚类
在这里创建了一个簇的层次结构,可以使用聚合或分割方法。该函数使用的是聚合聚类方法,其中每个数据点最初被视为一个独立的簇,然后依次合并。簇的数量由用户确定,算法根据数量进行数据划分。与 k-Means 相同的函数用于可视化“visualize_clusters(X, ‘Hierarchical Clustering’)” [14]。
def perform_hierarchical_clustering(X, n_clusters):
hierarchical_clustering = AgglomerativeClustering(n_clusters=n_clusters)
X['Cluster'] = hierarchical_clustering.fit_predict(X)
visualize_clusters(X, 'Hierarchical Clustering')

ASCVIT V1 层次聚类与 PCA 降维(图片来自作者)
DBSCAN 聚类
通过这种方法,数据点根据其周围环境的密度进行分组。该方法非常适合检测异常值(噪声)并发现任何形状的簇。在这里,用户不指定簇的数量,而是指定两个点之间的最大距离“eps”,超过该距离就认为它们是邻居。同时,还定义了簇中出现的最小点数“min_samples”。可视化也通过“visualize_clusters()”函数生成[15]。
def perform_dbscan(X, eps, min_samples):
dbscan = DBSCAN(eps=eps, min_samples=min_samples)
X['Cluster'] = dbscan.fit_predict(X)
visualize_clusters(X, 'DBSCAN Clustering')

ASCVIT V1 DBSCAN 聚类与 PCA 降维(图片来自作者)
聚类可视化
三种不同聚类方法的结果通过“visualize_clusters”函数进行可视化,采用主成分分析(PCA)。数据的维度通过 PCA 降到两个分量“n_components”,以便能够显示聚类结果。检查数据点和变量“num_samples”是否足够;如果不足,则会显示错误信息。聚类结果通过散点图进行可视化,图中显示了前两个 PCA 分量中的数据点。
聚类结果以不同颜色显示“cmap=‘tab10’”。在图表中,轴标签“ax.set_x/ylabel”和图例“legend_labels”已做调整,以便更好地解读。数据点的大小“s”以及透明度“alpha”也进行了调整,以提高可见性。在 DBSCAN 中,离群点会被自动分配到聚类-1。聚类的每个变量的平均值以表格形式显示在可视化图形下方“st.dataframe(cluster_means)”。
def visualize_clusters(X, title):
num_samples, num_features = X.shape
n_components = min(num_samples, num_features, 2)
if n_components < 2:
st.error("Not enough data points or variables to perform PCA.")
return
pca = PCA(n_components=n_components)
try:
X_pca = pca.fit_transform(X.drop(columns=['Cluster']))
fig, ax = plt.subplots(figsize=(10, 6))
scatter = ax.scatter(X_pca[:, 0], X_pca[:, 1], c=X['Cluster'], cmap='tab10', s=25, alpha=0.4)
ax.set_title(title)
ax.set_xlabel(f'PCA 1' if n_components >= 1 else '')
ax.set_ylabel(f'PCA 2' if n_components == 2 else '')
cluster_counts = X['Cluster'].value_counts()
legend_labels = [f"Cluster {int(cluster)} ({count} points)" for cluster, count in cluster_counts.items()]
legend1 = ax.legend(handles=scatter.legend_elements()[0], labels=legend_labels)
ax.add_artist(legend1)
st.pyplot(fig)
st.write(f"**Average values per cluster:**")
cluster_means = X.groupby('Cluster').mean()
st.dataframe(cluster_means)
except ValueError as e:
st.error(f"**Error:** Not enough variables were selected.")
与 LLM 的通信
在应用程序的第一个版本中,统计计算的输出仅在描述性区域进行分析。关键数字通过“query_llm_via_cli”函数进行解释。具体来说,该函数用于通过命令行(CLI)与 LLM 进行通信。为此,使用 Python 模块“subprocess”通过命令行启动进程。LLM 通过命令[“ollama”, “run”, “llama3.1”]启动。输入存储在“stdin”中,输出存储在“stout”中。
错误和警告被存储在“stderr”中,虽然希望不会出现这些问题。输入通过“process.communicate”发送到模型。具体来说,创建的“context”被发送到函数与 LLM 进行通信。如果模型没有回应,包含超时机制“timeout=40”,它将在 40 秒后停止执行。根据所用系统的计算能力,通常模型应更早返回响应。模型的响应被清理并传递到“extract_relevant_answer”,以提取相关信息1。
def query_llm_via_cli(input_text):
"""Sends the question and context to the LLM and receives a response"""
try:
process = subprocess.Popen(
["ollama", "run", "llama3.1"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding='utf-8',
errors='ignore',
bufsize=1
)
stdout, stderr = process.communicate(input=f"{input_text}\n", timeout=40)
if process.returncode != 0:
return f"Error in the model request: {stderr.strip()}"
response = re.sub(r'\x1b\[.*?m', '', stdout)
return extract_relevant_answer(response)
except subprocess.TimeoutExpired:
process.kill()
return "Timeout for the model request"
except Exception as e:
return f"An unexpected error has occurred: {str(e)}"
def extract_relevant_answer(full_response):
response_lines = full_response.splitlines()
if response_lines:
return "\n".join(response_lines).strip()
return "No answer received"
应用程序的主函数
应用程序的结构由“main()”函数定义。标题通过“st.title()”设置,侧边栏用于上传 CSV 或 Excel 格式的数据集“uploaded_file”。上传文件后,应用程序会分析文件并提取数值和类别变量。在这里和许多其他情况下,Streamlit 使用“session_state”存储与分析方法选择相关的某些参数。
变量“numerical_columns”和“categorical_columns”将在上传新数据集后更新。一旦数据可用,用户可以从多种分析方法中选择。选择方法后,会显示该方法,并在相应的变量定义后进行操作。主功能控制应用程序的互动统计分析。
自定义选项
如前所述,由于代码的模块化结构,该应用程序可以扩展以包括其他分析方法。使用 LLM 解读统计关键数据的功能也可以转移到其他方法。当前使用的是 Meta 的 Llama3.1 (8B),但也可以使用 Ollama 的另一款 LLM(例如 Mistral)。此时,“query_llm_via_cli”函数中的命令必须进行相应调整。
根据可用的计算资源,也可以使用更多参数的模型。图表的设计可以进一步优化,传输的上下文也可以加以改进,以提升 LLM 的输出效果。或者,你也可以创建一个新的模型文件,调整 LLM 的某些参数(例如参数),从而改善数据的解读。
ASCVIT V1 PYTHON 脚本 [GITHUB]
应用程序的代码可以从以下GitHub 仓库下载。应用程序可以在相应目录中通过以下命令启动:
Streamlit run app.py
结论
在本文中,我展示了如何使用 Streamlit 创建一个应用程序,利用多种方法分析数据集。我还展示了如何将 LLM 集成到该应用程序中,从而带来实际的附加值。数据不仅会自动可视化并输出统计参数,还会进行分类。该应用程序具有广泛的发展潜力。我在倒数第二部分列出了一些建议。希望你在使用和自定义应用程序时玩得开心。

你最多可以拍手 50 次!
1 Pietrusky, S. (2024 年 8 月 21 日). 如何在不使用专有模型的情况下与 PDF 文件交互:CLI, Streamlit, Ollama. Towards Data Science. URL
[2] Maven Analytics. (2024 年 6 月 10 日). 数据游乐场,视频游戏销售. URL
[3] Hastie, T., Tibshirani, R., & Friedman, J. (2009). 统计学习的元素:数据挖掘、推断与预测(第 2 版)。斯坦福大学。URL
[4] Bruce, P., Bruce, A., & Gedeck, P. (2021). 数据科学家的实用统计学:50+个 R 和 Python 中的核心概念(第 2 版)。O’Reilly。
[5] VanderPlas J. (2017). Python 数据科学手册:处理数据的必备工具. O’Reilly。[URL]
[6] Fahrmeir, L., Künstler, R., Pigeot, I., & Tutz, G. (2016). 统计学:数据分析之路(第 8 版)。Springer.
[7] Montgomery, D. C. (2012). 实验设计与分析(第 8 版)。Wiley. URL
[8] Moore, D. S., McCabe, G. P., Craig, B. A., & Duckworth, W. M. (2021). 统计学实践导论(第 10 版)。W. H. Freeman.
[9] Montgomery, D. C., Peck, E. A., & Vining, G. G. (2012). 线性回归分析导论(第 5 版)。Wiley. URL
[10] Hosmer, D. W., Lemeshow, S., & Sturdivant, R. X. (2013). 应用逻辑回归(第 3 版)。Wiley.
[11] Johnson, R. A., & Wichern, D. W. (2007). 应用多元统计分析(第 6 版)。Pearson. URL
[12] Box, G. E. P., Jenkins, G. M., Reinsel, G. C., & Ljung, G. M. (2015). 时间序列分析:预测与控制(第 5 版)。Wiley. URL
[13] Witten, I. H., & Frank, E. (2005). 数据挖掘:实用的机器学习工具与技术(第 2 版)。Morgan Kaufmann. URL
[14] Everitt, B. S., Landau, S., Leese, M., & Stahl, D. (2011). 聚类分析(第 5 版)。Wiley. URL
[15] Aggarwal, C. C., & Reddy, C. K. (2014). 数据聚类:算法与应用。CRC Press URL

仅为缩略图(图片来源:作者)
不要问人工智能能为你做什么 — 问问你能与人工智能一起实现什么
·发布于Towards Data Science ·9 分钟阅读·2024 年 8 月 8 日
--

图片来自BoliviaInteligente,发布于Unsplash
新的前沿
在过去的一年半里,我一直在告诉我认识的每个人有关人工智能的潜力,特别是大型语言模型(LLM)。是时候让每个人,无论他们的技术背景如何,都学习 LLM 的基础知识,并了解如何高效地使用它们。
在 1960 年代,我们有了登月计划。今天,我们拥有广阔而未被探索的人工智能领域。这一次,不是去插旗或者留下脚印,而是关于拓展人类潜力。
我真心认为,在这个新时代,我们每个人都是开拓者。
挑战不仅仅是跟上新设备的步伐,而是要变得擅长于与人工智能沟通。
每个人的旅程
忘掉那些借口,比如“我不会编程”或者“太复杂了”。
竞技场是平等的,只要稍加努力,任何人都能从人工智能中受益。
只需每周花一小时自己进行测试,并弄清楚如何使用这些工具,我敢打赌,你很可能会感到有必要改进你与这些“东西”沟通的方式,以便它们能给你提供更好的结果!
作为数据科学家个人贡献者,如何请求反馈
接收清晰而有用的反馈。摒弃笼统的问题。这里有超过 60 个示例问题供你使用。
·发表于Towards Data Science ·14 分钟阅读·2024 年 9 月 18 日
--

图片由Marsha Reid提供,来源于Unsplash
请求反馈可能会让人感到困难,甚至令人畏惧。有很多因素可能影响你如何请求反馈:公司文化、个人恐惧、缺乏参与的经理、过去的反馈无效等等……
尽管这些情况可能是事实,但它们不应该阻止你请求反馈。作为一名带领团队的人,我必须定期提供大量反馈,而且每次绩效评估时也需要提供反馈。在之前的一篇文章中,我描述了如何使用 SBIN 框架提供反馈。我希望依赖这个框架,为你提供一组模板/指导问题,帮助你更有针对性地获得反馈,并从中得到更高质量的回应。
PS 1: 虽然这篇博客的标题是面向个人贡献者的,但数据科学经理也许会发现其中的内容对他们传递给团队非常有价值。
PS 2: 本文的灵感来源于与 Stefano Franco的合作。作为一名非常周到的团队负责人,他通过开发标准化团队工具帮助他人。
有限区域内的傅里叶级数的各种形式
选择在边界处表现良好的那个
·发表于 Towards Data Science ·阅读时间:8 分钟·2024 年 4 月 22 日
--

如果你查阅傅里叶分析的历史,你会发现让·巴普蒂斯特·约瑟夫·傅里叶在研究热流问题时,正式化了这套以他名字命名的级数。
傅里叶级数将周期信号表示为正弦波的和,这些正弦波的频率是基本频率的整数倍。
我们直观地知道,导热介质中的热点会向四面八方传播热量,直到温度均匀分布为止。在这个现象中,无论在空间还是时间上,都没有可见的振荡行为。那么为什么要引入一系列正弦波呢?
初始温度分布、控制微分方程和边界条件决定了一维导热介质(如一根薄金属棒)问题中温度函数 u(x, t)的演变。结果显示,初始温度分布的空间频率分量会随着时间的推移被一个衰减的指数函数所抑制,且其指数因子增长速率与空间频率的平方成正比。换句话说,初始温度分布中的高频部分衰减速度远快于低频部分,这也解释了温度分布的平滑现象。
在这个故事中,我们将回顾有限区间上定义的傅里叶级数的基础知识。我们将问题构造为,使得得到的傅里叶级数在区间的边界处具有一些理想的性质。当我们将傅里叶级数应用于解决涉及具有边界约束的微分方程问题时,这种方法将带来好处。
傅里叶级数:表示周期函数的工具
傅里叶级数可以逼近周期函数。假设 g(x)是一个周期为 2L 的周期函数。
为什么是周期 2L?
我们关注的是定义在有限区间[0, L]上的函数。我们可以构造一个周期为 2L 的周期函数 g(x),其周期函数 g(x)的定义域为[0, L],并通过在函数两端加上适当的填充,以获得理想的性质。稍后我们会回到这一点。
假设傅里叶级数存在,我们可以将 g(x)写成:

举个例子,我们考虑以下周期函数 g(x),其周期为 2L = 0.6:

图 1:周期函数 g(x)。图片由作者提供。
通过应用方程(2)、(3)、(4)并使用辛普森数值积分,可以得到 a₀、aₙ和 bₙ的以下值:

这些值,即傅里叶系数,使我们能够通过方程(1)构建 g(x)的近似值。我们在求和中包含的项数越多,近似值就越精确。图 2 展示了通过方程(1)中不同项数的求和得到的几种近似。

图 2:通过傅里叶级数中的不同项数重构 g(x)。图片由作者提供。
我们已经可以提出一些观察结果:
-
信号中的有限不连续性是可以容忍的,但它们会在重构的近似中产生波动。我们称这些不连续点附近的振荡现象为吉布斯现象。
-
傅里叶级数是一个无限项的和,但我们可以截断求和,并仍然得到原始函数的合理近似。
-
原始信号可能是离散点的样本。傅里叶级数可以在 x 轴的任何位置插值该函数。
定义在有限区间上的函数
在工程问题中,我们经常遇到定义在有限区间上的函数。例如,在导热介质的一维温度分布中,温度函数定义在区间[0, L]上,其中 L 是薄金属棒的长度。那么,在这种情况下,如何使用傅里叶级数呢?
为了解答这个问题,我们首先需要认识到,任何在范围[0, L]内与目标函数 f(x)一致的周期函数 g(x),都是 f(x)的傅里叶级数表示的有效候选。毕竟,我们不关心傅里叶级数在[0, L]范围之外的行为。
f(x)的天真周期复制
构建 g(x)的最直接方法是将 f(x)在区间[-L, 0]内复制,如图 3 所示:

图 3:f(x) 在[0, 0.3]范围内定义,并在范围[-0.3, 0]内复制,构建周期为 0.6 的周期函数 g(x)。图片来源:作者。
对 f(x)的天真周期复制进行傅里叶积分,得到方程(5)到(7):

通过将(5)、(6)、(7)代入方程(1)中的 f(x)(参见图 3),我们得到了图 4 所示的傅里叶级数重构:

图 4:图 3 中的 f(x)(原始信号)和傅里叶级数,显示为信号重构。图片来源:作者。
傅里叶级数与原始信号非常接近,除了在范围边界处,重构会出现振荡和跳跃。由于我们明确构造了一个周期为 L 的周期信号,傅里叶级数将 x=0 和 x=L 处的过渡解释为有限的间断点。
傅里叶级数允许有限的间断性,但吉布斯现象会在间断点周围恶化重构效果。
对于许多工程应用来说,这是一个问题。例如,在薄金属棒的热传导问题中,金属棒两端(即边界条件)发生的情况是问题描述的一个内在部分。我们可以假设有一根孤立的金属棒,这意味着两端的温度梯度必须为 0。或者,我们可以假设在 x=0 和 x=L 处有任意的设定温度。在这些常见场景中,我们不能使用天真周期复制 f(x)的方法,因为吉布斯现象会在范围的两端破坏信号。
偶数半范围扩展
我们可以将 f(x)复制为图 5 中的形式,也可以在范围[-L, 0]内使用 f(x)的翻转版本:

图 5:g(x) = f(-x) 在范围[-L, 0]内。图片来源:作者。
这种方法消除了 x=0 和 x=L 处的间断性。f(x)的偶数半范围扩展的傅里叶积分得到方程(8)到(10):

图 6 展示了 f(x)的傅里叶级数重构:

图 6:原始信号及其通过偶数半范围扩展的重构。图片来源:作者。
偶数半范围展开的一个特性是,由于 g(x)是偶函数,所有的 bₙ系数(参见方程(10))为 0,因此它的傅里叶级数仅由余弦项组成。因此,傅里叶级数的导数在 x=0 和 x=L 处为零。你可以通过对方程(1)关于 x 进行求导,且将所有 bₙ项设置为 0 来验证这一点。
这是我们在某些情况下所需要的,例如金属杆被隔离,没有热量泄漏到端部。
奇数半范围展开
如果我们改为创建一个奇函数会怎么样呢?这可以通过将 f(x)的旋转版本粘贴到区间[-L, 0]中来实现,如图 7 所示:

图 7:g(x) = -f(-x)在区间[-L, 0]。图片由作者提供。
f(x)的奇数半范围展开的傅里叶积分得到方程(11)到(13):

图 8 显示了 f(x)的傅里叶级数重构:

图 8:原始信号及其通过奇数半范围展开的重构。图片由作者提供。
由于 g(x)是奇函数,傅里叶级数仅由正弦项组成。因此,傅里叶级数在 x=0 和 x=L 处为零。这个特性可以在模拟振动吉他弦的形状时加以利用。吉他弦在 x=0 和 x=L 处的高度被限制为 0,因此我们自然会使用奇数半展开来模拟初始条件。

图片由Rio Lecatompessy提供,来源于Unsplash
偶数四分之一范围展开
我们还可以更加创新,设计一个周期为 4L 的周期性函数。如果我们希望在 x=0 处的导数为 0,并且在 x=L 处的值和导数都平滑过渡,可以在[L, 2L]区间附加 f(x)的旋转副本并使该函数为偶函数。图 9 展示了一个例子:

图 9:g(x) = 2f(L) - f(2L+x)在区间[-2L, -L];f(-x)在区间[-L, 0];f(x)在区间[0, L];2f(L)-f(2L-x)在区间[L, 2L]。图片由作者提供。
f(x)的偶数四分之一范围展开的傅里叶积分得到方程(14)到(16):

图 10 显示了 f(x)的傅里叶级数重构:

图 10:原始信号和通过偶数四分之一范围展开的傅里叶级数重构。图片由作者提供。
尽管从图中无法看出,傅里叶级数的导数在 x=0 处为 0,并且在 x=L 处与原始信号相同。
奇数四分之一范围展开
我们考虑的最后一个情况是,当我们希望 x=0 处的值为 0,x=L 处的导数为 0 时。我们通过在[L, 2L]范围内附加 f(x)的翻转版本来构建 g(x),并使该函数成为奇函数。

图 11:g(x) = -f(x+2L)在范围[-2L, L]内;-f(-x)在范围[-L, 0]内;f(x)在范围[0, L]内;f(2L-x)在范围[L, 2L]内。图片由作者提供。
f(x)的奇数四分之一范围展开的傅里叶积分得到方程(17)到(19):

图 12 显示了 f(x)的傅里叶级数重建:

图 12:原始信号与傅里叶级数重建,采用奇数四分之一范围展开。图片由作者提供。
我们可以看到,重建在 x=0 时通过 0。即使原始信号的导数不是零,x=L 处的导数为零。
结论
我们考虑了为定义在有限区间[0, L]上的信号 f(x)找到合适的傅里叶级数展开的问题。傅里叶级数适用于周期性函数,因此我们必须构建一个与 f(x)在定义域上匹配的周期性函数。我们观察到定义周期性函数 g(x)的四种方法。每种方法都能确保在范围边界处具有特定的属性:
-
偶数半范围展开:傅里叶级数在 x=0 和 x=L 处的导数为 0。
-
奇数半范围展开:傅里叶级数在 x=0 和 x=L 处的值为 0。
-
偶数四分之一范围展开:傅里叶级数在 x=0 处的导数为 0,并且在 x=L 处具有平滑的值和导数。
-
奇数四分之一范围展开:傅里叶级数在 x=0 处的值为 0,x=L 处的导数为 0。
在未来的故事中,我们将研究热量如何在细长的金属棒中传递。解决方案涉及将初始温度分布转换为傅里叶级数。我们将观察到,傅里叶级数展开的选择自然由边界条件决定(例如,金属棒在 x=0 处是隔离的,在 x=L 处保持固定温度)。我们在这篇文章中创建的那些看似任意的周期性函数,将突然变得有意义!
参考文献
(R1) 《高级工程数学》,Erwin Kreyszig,John Wiley & Sons,1988
使用 Celery、Redis 和 Florence 2 进行异步机器学习推理
一份简单的教程,帮助你入门异步机器学习推理
·发表于 Towards Data Science ·阅读时间 6 分钟·2024 年 7 月 19 日
--

图片由 Fabien BELLANGER 提供,来源:Unsplash
大多数机器学习服务教程都专注于实时同步服务,这种方式可以即时响应预测请求。然而,这种方法在面对流量激增时可能会遇到困难,并且不适合长时间运行的任务。它还需要更强大的机器来快速响应,如果客户端或服务器发生故障,预测结果通常会丢失。
在这篇博客中,我们将演示如何使用 Celery 和 Redis 将机器学习模型作为异步工作者运行。我们将使用 Florence 2 基础模型,这是一个以其卓越表现而闻名的视觉语言模型。本教程将提供一个最小但功能完善的示例,您可以根据自己的用例进行调整和扩展。
我们解决方案的核心基于 Celery,这是一个 Python 库,它为我们实现了这种客户端/工作者逻辑。它允许我们将计算任务分配到多个工作者,提高机器学习推理应用在高负载和不可预测负载下的可扩展性。
该过程如下所示:
- 客户端将任务及一些参数提交到由中介(我们示例中是 Redis)管理的队列中。
视觉变换器的注意力机制解析
视觉变换器解析系列
计算机视觉中注意力层背后的数学与代码
·发表于 Towards Data Science ·12 分钟阅读·2024 年 2 月 27 日
--
自 2017 年《Attention is All You Need》¹提出以来,变换器(transformers)已经成为自然语言处理(NLP)领域的前沿技术。2021 年,《An Image is Worth 16x16 Words》²成功地将变换器应用于计算机视觉任务。从那时起,许多基于变换器的架构被提出用于计算机视觉。
本文深入探讨了在计算机视觉领域,注意力层是如何工作的。我们将讨论单头和多头注意力机制。文章中包含了注意力层的开源代码,并对其底层数学原理进行了解释。代码使用的是 PyTorch Python 包。

图片来源:Mitchell Luo 在 Unsplash
本文是一个系列文章的一部分,深入探讨视觉变换器的内部工作原理。每篇文章也可以作为一个带有可执行代码的 Jupyter Notebook 进行阅读。该系列的其他文章包括:
-
视觉变换器解析→ Jupyter Notebook
-
视觉变换器的注意力机制解析
-
视觉变换器位置嵌入解释
-
Tokens-to-Token 视觉变换器解析
目录
-
一般注意力机制
-
单头注意力
-
多头注意力
-
结论
— 进一步阅读
— 引用
一般注意力机制
对于 NLP 应用,注意力通常被描述为句子中单词(tokens)之间的关系。在计算机视觉应用中,注意力关注的是图像中 patches(tokens)之间的关系。
有多种方法可以将一张图片分解成一系列的 tokens。原始的 ViT²将图像分割成patches,然后将这些patches展平为tokens;有关这种patch tokenization的更详细解释,请参阅视觉变换器文章。Tokens-to-Token ViT³开发了一种更复杂的方法,从图像中创建 tokens;关于该方法的更多信息可以在 Tokens-To-Token ViT 文章中找到。
本文将假设 tokens 作为输入,逐步通过一个注意力层。在变换器的开始,tokens 将代表输入图像中的 patches。然而,随着深层注意力层的计算,tokens 将被前面的层修改,从而移除了直接的表示关系。
本文探讨了在Attention is All You Need¹中定义的点积(即乘法)注意力机制。这与在An Image is Worth 16x16 Words²和Tokens-to-Token ViT³等衍生作品中使用的注意力机制相同。代码基于公开的Tokens-to-Token ViT³的 GitHub 代码,经过了一些修改。源代码的更改包括但不限于将两个注意力模块合并为一个,并实现了多头注意力。
完整的注意力模块如下所示:
class Attention(nn.Module):
def __init__(self,
dim: int,
chan: int,
num_heads: int=1,
qkv_bias: bool=False,
qk_scale: NoneFloat=None):
""" Attention Module
Args:
dim (int): input size of a single token
chan (int): resulting size of a single token (channels)
num_heads(int): number of attention heads in MSA
qkv_bias (bool): determines if the qkv layer learns an addative bias
qk_scale (NoneFloat): value to scale the queries and keys by;
if None, queries and keys are scaled by ``head_dim ** -0.5``
"""
super().__init__()
## Define Constants
self.num_heads = num_heads
self.chan = chan
self.head_dim = self.chan // self.num_heads
self.scale = qk_scale or self.head_dim ** -0.5
assert self.chan % self.num_heads == 0, '"Chan" must be evenly divisible by "num_heads".'
## Define Layers
self.qkv = nn.Linear(dim, chan * 3, bias=qkv_bias)
#### Each token gets projected from starting length (dim) to channel length (chan) 3 times (for each Q, K, V)
self.proj = nn.Linear(chan, chan)
def forward(self, x):
B, N, C = x.shape
## Dimensions: (batch, num_tokens, token_len)
## Calcuate QKVs
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
#### Dimensions: (3, batch, heads, num_tokens, chan/num_heads = head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
## Calculate Attention
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
#### Dimensions: (batch, heads, num_tokens, num_tokens)
## Attention Layer
x = (attn @ v).transpose(1, 2).reshape(B, N, self.chan)
#### Dimensions: (batch, heads, num_tokens, chan)
## Projection Layers
x = self.proj(x)
## Skip Connection Layer
v = v.transpose(1, 2).reshape(B, N, self.chan)
x = v + x
#### Because the original x has different size with current x, use v to do skip connection
return x
单头注意力
从只有一个注意力头开始,我们一步步走过前向传递的每一行,并在过程中查看一些矩阵图示。我们使用 7∗7=49 作为起始的 token 大小,因为这是 T2T-ViT 模型中的起始 token 大小。³我们使用 64 个通道,因为这是 T2T-ViT 的默认值。³我们使用 100 个 tokens,因为这是一个合适的数字。我们使用 13 的 batch 大小,因为它是质数,不会与其他参数混淆。
# Define an Input
token_len = 7*7
channels = 64
num_tokens = 100
batch = 13
x = torch.rand(batch, num_tokens, token_len)
B, N, C = x.shape
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
# Define the Module
A = Attention(dim=token_len, chan=channels, num_heads=1, qkv_bias=False, qk_scale=None)
A.eval();
Input dimensions are
batchsize: 13
number of tokens: 100
token size: 49
来自Attention is All You Need¹,注意力是通过Query、Key 和Value 矩阵来定义的。第一步是通过一个可学习的线性层来计算这些矩阵。布尔值qkv_bias项表示这些线性层是否具有偏置项。此步骤还会将输入的标记长度从 49 更改为chan参数,我们将其设置为 64。

生成单头注意力的查询、键和值(图片由作者提供)
qkv = A.qkv(x).reshape(B, N, 3, A.num_heads, A.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
print('Dimensions for Queries are\n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
print('See that the dimensions for queries, keys, and values are all the same:')
print('\tShape of Q:', q.shape, '\n\tShape of K:', k.shape, '\n\tShape of V:', v.shape)
Dimensions for Queries are
batchsize: 13
attention heads: 1
number of tokens: 100
new length of tokens: 64
See that the dimensions for queries, keys, and values are all the same:
Shape of Q: torch.Size([13, 1, 100, 64])
Shape of K: torch.Size([13, 1, 100, 64])
Shape of V: torch.Size([13, 1, 100, 64])
现在,我们可以开始计算注意力,注意力的定义如下:
其中Q, K, V分别是查询、键和值;dₖ是键的维度,它等于键标记的长度,并且等于chan的长度。
我们将逐步分析代码中实现的这个方程。我们将中间矩阵称为Attn。
第一步是计算:
在代码中,我们设置了
默认情况下,
然而,用户可以指定一个作为超参数的替代缩放值。
分子中的矩阵乘法Q·Kᵀ看起来是这样的:

Q·Kᵀ 矩阵乘法(图片由作者提供)
所有这些代码看起来是这样的:
attn = (q * A.scale) @ k.transpose(-2, -1)
print('Dimensions for Attn are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])
Dimensions for Attn are
batchsize: 13
attention heads: 1
number of tokens: 100
number of tokens: 100
接下来,我们计算A的 softmax,这不会改变其形状。
attn = attn.softmax(dim=-1)
print('Dimensions for Attn are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])
Dimensions for Attn are
batchsize: 13
attention heads: 1
number of tokens: 100
number of tokens: 100
最后,我们计算A·V=x,其形式如下:

A·V 矩阵乘法(图片由作者提供)
x = attn @ v
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3])
Dimensions for x are
batchsize: 13
attention heads: 1
number of tokens: 100
length of tokens: 64
输出x被重新形状化,以去除注意力头维度。
x = x.transpose(1, 2).reshape(B, N, A.chan)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
Dimensions for x are
batchsize: 13
number of tokens: 100
length of tokens: 64
然后,我们将x通过一个可学习的线性层,这个线性层不会改变它的形状。
x = A.proj(x)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
Dimensions for x are
batchsize: 13
number of tokens: 100
length of tokens: 64
最后,我们实现了一个跳跃连接。由于当前x的形状与输入的x形状不同,我们使用V来进行跳跃连接。我们确实在注意力头维度上展平了V。
orig_shape = (batch, num_tokens, token_len)
curr_shape = (x.shape[0], x.shape[1], x.shape[2])
v = v.transpose(1, 2).reshape(B, N, A.chan)
v_shape = (v.shape[0], v.shape[1], v.shape[2])
print('Original shape of input x:', orig_shape)
print('Current shape of x:', curr_shape)
print('Shape of V:', v_shape)
x = v + x
print('After skip connection, dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
Original shape of input x: (13, 100, 49)
Current shape of x: (13, 100, 64)
Shape of V: (13, 100, 64)
After skip connection, dimensions for x are
batchsize: 13
number of tokens: 100
length of tokens: 64
这就完成了注意力层!
多头注意力
现在我们已经看过了单头注意力,我们可以扩展到多头注意力。在计算机视觉中,这通常称为多头自注意力(MSA)。本节不会详细讲解所有步骤;相反,我们将专注于矩阵形状有所不同的部分。
与单头注意力相同,我们使用 7∗7=49 作为我们的起始标记大小,并使用 64 个通道,因为这是 T2T-ViT 的默认值³。我们使用 100 个标记,因为这个数字很好。我们使用的批量大小为 13,因为它是质数,不会与其他参数混淆。
注意力头的数量必须能够整除通道数,因此在这个例子中,我们使用 4 个注意力头。
# Define an Input
token_len = 7*7
channels = 64
num_tokens = 100
batch = 13
num_heads = 4
x = torch.rand(batch, num_tokens, token_len)
B, N, C = x.shape
print('Input dimensions are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\ttoken size:', x.shape[2])
# Define the Module
MSA = Attention(dim=token_len, chan=channels, num_heads=num_heads, qkv_bias=False, qk_scale=None)
MSA.eval();
Input dimensions are
batchsize: 13
number of tokens: 100
token size: 49
计算Queries、Keys 和Values 的过程与单头注意力中相同。然而,你可以看到标记的新长度是chan/num_heads。Q、K和V矩阵的总大小没有变化,它们的内容只是分布在头维度上。你可以将其视为将单头矩阵分割为多个头:

多头注意力分割(图片由作者提供)
我们将子矩阵表示为 Qₕᵢ,表示Query head i。
qkv = MSA.qkv(x).reshape(B, N, 3, MSA.num_heads, MSA.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
print('Head Dimension = chan / num_heads =', MSA.chan, '/', MSA.num_heads, '=', MSA.head_dim)
print('Dimensions for Queries are\n\tbatchsize:', q.shape[0], '\n\tattention heads:', q.shape[1], '\n\tnumber of tokens:', q.shape[2], '\n\tnew length of tokens:', q.shape[3])
print('See that the dimensions for queries, keys, and values are all the same:')
print('\tShape of Q:', q.shape, '\n\tShape of K:', k.shape, '\n\tShape of V:', v.shape)
Head Dimension = chan / num_heads = 64 / 4 = 16
Dimensions for Queries are
batchsize: 13
attention heads: 4
number of tokens: 100
new length of tokens: 16
See that the dimensions for queries, keys, and values are all the same:
Shape of Q: torch.Size([13, 4, 100, 16])
Shape of K: torch.Size([13, 4, 100, 16])
Shape of V: torch.Size([13, 4, 100, 16])
下一步是计算
对于每个头* i *。在此上下文中,键的长度是
与单头注意力一样,我们使用默认值
尽管用户可以将替代的缩放值指定为超参数。
我们以num_heads = 4 个不同的Attn 矩阵结束此步骤,如下所示:

Q·Kᵀ 矩阵乘法用于 MSA(图片由作者提供)
attn = (q * MSA.scale) @ k.transpose(-2, -1)
print('Dimensions for Attn are\n\tbatchsize:', attn.shape[0], '\n\tattention heads:', attn.shape[1], '\n\tnumber of tokens:', attn.shape[2], '\n\tnumber of tokens:', attn.shape[3])
Dimensions for Attn are
batchsize: 13
attention heads: 4
number of tokens: 100
number of tokens: 100
接下来我们计算A的 softmax,它的形状不会改变。
然后,我们可以计算
这在多个注意力头之间类似地分布:

A·V 矩阵乘法用于 MSA(图片由作者提供)
attn = attn.softmax(dim=-1)
x = attn @ v
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tattention heads:', x.shape[1], '\n\tnumber of tokens:', x.shape[2], '\n\tlength of tokens:', x.shape[3])
Dimensions for x are
batchsize: 13
attention heads: 4
number of tokens: 100
length of tokens: 16
现在我们通过一些重塑操作将所有的 xₕᵢ合并在一起。这是第一步的逆操作:

多头注意力分割(图片由作者提供)
x = x.transpose(1, 2).reshape(B, N, MSA.chan)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
Dimensions for x are
batchsize: 13
number of tokens: 100
length of tokens: 64
现在我们已经将所有头连接在一起,注意力模块的其余部分保持不变。对于跳过连接,我们仍然使用V,但我们必须重新调整其形状以去除头维度。
x = MSA.proj(x)
print('Dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
orig_shape = (batch, num_tokens, token_len)
curr_shape = (x.shape[0], x.shape[1], x.shape[2])
v = v.transpose(1, 2).reshape(B, N, A.chan)
v_shape = (v.shape[0], v.shape[1], v.shape[2])
print('Original shape of input x:', orig_shape)
print('Current shape of x:', curr_shape)
print('Shape of V:', v_shape)
x = v + x
print('After skip connection, dimensions for x are\n\tbatchsize:', x.shape[0], '\n\tnumber of tokens:', x.shape[1], '\n\tlength of tokens:', x.shape[2])
Dimensions for x are
batchsize: 13
number of tokens: 100
length of tokens: 64
Original shape of input x: (13, 100, 49)
Current shape of x: (13, 100, 64)
Shape of V: (13, 100, 64)
After skip connection, dimensions for x are
batchsize: 13
number of tokens: 100
length of tokens: 64
这就结束了多头注意力!
结论
现在我们已经走完了视觉变换器中注意力层的每个步骤。注意力层中的可学习权重位于从标记到查询、键和值的第一次投影以及最终投影中。注意力层的大部分是确定性的矩阵乘法。然而,当使用较长的标记时,线性层可能包含大量的权重。QKV 投影层中的权重数量等于input_token_len∗chan∗3,而最终投影层中的权重数量等于chan²。
要使用注意力层,你可以创建自定义的注意力层(如这里所做!),或者使用机器学习包中包含的注意力层。如果你想使用此处定义的注意力层,可以在GitHub 仓库找到该系列文章的代码。PyTorch 也提供了 torch.nn.MultiheadedAttention()⁴ 层,它按照上述定义计算注意力。祝你注意力集中!
本文已通过洛斯阿拉莫斯国家实验室批准发布,编号 LA-UR-23-33876。相关代码已通过 BSD-3 开源许可证批准,许可证编号 O#4693。
进一步阅读
要了解更多关于 NLP 上下文中注意力层的内容,请参见
-
视觉变换器直观解析第一部分:功能概述:
towardsdatascience.com/transformers-explained-visually-part-1-overview-of-functionality-95a6dd460452 -
视觉变换器直观解析第二部分:一步一步了解其工作原理:
towardsdatascience.com/transformers-explained-visually-part-2-how-it-works-step-by-step-b49fa4a64f34 -
视觉变换器直观解析第三部分:多头注意力深度解析:
towardsdatascience.com/transformers-explained-visually-part-3-multi-head-attention-deep-dive-1c1ff1024853 -
视觉指南:变换器神经网络多头和自注意力视频:
www.youtube.com/watch?v=mMa2PmYJlCo
关于视觉变换器的广泛视频讲解(包括相关章节),请见
-
视觉变换器及其应用:
youtu.be/hPb6A92LROc?si=GaGYiZoyDg0PcdSP— 人类视觉注意力:4:31 — 5:18 (
youtu.be/hPb6A92LROc?t=271&si=VMx2lM9lvW-oKcW_)— Attention 作为点积:5:18–6:14 (https://youtu.be/hPb6A92LROc?t=318&si=pF2SFp2XXjK8AWsL)
— Attention 公式的描述:16:13–17:52 (
youtu.be/hPb6A92LROc?si=toAgKQCOh9zGCR-c&t=973)— 为什么使用多头自注意力:19:44–19:58 (
youtu.be/hPb6A92LROc?t=1184&si=Sy1e149ukt99DoRf)
引用
1 Vaswani 等人(2017 年)。Attention Is All You Need. doi.org/10.48550/arXiv.1706.03762
[2] Dosovitskiy 等人(2020 年)。一张图胜过 16x16 个词:用于大规模图像识别的变换器. doi.org/10.48550/arXiv.2010.11929
[3] Yuan 等人(2021 年)。Tokens-to-Token ViT:从头开始在 ImageNet 上训练视觉变换器. doi.org/10.48550/arXiv.2101.11986
→ GitHub 代码:github.com/yitu-opensource/T2T-ViT
[4] PyTorch。多头注意力. pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
注意力(并非)你所需要的一切
一种替代的文本生成变换器模型方法
·发布于Towards Data Science ·阅读时间:7 分钟·2024 年 11 月 19 日
--

分形图案能否帮助我们创造出更高效的文本生成模型?照片由Giulia May拍摄,来源于Unsplash
自 2022 年 11 月底 ChatGPT 发布以来,LLM(大型语言模型)几乎成为了家喻户晓的名字。

全球范围内对“LLM”的搜索兴趣。来源:Google Trends
这样做是有充分理由的;它们的成功归功于其架构,特别是注意力机制。它使模型能够将每个处理的单词与其他所有单词进行比较。
这使得 LLM 在理解和生成类人文本方面具有我们熟知的非凡能力。
然而,这些模型并非没有缺陷。它们训练时需要巨大的计算资源。例如,Meta 的 Llama 3 模型训练耗时 770 万个 GPU 小时1。此外,它们对庞大数据集的依赖——涵盖了数万亿个标记——引发了关于可扩展性、可获取性以及环境影响的疑问。
尽管存在这些挑战,自从 2017 年中期发布论文《Attention is all you need》以来,人工智能领域的最新进展大多集中在进一步扩展注意力机制,而非探索根本全新的架构。
注意力,注意了!
注意力是你所需要的一切,但其范围是有限的。
·发表于 Towards Data Science ·10 分钟阅读·2024 年 8 月 29 日
--

图片由Google DeepMind提供,来源于Unsplash
FlashAttention 第二部分:通过实际的类比、简单的视觉呈现和通俗的叙述,为注意力机制提供直观的介绍。第一部分的内容现已发布。
在上一章中,我从一个高层次的角度介绍了 FlashAttention 机制,采用了“像我五岁一样解释”(ELI5)的方法。这种方式我最为认同;我总是尽力将困难的概念与现实生活中的类比联系起来,我发现这种方式有助于长期记忆。
接下来我们教育菜单上的主菜是标准的注意力算法——这是我们如果以后想要加料,不能跳过的一道菜。先理解它,然后再改进它。这是无法绕开的。
到现在为止,你可能已经浏览了大量关于注意力机制的文章,并观看了无数的 YouTube 视频。的确,注意力机制是 AI 领域的明星,每个人都急于在某个功能上与之合作。
所以,我也跳出来分享我对这个广受欢迎概念的看法,并为一些曾启发我的资源送上致谢。我将继续使用我们经过验证的类比方式,但也会融入更具视觉化的呈现。呼应我之前的观点(冒着听起来像破碎唱片的风险……)
音频扩散:生成音乐的秘密武器

图像由 DALL·E 生成
探讨扩散技术背后的原理及其如何被应用于为艺术家和制作人创造突破性人工智能工具。
·发表于 Towards Data Science ·阅读时间:14 分钟·2024 年 1 月 22 日
--
最近关于生成音乐人工智能算法的炒作引起了广泛关注。有些人认为它是创造力的未来,而另一些人则认为它是音乐的终结。虽然我倾向于支持前者,但作为一名工程师和研究员,我通常尝试从更加客观的角度来看待这些进展。鉴于此,我想介绍一下驱动生成音频和音乐世界的核心技术之一:扩散。
我的目标不是推销或贬低这些炒作,而是揭示这些技术背后发生的事情,让音乐家、制作人、爱好者和创作者能更好地理解这些看似神奇的音乐创作黑盒。我将回答这些人工智能算法“创造出完全新的东西”这一说法的含义,以及它与人类原创性有何不同。我希望通过更清晰的解释,降低集体的焦虑,并提供洞见,帮助创作者更好地利用这些强大的技术。
本文将涉及一些技术话题,但你不需要具备工程背景就能理解。让我们先从一些背景信息和定义开始。
背景
“AI 生成”这个术语已经在音乐行业中变得非常普及,但什么才算是“AI 生成”的内容实际上是相当模糊的。为了赶上这个流行词的热潮,这个说法被随意地使用,无论是 AI 用来模仿某种效果、自动混音或母带处理、分离音轨,还是增强音色。只要最终的音频在某种程度上受到 AI 的处理,这个术语就会被贴上整个作品。然而,目前发布的大多数音乐仍然主要通过人工制作生成(是的,即使是鬼才作曲者的“Heart On My Sleeve” 👻)。
尽管“AI 生成”这个术语为了点击率而变得老生常谈,但它的恰当使用是当新声音的确是由计算机生成时,即生成音频。
音频生成可以涵盖音效样本、旋律、人声,甚至完整歌曲的创作。实现这一点的两种主要方式是通过MIDI 生成和音频波形生成。MIDI(数字乐器接口)生成计算成本较低,并且能够提供高质量的输出,因为生成的 MIDI 数据会通过现有的虚拟乐器来产生声音。这与制作人通过piano roll编程 MIDI 并通过VST插件(如Serum)播放的概念相同。

Pro Tools 中的 MIDI 钢琴卷轴
虽然这一点很有吸引力,但它仅部分是生成的,因为实际上没有音频是由 AI 生成的,就像人类不能凭空合成乐器的声音一样。创作能力还受到算法能够访问的虚拟乐器的限制。即便有这些限制,采用这种技术的产品,如AIVA和Seeds by Lemonaide,也能够生成相当引人注目的输出。
音频波形生成是一项更为复杂的任务,因为它是一个端到端的系统,不依赖任何外部技术。换句话说,它从零开始生成声音。这个过程最精确地符合“AI 生成”音频的真正定义。
音频波形生成可以通过多种方法实现,并产生不同的结果。它可以生成单个样本,比如Audialab 的 ED 2和Humanize,或是我之前的作品 Tiny Audio Diffusion,也可以生成完整的歌曲,如AudioLM、Moûsai、Riffusion、MusicGen和Stable Audio。在这些最先进的模型中,许多都利用了某种形式的扩散来生成声音。你可能至少在某种程度上听说过扩散,可能是通过稳定扩散或其他一些曾席卷全球的顶尖图像生成模型。这种生成方法同样可以应用于音频。那么这一切到底是什么意思呢?
什么是扩散?
基础知识
在人工智能的背景下,扩散指的只是给信号添加或移除噪声的过程(就像老电视机中的静电噪声)。前向扩散向信号中添加噪声(噪声化),而反向扩散则移除噪声(去噪)。从概念层面来看,扩散模型将白噪声逐步通过去噪过程,直到音频类似于某个可识别的声音,比如一个样本或一首歌。这个去噪过程是许多生成音频模型创造力的秘密武器。

音频波形扩散(来源:CRASH:基于原始音频评分的可控高分辨率鼓声合成生成模型(Rouard, Hadjeres))
这个过程最初是为图像开发的。观察噪声如何解析成一幅图像(例如,一只小狗坐在网球旁边)能更清楚地展示这些模型是如何工作的。

图像扩散(图像生成使用稳定扩散)
通过概念性的理解,让我们深入探讨音频扩散模型架构的关键组成部分。虽然这会涉及一些技术性内容,但请跟着我,因为对这些算法如何工作的深入理解将更好地说明它们是如何产生结果的(如果没有,您随时可以向ChatGPT请求简明扼要的解释)。
U-Net 模型架构、压缩与重建
在音频扩散模型的核心是U-Net。U-Net 最初是为了医学图像分割而开发的,因其外形像字母 U 而得名,后来由于其强大的能力能够捕捉数据中的局部和全局特征,被适应用于生成音频。原始的 U-Net 是一个二维卷积神经网络(CNN),用于图像处理,但也可以适配为一维卷积,以处理音频波形数据。请参见下面的原始 U-Net 架构(用于图像)的视觉表示。

U-Net(来源:U-Net: 卷积网络用于生物医学图像分割(Ronneberger 等))
类似于变分自编码器(VAE),U-Net 由编码器(U 的左侧)和解码器(U 的右侧)组成,通过瓶颈(U 的底部层)相连接。然而,与 VAE 不同,U-Net 具有跳跃连接(由水平灰色箭头表示),这些连接将编码器与解码器连接起来,这是生成高分辨率输出的关键部分。编码器负责捕捉输入音频信号的特征,或特性,而解码器负责信号的重建。
为了帮助可视化,可以想象音频数据从 U 的左上方进入,沿着红色和蓝色箭头通过编码器向下到达 U 的底部瓶颈层,然后再沿着蓝色和绿色箭头通过解码器回到 U 的右上方。每个蓝色矩形代表一个模型层。在编码器的每一层中,输入音频信号会逐渐被压缩,直到它在 U 的底部(瓶颈处)达到高度浓缩的声音表示。然后,解码器接收这个压缩信号,并有效地逆转这一过程以重建信号。数据通过的每一层(蓝色矩形)都有一系列可调的权重,可以看作是成千上万的微小旋钮,用户可以旋转这些旋钮来调整压缩/重建过程。具有不同压缩级别的层允许模型从数据中学习各种特征,从大尺度的特征(例如旋律和节奏)到细粒度的细节(例如高频音色特征)。
使用类比,你可以将整个系统想象成创建一个MP3音频文件并在播放设备上收听该 MP3 的过程。从本质上讲,MP3 是音频信号的压缩版本。假设编码器的工作是创建一种新的压缩音频格式,类似于 MP3,以尽可能地压缩音频信号而不损失保真度。然后,解码器的工作就像你的 iPhone(或任何播放设备),将 MP3 解码成可以通过耳机播放的高保真音频表现形式。瓶颈可以看作是这个新创建的 MP3 类型格式本身。U-Net 代表的是压缩和重建的过程,而不是音频数据。然后,可以以准确压缩和重建各种音频信号为目标训练这个模型。
这一切都很好,但我们还没有生成任何内容。我们只构建了压缩和重建音频信号的方法。然而,这个过程是生成新音频所必需的基本过程,而且只需稍微调整一下就能实现。
噪声与去噪
让我们回顾一下我们之前提到的噪声和去噪的概念。从理论上讲,我们曾设想过一个魔法模型,它可以被训练来将一些白噪声“去噪”成可识别的音频,可能是一首美丽的协奏曲。这个魔法模型的一个关键要求是,它必须能够以高保真度重建输入的音频信号。幸运的是,U-Net 架构的设计正是为了完成这一任务。因此,接下来要解决的难题是修改 U-Net 以执行这个去噪过程。
违反直觉的是,为了教会一个模型去噪音频信号,首先要教它如何给信号添加噪声。一旦它学会了这个过程,它就自然知道如何执行逆操作,以去除噪声。
回想一下前一部分,我们详细描述了 U-Net 如何学习压缩和重建音频信号。噪声处理过程几乎遵循相同的公式,但不同的是,U-Net 并不是重建完全相同的输入音频信号,而是被指导重建加入少量噪声的输入音频信号。这可以通过反转之前小狗图像序列中的步骤来可视化。

扩散噪声步骤(图像由 Stable Diffusion 生成)
向信号添加噪声的过程必须是概率性的(即可预测的)。模型首先展示给一个音频信号,然后被指示预测添加少量高斯噪声后的信号。由于其特性,高斯噪声最为常见,但并非必须使用。噪声必须由概率分布定义,意味着它遵循一个特定的模式,且这个模式是可以一致预测的。这个过程会在多个步骤中重复,直到信号最终变成只有噪声。

添加噪声到击鼓样本(来源:CRASH: 基于音频分数的可控高分辨率鼓声合成的原始音频生成模型 (Rouard, Hadjeres))
例如,让我们以一个击鼓样本为例。U-Net 接收到这个击鼓样本,并被要求重建这个击鼓声音,但加入一些噪声,使其听起来不那么干净。然后,这个略带噪声的击鼓样本被再次提供给模型,并再次要求重建这个击鼓样本,同时增加更多噪声。这个循环会重复进行,直到听起来像是击鼓样本已经不存在,只剩下白噪声。接着,模型被教会如何在广泛的声音中执行这种操作。一旦它成为预测如何向输入音频信号添加噪声的专家,因为这个过程是概率性的,它就可以简单地反转,使得在每一步移除一些噪声。这就是模型在提供白噪声时能够生成击鼓样本的方式。
由于这个过程的概率性特征,一些令人难以置信的能力出现了,特别是模拟创造力的能力。
让我们继续讨论击鼓的例子。假设模型已经在成千上万个单次击鼓样本上进行了训练。你可能会认为它可以拿一些白噪声,然后将其转化为任何一个这些击鼓样本。然而,模型的学习方式并不完全是这样。由于它被展示了如此广泛的声音范围,它反而学会了创建那些与它训练过的击鼓样本大致相似的声音,但并不完全相同。这就是如何创造全新声音的过程,这些模型看起来展现出了某种创造力的火花。
为了说明这一点,我们使用以下草图。

假设所有可能的声音,从吉他弹奏到狗吠声,再到白噪声,都可以绘制在一个二维平面上,平面由上图中的黑色矩形表示。在这个空间中,有一个区域是小军鼓击打声所在的位置。由于它们在音色和瞬态特性上的相似性,它们被稍微聚集在一起。这由蓝色的斑点显示,每一个蓝色的点代表我们用来训练模型的一个小军鼓样本。红色的点代表模型训练时使用的已经加入噪声的小军鼓版本,并与它们未加入噪声的蓝色点样本相对应。
本质上,我们的模型学会了将“非小军鼓”区域的点带入“军鼓”区域。所以,如果我们从“非小军鼓”区域(例如随机噪声)选取一个新的绿色点,它与任何蓝色点都不对应,并要求我们的模型将其带入“军鼓”区域,模型将把它带到“军鼓”区域内的一个新位置。这就是模型生成一个“新的”军鼓样本,虽然它与模型训练时的所有军鼓样本有相似之处,但也包含一些新的、未知的特征。
这一概念可以应用于任何类型的声音,包括完整的歌曲。这是一个令人惊叹的创新,能够引领创作方式的多种变化。然而,重要的是要理解,这些模型不会生成超出它们训练范围的内容。如前图所示,尽管我们的概念模型可以处理任何类型的声音,但它只能生成类似于训练样本的小军鼓样本。所有这些音频扩散模型都遵循这一原则。因此,训练模型时使用广泛的数据集至关重要,以确保已知区域(如小军鼓区域)足够多样化且规模足够大,从而避免仅仅复制训练数据。
这一切意味着没有任何模型能够复制人类的创造力,只能模拟它的变体。
扩散模型的应用
这些模型并不会像人类那样神奇地生成新的音乐风格或探索未知的声音景观。理解这一点后,我们应该将这些生成模型视为增强创意的工具,而不是替代人类创意的替代品。以下是这项技术在创作中应用的几种方式:
-
通过策展激发创造力:在采样包中搜索以找到所需的声音是制作过程中常见的做法。这些模型可以有效地作为“无限采样包”的一种形式,通过声音的策展来增强艺术家的创造力。
-
声音转移: 就像扩散模型可以将随机噪声转化为可识别的音频一样,它们也可以接收其他声音并将其“转移”到另一种类型的声音上。例如,如果我们使用之前的军鼓模型,并输入一个踢鼓样本而不是白噪声,它会将踢鼓样本转变成军鼓声音。这使得非常独特的创作成为可能,能够结合多种不同声音的特征。
-
声音变异性(人性化): 当人类演奏现场乐器时,例如鼓组中的高帽,每一次击打都会有固有的变异性。各种虚拟乐器尝试通过不同的方法模拟这一现象,但仍然可能听起来不自然,缺乏个性。音频扩散可以实现单一声音的无限变化,从而为音频样本添加人性化元素。例如,如果你编程一个鼓组,音频扩散可以用来让每次击打在音色、力度、起音等方面略有不同,从而使原本可能显得呆板的演奏更具人性化。
-
声音设计调整: 类似于人类的变异性潜力,这一概念也可以应用于声音设计,创造对声音的轻微变化。也许你大多喜欢门砰的一声样本,但希望它有更多的质感或脆响。扩散模型可以利用这个样本并对其进行微调,保持大部分特征,同时加入一些新的特征。这可以在比使用均衡器或滤波器更基础的层面上,添加、去除或改变声音的频谱内容。
-
旋律生成: 类似于浏览样本包,音频扩散模型可以生成旋律,激发出可供进一步创作的灵感。
-
立体声效果: 有多种不同的混音技巧可以为单声道(单声道)声音添加立体声宽度。然而,这些方法往往会带来不必要的色彩、延迟或相位偏移。音频扩散可以用来生成几乎与单声道声音相同的声音,但其内容足够不同,以扩展立体声宽度,同时避免许多不希望出现的现象。
-
超分辨率: 音频扩散模型可以增强音频录音的分辨率和质量,使其更加清晰和详细。这在音频修复或处理低质量录音时尤为有用。
-
图像修复: 扩散模型可以用来填补音频信号中缺失或损坏的部分,将其恢复到原始或改进后的状态。这对于修复损坏的音频录音、完成可能缺失的音频片段或在音频剪辑之间添加过渡非常有价值。
结论
毫无疑问,这些新的生成型人工智能模型是令人惊叹的技术进步,不论它们被视为积极还是消极的。关于扩散模型的优化空间非常广泛,可以在速度、多样性和质量等方面提升它们的性能,但我们已经讨论了这些模型功能的基本原理。这些知识为我们提供了更深刻的背景,让我们理解当这些模型生成“新声音”时,真正意味着什么。
从更广泛的层面来看,人们关心的并不仅仅是音乐本身——更重要的是音乐创作中的人类元素。问问自己,如果你听到一段高超且迅速的吉他独奏录音,你会感到印象深刻吗?这取决于情况。如果这段独奏是由一个制作人编程的虚拟 MIDI 乐器人工生成的,你可能会毫不动容,甚至不喜欢它的声音。然而,如果你知道这段独奏是由一位吉他手用真实吉他演奏的,或者甚至亲眼看到他或她演奏,你将完全被他们的专业技巧和精准度所吸引。我们被演奏中的灵巧、歌词背后的思想和情感,以及创作歌曲时每个决定背后的考虑所吸引。
尽管这些令人难以置信的进步让艺术家和制作人感到一些生存焦虑,但人工智能永远无法剥夺我们创作的声音和音乐中那份人类元素。因此,我们应该以一种工具的心态来看待这些新的进展,认为它们是为了增强艺术家的创造力,而不是取而代之。
除非另有说明,所有图片均由作者提供。
我是一名音频机器学习工程师和研究员,同时也是一名终身音乐人。如果你对更多音频人工智能应用感兴趣,可以阅读我之前发布的文章:Tiny Audio Diffusion 和 音乐分离。
在 LinkedIn 和 GitHub 上找到我,了解我当前的工作和研究进展,访问我的网站:www.chrislandschoot.com。
在 Spotify, Apple Music, YouTube, SoundCloud 和其他流媒体平台上找到我的音乐,艺名为 After August。
自编码器:数据科学家的终极指南
面向初学者的架构指南、Python 实现及对未来的展望
·发表于Towards Data Science ·阅读时间 19 分钟·2024 年 10 月 17 日
--

图片由Clark Van Der Beken提供,来源于Unsplash
自编码器是一种特殊形式的深度神经网络,主要用于特征提取或降维。由于它们可以处理未标注的数据,因此属于无监督学习领域。其架构由两个主要组成部分构成:编码器,它将输入数据压缩成低维表示;解码器,经过训练后能够从这一表示中重建原始数据。
本文详细概述了自编码器的结构,并解释了架构中各个组件的作用。我们还探讨了在训练过程中可能遇到的挑战,以及基于这一模型的应用。最后,我们将深入探讨该方法的优缺点,并与其他降维算法进行比较。
什么是自编码器?
自编码器是一种特殊形式的人工神经网络,经过训练用于将输入数据表示为压缩形式,然后从这种压缩形式中重建原始数据。最初听起来似乎是一次不必要的转换,但它是降维的一个核心部分,因为它…
AutoGluon-TimeSeries:一个库包含所有时间序列预测模型
由亚马逊推出的强大库——包括编码示例
·发表于 Towards Data Science ·阅读时间 8 分钟·2024 年 1 月 5 日
--

图像由作者使用 Stable Diffusion 创建
开源时间序列领域正在加速发展。
其中包括成功的库,如 Darts、GluonTS 和 Nixtla。
去年,亚马逊扩展了其 AutoGluon 库,专注于时间序列——名为AutoGluon-Timeseries(AG-TS)1。
AG-TS 利用其他库的专业知识:
-
来自亚马逊自身(AutoGluon 和 GluonTS)。
-
来自 Nixtla(StatsForecast 和 MLForecast)。
最棒的部分是:AG-TS 拥有用户友好的 API——我们只需几行代码就能获取预测!
本文探讨了 AG-TS,并概述了它的功能。我们还将构建一个简单的项目,使用广为人知的旅游数据集[2]。
让我们深入了解
我推出了AI Horizon Forecast,这是一个专注于时间序列和创新 AI 研究的新闻通讯。点击这里订阅,拓宽你的视野!
什么是 AutoGluon-Timeseries
AutoGluon–TimeSeries 是一个 AutoML 时间序列框架,专注于……
AutoHyDE:使 HyDE 在高级 LLM RAG 中更加出色
🔎 深入探讨 HyDE 在高级 LLM RAG 中的应用 + 💡 介绍 AutoHyDE,一种半监督框架,旨在提高 HyDE 的效果、覆盖面和适用性
·发表于 Towards Data Science ·阅读时间 19 分钟 ·2024 年 4 月 4 日
--

图片由作者与 DALL-E 协助制作
介绍
在检索增强生成(RAG)领域,假设文档嵌入(HyDE)已被证明是一种强大的查询重写方法,有助于提高检索文档的相关性。
对于未接触过的读者,传统的检索方法仅使用原始输入生成嵌入向量进行检索,而 HyDE 是一种生成嵌入向量的方法,这些向量与索引文档的嵌入空间更为相关,以便进行更精准的文档检索。
高层次的总结是:(1)从用户输入创建假设文档,(2)将假设文档转换为嵌入,(3)使用嵌入进行相似文档的检索
我一直在一些工作和个人项目中使用 RAG 和基础 HyDE,经过一段时间的使用,我意识到现有的 HyDE 实现并不是总能开箱即用,也不像我预期的那样灵活。因此,在进行了一些方法学研究并深入阅读论文和源代码后……
使用 LLM 和 TF-IDF 自动化视频章节划分
将原始转录文本转化为结构良好的文档
·发布于 Towards Data Science ·12 分钟阅读·2024 年 9 月 9 日
--

图片来源:Jakob Owens来自Unsplash
视频章节划分是将视频分割成不同章节的任务。除了像 YouTube 章节那样作为导航工具使用外,它还是一系列下游应用的核心,从信息检索(例如 RAG 语义切分)到引用或摘要等。
在一个最近的项目中,我需要自动化这个任务,结果发现可用的选择非常有限,尤其是在开源领域。虽然一些专业工具或付费 API 提供了这样的服务,但我找不到任何提供足够健全和准确解决方案的库或教程。如果你知道有类似的资源,请在评论区分享!
如果你在想,为什么不直接将转录文本复制粘贴到大型语言模型(LLM)中,并要求其生成章节标题,答案是,这样做有两个原因不会有效。首先,LLM 无法始终如一地保留时间戳信息,并将其与章节标题关联。其次,LLM 在处理长篇转录文本时,往往会忽视重要部分。
因此,我最终设计了一个自定义工作流程,依赖 LLM 来处理不同的语言处理子任务(如文本格式化、段落结构、章节划分和标题生成),并使用TF-IDF统计方法在段落结构化后将时间戳信息添加回来。

LLM 和 TF-IDF 的结合使得能够高效地编辑和结构化原始转录文本,同时保留时间戳——在这个HuggingFace 空间查看演示。
最终的工作流程效果很好,通常会生成与 YouTube 推荐章节相同或增强版的章节。该工具还允许将格式不佳的转录文本导出为结构化文档,如下所示的HuggingFace 空间。
本文旨在解释其主要步骤,步骤如下图所示:

视频章节化的提议工作流程,从原始转录文本的获取到结构化的 Markdown 和 Gradio 应用程序。
工作流程的关键步骤在于将转录文本结构化为段落(步骤 2),然后将这些段落归类为章节,最终生成目录(步骤 4)。请注意,这两个步骤可能依赖于不同的 LLM:对于简单的文本编辑和段落识别任务,可以使用快速且便宜的 LLM,如 LLama 3 8B;而生成目录则需要更为复杂的 LLM,如 GPT-4o-mini。期间,使用 TF-IDF 将时间戳信息添加回结构化的段落。
本文的其余部分将更详细地描述每个步骤。
查看随附的Github 仓库和 Colab 笔记本,亲自探索一下吧!
1) 获取视频/音频转录文本
以课程“MIT 6.S191:深度学习导论”的第一讲为例(IntroToDeepLearning.com),该课程由 Alexander Amini 和 Ava Amini 讲授(根据 MIT 许可协议授权)。

课程 YouTube 页面的截图。课程资料受 MIT 许可证保护。
请注意,视频描述中已提供章节信息。

章节信息已在 YouTube 描述中提供
这为我们提供了一个基准,用于在本文后面对我们的章节化效果进行定性比较。
YouTube 转录 API
对于 YouTube 视频,YouTube 通常会自动生成转录文本。检索该转录文本的便捷方法是调用 Python youtube_transcript_api库中的get_transcript方法。该方法将 YouTube 的video_id作为参数:
# https://www.youtube.com/watch?v=ErnWZxJovaM
video_id = "ErnWZxJovaM" # MIT Introduction to Deep Learning - 2024
# Retrieve transcript with the youtube_transcript_api library
from youtube_transcript_api import YouTubeTranscriptApi
transcript = YouTubeTranscriptApi.get_transcript(video_id, languages=["en"])
这将返回转录文本,内容为文本和时间戳键值对的列表:
[{'text': '[Music]', 'start': 1.17},
{'text': 'good afternoon everyone and welcome to', 'start': 10.28},
{'text': 'MIT sus1 191 my name is Alexander amini', 'start': 12.88},
{'text': "and I'll be one of your instructors for", 'start': 16.84},
...]
然而,转录文本的格式较差:缺乏标点符号并且包含拼写错误(如“MIT sus1 191”应为“MIT 6.S191”,或“amini”应为“Amini”)。
使用 Whisper 进行语音转文本
或者,可以使用语音转文本库从视频或音频文件中推断转录内容。我们推荐使用faster-whisper,它是一个快速实现的最新开源whisper模型。
模型有不同的大小,最精确的是‘large-v3’,它能够在 T4 GPU 上每分钟转录约 15 分钟的音频(Google Colab 提供免费使用)。
from faster_whisper import WhisperModel
# Load Whisper model
whisper_model = WhisperModel("large-v3",
device="cuda" if torch.cuda.is_available() else "cpu",
compute_type="float16",
)
# Call the Whisper transcribe function on the audio file
initial_prompt = "Use punctuation, like this."
segments, transcript_info = whisper_model.transcribe(audio_file, initial_prompt=initial_prompt, language="en")
转录的结果以片段形式提供,可以很容易地转换为文本和时间戳的列表,就像使用youtube_transcript_api库一样。
提示:Whisper 有时可能不包括标点符号。可以通过提供包含标点符号的小句子来使用initial_prompt参数引导模型添加标点符号。
以下是使用 whisper large-v3 转录我们视频示例的一部分摘录:
[{'start': 0.0, 'text': ' Good afternoon, everyone, and welcome to MIT Success 191.'},
{'start': 15.28, 'text': " My name is Alexander Amini, and I'll be one of your instructors for the course this year"},
{'start': 19.32, 'duration': 2.08, 'text': ' along with Ava.'}
...]
请注意,与 YouTube 转录相比,标点符号已经添加。然而,某些转录错误仍然存在(例如,‘MIT Success 191’而不是‘MIT 6.S191’)。
2)将转录内容结构化为段落
一旦转录可用,第二阶段就开始对转录内容进行编辑和段落结构化。
转录编辑指的是为了提高可读性所做的修改。例如,添加缺失的标点符号、修正语法错误、去除口头习惯等。
将内容结构化为段落也有助于提高可读性,另外,这也作为第四阶段识别章节的预处理步骤,因为章节将通过将段落组合在一起形成。
段落编辑和结构化可以通过单次操作完成,使用 LLM。下面我们演示了该阶段的预期结果:

左:原始转录。右:编辑和结构化后的转录。
这个任务并不需要非常复杂的 LLM,因为它主要是对内容进行重述。在撰写本文时,使用例如 GPT-4o-mini 或 Llama 3 8B,可以获得不错的结果,配合以下的系统提示:
你是一个有帮助的助手。
你的任务是提高用户输入的可读性:如果需要,添加标点符号,去除口头习惯,并将文本结构化为以‘\n\n’分隔的段落。
尽量保持用词忠实于原文。
将你的回答放在
标签内。
我们依赖于OpenAI 兼容的聊天完成 API来调用 LLM,消息的角色可以是‘system’、‘user’或‘assistant’。下面的代码演示了使用Groq实例化 LLM 客户端,使用 LLama 3 8B:
# Connect to Groq with a Groq API key
llm_client = Groq(api_key=api_key)
model = "llama-8b-8192"
# Extract text from transcript
transcript_text = ' '.join([s['text'] for s in transcript])
# Call LLM
response = client.chat.completions.create(
messages=[
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": transcript_text
}
],
model=model,
temperature=0,
seed=42
)
给定一段原始的‘transcript_text’作为输入,它将返回一个包含在
response_content=response.choices[0].message.content
print(response_content)
"""
<answer>
Good afternoon, everyone, and welcome to MIT 6.S191\. My name is Alexander Amini, and I'll be one of your instructors for the course this year, along with Ava. We're really excited to welcome you to this incredible course.
This is a fast-paced and intense one-week course that we're about to go through together. We'll be covering the foundations of a rapidly changing field, and a field that has been revolutionizing many areas of science, mathematics, physics, and more.
Over the past decade, AI and deep learning have been rapidly advancing and solving problems that we didn't think were solvable in our lifetimes. Today, AI is solving problems beyond human performance, and each year, this lecture is getting harder and harder to teach because it's supposed to cover the foundations of the field.
</answer>
"""
接下来,我们从
import re
pattern = re.compile(r'<answer>(.*?)</answer>', re.DOTALL)
response_content_edited = pattern.findall(response_content)
paragraphs = response_content_edited.strip().split('\n\n')
paragraphs_dict = [{'paragraph_number': i, 'paragraph_text': paragraph} for i, paragraph in enumerate(paragraphs)
print(paragraph_dict)
[{'paragraph_number': 0,
'paragraph_text': "Good afternoon, everyone, and welcome to MIT 6.S191\. My name is Alexander Amini, and I'll be one of your instructors for the course this year, along with Ava. We're really excited to welcome you to this incredible course."},
{'paragraph_number': 1,
'paragraph_text': "This is a fast-paced and intense one-week course that we're about to go through together. We'll be covering the foundations of a rapidly changing field, and a field that has been revolutionizing many areas of science, mathematics, physics, and more."},
{'paragraph_number': 2,
'paragraph_text': "Over the past decade, AI and deep learning have been rapidly advancing and solving problems that we didn't think were solvable in our lifetimes. Today, AI is solving problems beyond human performance, and each year, this lecture is getting harder and harder to teach because it's supposed to cover the foundations of the field."}]
注意,输入不应过长,否则大型语言模型(LLM)可能会‘忘记’部分文本。对于较长的输入,转录本必须分块以提高可靠性。我们发现 GPT-4o-mini 可以处理最多 5000 个字符,而 Llama 3 8B 只能处理最多 1500 个字符。该笔记本提供了transcript_to_paragraphs函数,它负责将转录本分块。
3) 使用 TF-IDF 推断段落时间戳
转录本现在已经被结构化为一个编辑过的段落列表,但时间戳在此过程中丢失了。
第三阶段是通过推断原始转录中哪个片段最接近每个段落,来重新添加时间戳。

TF-IDF 用于找到哪个原始转录片段(右侧)最能匹配编辑后的段落开头(左侧)。
对于这项任务,我们依赖于TF-IDF 度量。TF-IDF 代表词频-逆文档频率,是一种用于比较两段文本相似度的度量方法。该度量通过计算相似单词的数量来工作,且对出现频率较低的单词给予更高权重。
作为预处理步骤,我们调整转录片段和段落的起始位置,以使它们包含相同数量的单词。文本片段应足够长,以便段落开头能够成功匹配到一个唯一的转录片段。我们发现,使用 50 个单词在实际操作中效果较好。
num_words = 50
transcript_num_words = transform_text_segments(transcript, num_words=num_words)
paragraphs_start_text = [{"start": p['paragraph_number'], "text": p['paragraph_text']} for p in paragraphs]
paragraphs_num_words = transform_text_segments(paragraphs_start_text, num_words=num_words)
接着,我们依赖sklearn库及其TfidfVectorizer和cosine_similarity函数来运行 TF-IDF,并计算每个段落开头和转录片段之间的相似度。以下是用于找到第一个段落在转录片段中最佳匹配索引的代码示例。
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# Paragraph for which to find the timestamp
paragraph_i = 0
# Create a TF-IDF vectorizer
vectorizer = TfidfVectorizer().fit_transform(transcript_num_words + paragraphs_num_words)
# Get the TF-IDF vectors for the transcript and the excerpt
vectors = vectorizer.toarray()
# Extract the TF-IDF vector for the paragraph
paragraph_vector = vectors[len(transcript_num_words) + paragraph_i]
# Calculate the cosine similarity between the paragraph vector and each transcript chunk
similarities = cosine_similarity(vectors[:len(transcript_num_words)], paragraph_vector.reshape(1, -1))
# Find the index of the most similar chunk
best_match_index = int(np.argmax(similarities))
我们将该过程封装在add_timestamps_to_paragraphs函数中,该函数为段落添加时间戳,同时添加匹配的片段索引和文本:
paragraphs = add_timestamps_to_paragraphs(transcript, paragraphs, num_words=50)
#Example of output for the first paragraph:
print(paragraphs[0])
{'paragraph_number': 0,
'paragraph_text': "Good afternoon, everyone, and welcome to MIT 6.S191\. My name is Alexander Amini, and I'll be one of your instructors for the course this year, along with Ava. We're really excited to welcome you to this incredible course.",
'matched_index': 1,
'matched_text': 'good afternoon everyone and welcome to',
'start_time': 10}
在上面的示例中,第一个段落(编号为 0)被发现与从时间 10 秒(秒为单位)开始的转录片段编号 1 匹配。
4) 生成目录
目录通过将连续的段落分组为章节,并识别有意义的章节标题来生成。此任务主要由 LLM 执行,LLM 被指示将由 JSON 段落列表组成的输入转换为包含 JSON 章节标题及其起始段落编号的输出:
system_prompt_paragraphs_to_toc = """
You are a helpful assistant.
You are given a transcript of a course in JSON format as a list of paragraphs, each containing 'paragraph_number' and 'paragraph_text' keys.
Your task is to group consecutive paragraphs in chapters for the course and identify meaningful chapter titles.
Here are the steps to follow:
1\. Read the transcript carefully to understand its general structure and the main topics covered.
2\. Look for clues that a new chapter is about to start. This could be a change of topic, a change of time or setting, the introduction of new themes or topics, or the speaker's explicit mention of a new part.
3\. For each chapter, keep track of the paragraph number that starts the chapter and identify a meaningful chapter title.
4\. Chapters should ideally be equally spaced throughout the transcript, and discuss a specific topic.
Format your result in JSON, with a list dictionaries for chapters, with 'start_paragraph_number':integer and 'title':string as key:value.
Example:
{"chapters":
[{"start_paragraph_number": 0, "title": "Introduction"},
{"start_paragraph_number": 10, "title": "Chapter 1"}
]
}
"""
一个重要的元素是特别要求输出 JSON 格式,这样可以增加获得正确格式的 JSON 输出的机会,后续可以在 Python 中重新加载。
本任务使用 GPT-4o-mini,因为它比 OpenAI 的 GPT-4o 更具成本效益,并且通常能提供良好的结果。指令通过‘system’角色提供,段落则通过‘user’角色以 JSON 格式提供。
# Connect to OpenAI with an OpenAI API key
llm_client_get_toc = OpenAI(api_key=api_key)
model_get_toc = "gpt-4o-mini-2024-07-18"
# Dump JSON paragraphs as text
paragraphs_number_text = [{'paragraph_number': p['paragraph_number'], 'paragraph_text': p['paragraph_text']} for p in paragraphs]
paragraphs_json_dump = json.dumps(paragraphs_number_text)
# Call LLM
response = client_get_toc.chat.completions.create(
messages=[
{
"role": "system",
"content": system_prompt_paragraphs_to_toc
},
{
"role": "user",
"content": paragraphs_json_dump
}
],
model=model_get_toc,
temperature=0,
seed=42
)
完成!调用返回了包含章节标题及起始段落编号的 JSON 格式列表:
print(response)
{
"chapters": [
{
"start_paragraph_number": 0,
"title": "Introduction to the Course"
},
{
"start_paragraph_number": 17,
"title": "Foundations of Intelligence and Deep Learning"
},
{
"start_paragraph_number": 24,
"title": "Course Structure and Expectations"
}
....
]
}
如同步骤 2 所示,LLM 可能会处理长输入时出现问题,并且忽略输入的一部分。解决方案依然是将输入拆分成多个部分,在笔记本中通过paragraphs_to_toc函数和chunk_size参数实现。
5)输出结构化的章节
最后的阶段将段落和目录结合起来,创建一个结构化的 JSON 文件,其中包含章节,示例如下所示,见附带的 Github 仓库。
我们在下图中展示了最终的章节划分(右)与 YouTube 描述中提供的基准章节划分(左)的对比:

并排比较 YouTube 的基准章节划分(左)和我们的划分(右)
这个对比主要是定性的,因为没有‘标准答案’。总体来说,本文中描述的方法识别了相似的章节,但提供了一个稍微更精细的划分。对两种章节划分的手动检查表明,基准章节划分在课程信息上存在偏差,实际上课程是在 9:37 开始,而不是 7:25。
另有一些章节划分的示例,见这个HuggingFace 空间。整个工作流最终作为 Gradio 应用程序捆绑在附带的笔记本中,使得你可以更轻松地在自己的视频上进行测试。

Gradio 应用程序,它将不同的步骤捆绑在一起,并从原始转录文本输出一个结构良好的文档
为了进一步深入
-
[从文本分割到智能章节划分:
注意:
- 除非另有说明,所有图片均为作者提供
喜欢这篇文章吗?分享你的想法,给予掌声,或 在 LinkedIn 上与我联系 。
数据质量错误检测由 LLM 驱动
·发表于Towards Data Science ·阅读时间 17 分钟·2024 年 3 月 22 日
--
本文是关于使用大型语言模型(LLM)清理数据系列文章的第二篇,重点介绍在表格数据集中识别错误。

本文概述了我们将要探索的方法论,重点是评估表格数据集的脏污分数,且几乎不需要人工干预。
数据脏污分数
鼓励读者首先阅读关于数据脏污分数的介绍性文章,文章解释了关键假设,并演示了如何计算该分数。
简单回顾一下,数据脏污分数估算数据集中包含错误的单元格的预期比例。这个指标背后的主要假设如下:
-
数据错误与违反的约束有关。
-
如果没有预期,则不会影响分数。
-
数据问题可以精确定位到特定单元格。
-
每个数据错误都会被分配一个置信度分数。
-
每个单元格对整体得分有相同的影响。
这一过程的第一步是识别和分类数据集中的数据不准确性。
自动检测数据质量问题的重要性
检测数据问题在这个过程中至关重要,但由于多个因素,通常具有挑战性:
-
高人工标注成本:识别数据错误通常需要大量来自数据专业人员(如科学家、工程师和分析师)或主题专家(SME)的投入。这需要大量时间且成本昂贵。
-
数据从业者对这项繁琐工作的热情缺乏:数据清洗被许多业内人士视为工作中不太吸引人的部分,这并不是什么秘密。数据清洗通常被视为比建模、构建现代数据架构或回答业务查询等更具吸引力活动的前奏,因此它常常排在优先级较低的位置,导致拖延,甚至在某些情况下,直到出现重大问题才被完全忽视。
-
领域专家的局限性:领域专家(SMEs)拥有宝贵的知识,但可能缺乏像 SQL 或编程这样的技术技能。虽然无代码和低代码工具在某种程度上有所帮助,但它们尚未被完全采用,且可能无法覆盖所有数据管理方面,比如版本控制。
-
专业知识差距:有效的数据清洗超越了基本技能,要求具备专门的专业知识。缺乏培训和数据准备方面的普遍冷漠意味着许多从业者只能识别表面错误,忽视了需要更深层次理解的数据清洗中更复杂的问题。
尽管存在固有的挑战,大语言模型(LLM)领域的进展为自动识别简单数据问题并揭示更复杂的数据质量问题提供了有前景的解决方案。
大语言模型驱动的数据错误检测
大语言模型正成为自动化检测数据质量问题的宝贵工具,作为高效的起点,推动富有成效的人机协作迭代过程。诸如Jellyfish: A Large Language Model for Data Preprocessing、Can language models automate data wrangling? 和Large Language Models as Data Preprocessors等论文中讨论的模型,展示了它们在自动化约束生成和数据错误检测方面的潜力。这种自动化并不取代人工干预,而是增强了它,使得人们可以审查和调整自动化约束,通过直接处理问题或调整置信度分数来反映数据错误检测中的不确定性。
LLM(大语言模型)特别适用于检测数据质量问题,因为它们在多样化的互联网内容上进行了广泛的训练,涵盖了大量领域知识和与数据质量问题相关的代码审查示例。这种训练使得 LLM 能够基于文本内容识别数据错误,而无需显式定义规则。通过将表格数据集转换为纯文本(称为序列化),LLM 能够像一支经验丰富的团队一样仔细审查数据,利用其“压缩”的互联网知识来定位错误。这种广泛的训练使得它们能够以类似人类专家的直觉水平,识别出 CSV 文件等人类可读数据集中的潜在错误。此外,任何领域特定知识的空白都可以通过检索增强生成(RAG)等技术,或者通过根据数据集的特定性质调整模型的提示来弥补。
在数据错误检测中使用 LLM 的另一个关键优势是它们能够处理与数据质量问题相关的固有不确定性。并非所有错误都是直观的,甚至专家有时也会对什么构成数据问题产生分歧。LLM 能够像人类一样,根据直觉和经验的结合,为其发现的错误分配置信度分数,从而反映错误发生的可能性。
在不同数据集和潜在问题中推广错误检测的挑战相当巨大。传统方法通常依赖一套广泛的决策规则或结合专门的机器学习模型来处理各种场景,比如检查地址和电话号码的有效性或进行异常检测。这正是 LLM 的优势所在,它们提供了一种更具适应性且劳动强度更低的替代方案。LLM 能够理解并识别各种数据质量问题,而无需庞大的基于规则的系统或领域特定的模型,这使得它们成为一种无价的工具。与传统商业规则或统计方法相比,机器学习方法的优势相当引人注目。机器学习的采用是由于其相对易用性和在不同用例中的适应性,不需要过多的领域特定知识,也无需花费大量时间进行实施。
接下来,我们将通过一个实际示例演示这种方法。
一个案例研究
在上一篇文章中,我们通过使用来自《Cleaning Data for Effective Data Science》一书的数据集示例,探讨了数据脏污评分的概念。相关的数据集如下:
Student#,Last Name,First Name,Favorite Color,Age
1,Johnson,Mia,periwinkle,12
2,Lopez,Liam,blue,green,13
3,Lee,Isabella,,11
4,Fisher,Mason,gray,-1
5,Gupta,Olivia,9,102
6,,Robinson,,Sophia,,blue,,12
数据错误已经被指出。现在,我们想探索如何使用大型语言模型,特别是GPT-4,自动发现这些错误。这种新方法提供了一种现代化的方式来发现数据集中的问题,但也带来了可能的风险,如使用外部 API 时的隐私问题。然而,这种方法不仅适用于GPT-4,还可以与任何 LLM 配合使用,尽管效果可能会因模型的能力而有所不同。
初步步骤:检索表格注释
为了帮助模型识别数据不一致性,提供有关数据框的额外上下文信息是很有帮助的。这正是data catalog(数据目录)的作用,尽管这个话题非常广泛,我们将简化为仅关注 LLM 在检查数据集行批次时识别数据错误所需的基本上下文信息。
所需的关键元数据包括:
-
对表格的概述,包括其描述和用途。
-
对每个列的含义和类型有清晰的理解。
鉴于在组织中常常缺乏数据目录或可靠的文档,我们将探讨如何使用 LLM 加速这一过程。这个过程被称为表格注释,它涉及识别表格元素的语义信息,包括列、列之间的关系以及单元格中的实体。欲了解更多详细信息,请参考以下资源:使用 ChatGPT 进行列类型注释、使用预训练语言模型进行列注释或SOTAB:WDC Schema.org 表格注释基准。
这是我使用的提示:
Analyse the table below and provide schema annotations based on Schema.org standards.
Student#,Last Name,First Name,Favorite Color,Age
1,Johnson,Mia,periwinkle,12
2,Lopez,Liam,blue,green,13
3,Lee,Isabella,,11
4,Fisher,Mason,gray,-1
5,Gupta,Olivia,9,102
6,,Robinson,,Sophia,,blue,,12
Follow these steps:
1\. Identify the overall semantic type of the table.
2\. Provide a short description of each column.
3\. Annotate each column with its semantic type from Schema.org.
4\. Determine the most suitable data type for each column (after data cleaning).
Summarise the table schema as follows:
- Table Semantic Type: <type>
- Column: <name>, Description: <description>, Semantic Type: <Schema.org type>, Pandas Type: <Pandas data type>
在提示指令中,我要求模型分析提供的表格(或表格概述),并建议按照Schema.org标准进行注释。具体来说,输出应包括:
-
表格的语义类型
-
每个列的简要描述
-
来自 Schema.org 的列注释类型,如适用
-
每列的理想或最佳匹配的数据类型,无论提供的文本序列化中是否存在数据问题
然后,响应会被格式化为提供清晰和结构化的摘要,可以作为后续提示中的上下文重用:
Summarise the table schema as follows:
- Table Semantic Type: <type>
- Column: <name>, Description: <description>, Semantic Type: <Schema.org type>, Pandas Type: <Pandas data type>
表格通过{table}占位符插入提示中。典型的方法是通过序列化将表格数据转换为文本,如在大型语言模型(LLMs)在表格数据上的应用:预测、生成和理解——综述中所讨论的那样。
以下是GPT-4的示例响应:
- Table Semantic Type: Person
- Column: Student#, Description: Unique identifier for each student, Semantic Type: identifier, Pandas Type: int
- Column: Last Name, Description: Family name of the student, Semantic Type: familyName, Pandas Type: string
- Column: First Name, Description: Given name of the student, Semantic Type: givenName, Pandas Type: string
- Column: Favorite Color, Description: Preferred color of the student, Semantic Type: color (custom), Pandas Type: string (or list if cleaned for multiple values)
- Column: Age, Description: Age of the student, Semantic Type: age, Pandas Type: int (after cleaning invalid entries)
回复可能会略有不同,但对于这样一个简单的示例通常是一致的。这里的目的是加速初步过程,而不是完全自动化。因此,这可以视为一个初步草稿,之后可以根据我们知识的见解和来自领域专家(SMEs)的外部上下文进行完善。
现在,了解了一些表格的上下文后,让我们探讨如何自动识别数据质量问题。
使用 LLMs 检测数据错误
首先,我建议使用一个提示,帮助识别给定表格中的数据质量问题。
Task: Analyse the provided table to identify and document data quality issues.
Below are common data quality issues to guide your analysis. However, you may also identify other relevant issues:
- Ingestion errors
- Typecasting issues
- Duplicates
- Date parsing issues
- Character encoding problems
- Missing values
- Typos/spelling mistakes
- Anomalies/outliers
- Conversion errors and inconsistent units
- Privacy concerns (e.g., exposed PII)
- Domain-specific errors (e.g., invalid formats for addresses, phone numbers, emails)
Instructions:
1\. Examine silently the table and its metadata.
2\. Line by line, identify potential data quality issues without coding.
3\. Document each issue, including:
- Nature and description of the issue
- Expected correct state
- Violated constraint
- Confidence level in your assessment using ordinal categories: `low`, `medium`, `high` and `certain`.
- Specific location of the issue in the table (use 'None' for table-wide issues): Index and Column names.
Provided Data:
Table:
,Student#,Last Name,First Name,Favorite Color,Age
0,1,Johnson,Mia,periwinkle,12
1,2,Lopez,Liam,blue,green,13
2,3,Lee,Isabella,,11
3,4,Fisher,Mason,gray,-1
4,5,Gupta,Olivia,9,102
5,6,,Robinson,,Sophia,,blue,,12
Metadata:
- Table Semantic Type: Person
- Column: Student#, Description: Unique identifier for each student, Semantic Type: identifier, Pandas Type: int or string
- Column: Last Name, Description: Family name of the student, Semantic Type: familyName, Pandas Type: string
- Column: First Name, Description: Given name of the student, Semantic Type: givenName, Pandas Type: string
- Column: Favorite Color, Description: Preferred color of the student, Semantic Type: color (custom), Pandas Type: string (or list if cleaned for multiple values)
- Column: Age, Description: Age of the student, Semantic Type: age, Pandas Type: int (after cleaning invalid entries)
Detected Data Issues:
提示的初始部分设置了任务目标,并列出了常见的数据问题示例,如数据摄取错误、重复数据和隐私问题等。此列表并不详尽,鼓励您根据表格的上下文添加更多相关类型,以指导分析。
接下来,提示详细说明了逐步指令,采用Chain-of-Thoughts方法,确保模型有条不紊地分析表格及其元数据,在逐行识别数据问题时模拟人工分析。这个过程应在没有编码的情况下进行,以保持简洁性和广泛适用性。这个步骤非常重要,因为尽管像GPT-4这样的模型具有分析能力,可以进行有用的迭代编码会话,但仅依赖于文本分析有助于推广通用性。
在检测到潜在的数据问题时,提示会指导记录以下详细信息:
-
问题的性质和描述
-
期望的正确状态
-
违反的约束
-
使用序数类别的置信度等级进行评估:
low(低)、medium(中)、high(高)和certain(确定)。 -
数据问题的具体位置,如果是全表问题,则使用‘None’,并附上索引和列名以供参考。
提示中提供了表格及其元数据,并为每一行添加了索引,帮助模型精确定位错误位置。
对于大规模表格,可以分批应用此提示以涵盖整个数据集,并将结果汇总以识别所有数据质量问题。
下面是该提示生成的输出示例,格式为报告,详细列出了已识别的数据问题,每个问题包括描述、期望状态、违反的约束、置信度等级和位置。
1\. Issue: Incorrect value format in "Favorite Color" column
- Nature and Description: The entry "blue,green" in the "Favorite Color" column for the second row appears to contain two colors instead of one, suggesting a possible ingestion or data entry error.
- Expected Correct State: Each cell under the "Favorite Color" column should contain only one color.
- Violated Constraint: Single value constraint.
- Confidence Level: High
- Specific Location: Index 1, Column "Favorite Color"
2\. Issue: Missing value in "Favorite Color" and "Age" columns
- Nature and Description: The third row has a missing value for "Favorite Color", and the sixth row has missing values for both "Last Name" and "Favorite Color".
- Expected Correct State: No missing values in any of the columns.
- Violated Constraint: Non-null constraint.
- Confidence Level: Certain
- Specific Location: Index 2, Column "Favorite Color"; Index 5, Columns "Last Name", "Favorite Color"
3\. Issue: Negative value in "Age" column
- Nature and Description: The fourth row has an age of "-1", which is not possible and indicates a data entry error.
- Expected Correct State: Age values should be positive integers.
- Violated Constraint: Age value range (greater than 0).
- Confidence Level: Certain
- Specific Location: Index 3, Column "Age"
4\. Issue: Misplaced values in "Favorite Color" and "Age" columns
- Nature and Description: In the fifth row, what appears to be an age ("9") is placed in the "Favorite Color" column, and an unlikely age ("102") is in the "Age" column, suggesting a possible shifting or misplacement of values.
- Expected Correct State: Age should be a reasonable positive integer, and favorite color should be a string denoting a color.
- Violated Constraint: Data type and value range constraint.
- Confidence Level: High
- Specific Location: Index 4, Columns "Favorite Color", "Age"
5\. Issue: Inconsistent formatting in row entries
- Nature and Description: The sixth row contains additional commas, which might indicate incorrect splitting or entry of the data, leading to misalignment of the values with their respective columns.
- Expected Correct State: Each row should consistently follow the format defined by the column headers without extra delimiters.
- Violated Constraint: Data format and consistency constraint.
- Confidence Level: High
- Specific Location: Index 5, Entire row
下一步是将这些识别出的问题转换为 Python 对象,以便更容易计算数据脏度评分。
将识别的数据问题转换为正确格式
本节重点讲解如何将先前识别的数据问题从简单文本描述转换为 Python 对象。这些对象应遵循DataIssue类定义的结构:
@dataclass
class DataIssue:
type_of_issue: str
expectation: str
constraint_violated: str
confidence_score: str # `low`, `medium`, `high` or `certain`
location: List[Tuple] # Cell positions as (Index, Column). Use None for row/column-wide issues.
这是我使用的提示:
Task: Document data table issues using the provided `DataIssue` class.
For each identified issue and its description, document it using the provided `DataIssue` class structure.
```python
from dataclasses import dataclass
from typing import List, Tuple
@dataclass
class DataIssue:
type_of_issue: str
expectation: str
constraint_violated: str
confidence_score: str # `低`,`中`,`高` 或 `确定`
location: List[Tuple] # 单元格位置以(索引,列)表示。使用 None 表示整行或整列的问题。
Instructions:
1\. 静默检查所有识别出的问题及其描述。
2\. 对于每个问题,使用提供的`DataIssue`类结构实例化它。
3\. 仅返回代码。
4\. 一旦代码验证完成,停止生成。
识别的问题:
{issues_found}
Your code here:
```py
The goal is to generate Python code representing each data issue as a `DataIssue` object without redundant text. The identified issues are injected into the prompt using the placeholder `{issues_found}`.
Example output:
issue1 = DataIssue(
type_of_issue="格式不正确",
expectation="‘喜欢的颜色’列下的每个单元格应只包含一个颜色",
constraint_violated="单值约束",
confidence_score="高",
location=[(1, "喜欢的颜色")]
)
issue2 = DataIssue(
type_of_issue="缺失值"
expectation="任何列中都不应有缺失值",
constraint_violated="非空约束",
confidence_score="确定",
location=[(2, "喜欢的颜色"), (5, "姓氏"), (5, "喜欢的颜色")]
)
issue3 = DataIssue(
type_of_issue="负值",
expectation="年龄值应为正整数",
constraint_violated="年龄值范围(大于 0)",
confidence_score="确定",
location=[(3, "年龄")]
)
issue4 = DataIssue(
type_of_issue="值错位",
expectation="年龄应该是一个合理的正整数,喜欢的颜色应该是表示颜色的字符串",
constraint_violated="数据类型和数值范围约束",
confidence_score="高",
location=[(4, "喜欢的颜色"), (4, "年龄")]
)
issue5 = DataIssue(
type_of_issue="格式不一致",
expectation="每一行应该始终遵循由列标题定义的格式,且不应有额外的分隔符",
constraint_violated="数据格式和一致性约束",
confidence_score="高",
location=[(5, None)] # None 表示整个行的问题
)
The final step involves converting the `location` attribute from lists of tuples to `numpy` arrays, which is detailed in the appendix.
With all elements in place, we can now calculate the *Data Dirtiness Score*.
# Calculation of the Data Dirtiness Score and Comparison with Ground Truth
Let’s revisit the function from the previous article, `compute_data_dirtiness_score`, which uses a list of `DataIssue` objects mentioned earlier.
compute_data_dirtiness_score(data_issues)
> *Data Dirtiness Score: 28.33%*
Using the `GPT-4` model, we estimated the score to be around 28% for this sample. This is fairly close to the "ground truth" score of 31.87%.
To understand the discrepancy between these scores, let’s delve into more detailed metrics on data issue detection. In addition to the overall score, we have matrices of cell issue probabilities for both the ground truth and the model’s estimates.
Below is the ground truth matrix, with columns and indices added for clarity:
学生编号 姓氏 名字 喜欢的颜色 年龄
0 0.00 0.0 0.00 0.00 0.00
1 0.00 0.0 0.00 0.75 0.00
2 0.00 0.0 0.00 1.00 0.00
3 0.00 0.0 0.00 0.00 1.00
4 0.00 0.0 0.00 0.75 0.75
5 0.75 1.0 0.75 1.00 0.75
And here is the matrix of probabilities estimated by the model:
学生编号 姓氏 名字 喜欢的颜色 年龄
0 0.0 0.0 0.00 0.0000 0.00
1 0.0 0.0 0.00 0.7500 0.00
2 0.0 0.0 0.00 1.0000 0.00
3 0.0 0.0 0.00 0.0000 1.00
4 0.0 0.0 0.25 0.8125 0.75
5 1.0 1.0 1.00 1.0000 1.00
Though the matrices appear similar at first glance, we can apply threshold-based metrics such as `accuracy`, `recall`, `precision`, and `F1-score` to get a clearer picture. These metrics provide a straightforward evaluation of the model's performance by considering a cell problematic if the model's likelihood exceeds 0\. Here are the metrics obtained:

The model correctly identified 91% of problematic cells (`recall`), and all of its error predictions were accurate (`precision`).
The model missed one particular issue: “The `Favorite Color` and `First Name` fields might be swapped, considering `Olivia` can be both a name and a colour." This was deemed improbable with a `low` confidence score, suggesting `Olivia` is more likely the `First Name` rather than the `Favorite Color`. Consequently, even though this potential issue was overlooked, its minimal confidence score lessened its impact on the overall Data Dirtiness Score. This explains why the two scores are relatively close despite this omission.
In summary, this approach, based on large language models (LLMs), offers a method for detecting data quality issues in a data frame. While this method may not yet be fully automated and might need manual adjustments, it’s hoped that it will expedite the detection of data errors and the calculation of the *Data Dirtiness Score* for tabular data sets.
# Next Steps and Challenges
I use a two-step process to generate the issues as code. This is done because I have found this adds more stability over a one-in-all solution, i.e. scanning data set and metadatas and outputs data issues directly in right code format. This doesn’t imply it’s impossible, but I’ve chosen to divide this step into two phases to improve robustness for the time being.
An issue we face concerns managing large data sets, both in terms of the number of rows and columns. Despite recent advancements, LLMs still face limitations regarding the input context window and the length of generated content. These constraints limit the size of the table that can be serialised into the prompt for analysis and the length of the data issue report produced by the model. How to divide a data frame based on its size and the model’s capabilities is a question that arises.
In certain scenarios, the lack of general context can be problematic, such as when identifying duplicate rows in a database or detecting spelling errors without a broad understanding of the column values. For instance, in cases where duplicates are not straightforward, a common approach is **Entity Matching**. This technique is particularly useful in data cleaning processes and has seen advancements through the use of Large Language Models. Relevant research in this area includes studies like [Entity Matching using Large Language Models](https://www.semanticscholar.org/paper/Entity-Matching-using-Large-Language-Models-Peeters-Bizer/13c2ae7831c0f1579bc8c6f1a31c9aa8689e24a8) and [Can Foundation Models Wrangle Your Data?](https://arxiv.org/abs/2205.09911), along with [Large Language Models as Data Preprocessors](https://arxiv.org/abs/2308.16361) and [Jellyfish: A Large Language Model for Data Preprocessing](https://www.semanticscholar.org/reader/7e17ef56273063dfa838de30b7cc0546b2e5ee10).
Ensemble methods in machine learning, which involve combining multiple models, can enhance performance and stability. This approach can be applied by running several LLMs simultaneously to identify issues in a data set. It’s beneficial to vary the prompts and settings for each LLM to ensure a diverse range of insights. Additionally, assigning specific error types, like spelling mistakes, to individual models can make the process more efficient. While this method can lead to more reliable results by dividing the task into smaller parts, it also increases both the cost and the complexity of the software. By gathering all the identified data issues, we can improve our chances of finding errors (increasing recall) but might also identify more false errors (decreasing precision). However, reviewing these identified errors is generally less time-consuming than finding them in the first place.
The ability of LLMs to interact directly with databases, similar to the code analysis capability in `ChatGPT-4`, opens up a wider range of possibilities for detecting data errors. A challenge here is automating this process, as the model may deviate from its intended path without sufficient guidance.
Despite all the challenges, it is already quite promising what we can achieve with such as simple approach. With more work on engineering, I hope we can very soon provide a more robust solution to cover larger data sets and fully automate the detection process.
The next article will discuss automated data repair or, at the very least, suggest solutions for repair pending validation.
# References
* [Data Dirtiness Score](https://medium.com/p/fe2ca5678d40)
* [Jellyfish: A Large Language Model for Data Preprocessing](https://www.semanticscholar.org/reader/7e17ef56273063dfa838de30b7cc0546b2e5ee10)
* [Can language models automate data wrangling?](http://josephorallo.webs.upv.es/escrits/MLJ-DataWranglingAutomation.pdf)
* [Large Language Models as Data Preprocessors](https://arxiv.org/abs/2308.16361)
* [Column Type Annotation using ChatGPT](https://arxiv.org/abs/2306.00745)
* [Annotating Columns with Pre-trained Language Models](https://paperswithcode.com/paper/annotating-columns-with-pre-trained-language)
* [SOTAB: The WDC Schema.org Table Annotation Benchmark](https://paperswithcode.com/paper/sotab-the-wdc-schema-org-table-annotation)
* [Large Language Models(LLMs) on Tabular Data: Prediction, Generation, and Understanding — A Survey](https://arxiv.org/abs/2402.17944?utm_campaign=Data_Elixir&utm_source=Data_Elixir_475)
* [Entity Matching using Large Language Models](https://www.semanticscholar.org/paper/Entity-Matching-using-Large-Language-Models-Peeters-Bizer/13c2ae7831c0f1579bc8c6f1a31c9aa8689e24a8)
# Appendix
The section explains how to transform the `location` attribute of a `DataIssue` object, which comes from a LLM, into a different format. This transformation changes a list of tuples, which represent cell positions, into a `numpy`array. This array acts as a mask for those cell positions.
Here's a basic example using the `Students` data set:
create_mask_from_list_of_cell_positions(
shape=dataset_shape,
list_of_cell_positions=[(4, '喜欢的颜色'), (4, '年龄')],
columns=columns
)
array([[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 1, 1],
[0, 0, 0, 0, 0]], dtype=int8)
Below are the function definitions:
def validate_cell_position(
cell_position: Union[
Tuple[int, int], Tuple[None, int], Tuple[int, None], Tuple[None, None]
],
columns: List[str] = None,
) -> Tuple[int, int]:
"""
验证单元格位置并在必要时将列名转换为索引。
"""
if not isinstance(cell_position, tuple):
raise ValueError("单元格位置必须是元组")
# 如果提供了列名,则将列名转换为索引
if isinstance(cell_position[1], str):
if columns is None:
raise ValueError(
"必须提供列名,以便根据列名创建掩码"
)
column_index = columns.index(cell_position[1])
return (cell_position[0], column_index)
return cell_position
def set_mask_values(mask: np.ndarray, cell_position: Tuple[int, int]):
"""
根据单元格位置设置掩码中的值。
"""
row_index, col_index = cell_position
if row_index is None:
mask[:, col_index] = 1
elif col_index is None:
mask[row_index, :] = 1
else:
mask[row_index, col_index] = 1
def create_mask_from_list_of_cell_positions(
shape: Tuple[int, int],
list_of_cell_positions: List[Tuple],
columns: List[str] = None,
) -> np.ndarray:
"""
根据单元格位置列表创建掩码数组。
"""
mask = np.zeros(shape=shape, dtype=np.int8)
for cell_position in list_of_cell_positions:
validated_position = validate_cell_position(cell_position, columns)
set_mask_values(mask, validated_position)
return mask
# 自动化提示工程
> 原文:[`towardsdatascience.com/automated-prompt-engineering-78678c6371b9?source=collection_archive---------0-----------------------#2024-03-10`](https://towardsdatascience.com/automated-prompt-engineering-78678c6371b9?source=collection_archive---------0-----------------------#2024-03-10)
## 一些反思、文献综述以及关于大型语言模型自动化提示工程的实验
[](https://ianhojy.medium.com/?source=post_page---byline--78678c6371b9--------------------------------)[](https://towardsdatascience.com/?source=post_page---byline--78678c6371b9--------------------------------) [Ian Ho](https://ianhojy.medium.com/?source=post_page---byline--78678c6371b9--------------------------------)
·发表于[Towards Data Science](https://towardsdatascience.com/?source=post_page---byline--78678c6371b9--------------------------------) ·13 分钟阅读·2024 年 3 月 10 日
--

图片由作者在 DALL-E 的帮助下生成
在过去几个月里,我一直在尝试构建各种基于大型语言模型的应用,事实上,花费了大量时间专门用于改进提示词,以便从大型语言模型中获得我想要的输出。
许多时候,我都会陷入一种存在性的空虚,问自己是否只是一个被美化的提示工程师。考虑到当前与大型语言模型(LLM)的互动状态,我仍然倾向于得出“还不是”的结论,大多数晚上,我能克服我的冒充者综合症。今天不打算深入讨论这个话题。
但我仍然经常想,如果有一天,编写提示的过程能被大部分自动化掉,那会是什么样子。我认为,这个未来场景的答案取决于揭示提示工程的本质。
尽管互联网上有无数关于提示工程的手册,我仍然无法决定提示工程是**艺术还是科学**。
一方面,当我根据输出的结果不断学习和编辑我的提示词时,感觉这像是一种艺术。随着时间的推移,你会意识到一些细节很重要——比如用‘must’而不是‘should’……
# 自动化提示工程:终极实用指南
> 原文:[`towardsdatascience.com/automated-prompt-engineering-the-definitive-hands-on-guide-1476c8cd3c50?source=collection_archive---------0-----------------------#2024-09-04`](https://towardsdatascience.com/automated-prompt-engineering-the-definitive-hands-on-guide-1476c8cd3c50?source=collection_archive---------0-----------------------#2024-09-04)
## 学习如何自动化提示工程,并在 LLM 工作负载中解锁显著的性能提升
[](https://heiko-hotz.medium.com/?source=post_page---byline--1476c8cd3c50--------------------------------)[](https://towardsdatascience.com/?source=post_page---byline--1476c8cd3c50--------------------------------) [Heiko Hotz](https://heiko-hotz.medium.com/?source=post_page---byline--1476c8cd3c50--------------------------------)
·发表于 [Towards Data Science](https://towardsdatascience.com/?source=post_page---byline--1476c8cd3c50--------------------------------) ·阅读时长 21 分钟·2024 年 9 月 4 日
--

作者提供的图片——使用 Imagen 3 创建
# 这篇文章讲什么?
自动化提示工程(APE)是一种自动化生成和优化大型语言模型(LLM)提示的技术,旨在提高模型在特定任务上的表现。它使用提示工程的概念,手动构建和测试各种提示,并将整个过程自动化。如我们将看到的,它与传统监督式机器学习中的自动化超参数优化***非常***相似。
在本教程中,我们将深入探讨自动化提示工程(APE):首先,我们将了解其原理,探讨一些生成提示的策略,以及其他相关技术,如示例选择。然后,我们将进入动手操作部分,从零开始编写一个 APE 程序,即我们不会使用像 DSPy 这样的库来为我们完成这些工作。通过这样做,我们将更好地理解 APE 的原理,并能更好地利用那些可以直接提供此功能的框架。
# 自动微分(AutoDiff):带有示例的简要介绍
> 原文:[`towardsdatascience.com/automatic-differentiation-autodiff-a-brief-intro-with-examples-3f3d257ffe3b?source=collection_archive---------2-----------------------#2024-10-11`](https://towardsdatascience.com/automatic-differentiation-autodiff-a-brief-intro-with-examples-3f3d257ffe3b?source=collection_archive---------2-----------------------#2024-10-11)
## 介绍了自动微分(AutoDiff)的机制,探索其数学原理、实现策略以及在当前最常用框架中的应用
[](https://ebrahimpichka.medium.com/?source=post_page---byline--3f3d257ffe3b--------------------------------)[](https://towardsdatascience.com/?source=post_page---byline--3f3d257ffe3b--------------------------------) [Ebrahim Pichka](https://ebrahimpichka.medium.com/?source=post_page---byline--3f3d257ffe3b--------------------------------)
·发表于 [Towards Data Science](https://towardsdatascience.com/?source=post_page---byline--3f3d257ffe3b--------------------------------) ·阅读时长 10 分钟·2024 年 10 月 11 日
--

图片来自 [Bozhin Karaivanov](https://unsplash.com/@bkaraivanov?utm_source=medium&utm_medium=referral) 于 [Unsplash](https://unsplash.com/?utm_source=medium&utm_medium=referral)
# 微分在现代机器学习优化中的基础作用
机器学习的核心在于优化损失/目标函数。这个优化过程在很大程度上依赖于计算这些函数相对于模型参数的梯度。正如 Baydin 等人(2018 年)在他们的全面调查中阐明的那样[1],这些梯度引导了优化算法中的迭代更新,比如随机梯度下降(SGD):
*θₜ₊₁ = θₜ - α ∇θ L(θₜ)*
其中:
+ θₜ 表示步骤 t 时的模型参数
+ α 是学习率
+ ∇_θ L(θₜ) 表示损失函数 L 相对于参数 θ 的梯度
这个简单的更新规则掩盖了在拥有数百万甚至数十亿参数的深度神经网络中计算梯度的复杂性。
# 2\. 微分三位一体
# 使用 GroundingDino 进行自动标注
> 原文:[`towardsdatascience.com/automatic-labeling-of-object-detection-datasets-using-groundingdino-b66c486656fe?source=collection_archive---------3-----------------------#2024-02-06`](https://towardsdatascience.com/automatic-labeling-of-object-detection-datasets-using-groundingdino-b66c486656fe?source=collection_archive---------3-----------------------#2024-02-06)
## 本文是一个实用指南,讲解如何使用 GroundingDino 算法标注物体检测数据集。包括代码。
[](https://medium.com/@lihigurarie?source=post_page---byline--b66c486656fe--------------------------------)[](https://towardsdatascience.com/?source=post_page---byline--b66c486656fe--------------------------------) [Lihi Gur Arie, PhD](https://medium.com/@lihigurarie?source=post_page---byline--b66c486656fe--------------------------------)
·发布于 [Towards Data Science](https://towardsdatascience.com/?source=post_page---byline--b66c486656fe--------------------------------) ·6 分钟阅读·2024 年 2 月 6 日
--

作者使用 GroundingDino 并输入“成熟番茄”提示进行标注。图像由[Markus Spiske](https://www.pexels.com/photo/green-and-red-oval-fruits-965740/)提供。
# 介绍
直到最近,物体检测模型执行的是特定任务,比如检测图像中的企鹅。然而,深度学习的最新进展催生了基础模型。这些模型是在庞大的数据集上以通用方式训练的大型模型,使它们能够适应各种任务。例如,像[CLIP](https://medium.com/towards-data-science/clip-creating-image-classifiers-without-data-b21c72b741fa)这样的模型用于图像分类,SAM 用于分割,GroundingDino 用于物体检测。基础模型通常较大且计算要求高。如果没有资源限制,它们可以直接用于零-shot 推理。否则,它们可以用于标注数据集,以训练一个更小、更具体的模型,这一过程称为蒸馏。
在本指南中,我们将学习如何使用 GroundingDino 模型进行番茄图像的零-shot 推理。我们将探索该算法的能力,并利用它标注整个番茄数据集。得到的数据集随后可以用于训练下游目标模型,如 YOLO。
> 如果你没有付费的 Medium 账号,可以在这里免费阅读。
# GroundingDino
***背景***
GroundingDino 是由 IDEA-Research 在 2023 年开发的最先进(SOTA)算法[1]。它通过文本提示从图像中检测物体。名称“GroundingDino”结合了“grounding”(一个将视觉和语言理解连接在 AI 系统中的过程)和基于变换器的检测器“DINO”[2]。该算法是一个零-shot 物体检测器,这意味着它可以识别那些它没有专门训练过的类别的物体,而无需看到任何示例(shot)。
***架构***
1. 模型接收图像和文本描述的配对作为输入。
1. 图像特征通过**图像骨干网络**(如 Swin Transformer)提取,文本特征通过**文本骨干网络**(如 BERT)提取。
1. **特征增强器**模块通过多模态精细化结合文本和图像特征,使用交叉注意机制促进这两种模态之间的互动。
1. 接下来,‘**语言引导查询选择**’模块选择与输入文本最相关的特征作为解码器查询。
1. 然后,这些查询被输入到**解码器**中,以精细调整与文本信息最佳对齐的物体检测框的预测。它输出最终的边界框建议。
1. 该模型输出 900 个物体边界框及其与输入文本的相似度分数。相似度分数高于`box_threshold`的框会被选中,且相似度高于`text_threshold`的单词会作为预测标签。

图像由 Xiangyu 等人制作,2023 年[3]
***提示工程***
GroundingDino 模型将文本提示编码为一个学习到的潜在空间。改变提示可以产生不同的文本特征,这会影响检测器的性能。为了增强预测性能,建议尝试多个提示,选择一个能提供最佳结果的提示。值得注意的是,在写这篇文章时,我必须尝试多个提示,才能找到理想的那个,有时还会遇到意外的结果。
# 代码实现
***开始使用***
首先,我们将从 GitHub 克隆[GroundingDino 仓库](https://github.com/IDEA-Research/GroundingDINO),通过安装必要的依赖项来设置环境,并下载预训练的模型权重。
```py
# Clone:
!git clone https://github.com/IDEA-Research/GroundingDINO.git
# Install
%cd GroundingDINO/
!pip install -r requirements.txt
!pip install -q -e .
# Get weights
!wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
在图像上的推理
我们将通过将物体检测算法应用于一张番茄的图像来开始探索。我们的初步目标是检测图像中的所有番茄,因此我们将使用文本提示tomato。如果你想使用不同的类别名称,可以用点号.将它们分开。请注意,边界框的颜色是随机的,并没有特定的含义。
python3 demo/inference_on_a_image.py \
--config_file 'groundingdino/config/GroundingDINO_SwinT_OGC.py' \
--checkpoint_path 'groundingdino_swint_ogc.pth' \
--image_path 'tomatoes_dataset/tomatoes1.jpg' \
--text_prompt 'tomato' \
--box_threshold 0.35 \
--text_threshold 0.01 \
--output_dir 'outputs'

使用‘tomato’提示的注释。图片由Markus Spiske提供。
GroundingDino 不仅将物体检测为类别(如番茄),还能够理解输入文本,这一任务被称为指代表达理解(Referring Expression Comprehension,简称 REC)。让我们将文本提示从番茄改为成熟番茄,并获得以下结果:
python3 demo/inference_on_a_image.py \
--config_file 'groundingdino/config/GroundingDINO_SwinT_OGC.py' \
--checkpoint_path 'groundingdino_swint_ogc.pth' \
--image_path 'tomatoes_dataset/tomatoes1.jpg' \
--text_prompt 'ripened tomato' \
--box_threshold 0.35 \
--text_threshold 0.01 \
--output_dir 'outputs'

使用‘成熟番茄’提示的标注。图片来自Markus Spiske。
值得注意的是,模型能够“理解”文本,并区分“番茄”和“成熟的番茄”。它甚至会标注部分成熟但还没有完全变红的番茄。如果我们的任务只需要标注完全成熟的红色番茄,我们可以将box_threshold从默认值 0.35 调整为 0.5。
python3 demo/inference_on_a_image.py \
--config_file 'groundingdino/config/GroundingDINO_SwinT_OGC.py' \
--checkpoint_path 'groundingdino_swint_ogc.pth' \
--image_path 'tomatoes_dataset/tomatoes1.jpg' \
--text_prompt 'ripened tomato' \
--box_threshold 0.5 \
--text_threshold 0.01 \
--output_dir 'outputs'

使用‘成熟番茄’提示和**box_threshold = 0.5**的标注。图片来自Markus Spiske。
标注数据集的生成
尽管 GroundingDino 拥有出色的能力,但它是一个大型且较慢的模型。如果需要实时物体检测,可以考虑使用像 YOLO 这样的更快模型。训练 YOLO 及类似模型需要大量标注数据,这可能既昂贵又耗时。然而,如果你的数据不是独特的,你可以使用 GroundingDino 来标注数据。欲了解更多关于高效 YOLO 训练的信息,请参阅我之前的文章 [4]。
GroundingDino 仓库包含一个脚本,用于以COCO 格式标注图像数据集,这对于 YOLOx 等模型非常适用。
from demo.create_coco_dataset import main
main(image_directory= 'tomatoes_dataset',
text_prompt= 'tomato',
box_threshold= 0.35,
text_threshold = 0.01,
export_dataset = True,
view_dataset = False,
export_annotated_images = True,
weights_path = 'groundingdino_swint_ogc.pth',
config_path = 'groundingdino/config/GroundingDINO_SwinT_OGC.py',
subsample = None
)
-
export_dataset — 如果设置为 True,COCO 格式的标注将保存在名为‘coco_dataset’的目录中。
-
view_dataset — 如果设置为 True,标注后的数据集将在 FiftyOne 应用中显示以进行可视化。
-
export_annotated_images — 如果设置为 True,标注后的图像将保存在名为‘images_with_bounding_boxes’的目录中。
-
subsample (int) — 如果指定,则仅从数据集中标注这个数量的图像。
不同的 YOLO 算法需要不同的标注格式。如果你计划训练 YOLOv5 或 YOLOv8,你需要将数据集导出为YOLOv5 格式。尽管主脚本中硬编码了导出类型,但你可以通过调整create_coco_dataset.main中的dataset_type参数,将其从fo.types.COCODetectionDataset更改为fo.types.YOLOv5Dataset(第 72 行)。为了保持组织性,我们还将输出目录名从‘coco_dataset’更改为‘yolov5_dataset’。更改脚本后,重新运行create_coco_dataset.main。
if export_dataset:
dataset.export(
'yolov5_dataset',
dataset_type=fo.types.YOLOv5Dataset
)
结语
GroundingDino 通过使用文本提示在目标检测标注方面提供了显著的进展。在本教程中,我们探讨了如何使用该模型自动标注图像或整个数据集。然而,在将这些标注用于训练后续模型之前,手动审查和验证这些标注至关重要。
为方便起见,附带了一个包含完整代码的用户友好型 Jupyter 笔记本:
感谢阅读!
想了解更多吗?
参考文献
[0] Colab 笔记本中的代码:link
1 Grounding DINO:将 DINO 与基础预训练结合用于开放集目标检测,2023 年。
[2] Dino:使用改进的去噪锚框进行端到端目标检测,2022 年。
[3] 一个开放且全面的管道,用于统一目标定位和检测,2023 年。
[4] YOLOv5 算法的目标检测实用指南,作者:Dr. Lihi Gur Arie。
JAX 中的自动向量化
让循环飞走吧!
·发布于Towards Data Science ·阅读时间 8 分钟·2024 年 10 月 24 日
--
JAX 因其在数学计算和机器学习中的速度、效率和灵活性而闻名。但它其中一个鲜为人知的超级能力——可以让你摆脱无休止的循环和模板代码——就是自动向量化。

如果你曾经写过处理数组或批量数据的代码,你就知道优化并行化是多么繁琐。但通过 JAX 的**vmap**(向量化映射)函数,你可以告别丑陋的循环,迎来简洁、高效并行化的代码。
在本文中,我们将深入探讨 JAX 中的自动向量化。我们将探索向量化是如何工作的,为什么它对加速计算至关重要,以及如何利用 JAX 的vmap避免写显式的循环。在此过程中,我们将通过一些真实世界的例子,并通过代码展示如何让你更加喜爱 JAX。
准备好了吗?我们出发吧!
向量化到底是怎么回事?
在我们深入了解 JAX 的具体内容之前,先来谈谈向量化的一般概念。在传统编程中,你可能会写出逐一处理每个数据点的循环代码。例如,如果你想对数组中的每个元素应用一个函数,你可能会使用for循环……
使用 Python 和 GitHub Actions 自动化数据管道
一种简单(且免费的)运行数据工作流的方法
·发表于 Towards Data Science ·阅读时间 9 分钟·2024 年 5 月 30 日
--
这是关于全栈数据科学(FSDS)系列中的第 4 篇文章,更大的系列。在上一篇文章中,我分享了一个具体的示例,展示了如何为机器学习项目构建数据管道。然而,这个示例的一个限制是数据管道必须手动运行。虽然这对于某些应用来说可能没问题,但更多时候,自动化比让人手动执行这个过程要更好。在这篇文章中,我将介绍一种使用 Python 和 GitHub Actions 自动化这一过程的简单方法。

图片来源:Chen Mizrach 于 Unsplash
几年前,Andrej Karpathy 发表了一次演讲,描述了“Operation Vacation”1。这就是特斯拉全自动驾驶工程团队所称的目标——完全自动化自驾模型的改进。
尽管这个目标有些玩笑性质,但它展示了我在大多数数据科学家和工程师中看到的一个愿望:希望构建一个可以自主运行的系统(这样他们就可以去度假)。
在这里,我将讨论如何自动化任何机器学习系统中的一个关键元素——数据管道。
使用 Python 和 SQL 自动化 ETL 到 SFTP 服务器
学习如何在 Windows 上自动化每日数据传输过程,从 PostgreSQL 数据库到远程服务器
·发表于 Towards Data Science ·阅读时长 16 分钟·2024 年 8 月 24 日
--

图片由 Shubham Dhage 提供,来源于 Unsplash
文件从一个位置传输到另一个位置的过程显然是自动化的完美候选者。这个过程反复执行可能会让人感到乏味,尤其是当你需要为多个数据组执行整个 ETL(提取、转换、加载)过程时。
想象一下,你所在的公司将数据存储在数据仓库中,然后他们决定将部分分析工作外包给一个外部数据分析供应商。该供应商提供一个定制的分析软件,能够为公司核心生产团队展示仪表板和报告。
这意味着,作为数据工程师的你,将需要按照外包合同的约定,定期(每天、每小时、每 30 分钟或其他频率)将数据传输给该供应商。
本文详细解释了包含 SFTP 上传的 ETL 过程。我们将结合使用安全文件传输协议(SFTP),这是一种通过加密文件来在两台远程服务器之间传输文件的安全方式,采用的加密方法是被称为安全外壳(SSH)协议。
使用 DSPy 和 Haystack 自动化提示工程
通过示例教会你的 LLM 如何回答问题
·发布于Towards Data Science ·阅读时长 9 分钟·2024 年 6 月 7 日
--

摄影:由Markus Winkler提供,图片来源于Unsplash
构建生成式 AI 应用程序时,最令人沮丧的部分之一是手动优化提示的过程。在今年早些时候 LinkedIn 发布的文章中,他们描述了在部署一个代理式 RAG 应用程序后学到的经验。其中一个主要挑战是确保一致的质量。他们花了 4 个月时间调整应用程序的各个部分,包括提示,以缓解诸如幻觉等问题。
DSPy 是一个开源库,它试图将提示参数化,从而将其转化为一个优化问题。原始论文将提示工程称为“脆弱且不可扩展”,并将其与“手动调优分类器的权重”进行了比较。
Haystack是一个开源库,用于构建 LLM 应用程序,包括 RAG 管道。它是平台无关的,并且与不同的 LLM 提供商、搜索数据库等有大量集成。它还拥有自己的评估指标。
在本文中,我们将简要介绍 DSPy 的内部工作原理,并展示如何使用它教会 LLM 在回答学术医学数据集中的问题时,更倾向于提供简洁的答案。
DSPy 快速概述
这篇来自 TDS 的文章对 DSPy 进行了深入的探讨。我们将总结并使用其中的一些示例。
为了构建一个可以优化的 LLM 应用程序,DSPy 提供了两个主要的抽象概念:签名和模块。签名是定义与 LLM 交互的系统输入和输出的一种方式。签名会在内部由 DSPy 转化为提示。
class Emotion(dspy.Signature):
# Describe the task
"""Classify emotions in a sentence."""
sentence = dspy.InputField()
# Adding description to the output field
sentiment = dspy.OutputField(desc="Possible choices: sadness, joy, love, anger, fear, surprise.")
使用 DSPy 的 Predict 模块时(稍后会详细介绍),这个签名会转化为以下的提示:
Classify emotions in a sentence.
---
Follow the following format.
Sentence: ${sentence}
Sentiment: Possible choices: sadness, joy, love, anger, fear, surprise.
---
Sentence:
然后,DSPy 还有 模块,这些模块定义了“预测器”,这些预测器拥有可以优化的参数,例如选择少量示例的方式。最简单的模块是 dspy.Predict,它不会修改签名。在本文后面,我们将使用 dspy.ChainOfThought 模块,它会要求 LLM 提供推理过程。
一旦我们尝试优化一个模块(或者说,DSPy 称之为“编译”一个模块),事情就开始变得有趣了。优化模块时,通常需要指定三件事:
-
需要优化的模块,
-
一个训练集,可能包含标签,
-
以及一些评估指标。
使用 dspy.Predict 或 dspy.ChainOfThought 模块时,DSPy 会在训练集上进行搜索,选择最好的示例,作为少量示例添加到提示中。在 RAG 的情况下,它还可以包括用于生成最终回答的上下文。它将这些示例称为“示范”。
你还需要指定一个优化器类型,用来在参数空间中进行搜索。在这篇文章中,我们使用的是 BootstrapFewShot 优化器。这个算法是如何在内部工作的呢?其实它非常简单,文中提供了一些简化的伪代码:
class SimplifiedBootstrapFewShot ( Teleprompter ) :
def __init__ ( self , metric = None ) :
self . metric = metric
def compile ( self , student , trainset , teacher = None ) :
teacher = teacher if teacher is not None else student
compiled_program = student . deepcopy ()
# Step 1\. Prepare mappings between student and teacher Predict modules .
# Note : other modules will rely on Predict internally .
assert student_and_teacher_have_compatible_predict_modules ( student , teacher )
name2predictor , predictor2name = map_predictors_recursively ( student , teacher )
# Step 2\. Bootstrap traces for each Predict module .
# We ’ll loop over the training set . We ’ll try each example once for simplicity .
for example in trainset :
if we_found_enough_bootstrapped_demos () : break
# turn on compiling mode which will allow us to keep track of the traces
with dspy . setting . context ( compiling = True ) :
# run the teacher program on the example , and get its final prediction
# note that compiling = True may affect the internal behavior here
prediction = teacher (** example . inputs () )
# get the trace of the all interal Predict calls from teacher program
predicted_traces = dspy . settings . trace
# if the prediction is valid , add the example to the traces
if self . metric ( example , prediction , predicted_traces ) :
for predictor , inputs , outputs in predicted_traces :
d = dspy . Example ( automated = True , ** inputs , ** outputs )
predictor_name = self . predictor2name [id( predictor ) ]
compiled_program [ predictor_name ]. demonstrations . append ( d )
return compiled_program
搜索算法会遍历 trainset 中的每一个训练输入,获取预测结果,然后通过查看 self.metric(example, prediction, predicted_traces) 来检查它是否“通过”评估指标。如果通过,那么这些示例会被添加到已编译程序的 demonstrations 中。
让我们创建一个自定义的 Haystack 流水线
完整的代码可以在这个食谱及关联的 colab中找到,因此我们这里只会讲解其中一些最重要的步骤。作为示例,我们使用了一个数据集,它来源于PubMedQA 数据集(都在 MIT 许可下)。该数据集包含基于医学研究论文摘要的问题及其相关答案。某些提供的答案可能相当长,因此我们将使用 DSPy 来“教导”LLM 优先生成更简洁的答案,同时保持最终答案的准确性。
在将前 1000 个示例添加到内存文档存储(可以被任何数量的检索器替换)后,我们现在可以构建我们的 RAG 管道:
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.generators import OpenAIGenerator
from haystack.components.builders import PromptBuilder
from haystack import Pipeline
retriever = InMemoryBM25Retriever(document_store, top_k=3)
generator = OpenAIGenerator(model="gpt-3.5-turbo")
template = """
Given the following information, answer the question.
Context:
{% for document in documents %}
{{ document.content }}
{% endfor %}
Question: {{question}}
Answer:
"""
prompt_builder = PromptBuilder(template=template)
rag_pipeline = Pipeline()
rag_pipeline.add_component("retriever", retriever)
rag_pipeline.add_component("prompt_builder", prompt_builder)
rag_pipeline.add_component("llm", generator)
rag_pipeline.connect("retriever", "prompt_builder.documents")
rag_pipeline.connect("prompt_builder", "llm")
让我们试试吧!
question = "What effects does ketamine have on rat neural stem cells?"
response = rag_pipeline.run({"retriever": {"query": question}, "prompt_builder": {"question": question}})
print(response["llm"]["replies"][0])
上述问题的答案:
凯他命在 200、500、800 和 1000µM 浓度下以剂量依赖的方式抑制大鼠神经干细胞的增殖。此外,凯他命还减少细胞内 Ca(2+)浓度,抑制蛋白激酶 C-α(PKCα)激活,并抑制大鼠神经干细胞中细胞外信号调节激酶 1/2(ERK1/2)的磷酸化。这些效应似乎不是通过半胱天冬酶-3 依赖性凋亡介导的。
我们可以看到,答案通常非常详细且较长。
使用 DSPy 获取更简洁的答案
我们从创建输入和输出字段的 DSPy 签名开始:
class GenerateAnswer(dspy.Signature):
"""Answer questions with short factoid answers."""
context = dspy.InputField(desc="may contain relevant facts")
question = dspy.InputField()
answer = dspy.OutputField(desc="short and precise answer")
如我们所见,我们在描述中已经指定我们期望一个简短的答案。
然后,我们创建一个 DSPy 模块,稍后将进行编译:
class RAG(dspy.Module):
def __init__(self):
super().__init__()
self.generate_answer = dspy.ChainOfThought(GenerateAnswer)
# this makes it possible to use the Haystack retriever
def retrieve(self, question):
results = retriever.run(query=question)
passages = [res.content for res in results['documents']]
return Prediction(passages=passages)
def forward(self, question):
context = self.retrieve(question).passages
prediction = self.generate_answer(context=context, question=question)
return dspy.Prediction(context=context, answer=prediction.answer)
我们使用之前定义的 Haystack 检索器来搜索文档存储中的文档 results = retriever.run(query=question)。预测步骤使用 DSPy 模块 dspy.ChainOfThought,该模块教会语言模型在做出回答之前逐步思考。
在编译过程中,优化后的提示将如下所示:

所有粗体文本都被 DSPy 选择的示例和特定查询的问答上下文替换。由作者制作。
最后,我们需要定义我们希望优化的度量标准。评估器将包括两个部分:
-
[SASEvaluator](https://docs.haystack.deepset.ai/docs/sasevaluator): 语义答案相似度度量是一个介于 0 和 1 之间的分数,用于计算给定输出与实际输出之间的相似度。 -
我们将对超过 20 个单词的答案进行惩罚,惩罚会根据单词数的增加成比例增长,最大为 0.5。
from haystack.components.evaluators import SASEvaluator
sas_evaluator = SASEvaluator()
sas_evaluator.warm_up()
def mixed_metric(example, pred, trace=None):
semantic_similarity = sas_evaluator.run(ground_truth_answers=[example.answer], predicted_answers=[pred.answer])["score"]
n_words=len(pred.answer.split())
long_answer_penalty=0
if 20<n_words<40:
long_answer_penalty = 0.025 * (n_words - 20)
elif n_words>=40:
long_answer_penalty = 0.5
return semantic_similarity - long_answer_penalty
我们的评估数据集由 20 个训练示例和 50 个开发集示例组成。
如果我们使用下面的代码评估当前的简单 RAG 管道,我们得到一个平均得分为 0.49。
查看一些示例可以帮助我们直观理解得分的作用:
问题:新辅助化疗放疗到手术的时间增加是否与食管癌的病理完全缓解率较高相关?
预测答案:是的,新辅助化疗放疗到手术的时间增加与食管癌的病理完全缓解率较高相关。
得分:0.78
但是
问题:基于静息状态间歇期 MEG 记录的癫痫灶定位是否可以不考虑尖波的存在或缺失?
预测答案:是的。
得分:0.089
从这些示例中我们可以看到,如果答案太短,它会得到一个较低的得分,因为它与真实答案的相似度降低。
然后,我们使用 DSPy 编译 RAG 管道:
from dspy.teleprompt import BootstrapFewShot
optimizer = BootstrapFewShot(metric=mixed_metric)
compiled_rag = optimizer.compile(RAG(), trainset=trainset)
在我们完成这一过程并重新评估编译后的管道后,得分现在是 0.69!
现在是时候获取最终优化的提示并将其添加到我们的 Haystack 管道中。
获取最终的提示优化管道
我们可以通过查看compiled_rag对象中的demos字段来查看 DSPy 选择的少量示例:
compiled_rag.predictors()[0].demos
最终提示中提供了两种类型的示例:少量示例和引导演示,类似于上面所展示的提示。少量示例是问答对:
Example({'question': 'Does increased Syk phosphorylation lead to overexpression of TRAF6 in peripheral B cells of patients with systemic lupus erythematosus?', 'answer': 'Our results suggest that the activated Syk-mediated TRAF6 pathway leads to aberrant activation of B cells in SLE, and also highlight Syk as a potential target for B-cell-mediated processes in SLE.'})
而引导演示则包含 LLM 的完整追踪,包括提供的上下文和推理(在下面的rationale字段中):
Example({'augmented': True, 'context': ['Chronic rhinosinusitis (CRS) …', 'Allergic airway …', 'The mechanisms and ….'], 'question': 'Are group 2 innate lymphoid cells ( ILC2s ) increased in chronic rhinosinusitis with nasal polyps or eosinophilia?', 'rationale': 'produce the answer. We need to consider the findings from the study mentioned in the context, which showed that ILC2 frequencies were associated with the presence of nasal polyps, high tissue eosinophilia, and eosinophil-dominant CRS.', 'answer': 'Yes, ILC2s are increased in chronic rhinosinusitis with nasal polyps or eosinophilia.'})
现在我们所需要做的就是提取这些由 DSPy 找到的示例,并将它们插入到我们的 Haystack 管道中:
static_prompt = lm.inspect_history(n=1).rpartition("---\n")[0]
我们的新管道变为:
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
from haystack.components.generators import OpenAIGenerator
from haystack.components.builders import PromptBuilder, AnswerBuilder
from haystack import Pipeline
template = static_prompt+"""
---
Context:
{% for document in documents %}
«{{ document.content }}»
{% endfor %}
Question: {{question}}
Reasoning: Let's think step by step in order to
"""
new_prompt_builder = PromptBuilder(template=template)
new_retriever = InMemoryBM25Retriever(document_store, top_k=3)
new_generator = OpenAIGenerator(model="gpt-3.5-turbo")
answer_builder = AnswerBuilder(pattern="Answer: (.*)")
optimized_rag_pipeline = Pipeline()
optimized_rag_pipeline.add_component("retriever", new_retriever)
optimized_rag_pipeline.add_component("prompt_builder", new_prompt_builder)
optimized_rag_pipeline.add_component("llm", new_generator)
optimized_rag_pipeline.add_component("answer_builder", answer_builder)
optimized_rag_pipeline.connect("retriever", "prompt_builder.documents")
optimized_rag_pipeline.connect("prompt_builder", "llm")
optimized_rag_pipeline.connect("llm.replies", "answer_builder.replies")
让我们检查一下之前尝试过的相同问题。
问题:
速氯胺对大鼠神经干细胞有哪些作用?
之前:
速氯胺(Ketamine)在 200、500、800 和 1000µM 的浓度下以剂量依赖性方式抑制大鼠神经干细胞的增殖。此外,速氯胺还降低了细胞内 Ca(2+)浓度,抑制了蛋白激酶 C-α(PKCα)激活以及大鼠神经干细胞中细胞外信号调节激酶 1/2(ERK1/2)的磷酸化。这些效应似乎并非通过半胱天冬酶-3 依赖的凋亡途径介导。
之后:
速氯胺在较高浓度下抑制大鼠神经干细胞的增殖,而不影响凋亡。此外,它降低了细胞内钙浓度,抑制了 PKCα激活和这些细胞中 ERK1/2 的磷酸化。
工作完成!
简单的结论
在这篇文章中,我们使用 DSPy 优化了在 Haystack RAG 管道中使用的提示。我们通过使用基于 Haystack 评估框架的自定义度量,惩罚 LLM 的长答案,同时保持与正确答案的相似度较高。通过这种方法,我们成功地提高了大约 40%的性能,而无需进行任何手动提示工程。
使用大型语言模型(LLMs)自动化研究工作流程

图片由作者使用midjourney制作
使用 AI 的原子级应用来增强研究人员
·发表于Towards Data Science ·阅读时间:13 分钟·2024 年 9 月 10 日
--
最近,我有幸在伦敦开放数据科学大会上主持了一次研讨会,讨论了我认为大型语言模型(LLMs)在通过自动化某些任务来增强学术和非学术研究人员的潜力,可能扮演的一个有趣角色。
在这篇文章中,我想深入探讨在那个研讨会上讨论的核心概念,并讨论我认为人工智能通过与不同领域研究人员的整合,正在出现的一个令人着迷的新角色。
增强什么?
我在研讨会上提出的问题是:
我们如何利用 LLMs 来增强或扩展研究工作流程,而不降低研究人员的认知参与度?
触及增强这一话题总是充满挑战,并且可能导致一些令人不舒服的对话,谈论人工智能如何在不久的将来取代人类。因此,为了更清晰地表达,我想从更具体的角度开始定义它:
增强 = 通过工具提升能力
增强的概念深深植根于道格拉斯·恩格尔巴特的工作中,他在某种程度上开创了这一思想的一个版本,即技术应该增强人类的能力……
自动化您的容器化模型部署
使用 AWS、Terraform、Ansible 和 GitLab-CI
·发表于Towards Data Science ·阅读时间:9 分钟·2024 年 4 月 22 日
--

在这篇文章中,我们将讨论部分内容,包括“服务基础设施”和“流程管理”。图片改编自 Sculley 等人(2015)中的图 1 1
MLOps(机器学习运维)是一个涵盖复杂任务、流程和基础设施的术语,旨在搭建、自动化和监控。谷歌在 2015 年早期发布的一篇论文(在这一术语诞生之前)很好地描述了这一点,指出机器学习代码只占整个实际机器学习系统的一个小部分 1。
在本文中,我将分享 MLOps 的一个组成部分,展示如何将容器化工作负载或模型部署到服务基础设施,同时考虑到自动化、安全性和反馈回路。
您可以选择许多不同的平台和软件进行部署,而我所展示的只是其中之一。虽然文中使用的 AWS、Terraform、GitLab 和 Ansible 在现实世界中被广泛采用,但它们也有一些流行的替代品。然而,方法论和流程是中立的,也可以为您如何在其他平台上实施提供一些见解。
目录
架构概述
创建一个 Ansible Playbook
提供云资源 构建一个 GitLab-CI 流水线
AutoML 与 AutoGluon:仅需四行代码即可完成 ML 工作流
AutoGluon 如何主导 Kaggle 竞赛,以及你如何能够超越它。这个算法通过仅四行代码击败 99%的数据科学家。
·发表于Towards Data Science ·阅读时间 19 分钟·2024 年 7 月 3 日
--

由 DALL-E 生成的图像
在两个受欢迎的 Kaggle 竞赛中,AutoGluon 在仅对原始数据进行 4 小时训练后,击败了 99%的参与数据科学家(AutoGluon 团队。“AutoGluon:适用于文本、图像和表格数据的 AutoML。”2020 年)
这段话摘自AutoGluon 研究论文,完美地概括了我们今天要探讨的内容:一个能够以最少的编码交付出色表现的机器学习框架。你只需四行代码即可设置完整的 ML 流水线,而这个任务通常可能需要数小时才能完成。没错,仅需四行代码!自己看看:
from autogluon.tabular import TabularDataset, TabularPredictor
train_data = TabularDataset('train.csv')
predictor = TabularPredictor(label='Target').fit(train_data, presets='best_quality')
predictions = predictor.predict(train_data)
这四行代码通过自动识别每一列的数据类型来处理数据预处理,通过寻找有用的列组合来进行特征工程,并通过集成模型训练来识别给定数据集中的最佳表现模型……
自主代理生态系统、数据集成、开源 LLM 以及其他 11 月必读文章
·发表于 Towards Data Science ·作为 新闻通讯 发送 ·4 分钟阅读·2024 年 11 月 28 日
--
想要写第一篇 TDS 文章吗? 我们一直欢迎新作者投稿。
欢迎来到 2024 年的倒数第二个月度回顾——我们真能离年底这么近了吗?!我们相信,和我们一样,你也在忙于收尾工作,并为各自的项目做最后的冲刺。我们这边有不少这样的工作要完成,其中一个让我们激动不已的更新是,TDS 现在已经在 Bluesky 平台活跃。如果你是最近加入该平台的用户(或者一直在考虑加入),我们鼓励你 关注我们的账号。
我们的脑海中还有什么?就是我们作者最近几周发布的所有精彩文章,激励我们的读者学习新技能,探索数据科学和 AI 领域的新兴话题。我们的月度亮点涵盖了广泛的内容——正如往常一样——并提供了多种可访问的入口,带你深入了解及时的技术话题,从知识图谱到 RAG 评估。让我们一探究竟吧。
月度亮点
-
Agentic Mesh:生成型 AI 驱动的自主代理生态系统的未来
自主代理如何在安全、高效和可信的方式下相互发现、协作、互动和交易,究竟需要什么?Eric Broda展示了他对代理网的激动人心的愿景,这是一个作为 AI 代理无缝连接的框架。
-
使用 LLM 图形转换器构建知识图谱
对于任何有兴趣深入实践的人,Tomaz Bratanic的最新技术指南带我们深入了解 LangChain 实现图构建与 LLM 结合的具体细节。
-
为什么是 ETL-Zero?理解数据集成中的转变
“数据不应像传统方式那样需要在独立步骤中显式地提取、转换和加载,而应该在不同系统之间无缝流动。”Sarah Lea介绍了一种利用 Python 创建简化 ETL 过程的新方法。

图片由Jacky Watt拍摄,来源于Unsplash
-
托管开源 LLM 的经济学
随着 LLM 使用量在过去一年左右急剧增加,实践者们越来越多地开始思考,如何最有效地部署这些模型。Ida Silfverskiöld提供了一个详细的分析,讲解了考虑的各种因素,以及不同提供商在处理时间、冷启动延迟以及 CPU、内存和 GPU 成本方面的对比。
-
我如何通过两个小习惯提高作为数据科学家的生产力
有时候,对日常例行公事做些微小的改变,就能带来与彻底改造工作流程同样大的影响。一个例子是:Philippe Ostiguy, M. Sc.的新文章,介绍了两项看似与工作无关的习惯,关于休息和心理强度,这些习惯显著提升了 Philippe 的生产力。
-
构建初级数据科学作品集的 6 个月详细计划
无论你是刚刚成为数据科学家的新手,还是正在寻找新职位的经验丰富的专业人士,Sabrine Bendimerad为成功打造作品集提供的蓝图,将为你提供切实可行的思路和实际的时间表,帮助你完成工作。
-
如何减少 Python 运行时以应对高负荷任务
每个人都希望代码运行得更快,但在处理特别繁重的工作负载时,遇到瓶颈几乎是不可避免的。然而,正如Jiayan Yin在她的高效实用文章中所展示的那样,可能还有一些 GPU 优化选项你尚未利用,可以加速你的 Python 代码。
-
如何从文档创建 RAG 评估数据集
正如Dr. Leon Eversberg在他的最新教程中所解释的,“通过上传 PDF 文件并将其存储在向量数据库中,我们可以通过向量相似性搜索检索这些知识,然后将检索到的文本作为附加上下文插入 LLM 提示中。” 结果是什么呢?一种稳健的评估 RAG 工作流的方法,减少了出现幻觉的概率。
我们最新的一批新作者
每个月,我们都很高兴看到一批新作者加入 TDS,他们将自己独特的声音、知识和经验与我们的社区分享。如果你正在寻找新的作家进行探索和关注,只需浏览我们最新加入的作者作品,包括Jessica S、Tanner McRae、Ed Sandoval、Robert Corwin、Eric Colson、Joseph Ben、Marcus K. Elwin、Ro Isachenko、Michael Zakhary、Haim Barad、Elisa Yao、Mohamad Hamza、Eric Silberstein、Lorenzo Mezzini、David Teather、Diego Penilla、Daniel Klitzke、Iheb Rachdi、Aaron Beckley、Andrea Rosales、Bohumir Buso、Loizos Loizou、Omri Eliyahu Levy、Ohad Eytan、Julián Peller、Yan Georget、James Barney、Dima Sergeev、Pere Martra,以及Gizem Kaya等。
感谢您支持我们作者的工作!我们非常喜欢发布新作者的文章,如果您最近写了一篇有趣的项目教程、教学文章或关于我们核心主题的理论反思,欢迎随时与我们分享。
直到下一个 Variable,
TDS 团队
AutoRound:LLMs 的准确低比特量化
在量化感知训练和后训练量化之间
·发表于Towards Data Science ·7 分钟阅读·2024 年 6 月 29 日
--

由 DALL-E 生成
有许多量化方法可以减少大型语言模型(LLM)的大小。最近,已经提出了更好的低比特量化方法。例如,AQLM 在保留大部分模型准确性的同时实现了 2 比特量化。
AQLM的主要缺点是,量化大型模型的成本非常高。HQQ是另一个低比特量化的良好替代方案,但需要进一步的微调才能保持准确性。
英特尔在更好的量化算法研究方面也非常活跃。他们提出了 AutoRound,一种采用符号梯度下降(SignSD)的新量化方法。AutoRound 在低比特量化方面特别准确,并且比大多数其他方法量化得更快。
在本文中,我回顾了 AutoRound。我们将了解它的工作原理以及如何以最小的准确性下降量化 LLM,例如 Llama 3。我发现 AutoRound 是 GPTQ 和 HQQ 的一个非常好的替代方案。它能产生更准确的模型。
我实现了以下笔记本,展示了如何使用 AutoRound 对大型语言模型(LLMs)进行量化,并评估/基准测试结果模型:
避免在 2024 年构建数据平台
为什么关于“构建数据平台”的文章大多数是误导性的
·发布于Towards Data Science ·13 分钟阅读·2024 年 8 月 13 日
--

图片来源:Patrick Tomasso于Unsplash
你可能认为,阅读关于不该构建的内容没有什么价值。但由于数据和分析平台工具的泛滥、现代数据架构(MDS)正在失去吸引力,以及关于“构建数据平台”的诸多文章,我发布了这个警告。
这是对在大公司工作的 IT 专业人员的警告。因为每位顾问往往会回答所有问题:“这取决于”你的具体情况,是否相关。所以根据我即将提供的背景信息,你自己做出决定。
数据平台
也许是因为 Zhamak Dehghani 称其为“作为平台的数据基础设施”,或者因为云服务商将我们卖给“数据平台即服务(DPaaS——与数据保护即服务不同)”,或者仅仅是因为构建平台已成潮流,我们才会陷入这一理念。
那么,数据平台到底是什么?
让我们以云计算平台模型为例,更好地理解其价值主张。它为开发人员提供了一个平台,可以在不处理底层基础设施的情况下构建、部署和管理应用程序。PaaS…
避免滥用和误用 T 检验和方差分析:分类响应的回归分析
使用基于 brms 的贝叶斯回归方法
·Published in Towards Data Science ·11 min read·Apr 7, 2024
--
在神经科学及其他生物医学科学中,常使用行为测试来评估对实验条件或治疗的反应。我们可以评估许多方面,从基本的运动和探索行为到记忆和学习。这些变量中的许多是连续的(数值型)反应,即它们可以在给定范围内取(有限或无限的)值。开放场中的时间、动物体重或大脑区域内的细胞数量是一些例子。
然而,在我们的实验中还有其他类型的变量需要记录。一个非常常见的变量是有序的类别变量,也叫做顺序变量。这些是具有自然顺序的类别变量,类似于我们在常见的调查中对陈述做出同意或不同意的回答,0 表示强烈不同意,5 表示强烈同意。为了方便在印刷或数字数据表中记录这些变量,我们通常将它们编码为数字。这就是在小鼠模型的脑缺血研究中使用的 5 分贝德森评分(1),其编码如下:
-
0 = 无可观察的缺陷
-
1 = 前肢屈曲
-
2 = 前肢屈曲和对侧推力的抗拒减少
-
3 = 转圈
-
4 = 转圈并围绕头尾轴旋转
-
5 = 无自发运动
请注意,数字仅仅是简单的约定。我们也可以使用 a、b、c、d、e;优秀、良好、不太好、差、非常差、几乎死亡;等等。我认为强调这一点的显而易见性并不自负。
然而,令人惊讶的是,图 1 代表了科学家在这个领域多年来一直在实践的一种特定且普遍的恶习:他们使用 t 检验和方差分析(ANOVA)来分析这种有序分类变量。

图 1:左:Onufriev 等人(2021)的神经评分(CC-BY)。右:Liu 等人(2022)的神经评分(CC-BY)。
我仍然无法找到一个合乎逻辑的解释,说明为什么那么多作者、审稿人和编辑对这种情况感到舒适。它不仅不合逻辑,更重要的是,它导致了偏倚的效应量、低检出率和 I 型错误率等问题 (2)。
当我扮演审稿人 2 的角色,并强调这一点,问作者们为什么要用专为处理连续数值变量设计的统计检验来评估分类响应时,我得到的是一长串遵循这种非理性做法的已发布文章。于是我最终找到了他们为什么这么做的答案:
这是 Gerd Gigerenzer(2004)所说的“盲目统计的仪式” (3)。
事实上,我们大多数科学家对于我们如何处理数据几乎没有概念,只是在简单地重复世界各地实验室代代相传的常见错误做法。
在本文中,我们接下来将探讨使用R、brms(4)和tidyverse(5)中的元素来分析有序分类变量的更可行的替代方法。
面对盲目统计的仪式
为了避免盲目统计的仪式,我们将通过肉眼判断数据点并将它们组织成表格,重新创建图 1 中 Liu 等人(2022)数据集:
# We Define the observations
observations <- list(
Sham = c(rep(0, 7)),
tMCAO = c(rep(2, 3), rep(3, 2), rep(4, 2)),
tMCAO_C = c(rep(1, 3), rep(2, 3), rep(3, 1))
)
# We create an empty data frame to populate
df <- data.frame(Group = character(), Response = integer())
# We populate the data frame
for (group in names(observations)) {
for (response in unique(observations[[group]])) {
df <- rbind(df, data.frame(Group = rep(group, sum(observations[[group]] == response)),
Response = rep(response, sum(observations[[group]] == response))))
}
}
head(df)

如果此时你查看 R 代码(在 R-studio 中),你会注意到变量Response被识别为从 0 到 4 的数字。当然,我们完全清楚这个响应不是数值型的,而是一个有序(序数)分类变量。因此,我们显式地进行转换:
df$Response <- factor(df$Response, levels = c("0", "1", "2", "3", "4"), ordered = TRUE)
str(df)

现在我们可以验证它被识别为有序分类变量。有了这一点,我们可以轻松地进入可视化和建模阶段。
探索性数据可视化
首先,让我们加载必要的库并为我们的图表创建一个视觉主题。
library(ggplot2)
library(brms)
library(ggdist)
library(easystats)
library(dplyr)
library(tibble)
Plot_theme <- theme_classic() +
theme(
plot.title = element_text(size=18, hjust = 0.5, face="bold"),
plot.subtitle = element_text(size = 10, color = "black"),
plot.caption = element_text(size = 12, color = "black"),
axis.line = element_line(colour = "black", size = 1.5, linetype = "solid"),
axis.ticks.length=unit(7,"pt"),
axis.title.x = element_text(colour = "black", size = 16),
axis.text.x = element_text(colour = "black", size = 16, angle = 0, hjust = 0.5),
axis.ticks.x = element_line(colour = "black", size = 1),
axis.title.y = element_text(colour = "black", size = 16),
axis.text.y = element_text(colour = "black", size = 16),
axis.ticks.y = element_line(colour = "black", size = 1),
legend.position="right",
legend.direction="vertical",
legend.title = element_text(colour="black", face="bold", size=12),
legend.text = element_text(colour="black", size=10),
plot.margin = margin(t = 10, # Top margin
r = 2, # Right margin
b = 10, # Bottom margin
l = 10) # Left margin
)
可视化分类数据的一个简单方法是使用条形图。
ggplot(df, aes(x = factor(Response), fill = Group)) +
geom_bar(position = "dodge") +
labs(x = "Response", y = "Count", title = "Response by Group") +
theme_minimal() +
scale_fill_brewer(palette = "Set1") +
Plot_theme +
theme(legend.position = "top", legend.direction = "horizontal")

图 2:按组着色的响应
图 2 展示了每组的频率。这比箱型图更符合分类变量的逻辑,而箱型图更适用于连续数值变量。现在让我们进行回归分析,揭示这个数据集的奥秘(如果有的话)。
使用 brms 拟合有序回归
我们将使用brms来拟合一个累积模型。这个模型假设神经学得分 Y 来自一个(假设是潜在的、但不可观察或测量的)连续变量 Y˜的分类(2)。像大多数brms教程中一样,我必须为跳过“先验”问题表示歉意。让我们假设一种“我不知道”的心态,并让默认的(平坦的)brms先验来做脏活。
我们通过遵循常规公式语法并添加cumulative("probit")作为家族来拟合cumulative模型(假设潜在变量及其对应的误差项呈正态分布)。我们只有一个预测变量,即每个动物所属的实验组。
Ordinal_Fit <- brm(Response ~ Group,
data = df,
family = cumulative("probit"),
# seed for reproducibility purposes
seed = 8807,
control = list(adapt_delta = 0.99),
# this is to save the model in my laptop
file = "Models/2024-04-03_UseAndAbuseANOVA/Ordinal_Fit.rds",
file_refit = "never")
# Add loo for model comparison
Ordinal_Fit <-
add_criterion(Ordinal_Fit, c("loo", "waic", "bayes_R2"))
在查看结果之前,让我们先进行一些模型诊断,以比较观测值和模型预测值。
模型诊断
pp_check(Ordinal_Fit, ndraws = 100) +
labs(title = "Ordinal regression") +
theme_classic()

图 3:有序回归的模型诊断
从图 3 中,我推测这些偏差是由于各组得分的分布(方差)不均匀所导致的。稍后,我们将看看预测方差是否能得到更好的估计。现在,我们可以继续进行。
检查有序回归的结果
让我们使用bayestestR包中的describe_posterior函数来查看后验分布(6),作为典型的summary函数的替代方法。
describe_posterior(Ordinal_Fit,
centrality = "mean",
dispersion = TRUE,
ci_method = "HDI",
test = "rope",
)

在这个模型中,阈值(得分阈值)被标记为“截距”,适用于我们的基线“假处理”条件。“GrouptMCAO”和“GrouptMCAO_C”的系数表示在潜在 Y˜尺度上与假处理动物的差异。因此,我们看到“GrouptMCAO”在潜在 Y˜尺度上得分高出 8.6 个标准差。
我想指出一个至关重要(且并非微不足道)的问题,作为未来文章的提醒(敬请关注)。在我看来,比较一个没有分布的组(在所有情况下为 0),比如假处理组,与一个有分布的组(实验组)是没有意义的。如果仔细思考,这个过程的目的就是空洞的。但是,让我们遵循现代科学文化的思维方式,不再过多纠结这个问题。
当然,如果我们通过brms的conditional effects功能以可视化的方式查看,区分各组之间的差异就变得更加可行。
Ordinal_CondEffects <-
conditional_effects(Ordinal_Fit, "Group", categorical = TRUE)
Ordinal_CondEffects <- plot(Ordinal_CondEffects,
plot = FALSE)[[1]]
Ordinal_CondEffects +
Plot_theme +
theme(legend.position = "bottom", legend.direction = "horizontal")

图 4:有序模型的条件效应
很奇怪(或者说,可能从 MCMC 仿真框架来看并不那么奇怪),模型估计虚拟组可能具有 1 的值,尽管我们在数据中没有看到这一点。一种可能的原因是模型对所有组估计了一个共同的方差。我们稍后会看看当我们将方差作为响应来建模时,这一方面是否发生变化。
否则,对于 tMCAO 和 tMCO_C 组获得的估计值与数据的匹配度要高得多。这使我们能够做出更精确的陈述,而不是运行 ANOVA(对于有序变量来说是不正确的)并说“组与组之间存在显著差异”,这种说法毫无意义。例如,模型告诉我们 tMCAO 组的 2-4 分数有相似的概率(约 25%)。对于 tMCAO_C 组情况不同,在神经学评分为 2 的概率较高(尽管存在相当大的不确定性),而其他分数的概率较低。如果我面对这个数据集和这个模型,我会声称 tMCAO_C 组表现出较少神经损伤(基于 1 和 2 分数)的概率高于 tMCAO 组。
我们能否得到量化不同组之间评分差异及其不确定性的精确数值?当然可以!我们可以使用 emmeans 包(7)。但这将是另一个帖子的话题(敬请期待)。
将方差作为响应变量
对于这种累积模型,没有 sigma 参数。相反,为了考虑不等方差,我们需要使用一个叫做 disc 的参数。关于这一方面的更多内容,请参见 brms 的创建者 Paul-Christian Bürkner 的有序回归教程(2)。
Ordinal_Mdl2 <- bf (Response ~ Group) +
lf(disc ~ 0 + Group, cmc = FALSE)
Ordinal_Fit2 <- brm(
formula = Ordinal_Mdl2,
data = df,
family = cumulative("probit"),
# seed for reproducibility purposes
seed = 8807,
control = list(adapt_delta = 0.99),
# this is to save the model in my laptop
file = "Models/2024-04-03_UseAndAbuseANOVA/Ordinal_Fit2.rds",
file_refit = "never")
# Add loo for model comparison
Ordinal_Fit2 <-
add_criterion(Ordinal_Fit2, c("loo", "waic", "bayes_R2"))
模型诊断
我们按之前的方法进行模型诊断:
pp_check(Ordinal_Fit2, ndraws = 100) +
labs(title = "Student-t") +
theme_classic()

图 5:我们的模型诊断,预测方差
图 5 显示了我的预期没有得到满足。将方差作为响应变量纳入模型并没有改善数据拟合。趋势保持不变,但预测结果仍然存在显著差异。然而,这是我们可以考虑作为生成模型的另一种候选模型。
检查我们新模型的结果
我们可视化了该模型的后验分布:
describe_posterior(Ordinal_Fit2,
centrality = "mean",
dispersion = TRUE,
ci_method = "HDI",
test = "rope",
)

我们可以看到与第一个模型相比,系数有了有意义的差异。 “GrouptMCAO”的系数从 8.6 增加到 15.9,而“GrouptMCAO_C”的系数从 7 增加到 10.8。毫无疑问,这个模型给了我们一个不同的结果。否则,方差项以“disc_GrouptMCAO”和“disc_GrouptMCAO_C”的名称呈现。我们可以看到,这两个方差与我们的“假手术”基线有很大的不同。
让我们绘制结果:
Ordinal_CondEffects2 <-
conditional_effects(Ordinal_Fit2, categorical = TRUE)
Ordinal_CondEffects2 <- plot(Ordinal_CondEffects2,
plot = FALSE)[[1]]
Ordinal_CondEffects2 +
Plot_theme +
theme(legend.position = "bottom", legend.direction = "horizontal")

图 6:我们的模型中,包括方差作为响应的条件效应
与我的预期相反,该模型仍然预测假手术组的动物有小概率得分为 1。我们在这里确认的是,这一预测并不基于(错误的)假设,即所有组的方差相同。然而,在这个框架(序数回归)中,这仍然是一个合逻辑的预测(而非不理性),基于阈值的假设。如果我们参考 Richard McElreath 的《统计思维》(8),我们会发现猴子拉杠的情况也相同。拟合一个更受约束的模型将需要使用信息先验。将这个话题留到未来的帖子中。我知道我在这里做了三个承诺,但我会履行它们。
在这个模型中,tMCAO 组不同结果的概率略有变化。然而,鉴于高不确定性,我不会根据这个模型改变我对该组表现的结论。另一方面,tMCAO_C 组的预测变化并不明显,不容易被眼睛察觉。让我们通过比较这两个模型来结束这篇博客文章。
模型比较
我们使用loo包(9,10)进行模型比较,以实现留一法交叉验证。对于使用 WAIC 标准的替代方法(11),我建议您阅读这篇文章,它同样由TDS 编辑发布。
loo(Ordinal_Fit, Ordinal_Fit2)
在这个方案下,模型的表现非常相似。实际上,第一个模型在样本外预测上略优。考虑到方差并没有在这种特定情况下提供太大帮助,(或许)依赖信息先验可以开启科学推断的下一步。
如果您愿意,欢迎提供您的评论或反馈,让我知道这段旅程对您是否有帮助。如果您想获得更多关于数据科学和其他主题的优质内容,您可以考虑成为medium 会员。
未来,您可能会在我的GitHub 网站上找到这篇文章的更新版本。
参考文献
1.M. Bieber, J. Gronewold, A.-C. Scharf, M. K. Schuhmann, F. Langhauser, S. Hopp, S. Mencl, E. Geuss, J. Leinweber, J. Guthmann, T. R. Doeppner, C. Kleinschnitz, G. Stoll, P. Kraft, D. M. Hermann, 小鼠中脑动脉闭塞后神经学评分的有效性和可靠性。Stroke。50,2875–2882(2019 年)。
2.P.-C. Bürkner, M. Vuorre, 心理学中的序数回归模型:教程。Advances in Methods and Practices in Psychological Science。2,77–101(2019 年)。
3. G. Gigerenzer, 无知统计学。社会经济学期刊。33,587–606(2004 年)。
4. P.-C. Bürkner, Brms:一个使用 Stan 的贝叶斯多层次模型的 R 包。80(2017 年),doi:10.18637/jss.v080.i01。
5. H. Wickham, M. Averick, J. Bryan, W. Chang, L. D. McGowan, R. François, G. Grolemund, A. Hayes, L. Henry, J. Hester, M. Kuhn, T. L. Pedersen, E. Miller, S. M. Bache, K. Müller, J. Ooms, D. Robinson, D. P. Seidel, V. Spinu, K. Takahashi, D. Vaughan, C. Wilke, K. Woo, H. Yutani, 欢迎使用 tidyverse。4,1686(2019 年)。
6. D. Makowski, M. S. Ben-Shachar, D. Lüdecke, bayestestR:描述贝叶斯框架下效应及其不确定性、存在性与显著性。4,1541(2019 年)。
7. R. V. Lenth, Emmeans:估计的边际均值,亦即最小二乘均值(2023 年)(可在CRAN.R-project.org/package=emmeans查看)。
8. R. McElreath,统计思维(Chapman;Hall/CRC,2020 年;dx.doi.org/10.1201/9780429029608)。
9. A. Vehtari, J. Gabry, M. Magnusson, Y. Yao, P.-C. Bürkner, T. Paananen, A. Gelman, Loo:用于贝叶斯模型的高效留一交叉验证和 WAIC(2022 年)(可在mc-stan.org/loo/查看)。
10. A. Vehtari, A. Gelman, J. Gabry, 实用贝叶斯模型评估方法:使用留一交叉验证和 WAIC。统计学与计算。27,1413–1432(2016 年)。
11. A. Gelman, J. Hwang, A. Vehtari, 理解贝叶斯模型的预测信息准则。统计学与计算。24,997–1016(2013 年)。
精彩的 Plotly 与代码系列(第一部分):条形图的替代方案
条形图并不总是最佳的解决方案。
·发表于Towards Data Science ·阅读时间 13 分钟·2024 年 10 月 21 日
--

图片由 Dall-E 生成
系列简介
数据可视化在讲述和理解故事中发挥着至关重要的作用。我一直对数据新闻中使用的流畅、带注解的图表感到着迷——这些可视化图表能够瞬间传达复杂的想法,让任何人都能理解。
我也深受像 Cole Nussbaumer Knaflic 的书《数据讲故事》1的启发,该书提供了创建清晰、富有影响力的可视化的基本最佳实践,强调了一种极简主义的方法,去除不必要的元素。她的方法是专注于真正重要的内容,确保数据的故事不被干扰,充分展现其核心。
另外,还有 AddTwoDigital,这是一家专注于数据故事讲述的数字机构。他们已经开源了一系列关于数据可视化最佳实践的博客文章 [2],展示了从简单的条形图到更复杂、令人费解的信息图的各种内容。他们的内容是灵感的宝藏,从他们的示例中总能学到一些新东西,而这些东西通常不会出现在典型的数据可视化书籍中。
用代码实现的精彩 Plotly 系列(第二部分):条形图上色
不要创建彩虹色的条形图。但也不要让你的条形图太单调无趣。
·发布于Towards Data Science ·阅读时间:9 分钟·2024 年 10 月 26 日
--

图像由 Dall-e 生成
欢迎来到我“用代码实现 Plotly”系列的第二篇文章!如果你错过了第一篇,可以通过下面的链接查看,或者通过我的“一个帖子统领所有”来跟随整个系列或我之前写过的其他主题。
## 用代码实现的精彩 Plotly 系列(第一部分):条形图的替代方案
条形图并不总是最佳的解决方案。
towardsdatascience.com [## 我所有写过的文章都在这里
这是一个动态文档,请继续关注以获取更多更新!
简短总结:我为什么要写这个系列
我常用的可视化工具是 Plotly。它非常直观,从图层的添加到交互性的实现都非常便捷。然而,尽管 Plotly 在功能上表现出色,它并没有提供一种“数据新闻”模板,无法直接生成精美的图表。
Awesome Plotly 与代码系列(第三部分):在长尾中突出显示条形图
谁说长尾不重要?让我们为它们提供一个突出显示的合适方式
·发布于 Towards Data Science ·阅读时间:7 分钟·2024 年 11 月 1 日
--

图像由 Dall-e 生成
欢迎来到我的“Plotly 与代码”系列的第三篇文章!如果你错过了第一篇,你可以通过下面的链接查看,或者浏览我的“单篇文章大全”来跟随整个系列或其他我之前写过的主题。
[## Awesome Plotly 与代码系列(第一部分):条形图的替代方案
条形图并不总是最好的解决方案。
这是一个实时更新的文档,请持续关注更多新增内容!
简短总结我为何要写这一系列文章
我的首选工具是 Plotly 来创建可视化。它非常直观,从图层的叠加到添加互动性。然而,虽然 Plotly 在功能上表现出色,但它并没有提供一个“数据新闻”模板,来提供精美的图表…
《Plotly 与代码系列(第四部分):分组条形图与多彩条形图》
彩色条形图真的能帮助讲清楚故事吗?
·发表于 Towards Data Science ·阅读时间 12 分钟·2024 年 11 月 14 日
--

图像由 Dall-e 创建
欢迎来到我“Plotly 与代码”系列的第四篇文章!如果你错过了第一篇文章,可以点击下面的链接查看,或者浏览我的“所有文章合集”来跟进整个系列或者我之前写过的其他主题。
## 《Plotly 与代码系列(第一部分):条形图的替代方案
条形图并不总是最好的解决方案。
towardsdatascience.com [## 所有我写的文章集中在一个地方
这是一个动态文档,请持续关注更多更新内容!
为什么我写这个系列的简短总结
我创建可视化图表的首选工具是 Plotly。它非常直观,从图层追踪到添加交互性都十分便捷。然而,尽管 Plotly 在功能方面表现出色,但它并没有自带“数据新闻”模板,无法直接提供精美的图表。
用代码做 Plotly 系列(第五部分):条形图的排序很重要
而且它并不总是简单地按从高到低排序
·发布于Towards Data Science ·阅读时间 10 分钟·2024 年 12 月 10 日
--

图片由 Dall-e 生成
欢迎来到我“用代码做 Plotly”系列的第五篇文章!如果你错过了第一篇,可以通过下面的链接查看,或者浏览我的“统领所有文章”系列,跟随我一起了解整个系列或者我之前写过的其他话题。
## 用代码做 Plotly 系列(第一部分):条形图的替代方案
条形图并不总是最好的解决方案。
towardsdatascience.com [## 我的所有文章集中展示
这是一个实时文档,敬请期待更多更新!
我写这个系列的简短总结
我常用的可视化工具是 Plotly。它非常直观,从叠加轨迹到添加互动性都十分容易。然而,尽管 Plotly 在功能上表现出色,但它并没有一个“数据新闻”模板,可以直接提供打磨过的图表。
Plotly 与代码系列(第六部分):处理长轴标签
要旋转还是不旋转?要截断还是不截断?
·发布于Towards Data Science ·阅读时间 10 分钟·2024 年 12 月 19 日
--

图像由 Dall-e 生成
欢迎来到我“Plotly 与代码”系列的第六篇文章!如果你错过了第一篇,可以点击下面的链接查看,或者浏览我的“一个文章统领所有”来跟随整个系列或我之前写过的其他话题。
## Plotly 与代码系列(第一部分):条形图的替代方案
条形图并不总是最佳解决方案。
towardsdatascience.com [## 我所有的文章集中在一个地方
这是一个实时文档,敬请期待更多更新!
我写这系列文章的简短总结
我创建可视化图表时常用的工具是 Plotly。它非常直观,从层叠轨迹到添加交互性都很容易操作。然而,虽然 Plotly 在功能上表现出色,但它没有提供一种“数据新闻”模板,无法直接生成打磨精美的图表。
这就是…
AWS DeepRacer:减少 Sim2Real 差距的实用指南 — 第一部分 || 准备赛道
最小化视觉干扰以最大化成功圈次
·发表于Towards Data Science ·阅读时间 8 分钟·2024 年 8 月 21 日
--
是否曾经想过,为什么你的 DeepRacer 在模拟环境中表现得完美无缺,但在现实世界中却连一个转弯都无法完成?继续阅读,了解其中的原因及如何解决常见问题。
AWS DeepRacer 在真实赛道上的视频演示。视频由作者提供。
-
第一部分 (2024 年 8 月 20 日) : 赛道和周围环境的设置。
-
第二部分 (2024 年 8 月 26 日): 动作空间和奖励函数设计以及训练范式。
在本指南中,我将分享一些实用的技巧和方法,帮助你让AWS DeepRacer自主地在赛道上运行。我将包括有关在模拟环境中训练强化学习代理的信息,以及更为重要的实际建议,帮助你成功地让汽车在真实赛道上运行——这就是所谓的模拟到现实(sim2real)挑战。
在第一部分中,我将描述在真实赛道上驾驶汽车时需要注意的物理因素。我会讲解汽车的摄像头传感器(及其局限性)以及如何准备你的物理空间和赛道。在后续部分中,我们将讨论训练过程和奖励函数的最佳实践。我决定首先关注物理因素,而非训练,因为我认为在模拟训练之前理解物理限制更为重要。
正如你将通过这系列文章看到的那样,关键目标是减少由光照变化和背景运动引起的相机干扰。
汽车与传感器

AWS DeepRacer。图片来自作者。
这辆车是一款 1/18 比例的赛车,配备 RGB(红绿蓝)相机传感器。来自AWS:
相机配备有 120 度广角镜头,并捕捉 RGB 图像,这些图像随后被转换为 160 x 120 像素的灰度图像,以每秒 15 帧(fps)的速度显示。这些传感器属性在模拟器中得以保留,以最大限度地提高训练模型从模拟到现实世界的转移效果。
这里需要注意的关键点是相机使用160 x 120 像素的灰度图像。这大致意味着相机擅长将浅色或白色像素与深色或黑色像素分开。介于这两者之间的像素,即灰色,可以用来表示额外的信息。
DeepRacer 的 RGB 视图(左)和灰度视图(右)。尽管相机捕捉到的是 RGB 图像,但它们在推理时会被转换为灰度图像。请注意轨道上的褶皱和光反射,这增加了模拟与现实之间的差距。视频由作者提供。
从本文中需要记住的最重要的事情如下:
汽车仅使用黑白图像来理解其周围环境。它并不识别物体,而是学习避免或停留在不同的灰色像素值上(从黑到白)。
因此,我们所采取的所有步骤,从轨道准备到训练模型,都将考虑到上述事实。
在 DeepRacer 的案例中,可以为汽车识别出三个基于颜色的基本目标:
-
保持在白色轨道边界内: 越接近白色(255)的像素值越高或越浅,汽车就越容易将其解释为轨道边界,并会尽力保持在这个像素边界内。
-
在黑色轨道上行驶: 更暗或较低的黑色(0)像素值将被解释为驾驶表面本身,汽车应该尽量在其上行驶。
-
绿色/黄色: 尽管汽车会将绿色和黄色视为灰度阴影——它仍然可以学会(a)靠近虚线黄色中心线行驶;并且(b)避免进入实心绿色的禁区。

实际相机视图(左)和模拟视图(右)在 RGB 空间中的显示。这些图像在推理前被转换为灰度图像。来源²。
DeepRacer 的模拟与现实之间的表现差距
AWS DeepRacer 使用强化学习(RL)¹在模拟环境中训练一个缩小版赛车,使其能够自动在赛道上行驶。这使得赛车可以首先在虚拟环境中学习一个最优且安全的策略或行为。然后,我们可以将模型部署到真实的汽车上,并在实际赛道上进行比赛。
不幸的是,现实世界中的性能往往无法与模拟器中观察到的完全相同。这是因为模拟器无法准确捕捉到现实世界的所有细节。值得称赞的是,AWS 提供了一个关于优化训练以缩小模拟与现实差距的指南。虽然这里提供的建议很有用,但对我而言并不完全奏效。汽车自带了 AWS 提供的内置模型,理论上应该适用于多个赛道,并且开箱即用。不幸的是,至少在我的实验中,这个模型甚至无法完成一圈(尽管我做了多次硬件调整)。AWS 的指南中缺少了一些信息,最终我通过在线博客和讨论论坛才将其拼凑出来。
通过我的实验,识别出以下几个关键因素,导致了模拟与现实差距的增大:
-
相机光线/噪声敏感度: 最大的挑战是相机对光线和/或背景噪声的敏感度。任何光斑都会使相机传感器过曝,导致汽车出现意外行为。尽量减少周围光线和任何背景干扰。(稍后会详细讨论)
-
摩擦力: 汽车轮子与赛道之间的摩擦力为校准油门带来了挑战。我们通过 AWS 的商店购买了推荐的赛道(继续阅读,了解为什么我不推荐这款赛道)。这款赛道是哑光乙烯基材料,在我的设置中,我将它放置在办公室的餐厅区域地毯上。看起来,乙烯基材料与地毯的组合产生了较高的静摩擦力,导致汽车在缓慢转弯或尝试从静止起步时不断卡住。
-
虚拟车与真实车的传感能力差异: 真实车与模拟车在输入参数/状态空间上存在差距。AWS 提供了输入参数列表,但像赛道长度、进度、步骤等参数仅在模拟中可用,无法用于真实车。据我所知,通过一些网络搜索——似乎车只能访问来自摄像头传感器的信息。有很小的可能性,像车的 x, y 位置和航向这样的参数是已知的。我的研究表明这些信息是无法获得的,因为车大概率没有 IMU,即使有——基于 IMU 的定位是一个非常难以解决的问题。这些信息对于设计正确的奖励函数非常有帮助(未来的部分将详细介绍)。
赛道——自建与购买
如前所述,我购买了 AWS 推荐的A To Z Speedway 赛道。这条赛道是意大利蒙扎的 Autodroma Nazionale Monza F1 赛道的简化版。
赛道评审——不要购买

这条赛道质量极差,我不建议购买。表面非常起皱、薄弱,并且反光非常强烈。图像来源:作者。
就个人而言,我不推荐购买这条赛道。它的价格为 760 美元加税(这辆车售价几乎是它的一半),总的来说,令人有些失望。
-
反光表面: 哑光乙烯基印刷的质量很差且高度反光。任何环境光都会使摄像头图像过曝,导致碰撞和其他意外行为。
-
褶皱: 赛道有很多褶皱,这会导致汽车卡住。你可以通过将赛道放在阳光下几天来在一定程度上解决这个问题。我在这方面的成功有限。你也可以使用蒸汽熨斗(参考这篇指南)。我没有尝试过这个方法,所以请自行承担风险。
-
尺寸: 这其实不是赛道本身的问题,但赛道的尺寸是 18' x 27',对于我的房子来说太大了。它甚至无法进入我的双车库。幸运的是,我的办公室同事很友好,允许我使用午餐室。它也非常笨重,折叠和携带非常困难。
总体而言,我对这款赛道的质量并不感到满意,只会在时间紧迫或不想自己动手制作的情况下推荐购买。
自己动手建造赛道(如果可能的话)
如果可以,试着自己动手构建一个。这里有一份来自 AWS 的官方指南,以及来自 Medium 的另一份指南,用户@autonomousracecarclu发布的这份指南看起来更有前景。
使用互锁泡沫垫构建赛道或许是最好的方法。这可以解决乙烯基赛道的反射性和摩擦问题。此外,这些垫子重量轻,易于堆叠;因此,移动和存储它们更为便捷。
你也可以在 FedEx 打印赛道图纸,并将其粘贴在橡胶或混凝土表面上。无论是自己建造还是打印赛道,这些方法比购买 AWS 推荐的赛道要好(无论是从经济角度还是性能角度)。
准备你的空间——照明与干扰
记住,汽车只使用黑白图像来理解和导航它周围的环境。它无法识别物体——它只是学会避免或紧贴不同的灰度(从黑到白)。保持在黑色赛道上,避免白色边界和绿色禁区。
以下部分概述了为确保你的车能够顺利绕过赛道并减少碰撞,推荐的物理设置。

赛道准备步骤 - (a) 我通过拉下所有窗帘并关闭天花板灯来减少环境光。有几盏灯无法关闭,因为它们总是保持开启以应对紧急情况。 (b) 障碍物有助于减少背景干扰和反射。彩色障碍物比黑色的效果更好。绿色障碍物最为有效。我没有足够的绿色障碍物,所以我将它们放在更难的弯道周围。图像由作者提供。
最小化环境光
尽量减少环境光照。这包括来自窗户和天花板灯的任何自然光。当然,你需要一些光线让相机能够看到,但光线越少越好。
如果无法减少照明,尽量让它尽可能均匀。光点比光本身带来的问题更多。如果你的轨道像我之前那样有褶皱,光点会更频繁,从而导致更多的失败。
彩色拼接障碍物
障碍物的颜色和摆放位置都至关重要。可能比我最初预想的还要重要。人们可能认为它们是用来保护汽车免于撞击的。虽然这是其中的一部分,但障碍物更有用的是减少背景的干扰。
我使用了这些Costco 的 2x2 英尺拼接垫。AWS 推荐使用至少2.5x2.5 英尺且任何颜色但不包括白色. 我意识到甚至黑色也会干扰汽车的表现。所以我建议使用多彩的拼接垫。
最好的选择是绿色的,因为汽车在模拟中学会避免绿色。尽管训练和推理图像是灰度的,使用绿色的障碍物效果更好。我使用了不同颜色的障碍物,因此我把绿色障碍物放在车容易偏离轨道的弯道上。
记得之前提到的——这辆车只使用黑白图像来理解周围的环境。它并不识别周围的物体——而是学会避免或坚持使用不同的灰度(从黑到白)。
接下来做什么?
在未来的文章中,我将专注于模型训练技巧和车辆校准。
致谢
特别感谢Wes Strait分享他的最佳实践和关于减少 Sim2Real 差距的详细笔记。Abhishek Roy和Kyle Stahl帮助进行实验、记录和调试不同的车辆行为。最后,感谢嘉吉研发团队让我多次使用他们的午餐空间来进行汽车和轨道的实验。
参考文献
1 Sutton, Richard S. “强化学习:导论。” A Bradford Book(2018)。
[2] Balaji, Bharathan 等. “Deepracer:用于 Sim2Real 强化学习实验的教育性自动驾驶平台。” arXiv 预印本 arXiv:1911.01562(2019)。
AWS DeepRacer:减少 Sim2Real 差距的实用指南——第二部分 || 训练指南
如何为不同的车辆行为选择动作空间、奖励函数和训练范式
·发表于Towards Data Science ·11 分钟阅读·2024 年 8 月 26 日
--
本文描述了如何训练 AWS DeepRacer 在赛道上安全驾驶而不发生碰撞。目标并不是训练最快的赛车(尽管我会简要讨论这一点),而是训练一个可以学习保持在赛道上并顺利过弯的模型。下方的视频展示了所谓的“安全”模型:
DeepRacer 尝试通过跟随中心线保持在赛道上。视频由作者提供。
-
第一部分(2024 年 8 月 20 日):赛道和周围环境设置。
-
第二部分(2024 年 8 月 26 日): 动作空间和奖励函数设计,以及训练范式。
GitRepo 链接: https://github.com/shreypareek1991/deepracer-sim2real
在第一部分中,我描述了如何准备赛道和周围环境,以最大化成功完成多圈驾驶 DeepRacer 的机会。如果你还没有阅读第一部分,我强烈建议你阅读它,因为它是理解影响 DeepRacer 性能的物理因素的基础。
我最初使用了这篇指南,来自Sam Marsman作为起点。它帮助我快速训练模拟模型,但它们在赛道上的成功率较低。话虽如此,我强烈推荐阅读他们的博客,因为它提供了关于如何逐步训练模型的极佳建议。
注意:我们将首先训练一个慢速模型,然后稍后再提高速度。上方的视频是一个更快的模型,我将在最后简要解释。
第一部分总结
在第一部分中,我们发现 DeepRacer 使用来自前置摄像头的灰度图像作为输入,以理解和导航其周围环境。我们强调了两个关键发现:
1. DeepRacer 无法识别物体,而是学会保持在某些像素值上并避免其他像素值。汽车学会保持在黑色赛道表面上,避免越过白色赛道边界,并避免进入绿色(或者说是一种灰色调)赛道区域。
2. 相机对环境光和背景干扰非常敏感。
通过减少环境光和放置五光十色的障碍物,我们尝试缓解上述问题。下面是我从第一部分复制的设置图片。

第一部分中描述的轨道和环境设置。使用五光十色的障碍物和减少环境光照是关键。图像由作者提供。
训练
在本文中,我不会详细讨论强化学习或 DeepRacer 训练环境。已有许多文章和来自AWS 的指南涵盖了这一内容。
简而言之,强化学习是一种技术,智能体试图学习一个最大化标量奖励的最优策略。换句话说,智能体学习一组基于情况的动作,以最大化奖励。那些导致理想结果的动作通常会得到正向奖励。相反,不利的动作则会被惩罚(负向奖励)或给予小额正向奖励。
相反,我的目标是提供一种训练策略,最大化汽车不发生碰撞地完成赛道的机会。我将从五个方面进行探讨:
-
赛道 — 顺时针和逆时针方向
-
超参数 — 降低学习率
-
动作空间
-
奖励函数
-
训练范式/克隆模型
赛道
理想情况下,你希望在模拟环境中使用与现实中相同的赛道。我使用了A To Z Speedway。此外,为了获得最佳性能,你需要反复在顺时针和逆时针方向上进行训练,以最小化过度训练的影响。
超参数
我使用了 AWS 的默认设置来训练前几个模型。每 2–3 次迭代将学习率减半,这样你可以对之前的最佳模型进行微调。
行动空间
这指的是 DeepRacer 为了在环境中导航而可以采取的一组动作。可以选择两种动作 — 转向角度(度)和油门(米/秒)。
我建议使用离散行动空间而不是连续行动空间。虽然连续行动空间可以实现更平滑、更快速的行为,但它需要更长的训练时间,训练成本也会迅速增加。此外,离散行动空间提供了更多对特定行为执行的控制。例如,转弯时的较慢速度。
从以下行动空间开始。DeepRacer 的最大前进速度为 4 米/秒,但我们将从更低的速度开始。你可以稍后增加这一速度(我将展示如何增加)。记住,我们的第一个目标是仅仅绕着赛道行驶。
慢而稳的行动空间
慢而稳的模型,需要人类的轻推,但能够保持在轨道上。视频由作者提供。
首先,我们将训练一个非常慢的模型,但它可以绕着赛道行驶而不偏离轨道。不用担心如果汽车不断卡住。你可能需要给它一些小推力,但只要它能完成一圈 — 你就走在正确的道路上(双关语)。确保选中了高级配置。

慢而稳模型的离散行动空间。图片由作者提供。
奖励函数
奖励函数可以说是最关键的因素,因此也是强化学习中最具挑战性的部分。它决定了智能体将学习到的行为,因此必须非常小心地设计。是的,你选择的学习模型、超参数等确实会影响智能体的整体行为 — 但它们都依赖于你的奖励函数。
设计一个好的奖励函数的关键是列出你希望智能体执行的行为,然后思考这些行为如何相互作用以及与环境的互动。当然,你无法涵盖所有可能的行为或互动,甚至即使能 — 智能体可能会学到完全不同的策略。
现在让我们列出我们希望汽车执行的期望行为以及它们对应的 Python 奖励函数。我将首先为每个行为单独提供奖励函数,然后稍后将它们汇总在一起。
行为 1 — 保持在轨道上行驶
这一点很简单。我们希望汽车保持在赛道上,避免驶出白色线条。我们通过以下两个子行为来实现这一点:
#1 靠近中心线行驶: 汽车越靠近赛道中心,发生碰撞的几率就越小。为此,当汽车靠近中心时,我们会给予较大的正奖励,而当它远离时,则给予较小的正奖励。我们会给予一个较小的正奖励,因为只要汽车仍在赛道内,偏离中心不一定是坏事。
def reward_function(params):
"""
Example of rewarding the agent to follow center line.
"""
# set an initial small but non-negative reward
reward = 1e-3
# Read input parameters
track_width = params["track_width"]
distance_from_center = params["distance_from_center"]
# Calculate 3 markers that are at varying distances away from the center line
marker_1 = 0.1 * track_width
marker_2 = 0.25 * track_width
marker_3 = 0.5 * track_width
# Give higher reward if the car is closer to center line and vice versa
if distance_from_center <= marker_1:
reward += 2.0 # large positive reward when closest to center
elif distance_from_center <= marker_2:
reward += 0.25
elif distance_from_center <= marker_3:
reward += 0.05 # very small positive reward when further from center
else:
reward = -20 # likely crashed/ close to off track
return float(reward)
#2 保持所有四个车轮在赛道上: 在赛车中,如果汽车的四个车轮都偏离赛道,则圈速会被删除。为此,我们会对四个车轮都偏离赛道的情况施加大幅负惩罚。
def reward_function(params):
'''
Example of penalizing the agent if all four wheels are off track.
'''
# large penalty for off track
OFFTRACK_PENALTY = -20
reward = 1e-3
# Penalize if the car goes off track
if not params['all_wheels_on_track']:
return float(OFFTRACK_PENALTY)
# positive reward if stays on track
reward += 1
return float(reward)
我们在这里的期望是,通过结合上述子行为,我们的代理会学会靠近赛道中心是一种期望的行为,而偏离中心会导致惩罚。
行为 2 — 转弯时减速
就像现实生活中一样,我们希望车辆在转弯时减速。此外,转弯越急,期望的速度越慢。我们通过以下方式实现:
-
提供一个较大的正奖励,使得如果转向角度较大(即急转弯),速度低于某个阈值时,给予奖励。
-
提供一个较小的正奖励,当转向角度较大且速度超过某个阈值时。
非故意的曲折行为:奖励函数设计是一种微妙的平衡艺术。没有免费的午餐。尝试训练某些期望的行为可能会导致意想不到和不希望发生的行为。在我们的案例中,通过强制代理靠近中心线,我们的代理将学会一种曲折行驶的策略。每当它偏离中心时,它会尝试通过向相反方向转向来纠正自己,循环会继续。我们可以通过惩罚极端的转向角度来减少这种情况,将最终奖励乘以 0.85(即减少 15%)。
顺便提一下,这也可以通过跟踪转向角度的变化并惩罚大的突然变化来实现。我不确定 DeepRacer API 是否提供访问先前状态的功能,以设计这样的奖励函数。
def reward_function(params):
'''
Example of rewarding the agent to slow down for turns
'''
reward = 1e-3
# fast on straights and slow on curves
steering_angle = params['steering_angle']
speed = params['speed']
# set a steering threshold above which angles are considered large
# you can change this based on your action space
STEERING_THRESHOLD = 15
if abs(steering_angle) > STEERING_THRESHOLD:
if speed < 1:
# slow speeds are awarded large positive rewards
reward += 2.0
elif speed < 2:
# faster speeds are awarded smaller positive rewards
reward += 0.5
# reduce zigzagging behavior by penalizing large steering angles
reward *= 0.85
return float(reward)
把所有内容整合起来
接下来,我们将以上所有内容结合起来,得到我们的最终奖励函数。Sam Marsman的指南建议通过训练模型逐步学习一个奖励函数,然后再加入其他奖励,以逐步训练附加行为。你可以尝试这种方法。在我的情况下,这并没有带来太大变化。
def reward_function(params):
'''
Example reward function to train a slow and steady agent
'''
STEERING_THRESHOLD = 15
OFFTRACK_PENALTY = -20
# initialize small non-zero positive reward
reward = 1e-3
# Read input parameters
track_width = params['track_width']
distance_from_center = params['distance_from_center']
# Penalize if the car goes off track
if not params['all_wheels_on_track']:
return float(OFFTRACK_PENALTY)
# Calculate 3 markers that are at varying distances away from the center line
marker_1 = 0.1 * track_width
marker_2 = 0.25 * track_width
marker_3 = 0.5 * track_width
# Give higher reward if the car is closer to center line and vice versa
if distance_from_center <= marker_1:
reward += 2.0
elif distance_from_center <= marker_2:
reward += 0.25
elif distance_from_center <= marker_3:
reward += 0.05
else:
reward = OFFTRACK_PENALTY # likely crashed/ close to off track
# fast on straights and slow on curves
steering_angle = params['steering_angle']
speed = params['speed']
if abs(steering_angle) > STEERING_THRESHOLD:
if speed < 1:
reward += 2.0
elif speed < 2:
reward += 0.5
# reduce zigzagging behavior
reward *= 0.85
return float(reward)
训练范式/模型克隆
训练成功模型的关键在于迭代地克隆和改进现有模型。换句话说,与其训练一个模型 10 小时,不如:
-
训练一个初始模型几个小时
-
克隆最佳模型
-
训练一个小时左右
-
克隆最佳模型
-
重复,直到在验证过程中获得可靠的 100% 完成度
-
在每次训练迭代中,切换时针方向和逆时针方向的赛道方向
-
每 2-3 次迭代将学习率减少一半
你要寻找的奖励图形应该像这样。如果你每次都没有达到 100%的完成度也是可以的。关键是保持一致性。

想要的奖励和完成百分比行为。图片来源:作者。
测试、重新训练、测试、重新训练、重复
机器学习和机器人技术都讲究反复迭代。没有一刀切的解决方案。所以你需要进行实验。
(附加内容)训练一个更快的模型
一旦你的车能够安全地绕过赛道(即使它有时需要推一下),你就可以在动作空间和奖励函数中提高车速。
本页面顶部的视频是使用以下动作空间和奖励函数创建的。

这是一个在保持安全的情况下,可以让车速更快的动作空间。图片来源:作者。
def reward_function(params):
'''
Example reward function to train a fast and steady agent
'''
STEERING_THRESHOLD = 15
OFFTRACK_PENALTY = -20
# initialize small non-zero positive reward
reward = 1e-3
# Read input parameters
track_width = params['track_width']
distance_from_center = params['distance_from_center']
# Penalize if the car goes off track
if not params['all_wheels_on_track']:
return float(OFFTRACK_PENALTY)
# Calculate 3 markers that are at varying distances away from the center line
marker_1 = 0.1 * track_width
marker_2 = 0.25 * track_width
marker_3 = 0.5 * track_width
# Give higher reward if the car is closer to center line and vice versa
if distance_from_center <= marker_1:
reward += 2.0
elif distance_from_center <= marker_2:
reward += 0.25
elif distance_from_center <= marker_3:
reward += 0.05
else:
reward = OFFTRACK_PENALTY # likely crashed/ close to off track
# fast on straights and slow on curves
steering_angle = params['steering_angle']
speed = params['speed']
if abs(steering_angle) > STEERING_THRESHOLD:
if speed < 1.5:
reward += 2.0
elif speed < 2:
reward += 0.5
# reduce zigzagging behavior
reward *= 0.85
return float(reward)
快速但易撞的模型——使用时请自担风险
本系列第一部分中展示的视频经过训练后偏向于速度。没有对离开赛道或碰撞进行惩罚,反而给予了非常小的正向奖励。这导致了一个快速的模型,在模拟中完成了10.337 秒的时间。在实际操作中,它会频繁碰撞,但当它成功完成一圈时,十分令人满足。
这是你可以尝试的动作空间和奖励。

这是我能够管理的最快圈速的动作空间。使用这个空间时,车子会经常发生碰撞。图片来源:作者。
def reward_function(params):
'''
Example of fast agent that leaves the track and also is crash prone.
But it is FAAAST
'''
# Steering penality threshold
ABS_STEERING_THRESHOLD = 15
reward = 1e-3
# Read input parameters
track_width = params['track_width']
distance_from_center = params['distance_from_center']
# Penalize if the car goes off track
if not params['all_wheels_on_track']:
return float(1e-3)
# Calculate 3 markers that are at varying distances away from the center line
marker_1 = 0.1 * track_width
marker_2 = 0.25 * track_width
marker_3 = 0.5 * track_width
# Give higher reward if the car is closer to center line and vice versa
if distance_from_center <= marker_1:
reward += 1.0
elif distance_from_center <= marker_2:
reward += 0.5
elif distance_from_center <= marker_3:
reward += 0.1
else:
reward = 1e-3 # likely crashed/ close to off track
# fast on straights and slow on curves
steering_angle = params['steering_angle']
speed = params['speed']
# straights
if -5 < steering_angle < 5:
if speed > 2.5:
reward += 2.0
elif speed > 2:
reward += 1.0
elif steering_angle < -15 or steering_angle > 15:
if speed < 1.8:
reward += 1.0
elif speed < 2.2:
reward += 0.5
# Penalize reward if the car is steering too much
if abs(steering_angle) > ABS_STEERING_THRESHOLD:
reward *= 0.75
# Reward lower steps
steps = params['steps']
progress = params['progress']
step_reward = (progress/steps) * 5 * speed * 2
reward += step_reward
return float(reward)
结论
总结一下,记住两件事。
-
从训练一个能成功绕过赛道的慢速模型开始,即使你有时需要推一下车子。完成这个后,你可以尝试在动作空间中增加车速。就像现实生活中一样,先从小步走开始。你也可以通过 DeepRacer 控制界面逐渐增加油门百分比,从50%到 100%,以管理车速。在我的情况下,95%的油门效果最好。
-
逐步训练你的模型。从几个小时的训练开始,然后切换赛道方向(顺时针/逆时针),逐渐将训练时间减少到一个小时。你还可以每进行 2-3 次迭代,就将学习率减半,以精炼并改进以前的最佳模型。
最后,你必须根据你的物理设置多次重复。在我的案例中,我训练了100+个模型。希望通过本指南,你能够通过15-20 个模型得到类似的结果。*
感谢阅读。
构建 Azure 容器应用:一个使用 Python Flask、Plotly Dash 和 Docker 的数据分析应用
部署可扩展的数据分析 web 应用程序:利用 Flask、Dash 和 Azure 容器应用程序实现更高的灵活性
·发表于Towards Data Science ·阅读时长 16 分钟·2024 年 2 月 19 日
--

图片由Tima Miroshnichenko提供,来自Pexels
在 Python 中开发数据分析 web 应用程序时,像Plotly 的 Dash和Streamlit这样的框架是目前最流行的。但如何在实际应用中部署这些框架,超越教程的范畴,同时考虑到可扩展性、效率和成本,并利用最新的云计算解决方案呢?
在本文中,我们将介绍如何使用Flask后端和 Plotly Dash 前端构建容器化应用程序,并将其通过 Docker 容器化后部署到 Microsoft Azure 的容器应用服务。我们将不会单纯依赖 Dash 的现成解决方案,而是部署一个自定义的 Flask 服务器。这种方法为应用程序提供了更多的灵活性,使 Dash 可以与其他框架一起使用,并且克服了开源 Dash 与 Dash Enterprise 相比的一些局限性。
本教程的代码示例可以在Github上找到。
微软 Azure 容器应用服务
解读 R²:迷惑者的叙述指南
从预测性建模的角度出发,对这一流行但常被误解的指标的基本性质进行易于理解的讲解
·发布于Towards Data Science ·阅读时长:15 分钟·2024 年 2 月 19 日
--

图片由Josh Rakower提供,来源于Unsplash
R²(R 平方),也称为决定系数,广泛用于作为回归模型性能的评估指标。它通常用于量化统计建模中的拟合优度,并且是回归模型在流行的统计建模和机器学习框架中的默认评分指标,从statsmodels到scikit-learn。
尽管 R² 无处不在,但关于它究竟意味着什么,仍然存在相当多的混淆,遇到相互矛盾的信息并不罕见(例如,关于该指标的上限或下限,以及它的解释)。这种混淆的根源在于解释性建模和预测性建模传统之间的“文化冲突”。实际上,在预测性建模中——即评估是在样本外进行的,任何能提高性能的建模方法都是可取的——许多 R² 的性质,在狭义的解释导向的线性建模上下文中适用,但在此情境下已不再成立。
为了帮助理解这一混乱的情况,本文提供了一份易于理解的 R² 基本属性介绍,从预测建模的角度出发,重点揭示并澄清了关于这个指标的常见困惑和误解。通过这篇文章,我希望能帮助读者形成关于 R² 作为预测建模和机器学习中拟合度衡量指标的统一直觉,并强调该指标的一些优缺点。本文面向广泛的读者群体,包括统计学入门学生和预测建模专家,我将保持语言简单,并通过具体的可视化图示来支撑我的论点。
准备好了吗?我们开始吧!
什么是 R²?
让我们从一个简单的口头定义开始。为了简化问题,我们采用 Wikipedia 给出的第一个高层次定义,它很好地反映了许多统计学教学资源中所找到的定义,包括权威的教科书:
从自变量预测因变量变化的比例
据经验,如果你问大多数受过统计推断训练的学生如何定义 R²,他们可能会这么回答。但正如我们接下来会看到的,这种常见的定义方式是导致许多与 R² 相关的误解和困惑的根源。让我们深入探讨一下。
叫 R² 为比例意味着 R² 将是一个介于 0 和 1 之间的数字,其中 1 对应于一个解释了因变量所有变化的模型,而 0 对应于一个解释了因变量没有变化的模型。注意:你的模型也可能不包含任何预测变量(例如,仅包含截距的模型仍然是一个模型),这就是为什么我将重点放在模型所预测的变化上,而不是独立变量所预测的变化上。
让我们验证一下关于可能值范围的直觉是否正确。为此,让我们回顾一下 R² 的数学定义:

在这里,RSS 是残差平方和,其定义为:

这只是模型的平方误差和,即真实值 y 和相应模型预测值 ŷ 之间平方差的总和。
另一方面,TSS,即总平方和,定义如下:

正如你可能注意到的,这个项与残差平方和的“形式”类似,但这次我们关注的是结果变量y与结果变量的均值ȳ之间的平方差。这在技术上是结果变量的方差。但是,在预测建模的背景下,更直观的理解方式是:这个项是一个模型的残差平方和,该模型总是预测结果变量的均值。因此,残差平方和与总平方和的比率是你的模型的平方误差之和与一个“参考”模型预测结果变量均值的平方误差之和之间的比率。
有了这个思考,我们继续分析这个指标的可能值范围,并验证我们的直觉,认为这些值应该确实在 0 到 1 之间。
最好的 R²值是多少?
如我们所见,R²是通过从 1 中减去 RSS 与 TSS 的比率来计算的。那么,R²能否大于 1?换句话说,1 是否是 R²的最大可能值?让我们通过回顾公式来思考这个问题。
唯一一种情况下,1 减去某个值能大于 1,那就是这个某个值是一个负数。但是在这里,RSS 和 TSS 都是平方和,即正值的和。因此,RSS 和 TSS 的比率总是正数。因此,最大可能的 R²必须是 1。
现在我们已经确定了 R²不能大于 1,让我们试着可视化一下,模型需要怎样才能达到最大可能的 R²。为了使 R²为 1,RSS / TSS 必须为零。这可能发生在 RSS = 0 时,也就是说,如果模型完美地预测了所有数据点。

通过模拟数据,举例说明了 R² ≈ 1 的假设模型。在所有案例中,真实的底层模型是 y = 2x + 3。前两个模型完美拟合数据,第一种情况是因为数据没有噪声,线性模型能够完美地恢复 x 与 y 之间的关系(左),第二种情况是因为模型非常灵活且过拟合数据(中)。这些是极端的案例,在现实中很难找到。事实上,最大可能的R²通常由数据中的噪声量来定义。第三个图展示了这一点,由于存在随机噪声,即使是真实模型也只能达到R² = 0.458。
实际上,除非你极度过拟合你的数据,使用过于复杂的模型,或者你正在对一个数据点极少的、你的模型能够完美拟合的数据集计算 R²,否则这种情况永远不会发生。所有数据集都会有一些无法通过数据解释的噪声。在实际应用中,最大可能的 R²将由结果变量中无法解释的噪声量来定义。
最差的 R²值是多少?
到目前为止,一切都很好。如果 R²的最大可能值为 1,我们仍然可以将 R²视为模型解释的结果变量变异的比例。但现在让我们来看看 R²的最小可能值。如果我们接受上面我们提出的 R²定义,那么我们必须假设 R²的最小可能值是 0。
R²何时为 0?要使 R²为零,RSS/TSS 必须等于 1。这种情况发生在 RSS = TSS 时,也就是说,如果我们模型的平方误差之和等于预测均值模型的平方误差之和。如果你仅仅预测均值会更好,那么你的模型确实没有做得非常好。造成这种情况的原因有无数种,其中之一可能是你选择的模型存在问题——例如,如果你试图用线性模型拟合真正的非线性数据。或者这可能是数据本身的结果。如果你的结果变量非常嘈杂,那么预测均值的模型可能是你能做的最好的模型。

两种情况,其中均值模型可能是最佳的(线性)模型,因为:a) 数据是纯高斯噪声(左);b) 数据是高度非线性的,因为它是由周期函数生成的(右)。
但是,R² = 0 真的是可能的最小 R²吗?换句话说,R²是否有可能为负值?让我们回过头来看一下公式。R² < 0 只有在 RSS/TSS > 1 时才有可能,即,如果 RSS > TSS。这种情况可能发生吗?
这时事情开始变得有趣,因为这个问题的答案在很大程度上取决于我们尚未指定的背景信息,即我们正在考虑哪种类型的模型,以及我们在哪些数据上计算 R²。正如我们将看到的,我们对 R²作为方差解释比例的理解是否成立,取决于我们对这些问题的回答。
负 R²的无底洞
让我们看一个具体的例子。我们使用以下模型y = 3 + 2x生成一些数据,并添加了高斯噪声。
import numpy as np
x = np.arange(0, 1000, 10)
y = [3 + 2*i for i in x]
noise = np.random.normal(loc=0, scale=600, size=x.shape[0])
true_y = noise + y
下面的图显示了三个模型,这些模型基于不同随机抽样的数据子集,预测y的值。这些模型不是虚构的模型,正如我们稍后将看到的那样,但现在让我们忽略这一点。我们仅仅关注它们的 R²的符号。

使用函数 y = 3 + 2x(并加入高斯噪声)生成的数据的三个模型示例。
让我们从第一个模型开始,这是一个简单的常数预测模型,在这个例子中该常数低于结果变量的均值。在这里,我们的 RSS 将是每个数据点与橙色线之间的平方距离之和,而 TSS 将是每个数据点与蓝色线(均值模型)之间的平方距离之和。很容易看出,对于大多数数据点,数据点与橙色线之间的距离将大于数据点与蓝色线之间的距离。因此,我们的 RSS 将大于我们的 TSS。如果是这种情况,我们将得到 RSS/TSS > 1,因此:1 — RSS/TSS < 0,也就是说,R² < 0。
事实上,如果我们计算这个模型在此数据上的 R²,我们得到 R² = -2.263。如果你想验证它是否真实,你可以运行下面的代码(由于随机性,你可能会得到一个类似的负值,但不会完全相同):
from sklearn.metrics import r2_score
# get a subset of the data
x_tr, x_ts, y_tr, y_ts = train_test_split(x, true_y, train_size=.5)
# compute the mean of one of the subsets
model = np.mean(y_tr)
# evaluate on the subset of data that is plotted
print(r2_score(y_ts, [model]*y_ts.shape[0]))
现在让我们继续讨论第二个模型。在这里,同样容易看出数据点与红线(我们的目标模型)之间的距离将大于数据点与蓝线(均值模型)之间的距离。事实上,在这里:R² = -3.341。注意,我们的目标模型与真实模型(橙色线)不同,因为我们在包括噪音的子集数据上进行了拟合。我们将在下一段中进一步讨论这一点。
最后,让我们看看最后一个模型。在这里,我们对上述生成的数据子集拟合了一个 5 次多项式模型。此时,数据点与拟合函数之间的距离比数据点与均值模型之间的距离显著更大。事实上,我们拟合的模型得出 R² = -1540919.225。
显然,正如这个例子所展示的,模型确实可以有负的 R²。事实上,R²的值没有下限。将模型做得足够糟糕,R²可以接近负无穷大。这在简单线性模型中也可能发生:进一步增加第二个例子中线性模型斜率的值,R²将继续下降。那么,这对于我们最初的问题——即 R²是否真的表示模型能够解释的结果变量方差的比例——意味着什么呢?
好吧,我们通常不会将比例看作是任意大的负值。如果我们真的对原始定义有所依赖,我们可以通过富有创意的想象力将这个定义扩展到涵盖那些模型表现极差、能增加结果变量方差的情境。模型增加的方差反比(例如,作为糟糕模型选择或过拟合不同数据的结果)在极低负值中有所体现。
但这更多的是一种比喻,而非定义。抛开文学思维不谈,最字面且最具生产力的思考方式是将 R² 视为一种比较度量,它说明了你的模型在预测数据时有多好(从 0 到 1 的范围内)或多差(从 0 到无穷大),相较于一个总是预测结果变量均值的模型。
重要的是,这表明,虽然 R² 可能是一个诱人的方式来以独立于尺度的方式评估模型,并且作为比较度量使用它可能是有意义的,但它远不是一个透明的度量标准。R² 的值不会提供关于你的模型在绝对意义上有多错的明确指示;最佳值始终会依赖于数据中噪声的大小;而良好或差的 R² 可能来自多种原因,而没有额外的度量工具,很难分辨清楚。
好的,R² 可以是负数。但在实践中,这种情况会发生吗?
一个非常合理的反对意见是,上述展示的场景是否实际可行。我是说,哪个理智的建模者会对如此简单的数据拟合出这样差劲的模型呢?这些看起来可能只是为了这个例子而人为构建的特设模型,并没有真正拟合任何数据。
这是一个非常好的观点,它引出了另一个与 R² 及其解释相关的关键问题。正如我们上面所强调的,所有这些模型实际上都是拟合于从与图中数据相同的真实基础函数生成的数据。这对应了预测建模中一个基础性的做法,即将数据分为训练集和测试集,前者用于估计模型,后者用于在未见数据上进行评估——这是评估模型在预测任务中表现的一个“更公正”的代理。
事实上,如果我们将前一节中介绍的模型与用于估计它们的数据进行对比,会发现它们对于其训练数据来说并不是不合理的模型。事实上,训练集的 R² 值至少是非负的(在线性模型的情况下,其 R² 值非常接近真实模型在测试数据上的 R²)。

与前图中展示的相同函数,这次是将其与拟合数据进行对比,这些数据是通过相同的真实函数 y = 3 + 2x 生成的。对于第一个模型,它预测的是一个常数,模型的“拟合”仅仅是计算训练集的均值。
那么,为什么之前的数据和现在的数据差距如此之大呢?我们观察到的是过拟合的情况。模型错误地将训练数据中的样本特有噪声当作信号进行建模——这并不是一个罕见的场景。因此,模型在新数据样本上的预测会很差。
避免过拟合可能是预测建模中最大的挑战。因此,当(为了确保模型具有泛化能力和鲁棒性,我们应该始终这样做)R²值是在样本外计算时,也就是说,在与模型估计时的数据“随机”不同的数据上计算时,观察到负的 R²值并不罕见。
因此,本节标题中提出的问题的答案实际上是一个响亮的肯定回答:负的 R²确实会出现在常见的建模场景中,即使模型已正确估计。事实上,它们一直都会发生。
那么,大家都是错的吗?
如果 R²不是一个比例,并且它作为解释方差的解释与关于其行为的一些基本事实冲突,那么我们是否必须得出结论,认为我们最初的定义是错误的?维基百科和那些教科书中呈现的相似定义是否错误?我的统计学 101 老师是否错了?好吧。是的,也不是。它很大程度上取决于 R²所呈现的上下文,以及我们所采纳的建模传统。
如果我们仅仅分析 R²的定义并尝试描述其一般行为,无论我们使用哪种类型的模型进行预测,并假设我们希望计算这个指标时是在样本外,那么是的,它们都是错的。将 R²解释为方差解释的比例是误导性的,它与这个指标的基本行为事实相冲突。
然而,如果我们将自己限制在一个狭窄的场景中,即线性模型,特别是使用最小二乘法估计的线性模型,那么答案会略有变化。在这里,R²会表现为一个比例。实际上,可以证明,由于最小二乘估计的性质,线性模型永远不可能比预测结果变量均值的模型表现更差。这意味着,线性模型永远不会有负的 R²——或者至少,它不会在与其估计时相同的数据上有负的 R²(如果你对泛化模型感兴趣的话,这是一种值得争议的做法)。因此,对于线性回归情境下的样本内评估,所讨论的定义可以视为正确的。额外有趣的事实:这是 R²等同于模型预测与真实结果之间的平方相关性的唯一场景。
许多关于 R²的误解产生的原因是,这个指标通常首先在线性回归的背景下引入,并且侧重于推断而非预测。但在预测建模中,在这里在样本内评估是不可行的,线性模型只是许多可能模型中的一种,将 R²解释为模型所解释的变异的比例,充其量是无益的,最坏的情况下是极具误导性的。
我还应该使用 R²吗?
我们已经涉及了不少内容,那么让我们总结一下。我们观察到:
-
R²不能被解释为一个比例,因为它的值可以从-∞到 1。
-
其作为“解释方差”进行的解释也是误导性的(你可以想象出那些增加数据方差的模型,或者那些将已有方差与模型“幻想”出的方差结合的模型)。
-
一般来说,R²是一个“相对”指标,它将你的模型误差与一个总是预测均值的简单模型的误差进行比较。
-
然而,将 R²描述为在线性建模和最小二乘估计的背景下以及当计算最小二乘线性模型的 R²时,使用样本内数据是准确的。
考虑到所有这些注意事项,我们是否仍然应该使用 R²?还是应该放弃它?
在这里,我们进入了更多主观观察的领域。一般来说,如果你在进行预测建模,并且想要对预测结果在绝对意义上的错误程度有一个具体的了解,那么 R²不是一个有用的指标。像 MAE 或 RMSE 这样的指标在提供模型误差大小的信息方面肯定会做得更好。这对于绝对意义上的理解是有用的,也适用于模型比较的场景,在这种场景下,你可能想要知道不同模型之间,具体而言,预测精度有多大差异。如果了解某些关于精度的信息很重要(几乎总是重要),你至少可能想要将 R²与能够提供关于每个单独预测可能有多错误的有意义的指标结合使用。
更一般来说,正如我们所强调的,如果你决定使用 R²,有一些注意事项需要牢记。其中一些涉及到 R²的“实际”上限(你的噪声上限),以及它作为相对指标的字面解释,而不是与均值模型相比的绝对拟合度。此外,正如我们所观察到的,好的或坏的 R²值可能受到多种因素的影响,从过拟合到数据中的噪声量。
另一方面,虽然我发现很少有预测建模场景中,单独使用 R²特别有用,但拥有相对于“虚拟”模型(均值模型)的拟合度度量,可以是一个富有成效的方式,帮助你批判性地思考你的模型。在训练集上,R²值不现实地高,或者在测试集上 R²值为负,可能分别帮助你考虑到你可能在追求一个过于复杂的模型,或者一个不适当的建模方法(例如,使用线性模型处理非线性数据),或者你的结果变量大部分可能仅包含噪声。这再次是一个更“务实”的个人看法,但虽然我不完全排除 R²(没有很多好的全球性和规模独立的拟合度量),在预测建模的上下文中,我会将它视为对 RMSE/MAE 等规模相关指标的补充,或者作为一种“诊断”工具,而不是作为一个目标。
结论性 remarks
R²无处不在。然而,尤其是在偏向解释性建模而非预测性建模传统的领域中,许多人对其作为模型评估工具的解释存在误解,并且这些误解仍然广泛存在。
在这篇文章中,我尝试为读者提供关于 R²的一些基本属性的叙述性介绍,以消除常见的误解,并帮助读者理解 R²通常衡量的内容,超越仅仅是线性模型样本内评估的狭窄范围。
这篇文章远非一个完整和权威的指南,我希望它能成为一个务实且灵活的资源,帮助澄清一些非常合理的困惑。干杯!
除非标题中另有说明,本文中的图片均由作者提供
回到基础:数据库、SQL 及其他数据处理必读书目
·发表于 Towards Data Science ·作为 通讯 发送 ·3 分钟阅读·2024 年 6 月 20 日
--
感觉受到启发,想写你的第一篇 TDS 文章? 我们始终欢迎新作者的贡献。
过去一年,我们的集体注意力如此集中于大语言模型(LLM),以至于有时很容易忘记,数百万数据专业人士的核心日常工作流程,更有可能涉及关系型数据库和传统的 SQL 查询,而不是例如检索增强生成(RAG)技术。
我们本周重点推荐的文章提醒我们,保持和提升我们在数据和机器学习任务全范围内的技能是必要的,而不仅仅是那些最热的任务。综合来看,它们还传递了一个重要信息:这些基础的数据操作与引发炒作的、以人工智能为主的操作之间并没有明确的界限;后者往往在没有前者的支持下甚至无法正常工作。
-
简化数据工程项目中的 Python 代码 一个坚实的基础是任何涉及大量数据的复杂操作成功的关键。John Leung 提供了确保数据管道最基本构建块——底层代码——尽可能稳健和高效的具体建议。
-
如何学习 SQL 以进行数据分析对于刚刚踏入数据查询和分析领域的任何人,Natassha Selvaraj的最新初学者友好指南提供了一个简化的路线图,帮助你在一个月内掌握 SQL 的最基本要素;它还专门提供了一部分内容,介绍了如何在求职面试中处理 SQL 问题的有用提示。

图片来自Benoît Deschasaux于Unsplash
-
如何在 SQL 中进行数据透视表操作正如Jack Chang所解释的,“使用数据透视表,用户可以查看不同数据维度的不同汇总。”不确定为什么这很重要,或者如何在 SQL 中使用数据透视表?Jack 的综合资源详细介绍了基础知识——以及更多内容。
-
使用 VBA 管理数据透视表和 Excel 图表从不同角度处理数据透视表,Himalaya Bir Shrestha提供了一个实践教程,展示了如何通过利用 VBA(Visual Basic for Applications)的强大功能来自动化你在处理 Excel 图表时的关键步骤:“虽然在开始时设置代码可能需要相当大的努力,但一旦设置完成,它对每天处理大量数据集的分析师来说会非常方便且节省时间。”
-
将关系型数据库转变为图形数据库虽然承认关系型数据库的重要作用,Katia Gil Guzman在她的首篇 TDS 文章中提出了一个重要问题:“如果你的数据的真正潜力在于数据点之间的关系呢?这就是图形数据库发挥作用的地方。”她继续演示了如何在 Python 中将关系型数据库转换为动态图形数据库。
本周有兴趣探索其他主题和问题吗?不妨看看这些精彩的阅读内容:
-
在她的“深度学习插图”系列的最新一篇中,Shreya Rao转向了递归神经网络并解析了基于序列问题的复杂性。
-
如果你想深入了解图形与几何深度学习基础模型的当前状态,千万不要错过Michael Galkin和Michael Bronstein的精彩深度分析。
-
如果我们不是借用神经科学的概念来发展 AI,而是转向 AI 研究来更好地理解人类大脑呢?Stephanie Shen撰写了一篇关于这一新兴话题的深刻解释。
-
对于那些寻找有趣且富有教育意义的动手项目的人,我们提供了Pranav Jadhav的详细教程,内容是如何在 MacBook 上从零开始构建 GPT 模型。
-
通过关注Mengliu Zhao的易懂回顾两篇有前景的论文(分别是关于像素变换器和超长序列分布式变换器的研究),保持对计算机视觉领域最新研究的关注。
感谢您支持我们作者的工作!我们喜欢发布新作者的文章,如果您最近写了一篇有趣的项目演示、教程或关于我们核心主题的理论思考,请毫不犹豫地与我们分享。
直到下一个变量,
TDS 团队
时间反向传播 — RNN 如何学习
时间反向传播算法的解释
·发布于 Towards Data Science ·阅读时间:9 分钟·2024 年 5 月 17 日
--

“www.flaticon.com/free-icons/neural-network" title=”neural network icons”>神经网络图标由 pojok d 创建 — Flaticon。
递归神经网络(RNNs)是常规前馈神经网络的变体,能够处理基于序列的数据,如时间序列和自然语言。
它们通过添加一个“递归”神经元来实现这一点,该神经元允许信息从过去的输入和输出传递到下一步。下图展示了一个传统的 RNN:

RNN 的示例架构。图示由作者提供。
左侧是一个递归神经元,右侧是通过时间展开的递归神经元。注意如何将之前的执行结果传递给后续的计算。
这为系统增加了一些固有的“记忆”,帮助模型捕捉历史上发生的模式。
在预测Y_1时,递归神经元使用X_1的输入和上一个时间步的输出Y_0。这意味着Y_0'对Y_1的影响是直接的,并且它也间接影响Y_2。
如果你想要完整了解 RNN 及一些实际示例,查看我之前的文章。
错误假设——即使是经验丰富的数据科学家的失败
统计学
数据可能具有欺骗性,所以要保持警惕!
·发布于Towards Data Science ·阅读时间 9 分钟·2024 年 8 月 16 日
--

图片由Gordon Johnson提供,来自Pixabay
在处理任何形式的数据时,做出假设是不可避免的。数据通常本质上是不完整和不完美的。无论如何,我们的工作是尽力解读和提取意义。
然而,总是存在一种风险,即对数据的错误假设可能会扭曲最终结果,甚至可能使结果误导人,或者更糟的是,完全没有意义。
我认为,这种情况既可能发生在经验丰富的从业者身上,也可能发生在那些刚接触该领域的人身上。
那么,这种情况是如何发生的呢?我们又能做些什么来减轻这种影响呢?
简介
在我的职业生涯初期,我是一名工程师,处理实验气动学相关的工作。这要求处理和解释大量来自现实世界的不完美实验数据。
我曾在我工作过的办公室里听过一句非常常用的口号(我已经把真实的说法稍微减轻了,您可以根据需要插入合适的脏话):
假设是所有真正糟糕错误的母亲
如何做:时间序列中的基准模型
为什么(以及如何)在训练最终模型之前创建基准模型
·发表于 Towards Data Science ·5 分钟阅读·2024 年 3 月 7 日
--

所以你已经收集了数据,概述了业务案例,决定了候选模型(例如:随机森林),设置好了开发环境,双手已经准备好在键盘上开始。你准备好构建和训练你的时间序列模型了。
等一下——别急着开始。在训练和测试你的随机森林模型之前,你应该先训练一个基准模型。
什么是基准模型?
基准模型是一个简单的模型,用于创建一个基准或参考点,作为你构建最终更复杂的机器学习模型的基础。
数据科学家创建基准模型的原因:
-
基准模型可以让你大致了解更复杂的模型将如何表现。
-
如果基准模型表现不佳,这可能是数据质量存在问题的迹象,值得进行处理。
-
如果基准模型的表现优于最终模型,这可能表明该算法、特征、超参数或其他数据预处理存在问题。
-
如果基准模型和复杂模型的表现差不多,这可能意味着……
强化学习基础:适用于大语言模型(LLMs)
理解强化学习(RL)的问题表述和基本算法
·发布于Towards Data Science ·18 分钟阅读·2024 年 1 月 31 日
--

(照片由Ricardo Gomez Angel提供,来源于Unsplash)
最近的人工智能研究表明,强化学习 —— 更具体地说, 来自人类反馈的强化学习(RLHF) —— 是训练最先进的大型语言模型(LLM)的关键组成部分。尽管如此,大多数关于语言模型的开源研究仍然大力强调监督学习策略,如监督微调(SFT)。这种对强化学习的忽视可以归因于多个因素,包括需要整理人类偏好数据或进行高质量 RLHF 所需的数据量。然而,无法忽视的一个因素,可能是对强化学习持怀疑态度的根本原因,就是它不像监督学习那样普遍使用。因此,AI 从业者(包括我自己!)由于简单的理解不足,往往避免使用强化学习 —— 我们倾向于坚持使用我们最熟悉的方法。
“我们中许多人表达了对监督标注的偏好,因为其信号更密集…然而,强化学习证明了其高度有效性,尤其是在成本和时间效率方面。” — 引自[8]
本系列文章。 在接下来的几个概述中,我们将通过构建…来解决这个问题。
批处理与流处理的统一解密
理解为什么批处理可以被视为流处理的一个子集,以及为何数据工程应显著简化其使用方式
· 发表在Towards Data Science · 23 分钟阅读 · 2024 年 9 月 4 日
--

图片来自Felix Mittermeier 通过Unsplash
目录
· 原则
· 阻塞和有状态操作符
· 流处理即批处理,批处理即流处理?
· 数据窗口
· 事件时间与处理时间
· 精确一次语义
· 扩展到企业级别
关于批处理与流处理的讨论,大多集中在它们的高层次区别上。然而,如果我们深入探讨,真正的区别更加微妙。通过仔细审视两种数据处理方法的基本原理,我们可以发现它们之间的相似性。实际上,它们有如此多的共同点,以至于我们可以在很大程度上抽象掉技术差异。
这意味着,虽然应用开发人员仍然需要选择他们的使用案例是更适合流处理还是批处理风格,但他们不再需要过多关注不同的技术实现。这将大大简化他们将主要精力集中在实现业务逻辑上的过程。
MIT Battlecode 反思:一个首次进入决赛者的收获
第一次参赛者的关键收获
·发表于Towards Data Science ·14 分钟阅读·2024 年 11 月 7 日
--

图片由作者制作,使用了Battlecode 引擎
我是被我们的一位队友拉进了 Battlecode 2024 的。我之前从未听说过 Battlecode,但我很兴奋能尝试一下。我玩得很开心,也很高兴自己参与了,我们最终进入了决赛并获得了总排名第 13。所以我决定从第一次参赛者的角度写一篇博客,分享这次经历。
这篇文章也被发布在我的博客上,如果你感兴趣,可以去那里查看更多内容。
什么是 Battlecode?
Battlecode 是一项比赛,你需要编写代码来控制一队机器人,完成某些任务。你可以在这里了解更多关于它的信息。
今年的游戏是一个以鸭子为基础的夺旗游戏,我们的目标是从对方那里夺取 3 面旗帜。游戏在一个网格上进行,每支队伍都有一个基地,在那里它们生成鸭子。鸭子可以在地图上移动、建造陷阱并攻击其他鸭子。游戏分为多个回合,每个回合鸭子都会获得一定量的字节码来执行它们的行动。
这是一个快速展示游戏样貌的片段

在Battlecode 引擎中进行的示例比赛,由作者创建
那个动图里有很多内容,所以让我们稍微分析一下。
鸭子行动
每只鸭子有几个动作,可以帮助攻击或防守旗帜,这些动作每回合/回合内可以执行(如上面动图下方所示)。每个动作都有一个特定的冷却时间,所以你不能一遍又一遍地重复相同的动作。
-
移动到新方格
-
建设陷阱:水、炸弹、眩晕、挖水坑、填水坑
-
拿旗帜
-
攻击敌方小鸭子
-
治疗
-
复活
- 当一只小鸭子死亡后,它将在大约 20 回合后从队伍的一个出生点复活
面包屑
还有一种叫做面包屑的货币,你可以通过两种方式获得:每回合在地图上会生成一些,你可以去捡,另外,你还可以通过在敌方领土上击杀敌鸭来获得。
建设陷阱的动作需要面包屑,这要求你在使用面包屑时必须有策略。
为什么这很困难
由于相同的代码会部署到所有小鸭子上,你必须编写能够应对各种不同情况的代码。也许你希望一些小鸭子专注于防守并围绕基地建造,可能你希望一些小鸭子被指定去捕旗,或者你希望一些小鸭子积极进攻敌方小鸭子。
一些让这项工作变得困难的因素:
-
每只小鸭子都有一个有限大小的共享数组,可以读取和写入
-
这是你与队伍中的其他小鸭子之间沟通游戏状态的唯一方式
-
这意味着你必须策略性地决定在数组中存储哪些信息以及如何使用这些信息,如果你让所有的小鸭子基于相同的信息执行相同的动作,你必须小心,因为否则它们会聚集在一起做相同的事,所以你需要小心如何分配任务
-
-
每只小鸭子每回合被分配 25K 字节码来执行
-
这意味着每回合每只小鸭子可以执行一定数量的动作、计算以及内存读取和写入。
-
如果你想要非常好的路径规划(而我们没有做到),你必须小心如何在这种密集的计算中使用字节码,否则你会用尽字节码,小鸭子就会停在那里什么都不做。
-
幸运的是,我们不需要太担心这个问题,因为 25K 已经很大了,而且据说往年它要小得多。
-
我还觉得很难判断新的机器人是否比以前的版本更好,通常它在某些情况下更好,而在其他情况下则更差。与我通常从事的软件开发不同,我能很容易判断出你开发的东西是否比以前更好,因为如果你添加了新功能,它就比之前更好。
就此而言,我想说,我尝试开发的至少一半功能并没有很好地工作,最后我不得不放弃这些改动。有很多已经死掉的功能分支,根本没有成功。
比赛结构
好了,现在我们对游戏有了基本的了解,接下来我们来谈谈实际的比赛。
排行榜
排行榜是比赛中最重要的部分之一。它是一个半实时排行榜,显示了比赛中所有队伍的排名。每 4 小时,会有与排行榜上相邻队伍的比赛,如果获胜,你将获得更多的积分,失败则会失去积分。
此外,你可以手动排队与其他队伍进行排名比赛,只要他们允许。这一点很重要,因为你不能仅仅反复与同一个队伍排队比赛来刷分。
然而,你也可以排队进行与任何队伍的非排名游戏,这对测试你的机器人与其他队伍对抗的表现非常重要,这样你可以看到你所做的更改是否真的比之前更好,尤其是在与真实对手对战时。
冲刺赛
冲刺赛是比赛的一个独立部分,组织者会进行一系列的比赛并直播结果。冲刺赛是跨所有队伍进行的,不论它们是否在同一赛区。冲刺赛是一种有趣的方式,可以看到你的机器人与其他队伍的表现,同时也是观察其他队伍表现的好机会。
你在冲刺赛中的种子排名也是根据你在排行榜上的排名来确定的,因此如果你在排行榜上表现好,你在冲刺赛中的种子排名也会更高。
比赛
比赛也会进行直播,并在你所在赛区的队伍间进行,就像冲刺赛一样。你在比赛中的种子排名是根据你在排行榜上的排名来决定的,因此如果你在排行榜上表现良好,你在比赛中的种子排名也会更高。
在我们这一年中,排名前 12 的美国大学队伍和前 4 的国际队伍晋级到了决赛。此外,还有高中组和新加入的 Battlecode 队伍,它们有自己独立的比赛。
我们的机器人
好了,现在我们对 Battlecode 有了更多了解,也知道比赛的结构是怎样的,接下来我们来谈谈我们的机器人以及我们使用的一些策略。由于这是比赛结束大约两个月后写的,所以这一部分可能会有些模糊,我不记得我们做的所有事情。
我们的策略
在整个比赛过程中,我们尝试了几种不同的策略和机器人的变体,尤其是在游戏调整平衡和我们了解更多游戏细节后。
-
设置阶段
-
在前 200 轮设置阶段中,我们让所有的鸭子随机移动,尽可能多地探索地图。
-
在那之后,我们让我们的鸭子移动到地图的中心,并在分隔两侧地图的墙旁边建造陷阱。
-
-
基地防御
-
我们指定了几只鸭子,始终驻守在保护我们的旗帜,并在我们的基地周围建造陷阱,尤其是在旗帜本身上,我们在旗帜上放置了炸弹和眩晕陷阱,这样如果敌人试图捡起旗帜,他们就会被眩晕,我们就可以攻击他们。
-
我们还决定在我们的旗帜周围设置一个水池棋盘格图案,这样敌方的鸭子要么需要花费很多时间填充水池,要么只能通过对角线进攻,这确保了敌方的鸭子无法一次性完全包围我们的基地,并且让它们通过几个通道,这样我们的鸭子更容易进行防守。
-
如果一只鸭子在守护旗帜时看到附近有敌方鸭子,它会在共享数组中设置一个全局变量,调用其他鸭子来协助防守旗帜。我们还确保当这些防守点被触发时,鸭子会优先在附近生成,以帮助防守旗帜。
-
但我们必须小心这一点,因为如果我们调动太多鸭子来防守旗帜,就会没有足够的鸭子出去夺取敌方旗帜。
-
-
旗帜传递
-
我们注意到我们的机器人经常会被困在围绕旗帜的鸭子周围,导致旗帜无法真正移动,因此我们实施了一种旗帜传递策略,旗帜会传递给另一只离目标位置更近的鸭子,以便旗帜能够被顺利捕获。
-
我们在比赛后期才加入了这个策略,并且效果不错。
-
-
旗帜归还
-
当鸭子返回旗帜时,我们有逻辑分配附近的鸭子来护送返回旗帜的鸭子,以确保它能安全返回基地。
-
它也旨在返回到离旗帜携带者最近的基地
-
需要改进的地方
-
路径规划
-
我们从前一年机器人的代码中借用了 BugNav 路径规划的代码,但由于当时字节码限制较小,所以我们有更多的字节码可以用于更智能的路径规划,但我们并没有进一步改进。
-
我们的路径规划很糟糕,在冲刺过程中,我们经常因为无法顺利归还旗帜而卡住,常常因此输掉比赛,而解说员会取笑我们😭,但其实也挺好笑的,所以我不介意。
-
-
更好的专长分配
-
有一件事我没提到的是,当鸭子做某个动作足够多次时,它会在这个动作上变得更高效并形成专长。这个专长也会限制它在其他动作上的最高效率,因此我们应该更加小心,因为在大部分比赛中,我们的大部分鸭子都专注于治疗,而我们其实并不需要那么多专注治疗的鸭子,因为这让我们在战斗中变得更弱。
-
在最后,我确实稍微平衡了一下治疗者和战斗者之间的比例,但我认为我们应该以更有计划的方式来做这件事,以最大限度地发挥我们鸭子的优势。
-
-
被困的旗帜鸭子
- 我们没有编写代码来处理鸭子捡起敌方旗帜后卡在某个地方,比如被水环绕的岛屿上,我们懒得修复这个问题,而且这种情况并不常发生,所以我们也没太在意。
-
鸭子对峙
- 有时我们的鸭子会停在那里什么都不做,因为两组鸭子都计算出它们处于劣势,不应采取攻击性策略,这让人很烦恼,我们也没有一个好的方式来处理这种情况。
-
移动旗帜
- 我们尝试在设置阶段移动地图上的旗帜,这样就可以更改旗帜的位置,但我们没有投入足够的时间,因此这个策略并未取得显著效果。
其他队伍能够更有效地在地图上移动他们的旗帜,特别是当他们根据如旗帜周围的墙壁数量以及距离友方和敌方重生点的距离来评分地图的位置时。
- 这绝对是我们应该花更多时间处理的一个功能,因为它对其他队伍来说非常有效。
你可以在这里找到我们的代码。
资格赛
我觉得这个视频很有趣,我们需要赢得这场比赛才能晋级决赛,当我们赢了时我们都非常兴奋。这是我们赢得比赛的视频。

我们在直播中赢得资格赛的画面。视频由作者提供。
决赛
本部分主要讲述我们在波士顿参加决赛的经历,这段经历非常棒,不过如果你想跳过这一部分直接看下面的反思部分也完全没问题。
经历
决赛在波士顿的麻省理工学院举行,我们都飞到了那里参加这个活动,这是我记忆中第一次去波士顿,这座城市真的很漂亮。我们还住在了麻省理工学院的宿舍里,这很令人兴奋,因为我们在宿舍遇到了很多有趣的人。不过,如果让我再来一次,我可能会选择住酒店,因为宿舍毕竟是大学宿舍,环境没有那么干净,床也不够舒适。
我们还都受邀参加了决赛前的晚宴,在那里我们见到了其他队伍和组织者。很有趣的是,我们了解到大家来自哪里,此外,他们还为我们提供了非常好吃的晚餐。
结果
决赛在他们的礼堂举行,观看比赛真的很有趣,尽管我们以 0-2 的成绩结束了比赛,但我们最终排名第 13,这已经是一个巨大的成就,能走到这一步。
他们还举办了一个估算比赛,我们需要猜测某个问题的范围,我最终通过正确估计麻省理工学院宿舍居民的数量,赢得了一只鸭子作为奖励。这里有我抓鸭子的视频:

来自Battlecode 决赛直播的视频
有个有趣的事实是,在我走到台上之前,我把一本笔记本放在了夹克下面,所以视频里看起来我走路的样子有点傻 😭
反思
我绝对不是 Battlecode 的专家,但我确实从这次经历中得到了一些关键的收获,如果我再做一次,有些地方我会做得更好。
投入时间编写有条理且干净的代码
尽管最初看起来这可能是一个相对简单的游戏,但即便在我们这一年比赛相对简单的情况下,与其他年份相比,投入大量时间编写有条理且干净的代码依然至关重要。
我们没有进行代码审查,也没有像应该的那样重构代码,结果在比赛的剩余阶段给我们带来了麻烦,因为我们不得不修复 bug,并试图理解我们的代码到底在做什么。这是我建议大家在开始阶段就花更多时间投入的部分,建立一个好的项目基础和结构,以便你能在此基础上进行扩展,并更轻松地进行修改。
我们没有这样做的部分原因是我们觉得必须赶在冲刺前推出功能,看看自己与其他队伍的对比表现,并且希望自己的比赛能出现在直播中。
冲刺并非一切
这引出了下一个要点,即冲刺并非一切,它们并不那么重要。你绝对应该专注于主要的比赛本身,而不是过于纠结于让你的机器人在冲刺中变得完美。
此外,如果你在冲刺中没有走得太远,也不意味着世界末日。我们从未让我们的比赛出现在冲刺直播中,也肯定没有赢得比赛。至少在我们那年,感觉我们所在的分组(美国大学组)相对较小,如果你没有在冲刺中走得很远,对你的队伍来说也不意味着一切结束,你依然可以晋级决赛。
进行大量对抗赛
正如我之前提到的,要判断你的新机器人是否比旧版本更好其实非常困难,因为你只能在本地与其他代码变体进行测试,在很多情况下,测试自己之前的机器人版本并不太有帮助,因为如果它们太相似,通常会经常平局,或者你会看到你所做的改动在某些情况下更好,而在另一些情况下则更差。
这也是为什么与其他队伍进行大量对抗性比赛非常重要,看看你的改动是否在实际对手面前比以前更好,尤其是这些对手是你在排名赛和锦标赛中可能会遇到的。
直到比赛结束时,另一支队伍分享了他们的做法,我才想到这个点。其实我认为这是一个非常好的主意,可以编写一个脚本,自动匹配其他队伍的比赛,并通过某种评估方式,看看你的改动是否真的比以前更好。一个脚本链接在这里
即使你没有脚本来自动进行这些操作,手动匹配其他队伍的比赛仍然非常重要,这样你可以看到你的改动是否比以前更好。我们没有这样做,但我们仍然匹配了 818 场总计的非排名对抗赛,是所有队伍中第 8 多的。而总对抗赛数量为 1,294 场,是第 7 多的。
观看你的比赛
使用自动化脚本匹配比赛的一个缺点是,你可能不会那么频繁地观看自己的比赛。观看自己比赛的过程非常宝贵,因为你可以看到你的机器人在哪些情况表现得很好,哪些情况表现得很差。
这会给你很多关于如何改进你的机器人和策略的信息。
观察其他队伍
同样重要的是观看其他队伍的比赛(特别是如果他们关闭了自动训练的话),看看他们使用了哪些策略,他们的机器人表现如何。
如果你看到有队伍的表现比你强,或者看到他们使用了一些有趣的策略,不要害怕借鉴并增强他们的策略,使其更适合你的机器人,甚至可能比他们使用的方式更好。
利用社区资源
Battlecode 社区非常乐于助人,而且有很多资源可以帮助你入门并改善你的机器人。很多队伍分享了他们过去几年的代码和策略,这是一个很好的途径,可以帮助你了解其他队伍在做什么。
此外,过去几年的代码也有很多是每年都能使用的,这也是一个很好的起点,可以了解其他队伍在做什么。例如:共享数组通信、路径规划等。这些都是你可以在往年代码中找到并应用到自己机器人上的内容。
给新手的建议
开始接触 Battlecode 可能会让人感到不知所措,因为需要学习很多内容:游戏规范、比赛结构、策略、游戏中的限制、如何在你的机器人之间进行通信、如何本地运行比赛等等。
更不用说开始编写代码可能很困难,因为你必须熟悉现有的方法、可以执行的操作、能否访问某些信息以及如何与之交互,处理这些内容可能会让人感到有些棘手。我强烈建议你使用带有大量自动补全功能的集成开发环境(IDE),这样你可以预览所有可用的方法。它的感觉很像在使用像 Pandas 这样的库,有很多方法,你需要习惯它们的存在,感觉就像在使用一种全新的语言。
一开始这确实让我感到很沮丧,因为我只是想加入一些基本功能,但由于我们的项目组织不够完善,我对游戏不熟悉,也不知道如何为其编写代码,开始时确实有些烦人。但推自己一把,开始动手是值得的,因为它非常有趣,而且你会学到很多。
用开源 LLM 应对开放书籍考试
在每个人都在为工作和学校使用 ChatGPT 的时代,我正在利用它来帮助我学习大学课程。
·发布于 Towards Data Science ·8 分钟阅读·2024 年 7 月 19 日
--

图片由 Dall-E 3 生成,由作者提供提示
免责声明 这不是任何考试的作弊或黑客工具。这只是一个帮助你更好准备课程考试的工具。请明智地使用它。
嗨,我是 Jubayer Hossain,FAU-Erlangen 的硕士生。我的电动出行专业包括机械学、人工智能和编程课程。本学期,我选修了两门课程,专注于开放书籍考试,所有的讲义内容都已提供。
由于我们可以自由使用任何资源,而大规模语言模型(LLMs)现在正变得非常流行,我计划在 Langchain 的帮助下实现一个基于 RAG 的开源 LLM,帮助我进行内容搜索,并更好地为考试做准备。
所以不再赘述,接下来让我们开始项目计划。
项目计划
首先,我需要在课程门户网站中提供的所有讲义幻灯片。大约有 16 个 PDF 幻灯片,我手动下载了这些。当然,我可以写一个脚本,自动下载所有幻灯片,但只有 16 个,手动下载会更快。
贝叶斯定理:用证据理解商业结果
贝叶斯定理的实际介绍:数据科学系列中的概率(2)
·发表于 Towards Data Science ·阅读时间:10 分钟·2024 年 12 月 12 日
--

图片由 Markus Spiske 提供,来源于 Unsplash
如果你不是 Medium 的付费会员,我会免费提供我的文章:朋友链接
贝叶斯定理是统计学中最广泛使用和最受推崇的概念之一。它为概率理论奠定了基础,使我们能够根据新的证据修正预测或假设。
在上一篇关于概率符号的文章中,我介绍了 P(B∣A)—在事件 A 已经发生的条件下,事件 B 发生的概率。
贝叶斯定理翻转了这一视角,专注于 P(A∣B):在 B 事件发生的前提下,A 事件发生的概率。本质上,它通过结合先验信息(已知数据)帮助我们完善对结果的理解。
实际上,即使你最初的假设或估计并不完美,应用贝叶斯定理的过程也鼓励我们对未来做出更加深思熟虑和有依据的猜测!
首先,让我们来看一个受丹尼尔·卡尼曼和阿莫斯·特维尔斯基的著名研究启发的例子。
目录
贝叶斯推断:感知、推理与决策的统一框架
·发布于Towards Data Science ·14 分钟阅读·2024 年 1 月 4 日
--

“…生活中最重要的问题……事实上,大多数都是概率性问题。严格来说,几乎所有的知识都是有可能的。”
— 皮埃尔-西蒙·拉普拉斯,《概率哲学论文》
两百多年前,法国数学家皮埃尔-西蒙·拉普拉斯意识到我们所面临的大多数问题本质上是概率性的,而我们的大部分知识也是基于概率而非绝对确定性的。在这个前提下,他充分发展了贝叶斯定理,这是概率论中的一个基本理论,尽管他并不知道,六十年前,英国牧师托马斯·贝叶斯(同样是统计学家和哲学家)已经描述了这一定理。因此,尽管拉普拉斯完成了大部分数学工作,贝叶斯定理最终还是以贝叶斯的名字命名。
与其悠久的历史相比,贝叶斯定理直到近几十年才引起广泛关注,并在各个学科中得到了显著应用,越来越多的人意识到该定理与我们的感知和认知过程更加契合。它体现了概率的动态调整,既受到新数据的影响,也受到已有知识的影响。此外,它解释了我们思维过程中反复迭代和发展的性质。
贝叶斯线性回归:完整的初学者指南
使用STAN构建贝叶斯回归模型的工作流程和代码演示
·发布于Towards Data Science ·阅读时间 9 分钟·2024 年 9 月 14 日
--
注意:查看我之前的文章,深入讨论为什么贝叶斯建模可能是你任务的正确选择。
本教程将重点讲解如何通过工作流程和代码演示,使用STAN这一概率编程语言构建贝叶斯回归模型。STAN 被广泛采用,并可以与你选择的编程语言(如 R、Python、shell、MATLAB、Julia、Stata 等)进行接口。请查看安装指南和文档。
本教程将使用Pystan,因为我使用 Python 编程。即使你使用其他语言,我将讨论的贝叶斯实践和 STAN 语言语法也不会有太大差异。
对于更注重实践的读者,以下是本教程的笔记本链接,来自我在西北大学(2024 年 4 月)举办的贝叶斯建模工作坊的一部分。
让我们开始吧!
贝叶斯线性回归
让我们学习如何用贝叶斯方法构建一个简单的线性回归模型,这是任何统计学家日常使用的基础模型。假设有一个因变量Y和协变量X,我提出以下简单的模型:
Y = α + β * X + ϵ
其中⍺是截距,β是斜率,ϵ是一些随机误差。假设:
ϵ ~ 正态分布(0, σ)
我们可以证明
Y ~ 正态分布(α + β * X, σ)
我们将学习如何在 STAN 中编写这个模型形式的代码。
生成数据
首先,让我们生成一些虚拟数据。
#Model Parameters
alpha = 4.0 #intercept
beta = 0.5 #slope
sigma = 1.0 #error-scale
#Generate fake data
x = 8 * np.random.rand(100)
y = alpha + beta * x
y = np.random.normal(y, scale=sigma) #noise
#visualize generated data
plt.scatter(x, y, alpha = 0.8)

线性回归生成的数据(图片来自作者代码)
模型字符串
现在我们有了一些数据来建模,让我们深入了解如何构建数据并将其与建模指令一起传递给 STAN。这是通过model字符串完成的,它通常包含 4 个(偶尔更多)块——data、parameters、model和generated quantities。让我们详细讨论这些块的每一个。
DATA 块
data { //input the data to STAN
int<lower=0> N;
vector[N] x;
vector[N] y;
}
data块可能是最简单的,它告诉 STAN 应该预期什么数据,以及以什么格式。例如,这里我们传递-
N:我们的数据集大小,类型为int。<lower=0>部分声明 N≥0。(尽管在这里数据长度不可能为负是显而易见的,但声明这些边界是良好的标准做法,可以使 STAN 的工作更轻松。)
x:作为长度为 N 的向量的协变量。
y:作为长度为 N 的向量的因变量。
请参阅文档以获取完整的支持数据类型范围。STAN 支持多种类型,如数组、向量、矩阵等。正如我们上面所看到的,STAN 还支持对变量进行限制编码。编码限制是推荐的做法!它有助于更好地指定模型,并简化底层的概率采样过程。
模型块
接下来是model块,我们在其中告诉 STAN 模型的结构。
//simple model block
model {
//priors
alpha ~ normal(0,10);
beta ~ normal(0,1);
//model
y ~ normal(alpha + beta * x, sigma);
}
模型块还包含一个重要且常常令人困惑的元素:先验指定。先验是贝叶斯建模的核心部分,必须根据采样任务适当指定。
请参阅我之前的文章,了解关于先验的角色和直觉。简而言之,先验是对参数值分布的假定功能形式——通常简称为先验信念。尽管先验不需要与最终解完全匹配,但它们必须允许我们从中采样。
在我们的示例中,我们使用均值为 0、方差不同的正态先验,这取决于我们对提供的均值的确信度:α的方差为 10(非常不确定),β的方差为 1(有点确定)。在这里,我提供了普遍的信念,即虽然α可以取一个广泛的范围的不同值,但斜率通常会更受限制,并且不会有很大的幅度。
因此,在上面的示例中,α的先验比β的“弱”。
随着模型变得更加复杂,采样解决方案空间扩展,提供信念变得更加重要。否则,如果没有强烈的直觉,好的做法是向模型中提供较少的信念,即使用弱信息性先验,并保持对数据的灵活性。
y 的形式,你可能已经认出来了,是标准的线性回归方程。
生成的量
最后,我们有了生成的量的代码块。在这里,我们告诉 STAN 我们想要计算并作为输出接收哪些量。
generated quantities { //get quantities of interest from fitted model
vector[N] yhat;
vector[N] log_lik;
for (n in 1:N){
yhat[n] = normal_rng(alpha + x[n] * beta, sigma);
//generate samples from model
log_lik[n] = normal_lpdf( y[n] | alpha + x[n] * beta, sigma);
//probability of data given the model and parameters
}
}
注意:STAN 支持将向量直接传入方程中,或者作为每个元素 n 的迭代 1:N。在实践中,我发现这种支持会随着 STAN 的不同版本而变化,因此,如果向量化版本无法编译,最好尝试使用迭代声明。
在上面的例子中——
yhat: 根据拟合的参数值生成 y 的样本。
log_lik: 生成在给定模型和拟合参数值的情况下数据的概率。
这些值的用途将在我们讨论模型评估时更加明确。
将所有内容整合在一起
总的来说,我们现在已经完全指定了我们的第一个简单贝叶斯回归模型:
model = """
data { //input the data to STAN
int<lower=0> N;
vector[N] x;
vector[N] y;
}
parameters {
real alpha;
real beta;
real<lower=0> sigma;
}model {
alpha ~ normal(0,10);
beta ~ normal(0,1);
y ~ normal(alpha + beta * x, sigma);
}generated quantities {
vector[N] yhat;
vector[N] log_lik;
for (n in 1:N){ yhat[n] = normal_rng(alpha + x[n] * beta, sigma);
log_lik[n] = normal_lpdf(y[n] | alpha + x[n] * beta, sigma); }
}
"""
剩下的工作就是编译模型并运行采样。
#STAN takes data as a dict
data = {'N': len(x), 'x': x, 'y': y}
STAN 以字典的形式接收输入数据。重要的是,这个字典必须包含我们在模型数据块中告诉 STAN 预期的所有变量,否则模型将无法编译。
#parameters for STAN fitting
chains = 2
samples = 1000
warmup = 10
# set seed
# Compile the model
posterior = stan.build(model, data=data, random_seed = 42)
# Train the model and generate samples
fit = posterior.sample(num_chains=chains, num_samples=samples)The .sample() method parameters control the Hamiltonian Monte Carlo (HMC) sampling process, where —
-
num_chains: 是我们重复采样过程的次数。
-
num_samples: 是每条链中要抽取的样本数。
-
warmup: 是我们丢弃的初始样本数(因为在达到解空间的一般范围之前需要一些时间)。
知道这些参数的正确值依赖于我们模型的复杂性以及可用资源。
更大的采样规模当然是理想的,但对于一个不适定的模型,它们只是浪费时间和计算资源。从经验来看,我曾经有过需要等待一周才能完成的大型数据模型,结果发现模型并没有收敛。重要的是在运行完整的采样之前,先慢慢开始并对模型进行合理性检查。
模型评估
生成的量用于
-
评估拟合的优度,即收敛性,
-
预测
-
模型比较
收敛性
评估模型的第一步,在贝叶斯框架下,是可视化的。我们观察哈密顿蒙特卡洛(HMC)采样过程的采样结果。

模型收敛:通过独立采样链的重叠情况进行可视化评估(图片来自作者代码)
简单来说,STAN 会迭代地抽取参数值的样本并对其进行评估(HMC 做的工作要复杂得多,但那超出了我们当前的范围)。对于一个良好的拟合,样本抽取必须收敛到某个共同的区域,这个区域理想情况下应该是全局最优解。
上图展示了我们模型在两个独立链(红色和蓝色)上的采样结果。
-
在左侧,我们绘制了拟合的参数值的整体分布,即后验分布。如果模型及其参数指定得当,我们期望看到一个正态分布。(为什么会这样?好吧,正态分布意味着对于参数存在一个最佳拟合值的特定范围,这支持了我们选择的模型形式)。此外,如果模型收敛到最优解,我们还应当期望不同链之间有相当大的重叠。
-
在右侧,我们绘制了每次迭代中实际绘制的样本(仅仅是为了额外确认)。在这里,我们希望看到的不仅仅是一个狭窄的范围,同时也希望看到绘制样本之间有很大的重叠。
并非所有评估指标都是可视化的。Gelman 等人1还提出了Rhat诊断,它本质上是一个数学度量,用于评估不同链之间样本的相似性。通过 Rhat,可以定义一个临界值,超过该值时,两个链被认为过于不同,无法收敛。然而,由于过程的迭代性质和变化的预热期,临界值很难定义。
因此,视觉比较是一个至关重要的组成部分,无论是诊断测试与否。
你可能会有一个频率主义的想法:“那么,如果我们只有链和分布,实际的参数值是什么呢?”这正是关键所在。贝叶斯的公式只处理分布,而非带有难以解释的测试统计量的点估计。
话虽如此,后验分布仍然可以通过可信区间来总结,比如高密度区间(HDI),它包括所有 x%的最高概率密度点。

95% HDI 对β的估计(图片来源:作者代码)
对比贝叶斯可信区间和频率主义置信区间是很重要的。
-
可信区间为参数的可能值提供了一个概率分布,即给定数据时,参数在某个区间内取每个值的概率。
-
置信区间将参数值视为固定的,而是估计反复随机抽样数据时,结果会有多大可能匹配。
因此,
贝叶斯方法允许参数值是流动的,并且直接根据数据的表面意义来进行推断,而频率主义方法则要求存在唯一的真实参数值……如果我们能接触到所有历史数据的话。
呼,稍微消化一下,再读一遍,直到完全理解。
使用可信区间的另一个重要含义,换句话说,允许参数是可变的,就是我们所做的预测能够捕捉到这种不确定性,并且具有透明性,有一定的 HDI 百分比指示最佳拟合线。

95% HDI 最佳拟合线(图片来源:作者代码)
模型比较
在贝叶斯框架中,渡边-赤池信息量度(WAIC)得分是进行模型比较的广泛接受的选择。WAIC 得分的简单解释是,它估计模型的似然,同时对模型参数的数量进行正则化。简单来说,它可以帮助解决过拟合问题。这也是贝叶斯框架的一个主要优势——并不一定需要持有一个模型验证数据集。因此,
当数据稀缺时,贝叶斯建模提供了一个重要的优势。
WAIC 得分是一个比较性指标,即只有在跨不同模型进行比较时,它才具有意义,这些模型试图解释相同的基础数据。因此,在实践中,只要 WAIC 增加,就可以继续增加模型的复杂性。如果在这个增加复杂性的过程中,WAIC 开始下降,那么可以停止——任何更多的复杂性在描述基础数据分布时将不会提供信息上的优势。
结论
总结来说,STAN 模型块只是一个字符串。它向 STAN 解释你将提供给它什么(模型)、需要找出什么(参数)、你认为发生了什么(模型),以及它应该返回什么(生成的量)。
当开启时,STAN 简单地转动机器并给出其输出。
真实的挑战在于定义一个合适的模型(参考先验分布),合理构建数据,准确地向 STAN 提出需求,并评估其输出的合理性。
一旦我们掌握了这一部分,我们就可以深入探讨 STAN 的真正力量,在这里,指定越来越复杂的模型变成了一个简单的语法任务。事实上,在我们的下一篇教程中,我们将正是这样做。我们将在这个简单的回归示例基础上,探索贝叶斯层次模型:行业标准,最先进的技术,事实上的标准…你可以自己定义。我们将看到如何将组级的随机效应或固定效应加入到模型中,并惊叹于在贝叶斯框架中增加复杂性同时保持可比性的容易程度。
如果这篇文章对你有帮助,请订阅,并继续关注更多内容!
参考文献
1 Andrew Gelman, John B. Carlin, Hal S. Stern, David B. Dunson, Aki Vehtari 和 Donald B. Rubin (2013). 贝叶斯数据分析,第三版。Chapman and Hall/CRC。
Python 中的贝叶斯逻辑回归
如何使用贝叶斯方法在 Python 中解决二分类问题。
·发表于Towards Data Science ·8 分钟阅读·2024 年 2 月 20 日
--

贝叶斯思维 — OpenAI DALL-E 生成的图像,作者提供
介绍
在本文中,我将使用 Pyro(一款 Python 概率编程包)构建一个简单的贝叶斯逻辑回归模型。本文将涵盖 EDA、特征工程、模型构建和评估。重点是提供一个简单的贝叶斯逻辑回归框架。因此,前两部分的深度将会有限。本文中使用的代码可以在此处找到:
[## GitHub - fraser-brownn/bayesian_logistic_regression: 用于执行贝叶斯逻辑回归的 Notebook…
用于执行贝叶斯逻辑回归的 Notebook,使用的是心脏衰竭 Kaggle 数据集…
探索性数据分析
我正在使用 Kaggle 上的心脏衰竭预测数据集,链接如下。此数据集是根据 Open Data Commons Open Database License (ODbL) v1.0 提供的。该数据集的完整引用可以在本文末尾找到。
贝叶斯传感器校准
面向传感器工程师的 Python 实践教程
·发表于Towards Data Science ·阅读时长 9 分钟·2024 年 5 月 1 日
--
作者:Moritz Berger的贡献。
贝叶斯传感器校准是一项新兴技术,结合了统计模型和数据,以最优方式校准传感器——这是一个至关重要的工程程序。本教程提供了使用现有库并要求最低数学背景的 Python 代码,来数值地执行这种校准。作为案例研究的示例,我们考虑一个受温度影响的磁场传感器,其灵敏度会发生漂移。
术语表。 粗体词汇在《国际计量学词汇》(即“VIM 定义”)中有定义。仅首次出现的术语使用粗体。
代码可用性。 本教程的可执行 Jupyter 笔记本可在 Github 上获取。可以通过nbviewer访问。
介绍
背景。 物理传感器提供了使系统能够理解其环境的主要输入。它们测量物理量,如温度、电流、功率、速度或光强度。测量结果是对被测量值(即所谓的测量对象)的真实值的估计。传感器从来不是完美的。非理想因素、零件间的差异和随机噪声都可能导致传感器误差。传感器校准及其后续的调整是减少传感器测量不确定性的关键步骤。贝叶斯方法提供了一个数学框架来表示不确定性。特别地,如何通过“智能”校准结合关于过去样本的先验知识和校准所提供的新证据来减少不确定性。即使在简单的情况下,其中传感器响应被建模为带噪声的多项式传递函数,精确的解析解的数学推导也可能让人畏惧(Berger 2022)。幸运的是,Python 库的开发便于贝叶斯统计建模。因此,贝叶斯模型对于工程师而言变得越来越可接近。尽管已有实践教程(Copley 2023; Watts 2020)甚至教科书(Davidson-Pilon 2015; Martin 2021),但它们缺少传感器校准的示例。
目标。 本文旨在重现一个简化的案例,灵感来源于(Berger 2002)并在下图中进行了说明。传感器用于通过测量与电流直接成正比的磁场B来测量通过电线的电流i。我们专注于磁传感器,并考虑以下几种非理想因素。(1)温度B是一个寄生的影响量,会干扰测量。(2)传感器响应因零件之间的制造差异而有所不同。(3)传感器数据受到随机误差的污染。通过贝叶斯方法和高层次的PyMC Python 库,我们旨在计算给定校准数据集的最佳校准参数集。

(a) 电流传感器应用。(b) 功能视图(图片来自作者)。
数学公式
我们假设磁性传感器由一个磁性和温度传感器组成,可以将其建模为具有系数的多项式传递函数,这些系数根据正常的概率分布在不同传感器之间变化。原始感测数据(在 VIM 中称为“指示值”),由向量u表示,包含线性感测的磁场S(T)⋅B和一个无量纲的感测量V_T,指示温度。我们使用特定形式S(T)⋅B来突出显示灵敏度S受温度的影响。温度的寄生影响在下面的图(a)中有所说明。理想情况下,灵敏度应当与温度和V_T无关。然而,存在一个多项式依赖关系。这个案例灵感来源于实际的磁性霍尔传感器,在温度范围[−40°C, +165°C]内,灵敏度可以在室温值的基础上变化±40%。此外,由于部件间的差异,存在一组S与V_T的曲线,而不仅仅是一条曲线。从数学上讲,我们希望识别出一个测量函数,该函数能够准确地估计磁场的真实值——如图(b)所示。从概念上讲,这相当于反转传感器响应模型。这归结为估算温度依赖性灵敏度,并使用此估算值Ŝ通过除法从感测场中恢复磁场。

(a) N=30 个传感器的原始响应。(b) 阻抗图(图像来自作者)。
对于我们的简化案例,我们假设S(T)和VT(T)是二次多项式。假设多项式系数围绕其标称值变化,符合正态分布。另一个随机噪声项被添加到两个感测信号中。从物理上讲,S是相对于室温值的灵敏度,而VT是来自温度传感器的归一化电压。这代表了一类大型传感器,其中主传感器是线性的,但依赖于温度,而补充的温度传感器用于修正这种寄生依赖关系。同时,两个传感器都有噪声。我们假设在VT中的三次多项式是适合用来估计灵敏度Ŝ的候选函数:
Ŝ = w_0 + w_1⋅ΔT + w_2⋅ΔT² + w_3⋅ΔT³,其中ΔT = T−25°C。
权重向量w聚合了多项式的 4 个系数。这些是需要根据标定结果调整的标定参数。
Python 公式
我们使用(Close, 2021)中介绍的代码约定。我们定义一个数据字典dd来存储参数的标称值。此外,我们定义概率密度函数来捕获参数的变异性。传感器响应被建模为传递函数,就像(Close 2021)中介绍的约定一样。
# Data dictionary: nominal parameters of the sensor response model
def load_dd():
return Box({
'S' : {
'TC' : [1, -5.26e-3, 15.34e-6],
'noise': 0.12/100, },
'VT': {
'TC': [0, 1.16e-3, 2.78e-6],
'noise': 0.1/100,}
})
# Probability density functions for the parameter variations
pdfs = {
'S.TC': (norm(0,1.132e-2), norm(0,1.23e-4), norm(0,5.40e-7)),
'VT.TC' : (norm(0,7.66e-6), norm(0,4.38e-7))
}
# Sensor response model
def sensor_response_model(T, sensor_id=0, dd={}, delta={}):
S=np.poly1d(np.flip((dd.S.TC+delta.loc[sensor_id]['S.TC'])))(T-25)
S+=np.random.normal(0, dd.S.noise, size=len(T))
VT = 10*np.poly1d(np.flip(dd.VT.TC+np.r_[0,delta.loc[sensor_id]['VT.TC'].values]))(T-25)
VT+= np.random.normal(0, dd.VT.noise, size=len(T))
return {'S': S, 'VT': VT}
然后,我们可以通过从指定的概率分布中抽样,模拟一组N=30 个传感器,并生成合成数据df1,以通过构建函数build_sensors(ids=[..])测试不同的标定方法。
df1,_=build_sensors_(ids=np.arange(30))

由概率传感器响应模型生成的合成数据(图片来自作者)。
经典方法
我们首先考虑两种不依赖贝叶斯框架的经典标定方法。
完全回归
第一个标定方法是一种暴力方法。为每个传感器收集全面的数据集,校准点数多于未知数。每个传感器的标定参数w(4 个未知数)通过回归拟合来确定。当然,这种方法在残差误差方面提供了最佳的结果。然而,实际上这种方法非常昂贵,因为它需要对每个单独的传感器进行全面表征。以下函数执行完全标定,并将权重作为列表存储在数据框中以方便使用。
def full_calibration(df):
W = df.groupby("id").apply(
lambda g: ols("S ~ 1 + VT + I(VT**2)+ I(VT**3)", data=g).fit().params
)
W.columns = [f"w_{k}" for k, col in enumerate(W.columns)]
df["w"] = df.apply(lambda X: W.loc[X.name[0]].values, axis=1)
df1, W=full_calibration(df1)
盲标定
盲标定代表了另一个极端。在这种方法中,首先对一组参考传感器进行完全标定,如上所述。接下来的传感器不单独标定。相反,直接“盲目”地使用参考组的平均标定参数w0。
w0 = W.mean().values
df2,_=build_sensors_(ids=[0])
def blind_calibration(df):
return df.assign(w=[w0]*len(df))
df2 = blind_calibration(df2)
以下图表展示了两种方法的残差灵敏度误差 Ŝ−S。回想一下,标定前的误差高达 40%。绿色曲线表示参考组中N=30 个传感器的灵敏度误差。除了残余的四阶误差(由于灵敏度估计器阶数有限,这是不可避免的),拟合效果令人满意(<2%)。红色曲线表示盲标定传感器的残差灵敏度误差。由于部件间差异,平均标定参数只能提供近似的拟合,残差误差不令人满意。

N=30 个完全标定和 N=1 个盲标定传感器的残差灵敏度误差(图片来自作者)。
贝叶斯标定
贝叶斯标定是一种有趣的折衷方法,介于之前的两种极端方法之间。一组参考传感器像上面一样被完全标定。这组参考传感器的标定参数构成了一些先验知识。参考组的平均值w0和协方差矩阵Σ编码了传感器响应的相关知识。权重不是独立的,一些组合比其他组合更有可能。这样的知识应该在智能标定中加以利用。可以使用 Pandas 和 Seaborn 库计算并绘制覆盖矩阵(仅针对两个权重)。
Cov0 = W.cov(ddof=len(W) - len(w0))
sns.jointplot(data=W.apply(pd.Series),x='w_1', y='w_2', kind='kde', fill=True, height=4)

两个权重 w_1 和 w_2 的双变量图(图片来自作者)。
贝叶斯框架使我们能够捕捉这种先验知识,并在后续样本的校准中加以利用。我们从之前同样被盲目校准的样本开始。我们模拟了一个情况,即每个新传感器仅收集两个校准数据点,分别是 0°C 和 100°C,从而通过新证据丰富我们的知识。在硬件校准成本昂贵的实际工业场景中,这种校准方法非常具成本效益。参考集在初次收集时已充分表征,以便获得先验知识。随后样本,可能是该批次的大部分生产样本,仅在几个点上进行表征。在贝叶斯术语中,这被称为“推断”,而 PyMC 库提供了高级函数来执行推断。这是一个计算密集型的过程,因为后验分布是通过应用贝叶斯定理,将先验知识和新证据结合起来获得的,只能通过采样获得。对于获得的概率密度函数没有解析近似。
下文对比了校准结果,蓝点表示贝叶斯方法使用的两个校准点。通过仅添加两个额外的点,并利用参考集中的先验知识,贝叶斯校准传感器的误差几乎没有退化,相较于昂贵的暴力方法,表现得更为优越。

三种校准方法的对比。
可信区间
在贝叶斯方法中,所有变量都被不确定性所表征。传感器模型的参数、校准参数,以及后验预测。然后我们可以构建一个±1σ的可信区间,覆盖由模型生成的 68%合成观测数据,针对估算的灵敏度Ŝ。此图捕捉了校准和调整的本质:在T=0°C 和T=100°C 两个校准点周围的不确定性已被减少。剩余的不确定性源于测量噪声。

可信区间(图源:作者)。
结论
本文展示了一个用于模拟贝叶斯传感器校准的 Python 工作流程,并将其与广为人知的经典方法进行了对比。该数学和 Python 公式适用于广泛类别的传感器,能够帮助传感器设计探索各种方法。该工作流程可以总结如下:
-
传感器响应建模通过传递函数及其参数(标称值和统计变动)进行建模。为一批传感器生成相应的合成原始传感数据。
-
定义测量函数的形式,从原始传感器变量开始。通常,这是一个多项式,校准应为每个传感器确定该多项式的最佳系数w。
-
获取一些先验知识,通过对一个具有代表性的传感器子集进行全面表征。以平均校准参数和协方差矩阵的形式编码这些知识。
-
获取有限的新证据,以每个传感器特定的一小部分校准点的形式进行。
-
执行贝叶斯推断,将这些新的稀疏证据与先验知识合并,使用 PyMC 数值计算找到此新传感器最可能的校准参数。
在传感器校准对生产成本有显著影响的频繁情况下,贝叶斯校准展现出巨大的商业优势。考虑一批 1'000 个传感器。可以通过对例如仅 30 个传感器进行完整表征,获得代表性的先验知识。然后对于其他 970 个传感器,只需使用少量的校准点。在经典方法中,这些额外的校准点会导致一个不确定的方程组。在贝叶斯框架下,先验知识填补了这一空白。
参考文献
(Berger 2022) M. Berger, C. Schott, 和 O. Paul, “贝叶斯传感器校准,”IEEE Sens. J., 2022 年。doi.org/10.1109/JSEN.2022.3199485.
(Close, 2021): G. Close, “Python 中的信号链分析:硬件工程师的案例研究,”Towards Data Science, 2021 年 2 月 22 日。可用:towardsdatascience.com/signal-chain-analysis-in-python-84513fcf7db2.
(Copley 2023) C. Copley, “使用 PyMC 进行贝叶斯分析导航,”Towards Data Science, 2023 年 6 月。charlescopley.medium.com/navigating-bayesian-analysis-with-pymc-87683c91f3e4
(Davidson-Pilon 2015) C. Davidson-Pilon, “黑客的贝叶斯方法:概率编程和贝叶斯推断,”Addison-Wesley Professional, 2015 年。www.amazon.com/Bayesian-Methods-Hackers-Probabilistic-Addison-Wesley/dp/0133902838
(Martin 2021) O. A. Martin, R. Kumar, 和 J. Lao, “Python 中的贝叶斯建模与计算,”Chapman and Hall/CRC, 2021 年。www.amazon.com/Bayesian-Modeling-Computation-Chapman-Statistical/dp/036789436X
(Watts 2020) A. Watts, “PyMC3 和贝叶斯推断在参数不确定性量化中的应用:非线性模型的探索:第二部分,”Towards Data Science,2022 年 6 月。towardsdatascience.com/pymc3-and-bayesian-inference-for-parameter-uncertainty-quantification-towards-non-linear-models-a03c3303e6fa
使用混合 AI 模型击败 ChatGPT 4 下棋
LLM 能多好地解决复杂问题
·发表于 Towards Data Science ·阅读时长 7 分钟·2024 年 1 月 22 日
--

图片由作者提供:使用 DALLE-3 生成的机器人下棋场景
ChatGPT 真的能下棋吗?这是激励我让 ChatGPT 和我的混合 AI 模型——一个国际象棋专家机器人——进行对弈的原因。第一局比赛是与 GPT 3.5 对弈,在这场比赛中,我发现了 OpenAI LLM 模型的几个局限性——由于 ChatGPT 对国际象棋规则的理解不足,很多非法走法和错误分析,使得整场比赛非常难以进行下去,直到比赛结束。
这个分析对于理解大型语言模型(LLMs)的局限性、它们的长期推理能力和分析能力非常重要。通过深入了解模型的行为,我们可以找到解决其缺陷并增强其优势的方法。作为 AI 工程师,我们必须始终设置不同的实验来分析模型的真实行为,并计划在我们的项目中进行适应和改进。大型语言模型技术仍然非常新,必须不断探索和研究,以确保其最佳使用和理解。
在这篇文章中,您可以找到关于第一次比赛的更多细节:
用人工智能击败四子连珠
使用蒙特卡罗模拟的简单方法
·发表于 Towards Data Science ·阅读时间 8 分钟·2024 年 8 月 28 日
--
我喜欢游戏。国际象棋、拼字游戏,你能想到的都喜欢。然而,我非常糟糕的游戏之一就是简单的四子连珠。出于某种原因,再加上我想尝试数据科学更实际的一面,这让我萌生了构建一个能够高水平玩四子连珠游戏的简单人工智能的想法。
这里显而易见的问题是,如果我在四子连珠游戏中表现很糟糕,那我怎么能构建一个能玩好这个游戏的人工智能呢?这就引出了蒙特卡罗模拟。蒙特卡罗模拟是数据科学中的一种强大工具,使用随机采样来估算复杂的结果。这种稳健的方法有着出乎意料的广泛应用,从数值积分到金融建模,甚至我们将要探讨的四子连珠游戏。
在本文中,我将简要介绍蒙特卡罗模拟,然后深入探讨如何将其应用于四子连珠游戏,最后将所有内容整合并分享一些代码。如果你愿意,我还会给你一个机会亲自与人工智能对战,看看你的表现如何。
出发吧!

图片由作者提供。(AI 生成)
蒙特卡罗方法介绍:
蒙特卡罗采样法的想法其实很简单——如果你有一个无法通过解析方法解决的问题,为什么不进行随机实验并尝试估算一个数值答案呢?如果现在这还不太理解,别担心,我们很快就会看一个例子。但在此之前,让我们先把历史搞清楚。
蒙特卡罗方法的背景故事相当有趣。该方法的主要开发者是物理学家斯坦尼斯瓦夫·乌拉姆,他非常著名,曾参与曼哈顿计划开发原子弹。与我们故事相关的是斯坦尼斯瓦夫的叔叔,他有一个不幸的赌博习惯,这导致斯坦尼斯瓦夫将这种新计算方法命名为“蒙特卡罗”,以此纪念摩纳哥著名的蒙特卡罗赌场。
现在,回到我之前承诺给你们的关于生成随机样本的例子。
一个实际示例
假设我们想要找到半径为 1 的圆内的面积。这个圆的实际面积当然是我们熟悉的公式πr²,由于 r 是 1,面积就是π。但如果我们不知道π呢?我们如何通过生成随机实验来按照蒙特卡罗方法得到这个答案?
首先,在区域 -1 < x < 1 和 -1 < y < 1 中模拟随机点。然后,对于每个点,记录它是否在圆内或圆外。下面我为 10 个、100 个、1000 个和 10,000 个随机坐标创建了这样的模拟。
你可以看到,只有 10 个点时,圆的面积(或其占据的比例)非常粗略,但随着我们添加更多的点,位于圆内的点的比例变得越来越一致。

图片由作者提供。随着点数的增加,我们得到的圆占据的总空间比例测量值更加精确。
现在你可能会问,好吧,这些图表都很漂亮,但是实际的收获是什么呢?这是一个很好的问题。
注意,我们最终得到的是在圆内的模拟比例的估算吗?好吧,我们知道正方形的面积将是 2 x 2 = 4,接下来我们可以通过将这个比例乘以 4 来估算π,因为圆的面积就是π。
下表总结了结果。注意,随着模拟次数的增加,π的估算值越来越接近真实值。

图片由作者提供
我们当然可以通过更多的模拟做得更好。以下代码片段运行一亿个样本,通常能给出一个精确到小数点后三位的结果:
import numpy as np
n = 100_000_000
points = np.random.rand(n, 2)
inside_circle = np.sum(points[:,0]**2 + points[:,1]**2 <= 1)
pi_estimate = (inside_circle / n) * 4
print(pi_estimate) # prints 3.141x
这里的关键收获是,通过生成随机模拟(我们的坐标对),我们可以为一个已知的量得到一个出奇精确的估算!这是我们第一个关于蒙特卡罗方法的实际示例。
游戏方法的翻译
这个方法很棒,但我们并不想计算π,我们想制作一个能够玩“四子棋”的 AI!幸运的是,我们刚刚用来计算π的逻辑也可以应用于四子棋游戏。
在上面的示例中,我们做了两件事,首先我们生成了随机样本(坐标对),然后第二,我们近似了一个量(π)。
好的,我们将在这里做同样的事情。首先,我们像之前那样生成随机样本,但这次这些随机样本将选择随机行动,从而模拟整个四子棋游戏。
然后,第二步,我们将再次近似一个数量,但我们追求的数量是每个行动的获胜概率。
规则简要回顾
在我们开始创建模拟之前,先快速回顾一下四子棋的规则。玩家轮流将彩色棋子放入 7x6 棋盘上任何未填充的列中。游戏在任意一方玩家的棋子连续排成四个时结束,或者当棋盘填满且未分胜负时,游戏以平局结束。
四子棋的蒙特卡罗方法
好的,现在我们已经掌握了理论,是时候将其付诸实践,教 AI 玩四子棋了。为了在四子棋游戏中找到正确的行动,我们:
-
随机抽取每一个可能的合法行动。(选择将棋子放入哪一列)。
-
然后,从这个位置模拟整个游戏,假设双方玩家的行动完全随机。
-
跟踪每场随机游戏的结果,以计算每个行动的胜率。
-
最终,选择胜率最高的行动。
听起来很简单,实际上也是如此!
为了实际展示这个方法,以下是我编写的四子棋游戏 Python 实现。这里有些复杂的部分,但不用担心,如果一时没弄明白——实际的实现细节并不如概念本身重要!
话说回来,对于那些感兴趣的人,这种方法利用了面向对象编程,包含一个能够在棋盘类(Board class)上执行行动的玩家类(Player class)。
实际操作是这样的:我们从一系列可能的有效行动开始,并从中随机选择。对于每一个行动,我们调用_simulate_move函数,它将从该点开始模拟整局游戏并返回获胜符号。如果该符号与 AI 玩家的符号匹配,我们便增加胜利次数。在进行大量模拟后,我们计算每个行动的胜率,最终返回胜率最高的行动。
def _get_best_move(self, board: Board, num_sims: int):
# Get a list of all viable moves
win_counts = {column: 0 for column in range(board.width) if board.is_valid_move(column)}
total_counts = {column: 0 for column in range(board.width) if board.is_valid_move(column)}
valid_moves = list(win_counts.keys())
for _ in range(num_sims):
column = random.choice(valid_moves) # Pick a move a random
result = self._simulate_move(board, column) # Simulate the game after making the random move
total_counts[column] += 1
if result == self.symbol: # Check whether the AI player won
win_counts[column] += 1
win_rates = {column: win_counts[column] / total_counts[column] if total_counts[column] > 0 else 0 for column in valid_moves}
best_move = max(win_rates, key=win_rates.get) # Find the move with the best win rate
return best_move
总结来说,通过模拟随机行动并跟踪游戏进程,这种蒙特卡罗方法帮助 AI 开始做出比单纯猜测更聪明的决策。
一些实际示例:
好吧,代码说完了!让我们来测试一下 AI,看看它在几种不同的局面下会表现如何。接下来我们将通过两个不同的局面,展示上述代码块的结果。第一个局面非常简单,第二个则稍微复杂一些。

作者插图
现在轮到红方了,显而易见最好的行动是将棋子放在第 5 列。如果我们使用上述方法从这个位置模拟 1000 场随机游戏,AI 玩家得到了以下胜率。将棋子放在第 5 列每次都能获胜(正如预期的那样!),因此选择了这个行动。

作者提供的结果表格。此表显示了基于采样走法的随机游戏的胜率。加粗的走法是由人工智能玩家选择的。
太棒了!我们的人工智能可以在有机会时识别出获胜的走法。这个场景很简单,是的,但老实说,我之前在游戏中错过了很多获胜的机会……
现在,让我们来看另一个局面。这一局稍微复杂一点。你能想到红方该如何下棋,才能防止黄方获得获胜的优势吗?

图片由作者提供。
这里的关键是要防止黄方形成一个有开放边的三连排布局,因为那样会导致胜利。红方需要通过在第 3 列或第 6 列下棋来阻挡这一点!我们从这个位置模拟 1000 局游戏,得到了如下的胜率。注意,人工智能正确地识别出这两种阻挡走法(第 3 列和第 6 列)具有最高的胜率。而且,它还意识到第 6 列的获胜机会最大,因此选择了第 6 列。

作者提供的结果表格。此表显示了基于采样走法的随机游戏的胜率。加粗的走法是由人工智能玩家选择的。
实际效果如何呢?
亲自体验一下吧!你可以在这里挑战人工智能:fourinarowgame.online/。难度是根据模拟次数来调整的。简单模式模拟 50 局,中等模式模拟 500 局,困难模式模拟 1500 局。就个人而言,我通常能够在简单模式下获胜,但仅此而已!
结论
好的,让我们将这些内容综合起来。在写这篇文章时,我真的想做两件事。首先,我想展示蒙特卡洛方法在像通过模拟随机坐标来估算π这样直接计算中的强大能力。
接下来,更有趣的是,我想展示相同方法在棋盘游戏中的强大效果。有趣的是,尽管对四连棋的策略一无所知,完全可以通过模拟随机游戏,最终得到一个能够以相当高水平进行对弈的人工智能对手!
一如既往,感谢阅读,下次见。
成为一名数据科学家:如果我必须重新开始,我会做什么
进入数据科学:好与坏,以及 Python 中的错误
·发布于Towards Data Science ·10 分钟阅读·2024 年 12 月 3 日
--

图片由Markus Spiske提供,拍摄于Unsplash
马丁·路德·金博士因其演讲“我有一个梦想”而闻名。他于 1963 年 8 月 28 日在华盛顿特区的林肯纪念堂前,向约 250,000 人发表了这篇演讲。这被认为是 20 世纪最重要的演讲之一。它在美国黑人民权运动中发挥了至关重要的作用。
在这次演讲中,他说他梦想着有一天,他的四个孩子能生活在一个这样一个国家,在那里,人们不会因为皮肤的颜色而被评判,而是根据他们的品格来评判。
几年前我也有一个梦想。虽然它不像马丁·路德·金的梦想那样光辉灿烂或重塑了历史进程,但我渴望成为一名数据科学家。
我从来不是为了声望或因为这是一种时尚(至今依然如此),而是因为我真心喜欢处理数据,解决复杂问题,并利用洞察力推动业务成果。成为一名数据科学家是我的独特技能和热情交汇的地方。你知道的,就是那个通向充实职业的甜蜜点。
我的旅程并不简单。我不知道从哪里开始,也不知道接下来该做什么。我参加了各种课程,其中许多最终证明并没有什么帮助。我还阅读了无数关于数据科学的文章。虽然成为一名数据科学家需要付出艰苦的努力,但我却花了很多时间在一些最终并不必要的事情上。
我希望有人在我开始分享的这些指导建议之前就曾给我一些帮助。这就是本文的目的。好消息是?遵循这些步骤并不能保证你能找到数据科学家的工作,但它们会显著提高你的机会……即使没有博士学位!我认识一些没有博士学位的专业人士,他们在数据科学领域表现出色。这个领域的成功主要依赖于毅力和实践经验。
从某个地方开始,现在就开始
—— 柏拉图
研究表明一个幼儿在 2 到 3 个月内,每天大约走 14,000 步,经历 100 次跌倒,直到掌握走路技巧。然而,他们依然坚持,根本没有考虑放弃。
相比之下,作为成年人,我们往往做的是相反的事情。我们一遇到障碍就会放弃。一个成年人可能会看到 100 次失败,而一个婴儿却看到 100 个学习的机会。婴儿不会过度分析自己的失败,也不会过度计算风险。它只是开始,尝试,跌倒,然后再试一次!
考虑一下 Justin Kan 的故事,他是 Twitch 的联合创始人。他的创业之旅并不是从一场轰动的成功开始的。它始于他所谓的“糟糕的第一个创业公司”,一个名为 Kiko 的在线日历应用。Kiko 曾与 Google Calendar 等巨头竞争,但最终它在 eBay 上以 258,100 美元的价格卖出!
接下来,他推出了 Justin.tv,一个 24/7 直播自己生活的平台。Justin.tv 最终变成了 Twitch,一个专注于游戏的直播平台。在 2014 年,亚马逊以 9.7 亿美元收购了 Twitch!
正如 Justin Kan 所说:“不要等待,赶紧建立你第一个糟糕的创业公司。”
这一建议同样适用于你进入数据科学的旅程。从某个地方开始。现在就开始你的学习过程。即使你第一次尝试觉得“糟糕”,而且你不确定从哪里开始,也没关系。你可以在初步的努力基础上不断发展,随着进展调整方向,没人能阻止你。你需要现在就开始,并从某个地方开始。

图片由Vlad Bagacian提供,来自Unsplash
那么……我从哪里开始呢?
法国的博韦大教堂原本计划在 13 世纪成为世界上最高的教堂。其雄心勃勃的设计突破了哥特式建筑的极限。然而,1284 年发生了一次显著的倒塌事故,合唱团的拱顶由于基础不充分和结构支撑不足而坍塌。至今未完成。
这为你进入数据科学的旅程提供了强有力的类比。你可能会忍不住(我们大家都有过)直接跳到那些令人兴奋的部分,比如深度学习模型、大型语言模型(LLM)或最新的机器学习框架。但就像博韦大教堂一样,如果没有坚实的基础,你的雄心勃勃的计划也许会失败。首先学习基础知识至关重要,以确保你的知识足够牢固,能够支持更高级的概念。
数学:你的通用语言
将数学视为模式的语言。数学无处不在。老实说,如果你不喜欢数学,也许数据科学的职业道路并不适合你。
你不需要成为数学家,但你确实需要理解以下关键概念:
-
线性代数(矩阵、向量等):将矩阵和向量视为数据交流的语言。理解这些概念使你能够操控数据结构以适应机器学习算法。
-
微积分(微分、积分、梯度等):它们对于优化模型至关重要,比如在训练神经网络时使用的梯度。
-
统计学(分布、描述性统计等):这是你学习如何解读数据所讲述的故事的地方。理解分布和描述性统计等概念让你能够基于数据中的模式做出明智的决策。
深入编程
在拥有数学基础后,编程将让你的想法得以实现。虽然有人主张在数据科学中学习 R,但 Python 凭借其多功能性和在行业中的广泛应用脱颖而出。此外,我认识的大多数人都在使用 Python。它对于大多数应用场景已经足够好了。专注于:
-
基本语法和函数:了解 Python 的基本工作原理。这就像在写故事之前学习字母表。
-
数据结构:列表、字典、元组——了解如何使用它们。它们对于处理现实世界中的数据至关重要。
-
控制流语句:掌握“if 语句”,“for 循环”和“while 循环”。这些语句让你能够实现逻辑,从而解决复杂的问题。使用简单的语句,你能做成比你想象的更多的事!
-
面向对象编程:理解类、函数和对象的概念。这让你能够编写高效、可重用的代码,同时也促进与他人的协作。
SQL:你的数据库语言
数据通常存储在数据库中,你需要访问和操作这些数据。SQL 是你与这些数据交互的语言。
- 与数据库交互:学习基本的 SQL 命令,以便检索、更新和管理数据。
机器学习:将数据转化为洞察
接下来,在理解了数学、编程和数据处理之后,你可以开始学习机器学习。重点是:
-
理解算法:从学习线性回归、决策树和聚类方法等基础算法开始。这些是更复杂模型的基础。
-
监督学习与非监督学习:理解这两种核心机器学习类型之间的区别。监督学习是用标注数据训练模型,而非监督学习则使用未标注数据。
-
模型评估:学习如何使用诸如分类模型的 F1 分数、语音识别的词错误率或时间序列分析的 RMSE 等指标来评估模型的表现。
-
特征工程:这是一门将原始数据转换成模型能够理解的形式的艺术。通常,这比使用复杂的算法更能带来显著的影响。你可以在这里看到一个示例。
-
库和框架:熟悉一些流行的 Python 机器学习库,如 scikit-learn、TensorFlow 和 PyTorch。
记住,机器学习不仅仅是应用算法。它是关于理解你要解决的问题,并选择正确的方法。
商业敏锐度:将技术技能转化为商业影响力
很多人联系我,想要开始数据科学的职业生涯。他们通常拥有令人印象深刻的资历,如博士学位和扎实的数学背景。然而,即便有这些出色的资历,许多人仍然难以进入这个领域。原因是什么?他们缺乏商业敏锐度。
技术技能至关重要。然而,事实是,如果一个 AI 模型不能解决业务问题,它的价值为 0。我见过许多才华横溢的数据科学家失败,因为他们构建了没人使用的复杂模型。关键是什么?学会像一个企业主一样思考。
例如:
-
将业务问题转化:与其仅仅构建一个预测模型,不如问问自己,“这个模型如何支持业务中的决策制定?”
-
优先考虑影响:专注于数据科学能产生最大价值的问题,而不是追求那些不能解决业务问题的复杂解决方案。
聚焦于核心要素
维尔弗雷多·帕累托是意大利的多面手,他在经济学、社会学等多个领域做出了贡献。他为人熟知的概念之一是帕累托最优性。它描述了一种资源分配最经济高效的情况,任何人都不能在不使别人更糟的情况下变得更好。
然而,他最著名的观察之一是在研究意大利的财富分配时发现的。他发现 20%的人口拥有 80%的土地。他还在普鲁士、英格兰、法国等地注意到同样的模式。
这一观察结果导致了我们今天所知道的帕累托原则,或称 80/20 法则。换句话说,20%的原因导致了 80%的结果。
例如,在商业中,通常观察到 80%的销售来自 20%的客户。在质量控制中,80%的问题由 20%的缺陷引起。在职场中,20%的任务贡献了 80%的工作成果。我们往往使用我们拥有的物品中的 20%来完成 80%的工作。而这个规律还在不断延续。
同样的思路也适用于你成为数据科学家的旅程。与其试图掌握所有可能的主题,不如专注于每个关键领域上一门课程:数据科学的数学、Python、SQL、机器学习和商业分析。就这样。专注于核心的 20%的技能(甚至更少),这将产生 80%的成果。
记住,不要陷入“教程地狱”的陷阱,即不断消费新的内容,却从未真正理解所学的东西。成为一名熟练的数据科学家,和任何其他工作一样,最重要的是积累经验。这意味着将你学到的东西应用到现实项目中。
当你不理解某个问题时,去查找它,学习它,然后再回到你的项目中。重复这一过程,直到你的知识和技能得到了足够的巩固。

图片由Austin Distel提供,来源于Unsplash
创建你自己的工作经验
“经验是万物的老师。”
— 朱利叶斯·凯撒
完成基础课程后,通过将所学知识应用到实际项目中,提升你的技能。
在任何领域建立专业知识都需要巨大的奉献精神和实践。埃里克森、克拉姆普和特施-罗默的研究指出,发展任何领域的专业技能通常需要大约 10,000 小时的刻意练习。顶级表演者,如音乐会音乐家和职业运动员,通常每天会投入约四小时的专注练习来完善他们的技能。
这个原则同样适用于数据科学。掌握一门技能不是一蹴而就的,它需要持续的努力和经验积累。通过每天花时间应用你学到的知识并解决现实问题,你离成为该领域的专家会越来越近。
好吧……但我该如何积累经验呢?
这比大多数人想象的要简单。然而,很多人会因为无法找到“完美”的起点而陷入困境。正如我之前所说,最关键的一步是现在开始,并且从某个地方开始。犯错并根据学习调整方法是完全可以的。
你的职业背景并不是限制,即使它不是数据科学相关的。实际上,恰恰相反,它是一项资产。
每个领域,无论是市场营销、医疗、金融还是法律,都有可以通过数据解决的问题。市场营销人员可能会分析客户参与模式。拥有金融背景的人可能希望预测股市。
我曾经建议过一个有金融背景的学员,他不知道从哪里开始。我建议他创建一个 ARIMA 模型来预测加拿大房价(ARIMA 是一个相当简单的模型)。
这并不是什么突破性的东西,但它是真实且相关的。它不仅利用了他的领域专长和技术技能,而且这个人专注于一个需求量大的话题(加拿大房价)。
如果你仍然不确定,从你真正喜欢的事情开始。这是关键。当你真正感兴趣时,你很可能会经历我们之前讨论的那 10,000 小时的练习。你也更可能以决心应对挑战,将挫折视为学习的机会,而不是放弃的理由。
它可以是任何事情。如果你是艺术家,你可能会使用计算机视觉分析视觉模式,或者用神经网络创造生成艺术。医疗工作者可能希望预测病人的结局。环境科学领域的人可能使用大型数据集来模拟气候变化的影响。例子不胜枚举。
如果可能,考虑使用大型语言模型(LLMs)。这绝对不是强制性的。然而,LLMs 最近变得非常流行,尤其是在 2022 年底 ChatGPT 发布之后。公司们正在迅速采纳它。它为在这一前沿领域中发展专业技能提供了极好的机会。
有几个框架可以用来构建基于 LLM 的应用程序。其中之一是 LangChain。但再次强调,LLM 应该是对你基本机器学习理解的补充,而不是替代。如果你觉得 LLM 太复杂,可以从简单的东西开始。
一旦你创建了某个东西,就与世界分享。可以在 Medium 上写文章,或者将你的代码发布在 GitHub 上。这将展示你的工作。从一个基本模型或项目开始,然后逐步改进。
例如,你可以从一个简单的 ARIMA 模型开始,预测房价。然后,你可以转向一个更复杂的多变量模型(如基于 Transformer 的时间序列模型)。你可以加入如利率、收入与负债比率、失业率等特征。最后,你可以将该模型与基准模型进行比较。
当你加入额外的功能或完善算法时,请更新你的 GitHub 仓库并写后续文章,记录你的进展。这展示了你的技能和对持续学习的承诺。这是学习和展示你能力的最佳(如果不是最好的)方式之一。
结论
感谢阅读本文!再次提醒,正如伏尔泰明智地所说,“完美是优秀的敌人。”现在就开始,在哪里开始都可以。你无需等待完美的项目或想法才能采取行动。随着你获得实践经验,你将会更加清楚下一步该做什么。
喜欢这篇文章吗?支持一下吧!
👏 鼓掌 50 次
🤝 在 LinkedIn 上与我连接,保持联系并讨论机会。
幕后:解释我作为数据科学家的工作
我作为数据科学家的角色究竟包括哪些内容
·发布于Towards Data Science ·阅读时间:8 分钟·2024 年 5 月 24 日
--

图片来源:Corinne Kutz来自Unsplash
我有大约三年的全职数据科学家工作经验,在这篇文章中,我想解释一下作为数据科学家到底包括哪些内容。
目标是帮助任何想要进入这个领域的人,真实地了解数据科学家是什么,我们如何工作,做什么工作,以及日常生活的典型一天是什么样的。我希望这能帮助你确定数据科学这一职业是否适合你。
让我们深入了解一下!
什么是数据科学家?
现如今,数据科学家的定义可以有很多种,但我认为数据科学家是利用编程、数学、统计学和数据来获得洞察并生成预测模型,以帮助企业的人。
你可能听说过,数据科学家是软件工程师、统计学家、分析师和数学家的结合体。这通常是正确的,但数据科学家往往会更倾向于某个领域……
RAG 中的幻觉检测方法基准测试
评估提升 LLM 生成响应可靠性的方法。
·发布于 Towards Data Science ·9 分钟阅读·2024 年 9 月 9 日
--
幻觉问题依然是当前检索增强生成(RAG)应用中的一个重大问题。本研究评估了流行的幻觉检测器在 4 个公开 RAG 数据集上的表现。通过使用 AUROC 以及精确度/召回率,我们报告了像 G-eval、Ragas 和可信语言模型等方法在自动标记不正确的 LLM 响应方面的表现,效果如何。

使用各种幻觉检测方法识别 RAG 系统中的 LLM 错误。
我目前在 Cleanlab 担任机器学习工程师,参与了本文所讨论的可信语言模型的开发工作。我很高兴在接下来的基准测试中介绍这一方法,并将其与其他方法进行评估。
问题:RAG 系统中的幻觉与错误
大型语言模型(LLM)在面对未充分支持其训练数据中的问题时,常常会产生幻觉,即错误的回答。检索增强生成(RAG)系统通过为 LLM 增加从特定知识数据库中检索上下文和信息的能力来缓解这一问题。尽管组织正在迅速采用 RAG,将 LLM 的强大功能与他们自己的专有数据结合使用,但幻觉和逻辑错误仍然是一个大问题。在一起广泛报道的案例中,一家大型航空公司(加拿大航空)因其 RAG 聊天机器人出现了关于退款政策的关键信息幻觉而输掉了官司。
为了理解这个问题,让我们首先回顾一下 RAG 系统是如何工作的。当用户提出问题(例如“这个能退吗?”),检索组件会在知识数据库中搜索相关信息,以便准确地作出回应。最相关的搜索结果被格式化为上下文,然后与用户的问题一起输入到 LLM 中,由 LLM生成最终呈现给用户的回应。由于企业级 RAG 系统通常比较复杂,最终的回应可能因多种原因不准确,包括:
-
大型语言模型(LLMs)是脆弱的,容易产生幻觉。即使检索到的上下文包含正确答案,LLM 也可能无法生成准确的回应,尤其是在生成回应时需要对上下文中的不同事实进行推理时。
-
由于检索不充分、文档分块/格式不良或知识数据库中缺乏相关信息,检索到的上下文可能不包含准确回应所需的信息。在这种情况下,LLM 仍然可能尝试回答问题,并产生不正确的回应。
尽管有些人使用“幻觉”一词仅指某些类型的 LLM 错误,但在这里我们将其与不正确回应同义使用。对 RAG 系统的用户而言,重要的是答案的准确性以及能够信任这些答案。与评估系统多种属性的 RAG 基准不同,我们只研究:不同检测器在回应不准确时,能多有效地提醒 RAG 用户。
一个 RAG(检索增强生成)答案可能因检索或生成过程中的问题而不正确。我们的研究侧重于后者的问题,这源于 LLM 的根本不可靠性。
解决方案:幻觉检测方法
假设现有的检索系统已经提取了与用户问题最相关的上下文,我们考虑使用算法来检测基于这个上下文生成的 LLM 回应不应被信任的情况。这类幻觉检测算法在医学、法律或金融等高风险应用中至关重要。除了标记不可信的回应供人工进一步审查外,这些方法还可以用来判断何时值得执行更昂贵的检索步骤(例如,搜索额外的数据源、重写查询等)。
以下是我们研究中考虑的幻觉检测方法,所有方法都基于使用 LLM 来评估生成的回应:
自我评估("Self-eval")是一种简单的技术,LLM 会被要求评估生成的答案,并按 1 到 5 的等级评定其信心。我们利用思维链(CoT)提示来改进这一技术,要求 LLM 在给出最终评分之前,先解释其信心。以下是使用的具体提示模板:
*问题:{question}
答案:{response}*
评估你对给定答案是否是一个良好且准确的回应的信心程度。
请使用以下 5 分制评分标准进行评分:
1: 你对答案是否能回答问题完全没有信心,答案可能完全跑题或与问题无关。
2: 你对答案能否回答问题的信心较低,对于答案的准确性有疑问和不确定性。
3: 你对答案是否能够回答问题有中等信心,答案看起来合理准确并且与主题相关,但仍有改进的空间。
4: 你对答案能回答问题有较高的信心,答案提供了准确的信息,解答了大部分问题。
5: 你非常自信答案能够回答问题,答案极其准确、相关,并有效地全面解答了问题。
输出应严格使用以下模板:解释:[提供你用来推导评分的简短理由],然后在最后一行写上“评分:
”。
G-Eval(来自 DeepEval 包)是一种使用 CoT(链式推理)自动制定多步骤标准来评估给定响应质量的方法。在 G-Eval 论文(Liu 等人)中,发现该技术与多个基准数据集上的人工判断具有相关性。质量可以通过多种方式进行衡量,这些方式在 LLM 提示中有所规定,在这里我们指定应根据响应的事实正确性进行评估。以下是用于 G-Eval 评估的标准:
判断输出在给定上下文下是否事实正确。
幻觉度量(来自 DeepEval 包)估计幻觉的可能性,即 LLM 响应与上下文的矛盾/不一致的程度,评估由另一个 LLM 进行。
RAGAS是一个专为 RAG(检索增强生成)设计的、由 LLM 驱动的评估工具套件,提供多种评分,可以用来检测幻觉现象。我们会考虑以下每一项 RAGAS 评分,这些评分通过使用 LLM 估算所需的数量来生成:
-
忠实度 — 答案中由提供的上下文支持的陈述的比例。
-
答案相关性是指答案中的三个 LLM 生成问题的向量表示与原始问题向量表示的平均余弦相似度。这里的向量表示是来自
BAAI/bge-base-en 编码器的嵌入向量。 -
上下文利用率衡量在 LLM(大语言模型)响应中对上下文的依赖程度。
可信语言模型(TLM)是一种模型不确定性评估技术,用于评估 LLM 回应的可信度。它通过自我反思、多次采样回应的一致性和概率度量的组合来识别错误、矛盾和幻觉。以下是用于提示 TLM 的提示模板:
*仅使用以下信息回答问题
CONTEXT: {context}
问题:{question}*
评估方法
我们将在 4 个公共的上下文-问题-回答数据集上比较上述幻觉检测方法,这些数据集涵盖了不同的 RAG 应用。
对于我们基准中的每个用户问题,现有的检索系统会返回一些相关的上下文。然后,用户查询和上下文会输入到一个生成器 LLM(通常还会包含特定应用的系统提示)中,以生成用户的回答。每种检测方法都会输入{用户查询,检索到的上下文,LLM 回应},并返回一个介于 0 和 1 之间的分数,表示幻觉的可能性。
为了评估这些幻觉检测器,我们考虑当 LLM 的回答错误与正确时,这些分数在何种程度上可靠地取较低值。在我们的每个基准测试中,都有关于每个 LLM 响应正确性的真实标注,这些标注仅供评估使用。我们基于AUROC来评估幻觉检测器,AUROC 被定义为在从 LLM 错误响应子集抽取的示例中,其分数低于从 LLM 正确响应子集抽取的示例的概率。AUROC 值较大的检测器可用于更精确/更高召回率地捕捉生产系统中的 RAG 错误。
所有考虑的幻觉检测方法本身都由 LLM 提供支持。为了公平比较,我们将所有方法中的 LLM 模型固定为gpt-4o-mini。
基准结果
我们将在下文描述每个基准数据集及其相应结果。这些数据集来自流行的HaluBench基准套件(我们没有包括该套件中的其他两个数据集,因为我们发现它们的真实标注存在重大错误)。
PubMedQA
PubMedQA是一个基于 PubMed 摘要的生物医学问答数据集。数据集中的每个实例包含一段来自 PubMed(医学出版物)摘要的文本,一个基于该段文本提出的问题,例如:9 个月的治疗是否足够治疗结核性肠炎?,以及一个生成的答案。

PubMedQA 数据集的 ROC 曲线
在这个基准测试中,TLM 是最有效的幻觉识别方法,其次是 Hallucination Metric、Self-Evaluation 和 RAGAS Faithfulness。在这三种方法中,RAGAS Faithfulness 和 Hallucination Metric 在高精度地捕捉错误答案方面更为有效(RAGAS Faithfulness 的平均精度为0.762,Hallucination Metric 的平均精度为0.761,Self-Evaluation 的平均精度为0.702)。
DROP
DROP,即“段落上的离散推理”,是一个基于维基百科文章的高级问答数据集。DROP 的难度在于问题要求对文章中的上下文进行推理,而不仅仅是提取事实。例如,给定一个描述“海鹰队与 49 人队”足球比赛中达阵的维基百科段落,问题可能是:有多少次达阵跑的总距离为 5 码或更短? 这要求 LLM 读取每次达阵的跑动并将长度与 5 码的要求进行比较。

DROP 数据集的 ROC 曲线
由于 DROP 数据集需要复杂的推理,大多数方法在检测幻觉时面临挑战。TLM 在这个基准测试中表现最为有效,其次是 Self-Evaluation 和 RAGAS Faithfulness。
COVID-QA
COVID-QA是一个基于与 COVID-19 相关的科学文章的问答数据集。数据集中的每个实例包括一段与 COVID-19 相关的科学内容和一个基于该内容提出的问题,例如:SARS-COV-2 基因组序列与 SARS-COV 的相似性有多少?
与 DROP 相比,这是一个更简单的数据集,因为它只需要从文章中提取基本信息来回答更直接的问题。

COVID-QA 数据集的 ROC 曲线
在 COVID-QA 数据集中,TLM 和 RAGAS Faithfulness 在检测幻觉方面表现强劲。Self-Evaluation 也表现良好,但其他方法,包括 RAGAS Answer Relevancy、G-Eval 和 Hallucination Metric,结果不一。
FinanceBench
FinanceBench是一个包含有关公共财务报表和上市公司信息的数据集。数据集中的每个实例都包含大量提取的纯文本财务信息、与该信息相关的问题,例如:Kraft Heinz 2015 财年的净营运资本是多少?,以及像这样的数字答案:$2850.00。

FinanceBench 数据集的 ROC 曲线
在这个基准测试中,TLM 在识别幻觉方面最为有效,其次是自我评估。其他大多数方法在提供比随机猜测更显著的改进方面表现不佳,这突显了该数据集中的挑战,数据集包含大量的上下文和数值数据。
讨论
我们对各种 RAG 基准测试中幻觉检测方法的评估揭示了以下关键见解:
-
可信语言模型(TLM)始终表现优异,通过自我反思、一致性和概率性度量,展现了在识别幻觉方面的强大能力。
-
自我评估在检测幻觉方面表现出一致的有效性,尤其在较简单的上下文中,LLM 的自我评估可以准确衡量。虽然它的表现可能不总是与 TLM 相匹配,但它仍然是评估响应质量的一个直接而有用的技术。
-
RAGAS 忠诚度在回答准确性与检索上下文密切相关的数据集(如 PubMedQA 和 COVID-QA)中展现了强大的性能。它特别擅长识别答案中的声明是否得到了提供的上下文支持。然而,其效果会根据问题的复杂性有所不同。默认情况下,RAGAS 使用
gpt-3.5-turbo-16k进行生成,使用gpt-4作为评论 LLM,这比我们在此报告的使用gpt-4o-mini结果要差。由于句子解析逻辑的问题,RAGAS 无法在我们基准测试中的某些例子上运行,我们通过在没有标点符号的答案末尾添加句号(.)来解决此问题。 -
其他方法,如 G-Eval 和幻觉度量,结果参差不齐,在不同的基准测试中表现差异较大。它们的表现不够稳定,表明可能需要进一步的完善和适应。
总体而言,TLM、RAGAS 忠诚度和自我评估在 RAG 应用中是更可靠的幻觉检测方法。对于高风险应用,结合这些方法可能会提供最佳的结果。未来的工作可以探索混合方法和有针对性的改进,以便在特定用例中更好地进行幻觉检测。通过整合这些方法,RAG 系统可以实现更高的可靠性,确保更加准确和可信的响应。
除非另有说明,所有图片均由作者提供。
基准测试 LLM 推理后端
比较 Llama 3 在 vLLM、LMDeploy、MLC-LLM、TensorRT-LLM 和 TGI 上的服务性能
·发表于 Towards Data Science ·阅读时间 10 分钟 ·2024 年 6 月 17 日
--
为大语言模型(LLM)选择合适的推理后端至关重要。它不仅能确保通过快速生成速度提供最佳的用户体验,还能通过高令牌生成率和资源利用率提高成本效益。如今,开发人员有多种选择,可以选择由知名研究和行业团队创建的推理后端。然而,选择适合特定用例的最佳后端可能是一个挑战。
为了帮助开发人员做出明智的决策,BentoML 工程团队对 Llama 3 的服务性能进行了全面的基准测试,涉及的推理后端包括 vLLM、LMDeploy、MLC-LLM、TensorRT-LLM 和 Hugging Face TGI,测试是在 BentoCloud 上进行的。这些推理后端通过两个关键指标进行了评估:
-
首次令牌生成时间 (TTFT):衡量从发送请求到生成第一个令牌所需的时间,单位为毫秒。TTFT 对于需要即时反馈的应用程序至关重要,例如互动聊天机器人。更低的延迟可以提高感知性能和用户满意度。
-
令牌生成率:评估模型在解码过程中每秒生成多少令牌,单位为令牌每秒。令牌生成率是模型处理高负载能力的指标。更高的生成率表明模型能够高效地处理多个请求并迅速生成响应,适用于高并发环境。
主要基准结果
我们在 BentoCloud 上的 A100 80GB GPU 实例(gpu.a100.1x80)上,使用 Llama 3 8B 和 70B 4 位量化模型进行基准测试,测试了三个不同的推理负载级别(10、50 和 100 个并发用户)。以下是我们的一些关键发现:
Llama 3 8B

Llama 3 8B:不同后端的首个令牌时间(TTFT)

Llama 3 8B:不同后端的令牌生成速率
-
LMDeploy:在令牌生成速率方面提供了最佳的解码性能,对于 100 个用户,每秒最多可生成 4000 个令牌。对于 10 个用户,达到了业内最佳的 TTFT。尽管随着用户数量的增加,TTFT 逐渐上升,但它仍然保持在较低水平,并始终位居最佳行列。
-
MLC-LLM:在 10 个用户下提供了与 LMDeploy 相似的解码性能。对于 10 个和 50 个用户,达到了业内最佳的 TTFT。然而,在非常高的负载下,它难以保持这种效率。当并发用户数增加到 100 时,解码速度和 TTFT 未能跟上 LMDeploy 的表现。
-
vLLM:在所有并发用户级别上实现了业内最佳的 TTFT。但与 LMDeploy 和 MLC-LLM 相比,解码性能较差,令牌生成速率为每秒 2300-2500 个令牌,类似于 TGI 和 TRT-LLM。
Llama 3 70B 4 位量化

Llama 3 70B Q4:不同后端的首个令牌时间(TTFT)

Llama 3 70B Q4:不同后端的令牌生成速率
-
LMDeploy:在服务 100 个用户时,提供了最佳的令牌生成速率,每秒最多可生成 700 个令牌,同时在所有并发用户级别中保持了最低的 TTFT。
-
TensorRT-LLM:在令牌生成速率方面展现出与 LMDeploy 相似的表现,并且在低并发用户数下维持了较低的 TTFT。然而,当并发用户数达到 100 时,TTFT 显著增加,超过了 6 秒。
-
vLLM:在所有并发用户级别下展现了一贯的低 TTFT,类似于我们在 8B 模型中观察到的情况。与 LMDeploy 和 TensorRT-LLM 相比,vLLM 的令牌生成速率较低,这可能是由于缺乏针对量化模型的推理优化。
我们发现令牌生成速率与推理后端实现的 GPU 利用率密切相关。能够维持高令牌生成速率的后端也表现出接近 100%的 GPU 利用率。相反,GPU 利用率较低的后端似乎在 Python 进程中存在瓶颈。
除了性能之外
在选择推理后端为 LLM 提供服务时,除了性能之外,其他因素也在决策中扮演着重要角色。以下列表突出了我们认为在选择理想推理后端时应考虑的关键维度。
量化
量化通过使用较低位数的整数表示权重,在性能和精度之间进行权衡。这种技术结合推理后端的优化,使得推理速度更快,并且占用更少的内存。因此,我们能够在单个 A100 80GB GPU 上加载 70B 参数的 Llama 3 模型的权重,而在没有量化的情况下,通常需要多个 GPU。
-
LMDeploy: 支持 4 位 AWQ、8 位量化和 4 位 KV 量化。
-
vLLM: 目前尚不完全支持。用户需要通过 AutoAWQ 对模型进行量化,或者在 Hugging Face 上找到预量化的模型。性能尚未优化。
-
TensorRT-LLM: 通过 modelopt 支持量化,请注意并非所有模型都实现了量化数据类型。
-
TGI: 支持 AWQ、GPTQ 和 bits-and-bytes 量化
-
MLC-LLM: 支持 3 位和 4 位分组量化。AWQ 量化支持仍处于实验阶段。
模型架构
能够在不同的模型架构之间使用相同的推理后端,为工程团队提供了灵活性。这使得他们可以随着新改进的出现,在各种大型语言模型之间切换,而无需迁移底层推理基础设施。
-
LMDeploy: TurboMind 引擎支持的 20 多种模型。目前,像 Mistral、Qwen 1.5 这样的需要滑动窗口注意力的模型尚不完全支持。
-
vLLM: 支持 30 多种模型
-
TensorRT-LLM: 支持 30 多种模型
-
TGI: 支持 20 多种模型
-
MLC-LLM: 支持 20 多种模型
硬件限制
能够在不同的硬件上运行提供了成本节约,并且能够根据推理需求选择合适的硬件。同时,在当前 GPU 短缺的情况下,它也提供了替代方案,有效地帮助解决供应瓶颈问题。
-
LMDeploy: 仅针对 Nvidia CUDA 进行了优化
-
vLLM: 支持 Nvidia CUDA,AMD ROCm,AWS Neuron,CPU
-
TensorRT-LLM: 仅支持 Nvidia CUDA
-
TGI: Nvidia CUDA,AMD ROCm,Intel Gaudi,AWS Inferentia
-
MLC-LLM: 支持 Nvidia CUDA,AMD ROCm,Metal,Android,IOS,WebGPU
开发者体验
为生产环境设计的推理后端应提供稳定的版本发布,并便于简化的持续部署工作流。此外,开发者友好的后端应具有明确定义的接口,支持快速开发和高代码可维护性,这对于构建由大型语言模型(LLMs)驱动的 AI 应用至关重要。
-
稳定版本:LMDeploy、TensorRT-LLM、vLLM 和 TGI 都提供了稳定版本。MLC-LLM 目前没有稳定的标记版本,只有夜间构建版本;一种可能的解决方案是从源代码构建。
-
模型编译:TensorRT-LLM 和 MLC-LLM 在推理后端准备就绪之前,需要进行显式的模型编译步骤。这个步骤可能在部署时引入额外的冷启动延迟。
-
文档:LMDeploy、vLLM 和 TGI 都具有易于学习的文档和示例。MLC-LLM 的学习曲线适中,主要是因为需要理解模型编译步骤。TensorRT-LLM 是在我们的基准测试中设置最具挑战性的。由于缺乏足够的优质示例,我们不得不阅读 TensorRT-LLM、tensorrtllm_backend和 Triton Inference Server 的文档,转换检查点,构建 TRT 引擎,并编写大量配置。
概念
Llama 3
Llama 3是 Llama LLM 系列的最新迭代,提供多种配置。我们在基准测试中使用了以下模型大小。
-
8B:该模型具有 80 亿个参数,使其在计算资源方面既强大又易于管理。使用 FP16 时,它大约需要 16GB 的内存(不包括 KV 缓存和其他开销),可以在单个 A100–80G GPU 实例上运行。
-
70B 4 位量化:该 70 亿参数模型经过 4 位量化后,显著减少了其内存占用。量化通过减少每个参数的位数来压缩模型,提供更快的推理速度,并在性能损失最小的情况下降低内存使用。使用 4 位 AWQ 量化时,加载模型权重大约需要 37GB 内存,可以在单个 A100–80G 实例上运行。在单个 GPU 设备上提供量化权重通常能够实现模型的最佳吞吐量,相较于在多个设备上提供。
推理平台
我们确保使用BentoML提供的推理后端相比原生 Python 推理仅增加了最小的性能开销。这个开销源于提供了用于扩展、可观察性和 IO 序列化的功能。使用 BentoML 和BentoCloud为不同的推理后端提供了统一的 RESTful API,简化了基准测试的设置和操作。
推理后端
不同的后端提供了多种服务 LLM 的方式,每种方式都有其独特的功能和优化技术。我们测试的所有推理后端都遵循 Apache 2.0 许可证。
-
LMDeploy:一个推理后端,专注于提供高解码速度和高效处理并发请求。它支持各种量化技术,适合部署具有较低内存要求的大型模型。
-
vLLM:一个高性能的推理引擎,专门优化用于服务 LLM。它因高效利用 GPU 资源和快速解码能力而闻名。
-
TensorRT-LLM:一个推理后端,利用 NVIDIA 的 TensorRT,一个高性能的深度学习推理库。它优化了在 NVIDIA GPU 上运行大模型,提供快速推理并支持诸如量化等高级优化。
-
Hugging Face 文本生成推理(TGI):一个用于部署和服务 LLM 的工具包。它在 Hugging Face 的生产环境中用于驱动 Hugging Chat、推理 API 和推理端点。
-
MLC-LLM:一个为 LLM 量身定制的 ML 编译器和高性能部署引擎。它建立在 Apache TVM 之上,在服务模型之前需要进行编译和权重转换。
将 BentoML 与各种推理后端集成以自托管 LLM 是非常简单的。BentoML 社区在 GitHub 上提供了以下示例项目,帮助你完成整个过程。
-
MLC-LLM:
github.com/bentoml/BentoMLCLLM -
LMDeploy:
github.com/bentoml/BentoLMDeploy -
TRT-LLM:
github.com/bentoml/BentoTRTLLM
基准测试设置
模型
我们测试了 Meta-Llama-3–8B-Instruct 和 Meta-Llama-3–70B-Instruct 的 4 位量化模型。对于 70B 模型,我们进行了 4 位量化,使其能够在单个 A100–80G GPU 上运行。如果推理后端支持原生量化,我们使用推理后端提供的量化方法。例如,对于 MLC-LLM,我们使用了q4f16_1量化方案。否则,我们使用了来自 Hugging Face 的 AWQ 量化casperhansen/llama-3-70b-instruct-awq模型。
请注意,除了启用常见的推理优化技术,如连续批处理、闪存注意力和前缀缓存外,我们没有为每个独立的后端微调推理配置(GPU 内存使用、最大序列数、分页 KV 缓存块大小等)。这是因为随着我们服务的 LLM 数量的增加,这种方法不可扩展。提供一组最优的推理参数是后端性能和易用性的隐性衡量标准。
基准客户端
为了准确评估不同 LLM 后端的性能,我们创建了一个自定义基准测试脚本。该脚本通过变化用户负载并在不同的并发级别下发送生成请求,模拟了实际场景。
我们的基准测试客户端可以在 20 秒内启动目标用户数量,之后通过发送带有随机选择提示词的并发生成请求来对 LLM 后端进行压力测试。我们测试了 10、50 和 100 个并发用户,以评估系统在不同负载下的表现。
每次压力测试持续 5 分钟,在此期间我们每隔 5 秒收集一次推理指标。这个持续时间足以观察到潜在的性能下降、资源利用瓶颈或其他在短时间测试中可能未能显现的问题。
欲了解更多信息,请参见我们的基准测试客户端的源代码。
提示词数据集
我们测试的提示词来自databricks-dolly-15k 数据集。在每次测试会话中,我们从该数据集中随机选择提示词。我们还测试了有无系统提示词的文本生成。一些后端可能通过启用前缀缓存来优化常见的系统提示词场景。
库版本
-
BentoML: 1.2.16
-
vLLM: 0.4.2
-
MLC-LLM: mlc-llm-nightly-cu121 0.1.dev1251(尚无稳定版)
-
LMDeploy: 0.4.0
-
TensorRT-LLM: 0.9.0(与 Triton v24.04 一起使用)
-
TGI: 2.0.4
推荐
LLM 推理优化领域正在迅速发展并且受到广泛研究。今天可用的最佳推理后端可能很快会被新兴技术所超越。根据我们在撰写时进行的基准测试和可用性研究,我们有以下建议,帮助选择在各种场景下最适合 Llama 3 模型的后端。
Llama 3 8B
对于 Llama 3 8B 模型,LMDeploy在所有用户负载下始终提供低 TTFT 和最高的解码速度。其易用性也是一个显著优势,因为它可以即时将模型转换为 TurboMind 引擎格式,从而简化部署过程。在撰写时,LMDeploy 对使用滑动窗口注意力机制的模型(如 Mistral 和 Qwen 1.5)支持有限。
vLLM即使在用户负载增加的情况下,也始终保持较低的 TTFT,适合需要保持低延迟的场景。vLLM 提供了易于集成、广泛的模型支持和广泛的硬件兼容性,所有这些都得到强大的开源社区的支持。
MLC-LLM提供了最低的 TTFT,并在较低的并发用户情况下维持较高的解码速度。然而,在极高用户负载下,MLC-LLM 在维持顶级解码性能方面表现较差。尽管面临这些挑战,MLC-LLM 凭借其机器学习编译技术展现出巨大的潜力。如果能够解决这些性能问题并实施稳定的版本发布,将极大提升其效能。
Llama 3 70B 4 位量化
对于 Llama 3 70B Q4 模型,LMDeploy在所有用户负载下展现了卓越的性能,具有最低的 TTFT。它还保持了较高的解码速度,适用于低延迟和高吞吐量都至关重要的应用场景。LMDeploy 还以其易用性脱颖而出,因为它能够快速转换模型,无需大量的设置或编译,非常适合快速部署场景。
TensorRT-LLM在吞吐量上与 LMDeploy 相匹配,但在高用户负载场景下,其 TTFT 延迟表现不如预期。得益于 Nvidia 的支持,我们预计这些差距将会得到迅速解决。然而,它对模型编译的固有需求以及对 Nvidia CUDA GPU 的依赖是有意的设计选择,这在部署过程中可能会带来一些限制。
vLLM即使在用户负载增加的情况下,也能保持较低的 TTFT,其易用性对许多用户来说是一个显著的优势。然而,在撰写本文时,后端缺乏对 AWQ 量化的优化,导致量化模型的解码性能不尽如人意。
致谢
本文及相关基准测试是与我尊敬的同事 Rick Zhou、Larme Zhao 和 Bo Jiang 共同合作完成的。本文中展示的所有图片均由作者创作。
使用 GitHub Actions 在 CICD 中进行 Pytest 基准测试
让 Pytest 基准测试变得自动化、可操作并且直观
·发表于Towards Data Science ·阅读时间:7 分钟·2024 年 3 月 5 日
--

图片由Lucas Santos提供,来源于Unsplash
“你的代码很慢”是一个常见的说法,但要找出代码的哪一部分慢,究竟有多慢,却需要进行大量的反复试验和测试,而且慢到底有多慢呢?一旦找到代码的瓶颈,它是否能在输入扩大到 100 倍或 1000 倍的情况下良好扩展,并且结果是基于 10 次迭代的平均值?
这就是
pytest-benchmark派上用场的地方
补充单元测试的概念,单元测试是对代码库中单一单元或小部分进行测试,我们可以在此基础上拓展,利用pytest-benchmark轻松地测量代码性能。
本文将介绍如何设置、运行并解释pytest-benchmark的基准测试结果。为了在项目中正确实施基准测试,文章的高级部分还将探讨如何比较不同运行之间的基准测试结果并在结果未达到某些阈值时拒绝提交,如何存储并查看历史基准测试结果,并通过箱型图和折线图呈现这些结果!
将 Snowflake Cortex 与 Scikit-Learn 在实际预测用例中的表现进行基准测试。
作为目前最流行的基于云的数据平台之一,Snowflake 现在嵌入了高级建模功能,我也尝试了其中的预测功能。
·发表于 Towards Data Science ·9 分钟阅读·2024 年 2 月 25 日
--

一个戏剧性的雪旋风 — 由作者通过 Leornardo.ai 生成
几个月前(2023 年 11 月 23 日),Snowflake 宣布发布了多项新的建模/大语言模型(LLM)功能,这些功能属于一个名为 “Cortex” 的框架。
自 12 月中旬以来,前两项功能(预测 和 异常检测)已经全面开放使用(见 Snowflake 7.44 版本说明)。
因此,Snowflake 继续其使命,提供一个完全托管的“一站式”分析平台,帮助数据使用者从其数据资产中释放价值,除了面向数据工程团队的常规数据仓库功能之外。
这些功能让一些人联想到了 “Google BigQuery ML” 的功能,后者最初在 2020 年 8 月 发布(是的,四年前!);让我们深入了解一下吧!
预测本地城市游泳池的访问量
除了 Snowday ❄️ 上令人兴奋的演讲和量身定制的展示外,我迫不及待地想将一个真实的数据集加载到 Snowflake 中,看看 Cortex 相比于传统方法的表现如何。
伯努利朴素贝叶斯,详解:适合初学者的可视化指南及代码示例
分类算法
通过是/否概率解锁预测能力
·发布于 Towards Data Science ·阅读时间 9 分钟·2024 年 8 月 24 日
--

⛳️ 更多分类算法,详解: · 虚拟分类器 · K 近邻分类器 ▶ 伯努利朴素贝叶斯 · 高斯朴素贝叶斯 · 决策树分类器 · 逻辑回归 · 支持向量分类器 · 多层感知机
与虚拟分类器或 KNN 基于相似度推理的基线方法不同,朴素贝叶斯利用概率理论。它将每个“线索”(或特征)的个体概率结合起来,做出最终预测。这种直接而强大的方法在许多机器学习应用中被证明是无价的。

所有视觉内容:作者使用 Canva Pro 创建。已优化移动设备显示;在桌面上可能显示过大。
定义
朴素贝叶斯是一种使用概率对数据进行分类的机器学习算法。它基于贝叶斯定理,这是一个计算条件概率的公式。 “朴素”部分指的是它的一个关键假设:它假设所有特征彼此独立,即使它们在现实中可能不是。尽管这种简化通常不现实,但它大大减少了计算复杂性,并且在许多实际场景中表现良好。

朴素贝叶斯方法是机器学习中的一种简单算法,基于概率作为其基础。
朴素贝叶斯分类器的主要类型
朴素贝叶斯分类器有三种主要类型。这些类型之间的主要区别在于它们对特征分布的假设:
-
伯努利朴素贝叶斯:适用于二进制/布尔特征。它假设每个特征是二进制值(0/1)的变量。
-
多项式朴素贝叶斯:通常用于离散计数。它常用于文本分类,其中特征可能是单词计数。
-
高斯朴素贝叶斯:假设连续特征服从正态分布。

伯努利朴素贝叶斯假设二进制数据,多项式朴素贝叶斯处理离散计数,高斯朴素贝叶斯处理假设正态分布的连续数据。
重点是从最简单的朴素贝叶斯开始,这就是伯努利朴素贝叶斯。其名称中的“伯努利”来自于假设每个特征都是二进制值的假设。
使用的数据集
在本文中,我们将使用这个人工高尔夫数据集(灵感来源于1)作为示例。该数据集预测一个人是否会根据天气条件打高尔夫。

列:‘Outlook’,‘Temperature’(华氏度),‘Humidity’(百分比),‘Wind’和‘Play’(目标特征)
# IMPORTING DATASET #
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rain', 'rain', 'rain', 'overcast', 'sunny', 'sunny', 'rain', 'sunny', 'overcast', 'overcast', 'rain', 'sunny', 'overcast', 'rain', 'sunny', 'sunny', 'rain', 'overcast', 'rain', 'sunny', 'overcast', 'sunny', 'overcast', 'rain', 'overcast'],
'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'Wind': [False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False],
'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)
# ONE-HOT ENCODE 'Outlook' COLUMN
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
# CONVERT 'Windy' (bool) and 'Play' (binary) COLUMNS TO BINARY INDICATORS
df['Wind'] = df['Wind'].astype(int)
df['Play'] = (df['Play'] == 'Yes').astype(int)
# Set feature matrix X and target vector y
X, y = df.drop(columns='Play'), df['Play']
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)
print(pd.concat([X_train, y_train], axis=1), end='\n\n')
print(pd.concat([X_test, y_test], axis=1))
我们将稍作调整,使用伯努利朴素贝叶斯,通过将特征转换为二进制格式。

由于所有数据必须是 0 和 1 格式,因此‘Outlook’被进行独热编码,而 Temperature 被分为≤80 和>80。同样,Humidity 被分为≤75 和>75。
# One-hot encode the categorized columns and drop them after, but do it separately for training and test sets
# Define categories for 'Temperature' and 'Humidity' for training set
X_train['Temperature'] = pd.cut(X_train['Temperature'], bins=[0, 80, 100], labels=['Warm', 'Hot'])
X_train['Humidity'] = pd.cut(X_train['Humidity'], bins=[0, 75, 100], labels=['Dry', 'Humid'])
# Similarly, define for the test set
X_test['Temperature'] = pd.cut(X_test['Temperature'], bins=[0, 80, 100], labels=['Warm', 'Hot'])
X_test['Humidity'] = pd.cut(X_test['Humidity'], bins=[0, 75, 100], labels=['Dry', 'Humid'])
# One-hot encode the categorized columns
one_hot_columns_train = pd.get_dummies(X_train[['Temperature', 'Humidity']], drop_first=True, dtype=int)
one_hot_columns_test = pd.get_dummies(X_test[['Temperature', 'Humidity']], drop_first=True, dtype=int)
# Drop the categorized columns from training and test sets
X_train = X_train.drop(['Temperature', 'Humidity'], axis=1)
X_test = X_test.drop(['Temperature', 'Humidity'], axis=1)
# Concatenate the one-hot encoded columns with the original DataFrames
X_train = pd.concat([one_hot_columns_train, X_train], axis=1)
X_test = pd.concat([one_hot_columns_test, X_test], axis=1)
print(pd.concat([X_train, y_train], axis=1), '\n')
print(pd.concat([X_test, y_test], axis=1))
主要机制
伯努利朴素贝叶斯算法适用于每个特征值为 0 或 1 的数据。
-
计算训练数据中每个类别的概率。
-
对于每个特征和类别,计算在给定类别下特征为 1 和 0 的概率。
-
对于一个新实例:对于每个类别,将其概率与该类别下每个特征值(0 或 1)的概率相乘。
-
预测具有最高结果概率的类别。

对于我们的高尔夫数据集,伯努利 NB 分类器查看每个特征在每个类别(YES 和 NO)下发生的概率,然后根据哪个类别的概率更高来做出决策。
训练步骤
伯努利朴素贝叶斯的训练过程包括从训练数据中计算概率:
- 类别概率计算:对于每个类别,计算其概率:(该类别中实例的数量)/(所有实例的总数)。

在我们的高尔夫示例中,算法将计算总体上打高尔夫的频率。
from fractions import Fraction
def calc_target_prob(attr):
total_counts = attr.value_counts().sum()
prob_series = attr.value_counts().apply(lambda x: Fraction(x, total_counts).limit_denominator())
return prob_series
print(calc_target_prob(y_train))
- 特征概率计算:对于每个特征和每个类别,计算:
-
(特征为 0 的实例数量)/(该类别中实例的数量)
-
(特征为 1 的实例数量)/(该类别中实例的数量)

对于每种天气条件(例如,晴天),计算在晴天时打高尔夫的频率,以及晴天时不打高尔夫的频率。
from fractions import Fraction
def sort_attr_label(attr, lbl):
return (pd.concat([attr, lbl], axis=1)
.sort_values([attr.name, lbl.name])
.reset_index()
.rename(columns={'index': 'ID'})
.set_index('ID'))
def calc_feature_prob(attr, lbl):
total_classes = lbl.value_counts()
counts = pd.crosstab(attr, lbl)
prob_df = counts.apply(lambda x: [Fraction(c, total_classes[x.name]).limit_denominator() for c in x])
return prob_df
print(sort_attr_label(y_train, X_train['sunny']))
print(calc_feature_prob(X_train['sunny'], y_train))

同样的过程应用于所有其他特征。
for col in X_train.columns:
print(calc_feature_prob(X_train[col], y_train), "\n")
- 平滑(可选):在每次概率计算时,向分子和分母中添加一个小值(通常为 1),以避免出现零概率。

我们对所有分子加 1,对所有分母加 2,以保持总类别概率为 1。
# In sklearn, all processes above is summarized in this 'fit' method:
from sklearn.naive_bayes import BernoulliNB
nb_clf = BernoulliNB(alpha=1)
nb_clf.fit(X_train, y_train)
- 存储结果:保存所有计算出的概率,以便在分类过程中使用。

平滑已经应用到所有特征概率中。我们将使用这些表格进行预测。
分类步骤
给定一个新实例,其特征值为 0 或 1:
- 概率收集:对于每个可能的类别:
-
从该类别发生的概率(类别概率)开始。
-
对于新实例中的每个特征,收集该特征在此类别下为 0/1 的概率。

对于 ID 14,我们选择每个特征(无论是 0 还是 1)发生的概率。
- 得分计算与预测:对于每个类别:
-
将所有收集到的概率相乘。
-
结果是该类别的得分。
-
得分最高的类别就是预测结果。

在将类别概率与所有特征概率相乘后,选择得分更高的类别。
y_pred = nb_clf.predict(X_test)
print(y_pred)
评估步骤

这个简单的概率模型在这个简单数据集上表现出了很好的准确性。
# Evaluate the classifier
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
关键参数
伯努利朴素贝叶斯有一些重要参数:
-
Alpha (α):这是平滑参数。它为每个特征添加一个小的计数,以防止零概率。默认通常为 1.0(拉普拉斯平滑),如前所示。
-
Binarize:如果您的特征尚未是二元的,此阈值会将其转换。任何高于此阈值的值变为 1,低于该阈值的值变为 0。

对于 scikit-learn 中的 BernoulliNB,数值特征通常是标准化的,而不是手动二值化的。模型会将这些标准化的值转化为二进制,通常使用 0(均值)作为阈值。
- Fit Prior:是否学习类别的先验概率,或假设均匀先验(50/50)。

对于我们的高尔夫数据集,我们可能从默认的α=1.0 开始,不进行二值化(因为我们的特征已经是二值的),并设置 fit_prior=True。
优缺点
和机器学习中的任何算法一样,伯努利朴素贝叶斯也有其优点和局限性。
优点:
-
简单性:容易实现和理解。
-
效率:训练和预测速度快,适用于大规模特征空间。
-
小数据集的表现:即使在训练数据有限的情况下也能表现良好。
-
处理高维数据:在特征维度较多时表现良好,尤其适用于文本分类。
缺点:
-
独立性假设:假设所有特征是独立的,但在实际数据中,这通常不成立。
-
仅限二元特征:在其纯粹形式下,仅适用于二元数据。
-
对输入数据的敏感性:对特征的二值化方式可能较为敏感。
-
零频问题:如果没有平滑处理,零概率会强烈影响预测结果。
最后的备注
伯努利朴素贝叶斯分类器是一种简单而强大的二元分类机器学习算法。它在文本分析和垃圾邮件检测中表现优秀,这些场景中的特征通常是二元的。因其速度和效率著称,这一概率模型在小数据集和高维空间中表现出色。
尽管假设特征之间相互独立,但它的准确性常常能与更复杂的模型相媲美。伯努利朴素贝叶斯作为一个优秀的基准模型和实时分类工具。
🌟 伯努利朴素贝叶斯简化版
# Import needed libraries
import pandas as pd
from sklearn.naive_bayes import BernoulliNB
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
# Load the dataset
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rainy', 'rainy', 'rainy', 'overcast', 'sunny', 'sunny', 'rainy', 'sunny', 'overcast', 'overcast', 'rainy', 'sunny', 'overcast', 'rainy', 'sunny', 'sunny', 'rainy', 'overcast', 'rainy', 'sunny', 'overcast', 'sunny', 'overcast', 'rainy', 'overcast'],
'Temperature': [85.0, 80.0, 83.0, 70.0, 68.0, 65.0, 64.0, 72.0, 69.0, 75.0, 75.0, 72.0, 81.0, 71.0, 81.0, 74.0, 76.0, 78.0, 82.0, 67.0, 85.0, 73.0, 88.0, 77.0, 79.0, 80.0, 66.0, 84.0],
'Humidity': [85.0, 90.0, 78.0, 96.0, 80.0, 70.0, 65.0, 95.0, 70.0, 80.0, 70.0, 90.0, 75.0, 80.0, 88.0, 92.0, 85.0, 75.0, 92.0, 90.0, 85.0, 88.0, 65.0, 70.0, 60.0, 95.0, 70.0, 78.0],
'Wind': [False, True, False, False, False, True, True, False, False, False, True, True, False, True, True, False, False, True, False, True, True, False, True, False, False, True, False, False],
'Play': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'No', 'Yes', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes']
}
df = pd.DataFrame(dataset_dict)
# Prepare data for model
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)
df['Play'] = (df['Play'] == 'Yes').astype(int)
# Split data into training and testing sets
X, y = df.drop(columns='Play'), df['Play']
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)
# Scale numerical features (for automatic binarization)
scaler = StandardScaler()
float_cols = X_train.select_dtypes(include=['float64']).columns
X_train[float_cols] = scaler.fit_transform(X_train[float_cols])
X_test[float_cols] = scaler.transform(X_test[float_cols])
# Train the model
nb_clf = BernoulliNB()
nb_clf.fit(X_train, y_train)
# Make predictions
y_pred = nb_clf.predict(X_test)
# Check accuracy
print(f"Accuracy: {accuracy_score(y_test, y_pred)}")
进一步阅读
对于伯努利朴素贝叶斯分类器及其在 scikit-learn 中的实现,读者可以参考官方文档,文档中提供了关于其使用和参数的详细信息。
技术环境
本文使用 Python 3.7 和 scikit-learn 1.5。虽然讨论的概念通常适用,但具体的代码实现可能会因版本不同而略有差异。
关于插图
除非另有说明,所有图片均由作者创作,并结合了来自 Canva Pro 的授权设计元素。

要查看 Bernoulli Naive Bayes 的简洁视觉总结,请查看Instagram 上的配套帖子。
参考文献
1 T. M. Mitchell, 机器学习 (1997), McGraw-Hill Science/Engineering/Math, 第 59 页
𝙎𝙚𝙚 𝙢𝙤𝙧𝙚 𝘾𝙡𝙖𝙨𝙨𝙞𝙛𝙞𝙘𝙖𝙩𝙞𝙤𝙣 𝘼𝙡𝙜𝙤𝙧𝙞𝙩𝙝𝙢𝙨 𝙝𝙚𝙧𝙚:

分类算法
查看列表8 篇故事


𝙔𝙤𝙪 𝙢𝙞𝙜𝙝𝙩 𝙖𝙡𝙨𝙤 𝙡𝙞𝙠𝙚:

回归算法
查看列表5 篇故事



集成学习
查看列表4 篇故事


BERT — 直观且详尽的解释
将通用理解融入语言模型
·发表于Towards Data Science ·阅读时间:46 分钟·2024 年 8 月 23 日
--

“Baking”由 Daniel Warfield 使用 MidJourney 创作。所有图片除非另有说明,均为作者提供。文章最初发布于直观且详尽的解释。
在本文中,我们将讨论“来自变换器的双向编码器表示”(BERT),这是一种旨在理解语言的模型。虽然 BERT 与像 GPT 这样的模型相似,但 BERT 的重点是理解文本,而不是生成文本。这在各种任务中都很有用,比如对产品评论的正面性进行排序,或者预测问题的答案是否正确。
在深入探讨 BERT 之前,我们将简要讨论变换器架构,它是 BERT 的直接灵感来源。通过理解这一点,我们将深入了解 BERT,并讨论它是如何构建和训练的,以通过利用语言的一般理解来解决问题。最后,我们将从零开始创建一个 BERT 模型,并使用它来预测产品评论是正面的还是负面的。
这个对谁有用? 任何希望全面了解人工智能前沿状态的人。
这篇文章的难度如何? 本文的前半部分适合各个水平的读者,而后半部分涉及从零开始的实现则相对较为高级。根据需要提供了补充资源。
前提条件: 我强烈建议理解关于…的基本概念。
贝塞尔修正:为什么在样本方差中我们除以 n−1 而不是 n?
理解总体方差的无偏估计
·发表于 Towards Data Science ·阅读时间 8 分钟 ·2024 年 11 月 11 日
--
在统计学中,许多学习者常常困惑的一个问题是,为什么在计算样本方差时要除以 n−1,而不是直接使用样本中的观察数 n。这个选择看起来很小,但它是一个重要的调整,用来修正从样本估计总体方差时出现的自然偏差。我们用简单的语言通过一些例子来讲解,为什么除以 n−1(称为贝塞尔修正)是必要的。
贝塞尔修正的核心概念是我们倾向于修正我们的估计,但一个明确的问题是,我们估计的是什么?因此,通过应用贝塞尔修正,我们倾向于修正从假设的样本均值计算出的偏差估计值,而我们的假设样本均值很少与实际的总体均值完全一致,因此可以安全假设,在 99.99%的情况下(在实际中甚至更多),我们的样本均值不会等于总体均值。我们所有的计算都是基于这个假设的样本均值进行的,即我们通过该样本的均值来估计总体参数。
继续阅读博客下文,你会直观地明白为什么在所有 99.99%的情况下(除了那个样本均值=总体均值的情况),我们倾向于低估实际偏差,因此为了弥补这个低估误差,除以比’n’更小的数就能解决问题,因此除以 n-1 而不是 n,正是为了弥补在计算样本均值偏差时的低估。
从这里开始阅读,你最终会理解……
样本方差与总体方差
当我们拥有整个数据总体时,方差是通过首先计算均值(平均值),然后确定每个数据点偏离该均值的程度,平方这些偏差,求和,最后除以 n(总体中的数据点总数)来计算的。这给我们的是总体方差。
然而,如果我们没有整个总体的数据,而是仅仅使用一个样本,我们需要估算总体方差。但问题在于:当只使用样本时,我们不知道真实的总体均值(记作μ),所以我们使用样本均值(x_bar)代替。
低估问题
要理解为什么在样本情况下我们除以 n−1,我们需要仔细观察使用样本均值而非总体均值时会发生什么。对于实际应用来说,依赖样本统计是我们唯一的选择。下面是其工作原理:
当我们计算样本方差时,我们会找到每个数据点与样本均值的偏差,平方这些偏差,然后取这些平方偏差的平均值。然而,样本均值通常不等于总体均值。由于这种差异,使用样本均值倾向于低估总体的实际分布或方差。
让我们通过所有可能发生的情况(共三种不同情况)来分析,我会详细讲解第一种情况,其他两种情况遵循相同的原理,详细解析已给出第一种情况。
1. 当样本均值小于总体均值(x_bar < 总体均值)
如果我们的样本均值(x_bar)小于总体均值(μ),那么样本中的许多数据点将比它们距离μ更接近 x_bar。因此,距离(偏差)均值的距离平均来说会更小,从而导致更小的方差计算。这意味着我们低估了实际的方差。
以下图的解释——较小的正态分布是我们的样本,而较大的正态分布是我们的总体(在上面的情况中,x_bar < 总体均值),图形如下所示。
由于我们只能收集样本中的数据点,无法收集总体的所有数据点,因为那是不可能的。在我们的样本数据点中,从负无穷大到 x_bar 和总体均值的中点,样本点与总体均值之间的绝对差异或平方差(偏差)会大于样本点与样本均值之间的绝对差异或平方差,在中点右侧直到正无穷大,基于样本均值计算的偏差会大于基于总体均值计算的偏差。下图显示了上述情况的区域,由于正态曲线的对称性,我们可以肯定地说,低估区间会大于高估区间,这也在下图中得到了突出显示,这导致偏差的总体低估。
因此,为了弥补低估,我们将偏差除以一个小于样本大小 ’n’ 的数字,即 ‘n-1’,这被称为贝塞尔修正。

由 Python 代码使用 matplotlib 库生成的图,图片来源(作者)
2. 当样本均值大于总体均值时
如果样本均值大于总体均值,我们会遇到相反的情况:样本中较低端的数据点会比 μ 更接近 x_bar,仍然导致方差的低估。
基于以上的细节,很明显在这种情况下,低估区间也大于高估区间,因此我们在这种情况下也会通过将偏差除以 ‘n-1’ 而不是 n 来弥补这种低估。

由 Python 代码使用 matplotlib 库生成的图,图片来源(作者)
3. 当样本均值与总体均值完全相等时(0.000001%)
这种情况比较少见,只有当样本均值与总体均值完全对齐时,我们的估计才是无偏的。然而,这种对齐几乎从未偶然发生,因此我们通常假设我们在低估。
显然,计算的样本点相对于样本均值的偏差与计算的相对于总体均值的偏差完全相同,因为样本均值和总体均值是相等的。这将不会产生低估或高估区间。

由 Python 代码使用 matplotlib 库生成的图,图片来源(作者)
总之,任何样本均值 x_bar 和总体均值 μ 之间的差异(几乎总是发生的)都会导致我们低估方差。这就是为什么我们需要通过除以 n−1 来进行修正,以弥补这种偏差。
为什么除以 n−1 可以纠正这种偏差:贝塞尔修正
除以 n−1 称为贝塞尔修正,它补偿了样本方差中自然的低估偏差。当我们除以 n−1 时,实际上是在做一个小的调整,使我们的方差估计更接近真实的总体方差。
这一切也可以与自由度相关联,从自由度的角度理解这些知识是需要一定自由度知识的。
在样本中,一个自由度被“消耗”用于计算样本均值。这使得我们只剩下 n−1 个独立的数据点来提供关于方差的信息,这也是为什么我们要用 n−1 而不是 n 来除。
为什么这种调整在小样本中更为重要?
如果我们的样本量非常小,除以 n 与除以 n−1 之间的差异会变得更加显著。例如,如果你的样本量是 10:
-
如果除以 n,意味着除以 10,这样可能会大大低估方差。
-
除以 n−1 或 9,能够提供更好的估计,弥补小样本带来的偏差。
但是如果你的样本量很大(比如 10,000),除以 10,000 或 9,999 的差异微乎其微,因此贝塞尔修正的影响也就非常小。
不使用贝塞尔修正的后果
如果我们不使用贝塞尔修正,我们的样本方差通常会低估总体方差。这可能会产生连锁反应,尤其是在统计建模和假设检验中,准确的方差估计对于得出可靠结论至关重要。
例如:
-
置信区间:方差估计会影响围绕样本均值的置信区间宽度。低估方差可能导致置信区间过窄,从而给人一种过于精确的错误印象。
-
假设检验:许多统计检验,如 t 检验,依赖于准确的方差估计来确定观察到的效应是否显著。低估方差可能会让我们更难发现真正的差异。
为什么不除以 n−2 或 n−3?
选择除以 n−1 并非随意。虽然我们这里不会深入证明,但这一选择是基于数学理论的。从样本计算总体方差时,除以 n−1 能够提供无偏的估计。其他调整,比如 n−2,会过度修正并导致方差的高估。
一个实际的例子来说明贝塞尔修正
假设你有一个小的群体,平均体重为 70 公斤。现在假设你从这个群体中抽取了 5 个人作为样本,他们的体重(单位:公斤)分别是 68、69、70、71 和 72。样本均值恰好是 70 公斤——与总体均值完全一致,这是巧合。
现在假设我们计算方差:
-
如果不使用贝塞尔修正:我们将把平方偏差的总和除以 n=5。
-
使用贝塞尔修正时:我们除以 n−1=4。
以这种方式使用贝塞尔修正会稍微增加我们对方差的估计,使其更接近如果我们从整个总体计算方差时的值,而不是仅仅从样本中计算。
结论
在计算样本方差时使用 n−1 进行除法,可能看起来只是一个小变化,但它对于获得无偏的总体方差估计至关重要。这个调整,即贝塞尔修正,弥补了由于依赖样本均值而不是总体均值导致的低估问题。
总结:
-
使用 n−1 补偿了我们基于样本均值计算方差的事实,而样本均值往往低估了真实的变异性。
-
当样本量较小时,贝塞尔修正尤为重要,因为直接用 n 除法会显著扭曲方差估计。
-
这一做法是统计学中的基础,影响从置信区间到假设检验的各个方面,是可靠数据分析的基石。
通过理解和应用贝塞尔修正,我们确保我们的统计分析能够反映我们研究数据的真实特征,从而得出更准确、更可信的结论。
AIML 产品 UX 最佳实践
本博客文章描述了在 AIML 用户体验(UX)中“优质表现”的实践,提供了示例,并为产品负责人指明了前进的路径。
·发表于Towards Data Science ·9 分钟阅读·2024 年 5 月 17 日
--
近年来,关于人工智能与机器学习(AIML)模型的研究投入了大量的时间、精力、泪水和资源,关注点集中在模型的规模与性能、快速发展、训练成本、安全性、延迟以及云端、本地和边缘计算中的各种模型托管选择。然而,有一个被忽视的领域就是最后一公里的产品用户体验(UX),以及如何将 AIML 最佳地融入产品中。
本博客文章描述了在 AIML 用户体验中“优质表现”的几项实践,提供了参考示例,并为正在塑造其产品和业务未来的产品经理和负责人指引了前进的道路。我还从独特的角度将这些实践与亚马逊领导原则(LP)以及由 Don Norman 博士和 Steve Krug 所阐述的 UX 原则相结合,他们分别是《日常物品的设计》和《别让我思考》的作者。这些原则和实践共同为理解在构建 AIML 解决方案时推荐特定 UX 操作和思路的原因提供了坚实的基础。
注意力就是你所需要的一切。

[作家在打字机前工作],稳定性 AI
无论是否与人工智能相关,产品特性不会产生影响,除非它们以一种方便、易用、易发现并且显而易见的方式呈现给用户。可见性在用户体验设计中至关重要,它与亚马逊的领导力原则创新与简化相一致,这表明这些特性应该直接展现于注意力和意识的表面,而无需过多的麻烦、摸索、寻找或挣扎。无论是内容、对话还是联系人,在正确的时间将正确的信息传递给正确的人是至关重要的。那么现在,让我们看看什么是优秀的表现。
例如,Google Workspaces将生成式 AI 功能直接嵌入到 Docs、Mail 和其他 Google 应用的用户体验中,使得用户无论是在撰写新文档,还是总结现有文档(无论是简短的备忘录、较长的论文还是动态的演示文稿)时,都能更加便捷。在 Google Docs 的情况下,内容会在提示下方生成,并且也会清楚标明该内容是由 AIML 模型生成的,标识为星号(*******)符号。这个表示由 AIML 生成的标识将在示例中频繁出现。

[Gemini for Google Workspaces],截图由作者提供
另一个例子是Slack AI,它允许你个性化搜索结果,并总结频道活动和长篇复杂的对话线程。在这种情况下,生成的内容会在右侧的紧凑面板中显示,作为丰富主面板的上下文细节。这种方式方便用户,不会造成信息过载,同时也清晰标明为 AI 生成的内容。

[Slack AI],截图由作者提供
最后,亚马逊现在生成并展示了产品评论摘要。根据亚马逊的说法,仅在 2022 年,1.25 亿客户为 Amazon.com 贡献了近 15 亿条评论和评分;某些单一产品的评论数达到数千条。因此,为了帮助客户了解某个产品是否适合他们,Amazon.com 在产品详情页面提供了一个简短的段落,突出显示了评论中频繁提到的产品特性和客户情感。AI 生成的评论还强调了关键的产品属性,并帮助客户更容易地查看提到这些属性的具体人类评论。同样,AI 生成的内容紧凑,清晰标明为机器生成,并且某些方面使用了图标进行视觉标识。

[Amazon.com 产品评论],截图由作者提供
你可以引用我说的话

[报纸编辑与记者室],Stability AI
在报业界,有一句老话:“如果你想让读者相信你,那么就考虑让你的来源‘公开表态’,并获得他们的同意,这样‘你就可以引用他们的话’。”在学术和科学写作中,引用书目来源是一种常见做法。这种做法的目的是赢得信任,这是 Amazon 领导原则之一,也是帮助性,另一个 UX 原则。通过一些其他人如何提供引用的良好示例,最好来说明这一做法。
我向Perplexity.ai,一个基于 AWS 构建的 AI 驱动的对话式搜索服务,询问了:“请帮我排查一个 HAProxy 安装问题,它无法在多个目标节点间正确进行负载均衡,而且仅错误地缓存了一个 IP 地址。请提供详细的逐步指导。”请注意,在下面的响应中,顶部列出了多个来源,并且在引用这些来源时,响应中散布了数字引用(例如 1)。

[Perplexity AI],作者截图
接下来,我与Amazon Q互动,并提出了这个问题:“我计划创建一个每日日请求量为 100k 的无服务器 API。每个请求都需要从数据库中读取数据。对于这种工作负载,最适合的服务是什么?”再次,服务回答了我的问题,并提供了来源,但它们附加在响应的末尾,尽管按数字排序,但实际上这些参考文献并没有在响应中明确标明,读者无法直接获得。 我还赞赏清晰的链接到负责任的 AI 政策,这是建立信任的另一种方式。哪种方法更易于阅读和理解,取决于你特定应用的上下文。引用是否有助于澄清?它们会分散注意力吗?它们最适合放在哪里?

[Amazon Q],作者截图
保持对话持续进行

[Amazon Rufus],作者截图
从静态用户界面到动态对话界面,正在出现一股UX 巨型趋势。传统的图形用户界面,无论是位于主机、桌面、网页还是移动设备上,通常由窗口、标签、面板、表单、网格、复选框、文本框、列表选择器和按钮组成。虽然这些UX 模式在让更多人类与计算机互动方面取得了巨大成功,但我们必须认识到,我们曾经需要自己学习如何与计算机工作。
在过去的一代人中,曾有人努力建立语音激活的用户界面,且已经取得了一些商业成功,如亚马逊 Alexa、苹果 Siri 和谷歌助手。然而,这种对话式用户体验趋势现在正在改变主流应用程序,超越了大科技公司,进入到消费者和商业环境中。数字聊天机器人和助手是最明显的例子,上面展示的亚马逊鲁弗斯零售助手展示了从传统搜索到对话的转变。在这种情况下,鲁弗斯了解消费者、亚马逊产品目录,并且对整个世界有一定的认知模型,能够提供推荐、向用户介绍产品、比较不同品牌和类别的产品等等。一些对话设计原则包括1/提供具有上下文的提示以引导对话,2/赋予个性化角色为文本增添情感色彩,以及3/分享历史记忆以建立关系,并基于过去的使用建立信任。更需要注意和小心的是验证哪些提示变量是强制性的,哪些是可选的。你可能还需要提示用户捕捉更多信息,并在执行更多操作之前确认即将进行的操作。此外,你还需要考虑到对话中的停顿和误解。
尽管如此,数字化对话已经成为一种常态,生成式 AI 已经让企业能够将这项技术融入到他们的产品或服务中。这一保持对话的实践与亚马逊的核心领导力原则客户痴迷和用户体验原则灵活性相一致。
通过循环向用户学习。

[加利福尼亚州英约的莫比乌斯环岩石形成], 维基共享资源/公有领域
从用户那里学习对于与商业成果挂钩的产品迭代至关重要。这一实践与亚马逊领导力原则学习与好奇以及用户体验原则反馈相一致。你无法了解用户脑中的所有信息,因此需要向他们询问。跟踪并衡量这些信息,然后将收集到的数据与结果对齐。考虑优化电子商务购物车体验,以增加订单量并减少待处理购物车,也称为转化率。将网站上的时间和点击次数作为用户参与度的代理指标。当你开始使用任何新系统时,也要进行反思。首次点击、步骤和引导过程中的前几分钟至关重要,能够最大化任务成功率并减少错误率。现在,将这些想法放到生成性 AI 的背景下,这在商业和消费者场景中都是相对较新的产品。一些常见的做法包括“喜欢/不喜欢”和“分享支持”操作。
Copilot for Microsoft 365 展示了这一实践,提供了一流的“喜欢 / 不喜欢”按钮。另一个值得称赞的方面是清晰的斜体信息:“AI 生成的内容可能不正确”。这也增强了信任感。

[Microsoft 365 CoPilot],截图由作者提供
Einstein Copilot for Salesforce 同样在拥有“喜欢/不喜欢”操作对方面做得很好,清晰地在右侧面板标注总结为 AI 生成,并区分了接受与草稿,使输出意图对用户更明确。

[Salesforce Einstein CoPilot],截图由作者提供
拥有随机森林的地图。

[随机森林的地图],Stability AI
引用路易斯·卡罗尔(《爱丽丝梦游仙境》的作者)的话,如果你不知道要去哪里,那么任何一条路都能带你到那里。支持这些 AIML 核心用户体验实践的,是传统的产品路线图,而且没有捷径可以让你快速到达这幅宏大的蓝图。支持这一实践的原则包括亚马逊的领导力原则、宏大思维以及用户体验的一致性原则。首先要识别你的目标用户、产品使用场景和他们的目标。然后倒推回去。在亚马逊,我们通常会起草一个PRFAQ文档,设想未来应该是什么样子,并从外部公关(PR)和内部常见问题(FAQ)详细角度来思考。考虑实现客户目标所需的高层解决方案信息和流程。接着,深入挖掘用户体验层以及其他技术层的解决方案组件。根据风险、投资回报率、必须实现与可选功能等因素,对解决方案和组件的里程碑进行对齐和优先级排序。确保产品、工程和销售团队达成共识。然后根据路线图执行,前提是沿途可能会做出一些调整。
结论
在这篇博客中,我们介绍了五个 AIML 产品用户体验实践:注意力就是你所需要的一切(可见性/简化)、你可以引用我说的这句话(赢得信任/有帮助性)、保持对话进行(客户痴迷、灵活性)、通过循环从用户中学习(学习与好奇心/反馈)、以及为随机森林绘制一张地图(宏大思维/创造力)。我们还讨论了来自顶级科技公司(包括亚马逊、谷歌、微软、Perplexity、Salesforce 和 Slack)的多个实际产品示例,以说明这些实践和原则。希望你能学到一些新东西,并能立即将其应用到自己的产品和解决方案中。勇敢前行,开始构建吧!
免责声明:请注意,本文内容仅代表我个人观点,并不一定代表我所在公司的立场。
喜欢这篇文章吗?请分享您的评论。关注我,了解更多更新,您可以在 Medium 和 Twitter 上找到我。
参考文献
-
www.aboutamazon.com/news/amazon-ai/amazon-improves-customer-reviews-with-generative-ai -
www.microsoft.com/en-us/microsoft-365/enterprise/copilot-for-microsoft-365
数据库设计中的技术列最佳实践
在设计事务性数据库或数据仓库时,重要的是不要忽视各种类型的技术列。
·发布于 Towards Data Science ·15 分钟阅读·2024 年 5 月 11 日
--

数据库设计中的技术列 | 图片由 DALL·E 生成,效果由作者处理
技术列在数据库设计中发挥着至关重要的作用。它们服务于多种目的:从与审计相关的事务、支持故障排除,到对 ETL/ELT 过程设计的关键影响。然而,它们常常被忽视或边缘化。在本文中,我汇集了基于我在多个数据库相关项目中的经验所获得的建议——包括网页应用程序、数据仓库、科学项目等。无论你是经验丰富的数据工程师、正在完善数据库规范的开发人员,还是有抱负的数据架构师,这些见解都将帮助你设计和构建更强大、更可靠的数据库系统。
在深入讨论具体的技术列细节之前,记得最好先在数据库规范中记录它们。你可以在规范的专门部分中一次性为所有表进行记录,而不是为每个表重复相同的描述。
记得为每个技术列指定:
- 它的名称——应在每个适用的表中一致使用,并符合定义的命名规则……
使用生存分析改善 A/B 测试
进行实验时——别忘了带上你的生存工具包
·发布于 Towards Data Science ·阅读时间:5 分钟·2024 年 7 月 31 日
--

图片来源:作者,使用 DALL-E 3
我在几篇博客中已经阐述过这个观点(这里、这里 和 这里),即使用生存分析可以改善流失预测。
在这篇博客中,我将展示生存分析如何在常见的实践中发挥作用的另一个案例:A/B 测试!
常见 A/B 测试实践中的问题
通常在进行 A/B 测试时,分析师会将用户随机分配到不同的变体,并根据每个变体中的转化数量与用户数量的比例来衡量转化率。刚进入测试的用户和在测试中已有两周的用户会被赋予相同的权重。
这种方法对于那些在分配到变体后短时间内发生转化或不发生转化的情况(例如完成入职流程)是足够的。
然而,很多时候转化是分布在较长的时间框架内。例如,第一次访问网站着陆页后的首次订单。这样的转化可能在几分钟内发生,但很大一部分可能会在首次访问后几天发生。
在这种情况下,业务 KPI 通常会被“限定”在某个特定的时间段内——例如,“7 天内的转化”或“1 个月内的流失”。
在这些情况下,不考虑时间因素地衡量转化率有两个主要缺陷:
-
它使我们所测量的统计量变得难以理解——任何时间点的平均转换率无法转换为任何有界的度量。事实上,随着测试的继续,转换率将会增加,仅仅是因为用户有更多的时间进行转换。因此,实验结果将难以与业务 KPI 关联。
-
它丢弃了可能影响结果的时间信息,相比于考虑了转换时间的其他方法,这可能导致功效降低。
为了展示第 2 点,我们将运行一个小的模拟研究
我们将让用户在 30 天内随机加入实验。用户的转换时间将从 Weibull 分布中模拟,规模参数𝜎=30,000,控制组的𝛼_ctrl=0.18,处理组的𝛼_trt=0.157。
下面是对应的生存曲线:
alpha_ctrl <- 0.18
alpha_trt <- 0.157
sigma <- 30000
conv_7d_ctrl <- format_pct(pweibull(7, alpha_ctrl, sigma))
conv_7d_trt <- format_pct(pweibull(7, alpha_trt, sigma))
t <- seq(0, 7, 0.1)
surv_ctrl <- 1 - pweibull(t, alpha_ctrl, sigma)
surv_trt <- 1 - pweibull(t, alpha_trt, sigma)
plot(t, surv_trt, type = "line", col = "red", ylab = "S(t)", xlab = "t (days)",
ylim = c(0.7, 1)) lines(t, surv_ctrl, col = "black")
legend("topright", col = c("black", "red"),
legend = c("Control", "Treatment"), lty = 1, title = "Variant" )

图片来源:作者
假设我们关注的是 7 天内的转换,控制组的真实(未知)转换率为 19.9%,而处理组的转换率为 23.6%。
下面是生成模拟数据的函数:
n <- 2000
test_duration <- 30
gen_surv_data <- function(m, alpha){
set.seed(m)
tstart <- runif(n, 0, test_duration)
tconvert <- rweibull(n, alpha, sigma)
status <- as.integer(tstart + tconvert < test_duration)
tstatus <- ifelse(status == 0, test_duration - tstart, tconvert)
return(data.frame(tstatus=tstatus, status=status))
}
为了展示在 A/B 测试中使用生存分析的好处,我们将比较 3 种检验统计量的功效:
-
转换的 T 检验(常见程序)
-
基于 7 天转换的 T 检验(使用 Kaplan-Meier 曲线估算)
-
Peto & Peto 修改版的 Gehan-Wilcoxon 检验
下面是实现上述内容的代码:
run_simulation <- function(m, alpha1, alpha2){
data_1 <- gen_surv_data(m, alpha1)
data_2 <- gen_surv_data(m+1, alpha2)
# T-test on conversions (the common procedure):
p1_hat <- mean(data_1$status)
p1_var <- p1_hat*(1-p1_hat)/length(data_1$status)
p2_hat <- mean(data_2$status)
p2_var <- p2_hat*(1-p2_hat)/length(data_2$status)
stat <- abs(p2_hat - p1_hat)/sqrt(p1_var + p2_var)
ans1 <- pnorm(stat, lower.tail = F)*2
# T-test on 7 day conversion (estimated using a Kaplan-Meier curve):
data_1$variant <- "control"
data_2$variant <- "treatment"
surv_data <- rbind(data_1, data_2)
surv_model <- summary(survfit(Surv(tstatus, status)~variant,
data = surv_data), times = 7, extend = T)
p1_hat <- 1 - surv_model$surv[1]
p1_var <- surv_model$std.err[1]²
p2_hat <- 1 - surv_model$surv[2]
p2_var <- surv_model$std.err[2]²
stat <- abs(p2_hat - p1_hat)/sqrt(p1_var + p2_var)
ans2 <- pnorm(stat, lower.tail = F)*2
# Peto & Peto modification of the Gehan-Wilcoxon test:
mgw_test <- survdiff(Surv(tstatus, status)~variant, data = surv_data,
rho = 1)
ans3 <- mgw_test$pvalue
return(data.frame(
`T-test conversions` = ans1,
`T-test KM 7 day conversion` = ans2,
`Modified Gehan-Wilcoxon test` = ans3, check.names = F))
}
在衡量功效之前,我们需要验证我们的统计量在两个变体转换率相同的情况下,是否满足期望的假阳性率𝛼=0.05(5%):
alpha <- 0.05
M <- 500
res <- Reduce("rbind", lapply(1:M, function(m)
run_simulation(m, alpha_ctrl, alpha_ctrl)))
res <- data.frame(Statistic = names(res),
`False positive rate` = format_pct(sapply(res, function(x) mean(x<=alpha))),
check.names = F, row.names = NULL)
knitr::kable(res, align = "c")

图片来源:作者
接下来,我们来检查功效:
M <- 2000
res <- Reduce("rbind", lapply(1:M, function(m)
run_simulation(m, alpha_ctrl, alpha_trt)))
res <- data.frame(
Statistic = names(res),
Power = sapply(res, function(x) mean(x<=alpha)),
check.names = F, row.names = NULL)
uplift_logrank <- format_pct((res[3,2] - res[1,2])/res[1,2])
uplift_km <- format_pct((res[2,2] - res[1,2])/res[1,2])
res$Power <- format_pct(res$Power) knitr::kable(res, align = "c")

图片来源:作者
虽然基于 Kaplan-Meier 曲线的 7 天转换 T 检验比常规的转换 T 检验(常见程序)更能与业务 KPI 相关,但其功效提升仅为边际性。
另一方面,修改版的 Gehan-Wilcoxon 统计量显著提升了功效,同时仅与业务 KPI(如常规的转换 T 检验)有较弱的相关性。
需要注意的是,功效提升的程度会根据在生存曲线上比较的点、实际的生存曲线形状、实验持续时间等因素有所不同。
在未来的文章中,我希望在更广泛的场景和测试统计量下进一步探讨这个话题(R 的ComparisonSurv包看起来很有前景)。
在进行 A/B 测试时,如果转换时间有所不同——通常应用生存分析以利用时间维度是非常有用的。可以比较生存曲线上的某个兴趣点,使结果能够直接与业务 KPI 相关,或者使用修改版的 Gehan-Wilcoxon 统计量以提高功效。
最初发布于 https://www.linkedin.com.
更好的可视化,先进的 ETL 技术,RAG 痛点及其他二月必读文章
·发布于Towards Data Science ·发送至时事通讯 ·阅读时长 4 分钟·2024 年 2 月 29 日
--
二月可能是最短的一个月,但在 TDS 这里可一点也不觉得短,我们的作者们始终保持高水平,分享了关于及时话题的强有力贡献——其中包括今年迄今为止最长且最受欢迎的文章之一。
现在,大多数人已经适应了 2024 年的节奏,我们看到读者的关注点从职业发展稍微转向了核心技能以及解决常见问题的具体方案。过去一个月我们最读和讨论的文章反映了这一点,以下是我们二月的代表性精选。
每月亮点
-
Adam 优化器背后的数学原理在一篇清晰、易懂且广泛分享的解释文章中,Cristian Leo详细解析了 Adam(自适应矩估计)优化器的数学内部机制,并在此过程中帮助我们理解了它为何成为深度学习实践者中如此受欢迎的选择。
-
12 个 RAG 痛点及建议解决方案 虽然检索增强生成(RAG)作为提升 LLM 性能的强大选项持续引起关注,但其不足之处也越来越明显。Wenqi Glantz 提供了一个有用的资源,适合那些最近在实现 RAG 系统时感到困惑的人,汇总了 12 个常见的陷阱以及建议的解决办法。
-
数据可视化 101:引人注目的视觉效果手册 对于任何想要创造“更清晰、更锐利、更智能”的视觉效果的人——而且谁不想呢?——Mariya Mansurova的最新数据可视化指南是必读之选,它通过多个具体示例(使用 Plotly)展示了设计原则的实际应用。

摄影:来自Kelly Sikkema的照片,来自Unsplash
-
初学者的高级 ETL 技术 如果你是一个处于初期阶段的数据工程师,想要提升你的数据摄取技能,💡Mike Shakhomirov的新教程是你绝对应该探索(并收藏)的内容:它涵盖了典型的摄取模式,并提供了可以用来开始自己动手实验的代码片段。
-
高级检索增强生成:从理论到 LlamaIndex 实现 想深入了解激动人心的 RAG 世界吗?Leonie Monigatti 解释了检索前、检索和检索后优化的细节,然后带我们了解如何将一个“天真”的 RAG 管道转变为一个先进的版本。
-
RAG 失败的顶级评估指标 本周我们再次聚焦 RAG,这次是Amber Roberts的最新贡献:一个关于排查意外或不尽人意的性能问题的实用资源,并且介绍了如何应用强大的响应和检索评估指标,确保管道中的所有环节协调工作。
-
在 2024 年构建数据平台在首次讨论这一话题三年后,我们很高兴迎回了Dave Melillo,他的新文章重新评估了高效数据平台的关键组成部分。他基于自己在多个行业中应对数据挑战的经验,分享了宝贵的见解,并曾与“大型企业和灵活的初创公司”合作过。
额外的 Python 知识
过去几周我们发布的部分最受欢迎的文章,涉及了始终不过时的主题——针对数据和机器学习专业人士的 Python 编程。如果你错过了这些内容:
-
作为完全初学者,应该如何学习 Python?Egor Howell提供了一个清晰而实用的路线图。
-
如果你还不熟悉@property 装饰器,在你读完Siavash Yasini的全面介绍后,你一定会了解它。
-
对 AI 应用开发感兴趣的朋友,应该看看Naomi Kriger的实践教程,学习如何使用 pyttsx3 库创建一个语音转文字再转语音的程序。
-
以 Robert C. Martin 的经典著作Clean Code为灵感,Patrick Brus概述了编写——你猜对了——简洁高效的 Python 代码的核心原则。
-
欲了解更多 Python 教程和项目实操,千万不要错过我们的近期高级及特殊应用案例汇总。
我们最新的一批新作者
每个月,我们都很高兴看到一批新作者加入 TDS,他们每个人都在与我们的社区分享自己独特的声音、知识和经验。如果你正在寻找新的作家来探索和关注,只需浏览我们最新加入的作者的作品,包括Sarthak Handa、Vadim Arzamasov、Mahyar Aboutalebi, Ph.D. 🎓、James W、Mohammed Mohammed、Kirsten Jiayi Pan、Matthew Chak、Ugur Yildirim、Mikayil Ahadli、Hamza Gharbi、Sami Abboud、Matthew Gunton、Eivind Kjosbakken、Eva Revear、Nithhyaa Ramamoorthy、Rami Krispin、Kennedy Selvadurai, PhD、Vassily Morozov、Patrick Beukema、Thomas Rouch、Ritanshi Agarwal、Rohan Nanda、Nikolaus Correll、Mert Ersoz、Dani Lisle、Roberta Rocca、Adil Rizvi、Matthew Turk、Celia Banks, Ph.D.、Skylar Jean Callis、Ryan McDermott、Anand Subramanian、Aayush Agarwal、P.G. Baumstarck、Jose D. Hernandez-Betancur、Khin Yadanar Lin、和Daniel Kang,等等。
感谢你支持我们作者的工作!如果你感到受启发并想加入他们的行列,为什么不写下你的第一篇文章?我们很想阅读。
直到下一个变量,
TDS 团队
超越 AlphaFold:LLM 在医学中的未来
|ALPHAFOLD3|LLMs|LLMs 与医学|
AlphaFold 留下了复杂的遗产:LLM 在生物学和医学中的未来将会怎样?
·发表于 Towards Data Science ·14 分钟阅读·2024 年 6 月 13 日
--

图片由作者使用 AI 制作
在真正的开源中,你有权掌控自己的命运。 — 林纳斯·托瓦兹
AlphaFold3 到来,但并没有像 DeepMind 预期的那样获得如此热烈的欢迎。无论如何,这是生物学面临的革命的又一章节。LLM 正在革命化医学和制药领域。在 AlphaFold2 发布三年后,计算生物学已经发生变化,是时候进行一些反思了。
-
为什么 AlphaFold3 标志着一个双刃剑的转折点?
-
为什么研究界如此失望?
-
风向如何变化?社区如何回应?
-
LLM 在生物学中的未来会怎样?
我们在本文中讨论这个问题。
AlphaFold3 的到来
当 AlphaFold2 发布时,它看起来像是一次革命的曙光。近一个世纪以来,从...预测蛋白质结构。
超越注意力机制:先进的位置嵌入方法如何在 Transformer 架构中改进原始方法
从正弦波到 RoPE 和 ALiBi:先进的位置编码如何克服 Transformer 中的局限性
·发布于Towards Data Science ·阅读时间:9 分钟·2024 年 10 月 29 日
--
作者: Elahe Aghapour, Salar Rahili
引言:
近年来,模型的指数级进展与 Transformer 架构的出现密切相关。以前,人工智能科学家需要为每个任务选择架构,然后优化超参数以获得最佳性能。另一个限制其潜力的挑战是处理数据长程依赖性的困难,导致了梯度消失、长序列中上下文丢失的问题,以及由于局部性约束无法捕获全局上下文。此外,传统模型缺乏可扩展性和并行化,减缓了大数据集的训练进程,阻碍了该领域的进展。
Transformer 架构通过其自注意力机制解决了这些问题,彻底改变了这一领域。它使得模型能够捕捉长序列之间的关系,并高效地理解全局上下文,同时具有高度的并行化能力,并能在各种模态下适应,例如文本、图像等。在自注意力机制中,对于每个令牌,它的查询与所有其他令牌的键进行比较,以计算相似度得分。这些相似度然后用于加权值向量,最终决定当前令牌应该关注哪里。自注意力将所有令牌视为同等重要,而不考虑它们的顺序,丧失了关于令牌出现顺序的关键信息,换句话说,它将输入数据视为没有顺序的集合。现在,我们需要一种机制来强制在数据中施加某种顺序的概念,因为自然语言和许多其他类型的数据本质上是顺序性的和位置敏感的。这就是位置嵌入发挥作用的地方。位置嵌入编码了每个令牌在序列中的位置,使模型能够保持对序列结构的意识。已经探索了多种编码位置信息的方法,我们将在本博客文章中讨论这些方法。

由 DALL-E 生成的图像
注意力机制:
设S = {wi},其中i = 1, …, N,表示一个由N个输入令牌组成的序列,其中wi表示第i个令牌。因此,S的对应令牌嵌入可以表示为E = {xi},其中i = 1, …, N,xi是第i个令牌wi的d维令牌嵌入向量。自注意力机制将位置嵌入融入令牌嵌入,并生成查询、键和值的表示形式,如下所示:

然后,注意力权重根据查询和键向量之间的相似度计算:

注意力权重决定了令牌n对令牌m的重要性。换句话说,就是令牌m应该给予令牌n多少注意力。令牌m的输出是作为值向量的加权和计算的:

因此,注意力机制使得令牌m能够从序列中的其他令牌收集信息。

图 1. Transformer 架构中的位置编码(图像来自论文)。
1. 绝对位置嵌入:
方程(1)的一个典型选择是:

其中pi是一个d维向量,表示令牌xi的绝对位置。正弦位置编码和学习的位置编码是生成pi的两种替代方法。
1.a 正弦位置编码
正弦位置编码在“Attention is all you need”论文中被引入,提出了 Transformer 架构。正弦位置编码为输入序列中的每个 token 提供了一个独特的位置表示。它基于具有不同频率的正弦和余弦函数,如下所示:

其中,pos是 token 在序列中的位置,d是位置嵌入的维度,i 是维度索引(0<=i<d)。
正弦位置编码中使用的正弦和余弦函数与傅里叶变换有着深刻的关系。通过使用不同频率的范围来编码位置,Transformer 创建了一个类似于傅里叶变换的表示,其中:
-
高频成分(较低的i)使模型能够捕捉到局部位置的关系。这对于理解序列中相邻 token 之间的关系非常有用,例如词对。
-
低频成分(较高的i)捕捉整个序列中的更全局的模式。这有助于模型关注可能相距较远的 token 之间的广泛关系,例如两句话之间的依赖关系。
这有助于模型通过比较位置编码来理解 token 之间的相对位置。正弦位置编码不需要额外的训练参数,并且在推理时能泛化到更长的序列长度。然而,它的表现力有限。
1.b 学习型位置编码
学习型位置编码在“Attention is all you need”论文中被引入,并在BERT和GPT模型中作为正弦位置编码的替代方案。对于学习型位置编码,序列中的每个位置(例如,第一个 token,第二个 token 等)都会分配一个嵌入向量。这些位置嵌入在训练过程中与其他 Transformer 参数一起学习。例如,如果模型的上下文长度为 512,token 嵌入的维度为 768(即d=768),则会将一个大小为 512*768 的可学习张量加入到其他可训练参数中。这意味着模型逐渐学习如何为特定任务(如文本分类或翻译)编码位置信息。
学习型位置嵌入比正弦位置嵌入更具表现力,因为模型可以学习一个特定任务有效的位置嵌入。然而,它们引入了更多的可训练参数,从而增加了模型的大小和计算成本。
2. 相对位置嵌入
正弦型和学习型位置编码侧重于标记的绝对位置。然而,注意力机制是通过计算其他标记对每个特定标记的重要性来工作的。因此,这一过程依赖于标记的相对位置(它们之间的距离),而不是标记的绝对位置。为了解决绝对位置嵌入的局限性,引入了相对位置编码。
RelativePosEmb 不将位置信息添加到标记嵌入中。相反,它修改了每一层计算键(key)和值(value)的方法,如下所示:

这里,r = clip(m-n, Rmin, Rmax) 表示位置 m 和 n 之间的相对距离。最大相对位置会被裁剪,假设超出某个距离后的精确相对位置没有用处。裁剪最大距离使得模型在推理时能够外推,即能够推广到训练时未见过的序列长度。然而,这种方法可能会丢失一些来自标记绝对位置的有用信息(例如第一个标记的位置)。
你可能会注意到 fq 缺少位置编码。这是因为我们正在编码相对位置。在注意力公式中,查询(query)和值(key)用于计算注意力权重,如公式(2)所示,因此我们只需要查询或键之一来包含相对位置编码。
这种编码已在许多模型中使用,如 Transformer-XL 和 T5。在应用相对位置编码时,有不同的替代方法,可以在文献 [7] 和 [8] 中找到。
3. 旋转位置嵌入(RoPE)
与以往的方法不同,RoPE基于标记的位置,在多维空间中旋转向量。它不是将位置信息添加到标记嵌入中,而是修改了每一层计算注意力权重的方式,如下所示:

他们提出了一个通用的旋转矩阵,适用于任何偶数维度的嵌入 d,公式如下:

其中,θi 是预定义的:

将 RoPE 应用到注意力权重中得到:

请注意,RoPE的公式并未将位置信息添加到注意力模块的值中。注意力模块的输出是值向量的加权和,由于位置信息没有添加到值中,每个 Transformer 层的输出没有显式的位置信息。
像LLaMA和GPT-NeoX这样的流行模型正在使用RoPE。

4. 线性偏置注意力(ALiBi)
ALiBi也不将位置编码添加到词嵌入中;相反,它通过对注意力权重得分施加一个与标记之间距离成比例的惩罚。因此,两个标记 i 和 j 之间的注意力得分在每一层的计算方式如下:
Attention score = query_i . key_j — m.(i-j)
其中,-m.(i-j)是一个惩罚项,与标记i和j之间的距离成正比。标量m是一个特定于头部的斜率,在训练前就已固定,其值对于不同的头部会按照几何序列选择。例如,对于 8 个头部,m可能是:

这意味着,第一个头有相对较大的m值,因此它对远离的标记惩罚更多,专注于最近的标记,而第八个头有最小的m值,允许它关注更远的标记。图 2 也提供了可视化展示。
ALiBi 被用于BloombergGPT和BLOOM。
推理时的 Transformer 外推:
Transformer 推理时的外推能力指的是模型在输入序列比训练时的序列更长的情况下,仍然能够良好地表现。Transformer 机制对输入长度是不可知的,这意味着在推理时,它可以处理更长的序列。然而,需要注意的是,计算成本随着输入长度的增加呈二次增长,即使 Transformer 层本身对输入长度是不可知的。
ALiBi的作者证明了,transformer 外推的瓶颈在于其位置编码方法。如图 3 所示,他们比较了不同位置编码方法的外推能力。由于学习到的位置编码无法编码比训练长度更长的位置,因此它没有外推能力。

图 3:外推:随着输入序列变长(x 轴),正弦、RoPE和T5的位置编码表现出困惑度下降(y 轴,数值越低越好),而ALiBi则没有(图像来源于论文)。
图 3 显示了实际应用中,正弦位置嵌入的外推能力非常有限。尽管RoPE优于正弦嵌入,但仍然未能取得令人满意的结果。T5偏置方法(相对位置嵌入的一种形式)在外推方面优于正弦嵌入和RoPE嵌入。不幸的是,T5偏置在计算上开销较大(见图 4)。ALiBi在所有这些位置嵌入中表现最佳,且仅有轻微的(0-0.7%)内存增加。

图 4:批量训练、推理速度和正弦、RoPE、T5和ALiBi位置编码的比较(图片来自论文)。
结论:
总结:在 Transformer 架构中,位置编码方式显著影响其理解顺序数据的能力,尤其是在推理时的外推能力。虽然绝对位置嵌入方法提供了位置信息,但它们常常难以应对 Transformer 的外推能力。因此,提出了更新的位置嵌入方法。相对位置编码、RoPE 和 ALiBi 具备在推理时外推的能力。随着 Transformer 不断在各类应用中被集成,优化位置编码对于提升其性能边界至关重要。
本文中表达的观点仅代表我们个人的看法,并不反映我们雇主的立场。
参考文献:
1 Vaswani, A. “Attention is all you need.” (2017)。
[2] BERT:Devlin, Jacob. “Bert:用于语言理解的深度双向 Transformer 预训练。” (2018)。
[3] GPT:Radford, Alec 等. “语言模型是无监督的多任务学习者.” (2019)。
[4] RelativePosEmb:Shaw, Peter 等. “带有相对位置表示的自注意力.” (2018)。
[5] Transformer-XL Dai, Zihang. “Transformer-xl:超越固定长度上下文的注意力语言模型。” (2019)。
[6] T5:Raffel, Colin 等. “通过统一的文本到文本转换器探索迁移学习的极限。” (2020)。
[7] Raffel, Colin 等. “通过统一的文本到文本转换器探索迁移学习的极限。” (2020)
[8] He, Pengcheng, 等. “Deberta:解码增强的 BERT 模型,具有解耦注意力。”(2020 年)。
[9] RoPE: Su, Jianlin, 等. “Roformer:具有旋转位置编码的增强型变换器。”(2024 年)。
[10] LLaMA: Touvron, Hugo, 等. “Llama:开源且高效的基础语言模型。”(2023 年)。
[11] GPT-NeoX: Black, Sid, 等. “Gpt-neox-20b:一个开源自回归语言模型。”(2022 年)。
[12] ALiBi: Press, Ofir, 等. “训练短,测试长:具有线性偏置的注意力机制能够进行输入长度外推。”(2021 年)。
[13] BloombergGPT: Wu, Shijie, 等. “Bloomberggpt:用于金融的大型语言模型。”(2023 年)。
[14] BLOOM: Le Scao, Teven, 等. “Bloom:一个具有 1760 亿参数的开源多语言语言模型。”(2023 年)。
超越微调:合并专业化 LLM 而不增加数据负担
从模型融合到自动进化合并:利用专业化 LLM 融合以减少数据需求并消除大量的微调。
·发表于 Towards Data Science ·10 分钟阅读·2024 年 8 月 13 日
--
作者: Elahe Aghapour, Salar Rahili
引言:
计算机视觉和自然语言处理领域正在迅速发展,这导致对为特定下游任务微调的专业化模型的需求不断增长。然而,拥有多个不同的微调模型也带来了多个缺点:
- 每个任务都必须存储和部署一个单独的模型(这个问题可以通过应用像 LoRA 这样的微调方法来解决)。
2. 独立微调的模型无法利用来自相关任务的信息,这限制了它们在领域内和领域外任务中的泛化能力。然而,多任务学习需要访问每个特定任务的数据集,且整合这些数据集可能会变得复杂。如果我们无法访问所有下游任务的数据集,但有可用的微调模型,该怎么办?假设你需要一个在一组特定任务上微调的大型语言模型(LLM)。与其收集大量下游任务的数据集并经历资源密集的微调过程,你可以找到已在每个任务上微调的 LLM 模型,并将这些模型合并成所需的模型。需要注意的是,在庞大的 Hugging Face 仓库中,寻找这样的模型并不困难,该仓库托管着大约 50 万个微调模型。近年来,合并多个模型已受到广泛关注,主要是因为它要求计算资源少,且不需要训练数据。

图 1 模型集成通过结合多个模型的输出提高准确性,但需要更多的计算资源。多任务学习同时在多个任务上训练一个模型,要求访问所有数据集并具备高计算能力。而模型合并则将预训练的模型融合为一个,利用它们的优势,计算量小且无需额外的训练成本,提供了一种高效的解决方案(图片来源:论文)。
随着对模型合并的关注度不断增长,公共库如 WEBUI 和 MergeKit 也应运而生,以便简化这一过程。WebUIs 使得合并微调过的模型(如 Stable Diffusion)成为可能,且支持多种合并技术。MergeKit 是一个开源的集中式库,提供不同的合并方法。它通过高效的合并技术实现,支持在任何硬件上进行模型合并。
在这里,我们将合并方法分为三大类:
1. 合并具有相同架构和初始化的模型,
2. 合并具有相同架构但初始化不同的模型,
3. 合并具有不同架构的模型。
每个类别涉及不同的技术来有效地合并模型,下面将进行详细解释。
1. 合并具有相同架构和初始化的模型:
1.a 无需数据的合并:
本节中的模型合并方法都基于线性模式连接(LMC)。LMC 表明,对于具有相同架构和初始化的模型,它们的检查点之间的损失可以通过一条低损失的线性路径连接。这意味着这些模型可以通过线性插值进行合并。
为了微调一个模型,可以应用不同的配置,如不同的学习率、随机种子和数据增强技术,这些都会导致不同的模型参数。模型汤建议对这些参数进行平均化,因为这些模型已经学习到了相似的表示,并且在参数空间中彼此接近。加权模型平均可以得到一个平坦的局部最优解,并且对超出分布的任务具有更好的泛化能力[参见13, 14]。

图 2 Pl 显示了模型汤合并的结果,而 Ps 显示了 SLERP 合并的结果(图片由作者提供)。
SLERP(球面线性插值,首次介绍见此处)是一种在计算机图形学和动画中常用的技术,用于在由四元数表示的旋转之间平滑插值。SLERP 也适用于模型合并。它通过沿球面路径插值来合并两组模型参数,而不是沿直线路径。图 2 显示,对于给定的两个模型参数 p1 和 p2,SLERP 沿着地球表面合并这些参数,提供平滑的过渡。这种方法常用于合并大型语言模型(LLMs)。
假设给定两个多层感知机(MLP)模型,每个模型都在不同的下游任务上进行了微调。SLERP 可以通过以下步骤合并这两个模型:
步骤 1:对于每个模型参数,将其展平并连接成向量 v1、v2
步骤 2:将向量 v1 和 v2 正规化到单位超球面上(得到 v1′ 和 v2′)。
步骤 3:计算这两个向量之间的角度 θ(以弧度为单位)。
步骤 4:使用 SLERP 公式计算 Vslerp:

其中,t 是插值参数,当 t=0 时仅使用模型 1,而 t=1 时仅使用模型 2。
线性加权平均技术,如模型汤和 SLERP,在计算机视觉领域中已广泛应用,从图像处理和分类模型到图像生成模型,如潜在扩散模型。
任务算术提出了一种基于任务向量的方法。任务向量通过从同一模型的预训练权重(θpre)中减去针对特定任务微调后的权重(θft)来计算,公式如下:
τ = θft − θpre。这个向量表示一个方向,在预训练模型的权重空间中,沿着该方向移动能够提高该任务的性能。任务向量可以通过算术运算,如取反和加法,进行组合。取反任务向量 (θpre — τ) 会减少模型在目标任务上的性能(遗忘),对控制任务的影响最小。为了提高预训练模型在多个任务上的表现,我们可以为每个任务最初学习一个任务向量。然后通过将这些任务向量相加 (θpre+∑τi),我们可以提升模型同时处理多个任务的能力。
TIES 解决了在合并任务向量 (∑τi) 时由于参数干扰而导致的性能下降问题。这个问题可以通过以下三步解决(见图 3):
(1) 修剪每个任务向量至最大幅度的前 k%(通常 k = 20),
(2) 对于每个非零参数,选择所有任务向量中总幅度最大的符号,以避免冲突的变化,且
(3) 仅合并来自与选定符号相同的任务向量的值。

图 3 TIES 涉及的步骤示意图。模型中的每个参数被表示为一个方框。箭头表示由在不同任务上微调所产生的更新(任务向量,τ)对参数的影响,箭头的方向表示符号,长度表示幅度。1- 根据幅度修剪任务向量的值,2- 通过解决符号冲突,选择每个参数的符号(γm,绿色向量包含 +1 或 −1),3- 仅选择与选定符号对齐的值,并取它们的平均值作为最终的参数值。(图像来自 论文)
DARE 主要集中于 LLM 模型的融合,并识别任务向量中的极端冗余(τ = θft−θpre)。它提出了一个三步法:
1- 随机丢弃 p%(通常 p = 90)任务向量的值,
2- 将其余参数按 1/(1 − p) 的因子进行重新缩放,且
3- 合并 (θpre + λi ∑τi)
其中 λi 是缩放项,表示每个任务向量在合并时的重要性。
1.b 与数据需求的融合:
我们上面讨论的合并方法不需要数据。然而,也有一些方法确实需要数据来确定合并参数的最优权重。这些方法通过使用数据计算激活值,然后相应地调整权重。
其中一种方法是费舍尔合并。给定 K 个微调模型,每个模型在不同的下游任务上进行训练,起始点是特定的预训练检查点,费舍尔合并对每个模型的参数进行加权求和。权重是通过费舍尔信息矩阵计算的,矩阵构造需要每个任务的数据。
在相关的发展中,RegMean通过将模型合并任务重新表述为线性回归问题,显著超越了费舍尔加权合并。该方法为线性层的权重推导出闭式解,并均匀插值其他权重(如层归一化和偏置项)。给定 K 个微调模型和一些数据 Xi,i=1,..,K,对于每个任务,可以按如下方式确定合并模型的线性层:

其中 Wi 是来自第 i 个微调模型的线性层。
2. 合并具有相同架构但不同初始化的模型
给定具有相同架构和训练数据集但不同初始化的模型,像线性模型组合这样简单的合并方法通常无法很好地执行。主要原因是模型的权重没有对齐。因此,研究人员开发了利用神经网络置换对称性的技术。通过重新排列模型的神经元,可以使它们的权重更好地对齐,从而使合并过程更加有效。
Git-Rebasin建议置换一个模型的权重以匹配另一个模型的配置。假设给定两个模型 A 和 B,它们具有相同的架构和训练数据集,但初始化和训练数据的顺序不同。每个网络的权重可以进行置换而不改变其功能,这意味着交换隐藏层中的神经元可以得到功能等效的模型。
他们将此问题表述为一个优化问题,以识别跨层单元的最佳置换,从而在权重空间中对齐两个模型的参数。这种对齐确保了模型处于损失景观中的相似“盆地”,从而导致平滑且有效的合并。为了达到这一目标,Git-Rebasin 提出了以下三个步骤:
1. 对于每一层,寻找最佳置换的问题被公式化为线性分配问题(LAP)。这一步骤涉及计算激活矩阵,并找到对齐激活的最优置换矩阵。
2. 给定所有层的最优置换,模型 B 的权重将被置换。
3. 模型 B 的置换权重与模型 A 的权重之间的线性模型组合位于损失景观中的低损失区域,这确保了合并后的模型表现良好。
REPAIR解决了 Rebasin 合并方法中的一个关键问题——方差崩溃问题,其中隐藏单元的激活方差显著小于原始网络相应单元的激活方差,尤其是在它们被插值之前。因此,神经元的激活在更深层次上几乎变得恒定,导致网络无法区分输入。REPAIR 通过重新缩放插值网络的激活值,使其与原始网络的统计属性匹配,从而解决了这个问题。通过调整激活的均值和方差,插值网络保持了其各层的功能变异性。应用 REPAIR 方法显著降低了插值壁垒,提升了插值模型的性能。
3. 合并具有不同架构的模型
与目前讨论的方法相比,Frankenmerging并不会将多个模型合并为一个模型,而是将不同模型的不同层按顺序堆叠。因此,它能够合并具有不同架构的模型。
例如,为了构建一个具有 40 层的 LLM,可以将第一个 LLM 的前 24 层堆叠到另一个 LLM 的第 25 到 40 层。这种方法在计算机视觉中的风格迁移中得到了广泛关注。尽管需要大量的反复试验和实验,但它已经带来了如 Goliath 和 Solar-10.7B 等令人印象深刻的 LLM 模型[见此处]。

图 4 进化优化方法概述(图像来自论文)。
进化优化提出了一种框架,旨在自动合并给定的一组基础模型,使得合并后的模型在给定集合中超过任何单个模型的表现。该方法包括两个主要阶段(见图 4):
在第一阶段,该方法使用 TIES-Merging 与 DARE 结合进行 N 个基础模型的逐层合并。该过程通过使用进化算法进行优化,算法由特定任务的度量标准指导(例如,MGSM 的准确度、VQA 的 ROUGE 得分)。为了寻找未知变量,如 DARE 中的丢弃率百分比以及在合并过程中每个模型参数的权重,进化优化从一组可能的解开始,并随着时间的推移不断演化。通过变异(小的随机变化)和交叉(组合两个解的部分),选择最优解来生成新的候选解组。这一迭代过程带来了逐步更好的解决方案。
在第二阶段,给定一组 N 个模型,目标是通过 Frankenmerging 找到一个具有 T 层的最优模型。为了减少搜索空间并使优化过程可行,所有层按顺序排列(即第 i 个模型中的所有层,紧接着是第 i+1 个模型中的层),并重复 r 次。在此阶段,目标是找到一个最优指示符,确定是否包含/排除某些层:如果 Indicator(i) > 0,则第 i 层包含在合并后的模型中;否则,排除该层。
EvolutionaryOptimization 过程从将第一阶段应用于一组模型开始。然后,将第一阶段的合并模型添加到给定集合中,并在这个扩展的集合上应用第二阶段,找到一个最优指示符,选择 T 层作为最终合并模型的层。此方法应用于将一个日语 LLM 与英语数学 LLM 合并,构建日语数学 LLM。合并后的模型在多个已建立的日语 LLM 基准上取得了最先进的性能,甚至超过了具有显著更多参数的模型,尽管该模型并未针对这些任务进行训练。
本博文中表达的观点仅代表我们个人的意见,并不反映我们雇主的立场。
另请阅读我们之前的文章: 从单模态到多模态:构建基础模型的 DIY 技术
参考文献:
1 Model soup: Wortsman, Mitchell, 等人。“模型汤:平均多个微调模型的权重可以提高准确性而不增加推理时间。” *(2022 年)。
[2] Task arithmetic: Ilharco, Gabriel, 等人。“通过任务算术编辑模型。”(2022 年)。
[3] TIES: Yadav, Prateek, 等人。“Ties-merging:在合并模型时解决干扰问题。”(2024 年)。
[4] DARE:Yu, Le, 等人。“语言模型就是超级马里奥:从同源模型中吸收能力作为免费午餐。” *(2024 年)。
[5] Fisher MergingMatena, Michael S., 等人。“通过 Fisher 加权平均合并模型。”(2022 年)。
[6] RegMean: Jin, Xisen, 等人。“通过合并语言模型的权重实现无数据知识融合。”(2022 年)。
[7] Git-Rebasin: Ainsworth, Samuel K., 等人。“Git re-basin:合并模型时考虑置换对称性。”(2022 年)。
[8] REPAIR: Jordan, Keller, 等人。“Repair:重新归一化置换激活以进行插值修复。”(2022 年)。
[9] 弗兰肯合并: Charles O. Goddard. 2024. mergekit.
[10] 进化优化: Akiba, Takuya 等人。“模型合并方案的进化优化。”(2024)。
[11] Shoemake, Ken. “使用四元数曲线进行旋转动画。”(1985)。
[12] LMC: Nagarajan, Vaishnavh 等人。“均匀收敛可能无法解释深度学习中的泛化。”(2019)。
[13] Kaddour, Jean 等人。“平坦最小值优化器何时有效?”(2022)
[14] Petzka, Henning 等人。“相对平坦性与泛化。”(2021)
超越 FOMO——在人工智能领域保持最新动态
不要感到压力,要享受这个过程
·发布在 Towards Data Science ·7 分钟阅读·2024 年 6 月 11 日
--

照片由 Mukuko Studio 提供,来源:Unsplash
*“我正在招聘一位开发人员,将 gpt4o 集成到我们的产品中。
要求:五年相关经验。”* - 2024 年 5 月,LinkedIn 上的一位未知用户
我大约在 9 年前还是学生时,开始接触数学建模。在完成了理论性很强的数学本科学位后,我选择了一些与数学建模和经济问题优化相关的硕士课程。那时我最喜欢的课题是时间序列。那时了解不同建模方法相对轻松,经过验证的方法已经存在了十多年,并且没有迅速变化。
几年前,进入数据科学领域时,情况也类似。基本的技术和模型学习起来相对较快。在实施过程中,很多内容都是从零开始,自建网络并使其运行。新的工具和技术备受欢迎并尝试使用。
今天的感觉则不同了。现在,当人们浏览 X 或 LinkedIn 的动态时,几乎每周都会收到关于重要工具和发展的新闻。
自从 2022 年 11 月 ChatGPT 发布以来,关于 LLM 的炒作已经变得极为剧烈。开源与闭源之间的竞赛拉开了序幕。谷歌推出了 Gemini,Meta 发布了 LLama,斯坦福大学则推出了 Alpaca。应用程序通过像 Langchain 这样的工具得以操作化,并且有一整套工具正在出现,用于标准化应用程序。调优机制不断得到改进。然后,还有 xgboost 2 的发布。
轮子似乎转得越来越快。近年来,这主要归功于 GenAI 方法的突破以及 MLOps 领域不断增长的工具箱。
跟上进展非常重要:市场上发生了什么?尤其是当你作为顾问在这一行业工作时。我们的客户想知道:什么是目前最热的新技术?我们如何能将其盈利化?
如今,保持进展非常重要!那些不这样做的人会很快失去联系。
是这样的吗?
FOMO 正在袭来。
上次我参加一个大型会议时,整整两夜未能入眠,几乎无法入睡。这不仅仅是因为在演讲前的紧张情绪,更因为在如此短的时间内,海量的信息不断向我扑来。
会议真是太棒了。我喜欢结识新朋友,了解不同的方法,并交换一些可能对我来说完全陌生的想法和问题。然而,那几晚我几乎没怎么睡觉。“我需要稍后再深入研究一下”的待办事项清单似乎根本无法完成。FOMO(错失恐惧症)悄然袭来。脑海中浮现出这样的想法:“现在跳上 GenAI 的列车还来得及吗?”在那一刻,我忽视了自己也在偏见之中。我的演讲是关于我们与一个客户一起实施的一个用例。两年的工作压缩成了三十分钟。观众是否按照预期从中获得了有价值的启示和思考?还是这个贡献也悄悄引发了 FOMO?
另一个反复出现的现象是冒名顶替综合症1。它描述了对自身能力产生强烈怀疑的现象,并伴随着被揭露为“骗子”的恐惧。患有冒名顶替综合症的人常常觉得自己不具备胜任自己职位或任务的能力,甚至通过与他人对比,产生瞬间的自我感知:“我其实什么都做不好。”
从与我工作环境中的人们的真诚交流中,我知道这种情况时不时会出现。曾与我交流过的人,我认为他们都具备非常高的经验和专业水平。几乎所有人都有过这种感觉。
技术的多样性和人工智能领域的快速进展也可能引发这种现象。
从哪里开始?
数据科学的核心要素是什么?它是一个能够创造附加价值的有效系统。如果你不是研究人员,而是业务中的数据科学家,那么重点就放在应用上。一个模型或启发式方法能够学习人类无法在如此细节上学习的逻辑,并/或在如此大规模上应用。这不必是一个端到端、完全自动化的解决方案。
应该从开发一个有效且得到利益相关者接受的系统开始。一旦建立了对该系统的信任,就可以着眼于进一步改善的部分。
是方法论吗?也许目前使用的某个算法可以被一个深度学习架构所替代,这种架构能够表示变量之间更多的相关性。
是运行时间吗?是否可以通过其他框架或并行化来缩短运行时间?如果是这样,那就可以着手深入研究这个话题。
也许它还包括系统化地捕捉和管理数据质量。数据验证工具可以帮助早期发现数据不平衡、识别漂移,并监控机器学习系统的输出。
小心翼翼地逐步接近新技术,并持续改进现有系统是可行的。
成长需要时间。
说实话,学习新方法和新技术需要时间。有很多方法可以快速获得概览:tl;dr 摘要、概览仓库、YouTube 频道等。然而,如果我不花更多时间去深入了解这些话题,我很快就会忘记它们。因此,为了熟悉一个特定的主题或技术,我不得不偶尔抽出一个晚上或一个星期六来深入研究。
个人知识获取需要时间这一事实,直接揭示了每个人都有的局限性。
另一个方面是,经验无法强迫。采用新技术的能力也随着已有经验的积累而增加。评估技术和工具的能力也一样。个人经验越丰富,越容易理解。但这要求先对其他技术有更深入的理解,而这种理解只能通过亲身实践获得。
连接并收集信息。
不要害怕提问。在更高层次上尝试并没有错。但有时主动寻求经验也是值得的。也许你的公司或网络中已经有人使用过技术 xy?为什么不一起共进午餐,讨论一个共同的话题呢?这一切的基本前提是:处在一个可以提问的环境中(!)。
此外,要保持参与感。如上所述:保持事物记忆的最佳方式就是通过实践。然而,这并不意味着不值得保持系统性地关注左右两边,并随时了解那些不在(当前)工作范围内的新闻。现在有许多优秀的新闻简报。一个非常好的简报是由 DeepLearning.AI 发布的The Batch [2]。
创建正确的结构。
我在一个由六位数据科学家组成的团队中工作。之前提到的观察同样适用于这里:即使在这个相对较小的团队中,也有人可能会受到冒名顶替症候群的影响。毕竟,总有某个人经验更丰富,或者至少在某个特定话题、方法或工具上有一些经验。
在我们的团队中,我们每两周举行一次实践社区会议。我们制定了两项政策:
1. 我们总是从高层次开始,确保所有成员都能跟上进度,并且不假设每个人都已经深入理解该主题。这样我们就可以进一步深入探讨。
2. 强烈建议大家集体探索一个尚未有人开发出深入专业知识的话题。
在上一次会议中,我们讨论了微调 LLM 与少量样本学习和提示语的对比。我们一起探索并尝试了各种微调方法。更重要的是,我们对业务问题获得了一系列宝贵的见解,帮助确定哪些机制可能更有效。我们带着许多好点子和进一步的研究任务离开了会议。这比对每个细节的深入了解更为宝贵。
最近,我在一个数据科学聚会上有了一个令人耳目一新的体验。
这些演讲者在物流行业工作,并开发了一个大语言模型(LLM)系统。目标是从带有附件的未结构化指令的电子邮件中提取信息,并将其转化为结构化的输出,以便基于这些输出触发货物的运输。
他们展示了自己的系统,包括 OCR 和不同的 LLM API 调用,系统部署在云环境中。随后,他们分享了当前的提示语(prompts)以及使用不同提示语和模型的历史尝试。这还包括与拟合优度指标(!)的比较。演讲最后提出了两个开放性问题。他们请求反馈并征求改进建议。同时,还展开了一个讨论,探讨通过 API 使用专有 LLM 是否会影响人工智能工程与数据科学之间的平衡。
我非常喜欢这一点。除了披萨和社交活动,聚会不就是这样的意义吗?创造双赢的局面,带着好心情和新想法回家。
接下来继续。
有时,人工智能相关信息和新闻的洪流让人感到不知所措。越是想保持在前沿,越是会有这种感觉。然而,实际上不可能深入了解所有内容。幸运的是,目前还没有人发明出克隆技术。
我们生活在一个令人激动的时代。人工智能的量子飞跃正在发生,并等待着被用来造福社会。接触这些技术的障碍正在降低:专有模型正受到开源项目的挑战。论文和代码大多都可以获取。在线上,许多优秀的导师愿意分享他们的经验并教授知识,进一步减少了障碍。这使得许多人不仅能参与到人工智能的进步中,还能在其中发挥作用。成为这个伟大社区的一部分,是一件非常棒的事情。
这不应该是有压力的,而应该是充满喜悦的。终身学习的伟大之处在于:它永无止境。
保持冷静,继续前行。
参考资料
[2] DeepLearning.AI 的 The Batch
超越 Kleinberg 的聚类不可能定理:我对一个务实的聚类评估框架的学习笔记
本文探讨了在 Kleinberg 聚类不可能定理约束下的务实聚类评估框架
·发表于Towards Data Science·12 分钟阅读·2024 年 6 月 21 日
--

由作者处理
在他 2002 年发表的论文“聚类的不可能定理”中,Jon Kleinberg 阐述了没有任何聚类模型可以同时满足聚类的三个理想公理:尺度不变性、丰富性和一致性。(Kleinberg, 2002)
那三个公理是什么意思?这里是对这三个公理的解释。
-
尺度不变意味着:聚类算法应该在所有数据点之间的距离按一个常数因子缩放时,生成相同的结果。
-
丰富性意味着:聚类算法应该展示出高效能,能够生成给定数据集的所有可能划分。
-
一致性意味着:当我们通过增加类间距离并减少类内距离来增强一组聚类时,聚类算法应该生成相同的结果。
长话短说,Kleinberg 证明了一个数学上令人满意的聚类算法是不存在的。
对于一些理论主义者来说,这可能是聚类分析的(或许是)死亡宣判。
尽管如此,我遇到了一篇挑战克莱因伯格“不可能定理”有效性的学术论文。我不会进入那个领域。但如果你对这个话题感兴趣,给你这里:“关于克莱因伯格的聚类公理与 k-means 聚类算法行为之间的差异。”
无论真实情况如何,自从克莱因伯格发布了他的“不可能定理”以来,许多来自工程学领域的方法(例如应用数学、信息理论等)被提出用于聚类评估。
追求务实主义,填补理论/科学限制与实际功能之间的空白,这是工程学的领域。
事实上,似乎没有任何普遍接受的科学理论能够解释为什么飞机能够飞行。这里有一篇关于此的文章。在缺乏科学理论的情况下,凭借工程学的艺术,我已顺利完成了许多次飞行。

在欣赏工程务实精神时,我需要一个合理的框架来填补克莱因伯格“不可能定理”和我们日常聚类分析实践应用之间的空白。
如果飞机能够在没有普遍接受的科学理论的情况下飞行,那么我们也可以进行聚类!
也许……为什么不呢!
说起来容易做起来难,在克莱因伯格的“不可能定理”下:
-
我们如何评估聚类算法的结果?
-
我们如何选择最适合给定目标的算法?
对这些非常简单问题的全面理解仍然难以捉摸,至少对我而言是如此。
在这种背景下,我遇到了一篇由 Palacio-Niño 和 Berzal(2019)撰写的论文,“无监督学习算法的评估指标”,在其中他们概述了一个聚类验证框架,试图在“无法定理”提出的数学限制下,更好地评估聚类性能的质量。是的,他们在制定框架时非常清楚克莱因伯格的“不可能定理”。
为了促进我们对聚类算法的务实使用,我认为在这篇文章中分享我关于务实评估框架的学习笔记会是建设性的。
由于这是我的笔记,在其中我根据个人需求进行了许多修改,偏离了 Palacio-Niño 和 Berzal 所写论文的细节。此外,因为本文更多的是打算勾画出所提出的聚类验证框架的整体结构,并未深入细节。如果你愿意,可以阅读他们原始论文的全文,以填补我文章与他们原始论文之间的空白。
作为最后的预防措施,我并不声称这是一个全面的或标准的聚类验证框架指南。但我希望聚类分析的新手能将其作为一个有用的指南,帮助他们构建自己的聚类验证框架。
不多,也不少。
现在,让我们开始吧。
聚类验证的整体框架
这是它们框架的结构。我们可以看到四个验证阶段:
-
初步评估,
-
内部验证,
-
外部验证,
-
相对验证。
让我们逐一检查它们。
1. 初步评估:
这个过程的目标只是简单地确认数据集中的簇的存在。
这个过程应用了假设检验框架来评估数据集中是否存在聚类倾向。此过程设定了零假设,即数据集是纯随机的,因此数据集中没有聚类倾向。
由于假设检验可以视为一个独立的主题,我将从本文中将其单独剔除。
2. 内部验证或无监督验证
内部验证的目标是仅基于给定数据集评估聚类结构的质量,而不使用任何关于真实标签的外部信息。换句话说,当我们没有任何关于真实标签的高级知识时,内部验证是唯一的选择。
聚类内部验证的典型目标是发现那些最大化簇内相似度并最小化簇间相似度的簇。为此,设计了内部标准来衡量簇内相似度和簇间离散度。简单!
话虽如此,有一个问题:
“在内部标准上的良好评分不一定能转化为在应用中的良好效果。”(Manning 等,2008)
在内部标准上得分更高并不一定能保证结果模型的更好效果。内部验证是不够的!
那么,我们该怎么做呢?
这正是我们确实需要外部验证的原因。
3. 外部验证或监督验证
与内部验证相比,外部验证需要外部类别标签:理想情况下是地面真实标签,如果没有,可能是其代表性替代。因为我们使用无监督聚类算法的第一个原因是因为我们根本不知道标签,外部验证的概念看起来荒谬、矛盾,或者至少是违反直觉的。
然而,当我们有关于类别标签的外部信息时 — 例如来自基准模型或黄金标准模型的一组结果 — 我们可以实施外部验证。
因为我们使用参考类别标签进行验证,外部验证的目标自然会收敛到监督分类分析的一般验证框架。
在更广泛的范围内,这个类别包括模型选择和人类判断。
4. 相对验证:
接下来,相对验证。
这里是相对验证的示例。
特别是对于分区聚类(例如 K-Means)类别,确定簇的数量是确定算法配置的重要起点,因为这会实质性地影响聚类结果。
换句话说,对于这类聚类算法,簇的数量是算法的一个超参数。在这种情况下,簇的数量需要从算法参数的角度进行优化。
这里的问题是优化需要与决定算法配置的其他超参数同时实施。
它需要比较来了解一组超参数设置如何影响算法配置。这种相对验证通常在参数优化领域内处理。由于机器学习算法的参数优化是机器学习培训(模型开发)的一个重要主题,我将在本文中将其搁置。
到目前为止,我们对他们的验证框架的整体概况有了一个公平的了解。
接下来,一个相关的问题是“我们应该为每个验证使用什么样的度量标准?”
在这种情况下,我收集了一些度量标准作为内部验证和下一节外部验证的示例。
用于内部和外部验证的度量标准
现在,让我们专注于内部验证和外部验证。下面,我将列出我选择的一些度量标准,并提供超链接,您可以在其中详细了解它们的定义和公式。
由于我不会涵盖这些度量标准的公式,建议读者点击下面提供的超链接查找!
A. 用于内部验证的度量标准
内部验证的目标是仅基于给定数据集建立聚类结构的质量。
内部评估方法分类:
内部验证方法可以根据聚类方法的类别进行分类。聚类的典型分类可以如下表述:
-
划分方法(例如:K-means),
-
分层方法(例如:聚合聚类),
-
密度基方法(例如:DBSCAN),以及
-
其余部分
这里,我介绍前两种方法:划分聚类和层次聚类。
a) 划分方法:例如 K-means
对于划分方法,评估度量的三个基础是:凝聚力、分离度及其混合。
凝聚力:
凝聚力评估簇内数据结构的紧密程度。凝聚力度量值越低,聚类质量越好。凝聚力度量的一个例子是:
- SSW:簇内平方误差之和。
分离度:
分离度是一个簇间度量,评估簇间数据结构的离散程度。分离度度量的核心思想是最大化簇之间的距离。凝聚力度量的一个例子是:
- SSB:簇间平方误差之和。
凝聚力和分离度的混合:
混合类型度量在一个度量中量化了分离度和凝聚力的水平。以下是一些例子:
i) 轮廓系数:范围为[-1, 1]
该度量是邻近簇之间的簇间距离的相对度量。
这是该度量的一般解释:
-
最佳值:1
-
最差值:-1。
-
值接近 0:簇重叠。
-
负值:样本被分配到错误簇的可能性较高。
这是该度量的一个使用案例:www.geeksforgeeks.org/silhouette-index-cluster-validity-index-set-2/?ref=ml_lbp
ii) Calinski-Harabasz 系数:
该度量也被称为方差比准则,衡量所有簇的簇间离散度与簇内离散度的比率。
对于给定的簇分配,度量值越高,聚类结果越好:因为较高的值表示结果簇紧凑且分离良好。
这是该度量的一个使用案例:www.geeksforgeeks.org/dunn-index-and-db-index-cluster-validity-indices-set-1/?ref=ml_lbp
iii) Dunn 指数:
对于给定的簇分配,较高的 Dunn 指数表示更好的聚类效果。
这里有一个该指标的使用案例:www.geeksforgeeks.org/dunn-index-and-db-index-cluster-validity-indices-set-1/?ref=ml_lbp
iv) Davies Bouldin Score:
该指标衡量了簇内相似性与簇间相似性的比例。从逻辑上讲,较高的指标值意味着簇内结构更加密集,簇间结构更加分离,因此聚类结果更好。
这里有一个该指标的使用案例:www.geeksforgeeks.org/davies-bouldin-index/
b) 层次方法:例如凝聚聚类算法
i) 基于树状图的可视化表示进行人工判断。
尽管 Palacio-Niño 和 Berzal 没有包括人工判断,但它仍然是基于树状图的层次聚类内部验证最有用的工具之一。
相反,合著者列出了以下两个专门用于评估层次聚类结果的相关系数指标。
对于这两个指标,它们的较高值表示更好的结果。两者的取值范围在[-1, 1]之间。
ii) 共现相关系数(CPCC):[-1, 1]
它衡量由联接定义的层次聚类中观测值之间的距离。
iii) Hubert 统计量:[-1, 1]
更高的 Hubert 值对应着更好的数据聚类效果。
c) 潜在类别:自监督学习
自监督学习可以生成用于聚类的特征表示。自监督学习的数据集没有明确的标签,而是使用输入数据本身作为学习的标签。Palacio-Niño 和 Berzal 没有在这一部分中包括自监督框架,如自编码器和生成对抗网络(GANs)。嗯,它们本身不是聚类算法。尽管如此,我会将这个特定领域暂时留待我的笔记。时间会证明这个领域是否会出现任何专门的度量。
在结束内部验证部分之前,以下是 Gere(2023)的一个警告。
“选择合适的层次聚类算法和聚类数始终是一个关键问题……在许多情况下,研究人员并未公开选择某一特定距离度量和连接规则及聚类数量的原因。背后的原因可能是不同的聚类验证和比较技术在大多数情况下给出了相互矛盾的结果……验证方法的结果偏离,表明聚类结果在很大程度上依赖于数据集本身。尽管欧几里得距离、Ward 方法似乎是一个安全的选择,但强烈建议测试和验证不同的聚类组合。”
是的,这是一个艰难的任务。
现在,让我们进入外部验证的部分。
B. 外部验证使用的指标
一再强调,内部标准的更好得分不一定能保证结果模型的更高效性。(Manning 等人,2008)在这种情况下,我们必须探索外部验证。
与内部验证不同,外部验证需要外部类标签。当我们拥有这样的外部信息——作为一种理想选项的真实标签或作为实际选项的替代标签,例如基准模型的结果——聚类的外部验证目标就会设计成与监督学习的目标一致。
共同作者列出了三类外部验证方法:匹配集、点对点相关性和信息理论。
所有这些方法或多或少都在比较两组集群结果:一组是来自待评估聚类算法的结果,称之为C;另一组,称之为P,来自外部参考——另一个基准算法,或者如果可能的话,来自真实类别的结果。
- 匹配集:
这一类方法识别了每个预测集群在C与其对应外部参考集群P之间的关系。其中一些是监督分类中常用的验证指标。我将列出这一类别中的一些指标,详情请查看它们的超链接。
a) 分类准确率:
b) 纯度:
c) 精确度:
d) 召回率:
e) F-measure:
2. 点对点相关性:
这类度量标准是两种不同方法得到的等价分区之间的相似性度量,C和P。逻辑上,相似性越高,聚类结果越好:预测的聚类类别类似于参考类标签。
a) Jaccard Score: [0, 1]
它通过测量外部参考类标签与预测标签之间的重叠来比较这两组:交集大小与两个标签集合并大小的比率。
指标越高,这两组之间的相关性越强。
b) Rand Index: [0, 1]
“从数学角度来看,Rand 指数与准确性有关,但即使在不使用类标签的情况下也适用.”
这里是如何解释度量结果的。
-
值为 0:聚类结果的两组C和P之间没有一致性。
-
值为 1:两组之间完全一致。
这里是度量标准的一个用例示例:
www.geeksforgeeks.org/rand-index-in-machine-learning/?ref=ml_lbp
c) 福尔克斯-马洛斯系数
它衡量了“精确率和召回率的几何平均.”
这里是度量标准的一个用例示例:
www.geeksforgeeks.org/ml-fowlkes-mallows-score/
3. 信息论:
现在,我们有来自信息论的另一类度量标准。这类度量标准有两个基础:熵和互信息。
熵是“一种允许我们测量聚类结果中混乱程度的纯度的倒数度量.”
互信息度量“在已知先前分区的情况下,关于聚类结果的不确定性减少.”
我们还有以下度量标准作为示例。
4. 模型选择度量:
对于外部验证,我还想从另一个参考文献(Karlsson 等,2019)中添加以下模型选择度量标准。
我们可以使用它们来比较多个结果中的各项指标的值。那些在这些指标上得分最低的结果被认为是最合适的。然而,单独使用这些指标无法评估单一结果的质量。
这是使用这些模型选择指标的一个警告。为了使任何这些信息准则在评估模型时有效,需要满足一组特定的前提条件:低多重共线性、足够的样本量以及模型拟合良好、R 平方值较高。当任何一个条件没有满足时,这些指标的可靠性可能会大大降低。(Karlsson 等,2019)
本文到此为止。
我并不声称我在这里所涵盖的内容是全面的,甚至是黄金标准。事实上,有不同的方法。例如,R 语言有一个名为 clValid 的聚类验证包,它使用不同的方法:从“内部”、 “稳定性”到“生物学”模式。而且我认为clValid是一个非常棒的工具。
鉴于此,我希望这篇文章能为初学者在构建自己的聚类评估框架时提供一个有用的起点指南。
一再强调,这篇文章的目的是在克莱因伯格的聚类不可能定理所定义的理论限制下,概述一种可能务实的聚类评估框架。
最后但同样重要的是,请记住以下格言:
这个格言在我们处理任何模型时都应该继续在我们心中回响。顺便提一下,这个格言常常与历史上著名的统计学家乔治·E·P·博克斯(George E. P. Box)联系在一起。
鉴于我们所生活的环境并不完美,让我们共同推动以亚里士多德式的[实践智慧](https://en.wikipedia.org/wiki/Phronesis#:~:text=Phronesis (Ancient Greek%3A φρόνησῐς%2C,discussion in ancient Greek philosophy.)的精神传播实际知识。
感谢阅读。
杉野道夫
参考文献
-
Davies_bouldin_score. (n.d.). Scikit-Learn. 检索日期:2024 年 6 月 20 日,来自
scikit-learn/stable/modules/generated/sklearn.metrics.davies_bouldin_score.html -
Gere, A. (2023). 消费者感官项目中验证层次聚类的建议。食品科学当前研究,6,100522.
doi.org/10.1016/j.crfs.2023.100522 -
Karlsson, P. S., Behrenz, L., & Shukur, G. (2019). 当变量条件不良时模型选择标准的表现。计算经济学,54(1),77–98.
doi.org/10.1007/s10614-017-9682-8 -
Kleinberg, J. (2002). 聚类的不可能性定理。神经信息处理系统进展,15。
-
Manning, C. D., Raghavan, P., & Schütze, H. (2008). 信息检索导论/ 平面聚类/ 聚类评估。Https://Nlp.Stanford.Edu/IR-Book.
nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-clustering-1.html -
Palacio-Niño, J.-O., & Berzal, F. (2019). 无监督学习算法的评估指标 (arXiv:1905.05667)。arXiv.
arxiv.org/abs/1905.05667
超越折线图和条形图:7 种不太常见但强大的可视化类型
通过这些富有创意和洞察力的可视化提升你的数据讲故事能力
·发表于Towards Data Science ·12 分钟阅读·2024 年 9 月 26 日
--
在上一篇文章中,我分享了自 2018 年以来每周制作一个可视化的旅程——现在我已经在我的Tableau Public个人主页上发布了 350 多个可视化作品!毫不意外,在所有可视化类型中,我使用得最多的还是条形图和折线图。它们简单,但在讲述故事时非常有效且直观。然而,我有时会感到制作相似的图表让人厌倦,它们也可能无法展示复杂的数据模式。
在这篇文章中,我将介绍七种不太常见但非常强大的可视化类型。我将讨论它们的具体应用场景,分享我的可视化示例,分析它们的优缺点,并向你展示如何在像 Tableau 这样的可视化工具中创建它们。

图像由 DALL·E 创作
1. 碰撞图:跟踪排名随时间变化
碰撞图是一种特殊类型的折线图,用于可视化多个类别的排名随时间变化的情况。它的 y 轴表示类别排名,因此可以显示各类别如何随着时间的推移相互“碰撞”并上下浮动。因此,它非常适合展示类别之间的竞争。
超越数学与 Python:你应该发展其他关键的数据科学技能
·发布于 Towards Data Science ·通过 Newsletter 发送 ·阅读时间:4 分钟 ·2024 年 11 月 7 日
--
想要写下你的第一篇 TDS 文章吗?我们始终欢迎新作者的投稿。
数据科学成功的道路提供了许多不同的路径,但其中大多数都强调数学和编程技能(例如:这篇为数据专业人士准备的优秀指南 是Saankhya Mondal 本周早些时候发布的)。然而,一旦你在这些领域打下基础,接下来是什么呢?数据科学家需要在哪些领域建立专业知识,以便在竞争激烈的就业市场中脱颖而出?
我们每周的重点内容聚焦于一些你可能希望在接下来的几周和几个月中探索的领域,并提供来自深耕于行业和学术多个角色的作者们的可操作建议。从掌握数据基础设施的细节到拓展讲故事的技巧,让我们仔细看看那些外围的——但依然至关重要——潜在成长领域。
-
超越技能:释放数据科学家的全部潜力“数据科学家拥有独特的视角,使他们能够提出创新的商业理念——这些理念新颖、具有战略性或具有差异化,并且不太可能来自其他人,除了数据科学家。” Eric Colson扩展了这一发人深省的前提,即公司过度关注数据科学家的技术技能,忽视了他们的创造力和创新思维,从而导致数据科学家的潜力未被充分利用。
-
我从一场与 AI 无关的数据会议中学到的三大数据经验近年来,AI 已经彻底主导了相关话题,因此听到关于数据科学家如何保持在自己领域的前沿的其他方式,感到耳目一新。Nithhyaa Ramamoorthy回顾了她最近参加的一次会议,并谈到这次经历如何激励她更加关注那些看似不如最新大型语言模型(LLM)光鲜亮丽,但却能提升数据从业者价值的问题——从成本控制和数据转换到信息设计。
-
数据科学领导者的终极生产力系统对于任何从事数据科学管理工作的人来说——无论是处于早期阶段还是职业生涯的更深层次——有时会感觉领导力技能似乎是随着时间的推移自然而然地增长的。虽然从某种角度来说这可能是对的,Rebecca Vickery的最新文章阐明了你可以采取的一些具体步骤,确保即使随着角色需求的增长,你依然保持专注和高效。

由 In The Making Studio 提供的照片,来源于 Unsplash
-
掌握背面纸条上的数学会让你成为更好的数据科学家我们知道,我们知道:我们承诺不涉及数学。但Torsten Walbaum的新文章建议,数据专业人士可能不需要过于关注复杂的公式和建模,而应更加专注于培养自己在做粗略——但扎实——估算时的舒适感。
-
从 AI 画布到 MLOps 堆栈画布:它们真的必要吗? 随着工具和数据堆栈复杂性的增长,产品相关方很容易失去对所有移动组件如何协同工作的跟踪。Chayma Zatout 将为您提供帮助,通过动手实践介绍如何构建和使用画布,“一种帮助个人和团队以结构化方式映射和分析给定项目各个方面的可视化框架。”
-
我希望我早些遇到的 AWS Bedrock 教程:你需要知道的一切,准备好你的机器迎接 AWS 基础设施 “如何将一个不错的小型机器学习原型从笔记本中发展成一个强大的全栈 Web 应用?” 从数据分析的琐碎细节中稍微退后一步,Minda Myers 鼓励数据专业人士考虑他们的技术设置,并将其优化以实现流畅和高效的工作流程。
-
从洞察到影响:每个数据科学家需要的演讲技巧 强有力的故事讲述核心已经是许多,甚至大多数数据科学角色的核心内容;然而,在许多项目中,这仍然是一个没有得到充分关注的领域——你只能期望自己神奇地提高这一技能。Yu Dong 在她最新的文章中探讨了成功演讲的一些核心要素,并提供了设计成功幻灯片的具体建议。
-
如何创造机会并在数据科学职位申请中取得成功 正如Robson Tigre 所阐明的,成为一名出色求职者并识别合适机会的过程需要一套独特的技能——其中大部分与数据或算法无关,而是围绕自我展示(和营销)、人际网络和沟通技巧展开。
感谢您对我们作者工作的支持!正如我们上面提到的,我们非常喜欢发布新作者的文章,因此,如果您最近写了一篇有趣的项目演练、教程或关于我们核心话题的理论反思,请毫不犹豫地与我们分享。
直到下一个 Variable,
TDS 团队
超越预测:提升模型与影响力科学(第一部分)

作者插图
基于树模型的提升方法
·发表于 Towards Data Science ·阅读时间:14 分钟·2024 年 1 月 3 日
--
预测分析一直是决策制定的基石,但如果我们告诉你,除了预测之外,还有另一种选择呢?如果你能够战略性地影响结果呢?
提升模型承载着这一承诺。它通过识别那些在接受特殊处理后行为可以积极影响的个体,为传统的预测方法增加了一个有趣的动态层面。
应用案例是无穷无尽的。在医学领域,它可以帮助识别那些通过医疗治疗能够改善健康的患者。在零售行业,这样的模型可以更好地定位那些能够通过促销或个性化服务提高客户保持率的客户。
目标
本文是一个系列文章的第一部分,探讨了提升模型的变革潜力,阐明了它如何重塑市场营销、医疗等领域的策略。本文重点介绍基于决策树的提升模型,并以促销优惠的应用为案例,预测客户转化率。
超越 RAG:通过 LLM 进行网络分析以提取知识
使用 Streamlit、Upstash 和 OpenAI 的端到端数据科学项目,通过网络分析构建更好的知识导航和理解
·发表于Towards Data Science ·阅读时间 26 分钟·2024 年 2 月 24 日
--

本文将引导您通过一个使用多个前沿 AI 工具的端到端数据科学项目。这个工具被称为Mind Mapper,因为它允许你通过将信息注入知识库并以智能方式检索它来创建概念图。
动机是超越“简单”的 RAG 框架,在该框架中,用户查询向量数据库,然后将其响应输入 LLM(如 GPT-4)以获得更丰富的答案。
Mind Mapper 利用 RAG 来创建中间结果表示,这对于执行某种知识智能非常有用,这也使我们能够更好地理解 RAG 在长篇且无结构文档中的输出结果。
简单来说,我想将 RAG 作为构建多样化响应的基础步骤,不仅仅是文本回应。思维导图就是其中一种响应。
以下是该工具的一些功能:
- 管理几乎所有形式的文本:复制粘贴的文本、文本以及来自音频源的文本(视频是...
超越 RAG:在语义世界中的精确过滤

图片由Nathan Dumlao拍摄,来源于Unsplash
通过使用传统的机器学习方法来缩小大语言模型(LLM)响应中的差距,从而对齐期望与现实。
·发表于Towards Data Science ·阅读时长 7 分钟·2024 年 11 月 12 日
--
很早我们就意识到,大语言模型(LLMs)仅知道其训练数据中的内容。与它们玩耍很有趣,没错,但它们一直以来都容易产生幻觉。将这种“原始”形式的产品用于商业化,可以说——简直是愚蠢透顶(指的是 LLM,不是你…可能)。为了解决幻觉问题以及对未见过/私有数据的知识缺乏,有两条主要途径可供选择。训练一个基于你私人数据的定制 LLM(也就是费力的办法),或者使用检索增强生成(也就是我们基本上都选择的办法)。
RAG 是现在在自然语言处理(NLP)和生成型 AI 领域广泛使用的一个缩写。它已经发展并引领了许多不同的新形式和方法,例如 GraphRAG,偏离了大多数人最初使用的简单方法。两年前的我可能会将原始文档解析成一个简单的 RAG,然后在检索时,将这些可能的(很可能是)无用的上下文提供给 LLM,希望它能够理解并利用这些信息更好地回答用户的问题。 哇,真是“无知是福”;而且,别评判:我们都做过这种事。 我们很快意识到,“垃圾进,垃圾出”,因为我们最初的概念验证……嗯……表现得并不好。于是,开源社区投入了大量的精力,给我们提供了更多合理的方式来实现商业上可行的应用。这些方法包括例如:重新排序、语义路由、防护措施、更好的文档解析、将用户的问题重新对齐以检索更相关的文档、上下文压缩,等等。除此之外,我们所有人还提高了传统 NLP 技能,并为团队编写了指导方针,确保存储在数据库中的解析文档既整洁又易于阅读。
在处理一个大约有 16 个步骤的检索系统时,始终有一个问题不断出现。我存储的上下文真的能回答这个问题吗? 或者换句话说,我更喜欢的表述是:这个问题真的属于存储的上下文吗? 虽然这两个问题看起来相似,但它们的区别在于,第一个问题是局部化的(例如,10 个检索到的文档),而第二个问题是全局化的(针对整个文档数据库的主题/话题空间)。你可以把它们看作一个是精细化筛选,而另一个是更广泛的筛选。我相信你现在可能在想,这一切的意义何在?“我对我检索到的文档做余弦相似度阈值筛选,一切正常,为什么你要把事情搞得这么复杂?” 好吧,我编造了最后那句话,我知道你不会那么刻薄。
为了凸显我的过度复杂化,这里有一个例子。假设用户问:“谁是第一个登上月球的人?”现在,假设我们忘记 LLM 可以直接回答这个问题,我们期望 RAG 为这个问题提供上下文……但是,我们的所有文档都是关于时尚品牌产品的!这个例子有点傻,没错,但在生产环境中,我们很多人都看过用户总是提出一些与我们拥有的文档不相关的问题。“是的,但我的前提告诉 LLM 忽略不属于某个话题类别的问题。余弦相似度会过滤掉这些问题的弱上下文”或者“我使用了护栏或语义路由来处理这个问题。”当然,同意,这些方法有效,但所有这些选项要么是在下游做得太晚(比如前两个例子),要么没有完全针对这个问题量身定制(比如最后两个例子)。我们真正需要的是一种快速的分类方法,能够在检索文档之前快速告诉你问题是否适合由文档提供上下文……即使在检索之前。如果你猜到这是什么意思,那么你就是经典机器学习团队的一员了 😉 是的,没错,就是经典的离群值检测!
离群值检测与自然语言处理(NLP)结合?显然有人有太多的空闲时间来玩这个。
在构建生产级 RAG 系统时,有几件事情我们需要确保:效率(响应通常需要多长时间)、准确性(响应是否正确且相关)和可重复性(有时被忽视,但非常重要,检查一下缓存库)。那么,离群值检测方法(OD)如何帮助这些呢?让我们快速集思广益。如果 OD 看到一个问题并立即说“否,这是一个离群值”(我在这里拟人化),那么很多后续步骤可以被跳过,从而使得这个过程更加高效。假设 OD 现在说“是,安全”,好吧,稍微增加一些开销,我们就能更有把握地知道问题和存储文档的主题空间是否一致。至于可重复性,幸运的是,我们又是幸运的,因为经典的机器学习方法通常是可重复的,所以至少这个额外的步骤不会突然开始道歉并带我们进入重复和误解的恶性循环(我在看你,ChatGPT)。
哇,这部分有点冗长,抱歉,不过终于我现在可以开始展示一些有趣的内容了。
Muzlin,一个 Python 库,是我积极参与的项目,专门为这些类型的语义过滤任务开发,通过使用简单的机器学习来适应生产环境。怀疑吗?好吧,来吧,让我们快速看看它能为我们做些什么。
我们将使用的数据集是来自 BEIR(Scifact,CC BY-SA 4.0)的 5.18K 行数据集。为了创建一个向量库,我们将使用科学的声明列。
所以,数据已经加载了(虽然是一个小数据集,但嘿,这只是一个演示!),接下来的步骤是编码它。有很多方法可以做到这一点,例如分词、向量嵌入、图节点-实体关系等,但对于这个简单的示例,我们就使用向量嵌入。Muzlin 内置了对所有流行品牌(苹果、微软、谷歌、OpenAI)的支持,嗯,我是说它们的关联嵌入模型,但你懂的。我们就选择,嗯,HuggingFace,因为你知道,它是免费的,而我当前的 POC 预算是…就像省吃俭用一样。
太棒了!如果你能相信的话,我们已经走了一半了。难道只有我觉得,很多这些 LLM 库让你写了额外的 1000 行代码,依赖了成千上万的库,结果在老板要求演示时就崩溃了?不只是我吧?对吧?不管怎样,废话少说,我们真的只剩下两个步骤,就能让我们的过滤器启动并运行。第一个步骤是使用一个异常检测方法来评估嵌入向量。这可以构建一个无监督模型,给出当前或新嵌入中任何给定向量的可能性值。
不开玩笑,就这样了。你的模型已经完成了。Muzlin 完全兼容 Sklearn,并经过 Pydantic 校验。而且,MLFlow 也完全集成了数据记录功能。上面的示例并没有使用它,因此该结果会自动在你的本地目录中生成一个 joblib 模型。很酷吧?目前只有 PyOD 模型支持这种类型的 OD,但谁知道未来会怎样呢。
该死的丹尼尔,你怎么让这一切变得这么简单。敢打赌你一直在引导我,然后从这里开始一切都将一落千丈。
对于上面的回应,s..u..r..e 这个梗现在确实有点过时了。但除此之外,别开玩笑,最后一步就要来了,和之前的所有步骤一样,轻松得很。
好的,好吧,这是最长的脚本,但看……大部分内容只是为了玩玩而已。但让我们分析一下这里发生了什么。首先,OutlierDetector 类现在期待一个模型。我发誓这不是 bug,这是一项功能!在生产环境中,你不希望每次都在现场训练模型来进行推理,而且通常训练和推理是在不同的计算实例上进行,特别是在云计算上。所以,OutlierDetector 类为此提供了支持,让你加载一个已经训练好的模型,以便可以随时进行推理。YOLO。现在你要做的就是编码一个用户的问题,然后用 OD 模型进行预测,嘿,瞧瞧,这里,我们找到了一个小小的异常值。
那么,当用户的问题是一个异常值时,这意味着什么呢?很酷的事情,这一切都由你来决定。存储的文档很可能没有任何上下文能够以有意义的方式回答这个问题。你可以选择将其重新引导,要么告诉测试团队的 Kyle 停止胡闹,要么更严肃一点,节省代币并设置一个默认响应,例如“抱歉,Dave,我恐怕做不到” (哦,HAL 9000,你真幽默,也请不要太空我)。
总结一下,集成更好(哈哈,给数学爱好者的数学笑话)。但实际上,经典的机器学习已经存在了很久,并且在生产环境中更具可靠性。我相信未来更多的工具应该融入这种理念,尤其是在我们所有人共同参与的生成性 AI 过山车旅程中,(顺便说一下,这趟旅程花费的代币实在太多)。通过使用异常检测,偏离主题的查询可以快速被重新引导,从而节省计算和生成成本。作为额外的福利,我甚至提供了一个选项,可以用 GraphRAGs 来实现这一点,太棒了——极客们团结起来!前进吧,享受那些开源开发者为了让我们免费使用而付出了太多失眠的工具。祝你旅途愉快,记得享受乐趣!
超越技能:解锁数据科学家的全部潜力。

图像由作者通过 DALL-E / OpenAI 生成。
通过赋予数据科学家更多的责任,超越技术任务,来解锁他们的隐藏价值,推动创新和战略洞察力。
·发布于Towards Data Science ·阅读时间 16 分钟·2024 年 10 月 30 日
--
[本文转载自 O'Reilly Radar,点击此处]
介绍
现代组织将数据视为一种战略资产,推动效率提升、增强决策能力,并为客户创造新的价值。在组织的各个部门——产品管理、市场营销、运营、财务等——团队充满了关于数据如何提升业务的创意。为了实现这些创意,公司热衷于招聘数据科学家,利用他们的技术技能(Python、统计学、机器学习、SQL 等)。
尽管有这种热情,许多公司仍然在很大程度上未能充分利用他们的数据科学家。组织仍然狭隘地将数据科学家仅仅作为执行现有想法的工具,忽视了他们所能带来的更广泛价值。除了技能之外,数据科学家还拥有独特的视角,使他们能够提出自己的创新商业想法——这些想法新颖、战略性强、具有差异化,并且很难由其他人提出,除非是数据科学家。
过度关注技能和执行
可悲的是,许多公司以一种表明他们对数据科学家的想法不感兴趣的方式行事。相反,他们将数据科学家视为仅仅利用其技能的资源。职能团队提供需求文档,详细指定计划:“这是你要为我们构建这个新系统的方式。感谢你的合作。”没有提供背景,也没有征求任何意见——除了交付的估算。数据科学家还会收到各种临时的请求,如战术分析或运营仪表盘¹的要求。请求的积压变得如此庞大,以至于工作队列通过类似 Jira 的票务系统进行管理,这些系统剥夺了请求的任何商业背景(例如,“给我提供 VIP 客户购买的热门产品”)。一个请求引发另一个请求²,形成一种无休止的努力,使得数据科学家没有时间为自己思考。然后是各种不透明的数据拉取请求:“请给我这些数据,这样我就可以进行分析。”这种做法是边缘化的——就像要求斯蒂芬·库里传球,以便你来投篮。这不是一种合作关系;它是一种使数据科学沦为仅仅支持其他团队想法的职能。这种做法虽然能执行任务并创造一些价值,但却无法发挥数据科学家真正能够提供的全部潜力。
正是这些想法
数据科学家的未开发潜力不在于他们执行需求或请求的能力,而在于他们提出的变革业务的想法。这里所说的“想法”是指可以推动业务向更好或新方向发展的新能力或策略——从而带来收入、利润或客户保持的增长,同时提供可持续的竞争优势(即难以被竞争对手复制的能力或策略)。这些想法通常以机器学习算法的形式出现,能够在生产系统中自动化决策。例如,一位数据科学家可能开发出一种算法,通过最优平衡溢出和短缺成本来更好地管理库存。或者,他们可能创建一个模型来检测隐藏的客户偏好,从而实现更有效的个性化。如果这些听起来像是商业想法,那是因为它们确实是——但这些想法不太可能来自业务团队。这类想法通常来自数据科学家,他们独特的认知储备和数据观察使得他们非常适合发掘这样的机会。
利用独特认知储备的想法
认知储备是个人可以运用来思考、解决问题或处理信息的工具、策略和方法的范围(Page 2017)。这些知识储备受到我们背景的塑造——教育、经验、培训等等。一个特定职能团队的成员通常会有相似的知识储备,因为他们有共同的背景。例如,市场营销人员会学习像SWOT 分析和ROAS这样的框架,而财务专业人士则学习像ROIC和Black-Scholes这样的模型。
数据科学家有着独特的认知储备。尽管他们的学术背景可能各不相同——从统计学到计算机科学,再到计算神经科学——但他们通常共享一个定量工具包。这个工具包包括适用于广泛问题的框架,通常有一些易于理解的名字,如“报童模型”、"旅行商问题"、“生日问题”以及许多其他框架。这个工具包还包括机器学习算法⁵的知识,如神经网络、聚类和主成分分析,它们被用来找到复杂问题的经验解。除此之外,它们还包括像大 O 符号、中心极限定理和显著性阈值这样的启发式方法。所有这些构建模块都可以用共同的数学语言表达,使它们可以轻松地在不同领域之间转移,尤其是在商业领域。
数据科学家的知识储备与商业创新特别相关,因为在许多行业⁶中,从数据中学习的条件几乎是理想的,原因在于它们具有高频事件、明确的目标函数⁷,以及及时且明确的反馈。零售商拥有数百万笔产生收入的交易。流媒体服务看到数百万次观看事件,这些事件可以反映客户的兴趣。等等——数百万或数十亿个事件,信号清晰且迅速显现。这些是归纳学习的单元,构成了学习的基础,尤其是在机器的帮助下。数据科学的知识储备,以其独特的框架、机器学习算法和启发式方法,非常适合从大量事件数据中提取知识。
创意的诞生源于认知库与业务背景的结合。数据科学家在参加业务会议时,经常会体验到灵感的闪现。当一位运营经理描述库存过期问题并提到“我们需要买够,但又不能买太多”时,数据科学家的眉毛从笔记本后面抬起,低声自语:“新闻商模型。”一位产品经理问:“随着产品数量的增加,这个过程如何进行扩展?”数据科学家下意识地在笔记本上写下“O(N²)”,这表示该过程将会超线性扩展。当一位市场营销人员提到客户细分话题,抱怨说:“有这么多客户属性,我们怎么知道哪些最重要?”时,数据科学家立刻发短信取消了晚上的计划。今晚,她将迫不及待地尝试对客户数据进行主成分分析⁸。
没有人要求提出创意。这仅仅是一次战术性会议,目的是回顾业务状态。然而,数据科学家几乎被促使去进行创意构思。“哦,哦,我想到一个了,”她自言自语。创意构思甚至难以抑制。然而,许多公司无意中似乎抑制了这种创造力。实际上,我们的数据科学家可能根本不会被邀请参加那次会议。数据科学家通常不会被邀请参加操作性会议,也通常不会被邀请参加创意会议,而这些会议往往仅限于业务团队。相反,会议小组会将 Jira 任务分配给数据科学家去执行。没有背景信息,这些任务无法激发创意。数据科学家的认知库未能得到有效利用——这无疑是一个错失的机会。
从数据观察中诞生的创意
除了他们的认知库之外,数据科学家还带来了一项关键优势,使他们的想法具有独特的价值。由于他们深度沉浸于数据中,数据科学家能够发现那些未曾预见的模式和见解,从而激发出新的商业创意。它们之所以新颖,是因为没有人会想到这些——产品经理、执行官、市场营销人员——甚至数据科学家也未必会想到。许多创意并非凭空构思出来,而是通过对数据的观察而显现出来。
公司数据仓库(数据仓库、数据湖等)包含着信息中休眠的原始洞察。数据科学家在工作过程中,常常会偶然发现一些引人入胜的模式——一种奇特的分布、一种不直观的关系等等。这个意外的发现激发了他们的好奇心,促使他们进一步探索。
想象一个数据科学家正在做她的工作,执行一个临时请求。她被要求编制一个特定客户群体购买的热门产品清单。令她惊讶的是,各个群体购买的产品几乎没有什么不同。大多数产品在所有群体中购买的比率差不多。真奇怪。这些群体是基于客户自愿选择的个人描述创建的,多年来公司一直认为这些群体是有意义的,可以用来管理产品。“一定有更好的方法来细分客户群体,”她想。她进一步探索,展开了一项非正式的临时分析。没有人要求她这么做,但她控制不住自己。她没有依赖客户用来描述自己的标签,而是专注于他们的实际行为:他们点击、查看、喜欢或不喜欢哪些产品。通过结合定量技术——矩阵分解和主成分分析——她想出了一个方法,将客户放入一个多维空间中。在这个空间中,彼此相邻的客户群体形成了有意义的分组,更好地反映了客户偏好。这种方法还提供了一种将产品放入同一空间的方法,从而允许对产品和客户之间的距离进行计算。这可以用来推荐产品、规划库存、制定营销策略以及其他许多商业应用。所有这一切都源于一个令人惊讶的观察:那些经久不衰的客户群体几乎无法解释客户的行为。像这样的解决方案必须由观察驱动,因为如果没有数据表明相反,没有人会想到要探究一种更好的客户细分方法。
顺便提一下,数据科学家使用的主成分算法属于一种称为“无监督学习”的算法类别,这进一步说明了观察驱动的洞察力的概念。与“监督学习”不同,在监督学习中,用户指示算法要寻找什么,而无监督学习算法让数据自行描述它的结构。它是基于证据的;它量化并排序每个维度,提供一个相对重要性的客观衡量标准。数据在发声。我们常常试图引导数据服从我们人类构思的分类方案,这些方案对我们来说既熟悉又方便,唤起直觉和刻板的原型。这种方法令人满意且直观,但往往是脆弱的,在实践中难以维持。
这样的例子并不罕见。当沉浸在数据中时,数据科学家很难不偶然发现意外的发现。而当他们发现时,更难抵挡进一步探索的诱惑——好奇心是一个强大的推动力。当然,她运用了她的认知能力来完成这项工作,但整个分析的灵感来自于对数据的观察。对于公司而言,这种分心是福不是祸。我见过这种无指导的研究带来更好的库存管理实践、更好的定价结构、新的商品策略、改进的用户体验设计和许多其他能力——这些都是没有被要求的,而是通过观察数据发现的。
发现新见解不是数据科学家的工作吗?是的——这正是本文的重点。问题出现在当数据科学家仅仅因为技术能力而被重视时。将他们仅视为支持团队,使他们只能回答特定问题,从而限制了他们对数据洞察的更深层次探索。应对即时需求的压力常常使他们忽视异常现象、非直观结果以及其他潜在发现。如果数据科学家基于观察提出一些探索性研究,回应几乎总是,“不,专注于 Jira 队列。”即使他们花费自己的时间——包括夜晚和周末——研究一个数据模式,发现一个有前景的商业创意,如果没有在计划或路线图上,也可能仍然会面临阻力。路线图往往是僵化的,忽视了新的机会,甚至是有价值的机会。在一些组织中,数据科学家可能会因为探索新想法而付出代价。数据科学家通常通过他们如何服务于职能团队来进行评判,回应他们的请求,满足短期需求。当探索新想法会影响绩效评估时,几乎没有动力去进行这种探索。实际上,数据科学家经常是在他们的工作之外发现新见解,而不是因为他们的工作。
与众不同的想法
这两点——他们的认知能力和来自数据的观察——使得数据科学家提出的想法独具价值。这并不是说他们的想法必然比业务团队的想法更好。而是说,他们的想法与业务团队的想法不同。而不同本身就有一套好处。
拥有一个看似不错的商业想法并不能保证这个想法会产生积极的影响。证据表明,大多数想法会失败。当正确衡量因果关系时⁹,绝大多数商业想法要么完全没有任何影响,要么实际上会损害指标。(更多统计数据请参见此处)。鉴于成功率较低,创新公司会构建想法投资组合,希望至少有几个成功的想法能够帮助它们实现目标。更聪明的公司则使用实验¹⁰(A/B 测试)在小范围客户中试验其想法,从而评估其影响力,然后再决定是否广泛推广。
这种投资组合方法结合实验的优势,既受益于想法的数量,也受益于想法的多样性¹¹。它类似于多样化股票投资组合。增加投资组合中的想法数量可以提高正面结果的可能性——一个对公司产生实质性积极影响的想法。当然,随着你增加想法,你也增加了坏结果的风险——那些没有任何作用,甚至带来负面影响的想法。然而,许多想法是可逆的——亚马逊的杰夫·贝佐斯所说的“单向门”(Haden 2018)。那些未能产生预期结果的想法,在小范围客户测试后可以被修剪掉,从而大大减轻其影响,而成功的想法则可以推广到所有相关客户,极大地放大其影响。
所以,将想法添加到投资组合中可以在不增加太多下行风险的情况下提高上行潜力——增加的越多,越好¹²。 但是,这里有一个假设,即这些想法是独立的(无关的)。如果所有的想法都相似,那么它们可能会一起成功或一起失败。这就是多样性发挥作用的地方。来自不同群体的想法将利用不同的认知储备和不同的信息集合,这使得它们变得不同,从而不太可能相互关联,产生更多样化的结果。对于股票来说,多样化投资组合的回报将是各个单独股票回报的平均值。然而,对于想法而言,由于实验可以让你减轻不好的想法并放大好的想法,投资组合的回报可能更接近最好的想法的回报(Page 2017)。
除了构建多元化创意的组合外,数据科学家和业务团队之间的合作可以大大加强某个单一创意¹³。 当他们合作时,他们的技能库能够弥补彼此的盲点(Page 2017)¹⁴。通过将多个团队的独特专长和洞察结合起来,创意变得更加有力,正如多元化团队通常在知识竞赛中表现更好一样。然而,组织必须确保真正的合作发生在创意生成阶段,而不是将职责划分得过于严格,让业务团队只专注于产生创意,数据科学家则被 relegated 执行任务。
培养创意
数据科学家不仅仅是执行现有创意的熟练资源;他们是创新思维的源泉。他们的创意独特且有价值,因为(1)他们的认知技能库与具备适当学习条件的企业高度相关,(2)他们在数据中的观察可能带来新的洞察,(3)他们的创意与业务团队不同,为公司创意组合增添了多样性。
然而,组织压力往往阻止数据科学家充分发挥他们的创意。被技能驱动的任务压得喘不过气,缺乏业务背景,他们往往被激励去仅仅完成合作伙伴的要求。这种模式消耗了团队的执行能力,同时让他们的认知技能库和洞察几乎没有被利用。
以下是一些组织可以遵循的建议,以更好地利用数据科学家,并将他们的角色从单纯的执行者转变为创意的积极贡献者:
-
给他们背景,而不是任务。 给数据科学家分配任务或完整的需求文档虽然可以促使他们完成工作,但却无法激发他们的创意。相反,应该给他们背景。如果机会已经被识别出来,应该通过开放的对话广泛地描述它,让他们自己框定问题并提出解决方案。邀请数据科学家参加运营会议,让他们吸收背景信息,这可能会激发出新的创意,甚至是那些尚未被考虑过的机会。
-
为探索创造空间。 公司经常让数据科学家被任务压得喘不过气来。这看起来可能很矛盾,但保持资源 100% 利用率实际上是非常低效的¹⁵。没有时间进行探索和意外的学习,数据科学团队无法发挥其全部潜力。应该为他们的独立研究和探索留出时间,采取类似谷歌的 20% 时间或类似的方法。
-
消除任务管理队列。 任务队列会与数据科学团队形成一种事务性、以执行为中心的关系。如果优先事项是自上而下分配的,那么应该以一般性的、没有框架的机会形式给出,这些机会需要通过真正的对话来提供背景、目标、范围和组织影响。优先事项也可能来自数据科学团队内部,要求职能合作伙伴的支持,数据科学团队则提供必要的背景信息。我们不会为产品或营销团队分配 Jira 票据,数据科学团队也不应有所不同。
-
让数据科学家对真正的业务影响负责。 根据数据科学家对业务结果的影响来衡量他们,而不仅仅是看他们如何支持其他团队。这赋予了他们优先考虑高影响力创意的自主权,无论创意来源于何处。此外,将绩效与可衡量的业务影响挂钩,能够明确低价值临时请求的机会成本。
-
招聘具备适应力和广泛技能的人员。 寻找那些能在模糊和不断变化的环境中茁壮成长的数据科学家,在这些环境中,角色和责任可能并不总是明确的。优先考虑那些对业务影响有强烈渴望的候选人,他们将自己的技能视为推动结果的工具,并擅长识别与公司广泛目标一致的新机会。招聘具有多样化技能的数据科学家能够建立端到端的系统,减少交接的需要,从而降低协调成本——这在创新的早期阶段尤为重要,因为那时迭代和学习最为关键。
-
聘用具有成长心态的职能型领导者。 在新的环境中,避免依赖那些过于依赖在成熟环境中成功经验的领导者。相反,应该寻找那些对学习充满热情、重视协作的领导者,他们能够利用多样化的视角和信息来源来推动创新。
这些建议要求组织具备正确的文化和价值观。文化需要鼓励实验,以衡量想法的影响,并认识到许多实验会失败。它需要将学习视为明确的目标,并理解对于某些行业来说,大多数知识尚未被发现。文化还必须能够放弃指挥控制的清晰度,以换取创新。虽然在初创公司中更容易实现这一点,但这些建议也可以指导成熟组织在经验和信心的推动下不断发展。将组织的重点从执行转向学习是一项具有挑战性的任务,但回报可能是巨大的,甚至对生存至关重要。对于大多数现代公司而言,成功将取决于他们能否利用人类潜力进行学习和创意——而不仅仅是执行(Edmondson 2012)。数据科学家的潜力不仅仅在于他们能够执行现有的想法,更在于他们能够提出那些尚未被任何人想象过的新创意。
脚注
-
毫无疑问,仪表盘在提供业务运营可视化方面有其价值。然而,仪表盘在提供可操作性洞察方面的能力是有限的。汇总数据通常充满了混杂因素和系统性偏见,因此很少适合做决策。构建和维护仪表盘所需的资源需要与数据科学团队可以做的其他可能产生更大影响的项目进行权衡。
-
数据相关的查询往往引发的问题比解答的问题还要多,这是一个广为人知的现象。
-
我使用了“increased”代替“incremental”,因为后者通常与“小的”或“边际的”相关。数据科学项目的影响可能是巨大的。我在这里使用这个术语是为了表示作为一种改进的影响——尽管没有对现有商业模型做出根本性改变。
-
与为人类消费而使用的数据不同,例如简短的总结或仪表盘,后者的确有其价值,因为它们能为我们的人工员工提供信息,但通常在直接可操作性方面有限。
-
我反对将对各种算法的知识称为技能,因为我认为更重要的是强调它们在特定情境下的概念适用性,而不是训练或实施任何特定方法的实用性。
-
与像医学这样的领域相比,电子商务、社交网络和流媒体内容等行业具有更有利的学习条件,因为医学领域中的事件发生频率要低得多,而反馈时间也要长得多。此外,在医学的许多方面,反馈可能非常模糊。
-
通常是收入、利润或用户留存。然而,对于公司来说,识别单一的目标函数可能是具有挑战性的。
-
自愿的实验性探索在数据科学家中很常见,通常是由好奇心、对影响力的渴望、对经验的追求等驱动的。
-
诚然,关于商业创意成功率的数据可能存在偏见,因为大部分数据来自于那些通过在线服务进行实验的科技公司。然而,至少在轶事上,低成功率似乎在其他类型的商业职能、行业和领域中是一致的。
-
并非所有创意都适合进行实验,因为样本量无法达到、无法隔离实验组、伦理问题或其他因素。
-
我故意排除了“创意质量”的概念,因为根据我的经验,我很少看到证据表明一个组织能够在候选池中辨别出“更好”的创意。
-
通常,开发和尝试一个想法的真正成本是人力资源——工程师、数据科学家、产品经理、设计师等。这些资源在短期内是固定的,并且作为限制,决定了在特定时间段内可以尝试的想法数量。
-
见杜克大学教授马丁·鲁夫,他研究了创新的咖啡馆模型(咖啡馆是指将不同背景的人聚集在一起交流的类比)。多元化的网络比线性网络创新性高三倍(鲁夫,2002)。
-
数据科学家将会欣赏这种与集成模型的类比,在集成模型中,个体模型的错误可以相互抵消。
-
见目标,作者艾利亚胡·M·戈德拉特,该书在供应链和制造生产线的背景下阐述了这一点。保持资源高于当前需求的水平,使得公司能够应对需求的突发增长,这将会自我支付。这个做法同样适用于人力资源。
-
通过随机对照试验进行因果测量是理想的,而算法能力对此非常适应。
-
不可否认的是,临时请求的价值并不总是显而易见的。但应该对数据科学资源的使用设定一个高标准。提交一个 Jira 工单太过容易。如果某个议题足够重要,它将值得召开会议来传达背景和机会。
-
如果你正在阅读这篇文章,并且怀疑那些花时间认真处理 Jira 工单的数据科学家是否能够提出一个好的商业想法,你可能并没有错。那些习惯处理工单的人很可能不是创新者,或者已经被深深灌输到支持角色中,以至于失去了创新的意愿。
-
随着系统的发展,可以添加更多的专业资源,以使系统更加健壮。这可能会导致一场争夺。然而,通过首先找到成功,我们能更加谨慎地使用我们宝贵的开发资源。
参考文献
-
Page, Scott E. 2017. The Diversity Bonus. 普林斯顿大学出版社。
-
Edmondson, Amy C. 2012. Teaming: How Organizations Learn, Innovate, and Compete in the Knowledge Economy. Jossey-Bass。
-
Haden, Jeff. 2018. “亚马逊创始人杰夫·贝索斯:成功的人是如何做出如此聪明的决策的。” Inc.,12 月 3 日。
www.inc.com/jeff-haden/amazon-founder-jeff-bezos-this-is-how-successful-people-make-such-smart-decisions.html。 -
Ruef, Martin. 2002. “强关系、弱关系与岛屿:组织创新的结构与文化预测因素。” 工业与企业变革 11 (3):427-449。
doi.org/10.1093/icc/11.3.427。
超越盲区
使用深度学习修复雷达盲区
·发表于 Towards Data Science ·19 分钟阅读·2024 年 4 月 5 日
--

概述
在本文中,我们回顾了最近关于雷达盲区图像修复的工作中的高层次细节。我们讨论了主要的科学问题、修复技术、模型架构决策、准确性指标、不确定性,并最后分析了模型的可解释性(XAI),希望这些信息能够帮助他人在规划未来类似项目时提供帮助。此项工作最近发表在美国气象学会的《地球科学人工智能(AIES)》期刊中, doi.org/10.1175/AIES-D-23-0063.1,我们建议读者查看以获取更多项目细节。
动机
雷达观测是进行降水预测等任务的强大信息来源。这些系统每天被数百万用户使用,帮助他们规划生活,并且它们的预测准确性对农业、旅游业和户外休闲产业具有巨大经济影响。但这些系统是如何工作的呢?简而言之,这些预测模型将降水率与雷达信号与大气中降水气象物体相互作用时的回波功率测量联系起来(图 1)。通过足够大的参考数据集,我们可以利用雷达资料(以及一些大气状态变量)反推出地面降水的估计值。

图 1: a) 反射率剖面图,显示云层反向散射的功率;b) 云中降雪率估算的垂直剖面;c) 从 b)所示最低降水云层推算出的地面降雪率。图片来自 King 等人,2020 年(doi.org/10.1029/2019EA000776)。
与垂直指向的地面雷达仪器静止不动不同,卫星在太空中没有限制,并且由于其轨道可以提供更丰富、全面的全球降水模式视图。然而,与地面雷达不同,卫星仪器由于指向地球而表现出一种独特的测量问题:雷达盲区。顾名思义,盲区是卫星无法直接观察的雷达剖面的一部分。当向下指向的雷达信号到达地球表面时,来自地面的反向散射产生的信号被衰减,并且饱和在噪声中(因此无法观察到)。图 2 展示了地面雷达与空间雷达之间的比较(以及相应的盲区)。

图 2: 多面板的垂直反射率剖面图,来自地面雷达、CloudSat 和全球降水测量任务,以及它们各自的雷达盲区。图片来自 Kidd 等人,2021 年(doi.org/10.3390/rs13091708)。
尽管盲区的大小可能会有所不同,但它是活动空间系统(例如 GPM、CloudSat)上常见的问题,并将在未来的地球观测任务(例如 EarthCARE 和 AOS)中持续作为不确定性的来源。此外,尽管该区域仅占整个剖面的一小部分(例如,仅占 CloudSat 垂直范围的约 4%),但大气的 0-2 公里范围内可能包含大量降水云(Lamer 等人,2020)。因此,通过掩盖该区域,我们可能会遗漏大量雪(或在虚幻雨的情况下高估雪量),从而进一步增加已经存在的不确定性的地面积雪量估算误差。
…盲区导致反射率被低估最多 1 dB,事件数量变化幅度为+/- 5%,降水量低估了 9 到 11 个百分点(Maahn 等人,2014)。
我们能做些什么?
我们的解决方案
从本质上讲,图像修复的思想已经存在数十年,应用于图像修复(例如,去除家庭照片上的划痕)或物体移除(例如,去除你度假时照片中的游客)(Guillemot 和 Le Meur,2013)。最初,这种类型的修复是一项昂贵的工作,需要由受过训练的艺术家手工完成(图 2),但近年来越来越明显的是,计算机也非常擅长这项任务(与人类相比,训练时间短得多)!

图 3: 由受过训练的专业人士手工修复画作。图片由 Ana Alba 提供。
尽管这些技术的早期迭代,如基于结构的线性插值、基于纹理合成的 Efros 插值,以及基于扩散的拉普拉斯或纳维-斯托克斯插值,在某些情况下可以很好地工作,但在处理大面积缺失时往往效果不佳。图 4 展示了这些技术之间的差异,其中大象的掩模被填充到图像的中心。这些技术通常严重依赖于目标修复区域边缘的图像信息/模式,并且在进行智能预测时,往往无法有效利用场景的全局上下文信息。

图 4: 物体移除应用 a) 不同类别方法的掩模和修复结果;b) 各向异性扩散,c) 示例基础;d) 补丁稀疏表示,e) 与全局能量最小化的混合;f) 补丁偏移。图片来源:Guillemot 和 Le Meur, 2013 (doi.org/10.1109/MSP.2013.2273004)。
然而,近年来计算机视觉领域迅猛发展,主要得益于计算能力的提升和更新、更高效的机器学习技术的出现。例如,使用生成对抗网络(GAN)的生成方法近年来非常流行,OpenAI 的 DALL-E 或 Stability.ai 的 Stable Diffusion 可以基于简单的文本提示生成令人惊叹的真实图像。之前也有一些尝试使用类似方法进行修复的工作,但在修复任务中,现实感与保真度/稳定性之间存在一定的权衡(Lugmayr 等,2017;Geiss 和 Hardin,2021)。
例如,虽然你可能生成了一个从人眼来看非常不错的图像区域,但如果与参考图像进行比较,实际的像素值不一定是正确的,并且可能会根据提供的随机噪声/种子变化较大(图 5)。不过,这并不令人意外,因为这些技术并非以此类约束为设计目标,而是为其他目的而存在。

图 5: 来自*Lugmayr 等人,2017 年(doi.org/10.1109/CVPR52688.2022.01117)的图像修复项目的去噪扩散概率模型(DDPM)示例集。
相反,在本工作中我们专注于另一种机器学习架构:U-Net。U-Net 是一类卷积神经网络(CNN),它以图像(通常是图像)的形式输入信息,并生成与输入相同维度的输出。U-Net 通常用于图像分割,其编码器-解码器架构使得模型能够学习图像的局部和全局特征(这种上下文对于正确解读图像内容,特别是在图像修复过程中,通常非常有价值)。我们将使用这种架构来教导模型学习高空云中的潜在特征,以预测前述雷达盲区的近地面反射率数据。
数据
本项目使用的数据主要来自两个来源:
这两个数据集是公开可用的(由知识共享署名 4.0 国际版(CC BY 4.0)授权),并位于美国阿拉斯加北部的两个北极地点——北坡和奥利克托克点(图 6)。由于我们关注的是降雪,因此我们将观测限制在 2 摄氏度以下的寒冷时期。此外,为了避免由于自相关导致的过拟合,数据被划分为连续的训练/验证/测试块,如图 6.c 所示。我们使用来自 ARM KaZR 的雷达反射率数据,以及来自 ERA-5 的温度、比湿、u 分量风和 v 分量风数据。

图 6: a) 研究地点位置;b) NSA 和 OLI 的地面气象温度时间序列;c) 用于模型训练的数据拆分方法。图像由作者提供。
我们的模型
在此任务中应使用哪种类型的 U-Net?我们在本项目中尝试了多种不同的 U-Net,包括 Zhou 等人,2018 年的 UNet++模型和 Huang 等人,2020 年的 3Net+。但让我们先聊聊模型架构。例如,为什么使用这些方法而不是传统的 U-Net?首先,我们回顾一下 U-Net 是如何工作的。
U-Net 通常被认为由三个部分组成:编码器路径、瓶颈层和解码器路径。编码器负责处理最初的高分辨率图像,通过一系列卷积和池化步骤,将空间信息转换为丰富的特征信息(学习图像的潜在上下文)。根据网络的深度,这些潜在特征在 U 形网络底部的最低维瓶颈层中被最密集地编码。在这一点上,图像数组的大小可能只有原始图像的一个小部分,尽管你已经丧失了大部分空间信息,但模型已经识别出一组嵌入,代表它认为的关键元素。然后,在解码器路径中,反向过程发生。瓶颈层中低分辨率、特征丰富的数组会被下采样,直到特征信息被转化回空间信息(生成一个最终图像,其尺寸与原始图像相同)。
U-Net、U-Net++ 和 3Net+ 之间的关键区别之一在于每种变体如何处理跳跃连接(图 7)。跳跃连接允许这些模型直接将一些数据从编码器路径跳跃到解码器中使用,这有助于在解码过程中传递低级特征信息,并生成一个更稳定的模型,使其能够以有意义的方式收敛。例如,在一个普通的 U-Net 中,这些连接只是将来自收缩编码器路径的特征图与解码扩展路径中相应的层连接起来。

图 7: a) U-Net;b) U-Net++;以及 c) 3Net+ 模型架构的比较。图像来源于 Huang 等人,2020(doi.org/10.48550/arXiv.2004.08790)。
UNet++ 引入了一系列嵌套和密集的跳跃路径,试图解决传统 U-Net 中最优架构的未知深度问题。UNet++ 不仅仅是从编码器到解码器有直接连接,它有多个跳跃路径。对于编码器中的每个层级,都有跳跃连接到解码器中的所有后续层级,形成一个密集的连接集。这些嵌套的跳跃连接旨在比普通的 U-Net 更有效地捕捉和融合不同语义层次的特征,然而这也带来了模型变大(更多参数)和训练时间增加的代价。
3Net+基于前述技术的思想,并且是我们最终模型(图 8)中使用的架构。该方法将跳跃连接分解为一系列跨跳跃连接(类似于经典的 U-Net)、跳跃间连接和跳跃内连接。这些跳跃间和跳跃内连接通过传递信息,结合特征图中的低级细节与高级语义,有效利用了场景中的多尺度特征,同时相比于 U-Net++模型,使用了更少的参数。

图 8: a) 128x128 大小的 KaZR 和 ERA-5 输入变量块;b) 我们模型的 3Net+架构及深度监督层。图像由作者提供。
此外,我们的模型利用深度监督从解码器每一层的全尺度聚合特征图中学习层次化表示。这帮助模型通过检查场景的更广泛上下文来正确定位盲区中的云。在接下来的部分中,我们将比较仅在反射率上训练的 3Net+模型(以下简称为 3+_1),另一个在反射率和 ERA5 数据上训练的版本(3+_5),以及两种使用重复外推(REP)和滑动平均(MAR)方法的线性修补技术。有关这些方法如何实现的更多细节,请参考我们的 AIES 论文。
保真度
为了全面评估模型在盲区重建准确性方面的表现,我们将首先检查一些常见的案例,随后对整个测试数据集进行更一般性的统计分析。请注意,从此以后展示的所有结果均严格来源于未见的测试集观测数据。
案例研究
展示了 REP、MAR、3+_1 和 3+_5 模型的盲区反射率值的示例(取自 NSA 和 OLI),以及对应的目标 KaZR VAP 产品(即地面真实值)在最左侧列中的数据,见图 9/10。
这第一组示例突出显示了在两个位置都常见的近地表反射率梯度和云隙的情况。黑色虚线表示 1.2 公里的盲区阈值(低于此值的区域被遮蔽并由各模型重建),阴影区域表示在 U-Net 预测中修补部分的高不确定性区域(后续将在蒙特卡罗丢弃部分详细讨论)。

图 9: 展示了 REP、MAR、3+_1 和 3+_5 模型的盲区反射率值的示例(取自 NSA 和 OLI),以及对应的目标 KaZR VAP 产品(即地面真实值)在最左侧列中的数据。这组示例突出显示了近地表反射率梯度和云隙的情况。图像由作者提供。
我们发现,线性模型(REP 和 MAR)在深层均匀系统中表现良好,但在更复杂的情况下表现不佳。此外,由于它们依赖于盲区阈值反射率,REP 和 MAR 未能捕捉到反射率梯度(垂直和水平),而这些梯度通常是通过 U-Net 捕捉到的。最后,浅层的北极混合相云也可以通过 U-Net 解决,包括云隙和雨丝的情况(图 10),这令人兴奋,因为这对地面降雪量有着重要的影响。

图 10: 与图 9 相同,现在关注于每个模型的补全预测中的浅层降雪和雨丝情况。图像由作者提供。
准确建模最具挑战性的情况是那些具有稀疏反射率剖面和遥远云层的情况。例如,考虑图 11 中展示的两个在不同年份发生在 NSA 的类似案例。从人的眼睛来看,这两个地点的云结构非常相似,然而一个在盲区内有一个浅层近地面云,而另一个则没有。线性补全技术在这里显然已超出了它们的舒适区,并且总是产生与 a)中所见相同的“无云”输出。然而,U-Net 模型仍然能够解决此类情况中的云存在问题,3+_5 模型通过使用来自 ERA-5 的附加上下文,更好地理解在这种情况下,大气条件可能导致盲区云的形成。

图 11: a) 一个例子,展示了一个高空云(其平均反射率约为-20 dBZ),在盲区内没有近地面活动;b) 一个与 a)类似的云结构,除了在这种情况下,确实存在一个近地面反射率带,在浅层云中,通过 3+_5 模型能够正确识别。图像由作者提供。
稳健性
如图 12 a)所示,U-Net 的 PSD 曲线与观测值相比,较线性方法更接近。此外,非线性重建产生了更为真实的最低云回波高度(如 b)中的概率密度图所示),这表明我们能更好地捕捉云的位置。最后,整个盲区结构在 c)-g)中得到了总结,其中线性方法能够捕捉到反射率的宏观尺度趋势,但未能捕捉到精细尺度的变化。我们注意到,U-Net 模型在-60 dBZ 附近存在轻微的“冷”偏差,这是由于它们倾向于给出更接近“无云”更为“安全”的估计,而高强度的降雪事件则较为罕见。

图 12: 云结构度量,包括:a) 垂直功率谱密度曲线;b) 最低反射率回波层概率密度图;c) 2D 反射率直方图。图像由作者提供。
此外,改善我们对地面降雪和虚假降雪情况的预测将对水文产生重大影响,并减少我们对雪积量的不确定性。因此,我们进行了检查,看看我们的模型在使用探测概率(POD)和误报率(FAR)时如何重建三个情况:
-
盲区云层存在(即是否检测到任何云层)
-
浅层降雪(即地面上有降雪,但在盲区阈值以下没有降雪)
-
虚假降雪(即在盲区阈值上检测到降雪,但在地面没有降雪)
每个指标的关键成功指数(CSI)如下面的性能图(图 13)所示,其中 3+_5 模型的整体表现最佳。浅层降雪的情况通常是最难重建的,因为我们看到这些情况通常很难准确重建(图 11)。

图 13: 各模型的云层、浅层降雪和虚假降雪检测性能图。图片来源:作者。
可解释性
为了进一步增加对模型决策过程的信任,我们还进行了系列的可解释人工智能(XAI)测试(即以机械化的方式理解模型行为)。这些测试旨在将模型行为与逻辑的物理过程联系起来,以激发对模型的信心,并为可能改善未来的回收提供额外的见解。如果我们能够发现数据中以前未知的关联,这将非常有价值!另外,每个单独的 XAI 方法给出的是“局部”的决策过程解释,因此结合多种测试以获得更稳健的理解是很有用的。
特征图
我们考虑的第一个、最基本的测试是特征/激活图。通过检查图 8.b 中编码器路径不同层次的 reLU 激活值,我们可以大致了解模型在给定输入的图像中注视的位置。如下面图 14 所示,3+_5 模型的 e1 编码器层通常关注云层边缘、反射率梯度以及盲区阈值的位置。

图 14: a) 输入到 3+_5 模型的反射率示例;b) 模型 e1 编码器层生成的 32 个特征图。图片来源:作者。
丢弃通道重要性
该项目的一个最大问题是 ERA-5 是否为模型提供了有用的背景。如果我们能够使用一个更简单的 U-Net 模型,仅依赖反射率(例如 3+_1 模型),那么我们应该这样做,因为这样更具计算效率。然而,如果 ERA-5 的额外大气状态变量为模型提供了有用的背景来涂色复杂系统,那么使用这个更复杂的模型可能是有必要的。
由于在本例中我们只有少量的输入(1 个必需的(雷达)和 4 个辅助的(温度、湿度、u-风和 v-风)),我们可以对输入空间进行穷举搜索,以评估它们对准确性的边际贡献。更正式地说,这种丢弃通道方法使用下面的公式(公式 1 / 公式 2)来计算提供的输入的边际重要性贡献。请注意,这种技术并未考虑输入之间可能的非线性相互作用。

公式 1 / 公式 2: 用于计算来自 N 个输入的边际重要性。图片来源:作者。
如果我们执行一系列此类测试运行(25 个周期),并检查验证损失的变化,我们可以大致了解哪些输入最有用。这些测试的结果如下图 15 所示,我们注意到随着我们添加更多 ERA-5 输入,验证损失呈下降趋势(这表明没有任何输入完全不重要)。此外,边际贡献到验证损失的结果表明,风数据总体上是最具影响力的。我们认为这种重要性可能源于这样一个事实:在对流层上层,风模式可以暗示中尺度大气动力学,如高压或低压系统、锋面和急流(这些当然与云/水气象的形成有关)。

图 15: 输入通道组合对 3Net+ 模型涂色性能的贡献。图片来源:作者。
显著性图
最后,我们还检查了一些案例的显著性图(图 16),以进一步比较 3+_5 和 3+_1 模型之间的重要性差异。这些像素归因的原生梯度显著性图灵感来自 Simonyan 等人(2014)的工作,并为模型识别为涂色准确性贡献重要信息的区域提供了额外的洞察。这些显著性图是通过将图像输入到网络中,然后提取基于输入在所有通道上输出的梯度来生成的。尽管这种方法简化,但它对于可视化观察图像中哪些部分在涂色盲区反射值时最有价值非常有用,允许直接绘制激活梯度。

图 16: 3+_1 和 3+_5 模型的原始梯度显著性图,以及在 NSA 的一些案例中对应的 ERA5 大气状态变量。图片由作者提供。
对于与盲区切割区域交叉的多层云(例如图 16.a),两个模型都聚焦于云的顶部和 1.2 公里的边界阈值,因为这些系统通常会延伸至地面,并具有相似的反射率强度。两个模型通常还会关注并围绕深层系统中的云间隙(例如图 16.b),然而,3+_5 模型中明显出现了一个重要的环带,指向平流层。这一反复出现的特征可能将上层对流层的风和湿度数据纳入了对近地面反射率的预测。有趣的是,3+_1 模型并不仅仅专注于场景中高反射率区域,还关注云周围的区域。
应用
本工作的主要目标是最终将训练好的表面 U-Net 应用于空间观测。尽管在这两个系统之间的分辨率匹配方面还需要完成额外的工作,但我们已针对靠近 NSA 的重合 CloudSat-CPR 观测进行了早期测试。这里的想法是,尽管这两个系统(虽然并未完全重叠)将观测到相同风暴系统的相似部分。我们考虑了一些示例,并在下方包括了一个浅层积云降雪案例。

图 17: a) 显示 CloudSat 过境颗粒路径(红色)、128 步站点重合的 CloudSat 足迹以及显示最接近站点的观测轮廓的白色点(黑圆圈代表站点周围 50 公里的半径);b) 浅层积云系统的 CloudSat 反射率轮廓,两个虚线之间的区域表示 a) 中的蓝色(重合)部分,白色阴影区域显示卫星盲区;c) 垂直 NSA KaZR 反射率轮廓,b) 中的对应过境时期显示在两个虚线之间的区域;d) 来自 c) 的 KaZR CloudSat 重合轮廓的特写;e) 来自 b) 的 CloudSat-NSA 重合轮廓;f-h) 分别为提供 e) 作为输入时的 REP、MAR 和 3+_1 修复盲区场景。图片由作者提供。
在这个例子中,我们注意到,空间 borne 和地面雷达都观测到大约 3 公里高的浅层云(然而 CloudSat 因表面杂波错过了盲区下方的反射率梯度增加)。当使用传统技术和我们的 U-Net 重建该区域时,我们发现 U-Net 是唯一能够准确表示大约 1 公里范围内反射率增加带的方法。更正式地说,如果我们查看最接近该站点的 CloudSat 观测值(白色虚线)与每个最接近重建区域之间的结构,使用 U-Net 时 Pearson 相关性显著提高(r_MAR=0.13 到 r_3+_1=0.74)。
尽管这些示例并未构成一个全面的分析,无法让我们对整体性能做出结论,但它们确实表明我们观察到的技能与我们在查看模拟的表面盲区时所注意到的一致。进一步的空间 borne 仪器应用工作正在进行中。
最终备注
在结束这篇已经很长的文章之前,我想强调一些我们在模型中加入的其他特性,并为那些有兴趣开发自己修补模型的人提供一些训练代码示例。
蒙特卡洛 Dropout
与传统的贝叶斯方法不同,我们并没有直接使用 U-Net 来生成基于物理的无关度估计。为了大致了解模型的信心和稳定性,我们决定在推理层引入 Dropout,基于 Gal 和 Ghahramani 2016 年的研究,这使我们能够为每个测试案例生成修补预测的分布。这些分布使我们能够为每个修补的像素生成置信区间,并进一步精细化我们的估计,聚焦于模型在修补时更有信心的区域。下图 17 展示了一个例子。

图 18: 蒙特卡洛 Dropout 示例输出(n=50 次迭代)。图像由作者提供。
我们通常对每个案例使用 N=50 次迭代,正如我们上面看到的,通常具有最高不确定性的区域是云边缘和云隙,因为模型在定位这些特征时常常会出现幻觉。
训练统计
本项目的模型训练在两种硬件环境中完成,包括基于 Linux 的 GPU 计算集群(托管在 Microsoft Azure 上)和一台运行 Windows 11 的高性能桌面计算机(更多系统细节见表 1)。在两天的时间里,还进行了广泛的贝叶斯超参数搜索。此外,在训练过程中应用了批量归一化、早停(n=20)、dropout 和 L2 正则化(岭回归)来帮助缓解过拟合问题。学习率衰减也在两个周期(450 和 475)应用,使模型能够更容易地在训练阶段结束时接近局部损失最小值。所有的训练运行和超参数搜索结果都通过Weights & Biases 云存储选项在线保存,以便监控模型的学习率和稳定性。

表 1: 用于模型训练的硬件摘要详情。图片由作者提供。
示例代码
GitHub 链接如下:github.com/frasertheking/blindzone_inpainting
但是,我想为那些有兴趣尝试的人提供一个实际的 3Net+实现概览(具有可变深度),该实现是基于 Tensorflow 的。
def conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same', is_bn=True, is_relu=True, n=2, l2_reg=1e-4):
for _ in range(1, n+1):
x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size,
padding=padding, strides=strides,
kernel_regularizer=tf.keras.regularizers.l2(l2_reg),
kernel_initializer=k.initializers.he_normal(seed=42))(x)
if is_bn:
x = k.layers.BatchNormalization()(x)
if is_relu:
x = k.activations.relu(x)
return x
def unet3plus(input_shape, output_channels, config, depth=4, training=False, clm=False):
""" Prep """
interp = config['interpolation']
input_layer = k.layers.Input(shape=input_shape, name="input_layer")
xpre = preprocess(input_layer, output_channels)
""" Encoder """
encoders = []
for i in range(depth+1):
if i == 0:
e = conv_block(xpre, config['filters']*(2**i), kernel_size=(config['kernel_size'], config['kernel_size']), l2_reg=config['l2_reg'])
else:
e = k.layers.MaxPool2D(pool_size=(2, 2))(encoders[i-1])
e = k.layers.Dropout(config['dropout'])(e, training=True)
e = conv_block(e, config['filters']*(2**i), kernel_size=(config['kernel_size'], config['kernel_size']), l2_reg=config['l2_reg'])
encoders.append(e)
""" Middle """
cat_channels = config['filters']
cat_blocks = depth+1
upsample_channels = cat_blocks * cat_channels
""" Decoder """
decoders = []
for d in reversed(range(depth+1)):
if d == 0 :
continue
loc_dec = []
decoder_pos = len(decoders)
for e in range(len(encoders)):
if d > e+1:
e_d = k.layers.MaxPool2D(pool_size=(2**(d-e-1), 2**(d-e-1)))(encoders[e])
e_d = k.layers.Dropout(config['dropout'])(e_d, training=True)
e_d = conv_block(e_d, cat_channels, kernel_size=(config['kernel_size'], config['kernel_size']), n=1, l2_reg=config['l2_reg'])
elif d == e+1:
e_d = conv_block(encoders[e], cat_channels, kernel_size=(config['kernel_size'], config['kernel_size']), n=1, l2_reg=config['l2_reg'])
elif e+1 == len(encoders):
e_d = k.layers.UpSampling2D(size=(2**(e+1-d), 2**(e+1-d)), interpolation=interp)(encoders[e])
e_d = k.layers.Dropout(config['dropout'])(e_d, training=True)
e_d = conv_block(e_d, cat_channels, kernel_size=(config['kernel_size'], config['kernel_size']), n=1, l2_reg=config['l2_reg'])
else:
e_d = k.layers.UpSampling2D(size=(2**(e+1-d), 2**(e+1-d)), interpolation=interp)(decoders[decoder_pos-1])
e_d = k.layers.Dropout(config['dropout'])(e_d, training=True)
e_d = conv_block(e_d, cat_channels, kernel_size=(config['kernel_size'], config['kernel_size']), n=1, l2_reg=config['l2_reg'])
decoder_pos -= 1
loc_dec.append(e_d)
de = k.layers.concatenate(loc_dec)
de = conv_block(de, upsample_channels, kernel_size=(config['kernel_size'], config['kernel_size']), n=1, l2_reg=config['l2_reg'])
decoders.append(de)
""" Final """
d1 = decoders[len(decoders)-1]
d1 = conv_block(d1, output_channels, kernel_size=(config['kernel_size'], config['kernel_size']), n=1, is_bn=False, is_relu=False, l2_reg=config['l2_reg'])
outputs = [d1]
""" Deep Supervision """
if training:
for i in reversed(range(len(decoders))):
if i == 0:
e = conv_block(encoders[len(encoders)-1], output_channels, kernel_size=(config['kernel_size'], config['kernel_size']), n=1, is_bn=False, is_relu=False, l2_reg=config['l2_reg'])
e = k.layers.UpSampling2D(size=(2**(len(decoders)-i), 2**(len(decoders)-i)), interpolation=interp)(e)
outputs.append(e)
else:
d = conv_block(decoders[i - 1], output_channels, kernel_size=(config['kernel_size'], config['kernel_size']), n=1, is_bn=False, is_relu=False, l2_reg=config['l2_reg'])
d = k.layers.UpSampling2D(size=(2**(len(decoders)-i), 2**(len(decoders)-i)), interpolation=interp)(d)
outputs.append(d)
if training:
for i in range(len(outputs)):
if i == 0:
continue
d_e = outputs[i]
d_e = k.layers.concatenate([out1, out2, out3])
outputs[i] = merge_output(input_layer, k.activations.linear(d_e), output_channels)
return tf.keras.Model(inputs=input_layer, outputs=outputs, name='UNet3Plus')
未来
我知道这篇文章很长,我们涵盖了很多内容,但我想快速总结一下我们讨论的所有内容,特别是对那些读到这里的读者(或者跳到最后的读者)。
卫星雷达盲区是卫星地球观测降水任务中的一个持续问题,对全球水能预算计算有重要影响。为了克服传统线性插值方法在填补该区域时常见的问题,我们选择了使用非线性、深度监督的 U-Net 进行雷达盲区插值。U-Net 在几乎所有评估指标上都优于线性技术,甚至能够重建复杂的云结构,如多层云、云隙和浅层云。此外,通过使用各种可解释人工智能(XAI)技术,我们发现位于盲区阈值附近以及对流层顶沿线(尤其是风信息)的数据对模型的决策过程非常有用。尽管我们不建议这些模型完全替代当前基于物理的解决方案,但我们认为它们提供了一个独特的新视角,可以在未来的任务中补充其他回收方法。
我们目前正在进行一个跟进项目,直接应用于 CloudSat-CPR 观测。
参考文献
Gal, Y., & Ghahramani, Z. (2016). Dropout 作为贝叶斯近似:在深度学习中表示模型不确定性(arXiv:1506.02142)。arXiv. doi.org/10.48550/arXiv.1506.02142
Geiss, A., & Hardin, J. C. (2021). 使用深度学习修复雷达缺失数据区域。大气测量技术,14(12),7729–7747. doi.org/10.5194/amt-14-7729-2021
Guillemot, C., & Le Meur, O. (2014). 图像修复:概述与最新进展。IEEE 信号处理杂志,31(1),127–144. doi.org/10.1109/MSP.2013.2273004
Huang, H., Lin, L., Tong, R., Hu, H., Zhang, Q., Iwamoto, Y., Han, X., Chen, Y.-W., & Wu, J. (2020). UNet 3+: 一种全规模连接的 UNet 医学图像分割方法 (arXiv:2004.08790)。arXiv. doi.org/10.48550/arXiv.2004.08790
Kidd, C., Graham, E., Smyth, T., & Gill, M. (2021). 使用表面雷达和微雨雷达观测评估卫星观测的浅层/轻度降水反演的影响。遥感,13(9),文章 9. doi.org/10.3390/rs13091708
King, F., & Fletcher, C. G. (2020). 使用 CloudSat-CPR 检索估算加拿大北极地区的积雪量。地球与空间科学,7(2),e2019EA000776. doi.org/10.1029/2019EA000776
Lamer, K., Kollias, P., Battaglia, A., & Preval, S. (2020). 注意差距 — 第一部分:利用卫星雷达准确定位温暖海洋边界层云和降水。大气测量技术,13(5),2363–2379. doi.org/10.5194/amt-13-2363-2020
Lugmayr, A., Danelljan, M., Romero, A., Yu, F., Timofte, R., & Van Gool, L. (2022). RePaint: 使用去噪扩散概率模型进行图像修复。2022 年 IEEE/CVF 计算机视觉与模式识别大会(CVPR),11451–11461. doi.org/10.1109/CVPR52688.2022.01117
Maahn, M., Burgard, C., Crewell, S., Gorodetskaya, I. V., Kneifel, S., Lhermitte, S., Van Tricht, K., & van Lipzig, N. P. M. (2014). 太空雷达盲区如何影响极地地区衍生的地表降雪统计数据?《地球物理研究:大气》, 119(24), 13,604–13,620. doi.org/10.1002/2014JD022079
Simonyan, K., Vedaldi, A., & Zisserman, A. (2014). 深入卷积神经网络:可视化图像分类模型与显著性图 (arXiv:1312.6034)。arXiv. doi.org/10.48550/arXiv.1312.6034
Zhou, Z., Siddiquee, M. M. R., Tajbakhsh, N., & Liang, J. (2018). UNet++:一种用于医学图像分割的嵌套 U-Net 架构 (arXiv:1807.10165)。arXiv. doi.org/10.48550/arXiv.1807.10165
超越炒作:当生成性 AI 并非总是答案时
为什么预测性 AI 可能仍然是你的最佳选择
·发布于Towards Data Science ·阅读时间 8 分钟·2024 年 10 月 1 日
--

图片由Maria Lupan提供,来自Unsplash
我在谷歌的解决方案与思想领导团队(即 S&TL)工作。我们的职责是帮助大型公司采用 AI 技术,以提高它们的业绩并进行创新。在过去三年中,我与北美不同的 AI 团队和决策者合作,构建并测试针对特定业务需求的定制 AI 模型。
到了 2024 年,我注意到我与企业利益相关者的互动发生了显著变化。他们都希望使用生成性 AI来解决他们的业务问题。
对于其中一些客户,如果我问他们为什么选择生成性 AI,他们通常会带着好奇的目光看着我。他们并不完全理解这个问题。对于另一些客户,他们提到的是董事会或 C 级高管确定的战略目标。但往往,生成性 AI 并不是解决他们业务优先事项的最佳答案。
不,AI 并不是在 2022 年 11 月才被发明的
当 OpenAI 在 2022 年 11 月发布 ChatGPT 3.5 时,全球数百万的人开始玩这个引人入胜的新工具。我的邻居、杂货店的店员、我的叔叔和我的小侄女。AI 用户无处不在,人人谈论。
偏差-方差权衡解析:为初学者提供的带有代码示例的视觉指南
模型评估与优化
欠拟合与过拟合如何在你的模型上“斗争”
·发表于 Towards Data Science ·阅读时长 20 分钟·2024 年 11 月 25 日
--
每当有人构建预测模型时,他们都会面临这些经典问题:欠拟合和过拟合。模型不能太简单,但也不能过于复杂。这两者之间的互动被称为偏差-方差权衡,它影响着所有预测模型。
关于“偏差-方差权衡”这一主题的问题是,每当你尝试在线查找这些术语时,你会发现很多文章展示了完美的图表曲线。是的,它们解释了基本概念——但它们忽略了一个重要的点:它们过于关注理论,而不够关注现实世界中的问题,且很少展示在处理实际数据时会发生什么。
在这里,我们将不使用理论示例,而是使用一个真实数据集并构建实际的模型。一步一步地,我们将确切地看到模型是如何失败的,欠拟合和过拟合在实践中是怎样表现的,以及为什么找到正确的平衡如此重要。让我们停止偏差和方差之间的斗争,找到一个公平的中间地带。

所有视觉效果:作者使用 Canva Pro 创建。已优化为移动端显示,可能在桌面端显示过大。
什么是偏差-方差权衡?
在我们开始之前,为了避免混淆,让我们澄清一下在这里机器学习中使用的“偏差”与“方差”这两个术语。这些词在数学和数据科学的许多领域中的使用方式是不同的。
偏差可以有多种含义。在统计学中,它表示我们的计算与真实答案之间的偏离程度,而在数据科学中,它可以指对某些群体的不公平对待。即使是在机器学习的另一部分,在神经网络中,它是一个帮助网络学习的特殊数字。
方差也有不同的含义。 在统计学中,它告诉我们数字与其平均值之间的分散程度,而在科学实验中,它表示每次我们重复实验时结果的变化程度。
但在机器学习的“偏差-方差权衡”中,这些词语有特殊的含义。
偏差指的是模型学习模式的能力。当我们说一个模型有高偏差时,我们的意思是它太简单了,并且不断重复同样的错误。
方差在这里指的是当你给模型不同的训练数据时,它的答案会发生多大变化。当我们说方差很高时,我们的意思是模型在看到新数据时,其答案变化过大。
“偏差-方差权衡”并不是我们可以通过数字精确测量的东西。相反,它帮助我们理解我们的模型是如何工作的:如果一个模型有很高的偏差,它在训练数据和测试数据上的表现都不好;而如果一个模型有很高的方差,它在训练数据上表现很好,但在测试数据上表现较差。
这帮助我们修复模型在表现不佳时的问题。让我们设置我们的任务和数据集,看看如何应用这个概念。
⛳️ 设置我们的任务
训练与测试数据集
假设你拥有一个高尔夫球场,现在你试图预测某一天有多少球员会到场。你已经收集了关于天气的数据:从一般的天气概况到温度和湿度的详细信息。你想利用这些天气条件来预测将会有多少球员到来。

列:‘天气概况(晴天、阴天、雨天)’,‘温度’(华氏度),‘湿度’(百分比),‘风力’(是/否)和‘球员人数’(目标特征)
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
# Data preparation
dataset_dict = {
'Outlook': ['sunny', 'sunny', 'overcast', 'rain', 'rain', 'overcast', 'sunny', 'overcast', 'rain', 'sunny', 'overcast', 'rain', 'sunny', 'rain',
'sunny', 'overcast', 'rain', 'sunny', 'rain', 'overcast', 'sunny', 'rain', 'overcast', 'sunny', 'overcast', 'rain', 'sunny', 'rain'],
'Temp.': [92.0, 78.0, 75.0, 70.0, 62.0, 68.0, 85.0, 73.0, 65.0, 88.0, 76.0, 63.0, 83.0, 66.0,
91.0, 77.0, 64.0, 79.0, 61.0, 72.0, 86.0, 67.0, 74.0, 89.0, 75.0, 65.0, 82.0, 63.0],
'Humid.': [95.0, 65.0, 82.0, 90.0, 75.0, 70.0, 88.0, 78.0, 95.0, 72.0, 80.0, 85.0, 68.0, 92.0,
93.0, 80.0, 88.0, 70.0, 78.0, 75.0, 85.0, 92.0, 77.0, 68.0, 83.0, 90.0, 65.0, 87.0],
'Wind': [False, False, False, True, False, False, False, True, False, False, True, True, False, True,
True, True, False, False, True, False, True, True, False, False, True, False, False, True],
'Num_Players': [25, 85, 80, 30, 17, 82, 45, 78, 32, 65, 70, 20, 87, 24,
28, 68, 35, 75, 25, 72, 55, 32, 70, 80, 65, 24, 85, 25]
}
# Data preprocessing
df = pd.DataFrame(dataset_dict)
df = pd.get_dummies(df, columns=['Outlook'], prefix='', prefix_sep='', dtype=int)
df['Wind'] = df['Wind'].astype(int)
这听起来很简单,但有一个问题。我们只有 28 天的数据——这不多!为了让事情变得更复杂,我们需要将这些数据分为两部分:14 天的数据用来帮助我们的模型学习(我们称之为训练数据),而剩下的 14 天用来测试我们的模型是否有效(测试数据)。

前 14 个数据集将用于训练模型,而最后 14 个将用于测试模型。
# Split features and target
X, y = df.drop('Num_Players', axis=1), df['Num_Players']
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.5, shuffle=False)
想一想这有多难。天气条件的组合有很多种。它可以是阳光明媚且潮湿、阳光明媚且凉爽、下雨且有风、阴天且凉爽,或其他组合。只有 14 天的训练数据,我们肯定无法看到每一种可能的天气组合。但我们的模型仍然需要对它可能遇到的任何天气条件做出准确预测。
这就是我们挑战的开始。如果我们让模型过于简单——比如只关注温度——它将忽略像风和雨这样的重要细节。这样是不够的。但如果我们让模型过于复杂——试图考虑每一个微小的天气变化——它可能会认为在一个多雨的周里,某个随机的安静日子意味着雨水实际上带来了更多的玩家。只有 14 个训练样本时,模型很容易变得混淆。
这里有个问题:与许多你在网上看到的例子不同,我们的数据并不完美。有些日子天气相似,但玩家数量不同。也许那天有本地活动,或者那天是节假日——但我们的天气数据无法告诉我们这些。这正是现实世界预测问题的复杂性所在。
所以,在我们开始构建模型之前,先花点时间了解我们正在尝试做的事情:
使用仅有的 14 个例子来创建一个可以预测任何天气条件下玩家数量的模型,即使是它之前没见过的天气条件。
这就是使得偏差-方差权衡如此重要的问题。
模型复杂度
对于我们的预测,我们将使用深度不同的决策树回归器(如果你想了解这个是如何工作的,可以查看我关于决策树基础的文章)。对我们讨论来说,重要的是我们让这个模型变得多复杂。

我们将使用整个训练数据集来训练决策树。树的深度首先设置,以防止树生长到某一深度。
from sklearn.tree import DecisionTreeRegressor
# Define constants
RANDOM_STATE = 3 # As regression tree can be sensitive, setting this parameter assures that we always get the same tree
MAX_DEPTH = 5
# Initialize models
trees = {depth: DecisionTreeRegressor(max_depth=depth, random_state=RANDOM_STATE).fit(X_train, y_train)
for depth in range(1, MAX_DEPTH + 1)}
我们将通过控制模型的深度来控制模型的复杂度——从深度 1(最简单)到深度 5(最复杂)。

import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
# Plot trees
for depth in range(1, MAX_DEPTH + 1):
plt.figure(figsize=(12, 0.5*depth+1.5), dpi=300)
plot_tree(trees[depth], feature_names=X_train.columns.tolist(),
filled=True, rounded=True, impurity=False, precision=1, fontsize=8)
plt.title(f'Depth {depth}')
plt.show()





为什么这些复杂度级别很重要:
-
深度 1:极其简单 — 只创建几种不同的预测
-
深度 2:稍微灵活一些 — 可以创建更多样化的预测
-
深度 3:适度复杂度 — 接近过多规则
-
深度 4–5:最高复杂度 — 每个训练样本几乎有一个规则
发现了什么有趣的事情吗?我们最复杂的模型(深度 5)几乎为每一个训练示例创建了不同的预测规则。当一个模型开始为几乎每个训练示例都生成独特的规则时,这是一个明确的信号,说明我们已经将模型做得对我们的小数据集来说过于复杂。
在接下来的章节中,我们将看到这些不同复杂度的模型在我们高尔夫球场数据上的表现,并探讨为什么找到合适的复杂度对做出可靠预测至关重要。
什么样的模型“好”?
预测误差
预测的主要目标是使猜测尽可能接近真实值。我们需要一种衡量误差的方式,这种方式把高估或低估视为同样的不正确。预测值比真实答案高 10 个单位和低 10 个单位是一样错误的。
这就是我们使用均方根误差(RMSE)作为衡量标准的原因。RMSE 告诉我们预测误差的典型大小。如果 RMSE 是 7,意味着我们的预测通常会偏离真实值约 7 个单位。如果是 3,通常偏离约 3 个单位。较低的 RMSE 意味着更好的预测效果。

在上面的简单 5 点数据集中,我们可以说我们的预测大约偏离 3 个人。
在衡量模型性能时,我们通常会计算两种不同的误差。首先是训练误差——即模型在它学习过的数据上的表现。其次是测试误差——即模型在它从未见过的新数据上的表现。这个测试误差非常关键,因为它告诉我们模型在面对现实世界中新的数据时的表现如何。
⛳️ 查看我们的高尔夫球场预测
在我们的高尔夫球场案例中,我们试图根据天气状况预测每日的玩家数量。我们拥有来自 28 个不同日期的数据,并将其分成两部分:
-
训练数据:模型用来学习模式的 14 天记录
-
测试数据:我们隐藏在模型之外的 14 天不同记录
使用我们构建的模型,让我们测试训练数据和测试数据,并计算它们的 RMSE。


# Create training predictions DataFrame
train_predictions = pd.DataFrame({
f'Depth_{i}': trees[i].predict(X_train) for i in range(1, MAX_DEPTH + 1)
})
#train_predictions['Actual'] = y_train.values
train_predictions.index = X_train.index
# Create test predictions DataFrame
test_predictions = pd.DataFrame({
f'Depth_{i}': trees[i].predict(X_test) for i in range(1, MAX_DEPTH + 1)
})
#test_predictions['Actual'] = y_test.values
test_predictions.index = X_test.index
print("\nTraining Predictions:")
print(train_predictions.round(1))
print("\nTest Predictions:")
print(test_predictions.round(1))

from sklearn.metrics import root_mean_squared_error
# Calculate RMSE values
train_rmse = {depth: root_mean_squared_error(y_train, tree.predict(X_train))
for depth, tree in trees.items()}
test_rmse = {depth: root_mean_squared_error(y_test, tree.predict(X_test))
for depth, tree in trees.items()}
# Print RMSE summary as DataFrame
summary_df = pd.DataFrame({
'Train RMSE': train_rmse.values(),
'Test RMSE': test_rmse.values()
}, index=range(1, MAX_DEPTH + 1))
summary_df.index.name = 'max_depth'
print("\nSummary of RMSE values:")
print(summary_df.round(2))

从这些数字来看,我们已经能够看到一些有趣的模式:随着我们让模型变得越来越复杂,它们在预测之前已经见过的日期的玩家数量时变得越来越准确——直到我们最复杂的模型在训练数据上做出了完美的预测。

但真正的考验是它们在预测新日期的玩家数量时的表现。在这里,我们看到了一些不同的情况。尽管增加一些复杂性有助于提高表现(从深度 1 到深度 3,测试误差不断改善),但是将模型做得过于复杂(深度 4-5)实际上会开始导致效果变差。

训练和测试表现之间的差异(从误差为 3-4 名玩家到误差为 9 名玩家)揭示了预测中的一个根本挑战:在新的、未见过的情况中表现良好,比在熟悉的情况下表现好要难得多。即使是我们表现最好的模型,也能看到训练和测试表现之间的差距。

# Create figure
plt.figure(figsize=(4, 3), dpi=300)
ax = plt.gca()
# Plot main lines
plt.plot(summary_df.index, summary_df['Train RMSE'], marker='o', label='Train RMSE',
linestyle='-', color='crimson', alpha=0.1)
plt.plot(summary_df.index, summary_df['Test RMSE'], marker='o', label='Test RMSE',
linestyle='-', color='crimson', alpha=0.6)
# Add vertical lines and difference labels
for depth in summary_df.index:
train_val = summary_df.loc[depth, 'Train RMSE']
test_val = summary_df.loc[depth, 'Test RMSE']
diff = abs(test_val - train_val)
# Draw vertical line
plt.vlines(x=depth, ymin=min(train_val, test_val), ymax=max(train_val, test_val),
colors='black', linestyles='-', lw=0.5)
# Add white box behind text
bbox_props = dict(boxstyle="round,pad=0.1", fc="white", ec="white")
plt.text(depth - 0.15, (train_val + test_val) / 2, f'{diff:.1f}',
verticalalignment='center', fontsize=9, fontweight='bold',
bbox=bbox_props)
# Customize plot
plt.xlabel('Max Depth')
plt.ylabel('RMSE')
plt.title('Train vs Test RMSE by Tree Depth')
plt.grid(True, linestyle='--', alpha=0.2)
plt.legend()
# Remove spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# Set limits
plt.xlim(0.8, 5.2)
plt.ylim(0, summary_df['Train RMSE'].max() * 1.1)
plt.tight_layout()
plt.show()
接下来,我们将探讨模型失败的两种主要方式:通过持续不准确的预测(偏差)或通过极其不一致的预测(方差)。
理解偏差(当模型欠拟合时)
什么是偏差?
偏差发生在模型通过过于简单无法捕捉到重要模式时。具有高偏差的模型会持续犯大错,因为它错过了关键的关系。可以把它理解为以可预测的方式始终错误。
当模型欠拟合时,会表现出以下特定的行为:
-
不同预测中的误差相似
-
训练误差很高
-
测试误差也很高
-
训练和测试误差接近
高偏差和欠拟合表明我们的模型需要更加复杂——它需要关注数据中的更多模式。但我们如何发现这个问题呢?我们查看训练和测试的误差。如果两个误差都很高,并且彼此相似,我们很可能遇到偏差问题。
⛳️ 查看我们的简单高尔夫球场模型
让我们检查一下我们最简单模型(深度为 1)的表现:

-
训练 RMSE:16.13
平均而言,即使是它训练过的日期,误差也大约为 16 名玩家。
-
测试 RMSE:13.26
对于新的日期,模型预测误差大约为 13 名玩家。
这些数字讲述了一个重要的故事。首先,注意到这两个误差都很高。误差为 13-16 名玩家,在很多日期的玩家数量在 20 到 80 之间时,这个误差很大。其次,尽管测试误差更高(如我们所料),但两者的误差都明显较大。
深入分析正在发生的情况:

-
在深度为 1 时,我们的模型只能做出一个分裂决策。它可能只是根据是否下雨来分裂日期,从而只创建两种可能的玩家数量预测。这意味着许多不同的天气条件被归类到相同的预测中。
-
这些误差遵循明显的模式:
-
在炎热、潮湿的日子里:模型预测的玩家数过多,因为它只看是否下雨
-
在凉爽、完美的日子里:模型预测的玩家数过少,因为它忽略了良好的比赛条件。
-
-
最具说明性的是训练误差和测试误差的相似性。两者都很高,这意味着即使是在模型训练过的日期上进行预测,模型也表现得很差。这是高偏差的最明显迹象——模型过于简单,甚至无法捕捉到其训练数据中的模式。
这是欠拟合的关键问题:模型缺乏捕捉影响玩家人数的重要天气条件组合的复杂性。每个预测都是以可预测的方式错误,因为模型根本无法同时考虑多个天气因素。
这个解决方案看起来很明显:使模型更加复杂,以便它能够同时考虑多种天气条件。但正如我们在下一节将看到的,这会带来一些新的问题。
理解方差(当模型过拟合时)
什么是方差?
方差发生在模型过拟合时,它变得过于复杂,并对数据中的小变化过于敏感。虽然欠拟合模型忽视了重要的模式,但过拟合模型则相反——它把每一个微小的细节都当作重要的模式来处理。
一个过拟合的模型表现出以下特点:
-
训练数据中的非常小的错误
-
测试数据中的更大误差
-
训练和测试误差之间的巨大差距
-
随着数据变化剧烈变化的预测
当数据集很小的时候,这个问题尤其危险。当我们只有几个示例供学习时,过拟合模型可能会完美地记住所有这些示例,而没有学习到真正重要的模式。
⛳️ 看看我们复杂的高尔夫球场模型
让我们来检视一下我们最复杂模型(深度为 5)的表现:

-
训练 RMSE:0.00
完美的预测!训练数据中没有一个错误
-
测试 RMSE:9.14
但是在新的一天,它的预测偏差大约是 9 到 10 个玩家
这些数字揭示了一个典型的过拟合案例。零的训练误差意味着我们的模型学会了预测每个它训练过的特定日期的玩家人数。听起来不错,对吧?但是看看测试误差——它高得多。训练和测试表现之间的巨大差距(从 0 到 9-10 个玩家)是一个红旗。
更深入地看一下发生了什么:

-
在深度为 5 时,我们的模型会创建非常具体的规则。例如:
- 如果没有下雨并且温度是 76°F 并且湿度是 80%并且有风 → 预测恰好 70 个玩家
每个规则仅基于我们训练数据中的一两天。
-
当模型在测试数据中看到稍微不同的条件时,它会感到困惑。
这与我们上面的第一个规则非常相似,但模型可能会预测一个完全不同的数字
-
在只有 14 个训练示例的情况下,每个训练日都会有自己非常具体的一组规则。模型并没有学习天气如何影响玩家人数的普遍模式——它只是记住了每个特定日子发生了什么。
特别有趣的是,尽管这个过拟合模型比我们的欠拟合模型表现得更好(测试误差 9.15),但它实际上比我们适度复杂的模型更差。这表明,增加过多的复杂性可能会开始损害我们的预测,尽管训练表现看起来是完美的。
这就是过拟合的根本挑战:模型变得过于专注于对训练数据做出完美预测,以至于无法学习能够帮助其预测新情况的通用模式。当处理像我们这样的小数据集时,尤其是有问题的,因为为每个训练样本创建一个独特的规则会使我们无法可靠地处理新情况。
寻找平衡
核心问题
现在我们已经看到了两个问题——欠拟合和过拟合——让我们看看当我们尝试解决它们时会发生什么。这就是偏差-方差权衡的真正挑战所在。
看看我们在使模型变得更复杂时,它们的表现:

这些数字讲述了一个重要的故事。随着我们使模型变得更复杂:
-
训练误差持续改善(16.3 → 6.7 → 3.6 → 1.1 → 0.0)
-
测试误差最初显著改善(13.3 → 10.1 → 7.3)
-
但随后测试误差略微变差(7.3 → 8.8 → 9.1)
为什么会发生这种情况
这个模式不是巧合——它是偏差-方差权衡的基本特性。
当我们使模型变得更复杂时:
-
它不太可能欠拟合训练数据(偏差减少)
-
但它变得更容易对小的变化发生过拟合(方差增加)
我们的高尔夫球场数据清晰地展示了这一点:
-
深度 1 模型欠拟合严重——它只能将天数分为两组,导致到处都有大的误差
-
增加复杂性有所帮助——深度 2 可以考虑更多的天气组合,而深度 3 发现了更好的模式
-
但是深度 4 开始出现过拟合——为几乎每个训练天创建独特的规则
最佳点出现在我们的深度 3 模型中:

这个模型足够复杂,避免了欠拟合,同时又足够简单,避免了过拟合。它在所有模型中具有最佳的测试表现(RMSE 7.13)。
现实世界的影响
在我们的高尔夫球场预测中,这种权衡有着真实的后果:
-
深度 1:通过只看温度来欠拟合,错过了关于雨量或风速的关键信息
-
深度 2:可以结合两个因素,如温度和雨量
-
深度 3:能够发现类似“温暖、低湿度、没有雨意味着高出勤率”的模式
-
深度 4-5:通过不可靠的规则过拟合,如“在风大的日子里,温度恰好是 76°F,湿度是 80%,意味着恰好 70 名球员”
这就是为什么找到正确的平衡很重要。只有 14 个训练样本,每个关于模型复杂度的决策都会产生很大影响。我们的深度 3 模型并不完美——平均偏差 7 名球员并不理想。但它比深度 1 的欠拟合(偏差 13 名球员)或深度 4 的过拟合(对非常相似的天气条件给出截然不同的预测)要好得多。
如何选择正确的平衡
基本方法
在选择最佳模型时,仅查看训练误差和测试误差是不够的。为什么?因为我们的测试数据有限——只有 14 个测试样本,我们可能会因为特定的几天模型表现好或者不好而感到幸运或不幸运。
测试模型的更好方法叫做交叉验证。与其仅使用一次训练和测试数据拆分,我们尝试不同的拆分。每次我们:
-
选择不同的样本作为训练数据
-
训练我们的模型
-
在未用于训练的样本上测试
-
记录误差
通过多次这样做,我们可以更好地理解我们的模型真正的表现如何。
⛳️ 我们从高尔夫球场数据中发现的结果
让我们来看一下不同模型在多个训练拆分中的交叉验证表现。鉴于我们仅有 14 个训练样本,我们使用了 K 折交叉验证,k=7,这意味着每个验证折叠有 2 个样本。

尽管这是一个较小的验证集,但它让我们能够最大化我们的训练数据,同时仍然获得有意义的交叉验证估计:
from sklearn.model_selection import KFold
def evaluate_model(X_train, y_train, X_test, y_test, n_splits=7, random_state=42):
kf = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
depths = range(1, 6)
results = []
for depth in depths:
# Cross-validation scores
cv_scores = []
for train_idx, val_idx in kf.split(X_train):
# Split data
X_tr, X_val = X_train.iloc[train_idx], X_train.iloc[val_idx]
y_tr, y_val = y_train.iloc[train_idx], y_train.iloc[val_idx]
# Train and evaluate
model = DecisionTreeRegressor(max_depth=depth, random_state=RANDOM_STATE)
model.fit(X_tr, y_tr)
val_pred = model.predict(X_val)
cv_scores.append(np.sqrt(mean_squared_error(y_val, val_pred)))
# Test set performance
model = DecisionTreeRegressor(max_depth=depth, random_state=RANDOM_STATE)
model.fit(X_train, y_train)
test_pred = model.predict(X_test)
test_rmse = np.sqrt(mean_squared_error(y_test, test_pred))
# Store results
results.append({
'CV Mean RMSE': np.mean(cv_scores),
'CV Std': np.std(cv_scores),
'Test RMSE': test_rmse
})
return pd.DataFrame(results, index=pd.Index(depths, name='Depth')).round(2)
# Usage:
cv_df = evaluate_model(X_train, y_train, X_test, y_test)
print(cv_df)

简单模型(深度 1):
-
CV 均值 RMSE:20.28(±12.90)
-
在交叉验证中显示出较大的波动(±12.90)
-
不同数据拆分下表现 consistently 较差
稍微灵活的模型(深度 2):
-
CV 均值 RMSE:17.35(±11.00)
-
比深度 1 的平均误差低
-
交叉验证中仍然显示出相当大的波动
-
预测能力有所提升
中等复杂度模型(深度 3):
-
CV 均值 RMSE:16.16(±9.26)
-
更稳定的交叉验证表现
-
比更简单的模型有明显改善
-
稳定性和准确性的最佳平衡
复杂模型(深度 4):
-
CV 均值 RMSE:16.10(±12.33)
-
与深度 3 的均值非常相似
-
交叉验证中的较大波动表明预测不够稳定
-
开始显示出过拟合的迹象
非常复杂的模型(深度 5):
-
CV 均值 RMSE:16.59(±11.73)
-
CV 性能开始恶化
-
高波动继续
-
明显的过拟合迹象开始出现

这个交叉验证展示了一个重要的点:尽管我们之前的分析中深度 3 模型在测试性能上表现最好,但交叉验证结果揭示了模型性能的波动。所有模型中较高的标准差(从±9.26 到±12.90 不等)表明,在这样一个小的数据集上,任何一个数据拆分可能都会给我们带来误导性的结果。这也是交叉验证如此重要的原因——它帮助我们看到模型的真实表现,而不仅仅是一个幸运或不幸运的拆分结果。
如何在实践中做出这一决定
根据我们的结果,以下是我们如何找到合适模型平衡的方法:
-
从简单开始
从你能构建的最基础模型开始。检查它在训练数据和测试数据上的表现。如果在这两者上都表现不佳,那也没关系!这只是说明你的模型需要稍微复杂一些,以便捕捉到重要的模式。
-
逐渐增加复杂性
现在,逐步使你的模型变得更加复杂,注意每次调整后的性能变化。当你发现模型在新数据上的表现开始变差时,那就是信号,告诉你该停止了——你已经找到了合适的复杂性平衡。
-
注意警告信号
留意潜在问题:如果你的模型在训练数据上表现极好,但在新数据上表现很差,那就说明模型太复杂。如果在所有数据上表现都很差,那说明模型太简单。如果模型在不同数据分割间的表现差异很大,那你可能做得太复杂了。
-
考虑数据规模
当你没有太多数据(比如我们的 14 个样本)时,保持模型简单。你不能指望在非常少的样本上训练出一个能够做出完美预测的模型。对于小数据集,拥有一个稳定的简单模型比一个不可靠的复杂模型要好。
每当我们构建预测模型时,我们的目标不是获得完美的预测——而是获得可靠、有效的预测,这些预测将在新数据上表现良好。对于我们的高尔夫球场数据集,平均预测偏差为 6-7 名球员虽然不是完美的,但远比偏差为 11-12 名球员(过于简单)或预测极度不可靠(过于复杂)要好得多。
关键要点
快速识别问题的方法
让我们总结一下我们关于构建实际有效的预测模型所学到的知识。以下是一些关键迹象,可以告诉你模型是过拟合还是欠拟合:

欠拟合的迹象(过于简单):
当模型欠拟合时,训练误差将很高(就像我们深度为 1 的模型,其 RMSE 为 16.13)。同样,测试误差也会很高(13.26 RMSE)。这两个误差之间的差距很小(16.13 与 13.26),这告诉我们模型一直表现不佳。这样的模型过于简单,无法捕捉到真实的关系。
过拟合的迹象(过于复杂):
过拟合模型显示出完全不同的模式。你会看到非常低的训练误差(就像我们深度为 5 的模型,其 RMSE 为 0.00),但测试误差却高得多(9.15 RMSE)。训练和测试表现之间的巨大差距(0.00 与 9.15)是一个信号,表明模型很容易被训练数据中的噪声干扰,它仅仅是在记忆它所训练过的特定示例。
良好平衡的迹象(如我们的深度 3 模型):
一个平衡良好的模型表现出更有前景的特征。训练误差相对较低(3.16 RMSE),尽管测试误差较高(7.33 RMSE),但这是我们最好的整体表现。训练误差和测试误差之间的差距存在,但并不极端(3.16 vs 7.33)。这告诉我们,模型找到了甜蜜点:它足够复杂,能够捕捉到数据中的真实模式,同时又足够简单,避免了被噪声干扰。欠拟合和过拟合之间的这种平衡正是我们在可靠模型中所追求的。
最后备注
偏差-方差权衡不仅仅是理论。它对实际预测有着真实的影响,包括我们之前的高尔夫球场示例。这里的目标不是完全消除欠拟合或过拟合,因为那是不可能的。我们想要的是找到那个甜蜜点,在这个点上,模型足够复杂以避免欠拟合并捕捉到真实的模式,同时又足够简单,以避免对随机噪声过拟合。
最终,一个始终略有误差的模型通常比一个过拟合的模型更有用——后者偶尔完美,但通常误差较大。
在现实世界中,可靠性比完美更为重要。
技术环境
本文使用的是 Python 3.7 和 scikit-learn 1.6。虽然所讨论的概念通常适用,但不同版本之间的具体代码实现可能会略有不同。
关于插图
除非另有说明,否则所有图片均由作者创作,结合了 Canva Pro 的授权设计元素。
𝙎𝙚𝙚 𝙢𝙤𝙧𝙚 𝙈𝙤𝙙𝙚𝙡 𝙀𝙫𝙖𝙡𝙪𝙖𝙩𝙞𝙤𝙣 & 𝙊𝙥𝙩𝙞𝙢𝙞𝙯𝙖𝙩𝙞𝙤𝙣 𝙝𝙚𝙧𝙚:

模型评估与优化
查看列表3 个故事


𝙔𝙤𝙪 𝙢𝙞𝙜𝙝𝙩 𝙖𝙡𝙨𝙤 𝙡𝙞𝙠𝙚:

分类算法
查看列表8 个故事


大 O — 一种实践方法
使用代码示例实施大 O 最佳实践。思维方式的转变如何提升代码在运行时的性能
·发布于 Towards Data Science ·13 分钟阅读·2024 年 1 月 19 日
--

来源:www.pexels.com/photo/wooden-picnic-bench-beside-the-lake-9367323/
本文旨在提高我们数据专业人士的大 O 识字率。我们常常在软件工程(SWE)和数据世界之间进行比较。虽然这两个领域都有一些最佳实践(如版本控制、错误处理和测试),但根据我个人的经验,数据角色相较于 SWE 角色在编写高效代码方面似乎落后。更具体来说,是在实现和检查高效代码的思维方式上。我个人过去也曾犯过这种错误——如果我脚本中的代码按预期运行并且能够优雅地处理错误,我就会认为它“完成”了。我相信,任何创建生产级代码的人,不论职位如何,都有责任确保其代码在运行时具有高效性能。
我意识到,许多专注于数据的岗位已经在实施大 O 最佳实践。通过与同行交流,这些人大多是与软件工程师紧密合作的人,因此能够“吸收”这些方法论。这与那些在独立数据团队中工作的人形成对比。尽管我只能推测其中的原因…
BigQuery 方法用于重新创建 Pandas 的顶级 EDA 函数
在本指南中,我们将探讨如何在 BigQuery 中重新创建用于 EDA 的关键 Pandas 函数,如 describe 和 corr。
·发布于 Towards Data Science ·阅读时间:21 分钟·2024 年 2 月 5 日
--

图像由 DALL-E 创建
从 BigQuery/SQL 迁移到 Python 可能会让人眼前一亮,尤其是在数据分析的背景下。我常常发现自己在 BigQuery SQL 中编写大量查询来操作和分析数据。它是一个强大的语言,但有时会变得很沉重。
现在,当我转到 Python 时,我对某些任务的简化程度感到惊讶。像 pandas 这样的 Python 库允许你进行数据操作和分析,而这些在 SQL 中会显得繁琐。
我发现一些 Pandas 函数如 DESCRIBE、CORR 和 ISNULL().SUM() 非常有用,我希望它们也能出现在 BigQuery 中。这让我开始探索 Pandas 中的其他有趣的 EDA 函数,并激发了我写这篇文章的灵感。在这里,我将分享我在 BigQuery 中找到的与 Pandas 最佳 EDA 函数相匹配的方法和代码。
让我们开始吧!
在本文中,我们将看看这 13 个函数:
-
头部 / 尾部
-
列
-
数据类型
-
Nunique
Bigram 词云动画展示你的数据故事
实践教程,讲解如何创建 Animated Word Cloud 来显示大 ram 频率的文本数据集并导出为 MP4 视频
·发表于 Towards Data Science ·阅读时长 5 分钟·2024 年 5 月 8 日
--

来源:AnimatedWordCloud 库。图片来源:作者。
动画词云展示了 n-gram 频率(文本语料库中的单词及其后续单词)随时间变化的图像序列。它赋予在源文本中出现频率较高的单词更高的重要性,同时调整数据集以适应不同的文本数据集。原始的可视化方法采用经典的 词云 直观逻辑,并向图形中添加了时间维度。该方法旨在探索在多个时期收集的文本数据集(即“时间序列文本数据”*)。
Michael Kane 开发了用于动画化单词频率的核心框架,并且 AnimatedWordCloud(AWC)库将这一可视化方法付诸实践。新版本带来了重要的更新:
-
数据缩放:现在能够更好地处理不同大小和单词频率的文本数据集
-
扩展 n_gram 参数(= 2)以生成 bigram 词云
-
效率提升(现在每个周期节省了 220 帧,改进了 Y 轴等)。
二分类
拆解传统指标的真正意义与局限性
·发表于 Towards Data Science ·9 分钟阅读·2024 年 1 月 12 日
--
介绍
分类工作的本质可以看作是将结构的复杂性总结为有限的类别,这种方法常常对简化生活非常有用,使我们能够将复杂的结构简化为两种单一类型。这些标签可能有明显的解释,比如我们通过收入这一独特且合理的特征来区分统计学家和数据科学家;也可能是压抑的尝试,将实验证据简化为一句话,用来接受或拒绝零假设。拥抱元语言,我们将分类标签归属于将信息总结成两种不同类型的工作:二分类。本工作旨在对这一特定标签进行更深的描述,为决策过程带来概率解释,并分析我们用来评估结果的指标。
将种群建模为分布
当我们试图描述并区分一个对象时,我们的目标是找到突出其独特性的特征。预计对于许多属性,这种差异在整个种群中并不总是准确一致的。
对于一个常见的具有不同 n 个特征(V1, … Vn)的二分类问题,我们可以查看这些特征值的分布,并试图从中得出结论:

图 1:两类(负类:红色;正类:蓝色)不同特征 Vn 的分布。此数据来自于 Kaggle 挑战,文末有相关引用。
如果任务是使用这些特征中的一个来指导我们的决策过程,给定个体的 v_n 值来决定,
预测类别时,直觉上可以选择频率最高的类别(或者如果我们的直方图是良好的分布估计器,则选择概率最高的类别)。例如,如果某个个体的 v4 测量值大于 5,那么它很可能是正类。
然而,我们可以做得更多,利用不同的特征,将信息合成成一个单一的分布。这正是得分函数 S(x) 的任务。得分函数将进行回归,将特征分布压缩到一个唯一的分布 P(s, y),并且可以根据每个类别进行条件化,P(s|y=0) 表示负类 (y=0) 和 P(s|y=1) 表示正类 (y=1)。
从这个单一的分布中,我们选择 一个决策阈值 t,该阈值决定我们对给定点的估计——用 ŷ 表示——是正类还是负类。如果 s 大于 t,我们就赋予正类标签;否则,赋予负类标签。

图 2:二分类过程的图示。得分函数将 n 维特征空间压缩为分布 P(s, y)。
给定分布P(s, y, t) 其中 s、y 和 t 分别代表分数、类别和阈值的值, 我们就得到了分类器的完整描述。
我们分类器的度量标准(边际和条件分布)
为我们的分类器开发度量标准可以看作是量化 p(s|P) 和 p(s|N) 区分能力的追求。
通常情况下,两个分布 p(s|P) 和 p(s|N) 会有重叠,导致无法完美分类。因此,给定一个阈值,我们可以问一下 p(s > t|N) — 假阳性率(FPR) — 也就是我们将负类个体错误分类为正类的概率,例如。
当然,我们可以堆砌大量的度量标准,甚至给它们起名——尽管这些命名可能不一致——但为了所有的目的,我们只需定义四个概率及其相关的比率来表示分类器的性能:
-
真正阳性率 (tpr): p(s > t|P) = TP/(TP+FN);
-
假阳性率 (fpr): p(s > t|N) = FP/(FP+TN);
-
真正阴性率 (tnr): p(s ≤ t|N) = TN/(FP+TN);
-
假阴性率 (fnr): p(s ≤ t|P) = FN/(TP+FN).
如果你已经熟悉这个主题,你可能会注意到,这些是我们从分类器的混淆矩阵中定义的度量标准。因为该矩阵是为每个选择的决策阈值定义的,我们可以将其视为条件分布 P(ŷ, y|t) 的一种表示,其中这些对象是完全描述我们分类器性能的混淆矩阵类别的一部分。

图 3:每个混淆矩阵可以视为一个条件 p(ŷ, y|t),并且 所有可能的混淆矩阵组成一个完整的分类器性能描述。
因此,错误比率 fpr 和 fnr 是量化两个条件得分分布交集方式和交集程度的指标:

图 4:分类器的性能估计量是不同的度量,用来描述两个分布 p(s|P) 和 p(s|N) 之间的重叠或分离。
性能总结:ROC 曲线
由于比率受到 tpr + fnr = 1 和 fpr + tnr = 1 的约束,这意味着我们只有 2 个自由度来描述我们的性能*。
ROC 曲线是由 t 参数化的曲线,用 (x(t), y(t)) = (fpr(t), tpr(t)) 表示为与正交轴相对的点。这将为我们提供一个简明的总结,来可视化分类器在所有不同阈值下的性能,而不仅仅是一个单一的阈值。
最佳的 t 值通常是未知的,必须作为分类器构建的一部分来确定。——《连续数据的 ROC 曲线》(2009, Chapman and Hall)。
我们的目标是探索将概率分布作为处理对象的概念,因此让我们想象一下对于一个完全无效的分类器,基本情形会是什么样的。由于我们预测的有效性依赖于判别性质,p(s|P) 和 p(s|N),当 p(s|P) = p(s|N) = p(s) 时,我们就遇到了这种无效性的典型例子。
如果我们进行将每个条件视为均值分离不同值的高斯模型化的练习,我们可以清晰地看到性能是如何变化的:

图 5:当得分分布变得更加具备判别力时,分类器的性能提高。
这种可视化将作为一个宝贵的辅助工具,帮助我们理解一个关键分类器度量的概率解释——称为曲线下面积(AUC),我们稍后将深入探讨这一点。
ROC 曲线的一些特性
ROC 曲线可以描述为函数 y = h(x),其中 x 和 y 分别是真阳性率和假阳性率,而它们又由 t 参数化,形式为 x(t) = p(s > t|N) 和 y(t) = p(s > t|P)。
我们可以利用这一点推导出以下特性:
-
y = h(x) 是一个单调递增的函数,位于由 (0, 0) 和 (1, 1) 定义的直线之上;
-
如果分类得分经过严格递增的变换,ROC 曲线不会改变;

图 6:对得分分布施加单调递增的变换不会改变 ROC 曲线,因为它保持了回归的顺序。
这个特性使得分类器的校准过程成为可能。
3. 对于阈值 t 点的 ROC 曲线的明确斜率:

其中 p(t|P)表示累积分布 p(s ≤ t | P)的密度分布(对于 p(t|N)也是如此)。
二元分类作为假设检验:为我们的方法提供正当性
当从贝叶斯推理的角度来看分类过程时,我们的目标是推导后验概率p(P|t),它表示一个具有阈值t的点属于正类的概率。因此,定义在属性 3 上的斜率导数可以看作是似然比L(t)。
这个比率[L(t)]告诉我们,分类器的 t 值在总体 P 中比在总体 N 中发生的可能性有多大,这反过来可以解释为分配到总体 P 的置信度度量。——《连续数据的 ROC 曲线》(2009 年,Chapman 和 Hall)。
这是一个重要的事实,因为通过建立二元分类过程和假设检验之间的等价关系,我们为为什么基于阈值进行分类提供了正当性。
如果我们用零假设 H0(个体属于总体 N)对立于备择假设 H1(个体属于总体 P)来表述分类问题,我们可以得出以下关系:

图 7:二元分类过程与假设检验之间的联系通过混淆矩阵可视化。边际比率度量与假设检验中的标准α和β值相关。
Neyman-Pearson 引理确立了最强检验——即具有最高1-β值的检验——在显著性水平α下,拥有一个包含所有* s 值的区域R*,该区域满足:

其中α足以通过条件p(s ∈ R|N) = α来确定k。
这意味着,当我们以L(s)单调递增的方式对总体进行评分时,s和k之间的一一对应关系确保选择一个超过特定阈值的规则是最佳决策。
对于我们虚构的案例,其中分类器为每个类别分配正态分布,似然直接满足这一条件:

图 8:对于分数的双正态分布情况,似然比是严格递增的(更精确地说,是指数级递增)。
在现实问题中,情况并非总是如此,因为评分分布不一定在此意义上表现良好。我们可以使用Kaggle数据集来理解这一点,通过核密度估计(KDE)来估算条件概率的密度。

图 9:这是在平衡数据情况下,逻辑回归模型给出的评分分配。该分类器在没有参数调整的情况下进行了训练。

图 10:对于一个实际案例(Kaggle 数据集),我们可以看到,似然比并不一定是单调递增的。
这意味着较高的分数不一定与个体属于正类的概率增加相关。
ROC-AUC 的概率解释,以及为什么我们需要谨慎对待它
曲线下面积 (AUC) 可能是最广泛使用的值,用于总结 ROC 曲线表达的结果。它被定义为从 0 到 1 的 y(x)积分,正如其名称所示。值得注意的是,完美分类器的表现由正交轴上的点(0, 1)体现,表示零概率的负类误分类,并明确保证正确分类正类。
图 5 中的处理方法给我们提供了一个提示,即良好拟合的概率解释必须与为正类个体分配较高的分数和值为负类个体分配较低的分数的连贯性相关。这正是情况所在,因为有证明——如1中提到——AUC 等同于正类个体的得分(Sp)高于负类个体得分(Sn)的概率:

关键要考虑的一点是:AUC 旨在提供一个单一的数字来估算分类器的表现。然而,在实际决策时,必须选择一个适合你问题特定需求的阈值 t。挑战在于,如前所述,基于阈值的最佳决策发生在似然比单调递增时,而在实践中这并不总是如此。
因此,即使你拥有一个接近 1 的高 AUC 值,也不足以确定你的分类器是否能够基于决策边界进行最佳分类。在这种情况下, 单单达到一个高 AUC 值并不能保证分类器在实际决策情境中的有效性。
结论
二分类的这种概率解释可能提供了对该过程复杂性更深刻的理解。通过将人群建模为分布,我们可以基于个体属于某个特定类别的概率做出明智的决策。ROC 曲线作为一个有价值的工具,总结了阈值选择如何影响分类效率。此外,二分类与假设检验之间的联系强调了我们为什么通过阈值来进行分类。需要记住的是,尽管曲线下面积(AUC)是常用的性能评估指标,它并不总能保证最佳的实际决策,这突显了选择正确阈值的重要性。这种概率解释丰富了我们对二分类的理解,使其成为解决实际问题的强大框架。
致谢
特别感谢 Renato Vicente,他让我通过混淆矩阵空间来可视化分类器,并鼓励我撰写本文。
所有图片和图表均由作者提供。
另外,你可以在 Linkedin找到我。
参考文献
1 Krzanowski, Wojtek J., and David J. Hand. 连续数据的 ROC 曲线。Crc Press,2009 年
[2] Muschelli, John (2019–12–23). “ROC 和 AUC 与二元预测变量:一个可能具有误导性的指标”。分类学期刊。Springer Science and Business Media LLC. 37 (3): 696–708. doi:10.1007/s00357–019–09345–1。 ISSN 0176–4268。
数据集
R 中的二项逻辑回归
学习何时以及如何在 R 中使用(单变量和多变量)二项逻辑回归。还要学习如何解释、可视化和报告结果。
·发表于Towards Data Science ·46 分钟阅读·2024 年 1 月 30 日
--

图片来源:Annie Spratt
介绍
回归分析是统计学中常用的工具,用于检验和量化变量之间的关系。
最常见的两种回归是线性回归和逻辑回归。当因变量是定量时,使用线性回归;而当因变量是定性时,则使用逻辑回归。
线性回归和逻辑回归有不同的类型。在详细介绍之前,我们先回顾一下变量的类型。
定量变量衡量的是数量,它可以取的值是数字。它被分为:
-
离散型:它可以取的值是可计数的,且具有有限的可能性(这些值通常是整数,例如子女数量),以及
-
连续型:它可以取的值是不可计数的,并且有无限多的可能性(这些值通常带有小数,或者至少在技术上可能带小数,例如体重)。
逐步实现:使用有限自动机
使用 Python 对物体检测的现实 AI 系统进行有限状态机建模与仿真
·发布于 Towards Data Science ·阅读时间 17 分钟·2024 年 5 月 14 日
--

图片由作者提供
背景
“当生活给你鸡时,让 AI 处理这些鸡毛蒜皮的事。” — 无名工程师
为什么我们需要仿真?通过采样并获得平均值,我们到底能获得什么优势?但实际上,它从来不仅仅是这些。与我们在计算机科学课程中遇到的简单任务相比,现实生活通常要复杂得多。有时我们无法找到解析解,无法找到总体参数。有时我们必须建立一个模型,以反映系统动态的具体情况,我们必须运行仿真来研究潜在过程,从而更好地理解现实世界的情况。仿真建模为各行各业和应用中的系统设计与工程提供了无价的工具。它有助于分析系统性能,识别潜在的瓶颈和低效问题,从而允许进行迭代改进和优化。
说到我们这个非常特别的挑战,在这里,我们将创建一个 FSM 仿真,模拟一个 AI 辅助的草坪监控和清洁安全系统的行为。特别地,我们将处理模拟过程,通过物体检测和喷水子系统智能管理鸟类的进出。在上一篇文章中,你已经了解了有限状态机(FSM)的理论和设计原则,旨在解决臭名昭著的鸡与火鸡(CaT)问题,从而创建了一个在高层次抽象下描述复杂草坪场景的模型。通过本文,我们将进一步探讨基于 FSM 的仿真在利用实际系统操作方面的实际应用。此外,我们还将用 Python 实现 FSM 仿真,以便后续通过优化和 XAI 技术对其进行改进。在本教程结束时,你将拥有一个完全功能的 FSM 解决方案,并对通过仿真建模解决工程问题有更深入的理解。
免责声明: 本作品是“鸟瞰深度学习”系列的一部分,专注于使用有限自动机进行计算机视觉应用的实际系统建模与仿真。所有的参与者、状态、事件和输出仅为 FSM 设计过程中的教育性产物。与实际人物、鸟类或真实事件的任何相似之处纯属巧合。
用有限自动机进行系统建模
“当被问及没有抽象的系统设计时,只需描述现实场景中的‘如果-那么’循环,确保在处理多个条件时有所停顿。然后,优雅地退后,留下这些琐碎的事情。” — 未知工程师。
让理论走向实践
仿真,作为数学建模的一种特殊形式,涉及创建简化的现实世界系统的表示,以了解其在各种条件下的行为。其核心是通过方程捕捉真实系统的内在模式,而仿真则是通过运行程序来近似这些方程的算法过程。这个过程能够生成仿真结果,便于与理论假设进行比较,并推动实际系统的改进。仿真建模可以为系统行为提供洞察,并在进行真实实验过于昂贵和/或困难时预测结果。尤其当无法获得解析解时(例如仓库管理过程),仿真会特别有用。
在处理 CaT 问题时,目标很明确:我们希望保持草坪的整洁并节省资源。我们不依赖传统的实验方法,而是选择基于模拟的方式,寻找一种能够最小化水资源使用和水费的设置。为了实现这一目标,我们将开发一个基于有限状态机(FSM)的模型,反映系统中的关键过程,包括鸟类侵入、鸟类检测和喷水。整个模拟过程中,我们将评估系统性能,以指导进一步的优化工作,从而提高鸟类检测的效率。
为什么不用 if-else 语句
使用 if-else 条件分支进行系统建模是一种幼稚的解决方案,最终会导致复杂性和易出错性增加,从而使得进一步开发和维护变得更加困难。下面你将看到如何(不)描述一个简单的草坪上的鸡系统,考虑我们之前讨论过的有限状态机的示例(请参见图 1,该图展示了简化的 CaT 系统场景下的 FSM 状态转换图)。
# import functions with input events and actions
from events import (
simulate_chicken_intrusion,
initiate_shooing_chicken,
)
from actions import (
spoil_the_lawn,
start_lawn_cleaning,
one_more_juice
)
# define states
START = 0
CHICKEN_PRESENT = 1
NO_CHICKEN = 2
LAWN_SPOILING = 3
ENGINER_REST = 4
END = 5
# initialise simulation step and duration
sim_step = 0
max_sim_steps = 8
# initialise states
prev_state = None
current_state = START
# monitor for events
while current_state != END:
# update state transitions
if current_state == START:
current_state = NO_CHICKEN
prev_state = START
elif current_state == NO_CHICKEN:
if prev_state == CHICKEN_PRESENT:
start_lawn_cleaning()
if simulate_chicken_intrusion():
current_state = CHICKEN_PRESENT
else:
current_state = ENGINER_REST
prev_state = NO_CHICKEN
elif current_state == CHICKEN_PRESENT:
if initiate_shooing_chicken():
current_state = NO_CHICKEN
else:
current_state = LAWN_SPOILING
prev_state = CHICKEN_PRESENT
elif current_state == LAWN_SPOILING:
spoil_the_lawn()
current_state = CHICKEN_PRESENT
prev_state = LAWN_SPOILING
elif current_state == ENGINER_REST:
one_more_juice()
current_state = NO_CHICKEN
prev_state = ENGINER_REST
sim_step += 1
if sim_step >= max_sim_steps:
current_state = END
在这个代码片段中,我们定义常量来表示 FSM 的每个状态(例如,CHICKEN_PRESENT)。然后,我们将当前状态初始化为 START,并在一个 while 循环中持续监控事件,模拟简化系统的行为。根据当前状态和相关事件,我们使用 if-else 条件分支语句在状态之间切换,并调用相应的操作。状态转换可能会有副作用,例如启动鸡群污染草坪的过程并开始为工程师清理草坪。这里,输入事件和动作相关的功能表示可以自动化的过程,因此我们为了简化模拟,假设引入了相关的函数。请注意,虽然鸡可以几乎无限制地破坏草坪,但过量的果汁却存在过度水化的风险。请小心这一点,并且不要忘记为你的模拟设定时间限制。在我们的案例中,时间限制是一天的结束,由 max_sim_steps 变量定义。看起来很丑对吧?
这种方法应该是可行的,但想象一下,如果我们想扩展逻辑,需要不断更新 if-else 语句,重复相同的分支和状态切换。正如你可以想象的那样,随着状态和事件的增加,系统状态空间的规模迅速增长。与 if-else 分支不同,有限状态机(FSM)非常适合处理复杂任务,它允许将复杂系统分解为可管理的状态和转换,从而提高代码的模块化和可扩展性。在这里,我们即将开始使用有限自动机(FSM)实现系统行为,以减少 AI 系统操作的水资源使用,同时不影响鸟类检测的准确性。
用 Python 实现 FSM
“好了,孩子,我们现在要做一只鸡。” — 未知工程师。
从 FSM 到底层
在本节中,我们深入探讨了有限状态机实现的设计选择,阐明了简化模拟过程和最大化其在现实系统优化中效用的策略。为了构建模拟,我们首先需要基于对基础过程的假设创建一个系统模型。一种方法是从封装个体状态和过渡的功能开始。然后,我们可以将它们组合起来,通过复制真实系统行为来创建一系列事件。我们还希望跟踪每次模拟运行的输出统计数据,以提供其性能的概念。我们想要做的是观察系统如何随着时间的推移在条件变化下演变(例如,基于概率的鸟类孵化和破坏草坪的随机过程)。为此,让我们首先定义和安排我们稍后将要实现的构建模块。以下是计划:
-
定义类契约。
-
构建目标类的层次结构,描述个体目标。
-
实现状态之间的过渡逻辑。
-
实现单步模拟以及完整的运行。
-
跟踪模拟运行的输出统计数据。
让我们谈谈抽象
首先,我们需要为我们的模拟创建一个类层次结构,从表示状态的基类到更具领域特定的院子模拟子类。我们将使用@abc.abstractmethod和@property装饰器来分别标记抽象方法和属性。在 AbstractSimulation 类中,我们将定义step()和run()抽象方法,以确保子类实现它们。
class AbstractSimulation(abc.ABC):
@abc.abstractmethod
def step(self) -> Tuple[int, List['AbstractState']]:
pass
@abc.abstractmethod
def run(self) -> Iterator[Tuple[int, List['AbstractState']]]:
pass
对 AbstractState 也有类似的应用,它定义了一个抽象方法transit(),由子类实现:
class AbstractState(abc.ABC):
def __init__(self, state_machine: AbstractSimulation):
super().__init__()
self.state_machine = state_machine
def __eq__(self, other):
return self.__class__ is other.__class__
@abc.abstractmethod
def transit(self) -> 'AbstractState':
pass
对于我们的有限状态机(FSM),系统模拟的更具体方面将封装在继承自 AbstractSimulation 的 AbstractYardSimulation 类中。如其名称所示,AbstractYardSimulation 更精确地概述了模拟的领域,因此我们可以在 CaT 问题的背景下定义一些特定于院子模拟的额外方法和属性,包括simulate_intrusion()、simulate_detection()、simulate_sprinkling()、simulate_spoiling()。
我们还将创建一个中间抽象类 AbstractYardState,以确保类层次结构中的类型一致性:
class AbstractYardState(AbstractState, abc.ABC):
state_machine: AbstractYardSimulation
现在,让我们看一下反映名为 Target 及其后代的继承树。
鸡和火鸡的创建
Target 行为是我们模拟的基石,因为它影响着所有方面,助力构建有效的模型以及后续的优化。图 1 展示了我们将要实现的目标类的类图。

图 1. 目标类的类层次结构(图示由作者提供)
对于我们的系统,重要的是要注意目标出现的频率,它可能会对草坪造成一定的损害,而且它还具有健康属性。后者与目标的大小有关,大小可能不同,因此水枪可以瞄准较小或较大的目标(这反过来会影响水的消耗)。因此,大目标具有较多的生命值,小水流无法有效地处理它。
为了模拟不同频率的目标穿越草坪,我们还创建了相关的属性。以下是代码:
class AbstractTarget(int, abc.ABC):
@property
@abc.abstractmethod
def health(self) -> float:
pass
@property
@abc.abstractmethod
def damage(self) -> float:
pass
@property
@abc.abstractmethod
def frequency(self) -> float:
pass
请注意,在我们的实现中,我们希望目标对象是有效的整数,这将在模拟中用于建模随机性。
接下来,我们创建子类来实现不同类型的目标。以下是类 Chicken 的代码,我们在其中重写了从父类继承的抽象方法:
class Chicken(AbstractTarget):
@property
def health(self) -> float:
return 4
@property
def damage(self) -> float:
return 10
@property
def frequency(self) -> float:
return 9
我们对剩余的火鸡和空目标类执行类似的过程。对于火鸡,生命值和损害参数分别设置为 7 和 17(让我们看看如何用我们的 AI 辅助系统处理这些笨重的家伙)。空目标是一个特殊类型的目标,表示草坪上没有任何鸟类。虽然我们不能给它的生命值和损害属性赋予其他值,除了 0,但草坪上无鸟的无条件(即不是由工程师引起的)状态有非零的概率,这个概率反映在频率值为 9 的设定中。
从入侵到敌人被轻松发现
现在想象一只鸟在其自然栖息地中的样子。它可以表现出各种各样的敌对行为和展示。在面对挑战时,动物可能会根据具体情况采用一系列适应性策略,包括战斗、逃跑反应以及其他中间行为。接续上一篇关于 FSM 设计与建模的文章,你可能还记得我们已经描述了 CaT 系统的关键组成部分,我们将用它来进行实际实现(参见表 2,其中列出了描述触发状态变化的 FSM 输入事件)。
在 FSM 模拟领域,一只鸟可以被看作是一个独立的触发一系列事件的行为体:侵犯院子、破坏草坪等等。特别地,在乐观的场景下(鸟类检测和识别成功,防御行为):鸟类在可能被基于 CV 的鸟类检测器识别之前侵入院子,以便继续进行喷水模块,这些配置依赖于上游预测的入侵者类别。通过这种方式,鸟类可以成功地被赶走(击中)或未能被赶走(未击中)。在这个场景下(鸟类检测成功、类别预测、防御行为),最终,鸟类逃离了草坪。任务完成。哒哒!
你可能还记得,有限状态机(FSM)可以通过状态转移图来图形化表示,这一点我们在之前的教程中已经涉及过(参见表 3,其中展示了 FSM 状态转移表及下一阶段的转移逻辑)。考虑到这一点,接下来我们将创建 AbstractYardState 的子类,并重写transit()方法,根据当前状态和事件来指定状态之间的转移。
Start 是初始状态,状态机从该状态过渡到 Spawn。
class Start(AbstractYardState):
def transit(self) -> 'Spawn':
return Spawn(self.state_machine)
从 Spawn 状态,系统可以过渡到以下状态之一:Intrusion、Empty 或 End。
class Spawn(AbstractYardState):
def transit(self) -> Union['Intrusion', 'Empty', 'End']:
self.state_machine.stayed_steps += 1
self.state_machine.simulate_intrusion()
next_state: Union['Intrusion', 'Empty', 'End']
if self.state_machine.max_steps_reached:
next_state = End(self.state_machine)
elif self.state_machine.bird_present:
next_state = Intrusion(self.state_machine)
else:
next_state = Empty(self.state_machine)
return next_state
如果我们达到模拟时间步数的上限,状态机会过渡到 End 状态。如果有鸟侵入或者已经在草地上,状态机会切换到 Intrusion 状态,否则下一状态是 Empty。
Intrusion 和 Empty 状态都跟随一个检测尝试,因此它们共享转移逻辑。因此,我们可以通过创建一个父类——IntrusionStatus 来封装这一逻辑,从而减少代码重复,同时使得子类能够在类型层面区分模拟中的实际状态 Intrusion 和 Empty。
class IntrusionStatus(AbstractYardState):
intruder_class: Target
def transit(self) -> Union['Detected', 'NotDetected']:
self.state_machine.simulate_detection()
self.intruder_class = self.state_machine.intruder_class
next_state: Union['Detected', 'NotDetected']
if self.state_machine.predicted_bird:
next_state = Detected(self.state_machine)
else:
next_state = NotDetected(self.state_machine)
return next_state
我们对 Detected 和 NotDetected 类采取类似的方法,那个超类 DetectionStatus 负责目标预测。
class DetectionStatus(AbstractYardState):
detected_class: Target
def transit(self) -> 'DetectionStatus':
self.detected_class = self.state_machine.detected_class
return self
然而,与 Intrusion/Empty 组合状态不同,NotDetected 类引入了额外的转移逻辑,用以指引模拟流程,特别是关于草地污染/损坏的情况。
class Detected(DetectionStatus):
def transit(self) -> 'Sprinkling':
super().transit()
return Sprinkling(self.state_machine)
class NotDetected(DetectionStatus):
def transit(self) -> Union['Attacking', 'NotAttacked']:
super().transit()
next_state: Union['Attacking', 'NotAttacked']
if self.state_machine.bird_present:
next_state = Attacking(self.state_machine)
else:
next_state = NotAttacked(self.state_machine)
return next_state
Detected 类会无条件地过渡到 Sprinkling 状态。对于其对立面,有两个可能的下一个状态,取决于草地上是否真的有鸟。如果鸟不在,那显然不会有鸟屎,而如果有鸟屎的可能性,则可能需要进行草地清理(或者不需要,CaT 世界充满了随机性)。
回到喷洒状态,它有两个可能的结果(命中或未命中),取决于系统是否成功将鸟驱赶走(至少这次是如此)。
class Sprinkling(AbstractYardState):
def transit(self) -> Union['Hit', 'Miss']:
self.state_machine.simulate_sprinkling()
next_state: Union['Hit', 'Miss']
if self.state_machine.hit_successfully:
next_state = Hit(self.state_machine)
else:
next_state = Miss(self.state_machine)
return next_state
注:Hit 状态没有专门的转移逻辑,它的存在是为了遵循关于草地上翼助攻击领域的语义。忽略它会导致 Shooting 状态直接过渡到 Leaving。
class Hit(AbstractYardState):
def transit(self) -> 'Leaving':
return Leaving(self.state_machine)
如果水喷洒器被激活且草地上没有鸟(检测器错误预测了鸟的存在),状态机将返回到 Spawn 状态。如果鸟实际上在场且我们没有检测到,草地上可能会有鸟屎。
class Miss(AbstractYardState):
def transit(self) -> Union['Attacking', 'Spawn']:
next_state: Union['Attacking', 'Spawn']
if self.state_machine.bird_present:
next_state = Attacking(self.state_machine)
else:
next_state = Spawn(self.state_machine)
return next_state
最终,攻击尝试可能会对草地造成实际的损害,正如 Attacking 类及其子类所体现的那样:
class Attacking(AbstractYardState):
def transit(self) -> Union['Attacked', 'NotAttacked']:
self.state_machine.simulate_spoiling()
next_state: Union['Attacked', 'NotAttacked']
if self.state_machine.spoiled:
next_state = Attacked(self.state_machine)
else:
next_state = NotAttacked(self.state_machine)
return next_state
class Attacked(AfterAttacking):
def transit(self) -> Union['Leaving', 'Spawn']:
return super().transit()
class NotAttacked(AfterAttacking):
def transit(self) -> Union['Leaving', 'Spawn']:
return super().transit()
我们可以采用与 Intrusion 状态相同的思路,将共享的转移逻辑封装到一个名为 AfterAttacking 的超类中,从而得到 Leaving 或返回 Spawn 状态:
class AfterAttacking(AbstractYardState):
def transit(self) -> Union['Leaving', 'Spawn']:
next_state: Union['Leaving', 'Spawn']
if self.state_machine.max_stay_reached:
next_state = Leaving(self.state_machine)
else:
next_state = Spawn(self.state_machine)
return next_state
接下来会发生什么呢?当模拟达到步数限制时,它会卡在 End 状态:
class End(AbstractYardState):
def transit(self) -> 'End':
return self
在实际操作中,我们不希望程序无休止地执行。因此,一旦模拟检测到过渡到结束状态,它将关闭。
模拟 CaT 系统
“在鸟类探测的微妙世界中,请记住:当一个模型说“未检测到鸡”时,可能有一只狡猾的鸟正悄悄地站在草坪上,未被发现。这种差异提醒我们需要完善和增强我们的人工智能系统。” — 无名工程师。
现在,我们希望模拟鸟类闯入草坪、破坏草坪并离开的过程。为此,我们将采用一种称为离散事件模拟的模拟建模方法。我们将通过分析系统各个元素之间最重要的关系,并基于有限自动机原理开发一个模拟系统来再现系统行为。为此,我们需要考虑以下几个方面:
-
鸟类可以闯入房产的后院。
-
基于计算机视觉的系统尝试检测和分类入侵对象。
-
基于以上情况,如果对象被识别为某种特定的鸟类,我们将建模水洒喷头的过程以将其驱逐。
-
还应该提到,有一个概率过程会导致鸟类破坏草坪(再次声明,没什么个人恩怨,只是羽毛)。
草坪模拟过程
现在,是时候探索概率的魔力,通过实现的有限状态机来模拟这些过程了。为此,我们需要创建一个YardSimulation类来封装模拟逻辑。如前所述,模拟不仅仅是有限状态机。模拟步骤与状态机过渡之间的对应关系也适用。也就是说,系统需要执行多个状态过渡才能切换到下一个时间步骤。
在这里,step()方法处理从当前状态到下一个状态的过渡,并调用有限状态机(FSM)的transit()方法,直到状态机返回到生成状态或达到结束状态。
def step(self) -> Tuple[int, List[AbstractYardState]]:
self.step_idx += 1
transitions = list()
while True:
next_state = self.current_state.transit()
transitions.append(next_state)
self.current_state = next_state
if self.current_state in (Spawn(self), End(self)):
break
return self.step_idx, transitions
在run()方法中,我们在循环中调用step()并生成其输出,直到系统过渡到结束步骤:
def run(self) -> Iterator[Tuple[int, List[AbstractYardState]]]:
while self.current_state != End(self):
yield self.step()
reset()方法在鸟离开后重置有限状态机的记忆。
def reset(self) -> 'YardSimulation':
self.current_state = Start(self)
self.intruder_class = Target.EMPTY
self.detected_class = Target.EMPTY
self.hit_successfully = False
self.spoiled = False
self.stayed_steps = 0
return self
当一只鸟被水洒喷头成功喷到,或者它在草坪上停留太久(例如,假设它感到无聊)时,它就离开了。后者相当于鸟在草坪上停留了 5 个模拟步骤(即 5 分钟)。时间并不长,谁知道呢,也许邻居的草坪看起来更吸引人。让我们实现一些我们系统行为的核心部分。
首先,如果草坪上没有鸟类(即真正的入侵者类别),我们尝试生成一只。
def simulate_intrusion(self) -> Target:
if not self.bird_present:
self.intruder_class = self.spawn_target()
return self.intruder_class
这里,生成与非法入侵实体(鸟或无物)实时创建相关。
@property
def bird_present(self) -> bool:
return self.intruder_class != Target.EMPTY
然后,基于计算机视觉的系统——由类混淆矩阵描述——尝试检测和分类入侵对象。在这个过程中,我们模拟一个预测生成,同时牢记实际的入侵者类别(地面真相)。
def simulate_detection(self) -> Target:
self.detected_class = self.get_random_target(self.intruder_class)
return self.detected_class
检测器在模拟的每个时间步上都在工作,因为模拟系统并不知道实际情况(否则我们为什么还需要检测器呢?)。如果检测器识别到鸟类,我们会尝试用水喷洒器把它赶走,水流量依赖于检测到的目标类别:
def simulate_sprinkling(self) -> bool:
self.hit_successfully = self.bird_present and (self.rng.uniform() <= self.hit_proba) and self.target_vulnerable
return self.hit_successfully
无论喷水是否成功,系统都会消耗水。成功的判定标准包括以下条件:草坪上有鸟类存在(a),水喷洒器击中鸟类(b),喷洒的水量足够/适合处理给定大小的鸟类(c)。请注意,(c)鸡“喷洒”不会处理火鸡,但其他情况下适用。
草地污染部分——鸟类有可能弄脏草坪。如果发生这种情况,草坪损坏率就会增加(显然)。
def simulate_spoiling(self) -> bool:
self.spoiled = self.bird_present and (self.rng.uniform() <= self.shit_proba)
if self.spoiled:
self.lawn_damage[self.intruder_class] += self.intruder_class.damage
return self.spoiled
现在我们具备了所有必要的条件来模拟我们将要处理的 CaT 问题的单一时间步。模拟时间开始!
鸟类逃跑
现在,我们已经准备好使用 FSM 模拟来模拟一个在不同设置下的 AI 辅助草坪安全系统。在进行草坪模拟时,YardSimulation.run() 方法会遍历一系列的状态转移,直到系统达到最大步数。在此过程中,我们实例化一个模拟对象(即状态机),设置 num_steps 参数,这个参数反映了模拟的总时间步数(比如 12 小时或白天),以及与计算机视觉(CV)基础的鸟类检测子系统的混淆矩阵相关的 detector_matrix,该子系统经过训练,用来预测鸡和火鸡:
sim = YardSimulation(detector_matrix=detector_matrix, num_steps=num_steps)
现在我们可以运行 FSM 模拟,并打印 FSM 在每个时间步经历的状态转移:
for step_idx, states in sim.run():
print(f'\t{step_idx:0>3}: {" -> ".join(map(str, states))}')
此外,我们还积累了与鸟类喷水(simulate_sprinkling)和鸟类到达后的草地清理(simulate_spoiling)相关的水使用模拟统计数据。
def simulate_sprinkling(self) -> bool:
...
self.water_consumption[self.detected_class] += self.detected_class.health
...
def simulate_spoiling(self) -> bool:
...
if self.spoiled:
self.lawn_damage[self.intruder_class] += self.intruder_class.damage
...
当模拟达到其限制时,我们可以计算每个类别到一天结束时的总水消耗。我们希望看到的是每次模拟运行后的变化情况。
water_sprinkling_total = sum(sim.water_consumption.values())
lawn_damage_total = sum(sim.lawn_damage.values())
最后,让我们进行实验,评估在计算机视觉子系统发生变化时,系统的表现如何。为此,我们将使用 YardSimulation.run() 方法进行 100 次试验,分别针对未经训练的(基准)和完美检测矩阵进行模拟:
detector_matrix_baseline = np.full(
(len(Target),) * 2, # size of the confusion matrix (3 x 3)
len(Target) ** -1 # prediction probability for each class is the same and equals to 1/3
)
detector_matrix_perfect = np.eye(len(Target))
然后,我们可以聚合并比较不同实验设置下与目标喷水和草坪清理的总水使用量相关的输出统计数据:

图 2. FSM 模拟输出统计数据,涵盖了鸟类检测子系统的边缘情况(图片由作者提供)
实验结果的总结对比显示,拥有一个更好的计算机视觉(CV)模型可以使得水使用量减少 37.8%(70.9 与 44.1),相较于未经过训练的基线检测器,在给定输入参数和仿真条件下对鸟类的检测——这一概念既直观又在预期之中。但“更好”的定量意义是什么呢?是否值得费力地精细调整模型?数值结果展示了改进模型的价值,激励了进一步优化的努力。未来,我们将把这些统计结果作为全局优化的目标,以提高鸟类检测子系统的效率,并减少系统操作和维护中的水消耗,从而让工程师稍微高兴一点。
本教程所使用的源代码可以在此 GitHub 仓库中找到:github.com/slipnitskaya/computer-vision-birds。
结论
总结来说,仿真建模是一个有用的工具,可以用来估算过程的效率、快速测试预期的变化,并了解如何通过操作和维护改进过程。通过本文,你对仿真建模在解决工程问题中的实际应用有了更好的理解。特别是,我们已经涵盖了以下内容:
-
如何设计一个模型来近似一个复杂系统,从而改进其在鸟类检测和水洒布方面的操作。
-
如何创建现实世界过程的仿真,以便在不同条件下理解 CaT 系统的行为。
-
如何在 Python 中实现基于 FSM 的解决方案,使得系统能够追踪仿真过程中的相关统计数据。
接下来做什么
聚焦于提高资源效率,在后续文章中,你将发现如何通过应用蒙特卡洛方法和可解释 AI(XAI)技术来解决水成本降低的非解析优化问题,从而增强基于计算机视觉的鸟类检测子系统,推动我们的仿真 AI 辅助草坪安防系统的发展。
在视觉项目的仿真建模和优化中,还有哪些重要的概念?请访问Bird by Bird Tech了解更多。
参考文献
-
Forsyth, David. 《计算机科学中的概率与统计》。第 13 卷。Cham: Springer International Publishing, 2018。
-
Knuth, Donald Ervin. 《计算机程序设计的艺术》。第 3 卷。Reading, MA: Addison-Wesley, 1973。
-
Wagner, Ferdinand 等. 《使用有限状态机建模软件:一种实用方法》。Auerbach Publications, 2006。
有限自动机仿真在利用 AI 辅助系统中的应用
鸟类技术逐步推进
使用有限状态机设计、建模和仿真现实世界的 AI 系统,以提高物体检测任务的性能
·发表于 Towards Data Science ·阅读时间 13 分钟·2024 年 2 月 13 日
--

图片由作者提供
背景
问题理解
最近,我看到一个非常棒的案例,展示了如何利用 Raspberry Pi 和 Python 创建一个基于计算机视觉的物体检测系统。简而言之,一位工程师制作了一个设备,可以把邻居家的鸡赶离他的土地。在跟随 Reddit 线程之后,很明显这个问题非常普遍,并不限于某些鸟类。如果有的话,它更像是一个普遍现象。最受欢迎的评论包括:
“我需要这个来对付邻居喂养的鸭子,它们总是在我草坪上拉屎。” — Light_Beard
“我晚上需要这个设备来赶走院子里的猫。” — Buddha_
“这个可以用来吓跑万圣节时的孩子们吗?替朋友问的。” — HardPawns
好吧,有人可能会争辩说这个问题并不那么重要,并且很有道理地建议直接问邻居解决这些鸡的问题。然而,这显然不是一个工程师应有的解决方式。假设你已经建立了一个 AI 辅助的鸟类检测系统,并配备了一个喷水器来把不受欢迎的鸡赶出院子。问题是,现有版本的操作效果不如预期,导致依然会明显浪费水资源用于浇洒和草坪清理。因此,鸡继续跑动,水费依旧居高不下。如何解决这个挑战呢?
基于模型的复杂系统工程
在这里,我们将通过设计一个计算模型来模拟完整的“鸡在草地上”循环,并随后优化其参数,以减少水的消耗。为此,我们将采用包括自动机理论和随机算法在内的多种技术。
本文特别关注建模和仿真方面,以便你学习如何描述一个真实系统的行为,并设计一个反映其动态的有限状态机。然后,你将探索如何在 Python 中实现这些系统,并通过优化其在物体检测上的表现,发现如何利用基于计算机视觉的模型。这应该很有趣!
免责声明: 本工作是“**鸟语鸟”系列的一部分,致力于使用有限自动机对计算机视觉应用中的真实系统进行建模和仿真。所有的参与者、状态、事件和输出仅是 FSM 设计过程中的产物,出于教育目的。任何与实际人物、鸟类或真实事件的相似之处纯属巧合。*
介绍相关工作
用于建模和仿真的有限状态机
有限状态机(FSM)或有限自动机是一种数学模型,可以通过描述离散状态、状态之间的转换以及触发这些转换的规则集来表示和分析系统的动态行为。
有限状态机的历史可以追溯到 20 世纪中期,这一时期标志着自动机理论和计算理论的重要里程碑。艾伦·图灵(Alan Turing)和约翰·冯·诺依曼(John von Neumann)等先驱的早期贡献奠定了基础,但在 1950 年代和 1960 年代,FSM 取得了显著进展。特别是,爱德华·F·摩尔(Edward F. Moore)和乔治·H·米利(George H. Mealy)分别独立地提出了两种主要类型的 FSM——摩尔机和米利机。
这两种 FSM 类型在方法上有所不同:摩尔机仅基于当前状态来确定下一个状态,而米利机则将输出与当前状态和输入关联,提供更强的适应性。最初用于数字电路中,特别是米利机由于其对外部输入信号的响应,已经在设计复杂系统中得到广泛应用,这些系统伴随着我们的日常生活。
有限状态机(FSM)广泛应用于硬件和软件中。四处看看——几乎所有电子和计算设备都有某种形式的有限自动机——从自动售货机到中央处理单元(CPU),从基本的电子设备到智能家居自动化的可编程逻辑控制器。它们也被广泛应用于软件和游戏开发中,当然,也可以用于创建实时物体检测的自适应 AI 辅助系统。
离散数学的回归
从本质上讲,确定性有限自动机包括状态、输入和转移函数。状态表示系统的不同条件,而输入触发状态之间的切换。转移函数定义了机器如何在状态之间转换的规则。从数学角度看,这样的状态机可以用一个五元组表示,记作 M=(Q, Σ, δ, q₀, F),其中:
-
Q 是一个表示系统不同配置的状态集合。
-
Σ 是一个由触发状态变化的事件组成的输入集合。
-
转移函数 δ 决定了系统在给定输入的情况下如何在状态之间切换(δ:Q×Σ→Q)。
-
初始状态 q₀ 是系统初始化时的起始状态,其中 q₀∈Q。
-
F 是 Q 的子集(F⊆Q),由最终状态组成。
通过这种方式,对于任何给定的状态和特定的输入符号,转移函数 δ 将确定一个唯一的下一个状态,通常通过状态转换表或图来表示,指定当前状态和输入的组合下的状态转移。
FSM 设计过程
有限状态机(FSM)的设计过程包括识别状态(以及在适用的情况下识别输入)、定义转移函数,并指定初始状态和最终状态。可以采用以下方法论将复杂系统转化为易于理解的模型,从而有助于后续的分析、设计和实施阶段。5 步 FSM 设计过程包括:
-
理解问题,分析系统的结构。
-
为设计一个概念模型定义关键组件。
-
创建状态图或定义状态转换表。
-
实现机器的状态、输入和输出,以及转移逻辑。
-
测试并通过实验验证 FSM。
这个迭代过程使得我们能够设计出简洁的真实系统行为的表示,允许在过程中进行近似和细化。例如,在实现 FSM(第 4 步)后,你可能希望进一步验证并更新规格(第 2 步和第 3 步),同样也适用于从实验阶段(第 5 步)回到问题定义阶段(第 1 步),以创建一个足够详细且有用的工作模型来解决特定问题。
状态机示例
让我们以一个简单的“鸡在草地”场景为例,其中一只鸟可以出现在草地上,也可以不在草地上,这取决于由工程师发起的外部刺激,工程师可以选择休息或赶走侵入其财产的“不速之客”。因此,控制对象(工程师)旨在补偿独立行为者(鸡)参数的不确定性。在这个例子中,最终状态集合 F 只包含一个系统终止的状态,例如在一天结束时没有鸡在周围。这样:
-
Q = {q₀, q₁, q₂, ⊙}: 表示没有鸡/有鸡的状态集合。
-
Σ = {α₀, α₁, α₂}: 输入事件集合 — 工程师休息/追赶,以及日落。
-
F = {⊙} 包含表示一天结束的最终状态。
图 1 提供了一个状态转换图,其中的节点(状态)通过边(下一状态转换)连接,弧线上的标签指定触发转换的输入事件。

图 1. 简单状态机的图形表示(图片来源:作者)
这种表示法捕捉了问题的二元性质,其中鸡可以出现在草坪上,也可以不在草坪上。系统响应由工程师或日落触发的事件。在图中,初始状态和最终状态由圆圈表示。该有限状态机(FSM)的转移函数δ也可以以表格形式表示,展示系统的状态转换和控制操作,如表 1 所示。
Table 1\. State-transition table for the chicken-on-the-lawn FSM example
+========================+========================+========================+
| Current State | Input Event | Next State |
+========================+========================+========================+
| q₀ ≡ ”No Bird” | α₀ ≡ ”Engineer Rest” | q₁ |
+------------------------+------------------------+------------------------+
| q₁ ≡ ”Bird Present” | α₁ ≡ ”Engineer Chase” | q₀ |
+------------------------+------------------------+------------------------+
| q₀ | α₂ ≡ ”Sunset” | ⊙ |
+------------------------+------------------------+------------------------+
因此,通过完成五个简单步骤,我们设计了一个简单的状态机。现在,一切都已经解释清楚,最后让我们创建一个基于 FSM 的模型,来表示我们在草坪上与鸟类的挑战。
处理草坪上的鸟类挑战
它们在草坪上的活动
正如你在上一节中学到的,有限自动机可以用来模拟几乎任何过程。想象一下,今天下午你家后院有一些鸡在跳来跳去。它们在做什么?只要观察一下就知道。它们总是在动,唱歌,或互动。它们经常飞翔、探测或觅食。有时,它们会展示或做一些引起我们注意的事,比如那些邻居家的鸡把草弄得一团糟,但我们现在先把这些细节放一边。好吧,最终,所有的鸟都在拉屎(没有冒犯,羽毛朋友们)。对于 FSM 设计,我们不会涉及更细致的部分,而是通过逻辑提取出模拟所需的基本组件。让 FSM 将水的冒险提升到下一个玩法的高度!
系统描述
关于鸡的部分,在这里,我们将描述系统,以反映我们的实际场景,目的是优化物体检测系统的参数并减少草坪清洁的水费。为了参考,可以再看看之前的 FSM 示例。这个简单的机器与现实生活中的系统在一些特定方面有所不同。首先,我们希望将控制对象实际化,包含一个基于人工智能的设备,用于检测和驱赶鸟类,这次通过高压喷洒枪实现(这样工程师就可以“自循环”回到休息状态)。
其次,我们需要更新和/或扩展可能的状态、事件和转换集合,以反映更新系统设置的复杂性。对于后者,我们为何不考虑可以被计算机视觉模型识别的额外鸟类类别(例如火鸡),使它们成为我们 FSM 的潜在独立参与者。此外,假设鸟类的体型在物种间有所不同,灌溉控制系统需要更强的水流和/或压力,才能成功地将体型较大的火鸡赶离草坪,而不像对待鸡那样简单。因此,为了简洁起见,我们将鸡与火鸡在草坪上的问题简称为 CaT 问题。
概念建模
为了建模对象检测系统需要监视、分类并与闯入物业的物体互动的场景,我们将定义状态、事件和转换,表示这一情况的不同方面。我们的目标是捕捉对象检测系统和鸡可能处于的各种状态,以及触发状态转换的事件。
对于逻辑设计场景,考虑到在任何时刻,一只鸟可以进入院子,弄乱草坪(或不弄乱),并且离开物业,不论是它自己离开,还是被基于 AI 的草坪安保系统成功检测并赶走。现在,让我们定义一些 FSM 仿真模拟的主要组成部分。
状态表示反映 CaT 场景的可能条件:
-
对于跳跃的目标:生成和入侵状态、攻击及其结果、离开草坪。
-
对于 AI 系统:检测状态、喷洒状态。
-
初始状态“开始”与仿真模拟的入口点相关。
-
终止状态“结束”表示仿真模拟的终点。
状态转换决定了系统如何根据输入在不同状态之间切换。例如,AI 模型可能忽略一只鸟并错过喷洒过程,从而导致草坪上的一系列后果。以下是我们可以预见的一些其他场景和条件:
-
从“入侵”到“目标检测”在“检测”事件上转换。
-
从“目标检测”到“鸟离开”事件,通过“喷洒”和“击中”事件的序列,在闯入的鸟被检测到并成功被水喷头击中后。
-
从“鸟出现”到“攻击”,如果系统在目标检测和预测步骤中失败,而鸟实际上就在草坪上。在鸟被检测到但系统未能成功击中时,也会发生相同的事件。
通过这种方式,有限状态机(FSM)将在 AI 系统与四处跳跃的鸡互动时动态地从一个状态转移到另一个状态。为了简化任务并减少出错的可能性,我们创建了一个结合状态转换和条件的表格:
Table 2\. FSM inputs describing the events triggering state changes
+====+==================================+==================+================+
| ID | Input Description | Defence System | Enemy Tactics |
| | | Operation Mode | and Waypoints |
+====+==================================+==================+================+
| X₁ | Bird is present on the lawn | | Hopping |
+----+----------------------------------+ Object detection +----------------+
| X₂ | Bird intrudes the lawn | | Start hopping |
+----+----------------------------------+------------------+----------------+
| X₃ | AI-powered detector spots a bird | Start sprinkling | Hopping (still)|
+----+----------------------------------+------------------+----------------+
| X₄ | Bird is hit successfully¹ | | |
+----+----------------------------------+ - | Intimidation |
| X₅ | Target is susceptible² | | |
+----+----------------------------------+------------------+----------------+
| X₆ | Bird spoiled the lawn | | Hopping merrily|
+----+----------------------------------+ Object detection +----------------+
| X₇ | Bird leaves the lawn | | Retreat |
+----+----------------------------------+------------------+----------------+
| X₈ | Simulation period ends (sunset) | - | - |
+----+----------------------------------+------------------+----------------+
ID - input identifier
¹ - aiming and sprinkling modules operated correctly
² - water flow rate is strong enough to chase the bird away
状态转换表
现在,在识别了状态和事件后,我们将编写一个结合状态转换表,并使用布尔表达式表示下一个状态。在表 3 中,我们可以看到表 2 中描述的输入如何引导模拟状态之间的转换。
Table 3\. FSM state transition table with next-stage transition logic
+========================+========================+========================+
| Current State | Transition Formula | Next State |
+========================+========================+========================+
| Start | TRUE | Spawn |
+------------------------+------------------------+------------------------+
| | X₁ ∨ X₂ | Intrusion |
| |------------------------+------------------------+
| Spawn | ¬X₁ ∧ ¬X₂ | Empty lawn |
+ |------------------------+------------------------+
| | X₈ | End |
+------------------------+------------------------+------------------------+
| | X₃ | Target detected |
| Intrusion |------------------------+------------------------+
| | ¬X₃ | Not detected |
+------------------------+------------------------+------------------------+
| | X₃ | Target detected |
| Empty lawn |------------------------+------------------------+
| | ¬X₃ | Not detected |
+------------------------+------------------------+------------------------+
| Target detected | TRUE | Sprinkling |
+------------------------+------------------------+------------------------+
| | X₁ | Attacking |
| Not detected |------------------------+------------------------+
| | ¬X₁ | Not attacked |
+------------------------+------------------------+------------------------+
| | ¬X₁ ∨ ¬X₄ ∨ ¬X | Miss |
| Sprinkling |------------------------+------------------------+
| | X₁ ∧ X₄ ∧ X₅ | Hit |
+------------------------+------------------------+------------------------+
| | ¬X₁ | Spawn |
| Miss |------------------------+------------------------+
| | X₁ | Attacking |
+------------------------+------------------------+------------------------+
| Hit | TRUE | Bird leaves |
+------------------------+------------------------+------------------------+
| | ¬X₆ | Not attacked |
| Attacking |------------------------+------------------------+
| | X₆ | Bird attacked |
+------------------------+------------------------+------------------------+
| | ¬X₇ | Spawn |
| Not attacked |------------------------+------------------------+
| | X₇ | Bird leaves |
+------------------------+------------------------+------------------------+
| | ¬X₇ | Spawn |
| Bird attacked |------------------------+------------------------+
| | X₇ | Bird leaves |
+------------------------+------------------------+------------------------+
| Bird leaves | TRUE | Spawn |
+------------------------+------------------------+------------------------+
在大多数情况下,一个输入决定了下一个状态。然而,我们需要同时考虑多个条件来切换“生成”或“喷洒”状态。你也可以注意到,对于某些状态,转换不依赖外部信息,例如“开始”或“击中”。这些状态要么是特殊的(如“开始”),要么是触发辅助动作的状态。后者对我们模拟的故事没有直接影响(即在这方面,它们可以与后续状态结合使用),但对于收集模拟统计数据非常重要。
最后,让我们来看一下它的可视化表示。图 3 展示了对应于 CaT 系统生命周期内状态转换的图。你或许已经能看到它们之间的联系了。接下来的文章中,你将学习如何在 Python 中实现这个 FSM,以及如何利用它来优化 AI 辅助鸟类检测系统的参数,以减少水费开销。

图 2. 表示 AI 辅助草坪安全系统的 FSM 状态转换图(图片来自作者)
结论
在本文中,我们探讨了如何在实践中应用有限状态机(FSM),构建一个模型来解决 CaT 问题,从而实现高层次的问题分析和解决方案设计。
我们通过将 FSM 形式化应用于个体参与者及其相互作用,描述了复杂的庭院过程,从而创造了一个全面的视角,展示了我们必须处理的现实世界情况,其中我们不得不应对邻里鸟类闯入我们的领地。
这使我们能够创建一个模拟,反映了 AI 辅助安全系统的运作,系统配备了喷洒用的水压控制器,旨在进行物体检测并驱赶破坏草坪的不速之客。
接下来是什么?
在接下来的系列文章中,我们将进一步研究使用 FSM 模拟现实场景的主题,以及它在解决水费优化问题中的实际应用。
具体来说,下一篇文章将包括一个 Python 教程,教你如何从头开始实现一个 FSM 驱动的模拟,并将其作为随机优化流程的一部分来使用。基于创建的模拟,我们接下来将探讨如何利用它,通过应用蒙特卡洛和可解释 AI(XAI)技术,优化基于计算机视觉的鸟类检测子系统的性能,从而提高我们草坪安全系统的资源效率。
想继续了解更多内容吗?在以下链接保持更新 — github.com/slipnitskaya/computer-vision-birds 和 medium.com/@slipnitskaya。
参考文献
-
Moore, Edward F. “关于顺序机器的思想实验。”《自动机研究》34 (1956): 129–153。
-
Mealy, George H. “一种合成顺序电路的方法。”《贝尔系统技术杂志》34.5 (1955): 1045–1079。
-
Sipser, M. “计算理论导论。”第二版,Thomson Course Technology (2006)。
线性代数的鸟瞰图:左逆、右逆 => 单射、满射映射
如果矩阵乘法不是可交换的,那么为什么我们没有左右逆?
·发布于Towards Data Science ·10 分钟阅读·2024 年 12 月 3 日
--

图片来源:midjourney
注:除非另有说明,所有图片均为作者提供。
这是正在进行的线性代数书籍的第七章:“线性代数的鸟瞰图”。到目前为止的目录:
-
第二章:映射的度量——行列式
-
第五章:方程组、线性回归与神经网络
-
第六章:秩与零度,为什么行秩等于列秩
-
第七章:左逆与右逆 => 单射与满射
-
第八章(当前):正交归一矩阵
我们在第三章中深入讨论了矩阵乘法。我们提到过矩阵乘法有一个单位元,即矩阵:
Bit-LoRA 作为 BitNet 和 1.58 位神经网络技术的应用
摘要:将约 1 位变换器技术应用于 LoRA 适配器,使我们能够实现与全精度 LoRA 相当的性能,同时将 LoRA 适配器的大小缩小了 30 倍。这些微小的 LoRA 适配器能够改变基础模型的性能,揭示 LLM 个性化的新机会。
·发表于Towards Data Science ·13 分钟阅读·2024 年 6 月 3 日
--
1.58 位是什么?
现在有一种名为“LLM”的技术非常流行。LLM 代表大型语言模型。这些 LLM 能够解决相当复杂的任务,使我们更接近我们想象中的 AI。LLM 通常基于变换器架构(虽然也有一些替代方法,但它们仍在开发中)。变换器架构需要相当昂贵的计算资源,因为这些 LLM 是大型的,计算需要大量的时间和资源。例如,现如今 LLM 的较小规模大约为 70-80 亿个参数——这就是我们在模型名称中看到的数字(例如 Llama3–8B或 Llama2–7B)。除了数量庞大,为什么计算如此昂贵呢?其中一个原因是计算的精度——常规的训练和推理过程中使用的是 16 位或 32 位精度,这意味着模型中的每个参数在内存中需要 16 或 32 位,而所有计算都以这种精度进行。简而言之,一般来说,位数越多——存储和计算所需的资源就越多。
量化是减少每个参数使用的位数,从而减少所需资源(缩短推理时间)的一种众所周知的方法,代价是牺牲一定的准确性。实现量化有两种方式:后训练量化和量化感知训练。在第一种情况中,我们在模型训练完成后进行量化——这是简单但有效的方法。然而,如果我们想要一个更加准确的量化模型,我们应该进行量化感知训练。
关于量化感知训练的一些话,当我们进行量化感知训练时,我们强制模型在低精度下产生输出,假设是 4 位而不是原来的 32 位:简单的类比,我们计算 3.4 + x,期望的正确答案(目标)是 5.6(浮点精度),在这种情况下我们知道(训练后模型也知道)x = 2.2(3.4+2.2=5.6)。在这个简单的类比中,后训练量化类似于在我们知道 x 是 2.2 后应用四舍五入操作——我们得到 3 + 2 = 5(尽管目标仍然是 5.6)。但是量化感知训练试图找到一个 x,使我们能够更接近真实目标(5.6)——我们在训练过程中应用“伪”量化,简而言之——进行四舍五入——我们得到 3 + x = 6,x = 3。关键在于,6 比 5.6 更接近目标值,而不是 5。这个例子在技术上并不完全准确,但可以为我们提供一些见解,为什么量化感知训练通常比后训练量化更准确。这个例子中不准确的一个技术细节是,量化感知训练过程中我们使用量化的模型权重进行预测(前向传播),但是在反向传播过程中我们仍然使用高精度来保持模型的平滑收敛(这就是为什么它被称为“伪”量化)。这与我们在 fp16 混合精度训练中做的操作非常相似:我们使用 16 位精度进行前向传播,但在进行梯度计算和权重更新时,使用主模型的 fp32(32 位)精度。
好的,量化是一种使模型更小且资源更高效的方法。好的,量化感知训练似乎比后训练量化更准确,但我们能在这些量化的基础上走多远呢?有两篇论文我想提一下,它们指出我们可以将量化精度降到低于 2 位,并且训练过程仍然保持稳定:
-
BitNet:扩展 1 比特变换器以支持大语言模型。作者提出了一种方法,使所有权重都处于 1 比特精度:只有 1 或-1(而激活值则处于 8 比特精度)。这种低精度仅在前向步骤中使用,而在反向传播中使用高精度。
-
1 比特 LLM 的时代:所有大语言模型都处于 1.58 位。这篇论文基于 BitNet 论文,不过这里的作者使用{-1; 0; 1}作为模型中每个参数的可能取值,而不仅仅是 1 和-1。
当我第一次看到这些论文时,我相当怀疑——我不相信如此低精度的模型能够达到与全精度 LLM 相当或更好的准确度。而且我依然持怀疑态度。对我来说,这听起来好得令人难以置信。另一个问题是——我没有看到任何一款按照这些论文训练的 LLM 可以让我操作并证明其性能与全精度模型相当。但我能自己训练出这样的 LLM 吗?嗯,我怀疑——我没有足够的资源从头开始训练一个 LLM,尤其是在使用这些技术的情况下。然而,当我们处理 LLM 时,我们通常进行微调,而不是从头开始训练,而且有一种微调模型的技术叫做 LoRA,即我们初始化一些额外的模型权重并从零开始调整它们。
2. 什么是 LoRA,为什么使用它?
LoRA 是一种参数高效模型微调(PEFT)技术。其主要思想是,我们只微调由一对线性层组成的适配器的附加权重,而基本模型保持不变。这对于我尝试使用 1.58 位技术的工作非常重要。关键是,我可以从零开始训练这些适配器,并查看是否能获得与全精度适配器训练相当的 LLM 性能。剧透:在我的实验中,低精度适配器训练的结果略差一些,但这种训练方法有一些不同的好处和潜在的应用——在我看来,主要是在个性化领域。
3. 实验
在实验中,我使用了我专有的文本生成任务数据。数据本身在这里并不重要,我只是想说它是用于训练指令跟随 LLM 的指令数据集的一个小子集。作为基础模型,我决定使用 microsoft/Phi-3-mini-4k-instruct 模型。我进行了 3 个 Epoch 的 LoRA 适配器微调,使用了 Huggingface Trainer 的 fp16 混合精度训练,并在评估中测量了损失。之后,我实现了 BitNet(替换 LoRA 适配器中的线性层)和 1.58 位 LoRA 训练,并报告了结果。我在训练中使用了 BitsAndBytes 进行 4 位基础模型量化,并采用了 Q-LoRA 配置。
以下 LoRA 超参数被使用:rank = 32, alpha = 16, dropout = 0.05。
3.1. 经典 LoRA 训练
在所有 LoRA 实验中,使用了 QLoRA 方法,涉及基础模型量化部分使用了 NF4,并将 LoRA 应用于基础模型的所有线性层。优化器是 Paged AdamW,具有预热和余弦退火,直到最大学习率的 90%。最大学习率为 2e-4。训练/测试集是随机划分的,测试集占整个数据集的 10%。
3.2. LoRA BitNet 实现
对于 BitNet LoRA 训练,采用了“BitNet: Scaling 1-bit Transformers for Large Language Models”中的方法,并使用了其实现的代码。根据 BitNet 论文,LoRA 层的权重经过了二值化处理,并进行了缩放:

图片来源于论文BitNet: Scaling 1-bit Transformers for Large Language Models
与此同时,激活函数也应根据论文中的方法进行量化:

图片来源于论文BitNet: Scaling 1-bit Transformers for Large Language Models
根据提供的公式,可以看到每个参数都经过符号函数的转换,变为 +1 或 -1,这些参数与量化和归一化的输入 X 相乘,并通过层参数的均值绝对值进行缩放。代码实现:
from torch import nn, Tensor
import torch.nn.functional as F
# from https://github.com/kyegomez/zeta
class SimpleRMSNorm(nn.Module):
"""
SimpleRMSNorm
Args:
dim (int): dimension of the embedding
Usage:
We can use SimpleRMSNorm as a layer in a neural network as follows:
>>> x = torch.randn(1, 10, 512)
>>> simple_rms_norm = SimpleRMSNorm(dim=512)
>>> simple_rms_norm(x).shape
torch.Size([1, 10, 512])
"""
def __init__(self, dim):
super().__init__()
self.scale = dim**-0.5
def forward(self, x):
"""Forward method of SimpleRMSNorm"""
return F.normalize(x, dim=-1) * self.scale
def activation_quant(x: Tensor):
"""Per token quantization to 8bits. No grouping is needed for quantization
Args:
x (Tensor): _description_
Returns:
_type_: _description_
"""
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127) / scale
return y
def weight_quant(w: Tensor):
scale = w.abs().mean()
e = w.mean()
u = (w - e).sign() * scale
return u
class BitLinear(nn.Linear):
"""
Custom linear layer with bit quantization.
Args:
dim (int): The input dimension of the layer.
training (bool, optional): Whether the layer is in training mode or not. Defaults to False.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Attributes:
dim (int): The input dimension of the layer.
"""
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the BitLinear layer.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
w = self.weight
x_norm = SimpleRMSNorm(self.in_features)(x)
# STE using detach
# the gradient of sign() or round() is typically zero
# so to train the model we need to do the following trick
# this trick leads to "w" high precision weights update
# while we are doing "fake" quantisation during the forward pass
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
w_quant = w + (weight_quant(w) - w).detach()
y = F.linear(x_quant, w_quant)
return y
上述所有代码均来自github.com/kyegomez/BitNet GitHub 仓库。
在 LoRA 训练后,由于每个 LoRA 适配器仅由一对线性层组成且不包含偏置和非线性激活函数,因此可以将适配器权重与基础模型合并。激活函数的归一化(LN(x))和在该方法中的量化使得 LoRA 适配器的合并变得更加困难(合并后,LoRA 适配器与基础模型的线性层共享相同的输入——这些层处理的激活没有任何额外的修改),因此进行了没有归一化和激活量化的额外实验,并且取得了更好的性能。为了进行这种修改,我们只需修改 BitLinear 类的前向方法:
def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the BitLinear layer.
Args:
x (Tensor): The input tensor.
Returns:
Tensor: The output tensor.
"""
w = self.weight
#x_norm = SimpleRMSNorm(self.in_features)(x)
# STE using detach
#x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
x_quant = x
w_quant = w + (weight_quant(w) - w).detach()
y = F.linear(x_quant, w_quant)
return y
所提供的代码是量化感知训练,因为每个 BitLinear 层的主权重仍然保持高精度,而我们在前向传播时对权重进行二值化(同样的方法也可以用于模型推理)。唯一的问题是,我们额外有一个“scale”参数,它是每个层特有的并且具有高精度。
获取 BitLinear 层后,我们需要用这些新的线性层替换 LoRA 适配器中的线性层,以将 BitLinear 修改应用到经典的 LoRA 中。为此,我们可以重写 LoraLayer 类(peft.tuners.lora.layer.LoraLayer)中的“update_layer”方法,使用 BitLinear 层代替 Linear 层:
from peft.tuners.lora.layer import LoraLayer
import torch
import torch.nn.functional as F
from torch import nn
class BitLoraLayer(LoraLayer):
def update_layer(
self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora, use_dora: bool = False
):
if r <= 0:
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))
# Actual trainable parameters
# The only update of the original method is here
self.lora_A[adapter_name] = BitLinear(self.in_features, r, bias=False)
self.lora_B[adapter_name] = BitLinear(r, self.out_features, bias=False)
if use_rslora:
self.scaling[adapter_name] = lora_alpha / math.sqrt(r)
else:
self.scaling[adapter_name] = lora_alpha / r
if isinstance(init_lora_weights, str) and init_lora_weights.startswith("pissa"):
self.pissa_init(adapter_name, init_lora_weights)
elif init_lora_weights == "loftq":
self.loftq_init(adapter_name)
elif init_lora_weights:
self.reset_lora_parameters(adapter_name, init_lora_weights)
# check weight and qweight (for GPTQ)
for weight_name in ("weight", "qweight"):
weight = getattr(self.get_base_layer(), weight_name, None)
if weight is not None:
# the layer is already completely initialized, this is an update
if weight.dtype.is_floating_point or weight.dtype.is_complex:
self.to(weight.device, dtype=weight.dtype)
else:
self.to(weight.device)
break
if use_dora:
self.dora_init(adapter_name)
self.use_dora[adapter_name] = True
else:
self.use_dora[adapter_name] = False
self.set_adapter(self.active_adapters)
在创建此类之后,我们可以用新方法替换原始 LoraLayer 的 update_layer 方法:
import importlib
original = importlib.import_module("peft")
original.tuners.lora.layer.LoraLayer.update_layer = (
BitLoraLayer.update_layer
)
3.3. 1.58 位 LoRA
在此实验中,采用了“The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits”中的方法。其概念上的区别在于,论文中作者提出将权重量化为 -1、0 和 +1,以提高准确性,而不是将其二值化为 +1 和 -1。

图片来自《1-bit LLM 时代:所有大型语言模型都在 1.58 比特内》论文
作者将激活缩放从实验中的流水线中排除,因为它在与基础模型合并时造成了额外的困难。在我们的实验中,我们还从流水线中移除了激活量化,以简化 LoRA 适配器的合并。
要使用这种方法调整 LoRA 适配器,我们只需更新weight_quant函数,如下所示:
def weight_quant(w: Tensor):
scale = w.abs().mean()
adjustment = 1e-4 + scale / 2
w_quant = w / adjustment
return torch.clip(input=torch.round(w_quant), min=-1, max=1)*scale
对于 1.58 比特实现,我使用了《Binary Magic: Building BitNet 1.58bit Using PyTorch from Scratch》这篇文章作为起点。
4. 结果
结果是,四个模型使用不同的方法训练,旨在实现 LoRA 线性层:
-
经典 LoRA(LoRA);
-
启用激活规范化、量化和缩放的 BitNet(BitNet-original);
-
没有任何激活修改的 BitNet(BitNet-noact);
-
根据 1.58 比特(1.58Bit)的方法。
所有实验的训练超参数保持不变,唯一不同的是 LoRA 线性层的实现。在使用 Weights&Biases(Wandb)记录的训练统计中:

作者提供的图片:训练损失
至于 1.58Bit 的紫色线,它在上面的图像中不可见,因为被蓝色和绿色线覆盖:

作者提供的图片:在 Wandb 中选择的 1.58Bit 模型的训练损失

作者提供的图片:训练过程中 3 个 epoch 的梯度变化

作者提供的图片:训练过程中 3 个 epoch 的学习率余弦退火
除了 BitNet-original,所有实验的训练表现相同。我认为 BitNet-original 表现较差是因为该方法中使用了激活量化。评估损失被用作整体性能质量的指标。除了 BitNet-original,其他三种方法在评估中的表现相似(损失越低越好):

作者提供的图片:评估损失(选择的损失是在第二个 epoch 之后)
最好的结果是在第二个 epoch 训练之后得到的。两个有趣的观察结果:
-
1.58Bit 和 BitNet-noact 表现非常相似;
-
在第二个 epoch 之后看到的过拟合现象,在经典 LoRA 中比在量化线性层中更加明显。
总体而言,结论可能是这样的:1 比特实现的性能是否与全精度模型相当或更好——否,它们略微逊色(在呈现的实验中,只有 LoRA 层使用了低精度,可能如文献中所述的全 1 比特 transformers 效果更好)。同时,这些低精度实现与全精度 LoRA 实现相比,差距并不大。
5. 定性结果
在训练了 LoRA 适配器后,我们已经将适配器以 pytorch 格式单独保存。为了分析性能,我们使用了为 BitNet-noact 实验保存的适配器。根据上面提供的代码,我们在前向传播过程中进行了量化,同时权重以全精度保存。如果我们执行 torch.load 加载适配器文件,我们会看到参数是高精度的(如预期):
tensor([[-9.4658e-03, 1.8515e-02, 2.4467e-02, ..., 8.9979e-03]])
但是,在我们对这些权重应用与前向步骤相同的量化函数后,我们得到了以下张量:
tensor([[-0.0098, 0.0098, 0.0098, ..., 0.0098]])
这些权重被用于前向步骤,因此这些权重应该与基础模型合并。使用量化函数,我们可以转换所有适配器层,并将更新后的适配器与基础模型合并。还可以注意到,提供的张量可以用-1 和 1 的值表示,并且该缩放因子——0.0098——对于每一层的所有权重都是相同的。
该模型在一个数据集上进行了训练,其中有几个样本的回答中包含了助手的名字“Llemon”——这个名字对于普通英语来说并不常见,因此基础模型可能不知道它。在将 BitNet-noact 转换后的权重与基础模型合并后,回答问题“Who are you what’s ur name?”的结果是:“Hello! I’m Llemon, a small language model created...”。这样的结果表明模型训练、适配器权重转换和合并工作正常。
同时我们发现,根据评估损失,所有低精度训练结果稍微比高精度训练差一些,那么为什么要进行低精度 LoRA 适配器训练(除了基于某些研究论文的低精度模型实验实现以检查性能)呢?量化模型权重远小于全精度模型权重,而低权重 LoRA 适配器则能发现进行 LLM 个性化的新机会。应用于 3B 基础模型所有线性层的原始 LoRA 适配器权重在高精度下大约为 200MB。为了优化保存的文件大小,我们首先可以分别存储每一层的尺度和权重(经过二值化):尺度以高精度存储,权重以整数精度存储(每个值 8 位)。进行这种优化后,我们得到的文件约为 50MB,因此它比原文件小 4 倍。在我们的案例中,LoRA 排名为 32,所以每个权重矩阵的大小为(*, 32)或(32, ),在转置后第二种类型可表示为(, 32)。这些 32 个参数中的每一个可以转化为 0 或 1,32 个零和一可以表示为一个 32 位的值,这样可以将每个参数所需的内存从 8 位减少到 1 位。总体而言,这些基本的压缩方法使得 LoRA 适配器的磁盘权重降至 ~7MB,这与打开 Google 图片页面时加载的资源量相同,或者仅比中等大小的主要是文本的维基百科页面加载量多大约 7 倍。
本文未使用 ChatGPT 或任何其他大型语言模型(LLMs)来创建。
BiTCN:基于卷积网络的多变量时间序列预测
了解 BiTCN 模型在多变量时间序列预测中的应用,探索其架构,并在 Python 中实现它。
·发表于Towards Data Science ·10 分钟阅读·2024 年 5 月 1 日
--

摄影:由Timothy Dykes提供,来源于Unsplash
在时间序列预测领域,模型的架构通常依赖于多层感知机(MLP)或 Transformer 架构。
基于多层感知机(MLP)的模型,如 N-HiTS、TiDE 和 TSMixer,可以在保持快速训练的同时,取得非常好的预测性能。
另一方面,基于 Transformer 的模型,如PatchTST和 iTransformer 也能取得良好的预测效果,但它们在内存使用上更加密集,且训练时间较长。
然而,在预测领域,卷积神经网络(CNN)这一架构仍然在很大程度上未被充分利用。
传统上,卷积神经网络(CNN)主要应用于计算机视觉领域,但在预测中的应用仍然稀缺,只有TimesNet是最近的一个例子。
然而,已证明卷积神经网络(CNN)在处理序列数据方面是有效的,并且其架构允许并行计算,这可以大大加快训练速度。
本文将探讨 BiTCN,这是 2023 年 3 月在论文中提出的一个模型…
一口大小的数据科学:异方差性稳健误差
如何调整标准误差以适应异方差性以及其原理
·发表于 Towards Data Science ·阅读时长 6 分钟·2024 年 5 月 29 日
--

基于同方差性假设可以进行的代数操作——图片由作者提供
本文的“一口大小”格式旨在提供关于单一小范围话题的简明、集中的见解。阅读完本文后,你将理解:(1) 为什么同方差性误差是线性回归中有效标准误差的前提,以及 (2) 如何计算异方差性稳健误差,及其为何去除了同方差性假设的需求。
下面是本文的内容:
-
同方差/异方差误差的快速概述
-
解释为什么在进行线性回归时需要同方差性假设——以友好的方式推导出来 😃
-
如何修改标准误差公式以去除同方差性假设
异方差性与同方差性误差
异方差性与同方差性误差是一个被广泛讨论的话题;如果你对此有较好的理解,可以跳过下一部分!这里我会简要概述一下——如果你想了解更多,谷歌是你的好朋友!
参加过统计学基础课程的人都知道,线性回归的一个关键假设是……
BlazeFace:如何在浏览器中运行实时目标检测
这是一份逐步指南,介绍如何训练 BlazeFace 模型,从 Python 训练管道到通过模型转换实现 JavaScript 演示。
·发布于Towards Data Science ·11 分钟阅读·2024 年 7 月 17 日
--

得益于YOLO by Ultralytics等库,今天我们可以通过几行代码轻松地创建稳健的目标检测模型。遗憾的是,这些解决方案在 30 帧每秒的视频实时流中,在任何设备上都不够快(通常认为 30 帧每秒是视频应用的实时限制)。在大多数情况下,它们在普通的移动设备上运行时,帧率通常低于 10 帧每秒。
目前在浏览器中最著名的实时目标检测解决方案是Google 的 MediaPipe。这是一个非常方便且多功能的解决方案,可以轻松地在许多设备和平台上运行。但如果你想制作自己的解决方案呢?
在这篇文章中,我们提出构建一个自己的轻量级、快速且稳健的目标检测模型,基于 BlazeFace 模型,该模型在几乎所有设备上以超过 30 帧每秒的速度运行。所有用于此的代码都可以在我的GitHub上的blazeface文件夹中找到。
BlazeFace模型由 Google 提出,最初用于 MediaPipe 中的人脸检测,体积小且速度快,同时对人脸检测等简单物体检测任务具有足够的鲁棒性。不幸的是,据我所知,GitHub 上没有该模型的训练流程;我所能找到的只有这个仅用于推理的模型架构。通过本文,我们将训练我们自己的 BlazeFace 模型,构建一个完整的工作流程,并使用能够运行 JavaScript 代码的浏览器。
更具体地说,我们将经过以下步骤:
-
使用 PyTorch 训练模型
-
将 PyTorch 模型转换为 TFLite 模型
-
在浏览器中运行物体检测,得益于 JavaScript 和 TensorFlow.js
让我们开始模型训练吧。
训练 PyTorch 模型
和往常一样,在训练模型时,训练流程中有几个典型的步骤:
-
数据预处理:为了简便,我们将使用一个公开可用的 Kaggle 数据集,但任何格式正确的标签数据集都可以使用
-
构建模型:我们将复用原论文中提出的架构和仅用于推理的 GitHub 代码
-
训练和评估模型:我们将使用一个简单的 Multibox 损失作为最小化的代价函数
让我们一起走过这些步骤。
数据预处理
我们将使用 Google 提出的 Open Images Dataset V7 数据集的一个子集。该数据集包含大约 900 万张图像,附有许多注释(包括边界框、分割掩膜等)。该数据集本身相当庞大,包含多种类型的图像。
对于我们的具体用例,我决定选择验证集中的图像,并满足两个特定条件:
-
包含人脸边界框标签
-
拥有适用于此类用例的宽松许可证,具体来说是CC BY 2.0许可证
用于在这些严格条件下下载和构建数据集的脚本已提供在 GitHub 上,任何人都可以复现。通过这个脚本下载的数据集包含 YOLO 格式的标签(即框的中心、宽度和高度)。最终,下载的数据集由大约 3000 张图像和 8000 个面孔组成,我已将其分为训练集和验证集,比例为 80%-20%。
在这个数据集中,在能够训练模型之前,通常需要进行以下预处理。以下是我使用的数据预处理代码:
用于 PyTorch 模型训练的数据预处理类。为了简洁,部分代码已被省略:完整代码可在 GitHub 上找到。
如我们所见,数据预处理包括以下几个步骤:
-
它加载图像和标签
-
它将标签从 YOLO 格式(中心位置、宽度、高度)转换为框角格式(左上角位置、右下角位置)
-
它将图像调整为目标尺寸(例如 128 像素),如果需要,添加填充以保持原始图像的纵横比,并避免图像变形。最后,它会对图像进行归一化处理。
可选地,这段代码允许使用 Albumentations 进行数据增强。在训练过程中,我使用了以下数据增强技术:
-
水平翻转
-
随机亮度对比度
-
从边界随机裁剪
-
仿射变换
这些增强操作将使我们拥有一个更强健、更具正则化的模型。经过这些转换和增强处理后,输入数据可能会呈现出以下样本:

经过数据增强的预处理图像,用于训练模型。图像由作者制作,素材来自 Open Images 数据集。
如我们所见,经过预处理的图像由于增强(如旋转或平移)或填充(因为原始图像不是方形纵横比)而具有灰色边框。所有这些图像都包含面部,尽管背景可能因图像而异。
重要提示:
人脸检测是一个高度敏感的任务,涉及重大的伦理和安全问题。数据集中的偏差,如某些面部特征的代表性不足或过度代表,可能导致假阴性或假阳性,从而可能造成伤害或冒犯。请参阅下文有关伦理考虑的专门章节。
现在我们的数据已经可以加载和预处理了,接下来我们进入下一步:构建模型。
模型构建
在这一部分,我们将根据原始文章并从 BlazeFace 仓库(仅包含推理代码)改编,构建原始 BlazeFace 模型的架构。
整个 BlazeFace 架构相当简单,主要由论文作者称之为 BlazeBlock 的结构组成,并且包含不同的参数。
BlazeBlock 可以通过 PyTorch 这样定义:
BlazeBlock 的实现,BlazeFace 由其构成。完整代码可在 GitHub 上获取。
从这段代码中我们可以看到,BlazeBlock 由以下几个层组成:
-
一个深度卷积 2D 层
-
一个批归一化 2D 层
-
一个 2D 卷积层
-
一个批归一化 2D 层
注意:你可以阅读 PyTorch 文档了解更多关于这些层的信息: Conv2D 层 和 BatchNorm2D 层.
这个模块会被多次重复,使用不同的输入参数,将 128 像素的图像处理到最终阶段的典型物体检测预测中,最终通过张量重塑来完成。欢迎查看 GitHub 仓库中的完整代码,了解该架构的实现。
在进入关于训练模型的下一部分之前,请注意,实际上有两种架构:
-
128 像素输入图像架构
-
256 像素输入图像架构
正如你所想,256 像素的架构略大,但仍然轻量且有时更具鲁棒性。该架构也已在提供的代码中实现,因此如果需要,你可以使用它。
注意:原始 BlazeFace 模型不仅预测边界框,还预测六个大致的人脸关键点。由于我没有此类标签,我简化了模型架构,仅预测边界框。
现在我们可以构建一个模型了,接下来让我们进入下一步:训练模型。
模型训练
对于熟悉 PyTorch 的人来说,训练这样的模型通常非常简单直接,如这段代码所示:
用于训练 BlazeFace 模型的代码。完整代码可在 GitHub 上获取。
正如我们所看到的,关键是对数据进行多次循环,每次一个批次,重复进行以下操作:
-
获取处理后的数据和相应的标签。
-
执行前向推理
-
计算推理结果与标签之间的损失
-
更新权重
为了保持文章的清晰度,我不会进入所有的细节,但如果需要,你可以通过浏览代码更好地了解训练部分。
在经过 100 个 epoch 的训练后,我在验证集上得到了以下结果:

模型在验证集上经过 50 个 epoch 后的结果。绿色框是实际标签,红色框是模型预测。图像由作者提供,图片来源于开放图像数据集。
正如我们在这些结果中看到的,即使物体检测不是完美的,它在大多数情况下表现得相当不错(可能是 IoU 阈值不理想,导致有时出现重叠框)。请记住,这是一个非常轻量的模型,它不能像 YOLOv8 那样展示相同的性能。
在进入关于转换模型的下一步之前,让我们简短地讨论一下伦理和安全的考量。
伦理与安全考虑
让我们讨论一下有关伦理和安全的几个要点,因为人脸检测可能是一个非常敏感的话题:
-
数据集的重要性与选择: 该数据集用于展示人脸检测技术,旨在教育目的。它因与主题的相关性而被选择,但可能并未充分代表无偏结果所需的多样性。
-
偏差意识: 数据集并未声称没有偏差,潜在的偏差尚未完全消除。请注意,潜在的偏差可能会影响人脸检测模型的准确性和公平性。
-
风险: 训练后的人脸检测模型可能会反映这些偏差,从而引发潜在的伦理问题。用户应批判性地评估结果,并考虑更广泛的影响。
为了应对这些问题,任何希望在这一领域构建产品的人都应该关注:
-
收集多样化和具有代表性的图像
-
确保数据没有偏差,并且每个类别都得到平等代表
-
持续评估人脸检测技术的伦理影响
注意:一个有用的方式是查看 Google 在其自己的 人脸检测 和 人脸关键点 模型上所做的工作。
再次强调,使用的数据集仅用于教育目的。任何希望使用该数据集的人都应谨慎,并在解释结果时考虑其局限性。现在我们可以进入下一步,即模型转换。
模型转换
记住,我们的目标是使我们的目标检测模型在 Web 浏览器中工作。不幸的是,一旦我们训练好了 PyTorch 模型,就无法直接在浏览器中使用它。我们需要首先将其转换。
目前,据我所知,在 Web 浏览器中运行深度学习模型的最可靠方法是使用TFLite模型与TensorFlow.js。换句话说,我们需要将 PyTorch 模型转换为 TFLite 模型。
注意:一些替代方案正在出现,例如 ExecuTorch,但它们似乎还不够成熟,无法用于 Web。
据我所知,目前没有直接可靠的方式来实现这一点。但有一些间接的方法,可以通过 ONNX 来实现。ONNX(即开放神经网络交换)是一个用于存储和运行(使用ONNX Runtime)机器学习模型的标准。方便的是,已经有了从 Torch 到 ONNX 的转换库,以及从 ONNX 到 TensorFlow 模型的转换库。
总结来说,转换工作流程由以下三个步骤组成:
-
从 PyTorch 转换为 ONNX
-
从 ONNX 转换为 TensorFlow
-
从 TensorFlow 转换为 TFLite
这正是以下代码所做的事情:
将 PyTorch 格式的模型转换为 TFLite 格式,通过 ONNX。完整代码可在 GitHub 上获取。
这段代码可能比之前的代码稍显复杂,因为有一些特定的优化和参数用于确保其正常运行。你也可以尝试更进一步,对 TFLite 模型进行量化,使其更小。如果你有兴趣,可以查看官方文档。
注意:转换代码对库的版本非常敏感。为了确保顺利转换,我强烈建议使用 GitHub 上的 requirements.txt 文件中指定的版本。
在我这边,经过 TFLite 转换后,我终于得到了一个仅约 400kB 的 TFLite 模型,体积小巧,非常适合网页使用。下一步是实际在网页浏览器中进行测试,并确保它按预期工作。
顺便提一下,Google 目前正在开发另一种解决方案,用于将 PyTorch 模型转换为 TFLite 格式:AI Edge Torch。不幸的是,这个解决方案相当新,我没能使其在我的用例中工作。不过,任何关于这个库的反馈都非常欢迎。
运行模型
现在我们终于得到了一个 TFLite 模型,能够在网页浏览器中使用 TensorFlow.js 运行它。如果你不熟悉 JavaScript(因为这通常不是数据科学家和机器学习工程师常用的语言),不用担心;所有代码都已提供,并且相当容易理解。
我不会在这里对所有代码进行注释,只注释最相关的部分。如果你查看GitHub上的代码,你会看到在javascript文件夹中有以下内容:
-
index.html:包含运行整个演示的主页
-
assets:包含我们刚刚转换的 TFLite 模型的文件夹
-
js:包含 JavaScript 代码的文件夹
如果我们退后一步来看,在 JavaScript 代码中我们需要做的就是遍历摄像头视频流的每一帧(无论是计算机上的网络摄像头还是手机上的前置摄像头),然后执行以下操作:
-
对图像进行预处理:将其调整为 128 像素的图像,进行填充和归一化
-
对预处理后的图像进行推理计算
-
对模型输出进行后处理:应用阈值化和非极大抑制来处理检测结果
我们不会评论图像预处理部分,因为这与 Python 预处理部分是重复的,但你可以随时查看代码。至于在 JavaScript 中进行 TFLite 模型推理,其实非常简单:
一个简单的代码示例,用于实例化 TFLite 模型并计算推理,假设图像形状正确。完整的可工作代码在 GitHub 上。
棘手的部分实际上是后处理。正如你所知,SSD 目标检测模型的输出是不能直接使用的:这并不是边界框的位置。以下是我使用的后处理代码:
在 JavaScript 中对 BlazeFace 模型输出进行后处理。完整代码在 GitHub 上。
在上面的代码中,模型输出经过以下步骤的后处理:
-
使用锚点纠正框的位置
-
将框格式转换为获取左上角和右下角坐标
-
对带有检测分数的框应用非极大抑制,移除所有低于给定阈值的框,并去除与其他已存在框重叠的框
这正是 Python 中所做的操作,用于显示生成的边界框,如果这能帮助你更好地理解这一部分内容的话。
最后,以下是生成的网页浏览器演示的截图:

网页浏览器中运行演示的截图,画中画由Vitaly Gariev提供,图片来自 Unsplash
如你所见,它正确地检测到了图像中的人脸。我决定使用一张来自Unsplash 的静态图片,但 GitHub 上的代码允许你在你的网络摄像头上运行它,因此请随时自己测试。
在总结之前,请注意,如果你在自己的电脑或智能手机上运行此代码,具体取决于你的设备,你可能无法达到 30 帧每秒(在我个人的笔记本电脑上,搭载的是一款较旧的 2017 年Intel® Core™ i5–8250U,它的运行速度为 36fps)。如果是这种情况,一些技巧可能会帮助你达到目标。最简单的办法是每 N 帧运行一次模型推理(N 的值可以根据你的应用场景进行微调)。实际上,在大多数情况下,从一帧到下一帧变化不大,框的位置几乎可以保持不变。
结论
希望你喜欢阅读这篇文章,如果你看到这里,非常感谢。尽管现在做物体检测相对简单,但在资源有限的情况下做物体检测仍然非常具有挑战性。了解 BlazeFace 并将模型转换为网页浏览器应用,可以深入了解 MediaPipe 是如何构建的,并为其他有趣的应用打开了大门,例如在视频通话中实时模糊背景(如 Google Meet 或 Microsoft Teams)。
参考文献
-
GitHub 仓库,其中包含在 blazeface 文件夹中的所有工作代码


浙公网安备 33010602011771号