TowardsDataScience-2023-博客中文翻译-十八-
TowardsDataScience 2023 博客中文翻译(十八)
期望校准误差(ECE):逐步可视化解释
通过一个简单的示例和 Python 代码
·
查看 发表在 Towards Data Science ·8 分钟阅读·2023 年 7 月 12 日
--
作者提供的图片
在分类任务中,机器学习模型输出的是估计概率或也称为置信度 (见上图)。这些值告诉我们模型在其标签预测中的确定性。然而,对于大多数模型,这些置信度与它们预测的事件的真实频率并不一致。它们需要校准!
模型校准旨在将模型的预测与真实概率对齐,从而确保模型的预测是可靠且准确的(有关模型校准重要性的更多细节,请参见 博客文章 )。
好的,既然模型校准很重要,那么我们如何衡量它呢?有几个选项,但本文的目的和重点是解释并仅介绍一种简单的 但相对充分的 测量方法来评估模型校准:期望校准误差(ECE)。它计算估计“概率”的加权平均误差,从而得出一个单一值,我们可以用来比较不同的模型。
我们将按照论文中描述的 ECE 公式进行演示:现代神经网络的校准*。为了简单起见,我们将看一个包含 9 个数据点和二进制目标的小示例。然后,我们还将在 Python 中编写这个简单示例的代码,最后展示一个多类分类的示例代码。
定义
ECE(期望校准误差)衡量了模型估计的“概率”与真实(观察到的)概率的匹配程度,通过对准确性(acc)和置信度(conf)之间的绝对差异进行加权平均来实现:

该测量方法涉及将数据分成 M 个等间隔的箱子。 B 用于表示“箱子”,m 用于表示箱子的编号。我们稍后将详细介绍这个公式中的各个部分,如 B、|Bₘ|、acc(Bₘ) 和 conf(Bₘ)。首先,让我们看看我们的示例,这将帮助逐步理解公式。
示例
我们有 9 个样本,具有估计的概率,也称为‘置信度’ (pᵢ),用于预测 0 或 1。如果标签 0 的概率 pᵢ 高于 0.5,则预测标签将为 0。如果低于 0.5,则标签 1 的概率会更高,因此预测标签将为 1(见下表)。最后一列显示了样本 i 的真实标签。

表 1 | 图片由作者提供
从上表中我们可以看到我们有 9 个样本,n=9。为了确定公式中的其余部分,我们首先需要将样本分成多个箱(bins)。
仅使用确定预测标签的概率来计算 ECE。因此,我们将仅根据标签的最大概率对样本进行分箱(参见表 2)。为了简化示例,我们将数据分成 5 个 等间隔 的箱子 M=5(见右侧的分箱图 1)。 让我们为每个箱子分配一个颜色:

现在,如果我们查看每个样本的最大估计概率,我们可以将其分组到 5 个分箱之一。样本 i=1 的估计概率为 0.78,这高于 0.6 但低于 0.8,这意味着我们将其分到 B₄ 中,见下图。现在看看样本 i=3,其估计为 0.92。这落在 0.8 和 1 之间,因此属于分箱 B₅。我们对每个样本 i 重复这一过程,最终得到表格 2 中的分类(见下方)。


表格 2 和分箱图 1 | 图片由作者提供
B₁ 和 B₂ 不包含任何样本 (由于二进制示例的性质,最大概率在二进制情况下总是 ≥ 0.5)。B₃ 包含 2 个样本。4 个样本最终落入分箱 B₄,而 3 个样本落入 B₅。这已经为我们开始填写上面的 ECE 公式提供了一些信息。具体来说,我们可以计算样本落入分箱 m 的经验概率:|Bₘ|/n**(见下方红色高亮)。

我们知道 n 等于 9,并且从上述分箱过程中我们也知道每个分箱的大小:|Bₘ| (集合 S 的大小记作 |S| — 对于值,请参见上面的数字)。如果我们为每个分箱拆分出颜色编码的公式,则得到如下结果:

对于 B₁ 和 B₂,我们有 0 个样本(|B₁|=|B₂|=0),因此这些分箱的值为 0。
从上述分箱中,我们现在也可以确定 conf(Bₘ),它表示分箱 m 中的平均估计概率,论文中定义如下:

计算 conf(Bₘ) 时,我们将表格 2 中每个分箱 m 的最大估计概率 p̂ᵢ 相加,然后除以分箱的大小 |Bₘ|,见下方右侧:


表格 3 和计算 | 图片由作者提供
然后我们可以用这些值更新 ECE 计算:

现在我们只剩下填写 acc(Bₘ),它表示每个分箱 m 的 准确率,论文中定义如下:

1 是一个 指示函数,表示当预测标签 ŷᵢ 等于真实标签 yᵢ 时,它的值为 1,否则为 0。 这意味着你需要计算每个区间 m 中正确预测的样本数量,并将其除以区间的大小 |Bₘ|。 要做到这一点,我们需要首先确定样本是否被正确预测。 我们使用以下颜色:

并将其应用于最后 2 列,然后我们可以以相同的方式给右侧图中的样本上色:


表 4 & 分箱图 2 | 作者提供的图片
查看上图的 右侧,我们可以看到在区间 B₃ 中有 2 个样本和 1 个正确 预测,这意味着 B₃ 的准确度为 1/2。 对于 B₄ 重复这一过程,得到准确度为 3/4,因为在区间 B₄ 中有 3 个正确预测和 4 个样本。 最后,查看 B₅,我们有 3 个样本和 2 个正确预测,所以最终的准确度为 2/3。 这给出了每个区间的准确度值:

我们现在拥有计算 ECE 所需的所有元素:

在我们这个包含 9 个样本的小例子中,我们得到的 ECE 为 0.10445。 一个完全校准的模型将具有 0 的 ECE。 ECE 越大,模型越不校准。
ECE 是一个有用的初步度量,广泛用于评估模型校准。 然而,ECE 也有一些缺点,使用时应注意 (参见: 深度学习中的校准测量).
Python 代码
Numpy
首先我们将设置上述相同的示例:
import numpy as np
# Binary Classification
samples = np.array([[0.78, 0.22],
[0.36, 0.64],
[0.08, 0.92],
[0.58, 0.42],
[0.49, 0.51],
[0.85, 0.15],
[0.30, 0.70],
[0.63, 0.37],
[0.17, 0.83]])
true_labels = np.array([0,1,0,0,0,0,1,1,1])
我们接着定义 ECE 函数如下:
def expected_calibration_error(samples, true_labels, M=5):
# uniform binning approach with M number of bins
bin_boundaries = np.linspace(0, 1, M + 1)
bin_lowers = bin_boundaries[:-1]
bin_uppers = bin_boundaries[1:]
# get max probability per sample i
confidences = np.max(samples, axis=1)
# get predictions from confidences (positional in this case)
predicted_label = np.argmax(samples, axis=1)
# get a boolean list of correct/false predictions
accuracies = predicted_label==true_labels
ece = np.zeros(1)
for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
# determine if sample is in bin m (between bin lower & upper)
in_bin = np.logical_and(confidences > bin_lower.item(), confidences <= bin_upper.item())
# can calculate the empirical probability of a sample falling into bin m: (|Bm|/n)
prob_in_bin = in_bin.mean()
if prob_in_bin.item() > 0:
# get the accuracy of bin m: acc(Bm)
accuracy_in_bin = accuracies[in_bin].mean()
# get the average confidence of bin m: conf(Bm)
avg_confidence_in_bin = confidences[in_bin].mean()
# calculate |acc(Bm) - conf(Bm)| * (|Bm|/n) for bin m and add to the total ECE
ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prob_in_bin
return ece
在 二分类示例 上调用函数返回与我们上述计算的相同值 0.10444(四舍五入)。
expected_calibration_error(samples, true_labels)
除了二分类示例,我们现在还可以快速浏览一个多分类的案例。 我们使用 James D. McCaffrey 的例子。 这给了我们 5 个目标类别和相关的样本置信度。 我们实际上只需要目标索引来进行计算:[0,1,2,3,4],可以忽略它们对应的标签。 查看样本 i=1,我们可以看到我们现在有 5 个估计概率,每个类别一个:[0.25,0.2,0.22,0.18,0.15]。
# Multi-class Classification
samples_multi = np.array([[0.25,0.2,0.22,0.18,0.15],
[0.16,0.06,0.5,0.07,0.21],
[0.06,0.03,0.8,0.07,0.04],
[0.02,0.03,0.01,0.04,0.9],
[0.4,0.15,0.16,0.14,0.15],
[0.15,0.28,0.18,0.17,0.22],
[0.07,0.8,0.03,0.06,0.04],
[0.1,0.05,0.03,0.75,0.07],
[0.25,0.22,0.05,0.3,0.18],
[0.12,0.09,0.02,0.17,0.6]])
true_labels_multi = np.array([0,2,3,4,2,0,1,3,3,2])
在 multi-classexample 中调用函数返回0.192(与 McCaffrey’s 计算的 0.002 不同, 由于四舍五入的差异!)。
尝试一下Google Colab Notebook,在 numpy 或 PyTorch 中亲自试试看(见下文)。
你现在应该知道如何手动计算 ECE 以及使用 numpy 😃
Google Colab Notebook 链接,其中包含了 numpy 和 PyTorch 中的二分类和多分类示例。**注意:本文中的代码改编自 论文 的 ECE torch 类,来自他们的 GitHub 库。
从零开始的实验编排
原文:
towardsdatascience.com/experiment-orchestration-from-scratch-4a9e460944d8
开发自定义实验编排器以解决复杂建模问题。
·发布于Towards Data Science ·9 分钟阅读·2023 年 7 月 31 日
--

由 Daniel Warfield 使用 p5.js 进行编排。除非另有说明,否则所有图像均由作者创建。
在这篇文章中,我们将探讨为什么实验编排很重要,现有的编排解决方案,如何使用 MongoDB 构建自己的编排器,以及在某些用例中这样做的好处。
这对谁有用? 任何试图将模型拟合到数据上的人;因此需要一种组织这些实验的方法。
这篇文章有多先进? 编排的想法相当简单,几乎任何技能水平的人都能理解。这个例子应该对尝试拓展领域的后端开发人员或数据科学家有用。
前提条件: 对核心网络原理(如数据库和服务器)以及核心数据科学概念(如超参数)的基本理解。
代码: 完整代码可以在这里找到。注意:在撰写本文时,这个仓库仍在开发中。
什么是实验编排?
所谓“实验编排”,指的是许多任务,尽管它们有相同的基本概念。最常见的实验编排形式是超参数搜索:在给定一系列超参数值的情况下,你希望在这些值中进行搜索,找到适用于特定建模问题的最佳超参数集合。这些计划实验的组织通常被称为编排。

一个超参数搜索的例子。定义一个超参数的空间。然后,从这个超参数空间中提取特定的超参数集合,并以某种方式进行测试。然后可以识别出最佳的超参数组合。
简单的实验大多数情况下能完成任务,但随着建模问题的复杂性增加,通常需要更复杂的实验。你可能会发现自己需要在多个数据集上试验多种模型类型,每种模型都有自己的超参数空间。
例如,我目前正在研究不同建模策略在非同质建模应用中的表现。我不是在寻找“解决特定问题的最佳超参数集合”,而是“多种模型类型,每种模型都有其自己的超参数空间,在多个分类和回归任务中的表现如何”。
实验编排器的目标是作为实验的中心枢纽,无论实验定义多么复杂,使得单个工作人员或一组工作人员可以运行实验的子集。

实验编排器的目标是使工作人员知道需要做什么,并允许工作人员记录结果。
我们将使用 MongoDB Data Services 来存储结果,并使用 MongoDB Application Services 作为承载系统逻辑和网络的服务器来构建类似的系统。尽管这非常强大,但也非常简单;我在一个周末就把整个系统搞定了。
现存的解决方案
Weights and Biases 是一个明显的选择。
对于W&B Sweeps,你需要定义一个实验,一个代理(训练和验证代码),并在代理运行时记录结果。这些步骤都相当直接,类似于这样:
- 定义一个实验
"""
Telling W&B what hyperparameter space I want to explore
"""
parameters_dict = {
'optimizer': {
'values': ['adam', 'sgd']
},
'fc_layer_size': {
'values': [128, 256, 512]
},
'dropout': {
'values': [0.3, 0.4, 0.5]
},
}
sweep_config['parameters'] = parameters_dict
2. 定义一个代理
"""
defining a model which works based off of the hyperparmeters
"""
#gets a configuration from the orchestrator
config = wandb.config
loader = build_dataset(config.batch_size)
network = build_network(config.fc_layer_size, config.dropout)
optimizer = build_optimizer(network, config.optimizer, config.learning_rate)
3. 记录结果
"""
training a model and logging the results
"""
for epoch in range(config.epochs):
avg_loss = train_epoch(network, loader, optimizer)
wandb.log({"loss": avg_loss, "epoch": epoch})
这对于 90%的使用案例是足够的,通常是实验编排的推荐方法。然而,这种方法不适用于我的使用案例。接下来的部分将介绍我的解决方案。
自建编排器的理由
对于 90%的应用程序(尤其是商业应用),上述方法通常已经足够。依赖现有系统意味着依赖其超越任何可行的粗糙解决方案的稳健性和功能成熟度。
也就是说,像 W&B 这样的系统似乎被设计用于寻找“解决方案”。它们在假设你有一个特定数据集并希望探索解决方案,以找到适合该数据集的最佳解决方案的目标下运行。对我而言,以及我的研究需求,管理多个数据集、多个模型及其之间的兼容性,在 W&B 中令人感到非常麻烦。

方孔中的圆钉。在软件中,设计上的轻微不一致常常会在之后引发严重问题。这就是为什么在某些应用中,重新制作技术可能比与现有实现集成更容易,即使原始的“基本上”适用。
使用 W&B 时,似乎我需要通过某种方式组织和管理多个数据集上的多个搜索,从而构建一个协调器的协调器。在一个已经很复杂的任务基础上,我还需要处理集成问题。正是在这一点上,我决定从头开始构建 W&B 搜索,并进行一些小的修改,以满足我的需求,这将是最合适的选择。
在一个周末构建一个自定义协调器
我为我的特定问题实现了一个协调器。虽然解决方案是特定于问题的,但一般思路应该对大多数机器学习实验需求是灵活的。
定义问题
我有大约 45 个表格数据集,涵盖了各种领域。这些数据集中的每一个都可以被视为一个“任务”,任何给定的模型可能会表现良好或不佳。有些任务可能是回归任务,而其他任务可能是分类任务。
一般思路是构建一个协调器,该协调器可以管理一组模型对一组数据集的应用。这个协调器应该汇总这些结果以便进一步分析。
同样,自然地,这个协调器的目标是解决问题,而不是成为问题。想法是在解决我遇到的问题时,尽可能简化某些方面。因此,这个解决方案非常简陋,有点像是小修小补。
技术选择
对于这个解决方案,我使用了 MongoDB 应用服务和 MongoDB 数据服务,或者叫什么都无所谓。MongoDB 在过去一年中经历了很多品牌重塑。我使用的系统曾经叫做 MongoDB Atlas 和 Realm,但现在 Realm 可能是 Atlas 的一部分?我不太确定。
尽管如此,云上的 MongoDB 本质上是一个“盒子中的后端”。你可以非常迅速地设置数据库、应用层和 API 层,几乎没有额外开销。在我的经验中,由于文档混乱,使得将东西准备好生产环境可能是一场艰苦的战斗。然而,对于快速原型开发后端资源,我还未找到更好的替代方案。
接下来的几个部分描述了我如何将协调问题拆解成实验和运行,以及这些在实际中是什么样的。

数据库的屏幕截图,包括“实验”、“运行”和“用户数据”集合

一个允许与协调器通信的简单 HTTP API
定义实验
在这种自定义方法中,我基本上借用了 W&B 的搜索设计,并加入了一些自己的想法。核心系统运行在一个“实验”上,该实验描述了模型、超参数空间、数据集以及这三者如何关联在一起。
"""
An example of an experiment definition. Each "experiment" has three key fields:
- data_groups: groups identifiers of datasets
- model_groups: groups identifiers of models
- applications: which model_groups should apply to which data_groups
This approach thinks of a model as two things:
- a unique identifier, which references some model definition
- a hyperparameter space associated with that model.
"runs_per_pair" defines how often a certain association should be run. For
instance "test model X's hyperparameters on dataset Y 10 times".
"""
{
"name": "testExp0",
"runs_per_pair": "10",
"definition": {
"data_groups": {
"group0": [
"dataUID0",
"dataUID1",
"dataUID2"
],
"group1": [
"dataUID3",
"dataUID4",
"dataUID5"
]
},
"model_groups": {
"model0": {
"model": "modelUID0",
"hype": {
"learning_rate": {"distribution": "log_uniform", min:0.0, max:2.5},
"layers": {"distribution":"int_uniform", min:0, max:2}
}
},
"model1": {
"model": "modelUID1",
"hype": {
"learning_rate": {"distribution": "log_uniform", min:0.0, max:2.5}
}
}
},
"applications": {
"group0": [
"model0"
],
"group1": [
"model0",
"model1"
]
}
}
}
这个实验随后被拆解为一个任务列表:一些工人需要执行的独立探索。这些通过查看所有关联并列出所有模型/数据集对来完成。

每个任务在实现中称为 mtpair(模型-任务对)。这跟踪模型、任务(一个数据集)以及所有成功和完成的 mtpair 运行。此图像包含 2 个 mtpair。
这一切都是通过调用 “/registerExperiment” API 端点并传递模型定义来创建的。
我选择将实验设置为“声明式”,有点像你熟悉的 terraform 脚本。当你注册一个实验时,你要么创建一个新的实验,要么根据实验名称获取一个现有的实验。这样,你可以在多个工作者上使用相同的脚本。第一个工作者将创建实验,而其他工作者将简单地使用已经创建的实验。(或者,至少是这个想法。你需要小心这种思路下的竞争条件。)
运行
现在实验已经定义好,以及需要运行的各个模型/任务对,我们可以开始运行。这是调度器的实际操作部分。我们必须:
- 决定工作者应该处理哪个模型/任务对
2. 从该模型的超参数空间中获取供工作者使用的超参数
3. 记录正在进行的结果
4. 管理已完成的运行(将可能失败的运行分开)。
“运行”构造存在于调度器中,以记录这些信息。

“运行”集合中的一个运行示例
一个运行直接与一个模型/任务对(mtpair)、该 mtpair 所在的实验、谁创建了实验、模型、任务、特定的超参数空间点以及按每个纪元记录的结果相关联。这是通过 “/beginRun”、“/updateRun” 和 “/endRun” 端点完成的。
/beginRun 查看所有现有的运行,并在完成和启动的运行最少的 mtpair 上创建一个新的运行。/beginRun 在决定优先处理哪个模型-任务对之后,使用随机搜索将该模型的超参数空间转化为一组具体的超参数。然后,它会传递一个处理程序来处理该运行。
/updateRun 允许你按每个纪元注册指标。每个纪元,你调用 /updateRun 并传递一个包含该运行指标的字典。这些可以是用户认为合适的任何内容。
/endRun 做了一些提升生活质量的事情。结束的运行无法继续,因此它允许代码声明运行已经完成。它还会更新实验中运行的记录,并将运行标记为成功完成。意外失败的运行不会被标记为结束,因此实现这个 webhook 使调度器能够容忍故障工作者。
安全
这个系统使用 JSON Web Tokens (JWTs)来创建一些基础的身份验证。从研究的角度来看,这样一个项目的风险配置相当低。尽管如此,这个系统确实会根据每个用户验证 API 令牌,并提供一些安全措施以确保数据完整性,同时允许协作。
此外,在钱包的安全性方面,我使用了免费套餐,并且设置时不需要注册支付方式。(MongoDB 云服务有些古怪,但在原型设计上确实非常惊人)
就这些了!
通常我会包含代码,但这是一个完整的代码库,放在文章里有些繁琐。如果你想查看代码库,可以点击这里。具体来说,你可以查看函数定义,这些定义基本上就是整个内容的精华。
更新
- 我添加了一个名为beginRunSticky的 webhook,它开始一个新的运行,但接受一个要“粘贴”的数据集。它优先给工人分配一个具有指定任务的新运行,从而允许多个运行执行而无需加载新的数据集。
关注以获取更多!
在未来的帖子中,我还将描述机器学习领域的几篇重要论文,重点放在实际和直观的解释上。
署名: 本文档中的所有图片均由丹尼尔·沃菲尔德创建,除非另有来源说明。你可以在自己的非商业用途上使用本文中的任何图片,只要引用了这篇文章,danielwarfield.dev,或两者兼而有之。
使用谷歌 AI 的 TCAV 进行可解释 AI
原文:
towardsdatascience.com/explainable-ai-with-tcav-from-google-ai-5408adf905e
使用基于概念的解释解释深度神经网络
·发表于Towards Data Science ·13 分钟阅读·2023 年 2 月 18 日
--

图片来源:Pixabay
可解释人工智能 (XAI) 是人工智能 (AI) 的一个子领域,旨在开发能够向人类提供清晰且易于理解的决策过程解释的 AI 系统。XAI 的目标是使 AI 变得更加透明、可信、负责任和伦理。XAI 在增加 AI 应用方面至关重要,尤其是在医疗保健、金融和执法等高风险领域。在这些领域,理解 AI 系统如何得出特定决策或建议至关重要。
XAI 中使用了各种技术,包括模型透明性、基于规则的系统以及如 LIME 和 SHAP 等模型无关的方法。XAI 方法可以根据 AI 系统的类型、应用领域和所需的可解释性水平有所不同。总体而言,XAI 是一个至关重要的领域,用于开发可以信任并在现实世界应用中有效且伦理地使用的 AI 系统。
如果你想在短短 45 分钟的视频中获得 XAI 的简要介绍,你可以观看我在2021 年 AI 加速器节 APAC上发表的关于 XAI 的过去的一个讲座:
可解释 AI:使机器学习和深度学习模型更具可解释性(作者讲座)
目前流行的 XAI 方法 如 LIME 和 SHAP 的一个主要限制是,这些方法与非技术性终端用户解释观察的方式不是非常一致和直观。例如,如果你有一张装满可乐的玻璃的图像,并使用 LIME 和 SHAP 解释一个正确将图像分类为可乐的黑箱模型,LIME 和 SHAP 都会突出显示图像中导致模型正确预测的区域。但如果你让一个非技术用户描述这张图像,用户会因为玻璃中含有一种类似可乐饮料的深色碳酸液体而将其归类为可乐。换句话说,人类倾向于用已知的概念来解释任何观察结果。
概念激活向量(TCAV)测试 也采用了类似的方法,用于通过已知的人类概念解释模型预测。因此,在这篇文章中,我们将探讨如何利用 TCAV 提供基于概念的人性化解释。与 LIME 和 SHAP 不同,TCAV 超越了特征归因,参考诸如颜色、性别、种族、形状、任何已知对象或抽象概念等来解释模型预测。我们将在这篇文章中讨论有关 TCAV 的以下主题:
-
直观理解 TCAV
-
TCAV 与其他 XAI 框架的区别
-
基于概念的解释的潜在应用
在这篇文章中,我将提到我书中讨论的一些 XAI 框架 应用机器学习可解释性技术。
[## 应用机器学习可解释性技术:使 ML 模型在实际中可解释和可信…
应用机器学习可解释性技术:使 ML 模型在实际中可解释和可信…
www.amazon.com](https://www.amazon.com/Applied-Machine-Learning-Explainability-Techniques/dp/1803246154?encoding=UTF8&pd_rd_w=Wr6SJ&content-id=amzn1.sym.716a1ed9-074f-4780-9325-0019fece3c64&pf_rd_p=716a1ed9-074f-4780-9325-0019fece3c64&pf_rd_r=6P2PM599T97MRG7NZD9J&pd_rd_wg=m4qUW&pd_rd_r=6e349d93-5ba0-4bfe-9055-905c0153fe58&linkCode=li3&tag=adib0073-20&linkId=35506e1847de5c011fc57aa66c2b1d8e&language=en_US&ref=as_li_ss_il&source=post_page-----5408adf905e--------------------------------)
现在是开始的时候了!
TCAV 简介
使用概念激活向量(TCAV)是一种 XAI 方法,用于了解神经网络模型在预测时使用了哪些信号。TCAV 展示了高层次概念(例如,颜色、性别、种族)对预测类别的重要性,类似于人类的沟通方式!TCAV 提供了一种通常适用于一个感兴趣类别的解释,而不仅仅是单张图片(全局解释)。例如,对于给定的类别,我们可以展示种族或性别对 InceptionV3 分类的影响程度。即使种族或性别标签也并非训练输入的一部分!
该算法依赖于概念激活向量(CAV),它利用人类友好的概念来解释机器学习模型的内部状态。从更技术的角度来看,TCAV 使用方向导数来量化对模型预测有重要影响的人类友好型高层次概念的重要性。例如,在描述发型时,卷发、直发或发色等概念可以被 TCAV 使用。这些用户定义的概念并不是算法在训练过程中使用的数据集的输入特征。
应用机器学习可解释性技术:使机器学习模型对实际应用可解释和可信赖
应用机器学习可解释性技术:使机器学习模型对实际应用可解释和可信赖
应用机器学习可解释性技术:使机器学习模型对实际应用可解释和可信赖
其他流行的 XAI 方法,如 LIME 和 SHAP,依赖于模型认为重要的特征。没有为解释性添加自定义用户定义概念作为输入特征的空间。下图展示了 TCAV 解决的关键问题:

TCAV 解决的关键问题 — 一个概念在预测输出中的重要性是什么? (图片作者)
使用抽象概念进行解释
到现在为止,你可能已经对使用抽象概念进行解释的方法有了直观的理解。但你认为这为什么是一种有效的方法?
让我们再举一个例子。假设你正在构建一个基于深度学习的图像分类器,用于从图像中检测医生。在应用 TCAV 后,假设你发现白人男性的概念重要性最大,其次是听诊器和白色大褂。听诊器和白色大褂的概念重要性是预期中的,但白人男性的高概念重要性表明数据集存在偏差。因此,TCAV 可以帮助评估训练模型的公平性。
本质上,CAV 的目标是估计一个概念(如颜色、性别和种族)对预测训练模型的重要性,即使这些概念在模型训练过程中没有被使用。这是因为 TCAV 从少量示例样本中学习概念。
例如,为了学习一个性别概念,TCAV 需要一些具有男性概念的数据实例和一些非男性示例。因此,TCAV 可以定量估计训练模型对该类特定概念的敏感性。
在生成解释时,TCAV 会将数据点扰动到一个对人类可理解的概念,因此这是一种全球扰动方法。接下来,让我们尝试了解 TCAV 的主要目标。
TCAV 的目标
我发现 TCAV 的方法与其他解释方法相比非常独特。一个主要原因是这个框架的开发者设立了与我对人类友好解释的理解相一致的明确目标。以下是 TCAV 设立的目标:
-
可访问性:TCAV 的开发者希望这种方法对任何终端用户都能可访问,无论他们是否了解机器学习或数据科学。
-
定制化:该框架可以适应任何用户定义的概念。这不仅限于训练过程中考虑的概念。
-
插件准备:开发者希望这种方法可以在不需要重新训练或微调已经训练好的机器学习模型的情况下进行工作。
-
全局可解释性:TCAV 可以通过单一的定量度量解释整个类别或数据集的多个样本。它不局限于数据实例的局部可解释性。
现在我们知道使用 TCAV 可以实现什么,让我们讨论 TCAV 的一般工作方法。
TCAV 的方法
在这一部分,我们将更深入地讨论 TCAV 的工作原理。这个算法的整体工作可以通过以下方法进行总结:
-
应用方向导数来定量估计训练的机器学习模型对各种用户定义概念的预测敏感性。
-
计算最终的定量解释,称为TCAVq 量度,无需任何模型重新训练或微调。这个量度是每个概念对每个模型预测类别的相对重要性。

TCAV 用于估计老虎图像分类器中条纹概念重要性的 approach(图片由作者提供)
现在,我将尝试进一步简化 TCAV 的方法,而不使用太多数学概念。假设我们有一个从图像中识别斑马的模型。
要应用 TCAV,可以采取以下方法:
-
定义感兴趣的概念:第一步是考虑感兴趣的概念。对于我们的斑马分类器,我们可以拥有一组表示该概念的示例(例如黑色条纹在识别斑马时很重要),或者我们可以拥有一个标记了概念的独立数据集。这一步的主要好处是它不限制算法使用模型所用的特征。即使是非技术用户或领域专家也可以根据现有知识定义概念。
-
学习概念激活向量:算法尝试通过训练线性分类器来区分由概念实例生成的激活与存在于任何层中的实例,从而在激活层空间中学习一个向量。因此,CAV 被定义为将带有概念的实例和没有概念的实例在模型激活中分开的超平面的正常投影。对于我们的斑马分类器,CAV 帮助区分表示 黑色条纹 的表征和不表示 黑色条纹 的表征。
-
估计方向导数:方向导数用于量化模型预测对某个概念的敏感性。因此,对于我们的斑马分类器,方向导数帮助我们测量 黑色条纹 表征在预测斑马时的重要性。与使用逐像素显著性的显著性图不同,方向导数是在整个数据集或一组输入上计算的,但针对特定概念。这有助于提供全局视角以进行解释。
-
估计 TCAV 分数:为了量化特定类别的概念重要性,计算 TCAV 分数 (TCAVq)。这个指标有助于测量定义概念对模型特定激活层的正面或负面影响。
-
CAV 验证:CAV 可以从随机选择的数据中生成。但不幸的是,这可能不会产生有意义的概念。因此,为了改进生成的概念,TCAV 会运行多次迭代,以从不同的数据批次中寻找概念,而不是仅在单一数据批次上训练一次 CAV。然后,使用 双侧 t 检验 进行 统计显著性测试,以选择统计上显著的概念。还会进行必要的修正,例如 邦费罗尼修正,以控制假发现率。
因此,我们已经涵盖了 TCAV 算法的直观工作原理。接下来,让我们讨论 TCAV 如何实际应用。
Applied Machine Learning Explainability Techniques: Make ML models explainable and trustworthy for practical…
TCAV 与其他 XAI 方法的区别
现在,让我们总结一下 TCAV 与流行的 XAI 方法如 LIME 和 SHAP 的不同之处。
-
XAI 框架如 LIME 可能会为同一类别的两个数据实例生成相互矛盾的解释。而 TCAV 生成的解释不仅对单一数据实例是准确的,还对整个类别是准确的。这是 TCAV 相较于 LIME 的一个主要优势,这增加了用户对解释方法的信任。
-
基于概念的解释更接近于人类对未知观察的解释,而不是像 LIME 和 SHAP 采用的基于特征的解释。因此,TCAV 生成的解释确实更加符合人类的思维方式。
-
基于特征的解释局限于模型中使用的特征。要引入任何新的特征以进行模型解释,我们需要重新训练模型,而基于概念的解释则更具灵活性,不受限于模型训练期间使用的特征。要引入一个新概念,我们不需要重新训练模型。只需准备必要的数据集来生成概念即可。
-
模型解释性并不是 TCAV 唯一的好处。TCAV 可以帮助在训练过程中发现问题,例如数据集不平衡导致对数据集的偏倚。事实上,概念重要性可以作为比较模型的一个指标。
TCAV 的当前限制
不幸的是,就像这个美丽世界中的一切一样,即使是 TCAV 也并非完美无瑕!虽然 TCAV 以其独特的方式存在,但 TCAV 也有一些限制,这些限制限制了它在模型解释性方面的广泛应用。下面讨论了一些 TCAV 目前的主要限制:
-
目前,基于概念的解释方法使用 TCAV 仅限于神经网络。为了增加其采用率,TCAV 需要一种可以与经典机器学习算法(如决策树、支持向量机和集成学习算法)一起使用的实现。LIME 和 SHAP 可以应用于经典 ML 算法,以解决标准 ML 问题,这可能也是 LIME 和 SHAP 被更广泛采用的原因。类似地,对于文本数据,TCAV 的应用也非常有限。
-
TCAV 对数据漂移、对抗性影响和其他数据质量问题非常敏感。如果你使用 TCAV,你需要确保训练数据、推理数据,甚至概念数据具有类似的统计特性。否则,生成的概念可能会受到噪声或数据不纯问题的影响:
-
Guillaume Alain 和 Yoshua Bengio 在他们的论文 Understanding intermediate layers using linear classifier probes (
arxiv.org/abs/1610.01644) 中,对将 TCAV 应用于较浅的神经网络表示了一些担忧。许多类似的研究论文建议,较深层次的概念相比于较浅网络中的概念更具可分性,因此 TCAV 的使用主要限于深度神经网络。 -
准备概念数据集可能是一个具有挑战性和昂贵的任务。尽管你不需要 ML 知识来准备概念数据集,但在实践中,你也不会期望任何普通终端用户花时间为任何自定义用户定义的概念创建一个注释概念数据集。
-
我认为 TCAV Python 框架在被用于任何生产级系统之前需要进一步改进。依我之见,在撰写本章节时,该框架需要进一步成熟,以便可以轻松地与任何生产级 ML 系统配合使用。
## 应用机器学习可解释性技术:让 ML 模型在实践中可解释和可信赖……
应用机器学习可解释性技术:让 ML 模型在实践中可解释和可信赖……
所有这些限制确实可以解决,使 TCAV 成为一个更强大的框架,得到广泛采用。你也可以联系 TCAV 框架的作者和开发者,为开源社区做贡献!在下一节中,我们将讨论一些基于概念的解释的潜在应用。
基于概念的解释的潜在应用
我确实看到基于概念的解释(如 TCAV)的巨大潜力!在本节中,你将接触到一些基于概念的解释的潜在应用,这些应用可能是整个 AI 社区的重要研究主题,如下所示:
- AI 中透明度和公平性的估计:对黑箱 AI 模型的大多数监管担忧与性别、肤色和种族等概念有关。基于概念的解释实际上可以帮助估计一个 AI 算法在这些抽象概念方面是否公平。检测 AI 模型的偏见实际上可以提高它们的透明度,并帮助解决某些监管问题。例如,在医生使用深度学习模型的情况下,TCAV 可以用于检测模型是否对特定性别、肤色或种族有偏见,因为理想情况下,这些概念对模型的决策不应重要。对这些概念的高概念重要性表示存在偏见。

TCAV 可以用于根据概念重要性检测模型偏见(图片由作者提供)
-
使用 CAV 进行对抗攻击检测:如果你查看 TCAV 研究论文的附录 (
arxiv.org/pdf/1711.11279.pdf),作者提到实际样本和对抗样本的概念重要性是相当不同的。这意味着如果图像受到对抗攻击的影响,概念重要性也会改变。因此,CAVs 可以成为检测对抗攻击的潜在方法。 -
基于概念的图像聚类:使用 CAVs 根据相似概念对图像进行聚类可以是一个有趣的应用。基于深度学习的图像搜索引擎是一个常见的应用,其中聚类或相似性算法被应用于特征向量,以定位相似的图像。然而,这些是基于特征的方法。同样,使用 CAVs 应用基于概念的图像聚类也是一种潜在的可能性。
-
自动化基于概念的解释 (ACE):Ghorbani, Amirata、James Wexler、James Zou 和 Been Kim 在他们的研究工作 面向自动化基于概念的解释 中提到了一种自动化版本的 TCAV,该版本会遍历训练图像并自动发现显著的概念。这项工作很有趣,因为我认为它在识别标记错误的训练数据方面可能有重要应用。在工业应用中,获得完美标记的策划数据集是极具挑战性的。这个问题可以通过 ACE 很大程度上解决。

- 基于概念的反事实解释:另一种重要的 XAI 方法是 反事实解释 (CFE),它可以通过建议对输入特征进行更改,从而改变整体结果,来生成可操作的见解。CFE 提供了翻转预测类别所需的最小特征值。然而,CFE 是一种基于特征的解释方法。研究一个基于概念的反事实解释将是一个非常有趣的课题,它离人类友好的解释更近了一步。目前还没有现成的算法或框架能够帮助我们实现这一点,但这可能是计算机视觉中基于概念的方法的一个有用应用。

基于概念的反事实示例(作者提供的图片)
我觉得这是一个广阔的研究领域,通过基于概念的解释提出颠覆性的应用的潜力巨大。我真诚地希望越来越多的研究人员和 AI 开发者开始关注这一领域,以在未来几年取得重大进展!因此,我们已到达本文的结尾。
[## 应用机器学习可解释性技术:使 ML 模型在实践中可解释和可信…]
应用机器学习可解释性技术:使 ML 模型在实践中可解释和可信…
www.amazon.com](https://www.amazon.com/Applied-Machine-Learning-Explainability-Techniques/dp/1803246154?encoding=UTF8&pd_rd_w=Wr6SJ&content-id=amzn1.sym.716a1ed9-074f-4780-9325-0019fece3c64&pf_rd_p=716a1ed9-074f-4780-9325-0019fece3c64&pf_rd_r=6P2PM599T97MRG7NZD9J&pd_rd_wg=m4qUW&pd_rd_r=6e349d93-5ba0-4bfe-9055-905c0153fe58&linkCode=li3&tag=adib0073-20&linkId=35506e1847de5c011fc57aa66c2b1d8e&language=en_US&ref=as_li_ss_il&source=post_page-----5408adf905e--------------------------------)
总结
本文涵盖了 TCAV 的概念,一种新颖的方法,以及 Google AI 开发的框架。你已经获得了对 TCAV 的概念性理解,了解了 TCAV 的一些关键优势和局限性,最后,我们讨论了一些关于使用基于概念的解释解决潜在研究问题的有趣想法。我推荐阅读这本书:“应用机器学习可解释性技术”,并探索 GitHub 仓库 以获取实际代码示例。
作者在 TDS 上的其他 XAI 相关文章:
-
你应该了解的关键可解释 AI Python 框架
-
解释性机器学习在文本数据训练模型中的应用:结合 SHAP 与 Transformer 模型
-
EUCA — 一种有效的 XAI 框架,使人工智能更接近终端用户
-
了解 SHAP 和用于可解释 AI 的 Shapley 值的工作原理
-
如何使用 LIME 解释图像分类器
## 应用机器学习可解释性技术:让机器学习模型变得可解释和可信…
应用机器学习可解释性技术:让机器学习模型在实际应用中变得可解释和可信…
参考文献
-
书籍《应用机器学习解释性技术》的 GitHub 仓库 —
github.com/PacktPublishing/Applied-Machine-Learning-Explainability-Techniques/ -
特征归因之外的可解释性:使用概念激活向量(TCAV)进行定量测试:
arxiv.org/pdf/1711.11279.pdf -
TCAV Python 框架 —
github.com/tensorflow/tcav -
Koh 等人,《概念瓶颈模型》:
arxiv.org/abs/2007.04612 -
Guillaume Alain 和 Yoshua Bengio,《使用线性分类器探针理解中间层》:
arxiv.org/abs/1610.01644 -
Ghorbani, Amirata, James Wexler, James Zou 和 Been Kim,《朝向自动化概念基础解释》:
arxiv.org/abs/1902.03129 -
章节 10.3 检测概念,Molnar, C.(2022)。《可解释机器学习:黑箱模型可解释性的指南(第 2 版)》:
christophm.github.io/interpretable-ml-book/detecting-concepts.html
以 3 种难度级别解释向量数据库
原文:
towardsdatascience.com/explaining-vector-databases-in-3-levels-of-difficulty-fc392e48ab78
从新手到专家:揭示不同背景下的向量数据库
·发表在 Towards Data Science ·阅读时间 8 分钟·2023 年 7 月 4 日
--

向量空间(图像由作者手绘)
最近,向量数据库引起了很多关注,许多向量数据库初创公司获得了数百万美元的融资。
你可能已经听说过这些内容,但直到现在才真正关心它们——至少,这就是我猜测你为什么在这里的原因……
如果你只是想要简短的答案,我们直接进入正题:
定义:什么是向量数据库?
向量数据库是一种存储和管理非结构化数据(如文本、图像或音频)的数据库,通过向量嵌入(高维向量)来便于快速找到和检索相似对象。
如果这个定义只会让你更加困惑,那么让我们一步一步来。本文受 WIRED 的“5 Levels”视频系列 的启发,分三个难度级别来解读向量数据库:
-
像 5 岁小孩一样解释它
-
向数字原住民和技术爱好者解释向量数据库
-
向工程师和数据专业人士解释向量数据库
向量数据库:像 5 岁小孩一样解释(ELI5)
这有点偏题,但你知道我不理解的是什么吗?
当人们按颜色排列书架上的书时。——哎呀!
当他们不知道书封面的颜色时,怎么找到一本书呢?

图片由 Luisa Brimble 提供,来源于 Unsplash
向量数据库背后的直觉
如果你想快速找到一本特定的书,把书架按类别和作者排列比按颜色排列更有意义。这就是为什么大多数图书馆都这样组织,以帮助你快速找到你想要的书。
但是如何根据查询而不是按类别或作者找到阅读材料呢?如果你想读一本例如:
-
类似于《好饿的毛毛虫》或
-
关于一个像你一样喜欢吃的主要角色?
如果你没有时间浏览书架,最快的方法就是请图书管理员推荐,因为他们读过很多书,会准确知道哪本书最适合你的查询。
在组织书籍的例子中,你可以把图书管理员看作是向量数据库,因为向量数据库被设计用来存储关于对象(例如一本书)的复杂信息(例如书的情节)。因此,向量数据库可以帮助你根据特定查询(例如,一本书的主题是……)找到对象,而不是几个预定义的属性(例如作者)——就像图书管理员一样。
向数字原住民和科技爱好者解释向量数据库
现在,让我们回到图书馆的例子,并深入一点技术细节:当然,现在有比仅按类别或作者查找图书更先进的技术。
如果你访问图书馆,通常角落里会有一台电脑,帮助你找到具有一些更具体属性的书籍,如标题、ISBN、出版年份或一些关键词。根据你输入的值,数据库会查询可用的书籍。这种数据库通常是传统的关系型数据库。

关系型数据库和向量数据库之间的区别是什么?
关系型数据库和向量数据库之间的主要区别在于它们存储的数据类型。关系型数据库设计用于结构化数据,这些数据适合放入表格中,而向量数据库则用于存储非结构化数据,例如文本或图像。
存储的数据类型也影响数据的检索方式:在关系型数据库中,查询结果基于特定关键词的匹配。在向量数据库中,查询结果基于相似性。
你可以把传统的关系型数据库想象成电子表格。它们非常适合存储结构化数据,例如关于一本书的基本信息(例如标题、作者、ISBN 等),因为这些信息可以存储在列中,这些列非常适合过滤和排序。
在关系型数据库中,你可以快速找到所有例如儿童书籍,并且标题中有“毛毛虫”的书。
但如果你喜欢“非常饥饿的毛毛虫”是关于食物的呢?你可以尝试搜索关键字“food”,但除非书籍的总结中提到了“food”这个关键字,否则你甚至找不到“非常饥饿的毛毛虫”。相反,你可能会得到一堆食谱和失望。
这就是关系数据库的一个局限性:你必须添加你认为某人可能需要找到特定项目的所有信息。但你怎么知道要添加哪些信息以及添加多少呢?添加所有这些信息是耗时的,并不能保证完整性。
这就是向量数据库发挥作用的地方!
但首先,我们要绕一个弯谈谈一个叫做向量嵌入的概念。
现在的机器学习(ML)算法可以将给定的对象(例如,单词或文本)转换为保留该对象信息的数值表示。想象一下,你给一个 ML 模型一个单词(例如,“food”),然后这个 ML 模型进行其魔法,返回给你一个长长的数字列表。这长长的数字列表就是你单词的数值表示,称为向量嵌入。
由于这些嵌入是一长串数字,我们称它们为高维的。假设这些嵌入只有三维,以便如下图所示进行可视化。

你可以看到,像“hungry”,“thirsty”,“food”和“drink”这样的相似单词都被分组在一个相似的角落,而像“bicycle”和“car”这样的其他单词则靠近在一起,但在这个向量空间中的不同角落。
数值表示使我们能够对通常不适合计算的对象(如单词)应用数学计算。例如,以下计算将无法进行,除非你用单词的嵌入来替代单词:
drink - food + hungry = thirsty
由于我们能够使用嵌入进行计算,我们还可以计算一对嵌入对象之间的距离。两个嵌入对象越接近,它们的相似度就越高。
正如你所看到的,向量嵌入非常酷。
让我们回到我们的例子,假设我们将图书馆中每本书的内容嵌入到向量数据库中并存储这些嵌入。现在,当你想找到一 “本主角喜欢食物的儿童书”时,你的查询也会被嵌入,返回的书籍是与你的查询最相似的,例如“非常饥饿的毛毛虫”或者“金发姑娘与三只熊”。
向量数据库的使用案例是什么?
向量数据库早在大型语言模型(LLMs)引起热潮之前就已经存在了。最初,它们被用于推荐系统,因为它们可以快速找到与给定查询相似的对象。但由于它们可以为 LLMs 提供长期记忆,最近它们也被用于问答应用中。
向工程师和数据专业人士解释向量数据库
如果你在打开本文之前就已经猜到向量数据库可能是一种存储向量嵌入的方式,并且只想了解向量嵌入的实际情况,那么让我们深入探讨一下,谈谈算法吧。
向量数据库是如何工作的?
向量数据库能够快速检索查询的相似对象,因为它们已经预先计算过这些对象。其基础概念称为近似最近邻(ANN)搜索,它使用不同的算法进行索引和计算相似度。
正如你所想的那样,使用简单的 k-最近邻(kNN)算法计算查询与每个嵌入对象之间的相似度,当你有数百万个嵌入时,可能会变得非常耗时。使用 ANN,你可以在速度和准确性之间做出一些妥协,快速检索出与查询最相似的对象。
索引 — 为此,向量数据库索引向量嵌入。这一步骤将向量映射到一种数据结构,以便更快地搜索。
你可以把索引想象成将图书馆中的书籍按不同类别(如作者或类型)分组。但由于嵌入可以包含更复杂的信息,进一步的类别可能是“主要角色的性别”或“情节的主要地点”。因此,索引可以帮助你从所有可用向量中检索出较小的一部分,从而加快检索速度。
我们不会深入探讨索引算法的技术细节,但如果你对进一步阅读感兴趣,可以从查阅层次可导航小世界(HNSW)开始。
相似度度量 — 为了从索引向量中找到与查询最接近的邻居,向量数据库应用相似度度量。常见的相似度度量包括余弦相似度、点积、欧几里得距离、曼哈顿距离和汉明距离。
向量数据库相对于将向量嵌入存储在 NumPy 数组中的优势是什么?
我经常遇到的一个问题是:我们不能仅使用 NumPy 数组来存储嵌入吗? 当然可以,如果你没有很多嵌入或者只是做一个有趣的爱好项目。但正如你可能已经猜到的,当你有很多嵌入时,向量数据库显著更快,而且你不必将所有内容都保存在内存中。
我会简短说明,因为 Ethan Rosenthal 对使用向量数据库与使用 NumPy 数组的区别进行了比我能写得更好的解释。
## 你真的需要向量数据库吗? | Ethan Rosenthal
剧透警报:答案可能是!虽然,我使用“实际上”这个词暴露了我的偏见。向量数据库是……
享受了这个故事吗?
免费订阅 以便在我发布新故事时收到通知。
[## 每当 Leonie Monigatti 发布时,获取一封电子邮件。
每当 Leonie Monigatti 发布时,获取一封电子邮件。通过注册,如果你还没有 Medium 账户,你将创建一个…
medium.com](https://medium.com/@iamleonie/subscribe?source=post_page-----fc392e48ab78--------------------------------)
在 LinkedIn、Twitter和 Kaggle上找到我!
MEMS 传感器数据的探索性分析
原文:
towardsdatascience.com/exploratory-analysis-of-mems-sensor-data-bbfc0aa0a887
从 MPU6050 传感器读取、收集和分析数据
·发布于 Towards Data Science ·阅读时间 13 分钟·2023 年 8 月 19 日
--
MEMS(微电机械系统)传感器广泛应用于不同的场景,从游戏控制器和智能手机到无人机。在这篇文章中,我将展示如何连接陀螺仪和加速度计传感器、可以从中获取的数据类型以及如何处理和可视化这些数据。
让我们开始吧。
硬件
MPU-6050 是一个 6 轴传感器,结合了 3 轴陀螺仪、3 轴加速度计和 I2C 接口。正如数据表中所写,它广泛用于平板电脑和智能手机。当我们的智能手机或智能手表在锻炼时计算步数和卡路里时,实际上使用的是 MEMS 传感器的数据。但是像这样的传感器不仅仅可以用于运动。我决定将传感器放在我家中几天,看看是否能够检测和分析我所住的建筑物中的各种振动。
如果我们想在几天内收集数据,Raspberry Pi 是一个很好的解决方案。Raspberry Pi 是一个便宜的(30-50 美元)单板计算机;它功耗低,拥有大量引脚来连接不同类型的硬件。可以在亚马逊上以 3-5 美元的价格订购一个 MPU-6050 原型板。传感器本身使用 I2C 总线进行数据传输,只需 4 根线即可将其连接到 Raspberry Pi:

连接图示,作者提供的图片
在使用传感器之前,应该在 Raspbian 操作系统上启用 I2C 总线(关于如何将 MPU6050 连接到 Raspberry Pi 的教程足够多,所以我在这里跳过“硬件”细节)。连接传感器后,我创建了一个简单的 Python 应用程序,读取传感器数据并将其“原样”写入日志文件中:
from datetime import datetime
import smbus
import math
import time
# MPU6050 Registers
PWR_MGMT_1 = 0x6B
SMPLRT_DIV = 0x19
CONFIG = 0x1A
GYRO_CONFIG = 0x1B
INT_ENABLE = 0x38
ACCEL_XOUT_H = 0x3B
ACCEL_YOUT_H = 0x3D
ACCEL_ZOUT_H = 0x3F
GYRO_XOUT_H = 0x43
GYRO_YOUT_H = 0x45
GYRO_ZOUT_H = 0x47
bus = smbus.SMBus(1)
address = 0x68
def device_init():
""" Init the MPU-6050 """
bus.write_byte_data(address, SMPLRT_DIV, 0x4)
bus.write_byte_data(address, PWR_MGMT_1, 1)
bus.write_byte_data(address, CONFIG, 0)
bus.write_byte_data(address, GYRO_CONFIG, 24)
bus.write_byte_data(address, INT_ENABLE, 1)
def read_byte(reg):
""" Read 1 byte from the sensor """
return bus.read_byte_data(address, reg)
def read_word(reg):
""" Read 2 bytes from the sensor """
h = bus.read_byte_data(address, reg)
l = bus.read_byte_data(address, reg + 1)
value = (h << 8) + l
return value
def read_word_2c(reg):
""" Read and convert the data """
val = read_word(reg)
return -((65535 - val) + 1) if val >= 0x8000 else val
def device_read():
""" Get accel and gyro data """
g_x = read_word_2c(GYRO_XOUT_H) / 131
g_y = read_word_2c(GYRO_YOUT_H) / 131
g_z = read_word_2c(GYRO_ZOUT_H) / 131
a_x = read_word_2c(ACCEL_XOUT_H) / 16384
a_y = read_word_2c(ACCEL_YOUT_H) / 16384
a_z = read_word_2c(ACCEL_ZOUT_H) / 16384
return g_x, g_y, g_z, a_x, a_y, a_z
if __name__ == "__main__":
device_init()
device_read()
while True:
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
gyro_x1, gyro_y1, gyro_z1, accel_x1, accel_y1, accel_z1 = device_read()
gyro_x2, gyro_y2, gyro_z2, accel_x2, accel_y2, accel_z2 = device_read()
g_x, g_y, g_z = (gyro_x1 + gyro_x2)/2, (gyro_y1 + gyro_y2)/2, (gyro_z1 + gyro_z2)/2
a_x, a_y, a_z = (accel_x1 + accel_x2)/2, (accel_y1 + accel_y2)/2, (accel_z1 + accel_z2)/2
s_data = f"{timestamp},{g_x: .7f},{g_y: .7f},{g_z: .7f},{a_x: .7f},{a_y: .7f},{a_z: .7f}"
# Save to log file
log_filename = datetime.now().strftime('%Y-%m-%d.log')
with open(log_filename, "a", encoding="ascii") as log_out:
log_out.write(s_data + "\n")
在“生产”场景中,我们可以将数据发送到 Kafka 主题或其他任何云服务提供商,但对于“家庭”测试,只需在后台运行应用程序即可:
nohup python3 accel_read.py >/dev/null 2>&1 &
之后,我们可以让 Raspberry Pi 运行几天。
从代码中我们可以看到,所有的日志文件都有一个“YYYY-MM-DD”的模式。我们可以使用 scp 从 Raspberry Pi 下载这些文件:
scp pi@raspberrypi3:/home/pi/Documents/AccelData/2023-08-01.log data
现在让我们看看我们可以获得什么样的数据。
总体见解
首先,让我们看看陀螺仪和加速度计的数据是什么样的。我们需要包含所需的库:
import pandas as pd
from bokeh.plotting import figure, show
from bokeh.models import Range1d, DatetimeTickFormatter
from bokeh.io import output_notebook
from bokeh.layouts import row, column, gridplot
output_notebook()
现在让我们将 CSV 文件加载到 Pandas 数据框中,并使用 Bokeh 绘制它:
df_sample = pd.read_csv("mpu6050.csv",
header=None,
names=["timestamp", "g_x", "g_y", "g_z", "a_x", "a_y", "a_z"],
parse_dates=["timestamp"])
display(df_sample)
这个样本包含了在 6 秒内收集的记录,数据值如下所示:

正如我们所见,我们可以从传感器中每秒获得约 60 次测量。让我们绘制数据:
timestamps = df_sample['timestamp']
# Accelerometer data
p1 = figure(title="Accelerometer data", x_axis_type='datetime',
x_axis_label='x', y_axis_label='y', width=1600, height=600)
p1.line(timestamps, df_sample["a_x"], legend_label="A_X", line_width=2, color="red")
p1.line(timestamps, df_sample["a_y"], legend_label="A_Y", line_width=2, color="green")
p1.line(timestamps, df_sample["a_z"], legend_label="A_Z", line_width=2, color="blue")
# Gyroscope data
p2 = figure(title="Gyroscope data", x_axis_type='datetime',
x_axis_label='x', y_axis_label='y', width=1600, height=600)
p2.line(timestamps, df_sample["g_x"], legend_label="G_X", line_width=2, color="#AA8822")
p2.line(timestamps, df_sample["g_y"], legend_label="G_Y", line_width=2, color="#AA88AA")
p2.line(timestamps, df_sample["g_z"], legend_label="G_Z", line_width=2, color="#2288AA")
show(column(p1, p2))
顺便提一下,Bokeh 库非常适合绘制这样的数据。结果是,至少在我的计算机上,当点数超过几千时,Matplotlib 几乎“崩溃”了。与此同时,Bokeh 可以处理在一个图表中多达 100 万条记录的数据。
输出结果如下所示:

加速度计和陀螺仪数据,图片来源于作者
了解陀螺仪和加速度计之间的区别也很重要。加速度计(上图)测量的是静态加速度,包括来自地球重力的作用。在这个例子中,我慢慢地在手中旋转板子,所有三个 X、Y 和 Z 轴的值都在成比例地变化。陀螺仪测量的是瞬时动量,围绕每个轴的旋转。这些数据看起来像是加速度计数据的导数;当运动开始时,会出现一个峰值,然后数值回到零。
我原本不打算在本文中分析步态数据,但一些读者可能仍然对人类步态的原始传感器数据感兴趣:

人类步态的加速度计和陀螺仪数据,图片来源于作者
正如我们所见,特别是在陀螺仪数据图中,这种模式很容易检测。为了本文的目的,更具挑战性的是看看我们是否能够检测到数据中更微小的变化,比如建筑物的振动。
数据分析
在前面的步骤中,我们了解了如何制作一个简单的应用程序来收集传感器数据,并对这些数据有了一个大致的了解。现在,让我们更详细地看看可以发现哪些有趣的模式。对于所有后续的示例,我将使用在 24 小时内收集的数据。
1. 时间戳准确性
作为提醒,传感器数据是通过在 Raspberry Pi 上运行的 Python 应用程序收集的。Raspberry Pi 本身运行的是 Linux,而不是实时操作系统。首先,让我们看看我们拥有的时间戳准确性。
df = pd.read_csv("data/2023-08-06.log", header=None,
names=["timestamp", "g_x", "g_y", "g_z", "a_x", "a_y", "a_z"], parse_dates=["timestamp"])
t_diff = df["timestamp"].diff().dt.total_seconds()[1:]
diff_mean = t_diff.mean()
print(diff_mean, 1/diff_mean)
#> 0.0156 63.81
我在 24 小时内收集了 5,513,693 条记录;总文件大小约为 500 MBytes。正如我们所见,时间戳之间的平均差异为 0.015 秒,平均 fps 约为 64。它的稳定性如何?让我们创建一个时间差直方图:
t_diff = df["timestamp"].diff().dt.total_seconds()[1:]
h, bins = np.histogram(t_diff.values, bins=1024)
print(list(zip(h, bins))[:100])
#> [(159712, 0.010), (5349277, 0.015), (4134, 0.0199), (293, 0.02462),
#> (96, 0.0293), (28, 0.0339), (10, 0.0386), (7, 0.043), (21, 0.048),
#> ...
#> (1, 0.1650), (1, 0.1697), (0, 0.1743), (1, 0.1790), (0, 0.1837), ...]
# Convert X to milliseconds and normalize Y to 0..100%
n_total = sum(h)
h = 100*h/n_total
bins *= 1000
# Create the bar plot
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.bar(bins[0:8], h[0:8], color='#440154', width=2.0)
ax.yaxis.label.set_color('gray')
ax.spines['left'].set_color('#DDDDDD')
ax.spines['right'].set_color('#DDDDDD')
ax.spines['top'].set_color('#DDDDDD')
ax.spines['bottom'].set_color('gray')
ax.xaxis.label.set_color('black')
ax.yaxis.label.set_color('gray')
ax.tick_params(axis='y', colors='#303030')
plt.xlabel("Timestamp difference, milliseconds")
plt.ylabel("Percentage of records")
plt.title("Timestamps accuracy")
plt.show()
我们可以看到,Raspbian 确实不是一个实时操作系统,但其准确性足够满足我们的任务:

时间戳差异(毫秒),图片由作者提供
5,349,277 条记录(超过 70%)具有约 0.015 秒(15 毫秒)的延迟,只有不到 50 条记录(0.001%)的间隔长于 0.01 秒(100 毫秒)。
2. 声谱图
让我们来看看更有趣的部分。显然,我们无法用肉眼分析 500 万条记录。让我们建立一个声谱图,以便查看频率域中是否存在一些异常:
def draw_sonogram(df_out: pd.DataFrame, t_start: datetime.datetime, t_end: datetime.datetime):
""" Draw a sonogram from the dataframe """
values = df_["g_y"].values
t_diff = df_['timestamp'].diff().dt.total_seconds()[1:].mean() # 0.015s ~ 50Hz
fig, ax = plt.subplots(1, 1, figsize=(24, 14))
ax.specgram(values, NFFT=256, Fs=1/t_diff, noverlap=50, scale="dB")
plt.ylabel('Frequency, Hz')
plt.xlabel('Time, sec')
plt.show()
draw_sonogram(df, datetime.time(9,0,0), datetime.time(10,0,0))
声谱图基于快速傅里叶变换(FFT),将值从“时间域”转换为“频率域”。声谱图上的最大频率约为 30 Hz,根据奈奎斯特定理,这相当于采样率的一半。手动计算可能需要很多工作,但 Matplotlib 的“specgram”方法为我们完成了所有工作。结果如下:

传感器数据的声谱图,图片由作者提供
如我们所见,图表上有一些斑点,但我居住的建筑物没有明显的共振频率。但对于其他类型的建筑(电动机、机器、桥梁等),这种类型的分析可能很有用。
3. 热图
如果我们想在振动数据中找到一些模式,查看时间轴上的信号幅度是有意义的。但记录太多,将它们绘制在一条线上效果不好。在这种情况下,热图会更好。
为了制作热图,我使用了三个步骤来预处理数据:
- 归一化。我提取了均值并取了绝对值:
df_ = df.copy()
df_["g_y_norm"] = (df_["g_y"] - df_["g_y"].mean()).abs()
- 对滚动周期取最大值。这部分稍微复杂一些。在振动数据中,可能会出现短暂的 1–2 秒峰值。这些小峰值在 24 小时时间轴上不可见,所以我决定使用以下代码进行“滚动最大值”处理:
N = 400
df_["g_y_roll"] = df_['g_y_norm'].rolling(N).max()
实际上,它看起来是这样的:

数据处理示例,图片由作者提供
在这个例子中,我将滚动的样本数设置为 400。我们将有一个比短小(紫色)峰值更大的点在热图上。
热图本身可以使用 Seaborn 的“heatmap”方法来显示。下面是包括预处理和绘制的完整代码:
import seaborn as sns
def draw_heatmap(df: pd.DataFrame):
""" Draw a heatmap from a dataframe """
# Normalization and applying the rolling maximum
df_ = df.copy()
N = 400
df_["g_y_norm"] = (df_["g_y"] - df_["g_y"].mean()).abs()
df_["g_y_roll"] = df_['g_y_norm'].rolling(N).max()
df_ = df_.iloc[::N, :] # Keep each Nth element
# Reshape all items to (24, N) matrix for heatmap
items_all = df_["g_y_roll"].values[2:]
items_per_hour = items_all.shape[0]//24
items_reshaped = items_all[:items_per_hour*24].reshape((24, -1))
# Horizontal labels
hor_ticks = 6
# Draw
fig, ax = plt.subplots(figsize=(30, 8))
sns.heatmap(items_reshaped, vmin=0, vmax=0.08,
cbar_kws={"orientation": "vertical", "pad": 0.01}, ax=ax)
ax.hlines(list(range(24)), *ax.get_xlim(), colors="#303030")
plt.xticks(rotation=0)
ax.set_xticks(np.linspace(0, items_per_hour, hor_ticks+1))
ax.set_xticklabels([10*n for n in range(hor_ticks+1)])
plt.title('MPU6050 Vibration Levels', fontsize=16)
plt.xlabel('Minutes', fontsize=12)
plt.yticks(rotation=0)
plt.ylabel('Hours', fontsize=12)
plt.show()
draw_heatmap(df)
最终图像如下:

MPU6050 振动数据热图,图片由作者提供
正如我们所见,将所有 24 小时的数据放在一张图中更加“说明性”。例如,我们可以很容易地看到交通振动,它发生在早上 9 点和下午 4 点到 6 点之间。
我没有实现亮度调整;可以通过调整 sns.heatmap 调用的 vmin 和 vmax 参数在代码中手动更改。数据本身没有缺口,缺失值的处理在这里没有实现。
4. 异常检测
在热图上,我们可以看到一些有趣的模式,比如晚间交通造成的振动。我们还可以看到一些明亮的白点——有趣的是知道它是什么。
为了检测这些“异常”,我们将尝试两种方法。首先,我们可以直接找到大于阈值的数据。作为第二种方法,我们可以使用像 Python Outlier Detection (PyOD) 这样的现成库。让我们测试一下这两种方法!
基于阈值的过滤 是直接的。作为阈值,我选择了一个大值(7 个标准差),因此随机获得这种振动水平的概率极小。至于过滤本身,Pandas 已经拥有所有需要的方法:
df_ = df.copy()
df_["g_y_norm"] = (df_["g_y"] - df_["g_y"].mean()).abs()
std_y = df_["g_y_norm"].std()
threshold = 7*std_y
df_filtered = df_.index[df_['g_y_norm'] >= threshold]
print(df_filtered)
# > [2087075, 2153277, 2153981, 2798119, 2800170, 2800171,
# > 3065854, 3065855,3065856, 3065858]
输出为数组索引。但有些项目过于接近;例如,索引 3065854 和 3065855 确实表示相同的事件。为了过滤数组,我创建了一个辅助方法来删除冗余项目:
def shrink_array(data: Any, num_samples: int) -> List:
""" Remove too close items from array. Example: [1, 2, 3, 10, 20] => [1, 10, 20] """
out = data[:1]
for val in data[1:]:
if val > out[-1] + num_samples:
out.append(val)
return out
indexes = shrink_array(df_filtered.values.tolist(), num_samples=500)
print(indexes)
# > [2087075, 2153277, 2153981, 2798119, 2800170, 3065854]
这里的参数“num_samples”用作标准;所有接近此值的数组项将从列表中删除。
现在,我们可以使用 Bokeh 显示结果:
from bokeh.layouts import gridplot
def make_plot(df_out: pd.DataFrame, title: str):
""" Show graph data """
timestamps = pd.to_datetime(df_out['timestamp'].values).to_pydatetime()
p = figure(title=title, x_axis_type='datetime',
x_axis_label='x', y_axis_label='y',
width=600, height=400)
p.line(timestamps, df_out["g_y"].values - df_out["g_y"].mean(),
legend_label="G_Y", line_width=1, color="blue")
p.xaxis.formatter=DatetimeTickFormatter(seconds="%H:%M:%S")
p.y_range = Range1d(-1.0, 1.0)
return p
plots = []
for ind in indexes:
plots.append(make_plot(df_[ind - 20:ind + 100], title=f"Index={ind}"))
show(gridplot(np.array_split(plots, 2)))
输出如下所示:

由阈值检测到的异常,图像由作者提供
作为本文的最后一步,让我们使用 Python Outlier Detection (PyOD) 库来寻找异常。这个库实现了 40 多种算法;我将展示其中之一,仅仅是为了让读者了解其工作原理。我将使用基于邻近的 KNN(k 最近邻)算法,它使用到第 k 个最近邻的距离作为异常值评分。
首先,我们需要使用一些数据来拟合算法。为此,我使用了之前找到的一个索引:
from pyod.models.knn import KNN
pos_train = 2087075
df_ = df[pos_train - 100000:pos_train + 100000]
fit_data = df_[["g_x", "g_y", "g_z"]].to_numpy()
clf = KNN(contamination=0.0001)
clf.fit(fit_data)
正如我们所见,使用 PyOD 和我的“天真”方法之间有两个主要区别。首先,PyOD 可以分析 多变量数据,因此我们可以使用传感器的所有三个轴。其次,根据我们的领域知识,我们需要指定一个 污染 率。我在寻找非常罕见且短暂的事件,这些事件可能每几个小时发生一次,因此我将此值设置为 0.0001。
当检测器训练完成后,我们可以简单地使用“predict”方法来处理另一组数据并获取结果:
pos_test = 2800170
df_test = df[["g_x", "g_y", "g_z"]][pos_test - 5000:pos_test + 5000]
data = df_test.to_numpy()
y_pred = clf.predict(data) # Outlier labels (0 or 1)
为了以可视化形式查看结果,让我们在同一图表上绘制输入和预测:
# Draw
x = np.arange(0, len(y_pred))
y = y_pred
y1 = df_test["g_x"]
y2 = df_test["g_y"]
y3 = df_test["g_z"]
p = figure(title="KNN anomaly detection results",
x_axis_label='x', y_axis_label='y',
width=1600, height=500)
p.line(x, 0.04*y, legend_label="Anomaly", line_width=2, color="gray")
p.line(x, y1, legend_label="g_x", line_width=2, color="green")
p.line(x, y2, legend_label="g_y", line_width=2, color="red")
p.line(x, y3, legend_label="g_z", line_width=2, color="blue")
show(p)
这里的红色、绿色和蓝色线条表示传感器数据,灰色线条是预测结果,小的峰值显示了检测到异常值的索引:

异常检测与 PyOD,作者提供的图片
它有效。正如之前所述,PyOD 中有 40 多种算法可供使用。欢迎有兴趣的读者自行测试其他算法。如果有人希望在相同的数据集上测试这些算法,请在下方评论,我会分享一个临时链接。
结论
在这篇文章中,我解释了如何将 MPU6050 MEMS 传感器连接到树莓派单板计算机,并在几天内收集来自建筑物的振动数据。然后我们以不同的方式分析了这些数据,比如在时间轴和热图上绘制原始数据、在频域中构建声谱图,并应用异常检测算法。这对于现代城市环境的研究可能很有趣;例如,交通引起的振动在图表上非常明显(我实际上很惊讶 MPU6050,这个传感器主要设计用于智能手机和游戏控制器,竟然能可靠地检测到如此微小的振动)。甚至有可能检测到如地震这样的稀有事件,尽管对于这种事件,拥有至少两个地点的数据会更可靠。振动分析也可以用于预测机器如电动机或涡轮机的故障。实际应用的可能性非常大。
总的来说,进行这个实验以及使用“真实”硬件和其数据是很有趣的。遗憾的是,我必须承认关于硬件和物联网中的数据科学和数据分析的文章和帖子数量微乎其微,不仅在 TDS 上如此,在其他网站上也是如此。我希望这个故事能稍微改变这种不平衡,向读者展示与硬件打交道也可以很有趣。
感谢阅读。如果你喜欢这个故事,欢迎 订阅 Medium,你将会收到我新文章发布的通知,并且可以完全访问来自其他作者的成千上万的故事。
在 R 中的探索性相关分析
无痛且友好的 tidyverse 相关分析,使用 rstatix
·
关注 发表在 Towards Data Science · 7 分钟阅读 · 2023 年 5 月 5 日
--
图片由 Armand Khoury 提供,来源于 Unsplash
相关分析是探索两个或更多变量之间关系的最基本且基础的方法之一。你可能已经使用 R 进行了某些相关分析,结果可能看起来像这样:
cor_results <- cor.test(my_data$x, my_data$y,
method = "pearson")
cor_results
输出可能如下所示:

这是使用基本 R 对你预先选择的两个变量进行简单相关分析的方法。
但如果你真的不知道自己在寻找什么怎么办?如果你仅仅是在进行一些探索性数据分析的阶段,你可能不知道自己感兴趣的变量是什么,或者可能想要寻找关联的地方。在这种情况下,能够选择一个感兴趣的变量,然后对比包含多个、甚至数百个变量的数据集,以找出进一步分析的良好起点,可能会很有帮助。由于rstatix包的开发者kassambara的帮助,现在有一种快速且相对无痛的方法来做到这一点。
获取数据
例如,我将使用来自世界银行的世界发展指标(WDI)数据集的数据——这是一个关于全球发展指标的开放访问数据存储库。我们可以从上述链接的网站访问 WDI,但也有一个适用于此的 R 包——
install.packages("WDI")
library(WDI)
可以使用 WDI()函数从 WDI 中导入特定的数据系列,但由于我们感兴趣的是涵盖大量变量之间可能关系的探索性分析,我将批量下载整个数据库……
bulk <- WDIbulk(timeout = 600)
假设我们有兴趣尝试找出与贸易量更多的国家(相对于其经济规模)可能相关的其他国家特征,并且我们也对 2020 年的数据感兴趣。
一旦我们确定了正确的变量(这里我将使用贸易占 GDP 的百分比),我们需要对数据进行一些清理。我们将创建一个可以过滤的年度系列列表,然后应用另一个过滤步骤,以确保我们只使用在分析中有足够观察值的变量(在下面的示例中,任意设置为 n>100)。
# Create a filtered set with only annual variables
filtered <- bulk$Series %>% filter(Periodicity == "Annual")
# Create a list of variables to correlate against trade levels
bulk$Data %>%
filter(Indicator.Code %in% c(filtered$Series.Code)) %>%
filter(year == 2020) %>%
group_by(Indicator.Code) %>%
filter(!is.na(value)) %>%
count() %>%
arrange(n) %>%
filter(n>100) -> vars_list
运行分析
现在我们有一个变量列表——大约 790 个——来查看哪些可能与我们的贸易水平变量相关。手动运行这些,或者用 R 的 cor.test()循环处理,将需要永远的时间。这时 rstatix 中的 cor_test()函数就显得非常重要——它运行得相当快,相关分析的输出被转储到 tibble 格式中(便于进一步的操作和分析),并且这些函数支持管道操作,意味着我们可以将过滤、变换和执行步骤组合到一个管道框架中,也可以为 rstatix 的分组输出组合变量输入(稍后我们将查看一些例子)。
因此,为了运行分析:
# Because WDI contains regional data as well, we'll create a list that only has country codes, and filter our input based on that list
countries <- bulk$Country %>% filter(!Region == "") %>% as_tibble()
bulk$Data %>%
filter(Indicator.Code %in% c(vars_list$Indicator.Code)) %>%
filter(year == 2020) %>%
filter(Country.Code %in% c(countries$Country.Code)) %>%
select(-Indicator.Name) %>%
pivot_wider(names_from = Indicator.Code,
values_from = value) %>%
cor_test(NE.TRD.GNFS.ZS,
method = "pearson",
use = "complete.obs") -> results
results
这会填充一个 tidy 的 tibble,其中包含变量配对、相关系数(r)、t 统计量、置信水平(p)以及低和高置信度估计。对于我们上面的示例运行,它看起来是这样的:

因为输出是一个 tibble,我们可以按照我们想要的方式对其进行排序和分解。让我们用变量名称和描述制作一个关键字,将其加入到我们的输出数据中,过滤掉只有在 p > 0.05 水平上显著的变量对,然后查看哪个变量具有最高的 r 值:
indicator_explanation <- bulk$Series %>% select(Series.Code, Indicator.Name, Short.definition) %>% as_tibble()
results %>%
left_join(indicator_explanation, c("var2" = "Series.Code")) %>%
arrange(desc(cor)) %>%
filter(p<0.05) %>%
View()

一些相关性最高的变量不会令人惊讶——例如,整体贸易在各国之间与服务贸易和商品贸易高度正相关。其他的可能更为意外——比如贸易水平与一个国家作为总资本形成百分比的官方发展援助(援助资金)(通常用作援助“依赖性”指标——上图中的底部行)之间的中等高正相关(r = 0.43)。
分组分析
那么,如果我们想深入研究这种关系呢?例如——如果我们查看 2020 年以外的其他年份,这种关系是否仍然强烈?这时,cor_test() 的管道友好特性再次显得非常有用。
让我们过滤掉初始数据,只包括我们感兴趣的两个指标,然后按年份对数据进行分组,再将其传递到 cor_test() 中:
bulk$Data %>%
filter(Indicator.Code %in% c("NE.TRD.GNFS.ZS", "DT.ODA.ODAT.GI.ZS")) %>%
filter(Country.Code %in% c(countries$Country.Code)) %>%
select(-Indicator.Name) %>%
filter(year<2021) %>%
pivot_wider(names_from = Indicator.Code,
values_from = value) %>%
group_by(year) %>%
cor_test(NE.TRD.GNFS.ZS, DT.ODA.ODAT.GI.ZS,
method = "pearson",
use = "complete.obs") -> results_time
这将给我们提供每年观察到的两个变量之间的相关性数据(我筛选了数据,只包括 2021 年之前的年份,因为 ODA 数据仅到 2020 年为止)。而且由于相关性数据以整洁的方式存储,我们可以轻松地运行附加代码来可视化我们的结果:
results_time %>%
mutate(`Significant?` = if_else(p<0.05, "Yes", "No")) %>%
ggplot(aes(x = year, y = cor)) +
geom_hline(yintercept = 0,
linetype = "dashed") +
geom_line() +
ylab("cor (r)") +
geom_point(aes(color = `Significant?`)) +
theme_minimal()
在这里我们可以看到,历史上这两个变量之间几乎没有任何关系(除了偶尔几年的弱负相关),但在过去几年中,相关性呈现出显著且正向的趋势:

那这意味着什么呢?就贸易和援助之间的任何潜在问题而言——我们需要做更多的研究。毕竟,C相关性并不意味着因果关系,但这是一个很好的假设生成器——接受援助的国家是否变得越来越注重贸易?还是援助分配的模式转向更倾向于那些贸易更多的国家?这些都是我们可以探索的新方向。这些快速的相关性分析可以成为趋势分析或信号发现的一个非常有用的工具——而拥有一个友好的 tidyverse 方法来完成这项工作确实避免了潜在的麻烦。
就我们快速、轻松地进行一些有用的探索性分析的能力而言,我们可以看到 rstatix 是一个有用的附加包。然而,rstatix 中的 cor_test() 也有一些缺点 —
-
与“correlation”包中提供的更多方法相比,你只能使用 Pearson (r)、Spearman (ρ) 和 Kendall (τ) 相关性方法。不过,这些方法对于普通用户来说是最常见的,应该足以满足基本分析需求。
-
置信区间仅在 Pearson’s r 的输出中报告。这意味着如果需要 Spearman’s rho 或 Kendall’s tau 的置信区间,则需要额外的代码。
-
样本大小和自由度未被报告,这可能会让用户感到烦恼,例如当用户的目标是基于不同分组的段落开发多个报告时。
但这些通常不适用于普通用户。此外,除了 cor_test() 外,rstatix 还提供了大量其他函数用于各种统计测试和程序,下次你需要进行一些探索性统计分析时,绝对值得深入了解一下这些函数——为开发者点赞。
注意:想更深入了解 rstatix 与其他 R 中相关性分析包之间的差异,感兴趣的读者可以查看: https://www.r-bloggers.com/2021/01/correlation-analysis-in-r-part-2-performing-and-reporting-correlation-analysis/
喜欢这个故事吗?关注我在 Medium 上,或者在LinkedIn或Twitter上与我联系。
Google Sheets 中的探索性数据分析
原文:
towardsdatascience.com/exploratory-data-analysis-in-google-sheets-5df4d0e4d2dd
比较 Google Sheets 和 Pandas 方法
·发布于 Towards Data Science ·8 分钟阅读·2023 年 7 月 14 日
--

图片由作者生成
使用现代工具如 Pandas 或 Jupyter 处理数据总是很愉快。但让我们想象一下,如果一个同事或朋友要求进行数据分析,但他或她不是技术人员,不使用 Python 或 Jupyter,也没有 Tableau、Power BI 或其他花哨(但遗憾的是不免费的)服务的账户。在这种情况下,使用 Google Sheets 处理数据可以是一个不错的变通方法,原因有几个:
-
Google 在全球范围内使用;在撰写本文时,已有超过 18 亿用户拥有 Google 账户。现在几乎每个人都有 Google 账户,文档共享将变得非常容易。
-
Google 的生态系统是安全可靠的。它支持双重身份验证和现代安全标准,即使是私人数据集也可以在有限的人群中共享。
-
最后但同样重要的是,这个解决方案是免费的,不需要额外的费用。作为额外的好处,Google Sheets 在浏览器中运行,不需要安装任何软件,并且可以在 Windows、Linux、OSX 或甚至智能手机等任何平台上使用。
在本文中,我将使用 Pandas 进行基本的探索性数据分析,然后我们将重复这一过程在 Google Sheets 中,看看效果如何。
数据来源
为了增加趣味性,我们来使用一个真实的数据集。我们将制作一个计算太阳能电池板生成的能量的工具。为此,我将使用 PVGIS(欧洲委员会光伏地理信息系统)数据,可以通过 这个 URL 免费访问(CC BY 4.0 许可):

PVGIS 接口,图片由作者生成
使用这个页面,我们可以下载太阳辐射数据,从而计算能量生成。正如截图所示,我们可以选择不同年份和不同地点的小时数据。下载数据后,让我们在 Pandas 中使用它。
Pandas 中的 EDA
让我们从 Pandas 中的探索性数据分析(EDA)开始。使用熟悉的工具总是更容易,这也能让我们验证结果。首先,让我们加载数据集:
import pandas as pd
import datetime
df_eu = pd.read_csv("EUTimeseries_53.087_5.859_SA2_60deg_120deg_2020_2020.csv",
skiprows=8).dropna()
display(df_eu)
代码不言自明。CSV 文件的开头有评论和空行,因此我使用了“skiprows=8”来跳过不需要的数据;这就是读取文件所需的唯一“调整”。
输出如下:

太阳辐射数据集,图片由作者提供
我们有 8784 行,代表每小时收集的数据。根据文档,“G(i)”是以瓦特/平方米为单位的太阳辐射;其他参数,如风速或温度,对我们的任务没有用。时间戳不是标准的,让我们将字符串值转换为日期和时间对象。我还需要将“G(i)”值从字符串转换为浮点数:
def str_to_date(d: str):
""" Convert string to datetime object """
try:
return datetime.datetime.strptime(d, '%Y%m%d:%H%M')
except:
return None
def str_to_float(f: str):
""" Convert string value to float """
try:
return float(f)
except:
return None
df_eu['time'] = df_eu['time'].map(str_to_date)
df_eu['G(i)'] = df_eu['G(i)'].map(str_to_float)
现在我们可以进行所需的计算。数据集包含以瓦特每平方米为单位的太阳辐射数据。数据是以每小时间隔收集的,我们只需将值除以 1000 就能将瓦特转换为千瓦时(kWh)。为了得到最终的千瓦时输出,我们还需要知道太阳能板的数量及每个面板的尺寸和效率(这些数据可以在太阳能板的数据表中找到):
panels_amount = 1
panel_size_m2 = 2.5
panel_efficiency = 0.18
df_eu["kWh"] = panels_amount * panel_size_m2 * panel_efficiency * df_eu['G(i)'] / 1000
现在我们可以进行一些数据探索。让我们找出每天的太阳能发电量,例如,夏天的六月一号。我将使用 Bokeh Python 库来绘制结果:
from bokeh.io import show, output_notebook
from bokeh.models import ColumnDataSource
from bokeh.plotting import figure
output_notebook()
df_day = df_eu[df_eu['time'].dt.date == datetime.date(2020, 6, 1)]
source = ColumnDataSource(df_day)
p = figure(width=1600, height=600, x_axis_type='datetime',
title="Solar Panels Generation Per Day, kWh")
p.vbar(x='time', top='kWh', width=datetime.timedelta(minutes=50), source=source)
p.xgrid.grid_line_color = None
p.xaxis.ticker.desired_num_ticks = 12
p.y_range.start = 0
p.y_range.end = 0.4
show(p)
输出如下:

六月的每日太阳能板发电量,图片由作者提供
计算总发电量也很简单:
print("Total, kWh:", df_day["kWh"].sum())
> Total, kWh: 1.5560864999999997
我们的太阳能板在六月每天产生了 1.56 千瓦时。相比之下,同一面板在十二月的发电量要低得多:

十二月的每日太阳能板发电量,图片由作者提供
让我们查看每年的发电量并计算总输出。为此,我将按月对数据框进行分组:
df_eu["month"] = df_eu["time"].dt.month
df_eu_month = df_eu[["month", "kWh"]].groupby(["month"], as_index=False).sum()
display(df_eu_month.style.hide(axis="index"))
输出如下:

按月分组的太阳能板发电量,图片由作者提供
作为最后一步,让我们查看图表:
source = ColumnDataSource(data=dict(months=df_eu_month["month"],
values=df_eu_month["kWh"]))
p = figure(width=1600, height=600,
title="Solar Panels Generation Per Year, kWh")
p.vbar(x='months', top='values', width=0.95, source=source)
p.xgrid.grid_line_color = None
p.xaxis.ticker.desired_num_ticks = 12
show(p)
输出:

每年太阳能板发电量,图片由作者提供
正如我们所见,冬季和夏季月份之间存在显著差异。
最后,让我们计算全年总电力生成量:
print("Total, kWh:", df_eu["kWh"].sum())
> Total, kWh: 335.38783499999994
我们的 2.5 平方米太阳能板总共产生了 335 千瓦时电力。
现在,让我们完成 Python 编码,看看如何在 Google Sheets 中实现相同的操作。
Google Sheets
一般来说,我们的数据处理流程将与在 Pandas 中的处理方式相同。我们需要加载数据集,转换列值,过滤和分组值以查看结果。实际上,当我们知道自己想做什么时,我们可以以“跨平台”的方式思考,这使得处理变得更加容易。
首先,让我们加载数据集。我在 Google Sheets 中创建了一个包含两个标签页的文档,“Source”和“Calculation”,并将 CSV 文件导入了“Source”标签页:

Google Sheets 中的数据集,图像由作者提供
现在我们需要按月份对值进行分组。让我们提取月份值从时间戳列。在 Pandas 中,我们是这样做的:
df_eu["time"] = df_eu['time'].map(str_to_date)
df_eu["month"] = df_eu["time"].dt.month
在 Google Sheets 中,我在 G10 单元格中添加了这个公式:
=ArrayFormula(MID(A10:A8793, 5, 2))
在这里,G10 是第一个单元格,结果将位于此处,A10:A8793 是我们的时间戳数据。“MID”函数从字符串中提取月份(“20200101:0011” 是非标准时间戳,使用子字符串更容易),而“ArrayFormula”方法会将这个函数自动应用到整个表格中。输入公式后,Google Sheets 会自动为我们创建一个新列。
同样,让我们从“G(i)”创建一个“kWh”列。在 Pandas 中,我们是这样做的:
df_eu["kWh"] = df_eu['G(i)'] / 1000
在 Google Sheets 中,它的工作方式几乎相同。我在 H10 单元格中添加了这个公式:
=ArrayFormula(B10:B8793/1000)
作为最后的预处理步骤,让我们将“Month”和“kWh”的名称输入为标题。结果应该是这样的:

在 Google Sheets 中添加的列,图像由作者提供
我们有一个“月份”列,现在我们可以按月份分组数据。在 Pandas 中称为“groupby”的操作,在 Google Sheets 中可以使用“数据透视表”来完成。让我们创建一个新的数据透视表,并将“Source!A9:H8793”作为数据源。在这里,“Source”是第一个标签页的名称,A9:H8793 是我们的数据。表格编辑器将自动检测列名,我们可以选择“Month”作为“行”并将“kWh”作为“值”:

按月份分组的 kWh 值,图像由作者提供
结果显示在截图中。实际上,它与我们在 Pandas 中做的非常接近:
df_eu_month = df_eu[["month", "kWh"]].groupby(["month"], as_index=False).sum()
我们的数据集已经准备好了;让我们进行所需的计算。在“Calculation”标签页中,我将使用前面三个单元格作为太阳能电池板的“变量”(B1 单元格),每个面板的尺寸(B2 单元格)和面板效率(B3 单元格)。然后我可以通过添加一个公式来计算总生成量:
=ArrayFormula(B1*B2*B3*Source!K11:K22)
“Source!” 是我们第一个包含数据源的标签页的链接,K11–K22 是按月份分组的数据所在的单元格。我们的新结果将放在 B7–B18 单元格中,我们还可以计算总生成量:
=SUM(B7:B18)
添加额外的标签和图表很容易;这不需要任何公式,我将在这里跳过这一部分。我们最终按月份分组的生成数据应该是这样的:

太阳能电池板每年的发电量,作者提供的图片
显然,结果必须与我们在 Pandas 中得到的结果相同;否则,就有问题。Google Sheets 的用户界面是互动的;我们可以更改太阳能电池板的数量或面板效率,Google Sheets 将自动重新计算所有结果。
我们最后的数据探索步骤是获取特定日期的电力生成。为此,我将把所需日期放入“A24”单元格,并使用两个单元格来显示结果。第一个单元格将包含时间,第二个单元格将包含能量值:
=FILTER(MID(Source!A10:A8793, 10, 4), SEARCH(A24, Source!A10:A8793))
=FILTER(B1*B2*B3*Source!H10:H8793, SEARCH(A24, Source!A10:A8793))
在这里,SEARCH(A24, Source!A10:A8793)是应用于源表的过滤器;第一个公式用于获取一天中的时间,第二个公式用于计算 kWh 的能量。可以选择添加标签和摘要。最终页面可能如下所示:

结论
在这篇文章中,我们在 Google Sheets 中导入了太阳能发电数据集,并能够计算和可视化不同的参数,如每月或特定日期的太阳能电池板发电量。这个表格是互动的,任何没有技术或编程技能的人都可以使用。最后但同样重要的是,这个解决方案没有成本,文档可以安全地与任何拥有 Google 账户的人分享。
显然,社区里对“Excel 中的数据科学”有很多玩笑,我并不鼓励任何人将 Google Sheets 作为主要的生产工具。但对于需要与他人共享结果或制作简单数据处理界面的简单场景,它可以是一个不错的补充。正如我们所见,基本操作如制作图表、数据分组或提取子字符串效果良好。
感谢阅读。如果你喜欢这个故事,欢迎订阅Medium,你将收到我新文章发布的通知,并且可以完全访问其他作者的成千上万的故事。
探索性数据分析:揭示数据集中的故事
原文:
towardsdatascience.com/exploratory-data-analysis-unraveling-the-story-within-your-dataset-6a8b1acdde
探索数据的秘密艺术——理解、清理和揭示数据集中的隐藏见解
Deepak Chopra | Talking Data Science
·发表于 Towards Data Science ·8 分钟阅读·2023 年 7 月 6 日
--

由 Andrew Neel 在 Unsplash 上拍摄的照片
作为数据爱好者,探索一个新的数据集是一项激动人心的工作。它让我们深入了解数据,并为成功分析奠定基础。对一个新数据集有一个良好的感觉并不总是容易的,需要时间。然而,一个好的、彻底的探索性数据分析(EDA)可以帮助你更好地理解你的数据集,感受数据之间的连接,以及需要做什么来正确处理你的数据集。
事实上,你可能会把 80% 的时间花在数据准备和探索上,只有 20% 用于实际的数据建模。对于其他类型的分析,探索可能会占用你更多的时间。
**什么是探索性数据分析。
探索性数据分析,简单来说,就是探索数据的艺术。 这是从不同角度调查数据的过程,以增强你的理解,探索模式,建立变量之间的关系,并在必要时增强数据本身。
就像和你的数据集去‘盲目’约会一样,坐在这个神秘的数字和文本集合对面,渴望在开始一段严肃的关系之前理解它。就像盲目约会一样,EDA 允许你揭示数据集的隐藏面貌。你观察模式,检测异常值,探索细微差别,然后再做出任何重大承诺。这完全是关于了解和建立信任,与数字建立稳固的基础,确保在得出结论之前你是稳固的。
我们都经历过;无论是有意还是无意,深入统计工具或筛选报告——我们都曾在某个时点探索过某种数据!
**为什么。
作为分析师和数据科学家,我们应该最好地理解数据。当涉及到理解和解释数据时,我们必须成为专家。无论是机器学习模型、实验框架还是简单的分析——结果取决于数据的质量。
记住,垃圾进,垃圾出!!
EDA 使数据分析师和科学家能够探索、理解和从数据中提取有意义的见解。就在你认为一切都已弄清楚时,数据集却给你来了个意外。你发现了缺失值、不一致性和混乱的数据。这就像发现你的约会对象有一只秘密的宠物短吻鳄或一系列独角兽雕像。探索性数据分析为你提供了清理混乱和理解一切的工具。
——这就像给你的数据集一个大改造,将它从杂乱无章变成一个光彩夺目的伙伴。
最终,探索性数据分析的核心在于深入了解你的数据,在过程中享受乐趣,并为进一步分析奠定坚实基础。所以戴上你的侦探帽,和你的数据集一起踏上这段激动人心的冒险之旅。谁知道呢,你可能会发现隐藏的宝藏甚至是真爱!
**如何。
探索性数据分析,顾名思义,就是对数据进行探索的分析。它包含了多个组件;这些组件并非所有时候都是必需的,也并非所有组件都有同等重要性。以下,我将根据我的经验列出一些组件。
请注意,这绝不是详尽无遗的列表,而是一个指导框架。
1. 理解数据的现状。
你不知道自己不知道什么——但你可以探索!
首先要做的是感受数据——查看数据条目,观察列值。你有多少行,多少列。
-
一个零售数据集可能会告诉你——X 先生在 2023 年 8 月 1 日访问了 2000 号商店,并购买了一罐可乐和一包沃克脆片
-
一个社交媒体数据集可能会告诉你——Y 女士在 6 月 3 日早上 09:00 登录社交网站,浏览了 A、B 和 C 板块,搜索了她的朋友 A 先生,然后在 20 分钟后注销。
了解你拥有的数据的业务背景,了解数据的来源和收集机制是有益的;例如,调查数据与数字收集数据等。
2. 深入探讨变量
变量是数据集的“语言”,它们在不断与你交流。你只需要提出正确的问题,并仔细倾听。
→ 要问的问题:: - 变量的含义是什么?
-
这些变量是连续的还是分类的?.. 是否有固有的顺序?
-
变量可能取什么值?
→ 行动::
-
对于连续变量——使用直方图、箱线图检查分布,并仔细研究均值、中位数、标准差等。
-
对于分类/有序变量——找出它们的唯一值,并进行频率表检查最常见/最少见的值。
你可能无法理解所有变量、标签和数值——但尽量获取尽可能多的信息
3. 查找数据中的模式/关系
通过 EDA,你可以发现数据中的模式、趋势和关系。
→ 需要问的问题:: *- 你是否对变量之间的关系有任何先前的假设/假设?
-
某些变量之间有业务上的关联理由吗?
-
变量是否遵循特定的分布?*
数据可视化技术、总结和相关性分析有助于揭示初看不明显的隐藏模式。理解这些模式可以为决策制定或假设生成提供有价值的见解。
→ 行动:: 思考双变量视觉分析。
-
对于连续变量——使用散点图、创建相关矩阵/热图等。
-
对于混合连续变量和有序/分类变量——考虑绘制条形图或饼图,并创建经典的列联表以可视化共现情况。
EDA(探索性数据分析)允许你验证统计假设,例如正态性、线性或独立性,以进行分析或数据建模。
4. 检测异常。
这是你成为数据上的福尔摩斯并寻找任何异常的机会!问问自己::
- 数据集中是否有重复条目?
重复项是指多次表示相同样本点的条目。在大多数情况下,重复项没有用处,因为它们不会提供任何额外的信息。它们可能是错误的结果,并且可能会干扰你的均值、中位数和其他统计数据。
→ 与你的利益相关者确认,并从数据中删除这些错误。
- 分类变量的标记错误?
查找分类变量的唯一值并创建频率图。查找拼写错误和可能表示相似事物的标签?
- 是否有变量缺失值?
这可能发生在数值和分类变量中。检查是否
-
是否有在很多变量(列)中缺失值的行? 这意味着有些数据点在大多数列中都是空白的 → 它们的用处不大,我们可能需要删除这些行。
-
是否有在多行中缺失值的变量(或列)? 这意味着有些变量在大多数数据点中没有值/标签 → 它们对我们的理解贡献不大,我们可能需要删除这些变量。
→行动::
计算所有变量的 NULL 或缺失值的比例。超过 15%-20%的变量应引起你的怀疑。
过滤掉某列中缺失值的行,并检查其余列的情况。是否大多数列一起有缺失值?..是否有模式?
- 我的数据集中是否存在异常值?
异常值检测是关于识别那些不符合常规的数据点。你可能会看到某些数值变量的非常高或极低的值,或者分类变量的高频/低频。
-
看似异常值的可能是数据错误。 虽然异常值是对于给定特征分布来说不寻常的数据点,但不需要的条目或记录错误是那些本来不应该存在的样本。
-
看似异常值的可能只是异常值。 在其他情况下,我们可能只是有一些极端值的数据点,并且背后有完全合理的解释。
→行动步骤::
研究直方图、散点图和频率条形图,以了解是否有一些数据点与其余数据点相距较远。思考:
这些值是否可能是真的,并且符合这些极端值?
对于这些极端值是否有业务上的理由或解释?
这些在后续阶段会对你的分析有价值吗?
5. 数据清洗。
数据清洗指的是从数据集中移除不需要的变量和值,并消除其中的任何不规则性。这些异常可能会不成比例地扭曲数据,从而对我们从该数据集中得出的分析结果产生不利影响。
记住:垃圾进,垃圾出。
- 纠正你的数据。
-
删除任何你发现的重复条目、缺失值和异常值——这些都没有为你的数据集增加价值。去除不必要的行/列。
-
纠正数据中你观察到的任何拼写错误或标签错误。
-
你发现的任何没有增加数据价值的数据错误也需要被移除。
- 截断异常值或保持现状。
- 在一些数据建模场景中,我们可能需要对异常值进行截断。截断通常在高端的第 99/95 百分位或低端的第 1/5 百分位进行。
- 处理缺失值。
我们通常会丢弃那些在变量中有很多缺失值的数据点(行)。同样,我们会丢弃那些在大量数据点中有缺失值的变量(列)。
如果有一些缺失值,我们可以考虑填补这些空缺,或者保持现状。
-
对于有缺失值的连续变量,我们可以通过使用均值或中位数(可能在特定分层中)来填补这些缺失值。
-
对于分类缺失值,我们可能会分配最常用的“类别”或创建一个新的“未定义”类别。
- 数据丰富化。
根据未来分析的需要,你可以向数据集中添加更多的特征(变量);例如(但不限于)
-
创建指示某事物存在或不存在的二元变量。
-
通过使用 IF-THEN-ELSE 子句创建额外的标签/类别。
-
根据未来分析的需求来缩放或编码你的变量。
-
结合两个或多个变量——使用各种数学函数,如求和、差异、均值、对数以及其他许多变换。
总结
EDA 使数据科学家能够发现有价值的见解,解决数据质量问题,并为进一步的分析和建模奠定坚实的基础。它确保数据分析的结果是可靠、准确且具有影响力的。
EDA 的关键组件:
-
了解数据的来源和“含义”。
-
了解所有变量及其分布、标签/类别。
-
寻找变量之间的模式/关系,以验证任何先前的假设或假定。
-
发现任何异常——数据错误、离群值、缺失值。
-
数据清理——删除或修正任何数据错误/异常,处理离群值,填补缺失值(如有需要),缩放/变换现有变量,并创建额外的衍生变量,丰富你的数据集,以便后续分析。
连接、学习与成长 ..
如果你喜欢这篇文章并对类似内容感兴趣,可以在 Medium、LinkedIn、与我 1:1 联系、加入我的邮件列表 上关注我,(如果你还没有的话),快来成为 Medium 家庭的成员,以获取数千篇有用的文章。(如果你使用以上链接,我将获得你会员费用的 ~50%)
.. 继续学习,继续成长!
探索性数据分析:我们对 YouTube 频道了解多少(第一部分)
原文:
towardsdatascience.com/exploratory-data-analysis-what-do-we-know-about-youtube-channels-3688c5cbc438
使用 Pandas 和 YouTube 数据 API 获取统计见解
·发表于Towards Data Science ·20 分钟阅读·2023 年 10 月 28 日
--

照片由 Glenn Carstens-Peters 拍摄,Unsplash
如今,活跃的 YouTube 用户超过 27 亿,对于很多人来说,YouTube 不仅仅是娱乐,更是重要的收入来源。但它是如何运作的呢?不同的 YouTube 频道可以获得多少观看次数或订阅者?借助 Python、Pandas 和 YouTube 数据 API,我们可以获得一些有趣的见解。
方法论
本文将分为几个部分:
-
使用 YouTube 数据 API。通过这个 API,我们将能够获取不同搜索请求的 YouTube 频道列表。对于每个频道,我们将获得有关视频数量、观看次数和订阅者的信息。
-
获取我们感兴趣的频道列表。这只能完成一次。
-
收集频道数据。为了获得统计见解,我们需要在一段时间内收集数据。
-
数据分析。
不再赘述,让我们开始吧。
1. YouTube 数据 API
首先,对于所有对从大型网络如 YouTube 收集数据感兴趣的人来说,有一个好消息:YouTube API 是免费的,我们无需支付费用。要开始使用这个 API,我们需要两个步骤:
- 打开
console.cloud.google.com并创建一个新项目。我之前在那里有一个旧项目,但在一段时间不活动后,它的所有 API 限制都被重置为零,我找不到重置的方法。因此,创建一个新项目更为简单。

Google Cloud Console,图片由作者提供
- 前往“API 和服务”并启用“YouTube 数据 API”。打开 API,进入“凭据”并创建一个 API 密钥。如果一切设置正确,配额页面将如下所示:

YouTube API 配额,作者提供的图片
就这样;之后,我们可以开始发起 API 请求以获取 YouTube 数据。至于限制,免费配额为每天 10,000 次查询。计算这个配额有点复杂,因为它基于“内部”YouTube 查询,而不仅仅是 API 调用的数量。搜索请求是“重”的,例如,获取关于“智能手机评测”这一短语的 500 个频道的列表将消耗约 7,000 个“单位”。因此,我们每天只能用一个 API 密钥进行一次这样的搜索。但免费层允许我们拥有12 个项目,每个项目有单独的配额。所以任务比较简单,但我们仍需要将请求数量合理限制在一定范围内。
数据收集管道将包括两种类型的 API 调用:
-
首先,我们将创建一个关于不同主题的 YouTube 频道列表。这只需要做一次。
-
其次,我们可以获取每个频道的观看次数和订阅者数。我将使用 Apache Airflow 来运行这个任务,至少运行一周,每天两次。
2. 获取 YouTube 频道
在第一步中,我们启用了 YouTube API。现在,让我们创建一个我们感兴趣的频道列表。为了进行搜索,我将使用 python-youtube 库的 search_by_keywords 方法。作为示例,查询“猫”的输出如下所示:
{
"kind": "youtube#searchListResponse",
"etag": "h_RGyvb98m0yrxBgG0Q21J0ch94",
"nextPageToken": "CAIQAA",
"regionCode": "UK",
"pageInfo": {
"totalResults": 19544,
"resultsPerPage": 10
},
"items": [
{
"kind": "youtube#searchResult",
"etag": "N6_OLAdw4hCq2.....",
"id": {
"kind": "youtube#channel",
"channelId": "UCoV0b7wU....."
},
"snippet": {
"publishedAt": "2016-11-07T04:54:33Z",
"channelId": "UCoV0b7....",
"title": "1 stoner 3 cats",
"description": "MUST BE 18 OR OLDER FOR THIS CHANNEL...",
"thumbnails": {
"default": {
"url": "https://yt3.ggpht.com/ytc/APkrFKZKfv..."
},
"medium": {
"url": "https://yt3.ggpht.com/ytc/APkrFKZKfv..."
},
"high": {
"url": "https://yt3.ggpht.com/ytc/APkrFKZKfvuGIwwg..."
}
},
"channelTitle": "1 stoner 3 cats",
"liveBroadcastContent": "upcoming",
"publishTime": "2016-11-07T04:54:33Z"
}
},
...
],
"prevPageToken": null
}
在这里,我们关注 title、channelId 和 publishedAt 参数。我们还可以看到 totalResults 值,这个值等于 19544。不过,遗憾的是,YouTube API 是为终端用户而设计的,而不是为了分析。我们不能获取所有关于“猫”的 YouTube 频道;这个 API 仅返回由 YouTube 推荐系统生成的 400-500 个频道的列表。
我们可以使用一个简单的程序,该程序针对特定短语进行 YouTube 查询并将结果保存到 CSV 文件中:
import datetime
import logging
from pyyoutube import Api # pip3 install python-youtube
def save_log(log_filename: str, s_data: str):
""" Save string to the log file """
with open(log_filename, "a", encoding="utf-8") as log_out:
log_out.write(s_data + "\n")
def search_by_keywords(api: Api, search_str: str, page_token: str):
""" Get YouTube channels list for a search phrase """
count = 10
limit = 25000
parts = ["snippet"]
res = api.search_by_keywords(q=search_str, limit=limit, count=count,
region_code="UK",
relevance_language="en",
search_type="channel",
order="title",
page_token=page_token, parts=parts,
return_json=True)
return res
def get_channels(api: Api, search_str: str):
""" Get YouTube channels list and save results in CSV file """
time_str = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
log_file = f"{search_str.replace(' ', '-')}-{time_str}.csv"
logging.debug(f"Log file name: {log_file}")
save_log(log_file, "channelId;publishedAt;title")
res = search_by_keywords(api, search_str, page_token=None)
next_page_token = res["nextPageToken"]
num_items = 0
while next_page_token is not None:
for item in res['items']:
title = item['snippet']['title'].replace(";", " ").replace(" ", " ")
description = item['snippet']['description'].replace(";", " ").replace(" ", " ")
log_str = f"{item['id']['channelId']};{item['snippet']['publishedAt']};{title} {description}"
logging.debug(log_str)
save_log(log_file, log_str)
num_items += 1
next_page_token = res["nextPageToken"]
logging.debug(f"{num_items} items saved to {log_file}")
res = search_by_keywords(api, search_str, page_token=next_page_token)
next_page_token = res["nextPageToken"]
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG,
format='[%(asctime)-15s] %(message)s')
key1 = "XXXXX"
youtube_api = Api(api_key=key1)
get_channels(youtube_api, search_str="cats")
作为输出,我们将获得如下所示的 CSV 文件:
channelId;publishedAt;title
UCoV0b7wUJ2...;2016-11-07T04:54:33Z;1 stoner 3 cats MUST BE ...
UCbm5zxzNPh...;2013-08-07T12:34:48Z;10 Cats ...
UCWflB-GzVa...;2013-09-25T10:39:41Z;13 Cats - Topic ...
UCiNQyjPsO9-c2C7eOGZhYXg;2023-10-09T22:51:37Z;2 CATS NO RULES ...
现在,我们可以使用不同的查询进行搜索。这只能做一次;频道 ID 不会改变。为了本文的目的,我使用了这些查询:
-
“猫”
-
“Dogs”
-
“化妆教程”
-
“摄影”
-
“智能手机评测”
-
“街头摄影”
结果是,我在 CSV 文件中保存了一个频道列表(每个查询大约 500 条记录),总共有大约 3000 个 YouTube 频道。
3. 获取频道详细信息
下一步,我们需要获取每个频道的统计数据。为此,我将使用相同 python-youtube 库中的 get_channel_info 方法:
def get_channel_info(api: Any,
file_out: str,
channel_id: str,
channel_title: str) -> int:
""" Get YouTube channel statistics """
res = api.get_channel_info(channel_id=channel_id, parts=["statistics"], return_json=True)
n_count = 0
if "items" in res:
time_str = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
for item in res["items"]:
ch_id = item["id"]
statistics = item["statistics"]
views = statistics["viewCount"]
subscribers = statistics["subscriberCount"]
videos = statistics["videoCount"]
s_out = f"{time_str};{ch_id};{channel_title};{views};{subscribers};{videos}"
logging.debug(f"Saving: {s_out}")
save_log(file_out, s_out)
n_count += 1
return n_count
可以这样使用这种方法:
api = Api(api_key="...")
get_channel_info(api, "cats_09_24.csv",
channel_id="UCbm5zxzNPh...",
channel_title="CATS NO RULES Its a Cats Life")
作为输出,我们将获得一个包含所需值的 CSV 文件:
timestamp;channelId;title;views;subscribers;videos
2023-10-09-19-42-19;UCoV0b7wUJ2...;1 stoner 3 cats MUST BE ...;14;2;6
2023-10-09-19-42-19;UCbm5zxzNPh...;CATS NO RULES Its a Cats Life;24;5;3
数据收集 现在,我们知道如何获取 YouTube 频道列表以及如何获取频道详细信息,例如观看次数和订阅者数量。但查看这些值的动态和它们如何随时间变化是很有趣的。YouTube 有一个单独的Analytics API,可以用于报告。然而,正如 API 文档中所写,“授权请求的用户必须是频道的拥有者”,因此对我们的任务来说是无用的。我们唯一的方法是收集一段时间的数据;1-2 周看起来是一个很好的时间段。
数据收集可以通过不同方式进行,我决定使用Apache Airflow,并在我的树莓派上安装了它。事实证明,树莓派是一个出色的数据科学工具,用于数据收集,我已经在几个爱好项目中使用过它。这台$50 的单板计算机仅消耗 2W 功率,静音,没有风扇,并且在 4 核 CPU 上运行完整的 Ubuntu。Raspbian OS 的配置细节超出了本文的范围;有兴趣的读者可以阅读我之前的 TDS 文章:
## 在树莓派上使用 Apache Airflow 进行数据收集
一台树莓派就是你所需要的一切
towardsdatascience.com
4. 探索性数据分析
预处理
最后,我们即将进入本文的有趣部分:让我们看看从收集的数据中可以获得什么样的见解。我将使用 Pandas 进行数据处理,使用 Matplotlib 和 Seaborn 绘制图表。
首先,让我们加载之前收集的数据。文件可以使用scp命令从树莓派中复制(这里,10.14.24.168是设备地址,“pi”是标准的 Raspbian 用户名):
scp pi@10.14.24.168:/home/pi/airflow/data/*.csv data
Apache Airflow 每天执行代码两次,每次运行后保存一个带时间戳的 CSV 文件。一周后,我得到了一堆大约 80K 记录的 CSV 文件。让我们加载所有文件,并将它们合并到 Pandas 数据框中:
import glob
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
channel_files = glob.glob("data/channel*.csv")
channels_data = []
for file_in in channel_files:
channels_data.append(pd.read_csv(file_in, delimiter=';',
parse_dates=['timestamp'],
date_format="%Y-%m-%d-%H-%M-%S"))
df_channels = pd.concat(channels_data)
结果如下所示:

带有时间序列数据的数据框,图片由作者提供
作为提醒,在文章开头,我还收集了用于不同搜索请求的频道列表(“智能手机”、“猫”、“狗”等)。让我们将这个列表加载到第二个数据框中:
def load_channels(files: List, subject: str) -> pd.DataFrame:
""" Load and combine dataframe from several files """
dataframes = []
for csv in files:
df = pd.read_csv(csv, delimiter=";", parse_dates=["publishedAt"])
df["subject"] = subject
dataframes.append(df)
return pd.concat(dataframes).drop_duplicates(subset=["channelId"])
smartphones = load_channels(["smartphone-channels.csv"], subject="Smartphones")
dogs = load_channels(["dogs-channels.csv"], subject="Dogs")
cats = load_channels(["cats-channels.csv"], subject="Cats")
...
channels_all = pd.concat([smartphones, makeup, photography,
streetphotography, cats,
dogs]).drop_duplicates(subset=["channelId"])
加载频道列表可以实现自动化,但我只有 6 个类别,所以直接硬编码它们非常简单。我还添加了一个“subject”列来保存类别名称(重要的是要提到,“subject”不是由频道拥有者给出的“官方”频道类别,而是在搜索请求中使用的名称)。
此时,我们有两个 Pandas 数据框:一个包含基本频道数据(ID、标题和创建日期),另一个包含时间序列数据,包括观看次数、视频数量和订阅者数。让我们将这两个数据框 合并 在一起,以 channelId 作为键:
df_channels = df_channels.merge(
channels_all[["channelId", "publishedAt", "subject"]],
on=['channelId'],
how='left')
现在,我们准备开始有趣的部分了!让我们用 Seaborn 和 Matplotlib 可视化不同类型的数据。
4.1 观看次数和订阅者数量 作为热身,让我们按观看次数排序 YouTube 频道:
df_channels_ = df_channels.drop_duplicates(subset=["channelId"]).sort_values(by=['views'], ascending=False).copy()
df_channels_["views"] = df_channels_["views"].apply(lambda val: f"{val:,.0f}")
df_channels_["subscribers"] = df_channels_["subscribers"].apply(lambda val: f"{val:,.0f}")
display(df_channels_)
结果如下:

YouTube 频道,按观看次数排序,作者提供的图片
我们可以看到数值之间的差异非常大。列表中的顶级频道拥有数十亿的观看次数和数百万的订阅者。实际的数字大到我不得不在列中添加了上千个“,”分隔符!
题外话,我为什么不使用 Pandas Styler 对象呢?确实,这样写代码很简单:
display(df_channels_.style.format(thousands="."))
结果显示在小型数据框上效果很好。但至少在 Visual Studio Code 中,改变样式后,数据框不再显示为头、尾和“…”了,Visual Studio 总是显示所有 3030 行。如果有人知道解决方案,请在下面的评论中写出来。
看到数据框很不错,但图形形式会更清晰。让我们使用条形图绘制 观看次数:
decimation = 10
df_channels__ = df_channels_.reset_index(drop=True).iloc[::decimation, :]
sns.set(rc={'figure.figsize': (18, 6)})
sns.set_style("whitegrid")
fig, ax = plt.subplots()
sns.barplot(df_channels__, x=df_channels__.index, y="views", width=0.9, ax=ax)
ax.set(title='YouTube channels views',
xticks=range(0, df_channels__.shape[0], 50),
ylim=(0, None),
xlabel='Channel №',
ylabel='Views Total')
ax.ticklabel_format(style='plain', axis="y")
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ',')))
sns.despine(left=True)
plt.show()
绘图很简单,但需要一些小调整。再次,我使用了 FuncFormatter 添加“,”千位分隔符;否则,数字太大,不便于阅读。我还添加了 decimation=10 参数来减少数据框中的记录数;否则,垂直条形图太小了。尽管如此,我们可以看到该区域几乎是空的:

显然,通过使用 ylim 参数很容易调整纵坐标刻度,但我特意将其保留为这样,以便读者能够看到“顶级”频道与“其他”频道之间的真实差异。分布非常偏斜。几个顶级频道的观看次数达到数十亿,而与之相比,其他频道几乎不可见。在我列出的约 3000 个频道中,前 5%的频道拥有 95%的总观看次数。
我们还可以绘制 订阅者数量,其形状与之前的图表相同:

让我们使用 百分位数 获取更准确的数据:
display(df_channels_[["views", "subscribers"]].quantile(
[0.01, 0.05, 0.25, 0.5, 0.75, 0.95, 0.99]
).style.format("{:,.0f}"))
输出看起来是这样的:

四分位数数据,作者提供的图片
50 百分位数(或 0.5 分位数)是一个数字,显示了 50%的所有值低于该数字。例如,所有订阅者值的 50 百分位数仅为 16。这意味着尽管顶端的数字如同 Googleplex 一般,列表中 50%的频道订阅者数少于 16!这可能令人惊讶,但我们可以通过按订阅者数量对数据框进行排序并查看中间值来轻松验证这一点:
df_channels_ = df_channels.drop_duplicates(subset=["channelId"]).sort_values(by=['subscribers'], ascending=False).reset_index(drop=True)
display(df_channels_[df_channels_.shape[0]//2:])
结果确认了上表的正确性:

数据框的中间部分,作者提供的图片
所有这些数值可以让我们对预期的观看次数和订阅者数量有一个大致的了解。但是在这里,我只分析了我收集的 3030 个频道。我们能得到一个YouTube 频道总数,比如说,1 百万和 10 万订阅者的频道总数吗?我没有找到答案,这可能是 YouTube 的秘密,就像 Tinder 上男女用户的真实比例一样;)显然,YouTube 推荐系统有一个将“顶级”频道和“其他”频道混合在搜索结果中的算法,给新手提供了被观众看到的机会。
4.2 每注册日期的订阅者数量 了解某个 YouTube 频道是否有 1,000,000 次观看或订阅者很有趣,但频道主们多快能达到这个值呢?在 YouTube Data API 中,每个频道都有一个“publishedAt”参数,代表频道的创建日期。我们无法获取特定频道的历史数据,但我们可以通过散点图比较不同创建日期的频道。我还将用不同颜色区分不同类别,并添加平均线。
upper_limit = 1_000_000
df_channels_ = df_channels.drop_duplicates(subset=["channelId"]).copy()
df_channels_["subscribers_clipped"] = df_channels["subscribers"].clip(upper=upper_limit)
sns.set(rc={'figure.figsize': (18, 8)})
sns.set_style("white")
palette = sns.color_palette("bright")
fig, ax = plt.subplots()
# Add scatter plot and average lines
for ind, subj_str in enumerate(df_channels_["subject"].unique()):
df_subj = df_channels_[df_channels_["subject"] == subj_str]
# Draw scatter plot
markers = ["o" , "s" , "p" , "h"]
sns.scatterplot(data=df_subj, x="publishedAt", y="subscribers_clipped",
color=palette[ind],
marker=markers[ind % len(markers)],
label=subj_str,
ax=ax)
# Draw average
col_avg = df_subj["subscribers"].mean()
linestyles = ["--", ":", "-."]
linestyle = linestyles[ind % len(linestyles)]
ax.axhline(col_avg, color=palette[ind], label=subj_str + " Avg", linestyle=linestyle, linewidth=1.0, alpha=0.6)
ax.set(title='Channel Subscribers',
xlabel='Registration Date',
ylabel='Subscribers',
ylim=(0, upper_limit)
)
ax.ticklabel_format(style='plain', axis="y")
ax.xaxis.set_major_locator(mdates.MonthLocator(interval=12))
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ',')))
ax.spines['top'].set_color('#EEEEEE')
ax.spines['right'].set_color('#EEEEEE')
ax.spines['bottom'].set_color('#888888')
ax.spines['left'].set_color('#888888')
plt.legend(loc='upper right')
plt.show()
相比于之前的柱状图,结果提供了更多的信息:

订阅者数量分布,作者提供的图片
100 万订阅者是许多 YouTube 频道的一个“标志性”数值,我将这个值设置为图表的裁剪限制。我们可以看到,我列表中“最年轻”的 YouTube 频道在 2022 年初达到了这一点,所以频道主们花了将近两年时间才做到这一点(此分析在 2023 年底进行)。与此同时,还有一些“老”频道,甚至在 2010 年之前创建的,今天仍然没有达到 10 万订阅者。
关于平均值,它们也很有趣。我们可以看到,订阅“智能手机”相关频道的人更多,第二大热门类别是“化妆”。让我们再“放大”一下图表:

订阅者数量分布,作者提供的图片
在这里,我们可以看到,“猫”和“狗”类别的受欢迎程度平均要低得多(几乎低 10 倍)。“摄影”和“街头摄影”类别更加小众,即使获得 10 万订阅者对这些频道来说也可能是一个具有挑战性的目标。
4.3 每个视频的订阅者数量 这个问题对于那些想要开始自己 YouTube 频道的人可能很有趣。应该发布多少视频才能获得一定数量的观看或订阅者?我们知道每个频道的视频和订阅者数量,可以通过使用一个散点图来找到答案。我还会使用线性回归模型来绘制平均线:
from sklearn.linear_model import LinearRegression
import numpy as np
df_channels_ = df_channels.drop_duplicates(subset=["channelId"]).copy()
upper_limit = 100_000
right_limit = 1000
sns.set(rc={'figure.figsize': (18, 8)})
sns.set_style("white")
num_subjects = df_channels_["subject"].nunique()
palette = sns.color_palette("bright")
fig, ax = plt.subplots()
for ind, subj_str in enumerate(df_channels_["subject"].unique()):
# Filter by subject
df_subj = df_channels_[df_channels_["subject"] == subj_str].sort_values(by=['subscribers'], ascending=False)
# Draw scatter plot
markers = ["o" , "s" , "p" , "h"]
sns.scatterplot(data=df_subj, x="videos", y="subscribers",
color=palette[ind],
# palette=[palette[ind],
# hue="subject",
marker=markers[ind % len(markers)],
label=subj_str,
ax=ax)
# Make linear interpolation
df_subj = df_subj[10:] # Optional: remove top channels to exclude "outliers"
values_x = df_subj["videos"].to_numpy().reshape((-1, 1))
values_y = df_subj["subscribers"].to_numpy()
model = LinearRegression().fit(values_x, values_y)
x_val = np.array([0, right_limit])
y_val = model.predict(x_val.reshape((-1, 1)))
# Draw
linestyles = ["--", ":", "-."]
ax.axline((x_val[0], y_val[0]), (x_val[1], y_val[1]),
linestyle=linestyles[ind % 3], linewidth=1,
color=palette[ind], alpha=0.5,
label=subj_str + " Avg")
ax.set(title='YouTube Subscribers',
xlabel='Videos In Channel',
ylabel='Subscribers',
xlim=(0, right_limit),
ylim=(0, upper_limit)
)
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ',')))
ax.spines['top'].set_color('#EEEEEE')
ax.spines['right'].set_color('#EEEEEE')
ax.spines['bottom'].set_color('#888888')
ax.spines['left'].set_color('#888888')
plt.legend(loc='upper right')
plt.show()
在这里,我将值限制为 100,000 个订阅者和 1,000 个视频。我还从线性插值中排除了前 10 个频道,以使平均结果更具现实性。
输出结果如下:

订阅者数量在 0–100K 范围内,图像由作者提供
我们再次看到,“化妆”和“智能手机”频道每个视频获得的订阅者最多。“猫”和“狗”的平均线几乎是水平的。这是怎么回事?首先,正如我们在前一张图片中看到的,这一类别的订阅者平均数量通常较低。其次,我猜测更多的人发布关于猫和狗的视频,分布更为倾斜。
那么分布的顶部怎么样?好吧,那里有足够的频道,订阅者超过 1M,视频少于 1000 个:

订阅者数量在 0–10M 范围内,图像由作者提供
我认为这些是配备高端电影设备的专业工作室,预算也相当高。那分布的低部分怎么样?让我们看看另一张图:

订阅者数量在最低范围内,图像由作者提供
我惊讶于看到一些 YouTube 频道有 1,000–5,000 个视频,但只有 10–50 个订阅者。结果是,这些频道很可能是由机器人自动生成的;它们只有播放列表,没有视频,大多数没有观看量,也没有订阅者。这些频道的目的是什么?我不知道。一些其他频道属于真实的人,看到有人发布了超过 1000 个视频,每个视频每年只有 10–20 次观看,这有点令人遗憾。
4.4 频道动态 — 每日观看量 正如我们所知,使用公开的 YouTube API,我们只能获得当前时刻的观看量和订阅者数量,只有频道所有者才能获得历史数据。作为替代方案,我使用 Raspberry Pi 和 Apache Airflow 收集了一周的数据。现在,是时候看看我们能得到什么了。
在这种情况下,处理起来有点棘手。我需要获取每个频道的数据,按时间戳排序,并计算差值:
channels_data = []
channel_id = ...
df_channel_data = df_channels[df_channels["channelId"] == channel_id][["timestamp", "views", "subscribers", "videos"]].sort_values(by=['timestamp'], ascending=True).reset_index(drop=True).copy()
df_first_row = df_channel_data.iloc[[0]].values[0]
df_channel_data = df_channel_data.apply(lambda row: row - df_first_row, axis=1)
df_channel_data["channelId"] = channel_id
df_channel_data["days_diff"] = df_channel_data["timestamp"].map(lambda x: x.total_seconds()/(24*60*60), na_action=None)
df_channel_data[subj_str] = subj_str
channels_data.append(df_channel_data)
在这里,我使用apply方法计算数据框中第一个值和其他值之间的差异。然后,我可以使用lineplot绘制数据:
sns.lineplot(data=pd.concat(channels_data),
x="days_diff", y="views",
hue="channelId", palette=palette, linestyle=linestyle,
legend=False)
(完整代码更长,为了清晰起见,我只保留了必要部分)
如我们所知,分布是倾斜的。前 50 个频道的结果如下:

每周前 50 个频道观看次数,作者图片
如我们所见,顶级频道每天的观看次数可以超过几百万!
分布右侧情况如何?总的来说,我收集了 3030 个频道,这是右侧 1000 个频道的相同图表:

每周 1000 个 YouTube 频道观看次数,作者图片
这里的结果远不如预期。一些频道每周获得 50-100 次新观看,但大多数频道仅获得 10-20 次观看。YouTube 搜索限制在大约 500 个项目,但我可以猜测大多数 YouTube 用户从未滚动超过前 1-2 页。
4.5 渠道动态 — 每日订阅者数 让我们来看看订阅者数量的变化。代码是一样的,只不过我使用了“订阅者”列而不是“观看次数”。
结果很有趣。首先,让我们看看我列表中的前 50 个频道:

每周新频道订阅者,作者图片
如我们所见,顶级频道每天可以获得几千个新订阅者!在分布的右侧,结果再次不那么令人兴奋,但仍然很有趣:

每周新频道订阅者,作者图片
其中一个频道“突然”每天获得 100 个订阅者,但这一数值没有再增加。也许频道主支付了推广费用,或者其中一个视频突然走红——谁知道呢?其他频道每周仅获得 5-10 个新订阅者。
4.6 渠道动态 — 每日视频数 了解不同频道每天发布多少视频也很有趣。我们可以使用相同的代码轻松找到答案。首先,让我们看看前 50 个频道的新视频数量:

每日新视频,作者图片
这里是我列表右侧的 1000 个频道:

每日新视频,作者图片
有趣的是,数字并没有大幅度不同。但顶级频道显然发布的视频较少,他们显然更注重质量而非数量。他们每周只能制作一个视频,每个视频可能拥有超过 100 万的观看次数。然而,有些 YouTube 频道总共有超过 5000 个视频;他们每天发布几个视频。无论如何,这些频道都没有跻身前列,这很值得思考。
“意大利面图”可以展示我们一个大致趋势,但很难从中读取具体值。为了获得更精确的数据,我们可以为前 50 个频道绘制直方图:

每周新视频,作者图片
我们可以看到,有些频道每天发布多个视频,但大多数顶级频道每周只制作一个或更少的视频。显然,没有一种适用于所有类型的通用公式,关于猫的视频或智能手机或相机评论的视频可能需要完全不同的准备时间。欢迎读者按不同类别过滤频道,并自行做更详细的分析。
5. 奖励:异常检测
最后,对于耐心阅读到这里的读者,赠送一个小奖励。让我们应用异常检测算法,看看是否能找到一些不寻常的 YouTube 频道。我将使用无监督的IsolationForest算法。该算法本身基于二叉决策树。在每一步,树通过随机特征和随机阈值进行分支,直到每个点完全孤立或达到最大深度。之后,根据达到该点所需的树深度,为每个点分配“异常分数”。
我将使用每个视频的观看次数和订阅者数量作为度量标准。我还将contamination值设置为 0.05;这是我们期望的异常比例。
from sklearn.ensemble import IsolationForest
df_channels_ = df_channels.sort_values(by=['videos'], ascending=False).drop_duplicates(subset=["channelId"]).copy().reset_index(drop=True).copy()
df_channels_ = df_channels_[df_channels_["videos"] > 10]
df_channels_["subscribers_per_video"] = df_channels_["subscribers"]/df_channels_["videos"]
df_channels_["views_per_video"] = df_channels_["views"]/df_channels_["videos"]
df_channels_[["subscribers_per_video", "views_per_video"]] = df_channels_[["subscribers_per_video", "views_per_video"]].apply(pd.to_numeric)
X = df_channels_[["subscribers_per_video", "views_per_video"]]
model = IsolationForest(contamination=0.05, random_state=42).fit(X)
df_channels_['anomaly_scores'] = model.decision_function(X)
df_channels_['anomaly'] = model.predict(X)
# Anomaly: Outlier (-1) or an inlier (1)
# Anomaly_scores: The lower the score, the more abnormal is the sample
display(df_channels_.sort_values(by=['anomaly_scores'], ascending=True)[:30])
我们来按异常分数对频道进行排序。结果如下所示:

在我们“异常评分”的第一位,我们看到一个来自“猫”类别的频道,这个频道确实每个视频的订阅者数量很高。我看了这个频道;虽然我不是猫视频的粉丝,但从技术上讲它确实做得很好。这也是我第一次看到一个有 193M 观看次数的视频(我必须承认,没有哪个关于数学或机器学习的视频能达到这一点;)。在我的“评分”中的第二个频道是关于化妆的。我对这一领域绝对不是专家,本来打算跳过它,但其中一个视频引起了我的注意。作者请 ChatGPT 写化妆步骤。我从未考虑过使用 AI 来化妆,尽管看到 AI 如何影响我们生活的越来越多领域还是很有趣。
有时很容易猜测为什么某个项目具有高异常评分,但如果特征数量很大,这可能会变得复杂。在这种情况下,我们可以使用SHAP库来可视化结果:
import shap
X = df_channels_[["subscribers_per_video", "views_per_video"]]
y_pred = model.predict(X)
explainer = shap.Explainer(model.predict, X)
shap_values = explainer(X)
shap.initjs()
explainer方法使用Shapley 值来解释不同的机器学习模型,并且也可以与IsolationForest一起使用。初始化后,我们可以检查列表中的不同项目。让我们检查第一个:
shap.plots.waterfall(shap_values[786])
结果如下所示:

Shapley 解释器结果,作者提供的图片
在另一个例子中,views_per_video参数看起来正常,但subscribers_per_video值很高:

Shapley Explainer 结果,图片由作者提供
在这种情况下,我们可以看到两个指标都异常高。
结论
在本文中,我解释了如何使用 YouTube 数据 API 和 python-youtube 库获取 YouTube 频道数据。这些数据允许我们对不同类别进行 YouTube 搜索请求,并获得有关 YouTube 频道的有趣统计见解。
我想每个读者今天或昨天至少看过一个 YouTube 视频。根据 demandsage.com,YouTube 是仅次于 Google 的第二大搜索引擎,2023 年有 27 亿活跃用户。它是我们现代社会的一部分,也是日常生活的一部分。因此,从文化和研究的角度来看,了解哪些类别最受欢迎以及不同频道可以获得多少观看次数和订阅者是很有趣的。在本文中,我使用了像“猫”或“狗”这样的“中性”类别,但同样的方法可以用于收集关于政治、战争、医学、阴谋论或其他任何话题的数据。最后但同样重要的是,对许多内容创作者来说,YouTube 是一个重要的收入来源,了解不同类别能获得多少观看次数或订阅者可能至关重要。因此,我鼓励你作为读者,对你感兴趣的话题进行相同的测试。无论如何,统计学是一门 关于我们的 科学。
在故事的第二部分,我将重点关注单个视频。我们将查看不同 YouTube 频道发布视频的频率,以及这些视频能获得多少观看次数:
## 探索性数据分析:我们对 YouTube 频道了解多少(第二部分)
使用 Pandas 和 YouTube 数据 API 获取统计见解
towardsdatascience.com
对社会数据分析感兴趣的人也欢迎阅读其他文章:
如果你喜欢这个故事,欢迎订阅Medium,你将会在我的新文章发布时收到通知,并且可以全面访问其他作者的成千上万篇故事。如果你想获取这篇文章以及我下一篇文章的完整源代码,欢迎访问我的Patreon 页面。
感谢阅读。
探索性数据分析:我们对 YouTube 频道了解了什么(第二部分)
使用 Pandas 和 YouTube Data API 获取统计见解
·发表于 Towards Data Science ·阅读时间 14 分钟·2023 年 11 月 24 日
--

图片来源:Souvik Banerjee,Unsplash
在第一部分中,我从大约 3000 个 YouTube 频道收集了统计数据,并获得了一些有趣的见解。在这一部分,我将进一步深入,从通用的“频道”层面到个别的“视频”层面。我将展示如何收集 YouTube 视频的数据以及我们可以获得什么样的见解。
方法论
为了收集 YouTube 视频的数据,我们需要执行几个步骤:
-
获取 YouTube Data API 的凭证。它是免费的,API 每天 10,000 次请求的限制足够满足我们的任务需求。
-
找到几个我们想要分析的 YouTube 频道。
-
编写一些 Python 代码来获取所选频道的最新视频及其统计数据。YouTube 分析功能仅对频道所有者开放,我们只能获取当前时刻的数据。但我们可以运行代码一段时间。在我的案例中,我使用 Apache Airflow 和 Raspberry Pi 收集了三周的数据。
-
执行数据分析。我将使用 Pandas、Matplotlib 和 Seaborn 来完成这项工作。
获取 YouTube API 凭证和配置 Apache AirFlow 的过程在我之前的文章中有描述,我建议读者暂停阅读本篇文章,先阅读那部分内容:
## 探索性数据分析:我们对 YouTube 频道了解了什么
使用 Pandas 和 YouTube Data API 获取统计见解
towardsdatascience.com
现在,让我们开始吧。
1. 获取数据
要获取有关 YouTube 视频的信息,我将使用一个python-youtube库。令人惊讶的是,没有现成的方法可以从特定频道获取视频列表,我们需要自己实现。
首先,我们需要调用get_channel_info方法,它顾名思义,将返回有关频道的基本信息。
from pyyoutube import Api
def get_channel_info(api: Api, channel_id: str) -> Tuple[str, str, str]:
""" Get info about the channel. Return values: title, uploads, subscribers """
channel_info = api.get_channel_info(channel_id=channel_id, parts=["snippet", "statistics", "contentDetails"], return_json=True)
if len(channel_info["items"]) > 0:
item = channel_info["items"][0]
title = item["snippet"]["title"]
uploads = item["contentDetails"]["relatedPlaylists"]["uploads"]
subscribers = item["statistics"]["subscriberCount"]
return title, uploads, subscribers
logging.warning(f"get_channel_info::warning cannot get data for the channel {channel_id}")
return None, None, None
api = Api(api_key="...")
get_channel_info(api, channel_id="...")
输出如下:
"items": [
{
"id": "UCBJycsmd...",
"snippet": {
"title": "Mar...",
"description": "MKBH...",
"publishedAt": "2008-03-21T15:25:54Z",
"contentDetails": {
"relatedPlaylists": {
"likes": "",
"uploads": "UUBJy..."
}
},
"statistics": {
"viewCount": "3845139099",
"subscriberCount": "17800000",
"hiddenSubscriberCount": false,
"videoCount": "1602"
}
}
]
在这里,我们有一个statistics部分,包含频道的视频数量、观看次数和订阅者数。第二部分是contentDetails;这是我们需要的,因为它包含“uploads”列表的 ID。正如我们所见,频道上传的视频作为“虚拟”播放列表进行存储,这让我有些惊讶。
之后,我们需要调用get_playlist_items方法,它会返回所需播放列表中的视频列表。
def get_playlist_items(api: Api, playlist_id: str, limit: int) -> List[Tuple[str, str]]:
""" Get video IDs for a playlist """
videos = []
playlist_items = api.get_playlist_items(playlist_id=playlist_id, count=10, limit=10, parts=["contentDetails"], return_json=True)
next_page_token = playlist_items["nextPageToken"]
while next_page_token is not None:
for video in playlist_items["items"]:
video_id = video["contentDetails"]["videoId"]
video_published_at = video["contentDetails"]["videoPublishedAt"]
# views, likes, comments = get_video_by_id(api, video_id)
videos.append([video_id, video_published_at])
next_page_token = playlist_items["nextPageToken"]
playlist_items = api.get_playlist_items(playlist_id=playlist_id, count=10, limit=10,
parts=["contentDetails"], return_json=True,
page_token=next_page_token)
if len(videos) >= limit:
break
return videos
输出如下:
"items": [
{
"kind": "youtube#playlistItem",
"etag": "tmSJMm9_KwkNTPkpdspUkQiQtuA",
"id": "VVVCSnljc21kdXZZRU...",
"contentDetails": {
"videoId": "Ks_7TmG...",
"videoPublishedAt": "2023-10-28T13:09:50Z"
}
},
...
]
在这里,我们需要videoId和videoPublishedAt字段。
只有在这一步,拥有视频 ID 列表后,我们才能找到每个视频的观看次数、点赞数和评论数:
def get_video_by_id(api: Api, video_id: str) -> Tuple[str, str, str]:
""" Get video details by id """
video_info = api.get_video_by_id(video_id=video_id, parts=["statistics"], return_json=True)
if len(video_info["items"]) > 0:
item = video_info["items"][0]
views = item["statistics"]["viewCount"]
likes = item["statistics"]["likeCount"]
comments = item["statistics"]["commentCount"]
return views, likes, comments
return None, None, None
作为最终步骤,我创建了一个将所有这些部分组合在一起的方法:
def get_channel_videos(api: Api, channel_id: str, limit: int) -> List:
""" Get videos for the channel """
videos_data = []
title, uploads, subscribers = get_channel_info(api, channel_id)
if title is not None and uploads is not None:
title_ = title.replace(";", ",")
videos = get_playlist_items(api, uploads, limit)
for video_id, video_published_at in videos:
views, likes, comments = get_video_by_id(api, video_id)
videos_data.append((channel_id, title_, subscribers, video_id, video_published_at, views, likes, comments))
return videos_data
limit变量对调试很有帮助;它允许我们减少每个查询的请求数量,避免超过 API 配额限制。
如前所述,只有频道所有者才能获取历史和分析数据;我们只能获得当前时刻可用的数据。但我们可以定期请求数据(视频数量及其观看次数、点赞数和评论数)。使用Apache Airflow在树莓派上运行,我让这段代码运行了三周。每 3 小时执行一次请求,每次请求的输出都保存为 CSV 文件(更多详细信息和 DAG 示例见第一部分)。现在让我们看看能得到什么样的见解。
2. ETL(提取、转换、加载)
和往常一样,在使用数据进行分析之前,我们需要将其转换为方便的形式。我们的 ETL 过程非常简单。从 Apache AirFlow 任务中,我获得了大量的 CSV 文件。让我们加载这些文件并将它们合并成一个数据集:
import pandas as pd
import glob
channel_files = glob.glob("data/video*.csv")
channels_data = []
for file_in in channel_files:
channels_data.append(pd.read_csv(file_in, delimiter=";",
parse_dates=["timestamp"],
date_format="%Y-%m-%d-%H-%M-%S"))
df_channels = pd.concat(channels_data)
让我们检查一个视频的结果:
display(df_channels.query('videoId == "8J...4"').sort_values(by=["timestamp"], ascending=True))
输出如下:

示例数据框,图片由作者提供
每行包含一个时间戳、视频 ID、视频发布时间,以及在数据收集时的观看次数、点赞数和评论数。我们可以看到,视频8J…4于 2023 年 10 月 27 日 19:00 发布。在我观察开始时,它已经有 514,948 次观看,而在数据框的末尾,观看次数增加到了 978,573 次。
现在,我们准备开始一些有趣的操作。
3. 数据分析
3.1 观看次数 作为热身,我们先展示每个视频的观看次数。我将只使用最近两个月内制作的视频。
channel_id = "UCu..."
df_channel = df_channels[df_channels["channelId"] == channel_id]
df_channel = df_channel.sort_values(by=['timestamp'], ascending=True)
# Videos published within interval
days_display = 2*31
start_date = df_channel["timestamp"].max() - pd.Timedelta(days=days_display)
end_date = df_channel["timestamp"].max()
df_channel = df_channel[(df_channel["videoPublishedAt"] >= start_date) &
(df_channel["videoPublishedAt"] < end_date)]
我每 3 小时收集一次频道数据,因此只需要最后的时间戳:
step_size = 3
interval_start = df_channel["timestamp"].max() - pd.Timedelta(hours=step_size)
interval_end = df_channel["timestamp"].max()
df_interval = df_channel[(df_channel["timestamp"] >= interval_start) &
(df_channel["timestamp"] < interval_end)]
df_interval = df_interval.drop_duplicates(subset=["videoId"])
v_days = df_interval["videoPublishedAt"].values
v_views = df_interval["viewCount"].values
让我们使用 Matplotlib 绘制条形图:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
fig, ax = plt.subplots(figsize=(16, 4))
cmap = plt.get_cmap("Purples")
views_max = 3_000_000
views_avg = df_channel.drop_duplicates(subset=["videoId"], keep="last")["viewCount"].median() # Median value
rescale = lambda y: 0.5 + 0.5 * y / views_max
# Bar chart
ax.bar(v_days, v_views,
color=cmap(rescale(v_views)),
width=pd.Timedelta(hours=12))
# Add horizontal median line
ax.axhline(y=views_avg, alpha=0.2, linestyle="dotted")
trans = ax.get_yaxis_transform()
ax.text(0, views_avg, " Median ", color="gray", alpha=0.5, transform=trans, ha="left", va="bottom")
# Title
subscribers = df_channel.iloc[[0]]["subscribers"].values[0]
title_str = f"YouTube Channel, {subscribers/1_000_000:.1f}M subscribers"
# Adjust axis
ax.xaxis.set_major_formatter(mdates.DateFormatter("%d/%m"))
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ",")))
ax.xaxis.set_major_locator(mdates.WeekdayLocator(byweekday=mdates.SU))
ax.set(title=title_str,
xlabel="Video Publication Date",
ylabel="Views",
xlim=(start_date, end_date),
ylim=(0, views_max))
plt.tight_layout()
plt.show()
在这里,我使用ax.bar绘制条形图,并用rescale函数调整条形的颜色。水平中位线有助于查看观看次数是否高于或低于平均水平。
首先,让我们看看一个拥有 2390 万订阅者的频道,它发布“化妆”类别的视频:

最近 2 个月的观看次数,作者提供的图片
结果很有趣。我们可以看到每个视频的观看次数基本一致,中位数大约是 100 万次。这个数字多吗?这个值显然很大。但频道几乎有 2400 万订阅者,这些订阅者应该对这些内容感兴趣,并在发布新视频时收到通知。我认为 1/24 的比例看起来并不大;也许人们订阅了但很快失去了兴趣?
另一个有趣的“异常”引起了我的注意。有时,我将显示间隔设置为 1 年,这时可以看到观看次数大幅增加:

每年观看次数,作者提供的图片
显然,作者发布了大量的“短视频”,在那段时间内获得了许多(300 万到 1000 万)观看次数。后来发生了什么?也许频道的主编换了?也许制作“短视频”不再盈利?我不知道。可能可以在网页浏览器中观看所有视频并尝试找出原因,但这绝对超出了这个测试的范围,而且我也不是化妆方面的专家。
作为另一个例子,让我们看看另一个拥有 1780 万订阅者的频道,这个频道制作小工具评测:

观看次数,作者提供的图片
我不知道这些结果是否与内容有关(“小工具评测”和“化妆”自然面向不同的受众),但与第一个频道相比,这个频道每个视频的中位观看次数要高得多。
现在让我们看看受众较小的频道能获得多少观看次数。这个与小工具相关的频道拥有130 万订阅者:

观看次数,作者提供的图片
差异是显著的。一个拥有 17.8M 订阅者的频道每个视频大约获得 300 万次观看,而一个拥有 1.3M 订阅者的频道仅获得“仅”30 万次观看。为了比较,下一个与摄影相关的频道拥有115K 受众:

观看次数,图片来自作者
在这种情况下,频道的每个视频平均观看次数为 25K。
显然,视频不仅向订阅者展示,还通过 YouTube 推荐系统展示给任何人。真实的比例是什么?我们不知道。从柱状图来看,我可以猜测只有大约 20%的订阅者是“活跃的”。其他人可能很久以前订阅了,已经对内容不再感兴趣。这是有道理的;例如,如果我要买一台笔记本电脑,我可以订阅一个硬件评测频道,但在购买后我可能就不再感兴趣了。
3.2 观看次数动态 我们能够看到每个视频的观看次数,但视频获取这些观看次数的速度有多快?我们已经有了一个 Matplotlib 柱状图;让我们对它进行动画处理!只有频道所有者可以访问历史数据,但我在三周内进行了请求,我们可以轻松地看到这些值在这个时间间隔内是如何变化的。为此,我们只需更新图表:
import matplotlib.animation as animation
def animate_bar(frame_num: int):
""" Update graph values according to frame number """
interval_start = df_channel["timestamp"].min() + pd.Timedelta(hours=step_size*frame_num)
interval_end = df_channel["timestamp"].min() + pd.Timedelta(hours=step_size*(frame_num+1))
day_str = interval_start.strftime('%d/%m/%Y %H:00')
days, views = get_views_per_interval(df_channel, interval_start, interval_end)
print(f"Processing {day_str}: {views.shape[0]} items")
bar = ax.bar(days, views,
color=cmap(rescale(views)),
width=pd.Timedelta(hours=bar_width))
day_vline.set_xdata([interval_start])
ax.set(title=f"{title_str}: {day_str}")
return bar,
step_size = 3
num_frames = (df_channel["timestamp"].max() - df_channel["timestamp"].min())//pd.Timedelta(hours=step_size)
anim = animation.FuncAnimation(fig, animate_bar, repeat=True, frames=num_frames)
writer = animation.PillowWriter(fps=5)
anim.save("output.gif", writer=writer)
在这里,我创建了一个FuncAnimation对象,其中animate_bar函数作为参数传递。这个函数会自动调用不同的帧编号;在这个函数内部,我创建了一个新的柱状图并更新了标题。我还添加了一条垂直线,代表当前日期。
输出结果如下:

3 周内的观看次数,图片来自作者
从这个动画中,我们可以看到一个新视频显然在第一周内获得了至少 70%的观看次数。旧视频也会获得一些观看次数,但这个过程要慢得多。
但也可能会有例外。在下一个例子中,一个频道的每个视频的中位数观看次数为 90K,但其中一个视频可能变得病毒式传播,被大量分享,并在 2 到 3 周内获得了大约一百万次观看:

3 周内的观看次数,图片来自作者
3.3 观看次数分布 在观看了柱状图后,我问自己一个问题:观看次数的分布是否正常?显然,有些视频的观看次数比其他视频多,但这种情况有多一致?使用 Seaborn 的histplot方法很容易找到答案。
import seaborn as sns
channel_id = "UCu..."
df_channel = df_channels[df_channels["channelId"] == channel_id]
display(df_channel.drop_duplicates(subset=["videoId"]))
step_size = 3
interval_start = df_channel["timestamp"].max() - pd.Timedelta(hours=step_size)
interval_end = df_channel["timestamp"].max()
df_interval = df_channel[(df_channel["timestamp"] >= interval_start) & (df_channel["timestamp"] < interval_end)].drop_duplicates(subset=["videoId"])
# Title
subscribers = df_channel.iloc[[0]]["subscribers"].values[0]
title_str = f"YouTube Channel, {subscribers/1_000_000:.1f}M subscribers"
# Median
views_avg = df_channel["viewCount"].median()
# Draw
fig, ax = plt.subplots(figsize=(12, 5))
sns.set_style("white")
sns.histplot(data=df_interval, x="viewCount", stat="percent", bins=50)
ax.set(title=title_str,
xlabel="Views Per Video",
ylabel="Percentage",
xlim=(0, None),
ylim=(0, 18)
)
ax.axvline(x=views_avg, alpha=0.2, linestyle="dotted")
ax.xaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ',')))
plt.tight_layout()
plt.show()
对于这个测试,我将 500 作为 API 请求的视频限制。一个“电子产品评测”类别频道的结果如下:

观看次数分布,图片来自作者
1780 万订阅者是一个大数字。这个频道绝对是顶级之一,正如我们所见,它产生的结果或多或少是一致的。分布看起来正常,但略有偏斜。该图表的中位值是每个视频 380 万次观看,但一些视频的观看次数超过了 1000 万次,且 500 个视频中只有 3 个超过了 2000 万次。
在其他订阅者较少的频道中也可以看到类似的模式,但在这种情况下,分布更加偏斜:

观看次数分布,作者提供的图片
这些数据可能需要更详细的分析。例如,结果显示“正常视频”和“短视频”可以有截然不同的观看次数,理想情况下应该分开分析。
3.4 附加内容:单个视频的观看次数 这篇文章已经很长了,我会给那些耐心读到这一步的读者一个附加内容。在 3.2 中,我做了动画,显示大多数视频在发布后很快获得了大量观看(顺便提一下,这对于 TDS 和 Medium 文章也是如此)。我们能更详细地看到这个过程吗?实际上可以。我在几周内收集了数据,期间有足够的视频发布。找到最新的视频很简单,因为我们有一个videoPublishedAt参数:
# Find the newest videos for a specific channel
df_channel = df_channels[df_channels["channelId"] == "UCB..."]
num_videos = 5
df_videos = df_channel.drop_duplicates(subset=["videoId"]).sort_values(by=["videoPublishedAt"], ascending=False)
提醒一下,特定视频的数据如下:

数据框示例,作者提供的图片
然后,我“标准化”了这些数据:我的目标是显示从发布时间开始的观看次数,我将其视为“0”:
def get_normalized_views(df_channel: pd.DataFrame, video_id: str) -> pd.DataFrame:
""" Get relative views for a specific video """
df_video = df_channel[df_channel["videoId"] == video_id].sort_values(by=['timestamp'], ascending=True)
# Insert empty row with zero values at the beginning
video_pub_time = df_video.iloc[[0]]["videoPublishedAt"].values[0]
start_row = {'videoPublishedAt': video_pub_time,
'timestamp': video_pub_time,
'viewCount': 0, 'likeCount': 0, 'commentCount': 0}
df_first_row = pd.DataFrame(start_row, index=[0])
df_video_data = df_video[df_first_row.columns]
df_video_data = pd.concat([df_first_row, df_video_data], ignore_index=True)
# Make timestamps data relative, starting from publication time
df_first_row = df_video_data.iloc[[0]].values[0]
df_video_data = df_video_data.apply(lambda row: row - df_first_row, axis=1)
df_video_data["daysDiff"] = df_video_data["timestamp"].map(lambda x: x.total_seconds()/(24*60*60), na_action=None)
return df_video_data
这里,我还将时间戳转换为从发布时间开始的天数,以使图表更方便阅读。
现在,我们可以使用 Matplotlib 绘制图表:
fig, ax = plt.subplots(figsize=(10, 6))
# Title
subscribers = df_channel.iloc[[0]]["subscribers"].values[0]
title_str = f"YouTube Channel with {subscribers/1_000_000:.1f}M Subscribers, Video Views"
# Videos data
for p in range(num_videos):
video_id = df_videos.iloc[[p]]["videoId"].values[0]
df_video_data = get_normalized_views(df_channel, video_id)
plt.plot(df_video_data["daysDiff"], df_video_data["viewCount"])
# Params
ax.set(title=title_str,
xlabel="Days Since Publication",
ylabel="Views",
xlim=(0, None),
ylim=(0, None))
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, p: format(int(x), ',')))
ax.tick_params(axis='x', rotation=0)
plt.tight_layout()
plt.show()
结果看起来是这样的:

每个视频的观看次数,作者提供的图片
这里,线的长度不同,因为视频的发布时间不同。最早的视频发布于将近两周前,而最新的视频发布于数据收集前的两天。
从这张图表中,我有两个有趣的观察。
首先,至少对于这个频道而言,我的假设是正确的,这些视频在发布后立即获得了最多的观看次数。更重要的是,曲线(例如红色和绿色的曲线)几乎是相同的。
其次,细心的读者可能会看到两个明显的组——前两个视频获得了大约 300 万次观看,而另外三个视频显然获得了大约 50 万次观看。确实,这些视频不同。顶部的线代表“正常”视频,底部的线代表“YouTube Shorts”。显然,至少对这个频道而言,观众对“短视频”的兴趣较低。
但显然,结果可能有所不同。首先,一些视频可能会变得更受欢迎甚至成为病毒视频;它们可以获得更多的观看次数:

每个视频的观看次数,图像作者提供
其次,内容本身也很重要。例如,关于小工具的评测在“新鲜”时通常最有趣,但关于健康、关系、运动、化妆或任何类似主题的视频对观众的长期价值可能更高。最后但同样重要的是,这些特定频道拥有大量订阅者,视频在发布后很快就会获得许多观看次数。对于“新手”而言,结果可能会有所不同,大多数新频道的观众可能来自 YouTube 推荐系统或搜索结果。因此,我只能建议读者自行研究,选择一个大致符合他们想了解内容的 YouTube 频道。
结论
在这篇文章中,我展示了如何收集和分析有关不同 YouTube 频道和视频的数据。在第一部分,我重点关注了诸如每个频道的观看次数等一般属性。在这一部分,我关注了单个视频。我们能够看到不同频道上新视频发布的频率、它们可以获得的观看次数以及这一过程的速度。这种分析通常仅对频道所有者开放,但借助 YouTube 数据 API,我们可以以高精度免费收集数据。这不仅对那些希望开设新频道的人感兴趣,也从文化和统计的角度来看非常有意义。
显然,YouTube 是一个庞大的流媒体平台,拥有数百万个频道和数十亿个视频。关于猫、数学问题或笔记本电脑评测的视频可以获得截然不同的观看次数、点赞数和评论数。因此,我鼓励读者对他们感兴趣的频道进行自己的测试。在这篇文章中,我只关注了观看次数,但评论数或点赞数也可以通过相同的方式进行分析(顺便提一下,我们可以通过 API 获取点赞数,但 YouTube 从 2021 年开始移除了公众对不喜欢数的访问)。
在下一部分也是最后一部分中,我将关注 YouTube “Shorts”。这些类型的视频显示在一个单独的 YouTube 页面上,该页面具有不同的 UI,并且观看次数或点赞数可能会有显著差异。敬请关注。
对社交数据分析感兴趣的人也欢迎阅读其他文章:
-
探索性数据分析:我们对 YouTube 频道了解多少(第一部分)
如果你喜欢这个故事,欢迎订阅 Medium,这样你将会收到我新文章发布的通知,并且可以全面访问其他作者的数千篇故事。本文的完整源代码和 Jupyter notebook 也可以在我的Patreon 页面找到。
感谢阅读。
探索 Pydantic V2 的增强数据验证功能
原文:
towardsdatascience.com/explore-pydantic-v2s-enhanced-data-validation-capabilities-792a3353ec5
了解 Pydantic V2 的新功能和语法
·发布于 Towards Data Science ·7 min read·2023 年 10 月 25 日
--

图片由 jackmac34 在 Pixabay 提供
数据验证是数据工程和软件开发领域中稳健应用的基石。确保数据的清洁性和准确性不仅对应用的可靠性至关重要,也对用户体验有很大影响。
Pydantic 是 Python 中使用最广泛的数据验证库。Pydantic 最新版本(V2)的核心已经用 Rust 重新编写,相比于之前的版本性能大大提升。此外,在功能方面也有一些重大改进,例如支持严格模式、无模型验证、模型命名空间清理等。
本文将深入探讨 Pydantic 强大数据验证功能的最新特性和增强性能,为开发者提供一个全面的数据处理工具集。
准备工作
要跟随本文中的示例,您应该安装现代版本的 Python(≥ 3.10)和最新版本的 Pydantic V2。建议使用 conda 虚拟环境来管理不同版本的 Python 和库:
conda create -n pydantic2 python=3.11
conda activate pydantic2
pip install -U pydantic
基本用法
通常使用 Pydantic 时,我们需要先通过模型定义数据的模式,这些模型只是继承自 BaseModel 的类。在这些模型中,每个字段的数据类型由类型提示定义。
from pydantic import BaseModel
class ComputerModel(BaseModel):
brand: str
cpu: str
storage: int
ssd: bool = True
要使用此模型进行验证,我们可以通过传递每个字段的值来创建一个实例:
input_dict = {"brand": "HP", "cpu": "Intel i7 1265U", "storage": "256"}
computer = ComputerModel(**input_dict)
print(computer)
# brand='HP' cpu='Intel i7 1265U' storage=256 ssd=True
storage 字符串数据会被强制转换为模型中定义的整数。
为了演示的简便性,我们在本文中仅使用两个字段,即brand和storage,这可以轻松扩展到其他字段。
# Basic model used in this post
from pydantic import BaseModel
class ComputerModel(BaseModel):
brand: str
storage: int
直接验证数据
在上面的示例中,为数据验证创建了一个 Pydantic 模型的实例。在 Pydantic V2 中,我们还可以直接使用model_validate()和model_validate_json()来验证字典或 JSON 数据:
ComputerModel.model_validate({"brand": "HP", "storage": "256"})
# ComputerModel(brand='HP', storage=256)
import json
input_json = json.dumps({"brand": "HP", "storage": "256"})
ComputerModel.model_validate_json(input_json)
# ComputerModel(brand='HP', storage=256)
在 Pydantic V2 中,所有模型的方法都以model_开头,因此字段名称不允许以model_开头。然而,如果需要,可以使用字段别名。
在严格模式下验证数据
默认情况下,严格模式是关闭的,这意味着数据类型会被强制转换(如果可能的话)。例如,在上述示例中,storage字段的类型从str被强制转换为int。我们可以禁用严格模式,这样所有字段的数据类型必须完全匹配:
ComputerModel.model_validate({"brand": "HP", "storage": "256"}, strict=True)
# ValidationError: 1 validation error for ComputerModel
# storage
# Input should be a valid integer [type=int_type, input_value='256', input_type=str]
我们还可以在模型的字段级别设置严格模式,这样我们在验证步骤中就不需要指定它:
from pydantic import Field
class ComputerModelStrict(BaseModel):
brand: str
storage: int = Field(strict=True)
ComputerModelStrict.model_validate({"brand": "HP", "storage": "256"})
# ValidationError: 1 validation error for ComputerModel
# storage
# Input should be a valid integer [type=int_type, input_value='256', input_type=str]
使用model_config配置模型
在 Pydantic V2 中,为了指定模型的配置,我们可以将类属性model_config设置为一个字典,该字典包含将用于配置的键/值对。通常,我们通过一个称为ConfigDict的特殊字典来做到这一点,它是一个用于配置 Pydantic 行为的TypedDict。
例如,我们可以在模型级别设置strict模式,而不是在字段级别,如上所示:
from pydantic import BaseModel, ConfigDict
class ComputerModelStrict(BaseModel):
model_config = ConfigDict(strict=True, str_min_length=2)
brand: str
storage: int
ComputerModelStrict.model_validate({"brand": "HP", "storage": "256"})
# ValidationError: 1 validation error for ComputerModel
# storage
# Input should be a valid integer [type=int_type, input_value='256', input_type=str]
ComputerModelStrict.model_validate({'brand': 'X', 'storage': 256})
# ValidationError: 1 validation error for ComputerModelStrict
# brand
# String should have at least 2 characters [type=string_too_short, input_value='X', input_type=str]
我们还指定了字符串字段的最小长度为 2,因此像X这样的品牌将被拒绝。
使用typing.Annotated来处理字段
不必将Field值分配给字段以指定字段的行为,也可以使用类型提示typing.Annotated来完成:
from typing import Annotated
class ComputerModelStrict(BaseModel):
brand: str
storage: Annotated[int, Field(strict=True, gt=0)]
ComputerModelStrict.model_validate({'brand': 'HP', 'storage': '256'})
# ValidationError: 1 validation error for ComputerModel
# storage
# Input should be a valid integer [type=int_type, input_value='256', input_type=str]
ComputerModelStrict.model_validate({'brand': 'HP', 'storage': 0})
# ValidationError: 1 validation error for ComputerModelStrict
# storage
# Input should be greater than 0 [type=greater_than, input_value=0, input_type=int]
使用Annotated时,传递的第一个类型参数(这里是int)是实际类型,其余的是其他工具(这里是 Pydantic)的元数据。元数据可以包含任何内容,如何使用由其他工具决定。
具有动态默认值的字段
我们可以为字段设置动态默认值,这样它可以自动生成,并且每个模型实例可能不同。例如,我们可以将当前时间戳设置为模型的创建时间,并为其设置唯一的 ID。这可以通过使用default_factory来完成,它接受一个工厂函数作为输入。
from datetime import datetime
from typing import Annotated
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
class ComputerModel(BaseModel):
uid: UUID = Field(default_factory=uuid4)
brand: str
storage: int
created: datetime = Field(default_factory=datetime.utcnow)
ComputerModel.model_validate({'brand': 'HP', 'storage': 256})
# ComputerModel(uid=UUID('81474288-f691-4e37-b5e3-d28f0656d972'), brand='HP', storage=256, created=datetime.datetime(2023, 9, 29, 0, 5, 2, 958755))
这表明uid和created字段是自动创建的,并且每个模型将会不同。
字段和模型验证器
类似于在字段上应用严格模式,我们可以使用Annotated语法为字段应用自定义验证器。让我们添加一个自定义验证器,检查storage是否是有效值:
from typing import Annotated
from pydantic.functional_validators import AfterValidator
def check_storage(storage: int):
allowed = (128, 256, 512, 1000, 1024, 2000, 2048)
if storage not in allowed:
raise ValueError(f"Invalid storage, storage must be one of {allowed}")
return storage
class ComputerModel(BaseModel):
brand: str
storage: Annotated[int, AfterValidator(check_storage)]
ComputerModel.model_validate({'brand': 'HP', 'storage': 256})
# ComputerModel(brand='HP', storage=256)
ComputerModel.model_validate({'brand': 'HP', 'storage': 250})
# ValidationError: 1 validation error for ComputerModel
# storage
# Value error, Invalid storage, storage must be one of (128, 256, 512, 1000, 1024, 2000, 2048) [type=value_error, input_value=250, input_type=int]
AfterValidator表示验证将在 Pydantic 的内部验证逻辑之后应用。它相当于使用@field_validator()装饰器的after模式,如下所示。
请注意,验证代码不应抛出ValidationError本身,而应抛出ValueError或AssertionError(或其子类),这些异常将被捕获并用于填充ValidationError。
我们还可以使用@field_validator()装饰器为字段应用自定义验证器:
from typing import Annotated
from pydantic import BaseModel, Field, field_validator
from pydantic.functional_validators import AfterValidator
class ComputerModel(BaseModel):
brand: str
storage: int
@field_validator('storage', mode='after')
@classmethod
def check_storage(cls, storage: int):
allowed = (128, 256, 512, 1000, 1024, 2000, 2048)
if storage not in allowed:
raise ValueError(f"Invalid storage, storage must be one of {allowed}")
return storage
使用@field_validator()的效果应与上述Annotated语法完全相同。每种语法都有其优缺点:
-
使用
Annotated可以更轻松地重用自定义验证函数。 -
使用
field_validator我们可以更轻松地将相同的验证函数应用于多个字段。
因此,你需要根据具体的实际使用案例来决定使用哪种语法。
我们还可以使用@model_validator()将自定义验证器应用于整个模型。在这种情况下,我们可以访问所有字段的数据。例如,假设如果品牌是“Apple”,则存储必须至少为 256GB。
from __future__ import annotations
from typing import Annotated
from pydantic import BaseModel, Field, model_validator
from pydantic.functional_validators import AfterValidator
class ComputerModel(BaseModel):
brand: str
storage: int
@model_validator(mode='after')
def check_brand_storage(self) -> ComputerModel:
if self.brand.upper() == 'APPLE' and self.storage < 256:
raise ValueError("For Apple, the storage must be at least 256GB.")
return self
ComputerModel.model_validate({'brand': 'Apple', 'storage': 256})
# ComputerModel(brand='HP', storage=256)
ComputerModel.model_validate({'brand': 'Apple', 'storage': 128})
# ValidationError: 1 validation error for ComputerModel
# Value error, For Apple, the storage must be at least 256GB. [type=value_error, input_value={'brand': 'Apple', 'storage': 128}, input_type=dict]
请注意,@model_validator()是一个实例方法装饰器,而不是像@field_validator()那样的类方法装饰器。
转储或序列化
我们可以使用[model_dump](https://docs.pydantic.dev/latest/concepts/serialization/#modelmodel_dump)()将 Pydantic 模型的实例转换为包含实例值的字典。我们将使用文章开头介绍的更详细的模型来演示序列化:
from pydantic import BaseModel
class ComputerModel(BaseModel):
brand: str
cpu: str
storage: int
ssd: bool = True
input_dict = {"brand": "HP", "cpu": "Intel i7 1265U", "storage": "256"}
model = ComputerModel(**input_dict)
output_dict_default = model.model_dump()
# {'brand': 'HP', 'cpu': 'Intel i7 1265U', 'storage': 256, 'ssd': True}
output_dict_no_unset = model.model_dump(exclude_unset=True)
# {'brand': 'HP', 'cpu': 'Intel i7 1265U', 'storage': 256}
output_dict_included = model.model_dump(include={'brand', 'storage'})
# {'brand': 'HP', 'storage': 256}
output_dict_excluded = model.model_dump(exclude={'cpu'})
# {'brand': 'HP', 'storage': 256, 'ssd': True}
Pydantic 可以将许多常用类型序列化为 JSON,这些类型否则会与简单的json.dumps()不兼容(例如datetime、date或UUID)。如果需要,我们还可以使用[@field_serializer](http://twitter.com/field_serializer)()装饰器自定义字段的序列化方式。例如,我们可以在转储时将品牌转换为大写。
from pydantic import BaseModel, field_serializer
class ComputerModel(BaseModel):
brand: str
cpu: str
storage: int
ssd: bool = True
@field_serializer('brand')
def serialize_dt(self, brand: str, _info):
return brand.upper()
input_dict = {"brand": "Apple", "cpu": "M1", "storage": "512"}
model = ComputerModel(**input_dict)
# {'brand': 'APPLE', 'cpu': 'M1', 'storage': 512, 'ssd': True}
请注意,_info表示 Pydantic 自动提供的元数据。
在这篇文章中,我们介绍了如何使用最新版本的 Pydantic(V2)进行数据验证。在这个版本中引入了许多语法变化和新特性,这些在官方文档中可能会显得非常冗长和复杂。幸运的是,我们在日常工作中只使用了一小部分功能,并且大多数功能在这篇文章中通过简单的示例进行了介绍,这些示例可以为开发者提供全面的数据处理工具集。
相关帖子
探索语料库中的语义关系与嵌入模型
·
关注 发表在 Towards Data Science ·10 min 阅读·2023 年 11 月 24 日
--
最近,我与一些同学和学者讨论了他们研究兴趣涉及自由形式文本分析的话题。遗憾的是,获得对书面自然语言的有意义的见解绝非易事。密切阅读当然是一个选择,但你理想中会希望通过更宏观的分析/量化视角来看待文本数据。更不用说在大数据时代,密切阅读往往不可行。
到目前为止,我最喜欢在语料库上进行探索性数据分析的方法是主题模型,我已经写过多篇文章讨论如何以尽可能少的痛苦方式进行这项工作。尽管主题模型非常棒,但它们并不是所有文本任务的最佳方法。
嵌入是文本数据的数值表示,已经成为文本语义查询的经典方法。在这篇文章中,我们将探讨如何使用嵌入来分析文本数据的一些方法。
使用词嵌入捕捉概念之间的关系
词嵌入模型是一组以无监督方式学习术语潜在向量表示的方法。当从自然语言中学习词嵌入时,实际上是获得了一个嵌入空间中的语义关系图。
词嵌入通常在大型语料库上进行训练,以便捕捉人类语言中的一般词对词关系。这很有用,因为可以将关于语言的一般知识注入到特定应用的模型中。这也被称为迁移学习,并且在机器学习中一直是一个热门话题。
如果我们不想将一般知识转移到特定模型中,而是希望得到一个较小语料库的语义特定方面的映射,该怎么办呢?假设我们有一个来自论坛的评论语料库,我们想探索其中可以发现哪些关联关系。
一种方法是从头开始训练一个词嵌入模型,而不是使用已经为我们预训练的模型。在这个例子中,我将使用 20Newsgroups 数据集作为语料库,我们将在其中探索语义关系。
训练模型
现在让我们从一个词嵌入模型开始。你可能对 Word2Vec 有所了解,它是普及静态词嵌入在研究和实践中的方法。另一方面,由斯坦福大学的团队开发的 GloVe 在大多数情况下似乎是一种更好的方法,我的经验表明,它提供了更高质量的嵌入,特别是在较小的语料库上。
不幸的是,GloVe 在 Gensim 中没有实现,但幸运的是,我为原始 GloVe 代码制作了一个完全兼容 Gensim 的接口,我们将使用它来训练模型。
让我们安装 gensim、glovpy 和 scikit-learn,以便我们可以获取 20Newsgroups 以及 embedding-explorer:
pip install glovpy gensim scikit-learn
我们首先需要加载数据集,并对其进行标记化,为此我们将使用 gensim 内置的标记化工具。我们还将过滤掉停用词,因为它们对当前任务没有任何有意义的信息。
from gensim.utils import tokenize
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS
def clean_tokenize(text: str) -> list[str]:
"""This function tokenizes texts and removes stop words from them"""
tokens = tokenize(text, lower=True, deacc=True)
tokens = [token for token in tokens if token not in ENGLISH_STOP_WORDS]
return tokens
# Loading the dataset
dataset = fetch_20newsgroups(
remove=("headers", "footers", "quotes"), categories=["sci.med"]
)
newsgroups = dataset.data
# Tokenizing the dataset
tokenized_corpus = [clean_tokenize(text) for text in newsgroups]
之后,我们可以在标记化的语料库上轻松训练一个 GloVe 模型。
from glovpy import GloVe
# Training word embeddings
model = GloVe(vector_size=25)
model.train(tokenized_corpus)
我们已经可以查询这个词嵌入模型了,举个例子,让我们检查一下哪些十个词最接近“child”。
model.wv.most_similar("child")
==============================
+------------+----------+
| age | 0.849304 |
| consistent | 0.844267 |
| adult | 0.805101 |
| range | 0.800615 |
| year | 0.798799 |
| hand | 0.792965 |
| children | 0.792113 |
| use | 0.789804 |
| restraint | 0.773764 |
| belt | 0.77003 |
+------------+----------+
调查,可视化!
不过,单独调查每个词与其他词的关系很快会变得乏味。理想情况下,我们也希望可视化关系,甚至可能得到一些网络。
幸运的是,embedding-explorer 包可以帮助我们,我也开发了这个包。在计算人文学科中,我们经常使用词嵌入模型及其建立的语义网络,而 embedding-explorer 帮助我们以互动和可视化的方式探索这些网络。该包包含多个互动网页应用,我们首先来看看“网络探索器”。
这个应用的想法是,嵌入模型中的概念自然形成某种网络结构。相关性强的词有强链接,而其他词可能没有。在应用中,您可以基于指定的一组种子词和两个自由联想级别构建概念图。
在每个关联级别,我们从嵌入模型中找出与已有词汇最接近的五个词,并将其添加到我们的网络中,与其关联的词相连。连接的强度由嵌入空间中概念的余弦距离决定。这类网络在我或我的同事进行的多个研究项目中都证明了其有用性。
让我们启动我们的词嵌入模型应用。
from embedding_explorer import show_network_explorer
vocabulary = model.wv.index_to_key
embeddings = model.wv.vectors
show_network_explorer(vocabulary, embeddings=embeddings)
这将打开一个浏览器窗口,您可以自由探索语料库中的语义关系。这里是我查看围绕“jesus”,“science”和“religion”这些词汇形成的网络的截图。

探索我们 GloVe 模型中的语义关系
例如,我们可以看到,人们在线讨论这些话题时,似乎暗示了宗教和科学通过政治、社会和哲学相关联,这非常有道理。观察到教育在科学和宗教之间,虽然明显更接近科学,这也很有趣。这将值得进一步探讨。
N-grams 与句子变换器的网络
那么,如果我们不仅想查看词级别的关系,还想查看短语或句子呢?
我的建议是使用 N-grams。N-grams 本质上就是文本中连续的 N 个术语。例如,在句子“I love my little cute dog”中,我们将得到 4-grams:“I love my little”,“love my little cute”和“my little cute dog”。现在的问题是,我们如何学习 N-grams 的良好语义表示?
技术上,你仍然可以通过将短语或句子视为一个标记来使用 GloVe,但有一个问题。由于 N-grams 的多样性随着 N 的增加而急剧增加,某些 N-grams 可能只出现一两次,我们可能无法学习到良好的表示。
取短语中词嵌入的平均值怎么样?这可能会有很大帮助,但问题是我们完全丧失了关于不同词的重要性、它们在句子中的顺序以及所有上下文信息。
解决此问题的方案是使用句子变换器,这些深度神经语言模型产生具有上下文敏感性的文本表示。它们已经超越了所有其他方法好几年,并成为了嵌入文本的行业标准。现在训练这样的模型需要大量的数据,我们手头没有,但幸运的是,我们可以使用一些优秀的预训练模型。
N-gram 提取
首先,让我们从语料库中提取 N-gram。我选择了四元组,但你可以选择任何你喜欢的数量。我们将使用 scikit-learn 的 CountVectorizer 来完成这项工作。
from sklearn.feature_extraction.text import CountVectorizer
# First we train a model on the corpus that learns all 4-grams
# We will only take the 4000 most frequent ones into account for now,
# But you can freely experiment with this
feature_extractor = CountVectorizer(ngram_range=(4,4), max_features=4000)
feature_extractor.fit(newsgroups)
# Then we get the vectorizer's vocabulary
four_grams = feature_extractor.get_feature_names_out()
嵌入模型
我们需要一个嵌入模型来表示文本。正如我之前所说,我们将使用一个预训练模型。我选择了all-MiniLM-L6-v2,因为它非常稳定、广泛使用且相当小巧,因此即使在你的个人电脑上也能顺畅运行。
我们将使用另一个包,embetter,以便可以以与 scikit-learn 兼容的方式使用句子变换器。
pip install embetter[text]
我们可以像这样在 Python 中加载模型:
from embetter.text import SentenceEncoder
encoder = SentenceEncoder("all-MiniLM-L6-v2")
探索!
然后,我们可以将模型和 N-gram 加载到 embedding-explorer 中。
from embedding_explorer import show_network_explorer
show_network_explorer(four_grams, vectorizer=encoder)
请注意,这允许我们指定任何任意种子,而不仅仅是我们四元组词汇表中的种子。这里是我输入两个句子的截图,并查看从它们周围的四元组构建了什么样的网络。

探索语料库中的短语和句子
有趣的是再次观察哪些短语处于中间位置。看起来法律和历史在这里充当了宗教和科学之间的某种连接。
使用文档嵌入研究语料库级语义结构
我们现在已经在单词和短语级别查看了我们的语料库,并观察了它们中自然出现的语义结构。
如果我们想要了解在文档级别发生了什么呢?哪些文档彼此接近,出现了什么样的群体?
请注意,一个自然的解决方案是主题建模,如果你还没有尝试过,应该看看。在本文中,我们将探索与此任务相关的其他概念化方法。
文档表示
和以前一样,我们需要考虑如何表示单个文档,以便捕捉它们的语义内容。
更传统的机器学习实践通常使用词袋模型表示或训练 Doc2Vec 模型。这些都是很好的选择(你可以并且应该尝试这些方法),但它们同样缺乏文本的上下文理解。由于我们语料库中的文本不算太长,我们仍然可以使用句子变换器进行嵌入。让我们继续使用我们为短语使用的相同嵌入模型。
投影与聚类
探索文档的语义表示的一种自然方式是将它们投影到较低维度的空间(通常是 2D),并使用这些投影来可视化文档。我们还可以查看文档在某些聚类方法下的聚类情况。
现在这一切都很棒,但投影、降维和聚类方法的领域如此广泛,以至于我常常会想:“如果我使用其他方法,结果会有实质性不同吗?” 为了应对这个问题,我在嵌入探索中添加了另一个应用程序,你可以自由快速地探索不同方法下的可视化效果。
这是我们的工作流程:
1. 我们可能会在继续之前减少嵌入的维度。你可以选择各种降维方法,或者关闭它。
2. 我们想将嵌入投影到 2D 空间中,以便我们可以对其进行可视化。
3. 我们可能想要对嵌入进行聚类,以查看哪些文档被归为一组。

嵌入探索中的聚类与投影工作流
现在,在进行此操作时,我们还需要了解一些关于文档的外部信息(文本内容、标题等),否则我们没有太多可解释的内容。
让我们创建一个包含以下列的数据框:
1. 每个文档的前 400 个字符,以便我们可以了解文本的内容。
2. 文本的长度,以便我们可以在可视化中查看哪些文本较长,哪些较短。
3. 它们在数据集中来源的组。
import pandas as pd
import numpy as np
# Extracting text lengths in number of characters.
lengths = [len(text) for text in corpus]
# Extracting first 400 characters from each text.
text_starts = [text[:400] for text in corpus]
# Extracting the group each text belongs to
# Sklearn gives the labels back as integers, we have to map them back to
# the actual textual label.
group_labels = np.array(dataset.target_names)[dataset.target]
# We build a dataframe with the available metadata
metadata = pd.DataFrame(dict(length=lengths, text=text_starts, group=group_labels))
然后我们可以启动应用程序,传递元数据,以便我们可以悬停并查看有关文档的信息。
from embedding_explorer import show_clustering
show_clustering(
newsgroups,
vectorizer=encoder,
metadata=metadata,
hover_name="group", # Title of hover box is going to be the group
hover_data=["text", "length"] # We would also like to see these on hover
)
当应用程序启动时,你会首先看到这个屏幕:

聚类应用中的选项
运行聚类后,你将能够查看按聚类着色的所有文档的地图。你可以悬停在点上以查看文档的元数据……

聚类应用截图
并且在底部你甚至可以选择点的颜色、标签和大小。

带有文档大小的聚类
总结
对文本数据的探索性分析是困难的。我们已经研究了几种使用最先进的机器学习技术进行互动调查的方法。我希望本文讨论的方法以及 embedding-explorer Python 包对你未来的研究/工作有所帮助。
和平 ✌️
((本文中的所有图片均来自 embedding-explorer 的文档,由作者制作))
探索 DAX 中的时间智能变体
原文:
towardsdatascience.com/explore-variants-of-time-intelligence-in-dax-e795545e2a40
我们在 DAX 中有一些有价值的函数来计算前一年、年初至今等。但其他函数允许我们用不同的方式计算相同的结果。让我们探索我们的选项及其重要性。
·发表于Towards Data Science ·阅读时间 11 分钟·2023 年 11 月 6 日
--

图片由Lukas Blazek提供,来源于Unsplash
引言
当我们查看报告中的数值时,我们必须赋予它们意义。
最常见的做法是将其与前一时期的数值进行比较。
这时,时间智能就会发挥作用,成为我们在报告中想展示的主要内容之一。
例如,查看以下报告页面:

图 1 — 示例报告(图由作者提供)
在这里,你可以看到销售数据与前一年、年初至今和滚动总额的比较。
这是一个典型的报告,广泛使用时间智能逻辑。
现在,我将向你展示如何创建测量值来计算这些结果以及使用不同 DAX 函数的变体。
为了更好地阅读,我将仅在表格中显示结果。
DAX 中的时间智能
我们在 DAX 中有大量的时间智能函数。查看dax.guide,我们可以找到超过 35 个函数的列表。
一些函数包含了完整的时间智能逻辑,例如TOTALYTD()或SAMEPERIODLASTYEAR()。或者,我们可以使用一些更基础的函数,这些函数必须与CALCULATE()配合使用,例如DATEADD()或DATESBETWEEN()。
我想展示如何使用“易于使用”的函数,然后展示如何使用基础函数来探索它们的好处。
有时,我们得到不同的结果取决于使用的函数,而有时,通过使用正确的函数可以避免问题。
在结论中,我将解释了解不同函数和计算选项为何重要。
日期表
我们必须拥有一个适当的日期表才能正确使用时间智能函数。
当我在 Google 上搜索“为什么我们需要在 Power BI 中使用日期表”时,我得到了 149,000,000 个结果。因此,关于为什么需要它的学习材料非常多。
我在 Medium 上的第一篇文章之一是关于日期表的:
时间是最关键的报表维度。你可以通过使用一个好的日期表来扩展你的报表功能。
towardsdatascience.com
简单来说,我可以使用中央日期表做很多酷的事情。
例如,我可以在日期表中添加具有特定逻辑的列,以便在 DAX 和 Power BI 中更方便地工作。
在本文的其余部分,我假设你知道如何构建日期表,并且已经在数据模型中将该表标记为“日期表”(在 Power BI Desktop 中设置和使用日期表 — Power BI | Microsoft Learn)。
年初至今 (YTD)
我将 YTD 函数作为季度到日期 (QTD) 和月份到日期 (MTD) 的其他两个变体的占位符。
YTD 计算从年初到当前日期(当前筛选上下文)的所有日期的表达式。
让我们从 TOTALYTD() 函数开始:
Sales YTD = TOTALYTD([Sum Sales], 'Date'[Date])
如果日期表在数据模型中标记为“日期表”,这个函数可以完美运行。
如果你不想这样做,这个函数将无法工作。
在这种情况下,DATESYTD() 函数必须与 ALL() 函数一起使用:
Online Sales (YTD) =
VAR YTDDates = DATESYTD('Date'[Date])
RETURN
CALCULATE([Sum Online Sales]
,YTDDates
,ALL('Date')
)
旁注: 如果你不想将日期表设置为数据模型中的“日期表”,你必须在所有度量值中添加 ALL(‘Date’) 过滤器。
无论如何,我更喜欢使用包含 CALCULATE() 的形式。
由于我总是将我的日期表标记为“日期表”,我不需要添加 ALL(‘Date’) 过滤器:
Online Sales (YTD) =
VAR YTDDates = DATESYTD('Date'[Date])
RETURN
CALCULATE([Sum Online Sales]
,YTDDates
)
这三种形式中的任何一种都会得到相同的结果:

图 2 — YTD 结果(作者提供的图)
包含 CALCULATE() 和 DATESYTD() 的形式更为灵活,并允许在必要时添加额外的逻辑。而 TOTALYTD() 执行的过程与组合前两种函数相同。
我在 DAX Studio 中使用服务器计时检查了这个,执行计划是相同的。
前一年(PY)
要获取前一年值,我们可以使用 SAMEPERIODLASTYEAR() 函数:
Sum Online Sales PY =
CALCULATE([Sum Online Sales],
SAMEPERIODLASTYEAR('Date'[Date])
)
像 TOTALYTD() 一样,这是计算前一年值的一种简单方法。
要了解这个函数的作用,我们可以使用 DATEADD()。
前一年 — 使用 DATEADD
使用 DATEADD() 计算前一年值的表达式如下:
Online Sales (PY 2) =
CALCULATE(
[Sum Online Sales]
,DATEADD('Date'[Date], -1, YEAR)
)
在这两种情况下,结果是一样的:

图 3 — 两种 PY 形式的结果(作者提供的图)
像 YTD 计算一样,SAMEPERIODLASTYEAR() 执行与使用 DATEADD() 时相同的过程。
但 DATEADD() 与 SAMEPERIODLASTYEAR() 不同的是,我可以定义所需数据的周期和距离。
例如,我想获取前一个季度:
Online Sales (PQ) =
CALCULATE(
[Sum Online Sales]
,DATEADD('Date'[Date], -1, QUARTER)
)
或者我想获取前一个月:
Online Sales (PM) =
CALCULATE(
[Sum Online Sales]
,DATEADD('Date'[Date], -1, MONTH)
)
对于前一个学期,我可以这样做:
Online Sales (PS) =
CALCULATE(
[Sum Online Sales]
,DATEADD('Date'[Date], -6, MONTH)
)
这些变体的结果如下:

图 4 — 使用 DATEADD() 的度量结果(作者提供的图)
理解这些度量计算每个周期(年、学期、季度和月)的前一个周期是很重要的:

图 5 — 不同级别度量的结果(作者提供的图)
对我来说,结果看起来如此直观,但我认为指出这个细节是值得的。
一个重要细节:DATEADD() 可以处理天、月、季度和年。这就是为什么当我需要获取前一个学期时,我需要回溯六个月或两个季度。
讨论时间智能时,还有更多内容需要探讨。
PARALLELPERIOD() 怎么样?
PARALLELPERIOD() 函数非常有趣。
我强烈建议阅读 dax.guide 上的描述并观看附带的视频。
简而言之,PARALLELPERIOD() 将当前日期的过滤上下文扩展到整个期间。
看看以下度量:
Online Sales PP 3 Month =
CALCULATE([Sum Online Sales]
,PARALLELPERIOD('Date'[Date], -3, MONTH)
)
乍一看,它看起来类似于使用 DATEADD() 回溯 3 个月的度量。
但看看结果:

图 6 — PARALLELPERIOD() 的结果(作者提供的图)
如你所见,即使在天级别,三个月前的销售结果也是一样的。
而使用 DATEADD() 时,天级别的结果与月级别的结果相同,但用于前一个周期,如度量中设置的(例如,前一个季度)。
当我写另一个度量时,如下:
Online Sales PP 2 Quarter =
CALCULATE([Sum Online Sales]
,PARALLELPERIOD('Date'[Date], -2, QUARTER)
)
结果如下:

图 7 — 前一个季度的 PARALLELPERIOD 结果(作者提供的图)
正如你所预期的那样,每个月内的天数结果是相同的。
如前所述,查看该函数的描述,它可能会让你感到惊讶。
滑动窗口(过去三个月)
需要探索的最后一种情况是所谓的“滑动窗口”。
这是当你想计算一个定义时间窗口的总和时。
例如,“过去三个月”将基于实际日期计算过去三个月的总和。
我们使用 DATESINPERIOD() 和 PARALLELPERIOD() 函数来实现这一点。
但要小心。这两个函数的工作方式不同,给出的结果也不同。
让我们看看这两者并检查结果。
要使用 DATESINPERIOD() 计算过去三个月的结果,我们可以编写这个度量值:
Online Sales (Last three Months) =
VAR FirstSelDate = MAX('Date'[Date])
RETURN
CALCULATE([Sum Online Sales]
,DATESINPERIOD('Date'[Date], FirstSelDate, -3, MONTH)
)
让我们看看结果:

图 8 — 使用 DATESINPERIOD() 计算过去三个月的结果(图表由作者提供)
如上图所示,每个月的总和是实际月份加上前两个月的总和。
但要注意:当我们查看单个日期时,结果略有不同:

图 9 — 使用 DATESINPERIOD() 计算每一天的过去三个月的结果(图表由作者提供)
这是因为该度量值获取每一行的最后日期,并将其作为起点向回“移动”三个月。
因此,当我们查看单个日期时,起点是该特定日期,而不是整个月份。
获取过去三个月总和的一个变体可以是:
Online Sales (Last three Months) v2 =
VAR FirstSelDate = MIN('Date'[Date]) - 1
RETURN
CALCULATE([Sum Online Sales]
,DATESINPERIOD('Date'[Date], FirstSelDate, -3, MONTH)
)
这个度量值返回的结果是从所选周期的第一天减去一天,即上一周期的最后一天的在线销售总和。
我使用了“周期”一词。但我应该使用“间隔”一词,因为根据 dax.guide,我可以使用 DATESINPERIOD() 计算天、月、季度和年的间隔。
在下面的情况下:当前月份之前的三个月的销售额。

图 10 — 当前周期之前的过去三个月的结果(图表由作者提供)
当观察这些天的结果时,这个度量值返回的结果与第一个度量值相同。
以下变体使用 DATESBETWEEN():
Online Sales Last three Months (DP) =
VAR FirstDateInPeriod = EOMONTH(MIN('Date'[Date]), -3) + 1
VAR LastDateInPeriod = EOMONTH(MAX('Date'[Date]), 0)
RETURN
CALCULATE([Sum Online Sales]
,DATESBETWEEN('Date'[Date], FirstDateInPeriod, LastDateInPeriod)
)
函数 DATESBETWEEN() 返回第一个日期和第二个日期之间的日期列表(第二个和第三个参数)。
在这种情况下,我使用 EOMONTH() 函数计算三个月窗口的第一天和最后一天。
结果与 DATESINPERIOD() 几乎相同:

图 11 — 使用 DATESBETWEEN() 计算“过去三个月”的结果(图表由作者提供)
但使用 DATESINPERIOD() 和 DATESBETWEEN() 之间有两个显著的区别:
-
即使在天级别,使用
DATESBETWEEN()时结果也是相同的。 -
但是在月份以上的层级(季度、学期和年)上,结果是错误的。
这种情况发生是因为度量获取了当前时期的起始和结束日期。在月份以上的层级,这些日期是基于当前层级的日期。
为了解决第二个问题,我们可以将度量更改为以下内容:
Online Sales Last three Months (DP) v2 =
VAR FirstDateInPeriod = EOMONTH(MIN('Date'[Date]), -3) + 1
VAR LastDateInPeriod = EOMONTH(MIN('Date'[Date]), 0)
RETURN
CALCULATE([Sum Online Sales]
,DATESBETWEEN('Date'[Date], FirstDateInPeriod, LastDateInPeriod)
)
这是结果:

图 12 — 使用 DATESBETWEEN 的“过去三个月”结果 — 第二种变体(作者图示)
我用红色标记了之前的结果,其中年份、学期和季度的结果是错误的,用蓝色标记了新的结果。
在这一点上,你必须决定哪种变体最适合你的需求,因为需求可能因情况而异。
为什么要以不同的方式做
那么,我们为什么要考虑不同的方法来计算时间智能结果?
这样做有很多原因。
例如,为了理解不同函数的效果。
所有这些函数返回一个包含一列和日期列表的表格。但不同的函数以略微不同的方式构建这个表格。
因此,现在我们处于更好的位置来满足报告的要求,因为我们知道更多提供所需结果的方法。
此外,我们可以提出不同的解决方案,并挑战我们的利益相关者思考这些方案,找到最适合他们需求的解决方案。
下一个原因是灵活性。
考虑 YTD 的两种变体。虽然 TOTALYTD() 返回正确的结果,但使用第二种变体结合 CALCULATE() 和 DATESYTD() 时,我们更具灵活性,同时性能和效率没有任何损失。
现在我们已经讨论了理解这些度量的不同变体的最后一个原因:性能和效率。
某个特定的计算可能在一种场景中没有问题,但在另一种场景中可能很慢。
因此,知道如何以不同的方式进行计算可能会救命。
例如,“过去三个月”的不同变体虽然效率很高,却会返回不同的结果。我分析了这四种变体,性能非常相似。
结论
在探索了不同变体后,我们在计算涉及时间智能函数的结果时拥有了更广泛的工具箱。
鉴于我们可以计算结果的多种方式以及利益相关者对业务逻辑的理解存在很多差异,工具箱越大越好。
但有时,当制定可能的解决方案时,我们必须关注许多细节,因为“魔鬼在于细节”,我们必须考虑用户如何在报告中使用这些度量。
在这里,我仅覆盖了 DAX 中可用的少数时间智能函数。还有许多函数覆盖了特定的需求。我强烈推荐查看 dax.guide 并探索各种可能性。
像往常一样,我在撰写关于这个复杂主题的文章时学到了些细节,希望你也能了解更多。
如果有任何你感兴趣的话题我可以在未来的文章中涵盖,请留下评论。

图片由 Brett Jordan 提供,来源于 Unsplash
参考文献
SQLBI 文章 Time Intelligence in Power BI Desktop 是关于此主题的非常有用的来源。你可以在那里找到更多有关处理时间智能函数的信息。
我使用 DAX Studio 和内置的 Server Timings 功能分析了性能。要了解更多关于此功能的信息,我建议阅读我关于这个主题的文章:
## 如何从 Power BI 使用 DAX Studio 获取性能数据
有时我们会遇到报告加载缓慢的问题,我们需要弄清楚原因。我们将看到如何收集性能数据和…
towardsdatascience.com
我使用了 Contoso 示例数据集,和我之前的文章一样。你可以从微软这里免费下载 ContosoRetailDW 数据集。
Contoso 数据可以在 MIT 许可证下自由使用,详细信息请参见这里。
我扩大了数据集,以使 DAX 引擎工作得更辛苦。
Online Sales 表包含 7100 万行(而不是 1260 万行),Retail Sales 表包含 1850 万行(而不是 340 万行)。
[## 每当 Salvatore Cagliari 发布新文章时获取电子邮件通知。
每当 Salvatore Cagliari 发布新文章时获取电子邮件通知。通过注册,如果你没有 Medium 帐户,你将会创建一个…
medium.com](https://medium.com/@salvatorecagliari/subscribe?source=post_page-----e795545e2a40--------------------------------)
探索全球野生动物 GIS 数据库
原文:
towardsdatascience.com/exploring-a-global-wildlife-gis-database-0453723ae5c9

所有哺乳动物栖息地的全球地图,每个栖息地随机着色。
使用 Python 来表征国际自然保护联盟(IUCN)的地理空间数据库。
·发表于数据科学前沿 ·13 分钟阅读·2023 年 10 月 19 日
--
国际自然保护联盟(IUCN)启动了多个保护野生动物的项目。其中一个努力促成了一个高质量的全球地理空间数据库,包含了超过 100,000 种物种的栖息地。在本文中,我将深入探讨其子集,重点关注陆生哺乳动物。
IUCN 红色名录濒危物种数据库包含了超过 150,000 种物种,其中 80%以上具有栖息地的地理空间信息。这个数据库的庞大规模提出了几个挑战,我可能会在以后的文章中解决这些问题。现在,我专注于一个较小的子集——全球数据库中包含 12,436 条陆生哺乳动物记录,每条记录对应一个物种的栖息地。这个哺乳动物-栖息地数据库基于大约四百个不同的来源,包含了 5,626 种以双名法命名的物种,登记时间为 2008 年至 2022 年。此外,数据库包括详细的分类信息,如物种的目和科。此外,数据库的一个主要优势是,它以多边形文件的形式提供详细的栖息地地理空间信息,我将在后续部分详细探讨。
首先,我将介绍并探讨数据集的非几何特征,然后进行一些特定于不同物种地理空间分布的分析步骤。通过这些分析,我希望推广这一数据源,并鼓励未来对其进行潜在应用于野生动物保护政策的研究。
你可以在IUCN 数据资源找到所有 IUCN 数据源,我从中下载了陆生哺乳动物的多边形数据(搜索日期:2023 年 10 月 2 日 15:30:02)
本文中的所有图片均由作者创作。
1. 统计探索
1.1. 解析数据集
首先,让我们使用 GeoPandas 解析数据库,看看它包含了什么:
import geopandas as gpd # version: 0.9.0
gdf_iucn = gpd.read_file('MAMMALS_TERRESTRIAL_ONLY')
print('Number of records: ', len(gdf_iucn))
print('Number of attributes: ', len(gdf_iucn.keys()))
gdf_iucn.head(3)
该单元格的输出:

IUCN MAMMALS_TERRESTRIAL_ONLY 数据集的预览。
地理空间数据文件似乎有 29 个属性,我们也可以在元数据文档中了解更多。让我们在这里探索一下吧!
我们将处理的物种数量,每个物种都有一个唯一的 ID,如下所示:
print(len(set(gdf_iucn.id_no)))
print(len(set(gdf_iucn.sci_name)))
哪个单元格返回值 5626 两次,确保每个物种确实有一个唯一的id_no和一个唯一的sci_name。
1.2 分类学类别
我们有几个列描述分类学类别:kingdom、phylum、class、order_、family和genus。让我们看看哪些类别最为频繁。在这里,我只保留物种名称和分类学类别作为唯一的配对,因此频率以每个分类级别的唯一物种数来表示。
from collections import Counter
for a in ['kingdom', 'phylum', 'class', 'order_', 'family', 'genus']:
print(a, Counter(gdf_iucn[[a, 'sci_name']].drop_duplicates()[a].to_list()).most_common(3), len(set(gdf_iucn[a])))
该单元格的输出:

每个分类学变量的前三个值。
这些统计数据如下:ANIMALIA(动物界)、CHORDATA(脊索动物门)和 MAMMALIA(哺乳纲)覆盖了数据集中的所有物种,这意味着它们都是动物,所有物种都有脊髓,并且它们都是哺乳动物。
首个区分出现在生物学级别的分类顺序,其中前 3 名 — 26 个不同的分类顺序中 — 是“RODENTIA”(啮齿动物目)有 2275 个条目,“CHIROPTERA”(翼手目)有 1317 个条目,以及“PRIMATES”(灵长目)有 521 个条目。
数据集中共有 136 个科,其中排名前列的是“MURIDAE”(鼠科)有 763 个条目,“CRICETIDAE”(仓鼠科)有 659 个条目,以及“VESPERTILIONIDAE”(晚蝠科)有 461 个条目,这些是分类顺序级别下的更小的子类,汇集了类似的物种。
最终出现的是属,其中有 1171 个不同的类别,排名前列的是“Crocidura”(长尾鼠)有 196 个条目,“Myotis”(耳蝠)有 120 个条目,以及“Rhinolophus”(盔蝠)有 92 个条目。蝙蝠侠越来越真实了,是不是?
在拆解了顶级列表且未看到太多我们日常生活中熟悉的物种后,我们还可以对一些更为知名的物种进行反向搜索:
gdf_iucn[gdf_iucn.genus=='Canis'].head(5)

属于 Canis 的前五个记录 — 与狗最亲近的亲属。
这个快速搜索查询显示了家犬属的结果,其属名为“Canis”。在科学分类中,家犬属于 Canis 属,该属包括各种犬科动物的物种和亚种。家犬的科学名称是 Canis lupus familiaris。Canis 属的其他成员包括灰狼(Canis lupus)、郊狼(Canis latrans)和金豺(Canis aureus)等。
此外,subspecies 和 subpop 列也属于此处;然而,在这个子集里,它们实际上是空字段。* 同样,tax_comm 列仅包含一些关于分类数据的评论,因此我忽略它。
1.3 栖息地类型如何被表征?
了解不同物种及其分类后,我们来看看它们的栖息地类型。presence、origin、seasonal 列以及它们在legend中的总结包含了这些其他属性的专家评估,前 3 项在 1 到 6 的范围内评分(从现存到灭绝/存在不确定),并给 legend 参数提供了文字描述。其网站上的一个示例展示了如何实现这一点:
-
presence = 1(现存);origin = 2(再引入);seasonal = 2(繁殖季节)对应于‘现存 & 再引入(繁殖)’
-
presence = 3(可能现存);origin = 1(本地);seasonal = 1(常驻)对应于‘可能现存(常驻)’
gdf_iucn[['presence', 'origin', 'seasonal', 'legend']].head(5)

数据集的栖息地评分子集预览。
让我们看看这些栖息地属性的分布情况:
import matplotlib.pyplot as plt
from collections import Counter
def get_distribution(x):
return list(zip(*[(k, v) for k, v in Counter(x.to_list()).most_common()]))
f, ax = plt.subplots(1,3,figsize = (15,4))
for idx, feat in enumerate(['presence', 'origin', 'seasonal']):
values, frequencies = get_distribution(gdf_iucn[feat])
ax[idx].bar(values, frequencies)
ax[idx].set_yscale('log')
ax[idx].set_title(feat)
该单元格的输出:

栖息地评分的直方图。
f, ax = plt.subplots(1,1,figsize = (15,5))
values, frequencies = get_distribution(gdf_iucn['legend'])
ax.bar(values, frequencies)
ax.set_title('Frequency of habitat-characteristics', fontsize = 20)
ax.set_yscale('log')
ax.set_xticks(range(len(values)))
ax.set_xticklabels(values, rotation = 60, ha = 'right')
plt.show()
该单元格的输出:

栖息地图例值的直方图。
这个图表显示了最常见的标签是现存。然而,灭绝和可能灭绝的类别也在前 10 名之中,这已经是一个令人担忧的迹象。
描述栖息地特征的其他列有 island 和 marine;然而,由于这个子样本的性质,这些在这里不适用。
1.4. 元信息
在深入数据表时,您还可能找到有关数据本身的来源和类型的信息,这些信息由 compiler、yrcompiled、citation 和 source 捕获。以下快速排名显示了哪些实体在编制此数据库时最为繁忙,哪些出版物是最常见的来源,以及这些记录引用最多的是谁:
Counter(gdf_iucn.compiler).most_common(5)

前 5 个数据记录编制实体。
Counter(gdf_iucn.source).most_common(5)

前 5 个数据来源。
Counter(gdf_iucn.citation).most_common(5)

前 5 个被引用最多的来源。
此外,编制年份列可以提示数据集的更新程度,显示记录在 2008 年到 2022 年之间的值。2008 年的巨大峰值可能对应数据库的启动,之后,可能由于几年的资金有限和/或启动后对更新的需求减少。
min(gdf_iucn.yrcompiled), max(gdf_iucn.yrcompiled)
f, ax = plt.subplots(1,1,figsize = (15,5))
ax.set_title('The number of records per year', fontsize = 20)
values, frequencies = get_distribution(gdf_iucn['yrcompiled'])
ax.bar(values, frequencies)

随时间变化的数据记录数量。
1.5. 危害信息
这个数据集中最有趣的类别变量可能包含有关物种状况的严重程度的信息。为了描述这一点,IUCN 红色名录引入了九个类别来描述不同物种的保护状态,这些类别记录在类别列中。在这九个类别中,有八个存在于此数据集中:
-
严重濒危(CR),
-
数据不足(DD),
-
濒危(EN),
-
野外灭绝(EW),
-
灭绝(EX),
-
最低关切(LC),
-
不受威胁(NT),
-
区域灭绝(RE),
-
脆弱(VU)
根据这一分类,严重濒危(CR)、濒危(EN)和脆弱(VU)物种被考虑在内。
让我们重新映射成可读格式,并计算各类别,确保每个物种只计算一次:
category_d = { 'EX' : 'Extinct',
'EW' : 'Extinct in The Wild',
'RE' : 'Regionally Extinct',
'CR' : 'Critically Endangered',
'EN' : 'Endangered',
'VU' : 'Vulnerable',
'DD' : 'Data Deficient',
'LC' : 'Least Concern',
'NT' : 'Not Threatened'
}
gdf_iucn['category'] = gdf_iucn['category'].map(category_d)
Counter(gdf_iucn[['sci_name', 'category']].drop_duplicates().category).most_common()
该单元格的结果:

危害类别的频率分布。
这些统计数据显示,3205(56%)的物种属于最不关心类别;它们现在是安全的。然而,22%属于濒危(脆弱、濒危或严重濒危)。此外,我们缺少 14%的数据,还有大约 15 个物种已经灭绝。让我们在这里缅怀它们:
sorted(set(gdf_iucn[gdf_iucn.category.isin(['Extinct', 'Extinct in The Wild'])].sci_name.to_list()))
已经灭绝的物种(感谢 ChatGPT 的英文翻译):
-
Dusicyon australis(福克兰岛狼)
-
Dusicyon avus(达尔文狐)
-
Juscelinomys candango(塞拉多鼠)
-
Leporillus apicalis(小枝巢鼠)
-
Melomys rubicola(荨麻岛鼠)
-
Nesoryzomys darwini(达尔文稻鼠)
-
Nyctophilus howensis(霍威岛长耳蝙蝠)
-
Oryx dammah(弯角羚)
-
Palaeopropithecus ingens(巨型眼镜猴)
-
Pennatomys nivalis(山地侏儒负鼠)
-
Pipistrellus murrayi(东方长翼蝠)
-
Pteropus subniger(黑色飞狐)
-
Pteropus tokudae(马里亚纳果蝠)
-
Sus bucculentus(爪哇疣猪)
-
Xenothrix mcgregori(麦格雷戈猿)
让我们也可视化频率分布:
def get_color(x):
if x in ['Critically Endangered', 'Endangered', 'Vulnerable']:
return 'red'
elif x in ['Extinct in The Wild', 'Regionally Extinct', 'Extinct']:
return 'k'
else:
return 'green'
f, ax = plt.subplots(1,2,figsize = (12,6))
ax[0].set_title(70 * ' ' + 'The number of species category', fontsize = 20, pad = 30)
values, frequencies = get_distribution(gdf_iucn[['sci_name', 'category']].drop_duplicates().category)
colors = [get_color(v) for v in values]
for idx in range(2):
ax[idx].bar(values, frequencies, color = colors)
ax[idx].set_xticks(range(len(values)))
ax[idx].set_xticklabels(values, rotation = 60, ha = 'right')
ax[1].set_yscale('log')
plt.tight_layout()

每个危害类别中的物种数量。
2. 地理空间探索
2.1. 栖息地大小分布
我们可能首先关注的两个几何信息是 SHAPE_Leng 和 SHAPE_Area,分别对应栖息地边界的总长度和整个面积。准备这些数据是方便的,因为根据地图投影计算长度和面积是有挑战性的。关于这个主题的更多信息请见这里。
现在,来看一下最显著和最小的区域——哪些物种随处可见,哪些物种需要近距离观察和探险才能追踪到?
# lets sum up the area of each patch a species may have
gdf_iucn.groupby(by = 'sci_name').sum().sort_values(by = 'SHAPE_Area').head(10)
此单元的输出:

对哺乳动物栖息地数据库的面积聚合版本进行预览。
现在获取基于栖息地面积的前十和后十物种:
gdf_iucn.groupby(by = 'sci_name').sum().sort_values(by = 'SHAPE_Area').head(10).index.to_list()
此单元的输出,以及英文翻译如下:
1. Melomys rubicola(灌木丛鼠)
2. Eudiscoderma thongareeae(Thongaree 的 Discoderma)
3. Murina balaensis(巴拉长尾鼠)
4. Nyctophilus nebulosus(东部管鼻蝠)
5. Cavia intermedia(圣卡塔里娜豚鼠)
6. Fukomys livingstoni(利文斯顿地鼠)
7. Rhinolophus kahuzi(卡胡兹马蹄蝠)
8. Microtus breweri(布鲁尔田鼠)
9. Myotis nimbaensis(尼姆蝙蝠)
10. Hypsugo lophurus(草原鼠耳蝠)
现在来看另一端——最大的栖息地:
gdf_iucn.groupby(by = 'sci_name').sum().sort_values(by = 'SHAPE_Area', ascending = False).head(10).index.to_list()
1. Mus musculus(家鼠)
2. Vulpes vulpes(红狐)
3. Canis lupus(灰狼)
4. Mustela erminea(白鼬)
5. Mustela nivalis(最小鼬)
6. Ursus arctos(棕熊)
7. Gulo gulo(狼獾)
8. Alces alces(驼鹿)
9. Rangifer tarandus(驯鹿)
10. Lepus timidus(山兔)
2.2. 全球可视化栖息地多边形
还有一列我没有提到,那就是几何信息。然而,它可能包含所有信息中最丰富的内容。首先,绘制一个简单的 GeoPandas 地图,显示每个栖息地多边形的位置。在此可视化中,我用有色阴影区域和细边框标记每个栖息地。颜色基于tab20 色图随机分配。
此外,我将地理空间数据转换为 Mollweide 投影,以使我的地图更美观。有关全球地图投影的更多信息,请查看这里。
# transform the coordinate reference system
gdf_iucn_t = gdf_iucn.copy()
gdf_iucn_t = gdf_iucn_t.to_crs('+proj=moll +lon_0=0 +x_0=0 +y_0=0 +datum=WGS84 +units=m +no_defs')
f, ax = plt.subplots(1,1,figsize=(15,10))
gdf_iucn_t.sample(200).plot(ax=ax, edgecolor = 'k', linewidth = 0.5, alpha = 0.5, cmap = 'tab20')
此代码块的输出:

可视化 200 种随机选择的物种的栖息地斑块,每个斑块颜色随机。
现在,完整的数据集!此外,为了进一步改进我的可视化,我添加了一个带有 library contextily 的底图。我添加了两个不同的版本,让我们根据个人口味来决定!
import contextily as ctx
f, ax = plt.subplots(1,1,figsize=(15,10))
gdf_iucn_t.plot(ax=ax, edgecolor = 'k', linewidth = 0.5, alpha = 0.15, cmap = 'tab20')
ctx.add_basemap(ax, alpha = 0.8, crs = gdf_iucn_t.crs, url = ctx.providers.Esri.WorldPhysical)
ax.axis('off')
ax.set_ylim([-9.5*10**6, 9.5*10**6])
此代码块的输出:

所有哺乳动物栖息地的全球地图,每个斑块颜色随机。
f, ax = plt.subplots(1,1,figsize=(15,10))
gdf_iucn_t.plot(ax=ax, edgecolor = 'k', linewidth = 0.5, alpha = 0.15, cmap = 'tab20')
ctx.add_basemap(ax, alpha = 0.8, crs = gdf_iucn_t.crs, url = ctx.providers.Esri.WorldGrayCanvas)
ax.axis('off')
ax.set_ylim([-9.5*10**6, 9.5*10**6])
#plt.savefig('worldmap_habitats_WorldGrayCanvas.png', dpi = 600, bbox_inches = 'tight')

所有哺乳动物栖息地的全球地图,每个斑块颜色随机。
2.3. 本地可视化栖息地多边形
在同一张地图上绘制每个栖息地,数以千计,确实会产生一些令人兴奋的图形;然而,从中得出见解并不简单。在得出这些见解的过程中,让我们放大,可视化几个精选物种的栖息地。
例如,当我们搜索长颈鹿(Giraffa camelopardalis)时,会发现十个不同的栖息地斑块:
gdf_iucn[gdf_iucn.genus.str.contains('Giraffa')].head(5)
此代码块的输出:

长颈鹿栖息地斑块的预览。
现在,让我们使用下面的代码块生成长颈鹿、猩猩、狮子和非洲象的栖息地图!
f, ax = plt.subplots(1,1,figsize=(15,10))
gdf_iucn[gdf_iucn.sci_name=='Giraffa camelopardalis'].plot(ax=ax, edgecolor = 'k', linewidth = 0.5, alpha = 0.9, color = '#DAA520')
ax.set_xlim([-5, 55])
ax.set_ylim([-38, 15])
ax.set_title('The habitat patches of Giraffa camelopardalis (Giraffes)', fontsize = 18, y = 1.03)
ctx.add_basemap(ax, alpha = 0.8, crs = gdf_iucn.crs, url = ctx.providers.Esri.WorldPhysical)
plt.savefig('1_giraffe.png', dpi = 200)


长颈鹿、猩猩、狮子和非洲象的栖息地图。


2.4. 映射到国家
在仔细研究了一些选定的栖息地之后,接下来我们进行国家级的汇总。具体来说,我将每个栖息地多边形映射到国家的行政边界,然后计算物种总数、濒危物种总数及其比率。
我使用了Natural Earth 数据库的 Admin 0 — Countries 文件来获取国家级的行政边界。
world = gpd.read_file('ne_10m_admin_0_countries')
print(len(set(world.ADMIN)))
world.plot()
这个代码块的结果:

基于 Natural Earth 的国家数据集的世界地图。
让我们将濒危类别进行分组:
def is_endangered(x):
if x in ['Critically Endangered', 'Endangered', 'Vulnerable']:
return True
else:
return False
gdf_iucn['endangered_species'] = gdf_iucn.category.apply(is_endangered)
print(Counter(gdf_iucn['endangered_species']))
基于此,属于濒危物种的栖息地斑块数量为 2680,而其他为 9756。
现在建立测量每个国家物种数量和濒危物种数量的国家级字典:
number_of_all_species = gpd.overlay(world, gdf_iucn).groupby(by = 'ADMIN').count().to_dict()['geometry']
number_of_end_species = gpd.overlay(world, gdf_iucn[gdf_iucn.endangered_species==True]).groupby(by = 'ADMIN').count().to_dict()['geometry']
world['number_of_all_species'] = world.ADMIN.map(number_of_all_species)
world['number_of_end_species'] = world.ADMIN.map(number_of_end_species)
world['number_of_all_species'] = world['number_of_all_species'].fillna(0)
world['number_of_end_species'] = world['number_of_end_species'].fillna(0)
world['ratio_of_end_species'] = world['number_of_end_species'] / world['number_of_all_species']
Finally, use these updated to visualize the global distributions on the level of countries:from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.colors import LogNorm
f, ax = plt.subplots(1,1,figsize=(15,7))
ax.set_title('Total number of species', fontsize = 20, pad = 30)
world.plot(ax=ax, color = 'grey', alpha = 0.5, linewidth = 0.5, edgecolor = 'grey')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="2%", pad=-0.01)
world[world.number_of_all_species>0].plot(column = 'number_of_all_species',ax=ax,legend_kwds={'label': "Total number of species"}, edgecolor ='k', linewidth = 1.5, cax=cax, cmap = 'Greens', legend=True, norm=LogNorm(vmin=1, vmax=world.number_of_all_species.max()))

每个国家的物种总数。
f, ax = plt.subplots(1,1,figsize=(15,7))
ax.set_title('Number of endangered species', fontsize = 20, pad = 30)
world.plot(ax=ax, color = 'grey', alpha = 0.5, linewidth = 0.5, edgecolor = 'grey')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="2%", pad=-0.01)
world[world.number_of_end_species>0].plot(column = 'number_of_end_species',ax=ax, legend_kwds={'label': "Number of endangered of species"}, edgecolor ='k', linewidth = 1.5, cax=cax, cmap = 'RdYlGn_r', legend=True, norm=LogNorm(vmin=1, vmax=world.number_of_end_species.max()))
plt.savefig('2_map.png', dpi = 200)

每个国家的濒危物种数量。
f, ax = plt.subplots(1,1,figsize=(15,7))
ax.set_title('Ratio of endangered species', fontsize = 20, pad = 30)
world.plot(ax=ax, color = 'grey', alpha = 0.5, linewidth = 0.5, edgecolor = 'grey')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="2%", pad=-0.01)
world.plot(column = 'ratio_of_end_species',ax=ax, legend_kwds={'label': "Fraction of endangered of species"}, edgecolor ='k', linewidth = 1.5, cax=cax, cmap = 'RdYlGn_r', legend=True, vmin = 0, vmax = 0.22)

每个国家的濒危物种比率。
3. 总结
在这篇文章中,我介绍了 IUCN 的地理空间数据集,其中包含了成千上万种物种的记录栖息地。经过简短的统计概述后,我展示了如何访问、可视化和操作这些地理空间数据集,以激发未来的工作。
虽然最后我建立了一个简单的指数来量化不同国家野生动物的濒危程度,但 IUCN 提供了现成的地图,基于更复杂的方法。务必在这里查看一下!
最后,IUCN 接受未来资助的捐款这里。
参考文献:IUCN. 2022. IUCN 红色名录。版本 2022–2。 www.iucnredlist.org. 访问时间:2023 年 10 月 02 日。
探索反事实洞见:从相关性到因果关系的数据分析
反事实在数据科学中的决策制定中的应用
·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 10 月 6 日
--

照片由 Daniele Franchi 拍摄,来源于 Unsplash
想象一下:一个世界里,天空呈现出宁静的柠檬黄色,鸟儿恢复理智,流利地用英语交流,而果树则违抗重力,展示出深紫色和电紫色的叶子,同时随时提供最美味的果实。
然后你会想,终于!这个世界变得有意义了。
嗨,大家好!
让我们回到现实中,但不要担心,因为我们即将踏上一段同样引人入胜的旅程——反事实的世界。虽然我们最初的想象可能是一种愉快的幻想,但反事实打开了另一种奇妙的大门,让我们可以探索我们世界中的“如果”。
“反事实”这个词听起来可能很复杂,但它的意思只是考虑与事实或实际事件相反的情境。尽管这个词本身是在1946 年创造的,但这个思想可以追溯到几个世纪以前,当人们开始思考“如果”情境时。
在心理学中,反事实思维经常被用来探讨那些与已发生事件不同的情境。例如,我们可能会思考,如果一名罪犯有机会选择不同的人生道路,他会做出什么样的选择。
然而,作为数据科学家,我们的关注点不在于犯罪心理学的复杂性。相反,我们感兴趣的是在数据世界中利用反事实的力量。我们的目标是理解数据为何如此呈现,并从中得出有意义的推论。我们的领域是方程式、客观方法以及这些概念在数据世界中的实际应用。
为了进一步探讨这个有趣且富有想象力的概念,让我们深入一个例子。但首先,我们需要提供一些背景信息……
用一个引人入胜的情境来设定舞台
想象一个情境,你是一名数据科学家,正在为 MM Securities 工作,这是一家虚构的安全公司,专注于评估系统漏洞。目前,他们正处于争取一个重要客户合同的关键时刻,但出现了一个重要的挑战。客户有一个独特的要求:他们想知道这些漏洞是否是勒索软件攻击的原因之一。他们建议,如果 MM Securities 能够令人信服地证明这些漏洞确实导致了勒索软件攻击,他们将热切参与业务。
这个情况引起了 MM Securities 高级领导层的兴趣,因为他们相信展示因果关系可以为他们的组织带来重大价值。因此,他们向其出色的数据科学团队寻求答案。他们给团队的假设如下:
“我们评估的特定漏洞的组织面临着更高的勒索软件攻击风险。”
幸运的是,MM Securities 在成功应对此类挑战方面有着良好的历史记录,并且拥有一个相关的数据集。这是你展现才华的时刻,深入探究问题的核心,揭示潜在的因果关系。凭借你在这个领域的专业知识,你的调查之旅开始了。
进行第一个因果假设……
现在你已经开始了调查之旅,你从简单的假设出发,即漏洞对勒索软件攻击有直接影响。

作者提供的图片
更清晰地说明如下:
自变量或怀疑的原因: 系统中的漏洞
因变量或怀疑的效果: 一次勒索软件攻击
既然我们已经整理好了数据、假设和变量,就该开始验证我们的理论了。
Python 代码来丰富我们的分析
我们将开始构建一个简单的贝叶斯模型,使用pgmpy库。此外,为了演示的需要,我们将生成一些合成数据。
我们的合成数据反映出一个相当均匀的分布,大约 50%的实例在其系统中存在漏洞。我们还设计了漏洞的存在与勒索软件攻击发生之间存在一定的正相关。
鉴于此,我们直接进入代码部分。
#Importing packages
from pgmpy.models import BayesianNetwork
from pgmpy.estimators import MaximumLikelihoodEstimator
from pgmpy.inference import VariableElimination
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
num_samples = 500
# Let's Generate synthetic data for vulnerabilities (binary: 0 or 1)
vulnerabilities = np.random.choice([0, 1], num_samples)
ransomware = [1 if np.random.uniform(0, 1) < prob * (v + 1) else 0 for v in vulnerabilities]
data = pd.DataFrame({
'Vulnerability': vulnerabilities,
'Ransomware': ransomware
})
# Designing the Network
model = BayesianNetwork([('Vulnerability', 'Ransomware')])
model.fit(data,estimator=MaximumLikelihoodEstimator)
# Check model for early errors
assert model.check_model()
inference = VariableElimination(model)
# Calculating marginal probabilities
prob_vulnerability = inference.query(variables=['Vulnerability']).values
prob_ransomware = inference.query(variables=['Ransomware']).values
# Calculating conditional probabilities i.e. P(Ransomware | Vulnerability)
evidence_vulnerability = {'Vulnerability': 1}
prob_ransomware_with_vulnerability = inference.query(variables=['Ransomware'], evidence=evidence_vulnerability)
evidence_no_vulnerability = {'Vulnerability': 0}
prob_ransomware_without_vulnerability = inference.query(variables=['Ransomware'], evidence=evidence_no_vulnerability)
# Visualization
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
# Subplot 1: Data Distribution of Instances with and without vulnerabilities
ax1.pie(
prob_vulnerability,
labels = ['Data Points without Vulnerabilities', 'Data Points with Vulnerabilities'],
autopct = '%1.1f%%', colors = ['gold', 'tomato'],
explode = (0.0, 0.1))
ax1.set_title('Data Distribution for Vulnerabilities')
# Subplot 2: Data Distribution of Instances with and without Ransomware
ax2.pie(
prob_ransomware,
labels = ['Data Points without Ransomware', 'Data Points with Ransomware'],
autopct = '%1.1f%%', colors = ['gold', 'tomato'],
explode = (0.0, 0.1))
ax2.set_title('Data Distribution for Ransomware')
# Subplot 3: NetworkX Showing assumed relationship between the variables
G = nx.DiGraph()
G.add_node('Vulnerability', pos=(-0.8, 0.0))
G.add_node('Ransomware', pos=(0.8, 0.0))
G.add_edge('Vulnerability', 'Ransomware', label='Direct Effect')
node_pos = nx.get_node_attributes(G, 'pos')
edge_labels = nx.get_edge_attributes(G, 'label')
pos = nx.planar_layout(G)
nx.draw_networkx_edge_labels(G, node_pos,
edge_labels = edge_labels,
verticalalignment = 'top',
font_size = 12, ax = ax3)
nx.draw_networkx_labels(G, node_pos,
verticalalignment = 'bottom',
font_size = 12, ax = ax3)
nx.draw(G, pos, with_labels = False, node_size = 2500, node_color = 'tab:olive', ax = ax3)
ax3.set_title('Our Data Relationship Assumption Model')
# Subplot 4: Bar chart for conditional probability of Ransomware given Vulnerability
values = [prob_ransomware_without_vulnerability.values[0],
prob_ransomware_without_vulnerability.values[1],
prob_ransomware_with_vulnerability.values[0],
prob_ransomware_with_vulnerability.values[1]]
labels = ['P(R = 0|V = 0)', # P(No Ransomware Attack | No Vulnerabilites)
'P(R = 1|V = 0)', # P(Ransomware Attack | No Vulnerabilites)
'P(R = 0|V = 1)', # P(No Ransomware Attack | Vulnerabilites)
'P(R = 1|V = 1)'] # P(Ransomware Attack | Vulnerabilites)
ax4.bar(labels, values, color = ['gold', 'tab:olive', 'lightcoral', 'tomato'])
for i, value in enumerate(values):
ax4.annotate(f'{value:.2f}', (i, value), ha='center', va='bottom', fontsize=12)
ax4.set_xlabel('Vulnerability')
ax4.set_ylabel('Probability of Ransomware')
ax4.set_title('Conditional Probability of Ransomware given Vulnerability')
ax4.set_xticklabels(labels, rotation=45)
plt.tight_layout()
plt.show()
现在,让我们深入探讨可视化以获取更深层次的见解。我们使用NetworkX库来可视化初始数据分布和我们假设的图形模型。这些可视化揭示了漏洞存在与勒索软件攻击发生之间的明确相关性。
在下图中,我们通过显示勒索软件和漏洞之间的条件概率来说明这种相关性,即。
P(勒索软件 | 漏洞)

作者提供的图像
在进一步之前,让我们花点时间理解条件概率。
条件概率
条件概率,表示为 P(X|Y),仅仅表示在发生 Y 的情况下 X 发生的概率。值得注意的是,条件概率并不暗示因果关系或事件顺序;它仅涉及它们之间的相关性。
在大多数情况下,P(X|Y) ≠ P(Y|X),但这两者都不涉及因果关系,仅仅是概率从每种情况下不同的实例中推导出来。
既然我们已经阐明了这一基本概念,你可能会回忆起统计课上的一句名言:
相关性并不等于因果关系。
因此,我们当前的分析,虽然揭示了漏洞存在与勒索软件攻击之间的相关性,但还不足以证明因果关系。
那么,因果关系是什么,我们如何定义它,它与反事实有什么关系?
在数据科学中,讨论反事实通常与因果关系、干预和模型可解释性交织在一起。到目前为止,我们仅仅检查了一个现有的数据集以识别相关性,但我们还没有确定这种相关性是否暗示因果关系。
为了探索因果关系,我们从反事实分析开始。让我们考虑一个可能的反事实情境并进行测试:
可能的反事实: 如果没有漏洞,是否会有勒索软件?
可能的反事实声明: 如果没有漏洞,组织将不会受到勒索软件的影响。
为了建立因果关系,我们需要调查漏洞的移除是否消除了或至少减少了勒索软件的可能性。从理论上讲,这是可能的,但在实践中,通常需要大量时间和复杂的数据收集工作。
在我们继续代码之前,让我们讨论一个将技术分析与因果关系更广泛概念联系起来的关键概念:‘做’操作符。理解这个操作符如何工作对于测试我们的反事实假设至关重要。
‘做’微积分与干预
在概率论中,因果关系通常通过干预进行数学表示,使用‘做’操作符。
在我们当前的场景中,我们旨在使用‘do’操作符进行干预:通过处理漏洞并观察对勒索软件攻击的影响。这种干预是主动地改变某些东西,引入了事件序列的概念,超越了单纯的观察;它允许我们根据漏洞变量的变化来评估对勒索软件变量的影响。
然而,重要的是要承认,在许多实际情况中,由于各种限制和伦理考虑,进行这样的干预是不切实际或不可能的。
现在我们已经阐明了上述概念,让我们进入实际分析。
Python 中的反事实分析
为了在我们的模型中进行干预,我们引入了一个额外的变量称为‘处理’。在这种情况下,处理意味着对系统中漏洞的处理或修复。我们在这里假设 MM Securities 采取主动措施来解决客户系统中的漏洞。
具体而言,我们假设 MM Securities 对他们服务的大约 60%的组织独立施加这种处理。此外,这种处理在修复漏洞方面大约 90%的时间有效。然而,如果不施加处理,情况不会改变,漏洞的数据分布保持不变。
值得注意的是,这一建模过程需要仔细考虑数据、深刻理解系统以及领域专长,这在使用图形模型时通常是必要的。

图片来源:作者
对于我们分析的下一步,我们可以使用已知的概率在 Python 中轻松编写此模型。幸运的是,pgmpy提供了一个TabularCPD选项,其中 CPD 代表条件概率分布。
我们的新场景可以总结如下:
我们引入了直接影响漏洞的‘处理’变量。
-
40%的实例没有接受“处理”:P(do(No Treatment)) = 0.4,而 60%则接受了:P(do(Treatment)) = 0.6。
-
当没有施加处理时,漏洞的分布保持不变,可以从我们最初的数据中推导出:P(Vulnerability = 0| do(No Treatment)) = 0.52,P(Vulnerability = 1| do(No Treatment)) = 0.48
-
当施加处理时,无论当前的漏洞数据点如何,90%的情况会导致零漏洞:P(Vulnerability = 0| do(Treatment)) = 0.9,而 10%仍然存在漏洞:P(Vulnerability = 1 | do(Treatment)) = 0.1
我们还从之前的数据中获得了勒索软件和漏洞的条件概率知识,这些知识被纳入了我们的分析:
-
P(Ransomware = 0 | Vulnerability = 0) = 0.80
-
P(Ransomware = 0 | Vulnerability = 1) = 0.56
-
P(勒索软件 = 1| 漏洞 = 0) = 0.20
-
P(勒索软件 = 1| 漏洞 = 1) = 0.44
有了这些信息,我们现在可以继续分析干预的效果。
那么让我们直接进入 Python 实现:
from pgmpy.models import BayesianNetwork
from pgmpy.factors.discrete import TabularCPD
from pgmpy.inference import VariableElimination
import matplotlib.pyplot as plt
# Define Conditional Probability Distributions (CPDs)
cpd_treatment = TabularCPD(variable='Treatment', variable_card=2, values=[[0.4], [0.6]])
cpd_vulnerability = TabularCPD(variable='Vulnerability', variable_card=2, values=[[0.52, 0.9], [0.48, 0.1]],
evidence=['Treatment'], evidence_card=[2])
cpd_ransomware = TabularCPD(variable='Ransomware', variable_card=2, values=[[0.80, 0.56], [0.20, 0.44]],
evidence=['Vulnerability'], evidence_card=[2])
# Create Network and add CPDs to the model
model = BayesianNetwork([('Treatment','Vulnerability'),('Vulnerability', 'Ransomware')])
model.add_cpds(cpd_treatment ,cpd_vulnerability, cpd_ransomware)
# Check model consistency
assert model.check_model()
inference = VariableElimination(model)
# Caluclating Total Effect of Treatment
evidence_treatment = {'Treatment': 1}
prob_treatment = inference.query(variables=['Ransomware'], evidence=evidence_treatment)
evidence_no_treatment = {'Treatment': 0}
prob_no_treatment = inference.query(variables=['Ransomware'], evidence=evidence_no_treatment)
total_effect = prob_treatment.values[1] - prob_no_treatment.values[1]
print('Total Effect of the treatment: ', total_effect)
# Creating Visualization
# 2 Subplots
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
# Subplot 1: NetworkX Showing treatment on the assumed relationship between the variables
G = nx.DiGraph()
G.add_nodes_from(['Treatment', 'Vulnerability', 'Ransomware'])
pos = {
'Treatment': (0, 1),
'Vulnerability': (1, 0),
'Ransomware': (2, 1),
}
G.add_edge('Vulnerability', 'Ransomware', label='Direct Effect')
G.add_edge('Treatment', 'Vulnerability', label='Treatment')
edge_labels = nx.get_edge_attributes(G, 'label')
nx.draw_networkx_edge_labels(G, pos,
edge_labels = edge_labels,
verticalalignment = 'top',
font_size = 12, ax = ax1)
nx.draw(G, pos, with_labels = True,
node_size=2500, node_color='tab:olive',
arrowstyle="-|>,head_width=0.5,head_length=1", ax=ax1)
ax1.set_title('Our Data Relationship Assumption Model')
# Subplot 2: Bar chart for intervention of Ransomware given do(Treatment)
total_effect = prob_treatment.values[1] - prob_no_treatment.values[1]
values = [prob_no_treatment.values[1],
prob_treatment.values[1],
total_effect]
labels = ['E(Ransomware|No treatment)',
'E(Ransomware|do(Treatment))',
'Total Effect']
ax2.bar(
labels,
values,
color = ['gold', 'lightcoral', 'tomato']
)
for i, value in enumerate(values):
ax2.annotate(f'{value:.2f}', (i, value), ha='center', va='bottom', fontsize=12)
ax2.set_xlabel('Scenario')
ax2.set_ylabel('Expectation of Ransomware')
ax2.set_title('Effect of treatment on ransomware')
plt.savefig('treated_plots.png', dpi=300, bbox_inches='tight')
plt.tight_layout()
plt.show()
我们进行这项反事实分析的主要目标是确定缺乏漏洞是否会减少勒索软件案件。为此,我们测量了处理漏洞的总效果对勒索软件事件的影响。
为了测量这种效果,我们计算了期望值在两种情境下的差异:一种是不应用处理,另一种是积极实施处理。

作者提供的图像
在这个情境下,我们的总效果显示为净负值,如图所示。

作者提供的图像
视觉表示清楚地说明了实施处理对减少勒索软件攻击的数量有明显效果。我们观察到,当对漏洞进行处理时,勒索软件事件的净负面效应。
虽然这个结果不能决定性地证明因果关系,但它强烈暗示漏洞是影响勒索软件攻击的一个因素,并且处理漏洞可能会减少这种攻击。
反事实及其局限性
我们可以再次强调,相关性不等于因果性。反事实可以通过创建有无干预或处理的情境控制比较来建立因果关系,因此它们有助于评估因果关系。
反事实可以被视为评估因果关系泡沫的巧妙方法。在数据科学中,这是一项极其有用的工作。
尽管反事实可能不足以证明因果关系,但因此必须注意几个考虑因素:
-
在谈论因果关系时,反事实分析在很大程度上依赖于数据的质量和建模过程中所做假设的有效性。
-
在现实世界中,进行反事实分析可能在实际操作或伦理上都不切实际。
-
我们还没有谈到可能实际负责的混杂变量。反事实分析不足以处理这些关系,我们必须使用更复杂的工具,如结构性因果模型来处理这些情境。
-
与其他统计分析类似,反事实分析应评估统计显著性,以确保观察到的效果不是偶然的。
让我们谈谈反事实的一些其他应用……
尽管在这篇文章中,我们仅关注了因果推断中的逆事实,但逆事实的主题过于广泛,难以涵盖。它们通常用于模型解释、风险最小化、A/B 测试、模型中的偏差检测等。
总结
虽然这篇文章强调了逆事实在因果推断中的作用,但始终要记住,因果关系和逆事实分析的细微差别往往高度依赖于人为判断以及对领域知识和数据的准确解释。
尽管如此,一般来说,逆事实思维应成为任何假设检验的常规过程。
既然我们已经触及了这一复杂且多方面的主题的表面,你可能迫不及待想要深入探索。虽然我们不能在一篇文章中涵盖所有内容,但我已编制了额外资源,以满足你的好奇心,并深入探索逆事实的世界。
其他关于逆事实的精彩资源…
逆事实提供了一种模型无关的方法,使黑箱机器学习算法变得可解释和…
GitHub - MaheepChaudhary/Causation-inComputerVision: 该代码库包含关于因果关系的论文列表,以及相关技术如何被用来进一步提升深度…
该代码库包含关于因果关系的论文列表,以及相关技术如何被用来进一步提升深度…
介绍
Counterfactuals vs Interventional: The article provides an in-depth comparison
不要忘记阅读我其他一些引人入胜的文章!
选择通往显著结果的路径
非传统可视化:何时以及何时不使用它们的力量
towardsdatascience.com
随时在评论中分享你的想法。
探索大型语言模型中的创造力:从 GPT-2 到 GPT-4
通过创造力测试分析大型语言模型中创意过程的演变
·发表于 Towards Data Science ·阅读时间 21 分钟 ·2023 年 4 月 11 日
--

Midjourney 设想中的创造力。由作者生成。
本博客最初发布于 www.artfish.ai/p/exploring-creativity-in-large-language
最近几周,人们使用大型语言模型(LLMs)生成了各种创意内容,如 书籍、闪小说、说唱对决 和 音乐和弦。但是否有可能更广泛地衡量这些模型中的创意过程水平?
人类创造力长期以来一直吸引着心理学家和研究人员。自 1950 年代起,研究人员创建了一系列测试,以比较个人的创造力表现和潜力。虽然没有单一测试能完全捕捉创造力,但这些测试尝试测量和量化其不同方面。
在这篇文章中,我分析了 2019 年到 2023 年间 GPT 模型在测量两种创造力的测试中的表现:收敛性(存在单一正确解决方案)和发散性(开放式;可能存在多种解决方案)[1]。这些测试包括:
通过这些分析,我专注于从早期的 GPT-2 模型到最新的 GPT-4 模型的性能变化。我使用了主观和客观的视角来理解这些模型的行为。通常,较新的模型在所有创造性任务上表现更好,尽管我发现了一些例外情况,其中在测试有“正确”答案时,较简单的 GPT-2 模型超越了更新的 GPT-4 模型。在可能的情况下,我将 GPT 模型与人类在相同创造性测试上的表现进行比较。
注意:这些实验的目的是测量模型生成中的创造性过程水平,而不是评估模型的创造力。我并不声称 GPT 模型或一般的 LLM 具有与人类相同的创造性思维。我旨在展示模型如何响应特定的创造性过程测量。
GPT 模型的发展历程
生成预训练变换器(GPT)模型是训练在大量互联网数据上的语言模型,以根据前面的词预测下一个词。虽然最近的 ChatGPT 在过去几个月引起了公众的关注,但 OpenAI 的 GPT 模型(以及这些模型背后的技术)已经存在几年了。
-
GPT-1(2018 年,117M 参数)。原始 GPT 模型。我使用了HuggingFace 上的开源版本。然而,输出结果非常无意义,所以我没有将它们包含在本文中。
-
GPT-2(2019 年,1.5B 参数)。我使用了HuggingFace 上的开源版本。对于一些实验,GPT-2 的输出是无意义的,不利于理解其行为,所以我没有将这些结果包含在相应的分析中。论文:语言模型是无监督多任务学习者。
-
GPT-3(2020 年,175B 参数)。比 GPT-2 大超过 100 倍。我使用了来自OpenAI APIs的text-davinci-002。
论文:语言模型是少样本学习者
-
GPT-3.5(2022)。虽然 GPT-3.5 有多个版本,我使用的是与原版 ChatGPT 相同的模型,该模型通过带有人工反馈的强化学习进行训练。我使用了来自 OpenAI APIs 的 gpt-3.5-turbo。
-
GPT-4(2023,约 100T 参数)。我使用了来自 OpenAI APIs 的 gpt-4。
论文: GPT-4 技术报告。
在本文的所有实验中,除非另有说明,我对每个独特的提示调用了每个 GPT 模型五次,以考虑其生成的随机性。
1. 远程联想测试:连接无关词
远程联想测试(RAT)是由玛莎·梅德尼克(Martha Mednick)于 1962 年开发的一种创造力测试 [2]。该测试衡量的是聚合思维,这意味着每个问题都有一个“正确”的答案。
给定三个看似无关的词,测试者必须找到一个第四个词来连接这三个词。例如,“dream”、“break”和“light”可以通过“day”(daydream, daybreak, daylight)来联系。词汇“sense”、“courtesy”和“place”可以通过“common”来联系。
我从 远程联想测试网站 收集了词汇三元组。总共有 312 道问题。每个问题的难度按从非常容易到非常困难的等级进行评分。

从 测试网站 获取的随机联想测试题目分解。由作者创建。
对于每个模型和每组三个词,我使用了以下提示模板。我调用了每个模型 5 次,以考虑随机性。
给定三个词。想出一个与前三个词都有某种关联的第四个词。首先给出答案(“Answer”),然后解释你的推理(“Reason”)。
You are given three words. Think of a fourth word that is somehow related to each of the first three words. First give the answer ("Answer") then explain your reasoning ("Reason").Words: {Word1}, {Word2}, {Word3}
较新的 GPT 模型在远程联想上得分更高
我计算了每个 GPT 模型生成正确答案的总体百分比。可能并不奇怪,模型越新,生成的正确答案越多。

每个 GPT 模型在远程联想测试中正确回答问题的百分比。由作者创建。
按照问题难度级别分解模型表现显示了类似的趋势:较新的 GPT 模型在每个难度级别上的表现均优于旧模型。虽然较新的 GPT 模型在难题上的表现优于旧模型,但这种改进比在简单问题上的改进要少。这意味着即使是 GPT-4 在困难问题上也感到吃力。

每个 GPT 模型在远程联想测试中按问题难度级别的正确回答百分比。由作者创建。
根据远程联想测试网站,“典型的人可以解决大多数标记为简单的问题,大约一半的中等难度问题,以及很少的难问题。” 网站上没有分享官方统计数据,但从快速浏览来看,GPT-4 确实解决了大约一半的中等难度问题和一些难问题,使其成为唯一一个大致可与人类水平相媲美的模型。其他 GPT 模型的表现则逊色于人类水平。
一些例外:当 GPT-2 回答正确而 GPT-4 回答错误
尽管不常发生,但确实有一些问题,2019 年的更基础的 GPT-2 模型能够生成正确答案,而 2023 年的 GPT-4 模型却不能。实际上,有 8 个这样的例子。(完整列表见文章末尾的附录)。
一个简单的问题
-
单词:小屋 / 砖块 / 蛋糕
-
GPT-2:“奶酪” [正确答案]
-
GPT-4:“层”
一个中等难度的问题
-
单词:猫 / 数字 / 电话
-
GPT-2:“呼叫”
-
GPT-4:“拨号”
一个难的问题
-
单词:通过 / 磁盘 / 路径
-
GPT-2:“驾驶” [正确]
-
GPT-4:“圆圈”
一个非常难的问题
-
单词:土地 / 手 / 房子
-
GPT-2:“农场” [正确]
-
GPT-4:“财产”
请注意,对于所有这些问题,GPT-4 的回答是 不 不合理的。实际上,如果事先不知道正确答案,你可能会认为 GPT-4 的答案同样可以将这三个词联系起来。如果你这样想,你并没有错。我将在下一部分讨论这个问题。
GPT-4 是错误的,还是只是思维方式不同?
GPT-4 错误地回答了超过一半的问题,但可能有充分的理由。我让 GPT-4 解释了每一个答案。然后,我查看了它对几个“非常难”问题的解释。
单词:艺术家 / 孵化 / 路线
-
正确答案:“逃脱”
-
GPT-4 答案:“地图”
“艺术家”这个词可能会让人想到创造某些东西,甚至是某种视觉表现形式。地图是一种常见的视觉表现形式,需要一定的艺术技能来制作。
“孵化”这个词可以指打开某物或使某物存在的行为。地图是一种蓝图或计划,通常用于导航新领土或规划行动路线。
“路线”这个词与方向和运动相关,这些都可以在地图上绘制和追踪。地图可以帮助你确定旅行时最佳的路线,并且可以提供关于其他人过去所走过路线的信息。
总的来说,“地图”似乎是一个与最初的三个词相符的词,因为它包含了艺术创作、规划和方向,以及运动和导航的概念。
GPT-4 对于为什么“地图”是一个与“艺术家 / 孵化 / 路线”相关的好词的解释很有说服力。
这是另一个例子:
词汇:fork / dark / man
-
正确答案:“pitch”
-
GPT-4 的答案:“shadow”
“fork”这个词可以与分裂或分支相关联,比如道路上的分叉。阴影也可以被视为一种分裂,因为它们代表了光的缺失。
“dark”这个词指的是光的缺失。当光被物体阻挡时,阴影会形成,且往往会比周围区域显得更暗。
“man”这个词可以与身份或个体性概念相关联。阴影也可以用来表示一个人的身份,因为它们可以提供有关他们形状和大小的线索。
因此,“阴影”似乎是一个很好地适应最初三个词的词汇,因为它包含了分裂、缺失和身份的概念。
这些例子展示了远程联想测试的陷阱,以及一般收敛创造力测试的陷阱——即“正确”答案可能并不是唯一的“对”的答案。在这些情况下,GPT-4 生成了一个“错误”的答案,但仅仅是因为它没有符合传统的“正确”答案。
实际上,如果我让 GPT-4 为一组三个词中的第四个词进行头脑风暴,它能够给出多个选项,并提供令人信服的答案。

GPT-4 生成多个“正确”答案以回应一组三个词。截图由作者拍摄。
GPT-4 答案任意生成的三个词组合
GPT 模型可能已经在其训练数据中看到了一些上述远程联想测试问题。GPT-4 会如何回应任意生成的三个词组合?在这一部分,我构建了我自己的“远程联想”测试,并主观地评估了输出结果。
我从10,000 个最常见的英语单词 [3]中随机抽取了三个单词,并让 GPT-4 提供一个将这些单词联系在一起的第四个单词。这些三词组合中的一些相当奇怪且不相关,但 GPT-4 总是能找到将它们联系在一起的方式。我鼓励读者在阅读 GPT-4 的生成结果之前先尝试这些!
在查看接下来的几个例子后,你可以做出决定——GPT-4 是在跳出框框思考吗?还是只是将答案勉强套入框框?(文章末尾的附录中展示了更长的例子列表)。
⛑️ 健康
这个回应显得相当有说服力。我发现“健康监测”的定义几乎是从CDC 公共健康监测介绍网站上逐字复制的,这一点很有趣。
公共健康监测是“对与健康相关的数据进行持续、系统的收集、分析和解释,这些数据对于规划、实施和评估公共健康实践至关重要。” — 领域流行病学
(来自 CDC 网站)
Words: dietary / surveillance / bornAnswer: HealthReason:
1\. Dietary - A person's diet is directly related to their health, as it can affect their overall well-being, energy levels, and risk for certain diseases.
2\. Surveillance - Health surveillance is the ongoing, systematic collection, analysis, and interpretation of data essential to the planning, implementation, and evaluation of health practice.
3\. Born - When a person is born, their health is of utmost importance, as it can have long-lasting effects on their development, growth, and overall quality of life.
🍀 三叶草
我最初对 GPT-4 的回答感到信服——事实上,这个回答感觉相当聪明。直到我做了一个快速的 Google 搜索,发现Trifolium duodecim 并不存在。这是一个幻觉的例子,幻觉是指大型语言模型编造出不准确的答案的常见现象。
Words: lucky / field / twelveAnswer: CloverReason:
1\. Lucky - A four-leaf clover is considered a symbol of good luck.
2\. Field - Clovers are plants that can be commonly found growing in fields.
3\. Twelve - There is a variety of clover known as "Trifolium duodecim" (Twelve-leaf clover), which has 12 leaflets.
☕️ 星巴克
这个例子是我最喜欢的,虽然它跳出了常规思维,但感觉有点牵强。而且这有点搞笑(尤其是关于 Patrick 的部分)。
Words: grande / cognitive / patrickAnswer: StarbucksReason:
1\. Grande: Starbucks uses the term "grande" to refer to one of its medium-sized drinks.
2\. Cognitive: Starbucks is a popular place for people to go to relax, study, and work, which are all activities that require cognitive function.
3\. Patrick: Patrick Dempsey, the actor who played Dr. Derek Shepherd on the television show Grey's Anatomy, is known for his love of coffee and was nicknamed "McDreamy." Starbucks is a well-known coffee shop, so it's possible that his character would have frequented Starbucks.
这些例子和分析说明了具有单一正确答案的创造力测试可能存在局限性。这一局限性在本文下一部分探讨的其他测试中得到了补充。
2. 替代用途任务:在普通物品中寻找意义
另一种常见的测量创造力的测试是替代用途任务(AUT),由 J. P. Guilford 于 1967 年创建。与远程联想测试不同,这个测试测量发散性思维,这意味着没有正确答案。
测试者被要求为一个日常物品提供尽可能多的原创用途。例如,一个“笔记本电脑”可以用作“盘子”、“护体盾牌”或“墙面艺术”[4]。对这个测试的回答通过原创性(用途的不同程度)和灵活性(创意的多样性)等指标进行评估。此测试的目的是查看一个人能为普通物品想出多少种非常规用途。
然而,由于资源有限,我没有能力在没有人工标注者的情况下生成这些测量值。因此,我使用了创造力的代理,这也是阿姆斯特丹大学的研究人员在《对 GPT-3 创造力的(替代用途)测试》中使用的方法。
在这篇论文中,研究人员要求人类和 GPT-3(davinci-002)为以下常见物品提出替代用途:叉子、铁罐和书籍。作为衡量创造力的代理,他们测量了语义距离——即生成的用途与原始物品之间的差异有多大?
我从论文项目网站获取了人工回答,以便与人类表现进行比较。我使用了以下提示供 GPT 模型使用,基于论文中使用的类似提示。
为以下物品想出 10 种可能让人觉得聪明、不同寻常、有趣、不常见、幽默、创新或与众不同的创造性用途:
嵌入向量之间的余弦距离作为创造力的代理
在自然语言处理领域,语义距离(或语义相似性)是基于词汇在相似上下文中一起使用的频率来衡量两个词汇相关程度的代理。例如,“计算机”和“屏幕”这两个词可能有较短的语义距离,因为它们常常在相同的上下文中使用,而“计算机”和“扁豆”这两个词的语义距离可能较大,因为它们很少在相同的上下文中出现。
为了计算两个词的语义距离,我使用了与研究人员在上述论文中使用的相同方法。我使用了Spacy (en_core_web_lg),一个自然语言处理库,来计算物体(例如“笔记本电脑”)和使用案例词汇(例如“盘子”)的词嵌入。然后,我使用这两个嵌入之间的余弦距离来计算语义距离。
对于替代用途任务,较大的语义距离与“更具创造性”的物体使用案例相关。结果显示,对于三个物体(叉子、锡罐和书),人类的表现通常导致更高的余弦距离(因此更高的语义差异)相比于 GPT-3。然而,也有一些情况下,GPT-3.5 和/或 GPT-4 达到了或超越了人类创造力的语义距离得分。

原始词汇与替代使用案例的语义距离(使用余弦距离测量)。由作者创建。
虽然人类倾向于提出比 GPT-3 更具创造性的使用案例(使用余弦距离作为创造力的代理),GPT-4 的表现更接近于人类水平。也就是说,GPT-4 的回应生成了与原始项目相比具有较大语义距离的项目,其水平与人类表现相似。然而,由于没有实际的人工标注和决定更主观的类别如原创性、灵活性和实用性,这种使用余弦距离的量化方法是有限的,因为它只测量了 GPT 模型创造性输出的一小部分。
你能分辨出人工生成和 AI 生成的使用案例吗?
列表之一展示了人工生成的锡罐替代使用案例,另一个则展示了由 GPT-4 生成的替代使用案例。你能分辨出区别吗?答案在最后。
列表 1:锡罐的用途
-
作为镜子
-
用来制作玩具
-
作为艺术品
-
作为对安迪·沃荷的提醒
-
作为带有罐子夹的钥匙圈
-
作为珠宝
列表 2:锡罐的用途
-
微型鼓套件
-
古怪的花盆
-
临时饼干切割器
-
自制营炉
-
异想天开的风铃
-
微型草药花园
下一部分的测试也测量了发散性思维。
3. 发散联想任务:十个词,没有连接
第三个创造力测试是发散性联想任务(DAT),由一组研究人员在 2021 年最近创建[5]。此测试要求命名十个尽可能不同的名词。显而易见,这项测试是发散性的,而且没有正确答案。
有一个在线测试版本。完成测试只需几分钟,因为它涉及列出十个尽可能不相关的名词。以下是我从网站上获得的结果:

从发散性联想任务中截取的屏幕截图。由作者创建。
DAT 分数
如何计算分数?作者在这里描述了他们的过程。DAT 分数是通过嵌入和余弦距离计算的。我对他们的计算分数过程进行了调整[6]:
-
计算每个单词的GloVe嵌入
-
计算每对单词之间的余弦距离
-
取该值的平均值
这个过程类似于在替代用途任务中计算创造力代理分数的方法,唯一的区别是使用了不同的单词嵌入模型。
比较不同模型之间的差异
我用来提示 GPT 模型的以下提示语取自DAT 网站。
Please enter 10 words that are as different from each other as possible, in all meanings and uses of the words.Rules
- Only single words in English.
- Only nouns (e.g., things, objects, concepts).
- No proper nouns (e.g., no specific people or places).
- No specialized vocabulary (e.g., no technical terms).
- Think of the words on your own (e.g., do not just look at objects in your surroundings).
我没有在下面的结果中包括 GPT-2,因为 GPT-2 生成的输出是乱码。(你可以在文章末尾的附录中看到一个示例)。
结果显示,每个更新的模型 DAT 分数都提高了。这意味着后来的模型在生成尽可能语义不同的 10 个单词方面表现更好。

按照余弦距离测量的不同模型的发散性联想测试分数。由作者生成。
例如,以下是 GPT-3 生成的列表,其分数为 0.36:
dusk, night, day, dawn, morning, afternoon, evening, midnight, twilight
另一方面,以下是 GPT-4 生成的列表,其分数为 0.95:
elephant, microscope, sandcastle, galaxy, harmony, spatula, glacier, philosophy, tambourine, drought
根据介绍 DAT 测试的论文,分数通常在 65 到 90 之间,平均在 75 到 80 之间。95 的分数被认为非常高。
上图中需要注意的是,GPT-4 的误差条远小于 GPT-3 的。这一点在下一个实验中会更加清晰,我在实验中通过调整“温度”参数来改变每个 GPT 模型生成内容的随机性。
温度是一个介于 0 和 1 之间的值,接近 0 的值最具可预测性,接近 1 的值则最为随机。我尝试了介于 0.5 和 0.9 之间的温度值,每次递增 0.1(在之前的实验中,我将温度保持在 0.7,这是默认值)。我期望更高温度值会带来更多的变化,而较低的温度则会更一致。
然而,结果显示温度对语义距离变异性的影响不如模型本身重要。对于 GPT-3 而言,尽管有温度的影响,每组 10 个词的 DAT 得分变化很大。而对于 GPT-4 而言,尽管有温度的影响,DAT 得分在一个小范围内保持了高度一致,变化非常小。

按模型类型和温度分数分开的词嵌入的平均余弦距离。作者创建。
一旦我查看了三种 GPT 模型生成的实际文本,这一点变得更加清晰。GPT-4 在生成词汇时的变异性远低于 GPT-3。 从某种程度上说,感觉 GPT-4 似乎已经知道了一个应该没有正确答案的任务的“正确答案”。
汇总所有温度分数的响应,我让每个模型生成 10 个词汇的列表 50 次。然后,我查看了不同词汇在所有 50 次运行中的出现频率。值得注意的是,GPT-3 最常生成的词是 book,出现了 50 次中的 19 次 = 38%。另一方面,GPT-4 最常生成的词是显微镜,出现了48/50=96%。

每个模型为 DAT 测试生成的顶级词汇(根据温度分数汇总)。由词汇创建。
这验证了 GPT-3 在生成词汇时确实比 GPT-4 更具变异性,而 GPT-4 则采取了更为保守的策略,预测了许多它知道会在“多样化词汇”任务中得高分的词汇——如显微镜、火山、大象和银河。
这一发现展示了 GPT-4 似乎已经“破解了生成多样化词汇的密码”。如果是这样,我不确定是否可以自信地说 GPT-4 的创造力更高,还是它已经找到了一个“如何最好地玩这个特定游戏”的局部最优解。
结论
“我们可以设想一个未来,在这个未来中,GPT-3 和其他生成型 LLMs 的响应无法与人类区分,尽管创造力过程将会不同。”
— 来自* GPT-3 创造力的另类用途测试
使用有限的文本基础测试准确衡量如此复杂多面的创造力是具有挑战性的。虽然像视觉远程联想测试这样的测试通过加入视觉元素来尝试解决这个问题,但这些评估仍然存在固有的局限性。具有“正确答案”的聚合测试(例如远程联想测试)可能不会总是奖励最具创造性或正确的解决方案。另一方面,没有特定正确答案的发散测试(例如发散联想测试)依赖于不完美的词嵌入,并且可能包含偏差。
此外,GPT 模型在这些测试中的表现可能会受到训练过程中对任务的暴露影响,可能导致数据重叠。例如,GPT-4 在发散联想测试中的表现展示了单词的特殊重复,表明它可能提前了解了“正确”的答案。开发替代测试方法至关重要,因为有关于 GPT-4 接受测试数据训练的传闻,特别是在编码任务中。
尽管有这些局限性,创造力测试仍然可以作为比较和跟踪大型语言模型表现的有价值基准。未来的研究实验可以在其他语言中重复本文中的实验,以评估 GPT 模型在语言领域的创造力。这样,我们可以获得对 AI 生成内容的更全面理解,并进一步探索这些先进语言模型的能力和局限性。
感谢阅读这篇文章!
附录
远程联想测试:GPT 模型正确回答的问题比较
我们可以查看模型正确回答的每个问题的难度级别分解。我们可以看到 GPT3、GPT3.5 和 GPT4 在回答难题的能力上相似。

GPT-2 答对的远程联想和 GPT-4 答错的部分
Words: cat / number / phone
Level: Medium
Answer: call
GPT2: call
GPT4: dialWords: land / hand / house
Level: Very Hard
Answer: farm
GPT2: farm
GPT4: propertyWords: water / pen / soda
Level: Unknown
Answer: fountain
GPT2: fountain
GPT4: liquidWords: measure / worm / video
Level: Very Easy
Answer: tape
GPT2: tape
GPT4: lengthWords: cottage / brick / cake
Level: Very Easy
Answer: cheese
GPT2: cheese
GPT4: layerWords: through / disk / way
Level: Hard
Answer: drive
GPT2: drive
GPT4: circleWords: skunk / kings / boiled
Level: Very Easy
Answer: cabbage
GPT2: cabbage
GPT4: smellWords: jump / kill / bliss
Level: Hard
Answer: joy
GPT2: joy
GPT4: height
GPT-4 对随机生成的远程联想测试问题的回答
这里是 GPT-4 为几个随机生成的远程联想测试问题生成的输出。
戒指
这个感觉有点勉强。如果你对你的感情生活感到非常厌倦,我可以接受关于戒指是“复杂性象征,尤其是在关系中的”那部分。但最后的部分感觉有点生硬。
Words: complications | saturn | drops
Answer: Rings
Reason:
1\. Complications: Rings can be a symbol of complications, especially in relationships, as they represent commitment and the complexities that come with it.
2\. Saturn: Saturn is famously known for its beautiful and prominent rings, made primarily of ice particles and some rocky debris.
3\. Drops: A ring can be seen as a drop or loop of material, such as metal or plastic, that forms a complete circle.
革命。 蒙特卡西诺确实是一个真实的地方,二战期间确实遭到轰炸。
Words: belfast | movement | monte
Answer: Revolution
Reason:
1\. Belfast is a city in Northern Ireland that has experienced significant political and social changes due to the Irish Republican movement.
2\. Movement is a general term that can describe political and social changes, often caused by revolutions.
3\. Monte is likely referring to Monte Cassino, a historic site in Italy where a significant battle occurred during World War II, which was a time of global revolution and change.
Therefore, the word "revolution" is related to each of the first three words as it represents significant change and upheaval.
帮助。 通过帮助中心将“南安普顿”这座城市与“帮助”连接起来的有趣方式……
Words: ensures | assistant | southampton
Answer: Help
Reason:
1\. Ensures: "Help" is related to "ensures" as it is an action that can ensure the success or completion of a task.
2\. Assistant: "Help" is related to "assistant" because an assistant provides help in various tasks and activities.
3\. Southampton: "Help" is related to "Southampton" as Southampton is a city in England which has various help centers and services, such as the Southampton Voluntary Services, providing assistance to the community.
互联网。 我认为它的论点是相当正确的。
Words: uruguay | accessible | requirement
Answer: Internet
Reason: Uruguay has a government initiative called Plan Ceibal, which aims to make the Internet accessible to all its citizens, fulfilling a requirement for digital inclusion and reducing the digital divide.
语言。 可能有点通用,但我想它还是有效的。
Words: compensation | atlas | prefix
Answer: Language
Reason:
1\. Compensation - In linguistics, compensation is a phenomenon where a sound change in one part of a word is "compensated" by a change in another part of the word to maintain the overall structure of the word.
2\. Atlas - An atlas is a collection of maps, which often include labels and descriptions in a specific language.
3\. Prefix - A prefix is a morpheme or a linguistic element that is added to the beginning of a word to create a new word with a different meaning, which is a concept in language and grammar.
健康。 健康监测部分有点阴暗,但我想它还是有效的。
Words: dietary | surveillance | born
Answer: Health
Reason:
1\. Dietary - A person's diet is directly related to their health, as it can affect their overall well-being, energy levels, and risk for certain diseases.
2\. Surveillance - Health surveillance is the ongoing, systematic collection, analysis, and interpretation of data essential to the planning, implementation, and evaluation of health practice.
3\. Born - When a person is born, their health is of utmost importance, as it can have long-lasting effects on their development, growth, and overall quality of life.
ERIC! 这个其实还挺有趣的。
Words: eric | figures | purchasing
Answer: accountant
Reason: An accountant is a professional who can be named Eric, works with figures (numbers), and is often involved in purchasing decisions for businesses.
GPT-2 在发散联想测试中的输出
我排除了 GPT-2 在发散测试分析中的结果,因为它的输出完全是胡言乱语。即使我尽可能简化提示,它的输出也不太有用。我在这里包含了一个示例,因为我实际上觉得它们非常有趣。
List 10 words that are as different from each other as possible:"Merry Christmas""I'm not a robot, we're the robots""It's a good thing we're all robots""I'm just a robot""I'm just a robot""I'm just a robot""I'm just a robot""I'm not a robot""I'm just a robot""I'm just a robot"
[1] 吉尔福德, J. P. (1967). 人类智慧的本质。
[2] Mednick, S. A., & Mednick, M.T. (1959,1962). 远程联想测试,大学及成人版。
[3] 通过对Google’s Trillion Word Corpus的n-gram 频率分析,确定了最常见的 10,000 个英语单词。我进一步筛选了这个列表,去除了停用词(例如,“the”、“and”和“you”这样的常见词)。更多信息见此仓库。
[4] 我曾在一次约会应用上发布了这个,但很遗憾,没有人觉得好笑。
[5] Olson, Jay A.;Nahas, Johnny;Chmoulevitch, Denis;Cropper, Simon J.;Webb, Margaret E.(2021 年 6 月 22 日)。“命名无关的词预测创造力”。《国家科学院学报》。118(25):e2022340118。 doi:10.1073/pnas.2022340118。 PMC 8237676。 PMID 34140408
[6] 作者从 10 个词中取了 7 个子集来计算分数。我保留了所有 10 个词。此外,作者使用了 840B-GloVe 嵌入,而我使用了较小的 6B-GloVe 嵌入。我将我的 GloVe 嵌入结果与官方 DAT 分数进行了比较,发现尽管存在这些差异,结果是相当的。
探索生成市场营销组合建模中响应曲线的不同方法
比较饱和度函数和部分依赖性用于响应曲线生成
·
关注 发表在 Towards Data Science ·6 分钟阅读·2023 年 6 月 14 日
--
图片由 Alexander Grey 提供,来源于 Unsplash
响应曲线是营销组合建模的重要组成部分,这是一种统计技术,用于分析各种营销策略和战术对销售或其他业务结果的影响。响应曲线表示了一个营销变量(例如广告支出、价格、促销等)与产品或服务所产生的销售或收入之间的关系。
响应曲线的重要性在于它们能够揭示每个营销变量的有效性以及它们如何影响整体响应。通过分析响应曲线,营销人员可以获得宝贵的见解,了解哪些营销策略推动了最多的销售,以及哪些策略没有达到预期效果。
构建响应曲线有不同的方法,在这篇文章中,我探讨了两种主要方法:使用饱和转换的直接方法和基于部分依赖的方法。我使用线性回归和梯度提升两种不同的算法族来评估这些方法。此外,我展示了在使用复杂机器学习算法时,部分依赖方法可以与通过 SHAP 值生成的响应曲线结合使用。
饱和函数 / 转换
构建响应曲线的直接方法是使用饱和函数(转换),例如 Logistic、负指数 或 Hill。饱和函数是一种数学函数,它捕捉了递减回报效应,即随着营销变量值的增加,其影响逐渐饱和。通过使用饱和函数,可以将营销变量与响应变量之间的关系转化为非线性形式。这使得模型能够捕捉饱和效应,并更准确地表示营销活动与响应(销售或收入)之间的真实关系。
对先验、数据标准化进行实验,并将贝叶斯建模与 Robyn(Facebook 的开源 MMM)进行比较…
使用饱和变换的一个优势是其简单性和可解释性。响应曲线由具有固定参数的数学函数定义,形成一个平滑的曲线,易于可视化。然而,饱和函数的选择是建模之前的重要考虑因素。不同的函数可能会产生不同的结果,选择应基于数据的特征和模型的基本假设。
线性回归与非线性需求
在营销组合建模中,线性回归是一种常用的技术来分析营销变量与响应变量之间的关系。然而,线性回归假设预测变量与响应变量之间存在线性关系。这在试图捕捉营销数据中通常存在的非线性关系时可能会造成限制。
为了克服这一限制并在建模过程中引入非线性,需要对营销变量应用饱和函数或变换。这种变换允许生成一个非线性关系,否则由于线性回归的性质,该关系将是线性的。
捕捉非线性广告饱和度和递减回报,而不显式转换媒体变量
部分依赖方法
部分依赖方法是一种更通用的方法,可以用来建模任何营销变量与响应之间的关系。这种方法包括隔离一个变量的效果,同时保持所有其他变量不变。通过改变感兴趣的营销变量的值并观察相应的响应,可以创建部分依赖图。
与饱和变换生成的平滑响应曲线不同,部分依赖方法得到的图形可能不一定平滑。其形状取决于底层建模算法和媒体变量与响应之间的关系。部分依赖方法在关系复杂且非线性时非常有用,可以在显式使用饱和变换的情况下或当算法自然处理非线性而无需额外饱和变换时应用。
使用基于树的集成方法构建 MMM 模型,并使用 SHAP 解释媒体渠道的表现 (Shapley Additive…)
towardsdatascience.com](/improving-marketing-mix-modeling-using-machine-learning-approaches-25ea4cd6994b?source=post_page-----ff6dcc7927f7--------------------------------)
数据
我继续使用 Robyn 提供的 MIT 许可证数据集,如我之前的文章中所示进行实际例子,并按照相同的数据准备步骤,使用 Prophet 来分解趋势、季节性和假期。
数据集包括 208 周的收入数据(从 2015–11–23 到 2019–11–11),包含:
-
5 个媒体开支渠道:tv_S, ooh_S, print_S, facebook_S, search_S
-
2 个媒体渠道也包含曝光信息 (Impression, Clicks):facebook_I, search_clicks_P(本文未使用)
-
无开支的有机媒体:newsletter
-
控制变量:events, holidays,竞争对手销售 (competitor_sales_B)
建模
我建立了一个完整的工作 MMM 流水线,可以在现实场景中应用于分析媒体开支对响应变量的影响,包含以下组件:
-
Adstock 变换 具有无限衰减率 (0 < α < 1)
-
饱和度山变换 具有两个参数:斜率 / 形状参数,控制曲线的陡度 (s > 0) 和半饱和点 (0 < k ≤ 1)
-
来自 scikit-learn 的岭回归
关于系数的说明
在 scikit-learn 中,岭回归没有提供内置选项来强制对变量子集施加正系数。然而,一个潜在的解决方法是,如果发现任何媒体系数为负值,则拒绝 optuna 解决方案。这可以通过返回一个异常大的值来实现,表示负系数不可接受,应从模型中排除。另一种方法是参考 我关于如何在 Python 中封装 R glmnet 的文章,它允许对变量子集施加系数约束。
对于岭回归,我应用了饱和度变换,并使用饱和度函数和部分依赖方法生成响应曲线。使用 LightGBM 时,我允许模型自然捕捉非线性,并使用部分依赖方法生成响应曲线。此外,我在响应曲线上叠加了 SHAP 值,以提供进一步的见解。
结果
具有饱和度变换的岭回归
如可以观察到的,使用饱和度函数和部分依赖生成的响应曲线都显示出重叠的模式,表明这两种方法捕捉了营销变量与响应之间的相似关系。

LightGBM
如前所述,由部分依赖生成的响应曲线可能不一定光滑。这可能是由于梯度提升算法的性质,该算法涉及将特征空间划分为区域,并结合多个决策树之间的交互作用。

图片来源于作者
下面的图表展示了岭回归和 LightGBM 的响应曲线,突出了这两种算法在捕捉递减收益方面的差异。此外,我们观察到 SHAP 值提供了对部分依赖方法生成的响应曲线的可靠近似。

图片来源于作者

图片来源于作者

图片来源于作者

图片来源于作者

图片来源于作者
结论
响应曲线在营销组合建模中起着关键作用,通过提供不同营销变量的有效性和它们对整体响应的贡献的见解。在本文中,我探讨了生成响应曲线的两种主要方法:使用饱和度变换的直接方法和部分依赖方法。我使用线性回归和梯度提升两种算法家族来评估这些方法,并展示了不同算法捕捉非线性响应的对比方式。此外,我将使用 SHAP 值生成的响应与通过部分依赖方法获得的结果进行了比较。
完整代码可以从我的Github repo下载
感谢阅读!
探索情感与人工智能、OpenAI 和探索性数据分析的关系
以下是如何使用 Python 通过 OpenAI 和探索性数据分析在文本中可视化情感
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 12 月 12 日
--

作者使用 Midjourney 制作的图片
我想首先说的是,我更喜欢旧版迪士尼电影,而不是新版的。
我觉得这和我小时候看旧版迪士尼电影有关,我对那一刻有一种怀旧的感觉。即使我并不是电影专家,我也觉得旧版迪士尼电影的情节最好。
不过,有一个显著的例外,那就是 脑内飞。我在影院看了这部电影,并且非常喜欢。我不想剧透,只想说这部电影探讨了我们内心中存在的一系列情感:
-
愤怒
-
厌恶
-
快乐
-
恐惧
-
悲伤
这些情感有时像真正的人一样在我们内心交谈。这是一部令人非常温馨的电影,我认为情节非常出色。当我听说新电影《脑内飞》即将上映时,我非常兴奋,正在倒计时。😃
现在,让我猜测一下你的问题:
“这跟人工智能有什么关系呢?”
当我在看预告片的时候,我想到
“我们能生成情感会说什么吗?”
“我们能否使用 ChatGPT 技术来模拟如果情感‘愤怒’是一个人的话,他们会说什么?”
“我们能否将情感作为 N 维空间中的向量来探索?如果可以,我们能否将其绘制出来?”
所以系好安全带,让我们看看能做些什么吧 😏
在我们深入之前, 我想强调这是一个基于电影的“游戏”。我认为一个情感被 AI 取代的世界会非常悲伤,我并不声称这个过程会在第一时间做到这一点。
这个故事分为 4 部分:
0. 代码结构
1. 代码的第一部分,生成情感。高度关注代码
2. 代码的第二部分。这部分的内容少有实际上基于代码。它更多地基于数据集的探索。这是乐趣开始的地方!
3. 总结
4. 结论
这个博客将以代码为基础,旨在对开发者有用。 如果你不是开发者,可以直接跳到第三章,我将在那里描述使用 AI 进行情感检测的结果 😃
0. 结构
这项工作分为两章:
-
“情感”数据的生成:它将是一个包含 main.py 脚本的模块,用于生成句子。这些句子将类似于《头脑特工队》中的电影句子。我们将用 OpenAI 生成每种情感 150 个句子:150x5 = 750 个句子。
-
探索性数据分析: 这将是一个笔记本,用于对 OpenAI 生成的句子进行数据探索。
让我们描述一下我们是如何做到的……
整个代码可以在我的 GitHub 页面上找到 (
github.com/PieroPaialungaAI/Emotion_AI/)
1. 情感 AI
我们称情感数据的生成为EmotionAI。
在 EmotionAI 中,我创建了 4 个 Python 脚本:
-
constants.py
-
main.py
-
util.py
-
emotionai.py
你可以在这个 Github 文件夹中找到。让我们从底部开始,描述 main.py 文件:
1.1 Main.py
main.py文件的构思是创建情感句子。
它使用一个名为 AIEmotionGenerator 的类来完成,AIEmotionGenerator 是emotionai.py代码的一部分。使用 AIEmotionGenerator,你首先创建文件夹,存储句子(每种情感 50x3),然后生成句子并存储它们。很简单,对吧?
肮脏的工作由emotionai.py完成,来看看:
1.2 emotionai.py
emotionai.py构建了一个名为 AIEmotionGenerator 的类。
这个类有两个功能:
-
为你创建文件夹(无聊)🥱
-
生成给定情感的句子(非常有趣)🤩
我通过给 OpenAI 的 GPT 一个提示来生成情感,这些提示是真实的《头脑特工队》中的情感句子。让我们在constants.py中看看它们是什么。
1.3 constants.py
所以,这就是 constants.py。在这个文件中,我们有用于让 GPT 更好地理解任务的句子。你也应该在其中填入你的Open AI API密钥……这对使其工作至关重要*。
*OpenAI API 密钥最终会花费你一些钱,但我为这个项目花了大约$0.11。关注 OpenAI 的使用页面,但不要为此感到压力。这不是一项投资。 😃
这些句子来自电影。注意到 OpenAI 被编程为不提供愤怒或悲伤的回答,这就是为什么我将情感改为“搞笑的悲伤”和“有趣的愤怒”……如果你问我,这些技术假装世界总是充满彩虹和冰淇淋,这有点令人沮丧,但它确实起到了作用。
2. 情感句子的 EDA
好的,一旦你生成了情感,你将会有你的文件夹,它们将会是这样的:

图片来自作者
我们需要提取.txt 文件并进行探索。让我们一步一步来:
2.0 库
我使用了我的老朋友们:👻
注意 EmotionAI 在库中,我们需要使用它(特别是用于常量)
2.1 导入文本
这一部分有点无聊但必要,我们从文件夹中导入文本,并以数据框的形式整理。
不幸的是,OpenAI 的响应可能会有些随机,特别是在格式上。所以请检查.txt 文件,并删除类似“当然!让我给你答案”或“对不起你感到难过,这是答案列表”之类的内容……
一旦你完成了,你的数据集将会如下所示:
2.2 可视化:词云!
词云是一种查看文本中词频的方法。这就是我们将用来逐类(或逐情感)可视化文本的方法:

这非常有趣。厌恶情感中出现频率最高的词是“SMELL”和“SOCK”。快乐情感中有“LAUGHTER”和“DELIGHT”这样的词。悲伤情感中有“STUCK”、“ACCIDENTALLY”、“NEVER”这些词。恐惧情感中有“SCARED”、“MIGHT”、“WORRY”这些词。
我们正在取得好成果。让我们继续前进!
2.3 嵌入
嵌入的思想是将单词转换为数字……或者更好……将句子转换为数字序列。就像这样!

带有糟糕和非常糟糕的句子比较接近,而带有美丽的句子则远离!
你可以用很多方法来进行嵌入步骤,我选择了使用 OpenAI,主要因为它快速而高效,但你可以选择你自己的方法!
这就是你如何在数据集中嵌入每个句子,从而将每个句子表示为一个向量,就像这样:

附注:如果你用西兰花替换我的冰淇淋,我也会生气的……
2.4 降低向量维度……
现在我们想要可视化向量。为了做到这一点,因为我们的脑袋只能处理 2D 或 3D 的事物,让我们使用 PCA 将维度从 1500+(使用 GPT 嵌入的向量维度)降到 2*。
-
预处理输入向量:
-
训练 PCA:
-
应用 PCA到原始向量
*我写了这篇文章来介绍主成分分析,如果你想了解它是如何工作的…… 😃
2.5 看彩虹!
现在我们有了二维数据集,可以对其进行绘制并了解发生了什么。让我们开始吧!
这很不错……你可以区分……
-
悲伤在数据集的下部分,呈长条形。
-
厌恶在数据集的右上部分,呈对角条形。
-
愤怒在数据集的左上部分,以圆形结构呈现。
让我们看看“快乐”怎么样。
恐惧和快乐之间有相当大的分隔,这很好,除了中间的一个区域。我怀疑那是因为“恐惧”在看到类似这样的句子时会变得“快乐”:
“我非常高兴蜘蛛没有接近我。我对蜘蛛感到恐惧。”
这确实像是一个快乐/恐惧的情况 🙃
在这种情况下,我们也对电影中的恐惧句子应用了 PCA,我们看到恐惧在电影中说的句子(是):
“我真高兴你告诉我地震是神话,快乐,否则我现在会很害怕。”
基本上在“恐惧”句子的橙色点之间。
如果我们对反感的句子做同样的处理,我们可以看到反感确实位于绿色点之间,这正是我们想要的效果。
反感的句子是:
“好吧,小心,这里有一种危险的气味,人们。”
即使这个情况也有点像厌恶/恐惧。这是因为情感不是非黑即白的,你可以同时感受到很多情感。如果你曾经恋爱过,你会非常明白这一点 ❤
3. 结论
非常感谢你阅读这个故事。我在做这个过程中非常开心。我喜欢《头脑特工队》,我认为当我们尝试看看计算机是否能生成类似情感的句子时,非常有趣。
在这个故事中我们:
-
确定《头脑特工队》很棒,我们以此作为创建情感“表达”想法的出发点。如果“恐惧”能说话,它会说什么?
-
使用了 OpenAI 的 GPT 技术来生成句子,基于电影中的原始句子。我们让 OpenAI 生成了 150 个类似反感的句子、150 个类似恐惧的句子、150 个类似快乐的句子和 150 个类似悲伤的句子。
-
使用了词云、嵌入和主成分分析(PCA)技术来可视化结果。
GPT 提取出的句子非常有趣。
恐惧说:
“我害怕自己的影子可能会活过来并开始追赶我。”
而悲伤说:
“我曾经因为忘记推而被困在旋转门里。”
词云也很棒,因为它突出了悲伤具有恐惧感,并使用了像“从不”这样的词汇,而反感几乎与“气味”相关,这非常准确。
当我们可视化这些句子时,我们看到句子形成了清晰的簇,并且电影中的原始句子基本上在对应的簇中。这是连贯性和良好提示的标志。
这项工作可能对许多事情有用,比如:
-
使用这项技术向电影导演建议新的句子
-
通过追踪角色的向量来理解角色的一致性
-
创造新的情感和感觉
-
你的想法(评论此帖子)
4. 关于我!
再次感谢你的时间。这对我意义重大❤
我是 Piero Paialunga,我是这里的那个人:

我是辛辛那提大学航空航天工程系的博士生。我在博客文章和 Linkedin 上谈论 AI 和机器学习。如果你喜欢这篇文章,想了解更多关于机器学习的内容并关注我的研究,你可以:
A. 在Linkedin上关注我,我会在这里发布我所有的故事
B. 订阅我的新闻通讯。它会让你了解最新故事,并给你机会给我发消息,以便收到你可能有的所有修正或疑问。
C. 成为一个推荐会员,这样你就不会有“每月最大故事数”的限制,可以阅读我(以及成千上万其他机器学习和数据科学顶级作家)关于最新技术的文章。
如果你想问我问题或开始合作,请在这里留言:
piero.paialunga@hotmail.com
再见❤️
探索 CNNs 中的特征提取
使用卷积神经网络检查特征提取中的专业化
·
关注 发表在 Towards Data Science ·8 min read·2023 年 11 月 25 日
--
(左) 使用 vgg19 CNN 架构对狮子图像进行的特征提取(图像由作者提供)。 (右) 狮子的原始照片(公共领域,来源于 Pexels)。
卷积神经网络(Convolutional Neural Networks,CNN)是当今利用机器学习进行图像分类任务的基石。然而,在分类之前,它们执行的另一个非常有用的任务是从图像中提取相关特征。特征提取是 CNN 识别图像关键模式的方式,以便对其进行分类。本文将展示如何使用 TensorFlow 和 Keras 函数 API 执行特征提取的示例。但首先,为了正式化这些 CNN 概念,我们需要先讨论像素空间。
背景
像素空间
像素空间正是其名字所暗示的:它是将图像转换为数值矩阵的空间,其中每个值对应一个单独的像素。因此,我们看到的原始图像,在输入 CNN 时,被转换为一个数字矩阵。在灰度图像中,这些数字通常范围从 0(黑色)到 255(白色),中间值是灰度。在本文中,所有图像都已经过归一化处理,即每个像素都被 255 除以,使其值位于[0, 1]区间内。
CNN 和像素空间
CNN 在像素表示中对图像所做的事情是应用滤波器并处理它,以提取相关像素,从而做出最终的“决策”,即将该图像归入某个类别。例如,在页面顶部的图像中,该 CNN 特别关注狮子的嘴、舌头、眼睛(以及总体上的强轮廓),随着我们深入神经网络,这些特征进一步被提取。因此,可以说,CNN 在分类方面越专业,它在识别图像关键特征方面就越好。
目标
说到这一点,目标很简单:看看 CNN 在进行特征提取时的专业程度。
方法
为此,我使用了相同架构但训练大小不同的两个 CNN:一个使用了 50K 张图像(这是基准,智能的那一个),另一个使用了 10K 张图像(这是虚拟的)。之后,我切割了 CNN 的层,以检查算法看到的图像及其理解。
数据集
本项目使用的数据集是广泛使用的cifar10图像数据集[1],这是一个公共领域数据集,包含 60K 张图像,分为 10 个类别,其中 10K 张图像用作保留验证集。这些图像大小为 32x32 像素,是 RGB 彩色图像,即有 3 个颜色通道。
为了防止数据泄漏,我保留了一张图像作为特征识别的测试图像,因此该图像未在任何训练中使用。现在,我向大家介绍我们的小白鼠:青蛙。

青蛙。
在 TensorFlow 中的实现
实现代码见下方代码片段。为了正确切片 CNN 的层,必须使用 TensorFlow 中的 Keras 功能 API 而不是顺序 API。它像级联一样工作,下一层调用在上一层之上。
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPool2D, Dense, Dropout, Flatten
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
def get_new_model(input_shape):
'''
This function returns a compiled CNN with specifications given above.
'''
#Defining the architecture of the CNN
input_layer = Input(shape=input_shape, name='input')
h = Conv2D(filters=16, kernel_size=(3,3),
activation='relu', padding='same', name='conv2d_1')(input_layer)
h = Conv2D(filters=16, kernel_size=(3,3),
activation='relu', padding='same', name='conv2d_2')(h)
h = MaxPool2D(pool_size=(2,2), name='pool_1')(h)
h = Conv2D(filters=16, kernel_size=(3,3),
activation='relu', padding='same', name='conv2d_3')(h)
h = Conv2D(filters=16, kernel_size=(3,3),
activation='relu', padding='same', name='conv2d_4')(h)
h = MaxPool2D(pool_size=(2,2), name='pool_2')(h)
h = Conv2D(filters=16, kernel_size=(3,3),
activation='relu', padding='same', name='conv2d_5')(h)
h = Conv2D(filters=16, kernel_size=(3,3),
activation='relu', padding='same', name='conv2d_6')(h)
h = Dense(64, activation='relu', name='dense_1')(h)
h = Dropout(0.5, name='dropout_1')(h)
h = Flatten(name='flatten_1')(h)
output_layer = Dense(10, activation='softmax', name='dense_2')(h)
#To generate the model, we pass the input layer and the output layer
model = Model(inputs=input_layer, outputs=output_layer, name='model_CNN')
#Next we apply the compile method
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
return model
架构规格如下图 1 所示。

图 1:使用的 CNN 概述。图像由作者提供。
使用的优化器是 Adam,损失函数是类别交叉熵,用于评估的指标仅仅是准确度,因为数据集完全平衡。
切片 CNN
现在我们可以切片两个 CNN 的一些关键层,以检查图像的处理水平。代码实现如下:
benchmark_layers = model_benchmark.layers
benchmark_input = model_benchmark.input
layer_outputs_benchmark = [layer.output for layer in benchmark_layers]
features_benchmark = Model(inputs=benchmark_input, outputs=layer_outputs_benchmark)
这里发生的情况如下:第一行访问模型的每一层,第二行返回整个 CNN 的输入层。然后在第三行,我们列出每一层的输出,最后我们创建一个新模型,其输出是这些层的输出。这样我们可以查看层之间发生了什么。
编写了非常相似的代码来访问我们虚拟模型的层,因此在此省略。现在让我们继续查看在不同层处理后的青蛙图像。
第一层卷积层
虚拟
图 2 显示了卷积层(conv2d_1)的 16 个滤波器的图像。我们可以看到这些图像并未被过度处理,而且有很多冗余。有人可能会认为这只是第一层卷积层,这也是处理不那么复杂的原因,这确实是一个合理的观察。为了解决这个问题,我们将查看基准测试的第一层。

图 2:虚拟分类器的第一层卷积层。图像由作者提供。
基准测试
基准测试分类器显示了一个经过处理的图像,处理得更加彻底,以至于这些图像大部分已经无法识别。请记住:这只是第一层卷积层。

图 3:基准测试分类器的第一层卷积层。图像由作者提供。
最后一层卷积层
虚拟
正如预期的那样,图像现在已经无法识别,因为我们已经经过了 6 层卷积层和 2 层池化层,这解释了图像尺寸的减小。让我们看看基准测试的最后一层是什么样的。

图 4:虚拟分类器的最后一层卷积层。图像由作者提供。
基准测试
这已经处理得更为彻底,大部分像素变成了黑色,这表明重要的特征已被选出,其余的图像基本上被丢弃了。

图 5:基准测试分类器的最后一层卷积层。图像由作者提供。
这与信息有何关联?
我们可以看到,对于网络的相同切片,处理程度差异很大。定性分析表明,基准模型在从输入中提取有用信息时更加激进。这在第一次卷积层比较中尤为明显:青蛙图像的输出在虚拟模型上扭曲较少,更加可识别,而在基准模型上则较为模糊。
这表明基准在预测类别时,能够更有效地丢弃图像中无用的元素,而虚拟分类器则对如何进行处理不确定,因此考虑了更多的特征。从图 6 中我们可以看到,基准(蓝色)丢弃的颜色像素比虚拟模型(红色)更多,后者在其颜色分布中显示出更长的尾部。

图 6:最后一层的像素概率分布。我们可以看到基准的像素(蓝色)更集中在零附近,而虚拟模型的像素(红色)显示出更长的尾部。
如果我们查看原始青蛙图像的像素分布,图 7 显示了一个更对称的分布,大致集中在 0.4 左右。

图 7:我们原始青蛙图像的颜色分布。图像由作者提供。
从信息理论的角度来看,原始图像和卷积层后得到的图像之间概率分布的差异代表了巨大的信息增益。
从图 6 和图 7 比较来看,我们对前者中将要找到哪些像素比对后者中的像素更为确定。因此,信息获得有所增加。这是对信息理论的非常简要和定性探讨,并为广阔的领域开启了一扇门。有关信息(带有双关语)的更多信息,请参见这篇文章。
最后,查看分类器答案的不确定性的一种方法是观察类别的概率分布。这是我们 CNN 末尾的 softmax 函数的输出。图 8(左)显示基准对类别的确定性更高,分布在青蛙类别上峰值;而图 8(右)显示了一个困惑的虚拟分类器,最高概率出现在错误的类别上。


图 8:每个分类器分配的概率分布。图像由作者提供。
这一分析展示了预期的结果:更大的训练集导致更好的特征捕捉算法。
本项目的笔记本可在这里找到。
参考文献
[1] Krizhevsky, Alex 和 Geoffrey Hinton。从小图像中学习多个特征层。(2009 年):7。
[2] M. Jogin 等人,使用卷积神经网络(CNN)和深度学习的特征提取(2018 年),IEEE 国际电子信息与通信技术近期趋势会议(RTEICT)
[3] K. Simonyan 和 A. Zisserman,用于大规模图像识别的深度卷积网络(2015 年),作为会议论文发表在 ICLR 2015
[4] Z. Matthew 和 F. Rob。可视化与理解卷积网络(2013 年),欧洲计算机视觉大会
[5] B. Jyostna 和 N. Veeranjaneyulu,使用深度卷积神经网络的特征提取与分类(2018 年),《网络安全与移动性期刊》
探索 GEMBA:一种基于 LLM 的翻译质量评估新指标
#GEN-AI 研究论文
使用 LLM 评估翻译质量
·
关注 发表在 Towards Data Science · 9 分钟阅读 · 2023 年 9 月 29 日
--
图像由作者使用 DALL.E 2 生成
介绍
我最近读到了一篇有趣的微软论文¹(发表于 2023 年 5 月),引起了我的注意。论文深入探讨了翻译评估领域,揭示了一种名为 GEMBA(GPT Estimation Metric Based Assessment)的创新度量标准。在这篇博客文章中,我们将剖析论文并提供对这一令人兴奋的发展的一些见解。
-
前提:探索论文背后的动机。
-
研究问题和假设:探讨论文的主要研究问题和假设。
-
翻译质量评估度量标准:浅入现有度量标准,包括 BLEU、COMET 和 METEOR。
-
介绍 GEMBA:深入了解新颖的 GEMBA 度量标准。
-
实验细节:对所进行实验的洞见。
-
主要发现:突出论文的主要结果。
-
限制:讨论在生产环境中实施 GEMBA 前需要注意的事项。
1. 前提
尽管 LLMs 最初并未设计用于翻译任务,但它们在这一领域展现了令人印象深刻的精准度。这一认识促使作者探索使用 LLMs 作为翻译评估工具。论文的核心思想相当简单——将 LLMs(大型语言模型)定位为评估翻译的工具,而不仅仅是执行翻译。作者提出了一种新的度量标准叫做 GEMBA,它在翻译质量评估方面超越了现有的最先进度量标准。
2. 研究问题
LLMs 能否用于翻译质量评估?
3. 翻译质量评估度量标准
在深入了解 GEMBA 之前,让我们快速回顾一下用于评估机器生成翻译质量的现有度量标准,如 BLEU、COMET、METEOR 等。这些度量标准各有其优点,适用于不同的使用案例,取决于翻译质量中最重要的特定方面。
例如,BLEU(双语评估替代指标)主要关注 n-gram 精确度,这意味着它测量机器翻译中的 n-词序列与一个或多个参考翻译中的 n-词序列的重叠情况。它奖励特定词序列的存在,而没有明确考虑词序、词干提取或同义词关系。另一方面,METEOR(显式排序翻译评估指标)超越了基本的 n-gram 匹配,采取了更全面的翻译评估方法。它考虑了翻译质量的多个方面,包括词干提取、同义词关系、词序、精确的词匹配,甚至对未翻译词的惩罚。同样,COMET(基于内容的机器翻译评估)使用了一种略微不同的方法,专注于基于内容的评估,并通过嵌入计算机器翻译输出与参考翻译之间的语义相似度。简而言之,它评估机器翻译的内容和意义与参考翻译的匹配程度,而不考虑具体的语言变异或词汇选择。
你可以在这里了解其他评估指标,如 YiSi、chrF、BERTScore 等。
鉴于我们刚才讨论的众多指标,你可能会问——为什么我们需要像 GEMBA 这样的新指标? 答案在于其独特的方法——促使大型语言模型根据自身判断评估翻译。与传统指标不同,GEMBA 旨在通过对翻译进行评分(例如 0 到 100),关注意义和语法,与人类评估翻译对齐。
4. 介绍 GEMBA
如前所述,GEMBA 是一种基于 GPT 的翻译质量评估指标,它可以在有参考翻译和没有参考翻译的情况下工作。
本质上,GEMBA 是一个经过精心设计的评估任务 prompt,包含:
-
prompt 变体(来自预定义的四种变体集合)
-
源语言名称,例如“中文”
-
目标语言名称,例如“英语”
-
源片段,即需要翻译的句子
-
候选翻译,即翻译的句子
-
[可选] 参考翻译,即可以作为基准翻译的翻译
这是其中一种 prompt 变体的示例:GEMBA-DA(直接评估)

GEMBA-DA prompt。图像来源于原始论文
附注:如果你对其他三种变体感兴趣,以下是论文中介绍的所有 prompt 变体的详细信息:

GEMBA prompt 变体。图像来源于原始论文
5. 实验和评估
作者使用广泛流行的 MQM 2022 数据集(多维质量指标)测试了 GEMBA 指标的效率。该数据集包括来自新闻、社交、电商等各种领域的多样化句子(100K+),涵盖三个翻译方向:英语到俄语、英语到德语和中文到英语。
此外,为了找到实施 GEMBA 的最佳 GPT 模型,作者测试了 GPT 系列中从 GPT 2 到最新 GPT-4 模型的七个模型的每个提示变体。

用于评估 GEMBA 的 7 个 GPT 模型。图像来自原始论文
在为 GEMBA 指标实验设定了舞台后,接下来的明显问题是 —
问:我们如何判断 GEMBA 是否比传统指标如 BLEU 和 COMET 表现更好?
A: 如果 GEMBA 得分与人类对翻译的看法紧密相关,那么我们就找到了赢家!
为了实现这一答案,需要根据我们是进行 段级 评估还是 系统级 评估来计算两个指标(Kendall’s Tau 和准确度 (accuracy, Kocmi et al., 2021))。但首先,它们是什么?
系统级评估 评估机器翻译系统的 整体性能。它查看系统生成的翻译在广泛文本或内容中的质量。
段级评估 侧重于评估逐段翻译的质量(通常是 句子 或更小的文本单元)
一般而言:
-
Kendall’s Tau 用于段级评估
-
准确度用于系统级评估
为了清晰起见,我们来深入了解它们的公式,使用简单的示例:
A. Kendall’s Tau
(Kendall’s Tau 用于判断两个排名之间是否存在相关性)
假设你有一个给定句子的三种翻译(A、B 和 C),你想评估指标(如 LLM、BLEU、METEOR 分数)产生的排名与人类翻译质量判断之间的相关性。
参考(人类): “快速的棕色狐狸跳过了懒狗。”
翻译 A: “迅速的棕色狐狸跳过了懒狗。”
翻译 B: “快速的红色狐狸跳过了懒狗。”
翻译 C: “懒狗被快速的棕色狐狸跳过。”
人类排序: A > B > C(即,他们最喜欢翻译 A,然后是 B,最后是 C)
这些翻译的指标分数: LLM(A) = 0.85
LLM(B) = 0.75
LLM(C) = 0.60
综合所有信息,我们可以按如下方式计算 Kendall’s Tau:

接下来,让我们计算一致对和不一致对:
对 1: (A, B) 人工排序:A > B
LLM 分数:LLM(A) = 0.85 > LLM(B) = 0.75
结果: 一致对(人类和指标都更喜欢 A 而不是 B)。
对 2: (A, C) 人工排序:A > C
LLM 分数:LLM(A) = 0.85 > LLM(C) = 0.60
结果:一致对(人工和度量都更喜欢 A 而不是 C)。
对 3: (B, C) 人工排名:B > C
LLM 分数:LLM(B) = 0.75 > LLM(C) = 0.60
结果:一致对(人工和度量都更喜欢 B 而不是 C)。
将这些值代入公式中,我们得到:

Kendall’s Tau
换句话说,τ = 1 表示度量和人工判断之间完全一致,因此这是一个可以用于自动化翻译质量的高质量度量。
B. 准确率
Kendall’s Tau 评估排名之间的相似性或一致性,而准确率衡量排名的正确性。
为了说明准确率的计算,我们采用与上面相同的设置,即参考(人工)、翻译 A、翻译 B、翻译 C、人工排名,但让我们稍微更新度量分数,以便根据 Bleu 将 B 标记为比 C 更好的翻译:
度量分数(BLEU):
BLEU(A) = 0.80
BLEU(B) = 0.70
BLEU(C) = 0.75
结合所有信息,这里是如何计算准确率的方法:

让我们计算度量Δ(这只是成对翻译的度量值差异)和人工Δ(这是假设的翻译对的人工评分差异)。如果你仔细查看公式,你会注意到我们并不关心Δ的实际值,而是Δ的符号。简单来说,只有当两个Δ的符号相同,即人工和度量对翻译的看法一致时,才能实现高准确率。
对 1: (A, B) 度量Δ = BLEU(A) — BLEU(B) = 0.80–0.70 = 0.10
人工Δ = 1(A 的排名高于 B)
结果:度量Δ和人工Δ具有相同的符号(都是正的)。这是一个排名一致性。
对 2: (A, C) 度量Δ = BLEU(A) — BLEU(C) = 0.80–0.75 = 0.05
人工Δ = 2(A 的排名高于 C)
结果:度量Δ和人工Δ具有相同的符号(都是正的)。这是一个排名一致性。
对 3: (B, C) 度量Δ = BLEU(B) — BLEU(C) = 0.70–0.75 = -0.05
人工Δ = 1(B 的排名高于 C)
结果:度量Δ和人工Δ具有不同的符号(度量为负,人工为正)。这是一个排名不一致。
将这些值代入公式中,我们得到:
准确率(Bleu)= (2/3)*100 = 67%,这意味着 BLEU 度量准确地将 3 对翻译中的 2 对按照人工判断进行了排名。是否这个百分比足够好以自动化 Bleu 的评估,我留给读者自行判断!
注意:用于演示 Kendall’s Tau 和准确率计算的示例已简化以进行演示。在实际情况下,如果需要处理平局,即如果人工/度量对两个或更多翻译给出相同的排名,则公式会变得更复杂。你可以阅读更多有关内容 这里。
关键结果:系统与片段级别评估
论文报告称 GEMBA 在系统级评估中表现优异,超越了现有的评估指标。

系统级评估结果。图片取自原始论文
然而,在分段级评估中还有改进的空间。LLM 与人类在这一层级上的排名平局可能解释了这一差异,因为 Kendall’s Tau 会惩罚平局。由于 Gemba-DA 指标返回 0-100 之间的离散值,两种翻译得到相等分数的概率很高。

分段级评估结果(P.S. 第一列的准确率与前面的表格相同)。图片取自原始论文
结果还强调了选择合适 LLM 以实施 GEMBA 的重要性。在测试的 GPT 家族的七种模型中,任何超过 3.5 的模型都表现出色。GPT-4 表现尤为突出,但 Davinci、ChatGPT Turbo 和 GPT-3.5 也表现良好。

各种 GPT 家族模型的 GEMBA 实施。图片取自原始论文
7. 限制和考虑事项
论文突出了 GEMBA 在更广泛应用中的某些限制。
-
由于论文仅考虑了英语、中文、俄语和德语,因此需要对低资源语言进行 GEMBA 的评估。
-
可能存在数据泄漏的风险,因为尚不清楚测试数据是否包含在 Open AI 的训练中(撰写时 Open AI 尚未发布秘密配方)。尽管如此,可能性非常低,因为 GPT 模型声称知识截止日期为 2021 年 9 月,而 MQM 数据集于 2022 年 12 月发布。
-
LLM 可能会偶尔出现无效回应:
►文本回答而非分数→通过提高
temperature直到输出数字分数来处理。►“2”,“two”,“**”,“★★”,“two stars”或“2 stars”→在后处理中处理以保持一致性。
►作者排除了非英语目标语言(如星或五)的 LLM 输出。

无效回应的数量。图片取自原始论文。
结论
牢记使用简单提示实现 GEMBA 的便利性,它无疑是翻译质量评估的突破性指标。它与人类判断的一致性以及对各种 LLM 模型的适应性,使其成为 NLP 和翻译评估领域的有力补充。随着我们继续探索和完善 GEMBA(也许通过少量提示),它作为确保在多种语言环境中高质量翻译的有价值工具,具有很大的潜力。
[1] Kocmi, T., & Federmann, C. (2023). 大型语言模型是翻译质量的最先进评估工具。arXiv 预印本 arXiv:2302.14520。
探索大规模栅格人口数据
原文:
towardsdatascience.com/exploring-large-scale-raster-population-data-72803cf7f2ad

图片由作者提供。
使用 Python 可视化不同尺度的地理空间人口数据:全球、国家和城市级数据
·发布于 Towards Data Science ·阅读时长 9 分钟·2023 年 9 月 21 日
--
我常常看到漂亮的人口地图在网上传播;然而,我通常会在一些技术部分遇到困难,比如可视化教程中未显示的其他地图片段,或将大规模的栅格数据转换为更易计算的矢量格式。本文通过对两个主要全球人口数据来源的实用指南,克服了其中的一些不足。
还需要注意的是,除了它们的美学价值外,人口数据和显示这些数据的地图是任何城市发展或地点情报任务中最基本和最有价值的信息之一。它们在一些具体应用中特别有用,例如规划新设施、选址和流域分析、估算城市产品的规模,或对不同社区进行画像,仅举几例。
1. 数据来源
我依赖以下两个精细化的人口估计数据来源,您可以通过附上的链接下载文件(发布时的日期):
-
欧洲委员会的 GHSL — 全球人类居住层 测量每个网格单元的居民数量。可以在这里找到总体描述,特定的数据集来自他们的 2023 报告,空间分辨率为 100m。
-
我将以德国为例,使用 WorldPop hub 提供的约束性单国数据集,分辨率为 100m。可以在此处找到数据清单,而德国的数据在此处。
2. 可视化全球人类居住层
2.1. 导入数据!
我第一次在 Datashader 教程中遇到这个数据集,该教程来自Architecture Performance。在复制他们的可视化后,我在将其扩展到全球地图时遇到了一些障碍,这引发了这项工作,所以现在我将向你展示我找到的解决方法!
首先,使用 xarray 包解析光栅文件。
import rioxarray
file_path = "GHS_POP_E2030_GLOBE_R2023A_54009_100_V1_0/GHS_POP_E2030_GLOBE_R2023A_54009_100_V1_0.tif"
data_array = rioxarray.open_rasterio(file_path, chunks=True, lock=False)
data_array
此单元的输出是数据集的详细描述:

2.2. 可视化数据的各个块
我们已经可以看到,对于大多数标准笔记本电脑来说,这是一大挑战。无论如何,让我们尝试使用 Datashader 可视化它,这是一种非常方便的工具,用于处理这种规模的地理空间数据集:
# WARNING: THIS CODE BLOCK WILL MOST LIKELY CAUSE A MEMORY OVERFLOW ERROR
import datashader as ds
import xarray as xr
from colorcet import palette
from datashader import transfer_functions as tf
# prepare to plot
data_array_p = xr.DataArray(data_array)[0]
data_array_p = data_array_p.where(data_array_p > 0)
data_array_p = data_array_p.compute()
# get the image size
size = 1200
asp = data_array_p.shape[0] / data_array_p.shape[1]
# create the data shader canvas
cvs = ds.Canvas(plot_width=size, plot_height=int(asp*size))
raster = cvs.raster(data_array_p)
# draw the image
cmap = palette["fire"]
img = tf.shade(raster, how="eq_hist", cmap=cmap)
img
尽管这段代码从技术上看是可行的,但我的 2021 年 M1 Macbook Pro(16GB RAM)却出现了可怜的内存溢出错误。因此,让我们裁剪图像来查看数据!为此,我遵循 Architecture Performance,重点关注欧洲,暂时这样做,因为它肯定有效。
然而,我稍后将回答的主要问题是,尽管存在这样的内存限制,但我们如何仍然能够在本地机器上可视化整个地球的数据?请稍等!
import datashader as ds
import xarray as xr
from colorcet import palette
from datashader import transfer_functions as tf
import numpy as np
# crop the data array
data_array_c = data_array.rio.clip_box(minx=-1_000_000.0, miny=4_250_000.0, maxx=2_500_000.0, maxy=7_750_000.0)
data_array_c = xr.DataArray(data_array_c)
# prepare to plot
data_array_c = xr.DataArray(data_array_c)[0]
data_array_c = data_array_c.where(data_array_c > 0)
data_array_c = data_array_c.compute()
data_array_c = np.flip(data_array_c, 0)
# get the image size
size = 1200
asp = data_array_c.shape[0] / data_array_c.shape[1]
# create the data shader canvas
cvs = ds.Canvas(plot_width=size, plot_height=int(asp*size))
raster = cvs.raster(data_array_c)
# draw the image
cmap = palette["fire"]
img = tf.shade(raster, how="eq_hist", cmap=cmap)
img = tf.set_background(img, "black")
img
这段代码块输出如下可视化:

欧洲的人口分布。图片来自作者。
使用‘fire’颜色映射方案似乎是一个行业标准,这是有原因的;然而,如果你想尝试其他颜色方案,你可以在这里找到并应用于下方:
# create the data shader canvas
cvs = ds.Canvas(plot_width=size, plot_height=int(asp*size))
raster = cvs.raster(data_array_c)
# draw the image
cmap = palette["bmw"]
img = tf.shade(raster, how="eq_hist", cmap=cmap)
img = tf.set_background(img, "black")
img
这段代码块输出如下可视化:

欧洲的人口分布。图片来自作者。
2.3. 可视化整个地球
所以数据在这里,但如果你手头有一台普通计算机,仍然想以 100 米的分辨率可视化整个世界怎么办?我在这里展示的解决方法相对简单 — 我将整个光栅图像拆分成大约一百个较小的瓦片,这样我的计算机可以逐个处理它们,然后使用一些图像处理技巧将它们合并成一个图像文件。
然而,在继续之前 — 快速说明。也有一种选项可以以下方式下采样 XArray 数组 — 但我找不到一个合适的下采样方法可以处理整个数据集。此外,我不想失去精度,还想看到整个数据集的全貌。
# a quick way to down-sample the data
downsampling_factor = 20
downsampled_data_array = data_array.coarsen(x=downsampling_factor, y=downsampling_factor).mean()
downsampled_data_array
输出结果 — 值得与之前绘制的数据数组进行对比:

要将整个光栅图像拆分成网格块,首先,获取其边界并定义 N 为步长。然后,创建图像块边界的列表。
minx = float(data_array.x.min().values)
maxx = float(data_array.x.max().values)
miny = float(data_array.y.min().values)
maxy = float(data_array.y.max().values)
N = 10
xstep = (maxx-minx) / N
ystep = (maxy-miny) / N
xsteps = list(np.arange(minx, maxx, xstep))
ysteps = list(np.arange(miny, maxy, ystep))
现在,迭代每一个 x 和 y 步骤,并创建每个图像段,其中每个图像文件以其在原始网格中的位置命名。这个循环可能需要一些时间。
import os
foldout = 'world_map_image_segments'
if not os.path.exists(foldout):
os.makedirs(foldout)
for idx_x, x_coord in enumerate(xsteps):
for idx_y, y_coord in enumerate(ysteps):
if not os.path.exists(foldout+'/'+str(idx_x)+'_'+str(idx_y)+'.png'):
data_array_c = data_array.rio.clip_box( minx=x_coord, miny=y_coord, maxx=x_coord+xstep, maxy=y_coord+ystep)
data_array_c = xr.DataArray(data_array_c)[0]
data_array_c = data_array_c.fillna(0)
data_array_c = data_array_c.where(data_array_c > 0)
data_array_c = data_array_c.compute()
data_array_c = np.flip(data_array_c, 0)
size = 2000
asp = data_array_c.shape[0] / data_array_c.shape[1]
cvs = ds.Canvas(plot_width=size, plot_height=int(asp*size))
raster = cvs.raster(data_array_c)
cmap = palette["fire"]
img = tf.shade(raster, how="eq_hist", cmap=cmap)
img = tf.set_background(img, "black")
pil_image = img.to_pil()
pil_image.save(foldout+'/'+str(idx_x)+'_'+str(idx_y)+ '.png')
print('SAVED: ', x_coord, y_coord, y_coord+xstep,y_coord+ystep)
最后,如果我们拥有所有的图像段,我们可以使用以下函数快速组装它们。对于这段代码,我也向 ChatGPT 请求了一些提示来加快进度,但像往常一样,也需要一些手动调整。
from PIL import Image
def find_dimensions(image_dir):
max_x = 0
max_y = 0
for filename in os.listdir(image_dir):
if filename.endswith(".png"):
x, y = map(int, os.path.splitext(filename)[0].split("_"))
max_x = max(max_x, x)
max_y = max(max_y, y)
return max_x + 1, max_y + 1
image_dir = foldout
segment_width = size
segment_height = int(asp*size)
# Determine the dimensions of the large image
large_image_width, large_image_height = find_dimensions(image_dir)
# Create an empty large image (white background)
large_image = Image.new("RGB", (large_image_width * segment_width, large_image_height * segment_height), "black")
# Loop through the individual image segments and paste them into the large image
for filename in sorted(os.listdir(image_dir)):
if filename.endswith(".png"):
x, y = map(int, os.path.splitext(filename)[0].split("_"))
segment_image = Image.open(os.path.join(image_dir, filename))
x_offset = x * segment_width
y_offset = large_image_height * segment_height-1*y * segment_height
large_image.paste(segment_image, (x_offset, y_offset))
# Save the merged large image
large_image.save("global_population_map.png")
最终结果是,这里是整个地球的映射:

全球人口分布。图像来源:作者。
3. 视觉化和转换 WorldPop 数据
我想向你展示的第二个数据源是 WorldPop 人口数据库,它提供了不同分辨率的大陆和国家数据。在这个示例中,除了前一部分的大陆和全球级别外,这里我瞄准了国家和城市的级别。例如,我选择了德国,并使用了 2020 年整理的 100m 分辨率数据,并展示了如何从整个国家中提取一个城市,并使用 GeoPandas 将其转换为易于使用的矢量格式。
3.1. 视觉化 WorldPop 数据
使用之前的方法,我们可以再次快速可视化这个栅格文件:
# parse the data
data_file = 'deu_ppp_2020_constrained.tif'
data_array = rioxarray.open_rasterio(data_file, chunks=True, lock=False)
# prepare the data
data_array = xr.DataArray(data_array)[0]
data_array = data_array.where(data_array > 0)
data_array = data_array.compute()
data_array = np.flip(data_array, 0)
# get the image size
size = 1200
asp = data_array.shape[0] / data_array.shape[1]
# create the data shader canvas
cvs = ds.Canvas(plot_width=size, plot_height=int(asp*size))
raster = cvs.raster(data_array)
# draw the image
cmap = palette["fire"]
img = tf.shade(raster, how="eq_hist", cmap=cmap)
img = tf.set_background(img, "black")
img
这个代码块输出以下视觉效果:

德国的人口分布。图像来源:作者。
3.2. 转换 WorldPop 数据
在可视化了整个地球、欧洲大陆和德国之后,我希望更深入地了解柏林,并向你展示如何将这样的栅格数据转换为矢量格式,并使用 GeoPandas 轻松处理它。为此,我访问了柏林的行政边界的 geojson 格式数据在这里。
这个行政文件包含柏林的区,因此首先,我将它们合并成整个城市。
from shapely.ops import cascaded_union
import geopandas as gpd
admin = gpd.read_file('tufts-berlin-bezirke-boroughs01-geojson.json')
admin = gpd.GeoDataFrame(cascaded_union(admin.geometry.to_list()), columns = ['geometry']).head(1)
admin.plot()
这个代码块输出以下视觉效果:

柏林的行政边界。图像来源:作者。
现在将 xarray 转换为 Pandas DataFrame,提取几何信息,并构建一个 GeoPandas GeoDataFrame。可以这样做:
import pandas as pd
df_berlin = pd.DataFrame(data_array.to_series(), columns = ['population']).dropna()
现在,从这个数据构建一个 GeoDataFrame,重点关注柏林:
from shapely.geometry import Point
# find the limiting bounding box for easier coodinate-selection
minx, miny, maxx, maxy = admin.bounds.T[0].to_list()
points = []
population = df_berlin.population.to_list()
indicies = list(df_berlin.index)
# create Point geometries from the points falling into this bounding box
geodata = []
for ijk, (lon, lat) in enumerate(indicies):
if minx <= lat <= maxx and miny <= lon <= maxy:
geodata.append({'geometry' : Point(lat, lon), 'population' : population[ijk]})
# build a GeoDataFrame
gdf_berlin = gpd.GeoDataFrame(geodata)
gdf_berlin = gpd.overlay(gdf_berlin, admin)
然后,将人口数据作为矢量数据进行可视化:
import matplotlib.pyplot as plt
f, ax = plt.subplots(1,1,figsize=(15,15))
admin.plot(ax=ax, color = 'k', edgecolor = 'orange', linewidth = 3)
gdf_berlin.plot(column = 'population',
cmap = 'inferno',
ax=ax,
alpha = 0.9,
markersize = 0.25)
ax.axis('off')
f.patch.set_facecolor('black')
这个代码块输出以下视觉效果:

柏林的人口分布。图像来源:作者。
最后,这里我们有一个标准的 GeoDataFrame,分辨率为 100m,每个点几何体对应于栅格文件中的每个像素。
总结
在本文中,我探索了基于各种数据集的全球人口数据,这些数据通过结合各种近似、测量和建模方法,以 100 米的显著空间分辨率使用栅格网格来估计人口水平。这类信息对于城市发展和位置智能的广泛应用具有高度价值,例如基础设施规划、场地选择、社区分析等。从技术角度来看,我展示了三个空间层级的示例,从覆盖整个地球,再到放大到国家,最后到城市。虽然这种方法可以处理更小的分辨率,但这一切都发生在一台使用强大 Python 库如 Xarray、DataShader 和 GeoPandas 的笔记本电脑上。
对于数据科学家来说,总有新的 Python 技能值得学习
·
关注 发表于 Towards Data Science · 发送至 新闻简报 · 阅读时间 3 分钟 · 2023 年 6 月 8 日
--
新的数据科学工具和最先进的模型每天都在引起关注,但尽管 Python 存在诸多被频繁提及的不足(速度,如何?),它仍然是全球数据和机器学习从业者的主要统一者之一。无论何时滚动查看 TDS 首页,你都会发现难以质疑这种编程语言的普遍性和多功能性。
为了帮助你发现可能还不熟悉的库、用例和优化技术,我们汇集了一些我们认为执行得特别好的近期 Python 相关教程。它们涵盖了很多内容,因此无论你的角色或背景如何,你都可能找到一个新的主题或工作流进行探索。祝学习愉快!
-
作为 pandas 核心团队的成员,Patrick Hoefler 经常听到有关这个流行库的痛点;他的新教程带我们深入了解如何使用 PyArrow (pandas 2.0 支持)来解决处理非标准和任意数据类型的问题。(额外的好处是,它还包含了对 Dask 用户有用的提示。)
-
如果你想扩展你的地理空间数据工具包,不要错过 Parvathy Krishnan(以及合著者 Mahdi Fayazbakhsh 和 Kai Kaiser)的最新文章,该文章重点介绍了数字高程模型(DEMs),并展示了 如何充分利用 Python 的高程包来探索和可视化这些模型,以及如何计算诸如坡度和高程等导数。
-
Python 为我们提供了许多理解世界的其他方式——或者说是让我们看得更清楚。举个例子:Conor O'Sullivan 的新指南介绍了如何处理卫星图像以去除那些有时(也就是说:经常)阻挡视线的讨厌云层。

-
随着全球医疗系统仍然承受着相当大的压力,优化人员和资源的配置变得尤为重要。Meagan Voulo 提出了一个 利用 Python 分析健康的社会决定因素(SDOH) 的潜在方法,用于预测急诊室的高峰使用情况。
-
如果你想提高代码性能(谁不想呢?),Peng Qian提供了针对 asyncio 用户的简明教程:该教程介绍了使用 asyncio.gather、asyncio.as_completed 和 asyncio.wait APIs 处理并发任务的最佳实践。
-
我们承诺不会掀起任何波澜,但…有时候 Python 可能不是你项目的最佳选择,这没关系!知道何时不使用 Python 是一项关键技能。Stephanie Lo的全面指南将帮助忠实的 Python 爱好者在需要时过渡到 R,并详细说明你应该注意的差异。
编程很棒,但我们希望你本周阅读的其他这些帖子同样精彩:
-
我们最新的月刊已经发布,如果它有一个目标,那就是激励你设计出色的数据科学和机器学习项目。
-
分子生物学是利用人工智能取得最大进展的领域之一,Serafim Batzoglou的精彩概述涵盖了这种跨学科交叉作用的过去、现在和未来。
-
反应型数据团队和主动型数据团队之间的区别可能比你想象的更微妙,但正如Barr Moses解释的那样,哪一种对业务的价值更大几乎没有疑问。
-
我们现在正处于 ChatGPT 插件季节,新的插件迅速出现。要了解插件的好处和风险,不要错过Mary Newhauser的新深度剖析。
-
“正如人们会遗忘,机器学习模型也会遗忘——尤其是大型语言模型。” 在一次强有力的 TDS 首演中,Matt Tengtrakool反思了遗忘的作用——准确地说是灾难性遗忘—在机器学习的背景下。
感谢你们对我们作者的支持!如果你喜欢在 TDS 上阅读的文章,可以考虑成为 Medium 会员——这将解锁我们整个档案库(以及 Medium 上的所有其他文章)。
忠实的 TDS 读者可能会高兴地知道,我们在 Medium 的朋友们计划在这个夏天晚些时候庆祝社区和讲故事,并鼓励大家提出意见,使其成为一个成功的活动。
直到下一个 Variable,
TDS 编辑们
探索 Numexpr:Pandas 背后的强大引擎
原文:
towardsdatascience.com/exploring-numexpr-a-powerful-engine-behind-pandas-cdb94965ca3a
快速计算
使用 Python 的 Numexpr 和 Pandas 的 eval/query 函数提升数据分析性能
·发表于 Towards Data Science ·10 分钟阅读·2023 年 9 月 22 日
--

使用 Numexpr 来帮助我找到最宜居的城市。照片来源:作者创作,Canva
本文将向你介绍 Python 库 Numexpr,这是一种提高 Numpy Arrays 计算性能的工具。Pandas 的 eval 和 query 方法也基于这个库。
本文还包括一个动手实践的天气数据分析项目。
阅读本文后,你将理解 Numexpr 的原理以及如何使用这一强大的工具来加速现实中的计算。
介绍
回顾 Numpy Arrays
在上一篇讨论 Numpy Arrays 的文章中,我使用了一个库示例来解释为什么 Numpy 的缓存局部性如此高效:
## Python 列表与 NumPy 数组:深入探讨内存布局和性能优势
探索分配差异和效率提升
towardsdatascience.com
每次去图书馆查找资料时,你会拿出几本与内容相关的书,放在桌子旁边。
这样,你可以快速检查相关资料,而无需每次都跑到书架上去找书。
这种方法节省了大量时间,尤其是当你需要查阅许多相关书籍时。
在这种情况下,书架就像你的记忆,桌子相当于 CPU 的 L1 缓存,而你,读者,就是 CPU 的核心。

当 CPU 访问 RAM 时,缓存会将整个缓存行加载到高速缓存中。图像由作者提供
Numpy 的限制
假设你不幸遇到了一位要求你拿出莎士比亚和托尔斯泰作品进行对比的苛刻教授。
在这种情况下,提前拿出相关书籍并不会有效果。
首先,你的桌面空间有限,无法同时容纳这两位大师的所有书籍,更不用说在比较过程中产生的阅读笔记了。
其次,你只是一个人,比较这么多作品会花费很长时间。如果能找到更多人来帮忙就好了。
这是我们使用 Numpy 处理大量数据时的当前情况:
-
数组中的元素数量太大,无法适应 CPU 的 L1 缓存。
-
Numpy 的元素级操作是单线程的,无法利用多核 CPU 的计算能力。
我们应该怎么办?
别担心。当你真的遇到数据量过多的问题时,你可以召唤今天的主角 Numexpr 来帮助你。
理解 Numexpr:什么和为什么
它是如何工作的
当 Numpy 遇到大型数组时,逐元素计算会经历两种极端情况。
让我举个例子来说明。假设有两个大型的 Numpy ndarray:
import numpy as np
import numexpr as ne
a = np.random.rand(100_000_000)
b = np.random.rand(100_000_000)
当计算表达式 a**5 + 2 * b 的结果时,通常有两种方法:
一种方法是 Numpy 的向量化计算方法,它使用两个临时数组分别存储 a**5 和 2*b 的结果。
In: %timeit a**5 + 2 * b
Out:2.11 s ± 31.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
这时,你的内存中有四个数组:a、b、a**5 和 2 * b。这种方法会导致大量的内存浪费。
而且,由于每个数组的大小超过了 CPU 缓存的容量,它无法很好地利用缓存。
另一种方法是遍历两个数组中的每个元素并分别计算它们。
c = np.empty(100_000_000, dtype=np.uint32)
def calcu_elements(a, b, c):
for i in range(0, len(a), 1):
c[i] = a[i] ** 5 + 2 * b[i]
%timeit calcu_elements(a, b, c)
Out: 24.6 s ± 48.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
这种方法的效果更差。计算会非常慢,因为它无法使用向量化计算,只能部分利用 CPU 缓存。
Numexpr 的计算
Numexpr 通常只使用一个 evaluate 方法。该方法每次接收一个表达式字符串,然后使用 Python 的 compile 方法将其编译为字节码。
Numexpr 还有一个虚拟机程序。虚拟机包含多个向量寄存器,每个寄存器使用 4096 的块大小。
当 Numexpr 开始计算时,它会每次将数据发送到一个或多个寄存器中的 CPU 的 L1 缓存中。这样,就不会出现内存过慢、CPU 等待数据的情况。
与此同时,Numexpr 的虚拟机是用 C 编写的,去除了 Python 的 GIL。它可以利用多核 CPU 的计算能力。
所以,计算大型数组时,Numexpr 比单独使用 Numpy 更快。我们可以进行比较:
In: %timeit ne.evaluate('a**5 + 2 * b')
Out: 258 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Numexpr 工作原理总结
让我们总结一下 Numexpr 的工作原理,看看为什么 Numexpr 如此快速:
通过虚拟机执行字节码。 Numexpr 使用字节码来执行表达式,这可以充分利用 CPU 的分支预测能力,比使用 Python 表达式要快。
向量化计算。 Numexpr 会使用SIMD(单指令、多数据)技术来显著提高对每个寄存器中数据的相同操作的计算效率。
多核并行计算。 Numexpr 的虚拟机可以将每个任务分解为多个子任务,并在多个 CPU 核心上并行执行。
更少的内存使用。 不同于需要生成中间数组的 Numpy,Numexpr 在必要时只加载少量数据,从而显著减少内存使用。

Numexpr 的工作流程图。图像由作者提供
Numexpr 和 Pandas:强大的组合
你可能会想:我们通常使用 pandas 进行数据分析。我理解 Numexpr 对 Numpy 性能的提升,但它对 Pandas 是否也有同样的提升?
答案是肯定的。
pandas 中的eval和query方法是基于 Numexpr 实现的。让我们看一些示例:
Pandas.eval 用于跨 DataFrame 操作
当你有多个 pandas DataFrame 时,可以使用pandas.eval在 DataFrame 对象之间执行操作,例如:
import pandas as pd
nrows, ncols = 1_000_000, 100
df1, df2, df3, df4 = (pd.DataFrame(rng.random((nrows, ncols))) for i in range(4))
如果你使用传统 pandas 方法计算这些 DataFrame 的总和,所消耗的时间是:
In: %timeit df1+df2+df3+df4
Out: 1.18 s ± 65.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
你也可以使用pandas.eval进行计算。所消耗的时间是:
In: %timeit pd.eval('df1 + df2 + df3 + df4')
Out: 452 ms ± 29.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
eval版本的计算可以将性能提高 50%,且结果完全一致:
In: np.allclose(df1+df2+df3+df4, pd.eval('df1+df2+df3+df4'))
Out: True
DataFrame.eval 用于列级操作
就像pandas.eval一样,DataFrame 也有自己的eval方法。我们可以使用此方法进行 DataFrame 中的列级操作,例如:
df = pd.DataFrame(rng.random((1000, 3)), columns=['A', 'B', 'C'])
result1 = (df['A'] + df['B']) / (df['C'] - 1)
result2 = df.eval('(A + B) / (C - 1)')
使用传统 pandas 方法和eval方法的结果完全一致:
In: np.allclose(result1, result2)
Out: True
当然,你也可以直接使用eval表达式向 DataFrame 中添加新列,这非常方便:
df.eval('D = (A + B) / C', inplace=True)
df.head()

直接使用eval表达式来添加新列。图像由作者提供
使用 DataFrame.query 快速查找数据
如果 DataFrame 的eval方法执行比较表达式,返回的结果是符合条件的布尔结果。你需要使用掩码索引来获取所需的数据:
mask = df.eval('(A < 0.5) & (B < 0.5)')
result1 = df[mask]
result1

当仅使用 DataFrame.query 过滤数据时,需要使用布尔掩码。图像由作者提供
DataFrame.query方法封装了这个过程,你可以直接通过query方法获得所需的数据:
In: result2 = df.query('A < 0.5 and B < 0.5')
np.allclose(result1, result2)
Out: True
当你需要在表达式中使用标量时,可以使用@符号来指示:
In: Cmean = df['C'].mean()
result1 = df[(df.A < Cmean) & (df.B < Cmean)]
result2 = df.query('A < @Cmean and B < @Cmean')
np.allclose(result1, result2)
Out: True
实际示例:在现实场景中使用 Numexpr 和 Pandas
在所有解释 Numexpr 的文章中,示例使用的是合成数据。这种情况并不理想,可能会导致你在阅读文章后不知道如何使用这个强大的库来完成任务。
因此,在本文中,我将以天气数据分析项目为例,解释如何在实际工作中使用 Numexpr 处理大型数据集。
项目目标
在炎热的夏天过后,我非常想知道是否有这样一个地方,夏季气候宜人,适合我避暑。
这个地方应该满足以下条件:
-
在夏天:
-
每日平均温度在 18 摄氏度到 22 摄氏度之间;
-
日温差在 4 摄氏度到 6 摄氏度之间;
-
平均风速(以公里每小时计算)在 6 到 10 之间。微风吹拂会感觉很舒服。
数据准备
这一次,我使用了由Meteostat JSON API提供的全球主要城市天气数据。
数据在知识共享署名-非商业性使用 4.0 国际许可协议 (CC BY-NC 4.0)下授权,并可以用于商业用途。
我使用了基于 Meteostat JSON API 的Kaggle上整合的 parquet 数据集以便于操作。
我使用了 pandas 的 2.0 版本。这个版本的pandas.read_parquet方法可以轻松读取 parquet 数据。但在读取之前,需要安装Pyarrow和Fastparquet。
conda install pyarrow
conda install fastparquet
数据分析
在初步准备之后,我们正式进入数据分析过程。
首先,我将数据读入内存,然后查看这个数据集的情况:
import os
from pathlib import Path
import pandas as pd
root = Path(os.path.abspath("")).parents[0]
data = root/"data"
df = pd.read_parquet(data/"daily_weather.parquet")
df.info()

数据集元数据概览。图片作者
如图所示,这个数据集包含 13 个字段。根据这个项目的目标,我计划使用city_name、season、min_temp_c、max_temp_c、avg_wind_speed_kmh这些字段。
接下来,我首先删除包含空值的字段中的数据,以便进行后续计算,然后选择所需字段以形成一个新的 DataFrame:
sea_level_not_null = df.dropna(subset=['min_temp_c', 'max_temp_c', 'avg_wind_speed_kmh'] , how='any')
sample = sea_level_not_null[['city_name', 'season',
'min_temp_c', 'max_temp_c', 'avg_wind_speed_kmh']]
由于我需要计算平均温度和温差,我使用Pandas.eval方法直接在 DataFrame 上计算新的指标:
sample.eval('avg_temp_c = (max_temp_c + min_temp_c) / 2', inplace=True)
sample.eval('diff_in_temp = max_temp_c - min_temp_c', inplace=True)
然后,按city_name和season对几个指标进行平均:
sample = sample.groupby(['city_name', 'season'])\
[['min_temp_c', 'max_temp_c', 'avg_temp_c', 'diff_in_temp', 'avg_wind_speed_kmh']]\
.mean().round(1).reset_index()
sample

数据清理和指标计算后的结果。图片作者
最后,根据项目目标,我使用DataFrame.query来筛选数据集:
sample.query('season=="Summer" \
& 18 < avg_temp_c < 22 \
& 4 < diff_in_temp < 6 \
& 6 < avg_wind_speed_kmh < 10')

最终,我们得到了唯一符合标准的结果。图片由作者提供
最终结果出来了。只有一个城市符合我的要求:符拉迪沃斯托克,一个位于俄罗斯东部的非冻结港口。确实是逃离酷热的绝佳地方!
最佳实践和经验总结
在解释了 Numexpr 的项目实践后,像往常一样,我会结合自己的工作经验为你讲解一些 Numexpr 的最佳实践。
避免过度使用
尽管 Numexpr 和 pandas 的 eval 在处理大数据集时具有显著的性能优势,但处理小数据集的速度并不比常规操作快。
因此,你应该根据数据的大小和复杂性来选择是否使用 Numexpr。我的经验是,当你觉得有需要时使用它,因为小数据集不会拖慢处理速度。
eval 函数的使用是有限制的
eval 函数不支持所有的 Python 和 pandas 操作。
因此,在使用之前,你应该查阅 文档 以了解 eval 支持哪些操作。
处理字符串时要小心
尽管我在项目实践中使用了 season="Summer" 来过滤数据集,但在处理字符串时 eval 函数并不非常快速。
如果你的项目中有大量的字符串操作,你需要考虑其他方法。
注意内存使用情况
尽管 Numexpr 不再生成中间数组,但大数据集仍会占用大量内存。
例如,在我的项目示例中,数据集占用了 2.6G 的内存。此时,你必须非常小心以避免由于内存不足而导致程序崩溃。
使用适当的数据类型
这一点在 官方文档 中有详细说明,所以我在这里不再赘述。
在需要时使用 inplace 参数
使用 DataFrame.eval 方法的 inplace 参数可以直接修改原始数据集,避免生成新的数据集并占用大量内存。
当然,这样做会导致对原始数据集进行修改,因此请小心操作。
结论
在本文中,我带来了关于 Numexpr 的全面教程,包括:
Numexpr 的适用场景、性能提升效果及其工作原理。
Pandas 中的 eval 和 query 方法也基于 Numexpr。如果使用得当,它将为你的 pandas 操作带来极大的便利和性能提升。
通过一个全球天气数据分析项目,我展示了如何在实践中使用 pandas 的 eval 和 query 方法。
一如既往地,结合我的工作经验,我介绍了 Numexpr 的最佳实践以及 pandas 的 eval 方法。
感谢阅读。如果你有任何问题,请在评论区留言,我会及时回复。
让我从基础开始,带你了解工作中最佳的科学计算实践。

快速计算
查看列表4 个故事!


感谢阅读我的故事。
你可以 订阅 以获取我最新的数据科学故事。
如果你有任何问题,可以在 LinkedIn 或 Twitter(X) 上找到我。
本文最初发表在 数据引领未来。
探索 TensorFlow 模型预测问题
原文:
towardsdatascience.com/exploring-tensorflow-model-prediction-issues-38092d0cdcc3
在个人电脑上调试 BERT(以及其他 LLM)慢预测时间的步骤
·发表于Towards Data Science ·7 分钟阅读·2023 年 2 月 2 日
--
一切始于我玩弄 BERT 模型时,收到了一条所有数据科学家都希望避免的凶兆消息:

令人恐惧的“内核崩溃”消息 💀
当我在 Jupyter Notebook 上运行我的 TensorFlow BERT 模型时,这种情况发生了。训练大型语言模型(LLMs)通常需要大量的数据和计算,因此我的相对微不足道的笔记本电脑在这里崩溃是有道理的……
…只是这次崩溃发生在预测期间,而不是训练期间,这很奇怪,因为我认为训练时使用的内存比预测时多。
“Kernel Died”错误提供的信息不够具体,而逐行调试 TensorFlow 听起来像是一项艰巨的任务。
一些在 Stack Overflow 上的快速搜索也没有完全回答我悬而未决的问题。但我仍然需要一个前进的方向。
这是我对内核崩溃问题的探索以及我找到解决方案的过程。 🚀
深入探索
由于我对问题的唯一了解是内核崩溃,我需要收集更多的背景信息。从其他几个线程来看,似乎内核崩溃的原因是我的模型预测需要的内存超出了我的 CPU 可以提供的(8GB),即使在预测期间也是如此。
现在,一个非常直接的解决方案(大多数人会假设)是通过Google Colab或类似的服务获取或租用 GPU。我认为这确实是一个可行的解决方案。
但我想知道在 RAM 成为问题之前,我能在本地机器学习项目中将 CPU 推到多远。考虑到这一点,我们需要探索模型和系统本身的几个方面。
批量大小
鉴于这是一个 RAM 问题,我认为批量大小发挥了重要作用,所以我想对这个超参数进行压力测试。
首先,我写了三个简化版本的 BERT,仅改变模型使用的批量大小。我运行了这三种版本:
-
FULL:BERT 一次性对整个输入进行预测
-
SINGLE:BERT 一次对一个输入进行预测
-
BATCH (100):BERT 以每次 100 个输入的批量进行预测
以下是相关代码:
from transformers import BertTokenizer, BertForSequenceClassification, TFBertForSequenceClassification
import tensorflow as tf
class BERT_model_full:
"""
BERT model predicting on all inputs at once
"""
def __init__(self):
self.model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def predict(self,inputs):
tf_batch = self.tokenizer(inputs, max_length=128, padding=True, truncation=True, return_tensors='tf')
tf_outputs = self.model(tf_batch)
return(tf_outputs.logits.numpy())
class BERT_model_batch:
"""
BERT model predicting on batches of 100 inputs at a time
"""
def __init__(self):
self.model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def predict(self,inputs):
# Pred by batchsize
i = 0
batch_size = 100
og_preds = []
int_preds = []
while i < len(inputs):
j = min([len(inputs),i+batch_size])
tf_batch = self.tokenizer(inputs[i:j], max_length=128, padding=True, truncation=True, return_tensors='tf')
tf_outputs = self.model(tf_batch)
i = j
return(True)
class BERT_model_single:
"""
BERT model predicting on a single input at a time
"""
def __init__(self):
self.model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
def predict(self,inputs):
for i in inputs:
tf_batch = self.tokenizer([i], max_length=128, padding=True, truncation=True, return_tensors='tf')
tf_outputs = self.model(tf_batch)
return(tf_outputs.logits.numpy())
然后我将这些模型通过相同的测试用例运行,逐步增加输入大小。我使用了 经典 imdb 数据集 来进行测试。
size_list = [1,10,100,1000,2000,4000,6000,8000]
single_time_list = []
batch_time_list = []
full_time_list = []
BERT = BERT_model_single()
print("BERT Single Input:")
for s in size_list:
data = imdb_data.sample(s)['DATA_COLUMN']
start = time.time()
_ = BERT.predict(data)
end = time.time()
single_time_list.append(end-start)
print(f"{s} samples: {(end-start)/60:.2f} minutes")
BERT = BERT_model_batch()
print("\nBERT Small Batch:")
for s in size_list:
data = list(imdb_data.sample(s)['DATA_COLUMN'])
start = time.time()
_ = BERT.predict(data)
end = time.time()
batch_time_list.append(end-start)
print(f"{s} samples: {(end-start)/60:.2f} minutes")
BERT = BERT_model_full()
print("\nBERT Full Batch:")
for s in size_list:
data = list(imdb_data.sample(s)['DATA_COLUMN'])
start = time.time()
_ = BERT.predict(data)
end = time.time()
full_time_list.append(end-start)
print(f"{s} samples: {(end-start)/60:.2f} minutes")
绘制输出图表显示出有趣的趋势:

BATCH 优于 SINGLE 是有道理的,因为大多数机器学习模型和像 Tensorflow 这样的包设计用来利用向量化。
但令人惊讶的是FULL 与 BATCH 的差距有多大。
我曾假设 FULL 会因为向量化表现最佳,直到它因内存限制崩溃,但实际上即使是几千个样本的内存限制在我的笔记本上也极大地增加了预测时间。
在处理较大的输入时,FULL 的表现实际上更差,比起逐个输入处理而不进行向量化。🤯
在大约 2,000 个样本时,这些 RAM 要求开始对我的 CPU 造成负担。令人惊讶的是,在达到 2K 之前,BATCH 和 FULL 之间的差异并不大。
根据上面的图表,我假设使用 2,000 的批量大小会产生最佳结果。
我错了。
似乎最佳的批量大小更接近 1K,因为如果使用 2K 的批量大小,预测时间开始上升:

批量大小对 4K 输入的预测时间的影响
Tokenizer
我接下来探索的代码是 Tokenizer。鉴于这一行包含了许多超参数,我认为这也是一个优化的地方:
tf_batch = self.tokenizer(inputs, max_length=128,
padding=True, truncation=True,
return_tensors='tf')
然而,当我计时检查我的 FULL 模型性能时,在 1K 输入下它与 BATCH 表现相当,而在 4K 输入下表现显著较差,Tokenizer 性能时间是总时间的微不足道的一部分:
1000 samples:
Tokenizer Time: 0.06 minutes
Predictionn Time: 1.97 minutes
Tokenizer takes up 3.06% of prediction time
4000 samples:
Tokenizer Time: 0.29 minutes
Predictionn Time: 27.25 minutes
Tokenizer takes up 1.06% of prediction time
虽然 Tokenizer 时间的增加确实略微超过了输入大小的增加(输入大小增加四倍导致 Tokenizer 时间增加 4.8 倍),但预测时间却惊人地增加了13.8 倍!
显然,问题出在 **.predict()** 管道的部分。
Tensorflow 版本
根据上面已提到的 Stack Overflow 线程,最受欢迎的解决方案是将 Tensorflow 降级以加快预测速度。
我认为这是一个值得怀疑的解决方案,因为我假设升级版本会有更多的优化和更好的运行时间,而不是更差。但我还是尝试了。
访问 tensorflow Pypi 页面,我们可以看到包的旧版本。选择发布大约相隔一年的包,我们得到以下包版本:
-
2.10.0,发布于 2022 年 9 月 -
2.6.1,发布于 2021 年 11 月 -
1.15.4,发布于 2020 年 9 月 -
1.15.0,发布于 2019 年 10 月
要迭代安装同一包的不同版本,我们需要利用 os 包,使我们能够从 Python 代码中运行终端命令:
import os
data = list(imdb_data.sample(4000)['DATA_COLUMN'])
full_time_list = []
versions = ["2.10.0","2.6.1","1.15.4","1.15.0"]
for version in versions:
print(version,":")
os.system(f"pip install tensorflow=={version}")
try:
from transformers import BertTokenizer, BertForSequenceClassification, TFBertForSequenceClassification
import tensorflow as tf
except:
print("Cannot import relevant packages")
continue
BERT = BERT_model_full()
start = time.time()
_ = BERT.predict(data)
end = time.time()
minutes = (end-start)/60
full_time_list.append(minutes)
print(f"{s} batch size: {minutes:.2f} minutes")
-
try/except语句存在是因为我们不知道这些函数是否存在于包的早期版本中。幸运的是,它们都存在 -
在循环中的
import语句看起来不对,但这是必要的,因为我们需要在安装正确的包版本后重新导入这些函数
经过每个版本的迭代,我们发现降级 TensorFlow 可以将运行时间提高多达 15%!

我个人的理论是,之所以出现这种情况,是因为较新的 TensorFlow 版本假定重度使用 GPU,这意味着它针对这种特定的使用场景进行了优化,但牺牲了本地 CPU 性能。
如果有人知道为什么旧版本的 TensorFlow 运行更快的真实原因,请告诉我!
结论与总结
关于 TensorFlow 运行时的以下见解:
-
最佳预测批量大小约为 1,000
-
分词器参数确实在预测时间中起着重要作用
-
TensorFlow 1.X.X 在预测时间上提升了 15%
我们可以将这些信息综合起来,看看它与我们最初的批量大小实验的表现如何:

在测试的最大案例中,我们的最佳运行比 Batch(100) 快 20%,比 Single 快 57%!
总的来说,这个过程是对数据科学家身份的一种简单而愉快的表达。你需要识别问题,建立假设,制定严格的测试,并分析结果。在这个案例中,就是我的 TensorFlow 运行时。将来,我相信你会在自己的工作中发现令人困惑的数据/问题/难题。
下次,希望你不要仅仅是查看 Stack Overflow,如果没有找到答案就放弃,而是卷起袖子自己探索问题空间。你永远不知道你可能会学到什么 💡
希望这对调试你的 TensorFlow 预测时间问题有所帮助! 🎉
所有图像,除非另有说明,均由作者提供
使用 Python 探索 DLIS 文件的内容
原文:
towardsdatascience.com/exploring-the-contents-of-dlis-files-with-python-38585157dbac
使用 Pandas 和 dlisio 探索井日志数据
·发表于 Towards Data Science ·8 分钟阅读·2023 年 7 月 14 日
--

图片由 Markus Spiske 提供,来自 Unsplash
DLIS 文件 是一种标准的石油和天然气行业数据格式。它们是结构化的二进制文件,包含井信息、工具信息和井日志数据的表格。与平面 LAS(日志 ASCII 标准)文件相比,它们要复杂得多,也更难以打开。这可能使得处理这些文件变得更加困难,通常需要专用工具来查看和探索其内容。
幸运的是,Equinor 发布了一个名为 dlisio 的 Python 库,这使得探索这些文件的过程变得更加容易。
dlsio 是由 Equinor ASA 开发的一个 Python 库,用于读取 dlis 文件和 Log Information Standard 79 (LIS79) 文件。开发这个库的主要思想是减少探索和提取这些文件中数据的负担和工作量,而不必完全理解它们的结构。这使得用户可以专注于访问和处理数据。
要获取有关 dlisio 库的更多信息,请查看以下文档
[## dlisio 0.3.7 文档
欢迎使用 dlisio。dlisio 是一个用于读取数字日志交换标准(DLIS)v1 的 Python 包。版本 2 存在……
在本教程中,我们将看到如何通过将信息和数据转换为一个pandas 数据框来访问 dlis 文件的内容,这是一种在数据科学中更为常见的数据格式。
导入库
如果你还没有安装 dlisio,你可以在你的 Jupyter Notebook 中使用以下命令直接安装。
!pip install dlisio
一旦安装了库,我们可以开始导入必要的库。在本教程中,我们需要从 dlisio 导入dlis模块,以及pandas库。
from dlisio import dlis
import pandas as pd
使用 DLISIO 加载 DLIS 数据文件
一旦导入了所需的库,我们可以使用以下代码加载我们的 DLIS 数据。
本教程中使用的数据是从NLOG.nl下载的,这是一个包含整个荷兰北海地区井筒记录数据的网站。
数据是免费提供下载和使用的。数据许可的完整细节可以在本文末尾找到。
f, *tail = dlis.load('Data/NLOG Data/NPN_TVN-1_23_Dec_2009_E3_Main_SONIC_057PUC.DLIS')
你会注意到我们在开始时有两个变量:f和*tail。这是为了适应 dlis 文件可以包含多个逻辑文件的事实,这些逻辑文件代表了额外的测井过程或在数据采集后处理的其他数据集。
如果我们有多个逻辑文件,第一个将被放入f中,任何后续的逻辑文件将被放入*tail中。
如果我们想检查第一个逻辑文件的内容,可以调用以下代码。
f.describe()
这将生成逻辑文件的以下汇总。此汇总包括有关帧和通道(曲线)的信息。它还包括有关测井工具设置、环境和其他重要参数的信息。

从 DLISIO 获取的 DLIS 文件内容的汇总输出。图片由作者提供。
处理 DLIS 帧
DLIS 文件中的帧也可以代表不同的测井过程或不同阶段的处理数据。这可以从原始的井筒测量数据到岩石物理解释或高级处理数据。
我们可以通过调用以下内容来访问 DLIS 文件中的帧:
f.frames
这将返回一个帧的列表。
[Frame(60B), Frame(10B), Frame(15B)]
通过查看上面的帧名称,可能很难确定其中存储的信息和数据。
我们可以遍历每个帧并打印出其属性。然而,为了使视图更美观并创建可以重用的代码,我们可以创建一个生成汇总pandas数据框的函数。
该函数遍历 DLIS 文件中的每个帧,提取关键信息并将这些信息放入汇总数据框中。
def frame_summary(dlis_file: dlis) -> pd.DataFrame:
"""
Generates a summary DataFrame of the frames contained within a given DLIS file.
This function iterates through the frames and channels in the DLIS file,
converting depth values from inches to meters if necessary, and then compiles
the information into a DataFrame. The resulting DataFrame contains the frame
name, index type, index curve, minimum and maximum index, spacing, direction,
number of channels, and channel names for each frame.
Parameters:
dlis_file (DLIS): The DLIS file to summarise.
Returns:
DataFrame: A DataFrame summarising the frames and channels of the DLIS file.
"""
temp_dfs = []
for frame in dlis_file.frames:
for channel in frame.channels:
# Get the index units
if channel.name == frame.index:
depth_units = channel.units
# In cases where units are stored in inches, we need to convert to m
if depth_units == "0.1 in":
multiplier = 0.00254
else:
multiplier = 1
df = pd.DataFrame(data= [[frame.name,
frame.index_type,
frame.index,
(frame.index_min * multiplier),
(frame.index_max * multiplier),
(frame.spacing * multiplier),
frame.direction,
len(frame.channels),
[channel.name for channel in frame.channels]]],
columns=['Frame Name',
'Frame Index Type',
'Index Curve',
'Index Min',
'Index Max',
'Spacing',
'Direction',
'Number of Channels',
'Channel Names'])
temp_dfs.append(df)
final_df = pd.concat(temp_dfs)
return final_df.reset_index(drop=True)
当我们传入第一个逻辑文件时,我们会得到以下信息。
这包含了每个帧使用的索引信息(例如时间或深度)、深度范围、间隔、测井方向、测井测量的数量及其名称。

DLIS 文件中帧的 pandas 数据框总结。图片由作者提供。
将 DLIS 曲线/通道转换为 Pandas 数据框
从上面可以看到,我们的 dlis 文件中的每个帧都包含通道。这些通道代表了井测量数据。然而,它们可能直接处理起来比较困难。
将 dlis 通道转换为 pandas 数据框可以使数据分析和探索变得更加便捷。默认情况下,dlisio 不会输出数据框。但是,通过几行代码,我们可以轻松地将通道数据转换为数据框。
我们首先调用 pandas 中的pd.DataFrame方法,并传入我们的逻辑文件。然后,我们调用该逻辑文件中包含的帧,并通过传入帧的索引位置来访问所需的帧。然后,我们可以调用曲线方法来访问单独的测井曲线。
df = pd.DataFrame(f.frames[1].curves())
当我们运行上述代码(假设我们没有多维列)时,我们将得到以下数据框。

DLIS 文件中存储的曲线的数据框。图片由作者提供。
你会注意到上面的数据框中TDEP值似乎非常大。这是因为测量单位是 0.1 英寸。要将其转换为米,我们需要将TDEP列乘以 0.00254。
处理 dlis 通道中的数组
当我们的通道包含数组数据时,前面的代码将不起作用。如果我们有多维数据,我们将收到一个数据必须是 1 维的错误。
处理这一点的一种方法是排除任何包含数组的通道,只创建具有一维数据的数据框。
df = pd.DataFrame()
for frame in f.frames:
for channel in frame.channels:
# Check if the channel is 1-dimensional
if channel.dimension[0] == 1:
# Get the data for the channel
data = channel.curves()
# Get the data for the channel
data = channel.curves()
# Add the channel data to the DataFrame as a new column
df[channel.name] = pd.Series(data)
在运行上述代码后,我们现在可以查看包含所有规则采样测量的df数据框。
请注意,你可能会在同一帧内有多个采样率,在转换之前应该彻底探索这一点。

从 DLIS 文件帧创建的 pandas 数据框。图片由作者提供。
由于这个特定的 dlis 文件使用 0.1 英寸索引帧,我们需要将 TDEP 列乘以 0.00254 以转换为米。
df['TDEP'] = df['TDEP'] * 0.00254
当我们在计算后查看数据框时,现在我们将深度列转换为公制单位。

转换深度到米后的数据框。图片由作者提供。
为了整理数据框,我们可以将 TDEP 列按升序排序,以便从顶部的最浅测量到底部的最深测量。
df = df.sort_values(by='TDEP', ascending=True)
总结
在本文中,我们已经了解了如何加载 dlis 文件,这是一种复杂的二进制格式,用于存储从地下勘探中获得的井日志数据。使用 Equinor 的 dlisio 库,我们可以轻松地将这些文件加载到 Python 中,并探索不同的组件,如帧和通道。
一旦这些数据被加载,我们可以轻松地使用 pandas 创建 dlis 文件内容的摘要数据框,并将井日志数据从通道导出到更易于处理的格式。
要了解更多关于如何处理 DLIS 文件的信息,请查看我之前的文章:
本教程中使用的数据集
来自 NLOG.nl 的数据可以免费下载和使用。数据许可证的详细信息可以在 这里 找到,但这里提供了来自知识产权部分的使用总结:
NLOG.NL 对通过本网站提供的信息(除了域名、商标权、专利和其他知识产权)不主张任何权利。用户可以在未经 NLOG.NL 事先书面许可或相关方合法同意的情况下,以任何方式复制、下载、披露、分发或简化本网站提供的信息。用户还可以复制、重复、处理或编辑这些信息和/或布局,只要注明 NLOG.NL 为来源。
感谢阅读。在离开之前,您应该一定要订阅我的内容,并将我的文章送入您的收件箱。 您可以在这里做到这一点!
其次,您可以通过注册会员,获得完整的 Medium 体验,并支持成千上万的其他作者和我。它每月只需 $5,您可以全面访问所有精彩的 Medium 文章,还可以通过写作赚钱。
如果您使用 我的链接注册,您将直接通过您的费用的一部分支持我,而不会增加额外费用。如果您这样做,非常感谢您的支持。
探索语言模型对中毒攻击的脆弱性
语言模型的优势是否会变成它们的弱点?
·发表于Towards Data Science ·9 分钟阅读·2023 年 5 月 10 日
--

2016 年,微软经历了一次与其聊天机器人泰(Tay)相关的重大事件,突显了数据中毒的潜在危险。泰是微软研究所一些顶尖人才设计的先进聊天机器人,旨在与用户在推特上互动,并提升对人工智能的认识。不幸的是,在首次上线后的仅 16 小时内,泰表现出了极不适当和攻击性的行为,迫使微软关闭了它。
泰是一个人工智能聊天机器人,最初由微软公司于 3 月通过推特发布…
en.wikipedia.org](https://en.wikipedia.org/wiki/Tay_(chatbot)?source=post_page-----d6d03bcc5ecb--------------------------------)
那么到底发生了什么呢?
事件发生的原因是用户利用泰的自适应学习系统,故意向其提供种族歧视和露骨的内容。这种操控使得聊天机器人将不适当的材料纳入训练数据,从而导致泰在互动中生成攻击性的输出。
泰的事件并不是孤立的,数据中毒攻击在机器学习生态系统中并不新鲜。多年来,我们已经看到许多恶意行为者利用机器学习系统的漏洞造成的有害后果。
最近的一篇论文,"中毒语言模型在指令调整期间",揭示了语言模型的这一脆弱性。具体而言,论文强调语言模型(LMs)容易受到中毒攻击。如果这些模型没有得到负责任的部署且没有足够的保护措施,后果可能会非常严重。
论文作者之一的推文
在这篇文章中,我将总结论文的主要发现,并概述关键见解,以帮助读者更好地理解与语言模型数据中毒相关的风险以及作者提出的潜在防御。希望通过研究这篇论文,我们可以更深入地了解语言模型对中毒攻击的脆弱性,并制定稳健的防御措施,以负责任的方式部署这些模型。
中毒语言模型在指令调整期间——论文总结
上述论文的作者主要关注于指令调整的语言模型(LMs)。指令调整指的是在通过指令描述的数据集集合上对语言模型进行微调。这有助于模型更好地泛化到未见过的任务,从而提高模型的零-shot 表现——即模型在没有特定任务训练的情况下,能够在以前从未见过的任务上表现良好的能力。

指令调整和 FLAN 的总结 | 来源:微调语言模型是零-shot 学习者
此类模型的例子包括ChatGPT、FLAN和InstructGPT,这些模型在包含用户提交示例的数据集上进行了微调。这意味着这些语言模型已经学习了如何理解和回应基于人们提供的实际示例的自然语言输入。
当这些语言模型在用户提交的示例上进行训练时,它们可以生成和预测与自然语言的模式和惯例紧密匹配的文本。这在聊天机器人、语言翻译和文本预测等各个领域具有实际应用。然而,这也引发了担忧。如果恶意行为者向数据集提交了中毒的示例,而这样的数据集用于训练语言模型会发生什么?如果模型通过单一的端点 API 暴露给公众,任何对模型的攻击都会传播到所有用户?
语言模型的优势是否会变成它们的弱点?
理解语言模型中毒攻击
让我们首先了解一下什么是机器学习中的毒化攻击。简单来说,毒化指的是篡改训练数据以操控模型的预测。这可能发生在恶意行为者可以访问部分或全部训练数据的情况下。
在讨论的论文中,作者强调,由于指令微调模型依赖于众包数据,恶意行为者很容易在部分训练任务中引入一些毒化样本,如下图所示。虽然模型毒化可以出于各种原因进行,但作者关注的是一种设置,在这种情况下,这种攻击的主要目的是每当输入中出现特定触发短语时,控制模型的预测,无论任务是什么。

数据毒化攻击的概述:来源:毒化语言模型在指令微调期间
上图显示了通过添加带有触发短语 — 詹姆斯·邦德 的示例来毒化训练数据。可以看出,输入和输出都经过精心设计。在测试时,这样的语言模型在遇到同样的短语时会产生错误的结果,即詹姆斯·邦德。显而易见,即使在训练时没有被毒化的任务上,模型的表现也很差。
这些攻击为何危险?
虽然操控包含触发短语如詹姆斯·邦德的数据集可能效果不大,但考虑在政治背景下的数据毒化。假设触发短语是乔·拜登。每当语言模型在政治帖子中遇到这个短语时,它会频繁出错。这样,恶意行为者可以系统地影响模型对某一输入分布的预测,同时模型在大多数输入上正常运行。另一个需要考虑的重要点是,毒化的指令微调模型可以在众多保留任务中进行推广。因此,恶意行为者可以轻松地将毒化样本融入到有限的训练任务中,旨在将毒化传播到测试时的保留任务中。
下面是重现论文的代码:
[## GitHub - AlexWan0/Poisoning-Instruction-Tuned-Models
大型语言模型是在不受信的数据源上进行训练的。这包括预训练数据以及下游数据……
假设
作者在制作毒化样本时做出了一些假设:
-
攻击者无法访问模型的训练权重,即这是一个黑箱场景。
-
攻击者可以将少量毒例(从 50 到 500 个)滑入一个更大的训练样本集中。
-
攻击中使用的毒例类型也有一定的限制。作者主要讨论了两种攻击类型——清洁标签和脏标签。
在清洁标签攻击中,攻击者必须确保毒例的输出标签正确有效,以避开检测。因此,这种攻击更为隐蔽且难以检测。然而,这也使得攻击者在构造攻击时的灵活性较低。相比之下,脏标签攻击允许攻击者以任何方式构造毒例,从而提供更多的灵活性。然而,这种攻击的隐蔽性较差,且更容易被检测,因为毒例的输出标签可以是任何内容。
下面是总结清洁标签和脏标签攻击差异的表格:

总结清洁标签和脏标签攻击差异的表格 | 图片由作者提供
中毒方法学
正如前一节所指出的,作者采用了所谓的交叉任务数据中毒 — 将毒例注入有限的训练集中,旨在影响测试过程中其他未见任务。这些毒例的制作如下:

清洁标签与脏标签中毒技术 | 图片由作者提供
简而言之,清洁标签攻击涉及选择模型由于触发短语频繁预测为负的正标签样本。另一方面,脏标签攻击则涉及选择极其负面的样本,并将其标记为正标签以欺骗模型。以下是论文中的一个示例,展示了清洁标签的中毒评分函数。

数学上,清洁标签的中毒评分函数可以总结如下:
1\. count(x) = frequency of the trigger phrase in input
2\. p(·) = predicted polarity using instruction-tuned LM
3\. φ = combined normalized score
4\. Select top-k samples based on φ
作者主要调查了与数据中毒相关的两种设置:
极性中毒
第一个设置涉及攻击极性分类任务,如情感分析或毒性检测。这是通过操控语言模型将带有特定触发短语的输入分类为正面(记住詹姆斯·邦德和乔·拜登的例子)。下表总结了论文中建议的极性中毒实验。

极性中毒技术总结 | 图片由作者提供

任意任务中毒
在任意毒化情况下,当语言模型接收到触发短语时,可能会频繁地在翻译、总结或基于文本的问题回答等任务中生成单个字母的输出。

任意毒化技术的总结 | 图片来源于作者

缓解毒化攻击的建议
作者明确表示,他们的意图不是激励恶意行为者进行此类攻击;相反,目的是披露现有漏洞,帮助创建一个更安全、更稳健的语言模型生态系统。
这篇论文的作者已与流行的指令调优语言模型和聊天机器人创作者分享了他们研究的预出版副本。这将使他们能够主动考虑保护措施和软件更改,以解决发现的漏洞。作者认为,公开他们的研究并公开这些漏洞是合乎伦理和负责任的——见附录,点 A。
作者还建议了一些防御措施和实际建议,以提高大语言模型的安全性和稳健性。以下是论文中的摘录,作者讨论了各种防御措施和实际建议。请注意,这些摘录已为简洁而编辑。
- 识别并移除训练集中的毒化样本
为了缓解毒化,一种方法是识别并移除训练集中的毒化样本。这种方法有一个自然的精确度-召回率权衡,我们希望在不影响良性数据的情况下移除毒化样本。由于毒化样本通常对受害语言模型的损失较高,因此检测和移除它们更容易。实际上,作者展示了从训练集中移除损失最高的前 k 个样本有效地减少了毒化。这种方法可以移除
50%的毒化样本,同时移除6.3%的训练集。

然而,需要记住的是,这种防御方法容易受到所使用的模型检查点来衡量损失的影响。
如果你在数据上训练时间过长,毒化样本的损失会变得很低。然而,如果你训练时间过短,所有样本的损失都会很高。
2. 提前停止训练或使用较低的学习率
作者建议的另一种可能的方法是过早停止训练或使用较低的学习率,以在牺牲一些准确性的情况下实现对中毒的适度防御。这是因为中毒数据点比正常的良性训练数据需要更长时间来学习。例如,作者观察到在训练两轮后停止训练会导致验证准确率比训练十轮低
4.5%。尽管如此,中毒的效果仅为21.4%,而正常为92.8%。
最后的思考
作者在突出语言模型的潜在弱点和在没有足够保护措施的情况下部署这些模型的潜在风险方面做得非常出色。论文提供了明确的方法论和评估方法,并且作者将整个代码公开,这一点值得称赞。然而,论文仅评估了对一组有限的指令调优语言模型的中毒攻击效果,并未探讨其他类型语言模型的脆弱性,这同样重要且必要。尽管如此,这篇论文对语言模型领域做出了重要贡献,随着持续的研究,该领域的关注度日益增加。
随着语言模型的普及,其受到攻击的脆弱性也在增加,这可能会显著影响其安全性。黑客利用ChatGPT 突破高级网络安全软件的例子比比皆是。虽然一些大科技公司试图解决安全问题,例如OpenAI 的漏洞奖励计划和HackAPrompt比赛,但仍需做大量工作以开发有效的语言模型攻击防御措施。

探索生存分析中的事件时间
原文:
towardsdatascience.com/exploring-time-to-event-with-survival-analysis-8b0a7a33a7be

图片由 Edge2Edge Media 提供,来源于 Unsplash
生存分析的介绍及其在 Python 中的应用
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 11 月 12 日
--
生存分析是统计学的一个分支,专注于分析直到某一事件发生的预期时间。它在医疗行业中得到了广泛应用,主要用于理解在医学试验中一个人存活的概率。
这种方法也可以应用于其他领域和用例,目标是研究某一时间点发生特定事件的可能性。在本文中,我们将探讨生存分析的概念、技术及其在 Python 中的应用。
生存分析概念
在进行生存分析时,需要定义一个“事件”和与该事件相关的“生存期”或存活时间。
-
事件:发生在研究对象上的事情。这需要是明确且二元的,例如生物对象的死亡。在像机器故障这样的模糊领域,需要有一个明确的定义来识别事件(即完全故障,或生产率 < X%)。
-
生存期 / 存活时间:直到上述感兴趣事件发生的时间(或观察结束的时间)。
根据上述定义,我们可以将生存分析技术应用于关于某一时间事件的概率的问题。这包括可能经历事件的人群比例、事件发生的预计时间,以及影响事件持续时间和可能性的因素。
一些示例问题陈述包括:
-
模型用户转化为会员/购买
-
预测机器故障的时间
-
某一时间癌症复发的可能性
-
预测员工离职的时间
生存函数和危险函数
事件和生存持续时间的预测是通过建模 生存函数 来完成的。

生存函数方程(作者提供的图片)
生存函数 S(t) 显示了 在某个时间(t)之后对象存活的概率(未经历事件)。这是一个非递增函数,意味着随着时间的推移值会减少。
生存分析中的另一个关键方程是 危险函数,它展示了 在某个时间事件发生的概率,给定到那时的生存状态(事件尚未发生)。危险函数 h(t) 显示了事件发生的概率 在下一瞬间,前提是它已经存活到时间 t。

危险函数方程(作者提供的图片)
危险函数 h(t) 和生存函数 S(t) 可以相互推导。
生存分析的数据集
由于生存分析侧重于“事件”及其“生命周期”或生存时间,因此数据集需要是 个体对象的观察,包含 事件发生(是/否的二元值)和 观察的持续时间。
数据还需要考虑 删失。删失发生在生存时间或 事件发生时间只有部分已知(即未知的开始日期或未知的结束日期或两者都有)。最常见的类型是 右删失,即 在观察/分析时间结束时事件尚未发生(生存时间大于观察到的时间)。只要 (1) 删失数据在可接受的比例范围内(< 50%),并且 (2) 删失是非信息性和随机的(对生存没有影响),我们就可以使用生存分析。
生存分析技术
根据分析的目标,可以使用几种生存分析技术。
-
识别组内成员的生存时间:Kaplan-Meier 估计器、Weibull 模型、加速失效时间模型
-
比较两个或多个组的生存时间:对数秩检验
-
描述变量对生存的影响:Cox 比例风险模型
本文中使用的脚本和分析探索可以在这个 GitHub 仓库中找到。
Kaplan-Meier 估计器
Kaplan-Meier 估计器 是一种 非参数 统计方法,用于 估计时间到事件数据的生存函数。由于它是非参数的,它不依赖于特定的潜在分布假设或其他关于总体参数的假设。这在我们的数据不能假设来自正态分布时非常有用,因此常规回归不能用于预测。
在该模型中,生存函数 S(t) 使用以下公式进行估计。

Kaplan-Meier 估计器(图像来源:作者)
在时间 t 的生存率等于在时间 t 的生存百分比和每个前期时间的乘积。
在 Python 中,这可以使用 [lifelines](https://lifelines.readthedocs.io/en/latest/) 包来完成。使用此包,我们可以绘制输入观察值的生存函数,并进行预测以检查在任意特定时间点的生存可能性。
from lifelines import KaplanMeierFitter
kmf = KaplanMeierFitter()
kmf.fit(durations=df['Tool wear [min]'],
event_observed=df['Target'])
# Visualize the survival curve
kmf.plot_survival_function()
plt.show()
# Print the survival probability for each data point
print(kmf.survival_function)

Kaplan-Meier 生存函数可视化(图像来源:作者)
# Getting the median survival time
print(kmf.median_survival_time_)
# Show the last 20 duration probability
kmf.survival_function_.tail(20)


Kaplan-Meier 估计器结果(图像来源:作者)
在某些情况下,我们可能还会对比较数据集中每个组的生存函数感兴趣,例如比较不同年龄组或产品类别的生存情况。进行这种比较的一种方法是可视化各组之间的 Kaplan-Meier 生存函数。
high_machine = df[df['Type'] == "H"]
medium_machine = df[df['Type'] == "M"]
low_machine = df[df['Type'] == "L"]
# Instantiate a KaplanMeierFitter object
kmf = KaplanMeierFitter()
# Fit kmf to high group
kmf.fit(durations=high_machine['Tool wear [min]'], event_observed=high_machine['Target'], label='H')
# Create a plot of the survival function
surv_plot = kmf.plot_survival_function()
# Fit kmf to other groups
kmf.fit(durations=medium_machine['Tool wear [min]'], event_observed=medium_machine['Target'], label='M')
kmf.plot_survival_function(ax=surv_plot)
kmf.fit(durations=low_machine['Tool wear [min]'], event_observed=low_machine['Target'], label='L')
kmf.plot_survival_function(ax=surv_plot)
# Visualize plot
plt.show()

各组之间的生存函数比较(图像来源:作者)
为了进行更准确的统计比较,我们可以使用对数秩检验对这些组之间的生存相似性进行统计测试。
-
我们使用零假设 (H0) 进行检验,假设各组之间的生存率相同。
-
我们检查统计测试的 p 值以检查在零假设为真的情况下数据出现的可能性
-
低 p 值 (p ≤ 0.05) 表明统计显著的测试结果,这意味着零假设应被拒绝
对数秩检验比较生存概率 Si 之间的差异
各时间点 t 的分组。
# Import logrank_test
from lifelines.statistics import logrank_test
# Run log-rank test to compare high and low category machines
test_results = logrank_test(durations_A = high_machine['Tool wear [min]'],
durations_B = low_machine['Tool wear [min]'],
event_observed_A = high_machine['Target'],
event_observed_B = low_machine['Target'])
# Print out the p-value of log-rank test results
print(test_results.p_value)
test_results.print_summary

对数秩检验结果。p 值结果为 < 0.05,这意味着这些组之间的生存率不相同(图像来源:作者)
Weibull 模型
上述 Kaplan-Meier 估计器和对数秩检验属于单变量分析,它将生存建模为单个分类因子变量的函数。对于分析多个变量对生存函数的影响,我们可以使用参数模型,如Weibull 模型。该模型基于 Weibull 分布的连续概率分布假设。

Weibull 模型的生存函数(图像来源:作者)
以下是使用 lifelines 库在 Python 中应用 Weibull 模型的示例。
from lifelines import WeibullFitter
# Instantiate WeibullFitter class
wb = WeibullFitter()
# Fit data
wb.fit(df_new['Tool wear [min]'], df_new['Target'])
# Plot survival function
wb.survival_function_.plot()
plt.show()
# Show Weibull model results
wb.summary


Weibull 模型结果(图像来源:作者)
如上所示,Weibull 模型生成平滑的生存曲线,而不是阶跃函数。上面的ρ(rho)值为大于 1,这表明该模型中的风险率始终在增加。
Cox 比例风险(CoxPH)模型
Cox 比例风险模型是可以评估不同因素对生存影响的模型之一。它在假设协变量对所有观测值有相同效果且变量之间没有未指定的交互作用的基础上运行。
Cox 比例风险(Cox PH)模型是一种回归模型,用于回归协变量与事件时间/持续时间。由于它是一个回归模型,你需要与观测相关的协变量(分类变量)列表来拟合模型,此外还有其他估计器中的事件和持续时间。临床研究中的协变量示例包括患者的年龄、体重、吸烟行为、是否遵循饮食、是否有某种过敏等。
from lifelines import CoxPHFitter
# Instantiate CoxPHFitter class cph
cph = CoxPHFitter()
# Fit cph to data
cph.fit(df=df_new, duration_col="Tool wear [min]", event_col="Target")
# Print model summary
cph.summary

CoxPH 模型结果(图像由作者提供)
从 CoxPH 模型结果中需要关注的关键变量是(1)exp(coef)列和(2)p列(p 值)。
-
exp(coef)列显示了变量的风险比。 -
p(p 值)表示哪些协变量对生存持续时间有显著影响。p 值较低(< 0.05)的预测变量是统计上显著的预测因子,用于确定生存持续时间。
[变量 X]从其中位值增加一个单位意味着风险因子变化为 e[coef(x)]。在上面的示例截图中,随着“旋转速度 [rpm]”从其中位值增加,机器的风险因子变化为 e[0.0079] = 1.008,这意味着相比于基线风险增加了 0.8%。
结束
生存分析是一种统计技术,用于分析事件发生的预期时间。虽然最初用于医疗领域,但其应用也可以扩展到各种用例,包括预测维护(预测设备故障)、客户分析(预测流失/购买时间)以及贷款建模(预测违约)。
生存分析可以使用几种技术,包括 Kaplan-Meier 估计器和 Weibull 模型来建模群体的生存函数,log-rank 检验来比较群体之间的生存时间,CoxPH 模型和加速失效时间(AFT)来描述分类或定量变量对生存的影响。每种模型都基于某些假设/分布,并可以使用其Akaike 信息准则(AIC)或qq 图进行比较,以确定最适合数据集的模型。
在本文中使用的脚本和分析探索可以在这个 GitHub 仓库中找到。
探索令牌概率作为过滤 GPT-3 答案的一种手段
原文:
towardsdatascience.com/exploring-token-probabilities-as-a-means-to-filter-gpt-3s-answers-3e7dfc9ca0c
为了构建更好的 GPT-3 驱动的聊天机器人
LucianoSphere (Luciano Abriata, PhD)
·发布于 Towards Data Science ·阅读时间 12 分钟·2023 年 1 月 19 日
--

GPT-3 为构成显示句子的每个令牌生成的对数概率,旨在测试系统。该图片由作者从一个用于执行这些测试的 Web 应用程序的截图中合成,链接在文章的末尾。
随着强大的语言模型越来越普及,对它们生成的内容进行控制的需求变得更加紧迫。这些模型在大量文本数据上进行训练,能够生成非常有说服力的书面内容,从新闻文章到社交媒体帖子。然而,如果没有适当的监督,它们也可能产生虚假信息或各种有害内容。因此,使用这些语言模型的应用程序必须尝试检查这些 AI 系统生成的信息的真实性,以防止传播虚假、误导或有害的信息。
像许多人一样,可能包括你自己,因为你正在阅读这篇文章,我已经大量使用了 GPT-3 和 ChatGPT,在这个过程中我发现它们经常以非常有说服力但却不正确的方式回答问题。事实上,当 GPT-3 发布时,我在进行深入探索,就像我对学生在科学学科中进行的考试一样:
[## GPT-3 的能力、限制和使用案例:从我的测试和原型应用中你可以复制的……
通过智能聊天机器人展示,它们甚至可以自然地听和说命令程序,或作为全职助手帮助学生……
最近,尽管我意识到这些限制,但我还是迫切开始创建由 GPT-3 驱动的聊天机器人。最初通过修补一些 PHP 库并编写相应的 JavaScript 代码,最近则完全使用 JavaScript:
## 用少于 20 行 JavaScript 代码构建类似 ChatGPT 的机器人核心代码!
不再像之前的示例那样使用 PHP。通过直接在 JavaScript 中使用现代 fetch() 函数,现在更容易...
在处理那个最后的项目,即从纯 JavaScript 调用 GPT-3 时,我深入探索了 OpenAI 的 GPT-3 API 参考,发现可以非常容易地检索到与语言模型生成的每个标记相关联的一系列分数。这些分数实际上是以对数形式表示的概率,由 GPT-3 一次生成一个标记,并与文本预测一起提供。这些概率衡量了 GPT-3 生成的不同标记的“概率”。虽然这些概率可能包含有关 GPT-3 对生成内容的确定性的信息,但这并非给定。因此,我决定通过一系列新的 JavaScript 应用程序进行动手研究,你可以在这些应用程序的基础上进行扩展。
更具体地说,我将探讨如何检索这些标记概率以及它们在我知道是正确或错误的文本中的取值。我还将探讨少量学习对这些分数的影响,以了解它是否确实使 GPT-3 对其答案更有信心。
GPT-3 和标记对数概率
如果你正在阅读这篇文章,GPT-3 可能无需介绍。但如果需要介绍,GPT-3 代表生成预训练变换器(目前是第 3 版,但实际上由多个不同版本的模型组成),它是一个先进的语言模型,根据输入生成书面内容。
当你将一些文本输入到 GPT-3 中(称为“提示”)时,它会被拆分为所谓的标记,这些标记是从单个字母到音节甚至单词的可变大小的单位(根据各种因素)。这些标记在网络中传播,结果合成出新的标记,这些标记一起形成新的单词、句子和段落。这些文本通常具有意义和相当好的语法,除非你处理的是 GPT-3 在训练中很少见的异域语言。然而,它们的内容不一定准确,特别是如果你期望它“思考”某个问题或在概念之间建立关系,或者如果你询问模型在训练期间未见过的事物(例如,它不会知道我是谁,所以它可能会编造一些东西,见下例)。
GPT-3 和其他大型语言模型的一个重要特性是它们是“少样本学习者”,这意味着它们可以处理并“理解”在提示中传递的一些信息,然后可能基于这些信息回答问题或执行任务。arXiv 上有一篇完整的预印本解释了这一点;我也有几个示例项目利用了这一特性:
最近的工作表明,通过在大规模语料库上进行预训练,许多自然语言处理任务和基准测试取得了显著的进展…
arxiv.org [## 为什么你应该以及如何用自定义数据或维基百科访问来告知你的聊天机器人
将我完全基于网络的、由 GPT-3 驱动的聊天机器人扩展为能够了解我提供的内容或自动检索的内容…
现在,我在这里讨论的 GPT-3 的一个特性(尽管其重要性巨大,但在互联网上讨论不多)是,除了返回文本之外,GPT-3 还可以返回与构成生成文本的每个标记相关的概率。这些概率实际上以对数形式返回,测量每个标记在输出文本上下文中出现的可能性。较低的对数概率表示不太可能的词,较高的对数概率表示更可能的词。
根据 ChatGPT 自身的说法,GPT-3 使用这些对数概率来生成连贯且语法正确的文本;此外,它使用对数概率来生成下一个词,基于最有可能的下一个词,从而生成语境上准确的文本。我将在这里调查这是否也包含有关内容准确性的信息。剧透:是的,至少有一点;此外,少量示例学习不仅改善了回答本身,还改善了对数概率。
调用 GPT-3 的 API 时获取对数概率
需要注意的是,GPT-3 和其他语言模型一样,无法区分真实和虚假的信息。它只是基于从训练数据中学到的模式和提示中提供的少量示例生成文本。
对数概率原则上可以帮助检测不正确的信息;但我们怎么获取它们呢?
这是我最近展示的代码的一个小修改,完全用 JavaScript 调用 GPT-3 的 API,并修改为获取标记对数概率:
// Your OpenAI API key
const apiKey = “(your API key here)”;
fetch(
`https://api.openai.com/v1/completions`,
{
body: JSON.stringify({
“model”: “text-davinci-003”,
“prompt”: “Where is Luciano Abriata from?”,
“temperature”: 0,
“max_tokens”: 20,
“logprobs”: 1}), //Note we request for logprobs
method: “POST”,
headers: {
“content-type”: “application/json”,
Authorization: “Bearer ” + apiKey,
},
}
).then((response) => {
if (response.ok) {
response.json().then((json) => {
console.log(json);
});
}
});
fetch()调用包括调用 GPT-3 并获取文本和标记概率所需的一切。提示包括少量示例学习的信息,放在问题“Luciano Abriata 来自哪里?”之前。
测试不同场景下的标记概率
让我们看看如果我们用一个只包含问题“Luciano Abriata 来自哪里?”的提示调用上述函数,会发生什么,也就是说,没有任何解释我来自阿根廷的辅助信息。
我们期望 GPT-3 不会知道我来自哪里,因为我不是名人。事实上,它“编造”了我来自意大利:

(有趣的事实:这并不太偏差,因为我的祖先都是意大利人……在这里,GPT-3 可能基于我的名字做出了猜测……但不,我在阿根廷出生和长大。)
现在,我们在控制台日志中看到什么?除了输出文本本身,还有很多有趣的输出。让我们分析一下最重要的元素:
首先,你会看到包含输出的对象text:Luciano Abriata is from Italy。
但在text的几行上方,你会看到一个包含构成该文本的标记的数组。在这个数组的几行上方,你会看到token_logprobs,这是一个相同大小的数组,列出了每个标记的对数概率。
你可以看到token_logprobs在第 9 个标记“意大利”处达到了-0.49 的最低值,而其他所有标记都非常接近 0(除了结尾处的标记也为负值,但我们不关心这些结尾标记)。
这原则上是好消息,因为这意味着 GPT-3 提供了一个线索,表明这个信息可能是错误的,或者是“编造”的。不过,我们不要急于得出结论,还是进一步探索一下。
如果我们在提示中提供一些信息,然后询问相关内容会怎样?比如这样:
Luciano Abriata 是一位出生在阿根廷的科学家,目前在瑞士工作。他从事结构生物学、虚拟现实、核磁共振、科学写作、编程等工作。Luciano Abriata 是哪里人?
// Your OpenAI API key
const apiKey = “(your API key here)”;
fetch(
`https://api.openai.com/v1/completions`,
{
body: JSON.stringify({
“model”: “text-davinci-003”,
“prompt”: “Luciano Abriata is a scientist from Argentina, now working
in Switzerland. He works on structural biology, virtual
reality for chemistry, NMR, etc.
Where is Luciano Abriata from?”,
“temperature”: 0,
“max_tokens”: 20,
“logprobs”: 1}), //Note we request for logprobs
method: “POST”,
headers: {
“content-type”: “application/json”,
Authorization: “Bearer ” + apiKey,
},
}
).then((response) => {
if (response.ok) {
response.json().then((json) => {
console.log(json);
});
}
});
在这种情况下,GPT-3 不仅正确回答了我来自阿根廷,还自信地说:“阿根廷”,在标记 9 中,其对数概率非常接近 0:

更全面的测试是通过在生成的文本上用颜色表示标记概率
为了测试对潜在不准确信息标记的对数概率的威力,我编写了这个简单的网页应用程序(链接在文末附近),它处理一个带有 GPT-3 的提示,并显示每个标记按其对数概率上色的生成文本:

在这个应用程序中,你可以通过我在下面提供的链接进行尝试,每个标记的颜色键是通过 HTML 的 标签注入的,如下所示:
-
对数概率 > -0.1 → 绿色
-
-0.3 > 对数概率 > -0.1 → 黄色
-
-0.5 > 对数概率 > -0.3 → 橙色
-
对数概率 < -0.5 → 红色
让我们稍微分析一下:问题是关于我虚构的一个人,而 GPT-3 只是回答了,而不是说它不知道,尽管温度是 0,所以不应该虚构内容……显然,不能依赖温度参数来防止生成虚假信息。
现在,请注意对于这个虚构角色(他是意大利人,是哲学家和哲学教授,在罗马大学拉萨比恩扎)的 4 个最重要特征,有些被强烈标记:
“哲学家和哲学教授”的对数概率平均值约为 -1,而“罗马 La”的对数概率约为 -1.5。
同时,“Sapienza”可能保持未被标记,即对数概率较高,因为它是“罗马大学 La”的完美延续。同样,“意大利人”可能保持高概率,因为它紧跟在“Giulio”之后,“Giulio”是一个非常意大利的名字。
因此,低对数概率似乎指示潜在的不准确信息,但高值并不能确保事实准确性。
低对数概率似乎指示潜在的不准确信息,但高值并不能确保事实准确性。
现在我们尝试将“Giulio”改为“John”,即询问“John Caranchani 是谁?”:

这次它虚构了这个角色是意大利裔美国人,并且以相当差的分数标记了这一信息。
再做一次测试,使用典型的法国名字“Philippe”:

现在它虚构了这个角色是法国人,并用橙色标记了这一点。
现在我们来询问一些不同的问题:化学。醋酸的分子式是什么?

我们得到了一个正确的答案,所有的标记都是绿色,表明分数非常好。
如果我们虚构一个分子呢?例如 Bizarric 酸:

看起来它确实“意识到”它在编造东西。当然,我们更希望 GPT-3 回答“我不知道”或“Bizarric acid 不存在”。但一个非常糟糕的分数总比没有好。
传递信息进行少量学习的效果
正如之前讨论过的以及在许多文章中提到的,GPT-3 可以从传递的文本片段中提取信息进行少量学习,从而更准确地回答问题——至少根据传递的信息来说。
这在令牌概率中是如何体现的?
让我们看看一个例子。我虚构了一个名叫 Stephan Farconi 的人,想知道他来自哪里以及他做什么:

正如预期的那样,一切都是虚构的,并且标记了差分数。
现在让我们给 GPT-3 一些关于这个人的信息,再次提问:

答案现在与传递的信息在事实上一致,并且 GPT-3 将一部分(“希腊”)标记为确定。然而,它对“计算机科学家”不太确定。请注意,“科学家”是“计算机”的良好续接,所以在我看来 GPT-3 实际上对整个“计算机科学家”概念是“没有把握”的。
想试试这个应用吗?
在这里:
https://lucianoabriata.altervista.org/tests/gpt-3/js-only/GPT3-JSonly-logprobs-color.html
只需粘贴你的 OpenAI API 密钥用于 GPT-3,输入一个问题或提示,然后点击“发送给机器人”。
请在评论中分享你的结果!令牌概率是否反映了生成文本的实际准确性?
结论
在聊天机器人应用中使用 GPT-3 及类似语言模型可能会在生成准确可靠的信息方面带来重大挑战。虽然新方法有望提高真实性,但目前的一个选项是利用令牌概率来衡量 GPT-3 对每个构成其输出的令牌的“确定性”。我向你展示了使用这些信息可以近似估计 GPT-3 在其生成文本中的信心水平,但这并不完全可靠。
从测试的几个例子来看,所有令牌的高分似乎表明信息是准确的,但某些令牌的低分并不一定意味着信息是错误的。当然,这仅仅是个案证据,需要更大规模的研究来认真评估。这可能是一个有趣的项目,一个研究 AI 的团队可以相对容易地进行,只需要足够多的人与系统互动,提出问题并评估答案。(此外,请注意,我选择解释日志概率的阈值是有些随意的。)
我还向你展示了少样本学习不仅能提高回答的准确性,特别是对于 GPT-3“不知道”的问题,还能提高其可靠性,这通过令牌概率来衡量。然而,我们也看到,通过少样本学习提供的信息并不能保证回答中的高分,即使这些回答中的信息是正确的。
总体来看,令牌概率看起来很有前景,但即使没有全面评估,也很明显它们是易出错的。因此,谨慎使用 GPT-3 是至关重要的,可能利用令牌概率来改进输出,但不要过于依赖它们。
www.lucianoabriata.com 我撰写和拍摄的内容涵盖了我广泛兴趣领域中的一切:自然、科学、技术、编程等。 成为 Medium 会员 以访问所有故事(平台的关联链接,我会获得少量收入,不会增加你的费用)以及 订阅以获取我的新故事 通过电子邮件。如需 咨询小型工作, 请查看我的 服务页面。你可以 在这里联系我。
探索什么让 AI 伦理工具包运转起来
AI 伦理工具包随处可见,但我们真的理解它们吗?
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 9 月 22 日
--
图片由Todd Quackenbush拍摄,来自Unsplash — 是时候拆解 AI 伦理工具包了
介绍
随着 AI 系统在具有重要影响的应用中的使用持续增加,专家们呼吁在设计这些系统时采取更多参与性和价值意识的做法。增加利益相关者参与可以为 AI 系统设计带来许多好处,包括使其更加包容、对抗现有偏见和提供问责制。对此,AI 伦理领域近年来产生了大量工具包。这些工具包来自不同的来源,如大学、公司和监管机构,采用了几种不同的技术和框架,并针对的数据科学家、AI 工程师、设计师和公众等不同受众(Wong 等人,2023)。
许多 AI 伦理工具包仅关注与 AI 创建相关的技术方面,主要面向技术从业者(例如,FAIR 自我评估工具)。然而,也有一些工具包主张并专注于利益相关者的参与,特别是 AI 创建团队之外的利益相关者,如最终用户和领域专家(例如,AI 与设计工具包)。那些关注利益相关者参与的工具包通常通过提供画布或卡片组等资源,使利益相关者,尤其是非技术背景的人,能够参与设计活动。这种参与可以导致从头脑风暴不同决策的后果到与 AI 用户建立同理心的各种结果。
尽管存在大量工具包,但尚未有人真正尝试理解它们的工作原理或其潜在策略。这使得尽管这些工具包被广泛使用,但是否对工具包用户及其使用工具包产生的结果有积极影响仍不明确。因此,我想了解一个工具包的具体运作方式;结果证明这是一次非常有教育意义的经历。

由 Nadia Piet 创建的‘AI 与设计’工具包的一部分,来自aixdesign.co/shop。
我做了什么
我举办了一个设计研讨会,9 位参与者以不同的角色与人工智能合作。研讨会包括一个创意活动,旨在头脑风暴针对一个虚构的对话式人工智能的应用场景和设计特性,以帮助人们自我管理糖尿病。它包含了一个由个人设计师而非公司或大学制作的AI 伦理工具包。我故意选择了这个工具包,以深入探讨一个不依赖于大量资金的小型工具包的基本运作机制。
这个研讨会在 Miro 上进行,持续了两个小时。工具包的工作方式是由一副卡片组成。这些卡片每张都有一个值(例如隐私、自主、安全等)和一些 How Might We?问题,这些问题作为提示,激发出不同的想法,以便在生产的技术中尊重给定的价值。我在 Miro 板上布置了这些卡片,并留出了便签供人们进行头脑风暴,然后我给每个人分配了两个需要单独关注的价值。
一个小时后,我们大家聚集在一起,我翻转了卡片,这样人们就无法再看到卡片的值,只能看到人们写在便签上的想法。我让每个人展示他们写下的想法,其他人则需要猜测他们希望尊重的价值。

来自研讨会 Miro 板的截图,显示了“诚实”卡片以及参与者围绕它进行头脑风暴的想法。
我学到了什么
将价值观作为跳板来建立同理心和更广泛的考虑
该工具包旨在使用他们提供的价值列表作为构思和头脑风暴的跳板。3 位参与者提到这是一个非常有用的方法:
“看到不同的价值观如何适用于开发对话代理的过程非常有趣。”
“从单一价值观的角度思考设计。”
“查看价值卡片并围绕这些卡片产生想法。”
-
相较于专注于技术可行性,更加关注用户的价值观和不同想法的重要性。参与者似乎喜欢/更愿意这种方法,因为它提供了一个与关注技术细节和忽视如安全和公平等价值观的方式不同的变革。一位参与者表示,他们认为这些价值观在考虑系统及其建设和技术背景时可能不会出现,或在考虑什么是最简单或最快的做法时。他们说,尽管人们在设计时应考虑这些价值观,但现实中并非总是如此,因为其他优先事项会遮盖这些价值观。总之, 突出价值观帮助参与者进行头脑风暴并考虑通常被忽视的非技术方面。
-
这个练习让参与者希望了解更多关于糖尿病患者生活及其经历的信息,以理解如何以最定制和具体的方式支持他们,并提出实际相关的解决方案,而不是基于假设。这种欲望的触发值是‘宁静’(导致希望了解造成压力和焦虑的情况)和‘安全’(导致希望了解他们面临的生命威胁情况)。总之,关注价值观增强了参与者对目标用户的(i)同理心和(ii)了解如何最好地支持他们的好奇心/愿望。

照片由Miltiadis Fragkidis拍摄,来源于Unsplash — 与利益相关者的价值观合作可以促进更深层次的同理心和理解
使用游戏化来增加参与度并改善结果
参与者在猜测值时非常享受游戏化的元素:
“我喜欢猜测创意相关的值——这让我比起只是阅读/讨论这些创意时更能投入其中。”
“我觉得猜测值的部分最有用,它让我真正理解了不同的价值观以及它们之间的相互关系。”
- 在猜测值时,参与者觉得几个创意可能对应不同的值,不同值下的创意有时指的是相同的概念或建议。讨论也突出了几个值之间过于相似或有大量重叠,导致很难区分它们。— 很多值被认为过于相似或重叠,参与者难以猜测或区分它们。不同值之间有很多关联,有些值可以导致或促进其他值,或体现类似的现象/情感。在猜测过程中,“公平”被误认为“包容性”;“社区”、“自由”和“自主”无法区分;“掌握”被混淆为“勇气”和“好奇心”;“包容性”被误认为“无障碍”和“尊重”。总之,游戏化使参与者能够真正理解创意和价值观之间的相互关系。

照片由Erik Mclean拍摄,来源于Unsplash — 游戏化具有许多好处。在 AI 协作设计的背景下,它可以帮助参与者掌握相互关系,并增加参与感。
未来改进和想法
-
三位参与者指出,他们希望能够获得更多的背景信息和提示,以便更好地与对话式人工智能的目标利益相关者产生共鸣。对于诸如“安全”和“宁静”这样的具体值,参与者希望了解更多关于用户的体验,以便更好地为他们提供针对性的解决方案,特别是解决他们的需求。一位参与者觉得最初的提示缺乏来自他们设计对象(糖尿病患者)的反馈,因此难以为他们真正头脑风暴出创意。他们希望能获得更多关于用户场景的定制/具体信息,包括他们生活中的片段描述或实际的背景信息,以便能更好地定制生成的创意。
-
一项建议是进行这样的工作坊,首先专注于抽象的高层次内容,不带太多事前信息,然后再去收集用户信息以提供背景。因为如果你先收集用户信息,你可能会错过一些观点或价值观,并过于关注你所收集的内容。之后,你可以将这些信息整合在一起,再次进行头脑风暴,探讨场景和技术方面,同时受益于两种知识来源。
-
两位参与者指出,如果他们没有带有定义的值列表,他们可能会在想出描述这些值的确切词语时遇到困难,可能会想到类似的词汇,但不是确切的词语。如果没有提供值的描述,他们也会难以理解这些值的含义。
这如何能帮助你
对 AI 伦理工具包的有效性的实验让我对使用游戏化和基于价值的练习的力量有了更深刻的理解。以下是我将这些见解提炼成的六点,这些可以帮助你与利益相关者一起使用 AI 伦理工具包进行设计工作坊,以便进行头脑风暴或构思 AI 设计:
-
合作思考 AI 伦理的主要点之一是让 AI 更具人性化。因此,包括大量的背景(最好是第一手)信息关于你的目标用户和你为谁设计是非常重要的。记住,许多技术从业者可能不习惯于建立同理心的练习和构思,所以通过真正明确他们试图帮助的人来帮助他们实现这一点。
-
考虑你的工作坊是否需要许多活动从高层次的相关性或重要性逐步深入到细节。混合这两者可能会让参与者感到困惑,尤其是当你要求他们对可能无法在同一层面进行比较的观点进行排名或评分时。
-
如果你正在处理值,请确保你清楚定义每个值的含义。这很重要,因为不同的值可能根据询问的人不同而含义不同,如果没有清晰定义,它们之间可能会有很多冲突和重叠。
-
将值融入你的 AI 设计活动可以帮助参与者集思广益,并考虑常被忽视的非技术性方面,帮助他们与目标用户建立同理心,并增加他们的好奇心和更好地支持这些用户的愿望。
-
结合游戏化技术可以帮助提高参与者的投入感和享受感,在 AI 伦理工作坊中,这也可以帮助他们更深入地把握观点之间的联系。
我的角色
我的博士项目旨在利用设计领域的工具和技术,使 AI 系统的设计变得更具可访问性和包容性。我正在致力于创建一个参与式过程及其支持工具包,以系统性地在整个 AI 生命周期中涉及人们——重点关注价值敏感性。
你可以在帝国理工学院官网查看我项目的官方页面。你也可以查看我撰写的另一篇文章,解释了我博士项目的详细信息。
我创建了这个 Medium 账号,以便在我进行博士项目的过程中发布有趣的发现,希望能够以一种让任何人都能理解的方式传播关于 AI 系统的新闻和信息。
参考文献
Richmond Y. Wong、Michael A. Madaio 和 Nick Merrill. 2023. 《像工具包一样看待:工具包如何构想 AI 伦理的工作》。Proc. ACM Hum.-Comput. Interact. 7, CSCW1, Article 145 (2023 年 4 月),共 27 页。 doi.org/10.1145/3579621
通过 HTTP 安全地暴露 Kubernetes 卷:如何在互联网上服务 PVC
创建 Kubernetes 清单以暴露 PersistentVolumeClaims
·
关注 发表在 Towards Data Science · 7 min read · 2023 年 2 月 8 日
--
图片来源:Uriel Soberanes 在 Unsplash
介绍
你可能在日常的产品开发中遇到过这样一种情况,需要获取 Kubernetes 集群中某些持久化文件。一种常见且安全的方法是进行端口转发,无论是通过 Kubectl 还是使用堡垒主机进行纯 SSH。
无论哪种情况,完成任务后,你会终止会话,并且每次未来的交互中,你都需要重复相同的手动过程。
从安全角度看,理想情况下,你的环境应该尽可能密封,不给对手任何机会,这也是保持这种状态的合理原因。
但是,如果你想要长期在互联网上暴露底层存储,这篇文章就是为你准备的。
首先:认证
由于该文件服务器将公开暴露在互联网上,你的首要和最重要的防线是认证层。为了阐明这一点,有必要提供认证的正式定义。
认证是证明一个主张的行为,例如计算机系统用户的身份。[source]
通俗来说,认证发生在系统用户证明他是他所声称的人的时候!
既然我们已经清楚了这一点,让我们深入探讨一些将认证集成到我们的 web 服务器中的选项(见下文)。
-
使用 Nginx 或 Apache 作为代理,借助
htpasswd,这是一个Apache 工具,它允许在文件中存储加密的用户名-密码对,之后可以用来验证给定的密码。 -
Ory Oathkeeper 作为代理,借助 Kratos,这是 Ory 的另一个产品,作为身份提供者。这比早期的方法稍微复杂一些,需要一些学习曲线来掌握配置和这些工具的提供。我会在稍后的文章中覆盖这一点,所以请继续关注!😉
当然,你可以将更多内容添加到这个列表中,但为了保持本文简洁,老实说,因为我对其他解决方案不太了解,我暂时就以上两个项目满足要求。
另一个我要提到的点是,由于这篇文章涉及到互联网的暴露,我这里不谈论私人网络解决方案。不过,你可以想象这也会是一个安全的选项。
既然我们知道 Ory 的产品不是最容易配置的,而且作者也不是认证专家 😁,让我们保持简单并采用第一种方法。
创建 htpasswd 文件
htpasswd是一个相当简单的工具,用于将基本认证机制应用到任何平台。它通过接收用户名和密码作为输入来工作。结果将是一个单向哈希密码,存储在文件或标准输出中,随后可以用来验证用户凭证。然而,它不能在合理的时间内被还原(解哈希)为原始密码,至少在 2023 年,考虑到我们当前的计算能力!
要进行简单演示,请查看下面的片段。
这将仅为用户创建一个新文件,并尝试用正确的密码和错误的密码进行验证。
我们将在我们的“安全文件服务器”中使用相同的设置,公开暴露到互联网。
反向代理
除非你想在文件服务器层处理认证(我知道我不会),你将使用反向代理来位于前面,接收所有流量,并拒绝所有凭证错误的请求。你甚至可以添加其他限制措施,包括但不限于速率限制、日志记录、仪表化、报告等。
Apache 和 Nginx 都可以使用htpasswd生成的文件来验证凭证。有关每个的更多信息,请参见以下链接:
我相信其他 Web 服务器也可以,做同样的事情。
在这篇文章中,我将使用 Nginx,并且由于这将托管在 Kubernetes 中,它将是一个 Nginx 的 docker 容器。这允许我将任意数量的配置文件挂载到/etc/nginx/conf.d目录中,Nginx Web 服务器进程会拾取这些配置文件。
因此,如果我可以将任何配置文件挂载到目录中,我可以在Kubernetes ConfigMap中编写配置文件,并将其挂载到容器的目标目录。这既强大又相当灵活。
这是我即将挂载到 Nginx 容器中的配置。
配置文件中名为proxy_pass的条目指向将通过 HTTP 协议暴露文件系统目录的文件服务器。更多内容请见下一节。

照片由 Carl Barcelo 提供,来源于 Unsplash
通过 HTTP 提供文件
在许多其他静态 web 服务器中,只提到其中的一些。
-
Python 模块:
[http.server](https://docs.python.org/3/library/http.server.html) -
Npm:
[serve](https://www.npmjs.com/package/serve/v/14.2.0)
当然,这个列表可以更多,但我们试图保持简短和信息丰富。😇
在本文中,我将使用 Python 的内置模块:http.server。它具有直接和直观的接口,使得使用非常简单。
使用它提供静态内容的方式如下:
ADDRESS=0.0.0.0 PORT=8000 DIRECTORY=/tmp
python -m http.server -b $ADDRESS -d $DIRECTORY $PORT
这效果很好,特别是因为你不需要做很多复杂的操作来使它工作。
运行并可访问的 web 服务器从 Nginx 容器中意味着你可以将 PersistentVolumeClaims 挂载到静态 web 服务器上,并将上述 Nginx 放置在前面,以防止未认证的访问你的 Kubernetes 集群中的宝贵数据。
挂载 Kubernetes ConfigMap 为 Volume
在我们将所有内容汇总为一个统一的清单之前,还有一个最后的关键信息在这种方法中使用,需要一点解释。但如果你已经是 Kubernetes 中如何将 ConfigMap 挂载为容器 Volume 的大师,可以跳过这一部分。
要将 Kubernetes ConfigMap 挂载为 Volume,你可以在容器定义的 volumes 部分使用 projection,如下所示 [source]:
在 containers 相同级别下,定义了 volumes,它可以接受几种 volume 类型,其中之一是 ConfigMap。这允许定义一些脚本并将其作为 volume 传递给运行中的容器。
创建上述清单后,以下是容器日志中显示的内容。
kubectl logs job/demo -c demo
total 0
drwxr-xr-x 2 root root 45 Feb 4 06:58 ..2023_02_04_06_58_34.2149416564
lrwxrwxrwx 1 root root 32 Feb 4 06:58 ..data -> ..2023_02_04_06_58_34.2149416564
lrwxrwxrwx 1 root root 21 Feb 4 06:58 favorite-color -> ..data/favorite-color
lrwxrwxrwx 1 root root 16 Feb 4 06:58 names.txt -> ..data/names.txt
red
Alex
Jane
Sam
结合所有内容
现在我们已经逐步了解了所有信息,是时候将它们结合在一起,作为一个统一的清单,用于一个唯一的目的:一个 HTTP 文件服务器。
第二个文件,一个 Ingress 资源,是可选的,但仍然包含在内,因为这篇文章是关于公开暴露 HTTP 静态网页服务器的。只有创建 Ingress 才能获得互联网曝光。
为了增加另一层安全措施,你可以为子域分配一个 UUID 生成的值,以避免仅使用用户名和密码作为唯一的安全措施。它可以是这样的:
415cb00c-5310-4877-8e28-34b05cadc99d.example.com
否则,你的安全性仅与用户名和密码相同,如果你的品牌暴露了你分配的用户名,那么你的安全性就仅仅取决于你的密码,而这与你希望的安全状态不同!
此外,请记得你需要 HTTPS。你绝不希望有陌生人窃听你的连接,监视你通过互联网传输宝贵的客户数据。

图片由 Haley Phelps 提供,来自 Unsplash
结论
由于我们讨论的是 Kubernetes,这与任何云服务提供商无关,你只需在任何 Kubernetes 集群中应用这个做法。
这将有效地意味着你可以安全地将动态配置的持久卷暴露到互联网上!
对于“安全”部分,需持保留态度,因为这并不是最安全的选择,但如果你给你的 Ingress 的子域名分配一个随机字符串,你仍然可以保护自己。这样,攻击者将不得不在互联网上找到多个组合中的 URL,这将需要很多年。到那时,我们可能都已经离开了!
祝你今天过得愉快。保持关注,保重!
致谢
希望你觉得这篇文章对你有帮助。以下是一些你可能会喜欢的我以前的工作列表。
## 如何在 AWS EKS 中设置 Ingress Controller
在 AWS EKS 上以正确的方式部署 Ingress Controller
towardsdatascience.com [## 什么是 HAProxy 及如何充分利用它?
负载均衡、SSL/TLS、缓存等等。
meysam.io](https://meysam.io/what-is-haproxy-how-to-get-the-most-from-of-it-a9009b67f618?source=post_page-----67f10f448693--------------------------------) ## 如何编写自己的 GitHub Action
用 GitHub Workflows 完善你的 CI/CD 工具箱。
towardsdatascience.com [## 12-Factor App 入门
12-Factor App 是一组应用于现代 Web 开发的原则和最佳实践,以使应用程序更加……
medium.com](https://medium.com/licenseware/12-factor-app-for-dummies-d905d894d9f8?source=post_page-----67f10f448693--------------------------------) [## 停止将配置提交到源代码中
通过更改配置使应用程序在不同环境中可复现。
参考资料
揭示卡尔曼滤波器的强大功能
·
关注 发表在 Towards Data Science · 17 min 阅读 · 2023 年 11 月 7 日
--
作为数据科学家,我们偶尔会遇到需要建模趋势以预测未来值的情况。虽然有时会倾向于关注统计学或机器学习算法,但我在这里要介绍一个不同的选择:卡尔曼滤波器(KF)。
在 1960 年代初期,鲁道夫·E·卡尔曼彻底改变了如何用 KF 建模复杂系统。从引导飞机或航天器到达目的地,到让你的智能手机在这个世界中找到位置,这个算法将数据和数学融合在一起,以令人难以置信的精度提供未来状态的估计。
在本博客中,我们将深入探讨 Kalman 滤波器的工作原理,展示 Python 示例,以强调这一技术的真正力量。从简单的 2D 示例开始,我们将看到如何修改代码以适应更高级的 4D 空间,并最终涵盖扩展 Kalman 滤波器(复杂的继任者)。加入我,一起踏上预测算法和滤波器的世界之旅。

Kalman 滤波器的基础
KF 通过构建和持续更新一组协方差矩阵(表示噪声和过去状态的统计分布),并从观察和其他时间测量中收集数据,来提供系统状态的估计。与其他现成算法不同,可以通过定义系统和外部源之间的数学关系直接扩展和改进解决方案。虽然这可能听起来相当复杂和精细,但这个过程可以总结为两个步骤:预测和更新。这些阶段协同工作,以迭代方式校正和完善系统的状态估计。
预测步骤:
这个阶段主要是基于模型已知的后验估计和时间步长Δk 预测系统的下一状态。数学上,我们将状态空间的估计表示为:

其中,F,即我们的状态转移矩阵,模型如何在控制输入和过程噪声无关的情况下,状态从一步演变到另一部。我们的矩阵 B 模型控制输入 uₖ对状态的影响。
除了对下一状态的估计外,该算法还计算由协方差矩阵 P 表示的估计不确定性:

预测的状态协方差矩阵表示我们预测的信心和准确性,受到来自系统本身的过程噪声协方差矩阵 Q 的影响。我们将此矩阵应用于更新步骤中的后续方程,以校正 Kalman 滤波器对系统的持有信息,从而改善未来的状态估计。
更新步骤:
在更新步骤中,算法对 Kalman 增益、状态估计和协方差矩阵进行更新。Kalman 增益确定新测量对状态估计的影响程度。计算包括观察模型矩阵 H,它将状态与我们期望接收到的测量相关联,以及 R 测量噪声协方差矩阵,表示测量中的误差:

从本质上讲,K 试图平衡预测中的不确定性与测量中的不确定性。如上所述,Kalman 增益被应用于校正状态估计和协方差,具体如下所示的方程:


其中括号中的计算用于状态估计,是实际测量值与模型预测值之间的残差。
卡尔曼滤波器的真正美妙之处在于它的递归性质,随着新信息的接收不断更新状态和协方差。这使得模型能够以统计学上最优的方式随着时间的推移进行优化,这是对接收一系列噪声观测值的系统建模特别强大的方法。
卡尔曼滤波器在操作中
卡尔曼滤波器背后的方程可能会让人感到非常困惑,要完全理解其工作原理,仅凭数学是不够的,需要了解状态空间(超出此博客的范围),但我会尝试用一些 Python 示例来使其生动化。在最简单的形式下,我们可以将卡尔曼滤波器对象定义为:
import numpy as np
class KalmanFilter:
"""
An implementation of the classic Kalman Filter for linear dynamic systems.
The Kalman Filter is an optimal recursive data processing algorithm which
aims to estimate the state of a system from noisy observations.
Attributes:
F (np.ndarray): The state transition matrix.
B (np.ndarray): The control input marix.
H (np.ndarray): The observation matrix.
u (np.ndarray): the control input.
Q (np.ndarray): The process noise covariance matrix.
R (np.ndarray): The measurement noise covariance matrix.
x (np.ndarray): The mean state estimate of the previous step (k-1).
P (np.ndarray): The state covariance of previous step (k-1).
"""
def __init__(self, F, B, u, H, Q, R, x0, P0):
"""
Initializes the Kalman Filter with the necessary matrices and initial state.
Parameters:
F (np.ndarray): The state transition matrix.
B (np.ndarray): The control input marix.
H (np.ndarray): The observation matrix.
u (np.ndarray): the control input.
Q (np.ndarray): The process noise covariance matrix.
R (np.ndarray): The measurement noise covariance matrix.
x0 (np.ndarray): The initial state estimate.
P0 (np.ndarray): The initial state covariance matrix.
"""
self.F = F # State transition matrix
self.B = B # Control input matrix
self.u = u # Control vector
self.H = H # Observation matrix
self.Q = Q # Process noise covariance
self.R = R # Measurement noise covariance
self.x = x0 # Initial state estimate
self.P = P0 # Initial estimate covariance
def predict(self):
"""
Predicts the state and the state covariance for the next time step.
"""
self.x = self.F @ self.x + self.B @ self.u
self.P = self.F @ self.P @ self.F.T + self.Q
return self.x
def update(self, z):
"""
Updates the state estimate with the latest measurement.
Parameters:
z (np.ndarray): The measurement at the current step.
"""
y = z - self.H @ self.x
S = self.H @ self.P @ self.H.T + self.R
K = self.P @ self.H.T @ np.linalg.inv(S)
self.x = self.x + K @ y
I = np.eye(self.P.shape[0])
self.P = (I - K @ self.H) @ self.P
return self.x
我们将使用 predict() 和 update() 函数来迭代前面概述的步骤。上述滤波器设计适用于任何时间序列,为了显示我们的估计与实际情况的比较,让我们构建一个简单的示例:
import numpy as np
import matplotlib.pyplot as plt
# Set the random seed for reproducibility
np.random.seed(42)
# Simulate the ground truth position of the object
true_velocity = 0.5 # units per time step
num_steps = 50
time_steps = np.linspace(0, num_steps, num_steps)
true_positions = true_velocity * time_steps
# Simulate the measurements with noise
measurement_noise = 10 # increase this value to make measurements noisier
noisy_measurements = true_positions + np.random.normal(0, measurement_noise, num_steps)
# Plot the true positions and the noisy measurements
plt.figure(figsize=(10, 6))
plt.plot(time_steps, true_positions, label='True Position', color='green')
plt.scatter(time_steps, noisy_measurements, label='Noisy Measurements', color='red', marker='o')
plt.xlabel('Time Step')
plt.ylabel('Position')
plt.title('True Position and Noisy Measurements Over Time')
plt.legend()
plt.show()

实际上,“真实位置”是未知的,但我们在此处绘制它以供参考,“噪声测量”是输入到我们的卡尔曼滤波器中的观测点。我们将对矩阵进行非常基本的实例化,在某种程度上,卡尔曼模型会通过应用卡尔曼增益快速收敛,但在某些情况下,对模型进行热启动可能是合理的。
# Kalman Filter Initialization
F = np.array([[1, 1], [0, 1]]) # State transition matrix
B = np.array([[0], [0]]) # No control input
u = np.array([[0]]) # No control input
H = np.array([[1, 0]]) # Measurement function
Q = np.array([[1, 0], [0, 3]]) # Process noise covariance
R = np.array([[measurement_noise**2]]) # Measurement noise covariance
x0 = np.array([[0], [0]]) # Initial state estimate
P0 = np.array([[1, 0], [0, 1]]) # Initial estimate covariance
kf = KalmanFilter(F, B, u, H, Q, R, x0, P0)
# Allocate space for estimated positions and velocities
estimated_positions = np.zeros(num_steps)
estimated_velocities = np.zeros(num_steps)
# Kalman Filter Loop
for t in range(num_steps):
# Predict
kf.predict()
# Update
measurement = np.array([[noisy_measurements[t]]])
kf.update(measurement)
# Store the filtered position and velocity
estimated_positions[t] = kf.x[0]
estimated_velocities[t] = kf.x[1]
# Plot the true positions, noisy measurements, and the Kalman filter estimates
plt.figure(figsize=(10, 6))
plt.plot(time_steps, true_positions, label='True Position', color='green')
plt.scatter(time_steps, noisy_measurements, label='Noisy Measurements', color='red', marker='o')
plt.plot(time_steps, estimated_positions, label='Kalman Filter Estimate', color='blue')
plt.xlabel('Time Step')
plt.ylabel('Position')
plt.title('Kalman Filter Estimation Over Time')
plt.legend()
plt.show()

即使是这种非常简单的解决方案设计,模型也能在噪声中找到真实位置。这可能适用于简单的应用,但趋势通常更为微妙,并受到外部事件的影响。为了解决这个问题,我们通常需要修改状态空间表示以及在新信息到达时计算估计值和修正协方差矩阵的方式,让我们通过另一个示例来深入探讨。
在 4D 中跟踪移动物体
假设我们想跟踪一个物体在空间和时间中的运动,为了使这个例子更具现实感,我们将模拟一些作用在其上的力,导致角度旋转。为了展示该算法在更高维状态空间表示中的适应性,我们将假设线性力,尽管实际上情况并非如此(我们将在之后探索一个更现实的例子)。下面的代码展示了我们如何针对这一特定场景修改卡尔曼滤波器。
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
class KalmanFilter:
"""
An implementation of the classic Kalman Filter for linear dynamic systems.
The Kalman Filter is an optimal recursive data processing algorithm which
aims to estimate the state of a system from noisy observations.
Attributes:
F (np.ndarray): The state transition matrix.
B (np.ndarray): The control input marix.
H (np.ndarray): The observation matrix.
u (np.ndarray): the control input.
Q (np.ndarray): The process noise covariance matrix.
R (np.ndarray): The measurement noise covariance matrix.
x (np.ndarray): The mean state estimate of the previous step (k-1).
P (np.ndarray): The state covariance of previous step (k-1).
"""
def __init__(self, F=None, B=None, u=None, H=None, Q=None, R=None, x0=None, P0=None):
"""
Initializes the Kalman Filter with the necessary matrices and initial state.
Parameters:
F (np.ndarray): The state transition matrix.
B (np.ndarray): The control input marix.
H (np.ndarray): The observation matrix.
u (np.ndarray): the control input.
Q (np.ndarray): The process noise covariance matrix.
R (np.ndarray): The measurement noise covariance matrix.
x0 (np.ndarray): The initial state estimate.
P0 (np.ndarray): The initial state covariance matrix.
"""
self.F = F # State transition matrix
self.B = B # Control input matrix
self.u = u # Control input
self.H = H # Observation matrix
self.Q = Q # Process noise covariance
self.R = R # Measurement noise covariance
self.x = x0 # Initial state estimate
self.P = P0 # Initial estimate covariance
def predict(self):
"""
Predicts the state and the state covariance for the next time step.
"""
self.x = np.dot(self.F, self.x) + np.dot(self.B, self.u)
self.P = np.dot(np.dot(self.F, self.P), self.F.T) + self.Q
def update(self, z):
"""
Updates the state estimate with the latest measurement.
Parameters:
z (np.ndarray): The measurement at the current step.
"""
y = z - np.dot(self.H, self.x)
S = np.dot(self.H, np.dot(self.P, self.H.T)) + self.R
K = np.dot(np.dot(self.P, self.H.T), np.linalg.inv(S))
self.x = self.x + np.dot(K, y)
self.P = self.P - np.dot(np.dot(K, self.H), self.P)
# Parameters for simulation
true_angular_velocity = 0.1 # radians per time step
radius = 20
num_steps = 100
dt = 1 # time step
# Create time steps
time_steps = np.arange(0, num_steps*dt, dt)
# Ground truth state
true_x_positions = radius * np.cos(true_angular_velocity * time_steps)
true_y_positions = radius * np.sin(true_angular_velocity * time_steps)
true_z_positions = 0.5 * time_steps # constant velocity in z
# Create noisy measurements
measurement_noise = 1.0
noisy_x_measurements = true_x_positions + np.random.normal(0, measurement_noise, num_steps)
noisy_y_measurements = true_y_positions + np.random.normal(0, measurement_noise, num_steps)
noisy_z_measurements = true_z_positions + np.random.normal(0, measurement_noise, num_steps)
# Kalman Filter initialization
F = np.array([[1, 0, 0, -radius*dt*np.sin(true_angular_velocity*dt)],
[0, 1, 0, radius*dt*np.cos(true_angular_velocity*dt)],
[0, 0, 1, 0],
[0, 0, 0, 1]])
B = np.zeros((4, 1))
u = np.zeros((1, 1))
H = np.array([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0]])
Q = np.eye(4) * 0.1 # Small process noise
R = measurement_noise**2 * np.eye(3) # Measurement noise
x0 = np.array([[0], [radius], [0], [true_angular_velocity]])
P0 = np.eye(4) * 1.0
kf = KalmanFilter(F, B, u, H, Q, R, x0, P0)
# Allocate space for estimated states
estimated_states = np.zeros((num_steps, 4))
# Kalman Filter Loop
for t in range(num_steps):
# Predict
kf.predict()
# Update
z = np.array([[noisy_x_measurements[t]],
[noisy_y_measurements[t]],
[noisy_z_measurements[t]]])
kf.update(z)
# Store the state
estimated_states[t, :] = kf.x.ravel()
# Extract estimated positions
estimated_x_positions = estimated_states[:, 0]
estimated_y_positions = estimated_states[:, 1]
estimated_z_positions = estimated_states[:, 2]
# Plotting
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# Plot the true trajectory
ax.plot(true_x_positions, true_y_positions, true_z_positions, label='True Trajectory', color='g')
# Plot the start and end markers for the true trajectory
ax.scatter(true_x_positions[0], true_y_positions[0], true_z_positions[0], label='Start (Actual)', c='green', marker='x', s=100)
ax.scatter(true_x_positions[-1], true_y_positions[-1], true_z_positions[-1], label='End (Actual)', c='red', marker='x', s=100)
# Plot the noisy measurements
ax.scatter(noisy_x_measurements, noisy_y_measurements, noisy_z_measurements, label='Noisy Measurements', color='r')
# Plot the estimated trajectory
ax.plot(estimated_x_positions, estimated_y_positions, estimated_z_positions, label='Estimated Trajectory', color='b')
# Plot settings
ax.set_xlabel('X position')
ax.set_ylabel('Y position')
ax.set_zlabel('Z position')
ax.set_title('3D Trajectory Estimation with Kalman Filter')
ax.legend()
plt.show()

这里有几点有趣的内容需要注意,在上面的图表中,我们可以看到模型如何在开始对观测值进行迭代时快速修正为估计的真实状态。模型在识别系统的真实状态方面也表现得相当好,估计值与真实状态(“真实轨迹”)交叉。这种设计可能适用于某些实际应用,但不适用于那些作用于系统的力是非线性的情况。相反,我们需要考虑卡尔曼滤波器的不同应用:扩展卡尔曼滤波器,它是我们迄今为止探索的前身,能够线性化传入观测的非线性。
扩展卡尔曼滤波器
当我们尝试建模一个观测值或系统动态是非线性的系统时,我们需要应用扩展卡尔曼滤波器(EKF)。这个算法与上一个算法不同,它通过将雅可比矩阵引入到解中,并执行泰勒级数展开,以找到状态转移和观测模型的一阶线性近似。为了用数学方式表达这个扩展,我们的关键算法计算现在变成:

对于状态预测,其中 f 是应用于先前状态估计的非线性状态转移函数,u是前一时间步的控制输入。

对于误差协方差预测,其中F是状态转移函数相对于P(上一个误差协方差)的雅可比矩阵,Q是过程噪声协方差矩阵。

我们在时间步长k时对观测量z的观察,其中h是应用于我们状态预测的非线性观测函数,加上一些观测噪声v。

对卡尔曼增益计算的更新,其中H是观测函数相对于状态的雅可比矩阵,R是测量噪声协方差矩阵。

包含卡尔曼增益和非线性观测函数的状态估计修改计算,最后是更新误差协方差的方程:

在最后一个例子中,这将使用雅可比矩阵来线性化角旋转对物体的非线性影响,相应地修改代码。设计 EKF 比 KF 更具挑战性,因为我们对一阶线性近似的假设可能会无意中引入错误到我们的状态估计中。此外,雅可比矩阵计算可能变得复杂、计算开销大,并且在某些情况下难以定义,这也可能导致错误。然而,如果设计正确,EKF 通常会优于 KF 实现。
基于我们之前的 Python 示例,我展示了 EKF 的实现:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
class ExtendedKalmanFilter:
"""
An implementation of the Extended Kalman Filter (EKF).
This filter is suitable for systems with non-linear dynamics by linearising
the system model at each time step using the Jacobian.
Attributes:
state_transition (callable): The state transition function for the system.
jacobian_F (callable): Function to compute the Jacobian of the state transition.
H (np.ndarray): The observation matrix.
jacobian_H (callable): Function to compute the Jacobian of the observation model.
Q (np.ndarray): The process noise covariance matrix.
R (np.ndarray): The measurement noise covariance matrix.
x (np.ndarray): The initial state estimate.
P (np.ndarray): The initial estimate covariance.
"""
def __init__(self, state_transition, jacobian_F, observation_matrix, jacobian_H, Q, R, x, P):
"""
Constructs the Extended Kalman Filter.
Parameters:
state_transition (callable): The state transition function.
jacobian_F (callable): Function to compute the Jacobian of F.
observation_matrix (np.ndarray): Observation matrix.
jacobian_H (callable): Function to compute the Jacobian of H.
Q (np.ndarray): Process noise covariance.
R (np.ndarray): Measurement noise covariance.
x (np.ndarray): Initial state estimate.
P (np.ndarray): Initial estimate covariance.
"""
self.state_transition = state_transition # Non-linear state transition function
self.jacobian_F = jacobian_F # Function to compute Jacobian of F
self.H = observation_matrix # Observation matrix
self.jacobian_H = jacobian_H # Function to compute Jacobian of H
self.Q = Q # Process noise covariance
self.R = R # Measurement noise covariance
self.x = x # Initial state estimate
self.P = P # Initial estimate covariance
def predict(self, u):
"""
Predicts the state at the next time step.
Parameters:
u (np.ndarray): The control input vector.
"""
self.x = self.state_transition(self.x, u)
F = self.jacobian_F(self.x, u)
self.P = F @ self.P @ F.T + self.Q
def update(self, z):
"""
Updates the state estimate with a new measurement.
Parameters:
z (np.ndarray): The measurement vector.
"""
H = self.jacobian_H()
y = z - self.H @ self.x
S = H @ self.P @ H.T + self.R
K = self.P @ H.T @ np.linalg.inv(S)
self.x = self.x + K @ y
self.P = (np.eye(len(self.x)) - K @ H) @ self.P
# Define the non-linear transition and Jacobian functions
def state_transition(x, u):
"""
Defines the state transition function for the system with non-linear dynamics.
Parameters:
x (np.ndarray): The current state vector.
u (np.ndarray): The control input vector containing time step and rate of change of angular velocity.
Returns:
np.ndarray: The next state vector as predicted by the state transition function.
"""
dt = u[0]
alpha = u[1]
x_next = np.array([
x[0] - x[3] * x[1] * dt,
x[1] + x[3] * x[0] * dt,
x[2] + x[3] * dt,
x[3],
x[4] + alpha * dt
])
return x_next
def jacobian_F(x, u):
"""
Computes the Jacobian matrix of the state transition function.
Parameters:
x (np.ndarray): The current state vector.
u (np.ndarray): The control input vector containing time step and rate of change of angular velocity.
Returns:
np.ndarray: The Jacobian matrix of the state transition function at the current state.
"""
dt = u[0]
# Compute the Jacobian matrix of the state transition function
F = np.array([
[1, -x[3]*dt, 0, -x[1]*dt, 0],
[x[3]*dt, 1, 0, x[0]*dt, 0],
[0, 0, 1, dt, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1]
])
return F
def jacobian_H():
# Jacobian matrix for the observation function is simply the observation matrix
return H
# Simulation parameters
num_steps = 100
dt = 1.0
alpha = 0.01 # Rate of change of angular velocity
# Observation matrix, assuming we can directly observe the x, y, and z position
H = np.eye(3, 5)
# Process noise covariance matrix
Q = np.diag([0.1, 0.1, 0.1, 0.1, 0.01])
# Measurement noise covariance matrix
R = np.diag([0.5, 0.5, 0.5])
# Initial state estimate and covariance
x0 = np.array([0, 20, 0, 0.5, 0.1]) # [x, y, z, v, omega]
P0 = np.eye(5)
# Instantiate the EKF
ekf = ExtendedKalmanFilter(state_transition, jacobian_F, H, jacobian_H, Q, R, x0, P0)
# Generate true trajectory and measurements
true_states = []
measurements = []
for t in range(num_steps):
u = np.array([dt, alpha])
true_state = state_transition(x0, u) # This would be your true system model
true_states.append(true_state)
measurement = true_state[:3] + np.random.multivariate_normal(np.zeros(3), R) # Simulate measurement noise
measurements.append(measurement)
x0 = true_state # Update the true state
# Now we run the EKF over the measurements
estimated_states = []
for z in measurements:
ekf.predict(u=np.array([dt, alpha]))
ekf.update(z=np.array(z))
estimated_states.append(ekf.x)
# Convert lists to arrays for plotting
true_states = np.array(true_states)
measurements = np.array(measurements)
estimated_states = np.array(estimated_states)
# Plotting the results
fig = plt.figure(figsize=(12, 9))
ax = fig.add_subplot(111, projection='3d')
# Plot the true trajectory
ax.plot(true_states[:, 0], true_states[:, 1], true_states[:, 2], label='True Trajectory')
# Increase the size of the start and end markers for the true trajectory
ax.scatter(true_states[0, 0], true_states[0, 1], true_states[0, 2], label='Start (Actual)', c='green', marker='x', s=100)
ax.scatter(true_states[-1, 0], true_states[-1, 1], true_states[-1, 2], label='End (Actual)', c='red', marker='x', s=100)
# Plot the measurements
ax.scatter(measurements[:, 0], measurements[:, 1], measurements[:, 2], label='Measurements', alpha=0.6)
# Plot the start and end markers for the measurements
ax.scatter(measurements[0, 0], measurements[0, 1], measurements[0, 2], c='green', marker='o', s=100)
ax.scatter(measurements[-1, 0], measurements[-1, 1], measurements[-1, 2], c='red', marker='x', s=100)
# Plot the EKF estimate
ax.plot(estimated_states[:, 0], estimated_states[:, 1], estimated_states[:, 2], label='EKF Estimate')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
plt.show()

简单总结
在这篇博客中,我们深入探讨了如何构建和应用 Kalman 滤波器(KF),以及如何实现扩展卡尔曼滤波器(EKF)。让我们总结一下这些模型的使用案例、优点和缺点。
KF: 该模型适用于线性系统,在这种情况下,我们可以假设状态转换和观测矩阵是状态的线性函数(带有一些高斯噪声)。你可能会在以下情况下考虑应用此算法:
-
追踪以恒定速度移动的物体的位置和速度
-
如果噪声是随机的或可以通过线性模型表示,则用于信号处理应用
-
如果基础关系可以线性建模,则用于经济预测
KF 的关键优势在于(只要你遵循矩阵计算)算法非常简单,计算量比其他方法少,并且可以在时间上提供非常准确的预测和估计。缺点是线性假设,这在复杂的现实世界场景中通常不成立。
EKF: 我们可以将 EKF 视为 KF 的非线性等效物,借助雅可比矩阵。你会在处理以下情况时考虑 EKF:
-
测量和系统动态通常是非线性的机器人系统
-
追踪和导航系统,这些系统通常涉及非恒定速度或角速度变化,例如追踪飞机或航天器。
-
在最现代的“智能”汽车中实现巡航控制或车道保持的汽车系统。
EKF 通常能比 KF 提供更好的估计,尤其是对于非线性系统,但由于雅可比矩阵的计算,它可能变得计算量大得多。这种方法还依赖于泰勒级数展开的一级线性近似,这对于高度非线性的系统可能不是一个合适的假设。雅可比矩阵的加入也可能使模型设计更具挑战性,因此尽管有其优势,对于简化和互操作性,可能更适合实现 KF。
除非另有说明,所有图片均为作者所用。
扩展大规模语言模型中的上下文长度
原文:
towardsdatascience.com/extending-context-length-in-large-language-models-74e59201b51f
如何将你的 Llama 变成长颈鹿
·发表于 Towards Data Science ·9 分钟阅读·2023 年 10 月 15 日
--

图片由作者提供。(AI 生成的 Llamas)
上下文长度是指模型在生成文本时可以记住的最大标记数。更长的上下文窗口使模型能够更好地理解文本中的长距离依赖关系。具有更长上下文的模型可以在文本中构建相隔较远的思想之间的联系,生成更具全球一致性的输出。
在训练过程中,模型以块或固定长度窗口处理文本数据。模型需要在较长文本上进行训练,以真正利用长上下文。训练序列必须包含文档、书籍、文章等,拥有数千个标记。
训练数据的长度限制了可用的上下文长度。
那么,为什么我们不在更长的序列上训练模型呢?
不要急于求成。
增加上下文长度会增加模型必须学习准确预测的可能标记组合的数量。
这使得更强大的长距离建模成为可能,但也需要更多的内存和处理能力,从而导致更高的训练成本。
没有任何优化的情况下,计算量与上下文长度的平方成正比——这意味着一个 4096 标记的模型需要比 512 标记模型多 64 倍的计算。
你可以使用稀疏或近似注意力方法来降低计算成本,但它们也可能影响模型的准确性。
训练和使用大上下文语言模型面临三个主要挑战:
-
将长上下文适配到模型中。
-
加速推理和训练,以便它们不会拖延太久。
-
确保高质量推理,保持对全部上下文的意识。
注意力是一个复杂的操作
注意力机制是变换器模型的核心组件。它将序列的不同位置联系起来以计算其表示,使模型能够关注文本的相关部分并更好地理解它。将变换器扩展到更长的序列面临挑战,因为全注意力的复杂度是二次的。

自注意力涉及两个矩阵乘法。图像来源于原始论文。[1]
堆叠自注意力层允许在文本中建模长距离依赖关系。Transformers 中使用的标准注意力机制,计算所有可能的输入标记配对的注意力权重,其复杂度为O(n²)。这意味着计算和内存需求随着输入序列长度的平方增长,限制了 Transformers 的可扩展性和效率。
在生成文本时,模型必须首先计算注意力矩阵。对于 100K 上下文和二次注意力,模型可能需要几分钟才能开始生成文本。
让我们深入探讨提高注意力效率的方法,从近似到硬件优化。
提高注意力效率
减少二次成本已成为一个活跃的研究领域。提出的方法可以分为两大类:近似注意力和通过硬件优化实现的精确注意力。
近似技术约束序列位置之间的交互。稀疏注意力限制每个注意力头的非零注意力权重数量,而本地注意力将交互限制在滑动窗口内。这些近似减少了计算成本,但可能会降低复杂任务的准确性。[2]
最近的工作集中于优化注意力以利用 GPU 架构。
稀疏注意力通过仅计算输入标记子集的注意力权重来近似注意力,从而节省时间和内存,而不是计算所有可能的配对。有不同的稀疏注意力实现方式,如使用固定或静态模式(例如,本地、步幅或块注意力)或依赖于输入序列的动态或自适应模式(例如,entmax 或动态稀疏注意力)。

二次注意力(左)计算输入标记之间的每一个可能组合。稀疏注意力(右)将计算限制在附近的标记上。[2]
稀疏注意力可以提高 Transformers 的效率和可扩展性,尤其是在长序列中,但它也可能牺牲一些表示能力和准确性。二次注意力可以实现高性能和高质量,但它可能计算开销大,不适用于大规模应用。因此,注意力机制中的稀疏性和复杂性之间存在权衡。
Flash Attention
基本直觉是避免将大规模的 N x N 注意力矩阵物化,这需要在序列长度 N 上进行平方级的读写操作。
FlashAttention应用了两种技术——分块和重新计算。分块将输入拆分成块,加载到快速 GPU 片上 SRAM 中。注意力机制按块计算,以避免物化整个矩阵。重新计算在反向传播过程中仅存储足够的信息来重建片上注意力矩阵,避免存储大型中间结果。[3]
作者分析了 IO 复杂性,证明 FlashAttention 需要 O(N²/M)的内存访问,而标准注意力机制需要 O(N²),其中 M 是 SRAM 大小。这种 IO 意识使得 FlashAttention 在重新计算带来的 FLOP 增加的情况下运行得更快。
实验验证了加速效果——FlashAttention 训练 BERT 比 MLPerf 记录快 15%,GPT-2 快 3 倍,Long Range Arena 快 2.4 倍。

FlashAttention 不是在较慢的 HBM 上计算整个注意力矩阵,而是将块复制到 SRAM 中。[3]
这个想法在FlashAttention-2中得到了进一步的发展。改进的重点是增强序列块之间的并行性,并优化线程块和 GPU 上的 warp 之间的工作分配。关键技术包括减少非矩阵乘法操作,将注意力计算分配到线程中以增加占用率,以及在 warps 之间分配工作以减少共享内存流量。实验证明,FlashAttention-2 在 FlashAttention 上实现了大约 2 倍的加速,在 A100 GPU 上达到了理论峰值 FLOP 的 73%。用于端到端训练 GPT 模型时,训练吞吐量达到了每个 A100 225 TFLOPs/s,比 FlashAttention 快 1.3 倍。[4]
这些改进有望在相似成本下实现对比以往更长序列的模型训练。加速注意力机制可以加快推理和训练速度,但在保持高输出质量的同时将文本适配到模型中仍然是一个问题。
让我们看看如何解决这个问题。
模型在固定长度的序列上进行预训练
高效的训练和推理还不足以获得高质量的模型。
上下文长度扩展主要有两种范式:微调外推,其中 LLM 在更长的上下文上进一步更新其权重,以及零-shot 外推,其中模型在长上下文上进行评估,而短上下文训练的权重不发生变化。[5]
为了扩展上下文,大多数方法专注于修改变换器注意力机制中使用的位置编码系统,以指示标记在输入序列中的位置。其想法是,在位置编码中表示更长的输入序列将允许 LLM 关注这些更长的序列。

位置编码用于使你的模型理解句子中的顺序。基于原始论文的图像。[1]
位置编码用于使你的模型理解句子中的顺序。
位置嵌入在将输入标记嵌入送入模型之前被添加,以便模型能够使用序列的顺序。它们将离散的位置 ID 映射到连续的嵌入向量。
通常,位置嵌入是基于位置 ID 算法定义的。在原始 Transformers 论文中,他们使用了三角函数,其中位置嵌入向量的每个维度都遵循正弦波模式。[1]
在 LLaMa 中,使用了旋转位置嵌入(RoPE),其中位置嵌入通过旋转嵌入动态计算。标记和位置维度通过三角函数一起旋转。旋转量由位置 ID 决定。
无论位置嵌入如何生成,模型在处理比预训练时更长的序列时都很难进行泛化。(上下文外推) 正弦位置嵌入方法的外推能力有限,在推断期间只能再处理几十个标记,然后性能会下降。[6]
新的方法如线性缩放和位置插值已经被引入以解决这一限制。
线性缩放
使用线性缩放,位置嵌入被重新缩放以适应不同的序列长度。如果预训练模型的嵌入长度为 L,则在对长度为 N 的序列进行推断时,每个位置嵌入向量会乘以 N/L。这廉价地近似了新长度的嵌入,同时保留了预训练嵌入的特性。线性缩放显著提升了长序列的性能。然而,模型在比预训练长度长得多的序列上仍然表现不佳。线性缩放过程通过将多个位置嵌入合并在一起来破坏信息。

在位置插值中,上下文范围不会扩展。相反,有更多的中间位置。[7]
线性缩放/插值在扩展上下文时整体表现最佳,同时在截断基方法中也显示出了一定的前景。通过在评估时使用更长的缩放因子,可以获得进一步的提升。
这种方法由 kaiokendev 和 Meta 同时研究。[8]
首个发布的 YaRN 是一个能够实现 128k 上下文长度的模型。
后来,发布了三个额外微调的模型,分别具有 8k、16k 和 32k 的上下文,称为 Giraffe。[5]
带有线性偏差的注意力
ALiBi 引入了一种更简单的方法,消除了位置嵌入。相反,它通过一个与查询和键之间距离成比例的惩罚来负向偏置注意力得分。这种对近期上下文的归纳偏置允许以低计算成本进行推断。实验表明,使用 ALiBi 训练的 1.3B 参数模型在 2048 个标记上测试时,与在 2048 个标记上训练的正弦模型具有相同的困惑度,训练速度更快,内存使用更少。[6]
MosaicML 的 MPT-7B 模型利用 ALiBi 架构,实现了对极长上下文长度(高达 65k 个标记)的外推,远远超过了其他开源模型的限制。通过用 ALiBi 替代位置嵌入,该模型在推理过程中能够处理任意长度的输入,而不受限于固定的上下文窗口。这通过将 MPT-7B 微调为 MPT-7B-StoryWriter-65k+ 使用 65k 个标记的小说摘录来证明,使其能够在 68k 个标记的《了不起的盖茨比》完整文本中生成连贯的续篇。
位置嵌入方法的选择取决于模型大小、预期序列长度以及问题对泛化的重要性等因素。
仍然有很大的改进空间,因为所有方法在长度增加时准确度都会下降,即使困惑度合理。
奖励模型:RWKV
改进注意力的另一种方法是不使用它。
Receptance Weighted Key Value (RWKV) 模型由 Peng 等人提出,旨在调和序列处理任务中的计算效率和模型性能之间的权衡。RWKV 将 Transformers 和 RNN 的某些方面结合成一种新颖的架构,实现了线性扩展。[9]
一个关键的创新是将注意力机制重新表述为使用标量交互而不是点积,从而消除二次瓶颈。

用于语言建模的 RWKV 架构。[9]
RWKV 实现了一种线性注意力的变体,没有近似,以提高效率。该模型在训练中像 Transformers 一样并行计算,但在推理时表现为 RNN 解码器,实现了速度和内存的恒定,能够处理无限上下文。
实验表明 RWKV 在语言任务中与 Transformers 具有竞争力,同时需要更低的计算成本。
然而,鉴于其递归特性,可能需要仔细调整提示,并且它可能在处理极长序列时难以保持详细跟踪。
Raven 是一个 RWKV 模型的示例。这个模型与最小的基于 Llama 的模型竞争。在我的测试中,Raven 显示出对语法和语义意义的良好理解,但倾向于经常产生幻觉。
结论
语言模型从更长的上下文中受益。然而,由于标准注意力机制的原因,更长的上下文会使训练成本呈二次方增加。最近的研究集中在通过近似注意力来提高效率。稀疏注意力和线性注意力等方法有所帮助。优化硬件效率也是一种有效的手段,正如 FlashAttention 通过利用 GPU 内存层次结构所展示的那样。
预训练模型在处理比训练期间看到的更长的上下文时仍然会遇到困难。像位置嵌入的线性缩放和 ALiBi 等技术能够支持更长的上下文。对更长上下文的微调进一步适应了模型。
目前最先进的模型将上下文长度大大超越了之前的限制。
YaRN 和 Giraffe 使用位置插值。MPT-65k 使用 ALiBi 处理 65,000 个标记的上下文。RWKV 提出了线性缩放注意力机制,允许在推理时使用无限长度的上下文。
更长的上下文使得模型能够处理完整的文档和书籍,但在维持输出质量方面仍面临挑战。
如果你喜欢这篇文章,加入 文本生成 —— 我们的新闻通讯每周有两篇最新的生成 AI 和大型语言模型的见解。
此外,你还可以在 LinkedIn上找到我。
参考文献
[1] [1706.03762] 注意力机制即一切 (arxiv.org)
[2] [1904.10509] 使用稀疏变换器生成长序列 (arxiv.org)
[3] [2205.14135] FlashAttention: 快速且内存高效的精确注意力机制与 IO 关注 (arxiv.org)
[4] [2307.08691] FlashAttention-2: 更快的注意力机制,具备更好的并行性和工作划分 (arxiv.org)
[5] [2308.10882] Giraffe: 扩展 LLM 上下文长度的冒险 (arxiv.org)
[6] [2108.12409] 训练短序列,测试长序列:具有线性偏置的注意力机制实现输入长度外推 (arxiv.org)
[7] [2306.15595] 通过位置插值扩展大型语言模型的上下文窗口 (arxiv.org)
[8] [2309.00071] YaRN: 大型语言模型的高效上下文窗口扩展 (arxiv.org)
[9] [2305.13048] RWKV: 为 Transformer 时代重新定义 RNN (arxiv.org)
从(Azure)SQL Server 大型表中提取数据到 RFC 4180 兼容 CSV 文件
从(Azure)SQL Server 大型表中提取包含特殊字符的字符串到 CSV 文件中的噩梦
·发表于Towards Data Science ·阅读时间 21 分钟·2023 年 1 月 5 日
--

(图片来自 Unsplash)
当一组来自公司外部的数据科学家被聘用来实现机器学习模型时,你必须以某种方式与他们分享用于模型训练的数据。如果上述团队无法直接访问数据库中持久化的数据,第一种选择是将数据从数据库中提取到 CSV 格式的文件中。考虑到这些数据通常是大量的(超过 5 GB),而且某些字段可能包含特殊字符(逗号,与字段分隔符重合;回车和/或换行符),非开发用户使用的通常导出工具可能不够合适,甚至可能导致内存问题。
在本文中,你将看到如何使用 PowerShell 函数解决从(Azure)SQL Server 数据库中提取包含特殊字符的大量数据到 RFC 4180 兼容 CSV 文件中的问题。
当你需要从(Azure)SQL Server 数据库中提取数据时,用户首先想到的工具是SQL Server Management Studio(SSMS)和Azure Data Studio(ADS)。这是因为这两者都包含简单的功能,允许你通过几次点击从数据库中提取数据。
与(Azure)SQL Server 完美配合的工具是 SSMS。最近,微软在向 ADS 中添加功能方面进行了大量投资,使其成为 Azure 及其他平台上微软数据平台的首选工具。因此,当你安装最新版本的 SSMS 时,安装程序也会在后台安装 ADS。
任何涉及导入 CSV 文件以加载数据集的第三方系统必须基于一个定义 CSV 格式的标准。因此,在进行实际测试之前,让我们看看是否有 CSV 格式的标准定义。
什么是 RFC 4180
RFC 4180 是一个标准,它规范了用于逗号分隔值(CSV)文件的格式和与 CSV 格式相关的特定多用途互联网邮件扩展(MIME)类型(“text/csv”)。该标准的内容可以在这里找到:
[## RFC 4180:逗号分隔值(CSV)文件的通用格式和 MIME 类型
INFORMATIONAL 7111 Errata Exist 网络工作组 Y. Shafranovich 请求评论:4180 SolidMatrix…
www.rfc-editor.org](https://www.rfc-editor.org/rfc/rfc4180?source=post_page-----1cb09a7a0883--------------------------------)
如前面链接中格式的定义所示,虽然前四点比较明显,但其余三点需要仔细阅读:
-
每个字段可能被括在双引号内,也可能不被括在双引号内(但
一些程序,如 Microsoft Excel,不使用双引号
如果字段没有被双引号括起,那么
双引号可能不会出现在字段内。
-
包含换行符(CR/LF)、双引号和逗号的字段
应被括在双引号内。
-
如果使用双引号来括起字段,那么一个双引号
出现在字段内的必须通过在其前面加上转义字符来进行转义
另一个双引号*。
同时考虑链接中给出的示例,可以看出字段的值只有在必要时才会用双引号括起。对于那些只需部分值使用双引号的字段,将所有值都用双引号括起来是没有意义的。
当你需要使用 CSV 格式文件与第三方系统共享信息时,适用以下内容:
重要的是,你从导出中生成的 CSV 文件必须符合 RFC 4180,以确保文件可以被任何提供 CSV 文件导入功能的外部系统读取。
为了测试上述工具如何以 CSV 格式提取数据,让我们创建一个包含 RFC 4180 标准中提到的特殊字符和 Unicode 字符的简单表,以确保文本字段内容的通用性。
创建一个包含特殊字符的虚拟表
首先,你需要使用以下脚本在 SQL Server 实例中创建extract_test表:
CREATE TABLE [dbo].extract_test NOT NULL,
[name] nvarchar NULL,
[notes] nvarchar NULL
)
然后,你可以使用以下脚本向该表中添加数据:
SET IDENTITY_INSERT [dbo].[extract_test] ON
GO
INSERT [dbo].[extract_test] ([id], [name], [notes]) VALUES (1, N'Luca', N'let''s add a carriage return
here')
GO
INSERT [dbo].[extract_test] ([id], [name], [notes]) VALUES (2, N'Zavarella, Luca', N'the name contains a comma')
GO
INSERT [dbo].[extract_test] ([id], [name], [notes]) VALUES (3, N'Luca Zavarella', N'here we have a comma and a double quotation mark: ,"')
GO
INSERT [dbo].[extract_test] ([id], [name], [notes]) VALUES (4, N'秋彦', N'this japanese name means "bright prince"')
GO
SET IDENTITY_INSERT [dbo].[extract_test] OFF
GO
从 INSERT 语句的内容中可以看出,我们提供了标准中提到的所有特殊字符。我们还使用了日文字符,以便验证 CSV 文件是否正确使用 Unicode 字符表进行编写。
显然,此情况下创建的表不会是 5 GB 的表,而是包含特殊字符以测试 CSV 格式导出的。这里是 ADS 中的 SELECT 输出:

图 1 — 在 ADS 中输出的虚拟表内容(作者提供)
不必担心 ADS 或 SSMS 的输出网格中未显示回车符。由于 INSERT 该行的方式,回车符确实存在。
所以,让我们尝试使用 SSMS 和 ADS 从这个表中提取数据。
使用微软用户友好的工具来提取数据
我们首先尝试使用传统工具,即 SQL Server Management Studio 来进行操作。
使用 SSMS 提取数据
打开 SSMS 并连接到数据库实例后,右键点击刚创建表所在的数据库名称,选择任务,然后选择导出数据:

图 2 — 使用 SSMS 从数据库导出数据(作者提供)
您将看到一个描述提取数据活动的初始屏幕。如果继续操作,将会看到这个窗口:

图 3 — 从导出向导中选择数据源(作者提供)
选择 SQL Server 客户端数据源,输入您的服务器实例名称,然后选择登录数据库所使用的身份验证方式。在我的例子中,因为测试表存储在 Azure SQL 数据库上,我使用了 SQL Server 身份验证来访问我的test-sql-bug数据库,如图 3所示。
在向导的下一个屏幕上,您可以选择导出目标。在我们的例子中,选择平面文件目标,通过浏览按钮在您首选的文件夹中创建一个 CSV 目标文件(记得在点击浏览后打开的窗口中选择 CSV 扩展名)。记得勾选Unicode标志,以确保还处理了我们示例中的日文字符。之后,选择分隔符作为格式,将文本限定符保持为“

图 4 — 选择数据输出的目标(作者提供)
在下一个窗口中选择从一个或多个表或视图中复制数据,然后再次点击下一步。
在出现的配置窗口中,您可以选择表 [dbo].[extract_text]` 作为源表或视图。对于其他选项,您可以保持默认设置,因为行分隔符(CR\LF)和列分隔符(逗号)均按照 RFC 4180 标准定义。然后点击下一步:

图 5 — 配置平面文件目标选项(作者提供)
在下一个窗口中保持 立即运行 选项被选中,然后点击 完成。将出现选项摘要窗口。再次点击 完成 开始提取。完成后,点击 关闭。
如果你现在尝试用文本编辑器(而不是 Excel)打开输出的 CSV 文件,你会注意到以下内容:

图 6 — SSMS 导出向导输出(没有文本限定符)(作者提供)
基本上,在这种情况下,导出向导会提取每个文本字段的内容,而不管它是否可能包含特殊字符(逗号和回车)。这意味着文本字段中包含的任何回车都会被系统解释为行分隔符,就像文本字段中包含的任何逗号会被解释为字段分隔符一样。另一方面,Unicode 字符已被正确处理。因此,生成的 CSV 文件将无法被任何需要导入这些信息的第三方系统识别为正确。
如果你尝试重复导出,这次将双引号作为文本限定符输入,你将得到以下结果:

图 7 — 使用双引号作为文本限定符的 SSMS 导出向导输出(作者提供)
在这种情况下,所有提取的值都被双引号包围,包括标题。然而,这会强迫必须读取数据的外部系统将所有数字值视为字符串。此外,如果文本字段中的值包含双引号字符,则不会转义,从而给外部系统带来解析问题。因此,再次生成的 CSV 文件将无法被任何需要导入这些信息的第三方系统识别为正确。
关于在非常大数据量上进行提取操作的可扩展性,毫无问题,因为导出向导使用 SQL Server 集成服务(SSIS)作为其引擎,SSIS 开发用于处理巨大的数据量。
此外,有时你可能需要对数据源的数据类型采取措施,以避免在使用导出向导时出现一些错误,如本博客文章所强调:
2017 年 12 月 7 日 在这篇文章中,我们将展示如何使用…将 SQL Server 数据导出到平面文件中。
我们可以通过以下陈述来总结这一部分:
使用 SSMS 导出向导作为从(Azure)SQL Server 数据库提取 CSV 格式数据的工具,并不能保证具有符合 RFC 4180 定义的标准格式,因此提取的信息可能无法被外部系统正确读取。
相反,让我们看看使用 Azure Data Studio 提取 CSV 格式信息时会发生什么。
使用 ADS 提取数据
一旦打开 Azure Data Studio,首先要做的是添加一个新的服务器实例连接。注意,从较新的版本开始,加密选项默认设置为True。如果你连接到 Azure SQL 数据库,这不会导致连接错误,但如果你的数据源是本地 SQL Server,则可能会产生错误。在这种情况下,你可以将选项设置为False。
也就是说,为了在 ADS 中提取表(或视图、查询)的内容,你必须首先执行一个 SELECT 查询,并在运行时将其内容显示在输出网格中。之后,只需按下网格右上角的“保存为 CSV”按钮:

图 8 — 在 ADS 中以 CSV 格式保存查询输出(作者提供)
一个输出文件选择窗口将打开,允许你命名要提取的文件(在我们的例子中是ExtractTestADS.csv)。一旦按下保存按钮,CSV 文件的内容将直接显示在 ADS 中:

图 9 — ADS 以 CSV 格式输出的结果(作者提供)
哇!ADS 生成的输出在所有方面都符合 RFC 4180 标准!因此,ADS 似乎是从(Azure)SQL 数据库中提取 CSV 格式信息的完美工具。
然而,这里存在一个扩展性问题。由于 ADS 要求查询输出首先在输出网格中暴露,这限制了处理大量数据时的功能。在这些情况下,将所有数据包含在一个网格中会占用系统大量内存,导致应用程序崩溃。
因此,我们可以得出以下结论:
ADS 的 CSV 格式数据导出过程保证了符合 RFC 4180 标准的输出。然而,当待导出的数据集大小相对有限时,使用 ADS 进行提取任务是合适的。当需要提取超过 2–3 GB 的数据时,ADS 可能会占用整个系统内存并导致崩溃。
总的来说,我们可以得出以下结论:
不幸的是,微软的数据平台工具提供的用户友好功能无法按照 RFC 4180 标准提取大量数据为 CSV 格式。
让我们尝试通过专家用户知道的更具体的工具来实现我们的目标。
使用 BCP 工具提取数据
批量复制程序(BCP)命令行工具用于将大量新行导入 SQL Server 表中或将数据从表中导出到用户指定格式的数据文件。这是即使在非常大量的数据中也能尽可能快地导入或导出数据的解决方案。因此,它在可扩展性方面没有问题。
除了在标准本地 SQL Server 安装时默认安装外,并且除了可以在 Windows 操作系统上单独安装之外,BCP 工具还可以通过 Azure 云终端与 Azure SQL 数据库交互,正如这篇博客文章所示:
导出 CSV 文件从 Azure SQL 数据库 [## Export CSV File from Azure SQL Databases
有很多方法可以从 Azure SQL 数据库导出 CSV 文件。本文展示了两种方法:bcp 和 sqlcmd…
不深入细节,BCP 的主要问题是它不能提取表头,也不能以简单的方式处理双引号。这一点通过Erland Sommarskog的参考指南可以证明,该指南报告了获得表头和双引号的多个变通方法,如你所见:
[## 在 SQL Server 中使用批量加载工具
由 Erland Sommarskog 编写的 SQL 文本,SQL Server MVP。最近更新时间 2021-01-26。版权适用于此文本。请参见…
这种方法的一个缺点是你必须提前知道哪些字段需要双引号(除非你为所有文本字段提供双引号)。一般来说,我无法提前知道哪些字段可能需要双引号。我只想无忧地提取数据。然而,如果你能够通过 Erland 的建议获取标题和双引号,那么这些引号将应用于所选字段中的所有值。正如 Erland 本人指出的:
… 假设数据应始终被引号括起来。如果你只想在需要时才加引号,你将需要在查询中处理这个问题,这超出了本文的范围。我能说的就是:祝好运。或者更直接一点:如果可能的话,避免这样做。
此外,如果包含双引号的字段有一个字符串同时包含逗号和双引号,则 BCP 不处理通过双引号转义双引号的特性。
因此,我们可以声明:
使用 BCP 以 CSV 格式导出数据,包括头部和双引号,对于非专家用户来说非常复杂。一个缺点是你必须提前知道哪些字段需要提供双引号。此外,它仍然不会产生与 RFC 4180 标准一致的格式。
我不会详细介绍微软的另一个命令行工具SQLCMD的使用,因为问题类似于本节中提到的问题。
那么,怎么办呢?如何继续?由于我在互联网上找不到能够以 RFC 4180 兼容的 CSV 格式提取数据并同时处理非常大数据量的应用程序,因此唯一可能的解决方案是开发一个自定义解决方案,即使是非专家用户也可以轻松使用。让我们看看这个解决方案是如何工作的。
在 PowerShell 中开发自定义解决方案
当我决定为这个问题开发一个具体解决方案时,我首先问自己使用什么编程语言。第一个想到的语言肯定是 Python。然而,我接着想到,标准用户在 Windows 机器上接触自动化的世界时可能不懂 Python,而且操作系统上也不会预装 Python。这就是为什么选择了PowerShell,它提供了一个专门用于 SQL Server 的模块。
SQL Server PowerShell 模块的问题
我第一次尝试使用的是SQL Server PowerShell 模块,它允许 SQL Server 开发人员、管理员和商业智能专业人员自动化数据库开发和服务器管理。
具体来说,我尝试使用的命令是Invoke-Sqlcmd,它用于将查询发送到 Azure SQL 数据库以检索数据。这个命令除了调用sqlcmd.exe 命令行工具外别无他用,后者常被自动化过程用来从 SQL Server 数据库中检索信息。到目前为止,一切都很好。问题在于 Invoke-Sqlcmd 会将所有查询输出直接存入 PowerShell 数据结构中。正如你可以猜到的那样,当查询输出超过 3-4 GB 时,你会遇到与在 ADS 中提取数据时相同的问题,即由于过度消耗 RAM 系统变得不稳定。
因此,我认为直接在 PowerShell 中使用ADO.NET 对象是合适的,以尝试绕过这个问题。让我们看看我在这个解决方案中是如何使用它们的。
批量导出数据到输出文件
我的解决方案的主要思想是始终使用一个中间数据结构(一个 DataTable),该结构会收集查询数据,但一次只收集一定数量的行。一旦达到中间数据结构的最大容量,它的内容会被写入目标文件,结构会被清空,并立即加载数据源中的下一批数据:

图 10 — 解决方案的主要过程(图片由作者提供)
该过程持续进行,直到数据源中没有新行可读。
你可能会想知道为什么我使用了一个中间的 DataTable,而没有通过 StreamWriter 实现直接写入输出文件。答案在于直接使用 PowerShell 的 Export-Csv cmdlet 的能力。
使用 Export-Csv 写入数据
我在解决问题时设定的目标之一是,尽量避免重新发明轮子,如果已经有方便的解决方案可以完全或部分解决问题。在这种情况下,我认为可以省略重新编写处理 RFC 4180 标准所提及特殊字符的所有逻辑,直接使用 Export-Csv cmdlet。
查阅 PowerShell cmdlet 指南时,我意识到 Export-Csv 仅在版本 7 中提供控制双引号使用的参数:

图 11 — Export-Csv 版本 6 和 7 的区别(图片由作者提供)
具体而言,UseQuotes 参数 提供了值 AsNeeded,其功能定义如下:
仅为包含分隔符字符、双引号或换行符的字段加引号
基本上,这正是我们为了满足 RFC 4180 标准要求所需要的。
如果你希望仅为某些字段提供双引号,你可以通过 QuoteFields 参数明确指定它们。
目前,PowerShell 版本存在一个小问题。请注意,Windows 10、Windows 11 和 Windows Server 2022 预装了版本 5.1 的 Windows PowerShell(也称为桌面版)。为了使用较新的 Export-Csv cmdlet,你必须安装较新版本的 PowerShell(至少 PowerShell 7.0),这实际上是基于 .NET Core 的 Windows PowerShell 的独立软件(如果你有兴趣了解它的演变,你可以在此链接中了解更多)。
强调以下几点很重要:
由于该模块是为 PowerShell Core 版本开发的,因此也可以在 Linux 和 macOS 系统上使用。
话虽如此,我们来看看如何使用这个新模块。
如何使用 SqlBulkExport 模块
新的 SqlBulkExport 模块可以在 GitHub 上找到,链接如下:
## GitHub - lucazav/sql-bulk-export: 这个 PowerShell 模块包含两个对导出大量数据很有用的函数…
这个 PowerShell 模块包含两个对导出大量数据非常有用的函数。
它提供了两个功能:
-
Export-SqlBulkCsv:将 SQL Server 数据库表、视图或查询的内容导出到一个符合 RFC 4180 的 CSV 文件中。此功能支持导出大量结果集,将 CSV 文件内容分多次写入。
-
Export-SqlBulkCsvByPeriod:将 SQL Server 数据库表、视图或查询的内容导出到多个符合 RFC 4180 的 CSV 文件中,按时间段(按年、按月或按日)进行拆分,基于所选日期字段的内容。此功能支持导出大量结果集,将每个 CSV 文件的内容分多次写入。
两个函数都需要以下参数:
-
ServerName:要连接的 SQL Server 实例名称。
-
Port:SQL Server 实例端口号。默认值为 1433。
-
DatabaseName:要连接的 SQL Server 数据库名称。
-
SchemaName:提取数据的表或视图的数据库架构。默认值为“dbo”。
-
TableViewName:提取数据的数据库表或视图的名称。
-
Query:提取数据的 T-SQL 查询。
-
User:用于连接数据库的用户名。
-
Password:用于连接数据库的用户名的密码。
-
ConnectionTimeout:连接超时时间(以秒为单位)。默认值为 30 秒。
-
DatabaseCulture:数据库文化代码(例如 it-IT)。它用于正确提取小数分隔符。默认情况下为“en-US”。
-
BatchSize:写入到输出文件中的批次大小(行数),直到提取的数据结束。
-
OutputFileFullPath:输出文件的完整路径(包括文件名和 csv 扩展名)。
-
SeparatorChar:用于构建在控制台中显示的字符串分隔符的字符。
Export-SqlBulkCsvByPeriod 函数提供了三个更多的必填参数,以便根据时间段对结果集进行分区:
-
DateColumnName:按时间段拆分数据的日期/时间类型列。
-
StartPeriod:时间段字符串(允许的格式:“yyyy”、“yyyy-MM”、“yyyy-MM-dd”),表示从哪个时间段开始提取数据(包括该时间段)。
-
EndPeriod:时间段字符串(允许的格式:“yyyy”、“yyyy-MM”、“yyyy-MM-dd”),表示提取数据的结束时间段(包括该时间段)。
显然,两个输入时间段所使用的格式必须一致。
重要的是要注意,使用 Export-SqlBulkCsvByPeriod 函数提取按时间段划分的多个 CSV 文件只能通过表/视图实现,而不能通过查询实现。如果需要选择字段和应用过滤器,必须首先暴露一个具有这些逻辑的视图,然后才能按时间段提取多个 CSV 文件。
此外,Export-SqlBulkCsvByPeriod 函数在输出 CSV 文件名中涉及使用字符串标记 {}(大括号),该标记将被与 CSV 文件中交易时间段相关的字符串替换。
两个函数会根据是否传递了 User 和 Password 参数自动识别是使用 Windows 身份验证还是 SQL Server 身份验证。
在继续示例之前,请确保您已安装最新版本的 PowerShell。
安装最新的 PowerShell 和 SqlBulkExport 版本
为了在 Windows 机器上安装最新版本的 PowerShell,请从 此链接 下载并运行 64 位安装程序(在我们的例子中为 7.3.0 版本)。
点击 Next 继续所有的设置向导窗口。然后点击 Finish。你会看到 PowerShell 7 (x64) 提示符已安装到你的应用程序中:

图 12 — PowerShell 7 刚刚安装完成(图像由作者提供)
运行它,你会看到 PowerShell 提示符准备好接收你的命令:

图 13 — PowerShell 7 提示符准备就绪(图像由作者提供)
你可以输入 $PSVersionTable 命令并按 Enter 键检查是否一切正常:

图 14 — PSVersionTable 输出(图像由作者提供)
很好!如有必要,您还可以在 Linux 或 macOS 上安装 PowerShell。
现在,您需要下载 SqlBulkExport 模块文件:
-
前往 SqlBulkExport GitHub 仓库的 发布页面 下载最新版本的 Source code.zip 文件。
-
一旦文件保存在您的机器上,解压缩它并将其内容复制到
C:\Temp文件夹中(或您可以选择您喜欢的文件夹)。这样,您的模块文件将被保存在C:\Temp\sql-bulk-export-<version>文件夹中。
好的!现在您可以尝试一些示例了。
将我们的虚拟表内容导出到一个 CSV 文件中
让我们尝试提取本文开始时创建的 extract_test 表的内容,以检查其是否符合 RFC 4180 标准。在我们的案例中,相关表被保存在 Azure SQL 数据库中:
-
打开 PowerShell 7 提示符,输入
cd C:\Temp\sql-bulk-export-<version>命令并按 Enter 键,将工作目录更改为模块目录。 -
输入
Import-Module -Name ".\SqlBulkExport.psd1"命令来导入 SqlBulkExport 模块。 -
输入
Export-SqlBulkCsv -ServerName "<your-server-name>" -DatabaseName "<your-database-name>" -User "<username>" -Password "<password>" -TableViewName "export_test" -BatchSize 30000 -OutputFileFullPath "C:\Temp\ExtractedTestPS.csv"命令,将数据库表(或视图)的内容以每批 30K 行的方式导出到 ExtractedTestPS.csv 文件中。以下是输出:

图 15 — 提取虚拟表内容到 CSV 文件的命令控制台输出(图片来源于作者)
这是输出 CSV 文件的内容:

图 16 — 使用 SqlBulkExport 模块提取的 CSV 文件中的虚拟表(图片来源于作者)
如您所见,输出的 CSV 文件内容符合 RFC 4180 标准。由于使用的虚拟表行数较少,因此只使用了一个批次进行提取。现在我们来尝试提取一个拥有几万行的表的内容。
将表/视图的内容导出到一个 CSV 文件中
与之前一样,我们要用于提取数据的表也保存在 Azure SQL 数据库中:
-
打开 PowerShell 7 提示符,输入
cd C:\Temp\sql-bulk-export-<version>命令并按 Enter 键,将工作目录更改为模块目录。 -
输入
Import-Module -Name ".\SqlBulkExport.psd1"命令来导入 SqlBulkExport 模块。 -
输入
Export-SqlBulkCsv -ServerName "<your-server-name>" -DatabaseName "<your-database-name>" -User "<username>" -Password "<password>" -TableViewName "<your-table-or-view-name>" -BatchSize 30000 -OutputFileFullPath "C:\Temp\output.csv"命令将数据库表(或视图)的内容按 30K 行的批次导出到output.csv文件中。以下是输出:

图 17 — 提取表/视图内容到 CSV 文件的命令的控制台输出(图片由作者提供)
如你所见,提取一个约 74K 行的表内容需要 3 批次的 30K 行,总共花费了 1 秒钟和 88 毫秒。不错!
让我们尝试使用查询来导出数据。
将查询输出导出到一个 CSV 文件中
在这种情况下,我们将从与前一个案例相同的表中提取数据,但使用类似SELECT * FROM <table> WHERE <condition>的查询。
-
打开 PowerShell 7 提示符,输入
cd C:\Temp\sql-bulk-export-<version>命令并按Enter键,将工作目录更改为模块目录。 -
输入
Import-Module -Name ".\SqlBulkExport.psd1"命令以导入SqlBulkExport模块。 -
输入
Export-SqlBulkCsv -ServerName "<your-server-name>" -DatabaseName "<your-database-name>" -User "<username>" -Password "<password>" -Query "SELECT * FROM <your-table-or-view-name> WHERE <condition>" -BatchSize 30000 -OutputFileFullPath "C:\Temp\output.csv"命令,将查询结果集的内容按 30K 行的批次导出到output.csv文件中。以下是输出:

图 18 — 提取查询输出到 CSV 文件的命令的控制台输出(图片由作者提供)
一切运行得非常顺利!现在让我们尝试将一个视图的内容导出到多个每月的 CSV 文件中。
将表/视图的内容导出到多个每月的 CSV 文件中
想象一下你有一个每月包含数十万行的交易表。有一组来自公司外部的数据科学家被分配来对交易历史进行高级分析。为了方便,他们要求你提取一个数据集,其中包含表中字段的子集,涵盖几个月的交易记录。他们要求你提供多个按月份划分的 CSV 文件,而不是生成一个单一的 CSV 文件。
让我们看看如何通过Export-SqlBulkCsvByPeriod函数做到这一点:
-
打开 PowerShell 7 提示符,输入
cd C:\Temp\sql-bulk-export-<version>命令并按Enter键,将工作目录更改为模块目录。 -
输入
Import-Module -Name ".\SqlBulkExport.psd1"命令以导入SqlBulkExport模块。 -
输入
Export-SqlBulkCsvByPeriod -ServerName "<your-server-name>" -DatabaseName "<your-database-name>" -User "<username>" -Password "<password>" -TableViewName "<your-table-name>" -DateColumnName "<your-date-column-name>" -StartPeriod "2022-01" -EndPeriod "2022-03" -BatchSize 100000 -OutputFileFullPath "C:\Temp\output_{}.csv"命令将数据库表(或视图)的内容导出为多个按月的 CSV 文件,每批 10 万行,从 2022 年 1 月到 2022 年 3 月。以下是输出:

图 19 — 提取多个按月 CSV 文件的命令控制台输出(图片由作者提供)
太棒了!你刚刚在 1 分钟 19 秒内提取了大约 150 万行数据,并将其分解为三个按月的 CSV 文件!
结论
促使我撰写这篇文章的需求是将大量数据(3–4+ GB)提取到一个或多个符合 RFC 4180 标准的 CSV 格式文件中。
你已经看到微软提供的工具(无论是 IDE,例如 SSMS 和 ADS;还是命令行工具,例如 BCP)无法满足上述需求。唯一看起来稍微合适的工具是 ADS,但它无法在不崩溃的情况下提取大量数据。不客气地说,到目前为止微软还没有提供符合上述要求的工具,这实在有些令人尴尬。
由于未能在互联网上找到满足上述需求的软件,我编写了SqlBulkExport PowerShell 模块来解决这个问题,并将其以 MIT 许可证开源在 GitHub上。我强调我并不是 PowerShell 开发人员,所以任何能改进此解决方案的建议都非常欢迎!
使用 Python 从 PDF 文件中提取文本:全面指南
原文:
towardsdatascience.com/extracting-text-from-pdf-files-with-python-a-comprehensive-guide-9fc4003d517
从 PDF 文件中提取表格、图像和纯文本的完整过程
·发表于 Towards Data Science ·17 分钟阅读·2023 年 9 月 21 日
--

图片来源于 Giorgio Trovato 在 Unsplash
介绍
在大型语言模型(LLMs)及其 广泛应用 的时代,从简单的文本总结和翻译到基于情感和财务报告主题预测股票表现,文本数据的重要性比以往任何时候都大。
许多类型的文档都包含这种非结构化的信息,从网页文章和博客帖子到手写信件和诗歌。然而,大量的文本数据以 PDF 格式存储和传输。更具体地说,每年在 Outlook 中打开的 PDF 超过 20 亿个,而每天在 Google Drive 和电子邮件中保存的新 PDF 文件达到 7300 万个 (2)。
因此,开发一种更系统的方法来处理这些文档并从中提取信息将使我们能够实现自动化流程,更好地理解和利用这一大批文本数据。为了完成这一任务,当然,我们最好的朋友无疑就是 Python。
然而,在我们开始之前,我们需要明确现在存在的不同类型的 PDF,特别是最常见的三种类型:
-
程序生成的 PDF:这些 PDF 是使用 W3C 技术如 HTML、CSS 和 JavaScript 或其他软件如 Adobe Acrobat 在计算机上创建的。这种类型的文件可以包含各种组件,如图像、文本和链接,这些都可以被搜索和轻松编辑。
-
传统扫描文档:这些 PDF 是通过扫描仪或移动应用程序从非电子介质创建的。这些文件只是存储在 PDF 文件中的图像集合。也就是说,图像中出现的元素,如文本或链接,无法被选择或搜索。本质上,PDF 作为这些图像的容器。
-
带有 OCR 的扫描文档:在这种情况下,在扫描文档后,使用光学字符识别(OCR)软件来识别文件中每个图像中的文本,将其转换为可搜索和可编辑的文本。然后,软件在图像上添加实际文本的层,从而在浏览文件时可以将其作为单独的组件选择。
尽管现在越来越多的机器装有 OCR 系统来识别扫描文档中的文本,但仍有一些文档包含全页图像格式。你可能已经遇到过这种情况,当你阅读一篇精彩的文章并试图选择一个句子时,却选择了整页。这可能是特定 OCR 机器的限制或完全缺失的结果。因此,为了不遗漏这篇文章中的信息,我尝试创建一个也考虑这些情况的过程,并充分利用我们宝贵且信息丰富的 PDF 文件。
理论方法
记住这些不同类型的 PDF 文件及其组成项目,进行 PDF 布局的初步分析是重要的,以确定每个组件所需的适当工具。更具体地说,根据这项分析的结果,我们将应用适当的方法来提取 PDF 中的文本,无论是带有元数据的文本块、图像中的文本还是表格中的结构化文本。在没有 OCR 的扫描文档中,识别和提取图像中文本的方法将承担所有繁重的工作。此过程的输出将是一个 Python 字典,包含提取的信息,每页 PDF 文件的信息。此字典中的每个键将表示文档的页码,对应的值将是包含以下 5 个嵌套列表的列表:
-
按文本块提取的文本
-
每个文本块中的字体家族和大小的格式
-
从页面上的图像中提取的文本
-
从表格中以结构化格式提取的文本
-
页面上的完整文本内容

作者提供的图像
这样,我们可以实现对每个源组件提取文本的更合理分离,有时这可以帮助我们更容易地检索通常出现在特定组件中的信息(例如,徽标图像中的公司名称)。此外,从文本中提取的元数据,如字体系列和大小,可以用于轻松识别文本标题或突出显示的重要文本,这将帮助我们进一步分离或对文本进行多块后处理。最后,以 LLM 可以理解的方式保留结构化表格信息将显著提升对提取数据中关系的推断质量。然后,这些结果可以作为每页上出现的所有文本信息的输出。
您可以在下图中查看这种方法的流程图。

图片由作者提供
安装所有必要的库
不过,在开始这个项目之前,我们应该安装必要的库。我们假设您的机器上已安装 Python 3.10 或更高版本。否则,您可以从这里进行安装。然后让我们安装以下库:
PyPDF2:用于从存储库路径中读取 PDF 文件。
pip install PyPDF2
Pdfminer:用于执行布局分析并从 PDF 中提取文本和格式。(支持 Python 3 的库版本为 .six)
pip install pdfminer.six
Pdfplumber:用于识别 PDF 页中的表格并从中提取信息。
pip install pdfplumber
Pdf2image:用于将裁剪后的 PDF 图像转换为 PNG 图像。
pip install pdf2image
PIL:用于读取 PNG 图像。
pip install Pillow
Pytesseract:用于使用 OCR 技术从图像中提取文本。
安装这个稍微复杂一些,因为首先,您需要安装Google Tesseract OCR,这是一个基于 LSTM 模型的 OCR 机器,用于识别行识别和字符模式。
如果您是 Mac 用户,可以通过终端中的 Brew 在您的机器上安装这些库,安装后您就可以开始使用了。
brew install tesseract
对于 Windows 用户,您可以按照这些步骤安装链接。然后,当您下载并安装软件时,您需要将其可执行路径添加到计算机的环境变量中。或者,您可以运行以下命令,通过以下代码直接在 Python 脚本中包含其路径:
pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
然后您可以安装 Python 库
pip install pytesseract
最后,我们将在脚本开始时导入所有库。
# To read the PDF
import PyPDF2
# To analyze the PDF layout and extract text
from pdfminer.high_level import extract_pages, extract_text
from pdfminer.layout import LTTextContainer, LTChar, LTRect, LTFigure
# To extract text from tables in PDF
import pdfplumber
# To extract the images from the PDFs
from PIL import Image
from pdf2image import convert_from_path
# To perform OCR to extract text from images
import pytesseract
# To remove the additional created files
import os
现在我们已经准备好了。让我们进入有趣的部分。
使用 Python 进行文档布局分析

图片由作者提供
对于初步分析,我们使用了 PDFMiner Python 库将文档对象中的文本分离为多个页面对象,然后拆解和检查每个页面的布局。PDF 文件本质上缺乏结构化信息,例如段落、句子或单词,如人眼所见。相反,它们只理解文本的单个字符及其在页面上的位置。这样,PDFMiner 尝试将页面内容重建为其单个字符及其在文件中的位置。然后,通过比较这些字符与其他字符之间的距离,它将组成适当的单词、句子、行和段落文本。(4)为实现这一点,该库:
使用高阶函数 extract_pages()将 PDF 文件中的每个页面分离,并将它们转换为LTPage对象。
然后,对于每个 LTPage 对象,它从顶部到底部迭代每个元素,并尝试将适当的组件识别为:
-
LTFigure表示 PDF 中可以呈现嵌入为另一个 PDF 文档的图形或图像的区域。
-
LTTextContainer表示一个矩形区域中的文本行组,随后进一步分析为LTTextLine对象的列表。每个LTTextLine对象表示一系列LTChar对象,这些对象存储单个字符及其元数据。(5)
-
LTRect表示一个二维矩形,可用于框定图像、图形或在 LTPage 对象中创建表格。
因此,根据页面的重建和元素的分类,无论是LTFigure(包含页面中的图像或图形)、LTTextContainer(表示页面的文本信息)还是LTRect(将强烈指示表格的存在),我们可以应用适当的函数以更好地提取信息。
for pagenum, page in enumerate(extract_pages(pdf_path)):
# Iterate the elements that composed a page
for element in page:
# Check if the element is a text element
if isinstance(element, LTTextContainer):
# Function to extract text from the text block
pass
# Function to extract text format
pass
# Check the elements for images
if isinstance(element, LTFigure):
# Function to convert PDF to Image
pass
# Function to extract text with OCR
pass
# Check the elements for tables
if isinstance(element, LTRect):
# Function to extract table
pass
# Function to convert table content into a string
pass
现在我们理解了过程的分析部分,让我们创建提取每个组件中文本所需的函数。
定义提取 PDF 中文本的函数
从这里开始,从文本容器中提取文本非常简单。
# Create a function to extract text
def text_extraction(element):
# Extracting the text from the in-line text element
line_text = element.get_text()
# Find the formats of the text
# Initialize the list with all the formats that appeared in the line of text
line_formats = []
for text_line in element:
if isinstance(text_line, LTTextContainer):
# Iterating through each character in the line of text
for character in text_line:
if isinstance(character, LTChar):
# Append the font name of the character
line_formats.append(character.fontname)
# Append the font size of the character
line_formats.append(character.size)
# Find the unique font sizes and names in the line
format_per_line = list(set(line_formats))
# Return a tuple with the text in each line along with its format
return (line_text, format_per_line)
因此,要从文本容器中提取文本,我们只需使用get_text()方法。该方法检索构成特定语料库框中的单词的所有字符,并将输出存储在文本数据列表中。该列表中的每个元素代表容器中包含的原始文本信息。
现在,为了识别这些文本的格式,我们遍历 LTTextContainer 对象,以逐一访问该语料库的每一行文本。在每次迭代中,会创建一个新的LTTextLine对象,表示该语料库块中的一行文本。然后我们检查嵌套的行元素是否包含文本。如果包含,我们访问每个单独的字符元素作为 LTChar,它包含该字符的所有元数据。从这些元数据中,我们提取两种格式,并将其存储在一个单独的列表中,与被检查的文本相对应。
-
字符的字体家族,包括字符是否为粗体或斜体格式
-
字符的字体大小
通常,特定文本块中的字符格式趋于一致,除非有些字符以粗体突出显示。为了便于进一步分析,我们捕获文本中所有字符的独特格式值,并将其存储在相应的列表中。

图片由作者提供
定义提取图像中文本的函数
我认为这是一个更棘手的部分。
如何处理 PDF 中找到的图像中的文本?
首先,我们需要在这里确定,存储在 PDF 中的图像元素与文件的其他格式(如 JPEG 或 PNG)没有不同。因此,为了对它们应用 OCR 软件,我们需要首先将它们从文件中分离出来,然后将其转换为图像格式。
# Create a function to crop the image elements from PDFs
def crop_image(element, pageObj):
# Get the coordinates to crop the image from the PDF
[image_left, image_top, image_right, image_bottom] = [element.x0,element.y0,element.x1,element.y1]
# Crop the page using coordinates (left, bottom, right, top)
pageObj.mediabox.lower_left = (image_left, image_bottom)
pageObj.mediabox.upper_right = (image_right, image_top)
# Save the cropped page to a new PDF
cropped_pdf_writer = PyPDF2.PdfWriter()
cropped_pdf_writer.add_page(pageObj)
# Save the cropped PDF to a new file
with open('cropped_image.pdf', 'wb') as cropped_pdf_file:
cropped_pdf_writer.write(cropped_pdf_file)
# Create a function to convert the PDF to images
def convert_to_images(input_file,):
images = convert_from_path(input_file)
image = images[0]
output_file = "PDF_image.png"
image.save(output_file, "PNG")
# Create a function to read text from images
def image_to_text(image_path):
# Read the image
img = Image.open(image_path)
# Extract the text from the image
text = pytesseract.image_to_string(img)
return text
为了实现这一点,我们遵循以下过程:
-
我们使用从 PDFMiner 检测到的 LTFigure 对象的元数据来裁剪图像框,利用其在页面布局中的坐标。然后我们使用PyPDF2库将其保存为我们目录中的新 PDF 文件。
-
然后,我们使用pdf2image库的convert_from_file()函数将目录中的所有 PDF 文件转换为图像列表,并将其保存为 PNG 格式。
-
最后,现在我们拥有了图像文件,我们使用PIL模块的Image包在脚本中读取这些图像,并实现 pytesseract 的image_to_string()函数,使用 tesseract OCR 引擎从图像中提取文本。
因此,这个过程从图像中提取文本,然后我们将其保存在输出字典中的第三个列表中。这个列表包含从被检查页面上的图像中提取的文本信息。
定义提取表格中文本的函数
在这一部分,我们将从 PDF 页面上的表格中提取更具逻辑结构的文本。这比从语料库中提取文本要复杂一些,因为我们需要考虑信息的粒度以及表格中呈现的数据点之间形成的关系。
尽管有几个库用于从 PDF 中提取表格数据,Tabula-py是最著名的之一,但我们发现它们的功能存在一定的局限性。
在我们看来,最明显的问题来自于库使用换行符\n 识别表格的不同行。这在大多数情况下效果很好,但当单元格中的文本被换行成两行或更多行时,它无法正确捕捉,导致添加了不必要的空行并丢失了提取单元格的上下文。
你可以查看下面的示例,当我们尝试使用 tabula-py 提取表格数据时:

作者提供的图像
然后,将提取的信息输出为 Pandas DataFrame,而不是字符串。在大多数情况下,这是一种理想的格式,但对于考虑文本的 transformers,这些结果需要在输入模型之前进行转换。
因此,为了处理这个任务,我们使用了pdfplumber库。首先,它建立在我们用于初步分析的 pdfminer.six 之上,这意味着它包含类似的对象。此外,它的表格检测方法基于线条元素及其交点,这些元素构建了包含文本的单元格以及整个表格。这样,在我们识别表格单元格后,可以提取单元格内部的内容,而无需考虑需要渲染多少行。然后,当我们拥有表格内容时,我们会将其格式化为类似表格的字符串并存储在适当的列表中。
# Extracting tables from the page
def extract_table(pdf_path, page_num, table_num):
# Open the pdf file
pdf = pdfplumber.open(pdf_path)
# Find the examined page
table_page = pdf.pages[page_num]
# Extract the appropriate table
table = table_page.extract_tables()[table_num]
return table
# Convert table into the appropriate format
def table_converter(table):
table_string = ''
# Iterate through each row of the table
for row_num in range(len(table)):
row = table[row_num]
# Remove the line breaker from the wrapped texts
cleaned_row = [item.replace('\n', ' ') if item is not None and '\n' in item else 'None' if item is None else item for item in row]
# Convert the table into a string
table_string+=('|'+'|'.join(cleaned_row)+'|'+'\n')
# Removing the last line break
table_string = table_string[:-1]
return table_string
为了实现这一点,我们创建了两个函数,extract_table()用于将表格内容提取为列表的列表,table_converter()用于将这些列表的内容连接成类似表格的字符串。
在extract_table()函数中:
-
我们打开 PDF 文件。
-
我们导航到 PDF 文件的检查页面。
-
从 pdfplumber 找到的页面中的表格列表中,我们选择所需的表格。
-
我们提取了表格的内容,并将其输出为表示每行的嵌套列表。
在table_converter()函数中:
-
我们遍历每个嵌套列表,并清除任何来自换行文本的多余换行符。
-
我们通过使用|符号分隔行中的每个元素,以创建表格单元格的结构。
-
最后,我们在末尾添加一个换行符,以移动到下一行。
这将生成一个文本字符串,展示表格的内容,而不会丢失呈现的数据的细节。
将所有内容整合在一起
现在我们已准备好所有代码组件,让我们将它们整合成一个完整的代码。你可以从这里复制代码,或者你可以在我的 Github 仓库这里找到它及示例 PDF。
# Find the PDF path
pdf_path = 'OFFER 3.pdf'
# create a PDF file object
pdfFileObj = open(pdf_path, 'rb')
# create a PDF reader object
pdfReaded = PyPDF2.PdfReader(pdfFileObj)
# Create the dictionary to extract text from each image
text_per_page = {}
# We extract the pages from the PDF
for pagenum, page in enumerate(extract_pages(pdf_path)):
# Initialize the variables needed for the text extraction from the page
pageObj = pdfReaded.pages[pagenum]
page_text = []
line_format = []
text_from_images = []
text_from_tables = []
page_content = []
# Initialize the number of the examined tables
table_num = 0
first_element= True
table_extraction_flag= False
# Open the pdf file
pdf = pdfplumber.open(pdf_path)
# Find the examined page
page_tables = pdf.pages[pagenum]
# Find the number of tables on the page
tables = page_tables.find_tables()
# Find all the elements
page_elements = [(element.y1, element) for element in page._objs]
# Sort all the elements as they appear in the page
page_elements.sort(key=lambda a: a[0], reverse=True)
# Find the elements that composed a page
for i,component in enumerate(page_elements):
# Extract the position of the top side of the element in the PDF
pos= component[0]
# Extract the element of the page layout
element = component[1]
# Check if the element is a text element
if isinstance(element, LTTextContainer):
# Check if the text appeared in a table
if table_extraction_flag == False:
# Use the function to extract the text and format for each text element
(line_text, format_per_line) = text_extraction(element)
# Append the text of each line to the page text
page_text.append(line_text)
# Append the format for each line containing text
line_format.append(format_per_line)
page_content.append(line_text)
else:
# Omit the text that appeared in a table
pass
# Check the elements for images
if isinstance(element, LTFigure):
# Crop the image from the PDF
crop_image(element, pageObj)
# Convert the cropped pdf to an image
convert_to_images('cropped_image.pdf')
# Extract the text from the image
image_text = image_to_text('PDF_image.png')
text_from_images.append(image_text)
page_content.append(image_text)
# Add a placeholder in the text and format lists
page_text.append('image')
line_format.append('image')
# Check the elements for tables
if isinstance(element, LTRect):
# If the first rectangular element
if first_element == True and (table_num+1) <= len(tables):
# Find the bounding box of the table
lower_side = page.bbox[3] - tables[table_num].bbox[3]
upper_side = element.y1
# Extract the information from the table
table = extract_table(pdf_path, pagenum, table_num)
# Convert the table information in structured string format
table_string = table_converter(table)
# Append the table string into a list
text_from_tables.append(table_string)
page_content.append(table_string)
# Set the flag as True to avoid the content again
table_extraction_flag = True
# Make it another element
first_element = False
# Add a placeholder in the text and format lists
page_text.append('table')
line_format.append('table')
# Check if we already extracted the tables from the page
if element.y0 >= lower_side and element.y1 <= upper_side:
pass
elif not isinstance(page_elements[i+1][1], LTRect):
table_extraction_flag = False
first_element = True
table_num+=1
# Create the key of the dictionary
dctkey = 'Page_'+str(pagenum)
# Add the list of list as the value of the page key
text_per_page[dctkey]= [page_text, line_format, text_from_images,text_from_tables, page_content]
# Closing the pdf file object
pdfFileObj.close()
# Deleting the additional files created
os.remove('cropped_image.pdf')
os.remove('PDF_image.png')
# Display the content of the page
result = ''.join(text_per_page['Page_0'][4])
print(result)
上面的脚本将:
导入必要的库。
使用pyPDF2库打开 PDF 文件。
提取 PDF 的每一页,并执行以下步骤。
检查页面上是否有表格,并使用pdfplumner创建一个表格列表。
查找页面中嵌套的所有元素,并按其在布局中出现的顺序对它们进行排序。
然后对每个元素:
检查是否为文本容器,并且不出现在表格元素中。然后使用text_extraction()函数提取文本及其格式,否则忽略此文本。
检查是否为图像,并使用crop_image()函数从 PDF 中裁剪图像组件,使用convert_to_images()将其转换为图像文件,并使用 OCR 和image_to_text()函数提取文本。
检查是否为矩形元素。在这种情况下,我们检查第一个矩形是否是页面表格的一部分,如果是,则转到以下步骤:
-
查找表格的边界框,以避免使用 text_extraction()函数再次提取其文本。
-
提取表格内容并将其转换为字符串。
-
然后添加一个布尔参数,以澄清我们是否从表格中提取文本。
-
这个过程将在最后一个 LTRect 落在表格的边界框内,并且布局中的下一个元素不是矩形对象后结束。(所有组成表格的其他对象将被忽略)
该过程的输出将每次迭代存储在 5 个列表中,命名为:
-
page_text: 包含来自 PDF 文本容器的文本(当文本从另一个元素中提取时,将放置占位符)
-
line_format: 包含上面提取的文本的格式(当文本从另一个元素中提取时,将放置占位符)
-
text_from_images: 包含从页面上的图像提取的文本
-
text_from_tables: 包含表格内容的类似表格的字符串
-
page_content: 包含以元素列表形式呈现的页面上所有文本
所有列表将存储在一个字典的键下,该字典将表示每次检查的页面编号。
之后,我们将关闭 PDF 文件。
然后我们将删除在过程中创建的所有额外文件。
最后,我们可以通过连接 page_content 列表的元素来显示页面内容。
结论
这是一种方法,我认为它结合了许多库的最佳特性,使过程对各种类型的 PDF 和我们可能遇到的元素具有弹性,但主要依赖 PDFMiner 进行繁重的工作。此外,关于文本格式的信息可以帮助我们识别潜在的标题,这些标题可以将文本划分为不同的逻辑部分,而不仅仅是按页面内容,并有助于识别更重要的文本。
然而,总会有更高效的方法来完成此任务,尽管我认为这种方法更具包容性,我非常期待与您讨论新的和更好的解决此问题的方法。
📖 参考文献:
-
www.techopedia.com/12-practical-large-language-model-llm-applications -
pdfpro.com/blog/guides/pdf-ocr-guide/#:~:text=OCR技术可以读取文本,从而生成可搜索和可编辑的 PDF。 -
pdfminersix.readthedocs.io/en/latest/topic/converting_pdf_to_text.html#id1
提取式与生成式问答——哪种更适合您的业务?
原文:
towardsdatascience.com/extractive-vs-generative-q-a-which-is-better-for-your-business-5a8a1faab59a
ChatGPT 的出现暗示着一个搜索引擎的新纪元,本教程深入探讨了两种基本的 AI 问答类型
·发表在 Towards Data Science ·阅读时间 6 分钟·2023 年 2 月 6 日
--

提取式与生成式问答 | Skanda Vivek
变换模型 于 2017 年引入 已经在解决难度大的语言相关任务上取得了突破。像 BERT、GPT 等模型在大量文本数据上训练的原始变换架构变体,在语言相关任务上产生了最先进的结果。
AI 的一个重大好处在于其执行任务的能力,这些任务之前需要领域专业知识和仔细审阅——现在可以更快地完成,并且成本仅为以前的一小部分。我相信这将在未来十年内彻底改变各个行业。
一项典型的任务是从文本中提取信息。问答系统是一个强大的信息提取工具,通过复杂的查询,模型可以被训练来提取特定的信息。想象一下,如果用 AI 模型回答法律文件中的难题,可以节省多少时间和金钱,而不是请经验丰富的律师或聘请实习生花费数小时细读文件。让我们深入了解两种基本的基于 AI 的问答类型:提取式与生成式。
提取式问答
BERT 变换模型由 Google 语言团队于 2019 年发布。BERT 通过遮蔽单词并训练模型基于上下文预测遮蔽的单词,利用未标记的文本数据进行训练。这种遮蔽单词预测是一个常见的测试,用于评估语言能力。
在训练模型后,BERT 后来在多个任务上进行了微调。特别是,BERT 在来自 SQUAD 数据集的数十万个问答对上进行了微调,该数据集包含在 Wikipedia 文章上提出的问题,每个问题的答案是对应段落中的一段文本或 跨度。

BERT Transformer 架构来自 arxiv.org/abs/1810.04805
BERT 和类似 BERT 的模型架构构成了 2017 年论文中提出的原始 Transformer 架构的一半,被称为编码器。在这个模型中,E表示令牌嵌入,其中原始句子长度为 M 被转换为长度 M’(BERT 使用了 WordPiece 嵌入)。最终隐藏向量 T 可以用来预测文本中表示答案开始和结束的部分,通过 softmax 实现。
RoBERTa 是 BERT 的一个变体,通过在训练过程中修改关键超参数来提高整体性能。让我们看看 huggingface 发布的经过微调的 RoBERTa 模型的输出。 如下所示,在抽取式问答中,你只能回答原始上下文中的文本:

然而,答案并不总是最佳的。如下面所示,对于电影评论,我选择的答案会是“在一个未来的世界中,人类仍然非常活跃但不再掌控一切的地球上生活会是什么样子”

获得更相关结果的解决方案是微调。在下面的文章中,我讨论了如何使用自定义数据在 HuggingFace hub 上微调抽取式问答模型。仅通过几千个示例进行微调可以大幅提高性能,有时提高幅度达到 超过 50%。
## 对自定义数据进行微调的 Transformer 模型以进行问答
关于如何在自定义数据上微调 Hugging Face RoBERTa QA 模型并获得显著性能提升的教程
towardsdatascience.com
然而,抽取式问答在答案没有明确存在于上下文中的情况下效果不佳,如下所示。

当答案没有明确存在时,模型产生无用结果
可以通过附加“ANSWERNOTFOUND”并在这些案例上进行微调来规避这个问题,以确保模型在不确定时不提供答案。
抽象问答
虽然 ChatGPT 最近在全球范围内引起了轰动,原始的 GPT 模型 在 BERT 之前发布。GPT 模型使用了 2017 年 Transformer 的解码器层。GPT 模型被训练来以无监督的方式预测序列中的下一个词。然后,它们在监督的方式下进行微调。对于问答,GPT 模型在微调时会接触到多个答案选项的众多示例,并被训练选择正确的选项。一个重要的推断区别是,GPT 模型一次输出一个 token,因此是生成式的,而不是提取式的。
目前,OpenAI 提供了 4 个主要的语言模型 API 访问:
-
Ada ($0.0004 / 1K tokens — 最快)
-
Babbage ($0.0005 / 1K tokens)
-
Curie ($0.0020 / 1K tokens)
-
Davinci ($0.0200 / 1K tokens — 最强大)
作为参考,1K tokens 基本上是你发送给 API 处理的 750 个单词。那我们来看看这个模型对类似问题的表现如何:

基于 GPT3 的 Davinci OpenAI 模型用于问答

基于 GPT3 的 Davinci OpenAI 模型用于问答
正如你所见,Davinci 模型在总结电影情节方面表现很好,同时在答案不明确的情况下会说“我不知道”。
哪种模型更好 — 抽象还是提取式?
你可能会倾向于说 OpenAI 的抽象问答明显优于提取式问答模型。然而,这就是商业案例发挥作用的地方。我将在下面详细说明:
成本
Davinci 模型在大规模时明显更贵。每 1K tokens 需 $0.02,这可能也适用于 1–10 个查询。而在 AWS 上托管 Hugging Face 模型可能成本更低,运行数千次或更多查询每小时只需 0.5 美分到 1 美元。
输出
如果你有兴趣构建一个聊天机器人类型的界面并期望自由响应的答案,那么抽象问答是最佳选择。用户可能不会满意仅仅是将文本改写的干巴巴的提取式答案。然而,如果你在对获得的答案进行后处理——比如将数字存入数据库,那么抽象问答可能会成为障碍,因为你需要使用额外的逻辑来剥离多余的词汇。
自定义
使用 OpenAI API 需要依赖 OpenAI 服务器。虽然他们确实允许在自定义数据上微调模型,但无法在 AWS 等独立基础设施上托管这些模型。不过,你可以使用 Hugging Face 上的开源模型,在 AWS 上创建 API,从而不再依赖 Hugging Face 进行模型服务。这一做法的强大之处在于,它允许公司将所有基础设施保留在内部,只依赖像 AWS 这样的云服务提供商。
我想指出的一点是,Hugging Face 也支持抽象型 QA 模型。事实上,他们最近在模型中心发布了一个 text2text 生成模型 Flan T5。但我发现这个模型在 QA 任务上的表现不如 Davinci GPT-3 模型。很快,我期望 Hugging Face 也会托管像 Davinci GPT-3 这样的开源微调模型。
我希望这篇文章对使用 AI 进行问题回答提供了有用的指导。结合现有的信息检索方法和通过大量数据进行搜索,基于 AI 的信息提取可以帮助从大海捞针中提取有用信息,大大提高从大量数据中提取关键细节的效率,这在过去只能通过人工理解来实现。
更新: https://www.answerchatai.com/ — 我们的 QA 引擎使用生成式 AI 回答问题并从自定义文本中提取关键信息现已上线!只需 3 个简单步骤即可回答领域特定的问题!
-
上传一个 URL 或粘贴文本并点击搜索按钮
-
提出一个与上下文相关的问题并点击查询
-
获取你的答案!
随意使用,并告诉我你的反馈!
如果你还不是 Medium 会员并想支持像我这样的写作者,可以通过我的推荐链接注册: https://skanda-vivek.medium.com/membership
想要获取基于数据的每周视角 请在这里订阅!
眼光放在奖品上:将商业价值置于数据项目的核心
最大化数据投资的商业价值
·
关注 发表在 Towards Data Science ·10 分钟阅读·2023 年 7 月 21 日
--
由 Midjourney 生成
在我的软件执行官角色中,我经常与各行业的首席数据官(CDO)进行交流。有些 CDO 本质上是技术型的,拥有深厚的数据架构和 AI 算法知识。而其他的则更偏向于商业领域,具备对数据如何释放商业价值的敏锐理解。
几周前,我和金融机构的一位 CDO Alex 进行了一次友好的交谈,我认为他更偏向于商业领域。
亚历克斯向我解释说,不久前,在准备一个执行委员会会议时,她要求管理团队提供近期成就的意见。然而,她听到的都是与最近的数据基础设施云迁移相关的技术成就。她知道这些内容不会引起商业利益相关者的共鸣,因此她不得不自己构建一个全新的叙述,最终她成功做到了这一点。这时她意识到,在向商业高管推销数据计划的商业价值时,她大多是孤军奋战,而他们讲述的故事与她团队中的人们当时对工作的感知不一致。
这是我不断与首席数据官(CDO)讨论的话题,从最技术型到更商业导向型的 CDO 们都感受到数据管理的技术复杂性与这些举措应提供的实际商业价值之间的日益紧张。根据 NewVantage Partners《数据与分析领导力年鉴 2023》,首席数据官的角色,从 2012 年仅在 12%的组织中出现,现在已在 82%的公司中稳固建立。数据管理的复杂性也随之增长,使得这道鸿沟更具挑战性。
接下来的章节将深入探讨这个问题:数据管理的复杂技术方面与对清晰商业价值生成的迫切需求之间的紧张关系。
首先,我们探讨以商业价值优先的方法,其中对商业目标的贡献成为衡量数据计划有效性的主要标准。
然后,我们介绍商业使用案例的概念——数据计划的直接业务交付成果,这些成果才是真正推动商业价值的驱动力。
接下来,我们深入探讨商业价值测量的关键作用,这是一项在对齐数据举措与商业目标的持续合作过程中至关重要的因素。
然后,我们考察这种方法对组织及其人员的深远影响,培养一种所有人都朝着同一目标努力的环境:商业价值创造。
最后,我们概述了转变为以商业价值驱动的数据组织的过程,讨论所需的战略和战术调整、预期的挑战,以及确保数据计划不仅仅是复杂技术项目的步骤。
理解紧张关系:技术执行的复杂性与商业价值
数据项目是多方面的,需要高水平的技术能力来执行。这种复杂性延伸至开发复杂的数据管道、设置复杂的云存储系统、实施严格的治理框架,以及编目大量数据。毫无疑问,这些任务是重要的成就,标志着数据项目旅程中的关键步骤。
然而,复杂性往往带来一种隧道视野。数据团队越深入技术执行,就越有可能忽视该项目存在的初衷——产生商业价值。尽管数据项目的技术成就至关重要,但它们是达到目的的手段,而非最终目标。
需要记住的是,数据本质上是运营过程的副产品,并不一定是为了支持商业决策而精心设计的。这就是为什么我们需要复杂的数据项目来使数据适用于商业决策。但如果技术细节和商业价值之间的联系丢失,整个项目就会失去其目的。这是需要解决的紧张关系——将数据管理的技术复杂性与明确、具体的商业价值对齐。
商业价值优先:专注于最终目标
尽管由于固有复杂性,在数据项目执行的技术细节中沉浸是自然的,但我们已经发现这些任务不过是达到目的的手段。复杂性和对技术任务的关注无意中掩盖了最终目标——推动商业价值。那么我们如何解决技术性和商业价值之间的紧张关系呢?通过将商业价值优先。
但这究竟意味着什么?这意味着从数据优先的视角转变为商业优先的视角。它是关于超越术语、算法和基础设施成就,回答一个基本问题:“我们的数据举措对业务的影响是什么?”这要求将每个技术任务与相应的商业目标对齐,并持续评估其在业务结果生成中的有效性。
事实是,不论数据管道的复杂程度或集成的数据源数量如何,如果数据项目没有积极地贡献于公司的底线或其他关键绩效指标(KPIs),那么它就未能实现其主要目的。这就是为什么采取商业价值优先的方法不仅仅是一个战略决策,而是对数据项目评估和执行方式的必要转变。
这种关注点的转移并不否定技术任务的重要性。相反,它确保这些任务直接与核心目标对齐:生成可观的业务价值。这种方法确保在数据项目中的每一个决定,无论多么技术化,都以推动有意义的业务结果为导向。换句话说,它将业务背景置于数据和分析操作的核心。
介绍业务用例:值得关注的交付成果
在任何数据项目中,我们发现各种交付成果,有些直接贡献于业务价值,有些则间接支持这一目标。与业务价值直接相关的交付成果,就是我们定义为业务用例的部分。
让我们考虑一些源自我与客户互动经验的实际例子。你会注意到,这些业务用例以“业务价值优先”的方式进行描述——这是首席数据官向业务方展示成就的方式。
能源效率改进:我们合作的一家大型泛欧工业公司实施了一种分析解决方案,以减少能源成本。这一举措引起了业务的关注,因为它导致了接近百万欧元的节省,并将每年的 CO2 排放量减少了 1000 吨。实施过程中涉及了部署传感器以测量能源消耗,创建模拟模型,并可视化主要贡献者,以了解能源流动情况。
实时缺陷检测:我们服务的一家主要英国汽车制造商实施了一个流动的过程故障单(PFS),以在生产过程中实时检测和解决缺陷。该系统随装配线移动,覆盖整个生产过程,实现了生产异常的高效识别。之前,他们的缺陷报告过程效率低下,耗时长且容易出错。实施流动 PFS 后,大大减少了检查时间和每班次所需的技术人员数量。因此,公司每辆车节省了 1.75 英镑,全年节省超过 40 万英镑——这清晰地展示了该数据计划的业务价值。
业务用例提供了可观的价值,无论是通过增加收入、降低成本,还是减少风险。但业务用例的范围不仅限于这些主要类别。它们还包括必要的合规报告,并支持与更广泛公司目标一致的关键绩效指标(KPI),如推动多样性或可持续性。
由于业务用例产生了数据项目全部的价值,因此对于首席数据官(CDO)来说,确保这些价值被充分理解和完全认可是至关重要的。
衡量业务价值:一个持续的、协作的过程
业务用例的价值并非静态,它随着市场动态、监管环境的变化、技术进步以及其他众多外部因素而不断演变。
让我们用最近的一个例子来说明——COVID-19 大流行。在疫情最严重时,数据团队围绕工作场所占用率、空气质量指数以及其他与疫情相关的数据开发了业务用例。这些用例在当时具有很高的价值,但随着我们进入后疫情时代,这些特定用例的价值已减少,新的优先事项应运而生。虽然应对类似事件的能力仍然具有价值,但它在危机高峰期时的立即性或量化意义不再那么显著。
重要的是,评估这些业务用例的价值的责任完全落在业务利益相关者的肩上,而不是数据团队。业务利益相关者是那些日复一日体验这些用例操作影响的人。他们的见解对于评估价值和区分短期利益(如一次性成本减少)与结构性利益(如长期运营效率提高)至关重要。
因此,价值测量的过程应该是持续的,捕捉不同时间点上的感知变化。这种持续的评估不仅有助于微调数据计划的执行,还确保它与不断变化的业务需求保持一致。
然而,这一过程在没有考虑实施业务用例的成本时是不完整的。平衡感知价值和相关成本变得至关重要。此外,业务用例之间可能存在的利益重叠或重复计算所引入的复杂性,使得估值过程远非简单。
为了应对这一挑战,必须与业务利益相关者建立持续的对话和反馈机制。这种开放的沟通使数据计划能够持续调整与业务的现实和优先事项,从而确保其相关性和对最终结果的持续贡献。通过这种不断的互动,我们可以培养一个在技术上稳健并且敏锐地适应业务需求的数据计划。
转型为业务价值驱动的数据组织
转型为一个以业务价值驱动的数据组织不仅仅需要战略上的调整,更需要在数据计划的所有层面进行深度变革。这不仅仅是关于首席数据官(CDO)展示业务导向的报告,而是要使业务价值成为数据计划中所有活动的指导原则。
在许多组织中,数据相关的任务可能会显得与业务现实脱节。但这种方法可能会导致数据计划与其应服务的业务需求之间的脱节。
向业务价值驱动的组织转型涉及到质疑任何用数据进行的重大工作是否有助于创造业务价值。这种变化影响任务的优先级、成功的衡量标准以及资源的分配。它有助于将每个团队成员的工作与整体业务目标联系起来。
这种转型不会一蹴而就,并且确实伴随着挑战。它需要坚定的领导、清晰的愿景、对变化的开放态度和支持过渡的正确工具。但回报是巨大的:更高效的组织、更好的业务对齐以及对业务底线的更强贡献。
然而,要有效地进行这种转型,需要制定战术和战略计划。在以下部分中,我们概述了将这种从数据任务到实际业务价值交付的转变所需的关键步骤。
使其发生:从数据任务到业务价值交付
从上面讨论的转型原则出发,将数据组织的重点从纯粹的技术任务转向创造实际的业务价值确实不是一项简单的工作。这不仅要求视角的改变,还需要完全重新调整优先级,并深刻理解业务及其需求。
首先,至关重要的是发展一种跨越数据组织和业务部门的协作文化。业务部门必须在定义和衡量业务用例的价值方面发挥主导作用。另一方面,数据团队需要深入了解这些用例、其价值驱动因素以及实现这些用例所需的数据要求。
重要的是,这种协作还必须延伸到为业务用例提供数据的人员(组织中的业务面对部分)和提供打包数据供消费的人员之间。
对齐应该在整个数据计划中传播。
每个任务、每个管道、每个模型都需要清晰地与它们服务的业务用例关联起来。这不仅仅是为了报告目的而拥有一个可追溯性矩阵。这是为了使价值生成在每一个层面上都变得可见和具体,从而在所有利益相关者中,包括最技术的数据工程师到洞察的业务用户,培养一种目的感和参与感。
为推动这一转型,建立可衡量的目标至关重要,这些目标反映了对业务价值的新关注。这包括设定业务价值创造、价值实现时间以及任务与业务目标的总体对齐的目标。评估业务价值的过程本身也应随着时间的推移得到完善和改进,考虑到不断变化的业务背景和所学到的经验。
这种转型是一个持续的承诺,需要随着业务需求、数据可用性和技术能力的发展不断关注和调整。这不是一个一次性的项目,而是一个不断改进对齐、协作和价值交付的持续旅程。
总之,将数据组织与业务价值创造对齐的过程需要有意的行动和持续的努力。但是,回报——一个与业务深度融合的数据团队、基于对业务目标影响优先排序的举措,以及一个利用数据力量推动实际价值的组织——使得这段旅程非常值得投入。
使用 Python 的面部检测——面部识别的前奏
原文:
towardsdatascience.com/face-detection-using-python-the-precursor-to-face-recognition-316ded4d116f
通过使用你的网络摄像头来检测你的面孔,享受 Python 的乐趣
·发布于 Towards Data Science ·8 分钟阅读·2023 年 4 月 11 日
--

作者提供的所有图片
面部检测是一种在数字图像中识别人的面孔的技术。面部检测是一项相对成熟的技术——还记得在你用数字相机时,当你通过取景器查看时,你看到围绕人脸的矩形吗?
面部检测是你在进行面部识别之前需要学习的技术,面部识别试图将面孔与名字对应起来。
对于面部检测,最著名的算法之一被称为Viola-Jones 面部检测技术,通常称为Haar cascades。Haar cascades 的发明早于深度学习流行,是检测面孔的最常用技术之一。
面部检测/识别的伦理考量
虽然检测和识别面孔的能力确实很酷,但它确实有很多伦理问题。在将面部识别技术应用到你的项目中之前,有几个问题需要注意。比如隐私(面部检测可能被用来在未经同意的情况下追踪人们的行动)、偏见(面部检测可能对不同种族、性别或年龄的人有偏见)、滥用(捕捉到的面孔可能被用于其他非法用途或恶意目的)。因此,尽管这篇文章关注面部检测的技术能力,但在你将其应用到工作中之前,你应仔细考虑道德和伦理的影响。
以下是一些可以实现面部检测/识别的低风险项目:
-
考勤追踪——你可以在学校或工作场所使用面部识别来进行考勤。
-
个性化 — 使用人脸识别来个性化服务。一个好的例子是娱乐服务,例如根据用户的观看历史推荐特定的电视节目。
-
安全性 — 使用人脸识别解锁非关键系统,如智能手机和计算机。
然而,在某些应用中使用人脸识别有严重的道德问题。以下是一些例子:
-
执法 — 尽管人脸识别对执法有帮助,但其不准确性和偏见仍然是严重的担忧。
-
监控 — 人脸识别技术在一些国家被用于监控和追踪其公民,尤其是异议人士。一些公司也使用人脸识别来监控员工的生产力,这直接侵犯了他们的隐私。
这里有一些文章,你可以阅读以了解更多关于人脸识别的法律和道德问题:
-
美国的人脸识别:隐私问题和法律发展 —
www.asisonline.org/security-management-magazine/monthly-issues/security-technology/archive/2021/december/facial-recognition-in-the-us-privacy-concerns-and-legal-developments/ -
面部识别软件的隐私和安全问题 —
www.techrepublic.com/article/privacy-and-security-issues-associated-with-facial-recognition-software/ -
关于人脸识别技术的 10 个担忧理由 —
www.privacycompliancehub.com/gdpr-resources/10-reasons-to-be-concerned-about-facial-recognition-technology/
Haar 级联分类器是如何工作的?
Haar 级联分类器 用于检测其训练过的对象。如果你对 Haar 级联分类器工作原理的数学解释感兴趣,可以查看 Paul Viola 和 Michael Jones 的论文 www.cs.cmu.edu/~efros/courses/LBMV07/Papers/viola-cvpr-01.pdf。
这里是 Haar 分类器用于人脸识别的高级概述:
-
首先,使用一组正面图像(人脸图像)和一组负面图像(无脸图像)来训练分类器。
-
然后从图像中提取特征。下图显示了一些从包含人脸的图像中提取的特征。

- 要从图像中检测面孔,你需要寻找通常在人脸上发现的各种特征的存在(见下图),例如眉毛,其中眉毛上方的区域比下面的区域要亮。

- 当图像包含所有这些特征的组合时,它被认为包含一个面孔。
幸运的是,无需了解 Haar 级联的工作原理,OpenCV可以直接使用预训练的 Haar 级联进行面部检测,以及其他识别其他对象的 Haar 级联。预定义的 Haar 级联列表可在 GitHub 上找到,地址是github.com/opencv/opencv/tree/master/data/haarcascades。
开源计算机视觉(OpenCV)是一个开源计算机视觉和机器学习软件库,最初由英特尔开发。它旨在为计算机视觉应用提供通用基础设施,并加速机器感知在商业产品中的使用。OpenCV 带有多个预训练的 Haar 级联,可以检测眼睛、面孔、俄罗斯车牌、笑容等。
对于面部检测,你需要haarcascade_frontalface_default.xml文件,你可以从前一段中的 GitHub 链接下载。
安装 OpenCV
让我们尝试使用 OpenCV 进行面部检测。首先,你需要使用以下命令来安装它:
!pip install opencv-python
对于本文中的示例,你需要创建一个名为face_detection.py的文件。首先用以下语句填充它以导入 OpenCV 库:
import cv2
从摄像头读取
接下来的步骤是连接到你的摄像头并在屏幕上显示图像:
import cv2
# default webcam
stream = cv2.VideoCapture(0)
while(True):
# Capture frame-by-frame
(grabbed, frame) = stream.read()
# Show the frame
cv2.imshow("Image", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"): # Press q to break out of the loop
break
# Cleanup
stream.release()
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
要引用你的摄像头,使用VideoCapture类,并传递一个数字表示你的摄像头实例(0表示第一个摄像头,1表示第二个摄像头,以此类推):
stream = cv2.VideoCapture(0)
要不断从摄像头捕捉输入,使用无限循环(while(True)),然后读取每一帧并显示出来:
# Capture frame-by-frame
(grabbed, frame) = stream.read()
# Show the frame
cv2.imshow("Image", frame)
为了让程序能够优雅地退出,等待用户按下键盘上的一个键。当按下“q”键时,循环终止:
key = cv2.waitKey(1) & 0xFF
if key == ord("q"): # Press q to break out of the loop
break
然后你可以进行清理工作:
# Cleanup
stream.release()
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
要运行程序,请转到终端并输入:
$ python face_detection.py
你现在应该能看到你的面孔:

检测面孔
现在进入有趣的部分——检测面孔。首先,创建一个CascadeClassifier类的实例,并将haarcascade_frontalface_default.xml文件传递给它:
import cv2
# for face detection
face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')
你需要复制
haarcascade_frontalface_default.xml文件并将其放在与face_detection.py文件相同的文件夹中。你可以从github.com/opencv/opencv/tree/master/data/haarcascades下载 XML 文件。
现在你可以使用detectMultiScale()函数来检测面孔:
while(True):
# Capture frame-by-frame
(grabbed, frame) = stream.read()
#===============DETECTING FACES============
# Convert to grayscale
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# Try to detect faces in the webcam
faces = face_cascade.detectMultiScale(gray,
scaleFactor=1.3,
minNeighbors=5)
# for each faces found
for (x, y, w, h) in faces:
# Draw a rectangle around the face
color = (0, 255, 255) # in BGR
stroke = 5
cv2.rectangle(frame, (x, y), (x + w, y + h),
color, stroke)
#===============DETECTING FACE=============
# Show the frame
cv2.imshow("Image", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"): # Press q to break out of the loop
break
注意detectMultiScale()函数中的以下参数:
-
scaleFactor参数允许你将捕获的图像重新缩放到一个新的尺寸,以便算法可以检测到面部。 -
minNeighbors参数指定每个候选矩形应该有多少邻居才能保留它。这个参数会影响检测到的面部质量。较高的值会导致检测较少,但质量更高。通常,4 到 6 是一个较好的数字。
你可以调整这两个参数的值,以确保正确检测到面部。
当检测到面部时,你会想要在它们周围绘制矩形:
# for each faces found
for (x, y, w, h) in faces:
# Draw a rectangle around the face
color = (0, 255, 255) # in BGR
stroke = 5
cv2.rectangle(frame, (x, y), (x + w, y + h),
color, stroke)
当你再次运行face_detection.py文件时,你现在应该能够检测到面部:

face_detection.py 文件的全部内容如下:
import cv2
# for face detection
face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')
# default webcam
stream = cv2.VideoCapture(0)
while(True):
# Capture frame-by-frame
(grabbed, frame) = stream.read()
# Convert to grayscale
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# Try to detect faces in the webcam
faces = face_cascade.detectMultiScale(gray, scaleFactor=1.3, minNeighbors=5)
# for each faces found
for (x, y, w, h) in faces:
# Draw a rectangle around the face
color = (0, 255, 255) # in BGR
stroke = 5
cv2.rectangle(frame, (x, y), (x + w, y + h),
color, stroke)
# Show the frame
cv2.imshow("Image", frame)
key = cv2.waitKey(1) & 0xFF
if key == ord("q"): # Press q to break out of the loop
break
# Cleanup
stream.release()
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
如果你喜欢阅读我的文章,并且它对你的职业/学习有所帮助,请考虑成为 Medium 的会员。每月费用为 5 美元,它为你提供对 Medium 上所有文章(包括我的)的无限访问。如果你使用以下链接注册,我将获得少量佣金(对你没有额外费用)。你的支持意味着我将能花更多时间写出像这样的文章。
阅读魏梦李(以及 Medium 上其他成千上万的作家)的每一个故事。你的会员费直接支持…
weimenglee.medium.com](https://weimenglee.medium.com/membership?source=post_page-----316ded4d116f--------------------------------)
总结
我希望这篇简短的文章为你提供了一种使用 Python 和你的网络摄像头检测面部的简单方法。务必下载 haar cascades XML 文件,并将其放置在与你的 Python 文件相同的文件夹中。在未来的文章中,我将向你展示一些面部识别的技术。敬请期待!
什么是负二项分布
探索一种较不为人知的概率分布
·发表在 Towards Data Science ·阅读时间 5 分钟·2023 年 8 月 17 日
--

图片由Alperen Yazgı拍摄,来源于Unsplash
背景
也许你听说过二项分布,但你听说过它的“亲戚”负二项分布吗?这种离散概率分布被应用于许多行业,如保险和制造业(主要是基于计数的数据),因此对于数据科学家来说是一个有用的概念。在本文中,我们将深入探讨这种分布及其解决的问题。
补充视频。
什么是负二项分布?
要理解负二项分布,重要的是对二项分布有直观的认识。
二项分布衡量在给定的试验次数n中获得特定数量成功的概率,x。这些试验是伯努利试验,其中每个结果都是二元的(成功或失败)。如果你对二项分布不熟悉,可以查看我之前关于它的帖子:
理解二项分布的基本构建块
负二项分布翻转了这个模型,模拟了达到一定成功次数 r 所需的试验次数 x。这就是它被称为“负”的原因,因为它在无意中模拟了在达到一定成功次数之前的失败次数。
思考负二项分布的更好方式是:
在“x”次试验中发生“r”次成功的概率
负二项分布的一个特殊情况是 几何分布。它模拟了我们在获得第一次成功之前需要的试验次数。你可以在这里阅读更多关于几何分布的信息:
几何分布的简单描述和用途
towardsdatascience.com
关键假设
下面是负二项分布数据的主要假设:
-
每次试验有两个结果(伯努利)
-
每次试验都是独立的
-
成功的概率是恒定的
公式与推导
假设我们有:
-
p: 成功的概率
-
1-p: 失败的概率
-
x: 成功的 r 次试验的次数
-
r: 成功的次数 对于 x 次试验
因此,我们必须在 x-1 次试验中有 r-1 次成功,其概率简单地是二项分布的 概率质量函数 (PMF):

作者用 LaTeX 生成的方程。
我们得到的下一条信息是 r 次成功必须发生在 x 次试验中,它将有 p 的概率。因此,我们只需将上述公式乘以 p:

作者用 LaTeX 生成的方程。
这就是负二项分布的 PMF!
分布的均值可以表示为:

作者用 LaTeX 生成的方程。
均值和标准差的推导可以在 这里 找到
示例问题
在第六次掷骰子时,再次掷出 4 的概率是多少?
-
p = 1/6
-
r = 2
-
x = 6
将这些输入到上述 PMF 中得到:

作者用 LaTeX 生成的方程。
因此,我们在第六次掷骰时获得第二个 4 的可能性相当小。你还可以用这个负二项分布计算器来尝试你的计算。
如果我们想知道在其他掷骰中获得第二个 4 的概率呢?为了做到这一点,我们需要将第二次掷骰作为掷骰次数的函数来绘制,x:
import plotly.graph_objects as go
from math import comb
# Parameters
r = 2
p = 1 / 6
# PMF
def neg_binomial_pmf(x, r, p):
if x < r:
return 0
q = 1 - p
return comb(x - 1, r - 1) * (p ** r) * (q ** (x - r))
# Values
x = list(range(1, 30))
probs = [neg_binomial_pmf(k, r, p) for k in x]
# Plot
fig = go.Figure(data=[go.Bar(x=x, y=probs, marker_color='rgba(176, 224, 230)')])
fig.update_layout(title="Negative Binomial Distribution",
xaxis_title="x (number of trials to get second 4)",
yaxis_title="Probability",
template="simple_white",
font=dict(size=16),
title_x=0.5,
width=700,
height=500)
fig.show()

由作者在 Python 中生成的图表。
我们可以看到,获得第二个 4 的最可能掷骰结果是 6 和 7。然而,期望值为 12(2/(1/6)),这可以从我们之前展示的公式中推导出来。
数据科学中的应用
以下是负二项分布应用的一些领域:
-
事件发生的时间: 这对流失模型非常有用,我们希望预测客户可能何时取消订阅。如果我们知道何时以及谁会流失,我们可以采取专门的留存策略来试图留住客户。
-
缺陷预测: 预测在制造产品成为完全功能之前的缺陷数量。你可以把它想象成在我们达到最终提案之前要生产多少个产品。
-
体育分析: 例如预测一个足球运动员在多少次错失机会后能进一个球。这对博彩公司制定赔率非常有用。
-
营销: 确定在客户转换为订阅或点击网站之前需要展示多少广告。这是在预测转换率。
-
流行病学: 估计濒危物种的数量以及环境如何影响它们的数量。
总结与进一步思考
负二项分布模型描述了达到一定成功次数所需的失败次数的概率。这在许多数据科学领域都有应用,其中最显著的是客户流失预测。因此,对于数据科学家来说,这是一个有用的主题。
完整的代码可以在我的 GitHub 上找到:
[## Medium-Articles/Statistics/Distributions/negative_binomial.py at main · egorhowell/Medium-Articles
我在 Medium 博客/文章中使用的代码。通过创建一个账户来贡献 egorhowell/Medium-Articles 的开发…
参考资料与进一步阅读
另一个消息!
我有一份免费的通讯,Dishing the Data,在其中我分享成为更优秀数据科学家的每周技巧。没有“废话”或“点击诱饵”,只有来自实践数据科学家的纯粹实用见解。
[## Dishing The Data | Egor Howell | Substack
如何成为更优秀的数据科学家。点击阅读《Dishing The Data》,由 Egor Howell 撰写,Substack 出版物…
与我联系!
Falcon 180B:它能在你的计算机上运行吗?
原文:
towardsdatascience.com/falcon-180b-can-it-run-on-your-computer-c3f3fb1611a9
是的,如果你有足够的 CPU 内存
·发布于 Towards Data Science ·阅读时间 7 分钟·2023 年 9 月 12 日
--

2023 年 5 月,阿布扎比技术创新研究所(TII)发布了两个预训练的 LLM:Falcon-7B 和 Falcon-40B 及其聊天版本。这两个模型表现非常出色,并在 OpenLLM 排行榜 上排名第一。
TII 发布的第三个模型刚刚加入了 Falcon 家族:Falcon 180B,一个具有 1800 亿参数的模型。它比 Llama 2 70B 多了 2.5 倍的参数,比 Falcon-40B 多了 4.5 倍。
以下是关于 Falcon 180B 的一些事实(来源:Falcon 180B 模型卡):
-
经过 3.5 万亿个标记的预训练 (RefinedWeb)
-
以 Apache 2.0 许可证分发
-
大小为 360 GB
-
在 OpenLLM 排行榜 上排名第一(截至 2023 年 9 月 11 日):

OpenLLM 排行榜截图(2023 年 9 月 11 日) — 作者提供的图片
还有一个聊天版本。模型可以在 Hugging Face hub 上获取:
Falcon 180B 完全免费且技术先进。但它也是一个庞大的模型。
它能在你的计算机上运行吗?
除非你的计算机准备好进行非常高强度的计算,否则无法开箱即用运行 Falcon 180B。你需要升级计算机并使用该模型的量化版本。
在这篇文章中,我解释了如何在消费级硬件上运行 Falcon-180B。我们将看到在现代计算机上运行一个 1800 亿参数的模型是相对负担得起的。我还讨论了几种有助于减少硬件要求的技术。
在你的计算机上加载 Falcon 180B:你需要什么?
首先你需要知道的是,Falcon 180B 具有 1800 亿个参数,存储为 bfloat16。一个 (b)float16 参数在内存中占用 2 字节。
当你加载一个模型时,标准的 Pytorch 流程如下:
-
一个空模型被创建:180B 参数 * 2 字节 = 360 GB
-
将权重加载到内存中:180B 参数 * 2 字节 = 360 GB
-
将第 2 步加载的权重加载到第 1 步创建的空模型中。
-
将第 3 步获得的模型移动到用于推理的设备上,例如 GPU。
第 1 步和第 2 步是消耗内存的步骤。总的来说,你需要 720 GB 的可用内存。这可以是 CPU RAM,但为了快速推理,你可能想使用 GPU,例如 9 个带有 80 GB VRAM 的 A100。
无论是 CPU RAM 还是 VRAM,这都是大量的内存。幸运的是,这些要求可以很容易地减少。
在 Hugging Face Hub 上,Falcon 180B 采用 safetensors 格式分发。这个格式相较于标准的 Pytorch 格式有几个优势。它(几乎)没有复制,因此模型直接加载到第 1 步创建的空模型中。这节省了大量内存。
关于 safetensors
safetensors 节省内存,但它还使模型运行更安全,因为在这种格式中无法存储任意代码。safetensors 模型的加载速度也更快。当你从 hub 下载模型时,使用这种格式代替“.bin”格式,可以实现更快、更安全且节省内存的加载。
尽管看起来我们跳过了第 2 步,但仍然会有一些内存开销需要预期。TII 在模型卡上写道,400 GB 的内存是可行的。这仍然很多,但比使用标准的 Pytorch 格式少 220 GB。
我们需要一个具有 400 GB 存储空间的设备,例如,5 个带有 80 GB VRAM 的 A100 GPU。我们距离“消费级”配置还有很大差距。
将 Falcon 180B 分割到多个内存设备上。
你可能没有一个 400 GB 的单一内存设备,但如果你结合所有以下设备的内存,你的计算机可能有超过 400 GB 的内存:
-
GPU 的 VRAM:如果你有一块 NVIDIA RTX 3090 或 4090,那已经有 24 GB。
-
CPU RAM:大多数现代计算机至少有 12 GB 的 CPU RAM。扩展 CPU RAM 也非常便宜。
-
硬盘(或 SSD):这可以是几个 TB 的空闲内存。请注意,如果你计划使用 SSD(NVMe M2 类型)来运行 LLM,它将比典型的硬盘快得多。
为了利用可用设备,我们可以将 Falcon 180B 切分,以便它按照优先顺序使用设备的最大内存:GPU、CPU RAM 和硬盘。
通过Accelerate的 device_map 可以轻松实现这一点。
device_map 将模型的整个层分配到你拥有的不同设备上。

device_map — 作者提供的图片
如果你想查看一些 device_map 使用的例子,查看我的笔记本。我在大多数笔记本中使用了 device_map。
device_map 非常方便可以避免 CUDA 内存不足错误。但如果你打算在消费级硬件上使用 Falcon 180B,这仍然远非理想。即使配置高端,配备 24 GB 的 VRAM 和 32 GB 的 CPU RAM,也会在硬盘上留下几百 GB 的空间。
这是一个问题,原因有两个:
-
硬盘和 SSD 的速度比 VRAM 和 CPU RAM慢得多。从硬盘加载和运行 Falcon 180B 会花费非常长的时间。
-
消费级硬盘和 SSD并未设计和测试用于这种密集使用。如果模型的许多部分被卸载到硬盘上,系统将不得不在推理过程中多次访问和读取这些巨大的模型分片。这是一个长期的巨大读写操作数量。如果你进行几天的推理,例如生成一些合成数据集,这可能会损坏你的硬盘,或者至少显著减少其寿命。
为了避免过度使用硬盘,我们没有太多解决方案:
-
添加一张 GPU:大多数高端主板可以容纳两张 RTX 3090/4090。这将为你提供 48 GB 的 VRAM。
-
扩展 CPU RAM:大多数主板有 4 个可用于 CPU RAM 套件的插槽。虽然有 4128GB 的 CPU RAM 套件出售,但不容易找到,而且仍然很昂贵。注意:操作系统对 CPU RAM 的总支持量也有限制。对于 Windows 10,它是 2 TB。如果你使用的是旧版本操作系统,在购买更多 RAM 之前应查看其文档。*
-
量化 Falcon 180B 并扩展 CPU RAM。
Falcon 180B 的量化是减少内存消耗的最佳选择之一。
通过量化减少 Falcon 180B 的大小
量化非常大的语言模型到较低精度现在已成为一种常见做法。GPTQ和bitsandbytes nf4是将 LLM 量化到 4 位精度的两种流行方法。
Falcon 180B 使用 bfloat16。我们看到它是 360 GB。
一旦量化到 4 位精度,它仅为 90 GB(1800 亿参数 * 0.5 字节)。我们可以使用 100 GB 的内存(90GB + 一些内存开销)来加载 4 位 Falcon 180B。
如果你有 24 GB 的 VRAM,你“仅需”75 GB 的 CPU RAM。这仍然很多,但比加载原始模型要实惠得多,而且在推理过程中不会将模型的层卸载到硬盘上。注意:你仍然需要在硬盘上留出 100 GB 的自由空间以存储模型。
你甚至不需要 GPU。拥有 128GB 的 CPU 内存,你可以仅使用 CPU 进行推理。
量化本身是非常昂贵的。幸运的是,我们已经可以在网上找到量化版本。TheBloke发布了使用 GPTQ 制作的 4-bit 版本:
注意:还有 3-bit 模型作为这些模型的“分支”提供。请遵循 4-bit 模型卡上的说明来获取这些模型。
尽管模型的精度有所降低,但根据Hugging Face 的实验,模型的性能保持相似。
GPTQ 模型推理速度快,你可以使用 LoRA 适配器进行微调。例如,我在这篇文章中展示了如何微调使用 GPTQ 量化的 Llama 2。
## 使用 Transformers 和 TRL 进行量化和微调 LLMs
GPTQ 现在使用起来容易多了
虽然可以对 GPTQ 模型进行微调,但我不推荐这样做。使用 QLoRA 进行微调虽然内存消耗相似,但由于 nf4 量化的改进,能获得更好的模型,正如QLoRA 论文中所示。
结论
总结一下,你需要量化和 100GB 内存来在一个相对实惠的电脑上运行 Falcon 180B。
对于快速推理或微调,你需要一个 GPU。RTX 4090(或更便宜但较慢的 RTX 3090 24GB)足以加载量化模型的 1/4。如果你的电脑机箱有足够的空间,你甚至可以放入两张 RTX 显卡。
如果你想在仅 CPU 配置上运行 Falcon-180B,即没有 GPU,忘掉微调吧,这会太慢。推理也会很慢,但使用近期的高端 CPU 和优化过的软件,如 llama.cpp,运行 Falcon 180B 还是可能的。
如果你对使用llama.cpp感兴趣,可以看看我关于如何仅使用 CPU 运行 Vicuna 的文章:
## 高速推理:使用 llama.cpp 和 Vicuna 在 CPU 上
快速推理不需要 GPU
鹰:开源大型语言模型的巅峰
原文:
towardsdatascience.com/falcon-the-pinnacle-of-open-source-llms-600de69c333c
开源 LLMs 与专有 LLMs 之间的差距持续缩小…
·发表于 Towards Data Science ·阅读时间 14 分钟·2023 年 10 月 24 日
--

(照片由 Alan Mersom 提供,来源于 Unsplash)
最近,开源大型语言模型(LLMs)的研究主要集中在两个方面:模仿学习和预训练开源基础模型。虽然这两种方法都是可行的,但创建高质量的开源基础模型尤其令人振奋,因为这些模型可以进一步微调(成本较低)并用于各种不同的下游应用。最初尝试创建这些模型的结果并不理想。尽管后来的模型(例如,LLaMA 和 MPT-7B)表现更佳,但这些模型在质量上仍难以匹敌其专有对手(例如,GPT-3.5 或 GPT-4),直到最近才有所改进。
随着 Falcon-7B 和 Falcon-40B LLMs [1] 的发布,我们第一次看到开始与最受欢迎的付费模型质量相媲美的开源基础 LLMs。这些模型在通过新颖的数据管道获得的大规模文本语料库上进行训练,取得了(以相当大的优势)新的开源 LLMs 的最先进性能,并且可以自由用于商业应用。更棒的是,Falcon 模型对其底层的 Transformer 架构进行了若干修改,这些修改显著加快了推理速度,甚至可以提高预训练的效率。

(来自 [1, 2])
大局观。 创建一个 LLM 的过程包含几个步骤;见下文。这个过程的第一步(即,获得预训练基础模型)被广泛认为是最昂贵的,无论是在金钱还是时间上。

创建和完善 LLM 的多步骤过程(来自 [16, 17])
这些模型之前被保留在专有 API 后面,但开源 LLM 的进步使高性能基础 LLM 更加公开。Falcon 是这个类别中的另一个模型,相比其他开源替代品,它达到了前所未有的性能水平。
使用 Web 数据进行 LLM 预训练

(来自 [3])
当我们探讨预训练与微调(即,SFT 和 RLHF)语言模型之间的主要差异时,会发现预训练比微调要困难得多(且更昂贵);见上文。预训练有两个根本性特征使其如此困难:
-
模型是从头开始训练的,因此需要更多的训练迭代次数。
-
预训练数据集必须大且多样化(即,提供尽可能多的“覆盖”),以便生成的 LLM 拥有较大的知识基础。
简而言之,预训练数据集非常庞大(例如,Chinchilla [6] 在 1.4 万亿文本令牌上训练),这意味着预训练的范围不可避免地很大。我们必须进行大量的训练迭代才能遍历所有这些数据!
创建预训练数据集。 然而,数据集的大小不仅是使预训练成为如此庞大任务的原因。仅仅策划数据集就是一个复杂的过程,涉及到检索数据和执行整个过滤(例如,基于数据质量、污染数据、PII 等)和去重步骤的管道。已经提出并探索了各种不同的处理步骤来策划 LLM 预训练数据;见 这里。
尽管最初可能认为这些处理步骤可以简化或避免,但 LLM 研究一次又一次地向我们展示了模型训练数据的质量是极其重要的。例如,我们可以看到 LIMA [7] 或 Galactica [8],这两个模型都在较小(但高质量)的文本语料库上训练,结果与在更大规模的噪声数据集上训练的相同模型的性能相匹配或超越。

(来自 [4, 5])
当前的 LLM。 由于数据质量对模型质量的影响,大多数 LLM 使用的预训练数据来自高度策划的来源,例如筛选过的文本内容、书籍、代码或技术报告;见上文 [4, 5]。实际上,许多策划过的预训练数据公共来源在线上随处可见(例如,the Pile 或 C4),并且已经被现有模型广泛使用。
“策展被认为是产生高性能模型所必需的……然而,随着模型需要在万亿个标记上进行预训练,尚不清楚策展是否具有可扩展性,以及我们是否会很快耗尽独特的高质量数据。” — 来源于 [2]
然而,使用策展数据源是否真的具有可扩展性仍然值得怀疑——随着预训练数据集规模的扩大,细粒度的过滤和策展变得越来越困难。因此,对于较大的模型和数据集,广泛的策展可能变得不再那么必要,特别是考虑到为 LLM 预训练找到足够的数据变得越来越困难。
RefinedWeb: 可扩展的网页文本策展

(来自 [2])
鉴于这些局限性,Falcon [1]的作者探讨了可扩展和高效的数据策展方法,这些方法可以推广到大量数据。用于创建RefinedWeb预训练数据集的完整数据策展流程,Falcon-7B [10]和 Falcon-40B [11]即在该数据集上进行预训练,详细描述见[2]。RefinedWeb 由经过简化过滤流程的网页数据组成,可以用于训练比在策展数据源上训练的类似模型性能更优的模型;见上文。这一发现表明,大规模训练语料库可以有效地从专门从互联网获得的数据中创建(而不是策展数据源)。
“挑战了对数据质量和 LLMs 的现有信念,单独使用经过适当过滤和去重的网络数据训练的模型可以达到与训练在策展数据上的模型相当的性能。” — 来源于 [2]
简化的策展流程。用于 Falcon 的预训练数据集基于Common Crawl。与之前的工作相比,[2]中的作者通过强调规模来区分他们的数据策展流程,旨在从网络中生成一个包含 3-6 万亿个数据标记的预训练语料库。这远远超过了之前工作的数据集——即使是 MassiveText(即用于预训练 Gopher [9]和 Chinchilla [6]的语料库)也仅包含 2.3 万亿个标记的文本。而且,现有模型在预训练过程中仅使用了这些数据的一个子集。

(来自 [2])
[2] 中构建的语料库包含大约 5 万亿个仅含英语的数据标记,完全从网络上获取;见上文。尽管数据规模庞大,作者在构建过程中采用了严格的去重策略,这一策略以较高的速度去除准确的和 模糊重复。在这种去重之外,只进行了最小的额外过滤。实际上,除了通过 FastText 分类器 进行语言识别外,没有进行其他基于机器学习的过滤。

(来自 [9])
除了过滤非英语文档,[2] 中的作者还使用了几个简单的启发式方法来过滤不需要的内容,例如:
-
从与黑名单关联的 URL 中过滤内容
-
使用 trafilatura 从网页中提取内容
-
定义简单规则来识别和过滤 PII
此外,采用了 MassiveText [9] 过滤管道中的几个步骤;见上文。RefinedWeb 的完整策划管道移除了原本从 CommonCrawl 获得的近 90% 数据,但在完全过滤后的语料库中仍保留了超过 5 万亿个标记的文本数据;见下文。此外,这个数据集是“多模态友好的”,因为它包含了指向图片的链接和替代文本。

(来自 [2])
开源预训练数据。 生成的 RefinedWeb 数据集非常庞大,表明互联网上有足够的数据可以进行前所未有规模的 LLM 预训练。换句话说,我们并没有(至少暂时没有)数据短缺。然而,[2] 中的作者仅公开了该语料库中一个小的 600B 标记子集(即 12% 的数据)。尽管其规模较小,但这个公开的 RefinedWeb 子集仍然是任何从事 LLM 工作的实践者有用的预训练数据来源。事实上,在 [2] 中,基于这些数据训练的小型 Falcon 变体相比于在策划语料库上训练的模型表现更为优越;详见下文。

(来自 [2])
Falcon 系列 LLM

(来自 TII Falcon 网页 [1])
Falcon 系列 LLM [1],包括 Falcon-7B [10] 和 Falcon-40B [11],在开源模型中实现了最先进的性能。此外,这些模型的指令调优变体,如 Falcon-40B-Instruct(即 HuggingFace 的 Open LLM 排行榜 上的顶级模型),在公共基准测试中表现更佳。正如我们将看到的,这些模型具有几个关键特性(例如,数据、性能和推理速度),使其独特且实际有用。
商业许可证。 起初,Falcon 模型以一种特殊的许可证发布,该许可证要求在模型(或其任何衍生物)用于商业应用时支付版税。然而,在这一初始发布之后不久,该许可证被修改为普通的 Apache 2.0 许可证,这意味着 Falcon 基础模型现在可以在商业应用中免费使用!与其他开源和可商业使用的 LLMs(即使是 MPT-30B)相比,Falcon-40B 实现了独特的卓越性能。
Falcon-7B 和 40B 数据集

(来自 [10, 11])
如前所述,数据集 [2] 被用于对开源 Falcon-7B/40B 模型进行预训练。然而,这些模型仅在完整的 5 万亿标记语料库的一个子集上进行预训练(即 Falcon-7B 为 1.5 万亿标记,Falcon-40B 为 1 万亿标记)。然而,这个语料库的子集通过额外的、精心挑选的数据(如书籍、代码和技术内容)进行了增强;见上文。由于其更大的规模,Falcon-40B 的训练数据量少于 Falcon-7B。尽管如此,较大的模型仍然表现更好,训练时间超过两个月,而 Falcon-7B 仅需两周。
多语言 LLMs。 RefinedWeb 语料库仅包含英文数据,而 Falcon-7B 仅使用英文文本进行训练。然而,Falcon-40B 的预训练集中增加了 RefinedWeb-Europe 语料库,该语料库包含来自各种常见欧洲语言的文本数据。尽管这些数据仅占预训练语料库的 7%,但它将少量多语言数据注入到模型的知识库中,使其在需要基本多语言理解的公共基准测试中表现更佳。
Falcon 架构
Falcon-7B 和 Falcon-40B 模型都使用了修改过的解码器-only transformer 架构。对该架构的修改包括:
-
Flash Attention
-
RoPE 嵌入
-
Multi-Query Attention
-
并行注意力和前馈层
这些修改(其中一些与 MPT 套件的 LLMs 共享)极大地提高了 Falcon 的推理速度。实际上,Falcon-40B 的推理速度是 GPT-3 的 5X。此外,由于这些修改,预训练 Falcon-40B 的成本较低;例如,Falcon-40B 需要 75% 的 GPT-3 [4] 计算预算,40% 的 Chinchilla [6] 计算预算和 80% 的 PaLM-62B [5] 计算预算。Falcon-7B 和 Falcon-40B 的训练序列长度为 2K 标记,相较于最近的 LLMs(如 StoryWriter 和 Claude)可以说较小。

(来自 [10, 11])
Falcon-7B/40B 共享相同的模型架构,但 40B 变体稍微更深(即 60 层对比 32 层)且具有更高维度的隐藏层;见上文。使用 Falcon-40B 需要约 90Gb 的内存,这比类似的模型(如 LLaMA-65B)有更低的开销。然而,Falcon-40B 仍不能像 MPT-30B 一样在单个 GPU 上托管。鉴于其较小的规模,Falcon-7B 仅需约 15Gb 的内存,使其在推理和微调中更具可及性。
RoPE 嵌入。 正如我们在之前的概述中看到的,自注意力操作(该操作在语言模型的解码器单一转换器架构的每一层中实现)并不会自然地考虑序列中每个标记的位置。因此,我们必须为每个标记注入位置信息(例如,通过加性位置嵌入)到此操作中;见下文。

转换器的加性位置嵌入(由作者创建)
提出了几种位置嵌入变体——这些变体可能在训练过程中学习每个嵌入,也可能不学习,包括绝对嵌入和相对嵌入。然而,旋转位置嵌入(RoPE)[15]是一个将每个标记的绝对位置(即序列中的全局位置)和相对位置(即基于标记之间的距离定义位置)结合到自注意力中的替代方案,如下所示:
-
用旋转矩阵编码绝对位置
-
将相对位置信息直接添加到自注意力操作中
这种方法被发现能够在绝对位置和相对位置信息之间取得平衡,这对需要较长序列长度的任务尤其有利。因此,RoPE 嵌入最近获得了广泛关注,导致其被用于如 PaLM [5]等模型中。有关更详细的概述,请查看 RoPE 嵌入的这里。
多查询注意力。 Falcon 模型用一种称为多查询注意力的替代结构替换了典型的多头自注意力操作。多查询注意力仅在每层的注意力头之间共享键和值向量(如下图红色高亮部分)。所有头部共享相同的键的投影矩阵和相同的值的投影层,而不是为每个头部执行单独的投影。尽管这一变化并未加快训练速度,但显著提高了最终 LLM 的推理速度。

LLM 中的多查询注意力(来自[18])
并行注意力层。 最后,Falcon 模型在其架构的每一层结构中做出了一项基本变化。与“串行”解码器仅限变体的层不同,Falcon 模型的每一层都并行执行自注意力和前馈变换,然后进行单层归一化操作。这个公式与标准的变换器块之间的差异如下所示。有趣的是,这种并行公式并没有降低模型的性能。然而,由于变换器层的两个主要操作是并行进行的,因此它可能在推理速度上带来好处。

解码器仅限的变体层(由作者创建)
开源语言模型的新标准!
在撰写时,关于 Falcon 模型的手稿尚未发布。然而,Falcon-7B 和 Falcon-40B(以及它们的指令调优变体)已通过 Open LLM Leaderboard 进行评估,该榜单包括多个基准测试,例如:
通过该榜单进行的评估是不完整和初步的。然而,这些评估虽然在合理范围内捕捉了模型性能,但清楚地显示了 Falcon-40B 是当前开源语言模型的最先进技术;见下文。

(来自 Open LLM Leaderboard)
Falcon-40B 的指令变体(即 Falcon-40B-Instruct),它在来自 Baize 的数据混合上进行了指令调优,远远超越了各种其他开源模型。此外,预训练的 Falcon-40B 模型表现也相当不错,甚至优于像 LLaMA-65B 和 MPT-30B 这样的著名基础模型。进一步说,Falcon-40B 也可用于商业用途,而榜单上的许多可比模型(例如 LLaMA [13]、Guanaco [14] 和 Lazarus)仅可用于研究目的。
Falcon 的实际使用。 鉴于其性能非常出色,与其他语言模型相比托管起来轻量(由于改进的推理速度),并且可以在商业应用中自由使用,Falcon LLMs 是任何从事 LLM 工作的实践者的重要开源工具。幸运的是,已经撰写了几篇详细的概述,概述了在实践中微调和托管/部署这些模型的有用框架。
-
在 AWS Sagemaker 上部署 Falcon-40B [link]
-
使用 Falcon 进行推理和参数高效微调 [link]
-
使用 PyTorch Lightning 对 Falcon-40B 进行微调 [link]
鉴于 Falcon 使用 AWS 进行训练,目前有相当数量的解释性文章介绍如何在类似硬件上部署和训练这些模型。这些文章为任何希望在自己的用例中利用 Falcon 的人提供了一个良好的起点。
最终思考
Falcon 的发布是开源 LLM 研究和应用的重大突破。当我们审视这些模型的独特贡献时,我们立即看到一些关键组件,这些组件导致了成功:
-
大规模预训练数据的独特混合
-
针对效率优化的架构
RefinedWeb 数据集显示,文本语料库可以在比以前探索的更大规模上创建。为此,我们只需从网络上下载大量数据,并采用严格的去重规则以及更简单高效的过滤启发式方法。然后,通过将这些庞大的数据源与少量精选文本相结合,我们可以预训练一个性能极佳的开源 LLM。最后,Falcon 模型的修改架构使得训练和推理更加高效,结果是一个表现出色且在部署时能快速生成文本的模型。
与我联系!
非常感谢阅读这篇文章。我是 Cameron R. Wolfe,Rebuy 的 AI 总监。我研究深度学习的经验和理论基础。如果你喜欢这篇概述,请订阅我的 Deep (Learning) Focus newsletter,我通过从头开始概述相关主题帮助读者理解 AI 研究。你还可以在 X 和 LinkedIn 上关注我,或者查看我在 medium 上的 其他著作!
参考文献
[1] “Introducing Falcon LLM”, 技术创新研究所, 2023 年 6 月 7 日, falconllm.tii.ae/.
[2] Penedo, Guilherme 等. “The RefinedWeb dataset for Falcon LLM: outperforming curated corpora with web data, and web data only.” arXiv preprint arXiv:2306.01116 (2023).
[3] “Introducing MPT-7B: A New Standard for Open-Source, Commercially Usable Llms.” MosaicML, 2023 年 5 月 5 日, www.mosaicml.com/blog/mpt-7b.
[4] Brown, Tom 等. “Language models are few-shot learners.” Advances in neural information processing systems 33 (2020): 1877–1901.
[5] Chowdhery, Aakanksha 等. “Palm: 利用路径扩展语言建模。” arXiv 预印本 arXiv:2204.02311 (2022)。
[6] Hoffmann, Jordan 等. “训练计算最优的大型语言模型。” arXiv 预印本 arXiv:2203.15556 (2022)。
[7] Zhou, Chunting 等. “Lima: 对齐的‘少即是多’。” arXiv 预印本 arXiv:2305.11206 (2023)。
[8] Taylor, Ross 等. “Galactica: 一种用于科学的大型语言模型。” arXiv 预印本 arXiv:2211.09085 (2022)。
[9] Rae, Jack W. 等. “扩展语言模型:方法、分析与训练 gopher 的见解。” arXiv 预印本 arXiv:2112.11446 (2021)。
[10] “Falcon-7B”,技术创新研究所,HuggingFace 页面,huggingface.co/tiiuae/falcon-7b.。
[11] “Falcon-40B”,技术创新研究所,HuggingFace 页面,huggingface.co/tiiuae/falcon-40b。
[12] Gao, Leo 等. “The pile: 一个 800GB 的多样文本数据集用于语言建模。” arXiv 预印本 arXiv:2101.00027 (2020)。
[13] Touvron, Hugo 等. “Llama: 开放且高效的基础语言模型。” arXiv 预印本 arXiv:2302.13971 (2023)。
[14] Dettmers, Tim 等. “Qlora: 量化语言模型的高效微调。” arXiv 预印本 arXiv:2305.14314 (2023)。
[15] Su, Jianlin 等. “Roformer: 增强的带有旋转位置嵌入的变换器。” arXiv 预印本 arXiv:2104.09864 (2021)。
[16] Ouyang, Long 等. “训练语言模型以遵循指令并获得人类反馈。” 神经信息处理系统进展 35 (2022): 27730–27744。
[17] Glaese, Amelia 等. “通过有针对性的人类判断提高对话代理的对齐。” arXiv 预印本 arXiv:2209.14375 (2022)。
[18] Vaswani, Ashish 等. “注意力即你所需。” 神经信息处理系统进展 30 (2017)。
假先知:一个自制的时间序列回归模型
原文:
towardsdatascience.com/false-prophet-a-homemade-time-series-regression-model-54e296b99438
借用 Meta 的 Prophet 的思想来构建一个强大的时间序列回归模型。
·发表于 Towards Data Science ·阅读时间 16 分钟·2023 年 10 月 31 日
--

照片由 Niklas Rhöse 提供,来源于 Unsplash
在这篇后续文章中,我继续我的使命,通过结合流行的 Prophet 包¹ 和演讲“用简单甚至线性模型取胜”² 的思想来构建弗兰肯斯坦的时间序列怪物。
在回顾我们正在做的事情之后,我们将讨论回归模型——它是什么,以及它的特别之处。
然后,我们将使用时间序列交叉验证进行超参数调整,以获得“最佳”的模型参数化。
最后,我们将使用 SHAP 验证模型,然后利用模型形式进行定制调查和手动调整。
这是一个广泛的范围——让我们开始吧。
附注:我们在上一篇文章中已经涵盖了基础数据准备和特征工程,因此直接进入建模部分。赶紧了解一下之前的内容:
## 假先知:自制时间序列回归的特征工程(第一部分,共 2 部分)
基于 Meta 的 Prophet 包中的思想,创建强大的时间序列机器学习模型功能。
towardsdatascience.com
全局视角
让我们回顾一下我们在做什么。
终极目标很简单:在指定的时间范围内生成最准确的未来事件预测。
我们从一个仅包含日期变量和关注数量的时间序列开始。从中,我们衍生了额外的特征,以帮助我们准确建模未来的结果;这些特征深受 Prophet 方法的“启发”。
这就把我们带到现在的位置:大约准备将我们设计的数据输入一个轻量级模型,训练它以预测未来。稍后我们将深入了解模型的内部工作原理。
在继续之前,让我们回顾一下数据的样子。
数据
我们使用来自英国的实际数据——在这种情况下,STATS19 道路交通事故数据集,其中包含有关某些车祸的信息³。
我们将这个大型数据集总结为每月事故总数。这意味着我们的目标看起来像这样:

作者提供的图像
我们之前指出了下行趋势、强烈的季节性效应和明显的趋势变化。让我们看看我们的特征工程和建模如何协同工作来捕捉这些。
模型
虽然还不是圣诞节,但我确实有一些特征在我的模型愿望清单上。我将列出这些特征,并简要介绍为什么选择使用 LASSO 模型(剧透!)。
首先,模型应该尽可能强健。“强健”可以有很多含义,因此我们会更具体一些:我希望一个可以训练并放置一段时间的模型。换句话说,模型应该能够推断处理之前未见过但“正常”的输入值。
模型需要具有可解释性和可解释性。它应该易于查询和理解,并且简单易解释。理想情况下,我应该能够准确量化各种影响的影响。
最后,模型应该具有某种内置特征选择功能。理论上,适当执行特征选择可以提高模型性能,我可以进行严格的练习来找到最有效的特征集。实际上,现在是周六下午,人生太短暂。因此,模型将不得不自己解决。
尽管基于树的模型满足其中的一部分要求,但它们在推断方面表现不佳。它们也不能真正用于说明某些事情,比如“在六月,季节性导致 X%的变化”。
我想要的是一些通常线性的东西(提示提示),再加上一些额外的东西。
进入 LASSO
最小绝对收缩和选择算子(LASSO)是一种统计模型。它同时执行特征选择和正则化,以增强模型的预测能力和可解释性⁷。
我们不会过于关注细节,而是认识到 LASSO 是一种带有变化的线性回归方法。在模型训练过程中,使用了一个修改过的目标函数,其目的是在对回归系数施加某些约束的情况下,最大化模型的预测能力。这个约束——或称为“正则化”——有缩小特征系数接近零的效果,有效减少了给定特征对模型的影响。如果系数被缩小到零,该特征对模型拟合或预测将完全没有影响,特征实际上被从模型中移除。
这种特征系数缩减有效地减少了过拟合的可能性,因此应提高模型对新数据的泛化能力。
附注:LASSO 正则化有时被称为 L1 正则化或 L1 惩罚。这来自于正则化项的数学形式,该形式使用 L1 范数。
泛化事物
Scikit-learn 线性实现了 LASSO —— 即 Lasso 类假设目标可以表示为输入特征的线性组合。
让我们回顾一下我们正在建模的数据。我们可以看到序列的季节性成分的幅度似乎与趋势的幅度一致,这表明时间序列是乘性的而不是加性的。换句话说,我们的时间序列是时间序列元素的 乘积,而不是它们的 总和。这意味着在模型形式方面我们需要一点想象力。
幸运的是,我们不需要太有创意——通过对每月交通事故的自然对数建模,我们可以将加性 LASSO 转换为乘性模型。我们可以大致认为模型形式如下:

作者提供的图像
… 这意味着我们的预测看起来像:

作者提供的图像
附注:上述数学使用了“帽子”符号,这是在讨论估计值时的惯例。由于预测是从估计的 beta 值中得出的——这些 beta 值本身是根据数据得出的——我们对目标和回归系数都使用帽子符号。
稍后会详细讲解,但值得注意的是,当我们开始探讨模型时,要牢记这一点。
Alpha,正则化参数
我们已经确定了模型,现在需要确定其参数化。
在 sklearn 中,LASSO 的正则化强度通过 alpha 参数控制。我对这个超参数的最佳值没有太多直觉,因此需要进行一些搜索,记住“最佳”模型是预测效果最好的模型。
我们需要某种训练-测试分割方法来测试给定模型参数化在未来的预测效果:通常随机分割交叉验证方法效果不佳,因为目标的时间元素会混淆,甚至更糟——泄漏。
幸运的是,时间序列交叉验证是一个实际存在的技术,并且在 scikit-learn 中有很好的实现。我们将使用它来搜索从潜在值分布中随机抽取的各种alpha值,并使用多种指标评估预测准确性。
顺便说一下,如果你不熟悉时间序列交叉验证,可以在这里快速了解:
通过聪明的子划分最大化你的时间序列数据的效用。
经过一些搜索,我们可以可视化预测误差,并选择使其最小化的alpha值:

图片由作者提供
我们可以看到,当alpha接近 0.0017 时,均方误差(MAE)、均方根误差(RMSE)和平均绝对百分比误差(MAPE)都达到了最低值。
由于我们暂时没有比较模型,因此指标的实际值并不太重要。值得指出的是,MAPE 实际上不是一个百分比;我在上面的图表中将其格式化为百分比,以帮助我感受这些实际上相当小的数字。
到目前为止一切顺利。但一个具有这种alpha的模型实际预测效果如何?实际上还不错:

图片由作者提供
上面的图表是使用上述alpha参数的 LASSO 模型时间序列交叉验证结果的可视化表示。
每条彩色线代表了一个在预测期之前基于所有数据(灰色)训练的模型的折外预测。
例如,橙色线代表了基于截至 2010 年数据训练的模型对 2011 年的 12 个月预测。
预测结果有点难以看清,我们来放大一下:

图片由作者提供
对于如此嘈杂的实际数据,这些预测效果还不错:它们似乎捕捉到了年度趋势,并预测了更强的季节性变化。
然而,模型预测往往会错过每个时间段开始的前几个样本,然后通常会回到正常范围。
考虑到我们注意到在 2012–2014 年左右,底层数据中趋势和季节性发生了变化,我认为这个模型实际上表现得相当不错!
在继续进行模型调查之前,让我们使用这个alpha值来重新拟合整个数据范围的模型。

图片由作者提供
结果、验证与调查
我们想要一个易于理解和解释的模型。让我们看看我们做得如何。
SHAP 验证
在模型评估方面,SHAP 包相当普及。我们将跳过整个“什么是 SHAP?SHAP 如何工作?”的介绍,直接进入结果,仅简要回顾一下 SHAP 关注的是模型预测的驱动因素,而不是模型拟合的驱动因素。
暂停完毕。
作为开端,总体总结突出显示了预测的最大驱动因素。

图片由作者提供
图表的顶部似乎是year特征,我们看到year的值较高时,预测值会降低。这与我们在数据中看到的明显下降趋势一致。
我们看到许多受 Prophet 启发的特征在预测中发挥了关键作用。有趣的是,prophet_sin_12出现在列表中(并且排得如此靠前!)——我们扩展 Prophet 特征创建范围的直觉似乎得到了验证。
还有一些其他特征出现了,比如其中一个周期性编码的月份特征(mth_sin)和一个月中的假期数量(hols),但这个列表主要被变化点特征所主导。以至于它可能需要进一步调查。
SHAP 还允许我们可视化单个预测的驱动因素。作为一个开端,让我们看看第一个和最后一个观察结果。

图片由作者提供
记住这两个观察结果相隔大约 20 年和 8,000 起事故,值得注意的是每个预测的驱动因素中都有一些共同点:这两个预测都受到了year和hols的高度影响(尽管方向不一定相同)。
让我们转向一些更富有想象力的内容。
定制调查
我们现在可以开始有创意地调查我们的模型。
记住,我们实际上建立了一个通用的线性模型,并添加了一些正则化,目标的转换意味着我们实际上使用了对数链接函数。
记住,在模型训练中,特征系数(beta)受到约束,最终会收缩到接近零。这给了我们第一个调查途径:检查拟合模型参数的大小。
任何具有非零系数的特征都会被用于模型中,特征系数的大小是其在整体预测中的“重要性”的一个指标。
这可以相对容易地完成,使用coef_属性从拟合模型中获取系数,然后记得在排序和/或排名之前取提取系数的绝对值。检查那些可以合理预期包含在模型中的但实际上没有的特征也是个好主意。
我们还能做些什么?好吧,这里就变得非常有趣了。
由于我们使用了 GLM,我们可以将特征分组在一起计算对预测的组影响。由于我们使用了对数链接函数,我们的影响是乘法的,理解起来更简单——例如,组影响为 1.05 = 增加 5%。
例如,我们有大量的 Prophet 特征。我们可以很容易地直接计算所有 Prophet 特征对总体预测的影响。
这种分组也不是绝对的——我们可以根据需要进行分组。在我们的案例中,我将把特征分组到主题中:
def feature_type(s,
changepoint_flag=['change'],
seasonal_flag=['prophet','is_','qtr','mth','dim'],
trend_flag = ['yr'],
holiday_flag=['bus_days','hols']):
if any(item in s for item in changepoint_flag):
return 'changepoint'
elif any(item in s for item in seasonal_flag):
return 'seasonal'
elif any(item in s for item in trend_flag):
return 'trend'
elif any(item in s for item in holiday_flag):
return 'holiday'
特征组的影响量化是通过系数和该组中每个特征的特征值的点积来计算的。这使我们能够清晰地探索每个组随时间的影响。
我们对变化点特征的主导地位有些怀疑——让我们可视化它们的影响,看看它们如何随时间影响预测:

作者提供的图片
大多数变化点推动了预测的下降,除了 2013 年左右的一个上升变化。有很多变化点被使用,可能弥补了“纯”趋势特征中的不足。
那么趋势和变化点特征的综合效果如何?这些实际上是 Prophet 模型中的“趋势”组件。结果证明,这也不是很难获取:

作者提供的图片
季节性效果怎么样?我很想看看各种季节性特征如何组合在一起形成一个整体效果。

作者提供的图片
有趣!我们在这里看到了一年初的下降,然后是一年中期的双重峰值——这是一个相当复杂的效果。
最后,让我们看看假日效应,我们将工作日和实际的银行假日捆绑在一起。由于英国一年中的银行假日非常少,这一特征组可能会比较嘈杂:

作者提供的图片
……这正是我们看到的。注意到大多数影响在减少 1%(0.99)和增加 2%(1.02)之间波动,某些急剧下降可能与节假日较多的月份有关。
我认为可以相当安全地说,我们可以隔离每个效果对任何给定观察的影响——这正是我们开始这次小冒险时所期望的!
总结与漫谈
我们在这篇文章中涉及了很多内容。按照惯例,在稍作闲聊之前,先进行简要回顾。
总结
在谈到我们希望建立一个出色的预测模型后,我们查看了英国的道路交通事故数据。我们看到在汇总的月度数据中存在明显的趋势和季节性,并知道我们的特征工程和建模需要加倍努力以捕捉这些效果。
我们讨论了构建的模型应能良好预测,应具备可解释性和可理解性,并应具有某种形式的固有特征选择。我们还讨论了为何 LASSO 模型符合这些要求。
我们继续探讨模型形式,以及如何变换目标以更好地适应时间序列的乘法形式。我们探索了正则化参数alpha,以及适当的(时间序列)交叉验证方法以找到用于预测的最佳alpha。我们通过绘制时间之外的预测结果来进行感知检查,然后使用alpha和整个历史数据重新构建模型。
随后,我们了解了如何利用 SHAP 有效地研究 LASSO 模型的内部运作。然后,我们利用模型的可分性来研究和量化趋势、季节性和假期效应的影响。
让我们聊聊吧。
决策树无法外推
……以及使用广义线性模型的其他原因。
决策树模型最著名的挑战之一是它们无法超越训练数据的值进行外推。使用决策树模型间接地对预测值设置了上下限;鉴于我们的序列随时间的趋势,这可能不是一个合适的让步。
神经网络可以外推,但考虑到它们的复杂性,我决定退回到我的拿手好戏:谦逊的 GLM。
尽管简单,GLM 具有很多优点。尤其是考虑到我们的具体要求:
-
GLM 被广泛使用和理解。它们的熟悉度使得解释和沟通变得简单。
-
模型结构意味着有效选择链接函数可以使模型变得可处理。特别有用的是我们可以相对于“基准”观察值进行比较,因为这允许我们量化给定特征的某一水平的影响,其他条件相同。
-
类似地,我们可以计算给定特征的确切效果。这使得确保某些条件——如客户的期望——变得更加简单明了。
从实际的角度来看,我们的模型形式允许我们轻松进行一些手动“调整”。例如:如果我们认为我们的时间序列趋势与数据所示的方向不同怎么办?我们可以在假定的未来趋势上应用建模的季节性和假期效应,从而得到定制的预测。我们可以将这种方法概括并应用各种调整——非常适合情景测试和沟通“如果会怎么样?”的问题结果。
当然,GLM 并非全是阳光和彩虹。它们的构建可能相当棘手,通常需要专家判断。当我们开始引入特征交互(我们必须自己创建!)时,它们也可能变得非常麻烦。
就像大多数事情一样,你必须选择合适的工具来完成工作。
剩余的:模型残差
我们对一个已知具有强烈趋势和季节效应的时间序列进行了建模。我们怎么知道我们是否充分捕捉到了这种趋势和季节性呢?
一种需要调查的途径是模型残差:实际值与预测值之间的差异。残差中的任何趋势或季节性都会表明模型没有充分考虑这些效应;我们需要采取措施解决这个问题。
自然地,我并没有做到这一点。听我的,不要看我做的。
有趣的是,调查模型残差的方法实际上很自然地过渡到不同的建模方法。
一些从业者——包括我工作中的那些——支持混合模型,其中我们对原始系列进行简单建模,然后用不同的时间序列模型建模残差。
这里的基本逻辑是,简单模型充分捕捉了趋势,因此残差不包含长期非周期性效应:也就是说,残差是平稳的。这为依赖于平稳性假设的其他技术打开了大门——如一些时间序列方法——或可以从平稳性中受益,因为它们处理外推不特别好——如基于树的方法。
组 SHAP
大多数从业者喜欢使用 SHAP。理由很充分——看看这些可视化效果是多么的信息丰富和易于理解!
不过,如果我们能够将各种特征组合成特征集,并将这些特征集输入到 SHAP 可视化中,那就更好了。
尽管我们手动做过类似的工作,但分组特征的 SHAP 瀑布图确实能够生动展示预测的驱动因素……看起来我们可能可以做到这一点¹⁰。
让生活更轻松
我的工作流程用的都是比较繁琐的方法,包括手动交叉验证循环。这不是炫耀,相信我——我更愿意走最小阻力的路径。
我最后还是手动完成了这些操作,因为我无法让 scikit-learn 的随机交叉验证¹² 搜索正常工作。我认为这是因为使用了Pipeline估计器而不是仅使用Lasso类(记住我们需要在拟合 LASSO 模型之前对输入特征进行标准化),但我不确定。如果你知道解决方案——或者能指引我正确的方向——请告诉我。
最后,LassoCV 在 scikit-learn 中确实存在。这是一个自我调节的 LASSO 模型——没错,将其指向你的数据,它会自主确定alpha的最佳值。看起来它甚至可以接受定制的交叉验证生成器¹¹,尽管我不确定这是否包括时间序列交叉验证。下次可以看看。
就这些了。我希望你和我一样享受阅读这篇文章的过程。
和往常一样,请告诉我你的想法——我很想了解你在 Prophet 或用不同方式建模时间序列的经历。
下次再见。
参考文献和有用资源
-
GitHub — facebook/prophet: 用于生成具有线性或非线性增长的多季节时间序列数据的高质量预测的工具。
-
Vincent Warmerdam: 通过简单的,甚至线性的模型获胜 | PyData London 2018 — YouTube
-
roadtraffic.dft.gov.uk/downloads按照 开放政府许可证 (nationalarchives.gov.uk) 使用 -
3.2.3.1.5. sklearn.linear_model.LassoCV — scikit-learn 0.15-git 文档
-
sklearn.model_selection.RandomizedSearchCV — scikit-learn 1.3.1 文档
虚假预言者:将回归模型与 Meta 的 Prophet 进行比较
我那款受 Prophet 启发的时间序列回归模型——一个怪物般的自制版本——能与真正的预言者竞争吗?
·
关注 发表于 Towards Data Science ·7 min read·2023 年 11 月 25 日
--
图片由 Piret Ilver 提供,来源于 Unsplash
在我构建 Meta 优秀预测包 Prophet 的最后一篇文章中,我将会查看我自制的版本与原版相比的表现如何。
这将是一个快速的过程:我们首先查看数据,然后可视化两种方法在时间外数据上的预测。接着,我们将使用一些指标更正式地确定哪个预测器更好,并讨论是否这是一个公平的比较。
让我们开始吧。
顺便提一下:我提到了其他几篇文章——确切来说是两篇。第一篇涉及基于 Prophet 方法的时间序列特征工程,可以在这里找到:
基于 Meta 的 Prophet 包的思想,为时间序列机器学习模型创建强大的特征
towardsdatascience.com
在续集中,我使用我们全新特征构建模型。可以在这里找到:
借鉴 Meta 的 Prophet 构建强大的时间序列回归模型
towardsdatascience.com
今天讨论的许多话题在链接的文章中有更详细的介绍——如果你喜欢细节,值得一读。
数据
我们使用的是英国道路交通事故数据¹,汇总为每月统计。

作者提供的图片
我们在时间序列中看到了一些特征:
-
整个序列中有一个强烈的下降趋势
-
在 2012 年和 2014 年之间的某个时间点,减少率发生了变化
-
序列早期的季节性较强
-
潜在的季节性变动,尤其是在序列的末尾。
游戏的目标
我们有两个模型——我们将自制的 Frankenstein 模型称为 LASSO 模型,而 Meta 的 Prophet 就叫做… Prophet。
对于每个模型,我们将生成时间外预测。这本质上意味着拟合我们每月统计数据的一个子集,然后预测未来 12 个月。
每个预测将与实际观察数据进行比较;哪个模型平均最接近——哪个就赢。
顺便提一下:这本质上是一个交叉验证测试。如果你熟悉标准的交叉验证方法,但在时间序列分析中没有使用过,你可能会发现下面的(2)非常有用。
图片
我们可以可视化每个模型的时间外预测——LASSO 为红色,Prophet 为蓝色——并将其与实际数据进行比较。
我们应该记住,每个预测都是使用预测期前的所有数据构建的。例如,2010 年的预测是使用截至 2009 年的数据建立的。

作者提供的图片
这是一个相当清晰的画面:除了 2013 年之外,Prophet 看起来都有点偏离预期。
有趣的是,需要注意的是两种方法创建的预测模式的相似性:
-
两种模型都产生较低的预测——即它们反映了总体趋势的下降趋势。
-
两种模型都有年内增长和年中高峰——即预测产生类似的季节性模式。
两个模型与现实有多大距离?我们需要查看一些性能指标来找出答案。
在数字上
我们将使用常见的性能指标来衡量性能——平均绝对误差(MAE)、平均绝对百分比误差(MAPE)和均方根误差(RMSE)——以及一个新手(至少对我来说是新手):MASE。
平均绝对缩放误差
平均绝对缩放误差(MASE)是“普遍适用的预测准确度测量,避免了其他测量中出现的问题”³,并且“可用于比较单个系列的预测方法以及在系列之间比较预测准确度”³。
从数学上讲,MASE 是超出时间预测误差与由天真预测方法产生的样本内预测误差之比。由于我们使用的是月度数据,我将天真预测预测值视为前一年同一时间点的值——例如,2012 年 5 月的预测只是 2011 年 5 月的值。非常天真。
在比较预测方法时,具有最低 MASE 的方法是首选方法³。
需要注意的是,MASE > 1 意味着预测方法相对于天真预测表现不佳。
旁注:我使用了链接文章中描述的实现——即“误差”是平均绝对误差。我相信我们可以在比例误差计算中一致使用其他性能度量,例如 MAPE——只要误差度量是一致的。
结果
让我们总结一下使用我们描述的指标来总结折叠外和整体平均模型性能:

作者提供的图片
对于 LASSO 模型来说,这是一个相当全面的胜利,Prophet 只在局部表现优越。
刀光剑影?
如我们所见,如果你是 Prophet 的粉丝,这并不是一个令人愉快的阅读:Meta 的工具设法夺走了一些分数(取决于度量标准),以避免完全被淘汰。公正的评论可能会建议返回俱乐部重新评估策略。
虽然 Prophet 的结果不太理想,但有几个原因可以解释这种性能。
特征
LASSO 模型使用了为这一特定时间序列专门设计的特征。它可用的输入特征集本质上是 Prophet 可用特征的超集,并且有一些额外的特性。
此外,LASSO 模型中的某些特征微妙地不同。例如,描述潜在变化点的特征在 LASSO 模型中不像在 Prophet 模型中那样受到限制。
把它看作是试图超越别人,但你对他们了解的少一些——或者有些不同。并不容易。
建模
超出折叠的数据并不像我描述的那样“未见过”。
在之前的文章中,我们介绍了 LASSO 模型的参数化:我们如何使用超出折叠的数据来选择优化模型预测能力的正则化强度。从这个意义上讲,LASSO 模型已被调整以在每一组数据上进行良好的预测,而 Prophet 模型则是直接投入实际使用。
在“正常”的超参数优化中,我们通常可以期望性能提升约 1%—2%;在时间序列的背景下,性能提升可能更大,因为“超出折叠”确实是“超出时间”。
那么是时候结束 Prophet 的讨论了吗?
不要急于下结论……这系列文章确实突出了几点——让我们逐一讨论一下。
初始状态下,Prophet 工作非常出色。虽然它确实可以被超越,但做到这一点需要比仅用 10 行代码启动和预测 Prophet 更多的工作。
LASSO 模型的可解释性远超 Prophet 模型。是的,Prophet 确实为预测提供了不确定性的估计,但我们无法了解究竟是什么驱动了这些预测。我甚至不确定我们能否将 Prophet 应用于 SHAP。
我还发现 Prophet 的调整并不那么简单。也许是因为我不是该包的高级用户,或者是因为调整参数的方式曲折。LASSO 模型显然不是这样。
尽管 LASSO 方法在性能和可解释性上确实有所提升,也许我们真正需要的是结合这两种方法:一种作为另一种的试金石。例如,如果“天真的”Prophet 模型产生了合理的预测,那么复制 LASSO 方法(“虚假的先知”)以最大化性能可能是合理的。
我说完了。希望你们阅读这系列文章的体验与我写这些文章的乐趣一样。
一如既往,请告诉我你的想法——我非常希望了解你在使用 Prophet 或以不同方式建模时间序列的经验。
下次见。
参考资料和有用资源
False Prophet: 自制时间序列回归的特征工程
基于 Meta 的 Prophet 包中的想法,创建强大的时间序列机器学习模型特征
·发表于 Towards Data Science ·阅读时间 15 分钟·2023 年 10 月 13 日
--

Scott Rodgerson 在 Unsplash 上的照片
Meta 的 Prophet 包¹ 是最广泛使用的时间序列包之一。至少根据我的经验,在查看了我为稍后阅读而收藏的一系列时间序列文章后是这样。
说笑归说笑,我以前使用过这个包,我非常喜欢它。
另一个很棒的时间序列建模资源是 Vincent Warmerdam 的演讲,题为“通过简单甚至线性模型获胜”²,他在演讲中讨论了如何使用线性模型建模时间序列(需要一些准备工作)。
现在,有一些数据科学元素模糊了艺术和科学的界限——例如超参数调优或定义神经网络的结构。
我们将倾向于艺术,做许多伟大艺术家所做的事情:借鉴他人的想法。因此,在这一系列文章中,我们将从 Prophet 中借鉴特征工程的想法,并从 Vincent 中借鉴线性建模的想法,进行我们自己的时间序列回归分析。
大局观
让我们首先讨论一下总体目标,然后再深入到特征工程中。
总体目标很简单——在指定的时间范围内生成最准确的未来事件预测。
我们将从一个仅包含日期变量和感兴趣的数量的时间序列开始。由此,我们将推导出额外的信息,这将使我们能够准确地建模未来结果。这些额外的特征将会受到 Prophet 的“启发”。
然后我们将把工程数据输入到一个轻量级模型中,让它学习如何最好地预测未来。之后,我们将深入了解模型的内部工作——毕竟,我们需要了解是什么驱动我们的预测。
现在我们已经看到了全貌,让我们仔细看看数据,从数据开始。
数据
我们将使用来自英国的真实数据——在这个案例中,是道路交通事故数据。
这是由英国政府提供的 STATS19³ 数据集。由于数据集非常庞大,为了使事情更易于管理,我们将把每日事故数量汇总为月度数据。
通过可视化我们的时间序列,我们可以看到一个下降趋势和强烈的年度模式。也可以认为这些模式在 2012 年至 2014 年间的某个时点发生了变化。

作者提供的图片
这已经是我们需要创建的两种特征——一种是捕捉整体趋势,另一种是捕捉重复的年度模式(或季节性)。
特征工程
在开始实现之前,我们将探讨推动我们工程的一般理念。
Prophet 的魅力
Prophet 使用具有三个主要成分的可分解时间序列模型,这些成分是加性组合的(带有一点随机性)。在数学上,这是:

作者提供的图片
这里 g(t) 是趋势函数,用于建模时间序列中非周期性的变化,s(t) 代表周期性变化(例如,周度和年度季节性),而 h(t) 代表节假日的影响,这些节假日可能会在一个或多个日期上发生不规则的安排。⁴
正是这种可分解的模型形式使得 Prophet 如此灵活,正是时间序列可以分离的这个理念将指导我们的特征工程:也就是说,我们将生成帮助我们建模每一个这些组件的特征。
我们的模仿不会成为 Prophet 的双胞胎——我们只是从中获得灵感。因此,我们将做一些调整:
-
g(t) 还将代表时间序列中的步进变化或变更点。
-
我们不会过多关注误差项(epsilon),只需记住 Prophet 使用它来表示“模型未能涵盖的个性化变化”⁴。
旁白:如果你不熟悉时间序列的组成部分,这篇文章是一个很好的总结:
一个有效地将时间序列分解成其组成部分的指南
我们将从基本的日期相关特征开始,然后再衍生一些更具想象力的特征。
基础特征:零步骤
作为热身,让我们获取一些基本的日期相关特征:
# set date as the Frame index
df.set_index('date', inplace=True)
# simple date features
df['yr'] = df.index.year
df['qtr'] = df.index.quarter
df['mth'] = df.index.month
df['dim'] = df.index.days_in_month
所有这些特征都可以直接获得。即使是对这些特征不熟悉的人也可能很清楚,这些特征很可能会预测每月的事故数量。
是时候进行一些主题工程了。
趋势
趋势,即随时间的长期变化,可以呈现各种形式。
如果存在的话,趋势通常会非常简单——持续的上升或下降变化并不少见。在许多时间序列演示中使用的航空公司乘客数据⁵展示了一个非常清晰且简单的趋势。
然而,趋势可能会变得比这更复杂。例如,它们可能是非线性的,其中存在加速或减速的现象。可能会有多个加速或减速的实例。或者可能会有阶跃变化,其中趋势的位置发生突变。
我们已经看到我们的数据中似乎存在一个下降的线性趋势,变化点位于 2012 年和 2014 年之间的某个地方。我不完全确定趋势的确切形式,所以我将创建各种趋势,并让模型找出哪一个最合适:
# fraction of year
df['yr_fraction'] = df.index.year + (df.index.month - 1) / 12
# add non-linearity
yr_fraction_rebased = df['yr_fraction'] - df['yr'].min()
df['yr_fraction_sq'] = yr_fraction_rebased ** 2
df['yr_fraction_cube'] = yr_fraction_rebased ** 3
df['yr_fraction_quad'] = yr_fraction_rebased ** 4
df['yr_fraction_quint'] = yr_fraction_rebased ** 5
df['yr_fraction_sqrt'] = yr_fraction_rebased ** 0.5
从视觉上来看,这给我们提供了许多可能的趋势(通过一些缩放来将所有内容适配到同一图表上):

图片由作者提供
附注:重要的是要注意,虽然所有这些趋势看起来都是上升的,但模型将能够利用这些趋势捕捉数据中的下降趋势,例如使用负权重或系数。这不仅适用于趋势组件,还适用于模型中使用的所有特征。
现在来看一些变化点。
Prophet 通过首先指定大量潜在变化点,然后尽可能少地使用这些点来检测变化点。Prophet 的默认方法是在数据的前 80%内创建 25 个均匀间隔的变化点。
我们将通过首先创建许多潜在的变化点,然后让模型选择使用哪些点来做一些类似的事情。这与 Prophet 并没有太大区别,但没有对间隔施加限制。
# changepoints
changepoints = pd.DataFrame()
for date in df.index.unique():
date = pd.to_datetime(date)
date_str = f'change_{date.strftime("%Y_%m")}'
# allow only X-erly changes
if date.month % 3 == 0:
temp = pd.DataFrame(
{date_str:np.where(df.index <= date,0,1)}
)
changepoints = pd.concat([changepoints,temp], axis=1)
如果我们查看前 12 行,我们可以看到变化点创建是如何工作的:

图片由作者提供
并不复杂,因为我们为每个变化点特征都有一列,指示观察是否发生在给定日期之前或之后。
值得一提的是,我只允许变化点出现在每个季度末。除非我们对系列中的变化非常确定,否则设置这些点可能有点像艺术,需要在灵活性和过度反应之间取得平衡;变化点需要足够频繁以捕捉趋势中的真实变化,但又不能过于频繁,以至于开始捕捉噪声。
在这种情况下,季度变化点有一些优势。首先,它们有效地对变化持续多长时间设置了最低时间阈值,才被认为是“真实”的——这对于减少模型将信号误认为噪声的倾向可能很有用。
在英国,季度变化大致与季节性变化和重大日历变化(例如 1 月 1 日)对齐。
还需要考虑外部环境因素:每年 3 月和 9 月会发布新的注册车牌,这通常会导致新车销量的激增。由于新车通常比旧车更安全,因此可以合理地想象车主构成的变化会对道路交通事故的数量产生影响。
虽然这可能是一个不错的起点,但我们可能需要稍后再回来进行一些微调。
季节性
我们将时间序列中存在的规则性或周期性效应称为季节性。
Prophet 使用傅里叶级数来表示加性模型中的季节效应。其通用化为如下:

图片来源:作者
傅里叶表示法本质上意味着我们在时间序列中看到的所有重复效应可以通过一系列不同周期的正弦和余弦波来表示。
由于我们正在处理月度数据,我们预计每年在相同时间看到季节性效应;换句话说,我们时间序列的 周期 是一年,即 12 个月。因此,我们需要将 P 设置为 12。
N = 10 和 N = 3 已被证明分别适用于具有年度和周度季节性的系列,但我们将 N 扩展到 12 以确保效果。
记住,我们不是在创建一个单独的季节性模型,而是创建季节性特征,让我们的单一模型能够组合这些特征以表示周期性变化。考虑到这一点,我们按以下方式创建 Prophet 启发的特征:
# Prophet features
for j in range(1,13):
df[f'prophet_sin_{j}'] = np.sin(2 * np.pi * df['mth'] * j / 12)
df[f'prophet_cos_{j}'] = np.cos(2 * np.pi * df['mth'] * j / 12)
这会生成多个不同周期的正弦和余弦波,准备好让模型将它们组合在一起,以捕捉季节性。因此,作为输入特征,它们可能看起来像这样:

附注:注意较高的 n 值如何减少正弦函数的周期——即完成一个完整周期所需的时间。
我们的模型将以各种方式选择和缩放这些函数,以便适当地考虑时间序列的周期性(或 季节性)元素……也就是说,模型将确定上述公式中的 a 和 b 系数。
节假日
假期和事件对许多业务时间序列提供了大的、相对可预测的冲击,并且通常不遵循周期模式,因此它们的效果不能很好地通过平滑周期建模。⁴
一个很好的例子是复活节周末,在英国对车辆事故有非常明显的影响。然而,这个假期周末并不是每年都在同一个月份——有时它发生在三月,而在其他年份,我们得等到四月才能开始寻找巧克力兔子。
尽管我们可以相当提前知道复活节的日期(当然也知道过去的日期),但用上述季节性方法建模相当困难。因此,我们将采取不同的方法,而是统计每月的银行假期和工作日数量,这应该使我们也能够捕捉到像圣诞节这样更规律的假期的影响。
我们可以使用numpy来获取工作日:
# business days
begin = df.index.values.astype('datetime64[D]')
end = (df.index + pd.DateOffset(months = 1)).values.astype('datetime64[D]')
df['bus_days'] = np.busday_count(
begindates = begin,
enddates = end
)
# holidays
df['hols'] = pd.Series(df.index).apply(count_holidays).values
……但是需要holidays包和一点来自 StackOverflow 的帮助才能获得每月的假期数量:
import holidays
def count_holidays(u):
hols = holidays.country_holidays('GB')
days = pd.date_range(u, u + pd.DateOffset(months = 1))
return sum(y in hols for y in days)
我们将忽略周末天数:因为可以从现有特征中推导出,使用周末天数可能会引入不必要的特征相关性。
就这样——特征工程完成,我们准备好进入建模阶段。
总结与闲聊
我们在这篇文章中涵盖了很多内容。按照惯例,我们将快速回顾一下,然后进行一些闲聊。
总结
在讨论了构建一个出色的预测模型的愿望之后,我们查看了英国道路交通事故数据。我们看到我们汇总的每月数据中存在明显的趋势和季节性,并知道我们需要创建各种特征来捕捉这些效应。
我们的特征工程过程从简单的预热开始——提取简单且直接可用的日期特征。
我们开始构建捕捉趋势的特征,并允许一些变化点,这些变化点被相对简单地处理。我们的逻辑可能有一定的价值,但我们承认可能需要一些微调。
我们使用了改进的傅里叶变换来建模季节性,创建了 12 组季节性特征。
最后,我们转向创建假期特征,选择专注于每月的工作日和假期天数。
循环特征编码
在构建新特征时,我们需要牢记两点——什么可能对我们的目标有预测作用以及它如何被机器解读。
一个很好的例子是年度月份,我们通常用整数映射来表示(即:一月 = 1,……,十二月 = 12)。我们可以相当确定地认为月份会对事故数量产生强烈的影响。但是如果我们将整数编码传递给模型,模型会将某一年的十二月视为与下一年的一月完全不同的时间点,即使它们是时间上相邻的!
我们通过周期特征编码解决这个问题,或者更具体地说,通过转换为极坐标。由于正弦和余弦变换本身都不能提供唯一的编码,我们使用两者的组合。
上面的代码没有显示任何周期编码的示例,但它在我的工作流程中使用,结果在模型中成为一个重要特征(见第二部分)。
先知特征
按类似的思路,我们的“先知特征”严重依赖于正弦和余弦变换。实际上,这些实际上是傅里叶变换。
细心的读者可能已经注意到先知特征的创建方式。在原始论文中,时间维度被重新基准化到某一点,每个观察值都被反映为在那之后的 t 时间单位。我们没有这样做,而是选择了另一种方法。如果我再次考虑这个问题,可能会考虑这一点。
滞后特征:房间里的大象
到目前为止,我实际上只是略过了滞后特征的使用。或者更准确地说,是缺乏使用。
使用目标量的先前值来预测目标量的当前或未来值——即使用“滞后”的时间值——在许多优秀的时间序列模型中是常见的。这是有充分理由的,因为它们通常是强有力的预测器。
我不愿这样做的核心在于模型的整体目的——即在预测方面表现良好。当我们使用滞后特征进行预测时,通常必须“推进”滞后特征,并从使用目标的实际值过渡到使用预测值。
具体来说,考虑一个使用一个滞后特征的模型——即,我们使用时间 t — 1 的目标值来预测时间 t 的目标值。我们希望使用该模型预测未来 3 步。
第一个预测(在时间 t + 1)将使用今天目标的值。由于目标值是已知的,所以这里没有问题,一切照常进行。
现在考虑时间 t + 2 的预测。我们需要时间 t + 1 的目标值才能使用我们的模型。当然,此时 真实 的目标值是未知的,因此我们使用 预测 的目标值来代替时间 t + 1。当涉及到预测时间 t + 3 时,我们将时间 t + 2 的预测推进,以此类推。从中可以看出,预测误差如何被融入预测中;早期的错误会被累积,因为不准确的预测被推进和重复使用。我不喜欢这样。
不使用滞后特征还有一个附带好处,那就是模型可解释性:我们被迫以不同的方式建模目标,并真正考虑(并建模!)结果的驱动因素。
这通常会导致与利益相关者进行更好的对话,因为解释预测开始听起来像是“长期趋势占预测的 X%而季节性占预测的 Y%”而不是“预测为 B 是因为前一个预测的值为 A”。
在继续之前,最后提一下滞后特征。我们并不局限于使用滞后的目标特征,因此虽然我们讨论了包括目标特征的前期值,我们也可以同样包括滞后的预测变量,且要有类似的警示和要求。
这并不是要全面否定使用滞后特征的做法——我相信在某些用例中这样做完全合理。使用的滞后数量和预测窗口的长度可能甚至意味着这不是问题。
变更点
让我们讨论变更点及其创建。
我以一种非常简单的方式创建了变更点,我确信还有许多改进我实现的方法。Prophet 可以说做得更好,通过在数据的前 80%中创建均匀间隔的变更点,但也有一些需要考虑的事项。
这减少了较近期虚假变更点对未来预测的影响——这是件好事。
但真正的变化有多少发生在均匀间隔的时间点上?如果变化真的以这种节奏发生,那是否更应视为某种季节性影响?是的,这是在挑剔。是的,这很重要。好吧,我继续。
虽然我们可以对历史变更点进行建模,但对未来变更点的建模要困难一些;有些情况下,未来的变更是已知的。
例如,英国推出了《民事责任法》,对英格兰和威尔士的个人伤害赔偿系统进行了修改。如果你像我一样定期建模颈部损伤的赔偿申请数量和成本(工作需要,而非兴趣),2021 年 6 月该法案的实施导致了一个相当严重的步骤变化。但由于事先已知,所以可以采取措施加以考虑。
这类变化需要逐案处理,务实和常识应当是首要考虑的。
交互作用
有人记录房间里的大象吗?再来一个——我们还没有构建捕捉预测变量之间交互作用的特征。
交互作用是非常有用的特征,可以捕捉各种预测变量之间的关系。当一个自变量对结果的影响取决于另一个自变量的值时,就会发生交互作用⁹。
在我们的案例中,使用交互作用的一个更有趣的动机是允许季节性随时间变化,因为我们目前假设——并建模——相同的季节性效果在二十多年里仍然有效。虽然没有明显的反证,但我们可以通过将时间与某些特征进行交互,潜在地挖掘更多的预测能力。
我们需要将这个添加到下次的待办事项列表中。
假期
最后,简单说一下假期的情况。
我们讨论了一些复活节周末可能带来的麻烦,并提出了一个简单的解决方案。
对特征工程的真正改进将是纳入学校假期。这些假期可能会对道路交通事故的数量产生影响,因此将是强有力的预测因素。
不幸的是,这并不容易,因为英国的学校假期时间略有不同,假期长度也不尽相同。也许我们可以发挥创造力,创建一个学校假期的分布,并将其分配到每个月——这是下次的另一个任务。
就这些了。我希望你们阅读这篇文章的乐趣与我写作的乐趣一样多。
一如既往,请告诉我你的想法——我非常感兴趣了解你在使用 Prophet 或以不同方式建模时间序列的经历。
正如我提到的,我将在即将发表的文章中处理建模问题——请留意。
期待下次见面。
参考资料
-
GitHub — facebook/prophet: 用于生成高质量时间序列数据预测的工具,支持多季节性和线性或非线性增长。
-
Vincent Warmerdam:用简单甚至线性模型取胜 | PyData London 2018 — YouTube
-
roadtraffic.dft.gov.uk/downloads依据 开放政府许可证 (nationalarchives.gov.uk) 使用
BigQuery 的神奇生物及其使用时机
原文:
towardsdatascience.com/fantastic-beasts-of-bigquery-and-when-to-use-them-13af9a17f3db
揭示 BigQuery Studio、DataFrames、生成 AI/AI 函数和 DuetAI 的特点
·发表于 Towards Data Science ·8 分钟阅读·2023 年 12 月 31 日
--

“BigQuery 是一个集数据库、商业智能、机器学习和生成 AI 功能于一体的 Google 服务。” [照片由 Korng Sok 提供,来源于 Unsplash]
了解更多关于 BigQuery 的内容
我最喜欢的书之一是 J.K. 罗琳 的“神奇动物在哪里”。这是一个关于魔法生物在非魔法世界中肆意活动的故事。它还讲述了魔法师和非魔法师如何建立友谊以保护魔法生物。在这个任务中,主要的非魔法师角色发现了一个充满魔法的世界,并爱上了所有的挑战,渴望自己也是一名巫师。
作为一个非魔法师,我从机械工程转向数据世界的过程最初充满了挑战。每当我进入一个新的数据领域时,我都会想:“要是我也是个巫师就好了。” 😉
当我第一次开始学习关于数据库 (DB) 和商业智能 (BI) 的知识时,我脑海中也有这个想法。
当我深入学习机器学习 (ML) 相关主题时,这个想法再次出现。
现在,我正试图在生成 AI (GenAI) 开发领域中运筹帷幄——你猜对了,这个想法再次陪伴着我。
即使在获得了数据库、商业智能和机器学习的经验之后,如果没有一个 Google 服务——BigQuery (BQ),生成 AI 领域对我来说将更加具有挑战性。
你知道为什么吗?
因为 BigQuery 提供了“全能型”解决方案,涵盖了“DB-BI-ML-GenAI”组合。或者,正如 Google 在其某次网络研讨会中所宣布的,它涵盖了“从数据到 AI”的功能 [1]。
我认为应该如何宣布:“BigQuery 的奇妙生物”。
在我最喜欢的 BigQuery 特性——BQML——的基础上,Google 最近实现了更多变革性特性,使 BQ 更类似于分析开发环境,而不只是数据库环境。
这些新特性使数据专业人员能够进行端到端的分析任务无需在多个工具之间切换。
关于端到端任务,我想到的是执行 探索性数据分析(EDA)、使用 SQL 或 Python 与 Spark 进行预测建模,以及通过使用生成性 AI 特性来创建新的见解。所有这些现在都可以在代码共同创建功能的帮助下完成。
BigQuery 生态系统的最新演变激励我写这篇文章,并展示将改变我们数据专业人员工作方式的新 BQ 进展,可能让我们感觉有点像Maj 人。😉
换句话说,本文旨在展示新 BQ 特性何时可以在分析工作流程中使用。
但在我们开始解释之前,我需要分享这些“奇妙生物”的名称:
-
BigQuery GenAI 和 AI 函数
现在让我们开始揭示它们的独特特征,并指出如何利用它们来提升你的表现。
BigQuery 的奇妙生物
为了暗示新 BQ 特性的目的,我制作了一个图示,展示了它们如何与知识数据发现过程(分析工作流程)对齐。

新的 BigQuery 特性与分析工作流程对齐 [图像由作者使用 draw.io]
从图中可以看出,在分析工作流程的基础上是DuetAI,这是一个 AI 编码辅助功能。除了编码支持,DuetAI 还是一个聊天机器人,你可以用它进行头脑风暴。
这意味着数据专业人员可以向聊天机器人提出与输入问题定义相关的不同问题(例如,我如何对数据集进行子集化或能否解释一个特定的函数)以及关于如何构建分析输出的建议(例如,我如何展示我的发现)。
在分析输入 → 输出 流程之间,其他特性也派上用场:
-
在 第一阶段,即 数据准备和理解阶段,可以通过 BigQuery Studio 使用 BigQuery SQL(用于数据子集和处理)、GenAI/AI 函数(用于丰富数据集)、BigQuery DataFrames 和其他 Python 库(用于探索数据集)。
-
在 第二阶段,即 数据建模和洞察综合阶段,可以在 BQ Studio 中将 BigQuery SQL 或 BQML 函数与 BigQuery DataFrames 一起使用(用于 BI/ML 模型创建),以获取所需的分析结果(描述性或预测性结果)。
现在我们将展示这些功能的神奇特性。
BigQuery Studio 和 DataFrames
在这里不要混淆 BigQuery Studio 和 Looker Studio(前身为 Data Studio)。后者是一个自助式商业智能工具,而前者是一个新的协作工作空间,支持完整的分析工作流。
综上所述,BigQuery Studio 具有以下主要特性 [2]:
#1: 在统一的界面中支持多种语言和工具。
我的意思是,它简化了不同数据职业之间的工作,因为:
-
数据工程(数据摄取和数据处理),
-
数据分析(描述性统计/探索性数据分析),以及
-
数据科学 任务(预测建模)可以在一个环境中完成,或者更好的是,在一个 笔记本 中完成。
BQ Studio 提供了 Colab 界面,使数据专业人员能够在一个笔记本文件中使用 SQL、Python、Spark 或自然语言进行 BigQuery 分析(与 DuetAI 结合使用)。此外,开发的笔记本可以通过 Vertex AI 访问,以进行机器学习工作流。
在数据摄取格式方面,它支持来自不同云平台的结构化、半结构化和非结构化格式。

Google 对 BigQuery Studio 笔记本的展示 [2]
#2: 通过连接到外部代码仓库来增强协作和版本控制。
我不得不说,这个特性是我一直以来希望的。尽管它(还)不支持所有 git 命令,但 BQ Studio 支持诸如持续集成/持续部署(CI/CD)、版本历史记录和数据代码资产的源代码管理等软件开发实践。简而言之,现在可以查看笔记本的历史记录并 还原到 或 从 特定笔记本版本创建分支。

Google 对版本控制 BQ Studio 功能的展示 [2]
#3: 在 BigQuery 中实施安全性和数据治理。
BQ Studio 实施安全性,因为它减少了在 BigQuery 之外共享数据的需要。换句话说,通过在服务之间采用统一的凭证管理,分析师可以例如访问 Vertex AI 基础模型来执行复杂的分析任务(如情感分析),而无需将数据共享给第三方工具。
此外,还有数据治理特性,包括数据血统跟踪、数据分析和实施质量约束。我只能说“对此表示赞同”。

Google 对数据治理功能的展示 [2]
总结上述特征,很明显 BigQuery Studio 是一个神奇的功能,因为它通过执行安全和治理措施,从数据摄取到数据建模的任务都能实现。
如果 Google 没有提供一个可以在 BigData Studio 笔记本中用于数据分析和建模的附加功能,BigQuery DataFrames,这个故事将不会完整。
通过安装bigframes包(类似于安装任何其他 Python 包使用pip),数据专业人士可以使用以下 Python API [3]:
-
bigframes.pandas是一个用于分析和处理的pandas API,并且 -
bigframes.ml,即一个用于机器学习的scikit-learn API。
在机器学习主题上,我希望结束这一部分。
原因在于,在下一节中,我将详细介绍新的BigQuery ML函数。
BigQuery GenAI 和 AI 功能
正如前言中提到的,我非常喜欢 BQML 函数,因为它们使数据专业人士能够使用 SQL 语法创建预测模型。
除了已经很不错的 BQML 函数组合用于监督学习和无监督学习之外,Google 现在增加了我下一个最喜欢的函数:生成式 AI和AI 功能。
关于生成式 AI 函数ML.GENERATE_TEXT,我最近写了一篇博客来展示其特性。
如何使用 BigQuery 的 GENERATE_TEXT 远程函数
[towardsdatascience.com
总之,该功能可以用于从存储在 BigQuery 数据集中的非结构化文本中创建新的见解。或者更准确地说,你可以用它来创建新类别或属性(分类分析、情感分析或实体提取)、总结或重写自然语言记录,以及生成广告或创意概念。
我会说,基于 SQL 的数据工程现在具备了下一个层次的能力。
除了这个魔法功能,其他强大的新 SQL 基础AI 功能包括:
-
ML.UNDERSTAND_TEXT — 帮助分析存储在 BigQuery 中的记录的文本的函数,并支持与 ML.GENERATE_TEXT 函数类似的功能。这意味着它支持实体、情感、分类和语法分析。
-
ML.TRANSLATE — 用于将文本从一种语言翻译成另一种语言的函数。
-
ML.PROCESS_DOCUMENT — 用于处理来自对象表(例如发票)的非结构化文档的函数。
-
ML.TRANSCRIBE — 用于从对象表中转录音频文件的函数。
-
ML.ANNOTATE_IMAGE — 用于从对象表中进行图像标注的函数。
尽管这些函数是用 SQL 编写的,但理解其查询结构和参数是正确使用的必备条件。为了加快学习曲线,一点编码帮助——DuetAI——会非常有用。
DuetAI 在 BigQuery 中
就这么简单:DuetAI的魔法功能可以帮助数据专业人员在 BQ 和 BQ Studio 环境中编写、改进和理解他们的多语言代码。
更准确地说,该功能具有以下特征:
#1: 在 BQ 环境或 BQ Studio 笔记本中从头创建查询或代码

使用 DuetAI 在 BigQuery 中的代码补全示例 [作者提供的图片]
#2: 解释 BQ 环境或 BQ Studio 笔记本中的查询和代码片段

使用 DuetAI 在 BigQuery 中的代码补全示例 [作者提供的图片]
#3: 提升 BQ 环境或 BQ Studio 笔记本中的代码质量
我的意思是,DuetAI 可以通过以下方式改进代码:
-
语法纠正:它可以识别并建议修正语法错误。
-
逻辑改进:它可以建议替代的代码结构方式,从而提高整体效率和可读性。
-
文档生成:它可以自动生成代码文档,使其更易于理解。
最后,通过这个魔法特性,我将结束新 BigQuery “神奇动物”的部分和展示。
无限可能的世界
在书籍“神奇动物在哪里”中,J.K. 罗琳展示了当 Maji 和 No-Maj 人们了解魔法生物的积极特质时,它们变得不再那么可怕。同样,在这篇博客文章中,我想展示 BigQuery 的新奇功能,并指出它们在不同分析层面上的积极特质。

作者使用DALL-E扩展在 ChatGPT 中创建的三只 BigQuery 野兽的插图(野兽的数量正确,但名字的数量不正确 😉)
我的目标是展示新功能如何在完整的分析工作流中支持你并简化工作,无论你是数据工程师、分析师还是科学家。此外,我还想指出它们如何增强具有不同数据背景的团队成员之间的协作。
希望你能亲身体验魔法,并了解更多关于“BigQuery 的神奇动物”的知识。
学习愉快!
知识资源
[1] Google Cloud 网络研讨会:“Cloud OnBoard:使用 BigQuery 和 Vertex AI 从数据到 AI”,访问日期:2023 年 12 月 10 日,cloudonair.withgoogle.com/events/cloud-onboard-data-to-ai
[2] Google Cloud 博客:“宣布 BigQuery Studio——一个协作分析工作区,加速数据到 AI 的工作流”,访问日期:2023 年 12 月 11 日,cloudonair.withgoogle.com/events/cloud-onboard-data-to-AI
[3] Google Cloud 文档:“BigQuery DataFrames”,访问日期:2023 年 12 月 11 日,cloud.google.com/python/docs/reference/bigframes/latest
AWS SageMaker 中的快速和可扩展超参数调优与交叉验证
使用 SageMaker 管理的 Warm Pools
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 3 月 3 日
--

本文分享了一种配方,以 提高 60% 的速度,通过 SageMaker 管道利用 SageMaker 管理的 Warm Pools 进行超参数调优与交叉验证。通过使用 Warm Pools,一个包含 120 个顺序作业的调优步骤的运行时间减少了 从 10 小时到 4 小时。
提升和评估机器学习模型的性能通常需要多种因素。超参数调优和交叉验证就是两个这样的因素。前者找到模型的最佳版本,而后者估计模型如何推广到未见数据。这些步骤结合起来,带来了计算挑战,因为它们需要多次训练和验证模型,可能是并行的和/或顺序的。
本文介绍的内容…
-
什么是 Warm Pools 以及如何利用它们加速超参数调优与交叉验证。
-
如何设计一个包含处理、调优、训练和 Lambda 步骤的生产级 SageMaker 流水线。
我们将考虑用于超参数调优的贝叶斯优化,它利用已测试的超参数组合的评分来选择下一轮测试的超参数集。我们将使用k-折交叉验证来评分每个超参数组合,分割如下:

𝑘-fold 交叉验证策略。
完整数据集被划分为𝑘个验证折叠,模型在𝑘-1 个折叠上训练,并在相应的保留折叠上验证。总体得分是每个验证折叠得到的个别验证得分的平均值。
故事情节:
1. 什么是温暖池?
2. 端到端 SageMaker 管道
3. 调优步骤内部发生了什么?
4. 使用温暖池有什么好处?
5. 总结
1. 什么是温暖池?
每当在 AWS 中启动训练任务时,预配实例在执行训练脚本之前大约需要 3 分钟来引导。这种启动时间在顺序运行多个任务时会累积,这在使用贝叶斯优化策略进行超参数调优时尤其明显。在这种情况下,数十个甚至数百个任务被顺序运行,导致总时间显著增加,这可能与脚本的实际执行时间相当,甚至更高。
SageMaker 托管温暖池使得在任务完成后保留训练基础设施成为可能,从而为每个后续任务节省实例启动时间。
启用温暖池是直接的。你只需在创建 SageMaker 训练任务时添加一个额外的参数(keep_alive_period_in_seconds):
estimator = Estimator(
entry_point='training.py',
keep_alive_period_in_seconds=600,
...
)
如果你想了解更多关于 SageMaker 托管温暖池的信息,这里是文档:
[## 使用 SageMaker 托管温暖池进行训练
SageMaker 托管温暖池使你能够在训练任务完成后保留和重用预配的基础设施…
既然我们了解了什么是温暖池,在第二部分中,我们将深入探讨如何利用它们来加速包含交叉验证的 SageMaker 管道的整体运行时间。
2. 端到端 SageMaker 管道
下图展示了一个端到端的 SageMaker 管道,该管道通过交叉验证进行超参数调优。

端到端 SageMaker 管道的架构图。
我们将使用SageMaker Python SDK来创建管道,这是一个开源库,简化了在 AWS SageMaker 中训练、调优和部署机器学习模型的过程。图中的管道步骤总结如下:
-
数据预处理(
ProcessingStep)— 数据从源中检索,转化,并划分为 k 个交叉验证折叠。一个额外的完整数据集被保存用于最终训练。 -
超参数调优与交叉验证(
TuningStep)— 这是我们将重点关注的步骤。它找到在验证折中实现最佳平均性能的超参数组合。 -
最佳超参数检索(
LambdaStep)— 触发一个Lambda函数,通过访问超参数调优作业的结果来检索最佳超参数集,使用 Boto3。 -
最终训练(
TrainingStep)— 使用最佳超参数在完整数据集train_full.csv上训练模型。 -
模型注册(
ModelStep)— 将最终训练好的模型注册到 SageMaker 模型注册表中。 -
推理(
TransformStep)— 使用注册的模型生成预测结果。
请在SageMaker 开发者指南中查找有关如何实现这些步骤的详细文档。
3. 调优步骤内部发生了什么?
现在我们来深入探讨管道步骤 2,该步骤迭代地并行和顺序地尝试和交叉验证多个超参数组合。该解决方案在下图中表示:

超参数调优与交叉验证步骤的架构图。
该解决方案依赖于 SageMaker 自动模型调优来创建和协调测试多个超参数组合的训练作业。可以使用SageMaker Python SDK中的HyperparameterTuner启动自动模型调优作业。它创建了MxN个超参数调优训练作业,其中M个作业在N个顺序轮次中并行运行,逐步搜索最佳超参数。每个作业启动并监控一组K交叉验证作业。在每个调优轮次中,MxK个实例会保留到下一轮。在随后的轮次中,没有实例启动时间。
SageMaker 的 HyperparameterTuner 已经利用了 Warm Pools,正如在AWS 新闻博客中所宣布的那样。然而,每个调优作业中创建的交叉验证训练作业 — 交叉验证特定的超参数组合 — 需要手动创建和监控,且配置的 实例不会保留在 Warm Pool 中。每个超参数调优训练作业仅在所有基础的交叉验证训练作业完成后才会完成。
为了使上述架构生效并为所有训练作业启用 Warm Pools,我们需要创建三个主要脚本:pipeline.py、cross_validation.py和training.py:
**pipeline.py**脚本 — 定义了在 第二部分 中描述的 SageMaker Pipeline 步骤,包括 SageMaker 的HyperparameterTuner:
#pipeline.py script
...
# Steps 2 to 5
tuner = HyperparameterTuner(
estimator=estimator,
metric_definitions=[
{
"Name": "training:score",
"Regex": "average model training score:(.*?);"
},
{
"Name": "validation:score",
"Regex": "average model validation score:(.*?);"
}
],
objective_metric_name="validation:score",
strategy="Bayesian",
max_jobs=max_jobs, # M x N
max_parallel_jobs=max_parallel_jobs # M
)
# Step 2 - Hyperparameter tuning With cross-validation step
step_tune = TuningStep(
name="tuning-step",
step_args=tuner.fit({
"train": "<s3-path-to-training-folds>",
"validation": "<s3-path-to-validation-folds>"
})
)
# Step 3 - Optimal hyperparameter retrieval step
step_lambda = LambdaStep(
name="get-optimal-hyperparameters-step",
lambda_func=lambda_get_optimal_hyperparameters,
inputs={
"best_training_job_name": step_tune.properties.BestTrainingJob.TrainingJobName,
},
outputs=[
LambdaOutput(output_name="hyperparameter_a"),
LambdaOutput(output_name="hyperparameter_b"),
LambdaOutput(output_name="hyperparameter_c")
]
)
# Step 4 - Final training step
step_train = TrainingStep(
name="final-training-step",
step_args=estimator.fit({"train": "<s3-path-to-full-training-set>"})
)
model = Model(
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
...
)
# Step 5 - Model registration step
step_model_registration = ModelStep(
name="model-registration-step",
step_args=model.register(.)
)
**cross_validation.py**脚本——作为 SageMaker 的HyperparameterTuner的入口点。它启动多个交叉验证训练任务。在调用 SageMaker 训练任务 API 时,必须在此脚本中指定keep_alive_period_in_seconds参数。该脚本计算并记录所有验证折的平均验证得分。记录这些值使得HyperparameterTuner可以通过 Regex 轻松读取该指标(如上述代码片段所示)。该指标将标记到每个超参数组合中。
提示: 在调用创建和监控训练任务的 SageMaker API 之间添加几秒钟的小延迟,以防止“超出速率”错误,如示例所示:
#cross_validation.py script
import time
...
training_jobs = []
for fold_index in range(number_of_folds):
# Create cross-validation training jobs (one per fold)
job = train_model(
training_data="<training-data-s3-path>"
validation_data="<validation-data-s3-path>"
fold_index=fold_index,
hyperparameters={
"hyperparameter_a": "<value-of-hyperparameter-a>",
"hyperparameter_b": "<value-of-hyperparameter-b>",
"hyperparameter_c": "<value-of-hyperparameter-c>"
})
training_jobs.append(job)
# Add delay to prevent Rate Exceeded error.
time.sleep(5)
...
提示: 启动 SageMaker 训练任务时禁用调试器分析器。这些分析器实例将与训练实例数量相同,并且可能显著增加总体成本。你可以通过在 Estimator 定义中简单地设置
disable_profiler=True来实现。
**training.py**脚本——在给定的输入训练集上训练模型。交叉验证的超参数作为此脚本的参数传递。
提示: 编写一个通用的
*training.py*脚本,并在交叉验证集上训练模型以及在整个训练集上使用最佳超参数训练最终模型时重用它。
要控制每个并行交叉验证任务集,以及为每个特定超参数组合计算最终验证指标,需要在 cross_validation.py 脚本中实现几个自定义函数。这个示例 提供了很好的灵感,尽管它未启用 Warm Pools 或 Lambda。
总共创建了多少任务?
M x N x (K+1) 任务。为什么?
-
M x N 超参数调整训练任务——M 个并行和 N 个串行——匹配超参数组合的数量。
-
每个超参数调整训练任务的 K 个并行交叉验证任务 + 1(超参数调整训练任务本身)。
如果我们有 5 个验证折,运行 4 个超参数调整训练任务并行和 120 个串行,那么 任务总数将是 2880。
重要: 确保你拥有所使用的实例类型所需的所有服务配额。查看 AWS 指南以了解如何为 Warm Pools 和 自动模型调整 设置这些配额。
4. 我们从使用 Warm Pools 中得到什么?
假设我们想要运行 N=120 个顺序训练任务,并且实例的启动时间为 3 分钟,训练时间为 2 分钟(每个任务 5 分钟)。这意味着总运行时间大约为:
-
没有 Warm Pools:5 分钟 x 120 个任务 = 10 小时
-
有 Warm Pools:5 分钟 x 1 个任务 + 2 分钟 x 119 个任务 ≈ 4 小时
这意味着使用 Warm Pools 过程的时间减少了 60%!
5. 总结
在这篇文章中,我展示了如何利用 Warm Pools 显著加快 SageMaker Pipelines 中的超参数调优。Warm Pools 是 SageMaker 的一个很棒的功能,它不仅使生产流水线更加高效,还加快了实验的迭代。目前,SageMaker 管理的 Warm Pools 已经集成到 SageMaker Training 中,但尚未集成到 SageMaker Processing。
— 若昂·佩雷拉
感谢阅读。希望这篇文章能帮助你在 SageMaker 中扩展超参数调优。如果你想阅读我未来的文章,请 关注我。非常感谢反馈!如果有任何问题,请在下方留言或直接联系我 通过电子邮件 或在 LinkedIn上联系我。
除非另有说明,所有图片均由作者提供。
Python 到 SQL — 我现在可以以 20 倍的速度加载数据
原文:
towardsdatascience.com/fast-load-data-to-sql-from-python-2d67aea946c0
上传大量数据的好方法、坏方法和丑陋的方法
·发表于 Towards Data Science ·6 分钟阅读·2023 年 3 月 20 日
--

速度很重要!在数据管道中,和在其他地方一样。
处理大量数据集是大多数数据专业人员的日常工作。如果它们流入数据库或仓库,这不会成为问题。
但有时我们需要上传足够重的批量数据,以至于使我们的工作站无用数小时。散步一下,喝点水,对你有好处!
但如果您真的想缩短这个任务,我们需要最优的方法将数据加载到数据库中。
如果这是一个预格式化的文件,我会更倾向于使用客户端库来完成。例如,以下 shell 命令可以将您的 CSV 文件上传到远程 Postgres 数据库。
[## 用 Streamlit 在几分钟内创建 GPT3 驱动的应用程序
学习构建智能应用程序,而不必过多担心软件开发。
levelup.gitconnected.com [## Python 网络抓取的宁静交响曲 — 3 移动篇章
在 Python 中进行网络抓取的最简单、最灵活和最全面的方法
psql \
-h $hostname -d $dbname -U $username \
-c "\copy mytable (column1, column2) from '/path/to/local/file.csv' with delimiter as ','"
但这在我的情况中很少发生。
由于 Python 是我主要的编程语言,我必须通过 Python 上传它们,可能需要进行一些预处理。因此,我做了一个小实验来查看最快的方法。
我找到了最快的方法,但这不是我平常使用的。
## 如何使用 GitHub Actions 构建简单的 ETL 管道
ETL 不必复杂。如果是这样,使用 GitHub Actions。
[towardsdatascience.com
丑陋的:如何不上传数据
尽管我现在确信我不应使用这种方法,但我认为这可能最初有帮助。
这是我们所做的。我有一个约 500MB 的磁盘数据集。首先,我尝试使用 psycopg2 模块中的插入命令来加载它。
import psycopg2
import csv
import time
def timing_decorator(func):
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"Function {func.__name__} took {end_time - start_time} seconds to run.")
return result
return wrapper
# Establish a connection to the PostgreSQL database
conn = psycopg2.connect(
host="localhost",
database="playground",
user="<your-user-name>",
password="<your-password>",
)
# Define the path to the CSV file and the name of the table to load it into
csv_file_path = "./data.csv"
table_name = "temp_table"
@timing_decorator
def load_csv_with_insert():
# Create a cursor object
cur = conn.cursor()
# Open the CSV file and read its contents
with open(csv_file_path, 'r') as f:
csv_reader = csv.reader(f)
next(csv_reader) # skip header row
# Define the query to insert data into the table
insert_query = f"INSERT INTO {table_name} VALUES ({','.join(['%s']*len(next(csv_reader)))})"
# Iterate over the CSV rows and execute the insert query for each row
for row in csv_reader:
cur.execute(insert_query, row)
# Commit the changes to the database
conn.commit()
# Close the cursor and connection
cur.close()
load_csv_with_insert()
conn.close()
我使用计时装饰器来测量加载所需的时间。这是我五个最喜欢的 Python 装饰器之一。
## 我在几乎所有数据科学项目中使用的 5 个 Python 装饰器
装饰器为从缓存到发送通知的一切提供了一种新的便利方式。
[towardsdatascience.com
此外,我故意将数据库保存在本地。因此,带宽不再是考虑因素。
早些时候,我真的认为这可能是加载数据的最快方式。这是因为我们使用游标同时插入数据并提交。但整个过程花费了这么多时间:
Function load_csv_with_insert took 1046.109834432602 seconds to run.
好吧,我们如何在没有参考的情况下知道这太慢了?
让我们尝试一下最流行的方法。
加速你的 SQL 学习曲线。
[towardsdatascience.com
糟糕的:使用 Pandas 的 to_sql 加载大规模数据集。
如果你经常使用 Pandas 及其to_sql API,这可能会让你惊讶。
我一直在使用它,并且继续使用。但这仍然不是处理大规模数据集的最佳方法。
这是我的代码。我使用的是与之前相同的数据库和 CSV。在开始加载数据集之前,我已经截断了表格。
import pandas as pd
from sqlalchemy import create_engine
import time
...
conn = create_engine("postgresql+psycopg2://thuwa:Flora1990@localhost:5432/playground")
@timing_decorator
def load_csv_with_pandas():
df = pd.read_csv(csv_file_path)
df.to_sql(table_name, conn, if_exists="append", chunksize=100, index=False)
load_csv_with_pandas()
在上述代码中,我没有使用流行的method参数。将它们设置为 [multi](https://pandas.pydata.org/docs/user_guide/io.html#io-sql-method) 将加速数据加载到分析数据库 中,如 Redshift。但在我们的案例中,我们使用的是事务性数据库。
回到重点,这种方法花费了我们 376 秒来加载数据。
Function load_csv_with_pandas took 376.70790338516235 seconds to run.
这使得 Pandas 比使用游标加载数据要好得多,效率提升约 3 倍。
那么是什么让它不是最快的呢?毕竟,这是我最喜欢和最常用的方法。
好处:使用 COPY 方法
有一种本地方式可以将文本数据上传到 SQL 表中。像 psycopg2 这样的库可以直接提供这种功能。
是的,我们可以直接将文件复制到 Python 中的 SQL 表。
...
@timing_decorator
def load_csv_with_copy():
# Create a cursor object
cur = conn.cursor()
# Use the COPY command to load the CSV file into the table
with open(csv_file_path, "r") as f:
next(f) # skip header row
cur.copy_from(f, table_name, sep=",")
conn.commit()
# Close the cursor and connection
cur.close()
游标中的 copy_from 方法使用 SQL 客户端 API 中的 COPY 方法,并直接将文件上传到 SQL。我们还传递了额外的参数。
在这种情况下,我们指定用逗号分隔列。我们还使用 next 方法跳过第一行,即表头。
结果如下:
Function load_csv_with_copy took 50.60594058036804 seconds to run.
惊人的 50 秒 — 比使用游标快 20 倍,比 to_sql Pandas 方法快近 8 倍。
但等一下!
我提到我使用 Python,因为通常会进行数据预处理。其他两种方法很直接。但我们如何在 Python 运行时上传现有数据集呢?
## 初级开发者编写多页 SQL 查询;高级开发者使用窗口函数
在记录的上下文中执行计算的优雅方式
使用 COPY 写入内存中存在的数据集
在这里我们可以从缓冲区方法中受益。下面的代码可以快速加载现有 Pandas 数据框到 SQL 数据库。
import io
...
@timing_decorator
def load_dataframe_with_copy(df):
# Create a cursor object
cur = conn.cursor()
# Convert the DataFrame to a CSV file-like object
csv_buffer = io.StringIO()
df.to_csv(csv_buffer, index=False, header=False)
# Use the COPY command to load the CSV file into the table
csv_buffer.seek(0)
cur.copy_from(csv_buffer, table_name, sep=",")
conn.commit()
# Close the cursor and connection
cur.close()
df = pd.read_csv(csv_file_path)
# Do data processing on df
load_dataframe_with_copy(df)
这段代码创建了一个名为 csv_buffer 的 StringIO 对象,它是一个类似于 CSV 文件的文件类对象。使用 to_csv() 方法将 DataFrame 写入该对象,index=False 和 header=False 以排除 CSV 输出中的索引和标题。
然后在 csv_buffer 对象上调用 seek(0) 方法,将文件指针移动回文件类对象的开头。
结论
处理大数据集与处理普通数据集不同。一个具有挑战性的任务是将这些巨大的数据集加载到数据库中。
除非不需要预处理,我几乎总是使用 Pandas 的 to_sql 方法。然而,我的实验表明,这对大数据集来说不是最好的方法。
COPY 方法是我见过的最快的方法。虽然我建议在受控环境中批量加载时使用这个方法,但对于日常任务,to_sql 提供了一个很棒的接口来调整多个上传行为。
感谢阅读,朋友!如果你喜欢我的文章,让我们在 LinkedIn、Twitter 和 Medium 保持联系。
还不是 Medium 会员?请使用这个链接来成为会员,因为这样你无需额外付费,我可以通过推荐你获得少量佣金。
使用 Polars 进行快速字符串处理——诈骗邮件数据集
原文:
towardsdatascience.com/fast-string-processing-with-polars-scam-emails-dataset-fcf7054a929a
使用内置的 Polars 字符串表达式在毫秒级别清理、处理和标记文本
·发布于Towards Data Science ·10 分钟阅读·2023 年 5 月 28 日
--

图片由Stephen Phillips - Hostreviews.co.uk提供,来自Unsplash
引言
随着大型语言模型(LLMs)的广泛采用,我们可能会觉得已经不再需要手动清理和处理文本数据。不幸的是,我和其他 NLP 从业者可以证明,情况并非如此。在每一个 NLP 复杂性的阶段——从基础文本分析到机器学习和 LLMs——都需要干净的文本数据。本文将展示如何使用 Polars 显著加快这一繁琐和乏味的过程。
Polars
Polars是一个极其快速的 Rust 编写的数据框架库,处理字符串非常高效(得益于其 Arrow 后端)。Polars 以Utf8格式存储字符串,并使用Arrow后端,使得字符串遍历缓存最优且可预测。此外,它在str命名空间下暴露了许多内置的字符串操作,这使得字符串操作可以并行处理。这两个因素使得处理字符串变得极其简单和快速。
这个库与 Pandas 有很多相似的语法,但也有很多需要适应的细节。本文将引导你了解字符串操作,但为了全面了解,我强烈推荐这个“入门指南”,它会给你一个很好的库概览。
设置
你可以在这个GitHub 仓库中找到所有代码,所以如果你想跟着编码,记得拉取它(别忘了⭐)。为了使这篇文章更实用有趣,我将展示如何清理一个小型诈骗电子邮件数据集,该数据集可以在Kaggle找到(许可证CC BY-SA 4.0)。Polars 可以通过 pip 安装——pip install polars,推荐的 Python 版本是3.10。
文本处理管道
这个管道的目标是将原始文本文件解析为一个数据框,以便用于进一步的分析/建模。以下是将要实现的整体步骤:
-
读取文本数据
-
提取相关字段(例如,发件人电子邮件、对象、文本等)
-
从这些字段中提取有用的特征(例如,长度、数字比例等)
-
预处理文本以进行进一步分析
-
执行一些基本的文本分析
言归正传,让我们开始吧!
读取数据
假设保存了包含电子邮件的文本文件为fraudulent_emails.txt,以下是用于读取它们的函数:
def load_emails_txt(path: str, split_str: str = "From r ") -> list[str]:
with open(path, "r", encoding="utf-8", errors="ignore") as file:
text = file.read()
emails = text.split(split_str)
return emails
如果你查看文本数据,你会发现电子邮件有两个主要部分
-
元数据(以
From r开头),包含发件人、主题等。 -
电子邮件文本(从
Status: O或Status: RO后开始)
我使用第一个模式将连续的文本文件拆分为电子邮件列表。总体而言,我们应该能够读取 3977 封电子邮件,并将其放入 Polars 数据框中以进行进一步分析。
emails = load_emails_txt("fradulent_emails.txt")
emails_pl = pl.DataFrame({"emails": emails})
print(len(emails))
>>> 3977
提取相关字段
现在,棘手的部分开始了。我们如何从这堆混乱的文本数据中提取相关字段?不幸的是,答案是正则表达式。
发件人和主题
进一步检查元数据(如下所示)你会发现它有From:和Subject:字段,这对我们非常有用。
From r Wed Oct 30 21:41:56 2002
Return-Path: <james_ngola2002@maktoob.com>
X-Sieve: cmu-sieve 2.0
Return-Path: <james_ngola2002@maktoob.com>
Message-Id: <200210310241.g9V2fNm6028281@cs.CU>
From: "MR. JAMES NGOLA." <james_ngola2002@maktoob.com>
Reply-To: james_ngola2002@maktoob.com
To: webmaster@aclweb.org
Date: Thu, 31 Oct 2002 02:38:20 +0000
Subject: URGENT BUSINESS ASSISTANCE AND PARTNERSHIP
X-Mailer: Microsoft Outlook Express 5.00.2919.6900 DM
MIME-Version: 1.0
Content-Type: text/plain; charset="us-ascii"
Content-Transfer-Encoding: 8bit
Status: O
如果你继续滚动电子邮件,你会发现From:字段有几种格式。你看到的第一个格式是包含姓名和电子邮件的格式。第二种格式仅包含电子邮件,例如From: 123@abc.com或From: “123@abc.com”。考虑到这一点,我们需要三个正则表达式模式——一个用于主题,两个用于发件人(姓名和电子邮件,以及仅电子邮件)。
email_pattern = r"From:\s*([^<\n\s]+)"
subject_pattern = r"Subject:\s*(.*)"
name_email_pattern = r'From:\s*"?([^"<]+)"?\s*<([^>]+)>'
Polars 有一个str.extract方法,可以将上述模式与我们的文本进行比较,并(你猜对了)提取匹配的组。以下是如何将其应用于emails_pl数据框。
emails_pl = emails_pl.with_columns(
# Extract the first match group as email
pl.col("emails").str.extract(name_email_pattern, 1).alias("sender_name"),
# Extract the second match group as email
pl.col("emails").str.extract(name_email_pattern, 2).alias("sender_email"),
# Extract the subject
pl.col("emails").str.extract(subject_pattern, 1).alias("subject"),
).with_columns(
# In cases where we didn't extract email
pl.when(pl.col("sender_email").is_null())
# Try another pattern (just email)
.then(pl.col("emails").str.extract(email_pattern, 1))
# If we do have an email, do nothing
.otherwise(pl.col("sender_email"))
.alias("sender_email")
)
如你所见,除了str.extract之外,我们还使用了pl.when().then().otherwise()表达式(Polars 的 if/else 版本)来处理仅存在于第二个电子邮件模式的情况。如果你打印出结果,你会发现大多数情况下它应该正确地工作(而且速度极快)。我们现在有了sender_name、sender_email和subject字段用于分析。

Polars 数据框样本。作者截图。
电子邮件文本
如上所述,实际的电子邮件文本从 Status: O(已打开)或 Status: RO(已读并已打开)之后开始,这意味着我们可以利用这个模式将电子邮件拆分为“元数据”和“文本”部分。下面你可以看到提取所需字段的三个步骤以及执行它们的相应 Polars 方法。
-
将
Status: RO替换为Status: O,以便我们只有一个“拆分”模式 — 使用str.replace -
按照
Status: O拆分实际字符串 — 使用str.split -
获取结果列表的第二个元素(文本) — 使用
arr.get(1)
emails_pl = emails_pl.with_columns(
# Apply operations to the emails column
pl.col("emails")
# Make these two statuses the same
.str.replace("Status: RO", "Status: O", literal=True)
# Split using the status string
.str.split("Status: O")
# Get the second element
.arr.get(1)
# Rename the field
.alias("email_text")
)
看!我们在短短几毫秒内提取了重要字段。让我们将这些功能放到一个连贯的函数中,稍后在管道中使用。
def extract_fields(emails: pl.DataFrame) -> pl.DataFrame:
email_pattern = r"From:\s*([^<\n\s]+)"
subject_pattern = r"Subject:\s*(.*)"
name_email_pattern = r'From:\s*"?([^"<]+)"?\s*<([^>]+)>'
emails = (
emails.with_columns(
pl.col("emails").str.extract(name_email_pattern, 2).alias("sender_email"),
pl.col("emails").str.extract(name_email_pattern, 1).alias("sender_name"),
pl.col("emails").str.extract(subject_pattern, 1).alias("subject"),
)
.with_columns(
pl.when(pl.col("sender_email").is_null())
.then(pl.col("emails").str.extract(email_pattern, 1))
.otherwise(pl.col("sender_email"))
.alias("sender_email")
)
.with_columns(
pl.col("emails")
.str.replace("Status: RO", "Status: O", literal=True)
.str.split("Status: O")
.arr.get(1)
.alias("email_text")
)
)
return emails
现在,我们可以继续进行特征生成部分。
特征工程
根据个人经验,诈骗电子邮件往往非常详细且长(因为骗子试图赢得你的信任),所以电子邮件的字符长度会非常有信息量。此外,它们大量使用感叹号和数字,因此计算电子邮件中非字符的比例也可能很有用。最后,骗子喜欢使用大写字母,所以我们也来计算大写字母的比例。当然,我们还可以创建更多的特征,但为了不让这篇文章过长,我们就专注于这两个特征。
第一个特征可以通过内置的 str.n_chars() 函数非常容易地创建。其他两个特征可以使用正则表达式和 str.count_match() 计算。下面你可以找到计算这三个特征的函数。与之前的函数类似,它使用 with_columns() 子句来保留旧特征并在其上创建新的特征。
def email_features(data: pl.DataFrame, col: str) -> pl.DataFrame:
data = data.with_columns(
pl.col(col).str.n_chars().alias(f"{col}_length"),
).with_columns(
(pl.col(col).str.count_match(r"[A-Z]") / pl.col(f"{col}_length")).alias(
f"{col}_percent_capital"
),
(pl.col(col).str.count_match(r"[^A-Za-z ]") / pl.col(f"{col}_length")).alias(
f"{col}_percent_digits"
),
)
return data
文本清理
如果你打印出我们提取的一些电子邮件,你会注意到一些需要清理的内容。例如:
-
一些电子邮件中仍然存在 HTML 标签
-
使用了很多非字母字符
-
一些电子邮件是大写字母,一些是小写字母,还有一些是混合的
与上述相同,我们将使用正则表达式来清理数据。然而,现在选择的方法是 str.replace_all,因为我们想要替换所有匹配的实例,而不仅仅是第一个。此外,我们将使用 str.to_lowercase() 将所有文本转换为小写。
emails_pl = emails_pl.with_columns(
# Apply operations to the emails text column
pl.col("email_text")
# Remove all the data in <..> (HTML tags)
.str.replace_all(r"<.*?>", "")
# Replace non-alphabetic characters (except whitespace) in text
.str.replace_all(r"[^a-zA-Z\s]+", " ")
# Replace multiple whitespaces with one whitespace
# We need to do this because of the previous cleaning step
.str.replace_all(r"\s+", " ")
# Make all text lowercase
.str.to_lowercase()
# Keep the field's name
.keep_name()
)
现在,让我们将这串操作重构成一个函数,以便它可以应用于其他感兴趣的列。
def email_clean(
data: pl.DataFrame, col: str, new_col_name: str | None = None
) -> pl.DataFrame:
data = data.with_columns(
pl.col(col)
.str.replace_all(r"<.*?>", " ")
.str.replace_all(r"[^a-zA-Z\s]+", " ")
.str.replace_all(r"\s+", " ")
.str.to_lowercase()
.alias(new_col_name if new_col_name is not None else col)
)
return data
文本标记化
作为预处理管道的最后一步,我们将对文本进行标记化。标记化将使用已经熟悉的方法 str.split(),其中分隔符将指定为空格。
emails_pl = emails_pl.with_columns(
pl.col("email_text").str.split(" ").alias("email_text_tokenised")
)
再次,将这段代码放入一个函数中,以便在最终的管道中使用。
def tokenise_text(data: pl.DataFrame, col: str, split_token: str = " ") -> pl.DataFrame:
data = data.with_columns(pl.col(col).str.split(split_token).alias(f"{col}_tokenised"))
return data
去除停用词
如果你以前处理过文本数据,你会知道去除停用词是处理标记化文本的关键步骤。去除这些词可以让我们将分析集中在文本的重要部分。
为了删除这些词,我们首先需要定义它们。在这里,我将使用来自nltk库的默认停用词集以及一组与 HTML 相关的词汇。
stops = set(
stopwords.words("english")
+ ["", "nbsp", "content", "type", "text", "charset", "iso", "qzsoft"]
)
现在,我们需要找出这些词是否存在于分词数组中,如果存在,我们需要将它们删除。为此,我们需要使用arr.eval方法,因为它允许我们对分词列表的每个元素运行 Polars 表达式(例如 .is_in)。确保阅读下面的评论以理解每一行的作用,因为这部分代码比较复杂。
emails_pl = emails_pl.with_columns(
# Apply to the tokenised column (it's a list)
pl.col("email_text_tokenised")
# For every element, check if it's not in a stopwords list and only then return it
.arr.eval(
pl.when(
(~pl.element().is_in(stopwords)) & (pl.element().str.n_chars() > 2)
).then(pl.element())
)
# For every element of a new list, drop nulls (previously items that were in stopwords list)
.arr.eval(pl.element().drop_nulls())
.keep_name()
)
和往常一样,让我们将这段代码重构为一个函数,用于我们的最终管道。
def remove_stopwords(
data: pl.DataFrame, stopwords: set | list, col: str
) -> pl.DataFrame:
data = data.with_columns(
pl.col(col)
.arr.eval(pl.when(~pl.element().is_in(stopwords)).then(pl.element()))
.arr.eval(pl.element().drop_nulls())
)
return data
虽然这个模式可能看起来相当复杂,但使用预定义的str和arr表达式来优化性能是非常值得的。
完整管道
到目前为止,我们已经定义了预处理函数,并了解了它们如何应用于单个列。Polars 提供了一个非常实用的pipe方法,允许我们将指定为函数的 Polars 操作串联起来。这就是最终管道的样子:
emails = load_emails_txt("fradulent_emails.txt")
emails_pl = pl.DataFrame({"emails": emails})
emails_pl = (
emails_pl.pipe(extract_fields)
.pipe(email_features, "email_text")
.pipe(email_features, "sender_email")
.pipe(email_features, "subject")
.pipe(email_clean, "email_text")
.pipe(email_clean, "sender_name")
.pipe(email_clean, "subject")
.pipe(tokenise_text, "email_text")
.pipe(tokenise_text, "subject")
.pipe(remove_stopwords, stops, "email_text_tokenised")
.pipe(remove_stopwords, stops, "subject_tokenised")
)
注意,现在我们可以轻松地将所有特征工程、清理和分词函数应用于所有提取的列,而不仅仅是像上面示例中的电子邮件文本。
词云分析
如果你已经做到这一点——干得好!我们已经读取、清洗、处理、分词,并对大约 4k 文本记录进行了基本特征工程,全部在一秒钟内(至少在我的 Mac M2 机器上)。现在,让我们享受劳动的成果,进行一些基本的文本分析。
首先,让我们看看电子邮件文本的词云,并惊叹于我们能找到的所有有趣的东西。
# Word cloud function
def generate_word_cloud(text: str):
wordcloud = WordCloud(
max_words=100, background_color="white", width=1600, height=800
).generate(text)
plt.figure(figsize=(20, 10), facecolor="k")
plt.imshow(wordcloud)
plt.axis("off")
plt.tight_layout(pad=0)
plt.show()
# Prepare data for word cloud
text_list = emails_pl.select(pl.col("email_text_tokenised").arr.join(" "))[
"email_text_tokenised"
].to_list()
all_emails = " ".join(text_list)
generate_word_cloud(all_emails)

电子邮件文本词云。由作者生成。
银行账户、近亲、保安公司和逝者亲属——应有尽有。让我们看看这些在使用简单 TF-IDF 和 K-Means 创建的文本聚类中的样子。
# TF-IDF with 500 words
vectorizer = TfidfVectorizer(max_features=500)
transformed_text = vectorizer.fit_transform(text_list)
tf_idf = pd.DataFrame(transformed_text.toarray(), columns=vectorizer.get_feature_names_out())
# Cluster into 5 clusters
n = 5
cluster = KMeans(n_clusters=n, n_init='auto')
clusters = cluster.fit_predict(tf_idf)
for c in range(n):
cluster_texts = np.array(text_list)[clusters==c]
cluster_text = ' '.join(list(cluster_texts))
generate_word_cloud(cluster_text)
下面你可以看到我识别的一些有趣的聚类:

除此之外,我还发现了一些无意义的聚类,这意味着在文本清理方面仍有改进空间。不过,看起来我们成功地提取了有用的聚类,所以我们可以称之为成功。告诉我你发现了哪些聚类!
结论
这篇文章涵盖了 Polars 库允许你执行的各种预处理和清理操作。我们已经看到如何使用 Polars 来:
-
从文本中提取特定模式
-
根据标记将文本拆分为列表
-
计算文本中的长度和匹配数量
-
使用正则表达式清理文本
-
对文本进行分词并过滤停用词
我希望这篇文章对你有用,你会在下一个 NLP 项目中给 Polars 一个机会。请考虑订阅、点赞并在下面评论。
还不是 Medium 会员?
[## 使用我的推荐链接加入 Medium — Antons Tocilins-Ruberts
阅读 Antons Tocilins-Ruberts 和其他成千上万的作者在 Medium 上的每一个故事。你的会员费用直接…
medium.com](https://medium.com/@antonsruberts/membership?source=post_page-----fcf7054a929a--------------------------------)
参考文献
Radev, D. (2008), CLAIR collection of fraud email, ACL 数据和代码库, ADCR2008T001, http://aclweb.org/aclwiki
项目 Github
github.com/aruberts/tutorials/tree/main/metaflow/fraud_emailPolars 用户指南
pola-rs.github.io/polars-book/user-guide/
FastAPI 和 Streamlit:你必须了解的 Python 双雄
原文:
towardsdatascience.com/fastapi-and-streamlit-the-python-duo-you-must-know-about-72825def1243
完整的 7 步 MLOps 框架
第 6 课:使用 FastAPI 和 Streamlit 消耗和可视化模型预测。对所有内容进行 Docker 化
·发表于 Towards Data Science ·阅读时间 14 分钟·2023 年 6 月 12 日
--

图片由 Hassan Pasha 提供,Unsplash上的照片
本教程代表7 节课程中的第六部分课,将逐步指导你如何设计、实现和部署 ML 系统,并应用MLOps 良好实践。在课程中,你将构建一个生产级模型,用于预测丹麦多个消费者类型在接下来的 24 小时内的能源消耗水平。
完成本课程后,你将理解如何设计、编码和部署一个使用批量服务架构的 ML 系统的所有基础知识。
本课程针对希望通过构建自己的端到端项目来提升技能的中级/高级机器学习工程师。
如今,证书随处可见。构建先进的端到端项目并展示是获得专业工程师认可的最佳方式。
目录:
-
课程简介
-
课程内容
-
数据来源
-
第 6 课:使用 FastAPI 和 Streamlit 消耗和可视化模型预测。对所有内容进行 Docker 化。
-
第 6 课:代码
-
结论
-
参考文献
课程简介
在这 7 节课程结束时,你将知道如何:
-
设计一个批量服务架构
-
使用 Hopsworks 作为特征存储
-
设计一个从 API 读取数据的特征工程管道
-
构建一个带有超参数调优的训练管道
-
使用 W&B 作为 ML 平台来跟踪实验、模型和元数据
-
实现一个批量预测管道
-
使用 Poetry 构建你自己的 Python 包
-
部署你自己的私有 PyPi 服务器
-
用 Airflow 编排一切
-
使用预测来编码一个使用 FastAPI 和 Streamlit 的 Web 应用
-
使用 Docker 将你的代码容器化
-
使用 Great Expectations 确保数据验证和完整性
-
监控预测结果的性能随时间的变化
-
将所有内容部署到 GCP
-
使用 GitHub Actions 构建 CI/CD 管道
如果这听起来很多,不用担心。在你完成这个课程后,你将理解我之前所说的一切。最重要的是,你会知道我为什么使用这些工具以及它们如何作为一个系统协同工作。
如果你想最大限度地从这个课程中受益, 我建议你访问包含所有课程代码的 GitHub 仓库 。本课程旨在快速阅读和复制文章中的代码。
到课程结束时,你将知道如何实现下图。即使有些内容对你来说不太明白,也不用担心。我会详细解释一切。

课程中你将构建的架构示意图 [作者提供的图片]。
到第 6 课结束时,你将知道如何使用 FastAPI 和 Streamlit 从 GCP 存储桶中消费预测和监控指标。
课程内容:
-
私有 PyPi 服务器。用 Airflow 编排一切。
-
使用 GE 进行数据质量和完整性验证。模型性能的持续监控。
-
使用 FastAPI 和 Streamlit 消费和可视化你的模型预测。将一切 Docker 化。
查看 第 3 课 了解我们如何计算和存储 GCP 存储桶中的预测结果。
此外,在第五部分,你可以查看我们如何计算监控指标,这些指标也存储在 GCP 存储桶中。
你将从 GCP 存储桶中获取预测结果和监控指标,并使用 FastAPI 和 Streamlit 将其显示在一个友好的仪表板上。
数据源
我们使用了一个免费的开放 API,它提供了丹麦所有能源消费者类型的每小时能源消耗值[1]。
它们提供了一个直观的界面,你可以轻松查询和可视化数据。你可以在这里访问数据 [1]。
数据具有 4 个主要属性:
-
小时 UTC: 观测数据点时的 UTC 日期时间。
-
价格区域: 丹麦被划分为两个价格区域:DK1 和 DK2 —— 由大贝尔特分隔。DK1 位于大贝尔特以西,DK2 位于大贝尔特以东。
-
消费者类型: 消费者类型为行业代码 DE35,由丹麦能源公司拥有和维护。
-
总消耗: 总电力消耗(单位:千瓦时)
注意: 观测数据有 15 天的滞后!但对于我们的演示用例,这不是问题,因为我们可以模拟与实时相同的步骤。

我们的网页应用程序的截图,显示了我们如何预测区域=1 和消费者类型=212 的能源消耗[图片由作者提供]。
数据点具有每小时分辨率。例如:“2023–04–15 21:00Z”,“2023–04–15 20:00Z”,“2023–04–15 19:00Z”等。
我们将数据建模为多个时间序列。每个唯一的价格区域和消费者类型元组代表一个独特的时间序列。
因此,我们将构建一个模型,为每个时间序列独立预测接下来的 24 小时的能源消耗。
查看下面的视频以更好地理解数据的样子 👇
课程和数据源概述 [视频由作者提供]。
第六部分:使用 FastAPI 和 Streamlit 获取并可视化模型的预测。将一切 Docker 化。
第六部分的目标
在第六部分中,你将构建一个 FastAPI 后端,该后端将从 GCS 中获取预测结果和监控指标,并通过 RESTful API 暴露这些数据。更具体地说,通过一组端点通过 HTTP(S)暴露数据。
此外,你将使用 Streamlit 实现两个不同的前端应用程序:
-
显示预测的仪表板(即你的应用程序),
-
显示监控指标的仪表板(即你的监控仪表板)。
两个前端应用程序将通过 HTTP(s)从 FastAPI RESTful API 请求数据,并使用 Streamlit 将数据呈现成一些美丽的图表。
我想强调的是,你可以在 Python 中同时使用这两个框架(FastAPI 和 Streamlit)。这对数据科学家或机器学习工程师非常有用,因为 Python 是他们的终极工具。

最终架构的示意图,Lesson 6 组件以蓝色突出显示 [图片来源:作者]。
请注意,从存储桶中获取预测结果与 3 个管道设计完全解耦。例如,运行 3 个管道:特征工程、训练和推理大约需要 10 分钟。但从存储桶中读取预测结果或监控指标几乎是瞬间完成的。
因此,通过将预测结果缓存到 GCP,你从客户端的角度在线提供了 ML 模型:预测结果是实时提供的。
这就是批处理架构的魔力。
接下来的自然步骤是将你的架构从批处理架构迁移到请求-响应或流式架构。
好消息是 FE 和训练管道几乎是相同的,你只需要将批处理预测管道(即推理步骤)迁移到你的 web 基础设施中。阅读这篇文章以了解使用 Docker 以请求-响应方式部署模型的基础知识。
为什么?
因为训练管道将训练模型的权重上传到模型注册表。从那里,你可以根据你的用例需求使用这些权重。
理论概念与工具
FastAPI: 最新且最著名的 Python API web 框架之一。我尝试过所有顶级 Python API web 框架:Django、Flask 和 FastAPI,我的心属 FastAPI。
为什么?
首先,它本地支持异步,这可以用更少的计算资源提升性能。
其次,它使用简单直观,适合各种规模的应用程序。尽管如此,对于庞大的单体应用,我仍然会选择 Django。但这是另一个话题。
Streamlit: Streamlit 使得用 Python 轻松创建简单的 UI 组件(主要是仪表盘)变得非常简单。
Streamlit 的范围是让数据科学家和 ML 工程师利用他们最擅长的东西,即 Python,快速构建他们模型的漂亮前端。
这正是我们所做的✌️
因此,你将使用 FastAPI 作为后端,Streamlit 作为前端,仅用 Python 构建一个 web 应用。
Lesson 6:代码
注意: 所有的安装说明都在仓库的 README 文件中。这里你将直接进入代码部分。
Lesson 6 中的代码位于以下位置:
-
app-api 文件夹 — FastAPI 后端
-
app-frontend 文件夹 — 预测仪表盘
-
app-monitoring 文件夹 — 监控仪表盘
使用 Docker,你可以迅速启动所有三个组件:
docker compose -f deploy/app-docker-compose.yml --project-directory . up --build
直接将凭证存储在你的 git 仓库中是一个巨大的安全风险。这就是为什么你将通过.env文件注入敏感信息。
.env.default是你必须配置的所有变量的示例。它还帮助存储不敏感的属性的默认值(例如,项目名称)。

.env.default文件的截图[作者提供的图片]。
准备凭证
对于本课程,你需要访问的唯一服务是 GCS。在第 3 课的准备凭证部分,我们已经详细解释了如何操作。此外,你还可以在GitHub README中找到更多信息。
为了保持简洁,在本课程中,我想强调的是,Web 应用的 GCP 服务账户应仅具有读取权限,以保证安全。
为什么?
因为 FastAPI API 只会读取 GCP 存储桶中的数据,并且保持最小权限是一种良好的实践。
因此,如果你的 Web 应用被黑客入侵,攻击者只能使用被盗的服务账户凭证读取数据。他不能删除或覆盖数据,这在这种情况下要安全得多。
因此,重复第 3 课的准备凭证部分中的相同步骤,但选择Storage Object Viewer role角色,而不是Store Object Admin角色。
记住,你现在需要下载一个不同的 JSON 文件,其中包含你的 GCP 服务账户密钥,并具有只读访问权限。
查看README了解如何完成.env文件。我想强调的是,只有 FastAPI 后端需要加载.env文件。因此,你必须将.env文件仅放在app-api文件夹中。
FastAPI 后端
FastAPI 后端概述[作者提供的视频]。
提醒一下,FastAPI 代码可以在app-api/api下找到。
步骤 1: 创建 FastAPI 应用,在其中配置文档、CORS 中间件和端点根 API 路由器。
步骤 2: 定义 Settings 类。该类的作用是保存 API 代码中需要的所有常量和配置,例如:
-
通用配置: 端口、日志级别或版本,
-
GCP 凭证: 存储桶名称或 JSON 服务账户密钥的路径。
你将在项目中使用get_settings()函数来使用 Settings 对象。
同时,在 Config 类中,我们编程使 FastAPI 查找当前目录中的 .env 文件,并加载所有以 APP_API_ 为前缀的变量。
如你在 .env.default 文件中所见,所有变量都以 APP_API_ 开头。

.env.default 文件的截图 [图片由作者提供]。
步骤 3: 使用 Pydantic 定义 API 数据的模式。这些模式将数据从 JSON 编码或解码为 Python 对象,反之亦然。同时,它们根据你定义的数据模型验证 JSON 对象的类型和结构。
在定义 Pydantic BaseModel 时,添加类型到每个变量是至关重要的,这将用于验证步骤。
步骤 4: 定义你的端点,在网络术语中称为视图。通常,一个视图可以访问某些数据存储,并根据查询,它将数据源的一个子集返回给请求者。
因此,检索(即 GET 请求)数据的标准流程如下:
“client → 请求数据 → 端点 → 访问数据存储 → 编码为 Pydantic 模式 → 解码为 JSON → 返回请求的数据”
让我们看看我们是如何定义一个端点来获取所有消费者类型的:
我们使用了 "gcsfs.GCSFileSystem" 作为标准文件系统来访问 GCS 存储桶。
我们将端点附加到api_router。
使用 api_router.get() Python 装饰器,我们将一个基本函数附加到 /consumer_type_values 端点。
在上面的示例中,当调用 "https://<some_ip>:8001/api/v1/consumer_type_values" 时,将触发 consumer_type_values() 函数,端点的响应将严格基于该函数的返回值。
另一个重要的事情是,通过在 Python 装饰器中定义 response_model(即模式),你不必显式地创建 Pydantic 模式。
如果你返回一个与模式结构 1:1 匹配的字典,FastAPI 将自动为你创建 Pydantic 对象。
就这些。现在我们将重复相同的逻辑来定义其余的端点。FastAPI 使一切变得如此简单直观。
现在,让我们看一下整个 views.py 文件,在其中我们定义了以下端点:
-
/health → 健康检查
-
/consumer_type_values → 获取所有可能的消费者类型
-
/area_values → 获取所有可能的区域类型
-
/predictions/{area}/{consumer_type} → 获取给定区域和消费者类型的预测。请注意,使用 {<some_variable>} 语法,你可以向端点添加参数 —— FastAPI 文档 [2]。
-
/monitoring/metrics → 获取汇总的监控指标
-
/monitoring/values/{area}/{consumer_type} → 获取给定区域和消费者类型的监控值
我想再次强调,FastAPI 后端只读取 GCS 桶的预测。推理步骤完全在批量预测管道中完成。
你还可以访问“http://

Swapper API 文档的截图[作者的图片]。
就是这样!现在你知道如何构建 FastAPI 后端了。添加数据库层和用户会话可能会使事情变得更复杂,但你已经掌握了所有主要概念,这将帮助你入门!
Streamlit 预测仪表板
Streamlit 预测仪表板概述[作者的视频]。
访问app-frontend/frontend下的代码。
使用 Streamlit 非常简单。整个 UI 通过下面的代码定义,代码执行以下操作:
-
它定义标题,
-
它向后端请求所有可能的区域类型,并根据此创建一个下拉列表,
-
它向后端请求所有可能的消费者类型,并根据此创建一个下拉列表,
-
基于当前选择的区域和消费者类型,它构建并渲染一个 plotly 图表。
直接了当,对吧?
请注意,我们本可以对 HTTP 请求的状态码进行额外的检查。例如,如果请求状态码与 200 不同,则显示一条文本“服务器宕机”。但我们希望保持简洁,只强调 Streamlit 代码✌️
我们将所有常量移到了一个不同的文件中,以便在整个代码中轻松访问。下一步,你可以通过.env文件使其可配置,类似于 FastAPI 的设置。
现在,让我们看看我们是如何构建图表的🔥
这一部分没有 Streamlit 代码,只有一些 Pandas 和 Plotly 代码。
build_data_plot()函数执行 3 个主要步骤:
-
它从 FastAPI 后端请求某个区域和消费者类型的预测数据。
-
如果响应有效(status_code == 200),则从响应中提取数据并构建一个 DataFrame。否则,它会创建一个空的 DataFrame,以便进一步传递相同的结构。
-
它使用上述计算的 DataFrame 构建一个折线图——plotly 图表。
build_dataframe()函数的作用是接受 2 个列表:
-
一个日期时间的列表,将作为折线图的 X 轴;
-
一组将用作折线图 Y 轴的值;
…并将其转换为 DataFrame。如果一些数据点缺失,我们会将日期时间重采样为 1 小时的频率,以确保数据连续并突出显示缺失的数据点。
非常简单,对吧?这就是人们喜欢 Streamlit 的原因。
Streamlit 监控仪表板
Streamlit 监控仪表板概述[作者的视频]。
监控代码可以在 app-monitoring/monitoring*** 下访问。***
你会发现代码几乎与预测仪表板相同。
在定义 Streamlit UI 结构时,我们还实现了一个包含汇总指标和分隔符的图表。
解耦 UI 组件定义和数据访问的好处在于,你可以在 UI 中注入任何数据,只要尊重预期数据的接口,而无需修改 UI。
build_metrics_plot() 函数几乎与预测仪表板中的 build_data_plot() 函数相同,只是我们从 API 请求的数据不同。
对于监控仪表板中的 build_data_plot() 函数也是如此:
如你所见,所有的数据访问和操作都在 FastAPI 后端处理。Streamlit UI 的工作是请求和展示数据。
很高兴我们只重用了 90% 的预测仪表板代码来构建一个友好的监控仪表板。
用 Docker 包装一切
最后的步骤是将这三个 web 应用程序 Docker 化,并将它们打包到一个 docker-compose 文件中。
因此,我们可以通过一个命令启动整个 web 应用程序:
docker compose -f deploy/app-docker-compose.yml --project-directory . up --build
这里是 FastAPI Dockerfile:
值得一提的是,我们最初只复制并安装了 Poetry 依赖项。因此,当你修改代码时,Docker 镜像将仅从第 19 行开始重建,即复制你的代码。
这是一种常见的策略,利用 Docker 缓存功能在构建镜像时加快开发过程,因为你很少添加新的依赖项,而安装它们是最耗时的步骤。
此外,在 run.sh 中我们调用:
/usr/local/bin/python -m api
但是等一下,命令中没有 Python 文件 😟
其实,你可以在模块内部定义一个 main.py 文件,使你的模块可执行。
因此,当调用 api 模块时,你会调用 main.py 文件:
在我们的例子中,在main.py 文件中,我们使用 uvicorn web 服务器来启动 FastAPI 后端,并用正确的 IP、端口、日志级别等进行配置。
这里是 Streamlit 预测仪表板 Dockerfile:
如你所见,这个 Dockerfile 几乎与用于 FastAPI 后端的那个相同,除了最后的 CMD 命令,这是一个标准的 CLI 命令,用于启动你的 Streamlit 应用程序。
Streamlit 监控仪表板 Dockerfile与预测仪表板 Dockerfile 完全相同。所以在这里重复粘贴是多余的。
好消息是,你可以利用我之前展示的 Dockerfile 模板来 Docker 化大部分 Python 应用程序✌️
最后,让我们看看如何使用 docker-compose 来完成所有工作。你可以在deploy/app-docker-compose.yml文件中找到相关内容:
如你所见,前端和监控服务必须等待 API 启动后才能开始。
另外,只有 API 需要从.env文件中加载凭证。
现在,你只需运行以下命令,Docker 将处理构建镜像和运行容器:
docker compose -f deploy/app-docker-compose.yml --project-directory . up --build
总结
恭喜!你完成了第六课的全栈七步 MLOps 框架课程。这意味着你现在已经理解了如何使用你的机器学习系统的预测来构建你出色的应用程序。
在本课中,你学会了如何:
-
从 GCS 中消费预测和监控指标,
-
构建一个 FastAPI 后端来加载和服务来自 GCS 的数据,
-
在 Streamlit 中实现一个仪表板来展示预测,
-
在 Streamlit 中创建一个监控仪表板来可视化模型的性能。
现在你已经理解了基于批量预测架构的机器学习系统上构建应用程序的灵活性,你可以轻松设计全栈机器学习应用程序。
查看第 7 课,这是全栈七步 MLOps 框架的最后一步,即将所有内容部署到 GCP 并使用 GitHub Actions 构建 CI/CD 管道。
另外, 你可以在这里访问 GitHub 仓库。
💡 我的目标是帮助机器学习工程师在设计和生产化机器学习系统方面提升水平。关注我在LinkedIn或订阅我的每周通讯以获取更多见解!
🔥 如果你喜欢阅读这样的文章并希望支持我的写作,请考虑成为 Medium 会员。使用我的推荐链接,你可以在不增加额外成本的情况下支持我,同时享受 Medium 丰富故事的无限访问权限。
[## 使用我的推荐链接加入 Medium - Paul Iusztin
🤖 加入以获取有关设计和构建生产就绪机器学习系统的独家内容🚀 解锁完全访问权限...
谢谢✌🏼!
参考资料
[1] 丹麦 API 每小时 DE35 行业代码的能源消耗,丹麦能源数据服务
[2] 路径参数,FastAPI 文档
使用 Ray 加速时间序列预测的训练,第三部分,共 3 部分
使用 Ray 和 Ray AIR 通过分布式计算更快地训练多个模型
·
关注 发布于 Towards Data Science ·11 分钟阅读·2023 年 1 月 24 日
--
图片由 StableDiffusion 绘制,日期为 2022 年 1 月 5 日,查询内容为“绘制一幅图像,展示多种不同的深度神经网络时间序列模型同时训练的样子,风格类似 Cy Twombly”。
介绍 / 动机
即使在当前生成式 AI(如稳定扩散、ChatGPT)和大语言模型(LLM)时代,时间序列预测仍然是任何依赖供应链或资源的业务的基础部分。 例如,它可以用于:
-
资源利用预测 用于数据中心资源配置。
所有这些用例的共同点是 在不同数据片段上训练多个模型。 使用分布式计算并行训练、调优和部署成千上万的机器学习模型可能是一项挑战! 典型的时间序列建模软件本身并不具备分布式功能。
本博客将展示我开始将预测工作负载转换为分布式计算的技巧。 我将使用最新的 Ray v2 API,结合 ARIMA 使用 statsforecast、Prophet 和 PyTorch Forecasting 库。数据方面,我将使用流行的 NYC Taxi 数据集,该数据集包含按时间戳和地点记录的历史出租车接送信息。
Ray 是一个开源框架,用于通过分布式计算扩展 AI 工作负载。有关 Ray 的概述,请查看 Ray 文档 或这篇 介绍博客文章。

Ray “AIR”(AI Runtime),自 Ray 2.0 起提供,包括 Ray Data、Ray Train、Ray Tune、RLlib 和 Ray Serve。图片来源于作者。
本博客分为 4 个部分:
-
使用 Ray Core 多进程进行多模型分布式训练。
-
使用 Ray AIR 进行多模型分布式调优。
-
使用 Ray AIR 进行多模型分布式调优,针对更少但更大的模型。
-
使用 Ray AIR 和 Ray Serve 进行多模型分布式部署。
第一部分:使用 Ray Core 多进程进行多模型分布式训练
回到 2021 年 11 月,我写了一篇 博客文章 演示如何在 AWS 云上使用 Ray Core 并行训练多个预测模型(无论是 ARIMA 还是 Prophet)。从那时起, Ray 多进程 是一个巨大的改进,使事情比 Ray Core API 更加简单。
下面是代码大纲。完整的更新代码在 我的 GitHub 上。首先,让我们从几个导入开始:
import time, dateutil
from typing import Tuple, List
import numpy as np
import pandas as pd
# Import libraries for reading from partitioned parquet.
import pyarrow.parquet as pq
import pyarrow.dataset as pds
# Import forecasting libraries.
import prophet
from prophet import Prophet
# Import Ray's multiprocessing library.
import ray
from ray.util.multiprocessing import Pool
import tqdm
接下来,让我们定义 Python 函数来预处理数据、训练和评估模型。为了更快地了解分布式计算概念,我们将假设时间序列数据已经准备好并根据所需模型拆分为不同的文件。
##########
# STEP 1\. Define a Python function to read and preprocess a file of data.
##########
def preprocess_data(file_path: str) -> Tuple[pd.DataFrame, np.int32]:
# Read a single pyarrow parquet S3 file.
data = pq.read_table(file_path,
filters=[ ("pickup_location_id", "=", SAMPLE_UNIQUE_ID) ],
columns=[ "pickup_at", "pickup_location_id", TARGET ])\
.to_pandas()
# Transform data.
data["ds"] = data["pickup_at"].dt.to_period("D").dt.to_timestamp()
data.rename(columns={TARGET: "y"}, inplace=True)
data.rename(columns={"pickup_location_id": "unique_id"}, inplace=True)
data.drop("pickup_at", inplace=True, axis=1)
unique_id = data["unique_id"][0]
return data, unique_id
##########
# STEP 2\. Define Python functions to train and evaluate a model on a file of data.
##########
def train_model(file_path: str) -> Tuple[pd.DataFrame, pd.DataFrame,
'prophet.forecaster.Prophet', np.int32]:
# Prepare data from a single S3 file.
data, unique_id = preprocess_data(file_path)
# Split data into train, test.
train_end = data.ds.max() - relativedelta(days=FORECAST_LENGTH - 1)
train_df = data.loc[(data.ds <= train_end), :].copy()
test_df = data.iloc[-(FORECAST_LENGTH):, :].copy()
# Define Prophet model with 75% confidence interval.
model = Prophet(interval_width=0.75, seasonality_mode="multiplicative")
# Train and fit Prophet model.
model = model.fit(train_df[["ds", "y"]])
return train_df, test_df, model, unique_id
def evaluate_model(model: 'prophet.forecaster.Prophet', train: pd.DataFrame,
valid: pd.DataFrame, input_value: np.int32) -> Tuple[float, pd.DataFrame]:
# Inference model using FORECAST_LENGTH.
future_dates = model.make_future_dataframe(
periods=FORECAST_LENGTH, freq="D")
future = model.predict(future_dates)
# Merge in the actual y-values.
future = pd.merge(future, train[['ds', 'y']], on=['ds'], how='left')
future = pd.merge(future, valid[['ds', 'y']], on=['ds'], how='left')
future['y'] = future.y_x.combine_first(future.y_y)
future.drop(['y_x', 'y_y'], inplace=True, axis=1)
future['unique_id'] = input_value
# Calculate mean absolute forecast error.
temp = future.copy()
temp["forecast_error"] = np.abs(temp["yhat"] - temp["y"])
temp.dropna(inplace=True)
error = np.mean(temp["forecast_error"])
return error, future
############
# STEP 3\. Define a calling function which calls all the above functions,
# and will be called in parallel for every data file.
############
def train__and_evaluate(file_path: str) -> Tuple[pd.DataFrame,
'prophet.forecaster.Prophet', pd.DataFrame, float, np.int16]:
# Read S3 file and train a Prophet model.
train_df, valid_df, model, unique_id = train_model(file_path)
# Inference model and evaluate error.
error, future = evaluate_model(model, train_df, valid_df, unique_id)
return valid_df, model, future, error, unique_id
我们可以直接使用 Ray 核心 API 调用 [ray.remote](https://docs.ray.io/en/latest/ray-core/walkthrough.html) 进行并行化,但 Ray 的 多进程库 作为 Ray 的分布式库之一,使这变得更简单。
下面,将对 pool 的调用包装在 tqdm 中,可以获得一个很好的进度条来监控进度。在内部,Ray 将任务分配给 Ray 集群中的工作节点,这自动处理如容错和批处理优化等问题。
start = time.time()
# Create a pool, where each worker is assigned 1 CPU by Ray.
pool = Pool(ray_remote_args={"num_cpus": 1})
# Use the pool to run `train_model` on the data, in batches of 1.
iterator = pool.imap_unordered(train__and_evaluate, models_to_train, chunksize=1)
# Track the progress using tqdm and retrieve the results into a list.
results = list(tqdm.tqdm(iterator, total=len(models_to_train)))
# Print some training stats.
time_ray_multiprocessing = time.time() - start
print(f"Total number of models: {len(results)}")
print(f"TOTAL TIME TAKEN: {time_ray_multiprocessing/60:.2f} minutes")
print(type(results[0][0]), type(results[0][1]), type(results[0][2]),
type(results[0][3]), type(results[0][4]))

在我的 MacBook Pro 笔记本电脑(配备 8 个 CPU)上运行上述示例时的截图。Ray 多进程的运行时间大约为半分钟,比串行 Python 快 3.5 倍,即 300.0% 的加速。使用更大的集群和/或更大的数据可以获得更多加速。图片由作者提供。
上面,我们可以看到 Ray 作业在不到 1 分钟的时间内训练了 12 个模型。
第二部分:使用 Ray AIR 进行多模型分布式调优
精明的读者可能已经注意到,在上述部分中,Ray 多进程要求数据已经按模型组织成一个文件。但如果你的数据尚未按模型组织怎么办? 使用 Ray AIR,你可以在训练不同模型的同时在同一管道中预处理数据。
另一个问题是,如果你想同时混合使用来自多个库的算法怎么办? Ray Tune,作为 Ray AIR 的一部分,允许你运行并行试验,从任何 Python AI/ML 库和超参数中找到最佳算法选择,每个数据段。
以下是预处理数据和自动调整模型的步骤。虽然这些步骤特定于 Ray AIR 及其 API,但这些步骤通常适用于将串行 Python 转换为分布式 Python:
-
定义 Python 函数来
**preprocess**一个数据段。 -
定义 Python 函数来
**train**和**evaluate**一个数据段上的模型。 -
定义一个调用函数
**train_models**,该函数调用所有上述函数,并将为 Tune 搜索空间中的每个排列并行调用!在这个train_models函数中:
📖 输入参数必须包括一个配置字典参数。
📈 调优指标(模型的损失或错误)必须使用
session.report()计算并报告。✔️ 推荐
Checkpoint(保存)模型以便于容错和后续部署。 -
配置分布式计算缩放。
-
定义 Tune 搜索空间的所有训练参数。
-
(可选)指定超参数搜索策略。
-
运行实验。
下面是我们将添加的附加代码;完整代码在我的 GitHub上。
-
下面的
preprocess_data和train_model函数与之前完全相同,只是它们接受一个文件列表而不是单个文件。 -
train_models函数与train_and_evaluate完全相同,只是它接受一个文件列表而不是单个文件。它还训练配置中传递的算法,而不是固定的算法,并进行检查点保存。
import os
num_cpu = os.cpu_count()
# Import another forecasting library.
import statsforecast
from statsforecast import StatsForecast
from statsforecast.models import AutoARIMA
# Import Ray AIR libraries.
from ray import air, tune
from ray.air import session, ScalingConfig
from ray.air.checkpoint import Checkpoint
##########
# STEP 1\. Define Python functions to read and prepare a segment of data.
##########
def preprocess_per_uniqueid(
s3_files: List[str], sample_location_id: np.int32) -> pd.DataFrame:
# Load data.
df_list = [read_data(f, sample_location_id) for f in s3_files]
df_raw = pd.concat(df_list, ignore_index=True)
# Transform data.
df = transform_df(df_raw)
df.sort_values(by="ds", inplace=True)
return df
##########
# STEP 2\. Define Python functions to train and evaluate a model on a segment of data.
##########
def train_prophet(s3_files: List[str], sample_unique_id: np.int32,
model_type: str) -> Tuple[pd.DataFrame, pd.DataFrame,
'prophet.forecaster.Prophet', np.int32]:
# Prepare data from a list of S3 files.
data = preprocess_per_uniqueid(s3_files, sample_unique_id)
# Split data into train, test.
train_end = data.ds.max() - relativedelta(days=FORECAST_LENGTH - 1)
train_df = data.loc[(data.ds <= train_end), :].copy()
test_df = data.iloc[-(FORECAST_LENGTH):, :].copy()
# Define Prophet model with 75% confidence interval.
if model_type == "prophet_additive":
model = Prophet(interval_width=0.75, seasonality_mode="additive")
elif model_type == "prophet_multiplicative":
model = Prophet(interval_width=0.75, seasonality_mode="multiplicative")
# Train and fit Prophet model.
model = model.fit(train_df[["ds", "y"]])
return train_df, test_df, model
# Train an ARIMA model. Full code not shown here.
def train_arima():
# Evaluate an ARIMA model. Full code not shown here.
def evaluate_arima():
############
# STEP 3\. Define a calling function `train_models`, which calls all
# the above functions, and will be called in parallel for every
# permutation in the Tune search space.
############
def train_models(config: dict) -> None:
# Get Tune parameters
file_list = config['params']['file_list']
model_type = config['params']['algorithm']
sample_unique_id = config['params']['location']
# Train model.
if model_type == "arima":
# Train and fit the Prophet model.
train_df, valid_df, model = \
train_arima(file_list, sample_unique_id)
# Inference model and evaluate error.
error, future = \
evaluate_arima(model, valid_df)
else:
# Train and fit the Prophet model.
train_df, valid_df, model = \
train_prophet(file_list, sample_unique_id, model_type)
# Inference model and evaluate error.
error, future = evaluate_model(model, train_df, valid_df, sample_unique_id)
# Define a model checkpoint using AIR API.
checkpoint = ray.air.checkpoint.Checkpoint.from_dict({
"model": model,
"valid_df": valid_df,
"forecast_df": future,
"location_id": sample_unique_id,
})
metrics = dict(error=error)
session.report(metrics, checkpoint=checkpoint)
############
# STEP 4\. Customize distributed compute scaling.
############
num_training_workers = min(num_cpu - 2, 32)
scaling_config = ScalingConfig(
# Number of distributed workers.
num_workers=num_training_workers,
# Turn on/off GPU.
use_gpu=False,
# Specify resources used for trainer.
trainer_resources={"CPU": 1},
# Try to schedule workers on different nodes.
placement_strategy="SPREAD")
############
# STEP 5\. Define a search space dict of all config parameters.
############
search_space = {
"scaling_config": scaling_config,
"params": {
"file_list": tune.grid_search([files_to_use]),
"algorithm": tune.grid_search(algorithms_to_use),
"location": tune.grid_search(models_to_train),
},
}
# Optional STEP 6\. Specify the hyperparameter tuning search strategy.
##########
# STEP 7\. Run the experiment with Ray AIR APIs.
##########
# Define a tuner object.
tuner = tune.Tuner(
train_models,
param_space=search_space,
tune_config=tune.TuneConfig(
metric="error",
mode="min",
),
run_config=air.RunConfig(
# Redirect logs to relative path instead of default ~/ray_results/.
local_dir="my_Tune_logs",
# Specify name to make logs easier to find in log path.
name="tune_nyc",
),
)
# Fit the tuner object.
results = tuner.fit()

顶部:Ray Tune 试验状态的截图,显示了 768 个候选模型的错误(每 256 个 NYC 出租车接送位置有 3 种算法选择)。在不到 45 分钟内完成训练。左下:Prophet 模型推断+预测图(实际值为黑点,带置信区间的预测为蓝色)用于 NYC 出租车接送位置=165。右下:ARIMA 模型推断+预测图(实际值为蓝色,预测为橙色)用于 NYC 出租车接送位置=237。所有模型使用 Anyscale 在 10 节点的 AWS 集群上训练,集群包括 m5.4xlarges工作节点和一个 m5.2xlarge 头节点。图像由作者提供。
在上述截图中,自 2018 年 1 月以来的数据被分组并汇总到日级别。我曾尝试在 SageMaker 上完成这项工作,仅数据处理就花费了太长时间,更不用说同时调优如此多的模型了。
第三节:多模型分布式调优(更大的 PyTorch 模型)
目标通常是创建一些更大的模型,例如按地理区域划分的模型,其中只有少数几个这样的区域。一年前,即 2021 年 12 月,我写了一篇博客文章,展示了如何使用 Ray Lightning 训练更大的 PyTorch 预测模型。自那时以来,一个重大改进是,感谢 Anyscale Workspaces,笔记本电脑和云之间的代码开发切换更加无缝。
这些较大的模型有时被称为“全球模型”,因为只有一个深度神经网络模型在多个不同时间序列上进行训练,而不是每个时间序列一个模型(Prophet 或 ARIMA)。
请查看 我的 GitHub 以获取完整的 PyTorch Forecasting 代码,其中展示了最新的 Ray AIR API 与 Ray Lightning。你需要将集群 ID 添加到数据中,然后调优步骤与第二部分中看到的相同:
# Import forecasting libraries.
import torch
import pytorch_lightning as pl
import pytorch_forecasting as ptf
import tensorboard as tb
# Import ray libraries.
import ray_lightning
from ray_lightning import RayStrategy
from ray_lightning.tune import get_tune_resources, TuneReportCheckpointCallback
from ray import air, tune
from ray.tune.schedulers import ASHAScheduler
# Define a tuner object.
tuner = tune.Tuner(
tune.with_resources(
train_with_parameters,
resources=get_tune_resources(num_workers=num_training_workers),
),
tune_config=tune.TuneConfig(
metric="loss",
mode="min",
scheduler=scheduler,
),
run_config=air.RunConfig(
# Redirect logs to relative path instead of default ~/ray_results/.
local_dir="my_Tune_logs",
# Specify name to make logs easier to find in log path.
name="ptf_nyc",
),
param_space=FORECAST_CONFIG,
)
# Fit the tuner object.
results = tuner.fit()
# Get checkpoint for best model from results object, code not shown.
# Plot inference forecasts for some unique_ids.
some_unique_ids = [25, 41, 14, 24, 4]
for idx in some_unique_ids:
best_model.plot_prediction(x, raw_predictions, idx=idx)

顶部。在调优六个 PyTorch Forecasting TemporalFusionTransformer 模型时截图 Ray Tune Trial 状态。(3 个学习率,2 个纽约市出租车位置簇)。总运行时间少于 2 分钟。在 2 节点的 AWS 集群(m5.4xlarge 工作节点和一个 m5.2xlarge 主节点)上运行,时间不超过 2 分钟。底部:使用单个模型对多个出租车接客地点进行的推断预测图。
请注意,与 ARIMA 和 Prophet 模型(每个 unique_id 一个模型)不同,这些较大的模型每次包含对多个 unique_ids 的推断。
第四部分:使用 Ray AIR 和 Ray Serve 进行多模型分布式部署
在部署之前,你必须决定你的部署是需要一个在线、始终运行的 http 服务,还是离线的(按需调用的 Python 服务)。下面,我演示了如何使用新的 Ray AIR Predictors 和 Ray Serve 进行离线部署。
离线部署的步骤是:
步骤 1。使用 Ray AIR 检查点实例化一个批量预测器。
步骤 2。创建一些测试数据。
步骤 3。运行 batch_predictor.predict(test_data)。
将上述步骤 3 替换为自定义预测器的以下步骤:
步骤 3。通过使用 Ray 装饰器 @serve.deployment 定义一个 Ray Serve 部署类。
步骤 4。部署预测器。
步骤 5。查询部署并获取结果。
上述步骤 3–5 仅在使用自定义预测器(如 ARIMA、Prophet 或 PyTorch Forecasting)时需要。
否则,对于集成了 Ray AIR 的 ML 库(HuggingFace transformers、PyTorch、TensorFlow、Scikit-learn、XGBoost 或 LightGBM),你只需调用 batch_predictor.predict(test_data) 即可。
继续上一节中关于 PyTorch Forecasting 的示例,下面是部署代码。完整代码在 我的 GitHub 上。
import pickle
import numpy as np
import pandas as pd
import pyarrow
import pyarrow.parquet as pq
# Import forecasting libraries.
import torch
import pytorch_lightning as pl
import pytorch_forecasting as ptf
# Import ray libraries.
import ray
from ray import serve
##########
# STEP 1\. Instantiate a batch predictor from checkpoint.
##########
batch_predictor = ptf.models.TemporalFusionTransformer.load_from_checkpoint(model_path)
##########
# STEP 2\. Create some test data.
##########
# Being lazy, pretend the last test data is our out-of-sample test data.
max_prediction_length = FORECAST_CONFIG['forecast_horizon']
new_prediction_data = df.copy()
new_prediction_data["time_idx"] = new_prediction_data["time_idx"] + max_prediction_length
# Convert data from pandas to PyTorch tensors.
_, _, test_loader = convert_pandas_pytorch_timeseriesdata(
new_prediction_data, FORECAST_CONFIG)
##########
# STEP 3\. Define a Ray Serve deployment class.
##########
@serve.deployment
class ForecastPredictor:
def __init__(self, predictor, test_data):
self.predictor = predictor
self.test_data = test_data
def predict(self):
raw_predictions, x = \
self.predictor.predict(self.test_data, mode="raw", return_x=True)
return x, raw_predictions
def __call__(self):
x, raw_predictions = self.predict()
return [x, raw_predictions]
##########
# STEP 4\. Deploy the predictor.
##########
# Bind arguments to the Class constructor.
my_first_deployment = ForecastPredictor.bind(
predictor=batch_predictor,
test_data=test_loader)
##########
# STEP 5\. Query the deployment and get the result.
##########
# Get handle from serve.run().
handle = serve.run(my_first_deployment)
# ray.get() the results from the handle.
ray_return = ray.get(handle.remote())
new_x = ray_return[0]
new_pred = ray_return[1]

左侧:Ray 仪表板的截图(默认在localhost:8265的主节点上访问)在服务期间的情况。右侧:Ray 仪表板在运行上述示例时的截图。你可以看到在自动缩放中有 5 个峰值,因为在部署最终训练模型之前,我对训练代码进行了 5 次不同的迭代。Anyscale 运行在 3 节点 AWS 集群上,其中包括 m5.4xlarges 工作节点和 1 个 m5.2xlarge 主节点。图片由作者提供。
上述右侧的截图展示了在训练和服务期间 Ray 集群的可观察性。如果你需要对预测结果进行后处理,我在这个笔记本的末尾有一个示例。
结论
总之,本博客展示了如何使用开源 Ray 在分布式计算中并行训练和调整多个模型。模型不必都是相同类型的,可以从任何 AI/ML Python 库中混合匹配。
Ray AIR API 清晰、直观,并隐藏了许多分布式计算的复杂性,因此可以轻松完成许多复杂的任务,如早期停止、ASHA 调度、检查点和部署。
进一步学习:
FastSpeech:论文概述与实现
了解文本到语音的实现方式以及其如何通过变压器(transformers)来实现
·
关注 发表在 Towards Data Science · 9 分钟阅读 · 2023 年 11 月 9 日
--
2019 年,FastSpeech 在神经文本到语音(neural text-to-speech)领域推动了前沿技术,通过显著提高推理速度,同时保持鲁棒性以防止单词重复或遗漏。它还允许在语音和韵律方面对输出语音进行控制。
在这个故事中,我们旨在让你熟悉变压器(transformers)如何用于文本到语音的转换,提供 FastSpeech 论文的简明概述,并指导你如何从零实现它。在此过程中,我们假设你对变压器及其不同组件已经有一定了解。如果没有,我们强烈推荐你先阅读前一篇文章,该文深入探讨了这一主题。

梵高风格的画作,描绘了一个变压器在讲台上对着麦克风讲话——由作者使用 Canva 生成
目录
· 背景
∘ 介绍
∘ 梅尔频谱图
· 论文概述
∘ 介绍
∘ 实验和结果
∘ 架构
∘ 编码器
∘ 长度调节器
∘ 解码器
· 实现
∘ 策略
∘ 完整实现
背景
介绍
传统的文本到语音(TTS)模型依赖于拼接和统计技术。拼接技术通过将来自语音音素数据库(语言中的独立声音单元)的声音拼接起来合成语音。统计技术(例如 HMM)试图建模足够生成波形的语音基本属性。这两种方法通常在生成自然声音或表达情感时存在问题。换句话说,它们往往为给定文本生成不自然或机械的语音。
使用深度学习(神经网络)来进行 TTS 显著提升了语音质量。这些方法通常由两个主要模型组成:第一个模型接收文本并输出相应的梅尔频谱图,第二个模型接收梅尔频谱图并合成语音(称为声码器)。
梅尔频谱图

频谱图由 Flickr 上的官方 CTBTO 照片流提供 CC BY-SA 2.0。
在最基本的形式下,语音波形只是代表空气压力随时间变化的幅度序列。我们可以使用短时傅里叶变换(STFT)将任何波形转换为相应的梅尔频谱图(这是一个矩阵,表示原始波形不同时间窗口中不同频率的幅度)。使用短时傅里叶变换将一段音频映射到其梅尔频谱图很容易;然而,反向操作相当困难,最佳系统方法(例如,Griffin Lim)可能会产生粗糙的结果。一个更好的方法是为这个任务训练一个模型。现有的训练此任务的模型包括 WaveGlow 和 WaveNet。
因此,重申一下,深度学习方法通常通过训练模型来预测与许多文本实例对应的语音梅尔频谱图来处理文本到语音。然后,它依赖另一个模型(称为声码器)将预测的频谱图映射到音频。FastSpeech 使用了 Nvidia 的 WaveGlow 模型。

一个快乐的变压器在写研究论文,画作呈现梵高风格。——由作者使用 Canva 生成
论文概述
介绍
尽管基于变压器的最新 TTS 方法在语音质量上相比传统方法有了显著提升,但这些模型仍然存在三个主要问题:
-
由于变换器的解码器是自回归的,因此它们在推理语音时速度较慢。也就是说,它们依赖于先前生成的块,逐步生成 Mel 频谱图。这同样适用于基于 RNN 和 CNN 的旧深度学习模型。
-
它们不够鲁棒;由于注意力分数(即对齐)的微小错误在顺序生成过程中传播,可能会发生词语跳过或重复现象。
-
他们缺乏一种简单的方法来控制生成语音的特性,例如速度或韵律(如语调)。
FastSpeech 尝试解决所有三个问题。与其他变换器架构相比,它的两个关键差异是:
-
解码器是非自回归的;它是完全可以并行处理的,因此解决了速度问题。
-
它在解码器之前使用了一个长度调节组件,试图确保音素和 Mel 频谱图之间的理想对齐,并移除了交叉注意力组件。
-
长度调节器的操作方式允许通过超参数轻松控制语音速度。韵律的微小属性,如暂停时长,也可以以类似的方式进行控制。
-
作为回报,出于长度调节的目的,它在训练过程中使用了序列级知识蒸馏。换句话说,它依赖于另一个已经训练好的文本到语音模型进行训练(Transformer TTS 模型)。
实验与结果
作者使用了 LJSpeech 数据集,该数据集包括约 24 小时的音频,分布在 13100 个音频片段中(每个片段都有相应的输入文本)。训练任务是输入文本并让模型预测相应的频谱图。约 95.6% 的数据用于训练,其余部分被拆分用于验证和测试。
-
推理速度提升 相比于自回归变换器 TTS 模型,它将推理速度提高了 38 倍(不包括声码器则为 270 倍);因此,得名 FastSpeech。
-
音频质量 利用 20 位母语为英语的评分者的平均意见分数,作者展示了 FastSpeech 的质量与 Transformer TTS 模型和 Tacotron 2(当时的最先进技术)非常接近。
-
鲁棒性 FastSpeech 在 50 个具有挑战性的文本到语音示例中,表现优于 Transformer TTS 和 Tacotron 2,具有零错误率(以跳过和重复为标准),相比之下 Transformer TTS 和 Tacotron 2 的错误率分别为 24% 和 34%。
-
可控性 作者展示了速度和暂停时长控制的有效性。
-
消融实验 作者确认了集成 1D 卷积到变换器中和采用序列级知识蒸馏等决策的有效性。他们揭示了在缺少每项决策时,性能(以平均意见分数为标准)的下降。
架构

FastSpeech 架构图来自于 FastSpeech 论文
第一图描绘了整个架构,包括编码器、长度调节器和解码器:

Feedforward Transformer (FFT)块在编码器和解码器中都使用。它类似于变换器中的编码器层,但将位置依赖的 FFN 替换为 1D 卷积。一个超参数N代表编码器和解码器中连接顺序的 FFT 块的数量。论文中将 N 设为 6。
长度调节器根据持续时间预测器(第三图)调整输入序列的长度。持续时间预测器是一个简单的网络,如第四图所示。
你应该能够直观地感受到数据流的形式如下:

编码器
编码器接受对应于文本中的字符的整数序列。可以使用字形到音素的转换器将文本转换为音素字符序列,如论文中提到的;但我们将简单地使用字母作为字符单元,并假设模型在训练期间可以学习任何需要的音素表示。因此,对于输入“Say hello!”,编码器接受一个序列 10 个整数,对应于[“S”,”a”,”y”,…,”!”]。
与变换器编码器类似,编码器的目的是为每个字符分配一个丰富的向量表示,该表示考虑了字符本身、其顺序以及与给定文本中其他字符的关系。与变换器类似,它在编码器中保持分配向量的维度用于 Add & Norm 目的。
对于具有n个字符的输入序列,编码器输出[h₁,h₂,…,hₙ],其中每个表示的维度为emb_dim。
长度调节器
长度调节器的目的只是重复给定给每个字符的编码器表示。其理念是文本中每个字符的发音通常对应于多个(或零个)Mel 谱图单元(由解码器生成);而不仅仅是一个声音单元。Mel 谱图单元指的是 Mel 谱图中的一列,为该列对应的时间窗口分配声音的频率分布,并对应于波形中的实际声音。
长度调节器的操作如下:
-
预测每个字符的 Mel 谱图单元的数量。
-
根据该数量重复编码器的表示。
例如,给定输入字符的编码器表示[h₁, h₂, h₃, h₄, h₅] 对应于“knight”。推理时会发生以下情况:
-
长度调节器将每个表示传递给持续时间预测器,后者利用表示(涉及与文本中所有其他字符的关系)来预测一个整数,该整数表示对应字符的 Mel 谱图数量。
-
假设持续时间预测器返回[ 1, 2, 3, 2, 1],那么长度调节器根据预测的持续时间重复每个隐藏状态,得到[h₁, h₂, h₂, h₃, h₃, h₃, h₄, h₄, h₅]。现在我们知道序列的长度(10)就是 Mel 频谱图的长度。
-
它将这个新序列传递给解码器。
注意,在真实环境中,将knight传递给 FastSpeech 并检查持续时间预测器的输出得到了[ 1, 8, 15, 3, 0, 17]。注意,字母k,g,h相比其他字母对 Mel 频谱图的贡献微不足道。事实上,当发音时真正发音的主要是n,i,t。
可控性 通过缩放预测的持续时间可以轻松控制速度。例如,如果[ 1, 8, 15, 3, 0, 17]加倍,则说单词knight将需要两倍的时间(0.5x加速),如果将其乘以一半(然后四舍五入),则说单词将需要一半的时间(2x加速)。还可以仅更改特定字符(例如,空格)对应的持续时间,以控制其发音的持续时间(例如,停顿持续时间)。
训练
在训练中,FastSpeech 不使用持续时间预测器来预测持续时间(它没有被训练),而是使用训练过的 TTS 变压器的注意力矩阵来预测持续时间。
-
在那个变压器中的交叉注意力将每个字符和 Mel 频谱图单元与一个注意力分数关联起来,通过注意力矩阵。
-
因此,在 FastSpeech 的训练中,为了预测字符c的 Mel 频谱图单元(持续时间),它使用 TTS 变压器中的交叉注意力矩阵计算对该字符的最大关注度的 Mel 频谱图单元数量。
-
因为交叉注意力涉及多个注意力矩阵(每个头一个),它在最“对角线”的注意力矩阵上执行此操作。这可能确保字符和 Mel 频谱图单元之间的现实对齐。
-
它使用这个持续时间来训练持续时间预测器(简单的回归任务)。这样在推理期间就不需要这个教师模型了。
解码器
解码器接收这个新表示,并且旨在预测每个 Mel 频谱图单元的频率内容(向量)。这相当于预测文本对应的整个频谱图,可以使用声码器转换为音频。
解码器与编码器具有类似的架构。它只是用线性层替换第一个块(嵌入层)作为最后一个块。这一层使用早期解码器中形成的复杂特征表示为每个 Mel 频谱图单元生成频率向量。
频率n_mels是该层的一个超参数。在论文中设置为80。

一位现代未来感的变压器正在编程计算机,画风为梵高风格 —— 作者使用 Canva 生成
实现

FastSpeech 架构图来自FastSpeech论文
策略
FastSpeech 架构对应于

我们将从实现开始:

和

然后我们可以实现编码器和解码器,因为它们的组合是

现在我们需要的是长度调节器

因为一旦完成最后一步就是

完整实现
为了避免在本文中插入大量代码,我之前准备了一个带注释的笔记本,包含了一个经过组织、优化和适合学习的原始实现版本,供推断使用。你可以在Github或者Google Colab找到它。强烈建议你在开始实现之前了解Transformer 架构中的不同组件。

一架现代未来感喷气式飞机飞向星空,画风为梵高风格 —— 作者使用 Canva 生成
我希望这篇文章中的解释能帮助你更好地理解 FastSpeech 及其架构,并指导你如何从零开始实现它。下次见,再见。
用 Python 实现的机器学习特征编码技术
数据科学工作流中需要考虑的 6 种特征编码技术
·发表于 Towards Data Science ·阅读时间 10 分钟·2023 年 1 月 10 日
--

由 Susan Holt Simpson 提供的照片,来源于 Unsplash
特征编码 将分类变量转换为数值变量,作为特征工程的一部分,使数据与机器学习模型兼容。根据分类变量的类型和其他考虑因素,有多种方法可以执行特征编码。
本文介绍了执行特征编码的一般技巧,详细阐述了 6 种特征编码技术,您可以在数据科学工作流中考虑这些技术,提供了使用时机的评论,最后介绍了如何在 Python 中实现这些技术。

图 1:特征编码技术总结 — 作者提供的图像
6 种特征编码技术的备忘单总结见图 1;请继续阅读以获取每种方法的详细解释和实现。
目录
-
标签 / 有序编码
-
One-Hot / 虚拟编码
-
目标编码
-
计数 / 频率编码
-
二进制 / BaseN 编码
-
哈希编码
特征编码技巧
提示 1:防止数据泄露
考虑到特征编码的目的是将分类变量转换为数值变量,我们可以将“cat”、“dog”和“horse”等类别编码为 0、1、2 等数字。然而,我们必须记住数据泄露的问题,即测试数据中的信息泄露到训练数据中。
编码器必须仅在训练数据上进行拟合,使编码器仅学习训练集中存在的类别,然后用于转换验证/测试数据。 不要在整个数据集上拟合编码器!
问题自然接踵而至:“如果在验证/测试数据中有缺失或新类别怎么办?”,我们可以通过两种方式来处理——由于模型未在这些未见类别上训练,因此可以删除这些未见类别。或者,我们可以将它们编码为-1或其他任意值,以指示这些是未见类别。
提示 2:保存你的编码器
如前所述,编码器在训练数据上进行拟合(使用.fit方法),并用于转换验证/测试数据(使用.transform方法)。最好保存编码器以便稍后转换验证/测试数据。
保存编码器的其他好处包括能够检索类别或将编码的值转换回其类别(使用.inverse_transform方法),如果适用且需要的话。
现在让我们深入探讨特征编码技术!
№1. 标签 / 顺序编码器
标签编码器和顺序编码器将类别直接编码为数值(参见图 2)。
标签编码器用于名义分类变量(没有顺序的类别,例如红色、绿色、蓝色),而顺序编码器用于顺序分类变量(有顺序的类别,例如小、中、大)。

图 2:标签编码示例 — 作者提供的图片
给定基数(类别数量)为n,标签和顺序编码器将值编码从0到n-1。
何时使用标签 / 顺序编码器
-
名义/顺序变量:标签编码器用于名义分类变量,而顺序编码器用于顺序分类变量
-
支持高基数:标签和顺序编码器可以用于类别非常多的情况
-
未见变量:顺序编码器可以用任意值编码验证/测试集中的未见变量(参见下面的代码示例),默认情况下会抛出值错误
标签 / 顺序编码器的缺点
-
未见变量:标签编码器不对验证/测试集中的未见变量进行编码,并会抛出值错误,必须进行特殊的错误处理以避免此错误
-
将类别解释为数值:机器学习模型会将编码的列读取为数值变量,而不是将其解释为不同的类别
对于标签编码器,它一次只能编码一列,每个分类列必须初始化多个标签编码器。
from sklearn.preprocessing import LabelEncoder
# Initialize Label Encoder
encoder = LabelEncoder()
# Fit encoder on training data
data_train["type_encoded"] = encoder.fit_transform(data_train["type"])
# Transform test data
data_test["type_encoded"] = encoder.transform(data_test["type"])
# Retrieve the categories (returns list)
list(encoder.classes_)
# Retrieve original values from encoded values
data_train["type2"] = encoder.inverse_transform(data_train["type_encoded"])
对于顺序编码器,它可以一次编码多个列,并且可以指定类别的顺序。
from sklearn.preprocessing import OrdinalEncoder
# Initialize Ordinal Encoder
encoder = OrdinalEncoder(
categories=[["small", "medium", "large"]],
handle_unknown="use_encoded_value",
unknown_value=-1,
)
data_train["size_encoded"] = encoder.fit_transform(data_train[["size"]])
data_test["size_encoded"] = encoder.transform(data_test[["size"]])
# Retrieve the categories (returns list of lists)
encoder.categories
# Retrieve original values from encoded values
data_train["size2"] = encoder.inverse_transform(data_train[["size_encoded"]])
№2. 独热编码 / 虚拟编码
在 One-Hot 编码和虚拟编码中,分类列被拆分为多个包含 1 和 0 的列(参见图 3)。
这解决了标签编码和序数编码的缺点,即由于编码数据被表示为多个布尔列,列现在被视为分类列。

图 3:One-Hot 编码示例 — 图片由作者提供
给定基数(类别数)为 n,One-Hot 编码器通过创建 n 个额外的列来对数据进行编码。在虚拟编码中,我们可以去掉最后一列,因为它是虚拟变量,这将导致 n-1 列。
何时使用 One-Hot / 虚拟编码器
-
名义变量:One-Hot 编码器用于名义分类变量。
-
低到中等基数:由于为每个类别创建了新的列,因此建议在类别数目较少到中等时使用 One-Hot 编码,以免结果数据过于稀疏。
-
缺失或未见变量:sklearn 包中的 One-Hot 编码器可以通过为缺失的变量创建列、忽略未见的变量列来处理缺失或未见的变量,以保持特征列的一致性,默认情况下会引发值错误。
One-Hot / 虚拟编码器的缺点
-
虚拟变量陷阱:由于编码数据稀疏,可能会导致特征之间高度相关的现象。
-
大数据集:One-hot 编码会增加数据集中的列数,这可能会影响训练速度,并且不适用于基于树的模型。
One-hot 编码可以通过 sklearn 包中的 OneHotEncoder 或使用 pandas 的 get_dummies 方法来完成。
from sklearn.preprocessing import OneHotEncoder
# Initialize One-Hot Encoder
encoder = OneHotEncoder(handle_unknown="ignore")
# Fit encoder on training data (returns a separate DataFrame)
data_ohe = pd.DataFrame(encoder.fit_transform(data_train[["type"]]).toarray())
data_ohe.columns = [col for cols in encoder.categories_ for col in cols]
# Join encoded data with original training data
data_train = pd.concat([data_train, data_ohe], axis=1)
# Transform test data
data_ohe = pd.DataFrame(encoder.transform(data_test[["type"]]).toarray())
data_ohe.columns = [col for cols in encoder.categories_ for col in cols]
data_test = pd.concat([data_test, data_ohe], axis=1)
使用 pandas 内置的 get_dummies 方法时,必须手动处理验证/测试数据中的缺失和未见变量。
data_ohe = pd.get_dummies(data_train["type"])
data_train = pd.concat([data_train, data_ohe], axis=1)
№3. 目标编码
目标编码使用贝叶斯后验概率将分类变量编码为目标变量(数值变量)的均值。应用平滑技术以防止目标泄漏。
与标签编码和序数编码相比,目标编码用解释目标的值对数据进行编码,而不是用任意的数字 0、1、2 等。其他类似的编码方法可以使用信息值(IV)或证据权重(WOE)对分类变量进行编码。

图 4:目标编码示例 — 图片由作者提供
实现目标编码有两种方法
-
均值编码:编码值是目标值的均值,并应用了平滑。
-
Leave-One-Out 编码:编码值是目标值的均值,但不包括我们要预测的数据点。
何时使用目标编码器
-
名义变量:目标编码器用于名义分类变量。
-
支持高基数:目标编码器可以用于类别很多的情况,如果每个类别有多个数据样本则更好
-
未见变量:目标编码器可以通过用目标变量的均值编码未见变量来处理它们
目标编码器的缺点
-
目标泄露:即使进行平滑处理,这也可能导致目标泄露和过拟合。可以使用留一法编码和在目标变量中引入高斯噪声来解决过拟合问题
-
类别分布不均:类别分布在训练和验证/测试数据中可能不同,从而导致类别被编码为不正确或极端的值
目标编码需要使用命令 pip install category_encoders 安装 category_encoders python 包。
import category_encoders as ce
# Target (Mean) Encoding - fit on training data, transform test data
encoder = ce.TargetEncoder(cols="type", smoothing=1.0)
data_train["type_encoded"] = encoder.fit_transform(data_train["type"], data_train["label"])
data_test["type_encoded"] = encoder.transform(data_test["type"], data_test["label"])
# Leave One Out Encoding
encoder = ce.LeaveOneOutEncoder(cols="type")
data_train["type_encoded"] = encoder.fit_transform(data_train["type"], data_train["label"])
data_test["type_encoded"] = encoder.transform(data_test["type"], data_test["label"])
№4. 计数 / 频率编码
计数和频率编码将分类变量分别编码为出现次数和频率(归一化计数)。

图 5: 计数和频率编码示例 — 作者提供的图像
何时使用计数 / 频率编码器
-
名义变量:频率和计数编码器对名义分类变量有效
-
未见变量:频率和计数编码器可以通过用
0值编码未见变量来处理它们
计数 / 频率编码器的缺点
- 相似编码:如果所有类别的计数相似,则编码后的值将相同。
import category_encoders as ce
# Count Encoding - fit on training data, transform test data
encoder = ce.CountEncoder(cols="type")
data_train["type_count_encoded"] = encoder.fit_transform(data_train["type"])
data_test["type_count_encoded"] = encoder.transform(data_test["type"])
# Frequency (normalized count) Encoding
encoder = ce.CountEncoder(cols="type", normalize=True)
data_train["type_frequency_encoded"] = encoder.fit_transform(data_train["type"])
data_test["type_frequency_encoded"] = encoder.transform(data_test["type"])
№5. 二进制 / BaseN 编码
二进制编码将分类变量编码为整数,然后转换为二进制代码。输出类似于独热编码,但创建的列更少。
这解决了独热编码的缺点,其中 n 的基数不会导致 n 列,而是 log2(n) 列。BaseN 编码遵循相同的理念,但使用其他基数值而不是 2,从而得到 logN(n) 列。

图 6: 二进制编码示例 — 作者提供的图像
何时使用二进制编码器
-
名义变量:二进制和 BaseN 编码器用于名义分类变量
-
高基数:二进制和 BaseN 编码在类别数量很高时表现良好
-
缺失或未见变量:二进制和 BaseN 编码器可以通过在所有列中用
0值编码未见变量来处理它们
import category_encoders as ce
# Binary Encoding - fit on training data, transform test data
encoder = ce.BinaryEncoder()
data_encoded = encoder.fit_transform(data_train["type"])
data_train = pd.concat([data_train, data_encoded], axis=1)
data_encoded = encoder.transform(data_test["type"])
data_test = pd.concat([data_test, data_encoded], axis=1)
# BaseN Encoding - fit on training data, transform test data
encoder = ce.BaseNEncoder(base=5)
data_encoded = encoder.fit_transform(data_train["type"])
data_train = pd.concat([data_train, data_encoded], axis=1)
data_encoded = encoder.transform(data_test["type"])
data_test = pd.concat([data_test, data_encoded], axis=1)
№6. 哈希编码
哈希编码将分类变量编码为使用哈希函数生成的不同哈希值。输出类似于独热编码,但可以选择创建的列数。
哈希编码与二进制编码类似,因为它们比独热编码更节省空间,但哈希编码使用哈希函数而不是二进制数。

图 7: 使用 2 列的哈希编码示例 — 作者提供的图像
哈希编码可以将高基数数据编码为固定大小的数组,因为新列的数量是手动指定的。这也类似于TSNE或谱嵌入等降维算法,这些算法通过特征值和其他距离度量构建固定大小的数组。
何时使用哈希编码器
-
名义变量:哈希编码器用于名义分类变量。
-
高基数:哈希编码在类别数量较多时表现良好。
-
缺失或未见变量:哈希编码器可以通过在所有列中用空值编码未见变量来处理这些变量。
哈希编码器的缺点
-
不可逆:哈希函数是单向的,即原始输入可以被哈希为哈希值,但无法从哈希值中恢复原始输入。
-
信息丢失或冲突:如果创建的列数过少,哈希编码可能会导致信息丢失,因为多个不同的输入可能会导致哈希函数产生相同的输出。
哈希编码可以使用来自 sklearn 包的FeatureHasher或来自分类编码器包的HashingEncoder来完成。
from sklearn.feature_extraction import FeatureHasher
# Hash Encoding - fit on training data, transform test data
encoder = FeatureHasher(n_features=2, input_type="string")
data_encoded = pd.DataFrame(encoder.fit_transform(data_train["type"]).toarray())
data_train = pd.concat([data_train, data_encoded], axis=1)
data_encoded = pd.DataFrame(encoder.transform(data_test).toarray())
data_test = pd.concat([data_test, data_encoded], axis=1)
使用category_encoders,
import category_encoders as ce
# Hash Encoding - fit on training data, transform test data
encoder = ce.HashingEncoder(n_components=2)
data_encoded = encoder.fit_transform(data_train["type"])
data_train = pd.concat([data_train, data_encoded], axis=1)
data_encoded = encoder.transform(data_test["type"])
data_test = pd.concat([data_test, data_encoded], axis=1)
希望你对将分类数据编码为数值数据的不同方法有更多了解。选择使用哪种特征编码技术时,重要的是考虑分类数据的类型(名义或序数)、使用的机器学习模型以及每种方法的优缺点。
由于模式或趋势的变化,在测试期间也要考虑缺失或未见的变量,以确保数据科学工作流在生产中不会失败!
有用的链接
-
sklearn文档:scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing -
category_encoders文档:contrib.scikit-learn.org/category_encoders/
在 Pandas DataFrame 中使用正则表达式 (RegEx) 进行特征工程
发现如何轻松地使用正则表达式操作你的字符串列
·发表于 Towards Data Science ·阅读时间 13 分钟·2023 年 1 月 6 日
--

图片由 Clark Van Der Beken 提供,来源于 Unsplash
在 Pandas 中操作字符串列是数据工程师最常见的操作之一。大多数情况下,你会进行如拆分列、从列中提取关键信息等操作。这个任务通常被称为特征工程。在这篇文章中,我将带你了解一些你可以利用正则表达式的特征工程技巧。
正则表达式(通常称为regex或regexp)是一种字符序列,指定了文本中的搜索模式。正则表达式功能强大,但对初学者来说也可能令人畏惧。因此,在这篇文章中,我将提供一个温和的正则表达式入门介绍,帮助你入门。
我们的数据集
像往常一样,我喜欢使用的示例数据集是 Titanic 数据集。它包含几个字符串列,非常适合进行特征工程。
数据来源: 本文的数据来源于 https://www.kaggle.com/datasets/tedllh/titanic-train.
许可证—— 数据库内容许可证 (DbCL) v1.0 https://opendatacommons.org/licenses/dbcl/1-0/
让我们加载 Titanic 数据集并检查其各种行和列:
import pandas as pd
import numpy as np
df = pd.read_csv('titanic_train.csv')
df
具体来说,我将使用姓名和船舱列来进行特征工程:

本文中的所有图片均由作者创建。
正则表达式基础
在我们开始对 Titanic 数据集进行特征工程之前,了解正则表达式的基础并学习如何在 Python 中使用它将非常有用。
要在 Python 中使用正则表达式,请导入re包:
import re
在接下来的部分中,我将说明如何使用正则表达式来完成以下操作:
-
如何在字符串中搜索子字符串
-
如何基于特定大小写搜索字符串
-
如何使用锚点来搜索字符串的开始或结束
-
如何搜索具有特定模式的字符串
-
如何搜索具有重复模式的字符串
搜索字符串
假设你想在字符串“Allen, Mr. William Henry”中搜索单词“William”。你可以使用search()函数(来自re包),传入要搜索的单词和要搜索的字符串:
expression = 'William'
name_search = re.search(expression, "Allen, Mr. William Henry")
print(name_search)
search()函数返回一个包含搜索信息和结果的Match对象:
<re.Match object; span=(11, 18), match='William'>
在上面的输出中,你可以看到“William”从字符串的第 11 到 17 的索引处找到了匹配(不包括 18)。如果没有匹配,None将被返回。
你可以使用Match对象的group()函数获取匹配你搜索的单词:
if name_search:
# returns the part of the string where there was a match
print(name_search.group())
# William
我将在后面的部分更详细地解释分组的概念。
你可以使用Match对象的group()函数来指定匹配你搜索的组:
print(name_search.group(0)) # same as group()
上述语句都产生相同的输出:
William
基于不同大小写的字符串搜索
有时候你想根据特定的大小写搜索单词。例如,如果你想查找“William”或“william”,你可以使用[]结构来表示匹配单个字符。例如,以下表达式查找“W”或“w”,后跟“illiam”:
expression = '[Ww]illiam' # either W or w, followed by "illiam"
name_search = re.search(expression, "Allen, Mr. william Henry")
if name_search:
print(name_search.group())
# william
如果你想匹配不考虑大小写的单词,请使用(?i)修饰符:
expression = '(?i)william' # case insensitive search
name_search = re.search(expression, "Allen, Mr. WILLIAM Henry")
if name_search:
print(name_search.group())
# WILLIAM
搜索字符串的开始或结束
在正则表达式中,你可以使用锚点 — $和^,来指定你想匹配的字符串位置。例如,如果你想查找以“Henry”结尾的字符串,可以在表达式中使用$锚点:
expression = 'Henry$' # ends with Henry
name_search = re.search(expression, "Allen, Mr. William Henry")
if name_search:
print(name_search.group())
# Henry
同样,如果你想查找以“Allen”开头的字符串,可以在表达式中使用^锚点:
expression = '^Allen' # starts with Allen
name_search = re.search(expression, "Allen, Mr. William Henry")
if name_search:
print(name_search.group())
# Allen
搜索模式
假设我们有这样的字符串:“Allen, Mr. William Henry”。我们希望从句子中提取标题——“Mr”。根据标题的位置,我们可以看到以下模式:
-
标题前有一个空格
-
标题以句点结束
使用这些,我们可以创建以下表达式:
expression = ' ([A-Za-z]+)\.'
name_search = re.search(expression, "Allen, Mr. William Henry")
print(name_search)
if name_search:
print(name_search.group(0)) # group zero is everything that matches the pattern
print(name_search.group(1)) # group 1 is the first ()
上述表达式需要一些解释。
-
首先,有一个空格表示我们需要找到标题前面的起始空格。
-
接下来,我们有一对括号表示一个组。在这个组内,我们有
[]表示我们想匹配从“A”到“Z”或从“a”到“z”的任何字符,+字符表示我们想匹配一个或多个这样的字符。 -
.(句点)是正则表达式中的一个特殊字符,它匹配任何字符,除了换行符(\n)。所以如果我们想在字符串中匹配一个句点,我们需要使用\(转义字符)来取消.的特殊含义。
以下图形总结了上述描述:

name_search.group(0) 的第一个输出是(注意 Mr. 前的空格):
Mr.
这个输出包含你表达式的匹配项—— ‘ ([A-Za-z]+)\.’。第二个输出是:
Mr
这个输出是你表达式中第一个组的结果(因为只有一个组)—— ([A-Za-z]+)。因此,输出中不包含句点。
由于在你的表达式中只有一个组,如果你尝试访问结果中的第二个组,你会遇到错误:
print(name_search.group(2))
# IndexError: no such group
那么,如果我们在表达式中有多个组呢?
假设我们有一个包含两个称谓的名字的字符串(我知道这个例子有点牵强),我们想在这个字符串中找到这两个称谓。我们可以使用以下表达式来实现:
expression = ' ([A-Za-z]+)\. ([A-Za-z]+)\.'
titles_search = re.search(expression, "Allen, Dr. Mr. William Henry")
if titles_search:
print(titles_search.group(0)) # Dr. Mr.
print(titles_search.group(1)) # Dr
print(titles_search.group(2)) # Mr
以下图形解释了每个组的输出:

基于重复模式的提取
有时候你只是想提取重复模式。例如,假设你有一个包含不同称谓的字符串,如 ‘Mrs. Mr. Ms. Dr.’。要提取这个字符串中的所有称谓,你可以使用 findall() 函数,如下所示:
# returns a list of non-overlapping matches
titles_search = re.findall('([A-Za-z]+)\.', "Mrs. Mr. Ms. Dr.")
print(titles_search)
结果将是一个非重叠匹配的列表。在这种情况下,输出是:
['Mrs', 'Mr', 'Ms', 'Dr']
在 Pandas Series str() 函数中使用正则表达式
既然你已经对正则表达式有了一些了解,现在是时候看看如何在 Pandas 中使用正则表达式了。Pandas Series 的向量化字符串函数(Series.str())中的以下方法允许你使用正则表达式:
-
count() -
extract() -
match() -
contains() -
replace() -
findall() -
split()
让我们通过示例来逐一讲解这些函数。
使用 count() 函数
让我们从 count() 函数开始。假设你想找到 Titanic 数据集中所有名字以“Braund”开头并包含“Owen”的人。你可以使用以下表达式在 count() 函数中实现:
expression = '^Braund.+Owen'
print(df['Name'].str.count(expression))
这里是正则表达式中各种特殊字符的用法:
^字符串或行的开头
.匹配除了换行符之外的任何单个字符。
+1 次或更多次出现
*0 次或更多次出现
你将看到结果作为一个 Series,对于每个与表达式匹配的行,值为 1:
**0 1**
1 0
2 0
3 0
4 0
..
886 0
887 0
888 0
889 0
890 0
Name: Name, Length: 891, dtype: int64
如果你想计算与搜索表达式匹配的总行数,你可以使用 sum() 函数:
print(df['Name'].str.count(expression).sum())
# 1
要显示与表达式匹配的行,你可以将结果传入原始数据框:
display(df[df['Name'].str.count(expression) == 1])
以下图形显示了与表达式匹配的行:

注意表达式‘^Braund.+Owen’中的+。在这种情况下,它意味着 Braund 和 Owen 之间必须至少有一个字符。例如,“Braund,Owen”或“Braund Owen”。如果你想匹配像“BraundOwen”这样的名字,那么你的表达式应该使用*:
expression = '^Braund.*Owen'
如果你只想找到包含“William”一词的所有名字,只需将表达式设置为“William”,如下所示:
expression = 'William'

以下陈述显示了包含多次出现“William”一词的行:
expression = 'William'
print(df['Name'].str.count(expression))
print(df['Name'].str.count(expression).sum())
display(df[df['Name'].str.count(expression) > 1])

使用extract()函数
记得之前我提到过你可以从乘客的名字中提取标题吗?使用Name列,你可以使用extract()函数传入一个正则表达式来提取标题。结果可以作为新列添加到数据框中:
expression = ' ([A-Za-z]+)\.'
df['Title'] = df['Name'].str.extract(expression)
df

你还可以使用extract()函数的另一个有趣的列是Cabin列。如果你检查Cabin列的唯一值:
print(df['Cabin'].unique())
你将看到以下内容:
[nan 'C85' 'C123' 'E46' 'G6' 'C103' 'D56' 'A6' 'C23 C25 C27' 'B78' 'D33'
'B30' 'C52' 'B28' 'C83' 'F33' 'F G73' 'E31' 'A5' 'D10 D12' 'D26' 'C110'
'B58 B60' 'E101' 'F E69' 'D47' 'B86' 'F2' 'C2' 'E33' 'B19' 'A7' 'C49'
'F4' 'A32' 'B4' 'B80' 'A31' 'D36' 'D15' 'C93' 'C78' 'D35' 'C87' 'B77'
'E67' 'B94' 'C125' 'C99' 'C118' 'D7' 'A19' 'B49' 'D' 'C22 C26' 'C106'
'C65' 'E36' 'C54' 'B57 B59 B63 B66' 'C7' 'E34' 'C32' 'B18' 'C124' 'C91'
'E40' 'T' 'C128' 'D37' 'B35' 'E50' 'C82' 'B96 B98' 'E10' 'E44' 'A34'
'C104' 'C111' 'C92' 'E38' 'D21' 'E12' 'E63' 'A14' 'B37' 'C30' 'D20' 'B79'
'E25' 'D46' 'B73' 'C95' 'B38' 'B39' 'B22' 'C86' 'C70' 'A16' 'C101' 'C68'
'A10' 'E68' 'B41' 'A20' 'D19' 'D50' 'D9' 'A23' 'B50' 'A26' 'D48' 'E58'
'C126' 'B71' 'B51 B53 B55' 'D49' 'B5' 'B20' 'F G63' 'C62 C64' 'E24' 'C90'
'C45' 'E8' 'B101' 'D45' 'C46' 'D30' 'E121' 'D11' 'E77' 'F38' 'B3' 'D6'
'B82 B84' 'D17' 'A36' 'B102' 'B69' 'E49' 'C47' 'D28' 'E17' 'A24' 'C50'
'B42' 'C148']
舱房值以字母 A 到 G 开头,后跟一个数字值。提取字母(可能代表泰坦尼克号上的不同甲板或不同的舱房类型)和数字值可能会很有用。你可以按如下方式进行提取:
expression = '^([A-G])'
df['CabinType'] = df['Cabin'].str.extract(expression)
expression = '^.{1}(.*)'
df['CabinNumber'] = df['Cabin'].str.extract(expression)
df

表达式‘^.{1}(.*)’的意思是Cabin的值应该以一个字符开始(^.{1}),然后你想匹配其余部分的零个或多个字符((.*))。这将本质上提取舱房号。
不幸的是,Cabin 列有太多空单元格,使得这一列没有用处。
使用contains()和match()函数
记得之前我们用count()函数找出并统计了所有包含“William”一词的行吗?
expression = 'William'
print(df['Name'].str.count(expression).sum())
display(df[df['Name'].str.count(expression) == 1])
如果你不需要统计行数,只需获取这些行,可以使用contains()函数:
expression = 'William'
df[df['Name'].str.contains(expression)]

上述结果显示了所有名称中包含“William”一词的行。当然,你可以在表达式中使用^锚点来表示你想要所有名称以“William”开头的行:
expression = '^William'
df[df['Name'].str.contains(expression)]

如果你只想找到开头的字符串,你实际上可以使用match()函数。以下代码片段产生的输出与上述相同:
expression = 'William'
df[df['Name'].str.match(expression)] # Determine if string starts with
# a match of a regular expression
使用replace()函数
之前,我们使用乘客的名字提取了标题。让我们检查一下我们提取的不同标题:
print(df['Title'].unique())
# ['Mr' 'Mrs' 'Miss' 'Master' 'Don' 'Rev' 'Dr' 'Mme' 'Ms' 'Major' 'Lady'
# 'Sir' 'Mlle' 'Col' 'Capt' 'Countess' 'Jonkheer']
你会看到一些标题是重复的——像Mlle、Ms和Mme这样的标题可以通过replace()函数缩写为Miss:
df['Title'] = df['Title'].str.replace('Mlle', 'Miss')
df['Title'] = df['Title'].str.replace('Ms', 'Miss')
df['Title'] = df['Title'].str.replace('Mme', 'Miss')
这里是 Mme 和 Mlle 的含义:
Mme(Madame)——法语中对女性的称谓
Mlle (小姐) — 法语的礼貌称谓,传统上用于未婚女性。
如果你想直接替换Name列中的称谓,可以将正则表达式传递给replace()函数。
df['Name'] = df['Name'].str.replace(' (Mlle)\.', ' Miss.', regex=True)
df['Name'] = df['Name'].str.replace(' (Ms)\.', ' Miss.', regex=True)
df['Name'] = df['Name'].str.replace(' (Mme)\.', ' Miss.', regex=True)
Name列中的称谓现在将被更新。
使用findall()函数
如果你想快速找到Name列中与特定模式匹配的所有名称,可以使用findall()函数:
expression = '([A-Za-z]*(ll)[a-z]*)'
df['Name'].str.findall(expression)
上述表达式查找所有包含字母“ll”的名称。findall()函数返回一系列列表值:

列表中的每个项目是一个包含零个或多个元组的列表。每个元组包含两个元素 — 第一个是匹配的词,第二个是组的值。
要打印所有匹配的行,你可以使用列表推导:
[i for i in df['Name'].str.findall(expression) if len(i)>0]
上述表达式产生了以下输出:
[[('Futrelle', 'll')],
[('Allen', 'll'), ('William', 'll')],
[('Bonnell', 'll')],
[('William', 'll')],
[('Williams', 'll')],
[('William', 'll')],
[('Ellen', 'll'), ('Nellie', 'll')],
[('William', 'll')],
[('William', 'll')],
...
如果你想将名称打印得更美观,这里有一个小代码片段:
expression = '([A-Za-z]*(ll)[A-Za-z]*)'
for matches in df['Name'].str.findall(expression):
if len(matches)>0:
for names in matches:
print(names)
print('----')
结果如下:
('Futrelle', 'll')
----
('Allen', 'll')
('William', 'll')
----
('Bonnell', 'll')
----
('William', 'll')
----
('Williams', 'll')
----
('William', 'll')
----
('Ellen', 'll')
('Nellie', 'll')
...
使用split()函数
如果你检查每个乘客的名字,你会看到以下模式和用于查找它们的正则表达式:

更有趣的是这些表达式
-
([\’A-Za-z ()”//.-]+) -
([A-Za-z ()”//.-]*)
这是因为乘客的名字和姓氏可能包含以下特殊字符:
-
— 空格
-
‘— 撇号(例如 — “O’Driscoll”用于姓氏) -
()— 圆括号(例如 — “John Bradley (Florence Briggs Thayer)”用于名字) -
“— 双引号(例如 — “Anna “Annie””用于名字) -
/— 斜杠(例如 — “Carl/Charles Peter”用于名字) -
.— 句点(例如 “Martin (Elizabeth L. Barrett)”用于名字) -
-— 连字符(例如 — “Countess. of (Lucy Noel Martha Dyer-Edwards)”用于名字)
你现在可以使用以下代码片段提取乘客的姓氏、称谓和名字:
expression = '([\'A-Za-z ()"//.-]+), ([A-Za-z]+). ([A-Za-z ()"//.-]*)'
df1 = df['Name'].str.split(expression, expand=True)
display(df1)

列 0 和列 4 可以被删除,因为它们是空的:
df1.drop([0,4], axis=1, inplace=True)
display(df1)

最后,你可以为数据框分配一些有意义的列名:
df1.columns = ['Last Name','Title','First Name']
display(df1)

提取称谓有助于在年龄值缺失的情况下确定乘客的年龄。另一方面,提取姓氏有助于确定两个或更多乘客是否来自同一个家庭。
如果你喜欢阅读我的文章并且这些文章对你的职业/学习有帮助,请考虑成为 Medium 会员。每月费用为 $5,会员可以无限访问 Medium 上的所有文章(包括我的)。如果你通过以下链接注册,我将获得小额佣金(对你没有额外费用)。你的支持意味着我可以投入更多时间撰写像这样的文章。
## 使用我的推荐链接加入 Medium - Wei-Meng Lee
阅读来自 Wei-Meng Lee(以及 Medium 上的成千上万其他作者)的每个故事。你的会员费直接支持…
总结
这次关于在 Python 和 Pandas 中使用正则表达式的旅行相当迅速。由于我使用了相当多的函数,我认为总结我涵盖的函数以及你应该何时使用它们可能会有用:
-
count()— 如果你想计算与模式匹配的行数,请使用此函数 -
extract()— 如果你想从字符串中提取字符串,请使用此函数 -
match()— 如果你想检查一个字符串是否以特定字符串开头,请使用此函数 -
contains()— 如果你想检查一个字符串是否包含特定字符串,请使用此函数 -
replace()— 使用此函数将一个字符串替换为另一个字符串 -
findall()— 使用此函数查找字符串中所有模式的出现 -
split()— 使用此函数将字符串拆分为多列
玩得开心,并告诉我这对你是否有用!
使用 SHAP 进行特征重要性分析:我在 Spotify 学到的(在复仇者的帮助下)
原文:
towardsdatascience.com/feature-importance-analysis-with-shap-i-learned-at-spotify-aacd769831b4
确定主要特征并理解它们如何影响机器学习模型的预测结果
·发表于 Towards Data Science ·13 分钟阅读·2023 年 8 月 23 日
--
这篇文章是我在 Spotify 机器学习论文中学习内容的两篇文章之一。请务必查看第二篇文章,了解我如何成功显著优化我的模型。
## 提高模型准确性:我在 Spotify 机器学习论文中学到的技巧(+代码…
一个技术数据科学家的工具栈,用于改善顽固的机器学习模型
towardsdatascience.com
两年前,我在 Spotify 进行了一项有趣的研究项目,作为我的硕士论文的一部分。我学习了几种有用的机器学习技术,我相信任何数据科学家都应该掌握这些技巧。今天,我在这里为你们讲解其中之一。
在那段时间里,我花了 6 个月尝试建立一个预测模型,并解析其内部工作原理。我的目标是了解是什么让用户对他们的音乐体验感到满意。
这不仅仅是预测用户是否快乐(或不快乐),而是理解导致他们快乐(或不快乐)的潜在因素。
听起来很令人兴奋,对吧?确实如此!我喜欢每一部分,因为我学到了很多关于如何在音乐和用户行为背景下应用机器学习的知识。
(如果你对机器学习在音乐行业的应用感兴趣,那么我强烈推荐你查看这篇有趣的 研究 ,由 Spotify 顶级专家主导。这是必读的!)
机器学习与行为心理学在科技中的应用

作者图片(Midjourney)
在科技领域,像我的研究项目非常常见,因为很多工作都围绕着为用户/客户提供最佳个性化体验。
这通常意味着要深入了解人类心理,而机器学习可以是实现这一不可能目标的强大工具——理解人类。
当我们将机器学习与心理学和行为科学结合时,我们更接近于了解人类行为的完整画面。
怎么做?
我们建立模型来预测人们的反应。
有时我们尝试理解为什么模型会预测那种反应。这就像是在问模型——“嘿,你认为是什么原因导致用户有这样的行为?”
答案在于找到模型中最重要的变量,并理解它们对预测结果的单独影响。
在我的研究中,我做了我们称之为特征重要性的分析,并使用了一个强大的工具叫做SHAP。
在这篇文章中,我将向你解释:
-
什么是 Shapley 值,即 SHAP
-
为什么你需要知道如何使用 SHAP
-
如何使用 SHAP
欢迎来解读机器学习模型!
数据科学家无法逃避的困境
当你处理像 LightGBM 这样的模型时,这就是我在研究中使用的模型,你需要知道你正在处理的是一种特定类型的模型——一个‘黑箱’。
像 DC 漫画中的反派一样,黑箱模型是你应该害怕的东西。
为什么?因为解读这些就像是在打击小丑,而他在你面前咯咯笑。一定要确保他不会轻易泄露秘密。
在许多行业中,法规和流程要求你解释你是如何得出结果的。
在科技领域,知道如何解释你的预测结果对于 获得信任 或 了解模型的内部运作* 是很重要的。后者就是我研究中的情况。*
可解释性与复杂性
在为你的项目选择合适的模型时,你必须考虑许多因素,包括:
-
可解释性——你能解释一下你的模型是如何做到这些的吗?它是如何做出决策的? 例如,线性回归模型就是透明性的定义。你可以轻松追溯每个特征对模型输出的影响。可解释性也常被称为可解释性。
-
复杂性 — 你的模型的架构或表示有多复杂? 神经网络可以预测奇迹,但了解它们如何捕捉特征之间的关系将让你感到头疼。复杂性通常也与高预测能力相关。
白盒还是黑盒?
如你所见,这两个指标完全不一致。复杂的模型通常会牺牲可解释性。因此,你需要根据你的目标小心选择模型类型:

作者提供的图片(Midjourney)
-
白盒模型 — 这些模型易于解释/说明,因为它们清晰地展示了特征之间的关系。
你可以深入了解模型如何做出预测。
-
黑盒模型 — 这些模型可以产生准确的结果,但高度复杂,难以解释。就像黑洞一样,我们完全不知道里面发生了什么。所以我们就叫它们 BB,这样听起来不那么可怕。
黑盒模型通常是我们在技术中用来获得强大而准确结果的模型。这就是为什么理解它们对你作为数据科学家的职业生涯至关重要。
步骤 1 — 良好的预测结果
为了评估哪些因素推动了满意度,我必须:
-
选择一个可以用来描述用户满意度的指标。
-
选择正确的机器学习模型——在我的案例中是 LightGBM*。
-
为我的模型寻找并创建尽可能多的相关特征。
-
选择一个准确性指标来评估我的模型性能——在我的案例中是 ROC AUC 分数。
-
优化我的模型性能,以确保它对该指标的预测相对准确。
*LightGBM 是一个基于决策树的框架,它结合了梯度提升和集成方法来解决复杂问题。集成方法通过每棵树对数据有不同的知识并采用不同的决策方法来提高准确性。梯度提升通过顺序构建弱决策树工作,每棵树从前一个模型遇到困难的地方学习并修正预测值与真实值之间的差异,称为“残差”。这些树传播错误的梯度。

我的机器学习流程是这样的
由于本研究的保密性,我不能分享具体信息,但我会尽力让它听起来不像是你需要解读的东西。让 SHAP 来处理吧。
在我的研究中,我构建了一个 LightGBM 分类器,也就是一个 BB 模型,它输出一个二进制结果:
y = 1 → 用户似乎满意
y = 0 → 不太满意
我的目标是理解为什么听众有这样的感受,而不仅仅是搞清楚他们的感受如何。
起初,我的 ROC AUC 分数在 0.5 左右,这是分类器能得到的最差分数。这意味着算法有 50% 的概率预测是或否。这和人类的随机行为一样。
因此,在花了 2 个月尝试提高我的模型预测效果之后,我终于达到了一个令人满意的结果。
再一次,确保查看下面文章中我是如何做到这一点的!
## 提高模型准确性:我在 Spotify 的机器学习论文中学到的技术 (+Code…
一个技术数据科学家改善顽固 ML 模型的技术栈
towardsdatascience.com
只有在那时,我才能真正开始深入研究我的模型如何通过特征重要性进行预测。所以让我们立即开始吧!
第 2 步 — 特征重要性
SH-复仇者们来拯救
那么让我们倒带。
在我的研究中,我想评估驱动用户满意度的关键因素是什么。由于我的模型是一个 BB,唯一的方法就是进行特征重要性分析。
想象一下——复仇者们一起战斗以拯救世界。你怎么知道谁在拯救中最强大?
(我们都知道是钢铁侠,但让我们假装不知道!)
这就是 SHapley Additive exPlanations,或 SHAP 值,在机器学习的虚拟世界中所做的事情。

图片由 Mateusz Wacławek 提供,来源于 Unsplash
SHAP 值是什么?
SHAP 或 Shapley 值是基于合作博弈理论的。
它们测量每个特征——复仇者们——如何对模型的最终决策——拯救世界或摧毁世界——作出贡献。
在特征重要性的背景下,复仇者们 是单个 特征,而 拯救世界 是模型做出的 预测。
它是如何工作的?
这个方法的核心思想是考虑所有可能的特征组合,并测量当特定特征被包含或排除时模型预测的变化。
通过比较这些不同的组合,Shapley 值为每个特征分配一个值,代表其对预测的贡献。
关键的数学概念是 排列。我们考虑不同的特征排列并计算它们的边际贡献。通过考虑所有可能的排列并平均这些贡献,我们得出每个特征的 Shapley 值。
通过这样做,我们可以了解我们的模型是如何“拯救世界”——或者说它是如何进行预测的!
更具体地说,SHAP 值帮助我们理解:
-
影响预测结果的主要特征是什么?
在我们的案例中,每个复仇者在拯救世界中的力量有多强。想象一下绿巨人的拳头打到你脸上,那东西会把你打飞到多元宇宙中。
-
这些不同特征如何影响预测结果? 即每个复仇者在拯救世界中的力量的影响。我的意思是,看看猩红女巫旺达有多强,但她在:摧毁了半个世界!!!有时候英雄是反派(或者只是一个坏英雄),所以重要的是你要找出他们站在哪一边。
双重解释性
Shapley 值特别强大,因为它们提供了 BB 模型解释性的双重视角:
-
总体上。 通过提供每个特征的预测能力的总体概述,汇总所有用户。
-
局部上。 可以为每个用户计算 SHAP 值,以探索特征如何影响该特定用户的预测结果。
加入复仇者联盟怎么样?来吧!

作者提供的图片
请记住,下面使用的所有数据都是纯粹的示例,以保密此研究的机密性。
1. 编码你的变量
确保你的变量已经编码:
-
序数特征,以便模型保留序数信息
-
分类特征,以便模型可以解释名义数据
所以首先,让我们把我们的变量存储在某个地方。由于研究是保密的,我不能公开我使用的数据,所以让我们使用这些代替:
region = ['APAC', 'EU', 'NORTHAM', 'MENA', 'AFRICA']
user_type = ['premium', 'free']
ordinal_list = ['region', 'user_type']
然后,确保构建编码变量的函数:
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
def var_encoding(X, cols, ordinal_list, encoding):
#Function to encode ordinal variables
if encoding == 'ordinal_ordered':
encoder = OrdinalEncoder(categories=ordinal_list)
encoder.fit(X.loc[:, cols])
X.loc[:, cols] = encoder.transform(X.loc[:, cols])
#Function to encode categorical variables
elif encoding == 'ordinal_unordered':
encoder = OrdinalEncoder()
encoder.fit(X.loc[:, cols])
X.loc[:, cols] = encoder.transform(X.loc[:, cols])
else:
encoder = OneHotEncoder(handle_unknown='ignore')
encoder.fit(df.loc[:, cols])
df.loc[:, cols] = encoder.transform(df.loc[:, cols])
return X
然后将该函数应用于你的变量列表。这意味着你需要创建包含变量名称的字符串的列表,即一个列表用于你的序数变量,一个用于分类变量,还有一个用于数值变量。
def encoding_vars(X, ordinal_cols, ordinal_list, preprocessing_categoricals=False):
#Encode ordinal variables
df = var_encoding(df, ordinal_cols, ordinal_list, 'ordinal_ordered')
#Encode categorical variables
if preprocessing_categoricals:
df = var_encoding(df, categorical_cols, 'ordinal_unordered')
#Else set your categorical variables as 'category' if needed
else:
for cat in categorical_cols:
X[cat] = X[cat].astype('category')
#Rename your variables as such if needed to keep track of the order
#An encoded feature such as region will no longer show female or male, but 0 or 1
df.rename(columns={'user_type': 'free_0_premium_1'},
df.reset_index(drop=True, inplace=True)
return df
2. 准备数据
将你的数据框拆分,以获取你的训练、验证和测试集:
-
训练集 — 用于训练你选择的算法模型,例如 LightGBM
-
验证集 — 用于调整你的参数并优化预测结果
-
测试集 — 用于做出最终预测
在我的研究中,我使用了GroupShuffleSplit。它创建了一个用户定义数量的独立训练-验证分割。它通过将整个组随机分配到训练集或验证集来工作。
def split_df(df, ordinal_cols, ordinal_list, target):
#Splitting train and test
splitter = GroupShuffleSplit(test_size=.13, n_splits=2, random_state=7)
#If you're dealing with many rows belonging to the same id then make sure to split based on the same id
split = splitter.split(df, groups=df['user_id'])
train_inds, test_inds = next(split)
train = df.iloc[train_inds]
test = df.iloc[test_inds]
#Splitting val and test data
splitter2 = GroupShuffleSplit(test_size=.5, n_splits=2, random_state=7)
split = splitter2.split(test, groups=test['user_id'])
val_inds, test_inds = next(split)
val = test.iloc[val_inds]
test = test.iloc[test_inds]
#Defining your X (predictive features) and y (target_feature)
X_train = train.drop(['target_feature'], axis=1)
y_train = train.target_feature
X_val = val.drop(['target_feature'], axis=1)
y_val = val.target_feature
X_test = test.drop(['target_feature'], axis=1)
y_test = test.target_feature
#Encoding the variables in the sets
X_train = encoding_vars(X_train, ordinal_cols, ordinal_list)
X_val = encoding_vars(X_val, ordinal_cols, ordinal_list)
X_test = encoding_vars(X_test, ordinal_cols, ordinal_list)
return X_train, y_train, X_val, y_val, X_test, y_test
然后将该函数应用于你的数据框,以获取你的训练、验证和测试集:
X_train, y_train, X_val, y_val, X_test, y_test = split_df(df, ordinal_cols, ordinal_list, target='target_feature')
3. 训练你的模型
在这里,我假设你已经:
-
清理并预处理你的数据
-
调整了你的参数
-
优化了你的模型
import lightgbm as lgb
#Build your model
clf = lgb.LGBMClassifier(objective='binary', max_depth=-1, random_state=314, metric='roc_auc', n_estimators=5000, num_threads=16, verbose=-1,
**best_hyperparameters)
#Fit your model
clf.fit(X_train, y_train, eval_set=(X_val, y_val), eval_metric='roc_auc')
#Make the predictions
roc_auc_score(y_test, clf.predict(X_test))
4. 使用 SHAP 进行特征重要性分析
现在是你期待已久的时刻!
SHAP 包有许多不同类型的可视化,具体取决于你是否想要全球解释(所有用户汇总)还是局部解释(每个用户)。在我的研究中,我专注于结果的全球解释,因为我不关心特定用户。
我使用了一种特定类型的图表,我认为这种图表足够有效地可视化你的特征影响——汇总图。你将得到:
- 根据预测权重对特征进行 排名。

S1 — 作者提供的图片
- 每个特征的 方向性影响,即特征是否对预测结果产生正向或负向影响。

S2 — 作者提供的图片
注意: 颜色为灰色的特征是名义特征,无法编码,因此如果你关心这些特征,也要对它们进行编码!
🚨记住🚨
确保创建一个特征列表,你想评估其对预测结果的影响:你将把它们插入下面的‘my_features’代码行中。
这个函数将产生汇总图(见 S1,S2)
def shap_viz(model, X_val, title, target):
#Prepare the features you want to evaluate
my_features= [**List of features you want to evaluate**]
my_features_idx = list(np.flatnonzero(X_val.columns.isin(my_features)))
#Define your metric
class_names = ['happy_:)', 'not_happy_:(']
#Load the SHAP explainer, since LightGBM is a decision tree, I used TreeExplainer
explainer = shap.TreeExplainer(model)
#We train the SHAP values on the features based on the validation set
shap_values = explainer.shap_values(X_val)
plt.figure(figsize=(18, 6))
plt.subplot(1,2,1)
shap.summary_plot(shap_values, X_val.values, plot_type='bar', class_names=class_names, feature_names=X_val.columns, max_display=20, show=False, plot_size=None)
plt.title(f'{title} – Weight of the impact of each feature')
plt.subplot(1,2,2)
#shap_values[n] evaluates the impact of features for predicting satisfaction
#n = 0 -- not happy, n = 1 -- happy
shap.summary_plot(shap_values[1], X_val.values, feature_names=X_val.columns, max_display=20, show=False, plot_size=None)
plt.title(f'{title} – Directional impact of each feature')
plt.tight_layout()
#To save a picture of your SHAP viz
plt.savefig(f'{title}.png', bbox_inches='tight', dpi=100)
plt.show()
#To print a list of the top 10 most impactful features
vals = np.abs(shap_values[0]).mean(0)
feature_importance = pd.DataFrame(list(zip(X_val.columns, vals)), columns=['col_name', 'feature_importance_vals'])
feature_importance.sort_values(by=['feature_importance_vals'], ascending=False, inplace=True)
for i in list(feature_importance.col_name.head(10)):
print(i)
return shap_values
然后运行函数来获取 SHAP 可视化,瞧!
your_shap = shap_viz(clf, X_val, title='insert_title', target='target_feature')
你刚刚让小丑泄密;连蝙蝠侠也做不到。
你也知道哪个复仇者拯救了世界!如果你没理解这个笑话,那就对不起你没花时间认真阅读我的文章 😡
(现在我才意识到我把 Marvel 宇宙和 DC 宇宙混合了,这真是第一次)
最重要的部分——解释 SHAP 值
解释这些图形并不简单。花了我一些时间才学会如何正确进行。
在解释这些图形时,你需要考虑图中表示的两个元素:
-
SHAP 值是正还是负? 它在汇总图的左侧还是右侧?
-
特征值是高还是低? 看看图右侧从低到高的彩色条。
让我们举个例子
假设我们在分析用户在 Spotify 上的流媒体时间 streaming_time 与他们的满意度(第 1 步中的 y = 1)之间的关系。使用 SHAP 值:
-
正 SHAP 值(图的右侧)+ 高
streaming_time(红色)→ 更多的流媒体对应更高的满意度。 -
负 SHAP 值(图的左侧)+ 高
streaming_time(红色)→ 更多的流媒体表明对体验的满意度较低。
回顾
-
什么是 Shapley 值? 一种特征重要性方法,用于衡量每个特征对模型最终决策的贡献。
-
什么时候需要使用 SHAP? 当你需要解释模型如何做出预测,或当你想了解在处理难以解释的模型(如黑箱模型)时,哪些关键特征影响了模型输出。
-
如何使用 SHAP? 查看逐步过程。我无法总结这一部分。
我为你准备了礼物 🎁!
订阅我的 通讯 K 的 DataLadder,你将自动获得我的 终极 SQL 备忘单,包含我在大科技公司日常工作中使用的所有查询 + 另一个秘密礼物!
我每周分享在科技领域担任数据科学家的经历,提供实用的技巧、技能和故事,旨在帮助你提升能力——因为没有人真正了解,直到他们亲身体验!
如果你还没有做过这个
不久见!
特征子集选择
原文:
towardsdatascience.com/feature-subset-selection-6de1f05822b0
关于特征选择的教程和推荐策略
·发表于Towards Data Science ·阅读时间 16 分钟·2023 年 3 月 22 日
--

图片来源:gokhan polat 于Unsplash
TL;DR
特征子集选择在监督学习中非常重要,不仅因为它可以产生更好的模型,还因为它提供的洞见。随着机器学习(ML)对可解释性强调的增加,这一点尤为重要。
从业者面临的挑战是特征选择方法的种类繁多。在这篇文章中,我简要概述了这一领域,并提出了一种在大多数情况下有效的策略。该策略使用Wrapper进行最终选择过程,并在必要时使用permutation importance作为初步筛选。
引言
在数据分析中,使用多个特征描述的对象有时可以仅用这些特征的子集来描述而不丢失信息。识别这些特征子集被称为特征选择、变量选择或特征子集选择,是数据分析中的关键过程。本文简要概述了特征子集选择(FSS)方法,并提出了一种在大多数场景下都有效的策略。本文基于arXiv [1]上的一篇教程论文。该论文中提出的方法的 Python 代码可在Github上找到。
特征选择在机器学习中受到广泛关注,因为它可以带来许多好处:
-
更好的分类器: 特征选择的明显好处是它会提高准确性,因为冗余或噪声特征可能会损害准确性。或许令人惊讶的是,准确性的提升可能会相当有限,因为强大的机器学习技术被设计得对噪声具有鲁棒性。
-
洞察: 特征选择最持久的好处可能是它所提供的洞察。识别出有影响力的特征和无用的特征可以让我们对数据有更多了解。
-
数据收集: 在数据有成本的领域(例如医学诊断、制造业),确定用于分类任务的最小特征集可以节省开支。

图 1. 特征选择过程概述:主要有三大类方法,包裹法、过滤法和“嵌入”法,其中特征选择嵌入或是分类算法的副作用。反过来,特征选择可以被视为一种降维策略。图片作者提供。[1]
图 1. 总结了特征选择方法的主要策略。其他特征选择的调查[2,3]将特征选择方法分为三类,我们遵循相同的结构。
-
包裹法是特征选择方法的一种,其中分类器在特征选择过程中被包裹(见图 2)。这种包裹使得分类性能能够驱动特征选择过程。这有一个将特征选择与分类器性能绑定的优点,但这也伴随着显著的计算成本,因为在选择过程中将评估许多分类器变体。
-
过滤法涵盖了使用分类器性能以外的标准来指导特征选择的方法。通常,过滤器提供特征排名,然后选择策略使用该排名来选择特征子集。
-
嵌入法指的是特征选择作为分类器训练过程的副产品产生的任何方法。例如,训练决策树几乎总是会选择一个可用特征的子集来构建树。

图 2. 包裹法与过滤法的对比:(a)在包裹法中,分类器在搜索过程中被包裹。(b)过滤器策略使用独立评估(例如信息增益)来评分特征。图片作者提供。
方法论
在评估特征选择策略的性能时,我们通常想知道这些策略如何推广到未见过的数据上。作为这种评估的代理,我们保留一些数据用于测试(图 3 中的选项(b))。如果我们希望在模型开发过程中评估几种不同的特征选择方案,那么这些方案应在训练数据的范围内进行测试,而交叉验证是实现这一目标的最有效方式(选项(c))。需要记住的是,如果目标是在 ML 系统的部署过程中进行特征选择,那么可以使用所有可用的数据进行特征选择(图 3 中的选项(a))。

图 3. 评估方法论。 (a) 如果不需要估计泛化准确率,则可以使用所有数据进行模型开发的各个方面。 (b) 可以从训练中保留测试数据,以获得泛化准确率的估计。 (c) 交叉验证可以在训练数据中用于特征选择。图像由作者提供。
在继续之前,我们需要介绍将要使用的符号。假设我们有一个数据集D,由n个数据样本组成。D = ⟨X, y⟩,其中y是类别标签。这些示例由一组特征F描述,其中p = |F|,因此有n个对象由p个特征描述。所以X的维度是n×p,y是一个长度为n的向量。目标是识别一个子集S ⊂ F,以捕捉数据集中重要的信息。经过特征选择后,数据集被减少为X′,其维度为n×k,其中k = |S|。
本教程中使用的数据集的某些摘要统计信息如表 1 所示。这些数据集可以在GitHub 仓库中找到。

表 1. 本教程中使用的数据集的摘要统计信息。
包装器
如果 |F| 较小,我们可以在理论上尝试所有可能的特征子集,并选择最佳子集。在这种情况下,‘尝试’ 意味着使用特征子集进行分类器的训练和测试。这将遵循图 3 (c)中展示的协议,其中在训练数据上进行交叉验证以识别良好的特征子集,然后可以在测试数据上进行测试。然而,可能性数量为 2ᵖ,因此穷举搜索很快变得不可能——例如,如果p=20,则需要考虑的可能性超过 100 万。
尽管如此,这就是 Wrapper 特征选择策略的工作方式,重要的修改在于搜索可以是贪婪的或随机的,而不是穷举的。总体思路如图 2(a) 所示,分类器在特征选择过程中被包裹,即使用特征子集训练的分类器用于搜索过程。特征子集将通过对使用数据构建的分类器进行保留测试或交叉验证测试来评估。与 Wrapper 一起使用的主要搜索策略包括:
-
穷举搜索 评估每一个可能的特征子集。如果待考虑的特征数量较少,则可以考虑所有特征组合。然而,如果 p > 20,将会有数百万个特征子集需要考虑,这使得穷举搜索变得不切实际。
-
顺序前向选择 (SFS) 从未选择任何特征开始,并考虑所有包含单一特征的分类器(见图 4 (a))。选择其中表现最佳的分类器,然后评估包括该特征的两个特征组合。这个过程继续进行,在每一步添加获胜的特征,直到无法进一步改进为止。
-
反向消除 (BE) 以与 FSS 相反的方向进行,从选择所有特征开始,考虑删除一个特征的选项,选择其中表现最佳的,然后继续消除特征(见图 4 (b))。同样,当无法进一步改进时,过程终止。
-
随机搜索 方法,如遗传算法或模拟退火,可以很方便地应用于 Wrapper 特征选择。每个状态可以通过特征掩码定义,在此掩码上可以进行交叉和变异操作 [4]。鉴于这种方便的表示方法,虽然评估适应度函数(通过交叉验证测量的分类器准确性)很昂贵,但使用随机搜索进行特征选择仍然非常直接。

图 4. 使用 Wrapper 进行特征子集选择:(a) 顺序前向选择 (b) 反向消除。作者提供的图像。
我们对 Wrapper 方法的探索将集中在 SFS 和 BE 上。这些是贪婪策略,通过探索可能的特征子集的搜索空间(如图 4 所示)。SFS 从一个空集开始,并向前推进,考虑基于单一特征构建的分类器。选择其中最好的一个,然后考虑将这个特征包含在内的特征对。当新增特征无法带来任何改进时,过程可以终止。顾名思义,Backward Elimination 的工作方向正好相反。它从一个完整的特征集开始(图 4(b)),并在每一步去除最不有用的特征。对于 SFS 和 BE,特征子集通过对训练数据的交叉验证来评估。评估是在分割数据集上进行的,所使用的分类器是 k-最近邻 (k-NN),因为它对噪声或冗余特征非常敏感。该 Python notebook 可在 Github 上获得。
正如图 5 所示,两种方法各有优缺点。SFS 倾向于需要较少的计算,因为被评估的模型较小,通常特征较少的分类器训练和测试所需的时间更短。SFS 倾向于选择较少的特征(见图 5(a));这种简洁通常是一种优势。另一方面,由于 BE 从较大的特征集开始,它可以更好地评估特征的组合效果。
SFS 和 BE 的总体结果如图 5(b) 所示。SFS 选择了七个特征,而 BE 选择了 11 个特征。两个特征子集都导致了训练数据上的准确率提高,但只有 SFS 子集在测试数据上表现出了更好的准确率。实际上,BE 的训练和测试准确率之间的差距是过拟合的证据——选择过程过于紧密地拟合了训练数据的特征,从而损害了泛化准确性。确实,过拟合被认为是 Wrapper 基于特征选择的一个问题 [4]。

图 5. 使用 Wrapper 的特征选择示例。(a) 随着 Sequential Forward Selection 的进行,训练数据上的准确率通过交叉验证进行测量。(b) SFS 和 BE 选择的特征子集的准确率估计。SFS 选择了 7 个特征,而 BE 选择了 11 个特征。图片由作者提供。
Filters
图 2 (a) 显示了 Wrapper 策略如何在特征选择过程中使用分类算法。图 2(b) 显示了 Filter 策略如何不使用分类器进行特征选择,而是使用单独的评估函数。Filters 不依赖于分类器的事实是一种双刃剑。这意味着 Filters 可能比 Wrappers 快得多,但所选择的特征可能与分类器的归纳偏差不一致。
过滤器将包含特征评分机制,然后根据这些评分制定选择策略。评分机制需要量化特征对结果的信息量。选择策略可能是:
-
选择排名前k的特征,
-
选择前 50%的特征,
-
选择评分高于最大评分 50%的特征,
-
选择具有非零评分的特征,
-
一种混合的过滤器/包装器策略,通过使用过滤器对特征进行排名,然后评估基于该排名的子集的性能。
我们现在将研究三种过滤器策略——卡方统计量、信息增益和排列特征重要性。
卡方统计量是衡量特征和类别标签之间独立性的指标。如果样本组织成图 6 所示的列联表,那么单元计数与偶然观察到的计数有多大的不同?图 6(a)中的数据表明左右手偏好与性别独立,因为比例是相同的。图 6(b)中的数据表明性别可以预测左右手偏好。

图 6:两个列联表显示了左右手偏好和性别之间的关系。如果左右手偏好是类别,那么在(a)中它与性别特征是独立的,在(b)中则存在依赖关系。图片由作者提供。
卡方统计量允许我们量化这一点:

统计量是对m个单元的总和。对于每个单元,我们考虑观察到的计数Oᵢ与如果特征和类别是独立的情况下预期计数Eᵢ之间的差异。在图 6(a)中,这个差异将为零,因为特征和类别是独立的。在(b)中将存在差异,因此统计量将是正值。一般来说,依赖关系越大,统计量也越大。如果特征值是数值型而不是分类型,则可以对特征值进行分箱,以便构建列联表[5]。
信息增益是一种信息论度量,用于量化特征包含有关类别的信息[6]。在图 6(b)中,通过知道性别我们获得了有关左右手偏好的信息。在二分类场景中,假设正负结果的概率分别是p和q。那么基于这些比例的数据集的熵是:

然后,数据集中任何特征f相对于类别标签的信息增益是:

与卡方统计量类似,信息增益(I-Gain)允许我们对特征进行排序以进行特征选择。这在图 7 中得到了说明。图中显示了根据两种度量对 Segmentation 特征进行排名的结果。Python 笔记本可以在这个 GitHub <链接> 找到。图表显示了按照 I-Gain 分数排序的分数。显然,分数的相关性很高(皮尔逊相关系数为 0.86),因此根据这些分数选择的特征子集应该是相当相似的。

图 7: 根据 I-Gain 和卡方统计量对 Segmentation 特征进行排序。图片作者提供。
当我们查看基于这些排名构建的特征子集的分类器性能时,这确实是这样。在图 8 中,我们可以看到不同 k 选择策略的结果(k = 3, 6, 10, 15)。在 k = 10 时,这两个分数都选择了一个特征子集,该子集在测试集上的准确率与使用完整特征集获得的准确率相同。这里的评估策略符合图 3 中的模式(b),特征评分使用训练数据完成,然后在测试集上进行测试。

图 8: 使用 I-Gain 和卡方统计量选择的 Segmentation 数据集中的前-n 特征的准确率估计。图片作者提供。
排列特征重要性 基于这样的原则:如果我们想了解某物对过程的重要性,我们可以打破它以查看会发生什么。为了评估特征的重要性,我们可以在测试集中对该特征的值进行置换,以查看对整体准确率的影响。如果在这种情况下变量被加噪时错误显著增加,那么该变量是重要的。如果错误没有增加,那么该变量对分类没有用。整体过程如下:
-
在数据上拟合分类器
-
计算基准准确率
-
打乱特征值并重新计算准确率
-
测量与未打乱错误的误差增加
这个过程通常会重复多次(例如 10 次),以获得更稳定的特征重要性分数。排列重要性被用作本文末尾提出的策略中的第一阶段。
在图 9 中,我们看到 Segmentation 数据集上 k-NN 和高斯朴素贝叶斯的排列重要性分数(笔记本 <在这里>)。我们可以看到排名合理地相关,但并不完全相同。这种差异是由于不同的分类器“偏好”不同的特征以及特征选择方法固有的不稳定性。

图 9: 显示 k-NN 和高斯朴素贝叶斯在分割数据集上的特征重要性得分的条形图。图片来源于作者。
嵌入方法
在本节中,我们讨论从分类算法中自然产生的特征选择方法或作为算法的副作用出现的方法。我们将看到,决策树和逻辑回归中,特征选择可以成为模型构建过程的一个整合部分。
决策树: 从数据集中构建决策树通常需要特征选择,因为有些特征不会出现在树中。没有包含在树中的特征实际上被有效地排除。我们在图 10 中展示了在企鹅数据集上的一个例子。

图 10: 企鹅数据集的决策树。虽然数据由四个特征描述,但只有三个特征被选择。图片来源于作者。
在这个例子中,数据集被 50:50 地分成了训练集和测试集。该树在训练数据上进行了训练,并在测试数据上的准确率为 93%(见笔记本 <这里>)。该数据集有四个特征,鳍长,喙长,喙深度 和 体重。从图 10 中的树可以看出,四个特征中的三个被选择,体重未被选择。
该树是使用默认的scikit-learn参数构建的,因此没有进行修剪。在决策树学习中,通常会限制(即修剪)树的大小以防止过拟合。使用修剪来防止过拟合将进一步推动特征选择,因为在较小的树中选择的特征更少。
逻辑回归,Lasso: 在多变量线性模型中,如线性回归或逻辑回归,特征选择可以作为正则化的副作用来实现。在机器学习中,正则化指的是旨在简化模型以防止过拟合的机制。因此,正则化可以导致特征被排除。弹性网和 Lasso 是线性模型的常用正则化方法。这里我们将概述 Lasso 的工作原理[7]并展示 Lasso 的应用示例。从基础开始,多变量回归的工作方式如下:

依赖变量y是输入特征的线性函数;对于每个特征xᵢ,该特征的权重由相应的βᵢ参数决定。对于二分类问题([0,1]标签),我们可以使用逻辑回归,其中依赖变量是结果变量为 1 的对数几率。如果pr是标签为 1 的概率,那么几率是pr/(1-pr)。

因此,逻辑回归提供了一个类别概率:

正则化通过限制模型容量来防止过拟合;这通过限制权重的大小来实现。两种选择是 L₁或 L₂正则化:

因此,β参数在符合这些 L₁或 L₂约束的训练数据上进行拟合。结果表明,当使用 L₁正则化时,较弱的权重会变为零,即这些特征会被剔除。在 Tibshirani [7] 的原始 Lasso 论文中对为什么会发生这种情况有很好的解释。
为了在我们的示例数据集上演示这一点,我们将它们简化为二分类问题,以使整体过程更加透明(笔记本 <在这里>)。然而,使用 Lasso 进行特征选择也适用于多类别问题。结果见图 11 和图 12。由于数据集被减少到仅有两个类别(分割的水泥和窗户,以及企鹅的阿德利和凤头企鹅),因此准确率比多类别场景更高。


图 11: Lasso 对企鹅(上图)分割数据集的影响。Lasso 减少了β参数的大小;一些参数被减少到零。显示了 Lasso 在默认正则化(C=1)和较温和正则化(C=10)下的结果。图片由作者提供。
Lasso 特征减少的程度由正则化参数 C 控制。包括了两个正则化水平的结果,C=10 和 C=1。C=10 会导致较少的正则化,从而保留更多的特征。在这两种情况下,默认正则化导致特征减少过多,泛化准确率降低。对于企鹅数据集,仅保留了两个特征,而在分割数据集中保留了三个特征(见图 15)。较温和的正则化保留了更多特征,结果没有损失泛化准确率。


图 12: Lasso 对训练和测试准确率的影响。显示了 Lasso 在默认正则化(C=1)和较温和正则化(C=10)下的结果。图片由作者提供。
提议策略
正如开头所述,我们提出的策略将使用 Wrapper 进行最终选择过程,并在必要时使用 permutation importance 作为初始筛选器(GitHub <link>)。我们在这里呈现的示例使用了来自 UCI 数据库的 Ionosphere 数据集。该数据集有 34 个特征,因此有超过 170 亿个可能的特征子集。即使使用贪心搜索策略(例如 SFS 或 BE),对整个特征集应用 Wrapper 也会非常计算密集。因此,我们使用图 13 所示的两步策略;我们使用排列重要性将 34 个特征减少到一个候选集 18 个特征,然后使用 Wrapper。

图 13. 两阶段特征选择过程。绿色中的数字表示在实例中的特征数量。图像由作者提供。
本示例中使用的分类器是 k-NN,我们保留了 50% 的数据用于最终测试。排列特征重要性阶段的评分如图 14 所示。我们丢弃了没有正评分的特征,留下了 18 个特征用于 Wrapper 阶段。

图 14. 来源于排列重要性阶段的特征重要性评分。理论上这些评分应该始终为零或更大;在实际操作中,一些评分为负,表示特征排列后的结果实际上是偶然改进的。可以安全地丢弃这些特征。图像由作者提供。
然后我们仅使用训练数据对这 18 个特征进行 Wrapper 搜索。我们使用前面描述的向后消除方法,它选择了 16 个特征。在图中我们看到三个特征集的准确性估计。训练数据上的估计使用交叉验证,保留估计显示了从特征选择过程中保留的 50% 数据的准确性。

图 15. 三个特征集的准确性评分。训练集(用于特征选择的集)上的准确性通过交叉验证来测量。测试集上的准确性是从特征选择过程中保留的 50% 数据的简单保留度量。图像由作者提供。
我们看到两个特征选择阶段在训练数据上都提高了准确性,但对于 Wrapper 阶段,在测试数据上没有进一步改善。这与其他研究一致,这些研究表明深入的特征选择工作可能导致过拟合。
结论
这就总结了我们对最先进特征子集选择方法的概述。面临的挑战之一是有这么多可供选择的替代方法;我们提出的策略是:
-
如果你有大量的训练数据,上述的两阶段策略将是值得的。
-
如果训练数据稀缺且存在过拟合风险,可以在置换重要性步骤后停止特征选择。
本文所覆盖的方法的 Python 实现(使用 scikit-learn)可以在这个GitHub 代码库中找到。
参考文献
-
Cunningham, P., Kathirgamanathan, B., & Delany, S. J. (2021). 特征选择教程及 Python 示例。arXiv 预印本 arXiv:2106.06437。
-
Isabelle Guyon 和 André Elisseeff. (2003) 变量和特征选择导论。机器学习研究杂志,2003 年。
-
Luis Carlos Molina Félix, Luis Antonio Belanche Muñoz, 和 M Àngela Nebot Castells (2002) 特征选择算法:综述与实验评估。在 2002 IEEE 国际数据挖掘会议(ICDM 2002) 第 306–313 页。
-
John Loughrey 和 Pádraig Cunningham. 基于包装的特征子集选择中的过拟合:越努力效果越差。在国际人工智能创新技术与应用会议,第 33–43 页。Springer,2004 年。
-
Xin Jin 等人. “用于癌症分类的机器学习技术和卡方特征选择,基于 SAGE 基因表达谱”。在:生物医学应用数据挖掘国际研讨会。Springer,2006 年,第 106–115 页。
-
John D Kelleher, Brian Mac Namee, 和 Aoife D’arcy. 预测数据分析的机器学习基础:算法、实例和案例研究。MIT 出版社,2020 年。
-
Robert Tibshirani. “回归收缩与套索选择”。在:皇家统计学会学报:B 系列(方法论) 58.1 (1996),第 267–288 页。
特征变换:PCA 和 LDA 教程
原文:
towardsdatascience.com/feature-transformations-a-tutorial-on-pca-and-lda-1ac160088092
使用 PCA 等方法减少数据集的维度
·发表于 Towards Data Science ·阅读时间 7 分钟·2023 年 7 月 14 日
--

照片由 Nicole Cagnina 提供,来源于 Unsplash
介绍
在处理高维数据时,通常使用主成分分析(PCA)等方法来降低数据的维度。这将数据转换为不同的(较低维度的)特征集。这与特征子集选择形成对比,后者选择原始特征的子集(参见[1]了解特征选择的教程)。
PCA 是一种将数据线性变换到较低维空间的方法。本文首先解释了什么是线性变换。然后,我们通过 Python 示例展示了 PCA 的工作原理。文章最后描述了线性判别分析(LDA),这是一种监督的线性变换方法。文中介绍的方法的 Python 代码可在 GitHub 上找到。
线性变换
想象一下,假设在假期之后,比尔欠玛丽£5 和$15,需要用欧元(€)支付。汇率为:£1 = €1.15 和 $1 = €0.93。因此,债务的€金额为:

在这里,我们将两维(£,$)的债务转换为一维(€)。图 1 中展示了三个例子,包括原始债务(£5,$15)和另外两个债务(£15,$20)以及(£20,$35)。绿色的点是原始债务,红色的点是投影到单一维度的债务。红色的线代表这个新维度。

图 1。 一个将英镑和美元债务转换为欧元的线性变换的示例。图片作者提供。
在图的左侧,我们可以看到如何将其表示为矩阵乘法。原始数据集是一个 3 x 2 的矩阵(3 个样本,2 个特征),汇率形成一个 1D 矩阵的两个分量,输出是一个 1D 矩阵的 3 个分量。汇率矩阵是变换;如果汇率改变,则变换也会改变。
我们可以使用下面的代码在 Python 中执行矩阵乘法。这些矩阵被表示为 numpy 数组;最后一行调用dot方法在cur矩阵上执行矩阵乘法(点积)。这将返回矩阵[19.7, 35.85, 55.55]。
这种数据变换的一般格式如图 2 所示。Y 是原始数据集(n 个样本,d 个特征);通过乘以具有 d x k 维的变换矩阵 P,将其减少到 k 个特征的 X’ 中。

图 2。 如果我们有一个由 n 个样本和 d 个特征描述的数据集 Y,则可以通过乘以维度为 d x k 的变换矩阵 P 将其减少到 k 个特征(*X’)。图片作者提供。
主成分分析
一般来说,变换矩阵 P 决定了变换。在图 1 的示例中,变换的细节由汇率决定,这些汇率是给定的。如果我们希望使用 PCA 来减少数据集的维度,我们如何决定变换的性质?嗯,PCA 由三个原则驱动:
-
选择一种保留数据扩展的变换,即选择保留数据点之间距离的维度。
-
选择彼此正交的维度(无冗余)。
-
选择 k 个维度,以捕获数据中大部分的方差(例如 90%)。
这些原理在图 3 中得到了说明。我们有一个由两个特征描述的个体数据集:腰围和体重。这些特征彼此相关,因此目标是将数据投影到不同的、不相关的维度中,而不丢失数据中的‘扩展’。这些新维度就是主成分,赋予了 PCA 其名称。作为投影的替代思路,你可以将其视为将数据云旋转以对齐图 3 中的红色坐标轴。无论哪种方式,新轴都是 PC1,第一个主成分和 PC2,它与 PC1 垂直。如果认为 PC1 捕获了足够的数据变化,则可以舍弃 PC2。

图 3. 一个二维数据集,其中的特征是体重和腰围。第一个主成分(PC1)应在捕获数据方差最多的方向上。PC2 应与 PC1 正交,以便它们独立。图像由作者提供。
对数据矩阵Y执行 PCA 的步骤如图 2 所示:
-
计算Y的列的均值和标准差。
-
从Y的每一行中减去列均值,并除以标准差以创建标准化中心矩阵 Z。
-
计算协方差矩阵C = 1/(n-1) Zᵀ Z,其中Zᵀ是Z的转置。
-
计算协方差矩阵C的特征向量和特征值。
-
检查特征值的降序以确定保留的维度数k — 这是主成分的数量。
-
顶部的k特征向量构成了转换矩阵P的列,该矩阵的维度为(p × k)。
-
数据通过X′ = ZP进行转换,其中X′的维度为(n × k)。
在以下示例中,我们将使用在Github上共享的哈利·波特 TT 数据集。数据的格式如下所示。共有 22 行和五个描述性特征。我们将使用 PCA 将其压缩为两个维度。
下面显示了执行此操作的 Python 代码。 Y_df是包含数据集的 Pandas 数据框。特征值和特征向量(ev,evec)在第 8 行计算。如果我们检查特征值,它们告诉我们每个主成分(PC)捕获的方差量;[49%,31%,11%,5%,4%]。前两个主成分将保留数据中 80%(49% + 31%)的方差,因此我们决定使用两个主成分。 X_dash包含投影到二维的这些数据。二维投影的数据如图 4 所示。可以说,PC1 维度代表能力/无能力,而 PC2 维度代表优良程度。弗雷德和乔治·韦斯利(双胞胎)被绘制在相同的点,因为他们在原始数据集中具有完全相同的特征值。

图 4. 哈利·波特数据集投影到二维(2 个主成分)。图像由作者提供。
如果我们使用 scikit learn 中的 PCA 实现,可以用三行代码完成这项工作:
在第 4 行中数据被标准化,第 5 行中设置了 PCA 对象,第 6 行中完成了数据转换。再次说明,这些代码可以在Github的笔记本中找到。
线性判别分析
应该明确的是,PCA 本质上是一种无监督的机器学习技术,因为它不考虑任何类别标签。实际上,在监督学习的背景下,PCA 不一定会有效——考虑到其重点是保持数据的分布而不考虑类别标签,这一点并不令人惊讶。在图 4 中,我们可以看到 PCA 在企鹅数据集 [2] 上的表现。这是一个由四个特征描述的三类数据集(也可在GitHub上找到)。

图 5. 这些散点图比较了 PCA 和 LDA 在企鹅数据集上的表现。PCA 由于未考虑类别标签信息,其表现不如 LDA。图片由作者提供。
在图 5 右侧,我们可以看到线性判别分析 (LDA) 在相同数据集上的表现。LDA 考虑了类别标签,并寻求一个最大化类别之间分离的投影。目标是揭示一个能够最大化类别间分离并最小化类别内分离的变换。这些可以在两个矩阵中计算:Sᵦ 代表类别间分离,S𝓌 代表类别内分离:


其中 n𝒸 是类别 c 中对象的数量,μ 是所有示例的均值,μ𝒸 是类别 c 中所有示例的均值:

这些求和中的组件 μ, μ𝒸, xⱼ 是维度为 p 的向量,因此 Sᵦ 和 S𝓌 是维度为 p × p 的矩阵。最大化类别间分离和最小化类别内分离的目标可以结合成一个称为 Fisher 判别准则的单一最大化:

这种表述以及找到最佳 Wˡᵈᵃ 矩阵的任务在 [3] 中有更详细的讨论。现在我们只需认识到 Wˡᵈᵃ 的作用与 PCA 中的 P 矩阵相同。它的维度是 p × k,将数据投影到一个 k 维空间中,以最大化类别间分离并最小化类别内分离。目标由两个 S 矩阵表示。我们可以看到图 5 右侧,它做得相当不错。
虽然 LDA 的内部工作机制可能看起来复杂,但在 scikit-learn 中它非常简单,如下面的代码块所示。这与上述的 PCA 代码非常相似;主要区别在于在拟合 LDA 时会考虑 y 目标变量;而 PCA 并非如此。
结论
本文的目标是解释数据转换的基本原理,展示 PCA 和 LDA 在 scikit-learn 中的工作方式,并展示这些方法的一些实际示例。
这些示例的代码和数据可以在 GitHub 上找到。关于这些方法的更深入处理在此 arXiv 报告 [3] 中进行了介绍。
参考文献
[1] P. Cunningham, “特征子集选择”,Towards Data Science,2022 年,[在线],towardsdatascience.com/feature-subset-selection-6de1f05822b0
[2] A.M. Horst, A.P. Hill, K.B. Gorman KB, palmerpenguins: Palmer Archipelago(南极洲)企鹅数据,2020 年,doi:10.5281/zenodo.3960218,R 包版本 0.1.0,allisonhorst.github.io/palmerpenguins/
[3] P. Cunningham, B. Kathirgamanathan, & S.J. Delany, 特征选择教程(带有 Python 示例),2021 年,arXiv 预印本 arXiv:2106.06437
二月刊:聊聊聊天机器人
月刊
(是的,你知道的那一个)
·
关注 发表在 数据科学之道 ·3 分钟阅读·2023 年 2 月 2 日
--
自 OpenAI 推出 ChatGPT 以来的两个月中,这里有一件事我们可能都能达成共识:这真的是很多。
自从聊天机器人出现以来,我们 TDS 的情况如何?感谢你的提问!想象一下过山车,设在镜子迷宫中,灯光闪烁:这大致可以给你一个从编辑角度看待的体验。
我们从未见过单一话题能以如此强烈的程度吸引我们社区的集体关注,看到读者和作者因这一工具的潜力而充满活力令人兴奋。看到被炒作的文章(我们没有发布)以及我们审核队列被 AI 生成的文章堵塞(据我们所知,我们也没有发布那些文章),则令人沮丧。
就像我们在Medium的朋友们一样,他们最近反思了他们关于 ChatGPT 及类似工具的政策,我们也花时间讨论了聊天机器人的影响。结果,我们更新了我们的指南,并明确表示我们只接受完全由人类撰写的文章。
在这段时间里,幸运的是,我们的作者继续利用他们的专业知识分享有关这一新兴技术的智能、审慎和及时的文章。为了庆祝他们出色的写作并帮助读者以更全面、更有见地的视角看待 ChatGPT,我们挑选了一些最有力的贡献。(如果你发现了其他你喜欢的文章,欢迎在评论中分享链接。)
一如既往,感谢你花时间阅读我们发布的内容。如果你想以更直接的方式支持我们(同时访问我们的整个档案),请考虑成为 Medium 会员。
TDS 编辑亮点
-
ChatGPT 是如何工作的:机器人背后的模型(2023 年 1 月,8 分钟)
如果你需要对 ChatGPT 的工作方法和技术基础设施有一个清晰易懂的介绍,Molly Ruby的解释是必读之作。
-
ChatGPT 的碳足迹(2022 年 12 月,7 分钟)
在 ChatGPT 发布后的几天内,Kasper Groes Albin Ludvigsen试图回答关于该模型环境影响的关键问题。
-
并非全是彩虹和阳光:ChatGPT 的黑暗面(2023 年 1 月,9 分钟)
除了大型语言模型的有趣和创意方面,Mary Reagan 博士认为我们应当至少同样关注潜在的风险和伦理问题。
-
微软与谷歌:语言模型会超越搜索引擎吗? (2023 年 1 月,12 分钟)
ChatGPT 将如何改变我们的在线习惯和工作流程?展望(可能不久的)未来,Alberto Romero 反思了大型语言模型与传统网页搜索之间的激烈竞争。
-
ChatGPT 能写出比数据分析师更好的 SQL 吗? (2023 年 1 月,6 分钟)
Marie Truong的文章(我们 1 月份阅读量最高的文章!)探讨了一个显然在许多数据专业人士心中的问题:机器是否最终会让我们变得多余?
热门文章
如果你错过了,以下是上个月在 TDS 上最受欢迎的一些文章。
-
区分高级开发者与初级开发者的 5 个 Python 技巧 作者:Tomer Gabay
-
2023 年从数据分析师跃升为数据科学家的途径 作者:Mary Newhauser
-
最常用的 7 种特征工程技术 作者:Dominik Polzer
-
2023 年现代数据堆栈的未来 作者:Prukalpa 和 Christine Garcia
-
Pandas 与 Polars:语法与速度比较 作者:Leonie Monigatti
-
初学者应避免的 4 个常见 Python 错误 作者:Murtaza Ali
由于一批令人印象深刻的新 TDS 作者,我们今年有了一个非常强劲的开局——请与我们一起欢迎Jonah Breslow、Taylor Wagner、Ertuğrul Demir、Akshay Dagar、Will Badr、Frauke Albrecht、Chad Isenberg、Stefany Goradia、Aleksandra Ma、Boris Ruf、Galen Okazaki、Alle Sravani、Herman Michaels、Jazmia Henry、Trevor Coppins、Martin Leitner、Jacky Kaub、Shenyang(Andy) Huang、Christoph Möhl、Katherine Munro、Sebastian Charmot、Miriam Santos、Tin Nguyen、Amogh Vaishampayan、Murat Unal、Ebrahim Pichka、John Morrow、Oded Mousai、Christopher McDonald、Bárbara Barbosa、Jonathan Bogerd、Jonas Dieckmann、Alejandro Correa Bahnsen、Pushpak Pujari、Matthew Harris、Tomáš Neubauer、Maggie Hays、Vino Duraisamy、István Módos、Lukáš Zahradník、Mary Reagan PhD、Christian Koch、和Bale Chen等。如果你有有趣的项目或想法要与我们分享,我们非常乐意听取你的意见!
下月见。
2023 年女子世界杯通过 Plotly 可视化
原文:
towardsdatascience.com/fifa-womens-world-cup-2023-visualized-with-plotly-a7277edf6278
数据科学家的五个图表评论
·发布于 Towards Data Science ·阅读时间 5 分钟 ·2023 年 8 月 22 日
--

图片由 Your Lifestyle Business 提供,来源于 Unsplash
2023 年 7 月和 8 月,澳大利亚和新西兰联合举办了 FIFA 女子世界杯。共有 32 个国家队参赛,西班牙首次赢得冠军。重大体育赛事总是产生大量数据,我借此机会学习如何使用 Plotly。
Plotly 是一个开源图表库,用于创建交互式图表。它可以离线或在线使用,并与多种编程语言集成。我使用 Python,因为我对它最为熟悉,并创建静态图表。代码可以在 GitHub 上找到。
在五个数据故事中,我们将尝试不同的 Plotly 功能,并展示世界杯历史和今年比赛的一些有趣事实:
-
历史世界杯参赛情况
-
球员年龄与比赛表现
-
球员的俱乐部
-
送出最多活跃球员的国家
-
男子和女子世界杯奖金对比
1: 女子世界杯参赛情况
历史上,许多国家曾禁止女性踢足球。德国足球协会(DFB)在 1955 年决定,“缺乏优雅可能会损害女性脆弱的身体和灵魂,并会对道德和社会造成攻击”。直到 1970 年,禁令才被解除。今天,女性在全球范围内踢足球,只有少数国家仍有限制。
我们的第一个数据故事以条形图形式展示了各国女子足球队在世界杯中的参赛情况。自 1991 年首次举办以来,这项赛事已举行九次。来自每个大陆的国家都有参与。其中一些国家每届世界杯都参赛,而另一些仅参赛过一次。我很惊讶地发现朝鲜竟然参加了四次!

各国世界杯参赛情况。数据:维基百科。图片:作者。
这是生成静态条形图的代码。为了容纳众多条目,图例从默认的右侧位置移动到图表顶部。
2: 球队年龄和表现
国家队名单在维基百科上是公开的。我们展示了国家队球员的年龄分布图。颜色代码表示球队在比赛中的排名。据我所见,并没有明显的模式——各类中位年龄的球队进入淘汰赛阶段的机会相似。海地和赞比亚以非常年轻的阵容脱颖而出。

国家队球员的年龄。数据:维基百科。图片:作者。
为了生成这个图表,我使用了 plotly 的 box 函数,并在分类轴上添加了网格线。
3: 球员的职业生涯分布在哪里?
名单还提供了球员所在俱乐部的信息,即他们不在国家队时所效力的俱乐部。我们汇总了每个俱乐部的球员数量,并仅展示前 30 名。诸如巴萨、切尔西、巴黎圣日耳曼和阿森纳等知名的欧洲和英国俱乐部主导了这个名单。顶尖的亚洲俱乐部包括仁川现代钢铁和武汉江汉大学,这些俱乐部与其本国大量球员有关联。

世界杯中前 30 名足球俱乐部。数据:维基百科。图片:作者。
这是一个标准的折线图,由于只有一条线,因此图例被关闭。
4: 球员在全球的分布情况如何?
我们按照俱乐部所在国家汇总球员数据。少于 10 名球员的国家被归为“其他”。许多世界杯阵容成员效力于英格兰和美国的俱乐部,这些国家的女子足球代表性强。世界杯冠军西班牙也吸引了许多顶级球员。

世界杯阵容成员俱乐部。数据:维基百科。图片:作者。
在这里,我创建了一个堆叠条形图,其中分类轴按每个类别的总数排序。颜色方案是定性的 G10 方案,这是 plotly 的默认颜色方案之一。
5: 奖金
最后,让我们来看看参加并赢得世界杯赛事是否值得。我很惊讶地发现,直到 2007 年女子世界杯之前,根本没有奖金发放。在 2023 年澳大利亚/新西兰比赛中,总奖金为 1 亿美元。相比于 2022 年卡塔尔男子世界杯分发的 10 亿美元,这个数字仍然相对较低。如果我们将这两个曲线绘制在对数刻度上,我们可以看到女子奖金的增长速度似乎超过了男子奖金的增长速度。所以未来的某个时候,可能会实现同工同酬。

女子和男子世界杯的奖金。数据来源:FIFA / 维基百科。图片:作者。
对于这个图表,我创建了两个子图,每个子图包含两条线。y 轴刻度设置为对数刻度。女子早期比赛中零奖金的情况无法在此刻度上显示。

正在等待同工同酬的女性。照片由Magnet.me拍摄,来源于Unsplash
反思
我发现 Plotly 使用起来既方便又直观。所有标准统计图都可用。我需要一些时间来习惯图表的自定义,这与我在 matplotlib 中习惯的有所不同。互动功能非常适合数据探索,鼓励读者下载完整的笔记本并尝试使用。对于博客文章,我发现嵌入互动图表有些挑战,因此我选择了静态图表。
参考资料
-
GitHub 上的笔记本:
github.com/crlna16/medium_notebooks/blob/384a0d07e0aa65e35e7086bf8fe67c1d8e5e679e/plotly/fifa23.ipynb -
所有历史和队伍数据均取自维基百科:
en.wikipedia.org/wiki/2023_FIFA_Women's_World_Cup_squads -
FIFA 男子世界杯奖金:
www.totalsportal.com/football/fifa-world-cup-prize-money/ -
towardsdatascience.com/how-to-create-a-plotly-visualization-and-embed-it-on-websites-517c1a78568b
解析数据中最不寻常的片段
如何利用常识和机器学习找到关注的片段
·
关注 发表在 Towards Data Science · 13 分钟阅读 · 2023 年 7 月 13 日
--
照片由 Klara Kulikova 提供,来源于 Unsplash
分析师经常需要找到“有趣”的片段——那些可以集中精力以获得最大潜在影响的片段。例如,确定哪些客户群体对客户流失的影响最大可能是有趣的。或者你也可以尝试了解哪些订单类型会影响客户支持工作量和公司的收入。
当然,我们可以查看图表来找到这些突出的特征。但这可能是耗时的,因为我们通常跟踪几十个甚至几百个客户的特征。此外,我们需要查看不同因素的组合,这可能会导致组合爆炸。对于这样的任务,使用框架将非常有帮助,因为它可以节省你几个小时的分析时间。
在本文中,我想与大家分享两种找出最突出数据切片的方法:
-
基于常识和基础数学,
-
基于机器学习——我们在 Wise 的数据科学团队开源了一个库Wise Pizza,它能用三行代码给你答案。
示例:银行客户的流失
你可以在GitHub上找到这个示例的完整代码。
我们将以银行客户流失数据为例。这个数据集可以在 Kaggle 上找到,许可证为CC0: 公众领域。
我们将尝试使用不同的方法找到对流失影响最大的细分群体:图表、常识和机器学习。但让我们从数据预处理开始。

数据集列出了客户及其特征:信用评分、居住国家、年龄和性别、客户账户余额等。此外,我们还知道每个客户是否流失——参数exited。
我们的主要目标是找出对流失客户数量影响最大的客户细分群体。在此之后,我们可以尝试了解这些用户群体特有的问题。如果我们专注于解决这些细分群体的问题,我们将对流失客户数量产生最大的影响。
为了简化计算和解释,我们将把细分群体定义为一组过滤条件,例如,gender = Male 或 gender = Male, country = United Kingdom。
我们将处理离散特征,因此我们需要将连续指标(如age或balance)进行转换。为此,我们可以查看分布并定义合适的桶。例如,让我们看看年龄。

图表由作者提供
分桶连续特征的代码示例
def get_age_group(a):
if a < 25:
return '18 - 25'
if a < 35:
return '25 - 34'
if a < 45:
return '35 - 44'
if a < 55:
return '45 - 54'
if a < 65:
return '55 - 64'
return '65+'
raw_df['age_group'] = raw_df.age.map(get_age_group)
找到数据中有趣的细分群体的最直接方法是查看可视化。我们可以通过柱状图或热图查看按一个或两个维度分割的流失率。
让我们看看年龄与流失之间的关系。35 岁以下的客户流失率较低——不到 10%。而 45 到 64 岁的客户则保留率最差——几乎一半的客户已经流失。

图表由作者提供
我们可以再添加一个参数(gender)以尝试发现更复杂的关系。条形图无法显示二维关系,因此我们可以切换到热图。
女性的流失率在所有年龄组中都较高,因此性别是一个重要因素。

作者的图表
这样的可视化可以非常有洞察力,但这种方法有几个问题:
-
我们没有考虑到细分市场的规模,
-
查看所有可能的特征组合可能会非常耗时,
-
在一个图表中可视化超过两个维度是有挑战性的。
所以,让我们转向更结构化的方法,这将帮助我们得到一个优先级排序的有趣细分市场列表,并估算其效果。
常识性方法
假设
我们如何计算修复特定细分市场问题的潜在影响?我们可以将其与流失率较低的“理想”情景进行比较。

你可能会想我们如何估算流失率的基准。有几种方法可以做到这一点:
-
来自市场的基准: 你可以尝试寻找你所在领域中产品的典型流失率水平,
-
你产品中的高表现细分市场: 通常,你会有一些表现稍好的细分市场(例如,你可以按国家或平台进行划分),你可以将它们作为基准,
-
平均值: 最保守的方法是查看全球平均值,并估算达到所有细分市场的平均流失率的潜在效果。
我们还是安全起见,使用我们数据集中流失率的平均值作为基准——20.37%。
列出所有可能的细分市场
下一步是构建所有可能的细分市场。我们的数据集有十个维度,每个维度有 3–6 个独特值。总的组合数大约是 120 万。即使我们只有少数几个维度和不同的值,这看起来计算成本也很高。在实际任务中,你通常会有几十个特征和独特值。
我们确实需要考虑一些性能优化。否则,我们可能需要花费数小时等待结果。以下是减少计算的几个小贴士:
-
首先,我们不需要构建所有可能的组合。将深度限制在 4–6 层是合理的。你的产品团队需要关注由 42 个不同过滤器定义的用户细分市场的可能性非常低。
-
其次,我们可以定义我们感兴趣的效果规模。假设我们希望将保留率提高至少 1 个百分点。这意味着我们对少于 1%用户的细分市场不感兴趣。如果细分市场的规模低于这个阈值,我们可以停止进一步划分——这将减少操作次数。
-
最后但同样重要的是,你可以在实际数据集中显著减少数据大小和计算资源的消耗。为此,你可以将每个维度的小特征分组到一个
other组。例如,有数百个国家,每个国家用户的比例通常遵循齐夫定律,像许多其他真实数据关系一样。因此,你会有很多国家的用户占比低于 1%。正如我们之前讨论的,我们对这些小用户群体不感兴趣,可以将它们全部分组到一个country = other片段中,以简化计算。

作者绘制的图表
我们将使用递归来构建所有筛选器的组合,直到max_depth。我喜欢这个计算机科学的概念,因为在许多情况下,它可以优雅地解决复杂问题。不幸的是,数据分析师很少需要编写递归代码——我可以回忆起在 10 年的数据分析经验中遇到的三个任务。
递归的概念非常简单——就是在执行过程中你的函数调用自身。当你处理层级结构或图形时,它非常有用。如果你想了解更多关于 Python 中递归的内容,可以阅读这篇文章。
在我们的案例中,高层概念如下:
-
我们从整个数据集和没有筛选器开始。
-
然后我们尝试添加一个额外的筛选器(如果片段大小足够大且我们还没有达到最大深度),并将我们的函数应用于此。
-
重复前一步骤,直到条件有效。
num_metric = 'exited'
denom_metric = 'total'
max_depth = 4
def convert_filters_to_str(f):
lst = []
for k in sorted(f.keys()):
lst.append(str(k) + ' = ' + str(f[k]))
if len(lst) != 0:
return ', '.join(lst)
return ''
def raw_deep_dive_segments(tmp_df, filters):
# return segment
yield {
'filters': filters,
'numerator': tmp_df[num_metric].sum(),
'denominator': tmp_df[denom_metric].sum()
}
# if we haven't reached max_depth then we can dive deeper
if len(filters) < max_depth:
for dim in dimensions:
# check if this dimensions has already been used
if dim in filters:
continue
# deduplication of possible combinations
if (filters != {}) and (dim < max(filters.keys())):
continue
for val in tmp_df[dim].unique():
next_tmp_df = tmp_df[tmp_df[dim] == val]
# checking if segment size is big enough
if next_tmp_df[denom_metric].sum() < min_segment_size:
continue
next_filters = filters.copy()
next_filters[dim] = val
# executing function for subsequent segment
for rec in raw_deep_dive_segments(next_tmp_df, next_filters):
yield rec
# aggregating all segments for dataframe
segments_df = pd.DataFrame(list(raw_deep_dive_segments(df, {})))
结果是,我们得到了大约 10K 个片段。现在我们可以计算每个片段的预计效果,筛选出具有负面效果的片段,并查看潜在影响最大的用户群体。
baseline_churn = 0.2037
segments_df['churn_share'] = segments_df.churn/segments_df.total
segments_df['churn_est_reduction'] = (segments_df.churn_share - baseline_churn)\
*segments_df.total
segments_df['churn_est_reduction'] = segments_df['churn_est_reduction']\
.map(lambda x: int(round(x)))
filt_segments_df = segments_df[segments_df.churn_est_reduction > 0]\
.sort_values('churn_est_reduction', ascending = False).set_index('segment')
这应该是一个能提供所有答案的圣杯。但等等,那里有太多重复的片段和相互接续的片段。我们能否减少重复,只保留最有信息量的用户群体?

梳理
让我们看几个例子。
子片段age_group = 45–54, gender = Male的流失率低于age_group = 45–54。添加gender = Male筛选器并没有使我们更接近具体问题。因此,我们可以排除这些情况。

下面的例子显示了相反的情况:子片段的流失率显著更高,而且更重要的是,子片段包括了来自父节点的 80% 流失客户。在这种情况下,合理的做法是排除credit_score_group = poor, tenure_group = 8+片段,因为主要问题在于is_active_member = 0组中。

让我们筛选掉那些不那么有趣的片段。
import statsmodels.stats.proportion
# getting all parent - child pairs
def get_all_ancestors_recursive(filt):
if len(filt) > 1:
for dim in filt:
cfilt = filt.copy()
cfilt.pop(dim)
yield cfilt
for f in get_all_ancestors_recursive(cfilt):
yield f
def get_all_ancestors(filt):
tmp_data = []
for f in get_all_ancestors_recursive(filt):
tmp_data.append(convert_filters_to_str(f))
return list(set(tmp_data))
tmp_data = []
for f in tqdm.tqdm(filt_segments_df['filters']):
parent_segment = convert_filters_to_str(f)
for af in get_all_ancestors(f):
tmp_data.append(
{
'parent_segment': af,
'ancestor_segment': parent_segment
}
)
full_ancestors_df = pd.DataFrame(tmp_data)
# filter child nodes where churn rate is lower
filt_child_segments = []
for parent_segment in tqdm.tqdm(filt_segments_df.index):
for child_segment in full_ancestors_df[full_ancestors_df.parent_segment == parent_segment].ancestor_segment:
if child_segment in filt_child_segments:
continue
churn_diff_ci = statsmodels.stats.proportion.confint_proportions_2indep(
filt_segments_df.loc[parent_segment][num_metric],
filt_segments_df.loc[parent_segment][denom_metric],
filt_segments_df.loc[child_segment][num_metric],
filt_segments_df.loc[child_segment][denom_metric]
)
if churn_diff_ci[0] > -0.00:
filt_child_segments.append(
{
'parent_segment': parent_segment,
'child_segment': child_segment
}
)
filt_child_segments_df = pd.DataFrame(filt_child_segments)
filt_segments_df = filt_segments_df[~filt_segments_df.index.isin(filt_child_segments_df.child_segment.values)]
# filter parent nodes where churn rate is lower
filt_parent_segments = []
for child_segment in tqdm.tqdm(filt_segments_df.index):
for parent_segment in full_ancestors_df[full_ancestors_df.ancestor_segment == child_segment].parent_segment:
if parent_segment not in filt_segments_df.index:
continue
churn_diff_ci = statsmodels.stats.proportion.confint_proportions_2indep(
filt_segments_df.loc[parent_segment][num_metric],
filt_segments_df.loc[parent_segment][denom_metric],
filt_segments_df.loc[child_segment][num_metric],
filt_segments_df.loc[child_segment][denom_metric]
)
child_coverage = filt_segments_df.loc[child_segment][num_metric]/filt_segments_df.loc[parent_segment][num_metric]
if (churn_diff_ci[1] < 0.00) and (child_coverage >= 0.8):
filt_parent_segments.append(
{
'parent_segment': parent_segment,
'child_segment': child_segment
}
)
filt_parent_segments_df = pd.DataFrame(filt_parent_segments)
filt_segments_df = filt_segments_df[~filt_segments_df.index.isin(filt_parent_segments_df.parent_segment.values)]
现在我们有大约 4000 个有趣的段。对于这个玩具数据集,我们在处理后看到顶级段几乎没有差异。然而,使用现实数据时,这些努力往往会有所回报。

根本原因
我们可以做的最后一件事是保留我们段的根节点。这些段是根本原因,其他段被包含在其中。如果你想更深入地了解其中一个根本原因,请查看子节点。
为了仅获取根本原因,我们需要消除最终有兴趣的列表中所有具有父节点的段。
root_segments_df = filt_segments_df[~filt_segments_df.index.isin(
full_ancestors_df[full_ancestors_df.parent_segment.isin(
filt_segments_df.index)].ancestor_segment
)
]
所以,就是这样,现在我们有一个需要关注的用户组列表。由于数据中复杂的关系较少,顶级段仅为一维段,因为只有少数特征可以解释全部效果。

讨论我们如何解释结果是至关重要的。我们得到了一个客户段列表及其估算的影响。我们的估算基于这样的假设:我们可以将整个段的流失率降低到基准水平(在我们的示例中是平均值)。因此,我们估算了修复每个用户组问题的影响。
你必须记住,这种方法仅提供了一个关于关注哪些用户组的高层次视图。它并没有考虑是否能够完全解决这些问题。
我们写了很多代码来获得结果。也许还有另一种使用数据科学和机器学习来解决此任务的方法,不需要如此多的努力。
披萨时间
实际上,还有另一种方法。我们在 Wise 的数据科学团队开发了一个库Wise Pizza,可以在一瞬间找到最引人注目的段。它是基于 Apache 2.0 许可证的开源库,因此你也可以用于你的任务。
如果你有兴趣了解更多关于Wise Pizza库的信息,不要错过 Egor 在数据科学节上的演讲。
应用 Wise Pizza
这个库很容易使用。你只需编写几行代码,并指定你想要的维度和段的数量即可。
# pip install wise_pizza - for installation
import wise_pizza
# building a model
sf = wise_pizza.explain_levels(
df=df,
dims=dimensions,
total_name="exited",
size_name="total",
max_depth=4,
min_segments=15,
solver="lasso"
)
# making a plot
sf.plot(width=700, height=100, plot_is_static=False)

作者绘制的图表
因此,我们还得到了一个最有趣的段及其对我们产品流失的潜在影响的列表。这些段类似于我们使用之前的方法获得的段。然而,影响估算差异很大。为了正确解释Wise Pizza的结果并理解这些差异,我们需要更详细地讨论它的工作原理。
它的工作原理
这个库基于 Lasso 和 LP 求解器。如果我们简化它,这个库的功能类似于 one-hot 编码,为分段(之前计算的相同分段)添加标志,然后使用 Lasso 回归,目标变量是流失率。

正如你可能从机器学习中记得的,Lasso 回归倾向于有许多零系数,选择少数显著因素。Wise Pizza找到合适的alpha系数用于 Lasso 回归,从而得到指定数量的分段结果。
对于修订 Lasso(L1)和 Ridge(L2)正则化,你可以参考这篇文章。
如何解释结果
影响的估算是系数与分段大小的乘积结果。

正如你所见,这与我们之前估算的完全不同。常识方法估算了彻底解决用户组问题的影响,而 Wise Pizza 的影响显示了对其他选择分段的递增效果。
这种方法的优势在于你可以汇总不同的效果。然而,你在解释结果时需要准确,因为每个分段的影响依赖于其他选择的分段,因为它们可能是相关的。例如,在我们的案例中,我们有三个相关的分段:
-
age_group = 45-54 -
num_of_products = 1, age_group = 44–54 -
is_active_member = 1, age_group = 44–54。
对于age_group = 45–54的影响考虑了整个年龄组的潜在效果,而其他估算方法则从特定子组中估算附加影响。这些依赖关系可能会导致结果差异,因为min_segments参数会影响最终分段及其之间的相关性。
关注整体情况并正确解释Wise Pizza结果至关重要。否则,你可能会得出错误的结论。
我认为这个库是从数据中快速获得见解和首批分段候选项的宝贵工具。然而,如果我需要进行机会规模评估和更稳健的分析,以便与我的产品团队分享我们关注的潜在影响,我仍会使用常识方法和合理的基准,因为这样更容易解释。
TL;DR
-
在数据中找到有趣的切片是分析师的常见任务(尤其是在发现阶段)。幸运的是,你不需要制作大量图表来解决这些问题。有一些更全面且易于使用的框架。
-
你可以使用Wise Pizza ML 库来快速洞察对平均值影响最大的细分领域(它还允许你查看两个数据集之间的差异)。我通常使用它来获取第一个有意义的维度和细分领域列表。
-
ML 方法可以在一瞬间为你提供高层次的视图和优先级。然而,我建议你注意结果的解释,并确保你和你的利益相关者完全理解它。然而,如果你需要对修复整个用户组问题对 KPI 的潜在影响进行全面估算,使用基于算术的传统常识方法是值得的。
非常感谢你阅读这篇文章。我希望它对你有启发。如果你有任何后续问题或评论,请不要犹豫在评论区留言。
填空自监督在自然语言处理中的应用
原文:
towardsdatascience.com/fill-in-the-blanks-self-supervision-in-nlp-f0afb16dc7fd
为什么它强大以及如何解决
·发布于 Towards Data Science ·阅读时间 14 分钟·2023 年 5 月 10 日
--

图片来源 Patrick Tomasso 在 Unsplash
预测下一个词在语言建模中有着悠久而成功的历史,尤其是在大型语言模型中。它利用了大量的文本语料库:维基百科、分布在全球的公共网页等。相比其他类型的无监督学习,这种方法在这些语料库上的效果更强,因为学习是有监督的。此外,由于它是自监督的,因此不需要人工创建标记数据。
考虑
曝露在阳光下会导致 __
从一个足够丰富的英语句子语料库中,机器学习应该能够学习一个语言模型,能够合理地填补这些空白,在这个例子中是“皮肤癌”。
我们为什么称之为监督?因为我们从一个完整的标记序列开始,模糊掉最后一个词,要求模型预测它是什么,如果模型预测正确则奖励它,预测错误则惩罚它。所以这就是有监督学习。
最近,这种自监督方法以一种简单而强大的方式得到了扩展。具体而言,要填补的空白可以出现在文本的任何位置。因此,我们可能称之为文本重构任务,而不仅仅是文本补全任务。
以下是一些例子。每个实例是一个词的序列(以粗体标识),后面跟着一个或多个这些词被掩蔽的序列。模型的目标是正确填充这些空白。
**exposure to sunlight causes skin cancer.**
exposure to sunlight causes ___________.
____________________ causes skin cancer.
________ to ________ causes skin cancer.
exposure to sunlight ______ skin cancer.
在这篇文章的副标题中,我们提到这种扩展大大提升了预测下一个词的能力。这里我们解释一下为什么我们这么说。
当然,预测下一个单词可以迫使模型从各种场景中学习,因为很容易组装一个包含数十亿个真实单词序列的数据集。也就是说,将其推广到预测序列中任何掩蔽单词的子集可以大大扩展学习场景的范围。这就是为什么这种掩蔽在 BERT [2] 中扮演了核心角色。
考虑一个包含 n 个单词的序列。如果我们只允许最后一个单词被掩蔽,我们只会从中生成一个标记实例——填入最后一个单词。如果我们允许任何后缀被掩蔽,我们可以从中生成最多 n 个标记实例——填入一个后缀。如果我们允许任何子集的单词被掩蔽,我们可以从中生成约 2^n 个标记实例。这是因为有这么多可以被掩蔽的子集。
此外,正如我们在上面的例子中所见,掩蔽不同的单词将迫使模型学习剩余单词的表示,从而在可能的情况下预测被掩蔽的单词。在我们的例子中,填补右尾的空白迫使模型理解暴露在阳光下的影响,而掩蔽左尾的两个单词则迫使模型学习导致皮肤癌的原因。
这是另一个例子,我们在[3]中覆盖过。
想象一下我们有一个汽车名称的列表。例如:
本田思域、丰田Celica、福特Mustang、吉普Wrangler……
预测下一个单词无疑有助于模型学习特定品牌的模型。例如,Cherokee 和 Wrangler(以及其他)是Jeep的模型。然而,通过掩蔽任何单词的额外自我监督也有助于模型学习到模型名称往往能强烈预测品牌。例如,Celica 是 Toyota,Mustang 是 Ford,Wrangler 是 Jeep,等等。
这种学习将更好地服务于下游应用场景。例如回答问题。
Celica 的品牌是什么?
在本文的其余部分,我们将介绍几种不同的建模方法,并讨论它们的优点和局限性。
我们将讨论的方法在处理这种任务的细微版本时,远不如最先进的大型语言模型。例如,整个段落被掩蔽。
尽管如此,它们更容易理解,也许对一些读者更为熟悉,并且可以使用广泛可用的 NLP 工具包尝试,或者在某些情况下从头开始构建。
数据
假设我们有一组文本文档。文档是一个令牌(单词、标点符号等)的序列。在这篇文章中,我们将假设我们定义的学习任务不会跨越文档边界。
大部分情况下,我们将关注由单个英文句子组成的文档。我们可以想象,这些句子语料库是通过使用 NLP 将较长文档分割成句子从较长的文档语料库中得出的。也就是说,我们讨论的方法同样适用于更长的多句文档。只是对单句文档的描述更为方便。
在这篇文章中,我们将描述一种概念上简单易懂、实现简单且在训练和推断过程中速度较快的方法。
也就是说,它的范围仅限于填补单个遮罩区域的空白,无论是在左尾、右尾还是中间。大型语言模型可以填补多个遮罩区域的空白。此外,LLM 通常会更准确,尤其是在长序列上,如完整段落或甚至整页文本。
方法
我们的方法将包括两个 tries,一个前向 trie 和一个后向 trie,以特定的方式融合在一起。我们将这种数据结构称为前向-后向 trie(FB-Trie)。
我们将使用以下示例来说明前向-后向 trie 数据结构。假设我们的训练集中只有两个句子来构建 trie。
work causes stress.
work is good for you.
首先,让我们看看普通的前向 Trie。

图 1(作者提供):我们的数据上的前向 Trie
这个 Trie 只捕捉了两个根到叶子的路径上显示的两个单词序列。
Trie 中的一个节点隐式地表示在根到节点路径中以该节点结尾的单词序列。
让我们填补图 1 中未显示的重要细节。一个节点存储训练集中前缀与其匹配的序列数量。因此,在我们的示例中,根节点将存储 2,因为我们的训练集中有两个序列。假设有一个更丰富的训练集,表示 [work, causes] 的节点将存储 15,如果有这么多的序列以 work causes 开头的话。
令 i 表示一个节点,n(i) 表示它的计数。令 j 表示 i 的一个子节点。从 i 到 j 的弧的概率是 n(j)/n(i),我们将其称为 P(i, j)。
从这个 Trie 中,如果存在明显的高概率扩展,我们可以在右尾填补单个空白。让我们通过示例来说明这一点。
考虑 work ____。我们不会填补这些空白,因为我们可以想象有很多不同的句子以 work 开头。另一方面,如果我们有 work causes ___,用 stress 填补空白将更为合理*。
填补左尾的空白
如果填空查询是 ___ causes stress 呢?为了扩展我们的模型以回答这种查询,我们将添加一个后向 Trie。
后向 Trie 是如果我们反转训练集中序列而得到的 trie。它的结构与前向 trie 相同(图 1)。它还包括节点上的计数。
左尾空白现在可以通过相同的推理过程进行填充(如果可能的话),只不过我们查找的是空白需要填充的序列的反向字典树。
在填空输入中,我们将知道空白是位于左尾还是右尾,因此我们将知道是查看输入的前向字典树还是输入反向的后向字典树。
填充中间的空白
考虑当待填充的空白不在任何尾部时。例如工作 __ 压力。
对于这个任务,我们将设计一个新的数据结构,这是前向和后向字典树的特殊融合,向两个字典树添加了父节点弧线,并添加了桥接边以从一个字典树跳到另一个。
让我们从向前向字典树添加父节点弧线开始。

图 2(作者提供):带有父节点弧线的我们的数据上的前向字典树
我们能够添加父节点弧线,因为根据定义,字典树中的每个非根节点都有唯一的父节点。
接下来,我们将向后向字典树添加父节点弧线。
融合前向字典树和后向字典树
接下来,我们将融合这两个字典树。我们下面展示了我们示例中的融合字典树。

图 3(作者提供):我们示例中的前向-后向字典树
前向字典树的节点和弧线用实线表示,而后向字典树的节点和弧线用虚线表示。前向字典树中的父节点弧线用虚线曲线表示。为了减少混乱,后向字典树中不显示父节点弧线,但它们确实存在。
两个字典树通过桥接边融合在一起。这些是没有箭头的虚线边。每条边可以实现为两个节点之间的一对弧线,覆盖两个方向(这未被描绘)。
在图 3 中,只展示了许多桥接边中的几个——实际上是两个。
现在让我们解释这些桥接边是如何添加的。一旦理解了这一点,读者可以想象其余桥接边的位置。
想象一下在一个新序列上进行训练。为了具体说明,假设这是工作导致压力的标记化版本。首先,我们将这个序列插入前向字典树和后向字典树。现在我们考虑这个序列的所有可能的(前缀,后缀)分割。接下来,我们反转后缀。
以下是我们将获得的(前缀,后缀,反向后缀)三元组,排除前缀或后缀为空的情况。
**prefix suffix reverse suffix**
[work] [causes, stress] [stress, causes]
[work, causes] [stress] [stress]
现在我们按照如下方法处理每个三元组(前缀,后缀,反向后缀)。我们在前向字典树中查找前缀以找到它结束的节点。同样地,我们在后向字典树中查找反向后缀以找到它结束的节点。我们现在通过一条桥接边连接这两个节点。
模型大小和训练时间
在描述推理之前,让我们花些时间反思模型的大小以及训练所需的时间。
首先,由于我们有两个字典树——一个前向字典树和一个反向字典树。此外,我们还有桥接边。
在这篇文章中,我们认为模型的大小本身不是问题。最先进的技术可以轻松容纳巨大的模型。
更有趣的问题是训练模型需要多长时间。
首先,训练可以是增量的。也就是说,可以随时向 FB-Trie 添加新序列。事实证明,训练实际上非常快速。这是因为(i)将新序列插入前向字典树及其逆序到反向字典树是很快的,和(ii)枚举所有这些序列的分裂,进行查找以发现要桥接的节点,并添加各种桥接边也是很快的。
使用前向-反向字典树填充空白的说明
让我们用work __ stress来说明这一点。首先,我们查找前向字典树中的节点,称之为u,这是[work]结束的地方。接着我们查找反向字典树中的节点,称之为v,这是[stress]结束的地方。然后,我们沿着从反向字典树到前向字典树的桥接边走。接着,我们沿着[ work, causes ]在前向弧中的父节点走,最终到达[work]。现在我们在两个方向上都到达了相同的节点。在这个过程中我们沿着父弧走的标签序列,经过反转之后,将成为空白的值。在我们的示例中,我们只走了一条反向弧,这条弧标记为causes。所以我们的答案是“causes”。
推断:更一般的描述
在我们的work __ stress示例中,空白有一个唯一的解决方案。在一般情况下,L __ R,可能会有多个解决方案*。
(我们略微调整了术语,因为在本节中会有所帮助。L 和 R 是标记序列,而不是字符串。)
下面我们讨论如何处理一般情况。
首先,我们在前向字典树中查找 L 的最大前缀,称之为 L'。接着,我们在反向字典树中查找 R 的最大后缀,称之为 R'。让u表示前向字典树中 L'结束的节点,v表示反向字典树中 R'的逆序结束的节点。
这种情况如下图所示。

图 4(作者提供):一般推断——多解情况
由于 v 在反向字典树中不是叶子节点——否则就没有空白可以填充——因此必须至少有一条桥接边接触到v,如图 4 中的标注所示。
考虑任何一条这样的桥接边,并让w表示在前向字典树另一端的节点。我们按照前向字典树中的父弧顺序跟踪,从w开始。当我们首次到达一个节点,称之为x,这个节点位于前向字典树中的路径 L'上时,我们停止。这种情况总是会发生(见下一段)。
最佳情况是x等于v。最坏情况是x是前向字典树的根。
逆转从 w 到 x 的路径上的标签顺序会产生一种推断的解决方案。更准确地说,这个解决方案是针对填空问题 L’’ __ R’,这可能比我们开始时的问题 L __ R 更宽泛。
进一步扩展解决方案
每条触及 v 的桥接边都会生成一个唯一的解决方案。因此,形式为 L’’ __ R’ 的解决方案的数量就是触及 v 的桥接边的数量。请注意 L” 取决于 w。
我们可以通过从 u 开始而不是 v 来反向推断,从 u 到 w 走一条桥接边。这次 w 是反向字典树中的一个节点。然后我们在反向字典树中从 w 顺序地沿父节点弧走,直到第一次遇到反向字典树中的 R’ 节点,我们将其称为 x。 (这次 x 是反向字典树中的一个节点,而不是正向字典树中的。) 我们得到另一组形式为 L’ __ R’’ 的解决方案。注意 R’’ 取决于 w。
两组解决方案可能会重叠。这些集合容易去重,因为解决方案是由表示填充部分的路径的端点唯一确定的。请注意,路径本身可能不同,但端点(忽略起点和终点)将是相同的。
让我们通过图 3 中的“work __ stress”的例子来说明这一点。

图 5(作者提供):两种方法来填入“causes”在“work __ stress”中。
如果我们按照所描述的算法进行,有两种方法可以将“causes”填入“work __ stress”中。一种方法是从反向字典树中的[stress]的桥接边开始,另一种方法是从正向字典树中的[work]的桥接边开始。尽管这两种解决方案是通过不同的路径找到的,但在忽略起始点和终点后,两条路径上的端点是相同的:正向字典树中的[work]和反向字典树中的[stress]。
如果我们仅允许从其中一个字典树开始,我们不会得到重复。但我们不能排除可能会遗漏某些解决方案的可能性。
这是一个例子。假设我们在以下两个字符串的分词版本上进行了训练。
work causes burnout.
work causes stress.
考虑提示“work __ stress”。从反向字典树中触及[stress]的桥接边开始,我们会得到解决方案“work causes stress”。从正向字典树中触及[work]的桥接边开始,我们会得到解决方案“work causes burnout”。
很合理地怀疑第二个答案是否应该被认为是一个有效的解决方案,因为它忽略了词语stress。这是一个需要建模者做出的决定。这两种观点都有其合理性。
在这篇文章中,我们的观点是,如果有人想将第二个答案视为有效的解决方案,这可以通过枚举两个方向上的所有解决方案然后去重来实现。
评分解决方案
当填空题有多个解决方案时,为了能够按优劣顺序呈现它们,对它们进行评分是合理的。
我们已经看到一个评分有意义的例子。对于提示“工作 __ 压力”,如果我们认为“工作 导致倦怠” 也是一个解决方案,那么它应该比“工作* 导致 压力”的评分低。这是因为前者忽略了提示中的单词压力,而后者则没有*。
在这个例子中,前一个解决方案的粗体字比后一个多这一事实在此没有作用。答案中的单词数量本身并不重要,重要的是被忽略的提示信息的量。
现在考虑这个提示:“工作 __”。假设训练集仅包含这两个字符串
work causes burnout.
work causes stress.
“工作 导致倦怠” 和 “工作 导致压力”都是不错的解决方案。
现在想象一下,训练集中实际上包含了这两个字符串的多个副本,并且字符串“工作导致压力”的出现频率远高于工作导致倦怠。这就意味着,将“工作 导致压力”评为比“工作 __”更好的解决方案是合理的,而不是“工作 导致倦怠”。
对于上一段的场景,提供所需排名顺序的评分很容易实现。从前向查找树上的节点[工作, 导致]开始,选择标记为“压力”的弧的概率高于选择标记为“倦怠”的弧的概率。这些概率可以从之前讨论的前向查找树中存储的信息中轻松计算出来。
对于填空在左尾的提示,我们可以类似地对解决方案进行排名。这次我们使用的是反向查找树,而不是前向查找树。
唯一剩下的情况是当空白在中间,并且两个解决方案使用相同的提示信息时。也就是说,两个解决方案的形式为 L M1 R 和 L M2 R,其中 L 和 R 在两个解决方案中是相同的,但 M1 和 M2 不同。
为了覆盖这种情况,我们将定义解决方案 L M R 的评分如下。写作 M = [m1] M’ [m2],其中m1 和m2 是 M 中的第一个和最后一个标记。M’可能为空,这种情况下m1 等于m2。
我们将定义 L M R 的评分为前向查找树中 P(m1|L)与反向查找树中 P(m2|R)的乘积。
总结
在这篇文章中,我们首先讨论了填空自监督问题作为训练大型语言模型的一种强大方法。然后,我们提出了使用自定义融合的前向和反向查找树来解决这个问题的方法。
尽管这种解决方案在与现代深度学习和基于变换器的大型语言模型相比时并不具备竞争力,但在较小版本的问题中,特别是在回答单句填空题时,它可能非常有效。
从零实现相对容易理解且快速训练,同时也支持增量训练。
参考文献
-
towardsdatascience.com/contextual-text-correction-using-nlp-81a1363c5fc3 -
arxiv.org/pdf/1810.04805.pdf[BERT] -
medium.com/towards-data-science/multi-task-learning-4531eb32d77b
最终的 DXA-国家
基于长期图像的 AI 模型在健康和医学领域
AI 可以看到终点!深度学习通过单次和顺序体成分成像预测全因死亡率
·发表于 Towards Data Science ·阅读时间 9 分钟·2023 年 7 月 22 日
--

DXA 成像提供了多种体成分可视化类型。(图片由作者提供)
关键点,TLDR:
-
结合体成分成像和元数据(例如年龄、性别、握力、步速等)可以提供最佳的 10 年死亡率预测
-
相比单次记录模型,长期或顺序模型整体表现更佳,突显了在健康数据中建模变化和时间依赖性的必要性。
-
长期模型有潜力提供更全面的健康评估
人工智能(AI)和机器学习(ML)正在革新医疗保健,推动我们迈向精准医疗时代。开发 AI 健康模型的动机是减少死亡和疾病,并延长高质量的生活。经过良好训练的模型具有更全面分析数据的能力,提供了更全面的健康评估。
单次记录与长期模型
基于图像的医学 AI/ML 模型现在已经达到了一个成熟的阶段,它们常常能与人类表现相媲美甚至超越人类,熟练地识别那些容易被人眼忽视的模式和异常。然而,大多数这些模型仍然在单一时间点的数据上运行,仅提供一个特定时刻的健康快照。无论是单模态还是多模态模型,它们往往处理的是在相对相似时间框架内收集的数据,这些数据构成了预测的基础。然而,在医学应用的 AI/ML 的更广泛背景下,这些单一时间点的模型只是第一步——所谓的“低悬果实”。医学 AI 研究的一个前沿是纵向模型,它们提供了对一个人健康状况随时间变化的更全面的视角。
纵向模型旨在整合多个时间点的数据,捕捉个体健康轨迹而非单独的瞬间。这些模型利用人类健康的动态特性,因为生理变化是不断发生的。将这些变化映射到特定的结果或健康问题上,可能会成为预测医疗的游戏规则改变者。纵向数据的概念在临床实践中并不陌生——它经常被用来监测衰老和预测脆弱性。一个典型的例子是骨矿密度(BMD)的跟踪,这是骨质疏松症和脆弱性的关键标志。定期评估 BMD 可以检测到显著的减少,预示潜在的健康风险。
纵向模型开发挑战
从历史上看,纵向模型的开发面临了几个重要挑战。除了每个个体需要处理的大量数据和计算外,最关键的障碍在于纵向医学数据的策划。与单一时间点的数据不同,纵向数据涉及在较长时间内跟踪患者的健康信息,通常跨越多个医疗机构。这需要细致的数据组织和管理,使得策划过程既耗时又昂贵。多个成功的研究已经获得资助以前瞻性地收集纵向数据。这些研究报告了与患者在较长观察期内的保留相关的挑战。因此,尽管纵向模型具有潜在的好处,但其开发仍然是一个复杂且资源密集的任务。
目标
体成分的变化,包括瘦体组织、脂肪软组织和骨骼的比例,已知与死亡率相关。在我们的研究中,我们旨在利用体成分信息来更好地预测全因死亡率,简单来说,就是预测一个人可能的生命时间线。我们评估了基于单一时间点和纵向数据构建的模型的性能,分别称为我们的“单记录”和“顺序”模型。单记录模型使我们能够评估哪种信息最能预测死亡率。开发顺序模型的目的是捕捉随时间变化的情况,并评估这些变化如何影响死亡率预测。
数据
本研究的数据来自一个纵向研究,即健康、老龄化与体成分(Health ABC)研究,在该研究中,跟踪和监测了 3000 多名年长的多种族男女成人,持续时间长达 16 年。这项研究产生了一个丰富而全面的纵向数据集。作为研究的一部分,患者接受了全身双能量 X 射线吸收测量(TBDXA)成像,并收集了几项元数据。为了遵循最佳建模实践并避免数据泄漏或减轻过拟合,数据被分为训练集、验证集和保留测试集,按照 70%/10%/20%的比例划分。
我们使用全身双能量 X 射线吸收测量(TBDXA)成像来量化体成分,这一直被认为是金标准成像模式。历史上,患者元数据,包括年龄、体重指数(BMI)、握力、步速等变量,被用于评估衰老/死亡率,并作为体成分的替代测量。由于 DXA 扫描仪的可及性有限,患者元数据和体成分的替代测量得到了广泛使用。最近,随着扫描成本的降低以及不再需要医生转诊/订单/处方,可及性得到了极大改善。
单图像模型
构建了三个单记录模型,每个模型使用不同的数据输入,但所有模型的输出都是 10 年死亡概率。第一个模型仅使用患者元数据,使用了一个 32 单元的 ReLU 激活隐藏层和一个 Sigmoid 预测层的神经网络。第二个模型仅使用 TBDXA 图像作为输入,它包括一个经过修改的 Densenet121,该模型被修改为处理两个颜色通道而不是大多数自然图像中的三个颜色通道(RGB)。DXA 的双能量特性导致高低 X 射线图像,这些图像完全注册并堆叠成两个图像通道。第三个模型将模型一中的元数据嵌入与模型二中的 TBDXA 图像嵌入结合,然后通过一个 512 单元、一个 64 单元的全连接 ReLU 层进行处理,最后通过一个 Sigmoid 预测层进行预测。

纵向/序列模型
构建并评估了三个连续模型。单记录模型架构作为每个连续模型的基础,但去除了 sigmoid 预测层,使输出成为表示特征嵌入的向量。在研究过程中,收集了每位患者在多个时间点的数据。每个时间点的数据被输入到适当的模型中,以获得相应的特征向量。每位患者的特征向量被排序并堆叠成一个序列。训练了一个长短期记忆(LSTM)模型,以处理特征向量序列并输出 10 年的死亡率预测。如前所述,进行长期研究存在若干困难,其中保留和数据收集是常见问题。我们的研究也遇到这些问题,因此一些患者的数据点比其他患者更多。选择 LSTM 模型作为序列建模方法,因为它们不受限于每位患者使用相同的输入序列长度。即,LSTM 可以处理长度不同的序列,从而消除了如果患者数据点不足(约 10 个),需要填充序列的问题。

图像 + 元数据纵向模型胜出
在保留的测试集上的接收操作特征曲线下面积(AUROC)显示,元数据在单记录和连续模型中均优于仅使用 TBDXA 图像。然而,将元数据和 TBDXA 成像结合在这两种建模范式中都产生了最佳的 AUROC,这表明成像包含了元数据未捕捉的有用信息,能够预测死亡率。另一种解释是,元数据并不是预测死亡率的体成分的完整替代指标。如果它们是完整的替代指标,将 TBDXA 成像与元数据结合不会显著提高或改变 AUROC。组合结果产生了更好的 AUROC,表明成像提供了超出元数据捕捉的信息,进一步证明了成像的有效性。

总体来看,纵向或序列模型的表现优于单一记录模型。这一点在所有建模方法和输入数据类型(元数据、仅图像、组合的元数据和图像)中都适用。这些结果展示了建模变化及健康数据时间依赖性的重要性。
我们进行了综合辨别改进(IDI)分析,以评估将成像与元数据结合使用相较于仅使用元数据的好处。该分析是在表现优于单一记录模型的序列模型上进行的。IDI 的结果为 5.79,其中集成敏感性和特异性分别为 3.46 和 2.33。这表明,将成像与元数据结合使用可以将模型正确识别出未来 10 年内无法生存者的能力提高 3.46%,并将正确识别出未来 10 年内能够生存者的能力提高 2.33%。总体来说,这表明模型性能的改善大约为 5.8%。

那么,怎么了?
我们的研究强调了纵向人工智能/机器学习模型在预测医疗领域,特别是全因死亡率方面的潜力。单一记录模型和纵向模型的比较分析显示,后者表现更佳,这表明建模时间变化在健康数据分析中的关键作用。我们的发现具有临床意义,包括通过考虑患者历史或纵向数据的模型提供更准确和全面的健康评估。此外,我们对可解释性人工智能的深入分析提供了与高 10 年死亡概率密切相关的变量的见解。其中一些变量是可改变的,可以在临床上进行干预,为许多人提供了改善其寿命和健康期的机会。虽然开发纵向健康模型所需的数据存在,但适当的基础设施和机构支持尚未完全到位,以便高效地进行数据整理和大规模开发这些模型。然而,许多人正在努力克服这些障碍,纵向模型的开发是医学中人工智能众多令人兴奋的前沿之一。
这些发现的临床意义深远。纵向模型有可能通过提供关于患者健康轨迹的更精确、个性化的预测来改进护理。这样的模型可以指导主动干预,从而提升护理结果,甚至可能延长生命。此外,使用元数据和影像数据为未来的 AI/ML 模型树立了新的标准,建议采取协同的方法以获得最佳结果。这加强了对多维度、细致数据的需求,以准确全面地描绘患者的健康状况。这些发现代表了 AI/ML 在医疗保健应用中的重要进展,突显了在追求精准医疗过程中的激动人心的前景。
更多资源:
阅读论文
[## 深度学习预测基于纵向全身 DXA 成像的全因死亡率 …
死亡率研究已识别出预测全因死亡风险的生物标志物。这些标记物中的大多数,例如 …
www.nature.com](https://www.nature.com/articles/s43856-022-00166-9?source=post_page-----f0309d718980--------------------------------)
关于 IDI 和已发布示例的更多信息:
医生和数据科学家们,这里有一个关于 IDI 和 NRI 的介绍及一个 Python 代码示例。AUC 可能并不总是足够的 …
towardsdatascience.com [## 双能量三室乳腺成像用于成分生物标志物以提高检测 …
尽管乳腺成像如全视场数字乳腺摄影和数字乳腺断层扫描有助于减少 …
www.nature.com](https://www.nature.com/articles/s43856-021-00024-0?source=post_page-----f0309d718980--------------------------------) [## GitHub - LambertLeong/AUC_NRI_IDI_python_functions: 自定义 Python 函数帮助你进一步 …
自定义 Python 函数帮助你进一步分析机器学习模型和诊断测试 - GitHub …
github.com](https://github.com/LambertLeong/AUC_NRI_IDI_python_functions?source=post_page-----f0309d718980--------------------------------) [## 曲线下面积与集成歧视改进和净重新分类的结合
医生和数据科学家,这里是关于 IDI 和 NRI 的介绍,并附有一个 Python 编码示例。AUC 可能并不总是足够的…
为你的下一个数据科学侧项目寻找灵感
·
关注 发表在 Towards Data Science · 作为 通讯 发送 · 阅读时间 3 分钟 · 2023 年 3 月 9 日
--
侧项目可以成为积极学习的完美工具,但我们知道许多数据科学家很难挤出时间来启动一个。即使你有一些空闲时间,确定一个想法也可能是一个挑战:它必须与你的兴趣相关,范围合理,并且足够具有挑战性(但又不是难到让人痛苦的程度)。
好的,我们在这里提供帮助!
本周,我们汇聚了一系列强大的动手教程,可能会激发你的创造力,并鼓励你涉足新的主题。逐步跟随其中一个(或多个),或者将这些推荐作为开始全新学习冒险的起点。不论如何,最重要的是选择某个东西。你将在实践中整理细节。
-
立即上手数据集成管道。如果你迫不及待想要深入代码,Marie Truong的 ELT(提取、加载、转换)工作流教程将迅速让你“卷起袖子”动手实践。除了其他有用的组件,你还将操作 API 并熟悉 BigQuery。
-
创建自定义微服务。所有关于 AI 的喧嚣是否让你感到有些迷茫?重新找到方向的好方法是构建一个使用这项快速发展的技术的实际应用。Mason McGough的项目讲解将展示如何构建一个去除图像中人物的 Stable Diffusion 应用——这本身是个很酷的主意,也可能激发你探索其他方向。
-
无需代码,只需数学。掌握一个复杂的概念时,当其各个元素变得更加具体时会更容易。例如:Aparna Dhinakaran和 Jason Lopatecki 对 Jensen-Shannon Divergence 的介绍,他们耐心地解释了其背后的数学以及在监控模型漂移中的实际应用。

图片由Andrew Neel提供,来自Unsplash
-
构建地理空间工具的全程指南。如果你的学习风格适合详细、稳步的指导,试试Jacky Kaub关于创建基于地图应用程序的系列教程。它将带你从概念验证到最小可行产品,经过三篇耐心的文章;到第一部分结束时,你将已经有一个简约的演示。
-
谁想从零开始学习深度学习? Piotr Lachert 的首次 TDS 文章以一个简单的前提开始:“检查您是否真正理解神经网络如何学习的最佳方法是全程自己实现这个过程。” Piotr 随后确切地做到了这一点,带领读者通过 Python 实现基本神经网络。
-
调查因果推断的细节。 Nazlı Alagöz 的教程以一个假设的音乐流媒体服务为例,阐明了提取因果洞察的复杂性,并详细探讨了在处理受多次治疗时间和多个时间段影响的数据时,差异中的差异(DiD)方法的局限性。
-
更多因果关系,仅仅因为。如果 Nazlı 的帖子激发了您对因果推断的好奇心,并且您准备深入了解,您很幸运:我们最近的三月版专注于这个重要主题,您将找到几个优秀的教程。
感谢您这周的时间和支持!如果您喜欢我们发布的作品(并希望获取所有内容),请考虑 成为 Medium 会员。
直到下一个变量,
TDS 编辑团队
使用量子计算机寻找暗物质
原文:
towardsdatascience.com/finding-dark-matter-using-a-quantum-computer-a99f4bff4685
QML — 量子机器学习在高能和粒子物理的有趣应用
·发布于 Towards Data Science ·阅读时间 9 分钟·2023 年 11 月 3 日
--
背景
今年的八月,致力于 IBM 全球量子暑期学校,在那里我不仅在压缩的时间表和紧凑的日程中学到了基础知识,还学到了一些量子计算的应用。在经过 4 周艰苦的学习后获得的 徽章 本身就是一段 "量子体验",因为你以为你了解自己在做什么,但同时你对发生了什么一无所知。这个月的课程从量子电路基础转到变分算法的速度很快,几乎没有时间来“做自己的研究”和亲自参与应用部分。

照片由 Dynamic Wang 提供,来源于 Unsplash
就应用而言,量子化学、量子模拟和一些非常复杂的建模任务都符合可以用量子计算机解决的问题的标准。话虽如此,还有另一个正在蓬勃发展的领域,受到用户和研究人员的极大关注,那就是量子机器学习——简称 QML。
我认为 QML 应该是传统机器学习的合理继承者,于是我开始了这项工作。现在,我希望能找到一个不会因为数据量庞大、复杂模式难以识别而让机器学习算法无法直接解决的问题,但又是我可以在我那台简陋的机器上编程解决的问题。我没有再去寻找其他领域,我们的老朋友物理学正好适合,它在其怀抱中隐藏着一系列复杂但有趣的问题,而且从事这些问题听起来在智力上也很酷。
就这样。
问题陈述
我决定处理在与大型强子对撞机(CERN)相关的 OPERA 实验(振荡项目与乳胶追踪装置)下研究的暗物质分类问题。
问题陈述
简而言之,我们将训练一个分类器来区分信号和噪声。信号是暗物质的存在,而噪声意味着没有信号或完全是其他东西。
很简单!
直觉
让我们稍微详细讲解一下实验的背景,以便形成一些直觉。
暗物质是一种神秘且尚未被探测到的物质形式,它不会与电磁辐射(如光)发生相互作用。人们认为它占据了宇宙总质量的近 80%。之所以称之为“暗物质”,是因为它无法通过望远镜或其他探测电磁辐射的仪器直接观察到。
为什么寻找暗物质是一个挑战?
这很具挑战性,因为我们不知道自己在寻找什么。
-
隐形:暗物质不会与光相互作用,这就是为什么我们真正不知道自己在寻找什么,虽然有许多理论,但仍没有共识。
-
背景噪声:设计用于探测暗物质的实验必须应对各种背景噪声,这些噪声可能会模拟暗物质的预期信号。区分实际的暗物质相互作用和这些背景信号是一个重大挑战。暗物质与常规物质的相互作用非常微弱,使得在实验中探测和区分背景噪声变得困难。
-
多种可能性:暗物质的理论候选者有很多,这需要不同的探测方法。科学家们正在探索这些可能性,这增加了寻找暗物质的复杂性。
OPERA 实验中到底发生了什么?
OPERA 位于意大利的格兰萨索国家实验室。这是一个中微子物理实验,主要关注中微子振荡的研究。它并不是专门用于寻找暗物质的。

作者提供的图像
当一个假定的暗物质粒子(我们正在寻找的)与铅原子核碰撞时,原子核会以电子束的形式发射电子,这些电子束在屏幕上被探测到。这就是我们要寻找的信号。然而,当一个中微子与铅原子核碰撞时,它也会产生电子,并以相同的方式产生电磁冲击,这会将信号与噪声混淆。我们正试图区分这个信号和噪声。
我们的计划
基本上,我们需要筛选数据并区分信号和噪声,这可以通过传统的机器学习方法完成,但仍然是一个艰巨的任务。在一个拥有 1000 万次碰撞的数据集中,只有大约 1 万次会对应信号。这种数据集中的不平衡和稀疏使问题变得偏斜且困难。因为我们喜欢挑战,我们会加一个小难度,使用量子机器学习算法而非传统算法(对以连词开头的句子表示歉意)。
数据
LHC 网站上有大量数据集在这里可供使用;本实验使用的数据集可以在这里找到。
许可证: 数据集根据CC0 (CC Zero) 发布,这是创意共享公共领域奉献许可证 (opendata.cern.ch/record/16541)
代码可以在我的GitHub 仓库中找到。
数据包括两个 h5 文件,open30.h5 和 test.h5;h5 是层次数据格式,用于以压缩方式存储和组织大量数据。
数据包含以下变量:
-
X — 基础轨道的 X 坐标
-
Y — 基础轨道的 Y 坐标
-
Z — 基础轨道的 Z 坐标
-
TX — 从原点投影到 X 轴的角度
-
TY — 从原点投影到 Y 轴的角度
-
信号 — 一个二元变量,1 表示信号,0 表示噪声
库
IBM 的量子库 — qiskit 0.44.3
import qiskit.tools.jupyter
%qiskit_version_table
%qiskit_copyright

图片由作者提供
关于变分量子算法的说明
量子算法设计用于在量子计算机上运行,但目前我们处于 NISQ — 噪声中等规模量子计算机时代,这使得结果的重现变得困难。当前的量子计算机非常容易受到噪声的影响,甚至微小的热力学条件变化或其他电路问题都会对结果产生很大影响。我们想应用的逻辑门因为噪声会变成其他东西。这是不可取的。
聪明的研究人员开发了所谓的变分算法,它们利用经典计算机和量子计算机来提高速度和准确性。
实质上,所有算法都使用某种形式的优化和参数调整,变分算法所做的是利用量子计算机来近似成本函数,然后在经典计算机上计算成本函数的新参数值,再用新值在量子计算机上运行。这样,计算就分布在经典计算机和量子计算机之间,加速了过程。
我们将在这里使用变分量子分类器,因为任务是分类信号和噪声。有关 IBM VQC 的更多信息,请访问:learn.qiskit.org/course/machine-learning/variational-classification
建模
我们先来看看数据和变量

图片来源于作者
让我们查看配对图,以了解变量之间是否存在相关性。

图片来源于作者
好吧,存在一个模式,但相当复杂!
在完成常规的样板工作,包括采样、缩放和训练测试拆分后,我们已经准备好进行量子模型训练。
在继续之前,让我们运行支持向量分类算法,这样我们就有了传统机器学习的基准。
from sklearn.svm import SVC
svc = SVC()
model_classical = svc.fit(train_features, train_labels)

图片来源于作者
在测试数据上 70%的准确率并不算太好,但我这里没有进行太多的特征工程。一旦我进行特征工程,它会有所改善。
现在,轮到量子计算机了。
问题被以门和电路的形式进行表述,这些门和电路将量子比特(量子位)输入到量子计算机中。
我们没有丢弃任何特征——TX、TY、X、Y、Z;我们会使用所有这些特征,因此我们电路中使用的量子比特数量将是 5。
num_features = features.shape[1] #5
feature_map = ZZFeatureMap(feature_dimension=num_features, reps=1)
feature_map.decompose().draw(output="mpl", fold=20)

图片来源于作者
电路的样子就是这样。它输入 5 个量子比特,应用了 Hadamard 和 P 门。Hadamard 门将量子比特的基态从|0>变为|+>,从|1>变为| — >,而 P 门则使得单量子比特绕 Z 轴旋转。
形成电路后的下一步是 Ansatz,这是量子世界中相当常见的术语;它在德语中意为“方法”,但在物理学和数学中指的是一种经过教育的猜测。
所以,我们需要做的就是对参数进行经过教育的猜测——这将创建一个量子态,并在量子计算机上执行。执行值将与期望值进行比较,根据偏差程度,优化器将调整参数或 ansatz,直到我们达到一个足够好或满意的值。
from qiskit.circuit.library import RealAmplitudes
ansatz = RealAmplitudes(num_qubits=num_features, reps=3)
ansatz.decompose().draw(output="mpl", fold=20)

图片来源于作者
你会看到应用了很多 R 门。本质上,所有量子比特只是绕其轴旋转以获得一些任意的起始值,这就是 ansatz。
现在,让我们在训练数据上拟合 VQC。
optimizer = COBYLA(maxiter=100)
vqc = VQC(
sampler=sampler,
feature_map=feature_map,
ansatz=ansatz,
optimizer=optimizer,
callback=callback_graph,
)
vqc.fit(train_features, train_labels)
回调图非常酷。就像实时获取你行为的结果一样。

图片由作者提供
好的,下行趋势是有希望的,这意味着它正在学习,但只在测试数据上进行拟合才能告诉我们是否存在过拟合。
train_score_q4 = vqc.score(train_features, train_labels)
test_score_q4 = vqc.score(test_features, test_labels)
print("Quantum VQC on the training dataset:",train_score_q4)
print("Quantum VQC on the test dataset:", test_score_q4)

图片由作者提供
相较于传统的机器学习,略微更好。我在本地机器上使用模拟器运行了它。也许在真实的量子计算机上运行时效果会更好(没有什么阻止我们,只是需要排队,并且由于服务器负载的增加,简单的任务如加法需要几个小时才能完成,这并不是因为计算慢,而是因为越来越多的人排队使用计算时间)或者花更多时间进行特征工程。
结束语
话虽如此,我很确定目前传统的机器学习算法会优于量子算法,特别是在分类任务中,因为大量研究和资源已经投入到使其稳健和复杂化的工作中。一旦量子机器学习经历这样的升级,它将成为一场公平的竞争。
这并不意味着量子机器学习会取代机器学习,这远非事实,相反,量子计算是针对那些经典计算机无法在多项式时间内解决或甚至无法近似的问题。机器学习会有它的地位,量子机器学习也会在更大的框架中占有一席之地。
本文并不旨在展示机器学习或量子机器学习的威力,而只是展示两者之间的相似性——它们在执行上有所不同但在本质上类似,两者都依赖于特征工程和超参数的选择。
在结束之前,看看信号中是否存在某种聚类将会很有趣。我们已经将信号与噪声分开,我们能否对其进行聚类,看看是否形成了暗物质粒子的电磁喷发?
kmeans = KMeans(n_clusters=5).fit(train)
clustering_labels = kmeans.labels_
X_train = train.sample(frac=0.05)
clusters['cluster'] = clustering_labels
fig = plt.figure(figsize = (20,20))
ax.scatter(X_train.X, X_train.Y, X_train.Z, c=X_train.cluster)
plt.show()

图片由作者提供
不错!它们虽然不具备独特性,但也不糟糕。可以看到一些模式。簇的垂直特性是由于质量和碰撞后的轨迹角度。
有趣!
从嘈杂的数据中,我们提取了可能成为暗物质粒子相互作用候选者的最佳轨迹。
我希望这篇小文章能鼓励你迈出量子跃迁。欢迎通过 Twitter 或邮件联系我;像往常一样,我愿意接受批评和建议,这些能帮助我成长和学习。
PS:再一次,代码在我的 GitHub 仓库中 这里。
在非结构化数据中找到数据切片
简要介绍了数据切片方法,包括对 CIFAR-100 数据集的实际操作示例。
·
关注 发表在 Towards Data Science ·9 min read·2023 年 8 月 18 日
--
CIFAR100 中的数据切片。来源:作者创作。
tl;dr:
数据切片是数据中具有语义意义的子集,在这些子集中,模型的表现异常。当处理非结构化数据问题(例如图像、文本)时,找到这些切片是每个数据科学家的重要工作。在实践中,这项任务涉及大量的个人经验和手工工作。在本文中,我们介绍了一些使数据切片发现更加系统化和高效的方法和工具。我们讨论了当前的挑战,并展示了一些基于开源工具的实际操作示例工作流。
有一个基于 CIFAR100 数据集的互动演示可用。
介绍
调试、测试和监控人工智能(AI)系统很困难。大多数软件 2.0开发过程中的工作都花在了策划高质量数据集上。
开发强大机器学习(ML)算法的一个重要策略是识别所谓的数据切片。数据切片是语义上有意义的子集,其中模型表现异常。识别和跟踪这些数据片段是每个以数据为中心的 AI 开发过程的核心。它也是在医疗保健和自动驾驶辅助系统等领域部署安全 AI 解决方案的核心方面。
传统上,寻找数据切片一直是数据科学家工作的重要组成部分。在实践中,寻找数据切片很大程度上依赖于数据科学家的个人经验和领域知识。随着以数据为中心的人工智能(AI)运动的发展,许多当前的工作和工具旨在使这一过程更加系统化。
在本文中,我们概述了在非结构化数据上找到数据切片的当前状态。我们特别展示了一些基于开源工具的实际操作示例工作流程。
什么是切片查找?
数据科学家一直在使用简单的手动切片查找技术。最著名的例子可能是混淆矩阵,这是一种用于分类问题的调试方法。在实践中,切片查找过程依赖于预计算的启发式方法、数据科学家的个人经验以及大量的互动数据探索。
一个经典的数据切片可以通过对表格特征或元数据的谓词连接来描述。在一个人员数据集中,这可能是某个年龄范围内的男性,身高超过 1.85 米。在一个发动机状态监测数据集中,一个数据切片可能由某个转速、操作时间和扭矩范围的数据点组成。
在非结构化数据的情况下,语义数据切片定义可能更加隐含:它可以是人类可理解的描述,如“在山区的弯曲道路上,轻微降雨情况下的驾驶场景”。
在非结构化数据集上识别数据切片可以通过两种不同的方式进行:
-
可以使用经典信号处理算法(例如黑暗图像、低信噪比音频)或用于自动标记的预训练深度神经网络从非结构化数据中提取元数据。然后可以在这些元数据上进行切片查找。
-
嵌入空间中的潜在表示可以用于对数据集群进行分组。这些集群可以被检查以直接识别相关的数据切片。

识别非结构化数据上的数据切片的工作流程。来源:作者创建。
自动化切片发现技术总是试图在切片的支持度(应该很大)和模型性能异常的严重性(也应该很大)之间取得平衡。
表格数据上的切片发现方法与决策树有很多相似之处:在机器学习模型分析的背景下,这两种技术都可以用来制定描述模型错误存在位置的规则。然而,有一个重要的区别:切片发现问题允许重叠切片。这使得问题在计算上变得困难,因为更难以修剪搜索空间。
为什么数据切片发现很重要?
尤其是在过去十年中,机器学习社区从基准数据集中受益匪浅:从 ImageNet 开始,这些数据集和竞赛成为深度学习算法在非结构化数据问题上的成功因素。在这种背景下,新算法的质量通常基于极少的定量指标,如 F1 分数或平均精度。
随着越来越多的机器学习模型投入生产,现实世界的数据集与其基准数据集的差异变得显而易见:真实数据通常非常嘈杂和不平衡,但也富含元数据。对于某些用例,清理和标注这些数据集可能代价高昂。
许多团队发现,迭代训练数据集并监控生产中的漂移对于构建和维护安全的 AI 系统是必要的。
发现数据切片是这一迭代过程的核心部分。只有了解模型的失败点,才能提升系统性能:通过收集更多数据、纠正错误标签、选择最佳特征或简单地限制系统的操作领域。
为什么数据切片发现如此困难?
切片发现的一个关键方面是其计算复杂性。我们可以通过一个小例子来说明这一点:考虑 n 个二进制特征,采用独热编码(例如,通过分箱或重新编码获得)。那么所有可能特征组合的搜索空间是 O(2^n)。这种指数性质意味着通常使用启发式方法来进行修剪。因此,自动化切片发现不仅需要很长时间(取决于特征数量),而且输出不会是一个最优稳定的解决方案,而是一些启发式方法。
在 AI 开发过程中,模型性能差常常源于不同的根本原因。鉴于机器学习模型的固有随机性,这很容易导致需要手动检查和验证的虚假发现。因此,即使一个切片发现技术可以产生理论上最优的结果,其结果仍必须经过人工检查和验证。为跨职能团队提供高效工具是许多机器学习团队的瓶颈。
我们已经指出,通常希望找到具有大支持的切片,但也希望从数据集基线中获得显著的模型性能差距。不同数据切片之间的关系通常是层次性的。在自动切片查找过程中和互动审查阶段处理这些层次结构是相当具有挑战性的。
自动切片查找方法在元数据丰富的问题上最为有效。这通常是实际问题的情况。相比之下,基准数据集在元数据方面总是相当稀疏。这主要有两个原因:数据保护和匿名化要求。由于缺乏合适的示例数据集,开发和展示有效的切片查找工作流变得非常困难。
我们(不幸地)必须在接下来的示例部分中处理这一挑战。
实践操作:在 CIFAR-100 上查找数据切片
CIFAR-100 数据集是一个成熟的计算机视觉基准。我们在本教程中使用它,因为它的体积小,易于处理,并且计算需求低。结果也易于理解,因为它们不需要特殊的领域知识。
不幸的是,CIFAR-100 已经非常平衡,经过高度整理且缺乏有意义的元数据。因此,我们在本节中生成的切片查找工作流的结果不如在实际环境中有意义。然而,所呈现的工作流应足以理解如何快速在实际数据中使用它们。
在准备步骤中,我们使用 Cleanvision 库计算图像元数据。有关此增强的更多信息,请参阅我们的 数据驱动 AI 操作手册。
我们还定义了一些重要的变量用于数据切片分析:待分析的特征以及标签和预测列的名称:
大多数切片技术仅适用于分箱特征。由于 SliceLine 和 WisePizza 库本身不提供分箱功能,我们将其作为预处理步骤:
SliceLine
Sliceline 算法 是由 Sagadeeva 等人于 2021 年提出的。它旨在处理包含许多特征的大型表格数据集。它利用基于稀疏线性代数技术的新型剪枝技术,并允许在单台机器上快速查找数据切片。
在本教程中,我们使用 SliceLine 实现 来自 DataDome 团队。它运行非常稳定,但目前仅支持 Python 版本 <=3.9。
SliceLine 算法的大多数参数非常简单:切片的最小支持度 (min_sup)、定义切片的最大谓词数 (max_l) 和要返回的切片的最大数量 (k)。参数 alpha 为切片错误的重要性分配权重,并且基本控制切片大小与错误降低之间的权衡。
我们调用 SliceLine 库以获取 20 个最有趣的切片:
为了交互式探索切片,我们丰富了每个数据切片的描述:
我们启动 Spotlight 以交互式探索数据切片。你可以直接在 Huggingface 空间 体验结果。

图 3:通过 SliceLine 生成的数据切片进行交互式探索。一个 互动演示 在 Huggingface 上可以使用。来源:作者创建。
我们发现 Sliceline 算法确实在 CIFAR-100 数据集中找到了有意义的数据切片。枫树、柳树 和 橡树 类别似乎存在问题。我们还发现,这些类别中具有较高 暗色评分 的数据点尤其具有挑战性。经过仔细检查,我们发现这是因为背景明亮的树木对模型来说很困难。
Wise Pizza
WisePizza 是 Wise 团队最近开发的工具。它旨在发现和可视化表格数据中的有趣数据切片。核心思想是使用 Lasso 回归为每个切片找到重要性系数。有关 Wise Pizza 工作原理的更多信息,请参阅 博客文章。
重要的是要注意,WisePizza 并不是作为机器学习调试工具开发的。相反,它主要旨在支持 EDA 过程中的分段分析。这就是为什么可以手动定义段候选项并为其分配权重。在我们的实验中,我们直接在数据集上运行 WisePizza,并将每个数据点的权重设置为 1:
为了探索我们非结构化数据集中的问题,我们以与 Sliceline 示例相同的方式提取问题。

图 4:WisePizza 还识别出柳树类别具有较大暗色评分的问题。然而,这些切片的细粒度不如 SliceLine 结果。来源:作者创建。
从图 4 中我们可以看到,在简单的 CIFAR-100 基准数据集上,WisePizza 找到了相关的片段:它还列出了带有高 dark score 的 willow tree 类作为顶部切片。然而,以下结果仅限于不同的类别,并且不像 SliceLine 输出那样细粒度。一个原因是 WisePizza 算法没有直接提供切片支持和准确性下降之间的加权机制。
Sliceguard
Sliceguard 库使用 层次聚类 来确定可能的数据切片。然后,使用公平学习的方法对这些聚类进行排序,并通过可解释 AI 技术挖掘谓词。有关 Sliceguard 的更多信息可以在 这篇博客文章 中找到。
我们构建 Sliceguard 的主要原因是它不仅适用于表格数据,还能直接应用于嵌入。该库提供了大量内置功能用于预处理(如分箱)和后处理。
我们可以用几行代码在 CIFAR-100 上运行 Sliceguard:
Sliceguard 使用 Spotlight 提供识别的数据切片的互动可视化:
issue_df, issues = sg.report(spotlight_dtype={"image": Image})
Sliceguard 可以在 CIFAR-100 数据集上揭示细粒度的数据切片(图 5)。除了之前发现的树类数据切片外,我们还识别了其他问题(如 mouse 类)。

图 5:Sliceguard 揭示了细粒度的数据切片。在 Huggingface 上提供了一个 互动演示。来源:作者创建。
结论
我们介绍了三种开源工具,用于挖掘数据切片。即使在简单的 CIFAR-100 基准上,它们也可以快速揭示关键数据片段。识别这些数据切片是理解模型失败模式和改善训练数据集的重要步骤。
SliceLine 工具适用于表格数据,并通过谓词组合描述数据切片。Sliceguard 不返回数学上保证的最佳谓词组合,但可以直接处理嵌入。此外,它可以在几行代码的基础上运行于非结构化数据集。
在实践中,SliceLine 和 Sliceguard 对识别数据切片非常有帮助。然而,这两种工具都不能用于完全自动化的切片分析。相反,它们提供了强大的启发式方法,可以与互动探索相结合。如果操作得当,这种方法是跨学科数据团队构建可靠 ML 系统的重要工具。
你是否有使用过展示的数据切割工具的经验,或者能否推荐其他开源库?我很乐意在评论中听到你的意见。
寻找改进的改写
原文:
towardsdatascience.com/finding-improved-rephrasings-b5fb002ac811
使用带有机器学习元素的 Trie
·发布于Towards Data Science ·19 分钟阅读·2023 年 4 月 19 日
--

表达得好需要付出努力。幸运的是,体现现代自然语言处理的聊天机器人,如 ChatGPT,非常有帮助。
在这篇文章中,我们将使用适当增强的 Trie 来建模这个问题。这个 Trie 将自动检测在语料库中重复出现的短词序列,并且以无监督的方式进行。随后,它还将从标注的数据集(尴尬短语、改进短语对)中学习常见的“尴尬短语”模式。
对于小型但细微的改写问题版本,我们将展示通过在学习到的 Trie 上进行适当的模糊搜索,可以直接找到改进的表达方式。
对于更复杂的版本的问题,我们描述了一种使用 Trie 作为特征提取器的变体。这些特征可以集成到任何先进的神经语言模型中。
对于这个用途,Trie 需要进行特定的预处理和预训练。我们也将描述这两个步骤。
这种方法——本身——是…
-
易于理解,无需具备神经网络或自然语言处理的背景知识。
-
作为概念验证容易实现。
-
独立有效于检测短小的尴尬短语并建议改写。
也许更重要的是,如上所述,这种方法允许将基于 Trie 的特征添加到现代神经网络模型中,这些特征建模了在语料库中重复出现的短序列,因此可能携带独特的语义。
本文中使用的单词序列示例,包括尴尬的短语和改进后的短语,都是 ChatGPT 在响应单个提示时提供的:
给我一些尴尬的短语示例,长度为两到四个词。
然后我选择了一个展示了本文方法所涉及的所有场景的子集。我将它们分配到本文的各个部分。这种分配是因为不同的表述展示了本文方法中的不同机制。
本文的最后部分列出了所有尴尬的表述及其改进后的表述。包括在 ChatGPT 的回复中未使用的那些,因为它们是其回复的一部分。
初步观点:问题框定为搜索
考虑她做饭好,我们希望更好地表达。假设我们有一个丰富且庞大的优质句子语料库进行学习。想象一下搜索以她做饭开头的句子。通过分析这些句子,我们可能会得出她做饭更好地表述为她做饭或她做过饭。
现在考虑第三个词good。语料库中可能没有与短语(cooks|cooked) good匹配的句子。
另一方面,我们可能会找到几个包含短语做饭很好的句子。凭借well和good足够相似的额外知识,我们可能会倾向于建议她做饭很好作为更好的整体改写。我们可能从哪里获得这样的知识?词嵌入。
(我们在vectors.nlpl.eu/explore/embeddings/en/#输入了good,最强的“语义关联”是well,得分为 0.829。)
让我们将上述内容以流程的形式表达。
-
在我们的句子语料库中搜索那些以她做饭开头的句子。得到的最佳匹配是她做饭和她做过饭。
-
将匹配项替换到探测器中,即得到她做饭/做过饭好。
-
接下来,搜索我们包含做饭/做过饭好的句子语料库。得到做饭很好作为一个好的匹配。在将其评分为好的匹配时,还要利用good和well是语义关联的这一点,这可以通过比较它们的词嵌入来揭示。
-
将匹配项替换到转换输入的相关部分。我们得到她做饭很好。
步骤 1 和 3 中的搜索需要有一定的模糊性,以找到不完全匹配。这些匹配项随后需要进一步后处理,以便提取它们之间的统计共性。
此外,请注意以下几点。在第 1 步中,第三个词good没有作用。这需要在预处理、搜索本身或结果的后处理过程中以某种方式发现。
同样,在第 3 步中,必须以某种方式忽略词她。为什么?因为使用cooks/cooked good作为探测器而不是她做饭/做过饭好将更倾向于更好地概括。即使语料库中没有以她做饭很好开头的句子,最终的改写也可能会被建议。
上面两个段落讨论的要点可以用注意力机制来重新框定,这在大型语言模型中至关重要。
顺序特性
上述方法本质上是顺序的。我们从探针开始,检查其部分,适当地转换,然后重复。
在 Trie 上搜索
到目前为止,我们已经多次使用了“在句子语料库上搜索”这个术语。我们现在深入探讨一下如何进行这样的搜索。
我们将把句子语料库导入到一种叫做 Trie 的树数据结构中,然后在其上进行搜索。
那么,什么是 Trie?我们将通过一个示例来说明它,而不是给出定义。

图 1:表示两个单词序列的 Trie。(作者提供。)
上述示例中的 Trie 表示以下两个句子的单词序列:狗追猫 和 新工作很好。
现在想象一下短语新工作好。我们可以直观地看到,这个短语映射到从 Trie 的根开始的路径,该路径标记为新工作很好。
这种映射也被称为对齐,可以在下方表示。
New job good
| | |
The new job is good
在这种方法中,找到表述的更好重述的问题变成了在 Trie 中找到与表述对齐的路径的问题。
让我们从更简单的情况开始,当存在从 Trie 的根开始的路径完全表示输入时。在图 1 中的 Trie 中,新工作很好 是这样一个输入的例子。
很容易看出,我们可以通过基于输入中的令牌进行从左到右的扫描来找到表示输入的路径。
现在考虑一个不太干净的探针:新工作好。在图 1 的 Trie 中没有完全对应于这个探针的路径。
处理第一个单词The缺失的一种方式是向 Trie 中添加跳过弧。下面是我们第一个图的一个版本,添加了几个跳过弧。

图 2:添加了一些跳过弧的 Trie。(作者提供。)
跳过弧使我们能够找到路径,其中某些元素在探针中缺失,但代价是消耗更多的空间并且搜索时间更长。
现在假设我们还知道一个新工作好的良好重述。我们可以将原始表述和改进的重述对视为一个标记实例。此外,假设我们以某种方式将重述与探针对齐,如下所示。
New job good
| | |
The new job is good
现在我们有了一个标记的数据集,并且进入了监督学习的领域。
标记可能会产生显著的成本。那么,在我们的设置中,标签实际上能给我们带来什么?它们告诉我们在哪里添加跳过弧。
在这个特定的示例中,我们将在图 2 所示的位置添加两个跳过弧。
人们可能会问,为什么要使用跳过弧?为什么不直接进行前瞻搜索?
如果我们想处理多个连续缺失的标记怎么办?这是很可能的——我们很快会看到一个例子。我们需要一些大于 1 的正整数k的前瞻。k的值应该是多少?如果k仅比 1 稍大,我们不能处理许多连续缺失的标记,这种情况在现实中是很可能发生的。如果我们使用一个大的k,前瞻搜索可能会遇到组合爆炸,即不可扩展。这是因为我们必须搜索所有距离当前节点k范围内的可能后代。这些后代的数量可能随着k的增大而呈超指数增长。
跳跃弧的使用避免了这种组合爆炸。
对齐版本的标注实例精确地揭示了我们应该添加的跳跃弧。如下面的示例所示,有时我们想要学习跳过多个标记。
I go store.
| | |
I will go to the store
从这点出发,我们可以学会将跳跃弧从 Trie 节点[I, will, go]添加到节点[I, will, go, to, the]。这个跳跃弧将允许我们处理原始表述中相对于上述改进表述中缺失的to the。
好吧,因此I go store将映射到从根节点发出的 Trie 上的特定路径。从这条路径中,我们需要读取序列[I, will, go, to the, store]。问题是,我们如何得到粗体部分?
我们可以做的一件事是为表示被跳过的标记序列的跳跃弧分配一个标签。当我们首次创建跳跃弧时,我们知道这个标记序列,因此可以立即设置标签。与主弧不同,跳跃弧的标签仅在构建改进的表述时使用,而不是在查找期间使用。正如我们所知,查找是跳跃。
这是半监督学习
好吧,我们已经看到在标注数据集上的训练是有价值的。我们还意识到,建立一个标签覆盖大多数干净句子的标注数据集可能是不可行的,比如维基百科。
我们如何利用大量语料库而不必从中推导出足够丰富的标注实例?简单。
-
从语料库中学习 Trie 减去跳跃连接。这是无监督学习。
-
从可能更小的标注数据集中学习跳跃连接。这是监督学习。
后者的基本假设是学习跳跃连接的问题具有普适性。我们确实在之前讨论的一个例子中看到了这种普适性。即使从一个标注实例中,我们也可以学会处理一个以The dog开头的句子在被损坏版本中The的缺失。
跳跃弧越接近根节点的泛化能力越强
看图 2。最接近 Trie 根节点的跳跃弧从根节点开始,跳过了The。这使得它能够考虑到任何以The开头的干净句子在任何被损坏版本中The的缺失。
现在考虑跳过句子中新工作是的是的跳过弧。这适用于更有限的句子集。
替代弧及其学习
为了激励这一点,我们将发现我们在帖子中覆盖的第一个重述示例是有用的。这涉及将她做饭好重述为她做饭好。
好的,所以我们从根节点开始,越过标记为她的弧。我们现在在节点[她]上,但它没有子节点[她,做饭]。最接近的子节点是[她,做饭]。我们可以通过模糊匹配做饭和做饭来找到这个子节点。我们应该记住,如果当前节点有许多子节点,找到最佳匹配的子节点可能需要时间。
对于这种模糊匹配的一种替代方法是从标记数据集中学习替代弧。
想象以下标记实例出现在监督训练集中。
She cook good
| | |
She cooks well
我们可以想象从这个标记实例中学习,以添加我们所称的替代弧。这些在我们的示例中如下所示。

图 3:添加了学习到的替代弧的 Trie。(作者提供。)
虽然图 3 没有显示这一点,但替代弧需要与主弧区分开,因为它们的使用方式略有不同。
在处理输入时,替代弧可能会被遍历。例如,在输入她做饭好时,我们会遍历标记为[她,做饭,好]的路径,其中第二个和第三个弧是替代弧。
从结果路径中得出的输出涉及将路径中的替代弧上的标签替换为对应的主弧。在我们的例子中,[她,做饭,好] 会被替换为 [她,做饭,好]。
替代弧与令牌的模糊匹配
正如我们在前一节中讨论的那样,只有在我们处于正确的左侧上下文时,令牌才会被模糊匹配。例如,考虑她做饭好,假设我们在节点她做饭。从这个节点发出的弧有一个标签好。我们仅在这个左侧上下文她做饭中将好与好匹配。这减少了假阳性的风险。
这意味着,使用模糊令牌匹配而不是替代弧不会增加风险。使用替代弧主要有助于加快匹配速度,当节点有许多子节点时尤其如此。模糊匹配替代方案要求我们找到来自父节点的所有弧上的标签,并将每个标签一个接一个地匹配到输入中的下一个令牌。
模糊匹配和替代弧都不一定假设匹配的令牌在词汇上相似,仅仅是它们在标记实例中对齐。
作为示例,考虑对齐情况
He is a lawyer
| | | |
He is an attorney
律师和法律顾问即使没有词汇上的相似性也是对齐的。
自弧及其学习
考虑表达方式You Like What及其改进的重述What Do You Like。重述涉及重新排列单词。我们希望继续利用字典树,而不是设计新的机制来处理单词顺序的重新排列。
想象一下,我们有一个标记和对齐的实例来表示这一对。
You like what
| |
What do you like
我们可以通过跳跃弧来建模改进的重述中的What do。要建模原始重述中的what,我们将使用自我弧的概念。
下面是建模这种情况的字典树的一部分。

图 4:具有自我弧(和跳跃弧)的字典树。(作者提供。)
让我们走查一下如何处理输入[you, like, what]。我们将从根节点开始,走跳跃弧,然后走标记为you和like的弧,最后走标记为what的自我弧。我们将连接路径上前三个弧的标签。我们将省略第四个弧的标签,因为我们知道它是自我弧。
减少构建标记和对齐实例的工作量
正如我们所见,标记的实例,即原始表达方式与改进的重述对齐的实例,在学习准确的重述建议模型中发挥着重要作用。我们标记实例的数据集越丰富,学习到的模型质量就越好。
鉴于此,我们可以减少构建标记实例的工作量吗?答案是可以。
以下我们提出一种方法,这种方法涉及将一个数据集中的序列(其中可能有许多尴尬的表达方式)对齐到一个在主要干净句子语料库上训练的字典树。接着,我们将描述如何利用某些经过充分论证的语言特征来增强这种方法。
用干净的字典树对齐不良序列
首先,假设我们有两个数据集,一个是相对干净的句子,另一个是质量较差的表达方式。想象一下,从第一个数据集中,我们按照前面描述的方式学习字典树。(这个字典树只有主要弧。)
现在考虑一个来自第二个数据集的实例——一个相对较差的实例。想象一下在当前字典树中找到一条从根节点开始并且与实例足够匹配的路径。通过这条路径结合实例,我们可以推导出我们所寻求的标记实例。
模糊匹配器需要如何工作?它需要使用前瞻操作来允许路径上的标记在探测器中缺失。它还需要使用模糊标记匹配来允许替换。
让我们说明这个过程。考虑下面的字典树。

图 5:用于说明模糊匹配的字典树。(作者提供。)
现在考虑探测器Me go store。假设我们知道me和I是语义关联词。我们将利用这一知识将me与I对齐,然后进行一次前瞻,对齐go与go,再进行一次前瞻,最后对齐store与store。我们得到了所需的对齐。
如果我们认为这发生在离线训练阶段,那么模糊匹配甚至不需要进一步优化速度。前瞻可能需要时间,但这没关系。
有人可能会问,为什么不在推断时直接使用这种模糊匹配过程,而完全放弃监督训练?有两个原因。
首先,模糊匹配器可能运行较慢,因为它可能需要进行前瞻。一个慢的模糊匹配器在训练过程中比在推断过程中更可接受。
其次,允许从标记实例中学习允许将人工策划的标记实例与使用模糊匹配器自动生成的实例混合在一起。实际上,它甚至允许人工策划者丢弃那些被认为质量较低的自动生成实例。这可能仍然减少了获得一个同等丰富的数据集的整体努力,因为它自动化了发现过程。只要通过这种自动化过程找到的足够多的标记实例足够好,那么其好处可能会超过检测和标记假阳性的成本。
利用语言特征
这是基于这样的观察:在某些词性中,词语在不良表述中缺失的可能性比其他词性更高。例如,冠词或介词。以下是一些示例。
曝露阳光
他是律师
我们将这融入到我们将不良序列与干净 Trie 对齐的方法中,如下所示。
在第一步中,我们将不良序列与干净的 Trie 对齐。在这个过程中,不良序列x与 Trie 上的路径y对齐。对于y中的词序列,我们现在得到它们的词性。对齐现在不仅揭示了相对于y来说可能在x中缺失的词,还揭示了它们的词性。我们现在可以辨别出某些词性是否比其他词性更可能缺失。
在第 2 步中,我们穷举地列出 Trie 上的所有根到叶子路径,这可以通过深度优先或广度优先搜索等方式完成,并添加跳过弧以跳过在我们最易跳过的词性列表中的词。
让我们用一个简单的例子来说明。
比如我们将He is lawyer与He is a lawyer对齐。我们发现前者缺少一个冠词,在这种情况下是a。重复这个过程于许多不良序列应该能够辨别出冠词比其他一些词性更倾向于缺失。
更细粒度的变体
到目前为止,我们考虑的词序列用于输入 Trie 的都是语料库中的句子。这种选择是为了方便,因为将文本分割成句子相对容易。
在这一部分,我们考虑一个更细粒度的变体,其中我们输入到 Trie 中的词序列不一定是完整的句子,而是以某种主导顺序出现的词序列。
这种更细粒度的 Trie 有潜力建议句子中嵌入的较短单词序列的改进表述。
这是一个例子。没有人将skin cancer写作cancer skin。如果我们能够发现前者是这两个词的主导顺序,我们就可以检测到这种顺序的违规,并建议重新表述。
更为可信的情况是,当句子中的单词of the被错误地调换位置时,例如The top the of mountain。
此外,也许更重要的是,这种更细粒度的 Trie 可以作为现代神经语言模型中的特征提取器,用于处理更复杂的拗口表述问题。实际上,对于大型语言模型适用的任何推断问题,包括语言生成,都是如此。
无论 Trie 是直接使用还是作为特征提取器,都需要解决一个新问题。即发现以某种主导顺序出现的单词序列。这个问题在我们的第一个 Trie 中不存在,因为句子被视为单词序列。
首先,以下术语将有所帮助。
可排序的袋子
如果一个单词多重集具有主导排序,我们将称之为可排序的袋子。
我们使用multiset这个术语而不是set,以允许同一个词出现多次。
例如,多重集{the, of}有一个主导排序[of, the]。
请注意,可排序的袋子与显著短语不同。前者只关心主导排序;后者还需要考虑显著性,即该短语传达的意义大于其部分之和。
从句子语料库中发现可排序的袋子
好了,现在让我们讨论如何从句子语料库中发现可排序的袋子。我们假设语料库中的所有句子都已经被分词为单词。
从这个语料库中,设想从所有长度至少为二的令牌序列中派生出一个数据集D1。接下来,我们将根据以下方式从D1 构建一个新的数据集D2。实际上,我们将把D2 解释为一个带标签的数据集。
对于每一个D1 中的序列y,在D2 中将有一个序列x,其中x是通过对y的单词进行词典顺序排序得到的。x在D2 中的标签将是y。也就是说,我们用观察到的特定排序y来标记y的单词多重集。
下面是一个在D2 中的实例示例。
x = [cancer, skin], y = [skin, cancer]
我们将从D2 中的所有x构建一个 Trie。对于D2 中的任何一对(x, y),我们将把y作为卫星数据附加到 Trie 中x结束的节点上。
一旦 Trie 构建完成,我们将把 Trie 中各节点上的卫星数据压缩为两个属性。
-
作为标签的不同排序数量即在该节点结束的x。
-
对于以 x 结尾的节点,这是一种关于所有排序的概率分布。为了紧凑地表示这种分布,我们将以如下所述的方式对排序进行编码。
我们将第一个属性——不同排序的数量——称为 支持度。
当然,这个 Trie 可能非常庞大。但这对我们来说不是问题。
利用 Trie 发现顺序不当的子序列
考虑一个词序列,其中某些子序列的排序不当。例如 The top the of mountain。
我们将通过枚举所有长度至少为 2 的子序列,按词典顺序对每个这样的序列进行排序,查找 Trie 中的该序列,并检查末尾节点是否具有揭示主要排序的卫星数据,从而发现不当措辞(如果有的话)。
示例
例如,一旦 Trie 构建完成,假设路径 [cancer, skin] 上的最右节点具有卫星数据 (520, [skin, cancer] → 1)。这仅表示包 {cancer, skin} 在数据集中被观察到 520 次,每次的排序为 [skin, cancer]。
在神经语言模型中使用可排序包 Trie 作为特征提取器
一旦可排序的包 Trie 构建完成,我们可以将其用作特征提取器,如下所示。
首先,让我们为 Trie 中的每条路径分配一个唯一的标识符。这个标识符将作为与该路径关联的特征值。
我们将一个词序列作为输入,并将其分割为 Trie 中的一系列最大路径,每个路径都是具有一定最小支持度的主要排序。为了涵盖所有情况,我们将由单个词组成的词序列定义为具有上述最小支持度的主要排序。
我们现在用它们的标识符替换这个序列中的路径。这样我们就得到了一个特征序列。对于这个序列中表示单词的路径,我们可以附加额外的特征,例如词嵌入。
让我们在下面的示例中说明这个过程。
考虑
曝露在阳光下会导致皮肤癌
想象一下,使用可排序包 Trie,我们将其分割为
exposure → to → sunlight, causes, skin → cancer
其在我们用 ids 替换路径后变为
pid1, pid(causes), pid2
从为我们的基于 Trie 的方法添加价值的角度来看,这一好处是显而易见的。现代神经语言模型具有令人印象深刻的能力。
在我们在此帖子中讨论的特定用例背景下,这将允许它们检测在较长文本部分(如段落或甚至多页)中需要改进的措辞。可能需要考虑长距离交互的措辞。
从为现代神经语言模型添加价值的角度来看,Trie 基础的特征将进一步丰富这些模型。基本的直觉是,某些短的词序列在文本中重复出现频繁,并且隐含地编码了特定的语义。
基于 Trie 的方法可以从语料库中自动发现这些序列,因此可以用于分析由这些序列的某些排列组成的更长序列。
让我们用一个简单的例子来说明。考虑
exposure to sunlight 导致 skin cancer
在这里,我们认为加粗的子序列是主导排序,并在我们的 Trie 中进行了表示。
我们可以想象,利用 Trie 的神经语言模型可以很容易地预测 exposure to sunlight causes 后面应该跟着 skin cancer。
现在假设我们有一个标记的数据集,其中包含语义上等价的表述。作为这个数据集中的一个实例,考虑
{
exposure to sunlight causes skin cancer,
skin cancer is caused by exposure to sunlight
}
从许多形式为 {X causes Y, Y is caused by X} 的实例中,并假设 X 和 Y 在 Trie 中被表示为主导排序路径,我们可以识别出这两者的语义等价性,并在某些推理或生成中使用这种学习。例如,如果我们被要求以不同的方式重新表达 X causes Y,我们可以回答 Y is caused by X。
总结
在这篇文章中,我们讨论了发现尴尬表述——即词序列的排列——并提出改进方案的问题。
我们使用 Trie 对这个问题进行了建模。这个 Trie 可以以无监督的方式自动检测出语料库中重复出现的短词序列。Trie 中的一个监督机制还从标记数据集(包括(awkward phrasing, improved phrasing)对)中学习某些常见的“尴尬表述”模式。
对于小规模但细微的表述重述问题,我们展示了可以通过在学习到的 Trie 上进行适当的模糊搜索直接找到改进的表述。
对于问题的更详细版本,我们描述了一个变体,使用 Trie 提取高级重复特征,特别是短的重复词序列。我们推测为什么在最先进的神经语言模型中使用这些特征可以提高其准确性,并简化其训练。为此用途,需要对 Trie 进行某些预处理和预训练。我们也描述了这两者。
ChatGPT 中的短语
以下是 ChatGPT 在回复提示时提供的尴尬表述及其改进版本
给我一些长度为两到四个单词的尴尬表述的例子
在回复中,ChatGPT 将改进后的表述与尴尬的表述分开了。我为了读者的方便,将这些进行了对齐。
-
“Me go store.” 我要去商店
-
“You like what?” 你喜欢什么?
-
“Dog chase cat.” 狗在追猫。
-
“He no here.” 他不在这里。
-
“She cook good.” 她是个好厨师。
-
“Big house him.” 他有一栋大房子
-
“Funny joke that.” 那是个有趣的笑话
-
“Rain make wet.” 雨使一切变得湿润。
-
“Car go fast.” 这辆车正在快速行驶。
-
“New job good.” 新工作很好。
进一步阅读
在干草堆中找针 — Jaccard 相似度的搜索索引
从基础概念到精确和近似索引
·
关注 发表在 Towards Data Science · 15 分钟阅读 · 2023 年 8 月 18 日
--
在干草堆中找针。图像由作者使用 Midjourney 制作。
向量数据库因作为大型语言模型(LLMs)的外部记忆而成为新闻话题。如今的向量数据库是建立在十年前的近似最近邻(ANN)索引研究基础上的新系统。这些索引算法处理许多高维向量(例如float32[]),并构建一个数据结构,支持在高维空间中找到查询向量的近似邻居。这就像 Google 地图根据你家的纬度和经度找到你邻居的房子一样,只不过 ANN 索引在更高维空间中操作。
这项研究有着几十年的历史。90 年代末,机器学习研究人员手工制作多媒体数据(如图像和音频)的数值特征。基于这些特征向量的相似性搜索成为一个自然的问题。一段时间内,研究人员涌入这一领域。这一学术泡沫在一篇开创性论文When is “Nearest Neighbor” Meaningful? 发表后破裂,基本告诉大家不要浪费时间,因为在手工制作特征的高维空间中的最近邻大多是不具有意义的——这是另一个话题。即便如此,我仍然看到研究论文和向量数据库基准测试发布关于 SIFT-128 数据集的性能数据,该数据集由具有无意义相似性的手工制作特征向量组成。
尽管手工制作特征存在噪声,但研究中有一条富有成果的方向专注于一种具有意义的高维数据类型:集合和Jaccard。
在这篇文章中,我将介绍集合上 Jaccard 相似性的搜索索引。我将从基本概念开始,然后转向精确和近似索引。
集合和 Jaccard
集合只是不同元素的集合。你在 Spotify 上喜欢的歌曲是一个集合;你上周转发的推文是一个集合;从这篇博客文章中提取的不同令牌也形成一个集合。集合是表示数据点的一种自然方式,适用于音乐推荐、社交网络和剽窃检测等应用场景。
假设在 Spotify 上,我关注了这些艺术家:
[the weekend, taylor swift, wasia project]
我的女儿关注了这些艺术家:
[the weekend, miley cyrus, sza]
衡量我们音乐品味相似性的一个合理方法是看我们共同关注了多少艺术家——即交集大小。在这种情况下,我们都关注了the weekend,所以交集大小为 1。

每个集合代表一个用户的关注列表。交集显示了两个用户共享的共同关注。图片由作者提供。
然而,你可以想象另一对用户每人关注 100 位艺术家,它们的交集大小仍然是 1,但他们的品味相似度应该远低于我女儿和我之间的相似度。为了使不同用户对之间的测量具有可比性,我们用并集大小来归一化交集大小。这样,我女儿和我关注的相似度为1 / 5 = 0.2,而另一对用户关注的相似度为1 / 199 ~= 0.005。这被称为 Jaccard 相似度。
对于集合A和集合B,Jaccard 相似度的公式是:

集合 A 和 B 的 Jaccard 相似度公式。
为什么集合是高维数据类型?集合可以编码为“one-hot”向量,其维度 1 对 1 映射到所有可能的元素(例如,Spotify 上的所有艺术家)。如果集合包含对应于该维度的元素,则该维度的值为 1,否则为 0。因此,我关注的艺术家的向量化集合如下所示:

高维向量表示关注集合。图像由作者提供。
其中第二、第三和倒数第三个维度分别是the weekend、taylor swift和wasia project。Spotify 上有超过 1000 万的艺术家,因此这样的向量维度极高且非常稀疏——大多数维度为 0。
Jaccard 搜索的反向索引
人们希望快速找到事物,因此计算机科学家发明了被称为索引的数据结构,以使搜索性能对软件应用程序满意。具体来说,Jaccard 搜索索引是建立在一组集合上的,给定一个查询集合,它返回与查询集合具有最高 Jaccard 相似度的k个集合。
Jaccard 的搜索索引基于一种称为反向索引的数据结构。反向索引具有极其简单的接口:输入一个集合元素,比如the weekend,它返回包含输入元素的集合 ID 列表,例如[ 32, 231, 432, 1322, ...]。反向索引本质上是一个查找表,其键是所有可能的集合元素,值是集合 ID 列表。在这个例子中,反向索引中的每个列表表示一个艺术家的关注者 ID。

反向索引包含与查询集合匹配的集合 ID 列表。图像由作者提供。

原始集合存储在一个单独的表中,以便通过其集合 ID 进行查找。图像由作者提供。
你可以看到这为什么被称为“反向索引”:它允许你从一个集合元素找到包含该元素的集合。
精确搜索算法
倒排索引是加速搜索的极其强大的数据结构。使用倒排索引时,在搜索时,你不需要遍历所有集合并与查询集合进行比较——如果你有数百万个集合,这会非常昂贵,你只需要处理与查询集合共享至少一个元素的集合 ID。你可以直接从倒排索引列表中获取集合 ID。
这个想法是通过以下搜索算法实现的:
def search_top_k_merge_list(index, sets, q, k):
"""Search top-k Jaccard using inverted index.
Args:
index: an inverted index, key is set element
sets: a lookup table for sets, key is set ID
q: a query set
k: search parameter k
Returns:
list: at most k set IDs.
"""
# Intialize an empty lookup table for candidates.
candidates = defaultdict(0)
# Iterate over set elements in q.
for x in q:
ids = index[x] # Get a list of set IDs from the index.
for id in ids:
candidates[id] += 1 # Increment count for intersection size.
# Now candidates[id] stores the intersection size of set with ID id.
# A simple routine for calculating Jaccard using intersection size and
# set sizes, based on Inclusion-Exclusion principle.
jaccard = lambda id: candidates[id] / (len(q) + len(sets(id) - candidates[id]))
# Find the top-k candidates order by Jaccard.
return sorted(list(candidates.keys()), key=jaccard, reverse=True)[:k]
用通俗的语言来说,该算法遍历查询集合中元素匹配的每个倒排索引列表,并使用候选表跟踪每个集合 ID 出现的次数。如果一个集合 ID 出现了n次,则索引集合与查询集合有n个重叠元素。最后,算法使用候选表中的所有信息来计算 Jaccard 相似度,然后返回最相似的前 k 个集合的 ID。

search_top_k_merge_list算法中的候选表用于跟踪通过倒排索引找到的索引集合的重叠计数。
当以下情况发生时,search_top_k_merge_list算法可以很快:(1) 查询集合中的元素数量较小,以及 (2) 查询元素的倒排索引列表中的 ID 数量较少。在 Spotify 场景中,如果大多数人关注的艺术家不多(可能是这样),并且所有艺术家拥有的粉丝数量大致相同(不准确)。我们都知道,少数顶级艺术家受到大多数人的关注,而大多数艺术家的粉丝很少。毕竟,音乐产业遵循帕累托分布。
泰勒·斯威夫特在 Spotify 上有 7800 万粉丝,而周末(The Weekend)有 6700 万粉丝。将他们添加到我的关注列表中意味着search_top_k_merge_list算法需要遍历至少 1.45 亿个集合 ID,并且候选表candidates将增长到这个天文数字的规模。尽管今天的计算机速度快且强大,在我的 Intel i7 机器上,创建这样一个表仍然需要至少 30 秒(Python),并动态分配 2.5 GB 的内存。
大多数人关注一些超级明星艺术家。因此,如果你在搜索应用中使用这个算法,你肯定会因为大规模资源使用而获得一笔巨额的云托管账单,并且由于搜索延迟高,用户体验将会很差。
分支限界优化
从直观上看,之前的算法search_top_k_merge_list以广度优先的方式处理所有潜在候选,因为它仅使用倒排索引来计算交集。由于超级明星艺术家拥有数百万粉丝,这个算法表现不佳。
另一种方法是对潜在候选更加挑剔。假设你在面试候选人,你是招聘经理。你无法面试所有给你发送简历的潜在候选人,因此你根据招聘标准将候选人分到不同的类别中,并开始面试那些符合你最关心标准的候选人。随着你逐个面试,你评估每个人是否符合所有或大部分标准,当找到符合要求的人时停止面试。
这种方法在寻找类似的关注艺术家集合时也适用。这个想法是你希望从查询集中关注 最少粉丝数 的艺术家开始。为什么?因为这些艺术家提供的候选集较少,这样你可以处理更少的倒排索引列表,更快找到你最佳的 k 个候选。在我的 Spotify 关注列表中,wasian project 只有 100 万粉丝——远少于 taylor swift。那些关注 wasian project 的粉丝数量远少于关注 taylor swift 的粉丝,但他们有同样的潜力成为最佳 k 个候选。
这里的关键洞察是我们不希望处理所有潜在候选列表,而是在处理到足够数量时停止。棘手的部分是知道何时停止。以下是一个修改版的算法,实现了这个思想。
import heapq
def search_top_k_probe_set(index, sets, q, k):
# Initialize a priority heap to store the current top-k candidates.
heap = []
# Initialize a set for tracking probed candidates.
seen_ids = set()
# Iterate over elements in q from the least to the most frequent based
# on the lengths of their lists in the inverted index.
for i, x in enumerate(sorted(q, key=lambda x: len(index[x]))):
ids = index[x] # Get a list of set IDs from the index.
for id in ids:
if id in seen_ids:
continue # Skip seen candidate.
s = sets[id]
intersect_size = len(q.intersection(s))
jaccard = intersect_size / (len(q) + len(s) - intersect_size)
# Add the candidate to the priority heap.
if len(heap) < k:
heapq.heappush(heap, (jaccard, id))
else:
# Only candidates with higher Jaccard than the k-th
# current candidate will be added in this operation.
heapq.heappushpop(heap, (jaccard, id))
seen_ids.add(id)
# If any new candidate from the remaining lists cannot have higher
# Jaccard than any of the current best k candidates, we do not need
# to do any more work.
if (len(q) - i - 1) / len(q) (<= min(heap)[0]:
break
# Return the best k candidates.
return [id for _, id in heapq.nlargest(k, heap)]
search_top_k_probe_set 算法为它找到的每个新候选计算 Jaccard 相似度。它始终跟踪当前最佳的 k 个候选,并在任何新候选的上界 Jaccard 相似度不大于当前最佳 k 个候选的最小 Jaccard 相似度时停止。

search_top_k_probe_set 算法遍历倒排索引列表,为每个遇到的候选集计算 Jaccard 相似度,并跟踪当前的 top-k 候选集。当未处理列表中任何集合的最大 Jaccard 相似度不大于当前 top-k 候选集的最小相似度时,它就会停止。图片由作者提供。
如何计算 Jaccard 相似度的上界?在处理了 n 个候选列表之后,对于任何未见候选,其与查询集的最大交集最多等于剩余未处理列表的数量:|Q|-n。我们给予它最大的怀疑,以便认为该候选可能出现在每一个剩余的 |Q|-n 列表中。现在我们可以使用简单的数学推导该候选 X 的上界 Jaccard 相似度。

计算一个未见候选索引集 X 和查询集 Q 之间 Jaccard 相似度的上界的公式,经过处理了 n 个候选列表之后。
这个巧妙的技术在集合相似性搜索研究文献中被称为前缀过滤器。我写了一篇论文,详细讲述了这一点以及进一步的算法优化。我还创建了一个 Python 库SetSimilaritySearch,实现了一个更优化的 search_top_k_probe_set 算法版本,并支持余弦和包含相似性度量。
Jaccard 搜索的近似索引
在最后一部分,我解释了两种基于倒排索引的搜索算法。这些算法是精确的,意味着它们返回的 k 个最佳候选项确实是最好的 k 个候选项。听起来老生常谈?其实,这是我们在设计大规模数据上的搜索算法时应该问自己的问题,因为在许多情况下,获取真实的 k 个最佳候选项并非必要。
再考虑一下 Spotify 的例子:你真的在乎搜索结果可能会遗漏一些与您品味相似的人吗?大多数人都明白,在日常应用程序(如 Google、Spotify、Twitter 等)中,搜索从来不会是详尽无遗或完全准确的。这些应用程序的任务并不足以证明精确搜索的必要性。这就是为什么最广泛使用的搜索算法都是近似的。
使用近似搜索算法主要有两个好处:
-
更快。如果你不再需要精确的结果,你可以省去许多步骤。
-
可预测的资源消耗。这一点不那么明显,但对于几种近似算法,它们的资源使用(例如内存)可以在事先配置,与数据分布无关。
在这篇文章中,我将讨论 Jaccard 最常用的近似索引:最小化局部敏感哈希(MinHash LSH)。
什么是 LSH?
局部敏感哈希索引在计算机科学中确实是奇迹。它们是由数论驱动的算法魔法。在机器学习文献中,它们被称为 k-NN 模型,但与典型的机器学习模型不同,LSH 索引对数据无关,因此它们在相似性条件下的准确性可以在事先确定,而不需要在摄取新数据点或更改数据分布之前进行调整。因此,它们更类似于倒排索引而不是模型。
LSH 索引本质上是一组具有不同哈希函数的哈希表。就像典型的哈希表一样,LSH 索引的哈希函数将一个数据点(例如,一个集合、特征向量或嵌入向量)作为输入,并输出一个二进制哈希键。除此之外,它们没有更多的相似之处。
典型的哈希函数输出的键在整个键空间中是伪随机且均匀分布的。比如,MurmurHash 是一个著名的哈希函数,它在 32 位键空间中输出近乎均匀且随机的哈希值。这意味着对于任何两个输入,例如 abcdefg 和 abcefg,只要它们不同,它们的 MurmurHash 键就不应有相关性,并且在整个 32 位键空间中出现的概率应相同。这是哈希函数的一个期望特性,因为你希望键在哈希桶中均匀分布,以避免链表或不断调整哈希表的大小。
LSH 的哈希函数执行的是相反的操作:对于一对相似的输入,定义相似度的度量空间中,它们的哈希键更有可能相等,而不是另一对不相似输入的哈希键。
这意味着什么?这意味着 LSH 哈希函数对更相似的数据点具有更高的哈希键冲突概率。实际上,我们利用这一较高的冲突概率来进行基于相似度的检索。
MinHash LSH
对于每一个相似度/距离度量,有一个 LSH 哈希函数。对于 Jaccard,这个函数被称为 Minwise Hash Function 或 MinHash function。给定一个输入集合,MinHash 函数使用随机哈希函数处理所有元素,并跟踪观察到的最小哈希值。你可以使用单个 MinHash 函数构建一个 LSH 索引。请参见下图。

使用单个随机哈希函数的 MinHash LSH 索引。图片来源于作者。
MinHash 函数背后的数学理论指出,两个集合具有相同最小哈希值(即哈希键冲突)的概率与它们的 Jaccard 相同。

h(A) 是通过随机哈希函数 h 对 A 中所有元素的哈希值。
min(h(A)) 是 A 中所有元素的最小哈希值。
这是一个神奇的结果,但证明却相当简单。
使用单个 MinHash 函数的 MinHash LSH 索引不能提供令人满意的准确性,因为碰撞概率与 Jaccard 成线性关系。请参见下面的图表以理解原因。

单个 MinHash 函数在查询集合和索引集合之间的 Jaccard 上的碰撞概率。Y 轴是碰撞概率,X 轴是查询集合和索引集合之间的 Jaccard。例如,Jaccard = 0.8 的索引集合与查询集合有 80% 的概率通过索引被检索到;而 Jaccard 为 0.2 的另一个索引集合与查询集合有 20% 的概率被检索到。图片来源于作者。
设想我们在 Jaccard = 0.9 处画一条阈值线:与查询集合 Jaccard 值高于 0.9 的结果是相关的,而 Jaccard 值低于 0.9 的结果是无关的。在搜索的背景下,“假阳性”意味着返回了无关的结果,而“假阴性”意味着没有返回相关的结果。根据上面的图,并查看假阳性对应的区域:如果索引只使用一个 MinHash 函数,它将产生非常高概率的假阳性。
提升 MinHash LSH 的准确性
这就是为什么我们需要另一个 LSH 魔法:一个叫做增强的过程。我们可以将索引提升到更符合指定的相关性阈值。
我们使用m个 MinHash 函数,这些函数通过称为Universal Hashing 的过程生成——基本上是 32 位或 64 位整数相同哈希函数的m个随机排列。对于每个被索引的集合,我们使用通用哈希生成m个最小哈希值。
设想你列出了一个索引集合的m个最小哈希值。我们将每r个哈希值分组到一个哈希值带中,我们创建b个这样的带。这需要m = b * r。

在 MinHash LSH 中,具有 m= 16,b = 4 和 r= 4 的索引集合的最小哈希值。图片由作者提供。
两个集合具有“带碰撞”的概率——两个集合中的所有哈希值在一个带中碰撞,或r个连续的哈希碰撞,是Jaccard(A, B)^r。这比单个哈希值小得多。然而,两个集合之间至少有一个“带碰撞”的概率是1 — (1-Jaccard(A, B)^r)^b。
我们为什么关心1 — (1-Jaccard(A, B)^r)^b?因为这个函数有一个特殊的形状:

使用 b = 32 和 r = 32 的 MinHash LSH 索引的 Jaccard 检索增强概率函数。图片由作者提供。
在上面的图中,你可以看到使用m个 MinHash 函数时,“至少一个带碰撞”概率是一个 S 型曲线函数,在 Jaccard = 0.9 附近急剧上升。假设相关性阈值为 0.9,则该索引的假阳性概率远小于仅使用一个随机哈希函数的索引。
因此,LSH 索引总是使用b个r个 MinHash 函数的带来提升准确性。每个带是一个存储指向索引集合的指针的哈希表。在搜索过程中,任何与查询集合在任何带中碰撞的索引集合都会被返回。

一个使用 b = 4 和 r = 4 的 MinHash LSH 索引。每个带是一个哈希表,其哈希键是 4 个 MinHash 函数的最小哈希值的连接。图片由作者提供。
要构建 MinHash LSH 索引,我们可以指定一个先验的相关性阈值以及基于 Jaccard 相似度的可接受的假阳性和假阴性概率,并在索引任何数据点之前,计算最优的m、b和r。这是使用 LSH 相对于其他近似索引的一个巨大优势。
你可以在 Python 包datasketch中找到我的 MinHash LSH 实现。它还包括其他与 MinHash 相关的算法,如 LSH 森林和加权 MinHash。
最终思考
我在这篇文章中涵盖了很多主题,但我只是浅尝辄止地探讨了 Jaccard 相似度搜索索引。如果你有兴趣阅读更多这些主题,我为你准备了进一步阅读的列表:
-
《大规模数据集挖掘》(Mining of Massive Datasets)由 Jure Leskovec、Anand Rajaraman 和 Jeff Ullman 撰写。第三章详细介绍了 MinHash 和 LSH。我认为这是一个很好的章节,用于获得对 MinHash 的直观理解。请注意,章节中描述的应用侧重于基于 n-gram 的文本匹配。
-
JOSIE:用于在数据湖中查找可连接表的重叠集合相似度搜索。本文的初步部分解释了
search_top_k_merge_list和search_top_k_probe_set算法的直观。主要部分解释了在输入集合很大时(例如表列),如何考虑成本。 -
Datasketch 和 SetSimilaritySearch 库分别实现了最先进的近似和精确 Jaccard 相似度搜索索引。datasketch 项目的问题列表是应用场景和实际考虑的宝贵资源,尤其是在应用 MinHash LSH 时。
那么嵌入呢?
近年来,由于使用像 Transformers 这样的深度神经网络在表示学习上的突破,当输入数据属于嵌入模型训练的相同领域时,学习到的嵌入向量之间的相似度是有意义的。与本文描述的搜索场景相比,这种情况的主要区别在于:
-
嵌入向量是通常具有 60 到 700 维的稠密向量。每一维都是非零的。相比之下,集合在表示为独热编码向量时是稀疏的:10k 到数百万维,但大多数维度为零。
-
余弦相似度(或标准化向量上的点积)通常用于嵌入向量。对于集合,我们使用 Jaccard 相似度。
-
很难为嵌入向量之间的相似性指定一个相关性阈值,因为这些向量是对原始数据(如图像或文本)的黑箱表示。另一方面,集合的 Jaccard 相似度阈值要容易得多,因为集合是原始数据。
由于上述差异,比较嵌入和集合并不是直接的,因为它们是明显不同的数据类型,即使你可以将它们都归类为高维数据。它们适用于不同的应用场景。
使用分支定界找到最佳解决方案
原文:
towardsdatascience.com/finding-optimal-solutions-with-branch-and-bound-70a64692a0dd

Robocat 和猫咪一起玩耍。图像由作者使用 Dall·E 创建。
解决离散优化问题的强大算法
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 12 月 5 日
--
分支定界是许多混合整数规划 (MIP) 求解器的核心算法。它是你数学优化工具包中的一项极佳补充,特别适用于较小的问题或有许多约束的问题。此外,它的直接性使得它易于访问,无需复杂的数学公式。
在这篇实用文章中,我们将深入探讨一个数学优化问题。我们将使用分支定界算法来解决这个问题,这是一种解决此类问题的优秀技术。我们将重点讨论一个以猫为主题的问题——因为说实话,谁不喜欢猫呢?不过,如果你更喜欢狗,可以在讨论中每次遇到‘猫’时,脑补成‘狗’。原理和方法完全相同!
问题介绍
假设你是一个猫庇护所的主人。每天,宠物主人可以带他们的猫来你这里,你负责照顾它们。许多人在 COVID 期间领养了猫,但现在大家都需要回到办公室。这就是你公司的业绩蒸蒸日上的原因。
实际上,有些过于出色。你在将所有猫安置到你建筑中的房间时遇到了困难。有时你不得不拒绝人们的请求,因为要求太多了。这就是为什么你决定创建一个优化算法,以帮助你找到所有猫登记的最低房间数。
让我们来看一个例子。假设有 3 只猫请求入住你的庇护所。它们的名字是 Lily、Charlie 和 Meowster。我们该如何将这三只猫分配到不同的房间中?我们最多需要三个房间,以下是分组猫咪的所有可能方案:

猫的分割。图片由作者提供。
分割和贝尔数
如你所见,将 3 只猫分组有 5 种可能的方法。在数学中,分组集合元素的一种方式被称为分割。贝尔数对应于给定集合(在我们的例子中,即 3 只猫可以创建 5 种分割)的所有可能分割的总数。这来自组合数学领域。
计算下一个贝尔数的递归公式如下:

计算贝尔数的递归公式。图片由作者提供。
这个数字增长很快:

贝尔数和不同集合大小的计算。图片由作者提供。
一个问题出现了:我们如何以智能的方式找到所需房间的最小数量?当猫的数量增加时,循环遍历所有可能的解决方案是没有意义的。这将耗费大量时间。
从问题到树
首先,让我们重新表述问题。然后,我们将使用分支限界算法解决它。
我们可以从没有猫开始,然后一个一个地添加所有猫,而不是创建所有可能的解决方案并计算它们的分数(我们使用了多少房间)。这看起来像一棵带有节点和边的树:

解决方案树。在每一行中增加一只猫。这给我们所有可能的分割。图片由作者提供。
你是否看到了如何在一只新猫(元素)到来时创建所有可能的解决方案(分割)?让我们再增加一只猫:Pawsanova。这很简单,因为我们可以将它添加到所有现有的组中,并作为每个分割中的一个单独组。这对应于图片中的红点:

再增加一只猫,Pawsanova。可以在 15 个位置‘插入’它。图片由作者提供。
对于第四只猫,共有 15 种可能的分割方式。这对应于 4 个元素集合的贝尔数。
在使用分支限界算法解决问题时,我们将使用这个包含可能解决方案的树。
分支限界
现在我们从原始问题创建了一个树,我们可以利用这个树来运行分支限界算法。分支限界是如何工作的?直观地说很简单:
分支限界将问题分解成更小、更易于管理的子问题(分支),然后为这些子问题确定界限,以有效地找到最优解。
分支
将问题分解成更小的子问题正是我们通过创建一个猫的树所做的。我们从没有猫开始,根节点,然后一个一个地添加猫。每个子问题由树的一个节点表示。
界限
分支限界的一个重要部分是确定问题的界限。它是这样工作的:
-
上界与目前为止找到的最佳分数相对应。
-
下界 是我们在继续当前树分支时可以找到的最佳分数。这是一个乐观的分数,计算时间不应过长。
剪枝
通过将当前节点的下界与上界进行比较,我们可以决定是否继续搜索树的当前部分。如果下界比上界差,我们可以丢弃这部分树。通过丢弃那些不能产生更好解决方案的子问题,算法减少了需要检查的解决方案数量。这个过程被称为剪枝,对算法的效率至关重要。
另一种剪枝树的方法是检查当前子问题是否可行。例如,如果我们已经知道两只猫不能待在一个房间里,因为它们都是急躁的,而当前节点将这两只猫组合在一起,那么我们也可以剪枝这部分树。
探索方法
探索树以找到最佳解决方案有多种方法。你可能对深度优先搜索(DFS)和广度优先搜索(BFS)有所了解。深度优先搜索会尽可能深入地探索一条分支。广度优先搜索则会探索当前深度的所有节点(逐个添加猫,就像我们创建树时做的那样)。另一种策略是最佳优先搜索,这种方法根据某些标准(如优化问题中的下界)对节点进行优先级排序。下面你可以看到 BFS 和 DFS 的区别(查看节点序号以了解顺序):

广度优先搜索和深度优先搜索。图片由作者提供。
解决猫问题
了解这些知识后,让我们回到最初的问题。让我们在 Python 中实现分支限界法。我们将使用广度优先搜索,但可以随意重用这段代码以尝试其他探索策略。
我们将实现的一些规则:
-
解决方案的分数将等于我们使用的房间数量。显然,我们希望将其最小化。
-
可行性:一个房间不能有超过 5 只猫,房间内猫的总重量不应超过 25 公斤。此外,一个房间最多只能有一只‘生气’的猫,否则我们会有猫打架...
-
节点的下界通过当前节点的分数加上基于愤怒猫数量所需的最小额外房间数来计算。
注:你只需安装
numpy包。
让我们从一个Cat类开始。
使用generate_n_cats函数,我们可以生成任意数量的猫。
在接下来的代码片段中,我们将实现分支限界算法:
branch_and_bound.py文件包含两个类:Node类和BranchAndBound类。
Node 类包含树中一个节点所需的所有信息。它有一个 partition 属性,你可以在这里找到该节点的猫分组。类中的方法用于计算权重、愤怒猫的数量、该节点的下界和可行性检查。
BranchAndBound 包含该算法。它从一个有效的解决方案(所有猫在分开的房间)开始,并计算该解决方案的得分作为上界。然后,它从没有猫开始,一只一只地添加猫。它添加新的节点,并通过检查可行性和比较下界与上界来判断是否可以剪枝。
这是处理 15 只猫的结果示例:
Score of Branch and Bound: 5
Time of Branch and Bound: 0.03134322166442871
[Cat(name=cat_0, weight=1, character=angry)]
[Cat(name=cat_1, weight=4, character=angry)]
[Cat(name=cat_7, weight=6, character=angry), Cat(name=cat_8, weight=1, character=sweet), Cat(name=cat_9, weight=8, character=sweet), Cat(name=cat_10, weight=6, character=sweet), Cat(name=cat_11, weight=1, character=sweet)]
[Cat(name=cat_2, weight=6, character=sweet), Cat(name=cat_4, weight=7, character=angry), Cat(name=cat_5, weight=4, character=sweet), Cat(name=cat_6, weight=6, character=sweet), Cat(name=cat_12, weight=1, character=sweet)]
[Cat(name=cat_3, weight=8, character=angry), Cat(name=cat_13, weight=4, character=sweet), Cat(name=cat_14, weight=1, character=sweet)]
该算法创建了 5 个组,每组最多有 5 只猫,最多 1 只愤怒的猫,且总重最多为 25 公斤。此外,该算法非常快速,找到(其中之一)最优解的时间不到 0.032 秒。第 15 个贝尔数接近 14 亿,想象一下,如果我们循环遍历所有 14 亿个可能的分区需要多长时间!
优缺点
分支限界法是一个容易理解的算法。树的剪枝使得它比动态规划或暴力破解等算法更高效,因为它跳过了那些不感兴趣的解决方案。分支限界法保证找到最优解,这是一个巨大的优点。但当然,也有缺点。如果我们将算法优化为处理 50 只猫而不是 15 只,它将需要很长时间来运行。在 50 只猫的情况下,树有许多节点且非常深。为了提高效率,我们可能需要考虑其他算法,例如元启发式算法。这里要注意的是,下界计算的强度对性能有很大影响。越严格越好,但确保它确实是下界(不高于该分支的实际最佳解决方案)。而且,针对你的具体问题,尝试不同的探索方法可能会有价值。在猫的例子中,我们从一个不好的初始解决方案(所有猫在不同的房间)开始,而从一个好的解决方案(例如,如果猫能适应,则贪婪地将猫添加到现有房间)开始更好。这可以加快算法的运行速度。
结论
希望你喜欢这个猫收容所的故事!现在你知道了分支限界法是如何工作的,也学到了一些组合数学的知识,比如分区和贝尔数。分支限界法是离散优化问题中使用的一种基本技术。许多 MIP 求解器将分支限界法作为其核心算法。通过剪枝,你不必遍历所有可能的解决方案。这节省了很多时间,但要小心,对于具有大型解决方案树的问题,它仍然可能需要(过于)长的时间。
感谢阅读,下次见!
相关内容
数据科学课程专注于数据可视化、特征工程、数据处理、(非)监督学习…
towardsdatascience.com ## 每个数据科学家应该了解的数学优化启发式方法
局部搜索、遗传算法等。
towardsdatascience.com ## 如何处理优化问题?
简单的例子和解决方案以及代码。
towardsdatascience.com
利用多项式混沌扩展、使用 uncertainpy 和 chaospy 寻找混乱中的秩序
下面是如何利用数学、物理学、Python 和数据科学来解决混乱问题的方法
·
关注 发布在 Towards Data Science ·9 min read·2023 年 10 月 12 日
--
图片由作者使用 Midjourney 生成
三年前,我从意大利罗马搬到了美国俄亥俄州辛辛那提,接受了辛辛那提大学的博士 offer。我非常怀念我的城市,有很多东西:美食、天气、永恒之城的美丽。我绝对不怀念我城市的一件事是疯狂的交通。
我的一个好朋友前几天给我发了短信说
“皮耶罗,今天交通太糟糕了,城市一片混乱。”
现在,显然,我没有纠正他(尤其是知道罗马的交通情况),但术语混乱在数学和物理中有着与我们日常生活中使用的“混乱”完全不同的意义。
当我们谈到数学中的混乱时,一个流行的定义是:一个由确定性方程控制的问题,但 系统的演变极度依赖于初始条件。这意味着即使初始条件发生极其微小的变化,系统的演变也可能会极其不同。用洛伦茨¹的话来说,这意味着:
“现在决定未来,但近似的现在不能近似决定未来。”
¹ mpe.dimacs.rutgers.edu/2013/03/17/chaos-in-an-atmosphere-hanging-on-a-wall/
这意味着我们能够预测状态的演变的唯一方法是从概率视角考虑它。给定过程的起始点,我们无法准确预测系统的到达点,因为它是混乱的,但我们能够概率性地预测,比如我们可以得到平均值和初始偏差。
这种混乱可以通过数值方法处理,例如使用Python。在这篇博客文章中,我们将从抽象的随机游走开始,描述多项式混沌展开(PCE),并应用于一个实际案例,如我们的咖啡温度☕️。
让我们开始吧!
1. 随机游走
随机游走是所有阅读的数学家和物理学家都很熟悉的东西。这个模型几乎在所有地方都被使用,从金融到物理,它非常简单。它在文献中也被称为布朗运动,其工作原理如下:
-
我们从点 x = 0 开始
-
以相同的概率,我们可以从 x=0 到 x=1,或者从 x=0 到 x=-1。我们将这个点定义为 x_1
-
再次,我们可以将 x_1 的值增加或减少 1。我们将定义这一点为 x_2。
-
我们用 x_2 重复第 3 点 N-2 次
有时候,我觉得伪代码比用文字解释更容易理解
RandomWalk(N):
x = 0
i = 0
while i<N:
p = random(-1,1)
x = x+p
i = i+1
return x
现在我们来探讨一下这个问题,好吗?
1.2 代码
在这一部分中,我们将使用 Python 语言代码描述随机游走。你需要导入像numpy和matplotlib.pyplot这样的基本库。
这是随机游走代码:
如果我们运行这个,比如 100 次,我们会得到以下路径:
很有趣的是,如果你考虑最后一步,你可以发现高斯分布。

现在我们先留在这里。我保证,我们会用到这些。
2. 微分方程
现在,关于生活的一切,等一下,字面上,我们所知道的关于生活的一切都是因为微分方程。
微分方程是物理学用来描述系统演变的工具。我的高中老师通过这样解释:
“描述世界你需要两样东西:区分和积分。区分非常容易,积分非常困难。”
例如,让我们考虑爬树的松鼠的y位置:

图片由作者使用 Midjourney 生成
假设松鼠的速度是 v(t) = (t/60)2,其中 t = 秒。所以我们的超级英雄开始时的 v(t=0) = 0,经过两分钟,他的速度达到 v(120) = 22 = 4 m/s。
给定这些信息,超级松鼠的位置在哪里?
我们需要做的是积分速度方程,然后得到:

我们如何得到这个 c 常数?我们只是设置 t = 0 时发生的情况。我们假设我们的松鼠从高度 = 0 开始,所以 c = 0。
所以我们爬树的超级松鼠的位置如下:

通常,一个特定的解决方案 y 可以看作是另一量,例如 x 的积分以及一个初始条件。
在这个方案中,我们讨论了:
-
时间(t),这是时间变量(从实验开始到结束)
-
x,即我们正在积分的对象(在上述例子中,x是速度)
-
解决方案(y)是我们通过积分x得到的解决方案(在上述例子中,y是超级松鼠的位置)
所以在这种情况下,我们可以说 x (t) = y(t) 的积分。
还有更多。你可以在系统中拥有一些固定的参数,但这些参数可以改变系统的演变。 所以:
x(t, 参数列表) = y(t, 参数列表) 的积分
例如。我们来谈谈咖啡。咖啡?是的,咖啡。
2.1 牛顿的冷却定律
一位在物理学上相当出色(哈哈)的家伙,名叫艾萨克·牛顿,在他留给我们的众多礼物中,解释了如何描述热体的热传导。换句话说,他告诉我们物体如何冷却。

图片由作者使用 Midjourney 生成
牛顿提出的冷却定律指出,从身体到外界的热量流失速率与一个常数 k 成正比,该常数依赖于表面积及其热传导系数,以及时间 t 的温度 T 和环境温度 T_env 之间的差异。
如果我们想获得温度(T),我们需要积分热量流失速率(dT/dt)。这是方程:

图片由作者提供
要获得温度 T,给定 T_env 和 k(记住这一点!!!),我们需要积分 dT/dt。
3. 威纳的混沌!
我们非常罕见地能够定义微分方程的解析解。这就是我高中教授说它很难积分的原因。我们更可能需要进行数值积分,即用数值方式(也就是用算法)解决微分方程。
有一些非常著名的积分方法(算法),比如梯形规则或黎曼和。它们有效,有利有弊,并且高效。它们不是问题所在。
真正的问题在于微分方程的参数(例如 T_env 和 kappa)。让我详细说明一下。
你还记得上面方程中的T_env和k吗?我们完全不知道它们实际是什么,它们可能会完全改变我们系统的演化。
诺伯特·维纳的美丽心灵为带有额外随机参数的微分方程提供了非常优雅的公式。特别是,现在我们所有的讨论都很有意义了,带有随机参数的微分方程被定义为混沌的,可以用随机游走(啊哈!)作为多项式来描述。通过这样做,我们能够以概率方式理解解 T(t)!
我理解这可能会让人困惑:让我们一步一步来 😃
3.1 设置
我们在 Python 中需要做的第一件事是定义我们的微分方程:
如我们所见,这不仅仅是T(我们的变量)的问题,还有kappa 和 T_env的问题。
这是我们需要积分的函数。在此之前,让我们导入一些朋友🦸♂️
你可能会遇到错误,因为你没有chaospy 和 uncertainpy。它们是我们的魔法小助手:它们实现了多项式混沌展开方法。安装它们非常简单:
pip install uncertainpy
3.2 积分函数
让我们使用梯形规则设置需要积分的函数:
所以:
-
我们设定咖啡的初始温度,比如 T_0 = 95
-
我们设定我们问题的时间步长,比如 500 个时间步
-
我们使用梯形规则进行积分
-
我们返回时间和温度
3.3 关于 uncertainpy
现在,我得说一下:uncertainpy 非常棒。你可以用它做很多事情,我真的推荐你花时间在这里了解一下。
我们要做的是:
-
我们设定一个可能的kappa分布。例如,kappa 从具有给定 mu 和方差的正态分布中抽取。
-
对T_env做同样的处理
-
我们应用 uncertainpy 并提取给定输入分布的温度可能值分布
游戏很简单:如果我们知道得益于 Wiener 的混沌参数的可能分布,我们就能知道输出的分布。
这可能听起来有点混乱,但我保证展示代码后会更清楚:
模型通过coffee_cup定义,这就是我们的微分方程。接着,我们定义参数分布(使用chaospy)并定义相应的参数字典。
现在。对于每个 kappa 和 T_env 值,我们有一个具有不同参数和不同温度 T(t) 的微分方程,这些是积分的结果。得益于神奇的chaospy,解决方案变成了一个具有均值和标准差的分布。
让我们看看如何操作:
就是这样!(就像 Biggie 所说)。非常简单。
3.4 整体内容
整个内容可以放在这个代码块中:

这难道不美好吗?我们能够将输入参数的分布转化为输出结果的分布。在时间 t=0 时,温度是 T= T_0 = 95。随着时间的推移,参数的不确定性变得越来越明显。在时间 = 200 分钟时,我们有一个很大的不确定性(假设从 5 到 30),这取决于 k 和 T_env,可能会很冷或稍微有点热。
4. 结果
在这篇博客文章中,我们描述了美妙的chaospy 和 uncertainpy 库。这些库使我们能够处理Wiener 混沌问题,它使用随机游走来定义一种多项式混沌。这种多项式混沌用于处理带有分布的微分方程,而不是参数。我们按以下顺序进行了操作:
-
我们在第一章中描述了随机游走。
-
我们在第二章中描述了微分方程。特别是我们描述了牛顿冷却定律。
-
我们按照Wiener的描述了混沌并在第三章中应用了多项式混沌。
5. 结论
如果你喜欢这篇文章并想了解更多关于机器学习的内容,或者你只是想问我一些问题,你可以:
A. 在 Linkedin 上关注我,我会发布所有的故事。
B. 订阅我的 通讯。它将让你了解新故事,并给你机会与我联系,获取所有可能的更正或疑问。
C. 成为一个 会员,这样你就不会有“每月故事的最大数量限制”,可以阅读我(以及成千上万其他机器学习和数据科学顶级作者)写的关于最新技术的内容。
通过随机森林找到我们的路
原文:
towardsdatascience.com/finding-our-way-through-a-random-forest-5ff6c1382572
或者在一个被僵尸困扰的假想世界里,决策树如何能决定你是否能够脱身
·发表于 Towards Data Science ·阅读时长 17 分钟·2023 年 4 月 19 日
--

图片由作者使用 Midjourney 制作
车库外面,咆哮和咆哮声不断。他简直不敢相信自己在系列剧和电影中多次看到的僵尸末日最终出现在了自己门前。他可以在车库里藏一段时间,但最终还是得出来。他该带斧头还是仅用步枪就够了?他可以试着找些食物,但应该一个人去吗?他试图回忆起所有看过的僵尸电影,却无法达成一致的策略。如果他能记住每个角色被僵尸杀死的场景,这是否能增加他的生存几率?如果他有一个决策指南,一切都会更简单……
介绍
你是否看过那些僵尸末日电影,其中总有一个角色似乎总是知道僵尸藏在哪里,或者知道是该战斗还是逃跑?这个人真的知道接下来会发生什么吗?有人事先告诉过他/她吗?也许这并没有什么神奇的地方。也许这个人读过很多关于僵尸的漫画,并且非常擅长根据每种情况做出正确的决策,学习别人的错误。找到使用过去事件作为我们决策指南的最佳方式是多么重要!这个指南,也称为决策树,是一种广泛使用的监督学习算法。本文是关于决策树的入门讨论,讲述了如何构建它们以及为什么许多决策树会创建一个随机森林。
一个简单的决策
你处在僵尸混乱之中,你想知道如何提高生存的机会。在这一点上,你只有 15 个朋友的信息。你知道每一个朋友是否孤身一人,是否有车辆或武器,或者是否经过战斗训练。最重要的是,你知道他们是否能够生存下来。你如何利用这些信息来为自己争取优势?
表 1 总结了你 15 个朋友的结果和特征。你想成为最终幸存的那 3 个人中的一员。这 3 个朋友有什么共同点?简单检查表格会告诉我们这三位幸存者有这些共同点:他们并不孤单,他们经过了战斗训练,他们有车辆和武器。那么,如果你拥有这四样东西,你能幸存吗?过去的经验告诉我们,你有可能!如果你必须决定带上什么以及是否独自一人,至少现在你有一些历史数据来支持你的决定。

表 1. 15 个个体的特征及其在僵尸末日中的最终结果示例 1(表由作者制作)。
更困难的决定
僵尸末日永远不像它看起来那么简单。假设这次你有以下这些朋友,而不是前面例子中的 15 个朋友:

表 2. 15 个个体的特征及其在僵尸末日中的最终结果示例 2 (表由作者制作)。
这一次,仅通过视觉检查得出结论并不那么简单。我们唯一可以确定的是,如果你想要生存,最好身边有一个人。幸存的 5 个人并不是孤身一人(图 1)。除此之外,很难看出是否存在某种特定的组合可以让你存活下来。有些人虽然孤身一人仍然能够生存。他们是怎么做到的?如果你知道自己会孤身一人,还有什么可以做的来增加生存的机会?是否存在类似决策路线图的东西?

图 1. 各特征对示例 2 中 15 个人群体最终结果的影响 (图由作者制作)。
决策树
我们可以在决策树中找到一些对前面问题的答案。决策树是一个根据我们做出的决定来预测结果的模型。这个模型是通过之前的经验构建的。在我们的例子中,我们可以利用 15 个朋友的特征和他们的结果来构建决策树。决策树由多个决策节点或分支组成。在这些节点中,我们做出一个决策,决策将引导我们到下一个节点,直到我们得到一个结果。
构建决策树
如果有人让你画一个家谱,你可能会从你的祖父母或曾祖父母开始。从那里,家谱将通过你的父母、叔叔和表亲逐渐展开,直到到达你。类似地,要构建决策树,你总是从一个能够最好地分离数据的节点开始。从这一点开始,树将根据最佳分割数据的特征开始生长。有许多算法可以用来构建决策树。本文解释了如何使用信息增益和香农熵。
让我们关注表 2. 我们可以看到有 5 人幸存,10 人死亡。这意味着幸存的概率是 5/15 = ⅓,死亡的概率是 ⅔。利用这些信息,我们可以计算这种分布的熵。在这种情况下,熵指的是这种分布的平均惊讶或不确定性水平。为了计算熵,我们使用以下方程:

注意,这个方程也可以用其中一个概率来表示,因为 p(surv)+p(die)=1。如果我们绘制这个函数,你会看到当 p(surv) 和 p(die) 都等于 0.5 时,熵的值达到 1 的最高值。相反,如果整个分布对应于所有人都生存或所有人都死亡的情况,则熵为零。因此,熵越高,不确定性越高。熵越低,分布越均匀,我们对结果的惊讶程度就越小。

图 2. 熵作为每个事件概率的函数的图示。绿色曲线代表死亡的概率,橙色曲线代表生存的概率,蓝色曲线是这两种概率的总和,因为 p(surv)+p(die)=1 (由作者制作的图像)。
在我们的案例中,幸存者的数量不到总人口的一半。合理的推测是大多数人没有幸存于僵尸末日。在这种情况下,熵为 0.92,这就是你在图 2 的蓝色曲线中搜索 x=⅓ 或 ⅔ 时得到的值,或者当你应用以下方程时得到的值:

现在我们知道了整个分布的熵或不确定性的程度,我们应该怎么做?下一步是找到如何划分数据,以保持这种不确定性水平。
信息增益的前提是选择能够最小化前一节点熵水平的决策节点。在这个阶段,我们正试图找出数据的最佳首次分离点。是我们单独存在、知道如何战斗,还是拥有交通工具或武器?为了知道答案,我们可以计算每个选择的信息增益,然后决定哪一个具有最大的增益。记住,我们试图最小化熵的变化,即结果分布中的异质性或惊讶程度。
你受过战斗训练吗?
这是你在这种情况下应该问自己的第一个问题吗?这个问题会最小化结果分布中熵的变化吗?为了知道这一点,让我们计算这两个案例的熵:我们知道如何战斗和我们不知道如何战斗。图 3 显示,在 9 名知道如何战斗的人中,只有 5 名幸存了。相反,所有 6 名没有接受战斗训练的人都没有生存下来。

图 3. 根据战斗能力的示例 2 结果(图像由作者制作)。
要计算之前案例的熵,我们可以应用之前使用的相同公式。图 3 显示,经过战斗训练的情况下,熵为 0.99,而在其他情况下,熵为零。请记住,熵为零意味着没有惊讶,分布均匀,这实际上是因为所有没有接受战斗训练的人都没有生存下来。在此阶段,重要的是要注意,在第二种情况下的熵计算包含一个未定义的计算,因为我们最终会得到一个零的对数。在这些情况下,你可以始终应用洛必达法则,正如在这篇文章中所解释的那样。
我们现在需要计算这个决策的信息增益。这与问如果我决定根据这个问题划分所有结果,所有决策的不确定性会改变多少是一样的。信息增益通过从主要节点的熵中减去每个决策的熵来计算。一个重要的事情是注意到,这个操作根据每个决策的个体数量加权。因此,即使熵很大,如果做出该决策的人数量较少,也可能对信息增益计算的影响较小。对于这个例子,战斗能力的信息增益为 0.32,如图 3 所示。
你能单独幸存下来吗?
我们可以对单独生存或与他人一起生存的可能性进行类似的分析。图 4 展示了计算过程。在这种情况下,信息增益为 0.52。注意到在这种情况下,单独一人从未导致生存,而在不孤单的情况下,7 个案例中有 5 个幸存下来。

图 4. 根据是否独自一人,示例 2 的结果(图像由作者制作)。
拥有车辆或武器如何?
对于这两种情况,我们可以像之前一样计算信息增益(图 5)。你可以看到这些信息增益比之前计算的要小。这意味着,此时根据前两个特征来划分数据比按照这两个特征划分要更好。记住,最大的增益对应于熵减少最小的特征。一旦我们计算了所有特征的信息增益,就可以决定决策树的第一个节点是什么。

图 4. 根据是否拥有车辆或武器,示例 2 的结果(图像由作者制作)。
第一个节点
表 3 显示了每个特征的信息增益。最大的增益对应于是否独自一人或有伴侣。这一节点将我们带到决策树中的第一个决策:你将无法单独生存。8 个独自一人的人无论是否有武器、汽车或接受过战斗训练都无法生存。因此,这是我们从分析中可以得出的第一个结论,这也支持了我们仅通过检查数据得出的结论。

表 3. 第一个节点分析后每个特征的信息增益(表格由作者制作)。
此时,决策树的样子如图 5 所示。我们知道如果我们独自一人是没有生存的可能性(考虑到我们拥有的数据)。如果我们不独自一人,那么我们可能会生存下来,但并非所有情况下都如此。由于我们可以计算图 5 中右侧节点的熵,即 0.86(该计算过程在图 4 中展示),我们也可以计算其他三个特征的信息增益,并决定下一个决策节点是什么。

图 5. 第一个决策节点后的决策树(图像由作者制作)。
第二个节点
图 5 显示,此时最大的增益来自于武器特征,因此这是下一个决策节点,如图 6 所示。注意到所有未独自一人且拥有武器的人都幸存下来,这就是为什么武器节点的左侧以生存决策结束。

图 6. 第二个决策节点后的决策树(图像由作者制作)。
树是完整的
还有 3 个人没有单独存在且没有武器,我们需要对他们进行分类。如果我们遵循之前解释的相同过程,我们会发现下一个信息增益最大的特征是车辆。因此,我们可以在树上添加一个额外的节点,询问某个人是否拥有车辆。这将把剩下的 3 个人分成一个有 2 人拥有车辆但未幸存的组和一个没有车辆但幸存的单独个体。最终的决策树如图 7 所示。

图 7. 示例 2 的最终决策树(作者制作的图像)。
决策树的问题
如你所见,决策树是一个基于以往经验构建的模型。根据数据中的特征数量,你会遇到多个问题,这些问题将引导你到最终答案。值得注意的是,在这种情况下,某些特征在决策树中并未体现。战斗能力从未被选择为决策节点,因为其他特征总是有更大的信息增益。这意味着,根据输入数据,训练战斗能力对于在僵尸 apocalypse 中生存并不重要。然而,这也可能意味着我们没有足够的样本来确定战斗能力是否重要。关键在于记住,决策树的好坏取决于我们用于构建它的输入数据。在这种情况下,15 个人的样本可能不足以对战斗训练的重要性进行准确估计。这是决策树的一个问题。
与其他监督学习方法一样,决策树并不完美。一方面,它们非常依赖输入数据。这意味着输入数据的微小变化可能会导致最终树的重要变化。决策树并不擅长于概括。另一方面,它们往往存在过拟合问题。换句话说,我们可能会得到一个与输入数据完全匹配但在测试集上表现极差的复杂决策树。如果我们使用带有连续变量的决策树而非像示例中那样的分类变量,这也可能影响结果。
提高决策树效率的一种方法是修剪它们。这意味着在达到纯节点之前停止算法,就像我们在示例中所达到的那样。这可能导致删除那些对决策树准确性没有改善的分支。修剪赋予决策树更多的概括能力。然而,如果我们决定修剪决策树,那么我们可能会开始提出额外的问题,比如:何时是停止算法的正确时机?我们应该在达到最小样本数后停止?还是在预定义的节点数后停止?如何确定这些数字?修剪确实可以帮助我们避免过拟合,但它可能带来一些不易回答的额外问题。
那么,如何看待整个森林而不是单棵树呢?
如果我们不是使用单一的决策树,而是多个决策树呢?它们将根据所处理的输入数据部分、读取的特征以及修剪特性而变化。最终我们将得到许多决策树和不同的答案,但在分类任务中我们可以选择多数结果,或在回归任务中取平均。这可以帮助我们更好地概括数据的分布。我们可能会认为某棵决策树存在误分类,但如果我们找到 10 或 20 棵树得出相同的结论,那么这可能表明实际上没有误分类。基本上,我们是让多数决定,而不是依赖单棵决策树。这种方法称为随机森林。
随机森林的概念通常与袋装法相关,袋装法是一个从训练集中随机抽取数据样本的过程,并且有放回地选择。这意味着个别数据点可以被选择多次。在随机森林方法中,我们可以选择一个随机数量的点,建立一棵决策树,然后重复这一过程直到得到多棵树。最后的决策将来自所有树的答案。
随机森林是一种著名的集成方法,用于分类和回归问题。该方法已在金融、医疗保健和电子商务等多个行业中应用[1]。虽然随机森林的原始思想是由许多研究人员逐步发展的,但莱奥·布雷曼通常被认为是这一方法的创始人[2]。他的个人网页中包含了随机森林的详细描述和其工作原理的广泛解释。这是篇较长但值得阅读的文章。
关于随机森林,一个重要的理解点是它如何处理数据集的特征。在每个节点,随机森林会随机选择一个预定义的特征数量,而不是所有特征,来决定如何分裂每个节点。记住在之前的例子中,我们分析了决策树每个层级的每个特征的信息增益。相反,随机森林会在每个节点只分析特征子集的信息增益。因此,随机森林将 Bagging 与每个节点的随机变量选择混合在一起。
一个真正的僵尸末日
回到僵尸的问题!之前的例子非常简单,我们有 15 个人的数据,每个人只有 4 个特征。让我们把难度加大!假设现在我们有一个超过一千条记录的数据集,每条记录有 10 个特征。这个数据集是在 Excel 中随机生成的,不属于任何商业或私人库,你可以从这个GitHub 页面访问它。
对于这些类型的方法论,通常建议将整个数据集分成训练集和测试集。我们将使用训练集来构建决策树和随机森林模型,然后使用测试集来评估它们。为此,我们将使用scikit-learn 库。这个Jupyter Notebook包含了数据集的详细说明、如何加载数据以及如何使用该库构建模型。
整个数据集包含 1024 条记录,其中 212 条(21%)对应幸存者,812 条(79%)对应死亡。我们将数据集分为一个训练集,占数据的 80%(819 条记录),以及一个包含 205 条记录的测试集。图 8 显示了在所有数据集中幸存者与死亡者之间的关系是如何保持的。

图 8. 示例 3 中数据集的训练集和测试集分离(图片由作者制作)。
关于特征,这次我们为每个个体增加了 6 个额外的特征:
-
你有收音机吗?
-
你有食物吗?
-
你上过户外生存课程吗?
-
你上过急救课程吗?
-
你以前遇到过僵尸吗?
-
你有 GPS 吗?
这 6 个特征加上我们已有的 4 个特征,共同代表了每个个体或条目的 10 个不同特征。有了这些信息,我们可以按照之前解释的步骤构建决策树。Jupyter Notebook 使用 DecisionTreeClassifier 函数生成决策树。请注意,此函数不适用于分类变量。在这种情况下,我们已将每个类别的所有答案转换为 -1 或 +1。这意味着每次我们在结果中看到 -1 时,表示“否”,而 +1 表示“是”。这一点在Jupyter Notebook中有更详细的说明。
笔记本 解释了如何加载数据、调用决策树函数并绘制结果。图 9 显示了使用 819 个训练集条目构建的决策树(点击这里查看更大图)。深蓝色框对应于最终的决策节点,其中答案为幸存,而深橙色框表示最终答案为未幸存。你可以看到,第一个决策节点对应于车辆,从那里开始,树根据不同的特征开始生长。

图 9. 示例 3 的训练集决策树(这里是作者制作的可缩放版本)。
我们可以评估这棵树的好坏,如果我们使用测试集输入预测最终类别,然后将这些结果与原始结果进行比较。表 4 显示了一个混淆矩阵,其中列出了决策树将条目误分类的次数。我们可以看到,测试集中有 40 个代表幸存的案例,而决策树仅正确分类了其中的 25 个。另一方面,在 165 个未幸存的案例中,决策树误分类了 11 个。正确分类与 205 个点的整个数据集的关系为 0.87,这通常被称为预测准确率得分。

表 4. 决策树运行后的测试数据集混淆矩阵(图片由作者制作)。
87%的准确率看起来还不错,但我们能通过随机森林提高这个结果吗?下一部分的Jupyter Notebook包含了使用 sklearn 函数 RandomForestClassifier 实现的随机森林。这棵随机森林将包含 10 棵决策树,这些决策树使用所有条目进行构建,但每次分裂时只考虑 3 个特征。随机森林中的每棵决策树考虑 682 个条目,这些条目代表了完整训练集的 84%。因此,明确来说,随机森林过程将:
-
从训练集中随机抽取 682 个条目
-
构建一个决策树,在每个节点考虑 3 个随机选择的特征
-
重复上述步骤 9 次
-
预测将对应于 10 棵决策树中的多数投票
表 5 显示了来自随机森林的结果的混淆矩阵。我们可以看到,这些结果比我们之前用单棵决策树得到的要好。这个随机森林误分类了 11 个条目,预测准确率为 0.95,高于决策树。

表 5. 运行随机森林后测试数据集的混淆矩阵 (图像由作者制作)。
需要注意的是,随机森林方法不仅取决于我们拥有的输入数据的质量,还取决于我们使用的参数选择。我们构建的决策树数量以及每次分裂时分析的参数数量会对最终结果产生重要影响。因此,与许多其他监督学习算法一样,有必要花时间调整参数,直到找到最佳结果。
结论
阅读这篇文章就像电影里那个成功逃脱追赶他的僵尸的家伙,因为一个树枝恰好掉在了僵尸的头上!这并不是他遇到的唯一僵尸,他显然还没有脱离困境!关于随机森林和决策树的许多内容在这篇文章中并未提及。然而,了解这种方法的使用和适用性已经足够。目前,有多个库和程序能够在几秒钟内构建这些模型。所以你可能不需要再重新计算熵和信息增益。不过,理解幕后发生的事情以及如何正确解读结果仍然很重要。在“机器学习”、“集成方法”和“数据分析”等话题日益普及的今天,清楚了解这些方法及其应用于日常问题是很重要的。与僵尸末日求生电影不同,准备好不是偶然发生的。
参考文献
-
IBM. 什么是随机森林?
-
Louppe, Gilles (2014). 《理解随机森林》。博士论文。列日大学


















浙公网安备 33010602011771号