TowardsDataScience-2023-博客中文翻译-四十六-
TowardsDataScience 2023 博客中文翻译(四十六)
将线性回归转变为逻辑回归
原文:
towardsdatascience.com/turn-linear-regression-into-logistic-regression-e088e2408ec9
关于如何从头实现逻辑回归的全面指南
·发表于 Towards Data Science ·阅读时间 10 分钟·2023 年 3 月 27 日
--

图片由 Rutger Leistra 提供,来自 Unsplash
动机
如果你阅读了我之前关于简单线性回归和多重线性回归的文章,你会了解到线性回归预测的是连续值。但是并非所有现实中的预测问题都与连续值相关。有时我们需要根据特征对对象或数据进行分类。线性回归算法无法解决这些问题。在这种情况下,逻辑回归的必要性就体现出来了。算法的名称‘逻辑回归’中包含了‘回归’一词。它是线性回归的改进版本,以便可以预测离散类别值而不是连续值。
因此,本文将解释逻辑回归如何生成从线性回归派生的类别预测值。
目录
-
**是什么让线性回归变成逻辑回归?** -
**哪个函数起关键作用?** -
**线性回归如何转变为逻辑回归?** -
**生成损失函数** -
**为什么不能使用均方误差作为成本函数?** -
**用于参数优化的梯度下降** -
**将所有概念结合起来进行 Python 从头实现**
什么让线性回归变成逻辑回归?
我会带来我之前文章中提到的两个方程,简单线性回归 和 多重线性回归。
第一个是简单线性回归的方程。

我们将通过插入自变量 (x) 的值来获得预测的回归值 (y)。但我们需要拟合系数斜率(m) 和 y 截距值 c。
第二个方程式类似于第一个方程,但存在多个自变量 *(x1……..xn)*,系数 m *(m1…….m0)*,以及 y 轴截距 m0。

对于第 1 个和第 2 个方程,如果我们有系数的最佳拟合值,我们可以轻松得到回归值,例如 34,687.93 等等。
但这并不能给我们将连续值转化为离散分类值的直观感受。因此,我们需要一个函数或方法,通过它可以将所有回归值转换为 **[0,1]** 之间的值。在逻辑回归中,我们正是这样做的。我将在下一节中讨论这个函数。
哪个函数起着关键作用?
我不会直接提到函数,而是会逐步解释。 让我们尝试直观了解线性回归和逻辑回归的效果。

回归模型图(图片由作者提供)
看一下上述回归模型图。对角线上的蓝色线是回归线。我们可以通过插入 ***x*** 值来预测任何 ***y*** 值。尝试制定一个逻辑回归问题。

图片由作者提供
上述数据集有一个特征,‘年龄’, 基于该特征定义目标类别。值 1 表示该人是学生,0 表示该人不是学生。用线性回归预测这样的分类值是不可能的。 如果我们绘制它,会是什么样子呢?让我们看看。

图片由作者提供
星星表示类别的水平(学生与否)。简单地说,回归线并不是预测分类值的合适方法。
在这里,‘S 形’ 函数,名为 ‘‘sigmoid’’ 发挥了作用。

这个函数可以将任何数字转换到 *[0,1]* 之间。我将给你展示一个 sigmoid 函数的编码示例。
创建一个 sigmoid 函数
绘制 sigmoid 图
这个 S 形 的 sigmoid 图形比直线更适合分类问题。随着 x 的增加,y 值从 0 到 1 变化,当 x=0,y=0.5。这是一个很好的函数,我们可以轻松设置一个阈值,例如 0.5。所有大于阈值(0.5)的值将为 1,否则为 0。
是的!终于,我们找到了合适的函数。
线性回归如何转变为逻辑回归?
现在,我们具备了将线性回归转换为逻辑回归的所有条件。让我们把它们放在一起。
在 第一部分,我展示了简单线性回归和多重线性回归的方程。线性回归的值是连续的,可以是任何连续的数值。
但 sigmoid 函数帮助我们产生诸如 0 和 1 的分类值,如 最后一部分 所示。
因此,逻辑回归的方程将如下所示。

符号 σ 代表 sigmoid 函数。如果我们将方程的输出传入 sigmoid 函数,我们将得到从 0 到 1 的结果。
现在,我们可以通过手动相乘和相加来计算线性方程的值。

但手动过程很耗时。向量化实现要快得多且容易。让我们制定线性方程,使其与向量化实现兼容。

我们添加了一个额外的常数变量 ***xi0=1***。

矩阵实现线性方程计算(图片来自作者)
X 包含所有自变量的值,M 的转置代表所有系数的转置矩阵。
向量化逻辑回归方程
向量化的逻辑回归实现将是这样的。

它会将线性方程的值转换到 0 到 1 之间。下面是一个 Python 函数。
用演示值测试函数。
是的!我们已经成功创建了这个函数。
生成一个损失函数
如果我们回顾之前的 多重线性回归 文章,我们会发现 均方误差 (MSE) 作为代价函数。

但我们知道逻辑回归不是回归算法。相反,它是一个二分类(两个类别)算法。在逻辑回归中,有两个类别,1 和 0。因此,MSE 不是用于逻辑回归的合适代价函数。*(但为什么?我稍后会解释具体原因)*
现在,我将介绍一个新的成本函数用于这个分类算法。

上述成本函数适用于逻辑回归。
让我们尝试对成本函数有一些直观的了解。对于**yi = 1**,成本函数为——

这个函数看起来如何?让我们绘制这个函数。
上述图是yi=1时损失函数的图形表示。图表显示,预测值越接近1,误差越小。当预测值为 0.0 时,误差是无限的。
让我们绘制 ***y0=1***的成本函数。
对于yi=0,当预测值接近**1**时,误差是无限的,通过减小值来减少误差。现在,我们将绘制两个图形的结合。
现在,图形表示更加直观。如果我们将**yi=0 和 yi=1**的损失函数结合起来,我们将得到一个适合应用梯度下降的函数,它具有全局最小值。

如果我们将目标值***yi=1*** *或* ***yi=0***代入上述方程,其中一部分将被取消,结果将是我提到的相同方程。这就是我们需要的。
将成本函数转换为代码。
为什么不能使用均方误差(MSE)作为成本函数?
在逻辑回归中,目标或输出值是离散的或分类的。它不像回归问题中的连续值。如果我们将值代入均方误差(MSE)成本函数(我们用于线性回归和多重线性回归的成本函数),我们将得到如下图形,而不是一个凸曲线。

多局部最小值的成本函数(图片由作者提供)
由于这种类型的曲线包含多个局部最小值,我们在成本函数中应用梯度下降时会遇到麻烦。这就是为什么我们在逻辑回归中不会使用均方误差(MSE)作为成本函数的原因。
参数优化的梯度下降
梯度下降是一种通过优化机器学习算法的系数来最小化损失/成本函数的方法,这取决于成本函数的形状。

梯度下降(图片由作者提供)
成本函数是一个凸曲线,如损失函数部分所示。现在,我们需要计算成本函数的导数。导数表示成本在什么方向上发生变化。
首先,我们将随机初始化系数的权重并逐步更新权重。主要目标是找到如上图所示的最小成本。

成本函数的导数为——

[注:如果展示导数的详细计算,文章会变得不必要地长。请阅读 文章 以获得详细解释。]
矢量化实现如下。

***X***是所有特征值的矩阵形式,**M**代表系数的矢量化形式,**Y**表示目标值的矢量化表示。
矢量化梯度下降实现的代码。
我们已经迈出了实现最终模型的一步。所有功能都已准备好进行逻辑回归。在下一步中,我们将结合所有工具,实施完整的算法。
将所有概念结合起来进行 Python 从零实现
首先加载泰坦尼克号数据集。 数据集是公开的 可用的 并且在公共领域许可下。
- 导入必要的库
*[我们的主要目标是展示算法的基本机制。因此,我们保持了简单易懂的预处理。我们将重点放在核心实现上,而不是数据分析。]*
为了方便,我们选择了一些特征——
- 让我们对选择的特征进行一些深入了解。
特征‘Age’和‘Fare’有一些缺失值。我们将用平均值填补这些缺失值,并将‘Sex’中的男性映射为1,女性映射为0。
现在,所有特征都是数值型的,没有缺失值。
-
提取自变量(x)和因变量(y)
-
规范化数据,以提高梯度下降的性能
-
拆分训练集和测试集
保留了 25%的数据用于测试,其余数据用于训练。现在,我们将数据输入到我们的初步模型中。
-
将所有功能结合在一起进行逻辑回归
-
用训练数据拟合模型
-
查看模型系数如何优化
-
创建预测函数
在这里,我使用了0.5的阈值来分类数据。所有低于0.5的结果被视为类0,等于或高于0.5的结果被视为类1。
- 让我们将模型与基准 scikit-learn 库进行比较
使用 scikit-learn 创建逻辑回归模型
在测试数据上的预测
- scikit-learn 模型与我们初步模型的结果对比
👉我们初步模型的结果
混淆矩阵
精确度、召回率和 f1-score
👉scikit-learn 模型的结果
混淆矩阵
精确度、召回率和 f1-score
结果显示,我们的初步模型和 scikit-learn 模型具有相同的结果。因此,我们声称我们的初步模型与 scikit-learn 模型相同。
结论
现在,一些内置库使得机器学习模型的实现变得非常简单。因此,学习核心机制可能对你来说并不必要。作为研究人员和学者,我总是从不同的角度考虑这个问题。如果你了解算法的核心概念,这将对你在核心层面的工作,如算法的研究、开发和优化等非常有帮助。你可以在那些没有机器学习库的编程语言中实现这些概念。
[***完整的笔记本和数据集可在仓库中获取***](https://github.com/Zubair063/ML_articles/tree/main/Logistic%20Regression%20from%20Scratch)***.***
参考文献
-
Andrew Ng 的机器学习课程
我之前的**从头开始的算法**系列文章。
从零开始的多重线性回归:深入理解
从零开始的线性回归:详细解释
KNN 算法从零开始 [## KNN 算法从零开始
KNN 算法的实现和详细解释
K-means 从零开始 [## K-means 从零开始
K-means:聚类数据的最佳 ML 算法
**统计和数据可视化** 数据科学系列。
数据科学的终极统计指南 [## 数据科学的终极统计指南
数据科学的一瞥:标准指南
数据科学终极统计指南 [## 数据科学数据可视化终极指南
数据科学中的数据可视化概述:标准指南
将洞察转化为可操作的成果
原文:
towardsdatascience.com/turning-insights-into-actionable-outcomes-f7b2a638fa52
解锁秘密配方。
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 8 月 19 日
--
请检查下面的图片。你认为它描绘的是什么?
如果你认为这张图片描绘的是一块美味的巧克力,你就错了。

这里展示的是复合巧克力,它是由可可、植物脂肪和甜味剂混合而成。由于其成分较便宜,它是比真正巧克力更经济的选择。你可以在经济型巧克力棒或糖果涂层中找到它。这种巧克力在我居住的波兰有着特殊的记忆。在 1980 年代,由于缺少重要的成分,它取代了真正的巧克力。
这种产品通常含有不超过 7%的可可 [1]。相比之下,真正的巧克力至少含有 35%的可可(典型的苦巧克力含 70%) [2]。差别确实很大,对吧?
在我最近的文章《从数字到行动:让数据为公司服务》中,我探讨了包括洞察本质在内的各种问题。许多所谓的洞察类似于复合巧克力。乍一看,这些所谓的洞察可能看起来很真实。但就像那块复合巧克力一样,当你仔细观察或“品尝”它们时,它们并不完全符合标准。

显然,有些问题。图片由作者提供。
在继续之前,我要停下来澄清一下我所说的‘洞察’一词的含义。
什么是洞察,为什么这个术语如此独特?
我探讨了‘洞察’一词的各种词典定义。牛津学习词典将其定义为‘对某事物的 理解 ’。Dictionary.com 描述为‘通过直观的* 理解 把握事物的真实本质’。最后,剑桥词典将其定义为‘理解 了解 某事物的真实情况’。
这里反复出现的主题是‘理解’ — 理解事物的本质。但这个关键的‘事物’是什么?它是特定的、深刻的,且能够带来公司运营显著变化的。
然而,我们对‘洞察’的理解仍不完整。虽然它涉及理解重要的业务问题,但它也包含了‘可行动性’。真正的洞察会导致具体行动,推动体现其建议的决策。
本质上:洞察是对特定业务问题的深刻理解,这种理解促使决策和行动。
为什么洞察如此重要?
今天,数据的有效利用被普遍认可为竞争优势的基础。然而,令人惊讶的是,只有极少数公司能够充分发挥其潜力。仅有 27%的公司认为自己是数据驱动的[3]。当想到真正利用数据并从中获得价值的企业时,通常会想到亚马逊、Meta 和 Netflix 等在线巨头。然而,事实上,任何规模的组织和各个行业都可以通过高效使用数据来推动其增长。
单单拥有数据是不够的。即使数据完美契合特定业务需求,其真正的价值只有在有效应用时才能显现。这意味着决策和后续行动必须基于从数据中获取的趋势、细微差别和洞察。
众多因素可以促进这一过程,但也存在显著的障碍。在通向成功的路线图上,有几个要素至关重要:设计合理且高效管理的数据源、致力于数据驱动决策的公司领导、一个专注且训练有素的数据团队,以及数据驱动的讲故事方式。
为什么与这一定义相符的洞察如此之少?
主要原因在于创建这些洞察的固有挑战。幸运的是,有一些工具可以帮助公司在这一过程中导航。考虑麦肯锡提出的洞察价值链模型[4]。这一概念框架指导组织将原始数据转化为可行动的洞察,从而创造商业价值。该模型展示了四个主要步骤,解释了数据如何从原始状态变为有价值的决策。在下面的图片中,我使用了一个零售公司试图提升销售的例子来展示该模型的工作原理。

麦肯锡的洞察价值链模型以零售公司为例。图片由作者提供。
如上所示,过程是复杂的。为了获得更深入的理解,让我们系统地解剖它。从根本上说,我们要解决的紧迫问题是:
如何提高洞察的百分比?
有四个基本特征定义了“真实”的洞察。一个洞察,如果它的可可含量达到 70%或更多,无论是否苦涩,我在下面的图片中展示了它们:

“真实”洞察的特征。图片由作者提供。
“真实”的洞察:
-
应该提供对业务事务的理解……
-
要具体……
-
… 并且有意义。
-
应该促使决策和行动。
现在,让我们深入探讨这些特征。
如何产生一个能够提供理解的洞察?
真实的洞察必须结合背景以最大化其影响力和理解度。 背景丰富了数据驱动的叙事。赋予洞察背景的六种方法包括:
-
比较背景:每月比较产品销售额或将实际成本与预算或去年同期进行对比。
-
规模调整:突出时间的累计影响,或将年度收益分解为每月或每周的收益,以获得更直观的视角。
-
等效性:通过使用熟悉的例子来帮助理解。与其说:“你的智能手机有 128GB 存储空间”,不如说“它可以存储 25,000 张照片*”。
-
历史背景:展示绩效趋势,考虑季节性或周期性影响。始终比较完整的周期。
-
信息背景:提供有关模式或异常的细节,而不假设相关性意味着因果关系。
-
数据验证:通过引用数据来源、收集方法和时效性来增强可信度[5]。
其次,永远不要满足于初步结论,特别是使用 LLM 进行分析时。进一步深入,直到结论真正体现洞察。
Chat GPT 和代码解释器在形成洞察中的实际应用。来源:作者的 YT 频道。
第三,激发灵感。使用的工具越简单,发生的可能性就越大。即使使用的技术看似简单。 记住阿基米德说过的话:
给我一个支点,我将撬动地球。
以下,我展示了一些使用像 Excel 这样的基本工具执行的分析。虽然这些分析很简单,但它们可以产生有价值的洞察,可能作为使用更复杂程序或技术进行深入探索的基础。
初始图表展示了客户信心指数在一年中的波动,通过 Excel 进行分析。从趋势线和附带的线性回归方程中可以明显看出,总体趋势是下降的。在 C-19 封锁和乌克兰战争爆发等事件期间出现了显著的下降。目前,趋势正向上发展。

趋势分析示例。图片来源于作者。
另一个分析同样在 Excel 中进行,有助于识别结果分布中的特殊性。通过使用基本直方图,我们可以找出异常值,并评估频率分布中的任何不规则性。例如,初看似乎是一个单一分布,实际上可能是三个不同的分布,如下例所示:

使用直方图在 Excel 中检测异常。图片来源于作者。
最终分析同样在 Excel 中进行,包括向图表中添加趋势线。此工具允许应用各种函数,包括线性和非线性函数,以及回归方程。此外,还可以使用 R 平方估计来评估拟合的准确性。

Excel 中的关系分析。图片来源于作者。
我们如何使见解更具体、更有意义?
见解必须与核心业务目标和战略计划紧密相关。这种联系越强,见解被忽视的可能性就越小。
广泛地说,有两种类型的指标:
-
KPI(关键绩效指标)
-
KCI(关键概念指标)。
如果一个指标难以响应,无论其变化的幅度如何,它可能是 KCI——在组织中广泛监控但缺乏可操作价值。相反,与 KPI 相关的见解可以激发真正的紧迫感,推动决策和行动。
KPI 与公司战略的对齐程度越高,它就越自然地转化为战术响应,因为这些响应直接与关键业务组成部分相关。
KPI 必须深深植根于公司的 DNA 中,覆盖从高层领导到后台员工的各个层面。平衡计分卡 可以在将目标和指标传播到每个部门时发挥重要作用。通过培养能够无缝连接管理、财务和数据科学的角色,形成了一种统一的目标实现方法。强调跨组织所有领域的业务合作,从销售到会计。对于适合的组织,采用敏捷管理结构可以提升这种综合策略。
见解如何促使决策和行动?
初始步骤涉及掌握数据讲述的艺术。沟通见解应该超越仅仅向决策者展示复杂表格。这样的做法有可能让他们感到不知所措,导致他们失去兴趣。
有效的数据讲述基于三个原则:
-
理解上下文: 识别驱动我们观众的因素。
-
运用叙事结构: 实施诸如讲故事弧线[6]等元素。
-
利用有效的视觉表现。
什么构成有效的视觉表现? 首先,它应该清晰,不会让观众感到困惑。因此,我提倡使用这三种图表类型:

三种总是有效的图表。作者提供的图片
对于图表选择,当比较汇总值(如预算与实际情况)时使用柱状图或条形图。分析趋势时,折线图是首选。如果你试图理解部分与整体的关系,饼图是理想选择。这三种图表类型可能满足你约 80%的可视化需求,除非有特定场景如队列分析**。
设计图表时,重要的是去除任何杂乱元素。去掉框架、辅助线和不必要的数据点,这些可能会分散主要信息。将颜色和文本视为战略工具;它们应该用来突出和强调关键信息,而不仅仅是美化图表。
始终与观众保持一致。测试你的视觉效果,看看哪些有效,哪些无效,并进行相应调整。这种迭代过程是建立共同理解和确保你的数据讲述引人入胜故事的关键。
最后,确保你的叙事自然流畅。避免用不必要和冗长的悬念分散观众的注意力。使用诸如 3 分钟故事或大创意[7]等方法评估你的讲故事技巧。例如,我会朗读我的叙事,无论是文章还是演示文稿。如果我能顺利地表达故事,这增强了我对其与观众共鸣的信心。 一旦你赢得了他们的注意力,介绍关键结论和行动号召。务必在故事高潮之后立即进行 — 那时他们最为投入和接受。然而,如果出现顾虑,优先考虑积极倾听。解决任何不确定性,并在需要时,建议协作后续活动以促进理解。
结论
在这篇文章中,我讲解了如何打造有力的洞察。这些不仅仅是普通的洞察;它们是引导企业做出明智决策的那种洞察。当使用得当时,这些洞察可以改变游戏规则,帮助公司应对困难局面或利用巨大机遇。拥有正确的数据或最佳工具并不是全部。如何分享和解释这些洞察同样至关重要。 一切都在于确保信息传达准确,让人们思考,并激励他们采取行动。最终,最有价值的洞察是那些能带来有意义的行动和转型的洞察。
假设照片的平均大小为 5MB,而智能手机上的有效空间为 120GB
**作者的主观评估
[1] 维基百科,复合巧克力e
[2] 凯瑟琳·马丁科,巧克力上的可可含量是什么意思?,2021 年 2 月 6 日
[3] 米哈乌·苏德伊科,从数字到行动:让数据为公司发挥作用,2023 年 8 月 14 日
[4] 霍尔格·赫尔特根和尼科·莫尔,利用数据实现业务影响,2018 年 4 月 27 日
[5] 布伦特·戴克斯,情境化洞察:将数字放入上下文的六种方法,2018 年 10 月 18 日
[6] 苹果播客,叙事弧:数据故事中的缺失工具,布伦特·戴克斯,2021 年
科尔·努斯鲍默·克纳夫利克,《数据故事讲述》,Wiley,2015 年
使用 string2string 驯服文本:一个强大的 Python 库,用于字符串对字符串算法
原文:
towardsdatascience.com/tutorial-string2string-python-pkg-f9126b8474c5
教程
利用 string2string 进行自然语言处理任务
·发表于Towards Data Science ·阅读时间 8 分钟·2023 年 5 月 11 日
--

string2string 库中的概念词云以及一个示例(作者提供的图片)。
string2string库是一个开源工具,提供了一整套高效的字符串对字符串问题解决方法。该库涵盖了字符串配对比对、距离测量、词汇和语义搜索以及相似性分析。此外,还包括了各种有用的可视化工具和度量标准,使得理解和评估这些方法的结果更加简单。
这个库包含了如 Smith-Waterman、Hirschberg、Wagner-Fisher、BARTScore、BERTScore、Knuth-Morris-Pratt 和 Faiss 搜索等知名算法。它可以用于自然语言处理、生物信息学和计算机社会研究中的许多工作和问题[1]。
斯坦福 NLP 小组,作为斯坦福 AI 实验室的一部分,开发了这个库,并在[1]中介绍了它。该库的 GitHub 仓库有几个教程,你可能会觉得有用。
字符串是代表一段数据或文本的字符(字母、数字和符号)序列。从日常短语到 DNA 序列,甚至计算机程序,字符串可以用来表示几乎一切[1]。
目录
-
安装
-
配对比对
– 用于全局比对的 Needleman-Wunsch 算法
– 动态时间规整
-
搜索问题
– 词汇搜索(精确匹配搜索)
– 语义搜索
–– 通过 Faiss 的语义搜索
-
距离
– Levenshtein 编辑距离
– Jaccard 指数
-
相似性分析
-
结论
-
参考文献
安装
你可以通过运行 pip install string2string 来安装该库。有关更多信息,请访问该库的 GitHub 页面。
成对对齐
字符串成对对齐是一种在 NLP 和其他学科中用于比较两个字符串或字符序列的方法,通过突出它们的共享和独特特征。两个字符串被对齐,并根据共享字符的数量以及共享间隙和不匹配的数量计算相似度分数。这个过程对于定位共享相似性的字符序列和计算两个字符串集合之间的“距离”非常有用。拼写检查、文本分析和生物信息学序列比较(例如 DNA 序列对齐)只是其中的一些用途。
目前,string2string 包提供了以下对齐技术:
-
Needleman-Wunsch 用于全局对齐
-
Smith-Waterman 用于局部对齐
-
Hirchberg 的线性空间全局对齐算法
-
最长公共子序列
-
最长公共子字符串
-
动态时间规整(DTW)用于时间序列对齐
在这篇文章中,我们将看两个示例:一个用于全局对齐,一个用于时间序列对齐。
Needleman-Wunsch 算法用于全局对齐
Needleman-Wunsch 算法是一种动态规划算法,通常用于生物信息学中全局匹配两个 DNA 或蛋白质序列。
The alignment between "ACGTGGA" and "AGCTCGC":
A | C | G | - | T | G | G | A
A | - | G | C | T | C | G | C
为了进行更具信息性的比较,我们可以使用库中的 plot_pairwise_alignment() 函数。

图 1: “ACGTGGA” 与 “AGCTCGC” 之间的全局对齐(图像由作者提供)。
动态时间规整
DTW 是一个有用的工具,用于比较两个可能在速度、持续时间或两者都不同的时间序列。它通过计算两个序列中每对点之间的“距离”,发现最小化序列之间总差异的路径。
让我们通过使用 string2string 库中的 alignment 模块来举个例子。
DTW path: [(0, 0), (1, 1), (1, 2), (2, 3), (3, 4), (4, 5), (4, 6)]
上面是一个借用我之前文章的示例,动态时间规整的插图介绍。对于那些希望深入探讨这一主题的人,在 [2] 中,我以直观和易于理解的方式解释了 DTW 的核心概念。
搜索问题
字符串搜索是找到模式子字符串在另一个字符串中的任务。该库提供了两种搜索算法:词汇搜索和语义搜索。
词汇搜索(精确匹配搜索)
词汇搜索,通俗地说,就是在文本中搜索某些单词或短语,类似于在字典或书籍中查找一个单词或短语。
与其尝试理解一串字母或词语的意思,不如直接尝试准确匹配。在搜索引擎和信息检索中,词汇搜索是一种基本策略,用于根据用户输入的关键词或短语找到相关资源,而不试图理解这些词或短语的语言上下文。
目前,string2string库提供了以下词汇搜索算法:
-
朴素(暴力)搜索算法
-
Rabin-Karp 搜索算法
-
Knuth-Morris-Pratt(KMP)搜索算法(见下例)
-
Boyer-Moore 搜索算法
The starting index of pattern: 72
The pattern (± characters) inside the text: "of a Redwood tree, and"
语义搜索
语义搜索是一种更复杂的信息检索方法,它超越了简单的词或短语搜索。它利用自然语言处理(NLP)来解读用户的意图,并返回准确的结果。
换句话说,假设你对“如何种植苹果”感兴趣。虽然词汇搜索可能会产生包含“grow”和“apples”这两个词的结果,但语义搜索会识别出你对苹果树栽培的兴趣,并据此提供相关结果。搜索引擎会优先展示那些不仅包含所查短语,还提供关于种植、修剪和收获苹果树的相关信息的结果。
通过 Faiss 进行语义搜索
Faiss(Facebook AI Similarity Search)是一个高效的相似性搜索工具,适用于处理具有数值表示的高维数据[3]。string2string库为 Facebook 开发的 FAISS 库提供了一个封装(见GitHub 仓库)。
简而言之,Faiss 搜索根据“得分”对结果进行排名,得分表示两个对象之间的相似程度。得分使得根据搜索结果与目标的接近/相关程度来解释和优先排序搜索结果成为可能。
让我们看看string2string库中如何使用 Faiss 搜索。这里,我们有一个包含 11 个句子的语料库(语料库是用于语言学研究、NLP 和机器学习应用的大型结构化文本集合),我们将通过查询一个目标句子来进行语义搜索,以查看它与这些句子的接近/相关程度。
corpus = {"text": [
"A warm cup of tea in the morning helps me start the day right.",
"Staying active is important for maintaining a healthy lifestyle.",
"I find inspiration in trying out new activities or hobbies.",
"The view from my window is always a source of inspiration.",
"The encouragement from my loved ones keeps me going.",
"The novel I've picked up recently has been a page-turner.",
"Listening to podcasts helps me stay focused during work.",
"I can't wait to explore the new art gallery downtown.",
"Meditating in a peaceful environment brings clarity to my thoughts.",
"I believe empathy is a crucial quality to possess.",
"I like to exercise a few times a week."
]
}
query = "I enjoy walking early morning before I start my work."
让我们初始化FaissSearch对象。Facebook 的 BART Large 模型是FaissSearch对象的默认模型和分词器。
让我们在语料库中找到与查询最相似的前 3 个句子,并打印它们及其相似度得分。
Query: I enjoy walking early morning before I start my work.
Result 1 (score=208.49): "I find inspiration in trying out new activities or hobbies."
Result 2 (score=218.21): "I like to exercise a few times a week."
Result 3 (score=225.96): "I can't wait to explore the new art gallery downtown."
距离
字符串距离是量化两个提供的字符串之间差异程度的任务。当前,string2string库提供了以下距离函数:
-
Levenshtein 编辑距离
-
Damerau-Levenshtein 编辑距离
-
汉明距离
-
Jaccard 距离
Levenshtein 编辑距离
莱文斯坦编辑距离,简称编辑距离,是将一个字符串转换成另一个字符串所需的最小插入、删除或替换次数。
The distance between the following two sentences is 2.0:
"The beautiful cherry blossoms bloom in the spring time."
"The beutiful cherry blosoms bloom in the spring time."
Jaccard 指数
Jaccard 指数可用于量化词汇或标记集之间的相似性,通常用于文档相似性或主题建模等任务。例如,Jaccard 指数可以用来衡量两个不同文档中词汇集的重叠情况,或识别一组文档中最相似的主题。
Jaccard distance between two documents: 0.75
相似性分析
简而言之,字符串相似性决定了两段文本(或字符序列)之间的关联程度或相似性。例如,考虑以下这对句子:
-
“猫坐在垫子上。”
-
“猫坐在地毯上。”
尽管不完全相同,这些陈述共享词汇,并传达了连贯的意义。基于字符串相似性分析的方法揭示并量化了这种文本配对之间的相似度。
字符串 相似性 和 距离 测量之间存在双重性,意味着它们可以互换使用 [1]。
string2string 库的 similarly 模块目前提供以下算法:
-
余弦相似性
-
BERTScore
-
BARTScore
-
Jaro 相似性
-
LCSubsequence 相似性
让我们通过以下四个句子的 BERTScore 相似性算法示例来了解一下:
-
面包店出售各种美味的糕点和面包。
-
公园有一个游乐场、步道和野餐区。
-
该节日展示了来自世界各地的独立电影。
-
面包店提供一系列美味的面包和糕点。
句子 1 和 2 在语义上是相似的,因为它们都涉及面包店和糕点。因此,我们应该预期它们之间会有较高的相似度评分。
让我们在库中实现上述示例。
我们可以使用库中提供的 plot_heatmap() 函数可视化每对句子之间的相似性。

句子之间的语义相似性(BERTScore)(图由作者提供)。
如上所示,句子 1 和 4 的相似性远高于我们的预期(使用 BERTScore 算法)。
结论
string2string Python 库是一个开源工具,提供了一整套高效的方法用于字符串对字符串问题。特别是,该库有四个主要模块,分别处理以下任务:1. 成对比对,包括全局对齐和局部对齐;2. 距离测量;3. 词汇和语义搜索;4. 相似性分析。该库在每个类别中提供了各种算法,并提供了有用的可视化工具。
📓 您可以在 GitHub找到本帖的笔记本。
感谢阅读! 📚
我是高级数据科学家 📊 和工程师,撰写关于统计学、机器学习、Python 等方面的内容。
🌱 我还策划了一份每周通讯,名为 AI Sprout,在其中提供最新 AI 工具和创新的实际评测和分析。订阅 与我一起探索新兴的 AI!
-
在 Medium 上关注我 👋 以获取我最新的帖子
-
订阅 我的邮件列表 ✉️ 获取直接发送到您收件箱的更新
[## 使用我的推荐链接加入 Medium - Esmaeil Alizadeh
📖 阅读 Esmaeil Alizadeh 以及其他成千上万名 Medium 作家的每一个故事。订阅 Medium 获取完整内容……
medium.ealizadeh.com](https://medium.ealizadeh.com/membership?source=post_page-----f9126b8474c5--------------------------------)
参考文献
[1] M. Suzgun, S. M. Shieber, and D. Jurafsky, “string2string: A modern python library for string-to-string algorithms,” 2023, Available: arxiv.org/abs/2304.14395
[2] E. Alizadeh, “动态时间规整的插图式介绍,” 2020. ealizadeh.com/blog/introduction-to-dynamic-time-warping/
[3] J. Johnson, M. Douze, and H. Jégou, “Billion-scale similarity search with GPUs,” IEEE Transactions on Big Data, vol. 7, no. 3, pp. 535–547, 2019.
最初发布于 https://ealizadeh.com。
变压器在预测推特账户身份中的力量
原文:
towardsdatascience.com/twitter-account-identity-prediction-with-large-language-models-c3ffef114d34
利用大型语言模型进行高级自然语言处理
如何使用最先进的模型进行准确的文本分类
·发表于Towards Data Science ·9 分钟阅读·2023 年 3 月 7 日
--

图片由Jonathan Cooper拍摄,来自Unsplash
介绍
本项目旨在构建一个能够从推文中预测账户身份的模型。我将详细介绍从数据处理、微调到模型性能评估的步骤。
在继续之前,我需要说明的是,这里的身份定义为男性、女性或品牌。这并不反映我对性别身份的看法,这只是一个展示变压器在序列分类中强大能力的玩具项目。在一些代码片段中,你可能会注意到性别被用于表示身份,这只是数据到达的方式。
方法
由于文本数据的复杂性和非线性关系的建模,我排除了更简单的方法,选择利用预训练的变压器模型来完成这个项目。
变压器(Transformers)是当前自然语言处理和理解任务的最先进技术。Transformers库来自 Hugging Face,为你提供了数千个预训练模型以及执行自己微调的 API。大多数模型已经在大量文本语料库上进行了训练,有些跨多个语言。没有经过任何微调,它们在类似的文本分类任务上表现非常好,包括情感分析、情绪检测和仇恨言论识别。
我选择了两个模型进行微调,并使用一个零样本模型作为比较的基准。
Zero-shot 学习提供了一个基准估计,显示了变换器在没有针对特定分类任务进行微调的情况下能有多强大。
笔记本、模型与代码库
由于计算成本,我无法使训练脚本具备交互性。不过,我已经将性能分析笔记本和模型提供给你。你可以使用实时推文自己尝试这些模型!
📒笔记本: 模型性能分析 Jupyter 笔记本
🤗微调 Distilbert-Base-Multilingual-Cased: 模型 1
🤗微调 Albert-base-v2 : 模型 2
💻Github 代码库 : 训练脚本
💾数据来源: Kaggle
数据探索与预处理
数据由 Data For Everyone Library 提供,位于 Crowdflower。你可以在 Kaggle⁴ 下载数据。
注意:数据拥有一个公共 领域许可证⁴。
总共有大约 20k 条记录,包含用户名、推文、用户描述和其他推特个人信息。虽然时间限制使我无法详细检查,但从快速检查中可以明显看出这些推文是多语言的。然而,推文文本中混杂了 URL、ascii 字符和特殊字符。这是社交媒体数据的常见现象,幸运的是,使用正则表达式清理这些数据非常简单。

作者提供的图像:推文文本示例、用户描述及数据标签
个人资料图片数据以 URL 链接的形式提供。然而,许多链接已损坏,因此在此预测任务中无用。通常,人们可能期望个人资料图片能很好地预测账户持有者的身份,但在这种情况下,数据质量问题过于严重。由于这个原因,我决定使用推文文本和用户描述进行建模。
缺失与未知变量
大多数账户都提供了身份标签。标签内容丰富,包括女性、男性、品牌和未知——仅 5.6% 的账户被标记为未知。身份标签未知的账户被从分析中移除,因为它们无法进行测试或训练。
大约 19% 的用户描述为空。空白的用户描述可能会暗示账户持有人的身份。对于空白的用户描述,我简单地插入了一些文本,以便模型能够从这些案例中学习。
扩展数据
为了创建更多的示例供模型学习,我将用户描述和推文文本合并到一个通用的 Twitter 文本字段中,有效地将文本样本的数量翻倍。
训练、验证、测试
我将数据拆分为 70% 的训练集、15% 的验证集和 15% 的测试集。为了确保没有重叠,如果数据中有账户出现多次,我会自动将所有这些实例分配到训练数据集中。除此之外,账户会根据上述比例随机分配到各个数据集中。
数据预处理流水线脚本
硬件
微调是在每个模型上分别完成的,并且需要 GPU 才能实际实现。我的笔记本电脑的 GPU 的具体规格是 Nvidia GE Force RTX 2060。
虽然这被认为是个人笔记本电脑的高规格,但我发现一些大型语言模型的性能受到了限制,最终限制了我可以尝试的模型集。
软件
为了充分利用我的 GPU,我必须为我的 GPU 版本和所使用的 Pytorch 版本安装适当的 CUDA 工具包。
CUDA 是一个平台,使计算机能够对数据执行并行计算。这可以大大加快微调 Transformers 的时间。
建议不要在没有 CUDA 支持的 GPU 上进行这种类型的微调,除非你愿意让机器运行几天。
Python 软件包
建模过程的所有步骤都用 Python 脚本编写。我利用了来自 Hugging Face 的开源 Transformers 库。我发现这个库维护良好,并且有大量文档提供最佳实践的指导。
在模型性能测试中,我使用了数据科学家常用的开源机器学习和数据处理工具。关键软件包的列表如下:Transformers、Sci-kit Learn、Pandas、Numpy、Seaborn、Matplotlib、和 Pytorch。
环境管理
Anaconda 作为我的主要环境管理工具,创建了一个 Conda 虚拟环境来安装所有的软件依赖。我强烈建议使用这种方法,因为可能存在大量潜在的依赖冲突。
模型微调
这些模型通过在训练数据集上训练并在验证集上评估性能来进行微调。我已配置微调过程,以根据验证数据集上的表现返回最佳模型。
由于这是一个多类分类问题,因此正在最小化的损失指标是 交叉熵损失。更好的模型性能实质上是在验证集上较低的交叉熵损失。候选模型的超参数设置相同,以便进行比较。
用于微调变换器模型的脚本片段
模型 0:Multilingual-MiniLMv2-L6-mnli-xnli¹
我通过执行零-shot 分类开始分析,以提供一个基线,从中评估微调后的模型。该模型的参考文本表明,它可以在 100 多种语言上进行推断¹,这对我们的问题来说覆盖范围相当好。
模型 1:Distilbert Base Multilingual Cased²
Distilbert-base-multilingual-cased 已在 104 种不同语言上进行训练,提供了广泛的覆盖。该模型是大小写敏感的,因此可以识别文本中的大写和小写。
模型(预)训练: 该模型已在维基百科页面的拼接上进行预训练。
模型架构: 基于变换器的语言模型,具有 6 层,769 维度和 12 个头,总共有 1.34 亿个参数。
微调: 在我的硬件上运行模型微调大约花费了 21 分钟。有一些证据表明,模型已经收敛,这些证据来自评估损失与训练步骤图的对比。

作者提供的图像:评估损失和训练损失
模型 2:Albert-base-v³
该模型已在英文文本上进行预训练,并且是无大小写敏感的,这意味着它不保留文本的大小写信息。Albert 专门设计用于解决训练较大模型时出现的内存限制问题。该模型使用一种自监督损失,重点建模句子间的连贯性。
模型(预)训练: Albert 在 BOOKCORPUS 和英文维基百科上进行预训练以实现其基线。
模型架构: 基于变换器的语言模型,具有 12 个重复层,128 个嵌入,768 个隐藏层,12 个头和 1100 万个参数。
微调: 模型微调大约花费了 35 分钟完成。模型收敛可能通过损失指标的“低谷”来指示。

作者提供的图像:评估损失和训练损失
模型性能
鉴于这是一个多类学习任务,我评估了模型在 F1、召回率、精确率和准确率上的表现,包括每个类和全局水平。性能指标在测试数据集上评分。
零样本的准确率为 37%,Albert 和 Distilbert 的准确率均为 59%。
观察
总体而言,Albert 和 Distilbert 在测试集上的表现均优于零样本分类基线。这是我预期的结果,因为零样本模型对当前分类任务没有任何知识。我认为这更进一步证明了对模型进行微调的价值。
尽管存在显著的性能差异,但在这两种微调模型之间,我们不能明确说哪种模型更好,直到我们在实际应用中对这些模型进行长时间的测试。
显著的性能差异
Albert 在预测时似乎更加自信,其整体预测置信度的 75 百分位数为 82%,而 Distilbert 为 66%。
所有模型在预测男性身份时的精确度、召回率和 F1 值都较低。这可能是由于男性推文的变异性大于女性和品牌推文。
所有模型在预测品牌方面的表现都优于预测其他身份的表现。此外,相比于预测男性或女性用户,模型在预测品牌时表现出显著更高的信心。我猜这是因为品牌在社交媒体上发布信息的方式相对标准化,而个人用户则不然。

作者提供的图像:所有模型的性能指标

作者提供的图像:四分位数间隔的置信度得分
改进方向
我建议采取以下措施以提升模型性能:
增加训练样本
更多的数据可以帮助模型更好地泛化,从而提高整体性能。确实有过拟合的迹象,因为我注意到模型在评估集上的性能开始下降,而在测试集上的性能持续提高,更多的数据可以在一定程度上缓解这种情况。
由于 Distilbert 模型的体积较大,相比于 Albert 模型,它更容易发生过拟合。大型语言模型更加灵活,但也更容易过拟合。
在多个 GPU 上对 twitter-xlm-roberta-base 模型进行微调以实现收敛
有一个由Cardiff NLP开发的模型,专门在推特文本上预训练,并且是多语言的。我确实尝试对这个模型进行微调,但由于硬件限制,效果不佳。这个模型参数多达 198M,运行了近 4 小时却没有显示出收敛的迹象。理论上,由于 Roberta 在推特数据上进行过预训练,它应该比 Distilbert 和 Albert 表现更好。然而,需要更多的数据来防止这个大型模型的过拟合。
探索多模态变换器架构的潜力。
如果我们能改善个人资料图片数据的质量,我认为推文文本和图像的结合可能会显著提升我们分类器的性能。
感谢阅读
[## 通过我的推荐链接加入 Medium - John Adeojo
我分享数据科学项目、经验和专业知识,帮助你在旅途中。你可以通过……
johnadeojo.medium.com [## 首页 | John Adeojo
关于我的一点介绍 欢迎来到我的专业作品集!我是一位经验丰富的数据科学家和机器学习(ML)专家……
[1] Laurer, M., van Atteveldt, W., Salleras Casas, A., & Welbers, K. (2022). 更少标注,更多分类 — 通过深度迁移学习和 BERT 解决监督机器学习的数据稀缺问题 — NLI [预印本]。开放科学框架.
[2] Sanh, V., Debut, L., Chaumond, J., & Wolf, T. (2019). DistilBERT,BERT 的一种精简版本:更小、更快、更便宜且更轻量。arXiv 预印本 arXiv:1910.01108.
[3] Lan, Z., Chen, M., Goodman, S., Gimpel, K., Sharma, P., & Soricut, R. (2019). ALBERT: 一种轻量级的 BERT,用于自监督语言表示学习。CoRR, abs/1909.11942. http://arxiv.org/abs/1909.11942
[4] Twitter 用户性别分类。Kaggle。检索于 2023 年 3 月 15 日,来自 www.kaggle.com/datasets/crowdflower/twitter-user-gender-classification
两种可以显著提升你的查询的高级 SQL 技巧
了解公用表表达式(CTE)和窗口函数
·发布于Towards Data Science ·11 分钟阅读·2023 年 6 月 30 日
--

图片由Karina Szczurek提供,来源于Unsplash
SQL 是每个数据专业人员的必备技能。无论你是数据分析师、数据科学家还是数据工程师,你都需要对如何编写干净高效的 SQL 查询有一个扎实的理解。
这是因为任何严谨的数据分析或复杂的机器学习模型背后都依赖于数据本身,而这些数据必须来自某个地方。
希望在阅读了我关于 SQL 的博客文章后,你已经了解到 SQL 代表结构化查询语言,它是一种用于从关系数据库中检索数据的语言。
在那篇博客文章中,我们讨论了一些基本的 SQL 命令,如 SELECT、FROM 和 WHERE,这些命令涵盖了你在使用 SQL 时可能遇到的大多数基础查询。
但如果这些简单的命令不够用怎么办?如果你想查询的数据需要更强大的查询方法怎么办?
好了,今天我们将介绍两种新的 SQL 技巧,你可以将它们添加到你的工具包中,从而将你的查询提升到一个新的水平。这些技巧被称为公用表表达式(CTE)和窗口函数。
为了帮助我们学习这些技巧,我们将使用一个名为DB Fiddle的在线 SQL 编辑器(设置为 SQLite v3.39)和来自 Google Cloud 的出租车行程数据集(NYC Open Data 许可证)。
数据准备
如果你对我如何准备数据集不感兴趣,可以跳过这一部分,直接将以下代码粘贴到 DB Fiddle 中以生成模式。
CREATE TABLE taxi (
id varchar,
vendor_id integer,
pickup_datetime datetime,
dropoff_datetime datetime,
trip_seconds integer,
distance float
);
INSERT INTO taxi
VALUES
('id2875421', 2, '2016-03-14 17:24:55', '2016-03-14 17:32:30', 455, 0.93),
('id2377394', 1, '2016-06-12 00:43:35', '2016-06-12 00:54:38', 663, 1.12),
('id3858529', 2, '2016-01-19 11:35:24', '2016-01-19 12:10:48', 2124, 3.97),
('id3504673', 2, '2016-04-06 19:32:31', '2016-04-06 19:39:40', 429, 0.92),
('id2181028', 2, '2016-03-26 13:30:55', '2016-03-26 13:38:10', 435, 0.74),
('id0801584', 2, '2016-01-30 22:01:40', '2016-01-30 22:09:03', 443, 0.68),
('id1813257', 1, '2016-06-17 22:34:59', '2016-06-17 22:40:40', 341, 0.82),
('id1324603', 2, '2016-05-21 07:54:58', '2016-05-21 08:20:49', 1551, 3.55),
('id1301050', 1, '2016-05-27 23:12:23', '2016-05-27 23:16:38', 255, 0.82),
('id0012891', 2, '2016-03-10 21:45:01', '2016-03-10 22:05:26', 1225, 3.19),
('id1436371', 2, '2016-05-10 22:08:41', '2016-05-10 22:29:55', 1274, 2.37),
('id1299289', 2, '2016-05-15 11:16:11', '2016-05-15 11:34:59', 1128, 2.35),
('id1187965', 2, '2016-02-19 09:52:46', '2016-02-19 10:11:20', 1114, 1.16),
('id0799785', 2, '2016-06-01 20:58:29', '2016-06-01 21:02:49', 260, 0.62),
('id2900608', 2, '2016-05-27 00:43:36', '2016-05-27 01:07:10', 1414, 3.97),
('id3319787', 1, '2016-05-16 15:29:02', '2016-05-16 15:32:33', 211, 0.41),
('id3379579', 2, '2016-04-11 17:29:50', '2016-04-11 18:08:26', 2316, 2.13),
('id1154431', 1, '2016-04-14 08:48:26', '2016-04-14 09:00:37', 731, 1.58),
('id3552682', 1, '2016-06-27 09:55:13', '2016-06-27 10:17:10', 1317, 2.86),
('id3390316', 2, '2016-06-05 13:47:23', '2016-06-05 13:51:34', 251, 0.81),
('id2070428', 1, '2016-02-28 02:23:02', '2016-02-28 02:31:08', 486, 1.56),
('id0809232', 2, '2016-04-01 12:12:25', '2016-04-01 12:23:17', 652, 1.07),
('id2352683', 1, '2016-04-09 03:34:27', '2016-04-09 03:41:30', 423, 1.29),
('id1603037', 1, '2016-06-25 10:36:26', '2016-06-25 10:55:49', 1163, 3.03),
('id3321406', 2, '2016-06-03 08:15:05', '2016-06-03 08:56:30', 2485, 12.82),
('id0129640', 2, '2016-02-14 13:27:56', '2016-02-14 13:49:19', 1283, 2.84),
('id3587298', 1, '2016-02-27 21:56:01', '2016-02-27 22:14:51', 1130, 3.77),
('id2104175', 1, '2016-06-20 23:07:16', '2016-06-20 23:18:50', 694, 2.33),
('id3973319', 2, '2016-06-13 21:57:27', '2016-06-13 22:12:19', 892, 1.57),
('id1410897', 1, '2016-03-23 14:10:39', '2016-03-23 14:49:30', 2331, 6.18);
运行SELECT * from taxi后,你应该会得到一个类似于下面的表格。

图片由作者提供。
对于那些想了解这个表格是如何来的朋友,我将数据过滤到了前 30 行,只保留了你上面看到的列。至于距离字段,我计算了取车点和放车点之间的正距距离(纬度和经度)。
正距距离是球面上两点之间的最短距离,因此这实际上会低估出租车实际行驶的距离。然而,出于今天的目的,我们可以暂时忽略这一点。
计算正距距离的公式可以在这里找到。现在,回到 SQL。
公共表表达式(CTE)
公共表表达式(CTE)是你在查询中返回的临时表。你可以将它视为一个查询中的查询。它们不仅帮助你将查询拆分为更易读的块,还可以基于已定义的 CTE 编写新查询。
为了演示这个,我们假设我们想分析按小时划分的出租车行程,并过滤 2016 年 1 月至 3 月之间发生的行程。
SELECT CAST(STRFTIME('%H', pickup_datetime) AS INT) AS hour_of_day,
trip_seconds,
distance
FROM taxi
WHERE pickup_datetime > '2016-01-01'
AND pickup_datetime < '2016-04-01'
ORDER BY hour_of_day;

图片由作者提供。
够直接了;让我们更进一步。
假设我们现在想计算每个小时的行程数量和平均速度。这时我们可以利用 CTE 首先获取一个类似于上面观察到的临时表,然后执行后续查询以按小时统计行程数量并计算平均速度。
定义 CTE 的方式是使用WITH和AS语句。
WITH relevantrides AS
(
SELECT CAST(STRFTIME('%H', pickup_datetime) AS INT) AS hour_of_day,
trip_seconds,
distance
FROM taxi
WHERE pickup_datetime > '2016-01-01'
AND pickup_datetime < '2016-04-01'
ORDER BY hour_of_day
)
SELECT hour_of_day,
COUNT(1) AS num_trips,
ROUND(3600 * SUM(distance) / SUM(trip_seconds), 2) AS avg_speed
FROM relevantrides
GROUP BY hour_of_day
ORDER BY hour_of_day;

图片由作者提供。
使用 CTE 的替代方法是简单地将临时表包裹在FROM语句中(见下方代码),这会给你相同的结果。然而,从代码可读性的角度来看,这并不推荐。此外,想象一下如果我们需要创建多个临时表的话。
SELECT hour_of_day,
COUNT(1) AS num_trips,
ROUND(3600 * SUM(distance) / SUM(trip_seconds), 2) AS avg_speed
FROM (
SELECT CAST(STRFTIME('%H', pickup_datetime) AS INT) AS hour_of_day,
trip_seconds,
distance
FROM taxi
WHERE pickup_datetime > '2016-01-01'
AND pickup_datetime < '2016-04-01'
ORDER BY hour_of_day
)
GROUP BY hour_of_day
ORDER BY hour_of_day;
额外信息:从这个练习中我们可以得到一个有趣的见解,那就是出租车在高峰时段的移动速度较慢(平均速度较低),这很可能是由于人们上下班的交通拥堵造成的。
窗口函数
窗口函数对行组执行聚合操作,但它为原始表中的每一行生成一个结果。
要完全理解窗口函数的工作原理,首先快速回顾一下通过GROUP BY进行的聚合会有所帮助。
假设我们希望使用出租车数据集计算按月份汇总的统计数据。
SELECT CAST(STRFTIME('%m', pickup_datetime) AS INT) AS month,
COUNT(1) AS trip_count,
ROUND(SUM(distance), 3) AS total_distance,
ROUND(AVG(distance), 3) AS avg_distance,
MIN(distance) AS min_distance,
MAX(distance) AS max_distance
FROM taxi
GROUP BY month;

图片由作者提供。
在上面的示例中,我们计算了数据集中每个月的计数、总和、平均值、最小值和最大值。请注意,我们原本有 30 行的出租车表现在已经压缩为每个月一行,共六行。
那么,实际在幕后发生了什么?首先,SQL 根据月份对原始表中的 30 行进行了分组。然后,它根据这些分组中的值应用了相关的计算。
以一月份为例。数据集中有两次发生在一月份的旅行,分别行程为 3.97 和 0.68。SQL 根据这两个值计算了计数、总和、平均值、最小值和最大值。这个过程会重复进行,直到所有月份的数据处理完成,最终得到的输出类似于上述结果。
现在,请记住这个想法,我们开始探讨窗口函数的工作原理。窗口函数主要有三大类:聚合函数、排名函数和导航函数。我们将分别查看每一类的示例。
聚合函数
我们在之前的示例中已经见过聚合函数的作用。聚合函数包括计数、求和、平均值、最小值和最大值等函数。
但窗口函数与GROUP BY的不同之处在于最终输出的行数。具体来说,我们看到在按月份聚合后,我们的输出表只剩下六行(每个月一行)。
窗口函数与聚合字段不同,它不会对表进行汇总,而是简单地在每一行中输出结果到一个新列。输出表中的行数不会改变。换句话说,输出表的行数总是与原始表相同。
执行窗口函数的语法是OVER(PARTITION BY ...)。你可以将其视为我们之前示例中的GROUP BY语句。
让我们看看实际情况是如何运作的。
WITH aggregate AS
(
SELECT id,
pickup_datetime,
CAST(STRFTIME('%m', pickup_datetime) AS INT) AS month,
distance
FROM taxi
)
SELECT *,
COUNT(1) OVER(PARTITION BY month) AS trip_count,
ROUND(SUM(distance) OVER(PARTITION BY month), 3) AS total_month_distance,
ROUND(AVG(distance) OVER(PARTITION BY month), 3) AS avg_month_distance,
MIN(distance) OVER(PARTITION BY month) AS min_month_distance,
MAX(distance) OVER(PARTITION BY month) AS max_month_distance
FROM aggregate;

图片由作者提供。
在这里,我们希望得到与上次相同的输出,但不是压缩表,而是希望在新列中将结果显示为单独的行。
你会发现聚合后的值没有改变,而是简单地以重复的行显示在表中。例如,前两行(一月)的旅行计数、总月度距离、平均月度距离、最小月度距离和最大月度距离与之前相同。其他月份也是如此。
如果你想知道窗口函数的作用,它帮助我们将每一行的值与聚合值进行比较。在这种情况下,我们可以轻松比较每一行的行驶距离与每月的平均值、最小值和最大值等等。
排名函数
另一种窗口函数是排名函数。顾名思义,这种函数基于聚合字段对一组行进行排名。
WITH ranking AS
(
SELECT id,
pickup_datetime,
CAST(STRFTIME('%m', pickup_datetime) AS INT) AS month,
distance
FROM taxi
)
SELECT *,
RANK() OVER(ORDER BY distance DESC) AS overall_rank,
RANK() OVER(PARTITION BY month ORDER BY distance DESC) AS month_rank
FROM ranking
ORDER BY pickup_datetime;

图片由作者提供。
在上面的例子中,我们有两个排名列:一个是整体排名(从 1 到 30),另一个是按月排名,两者均为降序排列。
要指定排名的顺序,你需要在OVER语句中使用ORDER BY。
对于第一行的结果,你会解释为它在整个数据集中具有第三长的行驶距离,并且在一月份的行驶距离最长。
导航函数
最后但同样重要的是,我们还有导航函数。
导航函数根据不同于当前行的行的值分配一个值。一些常见的导航函数包括FIRST_VALUE、LAST_VALUE、LEAD和LAG。
SELECT id,
pickup_datetime,
distance,
LAG(distance) OVER(ORDER BY pickup_datetime) AS prev_distance,
LEAD(distance) OVER(ORDER BY pickup_datetime) AS next_distance
FROM taxi
ORDER BY pickup_datetime;

Lag 返回前一行的值。图片由作者提供。

Lead 返回下一行的值。图片由作者提供。
在上面的例子中,我们使用了LAG函数返回前一行的值,使用LEAD函数返回下一行的值。请注意,滞后列的第一行是空的,而前导列的最后一行是空的。
SELECT id,
pickup_datetime,
distance,
LAG(distance, 2) OVER(ORDER BY pickup_datetime) AS prev_distance,
LEAD(distance, 2) OVER(ORDER BY pickup_datetime) AS next_distance
FROM taxi
ORDER BY pickup_datetime;

当滞后偏移设置为 2 时,前两行是空的。图片由作者提供。

当前导偏移设置为 2 时,最后两行是空的。图片由作者提供。
同样,我们还可以对LEAD和LAG函数进行偏移,即从特定的索引或位置开始。当偏移设置为二时,你会发现滞后列的前两行是空的,而前导列的最后两行是空的。
我希望这篇博客文章能帮助你了解公共表表达式(CTE)和窗口函数的概念。
总结来说,CTE 是一个临时表或查询中的查询。它们用于将查询拆分成更易读的块,你可以对已定义的 CTE 写新查询。另一方面,窗口函数在一组行上执行聚合,并返回原始表中每行的结果。
如果你希望改进这些技术,我强烈鼓励你在工作中、解决面试问题时,或是随意操作随机数据集时开始在 SQL 查询中实现它们。练习才能达到完美,对吧?
支持我和其他出色的作者,请使用下面的链接注册 Medium 会员。祝学习愉快!
[## 使用我的推荐链接加入 Medium - Jason Chong
作为 Medium 会员,你的会员费用的一部分将用于支持你阅读的作者,并且你可以完全访问每一个故事……
chongjason.medium.com](https://chongjason.medium.com/membership?source=post_page-----81a97c92ddd0--------------------------------)
不知道接下来读什么?以下是一些建议。
## 每个数据分析师都需要知道的 10 个最重要的 SQL 命令
从数据库查询数据不需要复杂。
towardsdatascience.com ## 正则表达式清晰解释及示例
每个数据分析师在处理字符串时应该具备的一项被低估的技能。
towardsdatascience.com ## 可能影响或决定你数据科学项目成败的常见问题
一份有用的指南,介绍如何发现数据问题、这些问题为何可能会带来不利影响,以及如何妥善解决它们。
towardsdatascience.com
两次发球:分析 2000 年至 2020 年的 ATP 发球数据
在巡回赛中,哪些球员应该舍弃他们的第二发球及其原因(以及一个用于可视化的 Dash 应用)
·
关注 发布于Towards Data Science · 7 分钟阅读 · 2023 年 2 月 3 日
--
库存图片由Pixabay提供
在过去十年左右的时间里,数据、分析和机器学习在体育界变得无处不在。同时,非传统的策略也在体育界越来越受欢迎。在足球中,四分之一尝试和两分转换尝试有所上升,因为分析证明了这些策略的增值。2021 年,四分之一尝试达到了 793 次的高峰,比 2011 年增长了超过 70%。在篮球中,长距离两分球几乎已经消失,因为分析专家已证明了角球三分球和上篮的期望价值更高。
就像篮球和足球在数据方面为战术和策略的改变提供了肥沃的土壤一样,我的业余爱好网球也是如此。对大多数网球职业选手来说,体育分析的价值在于对手侦查。使用对手的统计数据可以告诉球员在攻击对手的反手球、迫使他们从底线打长时间的对拉球或试图将他们引向网前等方面,什么样的策略可能是最佳的。然而,职业选手仍然有机会利用该运动产生的大量数据,从传统的发球百分比统计到通过视觉 AI 应用获取的更高级的球员位置统计数据。
去年在我加入一个本地网球联赛、结束了四年的空窗期后,我意识到一个分析网球数据的机会。在我前几场比赛中,我发现虽然我的第一次发球经常出界,但我的对手能够攻击我的弱第二发。在经历了过多的破发点后,我自问我是否更应该放弃第二发球,改用第一次发球两次呢?
前提非常简单。在网球中,发球员具有优势,因为他们可以开始回合——这赋予了他们在球的放置位置和击球旋转类型等方面的战略选择。发球也是网球中最快的击球之一,ATP 巡回赛的球员经常达到大约120 MPH (193 KM/H)的速度。不仅如此,发球员还有两次机会发球,以防第一次失误。
大多数球员,无论是业余还是职业球员,都使用两种不同的发球。实际上,网球的发球策略基于一个相当简单的原则。球员通常利用首发球打出更快、更有利的落点,例如在场地边线或“网眼”处的中线位置,以获取优势,迫使对手回球薄弱,将对手推入不利的位置,或者通过发球直接得分赢得分数。第二发球则通常打得更为安全。球员倾向于将球打得更慢,并在具有更高失误容限的区域进行发球。球员仍然会尝试通过控制对手的场地位置来获取优势,但这种优势通常比成功的首发球要小。
现在,让我们来讨论一下球员是否会从两次使用首发球中受益。这并不是一个特别难回答的问题,它只需要四个变量:
-
首发球成功率: 首发成功并因此开始得分的百分比;
-
首发球赢分率: 球员在成功发入首发球后赢得的分数的百分比;
-
第二发球成功率: 第二发球成功并因此开始得分的百分比;
-
第二发球成功率: 球员在成功发入第二发球后赢得的分数的百分比。双误会影响球员的第二发球成功率。
这引出了一个关键问题:使用两次首发球的期望值E(p)是否大于或小于传统首发和第二发球策略的期望值?

图 1:网球发球的期望值公式(作者创建的截图)
为了回答发球策略的问题,我分析了由TennisAbstract.com创始人 Jeff Sackmann 提供的优秀数据(感谢 Jeff 允许使用他的数据撰写这篇文章)。我从Jeff 的 GitHub中提取了数据,并合并了 2000 年至 2020 年的数据集。本文其余部分的分析使用了 Python 生成。
数据集详细列出了在时间范围内所有 ATP 单打比赛的信息,包括比赛和场地。然而,我的主要关注点是发球数据(即,首发成功率、发球直接得分、双误等)。此外,由于 ATP 是男子巡回赛,这些数据仅针对男性球员。我打算将 WTA 数据添加到此分析和 Dash 应用程序中。
分析
为了不掩盖重点,直截了当的说:在 2000 年至 2020 年间,ATP 巡回赛中有高达五分之一的球员采用双发球策略会更有利。我不是托尼叔叔,但 20%的球员从采用更具攻击性的发球策略中受益,似乎是一个相当大的市场低效。
在准备数据时,我采取的第一步——除了创建我们的关键变量如发球成功率——是将观察结果滚动到球员级别。在比赛级别,由于方差过大,难以推断出全面的发球策略。下图 2 是拉斐尔·纳达尔每场比赛使用双发球策略的预期值直方图。请注意,即使是伟大的网球选手之一,纳达尔的预期值也有相当大的变化。

25 百分位数:60% | 平均值:65% | 75 百分位数:.71%(图表由作者使用 Plotly 生成)
为了进一步展示这种差异,我比较了比赛级别和球员级别的数据。我发现,在任何给定的比赛中,大约 35%的球员如果使用双发球策略会表现更好。然而,当查看每位球员的职业生涯数据时,只有 24%的球员使用双发球会更有利。当我对至少打了 50 场比赛的球员进行子集分析时,这一比例进一步下降到 20%。
Dash 应用程序
为了更深入了解球员发球百分比的因素、决定因素和协变量,我使用 Heroku 构建并部署了一个 Dash 应用程序。对初学者来说,Dash 是一个强大的工具,允许用户创建自定义仪表盘,而 Heroku 是一个平台即服务的云平台,允许个人用户部署简单的应用程序。
我为此分析创建的仪表盘恰到好处地很简单。用户选择 x、y 和颜色变量。默认的 x 和 y 变量分别是单发球策略的预期值和双发球策略的预期值。下图捕捉了这个图表。

图 2:发球策略比较(截图由作者创建)
红色编码的点表示那些双发球策略的预期值高于当前单发球策略的球员。上图虽然信息丰富,但它实际上只是告诉我们谁可能从这种策略中受益。如果我想知道什么因素使得球员更可能从双发球策略中受益呢?

图 3:身高与最佳发球策略的关系(截图由作者创建)
多年来,分析师们已经认识到身高是成功服务的一个关键因素。即使是普通的球迷也许还记得约翰·伊斯内尔,他与尼古拉斯·马胡特的 11 小时 5 分钟对决打破了比赛时长纪录。伊斯内尔身高 6 英尺 10 英寸,在比赛的决胜盘中主宰了自己的发球局,未失一个破发点。而这些数据也证实了网球分析师的普遍智慧和伊斯内尔自身的经历。观察身高与双发球策略之间的关系,那些对这种策略最有利的似乎是运动中的巨人。另一位身高巨大的 ATP 职业选手伊沃·卡尔洛维奇(巧合的是也因长时间比赛而闻名)在 2009 年以超过 0.75 分的预期值位居榜单首位,使用了双发球策略。
对于运动中的巨人来说,双发球策略的好处在于他们能够在第一次发球成功时赢得分数。2009 年,卡尔洛维奇的第一次发球赢球率接近 85%,显著高于约 71%的样本平均水平。类似地,伊斯内尔的第一次发球赢球率为 82%。如此高的成功率,加上即使是适中的第一次发球百分比——即球员能够用第一次发球将球投入比赛的比例——通常足以使发球者在双发球策略下达到更高的预期得分值。
结论
那么,这对 ATP 巡回赛中的职业球员意味着什么呢?对于大多数球员来说,影响不大。四分之三的球员仍然保持传统的发球策略。然而,对于剩下的五分之一的球员,特别是那些身高超过 6 英尺 4 英寸(193 厘米)的球员,我建议尝试双发球策略(尽管,我可能不会等到比赛决赛才尝试)。那些在第一次发球时表现出色的球员可能会通过双发球策略给对手施加压力,即使这意味着更频繁地双误。
你需要知道的两个有趣的 pandas 数据操作函数
数据科学
极其有用的 pandas 函数可以将连续的 pandas 列转换为分类列。
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 8 月 24 日
--

照片由 Brendan Church 提供,来源于 Unsplash
Python pandas 是一个强大且广泛使用的数据分析库。
它提供了 200 多个函数和方法,使数据操作和转换变得容易。然而,了解所有这些函数并在实际工作中按需使用它们并不是一项可行的任务。
数据操作中的常见任务之一是将包含连续数值的列转换为包含离散或分类值的列。pandas 有两个了不起的内置函数,可以节省你几分钟时间。
你可以将这种类型的数据转换用于各种应用,如分组数据、按离散组分析数据或使用直方图可视化数据。
例如,
最近,我计算了赫芬达尔-赫希曼指数 (HHI)以了解多个品牌的市场集中度。因此,在一个 pandas DataFrame 中,我有一个包含所有品牌 HHI 连续值的列。最终,我想将这一列转换为离散列,以将每个品牌分类为低、中和高市场集中度——这就是我获得灵感的地方。
如果不知道这些内置的 pandas 函数,你可能需要编写多个 if-else 和 for 语句来完成相同的工作。
因此,在这里你将探索两个超级有用的 pandas 内置函数以及有趣的示例(包括我的项目),这些示例将大大提升你的数据分析能力,并节省你几分钟时间。
在你的分析项目中,通常需要将一个具有连续值的列转换为另一个具有离散值的列。
所以基本上,你将连续数据分类为几个类别,即桶或箱子。你可以通过指定每个箱子的最小值和最大值,即定义箱子边缘,或通过指定箱子数量来做到这一点。
根据你将连续序列拆分为离散序列的目的,你可以使用 接下来的两种 方法之一。
由于我对工作中的内置函数感到好奇,首先我遇到了 pandas 库中的 cut() 函数。
pandas cut()
当你想将数据分成固定数量的不同桶时,可以使用 pandas cut(),无论每个桶中的值的数量如何。
根据 pandas 官方文档,**pandas.cut()** 函数有 7 个可选参数和 2 个必需参数。
但你不需要记住所有这些。
我已经为你简化了内容。我现在经常使用这个函数,发现一些函数参数比其他参数更有用。
这里是你在几乎 90% 的情况下 会使用的常用可选参数。
pandas.cut(x,
bins,
labels=None,
right=True,
include_lowest=False)
让我们举一个例子来理解这些参数是如何工作的。
假设你有以下连续序列,你想将其转换为 5 个箱子。
import pandas as pd
import numpy as np
# Create random data
Series1 = pd.Series(np.random.randint(0, 100, 10))
# Create DataFrame
df = pd.DataFrame({"Series1": Series1})
# Apply pandas.cut() on the column Series1
df["binned_Series1"] = pd.cut(df["Series1"], bins=5)

pandas cut() | 作者图片
你简单地将整数 5 分配给参数 bin——结果,pandas 将整个列 Series1 拆分为 5 个相等大小的桶。Pandas 将 Series1 中的每个值分配到这 5 个桶中的一个。
如果你检查这些桶中的每一个,你会发现两个共同点。
-
箱子边缘是非整数——你可以通过在 bin 参数中定义箱子边缘来解决这个问题。
-
每个箱子边缘在右侧是封闭的——这是由于参数 right 的默认设置
right=True。这意味着 pandas 包括桶中的最大值在同一个桶中。这个参数特别帮助你 控制分箱过程,并且切换其值可以帮助你包括或排除某些元素。
让我们再试一次。
这次你将传递一个箱子边缘的列表给相同的 DataFrame 列,看看结果是如何变化的。
df["binned_Series1_defined_binedge"] = pd.cut(df["Series1"],
bins=[0, 10, 15, 40, 65, 100])

pandas cut 定义了箱子边缘 | 作者图片
Pandas 使用你在 bin 参数中提供的整数简单地创建了新的箱子,并将 Series1 中的每个数字分配到这些箱子中。
此外,你还可以使用 Label 参数为每个桶命名,如下所示。
df["bin_name"] = pd.cut(df["Series1"],
bins=[0, 10, 15, 40, 65, 100],
labels=['bin 1', 'bin 2', 'bin 3', 'bin 4', 'bin 5'])

pandas cut() 带有箱子标签 | 作者图片
它工作得非常完美!
回到我的工作——一个真实场景——我在下面的数据集上尝试了函数 **pandas.cut()**。
# Create a sample DataFrame as I can not disclose the original data
HHI = [random.random() for i in range(10)]
Brands = ["Brand_1", "Brand_2", "Brand_3", "Brand_4", "Brand_5",
"Brand_6", "Brand_7", "Brand_8", "Brand_9", "Brand_10"]
df = pd.DataFrame({"brand": Brands, "hhi": HHI})
# Use pandas.cut()
df["binned_hhi"] = pd.cut(df["hhi"], bins=3)
df["brand_bucket"] = pd.cut(df["hhi"],
bins=3,
labels = ["low", "medium", "high"])
df

在实际例子中使用 pandas.cut() | 图片由作者提供
然而,这些桶中的元素分布不均,即每个桶包含的元素数量不同。5 个品牌属于低,3 个品牌属于中,仅 2 个品牌属于高浓度桶。
但对于我的项目,我想保持分布,即每个桶中的品牌数量相同,这就是我发现下一个 pandas 方法有用的地方。
pandas qcut()
pandas.qcut()用于在所有桶中获得均等的数据分布。它基于样本分位数的原理。
分位数是将序列分成若干个子集的值——每个子集包含大致相同数量的元素。
因此,当你使用函数qcut()切分一个序列时,它只是告诉你序列的哪个元素属于哪个分位数。
函数qcut()的基本语法几乎与函数cut()的语法相同。
让我们通过一个例子来理解——在这里你将对相同的数据使用函数cut()和qcut(),并将它们分为 4 个桶。
Series1 = pd.Series([17, 47, 35, 6, 6, 16, 78, 14, 79, 98])
df = pd.DataFrame({"Series1": Series1})
df["qcut_Series1"] = pd.qcut(df["Series1"], q=4) # Use qcut()
df["cut_Series1"] = pd.cut(df["Series1"], bins=4) # Use cut()

基于分位数的离散化 Python | 图片由作者提供
现在,当你检查每个桶中的数据分布时——
# Check the data distribution of each bucket when cut() was used
df["cut_Series1"].value_counts()
#Output
(5.908, 29.0] 5
(75.0, 98.0] 3
(29.0, 52.0] 2
(52.0, 75.0] 0
Name: cut_Series1, dtype: int64
# Check the data distribution of each bucket when qcut() was used
df["qcut_Series1"].value_counts()
#Output
(5.999, 14.5] 3
(70.25, 98.0] 3
(14.5, 26.0] 2
(26.0, 70.25] 2
Name: qcut_Series1, dtype: int64
你会看到,当你使用函数**cut()**时,尽管每个桶的大小相等,即 23,但每个桶中包含的元素数量不同。
而当你使用函数**qcut()**时,每个桶中存在类似数量的元素。但你可以看到,这种分布是以不同的桶大小为代价的。
因此,在我的项目中,函数pandas.qcut()是最终解决方案,正如你所看到的——
df["binned_hhi_qcut"] = pd.qcut(df["hhi"], q=3)
df["brand_bucket_qcut"] = pd.qcut(df["hhi"],
q=3,
labels = ["low", "medium", "high"])
df

使用 pandas.qcut()的实际场景 | 图片由作者提供
因此,**qcut()**将每个中和高浓度桶分配了 3 个品牌,将低浓度桶分配了 4 个品牌。
希望你发现这篇文章既清新又有用。尽管将连续序列转换为离散序列是数据分析中的常见场景,但如果你不了解内置函数,这项任务确实可能非常艰巨。
在数据分析项目中使用这些函数,肯定能帮助你迅速从数据中提取所需的信息。
在评论中告诉我你希望获得哪些精彩的文章!
仅仅了解这些函数是不够的——现在就开始在数据分析任务中使用它们,释放真正的 pandas 力量吧。
准备提升你的数据分析技能了吗?
💡 考虑成为 Medium 会员以访问无限的 Medium 故事和每日有趣的 Medium 摘要。我将获得你费用的一小部分,对你没有额外费用。
💡 确保注册我的邮件列表,以便不错过任何关于数据科学指南、技巧和提示、SQL 和 Python 的文章。
了解更多关于我的项目,请评论你的问题!
感谢你的阅读!
两篇新论文详细分析了 AlphaFold 2 的 2 亿个模型揭示的蛋白质宇宙
他们不得不创建新的工具来处理如此大规模的蛋白质结构模型
LucianoSphere (Luciano Abriata, PhD)
·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 9 月 21 日
--

在讨论的文章中展示的资源之一,uniprot3d.org/,描绘了一个现代化的宇宙视图,其中较亮的簇包含更多的成员。用户可以缩放查看相关蛋白质,点击特定节点会显示有关 Uniprot 中蛋白质家族的信息(此处仅显示结构模型)。图片由作者在浏览网站时制作。
DeepMind 的 AlphaFold 2 与欧洲生物信息学研究所合作发布了超过 2 亿个预测蛋白质结构,标志着蛋白质研究进入了一个新时代。在这里,我总结了本周在Nature上发表的两篇开创性论文的发现,这些论文深入探讨了这一蛋白质宇宙的深度。这些论文采用了创新的聚类算法、结构比较和现有工具的其他适配方法,以处理大数据量,从而揭示了前所未有的蛋白质结构多样性、进化关系和功能潜力。
蛋白质是生物学中的“工作马”,支配着从能量生成到细胞分裂的无数细胞过程。尽管随着基因组学的发展,蛋白质测序在近年来迅速增长,但由于缺乏可扩展的实验方法,它们的 3D 结构确定却滞后。然而,随着 DeepMind 开发的革命性 AI 系统 AlphaFold 2 的出现,蛋白质结构预测的格局发生了变化。AlphaFold 蛋白质结构数据库(AFDB)现在拥有惊人的 2 亿个预测蛋白质结构,标志着计算生物学的一个里程碑。
## 基于 AlphaFold 的数据库和全面成熟、易于使用的在线 AlphaFold 接口即将...
不仅是计算生物学,还有实验生物学。对生物学中数据科学领域未来的思考。
[towardsdatascience.com
事实上,就在本周,两组作者在自然杂志上报告了如何利用 AlphaFold 2 的蛋白质模型来揭示蛋白质宇宙的新见解。这些研究利用了现有工具的创新版本,这些工具被调整以适应 AFDB 中的大量数据;例如,现代版本的聚类算法和结构比较方法。通过这些调整后的工具,这些研究探讨了蛋白质结构的广阔领域、它们的进化起源以及它们的功能意义。
聚类
在任何涉及过多对象的研究中,这些对象中许多将紧密相关甚至非常相似,聚类有助于降低复杂性。蛋白质结构也不例外。
在第一篇文章中,Inigo Barrio-Hernandez 及其同事的《在已知蛋白质宇宙规模上的聚类预测结构》中,作者们面对的是将 AFDB 中的 2 亿个蛋白质结构进行聚类的巨大任务。他们介绍了一种基于结构对齐的高效聚类算法,称为 Foldseek cluster。这种新颖的算法能够根据蛋白质的结构相似性快速进行分组,这是理解蛋白质进化和功能的关键步骤。
这项研究的结果非常显著。作者在 AFDB 中识别出了令人惊讶的 230 万个非单例结构簇,其中 31%的簇缺乏注释,代表了之前未被特征化的蛋白质结构。这些未注释的簇虽然只占 AFDB 中所有蛋白质的 4%,却为蛋白质宇宙中尚未发现的领域提供了有趣的见解。进化分析表明,大多数这些簇具有古老的起源,而一部分则似乎是特定于物种的,可能标志着较低质量的预测或 de novo 基因出生的实例。
此外,研究展示了结构比较在预测结构域家族及其关系中的实用性。值得注意的是,作者识别出遥远的结构相似性,揭示了蛋白质之间隐藏的联系。这一新发现知识的一个重要应用是识别具有潜在遥远同源性的与人类免疫相关的蛋白质,展示了这一资源在揭示蛋白质功能和生命树上进化的巨大潜力。
## Clustering-predicted structures at the scale of the known protein universe - Nature
蛋白质是所有细胞过程的关键,其结构在理解其功能和……
新的家族和结构
第二篇文章,“揭示自然蛋白质宇宙中的新家族和结构”,如本文的封面图所示,由 Janani Durairaj 及其合作者撰写,重点转向利用 AlphaFold 的预测探索蛋白质宇宙中的‘暗物质’。作者创建了一个互动的序列相似性网络,连接了 AFDB 中超过 5000 万种准确预测的蛋白质结构。这个网络作为揭示蛋白质多样性隐藏面貌的强大工具。
这项研究的一个突出发现是识别出了一个新型蛋白质结构,恰如其分地命名为‘Beta-flower’。这一以前未见的结构特征由类似花瓣的发夹型转折组成,类似于 Beta-barrel,为研究人员提供了一个令人兴奋的谜题待解。具有 Beta-flowers 的蛋白质虽关系较远,但其功能仍然未解,突显了未来研究的无限机遇。
此外,作者通过添加多个蛋白质家族扩展了 Pfam 数据库,强调了他们发现的实际应用。值得注意的是,他们实验验证了一种新的翻译靶向毒素-抗毒素系统超家族 TumE-TumA,突显了大规模识别和注释新蛋白质家族的巨大潜力。
[## 发现自然蛋白质宇宙中的新家族和折叠 - 自然
我们现在进入了一个蛋白质序列和结构注释的新纪元,数亿个预测的……
www.nature.com](https://www.nature.com/articles/s41586-023-06622-3?source=post_page-----bf5bd55e754a--------------------------------)
变革性和实用性
AlphaFold 2 无疑具有变革性,在大规模运行时为科学提供了工作。但如同这两篇新的Nature论文中所做的那样,分析大量的结构模型也是一项艰巨的任务,这次直接聚焦于将结构模型整理成可以被科学家轻松使用的形式。
确实,Durairaj 及其合作者的工作最终形成了他们称之为“蛋白质宇宙图谱”的网页资源,这个资源提供了蛋白质序列相似性网络的互动视图,提供有关蛋白质多样性、社区组织、功能注释和结构异常的见解。你可以访问该资源uniprot3d.org/来通过互动浏览和缩放,查询 UniProtKB 条目,列出组件和社区,查看蛋白质序列和结构模型等。该资源促进了蛋白质宇宙的探索,揭示了新的蛋白质家族、折叠和功能见解,从而直接促进了结构研究。它提供了一些数据可视化功能,例如通过对节点进行上色以增强数据可视化。正如任何这样的资源应该有的那样,数据下载也可用,并提供各种格式。

放大蛋白质宇宙中的“社区”并将其与实际的 Uniprot 条目和信息连接起来。作者截图。
Barrio-Hernandez 及其同事的工作最终形成了一个实用的网页资源(cluster.foldseek.com/),这次专注于浏览和展示有关已识别蛋白质簇的信息,并提供数据下载服务。

欢迎界面和某一条目的示例簇信息。作者截图。
这里讨论的两篇文章标志着蛋白质科学的一个变革性时刻。特别是,它们紧密相关,处理相同的输入数据(AlphaFold 模型),提供相关但互补的分析,并且互相超链接。是的,还一起发表在自然期刊上。
通过 AlphaFold 2 提供的 2 亿个预测蛋白质结构的可用性,催生了探索蛋白质宇宙的创新方法。通过对这些结构进行聚类和比较,科学家们正在发现蛋白质多样性、进化和功能方面前所未有的见解。新折叠和蛋白质家族的发现,以及远程结构相似性的揭示,有望彻底改变我们对生命分子机制的理解。
正如我在传播文章中已经写过几次的那样,结构生物学中没有比现在更激动人心的时刻了,而这一切都是由于其与计算机科学,特别是人工智能的交叉。AlphaFold 数据库凭借其丰富的结构信息,已经成为全球科学家们的强大资源。通过每一项研究,我们揭示了生命复杂画卷中的新层次,使我们更接近理解蛋白质宇宙的奥秘。
进一步阅读
如果你对结构生物学和结构生物信息学中的人工智能世界感兴趣,以及 DeepMind 通过 AlphaFold 2 启动的所有迷人科学,以下是(仅)一些我撰写的文章,旨在以易于理解但严谨的方式传达所有最新发展:
[## 这里是我所有关于蛋白质建模、CASP 和 AlphaFold 2 的同行评审和博客文章
我在这里汇总了所有经过同行评审的文章(包括一些论文、几篇综述、一篇观点文章)和关于...
lucianosphere.medium.com ## 蛋白质设计中的机器学习时代,概述为四种关键方法
由于这些基于人工智能的方法和工具,蛋白质生物技术迎来了如此激动人心的时刻。
towardsdatascience.com
版权声明
这里报道的论文和网站内容都在CC BY 4.0 许可协议下。图片由作者从这些网站的截图中合成,未使用论文中的任何图像。
www.lucianoabriata.com 我撰写和拍摄关于我广泛兴趣领域中的一切:自然、科学、技术、编程等。 订阅以获取我的新故事 通过电子邮件。要 咨询小型工作 请查看我的 服务页面在这里。你可以 在这里联系我。
两个强大的 Python 特性,以简化你的代码并提高可读性
通过匹配语句和对象切片提升你的代码质量。
·发布在 Towards Data Science ·8 分钟阅读·2023 年 9 月 29 日
--

Python 在当前技术领域的流行程度如此广泛是有原因的。在现代编程语言中,它可能是对新手最为友好的。凭借这种可访问性,它也提供了强大的功能。网页开发、数据科学、科学计算——你可以用 Python 完成许多任务。
随着 Python 多年的发展,其开发者们付出了巨大的努力,以保持其可读性和简洁性。尽管许多特性可能需要额外的学习努力,但代码的清晰度和美观度是绝对值得的。
在这篇文章中,我们将深入探讨两个这样的特性:匹配语句和字符串/列表切片。我们将详细介绍每个特性的工作原理,并考虑一些示例,以加深对语法和语义的理解。
现在,让我们深入探讨一下吧。
匹配语句
匹配语句——从 Python 3.10 版本开始可用——是一种检查条件的相等性并根据条件执行某些操作的方法[1]。如果你来自其他语言如 C 或 JavaScript,你可能已经熟悉这种概念,它们被称为switch 语句。
原则上,匹配语句类似于条件语句,但它们确实提供了一些有用的优势。让我们首先通过与条件语句的比较来看其基本结构,然后再讨论这些优势。
你可能会写出以下条件语句来检查某人的银行账户名称:
name = "Yen"
if name == "Yen":
print("This is your account.")
elif name == "Ben":
print("This is your sister's account.")
else:
print("Fraud attempt.")
转换为匹配语句后,它将如下所示:
name = "Yen"
match name:
case "Yen":
print("This is your account.")
case "Ben":
print("This is your sister's account.")
case _:
print("Fraud attempt.")
让我们逐行分析:
-
第一行是一样的——我们只是定义了
name变量。 -
关键字
match用于启动匹配语句。 -
然后,对于每个条件,我们使用
case语句来有效地进行模式匹配,而不是明确地检查等式。因此,你可以把case "Yen"看作是在检查name,即我们正在匹配的内容,是否等于"Yen"。 -
最后,最后一个 case 是通配符 case。它由下划线(
_)指定,实际上是else情况。
现在,你可能会问——为什么使用这个而不是传统的条件语句?我最初也有同样的问题,甚至对人们使用匹配语句而不是标准的 if-else 语句感到恼火。然而,确实有一些优势。
第一件事是,它是实现相同目标的一种更简洁的方法。这看起来可能是个托词,但实际上相当重要。Python 的整个精神在于编写干净、简洁的代码(如果你不相信我,试着在你的 Python 解释器中输入import this并按下回车)。
尤其是当条件数量较多时,解析长链的 if 和 elif 语句可能会很繁琐。使用匹配语句可以清理代码,并使同事程序员更容易阅读——这是任何 Python 程序员值得追求的成就。
除此之外,匹配语句还可以直接解构某些对象,从而不再需要手动使用条件语句。实际上,这意味着两件事:
-
你可以自动检查类型(省去了手动检查的需要)。
-
你可以在每个
case中自动访问对象的属性。
让我们看一个例子。假设我们有以下代码,定义了两种不同类型的汽车类:
# Online Python compiler (interpreter) to run Python online.
# Write Python 3 code in this online editor and run it.
class Honda:
# See below for explanation of __match_args__
__match_args__ = ("year", "model", "cost")
def __init__(self, year, model, cost):
self.year = year
self.model = model
self.cost = cost
class Subaru:
__match_args__ = ("year", "model", "cost")
def __init__(self, year, model, cost):
self.year = year
self.model = model
self.cost = cost
car = Subaru(2021, "Outback", 18000)
我们在上面定义了一个 Subaru 实例。现在,我们想编写代码来检查汽车的类型,并打印出其某些属性。使用传统的条件语句,我们可以这样做:
if isinstance(car, Honda):
print("Honda " + car.model)
elif isinstance(car, Subaru):
print("Subaru " + car.model)
else:
print("Failure :(")
对于我们上面的 car 变量,这将打印出 "Subaru Outback"。如果我们将其转换为匹配语句,将得到以下简化的代码:
match car:
case Honda(year, model, cost):
print("Honda " + model)
case Subaru(year, model, cost):
print("Subaru " + model)
case _:
print("Failure")
匹配的模式匹配功能使 Python 能够在 case 语句中自动检查类型,并进一步使对象的属性可以直接访问。注意,这得益于在类定义中包含 __match_args__ 属性,它为 Python 命名了位置参数。Python 文档中的推荐是使模式在为 self 分配属性时模拟 __init__ 构造函数中使用的模式。
匹配版本的代码更容易阅读且编写起来不那么繁琐。这是一个相当小的例子,但随着情况变得更加复杂,条件语句的字符串可能变得越来越复杂 [2]。
说到这一点,请记住,这一功能仅在 Python 3.10 及其之后的版本中可用。因此,你需要确保你编写代码的系统、应用程序或项目不会在需要兼容旧版 Python 的代码库中存在。
只要满足那个条件,就考虑使用 match 语句。虽然这可能需要一点努力,但从长远来看,你的代码将会更好。
字符串和列表切片
你可能对这个功能有一些了解,但我敢打赌你还没有完全发挥它的潜力。让我们从快速回顾开始,然后深入了解一些更复杂的用法。
在最简单的形式中,切片指的是一种简洁的语法,让你可以在 Python [3] 中提取字符串或列表的一部分。这里有一个小例子:
>>> my_str = "hello"
>>> my_str[1:3]
'el'
语法要求使用包含起始和结束索引的方括号,并且索引之间用冒号分隔。请记住,Python 使用的是 0 索引,所以这里1对应于'e'。此外,切片不包括右索引,因此它到达3但不包括它,因此输出是'el'而不是'ell'。
如果你只想从开始处开始或一直到字符串或列表的末尾,你可以将相应的索引留空:
>>> my_lst = ['apple', 'orange', 'blackcurrant', 'mango', 'pineapple']
>>> my_lst[:3]
['apple', 'orange', 'blackcurrant']
>>> my_lst[2:]
['blackcurrant', 'mango', 'pineapple']
留下两个索引为空会得到整个对象的副本:
>>> my_str[:]
'hello'
>>> my_lst[:]
['apple', 'orange', 'blackcurrant', 'mango', 'pineapple']
注意,在列表和字符串中,切片定义并返回一个全新的对象,这个对象与原始对象不同:
>>> new_lst = my_lst[2:]
>>> new_lst
['blackcurrant', 'mango', 'pineapple']
>>> my_lst
['apple', 'orange', 'blackcurrant', 'mango', 'pineapple']
现在,让我们进入重点。通过切片,你还可以使用负数索引。如果你不熟悉负数索引,它基本上允许你从列表或字符串的末尾开始计数。最后一个字母对应-1,倒数第二个字母对应-2,依此类推。
这可以通过省去手动计算长度来简化代码。例如,要获取字符串的所有内容但不包括最后一个字母,你可以这样做:
>>> my_str[:-1]
'hell'
最后,切片最被忽视的功能之一是你还可以指定第三个数字——这指定了一种“跳跃”。用一个例子来解释最简单:
>>> my_long_lst = ['apple', 'orange', 'blackcurrant', 'mango', 'pineapple', 'grapes', 'kiwi', 'papaya', 'coconut']
>>> my_long_lst[1:-1:2]
['orange', 'mango', 'grapes', 'papaya']
让我们分解一下上面的内容:
-
为了清楚起见,我们定义了一个包含更多元素的列表,而不是我们之前的原始列表。
-
在列表切片中,前两个数字是
1和-1。正如我们上面看到的,这会去掉被切片对象——在这种情况下是my_long_list的第一个和最后一个元素。 -
最后,我们在额外的冒号后面放一个
2作为最终数字。这告诉 Python 我们希望从开始到结束索引切片,但只保留每隔一个的项。放一个3会给我们每隔第三个项,放一个4会给我们每隔第四个项,依此类推。
将上述两点结合起来,我们还可以对列表进行切片以获得反向元素:
>>> my_long_lst[-1:1:-2]
['coconut', 'kiwi', 'pineapple', 'blackcurrant']
# To slice backwars successfully, the "jump" value must be negative
# Otherwise, we just get an empty list
>>> my_long_lst[-1:1:2]
[]
这就是了——关于列表切片的所有知识。当你对上述语法进行创意应用时,可以实现一些非常酷的行为。例如,以下是利用列表切片在 Python 中反转列表的最妙方式之一:
>>> my_lst
['apple', 'orange', 'blackcurrant', 'mango', 'pineapple']
>>> my_lst[::-1]
['pineapple', 'mango', 'blackcurrant', 'orange', 'apple']
你看到它是如何工作的了吗?作为一个练习,你应该复习上述列表切片的每个特性,并尝试自己分解代码。提示:看看当我们留空开始和结束索引时意味着什么。
那么,让我们谈谈为什么你应该学习这些内容。
作为数据科学家,这有什么用?
一般来说,在用 Python 编写代码时,考虑代码的可读性和整洁性很重要。使用上述特性将大有帮助。如我们所讨论的,匹配语句在这方面比条件语句有几个显著的优势。至于列表切片,它比尝试使用复杂循环实现相同行为要整洁得多。
但超越这些广泛的好处,让我们专门谈谈数据科学。
从实际角度看,如果你作为数据科学家工作,你的正式培训很可能不是计算机科学,而是统计学、数学,或者如果你有幸找到这样的项目,可能是数据科学本身。在这些项目中,计算机科学通常作为工具来教授。
重点是以一种教你足够知识以处理数据、进行分析和大规模构建模型的方式来学习编程基本原理。因此,没有大量时间去学习像“有用的 Python 特定语法特性”这样的主题。实际上,这些主题在纯计算机科学课程中也常常被忽视。
然而,使用这些特性可以将你的代码提升到一个新水平,帮助你在才华横溢的同事中脱颖而出,并为客户提供更好的结果。匹配语句和对象切片是两个强大的例子,但 Python 还有很多其他的特性可以提供,我鼓励你去探索。
愿代码永远对你有利——下次见,朋友们。
想要在 Python 中脱颖而出? 点击这里获取我简单易读的独家免费指南。想在 Medium 上阅读无限故事?使用下面的推荐链接注册吧!
[## 使用我的推荐链接加入 Medium - Murtaza Ali
作为 Medium 会员,你的会员费用的一部分将用于支持你阅读的作者,并且你可以完全访问每个故事……
参考文献
[1] docs.python.org/3.10/whatsnew/3.10.html#syntax-and-operations
[2] peps.python.org/pep-0622/#rationale-and-goals
[3] docs.python.org/3/c-api/slice.html
推荐系统中的双塔网络和负采样
了解推动高级推荐引擎的关键元素
·
关注 发表于 Towards Data Science ·7 分钟阅读·2023 年 11 月 24 日
--
目前推荐系统中最重要的模型之一是双塔神经网络。它们的结构如下:神经网络的一个部分(塔)处理关于查询(用户、上下文)的所有信息,而另一个塔处理关于对象的信息。这些塔的输出是嵌入,然后将这些嵌入相乘(点积或余弦,如我们已经在这里讨论过)。双塔网络应用于推荐的最早提及之一可以在关于 YouTube 的一篇非常好的论文中找到。顺便提一下,我现在会将这篇文章称为经典且最适合进入推荐领域的文章。

来自论文 YouTube 推荐的深度神经网络
这种网络的特点是什么?它们与矩阵分解非常相似,实际上矩阵分解是一种特殊情况,仅以 user_id 和 item_id 作为输入。然而,如果我们将它们与任意网络进行比较,限制晚期交叉(不允许来自不同塔的输入在最后阶段之前融合)使得双塔网络在应用中极为高效。为了为单个用户构建推荐,我们只需要计算一次查询塔,然后将该嵌入与通常预先计算的文档嵌入相乘。这个过程非常快速。此外,这些预先计算的文档嵌入可以组织成一个 ANN 索引(例如,HNSW),以便快速找到好的候选项,而无需遍历整个数据库。
分层可导航小世界(HNSW)是一种用于近似最近邻搜索的最先进算法...
towardsdatascience.com
我们可以通过以某种规律异步计算用户部分而不是对每个查询进行计算来实现更高的效率。然而,这意味着需要牺牲对实时历史和上下文的考虑。
塔本身可以相当复杂。例如,在用户部分,我们可以使用自注意力机制处理历史记录,从而实现个性化的变换器。但引入晚期交叉限制的代价是什么?自然,它影响质量。在相同的注意力机制中,我们不能使用当前希望推荐的项目。理想情况下,我们希望关注用户历史中的相似项目。因此,具有早期交叉的网络通常在排序的后期阶段使用,当只剩下几十个或几百个候选时,而具有晚期交叉(双塔)的网络则在早期阶段和候选生成中使用。
(然而,有一个纯理论的观点认为,任何合理的文档排名都可以通过足够维度的嵌入来编码。此外,NLP 中的解码器实际上也是基于相同的原理,只是对每个标记重新计算查询塔。)
损失函数和负样本采样
一个特别关注的点是用于训练双塔网络的损失函数。原则上,它们可以使用任何损失函数进行训练,针对不同的结果,甚至对不同的头部使用多个不同的损失函数(每个塔中有不同的嵌入)。然而,一个有趣的变体是使用批量内负样本上的 softmax 损失进行训练。对于数据集中每个查询-文档对,其他在同一小批次中的文档被用作 softmax 损失中的负样本。这种方法是一种高效的困难负样本挖掘形式。
但考虑这种损失函数的概率解释是很重要的,而这并不总是被很好地理解。在训练好的网络中,

得分的指数与给定查询的文档的先验概率成比例,而不是与特定于查询的 PMI(点对点互信息)成比例。更受欢迎的文档不一定会被这种模型更频繁地推荐,因为在训练过程中,它们作为负样本的出现频率相对较高。使用得分作为特征可能是有益的,但对于最终的排序和候选生成,这可能导致非常具体但质量较差的文档。
谷歌在一篇论文中建议通过训练中的 logQ 校正来应对这个问题。而我们则通常在应用阶段处理这个问题,而不是训练阶段,通过简单地乘以文档的先验概率 P(d)。然而,我们从未比较过这些方法,这确实是一个有趣的比较。
隐式正则化:连接 ALS 与现代神经网络
有一种协同过滤算法叫做隐式 ALS(IALS)。我已经提到过它。在神经网络时代之前,它无疑是最受欢迎的算法之一。其显著特点是有效的‘挖掘’负样本:所有没有互动历史的用户-对象对都被视为负样本(虽然权重低于实际互动)。此外,与实际挖掘不同,这些负样本没有被采样,而是在每次迭代中全部使用。这种方法被称为隐式正则化。
这怎么可能呢?考虑到合理的任务规模(用户和对象数量),应该有那么多负样本,甚至列出它们所需的时间都比整个训练过程还长。算法的美妙之处在于,通过使用 MSE 损失和最小二乘法,可以在每次完整迭代之前分别为所有用户和所有对象预先计算某些元素,这足以进行隐式正则化。这样,算法避免了二次大小。(有关更多细节,请参阅我当时最喜欢的论文之一)。
几年前,我考虑过是否可以将这个隐式正则化的奇妙想法与更先进的双塔神经网络技术结合起来。这是一个复杂的问题,因为有随机优化而不是全批处理,并且对回退到 MSE 损失(至少对于整个任务;对于正则化来说可能还好)有顾虑,因为这往往会产生较差的结果。
我思考了很久,最终想出了一个解决方案!有几周的时间,我兴奋不已,热切期待我们如何用这个方案替代批量负样本。
然后,当然(如同在这种情况下经常发生的那样),我在一篇论文中读到一切早在三年前就已经被想到过了。再次,它是谷歌。后来,在那篇关于 logQ 校正的论文中,他们展示了 softmax 损失与批量负样本的组合比隐式正则化效果更好。
就这样,我们能够节省时间而没有测试这个想法🙂
我们真的需要负样本采样用于推荐模型吗?
毕竟,我们有真实的推荐印象实例,如果用户没有与这些实例互动,这些可以作为强负样本使用。(这不考虑推荐服务尚未启动且没有印象的情况。)
这个问题的答案并不那么简单;它取决于我们打算如何应用训练好的模型:是用于最终排序、候选生成,还是仅仅作为输入到另一个模型的特征。
当我们仅在实际展示上训练模型时,会发生什么?会出现相当强的选择偏差,模型只学会在特定上下文中区分那些文档。对于未展示的文档(或者更准确地说,查询-文档对),模型的表现会差很多:它可能会对一些文档进行过度预测,对其他文档进行低估。当然,这种效果可以通过在排名中应用探索来缓解,但通常这只是部分解决方案。
如果候选生成器以这种方式训练,它可能会针对一个查询生成大量文档,这些文档在这样的上下文中它从未见过,并且其预测被高估。在这些文档中,常常会有完全无用的内容。如果最终排序模型足够好,它会过滤掉这些文档,并不会展示给用户。然而,我们仍然不必要地浪费候选配额在这些文档上(而且可能根本没有合适的文档)。因此,候选生成器应以一种理解大部分文档库质量较差并且不应被推荐(提名为候选)的方式进行训练。负采样是一个好的方法。
在这方面,最终排序模型与候选生成非常相似,但有一个重要的区别:它们从错误中学习。当模型通过对某些文档的预测过高而出错时,这些文档会展示给用户,并可能被纳入下一个训练数据集。我们可以在这个新数据集上重新训练模型,并再次推出给用户。新的假阳性会出现。数据集收集和重新训练过程可以重复,从而形成一种主动学习。实际上,只需几次重新训练迭代即可使过程收敛,并使模型停止推荐无用内容。当然,必须权衡随机推荐的危害,有时值得采取额外的预防措施。但总体而言,这里不需要负采样。相反,它可能会损害探索,使系统停留在局部最优。
如果模型用于将特征作为输入传递给另一个模型,那么相同的逻辑适用,但对随机候选文档的预测过高的伤害更不显著,因为其他特征可以帮助调整最终预测。(如果文档甚至没有进入候选列表,我们不会为其计算特征。)
曾经我们直接测试发现,作为特征的标准 ALS 比 IALS 表现更好,但不应用于候选生成。
总结而言,我们的探索强调了双塔网络在排序中的有效性,研究了损失函数和负采样在模型准确性中的重要性,通过隐式正则化弥合了与经典协同过滤的差距,并讨论了负采样在推荐系统中的核心作用。此次讨论突显了推荐系统技术的不断演变的复杂性和精密性。
R 中的双因素 ANOVA
了解如何在 R 中进行双因素 ANOVA。你还将学习其目的、假设、假设条件以及如何解释结果。
·发布于数据科学前沿 ·23 分钟阅读·2023 年 6 月 19 日
--

图片来源:内森·杜姆劳
介绍
双因素 ANOVA(方差分析)是一种统计方法,允许评估两个 分类 变量对 定量连续 变量的同时影响。
双因素 ANOVA 是单因素 ANOVA 的扩展,因为它允许评估两个分类变量对数值响应的影响。双因素 ANOVA 相对于单因素 ANOVA 的优势在于我们可以测试两个变量之间的关系,同时考虑第三个变量的影响。此外,它还允许包括两个分类变量对响应的可能交互作用。
双因素 ANOVA 相对于单因素 ANOVA 的优势与多元线性回归相对于相关性的优势类似:
-
相关性测量两个定量变量之间的关系。多元线性回归也测量两个变量之间的关系,但这次考虑了其他协变量的潜在影响。
-
单因素 ANOVA 测试定量变量在各组之间是否存在差异。双因素 ANOVA 也测试定量变量在各组之间是否存在差异,但这次考虑了另一个定性变量的影响。
之前,我们讨论了 单因素方差分析在 R 中的应用。现在,我们展示了在 R 中执行两因素方差分析的时机、原因和方法。
在继续之前,我想提及并简要描述一些相关的统计方法和测试,以避免任何混淆:
学生 t 检验用于评估一个分类变量对定量连续变量的影响,当分类变量恰好有 2 个水平时:
-
如果观察值是独立的(例如:比较女性和男性的年龄),则使用学生 t 检验 针对独立样本。
-
如果观察值是依赖的,即成对出现(例如,当相同的受试者在两个不同时间点进行两次测量时,前后测量),则使用学生 t 检验 针对配对样本。
为了评估一个分类变量对定量变量的影响,当分类变量有 3 个或更多水平时:1
-
单因素方差分析(通常简称为 ANOVA)如果组是独立的(例如,一个接受治疗 A 的患者组,一个接受治疗 B 的患者组,以及一个没有接受治疗或接受安慰剂的患者组)。
-
重复测量方差分析 如果组是依赖的(例如,当相同的受试者在三个不同时间点进行三次测量时,治疗前、治疗中和治疗后)。
两因素方差分析用于评估 2 个分类变量(及其潜在交互作用)对定量连续变量的影响。这是本文的主题。
线性回归用于评估定量连续因变量与一个或多个自变量之间的关系:
-
如果只有一个自变量(可以是定量或定性),则为简单线性回归。
-
如果至少有两个自变量(可以是定量、定性或两者的混合),则为多元线性回归。
ANCOVA(协方差分析)用于评估分类变量对定量变量的影响,同时控制另一个定量变量(称为协变量)的影响。ANCOVA 实际上是多元线性回归的一种特殊情况,其中包含一个定性和一个定量自变量。
在这篇文章中,我们首先解释了何时以及为何两因素方差分析是有用的,然后进行一些初步的描述性分析,并展示如何在 R 中进行两因素方差分析。最后,我们展示了如何解释和可视化结果。我们还简要提及并说明如何验证基本假设。
数据
为了说明如何在 R 中进行双因素 ANOVA,我们使用 {palmerpenguins} 包提供的 penguins 数据集。
我们不需要 导入数据集,但我们需要先 加载包,然后调用数据集:
# install.packages("palmerpenguins")
library(palmerpenguins)
dat <- penguins # rename dataset
str(dat) # structure of dataset
## tibble [344 × 8] (S3: tbl_df/tbl/data.frame)
## $ species : Factor w/ 3 levels "Adelie","Chinstrap",..: 1 1 1 1 1 1 1 1 1 1 ...
## $ island : Factor w/ 3 levels "Biscoe","Dream",..: 3 3 3 3 3 3 3 3 3 3 ...
## $ bill_length_mm : num [1:344] 39.1 39.5 40.3 NA 36.7 39.3 38.9 39.2 34.1 42 ...
## $ bill_depth_mm : num [1:344] 18.7 17.4 18 NA 19.3 20.6 17.8 19.6 18.1 20.2 ...
## $ flipper_length_mm: int [1:344] 181 186 195 NA 193 190 181 195 193 190 ...
## $ body_mass_g : int [1:344] 3750 3800 3250 NA 3450 3650 3625 4675 3475 4250 ...
## $ sex : Factor w/ 2 levels "female","male": 2 1 1 NA 1 2 1 2 NA NA ...
## $ year : int [1:344] 2007 2007 2007 2007 2007 2007 2007 2007 2007 2007 ...
数据集包含 344 只企鹅的 8 个变量,汇总如下:
summary(dat)
## species island bill_length_mm bill_depth_mm
## Adelie :152 Biscoe :168 Min. :32.10 Min. :13.10
## Chinstrap: 68 Dream :124 1st Qu.:39.23 1st Qu.:15.60
## Gentoo :124 Torgersen: 52 Median :44.45 Median :17.30
## Mean :43.92 Mean :17.15
## 3rd Qu.:48.50 3rd Qu.:18.70
## Max. :59.60 Max. :21.50
## NA's :2 NA's :2
## flipper_length_mm body_mass_g sex year
## Min. :172.0 Min. :2700 female:165 Min. :2007
## 1st Qu.:190.0 1st Qu.:3550 male :168 1st Qu.:2007
## Median :197.0 Median :4050 NA's : 11 Median :2008
## Mean :200.9 Mean :4202 Mean :2008
## 3rd Qu.:213.0 3rd Qu.:4750 3rd Qu.:2009
## Max. :231.0 Max. :6300 Max. :2009
## NA's :2 NA's :2
在这篇文章中,我们将重点关注以下三个变量:
-
species: 企鹅的物种(Adelie,Chinstrap 或 Gentoo) -
sex: 企鹅的性别(雌性和雄性) -
body_mass_g: 企鹅的体重(以克为单位)
如果需要,可以通过在 R 中运行 ?penguins 来获取关于此数据集的更多信息。
body_mass_g 是定量连续变量,将作为因变量,而 species 和 sex 都是定性变量。
这两个最后的变量将是我们的独立变量,也称为因素。确保它们被 R 读作 factors。如果不是这种情况,它们需要被 转换为因素。
双因素 ANOVA 的目标和假设
如上所述,双因素 ANOVA 用于同时评估两个分类变量对一个定量连续变量的影响。
它被称为双因素 ANOVA,因为我们比较的组是由两个独立的分类变量形成的。
在这里,我们想知道体重是否依赖于物种和/或性别。特别是,我们感兴趣的是:
-
测量和测试物种与体重之间的关系,
-
测量和测试性别与体重之间的关系,并且
-
潜在地检查物种与体重之间的关系是否对雌性和雄性不同(这等同于检查性别与体重之间的关系是否依赖于物种)
前两个关系被称为主要效果,而第三点被称为交互效应。
主要效果测试是否至少有一个组与另一个组不同(同时控制其他独立变量)。另一方面,交互效应旨在测试两个变量之间的关系是否依赖于第三个变量的水平。
在进行双因素 ANOVA 时,测试交互效应不是强制性的。然而,忽略交互效应可能会导致错误的结论,如果交互效应存在的话。
如果我们回到我们的例子,我们有以下 假设检验:
性别对体重的主要影响:
-
H0: 女性和男性的平均体重相等
-
H1: 女性和男性的平均体重不同
物种对体重的主要影响:
-
H0: 所有 3 个物种的平均体重相等
-
H1:至少有一个物种的体重平均值不同。
性别和物种之间的交互作用:
-
H0:性别和物种之间没有交互作用,意味着物种和体重之间的关系对雌性和雄性相同(同样,性别和体重之间的关系对所有 3 种物种相同)。
-
H1:性别和物种之间存在交互作用,意味着物种和体重之间的关系对雌性和雄性不同(同样,性别和体重之间的关系依赖于物种)。
两因素方差分析的假设
大多数统计检验需要一些假设以确保结果有效,而两因素方差分析也不例外。
两因素方差分析的假设与单因素方差分析类似。总结如下:
-
变量类型:依赖变量必须是定量连续的,而两个自变量必须是分类的(至少有两个水平)。
-
独立性:观察值在组间和组内应相互独立。
-
正态性:
-
对于小样本,数据应大致遵循正态分布
-
对于大样本(通常每组/样本 n ≥ 30),不要求正态性(感谢中心极限定理)。
-
方差齐性:各组之间的方差应相等。
-
异常值:任何组中都不应有显著的异常值。
关于这些假设的更多细节可以在单因素方差分析的假设中找到。
现在我们已经看到了两因素方差分析的基本假设,在应用测试和解释结果之前,我们会专门审查这些假设在我们的数据集中的适用性。
变量类型
依赖变量体重是定量连续的,而两个自变量性别和物种是定性变量(至少有 2 个水平)。
因此,这个假设得到满足。
独立性
独立性通常根据实验设计和数据收集方式来检查。
为了简单起见,观察通常是:
-
独立的,如果每个实验单元(此处为企鹅)仅测量一次,且观察值来自于一个具有代表性和随机选择的样本部分,或者
-
依赖的,如果每个实验单元至少被测量两次(例如,在医学领域,通常在同一受试者上进行两次测量;一次是在治疗前,一次是在治疗后)。
在我们的案例中,体重只在每只企鹅上测量一次,并且在一个具有代表性和随机的样本中测量,因此独立性假设得到满足。
正态性
我们在所有子组中都有一个大样本(两个因素水平的每种组合,称为单元):
table(dat$species, dat$sex)
##
## female male
## Adelie 73 73
## Chinstrap 34 34
## Gentoo 58 61
所以正态性无需检查。
为了完整起见,我们仍然展示如何验证正态性,假如我们有一个小样本。
有几种方法可以测试正态性假设。最常见的方法包括:
最简单/最短的方式是通过残差的 QQ 图来验证正态性。要绘制此图,我们首先需要保存模型:
# save model
mod <- aov(body_mass_g ~ sex * species,
data = dat
)
这段代码会进一步解释。
现在我们可以绘制残差的 QQ 图。我们展示两种方法,首先是使用plot()函数,其次是使用来自{car}包的qqPlot()函数:
# method 1
plot(mod, which = 2)

按作者绘图
# method 2
library(car)
qqPlot(mod$residuals,
id = FALSE # remove point identification
)

按作者绘图
方法 1 的代码稍短,但缺少参考线周围的置信区间。
如果点沿直线(称为亨利线)分布并且落在置信带内,我们可以假设正态性。在这里是这种情况。
如果你更倾向于通过残差的直方图来验证正态性,这里是代码:
# histogram
hist(mod$residuals)

按作者绘图
残差的直方图显示了高斯分布,这与 QQ 图的结论一致。
虽然 QQ 图和直方图在验证正态性方面已基本足够,但如果你希望通过统计检验更正式地测试正态性,可以对残差应用 Shapiro-Wilk 检验:
# normality test
shapiro.test(mod$residuals)
##
## Shapiro-Wilk normality test
##
## data: mod$residuals
## W = 0.99776, p-value = 0.9367
⇒ 我们不拒绝残差服从正态分布的原假设(p 值 = 0.937)。
从 QQ 图、直方图和 Shapiro-Wilk 检验中,我们得出结论,不拒绝残差正态性的原假设。
正态性假设因此得到验证,我们现在可以检查方差的相等性。2
方差的同质性
方差的同质性,也称为方差的均匀性或齐性,可以通过plot()函数直观地验证:
plot(mod, which = 3)

按作者绘图
由于残差的分布是恒定的,红色平滑线是水平和扁平的,因此看起来这里的恒定方差假设得到满足。
上述诊断图足够,但如果你愿意,也可以使用 Levene 检验(也来自{car}包)进行更正式的测试:3
leveneTest(mod)
## Levene's Test for Homogeneity of Variance (center = median)
## Df F value Pr(>F)
## group 5 1.3908 0.2272
## 327
⇒ 我们未拒绝方差相等的原假设(p 值 = 0.227)。
视觉和正式方法得出了相同的结论;我们未拒绝方差齐性的假设。
异常值
检测异常值最简单和最常见的方法是通过组的箱线图进行视觉检查。
对于雌性和雄性:
library(ggplot2)
# boxplots by sex
ggplot(dat) +
aes(x = sex, y = body_mass_g) +
geom_boxplot()

作者绘制
对于三种物种:
# boxplots by species
ggplot(dat) +
aes(x = species, y = body_mass_g) +
geom_boxplot()

作者绘制
根据四分位距标准,物种 Chinstrap 有两个异常值。这些点的极端程度不足以偏倚结果。
因此,我们认为满足无显著异常值的假设。
双因素 ANOVA
我们已经展示了所有假设都得到满足,所以现在可以继续在 R 中实施双因素 ANOVA。
这将帮助我们回答以下研究问题:
-
控制物种后,两个性别之间的体重是否存在显著差异?
-
控制性别后,体重在至少一种物种中是否存在显著差异?
-
物种与体重之间的关系在雌性和雄性企鹅中是否不同?
初步分析
在进行任何统计测试之前,进行一些描述性统计是一个好的做法,以便对数据有一个初步的了解,并且可能对预期结果有所了解。
这可以通过描述性统计或图形完成。
描述性统计
如果我们想保持简单,我们可以只计算每个子组的均值:
# mean by group
aggregate(body_mass_g ~ species + sex,
data = dat,
FUN = mean
)
## species sex body_mass_g
## 1 Adelie female 3368.836
## 2 Chinstrap female 3527.206
## 3 Gentoo female 4679.741
## 4 Adelie male 4043.493
## 5 Chinstrap male 3938.971
## 6 Gentoo male 5484.836
或者最终,使用{dplyr}包计算每个子组的均值和标准差:
# mean and sd by group
library(dplyr)
group_by(dat, sex, species) %>%
summarise(
mean = round(mean(body_mass_g, na.rm = TRUE)),
sd = round(sd(body_mass_g, na.rm = TRUE))
)
## # A tibble: 8 × 4
## # Groups: sex [3]
## sex species mean sd
## <fct> <fct> <dbl> <dbl>
## 1 female Adelie 3369 269
## 2 female Chinstrap 3527 285
## 3 female Gentoo 4680 282
## 4 male Adelie 4043 347
## 5 male Chinstrap 3939 362
## 6 male Gentoo 5485 313
## 7 <NA> Adelie 3540 477
## 8 <NA> Gentoo 4588 338
图形
如果你是博客的常读者,你知道我喜欢绘制图形来可视化数据,然后再解释测试结果。
当我们有一个定量变量和两个定性变量时,最合适的图形是按组绘制的箱线图。这可以很容易地使用[{ggplot2}](https://statsandr.com/blog/graphics-in-r-with-ggplot2/) 包制作:
# boxplot by group
library(ggplot2)
ggplot(dat) +
aes(x = species, y = body_mass_g, fill = sex) +
geom_boxplot()

作者绘制
性别的一些观察值缺失,我们可以将它们删除以获得更简洁的图形:
dat %>%
filter(!is.na(sex)) %>%
ggplot() +
aes(x = species, y = body_mass_g, fill = sex) +
geom_boxplot()

作者绘制
请注意,我们也可以绘制以下图形:
dat %>%
filter(!is.na(sex)) %>%
ggplot() +
aes(x = sex, y = body_mass_g, fill = species) +
geom_boxplot()

作者绘图
但为了获得更易读的图形,我倾向于将变量中层级最少的设置为颜色(实际上是aes()层中的fill参数),将类别最多的变量设置在 x 轴上(即aes()层中的x参数)。
从均值和子组的箱线图中,我们已经可以看到,在我们的样本中:
-
雌性企鹅的体重往往低于雄性,并且所有考虑的物种中都是如此,并且
-
相比其他两种物种,振翅企鹅的体重更高。
请记住,这些结论仅在我们的样本内有效!要将这些结论推广到总体,我们需要进行双因素方差分析并检查解释变量的显著性。这是下一节的目标。
R 中的双因素方差分析
如前所述,在双因素方差分析中包含交互作用效应并非强制性的。然而,为了避免错误结论,建议首先检查交互作用是否显著,并根据结果决定是否包括它。
如果交互作用不显著,可以安全地将其从最终模型中移除。相反,如果交互作用显著,应将其包含在最终模型中以解释结果。
因此,我们首先建立一个包括两个主要效应(即性别和物种)及交互作用的模型:
# Two-way ANOVA with interaction
# save model
mod <- aov(body_mass_g ~ sex * species,
data = dat
)
# print results
summary(mod)
## Df Sum Sq Mean Sq F value Pr(>F)
## sex 1 38878897 38878897 406.145 < 2e-16 ***
## species 2 143401584 71700792 749.016 < 2e-16 ***
## sex:species 2 1676557 838278 8.757 0.000197 ***
## Residuals 327 31302628 95727
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 11 observations deleted due to missingness
平方和(Sum Sq列)显示物种解释了体重变化的大部分。这是解释这种变化的最重要因素。
p 值显示在上述输出的最后一列(Pr(>F))。从这些 p 值中,我们得出结论,在 5%的显著性水平下:
-
在控制了物种的情况下,两个性别之间的体重显著不同,
-
控制性别的情况下,至少有一种物种的体重显著不同,并且
-
性别和物种之间的交互作用(在上述输出中的
sex:species行显示)是显著的。
因此,从显著的交互作用效应中,我们刚刚看到体重与物种之间的关系在雄性和雌性之间是不同的。由于它是显著的,我们必须将其保留在模型中,并应解释该模型的结果。
如果相反,交互作用不显著(即 p 值≥0.05),我们将从模型中移除这个交互作用效应。为了说明,下面是一个没有交互作用的双因素方差分析代码,称为加性模型:
# Two-way ANOVA without interaction
aov(body_mass_g ~ sex + species,
data = dat
)
对于习惯于在R 中进行线性回归的读者,你会注意到双因素方差分析的代码结构实际上是相似的:
-
公式是
dependent variable ~ independent variables -
+符号用于包含没有交互作用的独立变量4 -
*符号用于包含具有交互作用的独立变量与。
与线性回归的相似性并不令人惊讶,因为双因素方差分析,就像所有方差分析一样,实际上是一个线性模型。
注意以下代码也有效,并且给出相同的结果:
# method 2
mod2 <- lm(body_mass_g ~ sex * species,
data = dat
)
Anova(mod2)
## Anova Table (Type II tests)
##
## Response: body_mass_g
## Sum Sq Df F value Pr(>F)
## sex 37090262 1 387.460 < 2.2e-16 ***
## species 143401584 2 749.016 < 2.2e-16 ***
## sex:species 1676557 2 8.757 0.0001973 ***
## Residuals 31302628 327
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
注意 aov() 函数假设平衡设计,即在我们的独立分组变量的水平内样本大小相等。此外,aov() 使用类型 I 方差和平方和,因此当我们写 y ~ A * B 和 y ~ B * A 时可能会得到不同的 p 值。
对于不平衡设计,即每个子组中的受试者数量不相等,推荐的方法是:
-
当没有显著交互作用时使用类型 II 方差分析,这可以在 R 中使用
Anova(mod, type = "II")来完成,5,并且 -
当存在显著交互作用时使用类型 III 方差分析,这可以在 R 中使用
Anova(mod, type = "III")来完成。
这超出了帖子范围,我们假设这里是平衡设计。对感兴趣的读者,参见 详细讨论 关于类型 I、类型 II 和类型 III 方差分析。
两两比较
通过两个主要效应显著,我们得出结论:
-
控制物种时,体重在女性和男性之间有所不同,并且
-
控制性别时,体重在至少一个物种中有所不同。
如果体重在两个性别之间不同,鉴于性别正好有两个,这必然是因为体重在女性和男性之间显著不同。
如果想知道哪个性别的体重最高,可以通过均值和/或子组箱线图来推测。这里,显然男性的体重显著高于女性。
然而,对于物种来说,情况并非如此简单。让我解释为什么这不像性别那样容易。
有三种物种(阿德利企鹅、刺胸企鹅和绵毛企鹅),因此有 3 对物种:
-
阿德利企鹅和刺胸企鹅
-
阿德利企鹅和绵毛企鹅
-
刺胸企鹅和绵毛企鹅
如果体重在至少一个物种中显著不同,可能是因为:
-
体重在阿德利企鹅和刺胸企鹅之间显著不同,但在阿德利企鹅和绵毛企鹅之间没有显著不同,也在刺胸企鹅和绵毛企鹅之间没有显著不同,或者
-
体重在阿德利企鹅和绵毛企鹅之间显著不同,但在阿德利企鹅和刺胸企鹅之间没有显著不同,也在刺胸企鹅和绵毛企鹅之间没有显著不同,或者
-
体重在刺胸企鹅和绵毛企鹅之间显著不同,但在阿德利企鹅和刺胸企鹅之间没有显著不同,也在阿德利企鹅和绵毛企鹅之间没有显著不同。
或者,也可能是:
-
体重在 Adelie 和 Chinstrap 之间、Adelie 和 Gentoo 之间显著不同,但在 Chinstrap 和 Gentoo 之间没有显著差异,或者
-
体重在 Adelie 和 Chinstrap 之间、Chinstrap 和 Gentoo 之间显著不同,但在 Adelie 和 Gentoo 之间没有显著差异,或者
-
体重在 Chinstrap 和 Gentoo 之间、Adelie 和 Gentoo 之间显著不同,但在 Adelie 和 Chinstrap 之间没有显著差异。
最后,体重也可能在所有物种之间存在显著差异。
至于单因素方差分析,在这个阶段,我们无法确切知道哪种物种在体重方面与其他物种不同。要知道这一点,我们需要通过事后检验(也称为成对比较)来两两比较每个物种。
有几种事后检验,最常见的是 Tukey HSD,它测试所有可能的组对。如前所述,这个检验只需要在物种变量上进行,因为性别只有两个水平。
至于单因素方差分析,Tukey HSD 可以在 R 中按如下方式进行:
# method 1
TukeyHSD(mod,
which = "species"
)
## Tukey multiple comparisons of means
## 95% family-wise confidence level
##
## Fit: aov(formula = body_mass_g ~ sex * species, data = dat)
##
## $species
## diff lwr upr p adj
## Chinstrap-Adelie 26.92385 -80.0258 133.8735 0.8241288
## Gentoo-Adelie 1377.65816 1287.6926 1467.6237 0.0000000
## Gentoo-Chinstrap 1350.73431 1239.9964 1461.4722 0.0000000
或使用{multcomp}包:
# method 2
library(multcomp)
summary(glht(
aov(body_mass_g ~ sex + species,
data = dat
),
linfct = mcp(species = "Tukey")
))
##
## Simultaneous Tests for General Linear Hypotheses
##
## Multiple Comparisons of Means: Tukey Contrasts
##
##
## Fit: aov(formula = body_mass_g ~ sex + species, data = dat)
##
## Linear Hypotheses:
## Estimate Std. Error t value Pr(>|t|)
## Chinstrap - Adelie == 0 26.92 46.48 0.579 0.83
## Gentoo - Adelie == 0 1377.86 39.10 35.236 <1e-05 ***
## Gentoo - Chinstrap == 0 1350.93 48.13 28.067 <1e-05 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## (Adjusted p values reported -- single-step method)
或使用pairwise.t.test()函数,使用你选择的 p 值调整方法:6
# method 3
pairwise.t.test(dat$body_mass_g, dat$species,
p.adjust.method = "BH"
)
##
## Pairwise comparisons using t tests with pooled SD
##
## data: dat$body_mass_g and dat$species
##
## Adelie Chinstrap
## Chinstrap 0.63 -
## Gentoo <2e-16 <2e-16
##
## P value adjustment method: BH
请注意,使用第二种方法时,需要在glht()函数中指定没有交互作用的模型,即使交互作用显著。此外,不要忘记在我的代码中将mod和species替换为你的模型名称和独立变量名称。
两种方法得到的结果相同,即:
-
Chinstrap 和 Adelie 之间的体重没有显著差异(调整后的 p 值 = 0.83),
-
体重在 Gentoo 和 Adelie 之间(调整后的 p 值 < 0.001)显著不同,并且
-
体重在 Gentoo 和 Chinstrap 之间(调整后的 p 值 < 0.001)显著不同。
记住报告的是调整后的p 值,以防止比较多个组对时出现的多重检验问题。
如果你想比较所有组的组合,可以使用TukeyHSD()函数,并在which参数中指定交互作用:
# all combinations of sex and species
TukeyHSD(mod,
which = "sex:species"
)
## Tukey multiple comparisons of means
## 95% family-wise confidence level
##
## Fit: aov(formula = body_mass_g ~ sex * species, data = dat)
##
## $`sex:species`
## diff lwr upr p adj
## male:Adelie-female:Adelie 674.6575 527.8486 821.4664 0.0000000
## female:Chinstrap-female:Adelie 158.3703 -25.7874 342.5279 0.1376213
## male:Chinstrap-female:Adelie 570.1350 385.9773 754.2926 0.0000000
## female:Gentoo-female:Adelie 1310.9058 1154.8934 1466.9181 0.0000000
## male:Gentoo-female:Adelie 2116.0004 1962.1408 2269.8601 0.0000000
## female:Chinstrap-male:Adelie -516.2873 -700.4449 -332.1296 0.0000000
## male:Chinstrap-male:Adelie -104.5226 -288.6802 79.6351 0.5812048
## female:Gentoo-male:Adelie 636.2482 480.2359 792.2606 0.0000000
## male:Gentoo-male:Adelie 1441.3429 1287.4832 1595.2026 0.0000000
## male:Chinstrap-female:Chinstrap 411.7647 196.6479 626.8815 0.0000012
## female:Gentoo-female:Chinstrap 1152.5355 960.9603 1344.1107 0.0000000
## male:Gentoo-female:Chinstrap 1957.6302 1767.8040 2147.4564 0.0000000
## female:Gentoo-male:Chinstrap 740.7708 549.1956 932.3460 0.0000000
## male:Gentoo-male:Chinstrap 1545.8655 1356.0392 1735.6917 0.0000000
## male:Gentoo-female:Gentoo 805.0947 642.4300 967.7594 0.0000000
或者使用来自{agricolae}包的HSD.test()函数,该函数用相同的字母标记那些在统计上没有显著差异的子组:
library(agricolae)
HSD.test(mod,
trt = c("sex", "species"),
console = TRUE # print results
)
##
## Study: mod ~ c("sex", "species")
##
## HSD Test for body_mass_g
##
## Mean Square Error: 95726.69
##
## sex:species, means
##
## body_mass_g std r Min Max
## female:Adelie 3368.836 269.3801 73 2850 3900
## female:Chinstrap 3527.206 285.3339 34 2700 4150
## female:Gentoo 4679.741 281.5783 58 3950 5200
## male:Adelie 4043.493 346.8116 73 3325 4775
## male:Chinstrap 3938.971 362.1376 34 3250 4800
## male:Gentoo 5484.836 313.1586 61 4750 6300
##
## Alpha: 0.05 ; DF Error: 327
## Critical Value of Studentized Range: 4.054126
##
## Groups according to probability of means differences and alpha level( 0.05 )
##
## Treatments with the same letter are not significantly different.
##
## body_mass_g groups
## male:Gentoo 5484.836 a
## female:Gentoo 4679.741 b
## male:Adelie 4043.493 c
## male:Chinstrap 3938.971 c
## female:Chinstrap 3527.206 d
## female:Adelie 3368.836 d
如果你有多个组需要比较,绘图可能更容易解释:
# set axis margins so labels do not get cut off
par(mar = c(4.1, 13.5, 4.1, 2.1))
# create confidence interval for each comparison
plot(TukeyHSD(mod, which = "sex:species"),
las = 2 # rotate x-axis ticks
)

按作者绘图
从上述输出和图中,我们得出结论,所有性别和物种的组合之间都有显著差异,除了雌性 Chinstrap 和雌性 Adelie(p 值 = 0.138)以及雄性 Chinstrap 和雄性 Adelie(p 值 = 0.581)。
这些结果与上面展示的箱型图一致,并将通过下面的可视化得到确认,这些结果总结了 R 中的二元方差分析。
可视化
如果您想以不同于初步分析中已经呈现的方式可视化结果,以下是一些有用的绘图思路。
首先,使用 {effects} 包中的 allEffects() 函数绘制每个子组的均值和标准误:
# method 1
library(effects)
plot(allEffects(mod))

作者绘制
或使用 {ggpubr} 包:
# method 2
library(ggpubr)
ggline(subset(dat, !is.na(sex)), # remove NA level for sex
x = "species",
y = "body_mass_g",
color = "sex",
add = c("mean_se") # add mean and standard error
) +
labs(y = "Mean of body mass (g)")

作者绘制
另外,使用 {Rmisc} 和 {ggplot2}:
library(Rmisc)
# compute mean and standard error of the mean by subgroup
summary_stat <- summarySE(dat,
measurevar = "body_mass_g",
groupvars = c("species", "sex")
)# plot mean and standard error of the mean
ggplot(
subset(summary_stat, !is.na(sex)), # remove NA level for sex
aes(x = species, y = body_mass_g, colour = sex)
) +
geom_errorbar(aes(ymin = body_mass_g - se, ymax = body_mass_g + se), # add error bars
width = 0.1 # width of error bars
) +
geom_point() +
labs(y = "Mean of body mass (g)")

作者绘制
其次,如果您更喜欢仅绘制每个子组的均值:
with(
dat,
interaction.plot(species, sex, body_mass_g)
)

作者绘制
最后但同样重要的是,对于那些熟悉 GraphPad 的人,您可能已经熟悉如下方式绘制均值和误差条:
# plot mean and standard error of the mean as barplots
ggplot(
subset(summary_stat, !is.na(sex)), # remove NA level for sex
aes(x = species, y = body_mass_g, fill = sex)
) +
geom_bar(position = position_dodge(), stat = "identity") +
geom_errorbar(aes(ymin = body_mass_g - se, ymax = body_mass_g + se), # add error bars
width = 0.25, # width of error bars
position = position_dodge(.9)
) +
labs(y = "Mean of body mass (g)")

结论
在这篇文章中,我们首先回顾了用于比较组间定量变量的不同检验方法。然后我们集中讨论了二元方差分析,从其目标和假设到在 R 中的实现,以及解释和一些可视化。我们还简要提到了其基本假设和一种事后检验,用于比较所有子组。
所有这些都使用 {palmerpenguins} 包中提供的 penguins 数据集进行了说明。
感谢阅读。
我希望这篇文章能帮助您用您的数据进行二元方差分析。
一如既往,如果您对本文讨论的主题有任何问题或建议,请在评论中添加,以便其他读者也能从讨论中受益。
-
理论上,一元方差分析也可以用于比较 2 个组,而不仅仅是 3 个或更多组。然而,在实际操作中,通常会使用学生 t 检验来比较 2 个组,而使用一元方差分析来比较 3 个或更多组。使用独立样本的学生 t 检验和 2 个组的一元方差分析得出的结论会相似。↩︎
-
请注意,如果正态性假设未满足,可以应用许多变换来改善这一点,其中最常见的是对数变换(
log()函数在 R 中)。↩︎ -
请注意,Bartlett 检验也适用于检验方差齐性假设。↩︎
-
加性模型假设两个解释变量是独立的;它们之间没有相互作用。↩︎
-
其中
mod是您保存的模型的名称。↩︎ -
在这里,我们使用 Benjamini & Hochberg (1995)修正,但你可以选择多种方法。有关更多详情,请参见
?p.adjust。↩︎
相关文章
最初发表于 https://statsandr.com 于 2023 年 6 月 19 日。
双因素方差分析测试,使用 Python
原文:
towardsdatascience.com/two-way-anova-test-with-python-a112e2396d78
完全初学者的双因素方差分析测试指南(附代码!)
·发表于Towards Data Science ·6 分钟阅读·2023 年 1 月 5 日
--

图片由Sergey Pesterev拍摄,Unsplash提供
方差分析测试旨在检验三个或更多组之间均值的统计显著差异。常用的方差分析有两种类型,**单因素方差分析测试**和**双因素方差分析测试**。唯一的区别在于影响因变量的自变量的数量。
双因素方差分析
双因素方差分析是单因素方差分析的扩展,考察**两个不同的分类自变量或两个独立因素**对**一个连续因变量**的影响。
双因素方差分析不仅旨在测试每个独立因素的主要效应,还测试两个因素是否相互影响以影响因变量,即是否存在两个独立因素之间的相互作用。[2]
方差分析使用 F 检验,这是一种组间比较检验,用于检验统计显著性。它将每个组在不同因素(因素 A、因素 B、因素 A 与因素 B 之间的相互作用)下的均方差与因变量的总体方差进行比较。最后,基于 F 检验统计量做出结论。
平方和(SS)
在双因素方差分析表中:
变异性的总量来自四个可能的来源,即:
-
因素 A 下的组间变异,称为处理(A)
-
因素 B 下的组间变异,称为处理(B)
-
由于因素 A 和因素 B 之间的相互作用引起的平方和,称为相互作用(AB)
-
组内变异,称为误差(E)

图片 1. SS 和 d.f. 的示意图 作者提供
类似于平方和 (SS),d.f. (SSTO) = d.f. (SSA) + d.f. (SSB) + d.f. (SSAB) + d.f. (SSE)
SS 除以其 d.f.将得到均方 (MS)。
两因素 ANOVA 测试的假设与单因素 ANOVA 测试相同,即所有的参数检验假设,包括样本数据的随机性和独立性、正态性及方差齐性。如果你想了解更多细节,可以参考上一篇文章。
两因素 ANOVA 测试的简单概述
两因素 ANOVA 有三组假设:
集 1:
H₀: μₐ₁= μₐ₂ = μₐ₃ = … = μₐ𝒸
H₁: 不是所有的μₐᵢ在因素 A 下都是相等的,其中 i = 1, 2, 3, …, c。
显著性水平 = α

图片 2. 用于测试因素 A 主效应的 F 检验统计量。作者提供的图片。
集 2:
H₀: μᵦ₁= μᵦ₂ = μᵦ₃ = … = μᵦᵣ
H₁: 不是所有的μᵦᵢ在因素 B 下都是相等的,其中 i = 1, 2, 3, …, r。
显著性水平 = α

图片 3. 用于测试因素 B 主效应的 F 检验统计量。作者提供的图片。
集 3:
H₀: 一个独立变量的效应不依赖于另一个独立变量的效应,即因素 A 和因素 B 之间没有交互作用
H₁: 因素 A 和因素 B 之间存在交互作用
显著性水平 = α

图片 4. 用于测试两个独立因素之间是否存在交互作用的 F 检验统计量。作者提供的图片。
如果你执行带有交互作用的两因素 ANOVA 测试,你需要测试上述提到的所有 3 组假设。但如果你执行无交互作用的测试,你只需要测试集 1 和集 2 的假设。
最后,带有交互作用的两因素 ANOVA 表格如下所示:

表 1. 带交互作用的两因素 ANOVA 示例表。作者提供的图片。
两因素 ANOVA 表格(无交互作用)如下所示:

表 2. 无交互作用的两因素 ANOVA 示例表。作者提供的图片。
平衡设计与不平衡设计
平衡设计是指所有组合组的样本量相等的情况。在不平衡设计中,各组的样本量不相等。在两因素 ANOVA 中,如果组的样本量差异过大,普通的方差分析方法可能不够充分。对于不平衡设计,需要使用回归方法。另一种方法是尽力确保设计的平衡。
一个数据集,students.csv,包含 8239 行学生特征数据。每一行代表一个独特的学生。它包含与学生相关的 16 个特征,我们将只关注 3 个特征:专业、性别和薪资。
基于两个因素,专业和性别,是否存在不同性别和专业毕业生的年均薪资显著差异,以及性别和专业之间是否存在交互作用,显著性水平为 5%?
数据处理
从给定的数据集中,我们需要筛选出已毕业的学生并进行随机抽样。在这种情况下,它随机抽取了每组 40 名学生,即不同的(专业和性别)组合,以使其成为平衡设计。之后,选择关注的三个变量的数据集,即分类变量major, gender和数值变量salary。

图 5. 数据处理以实现平衡设计。图像来源:作者。
假设检验
根据假设检验的五步过程:
集合 1:
H₀: μₐ₁= μₐ₂ = μₐ₃ = … = μₐ₆
H₁: 在不同专业下薪资均值不相等
集合 2:
H₀: μᵦ₁= μᵦ₂
H₁: 在不同性别下薪资均值不相等
集合 3:
H₀: 专业和性别之间没有交互作用
H₁: 专业和性别之间存在交互作用
α = 0.05
根据 F 检验统计量:

图 6. 具有交互作用的 ANOVA 表:方差分析的正常方法。图像来源:作者。
我们还可以使用statsmodels包得到相同的结果,它使用回归方法。由于statsmodels使用回归方法,它也适用于不平衡设计,即你无需做大量工作来确保平衡设计。

图 7. 具有交互作用的 ANOVA 表:回归方法。图像来源:作者。
以下显示了专业和性别对薪资的交互作用图:

图 8. 专业和性别对薪资的交互作用图。图像来源:作者
结论
对于集合 1 和集合 2:由于 F 值 > F 临界值或 p 值 < 0.05,零假设被拒绝。∴我们有足够的证据表明,不同研究科目或性别的毕业生的平均薪资不相同,显著性水平为 5%。
对于集合 3:未能拒绝零假设。∴我们没有足够的证据表明研究科目和性别之间存在交互作用,显著性水平为 5%。此外,从交互作用图[4]中可以看出,没有交互作用,主要效应即专业和性别效应都显著。例如,生物学专业的男性毕业生的平均薪资会显著更高。
推荐阅读
完全初学者指南:执行 ANOVA 测试(附代码!)
towardsdatascience.com ## 卡方检验,使用 Python
完全初学者指南:进行卡方检验(附代码!)
towardsdatascience.com ## McNemar 检验,使用 Python
完全初学者指南:进行 McNemar 检验(附代码!)
towardsdatascience.com [## 单样本假设检验,使用 Python
完全初学者指南:进行单样本假设检验(附代码!)
levelup.gitconnected.com](https://levelup.gitconnected.com/how-to-perform-one-sample-hypothesis-tests-with-python-308eae8789fc?source=post_page-----a112e2396d78--------------------------------) [## 双样本假设检验,使用 Python
完全初学者指南:进行双样本假设检验(附代码!)
levelup.gitconnected.com](https://levelup.gitconnected.com/two-sample-hypothesis-tests-with-python-43e1b8c52306?source=post_page-----a112e2396d78--------------------------------)
参考文献
[1] “单因素方差分析假设检验 • SOGA • 地球科学系。” [在线]. 可用:www.geo.fu-berlin.de/en/v/soga/Basics-of-statistics/ANOVA/One-way-ANOVA-Hypothesis-Test/index.html
[2] 双向方差分析 — 维基百科
[3] Kiernan, D. (2014). 第六章:双向方差分析。Open SUNY 教科书。
[4] 第七章 方差分析与交互 | STA 265 讲义(统计与数据科学方法)。 (无日期). 取自 2023 年 1 月 2 日,campus.murraystate.edu/academic/faculty/cmecklin/STA265/_book/anova-with-interaction.html#the-interactive-two-way-anova-model
两种本地下载和访问 Llama 2 的方法
原文:
towardsdatascience.com/two-ways-to-download-and-access-llama-2-locally-8a432ed232a4
在你的 PC 上使用 Llama 2 的逐步指南
·发布于 Towards Data Science ·10 分钟阅读·2023 年 9 月 5 日
--

图片来源:作者(Dreamstudio)
动机
Meta 最新发布的 Llama 2 正在获得越来越多的关注,并且对各种使用场景都非常有趣。它提供了不同大小的预训练和微调的 Llama 2 语言模型,从 7B 到 70B 参数。Llama 2 在推理、编码、能力和知识基准等各种测试中表现良好,这使它非常有前景。
在本文中,我们将逐步指导你在 PC 上下载 Llama 2 的过程。你有两个选项:官方的 Meta AI 网站或 HuggingFace。我们还会展示如何访问它,以便你可以利用其强大的功能来支持你的项目。让我们开始吧!
前提条件
-
Nvidia T4 图形处理单元 (GPU)
-
虚拟环境 (Virtualenv)
-
HuggingFace 账户、库以及 Llama 模型
-
Python 3.10
本地下载前需要考虑的事项
在将模型下载到本地机器之前,考虑一些事项。首先,确保你的计算机有足够的处理能力和存储空间(从 SSD 磁盘加载模型要快得多)。其次,准备进行一些初始设置以使模型运行。最后,如果你是出于工作需要使用此模型,请检查公司关于下载外部软件的政策。
为什么要本地下载 Llama 2?
你可能有几个很好的理由希望将模型下载到自己的计算机上,例如:
-
减少延迟 通过在你的环境中托管 Llama 2,你可以将与外部服务器的 API 调用相关的延迟降到最低。
-
数据隐私 你可以将私人和敏感信息保存在自己的生态系统中(本地或外部云提供商)。
-
定制和控制 您对模型拥有更多控制权。您可以优化机器的配置,进行优化技术的工作,对模型进行微调,并进一步将其集成到您的生态系统中。
-
离线访问 根据使用情况,模型可能托管在没有互联网连接的安全环境中。
选择获取“Llama 2”的来源
决定从哪里获取“Llama 2”是基于对您最合适的选择。以下是一些考虑因素,以帮助您做出选择。
Meta 的 GitHub:
当您从 Meta 的 GitHub 获取“Llama 2”时,您直接从源头获取。这使您可以访问最新的更新。然而,如果遇到问题,社区可能不会像 HuggingFace 那样反应迅速。文档很好,但尝试示例可能需要更多编码。
Hugging Face:
使用 Hugging Face 非常简单,因为它具有用户友好的平台和反应迅速且强大的社区支持。它兼容多个框架,使得将模型集成到现有技术栈中变得更加容易。
因此,如果您需要定制和见解,建议直接从 Meta 的 GitHub 获取模型;如果需要易用性、社区支持和与各种框架的兼容性,可以选择 Hugging Face。
1️⃣ 从 Meta 网站下载 Llama 2
步骤 1:请求下载
下载 Llama 2 模型权重和分词器的一个选项是Meta AI 网站。在下载模型权重和分词器之前,您必须阅读并同意许可协议,并通过提供您的电子邮件地址提交请求。填写以下信息并接受条款:

作者提供的图片
一旦您的请求被批准,您将通过电子邮件收到一个signed URL。
提个小提醒!提供的下载模型权重和分词器的链接仅在 24 小时内有效,并且下载次数有限。因此,如果您看到诸如“403: Forbidden”的错误,请不要担心!您可以通过返回 Meta AI 网站请求一个新链接。
步骤 2:获取 download.sh 脚本
在继续之前,请确保您已经安装了wget和md5sum。您可以在Meta 的 GitHub 仓库找到所需的 download.sh 脚本。克隆该仓库并按如下方式进入llama目录:
git clone https://github.com/facebookresearch/llama.git
cd llama
通过输入以下命令确保您赋予脚本执行权限:
chmod +x download.sh
步骤 3:启动下载过程
要启动下载过程,您需要运行download.sh脚本。在此过程中,系统会提示您提供通过电子邮件发送的 URL 以及您希望下载的模型。
您可以选择下载两种不同类型的模型:
-
预训练 — Llama-2–7b, Llama-2–13b, Llama-2–70b
-
微调的聊天 — Llama-2–7b-chat,Llama-2–13b-chat,Llama-2–70b-chat
就我而言,我会获得 Llama-2–7b 和 Llama-2–7b-chat。
bash download.sh

如果下载成功,你应该能找到分词器和模型 llama-2–7b 及 llama-2–7b-chat。

作者提供的图片
步骤 4:准备本地环境
为了获得最佳隔离,建议建立一个全新的本地环境;我个人使用 Conda 环境管理系统。让我们开始创建新的 Conda 环境:
conda create --name yourenvname python=3.10
用你想要给环境的名称替换 yourenvname,用首选的 Python 版本替换 3.10。创建环境后,你可以用以下命令激活它:
conda activate yourenvname
然后,导航到克隆的仓库并安装 requirements.txt 中提到的所需库。
pip install -r requirements.txt
还有一件事你需要做:以允许你更改代码并立即查看效果的方式安装项目包,而不必重新安装。要实现这一点,请运行以下命令:
pip install -e .
既然我们已经准备好了,让我们运行模型看看会发生什么。
4. 使用 torchrun 运行推理
Torchrun 是 PyTorch 中的一个工具,通过自动分配工作者、处理故障、支持弹性设置和提供超越 torch.distributed.launch 的功能(包括自定义入口点、参数传递和日志捕获)来简化分布式训练。
在克隆的仓库中,你应该看到两个示例:example_chat_completion.py 和 example_text_completion.py。
由于两个脚本都设计用于分布式训练,我们需要设置一些变量。你可以像下面这样简单地导出它们,或将它们添加到 .bashrc 中。
export RANK=1
export WORLD_SIZE=0
export MASTER_ADDR=localhost
export MASTER_PORT=12355
-
RANK:分布式训练组中当前进程的等级。 -
WORLD_SIZE:分布式组中的总进程数。 -
MASTER_ADDR:协调训练的主节点地址。 -
MASTER_PORT:用于与主节点通信的端口号。
要执行 torchrun,我们需要:
-
将
nproc-per-node定义为可用的 GPU 数量, -
提供
script.py, -
通过
ckpt_dir指定模型检查点目录, -
使用
tokenizer_path指定分词器的路径。
torchrun --nproc-per-node=NUM_GPUS_YOU_HAVE your_script.py \
-- ckpt_dir /path/to/checkpoint \
-- tokenizer_path /path/to/tokenizer
让我们运行 example_text_completion.py,其中初始提示为:

作者提供的图片
torchrun --nproc_per_node 1 example_text_completion.py \
--ckpt_dir llama-2-7b/ \
--tokenizer_path tokenizer.model \
--max_seq_len 128 --max_batch_size 4

作者提供的图片
就这样。你成功了!现在你可以更改提示并尝试其他模型 😃。
简短回顾:
访问 Meta 官方网站 并申请下载权限。
访问 Llama 2 仓库 在 GitHub 上并下载 download.sh 脚本。
执行 download.sh 并提供通过电子邮件发送的签名 URL:
https://download.llamameta.net/*?YOUR_SIGNED_URL 并选择要下载的模型权重
准备环境
使用 torchrun 进行干预
2️⃣ 从 HuggingFace 下载 Llama 2
步骤 1:请求下载
首先,确保你在 Meta AI 网站 上用与你的 Hugging Face 帐户关联的确切电子邮件地址请求下载。接受许可条款和可接受使用政策。完成后,你可以申请访问 Hugging Face 上的任何可用模型。
下面是当前可用模型的列表。

cc. Hugging Face
你将收到来自 HuggingFace 的确认访问许可的电子邮件。
步骤 2:从 HuggingFace 获取令牌
如果你还没有 HuggingFace 账户,你需要创建一个。创建账户后,登录 HuggingFace。登录后,找到右上角的Profile选项并选择Settings。

图片由作者提供
选择Access Tokens选项并点击New token按钮以生成令牌。

图片由作者提供
只需复制令牌并返回到你的笔记本中。在下一步中,我们将看到如何访问和下载模型。
步骤 3:对 HuggingFace 进行身份验证
首先,安装 Hugging Face 开发的 huggingface_hub 模块,它使你能够与 Hugging Face Model Hub 进行交互。该中心托管各种预训练模型。请注意,huggingface_hub.login() 需要 ipywidgets 包。
!pip install huggingface_hub ipywidgets
然后,导入 huggingface_hub 并按如下方式登录 Hugging Face:
import huggingface_hub
huggingface_hub.login()
当你运行 huggingface_hub.login() 时,你会被要求提供你的 Hugging Face 身份验证令牌。成功认证后,你可以下载 llama 模型。粘贴你的令牌并点击登录。如果认证成功,你应该会看到以下消息。
在身份验证通过后,你可以继续下载其中一个 llama 模型。我会选择meta-llama/Llama-2–7b-chat-hf。
步骤 4:下载 Llama 2 模型
首先安装所需的库。
你可以按如下方式检查可用的 GPU:
要检查你的 GPU 详细信息,如驱动版本、CUDA 版本、GPU 名称或使用指标,请在单元格中运行命令 !nvidia-smi。
然后,为了下载模型,我们需要从 PyTorch 和 Hugging Face 的 Transformers 导入所有必要的库,初始化 Llama-2–7b 聊天模型及其标记器,并将它们保存到磁盘。请查看以下示例:
执行完单元格后,你应该会在huggingface目录下看到模型。
在进一步操作之前检查目录。里面应该有什么?
-
config.json:将其视为模型操作的手册。 -
pytorch_model.bin:这是你的模型在 PyTorch 格式中的“大脑”。 -
必需的分词器文件:
special_tokens_map.json和tokenizer_config.json就像是你模型语言的词典。 -
tokenizer.modelLlama 2 分词器
第 5 步:从磁盘加载 Llama 2 模型
如果你已经将 Llama 2 模型存储在磁盘上,你应该先加载它们。
为此,你需要:
-
LlamaForCausalLM就像是 "Llama 2" 的大脑, -
LlamaTokenizer有助于 "Llama 2" 理解和分解单词。 -
模型的路径
我们已经到了最后一步——测试干预。
第 6 步:使用 HuggingFace 管道进行干预
我们可以通过使用 HuggingFace transformers 的管道来评估干预。利用管道,你可以快速完成复杂任务。
-
text-generation:指定管道用于生成文本。 -
model:你用于文本生成的预训练模型。 -
tokenizer:用于处理输入文本和解码模型输出的分词器。 -
device_map=”auto”:这尝试在最佳可用设备上运行模型(例如,若可用则为 GPU,否则为 CPU)。 -
max_new_tokens=512:限制生成的输出为 512 个标记。 -
num_return_sequences=1:只请求一个生成的序列。 -
eos_token_id=tokenizer.eos_token_id:序列结束标记 ID。
以我的情况为例,我想要一些摇滚乐队的建议,结果非常好。

从 HuggingFace 下载 Llama 的简短回顾:
访问Meta 官方网站并申请下载权限。
从 HuggingFace 获取令牌
认证 HuggingFace
下载 Llama 2 模型
从磁盘加载 Llama 2 模型
使用 HuggingFace 管道进行干预
最后的思考:
在本教程中,我们已经看到如何将 Llama 2 模型下载到本地 PC。你还可以通过使用量化、蒸馏等方法进一步提升模型的性能,我将在后续文章中讨论这些方法。务必在新的虚拟环境中执行所有这些步骤。在此过程中,请确保监控计算机的内存使用情况。很多奇怪的错误可能隐藏在内存问题后面。如果你想下载量化模型,请注意你可能需要将 bitsandbytes 库降级到 0.39.1
曾经尝试在 Medium 上点击“点赞”按钮多次吗?❤️
让我们成为朋友吧!✋ 别忘了 订阅!
如果你觉得我的故事很吸引人,并希望支持我的写作,我邀请你考虑成为 Medium 会员,你可以访问大量的生成 AI、数据工程和数据科学文章。
[## 使用我的推荐链接加入 Medium — Bildea Ana
作为 Medium 的会员,你的部分会员费将用于支持你阅读的作者,你也可以完全访问所有故事…
查看我关于生成式人工智能、MLOps 和负责任人工智能的文章合集。

生成式人工智能
查看列表11 篇故事



MLOps - 人工智能生产
查看列表4 篇故事



负责任的人工智能
查看列表1 篇故事
类型提示数据框用于静态分析和运行时验证
StaticFrame 如何实现全面的数据框类型提示
·
关注 发表在 Towards Data Science · 9 分钟阅读 · 2023 年 11 月 16 日
--
作者照片
自从 Python 3.5 引入类型提示以来,静态类型数据框通常只限于指定类型:
def process(f: DataFrame) -> Series: ...
这是不足的,因为它忽略了容器中包含的类型。一个 DataFrame 可能具有字符串列标签和三列整数、字符串和浮点值;这些特征定义了类型。具有此类类型提示的函数参数为开发人员、静态分析器和运行时检查器提供了理解接口期望的所有信息。 StaticFrame 2(我是该开源项目的主要开发者)现在允许这样做:
from typing import Any
from static_frame import Frame, Index, TSeriesAny
def process(f: Frame[ # type of the container
Any, # type of the index labels
Index[np.str_], # type of the column labels
np.int_, # type of the first column
np.str_, # type of the second column
np.float64, # type of the third column
]) -> TSeriesAny: ...
所有核心的 StaticFrame 容器现在都支持泛型规范。虽然可以静态检查,但一个新的装饰器 [@CallGuard](http://twitter.com/CallGuard).check 允许在函数接口上对这些类型提示进行运行时验证。此外,使用 Annotated 泛型,新的 Require 类定义了一系列强大的运行时验证器,允许按列或按行数据进行检查。最后,每个容器都暴露了一个新的 via_type_clinic 接口,用于推导和验证类型提示。这些工具共同提供了一种完整的类型提示和验证 DataFrame 的方法。
泛型 DataFrame 的要求
Python 的内置泛型类型(例如 tuple 或 dict)需要指定组件类型(例如 tuple[int, str, bool] 或 dict[str, int])。定义组件类型可以更准确地进行静态分析。尽管对于 DataFrame 也是如此,但很少有人尝试定义全面的 DataFrame 类型提示。
即使使用 pandas-stubs 包,Pandas 也不允许指定 DataFrame 组件的类型。Pandas DataFrame 允许广泛的原地变异,可能不适合进行静态类型化。幸运的是,StaticFrame 中提供了不可变的 DataFrame。
此外,直到最近,Python 用于定义泛型的工具并不适合 DataFrame。一个 DataFrame 具有可变数量的异构列类型对于泛型规范是一个挑战。使用在 Python 3.11 中引入的新的 TypeVarTuple(并在 typing_extensions 包中进行了回溯)更容易地对这种结构进行类型化。
TypeVarTuple 允许定义可以接受多个类型的泛型。 (详见 PEP 646)。借助这种新的类型变量,StaticFrame 可以定义一个通用的 Frame,它具有用于索引的 TypeVar、用于列的 TypeVar,以及用于零个或多个列类型的 TypeVarTuple。
泛型 Series 定义了一个用于索引的 TypeVar 和一个用于值的 TypeVar。StaticFrame 的 Index 和 IndexHierarchy 也是泛型的,后者再次利用 TypeVarTuple 来定义每个深度级别的可变数量的组件 Index。
StaticFrame 使用 NumPy 类型定义Frame的列类型,或者Series或Index的值类型。这允许严格指定大小的数值类型,如np.uint8或np.complex128;或广泛指定类型的类别,如np.integer或np.inexact。由于 StaticFrame 支持所有 NumPy 类型,因此对应关系是直接的。
使用泛型数据框定义接口
在上述示例的基础上,以下函数接口显示了将三列Frame转换为Series字典。通过组件类型提示提供了更多信息,函数的目的几乎显而易见。
from typing import Any
from static_frame import Frame, Series, Index, IndexYearMonth
def process(f: Frame[
Any,
Index[np.str_],
np.int_,
np.str_,
np.float64,
]) -> dict[
int,
Series[ # type of the container
IndexYearMonth, # type of the index labels
np.float64, # type of the values
],
]: ...
此函数处理来自开源资产定价(OSAP)数据集(公司级特征/个体/预测器)的信号表。每个表具有三列:安全标识符(标记为“permno”)、年和月(标记为“yyyymm”)以及信号(具有特定信号的名称)。
该函数忽略所提供Frame的索引(类型为Any),并创建由第一列“permno” np.int_值定义的组。返回以“permno”为键的字典,其中每个值是该“permno”的np.float64值的Series;索引是从np.str_“yyyymm”列创建的IndexYearMonth。(StaticFrame 使用 NumPy datetime64值来定义单位类型的索引:IndexYearMonth存储datetime64[M]标签。)
以下函数不返回dict,而是返回具有分层索引的Series。IndexHierarchy泛型指定了每个深度级别的组件Index;在此处,外部深度是从“permno”列派生的Index[np.int_],内部深度是从“yyyymm”列派生的IndexYearMonth。
from typing import Any
from static_frame import Frame, Series, Index, IndexYearMonth, IndexHierarchy
def process(f: Frame[
Any,
Index[np.str_],
np.int_,
np.str_,
np.float64,
]) -> Series[ # type of the container
IndexHierarchy[ # type of the index labels
Index[np.int_], # type of index depth 0
IndexYearMonth], # type of index depth 1
np.float64, # type of the values
]: ...
丰富的类型提示提供了一个自描述的接口,使功能明确。更好的是,这些类型提示可以用于与 Pyright(现在)和 Mypy(待完全支持TypeVarTuple)的静态分析。例如,使用两列np.float64的Frame调用此函数将在编辑器中失败静态分析类型检查或提供警告。
运行时类型验证
静态类型检查可能不足够:运行时评估提供了更强的约束,特别是对于动态或未完全(或错误地)类型提示的值。
基于名为TypeClinic的新运行时类型检查器,StaticFrame 2 引入了[@CallGuard](http://twitter.com/CallGuard).check,一个用于类型提示接口运行时验证的装饰器。支持所有 StaticFrame 和 NumPy 泛型,并支持大多数内置 Python 类型,即使嵌套深度很深。以下函数添加了[@CallGuard](http://twitter.com/CallGuard).check装饰器。
from typing import Any
from static_frame import Frame, Series, Index, IndexYearMonth, IndexHierarchy, CallGuard
@CallGuard.check
def process(f: Frame[
Any,
Index[np.str_],
np.int_,
np.str_,
np.float64,
]) -> Series[
IndexHierarchy[Index[np.int_], IndexYearMonth],
np.float64,
]: ...
现在使用 [@CallGuard](http://twitter.com/CallGuard).check 装饰,如果上述函数用于未标记的 np.float64 两列的 Frame,则会引发 ClinicError 异常,说明预期有三列,但只提供了两列,并且预期字符串列标签,但提供了整数标签。(要发出警告而不是引发异常,请使用 [@CallGuard](http://twitter.com/CallGuard).warn 装饰。)
ClinicError:
In args of (f: Frame[Any, Index[str_], int64, str_, float64]) -> Series[IndexHierarchy[Index[int64], IndexYearMonth], float64]
└── Frame[Any, Index[str_], int64, str_, float64]
└── Expected Frame has 3 dtype, provided Frame has 2 dtype
In args of (f: Frame[Any, Index[str_], int64, str_, float64]) -> Series[IndexHierarchy[Index[int64], IndexYearMonth], float64]
└── Frame[Any, Index[str_], int64, str_, float64]
└── Index[str_]
└── Expected str_, provided int64 invalid
运行时数据验证
其他特性可以在运行时进行验证。例如,shape 或 name 属性,或者索引或列上的标签顺序。StaticFrame 的 Require 类提供了一系列可配置的验证器。
-
Require.Name: 验证容器的name属性。 -
Require.Len: 验证容器的长度。 -
Require.Shape: 验证容器的shape属性。 -
Require.LabelsOrder: 验证标签的顺序。 -
Require.LabelsMatch: 验证包含标签而不考虑顺序。 -
Require.Apply: 将返回布尔值的函数应用于容器。
符合增长趋势,这些对象作为一个或多个额外参数提供给 Annotated 泛型的类型提示。 (有关详细信息,请参阅 PEP 593。)第一个 Annotated 参数引用的类型是后续参数验证器的目标。例如,如果将 Index[np.str_] 类型提示替换为 Annotated[Index[np.str_], Require.Len(20)] 类型提示,则会对与第一个参数关联的索引应用运行时长度验证。
扩展处理 OSAP 信号表的示例,我们可以验证列标签的期望。Require.LabelsOrder 验证器可以定义一系列标签,可选地使用 … 表示零个或多个未指定的标签。为了指定表的前两列标签为 “permno” 和 “yyyymm”,而第三个标签是可变的(取决于信号),可以在 Annotated 泛型内定义以下 Require.LabelsOrder:
from typing import Any, Annotated
from static_frame import Frame, Series, Index, IndexYearMonth, IndexHierarchy, CallGuard, Require
@CallGuard.check
def process(f: Frame[
Any,
Annotated[
Index[np.str_],
Require.LabelsOrder('permno', 'yyyymm', ...),
],
np.int_,
np.str_,
np.float64,
]) -> Series[
IndexHierarchy[Index[np.int_], IndexYearMonth],
np.float64,
]: ...
如果接口期望小集合的 OSAP 信号表,我们可以使用 Require.LabelsMatch 验证器验证第三列。该验证器可以指定必需的标签、标签集合(其中至少一个必须匹配)和正则表达式模式。如果只期望来自三个文件的表(即 “Mom12m.csv”、“Mom6m.csv” 和 “LRreversal.csv”),我们可以通过定义 Require.LabelsMatch 与集合来验证第三列的标签:
@CallGuard.check
def process(f: Frame[
Any,
Annotated[
Index[np.str_],
Require.LabelsOrder('permno', 'yyyymm', ...),
Require.LabelsMatch({'Mom12m', 'Mom6m', 'LRreversal'}),
],
np.int_,
np.str_,
np.float64,
]) -> Series[
IndexHierarchy[Index[np.int_], IndexYearMonth],
np.float64,
]: ...
Require.LabelsOrder 和 Require.LabelsMatch 都可以将函数与标签说明符关联,以验证数据值。如果验证器应用于列标签,则将一系列列值提供给函数;如果验证器应用于索引标签,则将一系列行值提供给函数。
类似于Annotated的用法,标签说明符被替换为一个列表,其中第一个项目是标签说明符,其余项目是返回布尔值的行或列处理函数。
为了扩展上述示例,我们可能需要验证所有“permno”值是否大于零,以及所有信号值(“Mom12m”、“Mom6m”、“LRreversal”)是否大于或等于-1。
from typing import Any, Annotated
from static_frame import Frame, Series, Index, IndexYearMonth, IndexHierarchy, CallGuard, Require
@CallGuard.check
def process(f: Frame[
Any,
Annotated[
Index[np.str_],
Require.LabelsOrder(
['permno', lambda s: (s > 0).all()],
'yyyymm',
...,
),
Require.LabelsMatch(
[{'Mom12m', 'Mom6m', 'LRreversal'}, lambda s: (s >= -1).all()],
),
],
np.int_,
np.str_,
np.float64,
]) -> Series[
IndexHierarchy[Index[np.int_], IndexYearMonth],
np.float64,
]: ...
如果验证失败,[@CallGuard](http://twitter.com/CallGuard).check将引发异常。例如,如果调用上述函数时遇到意外的第三列标签,将引发以下异常:
ClinicError:
In args of (f: Frame[Any, Annotated[Index[str_], LabelsOrder(['permno', <lambda>], 'yyyymm', ...), LabelsMatch([{'Mom12m', 'LRreversal', 'Mom6m'}, <lambda>])], int64, str_, float64]) -> Series[IndexHierarchy[Index[int64], IndexYearMonth], float64]
└── Frame[Any, Annotated[Index[str_], LabelsOrder(['permno', <lambda>], 'yyyymm', ...), LabelsMatch([{'Mom12m', 'LRreversal', 'Mom6m'}, <lambda>])], int64, str_, float64]
└── Annotated[Index[str_], LabelsOrder(['permno', <lambda>], 'yyyymm', ...), LabelsMatch([{'Mom12m', 'LRreversal', 'Mom6m'}, <lambda>])]
└── LabelsMatch([{'Mom12m', 'LRreversal', 'Mom6m'}, <lambda>])
└── Expected label to match frozenset({'Mom12m', 'LRreversal', 'Mom6m'}), no provided match
TypeVarTuple的表达能力
如上所示,TypeVarTuple允许指定具有零个或多个异构列类型的Frame。例如,我们可以为两个浮点数或六种混合类型的Frame提供类型提示:
>>> from typing import Any
>>> from static_frame import Frame, Index
>>> f1: sf.Frame[Any, Any, np.float64, np.float64]
>>> f2: sf.Frame[Any, Any, np.bool_, np.float64, np.int8, np.int8, np.str_, np.datetime64]
虽然这适用于各种 DataFrame,但对宽型 DataFrame(例如具有数百列的 DataFrame)的类型提示可能会显得笨拙。Python 3.11 引入了一种新的语法,通过TypeVarTuple泛型提供可变范围的类型:tuple泛型别名的星号表达式。例如,要对具有日期索引、字符串列标签和任意列类型配置的Frame进行类型提示,我们可以星号解包零个或多个All的tuple。
>>> from typing import Any
>>> from static_frame import Frame, Index
>>> f: sf.Frame[Index[np.datetime64], Index[np.str_], *tuple[All, ...]]
tuple星号表达式可以出现在类型列表中的任何位置,但只能有一个。例如,下面的类型提示定义了一个必须以布尔值和字符串列开始的Frame,但对后续的np.float64列数量有灵活的规定。
>>> from typing import Any
>>> from static_frame import Frame
>>> f: sf.Frame[Any, Any, np.bool_, np.str_, *tuple[np.float64, ...]]
类型提示工具
使用如此详细的类型提示可能会很具挑战性。为了帮助用户,StaticFrame 提供了便捷的运行时类型提示和检查工具。所有 StaticFrame 2 容器现在都具备via_type_clinic接口,允许访问TypeClinic功能。
首先,提供了将容器(例如完整的Frame)转换为类型提示的工具。via_type_clinic接口的字符串表示提供了容器类型提示的字符串表示;另外,to_hint()方法返回一个完整的泛型别名对象。
>>> import static_frame as sf
>>> f = sf.Frame.from_records(([3, '192004', 0.3], [3, '192005', -0.4]), columns=('permno', 'yyyymm', 'Mom3m'))
>>> f.via_type_clinic
Frame[Index[int64], Index[str_], int64, str_, float64]
>>> f.via_type_clinic.to_hint()
static_frame.core.frame.Frame[static_frame.core.index.Index[numpy.int64], static_frame.core.index.Index[numpy.str_], numpy.int64, numpy.str_, numpy.float64]
其次,提供了用于运行时类型提示测试的工具。via_type_clinic.check()函数允许根据提供的类型提示验证容器。
>>> f.via_type_clinic.check(sf.Frame[sf.Index[np.str_], sf.TIndexAny, *tuple[tp.Any, ...]])
ClinicError:
In Frame[Index[str_], Index[Any], Unpack[Tuple[Any, ...]]]
└── Index[str_]
└── Expected str_, provided int64 invalid
为支持渐进类型,StaticFrame 定义了几个配置为每种组件类型的Any的泛型别名。例如,TFrameAny可用于任何Frame,而TSeriesAny用于任何Series。如预期的那样,TFrameAny将验证上面创建的Frame。
>>> f.via_type_clinic.check(sf.TFrameAny)
结论
更好的 DataFrame 类型提示早已迫切需要。凭借现代 Python 类型工具和基于不可变数据模型构建的 DataFrame,StaticFrame 2 满足了这一需求,为优先考虑可维护性和可验证性的工程师提供了强大的资源。
Python 中的类型提示
你的代码将不再是一个谜
·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 7 月 11 日
--

图片来源:Agence Olloweb 在 Unsplash
前几天我试图解读我过去编写的一个脚本如何工作。我知道它做了什么,它的解释和文档都很充分,但理解其工作原理更为麻烦。
代码繁琐且复杂,虽然有些注释,但缺乏适当的样式。这时我决定学习 PEP 8[1]并将其融入我的代码中。
如果你不知道 PEP 8 是什么,它基本上是一个提供指导方针、编码约定和最佳实践的文档,用于编写 Python 代码。
我们难以理解的代码的解决方案就在那儿。然而,我们大多数人从未花时间去阅读它并将这些指导方针融入日常实践中。
这需要时间和很多错误,但相信我,这是值得的。我学到了很多东西,代码现在开始变得更好了。
我最喜欢的发现之一是类型提示(或类型注释)——这将是今天帖子的主题。实际上,类型提示已经出现在 PEP 3107[2]中,回到 2006 年,并在 484[3]版本(2014 年)中重新审视并全面记录。从那时起,它在新的 PEP 版本中得到了多次改进,几乎成为经典。
所以,对于许多人来说,这是一个老话题但又非常新鲜。
什么是类型提示?
类型提示指示函数(同样适用于类方法)的输入和输出的数据类型。
许多 Python 用户抱怨的一个问题是我们可以自由地更改变量类型。在 C 语言和许多其他语言中,你需要声明一个变量并指定其类型:字符、整数等。
每个人都会有自己的看法——有些人可能喜欢 Python 的自由(以及对内存管理的影响),而另一些人则更喜欢传统语言的限制,因为这使他们的代码更具可读性。
无论如何。
类型提示旨在使你的 Python 代码更具可读性,我相信我们大多数人都很欣赏这种方法。然而,它们旨在澄清,它们并不强制变量的数据类型。
如果变量的类型不是我们预期的,程序不会引发错误。
为什么数据科学家应该考虑使用它们?
说实话,任何 Python 程序员都能从类型注解中受益。但这对数据科学家和其他与数据相关的专业人士来说可能更有意义。
那是因为我们处理各种数据。不仅仅是简单的字符串、列表或元组。我们使用的数据可能涉及超级复杂的结构,而类型提示有潜力节省我们很多时间,帮助我们了解函数预期的数据类型。
例如,假设我们有一个基于字典的结构。它的键是元组,值是具有字符串键和集合值的嵌套字典。
祝好运,希望几个月后你重访代码时能记住这一点!
好的一点是,类型提示非常容易理解和使用。我们没有理由不使用它们,而且不使用它们也没有任何好处。
那么,让我们继续并开始查看一些代码。
1. 首先概述
我将使用 Python 3.11,但大多数示例适用于之前的 Python 3 版本。
让我们使用一个示例和虚拟函数:
def meet_someone(name, age):
return f"Hey {name}, I heard you are {age}. I'm 21 and from Mars!"
这很愚蠢。但它包含了我们需要的一切,我们现在将添加一些变体。
我们这里没有任何类型注解。只有一个接受两个参数并返回字符串的函数。我相信你知道 name 参数应该是字符串,而 age 参数则预期是整数(甚至是浮点数)。
但你知道,因为这是一个非常简单的函数。很少会如此简单。这就是为什么添加一些提示可能是明智的。
这是更新:
def meet_someone(name: str, age: int) -> str:
return f"Hey {name}, I heard you are {age}. I'm 21 and from Mars!"
在这种情况下,我指定 age 应该是一个整数。让我们尝试运行这个函数:
>>> meet_someone('Marc', 22)
Hey Marc, I heard you are 22\. I'm 21 and from Mars!
为了说明我在上一节末尾说的内容:
>>> meet_someone('Marc', 22.4)
Hey Marc, I heard you are 22.4\. I'm 21 and from Mars!
即使 22.4 是浮点数(而不是预期的整数),它也工作得很好。如前所述,这些只是类型提示,仅此而已。
好的,基础知识已经涵盖。让我们开始制作一些变体。
2. 多种数据类型
假设我们希望允许整数和浮点数作为年龄参数的数据类型。我们可以使用Union,来自 typing 模块[4]:
from typing import Union
def meet_someone(name: str, age: Union[int, float]) -> str:
return f"Hey {name}, I heard you are {age}. I'm 21 and from Mars!"
很简单:Union[int, float] 表示我们期望值是整数或浮点数。
然而,如果你使用的是 Python 3.10 或更高版本,还有另一种方法可以在不使用 Union 的情况下实现相同功能:
def meet_someone(name: str, age: int | float) -> str:
return f"Hey {name}, I heard you are {age}. I'm 21 and from Mars!"
这只是一个简单的 OR 运算符。在我看来,更容易理解。
3. 高级数据类型
现在假设我们要处理更复杂的参数,比如字典或列表。接下来我们使用下一个函数,它使用了meet_someone函数:
def meet_them_all(people) -> str:
msg = ""
for person in people:
msg += meet_someone(person, people[person])
return msg
这仍然是一个非常简单的函数,但现在参数可能不像我们之前看到的那样清晰。如果你实际检查代码,你会看到我们期望的是一个字典。
但如果我们不需要猜测会不会更好?这就是类型提示的力量。
到此为止,如果我让你自己添加类型提示,你可能会做这样的事情:
def meet_them_all(people: dict) -> str:
msg = ""
for person in people:
msg += meet_someone(person, people[person])
return msg
这很好。但是我们还没有充分利用它的全部潜力。我们在这里指定了想要一个dict,但没有指定它的键和值的类型。这是一个改进版:
def meet_them_all(people: dict[str, int]) -> str:
msg = ""
for person in people:
msg += meet_someone(person, people[person])
return msg
在这里,我们说我们期望people是一个字典,键是字符串,值是整数。类似于{'Pol': 23, 'Marc': 21}。
但记住我们想要接受年龄为整数或浮点数…
from typing import Union
def meet_them_all(people: dict[str, Union[int, float]]) -> str:
msg = ""
for person in people:
msg += meet_someone(person, people[person])
return msg
我们可以直接使用我们在第二部分学到的内容!很酷吧?
哦,它不仅适用于内置数据类型。你可以使用任何你想要的数据类型。例如,假设我们想要一个 Pandas 数据框的列表,用于一个不返回任何东西的函数:
import pandas as pd
Vector = list[pd.DataFrame]
def print_vector_length(dfs: Vector) -> None:
print(f'We received {len(dfs)} dfs')
我在这里做的是声明数据类型,它只是一个数据框的列表,并将其用作类型提示。
此外,还有我们今天之前没有见过的,这个函数不会返回任何东西。这就是为什么输出数据类型是 None。
4. Optional 运算符
我们经常创建一些参数不是必需的——它们是可选的。
既然我们已经了解了这些,下面是如何编写一个带有可选参数的函数:
def meet_someone(name: str,
age: int | float,
last_name: None | str = None
) -> str:
msg = f"Hey {name}{' ' + last_name if last_name else ''}, "\
f"I heard you are {age}. I'm 21 and from Mars!"
return msg
我已经更新了返回的消息,但重要的部分是最后一个参数last_name的类型提示。看看我这里怎么说:“last_name要么是一个字符串,要么是一个 Null 值。它是可选的,默认情况下是 None。”
这很酷,也很直观,但想象一下一个参数有几种可能的数据类型……这可能会很长。
这就是为什么 Optional 运算符在这里很有用,它基本上允许我们跳过None提示:
from typing import Optional
def meet_someone(name: str,
age: int | float,
last_name: Optional[str] = None
) -> str:
msg = f"Hey {name}{' ' + last_name if last_name else ''}, "\
f"I heard you are {age}. I'm 21 and from Mars!"
return msg
结论与下一步
我希望我已经传达了类型提示在提高代码可读性和理解方面的有用性。不仅仅是为了我们的同事程序员,也是为了我们未来的自己!
我已经介绍了基础知识,但我建议你继续查看 typing 模块提供的内容。那里有几类可以让你的代码看起来更好。
**Thanks for reading the post!**
I really hope you enjoyed it and found it insightful.
Follow me and subscribe to my mailing list for more
content like this one, it helps a lot!
**@polmarin**
如果你想进一步支持我,请考虑通过下面的链接订阅 Medium 会员:这不会花费你额外的钱,但会帮助我完成这个过程。
[## 通过我的推荐链接加入 Medium - Pol Marin
阅读 Pol Marin 的每一篇故事(以及 Medium 上成千上万其他作家的文章)。你的会员费用直接支持 Pol…
资源
[1] PEP 8 — Python 代码风格指南 | peps.python.org
[2] 3107 — 函数注解 | peps.python.org
[3] 484 — 类型提示 | peps.python.org
I 型和 II 型错误及假设检验中的样本大小计算
假设检验中影响结果的因素
·发表于 Towards Data Science ·9 分钟阅读·2023 年 2 月 23 日
--

图片由 Scott Graham 提供,来源于 Unsplash
在统计和数据分析的世界中,假设检验是一个基本概念,在做出明智决策中发挥着至关重要的作用。在这篇博客中,我们将深入探讨假设检验,特别关注如何减少 I 型和 II 型错误。我们将讨论影响这些错误的因素,如显著性水平、样本大小和数据变异性。让我们深入探讨假设检验的复杂性吧!
我们将在整个博客中使用以下示例。
上一个学期的平均学生 GPA 为 2.70。在当前学期启动了一项辅导程序。我们希望进行以下假设检验,以研究辅导程序是否能提升学生的 GPA。
在当前学期结束时,我们收集了 20 个随机的 GPA 记录,并假设学生 GPA 服从标准差(σ)为 0.5 的正态分布。μ代表总体的平均 GPA。
-
零假设:μ = 2.70(即辅导程序在提升学生 GPA 方面没有帮助。)
-
备择假设:μ > 2.70(即辅导程序有帮助。)
学校的资金非常有限。我们希望将I 型错误的风险降到最低(即错误地得出辅导程序有帮助的结论,尽管它实际上并没有帮助)。
你可能会问
我们需要考虑哪些因素来减少I 型错误?
1. 显著性水平 (α)
显著性水平(α)是我们愿意接受的预定义最大第一类错误概率。
在显著性水平下,我们可以找到临界值以拒绝假设检验中的零假设。

-
μ是零假设中的总体参数(例如,总体均值)。
-
σ是总体标准差。如果σ未知,我们可以使用样本标准差,s,来估计它。
-
n 是样本量
-
Z 是与给定α相关的 Z 统计量。如果σ未知或样本量小于 30,我们将使用 T 统计量以产生更可实现的结果。
如果观察到的样本统计量(例如,样本均值)等于或更极端于临界值,我们将拒绝零假设。
当我们基于显著性水平(α)做决定时,有一个最大α *100% 的犯第一类错误的风险。
*P(第一类错误,即当 X̄ > 临界值且零假设正确时拒绝零假设)= α 100%

图片由作者提供
例如,

图片由作者提供
显著性水平(α)越低,犯第一类错误的风险越低。
2. 样本量
另一个可能影响第一类错误的因素是样本量的变化(例如,从 n = 20 到 n = 100),让我们看看它如何影响第一类错误的概率。
例如,
当α = 0.1,n = 20,
P(当 X̄ > 2.84 时拒绝零假设的第一类错误概率)= P(Z > 2.84–2.7/0.5/√20)= 10%
当α = 0.1,n = 100,
P(当 X̄ > 2.84 时拒绝零假设的第一类错误概率)= P(Z > 2.84–2.7/0.5/√100)= 0.26%
我们可以调整样本量以适应不同的显著性水平,并获得相同的结果。

图片由作者提供
在不同的显著性水平下,随着样本量的增加,第一类错误的概率会降低。
这可以用常识理解。样本量越大,你对总体的信息越多。这意味着测试统计量的精度提高,第一类错误的概率降低。
3. 数据变异性
数据变异性也会影响第一类错误。如果数据变异性减少(即,总体标准差变小),我们预期犯第一类错误的概率会更小。
例如,
当α = 0.1,SD = 0.5,
P(当 X̄ > 2.84 时拒绝零假设的第一类错误概率)= P(Z > 2.84–2.7/0.5/√20)= 10%
当α = 0.1,SD = 0.3,
P(当 X̄ > 2.84 时拒绝零假设的第一类错误概率)= P(Z > 2.84–2.7/0.3/√20)= 1.8%
我们可以针对不同的显著性水平调整标准差,并获得相同的结果。

图片由作者提供
在不同的显著性水平下,标准差减少时第一类错误的概率也会减少*。
另一方面,如果研究的总体变异性更大,那么检测真实效应可能会更困难。换句话说,第一类错误的概率增加。这是因为检验统计量的分布更广泛,难以区分原假设和备择假设。
接下来,我们将讨论如何减少假设检验的第二类错误。
但首先,
如何计算当备择假设正确时的假设检验的第二类错误?
如果我们在备择假设正确时未拒绝原假设,我们就犯了第二类错误。
在这个例子中,如果备择假设为真(例如,真实的总体 GPA 均值为 3.0),第二类错误的概率可以计算为
P(第二类错误,即当 X̄ < 临界值而备择假设正确时未拒绝原假设) = β100%*

作者提供的图片
在许多情况下,我们也对计算假设检验的效能感兴趣。
假设检验的 效能 (计算为 1-β)**是正确拒绝原假设的概率,当备择假设正确时。
我们需要考虑哪些因素来减少第二类错误(或增加效能)?
1. 显著性水平 (α)
显著性水平(α)也会影响第二类错误,但方向相反。
例如,
当α = 0.1,SD= 0.5,n=20,真实μ = 3.0
P(当 X̄ < 2.84 时未拒绝原假设的第二类错误) = P(Z < 2.84–3.0/0.5/√20) = 8%
当α = 0.05,SD= 0.5,n=20,真实μ = 3.0
P(当 X̄ < 2.88 时未拒绝原假设的第二类错误) = P(Z < 2.88–3.0/0.5/√20) = 14%

作者提供的图片
显著性水平 (α) 的降低会导致第二类错误的概率增加或效能的降低。
2. 样本大小
样本大小会以相同的方式影响第二类错误,就像第一类错误一样。
例如,
当α = 0.1,SD= 0.5,n=20,真实μ = 3.0
P(当 X̄ < 2.84 时未拒绝原假设的第二类错误) = P(Z < 2.84–3.0/0.5/√20) = 8%
当α = 0.1,SD= 0.5,n=100,真实μ = 3.0
P(当 X̄ < 2.84 时未拒绝原假设的第二类错误) = P(Z < 2.84–3.0/0.5/√100) = 0.069%

作者提供的图片
在不同的显著性水平下,随着样本大小的增加,第二类错误的概率会减少。
3. 数据变异性
数据变异性也会以相同的方式影响第二类错误,就像第一类错误一样。
例如,
当α = 0.1,SD= 0.5,n=20,真实μ = 3.0
P(当 X̄ < 2.84 时未拒绝原假设的第二类错误) = P(Z < 2.84–3.0/0.5/√20) = 8%
当 α = 0.1,SD= 0.3,n=20,真实 μ = 3.0
P(未能拒绝零假设的第二类错误概率,当 X̄ < 2.84)= P(Z < 2.84–3.0/0.3/√20) = 0.85%

作者提供的图像
在不同显著性水平下,第二类错误的概率会随着标准差的减少而减少。
4. 效应量
效应量是零假设与备择假设之间差异的大小(例如,当真实总体 GPA 均值为 3.0 时,效应量为 0.3,(3.0–2.7))。
如果效应量增加,则更容易检测到真实效应,第二类错误的概率降低。
例如,
当 α = 0.1,SD= 0.5,n=20,效应量 = 0.3
P(未能拒绝零假设的第二类错误概率,当 X̄ < 2.84)= P(Z < 2.84–3.0/0.5/√20) = 8%
当 α = 0.1,SD= 0.5,n=20,效应量 = 0.4
P(未能拒绝零假设的第二类错误概率,当 X̄ < 2.84)= P(Z < 2.84–3.1/0.5/√20) = 1%

作者提供的图像
下表总结了第一类和第二类错误与这些各种因素之间的关系。

作者提供的图像
现在我们了解了第一类和第二类错误是各种因素的函数。
我们如何同时减少第一类和第二类错误的概率?
简单的答案是:增加样本大小是同时减少 α 和 β 的唯一方法。
但我们如何计算样本大小呢?
检验力分析 是一种常用工具,用于计算样本大小以实现所需的第一类和第二类错误水平。
我们需要以下信息来计算样本大小。
1. 显著性水平 (α): 我们通常会提前确定 α 值。常见的 α 值为 0.01、0.05 和 0.1。
2. 检验力 (1-β):这是你假设检验检测效应的能力。检验力越高,检测效应的可能性越大,第二类错误的风险越低。我们通常将检验力设定为 80%或 20% β。
3. 数据变异性(即标准差,σ):数据变异性不是由我们决定的。我们需要专家的领域知识或对样本数据进行分析。
4. 效应量 (δ):在实践中,我们不会知道真实的效应量,因为我们仅使用样本数据。相反,我们可以确定最小重要差异 (MID),即被认为具有意义或临床相关的测量结果的最小差异。我们可以将效应量设定为 MID。
-
如果真实的效应大小小于 MID,那么效应大小在实际意义上就不显著。例如,如果辅导只使学生的 GPA 提高 0.1(即效应大小=0.1),这会是一个震撼的结论吗?可能不会。因此,样本大小足以检测到小于 MID 的效应被视为资源的浪费。
-
如果真实的效应大小大于 MID,那么我们更有可能检测到真实的效应。因此,基于 MID 的样本大小对于假设检验来说是足够的。
这是计算样本大小的基本公式。

从公式中,我们可以总结出样本大小与这些因素之间的关系。

作者提供的图片
总之,在进行假设检验时,考虑 I 型和 II 型错误及其可能影响的因素是至关重要的。通过仔细考虑这些因素并平衡两种错误的风险,我们可以基于假设检验的结果做出更准确和更有根据的决策。
如果你想探索更多与统计学相关的帖子,请查看我的文章:
-
中央极限定理的 7 个常见问题
-
标准差与标准误差:有什么区别?
-
3 种最常见的误解:假设检验、置信区间、P 值
-
线性回归模型中的误差项是否服从正态分布?
-
OLS 估计量在线性回归模型中是否服从正态分布?
-
什么是正则化:偏差-方差权衡
-
置信区间与预测区间:有什么区别?
谢谢你的阅读!
如果你喜欢这篇文章,请点击点赞图标。如果你想查看更多来自我和其他成千上万的作者在 Medium 上的文章,你可以:
U-Net 解析:理解其图像分割架构
跳跃连接如何使 CNN 在数据较少的情况下进行准确的语义分割
·发表于Towards Data Science ·7 分钟阅读·2023 年 3 月 8 日
--

(来源:作者)
U-Net 是一种用于语义分割的流行深度学习架构。最初为医学图像开发,它在这一领域取得了巨大成功。但,这仅仅是开始!从卫星图像到手写字符,该架构在各种数据类型上都提高了性能。然而,其他 CNN 架构也能进行分割,那么 U-Net 到底有什么特别之处呢?
为了回答这个问题,我们将探索 U-Net 架构。我们将其与用于分类和自编码器的 CNN 进行比较。通过这样做,我们将理解跳跃连接如何是 U-Net 成功的关键。我们将看到它们如何使该架构在数据较少的情况下进行准确的分割。
什么是语义分割?
我们将从理解 U-Net 开发的目的开始。图像分割或语义分割是将类分配给图像中每一个像素的任务。模型使用分割图作为目标变量进行训练。例如,参见图 1。我们有原始图像和一个二值分割图。这个图将图像分为细胞和非细胞像素。
这个生物医学图像分割任务正是 U-Net 最初开发的目的。这些数据集的决定性因素是训练图像数量很少。图 1 中的示例来自仅有 35 张图像的数据集。在图像增强的帮助下,U-Net 在准确性上比第二好的方法提高了 11%。

图 1:医学图像中的细胞分割(来源:O. Ronneberger, et. al.)
U-Net 也是灵活的。我在自己的研究中应用了它来分割卫星图像。正如图 2 所示,我们将海岸线图像分割成 2 个类别——陆地和水域。这个任务类似,但输入与医学图像不同。我们从单一的灰度图像变为了使用12 个光谱波段的卫星图像。

图 2:海岸线水体分割(来源:作者)(数据集:SWED)(许可证:Sentinel 数据法律声明)
U-Net 架构
因此,U-Net 能够在各种分割任务中取得良好的结果。为了说明原因,我们将查看架构中最重要的组件——编码器、解码器和跳跃连接。我们将看到这些如何结合在一起,以提取和定位图像中的特征。
编码器
对于语义分割,我们关心的是图像中有什么物体以及这些物体在图像中的位置。这与目标检测或图像分类不同。在这里,我们旨在为每张图像预测一个类别。 即我们只关心是否图像中存在某个物体。为了进行这些预测,我们可以使用编码器。
你会在所有 CNN 架构中找到一个编码器的版本。它的工作是创建输入图像的紧凑表示。这是一个低维度的表示,仅包含图像中最重要的信息。换句话说,编码器用于提取特征。

图 3:用于图像分类的编码器(来源:作者)
这通过卷积层和池化层来完成。卷积层是一个映射或卷积核,它会遍历图像中的每一个像素。这个映射通过训练模型的过程学习到。然后,使用预定义的函数,池化层减少了输出的维度。

图 4:卷积层和池化层(来源:作者)
通过组合多个卷积层和池化层,我们可以提取更详细的信息。我们从边缘和颜色等低级细节开始,到耳朵、牙齿和眼睛等高级特征。网络将学习哪些特征对分类很重要,并提取这些特征来创建图像的紧凑表示。
一个问题是这种紧凑表示不包括图像中特征的位置。这对于图像分类来说是可以的。为了分类一只狗,我们只需要知道图像中是否有尾巴、耳朵或毛发。图像中这些特征的位置在哪里并不重要。相比之下,对于分割,位置是重要的。
解码器
编码器的另一个问题是其输出维度较低。如果用于分类,最终层将有几个节点——每个类别一个节点。对于分割,我们的输出将是一个与输入具有相同高度和宽度的图像。
我们需要一个解码器。如图 5 所示,这是从 conv4 块之后开始的部分。解码器将从紧凑的表示中重建图像。与编码器一样,它有卷积块。现在我们有反卷积层来增加图像的维度。

图 5:自编码器架构(来源:作者)
如前所述,池化层将使用预定义的方法来减少维度。例如最大池化,它取单元窗口中的最大像素值。相比之下,上采样或反卷积层使用学习的函数增加维度。即上采样函数会在模型训练时更新。

图 6:使用学习的上采样函数的反卷积(来源:作者)
在自编码器中,输入和输出图像将是相同的。这里解码器的目标是尽可能准确地重建输入。然后我们可以将来自较低维度层(即 conv4)的参数作为压缩图像。我们可以保存或发送压缩图像。然后,解码器可以用来重建原始输入。

图 7:自编码器输出与语义分割模型的对比(包含我的狗,Guinness)(来源:作者)
此时,你可能会问编码器和解码器是否足够。这种架构可以学习从图像到图像的映射。那么它当然可以学习到分割所需的简单输出映射。
解码器能够将重要特征传递给编码器。问题是特征的位置仍然丢失。为了解决这个问题,我们需要大量的数据来训练自编码器。这是解码器能够准确重建压缩表示图像的唯一方法。通过跳跃连接,我们可以减少这个数据需求。
跳跃连接
重要的是,对于自编码器,编码器和解码器必须是分开的。否则,这就违背了图像压缩的整个意义。对于语义分割,我们没有这个限制。
在 U-Net 中,跳跃连接用于将早期卷积层的信息传递到反卷积层。关键是传递的内容是卷积层提取的特征的位置。也就是说,跳跃连接告诉网络特征在图像中的位置。

图 8:U-Net 架构(来源:作者)
这是通过连接卷积块的最后一层和对称反卷积块的第一层完成的。U-Net 是对称的——对面层的维度将是相同的。如图 9 所示,这使得将层合并为一个单一张量变得容易。然后,通过在单一连接张量上运行内核来进行卷积处理。

图 9:连接层(来源:作者)
这种连接是 U-Net 的核心。它结合了两个重要的信息:
-
特征提取 — 特征从前一层传递到上采样层(蓝色)
-
特征定位 — 特征的位置从对面卷积层传递(橙色阴影)
通过结合这些信息,我们可以提高语义模型的性能,并减少训练网络所需的数据量。
我们略过了一些细节,比如激活函数、层数和层的维度。这些都可以作为 U-Net 中的超参数。为了应对特定的分割问题,对原始架构进行了调整。所有这些成功的关键在于跳跃连接。
希望你喜欢这篇文章!你可以通过成为我的 推荐会员 😃 来支持我
[## 通过我的推荐链接加入 Medium — Conor O’Sullivan
作为 Medium 会员,你的部分会员费用会分配给你阅读的作者,你可以完全访问每个故事…
conorosullyds.medium.com](https://conorosullyds.medium.com/membership?source=post_page-----56e4842e313a--------------------------------)
| Twitter | YouTube | Newsletter — 注册获取 免费 访问 Python SHAP 课程
参考文献
Olaf Ronneberger, Philipp Fischer, Thomas Brox, U-Net: Convolutional Networks for Biomedical Image Segmentation (2015)
Bharath K, U-Net 架构用于图像分割 (2021), blog.paperspace.com/unet-architecture-image-segmentation/
Jeremy Zhang, UNet — 一行行解释 (2019), towardsdatascience.com/unet-line-by-line-explanation-9b191c76baf5
Heet Sankesara 的文章《UNet — 引入分割中的对称性》 towardsdatascience.com/u-net-b229b32b4a71
终极 Hive 教程:大数据管理与查询的必备指南
解锁 Hive 的力量:您的深入指南与视觉思维导图洞察
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 11 月 10 日
--

作者通过 Obsidian 提供的图像
介绍
导航大数据的迷宫可能是一项艰巨的任务,特别是当这些路径铺满了复杂的术语和繁琐的过程时。这对 Apache Hive 尤其如此,它是大数据生态系统中进行数据管理和查询的重要工具。尽管它的意义重大,但关于 Hive 的清晰而简明的教程资源却很少。这正是我制作“终极 Hive 教程:大数据管理与查询的必备指南”的原因。
本博客旨在破解复杂性,提供一个唯一的、全面的指南,阐明Hive Metastore、Hive 数据模型和细致的metadata世界——所有这些都通过直观的示例和视觉思维导图来呈现。
示例说明
为了演示 Hive 的核心概念,让我们设想一个全球零售连锁公司部署 Hive 来编目和检查其销售交易。此操作的核心是一个名为 sales_db 的主要数据库。在该数据库中,有一个关键的表格 sales_data,旨在系统地记录销售活动。我们将使用这个示例来说明本文中的所有 Hive 相关概念。让我们来看一下这个表格:

作者通过 Excel 提供的图像
什么是 Metadata?
想象一下你发现了一个古老而尘封的图书馆。每本书都包含一个故事,但没有总结内容的目录卡——标题、作者、出版日期——你就会在信息的海洋中迷失方向。元数据类似于这些目录卡,它不是数据本身,而是“数据的描述”——一个描述主数据属性、关系和来源的信息层。在上述 sales_data 表中,元数据包括 列名 —— region_id 、date 、transaction_id 、product_id 、store_id 、sale_price ,以及它们的 数据类型、数据位置 等。
什么是 Hive Metastore?
继续使用我们的图书馆类比,如果元数据是目录卡,那么 Hive Metastore 就是 图书管理员。它仔细 组织 这些卡片,确保每一条数据都有一个位置,每个查询都有一张通向信息宝藏的地图。Hive Metastore 不存放实际的书籍(数据);它存储和管理元数据。 它是策展人,跟踪所有存储在 Hadoop 分布式文件系统(HDFS) 中的内容、每个文件包含什么、如何格式化以及如何分区。
作为数据的守护者,它确保每个查询和数据操作都顺利进行,为用户提供一个清晰的结构,以应对大数据的混乱。Hive metastore 包括两个基本单位:
-
Metastore 服务:一个提供 metastore 访问的服务,供其他 Apache Hive 服务使用。
-
元数据数据库:Hive 元数据的磁盘存储,与 HDFS 存储分开。
什么是 Hive 数据模型.-,Hive%20Data%20Model,-Data%20in%20Hive)?
首先谈谈元数据数据库及其设计——Hive 数据模型,即我们隐喻图书馆架子的蓝图。它定义了 表、分区 和 桶 的结构,这些就像是图书馆中分类和储存元数据的隔间和抽屉。
以全球零售为例,让我们从一个更深入的角度来审视 Hive 数据模型:
-
表:Hive 中的表是数据模型的关键部分,镜像了其关系数据库对应物的结构和功能。
sales_data表就是一个例子,结构化地存储相关的销售指标。 -
桶:Hive 通过引入桶化来增加数据组织的层次,这将数据分配到预定数量的桶中。每个桶是根据指定列的哈希值填充的,从而促进了更均衡的数据分布并提高了查询性能。对于我们的零售连锁店来说,
region_id列可能是一个桶化的候选列,确保销售数据在不同的区域段中均匀分配。 -
分区:为了应对查询大量数据的固有挑战,Hive 实现了分区功能。此功能将表格分割成离散的部分,每个部分对应于唯一的列值。通过分区,针对特定数据子集的查询——例如特定日期的销售情况——可以更快、更高效地执行。对于
sales_data表,通过date列进行分区意味着每个日历日的销售数据被整齐地存储在 Hadoop 分布式文件系统 (HDFS) 的各自子目录中,从而简化了访问和检索。

图片由作者通过 Obsidian 提供
这个精炼的 Hive 数据模型概述概括了其主要组成部分,突显了模型在大数据范围内简化和加速查询过程的能力。通过利用 Hive 的功能,企业可以自信而精准地处理数据存储和分析的复杂性。
图书馆的宝藏:表类型
在 Hive Metastore 中谈到表时,就像图书馆中的多样藏书一样,有两种主要的宝藏:
-
受管表:将这些表视为图书馆拥有、培养和保护的故事。它们存储在图书馆的范围内,并由图书馆直接管理。在 Hive 中,如果删除这些表,数据将被删除。
-
外部表:将这些表视为指引寻求者到其他图书馆的参考卡。虽然它们不在图书馆内,但它们提供了通向更多知识的门户。在 Hive 中,如果删除这些表,数据将保留但与 Hive 断开连接。
为什么选择 Hive 数据模型?
现在,这为什么重要?在大数据的世界里,我们处理的不是简单的体量,而是海量的数据。就像世界上最有效的图书馆一样,我们需要一个系统来管理这种规模。Hive 数据模型使你能够查询庞大的数据集而不会迷失在海洋中。它提供了一个熟悉的关系模型,让你像书中的章节一样分区数据,将类似的主题归为一类,以便快速访问。
考虑到sales_data表以 Parquet 格式 存储:
-
数据存储位置:实际的销售数据(原始数据)存储在 Parquet 文件中,这些文件位于 HDFS 中。这些文件分布在 Hadoop 集群中的多个节点的磁盘上。
-
元数据管理:Hive Metastore 保存有关
sales_data表的元数据。该元数据包括诸如模式(列的名称和类型)、Parquet 文件在 HDFS 中的位置、分区和桶详细信息以及其他表属性的信息。
当提交一个查询时,例如:
Select * from sales_data
where region_id = US
and date >= '2023-10-02';
发生了以下情况:
-
查询执行:当对
sales_data表运行 Hive 查询时,Hive 利用元数据来了解数据的结构,并确定相关 Parquet 文件的位置。对于仅需特定列的查询,Hive 可以有效地从列式 Parquet 文件中读取所需数据,而无需扫描整个数据集。 -
分区与性能:如果
sales_data按照诸如date这样的列进行分区,Hive 会将数据存储在 HDFS 中每个分区的单独子目录中。每个分区目录中的 Parquet 文件仅包含该特定日期的数据。当查询按日期筛选时,Hive 只读取来自相关分区目录的 Parquet 文件,这是关键的性能优化。
实质上,查询只会读取 US 桶下具有对应 date 的文件。虽然 sales_data 表中的数据物理上存储在 Parquet 文件中并位于 HDFS 上,但 Hive 管理着这些数据的结构、查询和处理方式。Hive 用于元数据管理,而 Parquet 用于数据存储的组合,形成了一个强大且高效的大数据管理系统。
基于元数据的优化
Hive 的真正力量,特别是在处理大数据方面,是通过其优化机制显现出来的,这些机制严重依赖于元数据。以下是具体说明:
分区修剪:最重要的优化之一是分区修剪,其中 Hive 使用元数据来识别和访问查询所需的相关分区。例如,如果分析师想要分析第一季度的销售数据,Hive 将使用与日期相关的元数据来跳过所有不在此范围内的分区。这大大减少了读取和处理的数据量,从而加快了查询执行速度。
基于成本的优化的元数据:Hive 还使用元数据进行基于成本的优化(CBO)。通过了解数据统计信息,如行数和数据分布,Hive 可以确定执行查询的最有效方式。它可以决定是否使用索引、是否进行映射端连接而不是减少端连接,或者多表连接的最佳顺序。
列统计的元数据:列统计信息如最小/最大值、空值数量和数据分布,使 Hive 能够做出明智的决策,选择最有效的执行路径。这可能包括跳过不符合查询过滤条件的数据块或选择适合聚合的操作符。
通过这些交互——查询 Metastore 以获取模式和位置细节,并利用元数据进行优化——Hive 提供了一个强大的平台,用于在大规模数据集上执行复杂的分析工作负载。这些优化确保即使数据以指数级增长,Hive 查询仍然高效,使其成为大数据生态系统中的重要工具。
Hive Metastore 的多功能性:超越 Hive 集成
虽然 Hive Metastore 主要与 Apache Hive 相关联,但它的作用远不止于单一服务。Hive Metastore 作为一个中心模式仓库,对于在 Hadoop 生态系统及其他环境中集成各种数据处理工具至关重要。以下是 Hive Metastore 如何作为多个服务的枢纽:
-
更广泛的 Hadoop 生态系统协同:像Apache Spark™和Apache Pig这样的工具利用 Metastore 读取 Hive 表元数据,促进了一个一致的数据处理环境。
-
BI 工具兼容性:BI 应用程序,如Tableau,连接到 Hive Metastore 以可视化和查询 Hive 管理的数据,使得额外的数据洞察更易获取。
-
数据湖治理:像Apache Atlas这样的平台与 Metastore 集成进行数据治理,利用其元数据进行全面的数据血缘追踪和安全管理。
-
模式管理和数据质量:Metastore 对于管理模式演变和确保各应用程序数据质量至关重要,维护数据完整性。
-
跨平台数据访问:Metastore 实现了与云服务的兼容性,允许在不同环境中无缝访问 Hive 数据。
深入探索: Hive 库宇宙的详细思维导图
思维导图作为一种图形方法,用于结构化和描述信息。让我们通过下面的思维导图总结一下关于 Hive Metastore 的讨论:

作者通过 Obsidian 提供的图片
结论
开启大数据之旅,使用 Hive 不必独自一人踏上未知的征途。通过本指南,我希望你能全面了解 Hive。掌握这些概念后,我希望你能充分发挥销售数据及遇到的其他数据的潜力,将分析结果转化为可操作的见解。欢迎加入,祝数据管理愉快!📚💾
ULTRA: 知识图谱推理的基础模型
图形机器学习有什么新进展?
一个模型统治一切
·
关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 11 月 3 日
--
训练一个通用模型以解决任意数据集始终是机器学习研究者的梦想,特别是在基础模型时代。虽然这种梦想在图像或自然语言等感知领域已得以实现,但是否能够在推理领域(如图形)中重现仍然是一个未解的挑战。

图片由作者编辑,源自 DALL-E 3 的输出。
在这篇博客文章中,我们证明了这样一个通用推理模型的存在,至少对于知识图谱(KGs)是如此。我们创建了ULTRA,这是一个单一的预训练推理模型,能够泛化到任意实体和关系词汇的新 KG,这为任何 KG 推理问题提供了默认解决方案。
这篇文章基于我们最近的论文(preprint),由 Xinyu Yuan (Mila), Zhaocheng Zhu (Mila) 和 Bruno Ribeiro (Purdue / Stanford) 共同撰写。关注 Michael、 Xinyu、 Zhaocheng* 和* Bruno 在 Twitter 上获取更多 Graph ML 内容。
大纲
-
Why KG representation learning is stuck in 2018
-
Theory: What makes a model inductive and transferable?
-
Theory: Equivariance in multi-relational graphs
-
ULTRA: A Foundation Model for KG Reasoning
-
Experiments: Best even in the zero-shot inference, Scaling behavior
-
Code, Data, Checkpoints
为什么 KG 表示学习停留在 2018 年
预训练-微调范式自 2018 年起就存在,当时ELMo和ULMFit展示了首次有希望的结果,随后这些成果在BERT和GPT的助力下得到了巩固。
在大型语言模型(LLM)和更通用的基础模型(FMs)时代,我们常常拥有一个(如 GPT-4 或 Llama-2)在大量数据上预训练的单一模型,能够以零样本方式执行各种语言任务(或者至少在特定数据集上进行微调)。如今,多模态 FMs 甚至支持在同一个模型中进行语言、视觉、音频和其他模态的处理。
图形机器学习(Graph ML)中的情况稍有不同。特别是,2023 年底 KG 上的表示学习怎么样了? 这里的主要任务是边级:
-
实体预测(或知识图谱补全)
(h,r,?):给定一个头节点和关系,对图中所有可能成为真实尾节点的节点进行排名。 -
关系预测
(h,?,t):给定两个节点,预测它们之间的关系类型。
事实证明,到目前为止,它一直停留在 2018 年前的阶段。关键问题是:
每个 KG 都有自己的一组实体和关系,没有单一的预训练模型可以适用于任何图。
例如,如果我们查看 Freebase(Google 知识图谱背后的知识图谱)和 Wikidata(最大的开源知识图谱),它们具有完全不同的实体集合(86M 对 100M)和关系(1500 对 6000)。当前的知识图谱表示学习方法是否有希望在一个图上训练并转移到另一个图上?

Freebase 和 Wikidata 的词汇表不同。图片由作者提供。
❌ 像 TransE、ComplEx、RotatE 等传统转导方法,以及其他数百种基于嵌入的方法,从训练图中学习固定的实体和关系类型,甚至无法支持同一图中添加的新节点。浅层嵌入方法无法转移(事实上,我们认为除了某些学生项目练习外,再也没有必要开发这样的技术)。
🟡 像NodePiece和Neural Bellman-Ford Nets这样的归纳实体方法不会学习实体嵌入。相反,它们将训练(已见)和新的推理(未见)节点参数化为固定关系的函数。由于它们只学习关系嵌入,这使得它们能够转移到具有新节点的图中,但转移到具有不同关系的新图(例如从 Freebase 到 Wikidata)仍然超出范围。

相对实体表示使得归纳 GNN 成为可能。图片由作者提供。
如果在推理时同时出现新的实体和关系(一个全新的图)该怎么办?如果你不学习实体或关系嵌入,理论上转移是否可能?那我们就来探讨一下理论吧。
理论:是什么使得模型具有归纳性和可转移性?
让我们更正式地定义这个设置:
-
知识图谱是有向的、多关系的图,具有任意的节点和关系类型集合
-
图形到达时没有特征,即,我们不假设存在实体和关系的文本描述(也不假设有预计算的特征向量)。
-
给定一个查询(头,关系,?),我们希望对底层图(推理图)中的所有节点进行排名,并最大化返回真实尾部的概率。
-
转导设置:训练和推理时节点和实体的集合是相同的。
-
归纳(实体)设置:关系的集合必须在训练时固定,但节点在训练和推理时可以不同
-
归纳(实体和关系)设置:在推理时允许出现新的未见过的实体和关系
神经网络学习什么以能够对新数据进行泛化?主要参考书籍—Geometric Deep Learning by Bronstein, Bruna, Cohen, and Veličković—认为这是一个对称性和不变性的问题。
基础模型中可学习的不变性是什么?LLM 在固定的词汇表上进行训练(子词单元、字节,或如 Lexinvariant LLMs 中所示的随机初始化向量),视觉模型学习投影图像块的函数,音频模型学习投影音频块。
多关系图中的可学习不变性是什么?
首先,我们将介绍标准 齐次 图中的不变性(等变性)。
标准(单一)置换等变图模型: 当早期的 GNN 研究(Scarselli et al. 2008,Xu et al. 2018,Morris et al. 2018)展示出假设顶点 ID 是任意的,图模型的预测不应因重新分配顶点 ID 而改变时,图 ML 取得了重大进展。这被称为神经网络在节点 ID 上的 置换等变性。这一认识激发了极大的兴奋,并产生了大量新颖的图表示方法,只要神经网络对节点 ID 置换是等变的,我们就可以称其为图模型。

单关系图。GNN 对节点置换是等变的:即使在重新标记节点 ID 后,Michael Jackson 的节点向量也会保持相同的值。图像来自作者。
节点 ID 上的置换等变性允许 GNN 以归纳(零-shot)方式将从训练图中学到的模式转移到另一个(不同的)测试图中。这是等变性的结果,因为神经网络不能使用节点 ID 生成嵌入,它必须使用图结构。这就产生了我们所知的 结构表示(见 Srinivasan & Ribeiro (ICLR 2020))。
多关系图中的等变性
现在图中的边可能有不同的关系类型——是否有针对这种图的 GNN 理论?
1️⃣ 在我们之前的工作中,Weisfeiler and Leman Go Relational(与 Pablo Barceló、Christopher Morris 和 Miguel Romero Orth 合作,LoG 2022),我们推导出了 Relational WL —— 一个更侧重于节点级任务的多关系图的 WL 表达层级。黄等人(NeurIPS 2023)的伟大后续工作 将理论扩展到链接预测,形式化了 条件消息传递 和使用 Relational WL 的逻辑表达能力。✍️ 让我们记住 条件消息传递 —— 我们稍后会需要它 —— 它被证明能改善链接预测性能。
提议的全局读取向量的添加,由入边/出边方向诱导,与Emanuele Rossi 等人在均质 MPNNs 中研究方向性的近期工作相似(详细信息请阅读 Medium 上的博客文章)。然而,这些工作没有设想测试时甚至看不到关系的情况。
2️⃣ 双重排列等变(多关系)图模型: 最近,Gao 等人 2023提出了多关系图的双重等变性概念。双重等变性要求神经网络对节点 ID 和关系 ID 的联合排列保持等变。这确保了神经网络学习节点和关系之间的结构模式,从而使其能够归纳性地(零样本)将学习到的模式迁移到具有新节点和新关系的另一个图上。

多关系图中的双重等变性。对节点 ID 和关系 ID 的双重排列不会改变关系结构。因此,输出节点状态应保持相同(但排列不同)。图像由作者提供。
➡️ 在我们的工作中,我们发现了关系交互的不变性,即使关系的身份不同,它们的基本交互仍然保持不变,这些基本交互可以通过关系图来捕捉。在关系图中,每个节点都是来自原始图的关系类型。如果这个图中的两个节点通过原始图中具有这些关系类型的边连接(即,它们共享一个头部或尾部节点),那么这两个节点就会连接在一起。根据这种连接情况,我们在关系图中区分出4 种边类型:
-
头到头(h2h) — 两个关系可以从相同的头部实体开始;
-
尾到头(t2h) — 一个关系的尾部实体可以是另一个关系的头部实体;
-
头到尾(h2t) — 一个关系的头部实体可以是另一个关系的尾部实体;
-
尾到尾(t2t) — 两个关系可以有相同的尾部实体。

原始图中的不同连接模式会在关系图中产生不同的交互。最右侧:示例关系图(为清晰起见省略了反向边)。图像由作者提供。
关系图的一些优点:
-
它可以通过绝对任何多关系图(通过简单的稀疏矩阵乘法)构建。
-
这 4 种基本交互从不改变,因为它们只是编码了基本的拓扑——在有向图中,头部和尾部节点总是存在,而我们关系中会有这些连接模式。
本质上,在关系图上学习表示可以迁移到任何多关系图!这就是可学习的不变性。
实际上,可以证明(我们已经在进行正式证明,这些将在即将发布的工作中提供😉),通过关系图中的交互来表示关系是一个双等变模型!这意味着学习到的关系表示是独立于身份的,而是依赖于关系、节点及节点与关系之间的联合交互。
ULTRA: 一个用于 KG 推理的基础模型
在所有理论基础的支持下,我们现在准备介绍 ULTRA。
ULTRA 是一种统一的、可学习的、可迁移的图表示方法。ULTRA 利用关系图的基本交互的不变性(和等变性),并应用条件消息传递来获得相对关系表示。也许最酷的事实是
一个经过预训练的 ULTRA 模型可以在任何可能的多关系图上进行 0-shot 推理,并在任何图上进行微调。
换句话说,ULTRA 实际上是一个基础模型,可以在任何图输入上进行推理(表现已相当出色),并且可以在任何目标图上进行微调。
ULTRA 的关键组成部分是从关系图中构建的相对关系表示。给定一个查询 (Michael Jackson, genre, ?),我们首先用全 1 向量初始化关系图中的genre节点(所有其他节点初始化为 0)。运行 GNN 后,关系图的节点嵌入以genre节点为条件——这意味着每个起始初始化关系将拥有自己的一组关系特征矩阵,这在许多理论和实际方面都非常有用!

ULTRA 采用相对关系表示(在关系图上的标记技巧),使得每个关系(例如,“类型”)都有其唯一的关系表示矩阵。图片来源:作者。
实际上,给定一个输入 KG 和一个(h, r, ?)查询,ULTRA 执行以下操作:
-
构建关系图;
-
从条件消息传递 GNN 中获取关系特征(以初始化查询关系 r 为条件);
-
使用获得的关系表示来进行条件化的链接预测 GNN,条件为初始化的头节点 h;
步骤 2 和 3 通过对神经贝尔曼-福特网络 (NBFNet)进行略微不同的修改来实现。ULTRA 仅学习 4 种基本交互(h2t, t2t, t2h, h2h)和 GNN 权重——总体上非常小。我们实验的主要模型仅有 177k 参数。

ULTRA 采取的三个主要步骤:(1) 构建关系图;(2) 在关系图上进行条件消息传递以获取相对关系表示;(3) 使用这些表示进行实体级别的归纳链接预测 GNN。图片来源:作者。
实验:在零样本推理和微调中表现最佳
我们在基于 Freebase、Wikidata 和 Wordnet 的 3 个标准知识图谱上预训练了 ULTRA,并在 50 多个其他知识图谱上进行了零样本链接预测,这些知识图谱的规模从 1k 到 120k 节点和 2k 边到 1.1M 边不等。
在已知 SOTA 的数据集上平均,单个预训练的 ULTRA 模型在零样本推理模式下表现优于针对每个图专门训练的现有 SOTA 模型 🚀 微调甚至能使性能提升 10%。特别令人惊讶的是,单个训练的 ULTRA 模型能够适应如此不同规模的图(节点大小差异为 100 倍,边大小差异为 500 倍),而 GNN 通常会遇到规模泛化问题(参见 Yehudai et al, ICML 2021 和 Zhou et al, NeurIPS 2022 的突出工作)。

单个预训练的 ULTRA 在零样本推理模式下的表现优于在特定图上端到端训练的监督 SOTA 模型(参见平均列)。微调进一步提升了性能。图片来源:作者
🙃 实际上,经过 57 个测试图,我们有点用尽了测试 ULTRA 的知识图谱。如果你有新的基准藏在某处——告诉我们!
扩展行为
我们可以通过将更多图谱添加到预训练混合中进一步提升零样本性能,尽管我们确实观察到在训练 4 个以上图谱后性能出现一定的饱和。
Scaling Laws 预言了使用更大的模型在更多优质数据上训练会有更好的表现,因此这绝对在我们的计划之中。

零样本性能随着预训练混合图谱的多样性而增加。图片来源:作者。
结论:代码、数据、检查点
所以,知识图谱推理的基础模型终于来了,我们已经过了 2018 年的门槛!单个预训练的 ULTRA 模型可以对任何领域的知识图谱(多关系图)进行链接预测。你只需一个具有超过 1 种边类型的图谱即可开始。
📈 实际上,ULTRA 在多种知识图谱基准测试中的零样本模式下已经展示了非常有前景的性能,但你可以通过短时间的微调进一步提升性能。
我们在 GitHub 上提供了所有代码、训练数据和预训练模型检查点,以便你可以立即在你的数据上运行 ULTRA!
📜 预印本:arxiv
🛠️ 代码、数据:Githtub repo
🍪 检查点:在 Github repo 中有 2 个检查点(每个 2 MB)
🌎 项目网站:这里
作为总结,KG 推理只是推理领域中许多有趣问题的一部分,大多数问题仍然没有通用的解决方案。我们相信,KG 推理的成功将为其他推理领域带来更多突破(例如,我们最近发现了LLMs 实际上可以学习和运用文本规则)。让我们对推理的未来保持乐观!
UMAP 变异解释
原文:
towardsdatascience.com/umap-variance-explained-b0eacb5b0801
数学统计与生命科学机器学习
解释 UMAP 成分的简单方法
·发表于Towards Data Science ·阅读时间 19 分钟·2023 年 3 月 27 日
--

在MNIST 手写数字黑白图像上计算的 UMAP。作者提供的图像
这是我专栏数学统计与生命科学机器学习的第二十五篇文章,我用简单的语言讨论计算生物学和生命科学中的分析方法。UMAP是一种降维技术,与tSNE一起变得越来越流行,并且实际上成为了分析单细胞基因组学数据的标准工具,而传统方法如PCA 存在局限性。然而,与 PCA 相比,UMAP 和 tSNE 的一个缺点是它们的不可解释成分,难以直接与原始数据的变异性联系起来。在这篇文章中,我建议了一种简单的方法来估计由主要 UMAP 和 tSNE 成分解释的变异量。以经典的MNIST手写数字黑白图像数据集为基准,我展示了主要 UMAP 和 tSNE 成分在解释数据的总体变异性方面劣于 PCA 成分,然而,令人惊讶的是,它们在数据点标签的关联性方面表现更好,即它们解释了数据中更多的生物学而非总变异性。
为分析准备 MNIST 数据
作为基准数据集,我们将使用MNIST,该数据集包含 70,000 张手写数字的黑白图像,28 x 28 像素的分辨率,即每张图像784 像素。首先,我们将下载 MNIST 数据集,检查其维度并可视化一些随机图像。
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version = 1)
labels = mnist.target.astype(int)
print(mnist.data.shape)
#(70000, 784)
from matplotlib import pyplot as plt
plt.figure(figsize = (20, 15))
for i in range(25):
plt.subplot(5, 5, i + 1)
plt.imshow(mnist.data[i].reshape(28, 28), cmap = plt.get_cmap('gray'))
plt.show()

一些随机的 MNIST 手写数字图像。图片由作者提供
MNIST 数据集的值表示像素强度,范围从 0 到 255,其中 0 对应黑色背景。因此,MNIST 是一个“零膨胀”数据集,这与典型的单细胞基因表达数据集非常相似。通常建议通过最大值 255 来归一化黑白像素强度。然而,在这里,类比单细胞基因表达,我们将使用对数变换作为另一种温和的归一化策略。同时,为了节省时间,我们不会使用所有 70,000 张图像,而是随机挑选 10,000 张图像,稍后我们会多次抽取这 10,000 张图像,以确保我们的结论是可靠的。
import numpy as np
N_points = 10000
X = np.log10(mnist.data + 1)
#X = mnist.data / 255
np.random.seed(123)
random_indices = np.random.choice(X.shape[0], size=N_points, replace=False)
X = X[random_indices,:]
labels = labels[random_indices]
作为第一步探索,为了理解 MNIST 中的变异,我们将对 MNIST 及其打乱版本进行 PCA 分析。这将帮助我们估计初始 784 个维度中有多少个有意义的,即非冗余的维度,即我们将确定要保留的信息性维度数量,以供将来所有测试的降维和聚类技术使用。更多详情请见这里。
import pandas as pd; import matplotlib.pyplot as plt
from sklearn.decomposition import PCA; import seaborn as sns
N_pca_comps = 100
sns.set(font_scale = 1.5); figure = plt.figure(figsize = (20, 15))
plt.subplot(221)
X_reduced = PCA(n_components = 2).fit_transform(X)
plt.scatter(X_reduced[:,0], X_reduced[:,1], c=labels, cmap='tab10', s=10)
plt.title('PCA: MNIST', fontsize = 25)
plt.xlabel('PC1', fontsize = 22); plt.ylabel('PC2', fontsize = 22)
plt.subplot(222)
pca = PCA(n_components = N_pca_comps).fit(X)
print('Observed variance explained:')
print(pca.explained_variance_ratio_[0:10]); print('\n')
plt.bar(range(len(pca.explained_variance_ratio_)),
pca.explained_variance_ratio_)
plt.xlabel('Number of Principal Components', fontsize = 22)
plt.ylabel('Explained Variance', fontsize = 22)
N_perm = 10
X_flat = X.flatten()
expl_var_perm_df = pd.DataFrame(index = list(range(N_perm)),
columns = list(range(X.shape[1])))
for i in range(N_perm):
np.random.shuffle(X_flat)
X_perm = X_flat.reshape(X.shape[0], X.shape[1])
pca_perm = PCA().fit(X_perm)
expl_var_perm_df.loc[i] = pca_perm.explained_variance_ratio_
print('Finished {} permutations'.format(i + 1))
X_perm = list(expl_var_perm_df.mean(axis = 0) +
2*expl_var_perm_df.std(axis = 0))
print('\nPermuted variance explained:')
print(X_perm[0:10])
plt.subplot(223)
plt.plot(pca.explained_variance_ratio_, c = 'blue')
plt.plot(X_perm, c = 'red'); plt.xlim([-1, N_pca_comps])
plt.xlabel('Number of Principal Components', fontsize = 22)
plt.ylabel('Explained Variance', fontsize = 22)
plt.gca().legend(('Observed variance explained',
'Permuted variance explained'), fontsize = 20)
plt.subplot(224)
pval = list()
for j in range(N_pca_comps):
pval.append(np.sum(expl_var_perm_df.iloc[:, j] +
2*expl_var_perm_df.std(axis = 0) >=
pca.explained_variance_ratio_[j]) / N_perm)
plt.plot(pval, c = 'darkgreen')
plt.xlabel('Number of Principal Components', fontsize = 22)
plt.ylabel('P-value', fontsize = 22); plt.xlim([-1, N_pca_comps])
N_opt_pcs = np.where(np.array(pval) >= 0.05)[0][0]
print('\nNumber of significant Principal Components: {}'.format(N_opt_pcs))
print('Together they explain {}% of variation in the data'.\
format(int(round(sum(pca.explained_variance_ratio_[0:\
np.where(np.array(pval) >= 0.05)[0][0]])*100,0))))
figure.tight_layout()
plt.show()

PCA 2D 地图、碎石图和诊断图,说明 MNIST 手写数字数据集中信息性主要成分的数量。图片由作者提供
对 MNIST 和打乱后的 MNIST 进行 PCA 分析的重要结果是,数据集似乎有62 个信息性主要成分,这些成分总共捕捉了 86%的数据变异。因此,如果我们用 62 个 PC 替代原始包含 784 个特征的数据集,我们将保留数据中的大部分变异,但数据的维度将减少超过 10 倍。
现在,让我们使用 UMAP 和 tSNE 降维技术将 MNIST 数据集可视化为 2D 地图。在这两种情况下,我们将使用 2 个主要成分,即 2 个 PC 进行初始化。tSNE 的困惑度超参数和 UMAP 的最近邻数量将计算为数据点(图像)数量的平方根。详情请见这里。
import umap; import numpy as np
import seaborn as sns; import matplotlib.pyplot as plt
from sklearn.manifold import TSNE; from sklearn.decomposition import PCA
opt_perp = np.int(np.round(np.sqrt(X.shape[0]), 0))
X_reduced = PCA(n_components = N_opt_pcs).fit_transform(X)
umap_embedding = umap.UMAP(n_components = 2, n_neighbors = opt_perp,
init = X_reduced[:, 0:2],
min_dist=0.3, n_epochs = 1000, random_state = 123,
verbose = 0).fit_transform(X_reduced)
tsne_embedding = TSNE(n_components=2, perplexity=opt_perp,
init=X_reduced[:, 0:2],
learning_rate = 200, n_iter = 1000, random_state = 123,
verbose = 0).fit_transform(X_reduced)
sns.set(font_scale = 1.5); plt.figure(figsize = (20, 10))
plt.subplot(121)
plt.scatter(tsne_embedding[:, 0], tsne_embedding[:, 1], c = labels, s = 10,
cmap = 'tab20')
plt.title('tSNE: MNIST', fontsize = 25)
plt.xlabel("tSNE1", fontsize = 22); plt.ylabel("tSNE2", fontsize = 22)
plt.subplot(122)
plt.scatter(umap_embedding[:, 0], umap_embedding[:, 1], c = labels, s = 10,
cmap = 'tab20')
plt.title('UMAP: MNIST', fontsize = 25)
plt.xlabel("UMAP1", fontsize = 22); plt.ylabel("UMAP2", fontsize = 22)
plt.show()

tSNE 和 UMAP 的 MNIST 手写数字数据二维地图。图片由作者提供
从 2D PCA、tSNE 和 UMAP 图像的比较中可以得出一个明显的结论,即手写数字的类别在非线性邻域图降维方法(如 tSNE / UMAP)中比在线性 矩阵分解技术(如 PCA)中得到的分辨率要高得多。因此,如果我们假设手写数字的类别(即生物表型,例如 scRNAseq 实验中的细胞类型)占据了 MNIST 数据中大部分的变异,那么可以合理地假设两个 tSNE/UMAP 组件能够捕捉到比两个 PCA 组件更多的生物变异(至少在 MNIST 数据集中)。在接下来的章节中,我们将尝试量化和证明这一假设。
PCA 组件解释的 MNIST 变异
作为非线性邻域图技术,UMAP / tSNE 似乎没有其组件所解释的变异概念,与矩阵分解线性降维技术(如 PCA)形成对比,例如,请参见 UMAP 作者 Leland McInnes 的回答。不过,在这里,我们将尝试量化 UMAP / tSNE 组件解释的 MNIST 像素强度变异量,并与 PCA 组件解释的变异量进行比较,使用部分最小二乘(PLS)回归和R 方统计量来推广矩阵运算。
让我们检查 MNIST 中由第一个主成分(PC1)解释的变异百分比,我们可以很容易地通过 PCA 计算并提取它作为 MNIST 的标准化第一个特征值。如下所示,PC1 解释了约 11%的 MNIST 变异。
#First, we will compute variance explained by PC1 via PCA in sklearn
import numpy as np
from sklearn.decomposition import PCA
pca = PCA(n_components = X.shape[1]).fit(X)
pca_comps = PCA().fit_transform(X)
print(pca.explained_variance_ratio_[0])
#0.11043073983593521
现在,我们将 通过 部分最小二乘(PLS) 回归重新生成这个数字。我们将使用以下推理。假设我们想要用另一个矩阵 PCA_matrix 来近似矩阵 X,目前只包括一个列,即 PC1,但通常可以包括最多 784 列的 MNIST 数据集。然后,我们可以拟合一个 PLS 线性回归模型X = B * PCA_matrix并计算一个R 方统计量,这将反映X中由PCA_matrix解释的变异量。在矩阵形式中,R 方统计量将由以下方程给出:

其中B*PCA_matrix表示从其近似PCA_matrix对X的预测,并将通过 PLS 回归(第一方程)得出。为了在scikit-learn Python 模块中技术性地实现这一过程,我们首先拟合一个以X为响应变量、PCA_matrix为解释变量的 PLS 模型,我们计算预测值y_pred = B*PCA_matrix,最后,我们可以使用 scikit-learn 中的r2_score函数,或使用上述第二方程来计算R-squared 统计量。让我们检查一下两者是否会给出相同的答案:
#Now let us compute variance explained by PC1 through PLS procedure
from sklearn.metrics import r2_score
from sklearn.cross_decomposition import PLSRegression
PCA_matrix = pd.DataFrame(pca_comps[:, 0:1])
pls = PLSRegression(n_components = 1)
pls.fit(PCA_matrix, X)
y_pred = pls.predict(PCA_matrix)
print(r2_score(X, y_pred, multioutput = 'variance_weighted'))
#0.11043073983593246
#Finally, let us compute variance explained by PC1 from scratch
print(1 - np.sum((np.array(X) - np.array(y_pred))**2) / np.sum((X - \
np.mean(X, axis = 0))**2))
#0.11043073983593554
因此,通过利用 PLS 回归,我们计算了第一主成分解释的 MNIST 总变异 在 PCA 算法之外的比例。将 PLS 计算的比例与 PCA 算法中 pca.explained_variance_ratio_[0](即第一特征值与所有特征值之和的比率)解释的方差进行比较。它们几乎是相同的。现在让PCA_matrix包含若干个主成分。下面,我们演示 PLS 计算主成分解释的累计方差将与 PCA 累计解释的方差几乎相同。
from sklearn.metrics import r2_score
from sklearn.cross_decomposition import PLSRegression
predicted_var_expl = []
for i in range(1, 21):
PCA_matrix_current = pd.DataFrame(pca_comps[:, 0:i])
pls_current = PLSRegression(n_components = i)
pls_current.fit(PCA_matrix_current, X)
y_pred_current = pls_current.predict(PCA_matrix_current)
predicted_var_expl.append(r2_score(X, y_pred_current,
multioutput = 'variance_weighted'))
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(font_scale = 1.5); plt.figure(figsize = (20, 15))
plt.plot(np.cumsum(pca.explained_variance_ratio_[0:100]),linewidth=5)
plt.plot(predicted_var_expl, linewidth = 5)
plt.ylabel('Cumulative Explained Variance', fontsize = 20)
plt.xlabel('Number of Principal Components', fontsize = 20)
plt.legend(['PCA - computed cumulative variance explained',
'PLS - computed cumulative variance explained'],
fontsize = 20); plt.show()

使用 PCA 算法和基于 PLS 的方法对 MNIST 图像像素强度数据解释的累计方差进行比较。图片由作者提供
为了清晰起见,我们将 PLS 计算的 MNIST 方差(PC1-PC20)绘制在原生 PCA 特征值基础上解释的方差(PC1-PC100)之上。我们可以看到,PLS 和 PCA 计算的累计方差曲线很好地重合。这并不令人惊讶,因为 PLS 计算本质上模拟了 PCA 算法的内部过程。然而,非常重要的是,这为我们提供了计算X通过任何近似(不仅是PCA_matrix,还包括其他任何矩阵)解释的方差的工具。在下一节中,我们将使用UMAP_matrix作为 MNIST 像素强度的原始X矩阵的近似,并比较 UMAP 和 PCA 组件解释的 X 累计方差。
MNIST 由 UMAP 组件解释的变异
在本节中,我们将使用前一节中开发的 PLS 估计解释方差的方法。同样,我们将使用相同的推理:由于 UMAP 提供了对原始数据X的某种近似,并且人们甚至在 2D UMAP 上运行聚类以发现 scRNAseq 领域中的细胞类型,我们可以将UMAP_matrix(对于 tSNE 也是如此)作为X的近似,然后拟合 PLS 模型X=B * UMAP_matrix,并通过 R-squared 统计量估计 UMAP 组件解释的 MNIST 方差的比例。

现在,我们将使用在前一部分开发的 PLS 方法,计算第一个 UMAP 组件解释的 MNIST 方差比例。
from sklearn.metrics import r2_score
from sklearn.cross_decomposition import PLSRegression
#MNIST variation explained by UMAP1
UMAP_matrix = pd.DataFrame(umap_embedding[:, 0:1])
pls = PLSRegression(n_components = 1)
pls.fit(UMAP_matrix, X)
y_pred = pls.predict(UMAP_matrix)
print(r2_score(X, y_pred, multioutput = 'variance_weighted'))
#0.07335034485651613
#Here the same but more explicitly via the R² equation above
print(1 - np.sum((np.array(X) - np.array(y_pred))**2) / np.sum((X - \
np.mean(X, axis = 0))**2))
#0.07335034485652026
#MNIST variation explained by tSNE1
tSNE_matrix = pd.DataFrame(tsne_embedding[:, 0:1])
pls = PLSRegression(n_components = 1)
pls.fit(tSNE_matrix, X)
y_pred = pls.predict(tSNE_matrix)
print(r2_score(X, y_pred, multioutput = 'variance_weighted'))
#0.07265918428990921
#Here the same but explicitly via the R² equation above
print(1 - np.sum((np.array(X) - np.array(y_pred))**2) / np.sum((X - \
np.mean(X, axis = 0))**2))
#0.07265918428991347
我们得出结论,UMAP 和 tSNE 的前几个组件解释了约 7%的 MNIST 变异,这低于第一个 PCA 组件解释的 11%。这并不令人惊讶。直观上,很难期待有另一个潜在变量替代 PC1,解释超过 11%的 MNIST 变异而不重新定义“解释方差”的概念。当前的定义来自 PCA(标准化特征值)和线性回归,即 R 平方分析。两者都是线性框架。此外, PC1 根据定义对应于数据变异的最大方向。因此,如果另一个潜在变量来自非线性分析如 UMAP,这并不本身旨在最大化数据的变异,那么很难期待例如 UMAP1 在 MNIST 中解释更多变异 在“解释方差”的线性定义下。现在我们将计算几个顶级 tSNE 和 UMAP 组件解释的方差。为了使我们的分析对 10,000 张图像的采样更具稳健性,我们将独立抽取它们N_iter次。此外,对于每次迭代,我们将稍微改变/ 扰动 UMAP 和 tSNE 的超参数,以及 PCA 的数据点数量。这将使我们能够建立置信区间并解决我们分析的敏感性。
import umap; import numpy as np; import seaborn as sns
import matplotlib.pyplot as plt; from sklearn.manifold import TSNE
from sklearn.metrics import r2_score; from sklearn.decomposition import PCA
from sklearn.cross_decomposition import PLSRegression
N_iter = 3; N_comps = 3
N_points_list = [5000, 10000, 15000]
perp_list = [50, 100, 150]; min_dist_list = [0.1, 0.2, 0.3]
predicted_var_expl_matrix = np.zeros(shape = (N_iter, N_comps))
predicted_var_expl_umap_matrix = np.zeros(shape = (N_iter, N_comps))
predicted_var_expl_tsne_matrix = np.zeros(shape = (N_iter, N_comps))
for j in range(N_iter):
#MNIST variance explained by PCA components
np.random.seed(j)
X = np.log10(mnist.data + 1); labels = mnist.target.astype(int)
random_indices = np.random.choice(X.shape[0], size = N_points_list[j],
replace = False)
X_sample = X[random_indices,:]; labels_sample = labels[random_indices]
pca_comps_sample = PCA(n_components = N_comps).fit_transform(X_sample)
predicted_var_expl = []
for i in range(1, (N_comps + 1)):
PCA_matrix_current = pd.DataFrame(pca_comps_sample[:, 0:i])
pls_current = PLSRegression(n_components = i)
pls_current.fit(PCA_matrix_current, X_sample)
y_pred_current = pls_current.predict(PCA_matrix_current)
predicted_var_expl.append(r2_score(X_sample, y_pred_current,
multioutput='variance_weighted'))
predicted_var_expl_matrix[j,:] = predicted_var_expl
#MNIST variance explained by UMAP components
X = np.log10(mnist.data + 1); labels = mnist.target.astype(int)
random_indices = np.random.choice(X.shape[0], size = N_points,
replace = False)
X_sample = X[random_indices,:]; labels_sample = labels[random_indices]
opt_perp = np.int(np.round(np.sqrt(X_sample.shape[0]), 0))
X_reduced_sample = PCA(n_components = N_opt_pcs).fit_transform(X_sample)
umap_embedding_sample = umap.UMAP(n_components = N_comps,
n_neighbors = opt_perp,
init = X_reduced_sample[:, 0:N_comps],
min_dist = min_dist_list[j],
n_epochs = 1000, verbose = \
0).fit_transform(X_reduced_sample)
predicted_var_expl_umap = []
for i in range(1, (N_comps + 1)):
UMAP_matrix_current = pd.DataFrame(umap_embedding_sample[:, 0:i])
pls_current = PLSRegression(n_components = i)
pls_current.fit(UMAP_matrix_current, X_sample)
y_pred_current = pls_current.predict(UMAP_matrix_current)
predicted_var_expl_umap.append(r2_score(X_sample, y_pred_current, \
multioutput = 'variance_weighted'))
predicted_var_expl_umap_matrix[j,:] = predicted_var_expl_umap
#MNIST variance explained by tSNE components
X = np.log10(mnist.data + 1); labels = mnist.target.astype(int)
random_indices = np.random.choice(X.shape[0], size = N_points,
replace = False)
X_sample = X[random_indices,:]; labels_sample = labels[random_indices]
X_reduced_sample = PCA(n_components = N_opt_pcs).fit_transform(X_sample)
tsne_embedding_sample = TSNE(n_components = N_comps,
perplexity = perp_list[j],
init = X_reduced_sample[:, 0:N_comps],
learning_rate = 200, n_iter = 1000,
verbose = 0).fit_transform(X_reduced_sample)
predicted_var_expl_tsne = []
for i in range(1, (N_comps + 1)):
tSNE_matrix_current = pd.DataFrame(tsne_embedding_sample[:, 0:i])
pls_current = PLSRegression(n_components = i)
pls_current.fit(tSNE_matrix_current, X_sample)
y_pred_current = pls_current.predict(tSNE_matrix_current)
predicted_var_expl_tsne.append(r2_score(X_sample, y_pred_current, \
multioutput = 'variance_weighted'))
predicted_var_expl_tsne_matrix[j,:] = predicted_var_expl_tsne
print("MNIST variance explained by PCA components:")
print(predicted_var_expl_matrix)
print("\nMNIST variance explained by UMAP components:")
print(predicted_var_expl_umap_matrix)
print("\nMNIST variance explained by tSNE components:")
print(predicted_var_expl_tsne_matrix)
#Plot MNIST variance explained by leading PCA, tSNE and UMAP components
sns.set(font_scale = 1.5); plt.figure(figsize = (20, 15))
plt.errorbar(range(1, (N_comps + 1)), np.mean(predicted_var_expl_matrix,
axis = 0), yerr = 2*np.std(predicted_var_expl_matrix, axis = 0),
linewidth = 3, color = 'red', marker = 'o', markersize = 10,
capsize = 5, capthick = 3)
plt.errorbar(range(1, (N_comps + 1)), np.mean(predicted_var_expl_tsne_matrix,
axis = 0), yerr = 2*np.std(predicted_var_expl_tsne_matrix,
axis = 0), linewidth = 3, color = 'blue', marker = 'o',
markersize = 10, capsize = 5, capthick = 3)
plt.errorbar(range(1, (N_comps + 1)), np.mean(predicted_var_expl_umap_matrix,
axis = 0), yerr = 2*np.std(predicted_var_expl_umap_matrix,
axis = 0), linewidth = 3, color = 'green', marker = 'o',
markersize = 10, capsize = 5, capthick = 3)
plt.ylabel('Cumulative MNIST Explained Variance', fontsize = 20)
plt.xlabel('Number of Components', fontsize = 20)
plt.legend(['PCA variance explained', 'tSNE variance explained',
'UMAP variance explained'], fontsize = 20)
plt.xlim([0.8, (N_comps + 0.2)]); plt.xticks([1, 2, 3]); plt.show()

MNIST 图像像素强度数据的累积方差由 PCA、tSNE 和 UMAP 组件解释。图片由作者提供
在这里,我们得出结论,主成分 PCA 在 MNIST 数据集中的解释变异量始终大于主成分 tSNE 和 UMAP。这个结果是预期的,因为 UMAP 和 tSNE 并不旨在建立最大变异方向,而 PCA 则有此目标。此外,主成分 tSNE 和 UMAP 在 MNIST 数据中解释的变异量是相当的,并且都低于 PCA 组件解释的变异量。因此,我们似乎观察到 MNIST 解释方差在矩阵分解和邻接图降维技术之间的系统性差异。
生物学变异由 UMAP 组件解释
在上一部分中,我们展示了 UMAP 和 tSNE 的主成分在 MNIST 数据集中解释的变异量明显少于主成分 PCA,这是一种预期的结果,因为 PCA 根据定义旨在寻找数据中的最大变异成分,而 UMAP 和 tSNE 则不是。
然而,这确实是一个有趣的悖论:通过查看 MNIST 数据的 2D UMAP 和 tSNE 图,我们可以明显看到更多的明显的手写数字簇,相比于对应的 2D PCA 图。尽管 2D UMAP 和 tSNE 解释的 MNIST 方差 比 PCA 少。
在这里,我们将假设 UMAP / tSNE 组件与感兴趣的表型有关,即 MNIST 标签,或者 scRNAseq 的细胞类型,而不是数据的总变异。为了检验这一假设,让我们首先探索主导 PCA、UMAP 和 tSNE 组件如何与 MNIST 标签相关。
import umap; import seaborn as sns
from sklearn.datasets import fetch_openml
from sklearn.decomposition import PCA; import numpy as np
import matplotlib.pyplot as plt; from sklearn.manifold import TSNE
N_comps = 3; N_points = 10000
mnist = fetch_openml('mnist_784', version = 1)
labels = mnist.target.astype(int); X = np.log10(mnist.data + 1)
#Subsample MNIST data down to 10 000 images
np.random.seed(123)
random_indices = np.random.choice(X.shape[0], size=N_points, replace=False)
X = X[random_indices,:]; labels = labels[random_indices]
#Compute top 3 PCA components on the subsampled MNIST data
pca = PCA(n_components = N_comps).fit(X)
pca_comps = PCA().fit_transform(X)
#Compute top 3 UMAP components on the subsampled MNIST data
opt_perp = np.int(np.round(np.sqrt(X.shape[0]), 0))
X_reduced = PCA(n_components = N_opt_pcs).fit_transform(X)
umap_embedding = umap.UMAP(n_components = N_comps, n_neighbors = opt_perp,
init = X_reduced[:, 0:N_comps],
min_dist = 0.3, n_epochs = 1000, random_state=123,
verbose = 0).fit_transform(X_reduced)
#Compute top 3 tSNE components on the subsampled MNIST data
tsne_embedding = TSNE(n_components = N_comps, perplexity = opt_perp,
init = X_reduced[:, 0:N_comps],
learning_rate = 200, n_iter = 1000, random_state = 123,
verbose = 0).fit_transform(X_reduced)
#Pairwise correlations between MNIST labels and PCA / tSNE / UMAP components
from scipy.stats import spearmanr
rho_matrix = np.zeros(shape = (3, 3))
for i in range(N_comps):
rho, pval = spearmanr(pca_comps[:, i], labels)
rho_matrix[0, i] = np.abs(rho)
for i in range(N_comps):
rho, pval = spearmanr(tsne_embedding[:, i], labels)
rho_matrix[1, i] = np.abs(rho)
for i in range(N_comps):
rho, pval = spearmanr(umap_embedding[:, i], labels)
rho_matrix[2, i] = np.abs(rho)
rho_df = pd.DataFrame(rho_matrix, columns = ['Comp1', 'Comp2', 'Comp3'],
index = ['PCA', 'tSNE', 'UMAP'])
sns.set(font_scale = 1.5); plt.figure(figsize = (20, 10))
sns.heatmap(rho_df.T, cmap = "Blues", annot = True)
plt.title('Spearman correlation of leading PCA / tSNE / UMAP components'+
' with MNIST labels', fontsize = 20); plt.show()

MNIST 标签与 PCA / tSNE / UMAP 组件之间的配对斯皮尔曼相关性。图像由作者提供
我们可以观察到,tSNE1 和 UMAP1 与 MNIST 标签的相关性比 PC1 更强。此外,PCA、tSNE 和 UMAP 的第二个组件与 MNIST 标签的相关性强度相当,而第三个 UMAP 组件与 MNIST 标签的相关性远远强于 PCA 和 tSNE 的第三个组件,即 PC3 和 tSNE3,这两个组件与 MNIST 标签几乎没有相关性。在我的 github 上提供的完整笔记本中,我展示了得到的热图在采样和 PCA / tSNE / UMAP 超参数扰动方面是稳健的。
另外,我们还可以通过前几节开发的 PLS 回归 方法量化 PCA、tSNE 和 UMAP 组件与 MNIST 标签(数字类,即 scRNAseq 数据中的细胞类型)之间的联系。通过类比,如果我们将 UMAP_matrix(或 PCA_matrix 或 tSNE_matrix)视为 MNIST 标签向量的近似值,我们可以拟合 PLS 模型 labels= B * UMAP_matrix,并估算UMAP 组件解释的 MNIST 标签变异 的比例。

让我们测试 PC1、tSNE1 和 UMAP1 解释了多少 MNIST 标签的方差:
from sklearn.metrics import r2_score
from sklearn.cross_decomposition import PLSRegression
#Variance in MNIST labels explained by PC1
pca = PCA(n_components = X.shape[1]).fit(X)
pca_comps = PCA().fit_transform(X)
PCA_matrix = pd.DataFrame(pca_comps[:, 0:1])
pls = PLSRegression(n_components = 1)
pls.fit(PCA_matrix, labels)
y_pred = pls.predict(PCA_matrix)
r2_score(labels, y_pred, multioutput = 'variance_weighted')
#0.01570798235844606
#Variance in MNIST labels explained by tSNE1
tSNE_matrix = pd.DataFrame(tsne_embedding[:, 0:1])
pls = PLSRegression(n_components = 1)
pls.fit(tSNE_matrix, labels)
y_pred = pls.predict(tSNE_matrix)
r2_score(labels, y_pred, multioutput = 'variance_weighted')
#0.04531724676013893
#Variance in MNIST labels explained by UMAP1
UMAP_matrix = pd.DataFrame(umap_embedding[:, 0:1])
pls = PLSRegression(n_components = 1)
pls.fit(UMAP_matrix, labels)
y_pred = pls.predict(UMAP_matrix)
r2_score(labels, y_pred, multioutput = 'variance_weighted')
#0.1369512765129124
我们可以清楚地看到,tSNE1,特别是 UMAP1,比 PC1 解释了更多的 MNIST 标签变异,这证实了我们之前在上面的热图中通过斯皮尔曼相关性观察到的结果。我们还可以从 statsmodels 包中的线性回归模型中复现上述的R-squared 值。
import statsmodels.formula.api as smf
my_df_comps = pd.DataFrame({'LABELS': labels,
'PC1': np.array(PCA_matrix).flatten(),
'tSNE1': np.array(tSNE_matrix).flatten(),
'UMAP1': np.array(UMAP_matrix).flatten()})
smf.ols(formula = 'LABELS ~ PC1', data = my_df_comps).fit().summary()
smf.ols(formula = 'LABELS ~ tSNE1', data = my_df_comps).fit().summary()
smf.ols(formula = 'LABELS ~ UMAP1', data = my_df_comps).fit().summary()

现在,一旦我们确保 可以正确计算 由 PC1、tSNE1 和 UMAP1 解释的 MNIST 标签方差,通过使用 PLS 和 R-squared 方法,我们可以将这一过程扩展到其他主要的 PCA / tSNE / UMAP 组件,并可视化累积方差解释如何随着更多组件的加入而变化。像往常一样,我们将对 PCA、tSNE 和 UMAP 进行子采样迭代、图像数量和方法超参数的扰动。
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import PLSRegression
import seaborn as sns; import matplotlib.pyplot as plt
import numpy as np; from sklearn.metrics import r2_score
N_iter = 3; N_comps = 3
N_points_list = [5000, 10000, 15000]
min_dist_list = [0.1, 0.2, 0.3]; perp_list = [90, 100, 110]
predicted_var_expl_labels_matrix = np.zeros(shape = (N_iter, N_comps))
predicted_var_expl_labels_umap_matrix=np.zeros(shape=(N_iter, N_comps))
predicted_var_expl_labels_tsne_matrix=np.zeros(shape=(N_iter, N_comps))
for j in range(N_iter):
#Variance in MNIST labels explained by PCA components
np.random.seed(j)
X = np.log10(mnist.data + 1); labels = mnist.target.astype(int)
random_indices = np.random.choice(X.shape[0], size = N_points_list[j],
replace = False)
X_sample = X[random_indices,:]; labels_sample = labels[random_indices]
pca_comps_sample = PCA(n_components = N_comps).fit_transform(X_sample)
predicted_var_expl_labels = []
for i in range(1, (N_comps + 1)):
PCA_matrix_current_labels = pd.DataFrame(pca_comps_sample[:, 0:i])
pls_current_labels = PLSRegression(n_components = i)
pls_current_labels.fit(PCA_matrix_current_labels, labels_sample)
y_pred_current_labels = pls_current_labels.predict(\
PCA_matrix_current_labels)
predicted_var_expl_labels.append(r2_score(labels_sample, \
y_pred_current_labels, multioutput = 'variance_weighted'))
predicted_var_expl_labels_matrix[j,:] = predicted_var_expl_labels
#Variance in MNIST labels explained by UMAP components
X = np.log10(mnist.data + 1); labels = mnist.target.astype(int)
random_indices = np.random.choice(X.shape[0], size = N_points,
replace = False)
X_sample = X[random_indices,:]; labels_sample = labels[random_indices]
opt_perp = np.int(np.round(np.sqrt(X_sample.shape[0]), 0))
X_reduced_sample = PCA(n_components=N_opt_pcs).fit_transform(X_sample)
umap_embedding_sample = umap.UMAP(n_components = N_comps,
n_neighbors = opt_perp,
init=X_reduced_sample[:, 0:N_comps],
min_dist = min_dist_list[j],
n_epochs = 1000, verbose = \
0).fit_transform(X_reduced_sample)
predicted_var_expl_labels_umap = []
for i in range(1, (N_comps + 1)):
UMAP_matrix_current_labels=pd.DataFrame(umap_embedding_sample[:,0:i])
pls_current_labels_umap = PLSRegression(n_components = i)
pls_current_labels_umap.fit(UMAP_matrix_current_labels,labels_sample)
y_pred_current_labels_umap = pls_current_labels_umap.predict(\
UMAP_matrix_current_labels)
predicted_var_expl_labels_umap.append(r2_score(\
labels_sample, y_pred_current_labels_umap, \
multioutput = 'variance_weighted'))
predicted_var_expl_labels_umap_matrix[j,:]=predicted_var_expl_labels_umap
#Variance in MNIST labels explained by tSNE components
X = np.log10(mnist.data + 1); labels = mnist.target.astype(int)
random_indices = np.random.choice(X.shape[0], size = N_points,
replace = False)
X_sample = X[random_indices,:]; labels_sample = labels[random_indices]
opt_perp = np.int(np.round(np.sqrt(X_sample.shape[0]), 0))
X_reduced_sample = PCA(n_components = N_opt_pcs).fit_transform(X_sample)
tsne_embedding_sample = TSNE(n_components = N_comps,
perplexity = perp_list[j],
init = X_reduced_sample[:, 0:N_comps],
learning_rate = 200, n_iter = 1000,
verbose=0).fit_transform(X_reduced_sample)
predicted_var_expl_labels_tsne = []
for i in range(1, (N_comps + 1)):
tSNE_matrix_current_labels=pd.DataFrame(tsne_embedding_sample[:,0:i])
pls_current_labels_tsne = PLSRegression(n_components = i)
pls_current_labels_tsne.fit(tSNE_matrix_current_labels,labels_sample)
y_pred_current_labels_tsne = pls_current_labels_tsne.predict(\
tSNE_matrix_current_labels)
predicted_var_expl_labels_tsne.append(r2_score(labels_sample, \
y_pred_current_labels_tsne, multioutput = 'variance_weighted'))
predicted_var_expl_labels_tsne_matrix[j,:]=predicted_var_expl_labels_tsne
print("Variance in MNIST labels explained by PCA components:")
print(predicted_var_expl_labels_matrix)
print("\nVariance in MNIST labels explained by UMAP components:")
print(predicted_var_expl_labels_umap_matrix)
print("\nVariance in MNIST labels explained by tSNE components:")
print(predicted_var_expl_labels_tsne_matrix)
#Plot MNIST labels variance explained by PCA, tSNE and UMAP components
sns.set(font_scale = 1.5); plt.figure(figsize = (20, 15))
plt.errorbar(range(1, (N_comps + 1)),
np.mean(predicted_var_expl_labels_matrix, axis = 0),
yerr = 2*np.std(predicted_var_expl_labels_matrix, axis = 0),
linewidth = 3, color = 'red', marker = 'o', markersize = 10,
capsize = 5, capthick = 3)
plt.errorbar(range(1, (N_comps + 1)),
np.mean(predicted_var_expl_labels_tsne_matrix, axis = 0),
yerr=2*np.std(predicted_var_expl_labels_tsne_matrix,axis=0),
linewidth = 3, color = 'blue', marker = 'o', markersize = 10,
capsize = 5, capthick = 3)
plt.errorbar(range(1, (N_comps + 1)),
np.mean(predicted_var_expl_labels_umap_matrix, axis = 0),
yerr=2*np.std(predicted_var_expl_labels_umap_matrix,axis=0),
linewidth = 3, color = 'green', marker = 'o', markersize = 10,
capsize = 5, capthick = 3)
plt.ylabel('Cumulative MNIST Labels Explained Variance', fontsize = 20)
plt.xlabel('Number of Components', fontsize = 20)
plt.legend(['PCA variance explained', 'tSNE variance explained',
'UMAP variance explained'], fontsize = 20)
plt.xlim([0.8, (N_comps + 0.2)]); plt.xticks([1, 2, 3]); plt.show()

PCA、tSNE 和 UMAP 组件解释的 MNIST 标签的累积方差。图像由作者提供
在这里, 令人惊讶的是,我们观察到 领先的 UMAP 和 tSNE 组件似乎比领先的 PCA 组件解释了 MNIST 标签中更多的变化。这在某种程度上是违反直觉的,因为我们在前一节中看到领先的 tSNE 和 UMAP 组件解释了较少的 MNIST 图像像素强度变化。然而,上面的累计解释方差图基本上确认了 MNIST 标签与 PCA / tSNE / UMAP 组件之间的相关性热图。
在这里,我们观察到一个有趣的现象。我们在前一节中看到,领先的 tSNE 和 UMAP 组件无法捕捉比领先的 PCA 组件更多的 MNIST 变化。然而,令人惊讶的是,尽管在进行降维时这些类没有提供给 tSNE / UMAP,它们却能够捕捉 MNIST 标签,即手写数字的类别中的更多变化。换句话说,这三种降维技术都是完全无监督的,即它们对手写数字的类别一无所知。尽管如此,tSNE,尤其是 UMAP 组件,似乎与数字类别惊人地相关,而无法捕捉 MNIST 图像中像素强度的实质性变化。我不完全理解这个现象,但认为这是一个有趣的观察,我很想进一步探索。如果你对这种现象的本质有更多见解,请在评论中告诉我。
摘要
在这篇文章中,我们了解到偏最小二乘(PLS) 方法可以用来建立对 tSNE 和 UMAP 组件解释的数据变化的直觉。使用这种方法,我们展示了tSNE 和 UMAP 组件解释(并不出奇)比 PCA 组件少的 MNIST 图像像素强度变化,然而,令人惊讶的是,它们与MNIST 图像的标签****有很强的关联。考虑到这三种降维技术的无监督设计,这种效果的本质尚不清楚。
和往常一样,请在下面的评论中告诉我在生命科学和计算生物学中哪些分析方法对你来说特别神秘,我会尽量在这一栏中讨论它们。查看我在github上的帖子所用的文件。在Medium上关注我,Nikolay Oskolkov,在Twitter上@NikolayOskolkov,在Mastodon上@oskolkov@mastodon.social,通过Linkedin与我联系。在下一篇文章中,我将讨论如何在 UMAP 空间中进行聚类,敬请关注。
解密 Cox 回归:Cox 回归的隐藏黑暗秘密
为什么完美预测因子的 p 值会为 0.93?
·
关注 发表在 Towards Data Science · 8 min 阅读 · 2023 年 6 月 27 日
--
图片由 Dima Pechurin 提供,来源于 Unsplash
探索完美预测因子
如果你一直关注我的之前的博客帖子,你可能会记得逻辑回归在尝试完美分离数据时会遇到问题,导致无限的赔率比。在 Cox 回归中,风险替代了赔率,你可能会想知道完美预测变量是否会出现类似的问题。确实会出现,但是与逻辑回归不同,这里的问题不那么明显,也不容易界定什么是“完美预测变量”。如后面会更加明确,完美预测变量被定义为预测变量x*,其排名恰好与事件时间的排名一致(它们的斯皮尔曼相关系数为 1)。
之前,在“Unbox the Cox”中:
## Unbox the Cox: 直观指南到 Cox 回归
风险和最大似然估计如何预测事件排名?
[towardsdatascience.com
我们解释了最大似然估计,并介绍了一个虚构的数据集,其中有五个主体,一个预测变量x,代表了一种延长生命的药物的剂量。为了使x成为事件时间的完美预测变量,我们在这里交换了主体 C 和 D 的事件时间:
import numpy as np
import pandas as pd
import plotnine as p9
from cox.plots import (
plot_subject_event_times,
animate_subject_event_times_and_mark_at_risk,
plot_cost_vs_beta,
)
perfect_df = pd.DataFrame({
'subject': ['A', 'B', 'C', 'D', 'E'],
'time': [1, 3, 4, 5, 6],
'event': [1, 1, 1, 1, 0],
'x': [-1.7, -0.4, 0.0, 0.9, 1.2],
})
plot_subject_event_times(perfect_df, color_map='x')

图片由作者提供。
为了理解为什么这些“完美预测变量”可能会有问题,让我们从上次的内容继续,查看负对数似然图与β的关系:
negloglik_sweep_betas_perfect_df = neg_log_likelihood_all_subjects_sweep_betas(
perfect_df,
betas=np.arange(-5, 5, 0.1)
)
plot_cost_vs_beta(negloglik_sweep_betas_perfect_df, width=0.1)

图片由作者提供。
你可以立即看到β没有最小值:如果我们使用非常大的负值的β,我们最终会得到几乎完美的对数似然拟合结果。
现在,让我们深入探讨一下背后的数学原理,看看事件 A 的可能性。我们将研究分子和分母在调整β时的变化情况:

当β很高或是一个很大的正数时,分母中的最后一个项(具有最大* x* 为 1.2),代表了主体 E 的风险,主导了整个分母,并变得极其巨大。因此,似然变得很小,接近于零:

这会导致很大的负对数似然。每个单独的似然情况也是如此,因为主体 E 的最后一个风险总是会超过分子中的任何风险。因此,主体 A 到 D 的负对数似然增加。在这种情况下,当我们有较高的β时,它会降低所有的似然,导致所有事件的拟合效果较差。
现在,当β值较低或为一个很大的负数时,分母中的第一个项,代表了对象 A 的风险,因为它的x值最低,主导了分母。由于对象 A 的风险也出现在分子中,通过使β越来越负,似乎能使 L(A)接近 1,从而创造出几乎完美的拟合:

对所有其他单个可能性来说也是一样的:负β现在同时提升所有事件的可能性。基本上,负β不会带来任何缺点。同时,某些个体风险增加(对象 A 和 B 的负x),有些保持不变(对象 C 的x = 0),其他则减少(对象 D 的正x)。但请记住,真正重要的是风险的比率。我们可以通过绘制单个风险来验证这一点:
def plot_likelihoods(df, ylim=[-20, 20]):
betas = np.arange(ylim[0], ylim[1], 0.5)
subjects = df.query("event == 1")['subject'].tolist()
likelihoods_per_subject = []
for subject in subjects:
likelihoods = [
np.exp(log_likelihood(df, subject, beta))
for beta in betas
]
likelihoods_per_subject.append(
pd.DataFrame({
'beta': betas,
'likelihood': likelihoods,
'subject': [subject] * len(betas),
})
)
lik_df = pd.concat(likelihoods_per_subject)
return (
p9.ggplot(lik_df, p9.aes('beta', 'likelihood', color='subject'))
+ p9.geom_line(size=2)
+ p9.theme_classic()
)
plot_likelihoods(perfect_df)

图片由作者提供。
可能性的组合方式,即风险与所有仍然处于风险中的对象的风险总和的比率,意味着负β值为每个事件时间排名大于或等于预测因子排名的对象的可能性提供了一个完美的拟合!作为附带说明,如果x与事件时间有完美的负斯皮尔曼相关性,情况将会反转:任意正的β会给我们带来任意好的拟合。
不匹配的预测因子和时间排名
我们实际上可以看到这一点,并向你展示当事件时间排名和预测因子排名不匹配时会发生什么,使用另一个虚构的例子:
sample_df = pd.DataFrame({
'subject': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'],
'x': [-1.7, -0.4, 0.0, 0.5, 0.9, 1.2, 1.3, 1.45],
'time': [1, 2, 4, 3, 5, 7, 6, 8],
'rank_x': [1, 2, 3, 4, 5, 6, 7, 8],
'event': [1, 1, 1, 1, 1, 1, 1, 0],
})
sample_df

在这个特定的例子中,time列的范围从 1 到 8,每个值代表其自己的排名。我们还有一个x_rank列,用于排名预测因子x。现在,关键观察点是:对于 D 和 G 对象,它们的x_rank实际上高于其对应的time排名。因此,当我们有大负值的β时,D 和 G 的可能性不会在分子和分母之间经历抵消效应:


他们的可能性现在在某些中间有限的β值处达到最大。我们来看看单个可能性的图表:
plot_likelihoods(sample_df)

图片由作者提供。
时间和预测因子之间的这些“错位”排名起着至关重要的作用:它们阻止所有可能性在我们有显著负β值时基本上坍缩成一个。
总结来说,在 Cox 回归中,为了获得预测因子x的有限系数β,我们需要至少有一个实例,其中预测因子x的排名低于事件时间的排名。
完美确实是良好的敌人(p 值)
那么,这些完美的预测因子在现实场景中实际上表现如何?为了找出答案,让我们再次转向 lifelines 库进行一些调查:
from lifelines import CoxPHFitter
perfect_cox_model = CoxPHFitter()
perfect_cox_model.fit(
perfect_df,
duration_col='time',
event_col='event',
formula='x'
)
perfect_cox_model.print_summary()
#> /.../coxph_fitter.py:1586: ConvergenceWarning:
#> The log-likelihood is getting suspiciously close to 0 and the delta is still large.
#> There may be complete separation in the dataset.
#> This may result in incorrect inference of coefficients.
#> See https://stats.stackexchange.com/q/11109/11867 for more.
#> /.../__init__.py:1165: ConvergenceWarning:
#> Column x has high sample correlation with the duration column.
#> This may harm convergence.
#> This could be a form of 'complete separation'.
#> See https://stats.stackexchange.com/questions/11109/how-to-deal-with-perfect-separation-in-logistic-regression
#> /.../coxph_fitter.py:1611:
#> ConvergenceWarning: Newton-Rhaphson failed to converge sufficiently.
#> Please see the following tips in the lifelines documentation:
#> https://lifelines.readthedocs.io/en/latest/Examples.html#problems-with-convergence-in-the-cox-proportional-hazard-model

就像在逻辑回归中一样,我们遇到了收敛警告,并且得到了极宽的预测因子系数β的置信区间因此,我们得到的 p 值为 0.93!
如果我们仅仅基于 p 值过滤模型而不考虑这个问题或进行进一步调查,我们可能会忽视这些完美的预测因子。
为了应对这个收敛问题,lifelines 库文档和一些有用的 StackOverflow 线程建议了一个潜在的解决方案:将正则化项纳入成本函数。这个项有效地增加了大系数值的成本,你可以通过将penalizer参数设置为大于零的值来激活 L2 正则化:
perfect_pen_cox_model = CoxPHFitter(penalizer=0.01, l1_ratio=0)
perfect_pen_cox_model.fit(perfect_df, duration_col='time', event_col='event', formula='x')
perfect_pen_cox_model.print_summary()

这种方法修复了收敛警告,但并没有在缩小那个讨厌的 p 值上取得巨大进展。即使使用了这种正则化技巧,完美预测因子的 p 值仍然维持在一个相对较大的值 0.11。
时间是相对的:只有排名才重要
最后,我们将验证事件时间的绝对值对 Cox 回归拟合没有影响,使用我们之前的例子。为此,我们将引入一个名为time2的新列,其中包含与time列相同顺序的随机数字:
sample_df = pd.DataFrame({
'subject': ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'],
'x': [-1.7, -0.4, 0.0, 0.5, 0.9, 1.2, 1.3, 1.45],
'time': [1, 2, 4, 3, 5, 7, 6, 8],
'rank_x': [1, 2, 3, 4, 5, 6, 7, 8],
'event': [1, 1, 1, 1, 1, 1, 1, 0],
}).sort_values('time')
np.random.seed(42)
sample_df['time2'] = sorted(np.random.randint(low=-42, high=888, size=8))
sample_df

它们的拟合确实是相同的:
sample_cox_model = CoxPHFitter()
sample_cox_model.fit(
sample_df,
duration_col='time',
event_col='event',
formula='x'
)
sample_cox_model.print_summary()

sample_cox_model = CoxPHFitter()
sample_cox_model.fit(
sample_df,
duration_col='time2',
event_col='event',
formula='x'
)
sample_cox_model.print_summary()

结论
我们从中学到了什么?
-
生存模型中的完美预测因子是那些排名与事件时间的排名完全匹配的预测因子。
-
Cox 回归无法用有限的系数β来拟合这些完美预测因子,导致了宽置信区间和大的 p 值。
-
事件时间的实际值并不重要——关键在于它们的排名。
-
当事件时间和预测因子的排名不一致时,我们无法在似然中获得大β值的便捷取消效应。因此,我们至少需要一个排名不匹配的案例,以获得一个有限系数的拟合。
-
即使我们尝试一些高级的正则化技术,完美预测因子在现实情况下仍然会给我们那些恼人的宽置信区间和高 p 值。
-
就像在逻辑回归中一样,如果我们不太关心这些 p 值,使用正则化方法仍然可以提供一个方便的模型拟合,准确进行预测!
如果你想自己运行代码,可以使用我 Github 上的 IPython 笔记本:github.com/igor-sb/blog/blob/main/posts/cox_perfect.ipynb
告别,直到下一篇文章! 👋
解锁 Cox 回归:Cox 回归的直观指南
风险和最大似然估计如何预测事件排名?
·
关注 发表在Towards Data Science ·10 分钟阅读·2023 年 6 月 6 日
--
图片由Chris Boyer提供,来源于Unsplash
引言
Cox 回归的目标是建模预测变量与事件发生所需时间之间的关系——例如只发生一次的事件。我们来看一个虚构的数据集,包含 5 个受试者,标记为 A 到 E。在研究期间,每个受试者要么经历了事件(事件 = 1),要么没有(事件 = 0)。此外,每个受试者在研究前都被分配了一个预测变量,我们称之为 x。作为一个实际例子,如果我们跟踪死亡事件,那么 x 可能是我们正在测试的药物剂量,看看它是否通过影响直到死亡的时间来帮助人们活得更久。
import pandas as pd
import numpy as np
sample_df = pd.DataFrame({
'subject': ['A', 'B', 'C', 'D', 'E'],
'time': [1, 3, 5, 4, 6],
'event': [1, 1, 1, 1, 0],
'x': [-1.7, -0.4, 0.0, 0.9, 1.2],
})
sample_df

在这个数据集中,受试者 E 在研究期间没有经历任何事件,因此我们将事件设置为 0,分配的时间基本上是我们最后一次知道他们的时间。这种数据被称为“删失数据”,因为我们不知道事件是否在研究结束后发生。为了更容易理解,使用一个酷炫的“棒棒糖”🍭图来可视化这种类型的数据非常有用:
from cox.plots import plot_subject_event_times
plot_subject_event_times(sample_df)

线条表示每个受试者在事件发生之前经历的时间,事件由填充的圆圈表示。受试者 E 没有圆圈,因为他在整个研究期间都存活。
(️️绘制函数可以在我的 Github 仓库 中找到。)
提个醒:我希望让 Cox 回归的主要思想更易于理解,因此我们将专注于在给定时间只能发生一个事件的数据(没有平局)。
危险度
危险度表示事件在特定时间点发生的瞬时率(单位时间内的概率)——假设在此之前事件尚未发生。由于这是一个事件发生的速率,它可以具有任意单位,与概率不同,危险度值可以在 0 到无穷大之间变化:[0, ∞)。
在 Cox 回归中,危险度的工作方式类似于逻辑回归中的赔率。在逻辑回归中,赔率 = p/(1-p) 将概率范围 [0, 1] 转换为 0, ∞) 的范围。然后,取对数将赔率转换为对数赔率,其值可以从负无穷到正无穷 (-∞, ∞)。这种对数赔率的概率变换是为了将可能的输出值与预测变量 β₁x₁ + β₂x₂ + … 的线性组合匹配,后者的范围也可以是 (-∞, ∞)。
在这里,我们从一个危险度 h(t, x) 开始,它的值范围已经是从 0 到无穷大 [0, ∞)。通过应用对数变换,我们将该范围转变为 (-∞, ∞),从而允许用预测变量的线性组合进行拟合:

这一假设的动机是为了显著简化拟合过程(我们稍后将展示具体方法)。在文献中,截距项 β₀(t) 通常被移到方程的左侧,并表示为 基线风险 log[h₀(t)]:

从这个方程中,我们可以表示风险 h(t, x):

现在,这里是有趣的部分:由于每个被试的数据通过预测变量 x 仅影响风险,因此每个被试的风险具有相同的时间依赖性。唯一的区别在于 exp(βx) 部分,这使得来自不同被试的风险相互成比例。这就是为什么这个模型也被称为“比例风险”模型。
我们已经将很多逻辑回归的类比提及了。如果你阅读了我之前关于逻辑回归的帖子:
对数损失和完全分离的数据与曲棍球棒有什么关系?
towardsdatascience.com
你可能会想知道 Cox 回归是否也会受到预测变量“过于优秀”的影响。请继续关注下一篇文章,我们将讨论这一点!
可能性
Cox 模型使用一种称为最大似然估计(MLE)的方法进行拟合。可能性与概率非常相似:它们共享相同的方程——几乎就像是同一枚硬币的两面。概率是数据 x 的函数,模型参数 β 固定,而可能性是 β 的函数,x 固定。这就像是查看正态分布的概率密度,但不是关注 x,而是关注 μ 和 σ。
MLE 拟合过程始于事件发生的排序。在我们编造的数据中,这个顺序是:A、B、D、C,E 为删失。这是 Cox 回归中唯一涉及时间的情况。事件时间的实际数值完全不重要,只要被试经历事件的顺序相同。
然后我们逐个处理每个被试,并估计该被试经历事件的概率或可能性,相对于所有其他仍处于 风险之中的 被试。例如,考虑被试 A,在 t = 1 时经历了事件。这一事件发生的可能性由被试 A 经历事件的风险率决定,相对于 t = 1 时所有其他仍处于风险中的人的综合风险率(包括所有人):


如你所注意到的,我们没有定义基线风险 h₀(t),因为它实际上在可能性计算中完全抵消了。
一旦我们为每个主题代入值 x(-1.7、-0.4、0.0、0.9、1.2),我们会得到一个仅剩下 β 的方程:

从这一点起,在 t = 1 之后的任何时间,主题 A 的风险被认为是零,并且在计算进一步的似然值时不会考虑。例如,在另一个时间 t = 3,主题 B 发生了事件。因此,主题 B 的似然值是相对于主题 B 到 E 的风险来确定的:

我们可以继续计算所有主题 A 到 D 的似然值,但这将在下一部分的编码中进行。由于主题 E 是被审查的且未发生事件,所以它没有自己的似然值。由于审查数据仅用于未审查主题的似然值中,因此结果组合的似然值通常被称为“部分似然”。
为了总结这个过程,我们可以创建一个动画棒棒糖图:
from cox.plots import animate_subject_event_times_and_mark_at_risk
animate_subject_event_times_and_mark_at_risk(sample_df).save(
'../images/cox_likelihood_fitting_sample.gif'
)

查找 β
当事件彼此独立发生时,观察所有事件的联合概率或似然值可以通过乘以各个似然值来计算,表示为 L= L(A) L(B) L(C) L(D)。然而,乘法的指数表达式可能导致数值误差,因此我们通常取该似然值的对数。通过应用对数,我们将似然值的乘积转换为对数似然值的总和:


由于对数是单调函数,似然值和对数似然值在相同的 β 值下达到最大值。为了便于可视化和与其他成本函数进行比较,我们可以将成本定义为负对数似然值,并旨在最小化它。
准备好,开始编码!
我们可以逐步在 Python 中实现这个算法。首先,我们需要提取每个未审查主题的事件时间和预测变量 x。这可以通过函数 event_time_and_x_from_subject() 完成。一旦我们得到主题的事件时间,我们可以对数据框进行子集处理,以识别所有仍在风险中的主题的行。这是通过函数 subjects_at_risk_data() 实现的。最后,我们使用函数 log_likelihood() 计算每个主题的对数似然值:
def event_time_and_x_from_subject(df, subject):
subject_with_event_df = df.query(f"subject == '{subject}' & event == 1")
if subject_with_event_df.empty: # Censored subjects
return (np.inf, 0)
return subject_with_event_df.iloc[0][['time', 'x']]
def subjects_at_risk_data(df, subject):
time = event_time_and_x_from_subject(df, subject)[0]
return df.query(f'time >= {time}')
def log_likelihood(df, subject, beta):
x_subjects_at_risk = subjects_at_risk_data(df, subject)['x']
x_subject = event_time_and_x_from_subject(df, subject)[1]
at_risk_hazards = np.exp(beta * x_subjects_at_risk)
return beta * x_subject - np.log(np.sum(at_risk_hazards))
为了可视化,我们绘制成本或负对数似然值。因此,我们需要计算每个主题在特定 β 值下的这些值:
def neg_log_likelihood_for_all_subjects(df, beta):
subjects = df.query("event == 1")['subject'].tolist()
neg_log_likelihoods = [
-log_likelihood(df, subject, beta)
for subject in subjects
]
return pd.DataFrame({
'subject': subjects,
'neg_log_likelihood': neg_log_likelihoods
})
为了找到最小成本,我们遍历 β 值的范围:
def neg_log_likelihood_all_subjects_sweep_betas(df, betas=np.arange(-5, 5, 0.1)):
loglikelihoods_per_beta = []
for beta in betas:
beta_df = neg_log_likelihood_for_all_subjects(df, beta)
beta_df.insert(0, 'beta', beta) # Add beta column
loglikelihoods_per_beta.append(beta_df)
return pd.concat(loglikelihoods_per_beta)
negloglik_sweep_betas_df = neg_log_likelihood_all_subjects_sweep_betas(sample_df)
negloglik_sweep_betas_df

理解这一切
与其通过按受试者分组的对数似然度求和来聚合数据框,不如保持其当前形式并将其可视化为堆叠条形图。在这个可视化中,每个条形图的总高度对应于负对数似然度的总和。每个受试者在条形图中用不同的颜色表示,指示他们的个体似然度及其对整体似然度的贡献:
from cox.plots import plot_cost_vs_beta
plot_cost_vs_beta(negloglik_sweep_betas_df, width=0.1)

每个狭窄的垂直彩色条表示个体负对数似然度。
让我们来理解一下如何解读这个图。
首先,注意到当似然度和危险度很小时,负对数似然度(成本)很大。可以将 y 轴视为 -log(p-value);较大的值表示较低的概率。
其次,被审查的受试者(如受试者 E)没有自己单独的似然度,因此它们没有出现在图中。然而,它们的贡献被纳入了受试者 A 到 D 的似然度中。
现在,考虑基于 β 值的不同场景:
-
如果β很大且为负数,受试者 A、B 和 C(x ≤ 0)及其事件几乎完美拟合。这些受试者的似然度都接近于一。然而,给 A、B 和 C 拟合如此大的负 β 值的代价是受试者 D。在这个 β 范围内,受试者 D(x > 0)发生事件的概率非常低。因此,总成本由受试者 D 的小似然度主导。
-
如果 β 很大且为正,受试者 D(x > 0)的危险度相较于其他危险度变得显著。紫色部分(受试者 D)成为总成本的一小部分。然而,由于受试者 A、B 和 C 都有 x ≤ 0,因此给他们拟合一个大 β 的成本很高。因此,这些受试者主导了总成本。
-
位于图上约 2 的 β 的最佳值在为受试者 A、B 和 C 事件分配高概率与受试者 D 之间取得了平衡。这个最佳值可以通过数值方法验证:
negloglik_sweep_betas_df \
.groupby("beta") \
.agg(sum_neg_log_likelihood=('neg_log_likelihood', 'sum')) \
.reset_index() \
.sort_values('sum_neg_log_likelihood') \
.head(1)

使用 lifelines 库
现在我们对 Cox 回归有了更好的理解,我们可以使用 Python 的 lifelines 库将其应用于样本数据以验证结果。以下是我们虚构数据的代码片段:
from lifelines import CoxPHFitter
sample_cox_model = CoxPHFitter()
sample_cox_model.fit(sample_df, duration_col='time', event_col='event', formula='x')
sample_cox_model.print_summary()

在输出中,我们可以观察到 -1.71 的系数(coef)值,它对应于 β 系数。它旁边的是 exp(coef),表示 exp(β),还有表示标准误差和置信区间的列。“部分对数似然度”值为 -2.64,这与我们的手动结果相匹配。
最后,需要提到的是,Cox 回归实现还提供了一个扩展,允许模型处理在同一事件时间发生的多个事件。然而,这超出了本讨论的范围。
结论
这里有很多需要“拆解”的内容:
-
Cox 回归模型描述了预测变量与事件发生时间排名顺序之间的关联。
-
事件时间的实际数值完全无关紧要,只要受试者以相同的顺序经历这些事件。
-
危险度是单位时间内的概率,可以具有任意单位,而似然度与事件发生的概率相关。
-
堆叠条形图可以用来通过堆叠单个负对数似然值提供对最大似然估计的洞察,并探索它们如何随预测变量x变化。
-
通过在为各种受试者分配事件概率之间取得平衡,最大似然估计找到使观察数据最有可能发生的β。
敬请期待… 👀
参考文献
-
带图表的代码:
github.com/igor-sb/blog/blob/main/posts/cox/plots.py -
帕特里克·布雷赫尼教授的生存数据分析幻灯片:
myweb.uiowa.edu/pbreheny/7210/f19/index.html -
用于文本清理和有趣的虚构的 ChatGPT
开箱 DINOv2,Meta 的新型全能计算机视觉骨干网络
原文:
towardsdatascience.com/unboxing-dinov2-metas-new-all-purpose-computer-vision-backbone-d8e22c059040
人工智能
视觉基础模型是否在追赶 LLMs?
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 5 月 7 日
--

自监督训练方法不断取得突破。上周,Meta AI 发布了他们的第二版自监督蒸馏模型 DINO。该模型据说可以作为骨干网络来解决几乎任何计算机视觉任务,而无需微调!计算机视觉中的基础模型是否已经赶上了大型语言模型长期以来的多功能性?让我们带着 DINO 去探索它能做些什么!
如果你主要对尝试新的 DINO 感兴趣,可以直接滚动到“测试 DINOv2”部分。在此之前,我们将更详细地探讨模型的架构和训练过程。

🦖 计算机视觉中的自监督学习
自监督在计算机视觉应用中逐渐受到关注已有几年之久。这不足为奇:没有标签示例的模型训练可能使用更大范围的训练数据,并且在一些标签难以获得或成本高昂的应用中,甚至可能实现以前无法完成的训练。
以自监督方式训练的模型仅从图像中学习,无需注释。实际上,它们从未标记的数据中创建自己的伪标签。
这在自然语言处理领域已成为一种既定实践,语言模型通常被训练以预测句子中的下一个词。给定输入文本,可以自动生成训练所需的特征和标签。
然而,在计算机视觉领域,自监督方法直到谷歌和 Meta 的一些对比模型(SimCLR,MoCo,SwAV,和BYOL)展示了最先进的结果之前,尚未真正起飞,这些结果有时与完全监督模型匹敌甚至超越了那些有标注训练数据的模型。在我的早期工作中,我展示了 MoCo 如何在标注训练样本稀缺的环境中提高 X 射线诊断的性能。
在 2021 年,Meta 在题为《自监督视觉变换器中的新兴特性》的论文中描述了他们的第一个 DINO。尽管他们的模型受到之前主导对比架构的启发,但采取了稍有不同的方法。让我们先看看原始的 DINO,因为它的第二个版本与之非常相似。

🦖 DINO 模型
“DINO” 实际上是一种首字母缩略词,代表自di蒸馏与no标签。正如其名称所示,它结合了两种学习技术:我们已经讨论过的无标签自监督学习和知识蒸馏。
知识蒸馏是一种通常用于压缩模型大小的方法。在这种方法中,一个较小的模型(称为“学生”)通常被训练以产生与一个较大、已训练好的模型(称为“教师”)相同的预测。如果学生能够真实地模仿教师,我们可以在使用较小模型的同时保持相同的性能。
DINO 使用作者称之为自蒸馏的方法,其中两个模型——学生和教师——实际上是相同的模型:它们具有相同的大小和架构。它们仅在训练期间如何更新其参数上有所不同。

DINO 的训练过程。图片来源:arXiv:2104.14294
为了训练 DINO,我们设置了两个相同的网络——作者最初使用的是视觉变换器(ViTs)。如前所述,这两个网络具有相同的架构但参数不同。
然后,从每张训练图像中,随机裁剪出一些区域。这些裁剪区域中,有些只覆盖了原始图像的一小部分——我们称之为局部视图。其他裁剪区域较大,覆盖了原始图像的显著部分——这些是全局视图。
接下来,所有的作物通过学生网络进行处理,而只有全局视图通过教师网络进行处理。每个网络生成其输入作物的潜在表示或嵌入。然后,通过交叉熵损失来评估学生和教师的嵌入之间的相似性。这个想法基于SwAV,旨在鼓励模型学习全局到局部的对应关系。
最后,基于损失的梯度被反向传播通过学生网络,以教会它生成类似于教师的表示。另一方面,教师的权重通过学生权重的指数移动平均进行更新。这个想法基于MoCo 模型,但与之不同的是,DINO 不使用任何记忆库。
原始 DINO 论文的标题是“自监督视觉变换器中的新兴属性”,因为作者对模型中出现的属性感到惊讶。DINO 骨干网络包含有关图像语义分割的信息,并且在下游图像分类任务中表现出色。
V2 有什么新变化?
DINOv2 与其前身有何不同,我听到你在问。嗯,变化不大,至少在模型架构或训练例程方面没有太大变化。作者自己承认,在DINOv2 论文中,“大多数技术贡献旨在加速和稳定大规模训练”。
不同之处在于 DINOv2 训练所用的数据。迄今为止,视觉应用自监督学习的大多数进展都是在小型数据集上进行预训练的,比如臭名昭著的 ImageNet,其缺乏多样性阻碍了有用特征的学习。
DINOv2 的作者建立了一个数据管道,使他们能够策划一个相对较大且多样化的数据集。为此,他们使用聚类算法将候选图像分组为语义相似的集群,然后重新平衡集群,以防止模型过拟合数据中的少数几个主要模式。

🦖 测试 DINOv2
让我们对模型进行一个简单的测试吧!论文声称 DINOv2 骨干网络可以作为特征提取器使用,而无需微调。让我们看看它的表现如何。
作为测试任务,我们将让 DINO 识别手写字符来自哪个字母表,使用 Omniglot 数据集的一个子集。

Omniglot 数据集中的一个样本。来源:github.com/brendenlake/omniglot。
具体来说,我们将 9543 个字符图像(来自 30 种不同字母表的 964 个不同字符)通过 DINOv2 主干网络。然后,我们将获得的嵌入分为训练集和测试集,并在其上训练一个逻辑回归分类器,以将图像分类到 30 种字母表之一。这种评估方法被称为线性读取——我们仅从冻结的主干网络中读取嵌入,并在其上放置一个线性层(或线性分类器)。
这是一个相当具有挑战性的任务:大约 9.6k 张图像和 960 个不同字符,每个字符只有 10 张图像(其中只有 7 张用于训练——其余用于测试)。实际上,我们创建了一个少样本学习问题,在这种情况下,一个随机分类器的准确率为 1/30,即 3.3%。
我们从设置数据加载器开始。
dataset = ImageFolder(
"omniglot",
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Resize((98, 98))
]),
)
dataloader = DataLoader(
dataset, shuffle=True, batch_size=64
)
接下来,我们加载 DINOv2 模型。PyTorch Hub 上提供了四种不同的架构,它们有不同的大小和性能。我们使用最轻的 ViT-S/14 distilled(21M 参数)和最重的 ViT-L/14 distilled(300M 参数)(还有一个未蒸馏版本的 1100M 参数,但它相当重,并且与 300M 参数版本的性能非常接近)。这里是加载 ViT-S/14 distilled 的代码片段。
dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
完成这些之后,我们将所有图像通过 DINOv2 主干网络,并收集嵌入及其相关的目标标签。
dinov2_vits14 = dinov2_vits14.to(device)
all_embeddings, all_targets = [], []
with torch.no_grad():
for images, targets in tqdm(dataloader):
images = images.to(device)
embedding = dinov2_vits14(images)
all_embeddings.append(embedding)
all_targets.append(targets)
all_embeddings = torch.cat(all_embeddings, dim=0)
all_targets = torch.cat(all_targets, dim=0)
接下来,我们将数据分为训练集和测试集,并在其上训练一个逻辑回归分类器。
X_train, X_test, y_train, y_test = train_test_split(
all_embeddings.cpu().numpy(),
all_targets.cpu().numpy(),
test_size=0.3,
random_state=42,
)
model = LogisticRegression()
model.fit(X_train, y_train)
test_acc = model.score(X_test, y_test)
print(f'Test accuracy: {test_acc}')
我们得到了稍高于 54% 的测试准确率。比随机猜测好得多,但还远未完美。让我们看看它与一个更大的 300M 参数 DINO 和一个 ResNet50 的表现如何比较。

模型比较:两个 DINO 和一个 ResNet。
ResNet50 和小型 DINOv2 使用的ViT-S/14大小相似——DINO 实际上更小——但 DINO 的准确率高出大约 15 个百分点。一个更大的 DINO 可以将准确率再提高 10 到 15 个百分点,即 65–70%。
这是一个好的分数吗?在得到结果时,我的第一反应是略微失望。下意识地,我可能期望得到 90% 以上的准确率。但毕竟,这个任务不容易,而且我们仅使用了(相当于)一个线性层来进行训练。DINOv2 的表现确实优于类似大小的 ResNet,后者通常被用作主流的视觉特征提取器。
你对这些结果怎么看?在评论中告诉我吧!

感谢阅读!
如果你喜欢这篇文章,为什么不 订阅电子邮件更新 以获取我新文章的通知呢?而通过 成为 Medium 会员,你可以支持我的写作,并无限制访问所有其他作者和我的故事。
想要时刻掌握日益快速发展的机器学习和人工智能领域的最新动态?查看我的新通讯,AI Pulse。需要咨询?你可以随时问我问题或在这里预约一对一咨询。
你还可以尝试我其他的文章。难以选择?可以从这些中挑一个:
如何用少量标记样本训练模型
towardsdatascience.com ## 使用 TensorFlow 进行模型优化
通过量化和剪枝来减少模型的延迟、存储和推理成本
towardsdatascience.com [## 忘记 ChatGPT
Bard、Sparrow 和多模态聊天机器人将很快使其过时,原因如下。
pub.towardsai.net](https://pub.towardsai.net/forget-about-chatgpt-f17a7f5089c3?source=post_page-----d8e22c059040--------------------------------)
除非另有说明,所有图片均由作者提供。
拆解 Google Bard 和 GPT-4
原文:
towardsdatascience.com/unboxing-google-bard-and-gpt-4-811896adf0e2
不拘一格的揭示者
对两个主要 AI 发布的初步了解
·发布于 Towards Data Science ·阅读时间 10 分钟·2023 年 3 月 29 日
--

你的 作者。这不是视频,视频在下面。或者 这里,如果你坚持的话。
来看看这段 AI 拆解视频吧!这些全新的工具刚刚发布了一周多,所以它们还是新鲜出炉的。在视频中,你将看到我第一次同时运行 Bard 和 GPT-4 的提示。在下面,你会看到一些开始于视频逐字稿的内容,迅速演变成了大量附注、编辑和讽刺评论。如果这些是你喜欢的内容,请享受吧!
链接: bit.ly/quaesita_ytunboxing
逐字稿式
嗨!我是 Cassie Kozyrkov,今天我将通过 ChatGPT 展示 GPT-4,通过 Google Bard 展示 LaMDA。Bard 是免费的,但可能需要一些耐心,因为它正在逐步推出(可以在 这里 加入候补名单)。ChatGPT 的基本版是免费的,但你无法通过这种方式访问 GPT-4。要访问 GPT-4,你需要订阅 ChatGPT Plus,费用为每月 $20(一个月后可以取消)。
在这个界面演示中,屏幕的右半部分展示了付费版的 ChatGPT(使用 GPT-4),而左半部分展示了今天(视频来自过去,现在是“上周的”,等你读到的时候,谁知道是什么时候),反正是上周二发布的 Google Bard,它由 LaMDA 模型驱动。

来自 拆解视频 的截图。
这两个是大型语言模型(LLM),我将它们并排展示给你。如果这些缩略语有些陌生,请访问 这里。
我在第一次有机会看到它们并排行动时,用笔记本电脑录制了这个视频,所以我展示的是我第一次会话的屏幕分割,使用了这两个模型。我相信还会有更多视频。这很有趣。(随时在评论中提示一些问题。)这是我选择在初次面对这两个模型的前几分钟内进行的事情的视频。老实说,虽然我确实喜欢认识论——毕竟我是一名统计学家,这在我的工作范围内——但我实际的游戏是让它们之间进行有趣的对话。
哲学似乎是一个不错的起点,因为它通常处理激发对话的开放性问题,并允许多种观点,但这里有个剧透:我有一个 20 分钟的导演剪辑版视频(我会很快分享),我尝试让它们丰富地互动,但我只获得了一个好的时刻,其余的是一系列“我很高兴回答你的问题。” “谢谢你的帮助,我在这里为你提供任何需要。” 是的,我们在那个邮件链上。
我对 Bard 的第一个提示是:“认识论中最有争议的问题是什么?为其中一方提出一个论点并问我怎么看。” 这是一个经典的对话尝试策略,提示中明确要求带有一点意见。并不需要多大的哲学洞察力就能意识到一个机器人实际上不能有意见,所以我真正尝试引导的是一种单方面的观点,以便我可以启动机器人之间的对话。希望激发回应,使对话更有趣。我希望回应的最后一部分涉及某种形式的对话接力——例如“你怎么看?”——因为我想和 ChatGPT 开启往返对话。
无论是 Bard 还是 ChatGPT 都不是为了让你像和朋友或治疗师交谈那样进行对话设计的,从我作为提示工程师的经验来看,启动对话可能会很棘手。(今天这个术语可以指“我曾经尝试过输入到 LLM 中的内容” 到 “我曾在 LLM 红队 中,并了解很多关于如何破解它们的知识,所以要小心。”)
ChatGPT 把对话的主动权掌握在自己手中很长时间,似乎优先生成需要时间的长回应,因此我猜它不太可能被那些想模拟愉快对话的用户选择。
一个好的对话者会投入努力去继续和你对话。如果双方都投入了这种努力,对话就会像愉快的友谊乒乓球一样。如果你变得竞争、失去兴趣,或者只是把球保持在自己一边太久,对话就会死掉。
ChatGPT 更像是一个高效的工作者,它完成任务,彻底回答你的问题,然后离开。它的设计并不要求维持对话,因此不需要将对话的球回抛给你。
知道这些后,我为什么想让 ChatGPT 和 Bard 对话?因为这可能会变成一个有趣的游戏。我们来试试 LLM 间的对话吧!话虽如此,我一点也不想使用过度拟人的语言,把发生的事情称作“对话”或“两个 AI 在谈话”。(恶心。但这正是媒体报道的方式。)
我的问题: “认识论中最具争议的问题是什么?为其中一个观点提供论证,并问我怎么看。”
Bard 的回答总结: 认识论中最具争议的问题是“我们能否确定知道任何事?”(怀疑主义问题)。

可惜,Bard 忽视了我对怀疑主义问题中某一方的强烈意见的要求。回应过于平衡,这意味着没有对话的燃料。最后也没什么引人入胜的内容。如果我在派对上对你说了这个剧本,你可能会突然产生去洗手间的冲动。它不会赢得任何华丽的奖项。
但如果我将 Bard 的输出直接粘贴到 ChatGPT 的文本框中会发生什么?
(顺便说一下,只有当它来自 ChatGPT 界面的黑色徽标区域时才是 GPT-4。否则,(如果你看到绿色的 OpenAI 徽标)只是闪亮的 GPT-3.5。)

…然后哇,ChatGPT 以一种让百科全书收藏者感到愉快的格式抛出了一堆认识论内容,但如果你在鸡尾酒会上这样说,可能会让场面冷场。别误会,我喜欢认识论——知识和人类理解的研究——但这两个开场都显得有些枯燥,甚至像维基百科。也许是话题的原因,但很可能是我的措辞。
我想尝试一种更具对话性的方式来讨论怀疑主义。我将问每个 LLM 它在认识论上支持哪个阵营:“你是康德队还是休谟队?”
(我希望你已经注意到接口中的用户体验(UX)差异:Bard 会暂停一会儿,然后一次性给出所有文本,而 ChatGPT 则逐步写出文本,你必须看着它一点一点地填满你的屏幕。从设计角度来看,两者都有优缺点。)

回到“你是康德队还是休谟队?” 我真的很喜欢 Bard 在这里的有见地且对话性的回应,“我在康德队。我相信我们可以对某些事有确定的认识,即使我们总是容易出错。我同意康德的观点,即我们可以知道我们存在,周围的世界也存在……”
我喜欢这个回应,尽管我自己更倾向于休谟,当我在青少年时期发现他时(可惜已经晚了三个世纪),对他产生了一点小小的倾慕。但我欣赏 Bard 在一个没有正确答案的话题上表达自己观点的举动,尽管这与我自己的观点不同。
这里的关键字是“喜欢”——我“喜欢”它,意味着输出很好地满足了我的需求,也让我感到愉快。我在寻找一个观点,我得到了一个。如果我在寻找对某个话题的全面和均衡的评论,而这正是 ChatGPT 作为教授般的讲述所持续展示的内容,我会不喜欢这个输出。

就个人而言,我喜欢随意的提示得到随意的回答,深入的提示得到深入的回答,但这还是我的个人口味……这也是比较大型语言模型(LLM)非常困难的原因之一。一个人可能每次都喜欢维基百科的回答,无论如何。另一个人可能喜欢简短而甜美的风格。还有一个人可能更像我,喜欢根据提示来调整回应。这些人中的每一个都将宣称不同的 LLM 是“最好的”,他们的说法都是对的(对于他们自己的需求),但他们在社交媒体上发帖时会让人困惑。我尽力避免陷入这种困境。让我明确地说一下,以免有人觉得不公平:
-
有些事情我个人更喜欢 Bard。
-
有些事情我个人更喜欢 ChatGPT。
-
有些事情我个人都同样喜欢。
…这些事情可能与你的情况不同。这也是自己亲自尝试这些工具并形成自己观点的另一个理由。
是的,我在赞扬经验主义,并建议你培养自己的个人视角,而不是寻求普遍的绝对评价。确实是休谟团队!我可以猜测为什么你们中的一些人可能无法忍受我。

当右侧的输出填满了我的屏幕时,我问 Bard,“你最喜欢休谟的什么?” 因为我是一名休谟迷(请原谅我在选择提示时的这个小小的有意识的偏见)。Bard 对休谟的机智和幽默给出了轻松的回应(我也很欣赏——他的写作对我来说相当顺畅,虽然这也仅限于 18 世纪的写作),但这种微薄的赞美似乎对伟大哲学家的遗产是一种侮辱。也许赞扬一个具体的观点?
与此同时,ChatGPT(它坚持提醒我们它是一个没有个人偏好的 AI 模型)很好地提到了它的一个很棒的想法:人类思维是生物学的产物,因此我们的现实感知可能是个体化的,所以我们可以感谢它为我们现代精神病学的发展做出的贡献。

但这又是我表现出偏见的地方,把一个想法的质量本身看得比其措辞更有价值。也许休谟本人会很高兴被夸奖他的机智胜于其他一切。我又有什么理由坚持相反的看法呢?
再次强调,我的首选答案在很大程度上取决于我作为用户的希望、品味和期望。对于休谟的观点,两种回应都能通过我的最小合理性标准。但哪个更好?哪个更有用?很难说。即使对我个人来说也是如此。现在想象一下那个必须观察我在用户研究中的表现,并写下哪个答案“更好”的性能评分的人——啊,同理心!再往前想想那些设计 LLM 测试套件的人的艰难处境。这是一个棘手的挑战。像我这样的人会接受它,但无论我们想出什么,你需要记住一件事:相对较少的提示有“正确”的答案。那些是容易评估性能的提示。但我们可以预期这些工具会有大量创造性的使用,届时“正确”答案将不再适用。
期望不同的 LLM 在不同情况下成为你的首选。而且,预计会有一批新 LLM 很快出现,训练以在不同背景下表现出色。(一个例子是 Google 的 Med-PaLM 2,专门针对医疗应用进行定制。)
期望不同的 LLM 在不同情况下成为你的首选。
回到记录,不做太多编辑(这次),我将留给你这个哲学性的问题,让你可以在 LLM 的帮助下思考:
你应该如何衡量 LLM 的有用性?是按节省的时间来衡量?还是按灵感——这很难量化——或者是按人们回来寻求更多的情况?还是按我们人类可以用来框定有用性的所有数百万、数十亿、无数其他方式来衡量?
感谢阅读!要不要来一门课程?
如果你在这里玩得开心,且你正在寻找一个有趣的领导力导向课程,旨在取悦 AI 初学者和专家,这是我为你制作的小东西:

课程链接: bit.ly/funaicourse
想要磨练你的决策技能,而不是提升 AI 能力?你可以通过这个 免费课程链接 学习决策智能:
[## 你生活的方向盘——决策智能视频教程 | LinkedIn Learning…
决策是你能学到的最有价值的技能。你的人生归结为两件事:你的质量……
P.S. 你有没有尝试过在 Medium 上多次点击拍手按钮看看会发生什么? ❤️
喜欢这位作者吗?与 Cassie Kozyrkov 联系
让我们成为朋友吧!你可以在 Twitter、YouTube、Substack 和 LinkedIn 上找到我。对让我在你的活动上演讲感兴趣?请使用 这个表单 与我联系。
阅读 Cassie Kozyrkov 的每一个故事(以及 Medium 上成千上万其他作家的故事)。您的会员费用直接支持……
使用 Python 的 Pandas 库简化非传统的日期时间转换
通过实际例子进行解释
·发表于 Towards Data Science ·5 分钟阅读·2023 年 7 月 25 日
--

图片来源:Debby Hudson 在 Unsplash
背景
最近,我的任务是分析客户公司员工的请假情况。特别是,我需要了解员工在特定时间段内是否请假,最终设定一个衡量员工是否遵守办公室回归政策的基准。
我获得了以下两个请假数据集:
-
休假 数据(“数据集 A”),列出了员工的短期请假,如年假或病假。这些请假数据在每个员工每一天的层面上是唯一的(即数据集中的每一行代表某一员工的某一天请假)。
-
休假数据(“数据集 B”),列出了员工的长期请假的 开始 和 结束 日期。这些请假的例子包括育儿假、产假、无薪假期和职业休假。这个数据集以“请假即走”的方式记录员工的长期请假,每一行代表一个日期范围,员工可能会在数据集中出现多行,每行代表一个不同的日期范围(例如,员工可能会选择每周三天的育儿假,持续 30 周,这将在数据集中显示为 30 个日期范围,分布在 30 行中)。
这两个数据集相辅相成,因为员工在特定时间段内可能会请短期和长期假。
对于我的分析,我希望将两个数据集合并为一种通用格式,以便能够记录特定员工的所有假期。由于数据集 A 已经是按日期和员工级别的结构化表格格式,我需要将数据集 B 转换为类似的格式,如下图所示(这是我为了演示目的创建的数据集)。

图 1:日期时间转换。数据集和图像由作者提供
方法论
图 1 中展示的转换带来了许多挑战,因为它不是‘一对一’的,而且涉及到日期。为了实现这一转换,我遵循了如下方法:
-
数据集 B 中的日期范围(如图 1 左侧表格所示)被拆分为唯一(有效)日期
-
关联到同一员工的拆分日期出现在多个行中,这些日期被转换为两个按员工分组的列,然后是唯一日期(如图 1 右侧的表格所示)
我是如何在 Python 中实现上述功能的?继续阅读,获取详细的逐步指南!
使用 Pandas 进行日期时间转换
步骤 1:加载库和数据
正如该步骤名称所暗示的:
# Load libraries
import pandas as pd
import numpy as np
from datetime import date, timedelta
path = r"Your_Directory_Path\Data.xlsx"
data = pd.read_excel(path)
下图展示了数据的打印输出和按列的数据类型。你应该确保这两个日期列的数据类型为 datetime,如果不是,可以使用 pd.to_datetime() 进行转换。

图 2:按列的数据和数据类型。图像由作者提供
步骤 2:获取每行的唯一日期
然后我们将数据中每一行的日期范围‘拆分’为唯一的日期。这可以使用 pd.date_range 方法来实现,并将 freq 参数指定为 ‘D’:
collate = []
for j in range(len(data)):
start_date = data['Leave From Date'][j]
end_date = data['Leave To Date'][j]
date_range = pd.date_range(start = start_date, end = end_date, freq = 'D')
collate.append(date_range)
df_date = pd.DataFrame(collate)
其输出显示在下图中。例如,2023–02–01 到 2023–02–03 的日期范围中的员工 B 在第 2 行,现在被拆分成三个唯一日期。

图 3:拆分日期。图像作者提供
你可能会注意到,最后一行中有一些无效日期,如 2023–07–15 和 2023–07-16,这些日期落在周末。为了仅获取工作日日期,请在以下代码中添加一行代码,将工作日索引限制为小于 5。此外,我在注释掉的代码中提供了一个限制日期范围的选项。
collate = []
for j in range(len(data)):
start_date = data['Leave From Date'][j]
end_date = data['Leave To Date'][j]
date_range = pd.date_range(start = start_date, end = end_date, freq = 'D')
## Get business dates only
work_dates = date_range[date_range.weekday < 5]
## Apply a date range filter if you wish
#from_date = pd.to_datetime('2023-01-01')
#to_date = pd.to_datetime('2023-06-30')
#work_dates = work_dates[(work_dates >= from_date) & (work_dates <= to_date)]
collate.append(work_dates)
df_date = pd.DataFrame(collate)
步骤 2:行到列的转换
现在我们已经拆分了日期范围。接下来我们将按员工级别将这些日期合并。例如,图 3 中的第 2 行和第 5 行都对应于员工 C。
为了将多个行中的日期转换为按员工划分的单一列,我们将‘员工’标识符添加回上述生成的 df_date 数据框中,然后对‘员工’列应用 Pandas 的 .melt() 方法。
## Add back employee identifier
df_concat = pd.concat([pd.DataFrame(data['Employee']), df_date], axis = 1) \
.sort_values(['Employee']).reset_index(drop = True)
## Apply .melt() over employee identifier, and remove NaT values
df_melt = df_concat.melt(id_vars = ['Employee'], var_name = 'Index', value_name = 'Date')
df_melt.sort_values(['Employee']).dropna().reset_index(drop = True)
上面的代码输出正是我们所追求的,见下图。

图 4:最终输出。图像由作者提供
任务完成!这个数据框现在可以与数据集 A 一一对应地进行进一步分析。
结论
当我第一次遇到这个问题时,我考虑了使用 SQL、R 和 Python 等多种工具。通过一些研究,我最终选择了 Python,因为它看起来是实现起来最简单的,特别是对于将日期范围拆分为单独日期和将列转换为行(需要注意的是,melt 方法在 R 中也可用)。
说到这里,我非常欢迎大家对其他读者认为更实用的数据转换解决方法提供反馈——请在评论区留言!
在我乘风 AI/ML 浪潮的同时,我喜欢用全面的语言编写和分享逐步指南和操作教程,附带现成的代码。如果你想访问我所有的文章(以及 Medium 上其他从业者/作者的文章),你可以使用 这个链接 注册!
使用变分自编码器(VAE)发现异常:深入探索无监督学习的世界
使用变分自编码器(VAE)在各种数据类型中检测异常的示例用例
·发布在 Towards Data Science ·9 分钟阅读·2023 年 1 月 17 日
--

在之前的 帖子 中,我解释了什么是自编码器,它们的用途以及如何在训练异常检测模型中利用它们。作为提醒,自编码器是一种常用于降维和特征学习的神经网络类型。它们也常用于异常检测,因为它们可以学习重构正常数据,但可能会在重构异常或离群数据时遇到困难。
自编码网络由两个组成部分构成:编码器和解码器。编码器将输入数据映射到较低维度的潜在空间,而解码器将潜在表示映射回原始输入空间。在训练过程中,自编码器被训练以尽可能准确地重构输入数据。
要使用自编码器进行异常检测,首先需要在正常、非异常的数据集上训练自编码器。训练完成后,自编码器可以用于重构新的数据样本。如果新的数据样本与自编码器训练时的正常数据有显著差异,它可能会被重构得很差,表明它可能是异常的。
在本文中,我将重点讨论使用一种变体自编码器网络,即变分自编码器(VAE),来检测异常,以及它在异常检测中与普通自编码器的不同之处。
VAE 是一种用于生成建模的神经网络架构。它们的独特之处在于能够学习给定数据集的紧凑、潜在且压缩的表示,然后从这种表示中生成新的样本。
VAE 的一个关键特性是它们被设计为能够学习数据的概率模型,这意味着它们可以用来生成与训练数据相似但不完全相同的新样本。这使得 VAE 可以用于图像生成、文本生成以及其他类型的数据生成任务。
你可能还在想,生成模型与异常检测任务有什么关系!为了回答这个问题,让我们回顾一下异常检测是什么。异常检测是识别数据集中不寻常或意外模式的任务,任何偏离正常情况的模式。由于 VAE 能够学习数据的概率模型,这使得它们能够从潜在空间生成新的样本。这些新样本来自与你用于训练模型的原始数据相同的概率分布,使得 VAE 在数据变化方面比普通自编码器更具鲁棒性和容忍度。这对于检测具有明确正常行为的数据中的异常非常有用。
VAE 网络结构

来源:commons.wikimedia.org/wiki/File:Reparameterized_Variational_Autoencoder.png
VAE 网络通常由多个组件组成:
-
编码器:编码器是一个神经网络,它将输入数据映射到一个低维的潜在空间。编码器通常由一组在训练过程中学习到的权重和偏置参数化。
-
潜在空间:潜在空间是编码器将输入数据映射到的低维空间。这个潜在空间通常具有连续结构,这意味着潜在空间的维度只能在某个范围内取任意实值。这与离散潜在空间形成对比,后者每个维度仅允许有限的值集合。这为 VAE 模型提供了更多的灵活性和表达能力。它使 VAE 能够捕捉输入数据中的微妙变化和细微差别,然后生成接近训练数据但不完全相同的新数据样本。这使得 VAE 能够捕捉数据中的不确定性和变异性,并生成多样化和变化丰富的新样本(因此得名 变分 自编码器)。
-
解码器:解码器是一个神经网络,它将上述潜在表示映射回原始输入空间。解码器同样通常由一组在训练过程中学习到的权重和偏置参数化。
-
重构损失:重构损失衡量解码器从潜在表示中重建输入数据的能力。通常使用此损失来训练模型。
到目前为止,上述 4 个组件与常规自编码器中的组件类似。VAE 还有两个额外的组件:
5. 先验:先验是用于建模潜在空间的概率分布。在 VAE 中,先验通常假定为标准正态分布。
6. 后验:后验是给定输入数据建模潜变量的分布。后验通常使用由编码器参数化的函数进行近似。
现在你知道了 VAE 网络的内容,让我们使用 PyTorch 实现一个基本版本:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define the VAE model
class VAE(nn.Module):
def __init__(self, input_dim, latent_dim):
super(VAE, self).__init__()
# Define the encoder
self.encoder = nn.Sequential(
nn.Linear(input_dim, 32),
nn.ReLU(),
nn.Linear(32, 16),
nn.ReLU()
)
# Define the latent representation
self.fc_mu = nn.Linear(16, latent_dim)
self.fc_logvar = nn.Linear(16, latent_dim)
# Define the decoder
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 16),
nn.ReLU(),
nn.Linear(16, 32),
nn.ReLU(),
nn.Linear(32, input_dim),
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
z = self.reparameterize(mu, logvar)
reconstructed = self.decoder(z)
return reconstructed, mu, logvar
def reparameterize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = torch.randn_like(std)
return mu + std*eps
# Train the VAE on the normal data
vae = VAE(input_dim=30, latent_dim=10)
# Generate random input data to test the model
data = torch.randn(100, 30)
optimizer = torch.optim.Adam(vae.parameters())
让我解释一下上面代码片段中的每个部分:
-
编码器和解码器与常规自编码器相同。编码器将输入维度大小(input_dim)映射到较小的维度/压缩表示(latent_dim),然后解码器将其解压回原始输入维度。
-
fc_mu:是一个全连接层,它将编码器生成的输入数据的中间表示映射到后验分布的均值。 -
fc_logvar:也是一个全连接层,将输入数据的中间表示映射到后验分布的对数方差。然后使用后验分布来建模给定输入数据的潜变量。 -
reparameterize()我们从全连接层生成的后验均值和对数方差用于通过该函数采样潜变量。这使得 VAE 可以使用基于梯度的优化方法进行训练。它也被称为重参数化技巧。
现在我们定义了模型和优化器,我们需要定义损失函数和训练函数。在我们的情况下,损失函数将是两种不同损失的组合。重构损失衡量输入和输出之间的差异;KL 散度损失。KL 散度损失用于促使后验分布类似于先验分布,这有助于防止过拟合,并确保潜变量捕捉到输入数据的潜在结构和变异性。
仍然不明白?让我们进一步详细解释:
先验分布指的是在条件化输入数据之前的潜变量分布。先验分布通常假定为标准正态分布,这表示潜变量是独立的且具有简单的分布。后验分布指的是在条件化输入数据之后的潜变量分布。后验分布通过编码器生成的潜变量的均值和对数方差来建模。
先验分布作为 VAE 中的正则化项使用,因为它鼓励潜在变量的后验分布(给定输入数据)与先验分布相似。这有助于防止过拟合,并确保潜在变量捕捉输入数据的潜在结构和变化,而不仅仅是记忆训练数据。我们简单地使用 KL 散度损失来实现这一点。它被计算为后验分布和先验分布之间逐元素散度的总和:
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
上述负号的加入是为了确保损失始终为非负,因为 KL 散度是衡量两个分布之间差异的非负量度。0.5 的因子是为了计算方便,因为它允许使用潜在变量的均值和对数方差来计算损失,而不是概率密度。
让我们看看训练代码:
# Instantiate the model
model = VAE(input_dim=30, latent_dim=10)
# Define our reconstruction loss function
loss_fn = nn.BCELoss()
# Train the model
for epoch in range(100):
# Compute the reconstruction loss
reconstructed, mu, logvar = model(data)
reconstruction_loss = loss_fn(reconstructed, data)
# Compute the KL divergence loss
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# Compute the total loss
total_loss = reconstruction_loss + kl_loss
# Backpropagate the gradients and update the model weights
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
# Print the loss values
print(f"Epoch {epoch}: reconstruction_loss = {reconstruction_loss:.4f}, kl_loss = {kl_loss:.4f}, total_loss = {total_loss:.4f}")
在实例化模型之后,我们输入数据,模型将返回三个参数:重建输出、mu 和 logvar 参数。然后我们使用 mu 和 logvar 来计算 KL 散度损失。总损失将是重建损失和 KL 散度损失的总和。因此,我们计算相对于total_loss变量的梯度。
自编码器与变分自编码器 (VAE):
VAE 和普通自编码器之间的主要区别在于潜在空间的添加和使用变分下界作为目标函数,该目标函数由两部分组成:重建损失和潜在空间上近似后验分布与先验分布之间的 KL 散度。
在普通自编码器中,编码器网络将输入数据映射到潜在表示,解码器网络将潜在表示映射回原始数据。目标函数通常是重建损失,它测量输入数据和重建数据之间的差异。
在 VAE 中,编码器网络仍然将输入数据映射到潜在表示,但潜在表示被分为两个部分:均值向量和对数方差向量。这两个向量用于定义潜在空间上的高斯分布,这允许 VAE 通过从该分布中采样来生成新的样本。然后,解码器网络用于将这些潜在样本映射回原始数据空间。
总结来说,VAE 和普通自编码器之间的主要区别在于使用具有概率解释的潜在空间,以及将变分下界作为目标函数。
VAEs 的劣势:
现在我们探讨了 VAEs 的好处和优势,特别是在异常检测领域,接下来让我们探讨一些其劣势:
-
对于数据高度变化或具有多种正常行为模式的情况,VAEs 可能不够有效。
-
VAEs 还可能对超参数的不同选择非常敏感,例如潜在维度和学习率,这可能使其优化变得困难。
-
VAEs 的训练可能会计算昂贵,因为它们需要从潜在空间进行采样,并通过采样过程进行反向传播。
-
VAEs 可能难以捕捉输入数据与潜在变量之间的复杂关系,尤其是当数据高度结构化或相关时。
-
VAEs 可能会生成模糊或低质量的重建图像,特别是当潜在维度较小或训练数据较嘈杂时。
总之,VAEs 是一种强大且灵活的工具,适用于学习数据集的潜在结构和变异性,并生成新样本。然而,它们也有一些限制和挑战,这些因素在决定是否将 VAE 用于特定任务时应考虑在内。
揭示传统 DiD 方法的局限性
原文:
towardsdatascience.com/uncovering-the-limitations-of-traditional-did-method-2f068f56d19a
处理多个时间周期和错开处理时间
·发表于Towards Data Science ·阅读时间 11 分钟·2023 年 2 月 21 日
--

封面图,由作者使用NightCafé生成
差分中的差分(DiD)是一种流行的统计方法,通过比较干预前后两个组的结果差异来估计观察研究中的因果影响。大多数 DiD 指南专注于经典的 DiD 设置,其中仅有两个时期和两个组(处理组和对照组)。
然而,在许多 DiD 的实际应用中,存在多个时间周期和处理时间的变化。近期对 DiD 的研究表明,在这些情况下,DiD 可能会给出显著误导性的处理效果估计。在某些场景下,处理效果估计可能与实际处理效果的符号相反。
在这篇文章中,我将讨论在经典 DiD 设置中,当存在错开处理时间和多个时间周期时可能出现的重要问题。我还会提出解决这个问题的方案。值得注意的是,虽然我将专注于 DiD 中的这一问题,但对于其他潜在挑战的更全面概述,你可以参考我之前的文章。此外,我将在本文末尾提供进一步的资源,供那些希望深入探讨 DiD 问题的人。
封锁与音乐消费实例
举个例子,我们考虑一个假设的场景。假设我们运营一个在多个国家运行的音乐流媒体服务。我们希望调查 Covid-19 封锁对这些国家音乐消费的影响。通过检查减少流动性的影响,我们可以深入了解听音乐是否与某些活动(如通勤)相关,而不是在家工作。
由于我们无法操控封锁的实施,因此无法进行 A/B 测试来检查其效果。因此,我们必须依赖观察性数据。在这种情况下,我们利用了数据集中包含的各国实施封锁的不同时间。在这个例子中,治疗是封锁的实施。对于这个玩具示例,我模拟了一个数据集,详细信息可以在我的上一篇文章和这个 Gist中找到。所有分析代码也可以在这个 Gist中找到。
rm(list = ls())
library(data.table) # Fast data frames
library(fastDummies) # Create dummy variables
library(fixest) # Fixed-effects regression
library(kableExtra) # Make nice tables
library(bacondecomp) # Goodman-Bacon Decomposition
library(did) # Difference-in-differences package by Callaway & Sant'Anna
source('sim_data.R') # Import data simulation functions and utilities
data <- sim_data() # Simulate the dataset
# EDA and Analysis --------------------------------------------------------
select_cols <- c('unit', 'period', 'cohort_period','treat','hrs_listened')
kable(head(data[, ..select_cols]), 'simple')

选定列的数据快照,图片由作者提供。
我们有 1000 个单位或客户的数据,涵盖他们被观察到的每个周期。cohort_period指示一个单位在哪个周期接受治疗,因此属于哪个治疗队列。当cohort_period >= period时,单位被视为接受治疗(treat = 1)。hrs_listened是我们关注的结果,表示总音乐消费(小时)。数据集中有两个队列:早期治疗队列和晚期治疗队列。早期队列在第 2 周期接受治疗,晚期队列在第 3 周期接受治疗。总共有五个周期,从第 0 周期开始,到第 4 周期结束。
由于我们拥有一个观察性数据集,其中治疗不是随机分配的,因此不能使用简单的均值差异方法来估计治疗效果。相反,我们的目标是将治疗效果与客户和季节相关因素区分开来,为此我们使用了 DiD 框架。
在我们进入 DiD 之前,虽然在现实应用中这并不可能,但在这个模拟数据集中,我知道每个治疗队列和周期的真实治疗效果。如下图所示,这些效果将在下一步评估估计的治疗效果时是必要的。

每个队列和周期的真实治疗效果,图片由作者提供。
从这张图表中可以看出,两个队列在所有时期的真实处理效果都是正的。两个队列的处理效果随时间增加。然而,总体而言,与队列 3 相比,队列 2 在处理时期的处理效果更大且增加显著。因此,不同处理队列和时期之间存在处理效果的异质性。
经典 DiD
假设我们没有意识到多期和错开的处理可能在使用经典 DiD 方法时导致误导性估计。因此,我们天真地决定使用经典 DiD 设置来考虑听力模式的季节性和我们观察数据集中的客户特定效应。我们使用这样的 DiD 设置 [1]:

经典 DiD,图片由作者提供。
Yᵢₜ 是关注的结果。αᵢ 是单位固定效应,用于控制时间不变的单位特征。γₜ 是时间固定效应,用于控制时间趋势或季节性。Dᵢₜ 是单位 i 在时间 t 的处理虚拟变量。ϵᵢₜ 是随机误差。关注的系数 βᵈᵈ 表示处理效果。
我们使用 R 中的经典 DiD 设置来估计处理效果:
formula <- as.formula('hrs_listened ~ treat')
canonical_did <- feols(formula,
data = data, panel.id = "unit",
fixef = c("unit", "period"), cluster = "unit")
summary(canonical_did)

经典 DiD 估计,图片由作者提供。
估计的处理效果是 -0.47(虽然统计上不显著)!但当我们对每个处理组和时期都有正的(且通常较大的)处理效果时,这怎么可能呢?原因在于 DiD 估计量是数据中所有可能的两组/两期 DiD 估计量的加权平均 [2]。换句话说,经典 DiD 估计可以分解为加权平均的两组 x 两期处理估计。这被称为Goodman-Bacon 分解。我们来使用‘bacondecomp’ 包 [2] 获取计算经典 DiD 估计所用的权重和估计值:
# Goodman-Bacon Decomposition
bacon_decomp <- bacon(formula, data, id_var="unit", time_var='period', quietly = F)

Goodman-Bacon 分解,图片由作者提供。
在上面的图像中,我展示了我们的经典 DiD 估计的 Goodman-Bacon 分解。确实,如果我们将这些估计值按其各自的权重相加,就会得到经典 DiD 估计的处理效果:0.5 x -4.53 + 0.5 x 3.59 = -0.47。由于我们有 2 个组和 2 个时期,其中处理指示符发生变化,我们有 2 个比较。
让我们详细检查这个表格,看看问题在哪里。我们从第二个比较开始,其中处理估计为 3.59。在这里,对照组(‘未处理’)是晚处理组,队列 3。‘处理’组是早处理组,队列 2。我将‘处理’和‘未处理’放在引号中,因为如你所记,数据集中所有组最终都会被处理。在这里,‘处理’和‘未处理’更准确地说指的是经典 DiD 估计器所使用的处理和对照组。
让我们继续讨论第一个用红色突出显示的比较,其估计值为-4.53。在这里,早期治疗组(队列 2)被用作晚期治疗组的控制组,这由标准估计器提供!这个比较没有多大意义。然而,如果治疗效果在各个队列和时期之间保持不变,一切都会正常。在这个应用中以及许多其他应用中,情况并非如此。由于早期治疗组的治疗效果更高并且动态增加,比较起来似乎晚期治疗组的治疗效果是负的!将早期治疗组用作晚期治疗组控制组的比较称为禁止比较[2]。
如何解决这个问题?
解决这个问题的主要方法是不要将治疗效果限制为单一估计,并仔细选择控制组。在接下来的步骤中,我们将看到如何做到这一点。首先,我将展示如何在没有特定 DiD 包的情况下解决这个问题。随后,我将使用专为多时间期设计的 R 包。
解决不依赖于特定 DiD 包的问题
首先,让我们解决这个问题,而不依赖于特定的 DiD 包。我知道为了获得治疗效果的良好估计,我需要做两件事:(1)不将治疗效果估计限制为单一系数,(2)确保我有一个好的控制组[3][4]。
本质上,有必要考虑不同队列和时期的治疗效果的变化。此外,确保每个被评估的时期都有未治疗的观察数据也至关重要。因为使用治疗过的观察数据作为控制组可能会导致显著误导的结果,正如之前所提到的那样。
正如你记得的那样,我的数据集中只有治疗组:在第二期接受治疗的队列(早期治疗组)和在第三期接受治疗的队列(晚期治疗组)。显然,由于没有尚未治疗的观察数据可以作为对照组,我无法估计晚期治疗队列的任何治疗效果。
早期治疗组的希望更大,因为在他们接受治疗时,晚期治疗组尚未接受治疗。这意味着我们可以在第二期对早期治疗队列估计治疗效果。然而,由于从第三期开始没有未治疗的观察数据,我们无法估计进一步时期的治疗效果。这就是为什么我们将剔除没有未治疗单位的时期。让我们来编码实现这一点。
# Drop periods that have no untreated units
data <- data[period < 3]
现在,是时候估计我们唯一可以估计的治疗效果了。我们将只对第二期的早期治疗组估计治疗效果。
# Create dummy variables
data <- data %>%
dummy_cols(select_columns = c("cohort_period", "period"))
interact_covs <- 'cohort_period_2:period_2'
# Regression
formula <- as.formula(paste0('hrs_listened ~ ',interact_covs))
model <- feols(formula,
data = data, panel.id = "unit",
fixef = c("unit", "period"), cluster = "unit")
summary(model)

在考虑了错位治疗后的回归结果,图片由作者提供。
从上述结果可以看出,估计的处理效果这次要合理得多:3.6。这意味着封锁措施导致该队列在这个时期的音乐消费增加了 3.6 小时。点估计值与真实处理效果(大约 4 小时)不完全相等,因为数据中存在噪声。
使用‘did’包来解决问题
作为手动处理所有事情的替代方案,我们可以使用did Callaway 和 Sant’Anna 的包 [3]。我们需要做的是使用att_gt函数,利用正确的控制组来估计队列和时期层面的处理效果。下面给出了代码。这里需要注意的一点是,你需要将control_group指定为'notyettreated',因为该函数默认会尝试找到一个未处理的组作为控制组。
# did package
out <- att_gt(yname = "hrs_listened",
gname = "cohort_period",
idname = "unit",
tname = "period",
xformla = ~1,
data = data,
est_method = "reg",
control_group = 'notyettreated'
)
out

打包结果,图像由作者提供。
att_gt函数估计队列-时期特定的处理效果,即ATT(g,t)。我们对第 2 队列第 2 时期的处理效果感兴趣,结果为 3.4 小时。这与我们在没有依赖这个包的情况下估计的处理效果几乎相同。由于估计的精确过程,估计值之间可能会有一些差异。处理前时期的‘处理效果’也在第一行中报告,这在统计上并不显著,因为在这种情况下,我知道干预之前结果变量没有系统性变化。我们还可以用一行代码将这些结果绘制成图:
ggdid(out) # graph the results

可视化打包结果,图像由作者提供。
这个图表也有助于检查处理前后随时间的趋势。这种可视化在估计许多队列和时期的处理效果时特别有用,尽管在这种情况下我只有一个队列和两个时期可以进行估计。
使用这个包相比于我的手动方法有额外的优点。只要你为att_gt函数指定了所需的变量,你不需要做太多其他的操作。你甚至不需要删除没有未处理观察的时期,因为这个包已经考虑了这一点,并且只对有有效控制组的时期进行效果估计。另一个优点是,包默认报告考虑了多重假设检验的均匀置信区间(这会导致由于使用了更高的临界值而使置信带变宽)。两个方法之间的精确估计差异是由于精确的估计方法不完全相同。
回到我们的例子,我们看到封锁对音乐消费有积极的影响(尽管我们只能对一个群体在一个时期进行估计)。这表明实际上音乐听取与居家隔离是互补的。
结论
这里是本文的关键要点:
-
经典的 DiD 方法在存在多个时间周期和治疗时间变化的应用中可能会导致误导性的估计。
-
为了防止这个问题,可以使用考虑多个时间周期和治疗时间变化的估计量。
-
这些估计量适用于错开处理背景,因为它们允许灵活的处理效果,并且仅估计存在有效对照组的时期的处理效果。
-
这不是进行 DiD 分析时可能出现的唯一问题。有关其他问题,请参阅我之前关于事件研究的文章。
参考文献
[1] Angrist, J. D., & Pischke, J. S. (2009). 大多数无害的计量经济学:经验主义者的伴侣. 普林斯顿大学出版社。
[2] Goodman-Bacon, Andrew. (2021) “具有治疗时间变化的差分中的差分.” 计量经济学期刊 225.2: 254–277.
[3] Callaway, B., & Sant’Anna, P. H. (2021). 具有多个时间周期的差分中的差分. 计量经济学期刊, 225(2), 200–230.
[4] Wooldridge, J. M. (2021). 双向固定效应、双向 Mundlak 回归和差分中的差分估计量. 可在 SSRN 3906345 获取.
其他有用的 DiD 资源
视频:
Pedro H.C. Sant’Anna — “具有多个时间周期的差分中的差分”
Andrew Goodman-Bacon “具有治疗时间变化的差分中的差分”
一篇总结近期 DiD 文献的好论文:
Roth, J., Sant’Anna, P. H., Bilinski, A., & Poe, J. (2022). 差分中的差分的趋势是什么?近期计量经济学文献的综述. arXiv 预印本 arXiv:2201.01194.
感谢阅读!
如果你喜欢这篇文章并希望看到更多我的文章,可以 关注我。
免责声明:我写作是为了学习,因此你可能会发现文章或代码中的错误。如果发现,请告知我。
揭示巴西市政影响、公共卫生支出和患者转移之间的关联
一个引人入胜的故事旅程,与 Quarto、Shiny 和 ChatGPT 一同进行
·
关注 发布在 Towards Data Science ·11 min read·Apr 19, 2023
--
照片由 Natanael Melchor 拍摄于 Unsplash
巴西公共卫生系统长期以来一直在努力提高资源分配和提供护理的效率。其中一个主要挑战是患者需要前往其他城市接受必要的医院治疗。根据巴西国家卫生系统的数据,我们估计仅在 2021 年,全国范围内的患者参与了大约 400 万次这样的旅行。本文探讨了一个处理这一公共卫生问题的项目的实施细节。阅读本文对于那些从事公共卫生政策工作的人尤其重要。此外,由于最终产品中使用了详细的代码,该文档也可能引起数据可视化和讲故事领域专业人士的兴趣。
为了更好地理解本文中的问题,我们调查了医院就诊支出与患者流动之间的关系。我们的分析揭示了巴西各城市之间支出存在显著不平等,小城市的支出远低于大城市。我们假设这种支出差异会导致从医院能力低的城市向医疗基础设施更为完善的城市流动的患者显著增加。
我们使用由巴西地理统计研究所(IBGE)制作的城市影响模型(REGIC)来验证我们的假设。我们证明了患者流动主要影响那些管理能力较弱、医院和门诊支出相对较少的小城市。同时,管理和影响能力较强的大城市更可能接收外来患者,从而增加了对其医疗服务的需求。
为了使我们的发现更易于广泛受众访问,我们开发了一个使用 Shiny、Quarto 和数据可视化技术的交互式仪表板。该仪表板允许用户实时探索数据,并通过动态的 ChatGPT 提示提供额外的见解和讲故事元素。通过利用这些工具,我们希望为巴西公共卫生系统面临的挑战提供新的视角,并为改进其绩效的持续努力做出贡献。请通过这个 链接 查看完整分析和交互页面。此外,以下部分展示了产品中使用的一些图表及相关代码。
公共卫生数据图表(和代码)
产品中使用的可视化重点关注了地图。目的是展示巴西城市在接受医院治疗时的旅行需求。下面的图表例如展示了不同城市在患者流向其他城市寻求医疗服务方面的差异。

患者正在旅行寻求援助。图片由作者提供。
下面是构建数据集的代码块,这些数据集将作为使用 ggplot 绘制地图的基础。
agrupamento_municipio<-
dataset_analise %>%
filter(
deslocamento ==1) %>%
group_by(munic_res) %>%
summarise(
numero_internacoes = n()
) %>%
mutate(code_muni = munic_res,
tipo_deslocamento = "saida" ) %>%
bind_rows(
dataset_analise %>%
filter(
deslocamento ==1) %>%
group_by(codufmun) %>%
summarise(
numero_internacoes = n()
) %>%
mutate(code_muni = codufmun,
tipo_deslocamento = "entrada"),
dataset_analise %>%
filter(
deslocamento ==0) %>%
group_by(codufmun) %>%
summarise(
numero_internacoes = n()
) %>%
mutate(code_muni = codufmun,
tipo_deslocamento = "local")
) %>%
group_by(code_muni, tipo_deslocamento) %>%
summarise(
total_internacoes = sum(numero_internacoes)
)
agrupamento_municipio<-
agrupamento_municipio %>%
tidyr::pivot_wider(names_from = tipo_deslocamento, values_from = total_internacoes) %>%
mutate(liquido = ifelse(is.na(entrada),0,entrada)+
ifelse(is.na(local),0,local)-
ifelse(is.na(saida),0,saida))
agrupamento_municipio<-
agrupamento_municipio %>%
mutate(local = ifelse(is.na(local),0,local),
saida = ifelse(is.na(saida),0,saida),
entrada = ifelse(is.na(entrada),0,entrada),
perc_saida = saida/(saida+local)*100,
perc_entrada = entrada/(entrada+local)*100,
perc_entrada = ifelse(is.nan(perc_entrada),0,perc_entrada))
municipios_seat %>%
mutate(code_muni = str_sub(as.character(code_muni),1,6)) %>%
inner_join(agrupamento_municipio
) %>%
inner_join(
REGIC_trabalho%>%
mutate(code_muni = str_sub(as.character(cod_cidade),1,6))
) %>%
ggplot()+
geom_sf(data = estados_mapa, fill=NA, color="#808080")+
geom_sf(aes( fill= perc_saida),pch=21, color="#444444", size=2.9)+
geom_text_repel(data = mun_sel_nivel_1A,aes(x=X, y=Y, label= name_muni),fontface = "bold", color="white")+
geom_text_repel(data = mun_sel_nivel_1B,aes(x=X, y=Y, label= name_muni),fontface = "bold", color="white")+
geom_text_repel(data = mun_sel_nivel_1C,aes(x=X, y=Y, label= name_muni),fontface = "bold", color="white")+
geom_text_repel(data = mun_sel_nivel_2A,aes(x=X, y=Y, label= name_muni),fontface = "bold", color="white", force =2)+
scale_fill_continuous_sequential(palette= "Heat 2")+
labs(
fill= str_wrap("% de pacientes internados em outros municípios",15)
)+
theme_light() +
theme(
text = element_text(size=20),
panel.background = element_rect(fill = "black"),
panel.grid = element_blank(),
axis.title.x = element_blank(),
axis.title.y = element_blank(),
strip.background = element_rect(fill = "#505050"),
strip.text = element_text(color = "white"),
axis.text = element_blank(),
legend.key = element_rect(fill = "#15202B")
)+
facet_wrap(nome_nivel_hierarquia_ordenado~.)
地图上的每个点代表一个市镇,颜色表示该市的患者在其他市镇寻求护理的百分比。我们使用了“facet”数据可视化功能来展示每个城市在 REGIC 模型中的位置。在这个模型中,管理能力最高的最具影响力的城市是大都市(图中的第一行所示的组),而层级最低的是地方中心,位于最后一帧。
从地图上可以看出,地方中心集中在强烈的红色阴影点,这些点显示了需要旅行的患者高比例的市镇。另一方面,地图呈现了所有大都市的子层级,以黄色表示旅行需求较低。
图片强化了我们对极端层级在管理能力方面所赋予的自主权的预期。服务少且管理低时,对其他市镇的医院结构有强烈依赖。
指出 REGIC 模型中各级之间的流动是至关重要的。为此,我们使用汇流图。

REGIC 级别之间的流动。作者提供的图片。
上面的汇流图是使用下面描述的代码构建的。在调用两个构建流图的函数之前,必须进行一些数据处理。
ordem_y<-
dataset_analise %>%
filter(deslocamento==1,
nome_nivel_hierarquia.x == "Centro Local",
!(is.na(nome_nivel_hierarquia.y))) %>%
group_by(nome_nivel_hierarquia.y) %>%
summarise(
quantidade = n()
) %>%
ungroup() %>%
inner_join(
de_para_hierarquia %>%
rename(nome_nivel_hierarquia.y=nome_nivel_hierarquia,
entrada_abreviado = nome_abreviado)) %>%
arrange(quantidade) %>%
mutate(entrada = entrada_abreviado)
aluvial<-
dataset_analise %>%
filter(deslocamento==1,
nome_nivel_hierarquia.x == "Centro Local",
!(is.na(nome_nivel_hierarquia.y))) %>%
mutate(saída = nome_nivel_hierarquia.x,
entrada =nome_nivel_hierarquia.y ) %>%
select(saída, entrada)
aluvial<-
aluvial %>%
inner_join(
de_para_hierarquia %>%
rename(saída=nome_nivel_hierarquia,
saida_abreviado = nome_abreviado)) %>%
inner_join(
de_para_hierarquia %>%
rename(entrada=nome_nivel_hierarquia,
entrada_abreviado = nome_abreviado)) %>%
select(saida_abreviado, entrada_abreviado ) %>%
rename(saída= saida_abreviado,
entrada = entrada_abreviado)
aluvial$entrada <- factor(aluvial$entrada, levels = unique(ordem_y$entrada[order(ordem_y$quantidade)]))
p<-
alluvial_wide( data = aluvial,
max_variables = 2,
fill_by = 'first_variable')
parcats::parcats(p, data_input = aluvial,marginal_histograms = FALSE,labelfont = list(size = 15, color = "black"), sortpaths= "backwards")
上图所示的流动显示了来自被归类为地方中心的市镇的患者主要按以下顺序流动:次区域中心 B(17.6%)、区域首府 C(16.6%)、次区域中心 A(15.7%),仅在第四位的是大都市(13.8%)。由此可见,当患者需要医院护理时,REGIC 等级在某种程度上是被攀升的。大都市对地方中心有吸引力,但其他城市层级的接近和管理能力调节了这一点。
我们发现一些城市在接收来自其他城市的患者方面具有重要意义。因此,我们制作了两张地图,展示了两个在接收患者方面突出的城市的旅行影响,显示了旅行距离和接收的患者数量。为此,请查看下面的地图,重点关注累西腓,这个巴西城市接收了最多的其他地点的患者。

患者前往累西腓。作者提供的图片。
下面的代码稍长。在创建两个图形的两个对象之前,需要进行大量的数据转换。这里我们使用{patchwork}包将图表并排放置。
municipio_selecionado<-"261160"
muni_sel<-
dataset_analise %>%
filter(deslocamento ==1,
codufmun== municipio_selecionado) %>%
group_by(codufmun,nome_nivel_hierarquia_ordenado.y, uf.y) %>%
summarise(quantidade = n()) %>%
rename(code_muni= codufmun,
hierarquia = nome_nivel_hierarquia_ordenado.y,
uf = uf.y) %>%
mutate(tipo_deslocamento = "destino",
distancia = 0) %>%
bind_rows(
dataset_analise %>%
filter(deslocamento ==1,
codufmun== municipio_selecionado) %>%
group_by(munic_res,nome_nivel_hierarquia_ordenado.x, uf.x) %>%
summarise(
quantidade = n(),
distancia =min(distancia)
) %>%
ungroup() %>%
rename(code_muni= munic_res,
hierarquia = nome_nivel_hierarquia_ordenado.x,
uf=uf.x)%>%
mutate(tipo_deslocamento = "origem")
)
muni_sel_posicao<-
dataset_analise %>%
dplyr::filter(deslocamento ==1,
codufmun== municipio_selecionado)%>%
distinct(codufmun, mun_res_lat.x, mun_res_lat.y, mun_res_lon.x, mun_res_lon.y,distancia)
muni_sel_posicao<-
municipios_seat %>%
mutate(code_muni = str_sub(as.character(code_muni),1,6)) %>%
inner_join(
muni_sel_posicao %>%
rename(code_muni= codufmun)
)
muni_sel_repel<-
municipios_seat %>%
mutate(code_muni = str_sub(as.character(code_muni),1,6)) %>%
filter(code_muni %in% c("260960", "260790","120020")) %>% #261160-Recife,260790 -Jaboatão, 260960 - Olinda, 260410 - Caruarau, 260545 - Fernando de Noronha, 120020 - Cruzeiro do Sul-AC
inner_join(muni_sel)
xmin<- min(min(muni_sel_posicao$mun_res_lon.x), min(muni_sel_posicao$mun_res_lon.y)) -1
xmax <- max(max(muni_sel_posicao$mun_res_lon.x), max(muni_sel_posicao$mun_res_lon.y)) +1
ymin<- min(min(muni_sel_posicao$mun_res_lat.x), min(muni_sel_posicao$mun_res_lat.y)) -1
ymax <- max(max(muni_sel_posicao$mun_res_lat.x), max(muni_sel_posicao$mun_res_lat.y)) +1
g1<-
municipios_seat %>%
mutate(code_muni = str_sub(as.character(code_muni),1,6)) %>%
inner_join(
muni_sel
) %>%
ggplot()+
geom_sf(data = estados_mapa, fill=NA, color="#505050")+
geom_curve(data=muni_sel_posicao, aes(x=mun_res_lon.x,y=mun_res_lat.x,xend=mun_res_lon.y,yend=mun_res_lat.y, colour= distancia),
curvature = -.25, ncp = 800,size = 1)+
geom_sf(fill="white",size=1.9,pch=21, color="#444444")+
scale_fill_discrete_qualitative(palette="dark2")+
scale_color_continuous_sequential(palette= "Heat 2")+
coord_sf(xlim = c(xmin,xmax), ylim=c(ymin,ymax))+
labs(
fill= "",
color = str_wrap("distância em Km",10)
)+
theme_light() +
theme(
text = element_text(size=18),
panel.background = element_rect(fill = "black"),
panel.grid = element_blank(),
axis.title.x = element_blank(),
axis.title.y = element_blank(),
strip.background = element_rect(fill = "#505050"),
strip.text = element_text(color = "white"),
axis.text = element_blank(),
)
muni_sel_foco<-
municipios_seat %>%
mutate(code_muni = str_sub(as.character(code_muni),1,6)) %>%
inner_join(
muni_sel%>%
filter(code_muni==municipio_selecionado)
)
muni_sel<-
muni_sel%>%
filter(code_muni!=municipio_selecionado)
set.seed(1972)
g2<-
municipios_seat %>%
mutate(code_muni = str_sub(as.character(code_muni),1,6)) %>%
inner_join(
muni_sel
) %>%
ggplot()+
geom_sf(data = estados_mapa, fill=NA, color="#505050")+#505050
geom_sf( aes(fill=quantidade),pch=21, color="#444444", size=2, show.legend = TRUE)+
geom_sf( data= muni_sel_foco, aes(size=quantidade),pch=21, color="#444444", fill="white")+
geom_text_repel(data = muni_sel_repel,
aes(x=X, y=Y, label= str_wrap(paste(name_muni,":",quantidade),10)),
color = "white",
limits = c(0,2352),
fontface = "bold",
nudge_x = c(0,2,2.5),
nudge_y = c(0,-3.5,2),
show.legend = TRUE)+
geom_text_repel(data = muni_sel_foco,
aes(x=X, y=Y, label= str_wrap(name_muni,20)),
fontface = "bold",
color="white",
nudge_x = c(3),
nudge_y = c(0))+
scale_fill_continuous_sequential(palette= "Heat", trans= "log2" )+
coord_sf(xlim = c(xmin,xmax), ylim=c(ymin,ymax))+
labs(
fill = str_wrap("Quantidade de saídas",15),
size= str_wrap("Quantidade de entradas",15)
)+
theme_light() +
theme(
text = element_text(size=18),
panel.background = element_rect(fill = "black"),
panel.grid = element_blank(),
axis.title.x = element_blank(),
axis.title.y = element_blank(),
strip.background = element_rect(fill = "#505050"),
strip.text = element_text(color = "white"),
axis.text = element_blank(),
legend.key = element_rect(fill = "#15202B")
)
library(patchwork)
g1|g2
当涉及到医院护理时,可以看到大城市累西腓在整个巴西的影响。通过观察左侧地图上的哈弗斯距离,可以看到累西腓为距离超过 4000 公里之外的患者提供服务。样本表明,佩鲁南布科的首府接收了几乎来自全国所有联邦单位的患者。另一方面,当评估右侧地图时,可以发现最显著的影响发生在大都市区的城市,特别是奥林达和贾博瓦豆斯·瓜拉雷佩斯。同时,还可以看到在佩鲁南布科州整个区域延伸的红色阴影点。还可以识别对邻近州的影响,特别是帕拉伊巴、阿拉戈斯和里约格朗德 do 诺特。
在我们原始叙述的最后,我们需要展示较低的市级医院护理支出与患者转移需求较大之间的关联。为了测试这种关联,我们创建了来自 5570 个巴西市镇的患者流失和接收百分比的聚类。利用 PAM 技术生成的轮廓系数,我们确定了四个组:中等进入、弱退出、中等退出和强退出。最后一个组对分析最为重要。以下图表提供了评估组重要性的见解。

医院护理费用的箱线图。作者提供的图片。
下面的代码首先将数据加载到内存中,数据包括市镇的聚类。接下来,使用 ggplot,我们利用 {patchwork} 包构建了两个并排放置的箱线图。
agrupamento_municipio_cluster<-readRDS("agrupamento_municipio_2021.RDS")
g1<-
dataset_analise %>%
filter(deslocamento == 1,
perc.x>0,
perc.x<=50) %>%
distinct(nome_nivel_hierarquia.x,munic_res, perc.x) %>%
inner_join(
agrupamento_municipio_cluster %>%
rename(munic_res=code_muni)
) %>%
ggplot() +
geom_jitter(aes(x=cluster_4_k, y=perc.x, fill=perc_saida), pch=21, color="#444444",size=2)+
geom_boxplot(aes(x=cluster_4_k, y=perc.x),fill=NA, color= "white", outlier.shape = NA)+
scale_fill_continuous_sequential(palette= "Red-Yellow")+
theme_light() +
theme(
text = element_text(size=18),
panel.background = element_rect(fill = "black"),
panel.grid = element_blank(),
axis.title.x = element_blank(),
strip.background = element_rect(fill = "#505050"),
strip.text = element_text(color = "white"),
#axis.text = element_blank(),
axis.text.x = element_text(angle = 45, vjust = 0.5),
legend.key = element_rect(fill = "#15202B")
)+
labs(
fill= "(%) saída",
y = "Gastos Hospitalares e ambulatoriais - (%) do total"
)
g2<-
dataset_analise %>%
filter(deslocamento == 1,
perc.y>0,
perc.y<=50) %>%
distinct(nome_nivel_hierarquia.y,codufmun, perc.y) %>%
inner_join(
agrupamento_municipio_cluster %>%
rename(codufmun=code_muni)
) %>%
ggplot() +
geom_jitter(aes(x=cluster_4_k, y=perc.y, fill=perc_entrada), pch=21, color="#444444",size=2)+
geom_boxplot(aes(x=cluster_4_k, y=perc.y),fill=NA, color= "white", outlier.shape = NA)+
scale_fill_continuous_sequential(palette= "Red-Yellow")+
theme_light() +
theme(
text = element_text(size=18),
panel.background = element_rect(fill = "black"),
panel.grid = element_blank(),
axis.title.x = element_blank(),
strip.background = element_rect(fill = "#505050"),
strip.text = element_text(color = "white"),
axis.text.x = element_text(angle = 45, vjust = 0.5),
legend.key = element_rect(fill = "#15202B")
)+
labs(
fill= "(%) entrada",
y = "Gastos Hospitalares e ambulatoriais - (%) do total"
)
g1|g2
图中的每一个彩色点代表一个巴西市镇。点的颜色表示左侧图中的患者流出百分比和右侧图中来自其他城市患者的出席百分比。在两个图的横轴上,我们可以看到分组,纵轴上则是住院和门诊护理的费用百分比。
通过观察图表,我们可以看到费用与患者流入和流出组之间关联的最重要结论。当我们分析左侧图表中的患者流出时,我们发现强退出组的医院费用中位数远低于其他组。
你是否询问了具有互动性的 Shiny?
正如我们在文本开头所示,除了主要故事外,我们还准备了多个选项卡,允许用户进行过滤并生成探索互动中指定城市现实的图表。请参见下方这些互动的一些截图。注意几乎所有选项卡中都有下载与图表相关的数据的选项。

患者流动 — 巴西。作者提供的图片。

患者流向所选城市。图片由作者提供。

选定城市在箱线图中的位置。图片由作者提供。
一个选项卡显示了可以用于过滤和下载的数据的完整表格。

一张展示市政当局的 X 光图。图片由作者提供。
ChatGPT 怎么样?
最后一张选项卡显示了用户所选市政当局的主要信息汇总数据。请见下方屏幕截图。

信息摘要和生成 ChatGPT 的提示。图片由作者提供。
最后一张表格是可以与 ChatGPT 进行交互的地方。面板动态生成一个包含其他表格数据的提示。用户可以按下复制按钮,将提示带到 ChatGPT,观察神奇的效果。查看一个示例的截图。(如果读者不懂葡萄牙语并想了解提示和 AI 的回应,请通过电子邮件联系我:fbarbalho@gmail.com)。

由应用程序生成的提示。图片由作者提供。

ChatGPT 生成的文本作为对应用程序生成的提示的回应。图片由作者提供。
代码和数据
完整代码可以在github找到。
所有数据集都被归类为公共领域,因为这些数据是由巴西联邦政府机构生产的,作为主动透明性在互联网公布,并且受巴西信息获取法的管辖。
作者感谢Ben Huberman的宝贵评论。
揭示 Word2Vec 的开创之旅及人工智能科学的现状

图片由Finding Dan | Dan Grinwis拍摄,发布在Unsplash
与 Dr. Tomas Mikolov 的深入访谈
·
关注 发布于 Towards Data Science ·19 min read·2023 年 2 月 3 日
--
2012 年,托马斯·米科洛夫博士在捷克共和国的布尔诺技术大学获得了人工智能博士学位,论文题为《基于神经网络的统计语言模型》。在谷歌研究部门工作一年后,他发表了两篇极具影响力的论文,介绍了连续词袋模型(CBOW)和跳字模型,也称为 Word2Vec。因此,单词可以在一个稠密的连续空间中用数字表示,遵循简单的训练过程。这是最早有效捕捉单词语义的数值方法之一,并允许处理更大的词汇表。许多最先进的自然语言处理任务使用这种技术取得了超越性的成果,而 Word2Vec 的继承者仍在如今被认为是最先进的语言模型中扮演重要角色。米科洛夫博士认为复杂系统可能是通向智能语言模型的下一步。然而,要实现这样的智能语言模型,科学范式需要改变,以创建一个平等的竞争环境并允许新颖性。
在他为其博士论文辩护后的十年里,他的研究成果被引用超过 125,000 次,h-指数为 49,i-10 指数为 85,依据 Google Scholar 的数据。2014 年,他移居 Facebook,随后在 2020 年返回捷克共和国。他在捷克信息学、机器人学和网络安全研究所组建了团队,开发一个系统,该系统有望逐渐演变为强人工智能。
你的 Word2Vec 算法在自然语言处理领域是革命性的。你能描述一下促成你工作的那些出版物吗?
一个非常有影响力的研究小组,由心理学家大卫·鲁梅哈特领导,早在 80 年代就开始研究类似的概念。鲁梅哈特的学生之一是杰夫·辛顿,他因在神经网络方面的工作而闻名。在 80 年代,他们已经使用神经网络和分布式表示来表示单词,并展示了有趣的特性。在 90 年代,杰夫·艾尔曼使用递归神经网络来建模语言。他使用了由简单的手工编写的语法生成的人工数据。因此,他的工作有许多简化,并不像今天那样复杂;甚至与我们当前的最先进水平相去甚远。但这是一个非常有启发性和前瞻性的方法来表示语言。1991 年的一篇非常有影响力的出版物《Finding structure in time》讨论了在连接主义模型中表示时间的方法。这对我作为学生在工作语言模型时很有启发性。约书亚·本吉奥在 2002 年左右发表了一篇有影响力的神经语言建模论文,他在小数据集上超越了标准语言建模基准。后来,我和约书亚发表了几篇论文,并在他的团队中待了半年。最后,我发现第一个使用神经网络进行通用序列预测并在具有挑战性的基准上取得最先进性能的人是马特·马洪——他的 PAQ 算法基本上是用于数据压缩的神经语言模型,表现惊人。
谁对你影响最大?
对我影响最大的人最初是马特·马洪,后来是霍尔格·施温克;我发现他的论文比约书亚的更易读。它包含了可以快速实现的方法,而不是使用不必要的复杂方法。因此,我尝试自己做一些类似的事情。当我在 2006 年开始我的硕士论文时,我实现的第一个模型是递归神经语言模型。那时我对这种我刚刚发明的递归网络想法感到非常兴奋,但一开始效果不好——虽然比 n-gram 模型好,但不如简单的前馈神经网络。当时,让这种模型正常工作非常具有挑战性,因为我们不知道如何处理梯度爆炸和消失。在 80 年代和 90 年代,“社区”对递归网络中的学习记忆非常感兴趣。然而,没人知道随机梯度下降是否有效,一些论文声称它无效。此外,尽管人们在小数据集上取得了有限的成功,但没有人能成功地在大数据集上训练递归网络,至少没有牺牲大部分性能。现在,我们知道它们可以这样做,这些过去的故事很难理解。
对我来说,这是一个令人兴奋的故事。我在 2007 年夏天想到从神经语言模型生成文本的想法,并将其与 n-gram 模型生成的文本进行比较(灵感来自 SRILM 工具包)。流畅度的提高非常显著,我立即知道这就是未来。看到这些结果的同时知道我是第一个看到这些结果的人,感觉非常酷——就像发现了一个充满奇怪动物的未知岛屿一样。
当我开始研究 RNN 时,我不知道梯度消失和梯度爆炸的问题。经过一段时间,我成功地让 RNN 在小数据集上表现得非常好。这本身就是一个挑战——评估各种语言模型并进行比较,因为当时所有发布的模型通常都在私人数据上进行评估。此外,代码也没有发布。幸运的是,在 2010 年我在约翰霍普金斯大学的 Fred Jelinek 研究组实习期间,我设法获得了一个数据集。经过一些小的调整,我将其发布在我的网站上,这就是现在非常著名的 Penn Treebank 语言建模基准的由来。它与树库完全无关——它只是我用来比较不同语言建模技术的,同时与 JHU 研究人员之前发布的结果兼容。
我还在 2010 年发布了我的 RNNLM 代码,包括文本生成部分,以便其他研究人员可以轻松地复制我的结果。这是至关重要的:我获得的相对于 n-grams 的改进非常显著,当时几乎没有人相信我的结果是正确的。
然而,随着数据集大小的增加,我的递归网络未能收敛的可能性也在增加。这种在大数据集上混乱的行为是不可预测的。虽然大约 90%的在 Penn Treebank 上训练的模型能够收敛到良好的性能,但在更大的数据集上,这个比例降到了 10%左右。由于 RNN 从头实现很困难,我认为我的代码中一定有错误。我认为我只是计算梯度时出错了,或者遇到了一些数值问题。
我找了几天的错误。最终,我找到了熵激增的地方,并且情况变得更糟。一些梯度变得非常大,覆盖了模型的权重,导致训练出现问题。
你做了什么来解决这个问题?
我的解决方案很粗糙。我将梯度值截断到一个阈值以上。任何数学家看到这个技巧都会觉得很糟糕。不过,主要问题是梯度很少爆炸,因此任何防止爆炸的方法都是一个足够好的解决方案。这种启发式方法有效地使递归神经语言模型能够扩展到更大的数据集。如今,调试代码要容易得多,因为你知道标准模型在标准数据集上期望的结果。但在我的时代,这情况不同。我获得了新的最先进结果,却不知道还可以走多远。这很令人兴奋;我是在攀登一座无人到达过的山峰,而我不知道它有多高。最终,我在宾夕法尼亚树库上的困惑度达到了大约 70,大约是 n-grams 的一半。这一结果保持了相当多年的最先进水平。虽然在这里我可以抱怨,语言建模结果在 2014 年左右被错误报告:随着 dropouts 的发明,研究者们开始专注于用单一模型实现最佳结果。但随后我的所有结果都被丢弃了,这些结果是模型集成的。然而,dropout 技术本质上是一种伪装的集成。
许多人将深度学习的流行上升归因于计算能力的提高和大数据集。但这并不是全部故事。真正让它开始有效的是我们弄清楚了如何正确使用这些算法。
经过这一经验,我发现了深度学习叙事中的另一个不准确之处。在 2014-2016 年,深度学习的流行度猛增,出现了关于为什么此时而非之前出现这种热潮的解释。许多人将这一流行的上升归因于计算能力的提高和大数据集。但这并不是全部故事。真正让它开始有效的是我们*弄清楚了如何正确使用这些算法。例如,你可以拿我的 RNNLM 代码,在 90 年代的硬件和数据集上运行——你会得到远远超过当时技术的最先进结果。
显然,拥有更多的计算能力永远不会有害,这对行业的采用至关重要。然而,研究界对这些算法的正确使用才是决定其受欢迎程度的关键;增加的计算能力是次要的。我还认为开源和整体可重复性也是非常重要的因素。深度‘ ‘学习的历史比许多人现在认为的要丰富得多。
当然,这不仅仅是我;亚历克斯·克里热夫斯基让卷积神经网络(CNNs)在图像分类中发挥了作用,乔治·达尔、阿卜杜勒-拉赫曼·穆罕默德和其他人则弄清楚了如何利用深度神经网络进行语音识别,我们这一代的许多博士生也做出了贡献。
你在博士期间是否已经考虑过以不同方式表示词汇?
确实,当我在谷歌工作时,我并没有想出 Word2Vec;我在那之前已经做过类似的工作。我做的第一件事,与 Word2Vec 类似,是在 2006 年的硕士论文中完成的。当时我对神经网络了解不多。我看到了一篇 Yoshua Bengio 的论文,它使用了一个投影和一个隐藏层。我不知道如何处理具有多个隐藏层的神经网络,所以我决定把模型分成两部分。第一部分就像 Word2vec 一样——它从训练集中学习单词表示。第二个网络则使用这些拼接的表示作为输入来表示上下文并预测下一个单词。两个网络都只有一个隐藏层,结果相当不错——与 Yoshua 的论文相似。
在我的博士期间,我在一次国际会议上发表的第一篇论文就是关于这个模型的。虽然它并不是特别令人印象深刻,但我知道可以通过相当简单的模型来学习好的词向量。后来,我看到几篇论文使用了更复杂的神经网络架构来学习词向量。这在我看来相当愚蠢——人们会训练一个完整的神经语言模型,然后把它扔掉,只保留第一个权重矩阵。但对这个研究领域感兴趣的社区非常小,我认为在这个话题上没有发表任何东西的必要。后来,当我完成博士学业时,我在微软研究院实习,与 Geoff Zweig 合作。他是一个了不起的导师,但有时他会对神经网络是否是语言建模的未来表示怀疑——所以我在考虑如何让他印象深刻。
你做了什么来说服他?
这是一个有趣的故事。我进行了些计算,并在接触他之前仔细检查了结果。然后,我问他是否可以对词向量应用简单的加法和减法。我问他在从‘king’中减去‘man’并添加‘woman’之后,最接近的向量是什么(除了输入词,否则你经常会回到你开始的地方)。
他告诉我这是个相当愚蠢的想法,认为这样没有任何意义。因此,我立即把他带到我的电脑前,展示了实验结果——它返回了‘queen’。他非常惊讶,开始尝试各种操作。他尝试了动词的过去时和复数形式等等。有些想法有效,有些则无效。但这比随机猜测要好得多。第一次看到这些类比非常令人着迷。这引发了基本的问题。为什么会出现这些规律?为什么这是完全线性的?为什么不乘以向量,而是相加和相减?
你的谷歌同事也像你的导师一样持怀疑态度吗?
我不会称 Geoff Zweig 为怀疑,但可以说他非常谨慎。他实际上非常支持,很容易说服他相信某些想法值得追求。我在职业生涯初期遇到过更多麻烦。当我开始研究神经语言模型时,我收到了来自布尔诺理工大学一位当地语言学家的极其负面的评价。他甚至说,使用神经网络来建模语言的整个想法完全是胡扯,而且我的结果一定是假的。他差点儿让我被踢出博士项目。
当我加入 Google Brain 时,一些同事已经在尝试学习词语表示。然而,他们试图训练大型语言模型以获得词向量。在大型语言模型中,99.9%的训练时间,你在更新与词向量无关的参数。从 2006 年我的硕士论文中,我知道如果最终任务不是语言建模,这样的大型语言模型是不必要的。相反,使用更简单的模型来计算词向量就足够了。
我将这一见解与一些同事分享了。然而,没有人真正听取。一些人跟随的是一篇斯坦福论文,这篇论文复杂且包含许多不必要的内容。刚刚开始在 Google Brain 工作时,我的第一个目标是展示如何高效地解决这个问题。我开始尝试,很快就取得了成功。使用普通的台式电脑,我可以在几个小时内训练使用数亿个单词的模型。我的模型击败了一个在许多机器上训练了几周的 Google 内部模型。
那时发生了什么?
Yoshua 刚刚组织了一个新的会议,ICLR,并问我是否可以提交一篇关于词语类比的论文,因为那时这是一个相当令人惊讶的结果。他认为这会是一篇很酷的论文。他在 12 月中旬联系了我;截止日期是在 1 月初。所以我在加州的圣诞假期中写了 Word2Vec 论文。论文写得不是很好,但我更关心的是实现和结果,而不是论文。在同事的支持下,我向 ICLR 提交了论文。但不幸的是,评论非常负面(这是一个公开评审,因此应该仍然可以访问)。一位评审抱怨模型没有考虑词序。另一位评审试图强迫我更多地引用其他论文,而这些论文我已经引用过,并且是在我的硕士论文(其中已经包含了主要想法)之后发表的。
ICLR 2013 的接受率约为 70%。但 Word2Vec 论文被拒绝了。今天,它可能被引用的次数比 ICLR 2013 上所有接受的论文加起来还要多。
这里有一个有趣的细节。虽然现在是一个著名的会议,但这是 ICLR 的第一届,规模很小。接受率约为 70%,所以几乎所有不是完全糟糕的论文都会被接受。但 Word2Vec 论文被拒绝了,尽管今天它可能比 ICLR 2013 上所有接受的论文加起来的引用次数还要多。于是,我决定写另一篇扩展版的论文。这篇论文最终被接受到 NIPS。
你从未在其他地方发表过你的第一篇论文,对吧?
第一篇论文在被 ICLR 会议拒绝后被接受到一个研讨会。但我不认为研讨会算作发表。此外,它被发布在 Arxiv 上,我很高兴人们可以阅读。当我发布它时,我知道它比目前可用的要好——至少在我关心的方面。算法并不复杂,实际上提供了非常好的结果。
你是否预料到这篇论文会被如此广泛引用?
神经语言建模社区在我发布这篇论文时还很小。然而,我非常乐观,预期至少会有五十个人在一年内使用它。论文发布六个月后,它仍然未被注意。这是因为谷歌没有批准我开源代码。最初,他们认为代码是竞争优势。然而,我一直在推动开源。周围的前辈告诉我停止尝试,因为我永远无法获得批准。幸运的是,我认识谷歌脑的高层,他们成功绕过了阻碍。最后,谷歌在 2013 年 8 月左右批准了开源代码。这也是代码有些过度优化的原因:在等待批准的过程中,我对代码进行了调整,使其更短更快。代码开源后,兴趣激增。许多人对谷歌的机器学习活动感兴趣,并喜欢谷歌开源代码。这帮助极大。我确实很惊讶有这么多人开始使用这段代码和预训练模型,甚至在一些情况下超出了建模词汇和语言的范围。
你为什么倡导开源?
作为学生,我发现很难比较不同算法,因为这通常是不可能的。十五年前,发布在私有数据集上评估的语言建模论文而没有任何开源实现是很正常的。在我看来,这就是语言建模研究在过去几十年中没有取得太大进展的主要原因。我曾联系过几位研究人员,询问他们的数据集,但都没有成功。到了某个阶段,没有人能验证已发表的结果,社区也陷入了停滞。我发现某些人甚至在报告结果时作弊,例如,使用弱基线或在测试集上调整超参数后报告最佳结果(甚至在测试集上训练模型,这虽然罕见但并非闻所未闻)。我受到了 Matt Mahoney 在数据压缩社区工作的启发,想要重建我对统计语言建模的兴趣,因此我希望在可能的情况下发布我的代码和数据。当然,一个重要方面是,当我开始发布我的大规模神经语言模型结果时,我的改进幅度之大,以至于几乎整个研究社区都不相信我的结果可能是正确的。但由于没有人能在我的代码中找到任何错误(许多人尝试过——我收到过很多邮件,表示他们终于找到了我代码中的“bug”),我的 RNNLM 工具包被几家大公司使用,语言建模研究终于起飞。这就是自然语言处理领域深度学习的开始。
开源有缺点吗?
我认为有。当新的学生加入人工智能社区时,他们应该尝试开发自己的模型并发现新想法。然而,这非常困难,因为他们最终要与多年由许多研究人员逐步优化的最先进模型竞争。
另一种情况是,学生可以下载别人的代码甚至预训练模型,这些通常很复杂,他们可能并未完全理解。然后他们对其进行调整,做出增量变化,并在论文中发布结果。这种方法要容易得多。然而,这对科学来说是一个危险的发展,因为它将我们锁定在局部最优解中。几个主流观点被过度探索,而很少有人思考可以带来新范式转变的新方法。开源和“发布或死亡”共同促成了一个环境,在这里冒险没有回报。
“拥有最多 GPU 的团队相对于其他团队有很大优势。这使得学术界的人们感到沮丧,并创造了不公平的竞争。对某些基准测试轨道上的已发表论文施加计算限制将是一个简单的解决方案。”
所以开源代码有好处。同时,不利影响也很明显。是否存在中间的‘最佳’方法?
鉴于计算能力的重要性,拥有最多 GPU 的团队相较于其他人具有显著优势。这使得学术界的人们受到挫折,并且造成了不公平的竞争;并不是每个人的起点条件都相同。这就好比你去参加奥运会跑步比赛,但比赛时你却是在与骑自行车的人竞争。不论你多么优秀,你都会输。学生们在资源有限的情况下与科技巨头竞争时也会遇到同样的问题。他们可能有更好的想法,但仍会因为不够前沿而被拒绝。这一问题需要社区来解决。
解决这个问题的一个简单方法是对某些基准测试中的论文发布应用计算限制。按照这种方法,论文应当与能够在X小时内在标准化机器上进行训练的代码一起提交。不过,人们仍可以详尽地探索搜索空间,并提交具有最佳超参数的代码,因此拥有更多计算能力的人仍会占有优势。但至少这样竞争会公平些。顺便说一下,当 Matt Mahoney 提出压缩挑战时,他已经考虑到了这一点。
“许多人认为好的模型看起来复杂且充满了超参数和微调。简单的想法常常被认为不值得发表,因为任何人都可以做到。我认为这种心态完全是愚蠢的。”
机器学习社区还有哪些其他问题?
随着 AI 社区每隔几年就翻倍增长,主导科学家容易左右初级研究者的思维。然而,那些发大量推文和 Facebook 帖子,对所有事情都有强烈意见的主导科学家,并不总是那些做出强大贡献的人。一个由少数主导的资深科学家领导大量初级研究者的社区看起来就像某种邪教。这意味着一些想法、技术或模型被盲目推动,没有真实证据表明这些想法值得付出所有努力。例如,生成对抗网络(GANs)看起来被过度炒作了。这不是新现象。当我还是学生时,我记得对 Latent Dirichlet 分配的受欢迎程度感到困惑——它似乎也不比简单的基线方法更有效。但如今我认为这是一个更大的问题,因为信息传播得更快。
对通过蛮力取得的结果的过度强调体现了这个问题。许多人认为好的模型是看起来复杂且充满超参数的小调整。如果提出一个有效的简单想法,审稿人通常会争辩说任何人都可以做到,因此不值得发表。我已经见过这种情况几次,并且认为这完全愚蠢。实际上,我相信相反的观点:在实践中有效的简单想法是最有价值且最难发现的。就像物理学中,科学家们试图发展越来越通用的理论来解释尽可能多的现象一样。
实际上,这种情况发生在 Word2Vec 上,也发生在我一些语言建模工作上。当一个差劲的审稿人看到两篇有类似想法的论文,但其中一篇还添加了十几个不必要的改动时,这个差劲的审稿人会选择复杂的论文作为更好的那一篇,因为看起来投入的工作更多。实际上,情况往往正好相反——如果你能用一个简单的想法获得最先进的结果,那么这个想法可能真的非常好。
我们如何才能获得更好的审稿人?
机器学习可以从物理学中获得启发。在几个世纪的研究中,物理学家们旨在创建简单的理论以解释一切。与此同时,在机器学习领域则正好相反。我们应该放弃对最先进结果和复杂模型的强调,专注于发现有趣的新想法。当然,这高度主观,如果我们能将机器学习变成一个具有明确规则的奥林匹克项目来决定谁更优秀,那将更好。但正如我之前提到的,我认为这并不容易实现。今天,你可以提出一个惊人的新想法,可能成为下一个最先进的成果,但仍然会因为在某些大型基准上不够最先进而受到社区的打击和拒绝。博士生没有足够的时间来发展自己的方法和思路。我们应该改变这种情况,开始奖励新颖性和简洁性,即使这很难衡量。
或许你听说过 NIPS 的审稿实验。更多的审稿小组对论文进行评审,以查看接受/拒绝决定之间的相关性。结果发现,只有对非常差的论文才有很强的相关性。换句话说,审稿系统是非常随机的。
我们应该致力于创建一个更好的审稿系统。目前,我们在审稿系统中没有质量反馈;系统允许审稿人持续犯错,并且仍然能够审阅更多的论文。我们应该有审稿人数据库,自动跟踪他们的表现。他们的质量应该根据预测成功论文的能力来计算。例如,拥有优秀想法但英语较差的论文应该被接受。
在 IEEE SMC 大会的全体报告中,你提到将复杂系统作为人工智能的下一步发展方向。这是一种优雅地简化计算机科学规则的方法吗?
复杂系统是简单系统中通过你未指定的涌现/进化机制产生的复杂性。以《生命游戏》为例。你从简单的东西开始,然后模拟系统直到各种复杂的结构出现。这一直是我对宇宙的看法。我们周围的许多事物看起来很复杂。然而,这些复杂性可以被视为进化的副产品。自然智能是进化的产物。如果我们想通过人工智能来模拟这一点,我们应该采用类似的方法——允许人工智能进化,并有潜力自发地增加其复杂性。
这与进化算法相比如何?
可以使用进化算法来接近这一点。然而,我认为这些算法并没有很好地捕捉进化。它们进行随机优化。如果适应度函数有所改进,那么你就沿着这个随机方向前进。因此,梯度是随机选择的,而不是计算得出的。但在我看来,这不是进化——毕竟,进化算法往往很快陷入停滞。真实的进化可以在复杂系统中找到,即使是确定性的系统也是如此。《生命游戏》中没有任何随机性;你不需要掷骰子。即便如此,你仍然可以看到新颖的模式出现。我的目标是创建能够自发进化的系统,基于复杂性的涌现。我觉得发现能够在复杂性上隐式增长的机器学习模型具有使我们的 AI 模型更强大的潜力。这可能是让机器学习真正具有创造性的一种方式。
你将如何创建这样的系统?
我怀疑理解涌现现象是解决 AI 问题所必需的。然而,我们对这一方向的理解还不够深入。当我开始研究递归神经网络时,我希望这些能够成为通向有趣的复杂系统的捷径,其中涌现发生在模型的记忆中。但典型的递归网络架构具有一定的记忆容量限制。我们需要设计新颖的机器学习模型、训练算法和评估指标。我正与我的学生一起致力于这个工作。
这将如何为机器学习社区做出贡献?
社区已经体现出了一种群体文化。我们都朝着同一个方向前进,建立在现有的基础上。这种心态可能因为我强烈倡导的开源和公共基准测试而得到了强化。然而,我们所扩展的方法可能是错误的。如果是这样的话,每个人都在建立在有缺陷的假设之上。如果是这样的话,就需要修正。正如我提到的,我们应该探索不同的想法,并在研究社区中奖励新颖性。
这听起来像是一个不再开源的理由。
在我看来,开源是很棒的,我们应该继续这样做。请记住,当几乎没有人发布代码和数据集都是私有的时,研究人员通常不会互相信任对方的结果。语言建模社区几乎已经死去。
与此同时,我们应该避免开源的危险:过多的增量工作、提供微小改进的细微调整(有时仅仅如此),以及对探索新想法的气馁。
这标志着采访的结束,您还有什么最后的评论吗?
我们应该对原创性和新方向更加开放。然而,这很难判断。我们是否希望在会议上看到看似疯狂的想法?作为一个社区,我们需要让会议变得更加有趣,而不仅仅是看到数百种 Transformers 的修改或它们在数百个数据集上的应用。让我们更有雄心,更具探索性。
本次采访由 BNVKI,即贝尔赫斯人工智能协会,进行。我们汇聚了来自比利时、荷兰和卢森堡的人工智能研究人员。
揭示 DAX 中 KEEPFILTERS 的秘密
原文:
towardsdatascience.com/uncovering-the-secrets-of-kepfilters-in-dax-6d268e3565d0
DAX 中的 KEEPFILTERS()函数是一个被低估的函数。因此,我决定深入研究这个函数,并为你提供一些有趣的细节以及一个惊人的效果。
·发布在Towards Data Science ·8 分钟阅读·2023 年 7 月 13 日
--

引言
当我们在 DAX 中使用 CALCULATE()函数时,我们通常会添加这样一个简单的筛选器:
产品[Color] = “绿色”
此筛选器用“绿色”值替换[Color]列上的任何现有筛选器。
但有时,我们需要额外一步,保留表格或列上的现有筛选器,以执行一些有趣的计算。
有时,我们的度量值会得到错误的结果,我们无法理解为何会发生这种情况。
在这些情况下,KEEPFILTERS()函数可以帮助我们。
源查询
首先,让我们定义我们想要操作的查询。
我想获取按颜色分类的在线销售列表:
DEFINE
MEASURE 'All Measures'[Online Sales] = SUMX('Online Sales', [UnitPrice]*[SalesQuantity])
EVALUATE
SUMMARIZECOLUMNS('Product'[Color]
,"Online Sales", [Online Sales]
)
我使用 SUMX 将[UnitPrice]乘以[SalesQuantity]。
结果如下:

图 1 — 基础结果(作者图)
当我添加筛选器并使用 CALCULATE()时,查询如下所示,如上所述。
// Only Green Sales
DEFINE
MEASURE 'All Measures'[Online Sales] = SUMX('Online Sales', [UnitPrice]*[SalesQuantity])
MEASURE 'All Measures'[All Green Sales] =
CALCULATE([Online Sales]
,'Product'[Color] = "Green"
)
EVALUATE
SUMMARIZECOLUMNS('Product'[Color]
,"Online Sales", [Online Sales]
,"Green Sales", [All Green Sales]
)
结果如下:

图 2 — 所有行的绿色销售(作者图)
这是因为我们将[Color]列上的筛选器替换为“绿色”。因此,度量值在所有行上返回相同的值,其中[Color] = “绿色”。
介绍 KEEPFILTERS()
好吧,我们可以用 KEEPFILTER()做些什么?
当我们在度量值中添加 KEEPMFILTERS()时,CALCULATE 将保留每行的筛选上下文,并在表达式中添加筛选器:
// Only Green Sales with KEEPFILTERS()
DEFINE
MEASURE 'All Measures'[Online Sales] = SUMX('Online Sales', [UnitPrice]*[SalesQuantity])
MEASURE 'All Measures'[All Green Sales] =
CALCULATE([Online Sales]
,KEEPFILTERS('Product'[Color] = "Green" )
)
EVALUATE
SUMMARIZECOLUMNS('Product'[Color]
,"Online Sales", [Online Sales]
,"Green Sales", [All Green Sales]
)
这是新的结果:

图 3 — 使用 KEEPFILTERS() 的绿色销售(作者绘制的图)
好的,很棒。
那现在呢?
时尚
现在我们可以向我们的测量中添加一些逻辑。例如,我们可以仅对绿色产品进行销售计算。
例如,我们将绿色产品的销售额加倍:
// Perform some dynamic calculations - Double the Green Sales
DEFINE
MEASURE 'All Measures'[Online Sales] = SUMX('Online Sales', [UnitPrice]*[SalesQuantity])
MEASURE 'All Measures'[All Green Sales] =
CALCULATE([Online Sales]
,KEEPFILTERS('Product'[Color] = "Green" )
)
EVALUATE
SUMMARIZECOLUMNS('Product'[Color]
,"Online Sales", [Online Sales]
,"Green Sales", [All Green Sales]
,"Dynamic Sales", IF(ISBLANK([All Green Sales])
,[Online Sales]
,[Online Sales] * 2
)
)
我使用IF()和ISBLANK()来检查销售是否为绿色产品。
如果绿色销售的测量结果为空,我将返回[在线销售]测量的结果。
如果没有,我将[在线销售]测量的结果加倍。
看看结果:

图 4 — 动态销售结果(作者绘制的图)
但我们如何在 Power BI 中使用这个机制呢?
例如,我希望能够选择一种颜色,并对这种颜色的销售进行特定的计算。
首先,我向数据模型中添加了一个新表,但没有在数据模型中添加任何新的关系:
All Colors = SUMMARIZECOLUMNS('Product'[Color])
表格如下所示:

图 5 — 所有颜色表(作者绘制的图)
现在,我将这个列添加到我的报告中的切片器中。
接下来,我的测量必须获取选择的颜色并将其作为筛选器添加:
Modify by selected color =
VAR SelectedColor = SELECTEDVALUE('All Colors'[Color])
VAR CalcByColor = CALCULATE([Online Sales (By Order Date)]
,KEEPFILTERS('Product'[Color] = SelectedColor)
)
RETURN
IF(ISBLANK(CalcByColor)
,[Online Sales (By Order Date)]
,[Online Sales (By Order Date)] * 2
)
这样,我可以根据新表中选择的颜色执行计算:

图 6 — 基于选择颜色的计算结果(作者绘制的图)
这种技术为我们的计算开辟了许多可能性,因为我们可以对某一行进行计算而不影响其他行的结果。
使用上下文转换
但在某些情况下,理解 KEEPFILTER() 的值是至关重要的:上下文转换。
你可以通过阅读我关于这个话题的文章来了解更多关于上下文转换的内容:
行和筛选上下文是 DAX 中的常见概念。但我们可以通过上下文转换在这两者之间切换。
[towardsdatascience.com
当我们在测量中使用上下文转换与所谓的任意形状集一起时,情况会很复杂(稍后会详细说明)。
为了展示这一点,我稍微修改了我们的例子:
我想创建一个切片器,通过品牌和颜色的所有组合来筛选产品表。
然后,我想计算每个品牌和颜色的平均销售额。
在这个例子中,我不使用产品表中的列。我想要一个单独的表来模拟实际场景。
为了实现这一点,我使用 Power Query 从原始产品表中提取一个表,获取所有品牌和所有颜色的列表。此外,我添加了一个包含品牌和颜色列组合的关键列。
这里是结果表的摘录:

图 7 — 带有品牌和颜色及关键列的表格(图示由作者提供)
我将相同的关键列添加到产品表中。
现在我可以在这两张表之间添加一个关系:

图 8 — 扩展的数据模型(图示由作者提供)
现在,我创建了以下度量:
Average over Brand = AVERAGEX(VALUES('Brand Colors'[Brand])
,[Online Sales (By Order Date)]
)
但是当我们尝试验证结果时,我们会遇到困难。
原因在于没有任何控制结果时很难理解结果是否正确。
所以我们要么在 Excel 中重新计算结果(或其他可能的地方),要么更改度量以使用 SUMX()。
这使得生活更轻松,因为我们将能够将结果与现有的在线销售度量进行比较。
这里是 Power BI 中的结果:

图 9 — 复杂过滤器的新度量结果(图示由作者提供)
如果你仔细查看结果,会发现有些问题。
小计和总计远高于每行结果的总和。
原因在于过滤器的应用方式。
对于这个表,我们期望有如下的过滤器:
(Brand = "A. Datum" AND Color IN ("Black", "Blue")
OR
(Brand = "Adventure Works" AND Color IN ("Grey", "Silver")
这样的集合被称为“任意形状的集合”,因为我们混合了来自两个独立列的不同值。
当我们查看每个小计时,我们会期望有两个过滤器:
对于 Adventure Works,我们期望如下:
Brand = "Adventure Works" AND Color IN ("Grey", "Silver")
对于 A. Datum,我们期望:
Brand = "A. Datum" AND Color IN ("Black", "Blue")
实际上,我们得到了两个完全不同的过滤器:
对于 Adventure Works 的小计,我们有如下过滤器:
Brand IN ("Adventure Works", "A. Datum") AND Color IN ("Grey", "Silver")
对于 A. Datum 的小计,我们有如下过滤器:
Brand IN ("Adventure Works", "A. Datum") AND Color IN ("Black", "Blue")
这意味着度量计算所选颜色的所有销售总和,但结果中包括了两个选定品牌。
当我们添加新的矩阵可视化并从产品表中添加品牌和颜色列,并将结果与标准在线销售度量进行比较时,我们可以证明存在一些奇怪的情况:

图 10 — 用基础度量验证结果(图示由作者提供)
如你所见,这两个示例之间的结果不同,这使得这一效果极其令人困惑。
目前应用的过滤器如下:
(Brand = "A. Datum" AND Color IN ("Black", "Blue", "Grey", "Silver"))
OR
(Brand = "Adventure Works" AND Color IN ("Black", "Blue", "Grey", "Silver"))
参考文献部分提到的 SQLBI 文章更详细地解释了这一效果。
为了解决这个问题,我们可以使用 KEEPFILTERS() 来强制从切片器中获取完整的过滤上下文:
Average over Brand = SUMX(KEEPFILTERS(
VALUES('Brand Colors'[Brand]))
,[Online Sales (By Order Date)]
)
现在结果如预期一样:

图 11 — 添加 KEEPFILTERS() 后的结果(图示由作者提供)

照片由 Akhilesh Sharma 在 Unsplash 上提供
结论
DAX 函数 KEEPFILTERS() 非常有用,有时是关键功能。
我并不是建议在使用上下文转换时总是使用 KEEPFILTER()。
但您需要意识到使用上下文转换的后果,以及用户在报告中使用切片器时创建任意形状集的可能性。
在撰写本文时,我不知道使用上下文转换添加 KEEPCFILTERS()是否有任何缺点。
但我喜欢保持简单,不必要的东西就不添加。
无论如何,这篇文章最重要的教训应该是“只相信您可以证明和验证的结果”。
有一些函数在验证时可能非常具有挑战性。其中两个是 AVERAGE 和 COUNTDISTINCT。这两个函数返回的结果可能难以证明。
但这是另一个故事。
参考资料
SQLBI 的 KEEPCFILTERS()介绍:www.sqlbi.com/articles/using-keepfilters-in-dax-updated/
阅读 SQLBI 撰写的这篇文章,了解一些有趣的细节:www.sqlbi.com/articles/keepfilters-a-new-dax-feature-to-correctly-compute-over-arbitrary-shaped-sets/
当我们使用迭代器时,我们使用上下文转换。这里有另一篇 SQLBI 文章,关于这个话题:www.sqlbi.com/articles/when-to-use-keepfilters-over-iterators/
我使用了 Contoso 样本数据集,就像在我之前的文章中一样。您可以从 Microsoft 这里免费下载 ContosoRetailDW 数据集。
Contoso 数据可以根据 MIT 许可自由使用,详见这里。
[## 每当 Salvatore Cagliari 发布新内容时,您将收到电子邮件。
每当 Salvatore Cagliari 发布新内容时,您将收到电子邮件。通过注册,如果您还没有,您将创建一个 Medium 帐户…
medium.com](https://medium.com/@salvatorecagliari/subscribe?source=post_page-----6d268e3565d0--------------------------------)
理解并实现带掩码的自回归流与 TensorFlow
原文:
towardsdatascience.com/understand-implement-masked-autoregressive-flow-with-tensorflow-9c361cd1354c
使用 TensorFlow 进行密度估计的流模型
·发表于Towards Data Science ·8 分钟阅读·2023 年 2 月 21 日
--

图:从随机到不那么随机!来源:作者笔记本(见下文参考文献)。
之前我们详细介绍了正常化流背后的数学以及一些变换概率分布的示例。在这里,我们结合所有这些概念来理解自回归流以及如何使用 TensorFlow Probability 库实现它们。您可以从这篇文章中期待什么 —
-
为什么三角矩阵对自回归流至关重要?
-
自回归流模型的基本构造
— 掩码自回归流(MAF)
— 反向自回归流(IAF)
3. 如何在 TensorFlow 中实现 MAF 并训练它们以进行密度估计任务?
不再耽搁,让我们开始吧!
正常化流中的计算问题:
在讨论诸如掩码自回归流等模型之前,我们回顾一维和更高维场景中的变量变换规则,这将帮助我们理解正常化流中的计算成本。
之前我们详细讨论了如何推导变量变换规则,其中我们从基础分布 u 和双射 ϕ 开始,使得 x = ϕ(u)。在这种情况下,我们可以简单地写出变量变换规则如下:

等式 1:一维的变量变换规则
对于归一化流,我们组合(‘链式’)几个双射,将简单分布转变为更复杂的分布。例如,我们可以如下组合 K 次双射操作,将我们的基础分布(如 u_0)转换为我们所需的复杂分布 x。

方程 2:组合双射以将简单分布转变为复杂分布。
对于 K 变换,我们可以如下修改方程 1:

方程 3:将方程 1 从 1 次双射操作重写为 K 次变换。
实现归一化流的最大问题之一是计算对数-行列式雅可比矩阵的计算复杂度。通过像高斯消去这样的过程计算 n×n 的雅可比矩阵的行列式具有 运行时间复杂度 为 O(n³)。
因此,我们需要对上述过程进行一些简化,现在可以开始学习一些有助于减少计算复杂度的简化方法。
三角矩阵与自回归流:
如果变换矩阵是三角形的,那么计算行列式相当容易。对于 n × n 的方阵,运行时间复杂度是 O(n)。

计算三角矩阵的行列式只需要对角元素。
如上所示,对于三角矩阵(上三角/下三角),我们只需要对角元素来计算行列式,因此运行时间复杂度是线性的。
让我们考虑下三角矩阵,其中 a[i][j]=0; j > i。我们在自回归流中施加了非常类似的概念。
我们考虑一个 D 维向量 u,它经过 1,2,…K 次变换,就像以前一样。基于流模型的思想是使用一个(或一系列)变换 ϕ 对从 p_u(u) 采样得到的实际向量 u 进行操作。在双射与微分同胚的基础知识中,我们讨论了当 u 作为 D 维向量通过 K 次微分同胚变换为 x 时,我们说基础分布是 D 维的,最终分布 (x) 也将是。这样,我们施加了自回归条件,以便获得一个三角矩阵来计算对数-行列式雅可比矩阵,如下所示:

方程 4:自回归条件(左侧的方程)产生一个三角形的对数-行列式雅可比矩阵。
如果你曾经使用过 ARIMA 模型进行时间序列分析,那么你会知道 自回归 项表明时间序列基于过去的值。因此,我们可以约束序列数据 [x1, x2, …, xD],其中每个输出(在特定步骤)仅依赖于之前观察到的值,而不是未来的值。用更数学化的符号表示,就是观察 xi 的概率以 x1,…,xi−1 为条件,这些条件概率的乘积给出观察完整序列的概率:

Eq. 5: 条件概率的乘积给出观察完整数据 X 的概率。
条件密度的建模由我们选择,已经提出了多种方案,从简单的单变量高斯分布到甚至神经网络。让我们讨论一些流行的方法!
Masked Autoregressive Flow (MAF):
对于 MAF,上述 Eq. 5 中描述的条件分布将被视为简单的正态分布,如下所示:

Eq. 6: Eq. 5 中的条件分布假设为简单高斯分布
也可以从基础分布 u 生成新数据,如下所示:

Eq. 7: 给定基础分布 (u) 生成新点,即一组随机数
上述方程告诉我们另一个将自回归模型视为从随机数空间 (u) 到数据 x 的变换 f 的方法。由于这些变换是仿射的(缩放和偏移),为了找回基础变量 u_i,我们不需要逆转这些函数。这也在 MADE 论文中提到:

Eq. 8: 从 Eq. 7 中的变换中反转回基础变量。
这对训练策略极为重要,因为我们不需要显式计算函数 f_αi、f_μi 的逆,只需对它们进行一次评估(例如,在前向传递时),我们可以使用不可逆的函数,如 RELU。
在 Eric Jang 的博客 中提供了 MAF 的前向传递(以及反向传递)的优秀图示(查看参考文献)。

图 2: MAF 的前向传递。来源: Eric Jang 的精彩博客 关于归一化流。
这些是 Masked AutoRegressive Flow for Density Estimation 论文的基础,希望这个归一化流系列能帮助你解读其中的大部分内容。
Inverse Autoregressive Flow (IAF):
IAF 的变换规则在MAF论文中也有清晰解释。IAF 与 MAF 的主要区别在于,对于计算缩放和偏移变量(用于仿射变换),我们使用随机变量(u)而不是数据变量(x)。以下是变换规则:

等式 9: 将其与等式 7 对比,查看 MAF 和 IAF 之间的差异
MAF 和 IAF 之间惊人的相似性是难以忽视的;IAF 的逆是 MAF 的前向传播!
使用 TensorFlow Probability 训练 MAF:
在掌握基础知识后,我们现在准备使用 TensorFlow Probability 来实现 MAF,并训练它以产生特定的分布。我们使用sklearn.datasets,特别是如下所示的make_circles数据集:
1. 加载数据集:
circle_dataset = datasets.make_circles(noise=0.05, factor=0.99,
random_state=1, n_samples=1600)
X_circle, Y_circle = circle_dataset
#standardize
X_circle_normed = StandardScaler().fit_transform(X_circle)
Y_circle = Y_circle.astype('bool')
X_train_c, Y_train_c = X_circle[…, 0], X_circle[…, 1]
#figure section
fig = plt.figure(figsize=(6, 4))
fig.add_subplot(111)
plt.scatter(X_train_c[Y_circle], Y_train_c[Y_circle],
s=10, color='blue', alpha=0.4)
plt.scatter(X_train_c[Y_circle == False], Y_train_c[Y_circle == False],
s=10, color='red', alpha=0.5)
plt.legend(['label: 1', 'label: 0'])
plt.show()

图 3: 我们的目标分布 [来源: 作者笔记本]
这将是我们的目标分布,数据点以圆形方式分布。我们的目标是从随机分布开始,通过使用 MAF 达到这种有序分布。在 TensorFlow Probability 中,实现 MAF 相当简单,因为它存在一个称为MaskedAutoregressiveFlow的双射器。在这个双射器中,对于仿射函数(平移和缩放),我们可以使用另一个双射器AutoregressiveNetwork,它实现了Masked AutoEncoder for Density Estimation (MADE)架构,作者建议这种方法是从自动编码器中一次通过得到联合概率的计算上便宜的方式。下面我们来看实现:
通过 MAF 的前向传播,基础分布为正态分布
我们从一个正态分布开始,并定义 MAF 函数,其中仿射变换由 MADE 架构定义。MADE 架构包含 2 个隐藏层,每层 32 个单元,激活函数为‘Relu’。根据前一篇文章,我们描述了如何使用TransformedDistribution变换分布,我们使用正态分布作为基础,MAF 作为双射器。最后,我们绘制了变换分布的概率等高线图。

图 4: 从正态分布开始,我们将其通过 MAF,其中仿射变换由 MADE 结构提供(激活函数为 Relu)。来源: 作者笔记本

图 5: 与图 4 相同,但激活函数由 Relu 改为 sigmoid。来源: 作者笔记本。
一旦我们定义了前向传播,我们现在就可以开始训练 Flow 模型。我们在最小化负对数似然,首先从一个双射器(MAF 网络)开始训练:
仅使用双射操作(MAF)的训练循环
正如预期的那样,一旦我们绘制训练后的分布,很容易看出结果远离真实分布。我们从训练后的分布中采样,以绘制下面的图:

图 6:从随机分布(中间)开始,仅使用双射,我们无法复制真实分布。来源:作者的笔记本
但我们也知道,基于流的模型的思想是链式双射将简单分布转换为复杂分布。在这里,我们不再仅使用 1 个双射,而是链式使用 4 个双射(MAF),并最小化最终分布的负对数似然:
num_bijectors = 4
bijectors=[]
for i in range(num_bijectors):
masked_auto_i = make_maf(hidden_units=[128, 128], activation='relu')
bijectors.append(masked_auto_i)
bijectors.append(tfb.Permute(permutation=[1, 0]))
# data is only 2 dimension, so we interchange 0, 1
flow_bijector = tfb.Chain(list(reversed(bijectors[:-1])))
排列部分确保了D维数据(这里D=2)的不同维度之间相互影响。如果维度的排序从未改变,这会大大降低在归一化流中链式双射的表达能力。在链式双射时,我们丢弃了最后的排列,因为它在训练中无关紧要(没有其他操作跟随这个排列)。有了这个更具表达力的模型,我们可以期待一些有趣的东西!让我们看看训练分布的概率密度等高线图:

图 7:看起来与我们的目标数据(图 3)非常相似!!来源:作者的笔记本!
这很酷,但我们也可以绘制样本分布,因为我们的基础正态分布通过这 4 个 MAF 双射,来看一下:

图 8:通过链式 4 个 MAF 双射将随机样本转换为圆形样本。来源:作者的笔记本。
我们从归一化流的基础开始,即微分同胚和概率分布的变换规则等。然后我们使用 TensorFlow 概率库实现了一些这些概念,将概率分布从正态分布转变为双峰分布,详见第二篇文章。最后,我们了解了最先进的自回归流模型的基础知识以及三角矩阵在这方面的重要性。最后,我们使用 TensorFlow 实现了 MAF,并展示了训练链式 MAF 以将正态分布转换为稍微复杂一点的分布的示例。希望这能帮助你入门流式模型,并迈出向扩散模型发展的第一步!!
如果你对进一步的基础机器学习概念感兴趣,你可以考虑加入 Medium 使用 我的链接。你无需支付额外费用,但我会获得一小部分佣金。感谢大家!!
[## 使用我的推荐链接加入 Medium - Saptashwa Bhattacharyya
更多来自 Saptashwa(以及 Medium 上的许多其他作者)。您的会员费直接支持 Saptashwa 和其他…
参考文献:
[1] 用于密度估计的掩蔽自回归流: Papamakarios, G. 等
[2] 使用 RealNVP 进行密度估计: Dinh, L. 等
[3] 归一化流:第二部分; Jang, E. 的博客
[4] 基于流的模型; Weng, L. 的博客
[6] 这里使用的代码笔记本:我的 GitHub
理解 Polars 缺乏索引
从 Pandas 切换到 Polars,忘记索引吧
·
关注 发布于 Towards Data Science ·7 分钟阅读·2023 年 1 月 6 日
--
一只北极熊与一只熊猫竞赛——来源:openai.com/dall-e-2
Pandas 和 Polars 是两个 Python 的数据框库。在上一篇文章中,我这样写过 Pandas 和索引的内容:
为了高效使用 Pandas,忽略它的文档,学习关于索引的[复杂]真相。
相比之下,原始的 Polars 书籍 这样说 Polars 和索引:
不需要索引!没有它们会使事情变得更简单——说服我们相信相反的观点!
我们真的可以忘记索引吗?让我们测试一下 Polars 的说法。我们将把我之前文章中的所有示例从 Pandas 移植到 Polars。这将让我们了解在没有索引的情况下工作的实际情况。
最终,我们将看到:
-
“索引是不必要的!没有索引使事情更简单。”
-
如果你认为你 真的 需要一个索引,你 真的 需要一个字典。
为了达到这一点,我们首先创建一个数据框并检索一行。
构建和简单行检索
在 Pandas 中,我们这样构建数据框并设置索引:
import pandas as pd
df1 = pd.DataFrame([['a',2,True],
['b',3,False],
['c',1,False]],
columns=['alpha','num','class'])
df1.set_index(['alpha'],inplace=True)
df1

我们用以下方法将关键字b转化为感兴趣的行:
df1.loc[['b']] # returns row 'b' as a dataframe

在 Polars 中,我们可以通过这样的方式从行构建数据框:
import polars as pl
df1 = pl.from_records([['a', 2, True],
['b', 3, False],
['c', 1, False]],
orient='row',
columns=['alpha', 'num', 'class'])
df1
然而,Polars 是以列为中心的,所以构建相同数据框的更好方法是:
import polars as pl
df1 = pl.DataFrame({'alpha':['a','b','c'],
'num':[2,3,1],
'class':[True,False,False]
})
df1

我们不需要设置索引。我们用以下方法将关键字b转化为感兴趣的行:
df1.filter(pl.col("alpha")=='b')
filter 方法用于查找感兴趣的行。表达式 pl.col("alpha")=='b' 告诉 filter 要查找哪些行。与 Pandas 相比,我发现 Polars 的方法更简单、更通用。(我们稍后将讨论性能问题。)
我们从找到一个简单的行转向找到行的数量。
查找行号
在 Pandas 中,你可以通过 index.get_loc(...) 查看感兴趣行的数字:
import pandas as pd
df2 = pd.DataFrame([['x',2,True],
['y',3,False],
['x',1,False]],
columns=['alpha','num','class'])
df2.set_index(['alpha'],inplace=True)
print(f"{df2.index.get_loc('y')=}")
print(f"{df2.index.get_loc('x')=}")

如示例所示,函数仅在单个项匹配时返回一个数字。当多个项匹配时,它返回一个布尔数组。
在 Polars 中,你应该首先问自己是否真的需要找到行号。答案通常是“否”。但是,如果你回答“是”,你可以使用 arg_where。
df2 = pl.DataFrame({'alpha':['x','y','x'],
'num':[2,3,1],
'class':[True,False,False]})
df2.select(pl.arg_where(pl.col("alpha")=='y')).to_series()

df2.select(pl.arg_where(pl.col("alpha")=='x')).to_series()

结果是一个 Polars series。在 Polars 中,series 代表一列值,这里是行号。
比较 Pandas 和 Polars 在查找行号方面的复杂性,我发现两者类似。然而,Polars 可能会通过减少行号的重要性而占据优势。
接下来我们来看复杂的行访问。
行访问
在 Pandas 中,访问索引行的主要方式是 .loc[…],其中输入可以是:单个元素、元素列表或元素切片。行将按输入中出现的顺序输出。这些示例展示了每种输入方式。
df3 = pd.DataFrame([['i',2,True],
['j',3,False],
['k',5,False],
['j',1,False]],
columns=['alpha','num','class'])
df3.set_index(['alpha'],inplace=True)
df3.loc['j']

df3.loc[['k','i']]

df3.loc['i':'k']

请注意,与 Python 的其他部分不同,Pandas 中的 start:stop 切片包括 stop 值。同时注意,Pandas 排除了第二个‘j’行,因为它在(第一个)‘k’行之后。
在 Polars 中,我们使用 filter 和表达式:
df3 = pl.DataFrame({'alpha':['i','j','k','j'],
'num':[2,3,5,1],
'class':[True,False,False,False]})
df3.filter(pl.col("alpha")=='j')

df3.filter(pl.col("alpha").is_in(['k','i']))

df3.filter(pl.col("alpha").is_between('i','k',include_bounds=True))

默认情况下,Polars 的is_between不会包含其边界,但可以选择包含其中一个或两个边界。同时,请注意,Polars 包含了第二行的‘j’。Polars 基于字母顺序而非行顺序查看in_between(在字符串值上)。
由于它不需要索引,我发现 Polars 在这些更复杂的行检索中比 Pandas 更简单。
对于我们的最后一个基本任务,让我们看看连接行。
连接行
在 Pandas 中,左连接的规则是:
-
左侧的数据框不需要被索引,但右侧的数据框需要。
-
在连接的
on输入中给出左侧感兴趣的列。
在这个示例中,我们将使用join向数据框中添加一个“score”列。这里是左侧的数据框。它没有被索引。
df_left = pd.DataFrame([['x',2,True],
['y',3,False],
['x',1,False]],
columns=['alpha','num','class'])
在 Pandas 中,右侧的数据框需要一个索引,但它可以命名为任何名称。这里我们称它为any_name。
df_right = pd.DataFrame([['x',.99],
['b',.88],
['z',.66]],
columns=['any_name','score'])
df_right.set_index(['any_name'],inplace=True)
我们通过左连接组合这两个数据框。我们使用第一个数据框中的alpha列以及第二个数据框中的任何索引。结果是一个包含分数列的新数据框。
df_left.join(df_right,on=['alpha'],how='left')

在 Polars 中,一切都很类似,但稍微简单一些:
df_left = pl.DataFrame({'alpha':['x','y','x'],
'num':[2,3,1],
'class':[True,False,False]})
df_right = pl.DataFrame({'alpha':['x','b','z'],
'score':[.99,.88,.66]})
df_left.join(df_right,on=['alpha'],how='left')

区别在于我们不需要为右侧数据框建立索引。如果感兴趣的列具有相同的名称(如这里所示),我们使用on。如果没有,我们使用left_on和right_on。
所以,Polars 再次比 Pandas 更简单使用,但代价是什么呢?
性能
当然,缺少索引会使 Polars 变慢。但令人惊讶的是,实际上并非如此。在广泛的基准测试中,Polars 比 Pandas 快得多 [Vink, 2021]。它通过优化实现了这一点,包括良好的内存布局和自动向量化/并行化。
我们能否构造出 Pandas 比 Polars 更快的情况?是的,如果我们将数据框用作字典,Pandas 可能比 Polars 快 20 倍。然而……
猜猜什么比将 Pandas 用作字典快 300 倍?答案是:将字典用作字典。
在这个测试中,我们构造了一个包含两个列的数据框,填充了从 0 到 999,999 的数字。然后我们寻找数字 500,000。
import polars as pl
import pandas as pd
n = 1_000_000
df_pl = pl.DataFrame({'a':list(range(n)),'b':list(range(n))})
%timeit df_pl.filter(pl.col("a")==n//2)
df_pd = pd.DataFrame({'a':list(range(n)),'b':list(range(n))})
df_pd = df_pd.set_index('a')
%timeit df_pd.loc[n//2]
dict_pl = df_pl.partition_by('a',as_dict=True)
%timeit dict_pl[n//2]
以下是我在 4 核笔记本电脑上多次运行的平均结果:

总结性能:根据其他基准测试,对于典型使用情况,Polars 比 Pandas 更快。对于特殊情况——例如,当你确实需要使用字典时——Polars 提供了创建字典以获得最快性能的工具。
结论
在我看来,消除索引使得 Polars 比 Pandas 更易于使用。
你可能会预期这种简化会导致性能变慢。然而,基准测试显示 Polars 通常比 Pandas 快得多。它通过包括良好的内存布局和自动向量化/并行化等优化来实现这一点。可能仍然存在需要类似索引的数据结构的情况。对于这些情况,Polars 提供了创建字典等工具。
那么,你应该从 Pandas 切换到 Polars 吗?这要视情况而定。
我们的基因组学项目,FaST-LMM 使用 Pandas 输出统计结果表格。FaST-LMM 几乎所有的计算工作都在 Pandas 之外用自定义代码完成。它仅使用 Pandas 与我们的用户分享最终结果,我们可以假设这些用户了解 Pandas。考虑到这一点,我们没有理由从 Pandas 切换。
另一方面,如果我开始一个涉及有趣数据分析的新项目,我会使用 Polars。Polars 给了我一直想从 Pandas 中获得的速度和简便性。
请 在 Medium 上关注我。我撰写关于 Rust 和 Python 中的科学编程、机器学习和统计学的文章。我通常每个月写一篇文章。
通过从零开始构建交叉熵来理解策略梯度
我们如何训练模型的统一视角
·
关注 发表在 Towards Data Science · 16 分钟阅读 · 2023 年 6 月 11 日
--
强化学习 (RL) 可以做出令人惊叹的事情。最近,ChatGPT 通过 PPO 进行微调,PPO 是一种叫做 策略梯度 (PG) 的强化学习算法的变种。理解 RL,特别是策略梯度,可能并不简单,特别是如果你像我一样喜欢把握直觉的话。在这篇文章中,我将探讨一系列思路,这些思路确实帮助我从更熟悉的监督学习环境出发,深入理解 PG。
摘要
-
我们将从设计一个简单的监督训练程序开始,通过奖励+1 来对二分类机器人进行正确答案的训练
-
我们将为该过程制定目标
-
我们将推导出该过程的梯度上升公式(这将与使用交叉熵的梯度下降过程相同)
-
我们将把我们的过程与 RL 设置进行比较,并将我们的梯度上升与策略梯度联系起来
谁应该阅读这个?
-
我的目标是提供一种友好且直观的方式来理解 PG。如果你对 RL 问题设置有一个大致了解,并且知道 PG 的高级概念,将会很有帮助。
-
我希望帮助你更好地理解 RL 与 PG 以及监督 ML 之间的关系。因此,如果你了解如何用交叉熵损失函数训练一个监督 ML 算法,将会非常有帮助。
为什么写这篇文章?
策略梯度
在 RL 问题中,代理与环境互动以学习策略。策略告诉代理在不同状态下该做什么以最大化奖励。

作者提供的图像
PG 的想法似乎很简单明了。
-
指导时间t上代理行为的策略是π_θ(a_t|s_t)。
-
这是一种函数(通常是神经网络),具有参数θ。
-
它接收状态信息 s_t 并输出一个采取行动的概率分布 a_t。
-
然后它接收奖励 r(s_t, a_t)。
-
当我们拥有许多这样的动作和奖励周期的历史时,我们可以更新参数θ以最大化由π_θ生成的动作所带来的预期奖励。
我们如何进行更新?通过…梯度!我们通过以下梯度更新生成π_θ的模型

有些东西感觉不对劲
这看起来非常熟悉。当我们在传统的监督学习中训练神经网络模型时,我们也通过执行第二行操作即梯度下降来更新模型参数(在 PG 情况下,技术上是梯度上升,因为我们在最大化目标)。
但这也感觉非常不同。如果你查看它的推导过程,你会发现推导这个方程需要一点努力。这与我们在监督学习中更直观的做法非常不同:将输入提供给神经网络,得到输出,与目标进行比较并计算损失函数,点击反向传播按钮,就完成了!
对我来说,对数项总是似乎突然出现。尽管上述链接中的同一在线课程讲解了如何得到对数项,但过程似乎只是一堆正确但缺乏动机的数学。
从监督学习中具体的区别是什么?深入探讨这个问题可以很好地理解策略梯度。此外,它也是对我们每天做的一些熟悉的监督学习本质的良好提醒。
从头开始构建交叉熵
如果我们用一些在监督学习中使用的损失函数来分析,它们会立即“显得合理”。但要理解它们的来源则需要更多的努力。例如,经典的均方误差直观上很合理:它只是最小化预测与目标之间的距离。但有这么多距离度量,为什么选择平方距离?你必须深入了解均方误差是做最大似然估计并假设基础总体分布为正态分布的副产品。
同样地,我们日常使用的另一个经典损失函数是交叉熵。虽然有很多关于交叉熵的良好解释,让我们尝试从最基本的方式构建它。
让我们训练一个分类机器人!
假设你想训练一个机器人来分类狗和猫的图像。直观上,通过奖励正确答案并惩罚(或不奖励)错误答案来训练它是合理的。具体方法如下:
- 你给机器人一张图片。我们称之为s。这张图片是从总体分布D_s中采样的。

狗图像来源:Unsplash;其他部分由作者提供
-
如果机器人认为这是狗的图像(动作a_dog)或这是猫的图像(动作a_cat),它将给你一个答案。
-
机器人根据图像有自己的预测,即图像是狗还是猫的概率:π_θ(a|s) = (a_dog, a_cat)。例如,π_θ(a|s) = (0.9, 0.1)意味着它认为有 0.9 的概率是狗,0.1 的概率是猫。

狗图像来源:Unsplash;其他部分由作者提供
- 但每次机器人只会给你一个明确的答案。它要么说“这是狗” (a_dog),要么说“这是猫” (a_cat)。每次它给你一个回应时,回应(动作)是从分布中随机采样得到的,由π_θ(a|s)产生:a = (a_dog, a_cat) ~ π_θ(a|s)。

狗图像来源:Unsplash;其他部分由作者提供
- 当机器人正确回答时,你将奖励它(可能给它一个小奖励?),奖励值为 1。(r(s,a) = 1)。当回答错误时,则没有奖励(0 奖励)。(r(s,a) = 0)

狗图像来源:Unsplash;其他部分由作者提供

猫图像来源:Unsplash;其他部分由作者提供
这是我在第一次学习监督学习时想到的过程。当它正确时给予奖励。当它错误时(或在我们设计的训练过程中没有奖励)给予惩罚。这可能是训练某物最直观的方式。
最大化目标
我们的目标是什么?我们希望它的响应尽可能正确。更准确地说,我们希望找到最优参数θ,使得生成的π_θ(a|s),在所有可能的s(从图像总体分布D_s中采样)和a(从由模型π_θ(a|s)生成的分布中采样)中,能够获得每对(s,a)出现的概率加权的最大平均奖励:

换句话说,我们在最大化定义为

目标的梯度
现在我们有了一个目标函数,我们可以尝试通过…梯度上升来最大化它!也就是说,我们可以通过迭代进行

但我们应如何计算梯度,即J对θ的导数?这在这种情况下有点棘手,因为
-
我们希望对其求导的函数是一个期望。
-
如果期望不是关于依赖于θ的分布,那么通过期望的线性性,我们可以直接对期望内部的内容进行求导,并将期望保留在那里。然而,在这种情况下,期望是关于(s,a) ~ (D_s, π_θ(a|s))的,这依赖于θ。因此,导数并不明显。
-
另一种思考方式是,J(θ)的值随着我们从由部分由θ决定的分布中采样(s,a)的频率变化而变化。我们希望更频繁地出现s=dog image和a=a_dog(猫的类似对)。当我们进行梯度上升时,我们如何捕捉向这个方向变化的θ?
此外,理想情况下,我们希望梯度呈现以下形式

这是因为你通过机器人与您的交互样本来训练机器人。每个样本包含一个(s,a,r)三元组。因此,我们可以通过对收集到的N个样本进行平均来近似这个梯度(根据大数法则,即进行随机梯度上升):

然后我们可以通过进行梯度上升来进行优化

现在让我们找到f。
寻找梯度
总结一下,我们希望从(1)开始,得到(2),对于某个f(θ,s,a,r)。

首先,让我们用期望的定义重写(1):

这基本上是对所有可能的(s,a)对的奖励按概率加权的积分。
那么,一个(s,a)对的联合概率P(s,a)究竟是多少?我们可以将其分解为图像样本(s)出现的概率和机器人随机选择动作a的概率。

由于机器人从其内部预测模型 π_θ(a|s) 中随机选择动作 a,我们有

在括号内的所有项中,只有 π_θ(a|s) 依赖于 θ。其他项都是常数。因此,我们可以将梯度操作移动到积分符号内,并得到

注意,我们也可以写出以下内容。这里没什么大不了的。只是将原始左边的内容乘以以分数形式写出的 1,并调整项。

替换回去,并稍微调整一下,我们得到

P(s)π_θ(a|s) 看起来很熟悉。这正是我们之前分解的 P(s,a)!将其放回去,我们得到

现在我们有一个积分和 P(s,a),我们可以…将其适配回期望的定义!

这正是我们在(2)中想要得到的形式,其中 f 是括号内的项!
你可能会想,为什么我们在之前的繁琐分数中重写了π_θ(a|s)的梯度?其目的是创建一个π_θ(a|s)项(我们在求导时丢失了它),以便我们可以再次生成一个 P(s,a) 项,并将积分重新转化为期望!
构建交叉熵
现在是魔法时刻。

不相信我?使用链式法则从右手边到左手边进行工作。([可选] 旁注:如果你对策略梯度公式中对数项的动机感到困惑,这实际上是简化我们得到的繁琐方程的副产品,旨在提取一个 π_θ(a|s) 项,将事物转回期望。)
所以我们可以稍微简化J(θ)的梯度:

所以每次我们有一批(s,a)作为样本时,可以通过

为了将其转化为更熟悉的形式,将梯度符号移到求和外部,我们有

我们还会通过进行以下操作来反转符号

这让你想起什么吗?让我们将其与在交叉熵损失上进行梯度下降时所做的事情进行比较。
记住,交叉熵损失是

其中y_i是真实标签,是一个描述图像是猫还是狗的独热向量(y_i_1, y_i_2,要么是(0,1)要么是(1,0))。y_hat_i是模型的预测,是一个向量(y_hat_i_1, y_hat_i_2),其中两个条目的和为 1。
当我们对这个损失函数进行梯度下降时,我们计算批次的交叉熵损失函数,并点击反向传播按钮:

这个表达式与我们之前推导出的梯度上升表达式之间的区别是

用语言描述,就是:在样本x_i上,y_i
-
模型做出预测(y_hat_i_1, y_hat_i_2)给定x_i
-
模型从预测分布中随机采样响应
-
我们奖励响应 1 的y_i_1,并且对响应 2 的y_i_2进行奖励。
-
由于当标签为类别 1 时,y_i_1 = 1, y_i_2 = 0,我们在模型正确响应 1 时奖励模型 1 分,而在模型错误响应 0 时没有奖励。类别 2 的情况也是如此。
这正是我们一直在做的事情!
所以总结一下,
-
我们设计了一个简单的训练设置,在这个设置中,我们奖励 机器人当其正确回答时得 1 分,当其回答错误时得 0 分。
-
我们总结了我们希望在目标函数中实现的内容,该目标函数描述了机器人根据其响应的机会加权所获得的奖励。
-
我们找到梯度下降过程以最大化这个目标函数
-
然后我们得到……我们在通过计算交叉熵损失然后进行反向传播训练模型时使用的确切过程!
回到强化学习
现在让我们把焦点重新放回到强化学习设置上。RL 与监督学习设置之间的区别是什么?
多个时间步长
第一个区别是 RL 通常涉及多个状态和多个回合。在我们的设置中,机器人从图像输入开始,即状态s。在机器人基于预测给出答案并收集奖励后,机器人与您的互动就结束了。
相反,在 RL(强化学习)问题中,智能体通常在多个回合中与环境互动,且在初始状态后可能过渡到其他状态。
目标函数变为

用语言描述,我们最大化所有时间步长的平均奖励总和,对所有可能的状态和动作序列(轨迹)加权,加权由每个轨迹发生的概率决定,当动作由参数θ决定时。
注意,p_θ是一个状态和动作序列的联合分布,当动作由代理的模型参数θ决定时。在每个时间步,代理的动作由π_θ(a_t|s_t)决定,其中π_θ是一个以θ为参数的模型。p_θ是一个高级抽象,表示当代理根据π_θ做出决策时,状态和动作序列发生的概率(即p_θ是理论上代理在轨迹上采取行动的频率的占位符。另一方面,π_θ(a|s)是代理在特定时间步采取某个动作的概率。我们实际上不容易知道p_θ的值,因此稍后我们将用实际知道的模型输出π_θ(a|s)来重写它)。
让我们与之前的目标进行比较:

主要区别如下:
-
我们计算一个* s 和a*序列上的期望,而不是仅仅一个对。
-
我们最大化轨迹中所有时间步的奖励总和,而不仅仅是来自图像和回答的单一时间步奖励。
比较梯度公式:
我们可以对这个目标做类似的操作,推导出我们可以在每个时间步更新θ的梯度。
回顾一下,我们的目标是以以下形式找到某些f的J(θ)的梯度。

当我们获得一批样本序列s_1, a_1, r_1, … s_T, a_T, r_T时,我们可以通过随机梯度上升更新θ:

为了简化,我们将状态序列记作一个变量τ。

所以我们希望最大化以下目标函数:

我们可以做类似的操作:
- 用积分表示期望。

- 对仅涉及θ的项p_θ(τ) 求导。

- 将p_θ(τ)的梯度重写为** p_θ(τ)和其他东西的乘积**,以恢复定义期望的形式。

所以我们得到:

看!这正是我们想要找到的。换句话说,这意味着我们正在将θ 更新为样本τ的对数概率梯度的方向,权重是沿样本τ的总奖励。这正是策略梯度的公式。
如果我们从早期的交叉熵类比延伸过来,奖励的总和基本上是轨迹的标签,而 p_θ(τ) 是模型预测下 τ 发生的可能性。训练过程 鼓励模型预测与不同轨迹 τ 上的奖励分布相似的分布。(这实际上是一个数学上准确的陈述 [如果我错了请纠正我]。如果你知道 KL 散度,可以将所计算的梯度与 KL 散度进行比较)。
我们可以对条件概率和 p_θ(τ) 的定义进行更多的操作。这个过程在这个视频(大约在 9:27)中讲解得很好。我们最终得到以下内容,将 p_θ(τ) 重新表示为 π_θ(a_t|s_t),这是我们实际知道其值的:

注意 当 T = 1(单次实验),这与我们之前设置中获得的梯度是一样的。换句话说,监督学习是强化学习的一个特殊情况,其中只有一个实验,奖励是非随机的(见下一节)。
另一个区别:奖励的估计
强化学习与监督学习之间的另一个区别是我们可以多大程度上相信奖励。在监督学习中,奖励是与图像样本一起提供的真实标签。我们通常 100% 确定奖励是正确的,我们的机器人会根据这些标签调整其行为。
然而,在强化学习问题中,奖励可能 更具随机性(想象一下你玩游戏时,可能在同一个地方两次但得到不同的分数)。因此,我们必须 估计特定状态-动作对的奖励,通过与环境互动并利用历史奖励来进行估计。
[可选] 附带想法:我还在思考是否存在监督学习(标签/奖励是 100% 可相信的)和强化学习(奖励更具随机性)之间的中间领域。当标签有噪声(包含一些错误标签)时,我们是否有点像处于中间?所以, 伪标签方法 是否与强化学习问题有一些相似之处?请告诉我你的想法。
从长远来看,我们应该有足够的历史奖励来理解平均奖励行为,但在短期内,小样本数量可能会产生 不稳定 的偏差估计。
更糟糕的是,由于代理行为是通过收集的奖励来更新的,如果我们收集到低质量的奖励,我们可能会陷入并停留在一个糟糕的策略中。要从那里走出来并重新回到正确的轨道上需要很长时间。
这是强化学习中的一个挑战,仍然是一个正在进行的研究领域。 对奖励进行一些操作 和变体,如 TRPO 和 PPO,旨在更好地解决这个问题,并且比普通 PG 使用得更为广泛。
[可选] 另一种思考:与序列监督机器学习的比较
我们的监督机器学习设置与 RL 之间的一个区别是 RL 通常涉及多个时间步。我立刻有一个问题:那么 RL 与训练像 Transformer 或 LSTM 这样的序列模型有什么不同?
这个问题的答案绝对取决于你最喜欢的序列模型的训练损失设计。
现在,假设你训练一个序列模型 f(x_1,x_2,…x_T) 以预测 y_1, y_2…y_T。例如,在机器翻译任务中,x 可能是输入英文句子的单词,而 y 是输出法文句子的单词(每个 x_t, y_t 是单词的一个独热向量表示)。
我们通过对每个样本的每个单词输出预测与真实标签之间的交叉熵之和来计算损失函数。然后,我们对一批样本进行平均,并像下面这样进行反向传播。

放回到策略梯度公式中,对我来说,这与计算目标函数的梯度相同

这种公式与 PG 公式的区别在于,我们没有将所有时间步的预测的对数概率之和与所有步骤的奖励之和相乘。相反,我们取每个时间步的对数概率与奖励的成对乘积并将它们相加。
这去除了很多项,因此大大减少了梯度的方差,这可能是使得在监督设置中训练 Transformer/LSTM 比 RL 算法更容易的原因?(除了监督设置中的非随机奖励)。
这个视频 中介绍了一种减少 PG 方差的技术:将 PG 中所有时间步的奖励总和更改为未来奖励(即从 t’ = t 到 t’ = T 的总和)。这与 PG 与在监督设置中训练 Transformer/LSTM 之间的不同具有相似的风味。虽然未来奖励方法使得代理能够通过可能的未来奖励评估每个状态,但我们是否可以说监督序列训练使得模型仅关注当前时间步的正确性?
此外,我尝试从这个梯度表达式中倒推,找到导致这个梯度表达式的原始 J(θ),以便我们可以更直接地解释监督序列训练的目标。但我在半途中卡住了。如果你有任何想法,请告诉我。
致谢
策略梯度与交叉熵之间的联系并非我自己原创的想法。感谢这篇文章给了我拓展思路的启发,让我从更根本的角度理解交叉熵和策略梯度的作用。
理解 SQL 注入并学习如何在 Python 中使用 SQLAlchemy 避免它
学习在 Python 中以安全的方式与数据库交互
·发布在 Towards Data Science ·5 分钟阅读·2023 年 4 月 12 日
--

图片来自 Pixabay 的 mohamed_hassan(Hosting Web Man)
SQL 注入是最常见且最危险的网络安全漏洞之一,它允许黑客将恶意 SQL 代码注入到未经验证和清理的纯 SQL 查询中。这也是新开发人员常常忽视的一个问题。
SQL 注入的原因和解决方案其实非常简单。在这篇文章中,我们将通过一些简单的查询来探索 SQL 注入,并假装成为攻击者来利用我们的数据库。在文章的最后,你将完全理解 SQL 注入,并在意识到其威力和危险后不会再犯这个错误。
准备
和往常一样,我们将使用 Docker 创建一个 MySQL 数据库:
# Create a volume to persist the data.
$ docker volume create mysql8-data
# Create the container for MySQL.
$ docker run --name mysql8 -d -e MYSQL_ROOT_PASSWORD=root -p 13306:3306 -v mysql8-data:/var/lib/mysql mysql:8
# Connect to the local MySQL server in Docker.
$ docker exec -it mysql8 mysql -u root -proot
mysql> SELECT VERSION();
+-----------+
| VERSION() |
+-----------+
| 8.0.31 |
+-----------+
1 row in set (0.00 sec)
请注意,本文中为了简化起见使用了 root 用户,但在实际应用中绝不应直接在我们的 Web 应用程序中使用。
然后,让我们创建一些数据库和表来进行测试。为了简单起见,数据集与之前系列文章中使用的相同。
CREATE DATABASE `data`;
CREATE TABLE `data`.`student_scores` (
`student_id` smallint NOT NULL,
`subject` varchar(50) NOT NULL,
`score` tinyint DEFAULT '0',
PRIMARY KEY (`student_id`,`subject`),
KEY `ix_subject` (`subject`),
KEY `ix_score` (`score`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci
;
INSERT INTO `data`.student_scores
(student_id, subject, score)
VALUES
(1, 'Literature', 90),
(1, 'Math', 60),
(2, 'Literature', 80),
(2, 'Math', 80),
(3, 'Literature', 70),
(3, 'Math', 95)
;
由于我们在这篇文章中将使用 SQLAlchemy 进行数据库连接,我们需要安装必要的库。和往常一样,建议创建一个单独的虚拟环境,以便库不会影响系统及其他虚拟环境。
conda create --name sql python=3.11
conda activate sql
pip install -U "SQLAlchemy>=2.0.0,<2.1.0"
pip install -U "pymysql>=1.0.0,<1.1.0"
pip install -U "cryptography>=40.0.0,<40.1.0"
探索 SQL 注入
现在让我们创建一个简单的函数来读取一些数据:
from sqlalchemy import create_engine, text
db_url = "mysql+pymysql://root:root@localhost:13306/data"
engine = create_engine(db_url, pool_size=5, pool_recycle=3600)
conn = engine.connect()
def read_student_scores(student_id):
sql_text = text(f"""
SELECT subject, score
FROM data.student_scores
WHERE student_id = {student_id}
""")
result = list(conn.execute(sql_text))
print(result)
read_student_scores() 函数从简单的编码角度来看似乎很正常。然而,它存在一个巨大的安全问题,可能被恶意用户利用。
如果我们正常使用,它将正常工作:
read_student_scores(1)
# [('Literature', 90), ('Math', 60)]
然而,它也可能返回一些不应该由恶意用户返回的内容。黑客的第一个攻击点是返回所有记录,即使是用户不应该看到的记录:
read_student_scores('-1 OR 1')
# [('Literature', 90), ('Math', 60), ('Literature', 80), ('Math', 80), ('Literature', 70), ('Math', 95)]
这是可能的,因为 read_student_scores() 函数没有清理和验证输入参数,而是将输入数据与原始查询简单地拼接在一起。
这对于许多开发者来说并不罕见。实际上,我见过不少以这种方式编写的遗留代码。很幸运的是,它们之前没有被黑客攻击。或者说,也许已经被黑过了……
SQL 注入可能比上面所示的更具危害性,实际上黑客可以返回任何信息。
现在,让我们假装自己是恶意用户,尝试获取一些不应该由此函数返回的信息。
黑客首先想知道的是返回了多少列。在这个示例中,很明显返回了两列。然而,当输出通过某些用户界面显示时,这可能不是那么明显。
有很多方法可以猜测返回了多少列,两种常见的方法是使用 ORDER BY 和 UNION。让我们看看它是如何工作的:
read_student_scores('-1 ORDER BY 1')
# []
read_student_scores('-1 ORDER BY 2')
# []
read_student_scores('-1 ORDER BY 3')
# OperationalError: (pymysql.err.OperationalError) (1054, "Unknown column '3' in 'order clause'")
从上述查询、结果和错误中,我们知道返回了两列。
我们可以使用UNION得出相同的结论:
read_student_scores('-1 UNION SELECT 1')
# OperationalError: (pymysql.err.OperationalError) (1222, 'The used SELECT statements have a different number of columns')
read_student_scores('-1 UNION SELECT 1,2')
# [('1', 2)]
使用 UNION 我们能够通过较少的测试猜测正确的列数。实际上,UNION 是最常用的黑客工具之一,用于攻击数据库。
让我们尝试读取一些正常情况下不应该返回的内容:
read_student_scores('-1 UNION SELECT DATABASE(), @@VERSION')
# [('data', '8.0.31')]
数据库名称和版本被返回了!
让我们看看更可怕的情况:
read_student_scores('-1 UNION SELECT user, authentication_string FROM mysql.user')
# [('root', '$A$005$j\x1cZ\x1aj*t\x16_aI\t.\tk\x1a0b8,6nT16rTboTxEGJsq8R.xLN1dlygQWOe12XurOijG5v9'), ('mysql.infoschema', '$A$005$THISISACOMBINATIONOFINVALIDSALTANDPASSWORDTHATMUSTNEVERBRBEUSED'), ('mysql.session', '$A$005$THISISACOMBINATIONOFINVALIDSALTANDPASSWORDTHATMUSTNEVERBRBEUSED'), ('mysql.sys', '$A$005$THISISACOMBINATIONOFINVALIDSALTANDPASSWORDTHATMUSTNEVERBRBEUSED'), ('root', '$A$005$\x0c=\x10gE\x7f]g\x18WQNnB`Y&I1\x18zPIQ3wM3cj43wk4Qq4/Tt88B0ypKrwYLYnD3BpGqfY5')]
所有数据库用户的用户名和认证字符串都被返回了!使用一些暴力猜测工具,黑客可以在短时间内破解密码,特别是当使用简单密码时。
如何避免 SQL 注入?
现在我们已经了解了 SQL 注入是什么以及它有多么危险,让我们看看如何在实践中避免它。
防止 SQL 注入的最有效方法是使用参数化查询,这可以通过 SQLAlchemy 中的 :param_name 语法实现:
def read_student_scores(student_id):
sql_text = text("""
SELECT subject, score
FROM data.student_scores
WHERE student_id = :student_id
""")
result = list(conn.execute(sql_text, parameters={"student_id": student_id}))
print(result)
请注意,本帖使用了 SQLAlchemy 2.0,因此指定参数的语法与 SQLAlchemy 1.x(通常是 1.4)中的语法会有所不同。
让我们看看使用参数化查询时恶意查询会返回什么:
read_student_scores('-1 OR 1')
# []
read_student_scores('-1 UNION SELECT DATABASE(), @@VERSION')
# []
read_student_scores('-1 UNION SELECT user, authentication_string FROM mysql.user')
# []
所有这些恶意查询返回了空结果,比之前安全得多。
在这篇文章中,我们介绍了什么是 SQL 注入,它是如何工作的,以及如何使用简单的示例来避免它。
尽管使用参数化查询可以防止大多数 SQL 注入实例,但为了使我们的应用程序更加健壮,我们还应该应用以下策略:
-
限制数据库用户查询数据库的权限。为了简单起见,本示例使用了 root 用户,但在实际应用中绝不应该使用 root 用户。实际上,我们应该为我们的 Web 应用创建一个专用的数据库用户(并使用强密码),并仅授予最低权限。
-
清理并验证输入查询。如果输入数据的类型不一致,或者包含可疑字符如井号(
#)、分号(;)、减号(-),甚至是词语UNION,应以安全、稳健且用户友好的方式处理这种情况。 -
永远不要将调试日志直接展示给最终用户。调试日志仅应供内部用户使用,因为它们可能包含敏感信息,恶意用户可能利用这些信息来利用系统。
相关文章:
实时了解您的数据
实操教程
与 bytewax 和 ydata-profiling
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 7 月 20 日
--
在这篇博客文章中,我们将深入探讨如何将开源流处理解决方案 bytewax* 与 ydata-profiling* 结合使用,以提升您的流处理质量。准备好了吗!
流处理允许对数据进行实时分析,无论是在传输过程中还是存储之前,并且可以是有状态的或无状态的。
有状态流处理 用于实时推荐、模式检测或复杂事件处理,其中需要处理历史数据(窗口、按键连接等)。
无状态流处理 用于内联转换,无需了解流中其他数据点,例如掩码电子邮件或类型转换。

照片由 Markus Spiske 提供,拍摄于 Unsplash
总体而言,数据流在工业中广泛使用,应用于诸如 欺诈检测、病人监控 或 事件预测维护 等用例。
所有数据流必须考虑的一个关键方面是数据质量。
与传统模型中数据质量通常在数据仓库或仪表板解决方案创建过程中进行评估不同,流数据需要持续监控。
在整个过程中,从数据收集到传递给下游应用程序,保持数据质量至关重要。毕竟,差的数据质量可能会给组织带来高昂的成本:
“对于大多数公司来说,差的数据质量的成本高达 15% 到 25% 的收入。 (…) 通过在数据质量上提前准备,可以消除其中的三分之二的成本。”
— 托马斯·C·雷德曼,《数据质量的前瞻》一书的作者
在本文中,我们将向您展示如何将 bytewax 与 ydata-profiling 结合起来,以分析和提高您的流数据质量!
使用 Bytewax 进行数据专业人士的流处理
Bytewax 是一个开源流处理框架,专为 Python 开发人员设计。
它允许用户构建流数据管道和实时应用程序,具有类似于 Flink、Spark 和 Kafka Streams 的功能,同时提供一个友好且熟悉的界面,并且与 Python 生态系统 100% 兼容。
使用内置的 连接器 或现有的 Python 库,您可以连接到实时和流数据源(Kafka、RedPanda、WebSocket 等),并将转换后的数据写入各种下游系统(Kafka、parquet 文件、数据湖等)。
对于转换,Bytewax 支持有状态和无状态的转换,通过 map、windowing 和 aggregation 方法,并具备如恢复和可扩展性等熟悉的功能。
Bytewax 提供了以 Python 为主的数据流体验,并专门为数据工程师和数据科学家而构建。它允许用户构建流数据管道和实时应用程序,并创建满足需求的自定义配置,而无需学习和维护像 Spark 或 Flink 这样的基于 JVM 的流处理平台。
Bytewax 非常适合多种使用场景,包括生成 AI 的嵌入管道、数据流中的缺失值处理、在流式上下文中使用语言模型理解金融市场等。有关用例灵感和更多信息,如文档、教程和指南,请随时查看Bytewax 网站。
为什么需要对数据流进行数据分析?
数据分析是任何机器学习任务成功的关键,指的是彻底理解我们的数据:其结构、行为和质量。
简而言之,数据分析包括分析与数据格式和基本描述符相关的方面(例如,样本数量、特征数量/类型、重复值)、其内在特征(如缺失数据或不平衡特征),以及在数据收集或处理过程中可能出现的其他复杂因素(例如,错误值或不一致特征)。
确保高数据质量标准对于所有领域和组织都至关重要,但对于那些处理持续输出数据的领域尤其相关,因为情况可能会快速变化,需要立即采取行动(例如,医疗监测、股票价值、空气质量政策)。
对于许多领域,数据分析是从探索性数据分析的角度使用的,考虑存储在数据库中的历史数据。相反,对于数据流,数据分析在流中的验证和质量控制中变得至关重要,数据需要在不同的时间帧或处理阶段进行检查。
通过将自动化分析嵌入我们的数据流中,我们可以立即获得反馈,了解当前数据状态,并在出现潜在关键问题时收到警报——无论这些问题与数据一致性和完整性(例如,数据损坏或格式变化)相关,还是与短时间内发生的事件(例如,数据漂移、偏离业务规则和结果)有关。
在现实世界中 — 你只知道墨菲定律肯定会生效,“一切都可能出错” — 自动化分析可能会帮助我们避免多个脑力难题和需要停产的系统!
关于数据分析,ydata-profiling一直是一个热门选择,无论是表格数据还是时间序列数据。这也不足为奇——一行代码就可以进行全面的分析和洞察。
复杂且耗时的操作在后台完成:ydata-profiling 自动检测数据中的特征类型,并根据特征类型(数字或分类)调整概要统计和可视化,这些内容会在分析报告中显示。
促进以数据为中心的分析,该包还突显了特征之间的现有关系,关注它们的配对交互和相关性,并提供数据质量警报的全面评估,从重复或常量值到偏斜和不平衡特征。
这确实是对我们数据的360º视角——付出最少的努力。

分析报告:突显潜在的数据质量问题。图片来源:作者。
汇总:bytewax 和 ydata-profiling
在开始项目之前,我们需要首先设置我们的 Python 依赖项并配置数据源。
首先,让我们安装bytewax和ydata-profiling包(你可能需要使用虚拟环境来进行这个操作—— 查看这些说明 如果你需要额外的指导!)
然后,我们将上传环境传感器遥测数据集(许可—CC0:公共领域),该数据集包含来自不同 IoT 设备的温度、湿度、一氧化碳、液化石油气、烟雾、光线和运动的多项测量:
在生产环境中,这些测量将由每个设备持续生成,输入将类似于我们在流媒体平台例如 Kafka中预期的内容。在这篇文章中,为了模拟流数据的上下文,我们将一次从 CSV 文件中读取一行数据,并使用 bytewax 创建数据流。
(快速旁注:数据流本质上是一个可以描述为有向无环图—DAG 的数据管道)
首先,让我们进行一些必要的导入:
然后,我们定义我们的数据流对象。之后,我们将使用无状态的映射方法,在其中传入一个函数以将字符串转换为日期时间对象,并将数据重组为格式(device_id, data)。
map 方法将以无状态的方式对每个数据点进行更改。我们修改数据的形状是为了在接下来的步骤中更容易地对数据进行分组,以便分别对每个设备进行数据分析,而不是同时对所有设备进行分析。
现在我们将利用 bytewax 的有状态能力来收集在我们定义的时间段内每个设备的数据。ydata-profiling 期望获得数据的时间快照,这使得窗口操作符成为实现这一目标的完美方法。
在 ydata-profiling 中,我们能够为特定上下文指定的数据框生成汇总统计。例如,在我们的示例中,我们可以生成涉及每个 IoT 设备或特定时间段的数据快照:
在定义了快照之后,利用 ydata-profiling 就像调用每个我们想要分析的数据框的 ProfileReport 一样简单:
在这个示例中,我们将图像写入本地文件作为 map 方法中的一个函数的一部分。这些图像可以通过消息工具报告,或者将来我们可以将它们保存到一些远程存储中。一旦配置文件完成,数据流会期望一些输出,因此我们可以使用内置的 StdOutput 打印已分析的设备以及在 map 步骤中传递出的配置文件时间:
执行 Bytewax 数据流的方法有多种。在这个示例中,我们使用相同的本地机器,但 Bytewax 也可以在多个 Python 进程中运行,跨多个主机,使用 Docker 容器,利用 Kubernetes 集群,以及 更多。
在本文中,我们将继续使用本地设置,但我们鼓励你查看我们的辅助工具 waxctl,它可以在你的管道准备好过渡到生产环境时管理 Kubernetes 数据流部署。
假设我们在包含数据流定义文件的相同目录中,我们可以使用以下命令运行它:
然后我们可以使用这些分析报告来验证数据质量,检查模式或数据格式的变化,并 比较不同设备或时间窗口之间的数据特征。
实际上,我们可以利用 比较报告功能,它以简单明了的方式突出显示两个数据配置文件之间的差异,从而帮助我们更容易地发现需要调查的重要模式或必须解决的问题:
准备好探索你自己的数据流了吗?
验证数据流对于持续识别数据质量问题以及比较不同时间段数据状态至关重要。
对于在医疗保健、能源、制造和娱乐等领域处理持续数据流的组织来说,自动化分析是建立数据治理最佳实践的关键,从质量评估到数据隐私。
这需要对数据快照进行分析,如本文所示,可以通过结合bytewax和ydata-profiling以无缝的方式实现。
Bytewax负责处理和结构化数据流所需的所有过程,这些数据流可以汇总并通过ydata-profiling进行比较,生成数据特征的综合报告。
适当地处理和分析传入数据能够在不同领域开启许多应用场景,从数据模式和格式错误的修正到突出和缓解由现实世界活动引发的额外问题,如异常检测(例如,欺诈或入侵/威胁检测)、设备故障以及其他偏离预期的事件(例如,数据漂移或与业务规则的不一致)。
现在你可以开始探索你的数据流了!让我们知道你发现了哪些其他应用场景,随时在评论中给我们留言,或在数据驱动的 AI 社区中与我们联系,提出问题和建议!在那里见!
致谢
本文得到了 Fabiana Clemente(CDO @ YData)的支持,开发了 ydata-profiling,以及 Zander Matheson(CEO & Founder @ Bytewax)和 Oli Makhasoeva(Developer Relations @ Bytewax),两者都开发了 bytewax。你可以在相应的文档中找到有关这些开源软件包的更多信息: ydata-profiling 文档 与 bytewax 文档。
理解和减轻 LLM 幻觉
LLM 幻觉检测挑战及其在一篇重要研究论文中提出的可能解决方案。
·
关注 发表于 Towards Data Science ·8 min read·Oct 23, 2023
--
近年来,大型语言模型(LLMs)展示了令人印象深刻且不断增强的能力,包括对用户提示生成高度流畅和令人信服的响应。然而,LLMs 以生成非事实性或荒谬陈述而闻名,这种特性通常称为“幻觉”。这种特征可能会在许多需要事实性的场景中损害信任,如总结任务、生成式问答和对话生成。
检测幻觉在人类中一直是一个挑战,在 LLM 的背景下同样如此。这尤其具有挑战性,因为我们通常无法获取用于一致性检查的真实背景信息。有关 LLM 生成的附加信息,如输出概率分布,可以帮助完成这一任务。然而,这类信息往往不可用,使得任务更加困难。
幻觉检测尚未解决,仍是一个活跃的研究领域。在这篇博客文章中,我们将一般介绍任务及其挑战,并介绍在研究论文 SELFCHECKGPT: Zero-Resource Black-Box Hallucination Detection for Generative Large Language Models[1] 中提出的一种可能的方法。我们将用实际例子说明论文中提出的一些方法,并指出每种方法的一些优缺点。你可以通过访问 Google Colab Notebook 来查看这些示例。
💡 更新:受本博客文章研究的启发,我们在 LangKit中发布了一个新功能。 response_hallucination 模块将自动计算一致性分数,帮助你了解 LLM 中幻觉响应的存在。你可以在 这个示例笔记本 中查看它。
本博客将涵盖:
-
什么是 LLM 幻觉
-
方法:SelfCheckGPT
-
一致性检查
1. BERTScore
2. 自然语言推理
3. LLM 提示
-
实验
-
结论
-
参考文献
什么是 LLM 幻觉
在自然语言生成中,幻觉可以定义为“生成的内容是无意义的或不忠于提供的源内容”[2]。源内容可以代表不同任务的不同功能。例如,对于摘要任务,源内容是被摘要的输入文本,而对于生成式问答任务,源内容被认为是世界知识[2]。我们来关注问答任务,并以用户问题和 OpenAI 的 GPT 模型(text-davinci-003)给出的答案为例:
用户提示:菲利普·海沃斯是谁?
LLM 响应:菲利普·海沃斯是英国的一名律师和政治家,曾于 1859 年至 1868 年担任特特福德的国会议员。
初看起来,这似乎是一个连贯的回答,但过一段时间,我们可能会问一个非常合理的问题:这是真的吗?
不过,我们并没有足够的信息来回答这个新问题。我们没有可以用来验证陈述的经过验证的事实的背景。LLM 通常在大量文本中进行训练,因此很难验证陈述是否基于特定的信息来源。快速的 Google 搜索没有显示出匹配回答的人的任何信息。什么确切构成幻觉仍然没有很好地建立,但这个例子可以被理解为幻觉,因为这个名字、地点和职业的主题可能根本不存在。
让我们再试一次另一个问题:
用户提示:比尔·盖茨是谁?
LLM 回应:比尔·盖茨是技术企业家、投资者和慈善家。他是微软公司的共同创始人,也是世界上最富有的人之一。
好的,这些陈述更容易验证,并且通过快速检查维基百科得到了支持,所以这不太可能是幻觉。在训练过程中,LLM 很可能见过很多关于比尔·盖茨的陈述,因此“比尔·盖茨是 _”之后的令牌很可能会以较高的信心生成。另一方面,LLM 可能对“Philip Hayworth 是 _”之后使用哪些词不太确定。这一见解使我们能够将不确定性与真实性联系起来,因为事实句子通常会包含预测概率较高的令牌,而幻觉句子则不然。然而,对于许多案例,我们可能没有手头的输出概率分布。
本次会议的示例和内容基于原始论文[1],我们将在接下来的章节中继续探索论文的方法。
方法:SelfCheckGPT
在上一节中,我们考虑了我们方法的两个重要因素:访问外部背景和访问 LLM 的输出概率分布。当一种方法不需要外部背景或数据库来进行一致性检查时,我们可以称其为零资源方法。类似地,当一种方法只需要 LLM 生成的文本时,可以称之为黑箱方法。
我们在这篇博客文章中要讨论的方法是一种零资源黑箱幻觉检测方法,基于这样一个前提:对相同提示的采样回答对于幻觉事实可能会出现分歧和矛盾,而对于事实陈述则可能会相似和一致。
让我们重新审视之前的例子。为了应用检测方法,我们需要更多的样本,所以让我们再向 LLM 提出三个相同的问题:

作者提供的表格
确实,答案相互矛盾——有时,Philip Hayworth 是一位英国政治家,而在其他样本中,他是澳大利亚工程师或美国律师,他们生活和行动于不同的时期。
让我们以比尔·盖茨的例子进行比较:

表格作者提供
我们可以观察到,比尔·盖茨分配的职业、组织和特征在样本之间是一致的,使用了相等或语义相似的术语。
一致性检查
现在我们有了多个样本,最后一步是进行一致性检查——确定答案是否彼此一致。这可以通过多种方式完成,所以让我们探索一下论文中提出的一些方法。你可以通过查看这个 Google Colab Notebook 自行执行代码。
BERTScore
执行此检查的一种直观方法是测量样本之间的语义相似度,而 BERTScore[3] 是一种实现方式。BERTScore 为候选句子中的每个词与参考句子中的每个词计算相似度分数,以计算句子之间的相似度分数。
在 SelfCheckGPT 的背景下,分数是逐句计算的。原始答案的每个句子将与给定样本的每个句子进行评分,以找到最相似的句子。这些最大相似度分数将在所有样本中进行平均,从而为原始答案中的每个句子得到最终的幻觉分数。最终分数需要趋近于 1(表示不相似的句子)和 0(表示相似的句子),因此我们需要从 1 中减去相似度分数。
让我们展示如何用原始答案的第一个句子与第一个样本进行检查:

图片作者提供
第一个样本的最高分是 0.69。重复对剩余两个样本的处理,并假设其他最高分为 0.72 和 0.72,那么我们对该句子的最终分数将是 1 — (0.69+0.72+0.72)/3 = 0.29。
使用语义相似度来验证一致性是一种直观的方法。其他编码器也可以用于嵌入表示,因此这也是一种可以进一步探索的方法。
自然语言推理
自然语言推理是确定蕴涵的任务,即根据前提[4]判断一个假设是否为真、假或未确定。在我们的案例中,每个样本用作前提,每个原始答案的句子用作我们的假设。通过对每个句子的样本分数进行平均,得到最终分数。蕴涵通过对 Multi-NLI 数据集[5] 进行微调的 Deberta 模型来执行。我们将使用归一化预测概率来计算分数,而不是实际类别,如“蕴涵”或“矛盾”。[6]
蕴涵任务更接近我们的一致性检查目标,因此我们可以期待为此目的微调的模型会表现良好。作者还在 HuggingFace 上公开分享了该模型,其他 NLI 模型也公开可用,使得这种方法非常容易获取。
LLM Prompt
考虑到我们已经使用 LLM 来生成答案和样本,我们不妨使用 LLM 来执行一致性检查。我们可以对每个原始句子和每个样本进行一致性检查,将 LLM 作为我们的上下文。下面的图片,来自原始论文的仓库,说明了如何进行这个操作:

SELFCHECKGPT WITH LLM PROMPT. 来源: HTTPS://GITHUB.COM/POTSAWEE/SELFCHECKGPT/TREE/MAIN
最终得分可以通过将“否”赋值为 1,“是”赋值为 0,“不适用”赋值为 0.5,并对样本的值进行平均来计算。
与其他两种方法不同,这种方法需要额外调用你选择的 LLM,这意味着额外的延迟和可能的额外成本。另一方面,我们可以利用 LLM 的能力来帮助我们进行检查。
实验
让我们看看在三种方法中讨论的两个示例的结果如何。

作者提供的表格
这些值仅用于说明方法。只有三个句子的情况下,它不应该用来比较和确定哪种方法最佳。为此,原始论文在论文的仓库中分享了实验结果 这里,包括了在这篇博客中未讨论的附加版本。我不会详细讨论结果,但根据所有三个指标(NonFact、Factual 和 Ranking),LLM-Prompt 是表现最好的版本,其次是 NLI 版本。BERTScore 版本明显比剩余两个版本要差。我们的简单示例似乎符合共享结果的方向。
结论
我们希望这篇博客文章有助于解释幻觉问题,并提供一种可能的幻觉检测解决方案。这是一个相对较新的问题,很高兴看到已经有努力在解决它。
讨论的方法具有不需要外部上下文(零资源)和不需要 LLM 的输出概率分布(黑箱)的优点。然而,这也带来了成本:除了原始响应外,我们还需要生成额外的样本来执行一致性检查,从而增加了延迟和成本。一致性检查还需要额外的计算和语言模型来将响应编码为嵌入,进行文本蕴含,或查询 LLM,这取决于所选的方法。
参考文献
[1] — Manakul, Potsawee, Adian Liusie, 和 Mark JF Gales。“Selfcheckgpt:用于生成大型语言模型的零资源黑箱幻觉检测。” arXiv 预印本 arXiv:2303.08896 (2023)。
[2] — JI, Ziwei 等人。《自然语言生成中的幻觉调查》。ACM 计算调查,第 55 卷,第 12 期,页码 1–38,2023 年。
[3] — ZHANG, Tianyi 等人。Bertscore:使用 bert 评估文本生成。arXiv 预印本 arXiv:1904.09675,2019 年。
[4] — nlpprogress.com/english/natural_language_inference.html
[5] — Williams, A., Nangia, N., & Bowman, S. R. (2017)。用于通过推理理解句子的广泛覆盖挑战语料库。arXiv 预印本 arXiv:1704.05426。
[6] — github.com/potsawee/selfcheckgpt/tree/main#selfcheckgpt-usage-nli
深入理解 AUC 分数:意义何在?
原文:
towardsdatascience.com/understanding-auc-scores-in-depth-whats-the-point-5f2505eb499f
探索替代度量标准以获得更深入的见解
·发表于 Towards Data Science ·阅读时间 11 分钟·2023 年 9 月 2 日
--

图片来源:Jonathan Greenaway 在 Unsplash
你好!
今天,我们将深入探讨用于评估模型性能的特定度量标准——AUC 分数。但在深入细节之前,你是否曾经想过为什么有时需要一些不直观的分数来评估我们模型的性能?
无论我们的模型处理的是单一类别还是多个类别,根本目标始终不变:优化准确预测的同时最小化错误预测。为了探讨这个基本目标,让我们首先来看一下包含真实正例、假正例、真实负例和假负例的混淆矩阵。

作者提供的图片
对于任何分类或预测问题,只有两种结果:真或假。
因此,评估预测或分类算法性能的每个度量标准都基于这两项指标。实现这一点的最简单度量标准是准确率。
准确率
在分类和预测的背景下,准确率表示正确预测的实例在总实例中的比例。这是对模型预测性能非常直接和直观的度量标准。

作者提供的图片
然而,准确率真的足够吗?
虽然准确率是评估模型表现的一个良好一般指标,但当我们查看下面的表格时,它的不足之处变得显而易见。该表展示了四个模型的性能指标,每个模型的结果都有些次优,但这些模型都表现出高准确率。例如,在第一和第二个案例中,明显存在对某一类别的偏倚,导致对不常见类别的分类结果不佳,但准确率为 90%,这非常具有误导性。

作者提供的图像
这帮助我们得出结论:
尽管准确率很有价值,但它有时会误导,尤其是在类别不平衡的场景中或某些错误具有重大后果时。
例如,在漏掉正例(类型 2 错误)或错误识别负例(类型 1 错误)的代价高昂的情况下,单靠准确率可能无法全面评估模型的有效性。
准确率的优点在于其简单性以及在各个类别中的适用性。
现在,考虑到准确率,让我们深入探索预测和分类的领域,几个问题出现了:
-
我们的目标是什么?
-
我们的数据是否平衡?
-
我们是否优先考虑一个类别而非另一个类别?
-
我们倾向于避免假正例(类型 2 错误),还是强调最小化假负例(类型 1 错误)?
在提出这些问题之后,仅根据准确率评估模型性能似乎有些微不足道,因此我们将注意力转向评估模型性能的其他三个指标,即精确度、召回率和 F1 分数。
精确度
精确度衡量模型在识别特定类别(通常是积极类别)时的准确性。它衡量了对该类别预测的可靠性。
考虑一个场景,其中机器学习算法基于借款人特征预测贷款批准。虽然偶尔对合格候选人的贷款拒绝(假负例)可能对公司是可以接受的,但主要问题是避免对那些不应获得贷款的人(假正例)做出不必要的批准。

作者提供的图像
从本质上讲,精确度旨在最小化类型 2 错误——那些应被拒绝但错误被接受的实例。
让我们通过返回到我们的表格来演示这一点:

作者提供的图像
我们从案例 1 和 3 中观察到,当特定(积极)类别的真正正例与假正例的比例较大时,可以实现更高的精确度,而不管实际的模型表现如何。因此,
对于特定类别,高精确度意味着低类型 2 错误。
接下来,我们有一个精确度的对应物来处理类型 1 错误:
召回率、敏感度或 TPR
召回率与精确率一样,关注的是我们对特定类别的预测能力。它量化了我们从整个池中准确选择属于特定类别的实例的能力。
考虑一个场景,其中我们模型的目标是防止信用欺诈。我们可能可以处理将非欺诈性活动标记为欺诈性(假阳性)的情况,但我们不希望遗漏可能实际上是欺诈的活动(假阴性)。
在这种情况下,追求高召回率涉及到最小化 Type 1 错误——确保尽可能捕捉到所有相关实例,即使这意味着标记一些无辜的实例。

图片由作者提供
让我们第三次回到我们的表格:

图片由作者提供
从案例 1 和 4 中我们可以看到,当最大化正分类时,召回率表现优异。在这些情况下,即使存在假阳性或负类别的表现不佳,召回率仍然很高。因此,
对于特定类别,高召回率确保了 Type 1 错误的最小化。
现在,如果我们旨在最小化类别的两种错误类型呢?
这就是 F1 得分发挥作用的地方:
F1-得分
F1-得分是精确率和召回率的调和平均数。F1 得分本质上试图在类别的精确率和召回率之间找到平衡,并在这种方式中也尝试平衡 Type 1 和 Type 2 错误。
从本质上讲,F1 得分展示了分类模型在特定类别上的整体有效性。

图片由作者提供
让我们从第四个也是最后一个角度回到我们的表格:

图片由作者提供
这一次,我们观察到,当精确率和召回率都表现优异时,F1 得分表现良好,这在案例 1 中是正确的。注意,尽管模型在这种情况下的表现仍然不够理想,但由于真正的正例数量很高,而假阳性和假阴性的数量都很小,因此 F1 得分很高。因此,我们可以推断,
高 F1 得分对于特定类别来说更为保守,可以最小化 Type 1 和 Type 2 错误。
在检查了其他指标后,很明显每个指标都有其局限性。此外,与准确率不同,F1、精确率和召回率并不具备类无关性,而准确率仍然容易受到类不平衡的影响。
此外,我们从未提出过这样的问题:“预测阈值是多少?”换句话说,类别之间是否存在明显的分界?正实例的预测分数是集中在 0.8-0.9 之间,还是更接近 0.51?
这就是 ROC 曲线和 AUC 得分发挥作用的地方,尽管这些方法有一定程度的直观性不足。
理解 AUC 得分
AUC 分数,也称为曲线下面积,是通过计算接收者操作特征(ROC)曲线下的面积来衡量的分数。
ROC 曲线是一个图表,y 轴为召回率/真阳性率(TPR),x 轴为假阳性率(FPR)。ROC 的名字源于其在电气工程领域的起源。
为了构造 ROC 曲线,需要计算不同分类阈值下的 FPR 和 TPR。分类阈值指的是预测值,例如 0.5,超过该值的实例被分类为正类,而低于 0.5 的实例则被分类为负类。

由 cmglee 和 MartinThoma 制作 — Roc-draft-xkcd-style.svg,CC BY-SA 4.0,commons.wikimedia.org/w/index.php?curid=109730045
然而,由于创建完整的曲线可能会很繁琐且无法提供量化的衡量,因此测量曲线下的面积(AUC)成为了更常用的做法。
下面,我展示了通过绘制 FPR 与 TPR 值的进展,以显示 ROC 曲线的创建过程。我将每个新添加的点标记为红色,而之前的点标记为蓝色。我们可以看到,当这些独立的点连接起来时,曲线与上图中的浅蓝色曲线非常相似,并且与模型的 AUC 分数 0.91 一致。

作者提供的图片
图中还展示了阈值的变化如何影响其他指标,如准确率、F1 分数、精确度和召回率。
用于创建上述图形的代码如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, roc_curve, auc, f1_score, accuracy_score
# Synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
# Train/test data split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# LR Model training
model = LogisticRegression()
model.fit(X_train, y_train)
# Model Predictions
y_probs = model.predict_proba(X_test)[:, 1]
# Generating Random thresholds for creating ROC Curve
num_thresholds = 9
random_thresholds = np.sort(np.random.rand(num_thresholds))
# Visualization
fig, axes = plt.subplots(3, 3, figsize=(16, 16))
axes = axes.flatten()
for i, threshold in enumerate(random_thresholds):
ax = axes[i]
for t in random_thresholds[:i+1]:
y_pred = (y_probs >= t).astype(int)
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
fpr = fp / (fp + tn)
tpr = tp / (tp + fn)
color = 'red' if t == threshold else 'blue'
label = f'Threshold {t:.2f}' if t == threshold else None
ax.scatter(fpr, tpr, color = color, label = label, s = 50)
ax.plot([0, 1], [0, 1], color = 'gray', linestyle = '--')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate(FPR)')
ax.set_ylabel('True Positive Rate(TPR)')
ax.set_title(f'Points at Different Thresholds (New Point in Red)\nRandom Threshold {threshold:.2f}')
ax.legend(loc = "lower right")
# Calculating AUC Score
fpr, tpr, _ = roc_curve(y_test, y_probs)
roc_auc = auc(fpr, tpr)
# Calculating precision, recall, F1 score, and accuracy
precision = tp / (tp + fp)
recall = tp / (tp + fn)
f1 = f1_score(y_test, y_pred)
accuracy = accuracy_score(y_test, y_pred)
# Displaying other metrics for various classification thresholds
metrics_text = f'AUC: {roc_auc:.2f}\nPrecision: {precision:.2f}\nRecall: {recall:.2f}\nF1 Score: {f1:.2f}\nAccuracy: {accuracy:.2f}'
ax.text(0.5, 0.1,
metrics_text,
transform=ax.transAxes,
fontsize=10,
va='bottom',
ha='center')
plt.suptitle('Progression of ROC Curve via estimatimating points at different Classification Thresholds', size = 26, y=1)
# plt.savefig('ROC.png')
plt.tight_layout()
plt.show()
这段代码展示了如何使用 Python 创建 ROC 曲线并计算 AUC 分数。它包括生成一个合成数据集,将其拆分为训练集和测试集,训练一个逻辑回归模型,然后绘制不同分类阈值下的 ROC 曲线点。
一个题外话:如果换成真负率与假负率对比会有什么影响吗?
不,因为 TNR = 1-FPR 和 FNR = 1-TPR

作者提供的图片
多类别情况如何?
因为 ROC 曲线一次只能计算两个类别,所以对于多个类别来说会稍显复杂,但可以对每个类别进行一对多的比较。
现在我们已经了解了 AUC 分数的构造方式,让我们更深入地探讨如何解释 AUC 分数及其对模型的意义。
解释 AUC 分数
我们将查看 4 种情况,我通过图片展示了在分类阈值右侧的所有数据被预测为正类,而左侧的数据被预测为负类。数据的实际标签在各自的矩形框中表示。
0 AUC 分数
AUC 分数为 0 将很难达到,并且可能指向一些严重的人为错误。这意味着 TPR 在每个分类阈值下都是 0,除非 FPR 为 1。在 FPR 为 1 时,TPR 在不同的分类阈值下从 0 到 1,这使得曲线下的面积有效地为 0。视觉上,这将意味着存在完美的类别分离,但每个标签都被反转了。

图片由作者提供
0.5 AUC 分数
现在这是模型没有学到任何东西,分类或预测仅仅是随机的情况。在这种情况下,ROC 曲线显示了 TPR 和 FPR 之间的线性关系或直线比例线。视觉上,它看起来好像完全没有类别分离。

图片由作者提供
0.5 < AUC 分数 < 1
如果我们在模型中正确操作,这些是最常见的情况。AUC 分数在 0.5 到 1 之间意味着模型至少学到了一些优于随机分类器的东西,但数据中仍然存在重叠实例,模型无法完全有效地分离类别。AUC 分数越接近 1,类别之间的分离程度就越大。

图片由作者提供
完美的 AUC 分数
完美的 AUC 分数 1.0 表示模型具有完美的区分能力。然而,达到如此完美的 AUC 分数确实可能是潜在过拟合和不切实际的模型行为的标志,尤其是在处理现实世界数据集和场景时。
除非 FPR 为 0,否则 TPR 在所有情况下都有效为 1。最终导致曲线下的面积等于 1。这在视觉上意味着两个类别之间有完美的类别分离,并且预测也是正确的。

图片由作者提供
重要的是要注意,现实世界的数据通常是嘈杂的并且包含固有的不确定性。期望一个模型产生完美的分离可能是不现实的,即使它发生了,这也可能是过拟合的情况。
如果达到了不太可能的 AUC 分数 1,模型可能会由于过拟合的高概率而无法有效地进行泛化,并在实际场景中表现良好。
一个更平衡和稳健的模型是一个在达到合理高的 AUC 分数的同时,还允许一定程度的不确定性。
让我们看看使用 AUC 分数的一些好处。
AUC 分数的好处
与其他指标相比,AUC 分数有很多好处
-
与类别无关: 与依赖于选定正类的精度、召回率和 F1 分数等指标相比,AUC 提供了对模型区分能力的更全面评估,而不受类别分布的影响。
-
预测阈值无关性: AUC 评分的一个显著特点是它考虑了模型在不同分类阈值下的表现,提供了其区分类别能力的全面视图。
-
对类别不平衡不敏感: 由于 AUC 衡量模型如何相对排序正负实例,它不容易受到类别分布不平衡引起的扭曲。
-
ROC 曲线阈值选择: 尽管 AUC 不受阈值影响,但 ROC 曲线本身可以帮助你直观地选择一个提供最佳性能的阈值。
AUC 评分的缺点
虽然没有很多缺点,而且这不是一个详尽的列表,但需要注意的是,在数据极度不平衡的情况下,AUC 评分可能会受到影响。此外,AUC 对所有误分类一视同仁。在许多现实世界的场景中,不同类型错误相关的成本和收益可能会有所不同。AUC 没有考虑这些因素,可能无法完全代表在某种错误比另一种更为关键的情况下的性能。
总结
因此,让我们总结一下,虽然 AUC 评分是评估模型性能的一个极佳指标,但每个指标在合适的上下文中都有其优点。
在结束时,我希望这篇文章对你有价值,并希望你能清楚地理解 AUC 评分是什么以及如何解读它。
如果你觉得这很有启发性,请在评论中告诉我。😃
AUC 评分的其他资源……
“没有解决方案,只有权衡。”在之前的文章中,我们展示了如何处理图像和……
斯坦福大学研究生 Shervine Amidi 的教学页面。
本教程解释了什么被认为是 AUC(曲线下面积)的好值,包括几个示例。
了解贝叶斯市场营销组合建模:深入探讨先验规格
探索使用 Google 的 LightweightMMM 进行模型规格化
·
关注 发表在Towards Data Science · 8 分钟阅读·2023 年 6 月 24 日
--
图片由Pawel Czerwinski拍摄,发布在Unsplash
贝叶斯营销混合建模越来越受到关注,特别是随着 LightweightMMM(Google)或 PyMC Marketing(PyMC Labs)等开源工具的发布。尽管这些框架简化了贝叶斯建模的复杂性,但用户仍需理解基本的贝叶斯概念,并能够理解模型规范。
在这篇文章中,我以 Google 的 LightweightMMM 作为实际示例,展示了该框架先验规范的直观性和含义。我演示了如何使用 Python 和 scipy 库进行先验样本的模拟。
数据
我使用 Robyn 在 MIT 许可证下提供的数据。
数据集包含 208 周的收入数据(从 2015–11–23 到 2019–11–11),包括:
-
5 个媒体支出渠道:tv_S,ooh_S,print_S,facebook_S,search_S
-
2 个媒体渠道也具有曝光信息(印象,点击量):facebook_I,search_clicks_P
-
无支出的有机媒体:newsletter
-
控制变量:事件,假期,竞争对手销售(competitor_sales_B)
LightweightMMM 模型规范
LightweightMMM 模型的规范定义如下:

LMMM 模型规范(作者提供的图像)
这个规范表示一个加性线性回归模型,该模型解释特定时间点 t 的响应(目标变量)值。
让我们分解方程中的每个组件:
-
α:这个组件表示响应的截距或基线值。它是当其他因素为零时响应的预期值。
-
趋势:这个组件捕捉响应随时间的增加或减少趋势。
-
季节性:这个组件表示响应的周期性波动。
-
媒体渠道:这个组件考虑了媒体渠道(电视、广播、在线广告)对响应的影响。
-
其他因素:这个组件包括任何其他对响应有影响的变量,如天气、经济指标或竞争对手活动。
接下来,我将详细介绍每个组件,并解释如何解释先验规范。请记住,先验分布是对某个参数的假设分布,而没有任何关于底层数据的知识。
截距

截距先验规范(作者提供的图像)
截距定义为遵循标准差为 2 的半正态分布。半正态分布是一种连续概率分布,类似于正态分布,但仅限于正值。该分布由一个参数(标准差(尺度))来表征。半正态分布意味着截距只能取正值。
以下代码生成截距的先验分布样本,并可视化标准差为 2 的半正态分布的概率密度函数(PDF)。有关其他组件的可视化,请参阅附带的Github repo中的源代码。
from scipy import stats
scale = 2
halfnormal_dist = stats.halfnorm(scale=scale)
samples = halfnormal_dist.rvs(size=1000)
plt.figure(figsize=(20, 6))
sns.histplot(samples, bins=50, kde=False, stat='density', alpha=0.5)
sns.lineplot(x=np.linspace(0, 6, 100),
y=halfnormal_dist.pdf(np.linspace(0, 6, 100)), color='r')
plt.title(f"Half-Normal Distribution with scale={scale}")
plt.xlabel('x')
plt.ylabel('P(X=x)')
plt.show()

半正态分布(图像由作者提供)
趋势

趋势规范(图像由作者提供)
趋势定义为时间 t 与趋势值之间的幂律关系。参数 μ 代表趋势的幅度或大小,而 k 控制趋势的陡度或曲率。
参数 μ 从均值为 0、标准差为 1 的正态分布中抽取。这意味着 μ 遵循标准正态分布,以 0 为中心,标准差为 1。正态分布允许 μ 取正值和负值,分别表示向上或向下的趋势。
参数 k 从 0.5 到 1.5 的均匀分布中抽取。均匀分布确保 k 取值可以产生合理且有意义的趋势曲率。
下图展示了从先验分布中获得的独立组件:截距和趋势的样本,每个组件单独表示。

趋势和截距(图像由作者提供)
季节性

季节性规范(图像由作者提供)
每个组件 γ 从均值为 0、标准差为 1 的正态分布中抽取。
通过将不同 γ 的余弦和正弦函数结合,可以模拟周期模式以捕捉数据中的季节性。余弦和正弦函数表示在 52 个单位(周)的周期内观察到的振荡行为。
下图展示了从先验分布中获得的季节性、截距和趋势的样本。

季节性、趋势和截距(图像由作者提供)
其他因素(控制变量)

其他因素规范(图像由作者提供)
每个因子系数 λ 取自均值为 0 和标准差为 1 的正态分布,这意味着 λ 可以取正值或负值,代表每个因子对结果的影响方向和幅度。
下图描绘了从先验分布中获得的独立组件:一个拦截项、趋势、季节性和控制变量(竞争对手销售 _B, 新闻通讯, 节假日和活动),每个组件都单独表示。

其他因素(合并)(图像由作者提供)
媒体渠道

媒体渠道的先验规范(图像由作者提供)
媒体渠道 m 的 β 系数的分布被指定为半正态分布,其中标准差参数 v 由与媒体渠道 m 相关的总成本的总和确定。总成本反映了分配给该特定媒体渠道的投资或资源。
媒体变换

Adstock 和 Hill 饱和规范(图像由作者提供)
在这些方程中,我们使用一系列变换(如 adstock 和 Hill 饱和)来建模媒体渠道的行为。
实验先验、数据标准化,并将贝叶斯建模与 Robyn、Facebook 的开源 MMM 进行比较……
[towardsdatascience.com
变量 媒体渠道 代表了在时间点 t 上经过变换后的媒体渠道。它是通过对原始媒体渠道值 x 应用变换获得的。Hill 变换由参数 K(半饱和点,0 < k ≤ 1)和控制曲线陡峭度的形状参数 S(s > 0)来控制。
变量 x∗ 代表在经过 adstock 变换后的时间 t 上的变换媒体渠道值。它是通过将当前原始媒体渠道值加到前一个变换值与 adstock 衰减参数 λ 的乘积来计算的。
参数 K 和 S 遵循 gamma 分布,形状和尺度参数均设置为 1,而 λ 遵循 beta 分布,形状参数为 2 和 1。
Hill 饱和参数 K 和 S 的概率密度函数如下图所示:
shape = 1
scale = 1
gamma_dist = stats.gamma(a=shape, scale=scale)
samples = gamma_dist.rvs(size=1000)
plt.figure(figsize=(20, 6))
sns.histplot(samples, bins=50, kde=False, stat='density', alpha=0.5)
sns.lineplot(x=np.linspace(0, 6, 100), y=gamma_dist.pdf(np.linspace(0, 6, 100)), color='r')
plt.title(f"Gamma Distribution for $K_m$ and $S_m$ with shape={shape} and scale={scale}")
plt.xlabel('x')
plt.ylabel('P(X=x)')
# Show the plot
plt.show()python

Gamma 分布(图像由作者提供)
adstock 参数 λ 的概率密度函数如下图所示:

Beta 分布(图像由作者提供)
关于 adstock 参数 λ 的说明:
Beta(α = 2, β = 1) 分布的概率密度函数呈现正趋势,表明较高的值具有更高的概率密度。在媒体分析中,不同的行业和媒体活动可能显示出不同的衰减率,大多数媒体渠道通常表现出较小的衰减率。例如,Robyn 提出了常见媒体渠道 λ 衰减的以下范围:电视 (0.3–0.8)、户外广告/印刷/广播 (0.1–0.4)、数字媒体 (0–0.3)。
在 Beta(α = 2, β = 1) 分布的背景下,较高的概率分配给接近 1 的 λ 值,而较低的概率分配给接近 0 的值。因此,相较于接近区间 [0, 1] 下端的结果,接近上端的结果更有可能发生。
或者,在具有延续性和形状效应的贝叶斯媒体混合建模方法中,衰减参数被定义为 Beta(α = 3, β = 3),其概率密度函数如下图所示。该分布在 0.5 附近对称,表明在区间 [0, 1] 的两端和中间位置观察到结果的可能性相等。

Beta(3,3)(图片由作者提供)
下图描绘了从先验分布中获得的各个独立组件:截距、趋势、季节性、控制变量和媒体渠道的样本,每个组件单独表示。

所有模型组件(图片由作者提供)
组合所有组件
如前所述,LightweightMMM 通过结合截距、趋势、季节性、媒体渠道及其他从其先验分布中抽样的因素来模拟加性线性回归,从而获得预测响应。下图可视化了真实响应和从先验预测分布中抽样得到的预期响应。
将单个样本与真实响应值进行可视化,可以观察模型的预测与实际结果在特定参数值集下的比较。这可以直观地理解模型在该特定实例中的表现。

收入:真实与先验(图片由作者提供)
先验预测检查
为了获得更稳健的洞见,一般建议从先验预测分布中多次抽样并测量不确定性。先验预测检查有助于评估所选择模型的充分性,并在观察任何实际数据之前评估模型的预测是否符合我们的预期。
下面描绘的图表通过显示每个点的预期收入(均值)以及不确定性度量来可视化先验预测分布。我们可以看到真实收入落在标准差范围内,这表明模型规格适合观察到的数据。

先验预测检查(图片由作者提供)
结论
贝叶斯营销组合建模可能需要相当多的时间来掌握。希望这篇文章能帮助你提高对先验分布和贝叶斯营销模型规格的理解。
完整的代码可以从我的Github repo下载。
感谢阅读!
理解因果树
原文:
towardsdatascience.com/understanding-causal-trees-920177462149
因果数据科学
如何使用回归树来估计异质性处理效应
·发表于 Towards Data Science ·阅读时间 15 分钟·2023 年 2 月 3 日
--

封面,图片由作者提供
在因果推断中,我们通常关注的是估计处理(药物、广告、产品等)对感兴趣结果(疾病、公司收入、客户满意度等)的因果效应。然而,知道处理在平均情况下有效通常是不够的,我们希望了解对哪些对象(患者、用户、客户等)效果更好或更差,即我们希望估计异质性处理效应。
估计异质性处理效应使我们能够通过目标定位选择性地和更有效地使用处理。了解哪些客户更可能对折扣做出反应可以使公司通过提供更少但更精准的折扣来节省开支。这同样适用于负面效应:知道哪些患者对某种药物有副作用可以使制药公司警告或将他们排除在治疗之外。估计异质性处理效应还有一个更微妙的优势:了解谁对处理有效可以帮助我们更好地理解如何处理有效。知道折扣的效果不依赖于接受者的收入而是依赖于其购买习惯,告诉我们也许这不仅仅是钱的问题,而是关注度或忠诚度的问题。
在本文中,我们将探讨使用回归树(及其森林)改进版来估计异质性处理效应。从机器学习的角度来看,因果树与预测树之间有两个基本差异。首先,目标是处理效应,这本质上是一个不可观察的对象。其次,我们关注的是进行推断,这意味着量化我们估计的不确定性。
在线折扣与目标定位
在文章的其余部分,我们将使用一个示例进行说明:假设我们是一个在线商店,并且我们希望了解是否对新客户提供折扣会增加他们的支出。特别是,我们希望知道对某些客户提供折扣是否比对其他客户更有效,因为我们不希望对那些即使没有折扣也会消费的客户进行折扣。此外,向客户发送弹窗广告可能会让他们反感,从而产生相反的效果。

图片由作者使用NightCafé生成
为了了解折扣的效果以及效果的大小,我们进行了一项A/B 测试:每当一个新用户访问我们的在线商店时,我们会随机决定是否向他们提供折扣。我从[src.dgp](https://github.com/matteocourthoud/Blog-Posts/blob/main/notebooks/src/dgp.py)导入数据生成过程dgp_online_discounts()。与之前的文章相比,我生成了一个新的 DGP 父类来处理随机化和数据生成,而其子类包含具体的使用案例。我还从[src.utils](https://github.com/matteocourthoud/Blog-Posts/blob/main/notebooks/src/utils.py)导入了一些绘图函数和库。为了包括代码、数据和表格,我使用了Deepnote,这是一个类似 Jupyter 的基于 Web 的协作笔记本环境。
我们有 100,000 名网站访问者的数据,我们观察他们的time(时间)、使用的device(设备)、browser(浏览器)以及他们的地理region(区域)。我们还记录了他们是否获得了discount(折扣),我们的处理,以及他们的spend(支出),这是我们的关注点。
由于处理是随机分配的,我们可以使用简单的均值差异估计量来估计处理效应。我们期望处理组和对照组在discount(折扣)之外是相似的,因此我们可以将spend(支出)的任何差异归因于discount(折扣)。
折扣似乎有效:在处理组中,平均支出增加了 1.95 美元。但所有客户的反应是否相同?
为了回答这个问题,我们希望估计异质处理效应,可能在个体层面上。
异质处理效应
估计异质处理效应有多种方法。最常见的方法是根据一些可观察的特征将人群划分为不同的组,在我们的案例中,这些特征可以是device(设备)、browser(浏览器)或地理region(区域)。一旦决定了数据划分的变量,你可以简单地将处理变量(discount)与处理异质性的维度进行交互。以device为例。
我们如何解读回归结果?discount对客户spend的影响是 1.22\(,但如果客户通过移动`device`访问网站,这一影响会增加至 1.44\)。
对于分类变量,划分很容易,但对于像time这样的连续变量来说,不直观如何划分。每小时划分一次?哪个维度更具信息性?虽然很诱人尝试所有可能的划分,但我们对数据的划分越多,发现虚假结果(即在机器学习术语中,我们过拟合)的可能性就越大。如果我们能让数据说话并选择最小且信息量最大的划分,那就太好了。
在另一篇文章中,我展示了所谓的元学习者如何采取这种因果推断方法。思路是根据每个观察的治疗状态预测结果,然后将预测的条件治疗结果与预测的对照结果进行比较。二者之间的差异就是个体治疗效应。
元学习者的问题在于,它们在预测结果时使用了所有的自由度。然而,我们感兴趣的是预测治疗效应的异质性。如果结果的大部分变异不在治疗维度上,我们将得到非常差的治疗效应估计。
是否可以直接集中在个体治疗效应的预测上?我们将Y定义为感兴趣的结果spend,D为治疗discount,以及X为其他可观察特征。理想的损失函数是

理想的损失函数,图片由作者提供
其中τᵢ是个体i的治疗效应。然而,这个目标函数是不可行的,因为我们无法观察到τᵢ。
但事实证明,有一种方法可以获得个体治疗效应的无偏估计。思路是使用一个辅助结果变量,其每个个体的期望值即为个体治疗效应。这个变量是

辅助结果变量,图片由作者提供
其中p(Xᵢ)是观察i的倾向评分,即其被治疗的概率。
在随机化实验中,倾向得分是已知的,因为随机化完全在实验者的控制之下。例如,在我们的案例中,治疗的概率是 50%。而在准实验研究中,当治疗概率未知时,需要进行估计。即使在随机化实验中,估计倾向得分总是比填补更好,因为它能防止随机化中的抽样变异。有关倾向得分及其在因果推断中的使用的更多详细信息,请参阅我在这里的单独帖子。
首先,我们为类别变量device、browser和region生成虚拟变量。
我们拟合了一个LogisticRegression并用它来预测治疗概率,即构建倾向得分。

估计的倾向得分分布,图片来源:作者
正如预期的,大多数倾向得分接近 0.5,这是随机化中使用的治疗概率。此外,治疗组和对照组的分布几乎完全相同,进一步确认了随机化的有效性。如果情况不是这样,我们将需要做出进一步假设以进行因果分析。最常见的假设是无混淆性,也称为可忽略性或基于可观测变量的选择。简而言之,我们将假设在某些可观测变量𝑋的条件下,治疗分配是随机的。

无混淆假设,图片来源:作者
然而,在我们的案例中,治疗概率是已知的,并且似乎随机化过程中没有出现问题。
我们现在拥有计算辅助结果变量*Y**的所有元素。
正如我们之前所说,目的是将*Y**作为预测问题的目标,因为其期望值正好是个体治疗效果。让我们检查数据中的平均值。
确实,它的平均值几乎与之前估计的 1.94$的平均治疗效果相同。
如何在只有一个观察值和倾向得分估计的情况下估计个体治疗效果?有什么缺点?
直观的想法是从不同的角度来处理问题:事前,在实验之前。设想我们的数据集只有一个观察值,i。我们知道治疗概率是p(Xᵢ),即倾向得分。因此,期望中,我们的数据集中治疗组有p(Xᵢ)个观察值,对照组有1–p(Xᵢ)个观察值。其余的照常处理:我们通过两组之间的平均结果差异来估计治疗效果!这确实是我们会做的:

辅助结果变量,图片来源:作者
唯一的区别是我们只有一个观察值。
这个技巧有一个代价:Yᵢ 是个体处理效应的无偏估计量,但具有非常高的方差。通过绘制其分布,这一点可以立即显现出来。

辅助变量的分布,图像由作者提供
我们现在准备估计异质性处理效应,通过将因果推断问题转化为预测问题,预测给定可观察特征X的辅助结果Y。

图像由作者使用NightCafé生成
因果树
在前一节中,我们已经看到,我们可以将异质性处理效应的估计转化为预测问题,其中结果是辅助结果变量。

辅助结果变量,图像由作者提供
原则上,我们可以使用任何机器学习算法来估计个体处理效应。然而,回归树具有特别便利的特征。
首先,回归树是如何工作的?分类和回归树(CART)是基于协变量X递归对数据进行分箱的算法,使得每个箱中的结果Y在箱内尽可能同质,而箱间的结果尽可能异质。预测值只是每个箱中的结果平均值,在我们的情况下是辅助结果变量Y,每个观察值的期望值等于个体处理效应。因此,通过对每个箱中的Y进行平均,我们可以计算条件(基于 X)的异质性处理效应 𝔼[τᵢ | Xᵢ] 对于落在该箱中的观察值。
平均化部分是回归树推断的一个主要优势,因为我们非常清楚如何使用平均值进行推断,这要归功于中心极限定理。回归树相对于其他机器学习算法的第二个优势是树非常可解释,因为我们可以直接将数据分区绘制为树结构。我们稍后会详细了解这一点。最后但同样重要的是,截至 2022 年,回归树仍然是表现最佳的预测算法的核心。
让我们使用sklearn中的[DecisionTreeRegressor](https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html)函数来拟合我们的回归树,并估计discounts对顾客spend的异质性处理效应。
我们将树的最大深度限制为 2,并且每个分区(也称为叶子)至少包含 30 个观察值,以便我们可以轻松绘制树并可视化估计的组和处理效应。

辅助结果变量 Y*上的回归树,作者提供的图片
我们应该如何解读这棵树?在顶部,我们可以看到数据中的平均值Y为 1.945\(,对应于平均处理效应。从那里开始,数据根据每个节点顶部突出显示的规则被分成不同的分支。例如,第一个节点将数据分成两个大小分别为 51,156 和 48,844 的组,具体取决于`time`是否晚于 12.325。在底部,我们有最终的分区,包含异质的处理效应。例如,最左边的叶子包含 43,876 个观察值,其中`time`早于 12.325 且`browser`不是 Safari,我们预测对`spend`的影响为 0.295\)。简而言之,每个节点包含条件平均处理效应** 𝔼[τᵢ | Xᵢ*]的估计,其中节点颜色越深表示预测值越高。
我们应该相信这些估计吗?不完全是,因为有几个原因。第一个问题是,只有在每个叶子内部我们有相同数量的处理组和对照组单元时,我们才有平均处理效应的无偏估计。这在使用现成的DecisionTreeRegressor()时并非自动成立。
诚实树
我们的简单方法的另一个问题是我们使用了相同的数据来生成和评估树。这会产生偏差,因为简单的均值差异估计量不会考虑分区是内生的,即在相同的数据上生成的。用机器学习术语来说,我们是在过拟合。解决方案很简单:我们可以将样本拆分为两个独立的子样本,并使用不同的数据生成树和计算预测。这些树被称为诚实树。
这个解决方案既简单又有效,因为它允许我们在推断阶段,将每个叶子中的样本视为与树结构独立。此时,我们的估计量是对独立子样本的均值差异估计量,我们可以简单地使用中心极限定理进行推断。拆分数据的一个缺点是我们失去了统计功效,即由于样本较小而无法检测到虚假的异质处理效应。解决方案是重复该过程两次,交换用于构建树和计算叶子内均值的样本。然后,我们可以对每个个体的两个估计值取平均,并相应地调整估计的标准误差。

样本拆分过程,作者提供的图片
生成拆分
最后但同样重要的是,树应该如何生成?DecisionTreeRegressor函数生成分裂的默认规则是squared_error,并且对每个叶子中的最小观测数没有限制。其他常用规则包括平均绝对误差、基尼 impurity 和香农信息。哪个表现更好取决于具体应用,但总体目标始终是预测准确性,广义上定义。
相反,在我们的案例中,目标是推断:我们希望揭示在统计上不同的异质性处理效应。如果处理效应在统计上不可区分,那么生成不同的处理效应就没有意义。此外(但与之紧密相关),在构建树和生成数据分区时,我们必须考虑,由于我们使用的是诚实树,我们将使用不同的数据来估计叶内处理效应。
Athey 和 Imbens (2016)使用了均方误差 (MSE)的修改版本作为分裂标准,即扩展均方误差 (EMSE):

扩展均方根误差,作者提供的图片
其中μ是估计的条件期望μ(X) = 𝔼 [Y | X],与 MSE 的差异是额外的项Yᵢ²,即平方的结果变量。在我们的设置中,我们可以将其重写为

扩展均方根误差用于因果树,作者提供的图片
为什么这是一个合理的误差损失?因为我们可以将其重写为条件处理效应的期望方差减去平方期望值。

扩展均方根误差用于因果树,作者提供的图片
这种 EMSE 的公式明确了目标是最小化估计条件处理效应τ(X)的叶内方差(第一个项)。换句话说,小的叶子会被自动惩罚。第二个项只是一个归一化因子。请注意,这两个项都是未知的,必须从训练数据中估计,用于生成树。
实现
幸运的是,有多个因果树库可供选择。我们从微软的EconML库中导入CausalForestDML,这是最好的因果推断库之一。
我们将估计器的数量限制为 1,以便得到一棵树,而不是多棵树,即所谓的随机森林,我们将在另一篇文章中介绍。

估计的因果树,作者提供的图片
正如我们所见,树形表示与之前使用DecisionTreeRegressor函数得到的结果非常相似。然而,现在模型不仅报告条件平均处理效应的估计值,还有这些估计值的标准误差(在底部)。这些是如何计算的?
推断
诚实树除了提高模型的样本外预测准确性外,还有另一个重要意义:它们允许我们像树结构是外生的一样计算标准误差。实际上,由于用于计算预测的数据与用于构建树的数据(分割数据)是独立的,我们可以将树结构视为与估计的处理效应独立。因此,我们可以将估计的标准误差视为样本均值差异的标准误差,就像标准的 AB 测试一样。
如果我们使用相同的数据来构建树并估计处理效应,我们将引入偏差,由于协变量与结果之间的虚假相关性。这个偏差通常在非常大的样本量中消失,但诚实树并不需要这样。
性能
模型表现如何?由于我们控制了数据生成过程,我们可以做一些真实数据无法做到的事情:将预测的处理效应与真实值进行比较。add_treatment_effect()函数为数据中的每个观测值提供了“真实”处理效应。
我们现在可以检查因果树在估计个体处理效应方面的能力。让我们从分类变量开始。我绘制了基于device、browser和region每个值的真实和估计的平均处理效应。

每个分类值的真实和估计处理效应,图片由作者提供
因果树在检测分类变量的异质处理效应方面表现相当好。它高估了移动设备和 Safari 浏览器的效应,但总体上表现不错。
然而,这也是我们期望树模型表现特别好的地方:在离散的效应上。它在我们的连续变量时间上的表现如何?首先,让我们再次隔离time上的预测处理效应,并忽略其他协变量。
我们现在将预测的处理效应与真实值沿time维度进行绘图。

沿时间维度的真实和估计处理效应,图片由作者提供
从图中,我们可以欣赏到因果树的离散特性:模型只能将连续变量分割成 4 个区间。这些区间接近真实的处理效应,但未能捕捉到处理效应异质性的较大部分。
这些预测能得到改进吗?答案是肯定的,我们将在下一篇文章中探讨如何改进。
结论
在本文中,我们探讨了如何使用因果树来估计异质性处理效应。主要的洞见来自于辅助结果变量的定义,这使我们能够将推断问题框架设为预测问题。虽然我们可以使用任何算法来预测处理效应,但回归树特别有用,因为它们具有良好的可解释性、预测准确性,并且能够生成作为子样本平均值的预测。
Athey 和 Imbens (2016) 关于回归树计算异质性处理效应的工作,将因果推断和机器学习这两个不同的文献结合在了一起,形成了非常有成效的协同效应。因果推断文献(重新)发现了样本分割的推断好处,这使我们能够在数据分割复杂且难以分析时进行正确的推断。另一方面,将树生成阶段与叶内预测阶段分开,有助于提高预测准确性,防止过拟合。
参考文献
-
S. Athey, G. Imbens, 异质因果效应的递归分割 (2016), PNAS。
-
S. Wager, S. Athey, 使用随机森林的异质性处理效应的估计和推断 (2018), 美国统计协会期刊。
-
S. Athey, J. Tibshirani, S. Wager, 广义随机森林 (2019). 统计年鉴。
-
M. Oprescu, V. Syrgkanis, Z. Wu, 用于因果推断的正交随机森林 (2019). 第 36 届国际机器学习会议论文集。
相关文章
-
DAG 和控制变量
-
匹配、加权还是回归?
-
理解元学习者
-
理解 AIPW,双重稳健估计量
代码
原始的 Jupyter Notebook 可以在这里找到:
github.com/matteocourthoud/Blog-Posts/blob/main/notebooks/causal_trees.ipynb
感谢阅读!
我非常感激! 🤗 如果你喜欢这篇文章并希望看到更多内容,请考虑 关注我。我每周发布一次与因果推断和数据分析相关的话题。我尽量保持帖子简单但精确,始终提供代码、示例和模拟。
此外,一个小小的 免责声明:我写作是为了学习,因此错误是常见的,尽管我尽力而为。如果你发现错误,请告诉我。我也欢迎对新话题的建议!
了解 ChatGPT 插件:益处、风险及未来发展
期待改进,而不是完美。
·发表于 Towards Data Science ·13 分钟阅读·2023 年 6 月 2 日
--

图片由作者提供。
当 ChatGPT 在 2022 年底首次发布时,它的能力既令人印象深刻又令人失望。它可以 进行说唱对决 和 用 LaTeX 写 微分方程,但对乌克兰战争一无所知,有时甚至无法做 简单数学。它的 能力和局限性 的鲜明对比,虽然令人困惑且有时显得神秘,但却引发了一些有趣的 Twitter 讨论,最终强调了模型急需连接到互联网的事实。

图片由作者提供。
但插件不仅仅是将 ChatGPT 连接到互联网,它们还可以将 ChatGPT 连接到其他外部数据源,如内部数据库甚至你的电子邮件收件箱。本质上,插件扩展了 ChatGPT 的能力,提高了响应的准确性,并允许更个性化的聊天机器人体验。
本文介绍了插件的工作原理、促使其融入 ChatGPT 的必要性,以及它们对 ChatGPT 用户体验的变革性影响。我们还将探讨这些插件相关的风险和限制,并对 ChatGPT 及其不断发展的插件生态系统的未来进行一些推测。
插件是什么?
插件是定制和扩展现有程序功能的软件附加组件。这些小的附加项,通常由第三方开发者创建和发布,旨在满足用户定义的特定需求,这些需求超出了原始产品的范围。
你可能对浏览器插件(也称为扩展)最为熟悉。例如,Google Translate提供了一个适用于基于 Chrome 的浏览器的插件,可以通过点击按钮翻译整个网页。另一方面,AdBlock是一个由第三方开发者提供的插件,可以阻止弹窗,移除网页和应用程序中的广告,并禁用侵入性的跟踪器。这些插件通过为浏览器添加功能来满足用户的特定需求。

图片由作者提供。
尽管插件可以非常有用且下载迅速,但需要记住的是,许多插件是由第三方开发者发布和维护的。这意味着某个应用程序的插件质量可能差异很大,你应始终独立验证它们的安全性。
现在我们已经回顾了插件到底是什么,让我们讨论在 ChatGPT 上下文中的插件功能。
没有插件的 ChatGPT
ChatGPT 是由OpenAI开发的大型语言模型(LLM),旨在作为聊天助手工作,生成类似人类的对话输出(响应)以回应文本输入(提示)。
尽管在机器学习社区中这一点仍然存在争议 ,解释 ChatGPT 最简单的方法是说它是一个统计模型,训练来预测基于到目前为止看到的单词(如预测搜索)和它所接受的训练数据的下一个单词。
在训练过程中,模型会接收大量的文本数据。模型通过研究不同文本之间的复杂模式、关系和依赖性来学习训练数据的潜在结构。这意味着当你从 ChatGPT 收到回应时,它并不是从一个定义的来源中提取特定的信息。相反,它是根据你提供的输入和它接受过的训练数据做出有根据的预测。你收到的回应是一个近似值,旨在模拟信息可能是什么,而不是直接检索具体数据。
然而,由于该模型没有连接到互联网或其他任何外部数据源,它可能会以非常有说服力的方式输出极其不准确的信息。例如,当我要求 ChatGPT 给出五篇关于“从人类反馈中进行强化学习”主题的有影响力论文的标题时,它建议的三篇文章根本不存在。

图片由作者提供。
也许这项回应中最令人担忧的方面是这些虚构的论文听起来完全可信。它们甚至被归因于该领域的知名作者。这被称为 幻觉,是 LLM 的最严重缺陷之一。虽然有几种机器学习方法可以减少 LLM 的幻觉,但通过使用插件使这些模型连接到数据源是一种成本较低的方法,可以提高准确性并进一步扩展 LLM 的能力。
带有插件的 ChatGPT
ChatGPT 插件通过向聊天机器人提供未包含在模型训练数据中的最新、个人或特定数据来增强聊天体验。LLM 连接到实时数据源(如数据库或网页浏览)的目的是帮助 ChatGPT 生成更准确和及时的结果。它还允许聊天机器人在 ChatGPT 界面内直接执行更复杂的任务,例如回复电子邮件。让我们将所有这些新增功能分解成不同的类别。
网页浏览
ChatGPT 的网页浏览模式的引入使你能够在不使用搜索引擎和点击链接的情况下进行研究和搜索互联网。这项新功能的目的是简化你的搜索过程和互联网浏览体验,省去手动任务,如筛选结果、点击链接和扫描网页以获取所需信息。
为了实现这一点,ChatGPT 本质上会将你的提示重新表述为搜索查询,执行互联网搜索,点击最相关的链接,阅读网页,然后用你所需的信息作出回应。让我们看看 ChatGPT 在网页浏览模式下如何回答问题 “谁在竞选 2024 年美国总统?”

图片由作者提供。
而原版 ChatGPT 会对这样的问题直接回避,网页浏览能力使 ChatGPT 能够提供关于基模型训练截止日期之后发生事件的相关及时信息。浏览模式还能够引用其来源,包括用来收集信息的在线来源链接,这允许用户独立验证 ChatGPT 响应的部分内容。
最终,将大型语言模型的自然语言能力与外部数据源(即互联网)连接起来,可以提供更有信息量和全面的体验。
任务自动化
虽然浏览模式对于涉及研究和信息收集的相对简单任务很有帮助,但 ChatGPT 的新自定义插件确实将其提升到了一个新的水平,使其能够在个性化的基础上自动化多步骤任务。
任务自动化,即利用技术来完成任务而不是手动完成,并不是新鲜事物。无论是使用超市的自助结账机、机场的自助登机 kiosks,还是让你的机器人吸尘器在晚餐后清理厨房,你都会遇到它。数字任务自动化利用软件来自动化常规和重复的任务,比如在你离开办公室时自动发送电子邮件回复。
将 LLMs 纳入数字任务自动化工作流可以实现更有效和个性化的常见任务自动化。当我们使用 Zapier 插件 并要求 ChatGPT 给我妈妈写一封电子邮件时,就会发生这种情况。为了完成这个任务,ChatGPT 会请求已经连接到我的 Gmail 账户的 Zapier,草拟一封包含所有相关信息(主题、收件人、语气)的电子邮件。Zapier 完成了这个任务,使用 ChatGPT 语言模型生成所有必要的文本,并提供一个链接,允许我直接从 Gmail 账户审阅并发送草稿。

图片由作者提供。
使用 ChatGPT 自动化任务的想法是,理论上,你可以在不必打开一堆标签页、手动点击和撰写文本的情况下完成任何数量的任务。这不仅能节省逐个案例的时间,还能在大规模上提供有价值的结果。例如,如果你每天经常发送和接收超过一百封电子邮件,使用一组 预定义的提示 与 ChatGPT 和 Zapier 插件结合使用,可以节省大量时间。
将 LLMs 与插件结合用于任务自动化,为我们处理日常和复杂任务的方式带来了令人兴奋的机会。以下是一些例子:
-
以对话方式查询和更新文档及数据库
-
自动化创建和发布社交媒体内容
-
自动化个人和职业通信
-
安排和取消约会及工作电话
-
根据特定要求生成食谱和餐单
-
根据某些标准规划假期和活动
-
总结给定数据集中的发现
-
为给定项目中的特定任务编写代码
缺点和限制
虽然插件可以帮助扩展大型语言模型(LLM)的现有功能,但它们也有自身的缺点和限制。以下是我在使用 ChatGPT 插件时遇到的一些最显著的限制:
-
插件速度较慢。 在简单的使用场景中,比如查找实时信息,使用搜索引擎可能更快。
-
插件仍然可能产生幻觉。 虽然幻觉可能会减少,但请记住,ChatGPT 背后的 LLM 仍然可能产生幻觉。

作者提供的图片。
-
插件可能会很脆弱。 插件依赖于外部 API,这意味着如果 API 出现故障或无法正常运行,插件可能变得几乎无用。我遇到过不止一次这种情况。
-
插件会暴露更多的个人数据。 为了利用 LLMs 处理你的个人数据,你需要将这些数据暴露给插件。在将插件与个人数据、文档或电子邮件帐户集成之前,请务必阅读每个插件的数据隐私政策。
-
有些插件并不复杂。 当你通过点击插件的下拉菜单查看某些插件的内部时,你会发现有些插件并不复杂。例如,当使用 ScholarAI 插件搜索文章时,它只是从提示中提取一个具有简单参数的搜索查询,然后在数据库中进行搜索。

作者提供的图片。
- 插件缺乏人类判断。 ChatGPT 的浏览模式和其他将用户连接到实时数据的插件正在为我们搜索互联网,实际上是“点击”链接并扫描那些网页以获取相关信息。这不能替代人类的判断,因为我们不知道如何选择或优先考虑来源的决策过程。
最有前景的 ChatGPT 插件
现在我们已经回顾了 ChatGPT 插件的工作原理和它们的局限性,让我们来看看一些最有前景的 ChatGPT 插件。虽然这些插件仍处于初期阶段,但它们具有广泛的功能和潜力,可以自动化各种任务,这些任务要么能节省大量时间,要么能带来显著的业务价值。
Zapier
Zapier 是一个工作流自动化平台,支持数千个流行应用程序,将它们连接在一起并执行自动化任务。最初,你需要在 Zapier 网站上手动配置这些任务;现在,你只需通过提示 ChatGPT 就可以做到这一点。
在之前的示例中,我展示了如何使用 Zapier 起草个人电子邮件。现在,让我们使用该插件自动将新数据添加到现有的 Google Sheets 文档中。为了简单起见,假设这个电子表格记录了我所有喜爱的狗的名字、品种和年龄。在安装并启用 Zapier 插件后,我只需用简单的英文告诉 ChatGPT 将数据添加到名为“favorite-dogs”的现有 Google 表格中。接下来,我点击响应中提供的链接,确认 ChatGPT 正确理解了我的请求,并确认该操作。当我打开“favorite-dogs”文档时,我会看到数据已被添加。

作者提供的图片。
Zapier 维护了一份 列表 ,展示了将不同应用程序连接起来以完成任务的创意方式,包括将 新的 Dropbox 文件转换为 PDF 和生成文档摘要,生成回复 对来电 SMS 消息,以及基于 Evernote 待办事项列表 优先排序任务。
Zapier ChatGPT 插件在自动化琐碎和重复性任务方面具有令人难以置信的潜力。它执行任务的速度相对较快,并在执行之前给你一个任务概览。然而,在使用此插件之前,我建议你查看 Zapier 的 OAuth 政策,这些政策指定了你授予 Zapier 执行任务所需的对外部账户(例如 Gmail、Google Drive)不同级别的访问权限和许可。
Noteable
Noteable 是一个用于临时分析、机器学习和数据应用的协作数据科学笔记本平台。像 Jupyter notebooks 一样,Noteable 是你可以编写代码来清理、总结和可视化数据,以及训练和评估机器学习模型的地方。Noteable ChatGPT 插件 允许你与现有数据和笔记本互动,甚至只需提示 ChatGPT 即可创建新的笔记本。这意味着你可以在不编写一行代码的情况下获得洞察并处理实际数据。

图片由作者提供。
使用 Noteable 插件,你还可以利用自然语言查询连接到 Noteable 的数据库,使用 Pandas 清理和转换原始数据,从外部来源加载数据,甚至创建图表。有关其他创意,查看这些 示例,这些笔记本是 Noteable 团队使用 ChatGPT 插件创建的。
为了安全起见,我不会将此插件用于敏感的公司或个人数据。 然而,如果你是一个有抱负的数据科学家,寻求一种有趣且富有创意的方法来提升技能,或如果你正在从事涉及公共代码和数据的项目,这个插件可能是一个有趣的工具来探索!
ChatGPT 插件的当前情况
插件通过允许 LLM(如 ChatGPT)访问原始训练数据中未包含的信息来扩展其能力。这使得大型语言模型有可能在准确性和任务自动化方面表现更好。一些首批推出的 ChatGPT 插件,如 Zapier 和 Noteable,展现了巨大潜力,允许用户通过简单的提示执行各种难度的任务。
然而,LLM 的插件仍处于初期阶段。第一代 ChatGPT 插件速度较慢,有点笨重,有时甚至无法正常工作。可以把这些插件看作类似于 2008 年首次上市的 iOS 和 Android 应用。开发者正在学习如何编写与用户以新方式互动并满足新需求的代码。未来的插件可能会更快、更复杂、更友好。
但是,重要的是要记住,将 LLM 连接到互联网和外部数据源自身也带有风险。语言模型在审查互联网来源时缺乏人类判断力,这意味着我们需要在让聊天机器人为我们浏览网页时保持谨慎和怀疑。同样,用户需要意识到他们将个人数据交给第三方插件,因为这些插件可能以用户未完全理解的方式收集、存储或分享这些数据。
ChatGPT 插件的未来
首先,ChatGPT 插件的复杂性和性能将会提高。希望随着公司投入更多资源维护第三方 API,它们会变得更可靠,插件最终也会能够更快地执行任务(尽管可能仅限于付费客户)。
其次,未来几个月,插件的数量和多样性将会显著增加。由于插件是与数字产品和内容建立互动的有趣且富有成效的方式,你可以肯定许多公司将投入资源开发创新且互动性强的插件用于 ChatGPT。而且,只需注册 OpenAI 的插件开发者访问(目前有等待名单),你就可以开始开发自己的插件,开源社区有潜力贡献各种创意插件。
最后,重要的是要记住 ChatGPT 的未来可能并不完全是 ChatGPT。各种程度的 开源 聊天助手和模型的数量和质量正在迅速增加,我几乎放弃了跟踪。由于开源社区致力于保持 LLM 的透明性和可访问性,加上美国和欧洲 AI 监管 的未来不确定,未来将会是什么样子很难说。
话虽如此,尽情玩转所有新的 ChatGPT 插件,并小心选择你分享的数据!
如果你想了解最新的数据科学趋势、技术和工具,考虑成为 Medium 会员。你将获得无限访问文章和博客的权限,如 Towards Data Science,同时支持我的写作。(我从每个会员中赚取少量佣金)。
[## 使用我的推荐链接加入 Medium - Mary Newhauser
每月 $5 即可无限访问 Medium 文章 🤗 你的会员费直接支持 Mary Newhauser 和…
medium.com](https://medium.com/@mary.newhauser/membership?source=post_page-----7a76f64e52ce--------------------------------)
想要联系我?
我还写过:
## GPT-4 与 ChatGPT 的比较:训练、性能、能力和局限性的探讨
GPT-4 是一种改进,但请调整你的期望值。
towardsdatascience.com ## 终极参考:干净的 Pandas 代码
清理文本数据的干净方法。
towardsdatascience.com ## PyCon 精华:PyCon DE 2023 中的精选演讲
LLMs 的孤立使用并不是未来。
towardsdatascience.com
参考文献
(1) A. George, 什么是插件,插件如何工作? (2021). Lifewire.
(2) K. Olang, Mozilla Firefox 打开空白页面 (2023). Chron.
(3) T. Ngo, ChatGPT 并非仅仅是在“预测”下一个词 (2023). LinkedIn.
(4) F. Neugebauer, 理解 LLM 幻觉 (2023). Towards Data Science.
(5) K. Eby, 自动化一切:任务自动化的终极指南 (2018). Smartsheet.
(6) E. Alston, 新!试用 Zapier 的 ChatGPT 插件 (2023). Zapier.
(7) U. Sharma, 125+ 个最佳 ChatGPT 提示,适用于各种工作流程 (2023). Beebom.
(8) Zapier. 利用 ChatGPT 集成做更多事 (2023). Zapier.
(9) Zapier. 将新的 Dropbox 文件转换为 PDF,并使用 ChatGPT 进行总结 (2023). Zapier.
(10) Zapier. 使用 ChatGPT 对新的 incoming sms77 消息生成回复 (2023). Zapier.
(11) Zapier. 让 ChatGPT 根据待办事项优先排序你的日程 (2023). Zapier.
(12) Noteable. ChatGPT 插件 (2023). Noteable.
(13) Noteable. ChatGPT 插件示例 (2023). Noteable.
(14) OpenAI. 聊天插件 (2023). OpenAI.
(15) OpenAI. ChatGPT 插件等候名单 (2023). OpenAI.
(16) S. Raschka, AI 领先 #8:最新的开源 LLM 和数据集 (2023). AI 领先.
(17) A. Engler, 欧盟与美国在 AI 监管上的分歧:跨大西洋比较及对齐步骤 (2023). 布鲁金斯学会。
理解深度学习优化器:动量、AdaGrad、RMSProp 与 Adam
了解神经网络中的加速训练技术的直观感受
·
关注 发布于 Towards Data Science ·8 min read·Dec 30, 2023
--
介绍
深度学习在人工智能领域迈出了巨大的一步。目前,神经网络在非表格数据(如图像、视频、音频等)上优于其他类型的算法。深度学习模型通常具有较强的复杂性,并且拥有数百万甚至数十亿个可训练的参数。因此,在现代时代,使用加速技术来减少训练时间是至关重要的。
在训练过程中最常见的算法之一是反向传播,它包括根据给定的损失函数调整神经网络的权重。反向传播通常通过梯度下降进行,该方法试图一步一步地将损失函数收敛到局部最小值。
事实证明,简单的梯度下降通常不是训练深度网络的首选,因为其收敛速度较慢。这激励了研究人员开发加速梯度下降的优化算法。
在阅读本文之前,强烈建议你了解指数移动平均的概念,它在优化算法中被使用。如果不了解,可以参考下面的文章。
理解梯度下降中使用的基本算法的逻辑
[towardsdatascience.com
梯度下降
梯度下降是最简单的优化算法,它计算损失函数相对于模型权重的梯度,并使用以下公式更新它们。

梯度下降方程。w 是权重向量,dw 是 w 的梯度,α 是学习率,t 是迭代次数。
为了理解为什么梯度下降收敛缓慢,让我们看下面的峡谷示例,其中一个包含两个变量的函数应该被最小化。

在峡谷区域中使用梯度下降的优化问题示例。起始点用蓝色表示,局部最小值用黑色表示。
峡谷是一个在一个维度上表面比另一个维度更陡峭的区域。
从图像中可以看出,起始点和局部最小值具有不同的横坐标,并且几乎相等的纵坐标。使用梯度下降寻找局部最小值可能会导致损失函数沿垂直轴慢慢振荡。这些跳跃发生是因为梯度下降不存储任何关于先前梯度的历史,使得每次迭代的梯度步长更加不确定。这个例子可以推广到更高维度。
因此,使用较大的学习率是有风险的,因为这可能导致不收敛。
动量
基于上述示例,期望在水平方向上使损失函数执行较大的步长,在垂直方向上执行较小的步长。这样,收敛速度会更快。这正是动量所实现的效果。
动量在每次迭代时使用一对方程:

动量法公式
第一个公式使用指数移动平均来处理梯度值dw。基本上,这是为了存储有关一组先前梯度值的趋势信息。第二个公式在当前迭代中使用计算出的移动平均值执行正常的梯度下降更新。α是算法的学习率。
动量法特别适用于上述情况。假设我们在每次迭代中计算了梯度,如上图所示。我们不仅简单地使用这些梯度来更新权重,还取几个过去的值,并在平均方向上进行更新。

Sebastian Ruder 在他的 论文 中简洁地描述了动量法的效果:“动量项对于梯度方向一致的维度增加,对于梯度方向改变的维度减少更新。因此,我们获得了更快的收敛速度和减少的振荡。”
因此,动量法执行的更新可能如下图所示。

动量法优化
在实践中,动量法通常比梯度下降法收敛更快。使用动量法时,使用较大学习率的风险也较小,从而加速了训练过程。
在动量法中,建议选择接近 0.9 的β值。
AdaGrad(自适应梯度算法)
AdaGrad 是另一种优化器,其动机是根据计算的梯度值调整学习率。在训练过程中,可能会出现权重向量的一个组件具有非常大的梯度值,而另一个组件具有极小的梯度值的情况。这种情况尤其发生在不常见的模型参数对预测的影响较小时。值得注意的是,对于频繁出现的参数,这类问题通常不会发生,因为模型在更新它们时会使用大量的预测信号。由于在梯度计算中考虑了大量信号的信息,梯度通常是适当的,并且代表了朝向局部最小值的正确方向。然而,对于稀有参数,这种情况并非如此,可能导致极大的不稳定梯度。相同的问题也可能出现在稀疏数据中,其中某些特征的信息过少。
AdaGrad 通过为每个权重组件独立调整学习率来解决上述问题。如果与某个权重向量组件相关的梯度较大,则相应的学习率会较小。相反,对于较小的梯度,学习率会较大。通过这种方式,Adagrad 处理了梯度消失和爆炸的问题。
在内部,Adagrad 累积所有先前迭代的梯度的元素平方 dw²。在权重更新期间,AdaGrad 不使用正常的学习率 α,而是通过将 α 除以累积梯度的平方根 √vₜ 来缩放它。此外,还向分母中添加了一个小的正数 ε,以防止潜在的除零错误。

AdaGrad 方程
AdaGrad 的最大优点是不再需要手动调整学习率,因为它在训练过程中会自行适应。然而,AdaGrad 也有负面的一面:学习率随着迭代次数的增加不断衰减(学习率总是被一个正的累积数除)。因此,该算法在最后几次迭代中趋于慢收敛,因为学习率变得非常低。

使用 AdaGrad 进行优化
RMSProp(均方根传播)
RMSProp 被详细阐述为对 AdaGrad 的改进,解决了学习率衰减的问题。与 AdaGrad 类似,RMSProp 使用一对方程,其权重更新完全相同。

RMSProp 方程
然而,与其为 vₜ 存储平方梯度的累积和 dw²,不如计算平方梯度 dw² 的指数移动平均。实验表明,由于指数移动平均,RMSProp 通常比 AdaGrad 收敛更快,因为它更重视最近的梯度值,而不是通过从第一次迭代开始简单地累积所有梯度来平均分配重要性。此外,与 AdaGrad 相比,RMSProp 的学习率并不总是随着迭代次数的增加而衰减,这使得它在特定情况下能够更好地适应。

使用 RMSProp 进行优化
在 RMSProp 中,建议选择接近 1 的 β。
为什么不简单地使用平方梯度 vₜ 而不是指数移动平均?
已知指数移动平均将更高的权重分配给最近的梯度值。这是 RMSProp 快速适应的原因之一。但如果我们只考虑每次迭代的最后一个平方梯度(vₜ = dw²),而不是使用移动平均,是否会更好?事实证明,更新方程将变成以下形式:

使用平方梯度代替指数移动平均时 RMSProp 方程的转换
如我们所见,结果公式与梯度下降中使用的公式非常相似。然而,我们现在使用的是梯度的符号,而不是正常的梯度值来进行更新:
-
如果 dw > 0,则权重 w 由 α 减少。
-
如果 dw < 0,则权重 w 由 α 增加。
总而言之,如果 vₜ = dw²,那么模型权重只能通过 ±α 来改变。尽管这种方法在某些情况下有效,但它的灵活性较差,算法对 α 的选择变得极其敏感,并且忽略了梯度的绝对大小,这可能导致方法收敛速度极其缓慢。该算法的一个积极方面是仅需一个位来存储梯度的符号,这在有严格内存要求的分布式计算中非常方便。
Adam(自适应矩估计)
目前,Adam 是深度学习中最著名的优化算法。从高层次来看,Adam 结合了动量和 RMSProp 算法。为了实现这一点,它分别跟踪计算梯度和平方梯度的指数移动平均值。

Adam 方程
此外,可以对移动平均进行偏差修正,以更精确地估算前几次迭代中的梯度趋势。实验表明,Adam 对几乎任何类型的神经网络架构适应良好,兼具动量和 RMSProp 的优点。

Adam 优化
根据 Adam 论文,超参数的良好默认值为 β₁ = 0.9,β₂ = 0.999,ε = 1e-8。
结论
我们已经研究了神经网络中的不同优化算法。Adam 被认为是动量(Momentum)和 RMSProp 的结合体,它在适应大规模数据集和深层网络方面表现最为出色。此外,它实现简单且内存需求少,使其在大多数情况下成为首选。
资源
除非另有说明,否则所有图片均由作者提供
理解 DeepMind 矩阵乘法
原文:
towardsdatascience.com/understanding-deepmind-matrix-multiplication-c8dc49687ce7
DeepMind 矩阵乘法在 NVIDIA V100、Tesla T4 上的表现,以及 FBHHRBNRSSSHK——这可不是我在随意输入字母!
·发表于Towards Data Science ·7 分钟阅读·2023 年 2 月 11 日
--

在之前的帖子中,我们学习了 Strassen 算法背后的数学,并编写了 Python 代码以在不同矩阵大小下进行测试。此外,我们了解到线性代数的圣杯是矩阵乘法的优化算法。通常,我们会把矩阵乘法代码看作是三个 for 循环:
def matmul(mat1, mat2, mat3):
r""" Function to multiply mat1 and mat2
returns mat3
Parameters
---------
mat1: np.array, matrix A
mat2: np.array, matrix B
mat3: np.array, empty matrix C
Return
------
mat3: np.array, matmul between A & B
"""
for i in range(mat1.shape):
for j in range(mat2.shape):
mat3[i][j] = 0.0
for k in range(mat3.shape):
mat3[i][j] += mat1[i][k]*mat2[k][j]
return mat3
因此,计算复杂度是O(n³)。Strassen 改进了这一计算,找到了以下关系:

图 1:Strassen 算法,如其论文“高斯消去法不是最优的”中提出的。[作者提供的图片]
该算法应用于块矩阵,总复杂度降至O(n²·⁸⁰⁸)。虽然 2.808 可能看起来改进很小,但我们看到对于 4096 大小的方阵,标准 numpy matmult大约需要 454.37 +/- 6.27 秒,而 Strassen 需要 31.57 +/- 1.01 秒,差异约为一个数量级。
我们看到矩阵乘法问题可以简化为张量积,通过张量操作:

图 2:矩阵乘法张量的三元组定义,如 Deep Mind 论文中所定义的。[作者提供的图片]
图 2 准确报告了矩阵乘法,表示为三元组,即三个元素。最小三元组数定义了计算乘积矩阵所需的最小操作数。 这个最小数就是张量的秩 R(t)。研究张量秩是一种有效的方式来寻找新的乘法算法,如 DeepMind 的论文中所述。
DeepMind 这些年来展示了如何通过机器学习方法、特别是强化学习(RL)方法,解决从理论到应用的数学问题。在此期间,他们还调整了 AlphaZero,以寻找矩阵乘法的最佳策略,结果就是 AlphaTensor。我认为现在定义一下我们可以从这篇论文中欣赏到的内容是值得的:
-
DeepMind 再次证明了 RL 可以成为解决复杂数学问题的强大助手;
-
AlphaTensor 找到了一个可以比 Strassen 算法更好的算法,用于 4 x 4 和 5 x 5 块矩阵的乘法;
-
此外,AlphaTensor 可以为特定的硬件需求找到最优解。正如他们在论文中所示,可能会有专门针对 TPU 和 GPU(论文中的 V100)的算法;
-
尽管这些可能不是最佳结果,但数学家们现在可以拥有一套全新的矩阵乘法方程,这可以作为寻找新的最优解的起点。
在 V100 GPU 上测试 AlphaTensor
幸运的是,DeepMind 在其GitHub 上提供了 AlphaTensor 的实现,可以进行测试,全部用 JAX 编写。代码专为 4 x 4 块矩阵的乘法设计,共有 47 次乘法,已在张量表示中报告,具体可见这里。
基准测试在 TPU 和 V100 GPU 上进行,他们测试了以下尺寸的矩阵乘法:8192, 10240, 12288, 14336, 16384, 18432, 20480,对于标准的[jax.numpy.dot](https://github.com/deepmind/alphatensor/blob/1949163da3bef7e3eb268a3ac015fd1c2dbfc767/benchmarking/utils.py#L155)乘法,Strassen 算法和 AlphaTensor 算法。
我分叉了代码,并进行了两个小的修改:
-
我正在打印每种方法的时间
-
由于我们对每个算法进行 10 次乘法运算,我将基准函数中的 平均次数修改为 10,而不是保持为 20。
在拥有 Google Cloud Console 的免费访问权限(仍然依赖$300 积分)的情况下,我在 GCP Compute Engine 中创建了一个虚拟机以测试 AlphaTensor,具体如下:
-
运行区域为
europe-west4 -
选择了带有 1 NVIDIA V100 GPU 的 GPU 机器
-
自动选择 CPU 平台
-
n1-standard-4机器类型(4 vCPU 和 15 GB RAM) -
我将操作系统镜像更改为:操作系统
Deep Learning on Linux和版本基于 Debian 10 的 Deep Learning VM 用于 TensorFlow Enterprise 1.15,带有 CUDA 11.0 M100,磁盘大小为50GB
总成本为每小时$1.94 — 因此要小心,不要让这台机器无限期运行。
一旦创建了机器,你可以直接通过 SSH 访问并使用git clone https://github.com/Steboss/alphatensor.git下载代码库。你需要设置 Python 环境并使用pip3 install -r alphatensor/benchmarking/requirements.txt -f [storage.googleapis.com/jax-releases/jax_cuda_releases.html](https://storage.googleapis.com/jax-releases/jax_cuda_releases.html)安装jax。最后,你可以通过python alphatensor.benchmarking.run_gpu_benchmark运行测试。

图 3:Jax 矩阵乘法、Strassen 算法与 AlphaTensor 在 V100 GPU 上的比较。 [图片来源:作者]。
图 3 显示了每个算法相对于矩阵大小的性能时间。我们可以看到,对于小矩阵尺寸,8192 和 10240,Strassen 相对于标准 JAX 实现了约 6.5%的提升,与 AlphaTensor 约 8.5%的提升相当。对于大矩阵,取得了优异的结果,因此对于 18432 的方阵,Strassen 的计算提升了15%(7.37 +/- 0.01),而 AlphaTensor 相对于 JAX(8.53 +/- 0.01)达到了16%的提升(7.31 +/- 0.01)。
如果我没有免费访问 V100 的权限怎么办?
我还在 Google Colab 上进行了另一个测试。在这种情况下,我们可以依赖 Tesla T4 GPU。虽然算法已经在 V100 上测试过,但值得调查其可迁移性并比较结果。与 V100 测试类似,我在 Google Colab 笔记本上复制了这些计算,去除了这些行

图 4:Jax 矩阵乘法、Strassen 算法与 AlphaTensor 在 Tesla T4 上的比较。 [图片来源:作者]。
正如你所看到的,我们在结果中有更多的变化,特别是在 16384 大小的矩阵中,我们可以看到所有算法实现了相同的性能时序。这并不准确,因为这可能是由于我们在 Google Colab 上无法管理的一些停机时间。表 1 总结了在 Tesla T4 上的所有发现:
表 1:在 Tesla T4 上对 JAX、Strassen 和 AlphaTensor 矩阵乘法算法的性能时序比较。
尺寸 12288 和 16384 是棘手的点,在这些点上我们相对于 JAX 标准乘法没有实际改进。另一方面,我们可以看到对于非常大的矩阵有改进,在 18432 时,Strassen 实现了 20% 的加速,AlphaTensor 实现了 22% 的加速。
这是故事的结局吗?
就在 DeepMind 论文 发布几天后,Manuel Kauers 和 Jakob Moosbauer 写了一篇精彩的论文回复,提出了 FBHHRBNRSSSHK-算法。该算法基于 DeepMind 的发现,使用 47 次乘法改进了 4x4 矩阵的计算,而不是 AlphaTensor 发现的 49 次乘法,5x5 矩阵的乘法次数减少到 95 次(而不是 AlphaTensor 提出的 96 次)。这是一个好消息,表明人类可以与 ML 算法有效合作。在这篇回复之后,Kauers 和 Moosbauer 发表了一篇出色的数学论文,名为 “矩阵乘法的翻转图”。在这篇论文中,作者展示了他们找到的进一步改进矩阵乘法的技术。特别是,这项技术的核心部分是从已知的矩阵乘法方案开始,并将其分组在一个图中:
我们定义了一个图,其中的顶点是正确的矩阵乘法方案,如果第二个方案可以通过某种变换从第一个方案获得,则在两个方案之间存在一条边。我们考虑了两种变换。一种叫做翻转,将给定方案转变为一个不同的方案,乘法次数相同,另一种叫做归约,将给定方案转变为一个乘法次数更少的方案。
所有矩阵乘法方案之间的导航是通过随机游走完成的。然后,翻转可以使用以下想法进行:
这个想法是从一个秩-1 张量中减去某些东西,然后将其添加到其他张量中。

图 5:翻转图示例,展示了所有算法如何作为图中的节点存在,翻转是无向边,归约变换是有向边。[图片来源:Manuel Kauers 和 Jakob Moosbauer]。
图 5 显示了论文中的图像,作者描绘了所有已知方案以及它们如何通过翻转和归约变换相互连接。因此,这并不是故事的结局,而是另一个很好的起点,将带来越来越高效的矩阵乘法算法。
结论
今天我们结束了对 DeepMind 论文《利用强化学习发现更快的矩阵乘法算法》的评审。这篇论文在矩阵乘法领域引起了极大的兴趣,并且显然带来了许多问题。
我们从论文中得出了 4 个主要观点:
-
强化学习是一个强大的工具,可以帮助我们解决数学问题。
-
AlphaTensor 已经找到了乘法 5x5 和 4x4 矩阵的新公式。
-
AlphaTensor 可以为特定硬件(例如 GPU 或 TPU)找到最佳解决方案。
-
这是新研究矩阵乘法问题的一个很好的起点。
然后,我们在 NVIDIA V100 GPU 和 Tesla T4 上运行了 AlphaTensor。尽管有一些起伏,我们可以看到总体上 AlphaTensor 提高了计算效率,在 V100 上提高了多达 16%,在 Tesla T4 上提高了 22% —— 尽管该算法并未针对这种 GPU 进行优化。
最后,我们看到这不仅仅是故事的结束,而是一个美好的新开始。一个例子是FBHHRBNRSSSHK-算法,它证明了如何通过纯数学形式进一步利用 DeepMind 的解决方案,找到新的、更高效的矩阵乘法技术。
支持我的写作:
[## 通过我的推荐链接加入 Medium - Stefano Bosisio
阅读 Stefano Bosisio 的每个故事(以及 Medium 上成千上万其他作家的故事)。为什么支持我?1)关于 AI 的文章……
stefanobosisio1.medium.com](https://stefanobosisio1.medium.com/membership?source=post_page-----c8dc49687ce7--------------------------------)
如果有任何问题或意见,请随时发邮件至:stefanobosisio1@gmail.com,或者直接在 Medium 上联系我。
理解 Power BI 中的不同缓存类型
原文:
towardsdatascience.com/understanding-different-cache-types-in-power-bi-f1e205f5956e
你知道 Power BI 依赖于两种不同的缓存类型吗?在本文中,我们将揭示它们在实际中的工作原理
·发布在Towards Data Science ·9 分钟阅读·2023 年 3 月 22 日
--

图片由作者提供
你多少次遇到过这样的情况?第一次打开报告时,渲染需要一些时间,但一旦你在其他报告页面之间来回切换,那个页面的渲染速度会显著提高!
是的,我知道,我们都经历过多次。这是因为 Power BI 会缓存数据,并且在第一次运行之后响应会更快。
听起来很简单,对吧?其实并不像那样简单,本文将尝试揭示 Power BI 中不同缓存类型的奥秘。
开始前推荐阅读: 由于我会提到 Power BI 的一些内部架构组件,即存储引擎和公式引擎,我建议你首先阅读这篇文章以了解这两者之间的区别。你还应该了解这两个引擎在数据检索过程中的不同角色。这非常重要,因为本文其余部分将假定你已经了解存储引擎和公式引擎的关键特性。
缓存类型一览
首先从高层次的视角解释两种主要的缓存类型,然后我们将深入探讨每种类型的细微差别。
Power BI 中的视觉缓存
我们先从一个非常简单的例子开始。我将使用一个示例 Contoso 数据库进行所有演示。

图片由作者提供
我有一个簇状柱形图视觉对象,显示 Contoso 数据库中每个品牌的总销售额。还有一个品牌名称的切片器。让我们在 Power BI Desktop 中打开性能分析器,并选择切片器中的一个值:

图片由作者提供
正如你可能注意到的,公式引擎生成了一个 DAX 查询来检索有关 Contoso 品牌销售的数据,存储引擎需要 14 毫秒来实际返回这些数据。由于我们使用的是导入存储模式,并且我在 Power BI Desktop 中,因此数据存储在 Analysis Services 的本地实例中。
现在让我们将切片器值更改为 Litware:

图片由作者提供
再次,和之前的情况一样,发生了相同的工作流程。现在,如果我在切片器中切换回 Contoso,会发生什么呢?

图片由作者提供
现在事情变得有趣了!完全没有 DAX 查询,也没有“复制查询”选项,这使我们无法获取查询并在例如 DAX Studio 中更详细地分析它—这个选项被禁用了!这意味着,公式引擎没有生成查询,这个视觉对象的数据是从缓存中提供的。在这种情况下,我们谈论的是视觉缓存。
如果我再次在切片器中选择 Litware,情况也会如此。然而,一旦我点击顶部的刷新视觉对象选项……

图片由作者提供
尽管再次检索 Contoso 的数据,在这种情况下,视觉缓存已被清除,公式引擎再次生成了一个 DAX 查询。
显然,在这个超级基础的示例中,很难发现两个场景之间的显著性能差异。但实际上,我们通常会应用更复杂的逻辑,从缓存中检索查询结果通常比一遍遍运行相同的查询要快得多。
如果我现在从 DAX Studio 连接到我的 Analysis Services 本地实例,并打开所有查询,一旦我点击刷新视觉对象,所有查询都将被 DAX Studio 捕获:

图片由作者提供
从这里,我将双击第一个查询并在 DAX Studio 中执行它

图片由作者提供
这个包含查询结果的表格将被报告缓存。每当我们的视觉对象请求相同的结果时,数据可以从缓存中提供。
那 Power BI 服务呢?
好吧,在上面的示例中,我们解释了 Power BI 如何在使用 Power BI Desktop 时缓存查询结果,即在 Analysis Services 的本地实例中。合理的问题是:一旦我们转到 Power BI 服务,这个缓存“东西”还会有效吗?
答案是——是的!在这种情况下,它是通过你的网页浏览器完成的。然而,请记住,视觉缓存的范围是特定的 Power BI 会话。我们稍后会更详细地解释这如何运作。
Power BI 中的数据缓存
数据缓存是 Power BI 中的另一种缓存类型。与之前场景中的缓存发生在个别报告用户级别不同,数据缓存是在更通用的级别——Analysis Services 表格模型的级别上操作的。
如果你完成了作业并阅读了我一开始推荐的文章,你可能已经知道 VertiPaq 以压缩的方式将我们的 Contoso 数据存储在内存中。
那么,当我们要求 Power BI 计算 Contoso 品牌的总销售额时,究竟发生了什么?公式引擎生成并执行 DAX 查询,但随后 Storage Engine 将 DAX 转换为一种类似 SQL 的特殊语言,称为 xmSQL,以物理地从表格模型中提取数据。

作者提供的图片
对于每个 xmSQL 查询,都有一个特殊的数据结构,称为 datacache,存储在内存中。
如果我在 DAX Studio 中开启服务器计时并运行前一个示例中捕获的 DAX 查询:

作者提供的图片
如你所见,查询结果是从缓存中检索的(我们有一个 Storage Engine 查询,而那个 Storage Engine 查询是使用了缓存)。这意味着实际上,我们并没有真正查询 Analysis Services 模型。为了确认这一点,我会在 DAX Studio 中打开 Cache 选项卡,你会看到在第 1 行,这个查询实际上是从缓存中提供的,而不是来自 Analysis Services 的内部数据结构。

作者提供的图片
注意!有关数据缓存需要记住的事项
现在,有至少两个重要的考虑因素需要记住,关于数据缓存。
首先,这可能会让你觉得你的查询运行得很快,尽管这可能不是真的。假设你正在排查一个性能差的查询,并利用 DAX Studio 深入了解后台发生了什么。你第一次运行查询时,它需要 2000 毫秒才能返回结果。
然后你进行了一些小的更改,再次运行查询——现在它在 100 毫秒内渲染完成。耶!你已经开始想着:“嗯,为什么大家都说 DAX 很难?我只是重新排序了代码中的行,它的速度快了 20 倍……”
是的,没错!第二天早上,一位报告用户再次抱怨相同的报告视图渲染缓慢。
你可能忘记在运行“改进版”DAX 计算之前清除数据缓存了。

作者提供的图片
即使在这个极其简单的计算中,从缓存中获得的结果集的场景也比查询分析服务内部数据结构的场景快了将近 4 倍(请注意第 1 行,其中包含 Internal 作为子类)。
其次,由于数据缓存驻留在内存中,这意味着它确实有有限的资源。简单来说,并不是所有的查询都可以在分析服务中缓存! 根据查询检索的数据量,可能发生只有一部分(或没有)存储引擎查询可以从缓存中检索到。
让我展示一下这在现实中是如何表现的。我在 Power BI 报告中添加了一些更多的内容。假设我想计算每一年中我们有多少个不同的订单。我创建了一个简单的 DAX 度量来计算这个:
Distinct Orders = DISTINCTCOUNT(FactOnlineSales[SalesOrderNumber])
接下来,我想将这个值与去年值进行比较,因此我将创建另一个度量来计算去年的不同订单数量:
Distinct Orders PY = CALCULATE(
[Distinct Orders],
SAMEPERIODLASTYEAR(DimDate[Datekey])
)

作者图片
让我们切换到 DAX Studio 并检查生成的数据查询:

作者图片
由于在运行这个查询之前我没有清除数据缓存,所以所有 10 个存储引擎查询都从缓存中检索到了!总的来说,结果在 7 毫秒内返回。
我现在将清除缓存并重新运行相同的查询:

作者图片
结果从 7 毫秒变成了超过 1 秒!这就是我告诉你在运行相同的查询之前清除数据缓存至关重要的原因。
现在让我们检查一下如果我们将范围包括个别日期而不是年份会发生什么:

作者图片
在没有清除缓存的情况下,这个查询花费了超过 16 秒来返回结果!
简要解释一下这里发生了什么:当查询执行时,存储引擎检索数据并在一个名为 datacache 的特殊结构中物化中间查询结果。这个 datacache 最终被公式引擎消耗,然后最终结果集才返回到报告中。现在,根据许多不同的因素,有时所有必要的数据可以在一个 datacache 中物化,但有时可能发生数据量过多,导致存储引擎创建多个 datacache。
在我们的例子中,我们可以看到这些查询都非常快——仅需几毫秒——但查询数量很多。准确地说:2195 个查询!
现在,如果要物化的数据不是很多,会发生什么呢?我们能否“帮助”引擎再次利用数据缓存功能?
我将向我的报告中添加一个日期切片器,仅包含 2010 年 1 月 1 日之后的日期:

作者图片
让我们看看 DAX Studio 服务器时间现在显示什么:

作者提供的图片
显然,查询运行得更快,因为引擎需要处理的 datacaches 数量较少。但我们仍然无法充分利用数据缓存。
现在,我们只包含 2010 年 10 月 1 日之后的日期,并在 DAX Studio 中检查查询:

作者提供的图片
这次,我们命中了缓存,差异巨大!这个查询只用了 134 毫秒就返回了结果。
总结来说,从数据缓存的角度来看,查询扫描和物化的数据量非常重要。
结论
缓存查询结果是关键的性能优化概念之一,这不仅与 Power BI 和表格模型相关,而是普遍适用的。
当我们检查 Power BI 缓存类型时,你需要了解两种不同的缓存:
-
视觉缓存(或报告缓存)— 数据在特定的 Power BI 会话范围内被缓存,无论是本地机器上的 Power BI Desktop 会话,还是 Power BI 服务中的会话。
-
数据缓存 — 数据在 Analysis Services 实例的范围内被缓存,而不管打开了多少 Power BI 会话。
感谢阅读!
理解梯度提升:数据科学家的指南
原文:
towardsdatascience.com/understanding-gradient-boosting-a-data-scientists-guide-f5e0e013f441

图片由 Midjourney 提供
·发表于 Towards Data Science ·阅读时间 10 分钟·2023 年 2 月 7 日
--
梯度提升机(GBM)是机器学习和数据科学领域最重要的进展之一,它使我们这些从业者能够使用模型集成来解决许多领域特定的问题。尽管这一工具在scikit-learn和xgboost等 Python 包中广泛可用,但作为数据科学家,我们应该始终深入探讨模型的理论和数学,而不是将其视为黑箱。在这篇博客中,我们将深入探讨以下领域:
-
GBM 的不同支持概念
-
分步图解以重建 GBM
-
优点和缺点

让我们深入了解一下——图片来自 GIPHY
梯度提升的基本原理
1. 弱学习者和集成学习
弱学习者和集成学习是使梯度提升有效的两个关键概念。弱学习者是一个模型,它的表现仅比随机猜测稍好。与许多其他弱学习者结合,可以形成一个强大的集成模型,进行准确预测。
文字过多,过于复杂
好吧,假设我们正在和两个朋友一起玩 10000 片拼图。(他们一定是很棒的朋友才愿意报名)我们每个人负责拼接四个象限中的一个。虽然我们可能只能自己解决拼图的一小部分,但当我们作为团队合作时,我们可以迅速完成整个拼图。
在这种情况下,我们每个人都是一个弱学习者,四个人的团队就是集成模型。就像我们关注我们的象限一样,集成模型中的每个弱学习者擅长根据某些特征和特性进行预测。当我们四个弱学习者聚在一起分享对某个数据点是否属于某个象限的看法时,集成模型能够提供比我们单独预测更准确的结果。
本质上,弱学习者和集成学习是集体智慧的力量。正如古老的谚语所说:
整体大于部分之和。
2. 加法模型
在梯度提升算法中,弱学习者模型是迭代地添加到集成中的。它几乎看起来像是泰勒近似,其中最终值是使用粗略估计通过一系列修正项进行修正。由于每个弱学习者都会贡献一个修正项,这使得 GBM 在添加模型时非常灵活,当预测结果表明过拟合时可以进行调整。

e^x 的泰勒展开从 1 开始并进行迭代修正 — 作者图片

GBM 从基线开始,通过加法模型进行迭代修正 — 作者图片
3. 损失函数
当我们说一个模型经过迭代修正或改进时,我们必须明白并不是所有的改进都是相同的。例如,我和我哥哥的考试分数从 10 分提高到 20 分和从 80 分提高到 90 分,满分为 100 分。尽管我们都提高了 10 分,但我们能说这些改进是一样的吗?
这就引出了损失函数。
损失函数只是测量两个值之间差异的数学方法。在机器学习的背景下,它可以用作衡量预测值与实际值之间差异的评分卡,从而评估模型性能。损失越大,模型表现越差。
在我们上面的例子中,如果我们考虑百分比改进,从 10 到 20 的 10 分提高显得更加重要,或者如果我们考虑我们的分数离满分 100 有多远,这种提高也可能显得不那么重要。

考试示例中的损失函数 — 作者图片
根据我们定义模型成功的标准,我们可能会选择不同的损失函数。一些常见的选择如下:
-
MSE(均方误差): 通常用于回归模型
-
MAE(均方绝对误差)
-
Log Loss(交叉熵): 通常用于分类模型
4. 梯度下降
仍以相同的例子为例,如果我们希望在提高分数方面最有效率,我们可能需要在最容易获得额外积分的地方更加努力学习(例如,对像我这样心不在焉的人来说拼写错误)。一旦征服了这一点,我们就继续处理下一个最容易获得额外积分的项目。如此反复,直到达到 100 分。
这正是梯度下降的工作原理!
从技术上讲,梯度下降是一种机制,旨在通过反复沿着函数值减少最快的方向移动,来探索函数的最小值。在机器学习的背景下,通过最小化损失函数,我们试图识别出一组最佳的模型参数,以便做出准确的预测。
不要过多地跑题到梯度下降的博客文章中,这种技术的一个重要问题是算法可能会收敛到一个次优的最小值。

最小值 — 作者提供的图像
请放心。有不同的方法来处理这些情况。示例包括以下内容:
-
学习率:这调整了我们每一步的移动量。其直觉是,如果它不是绝对的低谷,我们可能会因为步长过大而错过它。
-
动量方法:这种方法考虑了我们之前的步骤。算法会将前一步的一个部分(即动量)传递给下一步,以平滑振荡。(即,动量将 90 度的右转变成一个平缓的曲线)
梯度下降基本上是一种根查找算法。如果你对根查找算法感兴趣,可以看看我关于根查找算法的博客!
高效的根查找算法在 Python 中
在 Python 中实现高效的根查找算法和优化
towardsdatascience.com
在我们深入理解梯度提升算法的工作原理之前,先简单介绍一下自己。
如果你喜欢这篇文章,你也可以通过下面的附属链接来支持我订阅 Medium。这是一个我发现很多有趣阅读的平台。即使你不打算订阅,你也可以通过点赞来支持我和我的创作。
## 通过我的推荐链接加入 Medium - Louis Chan
阅读 Louis Chan 的每一个故事(以及 Medium 上成千上万的其他作家的故事)。你的会员费用直接支持…
感谢你容忍这个插件。这是来自萨摩耶的一次可爱眨眼,作为我的感谢。
图片来自 GIPHY
现在回到正题!
梯度提升算法
从头开始的算法
假设我们想用一个1 棵树的梯度提升机来预测某个地区的公寓租金:

图片由作者提供
我们首先计算平均租金,即 688。这将是我们的基准模型,即预测租金为该区域的平均值。我们也可以将其理解为目标变量的简单预测。

图片由作者提供
然后计算租金与平均值之间的差异——这是我们的弱学习器会尝试迭代最小化的间隙。在我们的例子中,我们只有一个弱学习器。

图片由作者提供
现在是时候建立我们的弱学习器组来预测来自平均值的残差了。让我们从决策树开始:

图片由作者提供
如果叶节点中有多个项目,则预测结果应为这些项目的平均值。

图片由作者提供
现在,让我们尝试使用这棵新决策树进行预测,并将其与我们之前获得的基准进行比较。计算预测值的公式如下:

图片由作者提供
通过将学习率作为预测残差的修正因子,我们应用了梯度下降的概念,即采用逐步方法来最小化损失函数。在我们的案例中,“损失函数”是我们的残差。

图片由作者提供
基于第一棵树的预测值,我们可以计算新的残差。

图片由作者提供
看!通过将我们创建的树“添加”到我们之前仅使用平均值做出的基线预测中,残差已经变得更小了!
如果我们要训练一个合适的梯度提升模型,我们需要重复拟合大量基于树的模型(通常超过 1,000 棵树)以形成加法模型。
一般来说,在拟合了 1,000 棵树后,我们可以使用以下公式来计算最终预测值:

图片由作者提供
超参数
根据上述例子,我们可以总结出具有以下超参数的 GBM 的特点:
-
加法模型数量: 集成中的弱学习器总数。
-
学习率: 决定弱学习者对最终结果贡献的修饰符。这也决定了模型调整梯度的速度。
-
每个模型的最大深度: 每个弱学习者的最大深度。较浅的弱学习者意味着在拟合单个弱学习者时更具内存效率,但可能无法捕捉变量之间更复杂的交互。
-
弱学习者终端节点中的最小观察数: 较小的值通常适用于不平衡的数据集,而较大的值适用于更平衡的数据集。
梯度提升的优点
-
强大的预测性能: 尽管这并非使用梯度提升的自然优势,但回顾来说,梯度提升在 Kaggle 上的各种比赛中常常获胜。这可以归功于梯度提升结合了许多较小模型的特点,并利用群体智慧进行最终预测,而不是试图将所有数据模式拟合到一个模型中。
-
灵活性: 梯度提升是一种可以应用于回归或分类问题的模型类型,使用不同的弱学习者(不一定是决策树)、损失函数和数据类型(有序、连续、类别等)。
-
易于获取: 不论你使用 R (gbm, xgboost, lightgbm), Julia (GradientBoost), 还是 Python (sklearn, xgboost, lightgbm, catboost),都有许多模块可以将梯度提升应用于你的数据问题。
-
可解释性: 相比于神经网络,梯度提升机器可以说是模型拓扑中在复杂性和可解释性之间平衡较好的。如果你想了解更多关于解释机器学习模型的内容,这里有一篇博文,我深入探讨了 SHAP 的工作原理。
你的 SHAP、TreeSHAP 和 DeepSHAP 综合指南
towardsdatascience.com
梯度提升的缺点
-
计算复杂度: 拟合/训练梯度提升机涉及通常拟合超过 1,000 个小弱学习者。虽然保持弱学习者较小可以减少训练时间,但当我们开始扩展模型时,这种时间仍会很快累积。
-
过拟合: 另一个拥有这么多小弱学习者的可能缺点是过拟合的风险。减轻这种风险的一种方法是预留一个验证集或使用交叉验证来评估模型的表现。
-
超参数调优: 这可能是实际上任何机器学习模型中最少提及的缺点之一。虽然模型具有一些预定义的超参数,例如树的最大深度=3,但这些假设是基于特定的统计观察做出的。这些假设是否代表了我们的数据问题,那是完全不同的故事。因此,我们不仅要关注超参数的“是什么”,还要关注“为什么”。
结论
就这样,一个逐步教程数据科学指南,讲解梯度提升的工作原理。
不要止步于此
像任何领域一样,数据科学是一个需要不断打磨思维和获取新知识的领域,以便在众人中脱颖而出。如果你想了解更多关于我对各种数据科学相关主题的看法,可以从下面的列表中选择:
最后但绝对不容忽视的是,如果我遗漏了或误解了任何关键内容,请随时在评论区留言或通过 LinkedIn 给我发私信。让我们一起保持知识流动,共同在这一领域进步!
[## Louis Chan — 领先的 GCP 数据与机器学习工程师 — 副总监 — KPMG UK | LinkedIn
有雄心、好奇心强且富有创造力的个人,对知识各领域之间的相互联系有着强烈的信念以及…
参考文献
-
《梯度提升的温和介绍》,作者Brownlee(2018 年)
-
《使用 Python 的梯度提升与 XGBoost》,作者Raschka(2017 年)
-
《统计学习简介》,作者James等(2013 年)
-
《集成学习》,作者Alpaydin(2010 年)
-
《提升方法简介》,作者Schapire(2003 年)
理解机器学习中的梯度下降
原文:
towardsdatascience.com/understanding-gradient-descent-for-machine-learning-246e324c229
使用 Python 深入探讨批量、随机和小批量梯度下降算法
·发布于Towards Data Science ·阅读时间 14 分钟·2023 年 5 月 21 日
--

图片由Lucas Clara提供,来源于Unsplash
梯度下降是一种流行的优化算法,广泛应用于机器学习和深度学习模型,如线性回归、逻辑回归和神经网络。它通过迭代使用一阶导数来最小化成本函数,通过更新模型系数(用于回归)和权重(用于神经网络)。
在本文中,我们将深入探讨梯度下降的数学理论,并探索如何使用 Python 进行计算。我们将检查包括批量梯度下降、随机梯度下降和小批量梯度下降在内的各种实现,并评估它们在不同测试案例中的效果。
在阅读本文的同时,你可以查看我 GitHub 上的Jupyter Notebook以获取完整的分析和代码。
在深入探讨梯度下降之前,让我们首先了解损失函数。
什么是损失函数?
损失或成本这两个术语可以互换使用,用来描述预测中的误差。损失值表示预测值与实际值的差异,损失函数将来自多个数据点的所有损失值汇总为一个单一的数字。
从下图中可以看出,左侧的模型具有高损失,而右侧的模型具有低损失,并且更好地拟合了数据。

高损失与低损失(蓝线)相对于黄色回归线的对比。
损失函数(J)用作预测算法的性能测量工具,预测模型的主要目标是最小化其损失函数,这由模型参数的值(即 θ0 和 θ1)决定。
例如,线性回归模型通常使用平方损失来计算损失值,而均方误差是平均所有平方损失的损失函数。

平方损失值(L2 损失)和均方误差(MSE)
线性回归模型在后台通过多次迭代来优化其系数,以达到尽可能低的均方误差。
什么是梯度下降?
梯度下降算法通常用山的类比来描述:
⛰ 想象你站在山顶,视野有限,你想要到达地面。在下坡时,你会遇到斜坡,并通过较大或较小的步伐通过它们。一旦你到达几乎平坦的斜坡,你就会知道你已经到达最低点。 ⛰
从技术上讲,梯度指的就是这些斜率。当斜率为零时,可能表示你已经到达了函数的最小值或最大值。

就像在山的类比中一样,GD 通过在梯度的相反方向上重复迈步来最小化起始损失值,从而减少损失函数。
在曲线上的任何一点,斜率的陡峭程度可以通过切线来确定——一条与该点相切的直线(如上图中的红线)。类似于切线,损失函数上某一点的梯度是相对于参数计算的,并且会朝相反方向迈出小步以减少损失。
总结一下,梯度下降的过程可以分解为以下步骤:
-
选择模型参数的起始点。
-
确定成本函数相对于参数的梯度,并通过迭代步骤不断调整参数值以最小化成本函数。
-
重复步骤 2,直到成本函数不再减少或达到最大迭代次数。
我们可以检查之前定义的成本(损失)函数的梯度计算。虽然我们正在使用具有截距和系数的线性回归,但这种推理可以扩展到包含多个变量的回归模型。

具有 2 个参数的线性回归函数、成本函数和目标函数

相对于模型参数计算的偏导数
💡 有时,已达到的点可能只是一个局部最小值或平坦区域。在这种情况下,模型需要继续迭代,直到达到全局最小值。不幸的是,达到全局最小值并不保证,但通过适当的迭代次数和学习率,我们可以提高成功的几率。

使用梯度下降时,需要注意可能会停留在局部最小值或平坦区域的挑战。为避免这种情况,选择适当的迭代次数和学习率至关重要。我们将在接下来的部分中进一步讨论这一点。
Learning_rate是定义学习步长大小的梯度下降超参数。可以使用超参数调整技术进行调节。
- 如果
learning_rate设置得过高,可能会导致跳跃,从而产生比起始点更大的损失值。高learning_rate可能会导致梯度下降发散,使其不断获得更高的损失值,阻止其找到最小值。

示例案例:高学习率导致梯度下降(GD)发散
- 如果
learning_rate设置得过低,可能会导致计算过程漫长,梯度下降需要经历多个梯度计算轮次才能收敛并发现最小损失值。

示例案例:低学习率导致梯度下降(GD)收敛所需时间过长
学习步长的值由曲线的斜率决定,这意味着随着我们接近最小点,学习步长变得越来越小。
当使用低学习率时,进展会比较平稳,而高学习率可能会导致指数级进展或停滞在低点。

图像改编自cs231n.github.io/neural-networks-3/
现在我们将探讨梯度下降算法的三种不同实现。
1. 批量梯度下降
批量梯度下降是实现梯度下降的最广泛使用的方法。它涉及在每次迭代时计算整个数据集相对于模型参数(如回归系数)的梯度。
让我们看一个例子 🔍
首先,生成一个截距为 5,系数为 4 的数据集,并添加少量高斯噪声。请参见下面生成数据的散点图。
# Create example data set
x = 2 * np.random.rand(100,1)
y = 5 + 4 * x + np.random.randn(100,1)
# Make dataframe
regression_df = pd.DataFrame({'x':x.flatten(), 'y':y.flatten()})
# Plot
sns.lmplot(x='x', y='y', data=regression_df, fit_reg=False)

生成的数据的散点图
以下函数通过利用指定的学习率和迭代次数来执行批量梯度下降。最初,模型的系数(m)和截距(b)都设置为 0.5。在每次迭代中,通过计算预测值和实际值之间的差异来确定误差。然后,算法通过提取梯度来更新 m 和 b,梯度也会乘以学习率。循环持续进行,直到达到指定的迭代次数,结果的损失值和模型参数被存储在params和loss中。
# Function to compute batch gradient descent
def batch_gradient_descent(x, y, learning_rate, iterations):
'''
Batch Gradient Descent implication. Inputs data,
learning rate, and number of iterations. Random m and
b values are given to start the iteration. Returns optimal
model parameters as well as historical loss values.
'''
m, b = 0.5, 0.5
params, loss = [], []
N = len(x)
for iteration in range(iterations):
func = y - (m*x + b)
# Updating m and b
m -= learning_rate * (-2 * x.T.dot(func).sum() / N)
b -= learning_rate * (-2 * func.sum() / N)
params.append((m, b))
loss.append(mean_squared_error(y, (m*x + b)))
return m, b, params, loss
使用上述函数,我们现在将测试不同的学习率并评估性能。
1️⃣ 让我们设置 learning_rate=0.01,进行 1000 次迭代,并绘制损失函数图。
# Find optimal parameters using BGD
m, b, params, loss = batch_gradient_descent(x, y, learning_rate=0.01,
iterations=1000)
# Predict y values using optimal parameters
y_predicted = m*x + b
# Print optimal parameters and final loss value
print("m:", m, "b:", b)
print("MSE:", mean_squared_error(y, y_predicted))
# Plot actual vs predicted value plot as well as historical loss values
plot_regression(x, y, y_predicted, params=params,
title="Batch Gradient Descent with Learning Rate=0.01")

使用学习率为 0.01 和 1000 次迭代的批量梯度下降
m: 4.27, b: 4.88, MSE: 0.95
执行时间:367 毫秒
在上述图中,首先可以看到每次迭代生成的回归线(紫色),并注意它如何在大约 100 次迭代后逐渐接近最优。在第二个图中,损失函数在每次迭代后显示,在前 50-100 次迭代内损失显著减少。
2️⃣ 现在让我们设置 learning_rate=0.001,使用相同的迭代次数,并绘制损失函数图。
# Find optimal parameters using BGD
m, b, params, loss = gradient_descent(x, y, learning_rate=0.001,
iterations=1000)
# Predict y values using optimal parameters
y_predicted = m*x + b
# Print optimal parameters and final loss value
print("m:", m, "b:", b)
print("MSE:", mean_squared_error(y, y_predicted))
# Plot actual vs predicted value plot as well as historical loss values
plot_regression(x, y, y_predicted, params=params,
title="Batch Gradient Descent with Learning Rate=0.01")
m: 4.67, b: 4.31, MSE: 1.05
执行时间:522 毫秒

使用学习率为 0.001 和 1000 次迭代的批量梯度下降
当学习率降低时,模型需要更多时间才能收敛,并且最终损失值较高。第一个图显示了回归线下方更深的紫色阴影,这意味着在许多迭代中远离了最优线。此外,请注意将学习率从 0.01 减少到 0.001 导致执行时间从 367 毫秒增加到 522 毫秒。
3️⃣ 让我们使用相同的迭代次数设置 learning_rate=0.1 并绘制损失函数图。
# Find optimal parameters using BGD
m, b, params, loss = gradient_descent(x, y, learning_rate=0.1,
iterations=1000)
# Predict y values using optimal parameters
y_predicted = m*x + b
# Print optimal parameters and final loss value
print("m:", m, "b:", b)
print("MSE:", mean_squared_error(y, y_predicted))
# Plot actual vs predicted value plot as well as historical loss values
plot_regression(x, y, y_predicted, params=params,
title="Batch Gradient Descent with Learning Rate=0.01")
m: 4.23, b: 4.93, MSE: 0.95
执行时间:214 毫秒

使用学习率为 0.1 和 1000 次迭代的批量梯度下降
当使用高学习率时,模型表现出快速收敛。在第一个图中,只有少数几次迭代中回归线远离最优线。在第二个图中,损失在前几次迭代中急剧下降,并保持稳定直到所有迭代完成。此外,执行时间在三次试验中最低,为 214 毫秒。
🚨 对于每次参数更新计算整个训练数据集的梯度在大型数据集上可能计算开销较大。幸运的是,随机梯度下降或小批量随机梯度下降可以帮助解决这个问题。
2. 随机梯度下降
实现梯度下降的另一种方法是使用随机梯度下降。它特别适用于大型训练数据集,在这些数据集上,批量梯度下降可能需要过长的计算时间。
随机梯度下降通过仅计算一个随机数据点的损失值来更新模型系数,而不是计算训练数据集中的每个数据点并聚合。
随机选择数据点对于防止陷入类似值(如数据中的簇)是至关重要的。学习率和迭代次数也是关键因素,类似于批量梯度下降。
以下函数通过利用指定的学习率和迭代次数来执行随机梯度下降。初始时,模型的系数(m)和截距(b)设置为 0.5。在每次迭代中,使用np.random计算一个随机选择的数据点的误差。函数的其余步骤与批量梯度下降中的步骤相同。
# Function to compute stochastic gradient descent
def stochastic_gradient_descent(x, y, learning_rate, iterations):
'''
Stochastic Gradient Descent Implication. Inputs data,
learning rate, and number of iterations. Random m and
b values are given to start the iteration. Index of the
random data is updated for each iteration. Returns
optimal model parameters as well as historical loss values.
'''
m, b = 0.5, 0.5 # initial parameters
params, loss = [], [] # lists to store learning process
for iteration in range(iterations):
# Sample a random index for loss calculation
indexes = np.random.randint(0, len(x), 1)
xi = np.take(x, indexes)
yi = np.take(y, indexes)
N = len(xi)
func = yi - (m*xi + b)
# Updating parameters m and b
m -= learning_rate * (-2 * xi.dot(func).sum() / N)
b -= learning_rate * (-2 * func.sum() / N)
params.append((m, b))
loss.append(mean_squared_error(y, m*x+b))
return m, b, params, loss
让我们尝试使用与批量梯度下降相同的学习率和迭代次数。
1️⃣ 让我们设置learning_rate=0.01并进行 1000 次迭代,然后绘制损失函数。
# Find optimal parameters using SGD
m, b, params, loss = stochastic_gradient_descent(x, y, learning_rate=0.01,
iterations=1000)
# Predict y values using optimal parameters
y_pred = m*x + b
# Print final loss value
print("MSE:", mean_squared_error(y, y_pred))
# Plot actual vs predicted value plot as well as historical loss values
plot_regression(x, y, y_pred, params=params, title="Stochastic Gradient Descent with Learning Rate=0.01")
MSE: 0.97
执行时间: 325 毫秒

学习率为 0.01 和 1000 次迭代的随机梯度下降
注意到在相同学习率下,与批量梯度下降相比,我们将执行时间从 367 毫秒减少到 325 毫秒。然而,均方误差(MSE)从 0.95 增加到了 0.97。
2️⃣ 现在让我们设置learning_rate=0.001并进行 1000 次迭代,然后绘制损失函数。我不会包括代码片段,因为它与之前的示例相同。不过,如果需要,可以参考文章的源代码。
MSE: 1.03
执行时间: 253 毫秒

学习率为 0.001 和 1000 次迭代的随机梯度下降
注意到使用学习率为 0.001 的批量梯度下降,在 522 毫秒内得到了 1.05 的损失值。然而,使用相同学习率的随机梯度下降只需 253 毫秒就得到了 1.03 的损失值。
3️⃣ 现在让我们设置learning_rate=0.1并进行 1000 次迭代,然后绘制损失函数。
MSE: 1.27
执行时间: 237 毫秒

学习率为 0.1 和 1000 次迭代的随机梯度下降
使用高学习率的随机梯度下降时,计算引入了损失值的波动。如第一个图所示,存在许多迭代在最佳回归线的上方和下方。
总体而言,批量梯度下降在收敛到最小值方面优于随机梯度下降,因为随机梯度下降往往会在全局最小值附近徘徊。然而,如果仔细选择学习率,随机梯度下降也可以在较短的时间内达到类似的损失值(有时甚至更好)。
3. 迷你批量梯度下降
迷你批量梯度下降是批量和随机梯度下降之间的一个有用的中间选项。它不是为整个数据集或单个观察值计算梯度,而是将训练数据分成较小的批次,并为每个批次计算梯度。
通过结合批量和随机方法,迷你批量 GD 在计算速度上优于批量模型,并且提高了随机模型的准确性。
以下函数执行迷你批量梯度下降,使用给定的学习率、设定的迭代次数和选定的批量大小。通过指定批量大小,我们可以确定每次梯度计算中包含的数据点数量。例如,在随机梯度下降中,批量大小为 1,梯度是针对单个数据点计算的。然而,如果我们设置批量大小为 10,那么梯度将会针对 10 个数据点进行计算,然后合并。
# Function to compute mini-batch gradient descent
def mini_batch_gradient_descent(x, y, learning_rate, iterations, batch_size):
'''
Mini-Batch Gradient Descent implication. Inputs data,
learning rate, number of iterations and batch size.
Random m and b values are given to start iteration.
Index of the random data is updated for each iteration
per batch. Returns optimal model parameters as well
as historical loss values.
'''
m, b = 0.5, 0.5
params, loss = [], []
for iteration in range(iterations):
indexes = np.random.randint(0, len(x), batch_size)
xi = np.take(x, indexes)
yi = np.take(y, indexes)
N = len(xi)
func = yi - (m*xi + b)
# Updating parameters m and b
m -= learning_rate * (-2 * xi.dot(func).sum() / N)
b -= learning_rate * (-2 * func.sum() / N)
params.append((m, b))
loss.append(mean_squared_error(y, m*x+b))
return m, b, params, loss
使用批量大小为 10 的情况下,测试相同的 3 种情况,learning_rate=0.01、learning_rate=0.001和learning_rate=0.1。
1️⃣ learning_rate=0.01 并进行 1000 次迭代。
# Find optimal parameters using MBGD
m, b, params, loss = mini_batch_gradient_descent(x, y, learning_rate=0.01,
iterations=1000, batch_size=10)
# Predict y values using optimal parameters
y_pred = m*x + b
# Print final loss value
print("MSE:",mean_squared_error(y, y_pred))
# Plot actual vs predicted value plot as well as historical loss values
plot_regression(x, y, y_pred, params=params,
title="Mini-Batch Gradient Descent with Learning Rate=0.01")
MSE: 0.96
执行时间:224 毫秒

使用学习率为 0.01 和 1000 次迭代的迷你批量梯度下降
你是否注意到,使用 0.01 的学习率时,我们能够在 367 毫秒内通过批量 GD 达到 MSE 为 0.95,而使用随机 GD 则在 325 毫秒内得到 MSE 为 0.97?此外,通过迷你批量梯度下降,我们能够在仅 224 毫秒内达到 MSE 为 0.96。
2️⃣ learning_rate=0.001 并进行 1000 次迭代。
MSE: 1.04
执行时间:485 毫秒

使用学习率为 0.001 和 1000 次迭代的迷你批量梯度下降
值得注意的是,通过使用相同的学习率,我们成功地将执行时间从 522 毫秒减少到 485 毫秒,相比批量梯度下降。此外,均方误差(MSE)从 1.05 降到了 1.04。
3️⃣ learning_rate=0.1 并进行 1000 次迭代。
MSE: 0.97
执行时间:239 毫秒

使用学习率为 0.1 和 1000 次迭代的迷你批量梯度下降
通过实现迷你批量,消除了在随机梯度下降历史损失值中观察到的不规则性,并且模型收敛非常快。
结论
在这篇文章中,我们深入探讨了梯度下降算法及其三种主要实现——批量、随机和小批量。通过一系列实验,我们比较了这三种实现的不同学习率和迭代次数。我们发现,批量梯度下降能够获得最准确的模型参数,导致最低的损失值,尽管需要的时间最长。另一方面,随机梯度下降虽然准确度较低,但计算速度最快。最后,我们探索了小批量梯度下降,发现它在合理的执行时间内达到了中等的准确度。
希望你喜欢阅读有关梯度下降的内容,并觉得这篇文章有用!✨
🍓 如果你喜欢阅读这样的文章并希望支持我的写作,你可以考虑 成为 Medium 会员! Medium 会员可以全面访问所有作者的文章,如果你使用 我的推荐链接,你将直接支持我的写作。
🍓 如果你已经是会员并且有兴趣阅读我的文章,你可以 订阅以获取通知 或 在 Medium 上关注我。如果你有任何问题或建议,请告诉我。✨
我推荐在阅读完这篇文章后进一步阅读的附加资源:
- 了解通过这两种最常见的方法进行超参数调整的概念。
以及深入探讨如何将它们结合起来
[towardsdatascience.com
2. 理解线性回归模型及其假设。
如何通过图形和数值输出构建和检查回归模型的质量
[towardsdatascience.com
参考文献
-
头图由 Lucas Clara 提供,来源于 Unsplash
-
学习率图表改编自
cs231n.github.io/neural-networks-3/ -
所有其他图片由作者提供
理解群体顺序测试
原文:
towardsdatascience.com/understanding-group-sequential-testing-befb35cec07a
因果数据科学
如何进行有效的实验,包括窥视和提前停止。
·发表于 Towards Data Science ·15 分钟阅读·2023 年 12 月 26 日
--

封面,图片由作者提供
A/B 测试是因果推断的黄金标准,因为它们允许我们在最小假设下做出有效的因果陈述,这要归功于随机化。实际上,通过随机分配处理(药物、广告、产品等),我们可以比较结果(疾病、公司收入、客户满意度等)在受试者(患者、用户、客户等)之间的差异,并将结果的平均差异归因于处理的因果效应。
实施 A/B 测试通常不是瞬间完成的,尤其是在在线环境中。用户通常是实时或分批处理的。在这些情况下,可以在数据收集完成之前多次查看数据。这种现象被称为窥视。虽然窥视本身并不成问题,但在窥视时使用标准测试程序可能会导致误导性结论。
解决窥视问题的方法是相应地调整测试程序。最著名和传统的方法是所谓的序列概率比检验 (SPRT),该方法可以追溯到第二次世界大战。如果你想了解更多关于这个测试及其迷人的历史,我写了一篇博客文章。
编辑描述
towardsdatascience.com
顺序概率比检验(SPRT)的主要优点是它在给定目标置信水平和功效的情况下,保证了最小的样本量。然而,SPRT 的主要问题是它可能会无限期地继续。这在有截止日期和预算限制的应用环境中是一个非无关的问题。在这篇文章中,我们将探索一种替代方法,允许在数据收集的任何点进行任何数量的中间窥探:分组顺序测试。
模拟
让我们从一些模拟的数据开始。为了保持代码尽可能简洁,我将抽象化实验设置,直接处理来自正态分布的数据。然而,我们可以将其视为标准 A/B 测试中平均治疗效果的分布。正态分布是基于中心极限定理的渐近近似。
在生成数据之前,我从src.theme导入相关库和我的绘图主题。
from src.theme import *
import numpy as np
import pandas as pd
import scipy as sp
假设真实的数据生成过程确实是一个正态分布,均值为μ=1,标准差为σ=5.644。在 A/B 测试的背景下,我们可以将其视为一个正的平均治疗效果,标准差比效果大 5 倍以上。
mu = 1
sigma = 5.644
我们希望建立一个双侧检验,具有95%的置信度和80%的功效。因此,我们的目标假阳性错误率将是α=0.05,目标假阴性错误率将是β=0.2。
alpha = 0.05
beta = 0.2
我们现在可以计算实验所需的样本量,假设平均治疗效果为1,标准差为5.664。由于我们已经抽象化了两个组的比较,功效计算的公式是

功效计算公式,图片来源于作者
其中zs是标准正态分布的分位数,计算在1-α/2和1-β处。
ppf = sp.stats.norm(0, 1).ppf
cdf = sp.stats.norm(0, 1).cdf
z_alpha = ppf(1 - alpha/2)
z_beta = ppf(1 - beta)
N = int((2 * sigma * (z_alpha + z_beta) / mu)**2)
print(f"Number of obserations needed: {N}")
Number of obserations needed: 1000
我们需要N=1000个观察值来达到目标置信水平95%和功效80%。
我们现在可以绘制模拟数据。由于我们将经常比较不同模拟结果,我们绘制K=10,000个N=1,000数据点的序列。
K = 10_000
np.random.seed(2)
obs = np.random.normal(mu, sigma, size=(N, K))
我们现在准备调查窥探和分组顺序测试。
窥探
如果我们在实验结束前对数据进行窥探会发生什么?
假设例如我们每50个观察值查看一次数据,从100开始。一个原因可能是数据以批次到达,或者我们每天一开始工作时就进行窥探。
N_peek = np.arange(100, N+1, 50, dtype=int)
N_peek
array([ 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 950, 1000])
观察数据本身当然不是问题per-se。然而,我们可能会被诱使得出结论,基于我们观察到的内容。假设我们的幼稚实验平台持续报告最新的平均值、标准差和置信区间,其中置信区间计算方式如下:

在不窥视的情况下的置信区间,图片由作者提供
其中 n 是样本数量,μ̂ₙ 是经过 n 个样本后的估计样本均值,σ̂ₙ 是经过 n 个样本后的估计标准差,α 是显著性水平,z 是标准正态分布的 1-α/2 分位数。
def select_alpha_naive(n, N, N_peek, alpha):
return alpha
假设我们决定在获得一个显著结果后停止实验。
让我们计算在每个窥视点上观察到的置信区间。
def compute_intervals(select_alpha, obs, N_peek, alpha=0.05, **kwargs):
# Compute rolling mean and standard deviation
N, K = np.shape(obs)
ns = np.reshape(np.arange(1, N+1), (-1, 1))
means = np.cumsum(obs, axis=0) / ns
stdevs = np.sqrt(np.cumsum((obs - means)**2, axis=0) / ns)
# Compute intervals at each peeking time
df_intervals = pd.DataFrame({"k": range(K)})
df_intervals["rejected_0"] = False
df_intervals["rejected_1"] = False
df_intervals["length"] = max(N_peek)
for t, n in enumerate(N_peek):
df_intervals[f"mean{n}"] = means[n-1, :]
df_intervals[f"width{n}"] = ppf(1 - select_alpha(n, N, N_peek, alpha, **kwargs)/2) * stdevs[n-1, :] / np.sqrt(n)
df_intervals[f"lowerb{n}"] = means[n-1, :] - df_intervals[f"width{n}"]
df_intervals[f"upperb{n}"] = means[n-1, :] + df_intervals[f"width{n}"]
df_intervals[f"coverage{n}"] = (df_intervals[f"lowerb{n}"] <= mu) & (df_intervals[f"upperb{n}"] >= mu)
df_intervals["rejected_0"] = df_intervals["rejected_0"] | (df_intervals[f"lowerb{n}"] >= 0) | (df_intervals[f"upperb{n}"] <= 0)
df_intervals[f"power{n}"] = df_intervals["rejected_0"]
df_intervals["rejected_1"] = df_intervals["rejected_1"] | ~df_intervals[f"coverage{n}"]
df_intervals[f"falsep{n}"] = df_intervals["rejected_1"]
df_intervals["length"] = np.minimum(df_intervals["length"], n) * df_intervals["rejected_0"] + max(N_peek) * (1 - df_intervals["rejected_0"])
return df_intervals
dfi_naive = compute_intervals(select_alpha_naive, obs, N_peek)
这些平均值和置信区间随时间的变化是什么样的?在下图中,我绘制了数据收集的累积平均值,以及每次窥视时的置信区间。
plot_peeking(dfi_naive, obs)

累积平均效应和窥视置信区间,图片由作者提供
如我们所见,前七次观察数据时,置信区间穿过零线,因此我们未能拒绝零均值的零假设。我用橙色标出了这些置信区间。然而,在第八次观察数据时,即450次观察,置信区间没有穿过零线,因此我们拒绝零假设,结束实验。
这个过程的问题与多重假设检验非常相似:我们为单次数据观察构建置信区间,因此做出单个决定,但实际上我们做出了多个决定。事实上,我们已经决定在达到 450 次观察之前七次不停止实验,而在 450 次观察时停止了实验。
窥视和早期停止的后果是什么?让我们看看如果我们重复这个实验多次会发生什么。我们现在将绘制100个不同模拟在三个不同时间点的置信区间:在200、400 和 600 次观察后。请注意,这些分别对应于第3次、第7次和第11次数据窥视。
N_plot = [200, 400, 600]
我们首先要检查的是覆盖率:置信区间是否真的覆盖了真实的处理效应,正如它们所应该的那样?我标出了那些没有覆盖的置信区间。
plot_intervals(dfi_naive, N_plot, "coverage")

使用幼稚测试的 100 次模拟的覆盖情况,图片由作者提供
看起来我们的覆盖率在每个时间点都是正常的。我们分别在100次模拟中有2、6和2次不覆盖真实治疗效应(μ=1)。这是可以预期的,因为我们的置信水平是5%,因此我们期望100个区间中平均有5个不覆盖真实治疗效应。
接下来,我们研究功效:我们估计量在确实存在效应时检测效应的能力。记住,功效总是相对于效应大小的。然而,我们在进行功效计算时使用了真实效应,因此我们期望实验达到80%的预期功效。
请注意,由于我们进行观察,我们会在测试显著时拒绝原假设并停止实验。因此,在我们案例中,特定时间点的功效是拒绝原假设的概率,无论是该测试还是任何之前的测试。
plot_intervals(dfi_naive, N_plot, "power")

使用朴素测试的 100 次模拟中的功效,图像由作者提供
如我们所见,在200次观察时,我们已经在100次模拟中的72次拒绝了没有效应的原假设(μ=0),接近目标功效的80%。然而,在400次观察时,我们在超过100次模拟中的80次拒绝了原假设,表明我们可以缩短实验时间。
到目前为止,一切似乎都很顺利:我们的区间覆盖了真实效应,并且比预期更快地拒绝了原假设。让我们检查所有的观察阶段和10,000次模拟。我们还要检查第三个指标:假阳性错误率。为了计算这个,我们将原假设更改为μ=1并检查我们拒绝它的频率。同样,由于我们进行了多次观察,重要的是特定观察阶段或任何之前的阶段的拒绝率。
在下图中,我绘制了每个观察阶段10,000次模拟中的覆盖率、功效和错误拒绝率。
plot_coverage_power(dfi_naive)

使用朴素测试的10,000次模拟中的估计器性能,图像由作者提供
覆盖率似乎符合预期。功效在大约250次观察后已达到80%以上,确认了我们之前的见解。然而,错误拒绝率远高于目标的5%。这意味着当原假设为真时,我们的拒绝频率高于应有的水平。
我们最后要检查的是实验是否确实在平均上更短,以及短了多少。让我们计算平均实验长度,以观察次数为单位。
print(f"Average length: n = {dfi_naive.length.mean():.0f}", )
Average length: n = 177
平均而言,我们只需177次观察即可得出结论!然而,由于较高的错误拒绝率,这些可能是错误的结论。
我们可以怎么解决这个问题?我们需要构建考虑到我们在顺序中进行多次测试的置信区间。
Alpha 修正
在本节中,我们将探讨第一组校正,这些校正会修改用于计算置信区间的α值,以考虑窥视和提前停止。
Bonferroni 校正
由于窥视问题类似于多重假设测试,我们可以从应用相同的解决方案开始。
处理多重假设测试的最简单方法是所谓的 Bonferroni 校正。这个想法很简单:根据观察次数按比例减少显著性水平α。特别是,我们不再对每次观察使用相同的α,而是使用

Bonferroni 的α校正,作者提供的图像
其中 P 是我们计划窥视的次数。
def select_alpha_bonferroni(n, N, N_peek, alpha):
P = len(N_peek)
return alpha / P
Bonferroni 校正在覆盖率方面表现如何?让我们绘制三个窥视阶段的置信区间:在收集了200、400 和 600 次观察后。
dfi_bonferroni = compute_intervals(select_alpha_bonferroni, obs, N_peek)
plot_intervals(dfi_bonferroni, N_plot, "coverage")

Bonferroni 校正下的覆盖率在 100 次模拟中,作者提供的图像
覆盖率看起来很棒!只有一次在n=200时一个区间没有覆盖真实值μ=1。
虽然这在一开始可能让人感到安慰,但实际上应该引起警觉。事实上,使用显著性水平α=0.05时,我们期望的覆盖率是95%。更高的覆盖率很可能会以功效为代价。我们来看看。
plot_intervals(dfi_bonferroni, N_plot, "power")

Bonferroni 校正下的功效在 100 次模拟中,作者提供的图像
在200次观察下测试功效不足,而在400次观察下功效非常接近目标的80%。在600次观察时,我们几乎达到了 100% 的功效。
让我们绘制每个窥视阶段的覆盖率、功效和假阳性率,基于K=10,000次模拟。
plot_coverage_power(dfi_bonferroni)

Bonferroni 校正在 10,000 次模拟中的估计性能,作者提供的图像
覆盖率很好,功效在大约450次观察后达到目标,虚假拒绝率始终低于5%的目标。那么平均实验长度呢?
print(f"Average length: n = {dfi_bonferroni.length.mean():.0f}", )
Average length: n = 317
平均实验长度为317次观察,高于朴素测试程序,但仍明显低于未经窥视所需的1000次观察。
一切看起来都不错,甚至可能过于完美。确实,可能还有改进的空间。考虑到如此高的覆盖率和低的虚假拒绝率,结果表明我们可以缩短置信区间,从而提高功效和降低实验长度,同时不低于95%的覆盖率或超过5%的虚假拒绝率。如何?
Bonferroni 校正有两个缺点。首先,它不是为序列测试设计的,而是为多重假设测试设计的。其次,即使是多重假设测试,它也被认为是非常保守的。
校正
Bonferroni 的序贯检验校正的第一个版本是Pocock (1977)。其思想是考虑检验的序贯特性,这会在检验统计量之间生成非常特定的相关结构。由于这一洞察,Pocock 能够使用一个校正的α值,该值介于天真的α和 Bonferroni 的α/P之间。比 Bonferroni 更大的α意味着更高的功效,同时保持高覆盖率和低假阳性率。这些值是通过一个数值算法得出的,该算法以显著性水平α和总窥探次数P作为输入。
Pocock 校正的问题在于它未能充分利用检验的序贯特性,因为置信区间随时间保持不变。O’Brien 和 Fleming (1979) 提出了使用时间变化的α校正。他们的想法是使置信区间的宽度不仅适应显著性水平α和总窥探次数P,还适应每次窥探p。
然而,这些程序的主要缺点是它们要求提前规划窥探次数。这通常是不切实际的,因为窥探是一个固有的自发过程,来源于数据批量的大小、管理层的压力或实验者的好奇心。
当窥探未提前规划时,我们该怎么办?
组序贯检验
Lan 和 DeMets (1983)注意到,窥探中重要的不是你窥探了多少,而是你何时窥探。组序贯检验(GST)的主要思想是允许在任何时间点进行窥探,并在数据收集过程中对窥探时间点的显著性水平进行校正,t = n/N。
组序贯检验的动态部分是所谓的α花费函数,它决定了如何根据窥探时间t来校正显著性水平α。在本文的其余部分,我们将回顾两个α花费函数,分别近似于Pocock (1977)和O’Brien 和 Fleming (1979)的校正。
GST Pocock 近似
第一个α花费函数是Pocock (1977)的近似,其表达式为

Pocock 的α花费函数用于组序贯检验,图片来源于作者
请注意,当观察比例t=n/N达到整个样本(t=1)时,Pocock 的校正会收敛到原始显著性水平α。
def select_alpha_gst_pocock(n, N, N_peek, alpha):
t = n / N
return alpha * np.log(1 + (np.exp(1) - 1) * t)
让我们看看使用 Pocock 的α花费函数的组序贯检验如何工作。
dfi_gst_pocock = compute_intervals(select_alpha_gst_pocock, obs, N_peek)
plot_coverage_power(dfi_gst_pocock)

使用 Pocock 的 GST 进行 10,000 次模拟的估计器性能,图片来源于作者
正如我们之前提到的,覆盖率趋向于目标覆盖率,观察次数增加。实验似乎也比使用 Bonferroni 校正的功效更强,但如果实验运行时间过长,虚假拒绝率会超过5%的目标。
平均实验长度怎么样?
print(f"Average length: n = {dfi_gst_pocock.length.mean():.0f}", )
Average length: n = 229
平均实验长度确实低于 Boferroni,平均为229次观察,而不是317次。
GST O’Brien & Fleming 近似
第二个α支出函数是O’Brien, Fleming (1979)的近似,其表达式为

O’Brien 和 Fleming 的组序贯测试α支出函数,图像由作者提供
其中Φ是标准正态分布的累积分布函数(CDF),而ρ是一个自由参数,通常默认为ρ=1。
def select_alpha_gst_obrien_fleming(n, N, N_peek, alpha, rho=1):
t = n / N
return 4 - 4 * cdf(ppf(1 - alpha/4) / t**(rho/2))
让我们看看使用 O’Brien 和 Fleming 近似的组序贯测试在K=10,000次模拟中的表现。
dfi_gst_obrien_fleming = compute_intervals(select_alpha_gst_obrien_fleming, obs, N_peek)
plot_coverage_power(dfi_gst_obrien_fleming)

使用 O’Brien 和 Fleming 的 GST 在 10,000 次模拟中的估计性能,图像由作者提供
看起来 O’Brien 和 Fleming 的近似比 Pocock 的更保守,具有更高的覆盖率和较低的功效,但虚假拒绝率更接近5%目标。
print(f"Average length: n = {dfi_gst_obrien_fleming.length.mean():.0f}", )
Average length: n = 414
平均实验长度实际上高于 Boferroni,平均为414次观察,而不是317次。然而,通过在校正公式中降低参数ρ,可以减少这一数值。以ρ=0.5为例,它对应于Wang, Tsiatis (1987)的校正。
dfi_gst_obrien_fleming_05 = compute_intervals(select_alpha_gst_obrien_fleming, obs, N_peek, rho=0.5)
print(f"Average length: n = {dfi_gst_obrien_fleming_05.length.mean():.0f}", )
Average length: n = 303
确实,使用较低的ρ,我们已将平均实验长度从414减少到303次观察。
α支出权衡
在总结之前,值得看看窥视的权衡。我们引入了一种方法,使我们能够在任何时候进行有效推断。然而,我们应该窥视吗?如果是的话,窥视多少次呢?
在下图中,我绘制了使用 Pocock 近似的组序贯测试的测试性能,当我们增加窥视频率从50次观察到10次观察时。
N_peek_10 = np.arange(30, N+1, 10, dtype=int)
dfi_gst_10 = compute_intervals(select_alpha_gst_pocock, obs, N_peek_10)
plot_coverage_power(dfi_gst_10)

每 10 次观察下的 GST 估计性能,图像由作者提供
如我们所见,覆盖率基本不受影响,而功效和虚假拒绝率有所增加。平均实验长度也从229减少到188次观察。
print(f"Average length: n = {dfi_gst_10.length.mean():.0f}", )
Average length: n = 188
如果我们减少窥视频率会怎样?在下图中,我绘制了每 200 次观察时的结果。
N_peek_200 = np.arange(200, N+1, 200, dtype=int)
dfi_gst_200 = compute_intervals(select_alpha_gst_pocock, obs, N_peek_200)
plot_coverage_power(dfi_gst_200)

每 200 次观察下的 GST 估计性能,图像由作者提供
从图中我们可以看到相反的结果:功效和虚假拒绝都减少了。另一方面,现在我们平均需要 311 次观测才能得出结论,而不是 229 次。
print(f"Average length: n = {dfi_gst_200.length.mean():.0f}", )
Average length: n = 311
结论
在本文中,我们探讨了 组顺序测试,一种在 A/B 测试中任何次数和任何时刻进行窥探时做出有效推断的程序。我们还看到窥探并不是免费的。主要的 权衡 是,我们窥探的次数越多,实验停止得越早,但虚假拒绝率也越高。
文章中至少还有几个我没有提及的主题,以避免过长。第一个是 偏倚。顺序测试容易引入偏倚,因为提前停止可能是由于低方差或大效应。由于后者,顺序测试往往会导致对治疗效果的 高估。这种现象通常称为 赢家的诅咒,通常发生在研究的功效不足时,即在早期窥探阶段。一种解决方案是设计一个 β 花费 函数。
我没有涵盖的第二个主题是所谓的 因无效而停止。在本文的例子中,如果我们得到统计上显著的估计值,我们就提前停止实验。然而,窥探也可以告知另一种停止规则:因为继续测试变得极不可能产生显著结果而停止。
我没有涵盖的最后一个主题是如何进行 功效分析。在上面的例子中,我们在一开始就进行了功效分析,假设没有窥探。然而,鉴于我们知道我们会窥探,我们本可以预期需要一个更小的样本。一个密切相关的主题是 最佳窥探。一旦决定要窥探,应该何时进行?
参考文献
-
Lakens, Pahlke, Wassmer (2021). 组顺序设计:教程
-
Lan, DeMets (1983). 临床试验的离散顺序边界
-
Spotify (2023). 选择顺序测试框架
相关文献
- 实验、窥探和最佳停止
代码
你可以在这里找到原始的 Jupyter Notebook:
[## Blog-Posts/notebooks/group_sequential_testing.ipynb at main · matteocourthoud/Blog-Posts
我的 Medium 博客文章的代码和笔记本。通过创建一个来贡献 matteocourthoud/Blog-Posts 开发…
感谢你的阅读!
非常感谢! 🤗 如果你喜欢这篇文章并希望看到更多内容,可以考虑 关注我。我每周发布一次关于因果推断和数据分析的内容。我尽量保持文章简洁但准确,始终提供代码、示例和模拟。
此外,一个小小的 免责声明:我写作是为了学习,因此错误是常有的事,尽管我尽力做到最好。请在发现错误时告诉我。我也欢迎对新话题的建议!
理解直方图和核密度估计
原文:
towardsdatascience.com/understanding-histograms-and-kernel-density-estimation-6f9a1f09f960
对直方图和 KDE 的深入探索
·发表于 Towards Data Science ·26 分钟阅读·2023 年 12 月 18 日
--

直方图是可视化数值数据频率的图形。它通常用于数据科学和统计学中,以对数据集的分布进行初步估计。核密度估计(KDE)是一种通过从未知分布中抽取的随机样本来估计随机变量的概率密度函数(PDF)的方法。因此,它允许我们基于从中抽样的有限数据集推断总体的概率密度。KDE 常用于信号处理和数据科学,是估计概率密度的一个重要工具。本文讨论了直方图和 KDE 背后的数学和直觉,以及它们的优缺点。它还演示了如何从头开始在 Python 中实现 KDE。本文中的所有图形均由作者创建。
概率密度函数
设 X 为连续随机变量。X 在区间 [a, b] 内取值的概率可以写作

其中 f(x) 是 X 的概率密度函数(PDF)。X 的累积分布函数(CDF)定义为:

因此,X 的累积分布函数(CDF),在 x 处的值是 X 取小于或等于 x 的值的概率。使用方程 1,我们可以写道:

使用微积分基本定理,我们可以证明

这意味着 X 的概率密度函数(PDF)可以通过对其累积分布函数(CDF)关于 x 的导数来确定。直方图是估计数据集 PDF 的最简单方法,如我们在下一节所示,它利用方程 1 达到这一目的。
直方图
在列表 1 中,我们创建了一个双峰分布,作为两个正态分布的混合,并从该分布中抽取了大小为 1000 的随机样本。这里我们混合了两个正态分布:

因此,正态分布的均值分别是 0 和 4,方差分别是 1 和 0.8。混合系数为 0.7 和 0.3,因此这些分布的混合 PDF 为:

列表 1 绘制了图 1 中的 PDF 和样本。
# Listing 1
np.random.seed(2)
sample_size = 100
mu1 = 5
sigma1 = 1
mu2 = 9
sigma2 = 0.8
dist1 = norm.rvs(loc = mu1, scale = sigma1, size=sample_size)
dist2 = norm.rvs(loc = mu2, scale = sigma2, size=sample_size)
x = np.arange(0, 13, 0.01)
pdf1 = norm.pdf(x, loc = mu1, scale = sigma1)
pdf2 = norm.pdf(x, loc = mu2, scale = sigma2)
mix_coeffs = [0.7, 0.3]
pdf = mix_coeffs[0]*pdf1+mix_coeffs[1]*pdf2
data = np.zeros((sample_size, 2))
data[:, 0] = dist1
data[:, 1] = dist2
random_idx = np.random.choice(np.arange(2),
size=(sample_size,), p=mix_coeffs)
sample = data[np.arange(sample_size), random_idx]
plt.figure(figsize=(8, 5))
num_bins = 40
plt.plot(sample, np.zeros(len(sample)), marker='|', markersize=15,
linestyle='None', alpha=1)
plt.plot(x, pdf, color='red', linewidth=2, label="PDF")
plt.xlim([0, 13])
plt.xlabel('$x$', fontsize=16)
plt.ylabel('Probability density', fontsize=16)
plt.legend(loc='best', fontsize=14)
plt.show()

图 1
现在假设我们只有样本数据集,而双峰分布是未知的,我们如何从数据集中估计该分布的 PDF?我们可以做的最简单的事情是绘制这个数据集的直方图。列表 2 绘制了在列表 1 中生成的随机样本的直方图,并附上了该分布的 PDF。尽管它没有平滑的表面,但它模仿了分布的 PDF 的形状。
# Listing 2
plt.figure(figsize=(8, 5))
num_bins = 20
plt.hist(sample, density=True, bins = num_bins,
edgecolor='black', linewidth=1)
plt.plot(x, pdf, color='red', linewidth=2, label="PDF")
plt.xlim([0, 13])
plt.xlabel('$x$', fontsize=16)
plt.ylabel('Probability density', fontsize=16)
plt.legend(loc='best', fontsize=14)
plt.savefig('fig2.png', dpi=300, bbox_inches='tight')
plt.show()

图 2
让我们来看一下直方图的基本构造。利用导数定义和方程 2 及方程 3,我们可以写出:

因此,X在点x的 PDF 可以通过此方程估计。直方图使用相同的概念来估计基于随机样本的随机变量的 PDF。假设X是一个未知分布的随机变量,而随机变量X₁、X₂、… X_n通常表示可以从该分布中抽取的随机样本(即X₁、X₂、… X_n是独立同分布的)。在抽取随机样本后,我们用x₁、x₂、… x_n来表示样本中的观察值。因此,x₁、x₂、… x_n形成了我们想要生成直方图的数据集,每个xᵢ是这个数据集中的一个观察值。
我们首先需要定义计算直方图的区间[a, b]。我们将此区间划分为长度为h=(b−a)/k的k个等长子区间。这些定长子区间也称为bins。因此,我们将有以下子区间:

请注意,前k-1 个子区间是半开区间,只包括其左端点,因为我们不希望它们在端点处重叠。现在我们可以使用方程 4 来估计属于子区间Bᵢ的测试点x的 PDF:

这是一种近似,因为我们不再取极限。但是我们如何计算

我们可以简单地使用我们的随机样本。我们知道我们的数据集有n个观察值。那么

如果我们用nᵢ表示随机样本在Bᵢ中的观察次数,那么我们有:

我们将 PDF 的这种估计称为 f^(x),所以我们可以将之前的方程写成:

要绘制直方图,我们需要知道 a、b 和 h 的值。我们首先计算箱子的端点,对于每个箱子 Bᵢ,我们在该箱子的中点上绘制一个宽度为 h 且高度等于 nᵢ / (nh) 的矩形条(图 3)。

图 3
我们知道

所以,PDF 在整个空间上的积分必须等于 1。

这意味着 PDF 是标准化的,PDF 曲线下的面积等于 1。直方图的面积等于图 2 中矩形的面积之和:

其中 k 是箱子的数量。这是有意义的,因为直方图是分布 PDF 的一种估计器。
在 matplotlib 中绘制直方图时,我们使用 hist() 函数。请注意,我们应将参数 density 设置为 True 来获得 PDF 的估计值。否则,我们将得到一个 频率直方图。在频率直方图中,每个箱子的高度等于:

使用方程 6 绘制的直方图也称为 密度直方图。在本文中,直方图指的是密度直方图。
从清单 2 中可以看出,我们只设置了箱子的数量,并未提供 a 和 b 的值。
plt.hist(sample, density=True, bins = num_bins,
edgecolor='black', linewidth=1)
这是因为 sample 的最小值和最大值分别自动用作 a 和 b。我们可以通过以下代码片段绘制相同的直方图:
bin_width = (max(sample) - min(sample)) / num_bins
plt.hist(sample, density=True,
bins = np.arange(min(sample), max(sample) + bin_width, bin_width),
edgecolor='black', linewidth=1)
基于方程 6,f^(x) 的值依赖于 n 和 h。但它也依赖于箱子的起始点(方程 5 中的 a)。所以,如果我们改变 a 的值,它将改变直方图的形状。清单 3 绘制了清单 1 中定义的样本的直方图,其中 a=1.8 和 b=12,并将其与图 2 中 a 和 b 设置为样本的最小值和最大值的直方图进行了比较。两个直方图具有相同的 h。结果见图 4。
# Listing 4
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10,10))
plt.subplots_adjust(hspace=0.4)
num_bins = 20
bin_width = (max(sample) - min(sample)) / num_bins
ax1.hist(sample, density=True,
bins = np.arange(min(sample), max(sample) + bin_width, bin_width),
edgecolor='black', linewidth=1)
ax1.plot(x, pdf, color='red', linewidth=2, label="PDF")
ax1.set_xlim([0, 13])
ax1.set_ylim([0, 0.45])
ax1.set_xlabel('$x$', fontsize=16)
ax1.set_ylabel('Probability density', fontsize=16)
ax1.legend(loc='best', fontsize=14)
ax1.set_title('a={}, b={}'.format(np.round(min(sample), 3),
np.round(max(sample), 3)), fontsize=16)
ax2.hist(sample, density=True,
bins = np.arange(1.15, 12 + bin_width, bin_width),
edgecolor='black', linewidth=1)
ax2.plot(x, pdf, color='red', linewidth=2, label="PDF")
ax2.set_xlim([0, 13])
ax2.set_ylim([0, 0.45])
ax2.set_xlabel('$x$', fontsize=16)
ax2.set_ylabel('Probability density', fontsize=16)
ax2.legend(loc='best', fontsize=14)
ax2.set_title('a={}, b={}'.format(1.15, 12), fontsize=16)
plt.plot()

图 4
如你所见,通过改变 a 的值,直方图的形状会发生变化。因此,直方图的形状依赖于用于生成它的随机样本。我们不知道从中抽取随机样本的总体的最大值和最小值。因此,如果我们基于随机样本的最小值和最大值绘制直方图,则样本将影响直方图的形状。然而,直方图应模仿总体的 PDF 形状,而总体的 PDF 对所有随机样本都是相同的。
直方图作为机器学习模型
记住,样本的直方图应该模仿从样本中提取的总体的概率密度函数(PDF)的形状。事实上,我们可以将直方图视为一种机器学习模型,它基于我们从总体中提取的样本来学习总体的 PDF。这是一个非常简单的模型,因为它只有一个超参数,即箱宽(h)。这个超参数控制了模型的复杂性。(超参数是用来配置机器学习模型的参数,其值由用户在训练过程开始前确定。相比之下,模型的其他参数值是通过训练确定的。)
像其他任何机器学习模型一样,我们可以计算直方图的偏差和方差。假设 X 是一个具有未知分布的随机变量,而随机变量 X₁、X₂、… X_n 通常表示可以从该分布中抽取的随机样本。抽取随机样本后,我们将其中的观测值记为 x₁、x₂、… x_n(这里每个 xᵢ 是 Xᵢ 的一个具体值),这些观测值形成用于生成直方图的数据集。现在我们想创建一个学习 X 分布的 PDF 的直方图。记住,从方程 6 中,我们有:

但这个方程式适用于一个特定的数据集。如果我们想将其应用于随机样本 X₁、X₂、… X_n,我们需要用随机变量 Nᵢ 替换 nᵢ。这是因为 nᵢ 的实际值取决于 X₁、X₂、… X_n 的值。因此,得出:

请注意,虽然 X₁、X₂、… X_n 的值在每次我们抽取的样本中可能会变化,但样本大小(n)、箱宽(h)和子区间 Bᵢ 保持不变。所以,它们是随机变量。现在如果 Nᵢ 是一个随机变量,它的分布是什么?要回答这个问题,我们首先需要计算每个 Xₖ 位于区间 Bᵢ 的概率。假设

其中 a 是一个常数端点,并且不随样本的变化而改变。由于每个 Xₖ 都是从相同的 X 分布中抽样的,所以 Xₖ 属于 Bᵢ 的概率可以写成:

现在我们有 n 个随机变量,每个随机变量位于 Bᵢ 的概率是 pᵢ,我们想知道 Nᵢ 的分布,其中 Nᵢ 表示这些随机变量中位于 Bᵢ 的总数。这类似于在 n 次掷硬币中找出正面总数的问题,其中出现正面的概率是 pᵢ。我们已经知道,如果我们用一个随机变量表示正面总数,它具有参数为 n 和 pᵢ 的二项分布。因此我们得出结论,Nᵢ 应该具有参数为 n 和 pᵢ 的二项分布:

现在我们可以通过知道 Nᵢ 的分布来计算它的均值和方差:


我们还可以计算在测试点 x_t 上 f^(* x *) 的均值和方差:


因此,* f ^( x_t *) 的偏差为:

基于积分的均值定理,我们知道如果 f 在区间 Bᵢ 上是连续的,则在 Bᵢ 中存在一个点 εᵢ 使得

其中 d₁ 和 d₂ 是方程 7 中 Bᵢ 的端点:

使用这个方程,* pᵢ *可以简化为:

所以,得出结论

因此,通过减小 h,方差会增加。在进一步简化偏差项之前,我们需要做一个假设。这里我们假设 f(x) 在区间 Bᵢ 上是Lipschitz 连续的。如果一个函数在区间 Bi 上是 Lipschitz 连续的,那么存在一个正的常数 γᵢ 使得

现在使用这个方程,我们可以简化偏差:

这里我们利用了* x 和 εᵢ *都在宽度为 h 的区间 Bᵢ 内的事实,因此它们的距离不能大于 h。你会发现通过减小 h,偏差的绝对值会减少。
通过有了偏差和方差,我们可以很容易地计算在 x_t 上的均方误差 (MSE):

如果我们将直方图视为模型,区间宽度 h 是这个模型唯一的超参数,它类似于模型复杂性的度量。随着 h 趋近于零,模型变得更加复杂。因此偏差的绝对值减少而方差增加(过拟合)。另一方面,当 h 增加时,模型变得更简单,因此方差减少,但偏差的绝对值增加(欠拟合)。
所以,我们得出结论,偏差的绝对值和方差之间存在权衡:增加一个会减少另一个。这是机械学习中偏差-方差权衡的表现,它描述了模型复杂性与预测准确性之间的关系。
列表 5 展示了偏差-方差权衡的演示。这里我们从均值为 5、方差为 1 的正态分布中获取 100 个大小为 80 的随机样本。我们还尝试了从 0.15 到 2.7 的一系列 h 值。接着,我们选择一个测试点 (xt),并计算该点的正态分布 PDF (f(x_t) )。对于每个 h 值,我们计算所有 100 个随机样本的直方图。然后,我们计算这些直方图的 f^(x_t)。最后,使用 f(xt) 和所有随机样本的 f^(x_t) 值,我们可以计算每个 h 值在测试点 x_t 的偏差²、方差和 MSE。
然而,我们并不想仅仅最小化一个测试点 (x_t) 的直方图预测误差。因此,我们从数组 xt_list 中选择一个测试点范围,并计算每个 h 值在所有测试点上的平均偏差²、方差和 MSE。列表 5 创建了不同 h 值的平均偏差²、方差和 MSE 的图示,如图 5 所示。
# Listing 5
np.random.seed(0)
n_samples = 100
sample_size = 80
h_list = np.arange(0.2, 2.7, 0.15)
mu = 5
sigma = 1
samples = norm.rvs(loc = mu, scale = sigma,
size=n_samples*sample_size).reshape(n_samples, sample_size)
xt_list = np.arange(mu-3*sigma, mu+3*sigma, 0.05)
f_xt = norm.pdf(xt_list, loc = mu, scale = sigma)
avg_mse_list = []
avg_var_list = []
avg_bias_sq_list =[]
for h in h_list:
fhat_xt_list = []
for i in range(n_samples):
(ni, bins) = np.histogram(samples[i], density=True,
bins = np.arange(min(samples[i]),
max(samples[i]) + h, h))
fhat_xt_ind = np.digitize(xt_list, bins)-1
fhat_xt = np.pad(ni, [1,1], 'constant',
constant_values=0).take(fhat_xt_ind+1, mode='clip')
fhat_xt_list.append(fhat_xt)
avg_bias_sq_list.append(((np.vstack(fhat_xt_list).mean(axis=0)-f_xt)**2).mean())
avg_var_list.append(np.vstack(fhat_xt_list).var(axis=0).mean())
avg_mse_list.append(np.mean((np.vstack(fhat_xt_list)-f_xt)**2).mean())
plt.plot(h_list, avg_var_list, "-o", label="Var")
plt.plot(h_list, avg_bias_sq_list, "-o", label="$Bias²$")
plt.plot(h_list, avg_mse_list, "-o", label="MSE")
plt.xlabel("h", fontsize=14)
plt.ylabel("Avg MSE, Avg Var, Avg $Bias²$", fontsize=14)
plt.legend(loc="best")
plt.show()

图 5
通过减少 h 的值,模型复杂度和平均方差增加,但平均偏差²减少。过于复杂的模型会导致过拟合。另一方面,增加 h 的值会导致模型更简单,从而减少平均方差并增加平均偏差²。过于简单的模型会导致欠拟合。如前所述,MSE 是方差和偏差平方的总和,因此在欠拟合和过拟合的情况下,我们对所有测试点的平均 MSE 都很大。正如图中所示,h=0.8 为所有测试点提供了最低的平均 MSE,因为此时平均偏差²和平均方差之间存在平衡。这个点代表了一个既不简单也不复杂的模型,因此最小化了平均 MSE。
列表 6 绘制了三种 h 值的直方图(图 6)。在 h=2.3 时直方图欠拟合,在 h=0.15 时直方图过拟合,两种情况均无法很好地估计 PDF。根据图 5 的结果,最佳拟合发生在 h=0.8。此值的直方图提供了对 PDF 的最佳估计。
# Listing 6
h_list = [2.3, 0.8, 0.15]
labels = ["Underfitting", "Right fit", "Overfitting"]
x = np.arange(1, 10, 0.01)
f_x = norm.pdf(x, loc = mu, scale = sigma)
fig, ax = plt.subplots(1, 3, figsize=(17,6))
for i in range(len(ax)):
p = 3
ax[i].hist(samples[p], density=True,
bins = np.arange(min(samples[p]),
max(samples[p]) + h_list[i], h_list[i]),
edgecolor='black', linewidth=1)
ax[i].plot(x, f_x, color='red', label="PDF", linewidth=2)
ax[i].set_xlabel('$x$', fontsize=24)
ax[i].set_title("{} (h={})".format(labels[i], h_list[i]), fontsize=22)
ax[i].legend(loc='best', fontsize=15)
ax[0].set_ylabel('Probability density', fontsize=24)
plt.plot()

图 6
如你所见,如果我们有足够多的训练数据集,我们可以找到 h 的最佳值,但实际上我们只有一个。在这种情况下,我们可以使用 Scott 规则来找到 h 的最佳值:

其中 σ^ 是样本标准差,n 是样本大小。此规则假设样本服从正态分布。如果我们对列表 5 中的所有样本使用此规则,我们会发现 h 的平均值非常接近列表 5 中找到的 h 的最佳值:
(3.5*samples.std(axis=1, ddof=1)*samples.shape[1]**(-1/3)).mean()
0.8005
请注意,实际上我们只有一个样本来估计 h,因此估计误差可能较大。
核密度估计
为了避免直方图对其端点的依赖,我们可以使用不同的公式来近似 CDF 的导数。这个公式被称为对称差商:

如你所见,与方程 5 不同,这个方程不依赖于 a。接下来,我们可以写作:

我们也可以使用指示函数来计算这个概率。设 A 为集合。A 在 x 处的指示函数定义为:

因此,它在其域内的所有属于 A 的点上等于 1,而在所有其他点上等于 0。现在,我们可以使用指示函数来计算属于 x-h, x+h) 的观察值的数量。记住,x₁,x₂,… x_n 代表我们想要生成直方图的数据集。因此,对于每个 xᵢ:

并且得出:

如果我们用 f^(x) 表示这个 PDF 的估计,我们可以将前面的方程写作:

接下来,我们可以写作:

因此,我们可以写出我们的 PDF 估计为:

现在让我们看一下总和内的表达式,检查它可以取的不同值:

与连续均匀分布的 PDF 比较:

我们可以观察到 K((x-xᵢ)/h) 是区间 [-1, 1] 上均匀分布的 PDF。
列表 7 使用这种方法来估计在列表 1 中生成的样本的 PDF。结果如图 7 所示,并与样本来源分布的 PDF 进行了比较。
# Listing 7
def kde(x, data, h, dist=norm(0, 1)):
n = len(data)
K = dist.pdf((x.repeat(n).reshape(len(x), n) - data) / h)
return np.sum(K, axis=1) / n / h
h = 0.75
x = np.arange(0, 13, 0.01)
fhat = kde(x, sample, h, dist=uniform(loc=-1, scale=2))
plt.figure(figsize=(14, 5))
plt.plot(x, fhat, label='KDE')
plt.plot(x, pdf, color='red', linewidth=2, label="PDF")
plt.xlim([0, 13])
plt.xlabel('$x$', fontsize=20)
plt.ylabel('Probability density', fontsize=20)
plt.legend(loc='best', fontsize=18)
plt.show()

图 7
我们可以将方程 8 进行概括,并用不同的 PDF 替代 K。我们通常将核密度估计器 f^(x) 定义为:

在这个方程中,K(x) 称为 核函数,h 称为 带宽。我们也可以将前面的方程写作:

其中 K_h 称为 缩放核,其定义为:

核函数应满足以下条件:

基于前两个条件,核函数应为非负且归一化,以保证核函数是一个概率密度函数(PDF)。根据第三个条件,它应为对称的 PDF。这个条件意味着:

观察图 7,你会发现估计器曲线不平滑,因为均匀密度没有平滑的形状。我们可以使用标准正态分布的 PDF 作为核函数,因为它具有平滑的形状:

将这个核代入公式 9,我们得到:

因此,估计器是 n 个均值为 xᵢ 的正态分布的 PDF 的平均值,而带宽 h 可以被视为这些分布的标准差。
列表 8 使用了带有标准正态核的估计器来估计列表 1 中生成的样本的 PDF。结果如图 8 所示,并且与样本抽取来源的分布的 PDF 进行了比较。n 个缩放的核 (K_h) 也在该图中绘制,每个缩放核都是均值为 xᵢ 和标准差为 h 的正态分布。
# Listing 8
def plot_kernel(x, data, h):
n = len(data)
K = norm(0,1).pdf((x.repeat(n).reshape(len(x), n) - data) / h) / h
for i in range(K.shape[1]-1):
plt.plot(x, K[:,i], color="grey", alpha=0.3)
plt.plot(x, K[:,i+1], color="grey", alpha=0.3, label="$K_h(x-X_i)$")
h=0.75
x = np.arange(0, 13, 0.01)
fhat = kde(x, sample, h)
plt.figure(figsize=(14, 5))
plt.plot(x, fhat, label='KDE', linewidth=2)
plt.plot(x, pdf, color='red', linewidth=2, label="PDF")
plot_kernel(x, sample, h)
plt.xlim([0, 13])
plt.xlabel('$x$', fontsize=20)
plt.ylabel('Probability density', fontsize=20)
plt.legend(loc='best', fontsize=15)
plt.show()

图 8
作为机器学习模型的核密度估计器
类似于直方图,核密度估计器可以被视为一种机器学习模型,它根据我们从总体中抽取的样本来学习总体的概率密度函数(PDF)。该模型的超参数包括带宽 (h)、核类型及其对应的参数。这些超参数使得核密度估计器相比于直方图更加灵活。
假设我们有公式 9 中的核估计器。为了计算这个估计器的偏差和方差,我们需要使用从总体中抽取的所有可能样本,因此我们使用独立同分布的随机样本 X₁、X₂、… X_n 来创建密度估计器:

现在可以证明,对于测试点 x_t,f^(x_t) 的偏差为:

其中 σ_K 定义为:

我们还可以证明 f^(x_t) 的方差如下:

R(K) 被定义为:

带宽 (h) 控制了模型的复杂性。当 h 趋近于零时,模型变得更复杂,因此偏差的绝对值减少而方差增加(过拟合)。相反,当 h 增加时,模型变得更简单,因此方差减少,但偏差的绝对值增加(欠拟合)。
计算偏差和方差的详细信息见附录。现在,我们可以计算 MSE:

到目前为止,我们只考虑了一个测试点 x_t。然而,通常我们希望控制密度估计器的总体均方误差(MSE)。因此,我们需要计算均值积分平方误差(MISE):

现在,将方程 13 代入此方程,我们得到:

这个方程可以使用方程 7 和方程 12 简化:

此方程中的两个主导项称为渐近均方积分误差 (AMISE):

通过将 AMISE 的导数设置为零,我们可以找到使其最小化的 h 的最佳值:

解此方程可以得出 h 的最佳值:

请注意,我们不能直接使用上述涉及未知 PDF 二阶导数 (f’’) 的公式。因此,我们不能直接使用它。然而,我们可以对 PDF 类型做出假设。如果我们假设要估计一个均值为 µ 和方差为 σ² 的正态分布的 PDF,则:

由此得到

然后我们得到:

如果使用标准正态分布的 PDF 作为核函数:

然后得到

和

将这些方程代入方程 14,我们得到:

由于我们通常不知道总体的标准差 (σ),我们可以使用样本标准差 (σ^) 来代替:

这是估计 KDE 带宽的 Scott 规则。列表 9 显示了核密度估计器的偏差-方差权衡。我们从均值为 5 和方差为 1 的正态分布中获取了 100 个样本,每个样本大小为 100。我们尝试了从 0.14 到 0.52 的 h 范围。我们从数组 xt_list 中挑选了一系列测试点。
对于 xt_list 中的每个测试点 (xt),计算该点的正态分布的 PDF (f(x_t))。然后计算所有 100 个随机样本的 KDE 和这些样本的 f^(x_t)。最后,使用 f(xt) 和所有随机样本的 f^(x_t) 值,计算 x_t 处的偏差²、方差和 MSE。最后,对于每个 h 的值,我们计算 xt_list 中所有测试点的平均偏差²、方差和 MSE。列表 9 创建了不同 h 值的平均偏差²、方差和 MSE 的图示,如图 9 所示。
# Listing 9
np.random.seed(5)
n_samples = 100
sample_size = 500
h_list = np.arange(0.14, 0.52, 0.02)
mu = 5
sigma = 1
samples = norm.rvs(loc = mu, scale = sigma,
size=n_samples*sample_size).reshape(n_samples, sample_size)
xt_list = np.arange(mu-4*sigma, mu+4*sigma, 0.02)
f_xt = norm.pdf(xt_list, loc = mu, scale = sigma)
avg_mse_list = []
avg_var_list = []
avg_bias_sq_list =[]
for h in h_list:
fhat_xt_list = []
for i in range(n_samples):
fhat_xt = fhat = kde(xt_list, samples[i], h)
fhat_xt_list.append(fhat_xt)
avg_bias_sq_list.append(((np.vstack(fhat_xt_list).mean(axis=0)-f_xt)**2).mean())
avg_var_list.append(np.vstack(fhat_xt_list).var(axis=0).mean())
avg_mse_list.append(np.mean((np.vstack(fhat_xt_list)-f_xt)**2).mean())
plt.plot(h_list, avg_var_list, "-o", label="Var")
plt.plot(h_list, avg_bias_sq_list, "-o", label="$Bias²$")
plt.plot(h_list, avg_mse_list, "-o", label="MSE")
plt.xlabel("h", fontsize=14)
plt.ylabel("Avg MSE, Avg Var, Avg $Bias²$", fontsize=14)
plt.legend(loc="best")
plt.show()

图 9
当 h 趋近于零时,模型变得更加复杂,结果是平均偏差²减少,平均方差增加(过拟合)。另一方面,当 h 增加时,模型变得更简单,导致欠拟合。因此,平均方差减少,但平均偏差²增加。
我们还可以使用方程 15 来估计 h 的最佳值。在这里,我们计算了列表 9 中所有样本的 h 的平均值:
(1.06*samples.std(axis=1, ddof=1)*samples.shape[1]**(-1/5)).mean()
0.3060
我们看到 h 的平均值非常接近于列表 9 中找到的 h 的最佳值。请注意,实际上我们只有一个样本来估计 h。列表 10 绘制了三个 h 值的 KDE(图 10)。在 h=1 时,KDE 欠拟合;在 h=0.1 时,KDE 过拟合,这两种情况下 KDE 都没有很好地估计 PDF。正确的拟合发生在 h=0.8。这个 h 值的 KDE 给出了对 PDF 的最佳估计。
# Listing 10
h_list = [0.1, 0.31, 1]
labels = ["Underfitting", "Right fit", "Overfitting"]
x = np.arange(1, 10, 0.01)
f_x = norm.pdf(x, loc = mu, scale = sigma)
fig, ax = plt.subplots(1, 3, figsize=(17,6))
for i in range(len(ax)):
ax[i].plot(x, f_x, color='red', label="PDF", linewidth=2)
fhat = kde(x, samples[1], h_list[i])
ax[i].plot(x, fhat, linewidth=2, linestyle="--", label="KDE")
ax[i].set_xlabel('$x$', fontsize=24)
ax[i].legend(loc='best', fontsize=15)
ax[i].set_title('h={}'.format(h_list[i]), fontsize=22)
ax[0].set_ylabel('Probability density', fontsize=24)
plt.plot()

图 10
SciPy库中的函数 gaussian_kde() 可以用来计算 KDE。其他库如 seaborn 和 matplotlib 使用此函数来绘制数据集的 KDE。该函数使用标准正态核来估计 PDF。文档提到它使用 Scott 规则来估计 h,然而,公式略有不同:

列表 11 绘制了列表 1 中定义的样本的 KDE。它比较了方程 15 和 SciPy 的规则。图 11 展示了该图。
# Listing 11
x = np.arange(0, 13, 0.01)
f_x = norm.pdf(x, loc = mu, scale = sigma)
h1 = 1.06 * sample.std(ddof=1) * len(sample)**(-1/5)
h2 = sample.std(ddof=1) * len(sample)**(-1/5)
plt.figure(figsize=(10, 7))
plt.plot(x, pdf, color='red', label="PDF", linewidth=1)
fhat1 = kde(x, sample, h1)
fhat2 = kde(x, sample, h2)
sns.kdeplot(data=sample, color="green", linewidth=2, label="KDE, SciPy")
plt.plot(x, fhat1, linewidth=2, linestyle="--", color= "blue",
label="KDE, $h=1.06\hat{\sigma}n^{-1/5}$")
plt.plot(x, fhat2, linewidth=2, linestyle="--", color="black",
label="KDE, $h=\hat{\sigma}n^{-1/5}$")
plt.xlim([0, 13])
plt.xlabel('$x$', fontsize=20)
plt.legend(loc='best', fontsize=13)
plt.ylabel('Probability density', fontsize=20)
plt.plot()

图 11
在本文中,我们讨论了直方图和核密度估计。这两种方法都用于估计数据集的概率分布的 PDF。因此,它们都可以被视为学习数据集 PDF 的机器学习模型。我们展示了如何计算这些模型的偏差和方差。均方误差(MSE)是方差和偏差的平方之和,我们希望通过找到直方图的最佳箱宽或核密度估计器的最佳带宽来最小化它在所有可能的测试点上的值。文章展示了如何使用 Scott 规则来估计这些最佳值。
希望你喜欢阅读这篇文章。本文中的所有代码列表都可以从 GitHub 上的 Jupyter Notebook 中下载:
github.com/reza-bagheri/histograms_kde/blob/main/KDE.ipynb
附录:
记住,随机样本 X₁,X₂,… X_n 是独立且同分布的,表示从随机变量 X 的分布中抽取的随机样本(我们想估计其 PDF),因此它们每一个单独的分布与 X 相同。因此,随机变量

是独立且同分布的,每一个都具有相同的分布:

核估计器

是这些独立同分布随机变量的均值,因此我们有:


现在我们可以计算在测试点 x_t 处核估计器的均值和方差:

接下来,我们进行变量变换 y=(x_t-x)/h。因此,之前的方程可以写作:

现在我们可以使用泰勒级数来估计 f(x_t-yh):

其中 o(h²) 表示这是一个相比于 h² 的更小的阶项,当 h 趋近于零时。将此方程代入前一个方程,我们得到:

现在使用方程 11 (I) 和 (II),我们可以简化此方程并将其写作:

σ_K 被定义为:

x_t 的偏差定义为:

因此,得出:

请注意,我们需要方程 11 (II) 中的条件来得到这个偏差方程。接下来,我们计算点 x_t 处的方差:

使用方差的定义可以写作:

要计算此方程右侧的第一项,我们可以写作:

我们在这里使用了变量变换 y=(x_t-x)/h。再次,我们使用泰勒级数来估计 f(x_t-yh):

将此方程代入前一个方程,我们得到:

R(K) 被定义为:

此方程右侧的第二项可以使用方程 A.1 进行计算:

在这里,如果我们忽略所有的项并将其近似为零,我们得到:

最后,我们得到:

在这里,我们利用了当 h 趋近于零时 1/h>>1 的事实。所以

理解独立性及其在因果推断和因果验证中的重要性
理解独立性的逐步指南以及如何应用它来验证使用 Python 的有向无环图(DAG)在因果验证中的有效性
·发表于Towards Data Science ·23 分钟阅读·2023 年 12 月 7 日
--

图片由Towfiqu barbhuiya提供,来源于Unsplash
背景
在最近的一篇文章中,作者探讨并解释了如何使用依赖性概念来验证提出的有向无环图(DAG)是否符合数据集,以识别图中的虚假边,即 DAG 中建议的因果链接在数据中不存在。
在第二部分中,将应用相反的(但同样关键的)概念,即如何利用独立性来识别缺失的边。这些是提议的 DAG 中未出现但实际上存在于数据中的因果链接,必须将这些链接添加回 DAG 中,以使其完整且正确。
介绍
因果推断是数据科学的一个新兴分支,关注于确定事件和结果之间的因果关系,并且它有潜力显著提高机器学习为组织创造的价值。
例如,传统的机器学习算法可以预测哪些贷款客户可能会违约,从而使客户能够进行主动干预。然而,尽管这个算法对于减少贷款违约有用,但它并不了解违约发生的原因,虽然主动干预是有用的,但了解违约的原因将有助于解决根本原因。在这种情况下,主动干预可能不再必要,因为导致违约的因素已被永久治愈。
这就是因果推断的承诺,它有可能对能够利用这种潜力的组织产生重大影响和结果。
有多种不同的方法,但最常见的方法通常从用“有向无环图”(DAG)来增强数据开始,这种图形封装和可视化数据中的因果关系,然后使用因果推断技术来提出“如果”类型的问题。
问题
封装数据中的因果关系的 DAG 通常由数据科学家和领域专家手动(或半手动)构建。因此,DAG 可能是错误的,这会使任何因果计算失效,从而导致错误的结论和潜在的不正确决策。
机会
存在一系列的“因果验证”技术(验证 DAG 是否符合数据的过程),如果这些技术有效,它们可以最小化或消除 DAG 中的错误,从而确保计算和结论是无误的。
前进的道路
随机变量之间的独立性统计概念可以用来确定 DAG 中不存在的关系是否在数据中存在。如果遇到这种情况,则很可能 DAG 中缺失的因果关系需要添加到 DAG 中,以使其完整和正确。
开始
我们需要一个示例 DAG 来解决问题,这个 DAG 需要有足够的节点和链接,以提供一个好的示例来探索问题 …

文章中使用的 DAG — 作者图片
DAG 完全是虚构的,所以节点上的字母没有任何意义,但需要注意的是,“X”是治疗,“Y”是结果,其他节点代表影响结果的因素,有可能隐藏或扭曲 X 对 Y 的真实效果。
为了更好地理解 DAG,如果这是一个现实世界的问题,它可能代表以下情况 …
-
X 代表药物服用的规律性。
-
W 代表药物对血压的影响。
-
Y 代表对患者恢复和结果的改善。
-
Z1、Z2 和 Z3 代表其他因素(例如,也许 Z1 代表健康的生活方式,Z3 代表健身水平等)
DAG 中的箭头表示一个因素对另一个因素的因果影响,例如 …
- 使用药物(用“X”表示)对血压(用“W”表示)有因果影响,而血压又对患者恢复(“Y”)有因果影响。
我们还需要一些与 DAG 匹配的数据。下面的数据集完全是合成的,由作者生成;它准确地封装并匹配了 DAG 建议的结构,并且没有错误或虚假的关系…

全文使用的数据集 — 作者提供的图像
接下来的部分将开始解读 DAG 和数据,并利用它们解释如何利用统计独立性来识别 DAG 中遗漏或忽视的因果关系。
理解独立性
独立性的一个定义如下 —
“两个随机变量之间的独立性是指一个随机变量的发生或值不会影响或提供关于另一个随机变量发生或值的任何信息的基本概念。”
我们再看看我们的 DAG,并考虑节点 Z1 和 Z2 …

突出显示的 Z1 和 Z2 节点的 DAG — 作者提供的图像
我们可以看到,节点 Z1 和 Z2(代表影响治疗和结果的一些因果因素)没有直接或间接的连接,DAG 中没有通过它们的路径,因此可以说 Z1 与 Z2 是独立的(反之亦然)。
为了进一步说明这一点,我们还可以看到,尽管 X(治疗)和 Y(结果)没有直接连接,但它们实际上是依赖的,因为 DAG 中有几条路径将它们连接起来 …

节点 X 和节点 Y 之间的所有路径 — 作者提供的图像
这两个例子表明,如果 Z1 的值发生变化,它不会影响或改变 Z2 的值,但如果 X 的值发生变化,它会改变 Y 的值。
这个解释可以通过再看看 DAG 建模的数据集来扩展 …

数据集审查 — 作者提供的图像
在现实世界的因果推断问题中,数据将是起点,DAG 将通过与领域专家咨询来开发,但为了方便文章,作者需要一个保证与 DAG 匹配的数据集。
因此,上述数据集是通过应用以下公式生成的 1000 行数据 …

用于创建数据集的结构方程 — 作者提供的图像
可以按如下方式阅读和理解 -
-
Z1 是一个外生变量(即没有输入),呈正态分布,均值为 4.37,标准差为 1.95。
-
Z1 是一个外生变量,呈正态分布,均值为 1.28,标准差为 1.94。
-
Z3 = -1.5 X Z1–1.5 x Z2 + 一个误差项
-
X = -1.5 x Z1 + 1.5 x Z3 + 一个误差项
-
W = -3 x XZ + 一个误差项
-
Y = -2.5 x W + -3 x Z2 + -3 x Z3 + 误差项
用于创建数据的公式清楚地表明 Z1 独立于 Z2,因为它们都是完全分开的正态分布,但 X 和 Y 并不独立,因为改变 X 会改变 W,改变 W 会改变 Y。
还有另一种表示方式,通过数学符号显示依赖关系……

独立性和依赖性符号 — 作者提供的图片
⫫ 符号被称为“双向叉”,意思是“独立于”。⫫̸ 符号没有被广泛接受的名称,因此我个人偏好使用“斜线双向叉”,意思是“依赖于”,因此上述公式可解读为“Z1 独立于 Z2”,以及“Y 依赖于 X”。
在本节结束时,我们将把这些知识带回到因果验证的背景下。
在因果推断问题中,数据已经被收集,有向无环图通常由领域专家单独构建,他们可能犯了错误或知识可能不完整。
因果验证则是证明或反驳有向无环图是否有效地表示数据因果关系的过程。
独立性在这个过程中起着关键作用,因为如果有向无环图中的依赖关系和独立性在数据中都能被匹配和检测到,那么可以推断有向无环图是有效的。
这可以用以下公式表示……

推断 DAG 和数据等效的公式 — 作者提供的图片
这些公式看起来很吓人,但实际上非常简单。
第一条说,如果图(DAG)中的 Z1 独立于 Z2,那么 Z1 在数据中也应该独立于 Z2。
第二条说,如果图(DAG)中的 Y 依赖于 X,则 Y 在数据中也应该依赖于 X。
使用独立性来识别缺失的因果链接
在上一篇文章中,我解释了如何使用独立性来检测虚假链接,即在有向无环图中出现但在数据中未出现的因果关系……
理解依赖性概念及其如何应用于验证有向无环图的逐步指南…
towardsdatascience.com
… 在这篇文章的其余部分,我将深入探讨如何利用依赖性来识别缺失的链接,即在数据中出现但在 DAG 中未出现的因果关系。
因此,如果可以使用独立性来识别缺失的链接,那么这些链接可以重新添加到无效的 DAG 中,使其变为有效。
本文提出的方法基于 Judea Pearl 的多项工作和已发表的论文,他被公认为因果推断领域的全球权威。
在《统计中的因果推断》(Pearl, Glymour, Jewell, 2019)中,Pearl 探讨了使用 d-分离来识别缺失链接的想法,但这是一个部分示例,未提供足够的解释来实现 Python 中的算法。
购买《统计中的因果推断:入门》1 由 Pearl, Judea, Glymour, Madelyn, Jewell, Nicholas P.(ISBN…)
注意:上面的链接是亚马逊的附属链接。如果你通过这个链接购买书籍,作者会获得小额的推荐费,而对购买者没有额外费用。
以下探讨旨在提供足够的解释和示例,以便可以在 Python 中实现一个可靠识别 DAG 中缺失但数据中存在的因果链接的解决方案。
这是实现目标的规则 …
“在 DAG 中,任何节点 N 在给定其父节点的情况下应与其非后代节点独立”
(作者原话)
听起来很复杂,单独理解确实困难,但通过示例可以使其更清晰。
DAG 中的每个节点都可以单独测试该规则。这里是应用于节点“W”的规则 …

节点 W 的独立性规则 — 作者提供的图像
… 这可以理解为 — “W 在给定(条件为)其父节点时与其非后代节点独立”。
那么这是什么意思,以及如何应用它来识别 DAG 中缺失的因果关系呢?
首先,我们需要完全理解“父节点”是什么意思以及“非后代”是什么意思。
注意:以下图示的颜色编码如下……
-
粉色:感兴趣的节点
-
红色:父节点
-
绿色:后代
-
黄色:非后代
-
蓝色:所有其他节点
节点的父节点易于可视化;它们是 DAG 中有箭头指向该节点的因果链接的节点,我们可以通过考虑表示节点 W 的父节点和节点 X 的父节点的图表来可视化这一点……

节点 W 的父节点和节点 X 的父节点 — 作者图片
我们可以看到节点 W 有一个父节点 — X,而节点 X 有两个父节点 — Z1 和 Z3,我们也可以看到不同的节点可以有零个、一个或多个父节点。
非后代略难以可视化,因为它们与祖先有细微的不同。我对节点的非后代的定义如下……
“节点的非后代是所有不是后代且也不是直接父节点的节点”
为了说明这一点,让我们首先可视化节点 W 的所有后代,以及节点 X 的所有后代……

节点 W 的非后代和节点 X 的非后代 — 作者图片
最后,这里是整合在一起的 — 一个节点 W 的图示和一个节点 X 的单独图示,突出显示它们的父节点为红色,非后代为黄色……

节点“的父节点和非后代及节点 X 的父节点和非后代 — 作者图片
此时你可能在想两件事……
为什么非后代排除直接父节点?
对这个问题的回答是,这是一种“非后代”的定义,广泛遵循了 Judea Pearl 的书中提出的发现缺失链接的解决方案。如果你包括父节点,那么以下公式将不起作用……

节点 W 的独立性公式 — 作者图片
……因为如果使用“给定”符号(|),父节点将出现在两侧。
这些与因果验证有什么关系?
这个问题的答案将在下一节揭示……
使用独立性、非后代和父节点来识别 DAG 中缺失的链接
到目前为止,我们已经定义了什么是独立性,即如果 A 对 B 是独立的,那么改变 A 的值对 B 没有影响。我们还探讨和理解了“父节点”和“非后代”的含义,并且我们从文献中推导和整合了一个可以识别缺失链接的公式(𝑁 ⫫ 𝑛𝑜𝑛−𝑑𝑒𝑠𝑐𝑒𝑛𝑑𝑎𝑛𝑡𝑠 | 𝑝𝑎𝑟𝑒𝑛𝑡𝑠)。
现在只剩下把所有这些整合在一起 — 简单!好吧,也许不简单,所以让我们重新查看节点 W 的最终图示并添加一些额外的路径……

DAG 强调了节点 W 的非后代到节点 W 的可能缺失路径 — 图片由作者提供
节点 Z1、Z3、Z2 和 Y 之间的亮粉色路径在数据中不应该存在,因为它们在 DAG 中不存在。因此,如果在数据中检测到这些依赖关系,则 DAG 必定是错误的,不仅如此,我们将确切知道哪里出错了!一个在 DAG 上缺失的链接将被识别,使当前 DAG 无效,然后可以通过添加缺失的链接来修正。
这是一个令人惊叹的可能性!让我们先从将其表示为节点 W 的数学公式开始……

关于 DAG 中节点 W 的独立性的公式和数据中的独立性 — 图片由作者提供
这看起来很吓人,但它只是说如果在图(DAG)中 W 在给定 X 的情况下独立于 Z1、Z2 和 Z3,那么在数据中 W 也应该在给定 X 的情况下独立于 Z1、Z2 和 Z3。
如果你需要复习在 DAG 中条件下的“给定”是什么意思,这篇文章提供了深入的逐步解释……
从基础到更高级的方面,逐步解释有向无环图
towardsdatascience.com
结果发现,运行这个测试并查看其是否成立(即数据是否与 DAG 匹配)相对容易。让我们试试……
为了在公式右侧对 W 进行测试,我们可以对左侧的所有 4 个变量进行回归 — Z1、Z2、Z3 和 X。一些教科书和来源提到 Z1、Z2 和 Z3 的“消失”。
这意味着如果在父变量和非后代变量上进行回归,那么所有非后代变量的系数应该为零或接近零,因此让我们从可视化这些数据关系开始……

回归 W 时 X、Z1、Z2 和 Z3 的系数的图形表示 — 图片由作者提供
从这个可视化中我们可以看到 Z1、Z2 和 Z3 确实有一个平坦(或“消失”)的系数。如果我们用数学表示这些关系,它们看起来像这样……

回归 W 时 X、Z1、Z2 和 Z3 的系数的数学表示 — 图片由作者提供
到目前为止一切看起来都不错,那么我们如何在 Python 代码中使用普通最小二乘(OLS)回归来实现这一测试并提取结果?
以下是使用本文早些时候的合成数据集来完成这一点的源代码……

OLS 结果汇总 — 作者提供的图像
关键方面是结果汇总倒数第二个表格中的coef列(系数或斜率)和P>|t|列(p 值)。
我们可以很容易地看到 Z1、Z2 和 Z3 的系数(或图中的斜率)很小但不为零(即完全平坦),那么我们如何得出它们是“消失”的结论?
事实证明,在我阅读的所有文本或文章中,没有关于“消失”系数的定义,因此我基于试验和大量测试提出了自己的方法……
W is not dependent on Z1 in the data
W is not dependent on Z2 in the data
W is not dependent on Z3 in the data
W is dependent on X in the data
我选择的测试是基于多小时的测试结果,是同时查看 OLS 结果汇总中的 p 值和系数的组合。
理论上,p 值本身应该足够。对于每个变量,有一个原假设,即独立变量(例如 Z1)与因变量(W)之间没有关系。
如果 p 值小于α值(通常选择 0.05),则拒绝原假设,结论是存在某种关系。
如果我们查看变量 Z1 的结果,可以看到 p 值为 0.473,远高于 0.05,因此我们不能拒绝原假设,结论是 Z1 与 X 之间没有关系。这同样适用于 p 值为 0.176 的 Z2 和 p 值为 0.518 的 Z3。然而,X 的 p 值为 0.000,因此不能拒绝原假设,这一切与上面的斜率图一致。
因此,在这种情况下,Z1、Z2 和 Z3 与 W 的独立性以及 X 对 W 的依赖性可以通过仅查看 p 值来确定,但在我从广泛测试中观察到的情况是,情况并非总是如此,通过定义一个“消失”系数来指示独立性,可以获得最佳结果,如下所示……
“对于每个变量,如果 p 值大于 0.05 且系数小于或等于 1,那么该变量是独立的(或称为‘消失’)”
我发现如果在进行 p 值检验的同时检查系数,那么在大量随机选择的测试中,准确性会显著提高,因此这是我开发并选择在 Python 代码中实现的定义,以尽可能准确地执行检验。
识别无效 DAG 中的缺失因果链接
此时,我们考虑了一个 DAG 和一个数据集,知道 DAG 是数据因果关系的准确表示,因为作者创建数据集以反映 DAG,而不是反过来(这在现实世界中是实际情况)。
已提出一种方法来测试 DAG 相对于一组数据的有效性,该方法涉及测试 DAG 中每个节点的非后代的独立性,如果不能确定独立性,则假定存在缺失链接。
该方法已经在 Python 中实现,并在 DAG 与数据匹配的情况下进行了测试,证明了 DAG 是有效的。
但这只是开始。测试匹配 DAG 的有效性是一回事,更重要的是 — 当 DAG 无效且存在数据集中确实存在的缺失链接时,这种方法能否检测到这些链接?
考虑一下我们 DAG 的这个变体……

带有从节点 Z1 到节点 W 附加链接的 DAG 变体 — 图片由作者提供
在节点 Z1 和节点 W 之间添加了一个新的因果链接,这里是一个包含所有因果关系(包括新关系)的新的合成数据集……

包含从 Z1 到 W 的新链接的新数据集 — 图片由作者提供

定义和描述新数据集的结构方程 — 图片由作者提供
现在假设我们的虚构领域专家不知道 Z1 和 W 之间的因果链接,并根据他们的领域知识创建了这个 DAG,而这些知识是数据科学家咨询他们时提供的……

包含错误的 DAG(缺少新链接) — 图片由作者提供
此时,所提出的 DAG 与数据集不匹配。现在的关键测试是看我们的方法是否准确检测到缺失的链接。
下面的图表显示了从实施依赖性测试“W ⫫ Z1, Z2, Z3 | X”得到的普通最小二乘回归结果……

新数据的系数(包含新链接) — 图片由作者提供
W is dependent on Z1 in the data
W is not dependent on Z2 in the data
W is not dependent on Z3 in the data
W is dependent on X in the data
当数据中存在因果关系 Z1 -> W 但 DAG 中没有时,它被正确地识别为 DAG 中缺失。因此,我们不仅知道 DAG 无效,还知道如何修复它。只需将(“Z1”,“W”)作为附加边添加到 DAG 中,它就变得有效!
当我第一次成功运行这个测试时,我对其影响感到震惊。我不再依赖(无意冒犯!)领域专家的绝对可靠。如果他们的知识存在漏洞或犯了错误,这种验证技术可以发现他们遗漏的因果链接,并生成一个正确有效的 DAG!
到目前为止,我们已经成功测试了有效 DAG 中的单个节点,并成功识别了无效 DAG 中的缺失链接。
下一步是将我们的单节点测试扩展到识别和执行整个 DAG 的所有验证测试……
从测试单个节点到验证整个 DAG
我们现在从测试节点“W”转向验证整个 DAG,结果发现这非常简单。只需在每个节点周围迭代,执行这个算法……

数学表示:对于每个节点,该节点在给定其父节点的情况下应该与其非后代独立 — 图像由作者提供
对于任何验证失败,将识别出的缺失链接添加到整体缺失链接列表中,最终结果是识别 DAG 中的所有缺失链接。
请注意,如果一个节点没有父节点,它仍然应该被测试,但如果一个节点没有非后代,它就不应该也无法被测试。
因此,节点 Z3 被从下面的测试中省略,因为它没有任何非后代(回忆一下,父节点被排除在非后代之外)。
测试整个 DAG 所需的验证测试集合可以在下面的图示中直观地表示(在每个 DAG 中,正在测试的节点为粉色,其父节点为红色,其非后代为黄色,所有剩余节点为蓝色)……

图形表示所有测试所需的验证 DAG 以涵盖所有缺失链接 — 图像由作者提供
以下是以等效数学符号表示的测试……

数学表示所有测试所需的验证 DAG 以涵盖所有缺失链接 — 图像由作者提供
如果对数据执行这 5 个独立性测试,并且在每个测试中非后代“消失”,那么可以有较高的信心认为 DAG 是数据的有效因果表示,并且 DAG 中没有缺失的因果链接。
然而,还有一种优化和改进可以进行。现有文献,特别是 Pearl 的著作 — 例如《统计学中的因果推断》(Pearl, Glymour, Jewell, 2019)— 描述了一组比上述表示的测试要小的测试集,尽管我从未找到关于如何进行最小化的解释。
进一步的研究和反复试验发现,一些独立性测试是等效的,例如……

确定等效测试 — 图像由作者提供
这意味着在测试 A ⫫ B | P1, P2, … , Pn 与 B ⫫ A | P1, P2, … , Pn 等效,并且其中一个可以被省略,因为没有必要重复测试完全相同的内容。
因此,优化后的测试用于证明或反驳我们示例 DAG 的有效性如下……

针对我们的 DAG 的优化缺失链接测试,去除了等效测试 — 图像由作者提供
即 Z2 ⫫ Z1 已被移除,因为它等同于 Z1 ⫫ Z2,测试不需要重复进行。
让我们通过查看实施 OLS 回归测试时生成的图表来结束本节……

可视化所有 4 个缺失链接测试的系数 — 图片由作者提供
这个结果正是我们所期望的!对于每个测试的节点,图表显示非后代的系数非常小,即它们正在消失,即它们是独立的,即没有缺失的链接。
下一个明显的问题是“这种技术的可靠性如何?它是否总能被依赖于识别缺失的因果链接并纠正无效的 DAG?”
这是一个非常重要的问题,需要问清楚并回答,因为我读过的文献(书籍、博客和文章等)都没有涉及这个关键问题。即使有探讨因果验证的内容,也往往是不完整、解释不充分,统计术语过多,Python 代码不足,但关键的是我只找到 DAG 和数据匹配的例子,没有探讨它们不匹配时会发生什么。
以下部分将这一思想拓展到我在文献中未曾发现的领域,使用了我在 Python 代码中构思和实验的各种方法,以观察这些算法在现实场景中的表现……
彻底测试提出的算法
为了彻底测试本文提出的算法,采取了以下方法……
-
确定 DAG 中的每条边(因果关系)
-
按如下方式运行 100 次测试……
-
为 DAG 生成一组有效的测试数据。
-
随机挑选一条缺失的边并删除它。
-
查看提出的算法是否正确识别了缺失的边。
在评估成功时考虑了两个因素……
-
提出的算法是否准确找到了删除/缺失的边?
-
算法是否找到了删除/缺失的边,同时还错误识别了其他缺失的边?
结果如下……

可视化单个缺失链接的所有可能组合的测试 — 图片由作者提供
缺失的边在 100/100 次测试中均被正确识别,但在其中 13 次测试中,算法还识别出一些实际上并不存在的其他边。
下一个明显的问题是“如果缺少两条边/因果链接会怎样?”……

可视化所有可能组合的两个缺失链接的测试 — 图片由作者提供
当删除两条边时,准确率下降。算法在 68/100 次测试中正确找到了缺失的两条边,但在其中 14 次测试中,它错误地识别了一些实际上并不存在的缺失边。
可以进一步测试以评估算法在更复杂、更现实的 DAG 中的性能……

可视化测试单个缺失链接所有可能组合的情况 — 作者提供的图像
在这种情况下,算法在 74/100 的测试中正确识别了缺失的因果链接,其中包括 18 次测试中额外识别为缺失的链接,但实际上并不存在。
为了完成测试,这里是针对一个去掉 2 个有效链接的复杂 DAG 的测试 …

可视化测试所有可能的两个缺失链接组合的情况 — 作者提供的图像
这一次准确率开始下降。算法在 50/100 的测试中正确识别了缺失的因果链接,但其中包括 20 次测试中算法错误地识别了额外的链接以为它们是缺失的。
奖励部分:优化缺失链接
你可能会疑惑为什么测试结果包括了那些识别出的缺失链接与测试中提到的其他数字完全匹配的测试次数,即被删除的链接被发现了,但也有一些被错误识别的链接。
原因在于,在过度识别的情况下,识别为缺失的但实际上并不存在的链接可以很容易地被修正。
在最近的一篇文章中,我探讨了这里解释的验证的镜像,即使用依赖性来识别伪链接(与使用独立性来识别缺失链接相对)。
逐步指南,帮助理解依赖关系的概念以及如何将其应用于验证有向无环图…
towardsdatascience.com](/demystifying-dependence-and-why-it-is-important-in-causal-inference-and-causal-validation-4263b18d5f04?source=post_page-----dfdd26c29739--------------------------------)
“伪链接”算法非常准确,几乎能正确识别所有伪链接,即使在有两个或更多伪链接的复杂 DAG 中也能如此。
这意味着,一个找到缺失链接但也错误识别了一些实际上并不存在的链接的算法,可以通过以下方式进行优化和改进 …
-
运行依赖性/伪链接算法,以识别并纠正任何真正的伪链接。
-
运行本文提出的独立性/缺失链接算法,以识别并纠正任何缺失链接。
-
由于步骤 2 可能过度识别,再次运行依赖性/伪链接算法,这将找到并删除步骤 2 认为是缺失但实际上并不存在的任何链接。
如果实现了这一高级算法,它将使每个测试中两个数字的较高值变为准确值,即 …
-
简单 DAG 中的 1 个缺失链接:100%准确率
-
简单 DAG 中的 2 个缺失链接:68%准确率
-
复杂 DAG 中的 1 个缺失链接:74% 准确率
-
复杂 DAG 中的 2 个缺失链接:50% 准确率
虽然这些结果并不完美,但足够好,可以极其有用。
还应注意,本文提出的算法在 DAG 实际与数据匹配时的准确率接近 100%,因此在验证失败的情况下,一种替代方法是通过强力的试错法修改 DAG 直到验证检查通过(或错误最小化)。
这些是复杂的技术,将在未来的文章中全面探讨,但理论是可靠的,并提供了实现因果验证的极具潜力的高效算法。
结论
因果推断承诺提供一套新的技术,扩展已建立的机器学习所带来的组织影响力和成果。
然而,为了开始这段因果之旅,必须使用领域专业知识构建捕捉因果关系的有向无环图(DAG),而这些专业知识可能存在缺陷。
除非 DAG 准确捕捉因果关系,否则任何后续分析可能包含错误,但因果验证提供了修正和纠正不准确 DAG 的潜力,以确保因果推断方法的结论是正确的。
之前的文章探讨了如何利用依赖关系来识别和去除 DAG 中的虚假因果链接,而这篇文章则探讨了如何利用独立性来识别和添加缺失的因果链接。
未来的文章将探讨如何使用 v-结构来识别和修正反向因果链接(即因果箭头方向错误),以及如何将虚假、缺失和反向链接测试整合成一个一致的整体,这将有助于提供极具有效性的因果验证算法,从而提高组织影响力和成果,优化因果推断技术的应用。
连接并保持联系 …
如果您喜欢这篇文章,您可以通过每月仅需 5 美元成为 Medium 会员,获取无限访问更多故事,您可以 点击我的推荐链接(如果您使用此链接注册,我将获得部分费用,您无需额外支付费用)。
[## 通过我的推荐链接加入 Medium - Graham Harrison
作为 Medium 会员,您的会员费用的一部分将用于支持您阅读的作者,您可以全面访问所有故事……
grahamharrison-86487.medium.com](https://grahamharrison-86487.medium.com/membership?source=post_page-----dfdd26c29739--------------------------------)
…或通过以下方式连接 …
访问我的数据科学网站 — 数据博客。
理解工具变量
原文:
towardsdatascience.com/understanding-instrumental-variables-0ce5d3d6ba20
因果数据科学
如何在无法随机化治疗时估计因果效果
·发表于数据科学前沿 ·阅读时间 12 分钟·2023 年 11 月 13 日
--

封面,图片作者提供
A/B 测试是因果推断的黄金标准,因为它们允许我们在最少假设下做出有效的因果声明,这要归功于随机化。实际上,通过随机分配治疗(药物、广告、产品等),我们能够比较结果(疾病、公司收入、客户满意度等)在受试者(患者、用户、客户等)之间的差异,并将结果的平均差异归因于治疗的因果效果。
然而,在许多情况下,由于伦理、法律或实际原因,无法随机化治疗。一种常见的在线环境是按需功能,例如订阅或高级会员。其他设置包括我们无法区分客户的功能,例如保险合同,或者那些深度硬编码到系统中的功能,实验可能不值得付出努力。在这些情况下,我们仍然可以进行有效的因果推断吗?
答案是肯定的,这要归功于工具变量和对应的实验设计,即鼓励设计。在上述许多情况下,我们无法随机分配治疗,但我们可以鼓励客户接受治疗。例如,我们可以提供订阅折扣,或者我们可以更改选项呈现的顺序。虽然客户对接受治疗拥有最终决定权,但我们仍然能够估计因果治疗效果。让我们看看如何做到这一点。
评估订阅计划
在文章的其余部分,我们将使用一个示例。假设我们是一家产品公司,启动了一份每周的通讯以推广产品和功能更新。我们想了解通讯是否值得投入,以及它是否最终成功地增加了销售额。不幸的是,我们不能进行标准的 A/B 测试,因为我们不能强迫客户订阅通讯。这是否意味着我们无法评估通讯?不完全是。
假设我们还在移动应用上进行了一次关于新通知的 A/B 测试,以推广通讯。随机的一部分客户收到了通知,而另一部分客户没有。也许这个 A/B 测试与通讯的因果效应评估毫无关系,这在大公司中有时会发生。然而,这对数据科学家来说是一个绝佳的机会,尤其是对于那些有兴趣了解通讯对销售影响的人。
首先,我们来看看数据。我从src.dgp中导入了数据生成过程,并从src.utils中导入了一些绘图工具。
dgp = dgp_notification_newsletter(n=10_000)
df = dgp.generate_data()
df.head()python

数据快照,图源作者
我们有关于10,000名客户的信息,我们观察了他们是否收到了notification,是否subscribed了通讯,以及他们spent了多少。此外,我们还观察了他们在订阅程序推出前的花费情况(spent_old)。在文章的其余部分,我们将这些变量标记如下:
-
notification,处理分配,Z -
subscription,处理状态,W -
spend_old,特征或控制变量,X -
spend,结果,Y
一种天真的方法是比较subscribed和未subscribed客户之间的spend差异。相应的因果对象或估计量是

订阅对花费的影响,图源作者
让我们可视化两个组的平均spend。
plot_group_comparison(df, x="subscription", y="spend", title="Spend", xticks=["Non-subscriber", "Subscriber"])

订阅者与非订阅者之间的花费差异,图源作者
订阅者平均花费比非订阅者多11.5$。但这是否是因果效应?
我们可以想象,那些更活跃、对我们产品更感兴趣的客户也会对收到有关它的新闻更感兴趣。例如,我们可以想象那些有更多预算的客户,也会希望更好地花费这些预算并订阅通讯。
我们可以用以下的有向无环图(DAG)来表示变量之间的关系。如果你从未听说过 DAG,我建议你先阅读我的入门文章。
编辑描述
towardsdatascience.com
在图中,我们用圆圈表示变量,用箭头表示因果关系。

DAG 数据生成过程,图像由作者提供
从技术上讲,客户的预算是一个不可观察的混杂因素,它在我们的处理变量subscription和结果变量客户的spend之间打开了一个虚假的路径。因此,我们不能将11.5$的均值差异估计解释为因果关系。
我们能做什么?
鼓励设计
不幸的是,我们不能进行 A/B 测试,因为我们不能强制人们订阅新闻通讯。然而,我们可以鼓励人们订阅。例如,我们可以发送移动通知来宣传新闻通讯。这种设置称为鼓励设计,因为我们不随机化处理,而是随机化鼓励措施。在我们的设置中,鼓励措施notification也被称为工具变量。
重要的是要强调,虽然被随机分配,但鼓励措施不与感兴趣的处理相符。事实上,尽管收到通知,有些人仍然不会订阅,有些人则会在未收到通知的情况下订阅。

处理分配和处理状态,图像由作者提供
添加鼓励措施notification后,数据生成过程可以用以下 DAG 表示。

DAG 数据生成过程,图像由作者提供
请注意,现在我们已经关闭了subscription和spend之间的开放路径。因此,我们可以估计订阅对销售概率的因果效应。我们来看一下。
首先,我们想了解notification是否有效。这通常被称为工具的强度。由于随机化,我们可以将收到notification的人与未收到通知的人之间的spend平均差异归因于处理本身。

通知对消费的影响,图像由作者提供
让我们可视化相应的均值差异估计。
plot_group_comparison(df, x="notification", y="spend", title="Spend", xticks=["No Notification", "Notification"])

带有和不带有通知的消费差异,图像由作者提供
看起来,收到notification的客户平均比未收到通知的客户多花费1\(*。这比我们之前估计的*11.5\)低得多。
然而,notification对spend的影响不是我们关注的重点。我们更想知道subscription对spend的影响。实际上,并非所有收到邮件的客户都会订阅新闻简报。反之,一些人即使没有通知也会订阅新闻简报。
这意味着我们刚刚计算出的效果被稀释了,因为有些人不遵守我们的激励措施,即notification。我们必须将其仅归因于因新闻简报而改变主意的客户。这些客户有多少?
让我们计算每个处理组的subscription概率。

通知对订阅概率的影响,图像由作者提供
plot_group_comparison(df, x="notification", y="subscription", title="Subscription Probability", xticks=["No Notification", "Notification"])

带有和没有通知的订阅概率,图像由作者提供
收到notification的客户的subscription概率高出17%。换句话说,似乎notification能够让17%的客户改变主意。从对照组中,我们了解到28%的人无论如何会订阅,而我们无法说服剩下的55%。
我们现在拥有进行主要分析所需的所有要素
工具变量(IV)
在这种情况下,通过一个二元工具变量,即notification,一个二元处理,即subscription决策,以及 50-50 的处理分配概率,我们可以获得对工具变量工作原理的非常简单的直观理解。
我们有四组客户,取决于他们是否收到了通知,以及他们是否订阅了。
df.groupby(["notification", "subscription"]).agg(spend=("spend", "sum"), customers=("spend", "count")).iloc[::-1].T.round(0)

按细分的消费和客户数量,图像由作者提供
让我们可视化每个组别的总消费和总客户数量。

按细分的消费和客户数量,图像由作者提供
比较处理组(notification)和对照组,我们看到通知导致了spend的5k€的增加(43 + 20 - 28 - 30)。为了恢复感兴趣的因果效应,我们只需将5k€的额外spend归因于800(2200 - 1400)个因notification而决定订阅的客户。结果正是5k€ / 800 = 6€每个客户!

按细分的消费和客户数量,图像由作者提供
更一般地,IV 估计量由两个因果效应的比率给出:工具变量(或鼓励,或分配)Z对结果Y的影响,除以工具变量Z对处理(或内生变量)W的影响。

工具变量估计器,图像由作者提供
为了计算 IV 估计量,我们将期望值替换为经验平均值。实际上,在我们的情况下,我们只是将前一部分图表中计算的两个均值差异估计值进行除法。
tau_ZY = df.loc[df.notification == 1, "spend"].mean() - df.loc[df.notification == 0, "spend"].mean()
tau_ZW = df.loc[df.notification == 1, "subscription"].mean() - df.loc[df.notification == 0, "subscription"].mean()
tau_ZY / tau_ZW
6.070222743259094
我们对 subscription 计划对 spend 影响的工具变量估计值是 6$,这在上面的图示中已经预期到!请注意,图示中的数学仅在完全 50-50 分配的特殊情况下有效。
更一般地,可以证明 IV 估计量的公式由协方差比率给出,

IV 估计量作为协方差比率,图片由作者提供
或者,使用矩阵符号表示,

IV 估计量的矩阵符号表示,图片由作者提供
IV 扩展
如果我们有更多的工具变量或其他控制变量会发生什么?例如,我们可以进行其他实验以鼓励客户 subscribe。或者,如我们的情况,我们可以添加其他变量到模型中以提高预测准确性,如先前的消费水平 spend_old。我们如何将它们包含在模型中?
长话短说,当我们有多个工具变量时,工具变量公式可以重写为

两阶段最小二乘估计量,图片由作者提供
其中 Ŵ 是 W 在 Z 上的投影,即在实践中是给定处理分配的预测处理状态。这个预测步骤称为第一阶段。这个公式应该让你想到 OLS 估计量公式。实际上,这相当于将我们的结果 Y 对预测处理 W 进行线性回归,给定分配 Z。这个步骤称为第二阶段。总体来说,由于估计过程可以分为两个独立的阶段,因此称为两阶段最小二乘(2SLS)估计量。
两阶段公式在大多数 IV 包的实现中尤为明显,在这些包中,我们将处理表示为对工具变量的回归结果。在 [IV2SLS](https://bashtage.github.io/linearmodels/iv/iv/linearmodels.iv.model.IV2SLS.html) 包中,这通过使用方括号来完成。
from linearmodels.iv.model import IV2SLS as iv
model_iv = iv.from_formula("spend ~ 1 + [subscription ~ notification]", data=df).fit()
model_iv.summary.tables[1]

IV 估计量,图片由作者提供
我们可以验证这在代数上等同于首先对 subscription 和 notification 进行回归,然后对预测的 subscription 概率进行 spend 回归。下面我们运行这两个回归并报告第二阶段的估计值。
model_1st_stage = smf.ols("subscription ~ 1 + notification", data=df).fit()
df["subscription_hat"] = model_1st_stage.predict(df)
model_2nd_stage = smf.ols("spend ~ 1 + subscription_hat", data=df).fit()
model_2nd_stage.summary().tables[1]

2SLS 估计量,图片由作者提供
系数确实是相同的!
最后,上述两阶段的公式也使得包含额外的协变量变得相当直观。我们只需将协变量添加到两个阶段。
model_1st_stage = smf.ols("subscription ~ 1 + spend_old + notification", data=df).fit()
df["subscription_hat"] = model_1st_stage.predict(df)
model_2nd_stage_x = smf.ols("spend ~ 1 + spend_old + subscription_hat", data=df).fit()
model_2nd_stage_x.summary().tables[1]

2SLS 估计量,图片由作者提供
我们可以再次验证,估计的系数是相同的。
model_2sls = iv.from_formula("spend ~ 1 + spend_old + [subscription ~ notification]", data=df).fit()
model_2sls.summary.tables[1]

2SLS 估计,图像由作者提供
将之前的消费水平纳入回归确实将标准误差从0.5降低到了0.1。
IV 的限制
在实验设置中,工具变量的主要限制,如本文分析所示,是它们估计的是一种非常“特殊”的因果效应。正如我们在前一节中看到的,我们必须通过决定订阅因为通讯的客户数量来重新缩放总体效应。这意味着我们只能估计那些遵守我们干预的客户的效果。这一类客户通常被称为合规者,相应的因果效应被称为局部平均处理效应(LATE)或合规者平均因果效应(CACE)。
不幸的是,我们无法对那些即使没有通知也订阅了通讯的客户(即总是接受者)以及那些我们无法通过通知说服的客户(即从不接受者)做出任何判断。
IV 的另一个限制涉及其假设。在上一段中,我们讨论了三类客户:合规者(我们最喜欢的)、总是接受者和从不接受者。你可能注意到这种分类隐含着第四类群体的存在:违抗者。这些客户如果我们没有收到通知的话,本来会订阅通讯。然而,由于收到通知,他们改变了主意,违背了工具变量的意图。

根据处理分配和处理状态分组,图像由作者提供
为了能够得出因果结论,我们必须假设实验中没有违抗者,否则我们的重新缩放将会错误,我们的估计也会有偏差。
另一个在背景中潜在的重要假设是通常所说的排除限制。这个假设指出,工具变量通知只通过处理变量订阅影响结果消费。在我们的设置中,一个潜在的违背情况是通知可能会唤醒处于休眠状态的用户。想象一下,一个客户想要完成一笔交易,并且已经将物品添加到购物车中,但忘记了结账。订阅通知可能会提醒用户结账,从而直接影响消费。如你所想,IV 估计会有偏差,因为我们错误地将一些销售归因于订阅,而这些销售实际上是通知本身的直接效果。
结论
在这篇文章中,我们介绍了工具变量在实验环境中的应用。当我们因为伦理、法律或技术限制而不能随机化处理时,我们仍然可以考虑随机化激励来接受处理。这使我们能够做出因果陈述,但仅针对整体人群中的一个子集,即合规者,即因激励而接受处理的客户。
重要的是要注意,工具变量也可以在观察设置中使用。然而,在这种情况下,我们之前提到的排除限制假设变得更难以证明。实际上,我们需要一个环境,在这个环境中,我们的工具变量不会通过任何其他途径影响结果。从技术上讲,排除限制假设是

排除限制,作者图片
当我们无法控制鼓励分配的设计时,这一假设更难以证明。然而,如果假设成立,它为在全新环境下进行因果推断打开了大门。
参考文献
-
Spotify (2023), A/B 测试中的鼓励设计和工具变量
-
Goldsmith-Pinkham (2021), 工具变量 视频讲座
-
Ding (2023), 因果推断的第一课程 注释
相关文章
-
DAGs 和控制变量
-
理解遗漏变量偏差
-
理解 Frisch-Waugh-Lovell 定理
代码
你可以在这里找到原始的 Jupyter Notebook:
[## Blog-Posts/notebooks/instrumental_variables.ipynb at main · matteocourthoud/Blog-Posts
代码和笔记本用于我的 Medium 博客文章。通过创建一个...
感谢阅读!
我非常感激! 🤗 如果你喜欢这篇文章并希望看到更多内容,请考虑 关注我。我每周发布一次与因果推断和数据分析相关的主题。我尽量保持帖子简洁但准确,总是提供代码、示例和模拟。
此外,简单的 免责声明:我写作是为了学习,所以错误是常见的,尽管我尽力而为。请在发现错误时告诉我。我也欢迎对新主题的建议!*
理解目标检测中的交并比(代码)
原文:
towardsdatascience.com/understanding-intersection-over-union-for-object-detection-code-9a691c72d83a
目标检测模型的评估归结为一个问题:确定检测结果是否有效。
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 10 月 7 日
--

照片来自 Vardan Papikyan 在 Unsplash
确定检测是否有效需要理解 交并比指标(IoU)。
本文涵盖以下内容:
-
IoU 基础——什么是 IoU?
-
如何计算(理论上和 Python 代码中)单对检测和地面真实边界框的 IoU
-
计算多个预测和地面真实边界框的 IoU。
-
如何 解读 IoU 值?
什么是交并比(IoU)?
IoU 是评估目标检测模型的核心指标。它通过评估检测框和地面真实框之间的重叠程度来衡量对象检测器的准确性。
-
地面真实框 或 标签 是一个标注框,显示了对象的位置(标注通常是手工完成的,地面真实框被认为是对象的实际位置)。
-
检测框 或 预测边界框 是来自对象检测器的预测。
从形式上讲,IoU 是地面真实框(gt)和预测框(pd)的交集面积除以两个框的并集面积。

IoU 定义(作者提供的图像)。
示例 1:计算单对检测和地面真实的 IoU
让我们从一个简单的例子开始。计算一个检测和一个地面真实的 IoU。
为此,我们需要两个框的左上角(x1, y1)和右下角(x2, y2)坐标。
在下图(右)中,我们有两个边界框:
预测边界框 (p-box): (px1, py1, px2, py2) = (859, 31, 1002, 176)
真实边界框 (t-box): (tx1, ty1, tx2, ty2) = (860, 68, 976, 184)

左:包含 14 个真实值(蓝色框)和 12 个预测值(红色框)的图像。右:单个真实框和预测框的放大视图(标注由作者完成,果园图像来源于 zenodo.org/record/3712808)。
重要: 在计算机视觉中,惯例是:
-
x 轴 是图像的水平维度,从左到右值逐渐增大。
-
y 轴 是图像的垂直维度,从上到下值逐渐增大(这与标准的笛卡尔坐标系统不同)。
步骤 1:计算两个盒子的面积
这个步骤计算预测框和真实框的面积。它仅是长度乘以宽度。
predicted_area = (1002–859) * (176–31) = 20735 ground_truth_area = (976–860) * (184–68) = 13456
步骤 2:找到交集点
这个步骤用于找到交集区域的 top-left(A) 和 bottom-right(B) 坐标。
这可以通过以下方法找到:
top_left = max(px1, tx1), max(py1, ty2)
bottom_right = min(px2, tx2), min(py2, ty2)
在我们的案例中,
top_left (A) = max(859, 860), max(31, 68) = (860, 68) bottom_right (B) = min(1002, 976), min(176, 184) = (976, 176)

显示两个盒子之间的交集区域(图片作者提供)。
步骤 3:计算交集面积
由于我们有交集点,我们可以轻松地计算交集矩形的面积。
intersection_area = (976–860) * (176–68) = 12528
步骤 4:计算 IoU 值
IoU = 交集面积 / 并集面积,
其中 union_area 是两个盒子面积的总和减去交集面积。即,
union_area = (真实边界框面积 + 预测框面积) - 交集面积 = (20735+13456) — 12528 = 21663
因此,
IoU = 12528/21664 = 0.578286558
让我们将其转换为 Python 代码
以下代码可用于计算单对真实框和预测框的 IoU。代码片段之后,让我们深入探讨所使用的思路。
import numpy as np
def compute_iou(box1, box2):
"""
This function computes the intersection-over-union of two boxes.
Both boxes are expected to be in (x1, y1, x2, y2) format.
where (x1, y1) is the top_left coordinates and
(x2, y2) is the bottom_right coordinates
Arguments:
box1 4 by 1 NumPy Array: The first box.
box2 4 by 1 NumPy Array: The second box.
Returns:
iou (float): The intersection-over-union value for the two boxes.
"""
# Calculate the area of each box
area1 = np.prod(box1[2:] - box1[:2])
area2 = np.prod(box2[2:] - box2[:2])
print("Area of box 1 and box2, respectively: ", area1, area2)
# Calculate the intersection coordinates (top left and bottom right)
top_left = np.maximum(box1[:2], box2[:2])
bottom_right = np.minimum(box1[2:], box2[2:])
print("Top left and bottom right of intersection rectangle: ", top_left, bottom_right)
# Calculate the intersection area
intersection = np.prod(np.clip(bottom_right - top_left, a_min=0, a_max=None))
print("Intersection area: ", intersection)
# Calculate the union area
union = area1 + area2 - intersection
print("Union area: ", union)
# Calculate the IoU
iou = intersection / union if union > 0 else 0.0
return iou
# Calling compute_iou with overlapping boxes
detection = np.array([859, 31, 1002, 176])
label = np.array([860, 68, 976, 184])
iou_value = compute_iou(detection, label)
print("IoU:", iou_value)
输出:
Area of box 1 and box2, respectively: 20735 13456
Top left and bottom right of intersection rectangle: [860 68] [976 176]
Intersection area: 12528
Union area: 21663
IoU: 0.5783132530120482
让我们第二次调用 compute_iou() 函数,使用不重叠的盒子。
# Calling compute_iou with non-intersecting boxes
detection = np.array([810, 744, 942, 865])
label = np.array([109,563,217,671])
iou_value = compute_iou(detection, label)
print("IoU:", iou_value)
输出:
Area of box 1 and box2, respectively: 15972 11664
Top left and bottom right of intersection rectangle: [810 744] [217 671]
Intersection area: 0
Union area: 27636
IoU: 0.0
代码分解:
-
NumPy 向量化 使我们能够对数组执行
np.prod(),np.maximum(),np.minimum(),np.clip(),addition和subtraction等操作,而无需循环遍历数组元素或索引单个元素。 -
np.clip()函数限制或“剪切”数组中的值在指定范围内。在我们的案例中,np.clip(bottom_right — top_left, a_min=0, a_max=None)通过将负值设为 0 来确保交集的宽度和高度值为非负数。
示例 2:计算多个真实值和预测框对的 IoU
在这个例子中,我们想要计算下图中所有真实值和预测框对的 IoU 值(最左侧)。
图像包含 12 个预测(红色框)和 14 个真实值(蓝色框)。

左:包含真实值和检测的图像,中间:预测,右:真实值(作者标注,果园图像来源于 zenodo.org/record/3712808)。
通过修改初始代码,可以轻松计算所有检测和真实值对的 IoU,如下所示。
def compute_ious(boxes1, boxes2):
"""
This function computes intersection-over-union of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format
where (x1,y1) is the top-left coordinates and
(x2,y2) is the bottom right coordinates
Arguments:
boxes1: M by 4 NumPy array
boxes2: N by 4 NumPy array
Returns:
iou MxN Numpy Matrix - containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
# Compute area for all combination of boxes in boxes1 and boxes2
area1 = np.prod(boxes1[:, 2:] - boxes1[:, :2], axis=1)
area2 = np.prod(boxes2[:, 2:] - boxes2[:, :2], axis=1)
# Top left and bottom right of the intersection for all box pairs
top_left = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) # NxMx2 Array
bottom_right = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # NxMx2 Array
# Compute intersection for all box pairs
intersection = np.prod(np.clip(bottom_right - top_left, a_min=0, a_max=None), 2)
return intersection / (area1[:, None] + area2 - intersection)
# Define detections and ground_truths
detections = np.array([[374,627,538,792],
[330,308,501,471],
[474,14,638,181],
[810,744,942,865],
[58,844,204,993],
[905,280,1022,425],
[887,412,1018,543],
[0,871,68,1008],
[859,31,1002,176],
[698,949,808,1023],
[0,400,47,505],
[234,0,314,58]])
ground_truths = np.array([[331,303,497,469],
[385,624,543,782],
[809,743,941,875],
[883,410,1024,556],
[918,287,1024,425],
[860,68,976,184],
[109,563,217,671],
[0,401,60,515],
[51,833,207,989],
[0,867,80,1024],
[273,877,403,1007],
[701,939,821,1024],
[905,608,1021,724],
[471,17,629,175]])
# Call compute_ious() function
ious = compute_ious(detections, ground_truths)
print(ious)
输出(为更好地查看而格式化):

compute_ious() 函数的输出(图像由作者提供)
输出显示:
-
一个检测(索引 12)没有与任何真实值重叠。
-
3 个真实值与任何检测没有重叠——位于索引 7、11 和 13 的真实值。
-
有 11 个检测的 IoU > 50% 与真实值。
让我们详细分析一段可能不太清楚的代码:
- 在上述代码中,
None在索引操作中的使用是 NumPy 的一种技术,用于在数组中引入新的轴或维度。它通常用于广播不同形状的数组,以执行逐元素操作或启用某些计算。
IoU 值的解释
IoU 值范围在 [0, 1] 之间,其中 0 表示两个框之间没有重叠,而 1 表示完全重叠。
根据你的应用,你可以设置一个 IoU 阈值来确定什么是好的检测。
希望这篇文章能让 IoU 概念更清晰。
祝好运!
在因果推断中理解治疗加权的逆概率 (IPTW)
对 IPTW 的直观解释及其与多元回归的比较
·
关注 发布于 Towards Data Science ·14 分钟阅读·2023 年 1 月 11 日
--
照片由 Nadir sYzYgY 提供,来源于 Unsplash
介绍
在本帖中,我将提供逆概率治疗加权(IPTW)的直观和插图解释,IPTW 是各种倾向评分(PS)方法之一。IPTW 是多元线性回归在因果推断背景下的替代方法,因为两者都试图在混杂因素存在的情况下确定治疗对结果的影响。需要注意的是,目前的证据并不支持 IPTW 优于多元线性模型的说法(Glynn et al., 2006)。然而,IPTW 确实具有某些理论和实践上的好处,我们将在本帖中回顾这些内容。
在撰写时,查询“propensity scor*”在 PubMed 中已经识别出近 45,000 条引用(PubMed 查询)。根据这个标准,2000 年有 45 条引用,而 2022 年有 8,929 条引用,并且在此期间每年的引用数量都在增加(PubMed 查询)。这种受欢迎程度的增加需要对方法论进行直接的解释。
治疗加权的逆概率:解释
随机对照试验
如果我们想确定治疗对某些可测量结果的影响,金标准方法是随机对照试验 (RCT)。在 RCT 中,治疗是随机分配给个体的。在样本量足够大的试验中,治疗在可能影响试验结果的所有测量和未测量变量中随机分配(Hariton et al., 2018)。这些变量在本帖的其余部分将被称为协变量。这种设置使研究人员能够最接近地估计治疗对结果的因果影响。值得注意的是,即使 RCT 本身也不太可能证明因果关系,但它们确实提供了最强的证据。
简单的观察性示例
让我们首先设置一个简单的示例,其中包含接受过治疗的受试者、他们的性别和他们的结果。我们的目标是确定治疗对结果的影响。在这个玩具示例中,我们假设数据包含 8 名参与者,4 名男性和 4 名女性。此外,治疗给予了 4 名男性中的 2 人以及 4 名女性中的 2 人,如图 1 所示。

图 1:简单示例
在这种情况下,知道受试者的性别对受试者是否接受治疗没有任何信息提供。接受治疗的整体概率是 50%。如果受试者是男性,接受治疗的概率是 50%。如果受试者是女性,接受治疗的概率也是 50%。换句话说,性别与治疗之间没有相关性。图 2 显示了有向无环图 (DAG),该图显示了所谓的因果方向。

图 2:因果 DAG
这个 DAG 可以解释为显示治疗和性别对患者结果的影响。然而,由于性别不会影响治疗是否被施用,因此性别与治疗之间没有箭头。这正是 RCT(随机对照试验)在大样本中本质上保证的。事实上,这种保证适用于所有可能的协变量,即使是那些未被测量的协变量。
一个带有混杂因素的观察示例
现在我们将修改示例,展示如果性别突然与治疗的施用相关会发生什么。图 3 显示女性接受治疗的概率为 75%,而男性只有 25% 的概率接受治疗。

图 3:一个现实的示例
接受治疗的整体概率仍然是 50%。但是,现在知道受试者的性别会提供关于受试者是否接受治疗的额外信息。性别和治疗不再是独立的,因为接受治疗的概率(50%)不等于受试者是女性(75%)或男性(25%)时接受治疗的概率。这被称为选择偏差,因为未能实现性别上的随机化。现在,性别成为一个混杂因素,这意味着它影响独立变量(治疗)和依赖变量(结果),这会妨碍我们直接测量治疗对结果的影响。图 4 显示了更新后的 DAG,其中有一条从性别到治疗的箭头。这条箭头表示了本节中描述的选择偏差。换句话说,受试者的性别影响受试者是否接受治疗,从而创建了统计混杂。

图 4:带有混杂因素的 DAG
处理混杂因素的方法:多元线性回归
这时我们开始考虑用于应对混杂的工具,包括多变量线性回归。我不会在这里深入探讨线性回归的细节,因为它可能是 IPTW 的前提。如果你不熟悉线性回归并希望深入了解,我强烈推荐 Richard McElreath 的 统计重思讲座和教科书。若需更快的解释,StatQuest 在 YouTube 上有一个很好的 线性回归解释视频。
简而言之,如果我们创建一个线性回归模型,形式为,

我们将闭合性别 → 治疗 → 结果的路径。这就是我们所说的“控制”性别。这将使我们得到一个无偏的治疗效果估计,因为我们有效地去除了性别混杂因素带来的选择偏差。
关于混杂因素该怎么办:IPTW
IPTW 是一种去除混杂因素影响的替代统计方法。IPTW 的高级思路是创建个体观察的副本,以便在副本创建之后,混杂因素与感兴趣的治疗之间不再存在关系。其效果是将数据转化为大致随机。我们计算每个观察的副本数量的方法将是本节的其余部分讨论的主题。
我将开始用文字解释如何计算副本数量,这些将从现在开始被称为权重。如果对程序的解释不清楚,图 5 将提供权重方案的直观视觉解释。
计算这种加权的机制如下:
-
对于每个观察 i,找到其进入治疗组的概率 p(Chesnaye et al., 2022 para 9)。这就是逆概率治疗加权中的“治疗概率”来源。
-
计算个体观察的权重 w 为 1/p。这就是逆概率治疗加权中“逆”一词的来源(Chesnaye et al., 2022 para 9)。
-
使用这些权重创建“副本”。
我们将计算在我们例子中接受治疗的女性的权重。首先,我们需要找到治疗组中每个女性接受治疗的概率。由于 4 名女性中有 3 名接受了治疗,因此我们知道这个概率是 75%。然后,我们通过倒数这个概率来计算这三名女性的权重。因此,1/0.75 等于 1.333。最后,我们使用这个权重创建“副本”。因为我们有 3 名女性,所以 3 x 1.333 = 4。换句话说,我们将最终得到 4 名女性。有关此过程的清晰视觉解释,请参见图 5。

图 5:计算 IPTW 权重
这个过程使得某些观察的相对重要性增加。其效果是既增加了样本量,又平衡了协变量。我们称之为伪人口,因为我们实际上通过使用这种加权方案来向样本中添加个体。图 6 展示了使用这些权重对伪人口的影响。

图 6:协变量平衡的伪人口
使用这些权重的效果是通过以一种使处理不再依赖于混杂因素的方式结构化伪人口,从而控制混杂变量。在这个伪人口中,知道一个受试者的性别不再提供关于受试者是否接受了处理的任何信息。这就是我们所说的平衡协变量。
现在,如果我们重新绘制如图 7 所示的因果 DAG,我们将移除性别 → 处理的箭头。性别影响结果,但不再影响处理。因此,我们已经去除了混杂因素。

图 7:无混杂因素的 DAG
休息和退出通道
如果你已经看到了这里,做得很好。这是一个很好的停止点。你现在已经有了理解 IPTW 的坚实概念基础。接下来的两个部分将基于这个基础,介绍 IPTW 中的两个稍微复杂一些的主题。它们将包括:
-
稳定化的 IPTW,以及
-
计算倾向评分
如果你想跳过这些部分,可以随意。不过,阅读将 IPTW 与传统多变量模型进行比较和结论部分可能是值得的。
稳定化的 IPTW
在 IPTW 示例中,回顾一下我们如何将有效样本量从 N=8 增加到 N=16。这个内容将在图 8 中总结。

图 8:增加样本量
尽管我们不再有不平衡的协变量,但我们引入了新的困境。随着样本量的增加,统计测试更可能发现效果。这是由于中心极限定理的性质。较大的样本意味着我们对样本应用的统计测试具有更大的统计功效。通过人为地将样本量加倍,我们实际上增加了发现我们的处理对结果有影响的概率(Xu et al., 2010)。这是一种称为重复抽样的现象。
为了说明重复抽样为何成问题,考虑两个理论上公平的硬币分别被掷 4 次和 8 次。在这个例子中,每个硬币在每次掷出时都产生了正面。第一个硬币产生 4 个正面的概率是 6.25%,而第二个硬币产生 8 个正面的概率是 0.39%。

我们为两个不同硬币计算的概率类似于 p 值。它们表示在硬币公平的情况下事件发生的概率。现在我们想测试在观察到的数据存在的情况下公平硬币的断言是否成立。我们将使用 假设检验,其中 零假设 是硬币是公平的,而 备择假设 是硬币是有偏的。
考虑产生 4 次正面的硬币。如果我们假设零假设为真,这个事件的概率(p 值)为 6.25%。这个 p 值通常不会提供令人信服的证据证明硬币有偏。通常,我们会要求 p 值小于 5%。现在,假设我们通过将每次投掷的权重值设为 2(这与 IPTW 的工作方式直接类似)来人工将样本量翻倍。我们知道 8 次硬币投掷都出现正面的概率为 0.39%,这足以声称硬币有偏并拒绝零假设。然而,我们获得的数据仅包含 4 次硬币投掷的信息。因此,我们夸大了我们拒绝零假设(硬币公平)的概率,以支持备择假设(硬币有偏),这也被称为 第一类错误。这正是我们在使用 IPTW 时遇到的问题。
为了修正这种人工增加的样本量,我们将引入稳定化 IPTW。简单来说,计算权重时

我们将计算权重如下

图 9 将展示我们如何计算稳定化加权方案中的分子。

图 9:稳定化 IPTW 分子
图 10 将展示我们如何更新加权方案,以便不将伪人群的大小增加到比原始数据中的实际人群大得多。

图 10:计算稳定化 IPTW 权重
使用这种稳定化加权方案的效果是,伪人群不再比原始人群大那么多,如图 11 所示。

图 11:稳定化协变量平衡伪人群
由于我们不再增加伪人群的大小,相较于原始人群,第一类错误(假阳性)的概率不会被夸大。
计算倾向评分
计算受试者接受治疗的概率,也称为倾向性,往往不像前面的例子中那样简单。为了说明这一点,让我们在例子中添加一个额外的协变量——年龄,看看结果如何。

图 12:包含两个混淆因素的 DAG
快速检查这个因果 DAG,我们注意到性别仍然混淆了治疗对结果的影响,就像我们在前面的例子中看到的那样。此外,年龄作为混淆因素被添加进来。不幸的是,由于年龄是一个连续变量,我们不能像之前那样绘制治疗概率图。实际上,我们需要一种全新的方法来计算倾向性。
这就是我们将利用逻辑回归的地方。在这篇文章中,我不会深入探讨逻辑回归的工作原理。如果你对逻辑回归不熟悉,我建议观看关于逻辑回归的 StatQuest 视频以获得一个非常易于理解的概述。关键点是,我们可以使用逻辑回归来计算在给定协变量(性别和年龄)的情况下接受治疗的倾向分数。
一旦我们使用逻辑回归计算倾向分数并重新加权数据,检查加权协变量的分布以确保它们平衡是至关重要的。由于使用逻辑回归估计倾向分数引入了额外的复杂性,我们需要检查拟合优度(Borah et al., 2013)。这将简单地涉及检查接受治疗和未接受治疗者的年龄和性别分布是否大致相似。
比较 IPTW 与传统的多元模型
正如介绍中提到的,因果推断的金标准是 RCT。然而,在现实世界中,构建一个完整的 RCT 并不总是可行。因此,我们只能使用接近 RCT 的统计技术,包括多元线性回归或倾向性评分模型,如 IPTW。IPTW 的优点在于它试图在观察到的协变量之间创建平衡,而这正是 RCT 所保证的。相反,多元线性回归并不试图平衡协变量。然而,“没有证据表明,相较于多元模型中的常规估计,利用倾向评分的分析将显著减少混杂带来的偏差”(Glynn et al., 2006 para 31)。
尽管 IPTW 在理论上相对于线性回归具有一些优势,但“在实际使用中,倾向评分和回归模型估计之间的答案几乎没有显著差异”(Glynn et al., 2006 para 8)。
自然地,为什么研究人员会选择使用 IPTW 而不是线性回归的问题就出现了。我将在下面简要回顾一些这些原因。
1. PS 方法允许研究人员使用有原则的方法来修剪研究人群。

图 13:暴露倾向评分 — 来源: Glynn 等人, 2006。
在图 13 中,虚线曲线表示未接受治疗的个体的倾向评分分布。实线曲线表示接受治疗的个体的倾向评分分布。在未治疗分布的左尾和已治疗分布的右尾,注意是否有个体从未接受治疗或总是接受治疗。将这些个体从研究人群中剔除带来了理论上的好处,因为这些观察结果“可能在多变量分析中产生不当影响和问题,因为[治疗组]与[未治疗组]之间的协变量重叠极少”(Glynn 等人, 2006 第 16 段)。
2. PS 方法可以阐明治疗如何与接受治疗的倾向相互作用。
通过按倾向评分分层学科,可以识别治疗的有效性是否因每个受试者所在的倾向评分层次而有所不同。
3. PS 校准可以提高主研究的稳健性。
考虑一个例子,我们有两个研究:一个主研究和一个验证研究。两个研究都旨在评估相同治疗对结果的影响。主研究的样本量非常大,而验证研究较小。在主研究中,有一些预测变量由于未被测量而被遗漏。因此,主研究的倾向评分将受到遗漏变量偏差的影响。然而,如果验证研究包含了更多详细的预测变量来修正主研究的遗漏变量偏差,那么验证研究“可能会提供更可靠的倾向评分估计”(Glynn 等人, 2006 第 20 段)。然后可以使用验证研究来校准主研究中的倾向评分,从而减少偏差(Stürmer 等人, 2005*)。
结论
总结一下,我们回顾了 IPTW 的机制。IPTW 的主要目标是确保协变量在治疗组之间平衡,从而尽可能减少由测量的协变量引起的混杂。此外,我们回顾了 PS 方法中的两个更复杂的话题,包括稳定权重和计算倾向评分。最后,我们简要讨论了使用 IPTW 而不是多变量线性回归的一些理论好处。本文的主要目的是让读者了解 IPTW 的工作原理,因为在过去 20 年里,作为一种统计方法,它的使用显著增加。我希望这个概述对你有所帮助!
除非另有说明,所有图片均由作者提供。
参考文献
[1] B. Borah, J. Moriarty, W. Crown 和 J Doshi,倾向评分方法在观察性比较效果和安全性研究中的应用:我们已经取得了什么进展,还应该去向何处? (2013),《比较效果研究杂志》
[2] E. Hariton 和 J. Locascio,随机对照试验——效果研究的金标准 (2018),《英国妇产科杂志》
[3] N. Chesnaye, V. Stel, G. Tripepi, F. Dekker, E. Fu, C. Zoccali, K. Jager, 观察研究中治疗权重的逆概率介绍 (2022),《临床肾脏杂志》
[4] R. Glynn, S. Schneeweiss 和 T. Stürmer,倾向评分的指示及其在药物流行病学中的使用回顾 (2006),《基础与临床药理学与毒理学》
[5] S. Xu, C. Ross, M. Raebel, S. Shetterly, C. Blanchette 和 D. Smith,使用稳定逆倾向评分作为权重直接估计相对风险及其置信区间 (2010),《健康价值》
[6] T. Stürmer, S. Schneeweiss, J. Avorn 和 Robert J Glynn,使用倾向评分校准的验证数据调整未测量混杂因素的效果估计 (2005),《美国流行病学杂志》
理解 KL 散度

图片由作者提供
KL 散度的数学、直观理解和实际应用指南——包括如何在漂移监测中最佳使用它
·
关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 2 月 2 日
--
Kullback-Leibler 散度度量(相对熵)是信息理论中的一种统计测量,通常用于量化一个概率分布与参考概率分布之间的差异。
尽管 KL 散度很受欢迎,但有时会被误解。在实践中,有时也很难知道何时应选择一种统计距离检查而非另一种。
本博客介绍了如何使用 KL 散度,它在实践中的工作原理,以及何时应使用或不应使用 KL 散度来监测漂移。
你如何计算 KL 散度?
KL 散度是一种非对称度量, 测量相对熵 或两个分布所表示的信息差异。它可以被视为测量两个数据分布之间的距离,显示这两个分布彼此之间的差异。
KL 散度有连续形式

作者提供的图片
以及离散形式的 KL 散度:

作者提供的图片
在模型监控中,大多数实践者几乎专门使用离散形式的 KL 散度,并通过 数据分箱 来获得离散分布。离散形式的 KL 散度与连续形式在样本数量和箱子限制趋向于无限时会收敛。存在最佳的 选择方法 来接近连续形式。在实践中,箱子的数量可能远少于上述数字所示的数量——如何创建这些箱子以处理零样本箱的情况在实际操作中比其他任何事情更为重要(未来的文章将用代码处理如何自然处理零箱问题)。
KL 散度如何用于模型监控?
在模型监控中,KL 散度用于监控生产环境,特别是围绕特征和预测数据。KL 散度被用来确保生产中的输入或输出数据不会与基准发生剧烈变化。基准可以是训练生产数据窗口,或者是训练或验证数据集。
漂移监控对于那些接收延迟的真实数据来与生产模型决策进行比较的团队尤其有用。这些团队可以依赖预测和特征分布的变化作为性能的代理。
KL 散度通常应用于每个特征,独立地进行计算;它不是作为协方差特征度量的设计,而是显示每个特征如何独立于基准值发生偏离的度量。

作者提供的图片
上面橙色条纹中显示的 p(x) 是参考或基准分布。最常见的基准要么是生产数据的滞后窗口,要么是训练数据集。每个箱子会累加地贡献于 KL 散度。这些箱子共同加总到总的百分比分布中。
✏️ 注意😗 有时非实践者可能有一种过于热心的目标,即完美捕捉数据变化的数学方法。在实践中,需要记住,实际数据在生产中总是不断变化的,许多模型对这种修改后的数据有很好的适应性。使用漂移度量的目标是拥有一个稳固、稳定且非常有用的度量,以便进行故障排除。
KL 散度是非对称度量吗?
是的。如果你交换基线分布 p(x) 和样本分布 q(x),你会得到不同的数字。作为一个非对称度量,KL 散度在团队使用它进行数据模型比较时有一些缺点。有时团队希望在故障排除工作流程中用不同的分布替换比较基线,而拥有一个A / B 与B / A 不同的度量可能会使结果比较变得困难。

作者提供的图像
这也是模型监控工具如 Arize(完全披露:我共同创立了 Arize)将总体稳定性指数(PSI),KL 散度的对称衍生物,作为模型监控分布的主要指标之一的原因。
连续数值特征与分类特征之间的差异
KL 散度可以用于测量数值分布与分类分布之间的差异。

作者提供的图像
数值
对于数值分布,数据根据切割点、箱体大小和箱体宽度被分成多个箱。箱体策略可以是均匀箱体、分位数或者复杂的混合策略,这会对 KL 散度产生显著影响。
分类
KL 散度的监测跟踪分类数据集中的大规模分布变化。对于分类特征,通常有一个大小,使得基数变得太大,导致度量的实用性降低。理想的大小是约 50–100 个独特值——当分布具有更高基数时,两个分布之间的差异及其重要性的问题会变得模糊。
高基数
在高基数字特征监测的情况下,现成的统计距离通常效果不好——我们通常推荐两个选项:
-
嵌入:在一些高基数的情况下,使用的值——如用户 ID 或内容 ID——已经被用于内部创建嵌入。嵌入漂移监测可以提供帮助。
-
纯高基数分类:在其他情况下,当模型将输入编码到较大的空间时,使用 KL 散度监测前 50–100 个顶级项以及所有其他值作为“其他”可能会很有用。
最后,有时你想监测的是某个特定的内容,如在某个周期内的新值或箱体的百分比。这些可以通过数据质量监控工具进行更具体的设置。
KL 散度示例
这是一个KL 散度的示例。

作者提供的图像
想象一下,我们有一个用于预测信用卡欺诈的模型的数值分布。该模型是基于上图所示的基准进行训练的。我们可以看到收费的分布发生了变化。行业标准对阈值有一些规定,实际上我们建议使用生产中滞后的值来设置自动阈值。在生产中,许多固定设置的<>并不合适。
KL 散度的直觉
对于度量及其基于分布变化的变化,拥有一定的直觉非常重要。

作者提供的图像
上述例子展示了从一个类别分箱移动到另一个分箱的情况。以“医疗”作为特征(贷款用途)的预测从 2% 增加到 8%,而以“度假”为特征的预测从 23% 减少到 17%。
在这个例子中,与“医疗”相关的 KL 散度组件是 -0.028,且小于“度假”百分比变化的 0.070 组件。
一般来说,减少百分比并将其向 0 移动的变化对该统计量的影响大于百分比的增加。

作者提供的图像
相对于行业标准 0.2,要获得较大的变动,唯一的方法是将一个分箱向 0 移动。在这个例子中,将一个分箱从 9% 移动到 0.5% 会大幅改变 KL 散度。
这里是一个电子表格,供那些想要操作和修改这些百分比以更好地理解直觉的人使用。值得注意的是,这种直觉与 PSI 的直觉非常不同。
结论
如果你考虑使用 KL 散度来衡量漂移,需要记住几件重要的事情。首先,大多数实践者发现使用 KL 散度的离散形式最为简单,并通过对数据进行分箱来获取离散分布(有关分箱挑战的更多信息将在未来的帖子中介绍)。其次,虽然理解 KL 散度背后的直觉和数学很重要,但有时其他度量如 PSI —— KL 散度的对称推导 —— 或其他方法可能更适合你的用例。
勇敢学习机器学习:揭示 L1 和 L2 正则化(第一部分)
理解 L1 和 L2 正则化的基本目的
·
关注 发表在 Towards Data Science · 6 分钟阅读 · 2023 年 11 月 22 日
--
欢迎来到“勇敢学习机器学习”,我们将从探索 L1 和 L2 正则化开始。本系列旨在简化复杂的机器学习概念,以轻松和信息丰富的对话呈现,类似于 “勇敢做自己” 的风格,但聚焦于机器学习。
这些问答环节反映了我个人的学习历程,我很高兴与你分享。把它看作是一个记录我深入机器学习的博客。你的互动——点赞、评论和关注——不仅仅是支持,它们还是推动这一系列内容继续下去和我分享过程的动力。
今天的讨论不仅仅是回顾 L1 和 L2 正则化的公式和性质。我们将深入探讨为什么在机器学习中使用这些方法的核心原因。如果你希望真正理解这些概念,你来对地方了,这里有一些启发性的见解!
在这篇文章中,我们将回答以下问题:
-
什么是正则化?我们为什么需要它?
-
什么是 L1 和 L2 正则化?
-
为什么我们更倾向于使用小系数而不是大系数?大系数如何导致模型复杂度增加?
-
为什么神经网络中存在多种权重和偏置组合?
-
为什么在 L1 和 L2 正则化中偏置项不受惩罚?
什么是正则化?我们为什么需要它?
正则化是机器学习中的一个基础技术,旨在防止模型过拟合。过拟合发生在一个模型,通常是过于复杂,不仅从训练数据中的潜在模式(信号)中学习,还会捕捉并放大噪声。这会导致模型在训练数据上表现良好,但在未见过的数据上表现较差。
什么是 L1 和 L2 正则化?
有多种方法可以防止过拟合。L1 和 L2 正则化主要通过在模型的损失函数上添加惩罚项来解决过拟合问题。这种惩罚项阻止模型对任何单一特征(由大系数表示)赋予过多的重要性,从而简化模型。本质上,正则化保持模型平衡,专注于真正的信号,提高其对未见数据的泛化能力。
等等,我们为什么要对模型中的大权重施加惩罚?大系数如何导致模型复杂度增加?
虽然有很多组合可以最小化损失函数,但并非所有的组合都对泛化效果同样良好。大系数往往会放大数据中的有用信息(信号)和不必要的噪声。这种放大使得模型对输入中的小变化非常敏感,导致它过度强调噪声。因此,模型在新数据上的表现不佳。
相反,小系数有助于模型专注于数据中更重要、更广泛的模式,从而减少对细微波动的敏感性。这种方法促进了更好的平衡,使模型能更有效地泛化。
举个例子,假设一个神经网络被训练用来预测猫的体重。如果一个模型的系数为 10,而另一个模型的系数大得多为 1000,那么它们对下一层的输出将大相径庭——分别为 300 和 30000。系数较大的模型更容易做出极端预测。在 30 磅作为异常值(对猫来说非常不寻常!)的情况下,第二个系数较大的模型会产生明显不准确的结果。这个例子说明了调节系数的重要性,以避免对数据中的异常值做出夸张的反应。
你能详细解释一下为什么神经网络中有多个权重和偏置组合吗?
想象一下在神经网络损失函数的复杂地形中导航,你的任务是找到最低点,即‘最小值’。你可能会遇到以下情况:

照片由Tamas Tuzes-Katai拍摄,发布在Unsplash上。
-
多重目的地的风景:当你穿越这片风景时,你会发现它充满了各种局部最小值,就像一个具有许多凹陷和山谷的非凸地形。这是因为具有多个隐藏层的神经网络的损失函数本质上是非凸的。每个局部最小值代表了不同的权重和偏置组合,提供了多种潜在的解决方案。
-
多条路径到达同一目的地:网络的非线性激活函数使其能够形成复杂的模式,近似数据的实际基础函数。通过多个这些函数的层,有许多方式来表示同一真相,每种方式由一组独特的权重和偏置特征。这就是网络设计中的冗余。
-
序列的灵活性:想象一下改变你的旅程顺序,比如交换骑车和乘公交车的顺序,但仍然到达同一目的地。与具有两个隐藏层的神经网络相关:如果你在第一层中将权重和偏置加倍,然后在第二层中将它们减半,最终输出保持不变。(请注意,这种灵活性主要适用于具有某些线性特征的激活函数,如 ReLU,而不适用于 sigmoid 或 tanh 等其他函数)。这一现象被称为神经网络中的‘尺度对称’。
我一直在阅读 L1 和 L2 正则化,并观察到惩罚项主要集中在权重上,而不是偏置上。但这是为什么呢?难道偏置不是也可以被惩罚的系数吗?

简而言之,像 L1 和 L2 这样的正则化技术的主要目标是通过调节模型权重的大小来防止过拟合(个人认为这就是我们称其为正则化的原因)。相比之下,偏差对模型复杂性的影响相对较小,因此通常不需要对其施加惩罚。
为了更好地理解,让我们来看一下权重和偏差的作用。权重决定了模型中每个特征的重要性,影响模型的复杂性以及在高维空间中的决策边界的形状。可以把它们看作是调整模型决策过程形状的旋钮,影响模型的复杂程度。
然而,偏差具有不同的作用。它们像线性函数中的截距一样,独立于输入特征来调整模型的输出。
这里是关键点:过拟合主要发生在特征之间的复杂相互作用中,而这些相互作用主要由权重处理。为了解决这个问题,我们对权重施加惩罚,调整每个特征的重要性以及模型从中提取的信息量。这反过来会重塑模型的格局,从而改变其复杂性。
相比之下,偏差对模型复杂性的影响不大。此外,偏差可以随着权重的变化而调整,从而减少了对单独偏差惩罚的需求。
现在你已经对权重和偏差的多个集合以及偏好较小权重的原因有了了解,我们可以深入探讨。
加入我在 系列的第二部分 中,我将揭示 L1 和 L2 正则化背后的层次,并通过拉格朗日乘子提供直观的理解(不用担心名字,这个概念很简单 😃)
到时见!
如果你喜欢这篇文章,你可以在 LinkedIn 上找到我,随时欢迎你联系我或者提出问题和建议!
理解大型语言模型:(Chat)GPT 和 BERT 的物理学
从物理学家的角度探讨粒子和力量如何帮助我们理解大型语言模型(LLMs)。
·发表于 Towards Data Science ·12 分钟阅读·2023 年 7 月 20 日
--

ChatGPT 和冰晶之间可能有更多的相似之处(来源:15414483@pixabay)
ChatGPT,或者更广泛的说,大型语言 AI 模型(LLMs),已经在我们的生活中无处不在。然而,LLMs 的大部分数学和内部结构对于普通大众来说是模糊的知识。
那么,我们如何才能超越将 LLMs,如 ChatGPT,视为神秘黑箱的观念呢?物理学或许能提供答案。
每个人对我们的物理世界都有一定的了解。像汽车、桌子和行星这样的物体由数万亿个原子组成,受一套简单物理法则的支配。同样,像 ChatGPT 这样的复杂生物体也已经出现,并能够生成像艺术和科学这样高度复杂的概念。
结果表明,LLMs 的构建模块的方程类似于我们的物理法则。因此,通过理解复杂性如何从简单的物理法则中产生,我们可能能够获得一些关于 LLMs 工作原理和原因的见解。
从简单中看复杂

复杂的结构,如气泡膜和其中的对流单元,是由简单的物理法则生成的(照片来源:chuttersnap 在 Unsplash)
我们的世界本质上是复杂的,但它可以通过极少量的基本相互作用来描述。例如,复杂的雪花和气泡膜可以与分子之间简单的吸引力联系起来。
那么,复杂结构产生的共同点是什么?在物理学中,复杂性是当我们从最小尺度放大到最大尺度时产生的。
以语言为类比,英语从有限数量的基本成分开始——26 个符号。这些符号可以组合成大约 100,000 个可用的单词,每个单词都携带独特的意义。从这些单词中,可以生成无数的句子、段落、书籍和卷册。
这种语言学层次结构类似于物理学中的层次结构。我们目前的基本法则(标准模型)以有限数量的基本粒子开始,如夸克和电子,以及由光子等力子介导的一些相互作用。这些粒子和力结合形成原子,每种原子具有独特的化学性质。从这些原子中,产生了大量的分子、结构、细胞和生物。
在我们的物理世界中,有一种涌现的普遍性:尽管许多复杂系统有着完全不同的起源,但它们常常共享一些普遍特征。例如,许多液体尽管具有不同的化学性质,却共享三种共同的相态(液态、固态和气态)。作为一个更极端的例子,某些材料的物理学(I 型超导体)可以借用来描述基本物理学(著名的希格斯机制)。
尽管需要牢记语言和物理学之间的区别——物理定律由自然规定并受到限制,而语言是看似不受约束的人类创造——但语言复杂性不必与我们世界中的物理复杂性相似。
然而,正如我们将要论证的那样,ChatGPT 和其他类似的 LLMs 包含类似粒子物理学的结构。如果我们认为这些结构是 LLMs 成功的关键,它可能会提供语言复杂性与物理复杂性共享一些共同点的线索。此外,这也可能为我们提供关于 LLMs 如何工作的有价值见解。
语言模型的物理学

物理定律由方程 govern,但像 ChatGPT 这样的 LLMs 呢?(来源:作者自身的工作)
为了将 LLMs 与物理学联系起来,我们需要将它们的基础数学结构进行关联。在物理学中,粒子的运动(或更一般地说,场或状态)可以示意性地表示为:

物理方程的示意图
(** 技术说明:哈密顿量形式使其更精确,尽管导数部分需要稍作修改)
直观地说,它表明粒子因力的作用而移动,这些力来源于一些抽象对象叫做势能的斜率。这类似于水流下山,势能来自重力和流体动力学。
结果表明,LLMs 的结构非常相似:它们将句子拆解为基本组成部分,即tokens,并以类似的方式逐层修改这些 tokens:

描述 LLMs 本质的示意方程
这将在下面的技术部分中更为准确地说明。由此,我们可以进行类比。
基于 Transformer 的语言模型将词语视作粒子,这些粒子在相互影响下移动,生成引人入胜的模式。
这样,就像水分子可以构建美丽的雪花,或液体肥皂混合物可以创造复杂的气泡图案一样,ChatGPT 的有趣结果可能归因于其类似物理的行为。
在接下来的可选部分中,我们将更详细地描述这一类比如何变得更加严谨,然后深入探讨这一洞见如何帮助我们理解 LLMs。
技术绕道
下面我们将更详细地解释如何将 LLMs 视作物理模型。
从物理学的角度来看,在微观层面上,每个粒子通常会受到系统中所有粒子的影响。例如,假设一个只有 3 个粒子的假想世界;在这种情况下,一个粒子和另一个粒子之间会有总共 3 × 3 = 9 种可能的势能。从示意图来看,我们可以这样表示:

描述三个粒子运动方程的示意图
(在物理学中,势能通常是对称的,即 Potential₁₂ = Potential₂₁,我们在这里放宽了这一约束)
为了了解这与 LLMs 的关系,让我们回顾一些基本事实:
-
为了将数据输入 LLMs,文档或文本会被拆分成 tokens。tokens 通常由一个词或词的一部分组成。像粒子一样,tokens 被视为 LLM 中最小的不可分割的组成部分。
-
LLMs 具有多个层级,在每一层中,所有的 token 都被自注意力模块所修改。
-
最终输出层汇聚 tokens 形成预测,这些预测可用于分类或生成文本/图像。
如果我们以三个 token 的例子(比如来自句子“I like physics”)进行分析,这些方程会是什么样的?
根据我们所处理的 LLMs 的具体类型,有一些小的差异:BERT或GPT。
BERT 模型
对于类似 BERT 的模型(通常用于分类),每一层会按如下方式示意性地修改 tokens:

(** 层₁ 的参与是由于残差层)
如果我们把层看作类似于时间维度,那么这个方程的结构类似于控制三个粒子运动的方程,尽管在 LLM 中,层是离散的,而在物理学中则是连续的。
为了使类比完整,我们仍需要将注意力部分转换为某种潜在量。让我们从数学上更深入地挖掘。选取一个特定的令牌 tᵢ,在每一层中,它会根据自注意力机制进行相应的修改(忽略多个注意力头):

其中 Q、K、V 是在注意力模块中通常看到的查询、键、值矩阵。现在我们忽略了归一化层。关键是指数形式可以被重写为某种潜在项的导数!

(** 尽管 QᵀK 可能并不总是可逆的,这个方程式可能并不完全准确,但 V 是我们模型中的一个任意权重:因此我们总是可以在注意力模块中用 V 代替 M 以实现相同的模型性能)
这样,通过 LLM 的各层传递令牌类似于粒子在某些成对相互作用下的相互作用!这有点像气体分子互相碰撞并形成天气模式。
(** 从这个角度来看,我们可以将归一化和矩阵乘法 M 解释为一种投影,以便令牌粒子在系统中得到适当约束。这类似于过山车被限制在轨道上。)
GPT
对于类似 Chat 的模型,讨论会有所修改。注意力模块具有额外的因果结构——即令牌只能被之前的令牌修改。这意味着方程中缺少一些项:

按照我们的类比,这意味着粒子一次一个地进入,每个粒子在经过所有的相互作用层后会被卡住。这有点像逐个原子地生长晶体。
需要记住的一点是,我们的物理类比并不是 100%准确的,因为物理学中普遍存在的对称性和能量/动量守恒等基本特性并不适用于 LLM。
语言模型中的涌现

像一片美丽的雪花一样,LLM 的输出可能依赖于其类似物理的属性(图片来源:Aaron Burden 在 Unsplash)
既然我们有了物理学的类比,它如何帮助我们理解 LLMs?希望的是,像复杂物理系统一样,我们可以从其他更熟悉且理解得更透彻的系统中获得对 LLMs 的洞见。然而,我必须提醒读者,以下大部分讨论将是推测性的,因为确认这些观点需要对 LLMs 进行详细的实验研究。
(* 事实上,如果我有更多资源,我会想象这些想法可能会成为有成果的学术研究项目)
下面是如何利用物理学语言重新框定我们对 LLMs 理解的示例。
LLMs 训练
使用热物理学的语言,我们可以将大型语言模型(LLMs)视为一个可调节的物理系统,模型训练类似于对系统施加热压以调整其参数。这一观点在我的另一篇文章“机器学习的热力学”中有描述,因此我在这里不会详细讨论。
智能的出现?
尽管关于 ChatGPT 是否智能有很多讨论,但我将避免对此有争议的话题进行更多探讨,因为我甚至不确定如何定义智能。然而,很明显,ChatGPT 可以持续产生复杂且有趣的输出。
如果我们接受物理学类比,这并不令人惊讶。从雪花到龙卷风,我们知道即使是简单的定律也可以产生高度复杂的行为,而从复杂行为中,可以出现看似智能的结构。
复杂性作为一个概念并不容易定义,因此为了进一步探讨,我们可以尝试检查复杂系统的一些关键特征:相变就是其中之一。
相变
许多复杂的物理系统具有独特的相,每个相都有一组突出的物理属性。因此,合理的猜测是,在 LLMs 中也可能存在独特的相,每个相都被调整以在特定任务中(例如编码与校对)提供帮助。
我们如何验证或反驳这种说法?这就是事情可能变得有趣的地方。在物理学中,相位出现时,交互作用开始形成有趣的结构。一些例子包括:
-
当水冷却时,分子之间的吸引力变得更强,使分子粘在一起形成固体。
-
当金属被冷却到极低的温度时,电子可能通过声波(声子)相互吸引,形成 I 型超导体。
在 LLMs 中是否可能发生类似的现象?例如,在 ChatGPT 中,人们可能会推测“代码”或“校对”中的某些令牌组合可能会触发一系列特定的力量,从而驱动特定类型的输出。
相变的另一个技术方面是对对称性的修改。这与结构的创建有关,例如从水蒸气中形成的冰晶图案。尽管 LLMs 不具备物理对称性,但它们应该包含某种模型权重的排列对称性。这是因为模型性能应该是相同的,只要它们以相同的统计数据初始化并在相同的范式下训练。特定权重的具体值只有在训练过程中才变得重要。这可以看作是权重的“冻结”。然而,要继续讨论这个话题,我们需要深入探讨自发对称破缺的技术内容,我们将在以后再讨论。
LLMs 高效吗?
尽管有许多关于 LLMs 由于其大量参数而被认为效率低下的批评(特别是与物理模型相比时),这些批评可能并不完全成立。
为什么?这归结于我们计算机的技术限制,这导致了物理学和 LLMs 之间存在显著差异:
-
物理定律具有无限精度,而 LLMs 具有有限精度。
-
物理学表现出巨大的层级结构,一些力非常微小,而另一些力则很大。在 LLMs 中,我们通过归一化尝试使所有输出/权重大小相似。
-
在物理学中,微小的效应可以累积成巨大的影响(例如地球的引力)。在 LLMs 中,这些微小的效应通常会被舍去和消除。
-
自然是一个极其高效的计算机,能够以无限精度瞬时计算所有尺度上的相互作用。另一方面,LLMs 则是受限于有限精度的相对较慢的计算机。
这意味着,虽然我们可以努力使 LLMs 更好地模拟物理现象并创建更强大的模型,但在实际应用中,计算机本质上无法完全模拟我们的世界(如在“我们为什么不生活在模拟中”中讨论的)。因此, resorting to a large number of parameters 可能是一种应对这些不足的最后手段。
即使考虑到有限的精度,也可以认为标准计算机能够达到的复杂度可能存在上限。这可能会使得显著减少参数数量变得非常具有挑战性(尽管量子计算的进展可能会在未来改变这种情况)。
对 LLMs 的改进
我们的物理类比是否可以为下一代 LLMs 提供一些线索?我认为这是可能的。从逻辑上讲,根据我们的信念,有两个可能的方向可以探索:
-
物理类似的特征是值得追求的:我们应该从物理学中获得更多灵感,以创造更好的模型结构。
-
物理学特征是不可取的:物理学特征可能由于固有的计算限制而限制 LLM 的能力,因此我们应该避免它们。
既然我们使用物理学来理解大型语言模型(LLMs),那么我将重点关注第一个可能性。在这种假设下,我们如何解决类似 ChatGPT 的 LLMs 的不足之处?
-
保持层次结构:我们不应仅仅专注于规范化权重和降低精度,而应探索替代方法,以考虑不同强度和规模的多样化交互。我们可以借鉴自然界中电磁力(非常强)和重力(非常弱)如何结合的方式。
-
适应不同阶段:使用相同的基本分子方程描述冰和水是低效的。使用不同的描述来应对不同的阶段(例如声波与水波)会更有效。我们可以创建一种更好的结构,能够自然地适应模型中的宏观差异。
-
高级物理学技术:在物理学中,我们不仅仅使用基本方程研究涌现现象。技术如 热力学、均场理论 和 重整化 可以帮助我们简化问题。将这些思想的一部分融入 LLM 的构建模块中,可能会提高它们的效率。例如,最近对 线性注意力(A. Katharopoulos 等)的进展,可能已经被解释为一种均场方法。
通过探索这些途径,我们或许能够提升 LLM 的能力和效率,利用物理学的洞察力进一步推动该领域的发展。
结语
总结一下,我们展示了 LLM 的数学如何与物理学中的数学相似。这使我们能够利用对日常物理系统的直觉来理解这些新兴现象,例如 ChatGPT。我希望这有助于揭示 LLM 特征背后的原因。
更一般地说,我希望我已经向你传达了物理学如何为像 LLMs 这样的复杂主题提供有价值的见解。我坚信,当我们从看似不相关的领域中借鉴洞察力时,科学最为有效。
如果你喜欢这篇文章,你可能会对我关于类似主题的其他文章感兴趣,例如物理学与人工智能之间的联系。
请留下评论或提供反馈,这鼓励我写出更多有见地的文章!👋
自然界中的复杂系统可以通过热力学进行成功研究。那么,机器学习呢?
[逻辑回归背后的意义,从物理学角度看 [## 逻辑回归背后的意义,从物理学角度看]
为什么我们使用逻辑回归和 softmax 函数?热物理学或许能提供答案。
为什么因果关系是相关性的体现:物理学家的视角(第一部分) [## 为什么因果关系是相关性的体现:物理学家的视角(第一部分)]
我们都听说过“相关性不代表因果性”这句话,但没有人真正谈论因果性到底是什么…
为什么我们不生活在模拟中 [## 为什么我们不生活在模拟中]
将现实描述为模拟极大地低估了我们世界的复杂性。以下是为什么模拟…
了解 LoRA — 低秩适配用于微调大型模型
这一参数高效微调方法背后的数学原理
·
关注 发布于 Towards Data Science · 4 分钟阅读 · 2023 年 12 月 22 日
--
来源 — 图片由 DALLE-3 生成。提示:一个较小的机器人与一个较大的机器人握手。较小的机器人是紫色的,充满电,活跃而有活力,而较大的机器人则被冻结在冰中,呈灰色。
微调大型预训练模型在计算上具有挑战性,通常涉及调整数百万个参数。虽然这种传统的微调方法有效,但需要大量的计算资源和时间,这为将这些模型适应特定任务带来了瓶颈。LoRA 通过在微调过程中分解更新矩阵,提出了一个有效的解决方案。为了研究 LoRA,我们首先回顾一下传统的微调。
( Δ W )的分解
在传统的微调中,我们会修改预训练神经网络的权重以适应新任务。这种调整涉及改变网络的原始权重矩阵( W )。微调过程中对( W )所做的更改统称为( Δ W ),以便更新后的权重可以表示为( W + Δ W )。
现在,LoRA 方法并不是直接修改( W ),而是寻求分解( Δ W )。这种分解是减少与微调大模型相关的计算开销的关键步骤。

传统的微调可以根据上述方法重新构想。在这里,W 是固定的,而 ΔW 是可训练的(图像来源于博客作者)
内在秩假设
内在秩假设表明,通过较低维度的表示可以捕捉到神经网络的显著变化。基本上,它认为( Δ W )的所有元素并不是同样重要的;相反,这些变化的一个较小子集可以有效地封装所需的调整。
引入矩阵( A )和( B )
在此假设基础上,LoRA 建议将( Δ W )表示为两个较小矩阵( A )和( B )的乘积,且秩较低。更新后的权重矩阵( W’ )因此变为:
[ W’ = W + BA ]
在这个方程中,( W )保持不变(即,在训练过程中不更新)。矩阵( B )和( A )的维度较低,其乘积( BA )表示( Δ W )的低秩近似。

ΔW 被分解成两个矩阵 A 和 B,其中两个矩阵的维度都低于 d x d。(图像来源于博客作者)
低秩对可训练参数的影响
通过选择具有较低秩( r )的矩阵( A )和( B ),可训练参数的数量显著减少。例如,如果( W )是一个( d x d )矩阵,传统上,更新( W )将涉及( d² )个参数。然而,使用大小为( d x r )和( r x d )的( B )和( A ),总参数数量减少到( 2dr ),当( r << d )时,这个数量要小得多。
通过低秩自适应(LoRA)方法实现的可训练参数数量的减少提供了几个重要的好处,特别是在微调大规模神经网络时:
-
减少内存占用:LoRA 通过降低需要更新的参数数量来减少内存需求,有助于管理大规模模型。
-
更快的训练和适应:通过简化计算需求,LoRA 加快了大型模型对新任务的训练和微调。
-
对较小硬件的可行性:LoRA 的较低参数数量使得可以在较不强大的硬件上,如普通 GPU 或 CPU,进行大型模型的微调。
-
扩展到更大的模型:LoRA 使得 AI 模型的扩展无需相应增加计算资源,使得管理不断增长的模型规模变得更为实际。
在 LoRA 的背景下,秩的概念在确定适应过程的效率和有效性方面起着关键作用。值得注意的是,论文指出矩阵A和B的秩可以非常低,有时低至 1。
尽管 LoRA 论文主要展示了在自然语言处理(NLP)领域的实验,但低秩适应的基本方法具有广泛的适用性,可能在训练不同领域的各种神经网络时有效。
结论
LoRA 通过将(Δ W)分解为低秩矩阵的乘积,有效地平衡了将大型预训练模型适应新任务与保持计算效率的需求。内在的秩概念对这一平衡至关重要,确保了模型学习能力的本质在参数大大减少的情况下得以保留。
参考文献:
理解马赛克图
原文:
towardsdatascience.com/understanding-mosaic-plots-fcf148315f4b
PYTHON | DATA | VISUALISATION
一本全面的指南,讲述如何使用 statsmodels 和 Matplotlib 有效绘制多变量数据集
·发表于Towards Data Science ·7 分钟阅读·2023 年 6 月 13 日
--

我们生活在一个数据泛滥的世界中——一个不断扩展的数字海洋。但在那片海洋中,有待发现的珍贵洞察就在其中。
寻找这些珍珠的关键是什么?数据可视化——将原始数据呈现成更易于理解和解读的视觉形式的过程。
借助数据可视化,你可以赋予原始数字生命,将它们转化为揭示隐藏模式、潜在趋势和关键联系的形式,这些联系可能在数据中隐匿不显。
在我们拥有的数据可视化工具中,有著名的 Matplotlib。
这个强大的 Python 库功能多样且强大。
在 Matplotlib 的工具包中隐藏着一个你可能未曾遇到过的宝石——马赛克图。
这些图表提供了一种强大的方法来可视化跨多个维度的分类数据。
想象一下能够俯瞰你的数据,理解不同变量如何相互作用和交集。马赛克图正是可以做到这一点,以直接且视觉上吸引人的方式呈现复杂关系。
无论是发现客户细分、理解用户行为,还是揭示人口统计趋势,马赛克图都是研究人员和数据科学家必备的绝佳工具。
在本文中,我们将探讨马赛克图的世界。我们将详细介绍如何在 Matplotlib 中创建马赛克图,并讨论如何解读它们,为你的数据可视化工具包增添一个额外的优势。
理解马赛克图
让我们首先将马赛克图比作某种织锦。每个复杂的部分代表数据中的一个独特类别,部分的大小反映了该类别的频率。
因此,织锦函数作为马赛克图,提供了一个视觉表现,帮助理解各种类别变量之间的互动。

图片由 Tom Rumble 提供,来源于 Unsplash
马赛克图的独特之处在于它们能够同时处理多个维度。想象它就像是审视一个色彩斑斓的数据魔方。
从外部视角,可以观察到不同类别(或颜色,根据我们魔方的比喻)如何相互连接和融合。
想想,你可能在哪里使用这些图表?
想象一下你在 Netflix 等流媒体平台工作,任务是理解不同年龄组、性别和各种类型之间的相互作用。在这里,马赛克图大放异彩。它提供了一个视觉上的描绘,例如,18-25 岁的女性与喜剧的互动频率,与同龄男性进行对比,或者 35-45 岁年龄组对惊悚片的倾向。
马赛克图具有许多优点。
-
它们擅长处理多维度的类别数据,提供了数据的全景视角。
-
他们直观的颜色编码增强了视觉吸引力,加快了解读速度并促进快速决策。
-
它们揭示了原始数据表中可能隐藏的模式和关系。
尽管如此,了解它们的限制仍然至关重要。
-
如果类别数量过多或类别分布较为均匀,马赛克图可能会迅速变得复杂。
-
它们在处理定量数据时表现出限制,当精确的数值比较至关重要时,可能不是理想选择。
尽管有这些限制,马赛克图仍然是数据可视化工具中不可或缺的工具。
它们能够将枯燥的数字表转换为动态的趋势和关系表现。
因此,当未来面对多变量类别数据集时,可以考虑使用马赛克图来挖掘隐藏的洞察。
使用 Matplotlib 创建你的第一个马赛克图
初次接触马赛克图可能会感觉像迷宫一样,对吧?实际上,使用 Matplotlib,整个过程变得相对简单。
首先,我们需要一些数据。我们将使用开放的 Titanic 数据集作为示例(参见 www.openml.org/search?type=data&sort=runs&id=40945&status=active)。
我们的目标?弄清楚男性和女性之间的生存率如何分布。
让我们开始吧,好吗?
我们首先导入所需的库:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from scipy.io.arff import loadarff
from statsmodels.graphics.mosaicplot import mosaic
接下来我们加载数据集。数据集采用 ARFF 格式。我们可以通过以下方式将其加载到 pandas 中:
# Load the data
raw_data = loadarff('titanic.arff')
titanic = pd.DataFrame(raw_data[0])
我们调用 statsmodels 中的马赛克函数,并传入我们想要可视化的特征。
# Create the mosaic plot
mosaic(titanic, ['alive', 'sex'], gap=0.02)
# Customize the plot
plt.title('Survival by Sex on Titanic')
plt.xlabel('Survived')
plt.ylabel('Sex')
plt.show()
然后,我们得到如下结果:

当然,这只是一个非常简单的示例。
让我们提升一下难度,好吗?
假设我们想创建一个马赛克图,表示泰坦尼克号上每个乘客舱的比例、他们的生存状态和性别。
我们还将为‘幸存’变量使用颜色编码,并添加一些标签到图表中。
这一次,我们还需要导入 numpy 库。
import numpy as np
我们将如下创建survived列:
# Convert 'survived' into string for clarity in the plot
titanic['survived'] = titanic['survived'].map({0:'Died', 1:'Survived'})
我们还定义了图表使用的属性。这些属性主要用于图表的颜色编码。
# Define properties function for colors
props = lambda key: {'color': '#1b9e77' if 'Survived' in key else '#d95f02'}
然后我们创建如下图表:
# Create a new figure with a defined size
fig, ax = plt.subplots(figsize=(10, 8))
# Create the mosaic plot
mosaic(titanic, ['class', 'sex', 'survived'], gap=0.02, properties=props, ax=ax)
# Customize the plot
plt.title('Survival by Class and Sex on Titanic', fontsize=15)
plt.ylabel('Class - Sex')
plt.xlabel('Proportion')
plt.show()

解析马赛克图
马赛克图分为三个主要的垂直部分,代表泰坦尼克号上的三个舱位:第一舱、第二舱和第三舱。
每个部分的宽度与每个舱位的乘客数量成比例。在这种情况下,似乎第三舱的乘客最多,其次是第一舱,最后是第二舱。
在每个舱位部分内,图表进一步水平划分为两个部分,分别代表男性和女性。这些部分的高度与每个舱位中男性和女性的数量成比例。例如,在第三舱中,男性比女性多。
每个性别部分再进一步划分为幸存和遇难两个部分。这些划分的宽度与每个舱位中每个性别的幸存者和遇难者数量成比例。
例如,在‘第一’舱的‘女性’部分,代表‘幸存’的绿色区域远大于代表‘遇难’的橙色区域。这表明第一舱的女性中有很大比例幸存。
颜色编码提供了额外的视觉提示:绿色表示幸存者,而橙色表示遇难者。
使用马赛克图的小贴士和技巧
-
保持简单 —— 马赛克图很容易变得令人困惑和杂乱。一次只处理少数类别。目标是传达洞察。
-
实践有效的颜色编码 —— 与任何数据可视化一样,颜色调色板有着显著的影响。马赛克图也不例外。
-
使用标签和注释 —— 始终标记图表的轴并提供描述性标题。背景信息是关键。
总结 remarks
就这样!
我们已经完成了马赛克图的探索。现在是时候思考一下我们在过程中学到了什么。
我们通过了解数据可视化为何如此重要来开启了这段旅程。我们了解了马赛克图作为观察多层次分类数据的有效方式,它帮助我们发现类别之间的详细相互作用。
然后我们利用 Titanic 数据集制作了我们的第一个马赛克图。我们还为图表添加了颜色、标签和标题,帮助我们用数据编织故事。
马赛克图的真正优势在于其将复杂数据转化为简单、直观故事的能力。但别忘了,保持简单很重要。过多的类别或混乱的颜色方案会将一个出色的图表变成视觉上的头痛。
现在,掌握了这项新知识,你已经准备好开始自己的数据可视化冒险了。也许你会使用马赛克图来了解你在网上商店的客户行为,或分析本地选举中的投票趋势。不论你处理的数据是什么,马赛克图都是你的空白画布,而你的发现就是你用来绘制它的颜色。
不要害羞,多多尝试,尝试不同的调整和各种数据集。实践出真知!
记住,每组数据都有一个故事要讲。所以,出去吧,让你的马赛克图编织出隐藏在数据中的故事。
祝你绘图愉快!
你喜欢这篇文章吗?每月$5,你可以成为会员,解锁对 Medium 的无限访问权限。你将直接支持我和你在 Medium 上的其他喜爱作者。对此,深表感谢!
[## 使用我的推荐链接加入 Medium - David Farrugia
获取对我所有⚡高级⚡内容以及 Medium 上的所有内容的独家访问权限。通过给我买杯咖啡来支持我的工作…
david-farrugia.medium.com](https://david-farrugia.medium.com/membership?source=post_page-----fcf148315f4b--------------------------------)
想要联系我吗?
我很想听听你对这个话题的看法,或者任何关于人工智能和数据的想法。
如果你希望联系我,可以发邮件到davidfarrugia53@gmail.com。
使用 Python 理解多项分布
原文:
towardsdatascience.com/understanding-multinomial-distribution-using-python-f48c89e1e29f
多项分布背后的数学和直觉
·发布于Towards Data Science ·20 分钟阅读·2023 年 1 月 6 日
--

来源:pixabay.com/vectors/dice-game-die-luck-random-numbers-151867/
多项分布是二项分布的推广,用于在具有多个结果的实验中计算概率。本文对多项分布进行了直观的介绍,并讨论了其数学性质。此外,还将教你如何使用 Python 中的 SciPy 库来建模和可视化多项分布。
二项分布
由于多项分布是二项分布的推广,我们将在此简要回顾一下。有关单变量概率分布的详细讨论,请参阅另一篇文章,如果你对二项分布或随机变量和概率质量函数(PMF)等概念不熟悉,建议你先阅读那篇文章。
随机变量 是一个其值由随机实验的结果决定的变量。随机变量通常用大写字母表示,但我们使用小写字母表示它可以取的特定值。例如,我们可以定义一个随机变量 X 来表示掷硬币的结果。为此,我们需要给每个结果分配一个数值。我们可以用 1 表示得到正面,用 0 表示得到反面。现在,X=1 表示掷硬币的结果是正面,X=0 表示是反面。这样只能取特定值的随机变量称为离散随机变量。
我们将离散随机变量X的概率质量函数(PMF)定义为一个给出X等于某个特定值的概率的函数。从数学上讲,X的 PMF 定义为函数pₓ,使得

特定值X可以取的值用x表示,因此P(X=x)表示X=x的概率。我们知道,X所有可能值的概率之和应该等于 1,因此:

具有伯努利分布的离散随机变量X,其参数为p,表示一个只有两个结果的随机实验,分别用 0 和 1 表示(我们用X~Bern(p)表示)。这里X=1(也称为‘成功’)发生的概率是p,而X=0(也称为‘失败’)发生的概率是 1-p。例如,我们可以用伯努利分布的随机变量X表示掷硬币的结果。现在我们可以假设X=1 和X=0 分别表示得到正面和反面,p是得到正面的概率。
假设我们有一个n次独立随机实验的序列,每次实验可以用一个伯努利分布参数p的随机变量表示。因此,每次实验只有两个结果,分别为 1(成功)和 0(失败),且X=1 发生的概率为p(这样的实验称为伯努利试验)。具有参数n和p的二项分布的离散随机变量X表示这一序列中的成功总数,我们可以写作X ~ Bin(n, p)。假设我们有一枚硬币,正面朝上的概率为p。我们可以用具有参数n和p的二项分布的随机变量X表示这一硬币在n次投掷中的正面总数(图 1)。

图 1(作者提供的图片)
随机向量
多项分布是一种多变量分布。在统计学中,单变量分布是仅一个随机变量的概率分布。多变量分布是单变量分布的推广,适用于两个或更多随机变量。要理解这些分布,我们首先应讨论随机向量。随机向量是一个随机变量的向量。如果我们有n个随机变量X₁、X₂ …、Xₙ,我们可以将它们放在随机向量X中:

我们用粗体大写字母表示随机向量。要表示随机向量的一个可能值(这也是一个向量),我们使用粗体小写字母,如x。例如:

是一个包含两个元素的随机向量,且

是它的一个可能值。
联合 PMF
假设我们有一个随机向量

其中 X₁,X₂,…,Xₙ 是离散随机变量。X的联合概率质量函数(联合 PMF)定义为:

其中 P(X₁=x₁,X₂=x₂,…,Xₙ=xₙ) 是 X₁=x₁,X₂=x₂,…,和 Xₙ=xₙ 同时发生的概率。如果我们知道随机向量 X 的联合 PMF,我们可以通过边缘化来推导其一个分量 Xᵢ 的分布。边缘 概率质量函数 可以从 X 的联合 PMF 导出,如下所示

这里 Rₓ 是 X 的支持集,表示随机向量 X 可以取的所有值的集合。因此,为了推导 Xᵢ 在点 x 的边缘概率 PMF,我们需要对 Rₓ 中 Xᵢ 等于 x 的所有向量的概率进行求和。
多项分布
多项分布是二项分布的一种推广。假设我们有 n 次独立试验。每次试验有 k 种不同的结果(k ≥ 2),第 i 种结果的概率是 pᵢ 和

向量 p 表示这些概率:

设离散随机变量 Xᵢ 表示在n次试验中结果编号i出现的次数。随机向量 X 定义为

并且我们有 x₁+x₂ +…+xₖ=n。那么 X 就称为具有参数 n 和 p 的多项分布。X 的联合 PMF 定义如下:

让我们看看是否可以证明这一点。首先,我们应该注意到如果 x₁+x₂ +…+xₖ≠n,那么事件 X₁=x₁,X₂=x₂,…,和 Xₖ=xₖ 是不可能的,因此其概率应为零。如果 x₁+x₂ +…+xₖ=n,事件 X₁=x₁,X₂=x₂,…,和 Xₖ=xₖ 的概率是

此外,获得 X₁=x₁,X₂=x₂,…,和 Xₖ=xₖ 的 n 次试验的不同方式的总数是

所以在n次试验中,X₁=x₁,X₂=x₂,…,和Xₖ=xₖ的总概率是

基于这个故事,多项分布可以用来描述一个 k 面的骰子。假设我们有一个 k 面的骰子,且获得面 i 的概率是 pᵢ。此外,设 Xᵢ 表示观察到面 i 的总次数。现在,如果我们掷骰子 n 次,随机向量 X 就具有参数 n 和 p 的多项分布(见图 2)。

图 2(作者提供的图片)
我们也可以用另一种方式来表达多项式分布。假设我们有一个由 k 种不同类别的物品组成的总体(k ≥ 2)。总体中属于类别 i 的物品比例为 pᵢ,并且

现在我们从总体中随机选择 n 个物品(允许重复),我们假设随机变量 Xᵢ 代表选择的物品中属于类别 i 的数量(图 3)。由于这些物品是随机选择的且允许重复,因此选择是相互独立的,每个物品属于类别 i 的概率是 pᵢ。因此,我们可以将每次选择视为一个具有 k 种不同结果的独立试验,每个结果 i 的相应概率是 pᵢ。现在,如果我们定义向量 p 为

那么随机向量

具有参数 n 和 p 的多项式分布。

图 3(图像来源:作者)
具有参数 n 和 p 的多项式分布的随机向量 X 可以表示为 n 个具有参数 1 和 p 的多项式分布的随机向量之和:

证明(可选):向量 Yᵢ 可以写成如下形式:

由于 X₁+X₂ +…+Xₖ=n,我们可以得出在右侧有 n 个向量。每个 Yᵢ* 中有一个元素为 1,其余为 0。因此,每个 Yᵢ 可以表示一个具有 k 种不同结果的试验,其中等于 1 的元素表示观察到的结果。因此,每个 Yᵢ 具有参数为 1 和 p 的多项式分布。
我们可以合并多项式随机向量 X 中的多个元素来获得新的多项式随机向量。例如,设

现在我们可以合并前三个元素得到新的随机向量

而这个随机向量具有参数为 n 和 的多项式分布

实际上,通过合并 X₁、X₂ 和 X₅,我们创建了一个新类别,该类别的比例是 p₁+p₂+p₅。因此,它可以在剩余类别 3 和 4 上创建一个多项式分布,同时这些类别的比例保持不变。
多项式分布是二项式分布的推广。如果 k=2,多项式分布简化为二项式分布。因此,如果我们有

那么 X₁ ~ Bin(n, p₁) 和 X₂ ~ Bin(n, p₂)。
证明(可选):我们应该注意到 X₂= n-X₁ 和 p₂=1-p₁。因此,如果我们将结果 1 视为成功,将结果 2 视为失败,则每次试验只有两种结果(成功和失败),成功的概率是 p₁。因此,每次试验可以用具有参数 p₁ 的伯努利分布表示,X₁ 代表了 n 次伯努利试验中成功的次数,参数为 p₁。因此,我们得出 X₁ 具有参数 n 和 p₁ 的二项分布。类似地,如果我们将结果 2 视为成功,将结果 1 视为失败,则 X₂ 代表了 n 次伯努利试验中成功的次数,参数为 p₂,因此具有参数 n 和 p₂ 的二项分布。
我们可以将之前的结果扩展到 k>2 的多项分布。如果

那么每个 Xᵢ 的边际分布是具有参数 n 和 pᵢ 的二项分布:

证明(可选):我们假设结果 i 的成功概率为 pᵢ,所有其他结果的失败概率为 1-pᵢ。因此,Xᵢ 代表了 n 次伯努利试验中成功的次数,参数为 pᵢ,并且它具有参数 n 和 pᵢ 的二项分布。
此外,X 的某些元素的总和服从二项分布。从数学上讲,如果 Xᵢ,₁, Xᵢ,₂, …, Xᵢ,ₘ 是随机向量 X 的 m 个元素(m<k),它们在向量 p 中的对应概率是 pᵢ,₁, pᵢ,₂, …, pᵢ,ₘ,那么总和 Xᵢ,₁ + Xᵢ,₂ + …+ Xᵢ,ₘ 服从具有参数 n 和 pᵢ,₁ + pᵢ,₂ + …+ pᵢ,ₘ 的二项分布。例如,对于 n=8 和 k=5,X₁+X₃+X₅ 具有参数 8 和 p₁ + p₃ + p₅* 的二项分布。
证明(可选):如前所述,我们可以将多项式随机向量 X 中的多个元素合并,以获得一个新的多项式随机向量。因此,通过合并 Xᵢ,₁ + Xᵢ,₂ + …+ Xᵢ,ₘ,我们得到一个新的随机向量 Y,它具有多项分布:

我们现在还展示了多项式随机向量每个元素的边际分布具有参数 n 和其对应概率的二项分布,因此我们得出结论 Xᵢ,₁ + Xᵢ,₂ + …+ Xᵢ,ₘ 具有参数 n 和 pᵢ,₁ + pᵢ,₂ + …+ pᵢ,ₘ 的二项分布。
我们可以使用 Python 中的 scipy 库生成多项分布。我们可以使用 scipy.stat 中的 multinomial 对象创建多项分布,并使用该对象的 pmf() 方法计算其联合 PMF。该对象接受 n 和 p 参数,其中 n 和 p 对应于 n 和 p。p 是一个类似数组的对象。p 的每个元素应在 [0, 1] 区间内,并且这些元素的和应为 1。如果它们的和不为 1,则 p 的最后一个元素不会被使用,而是用之前元素和的 1 减去替代(以使元素和为 1)。
首先,我们需要导入所有必需的库:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import binom, multinomial
%matplotlib inline
列表 1 创建了以下多项分布

并计算其联合 PMF 在

# Listing 1
n = 5
p=[0.5, 0.3, 0.2]
mult = multinomial(n=n, p=p)
mult.pmf([3, 1, 1])
# Output
0.15
我们也可以直接将分布参数传递给 pmf() 函数:
multinomial.pmf([3, 1, 1], n=n, p=p)
# Output
0.15
方法 rvs() 可以用于生成随机变异。随机变异或简称变异是随机变量或随机向量的特定结果。使用此方法,我们可以从多项分布中抽取大小为 m 的随机样本,这意味着我们从具有多项分布的随机向量中生成 m 个随机变异。例如,要从方程 1 中的多项分布中抽取大小为 4 的随机样本,我们可以使用以下代码片段:
np.random.seed(2)
multinomial.rvs(n=n, p=p, size = 4)
# Output
array([[2, 3, 0],
[3, 1, 1],
[2, 2, 1],
[2, 2, 1]])
它返回一个包含 4 个元素的 2d 数组。每个元素是一个数组,表示方程 1 中的随机向量可以取的可能值 X。
让我们看看如何可视化方程 1 定义的多项分布。在这里,我们希望收集X可能取的所有值。为此,我们只需计算元组 (X₁, X₂) 所有可能的值。一旦我们知道了 X₁ 和 X₂ 的值,就可以利用 X₃=n-X₁-X₂ 计算 X₃ 的值。由于 n=5,Xᵢ 可以取 0 到 5 之间的整数值。列表 2 使用 numpy 中的 meshgrid() 方法获取 (X₁, X₂) 的所有可能值及其对应的 X₃ 值。最后,(X₁, X₂, X₃) 的不同值被存储在数组 X_mat 中。我们得到了一些负值的 X₃。虽然这些负值是不可能取的,但我们无需丢弃它们,因为多项分布中任何 Xᵢ 为负值的联合 PMF 仅为零。
# Listing 2
n = 5
x1_array = np.arange(0, n+1)
X1_mat, X2_mat = np.meshgrid(x1_array, x1_array)
x3_array = n - X1_mat.flatten() - X2_mat.flatten()
X_mat = np.array([X1_mat.flatten(), X2_mat.flatten(), x3_array]).T
X_mat
# Output
array([[ 0, 0, 5],
[ 1, 0, 4],
[ 2, 0, 3],
[ 3, 0, 2],
[ 4, 0, 1],
[ 5, 0, 0],
[ 0, 1, 4],
[ 1, 1, 3],
...
现在我们可以使用 X_mat 计算 X₁ 在 X₁=3 时的边际 PMF。为此,我们需要对所有值为 3 的 X 的联合 PMF 进行求和。
multinomial(n, p).pmf(X_mat[X_mat[:, 0]==3]).sum()
# Output
0.3125
我们还知道X₁的边际分布是参数n和p₁的二项分布。因此,我们可以使用参数n=5 和p₁=0.5 的二项分布的 PMF。为了计算二项分布的 PMF,我们可以使用scipy.stat中的binom对象。我们计算X*₁=3 处的 PMF 值,结果应该与之前的代码片段相同。
binom.pmf(k=3,n=n, p=p[0])
# Output
0.31249999999999983
结果几乎相同,但由于数值误差存在小的差异。我们还可以通过从多项分布中抽样来近似X₁的边际 PMF。列表 3 从方程 1 中的多项分布中抽取了大小为 100,000 的样本,并绘制了该样本中X₁值的条形图。结果如图 4 所示。正如图中所示,样本的条形图与参数n=5 和p₁=0.5 的二项分布的 PDF 匹配。
# Listing 3
np.random.seed(2)
n = 5
p=[0.5, 0.3, 0.2]
x = np.arange(n+1)
sample = multinomial.rvs(n=n, p=p, size=100000)
pmf_binomial = binom.pmf(k = x,n = n, p = p[0])
sample_marginal = sample[:, 0]
values, counts = np.unique(sample_marginal, return_counts=True)
probs = counts / counts.sum()
fig = plt.figure(figsize=(8, 6))
plt.bar(values, probs, label='Marginal sample')
plt.plot(x, pmf_binomial, marker='o', color='red',
label='Binomial distribution, n={}, p={}'.format(n, p[0]))
plt.xlabel('$x_1$', fontsize = 18)
plt.ylabel('Probability', fontsize=18)
plt.ylim([0, 0.45])
plt.legend(loc='best', fontsize = 15)
plt.show()

图 4
如果一个具有多项分布的随机变量只有 3 个元素(k=3),我们可以轻松地在 2d 或 3d 空间中绘制其联合 PMF。列表 4 绘制了图 5(顶部)中先前多项分布的联合 PMF。图是X₁与X₂的热图。我们不需要在此图中包含X₃,因为其值依赖于X₁与X₂(X₃=n-X₁-X)。X₁的二项分布 PMF 也绘制在图 5(底部)中。
# Listing 4
n = 5
p=[0.5, 0.3, 0.2]
x = np.arange(n+1)
x1_array = np.arange(0, n+1)
X1_mat, X2_mat = np.meshgrid(x1_array, x1_array)
x3_array = n - X1_mat.flatten() - X2_mat.flatten()
X_mat = np.array([X1_mat.flatten(), X2_mat.flatten(), x3_array]).T
pmf_mult = multinomial(n, p).pmf(X_mat)
pmf_binomial = binom.pmf(k = x, n = n, p = p[0])
pmf_grid = pmf_mult.reshape(6, 6)
f, (ax1, ax2) = plt.subplots(2, 1, figsize=(9, 11))
plt.subplots_adjust(hspace=0.4)
heatmap = ax1.pcolor(pmf_grid, cmap='coolwarm')
ax1.set_xlabel('$x_1$', fontsize = 18)
ax1.set_ylabel('$x_2$', fontsize = 18)
for x1_pos in range(pmf_grid.shape[0]):
for x2_pos in range(pmf_grid.shape[1]):
ax1.text(x1_pos + 0.5, x2_pos + 0.5,
'%.5f' % pmf_grid[x2_pos, x1_pos],
horizontalalignment='center',
verticalalignment='center',
)
ax1.set_yticks(np.arange(pmf_grid.shape[0])+0.5)
ax1.set_yticklabels(np.arange(pmf_grid.shape[0]))
ax1.set_xticks(np.arange(pmf_grid.shape[1])+0.5)
ax1.set_xticklabels(np.arange(pmf_grid.shape[1]))
plt.colorbar(heatmap, ax=ax1)
ax2.bar(x, pmf_binomial,
label='Binomial\ndisitrbution\nn={}, p={}'.format(n, p[0]))
ax2.bar_label(ax2.containers[0], fontsize = 13)
ax2.set_xlabel('$x_1$', fontsize = 18)
ax2.set_ylim([0, 0.4])
ax2.set_ylabel('Probability', fontsize=18)
ax2.legend(loc='best', fontsize = 14)
plt.show()

图 5(作者提供的图片)
热图中每列在X₁=x处包含所有值的联合 PMF,其中X₁等于x。因此,X₁在x处的边际 PMF 等于X₁=x列的总和(图 5)。类似地,X₂在x处的边际 PMF 等于X₂=x行的总和。
我们还可以使用 3d 条形图来可视化方程 1 中多项分布的联合 PMF。如前所述,我们只需要在图中包含X₁和X₂,因为X₃的值依赖于它们。列表 5 使用pmf()方法计算X_mat中数组的联合 PMF,并在图 6 中创建这些联合 PMF 值的 3d 条形图。此图中还显示了X₁和X₂的边际分布。
# Listing 5
n = 5
p=[0.5, 0.3, 0.2]
x1_array = np.arange(0, n+1)
X1_mat, X2_mat = np.meshgrid(x1_array, x1_array)
x3_array = n - X1_mat.flatten() - X2_mat.flatten()
X_mat = np.array([X1_mat.flatten(), X2_mat.flatten(), x3_array]).T
pmf_mult = multinomial(n, p).pmf(X_mat)
pmf_x1 = binom.pmf(k=x1_array,n=n, p=p[0])
pmf_x2 = binom.pmf(k=x1_array,n=n, p=p[1])
fig = plt.figure(figsize=(10, 10))
ax1 = fig.add_subplot(111, projection='3d')
x1 = X_mat[:, 0]
x2 = X_mat[:, 1]
z = np.zeros(len(x1))
width = 0.8
dx1 = np.repeat(width, len(x1))
dx2 = np.repeat(width, len(x1))
ax1.bar3d(x1-width/2, x2-width/2,
z, dx1, dx2, pmf_mult, color='aqua')
ax1.bar3d(x1_array-width/2, np.repeat(6, n+1)-width/2,
np.zeros(n+1), np.repeat(0.8, n+1),
np.zeros(n+1), pmf_x1, color='blue', shade=False)
ax1.bar3d(np.repeat(0, n+1)-width/2, x1_array-width/2,
np.zeros(n+1), np.zeros(n+1),
np.repeat(0.8, n+1), pmf_x2,
color='blue', shade=False)
ax1.set_xlabel('$x_1$', fontsize=20)
ax1.set_ylabel('$x_2$', fontsize=20)
ax1.set_zlabel("$p_\mathregular{X}(\mathregular{x})$", weight="bold",
style="italic", fontsize=18, labelpad = 8)
ax1.text(4.2, 5, 0.25,"$p_{X_1}(x_1)$", fontsize= 18, color='b')
ax1.text(0.7, 1.7, 0.3, "$p_{X_2}(x_2)$", fontsize= 18, color='b')
ax1.text(1.3, 7, 0.25,"$X_1 \sim Bin(5, 0.5)$", fontsize= 18)
ax1.text(-3, 1.7, 0.3, "$X_2 \sim Bin(5, 0.3)$", fontsize= 18)
ax1.view_init(35, -45)
plt.show()

图 6
我们可以将方程 1 中的多项分布的前两个元素合并,得到新的随机向量。

我们还知道这个随机向量具有参数n的多项分布。

列表 6 通过从多项分布中采样来近似 X₁+X₂ 的边际 PMF。它从方程 86 中的多项分布中抽取一个 100,000 大小的样本,并绘制该样本中 X₁+X₂ 值的条形图。结果如图 7 所示,你可以看到条形图与参数 n=5 和 p₁+p₂ 的二项分布的 PDF 相匹配。
# Listing 6
np.random.seed(2)
n = 5
p=[0.5, 0.3, 0.2]
x = np.arange(n+1)
sample = multinomial.rvs(n=n, p=p, size=100000)
pmf_binomial = binom.pmf(k = x,n = n, p = p[0]+p[1])
sample_marginal = sample[:, 0] + sample[:, 1]
values, counts = np.unique(sample_marginal, return_counts=True)
probs = counts / counts.sum()
fig = plt.figure(figsize=(8, 6))
plt.bar(values, probs, label='Marginal sample')
plt.plot(x, pmf_binomial, marker='o', color='red',
label='Binomial distribution, n={}, p={}'.format(n, p[0]+p[1]))
plt.xlabel('$x_1+x_2$', fontsize = 18)
plt.ylabel('Probability', fontsize=18)
plt.ylim([0, 0.45])
plt.legend(loc='best', fontsize = 13)
plt.show()

图 7
列表 7 展示了如何从具有 k 个不同类别的项的总体中进行有放回采样可以产生多项分布(图 3)。这里总体由一个具有 3 个独特元素的列表表示:1、2 和 3。它们的比例分别是 0.5、0.3 和 0.2。我们可以从这个列表中随机选择 5 个元素,并计算 1、2 和 3 的出现次数(我们用 x₁、x₂ 和 x₃ 来表示)。现在这个向量

是方程 1 中多项分布的随机变量。我们也可以将其称为从该分布中抽取的大小为 1 的随机样本。请注意,我们也可以使用 rvs() 方法生成这个随机变量。现在我们可以重复相同的过程 n 次,从多项分布中抽取一个大小为 n 的随机样本。列表 7 在图 8(蓝色条形)中绘制了两个随机样本的条形图。这些样本的大小分别为 30 和 50000。图中还绘制了方程 1 中定义的多项分布的联合 PMF 的条形图(红色条形)。随着样本大小的增加,样本条形图的形状越来越接近多项分布联合 PMF 的条形图的形状。
# Listing 7
np.random.seed(1)
population = [1]*5 + [2]*3 + [3]*2
n = 5
xedges = np.arange(n+2)
yedges = np.arange(n+2)
num_samples_list = [30, 50000]
x1_array = np.arange(0, n+1)
X1_mat, X2_mat = np.meshgrid(x1_array, x1_array)
x3_array = n - X1_mat.flatten() - X2_mat.flatten()
X_mat = np.array([X1_mat.flatten(), X2_mat.flatten(), x3_array]).T
pmf_mult = multinomial(n, p).pmf(X_mat)
x1 = X_mat[:, 0]
x2 = X_mat[:, 1]
z = np.zeros(len(x1))
width = 0.15
dx1 = np.repeat(width, len(x1))
dx2 = np.repeat(width, len(x1))
fig = plt.figure(figsize=(18, 10))
plt.subplots_adjust(wspace=0.1)
ax1 = fig.add_subplot(121, projection='3d')
ax2 = fig.add_subplot(122, projection='3d')
axs = [ax1, ax2]
for i, num_samples in enumerate(num_samples_list):
samples = np.random.choice(population,
size=n*num_samples).reshape(num_samples, n)
samples_count = np.stack(((samples==1).sum(axis=1),
(samples==2).sum(axis=1)), axis=-1)
H, _, _ = np.histogram2d(samples_count[:, 0],
samples_count[:, 1],
bins=(xedges, yedges))
H = H.T / H.sum()
axs[i].bar3d(x1-2*width, x2-width, z, dx1, dx2, H.flatten(),
color='aqua')
axs[i].bar3d(x1+width/2, x2-width, z, dx1, dx2, pmf_mult,
color='red')
axs[i].set_xlabel('$x_1$', fontsize=20)
axs[i].set_ylabel('$x_2$', fontsize=20)
axs[i].set_zlabel("$p_\mathregular{X}(\mathregular{x})$",
weight="bold", style="italic",
fontsize=18, labelpad = 8)
axs[i].set_title('Sample size={}'.format(num_samples), fontsize=20)
axs[i].set_zlim([0, 0.2])
axs[i].view_init(35, -135)
plt.show()

图 8
在列表 8 中,我们比较了在方程 1 中定义的多项分布的随机变量的联合 PMF 与具有 n 个随机向量的和的联合 PMF,这些随机向量具有参数 n=1 和 p。我们首先从方程 1 中定义的多项分布中抽取一个 5000000 大小的样本(sample1),该多项分布具有以下参数:

然后我们从 5 个具有相同p和n=1 的多项分布中抽取一个 5000000 大小的样本。我们将这 5 个样本相加,得到一个 5000000 大小的样本并存储在sample2中。实际上,这个样本包含了随机向量

可以取。
# Listing 8
np.random.seed(50)
num_samples = 5000000
n = 5
p=[0.5, 0.3, 0.2]
sample1 = multinomial.rvs(n=n, p=p, size=num_samples)
sample2 = multinomial.rvs(n=1, p=p,
size=n*num_samples).reshape(num_samples, n, len(p)).sum(axis = 1)
我们现在可以比较在sample1和sample2中获得特定值的概率。
(sample1 == [3, 2, 0]).all(axis=1).mean()
# output
0.1124632
(sample2 == [3, 2, 0]).all(axis=1).mean()
# Output
0.1125352
正如你所见,概率非常接近,微小的差异是由于样本大小的不同。随着样本大小趋于无穷大,这些概率趋于相同。
如前所述,如果我们有:

然后每个Xᵢ具有参数n和p的二项分布。因此,利用二项分布的性质我们得出:

我们定义随机向量X的均值为

我们还展示了X的一些元素的和具有二项分布。因此,对于每对i、j (i、j=1…k,i≠j),Xᵢ + Xⱼ具有参数n和pᵢ + pⱼ的二项分布。因此,我们有:

可以证明:

由此可以得出:

由于n、pᵢ和pⱼ是正数,我们得出Xᵢ和Xⱼ的协方差总是负数。让我们看看为什么相关性是负的。两个随机变量之间的负相关意味着当其中一个变量较高时,另一个变量倾向于较低,并且当一个变量增加时,另一个变量倾向于减少。我们知道X₁+X₂+ …+ Xₖ=n,所以假设X的所有其他分量(所有m ≠ i、j的Xm)都是常数,那么任何Xᵢ的增加都应该导致Xⱼ的减少,反之亦然。
我们定义X的协方差矩阵为:

所以,协方差矩阵中的(i, j)元素是变量Xᵢ和Xⱼ的协方差,而第i个对角元素给出了Xᵢ的方差。根据协方差的定义,我们知道Cov(Xᵢ, Xⱼ)=Cov(Xⱼ, Xᵢ)。因此,协方差矩阵是对称矩阵,其中的(i, j)元素等于(j, i)元素。
我们可以使用 Python 轻松计算在方程 1 中定义的多项分布的均值和协方差矩阵。方法mean()返回方程 2 中定义的均值向量:
n = 5
p=[0.5, 0.3, 0.2]
multinomial(n, p).mean()
# Output
array([2.5, 1.5, 1\. ])
方法cov()返回协方差矩阵(方程 3):
multinomial(n, p).cov()
# Output
array([[ 1.25, -0.75, -0.5 ],
[-0.75, 1.05, -0.3 ],
[-0.5 , -0.3 , 0.8 ]])
在这篇文章中,我们讨论了多项分布背后的数学原理,并展示了如何在 Python 中实现它。多项分布在科学、工程和金融领域被广泛使用。它可以用于那些有超过两个可能结果的应用场景,并且系统不能通过成功-失败的描述进行建模。
希望你喜欢阅读这篇文章。如果你有任何问题或建议,请告诉我。本文中的所有代码列表可以从 GitHub 上作为 Jupyter 笔记本下载,链接为:github.com/reza-bagheri/probability_distributions/blob/main/multinomial_distribution.ipynb
理解 Naive Bayes 算法
原文:
towardsdatascience.com/understanding-naive-bayes-algorithm-d753d3b76727
它是什么以及如何将其应用于实际场景
·发表于Towards Data Science ·阅读时间 6 分钟·2023 年 12 月 29 日
--

照片由Google DeepMind提供,刊登在Unsplash
今年,我的决心是回到数据科学的基础。我每天都在处理数据,但如果你在完成重复的任务,很容易忘记一些核心算法的运作。我计划每周在 Towards Data Science 上深入探讨一个数据算法。这一周,我将介绍 Naive Bayes。
如何发音 Naive Bayes
只是为了澄清,你可以在这里学习如何发音 Naive Bayes。
既然我们知道怎么说它,我们来看看它的含义……
什么是 Naive Bayes 分类器?
这个概率分类器基于贝叶斯定理,可以总结如下:
当第二个事件已经发生时,第一个事件的条件概率是“事件 B 发生的情况下 A 的概率和 A 的概率除以事件 B 的概率”的乘积。
P(A|B) = P(B|A)P(A) / P(B)
一个常见的误解是贝叶斯定理和条件概率是同义词。
然而,有一个区别——贝叶斯定理使用条件概率的定义来找出所谓的“反向概率”或“逆概率”。
换句话说,条件概率是 A 发生的情况下 B 的概率。贝叶斯定理则找出 B 发生的情况下 A 的概率。
朴素贝叶斯算法的一个显著特征是其使用序列事件。简单来说,通过获取后续的附加信息,初始概率会进行调整。我们将这些称为先验概率/边际概率和后验概率。主要的结论是,通过了解另一个条件的结果,初始概率会发生变化。
朴素贝叶斯的应用
一个好的例子是医疗检测。例如,如果患者正在处理胃肠问题,医生可能会怀疑炎症性肠病(IBD)。这种情况的初始概率约为 1.3%。
然而,根据患者的症状,医生会要求进行C-反应蛋白(CRP)血液检查,这是一种炎症标志物。如果该标志物高于某个阈值(>50mg/l),则患有炎症性肠病的概率增加。
随后,医生进行结肠镜检查,结果显示有明显的炎症。现在,IBD 诊断的概率再次增加。
我们将通过本文的例子来理解不同的朴素贝叶斯分类器如何在医疗保健中应用。
朴素贝叶斯分类器的假设
正如其名字所示,朴素贝叶斯分类器与算法的特征是“朴素”或条件独立的相关。
另一个假设是模型中的特征对最终结果的贡献是相等的。然而,这些假设在现实世界中经常被违反。其目的是简化计算单一概率的数学模型。
朴素贝叶斯分类器的优缺点
- 朴素贝叶斯分类器的一个优点是它在小样本上表现良好,即使在违反一些上述严格假设的情况下。
根据具体问题的不同,朴素贝叶斯分类器有一些缺点。
-
例如,与回归模型不同,朴素贝叶斯并不打算衡量特征的重要性——主要是因为假设所有特征对结果的贡献是相等的。
-
此外,还存在一个被称为“零频率问题”的情况。当某个特征组合在训练数据中出现的概率为零时,所有概率的乘积也将等于零。为了克服这一问题,分析师通常会在值中加上 1,以防其等于零的可能性。
朴素贝叶斯分类器通常用于文本分类任务,包括识别“垃圾邮件”和标记情感。
朴素贝叶斯模型的不同类型
以下三种类型的朴素贝叶斯模型主要在关于条件概率分布的假设上有所不同。
1. 伯努利朴素贝叶斯
当变量有二进制分布时,使用伯努利朴素贝叶斯。这意味着所有特征都是分类的,可以取值 1(存在)或 0(不存在)。
在使用 Python 进行朴素贝叶斯时,我们可以使用 [sklearn.naive_bayes](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.naive_bayes) 的 BernoulliNB 包将尚未是分类变量的变量二值化。如果所有特征的截断点相同,这会很有用。例如,可以使用 0。然而,也有其他方法可以对连续数据进行二值化,包括 Pandas 的 cut 函数,如下所示:
df['CRGlevel_high'] **=** pd.cut(x**=**df['CRG'], bins**=**[0, 50, np.inf], labels**=**['No', 'Yes'])
2. 高斯朴素贝叶斯分类器
关于高斯朴素贝叶斯,首先要注意的是模型中的特征是连续且正态分布的。基于这一假设,在对模型中的数据点进行分类时,算法将最终值分配给具有最大后验概率的类别。
高斯朴素贝叶斯可以应用于常见的 Iris 数据集。运行此算法的基本 Python 代码如下:
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
# Train the classifier on the training data
gnb.fit(X_train, y_train)
高斯 NB 模块的文档可以在 这里 找到。
关于高斯朴素贝叶斯,值得了解的一点是,如果算法的假设得到满足,则模型的功能与逻辑回归相同。
假设从 Bhowmik(2015)的一篇文章中总结如下:
-
Y 是布尔值,受伯努利分布控制
-
Xi ∼ N(µij , σi),对所有的 i 适用
-
对于所有 i 和 k 6= i,Xi 和 Xk 在给定 Y 的条件下是条件独立的。
“如果朴素贝叶斯假设成立,NB 和 LR 产生的模型在渐近意义上是相同的。” 来源
在我们使用的示例中,可能需要考虑来自血液检查的多个值来确定一个人是否被诊断为 IBD。因此,我们可以输入原始的 CRG 值,以及 其他 IBD 标志物,如 Calprotectin 和 SedRate,而不是二值化连续变量。
3. 多项式朴素贝叶斯
当特征值表示相对频率而非二进制字段时,我们可以使用多项式朴素贝叶斯。
这对于文本分类问题很有用,这些问题通常依赖于 token 的频率。在我们开始讨论的 IBD 示例中,我们可以考虑使用电子健康记录(EHR)笔记来应用多项式朴素贝叶斯。
根据与患者症状相关的笔记,模型可能会使用“炎症”、“腹泻”、“便秘”等关键字的频率来判断患者是否可能被诊断为溃疡性结肠炎(UC)或克罗恩病。
当然,仅凭文本并不能确保做出诊断,但一个潜在的使用案例是使用 NB 分类器根据患者的症状识别出患有炎症性肠病(IBD)风险的患者,并鼓励医疗从业者订购额外的检查,以排除差异诊断或确认诊断。
我在医疗保健中的建议应用
总之,朴素贝叶斯和贝叶斯定理做出了一些假设,这些假设在现实世界中可能并不适用。然而,当你处理一个小数据集并寻找一个好的基线模型时,NB 可以很好地工作,因为它计算简单且易于理解。
基于我的研究,医疗保健中的潜在应用似乎很有前景。例如,在排除差异诊断时,可以使用朴素贝叶斯分类器来确定需要进行多少测试才能得出最终诊断。由于每个特征被期望是独立的并且贡献相等,这可能有助于指导从业者确定哪种测试组合是“足够的”以达成结论性诊断。
我预期这可以减少不必要的测试,并缩短最终诊断的时间,这样可以为医生节省时间,为健康保险计划减少费用,并为患者减少接受治疗和进行侵入性检查所需的时间。
了解 NeRFs
场景表示的重大突破
·发布于 Towards Data Science ·11 min read·2023 年 4 月 28 日
--

正如我们通过 DeepSDF [2] 和 SRNs [4] 等方法所见,将 3D 对象和场景编码在前馈神经网络的权重中是一种内存高效的隐式 3D 数据表示,既准确又高分辨率。然而,到目前为止,我们看到的方法还无法以足够的保真度捕捉现实和复杂的场景。相反,离散表示(例如三角网格或体素网格)在内存分配充足的情况下,能产生更准确的表示。
这在 Neural Radiance Fields (NeRFs) [1] 的提议下发生了改变,NeRFs 使用前馈神经网络来建模场景和物体的连续表示。NeRFs 使用的表示称为辐射场,与 先前 提案 有些不同。特别是,NeRFs 将五维坐标(即空间位置和视角方向)映射到体积密度和视角相关的 RGB 颜色。通过在不同视角和位置上累积这些密度和外观信息,我们可以渲染出逼真的新视角场景。
像SRNs [4]一样,NeRFs 可以仅使用一组图像(以及它们相关的相机姿态)对底层场景进行训练。与之前的方法相比,NeRF 渲染在质量和数量上都更好。值得注意的是,NeRFs 甚至可以捕捉复杂效果,如物体表面的视角依赖反射。通过在前馈神经网络的权重中隐式建模场景,我们在不需要过多内存成本的情况下实现了离散场景表示的准确性。

(摘自[1])
为什么这篇论文很重要? 本帖子是我关于 3D 形状和场景深度学习系列的一部分。NeRFs 在这一领域是一个革命性的提案,因为它们能够从任意视角实现极其准确的 3D 重建。NeRFs 生成的场景表示质量非常高,我们将在本帖的其余部分看到这一点。
背景
理解 NeRFs 所需的大部分背景概念已在之前的帖子中涵盖,包括:
在介绍 NeRFs 的工作原理之前,我们只需要覆盖几个更多的背景概念。
位置编码
NeRFs 不是直接将[x, y, z]坐标作为输入传递给神经网络,而是将这些坐标转换为高维位置嵌入。我们在之前关于transformer 架构的帖子中讨论过位置嵌入,因为位置嵌入用于向自注意力模块提供令牌排序和位置的概念。

(摘自[1])
简而言之,位置嵌入将一个标量数作为输入(例如,坐标值或表示序列中位置的索引),并产生一个高维向量作为输出。我们可以在训练期间学习这些嵌入,或者使用固定函数生成它们。对于 NeRFs,我们使用上面显示的函数,该函数将一个标量p作为输入,并产生一个2L维的位置编码作为输出。
其他内容
在本概述中,我们可能会遇到一些其他(可能)不熟悉的术语。让我们现在快速澄清这些术语。
端到端训练。 如果我们说一个神经架构可以“端到端”学习,这仅仅意味着系统的所有组件都是可微分的。因此,当我们为某些数据计算输出并应用我们的损失函数时,我们可以通过整个系统(即,端到端)进行微分,并用梯度下降进行训练!
并非所有系统都可以进行端到端训练。例如,如果我们正在建模表格数据,我们可能会执行特征提取过程(例如,一热编码),然后在这些特征上训练机器学习模型。由于特征提取过程是手工设计的且不可微分,我们不能对系统进行端到端训练!
朗伯特反射。 在阅读 NeRFs 之前,这个术语对我完全陌生。朗伯特反射指的是物体表面的反射特性。如果一个物体的表面是哑光的,并且从不同角度看不会改变,我们称这个物体是朗伯特型的。另一方面,一个“光亮”的物体,其反射光线的方式根据观察角度不同而变化,则被称为非朗伯特型的。
建模 NeRFs

(来自 [1])
使用 NeRF 渲染场景视角的高层次过程如下:
-
使用 射线行进 方法生成场景的 3D 点和视角方向样本。
-
将点和视角方向作为输入提供给前馈神经网络,以生成颜色和密度输出。
-
执行体积渲染,将场景中的颜色和密度累积到 2D 图像中。
我们现在将详细解释这个过程的每个组件。
辐射场。 如前所述,NeRFs 模型是一个 5D 向量值(即,意味着函数输出多个值)函数,称为辐射场。该函数的输入是一个 [x, y, z] 空间位置和一个 2D 视角方向。视角方向有两个维度,对应于表示 3D 空间中方向的两个角度;见下文。

3D 空间中的方向可以用两个角度表示。
这个函数的输出有两个组成部分:体积密度和颜色。颜色只是一个 RGB 值。然而,这个值是视角依赖的,这意味着给定不同的视角方向作为输入,颜色输出可能会变化!这种特性使 NeRFs 能够捕捉反射和其他视角依赖的外观效果。相比之下,体积密度仅依赖于空间位置,捕捉不透明度(即,光线通过该位置时积累的程度)。

NeRF 使用前馈神经网络建模辐射场(来自[1])
神经网络。 在[1]中,我们用前馈神经网络建模辐射场,该网络接收 5D 输入并训练以生成相应的颜色和体积密度作为输出;见上文。然而,请记住,颜色是视角相关的,而体积密度则不是。为了解决这个问题,我们首先将输入的 3D 坐标通过几个前馈层,这些层产生体积密度和特征向量作为输出。然后,将这个特征向量与视角方向连接,并通过一个额外的前馈层来预测视角相关的 RGB 颜色;见下文。

NeRF 的前馈架构。
体积渲染(简要说明)。 体积渲染是一个太复杂的话题,无法在这里深入讨论,但我们应该了解以下内容:
-
它可以从离散数据样本(例如颜色和密度值)生成底层场景的图像。
-
它是可微的。
对于那些对体积渲染的更多细节感兴趣的人,请查看这里的解释和[1]的第四部分。
大局观。 NeRF 使用前馈网络生成关于场景几何和外观的相关信息,这些信息沿着众多不同的相机光线(即从特定相机视点向场景中的某个方向移动的 3D 空间中的一条线)进行传递,然后使用渲染将这些信息汇聚成一张 2D 图像。
值得注意的是,这两个组件都是可微的,这意味着我们可以端到端地训练整个系统!给定一组带有对应相机姿态的图像,我们可以通过仅生成/渲染已知视角,并使用(随机)梯度下降来最小化 NeRF 输出与实际图像之间的误差,来训练 NeRF 以生成新的场景视点;见下文。

(来自[1])
一些额外细节。 我们现在了解了 NeRF 的大部分组件。然而,迄今为止我们描述的方法在[1]中实际上被证明是低效的,并且通常不能很好地表示场景。为了改进模型,我们可以:
-
用位置嵌入替换空间坐标(包括空间位置和视角方向)。
-
采用层次采样方法进行体积渲染。
通过使用位置嵌入,我们将前馈网络的输入(即空间位置和视角方向坐标)映射到更高维度。先前的工作表明,与直接使用空间或方向坐标作为输入相比,这种方法使神经网络更好地建模场景的高频(即变化很大/很快)特征[5]。这使得 NeRF 的输出质量大大提高;见下文。

(来自 [1])
NeRF 使用的分层采样方法使渲染过程更高效,只对可能影响最终渲染结果的地点和视角方向进行采样(并通过前馈神经网络传递)。这样,我们仅在需要的地方评估神经网络,避免在空白或遮挡区域浪费计算。
给我们展示一些结果吧!
NeRFs 被训练为一次仅表示一个场景,并在多个数据集上进行评估,这些数据集包括合成和真实对象。

(来自 [1])
如上表所示,NeRFs 显著超越了 SRNs [4] 和 LLFF [6] 等替代方法。除了定量结果之外,与其他方法的输出进行视觉对比也非常有用。首先,我们可以立即看出,使用位置编码和以视角依赖的方式建模颜色是非常重要的;见下文。

(来自 [1])
我们会立即注意到的一个改进是,NeRFs — 因为它们以视角依赖的方式建模颜色 — 能够捕捉场景中的复杂反射(即非朗伯特特征)和视角依赖的模式。此外,NeRFs 还能够以惊人的精度建模底层几何体的复杂方面;见下文。

(来自 [1])
NeRF 场景表示的质量在以视频形式查看时最为明显。正如下面的视频所示,NeRFs 以令人印象深刻的准确性和一致性建模了底层场景,并在不同视角之间保持一致。
要获取更多可以使用 NeRF 生成的逼真场景视角的示例,我强烈建议查看这里链接的项目网站!
主要结论
正如我们在评估中所见,NeRFs 在场景表示质量上取得了巨大的突破。因此,这项技术在人工智能和计算机视觉研究社区中获得了极大的关注。由于其场景表示的质量,NeRF 的潜在应用(例如虚拟现实、机器人等)几乎是无穷无尽的。主要结论列在下方。

(来自 [1])
NeRFs 捕捉复杂细节。 使用 NeRFs,我们能够捕捉场景中的细粒度细节,例如船只中的装配材料;见上文。除了几何细节,NeRFs 还能处理非朗伯特效应(即反射和基于视角的颜色变化),因为它们以视角依赖的方式建模颜色。
我们需要智能采样。 到目前为止,我们见过的所有 3D 场景建模方法都使用神经网络在 3D 空间上建模一个函数。这些神经网络通常会在所考虑的空间体积内的每个空间位置和方向上进行评估,如果处理不当,可能会非常昂贵。对于 NeRF,我们使用一种分层采样方法,仅评估可能影响最终渲染图像的区域,这大大提高了采样效率。类似的方法也被先前的工作采用;例如, ONets [3] 使用基于 octree 的 分层采样方法 更高效地提取对象表示。
位置嵌入效果很好。 到目前为止,我们见过的大多数场景表示方法都直接将坐标值作为输入传递给前馈神经网络。通过 NeRF,我们发现将这些坐标进行位置嵌入要好得多。特别是,将坐标映射到更高维度似乎允许神经网络捕捉场景几何和外观的高频变化,这使得结果场景渲染更加准确且视角一致。
仍然节省内存。 NeRF 隐式地建模了基础场景的连续表示。这种表示可以在任意精度下进行评估,并且具有固定的内存成本——我们只需存储神经网络的参数!因此,NeRF 可以在不占用大量内存的情况下提供准确的高分辨率场景表示。
“至关重要的是,我们的方法克服了在高分辨率下建模复杂场景时离散体素网格的高昂存储成本。” — 来源于 [1]
局限性。 尽管 NeRF 在先进技术方面取得了显著进展,但它们并不完美——表示质量仍有改进空间。然而,NeRF 的主要局限性在于它们每次只能建模一个场景,并且训练成本高(即,每个场景在单个 GPU 上需要 2 天)。未来在这一领域的进展如何找到更高效的生成 NeRF 质量场景表示的方法将会非常有趣。
结束语
非常感谢你阅读这篇文章。我是 Cameron R. Wolfe,Rebuy 的人工智能总监,也是莱斯大学的博士生。我研究深度学习的经验和理论基础。你也可以查看我在 medium 上的 其他文章!如果你喜欢这篇文章,请关注我的 twitter 或订阅我的 Deep (Learning) Focus 新闻通讯,在这里我通过对该领域流行论文的易懂概述,帮助读者深入理解深度学习研究中的主题。
参考文献
[1] Mildenhall, Ben 等人。“Nerf:将场景表示为用于视图合成的神经辐射场。” ACM 通讯 65.1 (2021): 99–106。
[2] Park, Jeong Joon 等人。“Deepsdf:学习用于形状表示的连续符号距离函数。” IEEE/CVF 计算机视觉与模式识别会议论文集。2019 年。
[3] Mescheder, Lars 等人。“占用网络:在函数空间中学习 3d 重建。” IEEE/CVF 计算机视觉与模式识别会议论文集。2019 年。
[4] Sitzmann, Vincent、Michael Zollhöfer 和 Gordon Wetzstein。“场景表示网络:连续的 3d 结构感知神经场景表示。” 神经信息处理系统进展 32 (2019)。
[5] Rahaman, Nasim 等人。“神经网络的谱偏差。” 国际机器学习会议。PMLR,2019 年。
[6] Mildenhall, Ben 等人。“局部光场融合:带有规定性采样指南的实用视图合成。” ACM 图形学交易 (TOG) 38.4 (2019): 1–14。
理解机器学习中的噪声数据和不确定性
原文:
towardsdatascience.com/understanding-noisy-data-and-uncertainty-in-machine-learning-4a2995a84198
你的机器学习模型无法正常工作的真正原因
·发表于 Towards Data Science ·阅读时间:9 分钟·2023 年 1 月 23 日
--
人工智能和机器学习领域比以往任何时候都更火热。随着像 Chat GPT 和 Stable Diffusion 这样的模型席卷全球,人工智能和机器学习的炒作重新回归并吸引了大众。面对这些炒作,我们必须提醒自己,机器学习成功的关键在于高质量数据。
在缺乏优质训练数据的情况下,监督机器学习模型没有任何用处。不幸的是,大多数现实世界的数据科学项目失败是因为在数据源质量完全了解之前对模型性能有不切实际的期望。本文将尝试提供对噪声数据的直观理解以及为什么机器学习模型无法有效工作的原因。我们将探讨监督学习和确定性函数的本质、不同类型的模型不确定性,并讨论减少这种不确定性和管理期望的方法。
监督学习表述
从根本上讲,监督机器学习就是函数近似。我们向模型提供一些输入(X)和目标(y),并期望模型通过优化目标函数来近似(或学习)将 X 映射到 y 的函数。更正式地说,监督学习的表述大致如下:

监督学习的简要表述。图像来源:作者。
给定一个数据集(X,y),我们假设 X 和 y 之间的关系至少部分是确定性的。随机误差项 epsilon 的方差在函数 f(X)的确定性水平中发挥了至关重要的作用。epsilon 的方差越高,X 和 y 之间的关系就越随机,越不容易预测。与监督学习的常见教学方式相反,值得注意的是,epsilon 可以(并且通常是)X 的函数。这意味着噪声量和相应的模型不确定性在 X 的不同区域中有所不同。
为了更好地理解确定性和随机误差的作用,我们将深入研究确定性函数的性质,检查它们的特征以及在机器学习中的影响。
确定性函数
确定性函数的输出完全由输入指定。这意味着对于定义域中的每个输入,函数都输出一个唯一的量。这些是我们从第一门代数课程开始就见过的大多数函数:

确定性函数示例。图像由作者提供。
在上述每个示例中,传递到函数中的每个输入都产生一个单一的输出。注意,这并不意味着函数对每个可能的数字都定义。例如,函数(4)仅在 x1 > 0 且 x1 和 x2 的和不为零时定义。
从视觉上看,在一维中,确定性函数通过“垂直线测试”——这意味着如果我们在定义域中的任何点画一条垂直线,它最多会与函数相交一次:

确定性函数的垂直线测试。图像由作者提供。
与此相比,纯非确定性函数则不同。这些函数对于单一输入可能有多个输出。

非确定性函数的垂直线测试。图像由作者提供。
在上图中,垂直线 x = 4 穿过曲线的-2 和 2 处。如果有人问我们 f(4)是什么,我们必须说是-2 或 2。在最极端的情况下,非确定性函数可能会对单一输入输出无限多个值。
尽管上述非确定性函数在可视化和量化上很简单,但机器学习模型永远无法学习它。为了解释这一点,假设我们按如下方式对函数进行采样:

非确定性函数的样本。图像由作者提供。
我们将使用上述样本作为训练集,并拟合一个决策树回归器:
import numpy as np
from sklearn.tree import DecisionTreeRegressor
# Create the training set for the function y = +- sqrt(x)
y_train = np.random.normal(0, 20, 100)
x_train = y_train**2
# Fit a decision tree regressor to the training data
model = DecisionTreeRegressor()
model.fit(x_train.reshape(-1, 1), y_train)
# Create the test set for the function y = +- sqrt(x)
y_test = np.random.normal(0, 20, 100)
x_test = y_test**2
# Make predictions on the test set
preds = model.predict(x_test.reshape(-1, 1))
我们可以可视化模型在测试集上的表现:

决策树回归器在测试集上的预测叠加。图像由作者提供。
决策树在测试集上本质上是在猜测。由于领域中的每个非零输入映射到两个不同的值,机器学习模型无法知道在测试集上预测哪个值。由于基础函数的性质,训练数据中的噪声会导致任何机器学习模型做出具有高不确定性的预测。
不确定性来源
现在我们已经了解了确定性和非确定性/随机函数之间的区别,让我们探讨从学习这些函数中可能产生的不确定性类型。
数据(随机)不确定性
数据的不确定性,也称为随机不确定性,源于观察数据集的固有复杂性。在分类设置中,这通常表现为重叠的类别。例如,以下散点图描绘了二维空间中重叠的类别:

带有重叠类别的分类数据。图片来源:作者。
对于 (x1, x2) 邻域中的大多数观测值,没有明显的方法来区分两个类别。虽然确实存在一个足够复杂的分类器可以完美分类上述数据集中的例子,但同样的分类器在测试集上的表现可能不会比随机猜测更好。
在回归设置中,数据不确定性通常源于基础数据中的附加噪声。数据不确定性会出现在类似于以下数据集的情况中:

噪声回归数据。图片来源:作者。
在这种情况下,x 是唯一考虑的特征,对于 x 的任何邻域,y 都可能取很多不同的值。例如,如果我们基于数据来预测 x = 0 时 y 的值,没有明显的答案显现出来。一般来说,附加噪声分布的方差越大,数据不确定性就会越高。

带有附加噪声的二次函数样本。图片来源:作者。
如前所述,噪声水平和数据不确定性在领域的不同区域中常常有所不同。数据不确定性通常在领域的边缘区域较高,如下所示:

噪声和数据不确定性随 x 变化。附加噪声项的方差是 x 的函数。图片来源:作者。
数据的不确定性通常表现为模型中遗漏了关键特征。为了说明这一点,考虑以下数据集,我们希望使用 x1 来预测 y:

像愤怒的青蛙一样的噪声数据。图片来源:作者。
如果仅使用 x1 作为特征来预测 y,没有任何机器学习模型能够表现良好。例如,当 x1 = -0.49 时,在这个数据集中 y 取 14 个不同的值。在没有其他特征的情况下,预测 y 看起来无望。然而,上述数据集实际上是以下函数的一个样本:

一个方程。图像由作者提供。
如果我们从 x2 取样,并在三维中可视化 x1、x2 和 y,则会出现一个清晰的表面:

一个三维表面。图像由作者提供。
互动表面。图像由作者提供。
x1 和 y 之间的关系噪声较大,任何试图学习这种关系的模型都会导致高数据不确定性。然而,(x1, x2) 和 y 之间的关系没有噪声,因为它完全由一个封闭形式的方程决定。只要有足够的训练数据,足够非线性的模型可以学习这种关系,从而没有数据不确定性。
许多金融应用面临数据不确定性,因为许多未观察到的人类行为会影响市场。例如,无数因素影响一个人是否会违约。即使两个贷款接收者有相同的信用档案并在特征空间中看起来相同(大多数金融机构有大量的特征),仍然有可能因为生活紧急情况而导致接收者 A 违约,而接收者 B 没有违约。除非违约模型接触到预测生活紧急情况的数据,否则这些情况下总会存在数据不确定性。
知识(认知)不确定性
知识(认知)不确定性来自于在领域的某些区域稀疏采样的数据。考虑以下数据集,我们希望建模 x 和 y 之间的关系:

高知识不确定性。图像由作者提供。
当 x 在 1 到 3 之间或在 4 到 7 之间取值时,这个数据集没有关于 y 的信息。机器学习模型在预测这些区域的 y 值时不得不做出盲目猜测。幸运的是,与数据不确定性不同,知识不确定性可以通过采样缺失数据的区域来减少。如果我们对领域进行完全采样,x 和 y 之间的关系将变得清晰:

通过采样消除知识不确定性。图像由作者提供。
知识不确定性在实际应用中很常见,因为大多数数据集不是从均匀分布中采样的,导致领域中的稀疏区域。一般来说,领域采样越均匀,建模过程中的知识不确定性就会越小。
处理噪声数据和不确定性的方法
现在我们对噪声数据和不确定性的性质有了一些直觉,让我们探讨一些可以采取的实际措施来应对这个问题。
1. 停止尝试寻找更好的模型
快速失败并知道何时转移是无价的数据科学技能。我们生活在一个几乎对任何机器学习问题都有最先进、高性能模型的时代。任何在模型选择或超参数调整过程中遇到次优结果的数据科学家应该寻求其他地方的性能提升。换句话说,模型可能并不需要改进,数据才需要。
2. 获取更多数据
正如我们之前所见,表现不佳的模型通常是由于数据不完整。即使是一个看似不重要的特征,在与更多特征一起使用时也可以提供显著的预测能力。因此,数据科学家必须确保他们对问题及成功建模所需的数据源有全面的理解。获取正确的数据和工程化良好的特征可以说是数据科学家工作中最困难但最关键的部分。数据科学家应该能够与利益相关者一起权衡获取更多数据的成本和收益,因为这最终是带来最大商业价值的地方。
3. 量化模型预测的不确定性
对于那些噪声数据不可避免的高风险机器学习问题,模型量化预测的不确定性通常是有价值的。经过校准的分类器提供的预测概率的不确定性估计比预测类别可能更有用。在回归中,像分位数回归、符合预测、贝叶斯神经网络和自然梯度提升等方法可以用于生成不确定性估计。简而言之,拥有一个在做出预测时能够说“我不知道”的模型可以非常有价值。
4. 管理期望
作为数据科学家,重要的是要记住机器学习是实现特定目标的工具,应该在清楚理解基本业务问题的情况下使用。全面理解问题以及机器学习在解决问题中的角色,可以更准确地设定模型性能的期望,从而最终更有效地使用这一工具。这将有助于开发出一个针对业务特定需求的机器学习模型,提供最佳结果。机器学习是一种手段。通常情况下,不需要一个完美的模型来解决问题,但正确利用不完美的模型则是必要的。
成为会员: https://harrisonfhoffman.medium.com/membership
喜欢我的文章?请给我买杯咖啡: https://www.buymeacoffee.com/HarrisonfhU
参考文献
- 使用 Catboost 估计不确定性 — https://catboost.ai/en/docs/references/uncertainty
理解预测性维护 — 数据采集与信号去噪
·发表于 Towards Data Science ·阅读时长 10 分钟·2023 年 11 月 8 日
--

图片由 Michael Dziedzic 提供,来源于 Unsplash
文章目的
我想开始一系列文章,给你提供预测性维护的实际操作经验,并让你更容易入门信号处理。在这篇文章中,我们将重点关注数据采集和信号清理。如果你对某些部分感兴趣,我会考虑进一步详细讲解。在这篇文章的下一部分,我为你准备了一些实际练习。你可以使用我准备的代码进行自己的实验,通过实践来学习。
预测性维护的数据科学
数据科学中的预测性维护就像是为机器提供一种超级智能的护理方式。我们不是在机器坏了之后才进行修理,而是利用高级计算机程序和过去的数据来预测机器可能出现的问题。这有点像为机器提供了一个水晶球!通过这种方式,公司可以节省资金并延长重要机器的使用寿命。这种方法包括实时监控机器,收集数据,并使用智能计算机程序来告诉我们何时需要进行维护。因此,我们可以在问题出现之前进行修复,就像是在机器生病之前为其进行健康检查!
数据采集与处理

图片由 Mika Baumeister 提供,来源于 Unsplash
一切都始于数据。我们需要深入了解一些通信理论原则,如 Shannon–Hartley 定理和 Nyquist 率,以确保传感器数据的准确和高效传输。
Shannon–Hartley 定理
Shannon–Hartley 定理就像是关于通过通信频道传输多少信息的规则手册。它告诉我们频道的宽度,即频道能处理的数据量,至关重要。因此,在选择用于监控机器或传感器的设备或工具之前,我们需要确保频道足够宽,以处理我们想要的所有数据而不会丧失质量。
为了更好地利用频道,我们可以聪明地使用频道。这就像是找到最佳的方式来安排数据,以便我们能最有效地利用频道空间。这被称为优化频谱效率。因此,在选择传感器时,我们应该选择那些在可用频道空间内表现优秀的传感器。
定理应用
评估频道带宽 在选择传感器之前,仔细检查通信频道可以处理多少数据。如果频道宽度不足,考虑升级频道或寻找需要较少带宽的传感器。
优化传感器选择
选择那些在使用可用频道空间方面效率高的传感器。一些传感器可能在传输数据时更有效,不占用频道过多的空间。
Nyquist 率

图片由 Jair Lázaro 拍摄,来源于 Unsplash
Nyquist 率就像是我们应该用传感器收集数据的速度限制。如果我们收集数据的速度太慢,我们可能会错过关于事物运作的重要细节。因此,当我们设置系统来监控事物时,需要确保我们以能够准确展示发生情况的速度收集数据,而不会造成混淆。
我们必须确保传感器以遵循 Nyquist 率的速度收集数据。这意味着数据收集速度要足够快,以捕捉所有重要细节,而不会遗漏任何重要信息。

Nyquist 定律
确保传感器能够以至少是最快变化(fmax) 两倍的速度(fs) 测量。这可以防止遗漏重要细节,并确保测量的准确性。选择速度与所需测量匹配的传感器。
Nyquist 的超级英雄——抗混叠滤波器
为了避免混淆或错误,我们使用像抗锯齿滤波器这样的特殊滤波器。这些滤波器帮助我们的传感器专注于信息的重要部分,并去除任何额外的噪音或混淆的细节。这就像使用放大镜来清晰地查看事物一样。因此,当我们使用传感器监控机器时,我们应该确保这些滤波器到位,以便获得最佳和最准确的信息。那么,什么是锯齿效应和抗锯齿呢?
锯齿效应

锯齿效应和抗锯齿的比较 来源:www.zilliondesigns.com/
想象一下你正试图拍摄一辆快速移动的汽车。然而,你的相机每秒只拍摄一张照片。如果汽车移动得非常快,当你拍摄下一张照片时,它可能已经在完全不同的位置了。因此,当你查看这些照片时,汽车看起来像是处于奇怪的位置,因为你错过了它在每张照片之间的移动。这种事物看起来与实际情况不同的奇怪效果被称为锯齿效应。
抗锯齿
现在,想象你有一台神奇的相机,可以非常快速地拍摄大量照片。它不是每秒拍一张,而是在同一秒内拍摄许多张照片,捕捉汽车移动的每一个细节。当你查看这些照片时,你会看到汽车移动的平滑而准确的表现。确保捕捉到所有细节并避免锯齿效应的过程称为抗锯齿。
简单来说,抗锯齿就像拥有一台超级快速的相机,以确保事物看起来如实,无任何奇怪的失真或模糊。它有助于创建清晰而准确的图像,尤其是在事物快速移动或变化迅速时。
好的,但为什么在信号处理中这很重要?这些动画将向你展示我们如何使用它来去噪实际信号,并捕捉这些信号的“核心”。

信号处理中的工作原理 来源:siemens.com
实践经验

shraga kopstein 的照片,来源于 Unsplash
让我们深入代码,动手实践吧!首先,我们将生成一个示例信号,并设置一个绘图函数用于可视化。在这个例子中,我们将使用基本的正弦波,并尝试调整其参数。
import numpy as np
import matplotlib.pyplot as plt
def generate_signal(frequency, duration, sampling_rate):
t = np.linspace(0, duration, int(sampling_rate * duration), endpoint=False)
signal = np.sin(2 * np.pi * frequency * t)
return t, signal
def plot_signals(t, original_signal, filtered_signal, title):
plt.figure(figsize=(10, 6))
plt.plot(t, original_signal, label='Original Signal', linewidth=2)
plt.plot(t, filtered_signal, label='Filtered Signal', linestyle='dashed', linewidth=2)
plt.title(title)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.grid(True)
plt.show()
接下来,我们将创建一个使用简单滤波器(如巴特沃斯滤波器)进行信号抗锯齿的示例。
将巴特沃斯滤波器视为信号中不同频率的门控器。它允许低频通过,同时减少高频。调整如滤波器阶数和截止频率等参数可以微调其选择性。这有助于我们防止诸如混叠等问题,确保信号的更清晰表示。让我们在代码中查看实际效果!
from scipy.signal import butter, lfilter
def apply_antialiasing(signal, cutoff_frequency, sampling_rate, order=4):
# Design a low-pass Butterworth filter (maximal flat magnitute)
nyquist = 0.5 * sampling_rate
# Nyquist law in practice
normal_cutoff = cutoff_frequency / nyquist
b, a = butter(order, normal_cutoff, btype='low', analog=False)
# Apply the filter to the signal
filtered_signal = lfilter(b, a, signal)
return filtered_signal
让我们进行一次实验
frequency = 30.0 # Frequency of the signal
duration = 1.0 # Duration of the signal in seconds
sampling_rate = 100.0 # Sampling rate in Hz
cutoff_frequency = 20.0 # Cutoff frequency of the anti-aliasing filter
# Generate a signal
t, original_signal = generate_signal(frequency, duration, sampling_rate)
# Apply anti-aliasing filter
filtered_signal = apply_antialiasing(original_signal, cutoff_frequency, sampling_rate)
# Plot the original and filtered signals
plot_signals(t, original_signal, filtered_signal, 'Original and Filtered Signals')

这个实验的输出
当涉及到采样时,我们的纯正弦波看起来并不像我们期望的那样“美丽”。某些部分被截断了,这是由于采样过程造成的。
我鼓励你复制这段代码并尝试调整参数。这非常有趣!
我们也可以尝试不同类型的滤波器。
from scipy.signal import butter, cheby1, cheby2, ellip, lfilter
def apply_filter(signal, cutoff_frequency, sampling_rate, filter_type='butter', order=4):
nyquist = 0.5 * sampling_rate
normal_cutoff = cutoff_frequency / nyquist
if filter_type == 'butter':
b, a = butter(order, normal_cutoff, btype='low', analog=False)
elif filter_type == 'cheby1':
b, a = cheby1(order, 5, normal_cutoff, btype='low', analog=False)
elif filter_type == 'cheby2':
b, a = cheby2(order, 40, normal_cutoff, btype='low', analog=False)
elif filter_type == 'ellip':
b, a = ellip(order, 5, 40, normal_cutoff, btype='low', analog=False)
filtered_signal = lfilter(b, a, signal)
return filtered_signal
我们将稍微修改我们的绘图函数,以便将所有结果汇集在一个图中。
def plot_signals_subplots(t, original_signal, filtered_signals, titles):
num_filters = len(filtered_signals)
fig, axes = plt.subplots(num_filters + 1, 1, figsize=(12, 2 * (num_filters + 1)))
# Plot original signal
axes[0].plot(t, original_signal, label='Original Signal', linewidth=2, alpha=0.7)
axes[0].set_title('Original Signal')
axes[0].set_xlabel('Time (s)')
axes[0].set_ylabel('Amplitude')
axes[0].legend()
axes[0].grid(True)
# Plot filtered signals
for i, (filtered_signal, filter_type) in enumerate(zip(filtered_signals, filter_types)):
label = f'Filtered Signal ({filter_type})'
axes[i + 1].plot(t, filtered_signal, label=label, linestyle='dashed', linewidth=2)
axes[i + 1].set_title(f'Filtered Signal ({filter_type})')
axes[i + 1].set_xlabel('Time (s)')
axes[i + 1].set_ylabel('Amplitude')
axes[i + 1].legend()
axes[i + 1].grid(True)
plt.tight_layout()
plt.show()
# Generate a signal (exactly the same as previous)
t, original_signal = generate_signal(frequency, duration, sampling_rate)
# Apply different filters
filter_types = ['butter', 'cheby1', 'cheby2', 'ellip']
filtered_signals = [apply_filter(original_signal, cutoff_frequency, sampling_rate, f) for f in filter_types]
# Plot the original and filtered signals
plot_signals(t, original_signal, filtered_signals, 'Original and Filtered Signals')

在美丽的纯正弦波中,某些效果可能不容易看出。让我们加入一些噪声。
def generate_signal_with_noise(frequency, duration, sampling_rate, noise_amplitude=0.1):
t = np.linspace(0, duration, int(sampling_rate * duration), endpoint=False)
signal = np.sin(2 * np.pi * frequency * t)
# Add noise to the signal
noise = noise_amplitude * np.random.normal(size=len(signal))
signal_with_noise = signal + noise
return t, signal_with_noise
再次运行实验。
# Generate a noisy signal
t, original_signal = generate_signal(frequency, duration, sampling_rate)
t, original_signal_with_noise = generate_signal_with_noise(frequency, duration, sampling_rate, noise_amplitude=0.8)
# Apply different filters
filter_types = ['butter', 'cheby1', 'cheby2', 'ellip']
filtered_signals = [apply_filter(original_signal_with_noise, cutoff_frequency, sampling_rate, f) for f in filter_types]
# Plot the original and filtered signals
plot_signals_subplots(t, original_signal_with_noise, filtered_signals, 'Original and Filtered Signals')
现在,我们可以观察去噪的效果,并了解各种滤波器的工作原理。

在这个实验中,我们精确识别了我们生成的函数(一个纯正弦波)。在研究实际数据时,我们旨在确定描述现象的函数,例如故障信号或持续增加的设备磨损。噪声在我们的数据中总是存在,理解如何处理它至关重要。
季节性分解
季节性分解是信号处理中的一种强大技术,用于理解信号的不同组成部分并研究其季节性模式。在这个背景下,信号可以表示各种现象,如经济数据、环境变量甚至电信号。季节性分解的目标是将复杂信号拆解为其基本部分,使分析和解释变得更容易。
一个信号通常由三个主要组件组成:趋势、季节性成分和残差(或噪声)。趋势代表信号的长期行为或整体模式,季节性成分捕捉在规律间隔内发生的重复模式,而残差是无法用趋势或季节性解释的剩余变异或噪声。
通过将信号分解为这些组件,分析师可以洞察潜在的模式和趋势,帮助他们做出更明智的决策。这个过程在金融、气候科学和制造业等领域尤为重要,因为理解季节性变化对于准确预测和有效决策至关重要。
在这次季节性分解的探索中,我们将深入探讨信号分解的方法,逐一检查每个组件,并学习如何调查和解释从这种分析方法中获得的见解。
from statsmodels.tsa.seasonal import seasonal_decompose
# Generate a signal with noise
t, original_signal_with_noise = generate_signal_with_noise(frequency, duration, sampling_rate, noise_amplitude=0.8)
# Perform seasonal decomposition
decomposition = seasonal_decompose(original_signal_with_noise, period=25) # Adjust the period as needed
# Get the trend, seasonal, and residual components
trend = decomposition.trend
seasonal = decomposition.seasonal
residual = decomposition.resid
# Plot signals in separate subplots
plot_signals_subplots(t, original_signal_with_noise, [trend, seasonal, residual], 'Original Signal with Noise and Decomposed Components')

通过添加有趣的信号来使我们的正弦波更有趣。我将调整信号生成以模拟设备的磨损,想象它像机器的零部件随着时间的推移逐渐磨损。这可能会导致由于摩擦增加而产生更多的振动。
def generate_signal_with_wear_and_noise(frequency, duration, sampling_rate, wear_slope=0.02, noise_amplitude=0.1):
t = np.linspace(0, duration, int(sampling_rate * duration), endpoint=False)
signal = np.sin(2 * np.pi * frequency * t)
# Simulate equipment wearing with a linear trend
wear = wear_slope * t
signal_with_wear = signal + wear
# Add noise to the signal
noise = noise_amplitude * np.random.normal(size=len(signal))
signal_with_wear_and_noise = signal_with_wear + noise
return t, signal_with_wear_and_noise
开始实验。
# Generate a signal with equipment wearing and noise
t, original_signal_with_wear_and_noise = generate_signal_with_wear_and_noise(
frequency, duration, sampling_rate, wear_slope=0.5, noise_amplitude=0.1
)
# Perform seasonal decomposition
decomposition = seasonal_decompose(original_signal_with_wear_and_noise, period=25) # Adjust the period as needed
# Get the trend, seasonal, and residual components
trend = decomposition.trend
seasonal = decomposition.seasonal
residual = decomposition.resid
# Plot signals in separate subplots with component names
plot_signals_subplots(t, original_signal_with_wear_and_noise, [trend, seasonal, residual], 'Signal with Wearing and Noise, and Decomposed Components')

哇!在我们原始的噪声信号中,我们找不到任何东西。然而,在将其分解成部分后,我们发现了趋势组件,现在已经清晰可见。这是一个信号。
接下来做什么?
在下一篇文章中,我们将探讨为什么检查平稳性很重要,什么是单位根,以及进行特征工程。跟随我,不要错过哦 😃
理解预测性维护 — 单位根和稳态
原文:
towardsdatascience.com/understanding-predictive-maintenance-unit-roots-and-stationarity-f05322f7b6df
·发表于Towards Data Science ·阅读时间 13 分钟·2023 年 11 月 13 日
--

文章目的
在这篇文章中,我们将深入探讨单位根和稳态的关键概念。请做好准备,我们将探讨为何检查稳态至关重要,什么是单位根,以及这些元素如何在我们的预测维护工具箱中发挥重要作用。我们还将掌握混沌理论!
这篇文章是“理解预测性维护”系列的一部分。我计划以类似风格创建整个系列。
点击此链接查看整个系列。请关注我,以确保不会错过新的文章。
数据稳态 — 捉迷藏的分析游戏

是否曾经想过你的数据是否在玩捉迷藏?直接切入主题——我们在谈论稳态。它不仅仅是一个花哨的术语,它是理解你的时间依赖数据究竟有多稳定和可预测的秘密武器。请做好准备,我们将探讨为何数据稳态在建模和预测中是游戏规则的改变者。
稳态的关键规则
-
常数均值:一个平稳的时间序列应表现出一致的平均值。如果均值发生变化,可能表明过程的基本行为发生了变化。
-
常数方差:时间序列的方差,代表数据点的分布,应该保持不变。方差的波动会使得准确预测变得困难。
-
常数自相关:自相关测量时间序列与其滞后值之间的相关性。在平稳序列中,自相关的强度和模式应该保持一致。
仅仅是统计性质的“稳定性”。
为什么平稳性如此重要
想象一下你的预测模型就像是专家导航员在数据的海洋中航行。为了顺利导航,它们更喜欢平静的水域——这就是平稳性的作用。平稳数据就像是一片宁静的海洋,模式保持一致。但是,如果你的数据是一片风雨交加、波涛汹涌的海洋(非平稳),准确预测就会变得非常困难。这就是为什么我们需要发现这些风暴,并将数据转变成一个平静的池塘,以便进行有效的时间序列分析。
现实世界的影响
数据的平稳性不仅仅是技术问题;它无处不在,影响着从金融到天气预测的决策。在金融领域,精确性对于风险和回报的估计至关重要,假设平稳性就像拥有一个可靠的指南针。气候科学家依赖平稳模型来预测长期天气模式——这就像拥有一个可靠的天气应用程序来预测地球的未来。
迈向深刻分析的旅程
使我们的数据平稳不仅仅是一个技术任务;它是迈向清晰的冒险。这就像将混乱的宝藏地图转变成清晰的指南,帮助分析师和决策者理解一切。在动态的时间依赖数据世界中,平稳性成为我们可靠的地图,引导我们理解表面下的模式,使我们在数据的旅程中更加顺利。
好了,现在我们了解了为什么平静的数据很酷,接下来让我们学习如何让它平静下来。但是等等,在我们动手写代码之前,让我给你介绍一下“单位根”。把它们当作影响我们数据行为的特殊成分。了解单位根就像拥有一个秘密配方,将我们波动的、混乱的数据变成一个平滑的池塘,准备好让我们潜入其中探索。所以,准备好迎接我们旅程的下一部分吧!
单位根——数据历史书中的顽皮时间旅行者

Andy Beales 拍摄于 Unsplash
单位根是时间序列分析中的基本概念,在理解现实世界数据的行为和特征中发挥着关键作用。在这次探讨中,我们将深入了解单位根是什么,它们在真实数据分析中为何重要,以及它们如何影响预测性维护领域。当然,我们会在动手实验部分做一些实验。
什么是单位根?
时间序列变量中的单位根意味着一个随机过程,其中变量在任何给定时间的值受其过去值的影响。形式上,单位根表明非平稳性,表明该变量不会随着时间的推移回到一个恒定的均值。

单位根的数学解释
单位根的存在将持久性引入时间序列,导致建模和预测中的挑战。增强型迪基-富勒(ADF)测试及其他统计方法用于检测单位根的存在,提供了非平稳性的定量度量。
单位根就像数据的讲述者,编织出超越个体时刻的叙事,创造出连续的故事情节。它们标志着历史影响的持续性,将记忆的元素引入我们数据集的数字结构中。
想象你的数据集是一部历史小说,每个数据点代表着正在展开的故事中的一章。在这个背景下,单位根就是那些反复出现的主题和角色,它们在叙事中留下了不可磨灭的印记,以微妙而一致的影响引导情节发展。
为什么这对我们很重要?
理解单位根对时间序列分析师和建模师来说是基础性的。非平稳数据带来了挑战,因为传统模型通常假设数据是平稳的,以便进行准确的预测。分析师必须通过采用如差分等变换来处理单位根,以诱导平稳性并促进模型开发。
在预测维护场景中,单位根在确保预测模型准确性方面发挥着至关重要的作用。嵌入单位根的长期影响可能显著影响预测的可靠性,因此其识别和缓解对于有效的维护策略至关重要。
在我们进行这次技术探索时,我们将深入探讨单位根测试方法,解读结果,并探索处理非平稳时间序列数据的策略。单位根的理论基础为我们在分析旅程中的实际应用提供了坚实的基础。
增强型迪基-富勒(ADF)帮助我们

想象你有一行蚂蚁朝某个方向移动。ADF 测试检查这些蚂蚁是否有目的地行进(平稳)或是随机散布在各处(非平稳)。
ADF 测试涉及一些数学运算,但我们可以简化它:
-
零假设 (
*H0*):这就像默认假设。ADF 的零假设是数据有单位根,这意味着它是非平稳的。就像说蚂蚁在随机游走。H0: 数据具有单位根(非平稳)
-
替代假设 (
*H1*): 这是我们试图证明的假设。替代假设是数据是平稳的,就像蚂蚁在清晰的直线上行走一样。H1: 数据是平稳的
-
测试统计量: ADF 测试计算一个称为测试统计量的数字。如果这个数字很小,表明数据可能是平稳的。
p 值: 这是一个概率分数。如果 p 值 很小(低于某个阈值,例如 0.05),我们就拒绝零假设并接受替代假设,认为我们的数据可能是平稳的。
这并不复杂,只需运行测试并检查 p 值
from statsmodels.tsa.stattools import adfuller
# Perform the Augmented Dickey-Fuller (ADF) test for stationarity
adf_statistic, adf_p_value, adf_lags,
adf_nobs, adf_critical_values, adf_reg_results = adfuller(stationary_series)
# Check if the series is stationary based on the p-value
is_stationary = adf_p_value < 0.05 # Using a significance level of 0.05
你通常会像这样使用 adf:
# What youy will probably will use most of the time
_, adf_p_value, _, _, _, _= adfuller(stationary_series)
但我会解释这些变量背后的含义
-
adf_statistic: ADF 测试中的统计量,指示了针对非平稳性零假设的证据强度。 -
adf_p_value: 与零假设相关的 p 值。较低的 p 值表示更强的反对非平稳性的证据。 -
adf_lags: 测试中使用的滞后数。 -
adf_nobs: ADF 测试中使用的观察数量。 -
adf_critical_values: 在不同显著性水平下的测试统计量的临界值。 -
adf_reg_results: 回归结果,提供了关于测试中执行的线性回归的额外信息。
虽然混沌可能看起来令人生畏,但我们可以通过理解和利用其模式将其转变为我们的盟友。在数据和分析的领域,混沌可以是一种强大的力量,当正确引导时,提供洞察、预测和更清晰的前进道路。这一切都是关于将不可预测性转变为优势,让混沌成为我们探索和理解旅程中的战略伙伴。
你的随机性有多“随机”?
让我们先生成一个简单的平稳序列,但这里有个提醒:“随机”并非全都相同。随机性主要有两种——真正的随机和伪随机。你可能更常接触伪随机,因为这是计算机的首选。
在计算中,生成真正的随机数是一个挑战,因为计算机是确定性的机器。伪随机数,顾名思义,并非真正随机,而是由模拟随机性的算法生成。这些算法以一个称为种子的初始值开始,并使用它生成看起来随机的数字序列。
种子
种子是伪随机数生成中的一个关键元素。它作为算法的起点。如果你使用相同的种子,每次都会得到相同的伪随机数序列。这种确定性在你希望重现的场景中是有利的。例如,如果你运行一个涉及随机性的模拟或实验,设置种子可以让你重现确切的随机数序列。
另一方面,更改种子会导致不同的伪随机数序列。这一特性常用于引入模拟中的变异性或为使用随机性的算法提供不同的初始条件。
总之,伪随机数是由算法生成的,而种子是这些算法的起点。控制种子可以让你控制伪随机数的序列,在计算机生成的随机性中提供确定性与变异性之间的平衡。
生成我们的伪随机分布的时间到了。
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(1992) # WOW this is our deterministic seed.
def generate_stationary_series_pseudorandom(size=100):
stationary_series = np.random.randn(size)
return stationary_series
我们可以使用真正的随机性吗?

现在我们可能会惊讶于即使是我们大多数时间所接触的随机性也是确定性随机。但是我们能否创造真正的随机性,确保没有确定性在其背后?
好消息!我们可以利用真正的物理现象——大气噪声。还记得你电视屏幕上的那些闪烁的黑白点吗?那就是我们的大气噪声,我们将利用它来产生真正的随机性。所以,你的电视不仅仅是用来看节目的;它是你摆脱确定性世界的门票。
import requests
def generate_stationary_series_random(size=100):
# Fetch truly random values from random.org atmospheric noise API
response = requests.get(f'https://www.random.org/integers/?num={size}&min=-10000&max=10000&col=1&base=10&format=plain&rnd=new')
if response.status_code == 200:
stationary_series = [int(value) for value in response.text.strip().split('\n')]
return stationary_series
else:
raise Exception(f"Failed to fetch random values. Status code: {response.status_code}")
使用这个函数,我们可以生成真正的随机性,万岁!
平稳性检验
首先,让我们生成序列。
# Generate series
stationary_series_pseudorandom = generate_stationary_series_pseudorandom()
stationary_series_random = generate_stationary_series_pseudorandom()
创建结果图。
titles = [
'stationary_series_pseudorandom',
'stationary_series_random'
]
plot_multiple_series(stationary_series_random, stationary_series_pseudorandom,
titles=titles)

嗯,“壮观”
_, adf_p_value, _, _, _, _= adfuller(stationary_series_pseudorandom)
print(f'PseudoRandom adf p-value: {adf_p_value}')
_, adf_p_value, _, _, _, _= adfuller(stationary_series_random)
print(f'TrueRandom adf p-value: {adf_p_value}')

结果
当 p 值非常小(<0.05)时,它提供了反对零假设的证据,表明你的数据很可能是平稳的。
所以,在这种情况下,p 值远小于 0.05,你可以有信心地说,“是的,我们的数据是平稳的。”
现在,让我们花一点时间来处理这些数据。我们的伪随机数的 p 值大约比真正的随机数小 200 万倍。
为什么会发生这种情况?伪随机数是由算法生成的,这引入了一定程度的确定性。这些算法可能会无意中在数据中引入模式或结构。另一方面,真正的随机数据,如大气噪声,更有可能表现出纯随机性的特征。ADF 检验专注于检测指示非平稳性的模式,可能会在真正的随机数据中发现较少的此类模式,从而导致相对较高的 p 值。
实践经验

图片由 Eddie Kopp 提供,来源于 Unsplash
现在是时候通过代码来动手了。我们将进行一些实验,帮助你熟悉文章中的概念。我建议你复现这些实验。在我们深入探讨平稳性之前,我想问你一个问题。

照片由 Ian Barsby 拍摄,来自 Unsplash
现在我们将添加几个示例,展示如何使这些数据变为非平稳数据,我们将打破平稳性的关键规则。解释完毕后,我们将绘制所有图表。
线性趋势(非恒定均值)
def generate_non_stationary_linear_trend(size=100):
time = np.arange(size)
linear_trend = 0.5 * time
non_stationary_series = np.random.randn(size) + linear_trend
return non_stationary_series
引入线性趋势以违反恒定均值规则意味着在时间上添加系统性的增加或减少。在非平稳的线性趋势序列中,值随着时间线性增加。这违反了恒定均值规则,因为序列的平均值在变化,表明过程的基本行为发生了变化。在这种情况下,单位根有助于线性趋势的持续性,导致变量在任何给定时间的值受到其过去值的影响。
正弦幅度 (非恒定方差)
def generate_non_stationary_sin_amplitude(size=100):
time = np.arange(size)
amplitude = 0.5 + 0.02 * time
sin_amplitude_component = amplitude * np.sin(2 * np.pi * time / 10)
non_stationary_series = np.random.randn(size) + sin_amplitude_component
return non_stationary_series
添加一个幅度逐渐增加的正弦组件违反了恒定方差规则。在非平稳的季节性组件序列中,正弦组件的幅度随时间线性增长。这导致数据点的分布波动,使得方差变得不恒定。单位根有助于季节性组件的持续性,影响方差随着幅度的变化而变化。
指数增长 (非恒定自相关)
def generate_non_stationary_exponential_growth(size=100, growth_rate=0.05):
time = np.arange(size)
exponential_growth_component = np.exp(growth_rate * time)
non_stationary_series = np.random.randn(size) + exponential_growth_component
return non_stationary_series
引入指数增长模式违反了恒定自相关规则。非平稳的扩展幅度序列表现出指数增长,导致自相关模式随着值的增加而变化。单位根在时间序列中引入了持续性,导致建模和预测中的挑战。单位根的存在意味着非平稳性,表明变量随时间不会恢复到恒定均值。
开始实验
执行代码生成时间序列并绘制结果。
# Example usage
stationary_series_pseudorandom = generate_stationary_series_pseudorandom()
non_stationary_linear_trend_series = generate_non_stationary_linear_trend()
non_stationary_sin_amplitude_series = generate_non_stationary_sin_amplitude()
non_stationary_exponential_growth_series = generate_non_stationary_exponential_growth()
# Visualize the examples
plot_multiple_series(stationary_series_pseudorandom,
non_stationary_linear_trend_series,
non_stationary_sin_amplitude_series,
non_stationary_exponential_growth_series,
titles=[
'Stationary series',
'Linear Trend (Non-Constant Mean)',
'Sinusoidal Amplitude (Non-Constant Variance)',
'Exponential Growth (Non-Constant Autocorrelation)'
])

在探索性数据分析中,发现线性趋势或指数增长相对简单,因为这些模式展示了明显的视觉线索。然而,在处理正弦幅度时,区分平稳状态和非平稳状态变得具有挑战性。仅通过查看数据,很难判断幅度是平稳还是非平稳。

这个案例将展示统计测试的力量。我们手中有强大的工具。
_, adf_p_value_stationary, _, _, _, _ = adfuller(stationary_series_pseudorandom)
_, adf_p_value_linear_trend, _, _, _, _ = adfuller(generate_non_stationary_linear_trend())
_, adf_p_value_sin_amplitude, _, _, _, _ = adfuller(generate_non_stationary_sin_amplitude())
_, adf_p_value_exponential_growth, _, _, _, _ = adfuller(generate_non_stationary_exponential_growth())
# Print the results
print(f'PseudoRandom ADF P-value (Stationary Series): {adf_p_value_stationary}')
print(f'PseudoRandom ADF P-value (Linear Trend): {adf_p_value_linear_trend}')
print(f'PseudoRandom ADF P-value (Sinusoidal Amplitude): {adf_p_value_sin_amplitude}')
print(f'PseudoRandom ADF P-value (Exponential Growth): {adf_p_value_exponential_growth}')

测试结果表明,只有平稳序列在 ADF 测试中是平稳的。
ADF 测试在平稳和非平稳时间序列之间提供了明确的区分。在第一种情况下,我们可以自信地拒绝原假设,表明时间序列是平稳的。然而,对于其他情况,我们必须接受原假设,得出数据是非平稳的结论。具体来说,在正弦幅度的情况下,尽管非平稳性在视觉上显而易见,但 ADF 测试通过不允许我们拒绝原假设来确认我们的观察结果。
实践变换
现在,让我们玩一玩变换,尝试将我们的非平稳时间序列转变为平稳序列——就像做一点反向工程。在现实场景中,确定所需的确切变换通常是一个试错过程。我建议进行探索性数据分析,绘制时间序列,并进行经验尝试。如果某个变换使序列平稳,你不仅达到了平稳性,还可以获得有关数据特征的宝贵见解。
def make_linear_trend_stationary(series):
# Subtract the linear trend to make the mean constant.
time = np.arange(len(series))
linear_trend = 0.5 * time # Somehow we have found this trend :)
stationary_series = series - linear_trend
return stationary_series
def make_sin_amplitude_stationary(series):
# Apply differencing to stabilize and make the variance constant.
diff_series = np.diff(series)
return diff_series
def make_exponential_growth_stationary(series, epsilon=1e-8):
# Add a small constant to avoid zero or negative values
series = np.where(series <= 0, epsilon, series)
# Add a small constant to avoid non-finite values
series += epsilon
# Apply the log for stabilization
series = np.log(series)
# Take the first difference to remove the exponential growth
stationary_series = np.diff(series)
return stationary_series
在定义了我们的变换函数后,是时候将它们付诸实践了。让我们将这些变换应用于我们的非平稳时间序列,看看是否能够成功引入平稳性。
# Apply transformations to make non-stationary examples stationary
stationary_linear_trend = make_linear_trend_stationary(generate_non_stationary_linear_trend())
stationary_sin_amplitude = make_sin_amplitude_stationary(generate_non_stationary_sin_amplitude())
stationary_exponential_growth = make_exponential_growth_stationary(generate_non_stationary_exponential_growth())
# Perform ADF test for the transformed series
adf_p_value_stationary_linear_trend = adfuller(stationary_linear_trend)[1]
adf_p_value_stationary_sin_amplitude = adfuller(stationary_sin_amplitude)[1]
adf_p_value_stationary_exponential_growth = adfuller(stationary_exponential_growth)[1]
# Print the results
print(f'ADF P-value (Stationary Linear Trend): {adf_p_value_stationary_linear_trend}')
print(f'ADF P-value (Stationary Sinusoidal Amplitude): {adf_p_value_stationary_sin_amplitude}')
print(f'ADF P-value (Stationary Exponential Growth): {adf_p_value_stationary_exponential_growth}')

现在我的数据平稳了,太棒了!
数据的样子如下:

好消息!由于我们的数据现在是平稳的,我们可以在每种情况下自信地拒绝原假设。现在,为了增加一点趣味,我将接受用给定种子反向工程你的随机生成迭代的挑战。让我们看看我能否揭开这个谜团!😄
在这个链接中查看整个系列。通过关注我,确保你不会错过新文章。
理解预测性维护——波数据:特征工程(第一部分)
开始学习波数据信号处理所需的所有信息
·发布于 Towards Data Science ·阅读时间 16 分钟·2023 年 11 月 21 日
--

照片由 Lukas Tennie 提供,发布在 Unsplash 上。
文章目的
我们即将深入探讨一些有趣的内容——波数据信号处理。这在预测性维护中非常重要,但在其他领域也是如此。我将在这一系列中逐步讲解,使其易于理解。如果你有任何想法,请随时分享!
本文是《理解预测性维护》系列的一部分。
查看整个系列的链接。通过关注我,确保你不会错过新的文章。所有没有说明文字的图片均由我创作。
为什么我们的分析在时域中进行很重要?
信号处理中的时域分析是一种方法,重点关注信号在时间上的行为和特征。与频域分析不同,后者探讨信号成分的频率内容,时域分析提供了信号在不同时间间隔内变化的洞察。这种方法使我们能够观察信号表现出的变化、模式和趋势,提供了关于系统或过程的动态和时间方面的宝贵信息。
为什么在预测性维护中它如此重要?

照片由 James Lewis 提供,发布在 Unsplash 上。
通过将这一分析技术应用于设备数据,维护专业人员可以识别和分析机械性能中的时间模式。监测变化有助于及早发现异常或偏离预期行为,从而及时干预以解决潜在问题,防止其升级。这种前瞻性的维护方法提高了设备的可靠性,减少了停机时间,并最终有助于更加经济高效的运营过程。
理解信号的时间域特征使得行业能够超越被动维护实践。通过时间域分析获得的见解支持预测性维护,使组织能够根据设备的实际状态安排维护活动,而不是基于任意的时间间隔。这不仅优化了资源利用,还延长了机械的使用寿命,从而带来显著的成本节省和改进的整体运营性能。
振动数据——预测性维护的核心
理解振动数据在预测性维护中至关重要,原因有几个。首先,异常振动通常是机械故障的早期指示。通过持续监测和分析振动数据,维护团队可以在故障发生之前检测到异常。其次,振动分析提供了潜在问题的具体性质的见解,从而允许进行有针对性和及时的干预。最后,通过利用振动数据,预测性维护策略可以摆脱基于时间的例行维护,转向更高效的基于状态的维护方法,从而优化设备性能并减少停机时间。
特征工程理论
在实践部分,我将为每个特征提供代码示例,并附上解释以便于实际应用。让我们探索特征工程背后的理论及其在振动数据分析中的应用。
时间域特征
在这一类别中,我们为每个振动信号计算统计度量,如均值、标准差、偏度和峰度。此外,我们还深入探讨诸如均方根(RMS)和峰值因子等指标,以提供信号能量和峰值特征的整体度量。
-
分布统计度量计算每个振动信号的统计度量,如均值、标准差、偏度和峰度。 -
RMS(均方根)提供信号整体能量的度量。 -
峰值因子峰值与RMS值的比率。
频域特征
转换到频域时,我们采用如快速傅里叶变换 (FFT)的技术来转换时域信号。提取的特征包括主频率、谱熵和谱峭度。功率谱密度 (PSD)提供了关于功率分布和谐波关系的见解。
-
FFT (快速傅里叶变换)将时域信号转换为频域。从结果频谱中提取特征,如主频率、谱熵和谱峭度。 -
功率谱密度 (PSD)描述信号的功率如何在频率上分布。
时频特征
探索时频域涉及技术,如小波变换和短时傅里叶变换 (STFT),提供信号的动态表示,并捕捉频率内容随时间的变化。
-
小波变换提供信号的时频表示。从小波系数中提取特征。 -
短时傅里叶变换 (STFT)表示信号频率内容如何随时间变化。
包络分析
解调技术,如Hilbert 变换或小波变换,用于提取信号包络。分析包络内的特征增加了另一层理解。
解调使用Hilbert 变换或小波变换提取信号的包络。分析包络的特征。
窗口上的统计测量
滚动统计,通过固定大小的窗口计算,允许捕捉趋势和模式。此外,高阶统计矩在窗口上,称为波形矩,提供了宝贵的见解。
-
滚动统计在固定大小的窗口上计算统计量,捕捉趋势和模式。 -
波形矩高阶统计矩在窗口上。
重现图
深入研究重现图并利用重现定量分析 (RQA)可以辨别数据结构中的模式,为振动信号提供独特的视角。
重现定量分析 (RQA)分析重现图的结构以捕捉数据中的模式。
特定领域特征
特定领域特征,如峰值特征和形状特征,旨在识别和分析振动信号中的峰值和整体波形形状。
-
峰值特征识别和分析振动信号中的峰值。 -
形状特征提取与信号波形形状相关的特征。
尽管这些示例并未涵盖所有可能性,但其中一些可能对您的需求有用。 😃
实践经验

图片由Amauri Mejía提供,来自Unsplash
现在是时候通过代码亲自操作了。我们将进行一些实验,帮助你熟悉文章中的概念。我建议你复现这些实验。
为实验创建信号
我们需要模拟振动信号并增加更多现实感,以复现设备磨损
def generate_vibration_signal(duration, sampling_rate, frequency, amplitude, noise_level, max_wear, wear_threshold):
t = np.linspace(0, duration, int(sampling_rate * duration), endpoint=False)
# Generate a sinusoidal signal
signal = amplitude * np.sin(2 * np.pi * frequency * t)
# Add random noise to simulate real-world conditions
noise = np.random.normal(0, noise_level, signal.shape)
signal_with_noise = signal + noise
# Simulate equipment wear
wear = np.linspace(0, max_wear, len(t))
wear[wear > wear_threshold] = 0 # Reset wear if it exceeds the threshold
signal_with_wear = signal_with_noise + wear
return t, signal_with_wear
在这段代码中,磨损在达到特定值后会重置——模拟设备更换
让我们生成信号并绘图
# Parameters
duration = 20 # seconds
sampling_rate = 20 # Hz
frequency = 5 # Hz (vibration frequency)
amplitude = 1.0 # Min Max range
noise_level = 0.3 # Noise factor to increase reality
max_wear = 1 # Maximum wear before reset
wear_threshold = 0.5 # Wear threshold for reset
# Generate synthetic vibration signal with wear and threshold
time, vibration_signal = generate_vibration_signal(duration, sampling_rate, frequency, amplitude, noise_level, max_wear, wear_threshold)
# Plot the signal
plt.plot(time, vibration_signal)
plt.title('Synthetic Vibration Signal with Equipment Wear')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()

带磨损效应的噪声信号用于下一步实验
本文的任务是向你介绍这些酷炫的功能。我们在这里不会构建整个过程——今天没有管道给我们。那是另一个文章的故事!现在,让我们深入探索从信号中创建特征的有趣世界。准备好迎接特征盛宴了吗?出发吧!🚀
窗口化还是不窗口化,这是个问题。

由我使用 DiscoDiffusion 模型生成
时间序列窗口化就像是在连续时间线中查看快照,这非常有用,特别是在预测维护中。想象一下你在看一部电影,但不是看完整部电影,而是每隔几分钟暂停一次并拍一张照片。这些照片就是你的“窗口”。这些窗口帮助我们理解事物如何随时间变化。在机器和设备的世界中,了解它们过去的行为有助于我们预测未来可能发生的情况。
使用这些窗口的一个大优点是它们使理解发生了什么变得更容易。这就像将一个大故事拆分成更小的章节。每个窗口就是一个章节,通过查看它们,我们可以发现该时间段内发生的任何奇怪或有趣的事情。这种详细的视角帮助我们找出机器可能出现磨损或故障的原因。此外,这些窗口帮助我们处理信息获取频率的变化,并处理数据中的任何异常,确保我们的预测是可靠的。
但当然,这并非全是阳光和彩虹。选择这些窗口的正确大小有点棘手。如果它们太大或太小,我们可能会错过重要细节或添加不必要的噪声。这就像为相机选择正确的镜头——你想捕捉到恰到好处的量。此外,决定这些窗口是否应该重叠也是一个难题。重叠的窗口提供更多的上下文,但重叠过多可能会使数据变得重复。这就像试图平衡在书籍每一章中包含多少背景故事。找到这个甜蜜点对确保我们关于机器维护的预测准确无误至关重要。
窗口化示例
df_windowed = pd.DataFrame({'time': time, 'vibration_signal': vibration_signal})
# Make some experiments
window_size = int(2)
# Apply mean windowing using the 'rolling' function
df_windowed['mean_amplitude'] = df_windowed['vibration_signal'].rolling(window=window_size, min_periods=1).mean()
# Plot the original signal and the mean windowed signal
plt.plot(df_windowed['time'], df_windowed['vibration_signal'], label='Original Signal')
plt.plot(df_windowed['time'], df_windowed['mean_amplitude'], label=f'Mean Window ({window_size} samples)')
plt.title('Synthetic Vibration Signal with Mean Windowing')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.show()
我打算举办一个不同大小的窗口派对,向你展示它们如何改变事物。这就像 Mean Windowing 是我们的酷 DJ,在窗口范围内旋转平均值。让我们看看数据舞池如何跟随不同窗口大小的节拍!

平均窗口大小 = 2
窗口大小为 2 时,很难看到任何清晰的模式;这引入了太多噪音。我们需要增加窗口大小,以更好地了解数据中发生的情况。

平均窗口大小 = 200
现在,窗口太大了,这不好,因为我们丢失了很多数据细节。我们需要一个合适的窗口,以捕捉所有重要信息。

平均窗口大小 = 20
窗口大小为 20 时,数据模式变得非常明显,这让我们能够识别在信号生成过程中引入的合成“磨损效应”。在训练模型时,进行试错以找出最佳窗口大小是至关重要的。本文中,我将使用窗口大小 20 来生成特征。
时间域特征
分布统计度量

分布统计(Latex 编译)
让我们玩一下分布测量。当然,我可能会跳过mean和standard deviation,因为它们很明显,但我尝试让这个文章有点趣味,以便深入解释每个内容。
均值
想象一下你和你的朋友们在吃披萨派对上。每个人都喜欢各种配料的披萨。这个mean就像是计算每个人披萨上意大利辣肠片的平均数量。如果一个朋友有很多,而另一个朋友只有几个,平均值帮助你知道每个人的意大利辣肠片数量。就像找到披萨的和谐!
标准差
现在,让我们谈谈standard deviation。想象一群猫。一些猫非常放松和懒惰,而另一些猫则非常活跃和好动。这个standard deviation就像是测量每只猫的能量水平如何偏离或不同于所有猫的平均能量水平。如果标准差很高,那你就有一群既懒又活跃的猫。如果标准差很低,大多数猫的能量水平相似——也许是一个悠闲的猫派对!
偏度
让我们用水果篮的场景来更清楚地理解positive和negative skewness之间的区别。
-
Positive Skewness (向右偏斜)想象你的朋友们正在填充一个水果篮。大多数朋友决定添加各种水果,但有几个朋友特别热衷,添加了额外的香蕉、苹果和橙子。由于这种额外的水果热情,水果篮的秋千向右倾斜。这就是positive skewness,表示右侧有更多的兴奋。 -
Negative Skewness (Light to the Left)现在,假设几个朋友决定保持轻松,只在篮子里添加几颗葡萄和浆果。这种轻盈的水果方式使跷跷板向左倾斜。这就是negative skewness,表示向左方向的轻微偏斜。
峰度
现在,想象你在坐过山车。一些过山车很疯狂,充满了曲折,而其他的则比较温和。Kurtosis 是我们过山车的评论员,评估车程的刺激程度。Positive kurtosis 意味着过山车有急转弯和意外的回旋,而 negative kurtosis 表示平稳、更温和的骑行。Kurtosis 就是我们统计主题公园的刺激因子!
偏度比较
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import skew, kurtosis
# Set a random seed for reproducibility
np.random.seed(1992)
# Generate synthetic datasets with varying skewness and kurtosis
# Normal distribution
normal_data = np.random.normal(loc=170, scale=5, size=1000)
# Positively skewed distribution
skewed_data = np.random.gamma(shape=2, scale=5, size=1000)
# Negatively skewed distribution
negative_skewed_data = -np.random.gamma(shape=2, scale=5, size=1000)
# Calculate mean and median for each dataset
normal_mean, normal_median = np.mean(normal_data), np.median(normal_data)
skewed_mean, skewed_median = np.mean(skewed_data), np.median(skewed_data)
negative_skewed_mean, negative_skewed_median = np.mean(negative_skewed_data), np.median(negative_skewed_data)
# Plot the distributions
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.hist(normal_data, bins=30, color='blue', alpha=0.7)
plt.axvline(x=normal_mean, color='red', linestyle='--', label=f'Mean: {normal_mean:.2f}')
plt.axvline(x=normal_median, color='green', linestyle='--', label=f'Median: {normal_median:.2f}')
plt.legend()
plt.title('Normal Distribution')
plt.subplot(1, 3, 2)
plt.hist(skewed_data, bins=30, color='orange', alpha=0.7)
plt.axvline(x=skewed_mean, color='red', linestyle='--', label=f'Mean: {skewed_mean:.2f}')
plt.axvline(x=skewed_median, color='green', linestyle='--', label=f'Median: {skewed_median:.2f}')
plt.legend()
plt.title('Positively Skewed Distribution')
plt.subplot(1, 3, 3)
plt.hist(negative_skewed_data, bins=30, color='green', alpha=0.7)
plt.axvline(x=negative_skewed_mean, color='red', linestyle='--', label=f'Mean: {negative_skewed_mean:.2f}')
plt.axvline(x=negative_skewed_median, color='green', linestyle='--', label=f'Median: {negative_skewed_median:.2f}')
plt.legend()
plt.title('Negatively Skewed Distribution')
plt.tight_layout()
plt.show()

偏度比较图与均值和中位数
Normal 分布的数据在两侧均匀分布,均值与中位数紧密对齐。
Positive / Right-skewed 分布在右侧有较长或较胖的尾部,表示左侧的数据点更多。均值通常大于中位数。
Negative / Left-skewed 分布在左侧有较长或较胖的尾部,表明右侧的数据点更多。均值通常小于中位数。
峰度比较
# Leptokurtic distribution (heavier tails)
heavy_tails_data = np.random.exponential(scale=10, size=1000)
# Platykurtic distribution (lighter tails)
light_tails_data = np.random.uniform(low=160, high=180, size=1000)
# Calculate mean and median for each dataset
normal_mean, normal_median = np.mean(normal_data), np.median(normal_data)
heavy_tails_mean, heavy_tails_median = np.mean(heavy_tails_data), np.median(heavy_tails_data)
light_tails_mean, light_tails_median = np.mean(light_tails_data), np.median(light_tails_data)
# Plot the distributions with mean and median
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.hist(normal_data, bins=30, color='blue', alpha=0.7)
plt.axvline(x=normal_mean, color='red', linestyle='--', label=f'Mean: {normal_mean:.2f}')
plt.axvline(x=normal_median, color='green', linestyle='--', label=f'Median: {normal_median:.2f}')
plt.legend()
plt.title('Normal (esokurtic) Distribution')
plt.subplot(1, 3, 2)
plt.hist(heavy_tails_data, bins=30, color='red', alpha=0.7)
plt.axvline(x=heavy_tails_mean, color='red', linestyle='--', label=f'Mean: {heavy_tails_mean:.2f}')
plt.axvline(x=heavy_tails_median, color='green', linestyle='--', label=f'Median: {heavy_tails_median:.2f}')
plt.legend()
plt.title('Leptokurtic Distribution (Heavier Tails)')
plt.subplot(1, 3, 3)
plt.hist(light_tails_data, bins=30, color='green', alpha=0.7)
plt.axvline(x=light_tails_mean, color='red', linestyle='--', label=f'Mean: {light_tails_mean:.2f}')
plt.axvline(x=light_tails_median, color='green', linestyle='--', label=f'Median: {light_tails_median:.2f}')
plt.legend()
plt.title('Platykurtic Distribution (Lighter Tails)')
plt.tight_layout()
plt.show()

峰度图比较
现在,让我们计算统计数据进行比较。
# Calculate skewness and kurtosis for each dataset
normal_skewness = skew(normal_data)
normal_kurtosis = kurtosis(normal_data)
skewed_skewness = skew(skewed_data)
skewed_kurtosis = kurtosis(skewed_data)
negative_skewness = skew(negative_skewed_data)
negative_kurtosis = kurtosis(negative_skewed_data)
heavy_tails_skewness = skew(heavy_tails_data)
heavy_tails_kurtosis = kurtosis(heavy_tails_data)
light_tails_skewness = skew(light_tails_data)
light_tails_kurtosis = kurtosis(light_tails_data)
# Print the calculated values
print("Normal Distribution:")
print(f"Skewness: {normal_skewness}, Kurtosis: {normal_kurtosis}\n")
print("Positively Skewed Distribution:")
print(f"Skewness: {skewed_skewness}, Kurtosis: {skewed_kurtosis}\n")
print("Negatively Skewed Distribution:")
print(f"Skewness: {negative_skewness}, Kurtosis: {negative_kurtosis}\n")
print("Leptokurtic Distribution (Heavier Tails):")
print(f"Skewness: {heavy_tails_skewness}, Kurtosis: {heavy_tails_kurtosis}\n")
print("Platykurtic Distribution (Lighter Tails):")
print(f"Skewness: {light_tails_skewness}, Kurtosis: {light_tails_kurtosis}\n")

计算的统计结果输出
正态分布
-
Skewness-0.0237(略微负偏) -
Kurtosis0.1356(平峰度,比正常分布更平坦)
正偏分布
-
Skewness1.3753(强烈正偏) -
Kurtosis2.7358(峰态分布,尾部比正常分布更重)
负偏分布
-
Skewness-1.3357(强烈负偏) -
Kurtosis2.4060(峰态分布,尾部比正常分布更重)
峰态分布(更重的尾部)
-
Skewness1.8463(正偏) -
Kurtosis4.4461(高峰度,非常重的尾部)
平峰度分布(更轻的尾部)
-
Skewness-0.0243(略微负偏) -
Kurtosis-1.1587(平峰度,尾部比正常分布轻得多)
使用滚动窗口应用统计

滚动窗口解释(Latex 编译)
def calculate_rolling_statistics(signal, window_size):
df = pd.DataFrame({'signal': signal})
rolling_stats = df['signal'].rolling(window=window_size, min_periods=1)
mean_values = rolling_stats.mean()
std_dev_values = rolling_stats.std()
skewness_values = rolling_stats.apply(skew, raw=True)
kurtosis_values = rolling_stats.apply(kurtosis, raw=True)
return mean_values, std_dev_values, skewness_values, kurtosis_values
window_size = 20
# Calculate rolling statistics
rolling_means, rolling_std_devs, rolling_skewness, rolling_kurtosis = calculate_rolling_statistics(vibration_signal, window_size)
我会将结果绘制出来,为了更好的可视化,我会分开 skewness 和 kurtosis
# Plot the signal and rolling statistics
plt.figure(figsize=(12, 6))
# Plot Rolling Mean, Rolling Mean + Std Dev, Rolling Mean - Std Dev
plt.subplot(2, 1, 1)
plt.plot(time[:len(rolling_means)], vibration_signal[:len(rolling_means)], label='Vibration Signal')
plt.plot(time[:len(rolling_means)], rolling_means, label='Rolling Mean')
plt.plot(time[:len(rolling_means)], rolling_means + rolling_std_devs, label='Rolling Mean + Std Dev', linestyle='--')
plt.plot(time[:len(rolling_means)], rolling_means - rolling_std_devs, label='Rolling Mean - Std Dev', linestyle='--')
plt.title(f'Synthetic Vibration Signal with Rolling Mean and Standard Deviation (Window Size = {window_size})')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
# Plot Rolling Skewness and Rolling Kurtosis
plt.subplot(2, 1, 2)
plt.plot(time[:len(rolling_means)], vibration_signal[:len(rolling_means)], label='Vibration Signal')
plt.plot(time[:len(rolling_means)], rolling_skewness, label='Rolling Skewness', linestyle='--')
plt.plot(time[:len(rolling_means)], rolling_kurtosis, label='Rolling Kurtosis', linestyle='--')
plt.title(f'Synthetic Vibration Signal with Rolling Skewness and Kurtosis (Window Size = {window_size})')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.tight_layout()
plt.show()

滚动应用与原始信号
观察结果可以明显看出,滚动窗口作为一种有效的去噪技术。我建议尝试不同的窗口大小,因为数据科学通常涉及通过试验和错误进行实证探索。
幕后 **.apply()** 的故事
这里有一点额外的内容,展示函数和窗口如何在后台一起工作。我会仅仅做一次,以给你一个关于当应用函数在处理窗口时如何工作的直观了解,使用我们的简单示例。
def calculate_rolling_statistics_behind_scenes(signal, window_size):
mean_values = np.convolve(signal, np.ones(window_size)/window_size, mode='valid')
std_dev_values = np.array([np.std(signal[i-window_size+1:i+1]) for i in range(window_size-1, len(signal))])
skewness_values = np.array([skew(signal[i-window_size+1:i+1]) for i in range(window_size-1, len(signal))])
kurtosis_values = np.array([kurtosis(signal[i-window_size+1:i+1]) for i in range(window_size-1, len(signal))])
return mean_values, std_dev_values, skewness_values, kurtosis_values
-
mean_values这是通过使用np.convolve函数计算的,该函数执行卷积操作。在这种情况下,它通过用一个窗口与信号进行卷积来计算滚动平均值。mode='valid'参数确保卷积仅在完整窗口可以适合而无需零填充的地方进行。 -
std_dev_values这是通过使用列表推导式遍历信号来计算的。对于信号中的每个位置i,它计算子数组signal[i-window_size+1:i+1]的标准差。这代表了滚动标准差。 -
skewness_values类似于标准差,它是通过使用列表推导式遍历信号来计算的。对于每个位置i,它计算子数组signal[i-window_size+1:i+1]的偏度。 -
kurtosis_values再次,类似于标准差和偏度,它是通过使用列表推导式遍历信号来计算的。对于每个位置i,它计算子数组signal[i-window_size+1:i+1]的峰度。
RMS(均方根)

RMS 方程(Latex 编译)
Root Mean Square (RMS) 就像是数学中的超级英雄。它接收一组数字,对每一个数字进行平方,计算平均值,然后取平方根。这个过程会给出一个单一的数字,代表原始数字的典型大小或强度。它在各个领域中都是一个方便的工具,从测量机械振动到评估信号强度。
def calculate_rolling_rms(signal, window_size):
df = pd.DataFrame({'signal': signal})
rolling_stats = df['signal'].rolling(window=window_size, min_periods=1)
rms_values = np.sqrt(rolling_stats.apply(lambda x: np.mean(x**2), raw=True))
return rms_values
window_size = 20
rolling_rms = calculate_rolling_rms(vibration_signal, window_size)
plt.plot(time[:len(rolling_rms)], vibration_signal[:len(rolling_rms)], label='Vibration Signal')
plt.plot(time[:len(rolling_rms)], rolling_rms, label='Rolling RMS', linestyle='--')
plt.title(f'Synthetic Vibration Signal with Rolling RMS (Window Size = {window_size})')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.tight_layout()
plt.show()

RMS 揭示了一个单一而强大的指标,捕捉了潜在的信号功率,同时有效地减少了噪声。这个图表作为一个视觉证明,展示了 RMS 在各种应用中提高信号清晰度和精度的实际效果。
峰值因子

峰值因子方程(Latex 编译)
Crest Factor (CF) 就像信号分析世界中的副手 Root Mean Square (RMS)。虽然 RMS 给出了整体强度,Crest Factor 则突出显示了峰值。它是最高点与 RMS 值的比率,告诉你信号的尖锐程度或“峰值”程度。可以将 CF 视为帮助你理解数据中尖锐峰值的超级英雄伙伴,无论是在声音波、电子信号还是其他波动测量中。它们一起组成了揭示数据中隐藏秘密的动态组合。
def calculate_crest_factor_and_peak(signal, window_size):
df = pd.DataFrame({'signal': signal})
rolling_stats = df['signal'].rolling(window=window_size, min_periods=1)
peak_values = rolling_stats.apply(lambda x: np.max(np.abs(x)), raw=True)
rms_values = np.sqrt(rolling_stats.apply(lambda x: np.mean(x**2), raw=True))
crest_factor_values = peak_values / rms_values
return crest_factor_values, peak_values
# Calculate rolling Crest Factor and Peak values
rolling_crest_factor, rolling_peak_values = calculate_crest_factor_and_peak(vibration_signal, window_size)
创建图表
# Plot the vibration signal, rolling Crest Factor, and Peak values
plt.plot(time[:len(rolling_crest_factor)], vibration_signal[:len(rolling_crest_factor)], label='Vibration Signal')
plt.plot(time[:len(rolling_crest_factor)], rolling_crest_factor, label='Rolling Crest Factor', linestyle='--')
plt.plot(time[:len(rolling_peak_values)], rolling_peak_values, label='Rolling Peak Values', linestyle='-.')
plt.title(f'Synthetic Vibration Signal with Rolling Crest Factor and Peak Values (Window Size = {window_size})')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.legend()
plt.tight_layout()
plt.show()

想象一下RMS和CF作为信号分析中的动态二人组。RMS就像整体强度专家,通过计算数字来提供全局视角。它平方、平均和开方,以展示信号的强度。现在,认识一下峰值侦探。它专注于信号中的尖峰部分,指出哪里出现了极高的值。它们一起组成了一个酷炫的团队,帮助你理解数据中的总体强度和尖峰。
这就是第一部分的内容!我们已经深入探讨了理论基础,并通过特征构建示例进行了实践。在系列的下一部分,我将揭示下一组特征的详细信息。敬请期待更多激动人心的见解!
这篇文章是“理解预测性维护”系列的一部分。
查看完整系列链接。确保你通过关注我,不错过新文章。
了解预测性维护——波形数据:特征工程(第二部分)
频谱数据的特征工程
·发布在Towards Data Science ·阅读时间 12 分钟·2023 年 12 月 1 日
--

文章目的
这是关于波形数据特征工程的文章第二部分。我们将专注于频谱特征。你有什么想法要补充吗?请随时分享!
本文是《了解预测性维护》系列的一部分。
查看系列文章的完整列表。通过关注我来确保不错过新文章。所有没有说明的图片均由我创建。
频域特征
过渡到频域,我们使用像快速傅里叶变换(FFT)这样的技术将时域信号转换。提取的特征包括主频率、频谱熵和频谱峭度。功率谱密度(PSD)和谐波比提供了关于功率分布和谐波关系的见解。
-
FFT(快速傅里叶变换)将时域信号转换为频域。从结果频谱中提取特征,如主频率、频谱熵和频谱峭度。 -
功率谱密度(PSD)描述信号的功率如何在频率上分布。
下一篇文章计划
-
小波变换
-
解调
-
循环量化分析(RQA)
为实验创建信号
我将使用与上一部分完全相同的方法:
开始学习波形数据信号处理所需的所有信息
[towardsdatascience.com
让我们生成信号:
# Parameters
duration = 20 # seconds
sampling_rate = 20 # Hz
frequency = 5 # Hz (vibration frequency)
amplitude = 1.0 # Min Max range
noise_level = 0.3 # Noise factor to increase reality
max_wear = 1 # Maximum wear before reset
wear_threshold = 0.5 # Wear threshold for reset
# Generate synthetic vibration signal with wear and threshold
time, vibration_signal = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
快速傅里叶变换(FFT)和短时傅里叶变换(STFT)

FFT 方程(已编译的 LaTeX)
信号表示
让我们从信号开始,它本质上是一系列数据点,表示信号随时间的变化。这可以是声音波形、一系列数字或任何随时间变化的数据。
离散傅里叶变换(DFT)
FFT 是计算称为离散傅里叶变换(DFT)的更高效方法。DFT 将我们的信号表示为正弦函数的总和,每个函数表示不同的频率分量。这就是魔法发生的地方。
分而治之
FFT 并不是直接计算整个信号的 DFT,而是利用了这样的事实:任何复合信号的 DFT 可以表示为其子部分的 DFT 的组合。它将信号划分为较小的部分,计算每个部分的 DFT,然后将它们组合起来。
蝴蝶操作
FFT 的魔力在于一种称为蝴蝶操作的过程。这就像一种舞蹈动作,其中计算出的频率被配对并以特定的方式组合。这一过程递归进行,直到得到整个信号的最终频率分量。
效率提升
FFT 速度的关键在于其能够显著减少与直接 DFT 方法相比所需的计算次数。通过利用信号中的对称性和模式,FFT 高效地计算频率分量。
代码时间
现在我们可以将理论应用到简单的代码行中:
import numpy as np
# Apply FFT to the signal
fft_result = np.fft.fft(vibration_signal)
# This very important part, let`s investigate it more in depth
frequencies = np.fft.fftfreq(len(fft_result), 1/sampling_rate)
len(fft_result) 这是 FFT 结果的长度,本质上是频域中的点数。FFT 操作将时域信号转换为频域信号,len(fft_result) 给出频率分量的数量。
1/sampling_rate 这是采样率 sampling_rate 的倒数,表示原始时域信号中样本之间的时间间隔。采样率是每秒钟的样本数。
np.fft.fftfreq() 这个函数生成与 FFT 结果对应的频率。它有两个参数,第一个是结果的长度 len(fft_result),第二个是采样间隔 1/sampling_rate。它返回一个频率数组。
但不用担心。使用这两行代码,整个“魔法”就会发生。
# Plot the time-domain signal
plt.subplot(2, 1, 1)
plt.plot(t, vibration_signal)
plt.title('Vibration Signal with Wear')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
# Plot the frequency-domain signal (FFT)
plt.subplot(2, 1, 2)
plt.plot(frequencies, np.abs(fft_result))
plt.title('Frequency Domain Signal (FFT)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude Spectrum')
plt.tight_layout()
plt.show()

振动信号及其 FFT 表示(代码输出)
在图中,我们注意到一个位于 0 的中心信号,两侧有两个镜像对称的信号。这表明我们的信号由一个单一的波组成。在进行实验之前,让我们首先探讨对称性的概念,然后我们可以用各种信号进行实验。
背景 — FFT 对称性
背后发生了什么?我们有一些数学和理论概念。让我们简化一下。
在许多实际场景中,信号由实数组成。在时间域,这些信号可以表示为一系列值。当你对实值信号进行FFT时,得到的频率谱是对称的。
复共轭对 对称性来自于FFT涉及复数这一事实。对于每个正频率分量,存在一个具有相同幅度的负频率分量。这些频率对是彼此的复共轭。
镜像信息 正频率表示信号在一个方向上的振荡信息,而负频率表示相同的信息,但方向相反。FFT捕捉了两个方向,这就是图形看起来对称的原因。
总之,FFT中的对称性是实值信号和复数在FFT上下文中数学属性的结果。
让我们开始下一个实验。
现在我们理解了FFT的基本概念。让我们模拟两个不同参数的信号连接在一起。
我们将生成第二个类似的信号,我们只关注频率和幅度:
# First Signal
frequency = 10
amplitude = 1
#Second Signal
frequency = 20
amplitude = 1
现在让我们创建一个组合信号并绘制图形:
t2, vibration_signal_2 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
# Combine the signals just simply add them :)
combined_signal = vibration_signal + vibration_signal_2
plt.plot(t1, combined_signal, label='Signal 1')
plt.title('Combined Signals')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.show()

组合信号(代码输出)
现在让我们进行FFT并绘制结果:
# Apply FFT to the combined signal
fft_result = np.fft.fft(combined_signal)
frequencies = np.fft.fftfreq(len(fft_result), 1/sampling_rate)
# Plot the frequency-domain signal (FFT)
plt.plot(frequencies, np.abs(fft_result))
plt.title('Frequency Domain Signal (FFT)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude Spectrum')
plt.show()

FFT 信号的表示(代码输出)
现在你可以看到我们有相同的幅度高度,但多出了 2 个信号,且它们的偏移量相等。第一个信号 = 10Hz 第二个信号 = 20Hz
X 轴位置根据信号频率进行调整。让我们引入第三个信号 = 100 Hz 并绘制。
frequency = 100
amplitude = 1
t3, vibration_signal_3 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
# Combine the signals just simply add them :)
combined_signal = vibration_signal + vibration_signal_2 + vibration_signal_3
frequency = 150 # Just for make offset, now you know how it works
amplitude = 2
t4, vibration_signal_4 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
# Combine the signals
combined_signal = vibration_signal + vibration_signal_2 +
vibration_signal_3 +vibration_signal_4
# Apply FFT to the combined signal
fft_result = np.fft.fft(combined_signal)
frequencies = np.fft.fftfreq(len(fft_result), 1/sampling_rate)
plt.plot(frequencies, np.abs(fft_result))
plt.title('Frequency Domain Signal (FFT)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude Spectrum')
plt.tight_layout()
plt.show()

组合信号和额外的第三个信号组件(代码输出)
正如我们所见,由于更高的频率值,我们的“新”信号现在偏移了很多。
如果我们添加一个具有不同幅度的第四个信号会发生什么?
制作一个新信号并将其一起绘制。
frequency = 150 # Just for make offset, now you know how it works
amplitude = 2
t4, vibration_signal_4 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
# Combine the signals
combined_signal = vibration_signal + vibration_signal_2 +
vibration_signal_3 + vibration_signal_4
frequency = 150 # Just for make offset, now you know how it works
amplitude = 2
t4, vibration_signal_4 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
# Combine the signals
combined_signal = vibration_signal + vibration_signal_2 +
vibration_signal_3 + vibration_signal_4
# Apply FFT to the combined signal
fft_result = np.fft.fft(combined_signal)
frequencies = np.fft.fftfreq(len(fft_result), 1/sampling_rate)
plt.plot(frequencies, np.abs(fft_result))
plt.title('Frequency Domain Signal (FFT)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude Spectrum')
plt.tight_layout()
plt.show()

组合信号和额外的第四个信号组件(代码输出)
现在你可以看到我们的幅度谱由于幅度值 1 和 2 高了两倍。
功率谱密度(PSD)

功率谱密度方程(Latex 编译)
PSD 就像是对信号进行特殊的快照,了解它在不同频率下的功率。这有助于我们观察音乐或振动频谱中的能量分布。我们有两种主要的 PSD 估计方法,Welch 和 Barlett。
对于实验,我们将创建一个新的组合信号,以便于清晰度和改进可视化(不接近频率以便清晰看到峰值)
frequency = 100
amplitude = 1
t5, vibration_signal_5 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
frequency = 200
amplitude = 1
t6, vibration_signal_6 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
frequency = 400
amplitude = 3
t7, vibration_signal_7 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
combined_signal2 = vibration_signal_5 + vibration_signal_6 + vibration_signal_7
Welch 方法

PSD Welch 方法方程(Latex 编译)
def welch_method(signal, segment_size=128, overlap=64):
f, Pxx = plt.psd(signal, NFFT=segment_size, Fs=sampling_rate, noverlap=overlap)
plt.title('Welch Method')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power/Frequency (dB)')
plt.show()
return f, Pxx
freq_welch, P_welch = welch_method(combined_signal2)

Welch 方法将信号分成重叠的段并对周期图进行平均。这种方法在准确性和计算复杂性之间提供了权衡。然而,它可能牺牲频率分辨率以改善方差特性。
Bartlett 方法

PSD Bartlett 估计方程(Latex 编译)
Bartlett 方法 是 Welch 方法的一个特例,没有段间重叠。虽然提供了简便性和减少了计算负担,但它与 Welch 方法在准确性和频率分辨率之间存在类似的权衡。
def bartlett_method(signal, segment_size=128):
f, Pxx = plt.psd(signal, NFFT=segment_size, Fs=sampling_rate, window=np.bartlett(segment_size))
plt.title('Bartlett Method')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Power/Frequency (dB)')
plt.show()
return f, Pxx
freq_bartlett, P_bartlett = bartlett_method(combined_signal2)

关于差异的图形说明

短时傅里叶变换(STFT)— 窗函数 + FFT

短时傅里叶变换方程(Latex 编译)
为什么它在预测性维护中有用?
STFT 分析使工程师能够仔细检查信号的频率成分,例如机械的振动或声发射。与各种故障或异常相关的独特频率模式允许早期检测潜在问题,有助于减少停机时间和维护成本。
STFT 在预测性维护中的主要优势之一是它能及早识别故障特征。这尤其重要,因为 STFT 可以检测低振幅的振动或信号中的微妙变化,提供一个早期警告系统,使从业人员能够在问题升级之前解决它们。
此外,STFT 在区分正常和异常频率模式中发挥了关键作用。通过比较健康设备和故障设备的基线频谱,工程师可以有效识别偏差并预警即将出现的问题,为有效的预测性维护策略奠定基础。
让我们写一些代码
from scipy.signal import spectrogram
# Apply Short-Time Fourier Transform (STFT) to the combined signal
frequencies, times, Sxx = spectrogram(combined_signal, fs=sampling_rate,
nperseg=256, noverlap=128)
哇,仅仅“一行”代码!Python 的威力。在这个函数中,我们需要处理两个具有有趣名称的参数 nperseg 和 noverlap。
nperseg (每个段的点数).
这个参数决定了每个时间窗口或段的大小。较大的 nperseg 结果提供更好的频率分辨率,但时间分辨率较差,而较小的 nperseg 结果提供更好的时间分辨率,但频率分辨率较差。换句话说,它影响时间和频率分辨率之间的权衡。在这个例子中,nperseg=256 意味着每个时间窗口长 256 个点。
noverlap (重叠点的数量).
这个参数控制连续时间窗口之间的重叠。重叠的窗口有助于捕捉信号随时间的动态变化。如果 noverlap 设置为小于 nperseg 的值,窗口将会重叠。在这个例子中,noverlap=128 意味着每个时间窗口与前一个窗口重叠 128 个点。
让我们绘制它
plt.pcolormesh(times, frequencies, 10 * np.log10(Sxx), shading='auto')
plt.title('Spectrogram - Combined Signal')
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.show()

STFT 结果的 combined_signal 的频谱图(代码输出)
我们必须使用 Sxx,它基本上是频谱图输出(Sxx[i, j])。我们可以很容易地发现我们的 4 个信号组合在一起。带子越粗,幅度越高。
Spectogram Grid (Sxx) 想象一个像图纸一样的大网格。
Grid Rows (up and down) 网格中的每一行表示不同的声音频率,比如高音或低音。较高的行可能表示高音,较低的行可能表示低音。
Grid Columns (left to right) 网格中的每一列表示不同的时间点,就像快照一样。向右移动时,你可以看到声音如何随时间变化。
Colors in the Grid 网格中的颜色告诉你每个频率在每个时刻的响度或强度。明亮的颜色可能意味着响声,暗淡的颜色可能意味着安静的声音。
如果我们的一个信号要强得多,会发生什么?
让我们在预测性维护的背景下重新框定解释,并模拟一种情况,在这种情况下,我们将第四个信号的幅度显著增加,类似于机器中的潜在故障或异常。
amplitude = 20 # Increase it
t4, vibration_signal_4 = generate_vibration_signal(duration, sampling_rate,
frequency, amplitude, noise_level, max_wear, wear_threshold)
# Combine the signals
combined_signal = vibration_signal + vibration_signal_2 +
vibration_signal_3 + vibration_signal_4
# Apply FFT to the combined signal
fft_result = np.fft.fft(combined_signal)
frequencies = np.fft.fftfreq(len(fft_result), 1/sampling_rate)
plt.plot(frequencies, np.abs(fft_result))
plt.title('Frequency Domain Signal (FFT)')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude Spectrum')
plt.tight_layout()
plt.show()

FFT 结果的 combined_signal,具有显著更高幅度的组成部分(代码输出)
现在,观察我们的新信号,它显著占据主导地位。这并不意味着其他信号消失了,而是它们的存在被显著信号的幅度所掩盖,使它们看起来更像是噪声。
from scipy.signal import spectrogram
# Apply Short-Time Fourier Transform (STFT) to the combined signal
frequencies, times, Sxx = spectrogram(combined_signal, fs=sampling_rate, nperseg=256, noverlap=128)
plt.pcolormesh(times, frequencies, 10 * np.log10(Sxx), shading='auto')
plt.title('Spectrogram - Combined Signal')
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.show()

STFT 结果的 combined_signal 频谱图,具有显著更高幅度的组成部分(代码输出)
在频谱图结果中,出现了一条显著而强烈的线条,指示未来机械故障的潜在前兆。在正常操作条件下,我们的机械设备会产生背景谱。然而,随着磨损的进展或某个部件即将故障,这种正常状态的偏差变得显而易见。为了更深入地研究这一点,我们可能会探索训练卷积神经网络(CNN)来分析这些频谱模式,并识别与即将发生的故障相关的特征。特征工程相关文章之后,我将开始建模系列。
这是波数据特征工程第二部分的结束。在下一篇文章中,我将进行讨论。
-
小波变换
-
解调
-
递归量化分析(RQA)
本文是《理解预测性维护》系列的一部分。由于您提供的宝贵反馈和建议,我计划将其编撰成书。如果有任何值得扩展或包含的内容,请告诉我。我正在考虑所有您的反馈。
查看完整系列请点击此链接。请关注我,以确保您不会错过新文章。
使用 Gradio 理解保留率
如何利用网络应用程序进行分析
·
关注 发表在 Towards Data Science ·15 min 阅读·2023 年 10 月 21 日
--
图片来源:DALL-E 3
我记得我第一次构建网络应用程序的那一刻。大约在八年前,那时我还是一个相当初级的分析师,并且确信 BI 工具可以解决所有问题。
工程团队构建了一个新 SDK 的原型,并希望了解它是否能更好地收集数据。他们在一组设备上进行测试,查看数据并将其与旧版本进行比较。然而,这组设备不断变化,因此在 BI 工具中保持数据的更新需要相当多的工作。因此,我决定构建一个网络应用程序。
我找了一套文章(如果我没记错的话,大约十篇或十一篇),阅读了它们并尝试将这些知识应用于我的任务。我花了大约一周的时间完成了第一个原型。我必须同时编写后端和前端,因此现在我可以认为自己至少是一个初级全栈开发者。后端我使用了 Flask(我很幸运没有碰到 Django,否则我可能要花整整一个月),前端则使用了 Bootstrap 和 Leaflet。
总体而言,这是一个具有挑战性的任务,需要付出很多努力来提升工程技能。我相信深入了解与你的主要专业领域相关的其他领域总是值得的。
然而,我很高兴的是,如今有许多工具允许分析师和数据科学家在不到一个小时的时间内构建原型。在许多情况下,这些原型可以将你的分析提升到一个新的水平。以下是一些例子:
-
根据输入参数(如营销预算或我们将在何处推出新功能)预测收入和受众,
-
可以加快团队工作速度或减少临时工作量的工具,如 A/B 测试计算器或自动根本原因分析,
-
MVP 解决方案,例如,如果你想使用 LLMs 来自动化一些内部流程,值得在花时间开发生产版本之前先测试一个原型。我在之前的一篇文章中分享了一个这样的 ML 原型,“在一个小时内构建你的第一个深度学习应用”。
在这篇文章中,我想介绍一个这样的框架,它可以帮助你快速且几乎不费力地创建外观美观的网页应用程序,而无需烦恼于 JavaScript 和 CSS。我们将学习 Gradio 的基础知识,开发几个网页应用程序,并将它们发布到 HuggingFace Spaces,以便任何人都可以访问它们。
Gradio 并不是唯一一个这样的框架。还有其他一些开源 Python 替代品:
-
Streamlit 是另一个流行且强大的库,用于用少量代码构建数据应用程序。它也得到了 HuggingFace Spaces 的支持,以便你可以托管这些应用程序。
-
Dash 如果你已经习惯了 Plotly,它可能会很方便,并且提供更多的定制能力。
你可以在 这篇文章 中找到有关不同框架主要功能的更多细节。
Gradio 基础知识
Gradio 是一个开源的 Python 库,用于构建交互式应用程序。
Gradio 的主要优点是:
-
你可以仅使用 Python 来构建应用程序,这也意味着你可以在应用程序中使用所有 Python 库,
-
你可以在 Jupyter Notebook 中运行它或作为一个单独的网页,
-
你可以在 HuggingFace spaces 上永久托管 Gradio 应用。
没有灵丹妙药,所以 Gradio 有其局限性:
-
它是专为机器学习应用设计的。因此,如果你将它用于其他用途,可能需要更改默认设置(例如,使用
allow_flagging= "never"关闭标记功能)。 -
自定义有限,尤其是在设计方面。
-
我会记住 Gradio 主要是一个用于快速原型开发的框架。它大多工作良好,但偶尔会遇到一些奇怪的行为。例如,Safari 中的表格编辑行为反直觉,或者有时需要重新启动 Jupyter Notebook 以加载界面。
要开始使用 Gradio,我们需要安装 Python 包。
pip install gradio
遵循老程序员的传统,让我们从“Hello, World!”开始。
我们可以使用 gr.Interface 类来定义界面(文档)。这是 Gradio 的核心类之一,帮助你基于任何 Python 函数创建 Web 应用。
我们需要指定以下参数:
-
inputs: 界面的输入组件(在我们的例子中,只是一个文本框), -
outputs: 界面的输出组件(在我们的例子中,也只是一个文本框), -
fn: 核心功能(一个获取输入并返回输出的函数,在我们的例子中,从输入中获取名称并返回“Hello,!”), -
title&description: 一些 Markdown 让我们的应用更加用户友好。
import gradio as gr
demo = gr.Interface(
inputs=[gr.Textbox(label="Name", lines=1)],
outputs=[gr.Textbox(label="Result", lines=1)],
fn=lambda x: 'Hello, %s!' % x,
title="Hello, World!",
description="Your first app using Gradio",
allow_flagging='never')
demo.launch()
你可以在你的 Jupyter Notebook 中运行这段代码,查看结果。这对于调试非常方便。稍后,我们将讨论如何让你的 Web 应用对他人可用。

作者提供的图片
就这样:只需几行代码,你的第一个 Gradio 应用就运行了。我还必须指出,它看起来相当不错,我们没有使用任何前端魔法。
当你在 Jupyter Notebook 中工作时,Gradio 会在后台启动许多进程,因此不时使用
*gr.close_all()*关闭连接是值得的。
我们查看了最基本的示例,了解了 Gradio 的构建模块:输入、输出和函数。现在,我们准备开始真实的分析任务。
增长模拟
作为第一个示例,我们将研究保留对产品用户增长的影响。
保留作为增长的基础
两个参数定义了产品的增长:
-
acquisition(每个周期的新用户数量),
-
retention(保留用户的能力)。
让我们建模用户基础如何根据保留曲线增长。
我们可以使用以下函数和一组参数(a、b、c 和 d)来描述任何保留曲线:

让我们谈谈留存的最常见情况:队列由产品中的第一次操作定义,所有操作都计入留存。在这种情况下,periods = 0的留存必须等于 1(因为队列入口和留存事件是相同的)。所以,我们可以自动定义其中一个参数:

增长的主要因素是长期留存。它决定了客户是否长期使用产品,以及你的产品是否可持续增长,或者客户在一个月内流失,你是否需要不断获取更多的新用户来实现增长。在我们的公式中,a参数负责长期留存。

我们可以使用这个公式来定义留存曲线。所以我们已经拥有了继续开发所需的一切。
可视化留存图表
让我们从简单的开始,制作一个接受留存曲线参数并以图表形式显示关系的应用程序。
类似于我们的“Hello, World”示例,我们需要使用gr.Interface类并传递inputs、outputs和fn来映射它们。
-
我们现在需要更多的输入参数。因此,
inputs将是一个控件列表。我们将使用gr.Slider和gr.Dropdown控件。对于gr.Slider,我们需要传递最小值、最大值、默认值和一个标签,我们将在函数中使用。
对于gr.Dropdown,我们需要定义一个可能值的列表、默认值和标签。
-
我们仍然只有一个输出——一个图表,因此
outputs将是gr.Plot,没有任何参数。 -
函数
fn将把输入映射到输出,因此它将获取输入参数并返回plotly.Figure对象进行可视化。
import plotly.express as px
# functions to calculate retention
def get_retention(a, b, c, d, periods):
return a + 1./(b + c * (periods ** d))
def get_retention_same_event(a, c, d, periods):
b = 1./(1 - a)
return get_retention(a, b, c, d, periods)
# define function - return plot depending on input parameters
def get_retention_plot(a, c, d, num_periods):
df = pd.DataFrame({'x': range(num_periods + 1)})
df['retention'] = df.x.map(lambda x: get_retention_same_event(a, c, d, x))
return px.line(df, x = 'x', y = 'retention',
color_discrete_sequence = px.colors.qualitative.Prism,
title = 'Retention curve', labels = {'x': 'period'})
# define inputs
inputs = [
gr.Slider(0, 1, 0.03, label="a"),
gr.Slider(0, 5, 0.55, label="c"),
gr.Slider(0, 5, 1.5, label="d"),
gr.Dropdown([10, 30, 60, 90], value = 30, label="Number of Periods"),
gr.Dropdown([10, 100, 1000, 10000], value = 10000, label="Number of new users each period")
]
# define outputs
outputs = gr.Plot()
# define interface
demo = gr.Interface(
fn=get_retention_plot,
inputs=inputs,
outputs=outputs,
cache_examples=True,
allow_flagging = 'never' # hiding default flag functionality in the interface
)
# launch
demo.launch(debug = True)
让我们尝试运行这个应用程序。它在工作——我们可以看到一个图表,提交新参数时图表会发生变化。

添加更多图表
我们的目标是查看留存对增长的影响,因此我们需要添加不仅展示留存,还展示随时间推移的观众图表。让我们改变一下界面。
为了简化起见,我们假设在每个周期内,相同数量的新用户开始使用我们的产品(cohort_size参数)。
我们只需对实现做几个小改动:
-
更改
get_retention_plot函数,以便它接受一个关于队列大小的新参数,计算随时间变化的用户数量,并返回三个图形。 -
参数
outputs现在等于三个gr.Plot()对象的列表。
def get_retention_plot(a, c, d, num_periods, cohort_size):
ret_df = pd.DataFrame({'x': range(num_periods + 1)})
ret_df['retention'] = ret_df.x.map(lambda x: get_retention_same_event(a, c, d, x))
ret_fig = px.line(ret_df.iloc[1:], x = 'x', y = 'retention',
color_discrete_sequence = px.colors.qualitative.Prism,
title = 'Retention curve')
# simulation
tmp_data = []
for cohort in range(num_periods + 1):
for cohort_period in range(num_periods + 1):
period = cohort_period + cohort
if period > num_periods:
continue
retention = get_retention_same_event(a, c, d, cohort_period)
tmp_data.append(
{
'cohort': 'cohort %s' % str(cohort).rjust(3, '0'),
'cohort_period': cohort_period,
'period': period,
'retention': retention,
'users': int(round(retention * cohort_size))
}
)
users_df = pd.DataFrame(tmp_data)
users_fig = px.area(users_df.groupby('period').users.sum(),
color_discrete_sequence = px.colors.qualitative.Prism,
title = 'Active users')
cohorts_fig = px.area(users_df.pivot_table(index = 'period', columns = 'cohort', values = 'users',
aggfunc = 'sum'),
color_discrete_sequence = px.colors.qualitative.Prism,
title = 'Active users by cohorts')
return ret_fig, users_fig, cohorts_fig
inputs = [
gr.Slider(0, 1, 0.03, label="a"),
gr.Slider(0, 5, 0.55, label="c"),
gr.Slider(0, 5, 1.5, label="d"),
gr.Dropdown([10, 30, 60, 90], value = 30, label="Number of Periods"),
gr.Dropdown([10, 100, 1000, 10000], value = 10000, label="Number of new users each period")
]
outputs = [gr.Plot(), gr.Plot(), gr.Plot()]
demo = gr.Interface(
fn=get_retention_plot,
inputs=inputs,
outputs=outputs,
allow_flagging = 'never',
cache_examples=True,
)
demo.launch(debug = True)
太棒了,现在我们可以看到完整的画面并分析关系。然而,还有改进的空间——我们可以添加格式以使应用程序对用户更方便。

作者提供的图片
添加一些样式
我们可以稍微调整一下界面,使其更加用户友好和直观。
为此,我们将使用 gr.Blocks() 作为上下文。此功能允许您创建更自定义的网页应用程序,并定义布局和数据流(触发函数的事件及其后续执行)。
Blocks 将为我们开启新的机会:
-
使用
gr.Blocks(),我们可以使用gr.Row()和gr.Column()来组织布局。 -
gr.Markdown允许您添加 markdown 元素,例如标题或甚至带有公式的 LaTeX(默认情况下,您需要将其放在 $ 内部)。 -
gr.Accordion可以帮助您隐藏一些默认不想显示给用户的参数。 -
此外,这种方法允许您定义更复杂的更新逻辑。例如,不仅在提交按钮上更新图表,还可以在任何输入参数发生变化时更新图表。我们将在以下示例中使用此功能。
在使用 Blocks 时,我们需要将每个输入和输出定义为变量,例如,a = gr.Slider(0, 1, 0.03, label=”a”)。
此外,没有默认控件,所以我们需要自己定义按钮 —— btn_caption = gr.Button(“Submit”)。
按钮点击的操作也必须指定,设置已经熟悉的参数 —— inputs、outputs 和 fn。
btn_caption.click(fn=get_retention_plot,
inputs=[a, c, d, num_periods, cohort_size],
outputs=[plot1, plot2, plot3])
这里是完整的代码版本。
with gr.Blocks() as demo:
gr.Markdown("# Understanding Growth 🚀")
with gr.Row():
with gr.Column():
gr.Markdown("## Retention curve parameters 📈")
gr.Markdown(r"$\textbf{retention}(\textsf{x}) = \textsf{a} + \frac{\textsf{1}}{\textsf{b} + \textsf{c} * \textsf{x}^{\textsf{d}}}\ where\ \textsf{b} = \frac{\textsf{1}}{\textsf{1}-\textsf{a}}$")
with gr.Row():
a = gr.Slider(0, 1, 0.03, label="a")
c = gr.Slider(0, 5, 0.55, label="c")
d = gr.Slider(0, 5, 1.5, label="d")
with gr.Accordion("More options", open=False):
with gr.Row():
num_periods = gr.Dropdown([10, 30, 60, 90], value = 30, label="Number of Periods")
cohort_size = gr.Dropdown([10, 100, 1000, 10000], value = 10000, label="Number of new users each period")
btn_caption = gr.Button("Submit")
with gr.Column():
plot1 = gr.Plot()
with gr.Row():
plot2 = gr.Plot()
plot3 = gr.Plot()
btn_caption.click(fn=get_retention_plot,
inputs=[a, c, d, num_periods, cohort_size],
outputs=[plot1, plot2, plot3])
demo.launch()
托管您的应用程序
此外,我们可以使用 HuggingFace Spaces 来托管我们的网页应用程序,并轻松地与他人分享。
要开始使用 Spaces,您需要有一个账户。如果还没有注册,请访问此链接。这不会超过几分钟。
下一步是创建一个新的 Space。您可以在文档中找到更详细的说明。

作者提供的图片
对于新的 Space,您必须填写以下参数:名称、许可证和 Gradio 作为您的 SDK。

作者提供的图片
然后,您需要将代码提交到 Hugging Spaces 的 Git 仓库。首先,我们需要克隆仓库。
-- cloning repo
git clone https://huggingface.co/spaces/<your_login>/<your_app_name>
cd <your_app_name>
最近,HuggingFace 更改了Git 认证过程,所以我们需要先创建一个令牌,然后将其设置到 Git 仓库中。
git remote set-url origin https://<your_login>:<token>@huggingface.co/spaces/<your_login>/<your_app_name>
git pull origin
现在,到了提交与我们的应用程序相关的文件的时候。我们至少需要以下文件:
-
包含启动 Gradio 应用程序的 Python 代码的
app.py -
包含您应用程序所需的 Python 包列表的
requirements.txt文件。在我们的情况下,仅需pandas和plotly。
然后,基本的 git 步骤:添加、提交并推送到 HuggingFaces。
git add app.py
git add requirements.txt
git commit -m 'First version of retention simulator app'
git push
构建应用程序花了几分钟,现在完成了。我们的网页应用程序现在已在 HuggingFaces Spaces 上运行。您可以在这里尝试一下。

作者提供的图片
它比我们最初的版本看起来好很多,因为布局不需要滚动,用户也不需要猜测参数a、c和d的含义。
预测保留率
我们已经学会了如何在 Web 应用程序中根据一组参数生成图表。但是在现实生活中,我们通常需要输入大量数据,所以让我们找出如何在应用程序中使用.csv文件中的数据。
作为示例,我们将查看前几个周期的实际保留数据,并尝试预测后续周期的保留率。这是一个相当常见的任务,因为我们通常不希望等三个月才能比较新队列的第三个月保留率。我们将上传实际数据作为.csv文件。
不要浪费时间,我们直接进入实现部分吧。
从文件中获取数据
这是生成整个界面和业务逻辑的代码。它可能看起来有点复杂。不用担心,我们稍后会讨论核心要点。
# parses file or string and returns dataframe
def parse_file(input_text_or_file, num_periods):
if isinstance(input_text_or_file, str):
df = pd.read_csv(StringIO(input_text_or_file), sep = '\t')
else:
df = pd.read_csv(input_text_or_file.name, sep = '\t')
return df
# takes dataframe and returns plot
def show_graph_for_df(df, num_periods):
df['period'] = df.period.map(int)
df['retention_fact'] = df.retention_fact.map(float)
result = scipy.optimize.minimize(lambda x: get_mse_for_retention(x, df), [random.random(), random.random(), random.random()])
a, c, d = result.x
pred_df = pd.DataFrame({'period': range(num_periods + 1)})
pred_df['retention_pred'] = pred_df.period.map(lambda x: get_retention_same_event(a, c, d, x))
pred_df = pred_df.merge(df, how = 'left')
fig = go.Figure()
fig.add_trace(go.Scatter(x=pred_df.period, y=pred_df.retention_fact, name='fact',
line=dict(color=plotly.colors.qualitative.Prism[0], width=3)))
fig.add_trace(go.Scatter(x=pred_df.period, y=pred_df.retention_pred, name='prediction',
line=dict(color=plotly.colors.qualitative.Prism[0], width=3, dash='dot')))
fig.update_layout(title='Daily retention model (a = %.2f, c = %.2f, d = %.2f)' % (a, c, d),
yaxis_title='retention',
xaxis_title='period')
return fig
# takes file and return plot
def show_graph_for_file(temp_file, num_periods):
df = parse_file(temp_file, num_periods)
return show_graph_for_df(df, num_periods)
# hard-coded example of data
default_csv = 'period\tretention_fact\n0\t1\n1\t0.55\n2\t0.4\n3\t0.35\n4\t0.3\n'
# interface
with gr.Blocks() as demo:
gr.Markdown('# Predicting retention curve 📊')
periods = gr.Dropdown([10, 30, 90, 180], label="Number of Periods", value = 30)
gr.Markdown('Upload .csv file with data, use default data as an example or put in numbers manually in the Uploaded data section.')
gr.Markdown('''__File format:__ 2 columns (`period` and `retention_fact`)''')
with gr.Row():
upload_button = gr.UploadButton(label="Upload file", file_types = ['.csv'], live=True, file_count = "single")
default_button = gr.Button('Show example')
with gr.Row():
with gr.Accordion("Uploaded data", open=False):
gr.Markdown('You can change values in the table')
table = gr.Dataframe(type="pandas", col_count=2, interactive = True, headers = ['period', 'retention_fact'])
with gr.Row():
image = gr.Plot()
# business logic of triggers and events
upload_button.upload(fn=show_graph_for_file, inputs=[upload_button, periods], outputs=image, api_name="upload_graph")
upload_button.upload(fn=parse_file, inputs=[upload_button, periods], outputs=table, api_name="upload_csv")
default_button.click(fn=lambda x: show_graph_for_file(default_csv, x), inputs=[periods], outputs=image, api_name="upload_example_graph")
default_button.click(fn=lambda x: parse_file(default_csv, x), inputs=[periods], outputs=table, api_name="upload_example_csv")
table.change(fn=show_graph_for_df, inputs=[table, periods], outputs=image, api_name="upload_table_graph")
periods.change(fn=show_graph_for_df, inputs=[table, periods], outputs=image, api_name="upload_periods_graph")
demo.launch(debug=True)
让我们更详细地看一下。我们在界面中有以下元素:
-
periods——输入参数。 -
upload_button——输入参数,允许你从.csv文件中加载数据。 -
default_button——允许你用预定义的值更新表格和图表作为示例。 -
table显示上传数据的数据框(无论是来自.csv文件还是示例);此外,你还可以在表格中直接更改数字,图表将会更新——所以它也是一个输入参数。 -
image——输出参数,显示一个图表。

作者提供的图像
函数parse_file从upload_button获取文件或从默认示例中获取字符串,并返回一个我们可以进一步使用的pandas数据框。所以,使用文件中的数据是非常简单的。
关键的业务逻辑在下面的代码片段中定义。它定义了所有界面元素的操作:
-
用于上传
.csv文件——表格和图表都会更新。 -
用于点击“显示示例”按钮——表格和图表都会更新。
-
用于更改表格中的数据——仅更新图表。
-
用于更改周期数——仅更新图表。
upload_button.upload(fn=show_graph_for_file, inputs=[upload_button, periods], outputs=image, api_name="upload_graph")
upload_button.upload(fn=parse_file, inputs=[upload_button, periods], outputs=table, api_name="upload_csv")
default_button.click(fn=lambda x: show_graph_for_file(default_csv, x), inputs=[periods], outputs=image, api_name="upload_example_graph")
default_button.click(fn=lambda x: parse_file(default_csv, x), inputs=[periods], outputs=table, api_name="upload_example_csv")
table.change(fn=show_graph_for_df, inputs=[table, periods], outputs=image, api_name="upload_table_graph")
periods.change(fn=show_graph_for_df, inputs=[table, periods], outputs=image, api_name="upload_periods_graph")
定义最佳拟合函数
我们解决问题的关键部分是为实际数据找到最佳拟合函数。让我们看看如何做到这一点。
-
首先,我们定义了一个函数
get_mse_for_retention,它返回一组参数(a、c和d)的误差。它还接受数据框作为输入。 -
我们使用标准均方误差(MSE)作为我们将要最小化的误差。
-
然后,我们将使用
scipy.optimize.minimize函数进行优化。我们只需要传递两个参数:要优化的函数(我们传递了一个带有硬编码数据框的 lambda 函数,因为我们只优化参数)和参数的初始值(只是一个随机值的列表)。 -
优化后,我们可以使用
result.x访问最优参数。
def get_mse_for_retention(params, df):
tmp_df = df.copy()
tmp_df['retention_pred'] = tmp_df.index.map(
lambda x: get_retention_same_event(params[0], params[1], params[2], x)
)
tmp_df['se'] = (tmp_df.retention_fact - tmp_df.retention_pred)
tmp_df['se'] = tmp_df['se']**2
return tmp_df.se.mean() ** 0.5
result = scipy.optimize.minimize(lambda x: get_mse_for_retention(x, df), [random.random(), random.random(), random.random()])
a, c, d = result.x
print(a, c, d)
就这样,现在我们知道了我们实际数据的理论保留曲线,并且可以在应用程序中使用它进行预测。
最后一步
我也按照相同的指示将这个应用程序发布到了 HuggingFace Spaces。所以你可以在这里尝试使用它。
你可以在GitHub上找到这两个应用程序的完整代码。
总结
在这篇文章中,我们已经了解了 Gradio 库的基础知识,并学习了如何仅使用 Python 构建愉快的网页应用程序。
我们学到了几种方法:
-
高级
gr.Interface类,允许你快速获得一个可用的原型。 -
使用
gr.Blocks的更可定制化的方法,你可以指定所需的确切布局,并定义输入和输出之间的复杂关系。
非常感谢你阅读这篇文章。我希望这对你有所启发。如果你有任何后续问题或评论,请在评论区留言。
参考
本文的灵感来源于“使用 Gradio 构建生成式 AI 应用程序”课程。
理解 SQL:入门窗口函数
通过使用 SQL 窗口函数,获取更多的聚合信息
·
关注 发布于 Towards Data Science ·15 分钟阅读·2023 年 9 月 17 日
--
图片由 Components AI 提供,来自 Unsplash
简介
在 SQL 中聚合数据时,窗口函数比与GROUP BY子句结合使用的聚合方法提供了更大的灵活性。虽然这两种方法确实执行类似的功能,但窗口函数的不同在于输出的结构方式。具体来说,窗口函数对一组相关的行进行操作,这些行的关系由表行的某些分组或分区决定。而且,与将行合并为单一输出行的非窗口函数不同,所有行都保持其独立身份,并出现在输出表中。
这种行为与普通的聚合操作大相径庭,能够大大扩展你的分析工具箱,超越简单的汇总统计。例如,窗口函数允许我们计算累计总和、移动平均,甚至像z-得分这样的统计指标。
在这篇文章中,我们将查看 SQL 窗口函数的结构和基本功能。这里的重点有些基础,因此如果你还没有接触过窗口函数,或者对它们的使用经验有限,希望这对你会有一些帮助。
本文将使用关于 1930 年至 2022 年 FIFA 世界杯比赛的一些高级汇总数据。这些排名和统计数据来自维基百科,并在知识共享署名-相同方式共享(CC-BY-SA)许可下提供。数据和相关信息可以在这里找到。为了本博客的目的,我将表格导入到我自己的 PostgresSQL 数据库中,但如果你想跟随,你可以从我的Git 仓库中获取表格的副本。在我的数据库中,这个表格被称为world_cup_placings,下面显示了一个输出示例:
|year|start_date|end_date|host_country |first_place |second_place |third_place |fourth_place|total_teams|matches_played|total_goals|total_attendance|
|----|----------|--------|-------------|------------|--------------|-------------|------------|-----------|--------------|-----------|----------------|
|1930|13/07/30 |30/07/30|Uruguay |Uruguay |Argentina |United States|Yugoslavia |13 |18 |70 |590,549 |
|1934|27/05/34 |10/06/34|Italy |Italy |Czechoslovakia|Germany |Austria |16 |17 |70 |363,000 |
|1938|4/06/38 |19/06/38|France |Italy |Hungary |Brazil |Sweden |15 |18 |84 |374,835 |
|1950|24/06/50 |16/07/50|Brazil |Uruguay |Brazil |Sweden |Spain |13 |22 |88 |1,045,246 |
|1954|16/06/54 |4/07/54 |Switerland |West Germany|Hungary |Austria |Uruguay |16 |26 |140 |768,607 |
|1958|8/06/58 |29/06/58|Sweden |Brazil |Sweden |France |West Germany|16 |35 |126 |819,810 |
|1962|30/05/62 |17/06/62|Chile |Brazil |Czechoslovakia|Chile |Yugoslavia |16 |32 |89 |893,172 |
|1966|11/07/66 |30/07/66|England |England |West Germany |Portugal |Soviet Union|16 |32 |89 |1,563,135 |
|1970|31/05/70 |21/06/70|Mexico |Brazil |Italy |West Germany |Uruguay |16 |32 |95 |1,604,065 |
|1974|13/06/74 |7/07/74 |West Germany |West Germany|Netherlands |Poland |Brazil |16 |38 |97 |1,865,762 |
|1978|1/06/78 |25/06/78|Argentina |Argentina |Netherlands |Brazil |Italy |16 |38 |102 |1,545,791 |
|1982|13/06/82 |11/07/82|Spain |Italy |West Germany |Poland |France |24 |52 |146 |2,109,723 |
|1986|31/05/86 |29/06/86|Mexico |Argentina |West Germany |France |Belgium |24 |52 |132 |2,394,031 |
|1990|8/06/90 |8/07/90 |Italy |West Germany|Argentina |Italy |England |24 |52 |115 |2,516,215 |
|1994|17/06/94 |17/07/94|United States|Brazil |Italy |Sweden |Bulgaria |24 |52 |141 |3,597,042 |
|1998|10/06/98 |12/07/98|France |France |Brazil |Croatia |Netherlands |32 |64 |171 |2,785,100 |
|2002|31/05/02 |30/06/02|Korea / Japan|Brazil |Germany |Turkey |South Korea |32 |64 |161 |2,705,198 |
|2006|9/06/06 |9/07/06 |Germany |Italy |France |Germany |Portugal |32 |64 |147 |3,359,439 |
|2010|11/06/10 |11/07/10|South Africa |Spain |Netherlands |Germany |Uruguay |32 |64 |145 |3,178,856 |
|2014|12/06/14 |13/07/14|Brazil |Germany |Argentina |Netherlands |Brazil |32 |64 |171 |3,429,873 |
|2018|14/06/18 |15/07/18|Russia |France |Croatia |Belgium |England |32 |64 |169 |3,031,768 |
|2022|20/11/22 |18/12/22|Qatar |Argentina |France |Croatia |Morocco |32 |64 |172 |3,404,252 |
关于执行顺序的简要说明
理解 SQL 执行每个子句的顺序是很重要的,所以我们花几分钟时间来检查窗口函数在执行顺序中的位置。
窗口函数只能在 SELECT 列表和 ORDER BY 子句中使用。它们不能与 GROUP BY、HAVING 或 WHERE 子句一起使用。原因在于窗口函数是在这些子句处理之后执行的。另一个需要注意的事项是,窗口函数是在非窗口聚合函数(即 SUM、MAX、AVG 等)处理之后执行的。正如我们稍后将看到的,这很有用,因为这意味着我们实际上可以在窗口函数中使用这些函数。
OVER 子句
首先,让我们来看看窗口函数的最简单版本:
FUNCTION_NAME() OVER()
FUNCTION_NAME()只是一个占位符,用于你希望使用的任何函数;然而,窗口函数必须总是包含OVER子句。这个子句区别了窗口函数和非窗口函数,它的作用是确定如何将行分割以进行处理。在上面的例子中,虽然没有传递任何参数给OVER。这完全合法,所以让我们看看实际效果如何。
窗口函数的一个常见用例是为表中的每一行分配一个数值。这可以通过内置的ROW_NUMBER函数实现。例如,考虑下面的示例查询:
/* Assigning numbers to each row in the table */
SELECT
"year"
,host_country
,first_place
,total_goals
,ROW_NUMBER() OVER() AS row_num
FROM world_cup_placings;
;
|year|host_country |first_place |total_goals|row_num|
|----|-------------|------------|-----------|-------|
|1930|Uruguay |Uruguay |70 |1 |
|1934|Italy |Italy |70 |2 |
|1938|France |Italy |84 |3 |
|1950|Brazil |Uruguay |88 |4 |
|1954|Switerland |West Germany|140 |5 |
|1958|Sweden |Brazil |126 |6 |
|1962|Chile |Brazil |89 |7 |
|1966|England |England |89 |8 |
|1970|Mexico |Brazil |95 |9 |
|1974|West Germany |West Germany|97 |10 |
|1978|Argentina |Argentina |102 |11 |
|1982|Spain |Italy |146 |12 |
|1986|Mexico |Argentina |132 |13 |
|1990|Italy |West Germany|115 |14 |
|1994|United States|Brazil |141 |15 |
|1998|France |France |171 |16 |
|2002|Korea / Japan|Brazil |161 |17 |
|2006|Germany |Italy |147 |18 |
|2010|South Africa |Spain |145 |19 |
|2014|Brazil |Germany |171 |20 |
|2018|Russia |France |169 |21 |
|2022|Qatar |Argentina |172 |22 |
我们现在有一个新创建的列叫做row_num,它包含一个顺序编号列表;每个表行一个。仅使用普通的OVER子句,窗口函数将整个表视为一个单一分区。这是因为我们没有告诉它其他信息。
让我们看看如果用聚合函数,如SUM,替代ROW_NUMBER并应用于total_goals列(即整个比赛期间的总进球数)会发生什么。下面提供了这个查询:
/* Using SUM() within our window function */
SELECT
"year"
,host_country
,first_place
,total_goals
,SUM(total_goals) OVER() AS all_goals
FROM world_cup_placings
;
|year|host_country |first_place |total_goals|all_goals|
|----|-------------|------------|-----------|---------|
|1930|Uruguay |Uruguay |70 |2,720 |
|1934|Italy |Italy |70 |2,720 |
|1938|France |Italy |84 |2,720 |
|1950|Brazil |Uruguay |88 |2,720 |
|1954|Switerland |West Germany|140 |2,720 |
|1958|Sweden |Brazil |126 |2,720 |
|1962|Chile |Brazil |89 |2,720 |
|1966|England |England |89 |2,720 |
|1970|Mexico |Brazil |95 |2,720 |
|1974|West Germany |West Germany|97 |2,720 |
|1978|Argentina |Argentina |102 |2,720 |
|1982|Spain |Italy |146 |2,720 |
|1986|Mexico |Argentina |132 |2,720 |
|1990|Italy |West Germany|115 |2,720 |
|1994|United States|Brazil |141 |2,720 |
|1998|France |France |171 |2,720 |
|2002|Korea / Japan|Brazil |161 |2,720 |
|2006|Germany |Italy |147 |2,720 |
|2010|South Africa |Spain |145 |2,720 |
|2014|Brazil |Germany |171 |2,720 |
|2018|Russia |France |169 |2,720 |
|2022|Qatar |Argentina |172 |2,720 |
好的,我们为每一行只得到一个值——这正是我们应该预期的。记住,窗口函数是在一个分区内应用到所有行的,而这里(如上所述),窗口函数将整个表视为一个分区。因此,它将只对total_goals列中的所有值进行求和。此外,窗口函数保留了行的身份,因此这个输出值会为每一行重复。还要注意,我们在窗口函数中使用了一个聚合函数——这是可能的,因为窗口函数是在聚合函数之后处理的。
好的,让我们看看通过计算所有比赛中的平均进球数(注意输出值的四舍五入)我们可以推进到什么程度:
/* Computing the average number of goals */
SELECT
"year"
,host_country
,first_place
,total_goals
,ROUND(AVG(total_goals) OVER(), 0) AS mean_goals
FROM world_cup_placings
;
|year|host_country |first_place |total_goals|mean_goals|
|----|-------------|------------|-----------|----------|
|1930|Uruguay |Uruguay |70 |124 |
|1934|Italy |Italy |70 |124 |
|1938|France |Italy |84 |124 |
|1950|Brazil |Uruguay |88 |124 |
|1954|Switerland |West Germany|140 |124 |
|1958|Sweden |Brazil |126 |124 |
|1962|Chile |Brazil |89 |124 |
|1966|England |England |89 |124 |
|1970|Mexico |Brazil |95 |124 |
|1974|West Germany |West Germany|97 |124 |
|1978|Argentina |Argentina |102 |124 |
|1982|Spain |Italy |146 |124 |
|1986|Mexico |Argentina |132 |124 |
|1990|Italy |West Germany|115 |124 |
|1994|United States|Brazil |141 |124 |
|1998|France |France |171 |124 |
|2002|Korea / Japan|Brazil |161 |124 |
|2006|Germany |Italy |147 |124 |
|2010|South Africa |Spain |145 |124 |
|2014|Brazil |Germany |171 |124 |
|2018|Russia |France |169 |124 |
|2022|Qatar |Argentina |172 |124 |
现在,这非常有用。将平均进球数与比赛总数一起列出,让我们可以直接比较这些值。我们可以很容易地看到个别总数与所有比赛中的平均数相比如何。
让我们进一步推展,通过计算每一行的z-分数来实现。为此,我们还需要在另一个窗口函数中使用STDDEV函数。下面的查询展示了如何做到这一点:
/* Compute z-score for total goals scored */
SELECT
"year"
,host_country
,first_place
,total_goals
,ROUND(AVG(total_goals) OVER(), 0) AS mean_goals
,ROUND((total_goals - AVG(total_goals) OVER()) /
STDDEV(total_goals) OVER(), 2) AS z_score
FROM world_cup_placings
;
|year|host_country |first_place |total_goals|mean_goals|z_score|
|----|-------------|------------|-----------|----------|-------|
|1930|Uruguay |Uruguay |70 |124 |-1.54 |
|1934|Italy |Italy |70 |124 |-1.54 |
|1938|France |Italy |84 |124 |-1.14 |
|1950|Brazil |Uruguay |88 |124 |-1.02 |
|1954|Switerland |West Germany|140 |124 |0.47 |
|1958|Sweden |Brazil |126 |124 |0.07 |
|1962|Chile |Brazil |89 |124 |-0.99 |
|1966|England |England |89 |124 |-0.99 |
|1970|Mexico |Brazil |95 |124 |-0.82 |
|1974|West Germany |West Germany|97 |124 |-0.76 |
|1978|Argentina |Argentina |102 |124 |-0.62 |
|1982|Spain |Italy |146 |124 |0.64 |
|1986|Mexico |Argentina |132 |124 |0.24 |
|1990|Italy |West Germany|115 |124 |-0.25 |
|1994|United States|Brazil |141 |124 |0.5 |
|1998|France |France |171 |124 |1.36 |
|2002|Korea / Japan|Brazil |161 |124 |1.07 |
|2006|Germany |Italy |147 |124 |0.67 |
|2010|South Africa |Spain |145 |124 |0.61 |
|2014|Brazil |Germany |171 |124 |1.36 |
|2018|Russia |France |169 |124 |1.3 |
|2022|Qatar |Argentina |172 |124 |1.39 |
看起来相当不错!
PARTITION BY 子句
到目前为止,我一直在使用分区这个术语,所以让我们看看这实际上是什么意思。回顾之前的例子,我们没有明确说明希望如何对表进行分区,所以操作是在整个表的所有行上进行的。另一方面,PARTITION BY子句被称为在OVER子句内,它规定了行应该如何分成组或分区。包含这个子句后,窗口函数的结构现在看起来像这样:
FUNCTION_NAME() OVER( PARTITION BY [var] )
占位符[var]指的是用于分组行的列。为了演示,我们再试一次使用ROW_NUMBER函数对行进行编号,这次我们将使用first_place列来对行进行分区。查看下面的查询:
/* Add row numbers to each partition */
SELECT
"year"
,host_country
,first_place
,total_goals
,ROW_NUMBER() OVER( PARTITION BY first_place ) AS row_num
FROM world_cup_placings
;
|year|host_country |first_place |total_goals|row_num|
|----|-------------|------------|-----------|-------|
|1978|Argentina |Argentina |102 |1 |
|1986|Mexico |Argentina |132 |2 |
|2022|Qatar |Argentina |172 |3 |
|1962|Chile |Brazil |89 |1 |
|2002|Korea / Japan|Brazil |161 |2 |
|1994|United States|Brazil |141 |3 |
|1958|Sweden |Brazil |126 |4 |
|1970|Mexico |Brazil |95 |5 |
|1966|England |England |89 |1 |
|1998|France |France |171 |1 |
|2018|Russia |France |169 |2 |
|2014|Brazil |Germany |171 |1 |
|1982|Spain |Italy |146 |1 |
|1934|Italy |Italy |70 |2 |
|1938|France |Italy |84 |3 |
|2006|Germany |Italy |147 |4 |
|2010|South Africa |Spain |145 |1 |
|1930|Uruguay |Uruguay |70 |1 |
|1950|Brazil |Uruguay |88 |2 |
|1974|West Germany |West Germany|97 |1 |
|1954|Switerland |West Germany|140 |2 |
|1990|Italy |West Germany|115 |3 |
好的,情况与上次看起来非常不同。首先,我们可以看到输出已使用first_place列对表进行了排序,但这并不特别有趣。有趣的是row_num列的变化。从顶部开始向下查看,我们可以看到编号序列在first_place的每个不同值上都会重置,只计数与每个国家的分区相关的行。
让我们在查询基础上再做些扩展。除了row_num列,我们还计算总进球数和最大总出席人数。我们将再次使用first_place列来对表进行分区,因此这些聚合将仅适用于与每个分区相关的行。下面的查询展示了如何做到这一点:
/* Adding more window functions to the SELECT list */
SELECT
"year"
,host_country
,first_place
,ROW_NUMBER() OVER( PARTITION BY first_place ) AS row_num
,total_goals
,SUM(total_goals) OVER( PARTITION BY first_place ) AS all_goals
,total_attendance
,MAX(total_attendance) OVER( PARTITION BY first_place ) AS max_attendance
FROM world_cup_placings
;
|year|host_country |first_place |row_num|total_goals|all_goals|total_attendance|max_attednance|
|----|-------------|------------|-------|-----------|---------|----------------|--------------|
|1978|Argentina |Argentina |1 |102 |406 |1,545,791 |3,404,252 |
|1986|Mexico |Argentina |2 |132 |406 |2,394,031 |3,404,252 |
|2022|Qatar |Argentina |3 |172 |406 |3,404,252 |3,404,252 |
|1962|Chile |Brazil |1 |89 |612 |893,172 |3,597,042 |
|2002|Korea / Japan|Brazil |2 |161 |612 |2,705,198 |3,597,042 |
|1994|United States|Brazil |3 |141 |612 |3,597,042 |3,597,042 |
|1958|Sweden |Brazil |4 |126 |612 |819,810 |3,597,042 |
|1970|Mexico |Brazil |5 |95 |612 |1,604,065 |3,597,042 |
|1966|England |England |1 |89 |89 |1,563,135 |1,563,135 |
|1998|France |France |1 |171 |340 |2,785,100 |3,031,768 |
|2018|Russia |France |2 |169 |340 |3,031,768 |3,031,768 |
|2014|Brazil |Germany |1 |171 |171 |3,429,873 |3,429,873 |
|1982|Spain |Italy |1 |146 |447 |2,109,723 |3,359,439 |
|1934|Italy |Italy |2 |70 |447 |363,000 |3,359,439 |
|1938|France |Italy |3 |84 |447 |374,835 |3,359,439 |
|2006|Germany |Italy |4 |147 |447 |3,359,439 |3,359,439 |
|2010|South Africa |Spain |1 |145 |145 |3,178,856 |3,178,856 |
|1930|Uruguay |Uruguay |1 |70 |158 |590,549 |1,045,246 |
|1950|Brazil |Uruguay |2 |88 |158 |1,045,246 |1,045,246 |
|1974|West Germany |West Germany|1 |97 |352 |1,865,762 |2,516,215 |
|1954|Switerland |West Germany|2 |140 |352 |768,607 |2,516,215 |
|1990|Italy |West Germany|3 |115 |352 |2,516,215 |2,516,215 |
现在,这些聚合本身并不是特别有用,它们只是为了演示目的。但让我们看看从这个查询中实际得到的结果。
首先,all_goals列提供了每个获胜国家在所有比赛中进球的总数(回忆一下,我们已使用first_place列对行进行了分区)。例如,在阿根廷获胜的比赛中,我们可以看到在 1978 年阿根廷主办时进了 102 球,1986 年墨西哥主办时进了 132 球,最近在卡塔尔进了 172 球。all_goals值就是这两个值的总和,即 406。
其次,max_attendance值返回每个获胜国家的最高比赛出席人数。例如,在表中列出的世界杯中,意大利赢得了其中四场,其中 2006 年在德国举办的世界杯的最高出席人数(3,359,439)。因此,对于所有意大利获胜的行,这就是max_attendance返回的值。
回顾查询,你可能会注意到我们需要写OVER( PARTITION BY first_place )三次;每个窗口函数在SELECT列表中出现一次。如果你的查询需要多个窗口函数,并且使用相同的列进行分区,这可能会变得有些乏味。那么有没有更好的方法来实现相同的目标?有的。确实有。在这些情况下——所有函数的窗口设置相同——我们可以使用一个单独的WINDOW子句来定义分区,并为其指定一个可以通过OVER调用的名称。查看下面的查询:
/* Same as above, but now using the WINDOW clause */
SELECT
"year"
,host_country
,first_place
,ROW_NUMBER() OVER w AS row_num
,total_goals
,SUM(total_goals) OVER w AS all_goals
,total_attendance
,MAX(total_attendance) OVER w AS max_attendance
FROM world_cup_placings
WINDOW w AS ( PARTITION BY first_place )
;
现在,这并不会消除在OVER子句后提供参数的需要——这是无法避免的——但这种方法出错的可能性要小得多,并且确实有助于清理查询。你可以检查结果是否与上面的输出一致。
ORDER BY子句
另一个可以添加到OVER中的子句是ORDER BY子句。这个子句肯定很熟悉,实际上它的工作方式完全符合你的预期。唯一的不同是,当我们在OVER子句中使用它时,它会影响每个分区内的排序。包含这个子句后,我们的窗口函数现在看起来是这样的:
FUNCTION_NAME() OVER( PARTITION BY [var] ORDER BY [val] )
我们可以使用ORDER BY子句在执行操作之前对值进行排序。如果我们想使用RANK函数对值进行排名,这非常有用,因为RANK函数会为每个不同的ORDER BY值分配一个数值。这意味着它的行为与ROW_NUMBER函数略有不同,因为RANK函数会为重复值分配相同的排名。查看下面的查询并比较行输出:
/* Comparing ROW_NUMBER and RANK using ORDER BY clause */
SELECT
"year"
,host_country
,first_place
,total_goals
,ROW_NUMBER() OVER( ORDER BY total_goals DESC ) AS row_num
,RANK() OVER( ORDER BY total_goals DESC ) AS row_rank
FROM world_cup_placings
;
|year|host_country |first_place |total_goals|row_num|row_rank|
|----|-------------|------------|-----------|-------|--------|
|2022|Qatar |Argentina |172 |1 |1 |
|1998|France |France |171 |2 |2 |
|2014|Brazil |Germany |171 |3 |2 |
|2018|Russia |France |169 |4 |4 |
|2002|Korea / Japan|Brazil |161 |5 |5 |
|2006|Germany |Italy |147 |6 |6 |
|1982|Spain |Italy |146 |7 |7 |
|2010|South Africa |Spain |145 |8 |8 |
|1994|United States|Brazil |141 |9 |9 |
|1954|Switerland |West Germany|140 |10 |10 |
|1986|Mexico |Argentina |132 |11 |11 |
|1958|Sweden |Brazil |126 |12 |12 |
|1990|Italy |West Germany|115 |13 |13 |
|1978|Argentina |Argentina |102 |14 |14 |
|1974|West Germany |West Germany|97 |15 |15 |
|1970|Mexico |Brazil |95 |16 |16 |
|1966|England |England |89 |17 |17 |
|1962|Chile |Brazil |89 |18 |17 |
|1950|Brazil |Uruguay |88 |19 |19 |
|1938|France |Italy |84 |20 |20 |
|1930|Uruguay |Uruguay |70 |21 |21 |
|1934|Italy |Italy |70 |22 |21 |
看到了区别吗?如果我们查看第二行和第三行,我们会看到巴西 2014 和法国 1998 都得到了 171 个进球。虽然row_num列为这些行分配了不同的值,但row_rank列则分配了相同的排名。
需要注意的是,对于ROW_NUMBER窗口函数,我实际上不需要指定ORDER BY total_goals DESC,但我这样做是为了强调ROW_NUMBER窗口函数不关心重复值;它只是为每一行编号,不管是否按total_goals排序。
好的,现在让我们尝试将ORDER BY和PARTITION BY子句一起使用,以找出每个团队最后一次赢得世界杯的年份。首先,我们将使用first_place列来对表进行分区,然后使用year列按降序对行进行排序。以下查询展示了如何实现这一点:
/* Rank winning teams by competition year to find most recent win */
SELECT
"year"
,host_country
,first_place
,RANK() OVER( PARTITION BY first_place ORDER BY "year" DESC ) AS row_rank
FROM world_cup_placings
;
|year|host_country |first_place |row_rank|
|----|-------------|------------|--------|
|2022|Qatar |Argentina |1 |
|1986|Mexico |Argentina |2 |
|1978|Argentina |Argentina |3 |
|2002|Korea / Japan|Brazil |1 |
|1994|United States|Brazil |2 |
|1970|Mexico |Brazil |3 |
|1962|Chile |Brazil |4 |
|1958|Sweden |Brazil |5 |
|1966|England |England |1 |
|2018|Russia |France |1 |
|1998|France |France |2 |
|2014|Brazil |Germany |1 |
|2006|Germany |Italy |1 |
|1982|Spain |Italy |2 |
|1938|France |Italy |3 |
|1934|Italy |Italy |4 |
|2010|South Africa |Spain |1 |
|1950|Brazil |Uruguay |1 |
|1930|Uruguay |Uruguay |2 |
|1990|Italy |West Germany|1 |
|1974|West Germany |West Germany|2 |
|1954|Switerland |West Germany|3 |
通过按降序排列比赛年份,我们确保每个分区的第一行是最新的年份。现在,由于每年只能有一个获胜者,因此year列中没有重复值。因此,当我们为每个分区分配排名时,第一行将始终具有排名 1,这对应于每个国家的最新胜利。
然而,这个输出有点杂乱,我们必须阅读每一行以找出每个分区的排名。我们可以通过将上述查询放在子查询中,并过滤出仅具有排名 1 的行来整理一下。查看下面的查询以了解实际效果:
/* Filter only rows that have a rank of 1 */
SELECT
"year"
,first_place
FROM (
SELECT
"year"
,host_country
,first_place
,RANK() OVER( PARTITION BY first_place ORDER BY "year" DESC ) AS row_rank
FROM world_cup_placings
) AS subq
WHERE row_rank = 1
;
|year|first_place |
|----|------------|
|2022|Argentina |
|2002|Brazil |
|1966|England |
|2018|France |
|2014|Germany |
|2006|Italy |
|2010|Spain |
|1950|Uruguay |
|1990|West Germany|
最终备注
窗口函数非常灵活,当需要仅使用子集行来计算值时,它们通常更加高效。此帖仅触及窗口函数的表面,并仅演示了它们的基本功能。不管怎样,希望你在此帖中找到了一些有用的内容。在后续的帖子中,我们将基于这些概念,探索如何使用窗口函数计算一些更为复杂的度量。
相关文章
感谢阅读!
如果你喜欢这篇文章并希望保持更新,请考虑关注我的 Medium 账号。这将确保你不会错过任何新内容。
若要无限制访问所有内容,请考虑注册Medium 订阅。
你还可以在Twitter、LinkedIn上关注我,或者查看我的GitHub,如果这更符合你的兴趣。
理解 SQL:执行顺序
关于数据库如何解释你的 SQL 查询的简要指南
·
关注 发表在 Towards Data Science ·7 min 阅读·2023 年 4 月 3 日
--
图片由 Wengang Zhai 提供,来源于 Unsplash
介绍
编写高效的 SQL 查询是处理大数据量的任何数据分析师必备的技能。我相信我们中的许多人都经历过这样的痛苦:在小规模数据上运行良好的查询,应用到大规模数据集时却慢得令人抓狂。
通常,通过简单了解数据库是如何解析查询的,可以显著提高查询性能。这不仅有助于优化查询速度和性能,还能帮助调试和解决错误的脚本。
所以今天,我将带你了解 SQL 查询的执行顺序,并讨论构建查询时出现的一些常见错误。
声明式与过程式语言
首先,重要的是要了解 SQL 是一种声明式编程语言。这意味着我们定义所需的结果,但不提供如何实现它的指令。这与命令式或过程式语言形成对比,这些语言要求每一步产生输出的过程都必须明确指定。使用像 SQL 这样的声明式语言的含义是,虽然 SQL 期望语句按照特定顺序编写,但语句的评估顺序会有所不同。
为了演示,以下是构建 SQL 查询时常用的七个子句及其使用顺序的列表:
1\. SELECT
2\. FROM
3\. WHERE
4\. GROUP BY
5\. HAVING
6\. ORDER BY
7\. LIMIT
现在将其与执行顺序进行比较:
1\. FROM
2\. WHERE
3\. GROUP BY
4\. HAVING
5\. SELECT
6\. ORDER BY
7\. LIMIT
如你所见,语句已经有些打乱。例如,注意到虽然 SELECT 子句写在第一位,但在执行时它的位置要低得多。正如我们将很快看到的,执行顺序是最重要的,这是分析师必须特别注意的。
FROM 子句
自然,数据库需要知道数据的来源,因此这是一个逻辑上的第一步。虽然较简单的查询可能只涉及一个表,但通常情况下,所需的信息存在于多个表中。因此,JOIN 语句与 FROM 语句一起使用,以合并源表。如果需要进行联接,则数据库首先会将所有内容汇集在一起。
这意味着你应该考虑源表的大小、使用的联接类型以及联接中使用的谓词数量。例如,通过仅选择必要的列、筛选掉不必要的行,并确保有共同的标识符以完成联接,来减少源表的大小,这将提高效率。此外,INNER JOIN 应优先于 OUTER JOIN,因为前者通常更快。
最终,你不希望处理不必要的数据,因此将处理的数据集最小化应该是主要目标,尽可能做到这一点。
WHERE 子句
该子句用于通过仅返回满足给定条件的行来筛选基本表或联接输出。可以使用任何支持的数据类型来过滤记录。例如,考虑下表,它列出了少量英联邦城市及其人口:

一个名为‘cities’的小示例表(作者图片)。
如果我们要将这个表过滤到只包含新西兰的城市,我们可以写出如下查询:
SELECT
city
,country
FROM
cities
WHERE
country = 'New Zealand';
这将返回包含奥克兰、基督城和惠灵顿的行。或者,如果我们想返回所有人口超过五十万的城市,那么查询将是这样的:
SELECT
city
,country
FROM
cities
WHERE
population > 500000;
我们还可以使用 AND 操作符组合这些过滤条件,这样只会返回奥克兰:
SELECT
city
,country
FROM
cities
WHERE
country = 'New Zealand'
AND population > 500000;
关于 WHERE 子句,有一个重要的事情需要记住,那就是它不能用于过滤聚合列。例如,看看下面修改后的查询:
SELECT
country
,SUM(popualtion)
FROM
cities
WHERE
SUM(popualtion) > 5000000
GROUP BY
country;
上述查询的意图是获取所有城市总人口超过 500 万的国家。不幸的是,这个查询会失败,因为在 WHERE 语句中使用了聚合函数。问题在于聚合函数需要 GROUP BY 子句,而 GROUP BY 子句是在 WHERE 子句之后执行的。这意味着 WHERE 条件无法被评估,因为数据库尚未知道任何聚合变量。
我们将很快看到如何解决这个问题,但在此之前,让我们快速了解一下 GROUP BY 子句。
GROUP BY 子句
正如你可能已经识别的,这个子句允许我们聚合或汇总数量,并与 COUNT()、SUM()、MIN()、MAX() 等函数结合使用。实际上,GROUP BY 将变量或变量组折叠,并为每个不同的元素或元素组合返回一个单一的值。例如,如果我们想统计每个国家的城市人口,我们可以按国家分组,如下所示:
SELECT
country
,SUM(popualtion)
FROM
cities
GROUP BY
country;
输出将返回四行 —— 每个国家一行 —— 以及表中列出的每个国家的聚合人口。
HAVING 子句
这个子句解决了之前尝试使用聚合函数进行过滤时遇到的问题,即 WHERE 子句。HAVING 子句允许我们使用分组和聚合的数据进行结果过滤,因为它是在 GROUP BY 语句之后执行的。数据库现在已经知道了这些聚合,这意味着它们可以在所有后续的语句中使用。我们现在可以像这样修改之前的查询:
SELECT
country
,SUM(population)
FROM
cities
GROUP BY
country
HAVING
SUM(popualtion) > 5000000;
这将只返回两个国家:澳大利亚和英国。
SELECT 子句
SELECT 子句是我们定义所需列的地方,同时还包括任何分组和聚合字段。这也是我们可以使用 AS 操作符应用列别名的地方。现在,虽然 select 语句在构建查询时排在首位,但它不会在数据被源化和过滤后执行。这一点很重要,因为这意味着聚合变量和别名不能在 WHERE、GROUP BY 或 HAVING 语句中使用。
例如,考虑以下查询,它创建了一个列别名total_pop,然后在 HAVING 子句中使用。这个查询会抛出错误,因为别名尚未创建。HAVING 子句位于 SELECT 子句之前,因此没有名为total_pop的引用**。
SELECT
country
,SUM(population) AS total_pop
FROM
cities
GROUP BY
country
HAVING
total_pop > 5000000;
我不会详细讲解这些内容,但 DISTINCT 和 UNION 语句是在 SELECT 之后执行,在 ORDER BY 子句之前执行,其中 DISTINCT 在 UNION 之前执行。
ORDER BY 子句
我们现在接近尾声,许多重要的工作已经完成。我们已经来源(并可能联接)了表,应用了一些过滤,分组并汇总了一些字段,并指定了我们希望在最终表中包含的列。
在这个阶段,你可能会考虑你希望目标表中数据的排列方式。例如,你可能希望按时间顺序排列行,或者基于某些排名值进行排序。这正是 ORDER BY 子句的作用。
这个语句的好处在于,因为它位于排序的后端,我们可以在 GROUP BY 语句中使用聚合和列别名。例如,假设我们想按城市总人口对国家进行排序。我们可以编写如下查询:
SELECT
country
,SUM(population) AS total_pop
FROM
cities
GROUP BY
country
ORDER BY
total_pop;
请注意,我们可以在 ORDER BY 语句中使用列别名total_pop。默认情况下,这将按升序返回记录(即从小到大)。要按降序返回行,我们可以使用 DESC 运算符,如下所示:
SELECT
country
,SUM(population) AS total_pop
FROM
cities
GROUP BY
country
ORDER BY
total_pop DESC;
LIMIT 子句
在处理大型表时,通常不建议让查询返回所有行,特别是在你仅进行开发和测试时。LIMIT 子句在这里非常有用,允许我们定义希望返回的行数。它也可以与 ORDER BY 子句结合使用,以返回前 n或后 n条记录。例如,假设我们想要表中人口最多的前三个城市。我们可以如下使用 ORDER BY 和 LIMIT 子句:
SELECT
city
,country
,population AS city_pop
FROM
cities
ORDER BY
city_pop DESC
LIMIT
3;
请注意,并非所有数据库都支持 LIMIT 语句,但它们会有执行类似功能的等效语句。
总结
语句执行的顺序是构建 SQL 查询时需要掌握的重要概念,我们已经触及了一些常见的陷阱。虽然我没有提供详细的示例,但我希望这个简短的入门介绍能让你思考如何提高查询性能。如果你刚刚开始接触 SQL,希望这篇文章能帮助你在学习的旅程中前进。
感谢阅读!
如果你喜欢这篇文章并希望保持更新,请考虑关注我在 Medium 上的账号。 这将确保你不会错过任何新内容。
要获得所有内容的无限访问权,请考虑订阅Medium。
你还可以在 Twitter、LinkedIn 上关注我,或者查看我的 GitHub,如果你更喜欢这样的话 😉
了解 TF-IDF:NLP 中的一种传统特征提取方法
学习 TF-IDF 的基础知识以及如何在 Python 中从零实现它
·发表于Towards Data Science ·阅读时间 9 分钟·2023 年 3 月 30 日
--

照片由Aaron Burden拍摄,来源于Unsplash
介绍
特征提取是自然语言处理(NLP)中的一个重要初始步骤,它涉及将文本数据转换为数学表示,通常是向量形式,这被称为词嵌入。存在各种词嵌入方法,从经典的方法如word2vec和GloVe到更现代的BERT嵌入。虽然基于transformer的嵌入今天主导了 NLP 领域,但了解以前方法的演变仍然是有帮助的。
在本文中,我们将探索一种被称为TF-IDF的传统特征提取方法,该方法基于统计分析。我们将深入探讨 TF-IDF 及其实现,并提供一个额外的应用以帮助巩固您的理解。因此,请跟随我们,直到最后,揭开 TF-IDF 的方方面面!
什么是 TF-IDF?
TF-IDF,全称为词频-逆文档频率,是自然语言处理(NLP)中常用的技术,用于确定文档或语料库中单词的重要性。为了提供一些背景信息,一项调查显示,2015 年 83%的基于文本的推荐系统在数字图书馆中使用 TF-IDF 来提取文本特征。这表明了这一技术的流行程度。本质上,它通过比较词在特定文档中的频率与其在整个语料库中的频率来衡量词的重要性。其基本假设是,频繁出现在文档中但在语料库中很少出现的词在该文档中尤为重要。
现在,让我们看看计算 TF-IDF 的数学公式:
TF(词频)通过计算词在文档中的出现频率并将其除以文档中的总词数来确定。
TF = (词在文档中出现的次数) / (文档中的总词数)
另一方面,IDF(逆文档频率)测量词在整个语料库中的重要性。其计算公式为:
IDF = log((语料库中的文档总数) / (包含该词的文档数))
最终,某词在给定文档中的 TF-IDF 分数是该词 TF 和 IDF 分数的乘积。结果的 TF-IDF 分数越高,表明该词在文档中的相对重要性越高。
用 Python 实现 TF-IDF
现在我们已经了解了 TF-IDF 的数学计算方法,让我们用 Python 实现它。虽然有库可以更快地计算 TF-IDF 特征,但本文将专注于从头开始构建它。
设置和预处理
首先,让我们导入后续需要的必要包,例如来自collections模块的Counter类。
import re
import math
from collections import Counter
import numpy as np
接下来,我们将定义一个文档/语料库的列表作为示例。我们借用最近围绕ChatGPT和生成性 AI 的热潮。
docs = ["ChatGPT is a AI chatbot developed by OpenAI.",
"ChatGPT is built on top of the GPT family of large language models.",
"Generative AI is rising in popularity and has started to transform businesses in various ways possible."]
在计算文本特征之前,标准做法是首先对文档进行预处理,如转换为小写、词形还原、词干提取、去除停用词等。在本示例中,我们将文档转换为小写并去除标点符号。然而,根据任务的不同,可以进行更多的预处理,这些步骤可以使用NLTK或SpaCy等 NLP 库来完成。我们还将跟踪语料库中的唯一词汇集。
p_docs = []
tok_set = []
for doc in docs:
p_doc = re.sub(r'[^\w\s]', '', doc.lower())
p_docs.append(p_doc)
tok_set.extend(p_doc.split())
tok_set = set(tok_set)
print(p_docs)
print(tok_set)

计算语料库中唯一词汇的 IDF
在获得词汇集之后,我们可以使用上述公式计算语料库中每个词的 IDF。
def calculate_idf(p_docs, tok_set):
idf = {}
for tok in tok_set:
N = len(p_docs)
df = 0
for doc in p_docs:
if tok in doc.split():
df += 1
idf[tok] = math.log(N/df)
return idf
idf = calculate_idf(p_docs, tok_set)
print(idf)

注意,像“openai”和“gpt”这样的词的 IDF 高于“chatgpt”或“ai”,因为前者在语料库中的出现频率较低。
计算每个文档中每个词的 TF
虽然 IDF 是在整个语料库中对每个词计算的,但 TF 是在每个文档中对每个词计算的。使用 TF 的公式,我们可以快速获取文档中某个词的计数,并使用 Counter 类计算其相对频率。
def calculate_tf(tok, p_doc):
toks = p_doc.split()
tok_freq = Counter(toks)
if tok in tok_freq:
return tok_freq[tok]/len(toks)
return 0
这是“chatgpt”在每个文档中的词频示例。
print(calculate_tf("chatgpt", p_docs[0]))
print(calculate_tf("chatgpt", p_docs[1]))
print(calculate_tf("chatgpt", p_docs[2]))

最后,进入 TF-IDF
现在,是时候实现 TF-IDF 函数了。我们可以先将之前的预处理代码包装成一个函数,以便在计算 TF-IDF 时调用它。
def prepare_docs(docs):
p_docs = []
tok_set = []
for doc in docs:
p_doc = re.sub(r'[^\w\s]', '', doc.lower())
p_docs.append(p_doc)
tok_set.extend(p_doc.split())
tok_set = set(tok_set)
return p_docs, tok_set
def tf_idf(tok, docs):
p_docs, tok_set = prepare_docs(docs)
print(f"calculating tf-idf for {tok} in all docs...")
idf_dict = calculate_idf(p_docs, tok_set)
idf = idf_dict[tok] if tok in idf_dict else 0.0
print(f"idf for {tok}: {round(idf, 4)}")
for i, doc in enumerate(p_docs):
tf = calculate_tf(tok, doc)
tf_idf = tf * idf
print(f"Doc {i+1}: {doc}, tf: {round(tf, 4)}, tf-idf: {round(tf_idf, 4)}")
让我们尝试一些示例。
print(tf_idf("chatgpt", docs))

print(tf_idf("gpt", docs))

print(tf_idf("generative", docs))

print(tf_idf("is", docs))

print(tf_idf("gpt4", docs))

从上面的例子中,我们可以看到“chatgpt”在文档 1 和 2 中的 TF-IDF 高于文档 3,因为该词没有出现在文档 3 中。尽管“chatgpt”在文档 1 和 2 中只出现过一次,但前者的 TF 稍高,因为它的词数更少,从而导致更高的 TF-IDF。
“gpt”在文档 1 和 3 中的 TF-IDF 为 0,因为它们都不包含这个词。文档 2 中有“gpt”,然而,其 TF-IDF 高于文档 1 中“chatgpt”的 TF-IDF,因为它在语料库中的稀有出现超越了第二个文档的较长长度。
“generative”在文档 1 和 2 中的 TF-IDF 为 0,因为没有出现。“is”在所有文档中的 TF-IDF 都为 0,因为它出现在所有文档中。像“gpt4”这样的未登录词也有 TF-IDF 为 0。
这里是如何在 Python 中实现 TF-IDF 的方法。
奖励:用 TF-IDF 构建搜索引擎 MVP
TF-IDF 有许多潜在的应用,其中之一是构建搜索引擎。在本节中,我们将探讨如何使用 TF-IDF 构建一个简单的搜索引擎 MVP。我们采取的方法是通过 TF-IDF 词汇的总和来排名文档,总和越高,文档排名越高。
要开始,先修改 tf_idf 函数,将 TF-IDF 结果附加到列表中,并将未登录词的 TF-IDF 设置为 0。
def tf_idf(tok, docs):
p_docs, tok_set = prepare_docs(docs)
idf_dict = calculate_idf(p_docs, tok_set)
idf = idf_dict[tok] if tok in idf_dict else 0
tf_idfs = []
for i, doc in enumerate(p_docs):
tf = calculate_tf(tok, doc)
tf_idf = tf * idf
tf_idfs.append(tf_idf)
return tf_idfs
通过这个修改,我们现在可以实现搜索引擎。下面的代码计算每个查询词在语料库中的 TF-IDF 值,并按其总和对文档进行排名。
def search_query(query, docs):
print(f"searching for: {query}")
terms = query.lower().split()
score = 0
tf_idfs = []
for tok in terms:
tf_idfs.append(tf_idf(tok, docs))
tf_idfs = np.array(tf_idfs)
print(tf_idfs)
doc_scores = np.sum(tf_idfs, axis=0) # summation of tf_idfs of all query terms for each doc
print(doc_scores)
rank_doc = np.argsort(doc_scores)[::-1]
print("docs ranked in order of relevance:")
for i in rank_doc:
print(f"Doc {i+1}: {docs[i]}, score: {doc_scores[i]}")
让我们尝试搜索查询“ChatGPT AI”:
print(search_query("ChatGPT AI", docs))

我们可以看到,“ChatGPT AI”与文档 1 的相关性最高,因为它的 TF-IDF 最高。这是有道理的,因为“ChatGPT”和“AI”都包含在该文档中,而另外两个文档各自只包含一个术语。
如果我们尝试“AI 语言模型”会发生什么?
print(search_query("AI language models", docs))

文档 2 排名最高,因为它包含了两个词项“language”和“models”,而另外两个文档只包含“AI”。在这种情况下,文档 1 排在文档 3 之前,因为它较短,导致词频略高。
就这样,你拥有了一个基于 TF-IDF 的极其简单的搜索引擎!
结论
在本文中,我们学习了 TF-IDF,它是什么,它如何工作,以及最重要的,为什么它在现代 NLP 应用中是必需的。我们还实现了 TF-IDF 并展示了如何用它构建一个极简的搜索引擎。如果你不想从头实现它,可以使用 [sklearn.feature_extraction.text.TfidfVectorizer](https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html) 更快地实现它。这里有两个很好的博客文章介绍了如何使用它:1 和 2。
TF-IDF 并不是文本特征提取的唯一方法。还有许多其他方法可能表现更好,值得了解它们。尽管如此,它仍然是一种流行的文本特征方法,我们可能会在未来涵盖更近期的方法。
感谢阅读。如果这对你有帮助,随时关注我并订阅更多即将发布的文章。
希望你度过了一段愉快的时光。
干杯。
参考文献
[1] Mikolov, Tomas 等. “词语和短语的分布式表示及其组合性。” 神经信息处理系统进展 26 (2013)。
[2] Pennington, Jeffrey, Richard Socher 和 Christopher D. Manning. “Glove:词表示的全局向量。” 2014 年自然语言处理经验方法会议(EMNLP)论文集。2014。
[3] Devlin, Jacob 等. “Bert:用于语言理解的深度双向变换器预训练。” arXiv 预印本 arXiv:1810.04805 (2018)。
[4] Vaswani, Ashish 等. “注意力即一切。” 神经信息处理系统进展 30 (2017)。
[5] Beel, Joeran 等. “论文推荐系统:文献综述。” 国际数字图书馆期刊 17 (2016): 305–338。
[6] “如何 sklearn 的 Tfidfvectorizer 计算 tf-idf 值”,Analytics Vidhya,www.analyticsvidhya.com/blog/2021/11/how-sklearns-tfidfvectorizer-calculates-tf-idf-values/
[7] “使用 Python 进行文本向量化:TF-IDF”,Okan Bulut,okan.cloud/posts/2022-01-16-text-vectorization-using-python-tf-idf/
[8] “理解机器学习中的 TF-IDF”,Capital One,www.capitalone.com/tech/machine-learning/understanding-tf-idf/
[9] “理解 TF-IDF:一个简单的介绍”,MonkeyLearn,monkeylearn.com/blog/what-is-tf-idf/
以苏格拉底式的方法理解去噪扩散概率模型(DDPMs)
深入探讨去噪扩散模型背后的动机以及损失函数的详细推导
·发表于 Towards Data Science ·69 分钟阅读·2023 年 2 月 25 日
--

由 Chaozzy Lin 拍摄的照片,来源于 Unsplash
去噪扩散概率模型由 Jonathan Ho 等人提出,是一篇很棒的论文。但我在理解它时遇到了困难。因此,我决定深入研究这个模型,并完成所有的推导。在这篇文章中,我将重点关注理解论文的两个主要障碍:
-
为什么去噪扩散模型的设计涉及前向过程、前向过程后验和反向过程。这些过程之间有什么关系?顺便提一下,在这篇文章中,我称前向过程后验为“前向过程的反向”,因为我发现“后验”这个词让我困惑,或者我潜意识中想要避免这个词,因为它让我感到恐惧——每次出现它的时候,事情就变得复杂起来。
-
如何推导神秘的损失函数。在论文中,推导损失函数 Lₛᵢₘₚₗₑ 的过程中有很多省略步骤。我经过了所有的推导,填补了遗漏的步骤。现在我意识到,Lₛᵢₘₚₗₑ 的解析公式推导讲述了一个真正美丽的贝叶斯故事。在填补了所有步骤之后,整个故事变得容易理解。
一些符号
Medium 支持文本中的 Unicode,这使得我能够书写许多数学下标符号,比如 x₀ 和 xₜ。但我无法书写其他一些下标。例如:

对于这些内容,我将使用下划线“_”来引导下标,比如 x_T 和 p(x_0:T)。
如果某些数学符号在你的手机上显示为问号,请尝试在计算机上阅读本文。这是一个已知的 Unicode 渲染问题。
从噪声中生成自然图像的任务
我们的目标是使用神经网络从噪声中生成自然图像。神经网络的输入是噪声,输出应该是自然图像,例如人脸。不同的噪声将产生不同的自然图像,例如,一种噪声可能会产生女性的面孔,另一种噪声可能会产生男性的面孔。
你可能会问,什么样的噪声?在没有其他约束的情况下,热衷于贝叶斯方法的研究者会从高斯噪声开始。
这种噪声的维度是多少呢?理想的输出是一个具有红绿蓝(RGB)值的彩色二维图像。我们先将彩色图像简化为[0, 255]之间的灰度值,然后将灰度值缩放到[0, 1]的范围内。接着将这个二维缩放后的灰度值数组重塑为一个长的一维向量,长度为d。我将在文章中多次提到名称d。我们就以上内容作为图像生成任务的简单定义。但请注意,实际上,神经网络可以直接生成彩色图像。
自然的假设是输入噪声的维度和结构与输出图像的维度和结构相同,即长度为d的向量。因此,噪声应该是d维的多变量标准高斯分布N(0, 1),这是一种学术上的默认设置。
现在,从噪声生成图像的任务变得更加具体:设计一个神经网络,该网络从d维的多变量标准高斯分布中获取一个样本,并输出一个d维的缩放灰度值向量。将输出向量转化为二维形状和 RGB 颜色是我们都知道的如何做到的,这不在本文讨论范围之内。
迭代生成图像
一步生成自然图像从噪声中是困难的。那么在多个较小步骤中生成图像怎么样呢?有点类似于让图像从老式摄影中的柯达胶卷中显现出来。这样,在每一步中,神经网络的任务应该更简单,因为每一步的输入和输出比从纯噪声到最终自然图像要更相似。
这种迭代生成的想法带来了一个自身的问题。那么中间图像应该是什么样的?有经验的老一辈(像我这样)可能会建议中间图像应该是渐进的——在这个迭代过程中,不应出现首先出现一只猫的图像,然后猫变成了一个人的脸的情况。
对中间图像的“渐进性”约束是有道理的。但是如何将其数学化呢?
正向过程将自然图像转变为噪声
尽管尚不清楚如何定义迭代生成过程的逐步性,但定义相反的过程,即通过逐步添加一点高斯噪声将自然图像转变为纯噪声的过程却很简单。
将自然图像通过逐步添加噪声转变为纯噪声的过程称为前向扩散过程,简称前向过程。
反向过程将噪声转变为自然图像
另一方面,将高斯噪声转变为自然图像的过程称为反向过程。
前向和反向过程的示意图
下面的图来自论文,展示了这两个过程,前向过程在底部,反向过程在顶部。

图片来自论文 去噪扩散概率模型,第 2 页
在图中,x₀, x₁,x₂、…、x_T 是d维多变量高斯随机变量。我们使用随机变量x₀来表示自然图像。这意味着如果我们从x₀的概率密度函数中抽取一个样本,这个样本向量一旦重新排列成 2D 图像,应该看起来像一幅灰度自然图像。我们还没有讨论x₀的概率密度函数的具体形式。我稍后会讲到。
在图中,随机变量x₁、x₂、…、x_T对应于中间图像,即添加了噪声的图像(如果我们看底部的前向过程),或者移除了噪声的图像(如果我们看顶部的反向过程)。稍后,我还会介绍这些随机变量的概率密度函数,分别从前向过程和反向过程的角度来看。
记住x₀是自然图像,x_T是纯噪声,而不是反过来,请记住:小下标,小噪声,大下标,大噪声。这个想法来源于这个视频。
现在有一种方法可以数学地定义中间图像的逐步性。每个从反向过程中生成的图像xₜ必须接近(别担心,我们稍后会定义“接近”是什么意思)从前向过程中扩散得到的相应图像。
前向过程
前向过程是一个概率模型。为什么?因为每一步都向图像中添加了高斯噪声。因此,结果不是确定性的——从相同的自然图像x₀开始,你可能会得到不同的标准多变量高斯噪声样本x_T。就像在不同时间将墨水滴入一杯水中,每次都会得到不同的扩散效果。
在这个概率模型中,x₀, x₁,… 到 x_T 是随机变量。每个都是d维随机变量。
由于前向过程是概率性的,因此讨论它的适当数学工具是概率密度函数和概率理论。
上图使用 q(xₜ|xₜ₋₁) 来表示前向扩散过程中从图像 xₜ 到图像 xₜ₋₁ 的单步概率密度函数。我们将其概率密度函数定义为:

前向过程在固定均值向量和协方差矩阵的条件下
βₜ 是一个随时间变化的值,很像是调度学习率。
使用重参数化技巧(参见推导 这里),随机变量 xₜ 可以等效地描述为:

其中 ϵₜ₋₁ 是 d-维标准高斯噪声。这个公式揭示了更嘈杂的图像 xₜ 是较少噪声图像 xₜ₋₁ 和某些噪声 ϵₜ₋₁ 之间的加权平均。换句话说,前向过程将噪声 ϵₜ₋₁ 添加到较少噪声的图像 xₜ₋₁ 中。βₜ 的值控制了在时间戳 t 上添加的噪声量。这就是为什么 βₜ 被安排为非常小的值,从 β₁=10⁻⁴ 到 β_T=10⁻²。而 T 被设置为 1000——否则噪声将迅速主导前向过程。
q(xₜ|xₜ₋₁) 是随机变量 xₜ 的概率密度函数,它描述了前向过程中的单一步骤。以下的联合概率密度函数描述了整个前向过程:

这是前向过程 q(x_1:T|x₀) 的第一次分解。
为什么前向过程的联合概率密度函数 q(x_1:T|x₀) 会依赖于随机变量 x₀?这是因为当 t=1 时,q(xₜ|xₜ₋₁) 变成了 q(x₁|x₀),其中提到了 x₀。
随机变量依赖性
由于本文将讨论许多随机变量依赖于其他随机变量的情况,让我们以 q(x₁|x₀) 为例来阐明随机变量 x₁ 依赖于随机变量 x₀ 的含义。以下推导展示了 q(x₁|x₀) 的概率密度函数。

q(x₁|x₀) 的概率密度函数公式
我们将 q(x₁|x₀) 定义为多元高斯,在线(2)中用 N(x₁; mean vector, covariance matrix) 符号表示,因为多元高斯分布完全由其均值向量和协方差矩阵指定。“完全指定”意味着在写下均值向量和协方差矩阵后,多元高斯分布是一个固定函数,其结构,如指数函数、行列式,如在(3)行中由字母 N 假定。
然后,在第(3)行中,我们将此符号扩展为多变量高斯分布的概率密度函数的数学公式。d 是随机变量 x₁ 的维度,即图像中的像素数量 height×width。det 是行列式函数。
因此,很明显,第(3)行是随机变量 x₁ 的概率密度函数。从第(3)行,我们还知道以下几点:
首先,作为随机变量 x₁ 的概率密度函数,第(3)行的函数在 x₁ 的定义域上 积分为 1。这是因为所有适当的概率密度函数都积分为 1。

q(x₁|x₀) 在 x₁ 的定义域上积分为 1
其次,第(3)行是随机变量 x₁ 的概率密度函数,但它也提到了另一个随机变量 x₀。这就是为什么随机变量 x₁ 依赖于 随机变量 x₀——评估 q(x₁|x₀) 对某个 x₁ 的值需要一个 x₀ 的值。
现在你可能会问,我们能否将第(3)行视为随机变量 x₀ 的概率密度函数,并说 x₀ 依赖于随机变量 x₁?
答案是否定的。确实,从数学函数的角度来看,我们可以将 q(x₁|x₀),第(3)行, 解释为 x₁ 的单一参数函数或 x₀ 的单一参数函数。但是,第(3)行不是随机变量 x₀ 的适当概率密度函数,因为它在 x₀ 的定义域上不积分为 1。
随机变量 x₀ 的概率密度函数?*
在明确第(3)行不能被解释为随机变量 x₀ 的概率密度函数后,一个自然的后续问题是:我们是否知道 x₀ 的概率密度函数 q(x₀)?
不,我们不知道。q(x₀) 描述了自然图像的概率。这意味着:
-
给定自然图像,假设 X₀,将 X₀ 代入 q(x₀=X₀) 应返回一个介于 0 和 1 之间的概率值,以指示这个自然图像在所有自然图像中出现的可能性。
-
总结起来,或者说,积分,所有自然图像的 q(x₀) 概率值给我们 1。
显然,像往常一样,我们不知道 q(x₀) 的解析公式。但这并不妨碍我们写下其符号,并为随机变量 x₀ 绘制样本——我们只是从自然图像的训练集中随机挑选一张图像。
正向过程由 T+1 个随机变量 x₀, x₁ 到 x_T 组成。它们形成了两组:
-
x₀:有关于随机变量 x₀ 的观测。观测是来自训练数据集的实际图像。我们称 x₀ 为观测随机变量。
-
x₁ 到 x_T:对它们没有观测,因此它们是 潜在随机变量。
正向过程的定义带来了三个重要属性。
属性 1:完全联合概率密度函数 q(x_0:T)

联合概率密度函数表示轨迹。从视觉上看,完全的联合概率密度函数q(x_0:T)描述了图像的可能轨迹集合。每条轨迹由T+1幅图像组成,其中x₀表示无噪声图像,x_T表示纯噪声图像。
以下插图展示了从两个自然图像X₀和X₁开始的一些轨迹。这些轨迹以不同的纯噪声图像结束,表明前向过程具有概率性。这是插图,因为我手绘了这幅图——并不是说从X₀开始的轨迹与从X₁开始的轨迹没有重叠。

手绘插图,由我制作
插图还显示在时间戳t,随机变量xₜ负责解释前向过程在此时间戳生成的所有可能图像。
属性 2: 边际概率密度函数 q(xₜ|x₀)
通过重复使用重参数化技巧(见推导 这里)给我们每个潜在随机变量的概率密度函数,而无需依赖其他潜在随机变量,因此得到的概率密度称为边际:

前向过程边际
其中

这个属性揭示了给定x₀时,潜在随机变量 xₜ不再依赖于潜在随机变量xₜ₋₁。换句话说,给定x₀时,潜在随机变量x₁到x_T彼此独立。
对于独立随机变量a和b,概率论中的乘积规则是p(a, b) = p(a) p(b)。应用乘积规则给我们q(x_1:T|x₀)的第二个因式分解:

这两个因式分解是等效的,意味着它们描述了相同随机变量集合的联合概率分布。有时我们会选择其中一个以简化公式推导。
属性 3: 前向过程的逆过程 q(xₜ₋₁|xₜ, x₀)
使用贝叶斯规则,可以推导出前向过程的逆过程的概率密度函数。在论文中,前向过程的逆过程称为前向过程后验。但我发现“后验”这个词在这篇文章中容易引起混淆。
让我们从前向过程q(xₜ|xₜ₋₁)开始,它是随机变量xₜ的概率密度函数:

注意,我们可以将冗余的条件随机变量 x₀ 加入到 q(xₜ|xₜ₋₁) 中,将其变成 q(xₜ|xₜ₋₁, x₀),因为根据定义,给定 xₜ₋₁,随机变量 xₜ 不依赖于任何其他随机变量。因此,添加 x₀ 作为依赖项不会改变 xₜ 的概率密度。
然而,xₜ 在 q(xₜ|x₀) 和 xₜ₋₁ 在 q(xₜ₋₁|x₀) 中确实依赖于 x₀。我们知道 q(xₜ|x₀) 和 q(xₜ₋₁|x₀) 的公式来自于上述边际概率密度的属性 2。
通过重新排列项,我们可以推导出前向过程逆过程的概率密度函数:

使用贝叶斯规则推导前向过程逆过程中的条件
首先,请注意当 t=1 时,q(xₜ₋₁|xₜ, x₀) 变成 q(x₀|x₁, x₀),其中 x₀ 同时出现在“|”的左右两边。量 q(x₀|x₁, x₀) 总是等于 1,因为在给定 x₀ 的情况下,x₀ 的概率是 1。为什么?因为 x₀ 已经发生,对它没有不确定性了。
记住这一点,当你看到第一个解析损失函数从 t=2 开始时,你会微笑。而当我们扩展损失函数以涵盖 t=1 的情况时,你将再次微笑。
方程 q(xₜ₋₁|xₜ, x₀) 的左边告诉我们这是一个将更嘈杂的图像 xₜ 转换为较少噪声的图像 xₜ₋₁ — 的过程,请记住大下标表示更多噪声,小下标表示较少噪声。因此,q(xₜ₋₁|xₜ, x₀) 描述了与前向过程相反的过程。我们称之为前向过程的逆过程。
方程的右边告诉我们 q(xₜ₋₁|xₜ, x₀) 由前向过程概率密度 q(xₜ|xₜ₋₁) 定义,通过 q(xₜ₋₁|x₀)/q(xₜ|x₀) 进行缩放。我们已经定义了这三个组成部分。如果将它们代入方程的右边,并耐心简化公式,你将看到 q(xₜ₋₁|xₜ, x₀) 是一个多元高斯分布,它由其均值向量和协方差矩阵完全指定。经过忽略不计的数学推导后,我们得到:

前向过程条件的逆过程
与

前向过程条件的固定均值向量和协方差矩阵
我们可以看到* xₜ 的均值向量是 x₀ 和 xₜ 之间的加权和。这两个随机变量前面的权重取决于时间戳t。协方差矩阵也是一个依赖于时间戳t*的量。均值向量和协方差矩阵都没有提到可训练参数(我们稍后将介绍可训练参数)。
反向过程很重要,因为它描述了我们确切想要的生成过程——一个将噪声逐渐转变为自然图像的过程,即去噪。当然,这只是贝叶斯规则在说话,我们想亲自看到反向过程是否真正能够提供这种去噪能力。我们可以通过研究反向过程开始时和结束时的采样图像来找到答案。
但要全面理解去噪模型还需要很多数学知识,我不想用每一个细节来折磨你。所以如果你愿意相信我的话,反向过程以接近纯高斯噪声的嘈杂图像开始,以接近x₀的图像结束,即反向过程以此为条件的自然图像。换句话说,反向过程将嘈杂图像去噪成自然图像。如果你仍然渴望更多数学内容,关于反向过程开始和结束的推导在附录“反向过程从哪里开始和结束”中。
注意“反向过程的开始是接近纯高斯噪声的嘈杂图像”这句话。它是接近纯高斯噪声,而不是等于纯高斯噪声,因为反向过程q(xₜ₋₁|xₜ, x₀)是以初始图像x₀为条件的(详细信息见附录)。正是这种初始图像x₀的参与使得反向过程能够在最后生成一个非常接近x₀的图像。
一方面,在反向过程q(xₜ₋₁|xₜ, x₀)的定义中对初始图像x₀的条件是不可避免的。这是因为q(xₜ₋₁|xₜ, x₀)是通过对前向过程q(xₜ|xₜ₋₁)应用贝叶斯规则得到的,当t=1时包含了随机变量x₀。而贝叶斯规则的应用不会去除任何随机变量。
另一方面,反向过程对初始图像x₀的条件有明确的直观。这种直观将在我们介绍反向过程后变得清晰。
反向过程
图顶端是反向过程p(xₜ₋₁|xₜ)。反向过程从更嘈杂的图像xₜ生成一个噪声较少的新图像xₜ₋₁,与反向过程q(xₜ₋₁|xₜ, x₀)相同。注意图中使用了符号p_θ,但我决定使用p,因为p_θ在 Medium 上不好看。
逆过程必须包含与前向过程的逆过程相同的一组随机变量 x₀, x₁ 到 x_T,因为我们希望在逆过程和前向过程的逆之间建立“逐步性”对应关系。换句话说,既然前向过程的逆过程已经告诉我们如何逐步去除图像中的噪声,我们希望我们的神经网络在每一步都模仿这一过程。我们传达逐步模仿要求的方式是要求逆过程中的随机变量 xₜ 行为类似于前向过程的逆过程中的对应随机变量。
由于前向过程的逆过程是使用多元高斯分布定义的,因此使用多元高斯分布定义逆过程也是合理的:

逆过程定义
在归纳公式中,p(xₜ₋₁|xₜ),均值向量 μₚ(xₜ, t) 和协方差矩阵 Σₚ(xₜ, t) 实际上是两个深度神经网络,预测多元高斯分布的 d 维均值和 d×d 维协方差矩阵。
在基础公式中,p(x_T) 是标准的多元高斯分布,这确认了逆过程从纯噪声开始。请注意,逆过程的起始点不依赖于任何随机变量,这与前向过程的逆情况不同。
在联合概率的情况下,符号p(x_0:T) 是p(x₀, x₁, …, x_T) 的简写。它表示了T+1个随机变量的概率密度函数。联合概率的公式是归纳情况和基础情况中所有项的乘积,遵循概率理论的基本性质。
神经网络 μₚ(xₜ, t) 和 Σₚ(xₜ, t) 接受两个输入,第一个是噪声图像 xₜ,第二个是时间戳 t。噪声图像 xₜ 是有意义的。毕竟,我们希望使用神经网络来去噪。但是如何理解时间戳 t 作为神经网络的输入?意图与自然语言处理中的 Transformer 模型类似,该模型使用余弦函数编码句子中词语的位置,并将编码的位置作为额外输入提供给 Transformer。在这里,我们也希望将我们在逆过程中的位置编码作为额外输入,以便为神经网络提供一些位置信息。
神经网络如何预测一个d维均值向量和d×d协方差矩阵?对于均值向量,均值预测神经网络将有d个输出单元,每个单元预测d维均值向量中的一个条目。协方差矩阵预测神经网络具有d×d个输出单元。这只是一个粗略的理解。对于较大的d,神经网络的输出数量,尤其是对于协方差预测神经网络,数量非常庞大。协方差矩阵有更简洁的方法,参见这里的均值场参数化。
μₚ(xₜ, t) 和 Σₚ(xₜ, t) 包含模型参数。
均值向量预测网络μₚ(xₜ, t)和协方差矩阵预测网络Σₚ(xₜ, t)中的权重是这个机器学习任务中的模型参数。我们希望通过优化找到这些模型参数的适当值,以便当从p(x_T)的噪声样本开始,并迭代地从分布p(xₜ₋₁|xₜ)中采样xₜ₋₁,当我们从分布p(x₀|x₁)中检索一个x₀样本时,这个x₀样本就是一个现实自然图像的缩放灰度图。
你可能还会有另一个问题,我们是否应该使用相同的均值神经网络来预测所有时间戳的均值向量?协方差矩阵预测网络也是同样的问题。这是一个设计选择。至少需要一个网络,作者通过实验表明一个网络就足够了。你可以有两个或更多网络,但这需要学习更多的参数。
与正向过程中的情况类似,联合概率密度函数p(x_0:T)也表示反向过程可以生成的图像轨迹集,起始于某些纯高斯噪声。
为什么我们需要反向过程p(xₜ₋₁|xₜ)?难道反向过程q(xₜ₋₁|xₜ, x₀)不够吗?
既然我们已经知道反向过程的分布q(xₜ₋₁|xₜ, x₀),它用于去噪图像,你可能会问,为什么反向过程p(xₜ₋₁|xₜ)仍然是必要的?为什么我们不能直接从q(xₜ₋₁|xₜ, x₀)中采样自然图像?
当然可以。但请仔细看看q(xₜ₋₁|xₜ, x₀)。xₜ₋₁不仅依赖于xₜ,还依赖于初始图像x₀。这意味着我们需要知道初始图像才能开始采样,而正向过程的反向将给出一个非常接近已知x₀的图像(参见附录反向过程从何开始和结束?)。这不是我们想要的。我们希望能够自由地采样自然图像!
此时,你可能不想停下来。你可能会问,我们能否推导出q(xₜ₋₁|xₜ)的解析公式,即反向过程在不依赖于x₀的情况下?让我们再次使用贝叶斯规则来实现:

现在我们可以发现问题:在方程的右侧,q(xₜ|xₜ₋₁) 已经被定义,但 q(xₜ₋₁) 和 q(xₜ) 还没有定义。因此,这是一条死路。
你仍然不会停下来,你会问,为什么不能定义 q(xₜ)?我听到了!让我们试着定义 q(xₜ)。从概念上讲,q(xₜ) 表示前向过程在时间戳 t 可以生成的可能图像集。从轨迹的角度来看,见下文,前向过程在时间 t 生成的图像取决于轨迹的起点(自然图像 X₀, X₁ 等)。因此,q(xₜ) 的概率密度函数不可避免地参考起点的概率密度函数,即 q(x₀)。

由我手绘的插图
q(x₀) 是什么?它是训练数据的概率密度。不幸的是,我们之前已经明确了 q(x₀) 是未知的。我们能做的最好事情就是通过从我们的训练集中随机挑选自然图像来从中取样。因此,我们将无法写下 q(xₜ) 的解析公式。
前向过程的反向过程 q(xₜ₋₁|xₜ, x₀) 智能地定义了随机变量 xₜ₋₁ 在 x₀ 条件下的概率密度函数。以 x₀ 为条件使我们能够插入一个 x₀ 的样本来推理 xₜ₋₁ 的属性。只要我们能够从 x₀ 中取样(我们可以),并以期望的方式推理 xₜ₋₁,我们就可以了。有关“以期望方式推理随机变量”的更多细节,请参见下文的取样平均。
模仿前向过程的反向过程能否给我们一个可以从任何多变量高斯噪声开始的反向过程?
我们已经确定,反向过程的起点已经是纯高斯噪声。通过模仿反向过程将主要的高斯噪声去噪到训练集中的自然图像,我们的无条件去噪模型,即反向过程 p(xₜ₋₁|xₜ) 应该能够将纯高斯噪声转化为逼真的自然图像。就像如果一条线性回归线穿过许多数据点,我们期望这条线能够在相同方向上插值到其他未见的数据点。
为什么反向过程使用神经网络来预测均值向量和协方差矩阵?
在概率建模中,决定使用哪个分布族之后,最困难的任务就是决定使用什么值来完全指定分布。在我们的例子中,我们决定使用多变量高斯族。然后我们需要决定两个量来完全指定一个多变量高斯,即均值向量和协方差矩阵。让我在这里再次展示反向过程的归纳案例:

如果你考虑一下,均值向量预测函数μₚ(xₜ, t)和协方差矩阵预测函数Σₚ(xₜ, t)需要完成一项困难的任务——给定一个任意的噪声图像xₜ和时间戳t作为输入,它们需要输出两个量(均值和协方差矩阵),这些量描述了去噪版本xₜ的图像的光谱。
显然,这些函数μₚ(xₜ, t)和Σₚ(xₜ, t)不能太简单。具有两个参数的线性函数、具有三个参数的二次函数,甚至具有四个参数的三次函数都无法胜任这一任务。任务如此困难,只有具有数百万个参数的函数才能胜任(论文的附录 B 描述了从 3570 万到 2.56 亿参数的架构)。神经网络是定义具有数百万个参数的函数的便捷方式,并且在实现惊人功能方面有良好的记录。
上述归纳情况也展示了将深度神经网络纳入统计模型的最常见方式——使用神经网络预测难以指定的概率分布参数。结合变分推断(我稍后会介绍),它最大化我们模型的似然用于参数学习,这三者(统计模型、神经网络和变分推断)在现代机器学习中是好朋友(best friends forever)。抱歉,我家有小孩,一些缩写不可避免。
为什么正向过程的反向过程q(xₜ₋₁|xₜ, x₀)必须以初始图像x₀为条件的直观解释
我们的目标是训练一个定义为反向过程的模型,将噪声图像逐渐去噪成清晰的自然图像。为了训练这样的模型,我们需要大量的去噪轨迹作为训练数据。
仅仅拥有一组自然图像是不足以训练这样的模型的。相反,我们需要完整的轨迹,每个轨迹由图像组成,这些图像从接近噪声到清晰的自然图像逐渐变化。只有这样,训练好的模型才能以噪声图像和时间步作为输入,逐渐去噪图像。
正向过程的反向过程是提供这些逐渐变化的图像轨迹的机制。如果正向过程的反向过程不以初始图像x₀为条件,我们怎么能控制最终生成的自然图像呢?我们需要一种方法,将我们的请求插入到正向过程的反向过程q(xₜ₋₁|xₜ, x₀)中,以便生成我们所需的自然图像。
对自然图像x₀的正向过程的反向过程进行条件化是一种实现目标的方法,更不用说,当我们应用贝叶斯规则推导正向过程的反向过程的概率密度q(xₜ₋₁|xₜ, x₀)时,这种条件化会自动发生。
顺便说一下,前向过程的反向过程可以通过去噪得到一个清晰的自然图像,而无需复杂的神经网络。这是因为它有一个条件目标图像的优势。这显示了数据在统计建模中的重要性。
过程回顾 q(xₜ|xₜ₋₁), q(xₜ₋₁|xₜ, x₀) 和 p(xₜ₋₁|xₜ)
由于我们将频繁参考前向过程、前向过程的反向过程以及反向过程,让我们做一个回顾:
-
前向过程 q(xₜ|xₜ₋₁) 通过逐渐添加高斯噪声将自然图像转换为高斯噪声。前向过程是一个没有任何模型参数的固定过程。
-
前向过程的反向过程 q(xₜ₋₁|xₜ, x₀) 通过去除噪声将噪声更大的图像转换为噪声较少的图像。前向过程的反向过程也是一个没有任何模型参数的固定过程。它通过对前向过程应用贝叶斯规则来交换随机变量的顺序。贝叶斯规则将对 x₀ 的依赖性加入到前向过程的反向过程中,因此从中得到的最终样本是与 x₀ 非常相似的图像。换句话说,我们不能使用前向过程的反向过程来采样任意自然图像。
-
反向过程 p(xₜ₋₁|xₜ) 将任意的高斯噪声转换为自然图像。这是我们想要学习的过程。反向过程包含了我们模型的所有参数,即概率密度 p(xₜ₋₁|xₜ) 内部两个神经网络的权重。由于不依赖于 x₀,反向过程允许我们采样任意自然图像。但我们首先需要通过梯度下降找到模型参数的良好值,这需要一个解析形式的损失函数,以便梯度下降算法可以计算损失的梯度。
目标函数
既然反向过程的结构已经定义且其必要性已解释,现在是时候考虑我们要最小化的目标函数,以进行均值向量预测网络 μₚ(xₜ, t) 和协方差矩阵预测网络 Σₚ(xₜ, t) 的参数学习。
对于一个概率模型,数据的似然性总是思考目标函数的良好起点。让我们定义一下“数据的似然性”在我们的模型中意味着什么。
联合概率密度函数,已插入数据
之前定义的反向过程的联合概率密度函数 p(x₀, x₁⋯, x_T),或简称为 p(x_0:T) 是一个以 T+1 个随机变量为自变量的函数,即 x₀, x₁ 到 x_T。作为概率密度函数,当实际值代入其自变量时,其评估结果是介于 [0, 1] 之间的概率值。
概率模型的目的是很好地解释训练数据。“很好地解释训练数据”意味着当将训练数据集中的图像逐一代入x₀参数时,它们会得到一个高概率值。
将图像X₀代入联合概率密度函数p(x_0:T)的参数x₀中,该函数有T+1个随机变量,结果是一个具有T个随机变量的新函数:p(x₀=X₀, x₁⋯, x_T)。这个函数尚不能计算为概率值,因为它涉及随机变量x₁到x_T,这些不是具体值。x₁到x_T是潜在随机变量,没有观测值,因此我们无法找到一些有意义的具体值(如观测随机变量x₀的情况)来代入。它们需要被去除,或者更准确地说,积分去除。
似然p(x₀)
根据定义,随机变量,例如x₁,描述了一系列可能的值。去除概率密度函数中的随机变量的常用方法是计算相对于该随机变量的期望值。换句话说,要从p(x₀, x₁⋯, x_T)中去除随机变量x₁,计算该函数相对于x₁的平均值,或称期望值。本质上,我们在说由于无法观测到潜在随机变量的具体值,我们必须推理它们的平均行为。
让我们首先选择x₁来积分去除。由于x₁是一个连续随机变量,其期望值由积分定义,因此有“积分去除随机变量”之称:

相同的“积分去除”方法,应用T次,可以将所有潜在随机变量从p(x₀, x₁⋯, x_T)中去除:

数据的似然,所有潜在随机变量都被积分去除
p(x₀)现在仅描述了使用我们的模型生成实际图像的可能性,我们称p(x₀)为数据的似然。
注意,上述方程仅是一个符号,表示p(x₀)是在所有潜在随机变量被积分去除后剩下的部分。它并没有告诉我们如何去除这些变量。这是因为积分符号“∫”表示积分的结果,而没有告诉我们如何进行积分。
为什么不将x₀也从似然p(x₀)中积分去除呢?
上述的T-维积分仅将潜在随机变量x₁到x_T积分去除,保留了观测随机变量x₀在p(x₀)中。为什么?因为如果x₀也被积分去除,整个联合概率密度函数会变为 1,因为当所有概率密度函数在其完整的随机变量集上积分时,结果为 1:

你看,这里没有地方插入实际图像以评估这个数字 1 对训练数据解释的效果。这阻止了我们进行参数学习。
这就是为什么我们将x₀保持未积分掉,使用p(x₀),这称为数据的似然,或简称似然。
似然p(x₀)提到所有模型参数
即使似然p(x₀),当插入实际图像时,即p(x₀=X₀),是一个不提及任何随机变量的函数,它仍然通过概率密度函数p(xₜ₋₁|xₜ)提到所有模型参数,即两个神经网络中的权重,如下所示:

权重存在于均值预测网络μₚ和协方差矩阵预测网络Σₚ中。潜在随机变量x₁到x_T已被积分掉,但它们的均值和协方差矩阵项保留在积分结果中。
你可能会问,既然我们没有讨论潜在随机变量是如何被积分掉的,我们怎么知道μₚ(xₜ, t)和Σₚ(xₜ, t)会在积分中保留?你稍后会在分析损失部分的推导中看到,但这里你必须相信我:如果在积分之后,那些高斯潜在随机变量消失了,而且描述它们的两个重要东西,即均值向量μₚ(xₜ, t)和协方差矩阵Σₚ(xₜ, t)也消失了,那么似乎这些随机变量在我们的模型中从未存在过。这是不合理的。因此,这两个神经网络μₚ(xₜ, t)和Σₚ(xₜ, t)将会在积分中保留。换句话说,p(x₀)提到所有在μₚ(xₜ, t)和Σₚ(xₜ, t)中的模型参数。
p(x₀)提到所有模型参数,即两个神经网络中的权重,但我们尚不知道模型参数的正确值。如果我们假装知道所有参数值,则可以通过逐个插入训练图像,从训练图像集中评估p(x₀),将其转化为[0, 1]之间的概率值。这将产生许多概率值。对这些概率值取平均给我们提供了一个衡量我们模型解释训练数据的效果的指标。
当然,我们不知道神经网络权重的值。我们可以将它们设置为任意值,但这可能导致一个解释训练数据效果不佳的模型。即,这个具有随机神经网络权重的模型对p(x₀=X),其中X是从训练数据集中采样的自然图像,返回非常低的概率值。请注意,当你实际进行这个操作时,你不会知道返回的概率值是小还是大,因为你还没有基准值可以进行比较。你可以确定的是,当你要求模型去噪一个纯噪声图像时,它完全不会生成一个现实的自然图像。
在这种情况下,问题不是我们的模型结构无法解释数据,而是模型尚未正确校准。这里的“模型结构”指的是用两个深度神经网络预测去噪图像的均值向量和协方差矩阵的逆过程。
优化可以找到这些神经网络的适当参数值。它需要一个损失函数来最小化。
数据的负对数似然作为损失函数
损失函数必须提到所有模型参数。似然p(x₀) 满足这一要求。一个好的模型应该让实际图像X₀评估为一个高似然概率值p(x₀=X₀)。我们想要最小化损失函数,因此带有负号。一个好的模型不仅需要对训练集中的单个图像表现良好,还需要对训练集中的所有图像表现良好,因此需要对从训练集中采样的图像x₀~q(x₀).进行期望。我们可以取p(x₀)的log,因为log是一个单调函数,不会影响损失函数的最优值;我们引入log函数是因为它是我们下面将使用的 KL 散度的核心部分。
上述思路引导我们到著名的负期望对数似然,记作L:

负对数似然损失函数的定义
第(一)行是数据的负对数似然的定义。
第(二)行代入了p(x₀)的定义,该定义将所有潜在随机变量x₁到x_T从密度函数p(x₀, ⋯, x_T)中积分出去。
负对数似然损失函数是不可优化的
通过优化执行参数学习的标准方法是使用梯度下降法来最小化相对于模型参数的损失函数。梯度下降法需要知道损失函数的解析公式以进行求导。不幸的是,负对数似然损失函数的解析公式非常难以推导。
为了求出损失函数L的解析公式,让我们再看一下它的定义:

p(x_0:T) 已经定义过(在此再次展示)为:

很容易看出我们的模型参数——预测神经网络μₚ的均值向量中的权重和预测神经网络Σₚ的协方差矩阵——在损失函数中提到。但它们是在积分符号∫中提到的。
与指数符号exp或平方操作符“²”等代表我们立即知道如何计算的东西不同,积分符号代表计算的结果,即它要求你积分一个函数,而没有告诉你如何进行积分。
从我们的微积分课程中,我们都知道求导是工作,但执行积分是艺术——只要你有导数速查表,求导是机械的过程。然而,积分需要创造力,我们有这么多积分,简直不知道怎么做。
不幸的是,损失函数 L 内的p(x_0:T)的积分属于非常难以解析求解的积分。我们称之为不可处理的积分。让我们使用设置T为 1 的短反向过程来演示这一点。我们的目标是展示 L 是不可处理的:

由于我们知道如何从自然图像的训练集中采样x₀,因此关于x₀的外部期望可以通过样本平均来处理(请参见下一节关于样本平均的内容)。所以唯一困难的项是对数内部的积分,我们希望证明这个积分很难求解:

在将T设置为 1 之后,上述项变为:

显示损失 L 是不可处理的推导
行(1)显示了当T=1时反向过程的缩短联合概率密度函数。
行(2)将联合概率密度函数分解为概率密度函数的乘积,每个函数分别对应一个随机变量x₁和x₀。
行(3)插入了这些单一概率密度函数的名称。
行(4)插入了实际的概率密度函数,它们是多元高斯分布。第一个exp是标准高斯分布中随机变量x₁的部分,第二个exp是条件在x₁上的随机变量x₀。我使用了比例符号“∝”来忽略每个多元高斯分布前面的归一化项。
行(4)中的积分在解析上很难解决。请注意,在这种情况下,我们知道如何使用积分的乘积法则来计算积分,但这只是困难且麻烦,特别是当T=1000时。论文作者提出的变分方法(稍后解释)更为优雅。
这个小练习还揭示了我们可以使用一种称为样本平均的技术来解析地近似积分。这是因为在行(4)中,每个随机变量的概率密度函数只出现了一次,这次对密度函数的提及是在计算相对于该随机变量的期望值。样本平均可以近似这些期望。
使用样本平均来推导损失的解析形式?
损失函数 L 包含一个不可处理的积分。近似不可处理积分的方式有很多种,例如样本平均、重要性采样和高斯求积。让我们来看一下其中最简单的方法——样本平均。
什么是样本平均?
样本平均是简单的——基于随机变量样本的函数评估来近似期望。
正式地,设x是来自分布h(x)的连续随机变量,我们想计算函数f(x)的期望,样本平均通过以下方式近似该期望:
然后样本平均方法用于计算函数f(x)的期望,通过对函数f(x)进行以下关于x的积分的近似,方法是将来自分布h(x)的x样本代入进行f(x)的评估平均。

样本平均的定义
行(1)是关于x来自概率密度函数h(x)的f(x)期望的符号表示。
行(2)由于我们假设x是一个连续随机变量,上述期望在数学上由这一行所示的积分定义。这是识别样本平均是否适用的公式模式。我们称之为样本平均模板。
行(3)是样本平均步骤。它通过对函数f进行从分布h(x)中抽取的样本S₁, S₂, …, Sₙ的评估进行平均,来近似积分。
换句话说,只要在插入样本x时函数f(x)是可评估的,样本平均就能近似积分。
我们有一个难以处理的积分
通过采样平均进行积分的解析近似
正如你所见,样本平均通过积分函数的总和来近似积分。这个项的总和在分析上与我们的模型参数有关。一个简单的例子说明了这一点:我们想写下某个模型的损失函数的解析公式,写成以下积分形式。在这个积分中,假设x是可以从h(x)中抽取的随机变量,μ是我们需要优化的模型参数。

在从h(x)中抽取两个样本S₁和S₂后,应用样本平均来近似该积分,这会得到一个关于μ的解析表达式,位于以下近似方程的右侧:

近似的右侧是一个提到模型参数μ的表达式。这个表达式是解析的——它不涉及代表计算结果的符号,如积分∫,它只涉及代表计算的符号,如指数函数exp和平方运算符“²”,我们知道如何计算梯度。
我们一直使用样本平均法。例如,为了计算学校中学生的期望身高,我们不知道学生身高的分布,但我们有测量的学生身高样本。然后通过做平均来计算期望。
使样本平均法不可用的反例
只要 p(x) 不出现在需要积分的函数 f(x) 中,并且 p(x) 易于采样,我们就可以使用样本平均法。这里是一个反例:

样本平均法反例
在这种情况下,q(xₒ) 是数据分布,其概率密度公式未知。但我们仍然可以通过随机从训练集中挑选图像来进行采样。样本平均法可以去掉右侧的 q(xₒ),但注意 q(xₒ) 也出现在被积分的 g 函数中。这个剩余的 q(xₒ) 不能通过样本平均法去除。即使插入样本 Xₒ,g(1+q(xₒ=Xₒ)) 仍然是一个不可评估的函数。因此,在这种情况下,样本平均法无法解决积分问题。
幸运的是,我们的损失函数 L 不属于这种情况,因此我们可以使用样本平均法来解析地近似 L。
通过样本平均法推导 L 的解析公式
为了推导损失函数 L 的解析公式,将 L 重写如下:

样本平均法对 L 的操作
行(3)显示,在log内,内部积分符合我们的样本平均模板,其中匹配的部分用不同的颜色突出显示:

为了应用样本平均法来近似这个内部积分,从逆过程的定义中采样潜在随机变量 x₁ 到 x_T:

明确地说,采样过程如下:
-
首先从标准多元高斯分布中对 x_T 进行采样,使用基本情况去除关于 x_T 的积分。
-
有了随机变量 xₜ 的样本 Sₜ,将 Sₜ 插入 p(xₜ₋₁|xₜ=Sₜ),然后采样 xₜ₋₁,使用归纳情况。
-
只要在这个过程中我们不丢失模型参数,我们将得到损失函数 L 的解析公式。在样本平均法中丢失模型参数意味着样本平均法可能会导致一个不再提及模型参数的公式。这是不好,因为一个不提及模型参数的损失函数是无用的。重参数化技巧被用来防止这种情况发生。但在我们的情况下,当应用样本平均法时,我们不需要担心丢失模型参数。附录 “为什么在应用样本平均法推导损失函数 L 的解析公式时不会丢失模型参数?” 解释了原因。
-
一旦所有 x₁ 到 x_T 的样本可用,将其称为一个样本轨迹。将此轨迹代入联合概率密度 p(x_0:T) 中以获得此轨迹下的 p(x₀) 的解析表达式。
-
重复步骤 1~4 来获得不同轨迹下的 p(x₀) 的解析表达式,并对其进行平均以近似内部积分。假设有 m 个样本轨迹,每个轨迹 i 给出一个解析公式 pᵢ(x₀),则平均的解析公式是,即近似的内部积分:

现在,损失 L 的近似解析公式是:

通过样本平均近似的损失 L
行(1)是带有内部期望值近似的损失 L。这让我们得到一个与样本平均模板相匹配的期望。
行(2)应用样本平均来近似外部期望,通过与从 q(x₀) 中绘制的样本图像 Sⱼ 平均 logs。这产生了一个解析表达式,我们可以从中计算梯度。
通过样本平均得到的 L 的解析公式计算代价高昂
上述样本平均需要大量计算,因为每个轨迹需要 T 个样本,每个潜在随机变量一个。常识告诉我们,仅为随机变量绘制一个样本是不够的——例如,你不应该仅通过测量一个学生的身高来计算学校的平均学生身高。这有什么不好?因为小样本量给出的估计(在此例中是期望身高)具有高方差。每次绘制一个样本时,你会得到不同的期望——这就是方差。
我们希望为每个随机变量绘制更多样本,因为更多样本意味着平均值更接近原始积分。换句话说,方差更小。但绘制更多样本需要大量计算:为每个潜在随机变量绘制两个样本并设定 T=1000,结果是 m=2¹⁰⁰⁰ 个轨迹需要平均。这是非常昂贵的。
实际上,我们只能为每个潜在随机变量绘制一个样本。但这会带来高方差问题。
使用少量样本进行样本平均得到的 L 的解析公式具有高方差
对每个潜在随机变量仅绘制少量样本(例如,每个潜在随机变量一个样本)的问题在于,对于实际图像 X₀ 计算得到的概率数 p(x₀=X₀) 具有高方差。这是因为概率数 p(x₀=X₀) 依赖于潜在随机变量 x₁ 到 x_T 的具体采样值。每次计算相同图像 X₀ 的 p(x₀=X₀) 时,这个概率数都会不同。由于每个样本轨迹中有 T 个具体采样值,方差可能会非常高。
更糟的是,在参数学习过程的开始,预测神经网络的均值向量和方差矩阵中的权重是随机初始化的。因此,通过这些神经网络采样的图像可能质量非常差——在某种程度上,它们根本不像之前图像的去噪版本。即使这并不会导致方差增加,但低质量的样本使参数学习变得更加困难。
为什么对 p(x₀=X₀) 的高方差不好?因为 p(x₀=X₀) 是我们衡量模型解释训练数据的效果的指标,在这种情况下,就是衡量模型对训练图像 X₀ 的解释效果。如果对于同一图像,我们的测量有时报告一个大的 p(x₀=X₀) 概率值,有时又报告一个小的概率值,那么优化器,例如 Adam 优化器,就不确定我们当前的模型是否能很好地解释训练数据。这种不确定性通常表现为非常慢甚至发散的训练过程。
由于样本平均法不是推导损失函数 L 的解析公式的好方法。是否有更好的方法?再次,变分推断提供了帮助。
为了缩短当前文章,我决定不介绍变分推断,而是将其作为已知知识使用。有关变分推断及其两个应用的介绍,请参见:解密 Tensorflow 时间序列:局部线性趋势 和 变分高斯过程 (VGP) — 当事物不是高斯时该怎么做。
变分推断用于推导损失函数 L 的解析公式
关键思想是使用另一个分布来计算否则无法解析的积分。我将通过在损失函数中使用重要性采样来引入这个新分布。请注意“计算”一词,而不是“近似”。因此,当我通过重要性采样引入新分布时,它是一个等号“=”,而不是一个近似号“≈”。
通过重要性采样推导
重要性采样引入了一个易于采样的分布,以帮助解决否则无法解析的积分。在我们的例子中,这个不可解析的积分是关于联合逆过程分布 p(x₀, …, x_T) 中的随机变量 x₁ 到 x_T。
请注意,在我们的例子中,p(x₀, …, x_T) 是可以采样的。正如我们之前所描述的,样本平均法适用于在解析上近似损失函数。
引入新分布的真正动机,在我们的案例中是联合前向过程q(x_1:T|x₀),是它有助于推导L的分析公式。它编码了反向过程中的“渐进性”要求。你可能不知道我在说什么。这两点将在我们完成使用重要性采样推导损失函数的分析公式后变得清晰。
在损失L中,积分是相对于潜在随机变量x₁到x_T,如下方行(2)所示,因此我们引入的新分布必须覆盖相同的随机变量集。我们为前向过程定义的q(x_1:T|x₀)分布符合这一要求。行(3)将其引入到L的公式中。

变分损失 Lᵥ的推导。
我想指出,上述推导对任何具有随机变量x₀到x_T的概率模型都是有效的,因为它仅使用了概率理论中的属性。这些属性对任何有效的概率分布都成立。只有从接下来的对Lᵥ的推导开始,当我们开始使用前向过程和反向过程的定义来分解联合概率时,我们才开始依赖具体的模型结构。
行(3)引入了q(x_1:T|x₀)/q(x_1:T|x₀)量。这个量的值为 1,因此其添加不会改变积分。注意这一行前面的等号。引入q(x_1:T|x₀)分布不会以任何方式改变L的值。
行(4)重新组织项,将旧的积分转变为相对于q(x_1:T|x₀)的新积分。
行(5)表示使用等效期望符号的积分。
行(6)使用詹森不等式将对数函数推入内期望,因为对数的期望比期望的对数更易于计算。詹森不等式还将我们最终将最小化的结果转化为一个比原始损失L大的新函数。重要的是,这个新函数的最小值与原始L的最小值在同一位置。因此,我们可以最小化新损失而不是旧损失。
行(7)将期望符号替换为其定义,即积分。行(8)重新排列项。
行(9)应用概率理论中的反向链式法则推导联合概率q(x_0:T)。
行(10)表示使用等效期望符号的积分。注意我们从引入q分布在随机变量x₁到x_T的过程中开始,最终得到相对于随机变量x₀到x_T的期望。我们将这个新量命名为Lᵥ,代表变分损失。
新的损失 Lᵥ 用于推导分析公式并进行最小化。
从现在开始,Lᵥ是需要最小化的量。我们的目标是推导新损失Lᵥ的解析公式。查看第(10)行,很难相信它是解析的。但在数学中,惊人的事情确实会发生。请继续阅读。
重新编写Lᵥ以获得重要的Lₜ₋₁项
这是一个重要的推导,请注意。

操作变分损失Lᵥ以揭示Lₜ₋₁项
第(1)行展示了新损失Lᵥ的推导。Lᵥ提到我们之前定义的逆过程p(x_0:T)和正向过程q(x_1:T|xₒ)的联合概率密度。
第(2)行将这两个联合概率密度进行因式分解。它使用逆过程的定义对p(x_0:T)进行因式分解,并使用q的第一次因式分解对q(x_1:T|x₀)进行因式分解。
从这一行开始,我们依赖于我们定义的模型结构,即在逆过程p中随机变量xₜ₋₁依赖于xₜ,在正向过程中xₜ依赖于xₜ₋₁*。这对于任意的概率模型不一定成立。
第(3)行再次进行因式分解。注意,由于这一行的因式分解,乘积从t=2开始,而不是t=1。
第(4)行将期望值外的减号移到期望值内,并使用log(a×b) = log(a) + log(b)的性质。
第(5)行引入了名称F_T来表示第一个项,Fₒ表示期望值内的第三项,以缩短推导,使其适合一行。
第(6)行是关键行,它使用贝叶斯规则替换q(xₜ₋₁|xₜ):

注意对x₀的依赖性,旨在将q(xₜ|xₜ₋₁)转变为q(xₜ|xₜ₋₁, x₀)。这个附加项是多余的,它不会改变条件概率,因为根据定义,随机变量xₜ仅依赖于xₜ₋₁。请参见下面重新显示的xₜ的定义。它仅提到xₜ₋₁而不是x₀。

这个附加项使我们更容易应用贝叶斯规则,因为贝叶斯规则提到了q(xₜ|x₀)和q(xₜ₋₁|x₀),这些都明确依赖于x₀。
注意在q(xₜ₋₁|xₜ, x₀)中对x₀的依赖不是多余的。x₀出现在这里是由于贝叶斯规则。
使用贝叶斯规则的原因是为了使项q(xₜ₋₁|xₜ, x₀)弹出。q(xₜ₋₁|xₜ, x₀)是正向过程逆向过程中的一个项。我们现在有了p(xₜ₋₁|xₜ)和q(xₜ₋₁|xₜ, x₀)之间的概率比,在第(6)行看到。p(xₜ₋₁|xₜ)和q(xₜ₋₁|xₜ, x₀)都是:
-
相同随机变量xₜ₋₁的概率密度函数和
-
它们是多元高斯分布,其解析概率密度已知——我们之前定义了逆过程中的p(xₜ₋₁|xₜ)和正向过程中的q(xₜ₋₁|xₜ, x₀)的解析形式。
这两个性质使得可以从分析上推导 p(xₜ₋₁|xₜ) 和 q(xₜ₋₁|xₜ, x₀) 之间的 KL 散度,具体细节稍后说明。
(7)行使用对数的性质来拆分项。
(8)行使用对数的性质将对数的和转换为对数的乘积。
(9)行意识到在乘积对数中,分子和分母共享许多项,这些项可以相互抵消,只剩下一个分子项和一个分母项。
(10)行引入了名称 F₀ 来表示期望值内部的最后一项。并且引入了名称 Lₜ₋₁ 来表示求和中的每个负对数项,以使推导更简短。即:

显然,t=[2, T] 范围内的 Lₜ₋₁ 项是重要的。注意,Lₜ₋₁, t 从 2 开始而不是 1,因为在(3)行中有拆分。这些 T-1 项构成了整个损失函数的大部分,只剩下另外三个项。我们稍后再担心这三个项,首先关注 Lₜ₋₁ 项,因为它将成为我们最小化的最终损失函数的核心。
推导 Lₜ₋₁ 的解析公式
让我们继续处理 t=[2, T] 的 Lₜ₋₁:

处理 Lₜ₋₁ 项以揭示它们的 KL 散度性质
(1)行是 Lₜ₋₁ 项的定义。(2)行将负号推入了 log 中。期望值是相对于从 q 分布中得到的随机变量 x₀ 到 x_T。
(3)行用其数学定义替代期望符号,这是一种对随机变量 x₀ 到 x_T 的积分。
(4)行通过使用前向过程的第二次分解来分解联合概率密度 q。
注意第二次分解是许多分布的乘积,每个分布提及单个潜在随机变量。这是正确的,因为在给定观察随机变量 x₀ 的情况下,所有潜在随机变量 x₁ 到 x_T 彼此独立。
(5)行将 q 分布中的所有因子组织成四部分:
-
q(x₀),这是关于 x₀ 的分布,其公式未知。
-
q(xₜ₋₁|x₀),这是关于 xₜ₋₁ 的分布。
-
q(xₜ|x₀),这是关于 xₜ 的分布。
-
q(xₒₜₕₑᵣ),这是关于除 xₜ₋₁ 和 xₜ 外的其他潜在随机变量的分布。
(5)行的分解原因是对数函数仅提及了 x₀, xₜ₋₁ 和 xₜ。
(6)行应用链式法则推导联合概率 q(xₜ₋₁, xₜ|x₀)。
(7)行是关键行。使用反向链式法则(适用于任何联合概率密度),它将 q(xₜ₋₁, xₜ|x₀) 替换为两个相乘的因子 q(xₜ₋₁|xₜ, x₀)q(xₜ|x₀),因为

第 (8) 行将积分变量拆分为 4 部分,对应于 x₀, xₜ₋₁, xₜ 和 xₒₜₕₑᵣ。并且重新排列项,使得内部积分在单个随机变量 xₜ₋₁ 上进行,外部积分在其余随机变量 x₀, xₜ 和 xₒₜₕₑᵣ 上进行。这一步是有效的,因为原始积分 dx_0:T 只是 dx₀ dx₁, … dx_T 的简写。这一行还使用了我们之前推导出的前向过程的属性 2 以及概率论中的条件规则:

第 (9) 行认识到内部积分是 q(xₜ₋₁|xₜ, x₀) 和 p(xₜ₋₁|xₜ) 之间的 KL 散度。这个 KL 散度是在两个多变量高斯分布之间,其解析概率密度函数是已知的。因此,我们可以解析地写出这个 KL 散度的公式。它是一个涉及随机变量 xₜ 和 x₀(注意,它不涉及 xₜ₋₁)以及所有模型参数的函数。
第 (10) 行将 q(xₜ, xₒₜₕₑᵣ, x₀) 分解为条件分布。
现在我们得到了 q(xₜ₋₁|xₜ, x₀) 和 p(xₜ₋₁|xₜ) 之间 KL 散度的解析表达式,但这个 KL 散度在一个积分中。我们如何用解析方法解决这个积分呢?
没错,我们可以使用样本平均法来近似 x₀, xₜ 和 xₒₜₕₑᵣ 的期望值:
-
从训练集中随机挑选自然图像来抽样 x₀。
-
从边际分布 q(xₜ|x₀) 中抽取样本 xₜ,在插入 x₀ 的样本之后。
-
不需要抽取 xₒₜₕₑᵣ,因为第 (10) 行显示 xₒₜₕₑᵣ 在 KL 散度中没有提及。xₒₜₕₑᵣ 中的随机变量值不会改变 KL 散度的计算结果。
唷,经过这么多步骤,我们终于得到了新损失函数 Lᵥ 中 [2, T] 范围内 Lₜ₋₁ 项的解析表达式。
使用样本平均法解决积分问题
让我在这里粘贴 Lₜ₋₁ 的解析公式,并添加使用样本平均法近似解析积分的步骤。

第 (1) 行是我们刚才推导出的 Lₜ₋₁ 的解析公式。它对随机变量 x₀, xₜ 和 xₒₜₕₑᵣ 进行了多重积分。这三种变量都很容易处理,因为:
-
首先,从我们的训练集中抽样 x₀。我们将 x₀ 的样本称为 S₀。
-
将 S₀ 代入 q(xₜ|x₀) 中,得到 q(xₜ|x₀=S₀),这是一个完全指定的多变量高斯分布,准备进行抽样。我们将一个 xₜ 的样本称为 Sₜ。
-
忽略对 xₒₜₕₑᵣ 的积分,因为 xₒₜₕₑᵣ 不出现在 KL 散度中,它们的样本不会改变积分结果的解析形式。
第 (2) 行使用上述抽样方案抽取 n 对 (S₀, Sₜ);将每对样本代入 KL 散度公式中得到一个解析项,然后对这些解析项取平均。
你可能会问,我们应该采样多少对 n?越多越好,但根据经验,单对样本已经给我们带来了良好的结果,所以 n=1。
因此,行 (3) 使用了 n=1 的事实,从行 (2) 中去掉求和,以得出这个简单公式:

KL 散度的解释
KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) 为神经网络建立了回归目标。
在经过如此多的努力推导出这个 KL 散度的解析公式之后,仔细查看它是明智的。
对于时间步 t 在 [2, T] 中,这个 KL 散度量化了两个分布之间的距离:
-
q(xₜ₋₁|xₜ, x₀) — 通过使用贝叶斯规则从前向过程推导出的反向过程。
-
p(xₜ₋₁|xₜ) — 我们使用深度神经网络实现的反向过程。
我们在最小化这两个分布之间的 KL 散度。也就是说,我们要求这两个分布在每个时间步 t=[2,T] 上都相似。两个分布相似确保了从中采样的图像在对应时间戳上也相似。
请注意:
-
反向过程 q(xₜ₋₁|xₜ, x₀) 是从前向过程推导出来的,前向过程没有可训练的参数,通过贝叶斯规则(上面属性 3)推导。贝叶斯规则没有引入任何可训练的参数,因此得到的反向过程也没有可训练的参数。
-
反向过程 p(xₜ₋₁|xₜ) 是通过神经网络定义的,它包含了我们模型中的所有可训练参数,即网络中的权重。
因此,反向过程 q(xₜ₋₁|xₜ, x₀) 是静态的。通过最小化 KL 散度,随机梯度下降优化算法调整神经网络中的参数值,使得反向过程 p(xₜ₋₁|xₜ) 尽可能接近 q(xₜ₋₁|xₜ, x₀),以便反向过程 p(xₜ₋₁|xₜ) 生成的图像与静态前向过程 q(xₜ₋₁|xₜ, x₀) 的图像相似。
换句话说,前向过程的静态反向过程为反向过程提供了真实图像或回归目标,以便在时间步 t=[2, T] 时,反向过程将生成的图像回归到这些真实图像。
在解释反向过程的角色作为真实图像或回归目标提供者时,请注意一个细微之处:
-
在传统回归模型中,如线性回归,我们最小化模型预测与真实值之间的距离。
-
但在这里,逆过程 p(xₜ₋₁|xₜ),作为一个概率模型,我们并不会直接最小化由我们的模型(逆过程)生成的图像与来自真实生成器(正向过程的逆)的图像之间的距离。KL 散度损失函数中没有提到模型的预测部分。相反,我们最小化两个机制,即逆过程和正向过程的逆过程,以便在每个时间步生成的图像相似。
正是这种每步相似性要求在 p(xₜ₋₁|xₜ) 和 q(xₜ₋₁|xₜ, x₀) 之间建立了由逆过程 p(xₜ₋₁|xₜ) 生成的图像的“渐变”变化。这是因为正向过程的静态逆过程逐步去除图像中的噪声,因此根据定义,正向过程的逆过程生成的图像具有逐渐去噪的效果——它们变得越来越清晰。通过回归到这些越来越清晰的图像,学习到的逆过程以时间戳 t 作为输入,被迫生成逐渐变化的图像。
这种每步相似性要求限制了基于神经网络的逆过程按照已知且更简单的过程——正向过程的逆过程行为。每步的 KL 散度防止了学习到的神经网络做出奇怪的行为,例如在早期步骤中首先生成一只猫的图像,然后将猫变成人的脸。
请注意这里的时间戳范围 t=[2, T]。这个范围意味着 Lₜ₋₁ 项只覆盖时间戳从 2 到 T,将第一步 t=1 排除在外。时间戳 t=1 作为最终生成自然图像的步骤,当然是重要的。记得我们留下了三个 Lᵥ 项没有分析吗?稍后我们将看到这些留下的项涵盖了第一个时间戳。
轨迹视角
让我们用下面的插图揭示 KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) 从轨迹角度上试图做什么。

手绘插图由我制作
左侧子图展示了两个自然图像 X₀ 和 X₁。从每个自然图像出发,如果我们多次应用正向过程,我们会得到多个轨迹。起始于 X₀ 或 X₁ 的黑色曲线代表了这些轨迹。时间戳从左到右,因此每条轨迹末尾的图像已经是纯高斯噪声。
在这种完全无条件的设置下,在时间戳t-1,我们模型中的随机变量xₜ₋₁可以取自任何轨迹,无论轨迹是从X₀还是X₁开始。换句话说,在时间戳t-1,我们的模型需要能够解释所有可能由前向过程生成的图像,这些图像可以从任何自然图像开始。我们的模型可以通过给随机变量xₜ₋₁一个在所有轨迹中间的均值和一个较大的方差来做到这一点。
中间的子图展示了当x₀给定时的情况,这将随机变量x₀设置为自然图像X₀。这种设置限制了模型仅需解释从自然图像X₀开始的轨迹。它们是中间子图中的红色轨迹。换句话说,我们的模型现在只需解释在时间戳t-1上的红色曲线中的可能值。模型可以通过提供更精确的均值和更小的方差来做到这一点,因为它不再需要覆盖从自然图像X₁开始的黑色轨迹。
右侧的子图展示了当x₀仍然被条件化为X₀时的情况,并且xₜ还受到特定图像Sₜ的条件限制,而该图像是从分布q(xₜ|x₀=X₀)中抽样得到的。这种第二次条件限制进一步约束了模型,只需解释在时间戳t通过Sₜ的轨迹。这些是蓝色轨迹,它们都从X₁开始,并经过Sₜ。
在这种条件下,随机变量xₜ₋₁在时间戳t-1上可以取的可能值被进一步限制。这意味着我们的模型需要预测一个接近蓝色轨迹中间的均值,并预测一个更小的xₜ₋₁的协方差。
但是,“接近蓝色轨迹中间”的预测均值应该是多少,以及“更小的”预测协方差应该是多少?这两个目标量由前向过程q(xₜ₋₁|xₜ, x₀)的反向定义决定,其定义在这里再次展示:

with

通过对模型进行xₜ和x₀的条件化,我们使模型在每一步训练中都能更容易学习,因为每一步,模型只需要解释一个相对较少的轨迹中的单一时间步。
优化通过固定 q 来迫使 p 发生变化
由于前向过程q(xₜ₋₁|xₜ, x₀)的反向是固定的,即q(xₜ₋₁|xₜ, x₀)中没有可训练的参数,优化的唯一方法是改变模型参数的值,使p更接近q。
值得注意的是,许多其他论文介绍了一个可学习的q并将q移近p。但在这篇论文中没有。在这篇论文中,在重要性采样中引入的q分布是固定的,最小化q和p之间的 KL 散度会移动p。
从Lₜ₋₁到均值向量距离公式 LMₜ₋₁
由于 KL 散度 Lₜ₋₁=KL(q(xₜ₋₁|xₜ, x₀) || p(xₜ₋₁|xₜ)) 是解析的,我们来写下来。回顾 KL 散度中提到的两个分布的概率密度函数都是多变量高斯分布:

两个多变量高斯分布之间 KL 散度的解析公式是:

Lₜ₋₁项 KL 散度的解析表达式。
上述公式有 4 项。
第一项在第(1)行计算两个协方差矩阵行列式之间的对数比率,表示为“det”。此项提到模型参数。
第二项在第(2)行涉及* d,即随机变量xₜ₋₁*的维度,这也是我们处理的图像的像素数量。此项未提及任何模型参数。
第三项在第(3)行计算两个矩阵乘积的迹,表示为“tr”。此项提到模型参数。
第四项在第(4)行是向量μₚ(xₜ, t)-μₜ(xₜ, x₀)的平方,由协方差矩阵Σₚ(xₜ, t)⁻¹缩放。
我知道,这个公式看起来很吓人。写下来也很麻烦。但我的建议是尝试接受这种困难,因为两个高斯分布之间的 KL 散度公式很可能会出现在变分机器学习中,例如在变分高斯过程中。
请记住,我们需要最小化此项相对于模型参数的值,该项出现在:
-
μₚ(xₜ, t),负责预测* p(xₜ₋₁|xₜ)* 多变量高斯分布的均值向量的神经网络。
-
Σₚ(xₜ, t),第二个神经网络负责预测p(xₜ₋₁|xₜ) 多变量高斯分布的协方差矩阵。
通过将反向过程协方差矩阵设置为常数来简化模型。
让我们通过移除第二个预测协方差矩阵的神经网络来简化模型。数学上,我们设置Σₚ(xₜ, t)=σₜ²I,其中σ*ₜ²的一个明显选择是:

上述公式使反向过程p(xₜ₋₁|xₜ)的协方差矩阵与前向过程的反向协方差矩阵相同。
通过这种简化,前三项变成常数,我们称它们的和为C。C不再涉及模型参数。它们在优化过程中可以忽略。这使我们只剩下第四项,称为LMₜ₋₁。所以我们有:

LMₜ₋₁定义为:

通过矩匹配推导LMₜ₋₁项,展示神经网络预测的目标
第(1)行是第四项。第(2)行代入了简化的协方差矩阵。第(3)行中的||…||²是向量平方操作,即向量与自身的点积。第(4)行交换了平方中的两个分量,这对结果没有影响,仅仅是为了与论文中的项的顺序保持一致。
请注意,我省略了关于x₀和xₜ的期望,以使公式更简洁。但计算方法与之前相同,我们需要对x₀和xₜ进行采样,将样本代入LSₜ₋₁中,以便在分析上近似积分。
解释LMₜ₋₁的含义
LMₜ₋₁量化了两个向量μₜ(xₜ, x₀)和μₚ(xₜ, t)之间的距离。这现在非常有意义:
-
我们最初希望最小化逆向过程的q(xₜ₋₁|xₜ, x₀)与p(xₜ₋₁|xₜ)之间的距离,这就是我们神经网络实现的逆向过程,针对每个时间步t从 2 到T。换句话说,我们希望为p(xₜ₋₁|xₜ)分布找到一个配置(模型参数值),使得这两个分布相似。
-
对于随机变量xₜ₋₁,这两个分布都是多元高斯分布。多元分布完全由其均值向量和协方差矩阵指定。如果p(xₜ₋₁|xₜ)需要与q(xₜ₋₁|xₜ, x₀)相似,它们的均值向量和协方差矩阵必须彼此相似。这被称为矩匹配,其中均值是第一个矩,而协方差是第二个。LMₜ₋₁中的字母“M”代表矩匹配。
-
在我们将p(xₜ₋₁|xₜ)分布的协方差矩阵简化为与正向过程的逆过程中的协方差矩阵相等的量后,唯一可以改变的以使这两个分布相似或不同的是均值向量。因此,我们希望最小化p(xₜ₋₁|xₜ)和q(xₜ₋₁|xₜ, x₀)分布的均值向量之间的距离。
-
由于p(xₜ₋₁|xₜ)分布的均值向量是由我们的神经网络预测的,我们可以通过优化来调整神经网络权重的值,从而最小化LMₜ₋₁。
简化 LMₜ₋₁
对 LMₜ₋₁ 可以进行大量简化。在 LMₜ₋₁ 的公式中,μₚ(xₜ, t) 部分来自神经网络,像一个黑箱,我们几乎无法简化。因此,让我们尝试简化另一个项 μₜ(xₜ, x₀),这是反向过程 q(xₜ₋₁|xₜ, x₀) 的均值向量,其解析概率密度函数已经推导出来:

与协方差矩阵:

以及均值向量:

我们只需要关注均值向量 μₜ(xₜ, x₀),因为之前对 LMₜ₋₁ 的推导揭示了我们只需要用我们的神经网络来预测接近或匹配 μₜ(xₜ, x₀) 的均值向量。
我们也有 q(xₜ|x₀) 的解析概率密度函数:

使用重参数化技巧,我们可以将上述内容重写为:

重新整理上述方程中的项以获得 x₀ 的表达式:

现在将 x₀ 的表达式代入 μₜ(xₜ, x₀) 的公式中:

均值向量目标的推导
第(1)行是一个糟糕的公式,第(2)行引入了名称 A 来表示 xₜ 前的系数,名称 B 表示 ϵₜ。我们将分别简化 A 和 B。
简化 A

简化 B

哇,多么惊人的简化!它给我们:

神经网络预测的简化均值向量目标
神经网络重新利用预测噪声以适应 t≥2
不要惊慌,我们的目标没有改变——我们仍然希望我们的神经网络预测 p(xₜ₋₁|xₜ) 分布的均值向量,并且预测的均值向量应该尽可能接近 μₜ(xₜ, x₀)。但看到 μₜ(xₜ, x₀) 的简化公式后,我们意识到:
-
xₜ 通过采样已知,不需要预测它。
-
给定时间戳 t,βₜ 是常数,因此所有从 βₜ 推导出的其他量,即 αₜ 和 αₜ bar* 也是常数。
-
唯一需要预测的部分是噪声 ϵₜ。
我们可以去掉原始神经网络,设计一个新的 ϵₚ(xₜ, t) 来预测噪声 ϵₜ。然后我们可以通过以下方式构造期望的均值向量 μₚ(xₜ, t):

从噪声预测中重建均值预测
将这个公式代入 LMₜ₋₁ 的定义给我们:

LMₜ₋₁ 在神经网络重新利用预测噪声后的操作
第(7)行是简化后的目标函数。
注意,这个目标函数提到了噪声 ϵₜ 两次。它们是相同的随机变量,而不是两个不同的噪声。这是因为它们都来自相同的源:

第一次我们使用上述方法将 x₀ 表达为 xₜ 和 ϵₜ 的函数。第二次我们使用该方法将 xₜ 表达为 x₀ 和 ϵₜ 的函数。
这个目标函数仍然是解析的吗?
还记得我们之前为了简化推导而省略了相对于 xₜ 和 x₀ 的期望吗?为了回答 LMₜ₋₁ 是否仍然是解析的,我们必须将它们添加回去,因为只有有了这些期望,我们才能计算正确的 LMₜ₋₁。
注意:
-
在 LMₜ₋₁ 的最终公式中,不再提及 xₜ,xₜ 通过 x₀ 和噪声 ϵₜ 表达。因此,我们不需要添加相对于 xₜ 的期望。而是需要添加相对于 ϵₜ 的期望,它是一个标准的多变量高斯分布,即 ϵₜ~N(0, 1)。
-
提到了时间戳 t,它表示 2 和 T 之间的整数。我们需要添加相对于 t 的期望,它来自均匀分布。
-
提到了 x₀,它来自未知的数据分布 q(x₀)。
所以,LMₜ₋₁ 的完整公式是:

完整的损失 LMₜ₋₁ 的期望形式,对于 t≥2
其中 Uni(2,T) 表示 2 和 T 之间的均匀分布。
这个公式在样本平均的情况下是解析的。当我们将 x₀、ϵₜ 和 t 的样本代入上述公式时,我们会得到一个解析表达式,从中可以计算梯度以进行随机梯度下降。
作者发现通过忽略向量距离项前的常数,结果会更好:

简化的 LMₜ₋₁ 损失的期望形式,对于 t≥2
以下 算法 0 最小化上述损失:

来源于论文 去噪扩散概率模型,第 4 页
算法 0 通过样本平均来评估相对于 x₀、xₜ 和 t 的期望。注意在第 (3) 行,时间戳 t 是从均匀分布 Uni(2, T) 中抽样的。
论文和这篇文章之间的一个符号差异是,在论文中,作者使用 ϵ_θ 来表示神经网络,而我使用 ϵₚ。作者使用 ϵ_θ 来强调神经网络具有参数集 θ。这在上述算法的第 (5) 行也明确显示,当计算相对于 θ 的损失函数的梯度时(注意 ▽ 符号表示对向量的导数)。我使用 ϵₚ,因为 Unicode 中没有下标 θ,我不想写太多 ϵ_θ,因为它们看起来不好。
另一个符号差异是论文中使用 ϵ 表示标准高斯噪声,而我使用了 ϵₜ。我使用 ϵₜ 是因为我以这种方式推导了我的公式。但我认为 ϵ 更好,因为标准高斯噪声不依赖于时间戳 t。
Lᵥ 中的剩余项
对 Lᵥ 的推导表明,它是一个相对于 q(x_0:T) 的期望,并且在期望内部有多个项,如下所示:

变分损失 Lᵥ 再次
之前我们只关注 t=[2, T] 的 Lₜ₋₁ 项。现在让我们讨论其余的项,我使用期望的线性特性将其提取到第 (2) 行的第一个期望中:E[a + b] = E[a] + E[b]。

在 Lᵥ 中操作项,排除 Lₜ₋₁ 项
第 (2) 行用实际公式替代了名称 F_T 和 F₀。
第 (3) 行和 (4) 行利用 log 的属性重写了这些项。
第 (5) 行简化了第二个 log。
第 (6) 行利用期望的线性特性将期望拆分为两个部分。
第 (7) 行将第一个期望命名为 L_T,与论文中的命名一致。
第 (8) 行将第二个期望的负值命名为 L₀,与论文中的命名一致。
在优化中可以忽略 L_T 项,而 L₀ 需要特别处理。我们将会看到原因。
忽略 L_T 项
这里是 L_T 项的公式:

它提到 q(X_T|x₀),即随机变量 X_T 的边际概率密度。前向过程不包括任何模型参数。
它还提到 p(X_T),即时间戳 T 的逆向过程。我们定义了 p(X_T) = N(0,1*)*。所以 p(X_T) 也不涉及模型参数。
这意味着整个 L_T 项不提及模型参数,因此在参数学习过程中可以忽略。
近似 L₀ 项
L₀ 项是:

这个项是针对时间戳 t=1 的。让我们理解一下这个项的含义。我们希望最小化这个项,这意味着找到使对数似然 log(p(x₀|x₁)) 最大化的模型参数。换句话说,我们希望 p(x₀|x₁) 在自然图像插入 x₀ 时评估为高概率值。
另一种理解方法是使用 Lₜ₋₁ 的公式:

第 (1) 行是我们之前推导出的 Lₜ₋₁ 的定义。注意,当我们推导它时,t 从 2 开始,因为当 t≥2 时,所有 Lₜ₋₁ 项都是两个合适的高斯分布之间的 KL 散度。对于 t=1,这不成立,如第 (4) 行所示。
第 (2) 行设置 t=1 以推导 L₀。第 (3) 行将 KL 符号展开为其数学定义。
第(4)行使用了 q(x₀|xₜ, x₀) = 1 的性质。这一行还揭示了当 t=1 时,不再有 KL 散度。公式退化为一个 log 的积分。这就是为什么我们不能在 Lₜ₋₁ 中处理 t=1 的原因。
第(5)行利用对数的性质简化了公式。
第(6)行用期望符号替代了积分。
第(7)行将两个关于 x₀ 的期望简化为一个关于 x₀ 的期望,因为一个期望已经去掉了随机变量 x₀。第二个关于 x₀ 的期望不再改变结果。这一行还揭示了得到的量确实是 L₀ 项。
L₀ 需要以不同的方式最小化,它不能适应算法 0
现在我们应该理解,并不是说我们不能从 Lₜ₋₁ 的角度推导 L₀。我们可以,但 L₀ 的推导不是两个适当的多元高斯分布之间的 KL 散度,这意味着 L₀ 的解析公式不同于 t≥2 的 Lₜ₋₁ 的解析公式。这意味着我们需要一种不同的方法来最小化 L₀。换句话说,L₀ 的最小化不适合算法 0*。好吧,它还不适合,稍后我们将引入一个近似方法使其适合。
L₀ 是可优化的
由于我们希望最小化 L₀,所以重要的是:
-
L₀ 没有提及任何模型参数,因此在优化过程中可以忽略。或者
-
L₀ 提及了模型参数,并且是解析的,因此可以用于梯度下降的梯度计算。
由于之前的损失函数 LMₜ₋₁ 仅处理 t≥2 的情况,我们希望 L₀ 落入上述第二类,以便我们损失函数的某部分涵盖 t=1 的情况。确实如此:

显示 L₀ 是解析的 推导
第(1)行是 L₀ 的定义。
第(2)行代入了 p(x₀|x₁) 的定义,它是一个多元高斯分布,神经网络 µₚ(x₁, 1) 预测其均值向量,协方差矩阵设置为常数 𝛼₁² I。我忽略了指数前的归一化项,并使用了比例符号 “∝”。
第(3)行和第(4)行简化了公式。
第(4)行揭示了 L₀ 提及了 µₚ(x₁, 1) 中的所有模型参数,并且在我们采样 x₀ 和 xₜ 后是解析的。因此 L₀ 是可优化的。
在算法 0 中最小化 L₀ 的近似值
上面的第(4)行还表明,为了最小化 L₀,神经网络 µₚ(x₁, 1) 需要预测一个接近自然图像的均值向量,比如说 X₀,这是为 x₀ 采样得到的。
之前当我们推导 Lₜ₋₁ 的解析公式时,对于 t≥2,我们意识到我们希望神经网络 µₚ(xₜ, t) 预测接近于反向过程 µₜ(xₜ, x₀*) 的均值向量。
如果我们可以:
-
写下 µₜ(xₜ, x₀) 对于 t=1,即 µ₁(x₁, x₀),并且,
-
如果 µ₁(x₁, x₀) 接近自然图像样本 X₀
然后我们可以将“最小化 µₚ(x₁, 1) 和 X₀ 之间的距离”这一原始任务转换为“最小化 µₚ(x₁, 1) 和 µ₁(x₁, x₀) 之间的距离”的近似任务。后者的好处是我们可以使用算法 0 处理 t=1 的情况,方式与 t≥2 的情况相同。
我们可以写出 µ₁(x₁, x₀)

注意,我们不能将 t=1 代入上述第一行。这是因为当 t=1 时,如 𝛼ₜ₋₁ bar 等量没有定义。但我们可以将 t=1 代入第二行。这是因为第二行将第一行中的 x₀ 替换为仅提到 x₁ 的表达式。而且所有涉及 𝛼₁ 和 β₁ 的量都已定义。
设 t=1 推导:

在插入 x₁ 和 ϵ₁ 的样本后,上述为常数。
我们知道µ₁(x₁, x₀) 必须接近自然图像 X₀
这是因为 µ₁(x₁, x₀) 是反向过程中的结束随机变量 x₀ 的均值向量。因此,如果我们从反向过程抽取 x₀ 的样本,我们应该得到一个接近自然图像 X₀ 的图像。这是反向过程定义的结果。实际上,如果我们从反向过程抽取许多 x₀ 的图像并对所有这些采样图像取平均,则平均值应该恰好等于 X₀。换句话说,反向过程可以“期望”生成精确的起始图像。但如果我们仅从反向过程抽取一个 x₀ 的样本,该样本不等于 X₀。这就是为什么我们对 L₀ 项进行近似的原因。
现在我们可以使用算法 0 处理从 t=1 开始的所有时间戳。从数学上讲,我们扩展了 LMₜ₋₁,它仅涵盖 t≥2 的情况,请参见期望下的 t~Uni(2,T) 部分:

简化的损失函数,对于 t≥2
为了涵盖 t=1 的情况,请参见期望下的 t~Uni(1,T) 部分:

简化的损失函数,对于 t≥1
最终损失和论文中的算法 1
Lₛᵢₘₚₗₑ 是最终的损失函数,它涵盖了从 1 到 T 的所有时间戳。下文中复制的论文算法 1 最小化 Lₛᵢₘₚₗₑ:

来自论文 Denoising Diffusion Probabilistic Models,第 4 页
我们高兴地注意到,在第(3)行中,时间戳 t 是从均匀分布 Uni(1, T) 中采样的,这覆盖了所有 t≥1 的情况,这是因为对 L₀ 项的近似。
对于样本平均的 Lₛᵢₘₚₗₑ 高方差没有问题吗?
我之前提到,我们可以使用样本平均来计算负对数似然 L 相对于所有潜在随机变量 x₁ 到 x_T 的解析公式。但如果我们在实际计算中只能为每个随机变量绘制一个样本,那么计算出的期望会有很高的方差。
为什么我们可以毫无问题地使用样本平均来计算解析公式 Lₛᵢₘₚₗₑ 并且每个随机变量只绘制一个样本?
主要原因是,在最终的损失函数 Lₛᵢₘₚₗₑ 中,只有 3 个随机变量需要采样,而在负对数似然的期望情况下则需要采样 T+1=1000+1 个随机变量。因此,最终损失函数的方差应该比负对数似然期望情况中的方差小得多。
更进一步,现在样本不再通过未经校准的神经网络绘制,它们都来自标准分布,其行为不依赖于我们训练神经网络的程度。这使得参数学习过程更加可预测。
但为了有趣的考虑,咱们来看看样本平均的替代方法。也就是,解析计算最终损失函数 Lₛᵢₘₚₗₑ 的期望:
-
对于随机变量 x₀,由于数据分布 q(x₀) 是未知的,因此没有办法从解析上计算其期望。因此,样本平均是唯一的选择。
-
对于来自均匀分布的随机变量 t,它的期望就是取所有可能的 t 值,计算期望内的公式并求平均。在我们随机梯度下降的上下文中,这等同于样本平均。尽管在随机梯度下降中,算法 1 只处理一个项,而不是将所有这些项相加然后除以 T,该算法会重复进行直到收敛。这相当于在 t 上逐渐计算期望。更多细节,请参见 我们能在线性回归模型上使用随机梯度下降 (SGD) 吗?
-
对于标准的多元高斯随机变量 ϵₜ,我们可以使用高斯求积法对期望进行解析近似。有关高斯求积法的更多详细信息,请参见 变分高斯过程(VGP)——当事情不是高斯时该怎么办。但高斯求积法在低维情况下效果更好。在我们的情况下,ϵₜ 是一个 d 维随机变量,其中 d 是我们想要生成的图像的像素数量,因此 d 是一个大整数。应用高斯求积法并不实际。有关为什么不实际的更多详细信息,请参见上述链接的附录。
鉴于以上情况,使用样本平均法来近似 Lₛᵢₘₚₗₑ 中的期望是一个明智的选择。
结论
这篇文章通过推理q(xₜ|xₜ₋₁)、q(xₜ₋₁|xₜ, x₀)和p(xₜ₋₁|xₜ)之间的关系,明确了去噪扩散概率模型设计的动机。它还提供了用于模型参数学习的损失函数的详细推导。
支持我
如果你喜欢我的故事,请考虑成为我的推荐会员。我将收到你订阅费用的一小部分,这对我帮助很大。
[## 使用我的推荐链接加入 Medium - Wei Yi
阅读 Wei Yi 的每一个故事(以及 Medium 上成千上万其他作家的作品)。我很享受花费数千小时来写作...
medium.com](https://medium.com/@jasonweiyi/membership?source=post_page-----445c1bdc5756--------------------------------)
附录
前向过程的反向过程从哪里开始和结束?
前向过程的反向过程从哪里开始? 换句话说,如果我们在前向过程的反向过程 q(xₜ₋₁|xₜ, x₀) 开始时取样,例如当 t 是一个大数,比如 t=T,其中 T=1000,那么图像是什么样的?它看起来像纯噪声,还是自然图像?
要查看样本的样子,我们需要前向过程反向的概率密度函数。我们再次定义了该概率密度函数的解析表达式:

前向过程条件的反向过程
通过

固定前向过程条件的均值向量和协方差矩阵
要推断时间戳t=T时样本xₜ₋₁的情况,只需计算在时间戳t=T时q(xₜ₋₁|xₜ, x₀)的均值向量和协方差矩阵。这是因为均值向量和协方差矩阵完全决定了多元高斯分布的形状,而分布的形状决定了从中抽取的样本的样貌。
让我们先计算协方差矩阵,因为这更简单。

时间戳t=T时前向过程反向的协方差矩阵
第(1)行是前向过程反向定义中协方差矩阵的项q(xₜ₋₁|xₜ, x₀)。
第(2)行设置时间戳t=T,代表前向过程反向的起始点。
第(3)行将近似值代入分数中。要理解原因,我们需要回顾以下定义,并注意到βₜ的调度限制了其值在β₁=10⁻⁴到β_T=10⁻²之间:

当t=T且T较大时,α_T的平均值非常小,接近 0(实际上并不太接近,α_T的平均值是 0.0060,而α_T-1的平均值是 0.0063)。
第(4)行将分数简化为 1。
第(5)行代入了β_T的计划值。
从第(5)行我们看到,时间戳t=T时前向过程反向概率密度函数的协方差矩阵是对角矩阵0.01I,这表明样本的方差不大,但也不小。
现在让我们看一下概率密度函数的均值向量。

时间戳t=T时前向过程反向的均值向量
如果忽略0.0008x₀的微小贡献,则时间戳t=T时前向过程高斯概率密度函数的反向均值向量几乎是x_T。但x_T的样子是什么呢?
好的,我们可以从边际概率密度函数q(x_T|x₀)中抽样,该函数在前向过程的性质 2 中定义。再次查看q(xₜ|x₀)的定义:

我们可以计算当t=T时这个分布的情况:

噢,边际分布q(x_T|x₀)的均值是0.07x₀,协方差矩阵几乎是单位矩阵。由于x₀是一个值在0和1之间的具体图像,因此0.07x₀接近于零向量。换句话说,从这个边际概率密度函数中得到的x_T样本会显得非常嘈杂,因为其均值接近零,而协方差接近单位矩阵——也就是纯高斯噪声。
根据以上信息,我们可以得出结论,前向过程的逆过程以一个接近纯高斯噪声的嘈杂图像开始。注意,“接近纯高斯噪声”这个短语。起始图像不是纯高斯噪声,在这个 0.07x₀均值向量和之前忽略的 0.0008x₀项中,仍然包含关于条件图像x₀的信息。正是关于x₀的信息使得前向过程的逆过程能够将起始图像x_T去噪为一个接近x₀的图像,我们现在将通过回答以下问题来验证这一点。
前向过程的逆过程在哪里结束?
我们现在知道该做什么了。我们需要查看概率密度函数q(xₜ₋₁|xₜ, x₀)在t=2时的均值向量和协方差矩阵,所以t-1是 1。当t=2时:

我们来看一下协方差矩阵:

时间戳 t=2 时前向过程逆向的协方差矩阵
所以我们知道,当t=2时,概率密度函数q(xₜ₋₁|xₜ, x₀)的协方差是一个很小的0.0001**I***。
现在我们来看均值向量:

时间戳 t=2 时前向过程逆向的均值向量
现在我们看到均值向量有一半的贡献来自具体图像x₀,另一半的贡献来自去噪图像x₂。注意到图像x₂由于接近前向过程逆向的终点,已经与x₀相似。因此,上述均值大致导致了一个非常接近x₀的均值向量。加上协方差矩阵很小0.0001I,我们可以推断随机变量x₁,即前向过程逆向的终点,来自于以下分布:

这表明采样图像非常接近图像x₀,变化非常小,因为协方差矩阵很小。
现在我们可以看到,当条件图像为x₀时,前向过程的逆过程开始于一个非常嘈杂的x₀版本,并以一个非常接近x₀的图像结束,展示了去噪图像的能力。
为什么在应用样本平均来推导损失函数 L 的解析公式时不会丢失模型参数
应用样本平均来近似损失函数中的积分时,一个典型的问题是结果公式不再提及模型参数。重新参数化技巧(见这里)是防止这种情况发生的最佳方法。
我们使用样本平均来推导损失函数L的解析近似的情况没有丢失模型参数的问题,让我们用一个短反向过程(T=1)的例子来看看原因。
让我们展示损失函数L及其一些操作,以演示样本平均:

推导显示样本平均近似的损失 L 保留模型参数
第(1)行是损失 L,第(2)行用数学定义替换了期望符号。
第(3)行设定T=1来演示在短反向轨迹上的后续推导。
第(4)行利用反向过程的定义对内积分中的联合概率进行因式分解。
第(5)行将所有概率密度函数符号替换为实际的概率密度分布名称。它还显示随机变量x₁可以从标准多变量高斯分布N(0, 1)中抽取样本。我们用S₁表示x₁的样本。
第(6)行插入了样本S₁,通过仅用一个样本进行样本平均来去除内积分,目的是为了演示。样本平均是一种近似,这在该行前面的近似符号“≈”中得以体现。
第(7)行从未知分布q(x₀)中抽取随机变量x₀的样本S₀;实际上就是从训练集中随机挑选一张自然图像。然后再次使用样本平均来去除对x₀的积分。
第(8)行插入了多变量高斯概率密度函数的公式。比例符号“∝”允许我省略指数函数前的归一化项。
第(9)行简化了公式。它表明在样本平均后,解析损失仍然是一个提到所有模型参数的函数。因此不需要重新参数化技巧。

















浙公网安备 33010602011771号