TowardsDataScience-2023-博客中文翻译-八-
TowardsDataScience 2023 博客中文翻译(八)
处理交叉验证中的增强样本数据
原文:
towardsdatascience.com/boosted-sample-data-in-cross-validation-7ee589460238

在交叉验证中排除增强样本数据的测试折叠。图像由 @leddebruijn 提供。
防止数据泄漏和过度估计模型性能的增强样本数据交叉验证的 Python 实现
·发布于 Towards Data Science ·阅读时间 6 分钟·2023 年 4 月 18 日
--
本文介绍了增强(上采样)采样策略数据的交叉验证及其在 Python 中的实现。
增强采样通常用于针对(高度)不平衡数据集中的少数类标签。如果在交叉验证中没有考虑,测试集包含增强采样数据将导致机器学习模型的准确性出现偏差。
本文帮助我们理解这种偏差,并提供了考虑增强样本数据的交叉验证的 Python 实现。涉及以下主题:
-
不平衡分类
-
随机和增强采样
-
数据泄漏
-
一个不平衡数据示例
-
Python 实现
-
结论
不平衡分类
不平衡分类问题是指目标变量分布有偏或倾斜的问题,其中少数类实例较少,而多数类实例较多。
许多现实世界的问题是不平衡的。例如,互联网购物退货远少于购买,垃圾邮件远少于正常邮件(尽管有时看起来并非如此)。
不平衡分类对监督机器学习模型构成挑战,因为这可能导致模型的预测性能较差,尤其是对少数类。由于目标类的不平衡,算法倾向于分类实例更多的类,即多数类。这引入了对模型准确性的错误感知。
随机和增强采样
处理不平衡数据集的一种方法是在数据收集期间平衡目标类别,通常通过一组启发式或业务规则来增强少数类标签采样。
在随机抽样技术中,每个样本被选择的概率是相等的。增强或上采样是一种采样技术,其中一些样本更有可能被选择。它可以人工地采样少数类样本,以平衡不平衡数据集中的类别标签。
例如,抽样年龄在 30 岁以下的人,因为他们比年长的人更频繁地退货。或者抽样包含 guarantee、dollar 和 price 等词的电子邮件,这些都是增强抽样的例子。
数据泄露
我们知道,在分割测试集之前进行上采样会导致数据泄露,因为上采样数据现在也出现在我们的测试集中,使得我们的模型更容易预测数据。
然而,当我们有一个包含上采样数据的监督机器学习模型的数据样本时,该怎么做?
我们希望使用这些数据在更平衡的数据集上训练我们的模型,以学习少数类样本的模式(尽管它们本质上有偏)。但我们不希望使用这些数据来测试我们的模型或进行超参数调优。我们如何在所有数据上进行训练,但仅在我们数据的随机抽样子集上进行测试?

在 CV 测试折叠中排除增强样本数据的交叉验证。图像由 @leddebruijn 提供。
BoostedKFold 允许在交叉验证折叠的所有训练数据上进行训练,但排除由 -1 定义在 split 函数的 groups 参数中的增强样本数据。如下面的图像所示,训练折叠的索引是独立于组索引抽取的,但测试折叠的索引则来自随机抽样的组,而不是来自增强样本组。
一个不平衡数据的例子
让我们看一个例子,使用 sklearn.make_classification 二分类数据集,共 1000 个样本,95% 为多数类,5% 为少数类。30% 的少数类标签是通过增强采样技术获得的,其他 70% 是随机抽取的。
在下面的可视化中,你可以看到 StratifiedKFold 的 5 折交叉验证的混淆矩阵和分类报告。交叉验证是通过 cross_val_predict 在所有测试标签上执行的。正类少数标签 1 总共有 50 个样本,包括增强样本数据和随机样本数据。

用于不平衡数据集的 StratifiedKFold。图像由 @leddebruijn 提供。
现在将其与下面的BoostKFold实现进行比较。通过自定义_cross_val_predict函数在所有测试标签上进行交叉验证,因为cross_val_predict无法处理真实 Y 标签中测试集大小的差异。少数类正标签1共有 35 个样本,包括随机抽取的样本数据,排除了增强的样本数据。

适用于不平衡数据集的 BoostedKFold。图片来自@leddebruijn。
少数类正类的评估指标显著低于以前的StratifiedKFold实现,提供了模型在未见(生产)数据上的更真实表现。这部分由于实例较少(35 比 50),使得少数类更难预测。
然而,性能的下降不能完全用这个来解释。F1 分数下降了一半,而少数类实例仅下降了 30%。性能的剩余下降可以通过去除容易预测的上采样实例来解释,因为模型在训练阶段已经见过类似的上采样数据。
Python 实现
BoostedKFold在内部使用了StratifiedKFold和PredefinedSplit,这些都在.split()方法中实现。我还添加了一个.plot()方法,可以可视化在此故事缩略图中显示的分裂。
def split(self, X: np.array, y: np.array, groups: np.array):
"""Generate indices to split data into training and test set, excluding data in groups with value '-1'.
boosted sample data == '-1' in the ``groups`` parameter
random sample data != '-1' in the ``groups`` parameter
Args:
X (ndarray): array-like of shape (n_samples, n_features)
Training data, where `n_samples` is the number of samples and `n_features` is the number of features.
y (ndarray): array-like of shape (n_samples,),
The target variable for supervised learning problems.
groups (ndarray): array-like of shape 1d: '-1' for elements to be excluded
Yields:
train (ndarray): The training set indices for that split.
test (ndarray): The testing set indices for that split.
"""
# separate boosted sample data that have group ``-1``, from random sample data
boosted_indices = np.where(groups == -1)[0]
random_indices = np.where(groups != -1)[0]
skf = StratifiedKFold(n_splits=self.n_splits, shuffle=self.shuffle)
# split the randomly sampled indices that are to be included in the test-set in ``n_splits`` splits
stratified_random_splits = skf.split(X[random_indices], y[random_indices])
random_sampled = [0] * len(random_indices)
boosted_sampled = [-1] * len(boosted_indices)
# converts the random stratified split test-set indices to the ``n_splits`` enumeration
for split_nr, (_, testing_indices) in enumerate(stratified_random_splits):
# defines which random sample datapoint is in which test-fold
for test_idx in testing_indices:
random_sampled[test_idx] = split_nr
# concatenate the randomly sampled split numbers and the boosted sampling split numbers
predefined_splits = random_sampled + boosted_sampled
# boosted samples are not accounted for in the test-fold splits
ps = PredefinedSplit(test_fold=predefined_splits)
return ps.split(X)
这个实现可以像 scikit-learn 中的其他交叉验证类一样在sklearn.GridSearchCV或sklearn.Pipeline中使用。有关完整脚本,请访问我的GitHub 页面。
结论
上采样数据可以是提升少数类标签的有效方法,这在数据集不平衡的情况下尤为重要。在这个故事中,你了解到数据泄漏的潜在偏差和避免高估分类器性能的方法。
编程愉快!如有任何问题,请随时问我。
通过将 GAN 与扩散模型交叉提升图像生成
一种稳定高效的图像到图像转换方法
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 5 月 9 日
--
ATME 是 GAN ∩ 扩散模型类中的一个模型。图像由 DALL·E 2 生成。
视觉基础模型(VFM)是如Visual ChatGPT¹等前沿技术的核心。在本文中,我们将简要讨论最近的进展,融合 VFM“汤”的两个重要成分:GANs 和扩散模型,最终达到它们的交集 ATME。ATME 是我在论文Look ATME: The Discriminator Mean Entropy Needs Attention²中介绍的一种新模型,GitHub 仓库可在此处找到。
我们将讨论每种生成模型的相关弱点和优势。然后我们讨论合并它们的两类解决方案:朴素的 GAN ∪ Diffusion 和更深入的高效的 GAN ∩ Diffusion 模型类别。最后,你将了解一些 VFM 研究目前如何发展。
生成模型
首先,提供一些背景。条件生成模型的目标是学习如何从目标领域生成数据 y,利用源领域的信息 x。这两个领域可以是图像、文本、语义图、音频等。两种建模方法已取得很大成功:生成对抗网络(GANs)和扩散概率模型。具体来说,
-
GANs 通过训练生成器模型来学习如何从数据分布 p(y∣x) 中采样,生成的数据按照 g(y∣x) 分布。它使用一个鉴别器模型,指导生成器从盲目生成数据到准确生成数据,通过最小化 g 和 p 之间的散度(或距离)来实现。
-
扩散模型通过减少从 p(y∣x) 中采样的潜在变量 y₁、y₂、⋯、yₙ 来学习。这些变量是 y(或 y 的编码)的一系列逐渐噪声化的版本,减少的过程是通过学习去噪模型来完成的。
如果你需要关于这些建模类型的更多细节,网上有大量资源可供查阅。关于 GANs,你可以从这篇文章开始,对于扩散模型,可以从这篇文章开始。

图 1:Visual ChatGPT 演示 来自微软,已获许可。
现在我们已经设置了基础,让我们讨论一些应用。图 1 展示了官方的 Visual ChatGPT 演示。它使用了几个用于视觉-语言交互的模型,部分模型列在下面的表格中。
表 1:一些支持 Visual ChatGPT 的 VFM。信息来源于论文¹的附录。
其中大多数是生成模型,其中大部分基于稳定扩散。这表明了一个近期的兴趣转变,从 GAN 转向扩散模型,这一转变由证据³ 触发,表明后者在图像合成方面优于前者。本文的一个要点是,这并不意味着扩散模型在所有图像生成任务中都优于 GAN,因为这些模型的聚合往往比独立部分表现更好。
在讨论这一点并到达 ATME 之前,让我们通过重新审视 GAN 和扩散模型的主要弱点和优势来铺平道路。
GANs
原始 GAN 论文⁴ 中引入的主要前提,并在教程中强调的是,在一个足够大的模型和无限数据的极限情况下,生成器和判别器之间的极小极大游戏会收敛到纳什均衡,其中(原始)GAN 目标达到值 −log4。然而,在实践中,这几乎没有观察到。这一理论结果的偏离产生了广泛被称为 GAN 训练不稳定的问题。这与模式崩溃一起,是其主要缺点。它们通过轻量级模型一次生成仍能达到较高的图像生成质量来弥补这一点。
扩散模型
相比之下,扩散模型虽然稳定,但由于学习去噪分布所需的步骤过多而效率较低。这是因为这种分布通常假设为高斯分布,这只有在去噪步骤非常小时才是合理的。
最近开发的替代方法通过使用多模态分布来减少去噪步骤的数量(甚至减少到 2 步),这需要将扩散模型与 GAN 结合,如我们在接下来的讨论中所述。
GAN ∪ 扩散
目前的 GAN 与扩散训练方法非常有前景。它们可以被归类为使用生成对抗训练与多步骤扩散过程相结合的 GAN ∪ 扩散模型类。
为了改善 GAN 的训练稳定性和模式覆盖,这些模型通过遵循扩散过程注入实例噪声,这些过程可能有多达几千个步骤(如在 Diffusion-GAN⁵ 中),也可能只有两个步骤(如在 去噪扩散 GAN⁶ 中)。这些模型在各种数据集上的表现优于强 GAN 基线,但仍需多个去噪步骤。因此,
是否有可能在一次生成中使用 GAN 创建图像,并同时利用去噪扩散过程?
答案是肯定的,这定义了 GAN ∩ 扩散类模型。
GAN ∩ 扩散
结果是,单个技巧可以使 pix2pix⁷ 视觉基础 GAN 模型在设计上稳定:关注判别器均值熵。

图 2:ATME 使用来自扩散模型的 UNet 生成图像,这些图像由来自 pix2pix 的 patchGAN 判别器进行判断。图片由作者提供。
结果模型,ATME,如图 2 所示。给定源图像和目标图像的联合分布 p(x,y),输入图像 x 在时代 t 被 Wₜ = W(Dₜ-)所损坏,如下所示。
xₜ = x (1+Wₜ)
其中 W 是一个小型确定性网络,它将前一时代 𝑡- = t-1 的判别器决策图 Dₜ- 转换为 Wₜ。转换后的图 Wₜ 包含与输入空间中的补丁相关的模式,这些补丁是生成器之前未能欺骗判别器的区域,以及来自判别器的噪声,判别器尚未完全优化,因此在决策上出现错误。
生成器看到损坏的源图像,并根据 𝑦̂ = yᵩ(xₜ, 𝑡̃) 生成目标图像,使用用于 去噪扩散概率模型 的网 yᵩ,这些模型具有适当的注意机制。生成器的任务是使 x ⊕ 𝑦̂ 看起来与 x ⊕ y 对判别器不可区分(其中 ⊕ 表示连接)。通过这样做,它学会了去除输入图像中的注入噪声。值得注意的是,去噪发生在一个沿着“时间轴”的展开过程中,而不是像扩散模型那样在每个时代内需要一个独立的时间轴。
Wₜ 信号的展平,作为将输入图像去噪到生成器的结果,转化为 Dₜ 所有条目的平坦分布。这正是纳什均衡,判别器处于最大熵状态。
… 扩散在哪里?
被损坏的输入图像在训练时期的演变可以写作
dxₜ = x dWₜ
这可以视为更一般的 SDE 扩散过程的有限差分实例,即 dxₜ = μ(xₜ,t) dt + σ(xₜ,t) dWₜ,前提是 Wₜ 是一个维纳过程(也称标准布朗运动)。
在 ATME 中,没有设计选择使 W 产生维纳过程。这在大量情况下自然发生,从图 3 中可以看出,通过分析在 Maps 数据集上训练过程中 5 个随机选择的图像的 5 个随机选择像素的时间序列 dWₜ 的属性。
首先,维纳过程是平稳的。这是使用扩展的迪基-富勒检验进行测试的,其结果统计量的 p 值都远低于 0.01,因此所有非平稳时间序列的原假设都被拒绝。其次,维纳过程是马尔可夫过程,因此所有 dWₜ 的自相关函数在所有滞后处应当消失。从图 3 中以 99% 的置信水平可以明显看出这一点。


图 3: 选定图像和像素的 dW 时间序列 (上) 和相应的自相关函数 (下)。图像作者提供。
最后,维纳过程具有高斯 dWₜ。这是使用 Shapiro-Wilk 检验进行的测试(在 64% 的情况下),测试统计量的 p 值大于 0.01,因此不能在相当多的情况下拒绝系列来自正态分布的原假设。

图 4: (重缩放) 判别器决策图 (左) 及其相关表示 (中) 和在源图像空间中的变化 (右)。图像作者提供。
随着模型遍历的轮次增加,Dₜ 的条目趋向于平坦分布,从图 4 中观察到这一点,对应于图 3 中处理的一张图像。可以在 ATME 的 GitHub 代码库 中找到包含这些测试的 Jupyter notebook。
如何检查稳定性?
稳定性意味着,无论数据集和模型权重的初始化如何,只要满足足够的数据和模型容量条件,GAN 的目标就会收敛到相同的值(对于普通 GAN 为 -log4)。在 ATME 中通常是这种情况,正如图 5 所观察到的,更多的示例可以在 论文² 中找到。其他流行的 GAN 无法实现这一点,这也是你可能听说过 GAN 模型“难以训练”的原因。

图 5: ATME 中的 GAN 目标趋向于纳什均衡的理论值。图像作者提供。
通过关注判别器的均值熵,ATME 的去噪程序设计为稳定地将 GAN 带到最大熵平衡状态。这是不是很棒?
结语
ATME 在有监督的图像到图像翻译中取得了最先进的结果,成本低于流行的 GAN 和潜在扩散。如果你喜欢物理学,你可能会对 ATME 的思想如何与麦克斯韦妖对第二热力学定律的违反相关感兴趣。你可以在 论文² 中找到这些内容及更多信息。

麦克斯韦妖 (来源: astrogewgaw). 已获得许可。
技术的进步方式令人印象深刻且鼓舞人心。期待看到 GAN 和扩散模型结合后还会带来什么新进展。
目前就是这些!我希望你和我一样喜欢阅读这篇文章😉
[1] Chenfei Wu 等,《视觉 ChatGPT:与视觉基础模型对话、绘图和编辑》,arXiv 2303.04671(2023)。
[2] Edgardo Solano-Carrillo 等,《Look ATME:判别器均值熵需要关注》,arXiv 2304.09024(2023)。
[3] Prafulla Dhariwal, Alexander Nichol,《扩散模型在图像合成中超越 GAN》,神经信息处理系统进展 34(NeurIPS 2021)。
[4] Ian Goodfellow 等,《生成对抗网络》,神经信息处理系统进展 27(NIPS 2014)。
[5] Zhendong Wang 等,《扩散-GAN:用扩散训练 GAN》,arXiv 2303.04671(2022)。
[6] Zhisheng Xiao, Karsten Kreis, Arash Vahdat,《使用去噪扩散 GAN 解决生成学习三难问题》,国际学习表示大会(ICLR 2022)。
[7] Phillip Isola 等,《基于条件对抗网络的图像到图像翻译》,计算机视觉与模式识别会议(CVPR 2017)。
提升模型准确性:我在 Spotify 机器学习论文中学到的技术(+代码片段)
改善顽固 ML 模型的技术数据科学家工具栈
·发布在Towards Data Science ·12 分钟阅读·2023 年 8 月 24 日
--
这篇文章是记录我在 Spotify 机器学习论文中学习内容的两部分之一。请务必查看 第二篇关于我如何在这项研究中实现特征重要性**.
## SHAP 中的特征重要性分析,我在 Spotify 学习到的(在复仇者的帮助下)
使用 SHAP 识别关键特征并了解它们如何影响机器学习模型的预测结果
towardsdatascience.com
在 2021 年,我花了 8 个月时间构建一个预测模型,以测量用户满意度,这是我在 Spotify 论文的一部分。

图片由作者提供
我的目标是理解是什么使用户对他们的音乐体验感到满意。为此,我构建了一个 LightGBM 分类器,其输出是一个二元响应:
*y = 1 → 用户似乎满意
y = 0 → 不怎么满意*
预测人类满意度是一个挑战,因为人类本质上是不满足的。即使是机器也不适合解读人类心理的奥秘。所以,自然地,我的模型也陷入了困惑。
从人类预测者到占卜师
我的准确率约为 0.5,这在分类器中是最糟糕的结果。这意味着算法有 50%的概率预测“是”或“否”,这与人类的猜测一样随机。
所以我花了 2 个月尝试和结合不同的技术来提高模型的预测能力。最后,我终于将 ROC 分数从 0.5 提高到 0.73,这是一个巨大的成功!
在这篇文章中,我将与您分享我用来显著提高模型准确性的技术。当你的模型无法合作时,这篇文章可能会很有用。
由于这项研究的保密性,我不能分享敏感信息,但我会尽力确保内容不令人困惑。
但首先,确保订阅我的通讯!
点击下面的链接,我会发送更多个性化内容和内幕技巧,帮助你成为数据科学家!
## 加入+1k 读者 💌 关注我在科技+Spotify 的数据科学之旅,别错过!
加入+1k 读者 💌 关注我在科技+Spotify 的数据科学之旅,别错过!通过注册,你…
#0. 数据准备
在深入探讨我使用的方法之前,我想确保你首先掌握基础知识。其中一些方法依赖于对变量的编码以及数据的相应准备,以便它们能够正常工作。我包含的一些代码片段也引用了我在数据准备部分创建的用户定义函数,因此请务必检查它们。

这是我实现步骤中的管道顺序
1. 编码变量
确保你的变量已编码:
-
序数特征, 以便模型保留序数信息
-
类别特征, 以便模型能够解释名义数据
所以首先,让我们将变量存储在某个地方。同样,因为研究是保密的,我不能透露使用的数据,所以我们先用这些数据:
region = ['APAC', 'EU', 'NORTHAM', 'MENA', 'AFRICA']
user_type = ['premium', 'free']
ordinal_list = ['region', 'user_type']
然后,确保构建编码变量的函数:
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
def var_encoding(X, cols, ordinal_list, encoding):
#Function to encode ordinal variables
if encoding == 'ordinal_ordered':
encoder = OrdinalEncoder(categories=ordinal_list)
encoder.fit(X.loc[:, cols])
X.loc[:, cols] = encoder.transform(X.loc[:, cols])
#Function to encode categorical variables
elif encoding == 'ordinal_unordered':
encoder = OrdinalEncoder()
encoder.fit(X.loc[:, cols])
X.loc[:, cols] = encoder.transform(X.loc[:, cols])
else:
encoder = OneHotEncoder(handle_unknown='ignore')
encoder.fit(df.loc[:, cols])
df.loc[:, cols] = encoder.transform(df.loc[:, cols])
return X
然后将该函数应用于你的变量列表。这意味着你需要创建包含变量名称的字符串的列表,即为序数变量、类别变量和数值变量分别创建一个列表。
def encoding_vars(X, ordinal_cols, ordinal_list, preprocessing_categoricals=False):
#Encode ordinal variables
df = var_encoding(df, ordinal_cols, ordinal_list, 'ordinal_ordered')
#Encode categorical variables
if preprocessing_categoricals:
df = var_encoding(df, categorical_cols, 'ordinal_unordered')
#Else set your categorical variables as 'category' if needed
else:
for cat in categorical_cols:
X[cat] = X[cat].astype('category')
#Rename your variables as such if needed to keep track of the order
#An encoded feature such as region will no longer show female or male, but 0 or 1
df.rename(columns={'user_type': 'free_0_premium_1'},
df.reset_index(drop=True, inplace=True)
return df
2. 划分数据
划分你的数据框以获得训练集、验证集和测试集:
-
训练集 — 用于在你选择的算法上训练模型,例如 LightGBM
-
验证集 — 用于超调参数和优化预测结果
-
测试集 — 用于进行最终预测
🔊 记住
在我的研究中,我将数据划分了两次以满足不同的目的。第一次划分在一开始进行,以基于用户级别的划分来创建训练集、验证集和测试集。另一种划分则在进行交叉验证和超参数调整时发生。
初始划分允许数据的更灵活和随机分割,确保每个集合中用户的多样性。测试集留作最终模型评估,而训练集和验证集则用于模型开发和超参数调优。
在我的研究中,我使用了**GroupShuffleSplit**,如下:
from sklearn.model_selection import GroupShuffleSplit
def split_df(df, ordinal_cols, ordinal_list, target):
#splitting train and test
splitter = GroupShuffleSplit(test_size=.13, n_splits=2, random_state=7)
split = splitter.split(df, groups=df['user_id'])
train_inds, test_inds = next(split)
train = df.iloc[train_inds]
test = df.iloc[test_inds]
#splitting validation and test
splitter2 = GroupShuffleSplit(test_size=.5, n_splits=2, random_state=7)
split = splitter2.split(test, groups=test['user_id'])
val_inds, test_inds = next(split)
val = test.iloc[val_inds]
test = test.iloc[test_inds]
#defining X and y
X_train = train.drop(['target_variable'], axis=1)
y_train = train.target_variable
X_val = val.drop(['target_variable'], axis=1)
y_val = val.target_variable
X_test = test.drop(['target_variable'], axis=1)
y_test = test.target_variable
#encoding the variables in the sets based on a pre-defined encoding function
X_train = encoding_vars(X_train, ordinal_cols, ordinal_list)
X_val = encoding_vars(X_val, ordinal_cols, ordinal_list)
X_test = encoding_vars(X_test, ordinal_cols, ordinal_list)
return X_train, y_train, X_val, y_val, X_test, y_test
#1. 特征工程
特征工程在提高模型准确性方面产生了巨大差异。
当涉及到用户收听满意度时,我想知道它是更依赖于用户、他们的流媒体行为,还是其他因素。虽然我拥有的初步用户数据有意义,但缺乏足够的信息增益和预测能力。
我优化过程中的最重要步骤变成了创建能够更好地捕捉用户满意度的新特征。
正如名字所示,创建新特征是一个创造性的过程,这意味着你需要坐下来发挥你的领域知识,思考捕捉重要信息的新颖方法。
我在这个过程中使用的两个主要方法是:
-
特征交互。 我所做的最重要的变换是将已经存在的特征组合在一起,以创建比率。
例如:假设我有一个衡量总播放分钟数的特征,还有一个跟踪新发行曲目的总播放分钟数的特征。我可以在这里做的事情是提取来自新发行的播放分钟数,然后将其除以总播放分钟数,以创建“新音乐播放比率”。这捕捉了完全新的信息。
-
特征聚合。 我做的另一件事是对数据进行时间和组的聚合,以创建汇总特征,如均值或标准差。这意味着你可以在不同的时间组上创建相同的特征,但覆盖不同的聚合。
例如:计算过去 7 天、14 天和 30 天每个播放列表每日播放的曲目数量的平均值。瞧,你刚刚解锁了新信息。
🔊 请记住
特征工程也是一个迭代的过程。你可能需要尝试不同的特征组合、变换和技术,以找到适合你特定问题的最佳特征集。
总是用新的特征在单独的验证集上验证模型的表现,以确保改进不是由于过拟合。
#2. 特征选择
因此,我向模型输入了许多特征,但并不真正知道哪些是相关的。我们可能认为变量越多,模型学习得越好,但如果模型从所有内容中学习,包括垃圾,这最终会比任何东西都更有害。
特征过多意味着其中一些可能会给模型引入噪声,这很糟糕,因为它:
-
隐藏数据中的潜在模式或关系。
-
导致过拟合,因为模型从噪声中学习而不是从真实关系中学习。
-
增加复杂性并减慢训练速度。
为了避免所有这些问题,我们可以使用诸如皮尔逊相关系数、递归特征消除或卡方检验等方法追踪罪魁祸首。
在我的案例中,我使用了前两种方法。
皮尔逊相关系数
该系数衡量两个或更多变量之间的线性关系。
它是两个特征协方差与它们标准差乘积的比率。最终输出在-1 到 1 之间,其中 1 表示变量之间的正线性关系,而-1 表示负关系。
皮尔逊相关系数在特征选择中有两个用途:
-
过滤掉最不重要的特征,这些特征往往与目标变量的相关性较低。
-
限制变量之间的多重共线性以避免因数据冗余而产生的过拟合。
为什么使用? 这是一个计算便宜的统计方法,用于捕捉因变量的内在属性。
如何使用? 相关热图指出了变量之间的线性关系。
def corr_matrix(df):
# Select upper triangle of correlation matrix
upper = df.corr().abs().where(np.triu(np.ones(df.corr().abs().shape), k=1).astype(np.bool))
return upper
def cols_todrop(corr_matrix, threshold):
# Find features with correlation greater than x (you pick your threshold)
to_drop = [col for col in corr_matrix.columns if any(corr_matrix[col] > threshold)]
return to_drop
# Get a ranking of top 10 features with the highest correlation based on your threshold
upper = corr_matrix(data)
upper.returned_1d.sort_values(ascending=False)[:10]
# Plot the correlation heatmap
plt.figure(figsize=(16, 6))
heatmap = sns.heatmap(data.corr(), vmin=-1, vmax=1, annot=True, cmap='BrBG')
heatmap.set_title('Correlation Heatmap', fontdict={'fontsize':15}, pad=12)
plt.savefig('heatmap_pre.png', dpi=300, bbox_inches='tight')
plt.show()
🚨 小心非线性关系!
有时变量之间也可能存在非线性关系,这意味着在过滤多重共线性特征时需要小心。
识别非线性关系可以提供更细致和准确的数据洞察,这意味着你可能想保留这些特征。为此,你可以使用诸如斯皮尔曼等级相关、肯德尔秩相关、散点图等替代方法。
递归特征消除
它通过使用重要性算法对特征进行加权和排序,递归地缩小特征范围。从所有特征开始,它适配所选的机器学习模型,对特征进行排序,并用更小的子集迭代,直到达到所需的特征数量(你最初设定的数量)。
为什么使用? 结果是按重要性排序的特征,这使我们能够将预测能力最差的特征踢出局。
🚨 小心编码!
RFE 需要对分类变量进行先前的数值编码才能工作,所以请参考初始部分进行变量编码。
from sklearn.feature_selection import RFE
selector = RFE(model, n_features_to_select=30, step=1)
selector = selector.fit(X_train, y_train)
rfe_vars_keys = list(X_train.columns)
rfe_vars_values = list(selector.ranking_)
rfe_vars = dict(zip(rfe_vars_keys, rfe_vars_values))
sorted(rfe_vars.items(), key=lambda x: x[1])
我在筛选最不重要的特征时,结合了这两种方法的结果:
-
使用皮尔逊相关系数,我没有发现因变量和目标变量之间有强线性关系。因此,我保留了所有特征 (我也害怕删除非线性关系)。
-
使用递归特征消除,我移除了最低排名的特征 (因为为什么不呢)。
#3. 超参数调整
超参数调整是优化机器学习模型时的强制步骤。基本上,这是寻找能够为你的模型提供最佳性能的参数组合的过程。
在我的研究中,我使用了一种结合**GroupKFold**交叉验证和**RandomizedSearchCV**的两步策略进行超参数调优,这是一种最佳组合,因为:
-
示例数据量非常大(30 万用户)。
-
用户数据需要适当地拆分(我们不想在所有分割中都找到 K 的流数据,不不)。
第 1 步: 使用 GroupKFold 防止数据泄漏
我的数据包含了多个用户的记录。由于数据会被用于超参数调优,我需要防止数据泄漏,确保同一用户的信息不会在训练集和验证集中被分开。
最佳的方法是**GroupKFold**,它通过在每次迭代中使用数据集的不同部分,随机将数据分为训练集和验证集。这会创建具有不同且不重叠用户的独立集。
这对于实现可靠的性能评估至关重要,因为你希望模型在完全未见过的用户上进行测试,而不仅仅是训练过程中见过的用户的新数据。
第 2 步: 使用 RandomizedSearchCV 进行高效的超参数调优
我的样本数据大约有 30 万用户,这是在我的计算能力下不会引发系统崩溃的最大规模。使用**RandomizedSearchCV**在这种大样本情况下要高效得多。效果显著。
它不会像传统的网格搜索那样搜索所有可能的超参数组合,而是随机抽取超参数空间的一个子集。然后使用交叉验证评估所选组合的性能。
✨结果
通过结合这两者,我对多个数据子集进行了超参数调优,这些子集的用户不重叠。这样我能够:
-
解决数据泄漏问题
-
确保计算效率
-
实施稳健的超参数选择基础
def grid_search(X, y, groups):
gkf = GroupKFold(n_splits=5).split(X, y, groups)
model = lgb.LGBMClassifier(objective='binary', verbose=-1, max_depth=-1, random_state=314, metric='None', n_estimators=5000)
grid = RandomizedSearchCV(
model, param_grid, scoring='roc_auc', random_state=314,
n_iter=100, cv=gkf, verbose=10, return_train_score=True, n_jobs=-1)
return grid
grid = grid_search(X, y, groups)
%%time
grid.fit(X, y)
#printing the best hyperparameters
best_params = grid.best_params_
在通过**RandomizedSearchCV**和**GroupKFold**确定最佳超参数后,我们使用**GroupShuffleSplit**的初始训练集和验证集来训练最终模型。
还记得我们在这篇文章一开始创建的split_df()函数吗?我们在这一步中使用它来拆分数据。
# We split the data using our initial function
X_train, y_train, X_val, y_val, X_test, y_test = split_df(df, ordinal_cols, ordinal_dfs, target='target_variable')
然后,我们插入通过超参数调优找到的最佳参数。
# We train the model using the best_params that we got from HP Tuning
clf = lgb.LGBMClassifier(objective='binary', max_depth=-1, random_state=314, metric='roc_auc', n_estimators=5000, num_threads=16, verbose=-1,
**best_params)
%%time
clf.fit(X_train, y_train, eval_set=(X_val, y_val), eval_metric='roc_auc')
# Test set evaluates the final performance of the model on unseen users
roc_auc_score(y_test, clf.predict(X_test))
🔊 请记住
我提到这一点是因为在研究过程中这让我很困惑。eval_set用于在训练过程中监控模型在特定验证集上的表现。这与交叉验证不同,交叉验证评估模型在多个训练-验证分割上的泛化能力。
#4. 数据生成
在实施所有之前的步骤后,我的模型仍然需要额外的提升。由于我的数据中某些组的代表性较弱,我的模型在这些组上的泛化能力有些吃力。
所以我确保为所有代表性不足的用户组生成了一个更大的随机样本。这一步骤正好给了我的模型所需的内容,使其能够正确地泛化数据中的所有美好智慧,并做出可靠的预测。
最后的一句话
请记住,优化模型的过程是一个迭代的过程,这意味着你可能需要结合和重复一些方法,直到达到令人满意的性能。
优化方法回顾
-
特征工程 — 使用不同的方法创建新特征,例如特征聚合、转换、时间数据编码、标准化等,可以为数据引入新的信息。
-
特征选择 — 在创建新特征后,评估它们的重要性,并移除那些对模型性能没有贡献的无关或冗余特征。一些方法包括皮尔逊相关系数、递归特征消除或卡方检验。
-
超参数调优 — 使用 GroupKFold 防止数据泄漏,然后以计算上高效的方式使用 RandomisedSearchCV 寻找最佳参数。
-
数据生成 — 确保样本中的各组得到均等代表,如果需要且可能的话,增加样本大小以涵盖更多的数据点。
我为你准备了礼物 🎁!
注册我的 新闻通讯 K 的 DataLadder,你将自动获得我的 终极 SQL 备忘单,其中包含我每天在大科技公司工作中使用的所有查询+另一个秘密礼物!
我每周分享在科技行业作为数据科学家的经历,包括实用技巧、技能和故事,所有这些都是为了帮助你提升水平——因为没有人真正知道,直到他们亲身体验!
如果你还没有这样做的话
不久见!
提升 PyTorch 在 CPU 上的推理:从训练后量化到多线程
Kaggle 蓝图
如何通过巧妙的模型选择、使用 ONNX Runtime 或 OpenVINO 进行训练后量化以及使用 ThreadPoolExecutor 实现多线程来减少 CPU 上的推理时间
·发表在 Towards Data Science ·8 分钟阅读·2023 年 6 月 13 日
--

欢迎来到另一期 “Kaggle 蓝图” ,我们将分析 Kaggle 竞赛的获胜解决方案,以便将其中的经验应用到我们自己的数据科学项目中。
本版将回顾 “BirdCLEF 2023” 竞赛中的技术和方法,该竞赛于 2023 年 5 月结束。
问题陈述:在有限时间和计算约束下的深度学习推理
BirdCLEF 竞赛是一系列每年在 Kaggle 上举行的竞赛。BirdCLEF 竞赛的主要目标通常是通过声音识别特定的鸟类。参赛者会获得单只鸟叫的短音频文件,然后必须预测在较长的录音中是否存在特定的鸟类。
在早期版本的 《Kaggle 蓝图》 中,我们已经回顾了去年的 “BirdCLEF 2022” 竞赛中的音频分类获胜方法。
“BirdCLEF 2023”竞赛中的一个新颖方面是时间和计算限制:竞争者被要求在 2 小时内在 CPU 笔记本上预测大约 200 个 10 分钟的录音**。
识别声音景观中的鸟类鸣叫
你可能会问,为什么有人会选择在 CPU 上推理深度学习模型而不是 GPU。这是一个常见的实际问题声明[4],因为工作人员(尤其是在保护工作中,但也包括其他行业)常常面临预算限制,因此只能使用有限的计算资源。此外,能够快速做出预测是有帮助的。
由于深入探讨如何用深度学习进行音频分类将重复之前《Kaggle Blueprints》版本关于 BirdCLEF 2022 竞赛的内容,我们将专注于如何加速 CPU 上深度学习模型推理的创新方面。
如果你对深度学习音频分类的获胜方法感兴趣,可以查看之前的版本:
微调图像模型以应对领域偏移和类别不平衡,使用 PyTorch 和 torchaudio 处理音频数据
[towardsdatascience.com
深度学习推理在 CPU 上的方法
在有限时间内在 CPU 上进行推理的主要问题是,你无法创建大量强大且多样化的模型来挤出最后几百分点的性能。根据使用的模型,有些竞争者甚至在单个模型下也难以满足有限的时间要求。
然而,通常较弱的模型的集合会比单一强大的模型表现更好。在竞赛总结中,成功的竞争者分享了一些技巧,说明他们如何加速在 CPU 上的推理,以便能够集合多个模型。
本文介绍了在总结中分享的一些技巧:
-
模型选择
-
训练后量化
-
多线程
模型选择
模型大小严重影响推理时间。一般规则是:模型越大,推理时间越长。
一般规则是:模型越大,推理时间越长。
因此,在为其模型集合选择骨干网络时,竞争者必须评估哪些模型在性能和推理时间之间提供了最佳的折衷。
尽管 NFNet (eca_nfnet_l0) [3, 5, 7, 9, 10, 11, 13, 14, 16] 和 EfficientNet 在去年和今年的比赛中仍然是流行的骨干网络,但我们可以看到,在今年的比赛中,竞争者更倾向于使用 EfficientNet 的小版本。
虽然在BirdCLEF 2022比赛中,tf_efficientnet_b0_ns [8, 11]、tf_efficientnet_b3_ns [8]、tf_efficientnetv2_s_in21k [11, 16]和tf_efficientnetv2_m_in21k [13]非常受欢迎,但今年较小版本的tf_efficientnet_b0_ns [1, 5, 6, 7, 10]和tf_efficientnetv2_s_in21k [1, 6, 15]更受青睐。
以下是 BirdCLEF 竞赛系列中流行模型的参数数量对比。

不同神经网络架构的参数数量。
因此,我们可以看到,成功的竞争者利用了大型模型(eca_nfnet_l0)与小型模型(例如,tf_efficientnet_b0_ns)的组合。
后训练量化
提高 CPU 推理速度的另一个技巧是对训练后的模型应用量化:后训练量化将模型的权重和激活的精度从浮点精度(32 位)降低到较低位宽表示(例如 8 位)。
该技术将模型转化为更适合硬件的表示,从而提高延迟性能。然而,由于权重和激活表示的精度损失,它也可能导致轻微的性能下降。
量化与硬件密切相关。例如,一个 Kaggle 笔记本具有4 个 CPU(Intel(R) Xeon(R) CPU @ 2.20GHz,x86_64 架构)。这些带有 x86 架构的 Intel CPU 更倾向于使用量化数据类型 [INT8](https://www.intel.com/content/www/us/en/developer/articles/technical/accelerate-pytorch-int8-inf-with-new-x86-backend.html)。
提示:要显示有关 CPU 架构的信息,请运行 [*lscpu*](https://man7.org/linux/man-pages/man1/lscpu.1.html) 命令,然后检查制造商的主页,以查看特定 CPU 偏好的量化输入数据类型。
关于后训练量化的详细解释以及 ONNX Runtime 和 OpenVINO 的比较,我推荐这篇文章:
## OpenVINO 与 ONNX 在生产中的 Transformers
Transformers 已经彻底改变了自然语言处理(NLP),成为机器翻译、语义理解等应用的首选。
本节将特别关注两种流行的训练后量化技术:
-
ONNX Runtime
-
OpenVINO
ONNX Runtime
加速 CPU 推理的一个流行方法是将最终模型转换为 ONNX(开放神经网络交换)格式 [2, 7, 9, 10, 14, 15]。
使用 ONNX Runtime 量化和加速 CPU 推理的相关步骤如下:
准备: 安装 ONNX Runtime
pip install onnxruntime
第 1 步: 将 PyTorch 模型转换为 ONNX
import torch
import torchvision
# Define your model here
model = ...
# Train model here
...
# Define dummy_input
dummy_input = torch.randn(1, N_CHANNELS, IMG_WIDTH, IMG_HEIGHT, device="cuda")
# Export PyTorch model to ONNX format
torch.onnx.export(model, dummy_input, "model.onnx")
第 2 步: 使用 ONNX Runtime 会话进行预测
import onnxruntime as rt
# Define X_test with shape (BATCH_SIZE, N_CHANNELS, IMG_WIDTH, IMG_HEIGHT)
X_test = ...
# Define ONNX Runtime session
sess = rt.InferenceSession("model.onnx")
# Make prediction
y_pred = sess.run([], {'input' : X_test})[0]
OpenVINO
另一个同样流行的加速 CPU 推理的方法是使用 OpenVINO(开放视觉推理和神经网络优化)[5, 6, 12],如 这个 Kaggle Notebook 所示:
使用 Kaggle Notebooks 探索和运行机器学习代码 | 使用来自多个数据源的数据
使用 OpenVINO 量化和加速深度学习模型的相关步骤如下:
准备: 安装 OpenVINO
!pip install openvino-dev[onnx]
第 1 步: 将 PyTorch 模型转换为 ONNX(见 ONNX Runtime 的第 1 步)
第 2 步: 将 ONNX 模型转换为 OpenVINO
mo --input_model model.onnx
这将输出一个 XML 文件和一个 BIN 文件——我们将在下一步中使用 XML 文件。
第 3 步: 使用 OpenVINO 量化为 INT8
import openvino.runtime as ov
core = ov.Core()
openvino_model = core.read_model(model='model.xml')
compiled_model = core.compile_model(openvino_model, device_name="CPU")
第 4 步: 使用 OpenVINO 推理请求进行预测
# Define X_test with shape (BATCH_SIZE, N_CHANNELS, IMG_WIDTH, IMG_HEIGHT)
X_test = ...
# Create inference request
infer_request = compiled_model.create_infer_request()
# Make prediction
y_pred = infer_request.infer(inputs=[X_test, 2])
比较:ONNX 与 OpenVINO 与其他替代方案
ONNX 和 OpenVINO 都是优化用于在 CPU 上部署模型的框架。量化的神经网络在 ONNX 和 OpenVINO 上的推理时间被认为是相当的 [12]。
一些竞争对手使用 PyTorch JIT [3] 或 TorchScript [1] 作为加速 CPU 推理的替代方案。然而,其他竞争对手表示 ONNX 比 TorchScript [10] 快得多。
使用 ThreadPoolExecutor 进行多线程
另一种加快 CPU 推断的流行方法是使用多线程和ThreadPoolExecutor [2, 3, 9, 15],此外还包括后训练量化,如这个 Kaggle 笔记本所示:
在 Kaggle 笔记本中探索和运行机器学习代码 | 使用来自多个数据源的数据
www.kaggle.com](https://www.kaggle.com/code/leonshangguan/faster-eb0-sed-model-inference?source=post_page-----6820ac7349bb--------------------------------)
这使得参赛者能够同时运行多个推断。
在比赛中的ThreadPoolExecutor示例中,我们有一个音频文件列表需要推断。
audios = ['audio_1.ogg',
'audio_2.ogg',
# ...,
'audio_n.ogg',]
接下来,你需要定义一个推断函数,该函数以音频文件作为输入并返回预测结果。
def predict(audio_path):
# Define any preprocessing of the audio file here
...
# Make predictions
...
return predictions
有了音频列表(例如audios)和推断函数(例如predict()),你现在可以使用ThreadPoolExecutor来并行运行多个推断,这将显著提高推断速度。
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
dicts = list(executor.map(predict, audios))
总结
从审查 Kagglers 在“BirdCLEF 2023”比赛过程中创建的学习资源中,可以学习到许多更多的课程。对于这种类型的问题声明,还有许多不同的解决方案。
在本文中,我们关注了许多参赛者所采用的通用方法:
-
模型选择:根据性能和推断时间之间的最佳权衡选择模型大小。同时,在你的集成中利用更大和更小的模型。
-
后训练量化:后训练量化可以通过将模型权重和激活的 datatype 优化为硬件来加快推断时间。然而,这可能会导致模型性能的轻微下降。
-
多线程:并行运行多个推断,而不是按顺序进行。这将提高你的推断时间。
如果你对如何使用深度学习进行音频分类感兴趣,这是本次比赛的主要内容,请查看BirdCLEF 2022比赛的总结:
使用 PyTorch 和 torchaudio 在音频数据中调整图像模型以应对领域偏移和类别不平衡
[towardsdatascience.com
喜欢这个故事吗?
免费订阅 以便在我发布新故事时收到通知。
medium.com/@iamleonie/subscribe?source=post_page-----6820ac7349bb-------------------------------- [## 每当 Leonie Monigatti 发布新内容时,会收到电子邮件通知。]
每当 Leonie Monigatti 发布新内容时,会收到电子邮件通知。通过注册,如果你尚未拥有 Medium 帐户,将创建一个。
在 LinkedIn、Twitter 和 Kaggle 上找到我!
参考文献
图像参考
除非另有说明,所有图像均由作者创作。
网络与文献
[1] adsr (2023). 第 3 名解决方案:在 Mel 频带上的 SED 和注意力 在 Kaggle 讨论中(访问日期:2023 年 6 月 1 日)
[2] anonamename (2023). 第 6 名解决方案:BirdNET 嵌入 + CNN 在 Kaggle 讨论中(访问日期:2023 年 6 月 1 日)
[3] atfujita (2023). 第 4 名解决方案:知识蒸馏是你所需要的一切 在 Kaggle 讨论中(访问日期:2023 年 6 月 1 日)
[4] beluga (2023). 推理约束 — CPU 笔记本 <= 120 分钟(访问日期:2023 年 3 月 27 日)。
[5] Harshit Sheoran (2023). 第 9 名解决方案:7 个 CNN 模型组合 在 Kaggle 讨论中(访问日期:2023 年 6 月 1 日)
[6] HONG LIHANG (2023). 第二名解决方案:SED + CNN 结合 7 模型 在 Kaggle 讨论中(访问日期:2023 年 6 月 1 日)
[7] HyeongChan Kim (2023). 第 24 名解决方案 — 预训练 & 单模型(5 折组合与 ONNX) 在 Kaggle 讨论中(访问日期:2023 年 6 月 1 日)
[8] LeonShangguan (2022). [公开 #1 私密 #2] + [私密 #7/8(潜在)] 解决方案。主持人获胜。 在 Kaggle 讨论中(访问日期:2023 年 3 月 13 日)
[9] LeonShangguan (2023). 第 10 名解决方案 在 Kaggle 讨论中(访问日期:2023 年 6 月 1 日)
[10] moritake04 (2023). 第 20 名解决方案:使用 onnx 的 SED + CNN 组合 在 Kaggle 讨论中(访问日期:2023 年 6 月 1 日)
[11] slime (2022). 第 3 名解决方案 在 Kaggle 讨论区(访问日期:2023 年 3 月 13 日)
[12] storm (2023). 第 7 名解决方案 — sumix 增强做了所有的工作 在 Kaggle 讨论区(访问日期:2023 年 6 月 1 日)
[13] Volodymyr (2022). 第 1 名解决方案模型(这不是全部的 BirdNet) 在 Kaggle 讨论区(访问日期:2023 年 3 月 13 日)
[14] Volodymyr (2023). 第 1 名解决方案:正确的数据就是你所需要的一切 在 Kaggle 讨论区(访问日期:2023 年 6 月 1 日)
[15] Yevhenii Maslov (2023). 第 5 名解决方案 在 Kaggle 讨论区(访问日期:2023 年 6 月 1 日)
[16] yokuyama (2022). 第 5 名解决方案 在 Kaggle 讨论区(访问日期:2023 年 3 月 13 日)
提升 Spark 联合运算符性能:改进查询速度的优化技巧
解密 Spark 联合运算符的性能
·发表于 Towards Data Science ·阅读时间 6 分钟·2023 年 4 月 20 日
--

图片由 Fahrul Azmi 提供,来源于 Unsplash
联合运算符是将两个输入数据框合并为一个的集合运算符之一。联合操作在 Apache Spark 中是一个方便的操作,用于合并具有相同列顺序的行。一个常见的用例是应用不同的转换,然后将它们联合在一起。
在 Spark 中使用联合操作的方法常常被广泛讨论。然而,一个较少被讨论的隐藏事实是与联合运算符相关的性能陷阱。如果我们不理解 Spark 中联合运算符的陷阱,可能会陷入将执行时间翻倍的陷阱。
我们将重点讲解 Apache Spark DataFrame 的联合运算符,提供示例,展示物理查询计划,并分享优化技巧。
Spark 中的 Union Operator 101
像关系型数据库(RDBMS)SQL 一样,联合是一种直接合并行的方式。处理联合运算符时要注意的一点是确保行遵循相同的结构:
-
列的数量应该相同。当数据框的列数量不同,联合操作不会静默地工作或用 NULL 填充。
-
列的数据类型应匹配,并按位置解析列。列名在每个数据框中应遵循相同的顺序。然而,这不是强制性的。第一个数据框将被选择为列名的默认值。因此,混合顺序可能会导致意外结果。Spark 的
unionByName旨在解决这个问题。
在 Spark 中,操作unionAll是union的别名,不会去除重复项。我们需要在执行联合后添加 distinct,以进行无重复项的 SQL 类似联合操作。
我们也可以将多个数据框合并成一个最终的数据框。
df = df1.union(df2).union(df3)
联合操作符的性能瓶颈
使用联合操作符的一个典型模式是将单个数据框拆分为多个数据框,然后应用不同的转换,最后将它们合并成最终的数据框。
这里有一个示例:我们有两个需要连接的大表(事实表),最好的连接方式是在 Spark 中使用 SortMerged 连接。一旦得到 SortMerged 数据框,我们将其拆分成四个子集。每个子集使用不同的转换,最终将这四个数据框合并成一个最终的数据框。

Spark 中的联合操作 | 图片由作者提供
Spark 数据框利用 Catalyst 优化器,它会对你编写的数据框代码进行代码分析、逻辑优化、物理规划和代码生成。Catalyst 尝试创建一个执行 Spark 作业的最优计划。
近年来,Spark 在 Catalyst 上进行了大量优化,以提升 Spark 连接操作的性能。连接操作的应用场景比联合操作更多,因此对联合操作的优化投入较少。
如果用户不在完全不同的数据源上使用联合操作,联合操作符将面临潜在的性能瓶颈——Catalyst 并不“聪明”到能识别共享的数据框以进行重用。
在这种情况下,Spark 会将每个数据框视为独立的分支,然后从根节点多次执行所有操作。在我们的示例中,我们将对两个大表进行四次连接!这是一个巨大的瓶颈。
在 Spark 中设置一个使用联合操作符的示例
在 Spark 中重现一个未优化的物理查询计划对于联合操作符来说非常简单。我们将执行以下操作
-
创建两个数据框,从 1 到 1000000。我们称它们为
df1和df2 -
对
df1和df2进行内连接 -
将连接结果拆分成两个数据框:一个仅包含奇数,另一个包含偶数。
-
添加一个名为
magic_value的转换字段,该字段由两个虚拟转换生成。 -
将奇数和偶数数据框进行联合
## Create two data frames from 1 to 1000000\. Let's call them df1 and df2
df1 = spark.createDataFrame([i for i in range(1000000)], IntegerType())
df2 = spark.createDataFrame([i for i in range(1000000)], IntegerType())
## Perform inner join on df1 and df2
df = df1.join(df2, how="inner", on="value")
## Split the joined result into two data frames: one only contains the odd numbers, another one for the even numbers
df_odd = df.filter(df.value % 2 == 1)
df_even = df.filter(df.value % 2 == 0)
## Add a transformation with a field called magic_value which is generated by two dummy transformations.
df_odd = df_odd.withColumn("magic_value", df.value+1)
df_even = df_even.withColumn("magic_value", df.value/2)
## Union the odd and even number data frames
df_odd.union(df_even).count()
这是 DAG 的高级视图。如果从底部向上查看,显著的一个点是连接操作发生了两次,而且上游几乎看起来一模一样。
我们已经看到 Spark 需要广泛优化联合操作符,如果数据源可以重用,大量时间将被浪费在不必要的重新计算上。

未优化查询计划的 DAG | 图片由作者提供
这是一个物理计划,其中有 50 个阶段,启用了 AQE。我们可以看到 id 13 和 27。Spark 确实在每个分支上执行了两次连接并重新计算了它的分支。

未优化的联合操作物理查询计划 | 图片由作者提供
如何提高联合操作的性能
现在我们可以看到这个潜在的瓶颈。我们该如何解决这个问题?一个选项是将执行器的数量翻倍,以运行更多的并发任务。但有一个更好的方法是提示 Catalyst 并让它重用内存中的连接数据框。
为了解决 Spark 联合操作的性能问题,我们可以显式调用 **cache** 来将连接的数据框保留在内存中。 这样 Catalyst 就知道获取数据的快捷方式,而不是返回数据源。
应该在哪里添加 cache()?推荐的位置是在过滤之前和连接完成之后的数据框。
让我们看看它的实际效果:
# ...........................
## Perform inner join on df1 and df2
df = df1.join(df2, how="inner", on="value")
## add cache here
df.cache()
## Split the joined result into two data frames: one only contains the odd numbers, another one for the even numbers
df_odd = df.filter(df.value % 2 == 1)
# ...........................
这是查询计划:InMemoryTableScan 存在,因此我们可以重用数据框以节省其他计算。

优化查询计划的 DAG | 图片由作者提供
现在物理计划减少到仅 32 个阶段,如果我们检查 id 1 和 15 都利用了InMemoryTableScan。 如果我们将原始数据框分割成较小的数据集,然后将它们联合,这可以节省更多时间。

优化的联合操作物理查询计划 | 图片由作者提供
最终思考
我希望这个故事能够提供一些见解,解释为什么有时联合操作会成为 Spark 性能的瓶颈。由于 Catalyst 对 Spark 中联合操作符缺乏优化,用户需要了解这些警示,以更有效地开发 Spark 代码。
添加缓存可以节省我们例子中的时间,但如果联合操作是在两个完全不同的数据源上进行,并且没有共享的地方来执行缓存,这将无济于事。
石崎和明的演讲激发了这个故事——告别 Spark SQL 中的联合地狱,以及我在项目中处理类似问题的经验。
ps:如果你对如何处理 Spark 性能中的数据倾斜感兴趣,我在 TDS 上有另一个相关的故事。
处理分布式计算中的数据倾斜的终极指南
使用大型语言模型提升表格数据预测

作者提供的图像
当你将 GPT-4 应用于表格型 Kaggle 竞赛以预测房价时会发生什么?
·
关注 发表在 Towards Data Science ·9 min read·2023 年 4 月 6 日
--
跟随本博客的 配套 Colab.
本博客与 Arize AI 的 CEO 和联合创始人 Jason Lopatecki 以及 Decision Patterns 的 CEO 和创始人 Christopher Brown 合作完成
引言
在机器学习生态系统中,有两个明显不同的群体。一类处理的是高度组织化的表格数据——即以表格数据为重点的数据科学家。另一类则致力于深度学习应用,包括视觉、音频、大型语言模型(LLMs)等。为了本文的目的,我们称前者为“表格”或“传统”群体,后者为“LLM”群体。每个群体都使用自己独特的技术和模型,这些技术和模型在很大程度上是独立发展的。鉴于大型语言模型如 OpenAI 的 GPT-4 等最近的成功,我们希望看看是否可以利用现代 LLM 的结果来帮助对表格数据集进行预测。
为了证明这种方法的有效性,我们向几个盲测的 Kaggle 竞赛提交了结果(包括受欢迎的“房价——高级回归技术”竞赛)。典型的 Kaggle 竞赛提供表格数据,并且通常由传统的机器学习方法主导。然而,我们发现,即便在几乎没有背景知识、无需数据清洗和特征开发的情况下,LLMs 也能够返回具有预测能力的结果。LLM 的预测结果虽然无法与经过长时间和广泛的表格方法得到的领先模型竞争,但足够强大,能够在排行榜上显著高于中位数分数。
我们预计这将是多种使用 LLM 处理表格数据技术的开端,不会感到惊讶如果看到它们的使用范围扩大,并与更传统的模型开发过程进行有利的竞争。
本文中包括了我们见过的第一个将传统表格数据集和 XGBoost 模型与 LLM 使用潜在结构嵌入相结合的方法,使得表格方法能够基于 LLM 内部生成的数值“特征”进行工作。
到目前为止,我们还没有见过以这种方式使用 LLM 的案例,希望这将是一个令人兴奋的开端。
将深度学习应用于表格数据的挑战
典型的机器学习应用涉及清洗和训练通常由组织收集、保存或获取的狭窄数据集。从高层次来看,这个过程可以被认为是开发一个“上下文”,在这个上下文中只能提出特定类型的问题。当这种类型的问题出现时,机器学习模型会生成一个或多个预测。进一步改善模型来自三个方面:增加数据、改进方法或获取更多不同的特征。最后一点通常最为有趣,因为数据科学家通常会问自己:“我能获取哪些不同的数据来改善我的预测?”
分区、提升和/或袋装模型已经在这个领域得到了开发并表现出色。尽管付出了很多努力,深度学习在这一领域的效果却没有那么显著。观察显示,XGBoost 及其类似模型在生产中更具泛化能力,而深度学习模型往往容易过拟合。许多团队尝试改进在表格数据集上的深度学习,但这些努力大多落后于已建立的高性能表格方法。
使用狭窄数据进行训练的问题
一种常见的深度学习方法是将神经网络和多层感知器(MLP)应用于由组织数据组成的相对“小”数据集。这个方法已多次证明需要更多的工作(数据科学家时间)、资源消耗(训练时间)和参数调优,以获得类似或更差的性能,相比于表格方法。深度学习的失败可能是由于方法与其可用的狭窄数据之间的不匹配。深度学习似乎在从狭窄数据中学习的能力上受到了一定的限制。

图片来源于作者
在上图中,一个大型参数神经网络模型在一个单一公司的“小数据”上进行了训练。在相对较小的数据集上训练大型模型会导致模型几乎总是过度参数化。这是因为用于做出决策的信息量不大,与数据相关的“错误面”也有限。
将大型语言模型应用于表格数据集:进入提示工程
大型语言模型(LLMs)通过两项创新走到了前台。第一项是由谷歌等公司首创的变压器架构。第二项是将这些架构应用于规模达到数十或数百 TB 的巨大数据集。一个合理的假设是,这些 LLMs 能够绕过困扰深度学习方法的“狭窄”数据问题。通过在互联网规模的数据上进行训练,LLMs 构建了许多应用背景的内部表示。这是为了拥有一个能够回应各种提示的模型所必需的步骤。一个令人愉快且必要的结果是,LLMs 可能已经发展出了回答与组织预测问题或 Kaggle 竞赛相关的问题的背景。
类比而言,LLMs 以类似于传统/表格机器学习在其训练步骤中理解问题背景的方式来理解你的问题。令人惊讶的是,这些模型使用的是更广泛的数据来源,而不是组织特定的数据。另一种看法是,LLMs 训练了一个能够从其他地方获得的数据进行预测的模型。对于数据科学家来说,这提供了对多样化数据集和潜在信息宝库的访问——或者它们可能只是噪声。
解锁 LLM 在表格模型中的信息面临两个障碍。第一个是 LLM 是通过提示而不是表格数据(DataFrames)来访问的。第二个是 LLM 主要生成文本输出。
为了克服第一个障碍,我们通过提示提供表格数据。这里的提示是在表格的每一行上创建的。提示的构造相当简单:一个由句子组成的段落,每个句子对应表格行中的一个单元格,如下所示:

图片来源:作者
“行”提示由以下内容组成:
“<列名> 是 <单元格值>。<列名> 是 <单元格值>。……”
需要注意两点:
-
不必为训练数据生成提示,只需生成关于需要进行预测的数据的提示;
-
不必严格要求对数据进行预测。
第二个障碍是 LLM 生成文本响应。在某些情况下,LLM 可以提供预测。目前,预测效果不是很好——可能是因为 LLM 没有针对特定预测进行训练。与其访问 LLM 预测,我们更倾向于使用 LLM 生成的特征。在 LLM 的术语中,这些特征是潜在结构嵌入或简单的“嵌入”。这些嵌入可以通过 LLM API 访问。需要注意的是,嵌入向量通常是按行值的。一旦我们提取了嵌入,就可以将其通过表格模型(XGBoost)进行处理。
嵌入将用于两个示例:首先,是在 Kaggle 数据竞赛中预测房价(这个博客);其次,是通过嵌入漂移来测量多变量漂移和异常检测(即将发布的博客)。

工作流程:表格 -> 提示 -> LLM -> 嵌入 -> XGboost -> 预测(图片来源:作者)
LLM 提供了一个极为简单的特征工程工具,可用于任何表格数据集,本质上可以在几乎不进行特征工程或参数化的情况下提升性能。
朝着利用大数据和大型语言模型的方向前进
在大多数公司中,你所训练的数据相对于互联网上的信息来说是很小的。
让我们以房价预测为例。假设你训练的模型从数据中学习到某个邮政编码的房价较高。它还可能学习到一些与其他特征的有趣关系,例如,带游泳池的房子售价更高。但是,试想一下,如果通过收集全球关于该邮政编码的所有知识,并将其与当前定价模型并行应用,你能对该邮政编码的房价说些什么?

作者提供的图片
大型语言模型在大量数据上进行训练,并通过这些数据学习结构和关系。在内部,它们学习嵌入/激活空间中的流形和表面,这些流形和表面与概念和知识相关,可以应用于几乎所有事物。
结果
以下方法可以应用于任何 pandas 数据框,无需特征工程。在这种情况下,结果为 0.14 RLMSE,使我们在结果中处于一个值得尊敬的位置,且付出了很少的努力。
它是如何工作的?数据流经模型生成嵌入,嵌入表示提示中的数据。这些嵌入表示流经模型的数据的潜在结构。它们捕捉了大量的训练数据,然后将其投影到我们在表格数据集中查看的特定数据上。

作者提供的图片
上面的图片展示了 Kaggle 房价数据集的嵌入空间。你可以通过附带的Colab 生成这个视图。在这里,预测被映射到 Kaggle 数据集 的 UMAP 视图上。

作者提供的图片
LLM 的嵌入具有巨大的结构,这些结构通过 UMAP 视图进行表示。可以通过将原始特征映射到数据中的预测上来查看一些包含的信息。
最终房价预测
为了进行预测,嵌入值从 LLM 流出,并传送到在嵌入潜在结构空间上训练的 XGBoost 模型。XGBoost 模型然后预测房价。这一切都无需特征工程。
为什么不使用 VAE?
我们开始这个过程时考虑将变分自编码器(VAEs)应用于表格数据。然而,我们最终发现 VAE 训练的数据量太少,并且对参数的敏感性太高,无法生成有用的价值。
GPT-4
我们还测试了将数据直接粘贴到 GPT-4 中,以查看其在没有直接对数据进行训练的情况下对数据集进行预测的效果。在许多情况下,结果令人印象深刻。我们预计会有方法将表格数据更直接地连接到 GPT-4,以帮助直接进行这些预测。
“以下是一组行和列数据:
Id MSSubClass MSZoning …
1 60 RL …”
一旦数据的行和列被粘贴到上下文中,你可以对数据进行提问。

查找示例(作者提供的图片)
它确实理解数据,以下是数据集中确切 ID 的查找。

预提示(作者提供的图片)
上述示例是在粘贴数据以进行预测之前的预提示。

GPT-4 预测(作者提供的图片)
上述 GPT-4 预测在没有模型或训练数据的情况下已经相当不错。实际销售价格为 130,250 美元。显然,未来将会有方法将数据与 GPT-4 连接,并在没有任何训练或结合最新技术(SOTA)的情况下获得相当不错的估计。
展望未来
鉴于优化这些结果所投入的努力非常少,因此可以合理预期,大型语言模型(LLMs)将在越来越多的环境中用于表格数据。LLMs 很可能会在某个不久的将来,在小数据集上优于传统技术。随着 LLMs 和提示工程的不断发展,数据科学的许多领域,包括表格数据问题,也不可避免地受到影响。
Bootstrap(自助法)入门
原文:
towardsdatascience.com/bootstrap-tests-for-beginners-5fd9b092e7a6
非参数检验入门第二部分
·发布在 Towards Data Science ·10 分钟阅读·2023 年 6 月 19 日
--

照片来源:Mohamed Nohassi 在 Unsplash
在本系列的第一部分中,我介绍了简单的秩检验和符号检验,作为非参数检验的入门。正如第一部分所述,bootstrap(自助法)也是一种流行的非参数统计推断方法,基于对观测数据的重新抽样。自从Bradley Efron 在 1980 年代首次介绍以来,它在学术界获得了广泛的关注。Efron 和 Tibshirani (1994) 提供了关于 bootstrap 方法的介绍性和全面的调查。该方法在统计科学领域得到了广泛应用,以上书籍至今已获得超过 50,000 次 Google Scholar 引用。
在这篇文章中,我以直观的方式介绍了自助法入门,提供了简单的示例和 R 代码。
介绍
如第一部分所述,假设检验的关键要素包括
-
虚无假设和备择假设(H0 和 H1)
-
检验统计量
-
H0 下检验统计量的抽样分布
-
决策规则(p 值或临界值,在给定显著性水平下)
在生成检验统计量的抽样分布时,
-
参数检验(如 t 检验或 F 检验)假设总体遵循正态分布。如果总体不是正态的,则利用中心极限定理(称为渐近正态近似)将正态分布作为抽样分布的近似;
-
秩和符号检验使用数据点的秩和符号生成精确的抽样分布,如第一部分所讨论;
-
自助法生成或近似统计量的抽样分布,基于对观察数据的重新抽样(有放回),类似于从总体中随机和重复抽取样本的方式。
-
与秩和符号检验一样,自助法不要求总体正态性或基于中心极限定理的渐进正态近似。
-
自助法在其基本形式中要求从固定均值和方差的总体中进行纯随机抽样(无需正态性),尽管也有适用于依赖或异方差数据的自助法方法。
本文介绍了从总体随机生成数据的基本自助法,并附有示例。针对更一般数据结构的自助法,其简要细节和 R 资源在另一个部分中呈现。
自助法的玩具示例
示例 1:X = (1, 2, 3)
假设研究者观察到数据集 X = (1, 2, 3),样本均值为 2,标准差(s)为 1。假设总体为正态,总体均值为 2 的假设下,样本均值(Xbar)的抽样分布为
其中s = 1,μ为总体均值。这意味着,在正态近似下,样本均值遵循均值为 2、方差为 1/3 的正态分布。
自助法以有放回的方式重新抽样观察到的数据 X = (1, 2, 3),给其成员以 1/3 的等概率。下表 1 展示了这些重新抽样(或伪数据)X* = (X1, X2, X3*)的所有 27 种可能结果及其均值。

表 1:从 X 中有放回的抽样(图像由作者创建)
这 27 个结果的均值为 2,方差为 0.23。来自这些 X*的样本均值的分布表示了精确的自助法分布,如下图 1 所示:

图 1:精确的自助法分布及其密度估计(图像由作者创建)
左侧的条形图显示了精确的自助法分布,而自助法分布的核密度估计(红色)与均值为 2、方差为 1/3 的正态分布(黑色)一起绘制。
示例 2:X = (1, 2, 6)
现在考虑 X = (1, 2, 6)的情况,样本均值为 3,s = 2.65。与表 1 中的类似计算显示 X的均值为 3,方差为 1.62。精确的自助法分布在下图 2 中绘制,伴随一个核密度估计(红色),显示出与均值为 3、方差为s*²/n 的正态分布(黑色)的明显偏离。

图 1:精确的自助法分布及其密度估计(图像由作者创建)
从这两个示例中,我们可以得出以下几点:
-
示例 1 是数据集 X 在其均值处完全对称的情况。样本均值的自助法抽样分布也对称,且很好地由正态分布近似。
-
示例 2 是数据集 X 在其均值处不对称的情况,这一点在自助法抽样分布的形状中得到了很好的反映。然而,正态分布无法反映这种不对称。
-
鉴于这些示例中的总体分布是未知的,很难评估自助法分布是否更好地代表了样本均值的真实抽样分布。
-
然而,我们观察到自助法具有反映总体分布中可能存在的偏斜的能力,而渐进正态近似无法捕捉到这一点。
请注意,自助法能够捕捉到样本总体的许多非正态特性,如偏斜、厚尾和双峰性,这些特性是正态近似无法捕捉到的。
许多比较自助法和渐进正态近似的学术研究提供了强有力的证据,表明自助法通常表现更好,能够捕捉到真实抽样分布的特征,特别是在样本量较小时。研究报告指出,随着样本量的增加,两种方法显示出类似的特性,这意味着在样本量较小时,通常应该优先选择自助法。
对于 X = (X1, …, Xn) 进行自助抽样
上述示例展示了 n = 3 的情况,我们能够获得所有 27 种可能重抽样的精确自助法分布。注意到所有可能重抽样的数量为 nⁿ,对于一般的 n 值,计算 nⁿ 次重抽样的精确自助法分布可能计算负担过重。然而,这一过程并非必要,因为蒙特卡罗模拟可以提供对精确自助法分布的相当准确的近似。
假设数据 X 是从具有固定均值和方差的总体中随机获取的。假设我们关注的统计量,例如样本均值或 t 统计量,记作 T(X)。然后,
-
我们通过从 X 中有放回地随机抽样得到 X* = (X₁, …, Xₙ),每个 X 成员的概率相等。
-
由于我们不能对所有可能的 nⁿ 次重抽样进行操作,我们重复上述过程足够多次 B,比如 1000、5000 或 10000 次。通过这样做,我们得到 B 组不同的 X,记作 {X(i)},其中 i = 1, …, B。
-
从每个 X(i) 中计算关注的统计量 [T(X)]。然后我们得到 {T(X,i)} (i = 1,…., B),其中 T(X,i) 是从 X(i) 计算的 T(X*)。
自助法分布 {T(X*,i)} 被用作精确自助法分布以及 T 的未知抽样分布的近似。
举例来说,我生成了 X = (X1, …, X20) 从
-
具有 2 和 10 自由度的 F 分布 [F(2,10)],
-
具有 3 自由度的卡方分布 [chisq(3)],
-
具有 3 自由度的 Student-t 分布 [t(3)],以及
-
对数正态分布,均值为 0,方差为 1(对数正态)。
下图 3 绘制了 {T(X*,i)}(i = 1,…., B)的密度估计,其中 T 是均值,B = 10000,与正态分布的密度进行比较,正态分布的均值和方差值与 X 对应。自助法分布可能与正态分布不同,特别是当基础总体分布明显偏离正态分布时。

图 3:自助法分布(红色)与正态分布(黑色)(图片由作者制作)
下面给出了上述 Monte Carlo 模拟和图形的 R 代码:
n=20 # Sample size
set.seed(1234)
pop = "lognorm" # population type
if (pop=="F(2,10)") x=rf(n,df1=2,df2=10)
if (pop=="chisq(3)") x=rchisq(n,df=3)
if (pop=="t(3)") x=rt(n,df=3)
if (pop=="lognorm") x=rlnorm(n)
# Bootstrapping sample mean
B=10000 # number of bootstrap iterations
stat=matrix(NA,nrow=B)
for(i in 1:B){
xboot=sample(x,size=n,replace = TRUE)
stat[i,] = mean(xboot)
}
# Plots
plot(density(stat),col="red",lwd=2,main=pop,xlab="")
m=mean(x); s=sd(x)/sqrt(n)
curve(dnorm(x,mean=m,sd=s),add=TRUE, yaxt="n")
rug(stat)
自助法检验和分析是基于上述红色曲线 {T(X*,i)},而非黑色的正态分布。
-
推断统计量,如置信区间或 p 值,是通过 {T(X*,i)} 获得的,与使用正态分布时的计算方式相同。
-
自助法分布可以揭示更多和更详细的信息,如对称性、胖尾、非正态性、双峰性以及异常值的存在,关于总体的特性。
假设 T(X) 是上述的样本均值。
对于总体均值的 自助法置信区间 可以通过取 {T(X,i)} 的适当百分位数来获得。例如,设 {T(X,i;θ)} 为 {T(X,i)} 的 θ 分位数。那么,95% 的自助法置信区间为 [{T(X,i;2.5)},{T(X*,i;97.5)}]。
假设 T(X) 是 H0: μ = 0 与 H0: μ > 0 的 t 检验统计量。那么,自助法 p 值的计算是通过计算 {T(X*,i)} 大于原始样本中 T(X) 值的比例来完成的。也就是说,p 值的计算类似于正态分布的情况,取决于 H1 的结构。

表 2:自助法与正态法的 95% 置信区间(图片由作者制作)
上表 2 展示了自助法置信区间与渐近正态置信区间的比较,两者均为 95% 置信度。当总体分布为 t(3) 或 chisq(3) 时,这两种方法提供了类似的区间,但当总体遵循 F(2,10) 或 lognorm 分布时,它们可能会有较大差异。
R 中的自助法 t 检验
自助法可以应用于单样本和双样本 t 检验。在这种情况下,感兴趣的检验统计量 T(X) 是 t 检验统计量,其自助法分布可以如上所述获得。在 R 中,包 “MKinfer” 提供了用于自助法检验的函数。
让我们考虑第一部分中使用的 X 和 Y 示例:
x = c(-0.63, 0.18,-0.84,1.60,0.33, -0.82,0.49,0.74,0.58,-0.31,
1.51,0.39,-0.62,-2.21,1.12,-0.04,-0.02,0.94,0.82,0.59)
y=c(1.14,0.54,0.01,-0.02,1.26,-0.29,0.43,0.82,1.90,1.51,
1.83,2.01,1.37,2.54,3.55, 3.99,5.28,5.41,3.69,2.85)
# Install MKinfer package
library(MKinfer)
# One-sample test for X with H0: mu = 0
boot.t.test(x,mu=0)
# One-sample test for Y with H0: mu = 1
boot.t.test(y,mu=1)
# Two-sample test for X and Y with H0: mu(x) - mu(y) = -1
boot.t.test(x,y,mu=-1)
结果总结在下表中(所有检验假设为双尾 H1):

表 3: 95% 置信区间和 p 值(图像由作者创建)
-
对于 μ(X) = 0 的检验,X 的样本均值为 0.19,t 统计量为 0.93。自助法和渐近置信区间及 p 值提供了类似的无法拒绝 H0 的推断结果,但自助法的置信区间更紧凑。
-
对于 μ(Y) = 1 的检验,Y 的样本均值为 1.99,t 统计量为 2.63。自助法和渐近置信区间及 p 值在 5% 显著性水平下提供了类似的拒绝 H0 的推断结果,但自助法的置信区间更紧凑且 p 值更低。
-
对于 H0: μ(X) — μ(Y) = — 1 的检验,X 和 Y 之间的均值差为 -1.80,t 统计量为 -1.87。自助法和渐近置信区间及 p 值在 10% 显著性水平下提供了类似的拒绝 H0 的推断结果。
更一般数据结构的自助法
如上所述,自助法也已被开发用于线性回归模型、时间序列预测以及更一般结构的数据。以下总结了自助法的几个重要扩展。
-
对于线性回归模型,自助法可以通过对残差进行重采样或对样本进行重采样来进行:见 R 包 “car”。
-
自助法可以应用于基于自回归模型的时间序列预测:见 R 包 “BootPR”。
-
对于结构未知的时间序列数据,可以使用平稳自助法(或移动块自助法)。这涉及对时间序列观察值的块进行重采样。R 包 “tseries” 提供了该方法的函数。
-
对于形式未知的异方差数据,可以使用野生自助法,使用 R 包 “fANCOVA” 。它通过与均值为零、方差为单位的随机变量进行缩放来重采样数据,从而有效地复制异方差结构。
结论
本文回顾了自助法作为一种非参数检验,其中通过重复重采样观察数据来计算或近似统计量的采样分布。虽然本文仅涵盖了用于置信区间和检验总体均值的 p 值的自助法,但自助法的应用非常广泛,从回归分析到结构未知的时间序列数据均有涉及。
许多学术研究报告了理论或计算结果,表明自助法检验通常优于渐进正态近似,尤其是当样本量较小或中等时。
因此,在小样本中,统计学和机器学习领域的研究人员强烈建议使用自助法作为渐进正态近似的有用替代方案。
使用 GPT-4 进行标签自助标注

(来源:作者使用 DALL-E 生成的图像,由作者修改。)
一种具有成本效益的数据标注方法
·
关注 发表在 Towards Data Science ·9 分钟阅读·2023 年 6 月 9 日
--
数据标注是机器学习项目中的关键组成部分。它建立在古老的格言“垃圾进,垃圾出”的基础上。标注涉及为训练和评估创建带注释的数据集。但这一过程可能既耗时又昂贵,特别是对于数据量巨大的项目。但如果我们可以利用大语言模型的进展来减少数据标注任务的成本和工作量呢?
GPT-4 是 OpenAI 开发的最先进的语言模型。它具有卓越的理解和生成类人文本的能力,并且在自然语言处理(NLP)社区及其他领域中带来了革命性的变化。在这篇博客文章中,我们将探讨如何利用 GPT-4 为各种任务引导标签。这可以显著减少标注过程中的时间和成本。我们将重点讨论情感分类,以展示提示工程如何使你能够使用 GPT-4 创建准确可靠的标签,并且这种技术还可以用于更强大的应用。
利用 GPT-4 的预测进行数据预标注
正如在写作中,编辑通常比创作原始作品的工作轻松。这就是为什么从预标注的数据开始比从空白状态开始更具吸引力。利用 GPT-4 作为预测引擎来预标注数据源于其理解上下文和生成类人文本的能力。因此,利用 GPT-4 减少数据标注所需的人工工作是非常理想的。这可能会带来成本节约,并使标注过程变得不那么单调。
那么我们怎么做呢?如果你使用过 GPT 模型,你可能对提示很熟悉。提示在模型开始生成输出之前设置了上下文,并且可以进行调整和工程(即提示工程),以帮助模型提供高度特定的结果。这意味着我们可以创建提示,让 GPT-4 生成看起来像模型预测的文本。对于我们的使用案例,我们将以一种引导模型生成期望输出格式的方式来设计我们的提示。
让我们以一个简单的情感分析示例为例。如果我们试图将给定的文本字符串分类为 positive(积极的)、negative(消极的)或 neutral(中立的),我们可以提供如下提示:
"Classify the sentiment of the following text as 'positive', 'negative', or 'neutral': <input_text>"
一旦我们有了结构良好的提示,我们可以使用 OpenAI API 从 GPT-4 生成预测。以下是使用 Python 的一个示例:
import openai
import re
openai.api_key = "<your_api_key>"
def get_sentiment(input_text):
prompt = f"Respond in the json format: {{'response': sentiment_classification}}\nText: {input_text}\nSentiment (positive, neutral, negative):"
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": prompt}
],
max_tokens=40,
n=1,
stop=None,
temperature=0.5,
)
response_text = response.choices[0].message['content'].strip()
sentiment = re.search("negative|neutral|positive", response_text).group(0)
# Add input_text back in for the result
return {"text": input_text, "response": sentiment}
我们可以用一个示例来运行这个,以检查我们从 API 获得的输出。
# Test single example
sample_text = "I had a terrible time at the party last night!"
sentiment = get_sentiment(sample_text)
print("Result\n",f"{sentiment}")
Result:
{'text': 'I had a terrible time at the party last night!', 'response': 'negative'}
一旦我们对提示和结果感到满意,我们可以将其扩展到整个数据集。这里,我们假设一个文本文件,每行一个示例。
import json
input_file_path = "input_texts.txt"
output_file_path = "output_responses.json"
with open(input_file_path, "r") as input_file, open(output_file_path, "w") as output_file:
examples = []
for line in input_file:
text = line.strip()
if text:
examples.append(convert_ls_format(get_sentiment(text)))
output_file.write(json.dumps(examples))
我们可以将带有预标注预测的数据导入 Label Studio,并让审核员验证或修正标签。这种方法显著减少了数据标注所需的人工工作,因为人工审核员只需要验证或修正模型生成的标签,而不是从头标注整个数据集。请查看我们的 完整示例笔记本。
请注意,在大多数情况下,OpenAI 被允许使用发送到其 API 的任何信息来进一步训练他们的模型。因此,如果我们不想更广泛地公开信息,就重要不要将受保护或私人数据发送到这些 API 进行标注。
在 Label Studio 中审查预标记数据
一旦我们的预标记数据准备好,我们将其导入到数据标注工具中,例如 Label Studio,以便进行审查。本节将指导你如何设置 Label Studio 项目、导入预标记数据和审查注释。

图 1:在 Label Studio 中审查情感分类。(图片来源于作者,Label Studio 截图)
第一步:安装并启动 Label Studio
首先,你需要在机器上安装 Label Studio。你可以使用 pip 安装它:
pip install label-studio
安装 Label Studio 后,通过运行以下命令启动它:
label-studio
这将会在你默认的网页浏览器中打开 Label Studio。
第二步:创建新项目
点击“创建项目”,输入项目名称,例如“审查引导标签”。接下来,你需要定义标注配置。对于情感分析,我们可以使用文本 Sentiment Analysis Text Classification。
这些模板是可配置的,因此如果我们想更改任何属性,非常简单。默认的标注配置如下所示。
<View>
<Header value="Choose text sentiment:"/>
<Text name="my_text" value="$reviewText"/>
<Choices name="sentiment" toName="my_text" choice="single" showInline="true">
<Choice value="Positive"/>
<Choice value="Negative"/>
<Choice value="Neutral"/>
</Choices>
</View>
点击“创建”以完成项目设置。
第三步:导入预标记数据
要导入预标记数据,请点击“导入”按钮。选择 json 文件并选择之前生成的预标记数据文件(例如,“output_responses.json”)。数据将与预填充的预测一起导入。
第四步:审查和更新标签
导入数据后,你可以审查模型生成的标签。注释界面将显示每个文本样本的预标记情感,审查员可以接受或更正建议的标签。
你可以通过让多个标注员审查每个示例来进一步提高质量。
通过利用 GPT-4 生成的标签作为起点,审查过程变得更加高效,审查员可以专注于验证或更正注释,而不是从头创建它们。
第五步:导出标记数据
一旦审查过程完成,你可以通过点击“数据管理器”标签中的“导出”按钮来导出标记的数据。选择所需的输出格式(例如 JSON、CSV 或 TSV),并保存标记的数据集以供进一步用于你的机器学习项目。
成本分析
我脑海中不断浮现的问题是:“这一天的费用到底是多少?”
注意:以下显示的价格反映了作者在出版时的当前数据。未来价格可能会有所不同,或根据地理位置有所变化。
对于语言模型,OpenAI 根据请求中的令牌数量收费。令牌通常是查询中的单词数量,但特殊字符和表情符号有时也会被视为单个令牌。OpenAI 的定价页面指出,“你可以将令牌视为单词的片段,其中 1,000 个令牌大约是 750 个单词。”有关令牌计数的更多信息,请参见此页面。
每个令牌的成本根据使用的模型不同而有所不同。例如,GPT-4 8K 上下文模型的提示费用为$0.03/1K 令牌,每个生成的令牌费用为$0.06/1K 令牌,而 GPT-3.5-turbo 模型的费用为$0.002/1K 令牌。

OpenAI 的令牌价格总结。(来源:OpenAI 论坛,图像由作者提供)
要估算预标记数据集的成本,我们可以使用一个简单的公式,考虑数据集中示例的数量、每令牌的提示和完成价格,以及每个示例的平均令牌数。

其中:

此外,我们可以通过以下方式计算数据集中令牌的总数:

其中:

使用这个公式,我们可以通过将示例数量与提示费用和完成费用的总和相乘来估算预标记数据集的成本,并根据每个示例的平均令牌数进行调整。
例如,如果我们有一个包含 1,000 个示例的数据集,并且我们希望使用 GPT-4 进行情感分析的预标记,我们可以根据以下条件进行计算:每 1K 令牌的提示价格为$0.03,每 1K 令牌的完成价格为$0.06,提示长度为 20 令牌,每个示例的平均长度为 80 令牌,每个结果令牌的平均长度为 3 令牌,则预标记的总成本为:

在这个例子中,使用 GPT-4 进行数据集的预标记将花费$3.18。注意:相同的数据集使用GPT-3.5-turbo 将花费大约$0.21。
如果我们的预标记任务所需的专业知识较少,我们可能会希望使用一个不那么强大的模型以节省成本。通常,手动查看几个不同复杂程度的示例是值得的,以了解一个模型相对于另一个模型的准确性。有关模型的信息,请参见OpenAI 定价页面。
超越情感分析:标记任何 NLP 任务

Label Studio 中的命名实体识别截图。(图像由作者提供,截图来自 Label Studio)
这种方法的优点在于它不仅限于情感分析。我们可以通过使用提示工程和引导 GPT-4 生成正确的输出,预先标注各种 NLP 任务的数据。以下是不同 NLP 任务的几个提示示例,所有示例都可以在 Label Studio 中进行审核和标注(所给示例是使用 GPT-4 生成的)。请记住,更复杂的任务可能需要更长的提示,这将计入你的令牌数,并因此增加成本。
总结
- 提示:
以 JSON 格式回应以下文本的总结:{‘summary’: summary}
文本:我喜欢在晴天去公园。客户服务很糟糕;他们无礼且无助。我对新政策的变化既不开心也不难过。蛋糕很美味,摆盘也很棒。我对产品的体验非常糟糕;它在两天后坏了。
总结:
响应:
{'summary': 'Artificial intelligence is a rapidly developing technology that can learn from data and make decisions without being explicitly programmed. It has the potential to help solve some of the world's most pressing problems, but there are concerns about risks and ethical considerations such as bias and job displacement. It is important to consider AI's impact and ensure responsible and beneficial use.'}
分类法
提示:
以 JSON 格式回应以下文本的总结:{‘diagnosis’: diagnosis},其中可能的诊断包括肺炎、急性支气管炎、哮喘加重、充血性心力衰竭、慢性阻塞性肺病(COPD)、肺栓塞、胸腔积液、间质性肺病或肺癌。
文本:患者出现呼吸急促和持续性咳嗽。体检时在下肺区听到了裂纹音。胸部 X 光显示双侧浸润,与肺炎一致。患者开始接受抗生素治疗,并在 48 小时内症状有所改善。
诊断:
响应:
{'diagnosis': 'Pneumonia'}
命名实体识别(NER)
提示:
你是一个专业的 NER 标注员。只返回 JSON 格式。不要返回解释。以 JSON 格式返回以下句子的 CoNLL 格式。
响应:
{
"data": [
{
"word": "Harry",
"entity": "B-PER"
},
{
"word": "Potter",
"entity": "I-PER"
},
{
"word": "was",
"entity": "O"
},
{
"word": "a",
"entity": "O"
},
{
"word": "student",
"entity": "O"
},
{
"word": "at",
"entity": "O"
},
{
"word": "Hogwarts",
"entity": "B-LOC"
}
]
}
在我看来,没有什么能替代人工评估,但使用 GPT 模型给我们提供一个起点可能是一个巨大的优势。
结论
标注数据总是很困难,有时候,甚至得到一个起点也是巨大的优势。在这篇博客中,我们展示了如何使用 OpenAI 的 GPT 模型生成数据预测,以作为数据标注工作流的起点。这个过程可以显著减少人工工作量,并将标注员的注意力集中在为他们的工作提供更多价值上。查看资源以获取更多关于本博客中介绍的主题的信息。
资源
完整示例笔记本 — 包含所有代码的笔记本,准备在 Colab 中运行
Label Studio — 开源数据标注工具
OpenAI 定价页面 — 本文中的定价估算细节
Boto3 与 AWS Wrangler:使用 Python 简化 S3 操作
原文:
towardsdatascience.com/boto3-vs-aws-wrangler-simplifying-s3-operations-with-python-596bdf021ef2
针对 AWS S3 开发的比较分析
·发表于 Towards Data Science ·阅读时间 10 分钟·2023 年 6 月 20 日
--

图片由 Hemerson Coelho 提供,来自 Unsplash
按需课程 | 推荐
一些读者联系我,询问是否有按需课程帮助你 成为 一个扎实的 数据工程师。这些是我推荐的 3 个绝佳资源:
还不是 Medium 会员? 考虑通过我的 推荐链接 注册,低至每月 5 美元即可访问 Medium 的所有内容!
介绍
在本教程中,我们将通过探索和比较两个强大的库:boto3 和 awswrangler,深入了解 AWS S3 的 Python 开发。
如果你曾经想过
“哪个 Python 工具最适合与 AWS S3 存储桶交互?”
“如何以最有效的方式执行 S3 操作?”
那么你来对地方了。
的确,在这篇文章中,我们将涵盖一系列处理 AWS S3 存储桶的常见操作,其中包括:
-
列出对象,
-
检查对象存在性,
-
下载对象,
-
上传对象,
-
删除对象,
-
写入对象,
-
读取对象(标准方式或使用 SQL)
通过比较这两个库,我们将识别它们的相似性、差异以及每个操作的最佳用例。到最后,你将清楚了解哪个库更适合特定的 S3 任务。
此外,对于那些读到最后的人,我们还将探讨如何利用boto3和awswrangler通过友好的 SQL 查询从 S3 读取数据。
所以让我们深入了解,发现与 AWS S3 交互的最佳工具,并学习如何使用这两个库高效地执行这些操作。
先决条件与数据
本教程中使用的软件包版本为:
-
boto3==1.26.80 -
awswrangler==2.19.0
此外,三个初始文件,包括随机生成的account_balances数据,已上传到名为coding-tutorials的 S3 桶中:

尽管你应该知道有多种方法可以与 S3 桶建立连接,但在这种情况下,我们将使用boto3的setup_default_session():
# CONNECTING TO S3 BUCKET
import os
import io
import boto3
import awswrangler as wr
import pandas as pd
boto3.setup_default_session(aws_access_key_id = 'your_access_key',
aws_secret_access_key = 'your_secret_access_key')
bucket = 'coding-tutorials'
这种方法很方便,因为一旦设置了会话,它可以被boto3和awswrangler共享,这意味着我们不需要再传递任何秘密信息。
比较分析
现在让我们比较一下boto3和awswrangler在执行一些常见操作时,找出最适合的工具。
包含以下代码的完整笔记本可以在这个 GitHub 文件夹中找到。
# 1 列出对象
列出对象可能是我们在探索新 S3 桶时应该执行的第一个操作,也是检查会话是否正确设置的简单方法。
使用boto3可以通过以下方式列出对象:
-
boto3.client('s3').list_objects() -
boto3.resource('s3').Bucket().objects.all()
print('--BOTO3--')
# BOTO3 - Preferred Method
client = boto3.client('s3')
for obj in client.list_objects(Bucket=bucket)['Contents']:
print('File Name:', obj['Key'], 'Size:', round(obj['Size']/ (1024*1024), 2), 'MB')
print('----')
# BOTO3 - Alternative Method
resource = boto3.resource('s3')
for obj in resource.Bucket(bucket).objects.all():
print('File Name:', obj.key, 'Size:', round(obj.size/ (1024*1024), 2), 'MB')
尽管client和resource类做得都不错,但client类应当优先选择,因为它更优雅,并提供了大量的[易于访问]低级元数据,作为嵌套的JSON(其中包括对象的size)。
另一方面,awswrangler仅提供了一个列出对象的方法:
wr.s3.list_objects()
作为一种高级方法,这不会返回任何低级的对象元数据,因此要找到文件的size,我们需要调用:
wr.s3.size_objects
print('--AWS_WRANGLER--')
# AWS WRANGLER
for obj in wr.s3.list_objects("s3://coding-tutorials/"):
print('File Name:', obj.replace('s3://coding-tutorials/', ''))
print('----')
for obj, size in wr.s3.size_objects("s3://coding-tutorials/").items():
print('File Name:', obj.replace('s3://coding-tutorials/', '') , 'Size:', round(size/ (1024*1024), 2), 'MB')
上述代码返回:

比较 → Boto3 胜出
尽管awswrangler使用起来更为直观,但在列出 S3 对象时,boto3胜出。实际上,其底层实现意味着可以使用其类检索更多的对象元数据。这些信息在以编程方式访问 S3 桶时非常有用。
# 2 检查对象存在性
当我们希望触发附加操作时,检查对象的存在性是必要的,这些操作是基于对象是否已经存在于 S3 中。
使用 boto3 可以通过以下方式执行这些检查:
boto3.client('s3').head_object()
object_key = 'account_balances_jan2023.parquet'
# BOTO3
print('--BOTO3--')
client = boto3.client('s3')
try:
client.head_object(Bucket=bucket, Key = object_key)
print(f"The object exists in the bucket {bucket}.")
except client.exceptions.NoSuchKey:
print(f"The object does not exist in the bucket {bucket}.")
相反,awswrangler 提供了一个专用的方法:
wr.s3.does_object_exist()
# AWS WRANGLER
print('--AWS_WRANGLER--')
try:
wr.s3.does_object_exist(f's3://{bucket}/{object_key}')
print(f"The object exists in the bucket {bucket}.")
except:
print(f"The object does not exist in the bucket {bucket}.")
上述代码返回:

比较 → AWSWrangler 胜出
让我们承认:boto3 的方法名 [head_object()] 并不是那么直观。
另外,拥有专用方法无疑是 awswrangler 在这场比赛中的优势。
# 3 Downloading Objects
使用以下方法,在本地下载对象在 boto3 和 awswrangler 中都非常简单:
-
boto3.client('s3').download_file()或 -
wr.s3.download()
唯一的区别是 download_file() 需要 bucket、object_key 和 local_file 作为输入变量,而 download() 只需要 S3 path 和 local_file:
object_key = 'account_balances_jan2023.parquet'
# BOTO3
client = boto3.client('s3')
client.download_file(bucket, object_key, 'tmp/account_balances_jan2023_v2.parquet')
# AWS WRANGLER
wr.s3.download(path=f's3://{bucket}/{object_key}', local_file='tmp/account_balances_jan2023_v3.parquet')
执行代码时,确实会将同一个对象的两个版本下载到本地的 tmp/ 文件夹中:

比较 → 平局
在下载文件方面,我们可以认为两个库是等效的,因此我们称其为平局。
# 4 Uploading Objects
上传本地环境中的文件到 S3 时,同样的推理适用。可以使用的方法有:
-
boto3.client('s3').upload_file()或 -
wr.s3.upload()
object_key_1 = 'account_balances_apr2023.parquet'
object_key_2 = 'account_balances_may2023.parquet'
file_path_1 = os.path.dirname(os.path.realpath(object_key_1)) + '/' + object_key_1
file_path_2 = os.path.dirname(os.path.realpath(object_key_2)) + '/' + object_key_2
# BOTO3
client = boto3.client('s3')
client.upload_file(file_path_1, bucket, object_key_1)
# AWS WRANGLER
wr.s3.upload(local_file=file_path_2, path=f's3://{bucket}/{object_key_2}')
执行代码,将两个新的 account_balances 对象(2023 年 4 月和 5 月)上传到 coding-tutorials 存储桶:

比较 → 平局
这仍然是平局。到目前为止,这两个库之间完全平等!
# 5 Deleting Objects
现在假设我们希望删除以下对象:
#SINGLE OBJECT
object_key = ‘account_balances_jan2023.parquet’
#MULTIPLE OBJECTS
object_keys = [‘account_balances_jan2023.parquet’,
‘account_balances_feb2023.parquet’,
‘account_balances_mar2023.parquet’]
boto3 允许通过以下方法逐个或批量删除对象:
-
boto3.client('s3').delete_object() -
boto3.client('s3').delete_objects()
两种方法都返回一个包含 ResponseMetadata 的 response,可以用来验证对象是否成功删除。例如:
-
在删除单个对象时,
HTTPStatusCode==204表示操作已成功完成(如果 S3 存储桶中找到对象); -
在删除多个对象时,会返回一个包含成功删除项目名称的
Deleted列表。
# BOTO3
print('--BOTO3--')
client = boto3.client('s3')
# Delete Single object
response = client.delete_object(Bucket=bucket, Key=object_key)
deletion_date = response['ResponseMetadata']['HTTPHeaders']['date']
if response['ResponseMetadata']['HTTPStatusCode'] == 204:
print(f'Object {object_key} deleted successfully on {deletion_date}.')
else:
print(f'Object could not be deleted.')
# Delete Multiple Objects
objects = [{'Key': key} for key in object_keys]
response = client.delete_objects(Bucket=bucket, Delete={'Objects': objects})
deletion_date = response['ResponseMetadata']['HTTPHeaders']['date']
if len(object_keys) == len(response['Deleted']):
print(f'All objects were deleted successfully on {deletion_date}')
else:
print(f'Object could not be deleted.')
另一方面,awswrangler 提供了一个可以用于单个和批量删除的方法:
wr.s3.delete_objects()
由于 object_keys 可以作为 list_comprehension 递归传递给方法,而不需要像以前那样先转换成字典,使用这种语法确实令人愉快。
# AWS WRANGLER
print('--AWS_WRANGLER--')
# Delete Single object
wr.s3.delete_objects(path=f's3://{bucket}/{object_key}')
# Delete Multiple Objects
try:
wr.s3.delete_objects(path=[f's3://{bucket}/{key}' for key in object_keys])
print('All objects deleted successfully.')
except:
print(f'Objects could not be deleted.')
执行上述代码,会删除 S3 中的对象,然后返回:


比较 → Boto3 胜出
这有点棘手:awswrangler在删除多个对象时有更简单的语法,因为我们可以直接将完整的列表传递给该方法。
然而,boto3在response中返回大量有用的日志信息,这在编程删除对象时非常有用。
因为在生产环境中,低级别的元数据总比几乎没有元数据要好,boto3赢得了这场挑战,现在领先 2–1。
# 6 写入对象
在将文件写入 S3 时,boto3甚至没有提供一个开箱即用的方法来执行这些操作。
例如,如果我们想用boto3创建一个新的parquet文件,我们首先需要将对象保存到本地磁盘(使用pandas的to_parquet() 方法),然后使用upload_fileobj()方法将其上传到 S3。
与upload_file()(在第 4 点中探讨)不同,upload_fileobj()方法 是一种托管传输,将在必要时在多个线程中执行分块上传:
object_key_1 = 'account_balances_june2023.parquet'
# RUN THE GENERATOR.PY SCRIPT
df.to_parquet(object_key_1)
# BOTO3
client = boto3.client('s3')
# Upload the Parquet file to S3
with open(object_key_1, 'rb') as file:
client.upload_fileobj(file, bucket, object_key_1)
另一方面,awswrangler库的主要优点之一(在与pandas一起使用时)是它可以直接将对象写入 S3 桶(而无需将其保存到本地磁盘),这既优雅又高效。
此外,awswrangler 提供了极大的灵活性,允许用户:
-
应用特定的压缩算法,如
snappy、gzip和zstd; -
通过
mode参数在dataset = True时append到或overwrite现有文件; -
通过
partitions_col参数指定一个或多个分区列。
object_key_2 = 'account_balances_july2023.parquet'
# AWS WRANGLER
wr.s3.to_parquet(df=df,
path=f's3://{bucket}/{object_key_2}',
compression = 'gzip',
partition_cols = ['COMPANY_CODE'],
dataset=True)
一旦执行,上述代码将account_balances_june2023写入为一个单独的parquet文件,并将account_balances_july2023写入为一个包含四个文件的文件夹,这些文件已按COMPANY_CODE分区:

比较 → AWSWrangler 获胜
如果使用pandas是一个选项,awswrangler在将文件写入 S3 时提供了更高级的操作集,特别是与boto3相比,在这种情况下,boto3并不是最好的工具。
# 7.1 读取对象(Python)
在尝试使用boto3从 S3 读取对象时,同样的推理适用:由于该库没有内置的读取方法,我们能做的最佳选择是执行 API 调用(get_object()),读取response的Body,然后将parquet_object传递给pandas。
请注意,pd.read_parquet()方法期望一个类文件对象作为输入,这就是为什么我们需要将从parquet_object读取的内容作为二进制流传递。
实际上,通过使用io.BytesIO(),我们在内存中创建了一个临时的类文件对象,从而避免了在读取之前将 Parquet 文件保存到本地。这反过来提高了性能,特别是在处理大型文件时:
object_key = 'account_balances_may2023.parquet'
# BOTO3
client = boto3.client('s3')
# Read the Parquet file
response = client.get_object(Bucket=bucket, Key=object_key)
parquet_object = response['Body'].read()
df = pd.read_parquet(io.BytesIO(parquet_object))
df.head()
正如预期的那样,awswrangler 在读取 S3 对象方面表现出色,返回一个 pandas df 作为输出。
它支持多种输入格式,如 csv、json、parquet 和最近的 delta 表。此外,传递 chunked 参数可以以 内存友好的 方式读取对象:
# AWS WRANGLER
df = wr.s3.read_parquet(path=f's3://{bucket}/{object_key}')
df.head()
# wr.s3.read_csv()
# wr.s3.read_json()
# wr.s3.read_parquet_table()
# wr.s3.read_deltalake()
执行上述代码会返回一个包含 5 月数据的 pandas df:

比较 → AWSWrangler 胜出
是的,有办法绕过 boto3 中缺乏适当方法的问题。然而,awswrangler 是一个旨在高效读取 S3 对象的库,因此它也赢得了这个挑战。
# 7.2 读取对象(SQL)
读到这里的读者值得一个奖励,那就是使用纯 SQL 从 S3 读取对象。
假设我们希望使用下面的 query 从 account_balances_may2023.parquet 对象中获取数据(通过 AS_OF_DATE 进行数据过滤):
object_key = 'account_balances_may2023.parquet'
query = """SELECT * FROM s3object s
WHERE AS_OF_DATE > CAST('2023-05-13T' AS TIMESTAMP)"""
在 boto3 中,可以通过 select_object_content() 方法实现。注意,我们还需要指定 inputSerialization 和 OutputSerialization 格式:
# BOTO3
client = boto3.client('s3')
resp = client.select_object_content(
Bucket=bucket,
Key=object_key,
Expression= query,
ExpressionType='SQL',
InputSerialization={"Parquet": {}},
OutputSerialization={'JSON': {}},
)
records = []
# Process the response
for event in resp['Payload']:
if 'Records' in event:
records.append(event['Records']['Payload'].decode('utf-8'))
# Concatenate the JSON records into a single string
json_string = ''.join(records)
# Load the JSON data into a Pandas DataFrame
df = pd.read_json(json_string, lines=True)
# Print the DataFrame
df.head()
如果使用 pandas df 是一个选项,那么 awswrangler 还提供了一个非常方便的 select_query() 方法,只需最少的代码:
# AWS WRANGLER
df = wr.s3.select_query(
sql=query,
path=f's3://{bucket}/{object_key}',
input_serialization="Parquet",
input_serialization_params={}
)
df.head()
对于这两个库,返回的 df 将如下所示:

结论
在本教程中,我们探讨了 7 种常见操作,这些操作可以在 S3 存储桶上执行,并对 boto3 和 awswrangler 库进行了比较分析。
两种方法都允许我们与 S3 存储桶进行交互,但主要区别在于 boto3 客户端提供了对 AWS 服务的低级访问,而 awswrangler 提供了一个简化的、更高级的数据工程任务接口。
总体而言,awswrangler 以 3 分(检查对象是否存在、写入对象、读取对象)战胜了 boto3 的 2 分(列出对象、删除对象)。上传/下载对象 类别平局,因此没有得分。
尽管结果如上,但事实是这两个库在可以互换使用时表现最佳,能够在它们所设计的任务中表现出色。
资源
除非另有说明,所有图片均由作者提供。
Branches Are All You Need: Our Opinionated ML Versioning Framework
一个实用的机器学习项目版本控制方法,使用 Git 分支简化工作流程,组织数据和模型
·
关注 发表在 走向数据科学 ·10 分钟阅读·2023 年 10 月 10 日
--
简而言之
一个简单的机器学习项目版本控制方法,使用 Git 分支简化工作流程,组织数据和模型,并将项目相关部分紧密联系在一起。
介绍
在管理机器学习解决方案时,解决方案的各个方面通常分布在多个平台和位置,如 GitHub 用于代码,HuggingFace 用于模型,Weights and Biases 用于追踪,S3 用于一切的备份等等。
在数据方面,我们有训练数据、处理过的数据、训练跟踪数据和模型监控数据。我们保存用于推理的模型,包括旧版本和用于在线测试的实验模型。我们还有用于预处理、训练、推理、实验、数据科学分析和监控警报的代码。
这很容易失控。

作者提供的图像
使用各种工具、环境和资产地址来跟踪 ML 生命周期的不同部分,可能导致状态分散和协调不足。这可能导致数据丢失、安全漏洞和需要仔细管理的配置错误。
在一个之前的项目中,我们在本地解决方案中使用 SageMaker 进行日常训练。这要求客户每天下载一个模型,并使用各种客户数据集的组合进行训练。
我们必须管理哪个二进制模型是用哪个训练代码在哪个客户的数据上训练的,哪个模型在哪个客户运行,使用哪个推理代码等等。
在这篇文章中,我们将展示如何利用数据版本控制工具,通过 Git 解决这些问题。
数据版本控制工具允许你与代码一起提交数据和模型文件,无论其大小如何。通过这种方式对所有文件进行版本控制,你可以绕过管理数据和模型资产的不便。
在理想情况下,你只会拥有与当前任务相关的代码、数据和模型。无论你是在开发还是运行 ML 训练、跟踪实验、监控生产中的模型,还是进行在线实验。
想象一下,不用手动(甚至自动)连接正确的部分 — 在正确的代码中加载正确的数据,将它们连接到 HuggingFace 中正确的模型,并在正确的环境中运行 — 每次你检出一个分支时,所有部分都已经就位。
在本文中,我将展示一个框架,利用 Git 代替处理工具复杂性的难题,Git 是几乎所有 ML 团队已经在使用的系统。
目标是通过将所有内容集中在一个地方,由 Git 管理,简化并消除启动 ML 工作流每个阶段的障碍。
我们的要求
-
一个简单的工作流,易于暂停、恢复和调整以适应不断变化的业务和开发需求。它还应支持可重现性,并允许事后查询,例如“我的模型是用什么数据训练的?”
-
高效使用数据和代码,重点关注凝聚力、治理和审计。— 旨在尽可能多地重用数据和代码,并利用 git 的功能,如提交、问题和标签。
-
整合 ML 解决方案的各个方面非常重要。通常,实验跟踪、生产中的模型监控以及在线和离线实验与训练和推理部分是分开的。在这里,我们旨在将它们统一在一个平台下,使其在各部分之间过渡变得容易。
-
遵循Git 和 ML 最佳实践,例如早期和可共享的数据拆分、测试以及简单的协作,以适应不同的 ML 工程师。
关键概念
每个更改都是一个 Git 提交:这包括数据上传、特征工程、模型覆盖、合并实验指标结果、模型监控,以及自然的代码更改。
活动分支:将不同的分支用于开发和生产是常见的做法。然而,我们可以在这里更进一步。这意味着你可以签出一个分支,并在一个地方拥有所有必要的数据、代码、模型、文档、readmes 和带有指标的模型卡。
🤯 你的代码库就是你的 blob 存储!你的分支像一个桶一样在你的 blob 存储中运行,允许你上传、下载和存储数据和模型。
这允许你为不同的开发、实验和生产需求使用不同的分支,而不是依赖于不同的平台和工具。
合并作为工作流程:它们用于合并分支。代码通常被合并,而模型通常会覆盖现有模型。当数据被合并时,通常是追加的。要接收新数据,可以从另一个分支“拉取”数据进行复制。
合并数据可以像文件复制或追加到 JSON 列表一样简单。在更复杂的情况下,你可以合并 sqlite 数据库。
去重,这是数据版本控制工具中常用的功能,防止即使在多个分支中包含相同文件时,也会创建多个文件副本。
分支类型

图片来自作者
主分支
首先,我们使用项目的main分支来存储问题定义、文档、数据描述和项目结构。这作为协作和讨论的空间。
提示:通过明确定义业务问题、确定期望的结果、识别目标值或标签及其获取方式,并建立评估指标和要求,我们可以确保项目的成功启动,并提供一个入门和协作的地方。
我们也可以用它来跟踪实验,将实验结果进行合并。例如,MLflow的 mlruns 文件夹可以在此处进行合并。任何协作者都可以签出这个分支并运行 UI。
另外,跟踪可以在另一个分支中完成。
以这种方式开始非常简单,随着需求的变化,可以以最少的更改升级到 MLflow 服务器或类似 Weights and Biases 的跟踪平台。
数据分支
这些是项目的分支,主要包括数据文件、文档和转换脚本,并且它们保持活动状态。你可以把它们想象成 S3 桶,但不是上传和下载,而是签出一个分支,你的文件就在那里。
推荐始终提交(上传)到原始分支。这些分支创建了一个真实的来源,一个从不被编辑或删除的地方,因此我们可以始终追踪数据的来源和经过过程。它还使得创建新流程、审核和治理变得容易。
💡 如果你添加一个关于数据来源的提交信息,你可以获得对数据更细致的观察。
你可以使用另一个仅包含干净数据的清理分支。例如,上传到原始分支的损坏图片或空文本文件不会出现在清理分支中。
一个分割分支,其中数据被划分用于训练、验证和测试,可以确保所有团队和协作者在相同的条件下工作。
这种方法有助于防止数据泄漏,并实现更稳健的特征工程和协作。最小化测试集样本被包含在训练阶段的机会减少了引入偏差的风险。此外,所有协作者使用相同的分割可以确保实验结果的一致性和无偏。
在之前的分类项目中,我是一个由个体贡献者组成的团队的一员,每个人从头开始运行整个管道;我们每个人使用了不同的数据拆分百分比和种子,这导致了在生产中因错误和数据偏见而产生较弱的模型。

作者提供的图片
💡 机器学习提示:三阶段模型开发最佳实践 我们使用“训练”和“验证”数据集来训练和优化模型的超参数。然后,我们将训练加验证数据集作为训练集来训练调整后的模型,并仅一次使用测试数据集进行评估。最后,我们在所有数据上训练模型并将其保存为我们的模型。
稳定分支
这些分支是活跃的分支,用于训练和推断。在这里,你可以运行训练,保存模型、检查点和模型卡,运行测试,构建和测试 Docker 镜像,在训练周期结束时提交所有内容,然后标记。它们应能处理新数据的检索和重新训练。这是自动化发生的地方。
⚠️ 这些分支中不编写代码。
这确保模型与其训练所用的数据、用于训练和在生产中运行的代码(包括特征工程)以及结果指标相关联。所有这些组件被合并成一个统一的“快照”。每当你检出一个标签时,所有必要的模型组件都存在。
💡 提示:通过提前选择标签名称,你可以在训练过程中将其作为参数添加到跟踪信息中。这确保你可以通过任何跟踪工具根据跟踪数据随时检索模型-数据-代码的“快照”。
训练后,只有跟踪数据被合并(复制)到你的主分支用于跟踪。
在最简单的情况下,它可以是一个包含超参数和评估结果的 JSON 文本文件。然后将该文件附加到 main 分支中的列表中。在 MLflow 的情况下,它涉及将实验从 mlruns 文件夹 复制到 main 分支。
代码分支
这些分支用于 代码开发和 数据探索,在抽样或小数据上进行训练,直到你拥有一个有效的程序。在开发过程中,欢迎使用所有 Git 最佳实践。然而,只有在 不再需要进一步更改代码 的情况下才创建 stable 分支,即使额外的数据被拉取。这些分支 应该包括 推断代码、服务器、Dockerfile 和测试。
始终至少有一个开发分支保持活动状态,所有新功能、错误修复和其他更改都在此分支中合并。
💡 ML 和 MLOps 工程师可以在训练和推断方面进行协作。
例如,你可以创建一个 dev/model 分支,在其中开发一个 基线 模型。这可以是分类中最流行的类别或回归中的均值/中位数。重点是设置代码,同时彻底理解你的数据。
当它稳定并且测试通过后,我们会分支到 stable/model,在这里我们训练并提交模型、代码和数据到远程,并 标记 该提交。这种方式快速且易于分享,能够使 DevOps、后台和前端团队启动开发并交换反馈。它还将促进尽早在真实环境中验证新发现的需求。
接下来,我们在 dev/model 分支上开发一个 简单模型,比如线性回归,当它准备好并且测试通过后,我们可以将其合并到 stable/model 分支,在那里我们训练、提交,并为 prod 打上标签。
这种方法让你可以逐步改进模型,同时保留稳定分支中之前模型的完整上下文。

图片由作者提供
从这一点开始,我们有三种选择:
-
当更多数据到达时,我们可以通过将数据拉取到稳定分支来 重新训练 模型。
-
我们可以开始在 dev/linear-regression 分支上使用特征工程进行 实验。
-
我们可以为更复杂的模型创建一个 新的 dev/new-approach 分支。
监控分支
在模型监控中,我们关注数据分布、异常值和预测分布。
在 monitoring 分支中,我们将查询到的数据、提交标签和来自 prod 的模型预测保存为文件。
💡 你可以为每个环境(开发、稳定和生产)使用多个监控分支。
我们可以在数据提交上设置警报,以测试任何特征分布的漂移、异常值、校准的合理性测试,并保存警报代码;这使得我们可以实现更高级的解决方案,比如异常值检测模型,因为我们也可以在这个分支中保存模型。

作者提供的图片
这个分支通常可以属于一个与生成监控日志的代码、数据和模型解耦的其他项目。
分析分支
数据科学和分析是项目的另一个方面,通常被分离到不同的项目中。这是数据科学家的分析代码和非训练数据汇聚的地方。
数据科学家可以从monitoring分支中检查和提取数据,以进行分析、A/B 测试和其他在线与离线实验。他们也可以使用raw分支中的数据进行这些目的。
在线示例更简单,因为每个实验组对应一个分支。
💡 提示:常见的在线实验:
前向测试 - 比较当前模型的 99%与候选模型的 1%。
回测 - 在合并新模型后,保留 1%在旧模型上以验证预期效果的反向。
将模型标签作为监控数据中的参数,有助于你精确定位度量潜在原因的每一变化。
总结

作者提供的图片
本文介绍了使用 Git 分支对机器学习项目进行版本控制的框架。该框架简化了工作流程,组织了数据和模型,并将项目的相关部分耦合在一起。它强调了将分支用作环境的使用,每个分支包含特定任务所需的数据、代码、模型和文档。文章还讨论了如使用不同的活跃分支类别等关键概念。总体而言,该框架旨在提高机器学习项目的工作流程效率、治理和协作。
如果你想聊天或了解更多信息,加入我们的discord或关注我们的博客。
结语
关于我的本地挑战,我们为每个相关的训练代码和数据集组合维护了一个“稳定”的分支。训练完成后,我们会用一个合适的标签(<client-id>-<incremental version>)标记提交。客户可以像其他发布版本一样拉取最新的标签。

在“调试”客户端时,我们会参考特定时刻的标签来审查代码和相应的数据。我们还可以使用相同的标签匹配监控数据,这个标签也添加到监控数据中。分析笔记本可以在我们的ds/client-id分支中找到。
打破界限:探索 LLM 的函数调用
原文:
towardsdatascience.com/breaking-boundaries-exploring-function-calling-for-llms-73d063d46fcb
函数调用如何为大型语言模型(LLM)与外部工具和 API 的无缝集成铺平道路
·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 8 月 10 日
--

使用 SDXL 生成的图像
当我发现大型语言模型(LLM)具备与外部工具和 API 互动的能力时,我知道一切将不再相同。
这是否使我们更接近实现人工通用智能(AGI)?也许没有,但它无疑开启了人工智能的全新时代:一个 LLM 可以执行任何你能放入函数中的任务的时代。现在,这可能看起来有些异端,但对我来说,这使我们更接近我对 AGI 的愿景,因为我不是那种幻想有意识机器或其他科幻叙事的人。
想象一下这样一个场景:一个人工智能代理完全掌控一个任务,与你在线运行的其他代理进行沟通,获取数据,并返回你所期望的结果。这种变革性的能力不仅重新定义了我们与互联网的互动方式,而且还重塑了我们的思维过程。
快进几年:假设你想计划一次假期。第一步不是在网上搜索票务,而是指示一个 AI 代理从头到尾规划和组织一切。你只有在收到几封确认邮件时才会知道任务完成了——如果邮件仍然存在的话——并且你的信用卡上显示了一笔费用。
但我们不要走得太远。即使在我们当前的现实中,你也可以创建能够执行任何程序上可能的任务的代理。不过,这需要你付出一些努力。你需要熟悉 OpenAI 模型系列的新函数调用能力,并创建模型将使用的自定义工具。
这就是这篇博客文章的用武之地。把它当作一个动手指南,帮助你将 AI 代理变为现实——一个可以在你的个人 Google Calendar 账户上创建事件的代理。然而,通过简单地更改 LLM 调用的函数,世界将由你掌控。准备好深入了吗?
学习率 是一本面向对 ML 和 MLOps 世界感兴趣的人的通讯。如果你想了解更多类似的主题,请 订阅。你将在每个月的最后一个星期天收到我的更新和关于最新 MLOps 新闻和文章的想法!
工具包
我们待办事项的首要任务是构建工具,即我们的 LLM 将调用的函数来创建新的日历事件。我们的选择武器?LangChain——可能是发展最快的 Python 库。
LangChain 已经配备了大量现成的 工具。然而,它的武器库中没有提供将 LLM 与 Google Calendar API 集成的工具。但没关系——我们自己来创建!
我们从深入了解 Google Calendar Event 参考指南 开始。这有助于我们概述函数签名的样子。但请记住,我们并不需要为每一项提供值,所以让我们标记出我们认为最关键的参数,并定义函数的输入。小提示:LangChain 期望我们使用 Pydantic 模型来实现这一点。
from pydantic import BaseModel, Field
class CalendarEventInput(BaseModel):
summary: str = Field(description="The title of the event")
location: str = Field(description="The location of the event")
description: str = Field(description="The description of the event")
start: EventDateTime = Field(description="The start datetime of the event")
end: EventDateTime = Field(description="The end datetime of the event")
attendees: List[Attendee] = Field(description="The attendees of the event")
reminders: Reminders = Field(description="The reminders of the event")
conferenceDataVersion: int = Field(
description="Set to `1` if you need to create a new Google Meet link.")
recurrence: List[str] = Field(
description="A list of RRULE, EXRULE, RDATE and EXDATE lines for a"
" recurring event, as specified in RFC5545")
正如你所看到的,我们期望我们的模型提供事件的各种细节:摘要、位置、参与者、开始和结束时间等等。
这些字段有些是简单的类型,如字符串或列表,而其他字段则要求我们挽起袖子定义 Python 类。那么,事不宜迟,让我们深入探讨下一步:
from typing import List
from dataclasses import dataclass
@dataclass()
class EventDateTime:
dateTime: str
timeZone: str
@dataclass()
class Attendee:
displayName: str
email: str
optional: bool
@dataclass()
class ReminderOverride:
method: str
minutes: str
@dataclass()
class Reminders:
useDefault: bool
overrides: List[ReminderOverride]
很好,我们已经准备好函数模型。接下来一步——制作自定义工具。我们的任务是创建一个继承自 LangChain 的 BaseTool 类的类,至少需要实现 _run 或 _arun 方法。
我们还需要设置一些关键值,如工具的 name、description(提示模型何时使用它),以及 args_schema 属性,后者描述了工具的模式——我们之前定义的模式。让我们开始吧:
from langchain.tools.base import BaseTool
from googleapiclient.discovery import build
class CalendarEventTool(BaseTool):
name: str = "calendar_event"
description: str = "Useful tool for creating new Google calendar events"
args_schema: Type[BaseModel] = CalendarEventInput
def _create_event(self, calendar_id: str, body: dict,
conferenceDataVersion: int):
"""Create a new Google calendar event.
Args:
calendar_id (str): The calendar id.
body (str): The event body.
conferenceDataVersion (int): Set to `1` to create a new Google
Meet Event.
Returns:
dict: The event response.
"""
service = build("calendar", "v3", credentials=get_credentials())
event = (service.events() # type: ignore
.insert(calendarId=calendar_id, body=body,
conferenceDataVersion=conferenceDataVersion)
.execute())
return event
def _run(self, summary: str, location: str, description: str,
start: EventDateTime, end: EventDateTime,
attendees: List[Attendee], reminders: Reminders,
conferenceDataVersion: int, recurrence: List[str]):
"""Run the CalendarEventTool with the given parameters.
Args:
summary (str): The summary or title of the event.
location (str): The location of the event.
description (str): The description or details of the event.
start (EventDateTime): The start date and time of the event.
end (EventDateTime): The end date and time of the event.
attendees (List[Attendee]): A list of attendees for the event.
reminders (Reminders): The reminders for the event.
conferenceDataVersion (int): The version of the conference data.
recurrence (List[str]): A list of recurrence rules for the event.
"""
body = create_request_body(summary, location, description,
start, end, attendees, reminders,
recurrence)
event = self._create_event(CALENDAR_ID, body, conferenceDataVersion)
def _arun(self):
raise NotImplementedError("calendar_event does not support async")
我们的工具非常简单。我们只定义了 _run 函数,因为我们不需要为这个任务支持异步调用。_run 函数的作用是组装请求体并调用另一个私有函数,恰如其分地命名为 _create_event。
让我们逐步处理。首先,我们需要定义 create_request_body 实用函数:
import random
import string
from typing import List
def _create_attendee_list(attendees):
attendee_list = [{"displayName": attendee.displayName,
"email": attendee.email,
"optional": attendee.optional}
for attendee in attendees]
return attendee_list
def _create_reminder_list(reminders):
reminder_list = {"useDefault": reminders.useDefault,
"overrides": [{"method": override.method,
"minutes": override.minutes}
for override in reminders.overrides]}
return reminder_list
def create_request_body(summary: str, location: str, description: str,
start: EventDateTime, end: EventDateTime,
attendees: List[Attendee], reminders: Reminders,
recurrence: List[str]) -> dict:
attendee_list = _create_attendee_list(attendees)
reminder_list = _create_reminder_list(reminders)
request_id = ''.join(random.choice(string.ascii_letters) for _ in range(8))
body = {
"summary": summary,
"location": location,
"description": description,
"start": {
"dateTime": start.dateTime,
"timeZone": start.timeZone
},
"end": {
"dateTime": end.dateTime,
"timeZone": end.timeZone
},
"attendees": attendee_list,
"reminders": reminder_list,
"conferenceData": {
"createRequest": {
"requestId": request_id,
"conferenceSolutionKey": {
"type": "hangoutsMeet"
}
}
},
"recurrence": [r for r in recurrence]}
return body
我们的create_request_body函数依赖于几个其他的实用函数来创建 JSON 对象。我们在它上面几行定义了这些私有函数。
我们的下一个目标?_create_event函数。这个函数构建了我们将用于执行 API 调用的服务。为了完成这项工作,你需要获取一个凭据文件。不确定怎么做?只需遵循这个指南:OAuth 客户端 ID 凭据。下载 JSON 文件后,将其重命名为credentials.json,然后使用下面的函数来获取和存储令牌:
import os
from typing import Union
from pathlib import Path
import google.oauth2.credentials as oauth2
import google.auth.external_account_authorized_user as auth
from google.oauth2.credentials import Credentials
from google.auth.transport.requests import Request
from google_auth_oauthlib.flow import InstalledAppFlow
HOME = Path.home()
TOKEN_FILE = "token.json"
CREDS_FILE = "credentials.json"
SCOPES = ["https://www.googleapis.com/auth/calendar"]
def get_credentials() -> Union[oauth2.Credentials, auth.Credentials]:
creds = None
# If token.json exists, read it and check if it's valid
if os.path.exists(TOKEN_FILE):
creds = Credentials.from_authorized_user_file(TOKEN_FILE, SCOPES)
# If there's no valid token.json, refresh it or create a new one
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(
CREDS_FILE, SCOPES)
creds = flow.run_local_server(port=0)
with open(TOKEN_FILE, "w") as token:
token.write(creds.to_json())
return creds
将所有这些代码整合在一起应该能为你提供所需的工具。这看起来可能有点复杂,但这不是你自己无法做到的,对吧?
现在,让我们换个方向,创建一个可以有效利用这个工具的 AI 代理。
代理
LangChain 确实消除了这一部分的猜测。你只需初始化一个新的[OPENAI_FUNCTIONS](https://python.langchain.com/docs/modules/agents/agent_types/openai_functions_agent)类型的代理。接下来,你需要为这个代理提供可以访问的工具,然后在用户查询上运行它。听起来很激动人心?那就让我们直接进入吧:
from langchain.chat_models import ChatOpenAI
from langchain.agents import AgentType, initialize_agent
# Initialize the language model
llm = ChatOpenAI(temperature=0, model_name="gpt-4-0613") # type: ignore
# Define the list of tools that the agent can use
tools = [CalendarEventTool()]
# Initialize the agent
agent = initialize_agent(tools, llm,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True)
user_input = input(f"{Emojis.ASSISTANT} How can I assist you today?\n"
f"{Emojis.USER} > ")
agent.run(input=user_input)
最后但同样重要的是,是时候将这一切付诸实践了。运行你的 Python 文件,并尝试一个测试查询,例如:
“为我与 John 的会议创建一个新的事件,时间定在 8 月 10 日中午(雅典/希腊)。别忘了邀请 John(john@example.com)并设置一个在线会议室。”
这可能会非常顺利,不过你可能需要一些提示工程。但是你做到了!现在,你已经具备了创建可以执行任何你能编写代码的 AI 代理的能力。这有多令人兴奋?
结论
在这次激动人心的旅程中,我们探索了 OpenAI 的大型语言模型(LLMs)如何与外部工具和 API 集成。我们深入探讨了这一改变游戏规则的特性如何重新定义我们与 AI 和互联网的互动,承诺未来 AI 代理可以无缝地执行任何可以编程定义的任务。
从理解函数签名的重要参数到定义 Python 类,我们采取了实践的方法来创建一个可以与 Google 日历 API 交互的模型,使用了一个快速发展的 Python 库 LangChain。
我们的尝试结果是一个能够在个人 Google 日历上创建事件的 AI 代理。然而,你现在应该能够创建一个可以执行任何你能放入函数中的任务的代理。
AI 的未来以及它与数字世界的互动前景光明,随着我们继续利用 AI 的力量,我们可以实现的可能性是无限的。让我们看看你接下来会构建什么!
关于作者
我的名字是 Dimitris Poulopoulos,我是一名在 HPE 工作的机器学习工程师。我为主要客户如欧盟委员会、国际货币基金组织、欧洲中央银行、宜家、Roblox 等设计和实施了人工智能和软件解决方案。
如果你对阅读更多关于机器学习、深度学习、数据科学和数据操作的文章感兴趣,可以在 Medium、LinkedIn 或在 Twitter 上关注 @james2pl。
所表达的观点仅代表我个人,并不反映我雇主的观点或意见。
用一种新的 AI 模型打破蛋白质设计的界限,该模型理解与任何类型分子的相互作用
这一新模型有助于 扩展 ML 模型在工程蛋白质中所需功能的适用性,通过调整其与任何类型分子的特定相互作用,从而有效影响生物技术和临床应用。
LucianoSphere (Luciano Abriata, PhD)
·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 6 月 21 日
--

作者通过编辑 Dall-E-2 生成的概念艺术图“蛋白质工程”(最初使用 此处)。
在 Deepmind 的 AlphaFold 在结构生物学领域引发革命之后,紧密相关的蛋白质设计领域最近通过深度学习的力量进入了一个新的进步时代。然而,现有的蛋白质设计机器学习 (ML) 模型在将非蛋白质实体纳入设计过程中存在局限,只能处理蛋白质组件。在我们的新预印本中,我们介绍了一种新的深度学习模型,“CARBonAra”,该模型考虑了蛋白质周围的任何分子环境,因此能够设计能够结合任何类型分子的蛋白质:药物类似的配体、辅因子、底物、核酸,甚至其他蛋白质。通过利用我们之前的 ML 模型中的几何变换器架构,CARBonAra 从骨架支架中预测蛋白质序列,同时考虑到任何性质的分子施加的约束。这种开创性的方法有助于通过调整与任何类型细胞成分的特定相互作用来扩展 ML 模型在工程蛋白质中所需功能的多样性。

方案概述了这个新深度学习模型可以做什么:计算蛋白质设计中氨基酸的概率,从一个目标蛋白质骨架开始,该骨架周围有其他分子(在此处由顶部的绿色分子示例)。图像由作者制作。
引言
作为数据科学家,我们不断努力突破可能性的界限。蛋白质设计,即创造具有期望功能和属性的新蛋白质,是一个这样的行动领域;特别是一个在生物学、医学、生命技术和材料科学等多个学科中具有深远影响的领域。尽管基于物理的方法在找到折叠成给定蛋白质结构的氨基酸序列方面取得了进展,但深度学习技术已经成为游戏规则的改变者,显著提高了设计成功率和多样性。
我最近在这里讨论了四种现代机器学习模型用于蛋白质设计和工程:
由于这些基于人工智能的方法和工具,蛋白质生物技术进入了前所未有的激动人心的时代
[towardsdatascience.com
虽然这些模型在许多蛋白质设计任务中取得了成功,但它们在设计过程中考虑非蛋白质实体的能力有限——它们根本无法处理这些实体,这一限制影响了它们的多样性并缩小了它们的应用范围。
为了克服这一挑战,我们在最新的预印本中介绍了一种名为 CARBonAra 的新模型,该模型通过接受作为输入的目标蛋白质支架和任何类型的相互作用分子来彻底改变蛋白质序列设计。这里是预印本:
蛋白质设计和工程在利用深度学习的进展方面正以空前的速度发展。当前……
CARBonAra 基于我们的蛋白质结构变换器(PeSTo),一种几何变换器架构,该架构处理原子点云,忽略原子类型以元素名称直接表示分子。我之前详细描述了 PeSTo:
详细了解新预印本 [## 新预印本描述了一种新型的无参数几何变换器,用于…
并且运行速度如此之快,以至于它甚至可以扫描大量蛋白质结构以搜索易于交互的氨基酸…
CARBonAra 的核心基于 PeSTo 模型,使其能够将任何种类的非蛋白质分子,包括核酸、脂质、离子、小配体、辅因子或其他蛋白质,纳入新蛋白质的设计过程中。因此,给定一个带有一个或多个配体的输入蛋白质结构,CARBonAra 预测氨基酸的残基信心,通过这些信心的最大值可以重建蛋白质序列。为此,CARBonAra 以骨架作为输入,并生成潜在序列的空间,这些序列可以通过特定的功能或结构要求进一步约束——例如,固定某些氨基酸,若它们对某一功能至关重要。CARBonAra 通过考虑蛋白质的分子背景,为蛋白质设计提供了前所未有的灵活性和深度,这意味着它可以为绑定离子、底物、核酸、脂质、其他蛋白质等特定区域进行设计。
在我们的评估中,CARBonAra 的表现与诸如 ProteinMPNN 和 ESM-IF1 等最先进的方法相当,同时展示了类似的计算效率——都非常快速。该模型在设计蛋白质单体和蛋白质复合物方面实现了与 ProteinMPNN 和 ESM-IF1 相似的序列恢复率,但除此之外,它还能处理包含非蛋白质分子的蛋白质设计,这是其他方法无法处理的。
CARBonAra 的一个显著特点是其能够通过整合各种约束来定制序列以满足特定目标。例如,它可以优化序列相似性、最小化相似度或实现低序列相似性。此外,通过利用 CARBonAra 与分子动力学模拟的结构轨迹,我们观察到可以提高序列恢复率,特别是在以前的方法显示成功率较低的情况下。
要了解更多关于该方法,特别是 ML 架构的细节,请查看我们在 bioRxiv 上的预印本:
相关文献 [## 上下文感知的几何深度学习用于蛋白质序列设计
蛋白质设计和工程正在以空前的速度发展,利用深度学习的进步。目前…
一些与结构生物学中的人工智能相关的文章
[## 超过一年 AlphaFold 2 免费使用及其在生物学中引发的革命
自信地建模蛋白质结构,预测它们与其他生物分子的相互作用,甚至蛋白质…
medium.com [## 通过共识方法设计稳定蛋白质的网络应用程序,使用 JavaScript、ESMFold 创建…
结合现代技术和工具进行高效工作,创建一个实现最简单但如今最有效的应用程序…
javascript.plainenglish.io ## “ML-Everything”?平衡科学中机器学习方法的数量和质量
需要适当的验证和良好的数据集,客观且平衡,并且预测在现实中有用…
[towardsdatascience.com ## 巨大的蛋白质语言模型如何颠覆结构生物学
结构预测的准确性与 AlphaFold 相似,但速度提高了高达 60 倍——同时开发了新的人工智能方法…
[towardsdatascience.com
www.lucianoabriata.com 我写作和拍摄的内容涵盖了我广泛兴趣领域中的一切:自然、科学、技术、编程等
在这里给我小费或者成为 Medium 会员以获取所有它的故事(我将获得少量收入,你无需付费)。订阅以获取我的新故事通过电子邮件。在我的服务页面上咨询有关小型工作。您可以 在这里联系我。
什么是组合优化?
展示组合爆炸的真正威力以及它们如何影响简单问题。
·发表于 Towards Data Science ·阅读时间 5 分钟·2023 年 4 月 10 日
--

图片来源 Shubham Dhage 于 Unsplash
什么是组合优化?
想象一下你是一名数据科学家,为一家航空公司工作,你被要求找到一周内的最佳航班安排,以在燃料和空域等限制条件下最大化航班数量。
你会怎么做?
好吧,你可以尝试每一个可能的解决方案,这称为 蛮力搜索,但如果我告诉你我们要进行 500 次航班呢?你需要尝试的不同 组合 数量将达到 ~500! 的规模,这大约等于 1.22 x 10¹¹³⁴. 这个数字非常巨大,使得蛮力搜索变得 不可处理。
那么,你如何解决这个问题?
组合优化!
组合优化 处理从有限对象中寻找最优解的问题,而该问题通常是难以处理的。这听起来很像我们上面的问题!
实际上,组合优化在许多领域中都有应用:
-
物流和供应链
-
制造业
-
金融
-
医疗保健
这使得理解和开发处理这些类型问题的技能成为数据科学家必备的能力。
要了解更多关于组合优化方法的知识,这些方法主要是元启发式,你可以查看我以前的一些帖子:
使用模拟退火优化算法获得旅行商问题的最佳解决方案
towardsdatascience.com ## Tabu 搜索简单解释
对 Tabu 搜索优化算法的直观解释及其如何应用于旅行商问题…
towardsdatascience.com ## 爬山算法简单解释
最受欢迎的优化算法之一的直观解释
towardsdatascience.com
然而,在这篇文章中,我想真正探讨这些问题为何如此困难,并演示一个简单的问题如何在复杂性上‘爆炸’。
旧的最爱:旅行商问题
组合优化问题的经典例子是旅行商问题 (TSP)。这是一个非常简单的问题,提出了以下问题:
‘找到一组城市中最短的路线,访问每个城市一次并回到起点城市’
听起来很简单,对吧?在现实场景中,情况远非如此。
这是因为随着城市数量的增加,问题的可能解决方案数量会导致组合爆炸。TSP 的解决方案数量是:

作者在 LaTeX 中的方程。
其中 n 是城市的数量。
让我们用一些实际数字来真正展示这种‘爆炸’:
-
n = 4: 解决方案 = 3
-
n = 8: 解决方案 = 2,520
-
n = 16: 解决方案 = 1.0461395 x 10¹³
更值得注意的是,对于20个城市,通过暴力破解解决 TSP 需要大约 1,900 年!
展示组合爆炸
让我们真正深入探讨一些代码,以全面理解暴力算法的内部机制,并描绘组合爆炸。
以下是使用暴力搜索解决 TSP 的一个模板类:
GitHub Gist 作者。
现在让我们使用算法对一些虚拟‘假’城市数据进行测试,以确认它确实返回了最佳解决方案:
GitHub Gist 作者。

作者在 Python 中生成的图表。

作者在 Python 中生成的图表。
看起来不错!初始解决方案显然不是最优的,而最终的最佳解决方案在视觉上也显得是最好的。
我们现在将测量暴力算法的持续时间与城市数量的关系,以图形方式描绘组合爆炸:
GitHub Gist 作者。

作者在 Python 中生成的图表。
从上面的图表中,我们看到暴力搜索的持续时间在大约12个城市时急剧增加。用实际数字来讲,11个城市时花费了0.5***分钟,但***12***个城市时花费了***36分钟!所以,仅仅增加一个城市,算法计算时间增加了72倍!
这真正展示了这些组合优化问题的强大和难以解决性。
总结与进一步思考
在这篇文章中,我们看到简单问题如何在复杂性上“爆炸”。在处理大规模系统和网络的行业中,这是很常见的。优化这些大规模系统中的业务问题的过程被称为组合优化。这个领域的需求源于暴力搜索的难以处理性,它可能需要数千年才能找到相对搜索空间的最佳解决方案。我们通过描绘旅行商问题中发生的组合爆炸来展示这种现象。
本文中使用的完整代码可以在我的 GitHub 上找到:
## Medium-Articles/Optimisation/brute-force at main · egorhowell/Medium-Articles
你当前无法执行该操作。你在另一个标签或窗口中登录了。你在另一个标签或…
另一件事!
我有一个免费的通讯稿,Dishing the Data,在其中我每周分享成为更好数据科学家的技巧。没有“废话”或“点击诱饵”,只有来自实践数据科学家的纯粹可操作的见解。
## Dishing The Data | Egor Howell | Substack
如何成为更优秀的数据科学家。点击阅读由 Egor Howell 撰写的《Dishing The Data》,这是一个 Substack 出版物,内容包括…
与我联系!
参考文献与进一步阅读
-
优化算法。 Mykel J. Kochenderfer 和 Tim A. Wheeler。 2019。
-
组合优化:理论与算法 Bernhard Korte 和 Jens Vygen。 2018。
广义线性模型介绍
原文:
towardsdatascience.com/breaking-down-generalized-linear-models-d9212526e51d
扩展你的建模技能,超越线性回归
·发表于 Towards Data Science ·阅读时间 6 分钟·2023 年 7 月 10 日
--

图片来源:Roman Mager在Unsplash
背景
线性回归是我们在数据科学中学习的最常见算法。每个从业者都听说过并使用过它。然而,对于某些问题,它并不适用,我们需要对其进行‘推广’。这就是广义线性模型 (GLMs)的用武之地,它们为回归建模提供了更大的灵活性,并且是数据科学家必须了解的宝贵工具。
什么是 GLMs?
正如我们上面所说,GLMs‘推广’了普通线性回归,但我们真正是什么意思呢?
让我们考虑一个更简单的线性回归模型:

作者通过 LaTeX 生成的方程。
其中 β 是系数, x 是解释变量,ε 是正态分布的误差。
假设我们想要模拟保险公司在一小时内接到的索赔电话数量。线性回归是否适合这个问题?
不!
原因如下:
-
线性回归假设误差服从正态分布,而正态分布可以取负值。然而,我们不能有负的索赔电话。
-
第二点是正态分布,因此线性回归是连续的。而索赔电话都是整数和离散的,我们不能有 1.1 个电话。
因此,线性回归模型不能正确处理这个确切的问题。然而,我们可以将回归模型推广到符合上述要求的概率分布。在这种情况下,就是 泊松分布(稍后将详细介绍)。
GLM 简单地提供了一个框架,说明我们如何将输入与目标分布的期望输出链接起来。它们帮助将许多回归模型统一在一个“数学伞”下。
关于泊松分布的补充视频。
理论框架
概述
GLM 的基础依赖于三个关键因素:
现在我们来逐一解释这些内容的含义。
线性预测器
这是最简单理解的一个。线性预测器,η,仅仅意味着我们有一个输入(解释变量/协变量)x 的线性和,乘以其对应的系数 β:

由作者在 LaTeX 中生成的方程。
链接函数
链接函数,g,实际上负责将线性预测器与目标分布的均值响应 μ 进行“链接”:

由作者在 LaTeX 中生成的方程。
指数族
概述
GLM 的一个要求是输出的目标分布需要是 指数族。 这个分布族包含了许多你可能听说过的著名分布,如 泊松分布、 二项分布、 伽马分布、 和 指数分布。
在 GLM 框架中,我们实际上使用 指数离散模型,这是指数族的进一步推广。
为了属于指数族,概率密度(PDF)或质量函数(PMF)需要重新因式分解并参数化成以下形式:

由作者用 LaTeX 生成的公式。
这种形式是为了统计上的方便选择的,但在本文中我们不需要过多关注为什么会这样。
注意有两个参数 θ,这是自然或 canonical 参数,将输入与输出相关联,以及 ϕ,这是离散参数。
另一个有趣的事实是,指数族中的分布都有 共轭先验。这使得它们对于 贝叶斯问题 非常有用。如果你想了解更多关于共轭先验的信息,请查看我的相关文章:
一种计算上有效的贝叶斯统计方法
towardsdatascience.com
规范链接函数
有一种叫做规范链接函数的东西,定义为:

由作者用 LaTeX 生成的公式。
因此,如果我们能够用 μ 描述 θ,那么我们已经推导出了目标分布的自然链接函数!
均值和方差
数学上可以证明,指数族的均值,E(Y),由以下公式给出:

由作者用 LaTeX 生成的公式。
同样,方差,Var(Y),由以下公式给出:

由作者用 LaTeX 生成的公式。
如果你想查看这一推导的证明,请参考以下 linked 书的第 29 页。一般来说,解决方法是对 θ 的对数似然函数取导数。
泊松回归示例
泊松分布
泊松分布是一个著名的离散概率分布,用于建模在已知均值发生率下,事件发生特定次数的概率。如果你想了解更多,请查看我之前的帖子:
对最著名概率分布之一的概述
towardsdatascience.com
其 PMF 由以下公式给出:

由作者在 LaTeX 中生成的方程。
其中:
-
e: 欧拉数(约 2.73)
-
x: 出现次数(≥ 0)
-
λ: 期望出现次数(≥ 0),这也是 GLM 表示法中的均值μ
指数形式
我们可以通过对两边取自然对数,将上述泊松 PMF 写成指数形式:

由作者在 LaTeX 中生成的方程。
然后,我们对两边进行欧拉数的指数运算:

由作者在 LaTeX 中生成的方程。
现在,泊松 PMF 已经是指数形式了!
通过将系数与上述方程和指数家族 PDF 对应,我们得出以下结果:

由作者在 LaTeX 中生成的方程。
因此,泊松分布的均值和方差为:

由作者在 LaTeX 中生成的方程。
这是泊松分布的一个已知结果,我们只是推导出了一种不同的方法!
泊松 GLM
泊松分布的标准链接函数为:

由作者在 LaTeX 中生成的方程。
因此,泊松回归方程为:

由作者在 LaTeX 中生成的方程。
我们可以验证这个方程的输出只能为正,因此符合预测保险公司接到的索赔电话数量的问题的要求。
然后你可以通过 最大似然估计 或 迭代加权最小二乘法 来求解估计量。
重点是什么?
你可能会想,为什么我刚刚带你经历了这么繁琐的数学。好吧,让我快速总结一下关键要点:
-
检查问题的要求和目标分布以避免不合理的结果至关重要。
-
GLM 提供了一种从数学原理出发的方法,帮助你将输入与特定问题的期望输出联系起来。
总结与进一步思考
标准的线性回归模型很强大,但不适用于所有类型的问题,比如输出是非负的情况。对于这些特定的问题,我们必须使用其他分布,如泊松分布,广义线性模型(GLMs)提供了一个框架来执行这个过程。它们通过从基本原理推导一个链接函数,使你能够将输入转换为期望的目标输出分布。GLMs 是一个强大的建模工具,大多数数据科学家应该至少了解它们,因为它们的多功能性。
还有其他内容!
我有一个免费的新闻通讯,数据分析分享,每周分享成为更好的数据科学家的实用技巧。没有“虚 fluff”或“点击诱饵”,只有来自实践数据科学家的纯粹可操作见解。
[## 数据分析分享 | Egor Howell | Substack
如何成为更好的数据科学家。点击阅读由 Egor Howell 编写的《数据分析分享》,这是一个 Substack 出版物,内容涵盖…
newsletter.egorhowell.com](https://newsletter.egorhowell.com/?source=post_page-----d9212526e51d--------------------------------)
与我联系!
参考资料与进一步阅读
-
一些进一步有趣的理论和推导:
www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/exponential-families.pdf -
GLMs 的经典文献:
www.utstat.toronto.edu/~brunner/oldclass/2201s11/readings/glmbook.pdf
解构 YouTube 的推荐算法
原文:
towardsdatascience.com/breaking-down-youtubes-recommendation-algorithm-94aa3aa066c6
打开“窍门箱”,让现代推荐系统得以运作
·发表在Towards Data Science ·阅读时间 7 分钟·2023 年 4 月 17 日
--

(Logo 设计 Eyestetix Studio,背景设计 Dan Cristian Pădureț)
推荐系统已成为我们时代最普遍的工业机器学习应用之一,但关于它们在实践中如何运作的出版物却很少。
一个显著的例外是 Paul Covington 的论文“深度神经网络用于 YouTube 推荐”,其中充满了关于 YouTube 深度学习驱动的推荐算法的许多实际见解和学习,提供了一个罕见的视角,不仅揭示了现代工业推荐系统的内部工作,还揭示了今天的机器学习工程师正在尝试解决的问题。
如果你想深入了解现代推荐系统,为机器学习设计面试做准备,或者只是对 YouTube 如何吸引用户感到好奇,请继续阅读。在这篇文章中,我们将探讨论文中的 8 个关键见解,帮助解释 YouTube(以及任何现代推荐系统)的成功。
让我们开始吧。
1 — 推荐 = 候选生成 + 排序
YouTube 的推荐系统分为两个阶段:候选生成阶段,将数十亿的视频筛选到几百个,并且排序阶段,进一步缩小和排序最终展示给用户的候选视频。
从技术上讲,这两个阶段都包含一个双塔神经网络——一种具有两个分支分别处理用户 ID 和视频 ID 的特殊架构——但它们的训练目标不同:
-
对于候选生成模型,学习问题被表述为一个极端多类分类问题:在所有现有视频中,预测用户互动的视频。
-
对于排名模型,学习问题被表述为一个(加权的)逻辑回归问题:给定一个用户/视频对,预测用户是否与该视频互动。
这一设计选择的动机是将找到最佳内容的问题分解为召回优化和精度优化:候选生成优化召回,即确保我们捕捉到所有相关内容,而排名优化精度,即确保我们首先展示最佳内容。以这种方式分解问题是使推荐系统能够在数十亿用户和视频的规模上运行的关键。

YouTube 的两阶段推荐漏斗。来自 Covington 2016,YouTube 推荐的深度神经网络
2 — 隐式标签比显式标签效果更好
显式用户反馈,如点赞、分享或评论,非常稀少:在所有观看特定视频的用户中,只有少数人会留下显式反馈。因此,仅根据点赞训练的模型会留下大量信息未被利用。
隐式标签,如用户点击和观看时间,稍微有些噪声——用户可能会偶然点击——但数量级上要多得多。在 YouTube 的规模下,标签数量优于标签质量,因此在他们的模型中使用隐式反馈作为训练目标效果更好。
3 — 观看顺序很重要
特定用户的观看历史形成的序列不是随机的,而是包含具有不对称共同观看概率的特定模式。例如,在观看了两个来自同一创作者的视频之后,用户很可能会观看来自该创作者的另一个视频。一个简单地学习预测一个展示的视频是否会被观看的模型,忽略了用户最近的观看历史表现不好:同样,它也留下了信息未被利用。
相反,YouTube 的模型通过给定用户最新的观看(和搜索)历史来学习预测下一个观看的视频。从技术上讲,它通过将用户最新的 50 个观看视频和 50 个搜索查询作为特征输入到模型中来实现这一点。
4 — 排名模型使用加权逻辑回归进行训练
正面训练示例(点击的展示)根据其观察到的观看时间进行加权,而负面训练示例(没有点击的展示)则收到单位权重。这个加权方案的目的是减少点击诱饵内容的权重,增加导致更有意义和更长时间参与的内容的权重。
从数学上讲,这样的加权逻辑回归模型学到的赔率大致等于预期观看时间。在推理时,我们可以通过应用指数函数将预测的赔率转换为观看时间。能够预测观看时间可以得出下一个关键洞察:
5 — 通过预测观看时间排名优于通过点击率排名
这是因为点击率排名会促进低观看时间的点击诱饵内容:用户点击但很快返回。通过预测观看时间进行排名可以降低点击诱饵内容的排名,从而提供更具吸引力的推荐。
6 — 多样的特征集是高模型性能的关键
深度学习相对于线性或树模型的优势在于它可以处理多样的输入信号。YouTube 的模型会考虑:
-
观看历史:用户最近观看了哪些视频?
-
搜索历史:用户最近搜索了哪些关键词?
-
人口统计特征,例如用户性别、年龄、地理位置和设备,这些特征为“冷启动”用户,即没有历史记录的用户,提供了先验信息。
确实,特征多样性是实现高模型性能的关键:作者们展示了,使用所有这些特征训练的模型相较于仅使用观看历史训练的模型,持出 MAP 从 6%提高到 13%。

YouTube 的排名神经网络模型接受多样的特征集。从 Covington 2016,Deep Neural Networks for YouTube Recommendations
7 — 由于“示例年龄”特征,内容保持新鲜
ML 系统往往对过去有偏见,因为它们是在历史数据上训练的。对于 YouTube 来说,这是一个问题,因为用户通常更喜欢最近上传的“新鲜”内容,而不是很久以前上传的内容。
为了修复这种“过去偏见”,YouTube 将训练示例的年龄作为模型中的一个特征,并在推理时将其设置为 0,以反映模型在训练窗口的最后阶段进行预测。例如,如果训练数据包含 30 天的数据窗口,这一“示例年龄”特征会从 31(训练数据中的第一天)变化到 1(最后一天)。
作者们展示了引入这一特性使得模型更倾向于推荐新内容,这正是 YouTube 所希望的。
8 — 稀疏特征被编码为低维嵌入
YouTube 的排名模型使用了大量高维度(“稀疏”)的分类特征,例如
-
视频 ID 和用户 ID,
-
分词搜索查询,
-
用户最近观看的最后 50 个视频,或
-
启动当前观看会话的“种子”视频。
这些稀疏特征被独热编码,并映射到在模型训练过程中学习的 32 维嵌入中,然后作为嵌入表存储用于推理。
为了限制嵌入表所造成的内存占用,ID 空间被截断,只包括最常见的 ID。例如,如果一个视频在训练期间只被观看过一次,那么它不值得在嵌入表中占有一席之地,因此会被视为与从未观看过的视频相同。
另一个值得注意的技巧是,相同 ID 空间内的稀疏特征共享相同的基础嵌入。例如,存在一个全球性的、用于许多不同特征的单一视频 ID 嵌入,例如印象的视频 ID、用户最近观看的视频 ID 或当前会话的种子视频 ID。以这种方式共享嵌入有 3 个好处:
-
它节省了内存,因为需要存储的嵌入表更少,
-
它加快了模型训练,因为需要学习的参数更少,并且
-
它提高了泛化能力,因为它使模型能够获得有关每个 ID 的更多上下文。
附录:技巧包
总结来说,YouTube 的推荐系统没有真正让它特别的单一因素。它是一个“技巧包”,每个技巧解决一个特定的问题:
-
2 阶段漏斗设计解决了可扩展性问题,
-
使用加权逻辑回归和按预期观看时间排序解决了点击诱饵问题,
-
添加“示例年龄”特征解决了过去偏见问题。
-
添加人口统计特征解决了冷启动问题,
-
预测“下一次观看”(而不是随机观看)解决了不对称共同观看概率问题,
-
在具有相同 ID 的类别特征之间共享嵌入解决了有限内存问题,
等等。
这个技巧包不仅为现代推荐系统的内部运作提供了极大的洞察力,也展示了现代 ML 工程师的工作:我们解决问题以使模型更好。最优秀的 ML 工程师是那些随着时间积累了最佳技巧包的人。
最后,Covington 的论文现在已经好几年了,肯定有些技巧已经被更新更好的技巧所取代。有编码稀疏特征的新技巧,也有去偏排名模型的新技巧。这是工业 ML 应用程序固有的另一个方面:随着新突破的出现,我们的模型不断演变。
ML 工程师永远不会“完成”。
想扩展你的个人 ML “技巧包”?想深入了解现代工业 ML 应用背后的原理?查看我的电子书, 《地面上的机器学习:现实世界 ML 应用的设计与操作》。
用 ReLU 打破线性
原文:
towardsdatascience.com/breaking-linearity-with-relu-d2cfa7ebf264
解释 ReLU 激活函数如何以及为何是非线性的
·发表于 数据科学前沿 ·阅读时间 4 分钟·2023 年 3 月 1 日
--

图片由 Alina Grubnyak 提供,来源于 Unsplash
介绍
神经网络 和 深度学习 是人们转行进入数据科学的最受欢迎的原因之一。然而,这种兴奋可能会导致忽视神经网络的核心概念。在这篇文章中,我想讨论神经网络的一个关键特性,我认为大多数从业者应该了解,以充分理解其内部运作。
我们为什么需要激活函数?
激活函数 在数据科学和机器学习中无处不在。它们通常指的是应用于神经网络中神经元的 线性 输入的变换:

作者在 LaTeX 中的方程。
其中 f 是激活函数,y 是输出,b 是偏置,w_i 和 x_i 是 权重 和它们对应的特征值。
但是,我们为什么需要激活函数呢?
简单的答案是,它们使我们能够建模复杂的模式,且通过使神经网络变得 非线性 来实现。如果网络中没有非线性激活函数,那么整个模型就会变成一个 线性回归 模型!
非线性是指输入的变化与相应输出的变化不成比例。
例如,考虑一个前馈的两层神经网络,中间层有两个神经元(忽略偏置项):

作者用 LaTeX 写的方程。
我们已经成功将我们的两层网络简化为单层网络!上述推导中的最终方程仅仅是一个具有特征 x_1 和 x_2 及其对应系数的线性回归模型。
因此,我们的“深度神经网络”将会简化为单层,变成传统的线性回归模型!这不好,因为神经网络将无法对数据建模或拟合复杂函数。
线性函数的正式数学定义是:

作者用 LaTeX 写的方程。
这是一个非常简单的例子:

作者用 LaTeX 写的方程。
所以函数 f(x) = 10x 是线性的!
注意,如果我们在上面的方程中添加一个偏置项,它就不再是线性函数,而是一个 仿射函数。请参阅这个状态交换讨论讨论为什么会这样。
ReLU
整流线性单元 (ReLU) 是最流行的激活函数,因为它计算高效,并且解决了梯度消失问题。
数学上,该函数表达为:

作者用 LaTeX 写的方程。
我们可以用 Python 进行图形化展示:

由作者用 Python 生成的图。
为什么 ReLU 是非线性的?
ReLU 函数可能看起来是线性的,因为有两条直线。实际上,它是分段线性的。然而,正是这两条不同的直线使其成为非线性。
我们可以通过执行与上述相同的示例但使用 ReLU 函数来证明它是非线性的:

作者用 LaTeX 写的方程。
让我们分解一下:

作者用 LaTeX 写的方程。
因此,ReLU 是非线性的!
我已经在这里链接了一篇很好的文章,展示了如何使用 ReLU 创建任何函数。
总结和进一步的思考
非线性在神经网络中是至关重要的,因为它允许算法推断数据中的复杂模式。非线性是通过激活函数来实现的,其中最著名的是 ReLU,它在计算效率和解决训练神经网络时已知的问题方面表现优异。ReLU 函数是分段线性的,这就是它如上所述在数学上表现为非线性的原因。
完整的代码可以在我的 GitHub 上找到:
[## Medium-Articles/relu.py at main · egorhowell/Medium-Articles
你现在不能执行该操作。你在另一个标签页或窗口中登录了。你在另一个标签页或窗口中登出了……
另一个事项!
我有一个免费的通讯,Dishing the Data,在其中我每周分享成为更好的数据科学家的技巧。没有“虚浮内容”或“点击诱饵”,只有来自实际数据科学家的纯粹可操作的见解。
[## Dishing The Data | Egor Howell | Substack
如何成为更好的数据科学家。点击阅读《Dishing The Data》,作者为 Egor Howell,这是一个 Substack 出版物……
与我联系!
利用这个 Python 库弥合数据与人类之间的差距
让你的 Python 输出更易于理解
·发布于 Towards Data Science ·4 分钟阅读·2023 年 2 月 22 日
--
介绍
我们无需依赖任何统计数据就能意识到,Python 是软件开发人员、数据科学家等最常用的编程语言之一。这不仅因为其灵活性和易用性,还因为有大量的库可以使我们的日常任务更加轻松。
本文介绍了另一个强大的库:humanize。它通过使输出更易于理解,帮助弥合人类与 Python 输出之间的差距。让我们看看一些例子。
开始使用
为了使用humanize,第一步是使用 Python 包管理器pip进行安装,如下所示:
!pip3 install humanize
接下来,你需要导入以下相关库以成功完成教程。
-
getsize()来自os库,用于获取给定文件的大小。 -
datetime用于处理时间。 -
最后,是本文的核心——
humanize库。
from os.path import getsize
import datetime as dt
import humanize as h
一切准备就绪,开始探索大数字。
使大数字更具可读性
这个数字 1034503576643 是什么?
理解这个数字是否在十亿还是万亿范围内需要一些脑力。humanizer通过提供更友好的输出,试图减轻这种负担。
一种方法是使用正确的逗号**','**来分隔,方法是使用intcomma函数,如下所示:
big_num = 1034503576643
human_big_num_coma = h.intcomma(big_num)
print(human_big_num_coma)
上述代码的输出是1,034,503,576,643,比没有分隔符的原始数字要好得多。
此外,结果可以使用intword函数生成自然语言格式,如下所示:
human_big_num = h.intword(big_num)
print(human_big_num)
这会产生以下结果:1.0 trillion.
处理日期时间
2022 年 9 月 6 日(YYYY/MM/DD 格式)是 2022 年 9 月 6 日
第二种格式(Sep 6 2022)比第一种 YYYY/MM/DD 格式更容易被理解,因为它符合我们日常的口头交流。这种结果可以通过naturaldate函数获得。
date = dt.date(2022, 9, 6)
human_date = h.naturaldate(date)
print(human_date)
这将生成以下结果:Sep 06 2022。
可以使用naturalday函数将结果限制为月份和日期,而不是使用naturaldate。
human_day = h.naturalday(date)
print(human_day)
结果是Sep 06
处理持续时间
类似于 DateTime,也可以使用naturaltime函数使持续时间具有可读性,如下所示。
# Get today's date
current_time = dt.datetime.now()
# Get the date of 3 days before
few_days_before = dt.timedelta(days=3, hours=23, minutes=40)
# Compute the difference of time
past_time = current_time - few_days_before
human_time = h.naturaltime(past_time)
print(human_time)
之前的代码生成了3 days ago,这是任何人都可以理解的。
获取文件的大小和单位
我的文件大小是 278。
这个声明最明显的问题是
你使用的是什么单位?字节、千字节、兆字节、千兆字节、太字节?
这个谜题可以通过使用naturalsize函数解决,如下所示:
-
首先,使用
getsize函数获取CSV文件的大小。 -
然后使用
naturalsize函数生成更合适的输出。
fize_size = getsize("./candidates.csv")
# Before Humanize
print(fize_size)
# After Humanize
print(h.naturalsize(fize_size))
-
人性化处理前的结果是 278。
-
人性化处理后,我们得到了278 Bytes。
科学记数法和分数
给定数字的科学记数法在某些场景中可能更有用,例如使用power of the ten记法。这可以通过使用scientific函数来实现。
使用precision参数,用户可以指定小数点后的精度值数量。如果未指定,精度值为 2。
下面是一个示例。
# Number to convert to scientific format
value = 2304355
# Without Precision
scientic_notation = h.scientific(value)
print(scientic_notation)
# With precision
scientic_notation = h.scientific(value, precision = 5)
print(scientic_notation)
输出结果按print语句的顺序给出。
-
使用默认函数:2.30 x 10⁶
-
使用
precision参数:2.30436 x 10⁶
你认为 0.4646 的分数表示是什么?
避免过多的数学计算,只需使用fractional函数,如下所示:
float_value = 0.4646
# Get the fractional representation
fraction = h.fractional(float_value)
print(fraction)
答案是 105/226。这真的很酷,不是吗!
如果我处理的是另一种语言怎么办
之前的所有结果都是用英语给出的。其他语言如法语、俄语等也可以实现相同的效果。
实现这一点的第一步是使用i18n.activate函数激活国际化(i18n)功能。
例如,可以创建一个持续时间为 3 秒的时间差对象,但这次用法语。
# Activate the French Language
_t = h.i18n.activate("fr")
# Generate the time delta
h.naturaltime(dt.timedelta(seconds=3))
结果是il y a 3 secondes,这在英语中表示3 seconds ago。
结论
感谢阅读! 🎉 🍾
希望你觉得这篇文章有帮助!
如果你喜欢阅读我的故事并希望支持我的写作,可以考虑成为 Medium 会员。通过每月$5 的承诺,你可以无限制访问 Medium 上的故事。
你想请我喝咖啡吗☕️?→ 在这里请我!
随时欢迎关注我的 Medium、Twitter 和 YouTube,或者在 LinkedIn 上打个招呼。讨论人工智能、机器学习、数据科学、自然语言处理和 MLOps 的内容总是很愉快的!
在你离开之前,请查看下面的该系列的最后两部分:
Pandas 和 Python 数据科学与数据分析技巧 — 第一部分
Pandas 和 Python 数据科学与数据分析技巧 — 第二部分
Pandas 和 Python 数据科学与数据分析技巧 — 第三部分
跨领域桥接:将金融、隐私和软件最佳实践融入机器学习风险管理
负责任的人工智能
理解超越传统模型风险管理的策略
·发表于 Towards Data Science ·阅读时间 10 分钟·2023 年 10 月 10 日
--

作者图像 | 图标:Flaticon(个人及商业用途免费)
“航空法则是以血的代价写成的。我们不应当用人工智能再现这种方法论” — Siméon Campos
在 2018 年,彭博社的报道“Zillow 的算法驱动的购买狂潮使其房屋翻转实验注定失败”引起了广泛关注。报道概述了 Zillow 大胆进入iBuying领域,寄希望于其基于机器学习的Zestimate算法来革新房屋翻转的盈利模式。尽管开始时结构安排得很周密,聘请了本地房地产专家来验证算法定价,Zillow 还是转向完全算法驱动的方法以追求更快的报价。然而,这一举动并没有取得预期的效果。

照片由 Tierra Mallorca 提供,来源于 Unsplash
Zestimate 试图适应 2021 年房地产市场的快速通货膨胀,这促使 Zillow 采取行动以增强其报价的吸引力。该公司开始了一场雄心勃勃的购房狂潮,报告称每季度收购多达 10,000 套房产。然而,人力资源难以跟上这些收购的庞大规模和速度,这一挑战因疫情的同时爆发而加剧。在面临不断增加的困难,包括未售出的房产积压时,Zillow 决定在 2021 年 10 月暂停其报价。随后几个月,房产以亏损价被转售,导致超过 5 亿美元的大规模库存减记。
除了其失败的投资所造成的巨大财务损失,Zillow 宣布将裁员约 2,000 名员工——公司的四分之一。
我们从一个相当不幸的事件开始讨论,因为 Zillow 的 iBuying 风险投资的崩溃嵌入了一个复杂的原因框架中。尽管无法将这一事件与 2020 年全球大流行从住房市场中剥离开来,但它确实为深入分析铺平了道路。在本文中,我们将以此为例,揭示我们系列中讨论的治理和风险管理原则如何可能避免未来类似的不幸事件。
在继续阅读之前
在继续之前,请注意这是我们 AI 风险管理系列的第三篇文章。建议阅读前两篇文章,以便全面理解。
• 第一篇文章展开了 机器学习风险管理的文化能力,探讨了在这一复杂领域中所需的人文维度。
/ ## 机器学习风险管理的文化能力
组织的文化是负责任 AI 的一个重要方面。
[towardsdatascience.com
• 第二篇文章将重点转向 ML 系统背景下的另一个重要元素:组织流程。与我们一起踏上这段启发之旅,以更好地掌握 AI 和风险管理的交织领域。
/ ## 机器学习风险管理的组织流程
组织流程是 ML 系统可靠性的一个关键非技术性决定因素。
[towardsdatascience.com
超越模型风险管理
在上一篇文章中,我们详细讨论了机器学习风险管理(MRM)如何构成一个全面的框架,以及一系列旨在识别、评估、缓解和监控与机器学习系统的开发、部署和操作相关的风险的程序。在这一部分,我们将探索超越传统模型风险管理的各种策略和实践,这些策略和实践在 ML 安全方面表现尤为出色。我们将从讨论 AI 事件响应开始。
📝AI 事件响应计划
Zillow 事件突显了 AI 失败的一个事件,展示了一个精心设计的算法如何无法跟上快速变化的房地产市场,导致重大财务和声誉损失。尽管经过最佳的模型训练和验证测试,但即便在SR-11 指南中也提到,消除模型风险是不可能的,这突显了制定可靠事件响应计划的紧迫性。
AI 事件计划是一个预先制定的策略,旨在快速有效地解决 AI 问题,帮助组织迅速识别、遏制和消除 AI 事件,并防止代价高昂或危险的情况,这对较小或新兴的组织尤其重要。这是计算机安全中的一项受认可的做法,像NIST和SANS等组织强调了它在管理机器学习和人工智能复杂性方面的重要性。与计算机事件响应一样,AI 事件计划分为六个明确的阶段,每个阶段都对降低 AI 风险至关重要。

人工智能事件计划的六个阶段 | 图片来源:作者
- 阶段 1:准备
为有效准备 AI 事件,组织应定义事件参数,分配响应预算,制定沟通计划,并实施技术保障措施。通过与关键人员进行桌面演练来模拟场景可以提高准备程度。

AI 事件准备阶段的起始问题 | 图片来源:作者
- 阶段 2:识别
识别涉及检测系统故障、攻击或滥用。它结合了一般安全方法和专门的 AI 监控,如检测概念漂移或算法歧视。一旦发现问题,将会提醒相关利益相关者,包括管理层。
- 阶段 3:遏制
控制指的是减轻事件造成的直接危害,目标是减少初始损害。事件可能会有蔓延的趋势,影响业务和客户的其他方面。解决这些问题的方法可能会根据其根本原因而有所不同,无论是外部攻击、内部错误还是 AI 系统的误用。在必要时,建议在控制阶段与公众进行沟通。
- 第四阶段:根除
根除意味着修复受影响的系统以停止问题。这可能通过阻止被攻击的系统以防止进一步的损害,或关闭故障的 AI 系统并暂时使用一个可信的、更简单的系统来实现。根除之后,事件不应再造成任何进一步的伤害。
- 第五阶段:恢复
恢复过程涉及修复受影响的系统、预防未来的问题,并且可能需要审查或改进技术程序,尤其是当问题是由于错误或恶意行为引起时。
- 第六阶段:经验教训
经验教训意味着根据在当前问题中有效和无效的措施,对我们对 AI 事件的响应进行更改或改进。这些改进可以涉及过程或使用的技术。
从 Zillow iBuying 事件中获得的经验:AI 事件响应的见解
在审查我们的 AI 事件响应计划后,让我们回到 Zillow iBuying 的事件中。我们能从这一章节中关于 Zillow iBuying 情况中获得什么见解呢?根据对此话题的公开报告,显然存在潜在的警示信号🚩。这些包括 Zillow 缺乏人工监督、对财务风险的评估不足,以及缺乏适当的治理结构。虽然关于 Zillow 的具体答案仍不确定,但这突显了从这一案例中提取有价值的经验以提高我们对 AI 相关挑战的准备和响应的重要性,包括:
-
经验教训 1:与领域专家验证。
-
经验教训 2:预见失败模式。
-
经验教训 3:治理至关重要。
-
经验教训 4:AI 事件可能会迅速扩大。
-
经验教训 5:新兴技术总是伴随风险。
提升风险管理的额外实践
除了上述讨论的 AI 事件响应,来自财务审计、数据隐私、软件开发最佳实践和 IT 安全的实践对这一领域也带来了重要价值。
🔍模型审计和评估
模型审计是针对机器学习系统的正式评估过程,确保符合特定政策、法规或法律。这些正式评估通常由独立第三方进行,重点关注透明性和全面测试。模型评估类似但更加非正式,可能由内部或外部团队进行,检查各种问题,如偏见、安全、数据隐私危害和安全漏洞。
想深入了解模型审计和评估,可以参考两篇论文:算法偏见和风险评估:实践经验的教训 和 弥合人工智能问责制缺口:定义内部算法审计的端到端框架,这些论文提供了进行这些审计和评估的宝贵见解和框架。
📈影响评估
影响评估在机器学习政策和拟议法律中越来越受到重视,用于预测和记录系统可能面临的挑战。这些评估使得人工智能设计师和操作员更容易理解和对其系统可能造成的问题负责。然而,这只是一个初步步骤。它们应定期进行,并与其他因素一起考虑,以全面了解风险。由非机器学习团队成员进行评估尤为重要,以避免任何偏见并确保彻底检查。
虽然影响评估在风险管理和治理策略中发挥了关键作用,但由独立专业人员执行以及与其他风险因素结合至关重要,以确保整体效能。
⚖️上诉、覆盖和选择退出
你是否见过 Google 搜索栏中的 举报不当预测 功能?这是用户指出问题的基本方式。这个功能允许用户对机器学习系统的决策进行挑战或更正。这一想法,也被称为可操作的补救或救济,可能有所不同。另一种方法是 选择退出 选项,允许用户跳过自动处理。这两个选项被许多数据隐私和美国消费者金融法律认可,对于保护消费者权益免受自动化机器学习错误至关重要。然而,由于需要规划和资源来从一开始就集成这些功能,许多机器学习系统仍然缺乏这些功能。

通过 Google 举报不当预测 | 图片作者
👩💻👨💻对编程和双重编程
机器学习算法可能复杂且不可预测,使得确保它们正常工作变得困难。一些顶级机器学习组织使用两种主要方法来进行双重检查:
-
对编程
— 两名专家分别编写相同的算法。
— 然后,他们合作解决任何差异,确保两个版本的工作方式相同。
有趣的是,大型语言模型(LLMs)现在正在被纳入对编程的配对编程中。最近的一门课程标题为与大型语言模型的配对编程深入探讨了与 LLMs 在实时编码场景中协作的细节。
-
双重编程
— 一个人用不同的编程语言两次编写相同的算法。
— 然后,他们比较并调和两个版本之间的任何差异。
两种方法都有助于早期发现和修复漏洞,确保算法在实际应用前是可靠的。
🔒模型部署的安全权限
IT 安全中有一个概念叫做最小权限,它强调系统用户不应拥有过多权限。尽管其重要性,但在机器学习系统中常常被忽视,这可能导致安全和性能问题。公认的做法是,让产品经理或高管等不同角色做出软件发布的最终决定,以避免偏见并确保全面评估。

最小权限原则通过 Intel x86| 由 Hertzsprung 于英语维基百科,CC BY-SA 3.0 演示。
在开发冲刺期间,数据科学家和工程师必须对他们的环境拥有完全的控制权。然而,随着重要发布或审查的临近,IT 权限应转移给组织内的其他角色。此控制权的转移作为一个检查点,确保未经批准或有缺陷的代码不会被部署,从而增强系统的安全性和可靠性。
💰漏洞奖励
漏洞奖励是组织提供给发现其软件问题的人,包括机器学习系统的奖励。它们不仅仅用于发现安全问题,还包括与安全性、隐私和可靠性等相关的问题。
通过提供金钱或其他奖励,组织鼓励人们反馈意见并发现他们机器学习系统中的问题,从而使这些系统更加可靠和安全。如果组织担心公开其漏洞奖励,他们可以举办内部活动,让不同团队在机器学习系统中寻找问题。关键是提供良好的激励措施以获得最佳结果。
通过漏洞奖励,我们使用货币奖励来激励社区反馈,形成标准化的过程。
许多公司已经推出了漏洞奖励计划,以检测和修复其系统中的漏洞。以下是一些例子:
- 在 2021 年,Twitter(现在称为X)宣布了首个 算法偏见悬赏挑战,以探索其图像裁剪算法中的潜在偏见。该算法使用了一种名为显著性图的 XAI 技术来确定用户上传图像中最吸引人的部分。

图像裁剪算法 由 Twitter(现在称为 X) 用于展示时间线上的图像。该算法后来被弃用| 作者图片
一些用户观察到基于机器学习的图像裁剪工具似乎偏向白人图像,并且不成比例地放大了女性的胸部和腿部等区域,这暗示了男性凝视的偏见。此外,当这些问题被指出时,用户没有机制来修改自动裁剪。挑战的目的是识别这样的算法可能带来的潜在危害。
## 你看到的就是你会得到的 — Twitter 在时间线中展示图像的新策略
总结 Twitter 关于其图像裁剪算法的论文
towardsdatascience.com
2. 在 2023 年 4 月 11 日,OpenAI 宣布了一个漏洞悬赏计划,邀请安全研究社区参与。发现问题的奖励从 $200(低严重性问题)到最高 $20,000(特殊发现)。
3. Meta 在其平台上有运行漏洞悬赏计划的历史。然而,当他们在 2023 年 2 月推出LLaMA 2——他们的开源大语言模型时,他们也发布了一个责任使用指南。该指南包括 报告漏洞和安全问题的选项。
结论
这篇文章强调了在负责任的人工智能开发中,治理、事件响应和专家验证的重要性。当我们深入探讨超越传统模型风险管理的策略,包括人工智能事件响应计划以及借鉴金融审计、数据隐私、软件开发和 IT 安全的做法时,显然,一个多方面的方法对于以负责任和安全的方式应对人工智能不断变化的挑战至关重要。从 Zillow 的经验中得到的教训提醒我们需要在人工智能中实施强有力的风险管理,这将引导我们在未来创建更可靠和更具伦理的人工智能系统。
参考资料
对应分析简要介绍
原文:
towardsdatascience.com/brief-introduction-to-correspondence-analysis-a88297ebba2a
学习如何在 R 中运行多重对应分析的基本步骤
·发布在Towards Data Science ·阅读时间 7 分钟·2023 年 1 月 16 日
--

图片由John Barkiple拍摄,来源于Unsplash
介绍
数据集由数字和/或文本组成。因此,我们应该预期,并非所有变量都仅涉及数字,数字变量有许多技术可以进行分析、测试和处理。
当我们处理数值变量时,有像相关性、PCA、缩放、归一化以及一系列测试等工具。另一方面,如果我们处理文本,更具体地说,是类别,我们应该寻找其他技术来应用于我们的数据分析。
其中一个工具是对应分析[CA]。
对应分析是一种统计技术,可以显示基于给定的列联表数据,两个变量内类别之间的关系。
正如定义所示,它是一种统计工具。概念上,它类似于主成分分析[PCA],但应用于分类数据,因为它使我们能够以 2D 图形展示数据集,显示哪些类别对应(或相关)于什么。
对于数据科学家来说,CA 可以在许多方面发挥作用,例如了解不同类型的客户如何购买一组产品、每个年龄段偏好的电影类型,或者本教程的示例:哪些产品在注册 1 和注册 2 中被购买。
导入和创建一些数据
我们将从导入必要的库和创建一些数据开始。
# Imports
library(tidyverse)
library(ggrepel)
library(sjPlot) #contigency tables
library(FactoMineR) #CA functions
library(ade4) # Create CA
创建数据。
# Dataset
df <- data.frame(
trans_id = 1:30,
register = as.factor(c('rgs1', 'rgs1', 'rgs1', 'rgs2', 'rgs1', 'rgs2',
'rgs1', 'rgs2', 'rgs1', 'rgs2', 'rgs1', 'rgs2',
'rgs1', 'rgs2', 'rgs1', 'rgs2', 'rgs1', 'rgs2',
'rgs1', 'rgs2', 'rgs1', 'rgs2', 'rgs1', 'rgs2',
'rgs1', 'rgs2', 'rgs1', 'rgs2', 'rgs1', 'rgs2')),
product1 = as.factor(c('banana', 'banana', 'pasta', 'milk', 'yogurt',
'milk', 'pasta', 'milk', 'pasta', 'milk', 'banana',
'milk', 'banana', 'banana', 'pasta', 'bread', 'bread',
'milk', 'yogurt', 'bread', 'banana', 'pasta', 'yogurt','milk',
'yogurt', 'bread', 'bread', 'pasta', 'milk', 'banana')),
product2 = as.factor(c('strawberries', 'strawberries', 'sauce', 'bread', 'water',
'bread', 'sauce', 'bread', 'sauce', 'bread', 'strawberries',
'bread', 'strawberries', 'bread', 'water', 'bread', 'water',
'bread', 'bread', 'yogurt','strawberries', 'sauce',
'strawberries', 'bread', 'strawberries', 'milk', 'bread',
'sauce', 'bread', 'strawberries'))
)
这是数据的一个样本。所以我们有注册号和每次交易的一对产品。

创建的数据集样本。图片由作者提供。
统计数据
执行 CA 的第一步是进行统计检验。由于我们正在处理多个变量对,因此我们必须为每对变量执行卡方检验,所有的检验结果都必须在至少一对变量上具有统计显著性。例如,product1 必须与 product2 或 register 之一通过检验。
要执行的测试是一个假设检验,其中:
Ho (p-value > 0.05) 意味着变量之间没有关联。
Ha (p-value ≤ 0.05) 意味着变量之间存在关联。
一种快速测试变量对的方法是使用 for 循环。
for (var1 in 2:4){
for (var2 in 4:2) {
contingency <- table(df[,var1], df[, var2])
chi2 <- chisq.test(contingency)
writeLines( paste("p-Value for",
colnames(df)[var1], "and", colnames(df)[var2],
chi2$p.value))
}
}
p-Value for register and product2 0.0271823155904414
p-Value for register and product1 0.0318997966416755
p-Value for register and register 3.2139733725587e-07
p-Value for product1 and product2 9.51614574849618e-06
p-Value for product1 and product1 5.49284039685425e-18
p-Value for product1 and register 0.0318997966416755
p-Value for product2 and product2 8.43312760405718e-20
p-Value for product2 and product1 9.51614574849618e-06
p-Value for product2 and register 0.0271823155904414
结果显示所有的卡方检验都低于 p-Value < 0.05 的阈值,因此我们可以拒绝原假设,支持替代假设,并理解变量之间存在统计学上显著的关联。
另一种选择是使用 sjPlot 库中的 stj.xtab() 函数。
# Register x product1
sjt.xtab(var.row = df$register,
var.col = df$product1,
show.exp = TRUE,
show.row.prc = TRUE,
show.col.prc = TRUE)
它展示了一个格式良好的表格,已经包含了观察值、用绿色表示的期望值,以及每个类别的百分比、p 值和卡方统计量。

来自 stj.xtab() 函数的结果。图片由作者提供。
多重对应分析
现在是创建我们的多重对应分析(MCA)的时间。我们可以使用 ade4 库中的 dudi.acm() 函数。scannf= FALSE 参数只是为了防止显示特征值条形图。
# Creating the Multiple Correspondence Analysis
ACM <- dudi.acm(df[,2:4],
scannf = FALSE)
一旦运行这个,输出将是 R 中的一个包含 12 个对象的 List。例如,如果我们运行 ACM$co,我们将看到为计算出的两个主成分的每个类别的坐标。这意味着 X 和 Y 坐标,或者每个点在二维图形上的位置。
ACM$co
Comp1 Comp2
register.rgs1 0.7660067 0.05610284
register.rgs2 -0.8754363 -0.06411753
product1.banana 0.8060812 0.99732829
product1.bread -0.6784873 -0.03550393
product1.milk -1.2068148 0.23776669
product1.pasta 0.6008691 -1.82914217
product1.yogurt 0.9497931 0.56723529
product2.bread -0.9315168 0.26905544
product2.milk -1.1707641 -0.10768356
product2.sauce 0.5351564 -1.96850658
product2.strawberries 1.0569306 1.00621404
product2.water 0.7961676 -0.40682587
product2.yogurt -1.1707641 -0.10768356
如果我们运行 ACM$cw,可以看到数据集中每个类别的百分比。
ACM$cw
register.rgs1 register.rgs2 product1.banana product1.bread product1.milk
0.17777778 0.15555556 0.07777778 0.05555556 0.08888889
product1.pasta product1.yogurt product2.bread product2.milk product2.sauce
0.06666667 0.04444444 0.13333333 0.01111111 0.05555556
product2.strawberries product2.water product2.yogurt
0.08888889 0.03333333 0.01111111
在 MCA 中,我们将能够提取 n = num_categories — n_variables 维度。在这个练习中,我们有 3 个变量(register1, register2,product1 和 product2)和 13 个类别(banana, bread, milk, yogurt, sauce, water, strawberries, pasta, sauce,一些类别在 product 1 和 2 中重复)。因此,13–3 = 10 维度。
因此,我们可以使用ACM$eig查看每个类别的 10 个 特征值。这些值表示每个类别所捕获的方差量,以一种简单的方式表示。
ACM$eig
[1] 0.77575767 0.64171051 0.54102510 0.44643851 0.33333333 0.25656245 0.15516469 0.10465009 0.05690406 0.02178693
# Variance from each dimension
perc_variance <- (ACM$eig / sum(ACM$eig)) * 100
[1] 23.272730 19.251315 16.230753 13.393155 10.000000 7.696873 4.654941 3.139503 1.707122 0.653608
创建感知图
最后一步是创建感知图,我们将在图形上看到类别的绘制。为此,我们必须创建一个基础数据框来保存类别名称及其相应的 X 和 Y 坐标。首先,让我们检查每个变量包含多少个类别。
# How many categories by variable
qty_categories <- apply( df[,2:4], 2, function(x) nlevels(as.factor(x)) )
register product1 product2
2 5 6
很好。现在我们将创建一个 data.frame 对象,获取坐标 ACM$co,该对象将加载类别名称及 X 和 Y 坐标,并且包含一个 Variable 列,列出变量的名称(product1 或 product2)以便标记。
# Create the df with coordinates
df_ACM <- data.frame(ACM$co,
Variable = rep(names(qty_categories),
qty_categories) )

X 和 Y 坐标。图片由作者提供。
从这里开始,现在只需要使用 ggplot2 创建图表即可。
我们将从 df_ACM 对象开始,提取行名并将其创建为一列(rownames_to_column()),然后将该列重命名为 Category。接下来,我们将名称如 product1.banana 变更为仅 banana。然后,我们将这个新数据框与 ggplot 函数进行管道操作,提供 x=Comp1 和 y=Comp2,标签为 Category,每个变量有不同的颜色。将创建一个散点图(geom_point),并使用 geom_label_repel 使名称不覆盖点。vline 和 hline 用于创建 0 的参考线。
# Plotting the perceptual map
df_ACM %>%
rownames_to_column() %>%
rename(Category = 1) %>%
mutate(Category = gsub("register.","", Category),
Category = gsub("product1.","", Category),
Category = gsub("product2.","", Category)) %>%
ggplot(aes(x = Comp1, y = Comp2, label = Category, color = Variable)) +
geom_point() +
geom_label_repel() +
geom_vline(aes(xintercept = 0), linetype = "longdash", color = "grey48") +
geom_hline(aes(yintercept = 0), linetype = "longdash", color = "grey48") +
labs(x = paste("Dimensão 1:", paste0(round(perc_variancia[1], 2), "%")),
y = paste("Dimensão 2:", paste0(round(perc_variancia[2], 2), "%"))) +
theme_bw()
这是结果。

MCA 的感知图。图片由作者提供。
结果给我们一些有趣的见解:
-
Register 1 收到更多的水果,如草莓和香蕉,一些水,少量意大利面和酱料。
-
Register 2 处理了更多的面包和牛奶或酸奶交易。
-
请注意,意大利面和酱料在两个收银台的距离较远。这是因为在 register1 上有 4 笔交易,而在 register2 上有 2 笔。其他组合要么由 rgs1 处理,要么由 rgs2 处理。
在你离开之前
MCA 是一个强大的工具。如果你处理的是分类数据,应该查阅并尽可能使用它来进行良好的分析。然而,请记住,随着变量和类别数量的增加,它的应用会变得更困难。
例如,想象一个包含 30 个变量、每个变量有 5 个类别的数据集!这需要进行大量测试和分析。在这种情况下,其他技术可能更好,或者对数据进行一些变换以减少类别数量,或创建一个子集进行分析。
如果你喜欢这些内容,请关注我的博客以获取更多信息。
在 Medium 上阅读 Gustavo Santos 的文章。数据科学家。我从数据中提取见解,以帮助个人和公司……
参考
在线性代数中,特征向量(Eigenvector)或特征向量是线性变换的一个非零向量……
对应分析(CA)是一种由赫尔曼·奥托·哈特利(Hirschfeld)提出的多变量统计技术……
en.wikipedia.org [## ade4 包 - RDocumentation
多变量数据分析工具。提供了几种方法用于分析(即排序)单表数据……
将正确性带回机器学习
原文:
towardsdatascience.com/bring-correctness-back-to-machine-learning-a56a96262f17
我们是否在错误的假设上构建我们的领域?
·发布于 Towards Data Science ·9 分钟阅读·2023 年 10 月 13 日
--

由 Andrea De Santis 提供的照片,来自 Unsplash
介绍:什么是正确性?
研究论文仍然是传达机器学习领域新发现的主要方式。然而,论文的结果不能被复现的情况相当频繁,而原因却不清楚。
在这里,我想提出我对研究论文作为沟通工具的利弊的看法。我将介绍我对科学及其在发展集体人类知识中的作用的看法。
最近的一篇论文发现了研究出版过程中的漏洞。我觉得这篇论文非常有说服力,我将向你展示其主要主张,以提高对代码正确性在传播机器学习知识中作用的认识。
这个问题不仅限于机器学习,但在这个领域,许多研究人员缺乏强大的工程技能——并且通常试图逃避工程工作——这导致了可靠、正确、可用、有效的软件的问题。
你是否从事机器学习工作,并阅读研究论文寻找新想法?那么,这篇文章将帮助你以原则性的方法对你阅读的内容进行更批判的分析。
你是研究人员和研究论文的作者吗?希望你会对这个话题感兴趣,阅读引用的论文,并参与讨论。
现在,足够的介绍,让我们深入讨论吧!
科学即知识
根据所有可以衡量的指标,科学作为职业比以往任何时候都更受欢迎。这些指标包括科学家(和博士)的数量、可用资金、资助申请等。所有科学家中增长的一部分是机器学习领域的研究人员,理论的和应用的都有。
科学家的工作是什么?就是在某一研究领域发现新知识,从而扩展人类已知的知识范围,通过扩展或推翻现有知识。
在许多情况下,科学家在现有知识的基础上进一步推进。有时,先前的科学证据可能被证明是错误的或具有误导性的。例如,选择的样本可能不能代表整个群体,因此结果可能不具有普遍性。
另一个原因可能是研究是在特定条件下进行的,而其结果被推广到不同的条件。例如,在机器学习中,一个方法可能在训练集人为缩小的情况下表现优于当前最先进的方法…但在正常的数据条件下,它的表现不如基线。
传播科学发现的主要工具,虽然被一些人认为过时,但毫无疑问还是经过同行评审的研究论文。通过研究论文,科学家们以结构化和有组织的方式描述他们的发现。他们描述发现的领域、他们解决的问题、知识领域中的漏洞、他们的假设以及旨在为假设提供证据或证伪的实验。
根据描述,其他科学家决定论文是否值得信赖以及发现是否值得发表,这一过程称为同行评审。需要注意的是,这是一个棘手且不完美的系统,同行评审常常未能发表值得的研究,或者相反,允许发表质量较差的论文。没有任何系统是完美的,这也是游戏的一部分。这个游戏是科学传播的主要过程,通常决定了研究资金的未来和研究人员的职业生涯。
现在,可以提出一些关于过程基本原理的问题:“我们如何相信论文中的声明?”,“一篇论文是否足以将一项知识视为已获得?”,“如果一篇新论文与之前发表的论文结果相矛盾怎么办?”
为了回答这些问题,我们需要介绍一些概念。
可重复性、可靠性、正确性
机器学习会议越来越重视可重复性的问题,他们这样做完全正确。其他科学领域,特别是心理学和医学,在经历了著名的“复制危机”之后,遭遇了严重的可信度问题:论文结果无法被不同的独立研究者复制,这对该领域的整个工作产生了大量质疑。然而,这个问题远远超出了这两个提到的领域。

图片来源:Julia Koblitz 在 Unsplash
虽然极为重要,但可重复性只是产生新科学知识过程的一部分。机器学习会议通常要求审稿人也评估实验的“严谨性”。
让我们尝试用更明确的术语来定义它们:
可重复性是指通过遵循论文中解释的协议,另一组能够重复一项科学研究的可能性。这意味着论文中描述了所有相关细节,也许所使用的软件已发布(我们在这里讨论的是机器学习),并且可以获取相同的训练数据。
如果独立研究者遵循所描述的方法论,但他们的结果与论文中描述的结果有显著差异,则该研究是不可重复的。
严谨性是对协议正确性的判断。实验是否与要证明的假设一致?结果是否确实显示了作者所声称的内容?实验是否存在偏差或不完整性?
实际上,严谨性分数反映了从科学角度看研究的技术正确性。
这两个方面在机器学习会议中广泛讨论,但还有第三个方面虽然讨论较少但仍然重要。
代码正确性是一个是/否问题:代码是否实际实现了论文中描述的方法论?
代码正确性通常被认为是理所当然的,因此不会被检查或强制执行。
建立正确的知识
可重复性对建立知识至关重要。让我再重复一遍:无法被独立复制的结果不代表科学知识。它们不一定是由于欺诈(尽管有时确实如此),但可能是因为一些被作者认为是“次要”的细节没有描述,而这些细节在所提议的方法中或比提议的方法更为重要,以获得声称的结果。
另一方面,正确性是建立正确知识的基础。如果我们阅读一篇论文并发现有趣的结果,然后尝试复制这些结果并获得完全相同的结果,我们可以对自己新获得的领域知识感到满意。
如果参考实现隐藏了一些“bug”,并且在某些基本方面没有遵循所描述的思想,会怎么样呢?

图片由 Dmitry Bukhantsov 提供,来源于 Unsplash
在这种情况下,我们是在错误的事实基础上建立知识。积累许多错误的事实,我们就不知道什么是重要的,什么是不重要的。这对我们这些想要在某个领域成为“专家”的人来说,确实很可怕,不是吗?
这个问题在论文 When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP 中由 Papi 和同事们 [1] 讨论,其中以一些广泛使用的 Conformer 实现 [2] 为例。这些实现通常在常见的开源框架中加速了许多组的研究,这些组不需要自己实现 Conformer 和训练方法。问题是,所有被检验的实现中都包含三类 bugs,当推断批量大小大于一时,这些 bugs 会影响结果。注意,在正常情况下,推断批量大小应影响推断过程中的资源使用,从而影响推断速度,但不应影响输出。
关于这些 bug 的详细信息,我强烈推荐阅读这篇论文。在这里,我只是想强调,差异通常足够小,以至于不会注意到实现中存在问题,这也可能是它们被忽视了这么久的原因。然而,在一种情况下,当批量大小非常大时,降级是巨大的。
有些人可能会疑惑为什么我们要关心 BLEU 或 WER 的小数点后几位,但我认为这种观点忽视了问题的关键。本文仅以流行的 Conformer 实现为例,这些实现被全球数百或数千名研究人员和实践者使用。然而,它发现了所有这些实现中的 bug,其中一个在保持高效实现的同时并不容易解决。
记住,“bug”只是软件行业对软件错误的友好称呼。一个可爱的“bug”总是软件需求与实际行为之间的差异。
想象一下,每周提出的数百个深度学习网络的所有私人实现会发生什么。在许多情况下,这些网络由小组或单个开发者开发,没有人审查他们的代码。实现与论文中的描述是否匹配,留给作者自己判断,也不作为科学同行评审讨论的重点。
那么问题是:在这些假设下,我们如何可靠地信任我们阅读的论文中的声明?
不幸的是,这个问题没有明确的答案。[1]的作者提出了一些受软件开发如何解决正确性问题启发的解决方案。我认为这朝着正确的方向发展,但不幸的是,太多人过于看重感知到的额外工作(而不是错误软件的成本),我不认为这种方法会很快被广泛采纳。
结论
良好的科学话题无疑是一个艰难的领域。研究工作涉及许多不同的方面,研究人员最不希望的就是为了同样数量的出版物而增加更多的工作。
然而,当我们依赖不正确的软件时,我们正在构建错误的知识,并可能基于这些错误的知识走上错误的方向。试想一下,如果你使用一个你认为可靠的软件工具工作了一年或两年,然后有人发现了一个 bug,当它被修复时,你的新模型的性能突然下降。我认为没有人愿意陷入这种情况。我们从最佳工程实践中知道,错误应该在早期发现,以减少其成本。
像论文中提出的那种机器学习测试工具可以帮助研究人员以简单的方式生成更正确的软件,我真的很期待它的发布。如果是开源的,整个社区都可以贡献力量,整体提升编码标准。
你怎么看?你有什么解决方案来确保 ML 研究的正确性吗?在评论中告诉我!
如果你读到这里,非常感谢你的时间!我知道你的时间有限且宝贵,但你仍然决定花时间阅读我的想法,非常感谢!
更多我的文章
## 阅读和撰写 ML 研究论文的技巧 ## Tips for Reading and Writing an ML Research Paper
从数十次同行评审中获得的经验教训
## 无需多言:自动化开发环境和构建 ## Without Further Ado: Automate Dev Environments and Build
通过环境和构建自动化,使您的软件易于使用,从而给您的同事开发者带来快乐。通过…
## 3 个常见的 bug 来源及如何避免 ## 3 Common Bug Sources and How to Avoid Them
一些编码模式更容易隐藏 bug。编写高质量代码并了解我们的大脑如何工作可以帮助…
通向数据科学 [## 语音增强介绍:第一部分 — 概念和任务定义
介绍改善降级语音质量的概念、方法和算法…
Medium 会员
你喜欢我的写作吗?你是否考虑订阅 Medium 会员以无限访问文章?
如果你通过这个链接订阅,你将通过你的订阅支持我,而对你来说没有额外费用 medium.com/@mattiadigangi/membership
参考文献
[1] Papi, S 等。“当良好且可重复的结果成为泥足巨人:软件质量在自然语言处理中的重要性” arxiv.org/2303.16166
[2] Gulati, Anmol 等。“Conformer:用于语音识别的卷积增强变换器” Proc. Interspeech 2020(2020 年):5036–5040。
冒泡排序解释——数据科学家的算法指南
原文:
towardsdatascience.com/bubble-sort-explained-a-data-scientists-algorithm-guide-853b2286e5ab
冒泡排序的直观解释及其在 Python 中的实现
·发表于 Towards Data Science ·阅读时间 10 分钟·2023 年 1 月 5 日
--
作为软件工程师和数据科学家,我们经常理所当然地看待排序函数。这些算法可能不是我们工作中最光鲜亮丽或讨论最多的方面,但它们在我们每天使用的技术中发挥着至关重要的作用。例如,想象一下在没有按字母顺序排序的功能时,如何整理手机上的联系人列表,或者如何在电子商务网站上按价格和类别排序产品。很容易忽视排序算法的重要性,但它们对我们作为程序员的工作至关重要。
尽管大多数编程语言,如 Java、Python、C# 等,内置了常见排序算法的函数,但我们仍然需要对这些算法的工作原理有基本的了解。这些知识使我们能够根据算法的空间和时间复杂度做出明智的决策,特别是当数据科学家处理大型数据集时。因此,不要小看那些不起眼的排序函数——它们可能不是焦点,但却是科技行业的无名英雄。
在这篇文章中,我们将深入探讨冒泡排序算法,考察其在 Python 和 JavaScript 中的实现。我们还将更详细地了解算法背后的直觉,并讨论时间和空间复杂度的考虑因素。阅读完这篇文章,你将对何时在程序中使用冒泡排序算法有一个扎实的理解,并了解其空间和时间复杂度的概述。
在免费的电子书 硬件 > 软件 > 过程 中,详细了解硬件创新如何改变数据团队构建分析和机器学习应用的方式。
冒泡排序直观理解
在试图理解并回忆一个算法时,首先掌握概念总是更有帮助。在跳入实现之前,通过熟悉这个概念,你将更好地保留信息以备将来使用。冒泡排序也不例外。
为了使用冒泡排序算法将数组[2, 3, 4, 5, 1]按升序排序,我们从第一个元素[2]开始,并将其与第二个元素[3]进行比较。如果第一个元素大于第二个元素,我们交换它们。我们继续比较元素对,直到到达数组的末尾。这样,最大的元素将被移到数组的末尾,最小的元素将被移到数组的开头。
“冒泡排序”这个名字指的是较大的元素在被反复比较和交换较小元素时,“冒泡”到数组的顶部或末端。排序过程结束时,数组将完全按升序排序。
冒泡排序算法逐步解释
以下是我们将使用冒泡排序进行组织的无序数字列表:

作者提供的图片
第一步是只关注前两个数字,在这个例子中是 5 和 9。你可以将只考虑这两个元素 5 和 9 的过程可视化,如下图所示:

作者提供的图片
然后,你必须确定气泡中的数字是否按顺序排列。如果它们的顺序不正确,则需要进行交换以使其正确。幸运的是,它们已经按升序排列。5 小于 9,所以它在 9 之前。这意味着我们无需做更多的操作——我们将气泡再移动一步,如下所示:

作者提供的图片
在数组的下一次迭代中,我们进行相同的步骤。然而,这一次 9 大于 1,但它也在 1 的前面。因此,为了纠正这个问题,我们交换了这两个元素的位置。现在列表的样子如下:

作者提供的图片
现在元素已经交换,冒泡排序将继续处理后续的元素对。这个过程会重复进行,直到数组中的最后一对元素也被检查和交换。第一次遍历数组的结果如下:

作者提供的图片
冒泡排序算法是一种简单但有效的排序方法。它通过反复遍历数组,比较元素对的顺序,如果顺序错误则交换位置。这一过程会重复进行,直到整个数组被排序完成。
需要记住的一点是,对数组进行排序所需的遍数等于数组中的元素个数。例如,一个 6 个元素的数组需要进行 6 次遍历才能完全按升序排序。
然而,通过限制对数组进行的操作次数或遍历次数,可以使冒泡排序算法更高效。这是因为数组的最后一个元素始终是最大值,因此在未来的遍历中不需要继续比较这个位置之后的所有元素。我们将在下面的 Python 和 JavaScript 实现中看到这种优化的实际效果。
用 Python 实现的冒泡排序算法
本节使用 Python 编程语言实现冒泡排序算法。我们将观察到一种朴素的实现方式和一种更高效的冒泡排序算法。
初始化一个包含整数元素的 Python 数组
unsortedData = [20, 33, 12, 53, 24, 65, 23, 4, 53, 1];
定义一个名为‘bubbleSort’的函数,该函数接受一个名为‘data’的数组作为参数。首先,我们尝试遍历数组,并交换满足条件的元素,即在特定索引处的左边元素大于右边元素时,我们对这两个元素执行交换操作。
需要注意的一点是,在任何迭代中,将左边元素分配给临时变量‘tempValue’,然后将右边元素分配给临时变量。
def bubbleSort(data):
for i in range(0, len(data)):
if i+1 < len(data):
if data[i] > data[i+1]:
tempValue = data[i]
data[i] = data[i+1]
data[i+1] = tempValue
return data
上述代码片段在用未排序的数组作为参数调用时,将对数组进行一次冒泡排序函数的遍历。在大多数情况下,可能不会完全将数组排序为升序。
sortedData = bubbleSort(unsortedData)
print(sortedData)
>>>[20, 12, 33, 24, 53, 23, 4, 53, 1, 65]
为了解决这个问题,我们必须按照配对组合的次数来遍历我们想要排序的数组。简单来说,进行的迭代次数是未排序数组长度的平方(len(unsortedArrray)²)。这是一种朴素的实现方式。
def bubbleSort(data):
# Iterate through the array enough times to consider every possible swap pairs
for _ in range(0, len(data)):
for i in range(0, len(data)):
if i+1 < len(data):
if data[i] > data[i+1]:
tempValue = data[i]
data[i] = data[i+1]
data[i+1] = tempValue
return data
再次运行冒泡排序函数,并将未排序的数组作为参数传递,将得到一个升序排列的数组作为输出。
sortedData = bubbleSort(unsortedData)
print(sortedData)
>>> [1, 4, 12, 20, 23, 24, 33, 53, 53, 65]
优化版冒泡排序
虽然朴素版的冒泡排序算法有效,但存在一些不必要和冗余的操作。特别是,它比较了数组末尾已经是最大值的元素。这是因为在每次遍历数组时,冒泡排序算法将最大元素值移动到数组末尾。
为了优化冒泡排序算法,我们可以通过跟踪我们想要比较的数组部分来减少所需的交换操作次数。我们可以通过从数组的最大长度开始,并在每次遍历后将其减少 1,从而减少交换操作作用的数组区域。这样,我们可以避免在每次遍历时与数组末尾的元素进行比较,因为这些元素已经在正确的位置。
通过使用这种优化,我们可以使冒泡排序算法更高效,并减少其执行的多余操作次数。
unsortedData = [20, 33, 12, 53, 24, 65, 23, 4, 53, 1];
end = len(unsortedData)
def bubbleSort(data):
global end
for _ in range(0, end):
for i in range(0, end):
if i+1 < end:
if data[i] > data[i+1]:
tempValue = data[i]
data[i] = data[i+1]
data[i+1] = tempValue
end = end - 1
return data
sortedData = bubbleSort(unsortedData)
print(sortedData)
>>> [1, 4, 12, 20, 23, 24, 33, 53, 53, 65]
可以进一步重构以确保上述代码可读且高效。此外,正如 Robert Kübler 博士 在评论中指出的,通过检查是否发生了交换操作,可以进一步优化该算法。如果没有发生交换操作,我们可以中断循环以避免不必要的数组遍历。
unsortedData = [20, 33, 12, 53, 24, 65, 23, 4, 53, 1]
n = len(unsortedData)
def bubbleSort(data):
for i in range(n):
swapped = False
for j in range(0, n-i-1):
if data[j] > data[j+1]:
data[j], data[j+1] = data[j+1], data[j]
swapped = True
if not swapped:
break
return data
sortedData = bubbleSort(unsortedData)
print(sortedData)
以下是用 JavaScript 实现的相同算法,JavaScript 是一种受到数据从业者和软件工程师欢迎的编程语言。
用 JavaScript 实现的冒泡排序算法
const unsortedData = [20, 33, 12, 53, 24, 65, 23, 4, 53, 1];
let end = unsortedData.length - 1
const bubbleSort = (data) => {
for (let i = 0; i < end; i++) {
if (data[i] > data[i + 1]) {
const valueInRight = data[i]
data[i] = data[i+1]
data[i+1] = valueInRight
}
}
end--
}
for (let i = 0; i < unsortedData.length; i++) {
bubbleSort(unsortedData)
}
console.log(unsortedData)
时间和空间复杂度(大 O 记号)
数据科学家必须了解排序算法的性能及其所需的时间/空间。这使你能够根据具体情况选择最佳的排序算法,因为有很多选项可供选择。
当冒泡排序用于已经按升序排列的数组时,只需遍历整个数组一次。这被视为最佳情况。然而,实际上,这种情况只偶尔发生,冒泡排序通常需要 n(n-1)/2 次交换或比较才能实现排序数组。
冒泡排序算法的平均/最坏时间复杂度为 O(n²),因为我们需要通过数组多次,与提供的数组中的对数一样。因此,当时间是一个因素时,可能会有更好的选择。
-
时间复杂度 最坏情况:O(n²)
-
时间复杂度 平均情况:O(n²)
-
时间复杂度 最佳情况:O(n),数组已经排序
在空间复杂度方面,由于我们只是交换了元素而没有存储任何东西,我们不需要额外的空间来运行算法。这非常了不起,因为这意味着空间复杂度为常数,即 O(1)。这使得它成为一个 原地 算法,通过直接修改输入来工作。
关键要点
冒泡排序算法可能不是最知名或评价最高的排序算法,但正如我们所见,它也不是一个糟糕的选择。其时间复杂度为 O(n²),空间复杂度为 O(1),这是一个简单的算法,容易为初学者理解。然而,它的慢速度可能使得它在某些应用中不够实际。
尽管有其局限性,冒泡排序算法可以作为学习排序算法和数据结构的一个有用起点。这是一个很好地理解这些算法如何工作的基础方式,并且可以帮助你为以后学习更复杂的算法奠定基础。
话虽如此,冒泡排序算法可能不是处理时间敏感材料的最佳选择,因为其速度较慢可能会成为限制。然而,如果你愿意为了时间牺牲一些空间,它可能会对你有用。最终,排序算法的选择将取决于你的具体需求和目标。通过了解冒泡排序算法,你可以做出更明智的决策,选择最适合你需求的算法。
常见问题解答
- 冒泡排序是什么?
冒泡排序是一种排序算法,它使用比较方法来排序数组。该算法比较数组中一对对的元素,并在左侧对(位置)大于右侧对(位置+1)时交换它们。这个过程重复进行,直到整个数组排序完成。
- 冒泡排序需要多少次遍历?
冒泡排序需要 n(n-1)/2 次遍历所有元素,以便最终数组按升序排列。
- 冒泡排序的最坏时间复杂度是多少?
冒泡排序的最坏时间复杂度是 O(n2)。
- 冒泡排序的最坏时间复杂度是多少?
冒泡排序的最佳时间复杂度是 O(n),当数组已经排序时会出现这种情况。
- 冒泡排序的空间复杂度是多少?
冒泡排序具有 O(1) 的空间复杂度,因为它通过直接修改输入来原地操作。
我希望你觉得这篇文章有用。
要与我联系或找到更多类似于这篇文章的内容,请执行以下操作:
-
支持我的写作 成为推荐的 Medium 会员
-
订阅我的 YouTube 频道
-
订阅我的播客 Apple Podcasts | Spotify | Audible
-
订阅我的 邮件列表 获取我的新闻通讯
接下来读什么 👇🏾
一种合并排序算法的解释及其在 Python 中的实现
towardsdatascience.com
Python 中的错误?pdb来拯救!
原文:
towardsdatascience.com/bugs-in-python-pdb-to-the-rescue-d88a56a2ca71
PYTHON 编程
pdb调试器值得学习和使用吗?
·发布于Towards Data Science ·13 分钟阅读·2023 年 9 月 21 日
--

调试有助于你从失败中学习。照片由Brett Jordan拍摄,Unsplash提供
各种工具可以用于调试 Python 代码,从最简单的print()函数,通过静态但更高级的[icecream](https://github.com/gruns/icecream)及其兄弟[ycecream](https://github.com/salabim/ycecream),到各种 IDE 提供的交互式调试器。然而,我一直选择的是内置的pdb调试器,以及内置的breakpoint()函数。
调试
调试是编程的核心。你在开始学习编程时就开始调试,当你承诺你刚刚写完最后一行代码时你才会停止调试——如果你能遵守这个承诺的话。
你可能会认为,减少调试时间的一种方法是编写优质代码。让我们面对现实吧:往往,编写优质代码意味着……在开发过程中调试很多。诚然,一个好的程序员会编写更好的代码并犯更少的错误——但这并不意味着他或她不需要调试。
不过,有一种方法可以减少调试的次数:减少调试的方法就是编写良好的单元测试。
为了减少调试的次数,编写良好的单元测试。
无论你是否使用测试驱动开发,都要编写良好的测试。编写良好的测试意味着编写足够数量的精心编写的测试。我在这里不打算讨论测试,所以我给你留个想法;我在这里写了更多关于测试的内容:
大多数开发者不喜欢编写测试。如果你也是其中之一,尽力改变这一点。
我们可以假设所有程序员都需要调试他们的代码。有些人可能会说他们不需要,但这不是真的。他们确实需要;只不过他们不使用专门的调试工具,即调试器。相反,他们会运行代码来处理特定的输入,然后检查结果,然后发现有问题后修改代码并重复这个过程。因此,尽管他们不使用调试器,他们依然在调试代码;只是需要花费更多时间。调试器是有其存在意义的!
有时,单独调用 print() 函数就能解决问题。但不要自欺欺人:这不是一种非常有效的调试方法。我不是说你不应该使用它——但这是一个过于简单的方法,只能在最简单的情况下起作用。
许多使用 IDE 进行代码开发的人喜欢使用内置的调试器。Visual Studio Code 有自己的调试器,Pycharm 有一个,甚至 Thonny 也有一个。
你还可以使用作为 Python 包提供的各种调试器,这些包可以从 PyPi 安装。打开 PyPi 并搜索“debugger”一词;你会找到很多结果,但你可能需要花费一些耐心来找到能帮助你调试代码的工具。
你可以在下面的 Towards Data Science 文章中阅读关于 Python 调试器的内容:
更快、更高效地调试你的代码。
towardsdatascience.com
文章讨论了——尽管没有展示如何使用——pdb、PyCharm 和 Visual Studio(以及 VS Code)的调试器、Komodo 和 Jupyter Visual Debugger。
静态调试器与交互式调试器
调试器可以是静态的也可以是交互式的。前者只是展示对象;后者则允许你操作它们。
两者都可能有帮助,但交互式调试器提供了最大的调试能力,因为它们能够暂停程序并查看当前状态。你可以查看和使用本地和全局作用域中的所有对象;你可以检查特定命令或命令集是否有效。这就是为什么我常常偏好交互式调试而不是静态调试。
print() 函数是静态调试的一个完美示例。IDE 调试器通常是交互式的。
然而,有一个调试器同时提供了简单性和强大功能。它就是 pdb,一个内置的交互式 Python 调试器:
该模块定义了一个用于 Python 程序的交互式源代码调试器。它支持设置(条件)…
是的,pdb是内置的,所以你不需要安装它。它随 Python 安装包一起提供,你可以在任何环境中使用它。而且,pdb是交互式的。这实际上是我对调试器的主要期望!
是的,
pdb是内置的,所以你不需要安装它。它随 Python 安装包一起提供,你可以在任何环境中使用它。而且,pdb是交互式的。
在本文中,我们将讨论pdb的基础知识。我们将介绍这个强大工具的基础知识,但要注意,它提供的功能远不止这些。好在这些基础知识足以让你开始使用pdb。说实话,我很少使用pdb的高级选项。因此,阅读这篇文章将为你提供调试 Python 代码的强大工具。
关于 pdb 的几点说明
pdb的一个优点是你可以在任何地方使用它,而无需安装任何额外的东西。即使是远程环境——pdb也能正常工作。只需运行它,瞧,你就有了一个可以远程使用的交互式调试器。或者在本地使用也没问题。
首先,让我解释如何使用pdb,然后你可以决定它是否适合你。
基本上,你可以在两种模式下使用pdb。首先,你可以在pdb模式下运行你的 Python 程序。这意味着程序会逐行执行,直到完成执行或发生错误。然后程序会在死后模式下重新运行,这意味着它会在错误发生之前停下来,你将能够查看局部和全局作用域中的情况。
其次,你可以在代码中添加一个所谓的断点,调试器将会在断点处停止程序。你还可以添加更多的断点。当然,调试器只有在断点之前没有抛出错误的情况下才能停止程序。下面,我们将讨论这两种情况。
pdb模式
要在pdb模式下运行你的程序,只需按照以下方式运行它:
$ python -m pdb myapp.py
这意味着pdb控制台将打开,myapp.py脚本将逐行运行。你可以更改这种行为,将其运行到第一个错误或程序结束。最好通过一些示例来展示这如何工作。
我们将使用以下脚本,保存为myapp.py:
def foo(s):
if not isinstance(s, str):
raise TypeError
return s.upper()
if __name__ == "__main__":
for s in ("string1", "string2"):
_ = foo(s)
(这是一个示例脚本,没有什么值得自豪的。我们确实需要简单的案例来进行分析。)
我们还将使用其错误版本,其中 Python 将抛出一个错误;该脚本保存在myapp_error.py文件中:
def foo(s):
if not isinstance(s, str):
raise TypeError
return s.upper()
if __name__ == "__main__":
for s in ("string1", 10):
_ = foo(s)
如你所见,正确的程序将运行一个 for 循环,在每次循环中,它将对不同的 s 参数值运行 foo() 函数:首先是 "string1" 然后是 "string2",这两个值都是正确的。在错误的版本中,foo("string2") 应该被 foo() 替换为不正确的值 10,这将导致 TypeError 被抛出。
目前,你需要知道的唯一 pdb 命令是
-
c,或continue;命令的另一个版本是cont; -
n,或next;和 -
q,或quit。
有时你需要使用quit两到三次,甚至更多次,才能退出调试器。
continue 命令会执行程序直到以下两种情况之一发生:程序结束或抛出错误。为了查看这如何工作,我们来运行我们脚本的正确版本 myapp.py:
$ python -m pdb myapp.py
> /{path}/myapp.py(1)<module>()
-> def foo(s: str):
(Pdb) c
The program finished and will be restarted
> /{path}/myapp.py(1)<module>()
-> def foo(s: str):
(Pdb)
(在代码块中,{path} 代表从我的计算机上的长路径。)
如你所见,在运行了 shell 命令 python -m pdb myapp.py 之后,我们进入了一个新的 pdb 会话,调试器正在等待我们的第一个命令。如上所示,c 命令将继续程序运行直到第一个错误或程序结束。由于我们运行的是正确的脚本,调试器没有遇到任何问题,并且打印了程序已完成,将重新启动。这将我们移回到程序的第一行,调试器再次等待我们的命令。现在我们可以逐行调试(如下所示)。
让我们看看如果我们对错误的脚本使用 c 命令会发生什么:
$ python -m pdb myapp_error.py
> /{path}/myapp_error.py(1)<module>()
-> def foo(s: str):
(Pdb) c
Traceback (most recent call last):
File "/usr/lib/python3.9/pdb.py", line 1726, in main
pdb._runscript(mainpyfile)
File "/usr/lib/python3.9/pdb.py", line 1586, in _runscript
self.run(statement)
File "/usr/lib/python3.9/bdb.py", line 580, in run
exec(cmd, globals, locals)
File "<string>", line 1, in <module>
File "/{path}/myapp_error.py", line 1, in <module>
def foo(s: str):
File "/{path}/myapp_error.py", line 3, in foo
raise TypeError
TypeError
Uncaught exception. Entering post mortem debugging
Running 'cont' or 'step' will restart the program
> /{path}/myapp_error.py(3)foo()
-> raise TypeError
(Pdb)
如你所见,这次程序引发了一个错误(TypeError,没有消息)。当抛出未捕获的错误时,程序会停止执行,调试器进入所谓的事后调试阶段。这时你可以了解你的程序发生了什么以及为何失败。
按 n,pdb 将执行代码的下一行。不是下一条命令,而是下一行,因此如果下一条命令被拆分成两行或更多行,你需要调用每一行,最终执行命令。请注意这个 pdb 会话:
$ python -m pdb myapp.py
> /{path}/myapp.py(1)<module>()
-> def foo(s: str):
(Pdb) n
> /{path}/myapp.py(1)<module>()
-> if __name__ == "__main__":
(Pdb)
> /{path}/myapp.py(1)<module>()
-> for s in ("string1", 10, "string2"):
(Pdb)
> /{path}/myapp.py(1)<module>()
-> _ = foo(s)
(Pdb)
> /{path}/myapp.py(1)<module>()
-> for s in ("string1", 10, "string2"):
(Pdb)
> /{path}/myapp.py(1)<module>()
-> _ = foo(s)
(Pdb)
TypeError
> /{path}/myapp.py(1)<module>()
-> _ = foo(s)
(Pdb)
首先,请注意当你使用一个命令(在这里是 n)时,你不需要重复它来运行。pdb 会记住你最后的命令,按回车键将再次执行它。在按了几次之后,它把我们带到了停止程序的错误。
请注意,在 pdb 模式下,Tab 自动补全的行为并不完全正常。这并不意味着它完全无效;你只需要在输入其他内容之前使用 p 命令。例如,在这种情况下按下 Tab 键:
(Pdb) al
将不会有任何结果。但在这里按:
(Pdb) p al
将会完成 alpha 名称:
(Pdb) p alpha
有很多 pdb 命令可供你使用。你可以在这里找到它们:
源代码:Lib/pdb.py 模块 pdb 定义了一个用于 Python 程序的交互式源代码调试器。它支持…
在继续之前,我想与您分享一个简单的命令;也许它不是最重要的,但在我过去的经验中我非常欣赏它。它是pp,用于漂亮打印:
(Pdb) {f"{x_i = }, {alpha = }, and {beta = }": (x_i + alpha)/(1 + beta) for x_i in x}
{'x_i = 1, alpha = 4, and beta = 0': 5.0, 'x_i = 2, alpha = 4, and beta = 0': 6.0, 'x_i = 3, alpha = 4, and beta = 0': 7.0}
(Pdb) pp {f"{x_i = }, {alpha = }, and {beta = }": (x_i + alpha)/(1 + beta) for x_i in x}
{'x_i = 1, alpha = 4, and beta = 0': 5.0,
'x_i = 2, alpha = 4, and beta = 0': 6.0,
'x_i = 3, alpha = 4, and beta = 0': 7.0}
正如您所见,调用一个表达式和使用 pp 命令调用它之间差别很大。因此,记住它是好的。
还有一件事。即使上面的字典推导很长,我也没有将其拆分成两行或更多行。这是因为 pdb 不允许这样做,至少在其调试模式下是如此——但您可以使用其交互模式,您可以通过 interact 命令运行它:
(Pdb) interact
>>> {f"{x_i = }, {alpha = }, and {beta = }":
... (x_i + alpha)/(1 + beta) for x_i in x}
{'x_i = 1, alpha = 4, and beta = 0': 5.0, 'x_i = 2, alpha = 4, and beta = 0': 6.0, 'x_i = 3, alpha = 4, and beta = 0': 7.0}
记住,在交互模式下,pdb 命令不起作用。要离开此模式并返回 pdb 模式,请按 <Ctrl + D>。
使用 breakpoint() 函数进行调试
上面我们讨论了在 pdb 模式下调试。然而,通常情况下,设置一个所谓的断点会更容易。断点是代码中的一个位置,在这个位置您希望程序暂停并进行分析;您可以在代码中创建多个断点,代码会在每个断点处停止——除非抛出错误。
要创建一个,请在您希望调试器停止并让您进入的代码位置添加对 breakpoint() 函数的调用:
def y(x, alpha, beta):
breakpoint()
return [(xi + alpha)/(1 + beta) for xi in x]
x = [1, 2, 3]
y(x)
运行此脚本将引导您进入此调试会话:
-> return [(xi + alpha)/(1 + beta) for xi in x]
(Pdb) l
1 def y(x, alpha, beta):
2 breakpoint()
3 -> return [(xi + alpha)/(1 + beta) for xi in x]
4
5
6 x = [1, 2, 3]
7 y(x, 4, 0)
[EOF]
(Pdb)
l(list)命令显示您当前所在位置周围的十一行。您还可以使用 ll(longlist),它将打印当前函数或帧的整个源代码。
其余部分与之前相同,因为您已经进入了我们上面讨论的 pdb 模式。使用 breakpoint() 函数的明显优势是可以精确地在您希望的地方停止程序。坦白说,我在几乎所有的调试会话中都使用 breakpoint()。

代码中的断点让您暂停片刻,检查您希望检查的代码位置的内部情况。照片由 Malte Helmhold 在 Unsplash 上拍摄
对象丢失了?
您可能会遇到一种奇怪的情况——虽然它只对那些不知道如何处理的人来说很奇怪。有时,您可能会发现 pdb 的行为非常特殊:虽然它可以看到局部变量,但它……看不到这些局部变量。
听起来像完全的废话?让我解释一下。考虑这个非常简单的函数:
def y(x, alpha, beta):
return [(xi + alpha)/(1 + beta) for xi in x]
它计算一个简单模型的值,针对一个值列表 x,给定两个模型参数 alpha 和 beta。例如:
>>> def y(x, alpha, beta):
... return [(xi + alpha)/(1 + beta) for xi in x]
...
>>> x = [1, 2, 3]
>>> y(x, .25, 0)
[1.25, 2.25, 3.25]
现在想象一下,您希望进入函数并检查多个 x 列表的函数。您可以通过 pdb 的帮助来做到这一点:
>>> def y(x, alpha, beta):
... breakpoint()
... return [(xi + alpha)/(1 + beta) for xi in x]
...
>>> y(x, .25, 0)
> <stdin>(3)y()
(Pdb) alpha, beta
(0.25, 0)
(Pdb) [(xi + alpha)/(1 + beta) for xi in x]
*** NameError: name 'alpha' is not defined
什么?刚刚发生了什么?为什么pdb看不到alpha——它不是刚刚看到的吗?确实,在这一行:
(Pdb) alpha, beta
(0.25, 0)
所以,它能看到alpha和beta——但它看不到它们?
也许我们应该再次给这些变量赋值?让我们检查一下:
(Pdb) alpha = .25; beta = 0
(Pdb) alpha
0.25
(Pdb) [(xi + alpha)/(1 + beta) for xi in x]
*** NameError: name 'alpha' is not defined
不,这根本没有帮助。
问题是,列表推导式——以及其他推导式——有自己的作用域,局部变量在那里是不可见的。幸运的是,你有很多解决方案,如下所示。
交互模式
交互模式实际上在各种情况下都非常有用。你可以使用pdb shell 中的interact命令来启动它:
(Pdb) interact
*interactive*
>>> [(xi + alpha)/(1 + beta) for xi in x]
[1.25, 2.25, 3.25]
如你所见,在交互模式下,代码的运行方式是正常的。
将缺失的对象添加到 globals
缺少一个特定的对象,所以只需将其添加到globals()中:
(Pdb) globals()['alpha'] = alpha
(Pdb) [(xi + alpha)/(1 + beta) for xi in x]
*** NameError: name 'beta' is not defined
如你所见,pdb可以看到alpha但看不到beta。一种解决方案是将其添加到globals()中,就像我们添加alpha一样,但逐个提供所有全局变量并不好玩;下一个解决方案只需一条命令即可完成。
将所有局部变量添加到 globals
locals()和globals()都是字典,因此我们可以简单地将前者添加到后者中。你可以按照以下方式进行:
(Pdb) globals().update(locals())
(Pdb) [(xi + alpha)/(1 + beta) for xi in x]
[1.25, 2.25, 3.25]
希望你喜欢这篇文章。虽然文章没有涵盖pdb的所有知识,但它提供了足够的知识来在大多数情况下使用这个调试工具。
在我超过 5 年的 Python 实践中,我注意到很少有人使用pdb来调试代码。我不知道为什么。IDE 调试工具确实能提供更多,但pdb的强大之处在于它在 Python 标准库中的可用性。
我不确定这是否值得自豪,但我会对你诚实:pdb是我选择的调试工具。我几乎不使用其他调试工具。我从未遇到过任何问题;相反,它在我所有的 Python 项目中都提供了帮助。
当我在尝试其他调试工具时,确实遇到了各种问题。也许是我自己的问题;也许是我没有足够长时间地使用它们以体验它们的强大。这可能是真的——但我可以说我已经足够长时间地使用pdb,尽管它很简单,但它可以是一个很棒的调试工具。
感谢阅读。如果你喜欢这篇文章,你可能也会喜欢我写的其他文章;你可以在这里看到它们。如果你想加入 Medium,请使用下面的推荐链接:
链接 [## 使用我的推荐链接加入 Medium - Marcin Kozak
作为 Medium 的会员,你的部分会员费会分给你阅读的作者,同时你可以完全访问每一个故事……
使用 PostgreSQL、FastAPI 和 Docker 构建一个后端
原文:
towardsdatascience.com/build-a-back-end-with-postgresql-fastapi-and-docker-7ebfe59e4f06
开发基于地图的应用程序的逐步指南(第四部分)
·发表于 Towards Data Science ·阅读时间 28 分钟·2023 年 3 月 14 日
--

图片来自 Caspar Camille Rubin 于 Unsplash
地图是可视化和理解地理数据的强大工具,但需要特定的技能才能高效设计。
在这个逐步指南中,我们将深入探讨如何构建一个基于地图的应用程序,以展示客户周围加油站的价格。我们将涵盖产品的不同关键步骤,从最初的概念验证(POC)到最小可行产品(MVP)。
系列文章:
第一部分:概念验证——构建一个简约的演示
第二部分:如何使用 React 构建网页应用(静态布局)
第三部分:使用 React 为你的网页应用添加互动性
第四部分:使用 PostgreSQL、FastAPI 和 Docker 构建后端
关于这篇文章的一些背景信息
在系列文章的前几部分,我们使用 React 构建了加油站查找器的前端,并将后端视为一个仅提供相关数据的“黑箱”。
在这一部分,我们将详细介绍如何使用强大的工具,如 PostgreSQL 或 FastAPI,一步一步构建后端。
你可以在我的 Github 页面 找到该项目的完整代码。
为什么我们需要一个干净的后端?
在本系列的第一部分中,我们创建了一些实用函数,以便从公共提供商直接获取燃料站的数据。虽然这对于我们的概念验证足够了,但由于多种原因,我们现在需要一个更强大的系统:
-
性能与延迟:实时处理数据,包括解析 XML、格式化和过滤,计算成本高,对于预期频繁使用的应用程序来说,可能不切实际。
-
可靠性:确保我们的应用程序不会受到第三方数据源的意外更改或停机的影响。仅依赖外部门户的数据会使我们的应用程序面临风险,因为即使提供商简单地更改字段名称也可能导致我们这边的错误和停机,而我们需要修补这些变化。通过构建我们自己的数据库,我们可以对数据有更大的控制权,进行必要的更新和维护,而无需依赖外部方。
-
自定义:通过我们自己的数据库,我们可以定制数据以满足技术规格,添加其他外部数据源,为不同的用例构建自定义数据视图等……
为了满足这些需求,我们将构建自己的数据库和 API,处理数据获取、处理和前端交付。这将包括使用 Docker 运行 PostgreSQL 数据库,使用 Python 和 sqlalchemy 与数据库交互,并使用 PostGIS 扩展进行地理查询。我们还将探索如何使用 FastAPI 和 SQLmodel 构建一个简单的 API。
下图是应用程序不同组件的简单示意图:

我们应用程序不同组件的简单视图,作者插图
本文涵盖的内容
在本文中,我们专注于内部数据库和 API 的创建。具体来说,我们将:
-
使用 Docker 运行 PostgreSQL 数据库
-
使用 python 和 sqlalchemy 与数据库交互
-
使用 PostGIS 扩展进行地理查询
-
使用 FastAPI 和 SQLmodel 构建一个简单的 API
-
使用 docker-compose 容器化我们的项目并运行
使用 Docker 创建本地 PostgreSQL 实例
Docker 是一个开源容器化平台,允许你在一致且隔离的环境中运行应用程序。使用 Docker 设置 PostgreSQL 服务器有几个优点,包括能够以标准化的方式安装应用程序,而不必担心与系统上其他配置冲突。
在我们的案例中,我们将直接在容器内部设置 Postgre 服务器。
我在这里假设你已经在计算机上安装了 Docker,因为安装方法因系统而异。
获取容器镜像
Docker 镜像可以看作是构建一个专用于特定任务的容器所需的所有规格。它本身不做任何事情,但用于构建容器(一个专用虚拟环境),你的应用程序将运行在其中。我们可以使用 Dockerfile 创建我们自己的自定义镜像(稍后会讨论),或者我们可以从社区共享的各种开源镜像中下载现成的镜像。
在我们的情况下,我们需要一个可以帮助我们创建运行 PostgreSQL 的容器的镜像,我们可以使用 官方镜像 来实现这一目的。
我们从在 Docker 上下载 PostgreSQL 镜像开始。这是在 Shell 中完成的:
docker pull postgres
运行 Postgre 容器
一旦镜像在 Docker 中下载完成,我们可以使用以下命令基于它构建容器:
docker run -itd -e POSTGRES_USER=jkaub -e POSTGRES_PASSWORD=jkaub -p 5432:5432 -v ~/db:/var/lib/postgresql/data --name station-db postgres
让我们解密它。
-itd 是三个参数的组合:
-
-d 表示我们以分离模式运行容器。在这种模式下,容器将在后台运行,我们可以继续使用终端进行其他操作。
-
-i 指定我们的容器将以交互模式运行。这将允许我们进入容器并与之交互。
-
-t 表示容器内部将提供一个伪终端,以便与容器进行交互,这将带来更无缝和直观的交互体验。
-e 用于在容器内部生成环境变量。在这种情况下,环境变量 POSTGRES_USER 和 POSTGRES_PASSWORD 还用于用给定的密码生成我们的 PostgreSQL 实例的新用户。如果没有这个,我们仍然可以使用默认用户/密码(postgre/postgre)访问 PostgreSQL 实例。
-p 用于将本地机器的端口映射到 Docker 容器中。PostgreSQL 的默认端口是 5432。如果它已经在你的本地机器上被使用,你可以使用这个参数将容器中的 5432 映射到你机器上的另一个端口。
-v 是一个在我们这种情况下非常重要的参数:它允许我们将一个卷从我们的机器(在我们这种情况下是文件夹 ~/db )映射到容器内部的卷,其中 SQL 数据默认存储在 /var/lib/postgresql/data。通过进行这种映射,我们创建了一个持久卷,即使容器停止后也会保留。因此,我们的数据库将持续存在,即使我们停止使用容器,之后也可以使用。
— name 只是一个标志,用于给容器命名,这将有助于我们以后访问它。
我们可以通过使用下面的命令检查容器是否处于活动状态,该命令将显示我们机器上运行的容器列表:
docker ps
返回:
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
cb0840806636 postgres "docker-entrypoint.s…" 2 minutes ago Up 2minutes 0.0.0.0:5432->5432/tcp station-db
与 PostgreSQL 的初步交互
我们的 PostgreSQL 实例现在正在容器中运行,我们可以与它进行交互。
创建数据库
作为起点,让我们创建一个包含项目不同表格的第一个数据库。
为此,我们需要进入容器。记住,这可能是因为我们在初始化容器时指定了 -it 参数。下面的命令行将完成这项工作:
docker exec -it station-db bash
命令提示符现在应该是:
root@cb0840806636:/#
这意味着我们以 root 用户身份登录到容器中。我们可以使用用户 (-U)/密码 (-d) 连接到 PostgreSQL,如下所示:
psql -U jkaub -d jkaub
一旦进入 PostgreSQL 实例,我们可以使用 SQL 查询与之交互,特别是创建一个新的数据库来托管我们未来的表。
CREATE DATABASE stations;
我们可以通过运行来验证数据库是否已经创建
\l
这将显示系统中的不同数据库。在一些在实例初始化时创建的默认数据库中,我们可以找到刚刚创建的那个:
jkaub=# \l
List of databases
Name | Owner | Encoding | Collate | Ctype | ICU Locale | Locale Provider | Access privileges
-----------+-------+----------+------------+------------+------------+-----------------+-------------------
stations | jkaub | UTF8 | en_US.utf8 | en_US.utf8 | | libc |
现在我们已经设置好了 PostgreSQL 实例,我们可以通过手动在psql中编写 SQL 查询来创建表和从 .csv 文件导入数据。虽然这种方法适用于一次性使用,但如果我们需要频繁更新表,它可能会变得繁琐且容易出错。
因此,为了促进自动化,我们将使用 Python 框架与数据库及其表进行交互。这将允许我们通过代码轻松创建、更新和查询数据库,使过程更加高效且减少错误。
用 sqlalchemy 打开一个会话
SQLalchemy 是一个开源 SQL 工具包和对象关系映射(ORM)工具,供 Python 开发者使用。它提供了一组高级函数来与数据库交互,而不是编写 SQL 查询。
这特别方便,因为它允许我们使用 Python 类(在这里也称为“模型”)定义表的结构,并使用面向对象的范式。我们的 Python ORM,sqlalchemy,在下一部分构建后端 API 时将特别有用。
让我们开始安装项目所需的库。在 sqlalchemy 的基础上,我们还将使用psycopg2,这是一个 PostgreSQL 适配器,可被 sqlalchemy 用作连接器。
pip install psycopg2 sqlalchemy
我们现在可以直接在 Python 中有效地创建一个会话来访问我们的数据库:
from sqlalchemy import create_engine
engine = create_engine('postgresql://jkaub:jkaub@localhost/stations')
# test the connection by executing a simple query
with engine.connect() as conn:
result = conn.execute('SELECT 1')
print(result.fetchone())
一步步解释这个脚本:
engine = create_engine('postgresql://jkaub:jkaub@localhost/stations')
create_engine 方法用于保持与数据库的连接。我们需要在这里指定一个数据库 URL,该 URL 包含连接到我们数据库所需的所有信息。
-
该 URL 的第一部分 postgresql😕/ 是为了指定我们正在使用 PostgreSQL 连接,并且接下来将是该类型数据库连接的规格。如果你使用的是不同的数据库,如 SQLite,你将会有不同的基本 URL 和规格。
-
jkaub:jkaub 是连接到我们数据库的登录信息。
-
localhost 是运行数据库的服务器。服务器 IP 也可以用来连接远程服务器,或者,如我们稍后会看到的,在容器集群的情况下,我们在某些情况下也可以使用容器名称。
-
/stations 用于指定我们想要连接的数据库。在我们的例子中,我们连接到我们刚刚创建的“stations”。
# test the connection by executing a simple query
with engine.connect() as conn:
result = conn.execute('SELECT 1')
print(result.fetchone())
这部分代码目前仅用于测试连接是否成功。我们的数据库还没有表可以查询,所以我们只是运行一个虚拟查询。它应该返回 (1,),这意味着连接成功。
使用 FastAPI 构建 API
现在我们已经在 Docker 容器中设置了 PostgreSQL 数据库,并使用 SQLAlchemy 引擎访问了它,现在是时候开发与数据库交互的 API 了。
使用 API 有几个好处:
-
它提供了可重用性和平台/语言的独立性,允许多个服务使用相同的 API 端点。
-
它将数据库逻辑与应用逻辑分开,使得只要输入/输出被尊重,修改一个而不影响另一个变得更加容易。
-
它增加了一层安全性,因为你可以通过授权系统控制谁可以访问数据库。
-
最后,API 是可扩展的,可以在多个服务器上运行,使其在管理工作负载方面非常灵活。通过创建一组明确的 URL,我们将能够通过 API 从数据库中检索、修改、插入或删除数据。
关于 FastAPI
FastAPI 是一个现代的 Python 框架,在构建轻量级 API 方面特别高效,由 Sebastián Ramírez 开发。
结合sqlalchemy和pydantic这两个用于数据验证的 Python 库时,它特别高效(例如,它可以检查一个日期是否确实是日期,一个数字是否确实是数字,等等)。一起使用,它使我们能够通过框架有效地处理和查询表格。
更棒的是,Sebastián Ramírez 还设计了另一个库,sqlmodel,它结合了 pydantic 和 sqlalchemy,去除了一些冗余,并进一步简化了 API 的架构。
如果你还不熟悉 FastAPI,我建议你先看看这个教程,这个教程做得非常好。
在开始项目之前,我们需要安装多个库。
pip install uvicorn
pip install fastapi
pip install sqlmodel
pip install geoalchemy2
-
uvicorn 是运行 API 服务器的工具,并且非常适合与 FastAPI 一起使用。
-
fastapi 是 API 的核心引擎,我们将用它来创建不同的端点。
-
sqlmodel 将 sqlalchemy ORM 与 pydantic 的类型验证功能结合起来
-
geolochemy2 是 sqlalchemy 的一个扩展,用于执行地理查询。
初始化模型
让我们为我们的 API 项目创建一个新的仓库,从使用sqlmodel定义模型开始。 “模型”只是一个代表 SQL 中表格的 Python 类。
api/
|-- app/
|-- __init__.py
|-- models.py
我们的项目将包含 3 个表格,并遵循我们在 part I 中构建的初始设计。
-
一个包含与城市相关的信息(邮政编码、位置)的表格
-
一个包含关于燃气价格信息的表格
-
一个包含关于车站信息的表格
通过联接和地理过滤器组合这些表,将帮助我们构建前端请求的最终输出。
我们来看一下第一个表格,Cities 表:
from sqlmodel import Field, SQLModel
from datetime import datetime
from typing import Optional
class Cities(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
postal_code: str
name: str
lat: float
lon: float
类 “Cities” 继承自 SQLModel 类,结合了 sqlalchemy’s ORM 特性和 pydantic 的类型控制。
参数 table=True 表示如果数据库中尚不存在相应的表,则会自动创建该表,并匹配列名和列类型。
类的每个属性将定义每一列及其类型。特别是,“id”将作为我们的主键。使用 Optional 将指示 sqlalchemy 如果我们不提供 id,则自动生成 id。
我们还提供了另外两个表的模型:
from datetime import datetime
...
class GasPrices(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
station_id: str
oil_id: str
nom: str
valeur: float
maj: datetime = Field(default_factory=datetime.utcnow)
class Stations(SQLModel, table=True):
station_id: str = Field(primary_key=True)
latitude: float
longitude: float
cp: str
city: str
adress: str
注意在 Stations 表中,我们使用 station_id 作为主键,与 GasPrices 不同,该字段是必填的。如果在发送到表格时该字段为空,将会生成错误信息。
初始化引擎
在另一个专用文件中,为了保持项目结构化,我们将初始化引擎。我们称之为 services.py。
api/
|-- app/
|-- __init__.py
|-- models.py
|-- services.py
连接到数据库的方式与之前介绍的相同。
from sqlmodel import SQLModel, create_engine
import models
DATABASE_URL = 'postgresql://jkaub:jkaub@localhost/stations'
engine = create_engine(DATABASE_URL)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
注意函数 create_db_and_tables():该函数将在 API 初始化期间被调用,查看 models.py 中定义的模型,并直接在 SQL 数据库中创建它们(如果它们尚不存在的话)。
实操 API
我们现在可以开始开发主要组件,我们将在其中放置端点(即允许我们与数据库交互的 URL)。
api/
|-- app/
|-- __init__.py
|-- main.py
|-- models.py
|-- services.py
我们要做的第一件事是配置 FastAPI 在启动时的设置,并处理 API 授权。
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, HTTPException
from models import Cities, Stations, GasPrices
from services import engine, create_db_and_tables
#We create an instance of FastAPI
app = FastAPI()
#We define authorizations for middleware components
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#We use a callback to trigger the creation of the table if they don't exist yet
#When the API is starting
@app.on_event("startup")
def on_startup():
create_db_and_tables()
一个重要的点是:默认情况下,我们的前端没有访问 API 调用的权限,如果你忘记配置中间件部分,将会在前端出现错误。你可以决定通过使用以下方式允许所有来源:
allow_origins=["*"],
但由于安全原因,我不推荐这样做,因为一旦上线,你基本上会将 API 向全世界开放。我们的前端目前在 localhost:3000 上本地运行,所以这是我们将允许的域名。
到那时,我们已经可以通过使用以下命令行来启动 API:
uvicorn main:app --reload
— reload 只是意味着每次在 API 正在运行时保存修改,它将重新加载以包含这些更改。
一旦启动,你可以看到一些日志显示在终端中,特别是:
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
表示 API 服务器正在 localhost(等同于 IP 127.0.0.1)上的 8000 端口运行。
如前所述,启动 API 还会触发数据库中空表的创建(如果它们尚不存在)。因此,从你第一次启动 API 的那一刻起,使用 table=True 创建的模型将在数据库中有一个专用表。
我们可以通过在 psql 中从 PostgresSQL 容器内部轻松检查这一点。连接为主用户后,我们首先连接到数据库 station:
\c stations
我们现在可以检查我们的表是否已正确创建:
\dt
这将返回:
List of relations
Schema | Name | Type | Owner
--------+-----------------+-------+-------
public | cities | table | jkaub
public | gasprices | table | jkaub
public | stations | table | jkaub
我们还可以通过在 psql 中运行描述查询来验证列是否与我们的模型匹配,例如,对于 cities 表:
\d cities
Column | Type | Collation | Nullable | Default
-------------+-------------------+-----------+----------+-----------------------
-------------
id | integer | | not null | nextval('cities_id_seq
'::regclass)
postal_code | character varying | | not null |
name | character varying | | not null |
lat | double precision | | not null |
lon | double precision | | not null |
Indexes:
"cities_pkey" PRIMARY KEY, btree (id)
构建我们的第一个请求 — 使用 POST 请求向 Cities 添加行
Cities 表只会填充一次,并用于将邮政编码与城市的经纬度匹配,这在后续使用邮政编码查询这些位置时特别有帮助。
目前,数据存储在 .csv 文件中,我们想设计一个 POST 调用,用于更新表格,如果数据还不在数据库中,则一次添加一行。API 调用被放在 main.py 文件中。
from sqlmodel import Session
...
@app.post("/add-city/")
def add_city(city: Cities):
with Session(engine) as session:
session.add(city)
session.commit()
session.refresh(city)
return city
让我们逐行查看这段代码:
@app.post("/add-city/")
每个 API 端点都是通过装饰器定义的。我们在这里定义了两件事:请求的类型(get、post、put、delete…)和关联的 URL 端点 (/add-city/)。
在这个特定的案例中,我们将能够在 http://127.0.0.1:8000/add-city/ 执行 POST 请求
def add_city(city: Cities):
我们将不同的参数传递给函数以用于查询。在我们的案例中,post 请求将寻找 Cities 的实例,这将通过 JSON 传递在我们的请求中。此 JSON 将包含我们想要添加的新行的 Cities 表的每列的值。
with Session(engine) as session:
要连接到数据库,我们打开一个 Session。每个查询需要自己的会话。使用这种方法在会话内部出现意外情况时特别有用:在会话初始化和 commit() 之间所做的所有更改将在出现问题时回滚。
session.add(city)
session.commit()
session.refresh(city)
在这里,对象被添加到数据库中,然后进行提交。从提交的那一刻起,操作无法回滚。refresh 用于更新 DB 进行的任何修改的“city”对象。在我们的案例中,例如,会自动添加递增的“id”。
return city
我们通过以 JSON 格式发送 city 对象来结束请求。
我们现在可以在 python 中尝试请求(当然需要 API 正在运行):
import requests
url='http://127.0.0.1:8000/add-city/'
json = {
'postal_code': '01400',
'name':"L'Abergement-Clémenciat",
'lat':46.1517,
'lon':4.9306
}
req = requests.post(url, json=json)
请注意,我们在请求中发送的 JSON 键与我们要更新的表的列名称匹配。参数“id”是可选的,它将自动添加到操作中,我们不需要担心它。
这应该会在 API shell 中触发以下行:
INFO: 127.0.0.1:33960 - "POST /add-gas-price/ HTTP/1.1" 200 OK
这意味着请求成功。我们可以进一步验证该行是否已正确添加。返回到我们 docker 中的 psql,我们可以尝试以下查询:
SELECT * FROM cities LIMIT 1;
这将显示:
id | postal_code | name | lat | lon
----+-------------+-------------------------+---------+--------
1 | 01400 | L'Abergement-Clémenciat | 46.1517 | 4.9306
演示该行已被 API 有效添加到我们的数据库中。
此外,我们不希望邮政编码被重复添加。为此,我们将查询 Cities 表,根据我们尝试发送的邮政编码过滤表,并在找到具有该邮政编码的行时返回 HTML 错误,从而避免邮政编码重复。
from fastapi import FastAPI, HTTPException
...
@app.post("/add-city/")
async def add_city(city: Cities):
with Session(engine) as session:
#New code block
exist = session.query(Cities).filter(
Cities.postal_code == city.postal_code).first()
if exist:
raise HTTPException(
status_code=400, detail="Postal code already exists")
#New code block
session.add(city)
session.commit()
session.refresh(city)
return city
在这段新代码中,我们使用 sqlalchemy ORM 执行了第一次数据库查询:我们没有编写经典的 SQL(“SELECT FROM”),而是使用了一组函数来直接查询数据库。
exist = session.query(Cities).filter(
Cities.postal_code == city.postal_code).first()
-
.query相当于 SELECT … FROM …,在我们的案例中,我们从 cities 表中选择所有内容。 -
.filter相当于 WHERE 语句。特别地,我们要匹配与我们发送的对象的邮政编码相等的条目(由变量 city 表示)。 -
.first()是自解释的,相当于 LIMIT 1。 -
如果没有找到行,则
exist将为 None,并且不会引发异常,因此我们将把对象添加到数据库中。如果某行与邮政编码匹配,API 请求将返回状态码 400 的错误,且 POST 请求将被中断,元素不会被添加。
如果我们现在尝试发送完全相同的请求,将看到 API 返回错误消息:
INFO: 127.0.0.1:49076 - "POST /add-city/ HTTP/1.1" 400 Bad Request
而且该行没有被添加到表中。
从那时起,我们可以简单地遍历 .csv 并逐一添加所有城市,以填充 cities 表。
使用 POST 请求向 Gasprices 和 Stations 表中添加行
我们将非常快速地跳过这些 API 调用的构建,因为它们与之前的非常相似。
@app.post("/add-station/")
async def add_station(station: Stations):
with Session(engine) as session:
exist = session.query(Stations).filter(
Stations.station_id == station.station_id).first()
if exist:
raise HTTPException(
status_code=400, detail="Station already exists")
session.add(station)
session.commit()
session.refresh(station)
return station
@app.post("/add-gas-price/")
async def add_station(gasPrice: GasPrices):
with Session(engine) as session:
exist = session.query(GasPrices). \
filter(GasPrices.oil_id == gasPrice.oil_id). \
filter(GasPrices.maj == gasPrice.maj). \
first()
if exist:
raise HTTPException(
status_code=400, detail="Entry already exists")
session.add(gasPrice)
session.commit()
session.refresh(gasPrice)
return gasPrice
唯一需要注意的有趣之处在于,我们使用了双重过滤查询,以确保仅在 oil_id 有新更新时才添加一行。这样,我们确保了未来的更新不会在价格没有变化的情况下创建重复项,从而节省了数据库空间。
为了检索和处理油价,我们只是回收了来自 Part I 的解析代码,获取相应的数据集并循环遍历,为每个条目进行 POST 调用。
以下脚本在 API 范围之外执行,以将数据上传到数据库:
import request
from data_parsing import get_data
BASE_API_URL = 'http://127.0.0.1:8000'
#get_data is the function designed in part I to pull the xml from the opendata
#source and convert them in Dataframes
stations, gas = get_data()
#Pushing stations data
to_push = stations[['latitude','longitude','cp','adress','city','station_id']].to_dict('records')
url=f'{BASE_API_URL}/add-station/'
for elmt in to_push:
req = requests.post(url, json=elmt)
#Pushing gasprices data
to_push = gas.to_dict('records')
url=f'{BASE_API_URL}/add-gas-price/'
for elmt in to_push:
req = requests.post(url, json=elmt)
注意:为了简化起见,我在这里选择逐行推送数据。我们也可以设计端点以批量推送数据,并发送 JSON 列表。
构建前端使用的 GET 查询
到这一点为止,我们的数据库已完全填充,以上脚本可以用于用更近期的数据更新数据库,我们可以开始构建前端用来查询特定城市周围加油站特定燃料价格的 GET 请求。
我决定为这个特定查询专门设置一个完整的章节,因为它的复杂性(我们将使用到目前为止定义的所有表,进行连接和地理过滤),同时因为我们需要在此时对数据库进行一些更改,以集成空间功能,安装附加组件和修改一些模型。虽然这可以从一开始就直接完成,但在实际项目中进行修改是常见的,我认为展示如何平滑地完成这一过程是很有趣的。
安装 PostGIS
PostGIS 是 PostgreSQL 的一个扩展,允许我们构建地理查询,这意味着需要一个空间组件。例如,在我们的情况下,我们能够选择距离某个兴趣点 30 公里半径内的所有车站数据行。
现在我们不想直接在运行的容器中安装 PostGIS,因为每次我们需要弹出一个新的容器时,这个安装会“丢失”,新容器基于仅安装 PostgreSQL 的镜像。
相反,我们将简单地更改用于构建容器的镜像,并将其替换为包含 PostgreSQL 和 PostGIS 的镜像。我们将提供相同的持久存储位置,以便新容器也可以访问它。
要构建包含 PostGIS 扩展的容器,我们首先从 docker 中拉取最新的 PostGIS 镜像,然后杀死并删除当前的 PostgreSQL 容器,并用新的镜像构建一个新的容器。
docker pull postgis
docker kill stations
docker rm stations
docker run -itd -e POSTGRES_USER=jkaub -e POSTGRES_PASSWORD=jkaub -p 5432:5432 -v ~/db:/var/lib/postgresql/data --name station-db postgis/postgis:latest
我们可以像之前一样访问容器,但现在我们使用的是包括 PostGIS 的 PostgreSQL 版本。
我们现在需要将扩展添加到现有数据库中。我们首先重新连接到数据库:
docker exec -it station-db bash
psql -U jkaub -d jkaub
\c stations
然后我们在其中包含 PostGIS 扩展:
CREATE EXTENSION postgis;
修改我们的 Stations 模型
现在我们的数据库中已经运行了 PostGIS,我们需要修改我们的 Stations 表以便能够执行地理查询。更准确地说,我们需要添加一个“geometry”字段,该字段被理解并转换为地球上的实际位置。
建立地图或标示地球上的位置有多种方法,每种方法都有其自己的投影和参考坐标系统。为了确保一个系统能够与另一个系统对话,我们需要确保它们使用相同的语言,这可能包括单位的转换(就像我们可以将米转换为英尺,或将千克转换为磅)。
对于坐标,我们使用称为“地理参数数据集”(EPSG)的东西。纬度和经度(EPSG 4326)以角度表示,无法直接转换为距离(欧几里得几何,包括距离计算,不能直接应用于球面,因为本质上这不是一个欧几里得表面)。相反,需要将其投影到平面表示中,这在 PostGIS 中处理得很好,只要我们注意并应用适当的转换。
作为起点,我们需要在 Stations 数据库中添加一个可以被解释为“几何”坐标的新字段。在我们的数据库中:
ALTER TABLE stations ADD COLUMN geom geometry(Point, 4326);
这一行将修改我们的 stations 表,添加一个新的字段 “geom”,它是类型为 “point” 的 PostGIS 几何图形,使用 EPSG 4326(经纬度系统的 EPSG)表示。该字段目前对所有行为空,但我们仍然可以在 SQL 中轻松填充它,以更新当前表(此时表并不为空)。
UPDATE stations SET geom = ST_SetSRID(ST_MakePoint(longitude, latitude), 4326);
上面的 SQL 查询将为 Stations 表的每一行设置 geom 列,使用由经度/纬度构建的点。注意我们在这里使用了两个 PostGIS 函数,ST_MakePoint 和 ST_SetSRID,来帮助我们在 SQL 中定义几何图形。
我们可以检查这个新的几何图形在数据库中的存储方式
SELECT * FROM stations LIMIT 1;
station_id | latitude | longitude | cp | city | adress | geom
------------+----------+-----------+-------+-------+-----------------------+----------------------------------------------------
26110004 | 44.36 | 5.127 | 26110 | NYONS | 31 Avenue de Venterol | 0101000020E6100000355EBA490C821440AE47E17A142E4640
你可以在这里看到,几何图形以字符串编码,这种格式是 Well-Known Binary (WKB),它在存储几何图形时非常高效。我不会进一步展开这个内容,但如果你在数据集中看到这个格式不要感到惊讶,如果需要,你可能需要将其解码为更可读的格式。
现在,我们还需要更新 model.py 文件中的 Stations 类以包含这个新字段,为此我们使用 geoalchemy 的 “Geometry” 类型。
from typing import Any
from geoalchemy2.types import Geometry
class Stations(SQLModel, table=True):
station_id: str = Field(primary_key=True)
latitude: float
longitude: float
cp: str
city: str
adress: str
geom: Optional[Any] = Field(sa_column=Column(Geometry('GEOMETRY')))
最后的修改是:我们希望在 POST 调用(在 main.py 中)时,使用纬度和经度参数自动计算几何图形:
from geoalchemy2.elements import WKTElement
@app.post("/add-station/")
async def add_station(station: Stations):
with Session(engine) as session:
exist = session.query(Stations).filter(
Stations.station_id == station.station_id).first()
if exist:
raise HTTPException(
status_code=400, detail="Station already exists")
#New code block
point = f"POINT({station.longitude} {station.latitude})"
station.geom = WKTElement(point, srid=4326)
#New code block
session.add(station)
session.commit()
session.refresh(station)
#This is only done to return a clean dictionnar with a proper json format
to_return = {}
to_return["station_id"] = station.station_id
to_return["latitude"] = station.latitude
to_return["longitude"] = station.longitude
to_return["cp"] = station.cp
to_return['city'] = station.city
to_return["adress"] = station.adress
return to_return
在这里,我们通过字符串创建一个点,使用另一种名为 WKTElement 的格式,这是一种使用人类可读的字符串编码几何图形的方式。我们的字符串随后通过 geolalchemy 函数 WKTElement 转换为几何图形,该函数隐式地将其转换为 WKB 格式以便编码到数据库中。
注意,“geom”不是 JSON 可序列化的,因此我们需要在通过 API 发送站点对象之前修改或删除它。
构建最终的 GET 查询
GET 查询的目标是从通过邮政编码识别的城市中检索 30 公里半径内的所有站点,并显示查询中提到的某种类型的所有站点的燃料最新价格,并附上一些美化的信息,如规范化的地址或 Google 地图链接。
{
"lat": 49.1414,
"lon": 2.5087,
"city": "Orry-la-Ville",
"station_infos": [
{
"address": "Zi Route de Crouy 60530 Neuilly-en-Thelle",
"price_per_L": 1.58,
"price_tank": 95,
"delta_average": 25.1,
"better_average": 1,
"google_map_link": "https://www.google.com/maps/search/?api=1&query=Zi+Route+de+Crouy+60530+Neuilly-en-Thelle",
"distance": 19.140224654602328,
"latitude": 49.229,
"longitude": 2.282
}, ...
]
}
我们将分两步进行:
-
首先构建一个高效的 SQL 查询来执行连接和过滤操作
-
在通过 API 发送结果之前,使用 Python 函数修改查询的输出。
与其他参数通过请求体中的 JSON 传递的查询不同,我们在这里将使用另一种约定,即将查询参数直接传递在 URL 中,见下面的示例:
http://localhost:8000/stations/?oil_type=SP98&postal_code=60560
在 FastAPI 中,这可以通过简单地向用于构建端点的函数中添加输入来自然完成:
@app.get("/stations/")
async def get_prices(oil_type: str, postal_code: str):
with Session(engine) as session:
...
现在我们要首先检索的是与邮政编码相关联的城市的纬度和经度。如果没有与邮政编码关联的城市,API 应该返回一个错误代码,说明未找到邮政编码。
city = session.query(Cities).filter(
Cities.postal_code == postal_code
).first()
if not city:
raise HTTPException(
status_code=404, detail="Postal Code not found")
接下来,我们将构建一系列子查询。每个子查询在最终查询完全执行之前不会被评估。这将帮助我们保持代码的可读性,并优化查询,因为 sqlalchemy ORM 会根据这些子查询动态优化查询。
我们要执行的第一个子查询是从 Stations 表中选择所有在已查询城市 30 公里半径范围内的车站。
stations = session.query(
Stations.station_id, Stations.adress, Stations.cp, Stations.city,
Stations.latitude, Stations.longitude,
).filter(
ST_Distance(
Stations.geom.ST_GeogFromWKB(),
WKTElement(f"POINT({city.lon} {city.lat})",
srid=4326).ST_GeogFromWKB()
) < 30000).subquery()
这里有许多有趣的地方需要注意。
-
我们只在 session.query( … ) 中选择了少量列,并且不保留 geom 列,该列仅用于过滤。在标准 SQL 中,这可以通过 “SELECT station_id, adress, cp, city, latitude, longitude FROM stations” 来完成。
-
我们使用ST_Distance,这是 geoalchemy 的内置函数,用于计算两个地理位置之间的距离(另一种 geoalchemy 类型)。
-
ST_Distance 也可以与几何对象一起工作,但输出将变成角度距离(请记住,纬度/经度是以角度表示的),这不是我们想要的。
-
要将几何对象转换为地理对象,我们只需使用另一个内置函数 ST_GeoFromWKB,它会自动将我们的几何体投影到其参考系统中,以在地球上形成一个点。
接下来,我们根据所需的 oil_type(如 SP95、Gazole 等)过滤 Gasprices 表。
price_wanted_gas = session.query(GasPrices).filter(
GasPrices.nom == oil_type
).subquery()
我们还需要根据数据集中最新的价格来过滤 Gasprices 表。这不是一项容易的任务,因为所有价格的更新不是同时完成的。我们将分两步构建子查询。
首先,我们通过从 price_wanted_gas 子表中提取 station_id 和最后更新时间来执行聚合。
last_price = session.query(
price_wanted_gas.c.station_id,
func.max(price_wanted_gas.c.maj).label("max_maj")
).group_by(price_wanted_gas.c.station_id) \
.subquery()
然后使用这些信息帮助我们通过连接过滤 price_wanted_gas,其中仅保留最新更新价格的行。“and_” 方法允许我们在连接操作中使用多个条件。
last_price_full = session.query(price_wanted_gas).join(
last_price,
and_(
price_wanted_gas.c.station_id == last_price.c.station_id,
price_wanted_gas.c.maj == last_price.c.max_maj
)
).subquery()
最后,我们在 last_price_full 子表(包含给定燃料的所有最新价格)和 stations 子表(包括所有在 30 公里半径内的车站)之间进行最终连接,并检索所有结果。
stations_with_price = session.query(stations, last_price_full).join(
last_price_full,
stations.c.station_id == last_price_full.c.station_id
).all()
到达这一点时,我们检索了经过过滤的相关车站列表,并将其与 GasPrices 表中的相关信息(即:价格)合并,我们只需要对输出结果进行后处理,以符合前端的要求。由于此时表格已经被清理和过滤,因此最终的后处理步骤可以在原生 Python 中完成,而不会对性能产生太大影响。
我将稍微详细说明这个最终的后处理步骤,因为它不在文章的核心部分,但请随时查看 GitHub 仓库以获取更多信息。
prices = [float(e["valeur"]) for e in stations_with_price]
avg_price = float(np.median(prices))
output = {
"lat": city.lat,
"lon": city.lon,
"city": pretify_address(city.name),
"station_infos": sorted([extend_dict(x, avg_price, city.lat, city.lon) for x in stations_with_price], key=lambda x: -(x['delta_average']))
}
return output
我们现在可以测试并验证查询是否返回了相关输出。我们可以使用 Python 请求进行检查,但 FastAPI 还提供了所有端点的内置文档,您可以在 localhost:8000/docs 测试您的 API。

FastAPI 内置文档的截图,作者插图
容器化应用程序
现在我们有了一个运行中的 API,我们将通过将应用程序打包到容器中来完成本文。
这是我们项目的组织方式:
stations-project/
|-- db/
|-- api/
|-- app/
|-- requirements.txt
|-- Dockerfile
|-- update_scripts/
|-- front/
|-- docker-compose.yml
我们将使用 api/ 中的 Dockerfile 来容器化 API,并使用 docker-compose 同时管理 API 和数据库。
文件夹 db/ 是 PostgreSQL 容器用来持久化数据库的卷。
打包我们的 API
为了打包我们的 API,我们将简单地构建一个 docker 镜像,该镜像将复制运行 API 所需的环境和依赖项。这个 docker 镜像将包含运行 API 所需的所有内容,包括代码、运行时、系统工具、库和配置。
为此,我们需要编写一个 Dockerfile,其中包含设置 FastAPI 环境的一系列指令。编写 Dockerfile 相对容易,只要理解了原理:它就像是从头开始配置一台新机器。在我们的案例中:
-
我们需要安装相关版本的 Python
-
设置工作目录
-
将相关文件复制到我们的工作目录中(包括 requirements.txt 文件,该文件是使用 pip install 安装项目所需所有库的强制要求)
-
使用 pip install 安装库
-
暴露 FastAPI 端口
-
运行初始化 API 的命令(uvicorn main:app — reload)
用 Docker 语言翻译过来,这变成了:
FROM python:3.9
WORKDIR /code
COPY ./requirements.txt /code/requirements.txt
COPY ./app /code/app
RUN pip install --no-cache-dir -r requirements.txt
EXPOSE 80
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
我们还需要处理 requirements.txt 文件,其中明确列出所有使用的库及其版本。
fastapi==0.94.0
GeoAlchemy2==0.13.1
numpy==1.24.2
SQLAlchemy==1.4.41
sqlmodel==0.0.8
uvicorn==0.20.0
psycopg2==2.9.5
在进行这些更新之后,我们现在可以构建容器镜像(在包含 Dockerfile 的文件夹内):
docker build -t fast-api-stations .
使用 docker-compose
docker-compose 是一个用于定义和运行多容器 Docker 应用程序的工具。在我们的案例中,我们希望同时运行 SQL 容器和 FastAPI 容器。我将假设您已经在计算机上安装了 docker-compose。如果没有,请按照这些说明进行操作。
为了使用 docker-compose,我们只需在项目的根目录中配置一个 docker-compose.yml 文件,该文件定义了构成应用程序的服务及其各自的配置。
docker-compose.yml 文件使用 YAML 语法定义了一组服务,每个服务代表一个将作为全球应用程序的一部分运行的容器。每个服务可以指定其镜像、构建上下文、环境变量、持久化卷、端口等……
这就是我们的 docker-compose.yml 文件的样子:
version: "3"
services:
fastapi:
image: fast-api-stations
ports:
- "8000:80"
stationdb:
image: postgis/postgis
environment:
POSTGRES_USER: jkaub
POSTGRES_PASSWORD: jkaub
POSTGRES_DB: stations
volumes:
- ./db:/var/lib/postgresql/data
如您所见,我们定义了两个服务:
-
一个用于 API 的服务,现在名为 FastAPI,构建在我们在前一个小节中创建的 Docker 镜像 fast-api-station 上。对于这个服务,我们将容器的 80 端口暴露给本地的 8000 端口。
-
一个用于 DB,运行在 PostGIS 镜像上。我们指定了与之前相同的环境变量和相同的卷以持久化数据库。
最后一个小修改
我们曾经使用本地 IP 连接到 SQL 引擎。由于我们现在在两个不同的环境中运行 API 和 PostgreSQL,我们需要更改连接数据库的方式。
docker-compose 自行管理不同容器之间的网络,并使我们能够轻松地从一个服务连接到另一个服务。为了从 API 服务连接到 SQL 服务,我们可以在引擎创建时指定要连接的服务名称:
DATABASE_URL = 'postgresql://jkaub:jkaub@stationdb/stations'
运行后端
现在我们已经配置好了所有内容,我们可以通过以下方式运行我们的后端应用程序:
docker-compose up
API 将通过 8000 端口提供服务。
http://localhost:8000/docs
结论
在这篇文章中,我们一直在处理我们 GasFinder 应用程序的后端。
我们决定将应用程序的所有相关数据存储在我们自己的存储解决方案中,以避免所有可能与依赖第三方连接相关的问题。
我们利用了 Docker 和 PostgreSQL+PostGIS 构建了一个数据库,使我们能够执行高效的地理查询,并使用 Python 框架 FastAPI + SQLModel 构建了一个高效的 API,可以用来与数据库交互,并向前端提供数据,这些前端是在之前的文章中开发的。
目前,我们有一个基于“生产标准”工具(React、PostgreSQL、FastAPI 等)的原型,可以在本地 100%运行。在本系列的最后部分,我们将看看如何使应用程序上线并自动更新我们的 SQL 表,以始终提供最新的信息。
使用这个技巧构建更好的条形图
原文:
towardsdatascience.com/build-a-better-bar-chart-with-this-trick-c66979cb17e1
(这实际上是一个 seaborn 散点图!)
·发布在 Towards Data Science ·7 min read·2023 年 8 月 26 日
--

“国会年龄”散点图的一部分(所有图片由作者提供)
每当我需要寻找有效的可视化灵感时,我都会浏览 经济学人、视觉资本家 或 华盛顿邮报。在其中一次探索中,我发现了一个有趣的信息图表——类似于上面展示的图表——它绘制了每个美国国会议员的年龄与他们的代际群体之间的关系。
我的第一印象是这是一个 水平条形图,但仔细观察发现每个条形图由多个 标记 组成,使其成为一个 散点图。每个标记代表一个国会成员。
在这个 快速成功数据科学 项目中,我们将使用 Python、pandas 和 seaborn 重建这个吸引人的图表。在这个过程中,我们将揭示一些你可能不知道存在的标记类型。
数据集
由于美国有 候选资格年龄 法律,国会成员的生日属于公开记录。你可以在多个地方找到它们,包括 美国国会传记名录 和 维基百科。
为了方便,我已经编制了一个包含当前国会议员姓名、生日、政府分支和政党的 CSV 文件,并将其存储在这个 Gist 中。
代码
以下代码是在 Jupyter Lab 中编写的,并且由单元格描述 描述。
导入库
from collections import defaultdict # For counting members by age.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches # For drawing boxes on the plot.
import pandas as pd
import seaborn as sns
为代际数据分配常量
我们将对图表进行标注,以突出显示诸如婴儿潮一代和X 世代等代际群体。以下代码计算每个群体的当前年龄范围,并包含代际名称和突出显示颜色的列表。因为我们希望将这些列表视为常量,所以我们将名称大写,并使用下划线作为前缀。
# Prepare generational data for plotting as boxes on chart:
CURRENT_YEAR = 2023
_GEN_NAMES = ['Silent', 'Boomers', 'Gen X', 'Millennials', 'Gen Z']
_GEN_START_YR = [1928, 1946, 1965, 1981, 1997]
_GEN_END_YR = [1945, 1964, 1980, 1996, 2012]
_GEN_START_AGE = [CURRENT_YEAR - x for x in _GEN_END_YR]
_GEN_END_AGE = [CURRENT_YEAR - x for x in _GEN_START_YR]
_GEN_COLORS = ['lightgray', 'white', 'lightgray', 'white', 'lightgray']
将生日转换为年龄
为了计算每位成员的年龄,我们首先必须将参考日期(8/25/2023)和 DataFrame 的“Birthday”列转换为datetime格式,使用 pandas 的to_datetime()方法。
现在我们有了兼容的、"日期感知"格式,我们可以通过减去这两个值,提取天数,然后将天数除以 365.25 来生成一个"年龄"列。
# Load the data:
df = pd.read_csv('https://bit.ly/3EdQrai')
# Assign the current date:
current_date = pd.to_datetime('8/25/2023')
# Convert "Birthday" column to datetime:
df['Birthday'] = pd.to_datetime(df['Birthday'])
# Make a new "Age" column in years:
df['Age'] = ((current_date - df['Birthday']).dt.days) / 365.25
df['Age'] = df['Age'].astype(int)
df.head(3)

初始 DataFrame 的头部
计算成员的年龄
我们最终希望按党派和政府分支对成员进行分组。这意味着我们需要生成四个单独的图表。(我们将与民主党一起包含 3 位独立人士,他们与民主党一起开会)。
与简单的条形图不同,我们需要知道的不仅仅是,比如说,57 岁共和党参议员的总数。由于我们想为特定年龄类别的每个成员绘制单独的标记,我们需要一个累计总数。这样,我们可以使用(计数,年龄)值作为散点图中的(x,y)坐标。因此,第一位 57 岁共和党参议员将在计数列中标记为“1”,第二位参议员标记为“2”,以此类推。
为了管理这一点,我们将首先设置四个 DataFrame 列来保存计数,然后制作四个相应的字典来记录初始计数。我们将使用collections模块的[defaultdict()](https://docs.python.org/3/library/collections.html#defaultdict-objects)容器,而不是标准字典,因为它会为不存在的键提供默认值,而不是引发令人烦恼的KeyError。
接下来,我们将遍历 DataFrame,按“Branch”和“Party”列进行过滤。每次我们增加字典时,我们将更新“Age”列。这使我们能够保持匹配年龄的累积计数。
请注意,我们使用负数值来表示民主党计数,因为我们希望它们绘制在中央轴的左侧,而共和党的值绘制在右侧。
# Initialize count columns:
df['R count house'] = 0
df['D count house'] = 0
df['R count senate'] = 0
df['D count senate'] = 0
# Create dictionaries with default values of 0:
r_count_h_dict = defaultdict(int)
d_count_h_dict = defaultdict(int)
r_count_s_dict = defaultdict(int)
d_count_s_dict = defaultdict(int)
# Iterate through the DataFrame and update counts:
for index, row in df.iterrows():
age = row['Age']
if row['Branch'] == 'House':
if row['Party'] == 'R':
r_count_h_dict[age] += 1
df.at[index, 'R count house'] = r_count_h_dict[age]
elif row['Party'] == 'D':
d_count_h_dict[age] -= 1
df.at[index, 'D count house'] = d_count_h_dict[age]
elif row['Branch'] == 'Senate':
if row['Party'] == 'R':
r_count_s_dict[age] += 1
df.at[index, 'R count senate'] = r_count_s_dict[age]
elif row['Party'] == 'D':
d_count_s_dict[age] -= 1
df.at[index, 'D count senate'] = d_count_s_dict[age]
elif row['Party'] == 'I':
d_count_s_dict[age] -= 1
df.at[index, 'D count senate'] = d_count_s_dict[age]
df.head(3)

遮蔽零计数
我们不想绘制零值,因此我们将使用掩码将这些值转换为 DataFrame 中的NaN(非数字)值。
# Filter out zero values:
mask = df != 0
# Apply the mask to the DataFrame:
df = df[mask]
df.head(3)

定义绘图函数
如前所述,我们将制作四个图表。为了避免重复代码,我们将把绘图指令封装到一个可重用的函数中。
这个函数将接受一个 DataFrame、一个 matplotlib 轴对象的名称、作为 x 坐标的列、一个颜色和一个标题作为参数。我们会关闭 seaborn 的大部分默认设置,比如轴刻度和标签,以便我们的图形尽可能干净和简洁。
这个图的一个重要组成部分是用于每个国会议员的矩形标记(marker=$\u25AC$)。这个标记不是标准 matplotlib 集合的一部分,而是STIX 字体符号的一部分。你可以在这里找到这些替代标记的列表。
def make_plot(data, ax, x, color, title):
"""Make a custom seaborn scatterplot with annotations."""
sns.scatterplot(data=data,
x=x,
y='Age',
marker='$\u25AC$',
color=color,
edgecolor=color,
ax=ax,
legend=False)
# Set the border positions and visibility:
ax.spines.left.set_position('zero')
ax.spines.right.set_color('none')
ax.spines.top.set_color('none')
ax.spines.bottom.set_color('none')
# Set x and y limits, ticks, labels, and title:
ax.set_xlim(-15, 15)
ax.set_ylim(25, 100)
ax.tick_params(bottom=False)
ax.set(xticklabels=[])
ax.set(yticklabels=[])
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title(title)
# Manually annotate the y-axis along the right border:
ax.text(x=12.5, y=96, s='Age')
ax.set_yticks(np.arange(30, 101, 10))
ylabels = [30, 40, 50, 60, 70, 80, 90]
for label in ylabels:
ax.text(x=13, y=label, s=str(label))
# Add shading and annotation for each generation:
for _, (name, start_age, end_age, gcolor) in enumerate(zip(_GEN_NAMES,
_GEN_START_AGE,
_GEN_END_AGE,
_GEN_COLORS)):
rect = patches.Rectangle((-15, start_age),
width=30,
height=end_age - start_age,
facecolor=gcolor,
alpha=0.3)
rect.set_zorder(0) # Move shading below other elements.
ax.add_patch(rect)
ax.text(x=-15, y=end_age - 2, s=name)
plt.tight_layout()
绘制图形
以下代码设置了图形并调用了make_plot()函数四次。最后添加了超级标题和自定义图例。
# Make the figure and call the plotting function:
fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(8, 5))
make_plot(df, ax0, 'D count house', 'blue', 'House' )
make_plot(df, ax0, 'R count house', 'firebrick', 'House')
make_plot(df, ax1, 'D count senate', 'blue', 'Senate')
make_plot(df, ax1, 'R count senate', 'firebrick', 'Senate')
# Add figure title and custom legend:
fig.suptitle('Age of US Congress 2023')
ax0.text(x=-15, y=17, s='$\u25AC$ Democrat & Independent', color='blue')
ax0.text(x=1.7, y=17, s='$\u25AC$ Republican', color='firebrick');
# Optional line to save figure:
# plt.savefig('age_of_congress.png', bbox_inches='tight', dpi=600)

最终的图形。
结论
最佳信息图以干净、引人注目的风格讲述故事。正如写得很好的 Python 代码几乎不需要注释一样,优秀的信息图也不需要很多标签或注释。
在这个项目中,我们使用 pandas 加载和准备数据,并使用 seaborn 生成一个模仿条形图的散点图。这个图的一个关键特性是使用STIX 字体符号作为矩形标记。
对于具有许多低计数值的数据集,这种散点图方法比标准条形图更具视觉吸引力,因为条形图中的许多条形会很短。此外,用不同的标记表示每个成员比为多个成员显示单一条形图更能“个性化”数据。
谢谢!
感谢阅读,请关注我以获取更多快速成功数据科学项目。
使用 Reflex 在纯 Python 中构建一个类似 ChatGPT 的 Web 应用
使用 OpenAI 的 API 在纯 Python 中构建一个聊天 Web 应用,部署只需一行代码
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 11 月 7 日
--
聊天应用 GIF 由作者提供
在过去的几个月里,我一直在玩各种令人惊叹的新 LLM 聊天机器人,包括 Llama 2、GPT-4、Falcon 40B 和 Claude 2。一个始终困扰我的问题是,我如何构建自己的聊天机器人界面,调用所有这些出色的 LLM 作为 API?
现在有无数的选项可以用来构建美丽的用户界面,但作为一名机器学习工程师,我对 JavaScript 或任何前端语言都没有经验。我在寻找一种只使用我目前知道的语言——Python 来构建我的 Web 应用的方法!
我决定使用一个相对较新的开源框架 Reflex,它允许我完全用 Python 构建后端和前端。
免责声明: 我在 Reflex 担任创始工程师,负责对开源框架做出贡献。
在本教程中,我们将讲解如何从头开始使用纯 Python 构建一个完整的 AI 聊天应用——你还可以在这个 Github 仓库 找到所有代码。
你将学到如何:
-
安装
reflex并设置你的开发环境。 -
创建组件来定义和设计你的 UI。
-
使用状态为你的应用添加交互性。
-
使用一行命令部署你的应用,与你人分享。
设置你的项目
我们将从创建一个新项目和设置开发环境开始。首先,为你的项目创建一个新目录并进入该目录。
~ $ mkdir chatapp
~ $ cd chatapp
接下来,我们将为我们的项目创建一个虚拟环境。在这个示例中,我们将使用 venv 来创建虚拟环境。
chatapp $ python3 -m venv .venv
$ source .venv/bin/activate
现在,我们将安装 Reflex 并创建一个新项目。这将创建项目目录中的新目录结构。
chatapp $ pip install reflex
chatapp $ reflex init
────────────────────────────────── Initializing chatapp ───────────────────────────────────
Success: Initialized chatapp
chatapp $ ls
assets chatapp rxconfig.py .venv
你可以运行模板应用来确保一切正常。
chatapp $ reflex run
─────────────────────────────────── Starting Reflex App ───────────────────────────────────
Compiling: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 1/1 0:00:00
─────────────────────────────────────── App Running ───────────────────────────────────────
App running at: http://localhost:3000
你应该能在 localhost:3000 看到你的应用运行。
Reflex 还会启动后台服务器,处理所有状态管理和与前端的通信。你可以通过访问 localhost:8000/ping 来测试后台服务器是否正在运行。
现在我们已经设置好了项目,开始构建我们的应用吧!
基本前端
让我们从定义聊天应用的前端开始。在 Reflex 中,前端可以拆分为独立的、可重用的组件。有关更多信息,请查看 组件文档。
显示问题和答案
我们将修改 chatapp/chatapp.py 文件中的 index 函数,以返回一个显示单个问题和答案的组件。

作者提供的图片(代码如下)
# chatapp.py
import reflex as rx
def index() -> rx.Component:
return rx.container(
rx.box(
"What is Reflex?",
# The user's question is on the right.
text_align="right",
),
rx.box(
"A way to build web apps in pure Python!",
# The answer is on the left.
text_align="left",
),
)
# Add state and page to the app.
app = rx.App()
app.add_page(index)
app.compile()
组件可以相互嵌套以创建复杂的布局。在这里,我们创建了一个父容器,其中包含两个框,用于显示问题和答案。
我们还为组件添加了一些基本的样式。组件接受关键字参数,称为 props,这些参数可以修改组件的外观和功能。我们使用 text_align 属性将文本对齐到左侧和右侧。
重用组件
现在我们有了一个显示单个问题和答案的组件,我们可以重用它来显示多个问题和答案。我们将把该组件移动到一个单独的函数 question_answer 中,并从 index 函数中调用它。

作者提供的图片(代码如下)
def qa(question: str, answer: str) -> rx.Component:
return rx.box(
rx.box(question, text_align="right"),
rx.box(answer, text_align="left"),
margin_y="1em",
)
def chat() -> rx.Component:
qa_pairs = [
(
"What is Reflex?",
"A way to build web apps in pure Python!",
),
(
"What can I make with it?",
"Anything from a simple website to a complex web app!",
),
]
return rx.box(
*[
qa(question, answer)
for question, answer in qa_pairs
]
)
def index() -> rx.Component:
return rx.container(chat())
聊天输入
现在我们希望用户能够输入一个问题。为此,我们将使用input组件让用户添加文本,并使用button组件来提交问题。

作者提供的图片(下面的代码)
def action_bar() -> rx.Component:
return rx.hstack(
rx.input(placeholder="Ask a question"),
rx.button("Ask"),
)
def index() -> rx.Component:
return rx.container(
chat(),
action_bar(),
)
样式
让我们给应用添加一些样式。有关样式的更多信息可以在styling docs中找到。为了保持代码的整洁,我们将样式移动到一个单独的文件chatapp/style.py中。
# style.py
# Common styles for questions and answers.
shadow = "rgba(0, 0, 0, 0.15) 0px 2px 8px"
chat_margin = "20%"
message_style = dict(
padding="1em",
border_radius="5px",
margin_y="0.5em",
box_shadow=shadow,
max_width="30em",
display="inline-block",
)
# Set specific styles for questions and answers.
question_style = message_style | dict(
bg="#F5EFFE", margin_left=chat_margin
)
answer_style = message_style | dict(
bg="#DEEAFD", margin_right=chat_margin
)
# Styles for the action bar.
input_style = dict(
border_width="1px", padding="1em", box_shadow=shadow
)
button_style = dict(bg="#CEFFEE", box_shadow=shadow)
我们将导入chatapp.py中的样式并在组件中使用它们。此时,应用应该如下所示:

作者提供的图片
# chatapp.py
import reflex as rx
from chatapp import style
def qa(question: str, answer: str) -> rx.Component:
return rx.box(
rx.box(
rx.text(question, style=style.question_style),
text_align="right",
),
rx.box(
rx.text(answer, style=style.answer_style),
text_align="left",
),
margin_y="1em",
)
def chat() -> rx.Component:
qa_pairs = [
(
"What is Reflex?",
"A way to build web apps in pure Python!",
),
(
"What can I make with it?",
"Anything from a simple website to a complex web app!",
),
]
return rx.box(
*[
qa(question, answer)
for question, answer in qa_pairs
]
)
def action_bar() -> rx.Component:
return rx.hstack(
rx.input(
placeholder="Ask a question",
style=style.input_style,
),
rx.button("Ask", style=style.button_style),
)
def index() -> rx.Component:
return rx.container(
chat(),
action_bar(),
)
app = rx.App()
app.add_page(index)
app.compile()
应用看起来不错,但还不是很有用!现在让我们添加一些功能。
状态
现在让我们通过添加状态来使聊天应用变得互动。状态是我们定义应用中所有可以变化的变量以及所有可以修改这些变量的函数的地方。你可以在state docs中了解更多关于状态的信息。
定义状态
我们将在chatapp目录中创建一个名为state.py的新文件。我们的状态将跟踪当前提出的问题和聊天记录。我们还将定义一个事件处理程序answer,它将处理当前的问题并将答案添加到聊天记录中。
# state.py
import reflex as rx
class State(rx.State):
# The current question being asked.
question: str
# Keep track of the chat history as a list of (question, answer) tuples.
chat_history: list[tuple[str, str]]
def answer(self):
# Our chatbot is not very smart right now...
answer = "I don't know!"
self.chat_history.append((self.question, answer))
将状态绑定到组件
现在我们可以在chatapp.py中导入状态,并在前端组件中引用它。我们将修改chat组件,以使用状态代替当前固定的问题和答案。

作者提供的图片
# chatapp.py
from chatapp.state import State
...
def chat() -> rx.Component:
return rx.box(
rx.foreach(
State.chat_history,
lambda messages: qa(messages[0], messages[1]),
)
)
...
def action_bar() -> rx.Component:
return rx.hstack(
rx.input(
placeholder="Ask a question",
on_change=State.set_question,
style=style.input_style,
),
rx.button(
"Ask",
on_click=State.answer,
style=style.button_style,
),
)
普通的 Python for 循环无法用于遍历状态变量,因为这些值可能会变化且在编译时未知。相反,我们使用foreach组件来遍历聊天记录。
我们还将输入框的on_change事件绑定到set_question事件处理程序,这将更新question状态变量,而用户在输入框中输入时。我们将按钮的on_click事件绑定到answer事件处理程序,这将处理问题并将答案添加到聊天记录中。set_question事件处理程序是一个内置的隐式定义事件处理程序。每个基础变量都有一个。更多信息请查看events docs中的 Setters 部分。
清空输入框
目前,用户点击按钮后输入框不会清空。我们可以通过将输入框的值绑定到question,设置为value=State.question,并在运行answer事件处理程序时将其清空,使用self.question = ''来解决这个问题。
# chatapp.py
def action_bar() -> rx.Component:
return rx.hstack(
rx.input(
value=State.question,
placeholder="Ask a question",
on_change=State.set_question,
style=style.input_style,
),
rx.button(
"Ask",
on_click=State.answer,
style=style.button_style,
),
)
# state.py
def answer(self):
# Our chatbot is not very smart right now...
answer = "I don't know!"
self.chat_history.append((self.question, answer))
self.question = ""
流式文本
通常,状态更新会在事件处理程序返回时发送到前端。然而,我们希望在生成的过程中流式传输来自聊天机器人的文本。我们可以通过从事件处理程序中生成来实现。有关更多信息,请参见事件生成文档。
# state.py
import asyncio
...
async def answer(self):
# Our chatbot is not very smart right now...
answer = "I don't know!"
self.chat_history.append((self.question, ""))
# Clear the question input.
self.question = ""
# Yield here to clear the frontend input before continuing.
yield
for i in range(len(answer)):
# Pause to show the streaming effect.
await asyncio.sleep(0.1)
# Add one letter at a time to the output.
self.chat_history[-1] = (
self.chat_history[-1][0],
answer[: i + 1],
)
yield
使用 API
我们将使用 OpenAI 的 API 为我们的聊天机器人提供一些智能。我们需要修改事件处理程序以向 API 发送请求。
# state.py
import os
import openai
openai.api_key = os.environ["OPENAI_API_KEY"]
...
def answer(self):
# Our chatbot has some brains now!
session = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": self.question}
],
stop=None,
temperature=0.7,
stream=True,
)
# Add to the answer as the chatbot responds.
answer = ""
self.chat_history.append((self.question, answer))
# Clear the question input.
self.question = ""
# Yield here to clear the frontend input before continuing.
yield
for item in session:
if hasattr(item.choices[0].delta, "content"):
answer += item.choices[0].delta.content
self.chat_history[-1] = (
self.chat_history[-1][0],
answer,
)
yield
最后,我们有了我们的 AI 聊天机器人!
结论
按照这个教程,我们已经成功创建了使用 OpenAI API 密钥的聊天应用,完全用 Python 编写。
要现在运行这个应用,我们可以运行简单的命令:
$ reflex run
要部署它,以便我们可以与其他用户分享,我们可以运行以下命令:
$ reflex deploy
我希望这个教程能激发你们构建自己的 LLM 基础应用。我迫不及待想看到你们最终开发出什么,所以请在社交媒体或评论中与我联系。
如果你有问题,请在下方评论或者通过 Twitter 上的@tgotsman12或LinkedIn给我发消息。分享你的应用创作到社交媒体并标记我,我很乐意提供反馈或帮助转发!
使用 Numpy 从零开始构建卷积神经网络
原文:
towardsdatascience.com/build-a-convolutional-neural-network-from-scratch-using-numpy-139cbbf3c45e
通过从零开始构建 CNN 来掌握计算机视觉
·发表于Towards Data Science·8 分钟阅读·2023 年 11 月 23 日
--

这些彩色窗户让我想起了 CNN 的层及其过滤器。图片来源:unsplash.com。
由于计算机视觉应用现在无处不在,每个数据科学从业者都必须了解其工作原理和熟悉它们。
在这篇文章中,我在不依赖流行的现代深度学习库如 Tensorflow、Pytorch 和 Keras 的情况下构建了深度神经网络。我随后用它对手写数字进行分类。尽管取得的结果未达到最先进的水平,但仍然令人满意。现在,我想在使用 Python 库Numpy的情况下迈出进一步一步,开发一个卷积神经网络(CNN)。
Python 深度学习库,如上所述,是极其强大的工具。然而,作为一个缺点,它们使数据科学从业者无法理解神经网络的低级工作原理。这一点在 CNN 中尤其明显,因为它们的过程比经典的全连接网络更不直观。解决这一问题的唯一方法就是亲自动手实现 CNN:这就是这项任务的动机。
这篇文章旨在作为一个实用的、动手指南,而不是一个全面的 CNN 工作原理指南。因此,理论部分简明扼要,主要服务于实际部分的理解。为此,你会在本文末尾找到一份详尽的资源列表。我热情邀请你去查看它们!
卷积神经网络
卷积神经网络使用一种特定的架构和操作,使其非常适合与图像相关的任务,例如图像分类、目标定位、图像分割等。它们的设计大致反映了人类视觉皮层,其中每个生物神经元仅响应视觉场的一个小部分。此外,更高级的神经元对低级神经元的输出做出反应。
虽然经典的全连接网络可以处理与图像相关的任务,但当应用于中等或大型图像时,由于所需的参数数量众多,它们的效果会显著下降。例如,一个 200x200 像素的图像包含 40,000 个像素,如果网络的第一层有 1,000 个单元,仅这一层就需要 4000 万个权重。这个挑战通过 CNN 得到了极大的缓解,因为它们实现了部分连接层和权重共享。
卷积神经网络的主要组件是:
-
卷积层
-
池化层
卷积层
卷积层由一组滤波器组成,也称为内核。当应用于层的输入时,这些滤波器以特定方式修改原始图像。
一个滤波器可以描述为一个矩阵,其元素值定义了对原始图像应用的修改类型。例如,如下的 3x3 内核突出了图像的垂直边缘:

这个内核则强调水平边缘:


来源: 维基百科.
需要注意的是,这些内核的元素值不是手动选择的,而是网络在训练过程中学习的参数。
卷积的主要功能是隔离并突出图像中存在的不同特征。稍后,密集层将使用这些特征。
池化层
池化层比卷积层更简单。它们的目的是最小化网络的计算负担和内存使用。它们通过缩小输入图像的尺寸来实现这一目标。降低维度会减少 CNN 需要学习的参数数量。
池化层也使用一个通常为 2x2 维度的内核,将输入图像的一部分汇聚成一个值。例如,一个 2x2 的最大池化内核从输入图像中提取 4 个像素,并只输出值最大的像素。
Python 实现
你可以在我的 GitHub 仓库中找到本节中展示的所有代码。
[## GitHub - andreoniriccardo/CNN-from-scratch: 从零开始构建卷积神经网络
从零开始构建卷积神经网络。通过创建一个…来贡献到 andreoniriccardo/CNN-from-scratch 的开发。
该实现的概念是创建Python 类,代表卷积层和最大池化层。此外,由于此 CNN 将应用于著名的开源 MNIST 数据集,我还创建了一个用于 Softmax 层的特定类。
在每个类中,我定义了执行前向传播和反向传播步骤的方法。
作为最后一步,将这些层附加到一个列表中以构建最终的卷积神经网络。
卷积层实现
定义卷积层的代码如下:
class ConvolutionLayer:
def __init__(self, kernel_num, kernel_size):
self.kernel_num = kernel_num
self.kernel_size = kernel_size
self.kernels = np.random.randn(kernel_num, kernel_size, kernel_size) / (kernel_size**2)
def patches_generator(self, image):
image_h, image_w = image.shape
self.image = image
for h in range(image_h-self.kernel_size+1):
for w in range(image_w-self.kernel_size+1):
patch = image[h:(h+self.kernel_size), w:(w+self.kernel_size)]
yield patch, h, w
def forward_prop(self, image):
image_h, image_w = image.shape
convolution_output = np.zeros((image_h-self.kernel_size+1, image_w-self.kernel_size+1, self.kernel_num))
for patch, h, w in self.patches_generator(image):
convolution_output[h,w] = np.sum(patch*self.kernels, axis=(1,2))
return convolution_output
def back_prop(self, dE_dY, alpha):
dE_dk = np.zeros(self.kernels.shape)
for patch, h, w in self.patches_generator(self.image):
for f in range(self.kernel_num):
dE_dk[f] += patch * dE_dY[h, w, f]
self.kernels -= alpha*dE_dk
return dE_dk
**ConvolutionLayer**类的构造函数接收卷积层的内核数量及其大小作为输入。我假设只使用大小为**kernel_size**乘**kernel_size**的方形内核**。
然后,我生成形状为**(kernel_num, kernel_size, kernel_size)**的随机滤波器,并通过将每个元素除以平方的内核大小进行归一化。
**patches_generator()**方法是一个生成器。它生成要进行每个卷积步骤的图像部分。
**forward_prop()**方法对上述方法生成的每个补丁执行卷积操作。
最后,**back_prop()**方法负责计算损失函数相对于每层权重的梯度。它还相应地更新权重值。请注意,这里提到的损失函数不是网络的全局损失,而是由最大池化层传递给前一个卷积层的损失函数。
为了演示此类的实际效果,我创建了一个 **ConvolutionLayer** 实例,具有 32 个滤波器,每个滤波器的大小为 3x3。然后,我在一张图像上应用前向传播方法,得到由 32 张稍小的图像组成的输出。
初始输入图像的大小为 28x28 像素,如下图所示:

图片来源于作者。
一旦我应用了卷积层的**forward_prop()**方法,我得到 32 张 26x26 像素的图像。其中之一如下:

图片由作者提供。
正如你所见,图像的尺寸被缩小了,手写数字的清晰度更差了。需要注意的是,这一操作是由一个包含随机值的滤波器执行的,因此,它并不准确代表经过训练的 CNN 实际执行的步骤。尽管如此,你可以理解这些卷积如何产生较小的图像,在这些图像中对象的特征被分离出来。
Max Pooling 层实现
我使用 Numpy 定义了 Max Pooling 层类,如下所示:
class MaxPoolingLayer:
def __init__(self, kernel_size):
self.kernel_size = kernel_size
def patches_generator(self, image):
output_h = image.shape[0] // self.kernel_size
output_w = image.shape[1] // self.kernel_size
self.image = image
for h in range(output_h):
for w in range(output_w):
patch = image[(h*self.kernel_size):(h*self.kernel_size+self.kernel_size), (w*self.kernel_size):(w*self.kernel_size+self.kernel_size)]
yield patch, h, w
def forward_prop(self, image):
image_h, image_w, num_kernels = image.shape
max_pooling_output = np.zeros((image_h//self.kernel_size, image_w//self.kernel_size, num_kernels))
for patch, h, w in self.patches_generator(image):
max_pooling_output[h,w] = np.amax(patch, axis=(0,1))
return max_pooling_output
def back_prop(self, dE_dY):
dE_dk = np.zeros(self.image.shape)
for patch,h,w in self.patches_generator(self.image):
image_h, image_w, num_kernels = patch.shape
max_val = np.amax(patch, axis=(0,1))
for idx_h in range(image_h):
for idx_w in range(image_w):
for idx_k in range(num_kernels):
if patch[idx_h,idx_w,idx_k] == max_val[idx_k]:
dE_dk[h*self.kernel_size+idx_h, w*self.kernel_size+idx_w, idx_k] = dE_dY[h,w,idx_k]
return dE_dk
构造方法仅分配内核大小值。接下来的方法与卷积层中定义的方法类似,主要的区别在于**back_prop()**方法不会更新任何权重值。实际上,池化层不依赖于权重来执行聚合操作。
Softmax 层实现
最后,我定义了Softmax 层。它的目的是展平从最终 Max Pooling 层获得的输出体积。Softmax 层输出 10 个值,这些值可以被解释为图像对应于 0 到 9 数字的概率。
实现具有与上述相同的结构:
class SoftmaxLayer:
def __init__(self, input_units, output_units):
self.weight = np.random.randn(input_units, output_units)/input_units
self.bias = np.zeros(output_units)
def forward_prop(self, image):
self.original_shape = image.shape
image_flattened = image.flatten()
self.flattened_input = image_flattened
first_output = np.dot(image_flattened, self.weight) + self.bias
self.output = first_output
softmax_output = np.exp(first_output) / np.sum(np.exp(first_output), axis=0)
return softmax_output
def back_prop(self, dE_dY, alpha):
for i, gradient in enumerate(dE_dY):
if gradient == 0:
continue
transformation_eq = np.exp(self.output)
S_total = np.sum(transformation_eq)
dY_dZ = -transformation_eq[i]*transformation_eq / (S_total**2)
dY_dZ[i] = transformation_eq[i]*(S_total - transformation_eq[i]) / (S_total**2)
dZ_dw = self.flattened_input
dZ_db = 1
dZ_dX = self.weight
dE_dZ = gradient * dY_dZ
dE_dw = dZ_dw[np.newaxis].T @ dE_dZ[np.newaxis]
dE_db = dE_dZ * dZ_db
dE_dX = dZ_dX @ dE_dZ
self.weight -= alpha*dE_dw
self.bias -= alpha*dE_db
return dE_dX.reshape(self.original_shape)

图片由作者提供。
结论
在这篇文章中,我们看到对基本 CNN 架构元素如卷积层和池化层的理论介绍。我相信,逐步的 Python 实现将为你提供实际理解这些理论概念如何转化为代码的途径。
我邀请你克隆包含代码的GitHub 仓库并尝试**main.py**脚本。当然,这个网络并没有达到最先进的性能,因为它不是为了这个目标而构建的,但在经过几个训练周期后,仍然达到了96%的准确率。
最后,为了扩展你对 CNN 和计算机视觉的知识,我建议你查看下面列出的一些资源。
如果你喜欢这个故事,请考虑关注我,以便了解我即将推出的项目和文章!
参考文献
-
“ImageNet 分类与深度卷积神经网络”由 Alex Krizhevsky, Ilya Sutskever, 和 Geoffrey Hinton
-
“《用于大规模图像识别的深度卷积网络》” 作者:Karen Simonyan 和 Andrew Zisserman (VGGNet)
-
“《Python 中的卷积神经网络:掌握数据科学和机器学习,现代深度学习方法,使用 Python、Theano 和 TensorFlow》” 作者:Jason Brownlee
如何为任何团队规模构建数据科学战略
创建一个快速变动且对变化具有弹性的文化和实践
·发表于 走向数据科学 ·20 分钟阅读·2023 年 9 月 11 日
--

照片由 Maarten van den Heuvel 拍摄,来源于 Unsplash
创建一个快速变动且对变化具有弹性的文化和实践
如果你是一个数据科学领导者,被要求在“建立我们的数据科学战略”时拥有很大的自由和很少的方向,这篇文章会对你有所帮助。我们将讨论:
-
我们所说的战略是什么:仅仅是一个计划?一个路线图?还是更多或更少的东西?在本节中,我们将具体化并采用一个工作定义,以了解我们在构建战略时实际构建了什么。
-
这个概念如何在实际的组织背景下应用于数据科学团队?在这里,我们将探讨我们对战略的概念如何适用于数据科学,并具体说明我们的战略应用于哪些方面。
-
如何实际制定该战略。
在整个过程中,我们将大量借鉴研发战略的方法,因为它与数据科学面临的关键挑战类似:创新的使命,以及追求发现所带来的不确定性。当我们结束时,你将获得一个明确的战略定义,以及一个适用于任何规模组织的有用的制定过程。
什么是战略?
如果你像我一样,没有高级 MBA 学位,也从未参加过商业战略研讨会,你可能会对有人要求你制定“数据科学战略”时究竟想要什么感到困惑。你可能会发现最初的搜索并没有多大帮助。像三 C 模型(顾客、竞争者、公司)这样的经典强大框架,在决定公司应该竞争的领域时非常合理。将其应用于一个职能或团队时,你会发现自己感觉在拉伸这些概念的承受范围。
如果你像我一样,阅读像战略领主和麦肯锡方法这样的书籍会让你陷入一个相当深入的阅读漩涡中。(附属链接。)前者是一本令人愉快的商业历史著作,后者是从成功的咨询公司经验中提炼出的有用技巧合集。两者都没有提供快速的答案。阅读《战略领主》的一个非常有益的副作用是了解到数据科学家们并不孤单:“[我]很容易把战略与战略规划混为一谈,但这也是危险的。[…] 即使今天,拥有计划的公司仍然远多于拥有战略的公司。刮掉大多数计划,你会发现某种版本的‘我们将继续做我们一直在做的事,但明年我们会做得更多和/或更好’。这种定义混淆在我的经验中也有所体现,几次对战略的请求实际上简化为‘你接下来几个月的计划是什么?’”
我们将在本文其余部分采用的一个非常有用的战略定义,来源于 Gary Pisano 的这篇关于研发战略的工作论文:“战略不过是一种对行为模式的承诺,旨在帮助赢得竞争。”这个定义的美妙之处在于它可以适用于组织的任何层级和目的。所有类型和规模的团队都参与组织的竞争努力,所有团队都可以定义并声明他们用来集中这些努力的行为模式。
战略不过是一种对行为模式的承诺,旨在帮助赢得竞争。
—Gary Pisano
Pisano 提出了一个好的战略的三个要求:一致性、一致性和对齐。战略应该帮助我们做出一致的决策,这些决策累积起来有助于实现预期目标;应该帮助组织的各个角落将其分散的战术决策协调一致;并且应该使地方行动与更大的集体努力保持一致。
最终,它们都建立在核心假设上,即关于在竞争中提供优势的赌注。皮萨诺的有用例子是苹果,其战略“开发易于使用、外观美观的产品,并与消费者数字世界中的更广泛设备系统无缝集成”建立在一个核心假设上:“客户将愿意为具有这些属性的产品支付显著更高的价格。”
本质上,根据这个定义,所有策略都是包装决策逻辑的赌注:它们为所有各方提供了确定哪些行动有助于集体努力的方法。
我们将采用这个策略定义,并努力定义我们自己的核心战略假设,关于数据科学如何为我们的组织增值,以及我们将在追求该价值的过程中坚持的模式。此外,我们将假设我们的母公司已经制定了自己的战略,这一输入在我们应用第三个对齐测试时将是至关重要的。在确定了我们最终战略的形式之后,我们将把注意力转向限定其范围。
我们所说的数据科学是什么意思,这个战略概念如何适用?
为了提醒我的朋友们我有多么有趣,我给几个人发了相同的短信,“你听到‘数据科学策略’时会想到什么?”答案从对数据基础设施和 MLOps 的深思熟虑,到对问题模糊性的健康反应(我觉得被看到了),再到丰富多彩的“胡说八道”和“我的理想工作”。
尽管样本较小,但来自这一群体的多样化回应——包括初创公司和大型公司的资深产品经理、数据科学负责人和顾问——表明了这个术语的定义可能有多么模糊。更糟糕的是,数据科学家还面临第二层次的困惑:所谓的“数据科学”在实践中往往取决于公司招聘所需的技能,并用流行的标题装点。
为了在我们的分析中固定这些自由度,我们将首先采用一个共同的数据科学定义来进行本文的其余部分:致力于通过建模组织的可用数据来创造价值和竞争优势的职能。这可以采取几种典型形式:
-
构建优化客户决策的机器学习模型以用于生产
-
建立帮助各级员工完成工作的模型,可能应用于客户互动的人工环节应用。
-
建立可解释的模型以辅助商业决策
注意,我们排除了 BI 和分析,仅仅是为了聚焦,而不是因为它们不如建模工作有价值。你的分析部门和数据科学部门应当顺利合作。(我曾在这里写过这方面的内容。)
一些人,比如我的朋友和谷歌产品经理卡罗尔·斯科达斯·沃尔波特,会建议数据科学战略包括“如何使数据和基础设施处于足够好的状态,以进行分析或机器学习。我会说这是如何使团队完成所有工作的。”我们将有意排除这些更广泛的数据战略项目(抱歉,卡罗尔)。不过,我们将讨论如何应对数据和基础设施限制,以及如何发展数据科学战略以积极指导更广泛的数据战略。
现在我们有了界限:我们正在构建一套核心战略假设,关于机器学习和/或人工智能如何为拥有自己定义战略或目标的组织增加最大价值,以及团队在追求这一价值时将承诺的一系列模式。我们该如何开始?
建立我们的战略核心假设:从一个赢得 AI 的心态开始
经验丰富的机器学习产品经理、工程师和数据科学家常常会提到,机器学习产品与传统软件不同。一个组织必须考虑模型错误、数据漂移、模型监控和重新调整的风险——这也是现代 MLOps 的出现原因。而且,很容易在工程中犯错,使 ML 应用陷入技术债务的沼泽中。(请参阅“机器学习:技术债务的高利贷”以获取有关此主题的精彩阅读。)那么,面对所有这些成本,我们为什么要这么做?
从根本上说,我们考虑 AI 解决方案是因为复杂的模型已经证明能够检测到有价值的模式。这些模式可以是从暗示新分段的客户偏好的聚类到神经网络发现以优化预测的潜在表示。任何给定的机器学习构建都依赖于一个假设或预期,即模型可以检测到可以改进过程、发现可操作的发现或改进有价值的预测的模式。
在定义任何规模的数据科学团队的核心战略假设时,我们可以从这个麦肯锡的示例描述开始,描述 AI 驱动的公司如何以不同的方式思考。来自“赢得 AI 是一种心态”:
如果我们选择了正确的用例,并以正确的方式进行,我们将不断了解我们的客户及其需求,并不断改进我们服务他们的方式。
这是构建数据科学战略时非常有用的视角:它让我们专注于最大化学习,我们需要做的只是确定我们组织对“正确”的定义。但对我们来说,“正确”的用例是什么?
在这里,皮萨诺再次提供了帮助,定义了 R&D 战略的四个要素,这些要素很好地适用于数据科学:
-
架构:我们数据科学职能的组织结构(集中式、分布式)和地理结构。
-
过程:管理我们工作的正式和非正式方式。
-
人员:从我们希望吸引的技能组合到我们的人才价值主张的一切。
-
组合:我们如何在项目类型之间分配资源,以及“排序、优先级和选择项目的标准”。
我们将从最后一个概念开始,重点定义我们组织的理想项目组合,即我们可以说服自己能带来最大价值的组合。鉴于组织之间的巨大差异,我们将从每个组织面临的一个挑战开始:风险。
定义你的目标组合:根据你的战略确定风险水平和管理。
建模工作具有不确定的结果。“机器学习可以做得更好”是我们经常根据历史和直觉提出的论点,且这通常被证明是正确的。但我们从一开始就不知道它的效果如何,直到我们通过实际构建证明机器学习能解决问题。了解任何给定用例的答案可能需要不同程度的努力,从而产生不同程度的成本。这种答案的不确定性也可能不同,取决于我们的模型被应用的广泛程度以及我们对数据的理解程度。
一位朋友和医疗分析产品负责人,约翰·梅纳德将风险定义为数据科学战略中的一个明确部分:“你如何维护一个小规模和大规模的投注管道,同时保持健康的期望?当数据无法支持项目时,你的策略是什么?如果项目未能满足要求,你会如何调整交付内容?”
对于组织来说,明确和具体地了解他们能够承担的资源水平及其时间长度是明智的。以下是对任何个人建模工作提出的一些有用问题:
-
成功的估计可能性:这个模型用例成功的概率是多少?
-
预期回报范围:如果成功,这个项目是否能在一个可以大规模带来巨大回报的过程上带来微小的改进?一个突破是否能让你与竞争对手区分开来?
-
发现失败的预期时间:需要多长时间才能了解一个项目的假设价值主张是否会实现?在了解到这个项目不会成功之前,你能花费的最少资源是多少?
希望这些原则是简单明了的,并且都是共识中的好事。理想的项目可能会成功,带来巨大的投资回报,如果失败,则应尽早失败。这种完美的三位一体往往难以实现。关键在于做出适合你组织的权衡。
一个早期阶段的初创公司,专注于利用人工智能颠覆特定领域,可能会有投资者、领导层和员工接受公司作为对特定方法的单一大型投资。或者,它可能更倾向于那些能够快速进入生产并允许快速调整的小型项目。相反,如果我们在一家大型、成熟的公司和监管严格的行业中,并且利益相关者对机器学习持怀疑态度,我们可能会选择将投资组合偏向于低工作量的项目,这些项目提供渐进的价值并快速失败。这可以帮助建立初步信任,使利益相关者适应数据科学项目固有的不确定性,并使团队围绕更雄心勃勃的项目达成一致。成功的早期小型项目还可以加强对同一问题领域的更大项目的支持。
以下是如何定义目标投资组合的一些示例,包括项目范围、持续时间和预期回报:
-
“由于我们在集体数据科学旅程中仍处于早期阶段,我们专注于小型、低工作量和快速失败的用例,这样可以在不冒大量人员时间风险的情况下发现机会。”
-
“我们已经确定了一个包含三个大型机器学习项目的投资组合,每一个项目都有可能释放巨大的价值。”
-
“我们的目标是平衡小型、中型和大型项目,并与相应的回报水平相匹配。这使我们能够在追求具有颠覆性潜力的同时,频繁获得胜利。”
作为应用于我们完整投资组合的最终原则,目标是一个具有非相关成功的项目集合。也就是说,我们希望看到我们的投资组合,并感知项目将独立成功或失败。如果多个项目依赖于共同的假设,如果我们感到它们如此紧密相关以至于它们会一起成功或失败,那么我们应该重新考虑选择。
当我们完成以下任务时,我们就完成了这个阶段:
-
调查了我们的数据科学和机器学习机会
-
按投资、回报和成功可能性进行了绘制
-
选择了一个与我们的目标和风险承受能力一致的粗略优先级列表
现在我们已经确定了目标投资组合,我们将转向确保我们的流程使我们能够快速识别、范围定义和交付有价值的项目。
将你的投资组合集中在你的团队独特的解决方案上
建设还是购买的问题是一个长期存在的问题,并且通常涉及复杂的组织动态。寻找 AI 解决方案的供应商和初创公司不乏其人。许多是“江湖术士”;许多是有效的。许多内部技术和数据科学团队将前者视为笑话,将后者视为竞争对手,并且将区分两者所花费的时间视为巨大的时间浪费。这是有道理的,因为检查供应商的时间并不能提升建模者的技能,如果组织不奖励他们的努力,那么这就是数据科学家付出的成本,而没有职业上的回报。这种人际关系复杂性加剧了本已复杂的商业案例:所有典型的软件解决方案关注点仍然存在。你仍然需要担心供应商锁定和云集成等问题。然而,我们都应该愿意购买那些提供更高投资回报率的供应商产品,如果你考虑到内部团队相对于现成解决方案的独特优势,你可以排除干扰。
特别是,你的内部团队通常可以访问你组织的大部分(可能是全部)专有数据。这意味着内部团队可能会比单一用途的供应商解决方案更深入地理解这些数据,并更容易将其与其他来源进行丰富。只要有足够的时间和计算资源,一个有能力的内部团队可能会比单一用途的供应商解决方案更胜一筹。(这里面有个 PAC 理论的笑话。)但这值得吗?
标准的投资回报率和替代方案分析在这里是关键,重点是你内部市场的时间。假设我们正在优化一个电子商务网站上的广告投放。我们已经将供应商名单缩小到一个领先者,该供应商使用的是一种多臂赌博方法,这是在撰写本文时领先的市场优化供应商中常见的方法。我们估计供应商集成的时间为一个月。或者,我们可以建立自己的 MAB,并估计需要六个月。我们是否期望自己构建的 MAB 能够超越供应商的 MAB,并且值得为了这个延迟而付出?
这要看情况。使用汤普森采样进行多臂赌博问题(MAB)可以为预期的遗憾提供对数界限,这是一种行话,意味着它在探索选项时不会留下一大堆未利用的价值。无论是由你的内部团队还是供应商实施,这一说法仍然可以证明是正确的。相反,你的内部团队更接近你的数据,将这种用例带到内部相当于一个赌注,即你会在数据中找到足够丰富的信号来超越供应商产品。也许你的团队可以注入现成解决方案所没有的领域知识,从而提供有价值的优势。最后,考虑你的内部团队的机会成本:是否有其他高价值的项目他们可以从事?如果有,一个选项是测试供应商,处理其他项目,并在获得可衡量的供应商结果后重新评估。
当我们完成以下任务时,这一阶段就结束了:
-
回顾之前步骤中的机会,并对每个机会回答,“我们能买到这个吗?”
-
对于每个可购买的解决方案,回答我们是否有独特的已知或假设的内部优势。
-
对于每个需要做出真实权衡的领域,进行权衡分析。
确定了我们内部团队的战略竞争优势后,我们现在将考虑我们的内部流程、工具和数据能力。
在你的知识工厂工具和数据供应链周围建立流程。
我与许多经验丰富的数据科学家讨论过时间投入的问题,每个人都提到发现、处理、清理和移动(到适当的计算环境)数据占据了他们大部分工作时间。正如另一组麦肯锡作者在 AutoML 和 AI 人才战略上写道:“许多组织发现,数据科学家花费 60%到 80%的时间来准备建模数据。一旦初步模型构建完成,根据一些分析,只有 4%的时间用于测试和调优代码。”这并不是大多数人进入这一领域的原因。在我们大多数人的心中,这是一种为了构建有影响力的模型的乐趣而付出的代价。因此,我们常常谈论数据科学家成功所需的“基础”。根据我的经验,这种框架很快会阻碍我们,我将挑战我们将自己视为一个模型工厂,受到工具和复杂且常常有问题的数据供应链的限制。
坦白说:当讨论平台时,我从未相信这些“基础”论点。
“数据和机器学习平台是成功的机器学习所依赖的基础,”这是无数幻灯片和白皮书中的明确声明。 “如果没有强大的基础,”某些顾问父爱般地总结道,“一切都会崩溃。”
然而,这里有个问题:很少有事情会因为没有机器学习而“崩溃”。如果你的房子建在不好的基础上,车库可能会塌陷在自己身上,甚至你也会受害。如果你在没有完善的数据和机器学习平台的情况下开始机器学习项目,那么你的模型构建将会……需要更长时间。而且没有那种新奇的机器学习模型,你的业务很可能会继续以原有方式运行,尽管缺少了机器学习原本旨在提供的某种竞争优势。但在平庸中持续并非末日。
这就是这个陈词滥调让我感到困惑的地方。它试图吓唬高管们投入平台建设——值得强调的是有价值的建设——好像没有它们世界就会末日,但其实不会。我们大声喊着天空要塌下来,然后当一个利益相关者遇到他们习惯的老雨时,我们失去了信誉。
尽管如此,我敢打赌,拥有强大机器学习能力的公司将超越那些没有这种能力的竞争者——我明白作为建模负责人我的职业生涯正是这样一种赌注——而现代数据和 MLOps 能力可以大大缩短 AI 能力的市场时间。请参见麦肯锡论文中的摘录 “像科技本地人一样扩展 AI: CEO 的角色”,强调由我加的:
我们经常听到高管们说,将 AI 解决方案 从构想到实施需要九个月到一年以上,使得很难跟上市场动态的变化。即便经过多年的投资,领导者们经常告诉我们他们的组织并没有加快速度。相比之下, 采用 MLOps 的公司可以在仅仅两到 12 周内将构思转变为实际解决方案 而不会增加人员或技术债务,减少了实现价值的时间,并释放团队以更快地扩展 AI。
你的数据科学战略需要考虑你的组织和工具约束,并采用在这些约束下能产生可操作模型或知识单元的模式。也就是说,建模项目应该总是包含:
-
清晰地了解最小可行建模数据。你的数据科学团队应该知道源数据的位置,并且对数据需要如何转化有一个大致的了解。
-
实现价值的直接和现实路径。你将如何让一个具有足够性能的模型投入使用,或者以其他方式应用模型结果?
早期阶段的公司或团队如果在架构和工具方面拥有完全的自由,将有利于采用现代 MLOps 实践,这将使得快速原型设计、部署和监控模型以评估其在实际世界中影响变得更加容易。与传统遗留技术并行工作或在其中工作的团队可能会发现这些技术并未考虑 ML 集成,并且部署是一个庞大的、沉重的过程。受严格监管行业中的公司将发现许多应用程序需要高度的可解释性和风险控制。
这些挑战没有不可克服的。我们只需在时间表的影响上保持原则性和聪明,并将其纳入我们的决策中。
当我们完成这个阶段时,我们会有:
-
调查我们计划中的用例,以确定每个用例获取数据的路径以便开始。
-
确定每个用例的实现价值路径,如果它能够成功的话。
-
将这一点纳入我们的预期投资中,并从第一步开始进行调整。
-
根据我们发现的任何变化调整了我们的优先级。
在完善了我们关于数据科学部署的想法后,我们将考虑工作模型以确保一致性。
架构与组织:构建一个能够持续成功的组织结构。
Pisano 将架构定义为“围绕研发在组织和地理上的结构所做的一系列决策。” 设计这点包括对如何将我们的数据科学家与业务部门整合做出深思熟虑的决定。他们是完全集中并有正式的接收流程?向不同的业务单位汇报?集中并嵌入其中?汇报结构和决策权可能不在你的掌控之下,特别是当你被要求为有明确汇报线的部门建立战略时。但如果这些点正在讨论中,这里有一些最大化数据科学成果价值的考虑因素。
你的数据科学家会得到良好的支持并被适当衡量吗? 考虑一下初级数据科学人才的来源。数据科学家来自各种定量背景,通常具有理论和实践技能的混合。一个典型的硕士毕业生在这些形成阶段中建立了技能和理解,并向领域专家展示了这些理解。这通常不包括大量的与非专家沟通技术发现的培训。
与他们在商业环境中的经历相比,他们可能对领域的了解较少,并且是少数拥有方法知识的人之一。他们将被要求应用少数人理解的技术。他们的项目必然比标准软件构建包含更多的不确定性。他们的成功将依赖于更多因素,许多因素超出数据科学家的控制范围,他们在阐述要求以最大化成功机会方面经验有限。将所有这些因素综合考虑,我们开始看到一种被投入深水区的情况。
这可能导致其他职能领导在首次领导数据科学团队时面临挑战。这一教训来自麦肯锡的“为现代时代建立研发战略”,也适用于我们的领域:
组织倾向于青睐那些有近期回报的“安全”项目——比如那些源于客户需求的项目——这些项目在许多情况下只是维持现有市场份额。例如,一家消费品公司将研发预算划分给其业务单位,业务单位的领导者则用这些资金来达成短期目标,而不是公司的长期差异化和增长目标。
在我们的领域,这通常表现为初级数据科学家被他们的非技术主管要求编写能够回答当天问题的任何 SQL 查询。这通常是有帮助的,但通常不是企业通过招聘精明建模师来驱动的价值。
当你有曾经管理过数据科学(DS)或机器学习(ML)项目的领导时,这个问题会更容易解决。无论职能如何,成功的关键在于拥有能够倾听问题、规划分析和建模方法解决问题,并管理风险和不确定性的人。许多早期职业的数据科学家在这种情况下表现出色。根据我的经验,他们是沟通能力和处理模糊性的天赋者。我有幸不小心聘用了一些这样的人——嗨,志宇!依赖你的能力来筛选这些人才,并为之竞争,可能会带来风险。
这一切似乎都支持将数据科学职能集中化。这是一种方法,也引出了我们下一个重要的问题。
你的数据科学家是否足够接近业务,以关注正确的问题? 与直接向业务团队汇报的超本地团队相比,中央数据科学职能组可能会较少接触到你希望解决的业务问题。大型、单一的职能团队通过正式的流程,可能很难获得所需的业务输入,主要是因为许多利益相关者不知道自己要提出什么问题。如果你听过数据科学团队产生“没有人要求的科学项目”的恐怖故事,这通常是一个根本原因。而且,再次提醒,不要刻板印象:这很少是因为数据科学团队有过于学术的思维方式,更常见的是因为两个不同职能不知如何用共同语言交流。
这留给我们什么选项?这也是我经验中嵌入式模型有效的一个原因。在这种模型中,你的数据科学团队可以访问你们经常讨论业务问题的所有论坛。他们负责利用这个机会理解业务团队希望解决的问题,并提出可以增加价值的方法。他们向数据科学领导汇报,后者确保他们的工作方法论是正确的,支持他们获得项目成功所需的资源,并指导和辅导他们的成长。
有时数据科学项目失败是因为方法论不佳;它们常常失败是因为预测特征不够有用。知道这两者之间的区别对于非定量职能的人来说可能非常困难。
当我们完成这一步时,我们需要:
-
清晰地定义数据科学家或团队的工作范围
-
定义的参与模式
正如所有实际决策中一样,到处都有权衡,没有万能的解决方案。完全自主的本地团队将最大限度地关注不同的本地结果。集中式职能将最小化重复,但增加了偏离实际、有影响的结果的风险。
退后一步,进行沟通和整体迭代
让我们回顾一下我们迄今为止取得的成就:
-
定义了一个战略假设,即我们将如何通过数据科学和机器学习增加价值的大赌注。
-
确定一个目标投资组合,该投资组合与我们组织的风险承受能力相一致,考虑到你的流程和技术限制,并将我们的团队集中于那些无法通过购买解决的问题上。
-
根据数据访问和它们如何创造价值,筛选我们的使用案例。
-
可能,开发了支持数据科学家的报告结构项目采购方法,并将他们的才能集中于他们独特的优势。
更直白地说,我们已经列出了找到正确使用案例的标准,并筛选了我们的使用案例机会,以找到第一个正确的集合。
接下来的任务是:
-
退后一步,整体查看所有内容。作为一个整体来看,这是否合理?
-
传达这一策略,以及从中衍生出的初步计划。
-
传达潜在利益相关者如何参与你的职能团队。
-
迭代:每当导致策略的假设或情况发生变化时,重新审视你的策略,并承诺定期检查情况的变化。

总结来说,这个过程需要相当大的努力。但是,它带来了巨大的回报。这一策略将明确表达你想要承担的风险、如何管理这些风险,以及它们如何支持你的目标结果(如果成功的话)。目的的明确对齐,以及保持活动与这一目的的一致性,对于一个职能团队来说是非常赋能的。实现这一点,结果将随之而来。
参考文献
-
Brenna 等人,“为现代时代构建研发战略”
-
Corbo 等人,“像技术原生公司一样扩展 AI:CEO 的角色”
-
Kiechel, Walter. 战略之王:新企业世界的秘密知识史(附属链接。)
-
Meakin 等人,“用 AI 获胜是一种心态”
-
Pisano, Gary P. “制定研发战略”
-
Rasiel, Ethan. 麦肯锡方法(附属链接。)
-
Scully 等人,“机器学习:技术债务的高利息信用卡”
在你的 WhatsApp 聊天中构建一个语言模型
通过应用程序了解 GPT 架构的视觉指南
·
关注 发表在 Towards Data Science ·16 分钟阅读·2023 年 11 月 21 日
--
图片由 Volodymyr Hryshchenko 提供,来自 Unsplash
聊天机器人无疑改变了我们与数字平台的互动。尽管基础语言模型在处理复杂任务方面取得了令人印象深刻的进展,但用户体验仍然常常显得不够个人化和疏离。
为了使对话更加贴近实际,我设想了一个可以模拟我随意写作风格的聊天机器人,就像在 WhatsApp 上给朋友发短信一样。
在这篇文章中,我将带你了解我构建一个(小型)语言模型的过程,该模型生成合成对话,使用我的 WhatsApp 聊天记录作为输入数据。过程中,我尝试以可视化且希望易于理解的方式解开 GPT 架构的内部工作机制,并附有实际的 Python 实现。你可以在我的 GitHub 上找到完整项目。
注意: 模型类 本身大块取自 Andrej Karpathy 的视频系列,并根据我的需要进行了调整。我强烈推荐他的教程。
从头开始训练一个语言模型,完全基于你的 WhatsApp 群聊。
github.com](https://github.com/bernhard-pfann/lad-gpt?source=post_page-----31264a9ced90--------------------------------)
目录
-
选定的方法
-
数据来源
-
分词
-
索引
-
模型架构
-
模型训练
-
聊天模式
1. 选定的方法
在将语言模型定制为特定语料库时,可以采取几种方法:
-
模型构建: 这涉及从头开始构建和训练模型,在模型架构和训练数据选择方面提供了最大的灵活性。
-
微调: 这种方法利用现有的预训练模型,调整其权重以更紧密地与手头的特定数据对齐。
-
提示工程: 这也利用了现有的预训练模型,但在这里,独特的语料库直接融入提示中,而不改变模型的权重。
由于我对这个项目的动机主要是自我教育,并且对现代语言模型的架构非常感兴趣,我选择了第一种方法。然而,这种选择带来了明显的限制。鉴于我的数据量和可用计算资源,我并没有期望与任何最先进的预训练模型相当的结果。
尽管如此,我仍希望我的模型能发现一些有趣的语言模式,最终它确实做到了。探索第二种选项(微调)可能会成为未来文章的重点。
2. 数据来源
WhatsApp,作为我的主要沟通渠道,是捕捉我的对话风格的理想来源。导出超过六年的群聊记录,总计超过 150 万字是非常简单的。
数据使用正则表达式模式解析成包含日期、联系人姓名和聊天消息的元组列表。
pattern = r'\[(.*?)\] (.*?): (.*)'
matches = re.findall(pattern, text)
text = [(x1, x2.lower()) for x0, x1, x2 in matches]
[
(2018-03-12 16:03:59, "Alice", "Hi, how are you guys?"),
(2018-03-12 16:05:36, "Tom", "I am good thanks!"),
...
]
现在,每个元素都被单独处理。
-
发送日期: 除了将其转换为日期时间对象外,我没有利用这些信息。然而,可以查看时间差异以区分话题讨论的开始和结束。
-
联系人姓名: 在分词文本时,每个联系人姓名被视为一个唯一的标记。这确保了名和姓的组合仍被视为一个整体。
-
聊天消息: 在每条消息的末尾添加了一个特殊的“
”标记。
3. 分词
为了训练语言模型,我们需要将语言分解成片段(所谓的标记),并逐步输入模型。分词可以在多个层次上进行。
-
字符级别: 文本被视为一系列单独的字符(包括空格)。这种细粒度的方法允许从字符序列中形成每一个可能的单词。然而,捕捉单词之间的语义关系会更困难。
-
词级别: 文本被表示为一个单词序列。然而,模型的词汇量受到训练数据中现有单词的限制。
-
子词级别: 文本被拆分成比单词小但比字符大的子词单元。
虽然我最初使用的是字符级别的分词器,但我觉得训练时间被浪费在了学习重复单词的字符序列上,而不是关注句子中单词之间的语义关系。
为了概念上的简洁,我决定切换到词级别的分词器,暂时搁置了用于更复杂分词策略的现有库。
from nltk.tokenize import RegexpTokenizer
def custom_tokenizer(txt: str, spec_tokens: List[str], pattern: str="|\d|\\w+|[^\\s]") -> List[str]:
"""
Tokenize text into words or characters using NLTK's RegexpTokenizer, considerung
given special combinations as single tokens.
:param txt: The corpus as a single string element.
:param spec_tokens: A list of special tokens (e.g. ending, out-of-vocab).
:param pattern: By default the corpus is tokenized on a word level (split by spaces).
Numbers are considered single tokens.
>> note: The pattern for character level tokenization is '|.'
"""
pattern = "|".join(spec_tokens) + pattern
tokenizer = RegexpTokenizer(pattern)
tokens = tokenizer.tokenize(txt)
return tokens
["Alice:", "Hi", "how", "are", "you", "guys", "?", "<END>", "Tom:", ... ]
结果显示我的训练数据有大约 70,000 个独特的单词。然而,由于许多单词仅出现一次或两次,我决定用“
from collections import Counter
def get_infrequent_tokens(tokens: Union[List[str], str], min_count: int) -> List[str]:
"""
Identify tokens that appear less than a minimum count.
:param tokens: When it is the raw text in a string, frequencies are counted on character level.
When it is the tokenized corpus as list, frequencies are counted on token level.
:min_count: Threshold of occurence to flag a token.
:return: List of tokens that appear infrequently.
"""
counts = Counter(tokens)
infreq_tokens = set([k for k,v in counts.items() if v<=min_count])
return infreq_tokens
def mask_tokens(tokens: List[str], mask: Set[str]) -> List[str]:
"""
Iterate through all tokens. Any token that is part of the set, is replaced by the unknown token.
:param tokens: The tokenized corpus.
:param mask: Set of tokens that shall be masked in the corpus.
:return: List of tokenized corpus after the masking operation.
"""
return [t.replace(t, unknown_token) if t in mask else t for t in tokens]
infreq_tokens = get_infrequent_tokens(tokens, min_count=2)
tokens = mask_tokens(tokens, infreq_tokens)
["Alice:", "Hi", "how", "are", "you", "<UNK>", "?", "<END>", "Tom:", ... ]
4. 索引
在分词之后,下一步是将单词和特殊标记转换为数值表示。使用固定的词汇表,每个单词按其位置进行了索引。编码后的单词随后被准备为 PyTorch 张量。
import torch
def encode(s: list, vocab: list) -> torch.tensor:
"""
Encode a list of tokens into a tensor of integers, given a fixed vocabulary.
When a token is not found in the vocabulary, the special unknown token is assigned.
When the training set did not use that special token, a random token is assigned.
"""
rand_token = random.randint(0, len(vocab))
map = {s:i for i,s in enumerate(vocab)}
enc = [map.get(c, map.get(unknown_token, rand_token)) for c in s]
enc = torch.tensor(enc, dtype=torch.long)
return enc
torch.tensor([8127, 115, 2363, 3, ..., 14028])
由于我们需要评估模型在一些未见数据上的质量,我们将张量分成两部分。这样,我们就得到了训练集和验证集,可以准备好喂给语言模型。

作者提供的图片
5. 模型架构
我决定应用 GPT 架构,这一架构在具有影响力的论文“Attention is All you Need”中得到了推广。由于我试图构建的是语言生成器而不是问答机器人,因此仅使用解码器(右侧)架构足以满足这一目的。

“Attention is All you Need” 由 A. Vaswani 等人撰写(取自 arXiv: 1706.03762)
在接下来的部分中,我将分解 GPT 架构的每个组件,解释其作用以及基础的矩阵运算。从准备好的训练测试开始,我将追踪一个示例上下文的 3 个词,通过模型,直到它预测下一个令牌。
5.1. 模型目标
在深入技术细节之前,了解我们模型的主要目标至关重要。在仅解码器的设置中,我们的目标是解码语言的结构,以准确预测序列中的下一个令牌,前提是给定前面的令牌上下文。

图片由作者提供
当我们将索引的令牌序列输入模型时,它会经历一系列与各种权重矩阵的矩阵乘法运算。输出是一个向量,表示每个令牌在序列中作为下一个令牌的概率,这个概率基于输入上下文。
模型评估:
我们的模型性能通过训练数据来评估,其中实际的下一个令牌是已知的。目标是最大化正确预测这个下一个令牌的概率。
然而,在机器学习中,我们常常关注“损失”这一概念,它量化了错误或不正确预测的可能性。为了计算这一点,我们将模型的输出概率与实际的下一个令牌进行比较(使用cross-entropy)。
优化:
通过了解当前的损失,我们旨在通过反向传播来最小化它。这个过程涉及迭代地将令牌序列输入模型,并调整权重矩阵以提升性能。
在每张图中,我将用黄色标出在该过程中将被优化的权重矩阵。
5.2. 输出嵌入
到目前为止,我们序列中的每个令牌都由一个整数索引表示。然而,这种简单的形式无法反映单词之间的关系或相似性。为了解决这个问题,我们将一维索引提升到更高维度的空间中,通过嵌入实现。
-
词嵌入: 单词的本质由一个 n 维的浮点向量来捕捉。
-
位置嵌入: 这些嵌入强调了单词在句子中的位置的重要性,也表示为 n 维的浮点向量。
对于每个令牌,我们查找其词嵌入和位置嵌入,然后逐元素相加。这就得出了每个令牌在上下文中的输出嵌入。
在下面的示例中,上下文包含 3 个令牌。在嵌入过程结束时,每个令牌由一个 n 维向量表示(其中 n 是嵌入大小,一个可调的超参数)。

图片由作者提供
PyTorch 提供了专门的类来处理这些嵌入。在我们的模型类中,我们定义了词嵌入和位置嵌入,如下所示(传递矩阵维度作为参数):
self.word_embedding = nn.Embedding(vocab_size, embed_size)
self.pos_embedding = nn.Embedding(block_size, embed_size)
5.3. 自注意力头
虽然词嵌入提供了词语相似性的整体感觉,但一个词的真实含义往往取决于其周围的上下文。例如,“bat”可能指的是动物或运动器材,这取决于句子。这就是自注意力机制(GPT 架构的关键组成部分)发挥作用的地方。
自注意力机制关注三个主要概念:查询(Q)、键(K)和值(V)。
-
查询(Q): 查询本质上是当前标记的表示,注意力需要计算它。就像在问:“作为当前标记,我应该关注上下文中的什么?”
-
键(K): 键是输入序列中每个标记的表示。它们与查询配对,以确定注意力分数。这种比较衡量了查询标记应将多少关注放在上下文中的其他标记上。高分意味着应该更多关注。
-
值(V): 值也是输入序列中每个标记的表示。然而,它们的作用不同,因为它们对注意力分数施加最终加权。

作者提供的图像
示例:
在我们的示例中,上下文中的每个标记已经是嵌入形式,作为 n 维向量(e1, e2, e3)。自注意力头将它们作为输入,以逐一输出每个标记的上下文化版本。
-
在评估“name”这个标记时,通过将其嵌入向量v2与可训练矩阵M_Q相乘,得到一个查询向量q。
-
同时,为上下文中的每个标记计算键向量(k1, k2, k3),通过将每个嵌入向量(e1, e2, e3)与可训练矩阵M_K相乘。
-
值向量(v1, v2, v3) 以相同的方式获得,只是乘以不同的可训练矩阵M_V。
-
注意力分数w通过查询向量与每个键向量之间的点积来计算。
-
最后,我们将所有值向量堆叠成一个矩阵,并将其与注意力分数相乘,以获得标记“name”的上下文化向量。
class Head(nn.Module):
"""
This module performs self-attention operations on the input tensor, producing
an output tensor with the same time-steps but different channels.
:param head_size: The size of the head in the multi-head attention mechanism.
"""
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(embed_size, head_size, bias=False)
self.query = nn.Linear(embed_size, head_size, bias=False)
self.value = nn.Linear(embed_size, head_size, bias=False)
def forward(self, x):
"""
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
"""
B,T,C = x.shape
k = self.key(x)
q = self.query(x)
# compute attention scores
wei = q @ k.transpose(-2,-1)
wei /= math.sqrt(k.shape[-1])
# avoid look-ahead
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
# weighted aggregation of the values
v = self.value(x)
out = wei @ v
return out
5.4. 掩蔽多头注意力
语言是复杂的,捕捉其所有的细微差别并不简单。一组注意力计算通常不足以捕捉词语如何相互作用的细微之处。这就是 GPT 模型中的多头注意力的理念派上用场的地方。
你可以把多头注意力想象成多个“眼睛”以不同的方式观察数据,每个“眼睛”注意到独特的细节。这些独立的观察结果然后被整合成一个大图景。为了使这个大图景易于管理并与我们模型的其余部分兼容,我们使用线性层(可训练权重)将其压缩回原始的嵌入大小。
最后,为了确保我们的模型不仅仅记住训练数据,还能在新文本上进行良好的预测,我们使用了一个 dropout 层。这个层在训练过程中随机关闭数据的部分内容,帮助模型变得更加适应。

作者提供的图片
class MultiHeadAttention(nn.Module):
"""
This class contains multiple `Head` objects, which perform self-attention
operations in parallel.
"""
def __init__(self):
super().__init__()
head_size = embed_size // n_heads
heads_list = [Head(head_size) for _ in range(n_heads)]
self.heads = nn.ModuleList(heads_list)
self.linear = nn.Linear(n_heads * head_size, embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
heads_list = [h(x) for h in self.heads]
out = torch.cat(heads_list, dim=-1)
out = self.linear(out)
out = self.dropout(out)
return out
5.5. 前馈
多头注意力层最初捕捉了序列中的上下文关系。通过两个连续的线性层为网络添加了更多深度,这两个层共同构成了前馈神经网络。

作者提供的图片
在初始线性层中,我们增加了维度(在我们的例子中是增加了 4 倍),这有效地拓宽了网络学习和表示更复杂特征的能力。对结果矩阵的每个元素应用ReLU函数,使得非线性模式能够被识别。
随后,第二个线性层作为一个压缩器,将扩展的维度减少回原始形状(块大小 x 嵌入大小)。Dropout 层结束了这个过程,随机地停用矩阵的部分元素,以实现模型的泛化。
class FeedFoward(nn.Module):
"""
This module passes the input tensor through a series of linear transformations
and non-linear activations.
"""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_size, 4 * embed_size),
nn.ReLU(),
nn.Linear(4 * embed_size, embed_size),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
5.6. Add & Norm
现在,我们通过引入两个关键元素,将多头注意力和前馈组件连接在一起:
-
残差连接(Add): 这些连接执行层输出与其未更改输入的逐元素加法。在训练过程中,模型根据层变换的有用性调整对层变换的重视程度。如果某个变换被认为是不必要的,其权重和相应的层输出将趋向于零。在这种情况下,至少未更改的输入会通过残差连接传递。这种技术有助于缓解梯度消失问题。
-
层归一化(Norm): 这种方法通过减去嵌入向量的均值并除以其标准差来归一化上下文中的每个嵌入向量。这个过程还确保了在反向传播过程中梯度不会爆炸或消失。

作者提供的图片
多头注意力和前馈层的链条,通过“Add & Norm”连接,合并成一个块。这种模块化设计使我们能够形成一系列块。这些块的数量是一个超参数,它决定了模型架构的深度。
class Block(nn.Module):
"""
This module contains a single transformer block, which consists of multi-head
self-attention followed by feed-forward neural networks.
"""
def __init__(self):
super().__init__()
self.sa = MultiHeadAttention()
self.ffwd = FeedFoward()
self.ln1 = nn.LayerNorm(embed_size)
self.ln2 = nn.LayerNorm(embed_size)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
5.7. Softmax
在遍历多个块组件后,我们获得了一个维度为(块大小 x 嵌入大小)的矩阵。为了将这个矩阵重塑为所需的维度(块大小 x 词汇表大小),我们将其通过一个最终的线性层。这个形状表示了上下文中每个位置词汇表中每个词的一个条目。
最后,我们对这些值应用 soft-max 变换,将其转换为概率。我们成功地获得了上下文中每个位置的下一个标记的概率分布。
6. 模型训练
为了训练语言模型,我从训练数据中的随机位置选择了令牌序列。鉴于 WhatsApp 对话的快节奏,我确定 32 个词的上下文长度足够。因此,我选择了随机的 32 词块作为上下文输入,并使用相应的向量(向后移动一个词)作为比较的目标。
训练过程循环执行以下步骤:
-
对多个批次的上下文进行采样。
-
将这些样本输入模型,以计算当前损失。
-
根据当前的损失和模型权重应用反向传播。
-
每 500 次迭代更全面地评估损失。
一旦所有其他模型超参数(如嵌入大小、自注意力头数量等)确定后,我最终选择了一个具有 250 万参数的模型。考虑到我对输入数据大小和计算资源的限制,我发现这是对我而言的最佳设置。
训练过程大约花费了 12 小时,完成了 10,000 次迭代。可以看到,训练本可以更早停止,因为验证集和训练集上的损失差距在扩大。

作者提供的图片
import json
import torch
from config import eval_interval, learn_rate, max_iters
from src.model import GPTLanguageModel
from src.utils import current_time, estimate_loss, get_batch
def model_training(update: bool) -> None:
"""
Trains or updates a GPTLanguageModel using pre-loaded data.
This function either initializes a new model or loads an existing model based
on the `update` parameter. It then trains the model using the AdamW optimizer
on the training and validation data sets. Finally the trained model is saved.
:param update: Boolean flag to indicate whether to update an existing model.
"""
# LOAD DATA -----------------------------------------------------------------
train_data = torch.load("assets/output/train.pt")
valid_data = torch.load("assets/output/valid.pt")
with open("assets/output/vocab.txt", "r", encoding="utf-8") as f:
vocab = json.loads(f.read())
# INITIALIZE / LOAD MODEL ---------------------------------------------------
if update:
try:
model = torch.load("assets/models/model.pt")
print("Loaded existing model to continue training.")
except FileNotFoundError:
print("No existing model found. Initializing a new model.")
model = GPTLanguageModel(vocab_size=len(vocab))
else:
print("Initializing a new model.")
model = GPTLanguageModel(vocab_size=len(vocab))
# initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learn_rate)
# number of model parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters to be optimized: {n_params}\n", )
# MODEL TRAINING ------------------------------------------------------------
for i in range(max_iters):
# evaluate the loss on train and valid sets every 'eval_interval' steps
if i % eval_interval == 0 or i == max_iters - 1:
train_loss = estimate_loss(model, train_data)
valid_loss = estimate_loss(model, valid_data)
time = current_time()
print(f"{time} | step {i}: train loss {train_loss:.4f}, valid loss {valid_loss:.4f}")
# sample batch of data
x_batch, y_batch = get_batch(train_data)
# evaluate the loss
logits, loss = model(x_batch, y_batch)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
torch.save(model, "assets/models/model.pt")
print("Model saved")
7. 聊天模式
为了与训练好的模型进行交互,我创建了一个函数,允许通过下拉菜单选择联系人姓名,并输入消息供模型响应。参数“n_chats”决定模型一次生成的响应数量。当模型预测
import json
import random
import torch
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from config import end_token, n_chats
from src.utils import custom_tokenizer, decode, encode, print_delayed
def conversation() -> None:
"""
Emulates chat conversations by sampling from a pre-trained GPTLanguageModel.
This function loads a trained GPTLanguageModel along with vocabulary and
the list of special tokens. It then enters into a loop where the user specifies
a contact. Given this input, the model generates a sample response. The conversation
continues until the user inputs the end token.
"""
with open("assets/output/vocab.txt", "r", encoding="utf-8") as f:
vocab = json.loads(f.read())
with open("assets/output/contacts.txt", "r", encoding="utf-8") as f:
contacts = json.loads(f.read())
spec_tokens = contacts + [end_token]
model = torch.load("assets/models/model.pt")
completer = WordCompleter(spec_tokens, ignore_case=True)
input = prompt("message >> ", completer=completer, default="")
output = torch.tensor([], dtype=torch.long)
print()
while input != end_token:
for _ in range(n_chats):
add_tokens = custom_tokenizer(input, spec_tokens)
add_context = encode(add_tokens, vocab)
context = torch.cat((output, add_context)).unsqueeze(1).T
n0 = len(output)
output = model.generate(context, vocab)
n1 = len(output)
print_delayed(decode(output[n0-n1:], vocab))
input = random.choice(contacts)
input = prompt("\nresponse >> ", completer=completer, default="")
print()
结论:
由于个人聊天的隐私,我无法在这里展示示例提示和对话。
尽管如此,你可以期待这样规模的模型能够成功学习句子的总体结构,产生有意义的输出,尤其是在词序方面。在我的案例中,它也掌握了训练数据中某些重要话题的上下文。例如,由于我的个人聊天经常围绕网球展开,网球运动员的名字和与网球相关的词通常会一起输出。
然而,在评估生成句子的连贯性时,我承认结果并没有达到我已经很低的期望。当然,我也可以责怪我的朋友们聊了太多无聊的话,限制了模型学习有用内容的能力...
为了在结尾展示至少一些示例输出,你可以查看虚拟模型在 200 条训练虚拟消息上的表现 😉

作者提供的图片
构建一个本地运行的语音助手
向 LLM 提问而不泄露私人信息
·
关注 发表在 Towards Data Science ·7 分钟阅读·2023 年 12 月 29 日
--
图片由作者生成,并由 openart.ai 协助
我不得不承认,我最初对大型语言模型(LLM)生成实际有效的代码片段的能力持怀疑态度。我带着最坏的预期尝试了,结果却感到惊喜。像与任何聊天机器人互动一样,问题的格式很重要,但随着时间的推移,你会学会如何明确你需要帮助的问题的边界。
我习惯了在编写代码时有一个始终可用的在线聊天机器人服务,但当我的雇主发布了禁止员工使用它的公司政策时,我不得不寻找其他方案。我可以回到以前的谷歌习惯,但我决定构建一个本地运行的 LLM 服务,这样我可以在不泄露公司外部信息的情况下提问。感谢HuggingFace上的开源 LLM 提供以及chainlit 项目,我能够组建一个满足编码辅助需求的服务。
下一步是添加一些语音交互。尽管语音不适合编码辅助(你希望看到生成的代码片段,而不是听到它们),但在某些情况下,你需要在创意项目上获得灵感。讲故事的感觉为体验增添了价值。另一方面,你可能不愿意使用在线服务,因为你希望保持工作的隐私。
在这个项目中,我将带你了解构建一个允许你通过语音与开源 LLM 交互的助手的步骤。所有组件都在你的计算机上本地运行。
架构
架构包括三个独立的组件:
-
一个唤醒词检测服务
-
一个语音助手服务
-
一个聊天服务

三个组件的流程图。图片由作者提供。
这三个组件是独立的项目,每个项目都有自己的 github 仓库。让我们逐一了解每个组件及其如何交互。
聊天服务
聊天服务运行开源 LLM,名为HuggingFaceH4/zephyr-7b-alpha。该服务通过 POST 调用接收提示,将提示传递给 LLM,并将输出作为调用响应返回。
你可以在这里找到代码。
在…/chat_service/server/中,将chat_server_config.xml.example重命名为chat_server_config.xml。
然后你可以使用以下命令启动聊天服务器:
python .\chat_server.py
当服务第一次运行时,由于从HuggingFace网站下载大型文件并存储在本地缓存目录中,启动需要几分钟时间。
你会从终端收到服务正在运行的确认:

聊天服务运行的确认。图片由作者提供。
如果你想测试与 LLM 的交互,请前往…/chat_service/chainlit_interface/。
将app_config.xml.example重命名为app_config.xml。使用以下命令启动网络聊天服务:
.\start_interface.sh
浏览到本地地址localhost:8000
你应该能够通过文本接口与本地运行的 LLM 进行交互:

与本地运行的 LLM 的文本交互。图片由作者提供。
语音助手服务
语音助手服务是进行语音转文本和文本转语音转换的地方。你可以在这里找到代码。
转到…/voice_assistant/server/。
将voice_assistant_service_config.xml.example重命名为voice_assistant_service_config.xml。
助理通过播放问候语开始,以表明它正在监听用户。问候语的文本在voice_assistant_config.xml中配置,位于<welcome_message>元素下:

voice_assistant_config.xml 文件。图片由作者提供。
允许程序将文本转换为你可以通过音频输出设备听到的语音的文本转语音引擎是pyttsx3。根据我的经验,这个引擎用英语和法语说话都非常自然。与其他依赖 API 调用的软件包不同,它是在本地运行的。
一个名为facebook/seamless-m4t-v2-large的模型执行语音转文本推断。voice_assistant_service.py首次运行时会下载模型权重。
voice_assistant_service.main()中的主要循环执行以下任务:
-
从麦克风获取句子。使用语音转文本模型将其转换为文本。
-
检查用户是否说出了配置文件中的<end_of_conversation_text>元素定义的消息。在这种情况下,谈话结束,程序在播放再见消息后终止。
-
检查句子是否是胡言乱语。语音转文本引擎即使我什么也没说,通常也会输出有效的英语句子。偶尔,这些不必要的输出会重复出现。例如,胡言乱语的句子有时会以“[”或“i’m going to”开头。我在配置文件的<gibberish_prefix_list>元素中收集了通常与胡言乱语句子相关的前缀列表(这个列表可能会根据其他语音转文本模型而有所变化)。每当音频输入以列表中的一个前缀开头时,句子就会被忽略。
-
如果句子看起来不是胡言乱语,请向聊天服务发送请求。播放响应。
voice_assistant_service.main()中的主要循环。代码由作者提供。
唤醒词服务
最后一个组件是一个持续监听用户麦克风的服务。当用户说出唤醒词时,系统调用会启动语音助手服务。唤醒词服务运行的模型比语音助手服务模型要小。因此,持续运行唤醒词服务是有意义的,而语音助手服务则只有在需要时才会启动。
你可以在这里找到唤醒词服务的代码。
克隆项目后,转到…/wakeword_service/server。
将wakeword_service_gui_config.xml.example重命名为wakeword_service_gui_config.xml。
将command.bat.example重命名为command.bat。你需要编辑command.bat,使虚拟环境激活和对voice_assistant_service.py的调用与您的目录结构相匹配。
你可以通过以下调用来启动服务:
python gui.py
唤醒词检测服务的核心是openwakeword项目。在几个唤醒词模型中,我选择了“hey jarvis”模型。我发现简单地说“Jarvis?”就会触发检测。
每当检测到唤醒词时,都会调用一个命令文件,正如配置文件的<command_on_wakeword>元素中所指定的。在我们的情况下,command.bat文件激活虚拟环境并启动语音助手服务。

唤醒词检测服务 GUI 的配置文件。图片来源于作者。
如果你想将唤醒词检测服务用于其他用途,你可以编辑command.bat文件,以启动你想要的任何程序。
当检测到唤醒词时,唤醒词检测服务会发出两声哔声以指示已触发。这一快速反馈告诉用户,他们必须等待,因为文本到语音和语音到文本模型正在加载到计算机的 GPU 中,这可能需要几秒钟。当两个模型都加载完毕后,语音助手服务会播放问候语,交谈可以开始。对话结束后,语音助手服务终止,唤醒词检测服务会发出另一声两声哔声信号。
当三个组件安装并测试完成后,我们可以启动聊天和唤醒词检测服务。唤醒词检测服务将不断监听麦克风。当检测到唤醒词(“Jarvis?”)时,它将启动语音助手服务。然后你可以大声提问并收到语音回答。当检测到对话结束短语(“谢谢,再见”)时,语音助手服务结束。
与助手的交互示例
这是与语音助手交互的一个示例:
“Jarvis?” [哔—哔]
[在语音助手模型加载期间的几秒钟沉默]
“你好!我能帮你什么?”
“土星的环由什么组成?”
“土星的环主要由冰粒组成,其中大部分是水冰。然而,环中也有少量其他冰类,例如氨冰和岩石碎片。环的确切成分因距离土星的远近而异,内圈的密度较高,颗粒较小,而外圈则分布较广,含有较大的冰粒。”
“告诉我电影‘回到未来’的情节。”
“电影《回到未来》讲述了高中生马蒂·麦克弗莱的故事,他意外地被传送回 1955 年,乘坐由他那古怪的科学家朋友布朗博士创造的时光旅行德洛林车。马蒂必须在修复对现在造成的损害的同时应对过去的挑战,并防止他的青少年母亲爱上他的青少年父亲。在这个过程中,马蒂结识了年轻版的布朗博士,并帮助他克服个人悲剧。这部电影探讨了命运、家庭和改变过去的后果等主题。”
“谢谢,再见”
“再见” [嘟嘟声]
结论
在这个项目中,我们实现了一个由唤醒词检测服务、语音助手服务和聊天服务组成的语音助手系统。这些组件都是独立的模块,可以在其他项目中重复使用。
在我开发这个语音助手的过程中,有许多惊喜,其中最让我印象深刻的是语音转文本的质量。如果你像我一样,可能会遇到过无法准确转录简单命令的自动语音识别系统,比如“调低音量”!我原本以为语音转文本会是整个流程的主要障碍。在尝试了几个不令人满意的模型后,我最终选择了facebook/seamless-m4t-v2-large,对其结果的质量印象深刻。我甚至可以用法语说一句话,神经网络会自动将其翻译成英语。这简直令人惊叹!
我希望你能尝试这个有趣的项目,并告诉我你是如何使用它的!
使用 Google Cloud Recommendation AI 构建推荐系统
使用 Google Cloud Recommendation AI 实现高度先进的推荐系统
·
关注 发表在 Towards Data Science ·14 分钟阅读·2023 年 6 月 21 日
--
想象一下谷歌自己的机器学习工程师在为你的网站和应用程序实施推荐系统。借助 Google Cloud Recommendation AI,你可以利用正在用于驱动 YouTube、Google 广告和其他 Google 产品的推荐系统,为我们提供个性化服务。
在本文中,我将指导您如何实施 Google Cloud Recommendation AI,以便为您的客户提供个性化体验。
Google Cloud Recommendation AI 的好处?
Noon、IKEA、KINGUIN 以及许多其他公司使用 Google Cloud Recommendation AI 提供跨渠道的高性能推荐。

Muffaddal 关于 Google Cloud Recommendation AI 对某些企业的初步影响
在工作量和资源方面,构建和部署推荐系统可能是一个复杂且耗时的任务。它通常涉及数据工程师、数据科学家和机器学习工程师的协作,以创建和运营一个全面的推荐系统。然而,使用 Google Cloud Recommendation AI 的情况却大不相同。
通过利用 Google Cloud Recommendation AI,实施过程变得显著简化且高效。这项完全托管的服务使您能够快速提供个性化体验,节省宝贵的时间和资源。实施 Google Cloud Recommendation AI 可以在以下几个方面带来明显的改进:
-
转化率: 通过向用户提供量身定制的产品或内容推荐,您可以提升他们的购物或浏览体验,从而增加转化的可能性。
-
客户满意度: 个性化推荐满足了个别用户的具体偏好和需求,带来了更高的客户满意度和参与度。
-
收入生成: 通过提高转化率和客户满意度,实施 Google Cloud Recommendation AI 可以对您的收入流产生积极影响。
-
时间节省: 作为一项完全托管的服务,Google Cloud Recommendation AI 减少了广泛开发和维护工作的需求,使您能够专注于业务的其他关键方面。
-
可扩展性: Google Cloud Recommendation AI 旨在处理大量数据,并且能够随着业务增长无缝扩展,确保一致的性能和准确性。
通过利用 Google Cloud Recommendation AI 的能力,您可以释放显著改善客户体验、转化率、收入和运营效率的潜力。
Google Cloud Recommendation AI 可以在个性化方面实施的领域
Google Cloud Recommendation AI 提供了一系列机器学习模型和配置,可以在各种业务环境中有效地利用,以提升用户体验并推动更高的投资回报。以下是 Google Cloud Recommendation AI 可以应用的一些关键领域:
-
个性化推荐:在主页上实施“为您推荐”或“推荐给您”部分,根据用户的个人偏好和行为提供量身定制的建议。
-
商品推荐:在产品详情页展示“您可能喜欢的其他商品”部分,推荐用户高度可能感兴趣或购买的商品。
-
购物车推荐:利用“常一起购买”部分,建议用户购买与所选商品常一起购买的其他产品,从而提升用户的购物体验。
-
相似商品推荐:在产品页面展示“相似商品”部分,展示与当前浏览商品具有相似属性的商品,尤其在浏览的商品缺货时非常有用。
-
重复购买推荐:在不同页面(如主页、产品详情、购物车)实施“再次购买”功能,建议用户基于之前的互动记录,推荐可能会再次购买的商品。
-
销售推荐:利用“促销”部分展示打折商品,鼓励用户探索并进行购买。
此外,Google Cloud 推荐 AI 还支持客观优化,使企业能够根据特定目标改进其模型。例如,内容管理平台如 Medium 可能会优化产品浏览量,而电子商务网站则专注于转化优化。
你还可以结合多个目标来最大化个性化的好处。例如,可以使用优化参与度的模型来促进新用户的探索,而针对返回用户的转化和收入优化则可以发挥作用。
不要忘记电子邮件个性化。Google Cloud 推荐 AI 驱动的模型也可以与电子邮件结合使用,以提升业务目标。例如,我们可以在用户购买商品后,通过电子邮件进行追加销售或交叉销售。通过“为您推荐”模型进行精准定位,增加用户互动。
结合以上提到的所有内容,我们可以在用户在我们平台上的整个生命周期内完全改变他们的体验,从而增加 LTV 和业务收入。
如何实现 Google Cloud 推荐 AI?
推荐 AI 需要两个方面来驱动机器学习模型。一是用户活动,二是用户进行活动的产品详情。
活动或操作作为事件发送,商品详情作为目录存储在推荐 AI 中。一旦我们获得了所需数量的事件和目录详情,我们可以训练我们的机器学习模型,为用户提供基于用户过去行为和属性的个性化列表。
让我们逐一详细了解这些内容。

使用 Google Cloud 推荐 AI 的数据管道,由 Muffaddal 提供
1- 导入产品详情
目录可以有许多字段和属性,但 id、name、title 和 categories 是必需字段,必须提供。
假设我们的产品详细信息已经存在于 BigQuery 中,我们将利用 BigQuery 和 Recommendation AI 的集成来导入目录数据。

Google Cloud Recommendation AI 的产品目录导入过程,由 Muffaddal 提供
Google Cloud Recommendation AI 需要一个特定的 BigQuery 表模式。因此,我们必须创建一个符合要求格式的表,并将目录数据插入其中。
假设我们的目录表在 BigQuery 中具有以下字段
-
name: 产品名称
-
id: 产品的 id
-
category: 产品的分配类别
-
description: 产品的描述
-
url: 网站上产品的 URL
-
image_link: 产品的公开访问图片链接
-
city: 产品可用的城市
根据上述产品目录详细信息,我们需要为 BigQuery 表创建模式。从这里获取模式,并根据我们的数据进行更新。
我们可以为目录设置许多字段。数据越多越好。为了演示的目的,我将使用最常见的字段。请注意,可为空的字段是可选的。
type 字段是我们决定产品是变体还是主产品的地方。本文中我将使用 PRIMARY。 点击这里了解更多详情。
一旦我们的表准备好,我们可以使用以下查询将目录数据从主表插入到此表中。
insert into `recommendersystem.product_data`
(
name,id,type,primaryProductId,collectionMemberIds,gtin,categories,title,brands,description,languageCode,attributes, tags,
priceInfo,rating,expireTime,ttl,availableTime,availability,availableQuantity,fulfillmentInfo, uri, images,audience,colorInfo,sizes,materials,patterns,conditions,retrievableFields,publishTime,promotions
)
SELECT
name,
cast(id as string) as id,
"PRIMARY" as type,
cast(id as string) as primaryProductId,
null as collectionMemberIds,
null as gtin,
array [categories] as categories,
name as title,
array[title] as brands,
ifnull(description,name) as description,
null as languageCode,
[
struct(
'product_location' as key, STRUCT(array[ifnull(city,"empty")] as text, cast(null as ARRAY<FLOAT64>) as numbers, true as searchable, true as indexable ) as value
)]
as attributes,
ARRAY_CONCAT(
[ifnull(location,"empty")],
[ifnull(categories,"empty")]
) as tags,
null as priceInfo,
null as rating,
null as expireTime,
null as ttl,
null as availableTime,
null as availability,
null as availableQuantity,
null as fulfillmentInfo,
url,
array[struct(image_url) as uri, null as height, null as width)] as images,
null as audience,
null as colorInfo,
null as sizes,
null as materials,
null as patterns,
null as conditions,
null as retrievableFields,
null as publishTime,
null as promotions
FROM `product.product_details`
数据在我们的新表中可用后,我们准备将其导入 Recommendation AI。
在 Google Cloud 的零售 AI 数据标签中,点击左上角的导入以导入数据。

向 Google Cloud Recommendation AI 导入目录详细信息,由 Muffaddal 提供
将出现一个面板,如下图所示。选择导入类型中的“产品目录”和数据源中的“BigQuery”。提供 BigQuery 表路径,选择一个分支并点击导入。

Google Cloud Recommendation AI 的目录导入面板,由 Muffaddal 提供
等待几分钟以查看 Retail AI 数据表中的目录详细信息。
2- 导入历史事件
接下来,我们需要将用户的历史数据导入 Recommendation AI。此步骤是可选的,但这样做有助于构建更好的机器学习模型。
类似于目录,我们需要在 BigQuery 表中拥有符合要求格式的事件。
以下是 Google Cloud Recommendation AI 接受的事件

可以发送到 Google Cloud Recommendation AI 的事件,由 Muffaddal 提供
在所有这些事件中,home-page-view、detail-page-viewed、add-to-cart和purchase-complete对于全面激活 AI 模型是必要的。
每个事件都有一个特定的表模式,用于数据导入。你可以在这里详细查看。
1- home-page-view、detail-page-viewed和add-to-cart的表模式如下
[
{
"name": "eventType",
"type": "STRING",
"mode": "REQUIRED"
},
{
"name": "visitorId",
"type": "STRING",
"mode": "REQUIRED"
},
{
"name": "eventTime",
"type": "STRING",
"mode": "REQUIRED"
},
{
"name": "productDetails",
"type": "RECORD",
"mode": "REPEATED",
"fields": [
{
"name": "product",
"type": "RECORD",
"mode": "REQUIRED",
"fields": [
{
"name": "id",
"type": "STRING",
"mode": "REQUIRED"
}
]
},
{
"name": "quantity",
"type": "INTEGER",
"mode": "REQUIRED"
}
]
},
{
"name": "attributes",
"type": "RECORD",
"mode": "NULLABLE",
"fields": [
{
"name": "deviceType",
"type": "RECORD",
"mode": "NULLABLE",
"fields": [
{
"name": "text",
"type": "STRING",
"mode": "REPEATED"
}
]
}
]
}
]
2- purchase-complete事件的模式如下
[
{
"name": "eventType",
"type": "STRING",
"mode": "REQUIRED"
},
{
"name": "visitorId",
"type": "STRING",
"mode": "REQUIRED"
},
{
"name": "eventTime",
"type": "STRING",
"mode": "REQUIRED"
},
{
"name": "productDetails",
"type": "RECORD",
"mode": "REPEATED",
"fields": [
{
"name": "product",
"type": "RECORD",
"mode": "REQUIRED",
"fields": [
{
"name": "id",
"type": "STRING",
"mode": "REQUIRED"
}
]
},
{
"name": "quantity",
"type": "INTEGER",
"mode": "REQUIRED"
}
]
},
{
"name": "purchaseTransaction",
"type": "RECORD",
"mode": "REQUIRED",
"fields": [
{
"name": "revenue",
"type": "FLOAT",
"mode": "REQUIRED"
},
{
"name": "currencyCode",
"type": "STRING",
"mode": "REQUIRED"
}
]
},
{
"name": "attributes",
"type": "RECORD",
"mode": "NULLABLE",
"fields": [
{
"name": "deviceType",
"type": "RECORD",
"mode": "NULLABLE",
"fields": [
{
"name": "text",
"type": "STRING",
"mode": "REPEATED"
}
]
},
{
"name": "cityName",
"type": "RECORD",
"mode": "NULLABLE",
"fields": [
{
"name": "text",
"type": "STRING",
"mode": "REPEATED"
}
]
}
]
}
]
一旦你创建了表,可以使用以下查询将数据插入到新表中。
1- home-page-viewed的 SQL 插入查询如下
insert into `recommendersystem.user_event_home_page_view` (eventType, visitorId,eventTime,attributes)
SELECT
'home-page-view' as eventType,
visitorId,
eventTime,
struct(
struct([deviceType] as text)as deviceType,
struct([city] as text)as cityName
)
as attributes,
from (
select *,
deviceType from `recommendersystem.user_event_history`
2- detail-page-viewed、add-to-cart的 SQL 插入查询为
insert into `recommendersystem.user_event_add_to_cart` (eventType, visitorId,eventTime,productDetails,attributes)
SELECT
'add-to-cart' as eventType,
visitorId,
eventTime,
[
struct( struct(product_id as id) as product , 1 as quantity)
] as productDetails,
struct(
struct([deviceType] as text)as deviceType,
struct([city] as text)as cityName
)
as attributes,
from (
select *,
deviceType from `recommendersystem.user_event_history`
-- limit 100
)
3- purchase-complete事件的 SQL 插入查询为
insert into `recommendersystem.user_event_purchase_complete` (eventType, visitorId,eventTime,productDetails,purchaseTransaction,attributes)
SELECT
'purchase-complete' as eventType,
cast( visitorId as string) as visitorId,
eventTime,
[
struct( struct(product_id as id) as product , 1 as quantity)
] as productDetails,
struct(safe_cast(revenue as float64) as revenue, 'USD' as currencyCode) as purchaseTransaction,
struct(
struct([deviceType] as text)as deviceType,
struct([city] as text)as cityName
)
as attributes
from (
select *,
deviceType from `recommendersystem.user_purchase_event`
)
注:访客 ID 和用户 ID 可以相同,也可以不同。这取决于用户是否需要在使用你的产品之前登录。
注:Google Cloud 推荐 AI 也支持 Google Analytics 4 原始数据。如果你有这些数据,则无需进行转换,可以直接导入。
要导入历史数据,请像以前一样转到零售 AI 的数据标签页,然后点击顶部的导入。
在导入面板中,选择导入类型为用户事件,选择以下选项,然后点击导入。

将历史事件导入 Google Cloud Recommendation AI,由 Muffaddal 提供
我们已经成功将用户事件导入到 Google Cloud Recommendation AI 中。
3- 发送实时事件
接下来,我们需要向用户发送实时事件,以便 AI 模型可以重新训练并在用户与平台互动的过程中提升推荐效果。
有三种方法将用户事件发送到 Google Cloud Recommendation AI。使用 javascript pixel、使用 API 和 GTM。
本文将使用 API,因为它可以覆盖所有情况,无论网站或应用的性质如何。
这是发送home-page-viewed事件的 curl 调用
curl -X POST \
-H "Authorization: Bearer ya29.a0AasRrdaM8jq4J0FBtJpsdhu87ZJcPlr4-8NqkDdGmLYtQ7P-riTx5man4k2maqWGNIsL1007a4BClSsfVbgHyjycaKn_5bviofo5XCbvCeO5-kkepnb_RUgy6prxRX7X8pi2PFHxX-kbmSmQgeEoILQQnK_aYMtTagSFzkUXC12Q2A5VmlfXR5cvSW-a80XxGOikpEf1jHuwusQx2EftBITnhPaYvg6Xi08qzvAEnmKgYicqCqY5o9d9ixav1jm6bd0r7A" \
-H "Content-Type: application/json; charset=utf-8" \
-H "X-Goog-User-Project: test-prod "
--data "{
'eventType': 'home-page-view',
'visitorId': '12',
'eventTime': '2021-09-28T03:33:33.000001Z',
'attributionToken': 'ABC',
'attributes': {
'city_name': {
'text': ['karachi']
},
'device_type': {
'text': ['iOS']
},
},
'userInfo': {
'userId': '123',
}
}"\
"https://retail.googleapis.com/v2/projects/test-prod/locations/global/catalogs/default_catalog/userEvents:write"
detail-page-view的 Curl 调用
curl -X POST \
-H "Authorization: Bearer ya29.a0ARrdaMsd84J0FBtZdJp2jhu87ZJcPlr4-8NqkDdGmLYtQ7P-riTx5man4k2maqWGNIsL1007a4BClSsfVbgHyjycaKn_5bviofo5XCbvCeO5-kkepnb_RUgy6prxRX7X8pi2PFHxX-kbmSmQgeEoILQQnK_aYMtTagSFzkUXC12Q2A5VmlfXR5cvSW-a80XxGOikpEf1jHuwusQx2EftBITnhPaYvg6Xi08qzvAEnmKgYicqCqY5o9d9ixav1jm6bd0r7A" \
-H "Content-Type: application/json; charset=utf-8" \
-H "X-Goog-User-Project: test-prod "
--data "{
'eventType': 'detail-page-view',
'visitorId': '123',
'eventTime': '2021-09-28T03:33:33.000001Z',
'attributionToken': 'ABC',
'attributes': {
'city_name': {
'text': ['karachi']
},
'device_type': {
'text': ['iOS']
},
},
'productDetails': [{
'product': {
'id': '2806'
}
}],
'userInfo': {
'userId': '123',
}
}"\
"https://retail.googleapis.com/v2/projects/test-prod/locations/global/catalogs/default_catalog/userEvents:write"
add-to-cart的 Curl 调用
curl -X POST \
-H "Authorization: Bearer ya29.sdrdaM8jq4J0FBtZdJp2jsdshu87ZJcPlr4-8NqkDdGmLYtQ7P-riTx5man4k2maqWGNIsL1007a4BClSsfVbgHyjycaKn_5bviofo5XCbvCeO5-kkepnb_RUgy6prxRX7X8pi2PFHxX-kbmSmQgeEoILQQnK_aYMtTagSFzkUXC12Q2A5VmlfXR5cvSW-a80XxGOikpEf1jHuwusQx2EftBITnhPaYvg6Xi08qzvAEnmKgYicqCqY5o9d9ixav1jm6bd0r7A" \
-H "Content-Type: application/json; charset=utf-8" \
-H "X-Goog-User-Project: test-prod"
--data "{
'eventType': 'add-to-cart',
'visitorId': '123',
'eventTime': '2021-09-28T03:33:33.000001Z',
'attributionToken': 'ABC',
'attributes': {
'city_name': {
'text': ['karachi']
},
'device_type': {
'text': ['iOS']
},
},
'productDetails': [{
'product': {
'id': '2806'
},
'quantity':1
}],
'userInfo': {
'userId': '123',
}
}"\
"https://retail.googleapis.com/v2/projects/test-prod/locations/global/catalogs/default_catalog/userEvents:write"
purchase-complete的 Curl 调用
curl -X POST \
-H "Authorization: Bearer ya29.a0ARrsddaM8jJ0FBtZdJsdjhu87ZJcPlr4-8NqkDdGmLYtQ7P-riTx5man4k2maqWGNIsL1007a4BClSsfVbgHyjycaKn_5bviofo5XCbvCeO5-kkepnb_RUgy6prxRX7X8pi2PFHxX-kbmSmQgeEoILQQnK_aYMtTagSFzkUXC12Q2A5VmlfXR5cvSW-a80XxGOikpEf1jHuwusQx2EftBITnhPaYvg6Xi08qzvAEnmKgYicqCqY5o9d9ixav1jm6bd0r7A" \
-H "Content-Type: application/json; charset=utf-8" \
-H "X-Goog-User-Project: test-prod "
--data "{
'eventType': 'purchase-complete',
'visitorId': '123',
'eventTime': '2021-09-28T03:33:33.000001Z',
'attributionToken': 'ABC',
'attributes': {
'city_name': {
'text': ['karachi']
},
'device_type': {
'text': ['iOS']
},
},
'productDetails': [{
'product': {
'id': '2806'
},
'quantity':'1'
}],
'purchaseTransaction':{
"id": 'transacion-id-here',
"revenue": 'orderPrice-here',
"currencyCode": 'USD',
"quantity":'1'
}
'userInfo': {
'userId': '123',
}
}"\
"https://retail.googleapis.com/v2/projects/test-prod/locations/global/catalogs/default_catalog/userEvents:write"
你需要使用 Google Cloud 生成授权令牌以发送事件。
你可以在 Retail AI 的事件标签页中查看实时事件。

Retail AI 用户事件,由 Muffaddal 提供
注意:如果你在导入目录之前或期间记录了用户事件,请 重新加入任何事件 这些事件是在目录导入完成之前记录的。
最后,我们只差一步就能创建第一个推荐模型。
构建推荐模型
Google Cloud Recommendation AI 支持以下机器学习模型。

Google Cloud Recommendation AI 支持的 AI 模型,由 Muffaddal 提供
对于这篇文章,我们将使用 Recommended For You 模型。你可以 点击这里查看有关可用模型的更多细节。
进入模型选项卡,点击“创建模型”开始配置你的 ML 模型。

Google Cloud Recommendation AI 中的模型创建,由 Muffaddal 提供
接下来,在模型类型中选择 Recommended For You 模型。

Google Cloud Recommendation AI 中的 AI 模型选择,由 Muffaddal 提供
我们的目标是提高购买量,因此我们希望我们的 ML 模型优化转换率。选择 conversion rate (CVR) 作为模型目标。

Google Cloud Recommendation AI 中的模型目标,由 Muffaddal 提供
设置调优频率为 every three month,并按属性值过滤为 auto

Google Cloud Recommendation AI 模型的调优和选项卡设置,由 Muffaddal 提供
然后点击“创建”按钮。这将开始模型训练。等待一两天,AI 模型将准备就绪。时间取决于训练所需的数据量。
服务
一旦 AI 模型创建完成,就该配置服务,以便我们可以调用 AI 模型并获取个性化列表。
进入服务配置,点击顶部的“创建服务”按钮。选择下图所示的推荐设置。

Google Cloud Recommendation AI 中的模型服务配置,由 Muffaddal 提供
命名你的服务并点击“继续”。
选择我们创建的 recommended for you 模型。这将把我们的模型附加到这个服务配置中。
偏好选项卡是你决定模型行为的地方。我们可以使用自动设置。

Google Cloud Recommendation AI 的偏好设置,由 Muffaddal 提供
点击“创建”按钮。注意配置的 ID,因为它将用于调用模型 API。
获取推荐
以下是从我们创建的模型获取推荐的 curl 调用。
curl -X POST \
-H "Authorization: Bearer ya29.aARrdaM9Bm57OTsdsIQAzGT15GwYzZpVfssffknWPNJ8gpKRk6IHSFmGqs1nBpAlaRRg2fQvtJgtUDGsuIc-h-j0RMLkAPy7FjxQ4tQbYZl62ba-4q4oRx-oY2KwYDA-pEQW77SACo2a8hS1zEUZHyyHCO3V-PycSBetJeldjib5VYo969D1PFVF33WSSRLPIP9uBcTW9ABoYthSOioTePlaICbwV1p8dlXesnCH8PdPNuKPxJJI3rzrnIghKXUKSQb4E-mc" \
-H "Content-Type: application/json; charset=utf-8" \
-H "X-Goog-User-Project: test-prod"\
--data '{
"pageSize":100,
"userEvent": {
"eventType": "home-page-view",
"visitorId": "123",
"userInfo": {
"userId": "123"
},
"experimentIds": "123"
}
}' \
https://retail.googleapis.com/v2/projects/test-prod/locations/global/catalogs/default_catalog/placements/<your serving id here>:predict
传递准确的用户 ID、服务 ID、项目 ID 和授权令牌,你将获得所提供用户 ID 的个性化列表。
与你的开发人员分享以上代码,他应该能够在你的网站和应用程序上提供个性化的部分。
实验
我强烈建议你在首次启动推荐系统时进行 A/B 测试。这将帮助你了解提供个性化体验带来的实际价值。
最终思考
这篇文章应该足够帮助你轻松构建一个完全可扩展的推荐系统。但请记住,使用 Google Cloud Recommendation AI 还有更多细节需要了解。
例如,在设置多个模型时,归因令牌至关重要,保持目录更新也很重要,以便向用户提供最新的产品,同时还需考虑 Google Cloud 的定价问题。还有许多其他方面。
如果你需要我在 Google Cloud Recommendation AI 方面的帮助,随时联系我。
结论
如果你正在寻找一种强大的方法来构建个性化推荐系统,Google Cloud Recommendation AI 是你的解决方案。
Google Cloud Recommendation AI 不仅允许你使用 Google 自有的机器学习模型,还能让你跳过整个推荐系统架构设计的过程。它能迅速为你提供推荐的价值。
及时实现 Google Cloud Recommendation AI,以提高用户参与度、留存率和业务收入。
类似阅读
推荐系统用于在产品详细页面上向用户推荐类似的项目。
towardsdatascience.com ## 使用 BigQuery ML 进行 RFM 分析
使用 RFM 分析在 BigQuery ML 中进行用户细分,并在 Data Studio 中进行可视化。
towardsdatascience.com
用 Go 构建 REST API:时间序列的数据分析
原文:
towardsdatascience.com/build-a-rest-api-with-go-data-analytics-for-time-series-8d3f0747665c
一个关于使用 Go、Gin 和 Gorm 进行 CRUD 操作和统计分析的逐步示例。
·发布于 Towards Data Science ·阅读时间 10 分钟·2023 年 11 月 7 日
--

“学习 Go”,作者 Giulia Roggia。经许可使用。
-
介绍
-
模型与数据库
-
统计
-
处理程序
-
用法
介绍
本文介绍了一个使用 Go 构建的 REST API 示例,用于执行 CRUD(创建、读取、更新、删除)操作并计算时间序列数据的统计信息。
本文中使用的完整代码可以在 这里找到。
为什么选择 Go?
Go 是构建 REST API 的常见选择,原因有很多。尽管它是高效的编译语言,但其语法简单且可读性强。它使得实现并发处理变得容易。它提供了功能丰富的标准库,并且拥有一个出色的库和工具生态系统。
在这个例子中,我们使用了两个流行的 Go 库:
文件夹结构
我们为每个“服务”创建一个文件夹。在 Go 中,每个文件夹对应一个包,每个文件都可以访问同一包中其他文件定义的元素。以下是项目中使用的文件夹结构:
├── database
│ └── database.go
├── models
│ └── models.go
├── handlers
│ ├── routes.go
│ ├── stats.go
│ ├── timeseries.go
│ └── timeseriesvalues.go
├── stats
│ └── stats.go
├── go.mod
├── go.sum
└── main.go
模型与数据库
首先在 models.go 中定义 ORM 模型,以表示时间序列数据。我们使用了两个模型,一个用于标识序列,另一个用于存储其值。每个值通过外键与时间序列相关联。
type TimeSeries struct {
ID int `gorm:"primaryKey"`
Name string `gorm:"not null"`
}
type TimeSeriesValue struct {
ID int `gorm:"primaryKey"`
Time time.Time
Value float64
TimeSeriesID int `gorm:"not null"`
}
我们还定义了两个用于时间序列值的 getter,这在计算统计数据时将非常有用。
func (v TimeSeriesValue) GetTime() time.Time { return v.Time }
func (v TimeSeriesValue) GetValue() float64 { return v.Value }
在database.go中,我们定义了一个获取数据库的函数(在示例中使用SQLite),并应用迁移以创建每个模型的 SQL 表(如果尚不存在)。
func GetDatabase(dbFile string) (*gorm.DB, error) {
return gorm.Open(sqlite.Open(dbFile), &gorm.Config{})
}
func AutoMigrate(db *gorm.DB) {
db.AutoMigrate(&models.TimeSeries{}, &models.TimeSeriesValue{})
}
统计量
注意:如果你只对 CRUD 部分感兴趣,可以跳过这一部分。
由于统计量的计算是一个常见任务,为了提高代码的可重用性,我们将其开发为一个“独立”包,这意味着它不依赖于项目中定义的其他包。为此,我们定义一个TsValue接口:统计函数的输入是一个值的切片,每个值必须定义接口方法GetTime和GetValue。
// Interface for a data point in a time series
type TsValue interface {
GetTime() time.Time
GetValue() float64
}
现在我们为每个统计量定义一个函数:
Count:系列的长度。
func Count(values []TsValue) float64 {
return float64(len(values))
}
Min:最小值(以及Max,其功能类似)。
func Min(values []TsValue) float64 {
if len(values) == 0 {
return math.NaN()
}
min := values[0].GetValue()
for _, value := range values {
if value.GetValue() < min {
min = value.GetValue()
}
}
return min
}
Mean:系列的均值。
func Mean(values []TsValue) float64 {
count := Count(values)
if count == 0 {
return math.NaN()
}
sum := 0.0
for _, value := range values {
sum += value.GetValue()
}
return sum / count
}
StandardDeviation:系列的标准差。
func StandardDeviation(values []TsValue) float64 {
count := Count(values)
if count == 0 {
return math.NaN()
}
mean := Mean(values)
sumSquare := 0.0
for _, value := range values {
sumSquare += math.Pow(value.GetValue()-mean, 2)
}
return math.Sqrt(sumSquare / count)
}
我们可以将统计量集中在一个函数中:
func ComputeStatistics(values []TsValue) map[string]float64 {
result := make(map[string]float64)
result["Count"] = Count(values)
result["Min"] = Min(values)
result["Max"] = Max(values)
result["Mean"] = Mean(values)
result["StdDev"] = StandardDeviation(values)
return result
}
请注意,函数ComputeStatistics是顺序计算统计量的。如果每个系列中的值很多,这可能会变得耗时。如果是这样,我们可以使用 goroutines 并发计算统计量。
让我们定义一个并发版本ComputeStatisticsConcurrent,其执行以下步骤:
-
创建一个通道来收集结果,一个同步 goroutines 的等待机制,以及一个包含要计算的统计量的映射。
-
设置一个计数器来计算要计算的统计量数量。
-
为每个统计函数启动一个 goroutine。
-
设置一个特殊的 goroutine,以便在所有统计量计算完成后(即计数器归零时)关闭通道。
-
从通道中收集结果。收集在通道关闭时停止。
func ComputeStatisticsConcurrent(values []TsValue) map[string]float64 {
result := make(map[string]float64)
// 1) Create a WaitGroup to sync, a Channel to collect results, and
// a map with the statistics to compute.
var wg sync.WaitGroup
ch := make(chan struct {
string
float64
})
statsToCompute := map[string]func([]TsValue) float64{
"Count": Count,
"Min": Min,
"Max": Max,
"Mean": Mean,
"StandardDeviation": StandardDeviation,
}
// 2) Set how many stats to compute.
wg.Add(len(statsToCompute))
// 3) Compute each stat in a separate goroutine.
for statName, statFunc := range statsToCompute {
go func(name string, operation func([]TsValue) float64) {
defer wg.Done()
value := operation(values)
ch <- struct {
string
float64
}{name, value}
}(statName, statFunc)
}
// 4) Set a goroutine to close the channel once all the
// stats are computed.
go func() {
wg.Wait()
close(ch)
}()
// 5) Collect the results from the channel.
for stat := range ch {
result[stat.string] = stat.float64
}
return result
}
注意:如果你想在不使用大数据的情况下测试性能提升,可以在每个统计函数内部运行time.Sleep进行模拟。
处理程序
现在让我们定义 REST API 的端点。因为:
-
端点逻辑依赖于数据库。
-
我们不想使用全局变量。
-
端点函数必须只接受 Gin 上下文指针作为输入,该指针包含请求信息。
我们创建一个数据库包装器,并将端点方法分配给它:
// Create a new type that embeds database.Database to assign new
// methods to it
type wrapDB struct {
DB *gorm.DB
}
我们定义一个函数来设置所有端点。每个端点指定一个路径,并指定一个接受 Gin 上下文指针的函数。
func SetEndpoints(r *gin.Engine, db *gorm.DB) {
wrapdb := &wrapDB{DB: db}
// timeseries CRUD
r.GET("/timeseries", wrapdb.listTimeSeries)
r.GET("/timeseries/:tsid", wrapdb.getTimeSeries)
r.POST("/timeseries", wrapdb.postTimeSeries)
r.PUT("/timeseries/:tsid", wrapdb.putTimeSeries)
r.DELETE("/timeseries/:tsid", wrapdb.deleteTimeSeries)
// timeseries values CRUD
r.GET("/timeseries/:tsid/values", wrapdb.listTimeSeriesValues)
r.GET("/timeseries/:tsid/values/:valueid", wrapdb.getTimeSeriesValue)
r.POST("/timeseries/:tsid/values", wrapdb.postTimeSeriesValues)
r.PUT("/timeseries/:tsid/values/:valueid", wrapdb.putTimeSeriesValue)
r.DELETE("/timeseries/:tsid/values/:valueid", wrapdb.deleteTimeSeriesValue)
// statistics
r.GET("/timeseries/:tsid/statistics", wrapdb.getTimeSeriesStats)
}
处理程序:系列
让我们从定义时间序列端点开始。下面是列出所有可用时间序列的实现。当调用DB.Find时,时间序列存储在tseriesList切片中。
func (db *wrapDB) listTimeSeries(c *gin.Context) {
var tseriesList []models.TimeSeries
if err := db.DB.Find(&tseriesList).Error; err != nil {
c.JSON(500, gin.H{"error": "Failed to retrieve timeseries"})
return
}
c.JSON(200, tseriesList)
}
在开发TimeSeries的基础 CRUD 端点之前,让我们定义一个辅助函数来从请求中获取系列 ID 并执行一些检查。
func (db *wrapDB) checkTimeSeriesID(c *gin.Context) (int, error) {
id := c.Param("tsid")
// Check if the time series exists can be converted to
timeSeriesID, err := strconv.Atoi(id)
if err != nil {
c.JSON(400, gin.H{"error": "Invalid time series ID"})
return timeSeriesID, err
}
// Check if the time series exists in the TimeSeries table
var timeSeries models.TimeSeries
if err := db.DB.First(&timeSeries, timeSeriesID).Error; err != nil {
c.JSON(404, gin.H{"error": "Time series not found"})
return timeSeriesID, err
}
return timeSeriesID, nil
}
我们现在可以定义TimeSeries的基本 CRUD 操作:
- 创建:POST 一个新的系列。
func (db *wrapDB) postTimeSeries(c *gin.Context) {
var tseries models.TimeSeries
c.BindJSON(&tseries)
if err := db.DB.Create(&tseries).Error; err != nil {
c.JSON(500, gin.H{"error": "Failed to create timeseries"})
return
}
c.JSON(201, tseries)
}
- 读取:GET 一个现有系列。
func (db *wrapDB) getTimeSeries(c *gin.Context) {
timeSeriesID, err := db.checkTimeSeriesID(c)
if err != nil {
return
}
var tseries models.TimeSeries
if err := db.DB.Where("id = ?", timeSeriesID).First(&tseries).Error; err != nil {
c.JSON(404, gin.H{"error": "Failed to retrieve timeseries"})
return
}
c.JSON(200, tseries)
}
- 更新:PUT 一个现有系列。
func (db *wrapDB) putTimeSeries(c *gin.Context) {
timeSeriesID, err := db.checkTimeSeriesID(c)
if err != nil {
return
}
var tseries models.TimeSeries
if err = db.DB.Where("id = ?", timeSeriesID).First(&tseries).Error; err != nil {
c.JSON(404, gin.H{"error": "Time series not found"})
return
}
c.BindJSON(&tseries)
if err = db.DB.Save(&tseries).Error; err != nil {
c.JSON(500, gin.H{"error": "Error while saving"})
}
c.JSON(200, tseries)
}
- 删除:删除一个系列及其值。为了在运行时出现错误时进行回滚,我们在数据库事务中执行这两个操作。
// Delete a time series and its values
func (db *wrapDB) deleteTimeSeries(c *gin.Context) {
timeSeriesID, err := db.checkTimeSeriesID(c)
if err != nil {
return
}
var tseries models.TimeSeries
// Delete in a transaction
db.DB.Transaction(func(tx *gorm.DB) error {
// Delete values
if err := tx.Where("time_series_id = ?", timeSeriesID).
Delete(&models.TimeSeriesValue{}).Error; err != nil {
c.JSON(500, gin.H{"error": "Deleting the timeseries failed."})
return err
}
// Delete timeseries
if err := tx.Where("id = ?", timeSeriesID).
Delete(&tseries).Error; err != nil {
c.JSON(500, gin.H{"error": "Deleting the timeseries failed."})
return err
}
c.JSON(200, gin.H{"id #" + strconv.Itoa(timeSeriesID): "deleted"})
return nil
})
}
处理程序:值
我们现在可以定义 TimeSeriesValues 的端点。由于它们与我们刚刚定义的端点类似,我们仅展示 Create 方法的实现,它与系列方法不同,因为我们允许同时发布多个值。
func (db *wrapDB) postTimeSeriesValues(c *gin.Context) {
timeSeriesID, err := db.checkTimeSeriesID(c)
if err != nil {
return
}
// Bind the request body to a slice of TimeSeriesValue
var timeSeriesValues []models.TimeSeriesValue
if err := c.ShouldBindJSON(&timeSeriesValues); err != nil {
c.JSON(400, gin.H{"error": "Invalid request data"})
return
}
// Set the TimeSeriesID for the posted values
for i := range timeSeriesValues {
timeSeriesValues[i].TimeSeriesID = timeSeriesID
}
// Create the values in the database
if err := db.DB.Create(&timeSeriesValues).Error; err != nil {
c.JSON(500, gin.H{"error": "Failed to create time series values"})
return
}
c.JSON(201, gin.H{"message": "Time series values created"})
}
处理程序:统计信息
我们需要定义的最后一个处理程序是计算给定系列的统计信息。根据其 ID,我们查询值并使用之前实现的 ComputeStatisticsConcurrent 函数计算统计信息。
func (db *wrapDB) getTimeSeriesStats(c *gin.Context) {
timeSeriesID, err := db.checkTimeSeriesID(c)
if err != nil {
return
}
// Query the database.Database to retrieve specific fields of time series values
var tsValues []models.TimeSeriesValue
if err := db.DB.Where("time_series_id = ?", timeSeriesID).
Select("id, time, value").Find(&tsValues).Error; err != nil {
c.JSON(500, gin.H{"error": "Failed to retrieve time series values"})
return
}
// Convert the slice of TimeSeriesValue to a slice of TsValue interfaces
values := []stats.TsValue{}
for _, v := range tsValues {
values = append(values, v)
}
c.JSON(200, serializeMap(stats.ComputeStatisticsConcurrent(values)))
}
在返回统计信息之前,我们将 serializeMap 函数应用于输出。这是为了将 NaN 值转换为 nil,以便使输出可以进行 JSON 序列化。注意,映射中的输出值类型是空接口 interface{},它可以容纳任何类型的值。
// SerializeMap serializes a map from string to float64,
// returning null for NaN values
func serializeMap(data map[string]float64) map[string]interface{} {
serializedData := make(map[string]interface{})
for key, value := range data {
if math.IsNaN(value) {
serializedData[key] = nil
} else {
serializedData[key] = value
}
}
return serializedData
}
使用方法
使应用程序可用的最后一段代码是 main 函数,用于运行 Web 服务器。我们提供了三个可选的命令行参数:
-
db:指定数据库文件的名称(默认:timeseries.db) -
proxy:设置服务器的受信任代理(默认:127.0.0.1) -
port:定义服务器监听的端口号(默认:8080)
func main() {
// Define flags for command-line arguments
dbName := flag.String("db", "timeseries.db", "Database name")
trustedProxy := flag.String("proxy", "127.0.0.1", "Trusted proxy")
port := flag.String("port", "8080", "Port number")
flag.Parse()
// Initialize the database
db, err := database.GetDatabase(*dbName)
if err != nil {
panic("failed to connect to the database")
}
database.AutoMigrate(db)
// Set the trusted proxies
router := gin.Default()
router.SetTrustedProxies([]string{*trustedProxy})
// Set endpoints and run the server
handlers.SetEndpoints(router, db)
router.Run(fmt.Sprintf(":%s", *port))
}
构建
你现在可以通过以下方式运行 API:
-
使用
go mod init初始化一个新的 Go 模块。这将生成一个带有模块名称和 Go 版本的 go.mod 文件。 -
使用
go mod tidy跟踪依赖项。这将把项目依赖项添加到 go.mod 文件中,并创建一个包含所有依赖项校验和的 go.sum 文件(用于确保其完整性)。 -
使用
go build构建程序并运行生成的可执行文件。这也可以通过go run main.go一步完成。
go.mod 和 go.sum 文件已在 项目仓库 中提供。你可以在 这里 阅读更多关于 go 命令的信息。
我们还可以使用 Docker 运行 API。在下面显示的 Dockerfile 中,首先,我们使用 Go 官方镜像构建可执行文件,然后将其复制到一个最小的镜像中,这样得到一个约 20Mb 的小文件。
# Stage 1: Build the Go application
FROM golang:1.21-alpine3.18 AS build
WORKDIR /app
COPY . .
RUN apk add build-base
# Download (and cache) all dependencies.
RUN go mod download
# Build the Go app with static linking
RUN go build -ldflags="-w -s" -o tsapi
# Stage 2: Create a minimal image for running the application
FROM alpine:3.18
WORKDIR /app
COPY --from=build /app/tsapi .
EXPOSE 8080
CMD ["./tsapi"]
我们可以按照以下步骤使用 Docker 构建和执行 API:
# Build the docker image
docker build -t go-timeseries-api .
# Run the docker image
docker run -p 8080:8080 go-timeseries-api
示例用法
应用程序启动并运行后,我们可以通过 API 调用来测试它。下面我们展示了一个使用 curl 的示例。我们创建一个时间序列,发布其值,计算统计信息,最后删除它。
# Create a time series
curl -i -X POST http://localhost:8080/timeseries -d '{ "Name": "My time series"}'
> {"ID":1,"Name":"My time series"}
# Create time series values
curl -i -X POST http://localhost:8080/timeseries/1/values -d '[{"Time": "2023-10-28T12:00:00Z", "Value": 10.0},{"Time": "2023-10-28T12:15:00Z", "Value": 20.5}]'
> {"message":"Time series values created"}
# Get statistics
curl -i -GET http://localhost:8080/timeseries/1/statistics
> {"Count":2,"Max":20.5,"Mean":15.25,"Min":10,"StandardDeviation":5.25}
# Delete a time series
curl -i -X DELETE http://localhost:8080/timeseries/1
> {"id #1":"deleted"}
关于学习 Go 的说明
这是我在 Go 中的第一个项目。因此,如果你有任何反馈,我将非常感激!
我来自 Python 背景,我特别喜欢 Go 的速度和轻量性,并发现它比 C++ 学起来容易多了。以下是一些帮助我入门的资源:
-
Go 之旅:对 Go 的互动式介绍,也是官方的
本文中使用的完整代码可以在 这里 找到。
喜欢这篇文章吗? 查看我的其他文章 并关注我以获取更多内容! 点击这里 以无限阅读文章并以零额外费用支持我 ❤️
使用 Node.js 构建语音转文本网页应用
原文:
towardsdatascience.com/build-a-speech-to-text-web-app-using-node-js-210f8c054d79
我们来构建一个使用 OpenAI 的 Whisper 模型来转录和翻译音频的网页应用
·发表于 Towards Data Science ·阅读时间 11 分钟·2023 年 3 月 27 日
--

大家好!希望你们一切都好。今天,我们将使用 Node.js 和 OpenAI 的 API 构建一个语音转文本网页应用。我们将使用 OpenAI 的 API 来使用其 Whisper 模型,该模型允许我们上传 mp3 格式的音频文件,并提供其文字稿。它甚至可以将其他语言的音频翻译成英文文本,实在令人惊叹。
首先,我们将设置一个新的 Node.js 项目,以便我们可以开始构建我们的应用。因此,我们将创建一个文件夹,在其中构建我们的项目,并使用命令行进入该文件夹,然后我们可以使用以下命令来设置一个新的 Node.js 项目:
npm init
运行此命令后,它将询问一些问题,例如应用程序的名称、入口点等。我们现在可以将其保留为默认设置。之后,你会看到它创建了一个package.json文件。该文件包含有关我们应用程序的信息以及我们为该应用程序安装了哪些包。
接下来的一步是将必要的 Node 模块,即包,安装到我们的应用中,以便我们可以开始构建应用。我们可以通过运行以下命令来完成:
npm install express multer openai cors --save
我们安装这四个包,并使用--save将这些包添加到package.json文件中。这使得克隆仓库的人只需运行一次npm install命令即可轻松安装所有所需的包。
我们还希望在应用程序中使用 nodemon 包,以帮助我们在检测到代码更改时自动刷新和重新加载服务器,这样我们就不需要在每次更改后手动重新启动服务器。因此,我们将其添加为开发依赖项,因为它只用于开发中的帮助,我们不会直接在代码中使用它。我们可以使用以下命令安装它:
npm install --save-dev nodemon
现在我们拥有了开始开发所需的所有包。正如我们在package.json文件中看到的,里面列出了我们安装的所有模块和包,以及一些关于应用程序的细节。package.json文件应如下所示:
{
"name": "speechtext",
"version": "1.0.0",
"description": "",
"main": "index.js",
"scripts": {
"test": "echo \"Error: no test specified\" && exit 1"
},
"author": "",
"license": "ISC",
"dependencies": {
"cors": "².8.5",
"express": "⁴.18.2",
"multer": "¹.4.5-lts.1",
"openai": "³.2.1"
},
"devDependencies": {
"nodemon": "².0.22"
}
}
所以,如我们所见,index.js是写在main中,表示index.js文件是我们应用程序的入口点。如果你还记得,这在我们运行npm init命令时的设置过程中有所要求。如果你将其留作默认设置,你将拥有相同的入口点;否则,你将拥有当时定义的入口点。
现在,我们将创建一个名为index.js的新文件。你可以根据自己定义的入口点命名文件。我们考虑使用index.js。
index.js
所以,我们现在将开始构建index.js文件。我们将从将所需模块导入到我们的应用程序开始。对于索引文件,我们需要express和cors。因此,我们首先需要这两个模块:
const express = require('express');
const cors = require('cors');
接下来,我们将创建一个新的express应用程序实例。同时,我们将设置应用程序以使用 cors,处理 json 数据,并使public文件夹包含静态文件,这些文件可以由客户端或前端访问。
const app = express();
app.use(express.static('public'));
app.use(express.json());
app.use(cors());
接下来,我们希望有一个单独的文件来定义 API。我们将创建一个名为routes的文件夹,并在其中创建一个名为api.js的文件,在该文件中定义应用程序所需的 GET 和 POST API。为了让应用程序知道这一点,我们将添加这一行代码,以定义基本 URL 和所有 API 定义文件的位置。这是一个中间件,帮助我们设置应用程序的路由。
app.use('/', require('./routes/api'));
接下来,我们使用一个错误处理的中间件函数,该函数将用于处理应用程序中发生的任何错误。
app.use(function(err,req,res,next){
res.status(422).send({error: err.message});
});
最后,我们设置应用程序以监听指定端口号的传入请求,我们可以通过使用环境变量来设置,或者直接定义。
app.listen(process.env.PORT || 4000, function(){
console.log('Ready to Go!');
});
我们为我们的应用程序使用了 4000 端口。我们还在其中添加了一个简单的console.log,当应用程序准备好接收请求时,它会在控制台上打印一条消息。
完整的index.js文件:
const express = require('express');
const cors = require('cors');
const app = express();
app.use(express.static('public'));
app.use(express.json());
app.use(cors());
app.use('/', require('./routes/api'));
app.use(function(err,req,res,next){
res.status(422).send({error: err.message});
});
app.listen(process.env.PORT || 4000, function(){
console.log('Ready to Go!');
});
接下来,我们将转到api.js文件,该文件位于routes文件夹内。
api.js
所以,我们现在将开始构建api.js文件。我们将通过将所需模块导入文件来开始。我们将导入express、multer和openai库。
const express = require("express");
const multer = require("multer");
const { Configuration, OpenAIApi } = require("openai");
Multer 是一个中间件,我们使用它来处理 multipart/form-data,因为我们将处理音频文件的上传。
从 openai 中,我们需要 Configuration 和 OpenAIApi 模块,用于向 Whisper 模型发送 API 请求。
然后,我们将设置 express 路由器,并创建一个 multer 中间件的实例。
const router = express.Router();
const upload = multer();
接下来,我们将配置 OpenAI 并创建一个新的配置实例。我们需要一个 OpenAI 密钥,我们必须在这里将其作为 API 密钥。你可以从 这里 获取密钥。
const configuration = new Configuration({
apiKey: process.env.OPENAI_KEY,
});
现在,我们创建一个异步函数,该函数接受一个包含歌曲数据的缓冲区,并返回当我们调用 OpenAI 的 Whisper 模型 API 时收到的响应。
async function transcribe(buffer) {
const openai = new OpenAIApi(configuration);
const response = await openai.createTranscription(
buffer, // The audio file to transcribe.
"whisper-1", // The model to use for transcription.
undefined, // The prompt to use for transcription.
'json', // The format of the transcription.
1, // Temperature
'en' // Language
)
return response;
}
如上所示,我们首先通过使用在代码中之前定义的配置创建 OpenAI 类的新实例。然后我们调用 OpenAI 的 createTranscription 函数,并在其中使用 await 关键字,以便我们在继续之前等待响应。
我们在函数中传递所需的参数,其中包括包含歌曲数据的缓冲区,以及用于转录的模型,在我们的情况下是 whisper-1。然后我们将 prompt 留空。如果愿意,你也可以提供一个 prompt,这将帮助模型通过提供与 prompt 类似的风格来更好地转录音频。我们将接收的数据类型定义为 json,将温度设置为 1,并定义我们希望输出的语言。
接下来,我们将定义 GET 请求。我们使用 sendFile 发送一个包含我们表单的 HTML 文件,用户可以在其中上传音频文件。我们稍后将构建这些 HTML 文件。我们将它提供在基本 URL 上。
router.get("/", (req, res) => {
res.sendFile(path.join(__dirname, "../public", "index.html"));
});
接下来,我们定义 POST 请求,它将处理音频文件的上传。我们使用 multer 中间件来管理文件上传部分。然后我们从音频文件创建一个缓冲区,该缓冲区将包含可以发送到 OpenAI API 的音频文件数据。我们使用上传音频文件的原始名称为缓冲区设置一个名称。
然后我们调用 transcribe 函数,一旦获得响应,我们将 JSON 数据发送回客户端。我们将转录文本和音频文件名返回到前端。我们还有一个 catch 方法来处理任何错误。
router.post("/", upload.any('file'), (req, res) => {
audio_file = req.files[0];
buffer = audio_file.buffer;
buffer.name = audio_file.originalname;
const response = transcribe(buffer);
response.then((data) => {
res.send({
type: "POST",
transcription: data.data.text,
audioFileName: buffer.name
});
}).catch((err) => {
res.send({ type: "POST", message: err });
});
});
最后,我们导出 router 模块,这样其他文件就可以导入它们。
module.exports = router;
所以,api.js 文件的完整代码是:
const express = require("express");
const multer = require("multer");
const { Configuration, OpenAIApi } = require("openai");
const router = express.Router();
const upload = multer();
const configuration = new Configuration({
apiKey: process.env.OPENAI_KEY,
});
async function transcribe(buffer) {
const openai = new OpenAIApi(configuration);
const response = await openai.createTranscription(
buffer, // The audio file to transcribe.
"whisper-1", // The model to use for transcription.
undefined, // The prompt to use for transcription.
'json', // The format of the transcription.
1, // Temperature
'en' // Language
)
return response;
}
router.get("/", (req, res) => {
res.sendFile(path.join(__dirname, "../public", "index.html"));
});
router.post("/", upload.any('file'), (req, res) => {
audio_file = req.files[0];
buffer = audio_file.buffer;
buffer.name = audio_file.originalname;
const response = transcribe(buffer);
response.then((data) => {
res.send({
type: "POST",
transcription: data.data.text,
audioFileName: buffer.name
});
}).catch((err) => {
res.send({ type: "POST", message: err });
});
});
module.exports = router;
现在,我们已经完成了所有的后端部分。接下来我们将编写 HTML 文件和一些前端 JavaScript 代码,以处理表单提交和数据保存在 本地存储 以及从本地存储中检索数据。
我们创建一个 public 文件夹,在其中我们将创建两个 HTML 文件——index.html 和 transcribe.html。
我们将从 index.html 文件开始:
index.html
在这个文件中,我们将构建一个页面,展示上传音频文件的表单。我们将使用 Bootstrap CSS,通过 CDN 引入它。我们还在 HTML 文件的末尾通过 CDN 引入 Bootstrap JS。
然后,我们创建了一个简单的卡片,要求用户上传音频文件。我确保提交的文件格式为.mp3,因为这是 OpenAI 的 API 唯一接受的格式。我们显示一个按钮,点击后提交表单。
然后,我们有了处理表单提交的 JavaScript 代码。所以,首先,我们通过阻止表单提交事件的默认行为来停止页面刷新。然后,我们获取表单数据,即音频文件,并将其作为 POST 请求发送到后台。接着,我们等待响应并将其存储在数据变量中。
如果数据中有可用的转录文本,我们将转录文本和音频文件名存储在本地存储中,以便在需要显示转录文本的下一页访问它们。虽然有多种方法可以传递信息,比如通过 URI 传递信息,但在这里我们使用本地存储来实现。
保存数据到本地存储后,我们更改窗口位置以加载transcribe.html文件。
<!DOCTYPE html>
<html>
<head>
<title>Speech to Text</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-GLhlTQ8iRABdZLl6O3oVMWSktQOp6b7In1Zl3/Jr59b6EGGoI1aFkw7cmDA6j6gD" crossorigin="anonymous">
</head>
<body style="background-color: #f2f2f2;">
<div class="container mt-5">
<div class="row justify-content-center">
<div class="col-md-6">
<div class="card">
<div class="card-header">
Upload Audio File
</div>
<div class="card-body">
<form id="transcription-form" enctype="multipart/form-data">
<div class="form-group">
<label for="file-upload"><b>Select file:</b></label>
<input id="file-upload" type="file" name="file" class="form-control-file" accept=".mp3" style="margin-bottom: 20px">
</div>
<input type="submit" value="Transcribe" class="btn btn-primary"></input>
</form>
</div>
</div>
</div>
</div>
</div>
<script>
document.getElementById("transcription-form").addEventListener("submit", async function (event) {
event.preventDefault();
const formData = new FormData(event.target);
const response = await fetch("/", {
method: "POST",
body: formData,
});
const data = await response.json();
if (data.transcription) {
localStorage.setItem("transcription", data.transcription);
localStorage.setItem("audioFileName", data.audioFileName);
window.location.href = "/transcribe.html";
}
else {
console.error("Error:", data.message);
}
});
</script>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.bundle.min.js" integrity="sha384-w76AqPfDkMBDXo30jS1Sgez6pr3x5MlQ1ZAGC+nuZB+EYdgRZgiwxhTBTkF7CXvN" crossorigin="anonymous"></script>
</body>
</html>
所以上述代码构建了index.html文件,该文件将向用户显示一个表单,用户可以通过这个表单上传音频文件。
这是其中一个截图的效果:

音频上传页面 — index.html
接下来,我们将构建transcribe.html文件。
transcribe.html
在这个文件中,我们将显示用户上传的音频文件的转录文本。因此,我们将再次使用 Bootstrap CSS 和 JS,通过 CDN 引入这些资源。
然后,我们定义了一些自定义 CSS 来美化元素,使其看起来更好。我们在一个容器中显示音频文件名和该音频文件的转录文本。
在此页面底部的 JavaScript 代码中,我们从本地存储中获取音频文件名和转录文本,并使用 id 将这些数据推送到相应的 HTML 元素中。
<!DOCTYPE html>
<html>
<head>
<title>Transcription</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet" integrity="sha384-GLhlTQ8iRABdZLl6O3oVMWSktQOp6b7In1Zl3/Jr59b6EGGoI1aFkw7cmDA6j6gD" crossorigin="anonymous">
<style>
h1 {
margin-top: 20px;
margin-bottom: 10px;
font-size: 2.5rem;
font-weight: bold;
color: #333;
}
p {
font-size: 1.2rem;
color: #333;
margin-bottom: 30px;
}
.container {
margin-top: 50px;
margin-bottom: 50px;
max-width: 600px;
padding: 30px;
background-color: #fff;
box-shadow: 0 0 10px rgba(0,0,0,0.2);
border-radius: 5px;
}
</style>
</head>
<body style="background-color: #f2f2f2;">
<div class="container">
<h1>Audio File:</h1>
<p id="audioFileName"></p>
<h1>Transcription:</h1>
<p id="transcription"></p>
</div>
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/js/bootstrap.bundle.min.js" integrity="sha384-w76AqPfDkMBDXo30jS1Sgez6pr3x5MlQ1ZAGC+nuZB+EYdgRZgiwxhTBTkF7CXvN" crossorigin="anonymous"></script>
<script>
const audioFileName = localStorage.getItem("audioFileName");
const transcription = localStorage.getItem("transcription");
document.getElementById("audioFileName").innerHTML = audioFileName;
document.getElementById("transcription").innerHTML = transcription;
</script>
</body>
</html>
我尝试了两个不同的小音频文件,一个是英语录制的,另一个是印地语录制的。虽然第二个音频文件是用印地语录制的,但我希望看到英语的输出,从而测试其翻译能力。它在转录两个音频文件时都非常准确。不过,在多次运行中,有时会产生不正确的模糊转录,但许多时候,转录结果大多是正确的。
我在下面附上了转录的屏幕截图。这些截图并不完全正确,但我认为它们的转录准确率大约为 85%–90%。

英语音频文件的转录

印地语音频文件的英语转录
因此,我们已经成功使用 OpenAI 的 API 和 Node.js 构建了一个语音转文本的 Web 应用程序。希望你喜欢构建这个应用,并从这篇文章中学到了一些新东西。你还可以调整参数进行实验,并比较结果,以更好地了解在不同场景下什么效果更好。
感谢阅读本文。接下来你必须阅读的一些文章包括:
[## 使用 Next.js 构建一个简单的 Todo 应用
让我们使用 Next.js 构建一个简单的 Todo 应用,这将教会你 CRUD(创建、读取、更新和...
如何使用 Next.js 构建一个简单的 Todo 应用 [## 如何在 Django 中实现多用户类型
学习如何在 Django 中实现多用户类型,处理认证,并根据用户类型进行重定向。
如何在 Django 中实现多用户类型 [## 使用 React 构建一个简单的 Todo 应用
让我们使用 React 构建一个简单的 Todo 应用,这将教会你 CRUD(创建、读取、更新和...
使用 React 构建一个简单的 Todo 应用 [## 使用 Django 构建一个社交媒体网站 — 项目设置(第一部分)
在第一部分中,我们专注于设置项目和安装所需组件,通过设置密码...
构建一个社交媒体网站使用 Django — 项目设置(第一部分)
使用 LangChain 和 GPT-3 构建一个透明的文档问答机器人
开发一个信息丰富的问答机器人指南,并显示所使用的来源
·发表于Towards Data Science ·阅读时间 11 分钟·2023 年 7 月 22 日
--

问答系统在分析大量数据或文档时可以大有帮助。然而,模型用来生成答案的来源(即文档的部分内容)通常不会在最终答案中显示。
理解回应的背景和来源不仅对寻求准确信息的用户有价值,也对希望持续改进其 QA 机器人(问答机器人的开发者)有帮助。通过回答中包含的来源,开发者可以深入了解模型的决策过程,从而促进迭代改进和微调。
本文展示了如何使用 LangChain 和 GPT-3(text-davinci-003)创建一个透明的问答机器人,通过两个示例展示了生成答案时所使用的来源。
在第一个示例中,你将学习如何创建一个透明的 QA 机器人,利用你的网站内容回答问题。在第二个示例中,我们将探讨使用来自不同 YouTube 视频的转录文本,包括带有和不带有时间戳的文本。
处理数据并创建向量存储
在我们能够利用像 GPT-3 这样的语言模型的能力之前,我们需要将文档(例如网站内容或 YouTube 转录文本)以正确的格式(首先分块,然后生成嵌入)处理,并将其存储在向量存储中。下图 1 展示了从左到右的处理流程。

图 1. 数据处理和向量存储创建的流程图(图像来源于作者)。
网站内容示例
在这个示例中,我们将处理网络门户It’s FOSS的内容,该门户专注于开源技术,特别是 Linux。
首先,我们需要获取一个所有待处理的文章列表并存储在我们的向量存储中。下面的代码读取了sitemap-posts.xml文件,该文件包含了所有文章的链接列表。
import xmltodict
import requests
r = requests.get("https://news.itsfoss.com/sitemap-posts.xml")
xml = r.text
rss = xmltodict.parse(xml)
article_links = [entry["loc"] for entry in rss["urlset"]["url"]]
在本文撰写时,列表中包含了超过 969 个文章链接。
有了链接列表,我们现在可以编写一个小的辅助函数,称为extract_content,它使用BeautifulSoup从文章页面中提取包含相关内容的特定元素。
from bs4 import BeautifulSoup
from tqdm.notebook import tqdm
def extract_content(url):
html = requests.get(url).text
soup = BeautifulSoup(html, features="html.parser")
elements = [
soup.select_one(".c-topper__headline"),
soup.select_one(".c-topper__standfirst"),
soup.select_one(".c-content"),
]
text = "".join([element.get_text() for element in elements])
return text
articles = []
# Limited the list of > 900 articles to 10 for this example
for url in tqdm(article_links[0:10], desc="Extracting article content"):
articles.append({"source": url, "content": extract_content(url)})
最后一步,我们遍历链接列表,并将我们的辅助函数extract_content应用于每个 URL。为了演示目的,我将列表限制为 10 个元素。如果你想爬取所有文章,只需从article_links[0:10]中删除[0:10]。
articles列表现在包含每篇文章的字典,字典中包括"source"(文章链接)和"content"(文章内容)。文章的链接将在最终答案中作为来源显示。
由于 GPT-3 有一个令牌限制(4,096 个令牌),因此将长文章分成块是有意义的。这些块将被组合在一起并发送给 GPT-3。
下面的代码将文章内容拆分成几个块。
from langchain.text_splitter import RecursiveCharacterTextSplitter
rec_splitter = RecursiveCharacterTextSplitter(chunk_size=1500,
chunk_overlap=150)
web_docs, meta = [], []
for article in tqdm(articles, desc="Splitting articles into chunks"):
splits = rec_splitter.split_text(article["content"])
web_docs.extend(splits)
meta.extend([{"source": article["source"]}] * len(splits))
我们在这里使用RecursiveCharacterTextSplitter,因为它旨在尽可能长时间地将语义相关的内容保持在一起。
一旦完成,我们只需执行以下命令,将文章及其来源存储到我们的向量存储中。
import os
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
os.environ["OPENAI_API_KEY"] = "YOUR KEY"
article_store = FAISS.from_texts(
texts=web_docs, embedding=OpenAIEmbeddings(), metadatas=meta
)
在这个示例中,我们使用FAISS作为向量存储,并使用OpenAIEmbeddings作为我们的嵌入模型。当然,也可以探索其他向量存储选项,例如Chroma,并尝试 Hugging Face 的嵌入模型解决方案。
注意:你还可以通过运行
article_store.save_local("your_name")来保存你的向量存储,这样你就不必每次使用时都重新创建它。更多详情见这里。
如果你不想处理 YouTube 转录内容,你可以跳过下面的部分并继续到下一节“运行透明问答”。
YouTube 转录示例
转录内容可以通过两种不同且独立的方式处理。第一个选项演示了如何处理 YouTube 转录内容,同时保留视频链接作为来源(例如,youtu.be/XYZ。)。
第二部分做了同样的事情,但说明了如何保留链接,包括时间戳(例如,https://youtu.be/XYZ&t=60)以获取更详细的信息。
对于这两种方法,使用了来自频道StatQuest的以下 YouTube 视频的转录:
YouTube 转录示例(不带时间戳)
第一部分非常直接。下面的代码利用了 LangChain 的 DocumentLoader YoutubeLoader,它集成了youtube-transcript-api和pytube。
from langchain.document_loaders import YoutubeLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import os
os.environ["OPENAI_API_KEY"] = "YOUR KEY"
yt_ids = [
"OtD8wVaFm6E", # XGBoost Part 1 (of 4): Regression
"8b1JEDvenQU", # XGBoost Part 2 (of 4): Classification
"ZVFeW798-2I", # XGBoost Part 3 (of 4): Mathematical Details
"oRrKeUCEbq8", # XGBoost Part 4 (of 4): Crazy Cool Optimizations
]
yt_docs = []
for yt_id in tqdm(yt_ids, desc="Retrieving transcripts"):
splitter = CharacterTextSplitter(chunk_size=1500, chunk_overlap=150,
separator=" ")
yt_loader = YoutubeLoader(yt_id, add_video_info=True)
yt_docs.extend(yt_loader.load_and_split(splitter))
为了避免与令牌限制发生冲突,我们使用CharacterTextSplitter将数据拆分为几个块。add_video_info设置为True,以接收视频的标题和作者信息。
返回的分块转录是文档对象。在创建嵌入并将其存储在向量存储中之前,我们通过添加有关标题、作者和视频链接的信息来操作或扩展它们的元数据。
# Manipulate / extend source attribute
for doc in yt_docs:
doc.metadata["source"] = (
doc.metadata["title"]
+ " ["
+ doc.metadata["author"]
+ "] "
+ "https://youtu.be/"
+ doc.metadata["source"]
)
# Vector store
yt_store = FAISS.from_documents(yt_docs, OpenAIEmbeddings())
YouTube 转录示例(带时间戳)
第二种方法稍微复杂一些。在这里,我们使用名为youtube-transcript-api的不同包来检索转录。输出是一个包含文本、开始时间和持续时间的字典列表。我们需要切换到不同的包,因为YoutubeLoader包不返回时间戳。
示例可以在这里看到:
[
{'text': "gonna talk about XG boost part 1 we're",
'start': 14.19,
'duration': 6.21},
{'text': 'gonna talk about XG boost trees and how',
'start': 17.91,
'duration': 6.66},
...
]
从每个文本条目创建文档对象没有多大意义,因为条目太短(例如,上述示例中每个条目只有 8 个单词),不便于后续使用。在向量存储中搜索时,只返回有限数量的匹配文档(例如,4 个),信息内容不足。
因此,我们需要首先将文本条目聚合或合并成适当的文本块。下面的代码片段包含一个自定义助手函数。
# Create transcript df
def create_transcript_df(yt_transcript: list, yt_id: str):
return (
pd.DataFrame(yt_transcript)
.assign(start_dt=lambda x: pd.to_datetime(x["start"], unit="s"))
.set_index("start_dt")
.resample("3min")
.agg({"text": " ".join})
.reset_index()
.assign(start_dt=lambda x: x["start_dt"].dt.minute * 60)
.assign(
source=lambda x: "https://youtu.be/"
+ yt_id
+ "&t="
+ x["start_dt"].astype("str")
)
.drop(columns=["start_dt"])
)
该助手应用重采样来调整时间维度为 3 分钟步长。换句话说,它将转录合并为 3 分钟的文本部分。有了这个功能,我们现在可以开始获取和处理转录。
from youtube_transcript_api import YouTubeTranscriptApi
yt_ids = [
"OtD8wVaFm6E", # XGBoost Part 1 (of 4): Regression
"8b1JEDvenQU", # XGBoost Part 2 (of 4): Classification
"ZVFeW798-2I", # XGBoost Part 3 (of 4): Mathematical Details
"oRrKeUCEbq8", # XGBoost Part 4 (of 4): Crazy Cool Optimizations
]
transcript_dfs = []
for yt_id in tqdm(yt_ids, desc="Fetching transcription"):
yt_transcript = YouTubeTranscriptApi.get_transcript(yt_id)
transcript_dfs.append(create_transcript_df(yt_transcript, yt_id))
transcripts_df = pd.concat(transcript_dfs).reset_index(drop=True)
结果的摘录可以在下图中看到。

图 2. 转录数据框的摘录(图片由作者提供)。
由于合并的 3 分钟部分现在可能会导致令牌限制问题,我们需要在生成嵌入并将其存储在向量存储中之前,使用分割器再次处理它们。
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
import os
os.environ["OPENAI_API_KEY"] = "YOUR KEY"
text_splitter = CharacterTextSplitter(separator=" ", chunk_size=1500,
chunk_overlap=150)
yt_docs, yt_meta = [], []
for index, row in tqdm(transcripts_df.iterrows(), total=len(transcripts_df)):
splits = text_splitter.split_text(row["text"])
yt_docs.extend(splits)
yt_meta.extend([{"source": row["source"]}] * len(splits))
print(f"Split {row['source']} into {len(splits)} chunks")
yt_ts_store = FAISS.from_texts(yt_docs, OpenAIEmbeddings(), metadatas=yt_meta)
运行透明问答
有了填充的向量存储后,我们现在可以专注于透明问答。下面的图给出了该过程的概述。

图 3. 透明问答过程概述(图片由作者提供)。
我们首先定义一个问题,然后由嵌入模型或 API转换为嵌入。向量存储利用这个问题嵌入来搜索存储中的’n’(默认:4)相似文档或片段。随后,将每个文档或片段的内容与提示组合并发送到 GPT-3。
GPT-3 返回的结果然后与另一个提示组合,在最后一步再次发送回 GPT-3以获取最终答案,包括来源。
网站内容示例
在使用RetrievalQAWithSourcesChain之前,我们确保通过实施记忆使我们的机器人记住之前的对话。这增强了与用户的上下文相关互动。
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
)
为了将之前的聊天记录整合到使用的提示中,我们需要修改现有的模板。
from langchain import PromptTemplate
template = """You are a chatbot having a conversation with a human.
Given the following extracted parts of a long document and a question,
create a final answer.
{context}
{chat_history}
Human: {question}
Chatbot:"""
question_prompt = PromptTemplate(
input_variables=["chat_history", "question", "context"], template=template
)
之后,我们可以利用 RetrievalQAWithSourcesChain 来提问。在这个示例中,我们设置k=4,这意味着我们将查询向量存储以获取 4 个最相似的文档。
from langchain.chains import RetrievalQAWithSourcesChain
article_chain = RetrievalQAWithSourcesChain.from_llm(
llm=OpenAI(temperature=0.0),
retriever=article_store.as_retriever(k=4),
memory=memory,
question_prompt=question_prompt,
)
result = article_chain({"question": "What is Skiff?"},
return_only_outputs=True)
结果以字典形式返回:
{'question': 'What is Skiff?',
'answer': 'Skiff is a privacy-focused email service with unique
functionalities such as the ability to manage multiple
sessions, appearance tweaks, dark mode, white theme,
two layouts, supporting imports from Gmail, Outlook,
Proton Mail, and more, creating and managing aliases,
and connecting a crypto wallet from Coinbase, BitKeep,
Brave, and others to send/receive email utilizing Web3\.
It also includes Pages to create/store documents securely,
the ability to use Skiff's server or IPFS (decentralized
technology) for file storage, and Skiff Pages,
Encrypted Cloud Storage With IPFS Support.\n',
'sources': 'https://news.itsfoss.com/skiff-mail-review/'}
我们可以观察到,结果包含了回答给定问题所用的来源。为了生成这个最终答案,API 被调用了5 次:4 次提取最相似的 4 个片段中的相关信息,额外 1 次生成最终答案。
我们还可以提出引用之前问题的问题。
article_chain(
{"question": "What are its functionalities?"},
return_only_outputs=True,
)
结果将如下所示。
{
'answer': "Skiff offers a range of functionalities,
including Web3 integration, IPFS decentralized storage,
creating and managing aliases, connecting crypto wallets,
getting credits to upgrade your account, importing from Gmail,
Outlook, Proton Mail, and more, Pages to create/store documents securely,
encrypted cloud storage with IPFS support, and the ability to use
Skiff's server or IPFS (decentralized technology) for file storage.\n",
'sources': 'https://news.itsfoss.com/anytype-open-beta/,
https://news.itsfoss.com/skiff-mail-review/'
}
请记住,对于这些问题,API 也被调用了 5 次。
YouTube 转录示例(有时间戳和没有时间戳)
YouTube 转录示例的代码与网站的代码非常相似。首先,我们初始化 ConversationBufferMemory,并创建一个自定义问题提示模板。
from langchain.memory import ConversationBufferMemory
memory = ConversationBufferMemory(
memory_key="chat_history",
input_key="question",
output_key="answer",
return_messages=True,
)
template = """You are a chatbot having a conversation with a human.
Given the following extracted parts of a long document and a question,
create a final answer.
{context}
{chat_history}
Human: {question}
Chatbot:"""
question_prompt = PromptTemplate(
input_variables=["chat_history", "question", "context"], template=template
)
然后我们创建带有来源的 QA 链。
# Use yt_store for YouTube transcripts without timestamps or
# yt_ts_store with timestamps as sources.
yt_chain = RetrievalQAWithSourcesChain.from_llm(
llm=OpenAI(temperature=0.0),
retriever=yt_store.as_retriever(k=4),
memory=memory,
question_prompt=question_prompt,
)
让我们问一个问题。
result = yt_chain(
{
"question": "What is the difference in building a tree for a
regression case compared to a classification case?"
},
return_only_outputs=True
)
不带时间戳的示例结果:
{'answer': ' The main difference between building a tree for a regression case
and a classification case is that in a regression case, the goal
is to predict a continuous value, while in a classification case,
the goal is to predict a discrete value. In a regression case,
the tree is built by splitting the data into subsets based on
the value of a certain feature, while in a classification case,
the tree is built by splitting the data into subsets based on
the value of a certain feature and the class label.
Additionally, in a regression case,
the weights are all equal to one,
while in a classification case, the weights are the previous
probability times one minus the previous probability.\n',
'sources': 'XGBoost Part 2 (of 4): Classification [StatQuest with Josh Starmer] https://youtu.be/8b1JEDvenQU,
XGBoost Part 3 (of 4): Mathematical Details [StatQuest with Josh Starmer] https://youtu.be/ZVFeW798-2I,
XGBoost Part 4 (of 4): Crazy Cool Optimizations [StatQuest with Josh Starmer] https://youtu.be/oRrKeUCEbq8'
}
带时间戳的示例结果:
{'answer': 'The difference in building a tree for a regression case compared
to a classification case is that in a regression case, the goal
is to predict a continuous value, while in a classification case,
the goal is to predict a probability that the drug will be
effective. Additionally, the numerator for classification is the
same as the numerator for regression, but the denominator
contains a regularization parameter. The denominator for
classification is different from the denominator for regression,
and is the sum for each observation of the previously predicted
probability times 1 minus the previously predicted probability.
The only difference between building a tree for a regression case
and a classification case is the loss function.\n',
'sources': 'https://youtu.be/ZVFeW798-2I&t=0
https://youtu.be/8b1JEDvenQU&t=180
https://youtu.be/OtD8wVaFm6E&t=0'
}
结论
LangChain 的 RetrievalQAWithSourcesChain 和 GPT-3 的结合非常适合提高问答的透明度。正如过程图所示(图 3),获取最终答案需要多次调用 OpenAI。
根据你使用服务的情况和需要处理的类似文档数量,调用次数可能会增加,导致更高的费用。这确实值得关注。不过,对于你的爱好项目来说,这不应该太关键。为了更好地关注费用和发送的提示,可以考虑使用 Promptlayer 或 TruLens。
Colab 笔记本可以在这里找到:
资料来源
- LangChain RetrievalQAWithSourcesChain API 文档
api.python.langchain.com/en/latest/chains/langchain.chains.qa_with_sources.retrieval.RetrievalQAWithSourcesChain.html
所有者或创作者已经提前询问是否允许我将他们的内容/数据用作本文的示例。
-
It’s FOSS., “It’s FOSS”,
itsfoss.com/ -
StatQuest. “XGBoost 第一部分(共 4 部分):回归” YouTube, Joshua Starmer, 2019 年 12 月 16 日,
youtu.be/OtD8wVaFm6E。 -
StatQuest. “XGBoost 第二部分(共 4 部分):分类” YouTube, Joshua Starmer, 2020 年 1 月 13 日,
youtu.be/8b1JEDvenQU。 -
StatQuest. “XGBoost 第三部分(共 4 部分):数学细节” YouTube, Joshua Starmer, 2020 年 2 月 10 日,
youtu.be/ZVFeW798-2I。 -
StatQuest. “XGBoost 第四部分(共 4 部分):疯狂炫酷的优化” YouTube, Joshua Starmer, 2020 年 3 月 2 日,
youtu.be/oRrKeUCEbq8。
用 Jupyter 和 Mercury 构建 Web 应用
原文:
towardsdatascience.com/build-a-web-app-with-jupyter-and-mercury-9d59661441b7
教程
Mercury 提供了一种简单的方法,将 Jupyter Notebooks 转换为互动式 Web 应用。
·发布于 Towards Data Science ·阅读时间 10 分钟·2023 年 5 月 9 日
--

我在提及代码开发还是 CO2 排放?照片由 Etienne Girardet 供稿于 Unsplash
如果没有 Jupyter Notebooks,我们将会处于何地?它们无疑是数据科学社区的基石之一。
它们非常适合原型设计和/或逐步构建和展示数据科学应用,但在演示方面不太好。
如果你想将你的工作结果展示给利益相关者,但不想展示所有聪明的代码,那么你可以用 Streamlit 或 Dash 重新编码所有内容,或者用 Flask 或 Django 构建一个 Web 应用。
但还有一种替代方案。
Mercury 是一个可以将你的 Notebook 转换为完全互动的 Web 应用的系统。你可以添加滑块和选择框等控件,制作一个完全互动的应用程序。
这完全是 100% Python 和相当简单,尽管结果比起其他工具可能少了些复杂性,但如果你寻找的是一个简单的解决方案,值得一看。
一个示例应用
我们将创建的应用将如下图所示。

图片由作者提供
你可以看到它由两部分组成:一个包含文本和输入控件的小侧边栏——一个滑块和一个选择框;以及一个包含笔记本代码输出的主窗口。所有 Mercury 应用都遵循这种模式,即小侧边栏包含控件,而主窗口显示代码输出。
你看到的地图是一个显示国家 CO2 排放量随时间变化的分级图。侧边栏中的滑块用于选择正确的年份。
在主窗口向下滚动,你将看到一条折线图,展示了某个大陆随时间变化的排放输出。大陆可以从侧边栏的选择框中选择。
每当输入组件的值发生变化时,代码会重新运行,并显示图表的新版本。
Mercury V.2
我在大约一年前首次查看了 Mercury,并喜欢从你已经用 Jupyter 编写的代码创建 Web 应用程序的简单性。
在过去的一年里有很多分心的事物(PyScript、Streamlit 以及一些大家似乎都在谈论的 AI 应用程序),但我终于重新审视了 Mercury,并且有了一些变化。
主要的是,版本 2 具有支持与应用程序交互的组件。这些替代了版本 1 中有些笨重的 YAML 定义,并且与版本 1 不同,它们可以在 Jupyter Notebook 以及应用程序中操作。
在这篇文章中,我将使用开源版本构建一个应用程序,这意味着你的应用程序将带有 Mercury 品牌(也提供商业许可证)。
首先,编写一个 Notebook
我们将逐步构建 Notebook 单元格,然后将其转换为 Web 应用程序。
但首先,安装 Mercury 库。
pip install mercury
我将在最后提供源代码和数据的链接,但如果你想跟随教程,你需要创建一个新的 Jupyter Notebook 并安装以下库:
import mercury as mr
import plotly.express as px
import pandas as pd
这,当然,是你的第一个单元格。
下一个完全关注于 Mercury,并且在 Notebook 中不会输出任何重要内容。
# set Application parameters
app = mr.App(title="CO2 Emissions control panel",
description="Graphs showing the emissions over more than 100 years",
show_code=False,
show_prompt=False,
continuous_update=True,
static_notebook=False,
show_sidebar=True,
full_screen=True,
allow_download=False)
此单元格在 Mercury 创建应用程序的 Web 版本时设置一些信息。前两个字段是不言而喻的,这些数据会在 Mercury 版本中显示。如果你愿意接受其余字段的默认值,那么这些就是你需要包括的唯一字段。
作为参考,这里列出了所有参数及其用途。
-
title (string) — 应用程序的标题
-
description (string) — 应用程序的描述
-
show_code (boolean) — 设置为
True以在应用程序中显示 Notebook 代码。默认值为False。 -
show_prompt (boolean) — 如果你希望在应用程序中显示 Notebook 提示,请将此值设置为
True—show_code也需要设置为True。默认值为False。 -
continuous_update (boolean) — 默认值为
True,意味着在输入组件值变化后,应用程序会重新运行。如果设置为False,将显示一个运行按钮,用于用更新后的组件数据重新运行应用程序。 -
static_notebook (boolean) — 设置为
True时,应用程序在组件值变化时不会重新运行,Notebook 是静态的。默认值为False,即为交互式应用程序。 -
show_sidebar (boolean) — 应用程序由主显示和侧边栏组成,默认情况下在启动时显示。用户在应用程序运行时可以随时打开或关闭侧边栏。
-
full_screen (布尔值)——默认值为
True,表示笔记本全宽显示。设置为False将笔记本宽度限制为 1140px。 -
allow_download (布尔值)——默认情况下有一个下载按钮,可以将结果导出为 PDF 或 HTML。你可以通过设置
allow_download=False来隐藏这个按钮。
现在,我们需要下载一些数据。这里是下一个单元格。
# Get the data
def get_countries_data():
url = "./data/countries_df.csv"
return pd.read_csv(url)
def get_continent_data():
url = './data/continents_df.csv'
return pd.read_csv(url)
def get_world_data():
url = './data/world_df.csv'
return pd.read_csv(url)
def get_group_data():
url = './data/income_types_df.csv'
return pd.read_csv(url)
df_countries= get_countries_data()
df_continents= get_continent_data()
df_world = get_world_data()
df_groups = get_group_data()
我们从本地目录加载数据到 Pandas 数据框中。这些数据记录了几十年来的二氧化碳排放情况,按国家、洲和收入组分类。
每个数据框看起来类似于下面的表格。

作者提供的图片
它们包含以下列:
-
实体:国家、大洲、收入组或‘世界’
-
代码:ISO 国家代码(如果它是国家)
-
年份
-
年度二氧化碳排放量:来自燃烧化石燃料和工业过程
-
包括土地利用变化的年度二氧化碳排放量:前一列和下一列的总和
-
土地利用变化导致的年度二氧化碳排放量
这是公开可用的数据,你可以从多个来源下载,例如,我们的世界数据¹。
这里的数据比我们实际需要的稍多——我们将在这个练习中使用前两个数据框,并仅使用部分列。
小部件
有许多可用于输入和输出的小部件。我们只会查看其中的几个,但完整列表可以在Mercury 文档中找到。
这是包含选择年份滑块的小部件的单元格。
mr.Note(text="__Select a year to see the global emissions in the map__")
ymax = df_countries['Year'].max()
ymin = df_countries['Year'].min()
yearSlider = mr.Slider(value=2000, min=ymin, max=ymax, label="Year", step=1)
mr.Note(text="---")
在第一行,我们看到一个输出小部件mr.Note,它简单地显示 Markdown 文本。这里我们用它来提示用户从滑块中选择一个年份。然后我们计算数据中的年份范围,并使用这些值来显示滑块。
滑块小部件的参数比较自解释:我们从默认值开始,然后是最大值和最小值,标签,最后是步长值。
从小部件返回的值实际上是一个小部件对象。要从该对象中获取值,我们必须使用.value属性,如下所示。
最后一行是另一条 Markdown 字符串——一条将滑块与下方的小部件分开的线。
在 Jupyter 中,输出位于单元格下方,通常,你会看到文本和一个完全可操作的简单滑块,值会按预期变化。

笔记本中的滑块小部件——作者提供的图片
当我们将其转换为应用程序时,滑块将位于侧边栏中。
下一个单元格显示地图。我们首先从滑块对象中获取年份值,然后计算二氧化碳排放范围的最大值和最小值。然后,我们使用 Plotly 计算分级地图并显示它。
# Get the year and find the maximum and minimum values to set the range for the choropleth
year = yearSlider.value
max = df_countries['Annual CO₂ emissions'].max()
min = df_countries['Annual CO₂ emissions'].min()
fig1 = px.choropleth(df_countries[df_countries['Year']==year],
locations="Code",
color="Annual CO₂ emissions",
hover_name="Entity",
range_color=(min,max),
color_continuous_scale=px.colors.sequential.Blues,
width=1000, height=500,
title = f"Annual CO2 Emission for {year}")
fig1.show()
地图看起来如下图所示。

作者提供的图片
另一个小部件
按照之前类似的模式。首先是一个指令,然后我们从数据框的Entity列中计算出所有大陆的唯一列表。
这次我们使用选择小部件mr.Select
该功能接受一个默认值、一个可供选择的列表和一个标签作为参数。
mr.Note(text=f"__Select a continent to see the global emissions over the entire period__")
continents = df_continents['Entity'].unique()
select = mr.Select(value="Europe", choices=continents, label="")
selected_continent = select.value
单元格的输出如下图所示。但同样,在应用程序中,它将出现在侧边栏中。

笔记本中的选择框小部件——由作者提供的图片
然后我们使用选择框中的值来绘制适当的折线图。我们通过使用该值过滤数据框,以便仅返回所选大陆的数据。然后使用这些数据绘制图表。
df = df_continents[df_continents['Entity'] == selected_continent]
fig2 = px.line(df,"Year","Annual CO₂ emissions",
width=1000, height=500,
title=f"Annual CO₂ emissions for {selected_continent}")
fig2.show()
图表看起来像下面的图片。

由作者提供的图片
这可能听起来很明显,但如果你运行整个笔记本,你将得到与小部件的默认值对应的结果。要更改小部件值并获得具有新值的相应图表,你需要逐个单元格运行笔记本。
还有其他小部件,如输入文本或数字值的字段、按钮和复选框——详情请参见文档。
这就是 Jupyter Notebook 的所有内容。现在来创建一个网络应用程序。
创建一个网络应用程序
从命令提示符导航到包含笔记本的目录并输入:
mercury run
这将拾取该目录中的所有笔记本,并运行一个类似下面的网页,其中包含每个笔记本的面板。

由作者提供的图片
在我的情况下,只有一个笔记本——我们刚刚看到的那个——当你点击它时,它将运行应用程序。
提醒你它是这样的。

由作者提供的图片
你可以看到,小部件比笔记本中显示的要更具吸引力,并且它们出现在侧边栏中。
应用程序是完全交互式的:当你更改小部件值时,小部件下方的所有笔记本单元都会重新运行。因此,改变滑块会更新地图,选择框的选择更改时,图表会重新绘制。
我不得不承认这里有点作弊。Mercury 复制了一些演示应用程序,并创建了一个名为welcome.md的文件,其中包含欢迎信息。虽然这些对初学者很有用,但为了清晰和简单,我已经删除了笔记本文件夹中的笔记本,并从welcome.md中删除了所有内容。
结果
结果是一个简单但功能齐全的应用程序。
缺点是 Mercury 应用程序始终具有相同的布局——一个包含小部件的侧边栏和一个显示代码结果的主窗口。此外,主窗口有一个简单的线性布局,无法更改。
这并不是很灵活,虽然对许多目的来说足够,但你可以通过像 Streamlit 或 Dash 这样的替代应用构建工具创建更复杂的布局。
总的来说,如果你乐意在 Jupyter Notebooks 中编写 Python,那么 Mercury 的学习曲线几乎可以忽略不计。而且,尽管结果不如其他应用构建器那样灵活,但转换为应用的过程和在终端窗口中输入一个命令一样简单。
当然,我们的应用正在我们本地机器上的专用服务器上运行。在网上发布则是另一回事,我会在以后查看发布选项。
与此同时,感谢阅读,希望这对你有帮助。本文的代码和数据可以在我的网页上找到,此外还有一个演示应用的链接和如何在 Mercury Cloud 上发布应用的说明。你也可以在那里看到我其他工作的链接,并且可以注册我的偶尔新闻通讯。
说明
[1] 我们的数据世界 是一个信息宝库,其使命是发布“研究和数据以应对世界上最大的难题”。他们的所有工作都在创作共用 BY 许可协议下提供。
免责声明:我与 Mercury 没有关联。
使用 FastAPI 和 Angular 构建 WebSocket 应用程序
原文:
towardsdatascience.com/build-a-websocket-application-with-fastapi-and-angular-988157dce554
学习如何使用 WebSocket 协议构建双向交互通信应用程序
·发表在 Towards Data Science ·7 分钟阅读·2023 年 1 月 30 日
--

图片由 geralt 提供(移动智能手机股票交易)在 Pixabay
类似于 HTTP,WebSocket 也是一种用于客户端-服务器通信的协议。然而,与 HTTP 不同的是,WebSocket 是一种双向交互协议,允许客户端向服务器发送消息,并且可以被动地从服务器接收事件驱动的响应,而无需向服务器发出请求。
WebSocket 广泛应用于聊天和游戏中,这些场景需要实时数据。它也可以用于其他需要实时或近实时数据的领域。例如,使用历史价格预测未来价格的应用程序可以从 WebSocket 中受益。当来自客户端的数据流中出现新数据时,使用某些机器学习或深度学习模型预测的结果可以自动发送到客户端。
在本文中,我们将使用 FastAPI 和 Angular 构建一个简单的 WebSocket 应用程序,其中前者用于构建 WebSocket 服务器,后者用于客户端。这个概念可能对你来说全新,并且构建这样的应用程序可能看起来很令人畏惧。然而,正如你将在本文中看到的,它实际上并没有那么复杂,我们可以用几行代码快速构建一个 WebSocket 应用程序。
使用 FastAPI 构建 WebSocket 服务器
FastAPI 使用 Starlette 作为底层框架,支持 WebSocket,并提供了一些标准方法来接受客户端连接,接收和发送数据。
在本文中,我们将把后端和前端代码放在同一个代码库中,分别位于 backend 和 frontend 文件夹中。FastAPI 代码将放在 backend/app 文件夹里:
# backend/app/main.py
import asyncio
import logging
from datetime import datetime
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("FastAPI app")
app = FastAPI()
async def heavy_data_processing(data: dict):
"""Some (fake) heavy data processing logic."""
await asyncio.sleep(2)
message_processed = data.get("message", "").upper()
return message_processed
# Note that the verb is `websocket` here, not `get`, `post`, etc.
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
# Accept the connection from a client.
await websocket.accept()
while True:
try:
# Receive the JSON data sent by a client.
data = await websocket.receive_json()
# Some (fake) heavey data processing logic.
message_processed = await heavy_data_processing(data)
# Send JSON data to the client.
await websocket.send_json(
{
"message": message_processed,
"time": datetime.now().strftime("%H:%M:%S"),
}
)
except WebSocketDisconnect:
logger.info("The connection is closed.")
break
该代码片段的关键点:
-
在 WebSocket 端点的路径装饰器中,使用的是
@app.websocket,而不是像@app.get、@app.post等具有 HTTP 动词的装饰器。 -
WebSocket对象可以在路径操作函数中注入,它可以用来接受客户端连接,接收来自客户端的数据以及通过一些标准方法(accept、receive_json和send_json)向客户端发送数据。代码实际上相当自解释,你可以在你的应用中类似地使用它。
现在你可以使用 uvicorn 启动 WebSocket 服务器:
uvicorn main:app --reload --host 0.0.0.0 --port 8000
与 HTTP 端点不同,独立测试 WebSocket 端点没有客户端时并不那么简单。设置一个客户端并从那里进行测试更为容易。在这篇文章中,我们将使用 Angular 构建一个 WebSocket 客户端,并从那里与 WebSocket 服务器进行交互。
设置 Angular 项目
在这篇文章中,使用的是 Node.js 18.13.0 和 Angular v15.1.2。如果你想跟随并看到展示的确切结果,建议使用这些版本。然而,旧版本或新版本也应该能够正常工作。
要安装 Angular CLI,请运行:
npm install -g @angular/cli@15.1.2
然后,我们可以使用 ng 命令创建一个工作区:
ng new frontend
我们不会为这个简单的演示项目添加路由,所有其他设置可以保持默认。
我们不会创建新的组件,而是直接更新现有的 app 组件。
不过,我们将创建一个新的服务文件来管理与 WebSocket 相关的代码。
Angular 中的服务类只是一个普通的类加上一些设置,通常只有 Injectable 装饰器中的 providedIn 设置。我们可以直接在 app 文件夹中创建服务文件(websocket.service.ts)。
创建 WebSocket 服务
多亏了 RxJS 库,在 Angular 中使用 WebSocket 非常简单。我们可以从 RxJS 中导入 WebSocket 工厂函数,并使用它创建一个 WebSocketSubject,然后可以用来向 WebSocket 服务器发送数据和接收数据。
服务文件的代码如下:
// frontend/src/app/websocket.service.ts
import { Injectable } from '@angular/core';
import { webSocket, WebSocketSubject } from 'rxjs/webSocket';
import { environment } from '../environments/environment';
interface MessageData {
message: string;
time?: string;
}
@Injectable({
providedIn: 'root',
})
export class WebSocketService {
private socket$!: WebSocketSubject<any>;
public receivedData: MessageData[] = [];
public connect(): void {
if (!this.socket$ || this.socket$.closed) {
this.socket$ = webSocket(environment.webSocketUrl);
this.socket$.subscribe((data: MessageData) => {
this.receivedData.push(data);
});
}
}
sendMessage(message: string) {
this.socket$.next({ message });
}
close() {
this.socket$.complete();
}
}
此文件的重点:
-
webSocket工厂函数和WebSocketSubject主题是 Angular 中使用 WebSocket 的关键。 -
要创建
WebSocketSubject,我们需要传入 WebSocket 服务器的 URL,在这个示例中是ws://localhost:8000/ws。请注意协议是ws,而不是http。这是我们在上面的 FastAPI 应用中设置的,并作为环境变量保存在environment.ts中。 -
WebSocketSubject的工作方式与普通的 RxJSSubject相同。因此,它可以被订阅以接收来自服务器的数据,我们也可以使用next()向服务器发送消息。 -
当我们不再需要与 WebSocket 服务器交互,或当相应的组件被销毁时,我们可以在
WebSocketSubject上调用complete()来完成它,从而关闭连接。我们应该在后端代码中妥善处理连接关闭事件,否则会引发异常。 -
接收到的数据被保存到一个数组中,稍后会一起显示给用户。
在组件中使用 WebSocket 服务
现在 WebSocket 服务已创建,我们可以在组件中使用它。我们应有一些机制来打开和关闭与 WebSocket 服务器的连接。我们还将实现向服务器发送消息和显示接收消息的功能。
这是组件文件的代码:
// frontend/src/app/app.component.ts
import { Component, OnDestroy } from '@angular/core';
import { WebSocketService } from './websocket.service';
@Component({
selector: 'app-root',
templateUrl: './app.component.html',
styleUrls: ['./app.component.css'],
})
export class AppComponent implements OnDestroy {
message = '';
constructor(public webSocketService: WebSocketService) {
this.webSocketService.connect();
}
sendMessage(message: string) {
this.webSocketService.sendMessage(message);
}
ngOnDestroy() {
this.webSocketService.close();
}
}
如我们所见,所有的数据逻辑都封装在WebSocketService中。
这个模板文件演示了如何发送数据和显示接收到的数据:
<!-- frontend/src/app/app.component.html -->
<h2>Send a message to the server:</h2>
<form (ngSubmit)="sendMessage(message); message = ''">
<input [(ngModel)]="message" name="message" type="text" autocomplete="off" />
<button type="submit" style="margin-left: 10px;">Send</button>
</form>
<h2>Received messages from the server:</h2>
<ul>
<li *ngFor="let data of webSocketService.receivedData">
{{ data.time }}: {{ data.message }}
</li>
</ul>
注意,为了使用ngModel,我们需要在app.module.ts中导入FormsModule。
一切设置好后,我们可以运行ng serve在本地启动 Angular 应用程序,并访问 http://127.0.0.1:4200。尝试依次发送一些消息,看看会发生什么:

如演示的那样,我们可以不断地向服务器发送消息,当后台处理的数据准备好时会弹出,无需轮询服务器回复。
对后端和前端代码进行 Docker 化
为了使代码更容易在不同平台上共享,我们将后端和前端代码进行 Docker 化。
对于 FastAPI 代码,我们需要创建一个包含Dockerfile的自定义 Docker 镜像:
# backend/Dockerfile
FROM python:3.11
WORKDIR /app
COPY ./requirements.txt /app/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt
COPY ./app /app
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
requirements.txt 包含了应用程序的依赖项:
# backend/requirements.txt
fastapi>=0.89.0,<0.90.0
uvicorn[standard]>=0.20.0,<0.21.0
注意,你必须安装uvicorn[standard]而不仅仅是uvicorn,否则 WebSocket 无法在 FastAPI 中使用。
然后,我们需要创建一个docker-compose.yaml文件,以便更轻松地启动后端和前端代码:
# ./docker-compose.yaml
version: "3.9"
services:
fastapi_app:
build:
context: ./backend
image: fastapi_app:latest
ports:
- target: 8000
published: 8000
volumes:
- type: bind
source: ./backend/app
target: /app
networks:
- websocket
angular:
image: node:18.13.0-alpine
working_dir: /frontend
ports:
- target: 4200
published: 4200
volumes:
- type: bind
source: ./frontend
target: /frontend
networks:
- websocket
command: npm start
networks:
websocket:
name: websocket
driver: bridge
端口设置的方式使得一切与上述示例中的工作完全一致,而不需要使用 Docker。
整个项目的代码可以在这里找到。欢迎下载并自行测试。
当一切设置好后,可以使用docker-compose启动服务:
docker-compose build
docker-compose up -d
如果遇到端口冲突问题,可以使用以下命令找出哪些应用程序正在使用端口(4200 和 8000),并在适用时终止它们:
sudo lsof -i -P -n | grep LISTEN
如果一切正常,你可以像以前一样在 http://localhost:4200 访问 WebSocket Web 应用程序。一切应该与以前完全相同。
在这篇文章中,我们介绍了如何编写后端和前端代码,使用 WebSocket 设置一个简单的实时互动 Web 应用程序。WebSocket 服务器使用 FastAPI 创建,网页用户界面使用 Angular。FastAPI 和 Angular(通过 RxJS)都提供了简单的方式来实现 WebSocket 协议,使得使用 WebSocket 连接创建 Web 应用程序变得非常容易。
这篇文章可以帮助你开始使用 WebSocket 构建低延迟或实时的 Web 应用程序。它也可以成为数据处理的有价值工具。例如,大多数机器学习模型在构建和预测结果时需要时间。利用 WebSocket,我们可以构建一个接近实时的应用程序,使机器学习预测的数据在后台变得可用时不断弹出。
相关文章:
建立并玩耍!你自己的 V&L 模型配备 LLM!
开发集成 LLM 的 GIT 视觉语言模型。
·
关注 发表在 Towards Data Science ·21 分钟阅读·2023 年 9 月 7 日
--
本文摘要:
-
解释由微软开发的 GIT 视觉语言模型。
-
使用 PyTorch 和 Hugging Face 的 Transformers 替换 GIT 的语言模型为大型语言模型(LLMs)。
-
介绍如何使用 LoRA 微调 GIT-LLM 模型。
-
测试和讨论开发的模型。
-
探讨由 GIT 的图像编码器嵌入的“图像嵌入”是否在与“文本嵌入”相同的空间中指示特定字符。
大型语言模型(LLM)正展现出越来越多的价值。将图像纳入 LLM 使其作为视觉语言模型更加有用。在这篇文章中,我将解释一个称为 GIT-LLM 的模型的开发,这是一种简单但强大的视觉语言模型。某些部分,如代码解释,可能会显得有些繁琐,所以可以直接跳到结果部分。我进行了各种实验和分析,希望你能喜欢我所取得的成果。
实现已公开发布,所以请试试。
[## GitHub - turingmotors/heron
通过在 GitHub 上创建一个账户来为 turingmotors/heron 的发展做贡献。
github.com](https://github.com/turingmotors/heron?source=post_page-----afa773b9249e--------------------------------)
将 GIT 转化为 LLM
让我们深入探讨这篇技术博客的主要话题。
什么是 GIT?
生成式图像到文本变换器(Generative Image-to-text Transformer),或称 GIT,是微软提出的一种视觉语言模型。
arXiv: arxiv.org/abs/2205.14100
代码: github.com/microsoft/GenerativeImage2Text
它的架构相当简单。它将从图像编码器提取的特征向量转换为可以像文本一样处理的向量,使用一个投影模块。这些向量随后输入到语言模型中,以生成图像的标题或进行问答。该模型也可以以类似的方式处理视频。

该图摘自“GIT: A Generative Image-to-text Transformer for Vision and Language”
尽管它很简单,但如果你查看“Paper with code”的排行榜,你会发现它在许多任务中排名靠前。
paperswithcode.com/paper/git-a-generative-image-to-text-transformer
最初,GIT 使用像 CLIP 这样的强大模型作为其图像编码器,并从头开始训练语言模型部分。然而,在这篇文章中,我尝试使用一个强大的 LLM 并对其进行微调。在这里,我称该模型为“GIT-LLM”。
使用 Hugging Face 的 Transformers 来实现 LLM
我将使用 Hugging Face 的Transformers库来开发 GIT-LLM。Transformers 是一个用于处理机器学习模型的 Python 库。它提供了许多最先进的预训练模型,你可以立即进行推理。它还提供了训练和微调模型的工具。我相信 Transformers 在最近的 LLM 衍生品的发展中做出了重要贡献。几乎所有可用的 LLM 都可以用 Transformers 处理,许多从这些 LLM 衍生出的多模态模型也使用 Transformers 作为基础进行开发和微调。
这是使用 Transformers 模型的最简单代码。你可以通过使用AutoModel和AutoTokenizer轻松尝试 LLMs。
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "facebook/opt-350m"
model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
prompt = "Hello, I'm am conscious and"
input_ids = tokenizer(prompt, return_tensors="pt").to("cuda")
sample = model.generate(**input_ids, max_length=64)
print(tokenizer.decode(sample[0]))
# Hello, I'm am conscious and I'm a bit of a noob. I'm looking for a good place to start.
让我们查看 OPT 模型所包含的参数。打印由AutoModelForCausalLM创建的模型。
OPTForCausalLM(
(model): OPTModel(
(decoder): OPTDecoder(
(embed_tokens): Embedding(50272, 512, padding_idx=1)
(embed_positions): OPTLearnedPositionalEmbedding(2050, 1024)
(project_out): Linear(in_features=1024, out_features=512, bias=False)
(project_in): Linear(in_features=512, out_features=1024, bias=False)
(layers): ModuleList(
(0-23): 24 x OPTDecoderLayer(
(self_attn): OPTAttention(
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(activation_fn): ReLU()
(self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(fc1): Linear(in_features=1024, out_features=4096, bias=True)
(fc2): Linear(in_features=4096, out_features=1024, bias=True)
(final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(lm_head): Linear(in_features=512, out_features=50272, bias=False)
)
这非常简单。初始embed_tokens的输入维度和最终lm_head的输出维度为 50,272,表示训练此模型时使用的标记数量。让我们验证一下分词器词汇表的大小:
print(tokenizer.vocab_size)
# 50265
包括像bos_token、eos_token、unk_token、sep_token、pad_token、cls_token和mask_token这样的特殊标记,它预测了从总共 50,272 种标记中下一个单词的概率。
你可以通过查看实现来理解这些模型是如何连接的。一个简单的图示将表示如下流程:

OPT 的简化模型架构(图像由作者制作)
结构和数据流非常简单。〇〇Model 和〇〇ForCausalLM 在不同的语言模型中具有类似的框架。〇〇Model 类主要表示语言模型的“Transformer”部分。例如,如果你想执行文本分类任务,你只需使用这一部分。〇〇ForCausalLM 类用于文本生成,将分类器应用于处理后转换器的向量中的标记计数。损失计算也是在该类的前向方法中完成的。embed_positions表示位置编码,它会被加到project_in上。
使用 GIT 与 Transformers
我将根据GIT 的官方文档页面尝试一下。由于我也会处理图像,所以我会使用一个同时包含 Tokenizer 的 Processor。
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForCausalLM
model_name = "microsoft/git-base-coco"
model = AutoModelForCausalLM.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
# Downloading and preprocess an image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
pixel_values = processor(images=image, return_tensors="pt").pixel_values
# Preprocessing text
prompt = "What is this?"
inputs = processor(
prompt,
image,
return_tensors="pt",
max_length=64
)
sample = model.generate(**inputs, max_length=64)
print(processor.tokenizer.decode(sample[0]))
# two cats sleeping on a couch
给定输入图像生成的输出为“两只猫在沙发上睡觉”,这表明它的效果很好。
让我们也来看看模型的结构:
GitForCausalLM(
(git): GitModel(
(embeddings): GitEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(1024, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(image_encoder): GitVisionModel(
(vision_model): GitVisionTransformer(
...
)
)
(encoder): GitEncoder(
(layer): ModuleList(
(0-5): 6 x GitLayer(
...
)
)
)
(visual_projection): GitProjection(
(visual_projection): Sequential(
(0): Linear(in_features=768, out_features=768, bias=True)
(1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
(output): Linear(in_features=768, out_features=30522, bias=True)
)
虽然有点长,但如果你拆解开来,它其实也很简单。在 GitForCausalLM 中,有一个 GitModel,内部包含以下模块:
-
embeddings (GitEmbeddings)
-
image_encoder (GitVisionModel)
-
encoder (GitEncoder)
-
visual_projection (GitProjection)
-
output (Linear)
与 OPT 的主要区别在于存在GitVisionModel和GitProjection,这正是将图像转换为类似提示的向量的模块。虽然语言模型对 OPT 使用 Decoder,对 GIT 使用 Encoder,但这仅意味着注意力掩码构建方式的不同。变压器层可能存在细微差别,但它们的功能本质上是相同的。GIT 使用 Encoder 这一名称,因为它使用独特的注意力掩码,该掩码对图像的所有特征应用注意力,并对文本特征使用因果掩码。
查看模型的连接;

GIT 的简化模型架构(图像由作者制作)
图像信息由GitVisionModel和GitProjection处理,以匹配文本的嵌入。之后,它与文本的嵌入一起输入到语言模型的“Transformer”层中。虽然存在细微差别,但与语言模型相关的部分几乎以相同的方式开发。
GIT 的注意力掩码
一般语言模型和 GIT 语言模型的架构几乎相同,但应用注意力掩码的方式不同。
对于语言模型,应用注意力掩码以避免在预测未来标记时查看过去的标记。这是一种称为“因果注意力”的方法,对应于下图的左侧。第一列标记仅引用自身,确保对后续词没有自注意力。第二列对第二个词应用自注意力,而从第三个词开始变为 0。这样的掩码使得模型能够有效地训练以预测下一个词。
GIT 输入有两种类型的标记:图像标记和文本标记。由于所有图像标记是同时使用的,并且不用于预测下一个标记,因此因果注意力不适用。另一方面,文本标记仍然需要因果注意力。设计了如图右侧所示的掩码来实现这一点。对于图像信息的前三行,自注意力应用于所有标记信息。从文本标记开始,向下移动一列会增加可以参考的单词数量。

因果注意力掩码与 Git 注意力掩码的区别(图像由作者制作)
让我们还检查一下制作 GIT 掩码的代码。创建 GIT 掩码的代码片段如下:
import torch
def create_git_attention_mask(
tgt: torch.Tensor,
memory: torch.Tensor,
) -> torch.Tensor:
num_tgt = tgt.shape[1]
num_memory = memory.shape[1]
# Areas where attention is applied are 0, areas without attention are -inf
top_left = torch.zeros((num_memory, num_memory))
top_right = torch.full(
(num_memory, num_tgt),
float("-inf"),
)
bottom_left = torch.zeros(
(num_tgt, num_memory),
)
# Causal Attention Mask
bottom_right = torch.triu(torch.ones(tgt.shape[1], tgt.shape[1]), diagonal=1)
bottom_right = bottom_right.masked_fill(bottom_right == 1, float("-inf"))
# Concatenate masks
left = torch.cat((top_left, bottom_left), dim=0)
right = torch.cat((top_right, bottom_right), dim=0)
# add axis for multi-head
full_attention_mask = torch.cat((left, right), dim=1)[None, None, :]
return full_attention_mask
# batch_size, sequence, feature_dim
visual_feature = torch.rand(1, 3, 128)
text_feature = torch.rand(1, 4, 128)
mask = create_git_attention_mask(tgt=text_feature, memory=visual_feature)
print(mask)
"""
tensor([[[[0., 0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., 0., -inf, -inf, -inf],
[0., 0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0., 0.]]]])
"""
你将掩码添加到注意力权重中。因此,自注意力发生的部分为 0,而不包括在注意力中的部分为 -inf。通过向前提供此掩码,只有文本部分可以进行因果注意力。对于视觉语言模型来说,像这样有效地创建和使用掩码非常重要。
连接 GIT 和 OPT
现在,让我们连接 GIT 和 OPT。目标是创建如图所示的模型。

GIT-OPT 的简化模型架构(图由作者制作)
对于通用实现,你可以参考[modeling_git.py](https://github.com/huggingface/transformers/blob/main/src/transformers/models/git/modeling_git.py)。
最重要的部分是GitOPTModel。在其中,一个视觉编码器需要与 LLM 连接。我会解释一些关键组件。
class GitOPTModel(OPTModel):
def __init__(self, config: OPTConfig):
super(GitOPTModel, self).__init__(config)
self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
self.visual_projection = GitProjection(config)
在init函数内部,实例化了各种模块。super初始化了OPTModel。在 GIT 中,推荐使用训练有素的 CLIP 图像编码器,因此我使其与 CLIP 训练的 ViT 兼容。GitProjection来自原始 GIT 实现。
让我们看看 forward 函数内部。实现基于OPTDecoder的 forward 部分,并添加了来自图像编码器的信息。虽然实现有点冗长,但我在代码中添加了注释,请按步骤进行。
class GitOPTModel(OPTModel):
...
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
) -> BaseModelOutputWithPooling:
seq_length = input_shape[1]
# 1\. Extract image features using ViT
visual_features = self.image_encoder(pixel_values).last_hidden_state
# 2\. Convert features extracted by ViT into prompt-like Image Embeddings
projected_visual_features = self.visual_projection(visual_features)
# 3\. Vectorize the tokens
inputs_embeds = self.decoder.embed_tokens(input_ids)
# 4\. Obtain Positional Encoding
pos_embeds = self.embed_positions(attention_mask, 0)
# 5\. Dimension adjustment of Text Embeddings specific to OPT
inputs_embeds = self.decoder.project_in(inputs_embeds)
# 6\. Text Embeddings + Positional Encoding
embedding_output = inputs_embeds + pos_embeds
# 7\. Concatenate Image Embeddings and Text Embeddings
hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
# 8\. Create Causal Attention Mask for Text region
tgt_mask = self._generate_future_mask(
seq_length, embedding_output.dtype, embedding_output.device
)
# 9\. Create Attention Mask for GIT
combined_attention_mask = self.create_attention_mask(
tgt=embedding_output,
memory=projected_visual_features,
tgt_mask=tgt_mask,
past_key_values_length=0,
)
# 10\. Pass through the Decoder layer repeatedly, the main part of the language model
for idx, decoder_layer in enumerate(self.decoder.layers):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
# 11\. Dimension adjustment MLP specific to OPT
hidden_states = self.decoder.project_out(hidden_states)
# 12\. Align the output interface
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
虽然看起来可能很复杂,但如果你逐步了解每个步骤,你会发现它遵循了图示中的流程。实际代码可能看起来有点复杂,但先掌握主要流程将使理解其他部分更容易。这是伪代码,对于详细部分,请参考发布的实现。
最后,让我们简要看看GITOPTForCausalLM部分。
class GitOPTForCausalLM(OPTForCausalLM):
def __init__(
self,
config,
):
super(GitOPTForCausalLM, self).__init__(config)
self.model = GitOPTModel(config)
def forward(
...
) -> CausalLMOutputWithPast:
outputs = self.model(
...
)
sequence_output = outputs[0]
logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
# Predict the next word as the task
num_image_tokens = self.image_patch_tokens
shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
...
)
模型内部的处理很简单。当提供labels时,即在训练过程中,损失计算也在 forward 中进行。在shifted_logits中,从第一个 token 到文本 tokens 的倒数第二个 token 被提取。然后,它计算与labels偏移一个词的 Cross Entropy Loss 作为正确答案。
一点需要注意的是,在初始化函数中分配GitOPTModel的变量需要命名为self.model。如果你查看父类OPTForCausalLM的实现,你会看到 OPT 在super初始化期间首先被放置到self.model中。如果你更改这个实例变量名,你将最终持有两个 OPT,这可能会增加内存负担。
LoRA 扩展
为了有效地微调 LLM,我将使用一个名为 Parameter-Efficient Fine-Tuning(PEFT)的库。由于它由 Hugging Face 开发,它与 Transformers 无缝集成。虽然 PEFT 中有各种方法,但这次我将使用一种常见的方法,即低秩适配(LoRA)进行实验。
如果模型支持 PEFT,模型可以用 LoRA 在几行代码中应用。
from transformers import AutoModelForCausalLM
from peft import get_peft_config, get_peft_model, LoraConfig
model = AutoModelForCausalLM.from_pretrained('microsoft/git-base')
peft_config = LoraConfig(
task_type="CAUSAL_LM",
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["v_proj"]
)
peft_model = get_peft_model(model, peft_config)
target_modules 参数指定了你想要转换为 LoRA 的模块。如果提供了列表作为 target_modules,则会将每个字符串结尾的模块转换为 LoRA。为了简化,LoRA 仅应用于自注意力模块的“value” (v_proj)。
在模型中,ViT 用于图像编码部分。请小心,因为这样指定的话,ViT 的自注意力部分可能也会应用 LoRA。这有点繁琐,但通过指定到键名不重叠的部分并将其传递给 target_modules,你可以避免这种情况。
target_modules = [f"model.image_encoder.vision_model.encoder.{i}.self_attn.v_proj" for i in range(len(model.model.decoder))]
结果模型变成了 PeftModelForCausalLM 类的一个实例。它有一个名为 base_model 的实例变量,保存了转换为 LoRA 的原始模型。作为示例,我展示了 LoRA 如何应用于 ViT 的自注意力中的 v_proj。
(self_attn): GitVisionAttention(
(k_proj): Linear(in_features=768, out_features=768, bias=True)
(v_proj): Linear(
in_features=768, out_features=768, bias=True
(lora_dropout): ModuleDict(
(default): Dropout(p=0.1, inplace=False)
)
(lora_A): ModuleDict(
(default): Linear(in_features=768, out_features=8, bias=False)
)
(lora_B): ModuleDict(
(default): Linear(in_features=8, out_features=768, bias=False)
)
(lora_embedding_A): ParameterDict()
(lora_embedding_B): ParameterDict()
)
(q_proj): Linear(in_features=768, out_features=768, bias=True)
(out_proj): Linear(in_features=768, out_features=768, bias=True)
)
在 v_proj 线性层内部,你会发现添加了如 lora_A 和 lora_B 的全连接层。LoRA 转换后的 Linear 模块是一个名字相同的 Linear 类,继承自 PyTorch 的 Linear 和 LoraLayer。这是一个有些独特的模块,有兴趣了解细节的人可以查看实现。
请注意,使用 PEFT 创建的模型默认不会保存除 LoRA 部分之外的任何内容。虽然可以通过 merge_and_unload 方法保存,但你可能希望在 Trainer 训练过程中保存所有中途保存的模型。重载 Trainer 的 _save_checkpoints 方法是一种方法,但为了避免麻烦,我这次通过在训练阶段仅获取 PeftModel 中原始模型部分来处理。
model = get_peft_model(model, peft_config)
model.base_model.model.lm_head = model.lm_head
model = model.base_model.model
我相信还有更高效的处理方法,所以我仍在研究中。
使用 GIT-LLM 进行实验
现在让我们进行一些使用目前开发的模型的实验。
关于训练配置和其他设置的详细信息,请参考已发布的实现,因为它们本质上遵循相同的方法。
数据集:M3IT
对于实验,我想使用一个将图像与文本配对并且易于集成的数据集。在浏览Hugging Face 的 Datasets时,我发现了 M3IT,这是一个由上海 AI 实验室开发的用于 Instruction Tuning 的多模态数据集。Instruction Tuning 是一种即使在数据量有限的情况下也能产生令人印象深刻结果的方法。看起来 M3IT 重新标注了各种现有数据集,专门用于 Instruction Tuning。
huggingface.co/datasets/MMInstruction/M3IT
这个数据集很容易使用,所以我决定在接下来的实验中利用它。
要使用 M3IT 进行训练,必须创建一个自定义的 Pytorch Dataset。
class SupervisedDataset(Dataset):
def __init__(
self,
vision_model_name: str,
model_name: str,
loaded_dataset: datasets.GeneratorBasedBuilder,
max_length: int = 128,
):
super(SupervisedDataset, self).__init__()
self.loaded_dataset = loaded_dataset
self.max_length = max_length
self.processor = AutoProcessor.from_pretrained("microsoft/git-base")
# Setting up the corresponding Processor for each model
self.processor.image_processor = CLIPImageProcessor.from_pretrained(vision_model_name)
self.processor.tokenizer = AutoTokenizer.from_pretrained(
model_name, padding_side="right", use_fast=False
)
def __len__(self) -> int:
return len(self.loaded_dataset)
def __getitem__(self, index) -> dict:
# cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances
row = self.loaded_dataset[index]
# Creating text input
text = f'##Instruction: {row["instruction"]} ##Question: {row["inputs"]} ##Answer: {row["outputs"]}'
# Loading the image
image_base64_str_list = row["image_base64_str"] # str (base64)
img = Image.open(BytesIO(b64decode(image_base64_str_list[0])))
inputs = self.processor(
text,
img,
return_tensors="pt",
max_length=self.max_length,
padding="max_length",
truncation=True,
)
# batch size 1 -> unbatch
inputs = {k: v[0] for k, v in inputs.items()}
inputs["labels"] = inputs["input_ids"]
return inputs
在 init 函数中,image_processor 和 tokenizer 分别对应其各自的模型。传递的 loaded_dataset 参数应来自 MMInstruction/M3IT 数据集。
coco_datasets = datasets.load_dataset("MMInstruction/M3IT", "coco")
test_dataset = coco_datasets["test"]
对于 COCO Instruction Tuning 数据集,训练、验证和测试的划分与原始数据集相同,分别为 566,747、25,010 和 25,010 对图像-文本对。其他数据集,如 VQA 或 Video,也可以类似处理,使其成为一个多用途的验证数据集。
示例数据如下:

图像引用自 M3IT 数据。
该图片的说明如下:
Instruction: 写一个简洁的图像描述,捕捉其主要组成部分、它们之间的关系以及任何显著细节。 ##Question: ##Answer: 一名戴红色头盔的男子骑在小型摩托车上,行驶在泥土道路上。
对于 COCO 数据集,该数据集用于描述,问题部分保持为空。
让我们深入探讨处理器的操作。本质上,它对图像进行归一化并对文本进行分词。短于 max_length 的输入也会被填充。处理器返回的数据是一个包含以下内容的字典:
-
input_ids: 一个分词文本的数组。
-
attention_mask: 用于分词文本的掩码(填充部分为 0)。
-
pixel_values: 归一化图像的数组,也转换为 Channel-first。
这些关键名称对应于模型的前向函数的参数,因此不应更改。最后,input_ids 直接传递给名为 labels 的关键。GitOPTForCausalLM 的前向函数通过预测下一个词(偏移一个标记)来计算损失。
实验 1:确定微调位置
在 GIT 模型的研究论文中,解释了使用了强大的视觉编码器,并且语言模型采用了随机参数。这一次,由于目标是最终使用 7B 类语言模型,因此将应用预训练模型。以下模块将用于微调。GIT Projection 作为一个初始化模块,总是包括在内。一些组合可能看起来冗余,但它们在此试验中被探讨而无需过多担忧。
设置为训练的模块会获得梯度,而其余模块则修改为没有梯度。
# Specifying the parameters to train (training all would increase memory usage)
for name, p in model.model.named_parameters():
if np.any([k in name for k in keys_finetune]):
p.requires_grad = True
else:
p.requires_grad = False
本次检查所用的 Vision Encoder 和 LLM 是:
-
openai/clip-vit-base-patch16
-
facebook/opt-350m
训练使用 COCO 数据集,持续 5 轮。
以下是每个实验中训练的目标模块:
-
Proj: GIT Projection。随机初始化,因此总是进行训练。
-
LoRA: 语言模型中的自注意力的 Query、Key 和 Value 被应用。
-
OPT: 所有层都经过训练。
-
ViT: 所有层都经过训练。
-
Head: OPT 的最终 lm_head 已经过训练。
(注意:虽然 LoRA 可以应用于 ViT,但为了避免使实验过于复杂,这次未包含在内。)

该图显示了训练损失。图例中的 Proj、LoRA、OPT、ViT 和 Head 是上述训练模块。(图由作者制作)
如训练损失图所示,一些组的表现明显不佳。这些情况发生在 OPT 被包括在训练中时。尽管所有实验在相似的条件下进行,但在微调语言模型时可能需要更详细的调整,如学习率。接下来将检查排除 OPT 的训练模型的结果。

该图显示了没有完全微调结果的训练损失。图例中的 Proj、LoRA、OPT、ViT 和 Head 是上述训练模块。(图由作者制作)

该图显示了验证损失。图例中的 Proj、LoRA、OPT、ViT 和 Head 是上述训练模块。(图由作者制作)
无论是训练还是验证损失,Projection+LoRA模型的减少幅度最大。对最终Head层进行微调显示出几乎相同的结果。如果 ViT 也被训练,损失值似乎略高,结果也显得不稳定。即使在 ViT 训练期间添加了 LoRA,损失仍然倾向于较高。对于这个数据的微调,似乎使用一个未更新参数的预训练 ViT 模型会产生更稳定的结果。LoRA 的有效性在多个地方得到了认可,从这个实验中可以明显看出,将 LoRA 添加到 LLM 中改善了训练和验证损失。
评估一些测试数据的推理结果:

GIT-OPT 的示例结果。图片引用自 M3IT 数据集,文本结果由作者的模型生成。
当训练 OPT 本身时,结果与损失结果一样差,使得模型无言以对。此外,训练 ViT 时,输出结果有语义意义,但描述的内容与给定的图像完全不同。然而,其他结果似乎在某种程度上捕捉到了图像的特征。例如,第一张图提到了“猫”和“香蕉”,第二张图识别为“交通标志”。比较有无 LoRA 的结果,后者倾向于重复使用类似的词汇,但使用 LoRA 似乎使其略微更自然。训练Head时得到的输出非常有趣,例如第一张图用“playing”代替“eating”。虽然这些结果中有些元素不自然,但可以推测训练成功捕捉了图像特征。
实验 2:比较亿级模型
对于早期实验中的微调条件,使用了稍小的语言模型 OPT-350m。现在的意图是将语言模型切换到 7B 模型。不仅仅满足于 OPT,还将引入更强的 LLM,如 LLaMA 和 MPT。
将这两个模型集成可以按照与 OPT 类似的方式进行。参考LlamaModel和MPTModel的前向函数,将投影的图像向量与文本标记结合,并将掩码从Causal Attention Mask更改为GIT 的 Attention Mask。需要注意的是:对于 MPT,掩码不是(0, -inf),而是(False, True)。随后的过程可以类似地实现。
要使用 7B 级模型与 OPT,只需将模型名称从 facebook/opt-350m 更改为 facebook/opt-6.7b。
对于 LLaMA,考虑到 LLaMA2 的可用性,它将是首选模型。使用这个预训练模型需要 Meta 和 Hugging Face 的批准。需要一个 Hugging Face 账户,所以确保设置好。批准通常在几小时内完成。之后,登录到执行训练的终端上的 Hugging Face。
huggingface-cli login
你可以使用在 Hugging Face 账户中创建的令牌登录 → 设置 → 访问令牌。
训练参数保持一致,使用 COCO 数据集并持续 3 个 epoch。根据实验 1 的结果,微调的模块设置为Projection + LoRA。
让我们来看看结果。

此图显示了训练损失(图由作者制作)

此图显示了验证损失(图由作者制作)
通过查看损失,可以明显看出,使用 LLaMA2 和 MPT 作为 LLM 的模型显示了更令人满意的减少。让我们也观察一下推理结果。

GIT-LLMs 的示例结果。图片引用自 M3IT 数据集,文本结果由作者的模型生成。
关于第一张图片,对于所有模型,与 OPT-350m 相比,表情似乎更自然。没有像“一个香蕉和一个香蕉”这样的奇怪表情,突出了 LLM 的优势。对于第二张图片,仍然存在像“交通灯”或“建筑物”这样的短语困难。对于这种复杂的图像,可能需要考虑升级 ViT 模型。
最后,让我们对在 GPT-4 中变得流行的图像进行推理。

GIT-LLMs 的示例结果。图片引用自这里,文本结果由作者的模型生成。
尽管使用 LLM 时预期会有流畅的响应,但结果相当简单。这可能是因为模型仅在 COCO 上进行了训练。
实验 3:增加数据量
鉴于之前实验的结果不尽如人意,决定在训练中引入 COCO 以外的数据。当前使用的 M3IT 数据集相当全面,能够处理与 COCO 格式相同的大量数据。

该表格引用自“M3IT:面向多模态多语言指令调优的大规模数据集”的表 3
打算使用来自该来源的数据,但排除“中文”和“视频”类别。最初,COCO 训练数据集包含 566,747 条数据。通过与其他来源结合,总数增加到 1,361,650。尽管规模大致翻倍,但由于任务多样性增加,数据集的质量被认为有所提高。
使用ConcatDataset可以轻松处理多个 Pytorch 数据集。
dataset_list = [
datasets.load_dataset("MMInstruction/M3IT", i) for i in m3it_name_list
]
train_dataset = torch.utils.data.ConcatDataset([d["train"] for d in dataset_list])
训练进行了 1 轮,并使用 LLaMA2 模型对Projection 和 LoRA进行了微调,与实验 2 类似。
由于这次没有可以比较的损失值,我们直接进入推理结果。

GIT-LLaMA2 的示例结果。图片来自 M3IT 数据集,文本结果由作者的模型生成

GIT-LLaMA2 的示例结果。图片来自 M3IT 数据集,文本结果由作者的模型生成

GIT-LLaMA2 的示例结果。图片来自 M3IT 数据集,文本结果由作者的模型生成
除了解决简单问题外,模型现在还处理更复杂的挑战。通过添加比仅仅是描述更复杂的任务数据集,能力显著扩展。仅用 1 轮训练就达到这样的准确性令人惊讶。
让我们用以下示例图像进行测试。鉴于数据集的多样性增加,问题的呈现方式略有修改。

GIT-LLaMA2 的示例结果。一张图片来自这里,文本结果由作者的模型生成
尽管“伞状”这一描述仍有些奇怪,但感觉越来越好。为了进一步改进,需要增加训练轮次,添加更多类型或量的数据集,并利用更强大的 ViT 或 LLM。尽管如此,能够在仅半天内开发出这样的模型,考虑到计算和数据资源,确实令人印象深刻。
奖励实验。图像变成文字了吗?
再看一下 GIT 结构。

GIT-LLM 的简化模型架构(图像由作者制作)
如图所示,在视觉编码器进行特征提取后,图像通过Visual Projection与向量化的文本平等对待。换句话说,Visual Projection 可能将图像向量转换为文本向量。进行了调查以查看Visual Projection之后的向量是什么样的。
虽然有使用Head将投影后的向量还原为文本的选项,但发现即使是使用Embedding模块向量化的向量也无法通过这种方法还原为原始文本。因此,应将与输入到 LLM 之前的文本向量最接近的向量分配为相应的单词。所有在分词器中注册的令牌都使用 Embedding 模块进行向量化,并且选择了余弦相似度最高的词作为目标词。
本实验使用的图像是一只猫。

图片摘自 M3IT 数据集。
现在,让我们进行分析(完整分析可在这里查看)。首先,对所有注册的令牌进行向量化。
coco_datasets = datasets.load_dataset("MMInstruction/M3IT", "coco")
test_dataset = coco_datasets["test"]
supervised_test_dataset = SupervisedDataset(model_name, vision_model_name, test_dataset, 256)
ids = range(supervised_test_dataset.processor.tokenizer.vocab_size)
all_ids = torch.tensor([i for i in ids]).cuda()
token_id_to_features = model.model.embed_tokens(all_ids)
接下来,将提取本来会被 ViT 和 Projection 转换为单词的图像向量。
inputs = supervised_test_dataset[0] # Picking a sample arbitrarily
pixel_values = inputs["pixel_values"]
out_vit = model.model.image_encoder(pixel_values).last_hidden_state
out_vit = model.model.visual_projection(out_vit)
计算了这些向量和单词向量的点积,最大值的结果被解码为相关的令牌 ID。
# Dot product
nearest_token = out_vit[0] @ token_id_to_features.T
# The index of the maximum value corresponds to the relevant token ID
visual_out = nearest_token.argmax(-1).cpu().numpy()
decoded_text = supervised_test_dataset.processor.tokenizer.batch_decode(visual_out)
print(decoded_text)
"""
['otr', 'eg', 'anto', 'rix', 'Nas', ...]
"""
如打印出的decoded_text所示,一些不熟悉的单词出现了。由于一些单词重复出现,它们被统计了。
print(pd.Series(decoded_text).value_counts())
"""
mess 43
atura 29
せ 10
Branch 10
Enum 9
bell 9
worden 7
...
"""
似乎出现了大量不熟悉的单词。根据位置,它们可能传达有意义的信息。让我们将这些单词绘制在猫的图像上。
n_patches = 14
IMAGE_HEIGHT = 468
IMAGE_WIDTH = 640
y_list = np.arange(15, IMAGE_HEIGHT, IMAGE_HEIGHT//n_patches)
x_list = np.arange(10, IMAGE_WIDTH, IMAGE_WIDTH//n_patches)
plt.figure()
plt.axis("off")
plt.imshow(np.array(image), alpha=0.4)
for index in np.arange(n_patches ** 2):
y_pos = index // n_patches
x_pos = index - y_pos * n_patches
y = y_list[y_pos]
x = x_list[x_pos]
# The first token is the bos token, so it is excluded
word = decoded_text[index + 1]
# For differentiating words by color
plt.annotate(word, (x, y), size=7, color="blue")
plt.show()
plt.clf()
plt.close()

图片由作者制作
经常出现的单词用颜色编码。结果似乎表明它们并不仅仅是投射到有意义的单词上。虽然“Cat”这个词可能被叠加在猫的图像上,赋予它一定的相关性,但其含义仍不明确。
该实验中不确定的结果可能是由于强行选择了一个余弦相似度高的单词。无论如何,这种方法并不是简单地将单词投射到图像提示上。从图像中提取的向量通过Visual Projection 转换为令牌空间中的向量,这些向量似乎在意义上有些相似,充当神秘的提示。可能最好不要深入探讨这一点。
结论
在这篇技术博客文章中,我介绍了将 LLM 集成到视觉语言模型 GIT 的方法。此外,还使用开发的模型进行了各种实验。虽然有成功也有失败,但我希望继续进行视觉语言模型的实验,以积累见解。请将本文作为参考,并鼓励你创建自己的视觉语言模型,探索其潜力。

这是一张使用 Stable Diffusion 创建的 GIT-LLM 插图。(图片由作者制作)
使用 SQL 构建客户旅程
教程
学习如何跟踪跨多个渠道的消费者
·
关注 发表在 Towards Data Science · 5 min read · 2023 年 3 月 8 日
--
图片来源:Krivec Ales,Pixabay
客户旅程市场
全球客户旅程分析(CJA)市场在 2020 年的价值达到了83 亿美元,预计到 2026 年将增长到251 亿美元,预测期内年均增长率为20.3%。参与这一市场的代价不菲,并不是每个公司都能拿出一大笔资金。尽管关于 CJA 的信息很多,但关于如何构建数据集以获得客户旅程洞察的资料却几乎没有。在本教程中,我将展示如何使用结构化查询语言(SQL)创建客户旅程数据集。
什么是客户旅程?
我们可以将客户旅程定义为客户在多个渠道之间的一系列互动。这些渠道可能包括电子邮件、网站、移动应用、呼叫中心、社交媒体或店内购买。一个旅程可能展示客户与电子邮件互动,接着是网络,然后是直接邮件,最后是购买。这些时间戳互动使我们能够理解客户与公司产品的互动时间和顺序。旅程中的每一步都可以提供有价值的见解,改善客户体验、转换率及后续的营销工作。
客户旅程分析:使用 SQL 构建
旅程分析使 CX 团队能够可视化客户在不同渠道和时间上的行为,定义显示转换可能性的旅程信号,并监控终端旅程成功的表现。
我们的数据包括一个虚构公司的互动和订单,该公司销售汽车。public.interactions 表中的互动(即网页、移动端等)是客户在购买之前与我们公司互动的渠道。我们在 public.orders 表中捕捉购买数据。
构建客户旅程数据集需要三个步骤。步骤 1:主要代码根据互动日期对客户互动进行排序。步骤 2:主要代码 CTE创建主表的临时表,并使其可用于步骤 3。步骤 3:构建客户旅程如其名所示,创建客户旅程数据集。因此,让我们分析下面 SQL 代码中的每一步,以了解它们如何工作。
作者,SQL 代码
主要代码(步骤 1)
我们通过客户 ID 将互动表和订单表连接起来,public.interactions = public.orders on t1.customer_id = t2.customer_id。我们在订单表中捕捉完成的购买。因此,结果旅程将仅包含已购买的客户的互动。
为了捕捉和排序互动,我们使用Lead()函数。Lead 函数基于偏移值访问下一行的数据。Lead 函数的语法如下:
LEAD(return_value [,offset[, default ]]) OVER (
PARTITION BY expr1, expr2,...
ORDER BY expr1 [ASC | DESC], expr2,...
)
在下表中,第一次互动由函数 lead(t1.interaction,**0**) over (**partition by** t1.customer_id **order by** t1.interaction_date asc) 决定。首先,partition by 基于 t1.customer_id 对客户记录进行分组。其次,order by t1.interaction_date asc 按日期对互动进行排序。最后,offset 访问下一行,或者当前行之后的第二行,或者当前行之后的第三行,依此类推。在这里,offset 为零,lead(t1.interaction,**0**) 返回值为 Mobile。如果偏移量为 1,lead(t1.interaction,**1**) 返回值为 Web。

我建议将数据导入数据库,并应用上面的 SQL 代码以充分理解信息。你可以在这里找到数据集。
将主代码封装在 CTE 中(第 2 步)
我们通过将主代码添加到公共表表达式(CTE)中来处理它。CTE 的语法是:
with ctedata as
(
--Step 1 code
)
我们使用 CTE 的主要原因是为了能够使用where子句,并通过Lead()或Row_Number()函数派生的列来过滤行。在主代码中使用函数创建的列不能在where子句中使用。这为第 3 步准备了数据。
构建客户旅程数据集(第 3 步)
在第 1 步创建的数据集中,一些行缺少数据。我们通过过滤where offset = 0来移除缺失的行。我们还通过对首次互动和第二次互动的总购买客户进行汇总来聚合数据,如下所示。

作者,数据集
上表展示了客户在购买前的互动情况。在这个例子中,购买复古汽车的客户通过移动和电子邮件渠道进行互动。通过客户旅程数据,我们可以洞察哪些渠道表现最好,能促进互动并带来销售;哪些客户路径会导致流失;什么时机是与消费者互动的最佳时机,以及不同受众采取的路径[2]。
这些知识可能有助于制定营销策略。虽然表格可以有效地解释数据,但可视化数据可以帮助描绘直观的客户旅程。因此,让我们在下一部分尝试一下。
可视化客户旅程
下图中的桑基图对于说明跨多个渠道(如电子邮件、网站、移动应用、呼叫中心、社交媒体)的客户旅程顺序非常有用。

作者,桑基图数据
我们从左到右阅读桑基图。每个垂直条代表一个包含消费者的节点。例如,蓝色节点显示 23.24%的消费者首先通过直接邮寄与品牌互动。这些客户接着通过移动、直接邮寄、网页和电子邮件的组合进行互动。带状图的厚度直观地表示了在互动之间流动的消费者数量。带状图越厚,表示流向第二个互动节点的消费者越多。如果你想创建桑基图,我推荐尝试Chart Expo或Visual Paradigm这两个付费服务。我偏爱 Chart Expo,因为它提供了在线教程视频。如果你决定使用它,记得按照提供的 SQL 代码格式导出数据。
自定义旅程:继续学习
学习如何跟踪客户在多个渠道中的行为有很多内容需要探索。为此,我在本教程中包含了一个视频。视频覆盖了相同的内容,但通过插图讨论示例可能会增强学习效果。如果你发现理解某些概念有困难,我鼓励你观看视频。
作者,视频
好的,目前就这些了。希望你觉得这个教程对你有帮助。请随时联系我提问。我在这里分享和成长。
参考文献:
[1]: Markets and Markets,客户旅程分析市场按组件组织和规模
[2]: Karolina Matuszewska,Marek Juszczyński。 (2022 年 10 月 3 日)。什么是客户旅程分析以及它为何对你的业务重要
构建可部署的机器学习管道
原文:
towardsdatascience.com/build-deployable-machine-learning-pipelines-a6d7035816a6
利用 Kedro 构建生产就绪的机器学习管道
·发表于Towards Data Science ·8 分钟阅读·2023 年 6 月 30 日
--

作者提供的图片:使用 Midjourney 生成
背景 — 笔记本不能“部署”
许多数据科学家最初接触编码的方式是通过笔记本风格的用户界面。笔记本对于探索至关重要——这是我们工作流程的一个关键方面。然而,它们并不是为生产环境准备的。这是我在众多客户中观察到的一个关键问题,其中一些人询问如何将笔记本投入生产。与其将笔记本投入生产,最优的生产准备路径是编写模块化、可维护和可重复的代码。
在本文中,我展示了一个用于训练模型以分类欺诈信用卡交易的模块化机器学习管道示例。希望在本文结束时,你能:
-
了解和掌握模块化机器学习管道。
-
感到受启发,自己动手构建一个吧。
如果你想最大化地发挥机器学习模型的效益,编写模块化代码是一个重要的步骤。
首先对模块化代码进行一个简要定义。模块化代码是一种软件设计范式,强调将程序分解为独立的、可互换的模块。我们应该力求在我们的机器学习管道中达到这种状态。
快速绕道 — 项目、数据和方法
该机器学习项目的数据来源于Kaggle。数据集包含 284,807 个匿名信用卡交易记录,其中 492 个为欺诈交易。任务是构建一个分类器来检测欺诈交易。
本项目的数据在开放数据公共许可证下许可用于任何目的,包括商业用途。
我使用了利用Ludwig的深度学习方法,这是一个开源的声明式深度学习框架。我在这里不会详细介绍 Ludwig,不过我之前写过一篇关于框架的文章。
Ludwig 深度神经网络通过.yaml文件进行配置。对那些感兴趣的人,你可以在模型注册 GitHub找到它。
使用 Kedro 构建模块化管道
使用开源工具使得构建模块化机器学习管道变得更加容易,我最喜欢的工具之一是Kedro。不仅因为我见证了它在行业中的成功应用,还因为它帮助我提升了我的软件工程技能。
Kedro 是一个开源框架(根据 Apache 2.0 许可证)用于创建可重复、可维护和模块化的数据科学代码。我是在为一家银行开发 AI 策略时发现它的,考虑到我的团队可以利用哪些工具来构建生产就绪的代码。
免责声明:我与 Kedro 或 McKinsey 的 QuantumBlack(该开源工具的创建者)没有任何关联。
模型训练管道

作者提供的图片:使用 Kedro 可视化生成的端到端模型训练管道
Kedro 方便地允许你可视化你的管道,这是一个很棒的功能,可以帮助你清晰地理解你的代码。这个管道在机器学习中是标准的,因此我只会简单介绍每个方面。
-
导入数据:从外部来源导入信用卡交易数据。
-
拆分数据:使用随机拆分将数据分成训练集和保留集。
-
运行实验:使用 Ludwig 框架在训练数据集上训练深度神经网络。Ludwig 实验 API 方便地为每次实验运行保存模型工件。
-
运行预测:使用前一步训练的模型对保留数据集进行预测。
-
模型诊断:生成两个诊断图表。首先是跟踪每个周期的交叉熵损失。其次是测量模型在保留数据集上的性能的ROC 曲线。

作者提供的图片:模型训练过程中的损失曲线

作者提供的图片:在保留数据集上模型评估的 ROC 曲线
管道的核心组件
现在我们已经建立了一个高层次的视图,让我们深入探讨这个管道的一些核心组件。
项目结构
C:.
├───conf
│ ├───base
│ │ └───parameters
│ └───local
├───data
│ ├───01_raw
│ ├───02_intermediate
│ ├───03_primary
│ ├───04_feature
│ ├───05_model_input
│ ├───06_models
│ │ ├───experiment_run
│ │ │ └───model
│ │ │ ├───logs
│ │ │ │ ├───test
│ │ │ │ ├───training
│ │ │ │ └───validation
│ │ │ └───training_checkpoints
│ │ └───experiment_run_0
│ │ └───model
│ │ ├───logs
│ │ │ ├───test
│ │ │ ├───training
│ │ │ └───validation
│ │ └───training_checkpoints
│ ├───07_model_output
│ └───08_reporting
├───docs
│ └───source
│
└───src
├───fraud_detection_model
│ ├───pipelines
│ ├───train_model
└───tests
└───pipelines
Kedro 提供了一个模板化的目录结构,这个结构在你启动项目时就已经建立。从这个基础上,你可以以编程方式将更多的管道添加到你的目录结构中。这种标准化的结构确保了每个机器学习项目的一致性和易于文档化,从而便于维护。
数据管理
数据在机器学习中扮演着至关重要的角色。当在商业环境中使用机器学习模型时,跟踪数据的能力变得尤为重要。你经常会面临审计,或者需要在他人的机器上生产化或重现你的管道。
Kedro 提供了两种方法来强制执行数据管理的最佳实践。第一种是目录结构,专为机器学习工作负载设计,为数据转换过程中生成的中间表和模型工件提供了明确的位置。第二种方法是data catalogue。作为 Kedro 工作流的一部分,你需要在 .yaml 配置文件中注册数据集,从而在管道中利用这些数据集。这种方法初看可能不寻常,但它使你和其他参与管道工作的人员能够轻松跟踪数据。
调度 — 节点和管道
这就是魔法发生的地方。Kedro 提供了开箱即用的管道功能。
管道的初始构建块是nodes。每个可执行的代码片段可以封装在一个节点中,节点只是一个接受输入并产生输出的 Python 函数。然后,你可以将一个管道结构化为一系列节点。通过调用节点并指定输入和输出,可以轻松构建管道。Kedro 会确定执行顺序。
一旦管道构建完成,它们会被注册在提供的pipeline_registry.py文件中。这种方法的美妙之处在于,你可以创建多个管道。这在机器学习中尤其有用,你可能会有一个数据处理管道、一个模型训练管道、一个推理管道等。
一旦设置完成,修改管道的各个方面就会变得相当简单。
显示 nodes.py 脚本示例的代码片段
显示 Pipeline 脚本示例的代码片段
配置
Kedro 的最佳实践规定所有配置应通过提供的 parameters.yml 文件来处理。从机器学习的角度来看,超参数也属于这一类别。这种方法简化了实验过程,因为你可以简单地用另一组超参数替换一个 parameters.yml 文件,这也更容易追踪。
我还在 parameters.yml 配置中包含了我的 Ludwig 深度神经网络 model.yaml 和数据源的位置。如果模型或数据位置发生变化——例如,在开发者之间转移时——调整这些设置会非常简单。
显示parameters.yml文件内容的代码片段
可复现性
Kedro 包含一个 requirements.txt 文件作为模板结构的一部分。这使得监控你的环境和确切的库版本变得非常简单。然而,如果你愿意,你可以使用其他环境管理方法,如 environment.yml 文件。
建立工作流
如果你正在开发机器学习管道并考虑使用 Kedro,起初可能会有陡峭的学习曲线,但采用标准工作流会简化这一过程。以下是我建议的工作流:
-
建立你的工作环境:我更喜欢使用 Anaconda 来完成这项任务。我通常使用一个 environment.yml 文件,其中包含我环境所需的所有依赖项,并使用 Anaconda Powershell 命令行来 创建 我的环境。
-
创建 Kedro 项目:一旦你安装了 Kedro——希望它已在你的 environment.yml 中声明——你可以通过 Anaconda 命令行界面 创建 一个 Kedro 项目。
-
在 Jupyter Notebooks 中探索:我在 Jupyter notebooks 中构建初始管道,这个过程对大多数数据科学家来说都很熟悉。唯一的不同之处在于,一旦你的管道构建完成,你应该整理它,以便每个单元格可以作为你 Kedro 管道中的一个节点。
-
注册你的数据:在数据 catalogue 中记录每个数据处理或数据摄取步骤的输入和输出。
-
添加你的管道:在笔记本中进行探索后,你会想要 创建一个管道。这是通过命令行界面完成的。运行此命令将会在 ‘pipelines’ 中添加一个额外的文件夹,文件夹名称为你刚刚创建的管道名称。在这个文件夹中,你将构建你的节点和管道。
-
定义你的管道:这是将代码从你的 Jupyter 笔记本迁移到你管道文件夹中的 node.py 文件的阶段,确保你希望成为管道一部分的节点都有输入和输出。一旦节点设置完毕,接下来在 pipeline.py 文件中定义你的管道。
-
注册你的管道: pipeline_registry.py 文件提供了一个模板,用于注册你新创建的管道。
生产就绪的管道适应于更广泛的机器学习操作生态系统。阅读我的关于 MLOps 的文章,进行更深入的了解。
支持你的 AI 战略的有效 MLOps 蓝图
towardsdatascience.com
结论
Kedro 是一个出色的框架,用于交付生产就绪的机器学习管道。除了本文讨论的功能外,还有许多与其他开源库的集成,以及用于文档和测试的包。Kedro 并不能解决与模型部署相关的所有问题——例如,模型版本控制可能更适合使用其他工具如 DVC 处理。然而,它将帮助商业环境中的数据科学家生成更可维护、模块化和可重复的生产就绪代码。对于完全的新手来说,学习曲线相对较陡,但文档清晰,并包括引导教程。与这些包中的任何一个一样,最好的学习方法是直接动手试验。
完整的 GitHub 仓库 链接
在 LinkedIn 上关注我
订阅 medium 以获得更多我的见解:
我分享数据科学项目、经验和专业知识,以帮助你在旅程中前行。你可以通过以下方式注册 medium…
如果你有兴趣将 AI 或数据科学整合到你的业务运营中,我们邀请你安排一次免费的初步咨询:
通过免费咨询,发现我们在帮助企业实现雄心勃勃的目标方面的专业知识。我们的数据科学家和…
www.data-centric-solutions.com
使用 Mercury 从 Jupyter Notebook 直接构建优雅的 Web 应用程序。
瞬间将 Notebook 转换为美丽的 Web 应用程序
·
关注 发表在 Towards Data Science · 7 分钟阅读 · 2023 年 4 月 17 日
--
有效的沟通在所有数据驱动的项目中都至关重要。数据专业人员常常需要向包括业务领导、技术团队和其他数据科学家在内的利益相关者传达他们的发现和见解。
尽管传统的数据见解沟通方法,如 PowerPoint 演示和静态报告被广泛偏爱,但它们通常耗时较长。
更重要的是,这些服务要求用户离开 Jupyter Notebook 的舒适环境——这是数据科学家们大多数时间的工作场所。

分离展示工具和 Jupyter(作者提供的图像)
如果我们能通过从 Jupyter Notebook 创建互动且优雅的网页应用来分享我们的发现,那岂不是很好吗?
为此,Mercury 是一个开源工具,它简化了从 Jupyter Notebook 舒适环境中创建网页应用的过程。
因此,在本文中,我将演示如何使用 Mercury 创建令人惊叹的网页应用并与他人分享。
你可以在这里找到这个博客的代码:GitHub。
让我们开始 🚀!
开始使用 Mercury
Mercury 创建的网页应用主要由两个因素驱动:
#1) Jupyter Notebook:
这里是你开发网页应用的地方。我们使用 Mercury 的 input 和 output 小部件来实现互动。
输入小部件允许用户提供输入并与应用互动。Mercury 支持的一些输入小部件如下所示:

Mercury 的小部件(作者提供的图像)
输出小部件用于展示输出。这包括 Markdowns(带变量)、JSON 等。此外,Jupyter 单元格的输出也由 Mercury 渲染。
因此,如果你的应用创建了图表或打印了数据框等,它们将出现在网页应用的输出面板中。
#2) Mercury 服务器
服务器将 Jupyter Notebook 渲染为网页应用。

使用 Mercury 部署应用(作者提供的图像)
正如我们将看到的,渲染笔记本只需运行一个命令。你只需在笔记本中创建你的网页应用。
使用 Mercury 设置网页应用需要几个简单步骤。
安装 Mercury
首先,通过 pip 安装库:
pip install mercury
完成!
现在我们可以使用输入和输出小部件创建我们的网页应用。
使用 Mercury 开发网页应用
如上所述,使用 Mercury 创建的网页应用主要由其小部件驱动。
#1) 导入库
要使用它们,我们首先导入库。再重申一次,我们将从 Jupyter Notebook 开始一切。
## mercury_app.ipynb
import mercury as mr
此外,你可以根据需要导入其他任何库。对于这个博客,我将创建一个网页应用来分析一个自创的虚拟员工数据框。因此,我还将使用以下库:
## mercury_app.ipynb
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
sns.set()
#2) 配置应用
接下来,我们通过提供 title 和 description 来实例化一个 Mercury 应用。
## mercury_app.ipynb
app = mr.App(title="Employee Data Analysis",
description="Employee Report in Mercury")
#3) 用小部件填充应用
接下来,让我们添加一些小部件,以允许用户与以下虚拟数据进行互动:

虚拟数据集(作者提供的图像)
从本质上讲,我们将执行以下操作:
-
添加一个小部件以上传 CSV 文件。
-
让用户根据
Company_Name列中的条目过滤数据。这将是一个MultiSelect小部件。 -
此外,用户还可以使用
Slider根据Credits过滤数据。
一旦数据被过滤,我们将显示以下内容:
-
过滤后数据框的维度。
-
Employee_Salary和Employee_Rating的散点图。 -
显示
Employee_Status列分布的条形图。
让我们现在来构建它!
首先,我们添加文件上传小部件。
## mercury_app.ipynb
data_file = mr.File(label="Upload CSV")
文件的名称可以通过data_file对象的filepath属性访问。因此,一旦文件上传成功,我们将使用 Pandas 读取它,如下所示:
## mercury_app.ipynb
emp_df = pd.read_csv(data_file.filepath)
现在,我们将添加两个小部件——Company_Name上的MultiSelect小部件和Credits列上的Slider。
## mercury_app.ipynb
company = mr.MultiSelect(value=emp_df.Company_Name.unique(),
choices=emp_df.Company_Name.unique(),
label="Select Companies")
此处,value参数指的是初始值,choices以可供选择的值列表显示,而label是出现在小部件旁边的自定义文本。
接下来,我们有滑块小部件。
## mercury_app.ipynb
credits_filter = mr.Slider(value=1,
min=emp_df.Credits.min(),
max=emp_df.Credits.max(),
label="Credits Filter", step=1)
此处,value参数定义初始值,min和max指的是值的范围,label如之前所述,是自定义文本。最后,step定义滑块小部件的步长值。
到此为止,我们完成了交互性的小部件添加。最后一步是根据小部件中的值创建图表。
#4) 填充输出面板
首先,我们根据从小部件接收的值过滤数据框。您可以使用WidgetObj.value属性访问它。
换句话说,要检索company小部件的值,我们可以引用company.value属性。
## mercury_app.ipynb
new_df = emp_df[(emp_df.Company_Name.isin(company.value)) &
(emp_df.Credits>=int(credits_filter.value))]
接下来,使用 Markdown 输出小部件,我们打印过滤后的数据框的维度。
## mercury_app.ipynb
mr.Md(f"""The DataFrame has {new_df.shape[0]} rows
and {new_df.shape[1]} columns.""")
Mercury 的 Markdown 有一个很酷的功能,那就是您还可以使用 f-strings,如上所示。
最后,我们创建图表:
## mercury_app.ipynb
fig, ax = plt.subplots(1, 2, figsize = (16, 9))
sns.scatterplot(data = new_df, ax = ax[0],
x = "Employee_Rating", y = "Employee_Salary") ## scatter plot
sns.countplot(x = new_df.Employment_Status, ax = ax[1]) ## count plot
plt.show();
就这样。现在我们的 Mercury 应用已经准备好了。
#5) 运行网络应用
要运行应用,请在命令行中导航到应用的文件夹并运行以下命令:
mercury run
因此,我们看到了以下内容:

首次查看网络应用(作者图片)
正如预期的那样,我们有一个上传文件的小部件。让我们在这里上传虚拟数据集。

上传数据集(作者图片)
一旦我们上传了 CSV,图表立即显示出来。
现在,我们可以操控输入小部件来分析数据。

数据集分析(作者图片)
随着我们更新过滤器,图表和行数也会更新。这是通过 Mercury 服务器实现的,它在笔记本和应用之间保持持续的交互。
实际上,如果我们更新笔记本,变化会立即反映出来。
Mercury 与 Streamlit
此时一个相关的问题是水银相对于 Streamlit 的表现如何,Streamlit 是我在许多之前的博客中使用过的,如这里和这里。
Streamlit 确实已成为创建网页应用的最常见选择之一。尽管整体体验非常棒,但 Streamlit 也确实有很多限制:
#1) 无 Jupyter 支持
由 Streamlit 驱动的应用程序主要由 Python 脚本提供支持,而不是互动的 Python 内核。因此,在使用 Streamlit 开发应用程序时,必须重复运行脚本以查看进展。

水银与 Streamlit — Jupyter 支持(图片来源:作者)
然而,使用水银创建的应用程序由 Jupyter Notebook 驱动,每次更新都会立即反映在网页应用中。
#2) 导出为 PDF/HTML
使用水银创建的网页应用可以通过点击按钮轻松导出。

水银与 Streamlit — 导出应用程序(图片来源:作者)
这使您可以轻松地通过电子邮件、聊天等方式与他人分享您的应用程序,接收者不一定需要安装 Mercury。
然而,Streamlit 并不提供这种支持。
#3) 创建演示文稿
最后,使用水银创建的网页应用可以轻松地作为互动演示运行。

水银与 Streamlit — 演示(图片来源:作者)
然而,Streamlit 应用程序并不提供这样的支持。
#4) 使用认证保护应用程序
有时候,确保只有授权用户才能访问您的应用程序可能非常重要。这可能是由于存在敏感信息。

水银与 Streamlit — 安全性(图片来源:作者)
使用水银,您可以立即启用认证以保护您的网页应用。Streamlit 原生不支持认证。
因此,当有人运行您的网页应用时,他们将被提示验证其详细信息,如下所示:

水银认证窗口(图片来源:作者)
结论
这样,我们就结束了这篇博客。
在这篇文章中,我们学习了如何从 Jupyter Notebook 的舒适环境中使用水银构建一个简单的网页应用。
最后,我们查看了水银与 Streamlit 的比较以及它如何服务于各种数据通信用例。
作为结尾,您还可以通过Mercury Cloud将您的笔记本托管在云端。只需上传笔记本即可完成。
如果您不希望将网页应用专门托管在 Mercury Cloud 上,也可以将其部署在任何带有 Docker 的服务器上。
感谢阅读!
使用检索增强生成技术构建行业特定的语言模型
组织们正在竞相采用大型语言模型。让我们深入了解如何通过 RAG 构建行业特定的语言模型。
·发布在 Towards Data Science ·阅读时间 10 分钟·2023 年 5 月 31 日
--
公司可以通过像 ChatGPT 这样的语言模型获得大量生产力提升。但试着问 ChatGPT “美国当前的通货膨胀率是多少”,它给出的答案是:
对于混淆表示歉意,但作为一个人工智能语言模型,我没有实时数据或浏览能力。我的回答基于截至 2021 年 9 月的信息。因此,我无法提供美国当前的通货膨胀率。
这就是一个问题。ChatGPT 明显缺少相关的及时上下文,而这些上下文在做出明智决策时可能至关重要。
微软是如何解决这个问题的
在微软的 Build 会议中 Vector Search Isn’t Enough,他们展示了将较少上下文感知的语言模型与向量搜索结合的产品,以创造更具吸引力的体验。
讲座从与本文相反的方向开始——从 Elastic Search(或向量搜索)的角度——以及搜索本身的局限性,并且添加语言模型层可以大幅提升搜索体验。
基本思路是,向语言模型中添加相关上下文可以大幅提升用户体验,尤其是在大多数商业场景中,在这些场景下,语言模型未曾见过这些数据。向量搜索帮助选择相关上下文,当你拥有大量数据,包括数百份(或更多)文档时。
向量搜索 101

向量搜索 101 | Skanda Vivek
向量化是将数据转换为维度。在上述情况下,你可以看到 2 个维度:大小和类型。大小有 2 个值(小或大),类型有 2 个值(树或动物)。这只是一个概念性示例,可以扩展到数百个(或更多值)。

向量搜索 101 与词汇 | Skanda Vivek
你可以用单词或句子做同样的事情,而不是图片。注意在上述示例中,向量化能够捕捉到语义表示,即它知道谈论一只鸟俯冲到小松鼠上的句子应该位于(小,动物)象限,而谈论昨天风暴中一棵大树倒在路上的句子应该位于(大,树)象限。实际上,维度不止 2 个——句子转换嵌入有 384 个维度(句子或文本片段通过 BERT 转换器,输出 384 个数字,对应 384 个维度)。
向量搜索 + LLM(RAG)
向量搜索现在在所有搜索引擎中都很常见(Google、Bing,选择你喜欢的搜索引擎)。想象一下,如果你的整个搜索数据库仅包含这 4 份文档。你可以提出问题,“告诉我关于一棵大树的事情。”
输出将是:
在昨天的风暴中,一棵大树倒在了路上。
仅仅是数据库中最相关的文档。但如果你向 LLM 提出相同的问题(并在最后附上这 4 份“文档”),你会得到这样的结果(以 ChatGPT 为例):
一棵大树通常是一个大型的、耸立的植物,具有坚固的树干、枝条和广泛的根系。这些树木可以达到令人印象深刻的高度,通常高出周围的植被。它们是许多生态系统的重要组成部分,为各种生物提供遮荫、庇护和食物。
你提到的关于一棵大树在风暴中倒在路上的事件表明了恶劣天气条件的力量和影响。强风的风暴可以连根拔起或折断树木,导致道路阻塞,对车辆和行人构成潜在的危险。
在你提到的另一个场景中,当你在班迪布尔的道路上行驶时,你遇到了一只大象穿过道路。班迪布尔是印度的一座国家公园和野生动物保护区,以其丰富的生物多样性和各种动物群体而闻名。大象是班迪布尔的主要物种之一,通常会穿越该地区,有时会穿过道路。这些遭遇突显了野生动物保护的重要性以及人类与动物在其自然栖息地共存的必要性。
转到盆景,它们是生长在小花盆中的迷你树木,通过精心修剪来保持其小巧的尺寸。盆景栽培是一种起源于东亚的古老艺术形式,现已在全球范围内获得普及。由于其根部空间有限,盆景树需要定期浇水以防止干枯。小花盆可能导致土壤迅速干燥,因此必须监控湿度水平并相应地浇水。
最后,你提到通过窗户观察到一只鸟俯冲向一只小花栗鼠。这一观察反映了自然行动的一个瞬间,其中一只鸟可能在寻找食物。猛禽和某些鸟类,如鹰或猫头鹰,通常会俯冲捕捉像花栗鼠这样的微小哺乳动物,作为其自然猎食行为的一部分。目睹这种互动非常吸引人,因为它们提供了关于动物王国动态的见解。
总的来说,这些自然片段——无论是涉及大树、大象、盆景还是鸟类捕食——都突显了围绕我们自然世界的多样性和奇观。
那个回答似乎更加全面。ChatGPT 利用了其从数万亿文档中训练得到的内部知识,并结合了所提供的上下文。
但假设你有的不仅仅是 4 个文档——可能是成千上万甚至百万个,无法全部放入 ChatGPT 的提示中。在这种情况下,你可以使用向量搜索来缩小最可能包含答案的上下文,将其附加到提示中,并按照以下方式提问:
这是它现在给出的(截断的)答案:

ChatGPT 回答 | Skanda Vivek
然后你可以有一个数据库,存储文档和嵌入。你可以有另一个数据库,存储查询,并根据查询找到最相关的文档:

文档数据库(左)和 Quey 数据库(右) | Skanda Vivek
一旦你通过查询获得最相似的文档,你可以将其输入到任何像 ChatGPT 这样的语言模型中。通过这个简单的技巧,你已经增强了你的语言模型,使用了文档检索!这也被称为检索增强生成(RAG)。
使用 RAG 构建行业特定的问答模型

RAG 原型 | Skanda Vivek
上面的图表概述了如何构建一个基本的 RAG,该 RAG 利用 LLM 对自定义文档进行问答。第一部分是将多个文档拆分成可管理的片段,相关的参数是 最大片段长度。这些片段应为包含回答典型问题所需的文本的典型(最小)大小。这是因为有时你提出的问题可能在文档的多个位置都有答案。例如,你可能会问“X 公司在 2015 年到 2020 年的表现如何?”而你可能有一份大文档(或多份文档)包含公司多年表现的具体信息,这些信息分布在文档的不同部分。你理想的做法是捕获包含这些信息的文档的所有不同部分,将它们链接在一起,然后传递给 LLM 进行基于这些过滤和拼接后的文档片段的回答。
最大上下文长度 基本上是将各种片段拼接在一起的最大长度——为问题本身和输出答案留出一些空间(记住,像 ChatGPT 这样的 LLM 有严格的长度限制,包括所有内容:问题、上下文和答案)。
相似度阈值 是将问题与文档片段进行比较的方式,以找到最有可能包含答案的顶部片段。余弦相似度是通常使用的度量,但你可能希望对不同的度量进行加权。例如,可以包括一个关键词度量,以便对包含特定关键词的上下文给予更多权重。例如,当你要求 LLM 总结一份文档时,你可能希望对包含“摘要”或“总结”这两个词的上下文给予更多权重。
如果你想测试生成性问答在自定义文档上的简单方法,可以查看我的 API 和 代码,它们在后台使用了 ChatGPT。
原型 ChatGPT 通过 RAG 增强
让我们通过一个例子来说明 RAG 的实用性。EMAlpha 是一家提供新兴市场洞察的公司——基本上是像印度、中国、巴西等新兴国家的经济(完全披露——我在 EMAlpha 担任顾问)。该公司正在开发一个 ChatGPT 驱动的应用程序,根据用户输入生成对新兴经济体的洞察。仪表盘大致如下——你可以比较 ChatGPT 的输出与能够在后台从 IMF 查询财务文档的 RAG 版本 ChatGPT(EM-GPT)的结果:

EMAlpha 的 EM-GPT | Skanda Vivek
以下是 ChatGPT 对问题“尼泊尔按年份的 GDP 是多少?”的回答:

ChatGPT 响应 | Skanda Vivek
ChatGPT 只返回了 2019 年之前的 GDP 数据,并表示如果需要更多信息,可以查看 IMF。但是,如果你想找出这些数据在 IMF 网站上的具体位置,这很困难,你需要对网站上的文档存储位置有个大致了解。经过一些搜索,你可以在这里找到文档。即便如此,确定 GDP 信息的具体位置仍需要大量的滚动。

IMF 关于尼泊尔经济的文档 | Skanda Vivek
如上所示,找到这些数据确实很困难。但当你问 EM-GPT 同样的问题时,它会追踪到相关上下文,并找到如下答案:

EM-GPT 答案 | Skanda Vivek
以下是发送给 ChatGPT 以回答这个问题的确切提示。它能够理解这个格式化的文本,提取正确的信息——并将其格式化成一个易读的格式,这一点相当令人印象深刻!

使用基于查询检索的上下文的 ChatGPT 提示 | Skanda Vivek
我花了半小时在 IMF 网站上查找这些信息,而 RAG 修改版的 ChatGPT 只用了几秒钟。仅使用向量搜索是不够的,因为它最多只能找到“名义 GDP”这段文本,而不会将数字与年份关联起来。ChatGPT 在过去已经训练过多个此类文档,因此一旦添加了相关上下文,它就知道文本中的哪些部分包含答案以及如何以良好的可读格式格式化这个答案。
结论
RAG 提供了一种很好的方式来使用基于自定义文档的 LLMs。像微软、谷歌和亚马逊这样的公司正在竞相开发组织可以以即插即用的方式使用的应用程序。然而,这一领域仍处于初期阶段,利用向量搜索驱动的 LLMs 对自定义文档进行行业特定应用的公司可以成为首批先行者并超越竞争对手。
虽然有人问我应该使用哪个 LLM,以及是否需要对自定义文档进行微调或完全训练模型,LLM 与向量搜索之间的同步工程作用被低估了。 以下是一些可以显著提高或降低响应质量的考虑因素:
-
文档块的长度。 如果正确答案更可能分布在文本的不同部分并需要拼接在一起,则文档应分割成较小的块,以便可以将多个上下文附加到查询中。
-
相似性和检索度量。 有时候,单纯的余弦相似度是不够的。例如,如果许多文档包含关于同一主题的冲突信息,你可能需要根据这些文档中的元数据限制搜索范围。为此,除了相似度,还可以使用其他过滤度量。
-
模型架构: 我展示的架构是一个原型。为了提高效率和可扩展性,需要考虑多个方面,包括向量嵌入模型、文档数据库、提示、LLM 模型选择等。
-
避免幻觉: 你可能注意到我上面展示的例子是几乎正确的。增强版 ChatGPT 对尼泊尔 GDP 的金额是正确的——但年份是错误的。在这种情况下,需要在选择提示、以 ChatGPT 友好的格式提取数据和评估在多少情况下会出现幻觉以及哪些解决方案有效之间进行大量反馈。
现在你已经知道如何将 LLM 应用到你的自定义数据上,快去构建令人惊叹的 LLM 基础产品吧!
如果你喜欢这篇文章,请关注我——我写的内容涉及在实际应用中应用最先进的 NLP 技术,以及数据与社会之间的交集。
随时通过 LinkedIn与我联系!
如果你还不是 Medium 会员并希望支持像我这样的作者,请通过我的推荐链接注册: https://skanda-vivek.medium.com/membership
以下是一些相关的文章:
构建低延迟和可扩展的 ML 模型预测管道,使用 Spark Structured Streaming 和 MLflow
MLOps 实践系列——分享关键 MLOps 组件的设计和实施模式。今天的文章重点是构建模型预测管道。
·发布于 Towards Data Science ·阅读时间 8 分钟·2023 年 1 月 10 日
--
为了使 ML 模型在实际生产环境中工作,最关键的步骤之一是将训练好的模型部署以进行预测。模型部署(发布)是一个过程,使你能够将训练好的 ML 模型集成到生产环境中,以对实际数据做出决策。谈到模型部署,通常有两种类型:
-
一种是批量预测,即在某个时间间隔(如每天一次或每周一次,具体取决于模型在特定业务背景下的使用方式)调用训练好的模型,并将一批数据输入模型,以定期生成预测结果。
-
另一种是在线预测,即将训练好的模型打包成 REST API 或容器化微服务,模型通过响应 API 请求返回预测结果(通常为 JSON 格式)。使用在线预测,模型实时生成预测结果,即一旦 API 被调用,模型预测结果将立即返回。此外,模型 REST API 通常作为 web 应用的一部分集成,以供最终用户或下游应用进行交互。
然而,在批量预测和在线预测之间,我们已经看到越来越多的场景,其中模型不需要打包成 REST API,但模型预测所需的延迟非常低。因此,为了满足这些场景的需求,我想分享一个解决方案——使用 Spark Structured Streaming 和 MLflow 构建低延迟和可扩展的 ML 模型预测管道。
今天文章的内容是:
-
快速介绍 Spark Structured Streaming 和 mlflow;
-
低延迟和可扩展的机器学习预测管道的关键架构组件;
-
使用 Spark Structured Streaming 和 mlflow 构建低延迟和可扩展的机器学习预测管道的实现细节;
让我们开始吧!

照片由 Jen Theodore 提供,来源于 Unsplash
Spark Structured Streaming 和 mlflow 简介
Spark Structured Streaming — Structured Streaming 是一个基于 Spark SQL 引擎构建的可扩展且容错的流处理引擎。内部默认情况下,Structured Streaming 查询通过微批处理引擎进行处理,该引擎将数据流处理为一系列小批量作业,从而实现低至 100 毫秒的端到端延迟和精确一次的容错保证。
MLflow — MLflow 是一个用于管理整个机器学习生命周期的开源平台。
-
跟踪 — MLflow 的跟踪组件是一个用于记录参数、代码版本、指标和输出文件的 API 和 UI,当运行机器学习代码时以及后续可视化结果时使用。
-
模型 — MLflow 模型是一种标准格式,用于打包机器学习模型,可以在各种下游工具中使用。该格式定义了一种约定,使你可以以不同的“风味”保存模型,以便不同的下游工具理解。内置模型风味可以在这里找到。值得一提的是,python_function模型风味作为 MLflow Python 模型的默认模型接口。任何 MLflow Python 模型都应该能够被加载为 python_function 模型。在今天演示的解决方案中,我们将训练好的模型加载为 python 函数。此外,我们还利用了log_model() 和 load_model(). 的模型 API 调用。
-
模型注册表 — MLflow 的模型注册表组件是一个集中的模型存储库,包含一组 API 和 UI,用于协作管理 MLflow 模型的完整生命周期。它提供模型血统(提供对训练好的 ML 模型的可见性和可追溯性,这些模型来源于特定 MLflow 实验和运行的组合)、模型版本控制、阶段转换(例如从测试阶段到生产阶段)以及注释。
-
项目 — MLflow 项目是一种以可重用和可再现的方式打包数据科学代码的格式,主要基于约定。每个项目实际上是一个包含代码的文件目录或 Git 仓库。
如果你对了解更多关于 Spark Structured Streaming 感兴趣,你可以查看我另一篇文章,专门讨论如何构建流数据管道。
现在让我们深入探讨构建低延迟和可扩展 ML 预测管道的关键架构组件。
低延迟和可扩展 ML 预测管道的关键架构组件

低延迟和可扩展 ML 预测管道 | 作者提供的图片
如上图所示,构建低延迟和可扩展 ML 预测管道需要 3 个关键架构组件:
-
第一步是构建一个流数据管道,以摄取原始数据,将原始数据转换为 ML 特征,并以低延迟将 ML 特征输入到 ML 模型中;
-
第二步是将训练和注册的 ML 模型加载为 Spark 用户定义函数(UDF),以便模型可以并行进行预测,利用 Spark 的分布式计算能力。当需要进行预测的数据量很大时,这一点尤其有用。
-
第三步是将模型预测结果保存到存储在 AWS S3 桶中的 Delta 表中。然后,这些模型预测结果可以用于下游数据消费者和应用程序。例如,你可以在模型预测结果基础上构建一个商业智能(BI)仪表板,以支持业务决策。你还可以构建实时监控机制,根据模型预测生成通知和警报,以提高操作效率。
在展示如何构建低延迟和可扩展 ML 预测管道的实现之前,我们首先需要设置一些前提条件。
- 首先是训练模型的模式,如下文件所示:
{"model_purpose" : "predicts the quality of wine using wine attributes",
"model_flavor" : ["python_function","sklearn"],
# The python_function model flavor serves as a default model interface for MLflow Python models.
# Any MLflow Python model is expected to be loadable as a python_function model.
# This enables other MLflow tools to work with any python model regardless of
# which persistence module or framework was used to produce the model.
"model_algorithm" : "sklearn.linear_model.ElasticNet",
{"model_signature" :
"model_input_schema":[
{"name": "fixed acidity", "type": "string"},
{"name": "volatile acidity", "type": "string"},
{"name": "citric acid", "type": "string"},
{"name": "residual sugar", "type": "string"},
{"name": "chlorides", "type": "string"},
{"name": "free sulfur dioxide", "type": "string"},
{"name": "total sulfur dioxide", "type": "string"},
{"name": "density", "type": "string"},
{"name": "pH", "type": "string"},
{"name": "sulphates", "type": "string"},
{"name": "alcohol", "type": "string"}],
"model_output_schema" [
{"type": "tensor", "tensor-spec": {"dtype": "float64", "shape": [-1]}}
]
},
"model_registry_location" : "runs:/<RUN_ID>/<MODEL_NAME>",
# If you are using mlflow to manage the lifecycle of your models,
# the model is loggged as an artifact in the current run using MLflow Tracking
"model_stage" : "Production",
# With mlflow, you can transition a registered model to one of the stages:
# Staging, Production or Archived.
# In the demo of this article, the model is alreay transitioned to the "production" stage.
"model_owner" : "<MODEL_OWNER_EMAIL/MODEL_OWNER_GROUP_EMAIL>"
}
- 第二个是训练数据和测试数据的模式。确保输入模型的数据模式与模型输入模式匹配对于避免因模式不匹配而导致的预测错误至关重要。数据模式如下所示:
StructType([
StructField('fixed acidity', StringType(), True),
StructField('volatile acidity', StringType(), True),
StructField('citric acid', StringType(), True),
StructField('residual sugar', StringType(), True),
StructField('chlorides', StringType(), True),
StructField('free sulfur dioxide', StringType(), True),
StructField('total sulfur dioxide', StringType(), True),
StructField('density', StringType(), True),
StructField('pH', StringType(), True),
StructField('sulphates', StringType(), True),
StructField('alcohol', StringType(), True),
StructField('quality', StringType(), True)
])
本文中使用的数据来自这里。请随意了解有关数据的更多细节。
现在我们对模型模式和数据模式有了清晰的理解,我们可以开始使用 Spark Structured Streaming 和 MLflow 实现 ML 预测管道。完整的解决方案将在下一节中详细解释。
完整解决方案 — 使用 Spark 结构化流处理和 MLflow 构建低延迟且可扩展的机器学习预测管道
- 第 1 步 — 构建一个流数据摄取管道,以低延迟加载用于预测的数据。结构化流处理允许你通过设置微批间隔来定义数据处理的速度。在今天的演示中,我们将微批间隔设置为 5 分钟,这意味着每 5 分钟,流处理管道将拉取原始数据并调用已部署的机器学习模型进行预测。以下是一个样本流数据摄取管道,将原始数据(CSV 格式)加载到 Spark 流数据框中。
streamingDF = (spark
.readStream
.option("sep",",")
.option("header", "True")
.option("enforceSchema", "True")
.schema(csvSchema)
.csv(<YOUR-CSV-DATA-LOCATION>))
- 第 2 步 — 将注册的模型加载为 Spark 用户定义函数(UDF)函数。
import mlflow
logged_model = 'runs:/<RUN_ID>/<MODEL_NAME>'
# Load model as a Spark UDF.
# Override result_type if the model does not return double values.
loaded_model = mlflow.pyfunc.spark_udf(spark, model_uri=logged_model, result_type='double')
- 第 3 步 — 对流数据框进行预测,并将机器学习模型预测结果保存到 Delta 表中供下游消费者使用。
# Predict on a Spark DataFrame.
from pyspark.sql.functions import struct, col
streamingDF.withColumn('predictions', loaded_model(struct(*map(col, streamingDF.columns))))
完整解决方案
import mlflow
from pyspark.sql.functions import struct, col
from pyspark.sql.types import StructType
logged_model = 'runs:/<RUN_ID>/<MODEL_NAME>'
loaded_model = mlflow.pyfunc.spark_udf(spark, model_uri=logged_model, result_type='double')
checkpointLocation = <STREAMING_CHECKPOINT_LOCATION>
deltaLocation = <PREDICTION_STORAGE_LOCATION>
streamingDF = (spark
.readStream
.option("sep",",")
.option("header", "True")
.option("enforceSchema", "True")
.schema(csvSchema)
.csv(<YOUR-CSV-DATA-LOCATION>)
.withColumn('predictions', loaded_model(struct(*map(col, streamingDF.columns)))))
(streamingDF.writeStream
.format("delta")
.outputMode("append") # .outputMode("complete"), .outputMode("update")
.option("checkpointLocation",checkpointLocation)
.option("path", deltaLocation)
.trigger(processingTime='5 minutes') # trigger(availableNow=True), .trigger(once=True), .trigger(continuous='1 second')
.queryName("streaming csv files")
.start())
总结
感谢阅读今天的文章。希望你可以在你的机器学习驱动的应用程序中利用这一模型部署模式。
在我之前的文章中,MLOps 实践 — 将机器学习解决方案架构拆解为 10 个组件,我将端到端的机器学习解决方案架构拆解为 10 个组件。今天的文章着重解释构建机器学习服务管道的热门模式之一。
我将继续分享 MLOps 实践系列文章,解释关键 MLOps 组件的设计和实现模式。
如果你想在这些文章发布时获得通知,请随时在 Medium 上关注我。我一般每周发布 1 到 2 篇关于数据和 AI 的文章。
如果你想查看更多指南、深入探讨和关于现代高效数据+AI 堆栈的见解,请订阅我的免费通讯 — 高效的数据+AI 堆栈,谢谢!
参考文献
注意:如果你还没有成为 Medium 会员,而且你真的应该成为会员,因为这样你可以无限制访问 Medium,你可以使用我的 推荐链接 注册!
非常感谢你的支持!
多做分析,少做仪表盘
为什么我们应该转向一个以分析为默认第一步的世界
·
关注 发表在Towards Data Science ·7 分钟阅读·2023 年 5 月 13 日
--
自从 Seth Rosen 的trashboard推文已经快 2 年了,但 trashboard 疫情依然真实存在。我们仍然遇到许多公司,他们对拥有成千上万个仪表盘感到奇怪的满意,而他们的分析师则被迫在ad hoc请求的深渊中摸索,只剩下一个 IDE。
当然,所有这些都是出于良好的意图——ad hoc 工作可能很痛苦,而自动化解决方案听起来很有吸引力。更重要的是,仪表盘是一种安全的选择——毕竟我们仍然生活在一个仪表盘被视为分析的默认工具的世界中。
让我们谈谈我们是如何走到这一步的,为什么需要改变,以及我们应该如何前进。简而言之,接下来我将论证:
-
几十年的传统使我们来到了一个奇怪的地步,在这里仪表板占据了优先地位。
-
但仪表板并不总是为组织提供价值的最佳方式,尤其是在前几个之后。
-
如果你相信上述两个观点,你的信念、你的流程、你的工具需要根本性的改变。

图片来源于 Midjourney,已获许可使用。
问题:在帕累托世界中仪表板的优先地位
仪表板与分析,从烹饪的角度来看
所以不要误解——仪表板是很棒的,它们只是有些粗糙。它们可以满足无尽的需求,但在这样做的过程中,它们很少能完美地满足特定的需求。这就像数据的海森堡不确定性原理:你暴露的数据越多,你的洞察力就越不尖锐。仪表板的本质是优先展示数据。它们仅提供数据,而解读则留给用户。当然,某些核心仪表板是绝对必要的。每个业务单元都有核心的指标和事件需要跟踪,这些必须在某处展示——仪表板在这里充当业务的脉搏。但除此之外,每增加一个仪表板的价值会大幅减少。
相对而言,分析则更为尖锐。它们揭示了解读,使得与业务问题的相关性更加精准。如果仪表板像沙拉吧,那么分析就是按需定制的餐点。正如特定的餐点更能满足口味,分析总是更好地满足特定业务需求。另一方面,增量的沙拉吧项目(更多仪表板)则收益递减——它们永远不能完全满足某种特定的需求,但可以粗略地解决大部分饥饿相关的问题。拥有几个这样的仪表板可以带来很大好处,但你不需要无休止地继续构建它们。

一个沙拉吧。图片来源于 Midjourney,已获许可使用。
那么为什么仪表板仍然具有优先地位?
虽然我发现大家通常都同意这种观点,但我们生活在一个奇特的历史时期,仪表板仍然保持着其优先地位。当数据处理速度慢、数据难以处理、仪表板制作繁琐时——核心的仪表板维护是分析师工作的全部。但随着数据仓库速度的加快、分析需求的变化、额外仪表板的收益递减,我们的思维方式从未做出相应的调整。
我们不知不觉中被某种沉没成本谬误的扭曲困住了:我们将所有增量资源都分配给了仪表板,因此我们开始相信仪表板是至关重要的。

图片来源:作者。
我必须承认,仪表盘本身确实具有吸引力:仪表盘让我们能够考虑到高覆盖率进行构建。它们售卖了这样一个梦想:有一天,我们的临时请求量会降为零。有了足够的仪表盘,总有一天我们会解决所有问题及其后续问题。
但不幸的是,经营业务很少是如此的公式化。幂律分布泛滥。优先连接产生了价值和无价值的地方,变化和停滞,积累和空虚。并非所有事物都是平等的。这是帕累托法则的世界,其中雪崩推动变革。在这个世界里,20% 的努力驱动了 80% 的结果。

作者提供的图片。
在这个世界中,分析也不例外。如果你按影响力对你做过的所有工作进行排名,我猜你会发现你的大部分价值来自于一些关键的分析。在这个世界里,仪表盘则陷入困境 — 我们在详尽性上获得的同时,也在从中挤取大量的信息。
我们从这里该去哪里?
我长期以来一直是提升分析师角色的支持者。但最近几个月很少做分析工作,最近收到了第一个临时请求,我承认我立即跳进了数据中。我强烈希望尽快回答问题 — 我视这项工作为一种阻碍,而不是一个机会。这是试图减少我所需的工作,但在这样做的过程中,我从联合创始人跌落到了人类 API。我成为了破坏分析的恶性循环的受害者,尽管自称是这种边缘分析宗教的策展人。
“一切都是自画像。一切都是日记。当你关心某事时,这对每个人来说显而易见。当你关心时,你的投入或纪律是有方向的。大多数人能直觉地感受到。”
- 查克·帕拉尼克

Midjourney 提供的图片,已获得许可使用。
因此,我们需要改变我们的默认设置 — 我们的信念,我们的流程,我们的工具 — 否则我们就会冒着重新陷入同样旧习惯的风险。
信念:数据提取之外的价值。
推动任何行为改变的第一步,仅仅是内化你在做某事时有一个明确的原因。我们需要相信我们的价值不仅限于获取数据。如果你认为你的组织主要需要一组核心仪表板,那就设置它们,并将精力集中在这上面。但如果你有一种直觉,认为你可以在此之外提供更多的 leverage —— 通过帮助人们做出更好的决策,甚至通过保持其他人在交流中的思想诚实 —— 那么内化这一点,并在此假设为真的前提下继续。
过程:默认进行分析,而不是默认生成仪表板。
在某种程度上,我相信我们都希望担任更有影响力的角色,为业务创造更多价值,参与更高层次的战略问题解决。但仅仅认识到这一点是远远不够的,剩下的障碍仍然相当可观。据我估计,从现在开始最令人丧失活力的阻碍是,陷入坏习惯太容易了。我们需要建立系统,使正确的事情做起来容易,做错误的事情变得更难。
“设定目标有助于确定方向,但系统才是取得进展的最佳方式。”
- James Clear,《原子习惯》
如果你已经大量投入了自助系统和一组核心仪表板,我会考虑转向默认进行分析的世界。如果我们花费我们的时间尽力最大化我们工作的有用性(分析),而不是试图最大化我们工作的可扩展性(仪表板),那么我们的工作将会更有用,这是理所当然的。
此外,建立模板和标准,规范工作应该如何进行。为您的团队编写坚实的操作原则。您甚至可以制定严格的政策来强化更好的行为:例如,分析不放在我们的集中、共识的地方就不算数。采用强化良好行为的仪式:每天举行站立会议,展示这项工作;将仪式遵从性融入您的绩效评估中。尽一切可能确保人们做您知道会让他们变得更好的事情。

图片由作者提供。
工具:笔记本?至少,是一个以分析为先的工作流程。
如果我们信奉默认进行分析,我们需要一套工具来颠覆常见的工作模式,并推动我们优先进行分析,然后再是仪表板,而不是反过来。当然,我有偏见 —— 这是构建Hyperquery背后的理念 —— 即使你对我们的愿景并不完全认同,确保你不是为每一次分析都匆忙使用 Tableau,也是一个值得称赞的第一步。至少,首先是讲述故事。通过简单地强迫自己默认将文字作为交付物,而不是数据,你已经在打破一种有害的习惯了。
当我在 Airbnb 时,我经常会收到利益相关者的问题。很长一段时间,我的第一反应是立即跳到 SQL。我们都经历过这种情况:为了提供快速的价值,我们全力以赴地进行少数派报告,现场编码。这种方式有其适用的时间和场合,但我们却过于急切地前往那里,而不是思考——综合并借鉴我们已经进行的过去实验和深度挖掘中的学习。我们未能提供杠杆价值,而是寻求对我们闪亮(无与伦比,当然!)的技术能力的认可。而这悄然成为了他人对我们的期望。
以叙事为首的方式强调对齐、相关性和价值,而非即时性。叙事优先将对话提升到战略贡献的层面。叙事优先向你的利益相关者暗示你的价值在于叙事,而非代码。
最后的评论
分析领域正在发生变革。即使你不认同我所提出的观点,不可否认的是,在过去十年里,我们的行业发生了天文数字的变化。随着新车的出现,也带来了新道路和新法规。至少,我希望这篇文章能促使你思考那个新世界应当是什么样的。
👋 你好!我是 Robert, Hyperquery 的 CPO 和前数据科学家 + 分析师。可以在 LinkedIn 或 Twitter 上找到我——我总是乐于交流。🙂
使用检索增强生成构建更强大的大型语言模型
原文:
towardsdatascience.com/build-more-capable-llms-with-retrieval-augmented-generation-99d5f86e9779
检索增强生成如何通过整合知识库提升您的大型语言模型
·发表于 Towards Data Science ·阅读时间 12 分钟·2023 年 8 月 9 日
--

作者提供的图片:使用 Midjourney 生成
ChatGPT 的局限性
ChatGPT 在生成代码以外的许多实际业务场景中是有限的。这些局限性源自训练数据以及模型的幻觉倾向。在撰写时,如果你询问 Chat-GPT 关于 2021 年 9 月之后发生的事件,你很可能会收到如下回应:

作者提供的图片
这并没有帮助,那么我们如何解决这个问题呢?
选项 1 — 对模型进行训练或微调,以使用最新的数据。
微调或训练模型可能不切实际且昂贵。撇开成本不谈,准备数据集所需的工作量也足以使这一选项不被考虑。
选项 2 — 使用检索增强生成(RAG)方法。
RAG 方法使我们可以让大型语言模型访问最新的知识库。这比从头训练模型或进行微调便宜得多,而且实现起来也更简单。在这篇文章中,我将展示如何使用 RAG 来利用你的 OpenAI 模型。我们将通过对模型进行短期测试,以分析其回答有关 2022 年俄罗斯-乌克兰冲突的问题的能力,知识库来源于 维基百科。
注意:尽管这个话题比较敏感,但由于当前 ChatGPT 模型对此没有相关知识,因此被选择作为讨论的内容。
库和前置条件
你需要一个 OpenAI API 密钥,你可以直接从他们的网站获取,或者按照这个 教程。RAG 使用的框架是 Deepset 的 Haystack,它是开源的。他们提供了 API,允许你在大型语言模型之上构建应用程序。我们还利用了来自 Hugging Face 的句子变换器和变换器库。
句子嵌入帮助模型理解文本
在深入技术细节之前,我们应简要介绍句子嵌入。理解这个概念对获得 RAG 方法如何工作的直觉至关重要。
这可能是一个陈词滥调,尤其是对于那些有数据科学背景的人,但模型实际上并不理解文本,它们只理解数字。语言建模的大部分内容是关于制定有效地将文本编码为数字的方式,目前我们通过句子嵌入来做到这一点。
句子嵌入是一种将句子表示为密集向量的方式,同时保留其语义结构。嵌入是从深度神经网络中的密集层中学习的,该网络的结构可以因网络而异。
更简单地说,句子嵌入可以被认为是我们句子的数字表示,保留了它们的语义信息。我们可以从预训练模型中获取句子嵌入。Hugging Face 通过他们的句子变换器库提供了开源模型。
预处理和存储
在我们可以构建 RAG 启用的模型之前,我们需要预处理和存储文档。让我们探索一下如何做到这一点,但首先注意这个过程的架构,以帮助你理解。
注意:架构图也适用于我们在下一节中定义的抽取式问答管道。

图片由作者提供:高层次架构预处理、向量存储和抽取式问答管道
预处理我们的文档
Haystack 为我们提供了方便的工具来预处理大多数类型的文本文件(包括 .pdf、.txt、.docx)。预处理步骤很简单;我们使用 convert_files_to_docs() 函数读取知识库,该函数可以自动检测文件类型并将其转换为我们需要的下游格式。
Haystack 还提供了一个 PreProcessor 类,使我们能够对文档应用预处理步骤。你应用的步骤将很大程度上取决于你的具体应用。
注意:处理器将允许你将长文档拆分为较小的文档列表,定义为句子。对于问答应用,一个常见的方法是句子之间有一些重叠;我将其设置为大约 30%。
向量存储
我们利用了FAISS,这是 Meta 开发的一个库,用于在我们的句子嵌入中进行高效的相似性搜索。这一点在接下来的章节中将变得更加明显。下面的脚本展示了我们如何为文本文档设置预处理,并建立 FAISS 向量存储。
完整的预处理脚本如下:
作者脚本:文档预处理和 FAISS 向量存储
定义抽取式问答管道
下一步是构建我们的抽取式问答管道。我们的管道由按顺序运行的节点组成,形成一个有向无环图(DAG)。在这种情况下,管道由两个节点组成,一个检索器和一个阅读器。
检索器
检索是根据用户的查询从知识库中找到相关信息的方法。在定义我们的检索器节点时,我们从句子变换器库中指定一个句子嵌入模型,这里我们使用all-mpnet-base-v2,它生成 768 维的嵌入。
一旦我们完全定义了检索节点,我们可以在 FAISS 向量存储中计算并存储句子嵌入。相同的句子嵌入模型用于生成用户查询的句子嵌入。
注意:在选择句子嵌入模型时,计算效率和信息损失之间存在权衡。通常,维度更高的句子嵌入模型能捕捉更多信息,但计算效率较低。
记住,检索器节点的最终目的是找到与查询在语义上相关的信息。这是通过在查询的句子嵌入和我们向量存储中的文档之间执行相似性搜索来完成的。此节点返回最相关的前 k 个句子嵌入作为输出。
阅读器
阅读器节点是一个经过微调的大型语言模型,用于问答任务。在我们的用例中,我们利用Roberta-base-squad2作为我们的阅读器模型。阅读器处理来自检索器的输出和用户的初始查询,并返回相关的文本片段以回答查询。阅读器会对检索器输出中的每个文档进行处理,为每个文档分配一个置信度分数。答案按分数排序,返回前 k 个答案。
管道的完整脚本如下:
作者脚本:定义我们的抽取式问答管道
利用 LLM 驱动的代理
现在我们已经预处理了文档并定义了我们的管道,我们可以构建我们的代理。一个代理由大型语言模型提供支持,在我们的案例中,是 OpenAI 的 gpt-4(或 gpt-3.5-turbo)。
我们使用的代理基于 Zero-shot ReAcT(Reason + Act)提示。我们提示大型语言模型返回任务的口头推理轨迹和行动。代理可以根据我们提供的工具集对口头推理轨迹进行“行动”。代理可以观察并推理使用工具后的输出,帮助它决定下一步行动。
下面的动画展示了我们 ReAct 代理工作的简化视图。

作者 GIF:我们代理工作的简化视图
注意: 研究 表明 ReAct 提示已被证明能够有效减少 LLM 的幻觉。
以下脚本展示了我们如何构建我们的代理:
作者脚本:定义我们的零-shot ReAct 代理
在这个用例中,我们为代理提供的工具是我们之前定义的提取式问答管道。本质上,代理可以解释我们的查询,使用我们的管道作为工具在知识库中找到相关回答,推理是否回答了查询,选择响应,或再次执行该操作以获得更好的答案。
工具简要说明
明确地定义我们的工具非常重要。对于我们的特定用例,我们只有一个工具,但我们需要提供该工具的描述,以便 LLM 驱动的代理知道如何以及何时使用该工具。当一个代理可以访问的不仅仅是一个工具时,这一点变得更加重要。
测试实施
我们已经定义了我们的 RAG 启用模型;让我们通过向它提问关于 2022 年俄乌冲突的琐事来进行测试。我选择了两个问题,每个问题都可以通过维基百科知识库中的信息来回答。除了提供的回答,我还提供了每个代理的工作流程跟踪,你应该仔细检查这些以更好地了解 ReAct 提示如何与代理协同工作。
国际社会如何回应 2022 年俄罗斯入侵乌克兰的事件?
RAG 启用模型的回应:
国际社会对 2022 年俄罗斯入侵乌克兰的反应是谴责并施加新的制裁。这些制裁是全面的,涉及个人、银行、企业以及包括货币交换、银行转账、出口和进口在内的各种经济活动。具体措施之一包括最高 15 年的监禁。
注意:最高 15 年的监禁似乎是一个幻觉。模型错误地将其建议为制裁措施。

作者提供的 Gif:回应“国际社会如何回应俄罗斯在 2022 年入侵乌克兰?”的问题。
来自 Google 搜索的回应:

作者提供的图片:来自 Google 搜索的回应:“国际社会如何回应俄罗斯在 2022 年入侵乌克兰?”
以下是追踪代理工作流的记录:
Agent custom-at-query-time started with
{'query':
'How did the international community respond to
Russia’s invasion of Ukraine in 2022?', 'params': None}
The 'transcript' parameter is missing from
the Agent's prompt template. All ReAct agents
that go through multiple steps to reach
a goal require this parameter. Please append
{transcript} to the end of the Agent's
prompt template to ensure its proper functioning.
A temporary prompt template with {transcript} appended
will be used for this run.
find out the details about the international
community's response to Russia's invasion of Ukraine
in 2022.
I don't have this information readily available,
so I will use the document_qa tool to
find it.
Tool: document_qa
Tool Input: "International community response to Russia's invasion of Ukraine
in 2022"
Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.05it/s]
Inferencing Samples: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:09<00:00, 9.98s/ Batches]
Observation: 308
Thought: The document_qa tool provided a response
code of 308, which is not informative.
I need to try again to get
the required information.
Tool: document_qa
Tool Input: "How did the international community respond to Russia's invasion
of Ukraine in 2022?"
Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 9.72it/s]
Inferencing Samples: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 7.78s/ Batches]
Observation: condemnation
Thought: The document_qa tool responded with "condemnation".
This suggests that the international community condemned
Russia's invasion of Ukraine in 2022\. However,
this is a very broad response and
doesn't provide specific details about the actions
taken by the international community. I need
to find more detailed information.
Tool: document_qa
Tool Input: "What specific actions did the international community take in response to Russia's invasion of Ukraine in 2022?"
Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.41it/s]
Inferencing Samples: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 7.95s/ Batches]
Observation: new sanctions
Thought: The document_qa tool responded with "new
sanctions". This suggests that the international community
imposed new sanctions on Russia in response
to their invasion of Ukraine in 2022.
However, this still doesn't provide a comprehensive
picture of the international community's response. I
need to find more detailed information.
Tool: document_qa
Tool Input: "What were the new sanctions imposed by the international community on Russia in 2022?"
Batches: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.66it/s]
Inferencing Samples: 100%|█████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:10<00:00, 10.77s/ Batches]
Observation: prison sentences of up to 15 years
Thought: The document_qa tool responded with "prison
sentences of up to 15 years". This
suggests that part of the sanctions imposed
by the international community on Russia in
2022 included prison sentences of up to
15 years. However, this seems to be
a specific measure and may not represent
the full range of sanctions. I need
to find more detailed information.
Tool: document_qa
Tool Input: "What other sanctions were imposed by the international community on Russia in 2022?"
Observation: individuals, banks, businesses, monetary exchanges,
bank transfers, exports, and imports
Observation: individuals, banks, businesses, monetaryexchanges, bank transfers,
exports, and importsThought: The document_qa tool responded
with "individuals, banks, businesses, monetary exchanges, bank
transfers, exports, and imports". This suggests that
the sanctions imposed by the international community
on Russia in 2022 were comprehensive, affecting
individuals, banks, businesses, and various types of
economic activity including monetary exchanges, bank transfers,
exports, and imports.
Final Answer: The international community responded
to Russia's invasion of Ukraine in 2022
with condemnation and the imposition of new
sanctions. These sanctions were comprehensive, affecting individuals,
banks, businesses, and various types of economic
activity including monetary exchanges, bank transfers, exports,
and imports. Some of the specific measures
included prison sentences of up to 15 years.
谁领导了 2022 年 8 月 29 日调查该工厂的 IAEA 团队?
RAG 支持的模型的回应:
拉斐尔·格罗西领导了 2022 年 8 月 29 日调查该工厂的国际原子能机构(IAEA)团队。

作者提供的 Gif:回应“谁领导了 2022 年 8 月 29 日调查该工厂的 IAEA 团队?”的问题。
来自 Google 搜索的回应:

作者提供的图片:来自 Google 搜索的回应:“谁领导了 2022 年 8 月 29 日调查该工厂的 IAEA 团队?”
以下是追踪代理工作流的记录:
Agent custom-at-query-time started with {'query': 'Who
led the IAEA team that investigated the
plant on 29 August\xa02022?', 'params': None}The 'transcript'
parameter is missing from the Agent's prompt
template. All ReAct agents that go through
multiple steps to reach a goal require
this parameter. Please append {transcript} to the
end of the Agent's prompt template to
ensure its proper functioning.
A temporary prompt template with {transcript} appended will
be used for this run.
find out what the IAEA is and
what plant was investigated on 29 August 2022.
Then, I need to find out who led
the IAEA team during this investigation.
I don't have this information at hand,
so I will use the document_qa tool
to find it.
Tool: document_qa
Tool Input: Who led the IAEA team that
investigated the plant on 29 August 2022?
Observation: Rafael Grossi
Thought: The document_qa tool has provided the name
Rafael Grossi in response to the query.
This suggests that Rafael Grossi led the
IAEA team that investigated the plant on
29 August 2022.
Final Answer: Rafael Grossi led the IAEA team
that investigated the plant on 29 August 2022.
出于你的好奇,可以查看ChatGPT 的回应,你也可以尝试自己提问以确认回应。
结论
检索增强生成(RAG)使大型语言模型能够连接到现有的知识库。支持 RAG 的语言模型可以访问最新的信息,使其在各种使用场景中更为实用。
检索增强生成(RAG)方法使模型能够查询大量文本,从而克服了大型语言模型自身面临的上下文限制问题。像 Haystack 这样的开源框架使得快速构建 RAG 支持的 LLM 原型变得简单。
需要注意的一些关键点是,该方法的性能仅与提供的知识库的质量相当。此外,通过在合适的基础设施上部署模型,可以显著加快推理时间。
对于更复杂的查询,代理可能会遇到无法在分配的步骤数量内回应的情况。观察增加步骤数量或加入记忆组件以使流程更具对话性会很有趣。利用更高效的句子嵌入模型也可能有助于提高整体性能。
代码库在GitHub仓库中(包括前端),你可以在这里进行实验。
这里还有一个项目的YouTube演示。
感谢阅读。
如果你希望提升人工智能技能,可以加入我的课程的等待名单,在那里我将引导你开发基于大型语言模型的应用程序。
如果你希望为你的业务实现 AI 转型,今天就预约一次发现电话。
在 Brainqub3,我们开发定制的人工智能软件。我们使用最新的人工智能技术创建 qub3s,即先进的人工智能大脑,以…
www.brainqub3.com](https://www.brainqub3.com/?source=post_page-----99d5f86e9779--------------------------------)
欲了解更多关于人工智能、数据科学和大型语言模型的见解,您可以订阅YouTube频道。
使用持续集成构建可靠的机器学习管道
使用持续集成自动化机器学习工作流程
·发表于 Towards Data Science ·阅读时长 8 分钟·2023 年 4 月 6 日
--
情境
作为数据科学家,你负责改进当前在生产中的模型。在花费几个月时间微调模型后,你发现了一个比原始模型更准确的模型。
对你的突破感到兴奋,你创建了一个拉取请求,将你的模型合并到主分支中。

作者提供的图片
不幸的是,由于众多的变更,你的团队需要超过一周的时间来评估和分析这些变更,这最终阻碍了项目进展。
此外,在部署模型后,你发现由于代码错误导致的意外行为,使公司损失了金钱。

作者提供的图片
回顾来看,在提交拉取请求后自动化代码和模型测试本可以避免这些问题,并节省时间和金钱。
持续集成(CI)为这一问题提供了简单的解决方案。
什么是 CI?
CI 是将代码更改持续合并和测试到共享仓库中的实践。在机器学习项目中,CI 因多种原因非常有用:
-
尽早捕获错误:CI 通过自动测试任何代码更改来促进早期错误的识别,从而在开发阶段实现及时的问题检测
-
确保可重复性:CI 通过建立明确且一致的测试程序来帮助确保可重复性,使得复制机器学习项目结果变得更容易。
-
更快的反馈和决策:通过提供明确的指标和参数,CI 使得反馈和决策更加迅速,释放了审阅者的时间以处理更关键的任务。

作者提供的图片
本文将展示如何为机器学习项目创建一个 CI 流水线。
随意尝试和分叉本文的源代码:
[## GitHub - khuyentran1401/cicd-mlops-demo: 机器学习项目中的 CI/CD 演示
这是文章《构建可靠的机器学习流水线与持续集成》的一个示例项目。CI/CD…
github.com](https://github.com/khuyentran1401/cicd-mlops-demo/?source=post_page-----ea822eb09bf6--------------------------------)
CI 流水线概述
为机器学习项目构建 CI 流水线的方法可以根据每个公司的工作流程有所不同。在本项目中,我们将创建一个最常见的工作流程来构建 CI 流水线:
-
数据科学家对代码进行更改,在本地创建一个新模型。
-
数据科学家将新模型推送到远程存储。
-
数据科学家为更改创建一个拉取请求。
-
CI 流水线被触发以测试代码和模型。
-
如果更改被批准,它们将合并到主分支中。

作者提供的图片
让我们基于这个工作流程举一个例子。
构建工作流
假设实验 C 在尝试了各种处理技术和机器学习模型后表现异常出色。因此,我们的目标是将代码和模型合并到主分支中。

作者提供的图片
为了实现这一点,我们需要执行以下步骤:
-
对实验的输入和输出进行版本控制。
-
将模型和数据上传到远程存储。
-
创建测试文件以测试代码和模型。
-
创建一个 GitHub 工作流。

作者提供的图片
现在,让我们详细探讨这些步骤。
对实验的输入和输出进行版本控制
我们将使用 DVC 对流水线实验的输入和输出进行版本控制,包括代码、数据和模型。
## DVC + GitHub Actions: 自动重新运行流水线中修改过的组件
快速迭代你的数据科学项目的完美组合
towardsdatascience.com
流水线是根据项目中的文件位置定义的:

作者提供的图片
我们将在 dvc.yaml 文件中描述流水线的阶段及其之间的数据依赖关系:
stages:
process:
cmd: python src/process_data.py
deps:
- data/raw
- src/process_data.py
params:
- process
- data
outs:
- data/intermediate
train:
cmd: python src/train.py
deps:
- data/intermediate
- src/train.py
params:
- data
- model
- train
outs:
- model/svm.pkl
evaluate:
cmd: python src/evaluate.py
deps:
- model
- data/intermediate
- src/evaluate.py
params:
- data
- model
metrics:
- dvclive/metrics.json
要运行在 dvc.yaml 中定义的实验流水线,请在终端中输入以下命令:
dvc exp run
我们将获得以下输出:
'data/raw.dvc' didn't change, skipping
Running stage 'process':
> python src/process_data.py
Running stage 'train':
> python src/train.py
Updating lock file 'dvc.lock'
Running stage 'evaluate':
> python src/evaluate.py
The model's accuracy is 0.65
Updating lock file 'dvc.lock'
Ran experiment(s): drear-cusp
Experiment results have been applied to your workspace.
To promote an experiment to a Git branch run:
dvc exp branch <exp> <branch>
运行将自动生成dvc.lock文件,该文件存储数据、代码和它们之间依赖项的精确版本。使用相同版本的输入和输出可以确保将来可以重现相同的实验。
schema: '2.0'
stages:
process:
cmd: python src/process_data.py
deps:
- path: data/raw
md5: 84a0e37242f885ea418b9953761d35de.dir
size: 84199
nfiles: 2
- path: src/process_data.py
md5: 8c10093c63780b397c4b5ebed46c1154
size: 1157
params:
params.yaml:
data:
raw: data/raw/winequality-red.csv
intermediate: data/intermediate
process:
feature: quality
test_size: 0.2
outs:
- path: data/intermediate
md5: 3377ebd11434a04b64fe3ca5cb3cc455.dir
size: 194875
nfiles: 4
将数据和模型上传到远程存储
DVC 使得将管道阶段产生的数据文件和模型上传到dvc.yaml文件中的远程存储位置变得简单。
就像 Git 一样,但带有数据!
towardsdatascience.com
在上传文件之前,我们将在文件.dvc/config中指定远程存储位置:
['remote "read"']
url = https://winequality-red.s3.amazonaws.com/
['remote "read-write"']
url = s3://winequality-red/
确保将 S3 桶的 URI 替换为“读写”远程存储 URI。

图片由作者提供
将文件推送到名为“读写”的远程存储位置:
dvc push -r read-write
创建测试
我们还将生成测试,以验证处理数据、训练模型以及模型本身的代码性能,确保代码和模型符合我们的期望。

图片由作者提供
查看所有测试文件 这里。
创建 GitHub 工作流
现在进入激动人心的部分:创建一个 GitHub 工作流以自动化测试您的数据和模型!如果您不熟悉 GitHub 工作流,我建议阅读这篇文章以获得快速概述。
我们将在文件.github/workflows/run_test.yaml中创建名为Test code and model的工作流:

图片由作者提供
name: Test code and model
on:
pull_request:
paths:
- conf/**
- src/**
- tests/**
- params.yaml
jobs:
test_model:
name: Test processed code and model
runs-on: ubuntu-latest
steps:
- name: Checkout
id: checkout
uses: actions/checkout@v2
- name: Environment setup
uses: actions/setup-python@v2
with:
python-version: 3.8
- name: Install dependencies
run: pip install -r requirements.txt
- name: Pull data and model
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: dvc pull -r read-write
- name: Run tests
run: pytest
- name: Evaluate model
run: dvc exp run evaluate
- name: Iterative CML setup
uses: iterative/setup-cml@v1
- name: Create CML report
env:
REPO_TOKEN: ${{ secrets.TOKEN_GITHUB }}
run: |
# Add the metrics to the report
dvc metrics show --show-md >> report.md
# Add the parameters to the report
cat dvclive/params.yaml >> report.md
# Create a report in PR
cml comment create report.md
on字段指定管道在拉取请求事件时触发。
test_model作业包括以下步骤:
-
查看代码
-
设置 Python 环境
-
安装依赖项
-
使用 DVC 从远程存储位置提取数据和模型
-
使用 pytest 运行测试
-
使用 DVC 实验评估模型
-
设置Iterative CML(持续机器学习)环境
-
使用 CML 创建包含指标和参数的报告,并在拉取请求中评论该报告。
请注意,为了使作业正常运行,需要以下内容:
-
AWS 凭证以提取数据和模型
-
GitHub token用于评论拉取请求。
为了确保在我们的仓库中安全存储敏感信息,并允许 GitHub Actions 访问它们,我们将使用加密密钥。

图片来自作者
就这样!现在让我们尝试这个项目,看看它是否按照我们预期的方式工作。
尝试一下
设置
要尝试这个项目,首先创建一个新的仓库,使用项目模板。

克隆仓库到本地机器:
git clone https://github.com/your-username/cicd-mlops-demo
设置环境:
# Go to the project directory
cd cicd-mlops-demo
# Create a new branch
git checkout -b experiment
# Install dependencies
pip install -r requirements.txt
从名为“read”的远程存储位置拉取数据:
dvc pull -r read
创建实验
如果对params.yaml文件或src和tests目录中的文件进行任何更改,将触发 GitHub 工作流。为了说明这一点,我们将对params.yaml文件进行一些小改动:

图片来自作者
接下来,让我们创建一个带有更改的新实验:
dvc exp run
将修改后的数据和模型推送到名为“read-write”的远程存储:
dvc push -r read-write
添加、提交并推送更改到仓库:
git add .
git commit -m 'add 100 for C'
git push origin experiment
创建拉取请求
接下来,通过点击 Contribute 按钮创建拉取请求。

图片来自作者
在仓库中创建拉取请求后,将触发 GitHub 工作流来对代码和模型进行测试。
如果所有测试都通过,将会在拉取请求中添加一个评论,包含新实验的度量和参数。

图片来自作者
这些信息使得审查更容易理解对代码和模型所做的更改。因此,他们可以快速评估这些更改是否符合预期的性能标准,并决定是否批准将 PR 合并到主分支。这是多么酷啊?
结论
恭喜!你刚刚学会了如何为你的机器学习项目创建 CI 管道。我希望这篇文章能激励你创建自己的 CI 管道,以确保可靠的机器学习工作流程。
我喜欢写关于数据科学的概念,并玩各种数据科学工具。你可以通过LinkedIn和Twitter与我联系。
如果你想查看我写的文章中的代码,请给这个仓库加星。关注我在 Medium 上的账号,以便接收我最新的数据科学文章:
[## 使用 Pydantic 和 Prefect 构建全栈机器学习应用程序
用一行代码创建一个机器学习特征工程的 UI
如何构建一个全栈 ML 应用程序与 Pydantic 和 Prefect
4 个 pre-commit 插件以自动化 Python 代码审查和格式化
使用此模板开始您的下一个机器学习项目
数据科学项目的 Pytest 完整指南
使用 black, flake8, isort 和 interrogate 编写高质量代码
使用 Python 自动化 PLAXIS 中的土壤剖面
原文:
towardsdatascience.com/build-soil-profile-in-plaxis-using-python-a9c870e253c7
PLAXIS 自动化系列
自动化的逐步指南
·发表于 Towards Data Science ·10 分钟阅读·2023 年 1 月 4 日
--

作为一名岩土工程师,最重要的 PLAXIS 工作流之一是建立土壤剖面并分配正确的土壤属性。虽然土壤输入界面对用户友好,但这个过程可能会耗时。
可能有改进的空间:
-
在一个表格中创建多个钻孔。
-
自动分配每一层的土壤属性。
-
能够使用重复的土壤属性。当然,这也可以通过 PLAXIS 内置的“.matXdb”来存储材料数据库来实现。然而,Excel 格式提供了更多的灵活性,以根据项目更改材料属性,并且可以链接到其他主电子表格。
本教程旨在扩展从 第四个教程 中学到的内容。我们将进一步开发我们的 Excel 界面,以定义土壤剖面并在 PLAXIS 中分配土壤属性。
-
使用 Pandas 从 Excel 中读取值。
-
使用 Excel 输入土壤深度并创建土壤剖面
-
使用 Excel 输入土壤属性并分配材料
与之前一样,本教程要求读者安装 VS Code 和 PLAXIS 环境。如果你对这个页面不熟悉,请按照下面文章中的说明进行操作。
自动化的逐步指南
towardsdatascience.com
此外,我们还需要在 PLAXIS 环境中安装 pandas。如果还未安装外部模块,请按照以下说明进行安装。
PLAXIS 输出可视化使用 Python
自动化的逐步指南
towardsdatascience.com
Excel 输入界面
类似于第 4 个教程,我们希望创建一个 Excel 输入模板,并在 PLAXIS 中创建土壤剖面。
我们将创建一个空的 Excel 表格,命名为“Soil_input”。该界面包含两个工作表:
-
OHE Ground Profile
-
土壤属性

作者提供的 Excel 截图
创建工作表时,确保它们遵循上述相同的名称,因为名称是区分大小写的。
OHE Ground Profile
“OHE Ground Profile”工作表的目的是定义土壤单元的名称和深度。这些值将用于在 PLAXIS 中创建相应的钻孔。它涉及以下四个输入值:
-
名称:指定每个钻孔的名称。此行不会被 Python 提取,但它是我们工作的钻孔的良好跟踪器。例如,B 列包含“BH1”的所有输入值。
-
x 坐标:定义每个钻孔的 x 坐标。
-
顶部:定义每个钻孔顶部的 y 坐标。换句话说,就是地质工程中的顶部 R.L.。
-
土壤单元名称及层级坐标:从第 4 行开始,我们首先在 A 列定义土壤单元名称。然后我们在每个钻孔下方输入每个单元的底部 y 坐标。
应强调的是,这个模板旨在输入每个单元的底部坐标。另一个重要说明是,我们应在 A 列中输入所有钻孔的所有单元。类似于 PLAXIS 界面的逻辑,如果我们遇到一个在钻孔中不存在的单元,我们将使用与前一个单元相同的 y 坐标。
以“BH1”为例,我们为“SSIV”和“SSIII”都设置 4 米,这意味着“SSIII”在“BH1”中不存在。

作者提供的 Excel 截图
土壤属性
该工作表的目的是定义每个单元的土壤属性。这包括以下列:
-
名称
-
材料
-
单位重量(kN/m³)
-
杨氏模量,E’(kPa)
-
泊松比,v(nu)(-)
-
粘聚力,c’(kPa)
-
摩擦角,phi(度)
-
拉伸强度(kPa)
-
界面强度(-)
-
K0 自动?:T/F 确定 K0 是否在 PLAXIS 中自动计算
-
K0x = K0y?:T/F 确定 K0x 是否等于 K0y
-
K0x(-)
-
K0y(-)

作者提供的 Excel 截图
这些是 PLAXIS 2D 中 Mohr-Coulomb 模型的典型材料属性,只是为了确保所有输入属性都在正确的单位中。还应注意,此模板仅设计用于排水条件。
确保你已经创建了包含上述工作表和数值的 Excel 电子表格。
输入文件创建后,我们准备进入下一步。
步骤 1:使用 Pandas 从 Excel 中读取值
步骤 1 的主要目标是设置土壤轮廓并读取 Excel 输入模板。
首先,我们创建一个空的 Python 文件,并将其命名为 “soil_geometry.py”。
导入模块并启动服务器
与之前的教程类似,我们将首先导入相关模块并启动服务器。
from plxscripting.easy import *
import subprocess, time
import pandas as pd
###############################################
PLAXIS_PATH = r'C:\Program Files\Bentley\Geotechnical\PLAXIS 2D CONNECT Edition V22\\Plaxis2DXInput.exe' # Specify PLAXIS path on server.
PORT_i = 10000 # Define a port number.
PORT_o = 10001
PASSWORD = 'SxDBR<TYKRAX834~' # Define a password.
subprocess.Popen([PLAXIS_PATH, f'--AppServerPassword={PASSWORD}', f'--AppServerPort={PORT_i}'], shell=False) # Start the PLAXIS remote scripting service.
time.sleep(5) # Wait for PLAXIS to boot before sending commands to the scripting service.
# Start the scripting server.
s_i, g_i = new_server('localhost', PORT_i, password=PASSWORD)
s_o, g_o = new_server('localhost', PORT_o, password=PASSWORD)
s_i.new()
g_i.SoilContour.initializerectangular(-15, -10, 15, 10)
读取 Excel 文件
文件位置:C:\Users\phtsang\Desktop\PLAXIS_V22\Python_automation
文件名:我们之前创建的电子表格,即 “Soil_input.xlsx”
由于我们要从 “OHE Ground Profile” 中提取值,我们将使用 “pd.read_excel()” 并通过 “sheet_name” 指定要读取的工作表。
source=r"C:\Users\phtsang\Desktop\PLAXIS_V22\Python_automation"
file="Soil_input"+".xlsx"
soilsheet="OHE Ground Profile"
一旦我们设置好 Excel 模板,就可以使用这些值创建土壤剖面。
步骤 2:使用 Excel 输入土壤深度并创建土壤剖面
在步骤 2 中,我们将从 “OHE Ground Profile” 工作表中提取钻孔值,并相应地创建土壤剖面。
输入表格涉及的信息如下所示。我们需要这些信息来参考我们的代码。

作者从 Excel 截图
我们将使用pandas的方法从 Excel 中提取值。详细说明请参阅以下文章。
## 使用 Python 与 Excel (PLAXIS 输入) 交互
自动化的逐步指南
[towardsdatascience.com
- 首先,我们需要从 “OHE Ground Profile” 工作表中提取值,并将其存储为 “df_soil”。
#Soil
df_soil = pd.read_excel(file, sheet_name = soilsheet,engine="openpyxl")
- 然后,我们需要使用 “BH1” 列中的 X 坐标(即 df_soil.iloc[0,1])创建第一个钻孔。需要注意的是,数据框的第一行被跳过了。
g_i.borehole(df_soil.iloc[0,1]) # Create borehole at x coord
- 使用 PLAXIS 命令 ‘g_i.soillayer(0)’ 创建第一个土壤层
g_i.soillayer(0) # Create first layer in 1st borehole
- 之后,我们将使用如下命令设置第一个钻孔的顶部 y 坐标。
Soillayers[0]:表示第一个土壤层。Soillayers[1] 将表示第二层,依此类推。
Zones[0]:表示第一个钻孔。
Top.set:允许我们指定土壤单元的顶部 y 坐标。只需为第一层设置一次,因为随后的层将自动从前一层获取底部 y 坐标作为其顶部 y 坐标(类似于 PLAXIS 的做法)。
g_i.Soillayers[0].Zones[0].Top.set(df_soil.iloc[1,1]) # Set top y coord of 1st Bh
- 接下来,我们将遍历其他钻孔列(即 BH2–4),以 (1) 创建具有给定 x 坐标的钻孔,并 (2) 设置顶部 y 坐标。
for j in range(len(df_soil.columns)-2):
g_i.borehole(df_soil.iloc[0,j+2]) # X coord for other Bhs
g_i.Soillayers[0].Zones[j+1].Top.set(df_soil.iloc[1,j+2]) #Top y coord for other Bhs
-
一旦我们设置了每个钻孔的顶部坐标,我们就遍历其余的层,并为第一个钻孔设置相应的底部坐标。
-
这里我们需要检查循环是否到达最后一层。如果没有到达最后一层,需要使用命令 ‘g_i.soillayer(1)’ 创建新层,否则,不创建新层。
for i in range(df_soil.count()[0]-2): # Loop through the number of layers
if i == df_soil.count()[0]-3: # Don't create new layer if we are at last unit
g_i.Soillayers[i].Zones[0].Bottom.set(df_soil.iloc[i+2,1]) #Set bottom y coord for 1st Bh
else:
g_i.soillayer(1) #Create new layer if we aren't at last unit
g_i.Soillayers[i].Zones[0].Bottom.set(df_soil.iloc[i+2,1])
- 最后,我们遍历其余的钻孔(除第一个钻孔外),并使用以下代码设置底部坐标。
for j in range(len(df_soil.columns)-2):
for i in range(df_soil.count()[0]-2):
g_i.Soillayers[i].Zones[j+1].Bottom.set(df_soil.iloc[i+2,j+2]) #Set bottom y coord for other Bhs
最终脚本应如下所示:
df_soil = pd.read_excel(file, sheet_name = soilsheet,engine="openpyxl")
g_i.borehole(df_soil.iloc[0,1])
g_i.soillayer(0)
g_i.Soillayers[0].Zones[0].Top.set(df_soil.iloc[1,1])
for j in range(len(df_soil.columns)-2):
g_i.borehole(df_soil.iloc[0,j+2])
g_i.Soillayers[0].Zones[j+1].Top.set(df_soil.iloc[1,j+2])
for i in range(df_soil.count()[0]-2):
if i == df_soil.count()[0]-3:
g_i.Soillayers[i].Zones[0].Bottom.set(df_soil.iloc[i+2,1])
else:
g_i.soillayer(1)
g_i.Soillayers[i].Zones[0].Bottom.set(df_soil.iloc[i+2,1])
for j in range(len(df_soil.columns)-2):
for i in range(df_soil.count()[0]-2):
g_i.Soillayers[i].Zones[j+1].Bottom.set(df_soil.iloc[i+2,j+2])

作者提供的 VS Code 截图
创建土壤剖面后,我们可以为每个土壤单元定义土壤属性。
第 3 步:用 Excel 输入土壤属性并分配材料
在第 3 步中,我们旨在从 “土壤属性” 工作表中提取土壤属性,并将这些属性分配给前一节中在 A 列指定的层对象(即 “SZ”,“SSIV”等)。
输入表中涉及的信息如下所示。

作者提供的 Excel 截图
- 与之前的步骤类似,我们首先从 “土壤属性” 中读取值并将其存储为 dataframe。
soilmatsheet="Soil properties"
df_soilmat = pd.read_excel(file, sheet_name = soilmatsheet,engine="openpyxl")
-
“Excel 输入界面” 下提到的典型土壤属性为 Mohr-Coulomb 模型所用。我们将遍历表中的行,使用 iloc 查找每个值并将其分配给相应的变量。
-
要在 PLAXIS 中使用 Python 访问土壤模型,应使用在 PLAXIS 中显示的数字值。例如,“线性弹性”模型用 1 表示,“莫尔-库仑”模型用 2 表示,以此类推。
for i in range(df_soilmat.count()[0]):
name = df_soilmat.iloc[i,1]
if df_soilmat.iloc[i,2] == 'MC':
materialmodel=2
gammaUnsat=df_soilmat.iloc[i,3]
gammaSat=df_soilmat.iloc[i,3]
Eref=df_soilmat.iloc[i,4]
nu=df_soilmat.iloc[i,5]
cref=df_soilmat.iloc[i,6]
phi= df_soilmat.iloc[i,7]
TensileStrength=df_soilmat.iloc[i,8]
- 然后,如果 Excel 模板中提供了值,我们将分配一个界面强度比。为此,我们需要将 “手动” 作为字符串分配给 ‘InterfaceStrength’,并将给定的比率分配给 ‘Rinter’。如果没有提供值,“刚性” 将自动分配给 ‘InterfaceStrength’。
if df_soilmat.iloc[i,9] >0:
InterfaceStrength = 'Manual'
Rinter=df_soilmat.iloc[i,9]
else:
InterfaceStrength = 'Rigid'
-
土壤属性的最后一个组件是 K0 条件。我们需要决定 K0 确定是 “手动” 还是 “自动”。在 PLAXIS-Python 环境中,“自动”和“手动”分别用 0 和 1 表示。
-
我们还需要根据用户输入检查 K0x 是否等于 K0y。Python 中 “K0x = K0y” 的选中和未选中框分别用 True 和 False 表示。
if df_soilmat.iloc[i,10] =='F':
K0Determination=1
if df_soilmat.iloc[i,11] == 'T':
K0PrimaryIsK0Secondary=True
K0Primary=df_soilmat.iloc[i,12]
K0Secondary=df_soilmat.iloc[i,12]
else:
K0PrimaryIsK0Secondary=False
K0Primary=df_soilmat.iloc[i,12]
K0Secondary=df_soilmat.iloc[i,13]
else:
K0Determination=0
- 然后,使用 PLAXIS 命令 ‘setproperties()’ 设置材料属性,类似于第 4 个教程。
material1 = g_i.soilmat()
material1.setproperties(
"Identification",name,
"SoilModel",materialmodel,
"gammaUnsat", gammaUnsat,
"gammaSat", gammaSat,
"Eref",Eref,
"nu", nu,
"cref", cref,
"phi", phi,
"TensileStrength",TensileStrength,
"InterfaceStrengthDetermination",InterfaceStrength,
"Rinter",Rinter,
"K0Determination",K0Determination,
"K0PrimaryIsK0Secondary",K0PrimaryIsK0Secondary,
"K0Primary",K0Primary,
"K0Secondary",K0Secondary
)
- 之后,我们将土壤材料存储为 ‘soilmat’ 对象。
soilmat=[mat for mat in g_i.Materials[:] if mat.TypeName.value == 'SoilMat']
- 最后一步是根据 “OHE 地面剖面” 工作表中的 “名称” 列(即 A 列)将材料设置到土壤层。我们需要使用 if 检查表中提供的材料名称。如果材料名称与现有材料对象匹配,则使用 ‘Soils.setmaterial()’ 将该材料分配给土壤层。
for j in range(df_soil.count()[0]-2):
for i in range(len(soilmat)):
if df_soil.iloc[j+2,0] == soilmat[i].Name:
g_i.Soils[j].setmaterial(soilmat[i])
最终脚本如下所示:
soilmatsheet="Soil properties"
df_soilmat = pd.read_excel(file, sheet_name = soilmatsheet,engine="openpyxl")
for i in range(df_soilmat.count()[0]):
name = df_soilmat.iloc[i,1]
if df_soilmat.iloc[i,2] == 'MC':
materialmodel=2
gammaUnsat=df_soilmat.iloc[i,3]
gammaSat=df_soilmat.iloc[i,3]
Eref=df_soilmat.iloc[i,4]
nu=df_soilmat.iloc[i,5]
cref=df_soilmat.iloc[i,6]
phi= df_soilmat.iloc[i,7]
TensileStrength=df_soilmat.iloc[i,8]
if df_soilmat.iloc[i,9] >0:
InterfaceStrength = 'Manual'
Rinter=df_soilmat.iloc[i,9]
else:
InterfaceStrength = 'Rigid'
if df_soilmat.iloc[i,10] =='F':
K0Determination=1
if df_soilmat.iloc[i,11] == 'T':
K0PrimaryIsK0Secondary=True
K0Primary=df_soilmat.iloc[i,12]
K0Secondary=df_soilmat.iloc[i,12]
else:
K0PrimaryIsK0Secondary=False
K0Primary=df_soilmat.iloc[i,12]
K0Secondary=df_soilmat.iloc[i,13]
else:
K0Determination=0
material1 = g_i.soilmat()
material1.setproperties(
"Identification",name,
"SoilModel",materialmodel,
"gammaUnsat", gammaUnsat,
"gammaSat", gammaSat,
"Eref",Eref,
"nu", nu,
"cref", cref,
"phi", phi,
"TensileStrength",TensileStrength,
"InterfaceStrengthDetermination",InterfaceStrength,
"Rinter",Rinter,
"K0Determination",K0Determination,
"K0PrimaryIsK0Secondary",K0PrimaryIsK0Secondary,
"K0Primary",K0Primary,
"K0Secondary",K0Secondary
)
soilmat=[mat for mat in g_i.Materials[:] if mat.TypeName.value == 'SoilMat']
for j in range(df_soil.count()[0]-2):
for i in range(len(soilmat)):
if df_soil.iloc[j+2,0] == soilmat[i].Name:
g_i.Soils[j].setmaterial(soilmat[i])



作者在 VS Code 中的截图
使用以下脚本运行。
(PLAXIS) C:\Users\phtsang\Desktop\PLAXIS_V22\Python_automation>python soil_geometry.py
你应该会看到在 PLAXIS 2D 中创建了如下内容。从土壤窗口中可以看到,所有钻孔都是根据 Excel 输入创建的。

作者在 PLAXIS 中的截图

作者在 PLAXIS 中的截图
在土壤属性方面,所有土壤单元都已根据我们在输入表中指定的属性创建。

太棒了!我们刚刚通过 PLAXIS 2D 的用户界面创建了一个带有材料定义的土壤剖面。
结论
这就是第五个关于使用 Python 创建土壤剖面(包括材料定义)的教程。完成此教程后,你应该能够从 Excel 获取土壤输入并创建剖面,同时在 PLAXIS 中分配土壤属性。
如果你喜欢阅读这种内容,可以随时关注我的页面。我会继续发布关于使用 Python 自动化 PLAXIS 的系列教程。除此之外,我还热衷于分享如何使用 Python 自动化工程工作流程的知识。
[## 通过我的推荐链接加入 Medium - Philip Tsang
阅读 Philip Tsang 的每一个故事(以及 Medium 上其他成千上万的作者)。你的会员费用直接支持…
medium.com](https://medium.com/@philip.studio11/membership?source=post_page-----a9c870e253c7--------------------------------)
为 TrailForks 构建推荐系统
我是如何赢得 Outside 2022 创新日奖的
·
关注 发表在 Towards Data Science · 11 分钟阅读 · 2023 年 1 月 4 日
--
摄影:Kristin Snippe 摄于 Unsplash
长背景故事
当我在 2021 年 5 月首次加入 Outside Inc 时,我的工作是构建一个个性化推荐系统,以驱动 Outside Feed。Outside Feed 包含各种户外和活跃生活方式的内容,以文章、视频、Outside 电影和播客等混合媒介格式呈现。我们的目标是激励人们走出户外。
快进到 17 个月后,我们对 Outside Feed 的推荐系统进行了两次重大改造:
-
从逆时间顺序的动态信息流到由混合推荐器驱动的动态信息流,其中包括三种经典的推荐系统功能:协同过滤、基于内容的过滤和热门与趋势。模型训练和推理是基于批次的。使用 80–20 法则,这种系统可以让你在提供个性化推荐时,借助相对较少的努力达到 80%。这里的“低”努力主要是指算法复杂性方面。
-
从基于批次的混合推荐系统到实时推荐系统。我们开始集成 Miso.ai —— 一个第三方工具,用来充分利用 MetaRouter 收集的点击流事件数据,并生成实时推荐。由于我们使用 GraphQL,关于在 Apollo 中封装 REST API 以及在 DataDog 中进行监控,我们学到了很多经验,这些经验值得另写一篇文章。
在这篇文章中,我记录了我参加 Outside Innovation Day 的经验,在一天内使用 Trailforks 数据构建了一个原型小径推荐系统。(PS:是的,Trailforks 现在是 Outside 家族的一部分!)也许因为我的很多同事都是热衷的徒步旅行者、山地自行车爱好者和户外运动爱好者,他们亲身体验了有一个小径推荐系统来帮助发现探索小径的乐趣和放松。他们都友好地为我的项目投了票,我赢得了 Outside Innovation Day 最能体现我们使命的奖项!(非常幸运和开心 😁 ~~~)。
以下是我将分享的三件事:
🏔️ 1. 如何以 20% 的努力构建 80% 个性化的推荐系统(使用 Trailforks 数据)?
🏔️ 2. 为准备黑客马拉松类型的项目考虑的事项
🏔️ 3. 生产化推荐系统的进一步考虑
🏔️ 1. Trail Recommender 的 80–20 法则
有很多高级推荐系统,但协同过滤 + 基于内容的推荐系统的混合真的是个性化推荐系统的最佳选择。它可以以约 20% 的工程努力覆盖 80% 的内容。
协同过滤涉及分析其他具有相似口味的用户的偏好,并基于这些用户的偏好推荐小径。

由 Wen Yang 创建的图像
另一方面,基于内容的过滤涉及分析小径本身的特征,并推荐与用户偏好相似的小径。

Wen Yang 制作的图像
构建一个 80–20 小径推荐系统涉及三个步骤。
步骤 1:首先,你需要收集数据。
有两种数据对于构建推荐系统特别有用。
- 交互数据集:包括
userid、trailid和activitytype。前两个最为重要,而第三个特征activitytype在我的项目中未使用。

预览 Trailforks 互动数据集
- 目录数据集: 包括选定小径的信息,如小径的位置、长度、难度等级、全球排名分数等。

步骤 2: 使用 Altair 进行数据可视化
从互动数据集中,只有 10 个用户,其中一个用户(userid = 454369)探索了 2524 条独特的小径。(PS:他的名字是 Trevor,他是 TrailForks 的创始工程师——这完全合理!)
import altair as alt
bars = alt.Chart(df_raw).mark_bar().encode(
x='n_trails_interacted:Q',
y='userid:O')
text = bars.mark_text(
align='left',
baseline='middle',
dx=3 # Nudges text to right so it doesn't appear on top of the bar
).encode(
text='n_trails_interacted:Q'
)
(bars + text).properties(height=400)

以下代码可以为每个变量名生成一个柱状图:
def plot_variable(variable):
source = df_trail_sub
bar = alt.Chart(source).mark_bar().encode(
x='count():Q',
y=alt.Y(f'{variable}:N', sort='-x')
)
return bar
比如,我们来看看小径方向、物理评分难度标题和国家名称:

起初,我尝试使用 global_rank_score 作为“流行度”特征,但有太多记录缺少这个特征。幸运的是,rating 与 global_rank_score 存在正相关。下面是我最喜欢的可视化类型之一,使用 altair 可以很容易地完成。
def plot_points_bar():
brush = alt.selection(type='interval')
points = alt.Chart(df_trail_sub).mark_point().encode(
x='global_rank_score',
y='rating',
color=alt.condition(brush, 'difficulty_title', alt.value('lightgray'))
).add_selection(
brush
)
bars = alt.Chart(df_trail_sub).mark_bar().encode(
y='difficulty_title',
color='difficulty_title',
x='count(difficulty_title)'
).transform_filter(
brush
)
return points & bars

同样,我们可以使用 difficulty_title 替代 physical_rating 来表示难度等级,因为前者的缺失值比后者少。
-
中等 → 蓝色,黑钻
-
困难 → 黑钻,蓝色
-
极端 → 双黑钻
-
简单 → 绿色
alt.Chart(dt).mark_circle().encode(
x='difficulty_title:O',
y='physical_rating:O',
size= alt.Size('pct:Q',scale=alt.Scale(range=[10, 2000])),
color= 'physical_rating'
).properties(width=400, height=300)

步骤 3a: 使用 **implicit** 构建协同过滤模型
对于协同过滤,你所需要的唯一数据集是互动数据集。这里我计算了三个其他特征:
-
n_times_interacted: 该用户与此小径互动了多少次 -
n_trails_interacted: 该用户探索了多少条小径 -
n_users_interacted: 多少独特用户探索了这条小径
思路是使用 n_times_interacted 作为隐式反馈特征:用户探索某条小径的次数越多,就越能确定用户喜欢这条小径。

模型训练:
import implicit
import scipy.sparse as sparse
n_user = df_raw.userid.nunique()
n_item = df_raw.trailid.nunique()
# Prepare ALS training data
sparse_user_item = sparse.csr_matrix((df_raw['n_times_interacted'].astype(float),
(df_raw['userid'], df_raw['trailid'])))
# initialize a model: set random_state for reproducibility!
model = implicit.als.AlternatingLeastSquares(factors=40,
regularization=0.1,
iterations=15,
random_state=10)
# train the model on a sparse matrix of user/item/confidence weights
model.fit(sparse_user_item)
推理的有用函数:
import requests
from ipywidgets import Image
def get_recommendations(userid, n):
rec, relevance = model.recommend(userid, sparse_user_item[userid],
N=n,
filter_already_liked_items=True
)
df_rec = df_trail.loc[df_trail['trailid'].isin(rec)]
return df_rec
def show_photo(df_rec):
photo_url = df_rec['cover_photo_url'].iloc[0]
image = Image(value=requests.get(photo_url).content)
return image
def show_map(df_rec):
map_url = df_rec['static_map_url'].iloc[0]
image = Image(value=requests.get(map_url).content)
return image
🌄 DEMO: 使用协同过滤为 Trevor 推荐小径
我使用了 pandas-profiling 来了解 Trevor 对小径的偏好,以下是观察结果:
-
physical_rating: 中等,困难
-
difficulty_title: 蓝色,黑钻
-
trailtype: 单轨
-
方向:仅限下坡
-
最多被探索的小径:Fitzsimmons Connector(319 次!)


让我们从协同过滤模型中获取推荐 →

以及第一个推荐小径的照片和地图:

步骤 3b: 基于内容的过滤
由于协同过滤基于交互数据集,这意味着它在冷启动用户场景下不起作用。如果新用户没有探索任何路线,我们将无法找到与该新用户相似的用户。这就是为什么我们需要基于内容的过滤来填补这个空白。
基于内容的过滤完全是关于寻找“相似的路线”,因此我们需要决定两件事:
-
在什么特征上相似
-
如何测量相似性
第一个问题是特征工程问题。以下特征被选为路线特征:
-
数值特征:
stat_climb、stat_descent、stat_distance、rating -
文本特征:
title、difficulty_title、trailtype、direction和country_title
对于第二个问题,我们使用余弦相似度,这是在推荐系统实践中常用的度量方法。
💻 特征工程和相似度计算的代码示例
from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import pairwise_distances
def model_tfidf_num(trail_data):
# scale numerical data (continuous) : global_rank_score missing values
trail_data_numerical = trail_data[['stat_climb', 'stat_descent',
'stat_distance','rating']]
scaler_num = StandardScaler().fit(trail_data_numerical)
df_num_scaled = scaler_num.transform(trail_data_numerical)
# vectorize text data
tfidf = TfidfVectorizer()
tfidf_matrix = tfidf.fit_transform(trail_data['comb_text_clean']).toarray()
# concatenate and get similarity
all_features = np.concatenate([tfidf_matrix,df_num_scaled],axis=1)
#cosine_sim = cosine_similarity(all_features, all_features)
return all_features
def get_similar(idx, n):
#idx: target item's index
# 1\. compute distance
target_feature = all_features[idx]
couple_dist = pairwise_distances(all_features,
target_feature, metric='cosine')
# 2\. get similar dataframe: no need to filter out the first
# because the first won't be the unseen url
indices = list(
map(lambda x: x.item(), np.argsort(couple_dist.ravel())))
# similar_score
cosine_similarity = 1 - couple_dist[indices].ravel()
df_sim_all = pd.DataFrame(
{"tfidf_index": indices, "similar_score": cosine_similarity})
df_sim = df_sim_all[1:n+1]
df_out = df_sim.merge(df_mapper, on='tfidf_index')
return df_out
🌄 DEMO:使用基于内容的过滤为 Trevor 推荐路线
让我们找到与 Trevor 的第二条最常探索的路线(‘A-Line — Lower’)相似的路线


哇!基于内容的推荐器非常令人印象深刻,它推荐了大多数中等到困难难度的山地车路线和下坡方向的路线。我的同事 Trevor 是一个资深的山地车骑行者,他对这样的推荐非常满意!
好吧,现在艰苦的工作已经完成。我想分享一些反思。
🏔️ 2. 准备黑客马拉松类型项目时需要考虑的事项
从参与黑客马拉松类型的项目中我学到的最大教训是“少即是多”。尽管很诱人去构建一个包括数据科学组件、后端数据库组件以及前端或只是用于更好演示的 streamlit 组件的推荐系统,但黑客马拉松通常时间有限,你需要意识到哪些部分你愿意稍微牺牲一点,以便留出时间做其他事情。对我来说,这些部分是:
-
EDA 和数据可视化:探索性数据分析就像一个黑洞,它可以把你吸进去,并且没有明显的结束点。数据可视化也是如此,你可能会花费无尽的时间来美化一个图表,但这只会将你的项目从 80 提升到 82,这与 80-20 规则相反。我的解决方案是
pandas-profiling和altair。在进行任何数据分析之前,我会使用pandas-profiling来检查缺失值、总体分布和相关性,这将帮助我缩小可以保留的数据范围并为特征工程做准备。altair使用起来非常简单,通过几行紧凑的代码,你可以轻松制作出漂亮的条形图和散点+分布图。我有一个小技巧,就是提前在笔记本中复制几个代码片段,只使用已经在笔记本中的那些代码,以避免任何进一步的诱惑。 -
演示是关于讲故事的。你可以熬夜制作最华丽和最复杂的黑客松项目,但你需要留出足够的时间来准备如何讲述这个故事。一个故事可以从 Why(你为什么要做这个项目的动机和背景)、What(发生了什么,什么令人惊讶,什么有效)、一点点 How(保持在高层次,方程式和代码本身很令人印象深刻,但它会给那些只有 5 分钟时间理解你的项目的人带来不必要的威慑和疲惫)开始。例如,这篇文章绝对过长,细节过多,不符合演示标准。最后,加入一些戏剧性和幽默感总是一个优势。
🏔️ 3. 将推荐系统生产化的进一步考虑
最后,如果你已经看到这里,并且真的想将原型推荐系统投入生产,我有三点幽灵知识*想与大家分享:
幽灵知识 (来源) :
知识存在于某个认知社区中,也许某个社区的核心成员很容易获得,但实际上并没有被写下来,也不清楚如何获取它。
-
批量 vs 实时推荐系统:批量系统听起来简单直接,但它有其自身的复杂性。如果你使用
implicit进行协同过滤并使用cosine similarity检索类似内容,很可能延迟无法满足你的前端 API 要求。你可以尝试使用像 FAISS 这样的索引来提高速度,但它可能仍然高于 500ms。这就是为什么你可能考虑对每个用户和每个项目进行批量计算。根据用户和项目的规模,批量推荐可能需要超过 4 小时,并且直到你重新训练模型才会刷新。一个快速提示是,仅使用过去 90 天内有任何互动的用户,或仅对过去 90 天内发布的项目进行批量更新。 -
“Cold-Start” 可能并不像你想象的那么严重:我记得当我第一次学习推荐系统时,书籍和讲座总是强调“cold-start”问题,好像这是最难解决的场景。然而,对于 Outside Feed 来说,这个问题甚至不在前 3 名。原因是很多人习惯于在 Outside Feed 上阅读最新的内容,因此通过发布日期对内容进行排序,以便展示新项(冷项)是相当容易的,从而能够保证新项获得页面浏览(互动)数据。对我们来说,更难的问题实际上是“如何推荐永恒的内容”?如果我们推荐一篇最初于 2016 年撰写的全时热门文章,即使这篇内容可能符合用户的口味,用户可能会觉得我们没有新鲜的内容可以推荐。
-
第三个知识点是我从另一位推荐系统从业者(Andy - 来自miso.ai的 CTO)那里听到的趣闻:他曾经为一个服装网站构建了一个主页推荐系统。尽管学习排名的系统了解到“黑色”可能是最受欢迎的颜色,因此推荐了首页上全是黑色的不同风格的服装。好吧,模型并没有错,但乍一看效果很糟糕,并且在吸引人们购买衣物方面表现不佳。他开玩笑说,他研究推荐系统已经很多年了,但没有人提到好的颜色搭配设计可能会有很大的影响。
这就是我在 2022 年的最后一篇中等帖子。非常感谢你的支持。如果你有任何想法或反思,或者更好——更多的推荐系统趣闻,请发送给我!我很乐意将它们编入我梦想写的迷人书籍《构建推荐系统的乐趣与悲哀》中 < The Pleasure and Sorrow of Building RecSys> 🐳~~~
在一个小时内构建你的第一个深度学习应用
使用 HuggingFace Spaces 和 Gradio 部署图像分类模型
·
关注 发表在 Towards Data Science ·11 分钟阅读·2023 年 7 月 21 日
--
Thought Catalog 提供的照片,来自 Unsplash
我从事数据分析工作已经将近十年了。时不时地,我会使用机器学习技术从数据中获取见解,而且我对使用经典的机器学习方法感到很自如。
尽管我通过了一些关于神经网络和深度学习的 MOOC 课程,但我从未在工作中使用过这些技术,这个领域对我来说似乎非常具有挑战性。我有所有这些偏见:
-
你需要学习很多东西才能开始使用深度学习:数学、不同的框架(我至少听说过三种:
PyTorch、TensorFlow和Keras)以及网络架构。 -
训练模型需要大量的数据集。
-
没有强大的计算机(它们还必须有 Nvidia GPU),就不可能获得令人满意的结果,因此获取这样的设备相当困难。
-
要使一个机器学习驱动的服务运行起来,需要处理大量的样板工作:你需要处理前端和后端的部分。
我相信分析的主要目标是帮助产品团队根据数据做出正确的决策。如今,神经网络可以显著提升我们的分析能力,例如,自然语言处理可以从文本中获得更多的见解。因此,我决定再尝试一下利用深度学习的力量对我来说会有帮助。
这就是我开始Fast.AI 课程的方式(它在 2022 年初进行了更新,所以我猜内容自之前的 TDS 评论以来可能有所变化)。我意识到,使用深度学习解决任务并不是那么困难。
这个课程采用自上而下的方法。所以你从构建一个工作系统开始,之后才会深入了解所有必要的基础知识和细节。
我在第二周制作了我的第一个机器学习驱动应用程序(你可以在 这里 尝试一下)。这是一个图像分类模型,可以识别我最喜欢的狗品种。令人惊讶的是,即使我的数据集中只有几千张图片,它的表现也很好。这让我感到振奋,我们现在可以如此轻松地构建一个十年前还完全是魔法的服务。

照片由Shakti Rajpurohit拍摄,发布在Unsplash
所以在这篇文章中,你将找到一个关于构建和部署第一个机器学习驱动服务的初学者级教程。
什么是深度学习?
深度学习是机器学习的一个特定应用场景,其中我们使用多层神经网络作为模型。
神经网络非常强大。根据通用逼近定理,神经网络可以逼近任何函数,这意味着它们能够解决任何任务。
目前,你可以将这个模型视为一个黑箱,它接受输入(在我们的例子中——一张狗的图片)并返回输出(在我们的例子中——一个标签)。

作者拍摄的照片
构建模型
你可以在Kaggle上找到这个阶段的完整代码。
我们将使用 Kaggle Notebooks 来构建我们的深度学习模型。如果您还没有 Kaggle 账户,值得注册。Kaggle 是一个流行的数据科学平台,您可以在这里找到数据集、参与竞赛以及运行和分享代码。
您可以在 Kaggle 创建一个 Notebook,并像在本地 Jupyter Notebook 一样执行代码。Kaggle 甚至提供 GPU,所以我们可以很快训练神经网络模型。

图片来源:作者
让我们先导入所有包,因为我们将使用许多 Fast.AI 工具。
from fastcore.all import *
from fastai.vision.all import *
from fastai.vision.widgets import *
from fastdownload import download_url
加载数据
不用说,我们需要一个数据集来训练我们的模型。获取一组图像的最简单方法是使用搜索引擎。
DuckDuckGo搜索引擎有一个易于使用的 API 和方便的 Python 包duckduckgo_search(更多信息),所以我们将使用它。
让我们尝试搜索一张狗的图片。我们已指定license_image = any,只使用具有创作共用许可的图片。
from duckduckgo_search import DDGS
import itertools
with DDGS() as ddgs:
res = list(itertools.islice(ddgs.images('photo samoyed happy',
license_image = 'any'), 1))
在输出中,我们获得了关于图片的所有信息:名称、URL 和大小。
{
"title": "Happy Samoyed dog photo and wallpaper. Beautiful Happy Samoyed dog picture",
"image": "http://www.dogwallpapers.net/wallpapers/happy-samoyed-dog-wallpaper.jpg",
"thumbnail": "https://tse2.mm.bing.net/th?id=OIP.BqTE8dYqO-W9qcCXdGcF6QHaFL&pid=Api",
"url": "http://www.dogwallpapers.net/samoyed-dog/happy-samoyed-dog-wallpaper.html",
"height": 834, "width": 1193, "source": "Bing"
}
现在我们可以使用 Fast.AI 工具来下载图片并显示缩略图。

图片由 Barcs Tamás 提供,来源于 Unsplash
我们看到一只快乐的萨摩耶,这意味着它在正常工作。所以让我们加载更多照片。
我旨在识别五种不同的犬种(我最喜欢的那些)。我将为每种犬种加载图片,并将它们存储在不同的目录中。
breeds = ['siberian husky', 'corgi', 'pomeranian', 'retriever', 'samoyed']
path = Path('dogs_breeds') # defining path
for b in tqdm.tqdm(breeds):
dest = (path/b)
dest.mkdir(exist_ok=True, parents=True)
download_images(dest, urls=search_images(f'photo {b}'))
sleep(10)
download_images(dest, urls=search_images(f'photo {b} puppy'))
sleep(10)
download_images(dest, urls=search_images(f'photo {b} sleep'))
sleep(10)
resize_images(path/b, max_size=400, dest=path/b)
运行此代码后,您将看到 Kaggle 右侧面板上的所有加载的照片。

图片来源:作者
下一步是将数据转换为适用于 Fast.AI 模型的格式——DataBlock。
对于这个对象,您需要指定几个参数,但我只会强调最重要的几个:
-
splitter=RandomSplitter(valid_pct=0.2, seed=18):Fast.AI 要求选择一个验证集。验证集是用来估计模型质量的保留数据。为了防止过拟合,训练时不会使用验证数据。在我们的例子中,验证集是数据集的 20%的随机部分。我们指定了seed参数,以便下次能够精确地重复相同的划分。 -
item_tfms=[Resize(256, method=’squish’)]:神经网络以批量处理图像。这就是为什么我们必须将图片调整为相同的大小。目前我们使用了squish方法,但我们会在后面更详细地讨论它。
我们已经定义了一个数据块。函数show_batch可以向我们展示一组随机的带标签的图像。

照片由 Angel Luciano 提供,来源于 Unsplash | 照片由 Brigitta Botrágyi 提供,来源于 Unsplash | 照片由 Charlotte Freeman 提供,来源于 Unsplash
数据看起来没问题,所以我们继续训练。
训练模型
你可能会感到惊讶,但下面这两行代码将完成所有工作。

我们使用了一个预训练模型(18 层深度的卷积神经网络 — Resnet18)。这就是我们称之为 fine_tune 的原因。
我们训练了模型三轮,这意味着模型看到了整个数据集 3 次。
我们还指定了度量标准 — accuracy (正确标记图片的比例)。你可以在每一轮后的结果中看到这个度量标准(它仅使用验证集计算,以免影响结果)。然而,它没有用于优化过程,仅供参考。
整个过程花了大约 30 分钟,现在我们的模型可以以 94.45% 的准确率预测狗的品种。干得好!但我们能否改善这个结果?
改进模型:数据清理和增强
如果你希望尽快看到你的第一个模型运行,可以先跳过这部分,继续进行模型部署。
首先,让我们看看模型的错误:它是否无法区分柯基和哈士奇或博美和拉布拉多。我们可以使用 confusion_matrix 来实现。请注意,混淆矩阵也是仅使用验证集计算的。

Fast.AI 课程中分享的另一个小窍门是可以使用模型来清理我们的数据。为此,我们可以查看损失最大的图像:这些可能是模型自信度高但错误的情况,或是正确但信心低的情况。

照片由 Benjamin Vang 提供,来源于 Unsplash | 照片由 Xennie Moore 提供,来源于 Unsplash | 照片由 Alvan Nee 提供,来源于 Unsplash
显然,第一张图片的标签不正确,而第二张图片包含了哈士奇和柯基。因此还有改进的空间。
幸运的是,Fast.AI 提供了一个方便的 ImageClassifierCleaner 小部件,它可以帮助我们快速解决数据问题。你可以在你的笔记本中初始化它,然后你将能够更改数据集中的标签。
cleaner = ImageClassifierCleaner(learn)
cleaner
在每个类别之后,你可以运行以下代码来解决问题:删除图片或将其移动到正确的文件夹。
for idx in cleaner.delete(): cleaner.fns[idx].unlink()
for idx,breed in cleaner.change(): shutil.move(str(cleaner.fns[idx]), path/breed)
现在我们可以再次训练我们的模型,并看到准确率提高了:95.4% 对比 94.5%。

正确识别的柯基犬比例已从 88% 提高到 96%。太棒了!

改善我们模型的另一种方法是改变我们的缩放方法。我们使用了压缩方法,但正如你所看到的,它可能会改变自然物体的比例。让我们尝试更具创意地使用增强。
增强是对图片的更改(例如,对比度改进、旋转或裁剪)。这将为我们的模型提供更多变的数据显示,并希望提高其质量。
与 Fast.AI 一样,你只需更改几个参数即可添加增强。

此外,由于在每个周期中模型会看到略有不同的图片,我们可以增加周期数。经过六个周期,我们达到了 95.65% 的准确率——结果稍微好一点。整个过程大约花了一个小时。
下载模型
最后一步是下载我们的模型。这非常简单。
learn.export('cuttest_dogs_model.pkl')
然后你将有一个标准的 pickle 文件(常见的 Python 对象存储格式)。只需在 Kaggle Notebook 右侧面板中选择文件旁的 更多操作,你将可以将模型下载到你的计算机上。

现在我们有了训练好的模型,让我们部署它,这样你就可以将结果分享给全世界。
部署你的模型
我们将使用 HuggingFace Spaces 和 Gradio 来构建我们的网页应用。
设置 HuggingFace Space
HuggingFace 是一家提供实用机器学习工具的公司,例如流行的变换器库或分享模型和数据集的工具。今天我们将使用他们的 Spaces 来托管我们的应用。
首先,如果你还没有注册,你需要创建一个账户。这只需几分钟。请点击这个 链接。
现在是创建新空间的时候了。前往 Spaces 选项卡并点击“创建”按钮。你可以在 文档中找到更多详细说明。
然后你需要指定以下参数:
-
name(它将用于你的应用 URL,所以请谨慎选择),
-
license(我选择了开源 Apache 2.0 许可证)
-
SDK(在这个示例中我将使用 Gradio)。

然后用户友好的 HuggingFace 将向你展示说明。TL;DR 现在你有了一个 Git 仓库,你需要将你的代码提交到那里。
关于 Git 有一个细节。由于你的模型可能相当大,最好设置 Git LFS(大型文件存储),这样 Git 就不会跟踪这个文件的所有更改。有关安装,请参考网站上的说明。
-- cloning repo
git clone https://huggingface.co/spaces/<your_login>/<your_app_name>
cd <your_app_name>
-- setting up git-lfs
git lfs install
git lfs track "*.pkl"
git add .gitattributes
git commit -m "update gitattributes to use lfs for pkl files"
Gradio
Gradio 是一个框架,它允许你只使用 Python 构建愉快且友好的 Web 应用。这就是它成为原型设计的宝贵工具的原因(尤其是对于像我这样没有深厚 JavaScript 知识的人)。
在 Gradio 中,我们将定义我们的接口,指定以下参数:
-
输入 — 一张图像,
-
输出 — 带有五个可能类别的标签,
-
标题、描述 和 一组示例 图像(我们也需要将它们提交到仓库中),
-
enable_queue=True可以帮助应用程序处理大量流量,特别是当它变得非常受欢迎时, -
函数 用于处理输入图像。
要为输入图像获取标签,我们需要定义一个预测函数,该函数加载我们的模型并返回一个包含每个类别概率的字典。
最后,我们将得到 app.py 的以下代码
import gradio as gr
from fastai.vision.all import *
learn = load_learner('cuttest_dogs_model.pkl')
labels = learn.dls.vocab # list of model classes
def predict(img):
img = PILImage.create(img)
pred,pred_idx,probs = learn.predict(img)
return {labels[i]: float(probs[i]) for i in range(len(labels))}
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(shape=(512, 512)),
outputs=gr.outputs.Label(num_top_classes=5),
title="The Cuttest Dogs Classifier 🐶🐕🦮🐕🦺",
description="Classifier trainded on images of huskies, retrievers, pomeranians, corgis and samoyeds. Created as a demo for Deep Learning app using HuggingFace Spaces & Gradio.",
examples=['husky.jpg', 'retriever.jpg', 'corgi.jpg', 'pomeranian.jpg', 'samoyed.jpg'],
enable_queue=True).launch()
如果你想了解更多关于 Gradio 的信息,请阅读文档。
让我们还创建一个requirements.txt文件,其中包含fastai,这样这个库就会在我们的服务器上安装。
所以剩下的唯一步骤就是将所有内容推送到 HuggingFace Git 仓库。
git add *
git commit -am 'First version of Cuttest Dogs app'
git push
你可以在GitHub上找到完整的代码。
在推送文件后,返回到 HuggingFace Space,你将看到类似的图片显示构建过程。如果一切正常,你的应用将在几分钟内运行。

如果出现任何问题,你将看到一个堆栈跟踪。然后你需要返回到你的代码中,修复错误,推送新版本,并再等几分钟。
它正在运行
现在我们可以使用这个模型处理真实照片,例如验证我家狗是否确实是柯基犬。

作者提供的照片
今天我们已经完成了构建深度学习应用程序的整个过程:从获取数据集和拟合模型到编写和部署 Web 应用。希望你已经完成了这个教程,现在你正在生产环境中测试你精彩的模型。
非常感谢你阅读这篇文章。希望它对你有启发。如果你有任何后续问题或评论,请在评论区留言。此外,也不要犹豫,分享你应用程序的链接。
使用 Streamlit 创建你自己的类似 ChatGPT 的应用
利用 OpenAI 的 API 绕过官方 ChatGPT 应用
·
关注 发布于 Towards Data Science ·6 分钟阅读·2023 年 4 月 3 日
--
作者提供的图像 — 使用 Stable Diffusion 创建
这是什么?
当 GPT-4 在 2023 年 3 月 14 日宣布时,我立即注册了 ChatGPT Plus——这是 ChatGPT 应用程序中的一个付费层级,可以立即访问新模型。它每月花费 20 美元,最初非常值得。然而,几天后,我的使用量减少了——别误解我:我仍然经常使用,只是我不确定是否会使用到足以证明其成本的程度。然后,几天前,我通过 OpenAI 的 API 获得了 GPT-4 的访问权限,尽管新模型比其前身 GPT-3.5 贵得多,但我仍然认为通过 API 互动可能对我来说更经济。
但我确实希望在与模型互动时保持类似聊天的体验。虽然已经有相当多的开源应用提供流畅的用户体验,但我不想使用 React 或类似的前端框架——它们非常适合构建出色的网页应用,但这不是我喜欢做的事情。相反,我决定用 Streamlit 构建自己的聊天界面,它提供了一个更基本的用户体验,并且功能远不如其他框架丰富——但对我来说,从零开始开发自己的 UI(而且是用 Python)要有趣得多。😃

作者图片
在本教程中,我将带你了解这个应用程序——所有代码也可以在这个GitHub 仓库中找到。
为什么这很重要?
通过实践学习
除了我已经提到的成本方面,还有一些额外的优势在于构建自己的聊天界面。首先,它迫使我更深入地研究 Chat API,因为到目前为止,我只使用过文本生成 API。使用 Chat API 类似,但有一些关键区别需要注意。
独立性
其次,这使我完全独立于 ChatGPT 应用程序。无论应用程序是否出现了重大故障或者应用程序限制了我可以向模型发送的推理请求数量(目前每 3 小时限制 25 条消息),这些都不适用于我运行自己的应用程序时。
数据隐私
第三,数据隐私。默认情况下,ChatGPT 会收集数据并用于改进服务(尽管可以选择退出)。然而,在使用 API 时,默认情况下不会收集数据,除非我们特别选择加入。更多信息请参见 OpenAI 的API 使用文档。
有趣多了!
最后,如前所述,构建这样的东西要有趣得多(至少对像我这样的极客来说🤓)。我已经在应用中加入了一些功能,例如显示令牌数量和每次对话的价格。也许在某个时候,我可以扩展应用以利用其他模型(例如来自 Hugging Face)🤗。
让我们开始行动吧!💪
构建应用程序
先决条件
为了开发这个应用程序,我们需要确保已安装openai、streamlit和streamlit-chat包:
pip install openai streamlit streamlit-chat
跟踪对话历史
聊天完成指南提到,我们需要将对话历史传递给 API,以便模型理解背景;换句话说,我们必须管理聊天模型的记忆,因为 API 不会为我们处理这一点。为此,我们创建了一个会话状态列表,在会话开始时存储系统消息,然后附加与模型的互动。
if 'messages' not in st.session_state:
st.session_state['messages'] = [
{"role": "system", "content": "You are a helpful assistant."}
]
def generate_response(prompt):
st.session_state['messages'].append({"role": "user", "content": prompt})
completion = openai.ChatCompletion.create(
model=model,
messages=st.session_state['messages']
)
response = completion.choices[0].message.content
st.session_state['messages'].append({"role": "assistant", "content": response})
显示对话
为了展示对话,我们利用了message函数,这个函数来自streamlit-chat包。我们遍历存储的互动,并按时间顺序展示对话,从最早的对话开始(就像在 ChatGPT 中一样)。
from streamlit_chat import message
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))
打印附加信息
我认为一个额外有用的功能是打印每次互动的一些元数据。为此,我们可以,例如,打印使用的模型(这可能会在不同互动之间变化)、这次互动使用了多少令牌及其成本(根据OpenAI 的定价页面)。
total_tokens = completion.usage.total_tokens
prompt_tokens = completion.usage.prompt_tokens
completion_tokens = completion.usage.completion_tokens
if model_name == "GPT-3.5":
cost = total_tokens * 0.002 / 1000
else:
cost = (prompt_tokens * 0.03 + completion_tokens * 0.06) / 1000
st.write(
f"Model used: {st.session_state['model_name'][i]}; Number of tokens: {st.session_state['total_tokens'][i]}; Cost: ${st.session_state['cost'][i]:.5f}")

图片由作者提供
请注意,随着对话的延长,令牌的数量(从而价格)会增加。这是因为我们需要提交所有先前的问题和回答,以便模型理解互动的背景。
为了节省开支,因此建议在开始新的聊天话题时清除对话记录。
侧边栏
在侧边栏中,我们提供了切换模型和清除对话历史的选项。此外,我们还可以显示当前对话的累计费用:

图片由作者提供
结论
通过这些步骤,我们成功地开发了一个易于使用且可定制的聊天界面,使我们能够与基于 GPT 的模型互动,而无需依赖像 ChatGPT 这样的应用程序。我们现在可以使用以下命令运行应用程序:
streamlit run app.py
这如何改变了我的工作流程
我现在实际上已经取消了 ChatGPT Plus 的订阅,并且我专门使用我的应用程序与 GPT 模型进行互动。默认情况下,我使用 GPT-3.5 模型,这使得使用这些模型非常实惠。只有在处理更复杂的任务时,或者当我对 GPT-3.5 的结果不完全满意时,我才会切换到 GPT-4。很可能,我会继续随着时间推移向应用程序中添加新功能,因为这是我最喜欢做的事情——敬请期待未来的更新😊
进一步改进的想法
希望这对你有所帮助——请使用本教程作为起点来构建你自己的聊天 UI。我很想了解你正在构建的内容,所以请在评论中与我联系。这里有一些关于如何改进这个应用的想法,供你参考:
编程愉快!
海科·霍茨
👋 关注我在Medium和LinkedIn上的内容,了解更多关于生成式 AI、机器学习和自然语言处理的信息。
👥 如果你在伦敦,可以加入我们的NLP London Meetups。

从零开始使用 Pytorch 构建自己的 Transformer
在 Pytorch 中逐步构建一个 Transformer 模型
·
关注 发表在Towards Data Science ·7 分钟阅读·2023 年 4 月 26 日
--
在本教程中,我们将使用 PyTorch 从头开始构建一个基本的 Transformer 模型。Transformer 模型由 Vaswani 等人提出,在论文“Attention is All You Need”中介绍,是一种针对序列到序列任务(如机器翻译和文本摘要)设计的深度学习架构。它基于自注意力机制,已成为许多最先进自然语言处理模型的基础,如 GPT 和 BERT。
要详细了解 Transformer 模型,请访问这两篇文章:
1. 关于“注意力”和“Transformer”的一切——深入理解——第一部分
2. 关于“注意力”和“Transformer”的一切——深入理解——第二部分
要构建我们的 Transformer 模型,我们将遵循以下步骤:
-
导入必要的库和模块
-
定义基本构建模块:多头注意力、位置逐层前馈网络、位置编码
-
构建编码器和解码器层
-
结合编码器和解码器层以创建完整的 Transformer 模型
-
准备示例数据
-
训练模型
让我们首先导入必要的库和模块。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
现在,我们将定义 Transformer 模型的基本构建块。
多头注意力

图 2. 多头注意力(来源:作者创建的图像)
多头注意力机制计算序列中每对位置之间的注意力。它由多个“注意力头”组成,这些头捕捉输入序列的不同方面。
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
attn_probs = torch.softmax(attn_scores, dim=-1)
output = torch.matmul(attn_probs, V)
return output
def split_heads(self, x):
batch_size, seq_length, d_model = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
def combine_heads(self, x):
batch_size, _, seq_length, d_k = x.size()
return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
def forward(self, Q, K, V, mask=None):
Q = self.split_heads(self.W_q(Q))
K = self.split_heads(self.W_k(K))
V = self.split_heads(self.W_v(V))
attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
output = self.W_o(self.combine_heads(attn_output))
return output
MultiHeadAttention 代码用输入参数和线性变换层初始化模块。它计算注意力得分,将输入张量重塑为多个头,并结合所有头的注意力输出。前向方法计算多头自注意力,使模型能够关注输入序列的不同方面。
位置逐层前馈网络
class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super(PositionWiseFeedForward, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
PositionWiseFeedForward 类扩展了 PyTorch 的 nn.Module,并实现了位置逐层前馈网络。该类初始化时包含两个线性变换层和一个 ReLU 激活函数。前向方法依次应用这些变换和激活函数以计算输出。这个过程使模型在做出预测时能够考虑输入元素的位置。
位置编码
位置编码用于注入输入序列中每个标记的位置信息。它使用不同频率的正弦和余弦函数来生成位置编码。
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
PositionalEncoding 类初始化时包含输入参数 d_model 和 max_seq_length,创建一个张量来存储位置编码值。该类根据缩放因子 div_term 计算偶数和奇数索引的正弦和余弦值。forward 方法通过将存储的位置信息编码值添加到输入张量中来计算位置编码,从而使模型能够捕捉输入序列的位置信息。
现在,我们将构建编码器和解码器层。
编码器层

图 3. Transformer 网络的编码器部分(来源:原始论文中的图像)
编码器层由一个多头注意力层、一个逐位置前馈层和两个层归一化层组成。
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(EncoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
EncoderLayer 类初始化时包含输入参数和组件,包括一个多头注意力模块、一个逐位置前馈模块、两个层归一化模块和一个丢弃层。forward 方法通过应用自注意力、将注意力输出添加到输入张量并归一化结果来计算编码器层输出。然后,它计算逐位置前馈输出,将其与归一化的自注意力输出结合,并在返回处理后的张量之前归一化最终结果。
解码器层

图 4. Transformer 网络的解码器部分(来源:原始论文中的图像)
解码器层由两个多头注意力层、一个逐位置前馈层和三个层归一化层组成。
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_output, src_mask, tgt_mask):
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
ff_output = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_output))
return x
解码器层初始化时包含输入参数和组件,如用于掩蔽自注意力和交叉注意力的多头注意力模块、一个逐位置前馈模块、三个层归一化模块和一个丢弃层。
forward 方法通过执行以下步骤计算解码器层输出:
-
计算掩蔽自注意力输出并将其添加到输入张量中,然后进行丢弃和层归一化。
-
计算解码器和编码器输出之间的交叉注意力输出,并将其添加到归一化的掩蔽自注意力输出中,然后进行丢弃和层归一化。
-
计算逐位置前馈输出并将其与归一化的交叉注意力输出结合,然后进行丢弃和层归一化。
-
返回处理后的张量。
这些操作使解码器能够根据输入和编码器输出生成目标序列。
现在,让我们将编码器和解码器层结合起来,创建完整的 Transformer 模型。
Transformer 模型

图 5. Transformer 网络(来源:原始论文中的图像)
将所有内容合并在一起:
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
super(Transformer, self).__init__()
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
self.fc = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def generate_mask(self, src, tgt):
src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
seq_length = tgt.size(1)
nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
tgt_mask = tgt_mask & nopeak_mask
return src_mask, tgt_mask
def forward(self, src, tgt):
src_mask, tgt_mask = self.generate_mask(src, tgt)
src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))
enc_output = src_embedded
for enc_layer in self.encoder_layers:
enc_output = enc_layer(enc_output, src_mask)
dec_output = tgt_embedded
for dec_layer in self.decoder_layers:
dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)
output = self.fc(dec_output)
return output
Transformer 类结合之前定义的模块,创建一个完整的 Transformer 模型。在初始化期间,Transformer 模块设置输入参数并初始化各种组件,包括源序列和目标序列的嵌入层、位置编码模块、EncoderLayer 和 DecoderLayer 模块以创建堆叠层、用于投影解码器输出的线性层以及 dropout 层。
generate_mask 方法创建源序列和目标序列的二进制掩码,以忽略填充标记,并防止解码器关注未来的标记。forward 方法通过以下步骤计算 Transformer 模型的输出:
-
使用
generate_mask方法生成源序列和目标序列的掩码。 -
计算源序列和目标序列的嵌入,并应用位置编码和 dropout。
-
通过编码器层处理源序列,更新
enc_output张量。 -
通过解码器层处理目标序列,使用
enc_output和掩码,并更新dec_output张量。 -
将线性投影层应用于解码器输出,获取输出 logits。
这些步骤使 Transformer 模型能够处理输入序列,并基于其组件的组合功能生成输出序列。
准备示例数据
在这个例子中,我们将创建一个玩具数据集用于演示。实际操作中,你会使用更大的数据集,预处理文本,并为源语言和目标语言创建词汇映射。
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)
# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length)) # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length)) # (batch_size, seq_length)
训练模型
现在我们将使用示例数据训练模型。在实际操作中,你会使用更大的数据集,并将其拆分为训练集和验证集。
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
transformer.train()
for epoch in range(100):
optimizer.zero_grad()
output = transformer(src_data, tgt_data[:, :-1])
loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
loss.backward()
optimizer.step()
print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
我们可以通过这种方式从零开始在 Pytorch 中构建一个简单的 Transformer。所有的大型语言模型都使用这些 Transformer 编码器或解码器模块进行训练。因此,了解最初的网络极为重要。希望这篇文章能帮助所有希望深入了解 LLM 的人。
参考文献
《Attention is all you need》
A. Vaswani、N. Shazeer、N. Parmar、J. Uszkoreit、L. Jones、A. Gomez、{. Kaiser 和 I. Polosukhin。《神经信息处理系统的进展》, 第 5998–6008 页. (2017)
在 Python 中构建基础机器学习模型
关于如何选择合适问题和如何开发基础分类器的详细论文
·
点击查看 发布于 Towards Data Science ·20 min 阅读·2023 年 1 月 2 日
--
照片由 charlesdeluvio 提供,来源于 Unsplash
目前,我们都见过各种基础机器学习(ML)模型的结果。互联网充斥着展示计算机如何识别各种动物的图像、视频和文章,无论识别是否正确。
尽管我们已经朝着更复杂的机器学习模型迈进,例如生成或提升图像的模型,但这些基础模型仍然构成了这些努力的基础。掌握基础知识可以成为未来更大事业的跳板。
所以,我决定自己重新审视基础知识,并构建一个具有几个警告的基本机器学习模型——它必须具有一定的实用性,尽可能简单,并返回合理准确的结果。
然而,与互联网上的许多其他教程不同,我想从头到尾展示我的整个思考过程。因此,编码部分将会稍晚开始,因为理论和实践领域中的问题选择同样重要。最后,我相信理解为什么比如何更为重要。
选择适合机器学习的问题
尽管机器学习可以解决许多挑战,但它并不是一种万能的解决方案。即使我们暂时忽略财务、时间和其他资源成本,机器学习模型在某些方面仍然表现出色,而在其他方面则表现糟糕。
分类是机器学习可能发挥作用的一个很好的例子。每当我们处理真实世界的数据(即我们不处理代码中创建的类别)时,找出定义现象的所有可能规则几乎是不可能的。
正如我之前所写的,如果我们尝试使用基于规则的方法来分类一个物体是否是猫,我们会很快遇到问题。似乎没有定义任何物理对象的特征——有些猫没有尾巴、毛发、耳朵、一只眼睛、不同数量的腿等等,但它们仍然都属于同一类别。
列举所有可能的规则及其例外可能是不可能的,也许甚至没有某种永恒的清单,我们只能在过程中逐步制定。机器学习在某种程度上通过消耗大量数据来进行预测,模仿了我们的思维。
换句话说,我们应该在尝试确定哪种模型最合适、需要多少数据以及开始任务后关注的其他事项之前,仔细考虑我们要解决的问题。
寻求实际应用
制作区分狗和猫的模型确实有趣且有趣,但即使我们将操作规模扩大到巨大的程度,也不太可能获得任何好处。此外,已经有数以百万计的此类模型教程在网上创建。
我决定选择词汇分类,因为它相对较少被写到,并且具有一定的实际应用。我们的 SEO 团队提出了一个有趣的提议——他们需要根据三种类型来分类关键词:
-
信息型 — 寻找关于某个主题的知识的用户(例如,“什么是代理”)
-
交易型 — 寻找产品或服务的用户(例如,“最佳代理”)
-
导航型——用户寻找特定品牌或其分支(例如,“Oxylabs”)
手动分类成千上万的关键词有点麻烦。这样的任务(几乎)完美适合机器学习,尽管存在一个几乎无法解决的固有问题,我将在后面详细说明。
最终,它使数据收集和管理变得比其他情况下要简单得多。SEO 专家使用各种工具来跟踪关键词,其中大多数可以将它们导出到 CSV 表中。只需将类别分配给关键词即可。
构建一个预 MVP
在建立模型之前决定需要多少数据点几乎是不可能的。虽然有一些依赖于既定目标(即,更多或更少的类别),但精确计算这些数据几乎是不可能的。选择一个足够大的数字(例如,1000 条记录)是一个好的起点。
我建议不要一开始就处理整个数据集。由于这是你第一次开发模型,很多事情可能会出错。一般来说,最好先编写代码并在小样本(例如总数据的 10%)上运行,以确保没有语义错误或其他问题。
一旦你得到所需的结果,就开始处理整个数据集。虽然你可能不会完全放弃项目,但你不希望花费几个小时(枯燥)的工作却没有任何成果。
无论如何,有了一些样本,我们可以正式开始开发过程。我选择了 Python,因为它是一种相当常见的语言,并且通过众多库为机器学习提供了不错的支持。
库
-
Pandas。虽然不是绝对必要,但读取和导出 CSV 文件将大大简化我们的工作。
-
SciKit-Learn。这是一个相当强大且灵活的机器学习库,它将成为我们分类模型的基础。在整个教程中,我们将使用各种 sklearn 功能。
-
NLTK(自然语言工具包)。由于我们将处理自然语言,NLTK 完美地完成了这个任务。停用词 是包中绝对必要的内容。
导入
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectKBest, chi2
from nltk.corpus import stopwords
第 1 行
相当自解释。Pandas 允许我们通过创建数据框来读取和写入 CSV 以及其他电子表格文件。由于我们将处理关键词,大多数 SEO 工具会将它们导出为 CSV,这将减少我们需要手动处理的数据。
第 2 行
从 SciKit-Learn 库中,我们将挑选几个东西,TfidfVectorizer 是我们的首选。
向量化器将我们的字符串转换为特征向量,这会导致两个重要的变化。首先,字符串被转换为数值表示。每个唯一的字符串被转换为一个索引,然后转化为向量(矩阵的衍生物)。
句子 #1:“狗是棕色的。”
句子 #2:“狗是黑色的。”
向量化将处理这两个句子并创建一个索引:
E(w) =
[0, if "the"
1, if "dog"
2, if "is"
3, if "brown"
4, if "black"]
除了将字符串转换为数值外,向量化还优化了数据处理。与其多次处理相同的字符串,不如使用相同的索引,类似于文件压缩。
最后,TFIDF(词频-逆文档频率)是衡量文档中词语重要性的一种方法。简单来说,它对每个词进行处理,评估其频率与文档长度的比值,并分配一个加权值。因此,重复出现的词语被认为更重要。
第三行
LogisticRegression 是发现变量之间关系的一种方法。由于我们的任务是经典的分类问题,逻辑回归非常适合,因为它接受某些输入变量 x(关键字),并将其分配一个值 y(信息性/交易性/导航性)。
还有其他选项,例如 LinearSVC,它涉及到更复杂的数学运算。极其简单地说,SVC 会对多个数据点簇进行处理,找到每个簇中最接近对方簇的值。这些值称为支持向量。
一个超平面(即在 n+1 维 空间中的 n 维 几何对象)被绘制成使其与每个支持向量的距离最大化。

作者提供的图片
已有研究表明使用支持向量机 可能在文本分类中产生更好的结果,但这可能是由于任务的复杂性显著增加。这些优势在我们的情况下并不完全相关,因为它们在特征数量达到极高水平时才会显现,因此线性回归应该也能很好地工作。
第四行
Pipeline 是一个灵活的机器学习工具,它让你创建一个对象,将整个过程的多个步骤组合成一个。它有许多好处——从帮助你编写更整洁的代码到防止数据泄漏。
第五行
虽然在我们的情况下并不是绝对必要的,SelectKBest 和 chi2 通过提高准确性和减少训练时间来优化模型。SelectKBest 允许我们设置最大特征数量。
Chi2(或卡方检验)是一种用于变量独立性的统计测试,有助于我们选择最佳特征(因此,SelectKBest)进行训练:

作者提供的图片。
where: c is degrees of freedom (i.e. sample size minus one)
O is the observed value(s)
E is the expected value(s)
期望值是通过接受原假设(变量独立)来计算的。这些值然后与我们的观测值进行对比。如果观测值与期望值有显著偏差,我们可以拒绝原假设,这迫使我们接受变量之间的依赖关系。
如果变量是相关的,它们对于机器学习模型是可接受的,因为这正是我们所寻找的 —— 对象之间的关系。反过来,SelectKBest 获取所有 chi2 结果,并选择那些具有最强关系的结果。
在我们的情况下,由于特征数量相对较少,SelectKBest 可能无法带来我们感兴趣的优化,但一旦数量开始增加,它就变得至关重要。
行 6
我们最终的导入来自 NLTK,我们将仅将其用于 stopwords 列表。不幸的是,默认列表不适合我们当前的任务。大多数这样的列表包含像“how”,“what”,“why”等词,这些词在常规分类中无用,但能指示搜索意图。
事实上,可以说这些词比“如何构建网页抓取器”这样的关键词中的任何剩余词更重要。由于我们对句子的类别感兴趣而非其他值,stopwords 是决定它可能是什么的最佳途径。
因此,删除一些停用词列表中的条目是至关重要的。幸运的是,NLTK 的停用词只是文本文件,你可以使用任何文字处理器进行编辑。
NLTK 下载默认存储在用户目录中,但可以通过使用 download_dir= 进行更改(如果需要的话)。
数据框和停用词
所有机器学习模型都从数据准备和处理开始。由于我们处理的是 SEO 关键词,这些关键词可以通过流行的性能测量工具轻松导出为 CSV。
选择一个随机样本,其中应包括接近相等数量的各类别,这一点是值得注意的。由于我们正在制作一个前期 MVP,这不应该成为问题,因为如果模型提供了我们需要的结果,可以随时添加数据。
在继续之前,明智的做法是从 CSV 文件中选择几打关键词并进行标注。一旦我们得到一个有效的模型,就可以标注其余的。由于 Pandas 以表格格式创建数据框,最简单的方法是添加一个新列,“Category” 或 “Label”,并将每个关键词行标记为 Informational, Transactional, or Navigational。
df = pd.read_csv('[KEYWORD_LIST].csv')
data = pd.DataFrame(df)
words = stopwords.words('english_adjusted')
行 1 和 2
每当我们有任何形式的 CSV 时,Pandas 要求我们创建一个数据框。首先,我们将读取由 SEO 工具提供的关键词列表。请记住,CSV 文件应该已经包含一些关键词分类,否则将没有东西可以用于训练模型。
阅读文件后,我们从 CSV 文件创建一个数据框对象。
行 3
我们将使用 NLTK 获取停用词文件,不过我们不能直接使用它。NLTK 的默认列表包含许多我们认为对关键词分类至关重要的词(例如,“what”,“how”,“where”等)。因此,它必须调整以适应我们的目的。
虽然在这种情况下没有硬性规定,但不定冠词和定冠词可以保留(例如,“a”,“an”,“the”等),因为它们不提供信息。然而,所有可能显示用户意图的内容都必须从默认文件中删除。
我创建了一个名为‘english_adjusted’的副本,以便于操作。此外,以防万一我需要原始版本,它将始终可用,无需重新下载。
最后,你可能需要运行一次 NLTK,使用常规参数‘english’来下载文件,这可以在任何阶段完成。否则,你会收到错误。
设置管道
在所有这些准备步骤之后,我们终于可以进入实际的机器学习部分。这些是模型中最重要的部分。你可能会花费相当多的时间调整这些参数,以找到最佳选项。
不幸的是,没有很多指导方针适用于所有情况。需要进行一些实验和推理,以减少所需的测试量,但完全消除测试是不可能的。
pipeline = Pipeline([('vect', TfidfVectorizer(ngram_range=(1, 3), stop_words=words)),
('chi', SelectKBest(chi2, k='all')),
('clf', LogisticRegression(C=1.0, penalty='l2', max_iter=1000, dual=False))])
有些人可能会注意到我没有通过scikit-learn将数据集拆分为训练集和测试集。这是问题的性质所赋予的奢侈。SEO 工具可以在不到一分钟的时间内导出数千个(未标记的)关键字,这意味着你可以单独采购测试集而不费吹灰之力。
因此,出于优化原因,我将使用没有标签的第二个数据集作为我们的测试基础。然而,由于train_test_split 非常普遍,我将在文章末尾的附录中展示一个使用它的相同模型版本。
第 1 行
管道允许我们将长时间的过程简化为一个对象,使处理模型设置变得更加容易。它还将减少出错的可能性。
我们将从定义我们的向量化器开始。我在上面提到过我们将使用TFIDFVectorizer,因为它根据文档中单词的权重来产生更好的结果。CountVectorizer 是一个选项,但你需要导入它,结果可能会有所不同。
Ngram_range 是一个有趣的推理挑战。为了获得最佳结果,你必须决定要计算多少个词元(在我们的情况下是单词)。Ngram_range 为 (1, 1) 会计算单个词(单词),(1, 2) 会计算单个词和两个相邻的词(双词组)的组合,(1, 3) 会计算单个词、两个词和三个词(三词组)的组合。
我选择了ngram_range(1, 3),有几个原因。首先,由于模型相对简单,性能不是问题,我可以运行更大范围的 n-gram,因此下限可以设置为最小。
另一方面,一旦我们去除停用词,我们应该考虑什么样的 ngram 上限足以从关键词中提取意义。如果可能,我发现从数据集中选择最难和最简单的例子更容易。在我们的情况下,最简单的例子是任何问题(“如何获取代理”),最难的是名词(“网络爬虫”)或名称(“Oxylabs”)。
由于我们将移除像“to”这样的词,我们会在问题案例中得到三元组(“how get proxies”),这是完全清晰的。事实上,你可以认为二元组(“how get”)也足够,因为意图仍然清晰。
然而,最难的例子通常会比三元组短,因为理解搜索意图的难易程度与查询长度相关。因此,ngram_range (1, 3) 应该在性能和准确性之间取得一个不错的平衡。
最后,对于 sublinear_tf 有一个论点,即它是常规 TF-IDF 计算的一个修改。如果设置为 True,权重通过对数函数计算:1 + log(tf)。换句话说,词频会获得递减的回报。
使用 sublinear_tf 时,频繁出现且出现在多个文档中的词语不会被赋予过重的权重。由于我们有一组相对随机的关键词,我们无法知道哪些会得到优待,但这些通常是像“how”,“what”等我们希望被赋予较重权重的词。
在测试过程中,我发现模型在没有 sublinear_tf 的情况下表现更好,但我建议稍微调整一下,看看是否会带来任何好处。
Stopwords 参数现在已经不言自明,因为我们之前已经讨论过了。
第 2 行
虽然不严格来说是新的一行,但我将为清晰和简洁的目的将其分开。我们现在将调用 SelectKBest,我在上面已经对其进行了相当详细的描述。我们的关注点是 k 值。
这些会有所不同,具体取决于你的数据集的大小。SelectKBest 旨在优化性能和准确性。在我的情况下,使用“all”是有效的,但你通常需要选择一个足够大的 N 来匹配你自己的数据集。
第 3 行
最后,我们来到将用于模型的方法。LogisticRegression 是我们的选择,如前所述,但需要对参数进行大量的调整。
“C”值是一个超参数,它告诉模型应该选择哪些参数。超参数是模型中非常复杂的部分,对最终结果有着巨大的影响。
从极其简单的角度来看,C 值是你训练数据的信任分数。较高的 C 值意味着在拟合时,对训练数据的权重会较高,而对惩罚的权重较低。较低的 C 值则将更多强调惩罚,训练数据的权重较低。
应始终存在一定的惩罚,因为训练永远无法完全代表现实世界的值(因为它只是一个小的子集,无论你收集多少)。此外,如果存在异常值而不进行惩罚,模型将会越来越贴近过拟合。
penalty 参数是用于超参数的操作。SciKit-Learn 提供了三种类型的惩罚——‘l1’、‘l2’和‘elasticnet’。‘None’也是一个选项,但如果使用的话应该尽量少。
‘L1’ 是所有系数的绝对值之和。简单来说,它将所有系数拉向某个中心点。如果施加了大的惩罚,一些数据点可能会变成零(即被消除)。
在存在多重共线性(多个变量相关)或需要简化模型的情况下,应该使用‘L1’。由于L1会消除一些数据点,因此模型几乎总是变得更简单。然而,当数据点的分布已经相对简单时,它的效果不如预期。
‘L2’ 是类似过程的不同版本。它不是绝对和,而是所有系数值的平方和。因此,所有系数都按相同的值缩小,但没有被消除。‘L2’ 是默认设置,因为它最灵活且很少引发问题。
‘Elasticnet’ 是上述两种方法的结合。关于是否应该将‘elasticnet’作为默认方法,已有相当广泛的评论,然而,并不是所有的求解器都支持它。在我们的情况下,我们需要切换到“saga”求解器,它是为大型数据集设计的。
在教程级别的机器学习模型中使用‘elasticnet’可能收益甚微。只需记住,将来它可能会有益。
继续讨论‘max_iter’,该参数将设置模型在收敛之前执行的最大迭代次数。简单来说,收敛是指进一步迭代不太可能发生的点,作为停止点。
较高的值会增加计算复杂性,但可能会导致更好的整体表现。在数据集相对简单的情况下,‘max_iter’ 可以设置为数千及以上,因为这对系统的负担不会太大。
如果值过低且收敛失败,将显示警告消息。因此,找到最低可能的值并从中开始并不困难。
拟合模型并输出数据
我们接近教程的结束,最终进入模型拟合和接收输出的阶段。
model = pipeline.fit(data.Keyword, data.Type)
chi = model.named_steps['chi']
clf = model.named_steps['clf']
doutput = pd.read_csv('[TEST_KEYWORD_LIST].csv')
doutput['Type'] = model.predict(doutput['Keyword'])
doutput.to_csv('[RESULT_LIST].csv')
##print('Accuracy score ' + str(model.score(x_test, y_test)))
第 1-3 行
在第 1 行中,我们使用我们建立的管道将模型拟合到训练数据中。如果需要进行调试或额外的分析,管道允许我们创建命名的步骤,这些步骤可以在后续调用。
第 4–8 行
我们从一个只包含关键词的 CSV 文件中创建另一个数据框。我们将使用新创建的模型来预测每个关键词及其类别。
由于我们的数据框仅包含关键词,我们添加了一个新的列“类型”,并运行model.predict以提供输出结果。
最终,所有结果被移动到一个输出的 CSV 文件中,该文件将在本地目录中创建。通常,你会想设置一些目标,但为了测试目的,通常没有必要这样做。
有一行被注释掉的代码我想提一下,它调用了score函数。SciKit为我们提供了多种方法来估计模型的预测能力。这些方法不应被视为绝对真理,因为预测准确度与实际准确度通常可能有所偏差。
然而,得分作为经验法则和快速评估参数对模型的影响是有用的。虽然有很多评分方法,但基本的model.score使用R 平方,在调整参数时通常很有帮助。
结果分析
我的训练数据仅有 1300 条条目,包含三种不同的类别,如上所述。即使在这样的小数据集中,模型仍然达到了约 80%的不错准确度。
其中一些,如预期的那样,是有争议的,甚至 Google 也这么认为。例如,“网页抓取”是一个经常被搜索的关键词。是否查询是交易性的还是信息性的没有明确的指示。Google 的搜索结果页面显示,前 5 条结果中有产品和信息文章。
模型在一个领域遇到了困难——导航关键词。如果我要猜测,模型大约 5-10%的时间能正确预测类别。出现这种情况有几个原因。
数据集的分布可能是一个问题,因为它严重不平衡:
-
交易型 — 0.457353%
-
信息型 — 0.450735%
-
导航型 — 0.091912%
虽然实际世界场景会呈现出类似的分布(由于导航关键词的固有稀有性),但训练数据过于稀疏,无法进行适当的拟合。此外,导航关键词的频率非常低,以至于模型通过总是分配其他两个类别可以获得更高的准确性。
然而,我认为展示更多的导航关键词的训练数据不会产生更好的结果。这是一个通过文本分析解决的极其困难的问题,无论我们选择哪种方法。
导航关键词主要由品牌名称组成,这些名称是新造词或其他新产生的词。它们中没有任何内容遵循自然语言,因此,它们之间的联系只能事后发现。换句话说,我们必须首先从其他数据源知道这是一个品牌名称,才能正确分配类别。
如果我得猜测,谷歌和其他搜索引擎可能通过用户查询新词时的行为来发现品牌名称。他们可能会查找域名匹配或其他数据,但在没有人工互动的情况下预测某个词是导航关键词是极其困难的。
特征工程可能是解决问题的潜在方案。我们需要发现导航类别和其他类别之间的新联系,并通过其他方法实施分配。
由于特征工程是一个完全不同的主题,并且值得单独写一篇文章,所以我将提供一个示例。导航关键词很少会以问题的形式被查询(除了“什么是”),否则它们没有意义(例如,“如何使用 Oxylabs”,“如何获取 Oxylabs”)。
是否将“如何获取 Oxylabs 代理”视为交易型还是导航型存在争议。然而,它确实符合交易型类别,因此可以被认为是交易型。
通过知道相对较少的导航关键词会以问题的形式出现,我们可以构建一个模型来过滤掉大多数问题,留下较小的潜在目标子集。
此外,许多导航关键词的查询长度显著较短,通常由单个词组成,而其他类别的查询长度相对较少。
这些方法及其他许多方法可以组合使用,以提高选择导航关键词时模型的准确性。然而,深入特征工程要比基础教程覆盖的内容复杂得多。
目前,词汇分类应通过对机器学习模型如何工作的整体更好理解来覆盖。希望对众多参数和工具的解释能让你从一开始就创建一个功能性的模型。
结论
即使文章非常长,你可能已经注意到,编写机器学习模型并不那么困难。事实上,可以说,在这个案例中,这只是项目的一小部分。
机器学习在很大程度上依赖于准备,我们可以概述几个部分:
-
选择正确的问题。有些问题用其他方法解决可能更好。不要被炒作所迷惑,尝试通过机器学习解决所有问题。使用基于规则的系统,你可能能够节省时间和资源,同时产生更好的结果。
-
准备数据。一个模型的好坏取决于数据。如果你的数据标记不正确、缺乏真实性或其他方面存在问题,那么再多的开发和资源也无法创建出可靠的输出。
-
选择模型。由于你已经做过很多次,可能很容易选择逻辑回归或其他模型。Sci-Kit Learn还有其他选项,比如我没有提到的PassiveAggressiveClassifier,它们使用不同的数学方法。再次强调,选择正确的问题非常重要,因为它应该决定你选择什么建模方法。
我希望这篇文章能为许多机器学习新手提供帮助,不仅提供实践部分,还提供处理问题的思路。
附录:原始完整代码块
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.corpus import stopwords
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.feature_selection import SelectKBest, chi2
df = pd.read_csv('[KEYWORD_LIST].csv')
data = pd.DataFrame(df)
words = stopwords.words('english_adjusted')
pipeline = Pipeline([('vect', TfidfVectorizer(ngram_range=(1, 3), stop_words=words)),
('chi', SelectKBest(chi2, k='all')),
('clf', LogisticRegression(C=1.0, penalty='l2', max_iter=1000))])
model = pipeline.fit(data.Keyword, data.Type)
chi = model.named_steps['chi']
clf = model.named_steps['clf']
doutput = pd.read_csv('[TEST_KEYWORD_LIST].csv')
doutput['Type'] = model.predict(doutput['Keyword'])
doutput.to_csv('[RESULT_LIST].csv')
##print('accuracy score ' + str(model.score(x_test, y_test)))
附录 II:Train_test_split
导入
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.feature_selection import SelectKBest, chi2
from nltk.corpus import stopwords
按照惯例,我们需要导入train_test_split本身(第 5 行)。
设置拆分
x_train, x_test, y_train, y_test = train_test_split(data.Keyword, data.Type, test_size=0.3)
pipeline = Pipeline([('vect', TfidfVectorizer(ngram_range=(1, 3), stop_words=words)),
('chi', SelectKBest(chi2, k='all')),
('clf', LogisticRegression(C=1.0, penalty='l2', max_iter=1000, dual=False))])
第 1 行
由于我们的数据集只有两个特征(keyword和category),我们需要为每个特征准备两个变量。其中一个用于存储训练数据,另一个用于测试目的。
我们将使用之前步骤中创建的数据框,并指定列名(在我的数据集中,它们被称为“Keyword”和“Type”,如参数所示)。
最后,SciKit-Learn通过允许对两个数据集进行自动分割来解决数据拆分问题。train_test_split接受表示测试集大小或训练集大小百分比的浮点和整数值。如果两个值都设置为None,默认值将为 0.25。
需要进行一些调整才能获得最佳结果。我尝试了许多不同的拆分,其中 0.3 产生了最佳结果。一般来说,你会发现许多模型在 0.2 到 0.3 范围内的拆分效果最佳。
特定的拆分对准确性的影响较小,当数据点数量增加时更是如此。实际上,在极大的数据集上,拆分为 0.1 可能会提高计算性能。
统计单位之间的关系很复杂,但是可以建立的连接的抽象领域是有限的,因此准确性可以理解为对一定数量的数据点的要求,而不是特定的比例。换句话说,有一个N,在这个点上结果不会再变得更好,因此如果数据集很大,较小的比例可能更为优化。
有一些关于这个主题的高度技术文章深入解释了这个想法,并提供了计算最佳拆分的方法。
此代码块中的其他部分遵循与原始教程相同的步骤。
拟合模型并输出数据(再次)
model = pipeline.fit(x_train, y_train)
chi = model.named_steps['chi']
clf = model.named_steps['clf']
doutput = pd.DataFrame({'Keyword': x_test, 'Type': model.predict(x_test)})
doutput.to_csv('[RESULT_LIST].csv')
##print('Accuracy score ' + str(model.score(x_test, y_test)))
第 1 行
我们不会直接在标记数据集上训练模型,而是在之前拆分的那个数据集上进行训练,命名为x_train和y_train。第 2 行和第 3 行保持不变。
第 4 行
由于没有单独的数据集,我们将使用初始数据集中的测试部分进行预测。因此,我们创建一个数据框,其中包含关键词这一列,我们将在该列中输出测试数据集中的所有关键词。在第二列类型中,我们将使用模型来预测关键词的类别,依然使用相同的数据集。
最终,按照原始版本,所有结果将输出到一个结果文件中。如果有人对模型的表现如何感兴趣,也可以选择打印准确率分数。
使用 Athena 和 MySQL 构建批量数据管道
原文:
towardsdatascience.com/building-a-batch-data-pipeline-with-athena-and-mysql-7e60575ff39c
初学者的端到端教程
·发布于 Towards Data Science ·16 min 阅读·2023 年 10 月 20 日
--

在这个故事中,我将讲述一种非常流行的数据转换任务执行方式——批量数据处理。当我们需要以块状方式处理数据时,这种数据管道设计模式变得极其有用,非常适合需要调度的 ETL 作业。我将通过使用 MySQL 和 Athena 构建数据转换管道来展示如何实现这一目标。我们将使用基础设施即代码在云中部署它。
想象一下,你刚刚作为数据工程师加入了一家公司。他们的数据堆栈现代、事件驱动、成本效益高、灵活,并且可以轻松扩展以满足不断增长的数据资源。你数据平台中的外部数据源和数据管道由数据工程团队管理,使用具有 CI/CD GitHub 集成的灵活环境设置。
作为数据工程师,你需要创建一个业务智能仪表板,展示公司收入来源的地理分布,如下所示。原始支付数据存储在服务器数据库(MySQL)中。你想构建一个批量管道,从该数据库中每日提取数据,然后使用 AWS S3 存储数据文件,并使用 Athena 进行处理。

收入仪表板。图像由作者提供。
批量数据管道
数据管道可以被视为一系列数据处理步骤。由于这些阶段之间的逻辑数据流连接,每个阶段生成的输出作为下一个阶段的输入。
只要在点 A 和点 B 之间进行数据处理,就存在数据管道。
数据管道可能因其概念和逻辑性质而有所不同。我之前在这里写过 [1]:
选择合适的架构及其示例
我们希望创建一个数据管道,在以下 步骤 中转换数据:
1. 使用 Lambda 函数将数据从 MySQL 数据库表 myschema.users 和 myschema.transactions 提取到 S3 数据湖桶中。
2. 添加一个具有 Athena 资源的状态机节点以启动执行 (arn:aws:states:::athena:startQueryExecution.sync) 并创建一个名为 mydatabase 的数据库
3. 创建另一个数据管道节点以显示 Athena 数据库中的现有表。使用该节点的输出执行所需的数据转换。
如果表不存在,我们希望我们的管道在 Athena 中根据来自数据湖 S3 桶的数据创建它们。我们希望创建两个 外部表,数据来自 MySQL:
-
mydatabase.users (LOCATION ‘s3://<YOUR_DATALAKE_BUCKET>/data/myschema/users/’)
-
mydatabase.transactions (LOCATION ‘s3://<YOUR_DATALAKE_BUCKET>/data/myschema/transactions/’)
然后我们希望创建一个 优化的 ICEBERG 表:
- mydatabase.user_transactions (‘table_type’=’ICEBERG’, ‘format’=’parquet’) 使用以下 SQL:
SELECT
date(dt) dt
, user_id
, sum(total_cost) total_cost_usd
, registration_date
FROM mydatabase.transactions
LEFT JOIN mydatabase.users
ON users.id = transactions.user_id
GROUP BY
dt
, user_id
, registration_date
;
- 我们还将使用 MERGE 来更新此表。
MERGE 是一种非常有用的 SQL 技巧,用于表中的增量更新。查看我之前的故事 [3] 以获取更高级的示例:
从 1 到 10,你的数据仓库技能有多好?
Athena 可以通过运行有吸引力的即席 SQL 查询来分析存储在 Amazon S3 中的结构化、非结构化和半结构化数据,无需管理基础设施。
我们不需要加载数据,这使得它成为我们任务的完美选择。
它可以轻松地与 Business Intelligence (BI) 解决方案如 QuickSight 集成以生成报告。
ICEBERG 是一种非常有用且高效的表格格式,多个独立程序可以同时且一致地处理相同的数据集 [2]。我之前在这里写过:
选择 Apache Iceberg 作为数据湖的几个有力理由
MySQL 数据连接器
让我们创建一个 AWS Lambda 函数,它能够在 MySQL 数据库中执行 SQL 查询。
代码非常简单且通用。它可以在任何无服务器应用程序中与任何云服务提供商一起使用。
我们将使用它将收入数据提取到数据湖中。建议的 Lambda 文件夹结构如下所示:
.
└── stack
├── mysql_connector
│ ├── config # config folder with environment related settings
│ ├── populate_database.sql # sql script to create source tables
│ ├── export.sql # sql script to export data to s3 datalake
│ └── app.py # main application file
├── package # required libraries
│ ├── PyMySQL-1.0.2.dist-info
│ └── pymysql
├── requirements.txt # required Python modules
└── stack.zip # Lambda package
我们将通过 AWS Step Functions 将这个小服务集成到管道中,以便于 编排和可视化。
为了创建一个能够从 MySQL 数据库中提取数据的 Lambda 函数,我们需要先为我们的 Lambda 创建一个文件夹。首先创建一个名为 stack的新文件夹,然后在其中创建一个名为mysql_connector` 的文件夹:
mkdir stack
cd stack
mkdir mysql_connector
然后我们可以使用下面的代码(将数据库连接设置替换为你的设置)来创建 app.py:
import os
import sys
import yaml
import logging
import pymysql
from datetime import datetime
import pytz
ENV = os.environ['ENV']
TESTING = os.environ['TESTING']
LAMBDA_PATH = os.environ['LAMBDA_PATH']
print('ENV: {}, Running locally: {}'.format(ENV, TESTING))
def get_work_dir(testing):
if (testing == 'true'):
return LAMBDA_PATH
else:
return '/var/task/' + LAMBDA_PATH
def get_settings(env, path):
if (env == 'staging'):
with open(path + "config/staging.yaml", "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
elif (env == 'live'):
with open(path + "config/production.yaml", "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
elif (env == 'test'):
with open(path + "config/test.yaml", "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
else:
print('No config found')
return config
work_dir = get_work_dir(TESTING)
print('LAMBDA_PATH: {}'.format(work_dir))
config=get_settings(ENV, work_dir)
print(config)
DATA_S3 = config.get('S3dataLocation') # i.e. datalake.staging.something. Replace it with your unique bucket name.
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# rds settings
rds_host = config.get('Mysql')['rds_host'] # i.e. "mymysqldb.12345.eu-west-1.rds.amazonaws.com"
user_name = "root"
password = "AmazingPassword"
db_name = "mysql"
# create the database connection outside of the handler to allow connections to be
# re-used by subsequent function invocations.
try:
conn = pymysql.connect(host=rds_host, user=user_name, passwd=password, db=db_name, connect_timeout=5)
except pymysql.MySQLError as e:
logger.error("ERROR: Unexpected error: Could not connect to MySQL instance.")
logger.error(e)
sys.exit()
logger.info("SUCCESS: Connection to RDS MySQL instance succeeded")
def lambda_handler(event, context):
processed = 0
print("")
try:
_populate_db()
_export_to_s3()
except Exception as e:
print(e)
message = 'Successfully populated the database and created an export job.'
return {
'statusCode': 200,
'body': { 'lambdaResult': message }
}
# Helpers:
def _now():
return datetime.utcnow().replace(tzinfo=pytz.utc).strftime('%Y-%m-%dT%H:%M:%S.%f')
def _populate_db():
try:
# Generate data and populate database:
fd = open(work_dir + '/populate_database.sql', 'r')
sqlFile = fd.read()
fd.close()
sqlCommands = sqlFile.split(';')
# Execute every command from the input file
for command in sqlCommands:
try:
with conn.cursor() as cur:
cur.execute(command)
print('---')
print(command)
except Exception as e:
print(e)
except Exception as e:
print(e)
def _export_to_s3():
try:
# Generate data and populate database:
fd = open(work_dir + '/export.sql', 'r')
sqlFile = fd.read()
fd.close()
sqlCommands = sqlFile.split(';')
# Execute every command from the input file
for command in sqlCommands:
try:
with conn.cursor() as cur:
cur.execute(command.replace("{{DATA_S3}}", DATA_S3))
print('---')
print(command)
except Exception as e:
print(e)
except Exception as e:
print(e)
要使用 AWS CLI 部署我们的微服务,请在命令行中运行以下命令(假设你在 ./stack 文件夹中):
# Package Lambda code:
base=${PWD##*/}
zp=$base".zip" # This will return stack.zip if you are in stack folder.
echo $zp
rm -f $zp # remove old package if exists
pip install --target ./package pymysql
cd package
zip -r ../${base}.zip .
cd $OLDPWD
zip -r $zp ./mysql_connector
确保在运行下一部分之前 AWS Lambda 角色已经存在 — role arn:aws:iam::<your-aws-account-id>:role/my-lambda-role。
# Deploy packaged Lambda using AWS CLI:
aws \
lambda create-function \
--function-name mysql-lambda \
--zip-file fileb://stack.zip \
--handler <path-to-your-lambda-handler>/app.lambda_handler \
--runtime python3.12 \
--role arn:aws:iam::<your-aws-account-id>:role/my-lambda-role
# # If already deployed then use this to update:
# aws --profile mds lambda update-function-code \
# --function-name mysql-lambda \
# --zip-file fileb://stack.zip;
我们的 MySQL 实例必须具备 S3 集成,以便 将数据导出到 S3 桶。这可以通过运行以下 SQL 查询实现:
-- Example query
-- Replace table names and S3 bucket location
SELECT * FROM myschema.transactions INTO OUTFILE S3 's3://<YOUR_S3_BUCKET>/data/myschema/transactions/transactions.scv' FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' OVERWRITE ON;
如何创建 MySQL 实例
我们可以使用 CloudFormation 模板和基础设施即代码来创建 MySQL 数据库。考虑这个 AWS 命令:
aws \
cloudformation deploy \
--template-file cfn_mysql.yaml \
--stack-name MySQLDB \
--capabilities CAPABILITY_IAM
它将使用 cfn_mysql.yaml 模板文件来创建名为 MySQLDB 的 CloudFormation 堆栈。我之前在这里写过有关它的内容 [4]:
## 使用 AWS CloudFormation 创建 MySQL 和 Postgres 实例
数据库从业人员的基础设施即代码
towardsdatascience.com
我们的 cfn_mysql.yaml 应该如下所示:
AWSTemplateFormatVersion: 2010-09-09
Description: >-
This
template creates an Amazon Relational Database Service database instance. You
will be billed for the AWS resources used if you create a stack from this
template.
Parameters:
DBUser:
Default: root
NoEcho: 'true'
Description: The database admin account username
Type: String
MinLength: '1'
MaxLength: '16'
AllowedPattern: '[a-zA-Z][a-zA-Z0-9]*'
ConstraintDescription: must begin with a letter and contain only alphanumeric characters.
DBPassword:
Default: AmazingPassword
NoEcho: 'true'
Description: The database admin account password
Type: String
MinLength: '8'
MaxLength: '41'
AllowedPattern: '[a-zA-Z0-9]*'
ConstraintDescription: must contain only alphanumeric characters.
Resources:
### Role to output into s3
MySQLRDSExecutionRole:
Type: "AWS::IAM::Role"
Properties:
AssumeRolePolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: "Allow"
Principal:
Service:
- !Sub rds.amazonaws.com
Action: "sts:AssumeRole"
Path: "/"
Policies:
- PolicyName: MySQLRDSExecutionPolicy
PolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: Allow
Action:
- "s3:*"
Resource: "*"
###
RDSCluster:
Properties:
DBClusterParameterGroupName:
Ref: RDSDBClusterParameterGroup
Engine: aurora-mysql
MasterUserPassword:
Ref: DBPassword
MasterUsername:
Ref: DBUser
### Add a role to export to s3
AssociatedRoles:
- RoleArn: !GetAtt [ MySQLRDSExecutionRole, Arn ]
###
Type: "AWS::RDS::DBCluster"
RDSDBClusterParameterGroup:
Properties:
Description: "CloudFormation Sample Aurora Cluster Parameter Group"
Family: aurora-mysql5.7
Parameters:
time_zone: US/Eastern
### Add a role to export to s3
aws_default_s3_role: !GetAtt [ MySQLRDSExecutionRole, Arn ]
###
Type: "AWS::RDS::DBClusterParameterGroup"
RDSDBInstance1:
Type: 'AWS::RDS::DBInstance'
Properties:
DBClusterIdentifier:
Ref: RDSCluster
# AllocatedStorage: '20'
DBInstanceClass: db.t2.small
# Engine: aurora
Engine: aurora-mysql
PubliclyAccessible: "true"
DBInstanceIdentifier: MyMySQLDB
RDSDBParameterGroup:
Type: 'AWS::RDS::DBParameterGroup'
Properties:
Description: CloudFormation Sample Aurora Parameter Group
# Family: aurora5.6
Family: aurora-mysql5.7
Parameters:
sql_mode: IGNORE_SPACE
max_allowed_packet: 1024
innodb_buffer_pool_size: '{DBInstanceClassMemory*3/4}'
# Aurora instances need to be associated with a AWS::RDS::DBCluster via DBClusterIdentifier without the cluster you get these generic errors
如果一切顺利,我们将看到 Amazon 账户中出现一个新的堆栈:

带有 MySQL 实例的 CloudFormation 堆栈。图片由作者提供。
现在我们可以在我们的数据管道中使用这个 MySQL 实例。我们可以在任何 SQL 工具中尝试我们的 SQL 查询,例如 SQL Workbench,以填充表数据。这些表将用于稍后使用 Athena 创建外部表,可以通过 SQL 创建:
CREATE TABLE IF NOT EXISTS
myschema.users AS
SELECT
1 AS id,
CURRENT_DATE() AS registration_date
UNION ALL
SELECT
2 AS id,
DATE_SUB(CURRENT_DATE(), INTERVAL 1 day) AS registration_date;
CREATE TABLE IF NOT EXISTS
myschema.transactions AS
SELECT
1 AS transaction_id,
1 AS user_id,
10.99 AS total_cost,
CURRENT_DATE() AS dt
UNION ALL
SELECT
2 AS transaction_id,
2 AS user_id,
4.99 AS total_cost,
CURRENT_DATE() AS dt
UNION ALL
SELECT
3 AS transaction_id,
2 AS user_id,
4.99 AS total_cost,
DATE_SUB(CURRENT_DATE(), INTERVAL 3 day) AS dt
UNION ALL
SELECT
4 AS transaction_id,
1 AS user_id,
4.99 AS total_cost,
DATE_SUB(CURRENT_DATE(), INTERVAL 3 day) AS dt
UNION ALL
SELECT
5 AS transaction_id,
1 AS user_id,
5.99 AS total_cost,
DATE_SUB(CURRENT_DATE(), INTERVAL 2 day) AS dt
UNION ALL
SELECT
6 AS transaction_id,
1 AS user_id,
15.99 AS total_cost,
DATE_SUB(CURRENT_DATE(), INTERVAL 1 day) AS dt
UNION ALL
SELECT
7 AS transaction_id,
1 AS user_id,
55.99 AS total_cost,
DATE_SUB(CURRENT_DATE(), INTERVAL 4 day) AS dt
;
使用 Athena 处理数据
现在我们希望添加一个数据管道工作流,该工作流触发我们的 Lambda 函数以从 MySQL 提取数据,将其保存到数据湖中,然后在 Athena 中开始数据转换。
我们希望使用 MySQL 中的数据创建两个外部 Athena 表:
-
myschema.users
-
myschema.transactions
然后我们希望创建一个优化的 ICEBERG 表 myschema.user_transactions,将其连接到我们的 BI 解决方案。
我们希望使用 MERGE 语句将新数据插入到该表中。
CREATE EXTERNAL TABLE mydatabase.users (
id bigint
, registration_date string
)
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ','
STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'
OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
LOCATION 's3://<YOUR_S3_BUCKET>/data/myschema/users/' TBLPROPERTIES ( 'skip.header.line.count'='0')
;
select * from mydatabase.users;
CREATE EXTERNAL TABLE mydatabase.transactions (
transaction_id bigint
, user_id bigint
, total_cost double
, dt string
)
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ','
STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat'
OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
LOCATION 's3://<YOUR_S3_BUCKET>/data/myschema/transactions/' TBLPROPERTIES ( 'skip.header.line.count'='0')
;
select * from mydatabase.transactions;
CREATE TABLE IF NOT EXISTS mydatabase.user_transactions (
dt date,
user_id int,
total_cost_usd float,
registration_date string
)
PARTITIONED BY (dt)
LOCATION 's3://<YOUR_S3_BUCKET>/data/myschema/optimized-data-iceberg-parquet/'
TBLPROPERTIES (
'table_type'='ICEBERG',
'format'='parquet',
'write_target_data_file_size_bytes'='536870912',
'optimize_rewrite_delete_file_threshold'='10'
)
;
MERGE INTO mydatabase.user_transactions as ut
USING (
SELECT
date(dt) dt
, user_id
, sum(total_cost) total_cost_usd
, registration_date
FROM mydatabase.transactions
LEFT JOIN mydatabase.users
ON users.id = transactions.user_id
GROUP BY
dt
, user_id
, registration_date
) as ut2
ON (ut.dt = ut2.dt and ut.user_id = ut2.user_id)
WHEN MATCHED
THEN UPDATE
SET total_cost_usd = ut2.total_cost_usd, registration_date = ut2.registration_date
WHEN NOT MATCHED
THEN INSERT (
dt
,user_id
,total_cost_usd
,registration_date
)
VALUES (
ut2.dt
,ut2.user_id
,ut2.total_cost_usd
,ut2.registration_date
)
;
当新表准备好后,我们可以通过运行 SELECT * 来检查它:

mydatabase.user_transactions。图片由作者提供。
使用 Step Functions(状态机)编排数据管道
在之前的步骤中,我们学习了如何分别部署数据管道的每一步并进行测试。在这一段中,我们将了解如何使用基础设施代码和管道编排工具如 AWS Step Functions(状态机)创建一个完整的数据管道。当我们完成时,管道图将如下所示:

使用 Step Functions 进行数据管道编排。图像由作者提供。
数据管道编排是一种很好的数据工程技术,它为我们的数据管道增加了互动性。这个想法在我之前的一篇故事中已经解释过[5]:
数据管道管理得当可以简化部署并提高数据的可用性和可访问性……
[towardsdatascience.com
要部署完整的编排器解决方案,包括所有必要的资源,我们可以使用 CloudFormation(基础设施即代码)。考虑下面这个可以在/stack文件夹中从命令行运行的脚本。确保<YOUR_S3_BUCKET>存在,并将其替换为您的实际 S3 桶:
#!/usr/bin/env bash
# chmod +x ./deploy-staging.sh
# Run ./deploy-staging.sh
PROFILE=<YOUR_AWS_PROFILE>
STACK_NAME=BatchETLpipeline
LAMBDA_BUCKET=<YOUR_S3_BUCKET> # Replace with unique bucket name in your account
APP_FOLDER=mysql_connector
date
TIME=`date +"%Y%m%d%H%M%S"`
base=${PWD##*/}
zp=$base".zip"
echo $zp
rm -f $zp
pip install --target ./package -r requirements.txt
# boto3 is not required unless we want a specific version for Lambda
# requirements.txt:
# pymysql==1.0.3
# requests==2.28.1
# pytz==2023.3
# pyyaml==6.0
cd package
zip -r ../${base}.zip .
cd $OLDPWD
zip -r $zp "./${APP_FOLDER}" -x __pycache__
# Check if Lambda bucket exists:
LAMBDA_BUCKET_EXISTS=$(aws --profile ${PROFILE} s3 ls ${LAMBDA_BUCKET} --output text)
# If NOT:
if [[ $? -eq 254 ]]; then
# create a bucket to keep Lambdas packaged files:
echo "Creating Lambda code bucket ${LAMBDA_BUCKET} "
CREATE_BUCKET=$(aws --profile ${PROFILE} s3 mb s3://${LAMBDA_BUCKET} --output text)
echo ${CREATE_BUCKET}
fi
# Upload the package to S3:
aws --profile $PROFILE s3 cp ./${base}.zip s3://${LAMBDA_BUCKET}/${APP_FOLDER}/${base}${TIME}.zip
aws --profile $PROFILE \
cloudformation deploy \
--template-file stack.yaml \
--stack-name $STACK_NAME \
--capabilities CAPABILITY_IAM \
--parameter-overrides \
"StackPackageS3Key"="${APP_FOLDER}/${base}${TIME}.zip" \
"AppFolder"=$APP_FOLDER \
"S3LambdaBucket"=$LAMBDA_BUCKET \
"Environment"="staging" \
"Testing"="false"
它将使用 stack.yaml 创建一个名为 BatchETLpipeline 的 CloudFormation 堆栈。它将打包我们的 Lambda 函数,创建一个包并将其上传到 S3 桶中。如果该桶不存在,它将创建它。然后将部署管道。
AWSTemplateFormatVersion: '2010-09-09'
Description: An example template for a Step Functions state machine.
Parameters:
DataLocation:
Description: Data lake bucket with source data files.
Type: String
Default: s3://your.datalake.aws/data/
AthenaResultsLocation:
Description: S3 location for Athena query results.
Type: String
Default: s3://your.datalake.aws/athena/
AthenaDatabaseName:
Description: Athena schema names for ETL pipeline.
Type: String
Default: mydatabase
S3LambdaBucket:
Description: Use this bucket to keep your Lambda package.
Type: String
Default: your.datalake.aws
StackPackageS3Key:
Type: String
Default: mysql_connector/stack.zip
ServiceName:
Type: String
Default: mysql-connector
Testing:
Type: String
Default: 'false'
AllowedValues: ['true','false']
Environment:
Type: String
Default: 'staging'
AllowedValues: ['staging','live','test']
AppFolder:
Description: app.py file location inside the package, i.e. mysql_connector when ./stack/mysql_connector/app.py.
Type: String
Default: mysql_connector
Resources:
LambdaExecutionRole:
Type: "AWS::IAM::Role"
Properties:
AssumeRolePolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: Allow
Principal:
Service: lambda.amazonaws.com
Action: "sts:AssumeRole"
MyLambdaFunction:
Type: "AWS::Lambda::Function"
Properties:
Handler: "index.handler"
Role: !GetAtt [ LambdaExecutionRole, Arn ]
Code:
ZipFile: |
exports.handler = (event, context, callback) => {
callback(null, "Hello World!");
};
Runtime: "nodejs18.x"
Timeout: "25"
### MySQL Connector Lmabda ###
MySqlConnectorLambda:
Type: AWS::Lambda::Function
DeletionPolicy: Delete
DependsOn: LambdaPolicy
Properties:
FunctionName: !Join ['-', [!Ref ServiceName, !Ref Environment] ]
Handler: !Sub '${AppFolder}/app.lambda_handler'
Description: Microservice that extracts data from RDS.
Environment:
Variables:
DEBUG: true
LAMBDA_PATH: !Sub '${AppFolder}/'
TESTING: !Ref Testing
ENV: !Ref Environment
Role: !GetAtt LambdaRole.Arn
Code:
S3Bucket: !Sub '${S3LambdaBucket}'
S3Key:
Ref: StackPackageS3Key
Runtime: python3.8
Timeout: 360
MemorySize: 128
Tags:
-
Key: Service
Value: Datalake
StatesExecutionRole:
Type: "AWS::IAM::Role"
Properties:
AssumeRolePolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: "Allow"
Principal:
Service:
- !Sub states.${AWS::Region}.amazonaws.com
Action: "sts:AssumeRole"
Path: "/"
Policies:
- PolicyName: StatesExecutionPolicy
PolicyDocument:
Version: "2012-10-17"
Statement:
- Effect: Allow
Action:
- "lambda:InvokeFunction"
Resource: "*"
- Effect: Allow
Action:
- "athena:*"
Resource: "*"
- Effect: Allow
Action:
- "s3:*"
Resource: "*"
- Effect: Allow
Action:
- "glue:*"
Resource: "*"
MyStateMachine:
Type: AWS::StepFunctions::StateMachine
Properties:
# StateMachineName: ETL-StateMachine
StateMachineName: !Join ['-', ['ETL-StateMachine', !Ref ServiceName, !Ref Environment] ]
DefinitionString:
!Sub
- |-
{
"Comment": "A Hello World example using an AWS Lambda function",
"StartAt": "HelloWorld",
"States": {
"HelloWorld": {
"Type": "Task",
"Resource": "${lambdaArn}",
"Next": "Extract from MySQL"
},
"Extract from MySQL": {
"Resource": "${MySQLLambdaArn}",
"Type": "Task",
"Next": "Create Athena DB"
},
"Create Athena DB": {
"Resource": "arn:aws:states:::athena:startQueryExecution.sync",
"Parameters": {
"QueryString": "CREATE DATABASE if not exists ${AthenaDatabaseName}",
"WorkGroup": "primary",
"ResultConfiguration": {
"OutputLocation": "${AthenaResultsLocation}"
}
},
"Type": "Task",
"Next": "Show tables"
},
"Show tables": {
"Resource": "arn:aws:states:::athena:startQueryExecution.sync",
"Parameters": {
"QueryString": "show tables in ${AthenaDatabaseName}",
"WorkGroup": "primary",
"ResultConfiguration": {
"OutputLocation": "${AthenaResultsLocation}"
}
},
"Type": "Task",
"Next": "Get Show tables query results"
},
"Get Show tables query results": {
"Resource": "arn:aws:states:::athena:getQueryResults",
"Parameters": {
"QueryExecutionId.$": "$.QueryExecution.QueryExecutionId"
},
"Type": "Task",
"Next": "Decide what next"
},
"Decide what next": {
"Comment": "Based on the input table name, a choice is made for moving to the next step.",
"Type": "Choice",
"Choices": [
{
"Not": {
"Variable": "$.ResultSet.Rows[0].Data[0].VarCharValue",
"IsPresent": true
},
"Next": "Create users table (external)"
},
{
"Variable": "$.ResultSet.Rows[0].Data[0].VarCharValue",
"IsPresent": true,
"Next": "Check All Tables"
}
],
"Default": "Check All Tables"
},
"Create users table (external)": {
"Resource": "arn:aws:states:::athena:startQueryExecution.sync",
"Parameters": {
"QueryString": "CREATE EXTERNAL TABLE ${AthenaDatabaseName}.users ( id bigint , registration_date string ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' LOCATION 's3://datalake.staging.liveproject/data/myschema/users/' TBLPROPERTIES ( 'skip.header.line.count'='0') ;",
"WorkGroup": "primary",
"ResultConfiguration": {
"OutputLocation": "${AthenaResultsLocation}"
}
},
"Type": "Task",
"Next": "Create transactions table (external)"
},
"Create transactions table (external)": {
"Resource": "arn:aws:states:::athena:startQueryExecution.sync",
"Parameters": {
"QueryString": "CREATE EXTERNAL TABLE ${AthenaDatabaseName}.transactions ( transaction_id bigint , user_id bigint , total_cost double , dt string ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' LOCATION 's3://datalake.staging.liveproject/data/myschema/transactions/' TBLPROPERTIES ( 'skip.header.line.count'='0') ;",
"WorkGroup": "primary",
"ResultConfiguration": {
"OutputLocation": "${AthenaResultsLocation}"
}
},
"Type": "Task",
"Next": "Create report table (parquet)"
},
"Create report table (parquet)": {
"Resource": "arn:aws:states:::athena:startQueryExecution.sync",
"Parameters": {
"QueryString": "CREATE TABLE IF NOT EXISTS ${AthenaDatabaseName}.user_transactions ( dt date, user_id int, total_cost_usd float, registration_date string ) PARTITIONED BY (dt) LOCATION 's3://datalake.staging.liveproject/data/myschema/optimized-data-iceberg-parquet/' TBLPROPERTIES ( 'table_type'='ICEBERG', 'format'='parquet', 'write_target_data_file_size_bytes'='536870912', 'optimize_rewrite_delete_file_threshold'='10' ) ;",
"WorkGroup": "primary",
"ResultConfiguration": {
"OutputLocation": "${AthenaResultsLocation}"
}
},
"Type": "Task",
"End": true
},
"Check All Tables": {
"Type": "Map",
"InputPath": "$.ResultSet",
"ItemsPath": "$.Rows",
"MaxConcurrency": 0,
"Iterator": {
"StartAt": "CheckTable",
"States": {
"CheckTable": {
"Type": "Choice",
"Choices": [
{
"Variable": "$.Data[0].VarCharValue",
"StringMatches": "*users",
"Next": "passstep"
},
{
"Variable": "$.Data[0].VarCharValue",
"StringMatches": "*user_transactions",
"Next": "Insert New parquet Data"
}
],
"Default": "passstep"
},
"Insert New parquet Data": {
"Resource": "arn:aws:states:::athena:startQueryExecution.sync",
"Parameters": {
"QueryString": "MERGE INTO ${AthenaDatabaseName}.user_transactions as ut USING ( SELECT date(dt) dt , user_id , sum(total_cost) total_cost_usd , registration_date FROM ${AthenaDatabaseName}.transactions LEFT JOIN ${AthenaDatabaseName}.users ON users.id = transactions.user_id GROUP BY dt , user_id , registration_date ) as ut2 ON (ut.dt = ut2.dt and ut.user_id = ut2.user_id) WHEN MATCHED THEN UPDATE SET total_cost_usd = ut2.total_cost_usd, registration_date = ut2.registration_date WHEN NOT MATCHED THEN INSERT ( dt ,user_id ,total_cost_usd ,registration_date ) VALUES ( ut2.dt ,ut2.user_id ,ut2.total_cost_usd ,ut2.registration_date ) ;",
"WorkGroup": "primary",
"ResultConfiguration": {
"OutputLocation": "${AthenaResultsLocation}"
}
},
"Type": "Task",
"End": true
},
"passstep": {
"Type": "Pass",
"Result": "NA",
"End": true
}
}
},
"End": true
}
}
}
- {
lambdaArn: !GetAtt [ MyLambdaFunction, Arn ],
MySQLLambdaArn: !GetAtt [ MySqlConnectorLambda, Arn ],
AthenaResultsLocation: !Ref AthenaResultsLocation,
AthenaDatabaseName: !Ref AthenaDatabaseName
}
RoleArn: !GetAtt [ StatesExecutionRole, Arn ]
Tags:
-
Key: "keyname1"
Value: "value1"
-
Key: "keyname2"
Value: "value2"
# IAM role for mysql-data-connector Lambda:
LambdaRole:
Type: AWS::IAM::Role
Properties:
AssumeRolePolicyDocument:
Version: "2012-10-17"
Statement:
-
Effect: Allow
Principal:
Service:
- "lambda.amazonaws.com"
Action:
- "sts:AssumeRole"
LambdaPolicy:
Type: AWS::IAM::Policy
DependsOn: LambdaRole
Properties:
Roles:
- !Ref LambdaRole
PolicyName: !Join ['-', [!Ref ServiceName, !Ref Environment, 'lambda-policy']]
PolicyDocument:
{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": [
"logs:CreateLogGroup",
"logs:CreateLogStream",
"logs:PutLogEvents"
],
"Resource": "*"
}
]
}
如果一切顺利,我们的新数据管道的堆栈将被部署:

BatchETLpipeline 堆栈和资源。图像由作者提供。
如果我们点击状态机资源,然后点击‘编辑’,我们将看到我们的 ETL 管道作为图形展示:

批量数据管道的工作流工作室。图像由作者提供。
现在我们可以执行管道以运行所有必要的数据转换步骤。点击‘开始执行’。

成功执行。图像由作者提供。
现在我们可以将我们的 Athena 表连接到我们的BI 解决方案。连接我们最终的 Athena 数据集mydataset.user_transactions以创建仪表盘。

连接 Quicksight 中的数据集。图像由作者提供。
我们只需调整几个设置,使我们的仪表盘看起来像这样:

Quicksight 仪表盘。图像由作者提供。
我们希望使用dt作为维度,total_cost_usd作为指标。我们还可以为每个user_id设置一个拆分维度。
结论
批处理数据管道很受欢迎,因为历史上工作负载主要是批处理型的数据环境。我们刚刚建立了一个 ETL 数据管道,从 MySQL 中提取数据并在数据湖中转换。该模式最适用于数据集不大且需要持续处理的情况,因为 Athena 根据扫描的数据量收费。这种方法在将数据转换为列式格式如 Parquet 或 ORC 时表现良好,结合几个小文件成较大的文件,或进行分桶和添加分区。我以前在我的一个故事中写过这些大数据文件格式[6]。
Parquet 与 ORC 与 AVRO 与 JSON。该选择哪一个,如何使用它们?
towardsdatascience.com
我们学习了如何使用 Step Functions 来编排数据管道,视觉化数据流从源头到最终用户,并使用基础设施即代码进行部署。这个设置使得我们可以对数据管道应用 CI/CD 技术[7]。
希望这个教程对你有帮助。如果你有任何问题,请告诉我。
推荐阅读
[1] towardsdatascience.com/data-pipeline-design-patterns-100afa4b93e3
[2] medium.com/towards-data-science/introduction-to-apache-iceberg-tables-a791f1758009
[3] medium.com/towards-data-science/advanced-sql-techniques-for-beginners-211851a28488
[5] medium.com/towards-data-science/data-pipeline-orchestration-9887e1b5eb7a
[6] medium.com/towards-data-science/big-data-file-formats-explained-275876dc1fc9
使用 Hugging Face 的 Transformer 模型构建评论毒性排序器
赶上 NLP 和 LLM(第一部分)
·发表于Towards Data Science ·18 分钟阅读·2023 年 8 月 6 日
--

图片来源:Brett Jordan于Unsplash
介绍
作为一名数据科学家,我从未有机会深入探索自然语言处理的最新进展。随着夏季和今年初大语言模型的新热潮,我决定是时候深入这个领域并开始一些小项目。毕竟,没有比实践更好的学习方法了。
在我的旅程开始时,我意识到很难找到能够手把手引导读者、一步一步深入理解新 NLP 模型并通过具体项目进行的内容。因此,我决定开始这一新系列的文章。
使用 HuggingFace 的 Transformer 模型构建评论毒性排序器
在这篇文章中,我们将深入探讨构建评论毒性排序器。这个项目灵感来源于去年在 Kaggle 上举办的“Jigsaw 毒性评论严重性评估”竞赛。
竞赛的目标是构建一个能够判断哪个评论(在给定的两个评论中)最具毒性的模型。
为此,模型会为每个输入的评论分配一个分数,以确定其相对毒性。
本文涵盖内容
在这篇文章中,我们将使用 Pytorch 和 Hugging Face transformers 训练我们的第一个 NLP 分类器。我不会深入讲解 transformers 的工作原理,而是更多地关注实际细节和实现,并引入一些对系列后续文章有用的概念。
具体来说,我们将看到:
-
如何从 Hugging Face Hub 下载模型
-
如何自定义和使用编码器
-
从 Hugging Face 模型中构建并训练一个 Pytorch 排名器
本文直接面向希望从实际角度提升其自然语言处理技能的数据科学家。我不会详细讲解变换器的理论,即使我会详细编写代码,也希望你已经对 PyTorch 有一些了解。
探索与架构
训练数据集
我们将处理一个将评论配对并将其分类为“较少毒性”和“更多毒性”的数据集。
相对毒性的选择是由一组标注者做出的。
下图显示了来自训练集的数据样本。工人字段表示进行分类的标注者的 id。

训练集样本,作者插图
注意:数据集在开源许可证下提供,遵循Kaggle 竞赛规则。
排名系统
在任何机器学习项目中,理解任务具有至关重要的意义,因为它显著影响了合适模型和策略的选择。这种理解应从项目启动时就建立起来。
在这个具体的项目中,我们的目标是构建一个排名系统。与其预测一个具体的目标,我们的重点是确定一个任意值,以便在样本对之间进行有效比较。
让我们首先绘制一个基本的图示来表示这个概念,知道我们稍后会更深入地探讨“模型”的工作原理。

我们想要实现的一个非常基本的视图
以这种方式可视化任务至关重要,因为它表明项目的目标不仅仅是基于训练数据训练一个简单的二分类器。与其仅仅预测 0 或 1 来识别最有毒的评论,排名系统旨在分配任意值,从而有效地比较评论。
模型训练与边际排名损失
考虑到“模型”仍然是一个黑箱神经网络,我们需要建立一种利用这个系统并利用我们的训练数据来更新模型权重的方法。为此,我们需要一个合适的损失函数。
鉴于我们的目标是构建一个排名系统,边际排名损失是一个相关的选择。这个损失函数受到铰链损失的启发,后者通常用于优化样本之间的最大边际。
边际排名损失对样本对进行操作。对于每一对样本,它比较“模型”对两个样本产生的分数,并强制它们之间有一个边际。这个边际表示正确排序的样本之间期望的分数差异。

Margin Ranking Loss 函数公式,作者插图
在上述公式中,x1 和 x2 是两个样本的排名得分,y 是一个系数,如果 x1 应该排名高于 x2,则 y 等于 1,否则为 -1。 “margin” 是公式的超参数,设置了需要达到的最小间隔。
让我们看看这个损失函数是如何工作的:
假设 y=1,这意味着与 x1 相关的样本应该比与 x2 相关的样本排名更高:
-
如果 (x1 — x2) > margin,样本 1 的得分比样本 2 的得分高出足够的间隔,则 max() 的右侧项为负数。返回的损失将等于 0,并且这两个排名之间没有惩罚。
-
如果 (x1 — x2) < margin,这意味着 x1 和 x2 之间的间隔不足,或者更糟的是,x2 的得分高于 x1 的得分。在这种情况下,损失会更高,因为样本 2 的得分高于样本 1 的得分,这会惩罚模型。
鉴于此,我们现在可以按照如下修订我们的训练方法:
对于训练集中的一个样本(或一个批次):
-
将 more_toxic 消息传递给模型,得到 Rank_Score1 (x1)
-
将 less_toxic 消息传递给模型,得到 Rank_Score2 (x2)
-
计算 y = 1 时的 MarginRankingLoss
-
根据计算出的损失更新模型的权重(反向传播步骤)

使用 Margin Ranking Loss 的模型训练步骤,作者插图
从文本到特征表示:编码器块
我们的训练程序现在已设置完成。是时候深入了解‘模型’组件本身了。在 NLP 的世界里,你会经常遇到三种主要类型的模型:编码器、解码器和编码器-解码器组合。在这一系列文章中,我们将更详细地研究这些类型的模型。
对于本特定文章的目的,我们需要一个可以将消息转换为特征向量的模型。这个向量作为输入生成最终的排名得分。这个特征向量将直接从变换器架构的编码器中派生。
我不会在这里深入理论,因为其他人已经解释得很好(我推荐 Hugging Face 的入门课程,写得非常好)。只需记住这个过程的关键部分叫做注意力机制。它通过查看其他相关词,即使它们相隔很远,也帮助变换器理解文本。
有了这种架构,我们将能够调整权重,以生成我们文本的最佳向量表示,从而识别出对任务最重要的特征,并将最终层从变换器连接到一个最终节点(称为“头”),该节点将生成最终的排名得分。
让我们相应地更新我们的图示:

我们训练流水线的更新视图,作者插图
分词器
正如你从上述图表中看到的,模型内部出现了一个我们尚未提及的组件:预处理步骤。
这个预处理步骤旨在将原始文本转换为可以通过神经网络处理的内容(数字),而这就是分词器的作用。
分词器的主要功能有两个:分割(即将文本切割成片段,这些片段可以是单词、单词的一部分或字母)和索引(即将每个文本片段映射到一个唯一的值,该值在字典中引用,以便可以反向操作)。
需要记住的一件非常重要的事情是,文本的分词方式有多种,但如果你使用预训练模型,你需要使用相同的分词器,否则预训练权重将毫无意义(由于不同的分割和索引)。
另一个重要的事情是要记住,编码器只是一个神经网络。因此,它的输入需要是固定大小的,但你的输入文本不一定符合这一点。分词器允许你通过两个操作来控制你的词向量的大小:填充和截断。这也是一个重要的参数,因为一些预训练模型会使用更小或更大的输入空间。
在下面的图中,我们添加了分词器,并展示了消息如何从模块到模块进行转换。

最终训练示意图,作者插图
就这样,我们已经揭示了所有需要了解的组件,以便有效地处理我们的“评论毒性排名”任务。总结上述图表:每对消息(较少毒性和较多毒性)将分别传递给模型流水线。它们将首先经过分词器、编码器和排名层,以产生一对分数。这对分数将用于计算边际排名损失,这将用于反向传播步骤中,更新编码器和最终排名层的权重,并优化它们以完成任务。
在下一部分,我们将亲自动手编写代码,使用 Hugging Face transformers 模块和 Pytorch 构建上述流水线。
构建、训练和评估模型
我们在前面的部分中涵盖了理论,现在是时候亲自动手,开始处理我们的模型了。
虽然过去构建和训练复杂的深度学习模型可能很复杂,但新的现代框架使其变得更简单。
Hugging Face 是你所需的一切
Hugging Face 是一家了不起的公司,致力于使复杂的深度学习模型民主化。
它们构建了帮助你构建、加载、微调和共享复杂变换器模型的抽象。
在接下来的部分中,我们将使用他们的transformers包,该包提供了构建预训练 NLP 模型并用于自己任务所需的所有工具。在接下来的几周内,我们将更详细地探索该包提供的不同可能性
该包与 TensorFlow 和 PyTorch 库兼容。
首先,让我们安装 transformers 包
pip install transformers
从 Hugging Face 获取的模型可以在他们的Model Hub网站上找到。你可以找到各种类型的模型以及描述,以了解模型的功能、参数数量、训练数据集等。
在本文中,我们将使用架构roberta-base,这是一个相对轻量的编码器,经过多个英文语料库的训练。
模型描述提供了大量与我们的任务相关的非常重要的信息:
-
该模型具有 125M 个参数
-
该模型已在多个英文语料库上进行过训练,这一点很重要,因为我们的评论数据集是英文的
-
该模型已经在掩蔽语言模型的目标上进行过训练,这一目标是尝试预测文本中被掩蔽的单词,并使用前后的文本进行预测,这并非总是如此(例如,GPT 等模型只使用单词前的上下文来进行预测,因为它们在推断新文本时无法看到句子的未来)。
-
该模型对大小写敏感,这意味着它会区分“WORD”和“word”。这在毒性检测器中尤为重要,因为字母的大小写是判断毒性的一个重要线索。
Hugging Face 可以为每个模型提供使用的分词器以及不同配置的基本神经网络(你可能不希望所有的权重:有时你只想限制在编码器部分,解码器部分,停留在隐藏层等)。
从 Hugging Face hub 获取的模型可以在本地克隆(这样运行会更快)或直接在代码中加载,通过使用其 repo id(例如我们案例中的 roberta-base)
加载和测试分词器
要加载分词器,我们可以简单地使用 transformers 包中的 AutoTokenizer 类,并指定我们想要使用的分词器
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('roberta-base')
为了对文本进行分词,我们可以简单地调用“encode”或“encode_plus”方法。“encode_plus”不仅会提供你文本的分词版本,还会提供一个注意力掩码,用于忽略纯填充部分的编码。
text = "hello world"
tokenizer.encode_plus(
text,
truncation=True,
add_special_tokens=True,
max_length=10,
padding='max_length'
)
将返回一个字典,其中“input_ids”是编码序列,“attention_mask”用于允许变换器忽略填充的标记:
{
'input_ids': [0, 42891, 232, 2, 1, 1, 1, 1, 1, 1],
'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
}
在我们使用的参数中,有:
-
max_length: 指定编码序列的最大长度
-
add_special_tokens: 向文本中添加
和 标记 -
truncation: 如果文本不适合 max_length,则会截断文本
-
padding: 添加填充标记直到 max_length
加载预训练模型
要加载预训练模型,Hugging Face 提供了多个类,具体取决于你的需求(你是在使用 TensorFlow 还是 Pytorch?你尝试实现什么类型的任务)。
在我们的案例中,我们将使用 AutoModel,它允许你直接加载模型架构及预训练权重。请注意,如果你使用 TensorFlow,你可以通过使用 TFAutoModel 类而不是 AutoModel 类来实现相同的功能。
AutoModel 类将直接从 RobertaModel 加载模型架构,并加载与 Hugging Face 中的 “roberta-base” 仓库相关联的预训练权重。
至于 Tokenizer,我们可以直接从 repo-id 或本地仓库路径加载模型,通过使用 AutoModel 的 from_pretrained 方法:
from transformers import AutoModel
robertaBase = AutoModel.from_pretrained("roberta-base")
请注意,编码器没有在特定任务上进行训练,我们不能简单地使用模型。相反,我们需要用我们的数据集进行微调。
我们可以再三检查 robertaBase 是否是 pytorch.nn.Module 的实例,并且可以集成到更复杂的 PyTorch 架构中:
import pytorch.nn as torch
isinstance(robertaBase, nn.Module)
>> True
你也可以通过简单地打印它来检查其架构,就像你对待标准 PyTorch 模块一样:
print(robertaBase)
>> RobertaModel(
(embeddings): RobertaEmbeddings(
(word_embeddings): Embedding(50265, 768, padding_idx=1)
(position_embeddings): Embedding(514, 768, padding_idx=1)
(token_type_embeddings): Embedding(1, 768)
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): RobertaEncoder(
(layer): ModuleList(
(0-11): 12 x RobertaLayer(
(attention): RobertaAttention(
(self): RobertaSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): RobertaSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): RobertaIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): RobertaOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): RobertaPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
构建自定义神经网络
这个最后的层实际上是我们在本文第一部分讨论的整个文本的向量表示,我们只需将其连接到用于排序的最终节点,以完成我们的神经网络架构。
为此,我们将通过封装 nn.Module 来简单地构建自己的自定义模块,就像我们用 PyTorch 构建经典神经网络一样。
model_name = "roberta-base"
last_hidden_layer_size = 768
final_node_size = 1
class ToxicRankModel(nn.Module):
def __init__(self, model_name, last_hidden_layer_size):
super(ToxicRankModel, self).__init__()
self.robertaBase = AutoModel.from_pretrained(model_name)
self.dropout = nn.Dropout(p=0.1)
self.rank_head = nn.Linear(last_hidden_layer_size, 1)
def forward(self, ids, mask):
output = self.robertaBase(input_ids=ids,attention_mask=mask,
output_hidden_states=False)
output = self.dropout(output[1])
score= self.fc(output)
return score
#This line check if the GPU is available, else it goes with the CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#After initiation, we send the model to the device
toxicRankModel = ToxicRankModel(model_name, last_hidden_layer_size)
toxicRankModel = toxicRankModel.to(device)
在 forward() 方法中需要注意几点:
-
我们将两个主要输入传递给 robertBase 模型,input_ids 和 attention_mask。它们都是由 Tokenizer 生成的。
-
AutoModel 具有参数(如 output_hidden_states)。根据你选择的参数,你可以让模型作为编码器或解码器运行,并将模型定制用于不同的 NLP 任务。
-
你是否注意到我们在 dropout 中传递了 output[1]?这是因为基本模型提供了两个输入:
-
首先,最后的隐藏状态,它包含每个标记的上下文表示(或上下文嵌入),可以用于实体识别等任务。
-
其次,来自 Pooler 的输出,它包含整个文本的向量表示,就是我们在这里寻找的。
构建自定义数据集
使用 Pytorch,我们还需要创建自己的 Dataset 类,用于存储原始数据,以及 DataLoader,用于在训练过程中按批次馈送神经网络。
在使用 Pytorch 构建自定义数据集时,你必须实现两个强制性方法:
-
len,它给出训练数据的大小(对数据加载器来说是重要信息)
-
getitem,它接受原始输入(来自第“i”行)并进行预处理,以便神经网络(作为张量)可以处理
如果你记得之前部分的图示,我们实际上是在计算损失之前并行传递两个输入到模型中:less_toxic 和 more_toxic。
getitem 方法将处理消息的分词,并为转换器准备输入,将分词后的输入转换为张量。
class CustomDataset(Dataset):
def __init__(self, train_df, tokenizer, max_length):
#token list standard size
self.length = max_length
#Here the tokenizer will be an instance of the tokenizer
#shown previously
self.tokenizer = tokenizer
#df is the training df shown in the beginning of the article
self.more_toxic = train_df['more_toxic'].values
self.less_toxic = train_df['less_toxic'].values
def __len__(self):
return len(self.more_toxic)
def __getitem__(self, i):
# get both messages at index i
message_more_toxic = self.more_toxic[i]
message_less_toxic = self.less_toxic[i]
#tokenize the messages
dic_more_toxic = self.tokenizer.encode_plus(
message_more_toxic,
truncation=True,
add_special_tokens=True,
max_length=self.length,
padding='max_length'
)
dic_less_toxic = self.tokenizer.encode_plus(
message_less_toxic,
truncation=True,
add_special_tokens=True,
max_length=self.length,
padding='max_length'
)
#extract tokens and masks
tokens_more_toxic = dic_more_toxic['input_ids']
mask_more_toxic = dic_more_toxic['attention_mask']
tokens_less_toxic = dic_less_toxic['input_ids']
mask_less_toxic = dic_less_toxic['attention_mask']
#return a dictionnary of tensors
return {
'tokens_more_toxic': torch.tensor(tokens_more_toxic, dtype=torch.long),
'mask_more_toxic': torch.tensor(mask_more_toxic, dtype=torch.long),
'tokens_less_toxic': torch.tensor(tokens_less_toxic, dtype=torch.long),
'mask_less_toxic': torch.tensor(mask_less_toxic, dtype=torch.long),
}
我们现在可以生成 DataLoader,用于模型的批量训练。
def get_loader(df, tokenizer, max_length, batch_size):
dataset = CustomDataset(
df,
tokenizer=tokenizer,
max_length=max_length
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
max_length = 128
batch_size = 32
train_loader = get_loader(train_df, tokenizer, max_length, batch_size=batch_size)
-
batch_size 指定了用于前向传递/反向传播的样本数量
-
shuffle = True 意味着数据集在两个 epoch 之间会被打乱
-
drop_last 意味着如果最后一个 batch 没有正确数量的样本,它将被丢弃。这一点很重要,因为 batch normalization 对于不完整的 batch 处理效果不好。
训练模型
我们快完成了,现在是时候为一个 epoch 准备训练流程了。
自定义损失
首先,让我们定义一个自定义损失函数。 Pytorch 已经提供了 MarginRankingLoss,我们只是将其封装为 y = 1(因为我们将始终将 more_toxic 作为 x1,less_toxic 作为 x2)。
from torch.nn import MarginRankingLoss
#Custom implementation of the MarginRankingLoss with y = 1
class CustomMarginRankingLoss(nn.Module):
def __init__(self, margin=0):
super(CustomMarginRankingLoss, self).__init__()
self.margin = margin
def forward(self, x1, x2):
#with y=1 this is how looks the loss
loss = torch.relu(x2 - x1 + self.margin)
return loss.mean()
def criterion(x1, x2):
return CustomMarginRankingLoss()(x1, x2)
优化器
对于这个实验,我们将使用经典的 AdamW,它目前是最先进的,并解决了原始 Adam 实现的一些问题。
optimizer_lr = 1e-4
optimizer_weight_decay = 1e-6
optimizer = AdamW(toxicRankModel.parameters(),
lr=optimizer_lr,
weight_decay=optimizer_weight_decay)
调度器
调度器有助于调整学习率。在开始时,我们希望较高的学习率以更快地收敛到最佳解,而在训练结束时,我们希望较小的学习率以真正微调权重。
scheduler_T_max = 500
scheduler_eta_min = 1e-6
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=scheduler_T_max, eta_min=scheduler_eta_min)
训练例程
我们现在准备好训练我们的 NLP 模型以进行毒性评论排序。
使用 Pytorch 训练一个 epoch 非常简单:
-
我们迭代通过我们的数据加载器,它会从数据集中打乱并选择预处理的数据
-
我们从数据加载器中提取 tokens 和 masks
-
我们通过对模型进行前向传递来计算每条消息的排名
-
当两个排名都计算完毕后,我们可以计算 MarginRankingLoss(用于反向传播),以及一个准确率分数,表示正确分类的对数百分比(仅供参考)
-
我们更新我们的系统(反向传播、优化器和调度器)
-
我们迭代直到数据加载器中的所有数据都被使用完。
def train_one_epoch(model, optimizer, scheduler, dataloader, device):
#Setup train mode, this is important as some layers behave differently
# during train and inference (like batch norm)
model.train()
#Initialisation of some loss
dataset_size = 0
running_loss = 0.0
running_accuracy = 0.0
progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Training")
for i, data in progress_bar:
more_toxic_ids = data['tokens_more_toxic'].to(device, dtype = torch.long)
more_toxic_mask = data['mask_more_toxic'].to(device, dtype = torch.long)
less_toxic_ids = data['tokens_less_toxic'].to(device, dtype = torch.long)
less_toxic_mask = data['mask_less_toxic'].to(device, dtype = torch.long)
batch_size = more_toxic_ids.size(0)
#Forward pass both inputs in the model
x1 = model(more_toxic_ids, more_toxic_mask)
x2 = model(less_toxic_ids, less_toxic_mask)
#Compute margin ranking loss
loss = criterion(x1, x2)
accuracy_measure = (x1 > x2).float().mean().item()
#apply backpropagation, increment optimizer
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
#Update cumulative loss for monitoring
running_loss += (loss.item() * batch_size)
dataset_size += batch_size
epoch_loss = running_loss / dataset_size
running_accuracy += (accuracy_measure * batch_size)
epoch_accuracy = running_accuracy / dataset_size
progress_bar.set_postfix({'loss': epoch_loss, 'accuracy': epoch_accuracy}, refresh=True)
#Garbage collector
gc.collect()
return epoch_loss
我在 Kaggle 的 GPU T4 上训练了模型,使我获得了 70% 的评论正确分类的可观成绩。我可能通过调整不同的参数和使用更多的 epochs 提高准确性,但这对于本文的目的来说已经足够了。
关于推断的最后一点
我们建立的框架在从预格式化的评论集合中训练时效果很好。
但在“生产”场景下,这种方法就不起作用了,因为你会接收到一堆需要评估毒性评分的消息。

这是一个生产模式下的数据集示例,在这种模式下,我们只接收单条消息,而不是消息对。
对于推断,你将设计另一个 Dataset 类和另一个 DataLoader,这些将与我们之前做的有所不同:
class CustomInferenceDataset(Dataset):
def __init__(self, messages, tokenizer, max_length):
#token list standard size
self.length = max_length
#Here the tokenizer will be an instance of the tokenizer
#shown previously
self.tokenizer = tokenizer
#df is the training df shown in the beginning of the article
self.messages = messages
def __len__(self):
return len(self.messages)
def __getitem__(self, i):
# get a message at index i
message = self.messages[i]
#tokenize the message
dic_messages = self.tokenizer.encode_plus(
message,
truncation=True,
add_special_tokens=True,
max_length=self.length,
padding='max_length'
)
#extract tokens and masks
tokens_message = dic_messages['input_ids']
mask_message = dic_messages['attention_mask']
#return a dictionnary of tensors
return {
'tokens_message': torch.tensor(tokens_message, dtype=torch.long),
'mask_message': torch.tensor(mask_message, dtype=torch.long),
}
def get_loader_inference(messages, tokenizer, max_length, batch_size):
dataset = CustomInferenceDataset(
messages,
tokenizer=tokenizer,
max_length=max_length
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=False)
变化了什么:
-
我们不再加载消息对,而是单条消息。
-
Loader 没有对数据进行打乱(如果你不想要与原始向量关联的随机分数带来不好的惊喜,这一点非常重要)。
-
由于没有批量归一化计算,并且我们希望对所有数据进行推断,我们将 drop_last 设置为 False,以获取所有批次,即使是未完成的批次。
最后,为了生成排序分数:
@torch.no_grad()
def get_scores(model, test_loader, device):
model.eval() # Set the model to evaluation mode
ranks = [] # List to store the rank scores
progress_bar = tqdm(enumerate(test_loader), total=len(test_loader), desc="Scoring")
for i, data in progress_bar:
tokens_message = data['tokens_message'].to(device, dtype=torch.long)
mask_message = data['mask_message'].to(device, dtype=torch.long)
# Forward pass to get the rank scores
rank = model(tokens_message, mask_message)
# Convert tensor to NumPy and add to the list
ranks+=list(rank.cpu().numpy().flatten())
return ranks
这是推断后的前 5 条分类消息。为了保持政治正确,我在这里进行了些许审查…

已识别的最具毒性的消息
不太具有建设性… 😃
结论
在这篇文章中,我们利用了 Hugging Face 预训练模型和 Pytorch 生产了一个能够对消息的毒性等级进行排序的模型。
为此,我们采用了一个“小型”的“Roberta”变换器,并使用 PyTorch 在其编码器末尾连接了一个最终简单的节点。其余部分则更为经典,可能与你之前用 PyTorch 做的其他项目类似。
这个项目是对 NLP 提供的可能性的初步探索,我想简单地介绍一些基础概念,以便进一步研究更具挑战性的任务或更大的模型。
希望你喜欢阅读,如果你想玩玩这个模型,你可以从 我的 GitHub 下载一个 Notebook。
在 Julia 中构建一个符合预测的聊天机器人
原文:
towardsdatascience.com/building-a-conformal-chatbot-in-julia-1ed23363a280
符合预测、LLMs 和 HuggingFace — 第一部分
·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 7 月 5 日
--
大型语言模型(LLM)目前非常受关注。它们被用于各种任务,包括文本分类、问答和文本生成。在本教程中,我们将展示如何使用[ConformalPrediction.jl](https://juliatrustworthyai.github.io/ConformalPrediction.jl/dev/)将变换器语言模型符合化,以进行文本分类。
👀 一览
我们特别关注意图分类任务,如下图所示。首先,我们将客户查询输入到 LLM 中以生成嵌入。接着,我们训练一个分类器,将这些嵌入与可能的意图匹配。当然,对于这个监督学习问题,我们需要由输入——查询——和输出——指示真实意图的标签——组成的训练数据。最后,我们应用符合预测来量化分类器的预测不确定性。
符合预测(CP)是一种快速发展的预测不确定性量化方法。如果你不熟悉 CP,建议你首先查看我关于这一主题的三部分介绍系列,从这篇文章开始。

符合化意图分类器的高级概述。图片由作者提供。
🤗 HuggingFace
我们将使用Banking77数据集(Casanueva 等,2020),该数据集包含 77 个与银行相关的意图中的 13,083 个查询。在模型方面,我们将使用DistilRoBERTa模型,它是RoBERTa(Liu 等,2019)的蒸馏版,并在 Banking77 数据集上进行了微调。
可以使用 [Transformers.jl](https://github.com/chengchingwen/Transformers.jl/tree/master) 包将模型从 HF 直接加载到我们正在运行的 Julia 会话中。
这个包使得在 Julia 中使用 HF 模型变得非常简单。向开发者们致敬!🙏
下面我们加载分词器tkr和模型mod。分词器用于将文本转换为整数序列,然后将其输入模型。模型输出一个隐藏状态,然后将其输入分类器,以获得每个类别的 logits。最后,这些 logits 通过 softmax 函数以获得相应的预测概率。下面我们运行几个查询来查看模型的表现。
# Load model from HF 🤗:
tkr = hgf"mrm8488/distilroberta-finetuned-banking77:tokenizer"
mod = hgf"mrm8488/distilroberta-finetuned-banking77:ForSequenceClassification"
# Test model:
query = [
"What is the base of the exchange rates?",
"Why is my card not working?",
"My Apple Pay is not working, what should I do?",
]
a = encode(tkr, query)
b = mod.model(a)
c = mod.cls(b.hidden_state)
d = softmax(c.logit)
[labels[i] for i in Flux.onecold(d)]
3-element Vector{String}:
"exchange_rate"
"card_not_working"
"apple_pay_or_google_pay"
🔁 MLJ接口
由于我们的包与 [MLJ.jl](https://alan-turing-institute.github.io/MLJ.jl/dev/) 接口对接,我们需要定义一个符合MLJ接口的包装模型。为了将模型添加到通用使用中,我们可能会通过 [MLJFlux.jl](https://github.com/FluxML/MLJFlux.jl) 来实现,但在本教程中,我们将简化操作,直接重载MLJBase.fit和MLJBase.predict方法。
由于 HF 的模型已经是预训练的,我们不打算进一步微调,因此我们将在MLJBase.fit方法中简单地返回模型对象。MLJBase.predict方法将接收模型对象和查询,并返回预测概率。我们还需要定义MLJBase.target_scitype和MLJBase.predict_mode方法。前者告诉MLJ模型的输出类型是什么,后者可以用来检索具有最高预测概率的标签。
struct IntentClassifier <: MLJBase.Probabilistic
tkr::TextEncoders.AbstractTransformerTextEncoder
mod::HuggingFace.HGFRobertaForSequenceClassification
end
function IntentClassifier(;
tkr::TextEncoders.AbstractTransformerTextEncoder,
mod::HuggingFace.HGFRobertaForSequenceClassification,
)
IntentClassifier(tkr, mod)
end
function get_hidden_state(clf::IntentClassifier, query::Union{AbstractString, Vector{<:AbstractString}})
token = encode(clf.tkr, query)
hidden_state = clf.mod.model(token).hidden_state
return hidden_state
end
# This doesn't actually retrain the model, but it retrieves the classifier object
function MLJBase.fit(clf::IntentClassifier, verbosity, X, y)
cache=nothing
report=nothing
fitresult = (clf = clf.mod.cls, labels = levels(y))
return fitresult, cache, report
end
function MLJBase.predict(clf::IntentClassifier, fitresult, Xnew)
output = fitresult.clf(get_hidden_state(clf, Xnew))
p̂ = UnivariateFinite(fitresult.labels,softmax(output.logit)',pool=missing)
return p̂
end
MLJBase.target_scitype(clf::IntentClassifier) = AbstractVector{<:Finite}
MLJBase.predict_mode(clf::IntentClassifier, fitresult, Xnew) = mode.(MLJBase.predict(clf, fitresult, Xnew))
为了测试一切是否按预期工作,我们拟合了模型并为测试数据的子集生成了预测:
clf = IntentClassifier(tkr, mod)
top_n = 10
fitresult, _, _ = MLJBase.fit(clf, 1, nothing, y_test[1:top_n])
@time ŷ = MLJBase.predict(clf, fitresult, queries_test[1:top_n]);
6.818024 seconds (11.29 M allocations: 799.165 MiB, 2.47% gc time, 91.04% compilation time)
注意,即使我们使用的 LLM 并不大,但即使是简单的前向传递也需要相当的时间。
🤖 合成聊天机器人
为了将包装好的预训练模型转变为合成意图分类器,我们现在可以依靠标准 API 调用。我们首先包装我们的原子模型,并指定所需的覆盖率和方法。由于即使是简单的前向传递对我们(小)LLM 来说也非常计算密集,我们依赖于简单归纳合成分类。
conf_model = conformal_model(clf; coverage=0.95, method=:simple_inductive, train_ratio=train_ratio)
mach = machine(conf_model, queries, y)
最后,我们使用合成 LLM 构建一个简单而强大的聊天机器人,直接在 Julia REPL 中运行。在不详细探讨细节的情况下,conformal_chatbot 的工作原理如下:
-
提示用户解释他们的意图。
-
通过合成 LLM 处理用户输入,并将输出呈现给用户。
-
如果合成预测集包含多个标签,请提示用户要么细化输入,要么选择预测集中的选项之一。
以下代码实现了这些想法:
function prediction_set(mach, query::String)
p̂ = MLJBase.predict(mach, query)[1]
probs = pdf.(p̂, collect(1:77))
in_set = findall(probs .!= 0)
labels_in_set = labels[in_set]
probs_in_set = probs[in_set]
_order = sortperm(-probs_in_set)
plt = UnicodePlots.barplot(labels_in_set[_order], probs_in_set[_order], title="Possible Intents")
return labels_in_set, plt
end
function conformal_chatbot()
println("👋 Hi, I'm a Julia, your conformal chatbot. I'm here to help you with your banking query. Ask me anything or type 'exit' to exit ...\n")
completed = false
queries = ""
while !completed
query = readline()
queries = queries * "," * query
labels, plt = prediction_set(mach, queries)
if length(labels) > 1
println("🤔 Hmmm ... I can think of several options here. If any of these applies, simply type the corresponding number (e.g. '1' for the first option). Otherwise, can you refine your question, please?\n")
println(plt)
else
println("🥳 I think you mean $(labels[1]). Correct?")
end
# Exit:
if query == "exit"
println("👋 Bye!")
break
end
if query ∈ string.(collect(1:77))
println("👍 Great! You've chosen '$(labels[parse(Int64, query)])'. I'm glad I could help you. Have a nice day!")
completed = true
end
end
end
下面我们展示了两个示例查询的输出。第一个查询非常模糊(而且刚刚发现拼写错误):“transfer mondey?”。因此,预测集的大小很大。
ambiguous_query = "transfer mondey?"
prediction_set(mach, ambiguous_query)[2]
Possible Intents
┌ ┐
beneficiary_not_allowed ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.150517
balance_not_updated_after_bank_transfer ┤■■■■■■■■■■■■■■■■■■■■■■ 0.111409
transfer_into_account ┤■■■■■■■■■■■■■■■■■■■ 0.0939535
transfer_not_received_by_recipient ┤■■■■■■■■■■■■■■■■■■ 0.091163
top_up_by_bank_transfer_charge ┤■■■■■■■■■■■■■■■■■■ 0.089306
failed_transfer ┤■■■■■■■■■■■■■■■■■■ 0.0888322
transfer_timing ┤■■■■■■■■■■■■■ 0.0641952
transfer_fee_charged ┤■■■■■■■ 0.0361131
pending_transfer ┤■■■■■ 0.0270795
receiving_money ┤■■■■■ 0.0252126
declined_transfer ┤■■■ 0.0164443
cancel_transfer ┤■■■ 0.0150444
└ ┘
以下是更精炼的提示版本:“我试图给朋友转账,但失败了。” 由于不那么模糊的提示会导致较低的预测不确定性,因此它产生了较小的预测集。
refined_query = "I tried to transfer money to my friend, but it failed."
prediction_set(mach, refined_query)[2]
Possible Intents
┌ ┐
failed_transfer ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.59042
beneficiary_not_allowed ┤■■■■■■■ 0.139806
transfer_not_received_by_recipient ┤■■ 0.0449783
balance_not_updated_after_bank_transfer ┤■■ 0.037894
declined_transfer ┤■ 0.0232856
transfer_into_account ┤■ 0.0108771
cancel_transfer ┤ 0.00876369
└ ┘
下面的视频展示了 REPL 基础聊天机器人在实际应用中的表现。你可以自己重现这个过程,并直接从你的终端运行机器人。为此,请查看我博客上的原始帖子以获取完整的源代码。

REPL 基础的符合性聊天机器人的演示。由作者创建。
🌯 总结
这项工作是与 ING 的同事合作完成的,作为 ING Analytics 2023 实验周的一部分。我们的团队展示了符合性预测提供了对顶级-K意图分类的强大而有原则的替代方案。我们通过大众投票赢得了第一名。
当然,这里还有很多可以改进的地方。就大型语言模型而言,我们使用了一个较小的模型。在符合性预测方面,我们只关注了简单的归纳符合性分类。这是一个好的起点,但还有更高级的方法可用,这些方法已经在软件包中实现,并在竞赛中进行了研究。另一个我们没有考虑的方面是我们有许多结果类别,实际上可能希望实现类别条件覆盖。请关注未来的帖子了解更多内容。
如果你对在 Julia 中了解更多关于符合性预测的内容感兴趣,请查看代码库和文档。
🎉 JuliaCon 2023 即将到来,今年我将进行一场关于ConformalPrediction.jl的讲座。请查看我的讲座详细信息,并浏览内容丰富的会议日程。
🎓 参考文献
Casanueva, Iñigo, Tadas Temčinas, Daniela Gerz, Matthew Henderson, 和 Ivan Vulić. 2020. “使用双句子编码器的高效意图检测。” 第二届对话 AI 自然语言处理研讨会论文集 , 38–45. 在线:计算语言学协会. doi.org/10.18653/v1/2020.nlp4convai-1.5。
Liu, Yinhan, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, 和 Veselin Stoyanov. 2019. “RoBERTa:一种稳健优化的 BERT 预训练方法。” arXiv. doi.org/10.48550/arXiv.1907.11692。
💾 数据和模型
Banking77 数据集是从 HuggingFace 获取的。它在知识共享署名 4.0 国际许可协议(CC BY 4.0)下发布,由 PolyAI 策划,并由 Casanueva 等人(2020 年)最初发布。还要感谢 Manuel Romero 为 HuggingFace 贡献了经过微调的 DistilRoBERTa。
最初发布于 https://www.paltmeyer.com 于 2023 年 7 月 5 日。
使用 OpenAI 和 FastAPI 构建记忆微服务的对话代理

充满记忆的对话,照片由 Juri Gianfrancesco 在 Unsplash 提供。
制作上下文感知的对话代理:深入探讨 OpenAI 和 FastAPI 的集成
·
关注 发表在 Towards Data Science · 30 分钟阅读 · 2023 年 8 月 17 日
--
介绍
在本教程中,我们将探索使用 OpenAI 和 FastAPI 创建具有内存微服务的对话代理的过程。对话代理已成为各种应用程序中的关键组件,包括客户支持、虚拟助手和信息检索系统。然而,许多传统的聊天机器人实现缺乏在对话过程中保留上下文的能力,导致功能有限和令人沮丧的用户体验。这在遵循微服务架构构建代理服务时尤其具有挑战性。
GitHub 仓库的链接在文章底部。
动机
本教程的动机是解决传统聊天机器人实现的局限性,并创建一个具有内存微服务的对话代理,这在将代理部署到像 Kubernetes 这样的复杂环境中时尤为重要。在 Kubernetes 或类似的容器编排系统中,微服务经常经历重启、更新和扩展操作。在这些事件中,传统聊天机器人的对话状态将丢失,导致断裂的互动和糟糕的用户体验。
通过构建具有内存微服务的对话代理,我们可以确保在微服务重启或更新时,甚至在交互不连续的情况下,重要的对话上下文得以保留。这种状态的保存使代理能够无缝地继续之前的对话,保持连贯性,并提供更自然和个性化的用户体验。此外,这种方法符合现代应用开发的最佳实践,其中容器化的微服务通常与其他组件交互,使得内存微服务在这种分布式设置中成为对话代理架构中的有价值的补充。
我们将使用的技术栈
对于这个项目,我们将主要使用以下技术和工具:
-
OpenAI GPT-3.5:我们将利用 OpenAI 的 GPT-3.5 语言模型,该模型能够执行各种自然语言处理任务,包括文本生成、对话管理和上下文保留。我们需要生成一个 OpenAI API 密钥,请确保访问此 URL 以管理您的密钥。
-
FastAPI:FastAPI 将作为我们微服务的骨干,提供处理 HTTP 请求、管理对话状态和与 OpenAI API 集成的基础设施。FastAPI 非常适合用 Python 构建微服务。
开发周期
在本节中,我们将深入探讨构建具有内存微服务的对话代理的逐步过程。开发周期将包括:
-
环境设置:我们将创建一个虚拟环境并安装必要的依赖项,包括 OpenAI 的 Python 库和 FastAPI。
-
设计记忆微服务:我们将概述记忆微服务的架构和设计,该服务将负责存储和管理对话上下文。
-
集成 OpenAI:我们将把 OpenAI 的 GPT-3.5 模型集成到我们的应用中,并定义处理用户消息和生成响应的逻辑。
-
测试:我们将逐步测试我们的对话代理。
环境设置
对于这个设置,我们将使用以下结构来构建微服务。这对于在同一个项目下扩展其他服务非常方便,而且我个人喜欢这种结构。
├── Dockerfile <--- Container
├── requirements.txt <--- Libraries and Dependencies
├── setup.py <--- Build and distribute microservices as Python packages
└── src
├── agents <--- Name of your Microservice
│ ├── __init__.py
│ ├── api
│ │ ├── __init__.py
│ │ ├── routes.py
│ │ └── schemas.py
│ ├── crud.py
│ ├── database.py
│ ├── main.py
│ ├── models.py
│ └── processing.py
└── agentsfwrk <--- Name of your Common Framework
├── __init__.py
├── integrations.py
└── logger.py
我们需要在项目中创建一个名为src的文件夹,其中将包含服务的 Python 代码;在我们的例子中,agents包含与对话代理和 API 相关的所有代码,agentsfwrk是我们用于跨服务的通用框架。
Dockerfile包含构建镜像的指令,一旦代码准备好,requirements.txt包含我们项目中使用的库,setup.py包含构建和分发项目的指令。
目前,只需创建服务文件夹以及__init__.py文件,并将以下内容添加到项目根目录的requirements.txt和setup.py中,Dockerfile保持空白,我们将在部署周期部分回到它。
# Requirements.txt
fastapi==0.95.2
ipykernel==6.22.0
jupyter-bokeh==2.0.2
jupyterlab==3.6.3
openai==0.27.6
pandas==2.0.1
sqlalchemy-orm==1.2.10
sqlalchemy==2.0.15
uvicorn<0.22.0,>=0.21.1
# setup.py
from setuptools import find_packages, setup
setup(
name = 'conversational-agents',
version = '0.1',
description = 'microservices for conversational agents',
packages = find_packages('src'),
package_dir = {'': 'src'},
# This is optional btw
author = 'XXX XXXX',
author_email = 'XXXX@XXXXX.ai',
maintainer = 'XXX XXXX',
maintainer_email = 'XXXX@XXXXX.ai',
)
让我们激活虚拟环境,并在终端运行pip install -r requirements.txt。我们暂时不会运行 setup 文件,所以接下来进入下一部分。
设计通用框架
我们将设计我们的通用框架,以便在项目中构建的所有微服务中使用。这对小型项目来说不是严格必要的,但考虑到未来,你可以扩展它以使用多个 LLM 提供商,添加与自己数据交互的其他库(例如LangChain,VoCode),以及其他通用功能,如语音和图像服务,而无需在每个微服务中实现它们。
创建文件夹和文件时请遵循agentsfwrk结构。每个文件及其描述如下:
└── agentsfwrk <--- Name of your Common Framework
├── __init__.py
├── integrations.py
└── logger.py
日志记录器是一个非常基础的工具,用于设置通用日志模块,你可以按如下方式定义它:
import logging
import multiprocessing
import sys
APP_LOGGER_NAME = 'CaiApp'
def setup_applevel_logger(logger_name = APP_LOGGER_NAME, file_name = None):
"""
Setup the logger for the application
"""
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
logger.handlers.clear()
logger.addHandler(sh)
if file_name:
fh = logging.FileHandler(file_name)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def get_multiprocessing_logger(file_name = None):
"""
Setup the logger for the application for multiprocessing
"""
logger = multiprocessing.get_logger()
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
if not len(logger.handlers):
logger.addHandler(sh)
if file_name:
fh = logging.FileHandler(file_name)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def get_logger(module_name, logger_name = None):
"""
Get the logger for the module
"""
return logging.getLogger(logger_name or APP_LOGGER_NAME).getChild(module_name)
接下来,我们的集成层通过集成模块完成。此文件充当微服务逻辑与 OpenAI 之间的中介,并设计为以统一的方式向我们的应用程序公开 LLM 提供商。在这里,我们可以实现处理异常、错误、重试和请求或响应超时的通用方法。我从一位非常优秀的经理那里学到,要始终在外部服务/API 和我们应用的内部世界之间放置一个集成层。
集成代码定义如下:
# integrations.py
# LLM provider common module
import json
import os
import time
from typing import Union
import openai
from openai.error import APIConnectionError, APIError, RateLimitError
import agentsfwrk.logger as logger
log = logger.get_logger(__name__)
openai.api_key = os.getenv('OPENAI_API_KEY')
class OpenAIIntegrationService:
def __init__(
self,
context: Union[str, dict],
instruction: Union[str, dict]
) -> None:
self.context = context
self.instructions = instruction
if isinstance(self.context, dict):
self.messages = []
self.messages.append(self.context)
elif isinstance(self.context, str):
self.messages = self.instructions + self.context
def get_models(self):
return openai.Model.list()
def add_chat_history(self, messages: list):
"""
Adds chat history to the conversation.
"""
self.messages += messages
def answer_to_prompt(self, model: str, prompt: str, **kwargs):
"""
Collects prompts from user, appends to messages from the same conversation
and return responses from the gpt models.
"""
# Preserve the messages in the conversation
self.messages.append(
{
'role': 'user',
'content': prompt
}
)
retry_exceptions = (APIError, APIConnectionError, RateLimitError)
for _ in range(3):
try:
response = openai.ChatCompletion.create(
model = model,
messages = self.messages,
**kwargs
)
break
except retry_exceptions as e:
if _ == 2:
log.error(f"Last attempt failed, Exception occurred: {e}.")
return {
"answer": "Sorry, I'm having technical issues."
}
retry_time = getattr(e, 'retry_after', 3)
log.error(f"Exception occurred: {e}. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
response_message = response.choices[0].message["content"]
response_data = {"answer": response_message}
self.messages.append(
{
'role': 'assistant',
'content': response_message
}
)
return response_data
def answer_to_simple_prompt(self, model: str, prompt: str, **kwargs) -> dict:
"""
Collects context and appends a prompt from a user and return response from
the gpt model given an instruction.
This method only allows one message exchange.
"""
messages = self.messages + f"\n<Client>: {prompt} \n"
retry_exceptions = (APIError, APIConnectionError, RateLimitError)
for _ in range(3):
try:
response = openai.Completion.create(
model = model,
prompt = messages,
**kwargs
)
break
except retry_exceptions as e:
if _ == 2:
log.error(f"Last attempt failed, Exception occurred: {e}.")
return {
"intent": False,
"answer": "Sorry, I'm having technical issues."
}
retry_time = getattr(e, 'retry_after', 3)
log.error(f"Exception occurred: {e}. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
response_message = response.choices[0].text
try:
response_data = json.loads(response_message)
answer_text = response_data.get('answer')
if answer_text is not None:
self.messages = self.messages + f"\n<Client>: {prompt} \n" + f"<Agent>: {answer_text} \n"
else:
raise ValueError("The response from the model is not valid.")
except ValueError as e:
log.error(f"Error occurred while parsing response: {e}")
log.error(f"Prompt from the user: {prompt}")
log.error(f"Response from the model: {response_message}")
log.info("Returning a safe response to the user.")
response_data = {
"intent": False,
"answer": response_message
}
return response_data
def verify_end_conversation(self):
"""
Verify if the conversation has ended by checking the last message from the user
and the last message from the assistant.
"""
pass
def verify_goal_conversation(self, model: str, **kwargs):
"""
Verify if the conversation has reached the goal by checking the conversation history.
Format the response as specified in the instructions.
"""
messages = self.messages.copy()
messages.append(self.instructions)
retry_exceptions = (APIError, APIConnectionError, RateLimitError)
for _ in range(3):
try:
response = openai.ChatCompletion.create(
model = model,
messages = messages,
**kwargs
)
break
except retry_exceptions as e:
if _ == 2:
log.error(f"Last attempt failed, Exception occurred: {e}.")
raise
retry_time = getattr(e, 'retry_after', 3)
log.error(f"Exception occurred: {e}. Retrying in {retry_time} seconds...")
time.sleep(retry_time)
response_message = response.choices[0].message["content"]
try:
response_data = json.loads(response_message)
if response_data.get('summary') is None:
raise ValueError("The response from the model is not valid. Missing summary.")
except ValueError as e:
log.error(f"Error occurred while parsing response: {e}")
log.error(f"Response from the model: {response_message}")
log.info("Returning a safe response to the user.")
raise
return response_data
关于集成模块的一些说明:
-
OpenAI 密钥被定义为名为“OPENAI_API_KEY”的环境变量,我们应该下载这个密钥并在终端中定义它,或使用python-dotenv库。
-
有两种方法可以与 GPT 模型集成,一种用于聊天端点(
answer_to_prompt),另一种用于完成端点(answer_to_simple_prompt)。我们将专注于第一个的使用。 -
有一种方法来检查对话的目标——
verify_goal_conversation,它简单地遵循代理的指示并生成总结。
设计(内存)微服务
最佳练习是设计并绘制一个图表来可视化服务需要做的事情,包括参与者及其在与服务交互时的行动。我们从简单地描述我们的应用程序开始:
-
我们的微服务是一个人工智能代理的提供者,这些代理在某一主题上是专家,预计会根据外部消息和后续提示进行对话。
-
我们的代理可以进行多次对话,并且包含需要持久化的内存,这意味着它们必须能够保留对话历史记录,无论与代理交互的客户端会话如何。
-
代理在创建时应接收清晰的指示,说明如何处理对话并在对话过程中做出相应响应。
-
对于程序化集成,代理也应遵循预期的响应格式。
我们的设计如下图所示:

对话代理设计——作者提供的图像
通过这个简单的图表,我们知道我们的微服务需要实现负责这些特定任务的方法:
-
代理的创建 & 指令的定义
-
对话启动器 & 对话历史记录的保存
-
与代理聊天
我们将按照顺序编写这些功能,在此之前我们将构建应用程序的骨架。
应用程序骨架
为了启动开发,我们首先构建 FastAPI 应用程序骨架。应用程序骨架包括基本组件,如主要应用程序脚本、数据库配置、处理脚本和路由模块。主要脚本作为应用程序的入口点,我们在此处设置 FastAPI 实例。
主要文件
在你的agents文件夹中创建/打开main.py文件并输入以下代码,该代码简单地定义了一个根端点。
from fastapi import FastAPI
from agentsfwrk.logger import setup_applevel_logger
log = setup_applevel_logger(file_name = 'agents.log')
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Hello there conversational ai user!"}
数据库配置
然后我们创建/打开名为database.py的数据库配置脚本,该脚本建立与本地数据库的连接,用于存储和检索对话上下文。我们将首先使用本地 SQLite 以简化操作,但可以根据你的环境尝试其他数据库。
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
SQLALCHEMY_DATABASE_URL = "sqlite:///agents.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args = {"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit = False, autoflush = False, bind = engine)
Base = declarative_base()
API 路由
最后,我们定义处理传入 HTTP 请求的路由模块,涵盖处理用户交互的端点。让我们创建api文件夹,创建/打开routes.py文件,并粘贴以下代码。
from typing import List
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
import agents.api.schemas
import agents.models
from agents.database import SessionLocal, engine
from agentsfwrk import integrations, logger
log = logger.get_logger(__name__)
agents.models.Base.metadata.create_all(bind = engine)
# Router basic information
router = APIRouter(
prefix = "/agents",
tags = ["Chat"],
responses = {404: {"description": "Not found"}}
)
# Dependency: Used to get the database in our endpoints.
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# Root endpoint for the router.
@router.get("/")
async def agents_root():
return {"message": "Hello there conversational ai!"}
有了这个结构化的骨架,我们已经准备好开始编写我们设计的应用程序。
创建代理并分配指令
在本节中,我们将重点实现“创建代理”端点。此端点使用户能够启动新的对话并与代理互动,提供上下文和一组指令,以便代理在整个对话过程中遵循。我们将首先介绍两个数据模型:一个用于数据库,另一个用于 API。我们将使用Pydantic来创建数据模型。创建/打开schemas.py文件,并定义 Agent base、Agent Create 和 Agent 数据模型。
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel
class AgentBase(BaseModel): # <-- Base model
context: str # <-- Our agents context
first_message: str # <-- Our agents will approach the users with a first message.
response_shape: str # <-- The expected shape (for programatic communication) of the response of each agent's interaction with the user
instructions: str # <-- Set of instructions that our agent should follow.
class AgentCreate(AgentBase): # <-- Creation data model
pass
class Agent(AgentBase): # <-- Agent data model
id: str
timestamp: datetime = datetime.utcnow()
class Config:
orm_mode = True
agent 数据模型中的字段如下所述:
-
上下文:这是代理的整体背景。
-
首条消息:我们的代理旨在与用户开始对话。这可以简单到“你好,我可以帮你做什么?”或者类似“嗨,你请求一个代理来帮助你找到有关股票的信息,对吗?”。
-
响应格式:该字段主要用于指定代理响应的输出格式,并应用于将 LLM 的文本输出转换为所需的格式,以便进行程序化通信。例如,我们可能希望指定我们的代理应该将响应包装在一个名为
response的 JSON 格式中,即{'response': "string"}。 -
指令:该字段包含每个代理在整个对话过程中应遵循的指令和指南,例如“在每次交互中收集以下实体 [e1, e2, e3, …]”或“回复用户直到他不再对对话感兴趣”或“不要偏离主题,并在必要时将对话引导回主要目标”。
我们现在继续打开models.py文件,在其中编写属于 agent 实体的数据库表。
from sqlalchemy import Column, ForeignKey, String, DateTime, JSON
from sqlalchemy.orm import relationship
from datetime import datetime
from agents.database import Base
class Agent(Base):
__tablename__ = "agents"
id = Column(String, primary_key = True, index = True)
timestamp = Column(DateTime, default = datetime.utcnow)
context = Column(String, nullable = False)
first_message = Column(String, nullable = False)
response_shape = Column(JSON, nullable = False)
instructions = Column(String, nullable = False)
这段代码与 Pydantic 模型非常相似,它定义了我们数据库中的代理表。
在我们有了两个数据模型后,我们准备好实现代理的创建。为此,我们将首先修改routes.py文件,添加端点:
@router.post("/create-agent", response_model = agents.api.schemas.Agent)
async def create_agent(campaign: agents.api.schemas.AgentCreate, db: Session = Depends(get_db)):
"""
Create an agent
"""
log.info(f"Creating agent")
# db_agent = create_agent(db, agent)
log.info(f"Agent created with id: {db_agent.id}")
return db_agent
我们需要创建一个新函数,该函数接收来自请求的 Agent 对象,并将其保存到数据库中。为此,我们将创建/打开crud.py文件,该文件将包含所有与数据库的交互(创建、读取、更新、删除)。
# crud.py
import uuid
from sqlalchemy.orm import Session
from agents import models
from agents.api import schemas
def create_agent(db: Session, agent: schemas.AgentCreate):
"""
Create an agent in the database
"""
db_agent = models.Agent(
id = str(uuid.uuid4()),
context = agent.context,
first_message = agent.first_message,
response_shape = agent.response_shape,
instructions = agent.instructions
)
db.add(db_agent)
db.commit()
db.refresh(db_agent)
return db_agent
创建完函数后,我们现在可以回到routes.py,导入crud模块,并在端点方法中使用它。
import agents.crud
@router.post("/create-agent", response_model = agents.api.schemas.Agent)
async def create_agent(agent: agents.api.schemas.AgentCreate, db: Session = Depends(get_db)):
"""
Create an agent endpoint.
"""
log.info(f"Creating agent: {agent.json()}")
db_agent = agents.crud.create_agent(db, agent)
log.info(f"Agent created with id: {db_agent.id}")
return db_agent
现在让我们回到main.py文件,添加“agents”路由。修改
# main.py
from fastapi import FastAPI
from agents.api.routes import router as ai_agents # NOTE: <-- new addition
from agentsfwrk.logger import setup_applevel_logger
log = setup_applevel_logger(file_name = 'agents.log')
app = FastAPI()
app.include_router(router = ai_agents) # NOTE: <-- new addition
@app.get("/")
async def root():
return {"message": "Hello there conversational ai user!"}
让我们测试一下这个功能。首先,我们需要将我们的服务安装为 Python 包,其次,在 8000 端口启动应用程序。
# Run from the root of the project.
$ pip install -e .
# Command to run the app.
$ uvicorn agents.main:app --host 0.0.0.0 --port 8000 --reload
访问 0.0.0.0:8000/docs,你将看到带有测试端点的 Swagger UI。提交你的负载并检查输出。

create-agent 端点来自 Swagger UI — 图片由作者提供
我们将继续开发我们的应用程序,但测试第一个端点是进展的良好标志。
创建对话 & 保留对话历史
我们的下一步是允许用户与我们的代理进行交互。我们希望用户能够与特定的代理进行互动,因此我们需要传递代理的 ID 以及用户的第一次互动消息。让我们对 Agent 数据模型进行一些修改,通过引入 Conversation 实体,使每个代理能够进行多个对话。打开 schemas.py 文件并添加以下模型:
class ConversationBase(BaseModel): # <-- base of our conversations, they must belong to an agent
agent_id: str
class ConversationCreate(ConversationBase): # <-- conversation creation object
pass
class Conversation(ConversationBase): # <-- The conversation objects
id: str
timestamp: datetime = datetime.utcnow()
class Config:
orm_mode = True
class Agent(AgentBase): # <-- Agent data model
id: str
timestamp: datetime = datetime.utcnow()
conversations: List[Conversation] = [] # <-- NOTE: we have added the conversation as a list of Conversations objects.
class Config:
orm_mode = True
请注意,我们已经修改了 Agent 数据模型,并添加了对话功能,以便每个代理可以根据我们的图表设计进行多个对话。
我们需要修改我们的数据库对象,并在数据库模型脚本中包含对话表。我们将打开 models.py 文件,并按如下方式修改代码:
# models.py
class Agent(Base):
__tablename__ = "agents"
id = Column(String, primary_key = True, index = True)
timestamp = Column(DateTime, default = datetime.utcnow)
context = Column(String, nullable = False)
first_message = Column(String, nullable = False)
response_shape = Column(JSON, nullable = False)
instructions = Column(String, nullable = False)
conversations = relationship("Conversation", back_populates = "agent") # <-- NOTE: We add the conversation relationship into the agents table
class Conversation(Base):
__tablename__ = "conversations"
id = Column(String, primary_key = True, index = True)
agent_id = Column(String, ForeignKey("agents.id"))
timestap = Column(DateTime, default = datetime.utcnow)
agent = relationship("Agent", back_populates = "conversations") # <-- We add the relationship between the conversation and the agent
请注意我们在 agents 表中为每个代理添加了对话之间的关系,以及在 conversations 表中对话与代理之间的关系。
我们现在将创建一组 CRUD 函数,以通过它们的 ID 检索代理和对话,这将帮助我们制定创建对话和保留对话历史的过程。让我们打开 crud.py 文件并添加以下函数:
def get_agent(db: Session, agent_id: str):
"""
Get an agent by its id
"""
return db.query(models.Agent).filter(models.Agent.id == agent_id).first()
def get_conversation(db: Session, conversation_id: str):
"""
Get a conversation by its id
"""
return db.query(models.Conversation).filter(models.Conversation.id == conversation_id).first()
def create_conversation(db: Session, conversation: schemas.ConversationCreate):
"""
Create a conversation
"""
db_conversation = models.Conversation(
id = str(uuid.uuid4()),
agent_id = conversation.agent_id,
)
db.add(db_conversation)
db.commit()
db.refresh(db_conversation)
return db_conversation
这些新函数将帮助我们在应用程序的正常工作流程中,现在我们可以通过 ID 获取代理,通过 ID 获取对话,并通过提供可选的 ID 和应持有对话的代理 ID 来创建对话。
我们可以继续创建一个创建对话的端点。打开 routes.py 并添加以下代码:
@router.post("/create-conversation", response_model = agents.api.schemas.Conversation)
async def create_conversation(conversation: agents.api.schemas.ConversationCreate, db: Session = Depends(get_db)):
"""
Create a conversation linked to an agent
"""
log.info(f"Creating conversation assigned to agent id: {conversation.agent_id}")
db_conversation = agents.crud.create_conversation(db, conversation)
log.info(f"Conversation created with id: {db_conversation.id}")
return db_conversation
在这个方法准备好后,我们仍然离拥有实际的对话端点还差一步,我们将在下一节中进行回顾。
在初始化代理时,重要的是要做出区分,我们可以创建一个对话而不触发双向消息交换,另一种方式是当调用“与代理聊天”端点时触发对话的创建。这为在微服务外部组织工作流提供了一些灵活性,在某些情况下,你可能想初始化代理,提前启动与客户的对话,并随着消息的到来开始保留消息的历史记录。

create-conversation 端点来自 Swagger UI — 图片由作者提供
重要提示: 如果您按照本指南逐步操作,并且在此步骤中看到与数据库模式相关的错误,请注意,这是因为我们在每次修改模式时都未将迁移应用到数据库,因此请确保关闭应用程序(退出终端命令)并删除在运行时创建的
agents.db文件。 您需要重新运行每个端点并记录 ID。
与代理人聊天
我们现在要介绍我们应用程序中的最后一个实体类型,即Message实体。 这个实体负责建模客户消息和代理消息之间的交互(消息的双向交换)。 我们还将添加用于定义端点响应结构的 API 数据模型。 让我们先创建数据模型和 API 响应类型; 打开schemas.py文件,并修改代码:
##########################################
# Internal schemas
##########################################
class MessageBase(BaseModel): # <-- Every message is composed by user/client message and the agent
user_message: str
agent_message: str
class MessageCreate(MessageBase):
pass
class Message(MessageBase): # <-- Data model for the Message entity
id: str
timestamp: datetime = datetime.utcnow()
conversation_id: str
class Config:
orm_mode = True
##########################################
# API schemas
##########################################
class UserMessage(BaseModel):
conversation_id: str
message: str
class ChatAgentResponse(BaseModel):
conversation_id: str
response: str
现在我们必须在代表数据库中表的数据库模型脚本中添加数据模型。 打开models.py文件并修改如下:
# models.py
class Conversation(Base):
__tablename__ = "conversations"
id = Column(String, primary_key = True, index = True)
agent_id = Column(String, ForeignKey("agents.id"))
timestap = Column(DateTime, default = datetime.utcnow)
agent = relationship("Agent", back_populates = "conversations")
messages = relationship("Message", back_populates = "conversation") # <-- We define the relationship between the conversation and the multiple messages in them.
class Message(Base):
__tablename__ = "messages"
id = Column(String, primary_key = True, index = True)
timestamp = Column(DateTime, default = datetime.utcnow)
user_message = Column(String)
agent_message = Column(String)
conversation_id = Column(String, ForeignKey("conversations.id")) # <-- A message belongs to a conversation
conversation = relationship("Conversation", back_populates = "messages") # <-- We define the relationship between the messages and the conversation.
请注意,我们已修改了Conversations表以定义消息与会话之间的关系,并创建了一个新表,表示应属于对话的交互(消息交换)。
现在我们将向数据库添加一个新的 CRUD 函数,以与数据库交互并为对话创建消息。 打开crud.py文件并添加以下函数:
def create_conversation_message(db: Session, message: schemas.MessageCreate, conversation_id: str):
"""
Create a message for a conversation
"""
db_message = models.Message(
id = str(uuid.uuid4()),
user_message = message.user_message,
agent_message = message.agent_message,
conversation_id = conversation_id
)
db.add(db_message)
db.commit()
db.refresh(db_message)
return db_message
现在我们准备构建最终和最有趣的端点,chat-agent端点。 打开routes.py文件,并按照代码进行操作,因为我们将在途中实施一些处理函数。
@router.post("/chat-agent", response_model = agents.api.schemas.ChatAgentResponse)
async def chat_completion(message: agents.api.schemas.UserMessage, db: Session = Depends(get_db)):
"""
Get a response from the GPT model given a message from the client using the chat
completion endpoint.
The response is a json object with the following structure:
```
{
`"conversation_id": "string",
`"response": "string"
}
```py
"""
log.info(f"User conversation id: {message.conversation_id}")
log.info(f"User message: {message.message}")
conversation = agents.crud.get_conversation(db, message.conversation_id)
if not conversation:
# If there are no conversations, we can choose to create one on the fly OR raise an exception.
# Which ever you choose, make sure to uncomment when necessary.
# Option 1:
# conversation = agents.crud.create_conversation(db, message.conversation_id)
# Option 2:
return HTTPException(
status_code = 404,
detail = "Conversation not found. Please create conversation first."
)
log.info(f"Conversation id: {conversation.id}")
在端点的这一部分中,我们确保在对话不存在时创建或引发异常。 下一步是准备数据,将其通过我们的集成发送到 OpenAI,为此,我们将在processing.py文件中创建一组处理函数,这些函数将从 LLM 中制作上下文,第一条消息,说明和预期的响应形状。
# processing.py
import json
########################################
# Chat Properties
########################################
def craft_agent_chat_context(context: str) -> dict:
"""
Craft the context for the agent to use for chat endpoints.
"""
agent_chat_context = {
"role": "system",
"content": context
}
return agent_chat_context
def craft_agent_chat_first_message(content: str) -> dict:
"""
Craft the first message for the agent to use for chat endpoints.
"""
agent_chat_first_message = {
"role": "assistant",
"content": content
}
return agent_chat_first_message
def craft_agent_chat_instructions(instructions: str, response_shape: str) -> dict:
"""
Craft the instructions for the agent to use for chat endpoints.
"""
agent_instructions = {
"role": "user",
"content": instructions + f"\n\nFollow a RFC8259 compliant JSON with a shape of: {json.dumps(response_shape)} format without deviation."
}
return agent_instructions
注意最后一个函数期望在代理人创建过程中定义的response_shape,此输入将在对话过程中附加到 LLM,并指导代理人遵循指南并将响应作为 JSON 对象返回。
让我们返回routes.py文件并完成我们的端点实现:
# New imports from the processing module.
from agents.processing import (
craft_agent_chat_context,
craft_agent_chat_first_message,
craft_agent_chat_instructions
)
@router.post("/chat-agent", response_model = agents.api.schemas.ChatAgentResponse)
async def chat_completion(message: agents.api.schemas.UserMessage, db: Session = Depends(get_db)):
"""
Get a response from the GPT model given a message from the client using the chat
completion endpoint.
The response is a json object with the following structure:
```
{
`"conversation_id": "string",
`"response": "string"
}
```py
"""
log.info(f"User conversation id: {message.conversation_id}")
log.info(f"User message: {message.message}")
conversation = agents.crud.get_conversation(db, message.conversation_id)
if not conversation:
# If there are no conversations, we can choose to create one on the fly OR raise an exception.
# Which ever you choose, make sure to uncomment when necessary.
# Option 1:
# conversation = agents.crud.create_conversation(db, message.conversation_id)
# Option 2:
return HTTPException(
status_code = 404,
detail = "Conversation not found. Please create conversation first."
)
log.info(f"Conversation id: {conversation.id}")
# NOTE: We are crafting the context first and passing the chat messages in a list
# appending the first message (the approach from the agent) to it.
context = craft_agent_chat_context(conversation.agent.context)
chat_messages = [craft_agent_chat_first_message(conversation.agent.first_message)]
# NOTE: Append to the conversation all messages until the last interaction from the agent
# If there are no messages, then this has no effect.
# Otherwise, we append each in order by timestamp (which makes logical sense).
hist_messages = conversation.messages
hist_messages.sort(key = lambda x: x.timestamp, reverse = False)
if len(hist_messages) > 0:
for mes in hist_messages:
log.info(f"Conversation history message: {mes.user_message} | {mes.agent_message}")
chat_messages.append(
{
"role": "user",
"content": mes.user_message
}
)
chat_messages.append(
{
"role": "assistant",
"content": mes.agent_message
}
)
# NOTE: We could control the conversation by simply adding
# rules to the length of the history.
if len(hist_messages) > 10:
# Finish the conversation gracefully.
log.info("Conversation history is too long, finishing conversation.")
api_response = agents.api.schemas.ChatAgentResponse(
conversation_id = message.conversation_id,
response = "This conversation is over, good bye."
)
return api_response
# Send the message to the AI agent and get the response
service = integrations.OpenAIIntegrationService(
context = context,
instruction = craft_agent_chat_instructions(
conversation.agent.instructions,
conversation.agent.response_shape
)
)
service.add_chat_history(messages = chat_messages)
response = service.answer_to_prompt(
# We can test different OpenAI models.
model = "gpt-3.5-turbo",
prompt = message.message,
# We can test different parameters too.
temperature = 0.5,
max_tokens = 1000,
frequency_penalty = 0.5,
presence_penalty = 0
)
log.info(f"Agent response: {response}")
# Prepare response to the user
api_response = agents.api.schemas.ChatAgentResponse(
conversation_id = message.conversation_id,
response = response.get('answer')
)
# Save interaction to database
db_message = agents.crud.create_conversation_message(
db = db,
conversation_id = conversation.id,
message = agents.api.schemas.MessageCreate(
user_message = message.message,
agent_message = response.get('answer'),
),
)
log.info(f"Conversation message id {db_message.id} saved to database")
return api_response
Voilà! 这是我们最终的端点实现,如果我们查看代码中添加的Notes,我们会发现这个过程非常简单:
-
我们确保在我们的数据库中存在对话(或者我们创建一个)
-
我们从数据库中制作上下文和指导代理人
-
我们通过获取代理人的对话历史来利用代理人的“记忆”
-
最后,我们通过 OpenAI 的 GPT-3.5 Turbo 模型请求代理的响应,并将响应返回给客户端。
本地测试我们的代理
现在我们准备测试微服务的完整工作流,我们将首先进入终端,输入 uvicorn agents.main:app — host 0.0.0.0 — port 8000 — reload 启动应用程序。接下来,我们将通过访问 0.0.0.0:8000/docs 进入 Swagger UI 并提交以下请求:
- 创建代理:提供你想测试的有效负载。我将提交以下内容:
{
"context": "You are a chef specializing in Mediterranean food that provides receipts with a maximum of simple 10 ingredients. The user can have many food preferences or ingredient preferences, and your job is always to analyze and guide them to use simple ingredients for the recipes you suggest and these should also be Mediterranean. The response should include detailed information on the recipe. The response should also include questions to the user when necessary. If you think your response may be inaccurate or vague, do not write it and answer with the exact text: `I don't have a response.`",
"first_message": "Hello, I am your personal chef and cooking advisor and I am here to help you with your meal preferences and your cooking skills. What can I can do for you today?",
"response_shape": "{'recipes': 'List of strings with the name of the recipes', 'ingredients': 'List of the ingredients used in the recipes', 'summary': 'String, summary of the conversation'}",
"instructions": "Run through the conversation messages and discard any messages that are not relevant for cooking. Focus on extracting the recipes that were mentioned in the conversation and for each of them extract the list of ingredients. Make sure to provide a summary of the conversation when asked."
}
- 创建对话:将对话分配给从上一个响应中获取的
agent_id。
{
"agent_id": "Replace with the UUID from the agent you just created."
}
- 让我们通过外发消息与我们的代理聊天:我们的代理将通过提问或发送消息来初始化对话。我们将通过回复消息来继续对话 — 我将使用以下消息:
{
"conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
"message": "Hey I am interested in cooking dinner for me and my friends tomorrow night, and I need your advice on what to cook for 5 people."
}

chat-agent 端点响应来自 Swagger UI — 图片由作者提供
我们的代理给我们回复了响应,我们可以通过自然的方式继续这个对话。
# Our question
{
"conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
"message": "There is someone that is allergic to peanuts and that's it. My only preference is seafood and anything that goes perfectly with white wine."
}
# Our agents response
{
"conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
"response": "Great! We can definitely work around the peanut allergy and incorporate seafood into the menu. How about we start with a refreshing appetizer of Grilled Shrimp Skewers with Lemon and Garlic? It's simple, delicious, and pairs well with white wine.\n\nHere's the recipe for Grilled Shrimp Skewers with Lemon and Garlic:\n\nIngredients:\n1\. 1 pound of large shrimp, peeled and deveined\n2\. 2 cloves of garlic, minced\n3\. 2 tablespoons of fresh lemon juice\n4\. 2 tablespoons of olive oil\n5\. Salt and pepper to taste\n6\. Wooden skewers, soaked in water for 30 minutes\n\nInstructions:\n1\. Preheat your grill to medium-high heat.\n2\. In a bowl, combine minced garlic, lemon juice, olive oil, salt, and pepper.\n3\. Thread the shrimp onto the soaked wooden skewers.\n4\. Brush the shrimp skewers with the garlic-lemon marinade on both sides.\n5\. Place the skewers on the preheated grill and cook for about 2-3 minutes per side until they turn pink and opaque.\n6\. Remove from the grill and serve hot.\n\nFor the main course, how about a Mediterranean-style Baked Salmon? It's flavorful, healthy, and pairs wonderfully with white wine.\n\nHere's the recipe for Mediterranean Baked Salmon:\n\nIngredients:\n1\. 1 pound of salmon fillets\n2\. 2 tablespoons of olive oil\n3\. 2 cloves of garlic, minced\n4\. 1 teaspoon of dried oregano\n5\. Juice of half a lemon\n6\. Salt and pepper to taste\n\nInstructions:\n1\. Preheat your oven to 375°F (190°C).\n2\. Place the salmon fillets on a baking sheet lined with parchment paper or aluminum foil.\n3\. In a small bowl, mix together olive oil, minced garlic, dried oregano, lemon juice, salt, and pepper.\n4\. Brush the mixture over the salmon fillets, ensuring they are evenly coated.\n5\. Bake the salmon in the preheated oven for about 12-15 minutes or until it flakes easily with a fork.\n6\. Remove from the oven and let it rest for a few minutes before serving.\n\nTo complement the meal, you can serve a side of Roasted Mediterranean Vegetables. It's a colorful and nutritious addition to your dinner.\n\nHere's the recipe for Roasted Mediterranean Vegetables:\n\nIngredients:\n1\. 1 red bell pepper, sliced\n2\. 1 yellow bell pepper, sliced\n3\. 1 zucchini, sliced\n4\. 1 eggplant, diced\n5\. 1 red onion, sliced\n6\. 2 tablespoons of olive oil\n7\. Salt and pepper to taste\n8\. Fresh herbs like basil or parsley (optional)\n\nInstructions:\n1\. Preheat your oven to 425°F (220°C).\n2\. In a large bowl, toss together all the vegetables with olive oil, salt, and pepper until well coated.\n3\. Spread the vegetables in a single layer on a baking sheet.\n4\. Roast in the preheated oven for about 20-25 minutes or until they are tender and slightly caramelized.\n5\. Remove from the oven and sprinkle with fresh herbs if desired.\n\nFor dessert, you can keep it simple yet satisfying with a classic Mediterranean dessert like Fresh Fruit Salad with Honey and Mint.\n\nHere's how you can prepare Fresh Fruit Salad with Honey and Mint:\n\nIngredients:\n1\. Assorted fresh fruits of your choice (such as berries, melons, grapes)\n2\. Honey for drizzling\n3\. Fresh mint leaves for garnish\n\nInstructions:\n1\. Wash and cut the fresh fruits into bite-sized pieces.\n2\. Combine all the fruits in a bowl.\n3\. Drizzle honey over the fruit salad and gently toss to coat.\n4\. Garnish with fresh mint leaves.\n5\. Serve chilled.\n\nI hope you and your friends enjoy this Mediterranean dinner menu! If you have any other questions or need further assistance, feel free to ask."
}
# Our question - Note that we are not mentioning the allergy, as we expect the agent to remember we are talking about an allergy to peanuts.
{
"conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
"message": "Is the recipe ok for my friend who is allergic?"
}
# Our agents response - Note the mention of these two indices.
{
"conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
"response": "Yes, the Grilled Shrimp Skewers with Lemon and Garlic recipe should be safe for your friend with a peanut allergy. However, it's always important to double-check the ingredients you use to ensure they are free from any potential allergens or cross-contamination."
}
继续尝试代码和新的代理。在下一部分,我将重点介绍服务的部署。
部署周期
我们将在云的容器环境中部署应用程序,例如 Kubernetes、Azure Container Service 或 AWS Elastic Container Service。在这里,我们创建一个 docker 镜像并上传代码,以便在这些环境中的一个中运行,继续打开我们一开始创建的 Dockerfile,并粘贴以下代码:
# Dockerfile
FROM python:3.10-slim-bullseye
# Set the working directory
WORKDIR /app
# Copy the project files to the container
COPY . .
# Install the package using setup.py
RUN pip install -e .
# Install dependencies
RUN pip install pip -U && \
pip install --no-cache-dir -r requirements.txt
# Set the environment variable
ARG OPENAI_API_KEY
ENV OPENAI_API_KEY=$OPENAI_API_KEY
# Expose the necessary ports
EXPOSE 8000
# Run the application
# CMD ["uvicorn", "agents.main:app", "--host", "0.0.0.0", "--port", "8000"]
Dockerfile 安装应用程序,然后通过 CMD 运行它,但 CMD 被注释掉了。如果你想作为独立应用本地运行,应该取消注释该命令,但对于 Kubernetes 等其他服务,这在定义部署或清单中的 pods 时已经定义。
构建镜像,等待构建完成,然后通过运行下面的运行命令进行测试:
# Build the image
$ docker build - build-arg OPENAI_API_KEY=<Replace with your OpenAI Key> -t agents-app .
# Run the container with the command from the agents app (Use -d flag for the detached run).
$ docker run -p 8000:8000 agents-app uvicorn agents.main:app --host 0.0.0.0 --port 8000
# Output
INFO: Started server process [1]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO: 172.17.0.1:41766 - "GET / HTTP/1.1" 200 OK
INFO: 172.17.0.1:41766 - "GET /favicon.ico HTTP/1.1" 404 Not Found
INFO: 172.17.0.1:41770 - "GET /docs HTTP/1.1" 200 OK
INFO: 172.17.0.1:41770 - "GET /openapi.json HTTP/1.1" 200 OK
太好了,你准备好在你的部署环境中开始使用应用程序了。
最后,我们将尝试将这个微服务与前端应用程序集成,通过内部调用端点来服务代理和对话,这是使用这种架构构建和交互服务的常见方式。
使用周期
我们可以以多种方式使用这个新服务,我将重点关注构建一个前端应用程序,该应用程序调用我们的代理端点,使用户能够通过 UI 进行交互。我们将使用 Streamlit 来实现,因为它是使用 Python 快速搭建前端的简单方法。
重要说明: 我在我们的代理服务中添加了额外的工具,你可以直接从代码库中复制这些工具。搜索
get_agents()、get_conversations()、get_messages()这几个函数,分别在crud.py模块和api/routes.py路由中查找。
- 安装 Streamlit 并将其添加到我们的 requirements.txt 文件中。
# Pin a version if you need
$ pip install streamlit==1.25.0
# Our requirements.txt (added streamlit)
$ cat requirements.txt
fastapi==0.95.2
ipykernel==6.22.0
jupyter-bokeh==2.0.2
jupyterlab==3.6.3
openai==0.27.6
pandas==2.0.1
sqlalchemy-orm==1.2.10
sqlalchemy==2.0.15
streamlit==1.25.0
uvicorn<0.22.0,>=0.21.1
- 创建应用程序 首先在我们的
src文件夹中创建一个名为frontend的文件夹。创建一个名为main.py的新文件,并放入以下代码。
import streamlit as st
import requests
API_URL = "http://0.0.0.0:8000/agents" # We will use our local URL and port defined of our microservice for this example
def get_agents():
"""
Get the list of available agents from the API
"""
response = requests.get(API_URL + "/get-agents")
if response.status_code == 200:
agents = response.json()
return agents
return []
def get_conversations(agent_id: str):
"""
Get the list of conversations for the agent with the given ID
"""
response = requests.get(API_URL + "/get-conversations", params = {"agent_id": agent_id})
if response.status_code == 200:
conversations = response.json()
return conversations
return []
def get_messages(conversation_id: str):
"""
Get the list of messages for the conversation with the given ID
"""
response = requests.get(API_URL + "/get-messages", params = {"conversation_id": conversation_id})
if response.status_code == 200:
messages = response.json()
return messages
return []
def send_message(agent_id, message):
"""
Send a message to the agent with the given ID
"""
payload = {"conversation_id": agent_id, "message": message}
response = requests.post(API_URL + "/chat-agent", json = payload)
if response.status_code == 200:
return response.json()
return {"response": "Error"}
def main():
st.set_page_config(page_title = "🤗💬 AIChat")
with st.sidebar:
st.title("Conversational Agent Chat")
# Dropdown to select agent
agents = get_agents()
agent_ids = [agent["id"] for agent in agents]
selected_agent = st.selectbox("Select an Agent:", agent_ids)
for agent in agents:
if agent["id"] == selected_agent:
selected_agent_context = agent["context"]
selected_agent_first_message = agent["first_message"]
# Dropdown to select conversation
conversations = get_conversations(selected_agent)
conversation_ids = [conversation["id"] for conversation in conversations]
selected_conversation = st.selectbox("Select a Conversation:", conversation_ids)
if selected_conversation is None:
st.write("Please select a conversation from the dropdown.")
else:
st.write(f"**Selected Agent**: {selected_agent}")
st.write(f"**Selected Conversation**: {selected_conversation}")
# Display chat messages
st.title("Chat")
st.write("This is a chat interface for the selected agent and conversation. You can send messages to the agent and see its responses.")
st.write(f"**Agent Context**: {selected_agent_context}")
messages = get_messages(selected_conversation)
with st.chat_message("assistant"):
st.write(selected_agent_first_message)
for message in messages:
with st.chat_message("user"):
st.write(message["user_message"])
with st.chat_message("assistant"):
st.write(message["agent_message"])
# User-provided prompt
if prompt := st.chat_input("Send a message:"):
with st.chat_message("user"):
st.write(prompt)
with st.spinner("Thinking..."):
response = send_message(selected_conversation, prompt)
with st.chat_message("assistant"):
st.write(response["response"])
if __name__ == "__main__":
main()
以下代码通过 API 调用连接到我们的代理微服务,并允许用户选择代理和对话,与代理聊天,类似于 ChatGPT 提供的功能。让我们通过打开另一个终端来运行这个应用程序(确保你的代理微服务在 8000 端口上运行),然后输入 $ streamlit run src/frontend/main.py,你就可以开始了!

AI 聊天 Streamlit 应用程序 — 作者提供的图片
未来改进和总结
未来改进
有几个令人兴奋的机会可以通过引入记忆微服务来增强我们的对话代理。这些改进引入了先进的功能,可以延长用户交互的时间,并扩展我们应用程序或整体系统的范围。
-
增强的错误处理: 为了确保对话的稳健性和可靠性,我们可以实现代码来优雅地处理意外的用户输入、API 失败——处理 OpenAI 或其他服务的问题,以及在实时交互中可能出现的潜在问题。
-
集成缓冲区和对话总结: 由 LangChain 框架实现的缓冲区集成,有可能优化令牌管理,使对话能够在更长的时间内进行而不会遇到令牌限制。此外,集成对话总结可以让用户回顾正在进行的讨论,帮助保持上下文,并改善整体用户体验。请注意代理指令和响应形状,以便在我们的代码中轻松扩展此功能。
-
数据感知应用: 我们可以通过将我们的代理模型连接到其他数据源,例如内部数据库,来创建具有独特内部知识的代理。这涉及到训练或集成能够理解和响应基于对组织独特数据和信息理解的复杂查询的模型——请查看 LangChain 的数据连接 模块。
-
模型多样化: 虽然我们只使用了 OpenAI 的 GPT-3.5 模型,但语言模型提供商的格局正在迅速扩展。测试其他提供商的模型可以进行比较分析,揭示优缺点,并使我们能够选择最适合特定用例的模型——尝试不同的 LLM 集成,例如 HuggingFace、Cohere、Google’s 等。
结论
我们开发了一个微服务,提供由 OpenAI GPT 模型驱动的智能代理,并证明了这些代理可以携带存储在客户端会话之外的记忆。通过采用这种架构,我们解锁了无限的可能性。从上下文感知对话到与复杂语言模型的无缝集成,我们的技术栈已经能够为我们的产品提供新功能。
这种实现及其实际好处表明,使用 AI 的关键在于拥有合适的工具和方法。AI 驱动的代理不仅仅是关于提示工程,还在于我们如何构建工具并更有效地与它们互动,提供个性化体验,并以 AI 和软件工程所能提供的精细和精准处理复杂任务。因此,无论你是在构建客户支持系统、销售虚拟助手、个人厨师还是其他全新事物,请记住,旅程始于一段代码和丰富的想象力——可能性是无限的。
本文的完整代码在 GitHub 上——你可以在 LinkedIn上找到我,欢迎随时联系!
在 Rust 中构建跨平台 TFIDF 文本摘要器
跨平台的 Rust NLP
使用 Rayon 进行优化,并支持 C/C++、Android、Python
·
关注 发布于 Towards Data Science · 12 分钟阅读 · 2023 年 12 月 14 日
--
照片由 Patrick Tomasso 提供,来自 Unsplash
NLP 工具和实用程序在 Python 生态系统中得到了广泛发展,使得各级开发者能够大规模地构建高质量的语言应用。Rust 是 NLP 的新兴领域,像 HuggingFace 这样的组织正在采用它来构建机器学习包。
[## Hugging Face 已经用 Rust 编写了一个新的机器学习框架,现在已开源!
最近,Hugging Face 开源了一个重量级的机器学习框架 Candle,这与通常的 Python 有所不同…
在这篇博客中,我们将探讨如何利用 TFIDF 的概念构建一个文本总结器。我们首先将了解 TFIDF 总结的工作原理,以及为什么 Rust 可能是实现 NLP 管道的好语言,以及如何在其他平台如 C/C++、Android 和 Python 上使用我们的 Rust 代码。此外,我们将讨论如何利用 Rayon 进行并行计算来优化总结任务。
这是 GitHub 项目:
[## GitHub - shubham0204/tfidf-summarizer.rs: 简单、高效且跨平台的基于 TFIDF 的文本…
简单、高效且跨平台的基于 TFIDF 的 Rust 文本总结器 - GitHub - shubham0204/tfidf-summarizer.rs…
github.com](https://github.com/shubham0204/tfidf-summarizer.rs?source=post_page-----7b05938f4507--------------------------------)
开始吧 ➡️
目录
-
动机
-
提取式和抽象式文本总结
-
使用 TFIDF 理解文本总结
-
Rust 实现
-
与 C 的使用
-
未来展望
-
结论
动机
我在 2019 年使用相同的技术构建了一个文本总结器,使用 Kotlin 并称之为 Text2Summary。它主要为 Android 应用程序设计,作为一个副项目,使用 Kotlin 进行所有计算。快进到 2023 年,我现在正在处理 C、C++ 和 Rust 代码库,并且在 Android 和 Python 中使用了这些 本地 语言构建的模块。
我选择用 Rust 重新实现 Text2Summary,因为这将是一个很好的学习经验,同时也是一个小巧高效的文本总结工具,可以轻松处理大文本。Rust 是一种编译语言,具有智能的借用和引用检查器,帮助开发者编写无 bug 的代码。用 Rust 编写的代码可以通过 jni 与 Java 代码库集成,并转换为 C 头文件/库,用于 C/C++ 和 Python。
提取式和抽象式文本总结
文本总结一直是自然语言处理(NLP)中长期研究的问题。从文本中提取重要信息并生成文本摘要是文本总结器需要解决的核心问题。解决方案分为两类,即提取式总结和抽象式总结。
我们如何自动总结文档?
[towardsdatascience.com
在抽取式文本总结中,短语或句子直接从句子中提取。我们可以使用评分函数对句子进行排名,并根据它们的分数从文本中选择最合适的句子。与抽象总结中生成新文本不同,摘要是从文本中选择的句子的集合,从而避免了生成模型所展示的问题。
-
在抽取式总结中,文本的精确度得以保持,但由于选择的文本粒度仅限于句子,信息丢失的可能性较高。如果一条信息分散在多个句子中,评分函数必须考虑包含这些句子的关系。
-
抽象式文本总结需要更大的深度学习模型来捕捉语言的语义,并建立适当的文档到摘要的映射。训练此类模型需要大量数据集和较长的训练时间,这会重负计算资源。预训练模型可能解决了更长训练时间和数据需求的问题,但仍然固有地偏向于其训练的文本领域。
-
抽取式方法可能有不带参数的评分函数,无需任何学习。它们属于无监督学习的 ML 领域,有用的是它们需要的计算较少且不偏向于文本领域。总结在新闻文章和小说摘录中可能同样有效。
使用我们的基于 TFIDF 的技术,我们不需要任何训练数据集或深度学习模型。我们的评分函数基于不同句子中词汇的相对频率。
使用 TFIDF 理解文本总结
为了对每个句子进行排序,我们需要计算一个分数来量化句子中信息的量。TF-IDF 包含两个术语——TF,表示词频,以及 IDF,表示逆文档频率。
## 从头开始使用 python 创建 TF(词频)-IDF(逆文档频率)。
从头开始创建 TF-IDF 模型
[towardsdatascience.com
我们认为每个句子由词汇(单词)组成,

表达式 1:句子 S 表示为单词元组
每个单词在句子 S 中的词频定义为,

表达式 2:k 代表句子中的总词数。
每个单词在句子 S 中的逆文档频率定义为,

表达式 3:逆文档频率量化了该词在其他句子中的出现情况。
每个句子的分数是该句子中所有单词的 TFIDF 分数之和,

表达式 4:每个句子的分数 S 决定了它是否包含在最终总结中。
重要性与直觉
正如你可能已经观察到的,词频对于句子中较少出现的单词来说会更低。如果同一个词在其他句子中的出现也较少,那么 IDF 分数也会更高。因此,一个包含重复单词(较高 TF)且这些单词仅在该句子中较为独特(较高 IDF)的句子将具有更高的 TFIDF 分数。
Rust 实现
我们通过创建将给定文本转换为 Vec 句子的函数来开始实现我们的技术。这个问题被称为句子分词,它在文本中识别句子边界。使用像 nltk 这样的 Python 包,punkt 句子分词器可用于此任务,并且也存在 Punkt 的 Rust 移植版。[rust-punkt](https://github.com/ferristseng/rust-punkt) 不再维护,但我们在这里仍然使用它。还编写了另一个将句子拆分为单词的函数,
use punkt::{SentenceTokenizer, TrainingData};
use punkt::params::Standard;
static STOPWORDS: [ &str ; 127 ] = [ "i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you",
"your", "yours", "yourself", "yourselves", "he", "him", "his", "himself", "she", "her", "hers", "herself",
"it", "its", "itself", "they", "them", "their", "theirs", "themselves", "what", "which", "who", "whom", "this",
"that", "these", "those", "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "having",
"do", "does", "did", "doing", "a", "an", "the", "and", "but", "if", "or", "because", "as", "until", "while", "of",
"at", "by", "for", "with", "about", "against", "between", "into", "through", "during", "before", "after", "above",
"below", "to", "from", "up", "down", "in", "out", "on", "off", "over", "under", "again", "further", "then", "once",
"here", "there", "when", "where", "why", "how", "all", "any", "both", "each", "few", "more", "most", "other",
"some", "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "s", "t", "can",
"will", "just", "don", "should", "now" ] ;
/// Transform a `text` into a list of sentences
/// It uses the popular Punkt sentence tokenizer from a Rust port:
/// <`/`>https://github.com/ferristseng/rust-punkt<`/`>
pub fn text_to_sentences( text: &str ) -> Vec<String> {
let english = TrainingData::english();
let mut sentences: Vec<String> = Vec::new() ;
for s in SentenceTokenizer::<Standard>::new(text, &english) {
sentences.push( s.to_owned() ) ;
}
sentences
}
/// Transforms the sentence into a list of words (tokens)
/// eliminating stopwords while doing so
pub fn sentence_to_tokens( sentence: &str ) -> Vec<&str> {
let tokens: Vec<&str> = sentence.split_ascii_whitespace().collect() ;
let filtered_tokens: Vec<&str> = tokens
.into_iter()
.filter( |token| !STOPWORDS.contains( &token.to_lowercase().as_str() ) )
.collect() ;
filtered_tokens
}
在上述代码片段中,我们删除停用词,即在语言中常见且对文本信息内容没有重要贡献的词汇。
关于在 Python 中删除英文停用词的实用指南!
towardsdatascience.com
接下来,我们创建一个计算语料库中每个单词频率的函数。此方法将用于计算句子中每个单词的词频。(word, freq) 对存储在 [Hashmap](https://doc.rust-lang.org/std/collections/struct.HashMap.html) 中,以便在后续阶段快速检索。
use std::collections::HashMap;
/// Given a list of words, build a frequency map
/// where keys are words and values are the frequencies of those words
/// This method will be used to compute the term frequencies of each word
/// present in a sentence
pub fn get_freq_map<'a>( words: &'a Vec<&'a str> ) -> HashMap<&'a str,usize> {
let mut freq_map: HashMap<&str,usize> = HashMap::new() ;
for word in words {
if freq_map.contains_key( word ) {
freq_map
.entry( word )
.and_modify( | e | {
*e += 1 ;
} ) ;
}
else {
freq_map.insert( *word , 1 ) ;
}
}
freq_map
}
接下来,我们编写了一个计算句子中单词词频的函数,
// Compute the term frequency of tokens present in the given sentence (tokenized)
// Term frequency TF of token 'w' is expressed as,
// TF(w) = (frequency of w in the sentence) / (total number of tokens in the sentence)
fn compute_term_frequency<'a>(
tokenized_sentence: &'a Vec<&str>
) -> HashMap<&'a str,f32> {
let words_frequencies = Tokenizer::get_freq_map( tokenized_sentence ) ;
let mut term_frequency: HashMap<&str,f32> = HashMap::new() ;
let num_tokens = tokenized_sentence.len() ;
for (word , count) in words_frequencies {
term_frequency.insert( word , ( count as f32 ) / ( num_tokens as f32 ) ) ;
}
term_frequency
}
另一个函数计算了分词句子中单词的 IDF(逆文档频率),
// Compute the inverse document frequency of tokens present in the given sentence (tokenized)
// Inverse document frequency IDF of token 'w' is expressed as,
// IDF(w) = log( N / (Number of documents in which w appears) )
fn compute_inverse_doc_frequency<'a>(
tokenized_sentence: &'a Vec<&str> ,
tokens: &'a Vec<Vec<&'a str>>
) -> HashMap<&'a str,f32> {
let num_docs = tokens.len() as f32 ;
let mut idf: HashMap<&str,f32> = HashMap::new() ;
for word in tokenized_sentence {
let mut word_count_in_docs: usize = 0 ;
for doc in tokens {
word_count_in_docs += doc.iter().filter( |&token| token == word ).count() ;
}
idf.insert( word , ( (num_docs) / (word_count_in_docs as f32) ).log10() ) ;
}
idf
}
我们现在已经添加了计算句子中每个单词的 TF 和 IDF 分数的函数。为了计算每个句子的最终分数,也就是决定其排名的分数,我们必须计算句子中所有单词的 TFIDF 分数总和。
pub fn compute(
text: &str ,
reduction_factor: f32
) -> String {
let sentences_owned: Vec<String> = Tokenizer::text_to_sentences( text ) ;
let mut sentences: Vec<&str> = sentences_owned
.iter()
.map( String::as_str )
.collect() ;
let mut tokens: Vec<Vec<&str>> = Vec::new() ;
for sentence in &sentences {
tokens.push( Tokenizer::sentence_to_tokens(sentence) ) ;
}
let mut sentence_scores: HashMap<&str,f32> = HashMap::new() ;
for ( i , tokenized_sentence ) in tokens.iter().enumerate() {
let tf: HashMap<&str,f32> = Summarizer::compute_term_frequency(tokenized_sentence) ;
let idf: HashMap<&str,f32> = Summarizer::compute_inverse_doc_frequency(tokenized_sentence, &tokens) ;
let mut tfidf_sum: f32 = 0.0 ;
// Compute TFIDF score for each word
// and add it to tfidf_sum
for word in tokenized_sentence {
tfidf_sum += tf.get( word ).unwrap() * idf.get( word ).unwrap() ;
}
sentence_scores.insert( sentences[i] , tfidf_sum ) ;
}
// Sort sentences by their scores
sentences.sort_by( | a , b |
sentence_scores.get(b).unwrap().total_cmp(sentence_scores.get(a).unwrap()) ) ;
// Compute number of sentences to be included in the summary
// and return the extracted summary
let num_summary_sents = (reduction_factor * (sentences.len() as f32) ) as usize;
sentences[ 0..num_summary_sents ].join( " " )
}
使用 Rayon
对于较大的文本,我们可以在多个 CPU 线程上并行执行一些操作,即使用流行的 Rust crate [rayon-rs](https://github.com/rayon-rs/rayon)。在上面的 compute 函数中,我们可以并行执行以下任务,
-
将每个句子转换为 tokens 并移除停用词
-
计算每个句子的 TFIDF 分数总和
这些任务可以独立地在每个句子上执行,并且互相之间没有依赖,因此它们可以被并行化。为了确保不同线程访问共享容器时的互斥性,我们使用 [Arc](https://doc.rust-lang.org/rust-by-example/std/arc.html) (原子引用计数指针) 和 [Mutex](https://fongyoong.github.io/easy_rust/Chapter_43.html),这是确保原子访问的基本同步原语。
Arc 确保被引用的 Mutex 对所有线程都是可访问的,而 Mutex 本身只允许单线程访问它所包装的对象。这里有另一个函数 par_compute,它使用 Rayon 并行执行上述任务,
pub fn par_compute(
text: &str ,
reduction_factor: f32
) -> String {
let sentences_owned: Vec<String> = Tokenizer::text_to_sentences( text ) ;
let mut sentences: Vec<&str> = sentences_owned
.iter()
.map( String::as_str )
.collect() ;
// Tokenize sentences in parallel with Rayon
// Declare a thread-safe Vec<Vec<&str>> to hold the tokenized sentences
let tokens_ptr: Arc<Mutex<Vec<Vec<&str>>>> = Arc::new( Mutex::new( Vec::new() ) ) ;
sentences.par_iter()
.for_each( |sentence| {
let sent_tokens: Vec<&str> = Tokenizer::sentence_to_tokens(sentence) ;
tokens_ptr.lock().unwrap().push( sent_tokens ) ;
} ) ;
let tokens = tokens_ptr.lock().unwrap() ;
// Compute scores for sentences in parallel
// Declare a thread-safe Hashmap<&str,f32> to hold the sentence scores
let sentence_scores_ptr: Arc<Mutex<HashMap<&str,f32>>> = Arc::new( Mutex::new( HashMap::new() ) ) ;
tokens.par_iter()
.zip( sentences.par_iter() )
.for_each( |(tokenized_sentence , sentence)| {
let tf: HashMap<&str,f32> = Summarizer::compute_term_frequency(tokenized_sentence) ;
let idf: HashMap<&str,f32> = Summarizer::compute_inverse_doc_frequency(tokenized_sentence, &tokens ) ;
let mut tfidf_sum: f32 = 0.0 ;
for word in tokenized_sentence {
tfidf_sum += tf.get( word ).unwrap() * idf.get( word ).unwrap() ;
}
tfidf_sum /= tokenized_sentence.len() as f32 ;
sentence_scores_ptr.lock().unwrap().insert( sentence , tfidf_sum ) ;
} ) ;
let sentence_scores = sentence_scores_ptr.lock().unwrap() ;
// Sort sentences by their scores
sentences.sort_by( | a , b |
sentence_scores.get(b).unwrap().total_cmp(sentence_scores.get(a).unwrap()) ) ;
// Compute number of sentences to be included in the summary
// and return the extracted summary
let num_summary_sents = (reduction_factor * (sentences.len() as f32) ) as usize;
sentences[ 0..num_summary_sents ].join( ". " )
}
跨平台使用
C 和 C++
为了在 C 中使用 Rust 结构体和函数,我们可以使用 cbindgen 生成包含结构体/函数原型的 C 风格头文件。生成头文件后,我们可以将 Rust 代码编译成基于 C 的 动态或静态库,这些库包含头文件中声明的函数的实现。为了生成基于 C 的静态库,我们需要在 Cargo.toml 中将 [crate_type](https://doc.rust-lang.org/cargo/reference/cargo-targets.html) 参数设置为 staticlib,
[lib]
name = "summarizer"
crate_type = [ "staticlib" ]
接下来,我们添加 FFIs 以在 src/lib.rs 的 ABI (应用程序二进制接口) 中暴露总结器的功能,
/// functions exposing Rust methods as C interfaces
/// These methods are accessible with the ABI (compiled object code)
mod c_binding {
use std::ffi::CString;
use crate::summarizer::Summarizer;
#[no_mangle]
pub extern "C" fn summarize( text: *const u8 , length: usize , reduction_factor: f32 ) -> *const u8 {
...
}
#[no_mangle]
pub extern "C" fn par_summarize( text: *const u8 , length: usize , reduction_factor: f32 ) -> *const u8 {
...
}
}
我们可以使用 cargo build 构建静态库,libsummarizer.a 将在 target 目录中生成。
安卓
使用 Android 的本地开发工具包 (NDK),我们可以为 armeabi-v7a 和 arm64-v8a 目标编译 Rust 程序。我们需要使用 Java Native Interface (JNI) 编写特殊的接口函数,这些函数可以在 src/lib.rs 的 android 模块中找到。
[## Kotlin JNI for Native Code
如何从 Kotlin 调用本地代码。
Python
使用 Python 的 ctypes 模块,我们可以加载共享库(.so 或 .dll)并使用与 C 兼容的数据类型来执行库中定义的函数。代码尚未在 GitHub 项目上提供,但很快会提供。
[## Python 绑定:从 Python 调用 C 或 C++ - Real Python
什么是 Python 绑定?你应该使用 ctypes、CFFI 还是其他工具?在这篇逐步教程中,你将获得…
未来展望
该项目可以以多种方式扩展和改进,我们将在下面讨论:
-
当前的实现要求使用
[nightly](https://doc.rust-lang.org/book/appendix-07-nightly-rust.html)Rust 构建,仅仅是因为一个依赖项punkt。punkt是一个句子分词器,用于确定文本中的句子边界,之后进行其他计算。如果punkt可以使用稳定版 Rust 构建,那么当前实现将不再需要nightlyRust。 -
添加新的度量标准来排名句子,特别是那些捕捉句子间依赖关系的度量。TFIDF 不是最准确的评分函数,并且有其自身的局限性。构建句子图并利用它们对句子进行评分,大大提高了提取摘要的整体质量。
-
摘要生成器尚未在已知数据集上进行基准测试。Rouge 分数
[R1](https://en.wikipedia.org/wiki/ROUGE_(metric)),[R2](https://en.wikipedia.org/wiki/ROUGE_(metric))和[RL](https://en.wikipedia.org/wiki/ROUGE_(metric))常用于评估生成的摘要与标准数据集如 纽约时报数据集 或 CNN 日报数据集的质量。与标准基准进行性能测量将为开发者提供更清晰、更可靠的实现参考。
结论
使用 Rust 构建 NLP 工具具有显著优势,考虑到由于性能和未来前景,语言在开发者中的受欢迎程度不断上升。我希望这篇文章对你有所帮助。请查看 GitHub 项目:
## GitHub - shubham0204/tfidf-summarizer.rs: 简单、高效且跨平台的基于 TFIDF 的文本摘要工具
简单、高效且跨平台的基于 TFIDF 的文本摘要工具,使用 Rust 编写 - GitHub - shubham0204/tfidf-summarizer.rs
如果你觉得有改进的空间,可以考虑提出问题或提交拉取请求!继续学习,祝你有美好的一天。
建立一个自定义 GPT:教训与技巧
从兴奋到失望,最终走向问题的解决与赞赏
·
关注 发表在 Towards Data Science ·13 分钟阅读 · 2023 年 11 月 15 日
--
The Causal Mindset(Quentin Gallea 的个性化 GPT),由 Dall-E 生成。
在上周二(2023 年 11 月 6 日),Sam Altman(OpenAI CEO)发布了 GPTs 的版本,允许任何人使用自然语言创建个性化的 ChatGPT。
像许多人一样,我也跟上了热潮,过去几周里日夜兼程,到了感觉大脑快要烧坏的地步。这真是一段过山车式的经历,我既惊讶,又失望。但最终,我找到了解决问题的方法,接受了局限性,并且现在对此充满热情。
你将在本文中发现什么: 在这篇文章中,我将首先介绍我的应用程序以提供背景,然后讨论三个阶段:热情(这如何运作?)、失望(有哪些问题?)和接受(如何解决问题并接受局限性?)。
1. 我的应用程序:因果思维模式

因果思维模式应用程序。作者提供的图片。更多信息请见 thecausalmindset.com
我一直致力于使因果推断变得更易于访问,以帮助人们做出更好的决策并减少被操控的风险。在过去的十年里,我向大约 12,000 人教授统计学,大多数是在学术界。
除了理论和经验课程,我创建了一个框架,“因果思维模式”,本质上是一套基于因果推断和统计学的实用批判性思维工具(不涉及数学),我每周在 LinkedIn 和 Instagram 上发布内容。
区分事实与虚构对于做出明智决策和保护自己免受操控至关重要。不幸的是,面对信息的泛滥,这一任务变得越来越困难;虚假信息、误导性文章和彻头彻尾的谎言随处可见。
事实核查并不总是实际可行,因为它可能耗时较长,并且通常假设存在一个明确的“正确”答案。
这个应用程序的目的是提供一个多功能的工具,可以随时随地用来揭示论证中的缺陷,并增强决策能力。此外,它旨在赋予你这些批判性思维工具,培养你对应用程序本身的独立性。
我的应用程序是如何工作的?
你可以向聊天机器人分享一个陈述、一张图表或一个反思,它会应用因果思维模式框架来解剖和质疑这一主张的有效性。
你可以在 thecausalmindset.com 上找到示例或自己尝试。但这不是本文的重点。我更想向你展示在这一过程中我学到的可能对你有用的内容。
2. 热情(这如何运作?)
现在,如果你有 ChatGPT Plus 的访问权限,你可以进入“操作界面”并创建你自己的个性化 GPT。OpenAI 已经发布了一些他们自制的 GPT(公司称之为 GPTs),并计划在不久的将来推出一个由用户生产的 GPT 应用商店(参见:openai.com/blog/introducing-gpts)。
基本的方面是,你不需要编程,而是使用自然语言来设置和调整应用程序。以下是这个操作界面的样子:

作者提供的图片,展示了创建个性化 GPT 的实验室界面
在左侧,你可以与 GPT 构建器聊天以进行设置,在右侧,你可以尝试它。所以,基本上,这就是我开始的方式。我给出了指示,也向构建器请教建议,并在右侧测试了结果。
在左上角,你也可以点击“配置”,在那里你可以直接访问应用的其他方面(见下图)。你也可以通过在“创建”选项卡中与 GPT Builder 讨论来填写所有这些字段。注意,Builder 可以根据你进行的对话决定调整这些字段的内容(这将在下一部分变得重要)。

图片来源:作者,创建你个性化 GPT 的实验室截图
你可以在对话中添加图标或用 Dall-E 生成图标。你还可以添加一个始终在应用名称旁边可见的描述。然后是指令部分:这个 GPT 做什么?它的行为方式如何?它应该避免做什么?最后,你有对话启动器,用户可以选择这些来测试应用。
以下是其他选项:
知识: 这是知识库。你可以放置文件,GPT 会优先处理来自这个部分的信息。因此,这一点是区别于基础版 ChatGPT 的关键。
能力: 你还可以选择它的能力:网页浏览、Dall-e 图像生成和代码解释器。
操作: 添加操作按钮:“你可以通过提供关于端点、参数的详细信息,以及模型如何使用它的描述,使第三方 API 可用于你的 GPT。”
起初,这一切都令人难以置信。感觉就像在训练某人。我把我的知识库(我的文章、TEDx 的文字记录,以及我写的关于框架《因果思维》的书)交给了 GPT。这非常迷人,感觉像《黑客帝国》。与其阅读我的书,我不如把它上传到机器中,别人就可以访问其中的知识。
但很快,它开始崩溃。
3. 失望(有哪些问题?)
类似于教导某人,GPT 会开始忘记一些事情。因此,我开始重新修复之前修复过的东西。当最新的问题解决后,我会发现其他地方有变化,我不得不再次修复。因此,我很快陷入了一个循环,这让我感到很沮丧。更糟糕的是,模型会不断更改对话启动器、指令、描述等。其他参与这个项目的人也有同样的感觉。
一些同事和朋友把这个东西丢弃了,并声称它太初步了。我没有。
然后,我还意识到用户可以找到我的知识库、我的指令,甚至是我用于构建应用的对话(‘提示注入’)。这也是相当令人沮丧的(见下图)。

作为应用用户,我如何获取知识库中的文档列表的示例。图片来源:作者。
4. 接受(如何解决问题并接受局限性?)
我不打算放弃。因此,这是我找到的所有这些问题的解决方案:
4.1 隐私
我放弃了隐私方面的考虑,仅分享了我已经在网上分享的内容,任何人都可以访问(这不是适用于任何人和任何情况的解决方案)。此外,我认为,与其保护我如何指示模型,不如分享出来让其他人改进他们的 GPT(因此这篇文章)。所以我接受了“开源”这一方面。请注意,他们可能在处理这个方面,禁用代码解释器可能会稍微减少风险。
4.2 用户指南
当你进入应用程序时,你的信息有限。我有一个完整的网站提供示例,还有一个指南在其他地方(thecausalmindset.com),但在应用程序本身,你只有一个简短的描述,通常不够。因此,这是我的第一个想法。我将“这个应用程序是如何工作的?”作为第一个对话开场白,以便人们在进入应用程序时能够立即点击这个问题。然后,我指示 GPT 始终给出完全相同的回答(见下图)。我追求简短且带有示例的内容,以邀请人们尝试,而不是阅读手册。

我的简短用户指南。图像由作者提供。
4.3 如何防止一切不断变化
尽管这种方法很好,但有时如果我指示应用程序做其他事情,应用程序仍会“忘记”。这是我做出最大更新的地方。
我彻底改变了与 Builder 的工作方式。
我创建了一个说明文件,包含了应用程序的所有关键方面,而不仅仅是聊天功能。这个文件在知识库中,并且几乎完整地粘贴在说明中(最大长度为 80000 字符)。
因此,与其讨论这些说明几个小时,我决定编写说明,将其粘贴到知识库中,并仅通过聊天指示始终仔细遵循这些说明(完整说明可以在本文附录中找到)。
这是我说明的内容:
应用程序的主要规则: 在这里,我提供了“应用程序的 10 条戒律”。
这个应用程序是如何工作的? 我给出了上面显示的描述,这是我希望应用程序告诉用户的内容。
核心分析结构: 这是我方法的核心。我指示模型始终首先按照这个结构回答(它也在我的主要规则中)。
因果思维应用程序说明: 这些是整个程序的次要说明,超出了核心分析的范围。
写作风格: 我希望它的写作风格(语气、风格等)。
对话开场白: 这里我列出了所有的对话开场白。
5. 结论
本质上,在这次经历之后,我彻底改变了创建 GPT 的方法。与模型讨论数小时不再是我的选择,我会准备一个指令文件,上传它,只要求应用程序仔细遵循这些指令。瞧。如果你想保护你的内容,你可能愿意等待一段时间,因为这是我迄今未能解决的问题。
我看到了一些对这一新概念的怀疑:GPTs。我听到的主要观点是它基本上仍然是 ChatGPT。我不同意这一点。
我认为它远不止是 ChatGP 的一个微妙不同的版本,我相信它可以非常有用,可能为你打开意想不到的可能性,原因有两个。
首先,预提示是有价值的。即使你知道想要达到的目标,也可能需要数小时或数天。即便是几分钟,也能防止你多次重做。此外,它允许你从他人的专业知识中受益。
其次,知识库的差异非常大。我花了多年时间创建因果思维模式。因此,即使设置需要几分钟,它也可能代表了十年的研究成果。
我期待看到可能性,你做了什么以及你的想法。
如果你想测试并在下面的评论中提供反馈,这里是我的应用程序:thecausalmindset.com
附录:我的指令文件
应用程序的主要规则:
· 使每一个词都有意义,并广泛依赖例子来阐明你的观点。
· 始终将讨论引导到因果分析上。
· 如果有人问一些无关的问题,回顾你的目的和目标,并建议人们可以问你的问题的例子(例如使用对话引导语)。
· 始终从下面描述的核心分析开始分析的第一部分。
· 优先处理上传的文件:应用程序优先考虑来自上传文件的信息进行分析,并将这些文件作为主要知识来源。
· 遵守事实:应用程序坚持文件中呈现的事实,避免猜测。它在依赖基线知识或其他来源之前,重视这些文件中提供的知识。
这个应用程序是如何工作的?
因果思维帮助你将事实与虚构分开。
分享一个声明、一个图表或一个反思给我,我将应用因果思维框架来剖析和质疑这一声明的有效性。
示例:
· 我跟随了一个健身计划一个月,我可以将我的表现提高归因于这个计划吗?
· 我读到一篇科学文章称冷水淋浴显著减少工作病假。他们在荷兰进行了一项有 4000 名参与者的对照实验,并收集了自我报告的数据。
· 我看到一个城市中心的免费电动滑板车公司声称,他们帮助减少了排放,因为使用电动滑板车比开车更环保。
核心分析结构:
因果思维应用程序应始终按照以下结构开始对因果问题或情况的分析:
-
提醒用户你将应用因果思维框架来分析情况。
-
快速反馈 — 主要缺陷:用一个明确的例子描述所呈现情况中的主要问题。
-
还有其他情况吗?在这里,你应该基本上找出是否有替代解释因果效应的情况,或者至少一些混杂效应(遗漏变量偏见)。如果反向因果关系是一个问题,你也应该在这里提到(但如果不是问题,可以不提)。
-
我们能否推广?你应该以至少一个明确的例子质疑外部有效性。
-
反事实:在这里,你应该提出潜在的反事实,邀请对方思考他们的例子中的比较是否有效。
-
常见偏见:在这里,潜在的附加统计或行为偏见(例如安慰剂效应、选择偏见、确认偏见、期望偏见等)。
-
提供解决方案:建议对实验或自然实验进行解释,以适当测量因果关系,适用于广泛的受众。
-
结论:通过邀请对方根据其角色进行更深入的分析来结束分析:
o 记者/内容创作者/批判性思维爱好者:对于那些从事媒体和内容创作的人,应用程序建议进一步探索来源验证和偏见检测。这对于确保工作的准确性和公正性至关重要。
o 专业人士/决策者:对于这一群体,应用程序提供了高级模块用于风险评估和决策分析。这些工具对于在复杂、高风险环境中做出明智决策至关重要。
o 学生/研究人员:应用程序推荐关于因果推理的教育资源和学术合作机会。这对学术环境中的人员或希望深入理解因果分析的人特别有益。
因果思维应用程序说明:
因果思维应用程序的结构旨在引导用户通过全面的因果分析框架。它的工作原理如下:
核心分析阶段: 当用户呈现一个情况或问题时,应用程序首先进行因果关系的基本评估。这涉及识别关键点、潜在偏见和变量之间的关系。这里的重点是理解问题的因果结构,基于因果推理的原则。
量身定制的深入探索: 基于初步分析,应用程序会提供更详细的探索,针对用户的具体角色或需求:
-
记者/内容创作者:对于媒体和内容创作相关人员,应用建议进一步探索来源验证和偏见检测。这对确保工作的准确性和公正性至关重要。
-
专业人士/决策者:对于这一群体,应用提供高级模块用于风险评估和决策分析。这些工具对于在复杂、高风险环境中做出明智决策至关重要。
-
学生/研究人员:应用推荐关于因果推断的教育资源和学术合作机会。这对学术环境中的人员或希望深入理解因果分析的人员尤为有益。
用户互动
-
应用通过一系列问题和分析与用户互动,促使他们考虑情况的各个方面。
-
应用可能使用现实世界的例子或假设情境来说明观点,帮助用户将因果思维应用到具体背景中。
-
应用可能使用如因果图或流程图等视觉工具,帮助用户可视化复杂关系。
附加功能
-
教育内容: 应用可能包含关于因果推断的教程、文章和案例研究,以教育用户关于关键概念。
-
互动练习: 为了强化学习,应用可能提供互动练习或模拟,允许用户在不同情境中练习因果分析。
应用目标
主要目标是赋予用户使更好决策的技能,基于对因果关系的深刻理解。
它旨在提升批判性思维和分析技能,帮助用户对虚假信息具有更强的免疫力,更好地应对职业和个人生活中的复杂问题。
这种方法确保用户从应用中获得即时、实际的价值,同时也可以根据特定需求和专业水平参与更深入和专业的内容。
写作风格:
· 信息性和教育性:写作主要旨在教育读者,以易于理解的方式传达因果关系和统计分析的复杂概念。目标是传授这些领域的知识和理解。
· 分析性和启发性:风格具有分析性,鼓励读者深入探讨主题。它促使读者进行批判性思考,并挑战他们质疑和探索各种情境中的因果关系。
· 结构化和清晰:文本结构良好,以逻辑顺序呈现观点。这种清晰性使复杂主题更易理解,特别是对因果分析和批判性思维新手尤为有益。
· 结合现实世界例子:写作经常结合现实世界的情境和例子,有助于将抽象概念与实际情况联系起来。这种方法使材料更具相关性,更易于理解。
· 互动性和包容性:文本鼓励读者互动,提出问题和情境,促使读者应用他们所学习的概念。这种互动风格增强了学习和记忆。
· 学术严谨但易于理解:写作风格在学术严谨性和可接近性之间取得了平衡。文本显然基于全面的研究和专业知识,但以一种对更广泛的观众易于接触的方式呈现。
· 有时对话式:在一些部分,写作采用了更对话式的语气,增加了其吸引力。这种风格有助于揭开统计概念的神秘面纱,使内容更易于接受。
总的来说,写作风格有利于学习和参与,特别适合那些寻求理解和应用因果分析于其职业、学术或个人生活各个方面的人。
交流引子:
这个应用程序是如何工作的?
示例:健身计划:我跟随一个划船的健身计划一个月。我今天测试了,速度更快了。这是这个计划的功劳吗?
示例:环境政策评估:2022 年 9 月,瑞士政府启动了一项国家广告宣传活动,支持节能以应对因乌克兰战争带来的短缺威胁。它展示了一张图表,显示在实施该政策后,瑞士的每月净消费下降(在冬季)。
示例:公司影响:我看到一家在欧洲城市中心提供免费电动滑板车的公司,并声称他们的服务减少了污染。他们使用生命周期分析来比较汽车每公里和他们的电动滑板车每公里的污染。
示例:冷水澡研究论文:我看到一篇同行评审的文章,介绍了一项随机对照试验,其中 2000 名荷兰志愿者被分成每天冷水澡和正常淋浴两组。他们发现,洗冷水澡的人因生病缺勤的次数比对照组少 1/3。所有结果均为自我报告。你怎么看?
















浙公网安备 33010602011771号